李沐23_LeNet——自学笔记

手写的数字识别

知名度最高的数据集:MNIST
1.训练数据:50000

2.测试数据:50000

3.图像大小:28✖28

4.10类

总结

1.LeNet是早期成功的神经网络

2.先使用卷积层来学习图片空间信息

3.使用全连接层来转换到类别空间

代码实现

LeNet由两部分组成:卷积编码器和全连接层密集块

import torch
from torch import nn
from d2l import torch as d2l

class Reshape(torch.nn.Module):
  def forward(self,x):
    return x.view(-1,1,28,28) # 原图:28✖28,填充后是32✖32

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(), # 6个通道
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

我们将一个大小为28✖28的单通道(黑白)图像通过LeNet。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。

第一个卷积层使用2个像素的填充,来补偿5✖5卷积核导致的特征减少。

相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。

随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。

同时,每个汇聚层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 下载fashion_MNIST数据集
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 9258920.33it/s] 


Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 171125.91it/s]


Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3169968.34it/s]


Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 4336669.41it/s]

Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw




/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
def evaluate_accuracy_gpu(net, data_iter, device=None): #计算模型精度
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

训练函数

1.训练函数train_ch6也类似于3.6节中定义的train_ch3。

2.使用高级API创建的模型作为输入,并进行相应的优化。

3.Xavier随机初始化模型参数。
4.使用交叉熵损失函数和小批量随机梯度下降。

#用GPU训练模型,比第三章多了device
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], # 动画效果
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.461, train acc 0.827, test acc 0.793
35664.6 examples/sec on cuda:0

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/530886.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

学习记录:bazel和cmake运行终端指令

Bazel和CMake都是用于构建软件项目的工具&#xff0c;但它们之间有一些重要的区别和特点&#xff1a; Bazel&#xff1a; Bazel是由Google开发的构建和测试工具&#xff0c;用于构建大规模的软件项目。它采用一种称为“基于规则”的构建系统&#xff0c;它利用构建规则和依赖关…

Android 属性动画及自定义3D旋转动画

Android 动画框架 其中包括&#xff0c;帧动画、视图动画&#xff08;补间动画&#xff09;、属性动画。 在Android3.0之前&#xff0c;视图动画一家独大&#xff0c;之后属性动画框架被推出。属性动画框架&#xff0c;基本可以实现所有的视图动画效果。 视图动画的效率较高…

第十届蓝桥杯大赛个人赛省赛(软件类) CC++ 研究生组-RSA解密

先把p&#xff0c;q求出来 #include<iostream> #include<cmath> using namespace std; typedef long long ll; int main(){ll n 1001733993063167141LL, sqr sqrt(n);for(ll i 2; i < sqr; i){if(n % i 0){printf("%lld ", i);if(i * i ! n) pri…

关于VMware安装win系统的磁盘扩容与缩减

使用VMware虚拟机安装虚拟windows系统时&#xff0c;如果创建虚拟磁盘的空间预留不足&#xff08;特别是C判空间&#xff09;&#xff0c;安装win系统后&#xff0c;由于默认win系统在安装时分配的healthy健康盘位于系统C盘临近区域&#xff0c;此时如果需要增加C盘虚拟空间&am…

张驰咨询:深圳六西格玛绿带培训5天专业能力提升课程

张驰咨询即将在深圳开设的六西格玛绿带5天培训班&#xff0c;是针对希望在质量管理、项目管理等领域提升自己能力的专业人士的一次重要机会。六西格玛作为一种旨在减少缺陷、提高效率和质量的方法论&#xff0c;已经被全球众多企业采用。绿带认证作为进入这一领域的门槛之一&am…

【产品】ADW300 无线计量仪表 用于计量低压网络的三相有功电能

1 概述 ADW300 无线计量仪表主要用于计量低压网络的三相有功电能&#xff0c;具有体积小、精度高、功能丰富等优点&#xff0c;并且可选通讯方式多&#xff0c;可支持 RS485 通讯和 Lora、2G、NB、4G 等无线通讯方式&#xff0c;增加了外置互感器的电流采样模式&#xff0c;从…

利用AI开源引擎平台:构建文本、图片及视频内容审核系统|可本地部署

网络空间的信息量呈现出爆炸式增长。在这个信息多元化的时代&#xff0c;内容审核系统成为了维护网络秩序、保护用户免受有害信息侵害的重要工具。本文将探讨内容审核系统的核心优势、技术实现以及在不同场景下的应用。 开源项目介绍(可本地部署&#xff0c;支持国产化) 思通数…

法拉电容Farad capacitor与锂电池的区别和对比!

法拉电容也称为超级电容。超级电容器是介于传统电容器和充电电池之间的一种新型环保储能装置&#xff0c;其容量可达0.1F至>10000F法拉&#xff0c;与传统电容器相比&#xff1a;它具有较大的容量、较高的能量、较宽的工作温度范围和极长的使用寿命&#xff1b;而与蓄电池相…

