paper:Learning Spatial Fusion for Single-Shot Object Detection
official implementation:https://github.com/GOATmessi7/ASFF
背景
金字塔特征表示pyramid feature representation是解决目标检测中尺度变化挑战的常用方法。特征金字塔的一个主要缺点是在不同尺度上的不一致性,特别是对于one-shot detector。具体来说,当用feature pyramid检测物体时,采用一种启发式引导的特征选择:大目标与深层特征图关联,小目标与浅层特征图关联。当一个对象在某一层的特征图中被视为positive时,其它level的特征图中相应的区域就被视为背景。因此,如果一个图像既包含小物体又包含大物体,不同层次特征之间的冲突往往占据特征金字塔的主要部分。这种不一致性干扰了训练过程中的梯度计算,并降低了特征金字塔的有效性。
本文的创新点
本文提出了一种新的数据驱动的金字塔特征融合策略,称为自适应空间特征融合(adaptively spatial feature fusion, ASFF)。它学习对冲突信息进行空间过滤以抑制不一致性的方法,从而提高了特征的尺度不变性,并几乎没有增加额外开销。
该方法使网络能够直接学习如何在其它层对特征进行空间过滤,从而只保留有用信息进行融合。对于某个level的特征,首先将其它level的特征整合并调整到相同的分辨率,然后通过训练找到最优融合。在每个空间位置,不同level的特征被自适应的融合,即一些特征可能因为在该位置带有矛盾的信息而被过滤掉,一些特征因为带有更具差异性的信息而在该位置占主导。
ASFF具有以下几点优势
- 由于搜索最优融合的操作是可微的,因此可以在反向传播中学习
- 它对于backbone是不可知的,并应用于带有特征金字塔结构的单阶段目标检测模型中
- 实现简单,增加的计算成本可以忽略
方法介绍
Feature Resizing
我们将 \(l\) 层的特征记为 \(\mathbf{x}^{l}\)(对于YOLOv3 \(l\in\{1,2,3\}\))。对于层 \(l\),我们将其它层 \(n(n\ne l)\) 的特征 \(\mathbf{x}^n\) resize到和 \(\mathbf{x}^{l}\) 一样的分辨率。由于YOLOv3中三个level的特征具有不同的分辨率和不同的通道数,因此我们相应地修改了每个尺度的上采样和下采样策略。对于上采样,我们首先通过1x1卷积将特征的通道数压缩到 \(l\) 层对应的通道数,然后通过插值进行上采样。对于1/2比例的下采样,我们通过stride=2的3x3卷积通过修改分辨率和通道数。对于1/4的下采样,我们在2-stride卷积之前添加一个2-stride的max pooling层。
Adaptive Fusion
我们用 \(\mathbf{x}_{ij}^{n\to l}\) 表示从 \(n\) 层resize到 \(l\) 层的特征图上位置 \((i,j)\) 处的特征向量,我们在 \(l\) 层上按下式进行融合
其中 \(\mathbf{y}_{ij}^l\) 表示输出特征图 \(\mathbf{y}^l\) 沿通道的第 \((i,j)\) 个向量。\(\alpha_{ij}^l,\beta_{ij}^l,\gamma_{ij}^l\) 表示三个不同level对level \(l\) 特征图的空间重要性权重,它们是网络自适应学习到的。注意 \(\alpha_{ij}^l,\beta_{ij}^l,\gamma_{ij}^l\) 可以是标量变量,所有通道共享。我们使 \(\alpha_{ij}^l+\beta_{ij}^l+\gamma_{ij}^l=1\),\(\alpha_{ij}^l,\beta_{ij}^l,\gamma_{ij}^l\in [0,1]\),并定义
这里 \(\alpha_{ij}^l,\beta_{ij}^l,\gamma_{ij}^l\) 通过对 \(\lambda^{l}_{\alpha_{ij}},\lambda^{l}_{\beta_{ij}},\lambda^{l}_{\gamma_{ij}}\) 使用softmax计算得到,我们使用1x1卷积从 \(\mathbf{x}^{1\to l},\mathbf{x}^{2\to l},\mathbf{x}^{3\to l}\) 得到权重标量 \(\lambda^{l}_{\alpha},\lambda^{l}_{\beta},\lambda^{l}_{\gamma}\)。
使用这种方法,所有lelvel的特征都在每个尺度上自适应地聚合。输出 \(\{\mathbf{y}^1,\mathbf{y}^2,\mathbf{y}^3\}\) 按照YOLOv3相同的pipeline用于目标检测。
实验结果
代码
class ASFF(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super(ASFF, self).__init__()
self.level = level
self.dim = [512, 256, 256]
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
self.expand = add_conv(self.inter_dim, 1024, 3, 1)
elif level == 1:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
self.expand = add_conv(self.inter_dim, 512, 3, 1)
elif level == 2:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.expand = add_conv(self.inter_dim, 256, 3, 1)
compress_c = 8 if rfb else 16 # when adding rfb, we use half number of channels to save memory
self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
self.vis = vis
def forward(self, x_level_0, x_level_1, x_level_2): # (b,512,13,13),(b,256,26,26),(b,256,52,52)
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1) # (b,512,13,13)
level_2_downsampled_inter = F.max_pool2d(x_level_2, 3, stride=2, padding=1) # (b,256,26,26)
level_2_resized = self.stride_level_2(level_2_downsampled_inter) # (b,512,13,13)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0) # (b,256,13,13)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest') # (b,256,26,26)
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2) # (b,256,26,26)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0) # (b,256,13,13)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest') # (b,256,52,52)
level_1_resized = F.interpolate(x_level_1, scale_factor=2, mode='nearest') # (b,256,52,52)
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized) # (b,16,13,13)
level_1_weight_v = self.weight_level_1(level_1_resized) # (b,16,13,13)
level_2_weight_v = self.weight_level_2(level_2_resized) # (b,16,13,13)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1) # (b,48,13,13)
levels_weight = self.weight_levels(levels_weight_v) # (b,3,13,13)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
level_1_resized * levels_weight[:, 1:2, :, :] + \
level_2_resized * levels_weight[:, 2:, :, :]
out = self.expand(fused_out_reduced) # (b,1024,13,13)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out