Attention is all you need

这篇文章最大的亮点就是提出了一种Transformer的结构,是完全依赖注意力机制来刻画输入和输出之间的全局依赖关系,而不使用递归运算的RNN网络了。这样的好处就是第一可以有效的防止RNN存在的梯度消失的问题,第二是允许所有的字全部同时训练(RNN的训练是迭代的,一个接着一个的来,当前这个字过完,才可以进入下一个字),即训练并行,大大加快了计算效率。

Transformer使用了位置嵌入来理解语言的顺序,使用了多头注意力机制和全连接层等进行计算,还有跳远机制,LayerNorm机制,Encoder-Decoder架构等。

这篇文章我们主要讲下下面的Transformer,先看总体结构:

从这个结构的宏观角度上,我们可以看到Transformer模型也是用了Encoder-Decoder结构,编码器部分负责把自然语言序列映射成为隐藏层,含有自然语言序列的数学表达,然后解码器把隐藏层再映射为自然语言序列,从而使我们可以解决各种问题,比如情感分类、命名实体识别、语义关系抽取、机器翻译等等。

先讲下编码器部分的工作细节

1.编码器部分的工作细节

看上面结构我们发现编码器部分是由Nx个transformer block堆叠而成的,我们就拿一个transformer block来进一步观察,每一个transformer block又有两个子层,第一个是多头注意力部分,第二个是feed-forward部分。

我们输入句子:Why do we work?的时候,它的编码流程进一步细化:

1.首先输入进来之后,经过input embedding层每个字进行embedding编码,然后再编入位置信息(position encoding),形成带有位置信息的embedding编码。

2.然后进入多头注意力部分,这部分是多角度的self-attention部分,在里面每个字的信息会依据权重进行交换融合,这样每一个字会带上其他字的信息(信息多少依据权重决定),然后进入feed-forward部分进行进一步的计算,最后就会得到输入句子的数学表示了。

下面再详细说下每个部分的细节。

1.1 位置嵌入

由于transformer模型没有循环神经网络的迭代操作,所以我们必须提供每个字的位置信息给transformer,才能识别出语言中的顺序关系。

现在定义一个位置嵌入的概念,即positional encoding,位置嵌入的维度为[max sequence length, embedding dimension],嵌入的维度同词向量的维度,max sequence length属于超参数,指的是限定的最大单个句长。

注意,我们一般以字为单位训练transformer模型,也就是说我们不用分词了,首先我们要初始化字向量为[vocab size, embedding dimension], vocab size为总共的字库数量,embedding dimension为字向量的维度,也是每个字的数学表达。具体来看个例子。

这里论文里面使用了sin和cos函数的线性变换来提供给模型的位置信息

上式中pos指的是句中字的位置,取值范围是[0, max sequence length),i指的是词向量的维度,取值范围是[0, embedding dimension),上面有sin和cos一组公式,也就是对应着embedding dimension维度的一组奇数和偶数的序号的维度,例如0,1一组,2,3一组,分别用上面的sin和cos函数做处理,从而产生不同的周期性变化,而位置嵌入在embedding dimension维度上随着维度序号增大,周期变化会越来越慢,而产生一种包含位置信息的纹理,位置嵌入函数的周期从2Π到1000*2Π变化,而每个位置在embedding dimension维度上都会得到不同周期的sin和cos函数的取值组合,从而产生独一的纹理位置信息,模型从而学到位置之间的依赖关系和自然语言的时序特性。

还是拿例子举例,我们看看输入Why do we work?的位置编码怎么编码的?

可视化一下,最后得到这样的结果:

所以,会得到Why do we work这四个词的位置信息,然后embedding矩阵和位置矩阵的加和作为带有位置信息的新X,Xembedding_pos

这里再补充两个问题,第一个就是为啥要用这种方式编码呢?

作者这里这么设计的原因是考虑到NLP任务中,除了单词的绝对位置,单词的相对位置也非常重要,根据公式
sin(α+β)=sinαcosβ+cosαsinβ以及cos(α+β)=cosαcosβ-sinαsinβ,这表明位置k+p的位置向量可以表示为位置k的特征向量的
线性变换,为模型捕捉单词之间的相对位置关系提供了非常大的便利。

第二个问题,就是这里为啥单词embedding和位置embedding能直接相加呢?论文里面提到了维度相同,应该还有更深的原因吧。

