1 介绍
支持向量机(Support Vector Machine,简称 SVM)是一种监督学习算法,主要用于分类和回归问题。SVM 的核心思想是找到一个最优的超平面,将不同类别的数据分开。这个超平面不仅要能够正确分类数据,还要使得两个类别之间的间隔(margin)最大化。
1.1 线性可分
在二维空间上,两类点被一条直线完全分开叫做线性可分。
样本中距离超平面最近的一些点,这些点叫做支持向量。
1.2 软间隔
在实际应用中,完全线性可分的样本是很少的,如果遇到了不能够完全线性可分的样本,我们应该怎么办?比如下面这个:
于是我们就有了软间隔,相比于硬间隔的苛刻条件,我们允许个别样本点出现在间隔带里面,比如:
1.3 线性不可分
我们刚刚讨论的硬间隔和软间隔都是在说样本的完全线性可分或者大部分样本点的线性可分。但我们可能会碰到的一种情况是样本点不是线性可分的,比如:
这种情况的解决方法就是将二维线性不可分样本映射到高维空间中,让样本点在高维空间线性可分,比如:
对于在有限维度向量空间中线性不可分的样本,我们将其映射到更高维度的向量空间里,再通过间隔最大化的方式,学习得到支持向量机,就是非线性 SVM。
1.4 优缺点
优点
- 有严格的数学理论支持,可解释性强,不依靠统计方法,从而简化了通常的分类和回归问题
- 能找出对任务至关重要的关键样本(即:支持向量)
- 采用核技巧之后,可以处理非线性分类/回归任务
- 最终决策函数只由少数的支持向量所确定,计算的复杂性取决于支持向量的数目,而不是样本空间的维数,这在某种意义上避免了“维数灾难”。
缺点
- 训练时间长。当采用 SMO 算法时,由于每次都需要挑选一对参数,因此时间复杂度为 O(N2) ,其中 N 为训练样本的数量;
- 当采用核技巧时,如果需要存储核矩阵,则空间复杂度为 O(N2) ;
- 模型预测时,预测时间与支持向量的个数成正比。当支持向量的数量较大时,预测计算复杂度较高。
因此支持向量机目前只适合小批量样本的任务,无法适应百万甚至上亿样本的任务。
2 使用 Python 实现 SVM
2.1 安装必要的库
首先,确保你已经安装了scikit-learn
库。如果没有安装,可以使用以下命令进行安装:
pip install scikit-learn
2.2 导入库
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
2.3 加载数据集
我们将使用scikit-learn
自带的鸢尾花(Iris)数据集。
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2] # 只使用前两个特征
y = iris.target
2.4 划分训练集和测试集
# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
2.5 训练 SVM 模型
# 创建SVM分类器
clf = svm.SVC(kernel='linear') # 使用线性核函数
# 训练模型
clf.fit(X_train, y_train)
2.6 预测与评估
# 在测试集上进行预测
y_pred = clf.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")
2.7 可视化结果
# 绘制决策边界
def plot_decision_boundary(X, y, model):
h = .02 # 网格步长
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('SVM Decision Boundary')
plt.show()
plot_decision_boundary(X_train, y_train, clf)