GANs和Diffusion模型(3)

接GANs和Diffusion模型(2)

扩散(Diffusion)模型

生成学习三重困难(Trilemma)

指生成学习(genrative learning)的模型都需要满足三个需求:

  • 高质量的采样(High Quality Samples):模型应该能生成非常高质量的采样
  • 快速采样(Fast Sampling):模型应该能快速地从噪声中产生采样
  • 多样性或模式覆盖(Diversity or Mode Coverage):模型应该能提供好的多样性或者模式覆盖

好的多样性或模式覆盖往往是生成模型的一个非常重要的需求。然而,事实上生成模型很难同时满足上述3个需求,不同的生成模型会在这3个需求中作出权衡和取舍。根据你的实际用例做出正确的权衡,这就是生成学习的三重困境(trilemma)。

前面学习的GAN模型,能够较好地满足前2个需求,但无法满足第3个需求,即无法提供好的多样性或者模式覆盖。如果你遇到了一个现有采样很少的问题,GANs并不是一个很好的选择。此时,你或许可以选择一个能够较好地满足第1/3需求的生成学习模型,即扩散模型(Diffusion Model)。

去噪扩散概率模型介绍

Denoising Diffusion Probabilistic Model,简写为DDPM。

DDPM首次发表于2020年的文献“Denosing Diffusion Probabilistic Models”:https://arxiv.org/pdf/2006.11239.pdf。

FID(Fréchet Inception Distance,弗雷切特起始距离)

用于评价一个生成模型产生的图片质量的指标。FID分数会比较生成图片的分布和用于训练生成器的真实图片的分布。由于FID是在测量两种图片分布之间的距离,因此一个完美的FID分数是0.0。通常,一个模型的FID分数越低,说明这个模型产生的图片质量越高。

关于FID的详细公式,可以参考这篇博文:【基础知识】FID(Fréchet Inception Distance)公式及解释

DDPM的工作方式

DDPM的实现实际上是非常复杂的,但是这个模型的工作方式给人的直观感觉却并不复杂。我们从一个来自真实数据集的图片开始,且将这个数据标记为x0,则

  • 扩散过程会反复地(通常有1000次)向这个x0添加非常小的噪声(按照原文的意思,是无穷小:infinitesimal)
  • 所有添加的噪声有自己的分布,即独立高斯噪声
  • 每次添加的噪声会在前一次的基础上进行一个微小地放大
  • 直到我们得到一个纯噪声的图片为止,此时已经不能从当前图片中分辨出原始的图像了

如下图所示,从最左边的图片开始,每次添加一个独立高斯分布的、且逐渐放大的噪声,直到最后成为纯噪声图片。

此刻,一个想法是:上述过程如果可逆,那么我们便可以从一个纯噪声图片得到一张和真实采样接近的图片。DDPM文献的作者对上述想法做出了数学证明,并得出了这样一个结论:如果添加噪声的过程(扩散)满足我们给出的条件,即噪声按照一种特殊的方式逐渐添加,则对一张图片添加噪声的前向过程、和对一张图片去噪的反向过程,会具有相同的函数形式

下图给出了反向去噪的示意。

于是,通过成功训练一个扩散模型的参数,我们可以得到具有相同参数的去噪模型,然后通过这个去噪模型就可以根据噪声产生新的采样了。

前向扩散过程

假定扩散过程从x0开始,直到xT结束

x_{0}\rightarrow x_{1}\rightarrow... \rightarrow x_{T}

在实践中,这个T比较大,一般在1000~4000之间。最终得到的图片xT,对x0而言,接近于一个高斯分布(又称为正态分布,其中均值为0,方差为1的高斯分布又称为标准正态分布)

q(x_{T}|x_{0}) \approx N(0, 1)

于是最终的xT接近于一个纯噪声。

同时,在扩散模型中,前向扩散过程的模型是一个马克可夫链(Markov Chain),即每一个步骤之和其前一个步骤有关,和再之前的所有步骤都不相关;即x(n+1)只依赖于x(n),而独立于x(n)之前的所有状态。于是,在扩散过程中,每一次产生的图片只和其前一张图片有关,和更早产生的图片无关。用公式表达为:

