【MindSpore学习打卡】应用实践-热门LLM及其他AI应用-使用MindSpore实现K近邻算法对红酒数据集进行聚类分析

在机器学习领域,K近邻算法(K-Nearest Neighbor, KNN)是最基础且常用的算法之一。无论是分类任务还是回归任务,KNN都能通过简单直观的方式实现高效的预测。在这篇博客中,我们将基于MindSpore框架,使用KNN算法对Wine数据集进行聚类实验。通过这个实验,我们不仅能够深入了解KNN算法的基本概念和原理,还能学习如何在MindSpore中实现和应用该算法。

K近邻算法原理

基本概念

K近邻算法(K-Nearest-Neighbor, KNN)是一种用于分类和回归的非参数统计方法。其核心思想是:要确定一个样本的类别,可以计算它与所有训练样本的距离,然后找出和该样本最接近的k个样本,统计这些样本的类别并进行投票,票数最多的那个类就是分类的结果。

KNN的三个基本要素

  1. K值:K值越小,模型越容易受噪声影响;K值越大,类别之间的界限会变得模糊。
  2. 距离度量:常用的有欧氏距离、曼哈顿距离、海明距离等。
  3. 分类决策规则:通常是多数表决,或者基于距离加权的多数表决。

分类问题与回归问题

  • 分类问题:找到距离待测样本最近的k个样本,统计每一类样本的个数,最终的分类结果为出现次数最多的那个类。
  • 回归问题:对样本的回归预测输出值为所有邻居标签的均值,或者带权重的均值。

距离的定义

KNN算法的实现依赖于样本之间的距离,最常用的距离函数是欧氏距离。公式如下:

[ d(x,y) = \sqrt{\sum_{i=1}{n}{(x_{i}-y_{i})2}} ]

为了减少特征值尺度范围不同带来的干扰,使用欧氏距离时应将特征向量的每个分量归一化。

数据处理

数据准备

Wine数据集是模式识别中最著名的数据集之一,包含了对来自意大利同一地区但来自三个不同品种的葡萄酒进行化学分析的结果。数据集分析了三种葡萄酒中每种所含13种成分的量。

数据读取与处理

首先,我们需要导入所需的Python库,并读取Wine数据集。

解释:首先,我们读取Wine数据集并将其转换为NumPy数组。X包含了所有样本的13个属性,而Y则包含了样本的类别标签。这样做的目的是将数据转换为易于处理的格式,以便后续的模型训练和预测。

import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import nn, ops

ms.set_context(device_target="CPU")

# 读取数据
with open('wine.data') as csv_file:
    data = list(csv.reader(csv_file, delimiter=','))
print(data[56:62]+data[130:133])

# 处理数据
X = np.array([[float(x) for x in s[1:]] for s in data[:178]], np.float32)
Y = np.array([s[0] for s in data[:178]], np.int32)

在这里插入图片描述

数据可视化

我们可以取样本的某两个属性进行二维可视化,观察样本的分布情况及可分性。

解释:我们通过取样本的某两个属性进行二维可视化,观察样本的分布情况及可分性。这样做的目的是帮助我们直观地理解数据的分布情况,从而更好地选择特征和参数。

attrs = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols',
         'Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue',
         'OD280/OD315 of diluted wines', 'Proline']
plt.figure(figsize=(10, 8))
for i in range(0, 4):
    plt.subplot(2, 2, i+1)
    a1, a2 = 2 * i, 2 * i + 1
    plt.scatter(X[:59, a1], X[:59, a2], label='1')
    plt.scatter(X[59:130, a1], X[59:130, a2], label='2')
    plt.scatter(X[130:, a1], X[130:, a2], label='3')
    plt.xlabel(attrs[a1])
    plt.ylabel(attrs[a2])
    plt.legend()
plt.show()

在这里插入图片描述

数据集划分

将数据集按128:50划分为训练集(已知类别样本)和验证集(待验证样本)。

解释:我们将数据集按128:50划分为训练集和验证集。训练集用于模型的训练,验证集用于评估模型的性能。这样做的目的是确保模型在未见过的数据上也能表现良好,避免过拟合。

train_idx = np.random.choice(178, 128, replace=False)
test_idx = np.array(list(set(range(178)) - set(train_idx)))
X_train, Y_train = X[train_idx], Y[train_idx]
X_test, Y_test = X[test_idx], Y[test_idx]

模型构建

