ResNet网络分析与demo实例

参考自 

  • up主的b站链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频
  • 这位大佬的博客 Fun'_机器学习,pytorch图像分类,工具箱-CSDN博客

 ResNet 详解

原论文地址 [1512.03385] Deep Residual Learning for Image Recognition (arxiv.org)

ResNet 网络是在 2015年 由微软实验室提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。

在ResNet网络的创新点:

  • 提出 Residual 结构(残差结构),并搭建超深的网络结构(可突破1000层)
  • 使用 Batch Normalization 加速训练(丢弃dropout)

下图是ResNet34层模型的结构简图:

在ResNet网络提出之前,传统的卷积神经网络都是通过将一系列卷积层与池化层进行堆叠得到的。

一般我们会觉得网络越深,特征信息越丰富,模型效果应该越好。但是实验证明,当网络堆叠到一定深度时,会出现两个问题

梯度消失或梯度爆炸

退化问题

如下图所示,20层网络 反而比 56层网络 的误差更小:

对于梯度消失或梯度爆炸问题,ResNet论文提出通过数据的预处理以及在网络中使用

 BN(Batch Normalization)层来解决。

对于退化问题,ResNet论文提出了 residual结构残差结构)来减轻退化问题,下图是使用residual结构的卷积网络,可以看到随着网络的不断加深,效果并没有变差,而是变的更好了。(虚线是train error,实线是test error)

为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)。

残差网络由许多隔层相连的神经元子模块组成,我们称之为 残差块 Residual block。单个残差块的结构如下图所示:

原文的表注中已说明,conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层残差结构都是虚线残差结构。因为这一系列残差结构的第一层都有调整输入特征矩阵shape的使命(将特征矩阵的高和宽缩减为原来的一半,将深度channel调整成下一层残差结构所需要的channel)

需要注意的是,对于ResNet50/101/152,其实conv2_x所对应的一系列残差结构的第一层也是虚线残差结构,因为它需要调整输入特征矩阵的channel。根据表格可知通过3x3的max pool之后输出的特征矩阵shape应该是[56, 56, 64],但conv2_x所对应的一系列残差结构中的实线残差结构它们期望的输入特征矩阵shape是[56, 56, 256](因为这样才能保证输入输出特征矩阵shape相同,才能将捷径分支的输出与主分支的输出进行相加)。所以第一层残差结构需要将shape从[56, 56, 64] --> [56, 56, 256]。注意,这里只调整channel维度,高和宽不变(而conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层虚线残差结构不仅要调整channel还要将高和宽缩减为原来的一半)。

下面是 ResNet 18/34 和 ResNet 50/101/152 具体的实线/虚线残差结构图:
 

ResNet 18/34

ResNet 50/101/152s

在迁移学习中,我们希望利用源任务(Source Task)学到的知识帮助学习目标任务 (Target Task)。例如,一个训练好的图像分类网络能够被用于另一个图像相关的任务。再比如,一个网络在仿真环境学习的知识可以被迁移到真实环境的网络。迁移学习一个典型的例子就是载入训练好VGG网络,这个大规模分类网络能将图像分到1000个类别,然后把这个网络用于另一个任务,如医学图像分类。

为什么可以这么做呢?如下图所示,神经网络逐层提取图像的深层信息,这样,预训练网络就相当于一个特征提取器。

model.py

  • 定义ResNet18/34的残差结构,即BasicBlock
  • 定义ResNet50/101/152的残差结构,即Bottleneck
  • 定义ResNet网络结构
  • 定义resnet18/34/50/101/152

import torch.nn as nn
import torch


