从最开始听到人用Unet左inpainting,再到自己使用Unet做图像去噪任务,虽然没有用Unet做过分割,但Unet也可以称得上是老朋友了。现在回头再看Unet,温故知新,一些魔鬼真就藏在一些细节之中。
structure
结构由forward函数定义:
def forward(self,x0):
x1=self.model.head(x0)
x2=self.model.down1(x1)
x3=self.model.down2(x2)
x4=self.model.down3(x3)
x=self.model.body(x4)
x=self.model.up3(x)
x=self.model.up2(x)
x=self.model.up1(x)
x=self.model.tail(x)
return x
head,tail是没有分辨率上的变化的。并且在down和up内部,第一步是采样带来的尺寸变化,第二部分也还是有一些正常的卷积操作。
skip connect
在Unet的解码部分,除了对低分辨率的上采样,还会融合Unet左边同分辨率的featuremap,称为Skip connection。这样能达到更精确的分割。
根据融合的方式不同,有直接相加的,也有在通道维度concatenate的。后者计算量更大,但是不会造成信息丢失。FCN使用的就是直接相加。
def forward(self,x0):
x1=self.model.head(x0)
x2=self.model.down1(x1)
x3=self.model.down2(x2)
x4=self.model.down3(x3)
x=self.model.body(x4)
x=self.model.up3(x+x4)
x=self.model.up2(x+x3)
x=self.model.up1(x+x2)
x=self.model.tail(x+x1)
return x
def forward(self,x0):
x1=self.model.head(x0)
x2=self.model.down1(x1)
x3=self.model.down2(x2)
x4=self.model.down3(x3)
x=self.model.body(x4)
x=self.model.up3(torch.cat[x,x4],dim=1)
x=self.model.up2(torch.cat[x,x3],dim=1)
x=self.model.up1(torch.cat[x,x2],dim=1)
x=self.model.tail(torch.cat[x,x1],dim=1) #dum=1表示在维度1上进行cancate,这里是在通道维度
return x
downsample
从分辨率看,左右两部分分别是下降和升高,但从通道维度看,左右两边分别是升高和下降。所以Unet是先将信息转换为通道层面,然后由此再进行重构。
在上面的Unet结构图中,虽然下采样是由最大池化实现的,但在实际应用中,正常的卷积就可以达到下采样的目的,只不过为了分辨率下降为一半,卷积的参数要经过设计。
由卷积层输出尺寸: o = ⌊(i + 2p - k) / s⌋ + 1 ,padding设置为1时,输出尺寸可以完全由stride控制。stride=2时,分辨率下降为一半;当stride=1,分辨率不变化(在Unet右边部分的水平部分,分辨率不需要变化,只是特征提取,就属于这种情况。
upsample
至于上采样操作,现在普遍使用的是nn.ConvTranspose2d。转置卷积就是基于卷积的实现,通过转置的方法实现分辨率的提升。为了达到采样倍率为2的目的,也需要根据公式确定转置卷积的参数。也是(3,2,1)的组合,不过额外有一个output_padding
的参数。
为了增强网络的表达能力,一方面是可以加大Unet的层数,另外就是加深每一层的深度。比如参考ESRGAN使用RRDB。
RDB,在densenet的基础上增加了残差。dense指每个卷积的输入不仅有上一个卷积的输出,还要和前面所有卷积的输出cancat在一起。最终卷积的输出表示残差,残差因子默认为0.2,和最初的输入相加得到最终的输出:
RRDB和RDB相比多了一个R,事实上也是如此,RRDB就是3个RDB串联,输出作为残差:
在构建多个卷积层时,可以使用self.add_module(f'conv{i}',nn.Conv2d(in,out,3,1,1))简化代码,使用时直接使用self.conv1,self.cov2...
reference:
深度学习系列(四)分割网络模型(FCN、Unet、Unet++、SegNet、RefineNet)-腾讯云开发者社区-腾讯云