(2024,attention,可并行计算的 RNN,并行前缀扫描)将注意力当作 RNN

Attention as an RNN

公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

3. 方法 

3.1 注意力作为一种(多对一的)RNN

3.2 注意力作为(多对多)RNN

3.3 Aaren: 将注意力当作 RNN

4. 实验 

附录

F. 计算时间


0. 摘要

随着 Transformer 的出现,序列建模领域取得了重大突破,为利用 GPU 并行性提供了高性能的架构。然而,Transformer 在推理时计算开销较大,限制了其在低资源环境(例如移动和嵌入式设备)中的应用。为了解决这个问题,我们采取了以下步骤:

  • 首先,我们展示了注意力机制可以被视为一种特殊的递归神经网络(RNN),能够高效地计算其多对一的 RNN 输出。
  • 接着,我们证明了流行的基于注意力的模型(如 Transformer)可以视为 RNN 的变体。然而,与传统的 RNN(如 LSTM)不同,这些模型无法高效地用新 token 更新,这在序列建模中是一个重要特性。
  • 为了解决这个问题,我们引入了一种基于并行前缀(prefix)扫描算法的新的高效计算注意力多对多 RNN 输出的方法。
  • 基于新的注意力公式,我们推出了 Attention as a recurrent neural network(Aaren),这是一种基于注意力的模块,不仅可以像 Transformer 一样并行训练,还能像传统 RNN 一样高效地用新 token 进行更新,只需要固定内存来进行推理。

在实验中,我们展示了 Aaren 在四个流行的序列问题设置中的 38 个数据集上,表现与 Transformer 相当,但在时间和内存效率上更优。这些问题设置包括强化学习、事件预测、时间序列分类和时间序列预测任务。

3. 方法 

3.1 注意力作为一种(多对一的)RNN

对一个查询向量 q 的注意力可以被看作一个函数,该函数将 N 个上下文 token x_(1:N) 通过它们的键和值 {(ki, vi)}^N_(i=1) 映射到一个单一的输出 o_N = Attention(q, k_(1:N), v_(1:N))。给定 s_i = dot(q, ki),输出 o_N 表示为: 

将注意力视为一种 RNN,我们可以将其迭代计算为滚动求和

其中 k=1,…,N。然而,在实践中,这是一种不稳定的实现,会由于有限精度表示和可能非常小或非常大的指数(即,exp(s))而遇到数值问题。为缓解这一问题,我们通过累积最大值项

重写递归关系,改为计算

值得注意的是,最终结果是相同的,即 o_N = ^a_N / ^c_N= a_N /c_N。因此,a_k、c_k 和 m_k 递归地计算如下: 

通过将 a_k​、c_k 和 m_k 的递归计算从 a_(k-1)​、c_(k-1) 和 m_(k-1) 中封装起来,我们引入了一个 RNN 单元,用于迭代计算注意力的输出(见图 2)。注意力的 RNN 单元以 (a_(k-1), c_(k-1), m_(k-1)​, q) 作为输入,并计算 (ak​,ck​,mk​,q)。请注意,查询向量 q 在 RNN 单元中被传递。注意力的 RNN 的初始隐藏状态为 (a_0, c_0, m_0, q) = (0, 0, 0, q)。

计算注意力的方法。通过将注意力视为一种 RNN,我们可以看到有不同的计算注意力的方法:

  • 在 O(1) 内存中逐 token 递归地(即,顺序地)
  • 以传统方式(即,并行地)在需要线性 O(N) 内存的情况下。由于注意力可以被看作是一种 RNN,计算注意力的传统方法也可以被看作是计算注意力的一对多 RNN 输出的有效方法,即,RNN 的输出将多个上下文 token 作为输入,但在 RNN 结束时只输出一个 token(见图 1a)
  • 一种按块处理 token 的 RNN,需要 O(b) 内存,其中 b 是块的大小。然而,这种方法超出了本工作的范围。因此,按块处理的 RNN 的描述包含在附录 A 中。

将现有基于注意力的模型视为 RNN。通过将注意力视为一种 RNN,现有的基于注意力的模型也可以被视为 RNN 的变体。例如,Transformer 的自注意力是 RNN(图 1b),其上下文 token 作为其初始隐藏状态。

