机器学习-KNN分类算法

1.1 KNN分类

        KNN分类算法(K-Nearest-Neighbors Classification),又叫K近邻算法。它是概念极其简单,而效果又很优秀的分类算法。1967年由Cover T和Hart P提出。

        KNN分类算法的核心思想:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

        如图,假设已经获取一些动物的特征,且已知这些动物的类别。现在需要识别一只新动物,判断它是哪类动物。首先找到与这个物体最接近的k个动物。假设k=3,则可以找到2只猫和1只狗。由于找到的结果中大多数是猫,则把这个新动物划分为猫类。

KNN方法有三个核心要素:

1.K值

        如果k取值太小,好处是近似误差会减小。但同时预测结果对近邻的样本点非常敏感,仅由非常近的训练样本决定预测结果。使模型变得复杂,容易过拟合。如果k值太大,学习的近似误差会增大,导致分类模糊,即欠拟合。

        下面举例看k值对预测结果的影响。对图5.2中的动物进行分类,当k=3时,分类结果为“猫:狗=2:1”,所以属于猫;当k=5时,表决结果为“猫:狗:熊猫=2:3:1”,所以判断目标动物为狗。

 

        那么K值到底怎么选取呢?涉及到距离的度量问题。

2.距离的度量

        不同的距离所确定的近邻点不同。平面上比较常用的是欧式距离。此外还有曼哈顿距离、余弦距离、球面距离等。

可以得到距离如下所示

3.分类决策规则

        分类结果的确定往往采用多数表决原则,即由输入实例的k个最邻近的训练实例中的多数类决定输入实例的类别。

1.2 初识KNN——鸢尾花分类

1.查看数据

SKlearn中的iris数据集有5个key,分别如下:

  • target_names : 分类名称,包括setosa、versicolor和virginica类。
  • data : 特征数据值。
  • target:分类(150个)。
  • DESCR: 数据集的简介。
  • feature_names: 特征名称。

【例】查看鸢尾花iris数据集。

#【例1.1】对鸢尾花iris数据集进行调用,查看数据的各方面特征。
from sklearn.datasets import load_iris
iris_dataset = load_iris()
#下面是查看数据的各项属性
print("数据集的Keys:\n",iris_dataset.keys())     #查看数据集的keys。
print("特征名:\n",iris_dataset['feature_names'])  #查看数据集的特征名称
print("数据类型:\n",type(iris_dataset['data']))    #查看数据类型
print("数据维度:\n",iris_dataset['data'].shape)    #查看数据的结构
print("前五条数据:\n{}".format(iris_dataset['data'][:5]))  #查看前5条数据
#查看分类信息
print("标记名:\n",iris_dataset['target_names']) 
print("标记类型:\n",type(iris_dataset['target']))
print("标记维度:\n",iris_dataset['target'].shape)
print("标记值:\n",iris_dataset['target'])
#查看数据集的简介
print('数据集简介:\n',iris_dataset['DESCR'][:20] + "\n.......")  #数据集简介前20个字符
运行结果: 
数据集的Keys:
 dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
特征名:
 ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
数据类型:
 <class 'numpy.ndarray'>
数据维度:
 (150, 4)
前五条数据:
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]]
标记名:
 ['setosa' 'versicolor' 'virginica']
标记类型:
 <class 'numpy.ndarray'>
标记维度:
 (150,)
标记值:
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
数据集简介:
 Iris Plants Database
.......
 2.数据集拆分

        使用train_test_split函数。train_test_split函数属于sklearn.model_selection类中的交叉验证功能,能随机地将样本数据集合拆分成训练集和测试集。

【例】对iris数据集进行拆分,并查看拆分结果。

