diffusion 之 mnist 数据集
- mnist数据集
- ddpm/script_utils.py
- scripts/train_mnist.py
- 展示采样结果
代码出处:https://github.com/abarankab/DDPM
wandb的问题解决方法:
step1: 按照这个https://blog.csdn.net/weixin_43164054/article/details/124156206一步步走 step2: 修改project_name=“cifar”,然后执行python train_cifar.py 若出现报错"wandb: ERROR It appears that you do not have permission to access the requested resource.",参看这个https://blog.csdn.net/weixin_43835996/article/details/126955917
mnist数据集
ddpm/script_utils.py
line 90:img_channel=1,因为cifar图片为3通道,而mnist图片为1通道
line 101: initial_pad=2, 是因为cifar数据集的图片大小为32,为2的指数倍,降采样过程中除以2的话一直能整除;而mnist的图片大小为28,所以要padding为32,即设置initial_pad=2
line 120:cifar10 的图片大小为3232, mnist的图片大小为2828,
import argparse
import torchvision
import torch.nn.functional as F
from .unet import UNet
from .diffusion import (
GaussianDiffusion,
generate_linear_schedule,
generate_cosine_schedule,
)
def cycle(dl):
"""
https://github.com/lucidrains/denoising-diffusion-pytorch/
"""
while True:
for data in dl:
yield data
def get_transform():
class RescaleChannels(object):
def __call__(self, sample):
return 2 * sample - 1
return torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
RescaleChannels(),
])
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
def add_dict_to_argparser(parser, default_dict):
"""
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
"""
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{k}", default=v, type=v_type)
def diffusion_defaults():
defaults = dict(
num_timesteps=1000,
schedule="linear",
loss_type="l2",
use_labels=False,
base_channels=128,
channel_mults=(1, 2, 2, 2),
num_res_blocks=2,
time_emb_dim=128 * 4,
norm="gn",
dropout=0.1,
activation="silu",
attention_resolutions=(1,),
ema_decay=0.9999,
ema_update_rate=1,
)
return defaults
def get_diffusion_from_args(args):
activations = {
"relu": F.relu,
"mish": F.mish,
"silu": F.silu,
}
# base_channels=128
model = UNet(
img_channels=1,
base_channels=args.base_channels,
channel_mults=args.channel_mults,
time_emb_dim=args.time_emb_dim,
norm=args.norm,
dropout=args.dropout,
activation=activations[args.activation],
attention_resolutions=args.attention_resolutions,
num_classes=None if not args.use_labels else 10,
initial_pad=2,
)
# line102 在cifar中为initial_pad=0,
if args.schedule == "cosine":
betas = generate_cosine_schedule(args.num_timesteps)
else:
betas = generate_linear_schedule(
args.num_timesteps,
args.schedule_low * 1000 / args.num_timesteps,
args.schedule_high * 1000 / args.num_timesteps,
)
# 本py文件共修改了3处:line 90 ; line 101 ;line 120.
# model, (32, 32), 3, 10,
# cifar10 的图片大小为32*32,3channel, mnist的图片大小为28*28,1channel
diffusion = GaussianDiffusion(
model, (28, 28), 1, 10,
betas,
ema_decay=args.ema_decay,
ema_update_rate=args.ema_update_rate,
ema_start=2000,
loss_type=args.loss_type,
)
return diffusion
scripts/train_mnist.py
把entity=‘treaptofun’,给去掉
import argparse
import datetime
import torch
import wandb
from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils
def main():
args = create_argparser().parse_args()
device = args.device
try:
diffusion = script_utils.get_diffusion_from_args(args).to(device)
optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)
# 接着上次中断保存的参数继续训练
if args.model_checkpoint is not None:
diffusion.load_state_dict(torch.load(args.model_checkpoint))
if args.optim_checkpoint is not None:
optimizer.load_state_dict(torch.load(args.optim_checkpoint))
if args.log_to_wandb:
if args.project_name is None:
raise ValueError("args.log_to_wandb set to True but args.project_name is None")
# wandb.init(project="ddpm_cifar")
run = wandb.init(
project=args.project_name,
config=vars(args),
name=args.run_name,
)
# entity='treaptofun',
wandb.watch(diffusion)
batch_size = args.batch_size
train_dataset = datasets.MNIST(
root='../dataset/mnist/mnist_train',
train=True,
download=True,
transform=script_utils.get_transform(),
)
test_dataset = datasets.MNIST(
root='../dataset/mnist/mnist_test',
train=False,
download=True,
transform=script_utils.get_transform(),
)
train_loader = script_utils.cycle(DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=2,
))
test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)
acc_train_loss = 0
for iteration in range(1, args.iterations + 1):
diffusion.train()
x, y = next(train_loader)
x = x.to(device)
y = y.to(device)
if args.use_labels:
loss = diffusion(x, y)
else:
loss = diffusion(x)
acc_train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
diffusion.update_ema()
if iteration % args.log_rate == 0:
test_loss = 0
with torch.no_grad():
diffusion.eval()
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
if args.use_labels:
loss = diffusion(x, y)
else:
loss = diffusion(x)
test_loss += loss.item()
if args.use_labels:
samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
else:
samples = diffusion.sample(10, device)
samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()
test_loss /= len(test_loader)
acc_train_loss /= args.log_rate
wandb.log({
"test_loss": test_loss,
"train_loss": acc_train_loss,
"samples": [wandb.Image(sample) for sample in samples],
})
acc_train_loss = 0
if iteration % args.checkpoint_rate == 0:
model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"
torch.save(diffusion.state_dict(), model_filename)
torch.save(optimizer.state_dict(), optim_filename)
if args.log_to_wandb:
run.finish()
except KeyboardInterrupt:
if args.log_to_wandb:
run.finish()
print("Keyboard interrupt, run finished early")
def create_argparser():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
defaults = dict(
learning_rate=2e-4,
batch_size=128,
iterations=80000,
log_to_wandb=True,
log_rate=1000,
checkpoint_rate=1000,
log_dir="./ddpm_logs_mnist",
project_name="mnist",
run_name=run_name,
model_checkpoint=None,
optim_checkpoint=None,
schedule_low=1e-4,
schedule_high=0.02,
device=device,
)
defaults.update(script_utils.diffusion_defaults())
parser = argparse.ArgumentParser()
script_utils.add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
命令行执行的训练命令:
python train.py
命令行执行的采样命令
python sample_images.py --model_path "your model path" --save_dir "your save img path" --schedule cosine
展示采样结果
import matplotlib.pyplot as plt
import numpy as np
import os
def show(num_imgs, dir_path):
'''
num_imgs: 要展示的图片的张数
dir_path:图片的路径
'''
img_names=os.listdir (dir_path)
img_names.sort(key=lambda x:int(x.split('.')[0]))
plt.figure(figsize=(20,5)) # 画布大小
N=2
M=10
#形成NxM大小的画布
for i in range(num_imgs):#有张图片
path = dir_path + img_names[i]
img = plt.imread(path)
plt.subplot(N,M,i+1)#表示第i张图片,下标只能从1开始,不能从0,
plt.imshow(img)
plt.title(img_names[i],color='black')
#下面两行是消除每张图片自己单独的横纵坐标,不然每张图片会有单独的横纵坐标,影响美观
plt.xticks([])
plt.yticks([])
plt.show()
print("mnist generation results:")
show(20, './scripts/save_dir_mnist/') # 模型训练出来的保存的结果
这里的名字只是预测出来的图片的序号,并不是预测的label!