神经网络(五):U2Net图像分割网络

文章目录

  • 一、网络结构
    • 1.1第一种block结构
    • 1.2第二种block结构
    • 1.3特征图融合模块
    • 1.4损失函数
    • 1.5总体网络架构
    • 1.6代码汇总
    • 1.7普通残差块与RSU对比
  • 二、代码复现


  参考论文:U2-Net: Going deeper with nested U-structure for salient object detection
  这篇文章基于显著目标检测任务提出,显著目标检测是指将图像中最吸引人的目标或者区域分割出来,因此只有前景和背景两个类别,相当于语义分割中的二分类任务。例如,下图中展示了三张图片进行显著目标检测后的结果:
在这里插入图片描述
其中,白色区域代表前景 ,即最吸引人的目标或区域,而黑色区域代表背景。
  在 U 2 N e t U^2Net U2Net被提出前,显著目标检测领域主要面临两个问题:

  • 1.现有的SOD网络大多基于当时已有的网络架构(主干网络)进行深度特征的提取,如 A l e x N e t 、 V G G 、 R e s N e t 、 R e s N e X t AlexNet、VGG、ResNet、ResNeXt AlexNetVGGResNetResNeXt等。这些网络最初是为图像分类设计的,它们提取代表语义含义的特征(如某一具体的猫、狗),而非局部细节和全局对比度信息(SOD的主要任务是将图像划分为前景与背景,而不注重某一具体特征的提取),使得这些模型在进行SOD任务时效率低下。
  • 2.当时的SOD网络架构不断通过向现有的主干网络中添加特征聚合模块以提取多级显著特征,使得模型过于复杂。且这些图像分类模型往往通常通过牺牲特征图的高分辨率来实现更深的架构,即特征图在早期阶段会被缩小到较低的分辨率,如ResNet和DenseNet会使用步长为2的卷积和步长为2的最大池化将特征图大小减小到输入图的四分之一。但是,高分辨率在图像分割中有着重要作用,这也使得这些模型并不适用SOD任务。

  为解决上述问题,提出了 U 2 N e t U^2Net U2Net网络结构:

  • 一种两级嵌套的U形结构,专为SOD 设计,无需使用任何来自图像分类的预训练主干。
  • 提出U型残差块RSU,能在不降低特征图分辨率的情况下提取阶段内多尺度特征。

一、网络结构

   U 2 − N e t U^2-Net U2Net网络基于 U N e t UNet UNet网络设计而来。事实上,该网络的整体结构与 U N e t UNet UNet网络几乎相同,但所使用的上采样、下采样模块变成了小型的 U N e t UNet UNet网络,即, U 2 N e t U^2Net U2Net本质上是 U N e t UNet UNet网络的嵌套。而 U 2 − N e t U^2-Net U2Net网络的核心就是这些作为模块的小型 U N e t UNet UNet网络,并将其起名为 R e S i d u a l U − b l o c k ( R S U ,残差 U 块 ) ReSidual U-block(RSU,残差U块) ReSidualUblock(RSU,残差U)。网络结构如下:
在这里插入图片描述
这些模块其实可以分为两种, E n c o d e r 1   E n c o d e r 4 、 D e c o d e r 1   D e c o d e r 4 Encoder1~Encoder4、Decoder1~Decoder4 Encoder1 Encoder4Decoder1 Decoder4采用的是同一种结构的残差块,只不过深度不同,而Encoder5、Encoder6、Decoder5 采用的是另一种结构的残差块。整体流程可概况为:

  • Encoder阶段:每通过一个模块后都会下采样两倍,使用的是torch.nn.MaxPool2d
  • Decoder阶段:每通过一个模块后都会上采用两倍,使用的是torch.nn.functional.interpolate()
  • 跳跃链接:与 U N e t UNet UNet网络思路相同,将编码器的输出与解码器输出的特征图进行拼接,最后得到分割后的图像。

