【论文笔记】Mamba: Linear-Time Sequence Modeling with Selective State Spaces

原文链接:https://arxiv.org/abs/2312.00752

1. 引言

基石模型(FM)的主干网络通常是序列模型,处理任意的输入序列。但现代FM主要基于Transformer这一序列模型,及其核心的注意力。但是,自注意力仅能在上下文窗口中密集地传递信息,而无法建模窗口外部的数据;此外,其尺度与窗口长度成二次方关系。注意力相关高效的改进牺牲了其有效性,因此也未被有效地用于不同领域。

最近,结构状态空间序列模型(SSM)作为序列建模的有前景方法,可被理解为RNN与CNN的结合,并受经典状态空间模型的启发。这类模型能高效计算,且尺度与序列长度成比例关系。此外,在部分模态下,还可建模长距离依赖关系,且在连续信号(如音频与视觉)下取得了成功。但对于离散且信息密集的数据(如文本)则不那么有效。

本文提出选择性状态空间模型,在前面的工作上做出改进,达到Transformer的建模能力,且尺度随序列长度线性增大。

选择机制:过去的方法缺乏以数据依赖的方式高效选择数据的能力(关注或忽视特定输入)。本文通过将SSM的参数基于输入参数化,设计选择机制,使模型过滤无关信息并记忆相关信息。

硬件感知的算法:所有之前的SSM需要是时不变和输入不变的,以高效计算。本文使用硬件感知的算法来克服这一问题,递归地使用扫描而非卷积计算模型,且不实现扩展状态以避免在GPU内存层次结构的不同层进行IO访问。这样,实施速度在理论上和现代硬件上均能超过过去的方法(伪线性时间)。

结构:本文将之前的SSM结构与Transformer的MLP组合为块(Manba),一种包含了选择性状态空间的简单而同质的结构设计。

选择性SSM与其扩展Manba均为完全递归的模型,适合作为以序列为输入的通用基石模型的主干网络。其关键属性为:

  1. 高质量:选择性能为密集模态(语言、基因组)带来强性能;
  2. 快速训练和推断:训练时的计算与存储尺度均随序列长度线性变化,推断时自回归地展开模型使得每步只需常数时间,因为无需过去元素的缓存。
  3. 长上下文:质量与效率使其能在1M长度序列上产生性能提升。

在语言、音频、基因组等领域上的实验表明,Mamba只需更少的参数量就能达到Transformer相同的性能,且速度更快。

2. 状态空间模型

结构状态空间模型(S4)与RNN、CNN以及经典状态空间模型相关。其受到特定连续系统的启发,该连续系统通过隐状态 h ( t ) ∈ R N h(t)\in\mathbb R^N h(t)RN映射1维函数或序列 x ( t ) ∈ R → y ( t ) ∈ R x(t)\in\mathbb R\rightarrow y(t)\in\mathbb R x(t)Ry(t)R

S4模型由4个参数定义( Δ , A , B , C \Delta,A,B,C Δ,A,B,C),包含序列到序列的两阶段变换(式(1)):
h ′ ( t ) = A h ( t ) + B x ( t ) ( 1 a ) h t = A ˉ h t − 1 + B ˉ x t ( 2 a ) K ˉ = ( C B ˉ , C A ˉ B ˉ , ⋯   , C A ˉ k B ˉ , ⋯   ) ( 3 a ) y ( t ) = C h ( t ) ( 1 b ) y t = C h t ( 2 b ) y = x ∗ K ˉ ( 3 b ) \begin{matrix} h'(t)=Ah(t)+Bx(t)&(1a)&h_t=\bar Ah_{t-1}+\bar Bx_t&(2a)&\bar K=(C\bar B,C\bar A\bar B,\cdots,C\bar A^k\bar B,\cdots)&(3a)\\ y(t)=Ch(t)&(1b)&y_t=Ch_t&(2b)&y=x*\bar K&(3b) \end{matrix} h(t)=Ah(t)+Bx(t)y(t)=Ch(t)(1a)(1b)ht=Aˉht1+Bˉxtyt=Cht(2a)(2b)Kˉ=(CBˉ,CAˉBˉ,,CAˉkBˉ,)y=xKˉ(3a)(3b)

