从开始实现扩散概率模型 PyTorch 实现

目录

一、说明

二、从头开始实施

三、线性噪声调度器

四、时间嵌入

五、下层DownBlock类块

六、中间midBlock类块

七、UpBlock上层类块

八、UNet 架构

九、训练

十、采样

十一、配置(Default.yaml)

十二、数据集 (MNIST)


keyword:  Diffusion Probabilistic Models 

一、说明

        扩散过程由前向阶段组成,其中图像通过在每个步骤中添加高斯噪声逐渐损坏。经过许多步骤后,图像实际上变得与从正态分布中采样的随机噪声无法区分。这是通过在每个时间步骤 xₜ 应用过渡函数来实现的,其中 β 表示在 t-1 时添加到图像中的预定噪声量,以产生 t 时的图像。

        在前面的讨论中,我们确定设置 α=1−β 并计算每个时间步骤中这些 α 值的累积乘积,使我们能够在任何给定步骤 t 直接从原始图像过渡到噪声版本。在反向过程中,模型被训练以近似反向分布。由于正向和反向过程都是高斯的,因此目标是让模型预测反向分布的均值和方差。

        通过详细的推导,从最大化观测数据的对数似然性这一目标出发,我们得出需要最小化真实去噪分布(以 x₀ 为条件)与模型预测分布之间的 KL 散度(以特定均值和方差为特征)。方差固定为与目标分布的方差匹配,而均值则以相同形式重写。最小化 KL 散度简化为最小化预测噪声与实际噪声样本之间的平方差。

训练过程包括对图像进行采样、选择时间步长 t,以及添加从正态分布中采样的噪声。然后将 t 处的噪声图像传递给模型。从噪声时间表得出的累积乘积项确定随时间增加的噪声。损失函数是原始噪声样本与模型预测之间的均方误差 (MSE)。

二、从头开始实施

        对于图像生成,我们从学习到的反向分布中进行采样,从正态分布中的随机噪声样本 xₜ 开始。使用与 xₜ 和预测噪声相同的公式计算平均值,方差与地面真实去噪分布相匹配。使用重新参数化技巧,我们反复从这个反向分布中采样以生成 x₀。在 x₀ 处,没有添加额外的噪声;相反,平均值直接作为最终输出返回。

        为了实现扩散过程,我们需要处理正向和反向阶段的计算。我们将创建一个噪声调度程序来管理这些任务。在正向过程中,给定一个图像、一个噪声样本和一个时间步长 t,调度程序将使用正向方程返回图像的噪声版本。为了优化效率,它将预先计算并存储 α(1−β) 的值以及所有时间步长中 α 的累积乘积。

        作者采用了线性噪声调度,其中 β 在 1,000 个时间步骤内从 1×10⁻⁴ 线性缩放到 0.02。调度程序还处理反向过程:给定 xt 和模型预测的噪声,它将通过从反向分布中采样来计算 xₜ₋₁。这涉及使用各自的方程计算均值和方差,并通过重新参数化技巧生成样本。

        为了支持这些计算,调度程序还将存储 1-αₜ、1-累积乘积项以及该项的平方根的预先计算的值。

三、线性噪声调度器

import torch


class LinearNoiseScheduler:

    def __init__(self, num_timesteps, beta_start, beta_end):
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1. - self.betas
        self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

使用传递给此类的参数初始化所有参数后,我们将定义 β 值从起始范围到结束范围线性增加,确保 βₜ 从 0 进展到最后的时间步骤。接下来,我们将设置正向和反向过程方程所需的所有变量。

  def add_noise(self, original, noise, t):

        original_shape = original.shape
        batch_size = original_shape[0]
        
        sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
        
        # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
        for _ in range(len(original_shape) - 1):
            sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
        for _ in range(len(original_shape) - 1):
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
        
        # Apply and Return Forward process equation
        return (sqrt_alpha_cum_prod.to(original.device) * original
                + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)

