Vision Transformer模型架构详解

🎀个人主页: https://zhangxiaoshu.blog.csdn.net
📢欢迎大家:关注🔍+点赞👍+评论📝+收藏⭐️,如有错误敬请指正!
💕未来很长,值得我们全力奔赴更美好的生活!

前言

2019年开始,自然语言处理(NLP)领域抛弃了循环神经网络(RNN)序列依赖的问题,开始采用Attention is All you need的Transformer结构[1],其中的Attention是一种可以让模型专注于重要的信息并能够充分学习和吸收的技术。在NLP领域中,伴随着各种语言Transformer模型的提出使得多项语言处理任务的精度和模型深度开始飞速提升。由于基于Transformer的预训练语言模型非常成功,研究者开始探索其在视觉领域的应用。2020年10月,Google创新性的设计了用于分类的Vision Transformer模型[2]—ViT。此后视觉Transformer模型的研究进入了快车道,本文主要对Vision Transformer模型架构进行详细介绍以及在pytorch中的使用方法进行介绍。


文章目录

  • 前言
  • 一、Vision Transformer模型架构
    • 1. Embedding层结构
    • 2. Transformer Encoder结构
      • (1)层归一化(Layer Norm)
      • (2)多头注意力机制(Multi-Head Attention)
      • (3)Dropout/DropPath
      • (4)MLP Block
    • 3. MLP Head结构
  • 二、PyTorch实现
    • 1. 首先安装vit-pytorch库:
    • 2.导入进行调用:
  • 总结


一、Vision Transformer模型架构

下图是原论文中作者给出的关于Vision Transformer的模型总体框架图:

在这里插入图片描述
从图中可以看出,Vision Transformer模型主要由三部分组成:第一部分为Linear Projection of Flattened Patches,也被称为Embedding层,主要用于将输入的图片数据转化为适合Transformer结构处理的形式。第二部分为Transformer Encoder部分,它是整个ViT模型的核心板块,在图右侧给出了更加详细的结构,它主要由层归一化(Layer Norm)、多头注意力机制(Multi-Head Attention)、Dropout/DropPath、MLP Block四部分组成用于学习输入图像数据的特征。第三部分为MLP Head,它是最终用于分类的层结构。下面本设计将对每一个组成部分进行一个详细介绍。

1. Embedding层结构

在视觉Transform模型中,其Transformer Encoder模块的输入形式是一个向量(token)序列,即一个二维矩阵[num_token, token_dim]的形式,如上图所示,输入的粉色小块token0-9对应的都是向量序列。

但是,图像处理和语言处理不一样,它的数据格式和Transformer Encoder输入格式是不一样,而是一个三维矩阵[H, W, C]的形式。所以在视觉Transform模型中首先加入了一个Embedding层结构用于将数据变化为向量序列。其主要过程为:首先将输入的图片形式数据按照模型定义的切割大小切割成多个小块(Patches),然后将切割的小块通过维度变化映射成向量形式。以常见的ViT-B/16为例,它首先将输入图片( 224 × 224 224\times224 224×224)按照 16 × 16 16\times16 16×16的大小进行切分得到196个Patches,接着通过线性映射将每一个Patches(16, 16, 3)映射成一个长度为768的向量。

在具体实现代码时,可以通过一个卷积层和Flatten层来直接实现。以ViT-B/16为例,如图所示其卷积层的参数为:卷积核大小是16x16、步距是16、卷积核的个数是768。数据通过卷积层后维度从(224, 224, 3)变化为(14, 14, 768),接着,将H和W两个维度展平即Flatten操作即可变化为(196, 768)这样的二维矩阵形式,这正是Transformer Encoder的输入格式。
在这里插入图片描述
除了将输入数据的形式变化为Transformer Encoder的输入格式,模型还在输入Transformer Encoder之前加入了[class]token以及Position Embedding,如下图所示。[class]token是参考了BERT所设计的,它是一个可以学习的参数,用于拼接到tokens中专门用于图像数据的分类。以ViT-B/16为例,就是让一个768长度的向量,与从Flatten层输出的数据拼接在一起,即,Cat((1, 768),(196, 768))—>(197, 768)。Position Embedding也是一个可以学习的参数。它是直接叠加在tokens上的(Add),因为对于图像数据而言,每一块和每一块在都有一定的位置依赖关系,所以Position Embedding主要用于表达Patches之间的位置关系。以ViT-B/16为例,就是让一个(197, 768)的向量与之前得到的(197, 768)向量相加。
在这里插入图片描述

