分类神经网络2:ResNet模型复现

目录

ResNet网络架构

ResNet部分实现代码


ResNet网络架构

论文原址:https://arxiv.org/pdf/1512.03385.pdf

残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的,通过引入残差学习解决了深度网络训练中的退化问题,同时该网络结构特点主要表现为3点,超深的网络结构(超过1000层)、提出residual(残差结构)模块和使用Batch Normalization 加速训练(丢弃dropout)。ResNet网络的模型结构如下:

ResNet网络通过“跳跃连接”,将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。即神经网络学习到该层的参数是冗余的时候,它可以选择直接走这条“跳接”曲线(shortcut connection),跳过这个冗余层,而不需要再去拟合参数使得H(x)=F(x)=x。同时通过这种连接方式不仅保护了信息的完整性(避免卷积层堆叠存在的信息丢失),整个网络也只需要学习输入、输出差别的部分,这克服了由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题(梯度消失或梯度爆炸)。

残差模块如下图示:

残差结构有两种,常规残差和瓶颈残差常规残差:由2个3x3卷积层堆叠而成,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度(随着网络深度的加深,这种残差模块在实践中并不十分有效);瓶颈残差:依次由1x1 、3x3 、1x1个卷积层构成,这里1x1卷积,能够对通道数channel起到升维或降维的作用,从而令3x3 的卷积,以相对较低维度的输入进行卷积运算,提高计算效率。

ResNet网络的具体配置信息如下:

在构建神经网络时,首先采用了步长为2的卷积层进行图像尺寸缩减,即下采样操作,紧接着是多个残差结构,在网络架构的末端,引入了一个全局平均池化层,用于整合特征信息,最后是一个包含1000个类别的全连接层,并在该层后应用了softmax激活函数以进行多分类任务。值得注意的是,通过引入残差连接模块,其最深的网络结构达到了152层,同时在50层后均使用的是瓶颈残差结构。

ResNet部分实现代码

老样子,直接上代码,建议大家阅读代码时结合网络结构理解

import torch
import torch.nn as nn

