随机微分方程的分数扩散模型 (score-based diffusion model) 代码示例

随机微分方程的分数扩散模型(Score-Based Generative Modeling through Stochastic Differential Equations)

基于分数的扩散模型,是估计数据分布梯度的方法,可以在不需要对抗训练的基础上,生成与GAN一样高质量的图片。来源于文章:Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations." Internation Conference on Learning Representations, 2021

score-based diffusion是diffusion模型大火之后,又一个里程碑式的工作,将扩散模型和分数生成模型进行了统一。原始的扩散模型也有缺点,它的采样速度慢,通常需要数千个评估步骤才能抽取一个样本。而 score-based 的扩散模型可以在较短的时间内完成采样。

网络上有很多关于score-based diffusion原理介绍,应用案例等,还有文章解读,大家可以参考。但是,提供代码简介的很少,为此这里提供了score-based diffusion 模型的简单的可运行的代码示例。

1. 定义time-dependent score-based模型

导入相关模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm

1.1 将时间t嵌入的投影层

其实并没有投影层的说法,这里是为了描述将时间t (time step),随机初始化采样权重,然后使用[sin(2πωt);cos(2πωt)]生成相应的高斯随机特征向量的过程。注意,里面的参数是不可训练的。

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # 在初始化期间随机采样权重。 这些权重是固定的 
    # 在优化期间并且不可训练
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

将时间t嵌入的投影层的出现,是因为score-based的扩散模型和正常的扩散模型的训练过程不一样。score-based的扩散模型在训练过程中,神经网络接受带有随机噪音的 x ,然后随机的时间信息 t 添加x中,然后利用x 和 t 作为输入,计算模型损失。

维度转换全连接层:

class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)[..., None, None]

1.2 时间依赖基于分数的Unet模型

(time-dependent score-based model) 时间依赖,打分相关的Unet模型,froward函数中,输入除了x,还有时间t. 时间t经过GaussianFourierProjection嵌入后融合到模型中,然后输出marginal_prob_std正则化的结果。

class ScoreNet(nn.Module):
  """初始化一个依赖时间的基于分数的Unet网络."""

  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
    """.

    Args:
      marginal_prob_std: 输入时间 t 并给出扰动核的标准差的函数 p_{0t}(x(t) | x(0)).
      channels: 各分辨率特征图的通道数.
      embed_dim: 高斯随机特征嵌入的维数,与1.1中GaussianFourierProjection相同.
    """
    super().__init__()
    # 时间t的高斯随机特征嵌入层
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, embed_dim))
    # Encoding layers where the resolution decreases
    self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
    self.dense1 = Dense(embed_dim, channels[0])
    self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
    self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
    self.dense2 = Dense(embed_dim, channels[1])
    self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
    self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
    self.dense3 = Dense(embed_dim, channels[2])
    self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
    self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
    self.dense4 = Dense(embed_dim, channels[3])
    self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])    

    # 分辨率增加的解码层
    self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
    self.dense5 = Dense(embed_dim, channels[2])
    self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
    self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)    
    self.dense6 = Dense(embed_dim, channels[1])
    self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
    self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)    
    self.dense7 = Dense(embed_dim, channels[0])
    self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
    self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
    
    # Swish 激活函数
    self.act = lambda x: x * torch.sigmoid(x)
    self.marginal_prob_std = marginal_prob_std
  
  def forward(self, x, t): 
    # 0   
    embed = self.act(self.embed(t))    
    # Encoding path
    h1 = self.conv1(x)    
    ## 合并来自 t 的信息
    h1 += self.dense1(embed)
    ## 组标准化
    h1 = self.gnorm1(h1)
    h1 = self.act(h1)
    h2 = self.conv2(h1)
    h2 += self.dense2(embed)
    h2 = self.gnorm2(h2)
    h2 = self.act(h2)
    h3 = self.conv3(h2)
    h3 += self.dense3(embed)
    h3 = self.gnorm3(h3)
    h3 = self.act(h3)
    h4 = self.conv4(h3)
    h4 += self.dense4(embed)
    h4 = self.gnorm4(h4)
    h4 = self.act(h4)

    # Decoding path
    h = self.tconv4(h4)
    ## 从编码路径跳过连接
    h += self.dense5(embed)
    h = self.tgnorm4(h)
    h = self.act(h)
    h = self.tconv3(torch.cat([h, h3], dim=1))
    h += self.dense6(embed)
    h = self.tgnorm3(h)
    h = self.act(h)
    h = self.tconv2(torch.cat([h, h2], dim=1))
    h += self.dense7(embed)
    h = self.tgnorm2(h)
    h = self.act(h)
    h = self.tconv1(torch.cat([h, h1], dim=1))

    # Normalize output 正则化输出
    h = h / self.marginal_prob_std(t)[:, None, None, None]
    return h

