论文标题:DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable Atrous Convolution
论文地址:https://arxiv.org/pdf/2006.02334
代码地址:https://github.com/joe-siyuan-qiao/DetectoRS/blob/612916ba89ad6452b07ae52d3a6ec8d34a792608/mmdet/ops/saconv.py


代码:
def forward(self, x):
# pre-context
avg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)
avg_x = self.pre_context(avg_x)
avg_x = avg_x.expand_as(x)
x = x + avg_x
# switch
avg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")
avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
switch = self.switch(avg_x)
# sac
weight = self._get_weight(self.weight)
if self.use_deform:
offset = self.offset_s(avg_x)
out_s = deform_conv(
x,
offset,
weight,
self.stride,
self.padding,
self.dilation,
self.groups,
1)
else:
out_s = super().conv2d_forward(x, weight)
ori_p = self.padding
ori_d = self.dilation
self.padding = tuple(3 * p for p in self.padding)
self.dilation = tuple(3 * d for d in self.dilation)
weight = weight + self.weight_diff
if self.use_deform:
offset = self.offset_l(avg_x)
out_l = deform_conv(
x,
offset,
weight,
self.stride,
self.padding,
self.dilation,
self.groups,
1)
else:
out_l = super().conv2d_forward(x, weight)
out = switch * out_s + (1 - switch) * out_l
self.padding = ori_p
self.dilation = ori_d
# post-context
avg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)
avg_x = self.post_context(avg_x)
avg_x = avg_x.expand_as(out)
out = out + avg_x
return out
out = switch * out_s + (1 - switch) * out_l 实现门控操作