ResUNet原理与实现

简述

  • ResNet是一种非常成功的深度卷积神经网络结构,其具有较强的特征表达能力和较浅的网络深度,使得其在图像分类等任务中表现出了出色的性能。因此,将ResNet作为encoder替换U-Net原始结构,可以使U-Net在图像分割任务中获得更好的性能表现。

  • U-Net是一种经典的深度卷积神经网络结构,特别适用于图像分割任务。U-Net提出的时间较早,当时并没有像ResNet等网络结构和大规模预训练权重这样的资源可用。但是,U-Net的下采样和上采样的设计思路和现在许多成熟的网络结构相似,因此,可以看作是先驱性的工作。

  • U-Net的下采样和上采样的设计与许多现在成熟的网络结构异曲同工。具体地,U-Net的下采样部分使用卷积和池化操作来逐渐减小特征图的尺寸和通道数,提取低级别特征。特征图每一层的尺寸会降低一半,而通道数会翻倍。这种设计与现在许多成熟的网络结构(如ResNet、VGG等)的下采样部分使用卷积和池化操作来逐渐减小特征图的尺寸和通道数的设计思路相似。

  • 成熟的网络结构和ImageNet预训练权重可以用来finetuning我们的U-Net。因为ImageNet是一个大规模的图像分类数据集,ImageNet预训练权重可以帮助我们在U-Net的训练中使用更好的初始化权重,加快网络的收敛速度并提高网络的泛化能力。通过finetuning,我们可以进一步优化U-Net网络的效果。

  • ResUNet是一种基于残差连接的深度学习模型,用于图像分割任务。它结合了ResNet和U-Net的优点,能够更好地解决梯度消失问题和语义信息缺失问题。下面将介绍ResUNet的原理。

原理

ResNet

Residual Network,是一种深度卷积神经网络
基于残差连接的思想,使得网络更容易训练
残差连接:跨层连接方法,可以使得网络更好地学习到低频信息

U-Net

一种全卷积网络,用于图像分割任务
包含编码器和解码器,能够提取全局和局部特征
上下文信息融合,可以缓解语义信息缺失问题

ResUNet

在U-Net的基础上加入了残差连接
编码器和解码器中的每个模块都包含了多个残差连接
在每个残差块中,引入了shortcut(或者称为skip connection)实现跨层连接
ResUNet的优点

残差连接可以缓解深度网络的梯度消失问题
编码器和解码器中的残差块可以更好地提取低频信息
上下文信息融合可以缓解语义信息缺失问题
实验结果表明,与U-Net相比,ResUNet在分割准确率上有显著提高。
总之,ResUNet是一种基于残差连接的深度学习模型,结合了ResNet和U-Net的优点。通过在编码器和解码器中引入多个残差块和shortcut,可以更好地提取低频信息、缓解语义信息缺失问题和梯度消失问题,从而提高图像分割的准确率。

示意图

在这里插入图片描述
可以看到,ResUNet包括了下采样、上采样和跳跃连接三个部分。

下采样部分使用卷积和池化操作逐渐减小图像尺寸和特征数量,提取低级别特征。

上采样部分使用转置卷积操作逐渐增大图像尺寸和特征数量,同时进行特征融合,生成高级别特征。

跳跃连接部分将下采样和上采样过程中相同分辨率的特征进行连接,帮助网络更好地捕捉多尺度信息,提高图像分割性能。

整个网络结构包括了ResNet的残差连接和U-Net的上下采样和跳跃连接思想,可以更好地平衡特征的丰富性和细节的保留性,在图像分割任务中表现出较好的性能。

代码实现

from torch import nn
import torchvision.models as models
import torch.nn.functional as F
from torchsummary import summary

# 定义解码器中的卷积块
class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()

        # 卷积块的结构
        self.block = nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, d, e=None):
        # 上采样
        d = F.interpolate(d, scale_factor=2, mode='bilinear', align_corners=True)
        # 拼接
        if e is not None:
            cat = torch.cat([e, d], dim=1)
            out = self.block(cat)
        else:
            out = self.block(d)
        return out