1.1第一种block结构

  本地和全局上下文信息对于显著对象检测和图像分割任务都非常重要,现代CNN网络设计中VGG、ResNet、DenseNet 等,一般使用1x1或3x3的小型卷积核提取特征。但在SOD任务中,由于它们的感受野太小而无法捕捉全局信息,使得浅层的输出特征图仅包含局部特征。在下图(图 ( a ) − ( c ) (a)-(c) (a)(c))中给出了具有小感受野的典型现有卷积块。为从浅层获得高分辨率特征图的更多全局信息,最直接的想法是扩大感受野,图 ( e ) (e) (e)是一种双向消息传递模块(见论文ieee),它试图通过使用扩张卷积扩大感受野来提取局部和非局部特征,以原始分辨率对输入特征图进行多次扩张卷积(尤其是在早期阶段)需要太多的计算和内存资源。
  为解决上述问题,本文提出了RSU模块(图 ( e ) (e) (e),L表示RSU的深度,图中L=7):

  • 一个输入卷积层,用于将尺寸为 ( H , W , C i n ) (H,W,C_{in}) (H,W,Cin)输入特征图x添加到中间映射 F 1 ( x ) F_1(x) F1(x)中,这是一个用于局部特征提取的普通卷积层,通道数为 C o u t C_{out} Cout
  • 高度为L的类似 U N e t UNet UNet的对称编码器-解码器结构,采用 F 1 ( x ) F_1(x) F1(x)作为输入,学习多尺度上下文信息。用 U ( x ) U(x) U(x)表示该类似 U N e t UNet UNet的结构,则提取的信息可表示为 U ( F 1 ( x ) ) U(F_1(x)) U(F1(x))。较大的 L 会生成更深的RSU、更多的池化操作、更大的感受野范围以及更丰富的局部和全局特征。配置此参数可以从具有任意空间分辨率的输入特征图中提取多尺度特征。从逐渐下采样的特征图中提取多尺度特征,并通过渐进式上采样、连接和卷积,将特征图还原为高分辨率特征图。此过程可减轻因使用大比例的直接上采样而导致的精细细节损失。
  • 通过求和融合局部特征和多尺度特征的残差连接: F 1 ( x ) + U ( F 1 ( x ) ) F_1(x)+U(F_1(x)) F1(x)+U(F1(x))

在这里插入图片描述

   R S U − 7 RSU-7 RSU7真实结构如下图所示:
在这里插入图片描述
  回到 U 2 − N e t U^2-Net U2Net结构,该RSU的使用场景有:

  • Encoder1 和 Decoder1 采用的是 RSU-7 结构。
  • Encoder2 和 Decoder2 采用的是 RSU-6 结构。
  • Encoder3 和 Decoder3 采用的是 RSU-5 结构。
  • Encoder4 和 Decoder4 采用的是 RSU-4 结构。

可见,相邻 block 相差一次下采样和一次上采样,例如 RSU-6 相比于 RSU-7 少了一个下采样卷积和上采样卷积部分,RSU-7 是下采样 32 倍和上采样 32 倍,RSU-6 是下采样 16 倍和上采样 16 倍。代码实现如下:


import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):    #实现conv2d+BN+ReLU操作                                                      
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        # dilation用于实现空洞卷积
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
 
def _upsample_like(src,tar):
    src = F.interpolate(src,size=tar.shape[2:],mode='bilinear',align_corners=True)     
    return src
 
 
### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):                          #En_1   
 
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)              #CBR1
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)              #CBR2
 
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)           
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)             
 
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
 
        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
 
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
 
        hx = x
        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)
 
        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)
 
        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)
 
        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)
 
        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)
 
        hx6 = self.rebnconv6(hx)
 
        hx7 = self.rebnconv7(hx6)                                  
 		#实现残差连接
        hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
 
        hx6dup = _upsample_like(hx6d,hx5)
        hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
 
        hx5dup = _upsample_like(hx5d,hx4)
        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
 
        hx4dup = _upsample_like(hx4d,hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
 
        hx3dup = _upsample_like(hx3d,hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
 
        hx2dup = _upsample_like(hx2d,hx1)
 
        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
 
        return hx1d + hxin

在这里插入图片描述

1.2第二种block结构

  数据经过 E n 1 − E n 4 En_1-En4 En1En4下采样处理后对应特征图的分辨率就已经相对比较小了,如果再继续下采样就会丢失很多上下文信息。为保留上下文信息,在 E n c o d e r 5 、 E n c o d e r 6 、 D e c o d e r 5 Encoder5、Encoder6、Decoder5 Encoder5Encoder6Decoder5中将原始RSU中的上采样、下采样结构换成了空洞卷积操作,从而得到了 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本。此时 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。在这里插入图片描述
需要注意,在 E n c o d e r 5 Encoder5 Encoder5中特征图大小已经到了18*18,非常小(也因此不需要再下采样),故采用了空洞卷积操作,目的在不改变特征图大小的情况下增大感受野。故在代码中使用了dalition=2、4、8 E n c o d e r 6 、 D e c o d e r 5 Encoder6、Decoder5 Encoder6Decoder5同理。这一特点在原图中显示为使用了虚线构成的长方体块。

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):                                                          #CBL
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
        
### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
 
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
 
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
        hx = x

        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)
 
        hx4 = self.rebnconv4(hx3)
 
        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
 
        return hx1d + hxin

