1、前言
这篇文章,我们讲VQ_GAN,这是一个将特征向量离散化的模型,其效果相当不错,搭配Transformer(GPT)或者CLIP使用,达到的效果在当时可谓是令人拍案叫绝!
原论文:Taming Transformers for High-Resolution Image Synthesis (arxiv.org)
参考代码:dome272/VQGAN-pytorch: Pytorch implementation of VQGAN
视频:[GPT与GAN结合生成图像——VQGAN原理解析-哔哩哔哩]
效果演示:
图像生成
其他任务
2、VQVAE
VQGAN其实是VQVAE修改过来的,是VQVAE先对VAE中的编码向量离散化。而后,VQGAN就是在VQVAE的基础上进行了一些修改,以提高其生成效果
由于这篇文章讲的是VQGAN,所以不会涉及VQVAE里面的公式推导,我们就直观的理解就行了,后续我看看是否需要补一个VQVAE
3、VQGAN
论文里面提到,VQGAN的出现的动机是针对transformer,由于transformer在NLP(自然语言处理)取得了令人惊讶的效果。所以,就有很多人尝试,是否可以将transformer应用在图像处理领域
在这篇论文之前,已经有人进行尝试,transformer可以应用在图像领域,并且取得了相当不错的效果。然而,相对于NLP,图像处理的难度却比较大,在transformer中,一句话的长度往往不会很长,里面的自注意力机制的计算量仍然可以接收;可图像处理领域的每个像素如果都要做自注意力的话,在低像素的或许还可以接收,但是一旦到了高像素,其计算量往往令人望而生畏。
受VQVAE的启发,作者先把图像通过编码器,编码成维度较低的向量,从而减少自注意力机制的计算量。并且,会把编码后的向量离散化。作者认为,在自然界中,图像本身应该是由一个个离散的量组合而成的,就好比东一个西一个,就拼成了车。
4、VQGAN流程
首先,从左下角开始,有一张狗的照片(红框),把这张图送给一个卷积编码器( E E E),输出向量 z ^ \hat z z^。
接着,初始化一个码本(Codebook Z ∈ R ( n u m , d i m ) Z\in R^{(num,dim)} Z∈R(num,dim),num是码本有多少行,dim是每行多少维度),把向量 z ^ \hat z z^在像素层面上,都在码本中找到与它最像的一个向量(使用最近邻搜索)。得到 z q z_q zq(图中像素上面的数字代表码本对应位置向量)
把得到的 z q z_q zq,送给解码器G,恢复图像,然后把这张还原的图像和生成的图像,送给卷积判别器D,判断真伪。
这就是整个流程。
我们看图中的码本,码本中对应的向量,分别表示图中那只狗某一块的特征,这种就是特征的离散化,能够让特征充分解耦。
5、VQVAE的损失
VQGAN的目标,就是学习到一个足够好的码本,编码器和解码器。
在讲VQGAN之前,我们先来看VQVAE。
5.1、VQVAE重构损失
这是VQVAE的模型图(与VQGAN相比,少了判别网络D)
如果你知道VAE或者AE,就应该知道,我们要让编码后再解码得到的图像和原始图像很像,那就说明这两个编码和解码器足够好。所以,我们要让重构的损失最小。即
L
r
e
c
=
∣
∣
x
−
x
^
∣
∣
2
=
∣
∣
x
−
G
(
z
q
)
∣
∣
2
L_{rec} = ||x-\hat x||^2=||x-G(z_q)||^2
Lrec=∣∣x−x^∣∣2=∣∣x−G(zq)∣∣2
x
^
\hat x
x^表示重构出来的图像,
G
G
G是解码器。
这是一种非常朴素的想法,但是,这里有个问题,那就是里面的
z
q
z_q
zq是
z
^
\hat z
z^在码本中最近邻搜索弄出来,这种最近邻匹配的方法是没有办法把梯度传递会编码器E那边的。于是,作者提出了straight-through estimator,具体做法如下,我们令
z
q
=
z
^
+
s
g
(
z
q
−
z
^
)
(1)
z_q = \hat z+ sg(z_q-\hat z)\tag{1}
zq=z^+sg(zq−z^)(1)
其中,里面的sg就是停止梯度的意思,也就是当反向传播的时候,括号里面那一项梯度不计。
于是,便有
s
g
=
{
s
g
=
1
;
正向传播
s
g
=
0
;
反向传播
sg=\left\{\begin{matrix}sg = 1;正向传播\\sg=0;反向传播\end{matrix}\right.
sg={sg=1;正向传播sg=0;反向传播
当正向传播,把
s
g
=
1
sg=1
sg=1代入式(1),等式成立;反向传播的时候,
s
g
=
0
sg=0
sg=0,会导致直接传梯度到
z
^
\hat z
z^
也就是说,当正向传播时,有损失
L
r
e
c
=
∣
∣
x
−
G
(
z
^
+
s
g
(
z
q
−
z
^
)
)
∣
∣
2
=
∣
∣
x
−
G
(
z
q
)
∣
∣
2
L_{rec}=||x-G(\hat z+ sg(z_q-\hat z))||^2=||x-G(z_q)||^2
Lrec=∣∣x−G(z^+sg(zq−z^))∣∣2=∣∣x−G(zq)∣∣2
反向传播时,有
L
r
e
c
=
∣
∣
x
−
G
(
z
^
+
s
g
(
z
q
−
z
^
)
)
∣
∣
2
=
∣
∣
x
−
G
(
z
^
)
∣
∣
2
L_{rec}=||x-G(\hat z+ sg(z_q-\hat z))||^2=||x-G(\hat z)||^2
Lrec=∣∣x−G(z^+sg(zq−z^))∣∣2=∣∣x−G(z^)∣∣2
或许你会想,为什么可以这样做,这样做真的可以收敛吗?是可以的!
试想一下,当 z ^ \hat z z^通过与码本中找到最相近的向量替代原来的向量,得到 z q z_q zq,换句话说, z ^ \hat z z^与 z q z_q zq是近似的,那么其更新方向也是近似相等的。
5.2、码本损失
我们要构造一个足够好的码本,去表示图像的离散特征。而我们知道
z
^
\hat z
z^是编码器编码图像得到的特征,那么理所应当的,我们只需要让
L
c
o
d
e
=
z
i
∈
Z
∣
∣
E
(
x
)
−
z
q
∣
∣
2
2
L_{code}=_{z_i\in Z}||E(x)-z_q||_2^2
Lcode=zi∈Z∣∣E(x)−zq∣∣22
z
q
z_q
zq是像素点,在码本的对应最近邻向量。
作者认为,编码器 E E E和码本向量不应该以一样的速率优化,码本的是要学习把自己的向量与编码器的向量尽量的接近,码本的学习速率必须要快于编码器,否则码本自己优化,而不是向着编码器的方向优化。
所以将其拆分成两项
L
c
o
d
e
=
∣
∣
s
g
(
E
(
x
)
)
−
z
q
∣
∣
2
2
+
β
∣
∣
E
(
x
)
−
s
g
(
z
q
)
∣
∣
2
2
L_{code}=||sg(E(x))-z_q||_2^2+\beta ||E(x)-sg(z_q)||_2^2
Lcode=∣∣sg(E(x))−zq∣∣22+β∣∣E(x)−sg(zq)∣∣22
β
\beta
β是学习速率。取值
0.1
0.1
0.1到
2.0
2.0
2.0之间,但是作者经过实验发现,
β
\beta
β的取值对结果的影响很小,几乎没有。在VQVAE中,
β
=
0.25
\beta=0.25
β=0.25
5.3、总损失
故而,我们得到VQVAE的总损失函数
L
V
Q
=
L
r
e
c
+
L
c
o
d
e
\mathcal{L}_{VQ}=L_{rec}+L_{code}
LVQ=Lrec+Lcode
6、VQGAN损失
6.1、感知损失
与VQVAE相比,VQGAN的作者首先把里面的重构损失 L r e c L_{rec} Lrec换成感知损失(perceptual loss)
所谓的感知损失,在一般请看下,就是把真实的图像,和解码器复原的图像,一起送给一个神经网络,比如VGG16,把这两张图像经过VGG16,都编码成特征向量,然后计算特征向量的差别,比如
L
p
e
r
=
∣
∣
V
G
G
(
x
)
−
V
G
G
(
x
^
)
∣
∣
2
(2)
L_{per}=||VGG(x)-VGG(\hat x)||_2\tag{2}
Lper=∣∣VGG(x)−VGG(x^)∣∣2(2)
这只是举个例子,在文章中VQGAN的代码中,比这个复杂一点,它是在很多层都进行都去计算式(2)。
另外,值得注意的是,虽然论文里面写的是把重构损失换成感知损失,但是在本文上面的代码中,其实两种损失都用到了。我个人觉得也没什么不妥的,很显然重构损失是在图像层面的差异,而感知损失是特征向量的差异,所以两者加起来应当不会有什么问题。
6.2、判别网络的损失
VQGAN比VQVAE多了一个判别网络,故而加上一个判别网络的损失,以优化参数让解码器G生成的图像更好。公式如下(这是GAN的基本公式,在此不过多赘述)
L
G
A
N
(
{
E
,
G
,
Z
}
,
D
)
=
[
log
D
(
x
)
+
log
(
1
−
D
(
x
^
)
)
]
\mathcal{L}_{GAN}(\{E,G,Z\},D)=[\log D(x)+\log(1-D(\hat x))]
LGAN({E,G,Z},D)=[logD(x)+log(1−D(x^))]
因此,最终的损失函数如下
L
=
min
E
,
G
,
Z
max
D
E
x
∼
p
(
x
)
[
L
V
Q
(
E
,
G
,
Z
)
+
λ
L
G
A
N
(
{
E
,
G
,
Z
}
,
D
)
]
L=\min\limits_{E,G,Z}\max\limits_{D}\mathbb{E}_{x\sim p(x)}\left[\mathcal{L}_{VQ}(E,G,Z)+\lambda\mathcal{L}_{GAN}(\{E,G,Z\},D)\right]
L=E,G,ZminDmaxEx∼p(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)]
其中,
λ
\lambda
λ是动态变化的,其公式如下
λ
=
∇
G
L
[
L
r
e
c
]
∇
G
L
[
L
G
A
N
]
+
δ
\lambda = \frac{\nabla_{G_L}[\mathcal{L_{rec}}]}{\nabla_{G_L}[\mathcal{L}_{GAN}]+\delta}
λ=∇GL[LGAN]+δ∇GL[Lrec]
论文里面,
δ
=
1
0
−
6
\delta=10^{-6}
δ=10−6,
∇
G
L
\nabla_{G_L}
∇GL是关于解码器最后一层求梯度。
7、GPT及图像生成
在VQGAN里面,当训练好之后,就会得到一个训练好的编码器,解码器,以及码本。
可是,我们该如何生成图像呢?就是依靠transformer,换句话中,作者在实验的时候,其实用的是GPT2
以下为具体流程(以单张图像为例):
首先,从训练图像中,采样出一张图像。送给编码器,得到编码向量,并按像素,寻找在码本中的最近邻。但是,得到的最近邻我们不要它的向量值,只要对应的索引。
于是,我们得到的就是一行索引。比如indexs=【1,5,9,3,5,1,10,20】。
接着,只需要按照GPT的训练步骤,随机掩掉一部分值,比如indexs_mask=【1,?,?,3,5,?,10,?】
掩掉的这一部分(也就是问号),写入一些随机值,然后把indexs_mask送给GPT,让其预测出index。更准确的说,其实就是让它预测那些被掩码掉的部分,以这种方式,学习到索引之间的关系。
在这个过程中,VQGAN的参数固定不变,只训练GPT,训练完成后,就可以依靠GPT,随机初始化一个开始值,然后一点点的预测出后面的索引,得到了索引后,送给解码器,得到图像。
8、结束
其实VQGAN可以配合CLIP模型使用,达到文生图的效果。
以上,就是VQGAN的全部内容了,如有问题,还望指出。阿里嘎多!
系。
在这个过程中,VQGAN的参数固定不变,只训练GPT,训练完成后,就可以依靠GPT,随机初始化一个开始值,然后一点点的预测出后面的索引,得到了索引后,送给解码器,得到图像。
8、结束
其实VQGAN可以配合CLIP模型使用,达到文生图的效果。
以上,就是VQGAN的全部内容了,如有问题,还望指出。阿里嘎多!