2. Transformer Encoder结构

Transformer Encoder其实就是将Encoder Block 重复堆叠L次, Encoder Block结构图如下图2.4所示,主要由层归一化(Layer Norm)多头注意力机制(Multi-Head Attention)Dropout/DropPathMLP Block四部分组成。

(1)层归一化(Layer Norm)

层归一化(Layer Norm):这是一种主要针对NLP领域提出的归一化方法,这里是对每个token进行归一化处理。目前的归一化层主要有BN、LN、IN、GN和SN五种方法,它解决了深度神经网络内部协方差偏移问题,是一种将深度神经网络之间的数据进行归一化的算法,使得深度学习的训练过程中梯度变化趋于稳定,从而使网络在训练时达到快速收敛的目的。将输入的图像shape记为[N, C, H, W],这些方法的主要不同之处是,BatchNorm是在Batch上进行的,对NHW做归一化,对于较小的Batch Size没有太大的作用;LayerNorm是在通道方向上进行的,对CHW归一化,对RNN有很大的作用;InstanceNorm是在图像的像素上进行的,对HW做归一化,主要用在风格化迁移等方面;GroupNorm首先将Channel进行分组,然后再做归一化;SwitchableNorm是将BN、LN、IN结合并给予权重,让网络自己去学习归一化层应当使用的方法。

*有关BN、LN、IN、GN归一化方法的详细介绍可以看我这篇文章:神经网络常用归一化和正则化方法解析(一);

在这里插入图片描述
Layer Norm即层归一化针对神经网络的某一层的所有输入按照以下公式进行归一化操作:

H H H是某一层中隐藏结点的数量, l l l表示层数,可以计算得到Layer Norm的归一化统计量 μ l \mu^l μl σ l \sigma^l σl,如下式:

μ l = 1 H ∑ i = 1 H a i l \mu^l=\frac{1}{H}\sum_{i=1}^{H}a_i^l μl=H1i=1Hail

σ l = 1 H ∑ i = 1 H ( a l − μ l ) 2 \sigma^l=\sqrt{\frac{1}{H}\sum_{i=1}^{H}\left(a^l-\mu^l\right)^2} σl=H1i=1H(alμl)2

其中 a l a^l al表示一个中间输出结果的总和。上面的统计量和样本数没有关系,而是和隐藏层的结点数有关,我们甚至可以使 Batch Size = 1。于是,我们可以根据约定的统计量进行归一化处理,

a ^ l = a l − μ l ( σ l ) 2 + ε {\hat{a}}^l=\frac{a^l-\mu^l}{\sqrt{\left(\sigma^l\right)^2+\varepsilon}} a^l=(σl)2+ε alμl

同样,在Layer Norm中常使用参数增益(gain)和偏置(bias)这两个参数来保障归一化操作不会破坏之前的信息,同BatchNorm中的 γ \gamma γ β \beta β

y i = γ a ^ l + β y_i=\gamma{\hat{a}}^l+\beta yi=γa^l+β

从以上公式可以看到, LN中同层神经元输入拥有相同的均值和方差,不同的输入样本有不同的均值和方差。所以,LN与Batch的大小无关,也不取决于输入Sequence的深度,所以可以在batchsize为1和RNN中对边长的输入Sequence进行Normalize操作。

(2)多头注意力机制(Multi-Head Attention)

多头注意力机制(Multi-Head Attention):通过多个注意力机制的并行组合,将独立的注意力输出串联起来,预期维度得到线性地转化。直观看来,多个注意头允许对序列的不同部分进行注意力运算

对于Self-Attention来说,假设输入的token长度为 L L L,则输入为 [ x 1 , x 2 . . . x L , ] [x_1,x_2...x_L,] [x1,x2...xL,],然后分别将 x 1 x 2 . . . x L x_1x_2...x_L x1x2...xL分别通过三个变化矩阵 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv(这三个参数是可训练的、共享的)得到对应的 q i , k i , v i , q^i,k^i,v^i, qi,ki,vi, 并将 q , k , v q,k,v q,k,v向量序列记为 Q , K , V Q,K,V Q,K,V。计算过程如下式所示,具体实现时使用全连接层实现。

( Q , K , V ) = ( q i , k i , v i ) = x i ∙ ( W q , W k , W v ) (Q,K,V)=\left(q^i,k^i,v^i\right)=x_i\bullet\left(W_q,W_k,W_v\right) (Q,K,V)=(qi,ki,vi)=xi(Wq,Wk,Wv)