#【例1.2】对iris数据集进行拆分,并查看拆分结果。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split( iris_dataset['data'], iris_dataset['target'], random_state=2)
print("X_train",X_train)
print("y_train",y_train)
print("X_test",X_test)
print("y_test",y_test)
print("X_train shape: {}".format(X_train.shape))
print("X_test shape: {}".format(X_test.shape))
运行结果:
X_train [[5.5 2.3 4.  1.3]
 [6.9 3.1 5.1 2.3]
 [6.  2.9 4.5 1.5]
 [6.2 2.9 4.3 1.3]
 [6.8 3.2 5.9 2.3]
 [5.  2.3 3.3 1. ]
 [4.8 3.4 1.6 0.2]
 [6.1 2.6 5.6 1.4]
 [5.2 3.4 1.4 0.2]
 [6.7 3.1 4.4 1.4]
 [5.1 3.5 1.4 0.2]
 [5.2 3.5 1.5 0.2]
 [5.5 3.5 1.3 0.2]
 [4.9 2.5 4.5 1.7]
 [6.2 3.4 5.4 2.3]
 [7.9 3.8 6.4 2. ]
 [5.4 3.4 1.7 0.2]
 [6.7 3.1 5.6 2.4]
 [6.3 3.4 5.6 2.4]
 [7.6 3.  6.6 2.1]
 [6.  2.2 5.  1.5]
 [4.3 3.  1.1 0.1]
 [4.8 3.1 1.6 0.2]
 [5.8 2.7 5.1 1.9]
 [5.7 2.8 4.1 1.3]
 [5.2 2.7 3.9 1.4]
 [7.7 3.  6.1 2.3]
 [6.3 2.7 4.9 1.8]
 [6.1 2.8 4.  1.3]
 [5.1 3.7 1.5 0.4]
 [5.7 2.8 4.5 1.3]
 [5.4 3.9 1.3 0.4]
 [5.8 2.8 5.1 2.4]
 [5.8 2.6 4.  1.2]
 [5.1 2.5 3.  1.1]
 [5.7 3.8 1.7 0.3]
 [5.5 2.4 3.7 1. ]
 [5.9 3.  4.2 1.5]
 [6.7 3.1 4.7 1.5]
 [7.7 2.8 6.7 2. ]
 [4.9 3.  1.4 0.2]
 [6.3 3.3 4.7 1.6]
 [5.1 3.8 1.5 0.3]
 [5.8 2.7 3.9 1.2]
 [6.9 3.2 5.7 2.3]
 [4.9 3.1 1.5 0.1]
 [5.  2.  3.5 1. ]
 [4.9 3.1 1.5 0.1]
 [5.  3.5 1.3 0.3]
 [5.4 3.7 1.5 0.2]
 [6.8 3.  5.5 2.1]
 [6.3 3.3 6.  2.5]
 [5.  3.4 1.6 0.4]
 [5.2 4.1 1.5 0.1]
 [6.3 2.5 5.  1.9]
 [7.7 2.6 6.9 2.3]
 [6.  2.2 4.  1. ]
 [7.2 3.6 6.1 2.5]
 [4.9 2.4 3.3 1. ]
 [6.1 2.8 4.7 1.2]
 [6.5 3.  5.2 2. ]
 [5.1 3.5 1.4 0.3]
 [7.4 2.8 6.1 1.9]
 [5.9 3.  5.1 1.8]
 [6.4 2.7 5.3 1.9]
 [4.4 2.9 1.4 0.2]
 [5.6 2.8 4.9 2. ]
 [5.1 3.4 1.5 0.2]
 [5.  3.3 1.4 0.2]
 [5.7 2.6 3.5 1. ]
 [6.9 3.1 5.4 2.1]
 [5.5 2.6 4.4 1.2]
 [6.3 2.8 5.1 1.5]
 [7.  3.2 4.7 1.4]
 [6.8 2.8 4.8 1.4]
 [6.5 3.2 5.1 2. ]
 [6.9 3.1 4.9 1.5]
 [5.5 2.4 3.8 1.1]
 [5.6 3.  4.5 1.5]
 [6.  3.  4.8 1.8]
 [6.  2.7 5.1 1.6]
 [5.8 2.7 5.1 1.9]
 [5.9 3.2 4.8 1.8]
 [5.1 3.8 1.6 0.2]
 [6.2 2.2 4.5 1.5]
 [5.6 3.  4.1 1.3]
 [5.6 2.5 3.9 1.1]
 [5.8 2.7 4.1 1. ]
 [6.4 3.1 5.5 1.8]
 [6.6 2.9 4.6 1.3]
 [5.5 4.2 1.4 0.2]
 [4.4 3.  1.3 0.2]
 [6.3 2.9 5.6 1.8]
 [6.4 3.2 4.5 1.5]
 [7.3 2.9 6.3 1.8]
 [5.  3.6 1.4 0.2]
 [7.1 3.  5.9 2.1]
 [4.9 3.1 1.5 0.1]
 [6.5 3.  5.5 1.8]
 [6.7 3.3 5.7 2.1]
 [5.4 3.4 1.5 0.4]
 [6.1 2.9 4.7 1.4]
 [4.6 3.2 1.4 0.2]
 [6.7 3.  5.2 2.3]
 [5.7 3.  4.2 1.2]
 [5.  3.4 1.5 0.2]
 [6.5 3.  5.8 2.2]
 [6.6 3.  4.4 1.4]
 [5.  3.5 1.6 0.6]
 [4.6 3.6 1.  0.2]
 [6.3 2.5 4.9 1.5]
 [5.7 4.4 1.5 0.4]]
