扩散模型实战(六):Diffusers DDPM初探

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

       之前的五篇文章主要是为了解释扩散模型的基本概念和流程,使读者更容易理解扩散模型的工作原理,但与实际工作中使用的模型差异较大,从本文开始,我们将初步使用DDPM模型的开源实现库Diffusers,在Diffusers库中DDPM模型的实现库是UNet2DModel

UNet2DModel模型实战

UNet2DModel模型比之前介绍的BasicUNet模型有一些改进,具体如下:

  • 退化过程的处理方式不同,UNet2DModel通过调节时间步来调节噪声量,t作为一个额外参数被传入前向过程;
  • 训练目标不同,UNet2DModel旨在预测不带缩放系数的噪声(也就是单位正太分布的噪声)而不是”去噪“的图像;
  • UNet2DModel有更多的采样策略可供选择;

下面我们来看一下UNet2DModel的模型参数以及结构,代码如下:

model = UNet2DModel(    sample_size=28,            # 目标图像的分辨率    in_channels=1,         # 输入图像的通道数,RGB图像的通道数为3    out_channels=1,        # 输出图像的通道数    layers_per_block=2,    # 设置要在每一个UNet块中使用多少个ResNet层    block_out_channels=(32, 64, 64), # 与BasicUNet模型的配置基本相同    down_block_types=(         "DownBlock2D",      # 标准的ResNet下采样模块        "AttnDownBlock2D",  # 带有空域维度self-att的ResNet下采样模块        "AttnDownBlock2D",    ),     up_block_types=(        "AttnUpBlock2D",         "AttnUpBlock2D",    # 带有空域维度self-att的ResNet上采样模块        "UpBlock2D",        # 标准的ResNet上采样模块       ),) # 输出模型结构(看起来虽然冗长,但非常清晰)print(model)

我们继续来查看一下UNet2DModel模型的参数量,代码如下:

sum([p.numel() for p in model.parameters()]) # UNet2DModel模型使用了大约170万个参数,BasicUNet模型则使用了30多万个参数
# 输出1707009

       下面是我们使用UNet2DModel代替BasicUNet模型,重复前面展示的训练以及采样过程(这里t=0,以表明模型是在没有时间步的情况下训练的),完整的代码如下:

#@markdown Trying UNet2DModel instead of BasicUNet:# Dataloader (you can mess with batch size)batch_size = 128train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# How many runs through the data should we do?n_epochs = 3# Create the networknet = UNet2DModel(    sample_size=28,  # the target image resolution    in_channels=1,  # the number of input channels, 3 for RGB images    out_channels=1,  # the number of output channels    layers_per_block=2,  # how many ResNet layers to use per UNet block    block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example    down_block_types=(        "DownBlock2D",  # a regular ResNet downsampling block        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention        "AttnDownBlock2D",    ),    up_block_types=(        "AttnUpBlock2D",        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention        "UpBlock2D",   # a regular ResNet upsampling block      ),) #<<<net.to(device)# Our loss finctionloss_fn = nn.MSELoss()# The optimizeropt = torch.optim.Adam(net.parameters(), lr=1e-3)# Keeping a record of the losses for later viewinglosses = []# The training loopfor epoch in range(n_epochs):    for x, y in train_dataloader:        # Get some data and prepare the corrupted version        x = x.to(device) # Data on the GPU        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts        noisy_x = corrupt(x, noise_amount) # Create our noisy x        # Get the model prediction        pred = net(noisy_x, 0).sample #<<< Using timestep 0 always, adding .sample        # Calculate the loss        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?        # Backprop and update the params:        opt.zero_grad()        loss.backward()        opt.step()        # Store the loss for later        losses.append(loss.item())    # Print our the average of the loss values for this epoch:    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')# Plot losses and some samplesfig, axs = plt.subplots(1, 2, figsize=(12, 5))# Lossesaxs[0].plot(losses)axs[0].set_ylim(0, 0.1)axs[0].set_title('Loss over time')# Samplesn_steps = 40x = torch.rand(64, 1, 28, 28).to(device)for i in range(n_steps):  noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low  with torch.no_grad():    pred = net(x, 0).sample  mix_factor = 1/(n_steps - i)  x = x*(1-mix_factor) + pred*mix_factoraxs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Generated Samples');
# 输出Finished epoch 0. Average loss for this epoch: 0.020033Finished epoch 1. Average loss for this epoch: 0.013243Finished epoch 2. Average loss for this epoch: 0.011795

可以看出,比BasicUNet网络生成的结果要好一些。

DDPM原理

论文名称:《Denoising Diffusion Probabilistic Models》

论文地址:https://arxiv.org/pdf/2006.11239.pdf

      下面是DDPM论文中的公式,Training步骤其实是退化过程,给原始图像逐渐添加噪声的过程,预测目标是拟合每个时间步的采样噪声。

       还有一点非常重要:我们都知道在前向过程中是不断添加噪声的,其实这个噪声的系数不是固定的,而是与时间t线性增加的(也成为扩散率),这样的好处是在后向过程开始过程先把"明显"的噪声给去除,对应着较大的扩散率;当去到一定程度,逐渐逼近真实真实图像的时候,去噪速率逐渐减慢,开始微调,也就是对应着较小的扩散率。

下面我们使用代码来看一下输入数据与噪声在不同迭代周期的变化:

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${ \sqrt{\bar{\alpha}_t}}$")plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5,  label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")plt.legend(fontsize="x-large");

