(2023|EMNLP,RWKV,Transformer,RNN,AFT,时间依赖 Softmax,线性复杂度)

RWKV: Reinventing RNNs for the Transformer Era

公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

2. 背景

2.1 循环神经网络 (RNN)

2.2 Transformer 和 AFT

3. RWKV

3.1 架构

3.1.1 Token 移位

3.1.2 WKV 算子

3.1.3 输出门控制

3.2 类 Transformer 训练

3.3 类 RNN 的推理

3.4 附加优化

3.5 实现

4. 实验 

7. 未来工作

8. 结论

9. 限制


0. 摘要

Transformer 已经彻底改变了几乎所有自然语言处理(NLP)任务,但其内存和计算复杂度随序列长度呈二次增长。与之相对,比起 Transformer,循环神经网络(RNN)的内存和计算需求呈线性增长,但由于在并行化和可扩展性方面的限制,RNN 难以达到 Transformer 的性能。

我们提出了一种新颖的模型架构,称为接收加权键值(Receptance Weighted Key Value,RWKV),它结合了 Transformer 的高效并行训练和 RNN 的高效推理。我们的方法利用线性注意力机制,使得模型可以被表述为 Transformer 或 RNN,从而在训练过程中并行计算,并在推理过程中保持恒定的计算和内存复杂度。

我们将模型扩展到多达 140 亿个参数,成为迄今为止训练过的最大的密集 RNN,并发现 RWKV 的性能与类似规模的 Transformer 相当。这表明未来的研究可以利用这一架构创建更高效的模型。这项工作在解决序列处理任务中的计算效率和模型性能之间的权衡方面迈出了重要的一步。

项目页面:https://github.com/BlinkDL/RWKV-LM

2. 背景

这里我们简要回顾 RNN 和 Transformer 的基本原理。

2.1 循环神经网络 (RNN)

流行的 RNN 架构如 LSTM(Hochreiter和Schmidhuber,1997)和 GRU(Chung等,2014)具有以下特征(以 LSTM 为例,其他可以类似推理):

尽管 RNN 可以分解为两个线性模块(W 和 U)和一个 RNN 特定模块(1)–(6),如 Bradbury等(2017)所指出的,依赖于先前时间步的数据依赖性禁止了对这些典型RNN的并行化处理。

(2024,LSTM,Transformer,指数门控,归一化器状态,多头内存混合)xLSTM:扩展的 LSTM 

2.2 Transformer 和 AFT

由 Vaswani 等(2017)引入的 Transformer 是一类神经网络,已成为多个 NLP 任务的主导架构。与逐步处理序列的 RNN 不同,Transformer 依靠注意力机制来捕捉所有输入和所有输出标记之间的关系: 

为了方便起见,多头机制和缩放因子 1 / √d_k 被省略了。核心的 QK^T 乘法是序列中每个标记之间的成对注意力得分的集合,可以分解为向量运算: 

无注意力 Transformer(Attention Free Transformer,AFT)(Zhai 等,2021)将注意力机制替代性地表述为:

其中 {w_(t,i)} ∈ R^(T×T) 是学习到的成对位置 bias,每个 w_(t,i) 是一个标量。

受 AFT 的启发,RWKV 采用了类似的方法。然而,为了简化,它修改了交互权重,以便可以将其转化为 RNN。RWKV 中的每个 w_(t,i) 是一个按通道的时间衰减向量,当衰减时乘以相对位置并从当前时间向后追溯:

其中 d 是通道数。我们要求 w 为非负,以确保 e^(w_(t,i)) ≤ 1,并且每通道的权重在时间上向后衰减。

3. RWKV

RWKV 模型架构由四个基本元素定义,这些元素本质上是时间混合和通道混合模块的组成部分:

  • R: 接收向量,用于接收过去的信息。
  • W: 权重,表示位置权重衰减向量,是模型中的一个可训练参数。
  • K: 键向量,其作用类似于传统注意力机制中的 K。
  • V: 值向量,其功能类似于传统注意力机制中的 V。

