机器学习实战1-kNN最近邻算法

文章目录

  • 机器学习基础
    • 机器学习的关键术语
  • k-近邻算法(KNN)
    • 准备:使用python导入数据
    • 实施kNN分类算法
    • 示例:使用kNN改进约会网站的配对效果
      • 准备数据:从文本文件中解析数据
      • 分析数据
      • 准备数据:归一化数值
      • 测试算法:作为完整程序验证分类器
    • 手写识别系统

机器学习基础

机器学习的关键术语

1、属性:将一种事务分类的特征值称为属性,例如我们在做鸟类分类时,我们可以将体重、翼展、脚蹼、后背颜色作为特征,特征通常时训练样本的列,它们是独立测量得到的结果,多个特征联系在一起共同组成一个训练样本
2、目标变量:就是我们要分类的那个结果
3、训练集和测试集:训练集作为算法的输入,用于训练模型,测试集用于检验训练的效果

k-近邻算法(KNN)

主要思想:我们先将已知标签的数据以及对应的标签输入,当输入未知标签的数据时,我们希望根据输入的特征值来判断该数据的特征值,我们先计算该数据与我们已知标签的数据的距离,并将距离排序,取前k个数据,根据前k个数据中出现次数最多的数据的标签作为新数据标签的分类

kNN算法主要是用于分类的一种算法

屏幕截图 2023-08-04 174500.png

准备:使用python导入数据

from numpy import *
# kNN排序时将使用这个模块提供好的函数
import operator

def createDataSet():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

实施kNN分类算法

1.png

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1) - dataSet)
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis = 1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

   sortedClassCount = sorted(classCount.items(),
                              key = operator.itemgetter(1), reverse= True)
	return sortedClassCount[0][0]

这里先说一下shape函数,只做简单说明,shape函数用于确定array的维度比如

group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
print(group.shape)

这里输出的结果是(4,2)

也就是说返回的是矩阵或者数组每一维的长度,返回的结果是一个元组(tuple),元组和例表的区别不能忘记,元组不可修改,列表可以修改

tile()函数,tile是numpy模块中的一个函数,用于矩阵的复制,tile(A, reps), A表示我们要操作的矩阵,reps是我们复制的参数,可以是一个数也可以是一个矩阵(4, 2),tile(A, (4, 2))表示将A矩阵的列复制4次,行复制两次

argsort()方法,对数组进行排序,这里返回的是排序后的下标这和C++中的sort()方法不同

argsort()实现倒序排序

group = array([2, 3, 5, 4])
x = argsort(-group)
print(x)

字典中的get()方法

python中对于非数值型数据进行排序,例如字典

sorted(iterable, cmp=None, key=None, reverse=False)

iterable是一个迭代器,
cmp是比较的函数,这个具有两个参数,参数的值都是从可迭代对象中取出,此函数必须遵守的规则为,大于则返回1,小于则返回-1,等于则返回0。
key – 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。
reverse – 排序规则,reverse = True 降序 , reverse = False 升序(默认)。

 sortedClassCount = sorted(classCount.iteritems(),
                              key = operator.itemgetter(1), reverse= True)

python中的items()返回的是一个列表,iteritems()返回一个迭代器, itemgetter()方法可用于指定关键字排序,operator.itemgetter(1)是按字典中的值进行排序,reverse= True按降序排序,python3已经不支持iteritems(),这里用items()即可。

字典中的get()方法

dict_name.get(key, default = None)

key是我们要查找字典中的key,如果存在则返回对应的值,如果不存在就返回第二个我们设置的参数,当我们没设置时,默认返回None

示例:使用kNN改进约会网站的配对效果

2.png

准备数据:从文本文件中解析数据

from numpy import *

def file2matrix(filename):
    fr = open(filename)
    arrarOLines = fr.readlines()
    numberOfLines = len(arrarOLines)
    returnMat = zeros((numberOfLines, 3))
    classLabelVector = []
    index = 0
    for line in arrarOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        # 将数据的前三行直接存入特征矩阵
        returnMat[index,:] = listFromLine[0:3]
        # 将字符串映射成数字
        if listFromLine[-1] == 'didntLike':
            classLabelVector.append(1)
        elif listFromLine[-1] == 'smallDoses':
            classLabelVector.append(2)
        elif listFromLine[-1] == 'largeDoses':
            classLabelVector.append(3)
        index += 1
    return returnMat, classLabelVector

分析数据

from numpy import *
# kNN排序时将使用这个模块提供好的函数
import operator
import matplotlib
import matplotlib.pyplot as plt

