接GANs和Diffusion模型(2)
扩散(Diffusion)模型
生成学习三重困难(Trilemma)
指生成学习(genrative learning)的模型都需要满足三个需求:
- 高质量的采样(High Quality Samples):模型应该能生成非常高质量的采样
- 快速采样(Fast Sampling):模型应该能快速地从噪声中产生采样
- 多样性或模式覆盖(Diversity or Mode Coverage):模型应该能提供好的多样性或者模式覆盖
好的多样性或模式覆盖往往是生成模型的一个非常重要的需求。然而,事实上生成模型很难同时满足上述3个需求,不同的生成模型会在这3个需求中作出权衡和取舍。根据你的实际用例做出正确的权衡,这就是生成学习的三重困境(trilemma)。
前面学习的GAN模型,能够较好地满足前2个需求,但无法满足第3个需求,即无法提供好的多样性或者模式覆盖。如果你遇到了一个现有采样很少的问题,GANs并不是一个很好的选择。此时,你或许可以选择一个能够较好地满足第1/3需求的生成学习模型,即扩散模型(Diffusion Model)。
去噪扩散概率模型介绍
Denoising Diffusion Probabilistic Model,简写为DDPM。
DDPM首次发表于2020年的文献“Denosing Diffusion Probabilistic Models”:https://arxiv.org/pdf/2006.11239.pdf。
FID(Fréchet Inception Distance,弗雷切特起始距离)
用于评价一个生成模型产生的图片质量的指标。FID分数会比较生成图片的分布和用于训练生成器的真实图片的分布。由于FID是在测量两种图片分布之间的距离,因此一个完美的FID分数是0.0。通常,一个模型的FID分数越低,说明这个模型产生的图片质量越高。
关于FID的详细公式,可以参考这篇博文:【基础知识】FID(Fréchet Inception Distance)公式及解释
DDPM的工作方式
DDPM的实现实际上是非常复杂的,但是这个模型的工作方式给人的直观感觉却并不复杂。我们从一个来自真实数据集的图片开始,且将这个数据标记为x0,则
- 扩散过程会反复地(通常有1000次)向这个x0添加非常小的噪声(按照原文的意思,是无穷小:infinitesimal)
- 所有添加的噪声有自己的分布,即独立高斯噪声
- 每次添加的噪声会在前一次的基础上进行一个微小地放大
- 直到我们得到一个纯噪声的图片为止,此时已经不能从当前图片中分辨出原始的图像了
如下图所示,从最左边的图片开始,每次添加一个独立高斯分布的、且逐渐放大的噪声,直到最后成为纯噪声图片。
此刻,一个想法是:上述过程如果可逆,那么我们便可以从一个纯噪声图片得到一张和真实采样接近的图片。DDPM文献的作者对上述想法做出了数学证明,并得出了这样一个结论:如果添加噪声的过程(扩散)满足我们给出的条件,即噪声按照一种特殊的方式逐渐添加,则对一张图片添加噪声的前向过程、和对一张图片去噪的反向过程,会具有相同的函数形式。
下图给出了反向去噪的示意。
于是,通过成功训练一个扩散模型的参数,我们可以得到具有相同参数的去噪模型,然后通过这个去噪模型就可以根据噪声产生新的采样了。
前向扩散过程
假定扩散过程从x0开始,直到xT结束
在实践中,这个T比较大,一般在1000~4000之间。最终得到的图片xT,对x0而言,接近于一个高斯分布(又称为正态分布,其中均值为0,方差为1的高斯分布又称为标准正态分布)
于是最终的xT接近于一个纯噪声。
同时,在扩散模型中,前向扩散过程的模型是一个马克可夫链(Markov Chain),即每一个步骤之和其前一个步骤有关,和再之前的所有步骤都不相关;即x(n+1)只依赖于x(n),而独立于x(n)之前的所有状态。于是,在扩散过程中,每一次产生的图片只和其前一张图片有关,和更早产生的图片无关。用公式表达为:
公式中
- 等号左边表示:当(t-1)时刻值为x(t-1)时、t时刻值为xt的状态转移分布
- 等号右边表示一个高斯分布,均值为,方差为,是一个超参数。
Beta的下标t表明这个方差并不是固定的,而是随步骤变化的,且满足
随着t的增大,均值会无限接近于0,方差无限接近于1,于是得到前面的标准正态分布。
反向去噪过程
反向过程和扩散过程相反,从xT开始,直到x0结束
反向过程也是一个马尔科夫链。给定t时刻的图片xt,则(t-1)时刻产生的图片x(t-1)的分布为:
反向过程的分布依赖于t时刻的采样xt。反向模型的神经网络模型的目标是学习去噪过程的均值和方差,以便从噪声中产生图片。
扩散模型的直观感觉:可变自动编码器
在前面的描述的扩散过程中,x0通常是真实数据,x1,x2,...,xT是通过不断增加噪声产生的,也称为潜在变量(latent variables),因此扩散模型也可以看作一种潜在变量生成模型(latent variable generator model)。
事实上,还有另一种非常流行的潜在变量生成模型:可变自动编码器,Variational Autoencoder (VAE)。并且我们会发现,两者其实很相似。
自动编码器(Autoencoder)
这里讲到的自动编码器是特指的用于降维的一种神经网络模型。自动编码器会从基础数据(underlying data)中抽取潜在变量或者特性。假定输入自动编码器的数据是图片,这些图片在自动编码器中会以较低的维度被编码。这些较低维度的编码表示了输入图片的潜在或重要的特征。潜在特征尽可能多地包含了原始图片的特征信息、但却比原始图片具有更低的维度。自动编码器利用这些抽取出来的较低维度的潜在特征,产生尽可能接近原始输入的数据。如下图所示:
自动编码器模型分为编码器和解码器两个部分。其中,编码器会输出一个较低维度的潜在特征z,然后由解码器恢复为一个具有原始维度的新数据x'。x'应该尽可能和x一样。
可变自动编码器(Variational Autoencoder)
可自动编码器相比,将输出由确定的变为不确定的,即增加了概率;因此可以看作是概率自动编码器(Probabilistic Autoencoder)。此外,普通的自动编码器实际上是在输出侧重构了输入,并不能产生新的采样,因此不能看作生成模型。但是可变自动编码器可以产生新的采样,是一种生成自动编码器(Generative Autoencoder),因此是一种生成模型。如下图所示:
VAE编码器产生的不再是一个固定的数据,而是一个具有均值和方差的概率分布。也就是说,VAE中的z是一个满足一定概率分布的随机变量。VAE的解码器同样也不再产生一个固定的输出,而是一个满足一定概率分布的随机变量。
VAE和DDPM的关系
DDPM的扩散过程和VAE的编码器十分相似,扩散过程通过逐步添加噪声最终产生一个具有标准高斯分布的随机变量xT
DDPM的反向过程和VAE的解码器十分相似
正是因为如此,很多训练VAE的概念,也同样可以用于训练DDPM上。
程序演示
去噪扩散概率模型(DDPM)的实际数学模型和实现其实是非常复杂的,本文不会对这些内容进行详细描述。事实上,在github上已经有很多DDPM的实现了,包括文献作者的实现。本文的演示中会借用这个github上的模型实现:https://github.com/awjuliani/pytorch-diffusion。该实现完成于2022年。
由于DDPM比较复杂,因此本节程序演示的目标并不是去理解模型本身,而是初步体会如何运用现有的DDPM即可。即,如何通过使用pytorch和pytorch-lighting、利用github上的DDPM、对现有的数据集进行学习,并产生新的采样。
github模型简介
打卡这个链接,可以看到这个模型的实现包含几个python文件:data.py,model.py,modules.py,用于实现模型;以及一个entry.ipynb文件,可以认为是一个驱动程序,或者对模型的使用,即创建和训练模型等。
接下来的“Readme”部分说明了,这个程序实现了文献中的DDPM。
同时可以看到,程序使用了MNIST,Fashion-MNIST和CIFAR这3个数据集进行了验证。其中
- MNIST是一个28x28像素的手写数字图片数据集,在神经网络必备基础中介绍过
- Fashion-MNIST是一个28x28像素的流行图片数据集,在GANs和Diffusion模型(1)中介绍过
- CIFAR是一个32x32的彩色图片数据集,分为CIFAR-10(10个类别的图片集)和CIFAR-100(100个类别的图片集),包括飞机,汽车,鸟,猫,轮船等等。链接:CIFAR-10 and CIFAR-100 datasets
再来看本模型的源文件
主要的文件有4个(请依次打开每个文件浏览一下):
- data.py:负责导入数据。其中会对导入的数据做scaling和centering
- modules.py:定义了DDPM相关的层,如DoubleConv(双卷积层),Down(下采样层,主要指汇聚层pooling),Up(上采样层,主要指转置卷积,transposed convolution),OutConv(输出卷积层)等。此外,该模块还定义了SelfAttention(自注意机制),及其封装SAWrapper(Self Attention Wrapper)
- model.py:定义DDPM本身。包括forward(前向扩散),get_loss(损失函数),beta,alpha(计算模型中的beta,alpha参数,可以查阅文献中关于这些参数的定义),denoise_sample(反向去噪),training_step(训练),validation_step(校验)等等。
- entry.ipynb:前面三个文件是DDPM的完整定义,这个文件是对DDPM的运用示例。在使用这个DDPM的时候,可以参考这个文件。
使用DDPM的程序
本节通过参考entry.ipynb,定义自己的程序,利用DDPM训练MNIST数据集,并产生自己的采样。
1. 挂在google colab的网盘(drive)
本程序会将github上的模型文件拷贝下来,存放到自己的google colab网盘中,因此需要执行这个步骤。否则,程序将无法识别本地文件。
# check my current drive
import os
os.getcwd()
# mount my drive in google colab
from google.colab import drive
drive.mount('/content/drive')
# change to my working directory, all sources are in this folder
%cd /content/drive/My Drive/Colab Notebooks/DeepLearning/gan_diffusion/ddpm
其中
- getcwd()的作用是获取本地网盘的路径,默认一般是'/content'。
- "%cd"的作用是将工作路径从google colab的“根目录”切换到ddpm的路径,方便后面使用时,不再需要添加相对路径。
2. 安装&导入python包
!pip install pytorch_lightning
#import pytorch, whose name is torch
import torch
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from data import DiffSet
from model import DiffusionModel
from torch.utils.data import DataLoader
其中,
- Google colab默认不带pytorch_lightning,因此需要安装。
- “from data import DiffSet”和“from model import DiffusionModel”均来自本地定义的python模块,即前面讲到的data.py和model.py。
3. 定义模型参数
diffusion_steps = 1000
dataset_choice = "MNIST"
max_epoch = 10
batch_size = 128
# tracks if the model will be loaded from a checkpoint, not used yet
load_model = False
load_version_num = 1
其中
- diffusion_steps即扩散次数/去噪次数。这个参数在前面讲过,会取比较大的值,一般取值为1000~4000。
4. 加载MNIST数据集
#def __init__(self, train, dataset="MNIST")
#if train=True, return training dataset, else return validation dataset
#all data are already scaled in DiffSet
train_dataset = DiffSet(True, dataset_choice)
val_dataset = DiffSet(False, dataset_choice)
#Build a pytorch DataLoader, where "shuffle=True" means the data will be
# shuffled (not in sequence) during the building
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=True)
注意,使用pytorch_lightning训练模型的时候,需要将数据集转化为pytorch的DataLoader
5. 构建模型
if load_model:
pass
else:
model = DiffusionModel(train_dataset.size * train_dataset.size,
diffusion_steps,
train_dataset.depth)
tb_logger = pl.loggers.TensorBoardLogger(
'lighting_logs/',
name = dataset_choice,
#version = pass_version,
)
# Using pytorch-lighting trainer class to do the training
trainer = pl.Trainer(
max_epochs = max_epoch,
log_every_n_steps = 10,
#gpus = 1,
#auto_select_gpus = True,
#resume_from_checkpoint = last_checkpoint,
logger = tb_logger
)
注意这里需要创建两个对象
- 通过自定义的DDPM创建的模型对象model
- 通过pytorch_lightning创建的训练对象trainer
这点和直接使用已经集成好的模型(比如TensorFlow中的Sequential())是不太一样的。
6. 训练模型
# Training model
trainer.fit(model, train_loader, val_loader)
这里需要将自定义好的模型model,传给通过pytorch_lightning创建的训练对象trainer;参数使用前面已经创建好的pytorch DataLoader。
另外,训练过程会非常慢,笔者在Runtime中设置T4 GPU大概也要训练10~20分钟,故不建议频繁执行“Run all”。
7. 产生新的采样
sample_batch_size = 9
gen_samples = []
x = torch.randn((sample_batch_size, train_dataset.depth,
train_dataset.size, train_dataset.size))
sample_steps = torch.arange(model.t_range, 0, -1)
for t in sample_steps:
x = model.denoise_sample(x, t)
# every 50 steps, add a sample in current step(t_x) into gen_samples
# as there are 1000 steps, so length of gen_samples will be 1000/50 = 20
if t % 50 == 0:
gen_samples.append(x)
print(model.t_range)
print(len(gen_samples), gen_samples[0].shape)
gen_samples = torch.stack(gen_samples, dim =0).moveaxis(2, 4).squeeze(-1)
gen_samples = (gen_samples.clamp(-1, 1) + 1) / 2
print(gen_samples[0].shape)
说明:
- 通过torch.randn()产生一个随机噪声,通过DDPM提供的denoise_sample()去噪并产生新的采样
- sample_steps = torch.arange(model.t_range, 0, -1),中的model.t_range即创建模型时给出的diffusion_steps参数,这里是1000,torch.arange是按照start和end从model.t_range中产生一个向量(这里实际上是1~1000的数组)
- “if t % 50 == 0: gen_samples.append(x)”这句是每隔50步,将当前采样(即前面提到的“潜在变量”)添加到gen_samples[]中,以便后面显示去噪过程。
8. 展示新的采样
fig = plt.figure(figsize = (12, 6))
for i in range(gen_samples.shape[0]):
plt.subplot(5, 6, i+1)
plt.imshow((gen_samples[i, 4, :, :] * 255).type(torch.uint8), cmap = 'gray')
plt.axis('off')
对前面gen_samples中保存的20个中间状态,使用plt.imshow()显示其效果,得到如下结果
通过展示可以看到,从纯噪声到产生一个新的采样(有点像手写的数字5)的大致过程。
由于DDPM是一个概率模型,这意味着每次根据噪声产生的图片是不一样的。保持已有的训练好的模型,再次运行步骤7和步骤8,可能得到下面的结果