其中 i = 1 , 2... L i=1,2...L i=1,2...L q q q表示query,后续会去和每一个k进行匹配, k k k代表key,后续会被每个 q q q匹配, v v v代表从 x x x中提取得到的信息value,后续 q q q k k k匹配的过程可以理解成计算两者的相关性,相关性越大对应 v v v的权重也就越大。

接着将 Q Q Q中的每一个 q i q^i qi去和 K K K中的每一个 k j k^j kj进行匹配,即点积操作。然后再除以 L \sqrt L L 得到对应的 α i , j \alpha_{i,j} αi,j,这样做的目的是进行点乘后的数值很大,导致通过Softmax后梯度变的很小,所以通过除以 L \sqrt L L 来进行缩放。具体计算过程如下式所示。

α i , j = q i ( k j ) T L \alpha_{i,j}=\frac{q^i\left(k^j\right)^T}{\sqrt L} αi,j=L qi(kj)T

α i , j \alpha_{i,j} αi,j表示 x i x_i xi x j x_j xj注意程度,然后对每一行分别进行Softmax处理得到 a ^ \hat{a} a^,相当于 x j x_j xj x i x_i xi权重,即对于 v v v的权重。具体计算过程如下式所示。

a ^ i , j = S o f t m a x ( α i , j ) {\hat{a}}_{i,j}=Softmax(α_{i,j}) a^i,j=Softmax(αij)

上面已经计算得到 a ^ i , j {\hat{a}}_{i,j} a^i,j,即针对每个 v v v的权重,接着进行加权得到最终结果,如下式所示。

b i = ∑ j = 1 L a ^ i , j × v j b^i=\sum_{j=1}^{L}{{\hat{a}}_{i,j}\times v^j} bi=j=1La^i,j×vj

其中 b i b^i bi表示 x i x_i xi经过Self-Attention后的结果。以上四式的过程习惯上用以下式来统一表示。

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q ( K ) T L ) V Attention(Q,K,V)=Softmax\left(\frac{Q\left(K\right)^T}{\sqrt L}\right)V Attention(Q,K,V)=Softmax(L Q(K)T)V

对于Multi-Head Attention来说, 使用多头注意力机制能够联合来自不同head部分学习到的信息。首先根据使用的head的数目 h h h W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv权值矩阵均分成 h h h份,即 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV, 其中 i = 1 , 2... h i=1,2...h i=1,2...h,然后还是和Self-Attention模块一样将 x i x_i xi分别通过变化矩阵 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV得到对应head的 q i , k i , v i q^i,k^i,v^i qi,ki,vi, 接下来针对每个head使用和Self-Attention中相同的方法即可得到对应的结果。如下式所示。

h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) {head}_i=Attention\left(QW_i^Q,KW_i^K,VW_i^V\right) headi=Attention(QWiQ,KWiK,VWiV)

其中 Q W i Q QW_i^Q QWiQ同前式相比多了一个 W i Q W_i^Q WiQ,表示这里是根据划分的变化矩阵去计算每一个head的结果。即通过 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV映射得到每个head的 q i , k i , v i q^i,k^i,v^i qi,ki,vi,然后计算结果。
最后将每个head得到的结果进行concat拼接,接着将拼接后的结果通 过 W o 过W^o Wo(可学习的参数)进行融合,融合后得到最终的结果 b i b^i bi。如式(2-11)所示。

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , h e a d 1 … h e a d h ) W o MultiHead\left(Q,K,V\right)=Concat\left({head}_1,{head}_1\ldots{head}_h\right)W^o MultiHead(Q,K,V)=Concat(head1,head1headh)Wo

(3)Dropout/DropPath

Dropout/DropPath:在学习深度学习模型时,由于模型的参数过多、样本数量过少,导致了模型的过度拟合。在神经网络的训练中,常常会碰到一些问题。该方法具有较低的训练数据损失,具有较高的训练准确率。但是,测试数据的损失函数比较大,导致预测的准确性不高。

Dropout能在一定程度上减轻过度拟合,并能在某种程度上实现正规化。其基本原理是:在前向传播前进的过程中,使一个神经元的激活值以 p的概率不能工作,这在下面的图中可以看到。停止工作的神经元用虚线表示,与该神经元相连的相应传播过程将不在存在。这使得模型更加一般化,因为它不会依赖于一些局部特征。