1.3特征图融合模块

  在通过编码、解码器的运算后,最后通过特征图融合模块(红框标出)将 D e 1 、 D e 2 、 D e 3 、 D e 4 、 D e 5 、 E n 6 De_1、De_2、De_3、De_4、De_5、En_6 De1De2De3De4De5En6模块的输出分别通过一个3x3的卷积层(卷积层的卷积核个数均为1),并通过双线性插值将得到的特征图还原回输入图像的大小,之后将得到的6个特征图进行拼接(Concatenation),最后再经过一个1x1的卷积层以及sigmoid激活函数,最终得到融合之后的图像。
在这里插入图片描述

1.4损失函数

   U 2 N e t U^2Net U2Net使用多监督算法构建损失函数。网络输出不仅仅包含最终特征图,还包含前面6个不同尺度的特征图,即,不仅要监督网络输出,还要监督中间融合特征图。 损失函数计算公式:
在这里插入图片描述
其中, M = 1 , 2 , 3 , . . . , 6 M=1,2,3,...,6 M=1,2,3,...,6 l s i d e m l_{side}^{m} lsidem表示特征图 S u p 1 、 S u p 2 、 . . . 、 S u p 6 Sup1、Sup2、...、Sup6 Sup1Sup2...Sup6的损失,而 l f u s e l_{fuse} lfuse表示最终特征图的损失, w w w则表示两种损失的权重参数(论文给出的源码中全为1)。 l s i d e l_{side} lside l f u s e l_{fuse} lfuse采用二值交叉熵(standard binary cross-entropy)进行计算:
在这里插入图片描述
其中, ( r , c ) (r,c) (r,c)表示像素坐标值, ( H , W ) (H,W) (H,W)表示图像高度和宽度, P G ( r , c ) P_{G(r,c)} PG(r,c)表示标签图像素灰度值, P S ( r , c ) P_{S(r,c)} PS(r,c)表示预测的图像素灰度值。

1.5总体网络架构

在这里插入图片描述
   U 2 N e t U^2Net U2Net主要由三部分组成:

  • 一个六级编码器:在 E n 1 、 E n 2 、 E n 3 、 E n 4 En_1、En_2、En_3、En_4 En1En2En3En4中分别使用 R S U 7 、 R S U 6 、 R S U 5 、 R S U 4 RSU7、RSU6、RSU5、RSU4 RSU7RSU6RSU5RSU4 L L L通常根据输入特征图的空间分辨率进行配置。对于高宽较大的特征图,使用更大的 L L L来捕获更多的大比例度信息。 E n 5 En_5 En5 E n 6 En_6 En6中特征图的分辨率相对较低,进一步降低这些特征图的采样会导致有用的上下文丢失。因此,在 E n 5 En_5 En5 E n 6 En_6 En6阶段,都使用 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本,其用膨胀卷积代替池化和上采样操作这意味着 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。
  • 一个五级解码器: D e 5 De_5 De5阶段同样使用 R S U − 4 F RSU-4F RSU4F,并且每个解码器都使用前一阶段的上采样特征图和来自其对称编码器阶段的特征图的串联作为输入(跳跃连接)。
  • 特征图融合模块:将六个侧输出显著性特征图上采样到输入图像的尺寸,之后使用融合操作,并通过1x1卷积层和sigmoid函数生成最终的显著性特征图。

  研究中将3320320的图像裁剪为3288288大小输入模型,最终得到1288288的图像分割结果(二值图像):
在这里插入图片描述

1.6代码汇总

import torch
import torch.nn as nn
import torch.nn.functional as F

class REBNCONV(nn.Module):
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self,x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):

    src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

    return src


### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
        hx6dup = _upsample_like(hx6d,hx5)

        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)


        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

1.7普通残差块与RSU对比

