1. DDPM模型概述
扩散模型(DM,Diffusion Model)是一类生成模型,常见的生成模型还有GAN和VAE。扩散模型分为前向阶段和逆向阶段,在前向阶段中逐步向数据中添加噪声,直至数据变成完全的高斯噪声,然后在逆向阶段学习从高斯噪声中还原为原始数据。
图1 扩散模型的前向阶段和逆向阶段. (图片来源:Ho et al. 2020)
前向阶段表示为图1中从右往左的过程。从原始图像�0开始,第�步在��−1的基础上添加噪声得到��。��只与��−1有关,直至�步后��完全变为高斯噪声。逆向过程表示为图1中从左往右的过程。首先给定高斯噪声��,通过逐步去噪,直至将原始数据�0恢复。
模型训练完后,只要给定高斯噪声就能够生成对应的图片,这样就达到了生成各种图片的目的。
2. DDPM模型原理
以DDPM(Denoising Diffusion Probabilistic Models)为例,讲解扩散模型的基本原理。
2.1 前向阶段
假设原始数据�0采样自真实分布�(�),前向阶段就是逐步往数据中添加高斯噪声,从而产生噪声序列�1,�2,⋯,��。由于��只与��−1有关,因此满足:
(1)�(�1:�|�0)=∏�=1��(��|��−1)
添加噪声的过程可以描述为条件概率分布:
(2)�(��|��−1)=�(��;1−����−1,���)
式(2)表示��服从正态分布�(1−����−1,���),其中��∈(0,1)是高斯分布的方差超参,表示��的方差,且满足�1<�2<⋯<��。随着�增大,��的分布逐步趋向于标准正态分布。
式(2)中有一项1−����−1,这其实是为了重参数化技巧(从标准正态分布�中按照�=�+�⊙�来采样分布�(�,�2�))设置的,我们可以推导一下。令��=1−��, �¯�=∏�=1���,�⋅是标准正态分布。有如下关系:
(3)��=����−1+1−����−1;reparameterization trick=��(��−1��−2+1−��−1��−2)+1−����−1=����−1��−2+(��(1−��−1)��−2+1−����−1)=����−1��−2+1−����−1�¯�−2;merge two Gaussians=…;merge all Gaussians=��¯�0+1−�¯��;�∼�(0,�)
这里合并��−2和��−1用到了性质�(0,�12�)+�(0,�22�)=�(0,(�12+�22)�),因此有:
(4)��(1−��−1)��−2+1−����−1∼�(0,(��(1−��−1)+1−��)�)
因此,从式(3)可以得到:
(5)�(��|�0)=�(��;��¯�0,(1−�¯�)�)
我们容易知道1>�1>�2>⋯>��>0,当�→∞时,�¯�=∏�=1���→0,此时��∼�(0,�)是标准正态分布。因此式(2)中均值是1−����−1可以使得最终��收敛到标准正态分布。
2.2 逆向阶段
前向阶段不断往数据中加入噪声得到��,逆向过程就是从��一步步去除噪声获得原始输入�0的过程。如果我们知道真实分布�(��|��−1),那么我们很容易生成原始数据。然而,我们获取不到真实分布,因此需要使用模型��来学习分布�(��|��−1)。其中�是模型参数:
(6)��(�0:�)=�(��)∏�=1���(��−1|��)��(��−1|��)=�(��−1;��(��,�),Σ�(��,�))
虽然我们无法求得真实分布�(��−1|��),但是在已知�0的情况下,我们可以根据贝叶斯公式得到�(��−1|��,�0):
�(��−1|��,�0)=�(��,��−1,�0)�(��,�0)=�(��|��−1,�0)⋅�(��−1,�0)�(��|�0)⋅�(�0)=�(��|��−1,�0)⋅�(��−1|�0)⋅�(�0)�(��|�0)⋅�(�0)=�(��|��−1,�0)⋅�(��−1|�0)�(��|�0)∝exp(−12((��−����−1)2��+(��−1−�¯�−1�0)21−�¯�−1−(��−�¯��0)21−�¯�))) ;because (2) and (3)=exp(−12(��2−2������−1+����−12��+��−12−2�¯�−1�0��−1+�¯�−1�021−�¯�−1−(��−�¯��0)21−�¯�)))=exp(−12((����+11−�¯�−1))��−12−(2������+2�¯�−11−�¯�−1�0)��−1+�(��,�0))
上式巧妙地通过贝叶斯公式将逆向过程转换为前向过程,且最终得到的概率密度函数和高斯概率密度函数的指数部分exp(−(�−�)22�2)=exp(−12(1�2�2−2��2�+�2�2))相对应。其中�(��,�0)是与��−1无关的常数项。令:
(7)�(��−1|��,�0)∼�(��−1;�~(��,�0),�~��)
我们将绿色部分与红色部分一一对应,有:
(8)�~�=1����+11−�¯�−1=��⋅(1−�¯�−1)��−�¯�+��=��⋅1−�¯�−11−�¯��~(��,�0)=(2������+2�¯�−11−�¯�−1�0)/(����+11−�¯�−1)=(2������+2�¯�−11−�¯�−1�0)⋅��⋅1−�¯�−11−�¯�=��(1−�¯�−1)1−�¯���+��−1��1−�¯��0
式(3)揭示了��和�0的关系,我们可以使用��来表示�0,有:
(9)�0=1�¯�(��−1−�¯���)
将式(9)代入到式(8)的�~(��,�0)项,有:
(10)��~=��(1−�¯�−1)1−�¯���+��−1��1−�¯�⋅1�¯�(��−1−�¯���)=��(1−�¯�−1)��(1−�¯�)��+����(1−�¯�)(��−1−�¯���)=����−�¯���+(1−��)��−(1−��)1−�¯�����(1−�¯�)=(1−�¯�)��−(1−��)1−�¯�����(1−�¯�)=1��(��−1−��1−�¯���)
2.3 模型训练
模型的目标是最大化对数似然,即优化在�0∼�(�0)下��(�0)的交叉熵:
���=��(�0)[−log��(�0)]
和VAE类似,可以使用VLB(Variational Lower Bound)来优化负对数似然:
与无关与无关(11)−log��(�0)≤−log��(�0)+���(�(�1:�|�0)||��(�1:�|�0))=−log��(�0)+��1:�∼�(�1:�|�0)[log�(�1:�|�0)��(�0:�)/��(�0)];Bayes' rule=−log��(�0)+��(�1:�|�0)[log�(�1:�|�0)��(�0:�)+log��(�0)⏟与�无关]=−log��(�0)+��(�1:�|�0)[log�(�1:�|�0)��(�0:�)]+log��(�0)=��(�1:�|�0)[log�(�1:�|�0)��(�0:�)]
注意到式(11)的左边求期望就是���,我们可以对式(11)右边也求期望,根据Fubini定理,有:
��(�0)(��(�1:�|�0)[log�(�1:�|�0)��(�0:�)])=��(�0:�)[log�(�1:�|�0)��(�0:�)]≜����
于是我们找到了���的上界,即:
(12)���=��(�0)[−log��(�0)]≤����
可以通过最小化����来最小化���。
除了从式(11)中利用Fubini定理推导出式(12),还可以使用Jensen不等式直接推导,具体如下:
���=−��(�0)log��(�0)=−��(�0)log(∫��(�0:�)��1:�);Law of total probability=−��(�0)log(∫�(�1:�|�0)��(�0:�)�(�1:�|�0)��1:�)=−��(�0)log(��(�1:�|�0)��(�0:�)�(�1:�|�0))≤−��(�0:�)log��(�0:�)�(�1:�|�0);�(�[�])≤�[�(�)]=��(�0:�)log�(�1:�|�0)��(�0:�)=����
虽然����能够作为模型的优化目标,但目前为止它是不好计算的,所以还需要进一步化简:
(13)����=��(�0:�)[log�(�1:�|�0)��(�0:�)]=��[log∏�=1��(��|��−1)��(��)∏�=1���(��−1|��)]=��[−log��(��)+∑�=1�log�(��|��−1)��(��−1|��)];see Eq.14=��[−log��(��)+∑�=2�log�(��|��−1)��(��−1|��)+log�(�1|�0)��(�0|�1)]=��[−log��(��)+∑�=2�log(�(��−1|��,�0)��(��−1|��)⋅�(��|�0)�(��−1|�0))+log�(�1|�0)��(�0|�1)]=��[−log��(��)+∑�=2�log�(��−1|��,�0)��(��−1|��)+∑�=2�log�(��|�0)�(��−1|�0)+log�(�1|�0)��(�0|�1)]=��[−log��(��)+∑�=2�log�(��−1|��,�0)��(��−1|��)+log�(��|�0)�(�1|�0)+log�(�1|�0)��(�0|�1)]=��[log�(��|�0)��(��)+∑�=2�log�(��−1|��,�0)��(��−1|��)−log��(�0|�1)]=��[���(�(��|�0)||��(��))⏟��+∑�=2����(�(��−1|��,�0)||��(��−1|��))⏟��−1−log��(�0|�1)⏟�0]
其中,式(13)的第3行到第4行的变换可以参考下面过程:
(14)�(��|��−1)=�(��,��−1)�(��−1)=�(��,��−1|�0)�(�0)�(��−1|�0)�(�0)=�(��|��−1,�0)�(��|�0)�(��−1|�0)
于是,我们可以将����拆解成如下形式:
(15)����=��+��−1+⋯+�0s.t.��=���(�(��|�0)||��(��))��=���(�(��−1|��,�0)||��(��−1|��))1≤�≤�−1�0=log��(�0|�1)
除了�0外,每一项KL项都是比较两个高斯分布。��是一个常数,因为��是标准正态分布,而�没有可学习的参数,所以训练时可以忽略��这一项。在Ho et al. 2020工作中,�0这一项是单独建模的,接下来我们主要讨论��怎么计算。
��要求分布�(��|��−1)=�(��−1;�~(��,�0),�~��)和��(��−1|��)=�(��−1;��(��,�),Σ�(��,�))之间的KL散度。因此我们希望模型��(��,�)预估�~�=1��(��−1−��1−�¯���)(式10)。由于��在训练阶段会作为输入,因此它是已知的,我们可以转而让模型去预估噪声��,即:
(16)��(��,�)=1��(��−1−��1−�¯���(��,�))
��(��,�)需要预测在前向阶段加入到��的噪声。因此有:
(17)��(��−1|��)=�(��−1;1��(��−1−��1−�¯���(��,�)),Σ�(��,�))
于是我们可以根据多元正态分布的KL散度进行如下推导(忽略常数项):
(18)��=��0,�[12‖Σ�(��,�)‖22‖��~(��,�0)−��(��,�)‖2]=��0,�[12‖Σ�(��,�)‖22‖1��(��−1−��1−�¯���)−1��(��−1−��1−�¯���(��,�))‖2]=��0,�[(1−��)22��(1−�¯�)‖Σ�(��,�)‖22‖��−��(��,�))‖2]=��0,�[(1−��)22��(1−�¯�)‖Σ�(��,�)‖22‖��−��(��¯�0+1−�¯���,�))‖2];because Eq.3
实验中,Ho et al. 2020发现忽略优化式(17)的的权重部分效果更好,也就是:
(19)��simple=��∼[1,�],�0,��[‖��−��(��,�)‖2]=��∼[1,�],�0,��[‖��−��(��¯�0+1−�¯���,�))‖2]
于是,最终DDPM模型的训练和采样算法如图2所示。
图2 DDPM模型的训练和采样算法. (图片来源:Ho et al. 2020)
训练阶段,每次采样数据�0,然后采样时间步�和该步的高斯噪声��,然后根据模型预测的高斯噪声与实际的噪声计算损失并更新参数。
采样阶段(逆向阶段),首先从标准高斯分布中采样��,然后串行执行逆向过程。在每一步中,首先从标准正态分布中采样�用于重参数化。我们需要根据��生成��−1,首先根据式(17)计算出��−1的均值,然后根据式(8)得到��−1的方差��=�~�,拿到均值和方差后就可以通过重参数化技巧得到��−1,最终生成�0。
3. 模型加速
DDPM模型整个流程分为前向阶段和逆向阶段,前向阶段和逆向阶段可能有上千步,因此用它生成样本实际上是很慢的,远远低于GAN生成样本的速度。一种简单的改进方法是进行等距采样(Nichol & Dhariwal, 2021),比如从�步中等距抽取�步,新的采样时间步可以表示为{�1,�2,…,��},其中�1<�2<⋯<��∈[1,�],�<�。
参考资料
[1] lilianweng,"What are Diffusion Models?",What are Diffusion Models? | Lil'Log.
[2] 珍妮的选择,"扩散模型 (Diffusion Model) 简要介绍与源码分析",扩散模型 (Diffusion Model) 简要介绍与源码分析 - 知乎.