李沐29_残差网络ResNet——自学笔记

残差网络

残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。

残差块

串联一个层改变函数类,我们希望扩大函数类,残差块加入快速通道来得到f(x)=x+g(x)的结果

ResNet块

1.高宽减半的ResNet块(步幅2)

2.后接多个高宽不变的ResNet块

ResNet架构

1.类似VGG和GoogLeNet总体架构

2.但替换成ResNet块

总结

残差块使得很深的网络更加容易训练,甚至可以训练一千层的网络

代码实现

ResNet沿用了VGG完整的3X3卷积层设计。 残差块里首先有2个有相同输出通道数的3X3卷积层。 每个卷积层后接一个批量规范化层和ReLU激活函数。 然后我们通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。 这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。 如果想改变通道数,就需要引入一个额外的
1X1卷积层来将输入变换成需要的形状后再做相加运算。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


class Residual(nn.Module):  #save
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

一种是当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。 另一种是当use_1x1conv=True时,添加通过1X1卷积调整通道和分辨率。

输入与输出形状一致

blk = Residual(3,3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape
torch.Size([4, 3, 6, 6])

增加输出通道的同时,减半高和宽

blk = Residual(3,6, use_1x1conv=True, strides=2)
blk(X).shape # batch size, channel, h, w
torch.Size([4, 6, 3, 3])

ResNet模型

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

观察ResNet的不同模块的输入形状是如何变化。

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 128, 28, 28])
Sequential output shape:	 torch.Size([1, 256, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 512, 1, 1])
Flatten output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 10])

训练模型

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.016, train acc 0.995, test acc 0.915
1553.6 examples/sec on cuda:0

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/535676.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android开发之移除权限

Android开发之移除权限 在Android开发结束后,我们将build好的App Bundle上传至Google Play。这时Google可能会提示一些需要解决的错误,比如xx权限需要限制,需要解释为什么需要这些权限,需要添加文字描述和视频链接,但…

MicroCinema与CCEdit:让文生视频兼具创造性与可控性

编者按:随着视频生成技术的飞速进步,我们见证了人工智能技术在视频清晰度、长视频连贯性以及对物理变化理解和镜头转换处理能力方面的显著提升。不过,这些高质量的生成结果是否完全符合我们的需求呢?显然,并非总是如此…

如何借助AI高效完成写作提纲

AI变革力量:未来数据中心的智能化之旅! 在当今这个信息爆炸的时代,人工智能(AI)在众多领域展现出了它的能力,特别是在写作领域。AI写作工具不仅能够帮助我们高效地生成内容,还能在一定程度上提升…

利用弱监督定位的高分辨率乳腺癌筛查图像的可解释分类器

标准筛查乳房 X 光检查由每个乳房的两张高分辨率 X 射线组成,分别从侧面(“中外侧”或 MLO 视图)和上方(“头尾”或 CC 视图)拍摄,总共四张图像。 乳房X光照片的分辨率非常高,而大多数无症状癌症…

WebGL异步绘制多点

异步绘制线段 1.先画一个点 2.一秒钟后&#xff0c;在左下角画一个点 3.两秒钟后&#xff0c;我再画一条线段 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"…

[2024年4月最新]Python安装教程

一、Python下载 1、进入Python官网 官网地址&#xff1a;https://www.python.org 2、点击【Downloads】展开后点击【Windows】跳转到下载python版本页面&#xff0c;选择"Stable Releases"稳定版本&#xff0c;我下载的是Python 3.10.10版本&#xff0c;所以找到【…

IDEA中无法保存设置 Cannot Save Settings

确定原因: 在IDEA中父工程不应该存在有子工程的相关东西 首先,这是我的DCYJ项目(观察右侧的Content Root) 其次,这是我的EAPOFode项目(观察右侧的Content Root爆红处) 最后我将DCYJ项目右侧的Content Root全部删掉

共享单车到底是什么通信原理

我们经常骑的共享单车到底是什么通信原理&#xff0c;有人了解过吗&#xff1f; 一、智能车锁 共享单车最核心的硬件是智能车锁&#xff0c;主要用于实现控制和定位功能。 车锁内集成了嵌入式芯片&#xff08;通信模块&#xff09;&#xff0c;GPS模块和物联网SIM卡。 智能锁制…

DP例题详解(二)最短编辑距离和编辑距离

902. 最短编辑距离 - AcWing题库 #include<bits/stdc.h>using namespace std;const int N1010;int n,m; char a[N],b[N]; int f[N][N];int main() {cin>>n>>a1;cin>>m>>b1;for(int i0;i<m;i)f[0][i]i;for(int i0;i<n;i)f[i][0]i;//先初始化…