在这里插入图片描述
  普通残差块的操作可概况为 H ( x ) = F 2 ( F 1 ( x ) ) + x H(x)=F_2(F_1(x))+x H(x)=F2(F1(x))+x,其中, F 1 、 F 2 F_1、F_2 F1F2代表权重层,此处设为卷积运算。RSU 和残差块之间的主要设计区别在于,RSU 用类似 U N e t UNet UNet的结构替换了普通的单流卷积,并将原始特征替换为由权重层转换的局部特征: H R S U ( x ) = U ( F 1 ( x ) ) + F 1 ( x ) H_{RSU}(x)=U(F_1(x))+F_1(x) HRSU(x)=U(F1(x))+F1(x),其中 U U U表示多层U型结构。种设计更改使网络能够直接从每个残差块中提取来自多个尺度的特征。更值得注意的是,由于 U 结构导致的计算开销很小,因为大多数操作都应用于下采样的特征图。下图中给出了RSU中 F 1 、 U F_1、U F1U的含义:
在这里插入图片描述

  残差块性能比较:
在这里插入图片描述

  • PLN:普通卷积块。
  • RES:残差块。
  • DSE:密集块。
  • INC:初始块。
  • RSU:U型残差块。

二、代码复现

https://github.com/xuebinqin/U-2-Net/tree/master

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/883375.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

机器学习之非监督学习(二)异常检测(基于高斯概率密度)

机器学习之非监督学习(二)异常检测(基于高斯概率密度) 0. 文章传送1.案例引入2.高斯正态分布3.异常检测算法4.异常检测 vs 监督学习5.算法优化6.代码实现 0. 文章传送 机器学习之监督学习(一)线性回归、多…

【mac开发入坑指南】能让你的终端好用一万倍的神仙组合iTerm2 + oh-my-zsh

介绍 iTerm2 iTerm2是默认终端的替代品,也是目前Mac系统下最好用的终端工具,集颜值和效率于一身。 Oh-My-Zsh Oh My Zsh 是一款社区驱动的命令行工具,正如它的主页上说的,Oh My Zsh 是一种生活方式。 它基于Zsh 命令行&#xff0c…

N诺计算机考研-错题

D 我们熟知的Windows XP、Linux、Mac OS X等都是多用户多任务分时操作系统。 C 分布式系统:由一组独立的计算机组成的系统,这些计算机通过网络相互连接,并且对外界用户来说,它们共同工作就像是一个单一的、统一的计算平台或服务。分布式系统的关键特征: 透明性:用户和应…

FLStudio21Mac版flstudio v21.2.1.3430简体中文版下载(含Win/Mac)

给大家介绍了许多FL21版本,今天给大家介绍一款FL Studio21Mac版本,如果是Mac电脑的朋友请千万不要错过,当然我也不会忽略掉Win系统的FL,链接我会放在文章,供大家下载与分享,如果有其他问题,欢迎…

[图解]静态关系和动态关系

1 00:00:01,060 --> 00:00:04,370 首先我们来看静态关系和动态关系 2 00:00:06,160 --> 00:00:10,040 我们要尽量基于静态关系来建立动态关系 3 00:00:11,740 --> 00:00:13,740 不能够在没有这个的基础上 4 00:00:14,220 --> 00:00:17,370 没有这个的情况下就胡…

Python第一篇:Python解释器

一:python解释器 python解释器是一款程序,用于解释、执行Python源代码。 一般python解释器都是c python使用c编写的,还有j python用java编写的。 二:python下载 三:使用示例 python进入控制台,python。 三…

【2.使用VBA自动填充Excel工作表】

目录 前言什么是VBA如何使用Excel中的VBA简单基础入门控制台输出信息定义过程(功能)定义变量常用的数据类型Set循环For To 我的需求开发过程效果演示文件情况测试填充源文件测试填充目标文件 全部完整的代码sheet1中的代码,对应A公司工作表Us…

2024年最新Redis内存数据库主从复制、哨兵模式、集群部署等详细教程(更新中)

Centos 安装 Redis 检查安装 GCC 环境 [rootVM-4-17-centos ~]# gcc --version gcc (GCC) 8.5.0 20210514 (Red Hat 8.5.0-4) Copyright (C) 2018 Free Software Foundation, Inc. This is free software; see the source for copying conditions. There is NO warranty; no…

Cisco Secure Firewall Threat Defense Virtual 7.6.0 发布下载,新增功能概览

