动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet

23卷积神经网络LeNet

在这里插入图片描述

import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt

# 定义一个卷积神经网络
net = nn.Sequential(
    nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2
    nn.ReLU(), # 激活函数
    nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2

    nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5
    nn.ReLU(), 
    nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2

    nn.Flatten(), # 展平层:将多维输入展平为1维
    nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120
    nn.ReLU(),
    nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84
    nn.ReLU(), 
    nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)

# 通过在每一层打印输出的形状,我们可以检查模型
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32) # 随机生成一个形状为(1,1,28,28)的张量,作为输入
for layer in net:
    X = layer(X) # 将输入依次通过每一层
    print(layer.__class__.__name__, 'output shape: \t', X.shape) # 打印每一层的输出形状
"""
Conv2d output shape:     torch.Size([1, 6, 28, 28])
ReLU output shape:       torch.Size([1, 6, 28, 28])
AvgPool2d output shape:          torch.Size([1, 6, 14, 14])
Conv2d output shape:     torch.Size([1, 16, 10, 10])
ReLU output shape:       torch.Size([1, 16, 10, 10])
AvgPool2d output shape:          torch.Size([1, 16, 5, 5])
Flatten output shape:    torch.Size([1, 400])
Linear output shape:     torch.Size([1, 120])
ReLU output shape:       torch.Size([1, 120])
Linear output shape:     torch.Size([1, 84])
ReLU output shape:       torch.Size([1, 84])
Linear output shape:     torch.Size([1, 10])
"""
# 模型训练
batch_size = 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size) # 加载Fashion-MNIST数据集


#分类精度
def accuracy(y_hat,y): #@save
    """计算预测正确的数量"""
    #判断y_hat.shape是否为二维以上的矩阵
    #并且列数大于1
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        #axis = 1 表示按照每一行
        #argmax(axis = 1)得到每行最大值的下标
        y_hat = y_hat.argmax(axis = 1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval() # 将模型设置为评估模式
    metric = lp.Accumulator(2) # 正确预测数、预测总数
    with torch.no_grad(): # 禁用梯度计算
        for X, y in data_iter:
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(accuracy(net(X), y), y.numel()) # 累加正确预测数和样本总数
    return metric[0] / metric[1] # 返回精度

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight) # 初始化权重
    net.apply(init_weights) # 对网络应用权重初始化
    print('training on', device)
    net.to(device) # 将模型加载到设备上
    optimizer = torch.optim.SGD(net.parameters(), lr=lr) # 使用随机梯度下降优化器
    loss = nn.CrossEntropyLoss() # 定义交叉熵损失函数
    animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],
                           legend=['train loss', 'train acc', 'test acc']) # 动画工具,绘制训练曲线
    timer, num_batches = lp.Timer(), len(train_iter) # 计时器和批次数

    for epoch in range(num_epochs):
        metric = lp.Accumulator(3) # 训练损失之和,训练准确率之和,样本数
        net.train() # 训练模式
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad() # 梯度清零
            X, y = X.to(device), y.to(device) # 将数据加载到设备上
            y_hat = net(X) # 前向传播
            l = loss(y_hat, y) # 计算损失
            l.backward() # 反向传播
            optimizer.step() # 更新参数
            with torch.no_grad(): # 禁用梯度计算
                metric.add(l * X.shape[0], lp.accuracy(y_hat, y), X.shape[0]) # 累加损失、准确率和样本数
            timer.stop()
            train_l = metric[0] / metric[2] # 计算平均训练损失
            train_acc = metric[1] / metric[2] # 计算平均训练准确率
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None)) # 更新动画
        test_acc = evaluate_accuracy_gpu(net, test_iter, device) # 计算测试集上的准确率
        animator.add(epoch + 1, (None, None, test_acc)) # 更新动画
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

lr, num_epochs = 0.5, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu()) # 训练模型
# d2l.plt.show() # 显示训练曲线
plt.show() # 显示训练曲线

# lr = 0.9,Sigmoid()
# loss 0.466, train acc 0.825, test acc 0.808

