Pytorch实战(一):LeNet神经网络

文章目录

  • 一、模型实现
    • 1.1数据集的下载
    • 1.2加载数据集
    • 1.3模型训练
    • 1.4模型预测


  LeNet神经网络是第一个卷积神经网络(CNN),首次采用了卷积层、池化层这两个全新的神经网络组件,接收灰度图像,并输出其中包含的手写数字,在手写字符识别任务上取得了瞩目的准确率。LeNet网络的一系列的版本,以LeNet-5版本最为著名,也是LeNet系列中效果最佳的版本。LeNet神经网络输入图像大小必须为32x32,且所用卷积核大小固定为5x5,模型结构如下:
在这里插入图片描述

模型参数:

  • INPUT(输入层):输入图像尺寸为32x32,且是单通道灰色图像。
  • C1(卷积层):使用6个5x5大小的卷积核,步长为1,卷积后得到6张28×28的特征图。
  • S2(池化层):使用了6个2×2 的平均池化,池化后得到6张14×14的特征图。
  • C3(卷积层):使用了16个大小为5×5的卷积核,步长为1,得到 16 张10×10的特征图。
  • S4(池化层):使用16个2×2的平均池化,池化后得到16张5×5 的特征图。
  • C5(卷积层):使用120个大小为5×5的卷积核,步长为1,卷积后得到120张1×1的特征图。
  • F6(全连接层):输入维度120,输出维度是84(对应7x12 的比特图)。
  • OUTPUT(输出层):使用高斯核函数,输入维度84,输出维度是10(对应数字 0 到 9)。

该模型有如下特点:

  • 1.首次提出卷积神经网络基本框架: 卷积层,池化层,全连接层。
  • 2.卷积层的权重共享,相较于全连接层使用更少参数,节省了计算量与内存空间。
  • 3.卷积层的局部连接,保证图像的空间相关性。
  • 4.使用映射到空间均值下采样,减少特征数量。
  • 5.使用双曲线(tanh)或S型(sigmoid)形式的非线性激活函数。

一、模型实现

1.1数据集的下载

  使用torchversion内置的MNIST数据集,训练集大小60000,测试集大小10000,图像大小是1×28×28,包括数字0~9共10个类。

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torchvision
# 下载训练、测试数据集
mnist_train = torchvision.datasets.MNIST(root='./dataset/',
                                         train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./dataset/',
                                        train=False, download=True, transform=transforms.ToTensor())
print('mnist_train基本信息为:',mnist_train)
print('-----------------------------------------')
print('mnist_test基本信息为:',mnist_test)
print('-----------------------------------------')
img,label=mnist_train[0]
print('mnist_train[0]图像大小及标签为:',img.shape,label)

在这里插入图片描述

1.2加载数据集

trainDataLoader = DataLoader(mnist_train, batch_size=64, num_workers=5, shuffle=True)
testDataLoader = DataLoader(mnist_test, batch_size=64, num_workers=0, shuffle=True)
write = SummaryWriter('./log')
step = 0
for images, labels in testDataLoader:
    write.add_images(tag='train', images, global_step=step)
    step += 1
write.close()

  注意不能使用for images, labels in testDataLoader.datasettestDataLoader.dataset[0]是保存图像(28
,28)和对应标签的元组,而Tensorboardadd_images只能输入NCHW格式对象,使用该代码会报错:

size of input tensor and input format are different. tensor shape: (1, 28, 28), input_format: NCHW

数据加载器按batch_size对数据及标签进行封装名,可直接作为输入。查看封装的元组:

for data in testDataLoader:
    print('type(data):',type(data))
    img,label=data
    print('type(img):',type(img),'img.shape:',img.shape)
    print('type(label):',type(label),'label.shape:',label.shape)

在这里插入图片描述

1.3模型训练

  LeNet模型的输入为(32,32)的图片,而MNIST数据集为(28,28)的图片,故需对原图片进行填充。搭建模型:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.model = nn.Sequential(  #MNIST数据集图像大小为28x28,而LeNet输入为32x32,故需填充
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),  #C1层共六个卷积核,故out_channels=6
            nn.AvgPool2d(kernel_size=2, stride=2),  #C2层使用平均池化
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Conv2d(in_channels=16 * 5 * 5, out_channels=120),
            nn.Linear(in_features=120, out_features=84),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        return self.model(x)

# 初始化模型对象
myLeNet = LeNet()

  设置损失函数、优化器并训练模型:

