pytorch-01

加载mnist数据集

one-hot编码实现

import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)  # 获取前5个样本的标签数据
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一个参数张量x,10为类别数
print(y)

对于拥有6000个样本的MNIST数据集来说,标签就是一个6000\times 10大小的矩阵张量。

多层感知机模型

#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()  # 拉平图像矩阵
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),   # 输入大小为28*28,输出大小为312维的线性变换层
            torch.nn.ReLU(),   # 激活函数层
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)  # 最终输出大小为10,对应one-hot标签维度
        )
    def forward(self, input):   # 构建网络
        x = self.flatten(input)  #拉平矩阵为1维
        logits = self.linear_relu_stack(x) # 多层感知机

        return logits

损失函数

优化函数

model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵损失函数,内置了softmax函数,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

loss = loss_fu(pred,label_batch)  # 计算损失

完整模型

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编
import torch
import numpy as np


batch_size = 320                        #设定每次训练的批次数
epochs = 1024                           #设定训练次数

#device = "cpu"                         #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"                         #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式


#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits

model = NeuralNetwork()
model = model.to(device)                #将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
#model = torch.compile(model)            #Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train)//batch_size

#开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred,label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  # 记录每个批次的损失值

    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

可视化模型结构和参数

model = NeuralNetwork()
print(model)

是对模型具体使用的函数及其对应的参数进行打印。

格式化显示:

param = list(model.parameters())
k=0
for i in param:
    l = 1
    print('该层结构:'+str(list(i.size())))
    for j in i.size():
        l*=j
    print('该层参数和:'+str(l))
    k = k+l
print("总参数量:"+str(k))

模型保存

model = NeuralNetwork()
torch.save(model, './model.pth')

netron可视化

安装:pip install netron

运行:命令行输入netron

打开:通过网址http://localhost:8080打开

打开保存的模型文件model.pth:

 

 点击颜色块,可以显示详细信息:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/756467.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【JVM-01】引言

【JVM-01】引言 1. 什么是JVM?2. JDK、JRE、JVM比较3.常用的JVM有那些4.学习路线 1. 什么是JVM? JVM即 Java Virtual Machine(Java虚拟机),是Java程序运行的环境(Java 二进制字节码运行环境)。 好处: 一次编写,到处…

【自然语言处理系列】掌握jieba分词器:从基础到实战,深入文本分析与词云图展示

本文旨在全面介绍jieba分词器的功能与应用,从分词器的基本情况入手,逐步解析全模式与精确模式的不同应用场景。文章进一步指导读者如何通过添加自定义词典优化分词效果,以及如何利用jieba分词器进行关键词抽取和词性标注,为后续的…

Python数据分析案例48——二手房价格影响因素分析

案例背景 房价影响因素也是人们一直关注的问题,本次案例也适合各种学科的同学,无论你是经济管理类还是数学统计,还是电商物流类,都可以使用回归分析。通过数据分析回归分析分组聚合可视化等方法进行研究房价影响因素。 数据介绍 …

【C++】数组、字符串

六、数组、字符串 讨论数组离不开指针,指针基本上就是数组的一切的基础,数组和指针的相关内容参考我的C系列博文:【C语言学习笔记】四、指针_通过变量名访问内存单元中的数据缺点-CSDN博客【C语言学习笔记】三、数组-CSDN博客 1、数组就是&…

高精密机械设备中滚珠导轨的表面处理工艺有哪些?

滚珠导轨是机床传动和定位的传动元件,其表面处理方式对机床性能和使用寿命起着决定性的作用,不同的表面处理方法可以提高导轨的耐磨性、抗腐蚀性和整体性能。那么,滚珠导轨的表面处理方式有哪几种呢? 1、磨削法:磨削技…

数据结构--队列(图文)

队列是一种特殊的线性表,其核心特点是先进先出。 概念及特点: 概念 队列是一种只允许在一端进行插入数据操作,在另一端进行删除数据操作的线性表。进行插入操作的一端被称为队尾(Tail 或 Rear),进行删除…

测定分子结构丨核磁共振(NMR)测试原理、制样要求以及常见问题深度解密!...

✨【元素魔方学术俱乐部】✨ 👩‍🏫👨‍🏫我们创建了一个学术交流群 给全国各地以及各种研究方向的硕博 和老师们提供一个交流的平台📚🧪 感兴趣的话欢迎加入 📲本公众号中回复“社群” 会自动发…

短剧APP平台开发,互联网下的市场发展前景

