如何利用Python實現SVM模型

2021-09-30 15:15:14 字數 4622 閱讀 3999

1樓:匿名使用者

我先直觀地闡述我對svm的理解,這其中不會涉及數學公式,然後給出python**。

svm是一種二分類模型,處理的資料可以分為三類:

線性可分,通過硬間隔最大化,學習線性分類器

近似線性可分,通過軟間隔最大化,學習線性分類器

線性不可分,通過核函式以及軟間隔最大化,學習非線性分類器

線性分類器,在平面上對應直線;非線性分類器,在平面上對應曲線。

硬間隔對應於線性可分資料集,可以將所有樣本正確分類,也正因為如此,受噪聲樣本影響很大,不推薦。

軟間隔對應於通常情況下的資料集(近似線性可分或線性不可分),允許一些超平面附近的樣本被錯誤分類,從而提升了泛化效能。

如下圖:

實線是由硬間隔最大化得到的,**能力顯然不及由軟間隔最大化得到的虛線。

對於線性不可分的資料集,如下圖:

我們直觀上覺得這時線性分類器,也就是直線,不能很好的分開紅點和藍點。

但是可以用一個介於紅點與藍點之間的類似圓的曲線將二者分開,如下圖:

我們將x^2對映為x,y^2對映為y,那麼超平面變成了x+y=1。

那麼原空間的線性不可分問題,就變成了新空間的(近似)線性可分問題。

此時就可以運用處理(近似)線性可分問題的方法去解決線性不可分資料集的分類問題。

以上我用最簡單的語言粗略地解釋了svm,沒有用到任何數學知識。但是沒有數學,就體會不到svm的精髓。因此接下來我會用盡量簡潔的語言敘述svm的數學思想,如果沒有看過svm推導過程的朋友完全可以跳過下面這段。

對於求解(近似)線性可分問題:

由最大間隔法,得到凸二次規劃問題,這類問題是有最優解的(理論上可以直接呼叫二次規劃計算包,得出最優解)

我們得到以上凸優化問題的對偶問題,一是因為對偶問題更容易求解,二是引入核函式,推廣到非線性問題。

求解對偶問題得到原始問題的解,進而確定分離超平面和分類決策函式。由於對偶問題裡目標函式和分類決策函式只涉及例項與例項之間的內積,即。我們引入核函式的概念。

拓展到求解線性不可分問題:

如之前的例子,對於線性不可分的資料集的任意兩個例項:xi,xj。當我們取某個特定對映f之後,f(xi)與f(xj)在高維空間中線性可分,運用上述的求解(近似)線性可分問題的方法,我們看到目標函式和分類決策函式只涉及內積。

由於高維空間中的內積計算非常複雜,我們可以引入核函式k(xi,xj)=,因此內積問題變成了求函式值問題。最有趣的是,我們根本不需要知道對映f。精彩!

我不準備在這裡放推導過程,因為已經有很多非常好的學習資料,如果有興趣,可以看:cs229 lecture notes

最後就是smo演算法求解svm問題,有興趣的話直接看作者**:sequential minimal optimization:a fast algorithm for training support vector machines

我直接給出**:smo+svm

核函式用了高斯核,取了不同的sigma

sigma=1,有189個支援向量,相當於用整個資料集進行分類。

sigma=10,有20個支援向量,邊界曲線能較好的擬合資料集特點。

2樓:螞蟻學

python使用tensorflow讀取csv資料訓練dnn深度學習模型

python svm 怎麼訓練模型

3樓:

支援向量機svm(support vector machine)是有監督的分類**模型,本篇文章使用機器學習庫scikit-learn中的手寫數字資料集介紹使用python對svm模型進行訓練並對手寫數字進行識別的過程。

準備工作

手寫數字識別的原理是將數字的**分割為8x8的灰度值矩陣,將這64個灰度值作為每個數字的訓練集對模型進行訓練。手寫數字所對應的真實數字作為分類結果。在機器學習sklearn庫中已經包含了不同數字的8x8灰度值矩陣,因此我們首先匯入sklearn庫自帶的datasets資料集。

然後是交叉驗證庫,svm分類演算法庫,繪製圖表庫等。

12345678910

#匯入自帶資料集from sklearn import datasets#匯入交叉驗證庫from sklearn import cross_validation#匯入svm分類演算法庫from sklearn import svm#匯入圖表庫import matplotlib.pyplot as plt#生成**結果準確率的混淆矩陣from sklearn import metrics

讀取並檢視數字矩陣

從sklearn庫自帶的datasets資料集中讀取數字的8x8矩陣資訊並賦值給digits。

12#讀取自帶資料集並賦值給digitsdigits = datasets.load_digits()

檢視其中的數字9可以發現,手寫的數字9以64個灰度值儲存。從下面的8×8矩陣中很難看出這是數字9。

12#檢視資料集中數字9的矩陣digits.data[9]

以灰度值的方式輸出手寫數字9的影象,可以看出個大概輪廓。這就是經過切割並以灰度儲存的手寫數字9。它所對應的64個灰度值就是模型的訓練集,而真實的數字9是目標分類。

