Transformer中的Self-Attention和Multi-Head Attention

2017 Google 在Computation and Language发表

当时主要针对于自然语言处理(之前的RNN模型记忆长度有限且无法并行化,只有计算完ti时刻后的数据才能计算ti+1时刻的数据,但Transformer都可以做到)

文章提出Self-Attention概念,在此基础上提出Multi-Head Atterntion

下面借鉴霹雳吧啦博主的视频进行学习:


Self-Attention

假设输入的序列长度为2,输入就两个节点x1,x2,然后通过Input Embedding也就是图中的f(x)将输入映射到a1,a2。紧接着分别将a1,a2分别通过三个变换矩阵Wq,Wk,Wv(这三个参数是可训练的,是共享的)得到对应的q^{i},k^{i},v^{i}(直接使用全连接层实现)。

其中:

q代表query,后续会去和每一个k进行匹配

k代表key,后续会被每个q匹配

v代表从a中提取得到的信息

后续q和k匹配的过程可以理解成计算两者的相关性,相关性越大对应v的权重也越大。

假设a_{1}=(1,1),a_{2}=(1,0),W^{q}=\begin{pmatrix} 1,&1 \\ 0,&1 \end{pmatrix}

那么q^{1}=(1,1)\begin{pmatrix} 1, &1 \\ 0, & 1 \end{pmatrix}=(1,2), q^{2}=(1,0)\begin{pmatrix} 1, &1 \\ 0, & 1 \end{pmatrix}=(1,1)

因为Transformer是并行化的,可以直接写成:

\begin{pmatrix} q^{1}\\ q^{2} \end{pmatrix}=\begin{pmatrix} 1, &1 \\ 1, &0 \end{pmatrix}\begin{pmatrix} 1, &1 \\ 0, &1 \end{pmatrix}=\begin{pmatrix} 1, &2 \\ 1, &1 \end{pmatrix}

同理可以得到\begin{pmatrix} k^{1}\\ k^{2} \end{pmatrix}\begin{pmatrix} v^{1}\\ v^{2} \end{pmatrix},那么求得的\begin{pmatrix} q^{1}\\ q^{2} \end{pmatrix}就是原论文中的Q,\begin{pmatrix} k^{1}\\ k^{2} \end{pmatrix}是K,\begin{pmatrix} v^{1}\\ v^{2} \end{pmatrix}是V。接着q^{1}和每个k进行match,点乘操作,接着除以\sqrt{d}得到对应的\alpha,其中d代表向量k^{i}的长度,除以\sqrt{d}的原因是在论文中的解释“进行点乘后数值很大,导致通过softmax后梯度变得很小”,所以通过\sqrt{d}进行缩放。

\alpha _{1,1}=\frac{q^{1}\cdot k^{1}}{\sqrt{d}}=\frac{1*1+2*0}{\sqrt{2}}=0.71\\ \alpha _{1,2}=\frac{q^{1}\cdot k^{2}}{\sqrt{d}}=\frac{1*0+2*1}{\sqrt{2}}=1.41

同理q^{2}去匹配所有的k能得到\alpha _{2,i},统一写成乘法矩阵形式:

\begin{pmatrix} \alpha _{1,1} & \alpha _{1,2} \\ \alpha _{2,1} & \alpha _{2,2} \end{pmatrix}=\frac{\begin{pmatrix} q^{1}\\ q^{2} \end{pmatrix}\begin{pmatrix} k^{1}\\ k^{2} \end{pmatrix}^{T}}{\sqrt{d}}

接着对每一行即(\alpha _{1,1},\alpha _{1,2}),(\alpha _{2,1},\alpha _{2,2})分别进行softmax处理得到(\widehat{\alpha} _{1,1},\widehat{\alpha} _{1,2}),(\widehat{\alpha} _{2,1},\widehat{\alpha} _{2,2}),这里的\widehat{\alpha }相当于计算得到针对每个v的权重。到这里完成了Attention(Q,K,V)公式中的softmax(\frac{QK^{T}}{\sqrt{d_{k}}})部分。

上面已经计算得到\alpha,即针对每个v的权重,接着进行加权得到最终结果

b_{1}=\widehat{\alpha }_{1,1}\times v^{1}+\widehat{\alpha }_{1,2}\times v^{2}=(0.33,0.67)\\ b_{2}=\widehat{\alpha }_{2,1}\times v^{1}+\widehat{\alpha }_{2,2}\times v^{2}=(0.50,0.50)

