文章目录
- 参数和数据尺寸约定
- class MambaBlock
- def forward
- def __ int__
- def ssm
- def selective_scan
johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)
manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现
参数和数据尺寸约定
之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示
参数及简写 | Mamba论文简写 |
---|---|
batch_size b | B |
序列长度 l | L |
隐藏维度 d / d_model | |
潜在状态维度 n / d_state | N |
扩展因子 expand | E |
d_in / d_inner | D |
数据依赖步长 Δ \Delta Δ / delta | |
delta秩 dt_rank |
class MambaBlock
def forward
根据forward简单梳理MambaBlock的结构
中间变量 | 来源 | shape |
---|---|---|
输入x | (b, l, d_model) | |
x_and_res | x经过输入映射后 | (b, l, 2* d_in) |
x | 切分后作为ssm分支输入 | (b, l, d_in) |
res | 切分后作为门控分支输入 | (b, l, d_in) |
y | 经过卷积,激活,ssm,门控后的输出 | (b, l, d_in) |
output | y经过输出映射后得到 | (b, l, d_model) |
def forward(self, x):
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, 'b d_in l -> b l d_in')
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
output = self.out_proj(y)
return output
def __ int__
初始化主要初始了几个部分
组件定义
操作及简写 | 维度变换 |
---|---|
输入映射 in_proj | (b, l, d_model) -> (b, l, 2*d_in) |
序列变换 conv1d | 只取前l (b, d_in, l) -> (b, d_in, l) |
非线性激活 silu | |
输出映射 out_proj | (b, l, d_in) -> (b, l, d) |
ssm初始化
操作及简写 | 作用 |
---|---|
参数生成映射 x_proj | 生成数据依赖的参数B, C, Δ \Delta Δ |
delta映射 dt_proj | 将 Δ \Delta Δ从dt_rank映射到d_in |
矩阵A初始化 | 简单初始化 |
矩阵D初始化 | 简单初始化 |
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
self.conv1d = nn.Conv1d(
in_channels=args.d_inner,
out_channels=args.d_inner,
bias=args.conv_bias,
kernel_size=args.d_conv,
groups=args.d_inner,
padding=args.d_conv - 1,
)
# ssm模型的初始化部分
# x_proj takes in `x` and outputs the input-specific Δ, B, C
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
# dt_proj projects Δ from dt_rank to d_in
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(args.d_inner))
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
def ssm
这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。
SSM参数 | shape | 来源 |
---|---|---|
状态矩阵A | (d_in, n) | 在初始化中定义,非数据依赖 |
输入矩阵B | (b, l, n) | 由x_db1切分而来,因此数据依赖 |
输出矩阵C | (b, l, n) | 由x_db1切分而来,因此数据依赖 |
直接传递矩阵D | (d_in) | 在初始化中定义,非数据依赖 |
数据依赖步长 Δ \Delta Δ | (b, l, d_in) | 由x_db1切分而来,因此数据依赖 |
其中一部分变量初始化于class MambaBlock的初始化部分
中间变量及简写 | 来源 |
---|---|
数据生成变量 x_db1 | x经过参数映射x_proj生成 |
最终delta Δ \Delta Δ | 切分而来的 Δ \Delta Δ经过映射和softplus |
def ssm(self, x):
(d_in, n) = self.A_log.shape
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
return y
SSM参数 | shape |
---|---|
状态矩阵A | (d_in, n) |
输入矩阵B | (b, l, n) |
输出矩阵C | (b, l, n) |
直接传递矩阵D | (d_in) |
def selective_scan
我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。
在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化
前向欧拉离散化
x
k
=
(
I
+
Δ
k
A
)
x
k
−
1
+
Δ
k
B
⋅
u
k
x
(
t
+
Δ
)
=
(
I
+
Δ
A
)
x
(
t
)
+
Δ
B
⋅
u
(
t
)
\begin{aligned} x_{k}& \begin{aligned}=(\boldsymbol{I}+\Delta_{k}\boldsymbol{A})x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k}\end{aligned} \\ x(t+\Delta)& =(\boldsymbol{I}+\Delta\boldsymbol{A})x(t)+\Delta\boldsymbol{B}\cdot u(t) \end{aligned}
xkx(t+Δ)=(I+ΔkA)xk−1+ΔkB⋅uk=(I+ΔA)x(t)+ΔB⋅u(t)
零阶保持离散化
x
k
=
e
Δ
k
A
x
k
−
1
+
(
Δ
k
A
)
−
1
(
e
Δ
k
A
−
I
)
⋅
Δ
k
B
⋅
u
k
x
(
t
+
Δ
)
=
e
Δ
A
x
(
t
)
+
(
Δ
A
)
−
1
(
e
Δ
A
−
I
)
⋅
Δ
B
⋅
u
(
t
)
\begin{aligned} x_{k}& =e^{\Delta_{k}\boldsymbol A}x_{k-1}+(\Delta_{k}\boldsymbol A)^{-1}(e^{\Delta_{k}\boldsymbol A}-\boldsymbol{I})\cdot\Delta_{k}\boldsymbol B\cdot u_{k} \\ x(t+\Delta)& =e^{\Delta \boldsymbol A}x(t)+(\Delta \boldsymbol A)^{-1}(e^{\Delta \boldsymbol A}-\boldsymbol{I})\cdot\Delta \boldsymbol B\cdot u(t) \end{aligned}
xkx(t+Δ)=eΔkAxk−1+(ΔkA)−1(eΔkA−I)⋅ΔkB⋅uk=eΔAx(t)+(ΔA)−1(eΔA−I)⋅ΔB⋅u(t)
这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢
def selective_scan(self, u, delta, A, B, C, D):
(b, l, d_in) = u.shape
n = A.shape[1]
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
y = y + u * D
return y