__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class ConvBN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBN, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding:捕捉局部特征和空间相关性,学习更复杂的特征和抽象表示"""
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_channels, out_channels, stride=1):
    """1x1 convolution:实现降维或升维,调整通道数和执行通道间的线性变换"""
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.convbnrelu1 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride)
        self.convbn1 = ConvBN(out_channels, out_channels, kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.conv_down = nn.Sequential(
            conv1x1(in_channels, out_channels * self.expansion, self.stride),
            nn.BatchNorm2d(out_channels * self.expansion),
        )

    def forward(self, x):
        residual = x
        out = self.convbnrelu1(x)
        out = self.convbn1(out)

        if self.downsample:
            residual = self.conv_down(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        groups = 1
        base_width = 64
        dilation = 1

        width = int(out_channels * (base_width / 64.)) * groups   # wide = out_channels
        self.conv1 = conv1x1(in_channels, width)       # 降维通道数
        self.bn1 = nn.BatchNorm2d(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = nn.BatchNorm2d(width)
        self.conv3 = conv1x1(width, out_channels * self.expansion)   # 升维通道数
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.conv_down = nn.Sequential(
            conv1x1(in_channels, out_channels * self.expansion, self.stride),
            nn.BatchNorm2d(out_channels * self.expansion),
        )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample:
            residual = self.conv_down(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64):
        super(ResNet, self).__init__()

        self.inplanes = 64
        self.dilation = 1
        replace_stride_with_dilation = [False, False, False]
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = False
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = True

        layers = nn.ModuleList()
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return layers

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        for layer in self.layer1:
            x = layer(x)
        for layer in self.layer2:
            x = layer(x)
        for layer in self.layer3:
            x = layer(x)
        for layer in self.layer4:
            x = layer(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def resnet18(num_classes, **kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

def resnet34(num_classes, **kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet50(num_classes, **kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet101(num_classes, **kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs)

def resnet152(num_classes, **kwargs):
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)


if __name__=="__main__":
    import torchsummary
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input = torch.ones(2, 3, 224, 224).to(device)
    net = resnet50(num_classes=4)
    net = net.to(device)
    out = net(input)
    print(out)
    print(out.shape)
    torchsummary.summary(net, input_size=(3, 224, 224))
    # Total params: 134,285,380

希望对大家能够有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/562910.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【FFmpeg】视频与图片互相转换 ( 视频与 JPG 静态图片互相转换 | 视频与 GIF 动态图片互相转换 )

文章目录 一、视频与 JPG 静态图片互相转换1、视频转静态图片2、视频转多张静态图片3、多张静态图片转视频 二、视频与 GIF 动态图片互相转换1、视频转成 GIF 动态图片2、 GIF 动态图片转成视频 一、视频与 JPG 静态图片互相转换 1、视频转静态图片 执行 ffmpeg -i input.mp4 …

云原生Kubernetes: K8S 1.29版本 部署ingress-nginx

目录 一、实验 1.环境 2. K8S 1.29版本 部署ingress-nginx 二、问题 1.kubectl 如何强制删除 Pod、Namespace 资源 2.创建pod失败 3.pod报错ImagePullBackOff 4.docker如何将镜像上传到官方仓库 5.创建ingress报错 一、实验 1.环境 (1)主机 表…

数据结构 - 顺序表

一. 线性表的概念 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串... 线性表在逻辑上是线性结构,也就说是连续的…

SpringBoot 集成 WebSocket

前言 最近在做一个 WebSocket 通信服务的软件,所以必须跟着学一学。 1、WebSocket 概述 一般情况下,我们的服务器和服务器之间可以发送请求,但是服务器是不能向浏览器去发送请求的。因为设计之初并没有想到以后会出现服务端频繁向客户端发送…

【代码随想录刷题记录】LeetCode34在排序数组中查找元素的第一个和最后一个位置

题目地址 最近忙活实验,实在没空刷题,这个题对我来说难度还蛮大的,尤其是理解那个找左边界和找右边界的条件,后来我按照自己的理解写了出来(感觉给的答案解释起来有点反认识规律),所以我从0开始…

三轴加速度计LIS2DUX12开发(2)----静态校准

三轴加速度计LIS2DUX12开发.2----静态校准 概述硬件准备视频教学样品申请源码下载六位置法的标定方案旋转加速度计以找到极值计算偏移和灵敏度应用校准参数注意事项串口中断变量定义主程序流程串口发送定义演示 概述 最近在弄ST和瑞萨RA的课程,需要样片的可以加群申…

木马——文件上传

目录 1、WebShell 2.一句话木马 靶场训练 3.蚁剑 虚拟终端 文件管理 ​编辑 数据操作 4.404.php 5.文件上传漏洞 客户端JS检测 右键查看元素,删除检测代码 BP拦截JPG修改为php 服务端检测 1.MIME类型检测 2.文件幻数检测 3.后缀名检测 1、WebShell W…

【Hadoop】-HDFS的Shell操作[3]

目录 前言 一、HDFS集群启停命令 1.一键启停脚本可用 2.独立进程启停可用 二、文件系统操作命令 1、创建文件夹 2、查看指定目录下内容 3、上传文件到HDFS指定目录下 4、查看HDFS文件内容 5、下载HDFS文件 6、拷贝HDFS文件 7、追加数据到HDFS文件中 8、HDFS数据移…

【信号处理】基于CNN的心电(ECG)信号分类典型方法实现(tensorflow)

关于 本实验使用1维卷积神经网络实现心电信号的5分类。由于数据类别不均衡,这里使用典型的上采样方法,实现数据类别的均衡化处理。 工具 方法实现 数据加载 Read the CSV file datasets: NORMAL_LABEL0 , ABNORMAL_LABEL1,2,3,4,5 ptbdb_abnormalpd.…

使用JavaScript收集和发送用户设备信息,后端使用php将数据保存在本地json,便于后期分析数据

js代码部分 <script> // 之前提供的JavaScript代码 fetch(https://api.ipify.org?formatjson).then(response > response.json()).then(data > {const deviceInfo {userAgent: navigator.userAgent,platform: navigator.platform,language: navigator.language,…

[Spring Cloud] (4)搭建Vue2与网关、微服务通信并配置跨域

文章目录 前言gatway网关跨域配置取消微服务跨域配置 创建vue2项目准备一个原始vue2项目安装vue-router创建路由vue.config.js配置修改App.vue修改 添加接口访问安装axios创建request.js创建index.js创建InfoApi.js main.jssecurityUtils.js 前端登录界面登录消息提示框 最终效…

微信小程序vue.js+uniapp服装商城销售管理系统nodejs-java

本技术是java平台的开源应用框架&#xff0c;其目的是简化Sping的初始搭建和开发过程。默认配置了很多框架的使用方式&#xff0c;自动加载Jar包&#xff0c;为了让用户尽可能快的跑起来spring应用程序。 SpinrgBoot的主要优点有&#xff1a; 1、为所有spring开发提供了一个更快…

贝叶斯分类 python

贝叶斯分类 python 贝叶斯分类器是一种基于贝叶斯定理的分类方法&#xff0c;常用于文本分类、垃圾邮件过滤等领域。 在Python中&#xff0c;我们可以使用scikit-learn库来实现贝叶斯分类器。 下面是一个使用Gaussian Naive Bayes(高斯朴素贝叶斯)分类器的简单示例&#xff1…

大数据Hive中的UDF:自定义数据处理的利器(上)

文章目录 1. 前言2. UDF与宏及静态表的对比3. 深入理解UDF4. 实现自定义UDF 1. 前言 在大数据技术栈中&#xff0c;Apache Hive 扮演着数据仓库的关键角色&#xff0c;它提供了丰富的数据操作功能&#xff0c;并通过类似于 SQL 的 HiveQL 语言简化了对 Hadoop 数据的处理。然而…

汇编语言(详解)

汇编语言安装指南 第一步&#xff1a;在github上下载汇编语言的安装包 网址&#xff1a;GitHub - HaiPenglai/bilibili_assembly: B站-汇编语言-pdf、代码、环境等资料B站-汇编语言-pdf、代码、环境等资料. Contribute to HaiPenglai/bilibili_assembly development by creat…

STM32 | USART实战案例

STM32 | 通用同步/异步串行接收/发送器USART带蓝牙(第六天)随着扩展的内容越来越多,很多小伙伴已经忘记了之前的学习内容,然后后面这些都很难理解。STM32合集已在专栏创建,方面大家学习。1、通过电脑串口助手发送数据,控制开发板LED灯 从题目中可以挖掘出,本次使用led、延…

【JVM常见问题总结】

文章目录 jvm介绍jvm内存模型jvm内存分配参数jvm堆中存储对象&#xff1a;对象在堆中创建分配内存过程 jvm 堆垃圾收集器垃圾回收算法标记阶段引用计数算法可达性分析算法 清除阶段标记清除算法复制算法标记压缩算法 实际jvm参数实战jvm调优jvm常用命令常用工具 jvm介绍 Java虚…

C++设计模式:适配器模式(十四)

1、定义与动机 定义&#xff1a;将一个类的接口转换成客户希望的另外一个接口。Adapter模式使得原本由于接口不兼容而不能一起工作的哪些类可以一起工作。 动机&#xff1a; 在软件系统中&#xff0c;由于应用环境的变化&#xff0c;常常需要将“一些现存的对象”放在新的环境…

【Hadoop】- YARN概述[6]

目录 一、YARN & Reduce 二、分布式资源调度 - YARN 1、资源调度 2、YARN的资源调度 总结 一、YARN & Reduce MapReduce是基于YARN运行的&#xff0c;即没有YARN “无法” 运行MapReduce程序。 二、分布式资源调度 - YARN YARN&#xff08;Yet Another Resou…

注意力机制中多层的作用

1.多层的作用 在注意力机制中&#xff0c;多层的作用通常指的是将注意力机制堆叠在多个层上&#xff0c;这在深度学习模型中被称为“深度”或“多层”注意力网络。这种多层结构的作用和实现过程如下&#xff1a; 1. **逐层抽象**&#xff1a;每一层都可以捕捉到输入数据的不同…