扩散模型从原理到实战 入门

diffusion-models-class-CN/unit1/README_CN.md at main · darcula1993/diffusion-models-class-CN · GitHub

你可以使用命令行来通过此令牌登录 (huggingface-cli login) 或者运行以下单元来登录:

from huggingface_hub import notebook_login

notebook_login()

 https://huggingface.co/settings/tokens

在上面的网站中

获取令牌登录

 直接调包的例子

导入将要使用的库,并定义一些方便函数,稍后将会在 Notebook 中使用这些函数:

import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image


def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im


def make_grid(images, size=64):
    """Given a list of PIL images, stack them together into a line for easy viewing"""
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im


# Mac users may need device = 'mps' (untested)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

加载管道

from diffusers import StableDiffusionPipeline

# Check out https://huggingface.co/sd-dreambooth-library for loads of models from the community
model_id = "sd-dreambooth-library/mr-potato-head"

# Load the pipeline
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(
    device
)

管道加载完成后,我们可以使用以下命令生成图像:

# prompt = "an abstract oil painting of sks mr potato head by picasso"
prompt = "A dancing elephant"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image

 Diffusers 的核心 API 被分为三个主要部分:

  1. 管道: 从高层出发设计的多种类函数,旨在以易部署的方式,能够做到快速通过主流预训练好的扩散模型来生成样本。
  2. 模型: 训练新的扩散模型时用到的主流网络架构,e.g. UNet.
  3. 管理器 (or 调度器): 在 推理 中使用多种不同的技巧来从噪声中生成图像,同时也可以生成在 训练 中所需的带噪图像。

能够生成小蝴蝶图片的管道。

from diffusers import DDPMPipeline

# Load the butterfly pipeline
butterfly_pipeline = DDPMPipeline.from_pretrained(
    "johnowhitaker/ddpm-butterflies-32px"
# ).to(device)
).to("cuda")

# Create 8 images
images = butterfly_pipeline(batch_size=8).images

# View the result
make_grid(images)

下面这里会是最终的结果: 

自己构建管道

下载一个训练数据集

在这个例子中,我们会用到一个来自 Hugging Face Hub 的图像集。具体来说,是个 1000 张蝴蝶图像收藏集. 这是个非常小的数据集,我们这里也同时包含了已被注释的内容指向一些规模更大的选择。如果你想使用你自己的图像收藏,你也可以使用这里被注释掉的示例代码,从一个指定的文件夹来装载图片。

# %pip install -qq -U datasets
import torchvision
from datasets import load_dataset
from torchvision import transforms
import torch

dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")

# Or load images from a local folder
# dataset = load_dataset("imagefolder", data_dir="path/to/folder")

# We'll train on 32-pixel square images, but you can try larger sizes too
image_size = 32
# You can lower your batch size if you're running out of GPU memory
batch_size = 64

# Define data augmentations
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}


dataset.set_transform(transform)

# Create a dataloader from the dataset to serve up the transformed images in batches
train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)

我们可以从中取出一批图像数据来看一看他们是什么样子:

xb = next(iter(train_dataloader))["images"].to("cuda")[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)

定义管理器

我们的训练计划是,取出这些输入图片然后对它们增添噪声,在这之后把带噪的图片送入模型。在推理阶段,我们将用模型的预测值来不断迭代去除这些噪点。在diffusers中,这两个步骤都是由 管理器(调度器) 来处理的。

噪声管理器决定在不同的迭代周期时分别加入多少噪声。我们可以这样创建一个管理器,是取自于训练并能取样 'DDPM' 的默认配置。 (基于此篇论文 "Denoising Diffusion Probabalistic Models":

from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

DDPM 论文这样来描述一个损坏过程,为每一个 ' 迭代周期 '(timestep) 增添一点少量的噪声。设在某个迭代周期有 xt−1, 我们可以得到它的下一个版本 xt (比之前更多一点点噪声):

这就是说,我们取 xt−1, 给他一个 的系数,然后加上带有 βt 系数的噪声。 这里 β 是根据一些管理器来为每一个 t 设定的,来决定每一个迭代周期中添加多少噪声。 现在,我们不想把这个推演进行 500 次来得到 x500,所以我们用另一个公式来根据给出的 x0 计算得到任意 t 时刻的 xt:

数学符号看起来总是很可怕!好在有管理器来为我们完成这些运算。我们可以画出 α¯t−−√ (标记为sqrt_alpha_prod) 和 (1−α¯t)−−−−−−−√ (标记为sqrt_one_minus_alpha_prod) 来看一下输入 (x) 与噪声是如何在不同迭代周期中量化和叠加的:

plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

噪声越来越大

不论你选择了哪一个管理器 (调度器),我们现在都可以使用noise_scheduler.add_noise功能来添加不同程度的噪声,就像这样:

timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)

