【AI】计算机视觉VIT文章(Transformer)源码解析

论文:Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020

源码的Pytorch版:https://github.com/lucidrains/vit-pytorch

0.前言

Transformer提出后在NLP领域中取得了极好的效果,其全Attention的结构,不仅增强了特征提取能力,还保持了并行计算的特点,可以又快又好的完成NLP领域内几乎所有任务,极大地推动自然语言处理的发展。
VIT这篇文章就是将Transformer模型应用在了CV领域,它将图像处理成Transformer模型可以应用的形式,沿用NLP领域中Transformer的方法,直接验证了其精度可以和ResNet不相上下,展示了在计算机视觉中使用纯Transformer结构的可能,为Transformer在CV领域的应用打开了大门。

直接读文章通常比较抽象,英文的原文更能劝退一大部分人,但对于程序员来说,代码是通行于世界的语言,理解起来就比较简单,结合源码理解论文中的结构,就比较事半功倍。

在这里插入图片描述

上图是VIT文章中的结构,我们看图提问题,从数据的流向来看:图像怎么切分重排的?Linear Projection of Flattened Patches对图像作了什么,怎么让图像变成Transformer能够输入的格式?
Position Embedding是怎么做图像位置编码的,为什么会多出来一个0的位置编码?Transformer Encoder中的各个结构分别代表什么,是怎么实现的?输出的类别是什么?

1.图像是如何适配Transformer输入的?

Transformer的输入是一个一维的向量,而我们的图像是二维的,需要把图像拉伸成一维的,最简单的方式就是沿着x轴展开,将所有的行拼接在一起,也就是Flatten的操作,但是这样处理会导致向量维度比较大,而且同一张图片也只能生成一个Embedding,不能适配Multi-head Embedding(这种解释有点牵强,有点拿结果去解释原因)。

1.1 图像切分重排

如结构图所示,VIT采取了图像切分重排的方式,将一个完整的图像按照行列的方向切分成小块,然后再进行后续处理。切分重排是怎么实现的,我们看代码:

self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), #图片切分重排
            nn.Linear(patch_dim, dim), # Linear Projection of Flattened Patches
        )

这里的Rearrange使用到了einops的库,相关的介绍可以查看文档
这里主要是对b c (h p1) (w p2) -> b (h w) (p1 p2 c)表达式的理解,b是batchsize,c是chanel,h和w是图像的高度和宽度,表达式前边的(h p1)表示将输入的h重新拆分成一个h*p1的向量,同理(w p2)是将输入的w拆分成w*p2的向量,这里的p1和p2是模型定义的patch_height和patch_width,可以理解为切分后小图的高和宽;
表达式右边表示输出向量的维度,(p1 p2 c)表示这3个数相乘,表示的是切分后一个小图的一维向量大小,(h w)则表示总共有多少个切分后的小图所生成的向量。这里的h和w的值跟输入的h和w值不同,表示的是原有h和w除以patch_height或patch_width的值,也就是在高和宽上各能切分出几个小图。
通过这样的处理,就将输入的c个channel的h*w的二维向量,转换成小图展开后的一维向量。

1.2 构建patch0

图像在输入Transformer之前,还连接了一个patch0,这一步的操作可以认为是延续了nlp的操作,在VIT这边的操作,可以认为是将分类类别拼接上去。

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)

1.3 Positional Embedding

为了在模型中引入位置信息,VIT引入了位置编码的形式,也就是图中的0、1、2、3…

x += self.pos_embedding[:, :(n + 1)]

然后将position_embedding与图像的Embedding相加。

2.Transformer Encoder中的各个结构分别代表什么,是怎么实现的?

结构中的Norm可以认为是归一化处理,MLP是多个全连接层,理解起来都比较简单。其实最主要的结构就是Multi-head Attention。
在这里插入图片描述

2.1 Attention 定义

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        #得到qkv
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

2.2 Attention 的理解

VIT不同于之前RNN的地方就是引入了Attention机制,其本质可以认为是相似度计算,是计算每一个输入值与其他输入值的相似度,然后带入到计算中,如图所示,q是输入的查询值,k是关键词,v是计算值,计算每一个q与其他k的相似度,然后再带入计算中去。
在这里插入图片描述