生成的结果,如下图所示:

       下面我们来看一下,噪声系数不变与DDPM中的噪声方式在MNIST数据集上的加噪效果:

# 可视化:DDPM加噪过程中的不同时间步# 对一批图片加噪,看看效果fig, axs = plt.subplots(3, 1, figsize=(16, 10))xb, yb = next(iter(train_dataloader))xb = xb.to(device)[:8]xb = xb * 2. - 1. # 映射到(-1,1)print('X shape', xb.shape) # 展示干净的原始输入axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().    cpu(), cmap='Greys')axs[0].set_title('Clean X') # 使用调度器加噪timesteps = torch.linspace(0, 999, 8).long().to(device)noise = torch.randn_like(xb) # <<注意是使用randn而不是randnoisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)print('Noisy X shape', noisy_xb.shape) # 展示“带噪”版本(使用或不使用截断函数clipping)axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap='Greys')axs[1].set_title('Noisy X (clipped to (-1, 1))')axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].   detach().cpu(), cmap='Greys')axs[2].set_title('Noisy X');X shape torch.Size([8, 1, 28, 28])Noisy X shape torch.Size([8, 1, 28, 28])

结果如下图所示:

采样补充

       采样在扩散模型中扮演非常重要的角色,我们可以输入纯噪声,然后期待模型能一步输出不带噪声的图像吗?根据前面的所学内容,这显然行不通。那么针对采样会有哪些改进的思路呢?

  • 可以使用模型多预测几次,以通过估计一个更高阶的梯度来更新得到更准确的结果(更高阶的方法和一些离散的ODE处理器);
  • 保留一些历史的预测值来尝试指导当前步的更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/90265.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

idea 对JavaScript进行debug调试

文章目录 1.新增 JavaScript Debug 配置2.配置访问地址3.访问url. 打断点测试 前言 : 工作中接手别人的前端代码没有注释&#xff0c;看浏览器的network或者console切来切去&#xff0c;很麻烦&#xff0c;可以试试idea自带的javscript debug功能。 1.新增 JavaScript Debug 配…

【网络安全】防火墙知识点全面图解(三)

本系列文章包含&#xff1a; 【网络安全】防火墙知识点全面图解&#xff08;一&#xff09;【网络安全】防火墙知识点全面图解&#xff08;二&#xff09;【网络安全】防火墙知识点全面图解&#xff08;三&#xff09; 防火墙知识点全面图解&#xff08;三&#xff09; 39、什…

