pytorch复现_UNet

什么是UNet
U-Net由收缩路径和扩张路径组成。收缩路径是一系列卷积层和汇集层,其中要素地图的分辨率逐渐降低。扩展路径是一系列上采样层和卷积层,其中特征地图的分辨率逐渐增加。
在扩展路径中的每一步,来自收缩路径的对应特征地图与当前特征地图级联。
在这里插入图片描述
主干结构解析
左边为特征提取网络(编码器),右边为特征融合网络(解码器)

高分辨率—编码—低分辨率—解码—高分辨率

特征提取网络
高分辨率—编码—低分辨率

前半部分是编码, 它的作用是特征提取(获取局部特征,并做图片级分类),得到抽象语义特征

由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块,一共经过4次这样的操作

特征融合网络
低分辨率—解码—高分辨率

利用前面编码的抽象特征来恢复到原图尺寸的过程, 最终得到分割结果(掩码图片)

代码:

import torch.nn as nn
import torch

# 编码器(论文中称之为收缩路径)的基本单元
def contracting_block(in_channels, out_channels):
    block = torch.nn.Sequential(
        # 这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小
        nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels, out_channels=out_channels),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
    return block


# 解码器(论文中称之为扩张路径)的基本单元
class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()

        # 每进行一次反卷积,通道数减半,尺寸扩大2倍
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=(3, 3), stride=2, padding=1,
                                     output_padding=1)
        self.block = nn.Sequential(
            # 这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小
            nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, e, d):
        d = self.up(d)
        # concat
        # e是来自编码器部分的特征图,d是来自解码器部分的特征图,它们的形状都是[B,C,H,W]
        diffY = e.size()[2] - d.size()[2]
        diffX = e.size()[3] - d.size()[3]
        # 裁剪时,先计算e与d在高和宽方向的差距diffY和diffX,然后对e高方向进行裁剪,具体方法是两边分别裁剪diffY的一半,
        # 最后对e宽方向进行裁剪,具体方法是两边分别裁剪diffX的一半,
        # 具体的裁剪过程见下图一
        e = e[:, :, diffY // 2:e.size()[2] - diffY // 2, diffX // 2:e.size()[3] - diffX // 2]
        cat = torch.cat([e, d], dim=1)  # 在特征通道上进行拼接
        out = self.block(cat)
        return out


# 最后的输出卷积层
def final_block(in_channels, out_channels):
    block = nn.Conv2d(kernel_size=(1, 1), in_channels=in_channels, out_channels=out_channels)
    return block


class UNet(nn.Module):

    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()

        # 编码器 (Encode)
        self.conv_encode1 = contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode2 = contracting_block(in_channels=64, out_channels=128)
        self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode3 = contracting_block(in_channels=128, out_channels=256)
        self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode4 = contracting_block(in_channels=256, out_channels=512)
        self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 编码器与解码器之间的过渡部分(Bottleneck)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )

        # 解码器(Decode)
        self.conv_decode4 = expansive_block(1024, 512, 512)
        self.conv_decode3 = expansive_block(512, 256, 256)
        self.conv_decode2 = expansive_block(256, 128, 128)
        self.conv_decode1 = expansive_block(128, 64, 64)

        self.final_layer = final_block(64, out_channel)

    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_pool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_pool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_pool3(encode_block3)
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_pool4(encode_block4)

        # Bottleneck
        bottleneck = self.bottleneck(encode_pool4)

        # Decode
        decode_block4 = self.conv_decode4(encode_block4, bottleneck)
        decode_block3 = self.conv_decode3(encode_block3, decode_block4)
        decode_block2 = self.conv_decode2(encode_block2, decode_block3)
        decode_block1 = self.conv_decode1(encode_block1, decode_block2)

        final_layer = self.final_layer(decode_block1)
        return final_layer


if __name__ == '__main__':
    image = torch.rand((1, 3, 572, 572))
    unet = UNet(in_channel=3, out_channel=2)
    mask = unet(image)
    print(mask.shape)
    
    #输出结果:
    torch.Size([1, 2, 388, 388])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/119090.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前端框架Vue学习 ——(四)Axios