add_noise()函数表示正向过程。它以原始图像、噪声样本和时间步长 ttt 作为输入。图像和噪声的维度为 b×h×w,而时间步长为大小为 b 的一维张量。对于正向过程,我们计算给定时间步长的累积乘积项的平方根和 1-累积乘积项。这些值被重新整形为维度 b×1×1×1。最后,我们应用正向过程方程来生成噪声图像。

    def sample_prev_timestep(self, xt, noise_pred, t):

        x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
              torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
        x0 = torch.clamp(x0, -1., 1.)
        
        mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
        mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
        
        if t == 0:
            return mean, x0
        else:
            variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
            variance = variance * self.betas.to(xt.device)[t]
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)
            
            return mean + sigma * z, x0

        调度程序类中的下一个函数处理反向过程。它使用噪声图像 xₜ、模型的噪声预测和时间步长 t 作为输入,从学习到的反向分布中生成样本。我们保存原始图像预测 x₀​ 以供可视化,它是通过重新排列正向过程方程以使用噪声预测而不是实际噪声来计算 x₀ 获得的。

        对于逆向过程中的采样,我们使用逆均值方程计算均值。在 t=0 时,我们只需返回均值。对于其他时间步骤,噪声会添加到均值中,方差与以 x₀​ 为条件的地面真实去噪分布的方差相同。最后,我们使用计算出的均值和方差从高斯分布中采样,应用重新参数化技巧来生成结果。

        这样就完成了噪声调度程序,它管理添加噪声的正向过程和采样的反向过程。对于扩散模型,我们可以灵活地选择任何架构,只要它满足两个关键要求。第一,输入和输出形状必须相同,第二,必须有一种方法可以整合时间步长信息。

作者图片

        无论是在训练期间还是采样期间,时间步长信息始终是可访问的。包含此信息有助于模型更好地预测原始噪声,因为它表明输入图像中有多少是噪声。我们不仅向模型提供图像,还提供相应的时间步长。

        对于模型架构,我们将使用 UNet,这也是原作者的选择。为了确保一致性,我们将复制 Hugging Face 的 Diffusers 管道中使用的稳定扩散 UNet 中实现的块、激活、规范化和其他组件的精确规格。

作者图片

        时间步长由时间嵌入块处理,该块采用大小为b(批次大小)的时间步长的一维张量,并输出批次中每个时间步长的大小为t_emb_dim的表示。此块首先通过嵌入空间将整数时间步长转换为矢量表示。然后,此嵌入通过中间带有激活函数的两个线性层,产生最终的时间步长表示。对于嵌入空间,作者使用了 Transformers 中常用的正弦位置嵌入方法。在整个架构中,使用的激活函数是 S 形线性单元 (SiLU),但也可以选择其他激活函数。

作者图片

        UNet架构遵循简单的编码器-解码器设计。编码器由多个下采样块组成,每个块都会减少输入的空间维度(通常减半),同时增加通道数量。最终下采样块的输出由中间块的几层处理,所有层都以相同的空间分辨率运行。随后,解码器采用上采样块,逐步增加空间维度并减少通道数量,最终匹配原始输入大小。在解码器中,上采样块通过残差跳过连接以相同的分辨率集成相应下采样块的输出。虽然大多数扩散模型都遵循这种通用的 UNet 架构,但它们在各个块内的具体细节和配置上有所不同。

作者图片

        大多数变体中的下行块通常由ResNet 块、后跟自注意力块和下采样层组成。每个 ResNet 块都使用一系列操作构建:组归一化、激活层和卷积层。此序列的输出将通过另一组归一化、激活和卷积层。通过将第一个归一化层的输入与第二个卷积层的输出相结合来添加残差连接。这个完整的序列形成ResNet 块,可以将其视为通过残差连接连接的两个卷积块。

        在 ResNet 块之后,有一个规范化步骤、一个自注意力层和另一个残差连接。虽然模型通常使用多个 ResNet 层和自注意力层,但为简单起见,我们的实现将只使用每个层的一层。

        为了整合时间信息,每个 ResNet 块都包含一个激活层,后面跟着一个线性层,用于处理时间嵌入表示。时间嵌入表示为大小为t_emb_dim的张量,通过此线性层将其投影到与卷积层输出具有相同大小和通道数的张量中。这样就可以通过在空间维度上复制时间步长表示,将时间嵌入添加到卷积层的输出中。

