人工智能知识分享第二天-机器学习之KNN算法

KNN算法

KNN算法简介

KNN算法思想

K-近邻算法(K Nearest Neighbor,简称KNN)。比如:根据你的“邻居”来推断出你的类别
KNN算法思想:如果一个样本在特征空间中的 k 个最相似的样本中的大多数属于某一个类别,则该样本也属于这个类别

样本相似性:样本都是属于一个任务数据集的。样本距离越近则越相似。

在这里插入图片描述

K值的选择

在这里插入图片描述

【知道】KNN的应用方式

  • 解决问题:分类问题、回归问题

  • 算法思想:若一个样本在特征空间中的 k 个最相似的样本大多数属于某一个类别,则该样本也属于这个类别

  • 相似性:欧氏距离

  • 分类问题的处理流程:

  • 在这里插入图片描述
    1.计算未知样本到每一个训练样本的距离

    2.将训练样本根据距离大小升序排列

    3.取出距离最近的 K 个训练样本

    4.进行多数表决,统计 K 个样本中哪个类别的样本个数最多

    5.将未知的样本归属到出现次数最多的类别

回归问题的处理流程:
在这里插入图片描述

1.计算未知样本到每一个训练样本的距离

2.将训练样本根据距离大小升序排列

3.取出距离最近的 K 个训练样本

4.把这个 K 个样本的目标值计算其平均值

5.作为将未知的样本预测的值

API介绍

【实操】分类API

KNN分类API:

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5) 

​ n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数

【实操】回归API
KNN分类API:

sklearn.neighbors.KNeighborsRegressor(n_neighbors=5)
# 1.工具包
from sklearn.neighbors import KNeighborsClassifier,KNeighborsRegressor
# from sklearn.neighbors import KNeighborsRegressor

# 2.数据(特征工程)
# 分类
# x = [[0,2,3],[1,3,4],[3,5,6],[4,7,8],[2,3,4]]
# y = [0,0,1,1,0]
x = [[0,1,2],[1,2,3],[2,3,4],[3,4,5]]
y = [0.1,0.2,0.3,0.4]

# 3.实例化
# model =KNeighborsClassifier(n_neighbors=3)
model =KNeighborsRegressor(n_neighbors=3)

# 4.训练
model.fit(x,y)

# 5.预测
print(model.predict([[4,4,5]]))

距离度量方法

欧式距离

在这里插入图片描述

曼哈顿距离

在这里插入图片描述

切比雪夫距离

在这里插入图片描述

闵氏距离

在这里插入图片描述

特征预处理

为什么进行归一化、标准化

特征的单位或者大小相差较大,或者某特征的方差相比其他的特征要大出几个数量级容易影响(支配)目标结果,使得一些模型(算法)无法学习到其它的特征
在这里插入图片描述

归一化

通过对原始数据进行变换把数据映射到【mi,mx】(默认为[0,1])之间
在这里插入图片描述
数据归一化的API实现

sklearn.preprocessing.MinMaxScaler (feature_range=(0,1))

feature_range 缩放区间

  • 调用 fit_transform(X) 将特征进行归一化缩放
  • 归一化受到最大值与最小值的影响,这种方法容易受到异常数据的影响, 鲁棒性较差,适合传统精确小数据场景

标准化

通过对原始数据进行标准化,转换为均值为0标准差为1的标准正态分布的数据
在这里插入图片描述

  • mean 为特征的平均值
  • σ 为特征的标准差

数据标准化的API实现

sklearn.preprocessing. StandardScaler()

调用 fit_transform(X) 将特征进行归一化缩放

# 1.导入工具包
from sklearn.preprocessing import MinMaxScaler,StandardScaler

# 2.数据(只有特征)
x = [[90, 2, 10, 40], [60, 4, 15, 45], [75, 3, 13, 46]]

# 3.实例化(归一化,标准化)
# process =MinMaxScaler()
process =StandardScaler()

# 4.fit_transform 处理1
data =process.fit_transform(x)
# print(data)

print(process.mean_)
print(process.var_)

对于标准化来说,如果出现异常点,由于具有一定数据量,少量的异常点对于平均值的影响并不大

【实操】利用KNN算法进行鸢尾花分类

