原文地址: A Visual Guide to Mamba and State Space Models
2024 年 2 月 19 日
论文地址:https://arxiv.org/pdf/2312.00752.pdf
这篇论文介绍了一种新型的线性时间序列模型Mamba,它通过选择性状态空间(Selective State Spaces, SSS)来提高处理长序列数据的效率和性能。Mamba模型的关键特点和贡献可以总结如下:
1. **背景与动机**:
- 传统的Transformer模型虽然在许多应用中表现出色,但在处理长序列时面临着计算效率低下的问题。
- 为了解决这个问题,研究者们开发了多种子二次时间复杂度的架构,如线性注意力、门控卷积和循环模型等,但这些模型在某些重要模态(如语言)上的性能仍不如Transformer。
- Mamba模型的提出旨在结合Transformer的强大性能和线性时间复杂度的效率。2. **选择性状态空间(SSS)**:
- Mamba模型引入了SSS机制,允许模型根据输入内容动态地选择性地传播或遗忘信息。
- 通过将SSM参数化为输入的函数,Mamba能够过滤不相关信息,并无限期记住相关信息,从而提高了模型在处理离散和信息密集型数据(如文本)时的性能。3. **硬件感知并行算法**:
- 为了克服选择性SSM在计算上的挑战,Mamba设计了一个硬件感知的并行算法,该算法采用递归模式而非卷积模式来计算模型。
- 该算法通过扫描操作来实现,避免了在GPU内存层次结构中不同级别之间的IO访问,从而提高了计算效率。4. **简化的端到端神经网络架构**:
- Mamba简化了传统的深度序列模型架构,将SSM的设计和Transformer的MLP块结合起来,形成了一个简单且同质的架构设计。
- 这种设计通过重复单一的Mamba块来构建,每个块都包含选择性状态空间,而不是像传统架构那样交替使用不同的模块。5. **跨模态性能**:
- Mamba模型在多个模态上都取得了优异的性能,包括语言、音频和基因组数据。
- 在语言建模任务中,Mamba在预训练和下游评估中都超越了相同规模的Transformer模型,并且与规模是其两倍的Transformer模型性能相当。6. **开源代码和预训练模型**:
- 作者公开了Mamba模型的代码和预训练模型,以便研究者和开发者可以复现和扩展Mamba的工作。
- 这有助于推动社区对Mamba模型的进一步研究和应用。7. **实验验证**:
- 论文通过一系列实验验证了Mamba模型的有效性,包括合成任务、语言建模、DNA建模和音频建模等。
- 实验结果表明,Mamba在处理长序列数据时具有显著的性能优势,并且在多个领域中都能达到最先进的结果。总结来说,Mamba模型通过其创新的选择性状态空间机制、硬件感知的计算策略和简化的架构设计,有效地解决了长序列数据建模中的效率和性能问题,为深度学习中的序列建模提供了一个强有力的新工具。它的跨模态性能和开源的实现进一步增强了其在深度学习序列建模领域的潜力和实用性。
Mamba模型是一种新型的序列建模架构,它通过以下几个关键创新点来提高长序列数据处理的效率和性能:
1. **选择性状态空间(Selective State Spaces, SSS)**:
- Mamba模型引入了选择性状态空间的概念,这是一种允许模型根据当前输入动态调整其状态传播机制的方法。
- 通过将结构化状态空间模型(SSM)的参数设置为输入的函数,Mamba能够有选择地传播或忘记信息,从而有效地过滤掉不相关的输入数据,并保留重要的信息。
- 这种机制使得Mamba在处理具有高信息密度的序列数据(如文本和基因组数据)时表现出色,因为它可以进行基于内容的推理。2. **硬件感知并行算法**:
- 为了解决选择性SSM在计算上的挑战,Mamba设计了一个硬件感知的并行算法,该算法采用递归模式(recurrent mode)而非卷积模式(convolutional mode)来计算模型。
- 该算法通过扫描(scan)操作来实现,避免了在GPU内存层次结构中不同级别之间的IO访问,从而提高了计算效率。
- 通过这种硬件感知的方法,Mamba在现代硬件(如GPU)上实现了快速的计算和内存效率。3. **简化的端到端神经网络架构**:
- Mamba简化了传统的深度序列模型架构,将SSM的设计和Transformer的MLP块结合起来,形成了一个简单且同质的架构设计。
- 这种设计通过重复单一的Mamba块来构建,其中每个块都包含选择性状态空间,而不是像传统架构那样交替使用不同的模块。
- 这种简化的架构不仅提高了模型的性能,而且使得模型更容易训练和部署。4. **跨模态性能**:
- Mamba模型不仅在语言建模任务上取得了与Transformer相当的性能,还在音频和基因组数据的建模任务中超越了先前的最先进的模型。
- 这表明Mamba模型具有很强的泛化能力,能够处理多种类型的序列数据,并且能够在不同的领域中发挥作用。5. **开源代码和预训练模型**:
- 为了促进研究和应用的发展,作者公开了Mamba模型的代码和预训练模型,使得其他研究人员和开发者可以轻松地复现和扩展Mamba的工作。
- 开源的代码库提供了模型的实现细节,包括模型架构、训练过程和评估方法,这有助于社区进一步探索和改进Mamba模型。总结来说,Mamba模型通过其创新的选择性状态空间机制、硬件感知的计算策略和简化的架构设计,为长序列数据的高效处理提供了一种新的解决方案。它的跨模态性能和开源的实现进一步增强了其在深度学习序列建模领域的潜力和实用性。
Mamba是在论文Mamba: Linear-Time Sequence Modeling with Selective State Spaces中提出的。你可以在其官方仓库repository中找到它的实现和模型检查点。
本文旨在深入探讨状态空间模型在语言建模中的应用,并逐步解释相关概念,旨在帮助读者对该领域有一个清晰的认识。随后,我们将探讨Mamba模型如何可能对传统的Transformer架构构成挑战。
作为一本侧重于视觉辅助的指南,本文将提供大量的图表和可视化资料,以便读者更好地理解Mamba和状态空间模型的工作原理。
目录
-
第 1 部分:Transformers的问题
-
Transformers的核心组件
-
A Blessing with Training
-
推理的缺陷
-
RNN 是一种解决方案吗?
-
-
第 2 部分:状态空间模型 (SSM)
-
什么是状态空间?
-
什么是状态空间模型?
-
从连续信号到离散信号
-
循环表征
-
卷积表示
-
三种表征
-
矩阵 A 的重要性
-
-
第 3 部分:Mamba - 选择性 SSM
-
它试图解决什么问题?
-
有选择地保留信息
-
扫描操作
-
硬件感知算法
-
The Mamba Block
-
第 1 部分:Transformers的问题
Transformer模型将所有文本输入视为一系列的标记组成。
Transformer模型的一大优势在于,它能够回顾并利用序列中任何早期标记的信息,以此来生成每个标记的表征。
Transformers的核心组件
Transformer模型由两部分核心结构组成:一是用于理解文本内容的编码器模块,二是用于生成文本输出的解码器模块。这两种结构通常联合使用,以应对包括机器翻译在内的多种语言处理任务。
A Blessing with Training
每个解码器模块主要包含两个核心部分:首先是进行掩码自注意力机制的处理,然后是前馈神经网络的运算。
自注意力机制是这些模型如此高效的主要原因之一。它能够在训练过程中迅速捕捉到整个序列的全局信息。
它是如何实现这一点的呢?
自注意力通过创建一个矩阵来工作,这个矩阵对序列中的每个标记与之前的每个标记进行比较。矩阵中的权重反映了各个标记之间的相互关联度。
在训练过程中,这个矩阵是同时构建的。这意味着在考虑“name”和“is”之间的关系时,并不需要先单独计算“My”和“name”之间的关系。
这种机制支持并行处理,极大地提升了训练速度!
推理的缺陷
然而,这种机制存在一个缺点。在生成序列中的下一个标记时,我们必须重新计算整个序列的注意力,即使某些标记已经被生成。
为长度为L的序列生成标记大约需要L²计算,如果序列长度增加,计算成本可能会很高。
这种需要重新计算整个序列的情况是 Transformer 架构的一个主要瓶颈。
用循环神经网络来解决推理速度慢的问题。
RNN 是一种解决方案吗?
循环神经网络(RNN)是一种按序列顺序处理数据的网络。在每个时间步,它需要两个输入:当前时间步t的输入数据,以及前一个时间步t-1的隐藏状态。这两个输入共同作用,生成新的隐藏状态,并据此预测输出。
RNN通过其循环结构实现信息在序列各步骤间的传递。为了更清晰地理解这一过程,我们可以将RNN的循环机制“展开”,从而更直观地展现信息的流动。
当RNN生成输出时,它只需考虑之前的隐藏状态和当前的输入,这样就避免了重新计算所有先前隐藏状态的需要,而这正是Transformer所做的。
换而言之,RNN在推理时能够快速响应,因为它的计算复杂度与序列长度成线性关系。理论上,它可以处理无限长的上下文。
为了展示这一特性,我们将对之前使用过的输入文本应用RNN进行处理。
每个隐藏状态都是对之前所有隐藏状态的综合,通常呈现为信息的压缩形式。
注意到,在生成名字“Maarten”时,最后一个隐藏状态已经不再包含关于单词“Hello”的信息。随着时间的推移,RNN倾向于遗忘信息,因为它们仅考虑前一个隐藏状态。
RNN的这种顺序处理特性还导致了另一个问题:训练无法并行进行,因为必须按顺序完成每个步骤。
与 Transformer 相比,RNN 的推理速度非常快,但不可并行化。
我们能否设计一种既能够像Transformer那样并行训练,又能在推理时保持与序列长度线性扩展的架构呢?
答案是肯定的!这正是Mamba模型所实现的,但在深入探讨其架构之前,让我们先对状态空间模型这一领域进行一番探索。
第 2 部分:状态空间模型 (SSM)
状态空间模型(SSM),与Transformer和RNN一样,用于处理信息序列,例如文本和信号。在本节中,我们将探讨SSM的基本概念以及它们如何与文本数据相互作用。
什么是状态空间?
状态空间是一组能够完整捕捉系统行为的最少变量集合。它是一种数学建模方法,通过定义系统的所有可能状态来表述问题。
让我们用一个更简单的例子来理解这个概念。想象我们正在走过一个迷宫。这里的“状态空间”就像是迷宫中所有可能位置的集合,即一张地图。地图上的每个点都代表迷宫中的一个特定位置,并包含了该位置的详细信息,比如离出口有多远。
而“状态空间表示”则是对这张地图的抽象描述。它告诉我们当前所处的位置(当前状态)、我们可以移动到哪些位置(未来可能的状态),以及如何从当前位置转移到下一个状态(比如向左转或向右转)。
虽然状态空间模型利用方程和矩阵来记录这种行为,但它们本质上是一种记录当前位置、可能的前进方向以及如何实现这些移动的方法。
描述状态的变量(在我们的迷宫例子中,这些变量可能是X和Y坐标以及与出口的相对距离)被称作“状态向量”。
这个概念听起来是不是有些耳熟?那是因为在语言模型中,我们经常使用嵌入或向量来描述输入序列的“状态”。例如,当前位置的向量(即状态向量)可能如下所示:
在神经网络的语境中,“状态”通常指的是网络的隐藏状态。在大型语言模型的背景下,隐藏状态是生成新标记的一个关键要素。
什么是状态空间模型?
状态空间模型(SSM)是用来描述这些状态表示,并根据给定的输入预测下一个可能状态的模型。
在传统意义上,SSM在时间t的工作方式如下:
- 它将输入序列x(t)(例如,迷宫中的左移和下移)映射到潜在的状态表示h(t)(例如,距离出口的远近以及X/Y坐标)。
- 然后它从这个状态表示中推导出预测的输出序列y(t)(例如,为了更快到达出口而再次左移)。
不过,与传统模型不同的是,SSM不仅处理离散的序列(比如一次向左移动),它还能够处理连续的序列作为输入,并预测输出序列。
状态空间模型(SSM)假定动态系统(比如在三维空间中移动的物体)的状态可以通过两个数学方程来预测,这两个方程描述了系统在时间t时的状态如何随时间演变。
通过解这两个方程,我们假设能够发掘出统计规律,从而根据观察到的数据(包括输入序列和先前的状态)来预测系统的状态。
其目标是确定状态表示h(t),以便我们能够从输入序列映射到输出序列。
这两个方程构成了状态空间模型的核心。
在本指南中,我们将多次提及这两个方程。为了帮助您更快地理解和引用它们,我们使用了颜色编码来突出显示。
状态方程展示了输入如何通过矩阵B影响状态,以及状态如何通过矩阵A随时间变化。
正如我们之前看到的,h(t)指的是任何给定时间t的潜在状态表示,而x(t)指的是某个输入。
输出方程描述了状态如何转换为输出(通过矩阵 C )以及输入如何影响输出(通过矩阵 D )。
注意:矩阵A 、B 、C和D通常也称为参数,因为它们是可学习的。
可视化这两个方程为我们提供了以下架构:
让我们逐步探究这些技术细节,以了解这些矩阵如何在学习过程中发挥作用。
设想我们有一个输入信号x(t),这个信号首先与矩阵B相乘,而矩阵B刻画了输入对系统的影响程度。
更新后的状态(类似于神经网络的隐藏状态)是一个潜在空间,它包含了环境的核心“知识”。我们将这个状态与矩阵A相乘,矩阵A揭示了所有内部状态是如何相互连接的,因为它们代表了系统的基本动态。
您可能已经注意到,矩阵A在创建状态表示之前被应用,并在状态表示更新之后进行更新。
接着,我们利用矩阵C来定义状态如何转换为输出。
最后,我们可以利用矩阵 D提供从输入到输出的直接信号。这通常也称为跳跃连接。
由于矩阵 D类似于跳跃连接,因此在没有跳跃连接的情况下,SSM 通常被视为如下。
回到我们的简化视角,我们现在可以关注矩阵A 、B和C作为 SSM 的核心。
我们可以像之前一样更新原始方程(并添加一些漂亮的颜色)来表示每个矩阵的用途。
这两个方程共同作用,目的是根据观测数据来预测系统的状态。由于输入被假定为连续的,状态空间模型的主要表现形式是连续时间表示。
从连续信号到离散信号
对于连续信号,直接找到状态表示h(t)在分析上可能颇具挑战。此外,由于我们通常处理的都是离散输入(比如文本序列),我们希望将模型转换为离散形式。
为了实现这一点,我们采用了零阶保持(Zero-Order Hold, ZOH)技术。其工作原理如下:每当接收到一个离散信号时,我们就保持该信号值不变,直到下一个离散信号的到来。这个过程实际上创建了一个SSM可以处理的连续信号。
我们保持信号值的时间由一个新的可学习参数表示,这个参数称为步长Δ。它代表了输入信号的分辨率。
现在,由于我们有了连续的输入信号,我们能够生成连续的输出信号,并且只需根据输入信号的时间步长来对输出值进行采样。
这些采样值构成了我们的离散输出。
从数学角度来看,我们可以按照以下方式应用零阶保持技术:
这些技术和方法使我们能够将连续状态空间模型(SSM)转换为离散形式,其公式不再是连续函数到函数的映射 x(t) → y(t),而是离散序列到序列的映射 xₖ → yₖ:
这里,矩阵A和B现在表示模型的离散参数。
我们使用k而不是t来表示离散时间步长,并在我们提到连续 SSM 与离散 SSM 时使其更加清晰。
注意:我们在训练期间仍然保存矩阵 A的连续形式,而不是离散化版本。在训练过程中,连续表示被离散化。
现在我们已经有了离散表示的公式,让我们探索如何实际计算模型。
循环表征
我们的离散 SSM 允许我们以特定的时间步长而不是连续信号来表述问题。正如我们之前在 RNN 中看到的那样,循环方法在这里非常有用。
如果我们考虑离散时间步长而不是连续信号,我们可以用时间步长重新表述问题:
在每个时间步,我们计算当前输入 ( Bx ₖ ) 如何影响先前的状态 ( Ahₖ₋₁ ),然后计算预测输出 ( Ch ₖ )。
这种表示可能已经有点熟悉了!我们可以像之前看到的 RNN 一样来处理它。
我们可以这样展开(或展开):
请注意我们如何使用 RNN 的基础方法来使用这个离散化版本。
这种技术给我们带来了 RNN 的优点和缺点,即快速推理和缓慢训练。
卷积表示
另一种可用于状态空间模型(SSM)的表示形式是卷积。回想一下,在传统的图像识别任务中,我们使用过滤器(或称为内核)来提取图像的聚合特征:
由于我们处理的是文本而不是图像,因此我们需要一维视角:
我们用来表示这个“过滤器”的内核源自 SSM 公式:
让我们探讨一下这个内核在实践中是如何工作的。与卷积一样,我们可以使用 SSM 内核来检查每组标记并计算输出:
这种操作也展示了填充如何影响输出结果。我调整了填充的顺序以优化可视化效果,尽管在实际应用中,我们通常会在句子的末尾添加填充。
在接下来的步骤中,内核将移动一个位置,以进行下一步的计算:
最后一步,我们可以看到内核的完整效果:
将 SSM 表示为卷积的一个主要好处是它可以像卷积神经网络 (CNN) 一样进行并行训练。然而,由于内核大小固定,它们的推理不如 RNN 那样快速和无限制。
三种表征
这三种表示法,连续的,循环的,和卷积的都有不同的优点和缺点:
有趣的是,我们现在可以在推理时利用循环SSM的高效性,并在训练时利用卷积SSM的并行处理能力。
借助这些不同的表示形式,我们可以采用一种巧妙的方法,即根据不同的任务需求选择合适的表示。在训练阶段,我们采用可以并行计算的卷积表示,以便加快训练速度;而在推理阶段,我们则切换到高效的循环表示,以优化推理性能:
该模型称为线性状态空间层(LSSL)
这些表示形式都共有一个关键特性,即线性时不变性(LTI)。LTI属性指出,在状态空间模型(SSM)中,参数A、B和C对于所有时间步来说是恒定的。这意味着无论SSM生成哪个令牌,矩阵A、B和C都是一模一样的。
换句话说,无论你向SSM提供何种顺序的输入,A、B和C的值都不会改变。我们拥有的是一个不区分内容的静态表示。
在我们深入探讨Mamba如何应对这一挑战之前,让我们先来分析这个难题的最后一部分:矩阵A。
矩阵A的重要性
矩阵A可以说是状态空间模型(SSM)公式中最为关键的组成部分之一。正如我们之前在循环表示中所讨论的,矩阵A负责捕捉先前状态的信息,并利用这些信息来构建新的状态。
本质上,矩阵A产生隐藏状态:
因此,构建矩阵A的关键可能在于仅保留之前几个标记的记忆,并捕捉我们所见每个标记之间的差异。特别是在循环表示的背景下,由于它仅考虑前一个状态,这一点尤为重要。
那么,我们如何创建一个能够保持大容量记忆(即上下文大小)的矩阵A呢?
这时,我们就用到了HiPPO(Hungering Hungry Hippo)或者说河马,这是一个高阶多项式投影运算器。HiPPO的目标是将迄今为止观察到的所有输入信号压缩成一个系数向量。
HiPPO利用矩阵A构建一个状态表示,这个表示能够有效地捕捉最近令牌的信息,并同时让旧令牌的影响逐渐减弱。其公式可以表示为:
假设我们有一个方阵A ,这给我们:
实践证明,使用HiPPO构建矩阵A的方法明显优于随机初始化。因此,它能够更精确地重建最新的信号(即最近的令牌),而不仅仅是初始状态。
HiPPO矩阵的核心在于其能够生成一个隐藏状态,用以存储历史信息。
在数学上,这是通过追踪勒让德多项式的系数来实现的,这使得它能够近似所有历史数据。
HiPPO随后被应用到循环和卷积表示中,以处理远程依赖关系。这导致了序列的结构化状态空间(S4)的产生,这是一种能够有效处理长序列的SSM。
S4由三部分组成:
- 状态空间模型
- HiPPO用于处理远程依赖关系
- 用于创建循环和卷积表示的离散化处理
勒让德多项式
序列的结构化状态空间 (S4)
这种类型的SSM具有多个优点,具体取决于您选择的表示形式(无论是循环还是卷积)。它还能够通过构建HiPPO矩阵来有效地处理长文本序列,并高效地存储记忆。
注意:如果您想深入了解有关如何计算 HiPPO 矩阵并自行构建 S4 模型的更多技术细节,我强烈建议您阅读Annotated S4 。
第 3 部分:Mamba - 选择性 SSM
我们已经掌握了理解Mamba模型独特之处所需的基础知识。状态空间模型可以用于建模文本序列,但仍存在一系列我们希望通过Mamba来避免的缺点。
在本节中,我们将介绍Mamba的两个主要创新:
- 选择性扫描算法:这个算法允许模型筛选(ignore/reject)与任务不相关的信息。
- 一种硬件感知算法:这种算法通过并行扫描、内核融合和有效存储(middle)结果的重新计算,提高了计算效率。
这两项创新共同构成了选择性状态空间模型(S6),这些模型可以用于构建Mamba块,例如自注意力机制。
在深入探讨这两个创新之前,我们先来探讨为什么它们是必要的。
它试图解决什么问题?
状态空间模型,包括S4(结构化状态空间模型),在某些语言建模和生成任务上表现不佳,尤其是在需要关注或忽略特定输入的能力方面。
我们可以通过两个综合任务来说明这一点,即选择性复制和感应头。
在选择性复制任务中,SSM的目标是复制输入的一部分并按顺序输出:
然而,循环/卷积SSM在执行此任务时表现不佳,因为它是线性时不变的。正如我们之前所观察到的,对于SSM生成的每个令牌,矩阵A、B和C都是固定不变的。
因此,SSM无法执行内容感知推理,因为它将每个标记视为固定A、B和C矩阵的结果。这是一个问题,因为我们希望SSM能够对输入(提示)进行推理。
SSM在另一个任务上表现不佳,那就是感应头,其目标是重现输入中发现的模式:
在上面的示例中,我们实际上是在执行一个一次性提示,试图“教会”模型在“Q:”之后提供“A:”响应。然而,由于SSM是时不变的,它无法从其历史记录中选择要调用的先前令牌。
让我们通过关注矩阵B来进一步说明这一点。无论输入x是什么,矩阵B都保持完全不变,因此与x无关:
同样,无论输入如何,A和C也保持固定。这证明了我们迄今为止所看到的 SSM 的静态性质。
Transformers在这些任务上表现较为出色,因为它们能够根据输入序列动态调整注意力,选择性地“查看”或“参与”序列的不同部分。
相比之下,SSM在这些任务上的表现不尽人意,这揭示了其根本问题:时不变的SSM(如矩阵A、B和C的静态性质)导致了内容感知上的困难。
有选择地保留信息
SSM的循环表示创建了一个非常高效的小状态,因为它压缩了整个历史记录。然而,与不压缩历史记录(通过注意力矩阵)的Transformer模型相比,它的功能要弱得多。
Mamba的目标是实现两全其美:创建一个像Transformer一样强大的小状态。
正如上文所述,Mamba通过有选择地将数据压缩到状态中来实现这一目标。当输入一个句子时,通常会包含一些没有多大意义的信息,例如停用词。
为了有选择地压缩信息,我们需要参数依赖于输入。为此,我们首先探讨训练期间SSM中输入和输出的维度。
在结构化状态空间模型 (S4) 中,矩阵A 、B和C独立于输入,因为它们的维度N和D是静态的并且不会改变。
相反,Mamba通过合并输入的序列长度和批量大小,使得矩阵B和C以及步长Δ都取决于输入。
这意味着对于每个输入标记,我们现在有不同的B和C矩阵,可以解决内容感知问题!
注意:矩阵A保持不变,因为我们希望状态本身保持静态,但它受到影响的方式(通过B和C )是动态的。
他们一起有选择地决定哪些内容应该保留在隐藏状态,以及哪些内容应该被忽略,因为它们现在依赖于输入。
较小的步长Δ会导致忽略特定单词,而更多地利用先前的上下文;而较大的步长Δ则会更多地关注输入单词而不是上下文。
扫描操作
由于这些矩阵现在是动态的,无法使用卷积表示来计算它们,因为卷积假设固定内核。我们只能使用循环表示,这导致失去了卷积提供的并行性。
为了实现并行化,让我们探讨如何使用循环计算输出。
每个状态都是前一个状态(乘以A )加上当前输入(乘以B )的总和。这称为扫描操作,可以使用 for 循环轻松计算。
相反,并行化似乎是不可能的,因为只有在我们拥有前一个状态的情况下才能计算每个状态。然而,Mamba 通过并行扫描算法使这成为可能。
它假设我们执行操作的顺序与关联属性无关。因此,我们可以分段计算序列并迭代地组合它们:
动态矩阵B和C以及并行扫描算法一起创建选择性扫描算法来表示使用循环表示的动态和快速本质。
硬件感知算法
最新 GPU 的一个缺点是其小型但高效的 SRAM 与大型但效率稍低的 DRAM 之间的传输 (IO) 速度有限。在 SRAM 和 DRAM 之间频繁复制信息成为瓶颈。
Mamba 与 Flash Attention 一样,试图限制我们需要从 DRAM 到 SRAM 的次数,反之亦然。它通过内核融合来实现,这允许模型防止写入中间结果并连续执行计算直到完成。
我们可以通过可视化 Mamba 的基础架构来查看 DRAM 和 SRAM 分配的具体实例:
在这里,以下内容被融合到一个内核中:
-
步长为Δ的离散化步长
-
选择性扫描算法
-
与C相乘
硬件感知算法的最后一部分是重新计算。
中间状态不会被保存,但对于向后传递计算梯度是必需的。相反,作者在向后传递期间重新计算这些中间状态。
尽管这看起来效率低下,但它比从相对较慢的 DRAM 读取所有这些中间状态的成本要低得多。
我们现在已经涵盖了其架构的所有组件,使用其文章中的以下图像进行了描述:
选择性 SSM
这种架构通常被称为选择性 SSM或S6模型,因为它本质上是使用选择性扫描算法计算的 S4 模型。
The Mamba Block
到目前为止,我们探索的选择性 SSM可以作为一个块来实现,就像我们在解码器块中表示自注意力一样。
与解码器一样,我们可以堆叠多个 Mamba 块并将它们的输出用作下一个 Mamba 块的输入:
它从线性投影开始,以扩展输入嵌入。然后,在选择性 SSM之前应用卷积以防止独立的令牌计算。
选择性SSM具有以下属性:
-
通过离散化创建循环 SSM
-
HiPPO对矩阵A进行初始化以捕获长程依赖性
-
选择性扫描算法选择性压缩信息
-
加速计算的硬件感知算法
在查看代码实现时,我们可以进一步扩展此架构,并探索端到端示例的外观:
请注意一些变化,例如包含归一化层和用于选择输出标记的 softmax。
当我们将所有内容放在一起时,我们可以获得快速推理和训练,甚至无限的上下文!
使用这种架构,作者发现它可以匹配甚至有时甚至超过相同大小的 Transformer 模型的性能!