文章目录
- DiT原始项目
- Fast-DiT readme
- Sampling
- Training
- 训练之前的准备
- 训练DiT
- PyTorch 训练结果
- 改进训练效果
- Evaluation (FID, Inception Score, etc.)
- 总结
DiT原始项目
该项目仅针对DiT训练,并未包含VAE 的训练
项目地址
论文主页
Fast-DiT readme
该项目仅针对DiT训练,并未包含VAE 的训练
项目地址
该项目是基于论文Scalable Diffusion Models with Transformers的pytorch改进实现
包含:
-
PyTorch的改进实现 和原始的DiT实现
-
在ImageNet (512x512 and 256x256)预训练的条件分类的 DiT 模型;
-
一个独立的 Hugging Face Space 和 Colab notebook,用于运行预训练的DiT-XL/2模型
-
改进的DiT 训练脚本 和一些 训练建议
启动
首先,下载开源代码
git clone https://github.com/chuanyangjin/fast-DiT.git
cd DiT
我们提供了 environment.yml
文件,可创建Conda 虚拟环境。
如果想要在本地CPU运行预训练的模型,可以在文件中删除cudatoolkit
和pytorch-cuda
相关的依赖项。
conda env create -f environment.yml
conda activate DiT
Sampling
预训练的DiT checkpoints
你可以使用预训练模型样例sample.py
预训练的DiT模型权重会根据使用的模型自动下载。
根据输入模型尺寸的不同,脚本中进行了不同的参数设置转换(256x256 and 512x512),比如针对512x512 DiT-XL/2 模型,你可以使用以下命令:
python sample.py --image-size 512 --seed 1
为了更加方便,我们的预训练模型也可直接进行下载:
DiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
---|---|---|---|---|
XL/2 | 256x256 | 2.27 | 278.24 | 119 |
XL/2 | 512x512 | 3.04 | 240.82 | 525 |
自定义的DiT checkpoints
如果你想训练一个新的DiT 模型,可使用 train.py
(see below)。
你可以增加一个参数–ckpt 使用你自己的checkpoint进行演示。比如运行一个 256x256 DiT-L/4 模型,可使用以下命令 :
If you’ve trained a new DiT model with train.py
(see below), you can add the --ckpt
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run:
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
Training
训练之前的准备
在一个GPU节点上抽取ImageNet 特征:
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path /path/to/imagenet/train --features-path /path/to/store/features
训练DiT
We provide a training script for DiT in train.py
. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning.
我们提供了一个DiT的训练脚本 train.py
。该训练脚本可以被用作训练条件分类的 DiT 模型,但是也可修改后用于其他类型的条件。
To launch DiT-XL/2 (256x256) training with 1
GPUs on one node:
accelerate launch --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features
To launch DiT-XL/2 (256x256) training with N
GPUs on one node:
accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features
或者,你可以选择提取并训练文件夹training options中的脚本。
PyTorch 训练结果
我们训练了 DiT-XL/2 和 DiT-B/4 模型
我们使用PyTorch训练脚本从头开始训练DiT-XL/2和DiT-B/4模型,以验证它重现了原始的JAX结果,达到数十万次训练迭代。
在我们的实验中,在合理的随机变化范围内,pytorch训练的模型与jax训练的模型相比,给出了类似(有时略好)的结果。一些数据点如下:
DiT Model | Train Steps | FID-50K (JAX Training) | FID-50K (PyTorch Training) | PyTorch Global Training Seed |
---|---|---|---|---|
XL/2 | 400K | 19.5 | 18.1 | 42 |
B/4 | 400K | 68.4 | 68.9 | 42 |
B/4 | 400K | 68.4 | 68.3 | 100 |
这些模型在256x256分辨率下进行训练; 我们使用8x A100来训练XL/2,使用4x A100来训练B/4。注意,这里的FID是使用mse
VAE解码器,在没有指导的情况下(cfg-scale=1
),通过250个DDPM采样步骤计算得到的。
改进训练效果
与原始实现相比,实现了一些训练加速和节省内存的特征,包括梯度检查点、混合精度训练和预提取的VAE特征,在DiT-XL/2上的速度提高了95%,内存减少了60%。一些数据点使用A100的全局批处理大小为128:
gradient checkpointing | mixed precision training | feature pre-extraction | training speed | memory |
---|---|---|---|---|
❌ | ❌ | ❌ | - | out of memory |
✔ | ❌ | ❌ | 0.43 steps/sec | 44045 MB |
✔ | ✔ | ❌ | 0.56 steps/sec | 40461 MB |
✔ | ✔ | ✔ | 0.84 steps/sec | 27485 MB |
Evaluation (FID, Inception Score, etc.)
我们提供一个’ sample_ddp.py '脚本,它可以并行地从DiT模型中对大量图像进行采样。
这个脚本生成一个样本文件夹和. npz文件一样,可以直接与ADM的TensorFlow评估套件一起使用,以计算FID, Inception分数和其他指标。
例如,要在N
个gpu上从预训练的DiT-XL/2模型中采样50K张图像,运行:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000
补充建议,详情请参见 sample_ddp.py
总结
开源项目仅能下载DiT-XL模型
微软开源了DiT-B模型,下载链接