RWKV: Reinventing RNNs for the Transformer Era
公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
目录
0. 摘要
2. 背景
2.1 循环神经网络 (RNN)
2.2 Transformer 和 AFT
3. RWKV
3.1 架构
3.1.1 Token 移位
3.1.2 WKV 算子
3.1.3 输出门控制
3.2 类 Transformer 训练
3.3 类 RNN 的推理
3.4 附加优化
3.5 实现
4. 实验
7. 未来工作
8. 结论
9. 限制
0. 摘要
Transformer 已经彻底改变了几乎所有自然语言处理(NLP)任务,但其内存和计算复杂度随序列长度呈二次增长。与之相对,比起 Transformer,循环神经网络(RNN)的内存和计算需求呈线性增长,但由于在并行化和可扩展性方面的限制,RNN 难以达到 Transformer 的性能。
我们提出了一种新颖的模型架构,称为接收加权键值(Receptance Weighted Key Value,RWKV),它结合了 Transformer 的高效并行训练和 RNN 的高效推理。我们的方法利用线性注意力机制,使得模型可以被表述为 Transformer 或 RNN,从而在训练过程中并行计算,并在推理过程中保持恒定的计算和内存复杂度。
我们将模型扩展到多达 140 亿个参数,成为迄今为止训练过的最大的密集 RNN,并发现 RWKV 的性能与类似规模的 Transformer 相当。这表明未来的研究可以利用这一架构创建更高效的模型。这项工作在解决序列处理任务中的计算效率和模型性能之间的权衡方面迈出了重要的一步。
项目页面:https://github.com/BlinkDL/RWKV-LM
2. 背景
这里我们简要回顾 RNN 和 Transformer 的基本原理。
2.1 循环神经网络 (RNN)
流行的 RNN 架构如 LSTM(Hochreiter和Schmidhuber,1997)和 GRU(Chung等,2014)具有以下特征(以 LSTM 为例,其他可以类似推理):
尽管 RNN 可以分解为两个线性模块(W 和 U)和一个 RNN 特定模块(1)–(6),如 Bradbury等(2017)所指出的,依赖于先前时间步的数据依赖性禁止了对这些典型RNN的并行化处理。
(2024,LSTM,Transformer,指数门控,归一化器状态,多头内存混合)xLSTM:扩展的 LSTM
2.2 Transformer 和 AFT
由 Vaswani 等(2017)引入的 Transformer 是一类神经网络,已成为多个 NLP 任务的主导架构。与逐步处理序列的 RNN 不同,Transformer 依靠注意力机制来捕捉所有输入和所有输出标记之间的关系:
为了方便起见,多头机制和缩放因子 1 / √d_k 被省略了。核心的 QK^T 乘法是序列中每个标记之间的成对注意力得分的集合,可以分解为向量运算:
无注意力 Transformer(Attention Free Transformer,AFT)(Zhai 等,2021)将注意力机制替代性地表述为:
其中 {w_(t,i)} ∈ R^(T×T) 是学习到的成对位置 bias,每个 w_(t,i) 是一个标量。
受 AFT 的启发,RWKV 采用了类似的方法。然而,为了简化,它修改了交互权重,以便可以将其转化为 RNN。RWKV 中的每个 w_(t,i) 是一个按通道的时间衰减向量,当衰减时乘以相对位置并从当前时间向后追溯:
其中 d 是通道数。我们要求 w 为非负,以确保 e^(w_(t,i)) ≤ 1,并且每通道的权重在时间上向后衰减。
3. RWKV
RWKV 模型架构由四个基本元素定义,这些元素本质上是时间混合和通道混合模块的组成部分:
- R: 接收向量,用于接收过去的信息。
- W: 权重,表示位置权重衰减向量,是模型中的一个可训练参数。
- K: 键向量,其作用类似于传统注意力机制中的 K。
- V: 值向量,其功能类似于传统注意力机制中的 V。
这些核心元素在每个时间步长上进行乘法交互,如图 2 所示。
3.1 架构
RWKV 模型由堆叠的残差块组成。每个块包含一个时间混合和一个通道混合子块,体现了利用过去信息的递归结构。该模型使用了一种独特的类似注意力的得分更新过程,包括一个时间依赖的softmax 操作,提高了数值稳定性并减轻了梯度消失问题(严格证明见附录 H)。这确保了梯度沿着最相关的路径传播。此外,架构中结合的层归一化(Ba 等,2016)有助于稳定梯度,有效解决了梯度消失和梯度爆炸问题。这些设计元素不仅增强了深度神经网络的训练动态,还促进了多层堆叠,捕捉不同抽象层次上的复杂模式,从而比传统 RNN 模型表现更优(另见附录 I)。
3.1.1 Token 移位
在此架构中,涉及计算的所有线性投影向量(时间混合中的 R、K、V 以及通道混合中的 R'、K')都是通过当前时间步和先前时间步输入之间的线性插值生成的,便于 token 移位(shift)。用于时间混合计算的向量是当前和先前输入的线性组合的线性投影
通道混合的输入也是如此:
token 移位在每个块的时间维度上通过简单的 offset 实现,使用 PyTorch(Paszke等,2019)库中的 nn.ZeroPad2d((0,0,1,-1))
。
3.1.2 WKV 算子
我们模型中 WKV 算子的计算方法类似于 Attention Free Transformer(AFT)(Zhai等,2021)中使用的方法。然而,与 AFT 中 W 为成对矩阵不同,我们的模型将 W 视为一个按通道的向量,并根据相对位置进行修改。在我们的模型中,这种递归行为由 WKV 向量的时间依赖更新定义,形式化为以下方程:
为了规避 W 的任何潜在降级,我们引入了一个向量 U,它关注当前 token。关于此的更多信息可以在附录 I 中找到。
3.1.3 输出门控制
输出门控制在时间混合和通道混合块中都使用接收向量的 sigmoid 函数 σ(r) 实现。WKV 算子后的输出向量 o_t 由以下公式给出:
在通道混合块中,执行类似的操作:
这里我们采用了平方 ReLU 激活函数(So 等,2021)。
3.2 类 Transformer 训练
RWKV 可以通过一种称为时间并行模式的技术进行高效并行化,类似于 Transformer。在单层中处理一批序列的时间复杂度为 O(BTd^2),主要由矩阵乘法 W_λ 组成,其中 λ ∈ {r, k, v, o}(假设 B 个序列,T 个最大标记,d 个通道)。相比之下,更新注意力分数 wkv_t 涉及串行扫描(详见附录 D 以获取更多细节),具有复杂度 O(BTd)。
矩阵乘法可以类似于传统 Transformer 中 的 W_λ 进行并行化,其中 λ ∈ {Q, K, V, O} 。元素级的 WKV 计算是时间依赖的,但可以沿其他两个维度(Lei 等,2018)轻松并行化。
3.3 类 RNN 的推理
RNN 通常利用状态 t 处的输出作为状态 t + 1 处的输入。这种用法也观察到在语言模型的自回归解码推理中,其中每个标记必须在传递到下一步之前计算。RWKV 利用了这种类似 RNN 的结构,称为时间顺序模式。在这种情况下,RWKV 可以在推理期间方便地递归地进行编码,如附录 D 所示。
3.4 附加优化
自定义内核。为了解决使用标准深度学习框架时由任务的顺序性质引起的 WKV 计算中的低效率,我们开发了一个定制的 CUDA 内核。该内核使得在训练加速器上执行单个计算内核成为可能,而模型的所有其他部分,如矩阵乘法和逐点操作,已经是固有的可并行化和高效的。
小初始化嵌入(Small Init Embedding)。在训练 Transformer 模型(Vaswani 等,2017)的初始阶段,我们观察到嵌入矩阵的变化速度较慢,这对模型摆脱初始噪声嵌入状态构成了挑战。为了解决这个问题,我们提出了一种方法,该方法涉及使用小值初始化嵌入矩阵,随后应用额外的 LayerNorm 操作。这加速和稳定了训练过程,允许使用后 LN 组件训练深层架构。这种方法的有效性在图 9 中得到了证明,该图说明通过使模型迅速摆脱初始小嵌入状态,实现了改进的收敛。这是通过在单个步骤中发生的小变化实现的,随后在 LayerNorm 操作之后导致了方向上的实质性变化和进一步的显着变化。
自定义初始化。建立在先前工作(He等,2016;Jumper等,2021)的原理之上,我们采用了一种初始化策略,其中参数被设置为类似于标识映射的值,同时打破对称性以建立清晰的信息流。大多数权重被初始化为零,线性层不使用偏差。附录E中给出了详细的公式。我们观察到初始化的选择在收敛的速度和质量方面起着至关重要的作用(有关更多详细信息,请参阅附录 F)。
3.5 实现
RWKV 使用 PyTorch 深度学习库(Paszke等,2019)实现。我们将 DeepSpeed(Rasley 等,2020)启发的附加优化策略集成到系统中,提高了其效率和可扩展性。模型从一个嵌入层开始,如第 3.4 节所述。随后是若干相同的残差块按顺序排列。这些在图 2 和图 3 中描述,并且符合第 3.1.1 节中概述的原则。在最后一个块之后,使用简单的输出投影头进行逻辑生成,该头包括一个LayerNorm(Ba等,2016)和一个线性投影,用于下一个 token 的预测和在训练期间计算交叉熵损失。
4. 实验
7. 未来工作
对于 RWKV 架构的未来工作有几个有希望的方向。可以通过增强时间衰减公式和探索初始模型状态的方式来增加模型的表达能力(expressivity),同时保持效率。可以通过在 wkv_t 步骤中应用并行扫描来进一步提高 RWKV 的计算效率,从而将计算成本降低到 O(B log(T)d)。
RWKV 所使用的机制可以应用于编码器-解码器架构,潜在地替代交叉注意力机制。这可能适用于 seq2seq 或多模态设置,从而增强训练和推理过程的效率。
可以利用 RWKV 的状态(或上下文)来提高序列数据的可解释性、可预测性和安全性。操纵隐藏状态也可以引导行为,并通过提示调整实现更大的可定制性。
RWKV 架构并不完美,可以通过修改公式或实现更大的内部状态等方面进行改进。更大的状态可以增强模型对先前上下文的记忆,并提高在各种任务上的性能。
8. 结论
我们介绍了 RWKV,这是一种利用基于时间的混合组件潜力的 RNN 模型的新方法。RWKV 引入了几个关键策略,使其能够捕捉局部性和长程依赖性,同时通过以下方式解决了当前架构的局限性:(1) 将二次的 QK 注意力替换为线性成本下的标量公式,(2) 重新定义了循环和顺序归纳偏差,以实现有效的训练并行化和有效的推理,(3) 使用自定义初始化增强训练动态。我们在各种 NLP 任务中对所提出的架构进行了基准测试,并展示了与 SoTA 相当的性能以及降低的成本。对表达能力、可解释性和扩展性的进一步实验展示了模型的能力,并在 RWKV 和其他 LLM 之间绘制了行为的类比。
RWKV 为在序列数据中建模复杂关系提供了一条可扩展和高效的新路。虽然已经提出了许多与 Transformer 类似的替代方案,但我们的方法是第一个通过拥有数十亿个参数的预训练模型来支持这些主张的方法。
9. 限制
虽然我们提出的 RWKV 模型在训练和推理期间的内存效率方面表现出了有希望的结果,但未来的工作中应该承认并解决一些限制。
首先,RWKV 的线性注意力带来了显著的效率提升,但也可能限制模型在需要在非常长的上下文中回忆细微信息的任务中的性能。这是由于信息通过单个向量表示在许多时间步中传递,与标准 Transformer 的二次注意力中维持的完整信息相比。换句话说,模型的循环架构固有地限制了其查看先前 token 的能力,与传统的自注意机制相比。虽然学习的时间衰减有助于防止信息的丢失,但与完全的自注意相比,它在机制上存在限制。
此外,与标准 Transformer 模型相比,对提示工程的重视程度增加是本文的另一个限制。RWKV 中使用的线性注意力机制限制了从提示中传递到模型后续部分的信息。因此,精心设计的提示对模型在任务中表现良好可能更加关键。
上述 RWKV 属性通过附录 L 中提出的提示工程研究得到了确认。通过改变信息片段的顺序,我们甚至能够将某些任务的 RWKV 性能几乎提高一倍。