6.6 实现卷积神经网络LeNet训练并预测手写体数字

模型架构

在这里插入图片描述
在这里插入图片描述

代码实现

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),#padding=2补偿5x5卷积核导致的特征减少。
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Flatten(),
    nn.Linear(16*5*5,120),nn.Sigmoid(),
    nn.Linear(120,84),nn.Sigmoid(),
    nn.Linear(84,10)
)
'''定义X,并打印模型的形状'''
# 第一个参数是样本
X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)
# 输出如下:
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])
'''定义训练批次并加载训练集和测试集'''
batch_size = 256
# 按照batch_size把数据集取出来。取出来之后是放到内存中的,后面要把它加载到GPU中
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 计算预测正确的个数
def accuracy(y_hat,y):
    '''计算预测正确的数量'''
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        # y_hat是下标表示类别,值是该类别的概率。模型结果是预测10个类的概率,谁的概率最大,就取谁的下标
        y_hat = y_hat.argmax(axis=1)
    #  y_hat.type(y.dtype):因为==对数据类型很敏感,因此我们将y_hat的数据类型转换为与y的数据类型一致。
    #  y_hat.type(y.dtype) == y,将预测值y_hat与真实值y比较,返回一个包含 0和1的张量,赋值给cam,最后求和会得到正确预测的数量。
    cam = y_hat.type(y.dtype) == y
    return float(cam.type(y.dtype).sum())
def evaluate_accuracy_gpu(net,data_iter,device=None):
    if isinstance(net,nn.Module):
        net.eval() # 将模型设置为评估模式
        if not device:
            '''
                iter(net.parameters())是将参数集合转换为迭代器,并获取其中的第一个元素
                next(iter(net.parameters())).device ,指取到net.parameters()的第一个元素,获取该元素的设备。
            '''
            device = next(iter(net.parameters())).device
        # Accumulator用于对多个变量进行累加,d2l.Accumulator(2) 是在Accumulator实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量。当我们遍历数据集时,两者都随着时间的推移而累加。
        metric = d2l.Accumulator(2) # 正确预测数,预测总数
        with torch.no_grad():
            for X,y in data_iter:
                if isinstance(X,list): # 详见文章最下面的补充内容
                    X = [x.to(device) for x in X] # 令X使用设备device
                else:
                    X = X.to(device)
                y = y.to(device)
                # y.numel()是批次中样本的数量,accuracy(net(X),y)是用于计算模型在输入数据X上的输出结果与标签Y之间的准确率。
                # metric.add函数将正确预测的数量 和 样本数量作为参数传递进去,用于记录和累计这些指标的值。
                metric.add(accuracy(net(X),y),y.numel())
        return metric[0]/metric[1] # 返回准确率,其中metric[0]存放的是正确预测的个数,metric[1]存放的是样本数量,
