Python使用pytorch深度学习框架构造Transformer神经网络模型预测红酒分类例子

1、红酒数据介绍

经典的红酒分类数据集是指UCI机器学习库中的Wine数据集。该数据集包含178个样本,每个样本有13个特征,可以用于分类任务。

具体每个字段的含义如下:
alcohol:酒精含量百分比
malic_acid:苹果酸含量(克/升)
ash:灰分含量(克/升)
alcalinity_of_ash:灰分碱度(以mEq/L为单位)
magnesium:镁含量(毫克/升)
total_phenols:总酚含量(以毫克/升为单位)
flavanoids:类黄酮含量(以毫克/升为单位)
nonflavanoid_phenols:非类黄酮酚含量(以毫克/升为单位)
proanthocyanins:原花青素含量(以毫克/升为单位)
color_intensity:颜色强度(以 absorbance 为单位,对应于 1cm 路径长度处的相对宽度)
hue:色调,即色彩的倾向性或相似性(在 1 至 10 之间的一个数字)
od280/od315_of_diluted_wines:稀释葡萄酒样品的光密度比值,用于测量葡萄酒中各种化合物的浓度
proline:脯氨酸含量(以毫克/升为单位),是一种天然氨基酸,与葡萄酒的品质和口感有关。

2、红酒数据集分析

2.1 加载红酒数据集

# 加载红酒数据集
wineBunch = load_wine()
type(wineBunch)

sklearn.utils.Bunch
sklearn.utils.Bunch是Scikit-learn库中的一个数据容器,类似于Python字典(dictionary),
它可以存储任意数量和类型的数据,并且可以使用点(.)操作符来访问数据。Bunch常用于存储机器学习模型的数据集,
例如描述特征矩阵的数据、相关联的目标向量、特征名称等等,以便于组织和传递这些数据到模型中进行训练或预测。

2.2 红酒数据集形状

len(wineBunch.data),len(wineBunch.target)

(178, 178)

2.3 红酒数据集打印前5行和后5行

featuresDf = pd.DataFrame(data=wineBunch.data, columns=wineBunch.feature_names)   # 特征数据
labelDf = pd.DataFrame(data=wineBunch.target, columns=["target"])               # 标签数据
wineDf = pd.concat([featuresDf, labelDf], axis=1)  # 横向拼接
wineDf.head(5).append(wineDf.tail(5))              # 打印首尾5行

在这里插入图片描述

2.4 红酒数据集列名

wineDf.columns

Index([‘alcohol’, ‘malic_acid’, ‘ash’, ‘alcalinity_of_ash’, ‘magnesium’,
‘total_phenols’, ‘flavanoids’, ‘nonflavanoid_phenols’,
‘proanthocyanins’, ‘color_intensity’, ‘hue’,
‘od280/od315_of_diluted_wines’, ‘proline’, ‘target’],
dtype=‘object’)

2.5 红酒数据集目标标签

print(wineDf.target.unique())
[0 1 2]

3、Transformer对红酒进行分类

3.1 Transformer介绍

Transformer是一种基于注意力机制的神经网络结构,主要用于自然语言处理领域中的序列到序列转换任务,比如机器翻译、文本摘要等。它在2017年被Google提出,并被成功应用于Google Translate中。

Transformer的主要特点在于使用了完全基于注意力机制的编码器-解码器结构,避免了传统循环神经网络(如LSTM)中存在的长序列依赖问题和梯度消失问题。此外,Transformer还使用了残差连接和层归一化等技术,增强了模型的训练效果和泛化能力。

在Transformer模型中,输入序列和输出序列都被表示为固定长度的向量,称为词向量,由多个嵌入层和多个编码器和解码器层组成。其中,编码器和解码器层包括多头注意力机制、前馈神经网络和残差连接等模块,以实现对序列的有效建模和转换。

3.2 引入依赖库

import random         # 导入 random 模块,用于随机数生成
import torch          # 导入 PyTorch 模块,用于深度学习任务
import numpy as np    # 导入 numpy 模块,用于数值计算
from torch import nn  # 从 PyTorch 中导入神经网络模块
from sklearn import datasets  # 从sklearn引入数据集
from sklearn.model_selection import train_test_split  # 导入 sklearn 库中的 train_test_split 函数,用于数据划分
from sklearn.preprocessing import StandardScaler     # 导入 sklearn 库中的 StandardScaler 类,用于数据标准化

3.3 设置随机种子

