1. 语言模型之精简RNN结构
近期关注到,Yoshua Bengio发布了一篇论文《Were RNNs All We Needed?》,提出简化版RNN(minLSTM和minGRU)。该工作的初始缘由:Transformer 在序列长度方面的扩展性限制重新引发了对可在训练期间并行化的循环序列模型的兴趣。最近一段时间,许多新的循环架构,如 S4、Mamba 和 Aaren,已经被提出并取得了相当的性能。
因此在Bengio这项工作中,重新对传统循环神经网络 (RNN)进行了分析和调整,LSTMs(1997年)和 GRUs(2014年)这些模型由于需要通过时间反向传播(BPTT)而运行缓慢,但实验结果表明,通过移除其输入、遗忘和更新门上的隐藏状态依赖,LSTMs 和 GRUs 不再需要 BPTT,并且可以高效并行训练。在此基础上,论文中引入了最简版本的 minLSTMs 和 minGRUs,(1) 它们使用的参数明显少于传统模型,(2) 在训练期间可以完全并行化(对于长度为 512 的序列,速度提高了 175 倍)。最后展示了这些简化版的 RNN 能够与近期的序列模型在性能上匹敌。
论文考虑了一种语言建模任务。在这种设置中,使用 nanoGPT(Karpathy,2022)框架,在莎士比亚作品上训练字符级别的 GPT。在上图 中,绘制了交叉熵损失的学习曲线,比较了所提出的最简 LSTM 和 GRU(minLSTM 和 minGRU)与 Mamba 和 Transformer 的表现。可以发现 minGRU、minLSTM、Mamba 和 Transformer 分别达到了 1.548、1.555、1.575 和 1.547 的可比测试损失。Mamba 的表现稍微逊色于其他模型,但在早期阶段训练得更快,在 400 步时达到了最佳性能,而 minGRU 和 minLSTM 分别在 575 和 625 步时继续训练。相比之下,Transformer 的训练速度明显更慢,需要 2000 步(大约 2.5 倍于 minGRU)才能达到相似的性能,这使得它的训练显著更慢且资源消耗更大(相较于 minGRU、minLSTM 和 Mamba 的线性复杂度,Transformer 的复杂度是二次的)。
2. 对传统RNN模型进行精简
常见的循环神经网络结构主要有这三种:
2.1 LSTM网络结构
Hochreiter 和 Schmidhuber (1997) 引入了长短期记忆网络(LSTM)。LSTM 是增强型的循环神经网络(RNN),旨在缓解梯度消失问题,从而允许模型学习长期依赖关系。LSTM 的计算方式如下:
其中,⊙ 表示向量的逐元素相乘,t 为当前时间步, 为输出的隐藏状态, 表示将输入 与上一时间步的隐藏状态 进行拼接, 是隐藏状态的维度, 是在整个序列中维持信息的细胞状态,是候选的细胞状态, 和 分别是控制输入、遗忘和输出的门机制。输入门 控制从候选细胞状态中添加多少新信息,遗忘门 决定丢弃细胞状态中的多少信息,输出门 决定细胞状态中的哪些信息应该输出。σ 和 tanh 用于缩放,以确保输出不会爆炸或消失。一个 LSTM 模块同时维护细胞状态和隐藏状态,总共包含 O(4dh(dx + dh)) 个参数。
解决梯度弥散问题
在RNN中,由于隐藏状态的逐步传递,导致梯度在每个时间步的反向传播中可能逐渐衰减,从而出现梯度弥散问题。LSTM通过以下方式避免这一问题:
长时间的梯度传递: 由于细胞状态 直接通过 传递,LSTM能够在较长的序列中保持较大的梯度。
门控机制的影响:
由于 和 控制着信息的流动和保留,LSTM能够有效地调整梯度的大小,防止它们在传播过程中逐渐消失。短路连接的作用: 短路连接使得梯度在反向传播时可以直接传递,这样避免了通过多个层次传递梯度时的衰减。LSTM中的细胞状态 ,它在时间步之间传递。细胞状态通过短路连接直接传递,不受激活函数的影响。这允许信息在时间步之间以恒定的比例传递,防止信息的消失或爆炸。
在这个公式中: 是遗忘门,决定了上一个细胞状态 中的信息在当前细胞状态中的保留程度。 是输入门,决定了新候选状态 在当前细胞状态中的贡献。
2.2 GRU网络结构
为了简化 LSTM,Cho 等人(2014 年)提出了门控循环单元(GRU),它仅使用两个门控机制和一个状态,取代了 LSTM 的三个门控机制和两个状态(隐藏状态和细胞状态)。GRU 的简化使其在许多任务中能够以更快的训练和推理时间实现竞争性的性能。GRU 的计算方式如下:
其中,是候选的隐藏状态,代表隐藏状态可能的新值。GRU 将 LSTM 的遗忘门和输入门合并为一个更新门 ,它决定了要携带多少过去的信息(即 ),以及从候选隐藏状态中添加多少新信息(即 )。此外,GRU 移除了 LSTM 的输出门,取而代之的是增加了一个重置门 ,用于控制在计算候选隐藏状态时使用多少过去的信息。GRU 减少了参数和计算量,仅需 O(3dh(dx + dh)) 个参数。然而,GRU 和 LSTM 只能顺序计算。因此,在训练期间,它们需要通过时间反向传播(BPTT)其梯度,导致线性训练时间,极大地限制了其扩展到长序列的能力。
关于BPTT: 反向传播通过时间(Backpropagation Through Time,BPTT)是一种用于训练递归神经网络(RNN)的算法,特别是LSTM和GRU等变种。BPTT的主要思想是将RNN展开为一个具有时间步的深度前馈网络,然后通过标准的反向传播算法来计算梯度。
2.3 循环神经网络结构的问题分析
从LSTM与GRU结构可以看出,由于只能顺序计算,因此训练性能受限。正因这一限制,Transformer 取代了 LSTM 和 GRU,成为多年来事实上的序列建模方法,因为它可以在训练期间实现并行化。然而,Transformer 的复杂度与序列长度呈二次关系,限制了其在长序列上下文中的扩展能力。最近,许多新的循环模型作为 Transformer 的替代方案被提出,这些模型不仅性能相当,还可以并行训练,并且避免了传统 RNN(如 LSTM 和 GRU)面临的时间反向传播(BPTT)问题。尽管提出了许多不同的架构,其中许多模型都可以使用并行前缀扫描算法(Blelloch,1990)高效训练。并行扫描算法是一种用于通过关联运算符 ⊕(例如加法 "+" 和乘法 "×")从 N 个顺序数据点计算 N 个前缀计算结果的并行计算方法。可以将并行扫描方法应用于高效计算一个常见的函数族:,其中 、 和 属于实数域 R,且 (Heinsen,2023)。该方法以 和 作为输入,通过并行扫描计算出 。
2.4 网络精简
上述算法也可以扩展到向量形式:,其中 是元素级乘法。可以看到 GRU 和 LSTM 的状态递归类似于这种向量形式。通过简化并移除它们在各种门控机制中的一些隐藏状态依赖关系,可以使用并行扫描训练 GRU 和 LSTM。在此基础上,进一步简化了这些 RNN,移除了它们对输出范围的限制(例如,tanh),并确保输出在时间上的尺度无关性。结合这些步骤,论文提出了 GRU 和 LSTM 的简化版本(minGRU 和 minLSTM),它们可以通过并行扫描进行训练,且性能与 Transformer 及最近提出的序列模型相当。
2.4.1 简化的 GRU:minGRU
2.4.1.1 第一步:去除门控机制中对先前隐藏状态的依赖
回顾 GRU 的隐藏状态递归,它的计算如下:
可以观察到该递归类似于前述的并行扫描公式,其中 ,,而 。然而, 和 依赖于先前的隐藏状态 ,即:
因此,不能直接将并行扫描算法应用于这种情况,因为算法的输入 和 是条件化的,需要知道输出 。可以通过简化 GRU 来解决这一问题,移除其对先前隐藏状态 的依赖。具体的修改如下:
通过移除候选隐藏状态 中对 的依赖,控制 权重的重置门 也不再需要,因此被移除。没有对先前隐藏状态的依赖后,算法的输入 和 都可以轻松地并行计算,从而可以通过并行扫描高效地计算 。
2.4.1.2 第二步:去除候选状态的范围限制
在 GRU 的隐藏状态递归中,从先前隐藏状态继承的比例 和为新的候选隐藏状态 添加的量相加为 1。因此,GRU 的隐藏状态值的规模是与时间无关的。相反,其隐藏状态的规模依赖于候选隐藏状态 的规模。双曲正切函数(tanh)在 LSTM 和 GRU 中起着关键作用,限制了(候选)隐藏状态的范围,即。tanh 帮助稳定训练,并缓解由对隐藏状态的线性变换应用 sigmoid()激活所导致的梯度消失问题(例如,。
在前一步中,移除了这些隐藏状态的依赖性。因此,可以进一步简化 GRU,移除(候选)隐藏状态上的范围限制(tanh),如下所示:
2.4.1.3 minGRU
结合这两步简化,得到 GRU 的最简版本(minGRU):
最终得到的模型相比原始 GRU 显著更高效:(1)仅需 的参数,而 GRU 则需 的参数,其中 和 分别对应 和 的大小。就训练而言,minGRU可以通过并行扫描算法进行并行训练,大大加快了训练速度。在 T4 GPU 上,针对长度为 512 的序列,训练步数加速达到了 175 倍。参数效率的提升也相当显著。通常,在 RNN 中会进行状态扩展(即 ,其中 ),以便模型更容易从输入中学习特征。当 或 时,minGRU 分别仅使用 GRU 参数的约 33%、22%、17% 或 13%。
2.4.2 简化版 LSTM: minLSTM
2.4.2.1 第一步:移除门控中对先前隐藏状态的依赖
回顾 LSTM 中的单元状态递归,其计算如下:
类似于 GRU 的隐藏状态,可以看到 LSTM 的单元状态递归与之前提到的并行扫描公式类似:
,其中 , , 且 。
然而,、 和依赖于先前的隐藏状态 ,因此 LSTM 的单元状态递归无法直接应用并行扫描算法。可以通过类似 GRU 的方式移除对隐藏状态的依赖,具体如下:
简化为:
2.4.2.2 第二步:去掉候选状态的范围限制
与 GRU 类似,LSTM 使用双曲正切函数(tanh)将状态值限制在 (-1, 1) 之间。LSTM 在两处应用了范围限制:一次是在计算候选单元状态时,另一次是在计算隐藏状态时。在此步骤中,去掉这两处限制:
简化为:
2.4.2.3 第三步:确保输出的尺度时间无关性
在许多序列建模任务中(如文本生成),优化目标的尺度是时间无关的。回顾 LSTM 的单元状态递归公式:
,其中、,
以及 GRU 的隐藏状态递归:
,
其中 。
GRU 通过保持 与 之和为 1,确保其输出(即隐藏状态)的尺度与时间无关。相比之下,LSTM 的遗忘门和输入门是独立计算的(例如, 或 ),这使得其单元状态的尺度随时间变化,增加了优化难度。
因此,论文提出通过对两个门进行归一化来确保 LSTM 输出的尺度与时间无关。具体做法如下:
确保,从而使 LSTM 的单元状态的尺度与时间无关。此外,去掉了对隐藏状态进行尺度调整的输出门 。没有输出门后,归一化的隐藏状态等于单元状态,即:
因此隐藏状态和单元状态的同时存在变得不再必要,故移除了单元状态。最终的简化如下:
在这里,"与时间无关"(time-independent) 的意思是指模型输出的尺度(即输出值的大小范围)不随着时间变化而变化。这与模型在序列数据上处理信息时,如何保留和更新状态有关。
具体来说,LSTM 和 GRU 模型的状态更新涉及在每个时间步对当前和之前的隐藏状态(或单元状态)进行加权组合。对于 GRU,它通过控制参数 和 的总和为 1 来确保在每个时间步中,上一时间步的状态和当前时间步的候选状态按照比例加权组合,这样输出的尺度在每个时间步都不会出现大的波动,保持了一致性(即与时间无关)。
相比之下,传统的 LSTM 并没有这种明确的限制。LSTM 的遗忘门和输入门 是独立计算的,它们的值并不要求加和为1。这意味着某些时间步上,LSTM 可能会完全忽略上一时间步的状态(即 ),或者完全忽略当前时间步的新信息(即 )。这样,LSTM 的输出尺度可能随着时间步的不同而变化,导致模型的输出和状态的变化幅度变得不稳定,这种时间依赖性增加了优化的难度。
为了使 LSTM 的输出尺度也像 GRU 一样不依赖时间,论文中通过对遗忘门和输入门进行归一化处理,使它们的和总是等于1。这就类似于 GRU 中 和 的关系,确保了在每个时间步上,上一时间步的状态和当前时间步的状态按照固定比例组合,从而输出的尺度不会随时间变化。这就是"与时间无关"的含义。
简单来说,时间无关性意味着:在整个序列的处理过程中,模型的输出值大小不会因时间步的变化而出现较大的波动,保持一致的输出范围,这有助于模型的稳定性和训练过程
2.4.2.4 minLSTM
结合前述三步,形成最简化版的 LSTM(minLSTM):
最简化版的 minLSTM 在效率上显著提升:首先,它只需要 个参数,而相比之下,传统 LSTM 需要 个参数。此外,minLSTM 可以使用并行扫描算法进行训练,大大加速了训练过程。对于长度为 512 的序列,minLSTM 在 T4 GPU 上的训练速度比传统 LSTM 提升了 235 倍。就参数效率而言,当 = 1, 2, 3, 4 且 时,minLSTM 只使用了传统 LSTM 参数量的 38%、25%、19% 或 15%。
3. 伪代码及torch代码
3.1 minGRU
3.2 minLSTM
3.3 Parallel scan
并行扫描:对数空间实现。
并行扫描的目标是计算 ,其中 。在代码中,原始的并行扫描函数输入参数为:系数 和值 ,输出为 。为了数值稳定性,论文考虑了对数空间的实现,输入改为 和 ,输出则仍为 。下方提供了基于 Heinsen(2023)的代码实现的对数空间并行扫描函数。
4. 参考材料
【1】Were RNNs All We Needed?