用pytorch实现Resnet

     ResNet(Residual Network)是一种深度卷积神经网络架构,由Kaiming He等人于2015年提出。它在计算机视觉领域引起了革命性的变革,使得训练更深的神经网络成为可能,超越了传统网络架构的限制。
     ResNet的主要创新在于残差学习的概念。传统神经网络存在梯度消失的问题,即随着梯度在多个层传播,其数值变得指数级小,从而阻碍了学习过程,限制了网络的有效训练深度。

      ResNet通过引入跳跃连接或快捷连接来解决这个问题。与直接拟合期望的映射不同,ResNet学习拟合残差映射,即期望输出与输入之间的差异。这些跳跃连接允许梯度直接在多个层之间流动,并缓解了梯度消失的问题。
       ResNet的核心构建块是残差块,它由两个卷积层、批归一化和ReLU激活函数组成。残差块接收输入张量,通过这些层,然后将输入张量和卷积层的输出相加。加法操作将原始输入与学到的残差相结合,创建了梯度流的快捷路径


    残差块

residual结构使用了一种shortcut的连接方式,也可理解为捷径。让特征矩阵隔层相加,注意F(X)和X形状要相同,所谓相加是特征矩阵相同位置上的数字进行相加

 残差块里首先有2个有相同输出通道数的3*3卷积层。 每个卷积层后接一个批量规范化层和ReLU激活函数。 然后我们通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。 这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。 如果想改变通道数,就需要引入一个额外的1*1卷积层来将输入变换成需要的形状后再做相加运算

 

import torch
import torch.nn as nn


# Residual Block (Basic Building Block of ResNet)
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out


# ResNet Architecture
class ResNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 3)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 6, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, num_blocks, stride=1):
        layers = []
        layers.append(ResidualBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels

        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)

        return out


# Create an instance of the ResNet model
model = ResNet(num_classes=1000)

# Print the model architecture
print(model)

 ResNet通常由多个堆叠的残差块组成,深度逐渐增加。网络架构包括不同的变体,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152,其中数字表示网络中总层数。较深的变体在图像分类、目标检测和分割等各种计算机视觉任务中表现出更好的性能。
ResNet的一个显著优势是可以训练非常深的网络而不降低性能。它使得可以训练超过100层的网络,同时仍然保持准确性和收敛性。此外,跳跃连接使得轻松实现恒等映射,意味着可以将浅层网络转变为更深的网络而不降低性能。
      ResNet对深度学习领域产生了重大影响,并成为各种计算机视觉应用中广泛采用的架构。它的残差学习概念也启发了其他使用跳跃连接的架构的发展,如DenseNet和Highway Networks。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/79632.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

第G1周:生成对抗网络(GAN)入门

📌 基础任务:了解什么是生成对抗网络(GAN) 学习本文代码,并跑通代码 🎈进阶任务: 调用训练好的模型生成新图像 目录 一、理论基础 1.1生成器 1.2判别器 1.3基本原理 二、前期准备工作 2.…

基于CNN卷积神经网络的口罩检测识别系统matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 ............................................................ % 循环处理每张输入图像 for…

汽车租赁管理系统/汽车租赁网站的设计与实现

摘 要 租赁汽车走进社区,走进生活,成为当今生活中不可缺少的一部分。随着汽车租赁业的发展,加强管理和规范管理司促进汽车租赁业健康发展的重要推动力。汽车租赁业为道路运输车辆一种新的融资服务形式、广大人民群众一种新的出行消费方式和…

Centos7 配置Docker镜像加速器

docker实战(一):centos7 yum安装docker docker实战(二):基础命令篇 docker实战(三):docker网络模式(超详细) docker实战(四):docker架构原理 docker实战(五):docker镜像及仓库配置 docker实战(六):docker 网络及数据卷设置 docker实战(七):docker 性质及版本选择 认知升…

区块链中slot、epoch、以及在slot和epoch中的出块机制,分叉原理(自己备用)

以太坊2.0中有两个时间概念:时隙槽slot 和 时段(周期)epoch。其中一个slot为12秒,而每个 epoch 由 32 个 slots 组成,所以每个epoch共384秒,也就是 6.4 分钟。 对于每个epoch,使用RANDAO伪随机…

Unity小项目__打砖块

//1.添加地面 1)创建一个平面,命名为Ground。 2)创建一个Materials文件夹,并在其中创建一个Ground材质,左键拖动其赋给平面Plane。 3)根据喜好设置Ground材质和Ground平面的属性。 // 2.创建墙体 1)创建一个Cube&…

基于Simulink的Chaos混沌电路设计与仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 07_001m 4.算法理论概述 混沌电路是一类特殊的非线性电路,其输出信号表现出无规律…

