Visual Transformer (ViT)模型详解

1 Vit简介

1.1 Vit的由来

ViT是2020年Google团队提出的将Transformer应用在图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型“简单”且效果好,可扩展性强(scalable,模型越大效果越好),成为了transformer在CV领域应用的里程碑著作,也引爆了后续相关研究。

论文地址:https://arxiv.org/pdf/2010.11929.pdf

Visual Transformer (ViT) 出自于论文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在视觉领域的开篇之作。ViT模型是基于Transformer Encoder模型的。

1.2 Vit如何工作

我们知道Transformer模型最开始是用于自然语言处理(NLP)领域的,NLP主要处理的是文本、句子、段落等,即序列数据。但是视觉领域处理的是图像数据,因此将Transformer模型应用到图像数据上面临着诸多挑战,理由如下:

  • 与单词、句子、段落等文本数据不同,图像中包含更多的信息,并且是以像素值的形式呈现。

  • 如果按照处理文本的方式来处理图像,即逐像素处理的话,即使是目前的硬件条件也很难。

  • Transformer缺少CNNs的归纳偏差,比如平移不变性和局部受限感受野。

  • CNNs是通过相似的卷积操作来提取特征,随着模型层数的加深,感受野也会逐步增加。但是由于Transformer的本质,其在计算量上会比CNNs更大。

  • Transformer无法直接用于处理基于网格的数据,比如图像数据。

为了解决上述问题,Google的研究团队提出了ViT模型,它的本质其实也很简单,既然Transformer只能处理序列数据,那么我们就把图像数据转换成序列数据就可以了呗。下面来看下ViT是如何做的。

1.3 ViT模型架构

我们先结合下面的动图来粗略地分析一下ViT的工作流程,如下:

  • 将一张图片分成patches

  • 将patches铺平

  • 将铺平后的patches的线性映射到更低维的空间

  • 添加位置embedding编码信息

  • 将图像序列数据送入标准Transformer encoder中去

  • 在较大的数据集上预训练

  • 在下游数据集上微调用于图像分类

 模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder
  • MLP Head(最终用于分类的层结构)

Embedding层

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。

对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。如下图所示,首先将一张图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

Transformer Encoder

Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

​​
MLP Head

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。


2 ViT工作原理

我们将上图展示的过程近一步分解为6步,接下来一步一步地来解析它的原理。如下图:

2.1 步骤1、将图片转换成patches序列

这一步很关键,为了让Transformer能够处理图像数据,第一步必须先将图像数据转换成序列数据,但是怎么做呢?假如我们有一张图片x\epsilon R^{H*W*C},patch大小为p,那么我们可以创建N个图像patches,可以表示为x_{p}\epsilon R^{N*(p^{2}c)},其中N=\frac{HW}{P^{2}}N就是序列的长度,类似一个句子中单词的个数。在上面的图中,可以看到图片被分为了9个patches。

2.2 步骤2、将Patches铺平

在原论文中,作者选用的patch大小为16,那么一个patch的shape为(3,16,16),维度为3,将它铺平之后大小为3x16x16=768。即一个patch变为长度为768的向量。不过这看起来还是有点大,此时可以使用加一个Linear transformation,即添加一个线性映射层,将patch的维度映射到我们指定的embedding的维度,这样就和NLP中的词向量类似了。

2.3 步骤3、添加Position embedding

与CNNs不同,此时模型并不知道序列数据中的patches的位置信息。所以这些patches必须先追加一个位置信息,也就是图中的带数字的向量。实验表明,不同的位置编码embedding对最终的结果影响不大,在Transformer原论文中使用的是固定位置编码,在ViT中使用的可学习的位置embedding 向量,将它们加到对应的输出patch embeddings上。

2.4 步骤4、添加class token

在输入到Transformer Encoder之前,还需要添加一个特殊的class token,这一点主要是借鉴了BERT模型。添加这个class token的目的是因为,ViT模型将这个class token在Transformer Encoder的输出当做是模型对输入图片的编码特征,用于后续输入MLP模块中与图片label进行loss计算。

