创建一个类别条件扩散模型
- 1. 配置和数据准备
- 2. 创建一个以类别为条件的UNet模型
- 3. 训练和采样
本文介绍一种给扩散模型添加额外条件信息的方法。具体地,将在MNIST数据集上训练一个以类别为条件的扩散模型。并且可以在推理阶段指定想要生成的是哪个数字。
1. 配置和数据准备
首先安装diffusers库:
!pip install -q diffusers
导入相关依赖包:
加载MNIST数据集:
# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(
root="./mnist/",
train=True,
download=True,
transform=torchvision.transforms.ToTensor()
)
# 创建数据加载器
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# 查看MNIST数据集中的部分样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
2. 创建一个以类别为条件的UNet模型
输入类别这一条件的流程:
(1)创建一个标准的UNet2DModel
,加入一些额外的输入通道。
(2)通过一个嵌入层,把类别标签映射到一个长度为class_emb_size
的特征向量上。
(3)把这个信息作为额外通道和原有的输入向量拼接起来。
net_input = torch.cat((x, class_cond), 1)
(4)将net_input
(其中包含class_emb_size + 1
个通道)输入UNet
模型,得到最终的预测结果。
这里,class_emb_size被设置成4,但它其实是可以进行任意修改的,或者把需要学到的nn.Embedding替换成简单地对类别进行one-hot编码,代码如下:
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
# 这个网络层会把数字所属的类别映射到一个长度为class_emb_size的特征向量上
self.class_emb = nn.Embedding(num_classes, class_emb_size)
# self.model是一个不带生成条件的UNet模型,这里,给他添加了额外的输入通道,用于接收条件信息
self.model = UNet2DModel(
sample_size = 28, # 所生成图片的尺寸
in_channels = 1 + class_emb_size, # 加入额外的输入通道
out_channels = 1, # 输出结果的通道数
layers_per_block=2, # 残差层个数
block_out_channels=(32, 64, 64),
down_block_types=(
"DownBlock2D", # 常规的ResNet下采样模块
"AttnDownBlock2D", # 含有spatial self-attention的ResNet下采样模块
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # 含有spatil self-attention的ResNet上采样模块
"UpBlock2D", # 上采样模块
),
)
# 此时扩散模型的前向计算就会含有额外的类别标签作为输入了
def forward(self, x, t, class_labels):
bs, ch, w, h = x.shape
# 类别条件将会以额外通道的形式输入
class_cond = self.class_emb(class_labels) # 将类别映射为向量形式,
# 并扩展成类似于(bs, 4, 28, 28)的张量形式
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
# 将原始输入和类别条件信息拼接到一起
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
# 使用模型进行预测
return self.model(net_input, t).sample # (bs, 1, 28, 28)
3. 训练和采样
这里使用prediction = unet(x, t, y)
在训练时把正确的标签作为第三个输入发送给模型。如果一切正常,模型将会输出与之相匹配的图片。y在这里的范围是0~9.
# 创建一个调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
# 定义数据加载器
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
n_epochs = 10
loss_fn = nn.MSELoss()
net = ClassConditionedUnet().to(device)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []
# 训练开始
for epoch in range(n_epochs):
for x, y in tqdm(train_dataloader):
# 获取数据并添加噪声
x = x.to(device) * 2 - 1 # 数据被归一化到区间(-1, 1)
y = y.to(device)
noise = torch.randn_like(x)
timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
# 预测
pred = net(noisy_x, timesteps, y) # 注意这里输入了类别信息
# 计算损失值
loss = loss_fn(pred, noise)
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
avg_loss = sum(losses[-100:])/100
print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')
plt.plot(losses)
Finished epoch 0. Average of the last 100 loss values: 0.053393
Finished epoch 1. Average of the last 100 loss values: 0.047172
Finished epoch 2. Average of the last 100 loss values: 0.045227
Finished epoch 3. Average of the last 100 loss values: 0.043402
Finished epoch 4. Average of the last 100 loss values: 0.041524
Finished epoch 5. Average of the last 100 loss values: 0.040847
Finished epoch 6. Average of the last 100 loss values: 0.040252
Finished epoch 7. Average of the last 100 loss values: 0.040134
Finished epoch 8. Average of the last 100 loss values: 0.038976
Finished epoch 9. Average of the last 100 loss values: 0.039234
损失曲线:
训练结束后,可以通过输入不同的标签作为条件来采样图片:
# 准备一个随机噪声作为起点,并准备想要的图片标签
x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)
print(y)
# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
with torch.no_grad():
residual = net(x, t, y)
x = noise_scheduler.step(residual, t, x).prev_sample
# 显示结果
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
这里,我们的y标签为:
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0'
因此对应生成的图片为:
至此,已经实现了对输出图片的控制。