def createDataSet():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis = 1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    sortedClassCount = sorted(classCount.items(),
                              key = operator.itemgetter(1), reverse= True)
    return sortedClassCount[0][0]

# [group, labels] = createDataSet()
# m = classify0([0, 0], group, labels, 2)
# print(m)


def file2matrix(filename):
    fr = open(filename)
    arrarOLines = fr.readlines()
    numberOfLines = len(arrarOLines)
    returnMat = zeros((numberOfLines, 3))
    classLabelVector = []
    index = 0
    for line in arrarOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        # 将数据的前三行直接存入特征矩阵
        returnMat[index,:] = listFromLine[0:3]
        # 将字符串映射成数字
        if listFromLine[-1] == 'didntLike':
            classLabelVector.append(1)
        elif listFromLine[-1] == 'smallDoses':
            classLabelVector.append(2)
        elif listFromLine[-1] == 'largeDoses':
            classLabelVector.append(3)
        index += 1
    return returnMat, classLabelVector


datingDataMat, datingLabels = file2matrix('C:/Users/cxy/OneDrive/桌面/datingTestSet.txt')

fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0*array(datingLabels), 15.0*array(datingLabels))
plt.show()

结果截图:
3.png

add_subplot(x)中参数的含义:
这里前两个表示几*几的网格,最后一个表示第几子图
可能说的有点绕口,下面上程序作图一看说明就明白

import matplotlib.pyplot as plt
fig = plt.figure(figsize = (5,5)) 
ax = fig.add_subplot(221)
ax = fig.add_subplot(222)
ax = fig.add_subplot(223)
ax = fig.add_subplot(224)

4.png

scatter()方法
matplotlib.pyplot.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs)
x,y:长度相同的数组,也就是我们即将绘制散点图的数据点,输入数据。
s:点的大小,默认 20,也可以是个数组,数组每个参数为对应点的大小。
c:点的颜色,默认蓝色 ‘b’,也可以是个 RGB 或 RGBA 二维行数组。
marker:点的样式,默认小圆圈 ‘o’。
cmap:Colormap,默认 None,标量或者是一个 colormap 的名字,只有 c 是一个浮点数数组的时才使用。如果没有申明就是 image.cmap。
norm:Normalize,默认 None,数据亮度在 0-1 之间,只有 c 是一个浮点数的数组的时才使用。
vmin,vmax::亮度设置,在 norm 参数存在时会忽略。
alpha::透明度设置,0-1 之间,默认 None,即不透明。
linewidths::标记点的长度。
edgecolors::颜色或颜色序列,默认为 ‘face’,可选值有 ‘face’, ‘none’, None。
plotnonfinite::布尔值,设置是否使用非限定的 c ( inf, -inf 或 nan) 绘制点。
**kwargs::其他参数。
我们主要用到的是前四个参数,第一个参数是我们要画散点图的横坐标,第二个是纵坐标,第三个散点图中点的颜色,第四个散点图中点的大小

准备数据:归一化数值

def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals, (m, 1))
    normDataSet = normDataSet / tile(ranges, (m, 1))
    return normDataSet, ranges, minVals

normMat, ranges, minVals = autoNorm(datingDataMat)
print(normMat)

min()、max()方法
minVals = dataSet.min(0) 返回dataSet中每一列中的最小值数组
minVals = dataSet.min(1) 返回dataSet中每一行中的最小值数组

测试算法:作为完整程序验证分类器

def datingClassTest():
    hoRatio = 0.10
    datingDataMat, datingLabels = file2matrix('C:/Users/cxy/OneDrive/桌面/datingTestSet.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
        print(f"the classifier came back with: {classifierResult}, the real answer is : {datingLabels[i]}")
        if classifierResult != datingLabels[i]:
            errorCount += 1.0
    print(f"the total error rate is : {errorCount / float(numTestVecs)}")

datingClassTest();

手写识别系统

from numpy import *
# kNN排序时将使用这个模块提供好的函数
import operator
import matplotlib
import matplotlib.pyplot as plt

def createDataSet():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis = 1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    sortedClassCount = sorted(classCount.items(),
                              key = operator.itemgetter(1), reverse= True)
    return sortedClassCount[0][0]

# [group, labels] = createDataSet()
# m = classify0([0, 0], group, labels, 2)
# print(m)


def file2matrix(filename):
    fr = open(filename)
    arrarOLines = fr.readlines()
    numberOfLines = len(arrarOLines)
    returnMat = zeros((numberOfLines, 3))
    classLabelVector = []
    index = 0
    for line in arrarOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        # 将数据的前三行直接存入特征矩阵
        returnMat[index,:] = listFromLine[0:3]
        # 将字符串映射成数字
        if listFromLine[-1] == 'didntLike':
            classLabelVector.append(1)
        elif listFromLine[-1] == 'smallDoses':
            classLabelVector.append(2)
        elif listFromLine[-1] == 'largeDoses':
            classLabelVector.append(3)
        index += 1
    return returnMat, classLabelVector


# datingDataMat, datingLabels = file2matrix('C:/Users/cxy/OneDrive/桌面/datingTestSet.txt')
# print(datingDataMat)
# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0*array(datingLabels), 15.0*array(datingLabels))
# plt.show()

def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals, (m, 1))
    normDataSet = normDataSet / tile(ranges, (m, 1))
    return normDataSet, ranges, minVals