# lr = 0.1,Sigmoid()
# loss 1.277, train acc 0.551, test acc 0.568

# lr = 0.1,ReLU()
# loss 0.339, train acc 0.874, test acc 0.803

# lr = 0.5,ReLU()
# loss 0.302, train acc 0.887, test acc 0.857

# lr = 0.6,ReLU()
# loss 0.316, train acc 0.878, test acc 0.861

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/734576.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024 Testing Expo China – Automotive I 风丘与您相约上海世博馆

2024汽车测试及质量监控博览会(中国)——(Testing Expo China – Automotive)是面向整车、零部件和系统开发的各种技术和服务的盛会,展示了汽车测试、开发和验证技术的各个方面,每年在上海举行,…

如何识别商业电子邮件诈骗

复制此链接到微信打开阅读全部已发布文章 不要关闭它标签!我知道很少有词组比商业、电子邮件和妥协更无趣。 但这不是一篇无聊的文章:这是一篇关于电子邮件骗子的文章,根据联邦调查局的说法,他们每年通过诈骗人们赚取 260 亿美元…

Java程序之动物声音“模拟器”

题目: 设计一个“动物模拟器”,希望模拟器可以模拟许多动物的叫声和行为,要求如下: 编写接口Animal,该接口有两个抽象方法cry()和getAnimalName(),即要求实现该接口的各种具体的动物类给出自己的叫声和种类…

33.LengthFieldBasedFrameDecoder四个参数

第一种方式: 消息的长度和内容一起发送。 数据分为两部分,一部分是数据的长度,另一部分是数据内容本身。 构造方法参数 lengthFieldOffset 表示整个消息体内,消息长度字段的偏移量。就是记录消息长度的字节的开始位置。 lengthFieldLength 表示长度字段的长度。就是用…

02_ESP32+MicroPython 点亮LED灯

书接第1篇《01_ESP32 MicroPython开发环境搭建_eps32开发板-CSDN博客》 想要让一个引脚输出高电平,只需要找到对应的GPIO然后通过on()或者value(1)操作就可以,同理如果想要输出低电平让LED灯灭,只需要调用off()或者value(0)就行。 一、点亮…

Android开发实用必备的几款插件,提高你的开发速度

1.GsonFormat 使用方法:快捷键AltS也可以使用AltInsert选择GsonFormat,作用:速将json字符串转换成一个Java Bean,免去我们根据json字符串手写对应Java Bean的过程。 2.ButterKnife Zelezny 又叫黄油刀 使用方法:CtrlS…

Android面试题之动画+事件处理篇

1、Android 中的动画有哪几类 帧动画、补间动画、属性动画 2、动画能组合在一起使用么? 可以将动画组合在一起使用AnimatorSet, AnimatorSet.play() 播放当前动画的同时可以 .with() :将现有动画和传入的动画同时执行 .after() &#xff1a…

Android:知道类加载过程面试还是卡壳?干货总结,一网打净“类”的基础知识!

多线程进行类的初始化会出问题吗&#xff1f; 类的实例化触发时机。 <clinit>()方法和<init>()方法区别。 在类都没有初始化完毕之前&#xff0c;能直接进行实例化相应的对象吗? 类的初始化过程与类的实例化过程的异同&#xff1f; 一个实例变量在对象初始化…

禅道身份认证绕过漏洞(QVD-2024-15263)复现

禅道项目管理系统在开源版、企业版、旗舰版的部分版本中都存在此安全漏洞。攻击者可利用该漏洞创建任意账号实现未授权登录。 1.漏洞级别 高危 2.漏洞搜索 fofa: title"禅道"3.影响范围 v16.x < 禅道 < v18.12 &#xff08;开源版&#xff09; v6.x <…

自动驾驶仿真测试用例(完善版本)

进一步完善上述的测试用例&#xff0c;并根据不同的测试准备、车辆准备、车辆状态、车辆场景、车辆执行、可变因素、具体信号状态、通过标准和预期标准来详细描述每个测试用例。 用例编号测试类型测试项目测试描述车辆准备车辆状态车辆场景车辆执行可变因素具体信号状态通过标准…

