【扩散模型】实战:创建一个类别条件扩散模型

创建一个类别条件扩散模型

  • 1. 配置和数据准备
  • 2. 创建一个以类别为条件的UNet模型
  • 3. 训练和采样

本文介绍一种给扩散模型添加额外条件信息的方法。具体地,将在MNIST数据集上训练一个以类别为条件的扩散模型。并且可以在推理阶段指定想要生成的是哪个数字。

1. 配置和数据准备

首先安装diffusers库:

!pip install -q diffusers

导入相关依赖包:
导入依赖包
加载MNIST数据集:

# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(
    root="./mnist/", 
    train=True, 
    download=True, 
    transform=torchvision.transforms.ToTensor()
    )
# 创建数据加载器
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# 查看MNIST数据集中的部分样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

部分样本

2. 创建一个以类别为条件的UNet模型

输入类别这一条件的流程:
(1)创建一个标准的UNet2DModel加入一些额外的输入通道
(2)通过一个嵌入层,把类别标签映射到一个长度为class_emb_size的特征向量上。
(3)把这个信息作为额外通道和原有的输入向量拼接起来。

net_input = torch.cat((x, class_cond), 1)

(4)将net_input(其中包含class_emb_size + 1个通道)输入UNet模型,得到最终的预测结果。

这里,class_emb_size被设置成4,但它其实是可以进行任意修改的,或者把需要学到的nn.Embedding替换成简单地对类别进行one-hot编码,代码如下:

class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()
    # 这个网络层会把数字所属的类别映射到一个长度为class_emb_size的特征向量上
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # self.model是一个不带生成条件的UNet模型,这里,给他添加了额外的输入通道,用于接收条件信息
    self.model = UNet2DModel(
        sample_size = 28, # 所生成图片的尺寸
        in_channels = 1 + class_emb_size, # 加入额外的输入通道
        out_channels = 1, # 输出结果的通道数
        layers_per_block=2, # 残差层个数
        block_out_channels=(32, 64, 64),
        down_block_types=(
            "DownBlock2D", # 常规的ResNet下采样模块
            "AttnDownBlock2D", # 含有spatial self-attention的ResNet下采样模块
            "AttnDownBlock2D", 
        ),
        up_block_types=(
            "AttnUpBlock2D", 
            "AttnUpBlock2D", # 含有spatil self-attention的ResNet上采样模块
            "UpBlock2D",  # 上采样模块
        ),
    )
  
  # 此时扩散模型的前向计算就会含有额外的类别标签作为输入了
  def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape
    # 类别条件将会以额外通道的形式输入
    class_cond = self.class_emb(class_labels)  # 将类别映射为向量形式,
    # 并扩展成类似于(bs, 4, 28, 28)的张量形式
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # 将原始输入和类别条件信息拼接到一起
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
    # 使用模型进行预测
    return self.model(net_input, t).sample # (bs, 1, 28, 28)

3. 训练和采样

这里使用prediction = unet(x, t, y)在训练时把正确的标签作为第三个输入发送给模型。如果一切正常,模型将会输出与之相匹配的图片。y在这里的范围是0~9.

# 创建一个调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

# 定义数据加载器
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

n_epochs = 10
loss_fn = nn.MSELoss()
net = ClassConditionedUnet().to(device)
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 
losses = []

# 训练开始
for epoch in range(n_epochs):
  for x, y in tqdm(train_dataloader):
    # 获取数据并添加噪声
    x = x.to(device) * 2 - 1 # 数据被归一化到区间(-1, 1)
    y = y.to(device)
    noise = torch.randn_like(x)
    timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
    noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

    # 预测
    pred = net(noisy_x, timesteps, y)  # 注意这里输入了类别信息
    # 计算损失值
    loss = loss_fn(pred, noise) 
    opt.zero_grad()
    loss.backward()
    opt.step()

    losses.append(loss.item())
  
  avg_loss = sum(losses[-100:])/100
  print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')
plt.plot(losses)