Perceiver 的交叉注意力是 RNN(图 1c),其上下文相关潜变量作为其初始隐藏状态。通过利用它们注意力机制的 RNN 形式,这些现有模型可以有效地计算它们的输出内存。将注意力视为 RNN 存在的挑战。然而,将现有的基于注意力的模型,如 Transformer,视为 RNN 时,这些模型缺乏传统 RNN(如 LSTM 和 GRU)中常见的重要属性。值得注意的是,LSTM 和 GRU 能够以仅为  O(1) 的固定内存和计算高效地更新自己,这对于序列建模是一个重要特性,其中数据以流的形式接收。相比之下,将 Transformer 的 RNN 视图(见图 1b)处理新 token 的方式是添加一个新的RNN,新 token 作为其初始状态。新的 RNN 处理所有先前的 token,需要 O(N) token 数量的线性计算。在 Perceiver 中,由于其体系结构,潜变量(图 1c 中的 L_i)是与输入相关的,这意味着当接收到新 token 时,它们的值会改变。由于它们的 RNN 的初始隐藏状态(即,潜变量)会改变,因此 Perceiver 需要从头重新计算它们的 RNN,需要 O(NL)  token 数量(N)和潜变量数量(L)的线性计算。

3.2 注意力作为(多对多)RNN

针对这些限制,我们提议开发一种基于注意力的模型,能够利用 RNN 形式的能力进行高效更新。为此,我们首先引入了一种高效的并行化方法,用于将注意力计算为多对多的 RNN,即,一种并行方法来计算

为此,我们利用并行前缀扫描算法(Blelloch,1990)(见算法 1),这是一种并行计算方法,用于通过关联运算符 ⊕ 从 N 个连续数据点计算 N 个前缀计算。该算法可以从 {xk}^N_(k=1) 高效地计算

回想一下,Attention(q, x(1:k)) = ok = ak / ck。为了高效计算 Attention(q, x(1:k)),我们可以通过并行扫描算法计算 {ak}^N_(k=1),{ck}^N_(k=1) 和 {mk}^N_(k=1),然后将 ak 和 ck 结合起来计算 Attention(q, x(1:k))。

为此,我们提出以下关联运算符 ⊕,它作用于形式为(m_A, u_A, w_A)的三元组,其中 A 是索引集合,

mA = max_(i∈A) si,uA = ∑_(i∈A) exp(si − mA),wA = ∑_(i∈A) exp(si − mA)vi

并行扫描算法将 {(m{i}, u{i}, w{i})}^N_(i=1) = {(si, 1, vi)}^N_(i=1) 作为输入。该算法递归地应用运算符 ⊕,其工作方式如下:

(mA, uA, wA)⊕(mB, uB, wB) = (mA∪B, uA∪B, wA∪B)

mA∪B = max(mA, mB),uA∪B = uA·exp(mA − mA∪B) + uB exp(mB − mA∪B)

wA∪B = wA· exp(mA−mA∪B) + wB·exp(mB−mA∪B)

在递归地应用运算符完成后,该算法输出

{(m{1,...,k}, u{1,...,k}, w{1,...,k})}^N_(k=1) = {(mk, ∑_(i=1) exp(si − mk), ∑_(i=1) exp(si − mk)vi)}^N_(k=1)

也称为 {(mk, ck, ak)}^N_(k=1)。

通过组合输出元组的最后两个值,我们得到 Attention(q, x(1:k)) = ok = ak / ck,从而实现了一种高效的并行化方法,用于计算注意力作为多对多的 RNN(见图 3)。

3.3 Aaren: 将注意力当作 RNN

利用注意力的并行化多对多形式,我们提出了Aaren。Aaren 的接口与 Transformer 相同,将 N 个输入映射到 N 个输出,其中第 i 个输出是从第 1 到第 i 个输入的聚合。因此,Aaren也是自然可堆叠的,且能够为每个序列 token 计算单独的损失项

然而,与使用因果自注意力的 Transformer 不同,Aaren 使用了前述的将注意力计算为多对多 RNN 的方法,使其更加高效。Aaren 的功能如下:

与 Transformer 中的查询是注意力的一个输入 token 不同,Aaren 的查询 token q 是通过反向传播在训练过程中学习的。在图 4 中,我们包含了一个具有输入上下文 token x_(1​:3) 和输出 y_(1:3) 的堆叠 Aaren 模型的示例。值得注意的是,由于 Aaren 利用了注意力的 RNN 形式,Aaren 的堆叠也是 RNN 的堆叠。因此,Aarens 也能够以高效的方式与新 token 进行更新,即,仅需要固定计算量的 y_k 的迭代计算,因为它仅依赖于 h_(k-1) 和 x_k。与基于 Transformer 的模型不同,它们在使用 KV 缓存时需要线性内存,并且需要存储所有先前的 token,包括中间 Transformer 层中的 token,而基于 Aaren 的模型仅需要固定内存,并且不需要存储所有先前的 token,使得 Aaren 比 Transformer 更加高效。

4. 实验 

附录

F. 计算时间

我们的实验是在 Nvidia GTX 1080 Ti(12 GB)和 Nvidia Tesla P100(16 GB)GPU 的混合环境中运行的。分析是在 Nvidia GTX 1080 Ti(12 GB)上执行的。

