上图展示了本文网络的整体架构。与ViT、MLP-Mixer和Swin Transformer类似,空间分辨率为H×W的输入图像被分割为不重叠的patch。作者在网络中采用了4×4的patch大小,每个patch被reshape成一个48维的向量,然后由一个线性层映射到一个c维embedding
import torch, os, datetime
from torch import nn
class sMLPBlock(nn.Module):
def __init__(self, h=224, w=224, c=3):
super().__init__()
self.proj_h = nn.Linear(h, h)
self.proj_w = nn.Linear(w, w)
self.fuse = nn.Linear(3 * c, c)
def forward(self, x):
x_h = self.proj_h(x.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
x_w = self.proj_w(x)
x_id = x
x_fuse = torch.cat([x_h, x_w, x_id], dim=1)
out = self.fuse(x_fuse.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return out
if __name__ == '__main__':
input = torch.randn(2, 3, 224, 224)
smlp = sMLPBlock(h=224, w=224)
out = smlp(input)
print(out.shape)