原文链接:https://arxiv.org/abs/2402.10739
1. 引言
基于Transformer的点云分析方法有二次时空复杂度,一些方法通过限制感受野降低计算。这引出了一个问题:如何设计方法实现线性复杂度并有全局感受野。
状态空间模型(SSM)作为序列建模方法,Mamba在结构状态空间模型(S4)的基础上使用时变SSM参数和硬件感知算法,实现了线性复杂度和全局感受野。但目前的Mamba较少用于视觉任务。
本文探索SSM在点云分析任务中的潜力。直接使用Mamba的性能不佳,这是因为SSM的单向建模能力(相反,自注意力是输入顺序不变的)。本文提出点状态空间模型(PointMamba),首先生成点的token序列,然后使用重排序策略以特定顺序扫描数据,使模型捕捉点云结构。最后将重排序后点的token输入Mamba编码器,进行全局建模。
实验表明,本文方法可以超过基于Transformer方法的性能,且有更少的参数和计算量。
3. 方法
3.1 准备知识
状态空间模型:状态空间模型建模了时不变(LTI)系统,使用一阶微分方程捕捉系统动态:
h
˙
(
t
)
=
A
h
(
t
)
+
B
x
(
t
)
,
y
(
t
)
=
C
h
(
t
)
+
D
x
(
t
)
.
\dot h(t)=Ah(t)+Bx(t),\\y(t)=Ch(t)+Dx(t).
h˙(t)=Ah(t)+Bx(t),y(t)=Ch(t)+Dx(t).
为处理离散token序列输入,需要进行离散化:
h
k
=
A
ˉ
h
k
−
1
+
B
ˉ
x
k
,
y
k
=
C
ˉ
h
k
+
D
ˉ
x
k
.
h_k=\bar Ah_{k-1}+\bar Bx_k,\\y_k=\bar Ch_k+\bar Dx_k.
hk=Aˉhk−1+Bˉxk,yk=Cˉhk+Dˉxk.
其中
A
ˉ
∈
R
N
×
N
,
B
ˉ
∈
R
N
×
1
,
C
ˉ
∈
R
1
×
N
,
D
ˉ
∈
R
\bar A\in\mathbb R^{N\times N},\bar B\in\mathbb R^{N\times 1},\bar C\in\mathbb R^{1\times N},\bar D\in\mathbb R
Aˉ∈RN×N,Bˉ∈RN×1,Cˉ∈R1×N,Dˉ∈R为参数矩阵。
D
ˉ
\bar D
Dˉ为残差连接,通常可简化或忽略。离散化需要使用时间步长
Δ
\Delta
Δ,在连续信号
x
(
t
)
x(t)
x(t)进行采样,得到
x
k
=
x
(
k
Δ
)
x_k=x(k\Delta)
xk=x(kΔ)。这使得:
A
ˉ
=
(
I
−
Δ
/
2
⋅
A
)
−
1
(
I
+
Δ
/
2
⋅
A
)
,
B
ˉ
=
(
I
−
Δ
/
2
⋅
A
)
−
1
Δ
B
,
C
ˉ
=
C
\bar A=(I-\Delta/2\cdot A)^{-1}(I+\Delta/2\cdot A),\\\bar B=(I-\Delta/2\cdot A)^{-1}\Delta B,\\\bar C=C
Aˉ=(I−Δ/2⋅A)−1(I+Δ/2⋅A),Bˉ=(I−Δ/2⋅A)−1ΔB,Cˉ=C
选择性SSM: B ˉ , C ˉ \bar B,\bar C Bˉ,Cˉ和 Δ \Delta Δ为动态、输入相关的参数,从而使得SSM为时变模型。这样能够过滤和捕捉时间相关的特征和关系,从而更精确地表达输入序列。
3.2 PointMamba
3.2.1 概述
如图所示,本文方法包括点tokenizer,重排序策略、Mamba和下游任务头。本文使用轻量化PointNet嵌入点的patch,得到点的token,然后根据几何坐标进行重排序,将序列长度变为3倍,输入Mamba。
3.2.2 点tokenizer
使用最远点采样(FPS)和K近邻(KNN)算法将点云分为不规则的点patch。具体来说,给定含 M M M个点的点云 I ∈ R M × 3 I\in\mathbb R^{M\times3} I∈RM×3,使用FPS采样 n n n个关键点,然后为每个关键点,使用KNN算法选择 k k k个最近点,得到 n n n个patch P ∈ R n × k × 3 P\in\mathbb R^{n\times k\times3} P∈Rn×k×3。然后,求取patch中各点相对关键点的相对坐标,并使用轻量化PointNet映射到特征空间,得到点token E 0 ∈ R n × C E_0\in\mathbb R^{n\times C} E0∈Rn×C。
3.2.3 重排序策略
由于Mamba是单向处理数据,适合1D数据;但难以处理点云这类无序数据。
本文通过特定顺序扫描点云,以捕捉点云结构。如图所示,本文分别基于点token簇中心的几何
x
,
y
,
z
x,y,z
x,y,z坐标进行排序并拼接,得到
E
0
′
∈
R
3
n
×
C
E'_0\in\mathbb R^{3n\times C}
E0′∈R3n×C。该方法通过提供更有逻辑的几何扫描顺序,提高了Mamba的几何建模能力。
3.2.4 Mamba块
每个Mamba块包含层归一化(LN)、SSM、逐深度卷积和残差连接,如图1右侧所示。公式表示为:
Z
l
′
=
D
W
(
M
L
P
(
L
N
(
Z
l
−
1
)
)
)
,
Z
l
=
M
L
P
(
L
N
(
S
S
M
(
σ
(
Z
l
′
)
)
)
×
σ
(
L
N
(
Z
l
−
1
)
)
)
+
Z
l
−
1
Z'_l=DW(MLP(LN(Z_{l-1}))),\\Z_l=MLP(LN(SSM(\sigma(Z_l')))\times\sigma(LN(Z_{l-1})))+Z_{l-1}
Zl′=DW(MLP(LN(Zl−1))),Zl=MLP(LN(SSM(σ(Zl′)))×σ(LN(Zl−1)))+Zl−1
其中 Z l ∈ R 3 n × C Z_l\in\mathbb R^{3n\times C} Zl∈R3n×C为第 l l l块的输出, Z 0 = E 0 ′ Z_0=E'_0 Z0=E0′; σ \sigma σ为SiLU激活函数。
3.2.5 预训练
本文使用PointMAE的设置进行预训练,即随机掩蔽60%的点patch,使用自编码器提取点的特征并使用预测头重建点云。
自编码器可公式化为:
T
v
=
F
e
(
T
v
+
P
E
)
,
H
v
,
H
m
=
F
d
(
C
o
n
c
a
t
(
T
v
,
T
m
)
)
,
P
m
=
F
h
(
H
m
)
.
T_v=F_e(T_v+PE),\\H_v,H_m=F_d(Concat(T_v,T_m)),\\P_m=F_h(H_m).
Tv=Fe(Tv+PE),Hv,Hm=Fd(Concat(Tv,Tm)),Pm=Fh(Hm).
其中 F e F_e Fe为编码器,以未掩蔽的token T v T_v Tv为输入; F d F_d Fd为Mamba解码器,以 F e F_e Fe的输出和掩蔽的token T m T_m Tm为输入。本文仅在编码器和解码器的第一层加入位置编码 P E PE PE。 F h F_h Fh为线性层,将掩蔽token H m H_m Hm投影为与掩蔽输入点形状相同的向量。使用Chamfer距离作为重建损失,以恢复掩蔽点的坐标。
4. 实验
4.1 实施细节
与ViT不同,本文不使用类别token。分类时,本文将最后一层Mamba的所有输出平均值用于分类。分割任务则将中间多层的输出合并,进行最大和均值池化得到全局特征,然后与逐点特征拼接,输入线性层预测。
4.2 与基于Transformer的方法比较
实验表明,本文方法在无预训练情况下能达到与基于Transformer的方法相当的性能,且有更少的参数和计算量。预训练和使用重排序策略均能提高性能。
此外,随着序列长度的增加,基于Transformer的方法GPU内存占用显著增加,但本文的PointMamba仅线性增长。
4.3 消融研究
重排序策略:比较不进行重排序(1倍序列长度)、进行重排序(3倍序列长度)和双向重排序(即将重排序结果逆序后与重排序结果拼接,6倍序列长度)。实验表明,基于Transformer的方法在序列长度增加时,性能略微下降;重排序策略能提高单向建模Mamba在点云中的适应能力;进一步增加序列长度能进一步提高性能,但为平衡计算量与性能,本文选择3倍序列长度;尽管如此,由于本文方法的线性复杂度,计算量增长也远小于基于Transformer的方法。
分类token的分析:实验表明,不使用类别token能达到最好的分类性能。
4.4 局限性
预训练没有考虑Mamba的单向建模特点;重排序需要将序列长度变为3倍。