强化学习实验大致需要相同的时间:每个实验约 2 到 4 小时。

事件预测实验的时间因数据集而异:

  • MIMIC 约需 0.5 小时
  • Wiki 约需 0.75 小时
  • Reddit 约需 3.5 小时
  • Mooc 约需 8 小时
  • StackOverflow 约需 3.5 小时
  • Sin 约需 1.5 小时
  • Uber 约需 3 小时
  • Taxi 约需1.5 小时

时间序列预测实验以 T ∈ {96, 192, 336, 720} 的单个脚本运行。实验时间因数据集而异:

  • Weather 约需 6 小时
  • Exchange 约需 0.5 小时
  • Traffic约需 1 小时
  • ECL 约需 4 小时
  • ETTh1 约需 0.75 小时
  • ETTm1 约需 11 小时
  • ETTh2 约需 0.75 小时
  • ETTm2 约需 11 小时

时间序列分类实验作为单个脚本运行。在所有数据集上运行实验总共约需 1 小时。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/647411.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux C++ Socket 套接字、select、poll、epoll 实例

文章目录 1. 概述2. TCP 网络编程实例2.1 服务器端2.2 客户端2.3 运行截图 3. I/O 模型3.1 阻塞式I/O模型3.2 非阻塞I/O模型3.3 I/O 复用模型3.4 信号驱动式I/O3.5 异步I/O模型 4. I/O复用之 select4.1 select 函数描述4.2 服务端代码4.3 客户端代码4.4 运行截图 5. I/O复用之 …

音视频开发9 FFmpeg 解复用框架--如何将一个影音文件(mp4文件/wav文件) 最终播放起来

一,播放器框架 二 常用音视频术语 容器/文件(Conainer/File): 即特定格式的多媒体文件, 比如mp4、flv、mkv等。 媒体流(Stream): 表示时间轴上的一段连续数据&#xff0…

jmeter之测试计划

一、测试计划作用 测试计划是jmeter的默认控件所有线程组都是测试计划的下级控件测试计划可以配置用户自定义的变量测试计划可以配置线程组的串行或并行 二、查看界面 名称:可以修改自定义的名称注释:解释测试计划是用来做什么的用户自定义的变量&…

23种设计模式顺口溜