# 设置随机种子,让模型每次输出的结果都一样
seed_value = 42
random.seed(seed_value)                         # 设置 random 模块的随机种子
np.random.seed(seed_value)                      # 设置 numpy 模块的随机种子
torch.manual_seed(seed_value)                   # 设置 PyTorch 中 CPU 的随机种子
#tf.random.set_seed(seed_value)                 # 设置 Tensorflow 中随机种子
if torch.cuda.is_available():                   # 如果可以使用 CUDA,设置随机种子
    torch.cuda.manual_seed(seed_value)          # 设置 PyTorch 中 GPU 的随机种子
    torch.backends.cudnn.deterministic = True   # 使用确定性算法,使每次运行结果一样
    torch.backends.cudnn.benchmark = False      # 不使用自动寻找最优算法加速运算

3.4 检测GPU是否可用

# 检测GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3.5 加载数据集

# 加载红酒数据集
wine = datasets.load_wine()
X = wine.data
y = wine.target

3.6 拆分训练集和测试集

# 拆分成训练集和测试集,训练集80%和测试集20%
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

3.7 缩放数据

# 缩放数据
scaler = StandardScaler() # 创建一个标准化转换器的实例
X_train = scaler.fit_transform(X_train) # 对训练集进行拟合(计算平均值和标准差)
X_test = scaler.transform(X_test) # 对测试集进行标准化转换,使用与训练集相同的平均值和标准差

3.8 转化成pytorch张量

# 将训练集转换为 PyTorch 张量,并转换为浮点数类型,如果 GPU 可用,则将张量移动到 GPU 上
X_train = torch.tensor(X_train).float().to(device)
# 将训练集转换为 PyTorch 张量,并转换为长整型,如果 GPU 可用,则将张量移动到 GPU 上
y_train = torch.tensor(y_train).long().to(device)
X_test = torch.tensor(X_test).float().to(device)
y_test = torch.tensor(y_test).long().to(device)

3.9 定义Transformer模型

定义 Transformer 模型

class TransformerModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TransformerModel, self).__init__()
        # 构建Transformer编码层,参数包括输入维度、注意力头数
        # 其中d_model要和模型输入维度相同
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size,  # 输入维度
                                                        nhead=1)             # 注意力头数
        # 构建Transformer编码器,参数包括编码层和层数
        self.encoder = nn.TransformerEncoder(self.encoder_layer,             # 编码层
                                             num_layers=1)                   # 层数
        # 构建线性层,参数包括输入维度和输出维度(num_classes)
        self.fc = nn.Linear(input_size,                                      # 输入维度
                            num_classes)                                     # 输出维度

    def forward(self, x):
        #print("A:", x.shape)  # torch.Size([142, 13])
        x = x.unsqueeze(1)    # 增加一个维度,变成(batch_size, 1, input_size)的形状
        #print("B:", x.shape)  # torch.Size([142, 1, 13])
        x = self.encoder(x)   # 输入Transformer编码器进行编码
        #print("C:", x.shape)  # torch.Size([142, 1, 13])
        x = x.squeeze(1)      # 压缩第1维,变成(batch_size, input_size)的形状
        #print("D:", x.shape)  # torch.Size([142, 13])
        x = self.fc(x)        # 输入线性层进行分类预测
        #print("E:", x.shape)  # torch.Size([142, 3])
        return x
# 初始化Transformer模型,并移动到GPU
model = TransformerModel(input_size=13,             # 输入维度
                         num_classes=3).to(device)  # 输出维度

3.10 定义损失函数和优化器

定义损失函数和优化器

criterion = nn.CrossEntropyLoss() # 定义损失函数-交叉熵损失函数

定义优化器

optimizer = torch.optim.Adam(model.parameters(), # 模型参数
lr=0.01) # 学习率

3.11 训练模型

# 训练模型
num_epochs = 100     # 训练100轮
for epoch in range(num_epochs):
    # 正向传播:将训练数据放到模型中,得到模型的输出
    outputs = model(X_train)
    loss = criterion(outputs, y_train)  # 计算交叉熵损失

    # 反向传播和优化:清零梯度、反向传播计算梯度,并根据梯度更新模型参数
    optimizer.zero_grad()  # 清零梯度
    loss.backward()        # 反向传播计算梯度
    optimizer.step()       # 根据梯度更新模型参数

    # 每10轮打印一次损失值,查看模型训练的效果
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

3.12 测试模型