统一写成矩阵乘法形式:

\begin{pmatrix} b_{1}\\ b_{2} \end{pmatrix}=\begin{pmatrix} \widehat{\alpha }_{1,1} & \widehat{\alpha }_{1,2}\\ \widehat{\alpha }_{2,1}& \widehat{\alpha }_{2,2} \end{pmatrix}\begin{pmatrix} v^{1}\\ v^{2} \end{pmatrix}

Self-Attention的内容就结束了,总结下来就是论文中一个公式:

 Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V


Multi-Head Attention

多头注意力机制能联合来自不同head部分学习到的信息。

首先还是和Self-Attention模块一样将a_{i}分别通过W^{q},W^{k},W^{v}得到对应的q^{i},k^{i},v^{i},然后再根据使用的head的数目h进一步把得到的q^{i},k^{i},v^{i}均分成h份。比如下图中假设的h=2然后q^{1}拆分成q^{1,1},q^{1,2},那么q^{1,1}就属于head1,q^{1,2}属于head2。

论文中写的通过W_{i}^{Q},W_{i}^{K},W_{i}^{V}映射得到每个head的Q_{i},K_{i},V_{i}:

head_{i}=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

其实简单的均分也可以将W_{i}^{Q},W_{i}^{K},W_{i}^{V}设置成对应值来实现均分,比如下图中的Q通过W_{1}^{Q}就能得到均分后的Q_{1}

通过上述方法就能得到每个headi对应的Q_{i},K_{i},V_{i}参数,接下来针对每个head使用Self-Atttention中相同的方法即可得到对应的结果。

Attention(Q_{i},K_{i},V_{i})=softmax(\frac{Q_{i}K_{i}^{T}}{\sqrt{d_{k}}})V_{i}

接着将每个head得到的结果进行concat拼接,比如下图中b1,1(head1得到的b1)和b1,2(head2得到的b1)拼接在一起,b2,1(head得到的b2)和b2,2(head得到的b2)拼接在一起。

接着将拼接后的结果通过W^{O}(可学习的参数)进行融合,如下图,融合后得到最终的结果b1,b2

到这,总结下来就是论文中的两个公式:

MultiHead(Q,K,V)=Concat(head_{1},...,heah_{h})W^{O}\\ where head_{i}=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

import torch
from fvcore.nn import FlopCountAnalysis

def main():
    #Self-Attention
    a1 = torch.nn.MultiheadAttention(embed_dim=512, num_heads=1)
    a1.proj = torch.nn.Identity() #removr Wo

    #Multi-Head Attention
    a2 = torch.nn.MultiheadAttention(embed_dim=512, num_heads=8)

    #[batch_szie,num_tokens,total_embed_dim]
    t = torch.rand(32, 1024, 512)

    flops1 = FlopCountAnalysis(a1, t)
    print("Self-Attention FLOPs:", flops1.total())

    flops2 = FlopCountAnalysis(a2, t)
    print("Multi-Head Attention FLOPs:",flops2.total())

if __name__ == '__main__':
    main()
Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 68719476736

其实两者FLOPs的差异只是在最后的W^{O}上,如果把Multi-Head Attentio的W^{O}也删除(即把a2的proj也设置成Identity),可以看出两者FLOPs是一样的:

Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 60129542144

Positional Encoding

刚才计算是没有考虑到位置信息的。假设在Self-Attention模块中,输入a1,a2,a3得到b1,b2,b3。对于a1而言,a2和a3离它都是一样近且没有先后顺序。假设将输入的顺序改为a1,a2,a3,对结果b1是没有任何影响的。下面是Pytorch的实验,首先使用nn.MultiheadAttention创建一个Self-Attention模块(num_heads=1),注意这里在正向传播过程中直接传入QKV,接着创建两个顺序不同的QKV变量t1和t2(主要是将q2,k2,v2和q3,k3,v3的顺序换了下),分别将这两个变量输入Self-Attention模块进行正向传播。

import torch
import torch.nn as nn

m = nn.MultiheadAttention(embed_dim=2, num_heads=1)

t1 = [[[1., 2.], #q1,k1,v1
            [2., 3.], #q2,k2,v2
            [3., 4.]]] #q3,k3,v3

t2 = [[[1., 2.], #q1,k1,v1
            [3., 4.], #q3,k3,v3
            [2., 3.]]] #q2,k2,v2

