13. Pretraining is All You Need for Image-to-Image Translation
该文提出一种基于预训练扩散模型的图像转换方法,称为PITI。其思想并不复杂,就是借鉴现有视觉和NLP领域中常见的预训练方法,考虑预先在一个大规模的任务无关数据集上对扩散模型进行预训练,使其具备一个高度语义化的空间。然后,再针对特定任务对模型进行微调训练,此时微调过程只需要关注与任务相关的输入信息,而困难的图像生成工作,比如渲染一个合理布局和真实的纹理,将根据预训练时得到的知识来完成。
在本文中,作者采用GLIDE模型作为基础模型,在一个包含67M个文本-图像对的数据集上进行预训练。使用基础模型进行图像生成的过程,可以看作是对原始输入
x
0
\boldsymbol{x}_{0}
x0和条件
y
\boldsymbol{y}
y进行编码和解码的过程
x
t
=
D
~
(
E
~
(
x
0
,
y
)
)
\boldsymbol{x}_{t}=\tilde{\mathcal{D}}\left(\tilde{\mathcal{E}}\left(\boldsymbol{x}_{0}, \boldsymbol{y}\right)\right)
xt=D~(E~(x0,y))其中
D
~
\tilde{\mathcal{D}}
D~和
E
~
\tilde{\mathcal{E}}
E~分别表示解码和编码器。微调训练包含两个阶段,第一阶段时锁定解码器的参数,只对编码器进行训练;第二阶段是对两者进行联合训练。
由于扩散模型生成的结果通常分辨率较低,如64*64,因此作者也采用了一个基于扩散模型的上采样器,对生成结果进行分辨率提升。然而,作者发现提升的结果存在过度平滑的问题,因此作者又引入了GAN中常见的感知损失和对抗损失,如下式
L
perc
=
E
t
,
x
0
,
ϵ
∥
ψ
m
(
x
^
0
t
)
−
ψ
m
(
x
0
)
∥
,
L
a
d
v
=
E
t
,
x
0
,
ϵ
[
log
D
θ
(
x
^
0
t
)
]
+
E
x
0
[
log
(
1
−
D
θ
(
x
0
)
)
]
\begin{aligned} \mathcal{L}_{\text {perc }} & =\mathbb{E}_{t, \boldsymbol{x}_{0}, \boldsymbol{\epsilon}}\left\|\boldsymbol{\psi}_{m}\left(\hat{\boldsymbol{x}}_{0}^{t}\right)-\boldsymbol{\psi}_{m}\left(\boldsymbol{x}_{0}\right)\right\|, \\ \mathcal{L}_{\mathrm{adv}} & =\mathbb{E}_{t, \boldsymbol{x}_{0}, \boldsymbol{\epsilon}}\left[\log D_{\theta}\left(\hat{\boldsymbol{x}}_{0}^{t}\right)\right]+\mathbb{E}_{\boldsymbol{x}_{0}}\left[\log \left(1-D_{\theta}\left(\boldsymbol{x}_{0}\right)\right)\right] \end{aligned}
Lperc Ladv=Et,x0,ϵ
ψm(x^0t)−ψm(x0)
,=Et,x0,ϵ[logDθ(x^0t)]+Ex0[log(1−Dθ(x0))]其中
x
^
0
t
=
(
x
t
−
1
−
α
t
ϵ
θ
(
x
t
,
y
,
t
)
)
/
α
t
\hat{\boldsymbol{x}}_{0}^{t}=\left(\boldsymbol{x}_{t}-\sqrt{1-\alpha_{t}} \boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t}, \boldsymbol{y}, t\right)\right) / \sqrt{\alpha_{t}}
x^0t=(xt−1−αtϵθ(xt,y,t))/αt表示预测得到的生成结果。
最后,作者发现在常规的无分类器引导的扩散模型CDM中
ϵ
^
θ
(
x
t
∣
y
)
=
ϵ
θ
(
x
t
∣
y
)
+
w
⋅
(
ϵ
θ
(
x
t
∣
y
)
−
ϵ
θ
(
x
t
∣
∅
)
)
\hat{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{t} \mid \boldsymbol{y}\right)=\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t} \mid \boldsymbol{y}\right)+w \cdot\left(\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t} \mid \boldsymbol{y}\right)-\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t} \mid \emptyset\right)\right)
ϵ^θ(xt∣y)=ϵθ(xt∣y)+w⋅(ϵθ(xt∣y)−ϵθ(xt∣∅))条件的引入会导致估计噪声的均值和方差发生漂移,如下
μ
^
=
μ
+
w
(
μ
−
μ
∅
)
\hat{\mu}=\mu+w\left(\mu-\mu_{\emptyset}\right)
μ^=μ+w(μ−μ∅)
σ
^
2
=
(
1
+
w
)
2
σ
2
+
w
2
σ
∅
2
\hat{\sigma}^{2}=(1+w)^{2} \sigma^{2}+w^{2} \sigma_{\emptyset}^{2}
σ^2=(1+w)2σ2+w2σ∅2并且这个偏移会随着迭代去噪过程逐渐累积,最终导致生成图像过饱和或者过度平滑。为此,作者提出一种规则化处理方式,如下式
ϵ
~
θ
(
x
t
∣
y
)
=
σ
σ
^
(
ϵ
^
θ
(
x
t
∣
y
)
−
μ
^
)
+
μ
\tilde{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{t} \mid \boldsymbol{y}\right)=\frac{\sigma}{\hat{\sigma}}\left(\hat{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{t} \mid \boldsymbol{y}\right)-\hat{\mu}\right)+\mu
ϵ~θ(xt∣y)=σ^σ(ϵ^θ(xt∣y)−μ^)+μ
作者在"掩码到图像"、"轮廓到图像"和”几何体到图像“等图像转换任务中,对本文提出的方法进行了测试,其效果如下