定义模型

大多数扩散模型使用的模型结构都是一些 [U-net] 的变形 (https://arxiv.org/abs/1505.04597) 也是我们在这里会用到的结构。

概括来说:

  • 输入模型中的图片经过几个由 ResNetLayer 构成的层,其中每层都使图片尺寸减半。
  • 之后在经过同样数量的层把图片升采样。
  • 其中还有对特征在相同位置的上、下采样层残差连接模块。

模型一个关键特征既是,输出图片尺寸与输入图片相同,这正是我们这里需要的。

Diffusers 为我们提供了一个易用的UNet2DModel类,用来在 PyTorch 创建所需要的结构。

我们来使用 U-net 为我们生成目标大小的图片吧。 注意这里down_block_types对应下采样模块 (上图中绿色部分), 而up_block_types对应上采样模块 (上图中红色部分):

from diffusers import UNet2DModel

# Create a model
model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)
model.to(device);

当在处理更高分辨率的输入时,你可能想用更多层的下、上采样模块,让注意力层只聚焦在最低分辨率(最底)层来减少内存消耗。我们在之后会讨论该如何实验来找到最适用与你手头场景的配置方法。

我们可以通过输入一批数据和随机的迭代周期数来看输出是否与输入尺寸相同:

with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape

创建训练循环

终于可以训练了!下面这是 PyTorch 中经典的优化迭代循环,在这里一批一批的送入数据然后通过优化器来一步步更新模型参数 - 在这个样例中我们使用学习率为 0.0004 的 AdamW 优化器。

对于每一批的数据,我们要

  • 随机取样几个迭代周期
  • 根据预设为数据加入噪声
  • 把带噪数据送入模型
  • 使用 MSE 作为损失函数来比较目标结果与模型预测结果(在这里是加入噪声的场景)
  • 通过loss.backward ()optimizer.step ()来更新模型参数

在这个过程中我们记录 Loss 值用来后续的绘图。

NB: 这段代码大概需 10 分钟来运行 - 你也可以跳过以下两块操作直接使用预训练好的模型。供你选择,你可以探索下通过缩小模型层中的通道数会对运行速度有多少提升。

# Set the noise scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []

for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch["images"].to(device)
        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # Get the model prediction
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        # Calculate the loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())

        # Update the model parameters with the optimizer
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

绘制 loss 曲线,我们能看到模型在一开始快速的收敛,接下来以一个较慢的速度持续优化(我们用右边 log 坐标轴的视图可以看的更清楚):

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()

生成图像

方法 1:建立一个管道:
from diffusers import DDPMPipeline

image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)
pipeline_output = image_pipe()
pipeline_output.images[0]

我们可以在本地文件夹这样保存一个管道:

image_pipe.save_pretrained("my_pipeline")

检查文件夹的内容:

!ls my_pipeline/

方法 2:写一个取样循环

从随机噪声开始,遍历管理器的迭代周期来看从最嘈杂直到最微小的噪声变化,基于模型的预测一步步减少一些噪声:

# Random starting point (8 random images):
sample = torch.randn(8, 3, 32, 32).to(device)

for i, t in enumerate(noise_scheduler.timesteps):

    # Get model pred
    with torch.no_grad():
        residual = model(sample, t).sample

    # Update sample with step
    sample = noise_scheduler.step(residual, t, sample).prev_sample

show_images(sample)

 把你的模型 Push 到 Hub

在上面的例子中我们把管道保存在了本地。把模型 push 到 hub 上,我们会需要建立模型和相应文件的仓库名。我们根据你的选择(模型 ID)来决定仓库的名字(大胆的去替换掉model_name吧;需要包含你的用户名,get_full_repo_name ()会帮你做到):

from huggingface_hub import get_full_repo_name

model_name = "zhzhzhzhzhz-butterflies-32"
hub_model_id = get_full_repo_name(model_name)
hub_model_id

然后,在 🤗 Hub 上创建模型仓库并 push 它吧:

from huggingface_hub import HfApi, create_repo

create_repo(hub_model_id)
api = HfApi()
api.upload_folder(
    folder_path="my_pipeline/scheduler", path_in_repo="", repo_id=hub_model_id
)
api.upload_folder(folder_path="my_pipeline/unet", path_in_repo="", repo_id=hub_model_id)
api.upload_file(
    path_or_fileobj="my_pipeline/model_index.json",
    path_in_repo="model_index.json",
    repo_id=hub_model_id,
)

