更多内容,请关注微信公众号:NLP分享汇
原文链接:xLSTM: Extended Long Short-Term Memory
论文链接:https://arxiv.org/pdf/2405.04517
为什么要在27年后提出新的LSTM呢?
LSTM(长短期记忆网络)自20世纪90年代引入以来,在深度学习领域取得了巨大成功。然而,随着Transformer技术的出现,LSTM在规模化应用中的性能开始落后。
【感兴趣的小伙伴可以去读读之前写的这篇文章】
LSTM介绍
NLP分享汇,公众号:NLP分享汇长短期记忆网络 Long Short-Term Memory
所以,文章提出了一个问题:如果将LSTM扩展到数十亿参数规模,并利用LLMs的最新技术,同时解决LSTM的已知限制,那么在语言建模方面LSTM能走多远?
为了解决LSTM的限制,作者介绍了xLSTM,它通过以下两种主要改进来增强LSTM的能力
-
指数门控(exponential gating)引入了适当的归一化和稳定技术,改进了LSTM的门控机制
-
记忆结构修改(LSTM memory structure)对LSTM的记忆结构进行了修改,引入了两种新的记忆单元
-
sLSTM:具有标量记忆、标量更新和新的记忆混合技术。
-
mLSTM:具有矩阵记忆和协方差更新规则,能够完全并行化。
-
随后通过将这些LSTM扩展集成到残差块中,形成了xLSTM块,然后将这些块残差堆叠形成xLSTM架构。通过引入指数门控和修改记忆结构,xLSTM在性能和规模化方面能够与最先进的Transformer和状态空间模型相媲美。
一张图了解前世今生(xLSTM Family)
图1 xLSTM 家族
图1左侧展示了传统的LSTM记忆单元,包括恒等误差旋转(Constant Error Carousel)和门控(Gating)机制。这些是LSTM的核心组成部分,用于处理和存储信息。在原始LSTM的基础上,引入了两种新的记忆单元sLSTM和mLSTM。然后将sLSTM和mLSTM记忆单元整合到残差块(Residual Blocks)中,形成了xLSTM Blocks。这些块是xLSTM架构的基本构建单元。最后通过残差堆叠(Residual Stacking)这些xLSTM块,形成了完整的xLSTM架构。这种架构设计用于处理长序列数据,并且能够扩展到数十亿参数规模。
记忆单元的特点:
-
恒等误差旋转:表示LSTM中细胞状态的更新方式,通过输入和遗忘门控制。
-
Sigmoid门控:用于控制LSTM中信息的流动,包括输入、遗忘和输出门。
-
循环推断:指的是LSTM在推断阶段的循环计算过程。
-
循环训练:指的是LSTM在训练阶段的循环计算过程。
LSTM的局限性以及xLSTM是如何克服这些限制
图2 三方较量图
图2左侧比较的是最近邻搜索问题(Nearest Neighbor Search problem)。LSTM在处理最近邻搜索问题时的局限性。这个问题要求模型在给定一个参考向量的情况下,顺序扫描序列以找到最相似的向量,并在序列末尾返回其附加值。均方误差(Mean Squared Error, MSE),这是衡量模型预测与实际值差异的指标。LSTM在找到更相似的向量时难以修正已存储的值,而新的xLSTM通过指数门控机制克服了这一限制。
图2右侧比较的是稀有词预测问题(Rare Token Prediction)。LSTM在预测Wikitext-103数据集中稀有词的性能问题。图中使用困惑度(Perplexity, PPL)来衡量不同词频桶(buckets of token frequency)上的预测性能。展示了不同词频下的PPL,LSTM在预测稀有词(即出现频率较低的词)时性能较差,这是因为其有限的存储容量。新的xLSTM通过矩阵记忆机制解决了这个问题,能够更有效地处理稀有词的预测。
Extended Long Short-Term Memory
LSTM
-
LSTM的引入:LSTM最初是为了解决循环神经网络(RNN)中的梯度消失问题而提出的。它通过引入一个称为“恒等误差旋转”(constant error carousel)的机制,以及输入、遗忘和输出门控(gating)来维持长期依赖关系。
-
LSTM的核心方程:文章提供了LSTM的核心更新规则,包括单元状态(cell state)和隐藏状态(hidden state)的计算公式。这些方程定义了LSTM如何在时间步t更新其内部状态。
-
门控机制:LSTM包含三个门控:输入门(input gate)、遗忘门(forget gate)和输出门(output gate)。输入门控制新信息的流入,遗忘门决定哪些信息应该从单元状态中被遗忘,输出门则控制从单元状态到隐藏状态的转换。
-
权重和偏置:LSTM的每个门控都与权重向量(weight vectors)和偏置项(bias terms)相关联。权重向量用于处理输入和隐藏状态之间的交互,而偏置项则为模型提供了一个初始状态。
-
激活函数:LSTM使用特定的激活函数,如tanh,来规范化或压缩单元状态,确保其值保持在一定的范围内。sigmoid函数用于门控机制,以生成介于0和1之间的值,表示门控的开启程度。
sLSTM
-
指数门控(Exponential Gating):为了使LSTM能够修订其存储决策,sLSTM引入了指数激活函数的门控机制。与传统的sigmoid门控不同,指数门控可以产生更大的值,从而允许模型更灵活地更新其内部状态。
-
归一化和稳定化技术:由于指数激活函数可能导致数值溢出,sLSTM引入了一个归一化状态(normalizer state),该状态对输入门和所有未来遗忘门的乘积进行求和。此外,为了稳定门控,sLSTM使用了一个额外的稳定器状态(stabilizer state)来控制门控的值。
-
sLSTM的前向传播方程:文章列出了sLSTM的核心更新规则,包括单元状态(cell state)、归一化状态(normalizer state)和隐藏状态(hidden state)的计算公式。这些方程定义了sLSTM如何在时间步t更新其内部状态。
-
记忆混合(Memory Mixing):sLSTM允许通过循环连接(recurrent connections)进行记忆混合,这在原始的LSTM中是不可能的。这种新的记忆混合技术允许sLSTM在不同的内存单元之间共享信息。
-
多头部结构(Multi-Head Structure):sLSTM可以具有多个头部,每个头部都有自己的记忆混合,但头部之间没有跨头的记忆混合。这种设计为sLSTM提供了一种新的记忆混合方式。
-
门控的激活和稳定化:sLSTM的输入门和遗忘门可以使用指数激活函数,而输出门仍然使用sigmoid函数。为了稳定化这些门控,sLSTM使用了一种广播机制,将原始的LSTM门控技术和偏差项扩展到新的架构中。
-
sLSTM的参数和计算:sLSTM的每个门控都与权重向量和偏置项相关联,这些参数在模型训练过程中进行学习。激活函数和稳定化状态的引入增加了模型的计算复杂性,但同时也提高了其表达能力。
mLSTM
-
矩阵记忆(Matrix Memory):mLSTM的核心创新是将传统的标量记忆单元扩展为矩阵记忆单元。这允许mLSTM以矩阵形式存储更多的信息,从而提高了模型的存储容量。
-
协方差更新规则(Covariance Update Rule):mLSTM使用协方差更新规则来存储和检索信息。这种规则通过将键(key)和值(value)对存储为矩阵的行或列来实现,从而提高了检索的分离度和信号/噪声比。
-
Bidirectional Associative Memories (BAMs):mLSTM的设计灵感来源于双向联想记忆模型,它使用矩阵乘法来实现信息的存储和检索。
-
mLSTM的前向传播方程:文章列出了mLSTM的核心更新规则,包括矩阵记忆状态(cell state)、归一化状态(normalizer state)和隐藏状态(hidden state)的计算公式。这些方程定义了mLSTM如何在时间步t更新其内部状态。
-
并行化(Parallelization):与sLSTM不同,mLSTM的设计允许完全的并行化处理,因为它放弃了隐藏层之间的循环连接(memory mixing)。这使得mLSTM可以更高效地在现代硬件上进行训练和推理。
-
输入门和遗忘门:mLSTM的输入门和遗忘门可以使用指数激活函数,而输出门仍然使用sigmoid函数。这些门控机制允许模型控制信息的流入和遗忘。
-
稳定化技术:为了稳定化mLSTM中的指数门控,文章采用了与sLSTM相同的稳定化技术。
-
多头部和多单元结构:mLSTM可以扩展为多个头部和多个单元,其中多个头部和多个单元在mLSTM中是等效的,因为它没有跨头部的记忆混合。
xLSTM Architectur
图3 xLSTM block
左边是sLSTM残差块,这种块的设计模仿了Transformer架构中的残差连接和上投影(post up-projection)。输入首先进入一个sLSTM层,可能在进入LSTM之前会经过一个卷积层(Convolution)。接着是一个门控的多层感知机(gated MLP),这是Transformer中常见的组件。
右边是mLSTM残差块,这种块的设计类似于状态空间模型(State Space Models),其中上投影发生在mLSTM层之前。mLSTM被包裹在两个多层感知机(MLPs)之间,通过卷积层、可学习的跳跃连接(skip connection)和一个逐元素作用的输出门(output gate)。
在这两种类型的xLSTM块中,都使用了残差连接,这是从原始的LSTM架构中继承而来的特性,有助于训练更深的网络。
xLSTM架构的内存和速度考虑
与Transformer相比,xLSTM网络具有线性的计算复杂性和恒定的内存复杂性,这使得它们在处理长序列时更加高效。xLSTM的内存是压缩的,这使得它非常适合工业应用和在边缘设备上的实现。mLSTM的内存不需要参数,但是由于其d×d的矩阵内存和更新,计算上较为昂贵。尽管mLSTM可以并行化,类似于FlashAttention或其他并行注意力机制,但sLSTM由于内存混合(hidden-hidden connections)而无法并行化。为了解决sLSTM的非并行化问题,作者开发了一个快速的CUDA实现,包括GPU内存优化,使得sLSTM的性能通常不超过mLSTM的两倍。文章还讨论了两种xLSTM块的架构,一种是后上投影(post up-projection),另一种是前上投影(pre up-projection)。这些设计分别适用于sLSTM和mLSTM,以最大化它们的性能。尽管mLSTM的矩阵内存计算复杂,但由于可以在GPU上并行处理,因此对实际的墙钟时间(wall clock time)影响较小。
实验部分
图4 xLSTM’s exponential gating with memory mixing
图4是文章中关键的实验结果之一,它直观地展示了xLSTM在处理需要复杂记忆和状态跟踪的任务时的优势。这些结果证明了xLSTM在语言建模和序列处理任务中的潜力,特别是在那些对传统LSTM构成挑战的任务中。
图4中的任务根据乔姆斯基层级(Chomsky hierarchy)进行了分组,这是一种描述形式语言的表达能力的方式。不同模型在解决这些任务时的准确率被标准化在0到1之间,其中0表示随机猜测,1表示完美解决。
图中比较了几种xLSTM变体,包括只有sLSTM的架构(xLSTM[0:1]),只有mLSTM的架构(xLSTM[1:0]),以及两者结合的架构(xLSTM[1:1])。除了xLSTM,还包括了其他几种模型的性能,如Llama、Mamba、RWKV-4、RWKV-5、RWKV-6、LSTM(Block)和传统的LSTM。
通过比较,可以看出xLSTM在处理需要状态跟踪的任务时的性能优势。例如,Transformer或没有状态跟踪的状态空间模型(SSMs)在解决某些正则文法任务(如奇偶性任务)时表现不佳。
图4的结果表明,xLSTM通过其指数门控和记忆混合机制,能够有效地解决状态跟踪问题,这是其在形式语言任务中表现出色的关键。
图5 memory capacities of different models at the Multi-Query Associative Recall task with context length 2048
图5展示了xLSTM在处理多查询关联回忆(Multi-Query Associative Recall, MQAR)任务时的内存容量性能。MQAR任务要求模型记忆一系列随机选择的键值对(key-value pairs),并在稍后根据给定的键(key)回忆(recall)相应的值(value)。这个任务考验了模型的内存容量,尤其是它能够在多大程度上存储和回忆信息。
图5比较了不同模型在MQAR任务上的性能,包括Llama、Mamba、RWKV-5、RWKV-6、xLSTM[1:1]和xLSTM[1:0]。不同模型维度(Model Dim)下的性能,模型维度可能指的是模型的大小或复杂度。
y轴表示模型在回忆键值对时的准确率,准确率越高,表明模型的内存容量越好。x轴表示模型需要记忆的键值对的数量,从32到512不等。键值对的数量越多,任务的难度越大。
模型在每个键值对数量设置下的性能通过验证准确率来评估。准确率越高,表明模型在处理给定数量的键值对时越有效。xLSTM[1:1]在所有非Transformer模型中表现最佳,即使是在小模型尺寸下。此外,sLSTM块并没有降低内存容量,反而在处理最困难的任务(256键值对)时,它的性能更加明显。
表1 comparison on next token prediction when trained on 15B tokens from SlimPajama
表1提供了在SlimPajama数据集上训练的不同语言模型的比较,特别是在15B个token上训练时的性能。表中的#Params列显示了每个模型的参数数量。表中的“SlimPajama (15B) ppl ↓”列显示了每个模型在SlimPajama数据集的验证集上的困惑度。困惑度是衡量语言模型性能的一个指标,越低表示模型性能越好。结果显示,xLSTM[1:0]和xLSTM[7:1]在验证集困惑度上表现最佳,这意味着它们在这些模型中具有最低的困惑度,因此在给定的数据集上表现最好。
图6 comparison on next token prediction when trained on 15B tokens from SlimPajama
图6展示了不同语言模型在SlimPajama数据集上进行下一个词预测任务时的性能比较,特别是在不同模型大小下的验证集困惑度(Perplexity)。X轴表示模型的参数数量,通常以10亿(109)为单位。模型大小的不同可能会影响其在任务上的表现。Y轴表示验证集上的困惑度,这是一个衡量语言模型性能的指标,用于评估模型对真实数据分布的拟合程度。困惑度越低,表示模型的性能越好。
随着模型大小增加,各模型在验证集上的困惑度如何变化。这些趋势线可以帮助我们理解模型规模对性能的影响。
表2 Ablation studies
表2提供了xLSTM模型的消融研究(Ablation Study)结果,这些研究旨在评估xLSTM中不同组件对整体性能的贡献。
顶部表:xLSTM新组件的消融研究
-
模型修改:列出了对原始LSTM模型所做的修改,以逐步构建xLSTM模型。
-
指数门控(Exponential Gating):指是否在模型中使用了指数门控机制。
-
矩阵内存(Matrix Memory):指是否在模型中使用了矩阵内存结构。
-
参数数量(#Params M):显示了每个修改后的模型的参数数量。
-
SlimPajama (15B) ppl ↓:展示了在SlimPajama数据集上,使用15B个token训练后,每个模型在验证集上的困惑度(Perplexity)。困惑度越低,表示模型性能越好。
-
结果分析:通过比较不同修改后的模型,可以观察到指数门控和矩阵内存对性能提升的贡献。
底部表:不同门控技术的消融研究
-
可学习门控(Learnable Gates):指模型中的门控是否是可学习的。
-
遗忘门(Forget Gate):指模型是否使用了遗忘门。
-
输入门(Input Gate):指模型是否使用了输入门。
-
偏差初始化(Bias Init):展示了门控偏差的初始化方式,这对于模型的学习动态有重要影响。
-
SlimPajama (15B) ppl ↓:同样显示了在SlimPajama数据集上,使用15B个token训练后,每个模型在验证集上的困惑度。
-
结果分析:通过比较不同门控配置的模型,可以了解哪些门控特性对性能提升更为关键。
图7 Sequence extrapolation in language modeling
图7展示了不同大语言模型(LLMs)在序列长度外推(Sequence Length Extrapolation)方面的性能。序列长度外推是指模型在训练时使用一定长度的上下文,然后在测试时评估模型处理更长上下文的能力。这是衡量模型能否有效利用其学习到的表示来处理未见过的长序列的重要指标。X轴表示测试时使用的上下文长度,从2048一直到16384。模型在训练时使用的上下文长度为2048。Y轴表示模型在不同上下文长度下的困惑度。困惑度是衡量语言模型性能的一个指标,用于评估模型对真实数据分布的拟合程度。困惑度越低,表示模型的性能越好。
根据图7,xLSTM模型在所有测试的上下文长度下都保持了较低的困惑度,表明其在序列长度外推任务上的性能优于其他模型。xLSTM在处理比训练时更长的上下文时,仍然能够维持较低的困惑度,这表明xLSTM具有良好的外推能力。
表3 Validation set perplexity and downstream tasks
表3展示了不同语言模型在SlimPajama数据集上的性能比较,特别是在不同模型大小下的验证集困惑度(Perplexity)和下游任务(Downstream Tasks)的性能。表中显示了各个模型在不同的下游任务上的性能。这些任务包括常识推理、文本分类、问答系统等,用于衡量模型在特定任务上的表现。
表4 Performance on PALOMA Language Modeling Tasks
表4展示了不同语言模型在PALOMA语言任务上的下一个词预测性能,特别是在不同模型大小下的困惑度(Perplexity)。表中列显示了每个模型在PALOMA数据集的不同文本域上的下一个词预测困惑度。困惑度越低,表示模型在该文本域上的性能越好。表中最后一列显示了每个模型在所有PALOMA文本域上的平均困惑度。通过比较不同模型在同一文本域或平均困惑度上的性能,可以评估哪种模型结构更适合特定的语言任务。
图8 Scaling Laws
图8在展示了不同语言模型在不同模型大小下的验证集困惑度(Perplexity)与参数数量的关系,这是对模型缩放行为(Scaling Laws)的分析。X轴表示模型的参数数量。这个参数范围可能涵盖了从较小到较大的模型尺寸。Y轴表示模型在验证集上的困惑度,这是一个衡量语言模型性能的指标,用于评估模型对真实数据分布的拟合程度。困惑度越低,表示模型的性能越好。通过观察曲线的斜率和分布,可以分析不同模型在缩放时的性能趋势。一般来说,如果曲线较为平缓,表明模型在增加参数后性能提升的边际效益较低。例如,xLSTM模型可能在所有尺寸下都展现出较低的困惑度,表明其在缩放时保持了较高的效率。
xLSTM局限性
sLSTM由于其内存混合特性,无法进行并行化操作。这与mLSTM形成对比,后者由于放弃了内存混合,可以实现并行处理。尽管为sLSTM开发了一个快速的CUDA实现,其速度仍然大约是并行mLSTM实现的1.5倍。对于mLSTM,当前的CUDA内核未经过优化,导致其性能大约是FlashAttention或Mamba中使用的scan操作的4倍慢。mLSTM的矩阵内存具有较高的计算复杂性,因为它需要处理d×d的矩阵。尽管矩阵内存的更新和检索不使用参数,并且可以并行化,但计算复杂性仍然是一个考虑因素。
遗忘门的初始化必须谨慎选择,以避免训练过程中的不稳定问题。
由于矩阵内存独立于序列长度,增加序列长度可能会在更长的上下文中对内存造成过载。然而,对于长达16k的上下文,这似乎并不是一个限制。
由于大语言模型实验的计算成本,xLSTM的架构和超参数并未完全优化。作者预计,为了使xLSTM发挥其全部潜力,需要一个广泛的优化过程。