模块出处
[link] [code] [NIPS 22] SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation
模块名称
Multi-Scale Convolutional Attention (MSCA)
模块作用
多尺度特征提取,更大感受野
模块结构
模块代码
import torch
import torch.nn as nn
class MSCA(nn.Module):
def __init__(self, dim):
super(MSCA, self).__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(
dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(
dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
attn = self.conv3(attn)
return attn * u
if __name__ == '__main__':
x = torch.randn([1, 512, 16, 16])
msca = MSCA(512)
out = msca(x)
print(out.shape) # 1, 512, 16, 16
原文表述
如图2(a)所示,MSCA包含三个部分:一个深度卷积以汇总局部信息;一个多分支深度条带卷积以获取多尺度上下文;一个1×1卷积以建模不同通道之间的信息。该1×1卷积的输出将直接作为注意力以对MSCA的输出结果进行后处理加权。