Vit将纯transformer结构引入到CV的基本任务——图像分类中
VIT
- 1.输入端适配
- 1.1 图像切分重排
- 1.2构造Patch0
- 1.3 positional enbedding
- ViT 结构的数据流
- 完整模型代码
1.输入端适配
1.1 图像切分重排
图像切分之后进行拉平,Flatten可能导致维度过高,假设是三通道图像,切分后图像大小是32,则维度是32323,因此需要经过Linear Projection 层进行降低维度
1.2构造Patch0
从CV角度理解patch0,它可能是用来整合信息的
patch0如何整合信息? 本质是一个动态池化层
cls_token的shape是(1,1,dim) 使用dim是为了与后面的维度拼接到一起,所以维度相同,第一个1是batchsize,第二个1是指有一个cls_token,即每张图只有一个patch0
repeat是干什么呢?
()表示占位,从一个变成了b个
把cls_token复制b份 ,假如一个batch中有8张图,那么需要8个cls_token,这八个cls_token是一样的,做的任务是一样的
patch0做的都是一样的 ,就是问这张图片,图片中的是什么物体,或者什么是最突出的?比较重要的
然后得到一个attention map
1.3 positional enbedding
ViT 结构的数据流
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])#初始化一个空的模块列表,用于存储Transformer的各个层
for _ in range(depth):#这个循环用于添加depth数量的Transformer层到self.layers中
self.layers.append(nn.ModuleList([ #每一层都包含以下两部分:
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
#这是带有预归一化(PreNorm)的多头注意力(Attention)模块。
#预归一化是一种常见的技术,用于在进入注意力机制之前对输入进行归一化,有助于稳定和加速训练。
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
#这是带有预归一化的前馈网络(FeedForward)模块。
# 前馈网络是Transformer层中另一个关键组成部分,它在多头注意力模块之后进一步处理数据。
]))
def forward(self, x): #输入x会依次通过所有的Transformer层
#对于每一层,首先执行多头注意力机制,
# 然后将结果与原始输入相加(残差连接),
# 接着通过前馈网络,并再次与输入相加(另一个残差连接)。
# 这样的设计有助于避免在深层网络中发生梯度消失或爆炸的问题。
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
x = attn(x) + x 实现了一个基本的残差连接,其中 attn(x) 代表对输入 x 应用了注意力机制,然后将这个变换后的输出与原始输入 x 直接相加。简而言之,它让数据经过一个变换层后,再将变换前的原数据加回来,以此保持信息的完整性并帮助梯度流通,这样做有助于模型学习和训练过程的稳定性。
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):#dim_head: 每个头在处理数据时的维度
super().__init__()
inner_dim = dim_head * heads #计算得到的内部维度,为每个头的维度乘以头的数量
project_out = not (heads == 1 and dim_head == dim)
#决定是否需要对输出进行线性变换的布尔值。
# 如果使用多头并且每头的维度与输入维度不同,则需要进行变换。
self.heads = heads
self.scale = dim_head ** -0.5#缩放因子,
# 用于调整query和key相乘的结果,通常用于稳定训练过程
self.attend = nn.Softmax(dim=-1)#一个Softmax函数,用于计算注意力权重
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
#一个线性变换,用于将输入转换为query、key和value
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
#如果需要对输出进行项目化,则使用一个线性变换加上dropout;如果不需要,则使用恒等变换。
def forward(self, x):
b, n, _, h = *x.shape, self.heads
#输入 x 的维度被拆分为批次大小 b、序列长度 n,
# 并且考虑了头的数量 h
qkv = self.to_qkv(x).chunk(3, dim=-1)#通过 to_qkv
# 线性变换将 x 映射到query、key和value上,并且这三者被均等地分割
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
#使用 rearrange 函数,将query、key和value重排为多头注意力所需的形状。
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
#计算query和key的点乘,然后乘以缩放因子 scale
attn = self.attend(dots)
#计算query和key的点乘,然后乘以缩放因子 scale
out = einsum('b h i j, b h j d -> b h i d', attn, v)
#使用点乘将注意力权重与value相乘,获取加权的value,即注意力机制的输出。
out = rearrange(out, 'b h n d -> b n (h d)')
#将多头输出重排回原始的形状,并通过 to_out 进行可能的线性变换和dropout。
return self.to_out(out)
将输入映射到query、key和value,计算注意力权重,并使用这些权重来加权value,最终可能对输出进行进一步的线性变换。这种机制是Transformer架构的核心部分,用于捕捉序列内的长距离依赖关系。
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
#hidden_dim: 中间层的维度,通常大于输入输出维度,以允许网络捕获更复杂的特征。
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),#一个线性层,将输入维度 dim 映射到隐藏维度 hidden_dim
nn.GELU(),#GELU激活函数,用于引入非线性,是Transformer模型常用的激活函数之一
nn.Dropout(dropout),#Dropout层,用于减少过拟合风险
nn.Linear(hidden_dim, dim),#另一个线性层,将隐藏层维度 hidden_dim 映射回原始输入维度 dim
nn.Dropout(dropout)#又一个Dropout层,增加模型的泛化能力
)
def forward(self, x):
return self.net(x)#输入 x 通过 self.net 顺序模型进行前向传播。
# self.net 中的层依次处理数据:
# 首先是线性变换到更高的维度,
# 然后是非线性激活,
# 接着是dropout正则化,
# 再是线性变换回原维度,
# 最后再次应用dropout。
前馈网络模块是Transformer内部的一个标准组成部分,用于在多头注意力层之后进一步处理数据。它通常包含两个线性层和一个非线性激活函数,以及用于正则化的dropout层。通过这种方式,前馈网络可以在保持输入和输出维度不变的同时,通过较高维度的隐藏层捕获更复杂的特征和关系。
class PreNorm(nn.Module):#实现了一个预归一化(Pre-Normalization)的结构
def __init__(self, dim, fn):
#dim: 这个参数指定了归一化层期望的特征维度。在Transformer中,这通常是嵌入向量的维度。
#fn: 一个函数或者是一个神经网络模块,它代表了在归一化后将要应用的操作。
# 这可以是任何类型的模块,比如多头注意力模块、前馈神经网络模块等。
super().__init__()
self.norm = nn.LayerNorm(dim)#使用了 nn.LayerNorm 来创建一个层归一化实例
#层归一化是通过对输入特征的每个样本在特定维度(这里是 dim)上进行归一化,
#以确保网络的训练更加稳定和快速。
self.fn = fn#这里存储了传入的函数或模块,以便之后在前向传播中使用。
def forward(self, x, **kwargs):#输入 x 和任何额外的关键字参数 **kwargs
return self.fn(self.norm(x), **kwargs)
#self.norm(x): 首先,输入 x 通过层归一化,
# 这有助于减少训练过程中的内部协变量偏移,提高训练稳定性。
# self.fn(...): 然后,归一化后的数据被传递给 fn 函数/模块进行进一步处理。
# 这里的 **kwargs 使得可以灵活地传递额外的参数给 fn。
PreNorm 类实现了预归一化的模式,这个模式通过在执行主要操作(如注意力或前馈网络)之前先对输入进行归一化,帮助改进了模型的训练动态。这种方法与后归一化(Post-Normalization)形成对比,后者在执行主要操作后进行归一化。预归一化在一些情况下被发现可以提高Transformer模型的性能和稳定性。
完整模型代码
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def pair(t):
# pair 将输入参数 t 转换成一个元组,
# 如果参数本身已经是元组类型,则保持不变,
# 否则将参数重复两次组成一个新的元组返回
return t if isinstance(t, tuple) else (t, t)
# classes
#PreNorm 类实现了预归一化的模式,这个模式通过在执行主要操作(如注意力或前馈网络)之前
# 先对输入进行归一化,帮助改进了模型的训练动态。
# 这种方法与后归一化(Post-Normalization)形成对比,后者在执行主要操作后进行归一化。
# 预归一化在一些情况下被发现可以提高Transformer模型的性能和稳定性。
class PreNorm(nn.Module):#实现了一个预归一化(Pre-Normalization)的结构
def __init__(self, dim, fn):
#dim: 这个参数指定了归一化层期望的特征维度。在Transformer中,这通常是嵌入向量的维度。
#fn: 一个函数或者是一个神经网络模块,它代表了在归一化后将要应用的操作。
# 这可以是任何类型的模块,比如多头注意力模块、前馈神经网络模块等。
super().__init__()
self.norm = nn.LayerNorm(dim)#使用了 nn.LayerNorm 来创建一个层归一化实例
#层归一化是通过对输入特征的每个样本在特定维度(这里是 dim)上进行归一化,
#以确保网络的训练更加稳定和快速。
self.fn = fn#这里存储了传入的函数或模块,以便之后在前向传播中使用。
def forward(self, x, **kwargs):#输入 x 和任何额外的关键字参数 **kwargs
return self.fn(self.norm(x), **kwargs)
#self.norm(x): 首先,输入 x 通过层归一化,
# 这有助于减少训练过程中的内部协变量偏移,提高训练稳定性。
# self.fn(...): 然后,归一化后的数据被传递给 fn 函数/模块进行进一步处理。
# 这里的 **kwargs 使得可以灵活地传递额外的参数给 fn。
#前馈网络模块是Transformer内部的一个标准组成部分,
# 用于在多头注意力层之后进一步处理数据。
# 它通常包含两个线性层和一个非线性激活函数,以及用于正则化的dropout层。
# 通过这种方式,前馈网络可以在保持输入和输出维度不变的同时,通过较高维度的隐藏层捕获更复杂的特征和关系。
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
#hidden_dim: 中间层的维度,通常大于输入输出维度,以允许网络捕获更复杂的特征。
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),#一个线性层,将输入维度 dim 映射到隐藏维度 hidden_dim
nn.GELU(),#GELU激活函数,用于引入非线性,是Transformer模型常用的激活函数之一
nn.Dropout(dropout),#Dropout层,用于减少过拟合风险
nn.Linear(hidden_dim, dim),#另一个线性层,将隐藏层维度 hidden_dim 映射回原始输入维度 dim
nn.Dropout(dropout)#又一个Dropout层,增加模型的泛化能力
)
def forward(self, x):
return self.net(x)#输入 x 通过 self.net 顺序模型进行前向传播。
# self.net 中的层依次处理数据:
# 首先是线性变换到更高的维度,
# 然后是非线性激活,
# 接着是dropout正则化,
# 再是线性变换回原维度,
# 最后再次应用dropout。
#将输入映射到query、key和value,
# 计算注意力权重,
# 并使用这些权重来加权value,
# 最终可能对输出进行进一步的线性变换。
# 这种机制是Transformer架构的核心部分,用于捕捉序列内的长距离依赖关系。
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):#dim_head: 每个头在处理数据时的维度
super().__init__()
inner_dim = dim_head * heads #计算得到的内部维度,为每个头的维度乘以头的数量
project_out = not (heads == 1 and dim_head == dim)
#决定是否需要对输出进行线性变换的布尔值。
# 如果使用多头并且每头的维度与输入维度不同,则需要进行变换。
self.heads = heads
self.scale = dim_head ** -0.5#缩放因子,
# 用于调整query和key相乘的结果,通常用于稳定训练过程
self.attend = nn.Softmax(dim=-1)#一个Softmax函数,用于计算注意力权重
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
#一个线性变换,用于将输入转换为query、key和value
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
#如果需要对输出进行项目化,则使用一个线性变换加上dropout;如果不需要,则使用恒等变换。
def forward(self, x):
b, n, _, h = *x.shape, self.heads
#输入 x 的维度被拆分为批次大小 b、序列长度 n,
# 并且考虑了头的数量 h
qkv = self.to_qkv(x).chunk(3, dim=-1)#通过 to_qkv
# 线性变换将 x 映射到query、key和value上,并且这三者被均等地分割
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
#使用 rearrange 函数,将query、key和value重排为多头注意力所需的形状。
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
#计算query和key的点乘,然后乘以缩放因子 scale
attn = self.attend(dots)
#计算query和key的点乘,然后乘以缩放因子 scale
out = einsum('b h i j, b h j d -> b h i d', attn, v)
#使用点乘将注意力权重与value相乘,获取加权的value,即注意力机制的输出。
out = rearrange(out, 'b h n d -> b n (h d)')
#假设有八个head 就有八个attention map
#将多头输出重排回原始的形状,并通过 to_out 进行可能的线性变换和dropout。
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])#初始化一个空的模块列表,用于存储Transformer的各个层
for _ in range(depth):#这个循环用于添加depth数量的Transformer层到self.layers中
self.layers.append(nn.ModuleList([ #每一层都包含以下两部分:
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
#这是带有预归一化(PreNorm)的多头注意力(Attention)模块。
#预归一化是一种常见的技术,用于在进入注意力机制之前对输入进行归一化,有助于稳定和加速训练。
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
#这是带有预归一化的前馈网络(FeedForward)模块。
# 前馈网络是Transformer层中另一个关键组成部分,它在多头注意力模块之后进一步处理数据。
]))
def forward(self, x): #输入x会依次通过所有的Transformer层
#对于每一层,首先执行多头注意力机制,
# 然后将结果与原始输入相加(残差连接),
# 接着通过前馈网络,并再次与输入相加(另一个残差连接)。
# 这样的设计有助于避免在深层网络中发生梯度消失或爆炸的问题。
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
dim_head=64, dropout=0., emb_dropout=0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity() #它不会对输入进行任何处理,只是将输入作为输出返回
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
print("img:", img.shape)
x = self.to_patch_embedding(img)
# print("x:", x.shape)
b, n, _ = x.shape
# print("self.cls_token:", self.cls_token.shape)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)#将 self.cls_token 这个张量在批量维度上重复 b 次,
# 构造出一个形状为 (b, n, d) 的新张量,其中 n 是重复的次数,d 是 self.cls_token 的维度。
#一个图片有一个cls_token
# 这通常用于为每个样本添加一个类别标记(通常用于 Transformer 架构中)
print("cls_token:", cls_tokens.shape)
x = torch.cat((cls_tokens, x), dim=1)
print("x:", x.shape)
# x += self.pos_embedding
# [:, :(n + 1)][:, :(n + 1)] 选择了该张量的第一个维度(通常是批量维度)的所有元素,
# 并且在第二个维度(通常是补丁数量维度)上选择从索引 0 到索引 n 的部分。
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
import torch
#from vit_pytorch import ViT
v = ViT(
image_size = 256,#图像大小,以像素为单位。
# 如果图像是矩形的,则应该选择宽度和高度中较大的那个作为图像大小。
patch_size = 32,# 补丁大小,指图像被切分成的小块大小。
# 图像大小必须能被补丁大小整除。补丁的数量由 (image_size // patch_size) ** 2 计算得到,同时这个数量必须大于 16。
num_classes = 1000,#分类的类别数量,即模型需要将图像分为多少个类别。
dim = 1024,# 线性转换后输出张量的最后一个维度大小,
# 通常用于指定 nn.Linear(..., dim) 中的输出维度。
depth = 12,# Transformer 模块的堆叠层数,
# 即模型中包含多少个 Transformer 块
heads =12,#多头注意力(Multi-head Attention)层中的头的数量,
# 用于增加模型对不同特征的关注度。
mlp_dim = 3072,#MLP(全连接前馈)层的维度大小,用于提取特征。
dropout = 0.1,#在模型训练过程中随机丢弃神经元的比例,用于防止过拟合。
# 取值范围为 [0, 1],表示丢弃的比例。
emb_dropout = 0.1#嵌入(Embedding)层的 dropout 比例,
# 用于在输入嵌入时进行随机丢弃。
)
img = torch.randn(1, 3, 256, 256)
print(v)
preds = v(img) # (1, 1000)
print(preds.shape)
运行结果(将depth修改为2的结果)
ViT(
(to_patch_embedding): Sequential(
(0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
(1): Linear(in_features=3072, out_features=1024, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
(transformer): Transformer(
(layers): ModuleList(
(0): ModuleList(
(0): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(attend): Softmax(dim=-1)
(to_qkv): Linear(in_features=1024, out_features=2304, bias=False)
(to_out): Sequential(
(0): Linear(in_features=768, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=3072, bias=True)
(1): GELU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=1024, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
)
(1): ModuleList(
(0): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(attend): Softmax(dim=-1)
(to_qkv): Linear(in_features=1024, out_features=2304, bias=False)
(to_out): Sequential(
(0): Linear(in_features=768, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=3072, bias=True)
(1): GELU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=1024, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
(to_latent): Identity()
(mlp_head): Sequential(
(0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=1024, out_features=1000, bias=True)
)
)
img: torch.Size([1, 3, 256, 256])
cls_token: torch.Size([1, 1, 1024])
x: torch.Size([1, 65, 1024])
torch.Size([1, 1000])