这些核心元素在每个时间步长上进行乘法交互,如图 2 所示。

3.1 架构

RWKV 模型由堆叠的残差块组成。每个块包含一个时间混合和一个通道混合子块,体现了利用过去信息的递归结构。该模型使用了一种独特的类似注意力的得分更新过程,包括一个时间依赖的softmax 操作,提高了数值稳定性并减轻了梯度消失问题(严格证明见附录 H)。这确保了梯度沿着最相关的路径传播。此外,架构中结合的层归一化(Ba 等,2016)有助于稳定梯度,有效解决了梯度消失和梯度爆炸问题。这些设计元素不仅增强了深度神经网络的训练动态,还促进了多层堆叠,捕捉不同抽象层次上的复杂模式,从而比传统 RNN 模型表现更优(另见附录 I)。

3.1.1 Token 移位

在此架构中,涉及计算的所有线性投影向量(时间混合中的 R、K、V 以及通道混合中的 R'、K')都是通过当前时间步和先前时间步输入之间的线性插值生成的,便于 token 移位(shift)。用于时间混合计算的向量是当前和先前输入的线性组合的线性投影

通道混合的输入也是如此: 

token 移位在每个块的时间维度上通过简单的 offset 实现,使用 PyTorch(Paszke等,2019)库中的 nn.ZeroPad2d((0,0,1,-1))

3.1.2 WKV 算子

我们模型中 WKV 算子的计算方法类似于 Attention Free Transformer(AFT)(Zhai等,2021)中使用的方法。然而,与 AFT 中 W 为成对矩阵不同,我们的模型将 W 视为一个按通道的向量,并根据相对位置进行修改。在我们的模型中,这种递归行为由 WKV 向量的时间依赖更新定义,形式化为以下方程:

为了规避 W 的任何潜在降级,我们引入了一个向量 U,它关注当前 token。关于此的更多信息可以在附录 I 中找到。

3.1.3 输出门控制

输出门控制在时间混合和通道混合块中都使用接收向量的 sigmoid 函数 σ(r) 实现。WKV 算子后的输出向量 o_t 由以下公式给出: 

在通道混合块中,执行类似的操作:

这里我们采用了平方 ReLU 激活函数(So 等,2021)。

3.2 类 Transformer 训练

RWKV 可以通过一种称为时间并行模式的技术进行高效并行化,类似于 Transformer。在单层中处理一批序列的时间复杂度为 O(BTd^2),主要由矩阵乘法 W_λ 组成,其中 λ ∈ {r, k, v, o}(假设 B 个序列,T 个最大标记,d 个通道)。相比之下,更新注意力分数 wkv_t 涉及串行扫描(详见附录 D 以获取更多细节),具有复杂度 O(BTd)。

矩阵乘法可以类似于传统 Transformer 中 的 W_λ 进行并行化,其中 λ ∈ {Q, K, V, O} 。元素级的 WKV 计算是时间依赖的,但可以沿其他两个维度(Lei 等,2018)轻松并行化。

3.3 类 RNN 的推理

RNN 通常利用状态 t 处的输出作为状态 t + 1 处的输入。这种用法也观察到在语言模型的自回归解码推理中,其中每个标记必须在传递到下一步之前计算。RWKV 利用了这种类似 RNN 的结构,称为时间顺序模式。在这种情况下,RWKV 可以在推理期间方便地递归地进行编码,如附录 D 所示。

3.4 附加优化

自定义内核。为了解决使用标准深度学习框架时由任务的顺序性质引起的 WKV 计算中的低效率,我们开发了一个定制的 CUDA 内核。该内核使得在训练加速器上执行单个计算内核成为可能,而模型的所有其他部分,如矩阵乘法和逐点操作,已经是固有的可并行化和高效的。

