专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!
一、本文介绍
本文只有代码及注意力模块简介,YOLOv9中的添加教程:可以看这篇文章。
YOLOv9有效提点|加入SE、CBAM、ECA、SimAM等几十种注意力机制(一)
BiFormer:《BiFormer: Vision Transformer with Bi-Level Routing Attention》
BiFormer是一种动态稀疏注意力机制,通过双层路由实现,可灵活分配计算,并具有内容意识。这种机制可以过滤掉无关的键值对,并应用精细的令牌到令牌注意力。作者使用这种机制构建了一种新型通用视觉转换器BiFormer,可以自适应地关注查询中的相关令牌,而不会被其他无关令牌分散注意力。。
"""
Core of BiFormer, Bi-Level Routing Attention.
To be refactored.
author: ZHU Lei
github: https://github.com/rayleizhu
email: ray.leizhu@outlook.com
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, LongTensor
class TopkRouting(nn.Module):
"""
differentiable topk routing with scaling
Args:
qk_dim: int, feature dimension of query and key
topk: int, the 'topk'
qk_scale: int or None, temperature (multiply) of softmax activation
with_param: bool, wether inorporate learnable params in routing unit
diff_routing: bool, wether make routing differentiable
soft_routing: bool, wether make output value multiplied by routing weights
"""
def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
super().__init__()
self.topk = topk
self.qk_dim = qk_dim
self.scale = qk_scale or qk_dim ** -0.5
self.diff_routing = diff_routing
# TODO: norm layer before/after linear?
self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
# routing activation
self.routing_act = nn.Softmax(dim=-1)
def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:
"""
Args:
q, k: (n, p^2, c) tensor
Return:
r_weight, topk_index: (n, p^2, topk) tensor
"""
if not self.diff_routing:
query, key = query.detach(), key.detach()
query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)
attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
return r_weight, topk_index
class KVGather(nn.Module):
def __init__(self, mul_weight='none'):
super().__init__()
assert mul_weight in ['none', 'soft', 'hard']
self.mul_weight = mul_weight
def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):
"""
r_idx: (n, p^2, topk) tensor
r_weight: (n, p^2, topk) tensor
kv: (n, p^2, w^2, c_kq+c_v)
Return:
(n, p^2, topk, w^2, c_kq+c_v) tensor
"""
# select kv according to routing index
n, p2, w2, c_kv = kv.size()
topk = r_idx.size(-1)
# print(r_idx.size(), r_weight.size())
# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy
dim=2,
index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv)
)
if self.mul_weight == 'soft':
topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
elif self.mul_weight == 'hard':
raise NotImplementedError('differentiable hard routing TBA')
# else: #'none'
# topk_kv = topk_kv # do nothing
return topk_kv
class QKVLinear(nn.Module):
def __init__(self, dim, qk_dim, bias=True):
super().__init__()
self.dim = dim
self.qk_dim = qk_dim
self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
def forward(self, x):
q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)
return q, kv
# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
# return q, k, v
class BiLevelRoutingAttention(nn.Module):
"""
n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
topk: topk for window filtering
param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
param_routing: extra linear for routing
diff_routing: wether to set routing differentiable
soft_routing: wether to multiply soft routing weights
"""
def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None,
kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
auto_pad=True):
super().__init__()
# local attention setting
self.dim = dim
self.n_win = n_win # Wh, Ww
self.num_heads = num_heads
self.qk_dim = qk_dim or dim
assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
self.scale = qk_scale or self.qk_dim ** -0.5
################side_dwconv (i.e. LCE in ShuntedTransformer)###########
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
lambda x: torch.zeros_like(x)
################ global routing setting #################
self.topk = topk
self.param_routing = param_routing
self.diff_routing = diff_routing
self.soft_routing = soft_routing
# router
assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
self.router = TopkRouting(qk_dim=self.qk_dim,
qk_scale=self.scale,
topk=self.topk,
diff_routing=self.diff_routing,
param_routing=self.param_routing)
if self.soft_routing: # soft routing, always diffrentiable (if no detach)
mul_weight = 'soft'
elif self.diff_routing: # hard differentiable routing
mul_weight = 'hard'
else: # hard non-differentiable routing
mul_weight = 'none'
self.kv_gather = KVGather(mul_weight=mul_weight)
# qkv mapping (shared by both global routing and local attention)
self.param_attention = param_attention
if self.param_attention == 'qkvo':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Linear(dim, dim)
elif self.param_attention == 'qkv':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Identity()
else:
raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
self.kv_downsample_mode = kv_downsample_mode
self.kv_per_win = kv_per_win
self.kv_downsample_ratio = kv_downsample_ratio
self.kv_downsample_kenel = kv_downsample_kernel
if self.kv_downsample_mode == 'ada_avgpool':
assert self.kv_per_win is not None
self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
elif self.kv_downsample_mode == 'ada_maxpool':
assert self.kv_per_win is not None
self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
elif self.kv_downsample_mode == 'maxpool':
assert self.kv_downsample_ratio is not None
self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
elif self.kv_downsample_mode == 'avgpool':
assert self.kv_downsample_ratio is not None
self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
elif self.kv_downsample_mode == 'identity': # no kv downsampling
self.kv_down = nn.Identity()
elif self.kv_downsample_mode == 'fracpool':
# assert self.kv_downsample_ratio is not None
# assert self.kv_downsample_kenel is not None
# TODO: fracpool
# 1. kernel size should be input size dependent
# 2. there is a random factor, need to avoid independent sampling for k and v
raise NotImplementedError('fracpool policy is not implemented yet!')
elif kv_downsample_mode == 'conv':
# TODO: need to consider the case where k != v so that need two downsample modules
raise NotImplementedError('conv policy is not implemented yet!')
else:
raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')
# softmax for local attention
self.attn_act = nn.Softmax(dim=-1)
self.auto_pad=auto_pad
def forward(self, x, ret_attn_mask=False):
"""
x: NHWC tensor
Return:
NHWC tensor
"""
x = rearrange(x, "n c h w -> n h w c")
# NOTE: use padding for semantic segmentation
###################################################
if self.auto_pad:
N, H_in, W_in, C = x.size()
pad_l = pad_t = 0
pad_r = (self.n_win - W_in % self.n_win) % self.n_win
pad_b = (self.n_win - H_in % self.n_win) % self.n_win
x = F.pad(x, (0, 0, # dim=-1
pad_l, pad_r, # dim=-2
pad_t, pad_b)) # dim=-3
_, H, W, _ = x.size() # padded size
else:
N, H, W, C = x.size()
assert H%self.n_win == 0 and W%self.n_win == 0 #
###################################################
# patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
#################qkv projection###################
# q: (n, p^2, w, w, c_qk)
# kv: (n, p^2, w, w, c_qk+c_v)
# NOTE: separte kv if there were memory leak issue caused by gather
q, kv = self.qkv(x)
# pixel-wise qkv
# q_pix: (n, p^2, w^2, c_qk)
# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)
##################side_dwconv(lepe)##################
# NOTE: call contiguous to avoid gradient warning when using ddp
lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())
lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
############ gather q dependent k/v #################
r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors
kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)
k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
######### do attention as normal ####################
k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)
# param-free multihead attention
attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
attn_weight = self.attn_act(attn_weight)
out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
h=H//self.n_win, w=W//self.n_win)
out = out + lepe
# output linear
out = self.wo(out)
# NOTE: use padding for semantic segmentation
# crop padded region
if self.auto_pad and (pad_r > 0 or pad_b > 0):
out = out[:, :H_in, :W_in, :].contiguous()
if ret_attn_mask:
return out, r_weight, r_idx, attn_weight
else:
return rearrange(out, "n h w c -> n c h w")
class Attention(nn.Module):
"""
vanilla attention
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
"""
args:
x: NCHW tensor
return:
NCHW tensor
"""
_, _, H, W = x.size()
x = rearrange(x, 'n c h w -> n (h w) c')
#######################################
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
#######################################
x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
return x
class AttentionLePE(nn.Module):
"""
vanilla attention
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
lambda x: torch.zeros_like(x)
def forward(self, x):
"""
args:
x: NCHW tensor
return:
NCHW tensor
"""
_, _, H, W = x.size()
x = rearrange(x, 'n c h w -> n (h w) c')
#######################################
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W))
lepe = rearrange(lepe, 'n c h w -> n (h w) c')
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = x + lepe
x = self.proj(x)
x = self.proj_drop(x)
#######################################
x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
return x
def _grid2seq(x:Tensor, region_size:Tuple[int], num_heads:int):
"""
Args:
x: BCHW tensor
region size: int
num_heads: number of attention heads
Return:
out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim)
region_h, region_w: number of regions per col/row
"""
B, C, H, W = x.size()
region_h, region_w = H//region_size[0], W//region_size[1]
x = x.view(B, num_heads, C//num_heads, region_h, region_size[0], region_w, region_size[1])
x = torch.einsum('bmdhpwq->bmhwpqd', x).flatten(2, 3).flatten(-3, -2) # (bs, nhead, nregion, reg_size, head_dim)
return x, region_h, region_w
def _seq2grid(x:Tensor, region_h:int, region_w:int, region_size:Tuple[int]):
"""
Args:
x: (bs, nhead, nregion, reg_size^2, head_dim)
Return:
x: (bs, C, H, W)
"""
bs, nhead, nregion, reg_size_square, head_dim = x.size()
x = x.view(bs, nhead, region_h, region_w, region_size[0], region_size[1], head_dim)
x = torch.einsum('bmhwpqd->bmdhpwq', x).reshape(bs, nhead*head_dim,
region_h*region_size[0], region_w*region_size[1])
return x
def regional_routing_attention_torch(
query:Tensor, key:Tensor, value:Tensor, scale:float,
region_graph:LongTensor, region_size:Tuple[int],
kv_region_size:Optional[Tuple[int]]=None,
auto_pad=True)->Tensor:
"""
Args:
query, key, value: (B, C, H, W) tensor
scale: the scale/temperature for dot product attention
region_graph: (B, nhead, h_q*w_q, topk) tensor, topk <= h_k*w_k
region_size: region/window size for queries, (rh, rw)
key_region_size: optional, if None, key_region_size=region_size
auto_pad: required to be true if the input sizes are not divisible by the region_size
Return:
output: (B, C, H, W) tensor
attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix
"""
kv_region_size = kv_region_size or region_size
bs, nhead, q_nregion, topk = region_graph.size()
# Auto pad to deal with any input size
q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0
if auto_pad:
_, _, Hq, Wq = query.size()
q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0]
q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1]
if (q_pad_b > 0 or q_pad_r > 0):
query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding
_, _, Hk, Wk = key.size()
kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0]
kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1]
if (kv_pad_r > 0 or kv_pad_b > 0):
key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
# to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim)
query, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead)
key, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead)
value, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead)
# gather key and values.
# TODO: is seperate gathering slower than fused one (our old version) ?
# torch.gather does not support broadcasting, hence we do it manually
bs, nhead, kv_nregion, kv_region_size, head_dim = key.size()
broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1).\
expand(-1, -1, -1, -1, kv_region_size, head_dim)
key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
# token-to-token attention
# (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size)
# -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size)
# TODO: mask padding region
attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2)
attn = torch.softmax(attn, dim=-1)
# (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim)
# -> (bs, nhead, q_nregion, reg_size, head_dim)
output = attn @ value_g.flatten(-3, -2)
# to BCHW format
output = _seq2grid(output, region_h=q_region_h, region_w=q_region_w, region_size=region_size)
# remove paddings if needed
if auto_pad and (q_pad_b > 0 or q_pad_r > 0):
output = output[:, :, :Hq, :Wq]
return output, attn
class BiLevelRoutingAttention_nchw(nn.Module):
"""Bi-Level Routing Attention that takes nchw input
Compared to legacy version, this implementation:
* removes unused args and components
* uses nchw input format to avoid frequent permutation
When the size of inputs is not divisible by the region size, there is also a numerical difference
than legacy implementation, due to:
* different way to pad the input feature map (padding after linear projection)
* different pooling behavior (count_include_pad=False)
Current implementation is more reasonable, hence we do not keep backward numerical compatiability
"""
def __init__(self, dim, num_heads=8, n_win=7, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False, attn_backend='torch'):
super().__init__()
# local attention setting
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!'
self.head_dim = self.dim // self.num_heads
self.scale = qk_scale or self.dim ** -0.5 # NOTE: to be consistent with old models.
################side_dwconv (i.e. LCE in Shunted Transformer)###########
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
lambda x: torch.zeros_like(x)
################ regional routing setting #################
self.topk = topk
self.n_win = n_win # number of windows per row/col
##########################################
self.qkv_linear = nn.Conv2d(self.dim, 3*self.dim, kernel_size=1)
self.output_linear = nn.Conv2d(self.dim, self.dim, kernel_size=1)
if attn_backend == 'torch':
self.attn_fn = regional_routing_attention_torch
else:
raise ValueError('CUDA implementation is not available yet. Please stay tuned.')
def forward(self, x:Tensor, ret_attn_mask=False):
"""
Args:
x: NCHW tensor, better to be channel_last (https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
Return:
NCHW tensor
"""
N, C, H, W = x.size()
region_size = (H//self.n_win, W//self.n_win)
# STEP 1: linear projection
qkv = self.qkv_linear.forward(x) # ncHW
q, k, v = qkv.chunk(3, dim=1) # ncHW
# STEP 2: region-to-region routing
# NOTE: ceil_mode=True, count_include_pad=False = auto padding
# NOTE: gradients backward through token-to-token attention. See Appendix A for the intuition.
q_r = F.avg_pool2d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False)
k_r = F.avg_pool2d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # nchw
q_r:Tensor = q_r.permute(0, 2, 3, 1).flatten(1, 2) # n(hw)c
k_r:Tensor = k_r.flatten(2, 3) # nc(hw)
a_r = q_r @ k_r # n(hw)(hw), adj matrix of regional graph
_, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(hw)k long tensor
idx_r:LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
# STEP 3: token to token attention (non-parametric function)
output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale,
region_graph=idx_r, region_size=region_size
)
output = output + self.lepe(v) # ncHW
output = self.output_linear(output) # ncHW
if ret_attn_mask:
return output, attn_mat
return output
《SEAFORMER: SQUEEZE-ENHANCED AXIAL TRANSFORMER FOR MOBILE SEMANTIC SEGMENTATION》
SeaFormer是一种新的方法,用于在移动设备上进行语义分割。这个方法设计了一个通用的注意力块,可以用来创建一系列主干架构,具有优越的成本效益。在与轻量级分割头配合使用时,该方法在ARM基于的移动设备上实现了最佳的分割精度和延迟。
这个暂时没调试,代码地址:https://github.com/fudan-zvg/SeaFormer
EMA:《Efficient Multi-Scale Attention Module with Cross-Spatial Learning》
EMA是一种新的高效多尺度注意力模块,也是具有跨空间学习的高效多尺度注意力模块,可以更好地保留每个通道的信息,减少计算开销。该模块将部分通道重塑为批量维度,并将通道维度分组为多个子特征,使空间语义特征在每个特征组内分布良好。它还通过将全局信息编码来重新校准每个并行分支中的通道权重,并进一步聚合两个并行分支的输出特征以捕获像素级的成对关系。
import torch
from torch import nn
class EMA(nn.Module):
def __init__(self, channels, factor=8):
super(EMA, self).__init__()
self.groups = factor
assert channels // self.groups > 0
self.softmax = nn.Softmax(-1)
self.agp = nn.AdaptiveAvgPool2d((1, 1))
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)
def forward(self, x):
b, c, h, w = x.size()
group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w
x_h = self.pool_h(group_x)
x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
x_h, x_w = torch.split(hw, [h, w], dim=2)
x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
x2 = self.conv3x3(group_x)
x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
return (group_x * weights.sigmoid()).reshape(b, c, h, w)
Efficient se《CenterMask : Real-Time Anchor-Free Instance Segmentation》
CenterMask是一种高效实例分割方法(实时无锚点实例分割),它结合了FCOS检测器和注意力引导掩膜分支,提高了检测性能。文章还改进了VoVNetV2主干网络,提高了性能。CenterMask和CenterMask-Lite分别针对大型和小型模型进行设计,CenterMask实现了38.3%的性能,超过了所有以前的最先进方法,同时速度更快。
import torch
from torch import nn as nn
from timm.models.layers.create_act import create_act_layer
class EffectiveSEModule(nn.Module):
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.add_maxpool = add_maxpool
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)
if __name__ == '__main__':
input=torch.randn(50,512,7,7)
Ese = EffectiveSEModule(512)
output=Ese(input)
print(output.shape)