图解Vit 2:Vision Transformer——视觉问题中的注意力机制

文章目录

  • Patch Embedding 回顾
  • Seq2Seq中的attention
  • Transformer中的attention

Patch Embedding 回顾

上节回顾在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

Seq2Seq中的attention

在Transformer之前的RNN,其实已经用到了注意力机制。Seq2Seq。
在这里插入图片描述
对于Original RNN,每个RNN的输入,都是对应一个输出。对于original RNN,他的输入和输出必须是一样的。
在处理不是一对一的问题时,提出了RNN Seq2Seq。也就是在前面先输入整体,然后再依次把对应的输出出来。
在这里插入图片描述
虽然Seq2Seq解决了输入和输出不定是相同长度的问题,但是我们所有信息都存在模型的一定地方,我们叫上下文,或者叫hidden state。又由于输入的都是同一个模型,每次都更新同一个位置,那么当我们的句子很长,或者是一个段落时,可能这个上下文就不会work,因为我们decoder的所有信息都是来自上下文的。效果不够好。把很多信息输入,就是encoder。后面把上下文信息解析出来,就是decoder。
很多学者想了办法去改进它。希望把前面时间段的信息,传递给解码decoder的时候。如下图所示。
在这里插入图片描述
我们除了在encoder的部分传递h,还会多存一份p,直接传给decoder。
RNN每次都是传递一个h。h是隐变量,高层的语义信息。c就是attention。它等于前面所有时间点的语义信息分别乘以a,再sum。c看到了前面的所有时间点,如果是一个句子,就是看到了句子里的所有token。c看到了h1-hn。那它看到的谁更重要,是a1-an控制的。那么a应该怎么设置呢?最好的办法是,是让a可学习。即通过大量的句子数据训练一个网络,让c1明白,他应该更关注前面句子里的哪个token,哪个token的a就是最大的。

Transformer中的attention

上述是RNN中的attention机制,下面来论述attention在Transformer中是如何工作的。

x1-x3都是image token,也就是patch embedding后的token特征。他们首先会做一个projection。这里使用了神经网络,或升维,或降维。得到Vector v。v和权重a相乘再相加,得到了attention c。注意,这里的c1是给x1做的attention。
在这里插入图片描述
我们又在x的这里另开了一个网络,对x进行另一个网络的projection Projk。得到另一个feature vector k1。他和v1可能维度不同,也可以相同。
x,k,v都是feature vector。a是通过两个k的点积得到的。这里a是一个scalar数值。k1会分别和k1,k2,k3进行点击,得到a1,a2,a3,a称为attention weights。注意里面的Projk是同一个,并且是可学习的。Projk可学习,就相当于a是可学习的,也相当于c是可学习的。
在这里插入图片描述
现在我们多出来一个Wq分支。q即query,也就是查询。q和k做点积,和上面讲的k与k做点积并没有不同,也就是我们不再通过k去做点积,而是通过q。query的作用就是去查询与key的相关性。比如q1,当它和其它k1,k2,k3点积加权得到attention c1时,c1就是表示x1与其它x的相关性。上述方法也就是让query和key进行了分离,key作索引功能,query作查询功能。给谁算attention,就用谁的query点积别人的key(包括自己的)。
在这里插入图片描述
在这里插入图片描述
这里x是vetor,就是一个patch feature,projection其实也是embedding,即提取特征。所以W的列其实就是embedd_dim(本质即卷积操作)。我们的attention是针对每一个token的。x1,x2,x3就是每一个单词token,或者图像的patch。Wq,Wk,Wv参数是不同的,他们都是可学习的。
在这里插入图片描述
p是attention weight。attention是表达,是token通过Transformer计算出来的,一个feature vector和其它vector的相关性,就是通过p表达的。

在这里插入图片描述
我们为什么要除以dk。variance方差表示数据的离散程度。如果variance值很大时,对于softmax,他会更偏向更大的那个值。如果variance更小,softmax波动就没有那么大。为了避免softmax在更大值上,我们需要把variance拉回来一点,让我们的attention更稳定一点,不能只盯着一个人看,让注意力更均衡一点,雨露均沾。这里是softmax写错了,要写最外面。

也就是我们输入多少个token,我们输出的attention z还是多少个。以上内容就是全部的self-attention了。

下面我们再讲讲multi-head self-attention。
在这里插入图片描述
multi-head self-attention就是你看你的,我看我的。让不同的人去看相同的序列信息,qkv进行复制,但是他们的参数是不同的,然后最后集众家所长。大家最后统一一下意见。这个统一也是learnable的。z输出也是不变维度,还是n行(和token个数对应)。

下面说下如何进行高效的attention计算,用矩阵计算。
在这里插入图片描述
将Multi-Head Attention带回Vit整体结构中,如下图所示。
在这里插入图片描述
下面是Attention class的代码:

import paddle
import paddle.nn as nn