小初始化嵌入(Small Init Embedding)。在训练 Transformer 模型(Vaswani 等,2017)的初始阶段,我们观察到嵌入矩阵的变化速度较慢,这对模型摆脱初始噪声嵌入状态构成了挑战。为了解决这个问题,我们提出了一种方法,该方法涉及使用小值初始化嵌入矩阵,随后应用额外的 LayerNorm 操作。这加速和稳定了训练过程,允许使用后 LN 组件训练深层架构。这种方法的有效性在图 9 中得到了证明,该图说明通过使模型迅速摆脱初始小嵌入状态,实现了改进的收敛。这是通过在单个步骤中发生的小变化实现的,随后在 LayerNorm 操作之后导致了方向上的实质性变化和进一步的显着变化。

自定义初始化。建立在先前工作(He等,2016;Jumper等,2021)的原理之上,我们采用了一种初始化策略,其中参数被设置为类似于标识映射的值,同时打破对称性以建立清晰的信息流。大多数权重被初始化为零,线性层不使用偏差。附录E中给出了详细的公式。我们观察到初始化的选择在收敛的速度和质量方面起着至关重要的作用(有关更多详细信息,请参阅附录 F)。

3.5 实现

RWKV 使用 PyTorch 深度学习库(Paszke等,2019)实现。我们将 DeepSpeed(Rasley 等,2020)启发的附加优化策略集成到系统中,提高了其效率和可扩展性。模型从一个嵌入层开始,如第 3.4 节所述。随后是若干相同的残差块按顺序排列。这些在图 2 和图 3 中描述,并且符合第 3.1.1 节中概述的原则。在最后一个块之后,使用简单的输出投影头进行逻辑生成,该头包括一个LayerNorm(Ba等,2016)和一个线性投影,用于下一个 token 的预测和在训练期间计算交叉熵损失。

4. 实验 

7. 未来工作

对于 RWKV 架构的未来工作有几个有希望的方向。可以通过增强时间衰减公式和探索初始模型状态的方式来增加模型的表达能力(expressivity),同时保持效率。可以通过在 wkv_t 步骤中应用并行扫描来进一步提高 RWKV 的计算效率,从而将计算成本降低到 O(B log(T)d)。

RWKV 所使用的机制可以应用于编码器-解码器架构,潜在地替代交叉注意力机制。这可能适用于 seq2seq 或多模态设置,从而增强训练和推理过程的效率。

可以利用 RWKV 的状态(或上下文)来提高序列数据的可解释性、可预测性和安全性。操纵隐藏状态也可以引导行为,并通过提示调整实现更大的可定制性。

RWKV 架构并不完美,可以通过修改公式或实现更大的内部状态等方面进行改进。更大的状态可以增强模型对先前上下文的记忆,并提高在各种任务上的性能。

8. 结论

我们介绍了 RWKV,这是一种利用基于时间的混合组件潜力的 RNN 模型的新方法。RWKV 引入了几个关键策略,使其能够捕捉局部性和长程依赖性,同时通过以下方式解决了当前架构的局限性:(1) 将二次的 QK 注意力替换为线性成本下的标量公式,(2) 重新定义了循环和顺序归纳偏差,以实现有效的训练并行化和有效的推理,(3) 使用自定义初始化增强训练动态。我们在各种 NLP 任务中对所提出的架构进行了基准测试,并展示了与 SoTA 相当的性能以及降低的成本。对表达能力、可解释性和扩展性的进一步实验展示了模型的能力,并在 RWKV 和其他 LLM 之间绘制了行为的类比。

RWKV 为在序列数据中建模复杂关系提供了一条可扩展和高效的新路。虽然已经提出了许多与 Transformer 类似的替代方案,但我们的方法是第一个通过拥有数十亿个参数的预训练模型来支持这些主张的方法。

9. 限制

