Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
目录
0. 摘要
1. 引言
2. 背景与概述
2.1 结构化状态空间模型
2.2 Attention
2.3 结构化矩阵
2.4 概述:结构化状态空间的对偶性
2.5 符号
3. 状态空间模型是结构化矩阵
3.1 状态空间模型的矩阵变换形式
3.2 半可分离矩阵
3.2.1 顺序半可分离(SSS)表示
3.2.2 1-半可分离矩阵:标量 SSM 递推
3.3 状态空间模型是半可分矩阵
3.4 通过结构化矩阵算法计算状态空间模型
3.4.1 线性(递归)模式
3.4.2 二次(朴素, Naive)模式
3.4.3 总结
4. 结构化掩码注意力:使用结构化矩阵泛化线性注意力
4.1 注意力框架
4.1.1 注意力
4.1.2 自注意力
4.1.3 核注意力
4.1.4 掩码(核)注意力
4.2 线性注意力
4.2.1 线性注意力的张量收缩证明
4.3 结构化掩码注意力
4.3.1 总结:掩码注意力的双重形式
5. 状态空间对偶性
5.1 标量-恒等结构状态空间模型
5.2 1-半分离结构化掩码注意力
5.3 结构化状态空间对偶性(SSD)
6. 硬件高效的 SSD 模型算法
6.1 对角块
6.2 低秩块
6.3 计算成本
7. Mamba-2 的架构
7.1 块设计
7.2 序列变换的多头模式
7.3 来自线性注意力机制的其他 SSD 扩展
8. SSM 的系统优化
8.1 张量并行
8.2 序列并行
8.3 可变长度
9. 实验
9.1 合成任务:关联召回
9.2 语言建模
9.2.1 扩展定律
9.2.2 下游评估
9.2.3 混合模型:将 SSD 层与 MLP 和注意力结合
9.3 速度基准测试
9.4 架构消融
9.4.1 块设计
9.4.2 头结构
9.4.3 注意力核近似
10. 相关工作和讨论
10.1 状态空间模型
10.2 结构化矩阵
10.3 (线性)注意力
10.4 相关模型
11. 结论
0. 摘要
虽然 Transformers 一直是深度学习在语言建模方面取得成功的主要架构,但最近表明,诸如 Mamba 之类的状态空间模型(SSM)在小到中等规模上可以匹敌甚至超越 Transformer。我们展示了这些模型家族实际上是相当紧密相关的,并且开发了一个丰富的理论框架,将 SSM 和各种注意力机制的变体通过对一类研究良好的结构化半可分离矩阵的各种分解联系起来。我们的状态空间对偶(state space duality,SSD)框架使我们能够设计一个新架构(Mamba-2),其核心层是 Mamba 选择性 SSM 的精炼版本,速度提高了 2-8 倍,同时在语言建模方面继续与 Transformers 竞争。
模型代码和预训练检查点:https://github.com/state-spaces/mamba
1. 引言
Transformers,特别是仅解码模型(例如 GPT (Brown et al., 2020),Llama (Touvron, Lavril, 等, 2023)),以因果方式处理输入序列,是现代深度学习成功的主要驱动因素之一。许多方法试图近似核心注意力层,以解决其效率问题(Tay 等,2022),例如在训练期间序列长度成二次方增长,并且在自回归生成期间需要线性序列长度大小的缓存。同时,另一类替代序列模型,结构化状态空间模型(SSM),在训练期间序列长度线性增长,生成期间状态大小恒定。它们在长距离任务中表现强劲(例如 S4 (Gu, Goel, 和 Ré,2022)),最近在小到中等规模的语言建模中匹敌甚至超越了 Transformers(例如 Mamba (Gu 和 Dao,2023))。然而,SSM 的发展似乎与社区改善 Transformers 的集体努力相脱节,例如理论理解和在现代硬件上优化。因此,与 Transformer 相比,理解和实验 SSM 更为困难,从算法和系统的角度来看,训练 SSM 也更具挑战性。
我们的主要目标是建立一个丰富的理论连接体系,将结构化 SSM 与各种注意力机制联系起来。这将使我们能够将最初为 Transformers 开发的算法和系统优化转移到 SSMs 上,以构建比 Transformers 性能更好且在序列长度上更高效的基础模型。一个重要的贡献是线性注意力(Linear Attention,LA)框架(Katharopoulos 等,2020),通过展示二次核注意力(quadratic kernelized attention)的 “对偶形式” 与特定线性递归的等价性,导出了自回归注意力与线性 RNN 之间的联系。这种对偶性使得能够进行高效的可并行训练和高效的自回归推理。同样,本论文提供了多种视角,将线性复杂度的 SSM 与二次复杂度的形式联系起来,结合 SSM 和注意力的优势。
(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN
状态空间对偶。我们称之为结构化状态空间对偶(structured state space duality,SSD)的框架,通过对结构化矩阵(具有亚二次(subquadratic)参数和乘法复杂度的矩阵)的抽象(abstraction)来连接结构化 SSM 和各种注意力机制。我们开发了两种广泛的表示序列模型的框架,一种是矩阵变换,另一种是张量收缩(tensor contraction),各自揭示了对偶性的不同视角。我们的技术贡献包括:
- 我们展示了状态空间模型和一类研究良好的结构化矩阵(称为半可分离矩阵,semiseparable matrices)之间的等价性(第 3 节)。这种联系是我们框架的核心,揭示了 SSM 的新性质和算法。本文的一个中心信息是,计算状态空间模型的不同方法可以重新框定为在结构化矩阵上进行各种矩阵乘法算法。
- 我们显著改进了线性注意力(Katharopoulos 等,2020)的理论。我们首先通过张量收缩的语言提供了其递归形式的精确证明,然后将其推广到一类新的结构化掩码注意力(structured masked attention,SMA)(第 4 节)。
- 我们连接了 SSM 和 SMA,展示了它们有很大的交集,是相互的对偶,具有像 SSM 的线性形式和像注意力的二次形式(第 5 节)。我们还证明,任何具有快速递归形式的核注意力方法必须是一个 SSM。
除了其内在的理论价值,我们的框架还开辟了广泛的方向来理解和改进序列模型。
高效算法。首先也是最重要的,我们的框架揭示了计算 SSM 的新高效且易于实现的算法(第 6 节)。我们引入了一种基于半可分离矩阵块分解的新 SSD 算法,利用了线性 SSM 递归和二次对偶形式,在所有主要效率轴(如训练和推理计算、内存使用和利用现代硬件上的矩阵乘法单元的能力)上获得了最佳权衡。专门的 SSD 实现比优化的 Mamba 选择性扫描实现快 2-8 倍,同时允许更大的递归状态尺寸(是 Mamba 的 8 倍或更高,几乎没有减速)。SSD 与 softmax 注意力(FlashAttention-2(Dao 2024))的优化实现竞争,在序列长度 2K 处交叉,在序列长度 16K 时快 6 倍。
架构设计。采用新架构如 SSM 的一个主要障碍是针对 Transformers 量身定制的生态系统,例如用于大规模训练的硬件高效优化和并行技术。我们的框架允许使用已建立的注意力惯例和技术来构建 SSM 的架构设计选择词汇,并进一步改进它们(第 7 节)。例如,我们引入了多头注意力(multi-head attention,MHA)的头部在 SSM 中的类似物。我们展示了 Mamba 架构是一个多输入 SSM(multi-input SSM,MIS),类似于多值注意力(multi-value attention,MVA),并比较了具有不同头部结构的其他 Mamba 变体。
我们还使用这些想法对 Mamba 模块进行轻微修改,以便实现张量并行(例如,Megatron(Shoeybi 等,2019)的风格)。主要思想包括引入分组值注意力(grouped-value attention,GVA)头部结构,并将所有依赖于数据的投影移至模块开始时并行发生。
修改后的并行 Mamba 模块与作为内部 SSM 层的 SSD 相结合,形成了 Mamba-2 架构。我们在与 Mamba 相同的设置中研究了 Mamba-2 的 Chinchilla 扩展定律(scaling laws),发现它在困惑度和实际运行时间上优于 Mamba 和 Transformer++。我们还在 Pile 上训练了不同大小的 Mamba-2 模型家族,显示它在标准下游评估中匹敌或超越了 Mamba 和开源 Transformers。例如,具有 2.7B 参数的 Mamba-2 在 Pile 上训练了 300B token,在相同数据集上超过了 Mamba-2.8B、Pythia-2.8B 甚至 Pythia-6.9B。
系统优化。SSD 框架连接了 SSM 和 Transformers,使我们能够利用大量针对 Transformers 开发的系统优化工作(第 8 节)。
- 例如,张量并行(Tensor Parallelism,TP)是一种重要的模型并行技术,通过在同一节点上的 GPU 之间拆分每一层来训练大型 Transformer 模型。我们设计的 Mamba-2 友好于 TP,将每个模块的同步点减少了一半。
- 对于非常长的序列,其激活不能放在一个设备上,已为注意力模块开发了序列并行。我们描述了如何训练一般的 SSMs,特别是 Mamba-2,通过在设备之间传递递归状态实现序列并行。
- 对于具有不同长度示例的微调,为了获得最佳效率,Transformer 需要复杂的技术来移除填充标记(padding tokens)并对可变长度序列执行注意力。我们展示了如何高效地训练 Mamba-2 具有可变序列长度,不需要填充 token。
第 9 节在语言建模、训练效率和一个难度较大的多查询关联回忆任务(Arora, Eyuboglu, Zhang, 等,2024)上实证验证了 Mamba-2。最后,在第 10 节中,我们提供了扩展的相关工作并讨论了我们框架带来的潜在研究方向。
2. 背景与概述
2.1 结构化状态空间模型
(2023,SSM,门控 MLP,选择性输入,上下文压缩)Mamba:具有选择性状态空间的线性时间序列建模结构化状态空间序列模型(Structured state space sequence models,S4)是一类最近用于深度学习的序列模型,广泛关联于 RNNs、CNNs 和经典状态空间模型。它们受到特定连续系统(1)的启发,该系统通过隐式潜在状态 h ∈ R^(T,N) 将一维序列 x ∈ R^T 映射到 y ∈ R^T。
结构化 SSMs 的一般离散形式如下:
结构化 SSMs 之所以命名为 “结构化”,是因为控制时间动态的矩阵 A 必须是结构化的,以便在深度神经网络中高效地计算这种序列到序列的转换。最初引入的结构是对角加低秩(diagonal plus low-rank,DPLR)(Gu, Goel 和 Ré 2022)和对角(Gu, Gupta 等 2022;Gupta, Gu 和 Berant 2022;J. T. Smith, Warrington 和 Linderman 2023),这些结构仍然是最受欢迎的。
在本研究中,我们使用术语 “状态空间模型”(SSM)来指代结构化 SSMs。这类 SSMs 有许多变体,与神经序列模型的几大主要范式如连续时间、递归和卷积模型有深厚的联系(Gu, Johnson, Goel 等 2021)。我们在下文中提供了简要概述,并参考了先前的工作以获得更多背景和细节(Gu 2023;Gu 和 Dao 2023)。
连续时间模型。原始的结构化 SSMs 起源于函数 x(t)∈R ↦ y(t)∈R 上的连续时间映射,而不是直接作用于序列。在连续时间视角中,在方程(1a)中,矩阵 (A,B) 不是直接学习的,而是从底层(underlying)参数 (∘A,∘B) 生成的,并且带有一个参数化的步长 Δ。通过固定公式 A=fA(Δ,∘A) 和 B=fB(Δ,∘B) 将“连续参数” (Δ,∘A,∘B) 转换为“离散参数” (A,B),其中 (f_A,f_B) 被称为离散化规则。
备注 1: 尽管我们的主要模型采用了与先前工作相同的参数化和离散化步骤(见 Gu 和 Dao(2023)了解详细信息),为了简化阐述和符号表示,我们在本文其余部分中省略了它。我们注意到,先前关于结构化 SSMs 的工作将连续参数 (∘A,∘B) 和离散参数 (A,B) 分别称为 (A,B) 和 (ˉA,ˉB);我们更改了符号表示以简化展示并直接关注控制主要 SSM 递归的离散参数。
递归模型。方程(1)和(2)采用了输入 x 线性递归的形式。因此,结构化 SSMs 可以视为递归神经网络(RNNs)的一种类型,其中线性赋予了它们额外的性质,并使其避免了传统 RNNs 的顺序计算。相反,尽管有这种简化,SSMs 仍然是全表达的序列转换(在通用近似的意义上)(Kaul 2020;Orvieto 等 2023;Shida Wang 和 Xue 2023)。
卷积模型。当 SSM 的动态在时间上是恒定的,如方程(1)所示,模型被称为线性时不变(linear time-invariant,LTI)。在这种情况下,它们等同于卷积。因此,SSMs 也可以视为 CNNs 的一种类型,但(i)卷积核通过 SSM 参数 (A,B,C) 隐式参数化,(ii)卷积核通常是全局而非局部的。相反,通过经典信号处理理论,所有足够良好的卷积都可以表示为 SSMs。
通常,以前的 LTI SSMs 会使用卷积模式进行高效的可并行训练(提前看到整个输入序列),并切换到递归模式(1)以进行高效的自回归推理(逐步看到输入)。
选择性状态空间模型。Mamba 中引入了参数 (A,B,C) 也可以随时间变化的形式(2),称为选择性 SSM。与标准 LTI 公式(1)相比,该模型可以在每个时间步选择性地关注或忽略输入。事实证明,与信息密集型数据(如语言)上的 LTI SSMs 相比,尤其是当其状态大小 N 增加以允许更多信息容量时,表现更好。然而,它只能在递归模式下计算,而不能在卷积模式下计算,并且需要仔细的硬件感知实现才能高效。即便如此,它仍然不如硬件友好的模型(如 CNNs 和 Transformers)高效,因为它没有利用现代加速器(如 GPUs 和 TPUs)专门针对的矩阵乘法单元。
虽然时不变 SSMs 与连续、递归和卷积序列模型密切相关,但它们与注意力没有直接关系。在本文中,我们展示了选择性 SSMs 与注意力之间更深层次的关系,并利用它显著提高了 SSMs 的训练速度,同时允许更大的状态尺寸 NNN。
结构化 SSMs 作为序列转换(Sequence Transformations)。
定义 2.1。我们使用术语序列转换来指代对序列的参数化映射 Y=fθ(X),其中 X,Y ∈ R^(T,P),且 θ 是参数的任意集合。T 代表序列或时间轴;下标索引到第一维,例如 Xt,Yt ∈ R^P。
序列转换(例如 SSMs 或自注意力)是深度序列模型的基石,它们被纳入神经网络架构中(例如 Transformers)。方程(1)或(2)中的 SSM 是一个序列转换,其中 P=1;可以通过简单地在此维度上进行广播将其推广到 P>1(换句话说,将输入视为 P 个独立序列,并对每个序列应用 SSM)。我们可以将 P 视为一个头维度,我们将在第 7 节详细说明。
定义 2.2。我们定义 SSM 运算符
为序列转换 X ∈ R^(T,P) ↦ Y ∈ R^(T,P),由方程(2)定义。
在 SSMs 中,N 维度是一个自由参数,称为状态大小或状态维度。我们也称之为状态扩展因子,因为它将输入/输出的大小扩展了一个因子 N,这对这些模型的计算效率有影响。
最后,我们指出,许多类型的序列转换,例如注意力,可以表示为跨序列维度的单个矩阵乘法。
定义 2.3。如果序列转换 Y=fθ(X) 可以写成 Y=M_θ·X 的形式,其中 M 是依赖于参数 θ 的矩阵,我们称其为矩阵转换。我们用矩阵 M 识别序列转换,并在上下文清晰时通常省略对 θ 的依赖。
2.2 Attention
注意力机制广泛指一种计算方法,它为序列中的每对位置分配分数,使得每个元素能够 “关注” 其他元素。迄今为止,最常见和重要的注意力变体是 softmax 自注意力,其定义为:
其中 Q,K,V ∈ R^(T,P)。成对比较机制(通过 QK^T 产生)导致了注意力的特征性二次训练成本。
(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN
虽然提出了许多注意力的变体,但它们都共享这些注意力分数的核心,并采用各种近似(Tay et al. 2022)。对本文工作最重要的变体是线性注意力(Katharopoulos et al. 2020)。大致来说,这类方法通过将 softmax 折叠到核特征映射中来省略 softmax,并利用矩阵乘法的结合性将 (QK^T)⋅V 重写为 Q⋅(K^T·V)
此外,在因果(自回归)注意力的重要情况下,当因果掩码(causal mask)被合并到左侧
中时,其中 L 是下三角全 1 矩阵,那么右侧可以展开(unrolling)为递归。最近的多项工作如 RetNet(Y. Sun et al. 2023)和 GateLoop(Katsch 2023)将这一点加强为更一般形式的 L(见第 10 节)。在本文中,我们的结构化掩码注意力的公式将强烈地概括这些思想。
2.3 结构化矩阵
一般矩阵 M∈R^(T,T) 需要 T^2 个参数来表示,并且需要 O(T^2) 时间来执行矩阵-向量乘法等基本操作。结构化矩阵是指:
- 通过压缩表示可以以次二次(理想情况下线性)参数表示的矩阵,且
- 可以通过直接操作这种压缩表示来实现快速算法(最重要的是矩阵乘法)。
也许最典型的结构化矩阵家族是稀疏和低秩矩阵。然而,还有许多其他家族,如 Toeplitz、Cauchy、Vandermonde 和 butterfly 矩阵,它们都已被用于机器学习中以实现高效模型(Dao, Gu, et al. 2019;D. Fu et al. 2024;Gu, Gupta, et al. 2022;Thomas et al. 2018)。结构化矩阵是高效表示和算法的强大抽象。在本文中,我们将展示 SSMs 等价于另一类以前未用于深度学习的结构化矩阵,并利用这一联系推导出高效的方法和算法。
2.4 概述:结构化状态空间的对偶性
虽然本文发展了 SSMs、注意力和结构化矩阵之间更丰富的联系框架,但我们提供了主要方法的简要总结,该方法在算法上实际上是非常自包含(self-contained)和简单的。
递归(线性)形式。状态空间对偶(SSD)层可以定义为选择性 SSM(2)的特例。SSM 的标准计算可以作为递归(或并行扫描)应用,其在序列长度上具有线性复杂度。与 Mamba 中使用的版本相比,SSD 有两个细微的差异:
- A 的结构从对角化简化为标量乘以单位结构。在这种情况下,每个 A_t 也可以识别为仅一个标量。
- 我们使用较大的头维度 P,相比于 Mamba 中使用的 P=1。通常选择 P={64,128},这类似于现代 Transformers 的惯例。
与原始选择性 SSM 相比,这些变化可以看作是略微减少表达能力以换取显著的训练效率改进。特别是,我们的新算法将允许在现代加速器上使用矩阵乘法单元。
对偶(二次)形式。SSD 的对偶形式是与注意力紧密相关的二次计算,定义为:
其中,𝑎𝑖 是在 [0, 1] 范围内的依赖于输入的标量。 与标准的 softmax 注意力相比,有两个主要区别:
- 丢弃了 softmax。
- 注意力矩阵与一个额外的掩码矩阵 𝐿 进行逐元素乘法。
这两个变化可以看作是解决 vanilla 注意力中的问题。例如,最近观察到 softmax 会在注意力分数中引发问题,如 “注意力陷阱” (attention sink)现象(Darcet 等,2024;Xiao 等,2024)。更重要的是,掩码矩阵 𝐿 可以视为用不同的数据依赖的位置掩码替代了 Transformer 的启发式位置嵌入,控制信息在时间上的传递量。
更广泛地说,这种形式是我们在第 4 节中定义的线性注意力结构化掩码注意力的一种实例。
矩阵形式和 SSD 算法。通过展示 SSM 具有矩阵变换形式 𝑌 = 𝑀𝑋,对于依赖于 𝜃 = (𝐴, 𝐵, 𝐶) 的矩阵 𝑀_𝜃 ∈ ℝ^(T,T),可以将各种形式的 SSD 连接起来。特别地,SSD 的对偶形式等价于通过矩阵 𝑀 的朴素(二次时间)乘法,而递归形式则是一种利用 𝑀 结构的特定高效(线性时间)算法。
除此之外,任何用于矩阵 𝑀 的乘法算法都可以应用。我们提出的硬件高效 SSD 算法(第 6 节)是一种新结构矩阵乘法方法,涉及 𝑀 的块分解,获得比纯线性或二次形式更好的效率权衡。与通用选择性 SSM(Gu 和 Dao 2023)相比,它相对简单且易于实现;List 1 提供了几行代码的完整实现。 图 1 提供了本文概念之间关系的简单路线图。
2.5 符号
在本文中,我们倾向于使用可以映射到代码的精确符号。
矩阵和向量。我们通常使用小写字母表示向量(即单轴张量),使用大写字母表示矩阵(即多轴张量)。在本文中,我们不使用粗体表示矩阵。有时,如果一个矩阵在一轴上绑定或重复(因此也可以视为向量),我们可能会用大写或小写字母表示它。 · 表示标量或矩阵乘法,◦ 表示 Hadamard(逐元素)乘法。
索引。我们使用 Python 风格的索引,例如 𝑖:𝑗 表示当 𝑖 < 𝑗 时的范围 (𝑖, 𝑖 +1, . . . , 𝑗 −1),以及当 𝑖 > 𝑗 时 (𝑖, 𝑖 −1, . . . , 𝑗 +1) 。例如,对于任何符号 𝑣,我们令 𝑣_(𝑗:𝑖) 对于 𝑗 ≥ 𝑖 表示序列 (𝑣_𝑗, . . . , 𝑣_(𝑖+1))。 [𝑖] 等价于 0:𝑖 = (0, . . . , 𝑖 − 1)。简写时,我们令 𝑣×_𝑗:𝑖 表示乘积 𝑣_𝑗 × · · · × 𝑣_(𝑖+1)。
维度。为了区别于矩阵和张量,我们经常使用打字机字体的大写字母(例如 D、N、T)来表示维度和张量形状。与传统符号 𝑀 ∈ R^(𝑇 ×𝑇) 不同,我们经常使用 𝑀 ∈ R^(T,T) 来反映代码中的张量维度。
张量收缩(Tensor Contractions)。我们将大量依赖张量收缩或 einsum 符号来阐明和证明我们的结果。我们假设读者熟悉这种符号,这在现代张量库(如 numpy)中经常使用。例如,我们可以使用 contract(MN, NK → MK) 来表示矩阵-矩阵乘法运算符,在我们的符号中,contract(MN, NK → MK) (𝑋,𝑌)(等价于 𝑋 · 𝑌)可以翻译成代码为 numpy.einsum('mn, nk → mk', X, Y)。
附录 A 中包含了大量的符号术语。
3. 状态空间模型是结构化矩阵
本节探讨了状态空间模型作为序列变换的不同视角,并概述了这些映射的属性和算法。本节的主要结果是关于状态空间模型与一类称为半可分离矩阵(semiseparable matrices)的结构化矩阵之间的等价性,这意味着新的效率结果(定理 3.5 和 3.7)。
3.1 状态空间模型的矩阵变换形式
回顾我们对 SSM 的定义是通过 (2) 定义的参数化映射。我们的理论框架从简单地将这个变换写成一个矩阵乘法开始,映射向量 𝑥 ∈ R^T ↦ 𝑦 ∈ R^T 。
按定义,ℎ0 = 𝐵0𝑥0。归纳地,
乘以 𝐶𝑡 以产生 𝑦𝑡,并在 𝑡 ∈ [T] 上对方程进行向量化,我们得到了 SSM 的矩阵变换形式。
3.2 半可分离矩阵
方程(3)中的 𝑀 是一类称为半可分离矩阵的矩阵的特定表示。半可分离矩阵是一种基本的矩阵结构。我们首先定义这些矩阵及其属性。
定义 3.1。一个(下三角)矩阵 𝑀 是 N-半可分离的,如果下三角部分(即对角线上或以下的部分)中的每个子矩阵的秩最多为 N。我们称 N 为半可分离矩阵的阶或秩。
定义 3.1,以及其他相关 “可分离” 结构形式(例如准可分离矩阵和其他半可分离矩阵的定义)有时被称为结构化秩矩阵(structured rank matrices)(或秩结构化矩阵,rank-structured matrices),因为它们以子矩阵的秩条件为特征。半可分离矩阵有许多结构化表示,包括分层半可分离(hierarchical semiseparable,HSS)、顺序半可分离(sequential semiseparable,SSS)和 Bruhat 形式(Pernet 和 Storjohann 2018)。我们主要使用 SSS 形式。
3.2.1 顺序半可分离(SSS)表示
定义 3.2。一个下三角矩阵 𝑀 ∈ R^(T,T) 具有 N-顺序半可分离(SSS)表示,如果它可以写成以下形式:
对于向量 𝐵_0, . . . , 𝐵_(T−1),𝐶_0, . . . ,𝐶_(T−1) ∈ R^N 和矩阵𝐴_0, . . . ,𝐴_(T−1) ∈ R^(N,N)。
我们定义操作符 SSS,使得
半可分离矩阵的一个基本结果是它们与具有 SSS 表示的矩阵完全等价。可以用一个简单的构造性证明来推导这个方向。
引理 3.3。具有表示(4)的 N-SSS 矩阵 𝑀 是 N-半可分离的。
证明。考虑任意的非对角线块 𝑀_(𝑗 :𝑗 ′,𝑖′ :𝑖),其中 𝑗 ′ > 𝑗 ≥ 𝑖 > 𝑖′。这有一个显式的秩-N 分解,如下所示:
方程(5)将在推导我们的序列模型的快速算法中被广泛使用。另一个方向在半可分离矩阵的文献中已经被很好地建立。
命题 3.4。每个 N-半可分离矩阵都有一个 N-SSS 表示。
此外,请注意,尽管定义 3.2 涉及表示的 𝑂(N^2·T) 参数(特别是为了存储 𝐴 矩阵),但实际上可以将其压缩到 𝑂(NT) 参数,这在渐近意义下是紧密的(tight)。因此,在本文的其余部分中,我们将混淆结构化矩阵类(定义 3.1)和它的特定表示(定义 3.2);我们将始终使用这种表示,而不是其他候选。反过来,我们将使用 N-SS 来指代 SSS 形式中的 N-半可分离矩阵。
半可分离矩阵是一种基本的矩阵结构,具有许多重要的性质。它们与递推关系密切相关,并可以通过多种表征(例如定义 3.1 和 3.2)来定义,这些表征揭示了它们的不同联系和有效算法。我们在附录 C.1 中提及了它们的一些其他性质。
备注 2。半可分离性的概念非常广泛,文献中出现了许多类似但细微不同的定义;我们的定义可能与其他约定略有不同。首先,因为本文主要关注因果或自回归设置,我们将半可分离性的定义限制为三角形的情况;一些作者更正式地可能称之为 (N, 0)-半可分离性。一些作者也可能将其称为一种准可分离性形式(quasiseparability)(Eidelman and Gohberg 1999; Pernet 2016)。有关简短调查,请参阅Vandebril等人(2005)。
3.2.2 1-半可分离矩阵:标量 SSM 递推
我们将特别关注 1-SS 矩阵的特殊情况。请注意,在这种情况下,𝐶𝑗 和 𝐵𝑖 是标量,并且可以从SSS 表示(4)中因子分解出来(在这种情况下,我们还使用小写字母来强调这些参数是标量)
由于对角矩阵很容易处理(例如,对角矩阵的乘法等同于逐元素标量乘法),我们可以忽略这些术语。因此,我们对于 1-SS 矩阵的基本表示是 𝑀_𝑗𝑖 = 𝑎_(𝑗 :𝑖) 或
1-SS 矩阵的重要性在于它们与标量递归的最简形式的等价性——即状态维度为 N = 1 且没有 (𝐵,𝐶) 投影的退化 SSM 情况。请注意,乘法 𝑦 = 𝑀𝑥 可以通过如下递归计算:
我们因此也将 1-SS 矩阵的矩阵乘法称为标量 SSM 递归(scalar SSM recurrence)或累积乘积和(cumprodsum,累积乘积和的缩写;是累积乘积以及累积和的推广)运算符。作为递归的基本形式,1-SS 矩阵的乘法是我们主要算法的重要构建模块之一。我们强调本文的一个中心主题是,许多序列模型的算法可以归结为结构化矩阵乘法算法。1-SS 矩阵体现了这一联系:存在许多用于计算原始标量递归或累积乘积和运算符的快速算法,而这些算法实际上都是不同的 1-SS 矩阵结构化分解。我们在附录 B 中专门讨论了这些用于 1-SS 矩阵乘法的算法。
3.3 状态空间模型是半可分矩阵
回顾我们对 SSM 的定义,这是通过定义 2.1 确定的参数化映射。SSM 与半可分矩阵之间的联系仅仅是通过将这种变换表示为将向量 𝑥 映射到 𝑦 ∈ R^T 的矩阵乘法。
方程(3)直接建立了状态空间模型与顺序半可分表示之间的联系,这在一般情况下等同于半可分矩阵(引理 3.3 和命题 3.4)。
定理 3.5。状态空间模型变换 𝑦 = SSM(𝐴, 𝐵, 𝐶) (𝑥) 的状态大小为 N,与顺序半可分表示中的 N-SS 矩阵乘法 𝑦 = SSS(𝐴, 𝐵, 𝐶) · 𝑥 相同。
换句话说,序列变换运算符 SSM(定义 2.2)与矩阵构造运算符 SSS(定义 3.2)一致,我们可以互换使用它们(有时简写为 SS)。此外,巧合的是,结构化状态空间模型和顺序半可分矩阵具有相同的缩写,强调了它们的等效性!我们可以方便地使用这些缩写 SSM(状态空间模型或半可分矩阵)、SSS(结构化状态空间或顺序半可分)或 SS(状态空间或半可分)来明确地指代任一概念。然而,我们通常遵循的约定是 SSM 指状态空间模型,SS 指半可分,SSS 指顺序半可分。
图 2 说明了将状态空间模型视为半可分矩阵的序列变换视角。
3.4 通过结构化矩阵算法计算状态空间模型
定理 3.5 之所以重要,是因为它将允许我们将高效计算 SSM(以及其他序列模型)的问题归结为高效的结构化矩阵乘法算法。我们简要概述并将我们主要的新算法推迟到第 6 节,在第 4 和第 5 节中展示 SSM 与其他序列模型的等效性之后。
如前所述,半可分矩阵(即秩结构矩阵)是一类经典的结构化矩阵:
- 它们具有压缩表示,例如 SSS 形式,其参数数量为 𝑂(T) 而不是 𝑂(T^2)。
- 它们具有直接在压缩表示上操作的快速算法。 此外,参数化和矩阵乘法成本在半可分顺序中可以是紧密的。
命题 3.6 (Pernet、Signargout 和 Villard (2023))。大小为 T 的 N-SS 矩阵可以用 𝑂(NT) 参数表示,并且其矩阵-向量乘法在时间和空间上均为 𝑂(NT)。
例如,1-SS 矩阵展示了这种联系的本质。矩阵 𝑀 = 1SS(𝑎) 由正好 T − 1 个参数
定义,并且可以通过遵循标量递归(7)在 𝑂(T) 时间内计算。
3.4.1 线性(递归)模式
在对角结构 SSM(S4D (Gu、Gupta 等 2022))的情况下,命题 3.6 可以很容易地通过利用状态空间模型公式(2)和展开递归来看出。我们在(8)中提供了正式的张量收缩算法,其中维度 S 等于 T。
其中,𝐿 ∈ R^(T,T) 定义为 1SS(𝐴),换句话说,
该算法包括对应(2)的三个步骤:
- 通过输入矩阵 𝐵 扩展输入 𝑋 (8a)
- 展开(unrolling)独立的标量 SSM 递归 (8b)
- 通过输出矩阵 𝐶 收缩隐藏状态 𝐻 (8c)。 请注意,我们在步骤 (8b) 中使用了标量 SSM 和 1-SS 矩阵之间的等效性。
备注 3。我们注意到 (8) 是 Mamba (S6) 模型的特例。然而,由于扩展张量 𝑍 和 𝐻 的大小为 (T, P, N),因此朴素实现是缓慢的;Gu 和 Dao (2023) 引入了硬件感知的实现,以避免实例化这些张量。
令人惊讶的是,定理 3.5 和命题 3.6 立即暗示所有 SSM 的渐近效率与算法 (8) 相同。
定理 3.7。任何状态大小为 N 的状态空间模型(定义 2.2)在序列长度 T 上的计算时间为 𝑂(TN)(不包括潜在的预处理)。
我们注意到,这一结果对结构化 SSM 文献来说是新的。特别是,给定稠密非结构化 𝐴𝑡 矩阵,总表示似乎为 𝑂(TN^2) 的大小。因此,定理 3.7 表明,即使是非结构化 SSM,通过预处理步骤也可以以最优效率计算,其上界与由 𝐵 和 𝐶 的大小给出的下界 𝑂(TN) 相匹配。
备注 4。考虑到几乎所有 R^(N,N) 上的稠密矩阵在 C 上都是可对角化的这一事实,定理 3.7 或许并不令人意外,这导致几乎所有稠密实数 SSM 等效于对角复数 SSM。这一事实是对角 SSM 是最受欢迎的结构化 SSM 形式的原因 (Gu、Gupta 等 2022;Gupta、Gu 和 Berant 2022;J.T. Smith、Warrington 和 Linderman 2023)。然而,定理 3.7 对所有实数 SSM(不仅仅是可对角化的那些)以及其他域上的稠密 SSM(包括 C 本身)意味着更强的结果。
实际上,高效可计算的 SSM 仍然需要对 𝐴 施加额外的结构,特别是为了避免昂贵的预处理步骤(这既有顺序 N 的额外 FLOP,又涉及硬件效率低下的操作,如奇异值分解)。这些结构是过去关于结构化 SSM 的工作(例如 S4(D) 和 Mamba)以及我们新算法的重点。特别是,当对 𝐴 施加稍强的结构时,我们将在第 6 节中通过 SSM 矩阵 𝑀 = SSS(𝐴, 𝐵, 𝐶) 的块分解设计非常高效的硬件算法。
3.4.2 二次(朴素, Naive)模式
我们注意到,通过我们新的矩阵视角,还有另一种计算 SSM 的方法。朴素计算矩阵 SSM 表示(3)涉及简单地实现序列变换矩阵 𝑀 = SSS(𝐴, 𝐵, 𝐶)。这是一个 (T, T) 矩阵,因此这种朴素算法将在序列长度上呈二次增长。然而,当序列长度 T 较短时,由于常数因素和计算模式的硬件友好性(例如,利用矩阵-矩阵乘法),这实际上可能比线性算法更高效。事实上,对于某种特定的结构化 SSM,这看起来非常类似于二次注意力计算(第 5 节)。
3.4.3 总结
许多序列模型显式地由矩阵序列变换驱动或定义——最著名的是 Transformers,其中矩阵混合器是注意力矩阵。另一方面,RNN 和 SSM 以前并未以这种方式描述。通过提供状态空间模型的显式矩阵变换形式,我们揭示了理解和使用它们的新方法。从计算的角度来看,计算状态空间模型前向传递的任何方法都可以视为半可分矩阵上的矩阵乘法算法。半可分矩阵视角提供了状态空间二重性 (SSD) 的一个视角,其中二重模式分别指线性时间半可分矩阵乘法算法和二次时间朴素矩阵乘法。
此外,利用半可分矩阵的丰富结构可以带来更好的算法和更多的见解(例如,第 6 节和附录 B)。在附录 C.1 中,我们描述了半可分矩阵的一些附加属性。
4. 结构化掩码注意力:使用结构化矩阵泛化线性注意力
本节我们将从基础重新审视线性注意力框架。该部分的主要结果包括一个基于张量收缩的简单线性注意力证明(命题 4.1),以及我们对结构化掩码注意力的泛化抽象(定义 4.2)。需要注意的是,本节从不同于状态空间模型的方向推导了主要的对偶性结果,可以完全独立于第 3 节阅读。
- 第 4.1 节建立了我们对注意力变体的框架,特别关注核注意力和掩码核注意力。
- 第 4.2 节提供了我们的第一个主要注意力结果,即通过张量收缩视角的线性注意力的简单证明。
- 第 4.3 节定义了结构化掩码注意力,这是我们通过结构化矩阵对以前注意力变体的泛化。
4.1 注意力框架
4.1.1 注意力
注意力的基本形式是一个将三个向量序列(𝑄, 𝐾, 𝑉)映射到 𝑌 的操作。
我们使用 “形状注释” 来指示张量的维度,例如 𝑄 ∈ R^(T, N)。在这种一般形式中,S 和 T 分别表示源序列和目标序列的长度,N 表示特征维度,P 表示头部维度。最常见的 softmax 注意力变体使用一个 softmax 激活 𝑓 = softmax 来规范化 𝐺 矩阵的行。
4.1.2 自注意力
我们的处理由最重要的自注意力情况驱动,其中:
- 源序列和目标序列相同(即 S = T),
- 通常特征和头部维度相同(即 N = P),
- 并且 𝑄, 𝐾, 𝑉 是通过对同一输入向量的线性投影生成的(𝑄 = 𝑊𝑄 · 𝑋, 𝐾 = 𝑊𝐾 · 𝑋, 𝑉 = 𝑊𝑉 · 𝑋)。
然而,我们的介绍抽象了这些选择,从 𝑄, 𝐾, 𝑉 矩阵开始。
备注 5。我们关注的是头部和特征维度相等的自注意力情况(即 S = T 和 N = P),应该作为运行示例。我们定义注意力的一般公式,不仅是为了使我们的框架捕获跨注意力等变体,还因为分离维度的符号(例如 S 和 T)使得本节主要结果的收缩符号证明更加清晰。
备注 6。虽然注意力通常被框定为对这三个对称输入 𝑄, 𝐾, 𝑉 的操作,但输入和输出维度在(9)中表明情况并非如此。特别是,特征维度 N 不存在于输出中;因此,当 S = T 时(例如自注意力),我们将 𝑉 视为主要输入,这样(9)定义了一个适当的序列变换 𝑉 ↦ 𝑌(定义 2.1)。
4.1.3 核注意力
对 Gram 矩阵 𝐺 应用 softmax 函数的步骤可以分解为两部分:
- 对 𝐺 矩阵进行指数化。
- 在 S 轴上归一化 𝐺 矩阵。
我们现在可以忽略规范化项,因为它相当于简单地传递 𝑉 = 1 并进行除法(我们将在第 7.3 节中重新讨论这一点)。指数化项可以视为一个核变换:存在一个(无限维的)特征映射𝜑,使得 exp(𝑄𝐾⊤) = 𝜑(𝑄)𝜑(𝐾)⊤。通过将特征映射抽象到𝑄和𝐾的定义中(即定义𝑄, 𝐾为变换后的版本),我们可以忽略 softmax 变换,并假设𝑄, 𝐾由核特征映射任意生成,并且可能N ≠ P。
许多核注意力(kernel attention)的实例已经被提出,包括:
- 原始的线性注意力(Katharopoulos等,2020)将核特征映射定义为任意逐点激活函数,例如 𝑥 ↦ 1 + elu(𝑥)。
- 随机特征注意力(Random Feature Attention,RFA)(H. Peng等,2021)选择核特征映射来近似 softmax 注意力(即 exp 特征映射),使用高斯核的随机傅里叶特征近似(Rahimi和Recht 2007)。这涉及随机投影(即将 𝑄 和 𝐾 乘以一个随机投影 𝑊 并应用激活𝑥 ↦ (cos(𝑥), sin(𝑥)))。
- Performer(Choromanski等,2021)提出通过正交随机特征(FAVOR+)实现快速注意力。正随机特征(PRF)部分选择核特征映射为随机投影,然后是特征映射 𝑥 ↦ 2^(-1/2)·(exp(𝑥), exp(−𝑥))。这种选择的动机是使核元素为正值,并且可以证明近似 softmax 注意力。[它还建议选择正交方向的随机投影,我们不考虑这一点。]
- cosFormer(Qin, Weixuan Sun等,2022)使用余弦重加权机制来增强 RFA,结合位置信息以强调局部性。这有效地将 𝑄𝑡, 𝐾𝑡 通过特征映射 𝑥 ↦ (𝑥 cos(𝜋𝑡/2𝑇), sin(𝜋𝑡/2𝑇))。
- 线性随机化注意力(Zheng, C. Wang, 和 Kong 2022)从重要性抽样的角度泛化 RFA,并将其推广为提供更好的全 softmax 核估计(而不仅仅是 exp 变换的分子)。
其他相关注意力变体包括 Linformer(Sinong Wang等,2020)和 Nyströformer(Xiong等,2021),它们都使用注意力矩阵 𝐺 的低秩近似(因此与方程(9)兼容),分别通过随机投影(Johnson-Lindenstrauss)和核近似(Nyström 方法)。
4.1.4 掩码(核)注意力
令 𝐿 为形状为 (T, S) 的掩码。最常见的是,在 S = T 的自回归自注意力情况下,𝐿 可能是一个下三角为 1 的矩阵,表示因果掩码。除了强制因果性,还可以应用许多其他类型的掩码——特别是各种稀疏模式,如带状、扩展或块对角——其动机是减少稠密注意力的复杂性。
掩码注意力通常写成矩阵形式:
更精确地说,通过形状注解(shape annotations)和将其分解为精确的计算序列:
我们在本节中改进的注意力变体推导从注意到该公式可以写成一个单一收缩(single contraction)开始:
公式 (11) 中的算法可以通过一对一的收缩特定顺序重新构建为 (12) 的计算:
4.2 线性注意力
线性注意力和许多其他高效注意力的变体通常以改变核心注意力计算中的矩阵结合顺序为动机 (即 (𝑄·𝐾^⊤)𝑉=𝑄(𝐾^⊤·𝑉)。然而,当添加掩码时,推导变得不那么直接(例如,原始论文 (Katharopoulos et al. 2020) 和变体 (Y. Sun et al. 2023) 都没有给出证明)。
大致来说,线性注意力方法声称以下公式等价于 (10),需要通过展开和仔细跟踪索引来验证。
命题 4.1 ((Katharopoulos et al. 2020))。自回归核注意力,即具有因果掩码的掩码核注意力,可以每步以在 𝑂(𝑇) 内的固定时间完成递归计算。
4.2.1 线性注意力的张量收缩证明
我们提出了一个简单而严格的线性注意力推导,这也将立即揭示如何对其进行推广。主要思想是以不同的顺序执行收缩 (12)。我们避免了模糊的矩阵表示,直接使用收缩表示:
直观上,我们将此收缩顺序解释如下。
第一步 (15a) 执行特征的 “扩展”,扩展的特征维度为 N。第三步 (15c) 将扩展的特征维度收缩。 如果将 𝐾 视为输入(备注 6),则 𝑉 和 𝑄 分别执行扩展和收缩。
第二步是最关键的,解释了线性注意力的线性部分。首先注意到 (15b) 只是通过 𝐿 的直接矩阵乘法(因为 (P, N) 轴可以被展平)。还要注意到这是唯一一个涉及 T 和 S 轴的项,因此应该具有 𝛺(TS) 复杂度(即序列长度的二次复杂度)。然而,当掩码 𝐿 是标准的因果注意力掩码(下三角的 1)时,矩阵-向量乘法通过 𝐿 与特征累积和相同。
4.3 结构化掩码注意力
通过张量收缩视角(公式15)看掩码注意力,我们可以立即看到,原始线性注意力的关键在于因果掩码的矩阵-向量乘法等同于累积和操作。然而,我们观察到注意力掩码并不必全是1。为了让线性注意力速度更快,所有必要的是 𝐿 是一个结构化矩阵,这些矩阵通过定义具有快速矩阵乘法(参见 2.3 节)。
特别地,我们可以使用任何具有次二次(理想情况下为线性)矩阵-向量乘法的掩码矩阵 𝐿,通过加速瓶颈方程(15b)来实现与标准线性注意力相同的复杂度。
定义 4.2。结构化掩码注意力(SMA,或简称结构化注意力)被定义为一个关于查询/键/值 𝑄,𝐾,𝑉 以及任何结构化矩阵 𝐿(即具有次二次矩阵乘法)的函数,通过四向(4-way)张量收缩
SMA 二次模式算法是由 (13) 定义的一对一收缩序列,对应于标准(掩码)注意力计算。
SMA 线性模式算法是由 (15) 定义的一对一收缩序列,其中步骤 (15b) 通过次二次结构化矩阵乘法进行优化。
我们可以将结构化掩码注意力实例化为任何给定类的矩阵结构。一些示例包括(图3):
- 线性注意力使用因果掩码。
- RetNet (Y. Sun et al. 2023) 使用衰减掩码,衰减因子 γ∈[0,1]
- 衰减掩码可以推广为托普利兹(Toeplitz)矩阵 𝐿_ij = α_(i−j),对于一些可学习的(或输入依赖的)参数集 α∈R。这可以解释为一种相对位置编码,类似于其他方法如 AliBi (Press, N. Smith, and Lewis 2022) 但使用乘法而不是加法。
- 另一个变体可以使用傅里叶矩阵 𝐿_ij=ω^(ij/T) 以不同方式编码位置结构。
在第 5 节中,我们考虑半分离 SMA,这定义了我们的主要 SSD 模型。
4.3.1 总结:掩码注意力的双重形式
标准的(掩码内核)注意力经常混淆在函数和算法之间。分离这个差异提供了一种清晰的方式来理解不同的注意力变体。
- 我们将掩码注意力视为一个特定的函数(公式 12)。
- 标准的二次注意力计算(公式 13)可以被视为计算该函数的算法。
- 线性注意力(公式 15)是计算相同函数的另一种算法。
此外,在这种情况下,
- 掩码注意力函数仅仅是四项(four terms)的特定收缩。
- 二次和线性注意力算法只是执行收缩的两种不同顺序。
已知收缩顺序可以在计算复杂度上产生很大的差异,导致二次与线性的分裂。就像状态空间模型是一种可以通过多种方式计算的变换,具有二次与线性形式的双重性(第 3.4 节),线性注意力也有类似的双重性,这源于两种收缩顺序。事实上,这些结果是对相同底层双重性的不同观点,我们在第 5 节中明确说明。
5. 状态空间对偶性
在第 3 和第 4 节中,我们定义了结构化状态空间模型和结构化注意力,讨论了它们的性质,并展示了它们都有一个二次算法和一个线性算法。本节将它们连接起来。我们的主要结果是表明特定情况下的结构化状态空间模型与特定情况下的结构化注意力一致,并且线性时间 SSM 算法和二次时间内核注意力算法是彼此的对偶形式。
- 第 5.1 节将状态空间模型专门化为标量结构,其中简单的二次计算可以看作是内核注意力的一个实例。
- 第 5.2 节将结构化掩码注意力专门化为半分离 SMA,它表征了具有高效自回归的掩码注意力。
- 第 5.3 节总结了结构化掩码注意力和结构化状态空间模型之间的连接,称为结构化状态空间对偶性。
5.1 标量-恒等结构状态空间模型
在第 3 节中,我们展示了状态空间模型等价于半分离矩阵变换,结果是既有线性递归形式,也有二次简单形式。
回忆一下,SSM 由 y=SSM(A,B,C)(x) 定义,SSM 的矩阵形式使用 SSS(顺序半分离)表示 M=SSS(A,B,C),其中(公式(3))
现在让我们考虑 Aj 只是一个标量的情况;换句话说,是一个结构化 SSM 的实例,其中 A 矩阵是极度结构化的: A=aI,其中 a 是标量, I 是单位阵。然后我们可以重新排列
这可以向量化为
使用这种形式,完整输出 Y=MX 精确地计算为
其中 S=T。但这正是掩码内核注意力定义(13)的原始定义!
因此,如第 3.4 节所述,简单地计算标量结构化 SSM——通过实现半分离矩阵 M 并执行二次矩阵-向量乘法——与二次掩码核注意力完全相同。
5.2 1-半分离结构化掩码注意力
结构化掩码注意力允许使用任何结构化掩码 L。当 L 是因果掩码时,就是标准的线性注意力。请注意,因果掩码是 L=SS(1_T),即 1-SS 掩码由定义 (6) 中的 a_t = 1 生成。这激发了将 L 推广到 1-半分离掩码类,即 1-半分离结构化掩码注意力(1-SS SMA),其中线性注意力的递归中的累积和被更一般的递归——标量 SSM 扫描,即 1-半分离矩阵乘法(第 3.2.2 节)——所取代。
最后,考虑 1-半分离 SMA 最重要的原因是其计算的线性形式是对角状态空间模型的特例。SMA 的线性形式是算法 (15),其中瓶颈步骤 (15b) 可以看作是通过 1-SS 掩码的矩阵乘法。在第 3 节中,我们还写出了对角 SSM 的计算(8),其中瓶颈步骤(8b)是标量 SSM 递归,相当于 1-SS 乘法。唯一的区别是 (8b) 在 L 中有一个额外的 N 维度,因为矩阵 A 是大小为 N 的对角矩阵。如果 A 的所有对角条目都是相同的,这个 N 维度将消失,这导致推论 5.1。
推论 5.1。1-SS SMA(带有 1-半分离结构矩阵 L 的掩码注意力)(15)是对角 SSM(8)的特例,其中对角矩阵是标量的恒等矩阵的倍数。
虽然推论 5.1 表明 1-SS SMA 有一个高效的递归形式,但我们也可以展示一个相反的结果,表征了哪些 SMA 实例具有高效自回归。
定理 5.2。对于任何实例化的结构化掩码注意力(定义 4.2),如果它是有界阶数的自回归过程,则结构化掩码 L 必须是一个半分离矩阵。
换句话说,高效自回归注意力是广义的半分离 SMA。定理 5.2 在附录 C.2 中证明。
备注 7。虽然 1-半分离 SMA 是状态空间模型的特例,但广义的半分离 SMA 比 1-SS SMA 更具表达能力,且不能用标准 SSM 来描述。然而,矩阵 L 的半分离乘法和 SMA 的线性形式(公式15a)都涉及扩展和收缩步骤,可以通过单个(更大的)扩展,吸收到 1-SS SMA 的类似实例中。
总而言之,1-半分离结构化注意力是 SMA 最重要的情况,因为它是:
- 具有输入依赖递归的线性注意力的自然推广。
- 等同于高效自回归注意力的最简单广义半分离注意力情况。
- 对角状态空间模型的特例。
5.3 结构化状态空间对偶性(SSD)
总结我们的结果:
- 结构化状态空间模型(第 3 节)通常通过线性时间递归定义。然而,通过扩展矩阵形式表征其线性序列到序列变换,可以导出二次形式。
- 注意力变体(第 4 节)通过二次时间成对交互定义。然而,通过将其视为四向张量收缩(4-way tensor contraction)并以不同顺序简化,可以导出线性形式。
- 每一个的自然特例——更准确地说,具有标量-恒等结构的状态空间模型的 A 矩阵,以及具有1-半分离结构的 L 掩码的结构化掩码注意力——是彼此的对偶形式,具有完全相同的线性和二次形式。
图 4 总结了这两种表示之间的对偶性。
扩展的相关工作和讨论(第 10 节)更详细地描述了 SSD 与一般 SSM / 注意力的关系。
6. 硬件高效的 SSD 模型算法
构建理论上的 SSD 框架的好处在于利用这些连接来改进模型和算法。在本节中,我们展示了如何从各种计算结构化矩阵乘法的算法中导出高效计算 SSD 模型的算法。我们的主要计算结果是一个结合了线性(递归)模式和二次(注意力)模式的 SSD 模型计算算法。这个算法的计算效率与 SSMs 一样(在序列长度上线性扩展),并且与注意力一样对硬件友好(主要使用矩阵乘法)。
定理 6.1。考虑一个具有状态扩展因子 N 和头维度 P = N 的 SSD 模型。存在一种算法可以在任何输入 X∈R^(T,P) 上计算模型,该算法只需要 O(TN2) 的训练 FLOPs、 O(TN) 的推理 FLOPs、 O(N^2) 的推理内存,并且其工作主要由矩阵乘法主导。需要注意的是,这些界限都是紧的,因为状态扩展为 N 的状态空间模型在头大小为 N 的情况下总状态大小为 N^2(分别产生训练和推理 FLOPs 的下界 O(TN^2)。此外,输入 X 本身有 TN 个元素,从而产生内存下界。
定理 6.1 背后的主要思想是再次将计算状态空间模型的问题视为半分离矩阵乘法,但以一种新的方式利用其结构。我们不在递归或注意力模式下计算整个矩阵,而是对矩阵进行块分解。对角块可以使用双重注意力模式计算,这可以通过矩阵乘法高效完成,而非对角块可以通过半分离矩阵的秩结构分解并减少到一个较小的递归。我们强调,List 1 提供了一个自包含的(self-contained) SSD 算法实现。与 Gu 和 Dao(2023)提出的一般选择性 SSM 相比,这个实现更简单,即使在原生PyTorch中也相对高效,不需要特殊的低级内核。
首先,对于某个块大小 Q,我们将矩阵 M 划分为一个 T/Q × T/Q 的子矩阵网格,每个子矩阵大小为 Q×Q。注意,由半分离矩阵的定义性质(定义3.1),非对角块是低秩的。(注意:即使分块大小不同(例如,Q ∤ T),块分解也是有效的,但为了简单起见,我们假设了均匀的可分割性。)
这最容易通过一个示例来说明,例如对于 T=9,将其分解成长度为 Q=3 的块。阴影单元格是半分离矩阵的非对角块的低秩因式分解。
6.1 对角块
对角块很容易处理,因为它们只是一个较小大小的自相似问题。第 j 个块表示计算范围为
的答案
关键是,这个块可以使用任何所需的方法计算。特别地,对于小的块长度 Q,这个问题使用对偶二次 SMA 形式来更有效地计算。此外,这些块可以并行计算。
这些子问题可以解释为:假设初始状态(对于块)为 0,每块的输出是什么。换句话说,对于块 j,这计算了考虑了仅块输入 x_(jQ:(j+1)Q) 的正确输出。
6.2 低秩块
低秩因式分解由 3 个项组成,因此有相应的三个部分的计算。在这个分解中,我们将术语定义如下:
类似于如上所示的的项分别称为右因子(或 B-块因子)、中心因子(或 A-块因子)、左因子(或 C-块因子)。
右因子。这一步计算了低秩因式分解的右 B-块因子的乘积。请注意,对于每个块,这是一个 (N,Q)×(Q,P) 的矩阵乘法,其中 N 是状态维度,而 P 是头维度。结果是每个块的 (N,P) 张量,其与扩展的隐藏状态 h 具有相同的维度。
这可以解释为:假设初始状态(对于块)为 0,每块的最终状态是什么。换句话说,假设 x_(0:jQ) = 0,这计算了 h_(jQ+Q−1)。
中心因子。这一步计算了低秩因式分解中的中心 A-块因子的效果。在前一步中,每个块的最终状态总共有形状 (T/Q,N,P)。现在这将乘以由
生成的 1-SS 矩阵。
这一步可以通过任何用于计算 1-SS 乘法的算法(也称为标量 SSM 扫描或 cumprodsum 运算符)来计算。
这可以解释为:考虑所有先前的输入,每个块的实际最终状态是什么;换句话说,这计算了 h_(jQ),考虑到所有的 x_(0:(j+1)Q)。
左因子。这一步计算了低秩因式分解的左 C-块因子的乘积。对于每个块,这可以用一个矩阵乘法 contract(QN, NP → QP) 表示。
这可以解释为:考虑正确的初始状态 h_(jQ−1),并假设输入 x_(jQ:(j+1)Q) 为 0,每个块的输出是什么。换句话说,对于块 j,这计算了仅考虑先前的输入 x_(0:jQ) 的正确输出。
6.3 计算成本
我们将符号 BMM(B, M, N, K) 定义为一个批量矩阵乘法 contract(MK, KN → MN),其中批次(batch)维度为 B。从这个符号我们可以推断出计算的三个方面的效率:
- 计算成本:总共 O(BMNK) FLOPs。
- 内存成本:总共 O(B(MK+KN+MN)) 空间。
- 并行化:较大的 M、N、K 项可以利用现代加速器上的专门矩阵乘法单元。
中心块。二次 SMA 计算的成本包括三个步骤(参见方程(16)):
- 计算核矩阵 C^T·B,其成本为 BMM(T/Q, Q, Q, N)。
- 乘以掩码矩阵,这是一个形状为 (T/Q, Q, Q) 的张量的逐元素操作。
- 乘以 XXX 值,其成本为 BMM(T/Q, Q, P, N)。
低秩块:右因子。这一步是一次矩阵乘法,成本为 BMM(T/Q, N, P, Q)。
低秩块:中心因子。这一步是一个标量 SSM 扫描(或 1-SS 乘法),长度为 T/Q,独立通道为 (N, P)。这个扫描的工作量为 TNP/Q,与其他因素相比可以忽略不计。
需要注意的是,由于分块将序列长度从 T 减小到 T/Q,这个扫描的成本比纯 SSM 扫描(例如 Mamba 的选择扫描)小了 Q 倍。因此我们可以观察到,在大多数问题长度上,其他算法(附录B)可能更有效,或者更容易实现而不会显著减慢速度。例如,通过 1-SS 矩阵乘法的朴素实现成本为 BMM(1, T/Q, NP, T/Q),这比朴素的递归/扫描实现更容易,并且可能比其更有效。
低秩块:左因子。这一步是一次矩阵乘法,成本为 BMM(T/Q, Q, P, N)。
总成本。如果我们设置 N=P=Q(换句话说,状态维度、头维度和块长度相等),那么上述所有 BMM 项都变成了 BMM(T/N, N, N, N)。其计算特性是:
- 总 FLOP 数量为 O(TN^2)。
- 总内存为 O(TN)。
- 工作主要由形状为 (N,N) 的矩阵乘法组成。
需要注意的是,内存消耗是紧凑的;输入和输出 x,y 的形状为 (T,P)=(T,N)。同时,FLOP 计数反映了额外的因素 N,这是由于自回归状态大小所产生的成本,对所有模型都是通用的。
除了矩阵乘法之外,还有一个对 NP = N^2 个特征和长为 T/Q 序列进行标量 SSM 扫描的操作。这个扫描的成本为 O(T/QN^2) FLOPs 和 O(log(T/Q)) 深度。虽然它不使用矩阵乘法,但仍然可以并行化,并且总工作量与其他步骤相比可以忽略不计;在我们的 GPU 实现中,这的成本可以忽略不计。
与纯 SSM 和注意力模型的比较。二次注意力通过只利用矩阵乘法也非常高效,但总 FLOP 数量为 T^2·N。它在训练和推理的较慢计算速度可以直接看作是具有更大状态大小的后果——标准注意力的状态大小随着序列长度 T 扩展,因为它缓存其历史记录并且不压缩其状态。
线性 SSM 具有 TNP=TN^2 的总 FLOP 数量,与 SSD 相同。然而,一个朴素的实现需要状态扩展(15a)导致额外的内存,并且标量操作(15b)不利用矩阵乘法。
我们注意到许多其他矩阵分解方法都是可能的(例如,请参阅附录 B,其中列举了通过不同结构的矩阵分解进行 1-SS 乘法的算法),这可能会导致更多用于 SSD 的算法,这些算法可能对其他专业设置更有效。更广泛地说,我们注意到半可分矩阵具有丰富的文献和除我们使用的 SSS 形式(定义 3.2)之外的许多其他表示形式,可能还可以使用更高效的算法。
7. Mamba-2 的架构
通过将 SSMs 和注意力联系起来,SSD 框架使我们能够开发共享的词汇和技术库,用于两者。在本节中,我们讨论了一些例子,通过最初为 Transformers 开发的思想来理解和修改 SSD 层。我们讨论了几种设计选择,导致了 Mamba-2 架构。这些变化轴在第 9.4 节中被消融。
7.1 块设计
我们首先讨论与神经网络块的修改,这些修改与内部序列混合层独立(即在核心 SSD 层之外)。
并行参数投影。Mamba-1 受 SSM 为中心的观点启发,其中选择性 SSM 层被视为从 𝑋 ↦ 𝑌 的映射。SSM 参数 𝐴, 𝐵,𝐶 被视为附属参数,并且是 SSM 输入 𝑋 的函数。因此,定义(𝐴, 𝐵,𝐶)的线性投影发生在创建 𝑋 之后。
在 Mamba-2 中,SSD 层被视为从 𝐴,𝑋, 𝐵,𝐶 ↦ 𝑌 的映射。因此,在块的开头进行单个投影并行产生 𝐴,𝑋, 𝐵,𝐶 是有意义的。注意这类似于标准注意力架构,其中 𝑋, 𝐵,𝐶 对应于并行创建的 𝑄,𝐾,𝑉 投影。
需要注意的是,采用并行投影以获得 SSM 的 𝐴, 𝐵,𝐶,𝑋 输入略微减少了参数,更重要的是,通过使用标准的 Megatron 分片图样(sharding patterns)(Shoeybi等,2019),更适合于较大模型的张量并行化。
额外的归一化。在初步实验中,我们发现较大模型容易出现不稳定。我们通过在块的最后输出投影之前添加一个额外的归一化层(例如 LayerNorm,GroupNorm 或 RMSNorm)来缓解这一问题。这种使用归一化的方式最直接与 NormFormer 架构(Shleifer,Weston和Ott,2021)相关,该架构还在 MLP 和 MHA 块的末尾添加了归一化层。
我们还注意到,这种改变类似于最近从线性注意力视角衍生出的与 Mamba-2 相关的其他模型。原始线性注意力公式通过模拟 softmax 函数的标准注意力中的归一化,通过分母项进行归一化。TransNormerLLM(Qin,Dong Li等,2023)和 RetNet(Y. Sun等,2023)发现,这种归一化是不稳定的,并在线性注意力层之后添加了额外的 LayerNorm 或 GroupNorm。我们的额外归一化层与这些略有不同,它出现在乘法门分支之后而不是之前。
7.2 序列变换的多头模式
回想一下,SSMs 被定义为一个序列变换(定义2.1),其中:
- 𝐴, 𝐵,𝐶 参数具有状态维度 N。
- 它们定义一个序列变换 R^T → R^T,例如可以表示为矩阵 𝑀 ∈ R^(T,T)。
- 此变换独立于 P 轴对输入序列 𝑋 ∈ R^(T,P) 进行操作。
我们可以将这视为定义序列变换的一个头。
定义 7.1(多头模式)。多头序列变换由 H 个独立头组成,总模型维度为 D = d_model。参数可能在头之间被绑定,导致头图样(head pattern)。
状态大小 N 和头维度 P 分别类似于注意力的 𝑄𝐾 头维度和 𝑉 头维度。就像在现代 Transformer 架构(Chowdhery等,2023;Touvron,Lavril等,2023)中一样,在 Mamba-2 中,我们通常选择这些常量为 64 或 128;当模型维度 D 增加时,我们增加头的数量,同时保持头维度 N 和 P 固定。为了描述如何做到这一点,我们可以迁移和泛化多头注意力的思想,为 SSMs 或任何一般序列变换定义类似的图样 。
多头选择性状态空间模型(MHS)/多头注意力(MHA)模式。经典的 MHA 模式假设头部维度 P 可以整除模型维度 D。头数定义为 H = D/P。然后,通过创建每个参数的 H 个独立副本,来创建 H 个核心序列转换的副本。请注意,虽然 MHA 模式最初是针对注意力序列转换描述的,但它可以应用于与定义 2.1 兼容的任何内容。例如,多头 SSD 层将接受与方程(17)中的形状相符的输入,其中 SSD 算法在 H = n_heads 维度上广播。
多收缩状态空间模型(MCS)/多查询注意(MQA)模式。多查询注意(Shazeer 2019)是一种对注意力的巧妙优化,可以显着提高自回归推断的速度,它依赖于缓存 𝐾 和 𝑉 张量。这种技术简单地避免了给出额外的头维度 𝐾 和 𝑉,换句话说,在所有 𝑄 头中广播(𝐾, 𝑉)的单个头。
利用状态空间的对偶性,我们可以将 MQA 定义为等效的 SSM 版本,如方程(18)所示。在这里,𝑋 和 𝐵(注意力的 𝑉 和 𝐾 的 SSM 类比)在 H 个头部之间共享。我们还将其称为多收缩 SSM(MCS)头模式,因为控制 SSM 状态收缩的 𝐶 参数每个头都有独立副本。
类似地,我们可以定义多键注意(MKA)或多扩展 SSM(MES)头模式,其中每个头的 𝐵(控制 SSM 扩展)独立,而 𝐶 和 𝑋 在头之间共享。
多输入 SSM(MIS)/ 多值注意(MVA)模式。虽然 MQA 对于注意力是有意义的,因为它的 KV 缓存,但它不是 SSM 的自然选择。在 Mamba 中,相反,𝑋 被视为 SSM 的主要输入,因此 𝐵 和 𝐶 被视为跨输入通道共享的参数。我们在方程(20)中定义了一种新的多值注意(MVA)的多输入SSM(MIS)模式,它可以再次应用于诸如 SSD 之类的任何序列转换。
有了这些术语,我们可以更准确地描述原始的 Mamba 架构。
命题 7.2。Mamba 架构(Gu和Dao 2023)的选择性 SSM(S6)层可以被视为具有
- 头维度 𝑃 = 1:每个通道都有独立的 SSM 动态 𝐴。
- 多输入 SSM(MIS)或多值注意(MVA)头结构:𝐵,𝐶 矩阵(对应于注意力对偶的 𝐾,𝑄)在输入 𝑋 的所有通道之间共享(对应于注意力中的 𝑉)。
当应用于 SSD 时,我们还可以消融这些头图样(pattern)变体(见第 9.4.3 节)。有趣的是,尽管在参数计数和总状态维度上受到控制,但在下游性能上有明显差异。我们凭经验发现,最初在 Mamba 中使用的 MVA 模式表现最佳。
分组头模式。多查询注意的思想可以扩展到分组查询注意(Ainslie等人,2023年):可以创建 G 个独立的 K 和 V 头,其中 1 < G 且 G 可以整除 H。这既是为了弥合多查询和多头注意力之间的性能差异,也为了通过将 G 设置为碎片数量的倍数来实现更高效的张量并行。
类似地,Mamba-2 中使用的多输入 SSM 头模式可以轻松扩展到分组输入 SSM(grouped-input SSM,GIS),或者可以同义地称为分组值注意(grouped-value attention,GVA)。这种泛化很简单,我们为简单起见省略了细节。
7.3 来自线性注意力机制的其他 SSD 扩展
我们在此描述了由线性注意力激发的 SSD 的架构修改示例。我们在第 9.4.3 节中对这些进行了消融,作为一种负面结果,发现它们并没有显着提高性能以至于将它们作为默认设置采用。尽管如此,这些例子说明了如何将大量关于注意力的文献纳入,以定义 SSD 的变体。在 Mamba-2 架构中,我们将核特征映射的选择视为超参数,并且期望还有其他受注意力启发的简单修改也是可能的。
核注意力对 Softmax 注意力的近似。许多线性注意力或核注意力的变体是基于将注意力分数 softmax(𝑄𝐾^⊤) 视为由以下组成的。
- 指数核 𝑍 = exp(𝑄𝐾^⊤),对于某些核特征映射,它可以被 𝑍 = 𝜓(𝑄)𝜓(𝐾)^⊤ 近似。
- 通过下式来对核进行标准化,其中除法是逐元素进行的,1 是全 1 向量,使行的总和归一化为1。
指数核特征映射。在 Mamba-2 中,我们引入了一个灵活的核特征映射,并将其应用于𝐵 和 𝐶 分支(对应于注意力中的 𝐾 和 𝑉 分支)。该特征映射也可以选择性地应用于 𝑋(𝑉)分支,以实现简单性和对称性。这在图 6 中用任意非线性表示。默认情况下,我们简单地选择 𝜓 为逐元素的 Swish / SiLU函数(Hendrycks和Gimpel 2016; Ramachandran,Zoph和Le 2017)。我们在第 9.4.3 节中探讨了其他选项,包括线性注意力、执行者、随机特征注意力和 cosFormer 使用的特征映射。
整合归一化(分母)项。要找到分母项,我们只需要计算 𝑀1。但请记住,模型的最终输出只是 𝑌 = 𝑀𝑋(方程(16))。因此,可以通过在 𝑋 中添加一个额外的列 1 来简单地找到归一化项,从而产生形状为(T,P + 1)的张量。
请注意,在这种情况下,核特征映射 𝜓 必须为正,以确保总和为正。
8. SSM 的系统优化
我们描述了几种用于 SSM 的系统优化方法,特别是 Mamba-2 架构,用于大规模高效的训练和推断。具体来说,我们关注用于大规模训练的张量并行和序列并行,以及用于高效微调和推断的可变长度序列。
8.1 张量并行
张量并行(Tensor parallelism,TP)(Shoeybi等人,2019年)是一种模型并行技术,它将每一层(例如,注意力,MLP)分割为在多个加速器(如 GPU)上运行。这种技术被广泛用于在 GPU 集群上训练大多数大型模型(Brown等人,2020年;Chowdhery等人,2023年;Touvron,Lavril等人,2023年;Touvron,L. Martin等人,2023年),每个节点通常具有 4-8 个 GPU,并配备快速网络(例如 NVLink)。TP 最初是为 Transformer 架构开发的,并且将其应用于其他架构并不直接。我们首先展示了在 Mamba 架构中使用 TP 的挑战,然后展示了 Mamba-2 架构如何设计以使 TP 高效。
回顾一下 Mamba 架构,其中单个输入 𝑢 ∈ R^(𝐿×𝑑)(为简单起见,没有批处理),输入投影矩阵𝑊^(𝑥),𝑊^(𝑧) ∈ R^(𝑑×𝑒𝑑),其中 𝑒 是扩展因子(通常为 2),输出投影矩阵 𝑊^(𝑜) ∈ R^(𝑒𝑑×𝑑):
使用 TP,假设我们想要沿着 2 个 GPU 分割计算。将输入投影矩阵 𝑊^(𝑥),𝑊^(𝑧) 分割成两个大小为 𝑑×𝑒𝑑 / 2 的分区是很容易的。然后每个 GPU 将持有 𝑥_𝑐 大小的一半: 𝐿×𝑒𝑑 / 2 。然而,我们注意到由于 Δ,𝐵,𝐶 是关于 𝑥_𝑐 的函数,所以在计算 Δ,𝐵,𝐶 之前,我们需要在 GPU 之间进行额外的 all-reduce,以获取完整的 𝑥_𝑐。之后,两个 GPU 可以并行计算 SSM,因为它们在 𝑑 上是独立的。最后,我们可以将输出投影矩阵 𝑊^(𝑜) 分割成两个大小为 𝑒𝑑 / 2 × 𝑑 的分区,并在最后进行 all-reduce。与 Transformer 相比,我们会产生两个 all-reduce,而不是一个,将通信时间加倍。对于大规模 Transformer 训练,通信可能已经占据了相当大的时间比例(例如 10-20%),而加倍通信会使 Mamba 在大规模训练中效率不高。
在 Mamba-2 中,我们的目标是每个块只进行一次 all-reduce,类似于 Transformer 中的注意力或 MLP 块。因此,我们通过投影直接从 𝑢 获取 Δ,𝐵,𝐶,而不是从 𝑥_𝑐 获取,从而允许我们分割这些投影矩阵。这意味着在不同的 GPU 上有不同的 Δ,𝐵,𝐶 集合,这相当于在更大的 “逻辑 GPU” 上有几个 “组” 的 Δ,𝐵,𝐶。此外,我们在每个块内使用 GroupNorm,组数可被 TP 度(degree)整除,以便 TP 组内的 GPU 在块内不进行通信:
我们看到我们只需要分割输入投影矩阵和输出投影矩阵,并且只需要在块结束时进行 all-reduce。这类似于对注意力和 MLP 层进行 TP 设计。特别地,如果我们有 TP 度为 2,我们将分割
对于 𝑖 = 1, 2,TP Mamba-2 层可以写为:
我们在图 7(左侧)中用 Mamba-2 来说明张量并行。
8.2 序列并行
对于非常长的序列,我们可能需要沿着序列长度维度将输入和激活分配到不同的 GPU 上。有两种主要技术:
- 序列并行(Sequence parallelism,SP)用于残差和归一化操作:由 Korthikanti等人(2023年)首次提出,这种技术将 TP 中的 all-reduce 分解为 reduce-scatter 和 all-gather。注意到残差和归一化操作在同一输入上对同一 TP 组中的所有 GPU 重复执行,SP 沿着序列长度维度分割激活,执行:reduce-scatter,残差和归一化,然后 all-gather。由于 Mamba-2 架构使用相同的残差和归一化结构,因此 SP 无需修改。
- 用于标记混合操作(注意力或 SSM)的序列并行,也称为 “上下文并行”(context parallelism,CP)。已经为注意力层开发了几种技术(例如,环形注意力(Liu,Yan等人,2024年;Liu,Zaharia和Abbeel 2023年),具有复杂的负载平衡技术(Brandon等人,2023年)。序列并行在注意力中的困难在于,我们可以将查询和键分割成块,但每个查询块都需要与键块交互,导致通信带宽与工作人员(worker)数量的平方成正比。 对于 SSM,我们可以简单地以一种简单的方式分割序列:每个工作人员获取一个初始状态,根据其输入计算 SSM,返回最终状态,并将该最终状态传递给下一个工作人员。通信带宽与工作人员数量成线性关系。这种分解与 SSD 算法(图 5)中的块分解完全相同,将其分割成块/块(blocks / chunks)。我们在图 7(右侧)中说明了这种上下文并行。
8.3 可变长度
尽管预训练通常对批处理使用相同的序列长度,但在微调或推断过程中,模型可能需要处理不同长度的不同输入序列。处理这种情况的一种简单方式是将批次中的所有序列右填充到最大长度,但如果序列长度差异很大,则这可能效率低下。对于 Transformer,已经开发了复杂的技术来避免填充并在 GPU 之间进行负载平衡(Zeng等人,2022年;Y. Zhai等人,2023年),或者将多个序列打包到同一批次中并调整注意力掩码(Ding等人,2024年;Pouransari等人,2024年)。对于 SSMs,特别是 Mamba,我们可以通过将整个批次视为一个长序列来处理可变序列长度,并避免在各个序列之间传递状态。这相当于把一个序列终止处的标记 𝑡 简单地设置 𝐴_𝑡 = 0,从而不将信息传递给属于另一个序列的标记 𝑡 + 1。
9. 实验
我们通过对合成召回(recall)任务(第 9.1 节)和标准语言建模预训练以及下游评估(第 9.2 节)进行实证评估来评估 Mamba-2,在这些任务中,循环模型一直存在挑战。我们验证了我们的 SSD 算法比 Mamba-1 更高效(第 9.3 节),并且在中等序列长度下与优化的注意力相当。最后,我们在 Mamba-2 架构中消融了各种设计选择(第 9.4 节)。
9.1 合成任务:关联召回
合成的关联召回(associative recall)任务一直是测试语言模型在上下文中查找信息能力的流行选择。广义上来说,这些任务涉及向自回归模型提供键-值关联对,然后提示模型在显示先前看到的键时产生正确的补完。多查询关联召回(multi-query associative recall,MQAR)任务是这个任务的一个特定形式,它要求模型记住多个关联(Arora等人,2024年)。最初的 Mamba 论文报告了相关合成任务的结果,特别是选择性复制(Gu和Dao,2023年)和归纳头部(Olsson等人,2022年),这可以看作是较简单的关联召回任务。MQAR 任务与 “电话簿查找” 任务密切相关,已经显示出对于循环模型(如 SSMs)来说是具有挑战性的,这是由于它们的有限状态容量(De等人,2024年;Jelassi等人,2024年)。
我们在(Arora,Eyuboglu,Zhang等人,2024年)的 MQAR 设置中比较了一个具有挑战性的版本,使用了更困难的任务、更长的序列和更小的模型。我们的基线包括标准的多头 softmax 注意力,以及基于卷积、局部注意力和线性注意力变体的 Based 架构。
结果如图 8 所示。尽管 Mamba-1 在这项任务上表现困难,但 Mamba-2 在所有设置下表现良好。令人惊讶的是,即使在控制状态大小(N = 16)时,它在很大程度上优于 Mamba-1。(我们不确定架构的哪个方面是主要因素,这仍然是未来工作中要探索的问题。)此外,这个任务验证了状态大小的重要性:从 N = 16 增加到 N = 64 和 N = 256,一致地提高了 MQAR 的性能,因为更大的状态允许记住更多的信息(键-值对)。
9.2 语言建模
遵循 LLMs 中的标准协议,我们对标准的自回归语言建模对 Mamba-2 架构进行了训练和评估,与其他架构进行比较。我们比较了预训练指标(困惑度)和 zero-shot 评估。模型大小(深度和宽度)遵循 GPT3 规格,从 125M 到 2.7B。我们使用了 Pile 数据集,并遵循了 Brown等人(2020年)描述的训练方法。这与 Mamba(Gu和Dao,2023年)中报告的相同设置一致;训练细节在附录 D 中描述。
9.2.1 扩展定律
对于基线,我们与 Mamba 及其基于 PaLM 和 LLaMa 架构的 Transformer++ 配方(例如旋转嵌入、SwiGLU MLP、RMSNorm 代替 LayerNorm、无线性偏置和更高的学习率)进行比较。由于 Mamba 已经证明优于标准的 Transformer 架构(GPT3 架构)以及最近的次二次架构(H3(Dao,D. Y. Fu等人,2023年),Hyena(Poli等人,2023年),RWKV-4(B. Peng,Alcaide等人,2023年),RetNet(Y. Sun等人,2023年)),我们在图中将其省略以保持清晰度(有关比较,请参见 Gu 和 Dao(2023年))。 图 9 展示了在标准的 Chinchilla 协议下的扩展规律,模型从 ≈125M 到 ≈1.3B 参数。
9.2.2 下游评估
表 1 显示了 Mamba-2 在一系列流行的下游 zero-shot 评估任务上的性能,与这些规模的最知名的开源模型进行了比较,其中最重要的是 Pythia(Biderman等人,2023年),它们使用了与我们的模型相同的标记器(tokenizer)、数据集和训练长度(300B token)。
9.2.3 混合模型:将 SSD 层与 MLP 和注意力结合
最近的一些研究表明,使用既有 SSM 层又有注意力层的混合架构可能会提高模型质量,尤其是在上下文学习方面,超过了 Transformer 或纯 SSM(例如 Mamba)模型的质量。我们探索了将 SSD 层与注意力和 MLP 结合的不同方式,以了解每种方式的优势。实证结果显示,将约 10% 的总层数用于注意力层效果最佳。将 SSD 层、注意力层和 MLP 结合在一起也比纯 Transformer++ 或 Mamba-2 效果更好。
SSD 和注意力:我们发现,SSD 和注意力层相辅相成:单独使用(例如在 Mamba-2 架构与Transformer++ 之间)它们的性能(以困惑度衡量)几乎相同,但混合使用 SSD 和注意力层的效果优于纯 Mamba-2 或 Transformer++ 架构。我们在 Pile 上对训练了 7B token 的 350M 模型(48层)进行了一些结果展示(见表 2),使用了 GPT-2 tokenizer(相同的参数数量、相同的超参数、相同的训练和验证集)。增加少量注意力层已经明显改善了性能,并在质量和效率之间达到了最佳平衡。我们假设 SSM 层能够很好地实现一般的序列到序列映射,而注意力层则作为检索机制,快速地引用序列中的先前 token,而不是强迫模型将所有上下文压缩到其内存中(SSM 状态)。
带有 SSD、MLP 和注意力的混合模型:我们比较了不同的方式,即 SSD 如何与(门控)MLP 和注意力层结合,并在 2.7B 规模(64 层)上进行了评估,在 Pile上 训练了 300B token(相同数量的参数、相同的超参数、相同的训练和验证集、相同的数据顺序):
- Transformer++:32 个注意力层和 32 个门控 MLP,交错排列。
- Mamba-2:64 个 SSD 层。
- Mamba-2-MLP:32 个 SSD 和 32 个门控 MLP 层,交错排列。
- Mamba-2-Attention:58 个 SSD 层和 6 个注意力层(索引为 9、18、27、36、45、56)。(在小规模实验中,我们发现只要注意力层分布均匀,不在开头或结尾位置,模型的质量并不太依赖于注意力层的确切位置。)
- Mamba-2-MLP-Attention:28 个 SSD 层和 4 个注意力层,与 32 个门控 MLP 层交错排列。
我们在 Pile 上报告验证困惑度,以及 zero-shot 评估,见表 3。总体而言,Transformer++ 和 Mamba-2 模型的质量大致相同。我们发现,仅增加 6 个注意力层就明显改善了纯 Mamba-2 模型的性能(以及 Transformer++ 模型)。增加 MLP 层会降低模型的质量,但可以(i)由于 MLP 层的简单性和硬件效率而加速训练和推理(ii)更容易地将模型升级为 MoE 模型,方法是将 MLP 层替换为专家混合(MoE)层。
9.3 速度基准测试
我们对 SSD 算法与 Mamba 的扫描实现以及 FlashAttention-2 的速度进行了基准测试(见图 10)。由于 SSD 利用矩阵乘法作为子程序的重组,可以利用 GPU 上的专用矩阵乘法(matmul)单元,也被称为张量核心。因此,它比 Mamba 的融合关联扫描(fused associative scan)快 2-8 倍,后者不利用矩阵乘法单元。由于其在序列长度上的线性扩展,SSD 在序列长度为 2𝐾 及以上时比 FlashAttention-2 更快。
然而,我们注意到整个 Mamba-2 模型在短序列长度(例如 2𝐾)的训练效率可能不如 Transformer 高,因为具有 𝐿 层的 Transformer 将具有 𝐿/2 个 MLP 层和 𝐿/2 个注意力层,而 Mamba-2 模型将具有相同数量的参数。通常,MLP 层非常高效,因为它们由简单的矩阵乘法和逐点线性组成。正如第 9.2.3 节所示,可以将 𝐿/2 个 SSD 层和 𝐿/2 个 MLP 层组合起来,以加快短序列长度的训练。
9.4 架构消融
9.4.1 块设计
第 7.1 节介绍了 Mamba-2 块,它对 Mamba-1 块进行了一些小的修改,部分是由于与注意力的连接,部分是为了改善 Mamba-2 的可扩展性。表 4 消融了块的这些架构变化,这些变化发生在核心 SSM 层之外。
消融结果证实了并行投影以创建(𝐴,𝐵,𝐶,𝑋)比 Mamba 的序列投影节省参数并且执行略有改进。更重要的是,这种修改有利于在更大的模型规模(第 8 节)下进行张量并行计算。此外,额外的归一化层也略微提高了性能。更重要的是,初步实验观察到,在更大规模上,它还有助于训练的稳定性。
9.4.2 头结构
第 7.2 节描述了 𝐵,𝐶,𝑋 投影的维度可以被视为与多头注意力和多查询注意力的概念类似的超参数。我们还展示了原始 Mamba 架构类似于多值注意力(命题 7.2),这是从状态空间模型的角度自然发展出来的选择,之前没有进行消融。 表 5 消融了 Mamba-2 架构的多头结构选择。引人注目的是,我们发现多值和多查询或多键头模式之间存在很大差异,尽管它们看起来非常相似。请注意,这并不是由总状态大小解释的,对于所有这些模式来说,它是相同的(等于 HPN 或头数、头维度和状态维度的乘积)。 我们还比较了 𝐶, 𝐵,𝑋(类似于 𝑄,𝐾,𝑉)头结构的多头模式,其中 𝑖 = 1, 2。我们将其与标准的多头模式进行比较,以及一个具有激进共享的模式,其中它们都只有 1 个头。请注意,在后一种情况下,模型仍然具有 H 个不同的序列混合器 𝑀,因为每个头仍然具有不同的 𝐴。当参数匹配时,这些多头模式彼此表现相似,介于 MVA 和 MQA / MKA 模式之间。
9.4.3 注意力核近似
第 7.3 节指出了如何将 SSD 与线性注意力文献中的思想相结合,例如各种形式的核近似。我们在表 6 中消融了前人提出的几种这样的变体。这些包括 cosFormer(秦,孙卫轩等人,2022年),随机特征注意(Random Feature Attention)(H. Peng等人,2021年)和正随机特征(Performer)(Choromanski等人,2021年)。
我们还尝试了添加一个归一化项,类似于标准注意力中 softmax 函数的分母。我们发现这引入了大多数变体的不稳定性,但对于 ReLU 激活函数 𝜓,性能略有提高。
表 7 还测试了更近期的改进线性注意力的提议,这些提议涉及扩展特征维度(Based(Arora,Eyuboglu,张等人,2024年)和ReBased(Aksenov等人,2024年))。这些线性注意力扩展旨在使用二次逼近适当地处理 exp 核。ReBased 还提出用层归一化替换 QK 激活函数;从 SSM-centric 视角来看,在应用 SSM 函数之前,我们在(𝐵,𝐶)之上应用归一化。我们注意到,这种技术已经被独立提出为 softmax 注意力的 “QK-Norm”(Team,2024年)和 Mamba 的 “内部归一化”(Lieber等人,2024年)。
总体而言,表 6 和表 7 发现我们尝试的核近似方法似乎并没有比简单的逐点非线性激活函数更好。因此,遵循 Mamba-1,我们 Mamba-2 的默认设置使用 𝜓(𝑥) = Swish(𝑥) ,但我们建议完全移除这个激活可能是一个更简单的选择,我们没有进行全面测试。 然而,我们要强调的是,SSD 和 普通的线性注意力在包含 1-半可分隔掩码 𝐿 方面存在差异,而文献中各种线性注意力方法是为了近似没有此项的 softmax 注意力;因此,我们的负面结果可能并不令人意外。
10. 相关工作和讨论
状态空间对偶(SSD)框架连接了 SSMs(状态空间模型)、结构化矩阵和注意力机制。我们将更深入地讨论 SSD 与这些概念之间的广泛关系。利用每种观点中的思想,我们还提出了一些未来扩展 SSD 框架的方向。
10.1 状态空间模型
结构化状态空间模型可以沿以下轴线进行表征:
- 是否是时不变的或时变的。
- 系统的维度。
- 递归转移 A 的结构。
SSD 可以描述为具有 SISO(单输入单输出)维度和标量-恒等结构的选择性 SSM。
时变性(选择性)。最初的结构化 SSM(S4)是线性时间不变(LTI)系统(Gu 2023;Gu, Goel, and Ré 2022),其动机是连续时间在线记忆(Gu, Dao, et al. 2020;Gu, Johnson, Goel, et al. 2021;Gu, Johnson, Timalsina, et al. 2023)。提出了许多结构化 SSM 的变体(Dao, D. Y. Fu, et al. 2023;Gu, Gupta, et al. 2022;Gupta, Gu, and Berant 2022;Ma et al. 2023;J. T. Smith, Warrington, and Linderman 2023),包括一些不再关注递归而专注于 LTI SSM 的卷积表示的变体(D. Y. Fu et al. 2023;Y. Li et al. 2023;Poli et al. 2023;Qin, Han, Weixuan Sun, B. He, et al. 2023)。
SSD 是一种时变的结构化 SSM,也被称为在 Mamba 中引入的选择性 SSM(Gu and Dao 2023)。选择性 SSM 与 RNN 的门控机制密切相关,包括经典 RNN,如 LSTM(Hochreiter and Schmidhuber 1997)和 GRU(J. Chung et al. 2014)以及更现代的变体,如 QRNN(Bradbury et al. 2016)、SRU(Lei 2021;Lei et al. 2017)、RWKV(B. Peng, Alcaide, et al. 2023)、HGRN(Qin, Yang, and Zhong 2023)和 Griffin(Botev et al. 2024;De et al. 2024)。这些 RNN 在参数化方面有所不同,最重要的是缺少状态扩展。
维度和状态扩展。SSD 的一个重要特征是它与其谱系中的早期 SSM(S4、H3、Mamba)一样,是一个单输入单输出(SISO)系统,其中输入通道独立处理。这导致 ND 的更大有效状态大小,其中 N 是 SSM 状态大小(也称为状态扩展因子),D 是标准模型维度。传统 RNN 要么 N = 1,要么是具有密集 B, C 矩阵的多输入多输出(MIMO)系统,这两者都导致较小的状态。虽然 MIMO SSM 在某些领域表现良好(Lu et al. 2023;Orvieto et al. 2023;J. T. Smith, Warrington, and Linderman 2023),但 Mamba 表明状态扩展对于信息密集型领域(如语言)至关重要。SSD 的主要优势之一是允许更大的状态扩展因子,而不会减慢模型速度。许多后续工作自此采用了状态扩展(见 10.4 节)。
结构。与以前的结构化 SSM 相比,SSD 的主要限制在于状态转移 A_𝑡 的表达能力。我们注意到,更通用的 SSM,如对角 A_𝑡 的情况,具有与 SSD 相同的理论效率,但在硬件上不太友好。这是因为对偶二次形式失去了类似注意力的解释,并且变得更难计算。因此,与 Mamba 相比,SSD 只在对角 A_𝑡 的形式上稍微更具限制性,并以此表达能力换取了更高的硬件效率(和易于实现性)。我们假设可以改进我们的结构矩阵算法,以改进到一般对角 SSM 的情况。
10.2 结构化矩阵
状态空间对偶(SSD)的第一个观点采用了这些模型作为矩阵序列变换或 “矩阵混合器” 的观点:可以表示为矩阵乘法(通过 T × T 矩阵)沿序列维度 T 进行的序列变换(定义 2.1)。
之前已经提出了几种这样的矩阵混合器,主要的变化轴是矩阵的表示。这些包括 MLP-Mixer(Tolstikhin et al. 2021)(非结构化矩阵)、FNet(Lee-Thorp et al. 2021)(傅里叶变换矩阵)、M2(Dao, B. Chen, et al. 2022;Dao, Gu, et al. 2019;Dao, Sohoni, et al. 2020;D. Fu et al. 2024)(蝶形/ monarch 矩阵)、Toeplitz 矩阵(Poli et al. 2023;Qin, Han, Weixuan Sun, B. He, et al. 2023),甚至更奇特的结构(De Sa et al. 2018;Thomas et al. 2018)。
一个重要的特征是,高效(次二次)矩阵序列变换恰好是那些具有结构化矩阵混合器的变换。SSD 框架的核心结果是将 SSM 视为具有特定结构的矩阵混合器——半可分矩阵(见第 3 节)。线性与二次对偶性则采用结构化矩阵乘法与朴素的矩阵乘法的形式。
结构矩阵表示通过特定半可分矩阵的块分解引导我们设计了高效的 SSD 算法(第 6 节)。我们注意到,半可分矩阵在科学计算文献中得到了充分研究,结合这些思想可能是改进状态空间模型的有希望途径。我们还建议,专注于矩阵混合器的观点可以为序列模型带来更多有成效的方向,例如设计有原则的非因果 Mamba 变体,或通过分析其矩阵变换结构找到表征和弥合 softmax 注意力与次二次模型之间差距的方法。
10.3 (线性)注意力
与标准(因果)注意力相比,SSD 只有两个主要区别。
首先,SSD 不使用标准注意力的 softmax 激活(Bahdanau, Cho, 和 Bengio 2015;Vaswani 等 2017),这正是使注意力具有二次复杂度的原因。当去掉 softmax 时,可以通过线性注意力框架(Katharopoulos 等 2020)线性扩展计算序列。
其次,SSD 通过输入依赖的 1-半可分掩码乘以 logits 矩阵。因此,这个掩码可以视为取代了标准注意力中的 softmax。
这种半可分掩码也可以看作提供了位置信息。元素 𝑎_𝑡 在 RNN 的意义上充当 “门” 或 “选择” 机制(参见 Mamba 论文中的讨论),其累积乘积 𝑎_(𝑗 :𝑖) 控制位置 𝑖 和 𝑗 之间允许的交互量。位置嵌入(例如,正弦(Vaswani 等 2017),AliBi(Press, N. Smith, 和 Lewis 2022),和 RoPE(Su 等 2021))是 Transformer 的重要组成部分,通常被视为启发式的,而 SSD 的 1-SS 掩码可以看作是一种更有原则的相对位置嵌入形式。我们注意到,这种观点也在 GateLoop(Katsch 2023)中同时提出。
状态空间对偶的第二个观点是我们更一般的结构化掩码注意力(SMA)框架的特例,其中对偶性在一个简单的四向张量收缩上被揭示为不同的收缩顺序。SMA 是线性注意力的强泛化,比 SSD 更一般;其他形式的结构化掩码可能会导致具有不同于 SSD 特性的高效注意力的更多变体。
除了引领新模型,这些与注意力的联系可以引导理解 SSMs 的其他方向。例如,我们好奇注意力吸收(attention sinks)现象(Darcet 等 2024;Xiao 等 2024)是否存在于 Mamba 模型中,以及更广泛地,解释性技术是否可以转移到 SSMs(Ali, Zimerman, 和 Wolf 2024)。
最后,已经提出了许多其他线性注意力的变体(Arora, Eyuboglu, Timalsina, 等 2024;Arora, Eyuboglu, Zhang, 等 2024;Choromanski 等 2021;H. Peng 等 2021;Qin, Han, Weixuan Sun, Dongxu Li, 等 2022;Qin, Weixuan Sun, 等 2022;Schlag, Irie, 和 Schmidhuber 2021;Zhang 等 2024;Zheng, C. Wang, 和 Kong 2022)(参见第 4.1.3 节对其中几种的描述),我们预计许多技术可以转移到 SSMs(例如,第 7.3 节)。
我们强调,SSD 不泛化标准 softmax 注意力,或任何没有有限特征映射 𝜓 的注意力核矩阵的变换。与一般注意力相比,SSD 的优势在于具有可控的状态扩展因子 N,可以压缩历史,而不是二次注意力的整个历史缓存随序列长度 T ≫ N 缩放。同期工作已经开始研究这些表示的权衡,例如在复制和上下文学习任务上(Akyürek 等 2024;Grazzi 等 2024;Jelassi 等 2024;Park 等 2024)。我们注意到,Mamba-2 在某些能力上显著改进了 Mamba(例如,通过第 9.1 节中的 MQAR 结果展示),但仍有许多需要理解的地方。
10.4 相关模型
最后,我们重点介绍了一些与 Mamba 和 Mamba-2 非常相似的序列模型的近期和并行工作。
- RetNet(Y. Sun 等 2023)和 TransNormerLLM(Qin, Dong Li, 等 2023)通过使用衰减项而不是累加和来泛化线性注意力,并提出了并行/递归算法以及混合 “块状(chunkwise)” 模式。这些算法可以看作是 SSD 的一个实例,其中 A_𝑡 是时不变的(对所有 𝑡 都是常数);在 SMA 解释中,掩码矩阵 L 将是一个衰减矩阵 L_(𝑖, 𝑗) = 𝛾_(𝑖−𝑗) 。这些模型在架构上也有所不同。例如,由于它们是从注意力中心视角派生的,它们保留了多头注意力(MHA)模式;由于 Mamba-2 是从 SSM 中心模式派生的,它保留了多值注意力(MVA)或多扩展 SSM(MES)模式,我们证明这种模式更好(第 9.4 节)。
- GateLoop(Katsch 2023)同时提出使用输入依赖的衰减因子 A_𝑡,并开发了与 SSD 相同的对偶二次形式,他们称之为 “代理注意力” 形式。
- Gated Linear Attention(GLA)(Yang 等 2024)提出了一种具有数据依赖门控的线性注意力变体,以及计算块状模式和硬件感知实现的高效算法。
- HGRN(Qin, Yang, 和 Zhong 2023)引入了一种具有输入依赖门控的 RNN,在 HGRN2(Qin, Yang, Weixuan Sun, 等 2024)中通过引入状态扩展进行了改进。
- Griffin(De 等 2024)和 RecurrentGemma(Botev 等 2024)表明,结合局部注意力的具有输入依赖门控的 RNN 可以与强大的现代 Transformer 非常有竞争力。Jamba 还表明,将 Mamba 与几层注意力结合在语言建模中表现非常好(Lieber 等 2024)。
- xLSTM(Beck 等 2024)通过采用状态扩展和其他门控、规范化和稳定化技术改进了 xLSTM。
- RWKV(-4)(B. Peng, Alcaide, 等 2023)是一种基于不同线性注意力近似(无注意力 Transformer(S. Zhai 等 2021))的 RNN。通过采用选择性和状态扩展的理念,最近已改进为 RWKV-5/6(Eagle 和 Finch)架构(B. Peng, Goldstein, 等 2024)。
(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN
(2021,AFT,MHA,RWKV 基础,线性内存复杂度)无注意力的 Transformer
(2023|EMNLP,RWKV,Transformer,RNN,AFT,时间依赖 Softmax,线性复杂度)
(2024,RWKV-5/6,RNN,矩阵值注意力状态,数据依赖线性插值,LoRA,多语言分词器)Eagle 和 Finch
(2023,SSM,门控 MLP,选择性输入,上下文压缩)Mamba:具有选择性状态空间的线性时间序列建模
(2024,Attention-Mamba,MoE 替换 MLP)Jamba:混合 Transformer-Mamba
(2024,LSTM,Transformer,指数门控,归一化器状态,多头内存混合)xLSTM:扩展的 LSTM
(2024,FLOPs 动态分配,MoD,MoDE,top-k 路由,块丢弃)在基于 Transformer 的语言模型中动态分配计算
(2024,Infini-T,Infini-A,压缩记忆,长期记忆)使用无限注意力的高效无限上下文 Transformer
11. 结论
我们提出了一个基于结构矩阵的理论框架,该框架弥合了 SSM 和注意力变体之间的概念差距。该框架提供了有关最近 SSM(例如 Mamba)在语言建模中表现良好的见解。此外,我们的理论工具通过连接两方面的算法和系统进展,为改进 SSM(以及可能的 Transformer)提供了新思路。作为演示,该框架指导我们设计了在 SSM 和结构化注意力交叉点的新架构(Mamba-2)。