2. 设置SDE

SDE用于将P_0扰动到P_T, 其中,包含两个重要函数:之前提到的marginal_prob_std和扩散系数diffusion_coeff marginal_prob_std,计算 p_{0t}(x(t) | x(0)) 的平均值和标准差; diffusion_coeff,计算SDE的扩散系数.

device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}

def marginal_prob_std(t, sigma):
  """计算p_{0t}(x(t) | x(0))的平均值和标准差.

  Args:    
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.  
  
  Returns:
    标准差.
  """    
  t = torch.tensor(t, device=device)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """计算SDE的扩散系数.

  Args:
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
  
  Returns:
    扩散系数向量.
  """
  return torch.tensor(sigma**t, device=device)
  
sigma =  25.0 #@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

3. 定义损失函数

损失函数是一个复杂的公式,但是具体形式固定. 代码如下:

def loss_fn(model, x, marginal_prob_std, eps=1e-5):
  """The loss function for training score-based generative models.

  Args:
    model: 时间依赖,基于分数的 PyTorch model.
    x: A mini-batch of training data.    
    marginal_prob_std: A function that gives the standard deviation of 
      the perturbation kernel.
    eps: A tolerance value for numerical stability.
  """
  random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
  z = torch.randn_like(x)
  std = marginal_prob_std(random_t)
  perturbed_x = x + z * std[:, None, None, None]
  score = model(perturbed_x, random_t)
  loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
  return loss

4. 训练模型

与正常的训练模型相似,调用模型,建立优化器,损失反向等;代码如下:

score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs =   50#@param {'type':'integer'}
## size of a mini-batch
batch_size =  32 #@param {'type':'integer'}
## learning rate
lr=1e-4 #@param {'type':'number'}

dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.notebook.trange(n_epochs)
for epoch in tqdm_epoch:
  avg_loss = 0.
  num_items = 0
  for x, y in data_loader:
    x = x.to(device)    
    loss = loss_fn(score_model, x, marginal_prob_std_fn)
    optimizer.zero_grad()
    loss.backward()    
    optimizer.step()
    avg_loss += loss.item() * x.shape[0]
    num_items += x.shape[0]
  # Print the averaged training loss so far.
  tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
  # Update the checkpoint after each epoch of training.
  torch.save(score_model.state_dict(), 'ckpt.pth')

训练过程输出如下:

5. 采样器/求解器

score-based diffusion模型有多种求解器,

5.1 欧拉-丸山采样器/求解器(Euler-Maruyama sampler)

欧拉-丸山采样方法属于数值SDE求解的方法,是基于神经网络预测的分数,利用逆时的SDE数值解,进行采样。

## 采样步数
num_steps =  500 #@param {'type':'integer'}
def Euler_Maruyama_sampler(score_model, 
                           marginal_prob_std,
                           diffusion_coeff, 
                           batch_size=64, 
                           num_steps=num_steps, 
                           device='cuda', 
                           eps=1e-3):
  """使用 Euler-Maruyama 求解器从基于分数的模型生成样本.

  Args:
    score_model: 时间依赖,基于分数的 PyTorch model.
    marginal_prob_std: A function that gives the standard deviation of
      the perturbation kernel.
    diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
    batch_size: 批次大小.
    num_steps: 采样步数, 等价于相当于离散时间步数.
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    eps: 数值稳定性的最小时间步长.
  
  Returns:
    采样样本.    
  """
  t = torch.ones(batch_size, device=device)
  init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
    * marginal_prob_std(t)[:, None, None, None]
  time_steps = torch.linspace(1., eps, num_steps, device=device)
  step_size = time_steps[0] - time_steps[1]
  x = init_x
  with torch.no_grad():
    for time_step in tqdm.notebook.tqdm(time_steps):      
      batch_time_step = torch.ones(batch_size, device=device) * time_step
      g = diffusion_coeff(batch_time_step)
      mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
      x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)      
  # Do not include any noise in the last sampling step.
  return mean_x

5.2 预测-检验采样器

