论文题目:Learning Spatial Fusion for Single-Shot Object Detection
论文地址:Paper - ASFF
官方源码:GitHub - GOATmessi8/ASFF
简 介
多尺度特征融合是解决多尺度目标检测问题的关键技术,其中 FPN(特征金字塔网络)通过自顶向下的特征融合机制,将高层语义特征与低层细节特征进行简单结合,提升了检测效果。然而,FPN 的融合方法由于未充分考虑不同层级的特征图之间存在表征不一致性,可能引入冲突信息,限制了融合效果的进一步提升。ASFF(自适应空间特征融合)通过动态加权机制,在不同尺度和空间位置上自适应地融合特征,有效抑制了层级特征间的冲突信息,提高了多尺度目标检测的效果。这种优化方式体现了特征融合理论中对层次差异和空间适应性的关注。
核 心 代 码
(1)融合相邻层与非相邻层:
import torch
import torch.nn as nn
from ultralytics.utils.tal import dist2bbox, make_anchors
import math
import torch.nn.functional as F
__all__ = ['ASFF_Detect']
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
class DFL(nn.Module):
"""
Integral module of Distribution Focal Loss (DFL).
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
"""
def __init__(self, c1=16):
"""Initialize a convolutional layer with a given number of input channels."""
super().__init__()
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
x = torch.arange(c1, dtype=torch.float)
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
self.c1 = c1
def forward(self, x):
"""Applies a transformer layer on input tensor 'x' and returns a tensor."""
b, c, a = x.shape # batch, channels, anchors
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
class ASFFV5(nn.Module):
def __init__(self, level, ch, multiplier=1, rfb=False, vis=False, act_cfg=True):
"""
ASFF version for YoloV5 .
different than YoloV3
multiplier should be 1, 0.5 which means, the channel of ASFF can be
512, 256, 128 -> multiplier=1
256, 128, 64 -> multiplier=0.5
For even smaller, you need change code manually.
"""
super(ASFFV5, self).__init__()
self.level = level
self.dim = [int(ch[2] * multiplier), int(ch[1] * multiplier),
int(ch[0] * multiplier)]
# print(self.dim)
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 3, 2)
self.stride_level_2 = Conv(int(ch[0] * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(
ch[2] * multiplier), 3, 1)
elif level == 1:
self.compress_level_0 = Conv(
int(ch[2] * multiplier), self.inter_dim, 1, 1)
self.stride_level_2 = Conv(
int(ch[0] * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(ch[1] * multiplier), 3, 1)
elif level == 2:
self.compress_level_0 = Conv(
int(ch[2] * multiplier), self.inter_dim, 1, 1)
self.compress_level_1 = Conv(
int(ch[1] * multiplier), self.inter_dim, 1, 1)
self.expand = Conv(self.inter_dim, int(
ch[0] * multiplier), 3, 1)
# when adding rfb, we use half number of channels to save memory
compress_c = 8 if rfb else 16
self.weight_level_0 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = Conv(
self.inter_dim, compress_c, 1, 1)
self.weight_levels = Conv(
compress_c * 3, 3, 1, 1)
self.vis = vis
def forward(self, x): # l,m,s
"""
# 128, 256, 512
512, 256, 128
from small -> large
"""
x_level_0 = x[2] # l
x_level_1 = x[1] # m
x_level_2 = x[0] # s
# print('x_level_0: ', x_level_0.shape)
# print('x_level_1: ', x_level_1.shape)
# print('x_level_2: ', x_level_2.shape)
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(
x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=4, mode='nearest')
x_level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(
x_level_1_compressed, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
# print('level: {}, l1_resized: {}, l2_resized: {}'.format(self.level,
# level_1_resized.shape, level_2_resized.shape))
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
# print('level_0_weight_v: ', level_0_weight_v.shape)
# print('level_1_weight_v: ', level_1_weight_v.shape)
# print('level_2_weight_v: ', level_2_weight_v.shape)
levels_weight_v = torch.cat(
(level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
level_1_resized * levels_weight[:, 1:2, :, :] + \
level_2_resized * levels_weight[:, 2:, :, :]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
class ASFF_Detect(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
def __init__(self, nc=80, ch=(), multiplier=1, rfb=False):
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
self.l0_fusion = ASFFV5(level=0, ch=ch, multiplier=multiplier, rfb=rfb)
self.l1_fusion = ASFFV5(level=1, ch=ch, multiplier=multiplier, rfb=rfb)
self.l2_fusion = ASFFV5(level=2, ch=ch, multiplier=multiplier, rfb=rfb)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
x1 = self.l0_fusion(x)
x2 = self.l1_fusion(x)
x3 = self.l2_fusion(x)
x = [x3, x2, x1]
shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training:
return x
elif self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
box = x_cat[:, :self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4:]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
img_h = shape[2] * self.stride[0]
img_w = shape[3] * self.stride[0]
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
dbox /= img_size
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
if __name__ == "__main__":
image1 = torch.rand(1, 128, 160, 160)
image2 = torch.rand(1, 256, 80, 80)
image3 = torch.rand(1, 512, 40, 40)
image = [image1, image2, image3]
channel = (128, 256, 512)
model = ASFF_Detect(nc=80, ch=channel)
out = model(image)
print(out[1].shape)
(2)仅融合相邻层:
import torch
import torch.nn as nn
from ultralytics.utils.tal import dist2bbox, make_anchors
import math
import torch.nn.functional as F
__all__ = ['ASFF_Detect']
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
class DFL(nn.Module):
"""
Integral module of Distribution Focal Loss (DFL).
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
"""
def __init__(self, c1=16):
"""Initialize a convolutional layer with a given number of input channels."""
super().__init__()
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
x = torch.arange(c1, dtype=torch.float)
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
self.c1 = c1
def forward(self, x):
"""Applies a transformer layer on input tensor 'x' and returns a tensor."""
b, c, a = x.shape # batch, channels, anchors
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
class ASFFV5(nn.Module):
def __init__(self, level, ch, multiplier=1, rfb=False, vis=False, act_cfg=True):
super(ASFFV5, self).__init__()
self.level = level
self.dim = [int(ch[2] * multiplier), int(ch[1] * multiplier), int(ch[0] * multiplier)]
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(ch[2] * multiplier), 3, 1)
elif level == 1:
self.compress_level_0 = Conv(int(ch[2] * multiplier), self.inter_dim, 1, 1)
self.stride_level_2 = Conv(int(ch[0] * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(ch[1] * multiplier), 3, 1)
elif level == 2:
self.compress_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 1, 1)
self.expand = Conv(self.inter_dim, int(ch[0] * multiplier), 3, 1)
compress_c = 8 if rfb else 16
self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)
if level == 1:
self.weight_levels = Conv(compress_c * 3, 3, 1, 1)
else:
self.weight_levels = Conv(compress_c * 2, 2, 1, 1)
self.vis = vis
def forward(self, x): # l,m,s
x_level_0 = x[2] # l (1,256,8,8)
x_level_1 = x[1] # m (1,128,16,16)
x_level_2 = x[0] # s (1,64,32,32)
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + level_1_resized * levels_weight[:, 1:, :, :]
elif self.level == 1:
level_0_resized = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_resized, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + level_1_resized * levels_weight[:, 1:2, :, :] + level_2_resized * levels_weight[:, 2:, :, :]
elif self.level == 2:
level_1_resized = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(level_1_resized, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_1_resized * levels_weight[:, 0:1, :, :] + level_2_resized * levels_weight[:, 1:, :, :]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
class ASFF_Detect(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
def __init__(self, nc=80, ch=(), multiplier=1, rfb=False):
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
self.l0_fusion = ASFFV5(level=0, ch=ch, multiplier=multiplier, rfb=rfb)
self.l1_fusion = ASFFV5(level=1, ch=ch, multiplier=multiplier, rfb=rfb)
self.l2_fusion = ASFFV5(level=2, ch=ch, multiplier=multiplier, rfb=rfb)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
x1 = self.l0_fusion(x)
x2 = self.l1_fusion(x)
x3 = self.l2_fusion(x)
x = [x3, x2, x1]
shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training:
return x
elif self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
box = x_cat[:, :self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4:]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
img_h = shape[2] * self.stride[0]
img_w = shape[3] * self.stride[0]
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
dbox /= img_size
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
if __name__ == "__main__":
image1 = torch.rand(1, 128, 160, 160)
image2 = torch.rand(1, 256, 80, 80)
image3 = torch.rand(1, 512, 40, 40)
image = [image1, image2, image3]
channel = (128, 256, 512)
model = ASFF_Detect(nc=80, ch=channel)
out = model(image)
print(out[1].shape)