29 残差网络 ResNet【李沐动手学深度学习v2课程笔记】

目录

1. ResNet

1.1 残差块

1.2 ResNet块

1.3 总结

2. 代码实现

2.1 残差块

2.2 ResNet模型

2.3 训练模型

3. 小结


1. ResNet

1.1 残差块

只有当较复杂的函数类包含较小的函数类时,我们才能确保提高它们的性能。 

1.2 ResNet块

五个stage 只不过是ResNet块

1.3 总结

2. 代码实现

2.1 残差块

 ResNet沿用了VGG完整的3×3卷积层设计。 残差块里首先有2个有相同输出通道数的3×3卷积层。 每个卷积层后接一个批量规范化层和ReLU激活函数。 然后我们通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。 这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。

3x3卷积层-批量规范化层-ReLU激活函数

3x3卷积层-批量规范化层-ReLU激活函数

如果想改变通道数,就需要引入一个额外的1×1卷积层来将输入变换成需要的形状后再做相加运算。 残差块的实现如下:

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


class Residual(nn.Module):  #@save
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

此代码生成两种类型的网络: 一种是当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。 另一种是当use_1x1conv=True时,添加通过1×1卷积调整通道和分辨率。

下面我们来查看输入和输出形状一致的情况。

blk = Residual(3,3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape
torch.Size([4, 3, 6, 6])

我们也可以在增加输出通道数的同时,减半输出的高和宽。

blk = Residual(3,6, use_1x1conv=True, strides=2)
blk(X).shape
torch.Size([4, 6, 3, 3])

2.2 ResNet模型

ResNet的前两层跟之前介绍的GoogLeNet中的一样: 在输出通道数为64、步幅为2的7×7卷积层后,接步幅为2的3×3的最大汇聚层。 不同之处在于ResNet每个卷积层后增加了批量规范化层。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

GoogLeNet在后面接了4个由Inception块组成的模块。 ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。 第一个模块的通道数同输入通道数一致。 由于之前已经使用了步幅为2的最大汇聚层,所以无须减小高和宽。 之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

下面我们来实现这个模块。注意,我们对第一个模块做了特别处理。

def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

接着在ResNet加入所有残差块,这里每个模块使用2个残差块。

b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

最后,与GoogLeNet一样,在ResNet中加入全局平均汇聚层,以及全连接层输出。

net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

每个模块有4个卷积层(不包括恒等映射的1×1卷积层)。 加上第一个7×7卷积层和最后一个全连接层,共有18层。 因此,这种模型通常被称为ResNet-18。 通过配置不同的通道数和模块里的残差块数可以得到不同的ResNet模型,例如更深的含152层的ResNet-152。 虽然ResNet的主体架构跟GoogLeNet类似,但ResNet架构更简单,修改也更方便。这些因素都导致了ResNet迅速被广泛使用。 图7.6.4描述了完整的ResNet-18。

在训练ResNet之前,让我们观察一下ResNet中不同模块的输入形状是如何变化的。 在之前所有架构中,分辨率降低,通道数量增加,直到全局平均汇聚层聚集所有特征。

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 128, 28, 28])
Sequential output shape:     torch.Size([1, 256, 14, 14])
Sequential output shape:     torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:      torch.Size([1, 512, 1, 1])
Flatten output shape:        torch.Size([1, 512])
Linear output shape:         torch.Size([1, 10])

2.3 训练模型

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

3. 小结

  • 学习嵌套函数(nested function)是训练神经网络的理想情况。在深层神经网络中,学习另一层作为恒等映射(identity function)较容易(尽管这是一个极端情况)。

  • 残差映射可以更容易地学习同一函数,例如将权重层中的参数近似为零。

  • 利用残差块(residual blocks)可以训练出一个有效的深层神经网络:输入可以通过层间的残余连接更快地向前传播。

  • 残差网络(ResNet)对随后的深层神经网络设计产生了深远影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/450128.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Docker学习——Dock镜像

什么是Docker镜像 Docker 镜像类似于虚拟机镜像,可以将它理解为一个只读的模板。 一个镜像可以包含一个基本的操作系统环境,里面仅安装了 Apache 应用程序(或 用户需要的其他软件) 可以把它称为一个 Apache 镜像。镜像是创建 Do…

POS 之 提款密钥与验证者密钥

回顾之前的文章 文章标题文章地址🍧ETH网络中的账户https://blog.csdn.net/weixin_29491885/article/details/136318898🍨我们为什么需要助记词https://blog.csdn.net/weixin_29491885/article/details/135860211🧁一组助记词走遍天下也不怕…

嵌入式学习第二十六天!(网络传输:TCP编程、HTTP协议)

TCP通信: 1. TCP发端: socket -> connect -> send -> recv -> close 2. TCP收端: socket -> bind -> listen -> accept -> recv -> send -> close 3. TCP需要用到的函数: 1. co…

excel函数

1.VLOOKUP(B2.[测试部门.xls]结果集!$A:$C,1,FALSE)

2023年第三届中国高校大数据挑战赛第二场赛题D题赛题:行业职业技术培训能力评价(成品论文 代码与思路 视频讲解)

赛题 中国是制造业大国,产业门类齐全,每年需要培养大量的技能娴熟的技术工人进入工厂。某行业在全国有多所不同类型(如国家级、省级等)的职业技术培训学校,进行 5 种技能培训。学员入校时需要进行统一的技能考核&…