作者图片

        另外两个块使用相同的组件,只是略有不同。上块完全相同,只是它首先将输入上采样为两倍空间大小,然后在整个通道维度上集中相同空间分辨率的下块输出。然后我们有相同的 resnet 层和自注意力块。中间块的层始终将输入保持为相同的空间分辨率。hugging face 版本首先有一个 resnet 块,然后是自注意力层和 resnet 层。对于这些 resnet 块中的每一个,我们都有一个时间步长投影层。现有的时间步长表示会经过这些块,然后被添加到 resnet 的第一个卷积层的输出中。

四、时间嵌入

import torch
import torch.nn as nn


def get_time_embedding(time_steps, temb_dim):

    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    
    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb

第一个函数为给定的时间步长get_time_embedding生成时间嵌入。它受到 Transformer 模型中使用的正弦位置嵌入的启发。
time_steps:时间步长值的张量(形状:[B]其中B是批次大小)。每个值代表批次元素的一个离散时间步长。
temb_dim:时间嵌入的维数。这决定了每个时间步长的生成嵌入的大小。

确保这temb_dim是均匀的,因为正弦嵌入需要将嵌入分成两半,分别表示正弦和余弦分量。无缝扩展以处理任何批量大小或嵌入维度。

五、下层DownBlock类块

class DownBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim,
                 down_sample=True, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
             for _ in range(num_layers)]
        )
        
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
             for _ in range(num_layers)]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else nn.Identity()


    def forward(self, x, t_emb):
        out = x
        for i in range(self.num_layers):
            
            # Resnet block of Unet
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            # Attention block of Unet
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
        out = self.down_sample_conv(out)
        return out

DownBlock 类结合了ResNet 块自注意力块和可选的下采样,并集成了时间嵌入来整合时间步长信息。将卷积层与残差连接相结合,以实现更好的梯度流和更高效的学习。将时间步长表示投影到特征空间中,使模型能够整合时间相关信息。通过对所有空间位置之间的关系进行建模来捕获长距离依赖关系。减少空间维度以专注于更深层中更大规模的特征。

参数

  • in_channels:输入通道数。
  • out_channels:输出通道数。
  • t_emb_dim:时间嵌入的维度。
  • down_sample:布尔值,确定是否在块末尾应用下采样。
  • num_heads:多头注意力层中的注意力头的数量。
  • num_layers:此块中的 ResNet + 注意力层的数量。

ResNet块

  • resnet_conv_first:ResNet 块的第一个卷积层。
  • t_emb_layers:时间嵌入投影层。
  • resnet_conv_second:ResNet 块的第二个卷积层。
  • residual_input_conv:用于残差连接的 1x1 卷积。

自注意力模块

  • attention_norms:在注意力机制之前对规范化层进行分组。
  • attentions:多头注意力层。

下采样

  • down_sample_conv:应用卷积来减少空间维度(如果down_sample=True)。

Forward Pass 方法定义了如何x通过块处理输入张量:out初始化为输入x。对于每一层,我们都有 ResNet Block 和 Self-Attention Block。

在 ResNet Block 中,我们有第一个 卷积层,它应用 GroupNorm、SiLU 激活和 3x3 卷积,以及一个时间嵌入函数,它将时间嵌入传递t_emb到线性层(投影到out_channels),并将此投影时间嵌入添加到out(在空间维度上广播)。然后我们有第二个卷积和一个残差连接,它将原始输入(resnet_input)添加到第二个卷积的输出。

