机器学习深度学习——卷积神经网络(LeNet)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——池化层
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

卷积神经网络(LeNet)

  • 引言
  • LeNet
  • 模型训练
  • 小结

引言

之前的内容中曾经将softmax回归模型和多层感知机应用于Fashion-MNIST数据集中的服装图片。为了能应用他们,我们首先就把图像展平成了一维向量,然后用全连接层对其进行处理。
而现在已经学习过了卷积层的处理方法,我们就可以在图像中保留空间结构。同时,用卷积层代替全连接层的另一个好处是:模型更简单,所需参数更少。
LeNet是最早发布的卷积神经网络之一,之前出来的目的是为了识别图像中的手写数字。

LeNet

总体看,由两个部分组成:
1、卷积编码器:由两个卷积层组成
2、全连接层密集快:由三个全连接层组成
在这里插入图片描述
上图中就是LeNet的数据流图示,其中汇聚层也就是池化层。
最终输出的大小是10,也就是10个可能结果(0-9)。
每个卷积块的基本单元是一个卷积层、一个sigmoid激活函数和平均池化层(当年没有ReLU和最大池化层)。每个卷积层使用5×5卷积核和一个sigmoid激活函数。
这些层的作用就是将输入映射到多个二维特征输出,通常同时增加通道的数量。(从上图容易看出:第一卷积层有6个输出通道,而第二个卷积层有16个输出通道;每个2×2池操作(步幅也为2)通过空间下采样将维数减少4倍)。卷积的输出形状那是由批量大小、通道数、高度、宽度决定。
为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本(也就是把四维的输入转换为全连接层期望的二维输入,第一维索引小批量中的样本,第二维给出给个样本的平面向量表示)。
LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量(代表0-9是个结果)。
深度学习框架实现此类模型非常简单,用一个Sequential块把需要的层连接在一个就可以了,我们对原始模型做一个小改动,去掉最后一层的高斯激活:

import torch
from torch import nn
from d2l import torch as d2l

net = nn.Sequential(
    # 输入图像和输出图像都是28×28,因此我们要先进行填充2格
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

上面的模型图示就为:
在这里插入图片描述
我们可以先检查模型,在每一层打印输出的形状:

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)

输出结果:

Conv2d output shape: torch.Size([1, 6, 28, 28])
Sigmoid output shape: torch.Size([1, 6, 28, 28])
AvgPool2d output shape: torch.Size([1, 6, 14, 14])
Conv2d output shape: torch.Size([1, 16, 10, 10])
Sigmoid output shape: torch.Size([1, 16, 10, 10])
AvgPool2d output shape: torch.Size([1, 16, 5, 5])
Flatten output shape: torch.Size([1, 400])
Linear output shape: torch.Size([1, 120])
Sigmoid output shape: torch.Size([1, 120])
Linear output shape: torch.Size([1, 84])
Sigmoid output shape: torch.Size([1, 84])
Linear output shape: torch.Size([1, 10])

模型训练

既然已经实现了LeNet,现在可以查看它在Fashion-MNIST数据集上的表现:

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

计算成本较高,因此使用GPU来加快训练。为了进行评估,对之前的evaluate_accuracy进行修改,由于完整的数据集位于内存中,因此在模型使用GPU计算数据集之前,我们需要将其复制到显存中。

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
            # BERT微调所需(后面内容)
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

