1、前言
从讲扩散模型到现在。我们很少讲过条件生成(Stable DIffusion曾提到过一点),所以本篇内容。我们就来具体讲一下条件生成。这一部分的内容我就不给原论文了,因为那些论文并不只讲了条件生成,还有一些调参什么的。并且推导过程也相对复杂。我们从一个比较简单的角度出发。
参考论文:Understanding Diffusion Models: A Unified Perspective (arxiv.org)
参考代码:
classifier guidance:GitHub - openai/guided-diffusion
classifier-free guidance:GitHub - coderpiaobozhe/classifier-free-diffusion-guidance-Pytorch: a simple unofficial implementation of classifier-free diffusion guidance
视频:[扩散模型条件生成——Classifier Guidance和Classifier-free Guidance原理解析-哔哩哔哩]
2、常用的条件生成方法
在diffusion里面,如何进行条件生成呢?我们不妨回忆一下在Stable Diffusion里面的一个常用做法。即在训练的时候。给神经网络输入一个条件。
L
=
∣
∣
ϵ
−
ϵ
θ
(
x
t
,
t
,
y
)
∣
∣
2
L=||\epsilon-\epsilon_{\theta}(x_t,t,y)||^2
L=∣∣ϵ−ϵθ(xt,t,y)∣∣2
里面的y就是条件。至于为什么有效,请看我之前写过的Stable DIffusion那篇文章。在此不过多赘述了。我们来讲这种方法所存在的问题。
很显然的,这种训练的方式,会有一个问题,那就是神经网络或许会学会忽略或者淡化掉我们输入的条件信息。因为就算我们不输入信息,他也照样能够生成。
接下来我们来讲两种更为流行的方法——分类指导器(Classifier Guidance) 和无分类指导器( Classifier-Free Guidance)
3、Classifier Guidance
为了简单起见。我们从分数模型的角度出发。
回忆一下在SDE里面的结论。其反向过程为
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
(1)
\mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{1}
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dwˉ(1)
如果施加条件的话,还是根据Reverse-time diffusion equation models - ScienceDirect这篇论文,可得条件生成时的反向SDE为
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
∇
x
log
p
t
(
x
∣
y
)
]
d
t
+
g
(
t
)
d
w
ˉ
(2)
\mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x|y)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{2}
dx=[f(x,t)−g(t)2∇xlogpt(x∣y)]dt+g(t)dwˉ(2)
我们利用贝叶斯公式,对
∇
x
log
p
t
(
x
∣
y
)
\nabla x \log p_t(x|y)
∇xlogpt(x∣y)进行处理
∇
x
log
p
t
(
x
∣
y
)
=
∇
x
log
p
t
(
y
∣
x
)
p
t
(
x
)
p
t
(
y
)
=
∇
x
(
log
p
t
(
y
∣
x
)
+
log
p
t
(
x
)
−
log
p
t
(
y
)
)
=
∇
x
log
p
t
(
x
)
+
∇
x
log
p
t
(
y
∣
x
)
\begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log\frac{p_t(y|x)p_t(x)}{p_t(y)}\\=&\nabla_x\left(\log p_t(y|x)+\log p_t(x)-\log p_t(y)\right)\\=&\nabla_x \log p_t(x)+\nabla_x\log p_t(y|x)\end{aligned}\nonumber
∇xlogpt(x∣y)===∇xlogpt(y)pt(y∣x)pt(x)∇x(logpt(y∣x)+logpt(x)−logpt(y))∇xlogpt(x)+∇xlogpt(y∣x)
第二个等号到第三个等号是因为对
log
p
t
(
y
)
\log p_t(y)
logpt(y)关于x求梯度等于0(
log
p
t
(
y
)
\log p_t(y)
logpt(y)与x无关)
把它代入Eq.(2)可得
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
(
∇
x
log
p
t
(
x
)
+
∇
x
log
p
t
(
y
∣
x
)
)
]
d
t
+
g
(
t
)
d
w
ˉ
(3)
\mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\left(\nabla_x\log p_t(x)+\nabla_x\log p_t(y|x)\right)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{3}
dx=[f(x,t)−g(t)2(∇xlogpt(x)+∇xlogpt(y∣x))]dt+g(t)dwˉ(3)
对比Eq.(1)和Eq.(3)。我们不难发现,它们的差别,居然是只多了一个
∇
x
log
p
t
(
y
∣
x
)
\nabla_x\log p_t(y|x)
∇xlogpt(y∣x)
p t ( y ∣ x ) p_t(y|x) pt(y∣x)是什么?是以 x x x作为条件,时间为t对应条件y的概率。我们可以怎么求呢?该怎么求出来呢?
当然是使用神经网络了。也就是说,我们可以额外设定一个神经网络,该神经网络输入是 x t x_t xt,输出是条件为y的概率
所以,实际上我们现在需要训练两部分,一部分是 ∇ x log p t ( x ) \nabla_x\log p_t(x) ∇xlogpt(x),这我们在SDE中已经讲过该如何训练了。
另一个就是 ∇ x log p t ( y ∣ x ) \nabla_x\log p_t(y|x) ∇xlogpt(y∣x),他就是一个分类神经网络网络。训练好之后,我们就可以使用Eq.(3)通过不同的数值求解器,进行优化了。
作者在此基础上,又引入了一个控制参数
λ
\lambda
λ
∇
x
log
p
t
(
x
∣
y
)
=
∇
x
log
p
t
(
x
)
+
λ
∇
x
log
p
t
(
y
∣
x
)
(4)
\nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(y|x)\tag{4}
∇xlogpt(x∣y)=∇xlogpt(x)+λ∇xlogpt(y∣x)(4)
当
λ
=
0
\lambda=0
λ=0,表示不加入任何条件。当
λ
\lambda
λ很大时,模型会产生大量附带条件信息的样本。
这种方法的一个缺点就是,需要额外学习一个分类器 p t ( y ∣ x ) p_t(y|x) pt(y∣x)
4、Classifier-Free Guidance
之前推出
∇
x
log
p
t
(
x
∣
y
)
=
∇
x
log
p
t
(
x
)
+
∇
x
log
p
t
(
y
∣
x
)
(5)
\nabla_x \log p_t(x|y)=\nabla_x \log p_t(x)+\nabla_x\log p_t(y|x)\tag{5}
∇xlogpt(x∣y)=∇xlogpt(x)+∇xlogpt(y∣x)(5)
把该式子代入Eq.(4)可得
∇
x
log
p
t
(
x
∣
y
)
=
∇
x
log
p
t
(
x
)
+
λ
(
∇
x
log
p
t
(
x
∣
y
)
−
∇
x
log
p
t
(
x
)
)
=
∇
x
log
p
t
(
x
)
+
λ
∇
x
log
p
t
(
x
∣
y
)
−
λ
∇
x
log
p
t
(
x
)
=
(
1
−
λ
)
∇
x
log
p
t
(
x
)
+
λ
∇
x
log
p
t
(
x
∣
y
)
\begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log p_t(x)+\lambda\left(\nabla_x\log p_t(x|y)-\nabla_x\log p_t(x)\right)\\=&\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(x|y)-\lambda\nabla_x\log p_t(x)\\=&\left(1-\lambda\right)\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(x|y)\end{aligned}\nonumber
∇xlogpt(x∣y)===∇xlogpt(x)+λ(∇xlogpt(x∣y)−∇xlogpt(x))∇xlogpt(x)+λ∇xlogpt(x∣y)−λ∇xlogpt(x)(1−λ)∇xlogpt(x)+λ∇xlogpt(x∣y)
此时我们注意到,当
λ
=
0
\lambda=0
λ=0是,第二项完全为0,会忽略掉条件;当
λ
=
1
\lambda=1
λ=1时,使用第二项,第二项就是附带有条件情况下的分布分数网络;而当
λ
>
1
\lambda> 1
λ>1,模型会优化考虑条件生成样本,并且远离第一项的无条件分数网络的方向,换句话说,它降低了生成不使用条件信息的样本的概率,而有利于生成明确使用条件信息的样本。
事实上,如果你看了free-Classifier Guidance这篇论文,会发现我们的结论不一样。
其实论文里面的控制参数是
w
w
w,也就是说,Eq.(4)就变成了这样
∇
x
log
p
t
(
x
∣
y
)
=
∇
x
log
p
t
(
x
)
+
w
∇
x
log
p
t
(
y
∣
x
)
\nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+w\nabla_x\log p_t(y|x)
∇xlogpt(x∣y)=∇xlogpt(x)+w∇xlogpt(y∣x)
我们把控制参数改成
1
+
w
1+w
1+w不会有任何影响
∇
x
log
p
t
(
x
∣
y
)
=
∇
x
log
p
t
(
x
)
+
(
1
+
w
)
∇
x
log
p
t
(
y
∣
x
)
\nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+(1+w)\nabla_x\log p_t(y|x)
∇xlogpt(x∣y)=∇xlogpt(x)+(1+w)∇xlogpt(y∣x)
把Eq.(5)代入该式子
∇
x
log
p
t
(
x
∣
y
)
=
∇
x
log
p
t
(
x
)
+
(
1
+
w
)
(
∇
x
log
p
t
(
x
∣
y
)
−
∇
x
log
p
t
(
x
)
)
=
∇
x
log
p
t
(
x
)
+
(
1
+
w
)
∇
x
log
p
t
(
x
∣
y
)
−
(
1
+
w
)
∇
x
log
p
t
(
x
)
=
(
1
+
w
)
∇
x
log
p
t
(
x
∣
y
)
−
w
∇
x
log
p
t
(
x
)
(6)
\begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log p_t(x)+(1+w)\left(\nabla_x\log p_t(x|y)-\nabla_x\log p_t(x)\right)\\=&\nabla_x\log p_t(x)+(1+w)\nabla_x\log p_t(x|y)-(1+w)\nabla_x\log p_t(x)\\=&(1+w)\nabla_x\log p_t(x|y)-w\nabla_x\log p_t(x)\end{aligned}\tag{6}
∇xlogpt(x∣y)===∇xlogpt(x)+(1+w)(∇xlogpt(x∣y)−∇xlogpt(x))∇xlogpt(x)+(1+w)∇xlogpt(x∣y)−(1+w)∇xlogpt(x)(1+w)∇xlogpt(x∣y)−w∇xlogpt(x)(6)
这就是原论文里面的结论。
那么接下来,我们来探讨一下该如何去训练。
对于 ∇ x log p t ( x ) \nabla_x\log p_t(x) ∇xlogpt(x),这个不用说了,之前我们训练的就是这个;如何计算 ∇ x log p t ( x ∣ y ) \nabla_x\log p_t(x|y) ∇xlogpt(x∣y)呢,它实际上就是在给定y的情况下,求出 p t ( x ∣ y ) p_t(x|y) pt(x∣y)。那我们可以怎么做呢?
在NCSN,我们是使用一个加噪分布 q ( x ~ ∣ x ) q(\tilde x|x) q(x~∣x)取代 p ( x ) p(x) p(x),而从让它是可解的。
对于
p
t
(
x
∣
y
)
p_t(x|y)
pt(x∣y),即便是加多了一个条件之后,我们仍然建模为
q
(
x
~
∣
x
)
q(\tilde x|x)
q(x~∣x),也就是说,我们仍然把它建模成一个正向加噪过程。因此,无论是否增加条件。最终的损失函数结果都是
L
=
∣
∣
s
θ
−
∇
x
log
q
(
x
~
∣
x
)
∣
∣
2
=
∣
∣
s
θ
−
∇
x
log
q
(
x
t
∣
x
0
)
∣
∣
2
L=||s_\theta-\nabla_x\log q(\tilde x|x)||^2=||s_\theta-\nabla_x\log q(x_t|x_0)||^2
L=∣∣sθ−∇xlogq(x~∣x)∣∣2=∣∣sθ−∇xlogq(xt∣x0)∣∣2
后者是通过SDE统一的结果(我在SDE那一节讲过)
那该如何体现条件y呢?其实我们在第二节的时候已经说过了,就是在里面神经网络的输出加入一个条件y。
L
=
∣
∣
s
θ
(
x
t
,
t
,
y
)
−
∇
x
log
q
(
x
t
∣
x
0
)
∣
∣
2
(7)
L=||s_\theta(x_t,t,y)-\nabla_x\log q(x_t|x_0)||^2\tag{7}
L=∣∣sθ(xt,t,y)−∇xlogq(xt∣x0)∣∣2(7)
而不施加条件的时候,长这样
L
=
∣
∣
s
θ
(
x
t
,
t
)
−
∇
x
log
q
(
x
t
∣
x
0
)
∣
∣
2
(8)
L=||s_\theta(x_t,t)-\nabla_x\log q(x_t|x_0)||^2\tag{8}
L=∣∣sθ(xt,t)−∇xlogq(xt∣x0)∣∣2(8)
由Eq.(5)可知,我们需要训练两种情况,一种是有条件的,对应Eq.(7);另外一种是无条件的,对应Eq.(8)。
理论上,我们其实也是要训练两个神经网络。但实际上,我们可以把他们结合成一种神经网络。
具体操作就是把无条件的情况作为一种特例。
当我们训练有条件的神经网络的时候,会照样把条件输入进网络里面。而训练无条件的时候,我们构造一个无条件的标识符,把它作为条件输入给神经网络,比如对于所有无条件的情况,我都构造一个0作为条件输入到神经网络里面。通过这种方式,我们就可以把两个网络变成一个网络了,
对于损失函数,直接使用Eq.(7)。我们在SDE里面讲过
∇
x
log
p
(
x
)
=
−
1
σ
ϵ
\nabla_x \log p(x)=-\frac{1}{\sigma}\epsilon
∇xlogp(x)=−σ1ϵ。所以我们最终我们把预测噪声,变成了预测分数。我们同样可以把它变回来,变成预测分数
L
=
∣
∣
ϵ
−
ϵ
θ
(
x
t
,
t
,
y
)
∣
∣
2
L=||\epsilon-\epsilon_{\theta}(x_t,t,y)||^2
L=∣∣ϵ−ϵθ(xt,t,y)∣∣2
所以损失函数就变成了这样。在训练的时候,作者设定一个大于等于0,小于等于1的超参数
p
u
n
c
o
n
d
p_{uncond}
puncond,它的作用就是判断是否需要输入条件(从0-1分布采样一个值,大于
p
u
n
c
o
n
d
p_{uncond}
puncond则使用条件,反之则不使用)。也就是说,这相当于dropout一样,随机舍弃掉一些条件,把他们作为无条件的情况(因为我们既要学习有条件的,又要学习无条件的)。所以,最终的训练过程就是这样
其中里面的 λ \lambda λ你就当作是时刻t吧(其实不是,其实是时刻t的噪声(噪声的初始化不一样,不是传统的等差数列,是用三角函数初始化的)。由于与本篇内容无关,故而忽略),c是条件。
同样的,采用过程使用Eq.(6)的结构进行采样