基于Spring Boot的智慧团支部建设网站的设计与实现(Java+spring boot+MySQL)

获取源码或者论文请私信博主 演示视频&#xff1a; 基于Spring Boot的智慧团支部建设网站的设计与实现&#xff08;Javaspring bootMySQL&#xff09; 使用技术&#xff1a; 前端&#xff1a;html css javascript jQuery ajax thymeleaf 微信小程序 后端&#xff1a;Java sp…

ChatGPT⼊门到精通(1):ChatGPT 是什么

⼀、直观感受 1、公司 OpenAI&#xff08;美国&#xff09; 2、官⽅⽹站 3、登录ChatGPT ![在这里插入图片描述](https://img-blog.csdnimg.cn/26901096553a4ba0a5c88c49b2601e6a.png 填⼊帐号、密码&#xff0c;点击登录。登录成功&#xff0c;如下 3、和ChatGPT对话 开始…

【论文笔记】Planning and Decision-Making for Autonomous Vehicles

文章目录 Summary1. INTRODUCTION2. MOTION PLANNING AND CONTROL2.1. Vehicle Dynamics and Control2.2. Parallel Autonomy2.3. Motion Planning for Autonomous Vehicles 3. INTEGRATED PERCEPTION AND PLANNING3.1. From Classical Perception to Current Challenges in Ne…

APP上线为什么要提前部署安全产品呢?

一般平台刚上线或者日活跃量比较高的时候&#xff0c;很容易成为攻击者的目标&#xff0c;服务器如果遭遇黑客攻击&#xff0c;资源耗尽会导致平台无法访问&#xff0c;业务也无法正常开展&#xff0c;服务器一旦触发黑洞机制&#xff0c;就会被拉进黑洞很长一段时间&#xff0…

激活函数总结(十九):激活函数补充(ISRU、ISRLU)

激活函数总结&#xff08;十九&#xff09;&#xff1a;激活函数补充 1 引言2 激活函数2.1 Inverse Square Root Unit &#xff08;ISRU&#xff09;激活函数2.2 Inverse Square Root Linear Unit (ISRLU)激活函数 3. 总结 1 引言 在前面的文章中已经介绍了介绍了一系列激活函…

Midjourney 完整版教程(从账号注册到设计应用)

目录 一、Midjourney 介绍 二、Midjourney 的AI出图示例 三、手把手教你上手Midjourney 1、账号&初始化 1.1 账号注册登录 1.2 账号付费 1.3 账号初始化 2、Midjourney的基础设置 3、Midjourney 出图步骤。 (一)直接描述出图 (二)垫图生图。 4、Midjourney的…

Android kotlin 跳转手机热点开关页面和判断热点是否打开

Android kotlin 跳转手机热点开关页面和判断热点是否打开 判断热点是否打开跳转手机热点开关页面顺带介绍一些其他常用的设置页面跳转 其他热点的一些相关知识Local-only hotspot 参考 判断热点是否打开 网上方法比较多&#xff0c;我这边使用了通过WifiManager 拿反射的getWi…

咸虾米之一些快捷方式的操作,一行方块的左右滑动,方块在一区域内的任意移动

由于本着只学习微信小程序的目的&#xff0c;上面的几篇博文都是跟着黑马程序的课程走的&#xff01;后面的就讲解uni-app的实验呢&#xff01;一个人的精力是有限的&#xff0c;于是换了们课程继续深造微信小程序&#xff01;&#xff01;&#xff01; 以下是在 .wxml中的一些…

More Effective C++学习笔记(4)

目录 条款16&#xff1a;谨记 80 - 20 法则条款17&#xff1a;考虑使用lazy evaluation&#xff08;缓式评估&#xff09;条款18&#xff1a;分期摊还预期的计算成本条款19&#xff1a;了解临时对象来源条款20&#xff1a;协助完成 “ 返回值优化 ”条款21&#xff1a;利用重载…