文章目录 Axios 介绍Axios 入门Vue项目中使用 Axios Axios 介绍 介绍: Axios 对原生的 Ajax 进行了封装,简化书写,快速开发。(异步请求) 官网: https://www.axios-http.cn/ 官网介绍:Axios 是一个基于 promise 网络请…

@Tag和@Operation标签失效问题。SpringDoc 2.2.0(OpenApi 3)和Spring Boot 3.1.1集成

问题 Tag和Operation标签失效 但是Schema标签有效 pom依赖 <!-- 接口文档--><!--引入openapi支持--><dependency><groupId>org.springdoc</groupId><artifactId>springdoc-openapi-starter-webmvc-ui</artifactId><vers…

多测师肖sir_高级金牌讲师_jenkins搭建

jenkins操作手册 一、jenkins介绍 1、持续集成&#xff08;CI&#xff09; Continuous integration 持续集成 团队开发成员每天都有集成他们的工作&#xff0c;通过每个成员每天至少集成一次&#xff0c;也就意味着一天有可 能多次集成。在工作中我们引入持续集成&#xff0c;通…

水利部加快推进小型水库除险加固,大坝安全监测是重点

国务院常务会议明确到2025年前&#xff0c;完成新出现病险水库的除险加固&#xff0c;配套完善重点小型水库雨水情和安全监测设施&#xff0c;实现水库安全鉴定和除险加固常态化。 为加快推进小型水库除险加固前期工作&#xff0c;水利部协调财政部提前下达了2023年度中央补助…

chinese-stable-diffusion中文场景文生图prompt测评集合

腾讯混元大模型文生图操作指南.dochttps://mp.weixin.qq.com/s/u0AGtpwm_LmgnDY7OQhKGg腾讯混元大模型再进化&#xff0c;文生图能力重磅上线&#xff0c;这里是一手实测腾讯混元的文生图在人像真实感、场景真实感上有比较明显的优势&#xff0c;同时&#xff0c;在中国风景、动…

ActiveMq学习⑨__基于zookeeper和LevelDB搭建ActiveMQ集群

引入消息中间件后如何保证其高可用&#xff1f; 基于zookeeper和LevelDB搭建ActiveMQ集群。集群仅提供主备方式的高可用集群功能&#xff0c;避免单点故障。 http://activemq.apache.org/masterslave LevelDB&#xff0c;5.6版本之后推出了LecelDB的持久化引擎&#xff0c;它使…

错误:ERROR Cannot read properties of null (reading ‘type‘)

ERROR Cannot read properties of null (reading ‘type’) TypeError: Cannot read properties of null (reading ‘type’) <template><el-card><el-row :gutter"20" class"header"><el-col :span"7"><el-input pl…

二、Hadoop分布式系统基础架构

1、分布式 分布式体系中&#xff0c;会存在众多服务器&#xff0c;会造成混乱等情况。那如何让众多服务器一起工作&#xff0c;高效且不出现问题呢&#xff1f; 2、调度 &#xff08;1&#xff09;架构 在大数据体系中&#xff0c;分布式的调度主要有2类架构模式&#xff1a…

【Redis】SSM整合Redis注解式缓存的使用

【Redis】SSM整合Redis&注解式缓存的使用 一、SSM整合Redis1.2.配置文件spring-redis.xml1.3.修改applicationContext.xml1.4.配置redis的key生成策略 二、Redis的注解式开发及应用场景2.1.什么是Redis注解式2.实列测试 三、Redis中的击穿、穿透、雪崩的三种场景 一、SSM整…

WebSocket Day03 : SpringMVC整合WebSocket

前言 在现代Web应用程序中&#xff0c;实时性和即时通信变得越来越重要。传统的HTTP请求-响应模式无法满足实时数据传输和双向通信的需求。随着技术的发展&#xff0c;WebSocket成为了一种强大而灵活的解决方案。 WebSocket是HTML5提供的一种新的通信协议&#xff0c;它通过一…