离散化:第一阶段通过固定公式 A ˉ = f A ( Δ , A ) , B ˉ = f B ( Δ , A , B ) \bar A=f_A(\Delta,A),\bar B=f_B(\Delta,A,B) Aˉ=fA(Δ,A),Bˉ=fB(Δ,A,B),将连续参数( Δ , A , B \Delta,A,B Δ,A,B)转化为离散参数( Δ , A ˉ , B ˉ \Delta,\bar A,\bar B Δ,Aˉ,Bˉ)。其中 f A , f B f_A,f_B fA,fB被称为离散规则。例如,零阶保持(ZOH)由下式定义:
A ˉ = exp ⁡ ( Δ A ) , B ˉ = ( Δ A ) − 1 ( exp ⁡ ( Δ A ) − I ) ⋅ Δ B \bar A=\exp(\Delta A),\bar B=(\Delta A)^{-1}(\exp (\Delta A)-I)\cdot\Delta B Aˉ=exp(ΔA),Bˉ=(ΔA)1(exp(ΔA)I)ΔB

离散化可为连续时间系统赋予额外特性,如分辨率不变性或保证模型被恰当地归一化。同时,离散化也与RNN的门控机制有关(见3.5节)。离散化可被简单地视为SSM前向过程计算图的第一步。一些类型的SSM可以直接参数化 ( A ˉ , B ˉ ) (\bar A,\bar B) (Aˉ,Bˉ),而绕过离散化步骤。

计算:得到离散化的模型后,可通过线性递归(式(2))或全局卷积(式(3))方式计算。通常会使用卷积模式进行并行训练(可一次获取整个序列),然后切换为递归模式进行自回归推断(一次只能获取一个时间步长的数据)。

线性时不变(LTI):式(1)(2)(3)的特点是模型动态随时间恒定,即 Δ , A , B , C \Delta,A,B,C Δ,A,B,C A ˉ , B ˉ \bar A,\bar B Aˉ,Bˉ恒定,称为LTI。LTI SSM与任何线性递归或卷积等价,因此可将LTI视为这些模型的总括术语。

目前位置,由于效率限制,所有的结构SSM均为LTI的(如通过卷积计算)。LTI模型在特定类型的数据上有局限性,本文则移除这一约束并克服效率瓶颈。

结构与维度:结构SSM需要强制矩阵 A A A的结构,如对角矩阵。这样 A ∈ R N × N , B ∈ R N × 1 , C ∈ R 1 × N A\in\mathbb R^{N\times N},B\in\mathbb R^{N\times1},C\in\mathbb R^{1\times N} ARN×N,BRN×1,CR1×N均可由 N N N个数表达。若输入序列 x x x的批量大小为 B B B,长度为 L L L,通道数为 D D D,则会对每个通道独立应用SSM。此时,总的隐状态维度为 D N DN DN,整个序列的计算需要 O ( B L D N ) O(BLDN) O(BLDN)的时间与存储。

通用状态空间模型:状态空间可以表达任何带有隐状态的递归过程,如马尔科夫决策过程、动态因果建模、卡尔曼滤波器、隐马尔科夫模型、线性动态系统、递归模型等。

SSM的结构:SSM可合并到端到端神经网络架构中。如线性注意力(含递归的自注意力近似,可视为退化线性SSM)、H3(门控连接之间加SSM,或在SSM层前加卷积)、Hyena(将H3的S4替换为MLP参数化的全局卷积)、RetNet(添加额外门控并使用更简单的SSM,使用多头注意力的变体而非卷积从而可并行计算)、RWKV(基于另一线性注意力近似的RNN,包含LTId递归,可视为两个SSM之比)。

3. 选择性状态空间模型