q(x_{t}|x_{t-1}) = N(x_{t}; \sqrt{1-\beta_{t} }x_{t-1}, \beta_{t}I)

公式中

  • 等号左边表示:当(t-1)时刻值为x(t-1)时、t时刻值为xt的状态转移分布
  • 等号右边表示一个高斯分布,均值为\sqrt{1-\beta_{t} },方差为\beta_{t}\beta_{t}是一个超参数

Beta的下标t表明这个方差并不是固定的,而是随步骤变化的,且满足

\beta_{1} < \beta_{2} <\beta_{3}<...<\beta_{t} ,\beta_{t} \in (0, 1)

随着t的增大,均值会无限接近于0,方差无限接近于1,于是得到前面的标准正态分布。

反向去噪过程

反向过程和扩散过程相反,从xT开始,直到x0结束

x_{0}\leftarrow ...\leftarrow x_{T-1}\leftarrow x_{T}

反向过程也是一个马尔科夫链。给定t时刻的图片xt,则(t-1)时刻产生的图片x(t-1)的分布为:

p_{\theta }(x_{t-1}|x_{t}) = N(x_{t-1}; \mu _{\theta }(x_{t},t),\Sigma \theta (x_{t},t))

反向过程的分布依赖于t时刻的采样xt。反向模型的神经网络模型的目标是学习去噪过程的均值和方差,以便从噪声中产生图片。

扩散模型的直观感觉:可变自动编码器

在前面的描述的扩散过程中,x0通常是真实数据,x1,x2,...,xT是通过不断增加噪声产生的,也称为潜在变量(latent variables),因此扩散模型也可以看作一种潜在变量生成模型(latent variable generator model)。

事实上,还有另一种非常流行的潜在变量生成模型:可变自动编码器,Variational Autoencoder (VAE)。并且我们会发现,两者其实很相似。

自动编码器(Autoencoder)

这里讲到的自动编码器是特指的用于降维的一种神经网络模型。自动编码器会从基础数据(underlying data)中抽取潜在变量或者特性。假定输入自动编码器的数据是图片,这些图片在自动编码器中会以较低的维度被编码。这些较低维度的编码表示了输入图片的潜在或重要的特征。潜在特征尽可能多地包含了原始图片的特征信息、但却比原始图片具有更低的维度。自动编码器利用这些抽取出来的较低维度的潜在特征,产生尽可能接近原始输入的数据。如下图所示:

自动编码器模型分为编码器和解码器两个部分。其中,编码器会输出一个较低维度的潜在特征z,然后由解码器恢复为一个具有原始维度的新数据x'。x'应该尽可能和x一样。

可变自动编码器(Variational Autoencoder)

可自动编码器相比,将输出由确定的变为不确定的,即增加了概率;因此可以看作是概率自动编码器(Probabilistic Autoencoder)。此外,普通的自动编码器实际上是在输出侧重构了输入,并不能产生新的采样,因此不能看作生成模型。但是可变自动编码器可以产生新的采样,是一种生成自动编码器(Generative Autoencoder),因此是一种生成模型。如下图所示:

VAE编码器产生的不再是一个固定的数据,而是一个具有均值和方差的概率分布。也就是说,VAE中的z是一个满足一定概率分布的随机变量。VAE的解码器同样也不再产生一个固定的输出,而是一个满足一定概率分布的随机变量。

VAE和DDPM的关系

DDPM的扩散过程和VAE的编码器十分相似,扩散过程通过逐步添加噪声最终产生一个具有标准高斯分布的随机变量xT

DDPM的反向过程和VAE的解码器十分相似

正是因为如此,很多训练VAE的概念,也同样可以用于训练DDPM上。

程序演示

去噪扩散概率模型(DDPM)的实际数学模型和实现其实是非常复杂的,本文不会对这些内容进行详细描述。事实上,在github上已经有很多DDPM的实现了,包括文献作者的实现。本文的演示中会借用这个github上的模型实现:https://github.com/awjuliani/pytorch-diffusion。该实现完成于2022年。