y_train [1 2 1 1 2 1 0 2 0 1 0 0 0 2 2 2 0 2 2 2 2 0 0 2 1 1 2 2 1 0 1 0 2 1 1 0 1
 1 1 2 0 1 0 1 2 0 1 0 0 0 2 2 0 0 2 2 1 2 1 1 2 0 2 2 2 0 2 0 0 1 2 1 2 1
 1 2 1 1 1 2 1 2 1 0 1 1 1 1 2 1 0 0 2 1 2 0 2 0 2 2 0 1 0 2 1 0 2 1 0 0 1
 0]
X_test [[4.6 3.4 1.4 0.3]
 [4.6 3.1 1.5 0.2]
 [5.7 2.5 5.  2. ]
 [4.8 3.  1.4 0.1]
 [4.8 3.4 1.9 0.2]
 [7.2 3.  5.8 1.6]
 [5.  3.  1.6 0.2]
 [6.7 2.5 5.8 1.8]
 [6.4 2.8 5.6 2.1]
 [4.8 3.  1.4 0.3]
 [5.3 3.7 1.5 0.2]
 [4.4 3.2 1.3 0.2]
 [5.  3.2 1.2 0.2]
 [5.4 3.9 1.7 0.4]
 [6.  3.4 4.5 1.6]
 [6.5 2.8 4.6 1.5]
 [4.5 2.3 1.3 0.3]
 [5.7 2.9 4.2 1.3]
 [6.7 3.3 5.7 2.5]
 [5.5 2.5 4.  1.3]
 [6.7 3.  5.  1.7]
 [6.4 2.9 4.3 1.3]
 [6.4 3.2 5.3 2.3]
 [5.6 2.7 4.2 1.3]
 [6.3 2.3 4.4 1.3]
 [4.7 3.2 1.6 0.2]
 [4.7 3.2 1.3 0.2]
 [6.1 3.  4.9 1.8]
 [5.1 3.8 1.9 0.4]
 [7.2 3.2 6.  1.8]
 [6.2 2.8 4.8 1.8]
 [5.1 3.3 1.7 0.5]
 [5.6 2.9 3.6 1.3]
 [7.7 3.8 6.7 2.2]
 [5.4 3.  4.5 1.5]
 [5.8 4.  1.2 0.2]
 [6.4 2.8 5.6 2.2]
 [6.1 3.  4.6 1.4]]
y_test [0 0 2 0 0 2 0 2 2 0 0 0 0 0 1 1 0 1 2 1 1 1 2 1 1 0 0 2 0 2 2 0 1 2 1 0 2
 1]
X_train shape: (112, 4)
X_test shape: (38, 4)
3.使用散点矩阵查看数据特征关系

        在数据分析中,同时观察一组变量的散点图是很有意义的,这也被称为散点图矩阵(scatter plot matrix)。创建这样的图表工作量巨大,可以使用scatter_matrix函数。scatter_matrix函数是Pandas提供了一个能从DataFrame创建散点图矩阵的函数。

【例】对鸢尾花数据结果,使用scatter_matrix显示训练集与测试集的散点图矩阵。

#【例5.3】使用scatter_matrix显示训练集与测试集。
import pandas as pd
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# 创建一个scatter matrix,颜色值来自y_train
pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8)
运行结果:

4.建立KNN模型

        在Python中,实现KNN方法使用的是KNeighborsClassifier类,KNeighborsClassifier类属于Scikit-learn的neighbors包。

核心操作包括以下三步:

  1. 创建KNeighborsClassifier对象,并进行初始化
  2. 调用fit()方法,对数据集进行训练
  3. 调用predict()函数,对测试集进行预测

使用KNN对鸢尾花iris数据集进行分类的完整代码如下:

from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
#导入鸢尾花数据并查看数据特征
iris = datasets.load_iris()
print('数据集结构:',iris.data.shape)
# 获取属性
iris_X = iris.data
# 获取类别
iris_y = iris.target
# 划分成测试集和训练集
iris_train_X,iris_test_X,iris_train_y,iris_test_y=train_test_split(iris_X,iris_y,test_size=0.2, random_state=0)
#分类器初始化
knn = KNeighborsClassifier()
#对训练集进行训练
knn.fit(iris_train_X, iris_train_y)
#对测试集数据的鸢尾花类型进行预测
predict_result = knn.predict(iris_test_X)
print('测试集大小:',iris_test_X.shape)
print('真实结果:',iris_test_y)
print('预测结果:',predict_result)
#显示预测精确率
print('预测精确率:',knn.score(iris_test_X, iris_test_y))

运行结果:

数据集结构: (150, 4)
测试集大小: (30, 4)
真实结果: [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0]
预测结果: [2 1 0 2 0 2 0 1 1 1 2 1 1 1 2 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0]
预测精确率: 0.9666666666666667

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/885662.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

YOLO11关键改进与网络结构图

目录 前言&#xff1a;一、YOLO11的优势二、YOLO11网络结构图三、C3k2作用分析四、总结 前言&#xff1a; 对于一个科研人来说&#xff0c;发表论文水平的高低和你所掌握的信息差有着极大的关系&#xff0c;所以趁着YOLO11刚刚发布&#xff0c;趁热了解&#xff0c;先人一步对…

如何从huggingface下载

我尝试了一下若干步骤&#xff0c;莫名奇妙就成功了 命令行代理 如果有使用魔法上网&#xff0c;可以使用命令行代码&#xff0c;解决所有命令行连不上外网的问题&#xff1a; #配置http git config --global http.proxy 127.0.0.1:xxxx git config --global https.proxy 127…

Linux递归找出目录下最近被修改文件(最近一段时间内被修改过的最新文件)(最近修改文件、最新文件、查找文件)(监控目录、监控mysql文件)

文章目录 命令1&#xff1a;找出目录下最近60分钟内修改的最新文件命令解析&#xff1a; 命令2&#xff1a;找出目录下最近60分钟内修改的最新n个文件 命令1&#xff1a;找出目录下最近60分钟内修改的最新文件 find /ky_data/mysql -type f -mmin -60 -exec ls -ltr {} | tai…

Linux驱动开发(速记版)--平台总线

第四十七章 平台总线模型介绍 47.1 什么是平台总线&#xff1f; 平台总线是Linux内核中的一种虚拟机制&#xff0c;用于连接和匹配平台设备与对应的平台驱动。它简化了设备与驱动之间的绑定过程&#xff0c;提高了系统对硬件的适配性和扩展性。 当设备或驱动被注册时&#xff…

完整网络模型训练(一)

文章目录 一、网络模型的搭建二、网络模型正确性检验三、创建网络函数 一、网络模型的搭建 以CIFAR10数据集作为训练例子 准备数据集&#xff1a; #因为CIFAR10是属于PRL的数据集&#xff0c;所以需要转化成tensor数据集 train_data torchvision.datasets.CIFAR10(root&quo…

YOLO11震撼发布!

非常高兴地向大家介绍 Ultralytics YOLO系列的新模型&#xff1a; YOLO11&#xff01; YOLO11 在以往 YOLO 模型基础上带来了一系列强大的功能和优化&#xff0c;使其速度更快、更准确、用途更广泛。主要改进包括 增强了特征提取功能&#xff0c;从而可以更精确地捕捉细节以更…

[云]Kubernetes 的基础知识

目标&#xff1a; 实践实验室涵盖 Kubernetes 的基础知识&#xff08;这个句子的意思是在实验室中通过实践学习 Kubernetes 的基本概念&#xff09; 在此过程中理解 Kubernetes 概念&#xff08;这个句子的意思是在学习的过程中理解 Kubernetes 的相关概念&#xff09; 议程&…

【无人机设计与技术】四旋翼无人机的建模

摘要 本项目的目标是通过 Simulink 建模和仿真&#xff0c;研究四旋翼无人机的建模、姿态控制、定点位置控制及航点规划功能。无人机建模包含了动力单元模型、控制效率模型和刚体模型&#xff0c;并运用这些模型实现了姿态控制和位置控制。姿态控制为无人机的平稳飞行提供基础…

OpenCV normalize() 函数详解及用法示例