解释:我们利用MindSpore提供的算子,通过矩阵运算的方式计算输入样本和训练样本的距离,并找出top k近邻。这样做的目的是提高计算效率,通过批量操作加速距离计算。

class KnnNet(nn.Cell):
    def __init__(self, k):
        super(KnnNet, self).__init__()
        self.k = k

    def construct(self, x, X_train):
        x_tile = ops.tile(x, (128, 1))
        square_diff = ops.square(x_tile - X_train)
        square_dist = ops.sum(square_diff, 1)
        dist = ops.sqrt(square_dist)
        values, indices = ops.topk(-dist, self.k)
        return indices

def knn(knn_net, x, X_train, Y_train):
    x, X_train = ms.Tensor(x), ms.Tensor(X_train)
    indices = knn_net(x, X_train)
    topk_cls = [0]*len(indices.asnumpy())
    for idx in indices.asnumpy():
        topk_cls[Y_train[idx]] += 1
    cls = np.argmax(topk_cls)
    return cls

模型预测

解释:我们在验证集上验证KNN算法的有效性,取k=5,计算验证集上的精度。这样做的目的是评估模型的性能,确保其在实际应用中的有效性。

在验证集上验证KNN算法的有效性,取 k = 5 k = 5 k=5,验证精度接近80%。

acc = 0
knn_net = KnnNet(5)
for x, y in zip(X_test, Y_test):
    pred = knn(knn_net, x, X_train, Y_train)
    acc += (pred == y)
    print('label: %d, prediction: %s' % (y, pred))
print('Validation accuracy is %f' % (acc/len(Y_test)))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/787930.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

for nested data item, row-key is required.报错解决

今天差点被一个不起眼的bug搞到吐,就是在给表格设置row-key的时候,一直设置不成功,一直报错缺少row-key,一共就那两行代码 实在是找不到还存在什么问题... 先看下报错截图... 看下代码 我在展开行里面用到了一个表格 并且存放表格…

【算法】代码随想录之数组(更新中)