由于DDPM比较复杂,因此本节程序演示的目标并不是去理解模型本身,而是初步体会如何运用现有的DDPM即可。即,如何通过使用pytorch和pytorch-lighting、利用github上的DDPM、对现有的数据集进行学习,并产生新的采样

github模型简介

打卡这个链接,可以看到这个模型的实现包含几个python文件:data.py,model.py,modules.py,用于实现模型;以及一个entry.ipynb文件,可以认为是一个驱动程序,或者对模型的使用,即创建和训练模型等。

接下来的“Readme”部分说明了,这个程序实现了文献中的DDPM。

同时可以看到,程序使用了MNIST,Fashion-MNIST和CIFAR这3个数据集进行了验证。其中

  • MNIST是一个28x28像素的手写数字图片数据集,在神经网络必备基础中介绍过
  • Fashion-MNIST是一个28x28像素的流行图片数据集,在GANs和Diffusion模型(1)中介绍过
  • CIFAR是一个32x32的彩色图片数据集,分为CIFAR-10(10个类别的图片集)和CIFAR-100(100个类别的图片集),包括飞机,汽车,鸟,猫,轮船等等。链接:CIFAR-10 and CIFAR-100 datasets

再来看本模型的源文件

主要的文件有4个(请依次打开每个文件浏览一下):

  • data.py:负责导入数据。其中会对导入的数据做scaling和centering
  • modules.py:定义了DDPM相关的层,如DoubleConv(双卷积层),Down(下采样层,主要指汇聚层pooling),Up(上采样层,主要指转置卷积,transposed convolution),OutConv(输出卷积层)等。此外,该模块还定义了SelfAttention(自注意机制),及其封装SAWrapper(Self Attention Wrapper)
  • model.py:定义DDPM本身。包括forward(前向扩散),get_loss(损失函数),beta,alpha(计算模型中的beta,alpha参数,可以查阅文献中关于这些参数的定义),denoise_sample(反向去噪),training_step(训练),validation_step(校验)等等。
  • entry.ipynb:前面三个文件是DDPM的完整定义,这个文件是对DDPM的运用示例。在使用这个DDPM的时候,可以参考这个文件。

使用DDPM的程序

本节通过参考entry.ipynb,定义自己的程序,利用DDPM训练MNIST数据集,并产生自己的采样。

1. 挂在google colab的网盘(drive)

本程序会将github上的模型文件拷贝下来,存放到自己的google colab网盘中,因此需要执行这个步骤。否则,程序将无法识别本地文件。

# check my current drive
import os
os.getcwd()

# mount my drive in google colab
from google.colab import drive
drive.mount('/content/drive')

# change to my working directory, all sources are in this folder
%cd /content/drive/My Drive/Colab Notebooks/DeepLearning/gan_diffusion/ddpm

其中

  • getcwd()的作用是获取本地网盘的路径,默认一般是'/content'。
  • "%cd"的作用是将工作路径从google colab的“根目录”切换到ddpm的路径,方便后面使用时,不再需要添加相对路径。

2. 安装&导入python包

!pip install pytorch_lightning

#import pytorch, whose name is torch
import torch
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from data import DiffSet
from model import DiffusionModel
from torch.utils.data import DataLoader

其中,

  • Google colab默认不带pytorch_lightning,因此需要安装。
  • “from data import DiffSet”和“from model import DiffusionModel”均来自本地定义的python模块,即前面讲到的data.py和model.py。

3. 定义模型参数

diffusion_steps = 1000
dataset_choice = "MNIST"
max_epoch = 10
batch_size = 128

# tracks if the model will be loaded from a checkpoint, not used yet
load_model = False
load_version_num = 1

其中

  • diffusion_steps即扩散次数/去噪次数。这个参数在前面讲过,会取比较大的值,一般取值为1000~4000。

4. 加载MNIST数据集