要使用GPU,我们要在正向和反向传播之前,将每一小批量数据移动到我们GPU上。
如下所示的train_ch6类似于之前定义的train_ch3。以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。
使用Xavier来随机初始化模型参数。有关于Xavier的推导和原理可以看下面的文章:
机器学习&&深度学习——数值稳定性和模型化参数(详细数学推导)
与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降,代码如下:

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):  #@save
    """用GPU训练模型"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc =  metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i+1) / num_batches, (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

此时我们可以开始训练和评估LeNet模型:

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

运行输出(这边我没有用远程的GPU,在自己本地跑了,本地只有CPU):

training on cpu
loss 0.477, train acc 0.820, test acc 0.795
8004.2 examples/sec on cpu

运行图片:
在这里插入图片描述

小结

1、卷积神经网络(CNN)是一类使用卷积层的网络
2、在卷积神经网络中,我们组合使用卷积层、非线性激活函数和池化层
3、为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数
4、传统卷积神经网络中,卷积块编码得到的表征在输出之前需要由一个或多个全连接层进行处理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/63109.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JMeter源码解析之结果收集器

JMeter源码解析之结果收集器 一、JMeter结果收集器概述二、单机模式三、分布式模式四、总结 一、JMeter结果收集器概述 JMeter是在压力领域中最常见的性能测试工具,由于其开源的特点,受到广大测试和开发同学的青睐。但是,在实际应用过程中&a…

数据结构 | 搜索和排序——搜索

目录 一、顺序搜索 二、分析顺序搜索算法 三、二分搜索 四、分析二分搜索算法 五、散列 5.1 散列函数 5.2 处理冲突 5.3 实现映射抽象数据类型 搜索是指从元素集合中找到某个特定元素的算法过程。搜索过程通常返回True或False,分别表示元素是否存在。有时&a…

LeetCode 热题 100 JavaScript--142. 环形链表 II

给定一个链表的头节点 head ,返回链表开始入环的第一个节点。 如果链表无环,则返回 null。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数…

HDFS架构刨析

HDFS架构刨析 概述HDFS架构图整体概述主角色:namenodefsimage内存元数据镜像文件edits log(Journal)编辑日志 从角色:datanode主角色辅助角色:secondarynamenode 重要特性主从架构分块存储机制副本机制namespace元数据…

快速搭建MQTT测试环境,手把手详细教程

文章目录 前言系统架构准备工具代理服务器客户端客户端 TEST-1客户端 TEST-2 验证消息传递订阅主题发布主题 前言 大家好,我是麦叔,之前有小伙伴建议出一期如何快速搭建一个MQTT协议的测试环境,这里要合理地使用现有的工具,其实很…

ubuntu18.04安装docker及docker基本命令的使用

官网安装步骤:https://docs.docker.com/desktop/install/ubuntu/ docker快速入门教程 Ubuntu-Docker安装和使用 docker官网 docker-hub仓库 1、常用指令 (1)镜像操作 # ############################# 以nginx为例 docker images docker p…

五分钟帮您理解Linux网络核心知识点——socket和epoll

关于linux网络相关的基础知识点,最热的两个就是socket和epoll,接下来我就用最简单的方式把他俩说清楚便于大家理解! Socket Socket 是一种进程间通信的方法,它允许位于同一主机(计算机)或使用网络连接起来…

Android SystemServer中Service的创建和启动方式(基于Android13)

Android SystemServer创建和启动方式(基于Android13) SystemServer 简介 Android System Server是Android框架的核心组件,运行在system_server进程中,拥有system权限。它在Android系统中扮演重要角色,提供服务管理和通信。 system …

右键文件夹 ------- 打开 vscode的方法

1、右键vscode点击属性 2、这是地址栏,一会复制即可 3、新建一个txt文件,将这个复制进去 Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\*\shell\VSCode] "Open with Code" "Icon""D:\\Microsoft VS Code\\Code.exe"[HKE…

人到中年不得已,保温杯里泡枸杞--送程序员

目录 一:你现在身体的体能状况如何?你有身体焦虑吗? 二:如何保持规律性运动? 三:你有哪些健康生活的好习惯? 大厂裁员,称35岁以后体能下滑,无法继续高效率地完成工作&…

【数据结构OJ题】删除有序数组中的重复项

原题链接:https://leetcode.cn/problems/remove-duplicates-from-sorted-array/ 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 用双指针算法,定义两个变量src和dst,一开始让src和dst指向num[ ]数组的第一个元素&a…

Java注解详细介绍

Java注解详细介绍 基于注解(Annotation-based)的Java开发无疑是最新的开发趋势.[译者注: 这是05年的文章,在2014年,毫无疑问,多人合作的开发,使用注解变成很好的合作方式,相互之间的影响和耦合可以很低]. 基于注解的开发将Java开发人员从繁琐笨重的配置文件中解脱出来. Java 5…

hacksudo3 通关详解

环境配置 一开始桥接错网卡了 搞了半天 改回来就行了 信息收集 漏洞发现 扫个目录 大概看了一眼没什么有用的信息 然后对着login.php跑了一下弱口令 sqlmap 都没跑出来 那么利用点应该不在这 考虑到之前有过dirsearch字典太小扫不到东西的经历 换个gobuster扫一下 先看看g…

自然语言处理:长文本场景下的关键词抽取实践

NLP专栏简介:数据增强、智能标注、意图识别算法|多分类算法、文本信息抽取、多模态信息抽取、可解释性分析、性能调优、模型压缩算法等 专栏详细介绍:NLP专栏简介:数据增强、智能标注、意图识别算法|多分类算法、文本信息抽取、多模态信息抽取、可解释性分析、性能调优、模型…

HBase-读流程

创建连接同写流程。 (1)读取本地缓存中的Meta表信息;(第一次启动客户端为空) (2)向ZK发起读取Meta表所在位置的请求; (3)ZK正常返回Meta表所在位置&#x…

SpringBoot使用@Autowired将实现类注入到List或者Map集合中

前言 最近看到RuoYi-Vue-Plus翻译功能 Translation的翻译模块配置类TranslationConfig,其中有一个注入TranslationInterface翻译接口实现类的写法让我感到很新颖,但这种写法在Spring 3.0版本以后就已经支持注入List和Map,平时都没有注意到这…

自制免费 SQL 闯关自学网,代码开源!

大家好,我是鱼皮。 相信很多学编程的同学都学习过 SQL 吧?SQL 作为数据库查询语言,实在是太重要了,可以说是程序员、产品经理、数据分析同学的必备技能。 为了帮助大家自学 SQL,这段时间,我一个人做了个 …

ios_base::out和ios::out、ios_base::in和ios::in、ios_base::app和ios::app等之间有什么区别吗?

2023年8月2日,周三晚上 今天我看到了这样的两行代码: std::ofstream file("example.txt", std::ios_base::out);std::ofstream file("example.txt", std::ios::out);这让我产生了几个疑问: 为什么有时候用ios_base::o…

智慧城市规划新引擎:探秘数字孪生中的二维与三维GIS技术差异

智慧城市作为人类社会发展的新阶段,正日益引领着我们迈向数字化未来的时代。在智慧城市的建设过程中,地理信息系统(GIS)扮演着举足轻重的角色。而在GIS的发展中,二维和三维GIS作为两大核心技术,在城市规划与…

C#使用libmodbus库与工业设备进行读写测试

一.编译libmodbus库供C#使用 如何编译?请移步:https://blog.csdn.net/weixin_42205408/article/details/119530811 上面博主的文章除了所写的modbus.cs内的代码有点问题外(可能上面博主和我的Win 10 64位 Visual Studio 2019平台不一样吧&a…