Pytorch从零开始实战16

Pytorch从零开始实战——ResNeXt-50算法的思考

本系列来源于365天深度学习训练营

原作者K同学

对于上次ResNeXt-50算法,我们同样有基于TensorFlow的实现。具体代码如下。

引入头文件

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model

分组卷积模块

# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):
    group_list = []
    # 分组进行卷积
    for c in range(groups):
        # 分组取出数据
        x = Lambda(lambda x: x[:, :, :, c * g_channels:(c + 1) * g_channels])(init_x)
        # 分组进行卷积
        x = Conv2D(filters=g_channels, kernel_size=(3, 3),strides=strides, padding='same', use_bias=False)(x)
        # 存入list
        group_list.append(x)
    # 合并list中的数据
    group_merage = concatenate(group_list, axis=3)
    x = BatchNormalization(epsilon=1.001e-5)(group_merage)
    x = ReLU()(x)
    return x

残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):

    if conv_shortcut:
        shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x)
        # epsilon为BN公式中防止分母为零的值
        shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)
    else:
        # identity_shortcut
        shortcut = x
        
    # 三层卷积层
    x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = ReLU()(x)
    # 计算每组的通道数
    g_channels = int(filters / groups)
    # 进行分组卷积
    x = grouped_convolution_block(x, strides, groups, g_channels)

    x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x

堆叠残差单元

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):
    # 每个stack的第一个block的残差连接都需要使用1*1卷积升维
    x = block(x, filters, strides=strides, groups=groups)
    for i in range(blocks):
        x = block(x, filters, groups=groups, conv_shortcut=False)
    return x

网络搭建

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):
    inputs = Input(shape=input_shape)
    # 填充3圈0,[224,224,3]->[230,230,3]
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = ReLU()(x)
    # 填充1圈0
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=(3, 3), strides=2, padding='valid')(x)
    # 堆叠残差结构
    x = stack(x, filters=128, blocks=2, strides=1)
    x = stack(x, filters=256, blocks=3, strides=2)
    x = stack(x, filters=512, blocks=5, strides=2)
    x = stack(x, filters=1024, blocks=2, strides=2)
    # 根据特征图大小进行全局平均池化
    x = GlobalAvgPool2D()(x)
    x = Dense(num_classes, activation='softmax')(x)
    # 定义模型
    model = Model(inputs=inputs, outputs=x)
    return model

对于残差单元中的代码,提出一个问题:当conv_shortcut=False的时候,在执行Add操作时,理论上通道数不一致,为什么代码不报错?
在这里插入图片描述
答:这主要是跟下面堆叠残差单元的代码有关系,每个stack第一轮总会令conv_shortcut为True,使得x通道数进行扩展,而后面循环的时候传入的filters还是这个函数的实参,没有发生变化,但由于conv_shortcut为False,此时shortcut的通道数是与上面的x一致,所以在Add的时候,代码不会报错。

def stack(x, filters, blocks, strides, groups=32):
    # 每个stack的第一个block的残差连接都需要使用1*1卷积升维
    x = block(x, filters, strides=strides, groups=groups)
    for i in range(blocks):
        x = block(x, filters, groups=groups, conv_shortcut=False)
    return x

本文只是对ResNeXt-50算法的部分代码进行思考,学习过程中需要积极思考与探索,以提高能力和解决问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/308405.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

字体图标操作步骤

