扩散模型实战(四):从零构建扩散模型

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

本文以MNIST数据集为例,从零构建扩散模型,具体会涉及到如下知识点:

  • 退化过程(向数据中添加噪声)
  • 构建一个简单的UNet模型
  • 训练扩散模型
  • 采样过程分析

下面介绍具体的实现过程:

一、环境配置&python包的导入
     

最好有GPU环境,比如公司的GPU集群或者Google Colab,下面是代码实现:

# 安装diffusers库!pip install -q diffusers# 导入所需要的包import torchimport torchvisionfrom torch import nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom diffusers import DDPMScheduler, UNet2DModelfrom matplotlib import pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f'Using device: {device}')
# 输出Using device: cuda

此时会输出运行环境是GPU还是CPU

载MNIST数据集

       MNIST数据集是一个小数据集,存储的是0-9手写数字字体,每张图片都28X28的灰度图片,每个像素的取值范围是[0,1],下面加载该数据集,并展示部分数据:

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)x, y = next(iter(train_dataloader))print('Input shape:', x.shape)print('Labels:', y)plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
# 输出Input shape: torch.Size([8, 1, 28, 28])Labels: tensor([7, 8, 4, 2, 3, 6, 0, 2])

扩散模型的退化过程

       所谓退化过程,其实就是对输入数据加入噪声的过程,由于MNIST数据集的像素范围在[0,1],那么我们加入噪声也需要保持在相同的范围,这样我们可以很容易的把输入数据与噪声进行混合,代码如下:

def corrupt(x, amount):  """Corrupt the input `x` by mixing it with noise according to `amount`"""  noise = torch.rand_like(x)  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works  return x*(1-amount) + noise*amount

接下来,我们看一下逐步加噪的效果,代码如下:

# Plotting the input datafig, axs = plt.subplots(2, 1, figsize=(12, 5))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')# Adding noiseamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Plottinf the noised versionaxs[1].set_title('Corrupted data (-- amount increases -->)')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

       从上图可以看出,从左到右加入的噪声逐步增多,当噪声量接近1时,数据看起来像纯粹的随机噪声。

构建一个简单的UNet模型

       UNet模型与自编码器有异曲同工之妙,UNet最初是用于完成医学图像中分割任务的,网络结构如下所示:

代码如下:

class BasicUNet(nn.Module):    """A minimal UNet implementation."""    def __init__(self, in_channels=1, out_channels=1):        super().__init__()        self.down_layers = torch.nn.ModuleList([             nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),            nn.Conv2d(32, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 64, kernel_size=5, padding=2),        ])        self.up_layers = torch.nn.ModuleList([            nn.Conv2d(64, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 32, kernel_size=5, padding=2),            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),         ])        self.act = nn.SiLU() # The activation function        self.downscale = nn.MaxPool2d(2)        self.upscale = nn.Upsample(scale_factor=2)    def forward(self, x):        h = []        for i, l in enumerate(self.down_layers):            x = self.act(l(x)) # Through the layer and the activation function            if i < 2: # For all but the third (final) down layer:              h.append(x) # Storing output for skip connection              x = self.downscale(x) # Downscale ready for the next layer                      for i, l in enumerate(self.up_layers):            if i > 0: # For all except the first up layer              x = self.upscale(x) # Upscale              x += h.pop() # Fetching stored output (skip connection)            x = self.act(l(x)) # Through the layer and the activation function                    return x

我们来检验一下模型输入输出的shape变化是否符合预期,代码如下:

net = BasicUNet()x = torch.rand(8, 1, 28, 28)net(x).shape
# 输出torch.Size([8, 1, 28, 28])

再来看一下模型的参数量,代码如下:

sum([p.numel() for p in net.parameters()])
# 输出309057

至此,已经完成数据加载和UNet模型构建,当然UNet模型的结构可以有不同的设计。

、扩散模型训练

        扩散模型应该学习什么?其实有很多不同的目标,比如学习噪声,我们先以一个简单的例子开始,输入数据为带噪声的MNIST数据,扩散模型应该输出对应的最佳数字预测,因此学习的目标是预测值与真实值的MSE,训练代码如下:

# Dataloader (you can mess with batch size)batch_size = 128train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# How many runs through the data should we do?n_epochs = 3# Create the networknet = BasicUNet()net.to(device)# Our loss finctionloss_fn = nn.MSELoss()# The optimizeropt = torch.optim.Adam(net.parameters(), lr=1e-3) # Keeping a record of the losses for later viewinglosses = []# The training loopfor epoch in range(n_epochs):    for x, y in train_dataloader:        # Get some data and prepare the corrupted version        x = x.to(device) # Data on the GPU        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts        noisy_x = corrupt(x, noise_amount) # Create our noisy x        # Get the model prediction        pred = net(noisy_x)        # Calculate the loss        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?        # Backprop and update the params:        opt.zero_grad()        loss.backward()        opt.step()        # Store the loss for later        losses.append(loss.item())    # Print our the average of the loss values for this epoch:    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')# View the loss curveplt.plot(losses)plt.ylim(0, 0.1);
# 输出Finished epoch 0. Average loss for this epoch: 0.024689Finished epoch 1. Average loss for this epoch: 0.019226Finished epoch 2. Average loss for this epoch: 0.017939

