paper:DaViT: Dual Attention Vision Transformers
official implementation:https://github.com/dingmyu/davit
third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/davit.py
出发点
现有的视觉Transformer在捕捉全局上下文和计算效率之间存在权衡。传统的方法在像素级或patch级别进行自注意力计算,要么带来高计算开销,要么丧失全局上下文信息。本文提出了一种新的自注意力机制,旨在解决这一问题。
创新点
DaViT通过引入“空间token”和“通道token”来同时捕捉全局上下文和局部信息,并保持计算效率。通过交替使用这两种自注意力机制,DaViT能够有效地处理高分辨率图像,同时保持计算成本的线性增长。
- 双重注意力机制:引入了空间窗口自注意力和通道组自注意力。这两种注意力机制交替使用,既捕捉了局部精细结构信息,又捕捉了全局上下文。
- 通道token:通过对token矩阵进行转置,定义通道token,使每个通道token在空间维度上具有全局性,包含整个图像的抽象表示。
- 计算效率:通过分组注意力,将通道维度的计算复杂度降低为线性,从而在保持全局信息捕捉能力的同时,显著降低计算成本。
方法介绍
DaViT由spatial window self-attention和channel group self-attention组成,其中前者在swin-transformer等多个网络中都已经使用,这里不再过多介绍。这里主要介绍下本文提出的通道组自注意力,如图1(b)和图3(c)所示。
之前的自注意力将pixels或patches定义为token,并沿空间维度收集信息。通道注意力将注意力应用于patch-level token的转置上,为了获得空间维度的全局信息,将heads数设置为1。作者认为每个转置的token都概括了全局信息。这样通道token以线性的空间复杂度与通道维度上的全局信息进行交互。如图1(b)所示。
为了进一步降低复杂度,作者将通道分成多个group,并在每个group内计算self-attention,类似于在空间维度分成多个window并在每个window内计算self-attention。具体来说我们用 \(N_g\) 表示group数量,\(C_g\) 表示每个group内的通道数量,我们有 \(C=N_g*C_g\)。这样channel group attention是全局的,定义如下
其中 \(Q_i,K_i,V_i\in \mathbb{R}^{P\times C_g}\) 是grouped channel-wise image-level的queries、keys和values。
代码解析
下面是timm中的channel attention实现代码,其中输入shape=(1, 3136, 96),1代表batch_size,3136=56x56是空间分辨率,96是通道数。其中k转置后与v相乘得到attention矩阵,然后取softmax再与q的转置相乘。个人理解其实这里q、k、v的符号不重要,毕竟三个张量的维度都是一致的,比如k的转置与q相乘后取softmax再与v的转置相乘也是一样的。重要的是张量中sequence_length或者是token的数量与特征维度的定义,在spatial attention中spatial维度是seq_len,而channel维度是feature dim。这里channel attention中反过来了,channel维度作为seq_len,而spatial维度作为feature_dim,因此需要进行转置。第一步attention矩阵的维度应该是(seq_len, seq_len)的,在普通的attention中是(3136, 3136),这里就是(32, 32)因为num_heads=3,所以feat_dim=96/3=32。
这里特别需要注意的是,前面说过为了保持spatial维度的全局信息,设置num_heads=1,而这里为什么num_heads=3。因为这里的self.qkv的输入是原始的x而不是转置后的x,因此是将96分成3份,对应的是前面说的group,而spatial维度的3136保持不变,即保持了全局的信息。
class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads # 3
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor): # (1,3136,96)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (1,3136,288)->(1,3136,3,3,32)->(3,1,3,3136,32)
q, k, v = qkv.unbind(0) # (1,3,3136,32)
k = k * self.scale
attention = k.transpose(-1, -2) @ v # (1,3,32,3136) @ (1,3,3136,32) -> (1,3,32,32)
attention = attention.softmax(dim=-1)
# (1,3,32,32) @ (1,3,32,3136)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) # (1,3,32,3136)->(1,3,3136,32)
x = x.transpose(1, 2).reshape(B, N, C) # (1,3136,3,32)->(1,3136,96)
x = self.proj(x)
return x
实验结果
不同大小的模型配置如下
其中 \(L\) 表示层数,\(N_g\) 是channel attention中的group数,\(N_h\) 是window attention中的head数。
在ImageNet上的结果如表1所示