paddle.set_device('cpu')

class Attention(nn.Layer):
    def __init__(self, embed_dim, num_heads, qkv_bias, qk_scale, dropout=0., attention_dropout=0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = int(embed_dim / num_heads)
        self.all_head_dim = self.head_dim * num_heads # 避免不能整除
        self.qkv = nn.Linear(embed_dim, 
                            self.all_head_dim * 3,
                            bias_attr=False if qkv_bias is False else None) # bias=None,在paddle里是默认给0
        self.scale = self.head_dim ** -0.5 if qk_scale is None else qk_scale
        self.softmax = nn.Softmax(-1)
        self.proj = nn.Linear(self.all_head_dim, embed_dim)

    def transpose_multi_head(self, x):
        new_shape = x.shape[:-1] + [self.num_heads, self.head_dim]
        x = x.reshape(new_shape) # [B, N, num_heads, head_dim]
        x = x.transpose([0, 2, 1, 3]) # [B, num_heads, N, head_dim]
        return x

    def forward(self, x):

        # [B, N, all_head_dim] * 3
        B, N, _ = x.shape
        qkv = self.qkv(x).chunk(3, -1) # [B, N, all_head_dim] * 3
        q, k, v = map(self.transpose_multi_head, qkv) # q,k,v: [B, num_heads, N, head_dim]

        attn = paddle.matmul(q, k, transpose_y=True) # q * k^t
        attn = self.scale * attn
        attn = self.softmax(attn) # [B, num_heads, N]
        attn_weight = attn
        # dropout
        # attn:[B, num_heads, N, N]
        

        out = paddle.matmul(attn, v) 
        out = out.transpose([0, 2, 1, 3]) # attn:[B, N, num_heads, head_dim]
        out = out.reshape([B, N, -1])
        out = self.proj(out)
        # dropout

        return out, attn_weight



def main():
    t = paddle.randn([4, 16, 96])
    print('input shape = ', t.shape)

    model = Attention(embed_dim=96, num_heads=8, 
                      qkv_bias=False, qk_scale=None, dropout=0., attention_dropout=0.)
    print(model)
    out, attn_weights = model(t)
    print(out.shape) # [4, 16, 96]
    print(attn_weights.shape) # [4, 8, 16, 16] 8是num_heads,8个人去看; N(num_patch)=16, 16个img token互相看


if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/40156.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

电路分析 day01 一种使能控制电路

本次分析的电路为 一种使能控制电路 (站在别人的肩膀上学习) 资料来源 : 洛阳隆盛科技有限责任公司的专利 申请号:CN202022418360.7 1.首先查看资料了解本次电路 1.1 电路名称: 一种使能控制电路 1.2 电路功能…

Midjourney助力交互设计师设计网站主页

Midjourney的一大核心优势是提供创意设计,这个功能也可以用在网站主页设计上,使用Midjourney prompt 应尽量简单,只需要以"web design for..." or "modern web design for..."开头即可 比如设计一个通用SAAS服务的初创企…

vue进阶-消息的订阅与发布

📖vue基础学习-组件 介绍了嵌套组件间父子组件通过 props 属性进行传参。子组件传递数据给父组件通过 $emit() 返回自定义事件,父组件调用自定义事件接收子组件返回参数。 📖vue进阶-vue-route 介绍了路由组件传参,两种方式&…

C#线性插值,三角插值

什么是插值? 是什么?很简单! 已知两点,推断中间的每一点的过程。 有什么用? 很简单!位置从30到40耗时3秒求每一时刻的位置! 1.线性插值 设v是结果,start是开始,end是结…

数据分析之Matplotlib

文章目录 1. 认识数据可视化和Matplotlib安装2. 类型目录3. 图标名称title4. 处理图形中的中文问题5. 设置坐标轴名称,字体大小,线条粗细6. 绘制多个线条和zorder叠加顺序7. 设置x轴刻度xticks和y轴刻度yticks8. 显示图表show9. 设置图例legend10. 显示线…

MQTT协议在物联网环境中的应用及代码实现解析(一)

MQTT协议全称是Message Queuing Telemetry Transport,翻译过来就是消息队列遥测传输协议,它是物联网常用的应用层协议,运行在TCP/IP中的应用层中,依赖TCP协议,因此它具有非常高的可靠性,同时它是基于TCP协议…

C# HTTP Error 500.19

解决办法&#xff1a; .vs configapplicationhost.config 修改<section name"windowsAuthenticationnurununoverrideModeDefault"Allow”/>

【SCI一区】【电动车】基于ADMM双层凸优化的燃料电池混合动力汽车研究(Matlab代码实现)

目录 &#x1f4a5;1 概述 1.2 电动车动力学方程 1.3 电池模型 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f308;4 Matlab代码、数据、文章讲解 &#x1f4a5;1 概述 文献来源&#xff1a; 随着车辆互联性的出现&#xff0c;互联汽车 (CVs) 在增强道路安全、改…

Lazygit贴合 neovim

功能性要比gitui 好用&#xff0c;vim 的键位习惯 > 嵌入式数据库 &#xff0c;python 的性能够用了 … … ,分析差异&#xff0c;选择 备份和升级

stm32 报错 dev_target_not_halted

烧录stm32H743&#xff0c;在cubeprogrammer里面点击connect&#xff0c;报错dev_target_not_halted 解决方法&#xff1a;先把H743的boot0引脚接到高电平上&#xff0c;然后少上电&#xff0c;此时会停止内核的运行&#xff0c;再点击connect即可 H743管脚&#xff1a; 在芯…

ES系列--es进阶

一、系统架构 一个运行中的 Elasticsearch 实例称为一个节点&#xff0c;而集群是由一个或者多个拥有相同 cluster.name 配置的节点组成&#xff0c; 它们共同承担数据和负载的压力。当有节点加入集群中或者 从集群中移除节点时&#xff0c;集群将会重新平均分布所有的数据。 …

「深度学习之优化算法」(十四)麻雀搜索算法

1. 麻雀搜索算法简介 (以下描述,均不是学术用语,仅供大家快乐的阅读)   麻雀搜索算法(sparrow search algorithm)是根据麻雀觅食并逃避捕食者的行为而提出的群智能优化算法。提出时间是2020年,相关的论文和研究还比较少,有可能还有一些正在发表中,受疫情影响需要论…

Kubernetes Volume及其类型(NFS、SAN) - PV - PVC - PV与PVC与Pod的关系

目录 volume 卷 官方文档&#xff1a;卷 | Kubernetes 一、emptyDir&#xff08;临时卷&#xff09; 二、hostPath卷 type字段参数 hostPath 实验&#xff1a; 三、第3方提供的存储卷&#xff08;百度云、阿里云、亚马逊云、谷歌云等&#xff09; 四、local卷 五、NF…

回归预测 | MATLAB实现Attention-GRU多输入单输出回归预测(注意力机制融合门控循环单元,TPA-GRU)

回归预测 | MATLAB实现Attention-GRU多输入单输出回归预测----注意力机制融合门控循环单元&#xff0c;即TPA-GRU&#xff0c;时间注意力机制结合门控循环单元 目录 回归预测 | MATLAB实现Attention-GRU多输入单输出回归预测----注意力机制融合门控循环单元&#xff0c;即TPA-G…

如何在无人机支持下完成自然灾害风险评估的原理和方法

对灾害的损失进行估算与测算是制定防灾、抗灾、救灾 及灾后重建方案的重要依据。 自然灾害评估按灾害客观地发展过程可分三种&#xff1a;一是灾前预评估&#xff0c;二是灾期跟踪或监测性评估&#xff0c;三是灾后实测评估。 灾前预评估要考虑三个因素&#xff0c;第一是未来…

真正的理解WPF中的TemplatedParent

童鞋们在WPF中经常看到 TemplatedParent ,或者经常看到下面的用法: {Binding RelativeSource={RelativeSource TemplatedParent}, Path=Content} 是不是看的一脸蒙圈? 先看官方文档: 意思是 和这个控件的 模板上的 父亲,如果这个控件不是模板创建的,那么这个值就…

Python爬虫学习笔记(三)————urllib

目录 1.使用urllib来获取百度首页的源码 2.下载网页图片视频 3.总结-1 4.请求对象的定制&#xff08;解决第一种反爬&#xff09; 5.编解码 &#xff08;1&#xff09;get请求方式&#xff1a;urllib.parse.quote&#xff08;&#xff09; &#xff08;2&#xff09;get请求…

使用 YOLOv8 和 Streamlit 构建实时对象检测和跟踪应用程序:第 1 部分-介绍和设置

示例:图像上的对象检测 介绍 实时视频中的目标检测和跟踪是计算机视觉的一个重要领域,在监控、汽车和机器人等各个领域都有广泛的应用。 由于需要能够识别和跟踪对象、确定其位置并对它们进行实时分类的自动化系统,对视频帧中的实时对象检测和跟踪的需求日益增加。 在这…

迁移 Gitee 仓库到 Github

Step1: 在Gitee找到你要迁移的仓库, 并复制 克隆|下载 链接 Step2: 打开 Github, 找到 按钮选择 Import Step3: 打开 Github, 找到 按钮选择 Import Step4: Waiting... 等待导入成功 Over~ 还有一种镜像更新的方案, Gitee 支持镜像同步, 但是我使用时无法获取到仓库名,…

初识vue3/setup/ ref()/ computed/watch/生命周期/父传子

创建项目先不着急学 main.js变了 新加setup reactive ref() computed watch 生命周期 父传子 子传父 ref/模板引用 暴露子组件属性 跨层传数据 defineOptions