def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d: # 对神经网络中的线性层和卷积层的权重进行初始化
            nn.init.xavier_uniform_(m.weight) #用于初始化权重的函数,
    net.apply(init_weights)
    print('training on',device)
    net.to(device) # 设置模型使用device
    optimizer = torch.optim.SGD(net.parameters(),lr=lr)
    loss = nn.CrossEntropyLoss()
    '''
        该代码创建了一个名为animator的动画器,用于在训练过程中可视化损失函数和准确率的变化情况
    '''
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train loss','train acc','test acc'])
    timer,num_batches = d2l.Timer(),len(train_iter)
    for epoch in range(num_epochs):
        # 创建 Accumulator类,统计训练损失之和,正确预测个数之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i,(X,y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X,y = X.to(device),y.to(device)
            y_hat = net(X)
            l = loss(y_hat,y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l*X.shape[0], d2l.accuracy(y_hat,y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2] # 损失之和 / 样本数
            train_acc = metric[1] / metric[2] # 正确预测个数 / 样本数
            if (i+1) % (num_batches//5)==0 or i == num_epochs-1:
                animator.add(epoch + (i+1)/num_epochs,(train_l,train_acc,None))
        test_acc = evaluate_accuracy_gpu(net,test_iter)
        animator.add(epoch+1,(None,None,test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')
# 定义学习率和批次 开始训练
lr,num_epochs = 0.9,10
train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

练习

把平均汇聚层改为最大汇聚层

在这里插入图片描述

把平均池化改为最大池和把激活函数改为RelU之后的效果

net = nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5,padding=2),nn.ReLU(),#padding=2补偿5x5卷积核导致的特征减少。
    nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),nn.ReLU(),
    nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Flatten(),
    nn.Linear(16*5*5,120),nn.ReLU(),
    nn.Linear(120,84),nn.Sigmoid(), #注意,此处不能改为RelU,此处的sigmoid是把预测结果映射成概率
    nn.Linear(84,10)
)

在这里插入图片描述

使用训练好的模型进行预测

y_hat = net(x)

补充:

isinstance(net,nn.Module)

isinstance(net,nn.Module)是Python的内置函数,用于判断一个对象是否属于制定类或其子类的实例。如果net是nn.Module类或子类的实例,那么表达式返回True,否则返回False. nn.Module是pytorch中用于构建神经网络模型的基类,其他神经网络都会继承它,因此使用 isinstance(net,nn.Module),可以确定Net对象是否为一个有效的神经网络模型。

`nn.init.xavier_uniform_(m.weight)

nn.init.xavier_uniform_(m.weight) 是一个用于初始化权重的函数,采用的是 Xavier 均匀分布初始化方法。

在神经网络中,权重的初始化非常重要,合适的初始化可以帮助网络更好地学习和收敛。Xavier 初始化方法是一种常用的权重初始化方法之一,旨在使权重在前向传播过程中保持方差不变。

具体而言,nn.init.xavier_uniform_() 函数会对输入的权重张量 m.weight 进行操作,将其初始化为一个均匀分布中的随机值。这个均匀分布的范围根据权重张量的形状进行调整,以保持前向传播过程中特征的方差稳定。

通过使用 Xavier 初始化方法,可以加速神经网络的训练过程,并且有助于避免梯度消失或梯度爆炸等问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/69096.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

竞赛项目 深度学习实现语义分割算法系统 - 机器视觉

文章目录 1 前言2 概念介绍2.1 什么是图像语义分割 3 条件随机场的深度学习模型3\. 1 多尺度特征融合 4 语义分割开发过程4.1 建立4.2 下载CamVid数据集4.3 加载CamVid图像4.4 加载CamVid像素标签图像 5 PyTorch 实现语义分割5.1 数据集准备5.2 训练基准模型5.3 损失函数5.4 归…

IT运维:使用数据分析平台监控PowerStore存储(进阶)

概述 本文基于《IT运维:使用鸿鹄监控PowerStore存储》(以下简称原文)文章进行了优化。主要优化部分包括存储日志进入到鸿鹄后,如何进行字段抽取,以及图表的展示。 字段抽取:由原来采用视图的方式&#xff0…

干货丨学完网络安全专业,我掌握了哪些技能?

andy Ng在我校完成网络防御与司法大专(Diploma in Network Defense and Forensic Countermeasures)之后,顺利升入我校的网络安全本科课程,目前她就职于一家金融机构并担任安全操作中心的分析专员。在进入我校就读之前,Sandy在建筑行业领域工作…

数学建模—多元线性回归分析(+lasso回归的操作)

第一部分:回归分析的介绍 定义:回归分析是数据分析中最基础也是最重要的分析工具,绝大多数的数据分析问题,都可以使用回归的思想来解决。回归分析的人数就是,通过研究自变量X和因变量Y的相关关系,尝试去解释…

解决遥感技术在生态、能源、大气等领域的碳排放监测及模拟问题

以全球变暖为主要特征的气候变化已成为全球性环境问题,对全球可持续发展带来严峻挑战。2015年多国在《巴黎协定》上明确提出缔约方应尽快实现碳达峰和碳中和目标。2019年第49届 IPCC全会明确增加了基于卫星遥感的排放清单校验方法。随着碳中和目标以及全球碳盘点的现…

2000-2022年全国各地级市绿色金融指数数据

2000-2022年全国各地级市绿色金融指数数据 1、时间:2000-2022年 2、来源:来源:统计局、科技部、中国人民银行等权威机构网站及各种权威统计年鉴,包括全国及各省市统计年鉴、环境状况公报及一些专业统计年鉴,如 《中国…

【深度学习笔记】TensorFlow 基础

在 TensorFlow 2.0 及之后的版本中,默认采用 Eager Execution 的方式,不再使用 1.0 版本的 Session 创建会话。Eager Execution 使用更自然地方式组织代码,无需构建计算图,可以立即进行数学计算,简化了代码调试的过程。…

KAFKA第二课之生产者(面试重点)

生产者学习 1.1 生产者消息发送流程 在消息发送的过程中,涉及到了两个线程——main线程和Sender线程。在main线程中创建了一个双端队列RecordAccumulator。main线程将消息发送给RecordAccumulator,Sender线程不断从RecordAccumulator中拉取消息发送到K…

泰国的区块链和NFT市场调研

泰国的区块链和NFT市场调研 基本介绍 参考: https://zh.wikipedia.org/zh-hans/%E6%B3%B0%E5%9B%BD参考: https://hktdc.infogram.com/thsc–1h7k2303zo75v2x zz制度: 君主立宪制(议会制) 国王: 玛哈哇集拉…

基于vue3+webpack5+qiankun实现微前端

一 主应用改造(又称基座改造) 1 在主应用中安装qiankun(npm i qiankun -S) 2 在src下新建micro-app.js文件,用于存放所有子应用。 const microApps [// 当匹配到activeRule 的时候,请求获取entry资源,渲染到containe…

JVM内存管理

文章目录 1、运行时数据区域1.1 程序计数器(线程私有)1.2 JAVA虚拟机栈(线程私有)1.3 本地方法栈1.4 Java堆(线程共享)1.5 方法区(线程共享)1.6 直接内存(非运行时数据区…

拥抱AIGC浪潮,亚信科技将如何把握时代新增量?

去年底,由ChatGPT带起的AIGC浪潮以迅雷不及掩耳之势席卷全球。 当互联网技术的人口红利逐渐消退之际,AIGC就像打开通用人工智能大门的那把秘钥,加速开启数智化时代的到来。正如OpenAI CEO Sam Altman所言:一个全新的摩尔定律可能…

560. 和为 K 的子数组

思路 本题的主要思路为创建一个哈希表记录每个0~i的和,在遍历这个数组的时候查询有没有sum-k的值在哈希表中,如果有,说明有个位置到当前位置的和为k。   有可能不止一个,哈希表负责记录有几个sum-k,将和记录下来。这…

10个问题,带你重新认识smardaten企业级无代码

很多新客户在接触数睿数据,或者在初步认识smardaten企业级无代码的时候,大家更多地以为只是个普通的无代码工具。在交流过程中,大家也提出了很多疑惑: smardaten无代码平台包括哪些能力? 适合开发哪些应用&#xff1f…

AI自动驾驶

AI自动驾驶 一、自动驾驶的原理二、自动驾驶的分类三、自动驾驶的挑战四、自动驾驶的前景五、关键技术六、自动驾驶的安全问题七、AI数据与自动驾驶八、自动驾驶的AI算法总结 自动驾驶技术是近年来备受关注的热门话题。它代表了人工智能和机器学习在汽车行业的重要应用。本文将…

web集群学习:源码安装nginx配置启动服务脚本、IP、端口、域名的虚拟主机

目录 1、源码安装nginx,并提供服务脚本。 2、配置基于ip地址的虚拟主机 3、配置基于端口的虚拟主机 4、配置基于域名的虚拟主机 1、源码安装nginx,并提供服务脚本。 1、源码安装会有一些软件依赖 (1)检查并安装 Nginx 基础依赖…

PHP智能人才招聘网站mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP智能人才招聘网站 是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 下载地址 https://download.csdn.net/download/qq_41221322/88199392 视频演示 PH…

TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用 TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包, 融合并改进了NLP和CV中的多种知识蒸馏技术&#xff0…

GB28181智慧可视化指挥控制系统之执法记录仪设计探讨

什么是智慧可视化指挥控制系统? 智慧可视化指挥控制平台通过4G/5G网络、WIFI实时传输视音频数据至指挥中心,特别是在有突发情况时,可以指定一台执法仪为现场视频监控器,实时传输当前画面到指挥中心,指挥中心工作人员可…

支持对接鸿蒙系统的无线模块及其常见应用介绍

近距离的无线通信得益于万物互联网的快速发展,基于集成部近距离无线连接,为固定和移动设备建立通信的蓝牙技术也已经广泛应用于汽车领域、工业生产及医疗领域。为协助物联网企业终端产品能快速接入鸿蒙生态系统,SKYLAB联手国产芯片厂家研发推…