q, k, v  = torch.as_tensor(t1), torch.as_tensor(t1), torch.as_tensor(t1)
print("result:\n", m(q, k, v))

q, k, v = torch.as_tensor(t2), torch.as_tensor(t2), torch.as_tensor(t2)
print("result2:\n", m(q, k , v))

即使调换了qkv顺序,但对b1是没有影响的。

为了引入位置信息,原论文引入了位置编码positional encoding。如下图所示,位置编码是直接加在输入的a={a1,...,an}中的,即pe={pe1,...,pen}和a={a1,...,an}拥有相同维度大小。关于位置编码在原论文有提出两种方案,一种是原论文中使用的固定编码,即论文中给出的sine and cosine funtions方法,按照该方法可计算出位置编码;另一种是可训练的位置编码。ViT论文中使用的是可训练的位置编码。positional encoding


超参对比

关于Transformer中的一些超参数的实验对比可以参考原论文,其中:

N表示重复堆叠的Transformer Block的次数

dmodel表示Multi-Head Self-Attention输入输出的token维度(向量长度)

dff表示在MLP(feed forward)中隐层的节点个数

h表示Multi-Head Self-Attention中的head的个数

dk,dv表示Multi-Head Self-Attention 中每个head的key(K)以及query(Q)的维度

Pdrop表示dropout层的drop_rate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/729780.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeYOLO 用于目标检测的新型可扩展和高效CNN架构 | 最新轻量化SOTA! 5GFLOP下无对手!

本改进已集成到 YOLOv8-Magic 框架。 论文地址:https://arxiv.org/pdf/2406.14239 代码地址:https://github.com/LilianHollard/LeYOLO/tree/main 在深度神经网络中,计算效率对于目标检测至关重要,尤其是在新型模型更倾向于速度而非计算效率(浮点运算次数,FLOP)的情况下…

Transformer1--self attention

目录 一、 Vector set as 输入二、 模型输出(三种)1 **n-to-n**2 n-to-13 n-to-m 三、self-attention1、问题引入2、self-attention3 self-attention 原理介绍 一、 Vector set as 输入 一段声音讯号: 图结构(graph)…

店员顾客起纠纷?EasyCVR+AI视频监控管理平台,助力连锁门店安全运营

近日,某品牌咖啡店店员与顾客起冲突登上了新闻热搜,一时间引发大量关注。随着门店完整的监控视频录像公开,大家才了解事情的原委,而并非网传的那样。 随着社会的进步和科技的发展,视频监控已成为各行各业不可或缺的安全…

红军九大技战法

一、动态对抗,线上社工持续信息追踪 发起攻击前,发起攻击前,尽可能多的搜集攻击目标信息,做到知己知彼,直击目标最脆弱的地方。攻击者搜集关于目标组织的人员信息、组织架构、网络资产、技术框架及安全措施信息&#x…

一分钟了解中小企业数字化转型如何进行?「建议收藏」

关于“中小企业数字化转型方法论”,其实网上已经有不少文章给出了一些方式方法,那么这里我再系统性的讲解一下。 一、中小企业为什么要实现数字化转型 首先要知道,中小企业为什么要实现数字化转型?当前,世界经济数字化…

社区团购系统智慧门店物流配送系统开发,支持小程序公众号。

目录 前言: 一、为什么要做社区团购小程序? 二、怎么做一个社区团购小程序? 三、制作属于自己的社区团购小程序有什么好处? 总结: 前言: 社区团购是针对小区居民或群体开发的在线购物平台,…

深入解析 Python dataclass:类属性与类方法解释

文章目录 dataclass实例属性和类属性自动设置属性 实例方法静态方法(staticmethod)和 类方法(classmethod)静态方法类方法 dataclass dataclass 是 Python 3.7 引入的一个装饰器,用于简化类的定义。 使用 dataclass …

AcWing 1801:蹄子剪刀布 ← 模拟题

【题目来源】https://www.acwing.com/problem/content/1803/【题目描述】 你可能听说过“石头剪刀布”的游戏。 这个游戏在牛当中同样流行,它们称之为“蹄子剪刀布”。 游戏的规则非常简单,两头牛相互对抗,数到三之后各出一个表示蹄子&#x…

玩玩大模型:总结归纳可以,策划创新拉垮

最近身边的人都在研究大模型。太深入的理解不了,有一些人会讲讲promt提示,学了几招。 比如: #角色 你是一个美食博主 #条件 我只有xxx元,在xxx.... #任务 找一家好吃的当地特色餐馆... 多试几次,有些结果很有参考价值…