维度相同是基础,但能相加的原因就是不同的位置,这个embedding肯定是不一样的,而对于词语来说,不同的词语,embedding肯定也是
不一样,那么这样相加,肯定能区分开词语和位置,这就类似于,每个位置one-hot编码,每个词one-hot编码,然后对应位置和对应词
one-hot相加,然后再取相应的embedding是一个道理。

1.2 多头注意力机制

这一步为了学到多重语意含义的表达,进行多头注意力机制的运算。我们先宏观看一下这个注意力机制到底在做什么?拿单头注意力机制举例:

左边的红框就是我们现在讲的部分,右图就是单头注意力机制做的事情,拿句子

The animal didn’t cross the street, because it was too tired.

我们看it这个词最后得到的R矩阵里面,就会表示出这个it到底是指的什么,可以看到R1和R2和it最相关,就可以认为it表示的是The animal.

也就是说,每个字经过映射之后都会对应一个R矩阵,这个R矩阵就是表示这个字与其他字之间某个角度上的关联性信息,这叫做单头注意力机制

下面看一下多头注意力宏观上到底干了什么事情:

左边这个是两头的注意力机制,上面说到这个橙色的这个注意力反映了it这个词指代的信息。而这个绿色的这个注意力,反映了it这个词的状态信息,可以看到it经过这个绿色的注意力机制后,tired这个词与it关联最大,就是说it,映射过去,会更关注tired这个词,因为这个正好是它的一个状态。它累了。

这样是不是就可以明白多头注意力的意义了,每个字经过多头注意力机制之后会得到一个R矩阵,这个R矩阵表示这个字与其他字在N个角度上(比如指代,状态...)的一个关联信息,这个角度就是用多个头的注意力矩阵体现的。这就是每个字多重语义的含义。

具体看一下其是怎么实现的。

我们的目标是把我们的输入Xembedding_pos通过多头注意力机制(系列线性变换)先得到Z。然后Z通过前馈神经网络得到R。这个R矩阵表示这个字与其他字在N个角度上(比如指代,状态...)的一个关联信息。

先看看怎么得到这个Z:在Xembedding_pos->Z的过程中到底发生了什么呢?

这就是整个过程的变换,首先Xembedding_pos会做三次线性变化得到Q,K,V三个矩阵,然后里面Attention机制,把Q,K,V三个矩阵进行运算,最后把Attention矩阵和Xembedding_pos加起来就是最后的Z。

可是为什么要这么做呢?Q,K,V又分别表示什么意思呢?

我们先说第二个问题,Q,K,V这三个矩阵分别是什么意思,Q表示query,K表示key,V表示Value.之所以引入了三个矩阵,是借鉴了搜索查询的思想,比如我们有一些信息是键值对(key->value)的形式存到了数据库,(5G->华为,4G->诺基亚),比如我们输入的query是5G,那么去搜索的时候,会对比一下Query和Key,把与Query最相似的Key对应的值返回给我们。这里同样的思想,我们最后想要的Attention,就是V的一个线性组合,只不过根据Q和K的相似性加了一个权重并softmax了一下而已,这里比较巧妙的是Q,K,V都是这个Xembedding_pos而已。下面具体来看一下:

上面图中有8个head,我们这里拿一个head来看一下做了什么事情:(注意这里head的个数一定要能够被embedding dimension整除才可以,上面的embedding dimension是512,head个数是8,那么每一个head的维度是(4,512/8)

怎么得到Q1和K1的相似度呢?我们想到了点积运算,点积运算的几何意义是两个向量越相似,他们的点积就越大,反而就越小。

我们看下Q1*K1的转置表达的是什么意思。

c1,c2,...,c6这些就代表我们的输入的每一个字,每一行代表每一个字的特征信息,那么Q1的c1行和K1转置的各个列做点积运算得到第一个字和其他几个字的相似度。这样最后的结果每一行表示的这个字和其他哪几个字比较相关,这个矩阵就是head1角度的注意力矩阵。自注意力的巧妙之处就在于这里,每个词向量两两之间内积,就能得到当前词与其他词的相似关系,有了相似关系,再通过softmax映射出权重,再把这个权重反乘到各自词语的embedding身上,再加权求和,就相当于融于了其他词的相关信息。

这里还有一个问题是QK^T除以了sqrt(dk)的操作,这个原因具体看参考博客的内容(自然语言处理之Attention大详解(Attention is all you need))

然后对每一行使用softmax归一化变成某个字与其他字的注意力的概率分布(使每一个字跟其他所有字的权重和为1).

这时候,我们从注意力矩阵取出一行(和为1),然后依次点乘V的列,因为矩阵V的每一行代表着每一个字向量的数学表达,这样操作,得到的正是注意力权重进行数学表达的加权线性组合,从而使每个字向量都含有当前句子的所有字向量的信息。这样就得到了新的X_attention(这个X_attention中每一个字都含有其他字的信息)。

用这个加上之前的Xembedding_pos得到残差连接,训练的时候可以使得梯度直接走捷径反传到最初层,不易消失。另外,这个用残差还能够保留原始的一些信息。

再经过一个LayerNormalization操作就可以得到Z。LayerNormalization的作用是把神经网络中隐藏层归一化为标准正态分布,起到加快训练速度,加快收敛的作用。

所以多头注意力机制系列总结起来就是下面这个图了:

注意,图里面有个地方表达错了,dk不是注意力的头数,而是拼接起来的那个最终维度,这里指的是512,另外就是,这里多个头直接拼接的操作,相当于默认了每个头或者说每个子空间的重要性是一样的,在每个子空间里面学习到的相似性的重要度是一样的。

1.3前馈神经网络

这一块就很简单了,我们上面通过多头注意力机制得到了Z,下面就是把Z再做两层线性变换,然后relu激活就得到了最后的R矩阵了。(相当于一个两层的神经网络)。

1.4 Layer Normalization和残差连接

1)残差连接:

我们在上一步得到了经过注意力矩阵加权之后的V,也就是Attention(Q,K,V),我们对它进行一下转置,使其和Xembedding维度一致,也就是[batch size, sequence length, embedding dimension],然后把他们加起来做残差连接,直接进行元素相加,因为他们的维度一致:

在之后的运算里,每经过一个模块的运算,都要把运算之前的值和运算之后的值相加,从而得到残差连接,训练的时候可以使梯度直接走捷径反传到最初始层。

2)LayerNorm

LayerNormalization的作用是把神经网络中隐藏层归一化为标准正态分布,也就是i.i.d独立同分布,以起到加快训练速度,加快收敛的作用。

然后用每一行的每一个元素减去这行的均值,再除以这行的标准差,从而得到归一化后的数值, � 是为了防止除0.

之后引入两个可训练参数α,β来弥补归一化过程中损失掉的信息,注意⊙表示元素相乘而不是点积,我们一般初始化α为全1,而β为全0.

所以一个Transformer编码块做的事情如下:

下面再说两个细节就可以把编码器的部分结束了。

  • 第一个细节就是上面只展示了一句话经过一个Transformer编码块之后的状态和维度,但我们实际工作中,不会只有一句话和一个Transformer编码块,所以对于输入来的维度一般是[batch_size, seq_len, embedding_dim],而编码块的个数一般也是多个,不过每一个的工作过程和上面一致,无非就是第一块的输出作为第二块的输入,然后再操作。论文里面是用的6个块进行的堆叠。
  • Attention Mask的问题。 因为如果有多句话的时候,句子都不一定一样长,而我们的seqlen肯定是以最长的那个为标准,不够长的句子一般用0来补充到最大长度,这个过程叫做padding.

但这时再进行softmax的时候就会产生问题。回顾softmax函数

e^0是1,是有值的,这样的话softmax中被padding的部分就参与了运算,就等于是让无效的部分参与了运算,会产生很大隐患,这时就需要做一个mask让这些无效区域不参与运算,我们一般给无效区域加一个很大的负数的偏置,也就是:

经过上面的masking我们使无效区域经过softmax计算之后还几乎为0,这样就避免了无效区域参与计算。Transformer里面有两种mask方式,分别是padding mask和sequence mask,上面这个就是padding mask,这种mask在scaled dot-product attention里面都需要用到,而sequence mask只有在Decoder的self-attention里面用到。

实际实现的时候,padding mask实际上是一个张量,每个值都是一个Boolean,值为false的地方就是我们要进行处理的地方(加负无穷的地方)

最后通过上面的梳理,我们解决了Transformer编码器部分,下面看看Transformer Encoder整体的计算过程。

1.字向量与位置编码:

X=EmbeddingLookup(X) + PostionalEncoding

X:[batch_size, seq_len, embed_dim]

2.自注意力机制:

Q=Linear(X)=XWq

K=Linear(X)=XWk

V=Linear(X)=XWv

Xattention=SelfAttention(Q,K,V)

3. 残差连接与Layer Normalization

Xattention = X + Xattention

4. Feedforword,其实就是两层线性映射并用激活函数激活,比如说Relu .