预测校正器采样器结合了逆时 SDE 的数值求解器和 Langevin MCMC 方法。 具体来说,我们首先应用数值 SDE 求解器的一个步骤从 xt 获得 xt−Δt,这称为“预测器”步骤。 接下来,我们应用 Langevin MCMC 的几个步骤来细化 xt ,使得 xt 成为 pt−Δt(x) 的更准确的样本。 这是“校正器”步骤,因为 MCMC 有助于减少数值 SDE 求解器的误差。

signal_to_noise_ratio = 0.16 #@param {'type':'number'}

## The number of sampling steps.
num_steps =  500#@param {'type':'integer'}
def pc_sampler(score_model, 
               marginal_prob_std,
               diffusion_coeff,
               batch_size=64, 
               num_steps=num_steps, 
               snr=signal_to_noise_ratio,                
               device='cuda',
               eps=1e-3):
  """
  使用预测-校正方法从基于分数的模型生成样本.

  Args:
    score_model: 时间依赖,基于分数的 PyTorch model.
    marginal_prob_std: A function that gives the standard deviation
      of the perturbation kernel.
    diffusion_coeff: A function that gives the diffusion coefficient 
      of the SDE.
    batch_size: 批次大小.
    num_steps: 采样步数, 等价于相当于离散时间步数.    
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    eps: 数值稳定性的最小时间步长.
  
  Returns: 
    采样样本.
  """
  t = torch.ones(batch_size, device=device)
  init_x = torch.randn(batch_size, 1, 28, 28, device=device) * marginal_prob_std(t)[:, None, None, None]
  time_steps = np.linspace(1., eps, num_steps)
  step_size = time_steps[0] - time_steps[1]
  x = init_x
  with torch.no_grad():
    for time_step in tqdm.notebook.tqdm(time_steps):      
      batch_time_step = torch.ones(batch_size, device=device) * time_step
      # 检验器 step (Langevin MCMC)
      grad = score_model(x, batch_time_step)
      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
      noise_norm = np.sqrt(np.prod(x.shape[1:]))
      langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
      x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)      

      # 预测器 step (Euler-Maruyama)
      g = diffusion_coeff(batch_time_step)
      x_mean = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
      x = x_mean + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)      
    
    # The last step does not include any noise
    return x_mean

5.3 ODE数值求解器

每一个SDE都对应着一个ODE,通过逆时间方向求解此 ODE, 我们可以从与求解逆时间 SDE 相同的分布中进行采样。 我们将此 ODE 称为概率流 ODE。 这可以使用 scipy 等软件包提供的许多黑盒 ODE 求解器来完成。

from scipy import integrate

## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5 #@param {'type': 'number'}
def ode_sampler(score_model,
                marginal_prob_std,
                diffusion_coeff,
                batch_size=64, 
                atol=error_tolerance, 
                rtol=error_tolerance, 
                device='cuda', 
                z=None,
                eps=1e-3):
  """Generate samples from score-based models with black-box ODE solvers.

  Args:
    score_model: A PyTorch model that represents the time-dependent score-based model.
    marginal_prob_std: A function that returns the standard deviation 
      of the perturbation kernel.
    diffusion_coeff: A function that returns the diffusion coefficient of the SDE.
    batch_size: The number of samplers to generate by calling this function once.
    atol: Tolerance of absolute errors.
    rtol: Tolerance of relative errors.
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    z: The latent code that governs the final sample. If None, we start from p_1;
      otherwise, we start from the given z.
    eps: The smallest time step for numerical stability.
  """
  t = torch.ones(batch_size, device=device)
  # Create the latent code
  if z is None:
    init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
      * marginal_prob_std(t)[:, None, None, None]
  else:
    init_x = z
    
  shape = init_x.shape

  def score_eval_wrapper(sample, time_steps):
    """A wrapper of the score-based model for use by the ODE solver."""
    sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
    time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))    
    with torch.no_grad():    
      score = score_model(sample, time_steps)
    return score.cpu().numpy().reshape((-1,)).astype(np.float64)
  
  def ode_func(t, x):        
    """The ODE function for use by the ODE solver."""
    time_steps = np.ones((shape[0],)) * t    
    g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
    return  -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
  
  # Run the black-box ODE solver.
  res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')  
  print(f"Number of function evaluations: {res.nfev}")
  x = torch.tensor(res.y[:, -1], device=device).reshape(shape)

  return x

6. 采样

from torchvision.utils import make_grid

## Load the pre-trained checkpoint from disk.
device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)

sample_batch_size = 64 #@param {'type':'integer'}
# 采样器配置
sampler = ode_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}

