手写kNN算法的实现-用余弦相似度来度量距离

设a为预测点,b为其中一个样本点,在向量空间里,它们的形成的夹角为θ,那么θ越小(cosθ的值越接近1),就说明a点越接近b点。所以我们可以通过考察余弦相似度来预测a点的类型。

在这里插入图片描述

在这里插入图片描述

from collections import Counter
import numpy as np

class MyKnn:
    def __init__(self,neighbors):
        self.k = neighbors
    
    def fit(self,X,Y):
        self.X = np.array(X)
        self.Y = np.array(Y)
        if self.X.ndim != 2 or self.Y.ndim != 1:
            raise Exception("dimensions are wrong!")
        
        if self.X.shape[0] != self.Y.shape[0]:
            raise Exception("input labels are not correct!")
    
    def predict(self,X_pre):
        
        pre = np.array(X_pre)
        if self.X.ndim != pre.ndim:
            raise Exception("input dimensions are wrong!")
        rs = []
        for p in pre:
            temp = []
            for a in self.X:
                cos = (p @ a)/np.linalg.norm(p)/np.linalg.norm(a)
                temp.append(cos)
            temp = np.array(temp)
            indices = np.argsort(temp)[:-self.k-1:-1]
            ss = np.take(self.Y,indices)
            found = Counter(ss).most_common(1)[0][0]
            print(found)
            rs.append(found)
        return np.array(rs)
        

测试:

# 用鸢尾花数据集来验证我们上面写的算法
from sklearn.datasets import load_iris
# 使用train_test_split对数据集进行拆分,一部分用于训练,一部分用于测试验证
from sklearn.model_selection import train_test_split
# 1.生成一个kNN模型
myknn = MyKnn(5)
# 2.准备数据集:特征集X_train和标签集y_train
X_train,y_train = load_iris(return_X_y=True)
# 留出30%的数据集用于验证测试
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train,test_size=0.3)
# 3.训练模型
myknn.fit(X_train,y_train)
# 4.预测,acc就是预测结果
acc = myknn.predict(X_test)
# 计算准确率
(acc == y_test).mean()

其实如果余弦相似度来进行分类,那么根据文章最开头讲到的,其实取余弦值最大的点作为预测类型也可以:

import numpy as np

class MyClassicfication:
    
    def fit(self,X,Y):
        self.X = np.array(X)
        self.Y = np.array(Y)
        if self.X.ndim != 2 or self.Y.ndim != 1:
            raise Exception("dimensions are wrong!")
        
        if self.X.shape[0] != self.Y.shape[0]:
            raise Exception("input labels are not correct!")
    
    def predict(self,X_pre):
        
        pre = np.array(X_pre)
        if self.X.ndim != pre.ndim:
            raise Exception("input dimensions are wrong!")
        rs = []
        for p in pre:
            temp = []
            for a in self.X:
                cos = (p @ a)/np.linalg.norm(p)/np.linalg.norm(a)
                temp.append(cos)
            temp = np.array(temp)
            index = np.argsort(temp)[-1]
            found = np.take(self.Y,index)
            rs.append(found)
        return np.array(rs)
        

测试:

# 用鸢尾花数据集来验证我们上面写的算法
from sklearn.datasets import load_iris
# 使用train_test_split对数据集进行拆分,一部分用于训练,一部分用于测试验证
from sklearn.model_selection import train_test_split
# 1.生成一个kNN模型
myCla = MyClassicfication
# 2.准备数据集:特征集X_train和标签集y_train
X_train,y_train = load_iris(return_X_y=True)
# 留出30%的数据集用于验证测试
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train,test_size=0.3)
# 3.训练模型
myCla.fit(X_train,y_train)
# 4.预测,acc就是预测结果
acc = myCla.predict(X_test)
# 计算准确率
(acc == y_test).mean()

经测试,上面两种方式的准确率是差不多的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/694218.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

群体优化算法----树蛙优化算法介绍以及应用于资源分配示例

介绍 树蛙优化算法(Tree Frog Optimization Algorithm, TFO)是一种基于群体智能的优化算法,模拟了树蛙在自然环境中的跳跃和觅食行为。该算法通过模拟树蛙在树枝间的跳跃来寻找最优解,属于近年来发展起来的自然启发式算法的一种 …

k8s挂载配置文件(通过ConfigMap方式)

一、ConfigMap简介 K8s中的ConfigMap是一种用于存储配置数据的API对象,属于Kubernetes中的核心对象。它用于将应用程序的配置信息与容器镜像分离,以便在不重新构建镜像的情况下进行配置的修改和更新。ConfigMap可以存储键值对、文本文件或者以特定格式组…

前后端分离

简要介绍 【【2020版】4小时学会Spring BootVue前后端分离开发】https://www.bilibili.com/video/BV137411B7vB?vd_source63e3491e19d2e508b4778a57ebd65ccf 传统模式 前后端分离 前端:负责数据应用和用户交互 后端:负责数据处理接口 前端HTML--&g…

【内网攻防实战】红日靶场(一)续篇_金票与银票

红日靶场(一)续篇_权限维持 前情提要当前位置执行目标 PsExec.exe拿下域控2008rdesktop 远程登录win7msf上传文件kail回连马连上win7upload上传PsExec.exe PsExec.exe把win7 带到 2008(域控hostname:owa)2008开远程、关防火墙Win7…