Android应用--简、美音乐播放器添加电话监听

3. 控制音量 4. 获取专辑图片 5. 在线下载歌词 6. 在线搜索音乐 7. 在线下载音乐 8. 实现有趣功能–甩歌 9. 界面美化–实现专辑倒影 10.实现左右界面切换 11.实现在通知栏显示播放状态 12.实现音乐播放的桌面小控件 暂时想到这些功能&#xff0c;如果朋友们有什么建…

Pikachu靶场--越权漏洞

参考借鉴 pikachu之越权漏洞_pikachu越权漏洞-CSDN博客 水平越权 需要输入username和password进行登录 查看提示&#xff0c;获取username和password 输入其中一组账号信息进行登录 可以查看到个人信息 在URL中更改username的值-->回车 成功越权&#xff0c;登录到其他账号…

2024-06-23 编译原理实验4——中间代码生成

文章目录 一、实验要求二、实验设计三、实验结果四、附完整代码 补录与分享本科实验&#xff0c;以示纪念。 一、实验要求 在词法分析、语法分析和语义分析程序的基础上&#xff0c;将C−−源代码翻译为中间代码。 要求将中间代码输出成线性结构&#xff08;三地址代码&#…

CTFHUB-SSRF-POST请求

通过file协议访问flag.php文件内容 ?urlfile:///var/www/html/flag.php 右键查看页面源代码 需要从内部携带key发送post数据包即可获得flag ?urlhttp://127.0.0.1/flag.php 得到了key 构造POST请求数据包&#xff0c;进行url编码&#xff08;新建一个txt文件&#xff0c;…

【vue3|第12期】Vue3的Props详解:组件通信

日期&#xff1a;2024年6月19日 作者&#xff1a;Commas 签名&#xff1a;(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释&#xff1a;如果您觉得有所帮助&#xff0c;帮忙点个赞&#xff0c;也可以关注我&#xff0c;我们一起成长&#xff1b;如果有不对的地方&#xf…

微服务——重复消费(幂等解决方案)

目录 一、唯一ID机制二、幂等性设计三、状态检查机制四、利用缓存和消息队列五、分布式锁总结 在微服务中&#xff0c;防止重复消费的核心思想是通过设计使得操作一次与多次产生相同的效果&#xff0c;并为每次操作生成唯一的ID。这样&#xff0c;即使在消息被重复发送的情况下…

Stable Diffusion 插件安装与推荐,助力你的AI绘图

在上一篇文章我们安装了Stable Diffusion &#xff0c;这篇文章我们来安装Stable Diffusion的插件 Stable Diffusion的插件是绘画中重要的一环&#xff0c;好的插件可以让你的绘画更加得心应手 中英双语插件 为什么要安装中英双语插件呢&#xff0c;不能只安装中文插件吗&…

计算机组成原理 | 数据的表示、运算和校验(1)数值型数据

有了一个二进制代码&#xff0c;首先要知道他是带符号的还是不带符号的&#xff0c;接着要知道他是原码还是补码还是反码&#xff0c;最终才能确定他的真值。 补码和移码&#xff1a;符号相反、数值位相同 表示范围不理解 数的定点表示法 对于反码而言&#xff1a;10000000表示…

Android蓝牙开发(一)之打开蓝牙和设备搜索

private BluetoothManager bluetoothmanger; private​ BluetoothAdapter bluetoothadapter; /** 判断设备是否支持蓝牙 */ bluetoothmanger (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); bluetoothadapter bluetoothmanger.getAdapter(); if (bl…

46、基于自组织映射神经网络的鸢尾花聚类(matlab)

1、自组织映射神经网络的鸢尾花聚类的原理及流程 自组织映射神经网络&#xff08;Self-Organizing Map, SOM&#xff09;是一种用于聚类和数据可视化的人工神经网络模型。在鸢尾花聚类中&#xff0c;SOM 可以用来将鸢尾花数据集分成不同的类别&#xff0c;同时保留数据间的拓扑…