文章目录 前言 一、二分查找法(LeetCode--704) 二、移除元素(LeetCode--27) 前言 跟随代码随想录,学习数组相关的算法题目,记录学习过程中的tips。 一、二分查找法(LeetCode--704&#xff0…

WEB安全基础:网络安全常用术语

一、攻击类别 漏洞:硬件、软件、协议,代码层次的缺陷。 后⻔:方便后续进行系统留下的隐蔽后⻔程序。 病毒:一种可以自我复制并传播,感染计算机和网络系统的恶意软件(Malware),它能损害数据、系统功能或拦…

实战 | YOLOv8使用TensorRT加速推理教程(步骤 + 代码)

导 读 本文主要介绍如何使用TensorRT加速YOLOv8模型推理的详细步骤与演示。 YOLOv8推理加速的方法有哪些? YOLOv8模型推理加速可以通过多种技术和方法实现,下面是一些主要的策略: 1. 模型结构优化 网络剪枝:移除模型中不重要的神经元或连接,减少模型复杂度。 模型精…

大模型lora微调中,rank参数代表什么,怎么选择合适的rank参数

在大模型的LoRA(Low-Rank Adaptation)微调中,rank参数(秩)是一个关键的超参数,它决定了微调过程中引入的低秩矩阵的维度。具体来说,rank参数r表示将原始权重矩阵分解成两个低秩矩阵的维度&#…

突破传统,实时语音技术的革命。Livekit 开源代理框架来袭

🚀 突破传统,实时语音技术的革命!Livekit 开源代理框架来袭! 在数字化时代,实时通信已成为我们日常生活的一部分。但你是否曾想象过,一个能够轻松处理音视频流的代理框架,会如何改变我们的沟通方式?今天,我们就来一探究竟! 🌟 什么是 Livekit 代理框架? Live…

从零开始搭建互联网医院系统:技术与案例解析

随着信息技术的飞速发展和人们对医疗服务需求的增加,互联网医院逐渐成为医疗服务的重要模式。本文将详细介绍从零开始搭建互联网医院系统的关键技术和具体案例,帮助读者理解如何构建一个高效、可靠的互联网医院系统。 一、互联网医院系统的核心技术 1…

ESLint: Delete `␍`(prettier/prettier)解决问题补充

如果你是克隆的,参考这位大佬的文章 vue.js - Delete ␍eslint(prettier/prettier) 错误的解决方案 - 个人文章 - SegmentFault 思否 如果你是个人在本地实现,且改为 仍旧报错,我解决的方案: 改为,同时勾选和我配置一…

Error:sql: expected 1 arguments, got 2

一 背景 在测试一个API接口时,看到日志里面突然抛出一个错误:Error:sql: expected 1 arguments, got 2 看了下,对应的表里面是有相关数据的,sql语句放在mysql里面执行也是没问题!那奇了怪了,为啥会产生这样…

TensorFlow系列:第二讲:准备工作

1.创建项目,选择虚拟环境 项目结构如下: data中的数据集需要提前准备好,数据分为测试集,训练集和验证集。以下是数据集的下载平台:kaggle 2.随便选择一个和水果相关的数据集,下载到本地,导入的项…

ARM裸机:一步步点亮LED(汇编)

硬件工作原理及原理图查阅 LED物理特性介绍 LED本身有2个接线点,一个是LED的正极,一个是LED的负极。LED这个硬件的功能就是点亮或者不亮,物理上想要点亮一颗LED只需要给他的正负极上加正电压即可,要熄灭一颗LED只需要去掉电压即可…

字节码编程javassist之生成带有注解的类

写在前面 本文看下如何使用javassist生成带有注解的类。 1:程序 测试类 package com.dahuyou.javassist.huohuo.cc;import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import ja…

JVM原理(二四):JVM虚拟机锁优化

高效并发是从JDK 5升级到JDK 6后一项重要的改进项,HotSpot虛 拟机开发团队在这个版本上花费了大量的资源去实现各种锁优化技术,如适应性自旋( Adaptive Spinning)、锁消除( Lock Elimination)、锁膨胀(Lock Coarsening)、轻量级锁(Lightweight Locking)、…

了解PPO算法(Proximal Policy Optimization)

Proximal Policy Optimization (PPO) 是一种强化学习算法,由 OpenAI 提出,旨在解决传统策略梯度方法中策略更新过大的问题。PPO 通过引入限制策略更新范围的机制,在保证收敛性的同时提高了算法的稳定性和效率。 PPO算法原理 PPO 算法的核心…

LAMP万字详解(概念、构建步骤)

目录 LAMP Apache 起源 主要特点 软件版本 编译安装httpd服务器 编译安装的优点 操作步骤 准备工作 编译 安装 优化执行路径 添加服务 守护进程 配置httpd 查看 Web 站点的访问情况 虚拟主机 类型 部署基于域名的虚拟主机 为虚拟主机提供域名解析&#xff…

ESP32的I2S引脚及支持的音频标准使用说明

ESP32 I2S 接口 ESP32 有 2 个标准 I2S 接口。这 2 个接口可以以主机或从机模式,在全双工或半双工模式下工作,并且可被配置为 8/16/32/48/64-bit 的输入输出通道,支持频率从 10 kHz 到 40 MHz 的 BCK 时钟。当 1 个或 2 个 被配置为主机模式…

db期末复习自用[应试向 附习题]

第一章 数据库系统实现整体数据的结构化,主要特征之一,是db区别于文件系统的本质区别。 数据库系统三个阶段:人工、文件、数据库系统。 数据库管理系统的功能:数据库定义、操纵 、(保护、存储、维护)、数…

大模型/NLP/算法面试题总结2——transformer流程//多头//clip//对比学习//对比学习损失函数

用语言介绍一下Transformer的整体流程 1. 输入嵌入(Input Embedding) 输入序列(如句子中的单词)首先通过嵌入层转化为高维度的向量表示。嵌入层的输出是一个矩阵,每一行对应一个输入单词的嵌入向量。 2. 位置编码&…

020-GeoGebra中级篇-几何对象之点与向量

本文概述了在GeoGebra中如何使用笛卡尔或极坐标系输入点和向量。用户可以通过指令栏输入数字和角度,使用工具或指令创建点和向量。在笛卡尔坐标系中,示例如“P(1,0)”;在极坐标系中,示例如“P(1;0)”或“v(5;90)”。文章还介绍了点…

SpringBoot + MyBatisPlus 实现多租户分库

一、引言 在如今的软件开发中,多租户(Multi-Tenancy)应用已经变得越来越常见。多租户是一种软件架构技术,它允许一个应用程序实例为多个租户提供服务。每个租户都有自己的数据和配置,但应用程序实例是共享的。而在我们的Spring Boot MyBati…