【C++心愿便利店】No.2---函数重载、引用

文章目录 前言&#x1f31f;一、函数重载&#x1f30f;1.1.函数重载概念&#x1f30f;1.2.C支持函数重载的原理 -- 名字修饰 &#x1f31f;二、引用&#x1f30f;2.1.引用的概念&#x1f30f;2.2.引用特性&#x1f30f;2.3.常引用&#x1f30f;2.4.使用场景&#x1f30f;2.5.传…

[当前就业]2023年8月25日-计算机视觉就业现状分析

计算机视觉就业现状分析 前言&#xff1a;超越YOLO&#xff1a;计算机视觉市场蓬勃发展 如今&#xff0c;YOLO&#xff08;You Only Look Once&#xff09;新版本的发布周期很快&#xff0c;每次迭代的性能都优于其前身。每 3 到 4 个月就会推出一个升级版 YOLO 变体&#xf…

腾讯云服务器价格表大全_轻量服务器_CVM云服务器报价明细

腾讯云服务器租用费用表&#xff1a;轻量应用服务器2核2G4M带宽112元一年&#xff0c;540元三年、2核4G5M带宽218元一年&#xff0c;2核4G5M带宽756元三年、云服务器CVM S5实例2核2G配置280.8元一年、GPU服务器GN10Xp实例145元7天&#xff0c;腾讯云服务器网长期更新腾讯云轻量…

【线程池】线程池拒绝策略还有这个大坑(二)

目录 踩坑代码 后果展示 原因 小结 概要 上文我们聊了聊阻塞队列&#xff0c;有需要的小伙伴可以去瞅瞅【线程池】换个姿势来看线程池中不一样的阻塞队列&#xff08;一&#xff09;_走了一些弯路的博客-CSDN博客 这波我们一起来研究下线程池的拒绝策略。 你肯定要说了&a…

【ArcGIS微课1000例】0071:普通最小二乘法 (OLS)回归分析案例

严重声明:本文来自专栏《ArcGIS微课1000例:从点滴到精通》,为CSDN博客专家刘一哥GIS原创,原文及专栏地址为:(https://blog.csdn.net/lucky51222/category_11121281.html),谢绝转载或爬取!!! 文章目录 一、空间自回归模型二、ArcGIS普通最小二乘法回归(OLS)一、空间自…

拒绝摆烂!C语言练习打卡第六天

&#x1f525;博客主页&#xff1a;小王又困了 &#x1f4da;系列专栏&#xff1a;每日一练 &#x1f31f;人之为学&#xff0c;不日近则日退 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、选择题 &#x1f4dd;1.第一题 &#x1f4dd;2.第二题 &#x1f4d…

SpringCloud 教程 | 第一篇: 服务的注册与发现(Eureka)

一、spring cloud简介 spring cloud 为开发人员提供了快速构建分布式系统的一些工具&#xff0c;包括配置管理、服务发现、断路器、路由、微代理、事件总线、全局锁、决策竞选、分布式会话等等。它运行环境简单&#xff0c;可以在开发人员的电脑上跑。另外说明spring cloud是基…

node_modules.cache是什么东西

一开始没明白这是啥玩意&#xff0c;还以为是npm的属性&#xff0c;网上也没说过具体的来源出处 .cache文件的产生是由webpack4的插件cache-loader生成的&#xff0c;node_modules里下载了cache-loader插件&#xff0c;很多朋友都是vuecli工具生成的项目&#xff0c;内置了这部…

全流程R语言Meta分析核心技术

​Meta分析是针对某一科研问题&#xff0c;根据明确的搜索策略、选择筛选文献标准、采用严格的评价方法&#xff0c;对来源不同的研究成果进行收集、合并及定量统计分析的方法&#xff0c;最早出现于“循证医学”&#xff0c;现已广泛应用于农林生态&#xff0c;资源环境等方面…