# 定义最后一层卷积块
def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
    )
    return block
# 定义 Resnet34_Unet 类
class Resnet34_Unet(nn.Module):
    # 定义初始化函数
    def __init__(self, in_channel, out_channel, pretrained=False):
        # 调用 nn.Module 的初始化函数
        super(Resnet34_Unet, self).__init__()
        
        # 创建 ResNet34 模型
        self.resnet = models.resnet34(pretrained=pretrained)
        # 定义 layer0,包括 ResNet34 的第一层卷积、批归一化、ReLU 和最大池化层
        self.layer0 = nn.Sequential(
            self.resnet.conv1,
            self.resnet.bn1,
            self.resnet.relu,
            self.resnet.maxpool
        )

        # 定义 Encode 部分,包括 ResNet34 的 layer1、layer2、layer3 和 layer4
        self.layer1 = self.resnet.layer1
        self.layer2 = self.resnet.layer2
        self.layer3 = self.resnet.layer3
        self.layer4 = self.resnet.layer4

        # 定义 Bottleneck 部分,包括两个卷积层、ReLU、批归一化和最大池化层
        self.bottleneck = torch.nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        )

        # 定义 Decode 部分,包括四个 expansive_block 和一个 final_block
        self.conv_decode4 = expansive_block(1024+512, 512, 512)
        self.conv_decode3 = expansive_block(512+256, 256, 256)
        self.conv_decode2 = expansive_block(256+128, 128, 128)
        self.conv_decode1 = expansive_block(128+64, 64, 64)
        self.conv_decode0 = expansive_block(64, 32, 32)
        self.final_layer = final_block(32, out_channel)

    # 定义前向传播函数
    def forward(self, x):
        # 执行 layer0
        x = self.layer0(x)
        # 执行 Encode
        encode_block1 = self.layer1(x)
        encode_block2 = self.layer2(encode_block1)
        encode_block3 = self.layer3(encode_block2)
        encode_block4 = self.layer4(encode_block3)

        # 执行 Bottleneck
        bottleneck = self.bottleneck(encode_block4)

        # 执行 Decode
        decode_block4 = self.conv_decode4(bottleneck, encode_block4)
        decode_block3 = self.conv_decode3(decode_block4, encode_block3)
        decode_block2 = self.conv_decode2(decode_block3, encode_block2)
        decode_block1 = self.conv_decode1(decode_block2, encode_block1)
        decode_block0 = self.conv_decode0(decode_block1)
final_layer = self.final_layer(decode_block0)

        return final_layer


flag = 0

if flag:
    image = torch.rand(1, 3, 572, 572)
    Resnet34_Unet = Resnet34_Unet(in_channel=3, out_channel=1)
    mask = Resnet34_Unet(image)
    print(mask.shape)

# 测试网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Resnet34_Unet(in_channel=1, out_channel=1, pretrained=True).to(device)
summary(model, input_size=(3, 512, 512))

该代码定义了一个基于ResNet34的Unet网络,用于语义分割任务。主要包括以下几个部分:

  1. expansive_block:扩张块,由两个卷积层、ReLU和BatchNorm组成。用于解码过程中对图像进行上采样和特征融合操作。
  2. final_block:最终块,由一个卷积层、ReLU和BatchNorm组成。用于将解码后的特征图转换为最终的输出图像。
  3. Resnet34_Unet:整个网络的主体部分。首先使用ResNet34作为编码器,对输入图像进行特征提取。然后通过一个卷积层和ReLU,将编码器的输出进行特征扩张。接下来进行解码操作,使用扩张块和编码器的特征图进行上采样和特征融合,直到得到与原始输入图像大小相同的特征图。最后通过最终块将特征图转换为输出图像
  4. flag:用于测试代码,如果设置为1,则会生成一个随机输入图像,并输出对应的分割结果。
  5. 测试网络:实例化Resnet34_Unet网络,并使用torchsummary库输出网络结构的信息,包括每一层的输出形状和参数数量等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/69156.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python爬虫实战(1)--爬取新闻数据

