文章目录
- k近邻算法
- 算法原理
- k值的选取
- 特征数据的归一化
- 距离的度量
- 分类原则的制定
- 鸢尾花分类
k近邻算法
k近邻算法是经典的监督学习算法,我们这里主要介绍k近邻算法的基本内容和如何应用
算法原理
k近邻算法的基本原理其实很简单
首先k近邻算法是一个分类算法,在我们进行分类之前,需要先定义“距离”,或者我们可以形象的理解为样本点之间的相似程度
训练集是已经准备好的,也已经完全分好了类别,接下来拿出每一个需要分类的测试点,找到训练集中和他最近的k个点,也就是最相似的k个点,如果这k个点都属于同一个类别,我们就有把握认为这个测试点也是这个类别的了
这个过程完成其实算法也就结束了,更加形象的说法可能是投票,比如说这个人有多少分像自己,然后取出最高的k个得分,如果这k个得分是同一类,就说明这个人是这一类的
k近邻算法是一种基于实例的学习,也是惰性学习的代表,没有显示的训练过程,因为我们一开始就没有对训练集做处理,而是直接使用,所以训练时间为0。与之对应的是急切学习,就是需要从训练集中建模
其次这个过程实际上是属于少数付出多数的,那么k值的选取和距离的定义就显得尤为重要
接下来我们会对这几个规则进行介绍
k值的选取
正如我们之前所说,k值的选取对分类的准确性有很大影响,当k比较小时,模型对于训练集就非常敏感了,一旦出现了扰动,模型很可能就不识别了,也就发生了过拟合
当k值比较大时,初期的时候分类错误率会有所下降,但是随着进一步增大,又会提升,因为样本逐渐趋向于全局的样本了
一般我们采用交叉验证的方式来选取最优的k值,也就是说对每一个k值都做若干次验证,计算他们各自的平均误差,然后选其中误差最小的
特征数据的归一化
如果读者学习过统计学相关的内容,就知道归一化的必要性之大
举个简单的例子,温度从10摄氏度到20摄氏度和分数从60到70,这是两个概念,虽然差距一样,一方面是因为单位不同,另一方面这两个10占所在区域的占比不同
一般情况下是会将所有特征值映射到0到1的范围内处理
归一化的方法有很多,假设一个值是x,他所在范围的最大值最小值已知,求他的归一化值可以用这个公式 x ′ = x − M I N M A X − M I N x^{'}=\frac{x-MIN}{MAX-MIN} x′=MAX−MINx−MIN
我们可以用numpy或者sklearn进行数据的预处理,需要注意的是,无论如何处理数据,我们都需要将数据转换成对应框架所需要的形式
距离的度量
一般来说有欧氏距离,这种距离如果不做归一化处理会受到量纲的影响,马氏距离很方便的表示数据的协方差距离,如果是文字变量就是海明距离,他表示将一个字符串变成另一个字符串所需要替换的字符个数
无论使用何种距离,只要能反映两个变量的相似程度即可,需要注意的是,不同的距离选取标准也会影响到分类的效果
分类原则的制定
如果我们仅仅采用一人一票,来贴标签,是不够恰当的,我们认为距离会赋予这些票的权重,也就是距离越近,越相似,权重也就越大,类似于最开始的评分,如果是一人一票就是0和1,而有权重就是从0到100
鸢尾花分类
sklearn已经内置了这个数据集,我们首先给出代码,然后再对代码做解释
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X,y)
y_predict_on_train = knn.predict(X_train)
y_predict_on_test = knn.predict(X_test)
print("训练集的准确率为:{:.2f}%".format(100*accuracy_score(y_train, y_predict_on_train)))
print("测试集的准确率为:{:.2f}%".format(100*accuracy_score(y_test, y_predict_on_test)))
其实整体分为了几个部分,载入数据集,初始化数据集,分割,训练,预测
最后我们可以给出最后预测的准确率,当然我们也可以自己手动调整一些参数,或者是knn的一些参数,查阅官方文档