Cisco Secure Firewall Threat Defense Virtual 7.6.0 - 思科下一代防火墙虚拟设备 (FTDv) Firepower Threat Defense (FTD) Software for ESXi & KVM 请访问原文链接:https://sysin.org/blog/cisco-firepower-7/,查看最新版。原创作品&#xff0c…

信息安全管理工程师(工信部教育与考试中心)

在信息技术迅猛发展的时代,信息安全已经成为企业乃至国家安全不可或缺的一环。 工信部高级信息安全管理工程师认证,作为软考中的一项顶尖资格认证,对提升信息安全管理人员的专业能力、确保信息安全性具有至关重要的作用。 本文将深入探讨该…

神经网络(四):UNet图像分割网络

文章目录 一、简介二、网络结构2.1编码器部分2.2解码器部分2.3完整代码 三、实战案例 一、简介 UNet网络是一种用于图像分割的卷积神经网络,其特点是采用了U型网络结构,因此称为UNet。该网络具有编码器和解码器结构,两种结构的功能如下&#…

网络安全中的 EDR 是什么:概述和功能

专业知识:EDR、XDR、NDR 和 MDR_xdr edr ndr-CSDN博客 端点检测和响应 (EDR) 是一种先进的安全系统,用于检测、调查和解决端点上的网络攻击。它可以检查事件、检查行为并将系统恢复到攻击前的状态。EDR 使用人工智能、机器学习和威胁情报来避免再次发生攻…

矩阵分析 学习笔记4 内积与Gram矩阵

内积 定义 由于对称,第二变元线性那第一变元也线性了。例如这个:

Tomcat may not be running

一、问题背景 tomcat7运行在JDK1.7上,可启动tomcat,但是停止时报错误,如下: 二、适用条件 JDK1.7/JDK1.8 tomcat7 三、解决方法 1、查找java路径 which java 2、修改文件 找到/usr/lib/jvm/jdk1.7.0_80/jre/lib/security/j…

力扣P1706全排列问题 很好的引入暴力 递归 回溯 dfs

代码思路是受一个洛谷题解里面大佬的启发。应该算是一个dfs和回溯的入门题目&#xff0c;很好的入门题目了下面我会先给我原题解思路我想可以很快了解这个思路。下面是我自己根据力扣大佬写的。 我会进行详细讲解并配上图辅助理解大家请往下看 #include<iostream> #inc…

Java 注解详解:从基础到自定义及解析

注解&#xff1a;概述 目标 能够理解注解在程序中的作用 路径 什么是注解注解的作用 注解 什么是注解&#xff1f; 注解(Annotation)也称为元数据&#xff0c;是一种代码级别的说明注解是JDK1.5版本引入的一个特性&#xff0c;和类、接口是在同一个层次注解可以声明在包…

创意实现!在uni-app小程序商品详情页轮播中嵌入视频播放功能

背景介绍 通过uni-app框架实现商城小程序商品详情页的视频与图片轮播功能&#xff0c;以提升用户体验和增加商品吸引力。通过展示商品视频和图片&#xff0c;用户可以更全面地了解商品细节&#xff0c;从而提高购买决策的便利性和满意度。这种功能适用于各类商品&#xff0c;如…

Redis --- redis事务和分布式事务锁

redis事务基本实现 Redis 可以通过 MULTI&#xff0c;EXEC&#xff0c;DISCARD 和 WATCH 等命令来实现事务(transaction)功能。 > MULTI OK > SET USER "Guide哥" QUEUED > GET USER QUEUED > EXEC 1) OK 2) "Guide哥"使用 MULTI命令后可以输入…

嘉立创EDA-- 线宽、过孔和电流大小对比图

导线宽度和电流大小如何来考虑 1 电流大小需要考虑问题 1、允许的温升&#xff1a;如果能够允许的铜线升高的温度越高&#xff0c;那么允许通过的电流自然也就越高 2、走线的线宽&#xff1a;线越宽 &#xff0c;导线横截面积越大&#xff0c;电阻越小&#xff0c;发热越小&a…

影刀---实现我的第一个抓取数据的机器人

你们要的csdn自动回复机器人在这里文末哦&#xff01; 这个上传的资源要vip下载&#xff0c;如果想了解影刀这个软件的话可以私聊我&#xff0c;我发你 目录 1.网页对象2.网页元素3.相似元素组4.元素操作设置下拉框复选框滚动条获取元素的信息 5.变量6.数据的表达字符串变量列…