在自注意力模块中,我们将空间维度扁平化为一个维度(h * w)以用于注意力机制。规范化输入并转置以匹配注意力层输入格式。多头注意力in_attn使用查询、键和值执行自注意力。重塑回转置并重塑回原始空间维度。残差连接下采样。

六、中间midBlock类块

class MidBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers+1)
            ]
        )
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers + 1)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers+1)
            ]
        )
        
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)]
        )
        
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers+1)
            ]
        )
    
    def forward(self, x, t_emb):
        out = x
        
        # First resnet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)
        
        for i in range(self.num_layers):
            
            # Attention Block
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i+1](out)
            out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i+1](out)
            out = out + self.residual_input_conv[i+1](resnet_input)
        
        return out

该类MidBlock是位于扩散模型中 U-Net 架构中间的模块。它由ResNet 块自注意力层组成,并集成了时间嵌入来处理时间信息。这是用于去噪扩散等任务的模型的重要组成部分。此外,我们还有:

  • 时间嵌入:通过将时间信息(例如,扩散模型中的去噪步骤)投影到特征空间并将其添加到卷积特征中来合并时间信息。
  • 层迭代:在注意力ResNet 块之间交替,按num_layers这些组合的顺序处理输入。

七、UpBlock上层类块

class UpBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.up_sample = up_sample
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        
        self.attention_norms = nn.ModuleList(
            [
                nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)
            ]
        )
        
        self.attentions = nn.ModuleList(
            [
                nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
                                                 4, 2, 1) \
            if self.up_sample else nn.Identity()
    
    def forward(self, x, out_down, t_emb):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim=1)
        
        out = x
        for i in range(self.num_layers):
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn

        return out

该类UpBlock是 U-Net 类架构的解码器阶段的一部分,通常用于扩散模型或其他图像生成/分割任务。它结合了上采样跳过连接ResNet 块自注意力来重建输出图像,同时保留早期编码器阶段的细粒度细节。

  • 上采样:通过转置卷积(ConvTranspose2d)实现,以增加特征图的空间分辨率。
  • 跳过连接:允许解码器重用编码器的详细特征,帮助重建。
  • ResNet Block:使用卷积层处理输入,集成时间嵌入,并添加残差连接以实现高效的梯度流。
  • 自我注意力:捕获远程空间依赖关系以保留全局上下文。
  • 时间嵌入:对时间信息进行编码并将其注入特征图,这对于处理动态数据的模型(如扩散模型)至关重要。

八、UNet 架构

class Unet(nn.Module):

    def __init__(self, model_config):
        super().__init__()
        im_channels = model_config['im_channels']
        self.down_channels = model_config['down_channels']
        self.mid_channels = model_config['mid_channels']
        self.t_emb_dim = model_config['time_emb_dim']
        self.down_sample = model_config['down_sample']
        self.num_down_layers = model_config['num_down_layers']
        self.num_mid_layers = model_config['num_mid_layers']
        self.num_up_layers = model_config['num_up_layers']
        
        assert self.mid_channels[0] == self.down_channels[-1]
        assert self.mid_channels[-1] == self.down_channels[-2]
        assert len(self.down_sample) == len(self.down_channels) - 1
        
        # Initial projection from sinusoidal time embedding
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )

        self.up_sample = list(reversed(self.down_sample))
        self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
        
        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels)-1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,
                                        down_sample=self.down_sample[i], num_layers=self.num_down_layers))
        
        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels)-1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,
                                      num_layers=self.num_mid_layers))
        
        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels)-1)):
            self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
                                    self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))
        
        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
    
    def forward(self, x, t):
        # Shapes assuming downblocks are [C1, C2, C3, C4]
        # Shapes assuming midblocks are [C4, C4, C3]
        # Shapes assuming downsamples are [True, True, False]
        # B x C x H x W
        out = self.conv_in(x)
        # B x C1 x H x W
        
        # t_emb -> B x t_emb_dim
        t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        
        down_outs = []
        
        for idx, down in enumerate(self.downs):
            down_outs.append(out)
            out = down(out, t_emb)
        # down_outs  [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
        # out B x C4 x H/4 x W/4
            
        for mid in self.mids:
            out = mid(out, t_emb)
        # out B x C3 x H/4 x W/4
        
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)
            # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        # out B x C x H x W
        return out