其对应的计算公式为:
在这里插入图片描述
相似度计算是用的点积的形式,除以根号下dk是为了抑制极端值,保证softmax之后数值不至于丧失梯度。对应的代码如下:

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

然后过一个softmax

attn = self.attend(dots) # self.attend = nn.Softmax(dim = -1)
attn = self.dropout(attn)

然后跟value值相乘

out = torch.matmul(attn, v)

2.3 Multihead Attention

多头注意力机制可以认为是定义多个Attention(每个attention关注的重点不同)分别来对数据进行处理,这里从源码中的循环结构可以体现出来:

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

3.使用示例

这里可以参考官方的readme文档

$ pip install vit-pytorch

使用示例

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/274182.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

QT、C++实验室管理系统

一、需求介绍: 题目:基于Qt的实验室管理系统的设计 项目命名以LabSystem姓名拼音首字母(例如: LabSystemwXC) 功能要求: 一,基本必要功能: 1,使用QSQLITE数据库完成数据库的设计。 2,注册功能:包含学生注册&#xff0…

鸿蒙项目二—— 注册和登录

此部分和上篇文章是连续剧 ,如果需要,请查看 一、注册 import http from ohos.net.http; Entry Component struct Reg {// 定义数据:State username: string "";State userpass: string "";State userpass2: string …

【网络技术】【Kali Linux】Wireshark嗅探(一)ping和ICMP

一、实验目的 本次实验使用wireshark流量分析工具进行网络嗅探,旨在了解ping命令的原理及过程。 二、网络环境设置 本系列实验均使用虚拟机完成,主机操作系统为Windows 11,虚拟化平台选择Oracle VM VirtualBox,组网模式选择“N…

python嵌套异常处理器

1 python嵌套异常处理器 python的异常处理器支持嵌套。 1.1 嵌套的try/except处理器 用法 def f1():raise E def f2():try:f1()except E:pass try:f2() except E:pass描述 嵌套的try/except处理器,发生异常时,控制权会跳回具有相符的except分句、最近…

SAP PP 配置学习(三)

Classification 分类 关联特征值 – (省市联动) 关联特征显示 一个特征是否输入,根据另一个特征来判断。如:只有输入了省份,才需要输入城市。没输省份前,城 市这个特征是不可见的。 修改【城市】特征. 在【城市】特征值中&#xf…

无人职守自动安装linux操作系统

无人职守自动安装linux操作系统 1. 大规模部署案例2. PXE 技术3. Kickstart 技术4. 配置安装服务器4.1 DHCP服务4.2 TFTP 服务4.3 NFS服务 5. 示例5.1 搭建server1. 启动dhcp并设为开机自启2. 设置并启动tftp3. 将客户端所需启动文件复制到TFTP服务器4. 创建Kickstart自动应答文…

C语言操作符逻辑与,逻辑或面试真题(2)

各位少年&#xff1a; 今天给大家分享几个代码示例&#xff0c;希望能帮助能从学习的方面&#xff0c;帮助大家。 #include<stdio.h> int main() { int i0,a0,b2,c3,d4; ia&&b&&d; printf("a%d\n b%d\n c%d\nd%d",a,b,c,d); return 0; } 大…

虚函数的讲解

文章目录 虚函数的声明与定义代码演示基类Person派生类Man派生类Woman 测试代码动态绑定静态绑定访问私有虚函数总结一下通过成员函数指针调用函数的方式 虚函数的声明与定义 虚函数存在于C的类、结构体等中&#xff0c;不能存在于全局函数中&#xff0c;只能作为成员函数存在…

Cucumber-JVM的示例和运行解析

Cucumber-JVM 是一个支持 Behavior-Driven Development (BDD) 的 Java 框架。在 BDD 中&#xff0c;可以编写可读的描述来表达软件功能的行为&#xff0c;而这些描述也可以作为自动化测试。 Cucumber-JVM 的最小化环境 Cucumber-JVM是BDD的框架&#xff0c; 提供了GWT语法的相…

Servlet获取前端请求的参数和中文乱码的解决方案

目录 1.Servlet获取前端请求的参数 1.1创建jsp 1.2构建servlet实例 1.3配置web.xml 2.中文乱码的解决方案 2.1请求时候的乱码问题 2.2响应时候中文乱码的问题 学好Servlet必须紧紧围绕着请求和响应这两个概念。 下面开始写在请求的时候前端带数据到servlet里面&#xff…

单词接龙[中等]

一、题目 字典wordList中从单词beginWord和endWord的 转换序列 是一个按下述规格形成的序列beginWord -> s1 -> s2 -> ... -> sk&#xff1a; 1、每一对相邻的单词只差一个字母。 2、对于1 < i < k时&#xff0c;每个si都在wordList中。注意&#xff0c;beg…

ubuntu 在线安装 python3 pip

ubuntu 在线安装 python3 pip 安装 python3 pip sudo apt -y install python3 python3-pip升级 pip python3 -m pip install --upgrade pip

【CVPR2023】可持续检测的Transformer用于增量对象检测

目录 导读 动机 本文贡献 相关工作/概念 本文方法 实验 实验结果 消融实验 可视化结果 结论 代码已开源&#xff1a;https://github.com/yaoyao-liu/CL-DETR 导读 本文旨在解决增量目标检测&#xff08;IOD&#xff09;问题&#xff0c;模型需要逐步学习新的目标类别…

数据驱动与数据安全,自动驾驶看得见的门槛和看不见的天花板

作者 |田水 编辑 |德新 尽管心理有所准备&#xff0c;2023年智能驾驶赛道的内卷程度还是超出了大多数人的预期。 这一年&#xff0c;汽车价格战突然开打&#xff0c;主机厂将来自销售终端的价格压力&#xff0c;传导到下游智驾供应商&#xff0c;于是&#xff0c;市面上出现…

功能测试知识超详细总结

一、测试项目启动与研读需求文档 &#xff08;一&#xff09; 组建测试团队 1、测试团队中的角色 2、测试团队的基本责任 尽早地发现软件程序、系统或产品中所有的问题。督促和协助开发人员尽快地解决程序中的缺陷。帮助项目管理人员制定合理的开发和测试计划。对缺陷进行跟…

Python+OpenCV 零基础学习笔记(1):anaconda+vscode+jupyter环境配置

文章目录 前言相关链接环境配置&#xff1a;AnacondaPython配置OpenCVOpencv-contrib:Opencv扩展 Notebook:python代码笔记vscode配置配置AnacondaJupyter文件导出 前言 作为一个C# 上位机&#xff0c;我认为上位机的终点就是机器视觉运动控制。最近学了会Halcon发现机器视觉还…

基于SpringBoot + Vue的图书管理系统的设计与实现

基于SpringBoot Vue的图书管理系统的设计与实现 目录 1 引言 1.1 编写目的 1.2 项目背景 1.3 参考资料 2 总体设计 2.1 需求概述 2.2 软件结构 3 模块设计 3.1 模块基本信息 3.2 功能概述 3.3 算法 3.4 模块处理逻辑 4 数据库设计 4.1 ER图表 4.…

网页乱码问题(edge浏览器)

网页乱码问题&#xff08;edge&#xff09; 文章目录 网页乱码问题&#xff08;edge&#xff09;前言一、网页乱码问题1.是什么&#xff1a;&#xff08;描述&#xff09;2.解决方法&#xff1a;&#xff08;针对edge浏览器&#xff09;&#xff08;1&#xff09;下载charset插…

【力扣题解】P404-左叶子之和-Java题解

&#x1f468;‍&#x1f4bb;博客主页&#xff1a;花无缺 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 花无缺 原创 收录于专栏 【力扣题解】 文章目录 【力扣题解】P404-左叶子之和-Java题解&#x1f30f;题目描述&#x1f4a1;题解&#x1f30f;总结…

量子密码学简介

量子密码学&#xff08;英语&#xff1a;Quantum cryptography&#xff09;泛指利用量子力学的特性来加密的科学。量子密码学最著名的例子是量子密钥分发&#xff0c;而量子密钥分发提供了通信两方安全传递密钥的方法&#xff0c;且该方法的安全性可被信息论所证明。目前所使用…