Finished epoch 0. Average of the last 100 loss values: 0.053393
Finished epoch 1. Average of the last 100 loss values: 0.047172
Finished epoch 2. Average of the last 100 loss values: 0.045227
Finished epoch 3. Average of the last 100 loss values: 0.043402
Finished epoch 4. Average of the last 100 loss values: 0.041524
Finished epoch 5. Average of the last 100 loss values: 0.040847
Finished epoch 6. Average of the last 100 loss values: 0.040252
Finished epoch 7. Average of the last 100 loss values: 0.040134
Finished epoch 8. Average of the last 100 loss values: 0.038976
Finished epoch 9. Average of the last 100 loss values: 0.039234
损失曲线:
loss
训练结束后,可以通过输入不同的标签作为条件来采样图片:

# 准备一个随机噪声作为起点,并准备想要的图片标签
x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)
print(y)
# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
  with torch.no_grad():
    residual = net(x, t, y)
  x = noise_scheduler.step(residual, t, x).prev_sample

# 显示结果
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')

这里,我们的y标签为:

tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
        6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
        9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0'

因此对应生成的图片为:
生长指定标签的图片
至此,已经实现了对输出图片的控制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/130637.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Navicat 解放双手的自动运行功能

Navicat 的自动运行功能可以帮助用户自动化重复性和周期性的任务,提高工作效率和数据安全性。用户可以根据自己的需求设置自动运行的任务和计划,以确保数据库管理和数据操作的顺利进行。为帮助用户更便捷、更直观地了解自动运行功能以及电子邮件通知的操…

烟草5G智慧工厂数字孪生可视化平台,赋能烟草工业数字化智慧转型

随着卷烟工厂提质增效需求增强,信息化建设推进及生产制造系统智能化改革发展,各生产单元逐步升级完善数字化,最终实现智能制造成为必然趋势。因此,5G卷烟加工工厂的数字化转型迫在眉睫。中国烟草制造行业正迈向全新的市场经济时代…

win11 idea 错误: 找不到或无法加载主类

买了新电脑win11系统,配置环境之后运行项目,始终运行不起来,一直报 刚开始以为是环境没装好,但是我创建其他项目运行时是正常的 纠结了好久突然发现,是不是因为项目路径中有中文造成的找不到编译文件 最后把项目改为…

汽车标定技术(九)--标定常量与#pragma的趣事

目录 1. 不添加#pragma语句 2. 添加#pragma语句 3. 标定量只给flash空间,不给ram指定空间 4. 总结 在之前不会使用overlay机制的时候,我们想要做汽车标定,标定常量编译出来的地址一般都应该是ram的地址,而且在链接文件中都会指…

HTML5学习系列之简单使用1

HTML5学习系列之简单使用1 前言基础显示学习定义网页标题定义网页元信息定义网页元信息定义文档结构div元素di和classtitlerole注释 总结 前言 下班加班期间的简单学习。 基础显示学习 定义网页标题 <html lang"en"> <head> <title>从今天开始努…

内存缓存系统

胤凯 (oyto.github.io)&#xff0c;欢迎到我的博客阅读。 今天我们围绕一个面试题来实现一个内存缓存系统。 面试题内容 1. 支持设置过期时间&#xff0c;精度到秒 2. 支持设置最大内存&#xff0c;当内存超出时做出合理的处理 3. 支持并发安全 4. 按照以下接口要求实现 typ…

【poi导出excel模板——通过建造者模式+策略模式+函数式接口实现】

poi导出excel模板——通过建造者模式策略模式函数式接口实现 poi导出excel示例优化思路代码实现补充建造者模式策略模式 poi导出excel示例 首先我们现看一下poi如何导出excel&#xff0c;这里举个例子&#xff1a;目前想要导出一个Map<sex,List>信息&#xff0c;sex作为…

使用Dockerfile依赖maven基础镜像部署springboot的程序案例

1、准备springboot Demo代码 就一个controller层代码&#xff0c;返回当前时间及hello world 2、项目根目录下&#xff0c;新建DockerFile文件 注意&#xff0c;等本地配置完毕后&#xff0c;Dockerfile文件需要与项目helloworld同级&#xff0c;这里先放项目里面 3、docker …

【MATLAB源码-第73期】基于matlab的OFDM-IM索引调制系统不同子载波数目误码率对比,对比OFDM系统。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 OFDM-IM索引调制技术是一种新型的无线通信技术&#xff0c;它将正交频分复用&#xff08;OFDM&#xff09;和索引调制&#xff08;IM&#xff09;相结合&#xff0c;以提高频谱效率和系统容量。OFDM-IM索引调制技术的基本思想…