虽然我们提出的 RWKV 模型在训练和推理期间的内存效率方面表现出了有希望的结果,但未来的工作中应该承认并解决一些限制。

首先,RWKV 的线性注意力带来了显著的效率提升,但也可能限制模型在需要在非常长的上下文中回忆细微信息的任务中的性能。这是由于信息通过单个向量表示在许多时间步中传递,与标准 Transformer 的二次注意力中维持的完整信息相比。换句话说,模型的循环架构固有地限制了其查看先前 token 的能力,与传统的自注意机制相比。虽然学习的时间衰减有助于防止信息的丢失,但与完全的自注意相比,它在机制上存在限制。

此外,与标准 Transformer 模型相比,对提示工程的重视程度增加是本文的另一个限制。RWKV 中使用的线性注意力机制限制了从提示中传递到模型后续部分的信息。因此,精心设计的提示对模型在任务中表现良好可能更加关键。

上述 RWKV 属性通过附录 L 中提出的提示工程研究得到了确认。通过改变信息片段的顺序,我们甚至能够将某些任务的 RWKV 性能几乎提高一倍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/653672.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

零拷贝(Zero Copy)

目录 零拷贝(Zero Copy) 1.什么是Zero Copy? 2.物理内存和虚拟内存 3.内核空间和用户空间 4.Linux的I/O读写方式 4.1 I/O中断原理 4.2 DMA传输原理 5.传统I/O方式 5.1传统读操作 5.2传统写操作 6.零拷贝 6.1.用户态直接IO 6.2.mmapwrite …

The First项目报告:解读去中心化衍生品交易所AVEO

2023 年12月8日凌晨,Solana 生态 MEV 基础设施开发商 Jito Labs 开放了 JTO 空投申领窗口,JTO 的价格在开盘短暂震荡后迅速攀高,一度触及 4.94 美元。 JTO 是加密社区这两日关注的热门标的,而在这场讨论中,除 Solana …

unity接入live2d

在bilibili上找到一个教程,首先注意一点,你直接导入那个sdk,并且打开示例,显示的模型是有问题的,你需要调整模型上脚本的一个枚举值,调整它的渲染顺序是front z to我看教程时候,很多老师都没有提…

python max_min标准化

python max_min标准化 max_min标准化sklearn实现max_min标准化手动实现max_min标准化 max_min标准化 Max-Min标准化(也称为归一化或Min-Max Scaling)是一种将数据缩放到特定范围(通常是0到1)的标准化方法。这种方法通过线性变换将…

【软考】下篇 第14章 云原生架构设计与理论实践

目录 一、云原生架构定义二、云原生架构原则三、云原生架构主要架构模式3.1 服务化架构模式3.2 Mesh化架构模式3.3 Serverless模式3.4 存储计算分离模式3.5 分布式事务模式4.6 可观测架构3.7 事件驱动架构 四、云原生架构反模式五、云原生架构技术5.1 容器技术容器编排K8S 5.2 …

Elasticsearch 分析器的高级用法二(停用词,拼音搜索)

Elasticsearch 分析器的高级用法二(停用词,拼音搜索) 停用词简介停用词分词过滤器自定义停用词分词过滤器内置分析器的停用词过滤器注意,有一个细节 拼音搜索安装使用相关配置 停用词 简介 停用词是指,在被分词后的词…

【umi-max】初识 antd pro

修改端口号 根目录下的 .env 文件: PORT8888目录结构 (umijs.org) 新增页面 在 umirc.ts 中进行配置。 新增页面 - Ant Design Pro 这里有一个配置 icon:string,可以在菜单加 icon 图标,默认使用 antd 的 icon 名,默认不适用二…

Yourpassword does not satisfy the current policyrequirements

mysql 新增数据库用户失败 解决方法: 修改校验密码策略等级 set global validate_password.policyLOW;

【K8s】专题四(1):Kubernetes 控制器简介

