分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。

论文地址:https://arxiv.org/pdf/2303.07428.pdf

具体的网络结构如下:

网络的原理还是比较简单的,

编码分支用的是预训练的resnet模块,解码分支则重新设计了。

解码器分支的模块结构示意图如下:

可以看出来,就是Transformer模块和残差连接相加,然后再经过一个residual模块处理。

1,用pytorch实现时,首先要把这个解码器模块实现出来:

class residual_transformer_block(nn.Module):
    def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None):
        super().__init__()

        self.ps = patch_size
        self.c1 = Conv2D(in_c, out_c)

        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
        self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)
        self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.r1 = residual_block(out_c, out_c)

    def forward(self, inputs):
        x = self.c1(inputs)

        b, c, h, w = x.shape
        num_patches = (h*w)//(self.ps**2)
        x = torch.reshape(x, (b, (self.ps**2)*c, num_patches))
        x = self.te(x)
        x = torch.reshape(x, (b, c, h, w))

        x = self.c2(x)
        s = self.c3(inputs)
        x = self.relu(x + s)
        x = self.r1(x)
        return x

其中我们需要注意的是这里构建Transformer块的方法,也就是下面两句:

encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

首先,第一句是用nn.TransformerEncoderLayer定义了一个Transformer层,并存储在encoder_layer变量中。

nn.TransformerEncoderLayer的参数包括:d_model(输入特征的维度大小),nhead(自注意力机制中注意力头数量),dim_feedforward(前馈网络的隐藏层维度大小),dropout(dropout比例),apply(用于在编码器层及其子层上应用函数,例如初始化或者权重共享等功能)。

第二句则是将多个Transformer层堆叠在一起,构建一个Transformer编码器块。

nn.TransformerEncoder的参数包括:encoder_layer(用于构建模块的每个Transformer层),num_layer(堆叠的层数),norm(执行的标准化方法),apply(同上)。

2,在上面的解码器模块中,还有一个residual block需要额外实现,如下:

class residual_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c)
        )
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

    def forward(self, inputs):
        x = self.conv(inputs)
        s = self.shortcut(inputs)
        return self.relu(x + s)

这个代码就是简单的残差卷积模块,不赘述。

3,重要的模块实现完了,接下来就是按照UNet的形状拼装网络,代码如下:

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        backbone = resnet50()
        self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4

        self.e1 = Conv2D(64, 64, kernel_size=1, padding=0)
        self.e2 = Conv2D(256, 64, kernel_size=1, padding=0)
        self.e3 = Conv2D(512, 64, kernel_size=1, padding=0)
        self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0)


        """ Decoder """
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.r1 = residual_transformer_block(64+64, 64, dim=64)
        self.r2 = residual_transformer_block(64+64, 64, dim=256)
        self.r3 = residual_block(64+64, 64)

        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        x0 = inputs
        x1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]
        x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]
        x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]
        x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]

        e1 = self.e1(x1)
        e2 = self.e2(x2)
        e3 = self.e3(x3)
        e4 = self.e4(x4)

        """ Decoder """
        x = self.up(e4)
        x = torch.cat([x, e3], axis=1)
        x = self.r1(x)

        x = self.up(x)
        x = torch.cat([x, e2], axis=1)
        x = self.r2(x)

        x = self.up(x)
        x = torch.cat([x, e1], axis=1)
        x = self.r3(x)

        x = self.up(x)

        """ Classifier """
        outputs = self.outputs(x)
        return outputs

其中,x1,x2,x3,x4就是编码器模块,用的都是resnet50的预训练模块。

其中r1,r2,r3,r4则是解码器的模块,就是上面实现的模块。

而e1,e2,e3,e4则是在skip connection前给编码器的输出做1x1卷积,作用大体上就是减少计算量。

完整代码:https://github.com/DebeshJha/TransNetR/blob/main/model.py#L45

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/444301.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HTML入门:属性

你好,我是云桃桃。今天来聊一聊 HTML 属性写法和特点。 HTML 属性是用于向 HTML 标签(也叫 HTML 元素)提供附加信息或配置的特性。 如果说,把HTML 标签比作一个房子,HTML 标签定义了房子的结构和用途,比如…

基于SpringBoot的闲置房屋搜索平台设计与实现

目 录 摘 要 I Abstract II 引 言 1 1相关技术 3 1.1 jQuery技术简介 3 1.2 SpringBoot框架简介 3 1.3 Bootstrap框架简介 4 1.4 ECharts框架简介 4 1.5 百度地图API简介 4 1.6 Ajax技术简介 5 1.7 MySQL数据库简介 5 1.8本章小结 6 2系统分析 7 2.1功能需求 7 2.2非功能需求 …

微软财务GPT Excel Copilot for Finance使用攻略

功能本身不收费,但是这个功能需要微软的商业版office账号才能使用,如果你没有账号,可以直说。 在桌面Excel软件中登录账号后,点击“copilot for finance”按钮,如果没有出现,则点击“加载项”,…

2024 年中国高校大数据挑战赛赛题 D:行业职业技术培训能力评价完整思路以及源代码分享

中国是制造业大国,产业门类齐全,每年需要培养大量的技能娴 熟的技术工人进入工厂。某行业在全国有多所不同类型(如国家级、 省级等)的职业技术培训学校,进行 5 种技能培训。学员入校时需要 进行统一的技能考核&#xf…

简述epoll实现