DropPath类似于Dropout,不同的是Dropout 是对神经元随机“失效”,而DropPath是随机“失效”模型中的多分支结构。例如如下图右图所示,若 x x x为输入的张量,其通道为[B,C,H,W],那么DropPath的含义为一个Batch_size中,在经过多分支结构时,随机有drop_prob的样本,不经过主干,而直接经过分支(图中虚线)进行恒等映射。这在一定程度上使模型泛化性更强。
在这里插入图片描述

(4)MLP Block

MLP Block:如前文中Transformer Encoder结构图右侧所示,MLP Block由全连接层、GELU激活函数、Dropout组成,以ViT-B/16为例,第一个全连接层会把输入节点个数翻4倍(197, 768)—> (197, 3072),第二个全连接层会还原回原节点个数(197, 3072)—> (197, 768)。

3. MLP Head结构

通过Transformer Encoder后输出的维度和输入的维度是保持不变的,以ViT-B/16为例,输入的是(197, 768)输出的还是(197, 768)。这里只需要从[class]token抽取生成的对应结果,即从(197, 768)中抽取出[class]token对应的(1, 768),即为需要的分类信息。然后就可以用 MLP Head进行最后的分类得到结果。原论文中提到,在训练ImageNet21K时MLP Head是由全连接层+tanh激活函数+全连接层组成。但是如果是在ImageNet1K或者自己的数据集上时,只需要使用一个全连接层(Linear)即可,其结构如下图所示。
在这里插入图片描述

二、PyTorch实现

ViT模型共有三个不同的规模,如下所示:
。

1. 首先安装vit-pytorch库:

$ pip install vit-pytorch

2.导入进行调用:

import torch
from vit_pytorch import ViT

model = ViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
 
imgs = torch.randn(1, 3, 224, 224)
 
preds = model(imgs) # (1, 1000)

总结

以上就是对Vision Transformer模型架构的详细介绍及其适用,Vision Transformer模型作为第一个将Transformer结构应用到计算机视觉上的模型,对近年来计算机视觉的研究具有很大的意义,其常常与swin Transformer(可以理解为FPN结构的ViT)用作其他任务如检测、分割的backbone以及视觉特征提取器。

参考:
Attention is all you need
An image is worth 16x16 words: Transformers for image recognition at scale

文中图片大多来自论文和网络,如有侵权,联系删除,文中有不对的地方欢迎指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/238788.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

搭建商城系统的构架如何选择?

近期有很多网友在csdn、gitee、知乎的评论区留言,搭建商城系统是选择单体架构还是微服务架构,这里先说结论,如果是纯电商的话,商城系统的架构建议选择单体架构。我们分析下微服务和单体架构的优劣势,就知道了。 一、什…

C语言——结构体