该类是U-Net 架构Unet的实现,专为图像处理任务而设计,例如分割或生成,通常用于扩散模型。该网络包括下采样中级处理上采样阶段。它利用时间嵌入执行动态任务(例如扩散模型),利用跳过连接保留空间信息,利用 GroupNorm 进行归一化。

作者图片

  • 时间嵌入:实现时间动态。
  • 跳过连接:通过连接将细粒度的空间细节集成到解码器中。
  • 灵活的架构:允许通过model_config不同的深度、分辨率和功能丰富度进行定制。
  • 规范化和激活:GroupNorm 确保稳定的训练,而 SiLU 激活则改善非线性。
  • 输出一致性:确保输出图像保留原始的空间尺寸和通道数。

九、训练

import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import Adam
from dataset.mnist_dataset import MnistDataset
from torch.utils.data import DataLoader
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def train(args):
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    print(config)
    
    diffusion_config = config['diffusion_params']
    dataset_config = config['dataset_params']
    model_config = config['model_params']
    train_config = config['train_params']
    
    # Create the noise scheduler
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                     beta_start=diffusion_config['beta_start'],
                                     beta_end=diffusion_config['beta_end'])
    
    # Create the dataset
    mnist = MnistDataset('train', im_path=dataset_config['im_path'])
    mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4)
    
    # Instantiate the model
    model = Unet(model_config).to(device)
    model.train()
    
    # Create output directories
    if not os.path.exists(train_config['task_name']):
        os.mkdir(train_config['task_name'])
    
    # Load checkpoint if found
    if os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])):
        print('Loading checkpoint as found one')
        model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
                                                      train_config['ckpt_name']), map_location=device))
    # Specify training parameters
    num_epochs = train_config['num_epochs']
    optimizer = Adam(model.parameters(), lr=train_config['lr'])
    criterion = torch.nn.MSELoss()
    
    # Run training
    for epoch_idx in range(num_epochs):
        losses = []
        for im in tqdm(mnist_loader):
            optimizer.zero_grad()
            im = im.float().to(device)
            
            # Sample random noise
            noise = torch.randn_like(im).to(device)
            
            # Sample timestep
            t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
            
            # Add noise to images according to timestep
            noisy_im = scheduler.add_noise(im, noise, t)
            noise_pred = model(noisy_im, t)

            loss = criterion(noise_pred, noise)
            losses.append(loss.item())
            loss.backward()
            optimizer.step()
        print('Finished epoch:{} | Loss : {:.4f}'.format(
            epoch_idx + 1,
            np.mean(losses),
        ))
        torch.save(model.state_dict(), os.path.join(train_config['task_name'],
                                                    train_config['ckpt_name']))
    
    print('Done Training ...')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Arguments for ddpm training')
    parser.add_argument('--config', dest='config_path',
                        default='config/default.yaml', type=str)
    args = parser.parse_args()
    train(args)

加载配置:从 YAML 文件读取训练配置(如数据集路径、超参数和模型设置)。

设置组件

  • 初始化噪声调度器,用于在不同的时间步添加噪声。
  • 创建一个MNIST 数据集加载器
  • 实例化U-Net模型

检查点管理:检查现有检查点,如果可用则加载。创建保存检查点和输出所需的目录。

训练循环:每个时期:

  • 遍历数据集,根据采样的时间步长向图像添加噪声。
  • 使用模型预测噪声并计算损失(预测噪声和实际噪声之间的 MSE)。
  • 使用反向传播更新模型参数并保存模型检查点。