3.1 动机:选择为一种压缩方式

序列建模的基本问题是将上下文压缩为更小的状态。高效的模型需要小状态,而有效的模型需要包含上下文所有必要信息的状态。Transformer没有压缩上下文,因此是有效但低效的;递归模型有有限状态,是高效的,但其有效性受到上下文压缩程度的限制。

合成任务中的选择性复制与归纳头均需要内容感知的推理,这说明了LTI模型的缺陷。从递归角度看,其常数动态( A ˉ , B ˉ \bar A,\bar B Aˉ,Bˉ)不能使其选择上下文中正确的信息,或以输入依赖的方式影响隐状态。从卷积角度看,静态卷积核不能建模变化的输入输出关系。

本文提出建立序列模型的基本原则是选择性(或上下文感知能力),能够关注或过滤输入,得到序列状态。选择机制可以控制信息沿序列维度的传播与交互。

3.2 使用选择改进SSM

为模型引入选择性的方法之一是将影响序列交互的参数(如RNN的递归动态或CNN的卷积核)改为输入依赖的。
在这里插入图片描述
为SSM(算法1)添加选择机制的算法如算法2所示。主要的改动为将 B B B C C C改为输入的函数,使其参数与时间相关(维度 L L L表明参数为时变的)。这使得SSM失去了与卷积的等价性。

本文设置 s B ( x ) = L i n e a r N ( x ) , s C ( x ) = L i n e a r N ( x ) , s Δ ( x ) = B r o a d c a s t D ( L i n e a r 1 ( x ) ) s_B(x)=\mathtt{Linear}_N(x),s_C(x)=\mathtt{Linear}_N(x),s_\Delta(x)=\mathtt{Broadcast}_D(\mathtt{Linear}_1(x)) sB(x)=LinearN(x),sC(x)=LinearN(x),sΔ(x)=BroadcastD(Linear1(x)) τ Δ = s o f t p l u s \tau_\Delta=\mathtt{softplus} τΔ=softplus。其中 L i n e a r d \mathtt{Linear}_d Lineard的输出维度为 d d d

3.3 选择性SSM的高效实施

由于时变的SSM失去了卷积计算能力,因此其计算效率受到了影响。

3.3.1 过去模型的动机

  • 尽管递归模式比卷积模式更为灵活,但需要计算大小为 ( B , L , D , N ) (B,L,D,N) (B,L,D,N)的隐状态 h h h。因此,往往采用更为高效的卷积模式,绕过隐状态计算,并设置大小为 ( B , L , D ) (B,L,D) (B,L,D)的卷积核。
  • LTI SSM使用双循环-卷积形式, N N N倍地增加有效的状态维度,而不影响效率。

3.3.2 选择性扫描概述:硬件感知的状态扩展

本文使用三种经典技巧处理选择性SSM的效率问题:核融合,并行扫描和重计算。注意到:

  • 递归计算的FLOP为 O ( B L D N ) O(BLDN) O(BLDN),而卷积计算为 O ( B L D log ⁡ ( L ) ) O(BLD\log(L)) O(BLDlog(L)),且前者的常系数更小。因此长序列和不大的状态维度下,递归模式的FLOP更低。
  • 递归的顺序性与大存储消耗为两大挑战。为解决后者,本文不计算完整状态 h h h

由于多数操作(包括扫描)会受限于存储带宽,本文使用核融合以减小存储器IO的次数,从而进行加速。

具体来说,不在GPU的高带宽存储器(HBM)中加载大小为 ( B , L , D , N ) (B,L,D,N) (B,L,D,N)的扫描输入 ( A ˉ , B ˉ ) (\bar A,\bar B) (Aˉ,Bˉ),而直接从较慢的HBM加载SSM参数 ( Δ , A , B , C ) (\Delta,A,B,C) (Δ,A,B,C)到较快的SRAM中,并在SRAM中进行离散化与递归。最后将大小为 ( B , L , D ) (B,L,D) (B,L,D)的输出写入HBM。