鸢尾花Iris Dataset数据集是机器学习领域经典数据集,鸢尾花数据集包含了150条鸢尾花信息,每50条取自三个鸢尾花中之一:Versicolour、Setosa和Virginica
在这里插入图片描述
每个花的特征用如下属性描述:
在这里插入图片描述
代码实现:

# 导入工具包
from sklearn.datasets import load_iris          # 加载鸢尾花测试集的.
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split    # 分割训练集和测试集的
from sklearn.preprocessing import StandardScaler        # 数据标准化的
from sklearn.neighbors import KNeighborsClassifier      # KNN算法 分类对象
from sklearn.metrics import accuracy_score              # 模型评估的, 计算模型预测的准确率


# 1. 定义函数 dm01_loadiris(), 加载数据集.
def dm01_loadiris():
    # 1. 加载数据集, 查看数据
    iris_data = load_iris()
    print(iris_data)           # 字典形式, 键: 属性名, 值: 数据.
    print(iris_data.keys())

    # 1.1 查看数据集
    print(iris_data.data[:5])
    # 1.2 查看目标值.
    print(iris_data.target)
    # 1.3 查看目标值名字.
    print(iris_data.target_names)
    # 1.4 查看特征名.
    print(iris_data.feature_names)
    # 1.5 查看数据集的描述信息.
    print(iris_data.DESCR)
    # 1.6 查看数据文件路径
    print(iris_data.filename)

# 2. 定义函数 dm02_showiris(), 显示鸢尾花数据.
def dm02_showiris():
    # 1. 加载数据集, 查看数据
    iris_data = load_iris()
    # 2. 数据展示
    # 读取数据, 并设置 特征名为列名.
    iris_df = pd.DataFrame(iris_data.data, columns=iris_data.feature_names)
    # print(iris_df.head(5))
    iris_df['label'] = iris_data.target

    # 可视化, x=花瓣长度, y=花瓣宽度, data=iris的df对象, hue=颜色区分, fit_reg=False 不绘制拟合回归线.
    sns.lmplot(x='petal length (cm)', y='petal width (cm)', data=iris_df, hue='label', fit_reg=False)
    plt.title('iris data')
    plt.show()


# 3. 定义函数 dm03_train_test_split(), 实现: 数据集划分
def dm03_train_test_split():
    # 1. 加载数据集, 查看数据
    iris_data = load_iris()
    # 2. 划分数据集, 即: 特征工程(预处理-标准化)
    x_train, x_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.2,
                                                        random_state=22)
    print(f'数据总数量: {len(iris_data.data)}')
    print(f'训练集中的x-特征值: {len(x_train)}')
    print(f'训练集中的y-目标值: {len(y_train)}')
    print(f'测试集中的x-特征值: {len(x_test)}')

# 4. 定义函数 dm04_模型训练和预测(), 实现: 模型训练和预测
def dm04_model_train_and_predict():
    # 1. 加载数据集, 查看数据
    iris_data = load_iris()

    # 2. 划分数据集, 即: 数据基本处理
    x_train, x_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.2, random_state=22)

    # 3. 数据集预处理-数据标准化(即: 标准的正态分布的数据集)
    transfer = StandardScaler()
    # fit_transform(): 适用于首次对数据进行标准化处理的情况,通常用于训练集, 能同时完成 fit() 和 transform()。
    x_train = transfer.fit_transform(x_train)
    # transform(): 适用于对测试集进行标准化处理的情况,通常用于测试集或新的数据. 不需要重新计算统计量。
    x_test = transfer.transform(x_test)

    # 4. 机器学习(模型训练)
    estimator = KNeighborsClassifier(n_neighbors=5)
    estimator.fit(x_train, y_train)

    # 5. 模型评估.
    # 场景1: 对抽取出的测试集做预测.
    # 5.1 模型评估, 对抽取出的测试集做预测.
    y_predict = estimator.predict(x_test)
    print(f'预测结果为: {y_predict}')

    # 场景2: 对新的数据进行预测.
    # 5.2 模型预测, 对测试集进行预测.
    # 5.2.1 定义测试数据集.
    my_data = [[5.1, 3.5, 1.4, 0.2]]
    # 5.2.2 对测试数据进行-数据标准化.
    my_data = transfer.transform(my_data)
    # 5.2.3 模型预测.
    my_predict = estimator.predict(my_data)
    print(f'预测结果为: {my_predict}')

    # 5.2.4 模型预测概率, 返回每个类别的预测概率
    my_predict_proba = estimator.predict_proba(my_data)
    print(f'预测概率为: {my_predict_proba}')

    # 6. 模型预估, 有两种方式, 均可.
    # 6.1 模型预估, 方式1: 直接计算准确率, 100个样本中模型预测正确的个数.
    my_score = estimator.score(x_test, y_test)
    print(my_score)  # 0.9666666666666667

    # 6.2 模型预估, 方式2: 采用预测值和真实值进行对比, 得到准确率.
    print(accuracy_score(y_test, y_predict))


