SMA2:代码实现详解——Image Encoder篇(Hiera)
写在前面
大家在SMA2:代码实现详解——Image Encoder篇(FpnNeck)下的留言我已收到,感谢大家的支持,后面如果遇到比较难以讲清的部分可能会使用视频的形式。博主最近要准备秋招,更新可能会慢许多,希望大家能谅解。
言归正传,在SMA2:代码实现详解——Image Encoder篇(FpnNeck)中,我们已经知道了SMA2的整体架构,并且介绍了Image Encoder组件中的FpnNeck。这一篇博客我们就来详细介绍Image Encoder的基本骨架backbone——Hiera。
Hiera介绍
Hiera是文章Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles中提出的一种分层视觉Transformer架构。它不仅可以处理图像,而且这个架构可以应用于视频。Hiera是一个纯粹的简单分层ViT模型,不存在任何卷积、移位或者十字窗口操作,仅有Transformer结构组件。它比之前跨多个模型大小、领域和任务的工作更快、更准确。
Hiera与MAE(Masked AutoEncoder)
MAE(Masked AutoEncoder, 掩码自编码器)
图像MAE由论文Masked Autoencoders Are Scalable Vision Learners提出,它表明,MAE是计算机视觉的可扩展自监督学习器。方法非常简单:屏蔽输入图像的随机Patch并重建丢失的像素。它基于两个核心设计。首先,作者开发了一种非对称编码器-解码器架构,其中的编码器仅对Patch的可见子集(没有掩码标记)进行操作,而轻量级解码器可根据潜在表示和掩码标记重建原始图像。作者发现屏蔽高比例的输入图像(例如 75%)会产生一项不简单且有意义的自我监督任务。将这两种设计结合起来能够高效且有效地训练大型模型:加速训练(3 倍或更多)并提高准确性。可扩展方法允许学习泛化良好的高容量模型:例如,在仅使用 ImageNet-1K 数据的方法中,普通 ViT-Huge 模型实现了最佳准确率 (87.8%)。下游任务中的传输性能优于监督预训练,并显示出有希望的扩展行为。
Hiera便使用了MAE的方式进行训练。
Hiera架构
选择使用像MAE(如图所示)这样的强代理任务(pretext task)来教导模型。 Hiera完全由标准ViT块组成。为了提高效率,在前两个阶段使用“掩模单元”内的局部注意力,其余阶段使用全局注意力(Global Attention)。在每个阶段转换中,Q和跳跃连接的特征通过线性层加倍,空间维度通过2×2最大池池化。
SMA2中Hiera(HieraDet)的实现
class Hiera(nn.Module):
"""
Reference: https://arxiv.org/abs/2306.00989
"""
def __init__(self, ...):
...
self.blocks = nn.ModuleList()
for i in range(depth):
dim_out = embed_dim
...
block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
drop_path=dpr[i],
q_stride=self.q_stride if i in self.q_pool_blocks else None,
window_size=window_size,
)
embed_dim = dim_out
self.blocks.append(block)
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile(
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
)
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.patch_embed(x)
# x: (B, H, W, C)
# Add pos embed
x = x + self._get_pos_embed(x.shape[1:3])
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.stage_ends[-1]) or (
i in self.stage_ends and self.return_interm_layers
):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs
首先,Hiera先将图片划分并映射为patch嵌入向量(上述代码62行),然后计算位置信息并相加(代码第66行)。值得注意的是,SMA2在实现Hiera中位置嵌入时,参照了Window Attention is Bugged: How not to Interpolate Position Embeddings一文,他们发现在使用窗口注意力的同时插值位置嵌入是错误的。Hiera和ViTDet两者确实都存在此错误。于是作者提出了一种简单的绝对窗口位置嵌入策略,它彻底解决了Hiera中的错误,并提高了ViTDet中模型的速度和性能。
代码的68-75行实际上就是Hiera主体ViT块的处理,值得关注的只有带有Q pooling
的ViT块,这是在MultiScaleBlock
中实现的。
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, ...] = (7, 7),
stride: Tuple[int, ...] = (4, 4),
padding: Tuple[int, ...] = (3, 3),
in_chans: int = 3,
embed_dim: int = 768,
):
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
PatchEmbed
模块将图片的形状(B,C,H,W)转化为更常见的适用于Transformer处理的形状(B, H, W, C),因为后面经过VIT块时会要求(B,L,C)的形式。实际上,这个模块的卷积映射继承了ViT的做法,直接利用了卷积的特性,通过指定Kernel_size与strides隐式划分了窗口,并且完成了线性变换得到patch enmbedding。
值得注意的是位置嵌入的计算:
class Hiera(nn.Module):
def __init__(...)
super().__init__()
...
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
self.pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
)
self.pos_embed_window = nn.Parameter(
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
)
...
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile(
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
)
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
代码第18行是计算全局的可学习位置嵌入。第19行加号的右边window_embed.tile(...)
是计算每个window内的局部位置编码,每个window的位置编码都是相同的。我们可以使用matplotlib
做一个可视化的样例,可能更容易理解。示例如下(由于代码中是零初始化,不太好展示,这里我选择随机初始化来展示):
从左到右依次为全局编码、局部编码和最终位置编码。
接下来我们来看MultiScaleBlock
的实现:
class MultiScaleBlock(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
norm_layer: Union[nn.Module, str] = "LayerNorm",
q_stride: Tuple[int, int] = None,
act_layer: nn.Module = nn.GELU,
window_size: int = 0,
):
super().__init__()
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
self.window_size = window_size
self.pool, self.q_stride = None, q_stride
if self.q_stride:
self.pool = nn.MaxPool2d(
kernel_size=q_stride, stride=q_stride, ceil_mode=False
)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=self.pool,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = MLP(
dim_out,
int(dim_out * mlp_ratio),
dim_out,
num_layers=2,
activation=act_layer,
)
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x # B, H, W, C
x = self.norm1(x)
# Skip connection
if self.dim != self.dim_out:
shortcut = do_pool(self.proj(x), self.pool)
# Window partition
window_size = self.window_size
if window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, window_size)
# Window Attention + Q Pooling (if stage change)
x = self.attn(x)
if self.q_stride:
# Shapes have changed due to Q pooling
window_size = self.window_size // self.q_stride[0]
H, W = shortcut.shape[1:3]
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
pad_hw = (H + pad_h, W + pad_w)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
# MLP
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def window_partition(x, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp)
def do_pool():... #(B, H, W, C) -> (B, H', W' C)
MultiScaleBlock
由MultiScaleAttention
和MLP
构成,有经验的小伙伴看到注意力机制和MLP,显然得出它是一个Transformer。第60-63行代码就是根据每个stage给定的window size划分patch。
而且针对于每个Stage的交界,都使用Q pooling
,这在MultiScaleAttention
中实现。
class MultiScaleAttention(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
q_pool: nn.Module = None,
):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.num_heads = num_heads
self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (B, H * W, 3, nHead, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
# q, k, v with shape (B, H * W, nheads, C)
q, k, v = torch.unbind(qkv, 2)
# Q pooling (for downsample at stage changes)
if self.q_pool:
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
H, W = q.shape[1:3] # downsampled shape
q = q.reshape(B, H * W, self.num_heads, -1)
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
x = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
)
# Transpose back
x = x.transpose(1, 2)
x = x.reshape(B, H, W, -1)
x = self.proj(x)
return x
代码19-23以及31-41都是比较传统的自注意力机制的计算了。
而所谓的Q pooling在26-30行,只是对Q向量转换为宽高的形状(B, H*W,)->(B, H, W, …),然后进行池化。其实对于H和W,它们应该是我们之前指定的window size。