# 测试模型,在没有梯度更新的情况下,对测试集进行推断
with torch.no_grad():
    outputs = model(X_test)   # 使用训练好的模型对测试集进行预测
    _, predicted = torch.max(outputs.data, 1)  # 对输出的结果取 argmax,得到预测概率最大的类别
    accuracy = (predicted == y_test).sum().item() / y_test.size(0)  # 计算模型在测试集上的准确率
    print(f'Test Accuracy: {accuracy:.2f}')   # 打印测试集准确率

3.13 控制输出

Epoch [10/100], Loss: 0.1346
Epoch [20/100], Loss: 0.0325
Epoch [30/100], Loss: 0.0116
Epoch [40/100], Loss: 0.0064
Epoch [50/100], Loss: 0.0040
Epoch [60/100], Loss: 0.0029
Epoch [70/100], Loss: 0.0026
Epoch [80/100], Loss: 0.0021
Epoch [90/100], Loss: 0.0019
Epoch [100/100], Loss: 0.0019
Test Accuracy: 1.00

Process finished with exit code 0

正确率:100%

3.14 保存模型

# 保存模型
PATH = "model.pt"
torch.save(model.state_dict(), PATH)

3.15 加载模型

加载模型

model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/15772.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用手机APP操作使用井用采样器更省时省力

井用采样器的主要功能特点就是:机身小巧,方便操作。可用于井下作业,手机APP可实时查看采样数据,节省人力。 利用自动采样器进行水样采集可以说节省很大的人力物力,但是有时为了采到更具代表性的水样,我们需…

JAVA-6-[Spring框架]Bean的作用域和生命周期

1 Spring Bean 1、Spring有两种类型bean,一种普通bean,另外一种工厂bean(FactoryBean)。 2、普通bean:在配置文件中定义的bean类型就是返回的类型。 3、工厂bean:在配置文件中定义的bean类型可以和返回类型不一样。 第一步 创建类…

React超级简单易懂全面的有关问题回答(面试)

目录 React事件机制: 2、React的事件和普通的HTML有什么不同: - 事件命名的规则不同,原生事件采用全小写,react事件采用小驼峰 3、React组件中怎么做事件代理?他的原理是什么? 4、React高阶组件、Rend…

mysql 如何避免索引失效