# 在main方法中测试.
if __name__ == '__main__':
    # 1. 调用函数 dm01_loadiris(), 加载数据集.
    # dm01_loadiris()

    # 2. 调用函数 dm02_showiris(), 显示鸢尾花数据.
    # dm02_showiris()

    # 3. 调用函数 dm03_train_test_split(), 查看: 数据集划分
    # dm03_train_test_split()

    # 4. 调用函数 dm04_模型训练和预测(), 实现: 模型训练和预测
    dm04_model_train_and_predict()

坚持分享 共同进步 如有错误 欢迎指出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/944086.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

b站ip属地评论和主页不一样怎么回事

在浏览B站时,细心的用户可能会发现一个有趣的现象:某些用户的评论IP属地与主页显示的IP属地并不一致。这种差异引发了用户的好奇和猜测,究竟是什么原因导致了这种情况的发生呢?本文将对此进行深入解析,帮助大家揭开这一…

Robyn+Vue3+wangEditor打造个人笔记

Github:https://github.com/gwt805/MYNotes Gitee: https://gitee.com/gwt805/MYNotes GitCode: https://gitcode.com/gwt805/MYNotes/overview

BGP路由反射器,解决路由黑洞问题

路由反射器解决路由黑洞问题 路由反射器解决路由黑洞问题 路由黑洞的产生路由黑洞的解决办法路由反射器解决黑洞问题基本配置配置反射器前查看路由配置路由反射器配置反射器后查看路由路由黑洞的产生 根据BGP建立邻居的规则,只要TCP可达便可建立邻居系。如下图所示: AR2、AR…

JavaFX FXML模式下的布局

常见布局方式概述 在 JavaFX FXML 模式下,有多种布局方式可供选择。这些布局方式可以帮助您有效地组织和排列 UI 组件,以创建出美观且功能良好的用户界面。常用布局容器及布局方式 BorderPane 布局 特点:BorderPane 将空间划分为五个区域&…

二叉树的深搜_求根节点到叶节点数字之和_验证二叉搜索树_二叉树的所有路径

2331. 计算布尔二叉树的值 二叉树遍历可以用递归来解决,该问题的简单子问题是 知道左右孩子的值,再根据| &返回true/false。左子树右子树的值交给dfs解决。 终止条件就是到达叶子节点,即没有左右孩子,因为是完全二叉树 所以只需要判断一个…

Jupyter占用内存高问题排查解决

前言 前段时间我们上线了实例内存预警功能,方便大家更好地管理服务器内存资源。那么,也有同学会问,如果收到系统通知,我该怎么做呢?系统提示交换内存占用过高,但是又不知道是哪些程序占用的,怎么…

python下载,安装,环境配置

下载地址:Python Windows版本下载| Python中文网 官网 选择路径 安装完成 检测安装是否成功 使用 winr 启动运行对话框,输入 cmd 进入命令行。 输入pip list 输入 where python 查看 python.exe 的路径 环境配置 winr 打开运行对话框,输入 …

抓取手机HCI日志

荣耀手机 1、打开开发者模式 2、开启HCI、ADB调试 3、开启AP LOG 拨号界面输入*##2846579##* 4、蓝牙配对 5、抓取log adb pull /data/log/bt ./

IDEA 搭建 SpringBoot 项目之配置 Maven