口诀: 原型 抽风 ,单独 建造 工厂 (寓意:(这里代指本来很简单的东西,却要干工厂这里复杂的业务) 抽风:抽象工厂单独:单例桥代理组合享元适配器,&#xff0…

驱动开发之新字符设备驱动开发

1.前言 register_chrdev 和 unregister_chrdev 这两个函数是老版本驱动使用的函数,现在新的 字符设备驱动已经不再使用这两个函数,而是使用 Linux 内核推荐的新字符设备驱动 API 函数。 旧版本的接口使用,感兴趣可以看下面这个博客&#…

关于C的\r回车在不同平台的问题

首先我们需要搞明白\r和\n是两回事 \r是回车&#xff0c;前者使光标到行首&#xff0c;&#xff08;carriage return&#xff09; \n是换行&#xff0c;后者使光标下移一格&#xff0c;&#xff08;line feed&#xff09; Linux平台下 #include <stdio.h> int main()…

TiDB学习4:Placement Driver

目录 1. PD架构 2. 路由功能 2. TSO 2.1 TSO 概念 2.2 TSO分配过程 2.3 TSO时间窗口 3. 调度 3.1 信息收集 3.2 生成调度(operator) 3.3 执行调度 4. Label 与高可用 4.1 Label 的配置 5. 小结 1. PD架构 PD是整个TiDB的总控&#xff0c;相当于集群的大脑 PD集成了…

Overleaf中出现文字越界、越下届、没有正确分页、换页的原因和解决方法

在使用overleaf中&#xff0c;我偶尔会遇到如标题所说的情况&#xff0c;也如图所示&#xff1a; 后来发现&#xff0c;是因为这一页前面是一个表格&#xff0c;所以怀疑是表格的格式导致的。所以让chatgpt帮我更换了表格的格式&#xff0c;成功解决问题。 对于问题可能的成因…

20232803 2023-2024-2 《网络攻防实践》实践十报告

目录 1. 实践内容1.1 SEED SQL注入攻击与防御实验1.2 SEED XSS跨站脚本攻击实验(Elgg) 2. 实践过程2.1 SEED SQL注入攻击与防御实验2.1.1 熟悉SQL语句2.1.2 对SELECT语句的SQL注入攻击2.1.3 对UPDATE语句的SQL注入攻击2.1.4 SQL对抗 2.2 SEED XSS跨站脚本攻击实验(Elgg)2.2.1 发…

视频拼接融合产品的产品与架构设计(四)分布式GPU运算合并单元

上一篇如下 视频拼接融合产品的产品与架构设计(三&#xff09;内存和显存单元数据迁移 视频合并单元说明 对下面这张图做些说明&#xff0c;视频接入是比较常见&#xff0c;可以说是普通&#xff0c;但是做到接入后随即进行比较重的算法运算&#xff0c;这个在视频领域并不多…

探索PyImGui:高效可交互图形界面的Python实现

简介 Pyimgui 是一个基于 Cython 的 Python 绑定层&#xff0c;它将功能强大的用户界面库 Dear ImGui 无缝集成到 Python 环境中。它使 Python 开发人员能够轻松地创建交互式图形用户界面 (GUI)&#xff0c;同时充分利用 Dear ImGui 的丰富功能集。 下图为用Dear ImGui开的GU…

Android 使用 ActivityResultLauncher 申请权限

前面介绍了 Android 运行时权限。 其中&#xff0c;申请权限的步骤有些繁琐&#xff0c;需要用到&#xff1a;ActivityCompat.requestPermissions 函数和 onRequestPermissionsResult 回调函数&#xff0c;今天就借助 ActivityResultLauncher 来简化书写。 步骤1&#xff1a;创…

2024年5月26日 (周日) 叶子游戏新闻

资深开发者&#xff1a;3A游戏当前处于一种尴尬的中间地带游戏行业整体&#xff0c;尤其是3A游戏正处于艰难时期。尽管2023年3A游戏佳作频出&#xff0c;广受好评&#xff0c;但居高不下的游戏开发成本&#xff08;传闻《漫威蜘蛛侠2》的制作成本高达3亿美元&#xff09;正严重…

提高Java编程效率:ArrayList类的使用技巧

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。运营社区&#xff1a;C站/掘金/腾讯云&#xff1b;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一…

虹科Pico汽车示波器 | 免拆诊断案例 | 2012 款雪佛兰科鲁兹车偶尔多个故障灯异常点亮

故障现象 一辆2012款雪佛兰科鲁兹车&#xff0c;搭载1.8 L 发动机&#xff0c;累计行驶里程约为9.6万km。该车组合仪表上的发动机故障灯、ABS故障灯及动力转向故障灯偶尔异常点亮&#xff0c;同时发动机转速表和发动机冷却液温度表的指针会突然归零&#xff0c;严重时发动机无…

【Linux 网络】网络基础(三)(网络层协议:IP 协议)

在复杂的网络环境中确定一个合适的路径。 一、TCP 与 IP 的关系 IP 层的核心作用是定位主机&#xff0c;具有将数据从主机 A 发送到主机 B 的能力&#xff0c;但是能力并不能保证一定能够做到&#xff0c;所以这时就需要 TCP 起作用了&#xff0c;TCP 可以通过超时重传、拥塞控…

[Vulnhub]Vulnix 通过NFS挂载+SSH公钥免密登录权限提升

端口扫描 Server IP AddressPorts Open192.168.8.103TCP:22/tcp, 25/tcp, 79/tcp, 110/tcp, 111/tcp, 143/tcp, 512/tcp, 513/tcp, 514/tcp, 993/tcp, 995/tcp, 2049/tcp, 37522/tcp, 42172/tcp, 43219/tcp, 47279/tcp, 54227/tcp $ nmap -p- 192.168.8.103 -sV -sC --min-ra…

Nginx实战(安装部署、常用命令、反向代理、负载均衡、动静分离)

文章目录 1. nginx安装部署1.1 windows安装包1.2 linux-源码编译1.3 linux-docker安装 2. nginx介绍2.1 简介2.2 常用命令2.3 nginx运行原理2.3.1 mater和worker2.3.3 Nginx 的工作原理 2.4 nginx的基本配置文件2.4.1 location指令说明 3. nginx案例3.1 nginx-反向代理案例013.…

python基于深度学习的聊天机器人设计

python基于深度学习的聊天机器人设计 开发语言:Python 数据库&#xff1a;MySQL所用到的知识&#xff1a;Django框架工具&#xff1a;pycharm、Navicat、Maven 系统功能实现 登录注册功能 用户在没有登录自己的用户名之前只能浏览本网站的首页&#xff0c;想要使用其他功能都…

ROCm上来自Transformers的双向编码器表示(BERT)

14.8. 来自Transformers的双向编码器表示&#xff08;BERT&#xff09; — 动手学深度学习 2.0.0 documentation (d2l.ai) 代码 import torch from torch import nn from d2l import torch as d2l#save def get_tokens_and_segments(tokens_a, tokens_bNone):""&qu…