2.5 步骤5、输入Transformer Encoder

将patch embedding和class token拼接起来输入标准的Transformer Encoder中。 Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

2.6 步骤6、分类

注意Transformer Encoder的输出其实也是一个序列,但是在ViT模型中只使用了class token的输出,将其送入MLP模块中,去输出最终的分类结果。

3 模型搭建参数

在论文的Table1中有给出三个模型(Base/ Large/ Huge)的参数,在源码中除了有Patch Size为16x16的外还有32x32的。

其中:

Layers就是Transformer Encoder中重复堆叠Encoder Block的次数 L。
Hidden Size就是对应通过Embedding层(Patch Embedding + Class Embedding + Position Embedding)后每个token的dim(序列向量的长度)
MLP Size是Transformer Encoder中MLP Block第一个全连接的节点个数(是token长度的4倍)
Heads代表Transformer中Multi-Head Attention的heads数。
 

4 结果分析

上表是论文用来对比ViT,Resnet(和刚刚讲的一样,使用的卷积层和Norm层都进行了修改)以及Hybrid模型的效果。通过对比可得出结论:

  • 在训练epoch较少时Hybrid优于ViT -> Epoch小选Hybrid
  • 当epoch增大后ViT优于Hybrid -> Epoch大选ViT
     

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/283041.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

strlen的自定义函数的三种实现方法

文章目录 一,概要二,题目三,方法1,方法一2,方法二3,方法三 四,总结 一,概要 在求字符串长度的时候我们会遇到求解字符串的长度,此时我们想到的是strlen,头文件…

如何跨系统构建docker镜像?

转载说明:如果您喜欢这篇文章并打算转载它,请私信作者取得授权。感谢您喜爱本文,请文明转载,谢谢。 1. 前言 docker镜像有基于amd64系统的,也有基于arm64系统的。 前段时间用了一个在x86_64的centos7服务器上构建的i…

SpringBoot自动配置原理和自定义启动器

1、自动配置的原理 项目在加载上下文时,会根据SpringBootApplication注解运行。该注解中有一个CompoentScan注解,会扫描和加载当前启动类所在的目录,以及所有的子目录;还有一个是EnableAutoConfiguration注解,这个注解…

huggingface的tokenizer解读

