diffusion model 简单demo

参考自:
Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
diffusion 简单demo
扩散模型之DDPM
Diffusion model 原理剖析
张振虎-扩散概率模型
生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼

核心公式和逻辑

在这里插入图片描述

核心公式:

在这里插入图片描述

训练阶段

在这里插入图片描述
实际上是根据加噪后的图 和时间步 t 去预测噪声
在这里插入图片描述

q_x 计算公式,后面会用到:
在这里插入图片描述

推理

在这里插入图片描述

代码

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve, make_swiss_roll
from PIL import Image
import torch
import io

# get data
# s_curve, _ = make_s_curve(10**4 , noise=0.1)
# s_curve = s_curve[:, [0, 2]] / 10.0

swiss_roll, _ = make_swiss_roll(10**4,noise=0.1)
s_curve = swiss_roll[:, [0, 2]]/10.0

print('shape of moons: ', np.shape(s_curve))

data = s_curve.T
fix, ax = plt.subplots()
ax.scatter(*data, color='red', edgecolors='white', alpha=0.5)

ax.axis('off')

# plt.show()
plt.savefig('./s_curve.png')

dataset = torch.Tensor(s_curve).float()

# set params
num_steps = 100

betas = torch.linspace(-6, 6, num_steps)    # # 逐渐递增
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5    # β0,β1,...,βt

print('beta: ', betas)

alphas = 1 - betas
alphas_pro = torch.cumprod(alphas, 0)   # αt^ = αt的累乘

# αt^往右平移一位, 原第t步的值维第t-1步的值, 第0步补1
alphas_pro_p = torch.cat([torch.tensor([1]).float(), alphas_pro[:-1]], 0)   # p表示previous, 即 αt-1^


alphas_bar_sqrt = torch.sqrt(alphas_pro)    # αt^ 开根号
one_minus_alphas_bar_log = torch.log(1 - alphas_pro)    # log (1 - αt^)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_pro)  # 根号下(1-αt^)

assert alphas.shape == alphas_pro.shape == alphas_pro_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape

print('beta: shape ', betas.shape)

# diffusion process

def q_x(x_0, t):
    ''' get q_x_{\t}
    作用: 可以基于x[0]得到任意时刻t的x[t]
    输入: x_0:初始干净图像; t:采样步
    输出: x_t:第t步时的x_0的样子
    '''
    noise = torch.randn_like(x_0) # 正态分布的随机噪声
    alphas_t = alphas_bar_sqrt[t]
    alphas_l_m_t = one_minus_alphas_bar_sqrt[t]

    return (alphas_t * x_0 + alphas_l_m_t * noise)


# test add noise
num_shows = 20
fig, axs = plt.subplots(2, 10, figsize=(28, 3))
plt.rc('text', color='blue')

# 测试一下加噪下过
## 共有10000个点,每个点包含两个坐标
## 生成100步以内,每个5步加噪后图像