linux网络知识

七层模型 应用层 为操作系统或者网络应用程序提供网络服务的接口 表示层 解决不同系统之间的通信问题&#xff0c;负责数据格式的转换 会话层 自动收发包&#xff0c;自动寻址&#xff0c;负责建立和断开连接 传输层 将上层数据分段并提供端到端的…

机器学习 —— 使用机器学习进行情感分析 演示版

机器学习 —— 使用机器学习进行情感分析 详细介绍版 机器学习 —— 使用机器学习进行情感分析 演示版 一、项目构想 在现代互联网时代里&#xff0c;人们的意见、评论和建议已成为政治科学和企业的宝贵资源。借助现代技术&#xff0c;我们现在能够最有效地收集和分析此类数据。…

第十五届蓝桥杯测试组模拟赛两期

文章目录 功能测试一期-场景法-登录功能一期-等价类-边界值-添加用户账号输入框一期-登录-缺陷报告一期- UI自动化测试一期-单元测试-路径覆盖二期-正交法-搜索条件组合二期-测试用例二期-缺陷报告二期-自动化测试二期-单元测试-基本路径覆盖 功能测试 一期-场景法-登录功能 …

什么是 DNS 记录?

DNS记录是存储在DNS服务器上的文本指令。它们表明与一个域名相关的IP地址&#xff0c;也可以提供其他信息。DNS记录是计算机用语&#xff0c;指域名系统&#xff08;Domain Name System&#xff0c;简称DNS&#xff09;中的一条记录&#xff0c;这条记录存储于DNS服务器中。每一…

html基础(2)(链接、图像、表格、列表、id、块)

1、链接 <a href"https://www.example.com" target"_blank" title"Example Link">Click here</a> 在上示例中&#xff0c;定义了一个链接&#xff0c;在网页中显示为Click here&#xff0c;鼠标悬停指示为Example Link&#xff0c…

Bootstrap 5 保姆级教程(一):容器 网格系统

一、容器 1.1 固定宽度&#xff08;.container&#xff09; .container 类用于固定宽度并支持响应式布局的容器。 以下实例中&#xff0c;我们可以尝试调整浏览器窗口的大小来查看容器宽度在不同屏幕中等变化&#xff1a; <!doctype html> <html lang"en&quo…

4.2.4 理解路由器数据包过程

1、实验目的 通过本实验可以掌握&#xff1a; 了解IP路由原理了解数据包封装和解封装的概念了解路由器路由和交换过程 2、实验拓扑 观察路由器路由数据包过程的实验拓扑如图4-3所示&#xff0c;设备接口地址信息如表4-2所示。 图4-3 观察路由器路由数据包过程的实验拓扑 本…

文件批量重命名,繁体中文秒变简体中文,轻松实现高效翻译

在数字化时代&#xff0c;我们的工作、学习和生活都离不开电脑文件。随着时间的推移&#xff0c;文件数量不断增加&#xff0c;管理起来变得越来越困难。你是否曾经为如何高效、有序地管理文件而烦恼&#xff1f;现在&#xff0c;有一款强大的文件批量重命名工具&#xff0c;它…

CAXA3D工艺图表2019版 下载地址及安装教程

CAXA 3D工艺图表是CAXA 3D软件中的一个功能模块&#xff0c;专门用于创建和展示工艺流程和工艺图。它为工程师和设计师提供了一种直观和清晰的方式来表示和记录产品的工艺过程。 使用CAXA 3D工艺图表&#xff0c;用户可以创建具有层次结构的工艺流程图&#xff0c;显示产品的各…

Aigtek:高压放大器常见问题及解决方法有哪些

高压放大器是一种关键的电子设备&#xff0c;常用于实验室、工业和科学研究中&#xff0c;用于放大电压信号。然而&#xff0c;像所有的电子设备一样&#xff0c;高压放大器也可能遇到各种常见问题。下面安泰电子将介绍一些高压放大器常见的问题以及相应的解决方法。 一、电源问…

深度学习实践(一)基于Transformer英译汉模型

本文目录 前述一、环境依赖二、数据准备1. 数据加载程序解析word_tokenize()将字符串分割为一个个的单词&#xff0c;并由列表保存。 2. 构建单词表程序解析&#xff08;1&#xff09;将列表里每个子列表的所有单词合并到一个新列表&#xff08;没有子列表&#xff09;中。&…

4.9日总结

1.MySQL概述 1.数据库基本概念&#xff1a;存储数据的仓库&#xff0c;数据是有组织的进行存储 2.数据库管理系统&#xff1a;操纵和管理数据库的大型软件 3.SQL&#xff1a;操作关系型数据库的编程语言&#xff0c;定义了一套操作型数据库统一标准 2.MySQL数据库 关系型数…