第J2周:ResNet50V2算法实战与解析

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第J2周:ResNet50V2算法实战与解析
  • 🍖 原作者:K同学啊|接辅导、项目定制

目录

    • 一、论文解读
      • 1. ResNetV2结构与ResNet结构对比
      • 2. 关于残差结构的不同尝试
      • 3. 关于激活的尝试
    • 二、模型复现
      • 1. Residual Block
      • 2. 堆叠Residual Block
      • 3. ResNet50V2架构复现
      • 4. 在cifar10上训练

📌 本周任务:
●1.请根据本文 TensorFlow 代码,编写出相应的 Pytorch 代码(建议使用上周的数据测试一下模型是否构建正确)
●2.了解ResNetV2与ResNetV的区别
●3.改进思路是否可以迁移到其他地方呢(自由探索)

一、论文解读

1. ResNetV2结构与ResNet结构对比

在这里插入图片描述

实线表示测试误差(右边的y轴),虚线表示训练损失(左边的y轴),Iterations 表示迭代次数

🧲 改进点:(a)original 表示原始的 ResNet 的残差结构,(b)proposed 表示新的 ResNet 的残差结构。主要差别就是(a)结构先卷积后进行 BN 和激活函数计算,最后执行 addition 后再进行ReLU 计算; (b)结构先进行 BN 和激活函数计算后卷积,把 addition 后的 ReLU 计算放到了残差结构内部。

📌 改进结果:作者使用这两种不同的结构在 CIFAR-10 数据集上做测试,模型用的是 1001层的 ResNet 模型。从图中结果我们可以看出,(b)proposed 的测试集错误率明显更低一些,达到了 4.92%的错误率,(a)original 的测试集错误率是 7.61%。

2. 关于残差结构的不同尝试

在这里插入图片描述
(b-f)中的快捷连接被不同的组件阻碍。为了简化插图,我们不显示BN层,这里所有单位均采用权值层之后的BN层。图中(a-f)都是作者对残差结构的 shortcut 部分进行的不同尝试 ,作者对不同 shortcut 结构的尝试结果如下表所示 。

在这里插入图片描述

使用ResNet-110在CIFAR-10测试集上的分类错误,对所有残差单元应用了不同类型的shortcut connections。当测试误差高于20%时,标注为“fail”。

作者用不同 shortcut 结构的 ResNet-110 在 CIFAR-10 数据集上做测试,发现最原始的(a)original 结构是最好的,也就是 identity mapping 恒等映射是最好的。

3. 关于激活的尝试

在这里插入图片描述
在这里插入图片描述

使用不同激活函数的CIFAR-10测试集上的分类误差(%)。

最好的结果是(e)full pre-activation,其次到(a)original。

二、模型复现

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings
from torchsummary import summary


#忽略警告信息
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1. Residual Block

class Block2(nn.Module):
    def __init__(self, in_channel, filters, kernel_size=3, stride=1, conv_shortcut=False):
        super(Block2, self).__init__()
        self.preact = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(True)
        )

        self.shortcut = conv_shortcut
        if self.shortcut:
            self.short = nn.Conv2d(in_channel, 4*filters, 1, stride=stride, padding=0, bias=False)
        elif stride>1:
            self.short = nn.MaxPool2d(kernel_size=1, stride=stride, padding=0)
        else:
            self.short = nn.Identity()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, filters, 1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(True)
        )
        self.conv3 = nn.Conv2d(filters, 4*filters, 1, stride=1, bias=False)

    def forward(self, x):
        x1 = self.preact(x)
        if self.shortcut:
            x2 = self.short(x1)
        else:
            x2 = self.short(x)
        x1 = self.conv1(x1)
        x1 = self.conv2(x1)
        x1 = self.conv3(x1)
        x = x1 + x2
        return x

2. 堆叠Residual Block