for i in range(num_shows):
    j = i // 10
    k = i % 10
    q_i = q_x(dataset, torch.tensor(i * num_steps // num_shows))    # 生成t时刻的采样数据
    axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white')
    axs[j, k].set_axis_off()
    axs[j, k].set_title('$q(\mathbf{x}_{' + str(i*num_steps // num_shows) + '})$')
    
# plt.show()
plt.savefig('diffusion_process.png')

# diffusion reverse process

# --------------------- diffusion model -----------------

import torch
import torch.nn as nn

class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=32):
        super(MLPDiffusion, self).__init__()
        
        self.linears = nn.ModuleList(
            [
                nn.Linear(2, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, 2)
            ]
        )
        
        self.step_embeddings = nn.ModuleList(
            [nn.Embedding(n_steps, num_units),
             nn.Embedding(n_steps, num_units),
             nn.Embedding(n_steps, num_units),
             ]
        )

    def forward(self, x, t):
        """
        模型的输入是加噪后的图片x和加噪step-> t, 输出是噪声
        """
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)

        x = self.linears[-1](x) # shape: [10000, 2]

        return x

# loss function
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps, use_cuda=False):
    """
    作用: 对任意时刻t进行采样计算loss
    参数:
        model: 模型
        x_0: 干净的图
        alphas_bar_sqrt: 根号下αt^
        one_minus_alphas_bar_sqrt: 根号下(1-αt^)
        n_steps: 采样步
    """
    batch_size = x_0.shape[0]

    # 对一个batchsize样本生成随机的时刻t, 覆盖到更多不同的t
    t = torch.randint(0, n_steps, size=(batch_size//2,))  # 在0~99内生成整数采样步
    t = torch.cat([t, n_steps-1-t], dim=0)  # 一个batch的采样步, 尽量让生成的t不重复
    t = t.unsqueeze(-1)  # 扩展维度 -> [batchsize, 1]
    if use_cuda:
        t = t.cuda()

    # x0的系数
    a = alphas_bar_sqrt[t]  # 根号下αt^

    # eps的系数
    aml = one_minus_alphas_bar_sqrt[t]  # 根号下(1-αt^)

    # 生成随机噪音eps
    e = torch.randn_like(x_0)
    if use_cuda:
        e = e.cuda()

    # 构造模型的输入
    x = x_0 * a + e * aml  # 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * eps

    # 送入模型,得到t时刻的随机噪声预测值
    output = model(x, t.squeeze(-1))  # 模型预测的是噪声, 噪声维度与x0一样大, [10000,2]

    # 与真实噪声一起计算误差,求平均值
    return (e - output).square().mean()



# --------------- reverse process ---------------
def p_sample_loop(model, shape, n_steps, betas, one_minus_alphas_bar_sqrt, use_cuda=False):
    """
    作用: 从x[T]恢复x[T-1]、x[T-2]、...x[0]
    输入:
        model:模型
        shape:数据大小,用于生成随机噪声
        n_steps:逆扩散总步长
        betas: βt
        one_minus_alphas_bar_sqrt: 根号下(1-αt^)
    输出:
        x_seq: 一个序列的x, 即 x[T]、x[T-1]、x[T-2]、...x[0]
    """
    if use_cuda:
        cur_x = torch.randn(shape).cuda()
    else:
        cur_x = torch.randn(shape)  # 随机噪声, 对应xt
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt, use_cuda=use_cuda)
        x_seq.append(cur_x)

    return x_seq


def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt, use_cuda=False):
    """
    作用: 从x[T]采样t时刻的重构值
    输入:
        model:模型
        x: 采样的随机噪声x[T]
        t: 采样步
        betas: βt
        one_minus_alphas_bar_sqrt: 根号下(1-αt^)
    输出:
        sample: 样本
    """
    if use_cuda:
        t = torch.tensor([t]).cuda()
    else:
        t = torch.tensor([t])

    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]  # 模型输出的系数:βt/根号下(1-αt^) = 1-αt/根号下(1-αt^)
    
    eps_theta = model(x, t)  # 模型的输出: εθ(xt, t)
        
    # (1/根号下αt) * (xt - (1-αt/根号下(1-αt^))*εθ(xt, t))
    mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))  
    if use_cuda:
        z = torch.randn_like(x).cuda()  # 对应公式中的 z
    else:
        z = torch.randn_like(x)  # 对应公式中的 z

    sigma_t = betas[t].sqrt()  # 对应公式中的 σt

    sample = mean + sigma_t * z

    return (sample)


# ----------- trainning ------------

print('Training model...')
if_use_cuda = True
batch_size = 1024
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)
num_epoch = 4000
plt.rc('text',color='blue')


model = MLPDiffusion(num_steps)  # 输出维度是2,输入是x和step
if if_use_cuda:
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

iteration = 0
for t in range(num_epoch):
    for idx, batch_x in enumerate(dataloader):
        # 损失计算
        if if_use_cuda:
            loss = diffusion_loss_fn(model, batch_x.cuda(), alphas_bar_sqrt.cuda(), one_minus_alphas_bar_sqrt.cuda(), num_steps, use_cuda=if_use_cuda)
        else:
            loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps)

        optimizer.zero_grad()  # 梯度清零
        loss.backward()  # 损失回传
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.)  # 梯度裁剪
        optimizer.step()

        iteration += 1

        # if iteration % 100 == 0:
        if(t % 100 == 0):
            print(f'epoch: {t} , loss: ', loss.item())
            if if_use_cuda:
                x_seq = p_sample_loop(model, dataset.shape, num_steps, betas.cuda(), one_minus_alphas_bar_sqrt.cuda(), use_cuda=True)
            else:
                x_seq = p_sample_loop(model, dataset.shape, num_steps, betas, one_minus_alphas_bar_sqrt, if_use_cuda)

            fig, axs = plt.subplots(1, 10, figsize=(28,3))
            for i in range(1, 11):
                cur_x = x_seq[i*10].cpu().detach()
                axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
                axs[i-1].set_axis_off();
                axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

            plt.savefig('./diffusion_train_tmp.png')