以下内容均来自个人笔记并重新梳理,如有错误欢迎指正!如果对您有帮助,烦请点赞、关注、转发!欢迎扫码关注个人公众号! 目录 一、基本概念 二、工作原理 三、常见类型 四、相关特性 一、基本概念 Kubernetes 控制器…

js中金额进行千分以及toFixed()保留两位小数丢失精度的问题

1、金额进行千分 function commafy(num) { if ((num "").trim() "") { return ""; } if (isNaN(num)) { return ""; } num num ""; if (/^.*\..*$/.test(num)) { const pointIndex num.lastIndexOf("."); co…

像素匹配+均值homograph+结果

1. 像素匹配 2. 均值homography 转换前转换后 3. 比较 基准图转换图

Kibana创建ElasticSearch 用户角色

文章目录 1, ES 权限参考2, 某应用的管理员权限:可以open/close/delete/cat/read/write 索引3, 某应用的读写权限:可以cat/read/write 索引 (不能删除索引或数据)4, 某应用的只读权限 1, ES 权限参考 https://www.elastic.co/gui…

Linux——Docker容器虚拟化平台

安装docker 安装 Docker | Docker 从入门到实践https://vuepress.mirror.docker-practice.com/install/ 不需要设置防火墙 docker命令说明 docker images #查看所有本地主机的镜像 docker search 镜像名 #搜索镜像 docker pull 镜像名 [标签] #下载镜像&…

智能奶柜:重塑牛奶零售新篇章

智能奶柜:重塑牛奶零售新篇章 回忆往昔,孩童时代对送奶员每日拜访的期待,那熟悉的一幕——新鲜牛奶被细心放置于家门口的奶箱中,成为了许多人温馨的童年记忆。如今,尽管直接投递袋装牛奶的情景已不多见,但…

机器学习-6-对随机梯度下降算法SGD的理解

参考一文带您了解随机梯度下降(Stochastic Gradient Descent):python代码示例 参考sklearn-SGDClassifier 1 梯度下降 在机器学习领域,梯度下降扮演着至关重要的角色。梯度下降是一种优化算法,通过迭代沿着由梯度定义的最陡下降方向,以最小化函数。类似于图中的场景,可以…

【自动驾驶技术栈学习】2-软件《大话自动驾驶》| 综述要点总结 by.Akaxi

----------------------------------------------------------------------------------------------------------------- 致谢:感谢十一号线人老师的《大话自动驾驶》书籍,收获颇丰 链接:大话自动驾驶 (豆瓣) (douban.com) -------------…

新版idea配置git步骤及项目导入

目录 git安装 下载 打开git Bash 配置全局用户名及邮箱 查看已经配置的用户名和邮箱 在IDEA中设置Git 问题解决 项目导入 git安装 下载 进入官网 Git - Downloads 点击所属本机系统,window如下图 选择64位安装 按照默认步骤一直下一步即可 打开git Bash …

2024下半年BRC-20铭文发展趋势预测分析

自区块链技术诞生以来,其应用场景不断扩展,代币标准也在不断演进。BRC-20铭文作为基于比特币区块链的代币标准,自其推出以来,因其安全性和去中心化特性,受到了广泛关注和使用。随着区块链技术和市场环境的不断变化&…

二零二四充能必读 | 618火热来袭,编程书单助你提升代码力

文章目录 📘 Java领域的经典之作🐍 Python学习者的宝典🌐 前端开发者的权威指南🔒 并发编程的艺术🤖 JVM的深入理解🏗 构建自己的编程语言🧠 编程智慧的结晶🌟 代码效率的提升 亲爱的…

【学习Day1】中央处理单元CPU

✍🏻记录学习过程中的输出,坚持每天学习一点点~ ❤️希望能给大家提供帮助~欢迎点赞👍🏻收藏⭐评论✍🏻指点🙏 中央处理单元CPU 中央处理器(CPU,central processing unit&#xff…