# ResNet18/34的残差结构,用的是2个3x3的卷积
class BasicBlock(nn.Module):
    expansion = 1  # 残差结构中,主分支的卷积核个数是否发生变化,不变则为1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):  # downsample对应虚线残差结构
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:  # 虚线残差结构,需要下采样
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积
class Bottleneck(nn.Module):
    expansion = 4  # 残差结构中第三层卷积核个数是第一/二层卷积核个数的4倍

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    # block = BasicBlock or Bottleneck
    # block_num为残差结构中conv2_x~conv5_x中残差块个数,是一个列表
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])             # conv2_x
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)  # conv3_x
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)  # conv4_x
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)  # conv5_x
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    # channel为残差结构中第一层卷积核个数
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None

        # ResNet50/101/152的残差结构,block.expansion=4
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

train.py

由于ResNet网络较深,直接训练的话会非常耗时,因此用迁移学习的方法导入预训练好的模型参数:
在pycharm中输入import torchvision.models.resnet,ctrl+左键resnet跳转到pytorch官方实现resnet的源码中,下载预训练的模型参数:
 

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/271412.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python、PHP/JAVA/C#电商评论数据采集与分析

引言 在电商竞争日益激烈的情况下,商家既要提高产品质量,又要洞悉客户的想法和需求,关注客户购买商品后的评论,而第三方商家获取商品评价主要依赖于人工收集,不但效率低,而且准确度得不到保障。通过使用Py…

【数据结构和算法】找到最高海拔

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、题目描述 二、题解 2.1 前缀和的解题模板 2.1.1 最长递增子序列长度 2.1.2 寻找数组中第 k 大的元素 2.1.3 最长公共子序列…

利用MATLAB设计一个(2,1,7)卷积码编译码器

1、条件: 输入数字信号,可以随机产生,也可手动输入 2、要求: (1)能显示编码树、网格图或状态转移图三者之一; (2)根据输入数字信号编码生成卷积码并显示&#xf…

如何进行块存储管理

目录 块存储概念 块存储(云盘)扩容 方式一:直接扩容现有云盘 方式二:创建一块新数据盘 方式三:在更换操作系统时,同时更换系统盘 块存储(云盘)变配 云盘变配操作步骤 块存储概…

索引进阶 | 再谈 MySQL 的慢 SQL 优化

索引可以提高数据检索的效率,降低数据库的IO成本。 MySQL在300万条记录左右性能开始逐渐下降,虽然官方文档说500~800w记录,所以大数据量建立索引是非常有必要的。 MySQL提供了Explain,用于显示SQL执行的详细信息,可以…

质量免费吗?

本文首发于个人网站「BY林子」,转载请参考版权声明。 两个场景 场景一:有限经费与质量改进 “要写自动化的单元测试、E2E测试,就会需要更多的钱,可是我们经费有限暂时做不了。” “CI上配置SonarQube扫描,对于扫描出来…

godot 报错Unable to initialize Vulkan video driver解决

版本 godot 4.2.1 现象 godot4.2.1 默认使用vulkan驱动,如果再不支持vulkan驱动的主机上,进入引擎编辑器将报错如下 解决 启动参数添加 –rendering-driver opengl3 即可进入引擎编辑器 此时运行项目仍然会报错无法初始化驱动 在项目设置中配置编…

Vue-Pinina基本教程

前言 官网地址:Pinia | The intuitive store for Vue.js (vuejs.org) 看以下内容,需要有vuex的基础,下面很多概念会直接省略,比如state、actions、getters用处含义等 1、什么是Pinina Pinia 是 Vue 的存储库,它允许您跨…

储能:东风已至,破浪在即——安科瑞 顾烊宇

今年的各省政府工作报告已经陆续发布,新能源是各省能源工作的重点,从目前31个省(区、市)相继公布的2022年经济增长数据来看,一些提前布局新能源产业的省市纷纷交出不错的成绩单,新能源成为当地GDP增速的重要…

饥荒Mod 开发(二三):显示物品栏详细信息

饥荒Mod 开发(二二):显示物品信息 源码 前一篇介绍了如何获取 鼠标悬浮物品的信息,这一片介绍如何获取 物品栏的详细信息。 拦截 inventorybar 和 itemtile等设置字符串方法 在modmain.lua 文件中放入下面代码即可实现鼠标悬浮到 物品栏显示物品详细信…