### ----------------动画演示扩散过程和逆扩散过程-------------------------
# 前向过程
imgs = []
for i in range(100):
    plt.clf()
    q_i = q_x(dataset,torch.tensor([i]))
    plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off');
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    imgs.append(img)

# 逆向过程
reverse = []
for i in range(100):
    plt.clf()
    cur_x = x_seq[i].cpu().detach()
    plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off')

    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    reverse.append(img)

print('save gif...')
imgs = imgs
imgs[0].save("diffusion_forward.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)

imgs = reverse
imgs[0].save("diffusion_denoise.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/568343.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

自适应STFT及其在地震时间行程自动拾取中的应用【附MATLAB代码】

文章来源:微信公众号:EW Frontie 摘要 在本文中,首先,我们提出的方法来产生高分辨率的短时傅里叶变换,通过计算最佳瞬时窗口长度。其次,利用生成的时频图提取瞬时走时属性,实现地震同相轴走时的…

vmstat命令详解

一、参数信息 vmstat 命令是用于报告虚拟内存统计信息的工具,常用于 Unix/Linux 系统上。它可以提供关于系统资源使用情况的详细信息,包括 CPU、内存、虚拟内存、磁盘、系统调用等方面的统计数据。以下是常见的 vmstat 命令参数的详解: vms…

k8s学习(三十六)centos下离线部署kubernetes1.30(单主节点)

文章目录 服务器准备工作一、升级操作系统内核1 查看操作系统和内核版本2 下载内核离线升级包3 升级内核4 确认内核版本 二、修改主机名/hosts文件1 修改主机名2 修改hosts文件 三、关闭防火墙四、关闭SELINUX配置五、时间同步1 下载NTP2 卸载3 安装4 配置4.1 主节点配置4.2 从…

2024商业地产五一劳动节健康大会朋克养生市集活动策划方案

2024商业地产五一劳动节健康大会朋克养生市集(带薪健康 快乐打工主题)活动策划方案 活动策划信息: 方案页码:53页 文件格式:PPT 方案简介: 打工不养生 赚钱养医生 期待已久的五一假期, …

WebSocket的原理、作用、常见注解和生命周期的简单介绍,附带SpringBoot示例

文章目录 WebSocket是什么WebSocket的原理WebSocket的作用全双工和半双工客户端【浏览器】API服务端 【Java】APIWebSocket的生命周期WebSocket的常见注解SpringBoot简单代码示例 WebSocket是什么 WebSocket是一种 通信协议 ,它在 客户端和服务器之间建立了一个双向…

开发环境中的调试视图(IDEA)

当程序员写完一个代码时必然要运行这个代码,但是一个没有异常的代码却未必满足我们的要求,因此就要求程序员对已经写好的代码进行调试操作。在之前,如果我们要看某一个程序是否满足我们的需求,一般情况下会对程序运行的结果进行打…

java泛型介绍

Java 泛型是 JDK 5 引入的一个特性,它允许我们在定义类、接口和方法时使用类型参数,从而使代码更加灵活和类型安全。泛型的主要目的是在编译期提供类型参数,让程序员能够在编译期间就捕获类型错误,而不是在运行时才发现。这样做提…

C语言学习/复习30--结构体的声明/初始化/typedef改名/内存对齐大小计算

一、自定义数据类型 二、结构体 1.结构体的定义(与数组相对比) 2.结构体全局/局部变量的定义 3.typedef对结构体改名 4.匿名结构体类型的声明 注意事项1: 匿名后必须立即创建结构体变量 、 5.结构体与链表节点定义 注意事项1&…

Python基础07-高级列表推导式和Lambda函数

在Python中,列表推导式和Lambda函数是处理数据的强大工具。本文将介绍如何使用嵌套列表推导式、带有条件的列表推导式、多可迭代对象的列表推导式、Lambda函数、在列表推导式中使用Lambda函数、用于展平嵌套列表的列表推导式、对元素应用函数、使用Lambda函数与Map和…

Arena-Hard:开源高质量大模型评估基准

开发一个安全、准确的大模型评估基准通常需要包含三个重要内容:1)稳定识别模型的能力;2)反映真实世界使用情况中的人类偏好;3)经常更新以避免过拟合或测试集泄漏。 但传统的基准测试通常是静态的或闭源的&…