Xhidden = Activate(Linear(Linear(Xattention)))

5.重复3:

Xhidden = Xattention+Xhidden

Xhidden = LayerNorm(Xhidden)

Xhidden:[batch_size, seq_len, embed_dim]

6.这样一个Transformer编码块就执行完了,得到了X_hidden之后,就可以作为下一个Transformer编码块的输入,然后重复2-5执行,直到Nx个编码块。

好了,编码器部分结束,下面进入解码器部分

解码器部分先放上一张图,这边先不做解释。

解码器最终会输出一个实数向量。我们如何把浮点数变成一个单词,这个就是最终的线性变换和softmax层要做的事情

2. 最终的线性变换和softmax层

解码组件最后会输出一个实数向量。线性变换和softmax层将其变成一个单词。

线性变换层是一个简单的全连接神经网络,它可以把解码组件产生的向量投射到一个比它大得多的,被称作对数几率的向量里面。

不妨假设我们的模型从训练集中学习一万个不同的英语单词(我们模型的"输出词表")。因此对数几率向量为一万个单元格长度的向量--每个单元格对应每个单词的分数。

接下来的softmax层便会把那些分数变成概率(都为正数,上限1.0).概率最高的单元格被选中,并且它对应的单词被作为这个时间步的输出。

Transformer作为NLP领域的一大神器,还是值得多去做些了解的,先记录这些,其余再补充。放上参考博客。

自然语言处理之Attention大详解(Attention is all you need)

深度学习中的注意力机制

Self-Attention与Transformer

BERT大火却不懂Transformer?读这一篇就够了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/31677.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Docker 数据卷

1、什么是数据卷 通过镜像创建一个容器。容器一旦被销毁,则容器内的数据将一并被删除。但有些情况下,通过服务器上传的图片出会丢失。容器中的数据不是持久化状态的。这个时候可以通过数据卷来解决这个问题。 数据卷是一个可供一个或多个容器使用的特殊目…

解决不允许一个用户使用一个以上用户名与一个服务器或共享资源的多重连接的问题

问题概述: 用windows server 2012 r2 vl x64搭了个文件服务器,在使用时有个问题,老是用户登录有问题,提示“不允许一个用户使用一个以上用户名与一个服务器或共享资源的多重连接”。出现的原因不详,网上也没查到合理的…

typescript找不到模块‘vue‘ ‘vue-router‘