数字后端笔试题(1)DCG后congestion问题

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 已知某模块的DCG结果显示存在congestion,有congestion部分逻辑结构如下图: 问题1: 如何分析该电路有congestion问题的原因? 答:data selecti…

基于STM32+FreeRTOS的四轴机械臂

目录 项目概述: 一 准备阶段(都是些废话) 二 裸机测试功能 1.摇杆控制 接线: CubeMX配置: 代码: 2.蓝牙控制 接线: CubeMX配置 代码: 3.示教器控制 4.记录动作信息 5.执…

el-table 多个表格切换多选框显示bug

今天写了个功能,点击左侧的树做判断,一级树节点显示系统页面,二级树节点显示数据库页面,三级树节点显示表页面。 数据库页面和表页面分别有2个el-table ,上面的没有多选框,下面的有多选框 现在出现bug,在…

Linux学习之iptables过滤规则的使用

cat /etc/redhat-release看到操作系统是CentOS Linux release 7.6.1810,uname -r看到内核版本是3.10.0-957.el7.x86_64,iptables --version可以看到iptables版本是v1.4.21。 iptables -t filter -A INPUT -s 10.0.0.8 -j ACCEPT会在最后一行插入。 10…

winform 封装unity web player 用户控件

环境: VS2015Unity 5.3.6f1 (64-bit) 目的: Unity官方提供的UnityWebPlayer控件在嵌入Winform时要求读取的.unity3d文件路径(Src)必须是绝对路径,如果移动代码到另一台电脑,需要重新修改src。于是考虑使…

Hadoop学习:深入解析MapReduce的大数据魔力之数据压缩(四)

Hadoop学习:深入解析MapReduce的大数据魔力之数据压缩(四) 4.1 概述1)压缩的好处和坏处2)压缩原则 4.2 MR 支持的压缩编码4.3 压缩方式选择4.3.1 Gzip 压缩4.3.2 Bzip2 压缩4.3.3 Lzo 压缩4.3.4 Snappy 压缩4.3.5 压缩…

Apache JMeter

下载 Apache JMeter 并安装 java链接 打开 apache-jmeter-5.4.1\bin 找到jmeter.bat 双击打开 或者 ApacheJMeter.jar 双击打开 设置中文 找到 options 》choose Language 》chinese 新建 计划 创建线程组 添加Http请求 配置元件添加请求头参数(content-type&…

腾讯云 CODING 荣获 TiD 质量竞争力大会 2023 软件研发优秀案例

点击链接了解详情 8 月 13-16 日,由中关村智联软件服务业质量创新联盟主办的第十届 TiD 2023 质量竞争力大会在北京国家会议中心召开。本次大会以“聚焦数字化转型 探索智能软件研发”为主题,聚焦智能化测试工程、数据要素、元宇宙、数字化转型、产融合作…

报名开启 | HarmonyOS第一课“营”在暑期系列直播

<HarmonyOS第一课>2023年再次启航&#xff01; 特邀HarmonyOS布道师云集华为开发者联盟直播间 聚焦HarmonyOS 4版本新特性 邀您一同学习赢好礼&#xff01; 你准备好了吗&#xff1f; ↓↓↓预约报名↓↓↓ 点击关注了解更多资讯&#xff0c;报名学习

CS:GO升级 Linux不再是“法外之地”

在前天的VAC大规模封禁中&#xff0c;有不少Linux平台的作弊玩家也迎来了“迟到”的VAC封禁。   一直以来&#xff0c;Linux就是VAC封禁的法外之地。虽然大部分玩家都使用Windows平台进行游戏。但实际上&#xff0c;使用Linux畅玩CS:GO的玩家也不在少数。 以前V社主要打击W…

LVS - DR

LVS-DR 数据流向 客户端发送请求到 Director Server&#xff08;负载均衡器&#xff09;&#xff0c;请求的数据报文&#xff08;源 IP 是 CIP,目标 IP 是 VIP&#xff09;到达内核空间。Director Server 和 Real Server 在同一个网络中&#xff0c;数据通过二层数据链路层来传…

商城-学习整理-高级-商城业务-商品上架es(十)

目录 一、商品上架1、sku在ES中存储模型分析2、nested数据类型场景3、构造基本数据&#xff08;商品上架&#xff09; 二、首页1、项目介绍2、整合thymeleaf&#xff08;spring-boot下模板引擎&#xff09;渲染页面3、页面修改不重启服务器实时更新4、渲染二级三级数据 三、搭建…

「UG/NX」Block UI 面收集器FaceCollector

✨博客主页何曾参静谧的博客📌文章专栏「UG/NX」BlockUI集合📚全部专栏「UG/NX」NX二次开发「UG/NX」BlockUI集合「VS」Visual Studio「QT」QT5程序设计「C/C+&#