函数栈帧的创建和销毁,带动图详细解析,带你大致分析汇编代码

目录 1.什么是函数栈帧 2.理解函数栈帧有什么用? 3.函数栈帧的创建和销毁解析 3.1什么是栈? 3.2 认识相关寄存器和汇编指令 3.3函数栈帧的创建和销毁解析过程 3.4函数的调用 3.5汇编代码 3.5.1函数栈帧的创建 3.5.2main函数部分 3.5.3Add函数…

策略模式编程

接口定义&#xff1a; public interface ProcessParserStrategy { List<ProcessInfo> parser(String osType, String processInfo); String getApp(); } public interface ConfigParserStrategy { List<ConfigInfo> parser(String configInfo); String getConfigT…

谷歌Chrome浏览器排查js内存溢出

1. 打开谷歌浏览器检查台 2. 点击memory 3. 点击开始快照录制&#xff0c;时隔一会儿录一次&#xff0c;多录几次 4. 进行快照对比

vue+element-plus完美实现跨境电商商城网站

目录 一、项目介绍 二、项目截图 1.项目结构图 2.首页 3.中英文样式切换 4.金钱类型切换 5.商品详情 6.购物车 7.登录 ​编辑 8.注册 9.个人中心 三、源码实现 1.项目依赖package.json 2.项目启动 3.购物车页面 四、总结 一、项目介绍 本项目在线预览&am…

[网络安全产品]---EDR

写在前面 前端时间看抖音&#xff0c;刷到周鸿祎介绍360为什么这么厉害&#xff0c;他提到一点就是360是全球第一个提出云查杀概念的公司&#xff0c;相比较传统的基于病毒特征库终端杀毒&#xff0c;360依托积累的庞大的信息数据能有效应对APT攻击。 然后又特意找了一下云查…

世界是软件定义的 - 正如硬件公司所证明的那样

很难相信&#xff0c;马克安德森&#xff08;Marc Andressen&#xff09;在13年前写下了他著名的博客&#xff0c;题为“软件正在吞噬世界”。在这篇文章中&#xff0c;他谈到了现代软件组织对传统企业造成的破坏。 十三年后&#xff0c;即使面对英伟达的平流层估值&#xff0…

openGauss开发者大会、华为云HDC大会举行; PostgreSQL中国技术大会7月杭州开启

重要更新 1. openGauss Developer Day本周五于北京举行&#xff0c;大会聚集了相关行业专家、用户、伙伴和开发者&#xff0c;分享给予openGauss的联合创新成果和实践案例。([2] ) &#xff1b;华为云 HDC 2024本周五于东莞松山湖举行&#xff0c;主题演讲主要覆盖鸿蒙、AI ([3…

IntelliJ IDEA 2024 mac/win版:编程利器,智慧之选

IntelliJ IDEA 2024是一款由JetBrains精心打造的集成开发环境(IDE)&#xff0c;专为Java等编程语言量身打造&#xff0c;同时支持多种其他语言&#xff0c;为开发者提供了卓越的开发体验。 IntelliJ IDEA 2024 mac/win版获取 这款IDE凭借其出色的智能化和高效性&#xff0c;赢…

【Python高级编程】新手小白必须得学会的文本文件操作,资料资源均可分享!

文件读取处理 使用 read()&#xff1a; # 使用 read 方法读取文件的所有内容 with open(resources/training_log.txt, r) as file:content file.read()print(content)# 报错处理版本 # 使用 read 方法读取文件的所有内容 # 使用 utf-8 编码方式打开文件 with open(resources…

车载模块负载基础认识

车载模块负载是指车辆上的各种电子设备和系统&#xff0c;如导航系统、音响系统、空调系统、安全气囊等。这些设备和系统在车辆运行过程中需要消耗一定的电能&#xff0c;以保证其正常工作。车载模块负载的基础认识主要包括以下几个方面&#xff1a; 1. 负载类型&#xff1a;车…

GaussDB关键技术原理:高性能(一)

引言 对数据库性能进行优化是令人激动的&#xff0c;无论是对其进行性能需求分析、性能需求设计、性能问题定个位都是富于变化又充满挑战的工作&#xff0c;本章围绕“数据库性能”进行全面系统化的介绍&#xff0c;首先从数据库在现代软件栈中所处的位置出发&#xff0c;介绍…