为避免顺序递归,可以使用并行扫描算法进行并行化。

为避免存储反向传播时必要的中间状态,在反向传播时对其进行重计算。因此,融合的选择扫描层与用FlashAttention优化的Transformer有着相同的存储需求。

完整的选择性SSM层如下图所示。
在这里插入图片描述

3.4 简化的SSM结构

在这里插入图片描述
如图所示,本文将H3中的线性注意力与MLP合并。

首先,将模型维度 D D D乘以可控的扩张因子 E E E。此时,大多数参数位于线性投影层,SSM的参数占比很小。复制该块并插入标准归一化和残差连接,得到的Mamba。激活函数 σ \sigma σ使用SiLU或Swish。最后,额外使用可选的归一化层(LayerNorm)。

3.5 选择机制的特性

3.5.1 与门控机制的联系

RNN的经典门控机制是本文SSM选择性机制的实例。

3.5.2 选择机制的解释

可变间距:选择性使得模型可以跳过输入中的噪声token,这在各种模态中(尤其是离散模态)无处不在。

过滤上下文:尽管更多的上下文应该导致严格更高的性能,但许多序列模型在长上下文下并没有提升。这是因为它们不能忽略无关上下文(如全局卷积),而选择性模型可以在任何时刻重置状态,以删除无关的历史,因此性能可以随序列长度增加而单调递增。

边界重置:当多个独立序列被缝合时,Transformer可以通过注意力掩膜使其保持分离,而LTI模型会在各序列间传递信息。选择性SSM也可在边界处重置状态。

Δ \Delta Δ的理解 Δ \Delta Δ控制了对当前输入 x t x_t xt关注或忽略程度的平衡,其泛化了RNN的门控。大的 Δ \Delta Δ会重置状态 h h h并关注当前输入,而小的 Δ \Delta Δ则保留状态并忽视当前输入。SSM可视为按照时间步长 Δ \Delta Δ进行离散化, Δ → ∞ \Delta\rightarrow\infty Δ表示系统关注当前输入而忽略其状态,而 Δ → 0 \Delta\rightarrow0 Δ0则表明瞬态输入被忽略。

A A A的解释:尽管 A A A也可是选择性的,但 Δ \Delta Δ的选择性就可保证 A ˉ = exp ⁡ ( Δ A ) \bar A=\exp(\Delta A) Aˉ=exp(ΔA)的选择性。

B B B C C C的理解:改变 B B B C C C,可以对输入 x t x_t xt是否影响状态 h t h_t ht或状态 h t h_t ht是否影响输出 y t y_t yt进行细粒度控制,即允许模型根据内容(输入)或上下文(隐状态)调节递归动态。

3.6 额外的模型细节

实数&复数:许多SSM使用复数状态 h h h以获取强性能,但某些设置下实数状态可能更好。这可能与数据模态相关,其中复数状态对连续模态有帮助,而实数状态对离散模态更好。

初始化:在复数(实数)情况下,本文使用S4D-Lin(S4D-Real)作为默认初始化方法,将 A A A的第 n n n个元素定义为 − 1 / 2 + n i -1/2+ni 1/2+ni − ( n + 1 ) -(n+1) (n+1))。

Δ \Delta Δ的参数化 Δ \Delta Δ被初始化为 τ Δ − 1 ( U n i f o r m ( [ 0.001 , 0.1 ] ) ) \tau_\Delta^{-1}(\mathtt{Uniform}([0.001,0.1])) τΔ1(Uniform([0.001,0.1]))。此外,可将维度1泛化为更大的维度 R R R(设置为 D D D的因数),且将广播操作视为另一线性投影,从而得到 s Δ s_\Delta sΔ的另一形式: s Δ ( x ) = L i n e a r D ( L i n e a r R ( x ) ) s_\Delta(x)=\mathtt{Linear}_D(\mathtt{Linear}_R(x)) sΔ(x)=LinearD(LinearR(x))

