DDPM代码详解

最近准备要学习一下AIGC,因此需要从一些基本网络开始了解,比如DDPM,本篇文章会从代码解析角度来供大家学习了解。DDPM(Denoising Diffusion Probabilistic Models) 是一种扩散模型。

扩散模型包含两个主要的过程:加噪过程去噪过程。对应到上述图中,从x0到xt是加噪的过程,从xt到x0是去噪的过程。

前向加噪过程和反向去噪过程都是马尔可夫链,全过程大约需要1000步。

前向的加噪过程是对输入数据不断的加噪声(高斯噪声)。

反向去噪过程是对从标准高斯分布中逐步地得到一个个噪声样本,最终得到生成的样本的数据。


DDPM在代码中的定义如下:

代码采用的是Bubbliiing的代码。

net    = GaussianDiffusion(UNet(3, self.channel), self.input_shape, 3, betas=betas)

可以看到该扩散模型传入参数有UNet网络,input_shape为输入大小,3指的图像输入通道,betas为一个线性时间表,其值在 schedule_lowschedule_high 之间,在总时间步数 num_timesteps (这里设置的是1000)内均匀分布。而betas定义如下(当然你可以可以用cosine生成,这里我只是用线性的举例子):

betas = generate_linear_schedule(
                self.num_timesteps,
                self.schedule_low * 1000 / self.num_timesteps,
                self.schedule_high * 1000 / self.num_timesteps,
            )

forward函数部分

然后我们进入GaussianDiffusion的代码内部看一下各个组成部分。我们直接去看一下内部的forward函数,看看是如何处理图片的。

    def forward(self, x, y=None):
        b, c, h, w  = x.shape
        device      = x.device

        if h != self.img_size[0]:
            raise ValueError("image height does not match diffusion parameters")
        if w != self.img_size[0]:
            raise ValueError("image width does not match diffusion parameters")
        # 随机生成batch个范围在0~1000内的数
        t = torch.randint(0, self.num_timesteps, (b,), device=device)
        return self.get_losses(x, t, y)

 可以看到在GaussianDiffusion的forward部分,x是输入的图片,然后里面有个t,表示随机生成范围在0~num_timesteps【时间步长】batch_size个数,或者可以理解为给每个batch(图片)随机打上时间戳。然后再一步一步深挖代码,进入get_losses函数。


get_losses部分