想要每天看到新闻数据又不想占用太多时间去整理,萌生自己抓取新闻网站的想法。 1. 准备工作 使用python语言可以快速实现,调用BeautifulSoup包里面的方法 安装BeautifulSoup pip install BeautifulSoup完成以后引入项目 2. 开发 定义请求头&#xf…

【Windows】Windows开机密码重置

文章目录 前言一、问题描述二、操作步骤2.1 安装DaBaiCai_d14_v6.0_2207_Online.exe2.2 插入U盘2.3 打开大白菜,点击“一键制作USB启动盘”2.4 等待进度条走完2.5 重启电脑,开机按“F12”或者“F8”(具体百度一下,对应品牌电脑开机…

Java 成功实现通过网址URL截图保存

Java 实现通过网址URL截图 1.DjNativeSwing方式 (不好用)2.phantomjs方式 (截图还是有瑕疵)3.selenium方式 (满意,成功实现)maven 引入下载相关浏览器chrome下载相关浏览器chromedriver驱动后端…

代码随想录算法训练营第53天|动态规划part11|123. 买卖股票的最佳时机 III、188.买卖股票的最佳时机IV

代码随想录算法训练营第53天|动态规划part11|123. 买卖股票的最佳时机 III、 188.买卖股票的最佳时机IV 123. 买卖股票的最佳时机 III 123. 买卖股票的最佳时机 III 思路: 相比买股票的最佳时机II,限制了买股票的次数&#xf…

Oracle 开发篇+Java调用OJDBC访问Oracle数据库

标签:JAVA语言、Oracle数据库、Java访问Oracle数据库释义:OJDBC是Oracle公司提供的Java数据库连接驱动程序 ★ 实验环境 ※ Oracle 19c ※ OJDBC8 ※ JDK 8 ★ Java代码案例 package PAC_001; import java.sql.Connection; import java.sql.ResultSet…

gitblit windows部署

1.官网下载 往死慢,我是从百度找的1.9.1,几乎就是最新版 http://www.gitblit.com/ 2.解压 下载下来是一个zip压缩包,直接解压即可 3.配置 3.1.配置资源库路径 找到data文件下的gitblit.properties文件,用Notepad打开 **注意路…

Android Ble蓝牙App(三)特性和属性

Ble蓝牙App(三)特性使用 前言正文一、获取属性列表二、属性适配器三、获取特性名称四、特性适配器五、加载特性六、显示特性和属性七、源码 前言 在上一篇中我们完成了连接和发现服务两个动作,那么再发现服务之后要做什么呢?发现服…

【二】数据库系统

数据库系统的分层抽象DBMS 数据的三个层次从 数据 到 数据的结构----模式数据库系统的三级模式(三级视图)数据库系统的两层映像数据库系统的两个独立性数据库系统的标准结构 数据模型从 模式 到 模式的结构----数据模型三大经典数据模型 数据库的演变与发…

windows使用/服务(13)戴尔电脑怎么设置通电自动开机

戴尔pc机器通电自启动 1、将主机显示器键盘鼠标连接好后,按主机电源键开机 2、在开机过程中按键盘"F12",进入如下界面,选择“BIOS SETUP” 3、选择“Power Management” 4、选择“AC Recovery”,点选“Power On”,点击“…

uniapp 格式化时间刚刚,几分钟前,几小时前,几天前…

效果如图: 根目录下新建utils文件夹,文件夹下新增js文件,文件内容: export const filters {dateTimeSub(data) {if (data undefined) {return;}// 传进来的data必须是日期格式,不能是时间戳//将字符串转换成时间格式…

使用 prometheus client SDK 暴露指标

目录 1. 使用 prometheus client SDK 暴露指标1.1. How Go exposition works1.2. Adding your own metrics1.3. Other Go client features 2. Golang Application monitoring using Prometheus2.1. Metrics and Labels2.2. Metrics Types2.2.1. Counters:2.2.2. Gauges:2.2.3. …

Python测试框架pytest:常用参数、查找子集、参数化、跳过

Pytest是一个基于python的测试框架,用于编写和执行测试代码。pytest主要用于API测试,可以编写代码来测试API、数据库、UI等。 pytest是一个非常成熟的全功能的Python测试框架,主要有以下几个优点: 简单灵活,容易上手。…

前端渲染数据

在前端对接受后端数据处理后返回的接收值的时候&#xff0c;为了解决数据过于庞大&#xff0c;而对数据进行简化处理例如性别&#xff0c;经常会使用1&#xff0c; 0这俩个来代替文字的男&#xff0c;女。以下就是前端渲染的具体实现。 以下是部分代码 <el-table-columnpr…

【MFC】10.MFC六大机制:RTTI(运行时类型识别),动态创建机制,窗口切分,子类化-笔记

运行时类信息&#xff08;RTTI&#xff09; C: ##是拼接 #是替换成字符串 // RTTI.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。 // #include <iostream> #include <afxwin.h>#ifdef _DEBUG #define new DEBUG_NEW #endifCWinApp th…

ubuntu 安装 nvidia 驱动

ubuntu 安装 nvidia 驱动 初环境与设备查询型号查询对应的驱动版本安装驱动验证驱动安装结果 本篇文章将介绍ubuntu 安装 nvidia 驱动 初 希望能写一些简单的教程和案例分享给需要的人 环境与设备 系统&#xff1a;ubuntu 设备&#xff1a;Nvidia GeForce RTX 4090 查询型…

Tcp是怎样进行可靠准确的传输数据包的?

概述 很多时候&#xff0c;我们都在说Tcp协议&#xff0c;Tcp协议解决了什么问题&#xff0c;在实际工作中有什么具体的意义&#xff0c;想到了这些我想你的技术会更有所提升&#xff0c;Tcp协议是程序员编程中的最重要的一块基石&#xff0c;Tcp是怎样进行可靠准确的传输数据…

web-ssrf

目录 ssrf介绍 以pikachu靶场为例 curl 访问外网链接 利用file协议查看本地文件 利用dict协议扫描内网主机开放端口 file_get_content 利用file协议查看本地文件&#xff1a; fsockopen() 防御方式: ssrf介绍 服务器端请求伪造&#xff0c;是一种由攻击者构造形成…

CSP复习每日一题(四)

树的重心 给定一颗树&#xff0c;树中包含 n n n 个结点&#xff08;编号 1 ∼ n 1∼n 1∼n&#xff09;和 n − 1 n−1 n−1条无向边。请你找到树的重心&#xff0c;并输出将重心删除后&#xff0c;剩余各个连通块中点数的最大值。 重心定义&#xff1a; 重心是指树中的一…

链式二叉树统计结点个数的方法和bug

方法一&#xff1a; 分治&#xff1a;分而治之 int BTreeSize1(BTNode* root) {if (root NULL) return 0;else return BTreeSize(root->left)BTreeSize(root->right)1; } 方法二&#xff1a; 遍历计数&#xff1a;设置一个计数器&#xff0c;对二叉树正常访问&#…

dubbo之高可用

负载均衡 概述 负载均衡是指在集群中&#xff0c;将多个数据请求分散到不同的单元上执行&#xff0c;主要是为了提高系统的容错能力和对数据的处理能力。 Dubbo 负载均衡机制是决定一次服务调用使用哪个提供者的服务。 策略 在Dubbo中提供了7中负载均衡策略&#xff0c;默…