## Generate samples using the specified sampler.
samples = sampler(score_model, 
                  marginal_prob_std_fn,
                  diffusion_coeff_fn, 
                  sample_batch_size, 
                  device=device)

## Sample visualization.
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))

plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()

输出结果如下:

大家可以多试试其他的采样器,看看不同采样器输出结果的区别。

7. 似然计算(Likelihood Computation)

概率流 ODE 公式的副产品是似然计算。

def prior_likelihood(z, sigma):
  """The likelihood of a Gaussian distribution with mean zero and 
      standard deviation sigma."""
  shape = z.shape
  N = np.prod(shape[1:])
  return -N / 2. * torch.log(2*np.pi*sigma**2) - torch.sum(z**2, dim=(1,2,3)) / (2 * sigma**2)

def ode_likelihood(x, 
                   score_model,
                   marginal_prob_std, 
                   diffusion_coeff,
                   batch_size=64, 
                   device='cuda',
                   eps=1e-5):
  """Compute the likelihood with probability flow ODE.
  
  Args:
    x: Input data.
    score_model: A PyTorch model representing the score-based model.
    marginal_prob_std: A function that gives the standard deviation of the 
      perturbation kernel.
    diffusion_coeff: A function that gives the diffusion coefficient of the 
      forward SDE.
    batch_size: The batch size. Equals to the leading dimension of `x`.
    device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
    eps: A `float` number. The smallest time step for numerical stability.

  Returns:
    z: The latent code for `x`.
    bpd: The log-likelihoods in bits/dim.
  """

  # Draw the random Gaussian sample for Skilling-Hutchinson's estimator.
  epsilon = torch.randn_like(x)
      
  def divergence_eval(sample, time_steps, epsilon):      
    """Compute the divergence of the score-based model with Skilling-Hutchinson."""
    with torch.enable_grad():
      sample.requires_grad_(True)
      score_e = torch.sum(score_model(sample, time_steps) * epsilon)
      grad_score_e = torch.autograd.grad(score_e, sample)[0]
    return torch.sum(grad_score_e * epsilon, dim=(1, 2, 3))    
  
  shape = x.shape

  def score_eval_wrapper(sample, time_steps):
    """A wrapper for evaluating the score-based model for the black-box ODE solver."""
    sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
    time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))    
    with torch.no_grad():    
      score = score_model(sample, time_steps)
    return score.cpu().numpy().reshape((-1,)).astype(np.float64)
  
  def divergence_eval_wrapper(sample, time_steps):
    """A wrapper for evaluating the divergence of score for the black-box ODE solver."""
    with torch.no_grad():
      # Obtain x(t) by solving the probability flow ODE.
      sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
      time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))    
      # Compute likelihood.
      div = divergence_eval(sample, time_steps, epsilon)
      return div.cpu().numpy().reshape((-1,)).astype(np.float64)
  
  def ode_func(t, x):
    """The ODE function for the black-box solver."""
    time_steps = np.ones((shape[0],)) * t    
    sample = x[:-shape[0]]
    logp = x[-shape[0]:]
    g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
    sample_grad = -0.5 * g**2 * score_eval_wrapper(sample, time_steps)
    logp_grad = -0.5 * g**2 * divergence_eval_wrapper(sample, time_steps)
    return np.concatenate([sample_grad, logp_grad], axis=0)

  init = np.concatenate([x.cpu().numpy().reshape((-1,)), np.zeros((shape[0],))], axis=0)
  # Black-box ODE solver
  res = integrate.solve_ivp(ode_func, (eps, 1.), init, rtol=1e-5, atol=1e-5, method='RK45')  
  zp = torch.tensor(res.y[:, -1], device=device)
  z = zp[:-shape[0]].reshape(shape)
  delta_logp = zp[-shape[0]:].reshape(shape[0])
  sigma_max = marginal_prob_std(1.)
  prior_logp = prior_likelihood(z, sigma_max)
  bpd = -(prior_logp + delta_logp) / np.log(2)
  N = np.prod(shape[1:])
  bpd = bpd / N + 8.
  return z, bpd

计算数据集的似然率:
batch_size = 32 #@param {'type':'integer'}

dataset = MNIST('.', train=False, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)