node管理器 nvm

需求背景:如果出现多个项目 不同项目使用的node版本不一致 需要切换node版本 如果每次单独下载对应的node版本太过于麻烦,使用nvm可以下载多个node版本,通过nvm切换控制使用哪个node版本 nvm下载地址:https://github.com/coreybut…

0-13 - 准备:智能指针类和异常类

---- 整理自狄泰软件唐佐林老师课程 文章目录 1. 智能指针示例1.1 内存泄漏(臭名昭著的Bug )1.2 当代 C 软件平台中的智能指针1.3 智能指针的设计方案1.4 智能指针的使用军规1.5 编程实验:智能指针示例 2. 异常类构建2.1 现代 C 库必然包含充…

蜂窝物联:智慧畜牧养殖解决方案

我国是一个畜牧大国,在实现畜牧业发展的过程中,面临着企业生产管理水平低、政府监管薄弱、环境污染、行业数据资源分散等问题,阻碍了现代畜牧业的快速发展。 近年来,蜂窝物联针对畜牧业的发展现状,借助新一代物联网和…

华为OD七日集训第2期 - 按算法分类,由易到难,循序渐进,玩转OD

目录 一、适合人群二、本期训练时间三、如何参加四、七日集训第 2 期五、精心挑选21道高频100分经典题目,作为入门。第1天、逻辑分析第2天、字符串处理第3天、矩阵第4天、深度优先搜索dfs算法第5天、回溯法第6天、二分查找第7天、双指针 大家好,我是哪吒…

新闻媒体软文发布,提升企业宣传效果的最佳方法!

在新闻媒体上发布宣传效果确实很高,可以帮助企业提高宣传效果,为企业打开知名度。迅推客新闻媒体软文发布有很多优势。如果写软文推广公司,可以有很多方法,比如用软文推广公司产品,介绍公司产品的亮点,其实…

【QT+QGIS跨平台编译】之七十一:【QGIS_Analysis跨平台编译】—【qgsrastercalclexer.cpp生成】

文章目录 一、Flex二、生成来源三、构建过程一、Flex Flex (fast lexical analyser generator) 是 Lex 的另一个替代品。它经常和自由软件 Bison 语法分析器生成器 一起使用。Flex 最初由 Vern Paxson 于 1987 年用 C 语言写成。 “flex 是一个生成扫描器的工具,能够识别文本中…

一口气看完西汉210年历史

1、刘邦建国 公元前202,刘邦在垓下之战中击败楚王项羽,终结了历时7年的秦末大乱,建立西汉王朝。 西汉全盛时期地图 公元前201年,匈奴单于冒顿引兵攻打太原,异姓诸侯王之一的汉王信战败投降,刘邦被迫亲自率…

Python·算法·每日一题(3月12日) 删除链表的倒数第 N 个结点

题目 给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。 示例 示例 1: 输入:head [1,2,3,4,5], n 2 输出:[1,2,3,5]示例 2: 输入:head [1], n 1 输出:[]示例…

server win搭建apache网站服务器+php网站+MY SQL数据库调用电子阅览室

一、适用场景: 1、使用开源的免费数据库Mysql; 2、自己建网站的发布; 3、使用php代码建网站; 4、使用windows server作为服务器; 5、使用apache作为网站服务器。 二、win server 中apache网站服务器搭建 &#xff0…

【v4l2】V4L2框架-videobuf2(二)

系列文章目录 【V4L2】V4L2框架简述 【V4L2】V4L2框架之驱动结构体 【V4L2】V4L2子设备 【V4L2】V4L2框架-media device 【V4L2】V4L2框架-videobuf2 文章目录 系列文章目录用户空间的操作/dev/video 节点与 videobuf2 联系编程注意事项 用户空间的操作 用户空间 stream 操作 …

【rk3229 android7.1.2 替换默认输入法】

问题平台描述 问题描述解决方法 郑重声明:本人原创博文,都是实战,均经过实际项目验证出货的 转载请标明出处:攻城狮2015 Platform: Rockchip CPU:rk3229 OS:Android 7.1.2 Kernel: 3.10 问题描述 国内客户,觉得安卓自带的输入法不好用&#x…

C语言从入门到熟悉------第二阶段

printf的用法 printf的格式有四种: (1)printf("字符串\n"); 其中\n表示换行的意思。其中n是“new line”的缩写,即“新的一行”。此外需要注意的是,printf中的双引号和后面的分号必须是在英文输入法下。双引…

如何选择满足业务需求的CRM系统?六大评估标准全解析!

任何企业在最终部署CRM管理系统前,都会经历一系列决断环节,例如是否要使用CRM、选择什么样的系统、前期投入是多少、预期的投资回报率等等。在挑选CRM系统这个环节,企业更是面临着大量的选择。市场上CRM厂商数量众多,产品宣传让人…

【Python】一文带你了解如何获取 Python模块 安装路径

【Python】一文带你了解如何获取 Python模块 安装路径 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望得到您的订阅…

ICC2:function eco / premask eco参考脚本

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 相关文章链接: ICC2:修short参考脚本 eco_netlist -by_verilog -file eco.v -write_changes eco.tcl source eco tcl place_eco_cells -eco_change_cells -no_legalize place_eco_cells -eco_cha…