context
- 1. 剪枝方案图释
- 2. 正交矩阵Q
1. 剪枝方案图释
图中的阴影是表示丢弃掉这部分数据。通过引入正交矩阵
Q
Q
Q使
Q
⊤
Q
=
Q
Q
⊤
=
I
\mathrm{Q}^\top\mathrm{Q}=\mathrm{Q}\mathrm{Q}^\top=\mathrm{I}
Q⊤Q=QQ⊤=I,来大量缩减
X
X
X的列数和
W
W
W的行数。
由于
Q
Q
Q是正交矩阵,有
∥
Q
x
∥
=
x
⊤
Q
⊤
Q
x
=
x
⊤
x
=
∥
x
∥
\|\mathbf{Q}x\|=\sqrt{x^\top\mathbf{Q}^\top\mathbf{Q}x}=\sqrt{x^\top x}=\|x\|
∥Qx∥=x⊤Q⊤Qx=x⊤x=∥x∥,所以
Q
Q
Q与
x
x
x相乘不会影响
x
x
x的范数。
在一般情况下,假设
X
ℓ
\mathbf{X}_{\ell}
Xℓ是transformer中一个块的输出,在经过RMSNorm(对每一行
x
←
X
∣
∣
X
∣
∣
x\leftarrow \frac{\mathbf{X}}{\left|\left|\mathbf{X}\right|\right|}
x←∣∣X∣∣X处理),然后
R
M
S
N
o
r
m
(
X
ℓ
)
\mathrm{RMSNorm}(\mathbf{X}_{\ell})
RMSNorm(Xℓ)作为下一块的输入。若引入矩阵
Q
Q
Q,则有
R
M
S
N
o
r
m
(
X
ℓ
)
=
R
M
S
N
o
r
m
(
X
ℓ
Q
)
Q
⊤
\mathrm{RMSNorm}(\mathbf{X}_\ell)=\mathrm{RMSNorm}(\mathbf{X}_\ell\mathbf{Q})\mathbf{Q}^\top
RMSNorm(Xℓ)=RMSNorm(XℓQ)Q⊤,所以实际上引入
Q
Q
Q不改变transformer的结构。对于transformer中的每一attention或FFN层都有线性层,同时由于transformer中有残差连接(图中的
+
◯
\textcircled{+}
+◯操作),这里把矩阵
Q
Q
Q引入每一块的线性层,所以需要把矩阵
Q
Q
Q引入到所有之前的层(一直到编码阶段)和所有之后的层(一直到LM头)。
令
W
i
n
ℓ
\mathbf{W}_{in}^\ell
Winℓ和
W
o
u
t
ℓ
\mathbf{W}_{out}^\ell
Woutℓ为transformer的第
ℓ
\ell
ℓ块的线性层的权重矩阵,
b
i
n
ℓ
\mathbf{b}_{in}^\ell
binℓ和
b
o
u
t
ℓ
\mathbf{b}_{out}^\ell
boutℓ为相对应的偏置,
W
e
m
b
d
\mathbf{W}_{embd}
Wembd和
W
h
e
a
d
\mathbf{W}_{head}
Whead为编码和头矩阵,
Q
Q
Q为
D
D
D维矩阵,则可以用以下矩阵来模型不变性变换
W
~
e
m
b
d
=
W
e
m
b
d
Q
,
(1)
b
~
o
u
t
ℓ
=
Q
⊤
b
o
u
t
ℓ
,
(4)
W
~
i
n
ℓ
=
Q
⊤
W
i
n
ℓ
,
(2)
W
~
h
e
a
d
=
Q
⊤
W
h
e
a
d
.
(5)
W
~
o
u
t
ℓ
=
W
o
u
t
ℓ
Q
,
(3)
\begin{aligned}\tilde{\mathbf{W}}_{embd}&=\mathbf{W}_{embd}\mathbf{Q} ,&&\text{(1)}&&\tilde{b}_{out}^{\ell}=\mathbf{Q}^{\top}b_{out}^{\ell} ,&&\text{(4)}\\\tilde{\mathbf{W}}_{in}^{\ell}&=\mathbf{Q}^{\top}\mathbf{W}_{in}^{\ell},&&\text{(2)}&&\tilde{\mathbf{W}}_{head}=\mathbf{Q}^{\top}\mathbf{W}_{head} .&&\text{(5)}\\\tilde{\mathbf{W}}_{out}^{\ell}&=\mathbf{W}_{out}^{\ell}\mathbf{Q} ,&&\text{(3)}\end{aligned}
W~embdW~inℓW~outℓ=WembdQ,=Q⊤Winℓ,=WoutℓQ,(1)(2)(3)b~outℓ=Q⊤boutℓ,W~head=Q⊤Whead.(4)(5)偏置矩阵保持不变
b
~
i
n
ℓ
=
b
i
n
ℓ
,
b
~
h
e
a
d
=
b
h
e
a
d
\tilde{b}_{in}^{\ell}=b_{in}^{\ell},\tilde{b}_{head}=b_{head}
b~inℓ=binℓ,b~head=bhead
文章主题思想如图Fig. 1.2
图中,(a)中的
W
Q
W_Q
WQ、
W
K
W_K
WK和
W
V
W_V
WV是注意力中的QKV操作,
W
V
W_V
WV表示注意力机制的输出矩阵,
M
=
I
−
1
D
1
1
⊤
\mathbf{M}=\mathbf{I}-\frac{1}{D}\mathbf{1}\mathbf{1}^{\top}
M=I−D111⊤是用来使矩阵
X
X
X中的每一个元素拉回到0上下,与下一步的
x
←
X
∣
∣
X
∣
∣
x\leftarrow \frac{\mathbf{X}}{\left|\left|\mathbf{X}\right|\right|}
x←∣∣X∣∣X共同完成归一化处理,
W
1
W_1
W1和
W
2
W_2
W2是MLP操作。(b)与(c)中的
(
α
)
(\alpha)
(α)就是diag(
α
\alpha
α),矩阵
(
α
′
)
(\alpha^{'})
(α′)来自前一块。向量
α
\alpha
α和偏置
β
\beta
β在每个LayerNorm实例上独立学习。diag(
α
\alpha
α)是一个矩阵操作,表示将一个向量
(
α
)
(\alpha)
(α)作为对角线元素创建一个对角矩阵。
最后移除一些不重要的行和列。
2. 正交矩阵Q
使用主成分分析(PCA)来求解
Q
ℓ
Q_{\ell}
Qℓ(transformer中第
ℓ
\ell
ℓ块),在训练集中抽取一些数据作为校准数据,喂给模型用来从前到后逐层提取正交矩阵。对于校准数据集中的
i
i
i条数据,使模型中第
ℓ
\ell
ℓ层输出为
X
ℓ
,
i
X_{\ell,i}
Xℓ,i,则有
C
ℓ
=
∑
i
X
ℓ
,
i
⊤
X
ℓ
,
i
\mathrm{C}_{\ell}=\sum_{i}\mathrm{X}_{\ell,i}^{\top}\mathrm{X}_{\ell,i}
Cℓ=i∑Xℓ,i⊤Xℓ,i则
Q
ℓ
Q_{\ell}
Qℓ是
C
ℓ
\mathrm{C}_{\ell}
Cℓ的降序排列特征值的特征矩阵。