训练过程的loss曲线如下图所示:

六、扩散模型效果评估

我们选取一部分数据来评估一下模型的预测效果,代码如下:

#@markdown Visualizing model predictions on noisy inputs:# Fetch some datax, y = next(iter(train_dataloader))x = x[:8] # Only using the first 8 for easy plotting# Corrupt with a range of amountsamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Get the model predictionswith torch.no_grad():  preds = net(noised_x.to(device)).detach().cpu()# Plotfig, axs = plt.subplots(3, 1, figsize=(12, 7))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Corrupted data')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')axs[2].set_title('Network Predictions')axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

从上图可以看出,对于噪声量较低的输入,模型的预测效果是很不错的,当amount=1时,模型的输出接近整个数据集的均值,这正是扩散模型的工作原理。

Note:我们的训练并不太充分,读者可以尝试不同的超参数来优化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/83072.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【论文阅读】CONAN:一种实用的、高精度、高效的APT实时检测系统(TDSC-2020)

CONAN&#xff1a;A Practical Real-Time APT Detection System With High Accuracy and Efficiency TDSC-2020 浙江大学 Xiong C, Zhu T, Dong W, et al. CONAN: A practical real-time APT detection system with high accuracy and efficiency[J]. IEEE Transactions on Dep…

Spring Boot 知识集锦之actuator监控端点详解

文章目录 0.前言1.参考文档2.基础介绍默认支持的端点 3.步骤3.1. 引入依赖3.2. 配置文件3.3. 核心源码 4.示例项目5.总结 0.前言 背景&#xff1a; 一直零散的使用着Spring Boot 的各种组件和特性&#xff0c;从未系统性的学习和总结&#xff0c;本次借着这个机会搞一波。共同学…

超越函数界限:探索JavaScript函数的无限可能

&#x1f3ac; 岸边的风&#xff1a;个人主页 &#x1f525; 个人专栏 :《 VUE 》 《 javaScript 》 ⛺️ 生活的理想&#xff0c;就是为了理想的生活 ! 目录 &#x1f4da; 前言 &#x1f4d8; 1. 函数的基本概念 &#x1f4df; 1.1 函数的定义和调用 &#x1f4df; 1.2 …

化繁为简,使用Hibernate Validator实现参数校验

前言 在之前的悦享校园的开发中使用了SSM框架&#xff0c;由于当时并没有使用参数参数校验工具&#xff0c;方法的入参判断使用了大量的if else语句&#xff0c;代码十分臃肿&#xff0c;因此最近在重构代码时&#xff0c;将框架改为SpringBoot后&#xff0c;引入了Hibernate V…

57从零开始学Java之一文详解String字符串的底层实现原理

作者&#xff1a;孙玉昌&#xff0c;昵称【一一哥】&#xff0c;另外【壹壹哥】也是我哦 千锋教育高级教研员、CSDN博客专家、万粉博主、阿里云专家博主、掘金优质作者 前言 在之前的两篇文章中&#xff0c;壹哥给大家介绍了String字符串及其常用的API方法、常用编码、正则表达…

第一讲:BeanFactory和ApplicationContext接口

BeanFactory和ApplicationContext接口 1. 什么是BeanFactory?2. BeanFactory能做什么&#xff1f;3.ApplicationContext对比BeanFactory的额外功能?3.1 MessageSource3.2 ResourcePatternResolver3.3 EnvironmentCapable3.4 ApplicationEventPublisher 4.总结 1. 什么是BeanF…

GuLi商城-前端基础Vue指令-单向绑定双向绑定

什么是指令? 指令 (Directives) 是带有 v- 前缀的特殊特性。 指令特性的预期值是:单个 JavaScript 表达式。 指令的职责是&#xff0c;当表达式的值改变时&#xff0c;将其产生的连带影响&#xff0c;响应式地作用于DOM 例如我们在入门案例中的 v-on&#xff0c;代表绑定事…

vscode + python

序 参考链接&#xff1a; 【教程】VScode中配置Python运行环境_哔哩哔哩_bilibili Python部分 Python Releases for Windows | Python.org vscode部分 Visual Studio Code - Code Editing. Redefined 一路next&#xff0c;全部勾上&#xff1a; 就可以了&#xff1a; 安装插…

盛元广通高校实验室开放预约与综合管理系统LIMS