class Stack2(nn.Module):
    def __init__(self, in_channel, filters, blocks, stride=2):
        super(Stack2, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module(str(0), Block2(in_channel, filters, conv_shortcut=True))
        for i in range(1, blocks-1):
            self.conv.add_module(str(i), Block2(4*filters, filters))
        self.conv.add_module(str(blocks-1), Block2(4*filters, filters, stride=stride))

    def forward(self, x):
        x = self.conv(x)
        return x

3. ResNet50V2架构复现

在这里插入图片描述

class ResNet50V2(nn.Module):
    def __init__(self,
                 include_top=True,  # 是否包含位于网络顶部的全链接层
                 preact=True,  # 是否使用预激活
                 use_bias=True,  # 是否对卷积层使用偏置
                 input_shape=[224, 224, 3],
                 classes=1000,
                 pooling=None):  # 用于分类图像的可选类数
        super(ResNet50V2, self).__init__()

        self.conv1 = nn.Sequential()
        self.conv1.add_module('conv', nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=use_bias, padding_mode='zeros'))
        if not preact:
            self.conv1.add_module('bn', nn.BatchNorm2d(64))
            self.conv1.add_module('relu', nn.ReLU())
        self.conv1.add_module('max_pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.conv2 = Stack2(64, 64, 3)
        self.conv3 = Stack2(256, 128, 4)
        self.conv4 = Stack2(512, 256, 6)
        self.conv5 = Stack2(1024, 512, 3, stride=1)

        self.post = nn.Sequential()
        if preact:
            self.post.add_module('bn', nn.BatchNorm2d(2048))
            self.post.add_module('relu', nn.ReLU())
        if include_top:
            self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
            self.post.add_module('flatten', nn.Flatten())
            self.post.add_module('fc', nn.Linear(2048, classes))
        else:
            if pooling=='avg':
                self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
            elif pooling=='max':
                self.post.add_module('max_pool', nn.AdaptiveMaxPool2d((1, 1)))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.post(x)
        return x


model = ResNet50V2().to(device)
summary(model, (3, 224, 224))

4. 在cifar10上训练

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义预处理转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将 PIL 图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    train_loss,train_acc = 0,0
    for x,y in dataloader:
        x,y = x.to(device),y.to(device)
        pred = model(x)
        loss = loss_fn(pred,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
    train_loss /= num_batches
    train_acc /= size
    return train_loss,train_acc
def test(dataloader,model,loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss,test_acc = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model(x)
            loss = loss_fn(pred,y)
            test_loss += loss.item()
            test_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
    test_loss /= num_batches
    test_acc /= size
    return test_loss,test_acc

参考文章:http://t.csdn.cn/VhPbf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/49796.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux下查找python路径

本地目前装了几个版本的python,这里记录下查找python路径的方法。 1:whereis命令 whereis python2:which命令 which python与whereis相似,但which会返回第一个找到的执行文件的位置。 3:find命令 find命令可以搜索系…

CenOS设置启动级别

背景知识 init一共分为7个级别,这7个级别的所代表的含义如下 0:停机或者关机(千万不能将initdefault设置为0)1:单用户模式,只root用户进行维护2:多用户模式,不能使用NFS(Net File S…

【Docker】Docker相关基础命令

目录 一、Docker服务相关命令 1、启动docker服务 2、停止docker服务 3、重启docker服务 4、查看docker服务状态 5、开机自启动docker服务 二、Images镜像相关命令 1、查看镜像 2、拉取镜像 3、搜索镜像 4、删除镜像 三、Container容器相关命令 1、创建容器 2、查…

【C++】开源:Linux端ALSA音频处理库

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍Linux端ALSA音频处理库。 无专精则不能成,无涉猎则不能通。。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下&#xff0c…

Windows数据类型LPSTR学习

Windows在C语言的基础之上又定义了一些Windows下的数据类型;下面学习一下LPSTR; LPSTR和LPWSTR是Win32和VC所使用的一种字符串数据类型。LPSTR被定义成是一个指向以NULL(‘\0’)结尾的32位ANSI字符数组指针,而LPWSTR是一个指向以NULL结尾的64…

代码版本管理工具 git

1. 去B站看视频学习,只看前39集: 01-Git概述(Git历史)_哔哩哔哩_bilibili 2.学习Linux系统文本编辑器的使用 vi编辑器操作指令分享 (baidu.com) (13条消息) nano编辑器的使用_SudekiMing的博客-CSDN博客 windows下载安装Git官…

电路原理分析1

d2的作用是提供一个1.25v的电平 r3、r4的作用都是限流 c1是滤波 运放的4、8脚是常规的外围 这个运放是一个运算放大电路 具体计算是这样的: 按照虚短原则,输入的信号Uinu1,输出的信号Uoutu3 按照虚断原则,i1i2i5i5 u1/r2i1i5&#xff…

Longhorn vs Rook vs OpenEBS vs Portworx vs IOMesh:细说 5 款 K8s 持久化存储产品优劣势

云原生时代下,越来越多的企业开始使用 Kubernetes(K8s)承载数据库、消息中间件等“生产级”有状态工作负载。由于这些应用对数据持久保存、性能、容量扩展和快速交付具有较高的要求,企业往往需要采用专为 Kubernetes 环境设计的持…

wxwidgets Ribbon使用简单实例

// RibbonSample.cpp : 定义控制台应用程序的入口点。 // #include "stdafx.h" #include <wx/wx.h> #include "wx/wxprec.h" #include "wx/app.h" #include "wx/frame.h" #include "wx/textctrl.h" #include "…

教雅川学缠论02-K线

传统行情上的K线是下图中这样子的 而在缠论中K线是下面这样子的&#xff0c;它没有上影线和下影线 下图是武汉控股2023年7月的日K线 接下来我们将它转换成缠论K线&#xff08;画图累死我了&#xff09; K线理解了我们才能进行下一步&#xff0c;目前位置应该很好理解的

压缩算法的原理丨基因型vcf文件为什么压缩后发生了什么?

压缩算法的本质 最近碰到一个神奇的现象&#xff0c;一份大小为16GB的xx.vcf.gz文件&#xff0c;解压之后体积变为600GB的vcf文件&#xff0c;为什么一份文件经过压缩后体积缩小了这么多&#xff1f; (work) [bio notes 21:29:40 ~/work/20230726/data]$ ls -lh总用量 620GB-…

OnnxRuntime TensorRT OpenCV::DNN性能对比(YoloV8)实测

1. 前言 之前把ORT的一套推理环境框架搭好了,在项目中也运行得非常愉快,实现了cpu/gpu,fp32/fp16的推理运算,同onnx通用模型在不同推理框架下的性能差异对比贴一下,记录一下自己对各种推理框架的学习状况 YoloV8模型大小 模型名称参数量NANO3.2M...... 2. CPU篇 CPU推理框架性…

机器人状态估计:robot_localization 功能包高级参数详解

机器人状态估计&#xff1a;robot_localization 功能包高级参数详解 前言功能包简介相关参数高级参数 前言 移动机器人的状态估计需要用到很多传感器&#xff0c;因为对单一的传感器来讲&#xff0c;都存在各自的优缺点&#xff0c;所以需要一种多传感器融合技术&#xff0c;将…

扫地机语音提示芯片,智能家居语音交互首选方案,WT588F02B-8S

智能家居已经成为现代家庭不可或缺的一部分&#xff0c;而语音交互技术正是智能家居的核心。在智能家居设备中&#xff0c;扫地机无疑是最受欢迎的产品之一。然而&#xff0c;要实现一个更智能的扫地机&#xff0c;需要一颗语音提示芯片&#xff0c;以提供高质量的语音交互体验…

【MySQL】表的内外连接

目录 一、内连接二、外连接2.1 左外连接2.2 右外连接 三、OJ题 表的连接分为内连和外连 一、内连接 内连接实际上就是利用where子句对两种表形成的笛卡儿积进行筛选&#xff0c;我们前面学习的查询都是内连接&#xff0c;也是在开发过程中使用的最多的连接查询。 语法&#x…

Feign API模块导入的两种方式

说明&#xff1a;在微服务框架中&#xff0c;会把其他微服务用到的FeignClient统一放到一个模块里面&#xff0c;称为FeignAPI&#xff0c;其他微服务需要使用FeignClient&#xff0c;可以引入FeignAPI的Maven坐标。 但是只引入FeignAPI的坐标还不行&#xff0c;FeignAPI中的B…

自动化测试:让软件测试更高效更愉快!

谈谈那些实习测试工程师应该掌握的基础知识&#xff08;一&#xff09;_什么时候才能变强的博客-CSDN博客https://blog.csdn.net/qq_17496235/article/details/131839453谈谈那些实习测试工程师应该掌握的基础知识&#xff08;二&#xff09;_什么时候才能变强的博客-CSDN博客h…

STM32CubeMX v6.9.0 BUG:FLASH_LATENCY设置错误导致初始化失败

背景 今天在调试外设功能时&#xff0c;发现设置了使用外部时钟之后程序运行异常&#xff0c;进行追踪调试并与先前可以正常运行的项目进行对比之后发现这个问题可能是由于新版本的STM32CubeMX配置生成代码时的BUG引起的。 测试环境 MCU: STM32H750VBT6 STM32CubeIDE: Versi…

WEB:Web_python_template_injection

背景知识 python模板注入 ssit 题目 打开题目&#xff0c;发现页面提示&#xff0c;翻译为python模板注入 先测试是否存在注入 可以发现被执行了 先查看所有的子类 payload {{[].__class__.__base__.__subclasses__()}} 利用site.Printer的os模块执行命令 payload {{.__…

路由器工作原理

路由器原理 路由概述 路由&#xff1a;跨越从源主机到目标主机的一个互联网络来转发数据包的过程。&#xff08;为数据包选择路径的过程&#xff09; 作用&#xff1a;路由器是连接不同网段的。 转发依据&#xff1a; 路由表&#xff1a;路径选择全看路由表&#xff0c;根…