#def __init__(self, train, dataset="MNIST")
#if train=True, return training dataset, else return validation dataset
#all data are already scaled in DiffSet
train_dataset = DiffSet(True, dataset_choice)
val_dataset = DiffSet(False, dataset_choice)

#Build a pytorch DataLoader, where "shuffle=True" means the data will be
# shuffled (not in sequence) during the building
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=True)

注意,使用pytorch_lightning训练模型的时候,需要将数据集转化为pytorch的DataLoader

5. 构建模型

if load_model:
  pass
else:
  model = DiffusionModel(train_dataset.size * train_dataset.size,
                         diffusion_steps,
                         train_dataset.depth)

tb_logger = pl.loggers.TensorBoardLogger(
    'lighting_logs/',
    name = dataset_choice,
    #version = pass_version,
)

# Using pytorch-lighting trainer class to do the training
trainer = pl.Trainer(
    max_epochs = max_epoch,
    log_every_n_steps = 10,
    #gpus = 1,
    #auto_select_gpus = True,
    #resume_from_checkpoint = last_checkpoint,
    logger = tb_logger
)

注意这里需要创建两个对象

  • 通过自定义的DDPM创建的模型对象model
  • 通过pytorch_lightning创建的训练对象trainer

这点和直接使用已经集成好的模型(比如TensorFlow中的Sequential())是不太一样的。

6. 训练模型

# Training model
trainer.fit(model, train_loader, val_loader)

这里需要将自定义好的模型model,传给通过pytorch_lightning创建的训练对象trainer;参数使用前面已经创建好的pytorch DataLoader。

另外,训练过程会非常慢,笔者在Runtime中设置T4 GPU大概也要训练10~20分钟,故不建议频繁执行“Run all”。

7. 产生新的采样

sample_batch_size = 9

gen_samples = []
x = torch.randn((sample_batch_size, train_dataset.depth,
                 train_dataset.size, train_dataset.size))
sample_steps = torch.arange(model.t_range, 0, -1)

for t in sample_steps:
  x = model.denoise_sample(x, t)
  # every 50 steps, add a sample in current step(t_x) into gen_samples
  # as there are 1000 steps, so length of gen_samples will be 1000/50 = 20
  if t % 50 == 0:
    gen_samples.append(x)

print(model.t_range)
print(len(gen_samples), gen_samples[0].shape)

gen_samples = torch.stack(gen_samples, dim =0).moveaxis(2, 4).squeeze(-1)
gen_samples = (gen_samples.clamp(-1, 1) + 1) / 2
print(gen_samples[0].shape)

说明:

  • 通过torch.randn()产生一个随机噪声,通过DDPM提供的denoise_sample()去噪并产生新的采样
  • sample_steps = torch.arange(model.t_range, 0, -1),中的model.t_range即创建模型时给出的diffusion_steps参数,这里是1000,torch.arange是按照start和end从model.t_range中产生一个向量(这里实际上是1~1000的数组)
  • “if t % 50 == 0: gen_samples.append(x)”这句是每隔50步,将当前采样(即前面提到的“潜在变量”)添加到gen_samples[]中,以便后面显示去噪过程。

8. 展示新的采样

fig = plt.figure(figsize = (12, 6))
for i in range(gen_samples.shape[0]):
  plt.subplot(5, 6, i+1)
  plt.imshow((gen_samples[i, 4, :, :] * 255).type(torch.uint8), cmap = 'gray')
  plt.axis('off')

对前面gen_samples中保存的20个中间状态,使用plt.imshow()显示其效果,得到如下结果

通过展示可以看到,从纯噪声到产生一个新的采样(有点像手写的数字5)的大致过程。

由于DDPM是一个概率模型,这意味着每次根据噪声产生的图片是不一样的。保持已有的训练好的模型,再次运行步骤7和步骤8,可能得到下面的结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/506822.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用 Python 模拟布朗运动(和股票价格)