最后一件事是创建一个超棒的模型卡,如此,我们的蝴蝶生成器可以轻松的在 Hub 上被找到(请在描述中随意发挥!):

from huggingface_hub import ModelCard

content = f"""
---
license: mit
tags:
- pytorch
- diffusers
- unconditional-image-generation
- diffusion-models-class
---

# Model Card for Unit 1 of the [Diffusion Models Class 🧨](https://github.com/huggingface/diffusion-models-class)

This model is a diffusion model for unconditional image generation of cute 🦋.

## Usage

```python
from diffusers import DDPMPipeline

pipeline = DDPMPipeline.from_pretrained('{hub_model_id}')
image = pipeline().images[0]
image
```
"""

card = ModelCard(content)
card.push_to_hub(hub_model_id)

现在模型已经在 Hub 上了,你可以这样从任何地方使用DDPMPipelinefrom_pretrained ()方法来下来它:

from diffusers import DDPMPipeline

hub_model_id = "https://huggingface.co/zhzhzhzhzhz/zhzhzhzhzhz-butterflies-32"

image_pipe = DDPMPipeline.from_pretrained(hub_model_id)
pipeline_output = image_pipe()
pipeline_output.images[0]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/921819.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

企业信息化-走进身份管理之搭建篇

​一、身份管理是什么 我们先要弄懂统一身份管理到底是什么? 统一身份管理(Unified Identity Manager,UIM),身份管理(Identity Management,简称IDM),也被称为IAM&#…

周期法频率计的设计

目录 周期法频率计 分析: 设计过程: 周期法频率计 对于低频信号,应用周期法进行测频。周期法测频的基本原理是:应用标准频率信号统计被测信号两个相邻脉冲之间的脉冲数,然后通过脉冲数计算出被测信号的周期&#xff…

C语言--分支循环编程题目

第一道题目&#xff1a; #include <stdio.h>int main() {//分析&#xff1a;//1.连续读取int a 0;int b 0;int c 0;while (scanf("%d %d %d\n", &a, &b, &c) ! EOF){//2.对三角形的判断//a b c 等边三角形 其中两个相等 等腰三角形 其余情…

MySQL Join 的原理与优化实践

文章目录 引言一、基础准备&#xff1a;创建环境与示例数据1. 初始化示例表2. 示例 Join 查询3. EXPLAIN 输出分析 二、MySQL Join 的核心算法与执行机制1. 三种 Join 算法的实现与原理1.1 Index Nested-Loop Join&#xff08;INLJ&#xff09;1.2 Simple Nested-Loop Join&…

关于安卓模拟器或手机设置了BurpSuite代理和安装证书后仍然抓取不到APP数据包的解决办法

免责申明 本文仅是用于学习研究安卓系统设置代理后抓取不到App数据包实验,请勿用在非法途径上,若将其用于非法目的,所造成的一切后果由您自行承担,产生的一切风险和后果与笔者无关;本文开始前请认真详细学习《‌中华人民共和国网络安全法》【学法时习之丨网络安全在身边一…

飞凌嵌入式旗下教育品牌ElfBoard与西安科技大学共建「科教融合基地」

近日&#xff0c;飞凌嵌入式与西安科技大学共同举办了“科教融合基地”签约揭牌仪式。此次合作旨在深化嵌入式创新人才的培育&#xff0c;加速科技成果的转化应用&#xff0c;标志着双方共同开启了一段校企合作的新篇章。 出席本次签约揭牌仪式的有飞凌嵌入式梁总、高总等一行…

下载安装Android Studio

&#xff08;一&#xff09;Android Studio下载地址 https://developer.android.google.cn/studio 滑动到 点击下载文档 打开新网页 切换到english ![](https://i-blog.csdnimg.cn/direct/b7052b434f9d4418b9d56c66cdd59fae.png 等待一会&#xff0c;出现 点同意后&#xff0…

准备阶段 Profiler性能分析工具的使用(一)

Unity 性能分析器 (Unity Profiler) 性能分析器记录应用程序性能的多个方面并显示相关信息。使用此信息可以做出有关应用程序中可能需要优化的事项的明智决策&#xff0c;并确认所做的优化是否产生预期结果。 默认情况下&#xff0c;性能分析器记录并保留游戏的最后 300 帧&a…

01Web3.0行业

目录 一、什么是Web 3.0? 二、Web 1.0 vs Web 2.0 vs Web 3.0 三、为什么选择Web 3.0 四、从法律角度观察Web 3.0 1. Web 3.0前时代的数字身份 问题1&#xff1a;个人信息的过度收集 问题2&#xff1a;个人信息的泄露和滥用 2. Web 3.0的解决方案及其法律问题 问题一&…