文章目录 前言一、huggingface的tokenizer含义1、含义2、整体概括 二、加载lmsys/vicuna-7b-v1.5模型的tokenizer三、调用tokernizer方法四、字符串的tokens应用1、tokenizer应用2、tokenizer进行token分词(tokenizer.tokenize)3、tokens转模型识别ids(tokenizer.convert_token…

Centos7:Jenkins+gitlab+node项目启动(2)

Centos7:Jenkinsgitlabnode项目启动(1) Centos7:Jenkinsgitlabnode项目启动(1)-CSDN博客 Centos7:Jenkinsgitlabnode项目启动(2) Centos7:Jenkinsgitlabnode项目启动(2)-CSDN博客 Centos7:Jenkinsgitlabnode项目启…

数据结构【线性表篇】(三)

数据结构【线性表篇】(三) 文章目录 数据结构【线性表篇】(三)前言为什么突然想学算法了?为什么选择码蹄集作为刷题软件? 目录一、双链表二、循环链表三、静态链表 结语 前言 为什么突然想学算法了? > 用较为“官方…

回溯法求不等式的所有整数解

这份代码本来是用来解决这个问题的 但是&#xff0c;修改之后即可用来解决任意多个xi组成的满足不等式的整数解 这里用真代码而不是伪代码来表示 源代码&#xff1a; #include<iostream> using namespace std; const int N1010; int p,q,r,goal,n; int cnt,sum,MIN; i…

ES6之生成器(Generator)

✨ 专栏介绍 在现代Web开发中&#xff0c;JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性&#xff0c;还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言&#xff0c;JavaScript具有广泛的应用场景&#x…

挑战Python100题(9)

100+ Python challenging programming exercises 9 Question 81 Please write a program to randomly print a integer number between 7 and 15 inclusive. Hints: Use random.randrange() to a random integer in a given range. 请编写一个程序,随机打印一个介于7和15之间…

Java学习,一文掌握Java之SpringBoot框架学习文集(1)

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

2021-06-25 51蛋骗鸡按键切合LED

缘由ISIS 7 Professional_有问必答-CSDN问答 #include "REG52.h" sbit K1 P3^0; sbit K2 P3^1; sbit K3 P3^2; sbit K4 P3^3; void main() {unsigned char Xd0,xz0,cs0;unsigned int wei0;P1255;while(1){if(K10&&Xd0){P10;while(K10);}if(K20&&…

【2023 CCF 大数据与计算智能大赛】基于TPU平台实现超分辨率重建模型部署 基于预训练ESPCN的轻量化图像超分辨率模型TPU部署方案

2023 CCF 大数据与计算智能大赛 《基于TPU平台实现超分辨率重建模型部署》 作品名&#xff1a;基于预训练ESPCN的轻量化图像超分辨率模型TPU部署方案 队伍名&#xff1a;Absofastlutely 蒋松儒 计算机科学与技术系 硕士 南京大学 中国-江苏 kahsoltqq.com 吕欢欢 计算…

AI产品经理 - 如何做一款软硬协同AI产品

【背景】从0做一款软硬协同的AI产品&#xff0c;以智能医药保温箱 1.以智能医药保温箱 2.调研定义市场方向 地点&#xff1a;医药、实验室 场景&#xff1a;长宽高/装箱/运输/实验室 3.需求挖掘 4.如何进行软硬件AI产品工作 软硬件产品设计&#xff1a;功能/硬件外观设计、…

《数据库开发实践》之存储过程【知识点罗列+例题演练】

一、什么是存储过程&#xff1f; 1.概念理解&#xff1a; 存储过程是一组为了完成特定功能的SQL语句集。通过组成SQL语句和控制语句&#xff0c;提供一种封装任务的方法。因此在创建编译好某个存储过程后&#xff0c;因为存储过程中有可执行操作的sql语句&#xff0c;用户可以…

OFDM——PAPR减小

文章目录 前言一、PAPR 减小二、MATLAB 仿真1、OFDM 信号的 CCDF①、MATLAB 源码②、仿真结果 2、单载波基带/通频带信号的 PAPR①、MATLAB 源码②、仿真结果 3、时域 OFDM 信号和幅度分布①、MATLAB 源码②、仿真结果 4、Chu 序列和 IEEE802.16e 前导的 PAPR①、MATLAB 源码②…

模型 KANO卡诺模型

本系列文章 主要是 分享 思维模型&#xff0c;涉及各个领域&#xff0c;重在提升认知。需求分析。 1 卡诺模型的应用 1.1 餐厅需求分析故事 假设你经营一家餐厅&#xff0c;你想了解客户对你的服务质量的满意度。你可以使用卡诺模型来收集客户的反馈&#xff0c;并分析客户的…

MySQL的日志管理以及备份和恢复

MySQL日志管理 mysql的日志默认保存位置为/usr/local/mysql/data vim /etc/my.cnf #开启二进制日志功能 vim /etc/my.cnf [mysqld]##错误日志&#xff0c;用来记录当MySQL启动、停止或运行时发生的错误信息&#xff0c;默认已开启 log-error/usr/local/mysql/data/mysql_…

Python从入门到网络爬虫、自动化

可以创建C、C#、Python、Golang、Java、React、Node、Vue、PHP项目 创建Java项目 创建Python项目 简单if……else……语句 # 简单的if……else……语句 state True if state:print("状态正常") else:print("状态异常")# 复杂的if……elif……语句 score …

基于 LangChain + GLM搭建知识本地库

一种利用 langchain 思想实现的基于本地知识库的问答应用&#xff0c;目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。 受GanymedeNil的项目document.ai和AlexZhangji创建的ChatGLM-6B Pull Request启发&#xff0c;建立了全流程可使用开源模…

【Linux C | 文件I/O】文件数据的同步 | sysc、fsync 和 fdatasync 函数

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…