all_bpds = 0.
all_items = 0
try:
  tqdm_data = tqdm.notebook.tqdm(data_loader)
  for x, _ in tqdm_data:
    x = x.to(device)
    # uniform dequantization
    x = (x * 255. + torch.rand_like(x)) / 256.    
    _, bpd = ode_likelihood(x, score_model, marginal_prob_std_fn,
                            diffusion_coeff_fn,
                            x.shape[0], device=device, eps=1e-5)
    all_bpds += bpd.sum()
    all_items += bpd.shape[0]
    tqdm_data.set_description("Average bits/dim: {:5f}".format(all_bpds / all_items))

except KeyboardInterrupt:
  # Remove the error message when interuptted by keyboard or GUI.
  pass

8. 自己心得总结:

(1)随机微分方程的分数扩散模型需要一个时间依赖的基于分数的神经网络;

(2)时间依赖的基于分数的神经网络forward函数,输入是扰动后的x, t,输出是分数,这一点与传统的扩散模型不同; 传统的扩散模型神经网络输入是扰动后的x,然后输出不带噪音的x或者噪音;

(3)在时间依赖的基于分数的神经网络forward函数中,需要几个重要的支持函数: GaussianFourierProjection:输入时间t,输出高斯随机特征向量,使t可以被整合到x中; marginal_prob_std:计算时间步t的方差,用于神经网络输出分数的归一化; *意味着基于分数的扩散模型,需要重新写模型架构

(4)基于分数的扩散模型的损失函数非常简单,为: loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))

(5)基于神经网络输出的分数采样有多种方法,分别为: 欧拉-丸山采样器(Euler-Maruyama sampler),预测-检验采样器,ODE数值求解器;

(6)每一种采样器都需要先设置好SDE,里面一个重要函数是diffusion_coeff_fn,用于计算SDE的扩散系数

(7)每一种采样器都有固定的形式直接使用就好;

写在最后,关于score-based的diffusion模型的原理,我这里并没有介绍。因为现在又很多博客或者公众号,视频都有详细的介绍,包括详细的共识推导,另外,我不是数学专业,里面很多的数学原理也是半知半解的,就不耽误大家了。大家可以查看相关资料。

关于原理,大家有想法或者想通俗的了解,可以留言,可以考虑出一个,专门说一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/117267.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Kotlin精简】第7章 泛型

1 泛型 泛型即 “参数化类型”,将类型参数化,可以用在类,接口,函数上。与 Java 一样,Kotlin 也提供泛型,为类型安全提供保证,消除类型强转的烦恼。 1.1 泛型优点 类型安全:通用允许…

CoDeSys系列-4、基于Ubuntu的codesys运行时扩展包搭建Profinet主从环境

CoDeSys系列-4、基于Ubuntu的codesys运行时扩展包搭建Profinet主从环境 文章目录 CoDeSys系列-4、基于Ubuntu的codesys运行时扩展包搭建Profinet主从环境一、前言二、资料收集三、Ubuntu18.04从安装到更换实时内核1、下载安装Ubuntu18.042、下载安装实时内核,解决编…

如何将PDF文件转换成翻页电子书?这个网站告诉你

​随着电子书的普及,越来越多的人开始将PDF文件转换成翻页电子书。翻页电子书不仅方便阅读,而且还可以在手机上轻松翻页。那么如何将PDF文件转换成翻页电子书呢?今天就为大家介绍一个网站,可以帮助你轻松完成这个任务。 1.首先&am…

Proteus仿真--12864LCD显示计算器键盘按键实验(仿真文件+程序)

本文主要介绍基于51单片机的12864LCD液晶显示电话拨号键盘按键实验(完整仿真源文件及代码见文末链接) 仿真图如下 本设计主要介绍计算器键盘仿真,按键按下后在12864液晶上显示对应按键键值 仿真运行视频 Proteus仿真--12864LCD显示计算器…

【漏洞复现】IIS_7.o7.5解析漏洞

感谢互联网提供分享知识与智慧,在法治的社会里,请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞扫描3、漏洞验证 1.5、修复建议 1.1、漏洞描述 漏洞原理: cgi.fix_path1 1.png/.php该…

第九章《搞懂算法:决策树是怎么回事》笔记

决策树算法是机器学习中很经典的一个算法,它既可以作为分类算法,也可以作为回归算法。 9.1 典型的决策树是什么样的 决策树算法是依据“分而治之”的思想,每次根据某属性的值对样本进行分类,然后传递给下个属性继续进行分类判断…

项目实战:新增@Controller和@Service@Repository@Autowire四个注解

1、Controller package com.csdn.mymvc.annotation; import java.lang.annotation.*; Target(ElementType.TYPE) Retention(RetentionPolicy.RUNTIME) Inherited public interface Controller { }2、Service package com.csdn.mymvc.annotation; import java.lang.annotation.*…