近几年,短剧作为当下的热门行业之一,具有非常大的发展空间,市场前景一片大好,成为了一个新的的蓝海赛道。 随着互联网科技的不断发展,行业数字化发展已经成为了必然趋势。短剧APP系统的开发让观众拥有了一个在线观看短…

Windows 中的 Hosts 文件是什么?如何找到并修改它?

什么是 Hosts 文件 Hosts 文件是一个纯文本文件,存在于几乎所有的操作系统中,用于将主机名映射到 IP 地址。在域名系统(DNS)尚未普及之前,Hosts 文件是计算机网络中唯一用于主机名解析的方式。随着网络规模的扩大和 D…

Mac可以读取NTFS吗 Mac NTFS软件哪个好 mac ntfs读写工具免费

在跨操作系统环境下使用外部存储设备时,特别是当Windows系统的U盘被连接到Mac电脑时,常常会遇到文件系统兼容性的问题。由于Mac OS原生并不完全支持对NTFS格式磁盘的读写操作,导致用户无法直接在Mac上向NTFS格式的U盘或硬盘写入数据。下面我们…

架构师必知的绝活-JVM调优

前言 为什么要学JVM? 首先:面试需要 了解JVM能帮助回答面试中的复杂问题。面试中涉及到的JVM相关问题层出不穷,难道每次面试都靠背几百上千条面试八股? 其次:基础知识决定上层建筑 自己写的代码都不知道是怎么回事&a…

[ROS 系列学习教程] 建模与仿真 - 使用 ros_control 控制差速轮式机器人

ROS 系列学习教程(总目录) 本文目录 一、差速轮式机器人二、差速驱动机器人运动学模型三、对外接口3.1 输入接口3.2 输出接口 四、控制器参数五、配置控制器参数六、编写硬件抽象接口七、控制机器人移动八、源码 ros_control 提供了多种控制器,其中 diff_drive_cont…

揭示隐藏的模式:秩和检验和单因素方差分析的实战指南【考题】

1.研究一种新方法对于某实验结果准确性提高的效果,并将其与原有方法进行比较,结果见下表,请评价两者是否有不同? (行无序,列有序)-->单方向有序-->两独立样本的秩和检验) 如下图所示,先将相关数据导入spss。 图…

springboot注解@ComponentScan注解作用

一 ComponentScan作用 1.1 注解作用 项目会默认扫描SpringBootApplication注解所在路径的同级和下级的所有子包,使用ComponentScan后他会取代掉默认扫描。 ComponentScan 是Spring框架的注解,它的作用是扫描指定的包路径下的标有 Component、Service、…

docker-mysql主从复制

MySQL主从复制 安装docker和拉取镜像不再赘述 一.主服务器 1.新建主服务器容器-3307 (这里设置的密码可能不生效,若未生效请看问题中的2) docker run -p 3307:3306 --name mysql-master \ -v /mydata/mysql-master/log:/var/log/mysql \…

mindspore打卡第9天 transformer的encoder和decoder部分

mindspore打卡第9天 transformer的encoder和decoder部分 import mindspore from mindspore import nn from mindspore import ops from mindspore import Tensor from mindspore import dtype as mstypeclass ScaledDotProductAttention(nn.Cell):def __init__(self, dropout_…

13017.win10安装WSL2及CUDA开发环境

文章目录 1 win10版本1.1 关键项不能忽略 2 安装WSL2 ubuntu20.042.1 打开控制面板,开启虚拟子系统功能2.2 离线安装ubuntu2.2 WSL2 启动 ubuntu2.3 修改默认启动用户 3 ubuntu中安装vscode-server3.1 win10 中安装vscode3.2 ubuntu中安装vscode-server3.3 启动WSL2…

嵌入式学习(Day 51:ARM指令/汇编与c语言函数相互调用)

1.Supervisor模式与SVC模式 Supervisor模式是ARM处理器的一个特权工作模式,允许执行特权指令和访问特权资源。SVC模式(Supervisor Call)是与Supervisor模式相关的一个功能或指令,用于从用户模式切换到Supervisor模式,…

【C++ | 类型转换】转换构造函数、类型转换运算符 详解及例子源码

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 ⏰发布时间⏰: 本文未经允许…

认识100种电路之稳压电路

在电子电路中,稳压电路扮演着至关重要的角色。那么,为什么电路需要稳压?稳压的原理又是什么?以及稳压需要用到哪些元器件,数量又有多少呢?今天,就让我们一同揭开稳压电路的神秘面纱。 【电路为什…