(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

3. 线性 Transformers

3.1. Transformer

3.2. 线性注意力机制

3.2.1. 特征映射与计算成本

3.3. 因果掩码

3.3.1. 梯度计算

3.3.2. 训练和推理

3.4. transformer 是 RNN

4. 实验


0. 摘要

Transformer 在多项任务中表现出色,但由于其对输入长度的二次复杂度,对于非常长的序列来说,速度极慢。为了解决这一限制,我们将自注意力表示为核特征映射(kernel feature maps)的线性点积,并利用矩阵乘积的结合性将复杂度从 O(N^2) 降低到 O(N),其中 N 是序列长度。我们证明了这种表达方式允许一种迭代实现,大大加速了自回归 Transformer,并揭示了它们与递归神经网络的关系。我们的线性 Transformer 在性能上与普通 Transformer 相似,并且在非常长序列的自回归预测中速度快达 4000 倍。 

3. 线性 Transformers

在本节中,我们提出了线性 Transformer。我们展示了将传统的 softmax 注意力机制改为基于特征映射的点积注意力,可以改善时间和内存复杂度,并且可以实现类似于 RNN 的线性时间序列生成模型。

3.1. Transformer

3.2. 线性注意力机制

公式 2 中的注意力定义是通用的,可以用于定义多种其他注意力实现,例如多项式注意力或 RBF 核注意力(Tsai等人,2019)。注意,为了使公式 3 定义的注意力函数有效,我们需要对 sim(·) 施加的唯一约束是非负性。这包括所有核函数 k(x, y): R^(2 × F) → R_+。

给定具有特征表示 ϕ(x) 的核函数,我们可以将公式 2 重写为:

然后利用矩阵乘法的结合性进一步简化为:

当分子以向量形式书写时,上述公式更容易理解,如下所示:

注意,特征映射 ϕ(·) 是逐行应用于矩阵 Q 和 K 的。

从公式 2 可以看出,softmax 注意力的计算成本随 O(N^2) 缩放,其中 N 表示序列长度。内存需求也是如此,因为必须存储完整的注意力矩阵以计算查询、键和值的梯度。相比之下,我们在公式 5 中提出的线性 transformer 具有 O(N) 的时间和内存复杂度,因为我们可以计算

一次,并在每个查询中重复使用它们。

3.2.1. 特征映射与计算成本

对于 softmax 注意力,就乘法和加法的总成本而言,随着 O(N^2·max(D, M)) 缩放,其中 D 是查询和键的维度,M 是值的维度。相反,对于线性注意力,我们首先计算维度为 C 的特征映射。随后,计算新值需要 O(NCM) 次加法和乘法。

上述分析未考虑核函数和特征函数的选择。需要注意的是,对应于指数核的特征函数是无限维的,这使得精确 softmax 注意力的线性化不可行。另一方面,例如多项式核具有精确的有限维特征映射,并且已证明与指数或 RBF 核(Tsai等人,2019)同样有效。线性化多项式 transformer 的计算成本为 O(N·D^2·M)。当 N > D^2 时,这使得计算复杂度更具优势。实际上,由于我们希望能够处理成千上万元素的序列,这一情况是成立的。

对于我们的实验,处理较小的序列,我们采用了一个结果为正相似函数的特征映射,如下定义:

其中 elu(·) 表示指数线性单元(Clevert等人,2015)的激活函数。我们更喜欢 elu(·) 而不是relu(·),以避免在 x 为负时将梯度设置为 0。这种特征映射导致的注意力函数需要 O(NDM) 次乘法和加法。在我们的实验部分,我们展示了公式 7 的特征映射在性能上与完整 transformer 相当,同时显著减少了计算和内存需求。

3.3. 因果掩码

transformer  架构可以通过掩蔽(masking)注意力计算来高效地训练自回归模型,使得第 i 个位置只能被第 j 个位置影响当且仅当 j ≤ i,即一个位置不能被后续位置影响。形式上,这种因果掩码将公式 3 修改如下:

按照3.2节的推理,我们如下所述对掩码注意力进行线性化:

通过引入 Si 和 Zi 如下所示:

我们可以将公式 9 简化为:

注意,Si 和 Zi 可以从 S_(i-1) 和 Z_(i-1) 在固定时间内计算得出,因此使得具有因果掩码的线性 transformer 的计算复杂度相对于序列长度为线性。

3.3.1. 梯度计算

在任何深度学习框架中,公式 12 的朴素实现需要存储所有中间值 Si,以计算梯度。这会增加max(D, M) 倍的内存消耗,从而阻碍因果线性注意力在更长序列或更深模型中的应用。为了解决这个问题,我们将公式 9 中的分子(numerator)的梯度导出为累积和。这使我们能够在线性时间和固定内存中计算因果线性注意力的前向和后向传播。详细推导见附录材料。

给定分子 ¯V_i 和标量损失函数相对于分子的梯度

推导可得:

累计和项在公式 9 和 13-15 中以线性时间计算,并且相对于序列长度需要常量内存。这导致的算法在给定维度为 C 的特征映射下,其计算复杂度为 O(NCM),内存复杂度为 O(N·max (C, M))。算法 1 是分子部分前向和后向传播的伪代码实现。

3.3.2. 训练和推理

在训练自回归 transformer 模型时,可以使用完整的真实序列。这使得公式 1 中的函数 φ(·) 和注意力计算都可以进行分层并行化。因此,transformer 比 RNN 更高效地进行训练。然而,在推理过程中,时间步 i 的输出是时间步 i + 1 的输入。这使得自回归模型无法并行化。此外,transformer 每个时间步的成本不是常量,而是随着当前序列长度的平方增长,因为必须为所有先前的时间步计算注意力。

我们提出的线性 transformer 模型结合了这两者的优点。在训练时,计算可以并行化并充分利用 GPU 或其他加速器。在推理时,我们模型的每次预测在时间和内存上的成本是常量的。这意味着我们可以简单地将

矩阵存储为内部状态,并在每个时间步像递归神经网络一样更新它。这使得推理速度比其他 transformer 模型快数千倍。

3.4. transformer 是 RNN

在文献中,transformer 模型被认为是一种与递归神经网络(RNN)根本不同的方法。然而,从 3.3 节中的因果掩码公式和前一节的讨论可以看出,任何具有因果掩码的 transformer 层都可以被表示为一种模型,该模型在给定输入后修改内部状态,然后预测输出,即 RNN。注意,与通用变压器(Universal Transformers)(Dehghani等人,2018)不同,我们考虑的是时间上的递归,而不是深度上的递归。

在以下公式中,我们将公式 1 的 Transformer 层形式化为 RNN。所得的 RNN 有两个隐藏状态,即注意力记忆 s 和归一化记忆 z。我们用下标表示递归中的时间步。

在上述公式中,x_i 表示特定 Transformer 层的第 i 个输入,y_i 表示第 i 个输出。需要注意的是,我们的公式对特征函数没有任何约束,因此可以用于表示任何 Transformer 模型,理论上甚至包括使用 softmax 注意力的模型。这一公式是更好理解 Transformer 与流行的 RNN(Hochreiter & Schmidhuber, 1997)及其存储和检索信息过程之间关系的第一步。 

4. 实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/655936.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue:快速上手

一、简介 Vue (发音为 /vjuː/,类似 view) 是一款用于构建用户界面的 JavaScript 框架。它基于标准 HTML、CSS 和 JavaScript 构建,并提供了一套声明式的、组件化的编程模型,帮助你高效地开发用户界面。无论是简单还是复杂的界面,…

第八次javaweb作业

我们小组课程设计的题目是:超市管理系统,我认领的模块是:商品信息管理 controller package com.example.supermarker.controller;import com.example.supermarker.pojo.MerchInfo; import com.example.supermarker.pojo.PageBean; import c…

pillow学习3

Pillow库中,图像的模式代表了图像的颜色空间。以下是一些常见的图像模式及其含义: L(灰度图):L模式表示图像是灰度图像,每个像素用8位表示(范围为0-255),0表示黑色&#…

TTime:截图翻译/OCR

日常网页翻译Translate Web Pages完全足够,TTime最重要的功能,还是截图翻译,还有个厉害的功能,就是静默OCR,相比之前的分享的识字精灵效率更高。 软件使用 打开软件,点击翻译源设置,建议勾选一…

grafana大盘展示node_expod节点

node_expod添加lables标签 Prometheus查询 语句查询 node_exporter_build_infografna添加变量查询 正常有值 切换其他的是有值的 我的报错原因 因为有多个数据源,我选择错了,因为修改的lable标签是其他数据源,所以获取不到 查询语句 我的变量是 $app node_filesyste…

养老院管理系统基于springboot的养老院管理系统java项目

文章目录 养老院管理系统一、项目演示二、项目介绍三、系统部分功能截图四、部分代码展示五、底部获取项目源码(9.9¥带走) 养老院管理系统 一、项目演示 养老院管理系统 二、项目介绍 基于springboot的养老院管理系统 角色:超级…

Python代码:十七、生成列表

1、题目 描述: 一串连续的数据用什么记录最合适,牛牛认为在Python中非列表(list)莫属了。现输入牛牛朋友们的名字,请使用list函数与split函数将它们封装成列表,再整个输出列表。 输入描述: …

011-Linux磁盘管理

文章目录 前言 一、du:查看文件和目录占用的磁盘空间 二、df:查看文件系统的整体磁盘使用情况 三、lsblk:查看设备挂载情况 四、fdisk:磁盘分区 4.1、查看磁盘分区列表 4.2、磁盘分区 4.2.1、交互命令的功能 4.2.2、对/d…

详细分析Element中的Drawer(附Demo)

目录 前言1. 基本知识2. Demo2.1 基本用法2.2 不同方向2.3 自定义大小2.4 嵌入表单2.5 嵌套抽屉 3. 实战4. Element Plus(Drawer) 前言 对于该组件针对Vue2比较多,而Element Plus中的Drawer针对Vue3比较多 此处的Demo主要偏向Vue2 后续的El…

【学习笔记】计算机组成原理(七)

指令系统 文章目录 指令系统7.1 机器指令7.1.1 指令的一般格式7.1.2 指令字长 7.2 操作数类型和操作类型7.2.1 操作数类型7.2.2 数据在存储器中的存放方式7.2.3 操作类型 7.3 寻址方式7.3.1 指令寻址7.3.1.1 顺序寻址7.3.1.2 跳跃寻址 7.3.2 数据寻址7.3.2.1 立即寻址7.3.2.2 直…

【数据结构与算法】七大排序算法(上)

【数据结构与算法】七大排序算法(上) 🥕个人主页:开敲🍉 🔥所属专栏:数据结构与算法🍅 🌼文章目录🌼 1. 排序的概念及应用 1.1 排序的概念 1.2 排序的应用 1.3 常见排序算法 2. 常…

Spring MVC+mybatis 项目入门:旅游网(二) dispatcher与controller与Spring MVC

个人博客:Spring MVCmybatis 项目入门:旅游网(二)dispatcher与controller与Spring MVC | iwtss blog 先看这个! 这是18年的文章,回收站里恢复的,现阶段看基本是没有参考意义的,技术老旧脱离时代…

中国上市企业行业异质性数据分析

数据简介:企业行业异质性数据是指不同行业的企业在运营、管理、财务等方面的差异性数据。这些数据可以反映不同行业企业的特点、优势和劣势,以及行业间的异质性对企业经营和投资的影响。通过对企业行业异质性数据的分析,投资者可以更好地了解…

杀死那个进程

一、场景 eclipse在启动tomcat时,出现端口被占用的情况。我寻思着“任务管理器”没出现相应程序在跑啊。 1.1问题:端口和进程的关系 端口和进程之间存在着一种关系,端口是一个逻辑概念,它用于标识网络通信中的一个终点&#xff0…

基于Java实现震中附近风景区预警可视化分析实践

目录 前言 一、空间数据说明 1、表结构信息展示 2、空间范围查询 二、Java后台开发实现 1、模型层设计与实现 2、控制层设计与实现 三、Leaflet地图开发 1、地震震中位置展示 2、百公里风景区列表展示 3、风景区列表展示 4、附近风景区展示 四、总结 前言 地震这类…

为表格添加背景色:\rowcolor, \columncolor,\cellcolor

设置行的背景 \rowcolor 是 LaTeX 中用于设置表格行的背景色的命令。它可以使表格更加美观和易于阅读。rowcolor 命令通常与 colortbl 宏包一起使用。 语法如下&#xff1a; \rowcolor{<color>}其中 表示要设置的背景色&#xff0c;可以是预定义的颜色名称&#xff08…

C++算术运算和自增自减运算

一 引言 表示运算的符号称为运算符。 算术运算&#xff1b; 比较运算&#xff1b; 逻辑运算&#xff1b; 位运算&#xff1b; 1 算术运算 算术运算包括加、减、乘、除、乘方、指数、对数、三角函数、求余函数&#xff0c;这些都是算术运算。 C中用、-、*、/、%分别表示加、减…

Redis 中 List 数据结构详解

目录 List 用法 1. 增 2. 删 3. 查 内部编码 应用场景 前言 Redis 中的 List 和 Set 数据结构各有特点&#xff0c;适用于不同的应用场景。List 提供了有序的列表结构&#xff0c;适合用于消息队列和任务列表等场景&#xff1b;Set 提供了无序且不重复的集合结构&#…

9.Docker网络

文章目录 1、Docker网络简介2、常用基本命令3、网络模式对比举例3.1、bridge模式3.2、host模式3.3、none模式3.4、container模式3.5、自定义网络 1、Docker网络简介 作用&#xff1a; 容器间的互联和通信以及端口映射容器IP变动时候可以通过服务名直接进行网络通信而不受到影…

module ‘plotting‘ has no attribute ‘EpisodeStats‘

plotting.py 的版本不同&#xff0c;可以使用下列版本 reinforcement-learning/lib/plotting.py at master dennybritz/reinforcement-learning GitHubImplementation of Reinforcement Learning Algorithms. Python, OpenAI Gym, Tensorflow. Exercises and Solutions to a…