一、说明 本文先介绍布朗运动的概念&#xff0c;紧接着应用布朗方程到股票的随机斩落模型。进而用python实现&#xff0c;并给出各种各样的条件模型。从中烘托出股票模型的规律所在。 二、什么是布朗运动&#xff1f; 布朗运动以罗伯特布朗的名字命名&#xff0c;他是第一个在通…

持续交付与持续部署相关概念(CD)

目录 一、概述 二、持续交付基本概念 2.1 持续交付的含义 2.1.1 项目管理的视角 2.1.2 产品研发的视角 2.1.3 总结 2.2 持续交付涉及的运作环境 2.2.1 开发环境 2.2.2 测试环境 2.2.3 UAT环境 2.2.4 准生产环境 2.2.5 生产环境 2.3 总结 三、持续部署基本概念 3.…

创新之路:云边对接与行业生态的前沿探索

全球 80% 的数据来自物联网&#xff0c;不论是传统行业还是新兴行业&#xff0c;都将利用更多有价值的数据来驱动业务&#xff0c;实现降本增效。智慧教育、资产追踪、环境监测、工业物联网、智慧城市、家居互联、智慧电力、智慧农业。从智能电表到智能家居&#xff0c;从机器人…

RAG:检索增强生成系统如何工作

随着大型语言模型&#xff08;LLM&#xff09;的发展&#xff0c;人工智能世界取得了巨大的飞跃。经过大量数据的训练&#xff0c;LLM可以发现语言模式和关系&#xff0c;使人工智能工具能够生成更准确、与上下文相关的响应。 但LLM也给人工智能工程师带来了新的挑战&#xff…

shopee、lazada、temu测评自养号策略解析

在跨境电商领域&#xff0c;测评作为提升销量的重要手段&#xff0c;其策略的制定和实施显得尤为重要。特别是对于Shopee和Lazada两大主流平台上的卖家而言&#xff0c;如何有效利用测评策略提升产品销量成为了一大挑战。 自养号测评系统可以批量注册买家账号、模拟真实人工操…

U8二次开发-钉钉集成

钉钉开放平台作为企业沟通和协作的重要工具,其技术的每一次迭代都为企业带来了新的机遇和挑战。随着企业对于高效沟通和智能化管理的需求日益增长,钉钉平台的SDK更新显得尤为重要。把传统的U8与钉钉平台集成,可以有效的将业务功能和角色进行前移,打破应用系统二八原则,即8…

Vue(十二):脚手架配置代理,github案例,插槽

一、脚手架配置代理 老师讲的主要有两种方法&#xff1a; 但是我的没有proxy&#xff0c;只有proxyTable,之前一直不成功&#xff0c;现在我是这样配置的&#xff1a; config文件夹下的index.js: App.vue: 然后就成功了&#xff1a;&#xff08;我真服了&#xff0c;之前在这…

Linux中xz一次恶意后门处理的名场面-尚文网络xUP楠哥

进Q群11372462领取专属报名福利! 说在前面 Linux系统中所使用的xz软件是用于日常文件的归档压缩工具&#xff0c;据悉就在今日&#xff0c;Utils 5.6.0、5.6.1版本存在恶意后门植入漏洞&#xff08;CVE-2024-3094&#xff09;。开发人员在调查SSH性能问题时发现了涉及XZ Util…

Taro多行文本最多展示5行,超出“查看更多”展示,点击弹层

Taro中&#xff0c;页面需求&#xff1a; 多行文本&#xff0c;展示最多展示5行&#xff0c;超出5行&#xff0c;展示“查看更多”按钮&#xff0c;点击弹层展示文本详细信息。 弹层代码就不说了&#xff0c;着重说一下怎么获取区域高度&#xff5e; 1.区域设置max-height&am…

2_2.Linux中的远程登录服务

# 一.Openssh的功能 # 1.sshd服务的用途# #作用&#xff1a;可以实现通过网络在远程主机中开启安全shell的操作 Secure SHell >ssh ##客户端 Secure SHell daemon >sshd ##服务端 2.安装包# openssh-server 3.主配置文件# /etc/ssh/sshd_conf 4.…

嵌入式|蓝桥杯STM32G431(HAL库开发)——CT117E学习笔记12:DAC数模转换