一、结构的基础知识 结构是一些值的集合,这些值称为成员变量结构的,每个成员可以是不同类型的变量。 二、结构的声明 struct tag {member-list; }variable-list; 描述一个学生: typedef struct Student {char name[20]; //姓名int age; …

Linux安装MySQL数据库系统

1、MySQL的编译安装。 1.1、准备工作 (1)为了避免发生端口冲突、程序冲突等现象,建议先查询MySQL软件的安装情况,确认没有使用以RPM方式安装的mysql-server、mysql软件包,否则建议将其卸载。 [rootlocalhost ~]# rp…

关系型数据库-SQLite介绍

优点: 1>sqlite占用的内存和cpu资源较少 2>源代码开源,完全免费 3>检索速度上十几兆、几十兆的数据库sqlite很快,但是上G的时候最慢 4>管理简单,几乎无需管理。灵巧、快速和可靠性高 5>功能简…

【产品设计】软件系统三基座之一:权限管理

不同的员工在公司享有不同的权限,用户可以访问而且只能访问自己被授权的资源。那么,权限管理功能要如何设计呢? 软件系统三基座包含:权限管理、组织架构、用户管理。 何为基座,即是有了这些基础,任一相关的…

边缘计算系统设计与实践

随着科技的飞速发展,物联网和人工智能两大领域的不断突破,我们看到了一种新型的计算模型——边缘计算的崛起。这种计算模型在处理大规模数据、实现实时响应和降低延迟需求方面,展现出了巨大的潜力。本文将深入探讨边缘计算系统的设计原理和实…

13、RockerMQ消息类型之广播与集群消息

RocketMq中提供两种消费模式:集群模式和广播模式。 集群模式 集群模式表示同一个消息会被同一个消费组中的消费者消费一次,消息被负载均衡分配到同一个消费者上的多个实例上。 还有另外一种平均的算法是AllocateMessageQueueAveragelyByCircle&#xff…

windows下docker环境安装

开启硬件虚拟化技术 win10中开启 Hyper-V Win10 下是否开启硬件虚拟化技术,在控制面板,启用 window 功能,找到 Hyper-V 选项,点勾选确认。如图: Windows 11 家庭中文版新增 Hyper-V选项 注意以下的解决方案来自win1…

带你手把手 解读 firejail 沙盒源码(0.9.72版本)目录和组件 (一)

文章目录 关于firejail 的介绍src 目录每个文件夹(组件)的意义文件目录树 关于firejail 的介绍 Firejail 是一个用于 Linux 系统的安全工具,它通过创建轻量级的沙箱环境来运行应用程序。这种沙箱环境将应用程序与系统其余部分隔离&#xff0…

openEuler 20.03 (LTS-SP2) aarch64 cephadm 部署ceph18.2.0【5】 添加osd存储节点

接上篇 openEuler 20.03 (LTS-SP2) aarch64 cephadm 部署ceph18.2.0【1】离线部署 准备基础环境-CSDN博客 openEuler 20.03 (LTS-SP2) aarch64 cephadm 部署ceph18.2.0【2】离线部署 podman配置registries 部署registry私服 准备离线镜像-CSDN博客 openEuler 20.03 (LTS-SP2…

Python手撕kmeans源码

参考了两篇文章 K-Means及K-Means算法Python源码实现-CSDN博客 使用K-means算法进行聚类分析_kmeans聚类分析结果怎么看-CSDN博客 # 定义kmeans类 from copy import deepcopy from sklearn.datasets import make_blobs import numpy as np import matplotlib.pyplot as pltc…

如何充分准备面试,迅速融入团队并在工作中取得卓越成就

首先,关于如何筹备面试,首先需要对所申请公司与职位进行深入的调查了解,并依据可能提出的面试问题预先准备相应的答案,并提前调试面试所需的仪器设备。同时,也要注重自身形象的塑造。更为关键的是 1. 在计算机领域的面…

redis-学习笔记(Jedis)

自定义的 Redis 客户端 咱们可以实现编写出一个自定义的 Redis 客户端 因为 Redis 公开了自己使用的自定义协议 ---- RESP 协议清楚了, 那么通信数据格式就清除了, 就能完成各层次之间的数据传输, 就能开发服务器和客户端 RESP — Redis 的 序列化 协议 特点: 简单好实现快读进…

ETLCloud的应用策略——实时数据处理是关键

一、ETLCloud是什么? ETLCloud又称数据集成(DataOps),是RestCloud旗下的一款数据仓库管理工具,通过自动化数据转换和集成来实现企业内部和外部数据的无缝对接,从而帮助企业快速获取准确的数据信息&#xff…

活动预告 | 微盟技术沙龙 - Elasticsearch 在微盟的实践 12/21/2023

微盟技术沙龙 「微盟技术沙龙」是由微盟研发中心发起并联合各方小伙伴为开发者举办的系列技术沙龙,从用户,产品,技术等方面与开发者进行交流。 微盟技术沙龙关注开发者在实际应用中遇到的问题。提供最真实的干货,以技术会友&…

【SpringBoot】从入门到精通的快速开发指南

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的专栏《SpringBoot》。🎯🎯 &…

[MySQL]SQL优化之sql语句优化

🌈键盘敲烂,年薪30万🌈 目录 一、索引优化 回顾: 📕索引分类: 📕索引失效: 📕设计原则: 📕SQL性能分析 二、SQL优化 语句优化 &#x1f4d…

Gorm 的关联查询

背景介绍 gorm 与 mybatis-plus 、hibernate 等 ORM 框架一样,为了应对查询场景居多的现象,支持原生 sql 和 api 两种方式读数据库。 gorm 原生 sql 参见:https://gorm.io/docs/sql_builder.html。 gorm 提供的 api 支持关联插入、关联查询…

Jmeter入门

一、下载jmeter 官网下载 下载之后解压,在目录/bin下面找到jmeter.bat双击之后即可启动Jmeter。 二、使用 如下左图,选择语言为中文,可以修改测试计划的名称。如下右图,添加线程组 添加线程组 添加http请求 路径传参方式 …

Linux——MySQL数据库系统

一、 MySQL的编译安装 1、准备工作 (1)为了避免发生端口冲突,程序冲突等现象,建议先查询MySQL软件的安装情况,确认没有使用以Rpm方式安装的mysql-server、mysql软件包,否则建议将其卸载 [rootlocalhost ~]…