# 设置损失函数为交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 设置优化器,使用Adam优化算法
learning_rate = 1e-2
optimizer = torch.optim.Adam(myLeNet.parameters(), lr=learning_rate)
total_train_step = 0  # 总训练次数
epoch = 10  # 训练轮数
writer = SummaryWriter(log_dir='./runs/LeNet/')
for i in range(epoch):
    print("-----第{}轮训练开始-----".format(i + 1))
    myLeNet.train()  # 训练模式
    train_loss = 0
    for data in trainDataLoader:
        imgs, labels = data
        imgs = imgs.to(device)  # 适配GPU/CPU
        labels = labels.to(device)
        outputs = myLeNet(imgs)
        loss = loss_fn(outputs, labels)#计算损失函数
        optimizer.zero_grad()  # 清空之前梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        total_train_step += 1  # 更新步数
        train_loss += loss.item()
        writer.add_scalar("train_loss_detail", loss.item(), total_train_step)
    writer.add_scalar("train_loss_total", train_loss, i + 1)
    
writer.close()

1.4模型预测

myLeNet.eval() 
total_test_loss = 0  # 当前轮次模型测试所得损失
total_accuracy = 0  # 当前轮次精确率
with torch.no_grad():  # 关闭梯度反向传播
    for data in testDataLoader:
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        outputs = myLeNet(imgs)
        loss = loss_fn(outputs, targets)
        total_test_loss = total_test_loss + loss.item()
        accuracy = (outputs.argmax(1) == targets).sum()
        total_accuracy = total_accuracy + accuracy
writer.add_scalar("test_loss", total_test_loss, i+1)
writer.add_scalar("test_accuracy", total_accuracy/len(mnist_test), i+1)

https://blog.csdn.net/qq_43307074/article/details/126022041?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171938503416800186515588%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171938503416800186515588&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_click~default-2-126022041-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

https://blog.csdn.net/hellocsz/article/details/80764804?ops_request_misc=&request_id=&biz_id=102&utm_term=LeNet&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-1-80764804.142v100pc_search_result_base3&spm=1018.2226.3001.4187

https://blog.csdn.net/qq_45034708/article/details/128319241?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171936257316800222847105%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171936257316800222847105&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-128319241-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/748224.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32之IIC(软件)

介绍 IIC ( 又称为 I2C 或 IC )是一种串行通信协议, IIC使用两根线路来进行通信: 串行数据线(SDA) 和 串行时钟线(SCL) 。 SDA 线上的数据在 SCL 线的时钟信号下进行 同步传输。 主…

安宝特方案 | AR术者培养:AR眼镜如何帮助医生从“看”到“做”?

每一种新药品的上市都需要通过大量的临床试验,而每一种新的手术工具在普及使用之前也需要经过反复的实践和验证。医疗器械公司都面临着这样的挑战:如何促使保守谨慎的医生从仅仅观察新工具在手术中的应用,转变为在实际手术中实操这项工具。安…

centos7迁移部分成功

早闻CentOS不再维护的消息,确实有些遗憾,毕竟这个系统好用又简单,已经成为了我们工作中的一种习惯。然而,2024年6月30日这一天如约而至,CentOS 7停止维护后,随之而来的安全漏洞又该如何防范?系统…

Stirling-PDF 安装和使用教程

PDF (便携式文档格式) 目前已经成为了文档交换和存储的标准。然而,找到一个功能全面、安全可靠、且完全本地化的 PDF 处理工具并不容易。很多在线 PDF 工具存在隐私和安全风险,而桌面软件往往价格昂贵或功能有限。那么,有没有一种解决方案能够…

Linux安装JDk教程

📖Linux安装JDk教程 ✅下载✅安装 ✅下载 官方Oracle地址:https://www.oracle.com/java/technologies/downloads/archive/ 123云盘:https://www.123pan.com/s/4brbVv-JdmWA.html ✅安装 1.上传安装包jdk-17_linux-x64_bin.tar.gz到指定位…

java易错题型(复习必看)

java易错题型: 下列符号中,哪个用于分隔throws关键字抛出的多个异常 逗号, Java中用来声明一个方法可能抛出某种异常的关键字是throw 对于catch子句的排列,下列哪种是正确的:子类异常在先,父类异常在后&a…

解决“Duplicate keys detected: ‘ ‘.This may cause an update error.”问题