OpenCV的normalize函数用于对数组&#xff08;图像&#xff09;进行归一化处理&#xff0c;即将数组中的元素缩放到一个指定的范围或具有一个特定的标准&#xff08;如均值和标准差&#xff09;。它有两个原型函数, 如下: Normalize()规范化数组的范数或值范围。当normTypeNORM…

制造企业为何需要PLM系统?PLM系统解决方案对制造业重要性分析

制造企业为何需要PLM系统&#xff1f;PLM系统解决方案对制造业重要性分析 新华社9月23日消息&#xff0c;据全国组织机构统一社会信用代码数据服务中心统计&#xff0c;我国制造业企业总量突破600万家。数据显示&#xff0c;2024年1至8月&#xff0c;我国制造业企业数量呈现稳…

简单线性回归分析-基于R语言

本题中&#xff0c;在不含截距的简单线性回归中&#xff0c;用零假设对统计量进行假设检验。首先&#xff0c;我们使用下面方法生成预测变量x和响应变量y。 set.seed(1) x <- rnorm(100) y <- 2*xrnorm(100) &#xff08;a&#xff09;不含截距的线性回归模型构建。 &…

计算机视觉综述

大家好&#xff0c;今天&#xff0c;我们将一起探讨计算机视觉的基本概念、发展历程、关键技术以及未来趋势。计算机视觉是人工智能的一个重要分支&#xff0c;旨在使计算机能够“看”懂图像和视频&#xff0c;从而完成各种复杂的任务。无论你是对这个领域感兴趣的新手&#xf…

Linux操作系统中MongoDB

1、什么是MongoDB 1、非关系型数据库 NoSQL&#xff0c;泛指非关系型的数据库。随着互联网web2.0网站的兴起&#xff0c;传统的关系数据库在处理web2.0网站&#xff0c;特别是超大规模和高并发的SNS类型的web2.0纯动态网站已经显得力不从心&#xff0c;出现了很多难以克服的问…

SpringBoot整合JPA详解

SpringBoot版本是2.0以上(2.6.13) JDK是1.8 一、依赖 <dependencies><!-- jdbc --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-jdbc</artifactId></dependency><!--…

C# C++ 笔记

第一阶段知识总结 lunix系统操作 1、基础命令 &#xff08;1&#xff09;cd cd /[目录名] 打开指定文件目录 cd .. 返回上一级目录 cd - 返回并显示上一次目录 cd ~ 切换到当前用户的家目录 &#xff08;2&#xff09;pwd pwd 查看当前所在目录路径 pwd -L 打印当前物理…

Unity实战案例全解析:RTS游戏的框选和阵型功能(5)阵型功能 优化

前篇&#xff1a;Unity实战案例全解析&#xff1a;RTS游戏的框选和阵型功能&#xff08;4&#xff09;阵型功能-CSDN博客 本案例来源于unity唐老狮&#xff0c;有兴趣的小伙伴可以去泰克在线观看该课程 我只是对重要功能进行分析和做出笔记分享&#xff0c;并未无师自通&#x…

ARM Process state -- SPSR

Holds the saved process state for the current mode. 保存当前模式的已保存进程状态。 N, bit [31] Set to the value of PSTATE.N on taking an exception to the current mode, and copied to PSTATE.N on executing an exception return operation in the current mod…

袋鼠云数据资产平台:数据模型标准化建表重构升级

数据模型是什么&#xff1f;简单来说&#xff0c;数据模型是用来组织和管理数据的一种方式。它为构建高效且可靠的信息系统提供了基础&#xff0c;不仅决定了如何存储和管理数据&#xff0c;还直接影响系统的性能和可扩展性。 想要建立一个良好的数据模型&#xff0c;设计时需…

链表的基础知识

文章目录 概要整体架构流程 小结 概要 链表是一种常见的数据结构&#xff0c;它通过节点之间的连接关系实现数据的存储和访问。链表由一系列节点&#xff08;Node&#xff09;组成&#xff0c;每个节点包含数据和指向下一个节点的指针。链表的特点是物理存储单元上非连续、非顺…

Qt的互斥量用法

目的 互斥量的概念 互斥量是一个可以处于两态之一的变量:解锁和加锁。这样&#xff0c;只需要一个二进制位表示它&#xff0c;不过实际上&#xff0c;常常使用一个整型量&#xff0c;0表示解锁&#xff0c;而其他所有的值则表示加锁。互斥量使用两个过程。当一个线程(或进程)…