微信小程序云开发-下载云存储中的文件

一、前言 很多时候我们需要实现用户在客户端下载服务端的文件(图片、视频、pdf等)到用户本地并保存起来,小程序也经常需要实现这样的需求。 在传统服务器开发下网上已经有很多关于小程序下载服务端文件的资料了,但是基于云开发的…

苹果怎么备份QQ的聊天记录?这3招教你快速备份!

QQ聊天记录是我们与好友之间的重要互动和沟通记录。但是,有时可能会由于各种原因,比如系统崩溃、更换手机、自身误操作、QQ闪退等,可能会导致聊天记录丢失。 因此,备份QQ聊天记录显得尤为重要。那么,苹果手机怎么备份…

SAP CO系统配置-与PS集成相关配置(机器人制造项目实例)

维护分配结构 配置路径 IMG菜单路径:控制>内部订单>实际过帐>结算>维护分配结构 事务代码 OKO6 维护结算参数文件 定义利润分析码

ZED-Mini 标定完全指南(应该是最详细的吧)

标定 ZED-Mini 相机主要为了跑 VINS-Fusion 以及后期的联合标定相关事宜 双目相机标定 出厂标定数据 关于ZED相机的内参,使用出厂标定的数据就好了,如果安装ZED的SDK时使用的是默认的安装路径,可以在/usr/local/zed/settings下面找到一个SN…

漏洞处理-未设置X-Frame-Options

漏洞名称&#xff1a;iFrame注入 风险描述&#xff1a;系统未设置x-frame-options头 风险等级&#xff1a;低 整改建议&#xff1a;为系统添加x-frame-options头 知识 X-Frame-Options 响应头 X-Frame-Options HTTP 响应头是用来给浏览器指示允许一个页面可否在 <fram…

通过 Bytebase API 做数据库 Schema 变更

Bytebase 是一款数据库 DevOps 和 CI/CD 工具&#xff0c;适用于开发人员、DBA 和平台工程团队。 它提供了一个直观的图形用户界面来管理数据库 Schema 变更。另一方面&#xff0c;一些团队可能希望将 Bytebase 集成到现有的内部 DevOps 研发平台中。这需要调用 Bytebase API。…

搭建Nginx文件下载站点

一、下载Nginx 首先&#xff0c;确保你的服务器上已经安装了Nginx&#xff0c;使用编译安装&#xff0c;下载最新版Nginx。 wget https://nginx.org/download/nginx-1.25.3.tar.gz tar -xf nginx-1.25.3.tar.gz二、安装Fancyindex和Nginx-Fancyindex-Theme模块 # 下载Fancyin…

外贸中的很多跟想的不一样的事情

说说最近遇到的几个客户情况&#xff0c;以及对一些事情刷新的认知。 第一个客户姑且称为A吧&#xff0c;这个客户在询价的时候&#xff0c;产品的名称以及数量以还有走货的方式写的很清楚&#xff0c;客户A要的产品不是很多&#xff0c; 顶多算是个样品单。 一般情况下&…

腾讯云2核4G服务器CVM标准型S5实例5年优惠价格表

腾讯云服务器续费贵所以一次性买3年或5年&#xff0c;腾讯云轻量应用服务器3年价格有优惠&#xff0c;CVM云服务器5年有特价&#xff0c;腾讯云3年轻量和5年云服务器CVM优惠活动入口&#xff0c;3年轻量应用服务器配置可选2核2G4M和2核4G5M带宽&#xff0c;5年CVM云服务器可以选…

学习笔记11——Spring的XML配置

学习笔记系列开头惯例发布一些寻亲消息 链接&#xff1a;https://www.baobeihuijia.com/bbhj/contents/3/192584.html SSM框架——IOC基础【BeanSetter注入加载xml】 框架总览 Spring Framework 谈谈我对Spring的理解 - 知乎 (zhihu.com)java - 【架构视角】一篇文章带你彻底…