GPT-4欺骗人类的惊人成功率达99.16%!

PNAS重磅研究揭示,LLM推理能力越强欺骗率越高!! 此前,MIT的研究发现,AI在各类游戏中为了达到目的,不择手段,学会用佯装和歪曲偏好等方式欺骗人类。 GPT-4o深夜发布!Plus免费可用&…

五个避免的管理错误:提升团队绩效与发展

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

【Java SE】字符串常量池详解,什么情况下字符串String对象存在常量池,通过==进行判断,字符串创建及截取后是否同一个对象

复习字符串创建方式 字符串的31种构造方法 public String();创建一个空白字符串, 不含有任何内容public String(char[] array);根据字符数组的内容,来创建对应的字符串public String(byte[] array);根据字节数组的内筒,来创建对应的字符串 …

Win10下CodeBlock实现socket TCP server/client

文章目录 1 安装codeblock2 适配libws2_32.a库3 TCP socket工作原理4 代码实现服务端客户端5 运行效果1 安装codeblock 官方免费下载 值得一提的是,安装时,指定安装路径,其他默认安装即可 2 适配libws2_32.a库 默认安装,只有3个库,如果编译socket,需要专门的库libws2…

Maven项目的创建

目录 1、Maven简介配置(1)设置本地仓库(2)修改Maven的jdk版本(3)添加国内镜像源添加到idea中 2、常用命令3、IDEA2023创建Maven项目(1)Maven和Maven Archetype区别(1-1&a…

L48---1637. 两点之间不包含任何点的最宽垂直区域(排序)---Java版

1.题目描述 2.思路 (1)返回两点之间内部不包含任何点的 最宽垂直区域 的宽度。 我的理解是相邻两个点,按照等差数列那样,后一个数减去相邻的前一个数,才能保证两数之间不含其他数字。 (2)所以&…

OmniGlue: Generalizable Feature Matching with Foundation Model Guidance

【引用格式】:Jiang H, Karpur A, Cao B, et al. OmniGlue: Generalizable Feature Matching with Foundation Model Guidance[J]. arXiv preprint arXiv:2405.12979, 2024. 【网址】:https://arxiv.org/pdf/2405.12979 【开源代码】:https…

[ue5]建模场景学习笔记(5)——必修内容可交互的地形,交互沙(3)

1.需求分析: 我们现在已经能够让这片地形出现在任意地方,只要角色走在这片地形上,就能够产生痕迹,但这片区域总是需要人工指定,又无法把这片区域无限扩大(显存爆炸),因此尝试使角色无…

【数据结构】十二、八种常用的排序算法讲解及代码分享

目录 一、插入排序 1)算法思想 2)代码 二、希尔排序 1)算法思想 2)代码 三、选择排序 1)算法思想 2)代码 四、堆排序 1)什么是最大堆 2)如何创建最大堆 3)算法思想 4&a…

电脑回收站清空了怎么恢复回来?分享四个好用数据恢复方法

电脑回收站清空了还能恢复回来吗?在使用电脑过程中,很多小伙伴都不重视电脑的回收站,,有用的没用的文件都往里堆积。等空间不够的时候就去一股脑清空回收站。可有时候会发现自己还需要的文件在回收站里,可回收站已经被清空了……那…

单灯双控开关原理

什么是单灯双控?顾名思义,指的是一个灯具可以通过两个不同的开关或控制器进行控制。 例如客厅的主灯可能会设置成单灯双控,一个开关位于门口,另一个位于房间内的另一侧,这样无论你是从门口进入还是从房间内出来&#x…

Meta首席AI科学家Yann LeCun指出生成式AI的不足

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

sqli-labs 靶场 less-11~14 第十一关、第十二关、第十三关、第十四关详解:联合注入、错误注入

SQLi-Labs是一个用于学习和练习SQL注入漏洞的开源应用程序。通过它,我们可以学习如何识别和利用不同类型的SQL注入漏洞,并了解如何修复和防范这些漏洞。Less 11 SQLI DUMB SERIES-11判断注入点 尝试在用户名这个字段实施注入,且试出SQL语句闭合方式为单…

插卡式仪器模块:数字万用表模块(插卡式)

• 6 位数字表显示 • 24 位分辨率 • 250 KSPS 采样率 • 电源和数字 I/O 均采用隔离抗噪技术 • 电压、电流、电阻、电感、电容的高精度测量 • 二极管/三极管测试 通道122输入 阻抗 电压10 MΩHigh-Z, 10 MΩ电流10 Ω50 mΩ / 2 Ω / 2 KΩ输入范围电压 5 V0–60 V电流…

Java桥接模式

桥接模式 最重要的是 将 抽象 与 实现 解耦 , 通过组合 在 抽象 与 实现 之间搭建桥梁 ; 【设计模式】桥接模式 ( 简介 | 适用场景 | 优缺点 | 代码示例 )-CSDN博客 桥接模式(Bridge Pattern)-(最通俗易懂的案例)_桥接模式 例子-…

SpringAI(二)

大模型:具有大规模参数和复杂计算结构的机器学习模型.通常由深度神经网络构建而成,拥有数十亿甚至数千亿个参数.其设计目的在于提高模型的表达能力和预测性能,应对复杂的任务和数据. SpringAI是一个AI工程领域的应用程序框架 大概推出时间是2023年7月份(不确定) 目的是将S…