问题原因 出现“Duplicate keys detected”的错误,通常表示在v-for指令中使的:key绑定值有重复。 如果前端是静态数据,一般能自我避免:key绑定值有重复。如果前端是绑定的动态数据,那么需要另外提供一个唯一的键。 在这个例子中&#xff0c…

CV每日论文--2024.6.26

1、StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal 中文标题:StableNormal:减少扩散方差以实现稳定且锐利的法线 简介:本文介绍了一种创新解决方案,旨在优化单目彩色输入(包括静态图片与动态…

糖与蛋白质的“隐秘对话”:DeepGlycanSite如何揭示生命之谜

在生命的复杂舞台上,糖类与蛋白质之间的相互作用犹如一场精心编排的舞蹈,其背后的每一个细微动作都可能对生物体的生理与病理过程产生深远影响。然而,糖类分子的多样性和复杂性,使得科学家们对糖-蛋白质结合位点的识别和研究充满了…

数据预处理功能教程,上传文件生成知识库 | Chatopera

如何快速的生成高质量的知识库? 数据预处理功能教程 | Chatopera 云服务低代码定制聊天机器人 关于 Chatopera Chatopera 云服务重新定义聊天机器人,https://bot.chatopera.com 定制智能客服、知识库、AI 助手、智慧家居等智能应用,释放创新…

图形化用户界面-java头歌实训

图形化用户界面 import java.awt.*; import javax.swing.*; public class GraphicsTester extends JFrame { public GraphicsTester() { super("Graphics Demo"); setSize(480, 300); setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); } public void paint…

Node.js 个人博客

关于该博客 这是一个自己搭建的简易的博客,用于记录一些学习笔记和技术分享。在大四毕业时完成了第一个版本,后续会不断完善和更新。欢迎大家提出宝贵意见和建议。 详细介绍在 blog/posts/博客/博客搭建.md 中: https://github.com/ximingx/blog/blob/m…

php goto解密脚本源码

php goto解密脚本源码 源码下载:https://download.csdn.net/download/m0_66047725/89426171 更多资源下载:关注我。

【Java Web】Servlet控制器

目录 一、Servlet简介 二、Servlet运行流程 三、Servlet开发流程 四、Servlet-api.jar包导入和Content-Type问题 4.1 Servlet-api.jar导入问题 4.2 Http报文头中的Content-Type属性 五、Servlet_url-pattern请求映射路径设置 5.1 url-pattern方式 5.2 注解方式配置servlet 六、…

Linux系统之nice命令的基本使用

Linux系统之nice命令的基本使用 一、nice命令介绍1.1 nice命令简介1.2 进程优先级介绍 二、nice命令基本语法2.1 nice命令的help帮助信息2.2 nice命令选项解释 三、nice命令的基本使用3.1 查看进程优先级3.2 使用nice启动进程3.3 提高优先级 四、注意事项 一、nice命令介绍 1.…

【unity笔记】七、Mirror插件使用

一、简介 Mirror 是一个用于 Unity 的开源多人游戏网络框架,它提供了一套简单高效的网络同步机制,特别适用于中小型多人游戏的开发。以下是 Mirror 插件的一些关键特点和组件介绍: 简单高效:Mirror 以其简洁的 API 和高效的网络…

操作系统面试篇一

很多读者抱怨计算操作系统的知识点比较繁杂,自己也没有多少耐心去看,但是面试的时候又经常会遇到。所以,我带着我整理好的操作系统的常见问题来啦!这篇文章总结了一些我觉得比较重要的操作系统相关的问题比如 用户态和内核态、系统…

在OPenFast中.fst文件,.sum文件,.ech文件,.out文件,.outb文件的功能和作用

在OpenFAST中,5MW_Land_DLL_WTurb目录下的这些文件分别有不同的作用,它们用于不同的模块和目的。以下是每个文件的总结及其作用: 5MW_Land_DLL_WTurb.fst 作用:这是OpenFAST主输入文件。内容:该文件包含了整个仿真所需…

.NET 一款支持8种方式维持权限的工具

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

80年代怀旧动画片大全集,90年代老动画片大全集视频少儿经典下载

观看动画片时,儿童注意力的一般都比较稳定,习惯于跟随动画片的变化而变化。所以,动画片可以从儿童熟悉的事物入手,引起儿童的兴趣,调动儿童的积极性;通过动画片的感染力把情感传达给儿童,把儿童…