# normMat, ranges, minVals = autoNorm(datingDataMat)
# print(normMat)

# def datingClassTest():
#     hoRatio = 0.10
#     datingDataMat, datingLabels = file2matrix('C:/Users/cxy/OneDrive/桌面/datingTestSet.txt')
#     normMat, ranges, minVals = autoNorm(datingDataMat)
#     m = normMat.shape[0]
#     numTestVecs = int(m*hoRatio)
#     errorCount = 0.0
#     for i in range(numTestVecs):
#         classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
#         print(f"the classifier came back with: {classifierResult}, the real answer is : {datingLabels[i]}")
#         if classifierResult != datingLabels[i]:
#             errorCount += 1.0
#     print(f"the total error rate is : {errorCount / float(numTestVecs)}")
#
# datingClassTest();

def classifyPerson():
    resultList = ['not at all', 'in small doses', 'in large doses']
    percentTats = float(input("percentage of time spent playing video games?"))
    ffMiles = float(input("frequent flier miles earned per year?"))
    iceCream = float(input("liters of ice cream consumed per year?"))
    datingDataMat, datingLabels = file2matrix('C:/Users/cxy/OneDrive/桌面/datingTestSet.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    inArr = array([ffMiles, percentTats, iceCream])
    classifierResult = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3)
    print(f"You will probably like this person: {resultList[classifierResult - 1]}")

classifyPerson()

5.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/64891.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Go语言并发编程(千锋教育)

Go语言并发编程(千锋教育) 视频地址:https://www.bilibili.com/video/BV1t541147Bc?p14 作者B站:https://space.bilibili.com/353694001 源代码:https://github.com/rubyhan1314/go_goroutine 1、基本概念 1.1、…

鉴源论坛·观擎丨浅谈操作系统的适航符合性(上)

作者 | 蔡喁 上海控安可信软件创新研究院副院长 版块 | 鉴源论坛 观擎 社群 | 添加微信号“TICPShanghai”加入“上海控安51fusa安全社区” 01 源头和现状​​​​​​​ 在越来越多的国产机载系统研制中,操作系统软件的选择对后续开展研制以及适航举证活动带来…

码云 Gitee + Jenkins 配置教程

安装jdk 安装maven 安装Jenkins https://blog.csdn.net/minihuabei/article/details/132151292?csdn_share_tail%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22132151292%22%2C%22source%22%3A%22minihuabei%22%7D 插件安装 前往 Manage Jen…

基于Windows手动编译openssl和直接安装openssl

零、环境 win10-64位 VS2019 一、手动编译 前言:对于一般的开发人员而言,在 openssl 上下载已经编译好的 openssl 库,然后直接拿去用即可,,不用手动编译,{见下文直接安装}。。。对于一些开发人员&#…

Jmeter录制HTTPS脚本

Jmeter录制HTTPS脚本 文章目录 添加“HTTP代理服务器”设置浏览器代理证书导入存在问题 添加“HTTP代理服务器” 设置浏览器代理 保持端口一致 证书导入 点击一下启动让jmeter自动生成证书,放在bin目录下: 打开jmeter的SSL管理器选择刚刚生成的证书&…

# 关于Linux下的parted分区工具显示起始点为1049kB的问题解释

关于Linux下的parted分区工具显示起始点为1049kB的问题解释 文章目录 关于Linux下的parted分区工具显示起始点为1049kB的问题解释1 问题展示:2 原因3 修改为KiB方式显示4 最后 1 问题展示: kevinTM1701-b38cbc23:~$ sudo parted /dev/nvme1n1 GNU Part…

SAP 开发编辑界面-关闭助手

打开关闭助手时的开发界面如下: 关闭关闭助手后的界面如下: 菜单栏: 编辑--》修改操作--》关闭助手