下面是get_losses代码,有三个输入,x,t,y。这里的x就是我们训练输入的图片t就是上面随机生成的时间戳

    def get_losses(self, x, t, y):
        # x, noise [batch_size, 3, 64, 64]
        noise           = torch.randn_like(x)  # 产生与输入图片shape一样的随机噪声(正态分布)

        perturbed_x     = self.perturb_x(x, t, noise)
        estimated_noise = self.model(perturbed_x, t, y)

        if self.loss_type == "l1":
            loss = F.l1_loss(estimated_noise, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(estimated_noise, noise)
        return loss

在函数内部首先是创建了一个与输入图片大小的相同的符合正态分布的随机噪声noise,然后perturb_x函数是对输入图片在时间t上加入噪声进行加噪的扰动处理。

perturb_x函数部分

那么就看一下perturb_x中是如何给图片在时间t上加噪的(要保持头脑清醒,这些代码和套娃一样一层一层的)。

在该函数中有三个输入参数:x(输入图片),t(时间序列),noise(随机噪声),函数最终返回的是经过加噪(扰动)后的图像,比如我现在输入一张图片,然后此时的t=323,那么就可以理解为在时间戳为323的时候为我的这张图加上噪声t,而这图就是对应于t时刻的输入Xtsqrt_alphas_cumprodsqrt_one_minus_alphas_cumprod 使用了这两个张量来控制输入图像x和噪声noise在时间维度上的混合比例。

    def perturb_x(self, x, t, noise):
        '''
        :param x:输入图像
        :param t: 每个图片不同的时间戳(范围在0~1000)
        :param noise: 与输入图片shape一样的正态分布随机噪声
        :return:经过扰动后的图像
        '''
        return (
            extract(self.sqrt_alphas_cumprod, t,  x.shape) * x +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
        )

 我们可以对perturb_x的过程进行可视化,比如我有下面一张未加噪声的原始图片

 通过perturb_x对x进行扰动加噪后的效果

我们还可以控制噪声在图像上的扩散扰动效果:

 上面就是通过对时间t时刻对应的图片Xt加噪的处理过程了。会随着时间t的推移而变得越来越模糊

然后再返回get_losses函数(代码如下),perturbed_x就是我们加噪后的t时刻的图片Xt,然后这里的model就是我们的主干网络UNet网络(UNet网络部分我会单独拿出来)。那么可以总结一下get_losses的主要过程:

步骤1.通过perturb_x对输入图像进行时间域上的扰动,并与随机噪声 noise 混合,生成 perturbed_x 扰动后的图像

步骤2.通过UNet网络对加噪后的图像进行预测,得到预测后的噪声信号estimated_noise。

步骤3.计算预测噪声estimated_noise和真实噪声noise的loss。

    def get_losses(self, x, t, y):
        # x, noise [batch_size, 3, 64, 64]
        noise           = torch.randn_like(x)  # 产生与输入图片shape一样的随机噪声(正态分布)

        perturbed_x     = self.perturb_x(x, t, noise)
        estimated_noise = self.model(perturbed_x, t, y)

        if self.loss_type == "l1":
            loss = F.l1_loss(estimated_noise, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(estimated_noise, noise)
        return loss

DDP是由Unet组成的,那就先看一下Unet中的组成。

class UNet(nn.Module):
    def __init__(
        self, img_channels, base_channels=128, channel_mults=(1, 2, 4, 8),
        num_res_blocks=3, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=SiLU(),
        dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
    ):

time_mlp

self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None

time_mlp又由PositionalEmbedding层、Linear、SiLu、Linear组成。

PositionalEmbedding层

class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device      = x.device
        half_dim    = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

代码中forward中的x为time(时间轴)并不是图像

该函数主要是用来做位置编码的。而位置编码可以用正余弦来计算位置。所用到的公式为:

PE_{pos,2i}=sin(pos/10000^{2i/d_{model}})

PE_{pos,2i+1}=cos(pos/10000^{2i/d_{model}})

在位置编码公式中,pos 表示序列中的每个位置的索引。对于长度为 4 的序列 x,每个位置的索引从 0 到 3。在计算每个位置的位置编码向量时,我们会利用这个索引值进行计算。

具体来说,公式中的 pos 表示序列中的位置索引,在计算位置编码向量的过程中,会使用它来计算正弦和余弦的函数参数。

例如,在计算位置编码矩阵的第一个位置编码向量时,pos 的值为 0;在计算第二个位置编码向量时,pos 的值为 1,以此类推。

可以举个例子,比如我现在有个序列X,长度为4,位置编码的维度也设置为4.然后计算每个序列的位置信息(通过正余弦)

# 设置向量的长度和位置编码的维度
vector_length = 4
embedding_dim = 4

# 生成位置编码矩阵
pos_encoding = np.zeros((vector_length, embedding_dim))

for pos in range(vector_length):
    for i in range(embedding_dim):
        if i % 2 == 0:
            pos_encoding[pos, i] = np.sin(pos / (10000 ** (2 * i / embedding_dim)))
        else:
            pos_encoding[pos, i] = np.cos(pos / (10000 ** (2 * (i - 1) / embedding_dim)))

# 打印位置编码矩阵
print(pos_encoding)

得到的位置编码矩阵如下 

[[ 0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00]
 [ 8.41470985e-01  5.40302306e-01  9.99999998e-05  9.99999995e-01]
 [ 9.09297427e-01 -4.16146837e-01  1.99999999e-04  9.99999980e-01]
 [ 1.41120008e-01 -9.89992497e-01  2.99999995e-04  9.99999955e-01]]

其中,数组的每一行对应位置编码矩阵的一个位置,第一列表示正弦函数在该位置上的取值,第二列表示余弦函数在该位置上的取值,以此类推。 

也就是说在该函数中,我们可以将输入信息的位置信息,通过sin和cos映射到高纬度空间中,得到位置特征

ResidualBlock

class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=SiLU(),
        norm="gn", num_groups=32, use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout), 
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
    
    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        # 第一个卷积
        out = self.conv_1(out)
        
        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        # 第二个卷积+残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out

。。。。暂未更新完

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/209590.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言--每日选择题--Day32

如果大家对读研究生和就业不知道如何抉择,我的建议是看大家的经济基础,如果家里不是很需要你们工作,就读研提升自己的学历,反之就就业;毕竟经济基础决定上层建筑; 第一题 1. 下面代码的结果是:…

牛客小白月赛82(A~C)

目录 A.谜题:质数 输入描述 输出描述 输入 输出 解析 B.Kevin逛超市 2 (简单版本) 输入描述 输出描述 输入 输出 思路 C.被遗忘的书籍 题目描述 输入描述 输出描述 输入 输出 输入 输出 思路 比赛链接 牛客小白月赛82_ACM/NOI/CSP/CCPC/ICPC算…

C#Backgroundworker与Thread的区别

前言 当谈到多线程编程时,C#中的BackgroundWorker和Thread是两个常见的选择。它们都可以用于实现并行处理和异步操作,但在某些方面有一些重要的区别。本文将详细解释BackgroundWorker和Thread之间的区别以及它们在不同场景中的使用。 目录 前言1. Backgr…

微软 Power Platform 零基础 Power Pages 网页搭建教程学习实践进阶以及常见问题解答(二)

微软 Power Platform 零基础 Power Pages 网页搭建教程学习实践进阶及常见问题解答(二) Power Pages 学习实践进阶 微软 Power Platform 零基础 Power Pages 网页搭建教程学习实践进阶及常见问题解答(二)Power Pages 核心工具和组…

基于单片机设计的智能水泵控制器

一、前言 在一些场景中,如水池、水箱等水体容器的管理中,保持水位的稳定是至关重要的。传统上,人们通常需要手动监测水位并进行水泵的启停控制,这种方式不仅效率低下,还可能导致水位过高或过低,从而对水体…

在 AlmaLinux9 上安装Oracle Database 23c

在 AlmaLinux9 上安装Oracle Database 23c 0. 下载 Oracle Database 23c 安装文件1. 安装 Oracle Database 23c3. 连接 Oracle Database 23c4. (谨慎)卸载 Oracle Database 23c 0. 下载 Oracle Database 23c 安装文件 版权问题,下载地址请等待…

企业加密软件有哪些(公司防泄密软件)

企业加密软件是专门为企业设计的软件,旨在保护企业的敏感数据和信息安全。这些软件通过使用加密技术来对数据进行加密,使得数据在传输和存储过程中不会被未经授权的人员获取和滥用。 企业加密软件的主要功能包括数据加密、文件加密、文件夹加密、移动设备…

专业视频剪辑利器Final Cut Pro for Mac,让你的创意无限发挥

在如今的数字时代,视频内容已经成为人们生活中不可或缺的一部分。无论是在社交媒体上分享生活点滴,还是在工作中制作专业的营销视频,我们都希望能够以高质量、高效率地进行视频剪辑和制作。而Final Cut Pro for Mac作为一款专业级的视频剪辑软…

2023-12-01 AndroidR 系统在root目录下新建文件夹和创建链接,编译的时候需要修改sepolicy权限

一、想在android 系统的根目录下新建一个tmp 文件夹,建立一个链接usr链接到data目录。 二、在system/core/rootdir/Android.mk里面的LOCAL_POST_INSTALL_CMD 增加 dev proc sys system data data_mirror odm oem acct config storage mnt apex debug_ramdisk tmp …

20、Resnet 为什么这么重要

(本文已加入“计算机视觉入门与调优”专栏,点击专栏查看更多文章信息) resnet 这一网络的重要性,上一节大概介绍了一下,可以从以下两个方面来有所体现:第一是 resnet 广泛的作为其他神经网络的 back bone&…

麻吉POS集成:如何无代码开发实现电商平台和CRM系统的高效连接

麻吉POS集成的前沿技术:无代码开发 在竞争激烈的电商市场中,商家们急需一种高效且易于操作的技术手段来实现系统间的快速连接与集成。麻吉POS以其前沿的无代码开发技术,让这一需求成为可能。无代码开发是一种允许用户通过图形用户界面进行编…

RocketMQ事务消息源码解析

RocketMQ提供了事务消息的功能,采用2PC(两阶段协议)补偿机制(事务回查)的分布式事务功能,通过这种方式能达到分布式事务的最终一致。 一. 概述 半事务消息:指的是发送至broker但是还没被commit的消息,在半…

java引入jjwt时候报错Undable to load class named ...

原因是包没有引全,像下面这样写重新加载maven <dependency><groupId>io.jsonwebtoken</groupId><artifactId>jjwt-api</artifactId><version>0.11.1</version></dependency><dependency><groupId>io.jsonwebtoken…

mac安装elasticsearch

下载地址&#xff1a; Past Releases of Elastic Stack Software | Elastic https://www.elastic.co/cn/downloads/past-releases#elasticsearch 选择7.10版本 进入es bin目录下执行启动命令 ./elasticsearch 会报错 ./elasticsearch-env: line 126: syntax error near u…

使用Python免费调用通义千问大模型

Qwen-72b开源模型 模型的主要用途是预测或描述一个系统或现象的行为模式。它可以帮助人们更好地理解这个系统或现象&#xff0c;例如预测股市变化、天气预报、地震预警、交通流量等。模型也常用于设计和优化产品和工艺。在科学研究中&#xff0c;模型也是一种方法&#xff0c;用…

【开源】基于JAVA语言的校园电商物流云平台

项目编号&#xff1a; S 034 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S034&#xff0c;文末获取源码。} 项目编号&#xff1a;S034&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 商品数据模块2.3 快…

Leetcode 剑指 Offer II 055. 二叉搜索树迭代器

题目难度: 中等 原题链接 今天继续更新 Leetcode 的剑指 Offer&#xff08;专项突击版&#xff09;系列, 大家在公众号 算法精选 里回复 剑指offer2 就能看到该系列当前连载的所有文章了, 记得关注哦~ 题目描述 实现一个二叉搜索树迭代器类 BSTIterator &#xff0c;表示一个按…

03数据仓库Flume

Flume 功能 Flume主要作用&#xff0c;就是实时读取服务器本地磁盘数据&#xff0c;将数据写入到 HDFS。 Flume是 Cloudera提供的高可用&#xff0c;高可靠性&#xff0c;分布式的海量日志采集、聚合和传输的系统工具。 Flume 架构 Flume组成架构如下图所示&#xff1a; A…

力扣232-用栈实现队列

文章目录 力扣232-用栈实现队列示例代码实现总结收获 力扣232-用栈实现队列 示例 代码实现 class MyQueue {Deque<Integer> instack;Deque<Integer> outstack ;public MyQueue() {instacknew ArrayDeque<Integer>();outstacknew ArrayDeque<Integer>(…

scrapy介绍,并创建第一个项目

一、scrapy简介 scrapy的概念 Scrapy是一个Python编写的开源网络爬虫框架。它是一个被设计用于爬取网络数据、提取结构性数据的框架。 Scrapy 使用了Twisted异步网络框架&#xff0c;可以加快我们的下载速度。 Scrapy文档地址&#xff1a;http://scrapy-chs.readthedocs.io/z…