vue.js实现科室无限层选中和回显

一、效果展示&#xff1a; 展示可选层级 查看选中的值 二、实现&#xff1a; <el-form-item label"相关科室:" prop"orgId"><el-cascaderpopper-class"cascader-my":options"orgOptions":show-all-levels"false"…

NVIDIA Jetson SOC 内存分配策略

CPU 是Host, GPU 是Device, 系统内存分配策略如下: 这段话的翻译如下&#xff1a; 集成的GPU会和CPU以及其他Tegra引擎共享DRAM&#xff08;动态随机存储器&#xff09;&#xff0c;并且CPU可以通过将DRAM的内容移动到交换区域&#xff08;SWAP area&#xff09;或者相反来控制…

成功品牌的营销秘诀揭密,营销秘诀,品牌成功

品牌营销是将品牌塑造为消费者心目中有价值的存在&#xff0c;从而提高品牌认知度和价值的过程。品牌营销是任何一家企业成功的关键所在。如果一家企业能够正确地营销其品牌&#xff0c;那么它就能够在行业中发挥更大的作用。接下来&#xff0c;迅推客将深入探讨品牌营销的重要…

rust变量绑定、拷贝、转移、引用

目录 一&#xff0c;clone、copy 1&#xff0c;基本类型 2&#xff0c;类型的clone特征 3&#xff0c;显式声明结构体的clone特征 4&#xff0c;类型的copy特征 5&#xff0c;显式声明结构体的clone特征 5&#xff0c;变量和字面量的特征 6&#xff0c;特征总结 二&am…

技术分享 | 抓包分析 TCP 协议

TCP 协议是在传输层中&#xff0c;一种面向连接的、可靠的、基于字节流的传输层通信协议。 环境准备 对接口测试工具进行分类&#xff0c;可以如下几类&#xff1a; 网络嗅探工具&#xff1a;tcpdump&#xff0c;wireshark代理工具&#xff1a;fiddler&#xff0c;charles&a…

触摸屏通过modbus转profinet网关连接PLC与变频器485modbus通讯案例

通过兴达易控modbus转profinet网关&#xff08;XD-MDPN100&#xff09;的桥接&#xff0c;数据可以以高速、可靠的方式从触摸屏传递到PLC&#xff0c;同时能够实现PLC对变频器的监控和控制。这四台变频器通过485modbus协议与PLC通讯&#xff0c;使得系统能够实现对变频器的高效…

ACmix:卷积和self-attention的结合,YOLOv5改进之ACmix

目录 一、ACmix理论部分 二、代码 三、YOLOv5改进 ACC3 一、ACmix理论部分 论文地址:2111.14556.pdf (arxi

梳理自动驾驶中的各类坐标系

目录 自动驾驶中的坐标系定义 关于坐标系的定义 几大常用坐标系 世界坐标系 自车坐标系 传感器坐标系 激光雷达坐标系 相机坐标系 如何理解坐标转换 机器人基础中的坐标转换概念 左乘右乘的概念 对左乘右乘的理解 再谈自动驾驶中的坐标转换 本节参考文献 自动驾驶…

517-0224-16A-458525 531X303MCPARG1 现代工厂中DCS与PLC的比较

517-0224-16A-458525 531X303MCPARG1 现代工厂中DCS与PLC的比较 分布式控制系统(DCSs)和可编程逻辑控制器(PLC)之间的区别可以归结为一个简单的足球比喻。你的指挥系统是你的船长。团队名单上的第一个名字&#xff0c;你的DCS是可靠的&#xff0c;勤奋的&#xff0c;控制着整个…

django+drf+vue 简单系统搭建 (1) - django创建项目

本系列文章为了记录自己第一个系统生成过程&#xff0c;主要使用django,drf,vue。本人非专业人士&#xff0c;此文只为记录学习&#xff0c;若有部分描述不够准确的地方&#xff0c;烦请指正。 建立这个系统的原因是因为&#xff0c;在生活中&#xff0c;很多觉得可以一两行代码…