zookeeper节点类型

节点类型 持久节点(Persistent Nodes) 这些是Zookeeper中最常见的一种节点类型,当创建一个持久类型节点时,该值会一直存在zookeeper中,直到被显式删除或被新值覆盖。 临时节点(Ephemeral Nodes&#xff…

【漏洞复现】Apache_Tomcat_PUT方法任意写文件(CVE-2017-12615)

感谢互联网提供分享知识与智慧,在法治的社会里,请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞扫描3、漏洞验证工具扫描验证POC 1.6、修复建议 说明内容漏洞编号CVE-2017-12615漏洞名称Tomcat_PU…

【python】路径管理+路径拼接问题

路径管理 问题相对路径问题绝对路径问题 解决os库pathlib库最终解决 问题 环境:python3.7.16 win10 相对路径问题 因为python的执行特殊性,使用相对路径时,在不同路径下用python指令会有不同的索引效果(python的项目根目录根据执…

服务器搭建:从零开始创建自己的Spring Boot应用【含登录、注册功能】

当然,你可以先按照IDEA搭建SSM框架【配置类、新手向】完成基础框架的搭建 步骤 1:设计并实现服务器端的用户数据库 在这个示例中,我们将使用MySQL数据库。首先,你需要安装MySQL并创建一个数据库以存储用户信息。以下是一些基本步…

5.3有效的括号(LC20-E)

算法: 题目中:左括号必须以正确的顺序闭合。意思是,最后出现的左括号(对应着栈中的最后一个元素),应该先找到对应的闭合符号(右括号) 比如:s"( [ ) ]"就是False&#xf…

【错误解决方案】ModuleNotFoundError: No module named ‘my_fake_useragent‘

1. 错误提示 ModuleNotFoundError: No module named my_fake_useragent,这意味着你试图导入一个名为 my_fake_useragent 的模块,但Python找不到这个模块。 2. 解决方案 检查模块名是否正确: 确保你试图导入的模块名是正确的。也许你拼写错误或者大小写不…

【Midjourney入门教程1】Midjourney的注册、订阅

文章目录 前言一、Midjourney是什么二、Midjourney注册三、新建自己的服务器四、开通订阅 前言 AI绘画即指人工智能绘画,是一种计算机生成绘画的方式。是AIGC应用领域内的一大分支。 AI绘画主要分为两个部分,一个是对图像的分析与判断,即“…

onnx 模型加载部署运行方式

1.通过文件路径的onnx模型加载方式: 在onnxruntime下面的主要函数:session Ort::Session(env, w_modelPath.c_str(), sessionOptions); 这里的文件路径是宽字节的,通过onnx文件路径直接加载模型。 在opencv下使用dnn加载onnx模型的主要函数: std::string model…

Redo Log(重做日志)的刷盘策略

1. 概述 Redo Log(重做日志)是 InnoDB 存储引擎中的一种关键组件,用于保障数据库事务的持久性和崩溃恢复。InnoDB 将事务所做的更改先记录到重做日志,之后再将其应用到磁盘上的数据页。 刷盘策略(Flush Policy&#x…

css基础之实现轮播图

原理介绍 图片轮播的原理是通过控制显示和隐藏不同的图片来实现图像的切换,从而创建连续播放的效果。用到的知识点有定位和定时器。 实现步骤: HTML 结构: 首先,需要在HTML中创建一个包含轮播图片的容器,通常使用 &l…

Golang源码分析之golang/sync之singleflight

1.1. 项目介绍 golang/sync库拓展了官方自带的sync库,提供了errgroup、semaphore、singleflight及syncmap四个包,本次分析singlefliht的源代码。 singlefliht用于解决单机协程并发调用下的重复调用问题,常与缓存一起使用,避免缓存…

要做一名成功的测试,首先得会想?

近在做测试时,突然想到了这么个问题——在测试的过程中对某个功能想得越开,测试就完整,就越彻底! 当然我们在产生与该功能相关的想象时,其中最关键的是不能脱离需求,不能脱离该软件本身;不然这…

WebSocket Day02 : 握手连接

前言 握手连接是WebSocket建立通信的第一步,通过客户端和服务器之间的一系列握手操作,确保了双方都支持WebSocket协议,并达成一致的通信参数。握手连接的过程包括客户端发起握手请求、服务器响应握手请求以及双方完成握手连接。完成握手连接后…