优化:使用 Adam 优化器和 MSE 损失函数来训练模型。

完成:打印 epoch 损失并在每个 epoch 结束时保存模型。

十、采样

import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from tqdm import tqdm
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def sample(model, scheduler, train_config, model_config, diffusion_config):

    xt = torch.randn((train_config['num_samples'],
                      model_config['im_channels'],
                      model_config['im_size'],
                      model_config['im_size'])).to(device)
    for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
        # Get prediction of noise
        noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
        
        # Use scheduler to get x0 and xt-1
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
        
        # Save x0
        ims = torch.clamp(xt, -1., 1.).detach().cpu()
        ims = (ims + 1) / 2
        grid = make_grid(ims, nrow=train_config['num_grid_rows'])
        img = torchvision.transforms.ToPILImage()(grid)
        if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
            os.mkdir(os.path.join(train_config['task_name'], 'samples'))
        img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
        img.close()


def infer(args):
    # Read the config file #
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    print(config)
    
    diffusion_config = config['diffusion_params']
    model_config = config['model_params']
    train_config = config['train_params']
    
    # Load model with checkpoint
    model = Unet(model_config).to(device)
    model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
                                                  train_config['ckpt_name']), map_location=device))
    model.eval()
    
    # Create the noise scheduler
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                     beta_start=diffusion_config['beta_start'],
                                     beta_end=diffusion_config['beta_end'])
    with torch.no_grad():
        sample(model, scheduler, train_config, model_config, diffusion_config)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')
    parser.add_argument('--config', dest='config_path',
                        default='config/default.yaml', type=str)
    args = parser.parse_args()
    infer(args)

加载配置:从 YAML 文件读取模型、扩散和训练参数。

模型设置:加载训练好的 U-Net 模型检查点。初始化噪声调度程序以指导反向扩散过程。

采样过程

  • 从随机噪声开始,并在指定的时间步内迭代地对其进行去噪。
  • 在每个时间步:
  • 使用模型预测噪音。
  • 使用调度程序计算去噪图像(x0)并更新当前噪声图像(xt)。
  • 将中间去噪图像作为 PNG 文件保存在输出目录中。

推理:执行采样过程并保存结果而不改变模型。

十一、配置(Default.yaml)

dataset_params:
  im_path: 'data/train/images'

diffusion_params:
  num_timesteps : 1000
  beta_start : 0.0001
  beta_end : 0.02

model_params:
  im_channels : 1
  im_size : 28
  down_channels : [32, 64, 128, 256]
  mid_channels : [256, 256, 128]
  down_sample : [True, True, False]
  time_emb_dim : 128
  num_down_layers : 2
  num_mid_layers : 2
  num_up_layers : 2
  num_heads : 4

train_params:
  task_name: 'default'
  batch_size: 64
  num_epochs: 40
  num_samples : 100
  num_grid_rows : 10
  lr: 0.0001
  ckpt_name: 'ddpm_ckpt.pth'

该配置文件提供了扩散模型的训练和推理的设置。

数据集参数im_path:指定训练图像的路径( )。

扩散参数:设置扩散过程的时间步数和噪声参数的范围(beta_startbeta_end)。

模型参数

  • 定义模型架构,包括:
  • 输入图像通道(im_channels)和大小(im_size)。
  • 下采样、中间处理和上采样的通道数。
  • 每一级是否发生下采样(down_sample)。
  • 各种块的嵌入尺寸和层数。

训练参数

  • 指定训练配置,如任务名称、批量大小、时期、学习率和检查点文件名。
  • 包括采样设置,例如用于可视化的样本数量和网格行数。

十二、数据集 (MNIST)

import glob
import os

import torchvision
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset


