论文名称:Neural Discrete Representation Learning
开源地址
发表时间:NIPS2017
作者及组织:Aaron van den Oord,Oriol Vinyals和Koray Kavukcuoglu, 来自DeepMind。
1、VAE
简单回顾下VAE的损失函数,ELBO的下界为:
L
o
w
e
r
B
o
u
n
d
=
E
q
φ
(
z
∣
x
)
[
l
o
g
p
θ
(
x
∣
z
)
]
−
D
K
L
(
q
φ
(
z
∣
x
)
∣
∣
p
(
z
)
)
\begin{equation} Lower Bound =E_{q_\varphi(z|x)}[logp_\theta(x|z)] - D_{KL}(q_\varphi(z|x)||p(z)) \tag{0} \end{equation}
LowerBound=Eqφ(z∣x)[logpθ(x∣z)]−DKL(qφ(z∣x)∣∣p(z))(0)
其中第一项为解码器的重构损失(regression loss) ;第二项为正则项,用KL散度来使Encoder----后验概率
q
φ
(
z
∣
x
)
q_\varphi(z|x)
qφ(z∣x) 和 先验
p
(
z
)
p(z)
p(z) 分布近似,通常
p
(
z
)
p(z)
p(z) 假设为多元标准正太分布,该项主要防止VAE坍塌到一个点,毕竟是生成模型。
而VQVAE和VAE主要不同:Encoder输出是离散的,而不是连续的隐变量z。
1、方法
1.1.模型结构
为了实现离散化编码,VQVAE引入了一个可学习的codebook,即上图中的EmbeddingSpace。大概说下流程:输入一张图像,经过CNN得到 z e ( x ) ∈ R H ∗ W ∗ D z_e(x) \in \mathbb{R}^{H*W*D} ze(x)∈RH∗W∗D ,然后计算 z e z_e ze 中每条特征向量跟codebook的最接近的向量的索引,得到 q ( z ∣ x ) ∈ R H ∗ W q(z|x) \in \mathbb{R}^{H*W} q(z∣x)∈RH∗W , 然后用codebook中向量 e i e_i ei 来替换 z e ( x ) z_e(x) ze(x) 得到 z q ( x ) z_q(x) zq(x) 。最后经过Decoder得到 x x x 。
1.2.训练
先说下总体损失函数,其实跟VAE的损失函数类似:
L
=
l
o
g
p
(
x
∣
z
q
(
x
)
)
+
∣
∣
s
g
[
z
e
(
x
)
]
−
e
∣
∣
2
2
+
β
∣
∣
z
e
(
x
)
−
s
g
[
e
]
∣
∣
2
2
\begin{equation} L = logp(x|z_q(x)) + ||sg[z_e(x)]- e|| ^2_2 + \beta||z_e(x)-sg[e]||^2_2 \tag{1} \end{equation}
L=logp(x∣zq(x))+∣∣sg[ze(x)]−e∣∣22+β∣∣ze(x)−sg[e]∣∣22(1)
其中第一项就是VAE中的重构损失,但有个问题:在用L2 Loss计算重构损失后,反向传播时,由于在codebook中argmin这个操作是不可导的,这样就优化不了Encoder,于是本文直接将
z
q
(
x
)
z_q(x)
zq(x) 节点的梯度拷贝给了
z
e
(
x
)
z_e(x)
ze(x) ,使得反向传播得以继续。具体的表达式如下:
l
o
g
p
(
x
∣
z
q
(
x
)
)
=
∣
∣
x
−
d
e
c
o
d
e
r
(
z
e
(
x
)
+
s
g
(
z
q
(
x
)
−
z
e
(
x
)
)
)
∣
∣
2
2
\begin{equation} logp(x|z_q(x)) = ||x-decoder(z_e(x)+sg(z_q(x)-z_e(x)))||_2^2 \tag{2} \end{equation}
logp(x∣zq(x))=∣∣x−decoder(ze(x)+sg(zq(x)−ze(x)))∣∣22(2)
式中的
s
g
sg
sg 表示 .detach() 操作,由于VQVAE多了一个可学习的codebook,而重构损失并没有梯度传过去。因此损失第二项就是让
e
e
e 逼近
z
e
(
x
)
z_e(x)
ze(x) ,这项仅更新codebook。
由于训练过程中,Encoder相较于codebook,肯定易于优化,也就是Encoder收敛快,而codebook收敛慢 ,为了让Encoder别距离codebook太远,于是增加了第三项损失,让 z e ( x ) z_e(x) ze(x) 逼近 e e e 。
在回过头来跟VAE的式子比较下 ,发现缺少了KL散度项:这是因为在VQVAE中,在根据 x x x 取得 e e e 的概率非0即1: q ( z = e ∣ x ) = 1 , q ( z = o t h e r ∣ x = 0 ) q(z=e|x)=1,q(z=other|x=0) q(z=e∣x)=1,q(z=other∣x=0) ,相当于二项分布,同时假设 p ( z ) p(z) p(z) 是均匀分布,两个均匀分布的KL散度是常数,在损失中可忽略。
1.3.生成
在训练集上训练完VQVAE后,VQVAE学习到的是一个有效的低维度的离散表示。然后将VQVAE置为推理阶段,用自回归模型PixCNN来拟合 q ( z ∣ x ) q(z|x) q(z∣x) ,训练完成后,PixCNN就能生成有意义的索引矩阵,然后去codebook中拿到对应的张量,送去VQVAE的Decoder中解码生成图像。
2、实验
生成的小图还是可以的。
思考
替换更强的自回归模型Transformer也就是后来VQGAN的工作了。