案例演示 建表及初始化数据 CREATE TABLE staffs (id INT PRIMARY KEY AUTO_INCREMENT,NAME VARCHAR(24) NOT NULL DEFAULT ,age INT NOT NULL DEFAULT 0,pos VARCHAR(20) NOT NULL DEFAULT ,#职位add_time TIMESTAMP NOT NULL DEFAULT CURREN…

【源码解析】SpringBoot日志系统源码分析

LoggingApplicationListener 日志组件的处理是LoggingApplicationListener实现的。LoggingApplicationListener#onApplicationEvent&#xff0c;监听事件。如果实现接口GenericApplicationListener&#xff0c;可以允许适配事件类型。 private static final Class<?>[]…

链接sqlite

一.sqlite库函数 1.sqlite3_open()函数 语法&#xff1a;*sqlite3_open(const char *filename, sqlite3 *ppDb) 作用&#xff1a;该例程打开一个指向 SQLite 数据库文件的连接&#xff0c;返回一个用于其他 SQLite 程序的数据库连接对象。 参数1&#xff1a;如果 filename …

8个免费的PNG素材网站推荐

很多设计小白都不知道什么是PNG。事实上&#xff0c;PNG是一种支持透明度的图像格式。当你想在设计中将图像与背景或文本混合时&#xff0c;它就会派上用场。 如果你没有时间为你正在处理的设计创建透明的PNG图像&#xff0c;你也可以使用我收集的PNG素材网站&#xff0c;以便…

Mysql-SQL优化

SQL优化 插入数据优化1&#xff09;批量插入数据2&#xff09;手动控制事务3&#xff09;主键顺序插入大批量插入数据 主键优化order by优化group by优化limit优化count优化update优化小结 插入数据优化 插入数据优化主要是对于insert进行批量插入时优化&#xff0c;下面我们就…

d3.js学习笔记②搭建服务器(含报错解决)

强烈建议自己搭建一个服务器&#xff0c;否则在后续往js里导入本地数据&#xff08;比如csv、json文件等&#xff09;的时候会报错。我用的是Apache服务器&#xff0c;下载、安装过程参考这篇文章&#xff1a;Apache安装配置 在浏览器输入http://localhost/或者http://127.0.0…

excel实战小测第四

【项目背景】 本项目为某招聘网站部分招聘信息&#xff0c;要求对“数据分析师”岗位进行招聘需求分析&#xff0c;通过对城市、行业、学历要求、薪资待遇等不同方向进行相关性分析&#xff0c;加深对数据分析行业的了解。 结合企业真实招聘信息&#xff0c;可以帮助有意转向数…

一文看懂数据分析必备计算功能—内存计算

各位数据的朋友&#xff0c;大家好&#xff0c;我是老周道数据&#xff0c;和你一起&#xff0c;用常人思维数据分析&#xff0c;通过数据讲故事。 接下来&#xff0c;我们准备讲一系列的奥威BI软件的内存计算在各种分析场景中的应用。在正式开讲之前&#xff0c;我先简单说一下…

嵌入式设备显示屏相关概念汇总

嵌入式设备常用的显示屏接口 LCD 接口&#xff1a;是一种常见的数字电路接口&#xff0c;支持多种显示器件&#xff0c;如字符型液晶显示器和点阵型液晶显示器等。 VGA 接口&#xff1a;是一种视频接口标准&#xff0c;用于连接显示器和计算机。该接口提供模拟 RGB 信号&#…

快速多关键字统计

实例需求&#xff1a;在每个章节中统计关键字&#xff08;“√”, “”, “〇”, “空缺”&#xff09;的个数&#xff0c;B列中的章节编号作为章节划分的标识&#xff0c;例如1.1.1 ~ 1.1.5为第1.1章节&#xff0c;对应工作表的12 ~ 16行&#xff0c;其中黄色列为需要统计的数…

IDEA中使用Git提交代码提示:您即将把CRLF行分隔符提交到Gt仓库。 建议将core.autocrlf Git特性设置为trUe,以免发生行分隔符问题。

IDEA中使用Git提交代码提示&#xff1a;您即将把CRLF行分隔符提交到Gt仓库。 建议将core.autocrlf Git特性设置为trUe,以免发生行分隔符问题。 问题背景&#xff1a; 在IDEA中&#xff0c;使用Git提交代码到远程仓库时&#xff0c;结果弹出一个警告窗口 问题原因&#xff1a; …

【测试开发】突破瓶颈必学技能——什么是k8s的核心概念?

目录 Docker 和K8s k8s中的重要概念 Master 节点 Node 节点 集群&#xff08;Cluster&#xff09; 标签&#xff08;Label&#xff09; 命名空间&#xff08;Namespace&#xff09; 容器组&#xff08;Pod&#xff09; 无状态部署&#xff08;Deployment&#xff09;…

【2023五一杯数学建模】 B题 快递需求分析问题 建模方案及MATLAB实现代码

【2023五一杯数学建模】 B题 快递需求分析问题 1 题目 请依据以下提供的附件数据和背景内容&#xff0c;建立数学模型&#xff0c;完成接下来的问题&#xff1a;问题背景是&#xff0c;网络购物作为一种重要的消费方式&#xff0c;带动着快递服务需求飞速增长&#xff0c;为我…

AVL 树(自平衡二叉搜索树) 介绍

AVL 树&#xff08;自平衡二叉搜索树) 介绍 前言 在介绍二叉搜索树的章节中提到&#xff0c;二叉搜索树可能退化为线性链表&#xff0c;失去作为二叉树的各种优势。那么过程中需要维持二叉树的形式&#xff0c;同时左右子树的深度差异可控&#xff0c;如果能实现这两个条件&a…

音视频 FFmpeg

文章目录 前言视频编解码硬件解码(高级)软解码(低级)软、硬解码对比视频解码有四个步骤Android 系统中编解码器的命名方式查看当前设备支持的硬解码 基础知识RGB色彩空间常见的格式对比YUV索引格式分离RGB24像素数据中的R、G、B分量 BMP 文件格式格式组成像素排列顺序RGB24格式…

autosar软件分层架构组成--汽车电子

介绍 autosar是汽车软件协会制定的一套软件标准 本文章所有图片来源于网络 一、分层架构 分层&#xff1a;3层 1.上层应用层&#xff08;Application Layer&#xff09; 2.中间件RTE(Runtime Environment) 3.下层的基础软件&#xff08;Basic Software&#xff09; 中间件R…

倾斜摄影超大场景的三维模型轻量化纹理压缩的关键技术

倾斜摄影超大场景的三维模型轻量化纹理压缩的关键技术 倾斜摄影超大场景的三维模型轻量化处理中纹理压缩是轻量化处理的重要手段之一&#xff0c;可以在保证模型真实感的前提下&#xff0c;减小数据体积、降低传输带宽和提高渲染性能。以下是几个关键的纹理压缩技术&#xff1a…