在学习VIT之前,建议先把 Transformer 搞明白了:【transformer】入门与理解
做了那些改进?
看图就比较明白了,VIT只用了Encoder的部分,把每一个图片裁剪成若干子图,然后把一个子图flatten一下,当成nlp中的一个token处理。
值得注意的是,在首个 token中嵌入了一个 class_token,维度为(1,embed_dim=768),这个class_token在预测的时候比较有意思,见下图:
注意上图中有些细节遗漏,全流程应该是:先把输入进行 patch_embedding 变成 visual tokens,然后和 class_token 合并,最后 position_embedding。
另外需要注意的是,class_token 是一个可学习的参数,并不是每次输入时都需要输入的类别数值。
self.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98) #(1,1,768)
代码
其实有了 Transformer 的基础后,直接看代码就知道VIT是怎么做的了。
import copy
import torch
import torch.nn as nn
# 所有基于nn.Module结构的模版,可以删掉
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Mlp(nn.Module):
def __init__(self, embed_dim, mlp_ratio, dropout=0.):
super().__init__()
self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) # 中间层扩增
self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# TODO
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class PatchEmbedding(nn.Module):
def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768, dropout=0.):
super().__init__()
n_patches = (image_size // patch_size) * (image_size // patch_size) # 196 个 patch
self.patch_embedding = nn.Conv2d(in_channels=in_channels, # embedding 操作后变成 torch.Size([10, 768, 14, 14])
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size)
self.dropout = nn.Dropout(dropout)
# TODO: add class token
self.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98) #(1,1,768)
# TODO: add position embedding
self.position_embedding = nn.Parameter(torch.ones(1, n_patches+1, embed_dim) * 0.98) #(1,196+1,768)
def forward(self, x): # 先把 x patch_embedding,然后和 class_token 合并,最后 position_embedding
# [n, c, h, w]
cls_tokens = self.class_token.expand([x.shape[0], -1, -1]) #(10,1,768) 根据batch扩增 class_token
x = self.patch_embedding(x) # [n, embed_dim, h', w']
x = x.flatten(2) # torch.Size([10, 768, 196])
x = x.permute([0, 2, 1]) # torch.Size([10, 196, 768])
x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)
x = x + self.position_embedding
return x # torch.Size([10, 197, 768])
class Attention(nn.Module):
"""multi-head self attention"""
def __init__(self, embed_dim, num_heads, qkv_bias=True, dropout=0., attention_dropout=0.):
super().__init__()
self.num_heads = num_heads
self.head_dim = int(embed_dim / num_heads) # 768/4=192
self.all_head_dim = self.head_dim * num_heads
self.scales = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim,
self.all_head_dim * 3) # [768, 768*3]
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.attention_dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax()
def transpose_multihead(self, x):
# x: [N, num_patches 197, all_head_dim 768] -> [N, n_heads, num_patches, head_dim]
new_shape = [x.shape[:-1][0], x.shape[:-1][1], self.num_heads, self.head_dim] # [10, 197, 4, 192]
x = x.reshape(new_shape)
x = x.permute([0, 2, 1, 3]) # [10, 4, 197, 192]
return x
def forward(self, x): # Attention 前后输入输出维度不变,都是 [10, 197, 768]
B, N, _ = x.shape # torch.Size([10, 197, 768])
qkv = self.qkv(x).chunk(3, axis=-1) # 含有三个元素的列表,每一个元素大小 [10, 197, 768]
q, k, v = map(self.transpose_multihead, qkv) # [10, 4, 197, 192]
attn = torch.matmul(q, k.transpose(2,3)) # [10, 4, 197, 197]
attn = attn * self.scales
attn = self.softmax(attn)
attn = self.attention_dropout(attn)
out = torch.matmul(attn, v) # [10, 4, 197, 192]
out = out.permute([0, 2, 1, 3]) # [10, 197, 4, 192]
out = out.reshape([B, N, -1]) # [10, 197, 768]
out = self.proj(out) # [10, 197, 768]
out = self.dropout(out)
return out
class EncoderModule(nn.Module):
def __init__(self, embed_dim=768, num_heads=4, qkv_bias=True, mlp_ratio=4.0, dropout=0., attention_dropout=0.):
super().__init__()
self.attn_norm = nn.LayerNorm(embed_dim)
self.attn = Attention(embed_dim, num_heads)
self.mlp_norm = nn.LayerNorm(embed_dim)
self.mlp = Mlp(embed_dim, mlp_ratio)
def forward(self, x):
h = x # residual
x = self.attn_norm(x)
x = self.attn(x)
x = x + h
h = x # residual
x = self.mlp_norm(x)
x = self.mlp(x)
x = x + h
return x
class Encoder(nn.Module):
def __init__(self, embed_dim, depth):
super().__init__()
Module_list = []
for i in range(depth):
encoder_Module = EncoderModule()
Module_list.append(encoder_Module)
self.Modules = nn.ModuleList(Module_list)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
for Module in self.Modules:
x = Module(x)
x = self.norm(x)
return x
class VisualTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=3,
num_heads=8,
):
super().__init__()
self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
self.encoder = Encoder(embed_dim, depth)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x):
# x:[N, C, H, W]
x = self.patch_embedding(x) # torch.Size([10, 197, 768])
x = self.encoder(x) # torch.Size([10, 197, 768])
x = self.classifier(x[:, 0]) # 注意这里的处理很奇妙哦,参考 x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)
return x
vit = VisualTransformer()
print(vit)
input_data = torch.randn([10,3,224,224]) # 每批次输入10张图片
print(vit(input_data).shape) # torch.Size([10, 1000])