4. 经验评估

4.5 速度与存储基准

在长序列下,本文高效的SSM扫描比最快的注意力实施还快,且推断速度比同等大小的Transformer快很多。注意由于无需缓存 K V KV KV,因此Manba可以设置更大的批量大小。

4.6 模型消融

4.6.1 结构

  • 过去LTI SSM的性能均相近。
  • 将复数S4替换为实数,不会太影响性能,因为实数SSM可能是硬件高效的。
  • 将上述任一方法替换为选择性SSM能极大地提高性能。
  • Mamba的结构与H3的性能相近(当使用选择性层时,会略高)。

4.6.2 选择性SSM

考虑为 Δ , B , C \Delta,B,C Δ,B,C之间不同的组合添加选择性,实验表明 Δ \Delta Δ为最关键的参数,因其与RNN门控相关。

不同的SSM初始化方法在不同模态上的表现可能有较大差异。

增大 Δ , B , C \Delta,B,C Δ,B,C的维度,可以较小的参数量适中地增加性能。

增加状态大小 N N N可以极小的参数量极大地提高性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/442687.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JVM基本概念、命令、参数、GC日志总结

原文: 赵侠客 一、前言 NPE(NullPointerException)和OOM(OutofMemoryError)在JAVA程序员中扮演着重要的角色,它也是很多人始终摆脱不掉的梦魇,与NPE不同的是OOM一旦在生产环境中出现就意味着只靠代码已经无…

Git使用教程:入门到精通

Git使用教程:入门到精通 一、Git安装根据需求选择电脑位数安装;20231023210945建议这里先新建一个文件夹如:D:/Git;专门来存放Git安装包和后续Git代码,方便管理; 二、Git使用前的配置需要先创建自己的Gitee…

四桥臂三相逆变器动态电压恢复器(DVR)MATLAB仿真

微❤关注“电气仔推送”获得资料(专享优惠) 简介 四桥臂三相逆变器 电路 的一般形式如图 1,为 便于分析 ,将其等效成图所示的电路 。以直流母线电压Ud的 1/2处为参考点 ,逆变器三相和零线相 输 出可等效成…

Kotlin dist downloading failed

现象: 在使用AndroidStudio编写Flutter项目时总是在工具的右下角提示错误信息 该问题通常在刚刚打开AndroidStudio时报出,但可以正常编译和运行flutter项目即Android项目 分析:Flutter项目组认为这是AndroidStudio工具平台本身的问题非Flut…

【CSP试题回顾】202009-2-风险人群筛查

