大模型技术论文不断,每个月总会新增上千篇。本专栏精选论文重点解读,主题还是围绕着行业实践和工程量产。若在某个环节出现卡点,可以回到大模型必备腔调或者LLM背后的基础模型新阅读。而最新科技(Mamba,xLSTM,KAN)则提供了大模型领域最新技术跟踪。若对于具身智能感兴趣的请移步具身智能专栏。技术宅麻烦死磕AI架构设计。
Mamba的出现为带来了全新的思路和可能性,通过对结构化半可分离矩阵的各种分解方法的理论研究,可以将状态空间模型SSM与注意力机制Attention的变种进行紧密关联,进而提出一种状态空间对偶SSD的理论框架。
状态空间对偶使得研究人员设计一种新的架构 (Mamba-2),其核心层是对 Mamba(选择性SSM)进行改进,速度提高了2-8倍,同时在语言建模方面能够保持对Transformers的压力。
在开始之前提醒下读者,在Mamba不再真正认为SSM是连续的。事实上,正如在原始论文的讨论,Mamba与S4在对不同类型的数据进行建模方面进行了权衡:S4 是一种连续时间模型,擅长对连续数据进行建模,例如音频波形和像素级视觉等感知信号。Mamba S6是一种离散时间模型,擅长对离散数据进行建模,例如语言等标记化数据。
线性注意力机制代表着在注意力运算时去掉了softmax。这点在线性RNN<要是忘记了,记得温习下!>已经讲过了。
张量收缩
张量收缩这个词一时之间很难解释清楚,后续开专题介绍。大白话的意思就是多个高维的矩阵按照某种方式转化(相乘)压缩到一定的维度,其实传统的矩阵乘法也是其中的一种。
假如用图标来表示,那么如下为各种维度的张量表示:
那么多个矩阵之间的相乘就可以用下面的图标进行简化表示:
张量收缩在高维张量中使用广泛,例如下面4维的矩阵AB之间要按照某个维度进行压缩,整个过程如下:
d = 10
A = np.random.rand(d,d,d,d)
B = np.random.rand(d,d,d,d)
Ap = A.transpose(0,2,1,3); Bp = B.transpose(0,3,1,2)
App = Ap.reshape(d**2,d**2); Bpp = Bp.reshape(d**2,d**2)
Cpp = App @ Bpp; C = Cpp.reshape(d,d,d,d)
当然也可以使用numpy的函数einsum,
einsum("some string describing an operation", tensor_1, tensor_2, ...) 。
例如输入,不用这个函数的话,你只能这么写:
n = A.shape[0]
out = (
A[t.arange(n), t.arange(n), :, None, None]
* B.permute(2, 1, 0)[:, :, :, None]
* C[None, None, None, :]
).sum(1,3).T
若采用einsum函数,则:
out = t.einsum("iij,kji,l->ki", A, B, C)
SMA
Structured Masked Attention
结构化掩蔽注意力 (SMA)(或简称为结构化注意力)被定义为𝑄、𝐾、𝑉以及任何结构化矩阵𝐿的函数,通过4向张量收缩,
这里请读者注意每个矩阵的形状参数,
SMA有两种算法可以通过进行张量收缩:
一种为二次模态算法,例如的标准注意力机制,这个计算大家应该不陌生吧。要是陌生的话,请移步。
另外一种为线性模式算法:
众所周知,收缩顺序会对计算复杂度产生巨大影响。状态空间模型是一种可以通过多种方式计算,具有二次与线性对偶形式。线性注意力具有类似的二元(对偶)性。
到了这里说明选择不同的Mask L可以生成各种线性注意力的变种。而SMA也借助这L矩阵将各种线性注意力统一到一个框架之下。在这个框架下Mamba-2则是采用了1-半分离矩阵。
下面其实列出了SSM和SMA之间的紧密联系,他们都拥有二元模态和线性模态。SSM和SMA在矩阵A为标量的时候相交,产生了一大堆SSD的模型,而这些模型只是SSD中的特例。<此处请注意A和L!>
因此SSD层可以看成是SSM也可以看成是线性注意力:
Mamba-2架构
上图右为Mamba-2的块结构,Mamba-2块通过删除连续的线性投影简化了Mamba块;SSM参数𝐴、𝐵、𝐶在块的开头生成,而不是作为SSM输入𝑋 的函数。同时Mamba-2添加了一个额外的规范化层,就像在NormFormer中一样,以提高稳定性。𝐵 和 𝐶 投影在 𝑋 头部之间只存在一个单头进行共享,类似于多值注意力 (MVA)。
请注意不同头,ABC的维度
这里重点讨论的是Mamba-2的并行策略。使用张量并行对 Mamba-1进行大规模训练存在一个问题,它每层需要2个all-reduce,而Transformer中每个注意力或 MLP层只需要 1个all-reduce。这是因为SSM参数是内部激活函数,而不是层输入的函数。
在Mamba-2中使用平行投影,所有SSM参数都是层输入的函数,可以很容易地将TP应用于输入投影,即每个SSM head (𝐴, 𝐵,𝐶,𝑋) ↦ 𝑌都在单一的设备上。例如将输入投影和输出投影矩阵拆分为 2、4、8 个分片,以及每个GPU单独进行归一化。这些更改会导致每层1个 all-reduce。
当在非常长的序列上进行训练时,可以沿着序列长度进行拆分,并将不同的部分分配给不同的设备。序列并行性有两种主要形式:对于残差和归一化运算,将张量并行中的all-reduce 替换为reduce-scattter、残差+归一化,然后是all-gather。由于 Mamba-2 使用与 Transformer 相同的残差和归一化结构,因此这种形式的序列并行无需修改即可直接应用。
对于注意力,可以使用环形注意力沿序列维度将其拆分。对于 Mamba-2,SSD框架使用相同的块分解,可以让每个 GPU 计算其本地输出和最终状态,然后在 GPU 之间传递状态(使用发送/接收通信原语),然后再更新每个 GPU 的最终输出。