import { createRouter, createWebHashHistory, createWebHistory } from vue-router 提示:找不到模块“vue-router”。你的意思是要将 "moduleResolution" 选项设置为 "node",还是要将别名添加到 "paths" 选项中?ts(27…

【HISI IC萌新虚拟项目】ppu模块基于spyglass的lint清理环境搭建与lint清理

关于整个虚拟项目,请参考: 【HISI IC萌新虚拟项目】Package Process Unit项目全流程目录_尼德兰的喵的博客-CSDN博客 前言 RTL代码在交付给验证同时进行功能验证时,可以同步进行lint的清理工作。一般而言影响编译和仿真的问题会在vcs的error和warning中被修正,因此清理lin…

Windows下Nacos的配置与使用

一、什么是 Nacos 以下引用来自 nacos.io Nacos /nɑ:kəʊs/ 是 Dynamic Naming and Configuration Service的首字母简称,一个更易于构建云原生应用的动态服务发现、配置管理和服务管理平台。 Nacos 致力于帮助您发现、配置和管理微服务。Nacos 提供了一组简单易用…

【Rust】2、实战:文件、网络、时间、进程-线程-容器、内核、信号-中断-异常

文章目录 七、文件和存储7.2 serde 与 bincode 序列化7.3 实现一个 hexdump7.4 操作文件7.4.1 打开文件7.4.2 用 std::fs::Path 交互 7.5 基于 append 模式实现 kv数据库7.5.1 kv 模型7.5.2 命令行接口 7.6 前端代码7.6.1 用条件编译定制要编译的内容 7.7 核心:LIBA…

详解Spring Cloud版本问题

目录 1.让人头疼的多版本号体系 2.目录关系 3.为什么会有多个版本号体系 1.让人头疼的多版本号体系 由于历史原因,spring cloud分为了Alibaba和Netflix两个体系。 想要了解原因以及整个spring cloud体系的来龙去脉的同学可以去看我的另一篇文章: S…

信息系统之网络安全方案 — “3保1评”

信息系统之网络安全方案 — “3保1评” 序:什么是“3评1保”?一、网络安全等级保护1.1 概念1.2等保发展1.3法律要求1.4分级及工作流程 二、涉密信息系统分级保护2.1概念2.2法律要求2.3分级及工作流程 三、关键信息基础设施保护3.1概念3.2关保的发展3.3法…

(UE4/UE5)Unreal Engine中使用HLOD

本教程将详细介绍在Unreal Engine的不同版本(4.20-4.24、4.25-4.26、5.2)中如何使用Hierarchical Level of Detail (HLOD)。注意,每个版本中使用HLOD的方法可能会有所不同。 一、预先生成LOD 步骤一:预先生成LOD打开UE4.21&…

C语言贪吃蛇课程设计实验报告(包含贪吃蛇项目源码)

文末有贪吃蛇代码全览,代码有十分细致的注释!!!文末有贪吃蛇代码全览,代码有十分细致的注释!!!文末有贪吃蛇代码全览,代码有十分细致的注释!!! 码文不易,给个免费的小星星和免费的赞吧,关注也行呀(⑅•͈ᴗ•͈).:*♡ 不要白嫖哇(⁍̥̥̥᷄д⁍̥̥…

解决Vuex刷新页面数据丢失的问题

参考: https://blog.csdn.net/qq_51441159/article/details/128047610 方法一(不使用插件): 1、直接在vuex修改数据方法中将数据存储到浏览器本地存储中 import Vue from vue; import Vuex from vuex;Vue.use(Vuex);export defa…

Java——《面试题——MyBatis篇》

前文 java——《面试题——基础篇》 Java——《面试题——JVM篇》 Java——《面试题——多线程&并发篇》 Java——《面试题——Spring篇》 目录 前文 1、什么是MyBatis 2、说说MyBatis的优点和缺点 3、#{}和${}的区别是什么? 4、当实体类中的属性名和…

Nginx-Goaccess(实时日志服务)

goaccess的功能 1、使用webscoket协议传输(双向传输协议)2、基于终端的快速日志分析器3、通过access.log快速分析和查看web服务的统计信息、PV、UV4、安装简单、操作简易、界面炫酷5、按照日志统计访问次数、独立访客数量、累计消耗的带宽6、统计请求次…

Redis知识点

Redis是一个数据库,不过与传统数据库不同的是Redis的数据库是存在内存中,所以读写速度非常快,因此 Redis被广泛应用于缓存方向。 除此之外,Redis也经常用来做分布式锁,Redis提供了多种数据类型来支持不同的业务场景。除…

Ansible剧本(playbook)

一、playbooks 概述以及实例操作 1、playbooks 的组成 playbooks 本身由以下各部分组成 (1)Tasks:任务,即通过 task 调用 ansible 的模板将多个操作组织在一个 playbook 中运行 (2)Variables&#xff1…

Ubuntu 22 服务器端安装图形化界面

文章目录 前言一、什么是图形化界面二、操作步骤1、更新安装工具2、开始安装3、重启 总结 前言 Ubuntu 系统做得是越来越好了,从CentOS 不再提供维护后,越来越多的企业和公司从CentOS转到Ubuntu服务器系统,转了之后才发现,它比Ce…

MySQL数据库日志管理、备份与恢复

目录 一、MySQL 日志管理 二、数据备份的重要性 造成数据丢失的原因 三、数据库备份的分类 1 、从物理与逻辑的角度 (1)备份划分 (2) 物理备份方法 2、 从数据库的备份策略角度 四、常见的备份方法 1、物理冷备 2、专用备…

解决安卓12限制32个线程

Android 12及以上用户在使用Termux时,有时会显示[Process completed (signal 9) - press Enter],这是因为Android 12的PhantomProcesskiller限制了应用的子进程,最大允许应用有32个子进程。 这里以ColorOS 12.1为例(其他系统操作略…

前端vue入门(纯代码)14

内容创作不易,各位帅哥美女,求个小小的赞!!! 【15.给todoList案例添加编辑按钮】 本篇内容在TodoList案例的基础上添加个编辑按钮,要求: (1)点击编辑按钮后&#xff0c…

数据挖掘——宁县(区、市)农村居民人均可支配收入影响因子分析(论文)

《数据挖掘与分析》课程论文 题目:宁县(区、市)农村居民人均可支配收入影响因子分析 xx学院xx班:xxx 2022年6月 摘要:农村居民人均可支配收入可能被农作物产量、牲畜存栏、农作物播种数量等诸多因素影响。为此&#…