class MnistDataset(Dataset):
        self.split = split
        self.im_ext = im_ext
        self.images, self.labels = self.load_images(im_path)
    
    def load_images(self, im_path):
        assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
        ims = []
        labels = []
        for d_name in tqdm(os.listdir(im_path)):
            for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
                ims.append(fname)
                labels.append(int(d_name))
        print('Found {} images for split {}'.format(len(ims), self.split))
        return ims, labels
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        im = Image.open(self.images[index])
        im_tensor = torchvision.transforms.ToTensor()(im)
        
        # Convert input to -1 to 1 range.
        im_tensor = (2 * im_tensor) - 1
        return im_tensor

初始化:采用分割名称、图像文件扩展名(im_ext)和图像路径(im_path)。调用load_images以加载图像路径及其相应的标签。

图像加载load_images遍历 处的目录结构im_path,假设子目录已标记(例如,数字类别的01、...)。收集图像文件路径并根据文件夹名称分配标签。

数据集长度__len__返回图像的总数。

数据检索__getitem__通过索引检索图像,将其转换为张量,并将像素值缩放到范围 -1,1-1,1-1,1。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/936749.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MTK Android12 更换开机LOGO和开机动画

1、路径: (1)device/mediatek/system/common/device.mk (2)vendor/audio-logo/animation/bootanimation.zip (3)vendor/audio-logo/products/resource-copy.mk (4)vendo…

数据分析思维(一):业务指标(数据分析并非只是简单三板斧)

个人认为,数据分析并非只是简单的数据分析工具三板斧——Excel、SQL、Python,更重要的是数据分析思维。没有数据分析思维和业务知识,就算拿到一堆数据,也不知道如何下手。 推荐书本《数据分析思维——分析方法和业务知识》&#x…

matlab测试ADC动态性能的原理

目录 摘要: 简介: 动态规范和定义 动态规格: 双面到单边的功率谱转换 摘要: 模数转换器(adc)代表了接收器、测试设备和其他电子设备中的模拟世界和数字世界之间的联系。正如本文系列的第1部分中所概述…

5G中的ATG Band

Air to Ground Networks for NR是R18 NR引入的。ATG很多部分和NTN类似中的内容类似。比较明显不同的是,NTN的RF内容有TS 38.101-5单独去讲,而ATG则会和地面网络共用某些band,这部分在38.101-1中有描述。 所以会存在ATG与地面网络之间的相邻信…

vue组件开发:构建响应式快捷导航

前言 快捷导航不仅能够显著提升系统的灵活性和用户交互性,还极大地增强了用户的操作体验。本文将展示如何在 vue 中实现一个既可自定义又具备响应式特性的快捷导航菜单。 一、实现思路 列表页 结构设计 定义页面结构,包含一个导航卡片和一个对话框组件&a…

事务管理与锁机制

title: 事务管理与锁机制 date: 2024/12/14 updated: 2024/12/14 author: cmdragon excerpt: 在数据库系统中,事务管理至关重要,它确保多个数据库操作能够作为一个单一的逻辑单元来执行,从而维护数据的一致性和完整性。一个良好的事务管理系统能够解决并发操作带来的问题…

《操作系统 - 清华大学》7 -1:全局页面置换算法:局部页替换算法的问题、工作集模型

文章目录 1. 局部页替换算法的问题2. 全局置换算法的工作原理3. 工作集模式3.1 工作集3.2 工作集的变化 4 常驻集 1. 局部页替换算法的问题 局部页面置换算法 OPT,FIFO,LRU,Clock 等等,这些算法都是针对一个正在运行的程序来讲的…

力扣-图论-12【算法学习day.62】

前言 ###我做这类文章一个重要的目的还是给正在学习的大家提供方向和记录学习过程(例如想要掌握基础用法,该刷哪些题?)我的解析也不会做的非常详细,只会提供思路和一些关键点,力扣上的大佬们的题解质量是非…

每日十题八股-2024年12月14日