系统概述&#xff1a; 高校实验室涉及到的课程、老师、学生多&#xff0c;管理起来费时费力&#xff0c;盛元广通高校实验室开放预约与综合管理系统LIMS提供简单易用的账号管理、实验室管理、课程管理、实验项目管理、实验时间设定&#xff1b;为学生提供简单易用的自主实验选…

设计模式篇---抽象工厂(包含优化)

文章目录 概念结构实例优化 概念 抽象工厂&#xff1a;提供一个创建一系列相关或相互依赖对象的接口&#xff0c;而无须指定它们具体的类。 工厂方法是有一个类型的产品&#xff0c;也就是只有一个产品的抽象类或接口&#xff0c;而抽象工厂相对于工厂方法来说&#xff0c;是有…

ORB-SLAM2学习笔记7之System主类和多线程

文章目录 0 引言1 整体框架1.1 整体流程 2 System主类2.1 成员函数2.2 成员变量 3 多线程3.1 ORB-SLAM2中的多线程3.2 加锁 0 引言 ORB-SLAM2是一种基于特征的视觉SLAM&#xff08;Simultaneous Localization and Mapping&#xff09;系统&#xff0c;它能够从单个、双目或RBG…

8月16日上课内容 部署LVS-DR群集

本章结构&#xff1a; 数据包流向分析: 数据包流向分析&#xff1a; &#xff08;1&#xff09;客户端发送请求到 Director Server&#xff08;负载均衡器&#xff09;&#xff0c;请求的数据报文&#xff08;源 IP 是 CIP,目标 IP 是 VIP&#xff09;到达内核空间。 &#xf…

追踪工时和控制成本 如何选对工具

研发工作中的工时管理软件是一种用于追踪、记录和分析团队成员在项目中所花费的工作时间的工具。它有助于组织、监控和优化研发项目的进展&#xff0c;确保资源得到有效利用&#xff0c;项目按时完成&#xff0c;并提供数据支持用于决策制定和资源规划。 能够记录团队成员的工…

Visual Studio 2022 你必须知道的实用调试技巧

目录 1、什么是bug&#xff1f; 2.调试是什么&#xff1f;有多重要&#xff1f; 2.1我们是如何写代码的&#xff1f; 2.2又是如何排查出现的问题的呢&#xff1f; ​编辑 2.3 调试是什么&#xff1f; 2.4调试的基本步骤 2.5Debug和Release的介绍 3.Windows环境调试介绍…

[Vue]解决npm run dev报错node:internal/modules/cjs/loader:1031 throw err;

解决: 有2中方法&#xff0c;建议先尝试第一种&#xff0c;不行再第二种 第一种: 重新安装依赖环境 删除项目的node_modules文件夹&#xff0c;重新执行 # 安装依赖环境 npm install# 运行 npm run dev 我只用了第一种方法就可以了 &#xff0c;第二种方法从别的博主那看到…

QT中的按钮控件Buttons介绍

目录 Buttons 按钮控件 1、常用属性介绍 2、按钮介绍 2.1QPushButton 普通按钮 2.2QtoolButton 工具按钮 2.3Radio Button单选按钮 2.4CheckButton复选按钮 2.5Commam Link Button命令链接按钮 2.6Dialog Button Box命令链接按钮 Buttons 按钮控件 在Qt里&#xff0c;…

【JavaEE进阶】SpringMVC

文章目录 一. 简单认识SpringMVC1. 什么是SpringMVC?2. SpringMVC与MVC的关系 二. SpringMVC1. SpringMVC创建和连接2. SpringMVC的简单使用2.1 RequestMapping 注解介绍2.2 RequestMapping支持的请求类型2.3 GetMapping 和 PostMapping 3. 获取参数3.1 传递单个参数3.2 传递对…

【1-3章】Spark编程基础(Python版)

课程资源&#xff1a;&#xff08;林子雨&#xff09;Spark编程基础(Python版)_哔哩哔哩_bilibili 第1章 大数据技术概述&#xff08;8节&#xff09; 第三次信息化浪潮&#xff1a;以物联网、云计算、大数据为标志 &#xff08;一&#xff09;大数据 大数据时代到来的原因…

如何编写一个通用的函数?

&#x1f388;个人主页:&#x1f388; :✨✨✨初阶牛✨✨✨ &#x1f43b;推荐专栏1: &#x1f354;&#x1f35f;&#x1f32f;C语言初阶 &#x1f43b;推荐专栏2: &#x1f354;&#x1f35f;&#x1f32f;C语言进阶 &#x1f511;个人信条: &#x1f335;知行合一 金句分享:…

Java 项目日志实例:Log4j2

点击下方关注我&#xff0c;然后右上角点击...“设为星标”&#xff0c;就能第一时间收到更新推送啦~~~ Apache Log4j 2 是对 Log4j 的升级&#xff0c;与其前身 Log4j 1.x 相比有了显着的改进&#xff0c;并提供了许多 Logback 可用的改进&#xff0c;同时支持 JCL 以及 SLF4J…