目录 1?配置 Maven 1.1?打开 settings.xml 文件1.2?配置本地仓库路径1.3?配置中央仓库路径1.4?配置 JDK 版本1.5?重新下载项目依赖 2?配置 idea 2.1?在启动页打开设置2.2?配置 Java Compiler2.3?配置 File Encodings2.4?配置 Maven2.5?配置 Auto Import2.6?配置 C…

算法比赛汇总

数据科学竞赛平台网站整理 | ✨DEEPAI数据分析

深入研究物理光学传播和 ZBF 文件

物理光学传播特征 Zemax 中的物理光学传播 (POP) 是一种用于模拟和分析光在光学系统中传播时的行为的工具。与假设理想化几何射线的射线追踪不同,POP 考虑了光的波动性质,包括衍射和干涉。这使得它对于设计和分析显微镜、激光器等高精度光学系统以及其他…

【Java数据结构】栈和队列

栈(Stack) 栈的概念 栈是一种特殊的线性表,只允许在一端进行插入和删除。栈遵循后进先出,分别在栈顶删除、栈底插入。 栈的常用方法 栈的一些方法,例如:出栈、入栈、取栈顶元素、是否为空、栈中元素个数等…

StarRocks元数据无法合并

一、先说结论 如果您的StarRocks版本在3.1.4及以下,并且使用了metadata_journal_skip_bad_journal_ids来跳过某个异常的journal,结果之后就出现了FE的元数据无法进行Checkpoint的现象,那么选择升级版本到3.1.4以上,就可以解决。 …

使用qrcode.vue生成当前网页的二维码(H5)

使用npm&#xff1a; npm install qrcode.vue 使用yarn&#xff1a; yarn add qrcode.vue package.json&#xff1a; 实现&#xff1a; <template><div class"code"><qrcode-vue :value"currentUrl" :size"size" render-as&…

【STM32】RTT-Studio中HAL库开发教程十:EC800M-4G模块使用

文章目录 一、简介二、模块测试三、OneNet物联网配置四、完整代码五、测试验证 一、简介 EC800M4G是一款4G模块&#xff0c;本次实验主要是进行互联网的测试&#xff0c;模块测试&#xff0c;以及如何配置ONENET设备的相关参数&#xff0c;以及使用STM32F4来测试模块的数据上报…

STM32完全学习——FATFS0.15移植SD卡

一、下载FATFS源码 大家都知道使用CubMAX可以很快的将&#xff0c;FATFS文件管理系统移植到单片机上&#xff0c;但是别的芯片没有这么好用的工具&#xff0c;就需要自己从官网下载源码进行移植。我们首先解决SD卡的驱动问题&#xff0c;然后再移植FATFS文件管理系统。 二、SD…

【知识】cuda检测GPU是否支持P2P通信及一些注意事项

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ 代码流程 先检查所有GPU之间是否支持P2P通信&#xff1b;然后尝试启用GPU之间的P2P通信&#xff1b;再次检查所有GPU之间是否支持P2P通信。 test.cu&…

Mysql大数据量表分页查询性能优化

一、模拟场景 1、产品表t_product,数据量500万+ 2、未做任何优化前,cout查询时间大约4秒;LIMIT offset, count 时,offset 值较大时查询时间越久。 count查询 SELECT COUNT(*) AS total FROM t_product WHERE deleted = 0 AND tenant_id = 1 分页查询 SELECT * FROM t_…

go语言的成神之路-筑基篇-对文件的操作

目录 一、对文件的读写 Reader?接口 ?Writer接口 copy接口 bufio的使用 ioutil库? 二、cat命令 三、包 1. 包的声明 2. 导入包 3. 包的可见性 4. 包的初始化 5. 标准库包 6. 第三方包 ?7. 包的组织 8. 包的别名 9. 包的路径 10. 包的版本管理 四、go mo…

Qt 应用程序转换为服务

一、在 Windows 上将 Qt 应用程序转换为服务 方法1&#xff1a; 创建一个 Windows 服务应用程序&#xff1a; Windows 服务应用程序是一个没有用户界面的后台进程&#xff0c;通常由 Win32 Service 模板创建&#xff0c;或者直接编写 main() 函数以实现服务逻辑。 修改 Qt 应…