系列文章目录 嵌入式|蓝桥杯STM32G431&#xff08;HAL库开发&#xff09;——CT117E学习笔记01&#xff1a;赛事介绍与硬件平台 嵌入式|蓝桥杯STM32G431&#xff08;HAL库开发&#xff09;——CT117E学习笔记02&#xff1a;开发环境安装 嵌入式|蓝桥杯STM32G431&#xff08;…

Php_Code_challenge12

题目&#xff1a; 答案&#xff1a; 解析&#xff1a; 字符串拼接。

iPhone设备中调试应用程序崩溃日志的高效方法探究

​ 目录 如何在iPhone设备中查看崩溃日志 摘要 引言 导致iPhone设备崩溃的主要原因是什么&#xff1f; 使用克魔助手查看iPhone设备中的崩溃日志 奔溃日志分析 总结 摘要 本文介绍了如何在iPhone设备中查看崩溃日志&#xff0c;以便调查崩溃的原因。我们将展示三种不同的…

Windows 上路由、端口转发配置,跨网络地址段

一、背景 有时候我们会遇到这样的场景&#xff0c;一批同一局域网中只有某一台主机带外且系统为windows&#xff0c;局域网中其他非带外的主机要想访问外网&#xff0c;本文将介绍如何配置在带外主机上开启路由及端口转发。 二、配置操作 2.1、带外主机开启路由转发 1&#x…

Centos7.X服务器搭建VOS系统的REC录音转换MP3,并支持外呼系统wav转换MP3

由于有的公司客户需要自己下载录音或做话务质检等工作需要&#xff0c;需要从VOS系统中把录音下载到其它服务器使用&#xff0c;但是VOS录音格式是REC格式的&#xff0c;就算下载下来了也无法直接播放&#xff0c;因此我们需要搭建一台转换MP3的服务器来完成需求&#xff01; 外…

15-研发流程实战:IAM项目是如何进行研发流程管理的?

为了向你演示流程&#xff0c;这里先假设一个场景。我们有一个需求&#xff1a;给IAM客户端工具iamctl增加一个helloworld命令&#xff0c;该命令向终端打印hello world。 开发阶段 开发阶段是开发者的主战场&#xff0c; 它又可分为代码开发和代码提交两个子阶段。 代码开发…

用Python标准GUI库Tkinter绘制分形图

用Python标准GUI库Tkinter绘制分形图 分形图是一种通过迭代规则生成自相似图案的艺术形式。 分形图包括曼德勃罗集、科赫曲线、谢尔宾斯基三角等代码等。 Tkinter是Python的标准GUI库&#xff0c;可以用于创建窗口、控件和其他图形界面元素。绘制分形图像&#xff0c;如曼德…

数据库---------完全备份和增量备份的数据恢复,以及断点恢复

目录 一、在数据库表中&#xff0c;分三次录入学生考试成绩 1.1先创建库&#xff0c;创建表&#xff0c;完成三次数据的录入 1.2首次录入成绩后&#xff0c;做该表的完全备份 1.3第二次插入后 做增量备份 1.4第三次插入后 做增量备份 二、模拟数据丢失&#xff0c;并使用…

【Ubuntu】用 VMware 安装 macOS

本教程使用 Ubuntu 20.04.6 LTS&#xff0c;VMware Workstation Pro 17.5.1&#xff0c;macOS Sonoma 14.4。文中所有需要的下载链接均以 Markdown 的形式体现在文字上。 下载 VMware Workstation Pro&#xff0c;目前最新版本是 17.5.1。 使用密钥&#xff0c;进行破解。 VM…

苹果应用上架流程解析

苹果上架要求是苹果公司对于提交应用程序到苹果商店上架的要求和规定。这些要求主要是为了保证用户体验、应用程序的质量和安全性。以下是苹果上架要求的详细介绍&#xff1a;1. 应用程序的内容和功能必须符合苹果公司的规 苹果上架要求是苹果公司对于提交应用程序到苹果商店上…