多粒度特征融合(细粒度图像分类)
- 摘要
- Abstract
- 1. 多粒度特征融合
- 1.1 文献摘要
- 1.2 研究背景
- 1.3 创新点
- 1.4 模型方法
- 1.4.1 Swin-Transformer
- 1.4.2 多粒度特征融合模块
- 1.4.3 自注意力
- 1.4.4 通道注意力
- 1.4.5 图卷积网络
- 1.4.6 基于Vision-Transformer的两阶段分类
- 1.5 实验
- 1.5.1 数据集
- 1.5.2 实施细节
- 1.5.3 实验结果
- 2. Swin-Transformer代码实现
- 总结
摘要
本周阅读了 Two-stage fine-grained image classification model based on
multi-granularity feature fusion 这篇 sci 2区论文,本文提出了一种基于Transformer多粒度特征融合的细粒度图像分类模型。 该方法采用目前比较先进的Swin-Transformer模型来提取特征并选择不同分辨率的特征图。 通过多粒度特征融合模块,融合不同粒度的特征。 并利用注意力机制增强通道和空间二维上的特征。 融合后的特征既有高语义的全局信息,又有低语义的局部信息。 此外,使用Vision-Transformer作为辅助模型,以非常小的成本定位图像中物体的位置。 图像处理后,最大程度地将物体与背景分离,以减少对分类结果的影响。 多粒度特征融合模块是一个即插即用的模块,可以与当前流行的Transformer和传统CNN网络相结合。 本文将详细介绍此架构。
Abstract
This week read Two-stage fine-grained image classification model based on
multi-granularity feature fusion This sci zone 2 paper presents a fine-grained image classification model based on Transformer multi-granularity feature fusion. The method adopts the current state-of-the-art Swin-Transformer model to extract features and select feature maps with different resolutions. Through the multi-granularity feature fusion module, features of different granularity are fused. And the attention mechanism is utilized to enhance the features on channel and spatial two dimensions. The fused features have both high semantic global information and low semantic local information. In addition, Vision-Transformer is used as an auxiliary model to localize the position of objects in the image at a very small cost. After image processing, objects are maximally separated from the background to minimize the impact on the classification results. The multi-granularity feature fusion module is a plug-and-play module that can be combined with the current popular Transformer and conventional CNN networks. This architecture is described in detail in this paper.
文献来源:Two-stage fine-grained image classification model based on
multi-granularity feature fusion
1. 多粒度特征融合
1.1 文献摘要
大多数现有方法直接使用网络的最终输出,该输出始终包含具有高级语义信息的全局特征。 然而,细粒度图像之间的差异反映在经常出现在网络前面的细微局部区域。 当背景和物体的纹理相似或背景所占比例过大时,预测会受到很大影响。
为了解决上述问题,本文提出多粒度特征融合模块(MGFF)和基于Vision-Transformer(ViT)的两阶段分类。 前者通过融合不同粒度的特征来综合表示图像,从而避免了单尺度特征的局限性。 后者利用ViT模型以非常小的成本将物体与背景分离,从而提高预测的准确性。 作者进行了全面的实验,并在 CUB-200-2011 和 NA-Birds 上的两个细粒度任务中取得了最佳性能。
1.2 研究背景
以往的方法:
- 通过使用附加注释或分析响应图来定位有区别的部分
- 分别提取这些部分的特征,并将这些特征连接起来进行分类
对于细粒度图像,判别部分的特征更为重要。 这就是自注意力机制发挥作用的地方。 这种机制通过对判别性特征施加更高的权重,使模型更加关注这部分的信息。 然而,上述方法存在一些缺点。 他们使用单尺度网络输出,这使得很容易忽略在预测中发挥关键作用的微妙特征。 另外,当背景所占比例过大或者纹理与物体相似时,最终的预测将会受到严重干扰。
为了解决上述问题,作者提出一种基于多粒度特征融合的两级细粒度图像分类模型。 该模型由两部分组成:多粒度特征融合模块和基于Vision-Transformer的两阶段分类。 作者对四个常用的细粒度数据集(CUB-200-2011、NA-Birds、Stanford Cars、Pets)进行了广泛的实验,并在两个数据集上实现了最先进的性能。
1.3 创新点
- 作者提出多粒度特征融合模块来解决单尺度特征的局限性。 该模块提取主干网络中不同尺度的特征并将其融合生成可以综合表示图像的特征。
- 作者提出基于 Vision-Transformer 的两阶段分类,以减少背景对预测的干扰。 借助ViT模型,可以将物体与背景分离,并放大细节,更有利于最终的预测。
大量的实验和最佳的性能可以证明该模型的优越性。 可视化结果表明,作者的两阶段分类可以准确地定位对象并促进正确的预测。
1.4 模型方法
我们提出的模型如下图所示。该模型由两部分组成:用于特征提取的骨干网络、多粒度特征融合模块。
该模型可以融合多粒度特征并增强判别性特征,生成代表整个图像的特征向量。 该模型首先使用先进的SwinTransformer模型对原始图像进行特征提取,提取了四层不同尺度的特征。 通过多粒度特征融合模块对特征进行处理:首先从空间和通道两个维度对特征进行增强,然后进行特征合并,最后利用图卷积进行特征融合,生成语义丰富的特征向量 代表整个图像进行分类。
1.4.1 Swin-Transformer
该模型在ViT的基础上进行改进,解决了ViT模型分辨率单一、计算量大的问题。 该模型被设计为适用于多种类型的视觉任务,例如图像分类、对象检测和语义分割。 该模型的效果远优于传统的纯卷积网络。 首先将图像分割成不重叠的补丁,然后通过线性连接对补丁进行编码,最后将提取的初始特征集作为Swin Transformer Blocks的输入进行更深层的特征提取。 为了生成分层表示,每组 2 × 2 相邻特征都通过线性映射进行合并,以减少特征数量。 这相当于将特征图下采样 2 倍,并且输出维度设置为双倍。 特征合并和特征变换的过程是共同重复的。 每个阶段都会产生不同分辨率的特征图,总共产生四个不同分辨率的特征图。
Swin-Transformer详细介绍请参见:第二十五周:文献阅读笔记(swin transformer)
1.4.2 多粒度特征融合模块
使用Swin-Transformer提取特征图后,需要将每一层变换为每个特征的统一尺寸(最后一层除外)。 如下图所示,总共需要处理三个特征图。 它们的大小分别为 48 × 48 × 256、24 × 24 × 512 和 12 × 12 × 1024。使用内核大小为 4 × 4、2 × 2 和 1 × 1 的卷积。所有特征图的统一大小为 12 × 12 × 1024。然后对于每个特征图,使用自注意力和通道注意力来增强重要信息。 最后,四个特征图在通道维度上连接起来,作为图卷积网络的输入,生成语义丰富的特征,最终代表整个图像。 该模块不限于单一主干网络,它可以与当前流行的Transformer网络和传统CNN网络相结合。 因此,它可以被视为即插即用模块。
1.4.3 自注意力
使用主干网络提取特征后,所有特征具有相同的重要性。 对于细粒度图像分类,并非所有特征都同等重要。 一些判别性特征应该具有较高的权重,这有利于模型的最终预测。 如何自适应地发现这些特征并增强它们,同时抑制其他特征,这就是注意力发挥作用的地方。
自注意力机制连接不同位置的特征,捕获长距离依赖关系,并增强重要特征。 将上述过程表述如下:首先,将特征投影到全连接的新空间,然后我们可以得到三个稠密矩阵:查询矩阵(𝑄)、键矩阵(𝐾)和值矩阵(𝑉):
其中
W
k
W_k
Wk、
W
q
W_q
Wq 和
W
v
W_v
Wv 是全连接对应的权重矩阵,F 是特征集,𝐾、𝑄、𝑉是新特征空间中的特征集。然后,自注意力机制使用点积相似度来描述任意两个特征之间的相关性,通过将查询矩阵与转置的关键矩阵相乘得到得分矩阵,进行softmax归一化得到注意力图:
其中
d
k
d_k
dk 是 K 的维度,𝐸𝑝𝑜𝑠 是相对位置嵌入,可以保留特征的位置信息。最后,通过计算注意力图 𝐴′ 和 𝑉 得到新的特征集 𝐼:
如下图所示,使用self-attention机制处理图像后,物体的可辨别部分明显增强,而其他区域则明显受到抑制。
对于自注意力的详细推导过程请参见:文献阅读笔记(Transformer)
1.4.4 通道注意力
为了从通道维度执行特征的自适应增强,引入了挤压和激励(SE)模块。 该模块通过捕获通道之间的相互依赖关系来增强特征,能够学习全局信息,选择性地强调信息丰富的特征,并抑制不太有用的特征。SE模块的结构如图3所示。对于任何给定的输入特征
X
∈
R
H
×
W
×
C
X\in R^{H\times W\times C}
X∈RH×W×C,通过使用全局均值池将全局空间信息压缩为通道描述符。 形式上,统计量
Z
∈
R
∣
X
∣
×
C
Z\in R^{\left | X \right | \times C}
Z∈R∣X∣×C 是通过压缩 𝑋 的空间维度产生的。 对于通道元素
Z
c
Z_c
Zc, 生成方法如下:
其中 𝑍𝑐 表示 𝑍 的第 𝑐 个元素,
x
c
x_c
xc(𝑖,𝑗)表示输入的第 𝑐 通道的元素。 为了利用聚合信息,需要第二次操作来完全捕获与通道相关的依赖关系。
其中 𝜎 是 Sigmoid 函数,𝛿 是 ReLU 函数,𝐹𝑒𝑥 表示由 𝑍 计算的通道注意力,
W
1
∈
R
C
r
×
C
W_1\in R^{\frac{C}{r}\times C }
W1∈RrC×C 和
W
2
∈
R
(
C
r
)
×
C
W_2\in R^{(\frac{C}{r})\times C }
W2∈R(rC)×C是矩阵权重,𝑟 控制的复杂性模块。 使用通道注意力图对原始特征进行缩放,得到最终的输出。 对应的特定通道可以通过以下公式得到:
1.4.5 图卷积网络
图卷积网络(GCN)是一系列可以自然地对图结构数据进行操作的神经网络。 通过从底层图中提取和利用特征,GCN 可以对这些关联实体做出比孤立考虑单个实体的模型更明智的预测。 特征图上的所有特征都被视为一个图结构,其中节点代表不同空间位置和尺度的特征。 如下图所示,特征图输入到GCN中,网络可以学习不同节点之间的关系。 然后,通过池化层将特征点聚合成多个超级节点。 最后对这些超级节点的特征进行平均,并使用线性分类器来完成预测。 这种方法的优点是可以更有效地集成每个点的特征,而不会破坏主干模型输出的结果。
1.4.6 基于Vision-Transformer的两阶段分类
如果背景占据图像的很大比例,这可能会导致最终结果的预测不准确。 为了解决这个问题,我们以很小的代价从原始图像中提取出目标并将其放大,以便新图像可以再次用于分类。 放大后的图像包含较少的背景,而原来较小的细节被放大,这两者都有利于模型捕获更多有用的信息。 最后将两个阶段的结果合并起来作为最终结果。
ViT模型将原始图像
X
∈
R
H
×
W
×
3
X\in R^{H\times W\times 3 }
X∈RH×W×3 以不重叠的方式划分为多个
p
a
t
c
h
∈
R
16
×
16
×
3
patch\in R^{16\times 16\times 3 }
patch∈R16×16×3,每个patch被线性映射成对应的
t
o
k
e
n
∈
R
1
×
1
×
C
token\in R^{1\times 1\times C }
token∈R1×1×C 。 如下图所示。矩阵的每一列都经过softmax函数处理,结合式
K
=
W
k
∗
F
K=W_{k}*F
K=Wk∗F。 每一列对应的值代表了对应新特征的比例。 那么这个非负值也是衡量它的一个重要指标。 对 𝐴′ 矩阵的行求和即可得到每个标记的重要性。 由于ViT是由多个MHA堆叠而成,因此该过程涉及多个权重矩阵的计算。
其中 𝐿 是 ViT 中 MHA 的数量,
N
=
H
16
×
W
16
N=\frac{H}{16}\times\frac{W}{16}
N=16H×16W 是tokens数量。 Reshape的作用就是将 𝑊𝑖 与对应位置的 token 对齐,使一维矩阵𝑊变成二维矩阵
W
′
∈
R
H
16
×
W
16
W'\in R^{\frac{H}{16}\times\frac{W}{16}}
W′∈R16H×16W 。 通过双线性插值将 𝑊′ 上采样 16 倍,得到与原始图像大小相同的响应图。
1.5 实验
1.5.1 数据集
实验在四个广泛使用的细粒度数据集上进行评估:CUB-200-2011、Stanford Cars、Oxford-IIIT Pets、NA-Birds。 此方法仅使用类标签,因此不会与使用附加信息的方法进行比较。 下表统计了数据集的数据划分和类别数。
1.5.2 实施细节
所有实验中,输入图像尺寸均调整为512×512,然后随机裁剪为384×384。由于不同规格的Swin-Transformer和ViT模型参数和效果不同,因此采用B规格最常用的模型 在这个实验中。 该方法使用动量为0.99的优化器(随机梯度下降,SGD)来优化模型。 初始学习率为2𝑒−4,我们使用函数𝑐𝑜𝑠(𝑥)调整每轮的学习率,训练50轮。 所有实验均使用 PyTorch 和 NVIDIA RTX 3090 GPU (24 GB) 进行。
1.5.3 实验结果
下表显示了作者的方法和最先进的方法在上述数据集上的比较结果。
总的来说,作者的方法在两个数据集上优于最先进的方法,并且在其他数据集上的表现与最先进的方法类似。 作者的方法在CUB-2002011数据集上可以达到92.60%,这是细粒度图像分类中最常用的数据集。 与目前最好的精度CAP相比,提高了0.86%。 与同样使用强大的Transformer模型作为骨干网络的TransFG和FFVT相比,该方法仍然实现了0.96%的精度提升。 TransFG 是 Transformer 模型应用于细粒度图像分类的首次尝试。 它和我们的方法都试图过滤掉重要的部分特征以进行最终分类。 然而,TransFG 仅从相同尺度中选择特征并丢弃其他特征,而我们的方法从多个尺度中选择特征并增强所选特征,同时抑制其他特征。 NA-Birds 是一个包含更多图像和类别的大型鸟类数据集,这进一步挑战了细粒度的视觉分类。 许多模型在小型数据集上表现良好,但在大型数据集上表现不佳。 在表2中,该方法可以获得92.08%,远远超过其他方法。 它展示了我们的方法构建模型的强大能力,该模型可以有效地识别子类别,而无需使用额外的数据集或辅助网络。 在另外两个数据集Stanford Cars和Oxford-IIIT Pets上,虽然我们的方法没有达到最好的精度,但与同样使用Transformer的Deit-B、TransFG等网络相比仍然具有很大的优势。
2. Swin-Transformer代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path_f(x, self.drop_prob, self.training)
def window_partition(x, window_size: int):
"""
将feature map按照window_size划分成一个个没有重叠的window
Args:
x: (B, H, W, C)
window_size (int): window size(M)
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
将一个个window还原成一个feature map
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size(M)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_chans = in_c
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
# padding
# 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
if pad_input:
# to pad the last 3 dimensions,
# (W_left, W_right, H_top,H_bottom, C_front, C_back)
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0],
0, 0))
# 下采样patch_size倍
x = self.proj(x)
_, _, H, W = x.shape
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
x: B, H*W, C
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
# 如果输入feature map的H,W不是2的整数倍,需要进行padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
# to pad the last 3 dimensions, starting from the last dimension and moving forward.
# (C_front, C_back, W_left, W_right, H_top, H_bottom)
# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x) # [B, H/2*W/2, 2*C]
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # [Mh, Mw]
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]
coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask: Optional[torch.Tensor] = None):
"""
Args:
x: input features with shape of (num_windows*B, Mh*Mw, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
# [batch_size*num_windows, Mh*Mw, total_embed_dim]
B_, N, C = x.shape
# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw]
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
# mask: [nW, Mh*Mw, Mh*Mw]
nW = mask.shape[0] # num_windows
# attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
# mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, attn_mask):
H, W = self.H, self.W
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
# 把feature map给pad到window size的整数倍
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C]
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C]
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
# 把前面pad的数据移除掉
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class BasicLayer(nn.Module):
"""
A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.depth = depth
self.window_size = window_size
self.use_checkpoint = use_checkpoint
self.shift_size = window_size // 2
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def create_mask(self, x, H, W):
# calculate attention mask for SW-MSA
# 保证Hp和Wp是window_size的整数倍
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
# [nW, Mh*Mw, Mh*Mw]
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, H, W):
attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw]
for blk in self.blocks:
blk.H, blk.W = H, W
if not torch.jit.is_scripting() and self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x = self.downsample(x, H, W)
H, W = (H + 1) // 2, (W + 1) // 2
return x, H, W
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
# stage4输出特征矩阵的channels
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
# 注意这里构建的stage和论文图中有些差异
# 这里的stage不包含该stage的patch_merging层,包含的是下个stage的
layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layers)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# x: [B, L, C]
x, H, W = self.patch_embed(x)
x = self.pos_drop(x)
for layer in self.layers:
x, H, W = layer(x, H, W)
x = self.norm(x) # [B, L, C]
x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]
x = torch.flatten(x, 1)
x = self.head(x)
return x
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes,
**kwargs)
return model
def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes,
**kwargs)
return model
def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
**kwargs)
return model
def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
# trained ImageNet-1K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=12,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
**kwargs)
return model
def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
**kwargs)
return model
def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=12,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
**kwargs)
return model
def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
num_classes=num_classes,
**kwargs)
return model
def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
# trained ImageNet-22K
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=12,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
num_classes=num_classes,
**kwargs)
return model
总结
本周阅读多粒度特征融合的架构, 该方法采用目前比较先进的Swin-Transformer模型来提取特征并选择不同分辨率的特征图。 通过多粒度特征融合模块,融合不同粒度的特征。 并利用注意力机制增强通道和空间二维上的特征。下周我将尝试复现该架构。加油~