用于语言建模的 Transformers 的替代方案
Transformer 架构一直是大型语言模型 (LLMs) 成功的主要组成部分。它已被用于当今几乎所有LLMs正在使用的产品,从 Mistral 等开源模型到 ChatGPT 等闭源模型。
为了进一步改进LLMs,开发了新的架构,甚至可能超过Transformer架构。其中一种方法是 Mamba,一种状态空间模型。
状态空间模型的基本体系结构。
Mamba 是在论文 Mamba: Linear-Time Sequence Modeling with Selective State Spaces 中提出的。您可以在其存储库中找到其官方实现和模型检查点。
在这篇文章中,我将介绍语言建模背景下的状态空间模型领域,并逐一探索概念,以发展对该领域的直觉。然后,我们将介绍曼巴如何挑战变形金刚架构。
作为视觉指南,期待许多可视化来发展对曼巴和状态空间模型的直觉!
第 1 部分:变压器的问题
为了说明为什么 Mamba 是一个如此有趣的架构,让我们先对 transformers 做一个简短的回顾,并探讨它的缺点之一。
Transformer 将任何文本输入视为由标记组成的序列。
变形金刚的一个主要好处是,无论它接收到什么输入,它都可以回顾序列中的任何早期标记来推导出它的表示。
变压器的核心部件
请记住,Transformer 由两个结构组成,一组用于表示文本的编码器块和一组用于生成文本的解码器块。这些结构一起可用于多项任务,包括翻译。
我们可以采用这种结构来创建仅使用解码器的生成模型。这个基于 Transformer 的模型,即生成式预训练转换器 (GPT),使用解码器块来完成一些输入文本。
让我们来看看它是如何工作的!
训练的祝福...
单个解码器模块由两个主要组件组成,一个是掩蔽的自注意力,然后是前馈神经网络。
自我关注是这些模型如此有效的主要原因。它可以通过快速训练实现整个序列的未压缩视图。
那么它是如何工作的呢?
它创建一个矩阵,将每个令牌与之前的每个令牌进行比较。矩阵中的权重取决于令牌对之间的相关性。
在训练期间,此矩阵是一次性创建的。在计算“name”和“is”之间的注意力之前,不需要先计算“My”和“name”之间的注意力。
它支持并行化,从而大大加快了训练速度!
还有推理的诅咒!
然而,有一个缺陷。在生成下一个标记时,即使我们已经生成了一些标记,我们也需要重新计算整个序列的注意力。
为长度为 L 的序列生成令牌大约需要 L² 计算,如果序列长度增加,成本可能会很高。
重新计算整个序列的需求是 Transformer 架构的主要瓶颈。
让我们看看一种“经典”技术,递归神经网络,如何解决这个缓慢推理的问题。
RNN是一种解决方案吗?
递归神经网络 (RNN) 是一种基于序列的网络。它需要序列中每个时间步长的两个输入,即时间步长t的输入和前一个时间步长t-1的隐藏状态,以生成下一个隐藏状态并预测输出。
RNN 具有循环机制,允许它们将信息从上一步传递到下一步。我们可以“展开”这个可视化,使其更加明确。
在生成输出时,RNN 只需要考虑之前的隐藏状态和当前输入。它可以防止重新计算所有以前的隐藏状态,而这正是 Transformer 会做的事情。
换句话说,RNN 可以快速进行推理,因为它与序列长度呈线性关系!从理论上讲,它甚至可以具有无限的上下文长度。
为了说明这一点,让我们将 RNN 应用于我们之前使用的输入文本。
每个隐藏状态都是所有先前隐藏状态的聚合,通常是一个压缩视图。
但是,有一个问题......
请注意,在生成名称“Maarten”时,最后一个隐藏状态不再包含有关单词“Hello”的信息。随着时间的流逝,RNN 往往会忘记信息,因为它们只考虑一个先前的状态。
RNN的这种顺序性带来了另一个问题。训练不能并行进行,因为它需要按顺序一次完成每个步骤。
与变形金刚相比,RNN 的问题完全相反!它的推理速度非常快,但不可并行化。
我们能否以某种方式找到一种架构,既能像 Transformers 一样并行训练,同时仍能执行随序列长度线性缩放的推理?
是的!这就是 Mamba 提供的功能,但在深入研究其架构之前,让我们先探索一下 State Space Models 的世界。
第 2 部分:状态空间模型 (SSM)
状态空间模型 (SSM) 与 Transformer 和 RNN 一样,处理信息序列,如文本和信号。在本节中,我们将介绍 SSM 的基础知识以及它们与文本数据的关系。
什么是状态空间?
状态空间包含完全描述系统的最小变量数。它是一种通过定义系统的可能状态来以数学方式表示问题的方法。
让我们稍微简化一下。想象一下,我们正在迷宫中航行。“状态空间”是所有可能位置(状态)的地图。每个点都代表迷宫中一个独特的位置,并带有特定的细节,例如您离出口有多远。
“状态空间表示”是此地图的简化描述。它显示您所处的位置(当前状态)、下一步可以转到的位置(可能的未来状态)以及哪些更改将您带到下一个状态(向右或向左)。
尽管状态空间模型使用方程和矩阵来跟踪这种行为,但它只是一种跟踪您在哪里、可以去哪里以及如何到达那里的方法。
描述状态的变量,在我们的示例中,X 和 Y 坐标,以及到出口的距离,可以表示为“状态向量”。
听起来很熟悉?这是因为语言模型中的嵌入或向量也经常用于描述输入序列的“状态”。例如,当前位置的向量(状态向量)可能看起来有点像这样:
就神经网络而言,系统的“状态”通常是其隐藏状态,在大型语言模型的上下文中,这是生成新令牌的最重要方面之一。
什么是状态空间模型?
SSM 是用于描述这些状态表示的模型,并根据某些输入预测它们的下一个状态。
传统上,在时间 t 时,SSM:
- 映射输入序列 x(t) — (例如,在迷宫中向左和向下移动)
- 到潜在状态表示 h(t) — (例如,到出口的距离和 x/y 坐标)
- 并推导出预测的输出序列 y(t) — (例如,再次向左移动以更快地到达出口)
但是,它不是使用离散序列(例如向左移动一次),而是将连续序列作为输入并预测输出序列。
SSM 假设动态系统,例如在 3D 空间中移动的物体,可以通过两个方程从其在时间 t 的状态预测。
通过求解这些方程,我们假设我们可以揭示统计原理,根据观察到的数据(输入序列和先前状态)预测系统的状态。
它的目标是找到这个状态表示 h(t),以便我们可以从一个输入序列到一个输出序列。
这两个方程是状态空间模型的核心。
本指南将引用这两个方程。为了使它们更直观,它们被颜色编码,以便您可以快速引用它们。
状态方程根据输入对状态的影响(通过矩阵 B)来描述状态如何变化(通过矩阵 A)。
正如我们之前所看到的,h(t) 指的是我们在任何给定时间 t 下的潜在状态表示,而 x(t) 指的是一些输入。
输出方程描述了状态如何转换为输出(通过矩阵 C)以及输入如何影响输出(通过矩阵 D)。
注意:矩阵 A、B、C 和 D 通常也被称为参数,因为它们是可学习的。
可视化这两个方程为我们提供了以下架构:
让我们逐步了解一般技术,以了解这些矩阵如何影响学习过程。
假设我们有一些输入信号 x(t),该信号首先乘以矩阵 B,矩阵 B 描述了输入如何影响系统。
更新的状态(类似于神经网络的隐藏状态)是一个潜在空间,其中包含环境的核心“知识”。我们将状态乘以矩阵 A,矩阵 A 描述了所有内部状态是如何连接的,因为它们代表了系统的基本动态。
您可能已经注意到,矩阵 A 在创建状态表示之前应用,并在状态表示更新后更新。
然后,我们使用矩阵 C 来描述如何将状态转换为输出。
最后,我们可以利用矩阵 D 来提供从输入到输出的直接信号。这通常也称为跳过连接。
由于矩阵 D 类似于跳跃连接,因此没有跳跃连接的 SSM 通常被视为以下内容。
回到我们的简化视角,我们现在可以将矩阵 A、B 和 C 作为 SSM 的核心。
我们可以像以前一样更新原始方程(并添加一些漂亮的颜色)来表示每个矩阵的用途。
这两个方程旨在从观测数据中预测系统的状态。由于输入应该是连续的,因此 SSM 的主要表示是连续时间表示。
从连续信号到离散信号
如果有一个连续的信号,那么找到状态表示 h(t) 在分析上是具有挑战性的。此外,由于我们通常有一个离散的输入(如文本序列),我们希望对模型进行离散化。
为此,我们使用了零阶保持技术。它的工作原理如下。首先,每次我们接收到离散信号时,我们都会保持其值,直到我们收到新的离散信号。此过程将创建一个 SSM 可以使用的连续信号:
我们保持该值的时间由一个新的可学习参数表示,称为步长∆。它表示输入的分辨率。
现在我们的输入有一个连续信号,我们可以生成一个连续输出,并且只根据输入的时间步长对值进行采样。
这些采样值是我们的离散输出!
在数学上,我们可以按如下方式应用零阶保持:
总之,它们允许我们从连续 SSM 转变为离散 SSM,该公式由一个公式表示,该公式不再是函数到函数 x(t) → y(t),而是序列到序列 xk → yk:
在这里,矩阵 A 和 B 现在表示模型的离散化参数。
我们使用 k 而不是 t 来表示离散化的时间步长,并在我们引用连续 SSM 与离散 SSM 时使其更加清晰。
注意:在训练期间,我们仍在保存矩阵 A 的连续形式,而不是离散化版本。在训练期间,连续表示被离散化。
现在我们已经有了离散表示的公式,让我们来探讨一下如何实际计算模型。
经常性表示
我们的离散化 SSM 允许我们以特定的时间步长而不是连续信号来表述问题。正如我们之前在 RNN 中看到的那样,一种循环方法在这里非常有用。
如果我们考虑离散时间步长而不是连续信号,我们可以用时间步长重新表述问题:
在每个时间步,我们计算当前输入 (Bxk) 如何影响先前状态 (Ahk₋₁),然后计算预测输出 (Chk)。
这种表示方式可能已经有点熟悉了!我们可以像之前看到的那样处理 RNN 的方式处理它。
我们可以这样展开(或展开):
请注意,我们如何使用 RNN 的底层方法来使用这个离散化版本。
这种技术为我们提供了 RNN 的优点和缺点,即快速推理和慢训练。
卷积表示
我们可以用于 SSM 的另一种表示是卷积。请记住,在经典的图像识别任务中,我们应用了过滤器(内核)来派生聚合特征:
由于我们处理的是文本而不是图像,因此我们需要一个一维视角:
我们用来表示这个“过滤器”的内核是从 SSM 公式派生而来的:
让我们来探讨一下这个内核在实践中是如何工作的。与卷积一样,我们可以使用 SSM 内核遍历每组令牌并计算输出:
这也说明了填充可能对输出产生的影响。我更改了填充的顺序以改善可视化效果,但我们经常在句子末尾应用它。
在下一步中,内核移动一次以执行计算的下一步:
在最后一步中,我们可以看到内核的完整效果:
将 SSM 表示为卷积的一个主要好处是它可以像卷积神经网络 (CNN) 一样并行训练。然而,由于内核大小是固定的,它们的推理不像 RNN 那样快速和无界。
三种表现形式
这三种表示形式,连续的、递归的和卷积的,都有不同的优点和缺点:
有趣的是,我们现在可以使用递归 SSM 进行有效的推理,并使用卷积 SSM 进行可并行化训练。
有了这些表示,我们可以使用一个巧妙的技巧,即根据任务选择表示。在训练过程中,我们使用可以并行化的卷积表示,在推理过程中,我们使用高效的循环表示:
此模型称为线性状态空间层 (LSSL)。
这些表示具有一个重要的属性,即线性时间不变性 (LTI) 的属性。LTI 指出,SSM 参数 A、B 和 C 对于所有时间步长都是固定的。这意味着 SSM 生成的每个令牌的矩阵 A、B 和 C 都是相同的。
换言之,无论您给 SSM 提供什么序列,A、B 和 C 的值都保持不变。我们有一个不感知内容的静态表示。
在我们探讨 Mamba 如何解决这个问题之前,让我们先来探讨一下拼图的最后一块,矩阵 A。
矩阵 A 的重要性
可以说,SSM公式中最重要的方面之一是矩阵A。正如我们之前在循环表示中看到的那样,它捕获有关先前状态的信息以构建新状态。
从本质上讲,矩阵 A 产生隐藏状态:
因此,创建矩阵 A 可能是只记住几个以前的标记和捕获我们迄今为止看到的每个标记之间的区别。特别是在递归表示的上下文中,因为它只回顾以前的状态。
那么,我们如何才能以保留大内存(上下文大小)的方式创建矩阵 A?
我们使用 Hungry Hungry Hippo!或用于高阶多项式投影算子的 HiPPO。HiPPO试图将迄今为止看到的所有输入信号压缩成一个系数向量。
它使用矩阵 A 来构建一个状态表示,该表示可以很好地捕获最近的标记并衰减较旧的标记。其公式可以表示如下:
假设我们有一个方阵 A,这给了我们:
使用 HiPPO 构建矩阵 A 被证明比将其初始化为随机矩阵要好得多。因此,与旧信号(初始代币)相比,它更准确地重建了较新的信号(最近的代币)。
HiPPO矩阵背后的想法是,它会产生一种隐藏状态,可以记住它的历史。
在数学上,它通过跟踪勒让德多项式的系数来实现,这允许它近似所有以前的历史。
然后将 HiPPO 应用于我们之前看到的循环和卷积表示,以处理长程依赖关系。结果是序列的结构化状态空间 (S4),这是一类可以有效处理长序列的 SSM。
它由三部分组成:
- 状态空间模型
- 用于处理长距离依赖关系的 HiPPO
- 用于创建递归和卷积表示的离散化
此类 SSM 有几个优点,具体取决于您选择的表示形式(递归与卷积)。它还可以通过基于 HiPPO 矩阵来处理长文本序列并有效地存储内存。
注意:如果您想深入了解有关如何计算 HiPPO 矩阵并自己构建 S4 模型的更多技术细节,我强烈建议您阅读带注释的 S4。
第 3 部分:Mamba — 选择性 SSM
我们终于涵盖了了解曼巴岛的特别之处所需的所有基础知识。状态空间模型可用于对文本序列进行建模,但仍然有一系列我们想要防止的缺点。
在本节中,我们将介绍 Mamba 的两个主要贡献:
- 一种选择性扫描算法,允许模型过滤(不)相关信息
- 一种硬件感知算法,允许通过并行扫描、内核融合和重新计算来有效存储(中间)结果。
它们共同创建了选择性的 SSM 或 S6 模型,这些模型可以像自注意力一样用于创建 Mamba 块。
在探讨这两个主要贡献之前,让我们先探讨一下为什么它们是必要的。
它试图解决什么问题?
状态空间模型,甚至 S4(结构化状态空间模型),在某些对语言建模和生成至关重要的任务上表现不佳,即专注于或忽略特定输入的能力。
我们可以用两个合成任务来说明这一点,即选择性复制和感应头。
在选择性复制任务中,SSM 的目标是复制输入的各个部分并按顺序输出它们:
但是,(递归/卷积)SSM 在此任务中表现不佳,因为它是线性时间不变的。正如我们之前所看到的,SSM 生成的每个令牌的矩阵 A、B 和 C 都是相同的。
因此,SSM 无法执行内容感知推理,因为它将每个令牌视为固定 A、B 和 C 矩阵的结果。这是一个问题,因为我们希望 SSM 对输入(提示)进行推理。
SSM 表现不佳的第二项任务是感应头,其目标是重现输入中的模式:
在上面的例子中,我们基本上是在执行一次性提示,我们试图“教导”模型在每个“Q:”之后提供“A:”响应。但是,由于 SSM 是时间不变的,因此它无法从其历史记录中选择要调用的先前令牌。
让我们通过关注矩阵 B 来说明这一点。 无论输入 x 是什么,矩阵 B 都保持完全相同,因此与 x 无关:
同样,无论输入如何,A 和 C 也保持固定。这证明了我们迄今为止所看到的SSM的静态性质。
相比之下,这些任务对变形金刚来说相对容易,因为它们会根据输入序列动态地改变注意力。它们可以选择性地“查看”或“参加”序列的不同部分。
SSM 在这些任务上的糟糕性能说明了时间不变 SSM 的潜在问题,矩阵 A、B 和 C 的静态性质导致内容感知问题。
有选择地保留信息
SSM 的循环表示创建了一个非常有效的小状态,因为它压缩了整个历史记录。然而,与不压缩历史记录(通过注意力矩阵)的 Transformer 模型相比,它的强大功能要小得多。
曼巴的目标是两全其美。一个与 Transformer 状态一样强大的小状态:
如上所述,它通过有选择地将数据压缩到状态中来实现。当你有一个输入句子时,通常有一些信息,如停用词,没有太多意义。
为了有选择地压缩信息,我们需要参数依赖于输入。为此,让我们首先在训练期间探索 SSM 中输入和输出的维度:
在结构化状态空间模型 (S4) 中,矩阵 A、B 和 C 与输入无关,因为它们的维度 N 和 D 是静态的,不会改变。
取而代之的是,Mamba 通过合并输入的序列长度和批量大小,使矩阵 B 和 C 甚至步长∆取决于输入:
这意味着对于每个输入令牌,我们现在都有不同的 B 和 C 矩阵,这解决了内容感知的问题!
注意:矩阵 A 保持不变,因为我们希望状态本身保持静态,但受其影响的方式(通过 B 和 C)是动态的。
它们一起有选择地选择要保持隐藏状态的内容和要忽略的内容,因为它们现在依赖于输入。
较小的步长∆会导致忽略特定单词,而是更多地使用以前的上下文,而较大的步长∆更关注输入单词而不是上下文:
扫描操作
由于这些矩阵现在是动态的,因此无法使用卷积表示来计算它们,因为它假设有一个固定的内核。我们只能使用递归表示,而失去卷积提供的并行化。
为了启用并行化,让我们探讨如何使用 recurrency 计算输出:
每个状态是前一个状态(乘以 A)加上当前输入(乘以 B)的总和。这称为扫描操作,可以很容易地用for循环来计算。
相比之下,并行化似乎是不可能的,因为只有当我们拥有前一个状态时,才能计算每个状态。然而,Mamba 通过并行扫描算法使这成为可能。
它假定我们通过 associate 属性执行操作的顺序无关紧要。因此,我们可以计算部分的序列并迭代组合它们:
动态矩阵 B 和 C 以及并行扫描算法共同创建了选择性扫描算法,以表示使用递归表示的动态和快速性质。
硬件感知算法
最近 GPU 的一个缺点是,它们在小型但高效的 SRAM 和大型但效率稍低的 DRAM 之间的传输 (IO) 速度有限。频繁地在SRAM和DRAM之间复制信息成为瓶颈。
Mamba 和 Flash Attention 一样,试图限制我们从 DRAM 到 SRAM 所需的次数,反之亦然。它通过内核融合来实现,这允许模型防止写入中间结果并持续执行计算,直到完成。
我们可以通过可视化 Mamba 的基本架构来查看 DRAM 和 SRAM 分配的具体实例:
在这里,将以下内容融合到一个内核中:
- 步长∆的离散化步长
- 选择性扫描算法
- 用 C 乘法
硬件感知算法的最后一部分是重新计算。
中间状态不会保存,但对于向后传递计算梯度是必需的。相反,作者在向后传递期间重新计算这些中间状态。
虽然这看起来效率低下,但它比从相对较慢的 DRAM 中读取所有这些中间状态的成本要低得多。
我们现在已经涵盖了其架构的所有组件,该组件使用其文章中的下图进行描述:
选择性 SSM。取自:Gu、Albert 和 Tri Dao。“Mamba: Linear-time sequence modeling with selective state spaces.” arXiv 预印本 arXiv:2312.00752 (2023).
此架构通常称为选择性 SSM 或 S6 模型,因为它本质上是使用选择性扫描算法计算的 S4 模型。
The Mamba Block(曼巴街区酒店)
到目前为止,我们已经探索的选择性SSM可以作为一个块来实现,就像我们可以在解码器块中表示自注意力一样。
与解码器一样,我们可以堆叠多个 Mamba 块,并使用它们的输出作为下一个 Mamba 块的输入:
它从线性投影开始,以扩展输入嵌入。然后,在选择性 SSM 之前应用卷积以防止独立的令牌计算。
选择性 SSM 具有以下属性:
- 通过离散化创建的循环 SSM
- 矩阵 A 上的 HiPPO 初始化以捕获长程依赖关系
- 选择性扫描算法,选择性地压缩信息
- 硬件感知算法,加快计算速度
在查看代码实现时,我们可以进一步扩展此体系结构,并探索端到端示例的样子:
请注意一些更改,例如包含规范化层和用于选择输出令牌的 softmax。
当我们把所有东西放在一起时,我们既可以快速推理和训练,甚至可以获得无限的上下文!
使用这种架构,作者发现它与相同尺寸的 Transformer 模型的性能相匹配,有时甚至超过!
其他资源
希望这是对曼巴和国家空间模型的无障碍介绍。如果您想更深入地了解,我建议您使用以下资源:
- Annotated S4 是 JAX 实现和指导 S4 模型,强烈建议您使用!
- 一个很棒的 YouTube 视频,通过基础论文来介绍 Mamba。
- 在 Hugging Face 上带有检查点的 Mamba 存储库。
- 一系列介绍 S4 模型的精彩博客文章(1、2、3)。
- Mamba No5 (A Little Bit Of...) 博客文章是深入研究有关 Mamba 的更多技术细节的下一步,但仍然从令人惊讶的直观角度出发。
- 当然,还有曼巴纸!它甚至被用于DNA建模和语音生成。