我們的模型所要做的就是在已知64個灰度值與每個數字對應關係的情況下,通過對模型進行訓練來對新的手寫數字對應的真實數字進行分類。

1234

#繪製圖表檢視資料集中數字9的影象plt.imshow(digits.images[9], cmap=plt.

cm.gray_r, interpolation='nearest')plt.title('digits.

target[9]')plt.show()

設定模型的特徵x和**目標y

檢視資料集中的分類目標,可以看到一共有10個分類,分佈為0-9。我們將這個分類目標賦值給y,作為模型的**目標。

12#資料集中的目標分類digits.target

12#將資料集中的目標賦給yy=digits.target

手寫數字的64個灰度值作為特徵賦值給x,這裡需要說明的是64個灰度值是以8×8矩陣的形式保持的,因此我們需要使用reshape函式重新調整矩陣的行列數。這裡也就是將8×8的兩維資料轉換為64×1的一維資料。

123#使用reshape函式對矩陣進行轉換,並賦值給xn_samples = len(digits.images)x = digits.images.

reshape((n_samples, 64))

檢視特徵值x和**目標y的行數,共有1797行,也就是說資料集中共有1797個手寫數字的影象,64列是經過我們轉化後的灰度值。

12#檢視x和y的行數x.shape,y.shape

將資料分割為訓練集和測試集

將1797個手寫數字的灰度值採用隨機抽樣的方法分割為訓練集和測試集,其中訓練集為60%,測試集為40%。

12#隨機抽取生成訓練集和測試集,其中訓練集的比例為60%,測試集40%x_train, x_test, y_train, y_test = cross_validation.train_test_split(x, y, test_size=0.4, random_state=0)

檢視分割後的測試集資料,共有1078條資料。這些資料將用來訓練svm模型。

12#檢視訓練集的行數x_train.shape,y_train.shape

對svm模型進行訓練

將訓練集資料x_train和y_train代入到svm模型中,對模型進行訓練。下面是具體的**和結果。

12#生成svm分類模型clf = svm.svc(gamma=0.001)

12#使用訓練集對svm分類模型進行訓練clf.fit(x_train, y_train)

使用測試集測對模型進行測試

使用測試集資料x_test和y_test對訓練後的svm模型進行檢驗,模型對手寫數字分類的準確率為99.3%。這是非常高的準確率。

那麼是否真的這麼靠譜嗎?下面我們來單獨測試下。

12#使用測試集衡量分類模型準確率clf.score(x_test, y_test)

我們使用測試集的特徵x,也就是每個手寫數字的64個灰度值代入到模型中,讓svm模型進行分類。

12#對測試集資料進行**predicted=clf.predict(x_test)

然後檢視前20個手寫數字的分類結果,也就是手寫數字所對應的真實數字。下面是具體的分類結果。

12#檢視前20個測試集的**結果predicted[:20]

再檢視訓練集中前20個分類結果,也就是真實數字的情況,並將之前的分類結果與測試集的真實結果進行對比。

12#檢視測試集中的真實結果expected=y_test

以下是測試集中前20個真實數字的結果,與前面svm模型的分類結果對比,前20個結果是一致的。

12#檢視測試集中前20個真實結果expected[:20]

使用混淆矩陣來看下svm模型對所有測試集資料的**與真實結果的準確率情況,下面是一個10x10的矩陣,左上角第一行第一個數字60表示實際為0,svm模型也**為0的個數,第一行第二個數字表示實際為0,svm模型**為1的數字。第二行第二個數字73表示實際為1,svm模型也**為1的個數。

12#生成準確率的混淆矩陣(confusion matrix)metrics.confusion_matrix(expected, predicted)

從混淆矩陣中可以看到,大部分的數字svm的分類和**都是正確的,但也有個別的數字分類錯誤,例如真實的數字2,svm模型有一次錯誤的分類為1,還有一次錯誤分類為7。

如何利用SparkStreaming實現UV統計

首先實現uv統計需要明確uv的需求判定,也就是定義什麼是uv,接著用各種運算元對讀取的資料進行處理,最後可以將結果存入redis或者其他資料庫中 如何利用spark streaming實現uv統計 首先以一個簡單的示抄例開始 用spark streaming對從tcp連線中接收的文字進行單詞計數。功...

如何讓SVM輸出概率,請問SVM如何輸出概率值

使用libsvm工具箱的話,在函式svmtrain和svmpredict的輸入引數部分加入 b 1 函式svmpredict的輸出引數增加至3個即可 例如 prob estimates svmpredict label test,feature test,svmstruct,b 1 其中prob e...

如何利用虛擬區域網(VLAN)實現共享

各個小辦公室,因為使用的是小交換機,組建的小型區域網,只要把ip設在同一網段,就可以直接共享。vlan不是用來共享的,是劃分不同的小網 隔離廣播域的。如果vlan要互通的話,需要通過第三層,你可以加個路由 照你的描述,你這個不應該是vlan。你說bai 的不具體。無法ping通有很du多原因。首先z...