1.类加载器有哪些? 2.双亲委派模型的作用 3.讲一下类加载过程? 4.讲一下类的加载和双亲委派原则 5.什么是Java里的垃圾回收?如何触发垃圾回收? 6.判断垃圾的方法有哪些? 7.垃圾回收算法是什么,是为了解决了…

智能引导小车充电系统设计(论文+源码)

1总体方案设计 在16*16点阵LED字符显示器的设计中,系统总体框架如图2.4所示,包括单片机主控模复位电路模块、晶振电路模块、按键电路模块、LED点阵驱动电路模块,蓝牙模块等构成。系统功能实现主要是利用系统在软件程序编写过程中&#xff0c…

【Vue】自定义指令、插槽

目录 自定义指令 是什么 作用 使用方法 定义 使用 自定义指令配合绑定数据 语法 自定义指令的简写 语法 使用时机 插槽 什么是插槽 默认(匿名)插槽 ​编辑插槽的默认值 具名插槽 使用方法 简写 使用示例 作用域插槽 自定义指令 是什…

顺序队列的实现及其应用

一、概念 队列是允许在两端(队头、队尾)进行插入和读出操作的线性表 默认情况下,队尾插入,队头读出(这一点和排队很像),先进先出FIFO 队中没有元素时称为空队 当队列两端都允许插入、读出时&…

Web安全深度剖析

1.Web安全简介 ​ 攻击者想要对计算机进行渗透,有一个条件是必须的,就是攻击者的计算机与服务器必须能够正常通信,服务器与客户端进行通信依靠的就是端口。 ​ 如今的web应该称之为web应用程序,功能强大,离不开四个要…

C# 探险之旅:第九节 - 循环(for):无限循环的魔法轮盘!

嘿,勇敢的探险家们,欢迎回到C#的神秘世界!在这一节里,我们将踏上一场关于循环的奇妙冒险,特别是那个能带我们无限次探险的“for循环”!准备好了吗?让我们一起揭开for循环的神秘面纱,…

解决Logitech G hub 无法进入一直转圈的方案(2024.12)

如果你不是最新版本无法加载尝试以下方案:删除AppData 文件夹下的logihub文件夹 具体路径:用户名根据实际你的请情况修改 C:\Users\Administrator\AppData\Local 如果你有通过lua编译脚本,记得备份!! ↓如果你是最新…

如何使用 Docker Compose 创建 LAMP 环境 ?

现如今,通过 Docker 容器化部署环境已经逐渐成为主流,特别是在部署像 LAMP (Linux、Apache、MySQL、PHP) 这样的复杂环境时。本教程旨在带您完成使用 Docker-Compose 建立 LAMP 环境的整个过程,同时还包括定制 PHP 环境的步骤,安装…

12.1【JAVA EXP4】next项目

next项目构建问题 详解一下这个页面 什么是Node选项? Node选项是指在运行Node.js应用程序时可以传递给Node.js进程的一系列命令行参数。这些选项可以让开发者控制Node.js的行为,例如设置内存限制、启用或禁用某些功能、指定调试端口等 --inspect 和 --i…

PyTorch3D 可视化

PyTorch3D是非常好用的3D工具库。但是PyTorch3D对于可用于debug(例如调整cameras参数)的可视化工具并没有进行系统的介绍。这篇文章主要是想介绍我觉得非常使用的PyTorch3D可视化工具。 1. 新建一个Mesh 从hugging face上下载一个glb文件,例…

内网穿透讲解

什么是内网穿透 内网穿透是一种网络技术,它允许外网或者其他局域网的用户来访问这个局域网的服务器资源,让资源的利用率更高,更加灵活,但是也要确保网络安全。 工作原理 如果你在公司,但是你需要使用到你家里的那台电…

Python中PyTorch详解

文章目录 Python中PyTorch详解一、引言二、PyTorch核心概念1、张量(Tensor)1.1、创建张量1.2、张量操作 2、自动求导(Autograd)2.1、自动求导示例 三、构建神经网络1、使用nn模块2、优化器(Optimizer) 四、…