用户状态保持机制-Session

0、业务需求 会话&#xff1a;web应用中的会话是指一个客户端浏览器和服务器之间连续发生的一系列请求和响应的过程。 会话状态&#xff1a;web应用中的会话状态是指web服务器与浏览器在会话过程中产生的状态信息&#xff0c;借助会话状态&#xff0c;服务器能够把属于同一会话…

线性表概念及实现1

文章目录 前言一、线性表1.定义2.特点3.一般线性表的抽象数据类型定义 二、线性表的顺序存储&#xff08;顺序表&#xff09;1.基本概念2.数组实现顺序表3.顺序表中基本操作的具体实现 总结 前言 T_T此专栏用于记录数据结构及算法的&#xff08;痛苦&#xff09;学习历程&#…

纯小白蓝桥杯备赛笔记--DAY14(计算几何)

文章目录 计算几何基础平面几何距离圆的周长和面积圆与圆之间的关系&#xff1a;海伦公式计算三角形面积点到直线的距离 点积和叉积例题&#xff1a; 点和线的关系点的表示形式和代码判断点在直线的那边点到线的垂足点到线的距离例题-1242例题-1240升级--点到线段的距离--1285 …

指定世界TOP名校|医学研究学者公派美国麻省理工学院做博士后

W医生公派博士后条件为&#xff1a;世界TOP100的知名高校&#xff0c;研究方向相符且前沿。最终我们用世界顶级高校-美国麻省理工学院&#xff08;MIT&#xff09;的博士后邀请函助其获得单位资助&#xff0c;顺利通过签证并出国。 W医生背景&#xff1a; 申请类型&#xff1a…

LC 501.二叉搜索树中的众数

501.二叉搜索树中的众数 给你一个含重复值的二叉搜索树&#xff08;BST&#xff09;的根节点 root &#xff0c;找出并返回 BST 中的所有 众数&#xff08;即&#xff0c;出现频率最高的元素&#xff09;。 如果树中有不止一个众数&#xff0c;可以按 任意顺序 返回。 假定 …

全国贫困县DID数据(2008-2022年)

数据来源&#xff1a;国W院扶贫开发领导小组办公室 时间跨度&#xff1a;2008-2022年 数据范围&#xff1a;各县域 数据指标 年份 县域名称 所属地市 所属省份 县域代码 是否贫困县(是为1&#xff0c;否为0) 参考文献&#xff1a; [1]马雯嘉,吴茂祯.从全面脱贫到乡村振兴…

VQ-BeT: Behavior Generation with Latent Actions 代码复现(Mujoco 安装)

代码地址&#xff1a;https://github.com/jayLEE0301/vq_bet_official.git 创建环境 conda create -n vq-bet python3.9 conda activate vq-bet拉取库 git clone https://github.com/jayLEE0301/vq_bet_official.git export PROJ_ROOT$(pwd)安装pytorch conda install pyto…

AI论文速读 | 线性时间序列预测模型分析

论文标题&#xff1a;An Analysis of Linear Time Series Forecasting Models 作者&#xff1a; William Toner&#xff0c; Luke Darlow 机构&#xff1a;爱丁堡大学&#xff08;Edinburgh&#xff09;&#xff0c;华为研究中心&#xff08;爱丁堡&#xff09; 论文链接&am…

开源项目若依放大招了?

前言 鉴于之前写了篇插件式相关的文章&#xff0c;阅读量比起其它文章可不要好太多&#xff0c;所以我决定继续这个主题&#xff01; 以前我们公司用的就是Ruoyi&#xff0c;代码比较简单易懂。但是有些功能确实用不上&#xff0c;比如部门和岗位&#xff0c;每次新项目我拉了…

Docker安装及开启远程访问

这几天有人问我docker是怎么开启远程服务的&#xff1f; 正好之前我做过这件事情&#xff0c;并且写了相关的笔记&#xff0c;现在整理为一篇博客发出来。 安装Docker 首先更新一下自己的yum版本 yum update安装一下所需要的软件包 yum-config-manager --add-repo http://…

Java复习第十七天学习笔记(转发、重定向,GET,POST),附有道云笔记链接

【有道云笔记】十七 4.3 转发、重定向、Get、POST、乱码 https://note.youdao.com/s/GD5TRksQ 一、转发 转发&#xff1a;一般查询了数据之后&#xff0c;转发到一个jsp页面进行展示 req.setAttribute("list", list); req.getRequestDispatcher("student_lis…