CSP-202009-2-风险人群筛查 解题思路 主循环(对每个查询): 使用一个布尔变量pass来标记风险人群是否至少一次进入了特定区域,以及一个布尔变量onlyOnce来确保停留计数 stayNum 在每次查询中最多只增加一次。内循环(对…

站长必备溯源教程-绕过CDN查找背后IP的方法手段

绕过CDN查询背后真实IP方法: 方法一 DNS历史解析记录 查询域名的历史解析记录,可能会找到网站使用CDN前的解析记录,从而获取真实IP 相关查询的网站有:iphistory、DNS查询、微步在线、域名查询、DNS历史查询、Netcraft 方法二 …

Aop注解+Redis解决SpringBoot接口幂等性(源码自取)

目录 一、什么是幂等性? 二、哪些请求天生就是幂等的? 三、为什么需要幂等 1.超时重试 2.异步回调 3.消息队列 四、实现幂等的关键因素 关键因素1 关键因素2 五、引入幂等性后对系统的影响 六、Restful API 接口的幂等性 实战Aop注解redis解…

STM32基本定时功能

1、定时器就是计数器。 2、怎么计数? 3、我们需要有一恒定频率的方波信号,再加上一个寄存器。 4、比如每来一个上升沿信号,寄存器值加1,就可以完成计数。 5、假设方波频率是100Hz,也就是1秒100个脉冲。…

海外媒体宣发套餐如何利用3种方式洞察市场-华媒舍

在当今数字化时代,媒体宣发成为了企业推广产品和品牌的重要手段之一。其中,7FT媒体宣发套餐是一种常用而有效的宣传方式。本文将介绍这种媒体宣发套餐,以及如何利用它来洞察市场。 一、关键概念 在深入讨论7FT媒体宣发套餐之前,让…

Django工具

一、分页器介绍 1.1、介绍 分页,就是当我们在页面中显示一些信息列表,内容过多,一个页面显示不完,需要分成多个页面进行显示时,使用的技术就是分页技术 在django项目中,一般是使用3种分页的技术: 自定义分页功能,所有的分页功能都是自己实现django的插件 django-pagin…

机器学习笔记 大语言模型是如何运作的?一、语料库和N-gram模型

一、语料库 语言模型、ChatGPT和人工智能似乎无处不在。了解大型语言模型(LLM)“背后”发生的事情将是驾驭数字世界的关键。 首先在提示中键入一个单词,然后点击提交。您可以尝试新的提示,并根据需要多次重新生成响应。 这个我们称之为“T&C”的语言模型是在一…

大模型概念解析 | In-context Learning

注1:本文系"概念解析"系列之一,致力于简洁清晰地解释、辨析复杂而专业的概念。本次辨析的概念是:大模型中的In-context Learning 大模型概念解析 | In-context Learning PR-418: What learning algorithm is in-context learning? Investigations with linear mo…

【CSP试题回顾】202012-1-期末预测之安全指数

CSP-202012-1-期末预测之安全指数 解题代码 #include <iostream> #include <algorithm> using namespace std;int n, sum;int main() {cin >> n;for (int i 0; i < n; i){int w, s;cin >> w >> s;sum w * s;}sum max(sum, 0);cout <&…

RocketMQ快速入门_2. rocketmq 的应用场景、与其他mq的差异

0. 引言 之前我们讲解过rabbitMQ&#xff0c;本期我们将进入吞吐量更加强大的rocketMQ的学习。 1. 基础概念 如果你是刚接触MQ的同学&#xff0c;还不清楚消息队列的基础概念的&#xff0c;可以参考我之前这篇文章&#xff1a; https://wu55555.blog.csdn.net/article/deta…

JPEG照片被误删除如何恢复?学会这个方法就够了

JPG/JPEG是一种后缀名为“.jpg”或“.jpeg”的图形格式。它是存储照片图像的常用格式&#xff0c;因此我们可以使用数码相机、手机或其他设备来获取大量的JPG/JPEG文件。有时&#xff0c;我们会遇到由于意外删除、格式化驱动器或其他未知原因导致 JPEG 文件丢失的情况。无论哪种…

外包干了30天,技术明显退步。。

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 这次来聊一个大家可能也比较关心的问题&#xff0c;那就是就业城市选择的问题。而谈到这个问题&a…

程序员失业,被迫开启 PlanB——成为自由职业/独立开发者的第 0 天

程序员失业&#xff0c;被迫开启 PlanB——成为自由职业/独立开发者的第 0 天 今天在逛V2EX的时候看到的一个帖子&#xff0c;程序员中年被裁&#xff0c;被迫开启独立开发这条路。 原贴如下&#xff1a; lastday, 失业啦 公司年前通知我合同到期不续签&#xff0c;今天是我…

seq2seq翻译实战-Pytorch复现

&#x1f368; 本文为[&#x1f517;365天深度学习训练营学习记录博客 &#x1f366; 参考文章&#xff1a;365天深度学习训练营 &#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制]\n&#x1f680; 文章来源&#xff1a;[K同学的学习圈子](https://www.yuque.com/…

有点NB的免费wordpress主题模板

一个不错的黄色模板&#xff0c;用WP免费主题模板搭建家政服务公司网站。 https://www.wpniu.com/themes/15.html

Spring Boot中Excel数据导入导出的高效实现

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…