1--DDIM介绍
原论文:DENOISING DIFFUSION IMPLICIT MODELS
2--核心代码
# ddim的实现
def compute_alpha(beta, t):
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) # beta -> [1, beta]
# 先通过cumprod计算累乘结果,即: alpha_(t)_hat = alpha_(t) * alpha_(t-1) * ... * alpha_1 * alpha_0
# 再选取alpha_(t)_hat, 这里用索引t+1来选取
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
return a
# ddim的实现, 参考: https://github.com/ermongroup/ddim/blob/main/functions/denoising.py
def generalized_steps(x, seq, model, b, **kwargs):
with torch.no_grad():
n = x.size(0) # batchsize
seq_next = [-1] + list(seq[:-1]) # t-skip: [-1, 0, 10, 20, ..., 980], len: 100
x0_preds = []
xs = [x]
for i, j in zip(reversed(seq), reversed(seq_next)): # i = t, j = t-skip
t = (torch.ones(n) * i).to(x.device) # t
next_t = (torch.ones(n) * j).to(x.device) # t-1
at = compute_alpha(b, t.long()) # alpha_(t)_hat
at_next = compute_alpha(b, next_t.long()) # alpha_(t-1)_hat
xt = xs[-1].to('cuda') # 获取当前时间步的样本,即x_t
et = model(xt, t) # 预测噪声
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() # 论文公式(12)中的 predicted x0
x0_preds.append(x0_t.to('cpu')) # 记录当前时间步的 predicted x0
c1 = (kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()) # 计算公式(12)中的标准差(\sigma)_(t)
c2 = ((1 - at_next) - c1 ** 2).sqrt() # 论文公式(12)中 direction pointing to xt 的系数
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et # 根据公式(12)计算x_(t-1)
xs.append(xt_next.to('cpu')) # 记录每一个时间步的x_(t-1)
return xs, x0_preds # 保存了每一个时间步的结果
3--完整代码
DDIM_Demo