机器学习 - multi-class 数据集训练 (含代码)

直接上代码

# Multi-class dataset

import numpy as np
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K, D))
y = np.zeros(N*K, dtype='uint8')
for j in range(K):
  ix = range(N*j, N*(j+1))
  r = np.linspace(0.0, 1, N)  # radius
  t = np.linspace(j*4, (j+1)*4, N) + np.random.randn(N)*0.2
  X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
  y[ix] = j
plt.scatter(X[:, 0], X[:, 1], c = y, s = 40, cmap=plt.cm.RdYlBu)
plt.show()

# Turn data into tensors
X = torch.from_numpy(X).type(torch.float)
y = torch.from_numpy(y).type(torch.LongTensor)

# Create train and test splits
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = RANDOM_SEED)

#!pip -q install torchmetrics
from torchmetrics import Accuracy
acc_fn = Accuracy(task = "multiclass", num_classes = 3).to("cpu")

class SpiralModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(in_features = 2, out_features = 10)
    self.linear2 = nn.Linear(in_features = 10, out_features = 10)
    self.linear3 = nn.Linear(in_features = 10, out_features = 3)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.linear3(self.relu(self.linear2(self.relu(self.linear1(x)))))

model_1 = SpiralModel().to("cpu")

# Setup data to be device agnostic
X_train, y_train = X_train.to("cpu"), y_train.to("cpu")
X_test, y_test = X_test.to("cpu"), y_test.to("cpu")

# Setup Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_1.parameters(),
                             lr = 0.02)

# Build a training loop for the model
epochs = 1000

# Loop over data
for epoch in range(epochs):
  ## Training
  model_1.train()

  y_logits = model_1(X_train)
  y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)

  loss = loss_fn(y_logits, y_train)
  acc = acc_fn(y_pred, y_train)

  optimizer.zero_grad()

  loss.backward()

  optimizer.step()

  ## Training
  model_1.eval()
  with torch.inference_mode():
    test_logits = model_1(X_test)
    test_pred = torch.softmax(test_logits, dim=1).argmax(dim=1)

    test_loss = loss_fn(test_logits, y_test)
    test_acc = acc_fn(test_pred, y_test)
  
  if epoch % 100 == 0:
    print(f"Epoch: {epoch} | Loss: {loss:.2f} Acc: {acc:.2f} | Test loss: {test_loss:.2f} Test acc: {test_acc:.2f}")

# Plot decision boundaries for training and test sets 
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model_1, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model_1, X_test, y_test)

结果如下

Epoch: 0 | Loss: 1.13 Acc: 0.32 | Test loss: 1.12 Test acc: 0.37
Epoch: 100 | Loss: 0.09 Acc: 0.98 | Test loss: 0.06 Test acc: 1.00
Epoch: 200 | Loss: 0.04 Acc: 0.99 | Test loss: 0.01 Test acc: 1.00
Epoch: 300 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 400 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 500 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 600 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 700 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 800 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 900 | Loss: 0.01 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00

结果1
结果2

给个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/520765.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

多线程的入门(二)线程实现与初步使用

1.实现Runable接口 实现Runable接口,实现run方法; 这种方式创建的线程实现类执行时需要创建Thread实例去运行该任务 示例如下: package com.example.springbootdamo.Thread;import org.apache.logging.log4j.LogManager; import org.apach…

三子棋游戏----C语言版【超级详细 + 视频演示 + 完整源码】

㊙️小明博客主页:➡️ 敲键盘的小明 ㊙️ ✅关注小明了解更多知识☝️ 文章目录 前言一、三子棋的实现思路二、三子棋的实现步骤2.1 先显示游戏的菜单2.2 游戏的具体实现2.2.1 棋盘的初始化2.2.2 展示棋盘2.2.3 下棋🔴玩家下棋🔴电脑下棋2.2…

二叉树进阶——手撕二叉搜索树

troop主页:troop 手撕二叉搜索树 1.二叉搜索树的定义2.实现(非递归)补充结构2.1查找2.2插入2.3删除(重要)情况1(无孩子&&一个孩子) 3.二叉搜索树的应用3.1K模型3.2KV模型3.2.1KV模型的实现 总结二叉…

RUST语言值所有权之内存复制与移动

1.RUST中每个值都有一个所有者,每次只能有一个所有者 String::from函数会为字符串hello分配一块内存 内存示例如下: 在内存分配前调用s1正常输出 在分配s1给s2后调用报错 因为s1分配给s2后,s1的指向自动失效 s1被move到s2 s1自动释放 字符串克隆使用

I2C驱动实验:读取AP3216C设备中寄存器的数据

一. 简介 经过前面几篇文章的学习,已经完成了I2C驱动框架,字符设备驱动框架,编写了 读写 I2C设备中寄存器的数据的代码,文章如下: I2C驱动实验:实现读/写I2C设备寄存器的函数-CSDN博客 本文在此基础上&a…