所有学习笔记:https://github.com/Dusongg/StudyNotes 文章目录 epoll数据结构的选择?以tcp为例,网络io的可读可写如何判断?epoll如何做到线程安全?LT和ET如何实现?tcp状态和io的读写有哪些关系&#xff1…

文本生成视频:从 Write-a-video到 Sora

2024年2月15日,OpenAI 推出了其最新的文本生成视频模型——Sora。Sora 能够根据用户的指令生成一分钟长度的高质量视频内容。这一创新的发布迅速在社会各界引发了广泛关注与深入讨论。本文将围绕本实验室发表于SIGGRAPH AISA 的 Write-a-video和 Sora 展开&#xff…

CPU设计实战-协处理器访问指令的实现

目录 一 协处理器的作用与功能 1.计数寄存器和比较寄存器 2.Status寄存器 3.Cause寄存器(标号为13) 4.EPC寄存器(标号为14) 5.PRId寄存器(标号为15) 6.Config 寄存器(标号为16)-配置寄存器 二 协处理器的实现 三 协处理器访问指令说明 四 具体实现 1.译码阶段 2.执行…

git命令行提交——github

1. 克隆仓库至本地 git clone 右键paste(github仓库地址) cd 仓库路径(进入到仓库内部准备提交文件等操作) 2. 查看main分支 git branch(列出本地仓库中的所有分支) 3. 创建新分支(可省…

Edu18 -- Divide by Three --- 题解

目录 Divide by Three: 题目大意: ​编辑​编辑思路解析: 代码实现: Divide by Three: 题目大意: 思路解析: 一个数字是3的倍数,那么他的数位之和也是3的倍数,所以我…

安信可IDE(AiThinker_IDE)编译ESP8266工程方法

0 工具准备 AiThinker_IDE.exe ESP8266工程源码 1 安信可IDE(AiThinker_IDE)编译ESP8266工程方法 1.1 解压ESP8266工程文件夹 我们这里使用的是NON-OS_SDK,将NON-OS_SDK中的1_UART文件夹解压到工作目录即可 我这里解压到了桌面&#xff0c…

WiFi模块助力少儿编程:创新学习与实践体验

随着科技的飞速发展,少儿编程已经成为培养孩子们创造力和问题解决能力的重要途径之一。在这个过程中,WiFi模块的应用为少儿编程领域注入了新的活力,使得学习编程不再是单一的代码教学,而是一个充满创新与实践的综合性体验。 物联网…

Redis作为缓存的数据一致性问题

背景 使用Reids作为缓存的原因: 在高并发场景下,传统关系型数据库的并发能力相对比较薄弱(QPS不能太大); 使用Redis做一个缓存。让用户请求先打到Redis上而不是直接打到数据库上。 但是如果出现数据更新操作&#xff…

开发指南002-前后端信息交互规范-概述

前后端之间采用restful接口,服务和服务之间使用feign。信息交互遵循如下平台规范: 前端: 建立api目录,按照业务区分建立不同的.js文件,封装对后台的调用操作。其中qlm*.js为平台预制的接口文件,以qlm_user.…

【红外与可见光融合:条件学习:实例归一化(IN)】

Infrared and visible image fusion based on a two-stage class conditioned auto-encoder network (基于两级类条件自编码器网络的红外与可见光图像融合) 现有的基于自动编码器的红外和可见光图像融合方法通常利用共享编码器从不同模态中提取特征&am…

arduino安装索尼spresense开发库

arduino安装索尼spresense开发库 一.库安装二.库文件下载1.直接下载2.git下载1.git加速下载2.git下载加速3.将文件导入arduino 一.库安装 打开arduino点击文件->首选项 将以下链接添加进附加开发板管理器网址 https://github.com/sonydevworld/spresense-arduino-compatib…

什么是数据采集与监视控制系统(SCADA)?

SCADA数据采集是一种用于监控和控制工业过程的系统。它可以实时从现场设备获得数据并将其传输到中央计算机,以便进行监控和控制。SCADA数据采集系统通常使用传感器、仪表和控制器收集各种类型的数据,例如温度、压力、流量等,然后将这些数据汇…

【李沐】动手学习ai思路softmax回归实现

来源:https://www.cnblogs.com/blzm742624643/p/15079086.html 一、从零开始实现 1.1 首先引入Fashion-MNIST数据集 1 import torch 2 from IPython import display 3 from d2l import torch as d2l 4 5 batch_size 256 6 train_iter, test_iter d2l.load_data…

tcp流式服务和粘包问题

目录 1.概念 2.流式服务 3.粘包问题 1.概念 套接字是一个全双工的 使用TCP协议通信的双方必须先建立连接,然后才能开始数据的读写,双方都必须为该连接分配必要的内核资源,以管理连接的状态和连接上数据的传输. TCP连接是全双工的,即双方的数据读写可以通过一个连接进行,完成…

集合框架(一)List系列集合

特点 有序,可重复,有索引。 LIst集合的特有方法 /** 目标:掌握List系列集合的特点,以及其提供的特有方法* */import java.util.ArrayList; import java.util.List;public class ListTest1 {public static void main(String[] arg…

android开发环境搭建

android开发环境搭建 Android 开发环境搭建1.JDK安装与配置1.1 Jdk官方下载1.2 JDK安装1.3 环境变量配置1.4 新建JAVA_HOME1.5 修改Path变量1.6 新建classpath1.7 验证环境是否配置完成 2.开发工具二选一1.如何创建一个工程2.工程的目录结构的了解3.与开发的相关的常规视图4.我…