Flink SQL自定义标量函数(Scalar Function)

使用场景&#xff1a; 标量函数即 UDF&#xff0c;⽤于进⼀条数据出⼀条数据的场景。 开发流程&#xff1a; 实现 org.apache.flink.table.functions.ScalarFunction 接⼝实现⼀个或者多个⾃定义的 eval 函数&#xff0c;名称必须叫做 eval&#xff0c;eval ⽅法签名必须是 p…

快速入门安装及使用git与svn的区别常用命令

一、导言 1、什么是svn&#xff1f; SVN是Subversion的简称&#xff0c;是一个集中式版本控制系统。与Git不同&#xff0c;SVN没有分布式的特性。在SVN中&#xff0c;项目的代码仓库位于服务器上&#xff0c;团队成员通过向服务器提交和获取代码来实现版本控制。SVN记录了每个…

Hbuilder打包项目为h5

Hbuilder打包项目为h5 manifest.json 配置 修改 web 配置下的 页面标题、路由模式、运行的基础路径 发行 H5 发行 填入网站标题和网站域名 编译 编译完成之后存放在 unpackage/dist/build/h5 目录下

Day26力扣打卡

打卡记录 搜索旋转排序数组&#xff08;二分&#xff09; 链接 class Solution {int findMin(vector<int> &nums) {int left -1, right nums.size() - 1; // 开区间 (-1, n-1)while (left 1 < right) { // 开区间不为空int mid left (right - left) / 2;if…

医学图像 ABIDE 等数据集 .nii.gz Python格式化显示

nii.gz 文件 .nii.gz 文件通常是医学影像数据的一种常见格式&#xff0c;比如神经影像&#xff08;如脑部MRI&#xff09;。这种文件格式通常是经过gzip压缩的NIfTI格式&#xff08;Neuroimaging Informatics Technology Initiative&#xff09;。 要在Python中查看.nii.gz文…

设备零部件更换ar远程指导系统加强培训效果

随着科技的发展&#xff0c;AR技术已经成为了一种广泛应用的新型技术。AR远程指导系统作为AR技术的一种应用&#xff0c;具有非常广泛的应用前景。 一、应用场景 气象监测AR教学软件适用于多个领域&#xff0c;包括气象、环境、地理等。在教学过程中&#xff0c;软件可以帮助学…

黑客(网络安全)技术——高效自学1.0

前言 前几天发布了一篇 网络安全&#xff08;黑客&#xff09;自学 没想到收到了许多人的私信想要学习网安黑客技术&#xff01;却不知道从哪里开始学起&#xff01;怎么学 今天给大家分享一下&#xff0c;很多人上来就说想学习黑客&#xff0c;但是连方向都没搞清楚就开始学习…

Paimon 与 Spark 的集成(一)

Paimon Apache Paimon (incubating) 是一项流式数据湖存储技术&#xff0c;可以为用户提供高吞吐、低延迟的数据摄入、流式订阅以及实时查询能力。Paimon 采用开放的数据格式和技术理念&#xff0c;可以与 ApacheFlink / Spark / Trino 等诸多业界主流计算引擎进行对接&#xf…

听GPT 讲Rust源代码--library/core/src(2)

题图来自 5 Ways Rust Programming Language Is Used[1] File: rust/library/core/src/iter/adapters/by_ref_sized.rs 在Rust的源代码中&#xff0c;rust/library/core/src/iter/adapters/by_ref_sized.rs 文件实现了 ByRefSized 适配器&#xff0c;该适配器用于创建一个可以以…

在Node.js中,什么是事件发射器(EventEmitter)?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

全新Inner-IoU损失函数!!!通过辅助边界框计算IoU有效提升检测效果

摘要 1 简介 2 方法 2.1 边界框回归模式分析 2.2 Inner-IoU 损失 3 实验 3.1 模拟实验 3.2 对比实验 3.2.1 PASCAL VOC上的YOLOv7 3.2.2 YOLOv5 在 AI-TOD 上 4. 参考 摘要 随着检测器的快速发展&#xff0c;边界框回归&#xff08;BBR&#xff09;损失函数不断进…