程序员缓解工作压力小技巧

文章目录 1. 规划时间和任务2. 学会放松和调节情绪3. 培养兴趣爱好4. 保持健康的生活方式总结 当面对程序员这样需要高度精神集中和持续创新的工作时,缓解工作压力是至关重要的。下面分享一些我个人的经验和方法,希望能对大家有所帮助: 1. 规…

如何让AI生成自己喜欢的歌曲-AI音乐创作的正确方式 - 第507篇

历史文章 AI音乐,8大变现方式——Suno:音乐版的ChatGPT - 第505篇 日赚800,利用淘宝/闲鱼进行AI音乐售卖实操 - 第506篇 导读 在使用AI生成音乐(AI写歌)的时候,你是不是有这样的困惑: &…

线性模型算法

简介 本文章介绍机器学习中的线性模型有关内容,我将尽可能做到详细得介绍线性模型的所有相关内容 前置 什么是回归 回归的就是整合+预测 回归处理的问题可以预测: 预测房价 销售额的预测 设定贷款额度 可以根据事物的相关特征预测出对…

模型部署的艺术:让深度学习模型跃入生产现实

模型部署的艺术:让深度学习模型跃入生产现实 1 引言 1.1 部署的意义:为何部署是项目成功的关键 在深度学习项目的生命周期中,模型的部署是其成败的关键之一。通常,一个模型从概念构思、数据收集、训练到优化,最终目的…

【UML建模】用例图

1 参与者 参与者的概念: 指系统以外的、需要使用系统或与系统交互的外部实体 可以分为:人、外部设备、外部系统 参与者的图形符号: 例 3.1 在一个银行业务系统中,可能会有以下参与者 客户 :在银行业务系统中办理…

详细分析MySQL中的distinct函数(附Demo)

目录 前言1. 基本知识2. 基础Demo3. 进阶Demo 前言 该函数主要用于去重,对于细节知识,此文详细补充说明 1. 基本知识 DISTINCT 是一种用于查询结果中去除重复行的关键字 在查询数据库时,可能会得到重复的结果行,但有时只需要这…

奇妙的探索——偶然发现的bug

今天想在腾讯招聘官网找几个前端的岗位投一下,最近自己也在找工作,结果简历还没有投出去,就发现了腾旭招聘官网的3个前端bug。 1.有时候鼠标hover还没有滑倒下拉选框的菜单上,就消失了,消失的太快了,根本点…

揭秘App广告变现,如何轻松赚取额外收入?

揭秘App广告变现,如何轻松赚取额外收入? 在移动互联网高速发展的今天,APP广告变现已经成为了众多开发者和公司的主要盈利方式。但是,如何让一个APP实现高效的广告变现呢?这是一门大学问,需要我们用心去揣摩…

聚观早报 | TCL召开电视新品发布会;OceanBase 4.3发布

聚观早报每日整理最值得关注的行业重点事件,帮助大家及时了解最新行业动态,每日读报,就读聚观365资讯简报。 整理丨Cutie 4月22日消息 TCL召开电视新品发布会 OceanBase 4.3发布 科大讯飞推出耳背式助听器 F1联想中国大奖赛开赛 蔚来展…

个人博客建设必备:精选域名和主机的终极攻略

本文目录 🌏引言🌏域名的选择🌕域名的重要性品牌识别营销和宣传可访问性和易记性信任和权威感搜索引擎优化(SEO)未来的灵活性和扩展性保护品牌 🌕如何选择域名🌕工具与资源分享国内的主流域名注…