C#开发中一些常用的工具类分享

一、配置文件读写类 用于在开发时候C#操作配置文件读写信息 1、工具类 ReadIni 代码 using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks;namesp…

不同设备使用同一个Git账号

想要在公司和家里的电脑上用同一个git账号来pull, push代码 1. 查看原设备的用户名和邮箱 第1种方法, 依次输入 git config user.name git config user.email第2种方法, 输入 cat ~/.gitconfig2. 配置新设备的用户名和邮箱 用户名和邮箱与原设备保持…

高效学习方法:冥想背诵,看一句念一句,再每个词分析位置及语法等合理性,忘记哪个词再看猜下为什么会忘,跟自己的表达哪里不一样。

原则:易学则易行,则效果最好。《易经》 你提到的这种学习方法结合了多种记忆和理解技巧,可以帮助提高学习效率。下面是对这种方法的一个详细解释和一些建议: 冥想背诵:通过冥想来集中注意力,可以帮助你在没…

数据如何才能供得出、流得动、用得好、还安全

众所周知,数据要素已经列入基本生产要素,同时成立国家数据局进行工作统筹。目前数据要素如何发挥其价值,全国掀起了一浪一浪的热潮。 随着国外大语言模型的袭来,国内在大语言模型领域的应用也大放异彩,与此同时&#x…

使用YOLOv8训练自己的【目标检测】数据集

文章目录 1.收集数据集1.1 使用开源已标记数据集1.2 爬取网络图像1.3 自己拍摄数据集1.4 使用数据增强生成数据集1.5 使用算法合成图像 2.标注数据集2.1确认标注格式2.2 开始标注 3.划分数据集4.配置训练环境4.1获取代码4.2安装环境 5.训练模型5.1新建一个数据集yaml文件5.2预测…

java中的正则表达式和异常

正则表达式: 作用一:用来校验数据格式是否合法 作用二:在文本中查找满足要求的内容 不用正则表达式:检验QQ号是否合法,要求全部是数字,长度在6-20,不能以0开头 public class test {public stat…

手机扫码查看视频如何实现?扫描二维码在线看视频的制作技巧

现在的学校或者幼儿园会需要拍摄学生的视频,然后展示给其他人查看,为了能够方便用户能够快速的获取文件内容,所以经常会通过生成视频二维码的方法,将二维码分享之后手机扫码来获取视频内容,有效提升用户获取内容的体验…

PTA C 1050 螺旋矩阵(思路与优化)

本题要求将给定的 N 个正整数按非递增的顺序,填入“螺旋矩阵”。所谓“螺旋矩阵”,是指从左上角第 1 个格子开始,按顺时针螺旋方向填充。要求矩阵的规模为 m 行 n 列,满足条件:mn 等于 N;m≥n;且…

160.相交链表

题目描述 解题思路 ————看评论区大神的思路———— 设「第一个公共节点」为 node ,「链表 headA」的节点数量为 aaa ,「链表 headB」的节点数量为 bbb ,「两链表的公共尾部」的节点数量为 ccc ,则有: 头节点 …

CSS设置字体样式

目录 前言: 1.font-family: 2.font-style: 3.font-weight: 4.font-size: 5.font-variant: 6.font: 前言: 在网页中字体是重要的组成部分,使用好字体可以让网页更…

第一次在msf控制台中运行search命令提示Module database cache not built yet问题解决

0x00 问题描述 在新装的kali虚拟机中使用msfconsole执行search命令时提示Module database cache not built yet问题,显然,是我们相关的数据库缓存存在问题。 故障现象: 0x01 启动数据库服务 msf中的search功能是基于postgresql来实现的&am…

python学习25:python中的元组(tuple)

python中的元组(tuple) 1.什么是元组? 元组也是容器数据类型的一种,同列表几乎是一样的,都是可以在里面封装多个,不同类型的元素在内;与列表最大的不同就是: 元组一旦被定义,就不能修改 2.元组…

物理层习题及其相关知识(谁看谁不迷糊呢)

1. 对于带宽为50k Hz的信道,若有4种不同的物理状态来表示数据,信噪比为20dB 。(1) 按奈奎斯特定理,信道的最大传输数据速率是多少?(2) 按香农定理,信道的最大传输数据速度…

Java设计模式—享元(FlyWeight)模式

享元模式(Flyweight),运用共享技术有效地支持大量细粒度的对象 public abstract class Piece {protected PieceColor m_color;protected PiecePos m_pos;public Piece(PieceColor color ,PiecePos pos){m_color color;m_pos pos;}public ab…

Java笔试总结

. 操作系统中关于竞争和死锁的关系下面描述正确的是? A 竞争一定会导致死锁 B 死锁一定由竞争引起 C 竞争可能引起死锁 D 预防死锁可以防止竞争 答案: C 进程的控制信息和描述信息存放在()。 A JCB B PCB C AFT D SFT 答案: B 当系统发生抖动(thrash…