会这个Python的测试员,工作都不会太差!

Python语言得天独厚的优势使之在业界的火热程度有增无减,尤其是在经历了互联网,物联网,云计算,大数据,人工智能等浪潮的推动下,其关注度,普适度一路走高。 对于测试人员来说,很多人…

【CSS】旋转中的视差效果

效果 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"/><meta http-equiv"X-UA-Compatible" content"IEedge"/><meta name"viewport" content"widthdevice-…

RK3568 lunch新增设备

rk3568 android9.0 &#xff0c;32位平台 1.进入devices/rockchip/rk356x/ 将rk3568_box_32 拷贝一份&#xff0c;命名为hdx6 2.打开vendorsetup.sh,添加lunch选项 add_lunch_combo hdx6-user add_lunch_combo hdx6-userdebug 3.进入hdx6&#xff0c;修改rk3568_box_32.mk…

Linux root用户执行修改密码命令,提示 Permission denied

问题 linux系统中&#xff08;ubuntu20&#xff09;&#xff0c;root用户下执行passwd命令&#xff0c;提示 passwd: Permission denied &#xff0c;如下图&#xff1a; 排查 1.执行 ll /usr/bin/passwd &#xff0c;查看文件权限是否正确&#xff0c;正常情况是 -rwsr-xr…

阿里云二级域名配置

阿里云二级域名配置 首先需要进入阿里云控制台的域名管理 1.选择域名点击解析 2.添加记录 3.选择A类型 4.主机记录设置【可以aa.bb或者aa.bb.cc】 到时候会变成&#xff1a;aa.bb.***.com 5.解析请求来源设置为默认 6.记录值 设置为要解析的服务器的ip地址 7.TTL 默认即…

MyCat水平分表

1.水平拆分案例场景 2.MyCat配置 这个表只是在 schema.xml配置的逻辑表&#xff0c;在具体的数据库里面是没有的 根据id的模确定数据存在哪个节点上&#xff01;&#xff01;

基于图像形态学处理的目标几何形状检测算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 .................................................... %二进制化图像 Images_bin imbinari…

7个月的测试经验,来面试居然开口要18K,我一问连5K都不值...

2021年8月份我入职了深圳某家创业公司&#xff0c;刚入职还是很兴奋的&#xff0c;到公司一看我傻了&#xff0c;公司除了我一个测试&#xff0c;公司的开发人员就只有3个前端2个后端还有2个UI&#xff0c;在粗略了解公司的业务后才发现是一个从零开始的项目&#xff0c;目前啥…

前端技术基础-css

前端技术基础-css【了解】 一、css理解 概念&#xff1a;CSS&#xff1a;C(cascade) SS(StyleSheet) &#xff0c;级联样式表。作用&#xff1a;对网页提供丰富的视觉效果&#xff0c;进行美化页面(需要在html页面基础上)样式规则&#xff1a;样式1&#xff1a;值1;样式2&…

项目中使用git vscode GitHubDesktopSetup-x64

一、使用git bash 1.使用git bash拉取gitee项目 1.在本地新建一个文件夹&#xff08;这个文件夹是用来存放从gitee上拉下来的项目的&#xff09; 2.在这个文件夹右键选择 git bash here 3.输入命令 git init (创建/初始化一个新的仓库) 4.输入命令 git remote add origin …

51单片机程序烧录教程

STC烧录步骤 &#xff08;1&#xff09;STC单片机烧录方式采用串口进行烧录程序&#xff0c;连接的方式如下图&#xff1a; &#xff08;2&#xff09;所以需要先确保USB转串口驱动是识别到&#xff0c;且驱动运行正常&#xff1b;是否可通过电脑的设备管理器查看驱动是否正常…

AVS3:跨多通道预测PMC

前面的文章中介绍了TSCPM&#xff0c;它是AVS3中用于intra模式的跨通道预测技术&#xff0c;它利用线性模型根据亮度重建像素预测色度像素&#xff0c; 跨通道预测技术用于去除不同通道间的冗余信息&#xff0c;TSCPM可以去除Y-Cb、Y-Cr通道间的冗余&#xff0c;然而却忽略了…

鉴源实验室丨SOME/IP协议安全攻击

作者 | 张昊晖 上海控安可信软件创新研究院工控网络安全组 来源 | 鉴源实验室 社群 | 添加微信号“TICPShanghai”加入“上海控安51fusa安全社区” 01 引 言 随着汽车行业对于数据通信的需求不断增加&#xff0c;SOME/IP作为支持汽车以太网进程和设备间通信的一种通信协议应…