网站 直接点击 进去后长这样,点免费的添加 保存下载 保存后解压 把fonts文件夹复制粘贴到我们自己项目 可以放在同images的路径下 引入 来源于 再style中粘贴 font-face {font-family: icomoon;src: url(fonts/icomoon.eot?jyg4cp);src: url(fonts/icomoo…

Java微服务系列之 ShardingSphere - ShardingSphere-JDBC

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 系列专栏目录 [Java项…

使用Go语言通过API获取代理IP并使用获取到的代理IP

目录 前言 【步骤一:获取代理IP列表】 【步骤二:使用代理IP发送请求】 【完整代码】 【总结】 前言 在网络爬虫、数据抓取等场景中,经常需要使用代理IP来隐藏真实的IP地址,以及增加请求的稳定性和安全性。本文将介绍如何使用…

数据结构实验4:链表的基本操作

目录 一、实验目的 二、实验原理 1. 节点 2. 指针 3.链表的类型 3.1 单向链表 3.2 双向链表 3.3 单向循环链表 3.4 双向循环链表 4. 单链表的插入 4.1 头插法 4.2 尾插法 4.3 在指定位置插入元素 5. 单链表的删除 5.1 删除指定数值的节点 5.2 删除指定位置的节点 …

软件测试|Python Selenium 库安装使用指南

简介 Selenium 是一个用于自动化浏览器操作的强大工具,它可以模拟用户在浏览器中的行为,例如点击、填写表单、导航等。在本指南中,我们将详细介绍如何安装和使用 Python 的 Selenium 库。 安装 Selenium 库 使用以下命令可以通过 pip 安装…

Matlab 分段函数(piecewise)

语法 pw piecewise(cond1,val1,cond2,val2,...) pw piecewise(cond1,val1,cond2,val2,...,otherwiseVal)描述 pw piecewise(cond1,val1,cond2,val2,...) 返回分段表达式或函数pw,当cond1为True时,其值为val1,当cond2为True时&#xff0…

优化CentOS 7.6的HTTP隧道代理网络性能

在CentOS 7.6上,通过HTTP隧道代理优化网络性能是一项复杂且细致的任务。首先,我们要了解HTTP隧道代理的工作原理:通过建立一个安全的隧道,HTTP隧道代理允许用户绕过某些网络限制,提高数据传输的速度和安全性。然而&…

PyCharm 设置新建Python文件时自动在文章开头添加固定注释的方法

在实际项目开发时,为了让编写的每个代码文件易读、易于维护或方便协同开发时,我们都会在每一个代码文件的开头做一些注释,如作者,文档编写时间,文档的功能说明等。 利用PyCharm 编辑器,我们只需设置相关设…

MS-DETR论文解读

文章目录 前言一、摘要二、引言三、贡献四、MS-DETR模型方法1、模型整体结构解读2、模型改善结构解读3、一对多监督原理 五、实验结果1、实验比较2、论文链接 总结 前言 今天,偶然看到MS-DETR论文,以为又有什么高逼格论文诞生了。于是,我想查…

12.1SPI驱动框架

SPI硬件基础 总线拓扑结构 引脚含义 DO(MOSI):Master Output, Slave Input, SPI主控用来发出数据,SPI从设备用来接收数据 DI(MISO) :Master Input, Slave Output, SPI主控用来发出数据,SPI从设备用来接收…

IPv6路由协议---IPv6动态路由(OSPFv3-4)

OSPFv3的链路状态通告LSA类型 链路状态通告是OSPFv3进行路由计算的关键依据,链路状态通告包含链路状态类型、链路状态ID、通告路由器三元组唯一地标识了一个LSA。 OSPFv3的LSA头仍然保持20字节,但是内容变化了。在LSA头中,OSPFv2的LS age、Advertising Router、LS Sequence…

Java中输入和输出处理(三)二进制篇

叮咚!加油!马上学完 读写二进制文件Data DataInputStream类 FilFeInputStream的子类 与FileInputStream类结合使用读取二进制文件 DataOutputStream类 FileOutputStream的子类 与FileOutputStream类结合使用写二进制文件 读写二进制代码 package 面…

DAPP和APP的区别在哪?

随着科技的飞速发展,我们每天都在与各种应用程序打交道。然而,你是否真正了解DAPP和APP之间的区别呢?本文将为你揭示这两者的核心差异,让你在自媒体平台上脱颖而出。 一、定义与起源 APP,即应用程序,通常指…

使用KubeSphere轻松部署Bookinfo应用

Bookinfo 应用 这个示例部署了一个用于演示多种 Istio 特性的应用,该应用由四个单独的微服务构成。 如安装了 Istio,说明已安装 Bookinfo。 这个应用模仿在线书店的一个分类,显示一本书的信息。 页面上会显示一本书的描述,书籍…

C/C++ 位段

目录 什么是位段? 位段的内存分配 位段的跨平台问题 什么是位段? 位段的声明与结构是类似的,但是有两个不同: 位段的成员必须是 int、unsigned int 或signed int 等整型家族。位段的成员名后边有一个冒号和一个数字 这是一个…

深度学习笔记(二)——Tensorflow环境的安装

本篇文章只做基本的流程概述,不阐述具体每个软件的详细安装流程,具体的流程网上教程已经非常丰富。主要是给出完整的安装流程,以供参考 环境很重要 一个好的算法环境往往能够帮助开发者事半功倍,入门学习的时候往往搭建好环境就已…

【ELISA检测】酶联免疫吸附实验概述-卡梅德生物

酶联免疫吸附实验(Enzyme linked immunosorbent assay,ELISA)是将抗原或抗体结合在固相载体表面,利用抗原抗体的特异性结合以及抗体或者抗原上标记的酶催化特定底物发生显色反应,实现目标物检测的免疫分析方法&#xf…

24/1/10 qt work

1. 完善对话框,点击登录对话框,如果账号和密码匹配,则弹出信息对话框,给出提示”登录成功“,提供一个Ok按钮,用户点击Ok后,关闭登录界面,跳转到其他界面 如果账号和密码不匹配&…

11 个 Python全栈开发工具集

前言 以下是专注于全栈开发不同方面的 Python 库;有些专注于 Web 应用程序开发,有些专注于后端,而另一些则两者兼而有之。 1. Taipy Taipy 是一个开源的 Python 库,用于构建生产就绪的应用程序前端和后端。 它旨在加快应用程序开发&#xf…

Appium + ios环境搭建过程Mac

前提: 已经搭建好NodeJavaPythonAppium...环境 见下面的文章: ok的话按照下面的步骤搭建IOs的自动化 1. 安装Xcode 官方下载 (Downloads and Resources - Xcode - Apple Developer 1)AppStore 下载安装最新版本 2. 依赖工具 工具名描述libimobile…