archlinux安装waydroid

目录 参考资料 注意 第一步切换wayland 第二步安装binder核心模组 注意 开始安装 AUR安裝Waydroid 启动waydroid 设置网络&#xff08;正常的可以不看&#xff09; 注册谷歌设备 安装Arm转译器 重启即可 其他 参考资料 https://ivonblog.com/posts/archlinux-way…

互联网时代的隐私保护

在这个数字化时代&#xff0c;我们的生活与互联网密不可分。打开手机刷刷朋友圈&#xff0c;浏览一下购物网站&#xff0c;约个网约车&#xff0c;点个外卖&#xff0c;这些看似平常的行为都在默默产生着数据足迹。可就在这不经意间&#xff0c;我们的个人信息正在被收集、分析…

python之使用django框架开发web项目

本问将对django框架在python的web项目中的使用进行介绍,有不对之处,烦请指正。 首先使用创建一个django工程(本示例中使用pycharm2024+python3.12),名称和项目保存路径根据自己的需要自行修改,新手直接默认本机环境就好(关于conda将会另开一篇进行讲解。),最后点击cre…

基于YOLOv8深度学习的扰乱公共秩序打架异常行为检测系统研究与实现(PyQt5界面+数据集+训练代码)

随着智能监控技术和人工智能的发展&#xff0c;基于深度学习的行为检测技术在公共安全和防范领域中发挥着越来越重要的作用。传统的监控系统通常依赖于人工监控&#xff0c;这不仅耗费大量的人力和时间&#xff0c;且容易因为人的疲劳或疏忽而漏检关键的异常行为。而近年来&…

gocv调用opencv添加中文乱码的解决方案

前言 相信很多做视觉的同学在使用opencv给图片添加中文文字的时候会出现这样的乱码显示: 而实际上你期望的是“告警时间:2011-11-11 11:11:11 告警类型:脱岗检测告警 Area:XXXXX Camera:Camera001-001”这样的显示内容,那么这篇文章我将用很简单的方法来解决乱码问题,只需…

JavaScript中的this指向问题

JavaScript中的this指向问题 1.1 为什么需要this? 为什么需要this? 在常见的编程语言中&#xff0c;几乎都有this这个关键字&#xff08;Objective-C中使用的是self),但是在JavaScript中的this和常见的面向对象语言中的this不太一样 常见面向对象的编程语言中&#xff0c;比…

预测气动阻尼

TLDR&#xff1a;通过结合 ANSYS Mechanical 和 ANSYS CFX&#xff0c;可以通过模拟预测气动阻尼。此方法可用于涡轮叶片、飞机机翼或 MEMS 微镜&#xff01; MEMS 系统的频率响应。峰值的高度取决于阻尼……那么阻尼比是多少&#xff1f; 多年来&#xff0c;很多人问我“嘿&am…

在 CentOS 系统上直接安装 MongoDB 4.0.25

文章目录 步骤 1&#xff1a;配置 MongoDB 官方源步骤 2&#xff1a;安装 MongoDB步骤 3&#xff1a;启动 MongoDB 服务步骤 4&#xff1a;验证安装步骤 5&#xff1a;可选配置注意事项 以下是在 CentOS 系统上直接安装 MongoDB 4.0.25 的详细步骤&#xff1a; 步骤 1&#x…

.NET9 - 新功能体验(一)

被微软形容为“迄今为止最高效、最现代、最安全、最智能、性能最高的.NET版本”——.NET 9已经发布有一周了&#xff0c;今天想和大家一起体验一下新功能。 此次.NET 9在性能、安全性和功能等方面进行了大量改进&#xff0c;包含了数千项的修改&#xff0c;今天主要和大家一起体…

乐理的学习(调式)

大致了解乐理之后的总结 跟着西蒙哥也是把基础乐理差不多能有一个大致的总结框架了&#xff0c;主要还是为了弹钢琴&#xff0c;也是知道了很多的规则都是为了人们的感受服务的 对手指要了解 对于手指的弹音局限 各个手指的使用频率 不同年龄的不同的人对手指的使用存在差…

08 —— Webpack打包图片

【资源模块 | webpack 中文文档 | webpack中文文档 | webpack中文网】https://www.webpackjs.com/guides/asset-modules/?sid_for_share99125_3 Webpack打包图片以8KB为临界值判断 大于8KB的文件&#xff1a;发送一个单独的文件并导出URL地址 小于8KB的文件&#xff1a;导出一…