【扩散模型系列3】DiT开源项目

文章目录

  • DiT原始项目
  • Fast-DiT readme
    • Sampling
    • Training
      • 训练之前的准备
      • 训练DiT
      • PyTorch 训练结果
      • 改进训练效果
    • Evaluation (FID, Inception Score, etc.)
  • 总结

DiT原始项目

该项目仅针对DiT训练,并未包含VAE 的训练

项目地址

论文主页

Fast-DiT readme

该项目仅针对DiT训练,并未包含VAE 的训练

项目地址

该项目是基于论文Scalable Diffusion Models with Transformers的pytorch改进实现

包含:

  • PyTorch的改进实现 和原始的DiT实现

  • 在ImageNet (512x512 and 256x256)预训练的条件分类的 DiT 模型;

  • 一个独立的 Hugging Face Space 和 Colab notebook,用于运行预训练的DiT-XL/2模型

  • 改进的DiT 训练脚本 和一些 训练建议

启动
首先,下载开源代码

git clone https://github.com/chuanyangjin/fast-DiT.git
cd DiT

我们提供了 environment.yml 文件,可创建Conda 虚拟环境。

如果想要在本地CPU运行预训练的模型,可以在文件中删除cudatoolkitpytorch-cuda 相关的依赖项。

conda env create -f environment.yml
conda activate DiT

Sampling

在这里插入图片描述

预训练的DiT checkpoints

你可以使用预训练模型样例sample.py

预训练的DiT模型权重会根据使用的模型自动下载。

根据输入模型尺寸的不同,脚本中进行了不同的参数设置转换(256x256 and 512x512),比如针对512x512 DiT-XL/2 模型,你可以使用以下命令:

python sample.py --image-size 512 --seed 1

为了更加方便,我们的预训练模型也可直接进行下载:

DiT ModelImage ResolutionFID-50KInception ScoreGflops
XL/2256x2562.27278.24119
XL/2512x5123.04240.82525

自定义的DiT checkpoints

如果你想训练一个新的DiT 模型,可使用 train.py (see below)。

你可以增加一个参数–ckpt 使用你自己的checkpoint进行演示。比如运行一个 256x256 DiT-L/4 模型,可使用以下命令 :

If you’ve trained a new DiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run:

python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

Training

训练之前的准备

在一个GPU节点上抽取ImageNet 特征:

torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path /path/to/imagenet/train --features-path /path/to/store/features

训练DiT

We provide a training script for DiT in train.py. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning.

我们提供了一个DiT的训练脚本 train.py。该训练脚本可以被用作训练条件分类的 DiT 模型,但是也可修改后用于其他类型的条件。

To launch DiT-XL/2 (256x256) training with 1 GPUs on one node:

accelerate launch --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features

To launch DiT-XL/2 (256x256) training with N GPUs on one node:

accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features

或者,你可以选择提取并训练文件夹training options中的脚本。

PyTorch 训练结果

我们训练了 DiT-XL/2 和 DiT-B/4 模型

我们使用PyTorch训练脚本从头开始训练DiT-XL/2和DiT-B/4模型,以验证它重现了原始的JAX结果,达到数十万次训练迭代。

在我们的实验中,在合理的随机变化范围内,pytorch训练的模型与jax训练的模型相比,给出了类似(有时略好)的结果。一些数据点如下:

DiT ModelTrain StepsFID-50K (JAX Training)FID-50K (PyTorch Training)PyTorch Global Training Seed
XL/2400K19.518.142
B/4400K68.468.942
B/4400K68.468.3100

这些模型在256x256分辨率下进行训练; 我们使用8x A100来训练XL/2,使用4x A100来训练B/4。注意,这里的FID是使用mse VAE解码器,在没有指导的情况下(cfg-scale=1),通过250个DDPM采样步骤计算得到的。

改进训练效果

与原始实现相比,实现了一些训练加速和节省内存的特征,包括梯度检查点、混合精度训练和预提取的VAE特征,在DiT-XL/2上的速度提高了95%,内存减少了60%。一些数据点使用A100的全局批处理大小为128:

gradient checkpointingmixed precision trainingfeature pre-extractiontraining speedmemory
-out of memory
0.43 steps/sec44045 MB
0.56 steps/sec40461 MB
0.84 steps/sec27485 MB

Evaluation (FID, Inception Score, etc.)

我们提供一个’ sample_ddp.py '脚本,它可以并行地从DiT模型中对大量图像进行采样。

这个脚本生成一个样本文件夹和. npz文件一样,可以直接与ADM的TensorFlow评估套件一起使用,以计算FID, Inception分数和其他指标。

例如,要在N个gpu上从预训练的DiT-XL/2模型中采样50K张图像,运行:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000

补充建议,详情请参见 sample_ddp.py

总结

开源项目仅能下载DiT-XL模型
微软开源了DiT-B模型,下载链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/429969.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

性能优化篇(七) UI优化注意事项以及使用Sprite Atlas打包精灵图集

UI优化注意事项 1.尽量避免使用IMGUI(OnGUI)来做游戏时的UI,因为IMGUI的开销比较大。 2.如果一个UGUI的控件不需要进行射线检测,则可以取消勾选Raycast Target 3.尽量避免使用完全透明的图片和UI控件。因为即使完全透明,我们看不见它&#xf…

论文笔记:Code Llama: Open Foundation Models for Code

导语 Code Llama是开源模型Llama 2在代码领域的一个专有模型,作者通过在代码数据集上进行进一步训练得到了了适用于该领域的专有模型,并在测试基准中超过了同等参数规模的其他公开模型。 链接:https://arxiv.org/abs/2308.12950机构&#x…

[cg] Games 202 - NPR 非真实感渲染

NPR特性(基于真实感渲染) 真实感--》翻译成非真实感的过程 NPR风格 需要转换为渲染中的操作 1.描边 B-->普通边界(不是下面几种的) C-->折痕 M-->材质边界 S-->需要在物体外面一圈上,并且是多个面共享…

win11部署自己的privateGpt(2024-0304)

什么是privateGpt? privategpt开源项目地址 https://github.com/imartinez/privateGPT/tree/main 官方文档 https://docs.privategpt.dev/overview/welcome/welcome PrivateGPT是一个可投入生产的人工智能项目,利用大型语言模型(LLMs)的…

流行 NFT 的必备指南

​作者:stellafootprint.network 编译:mingfootprint.network 来源:Footprint Analytics Blog 随着爱好者们对 NFT 的兴趣不断高涨,Footprint Analytics 发布了一系列文章,重点介绍各种热门 NFT 系列。这些文章深入…

GBU808-ASEMI整流桥GBU808参数、封装、尺寸

编辑:ll GBU808-ASEMI整流桥GBU808参数、封装、尺寸 型号:GBU808 品牌:ASEMI 封装:GBU-4 最大重复峰值反向电压:800V 最大正向平均整流电流(Vdss):8A 功率(Pd):中小功率 芯片个数&#…

【网站项目】075学生信息管理系统

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

xss.haozi.me:0x0B

<svg><script>(1)</script>

WebSocket 详解教程

概述 WebSocket 是什么&#xff1f; WebSocket 是一种网络通信协议。RFC6455 定义了它的通信标准。 WebSocket 是 HTML5 开始提供的一种在单个 TCP 连接上进行全双工通讯的协议。 为什么需要 WebSocket &#xff1f; 了解计算机网络协议的人&#xff0c;应该都知道&#x…

基于云的虚拟桌面基础架构 (VDI)的优势有哪些?

OpenText™ Exceed TurboX™ &#xff08;ETX&#xff09; 长期以来一直是虚拟化在 Linux 主机上运行的图形要求苛刻的软件的黄金标准。ETX 最新版本&#xff08;12.5&#xff09;增加了许多Microsoft Windows功能&#xff0c;包括Windows服务器&#xff0c;使ETX成为任何Windo…

HCIA-Datacom题库(自己整理分类的)_33_DHCP协议多选【7道题】

1.使用动态主机配置协议DHCP分配IP地址有哪些优点? 可以实现IP地址重复利用 工作量大且不好管理 配置信息发生变化(如DNS),只需要管理员在DHCP服务器上修改,方便统一管理 避免IP地址冲突 2.网络中部署了一台DHCP服务器,但是管理员发现部分主机并没有正确获取到该DHCP服务…

Kafka面经

1.Kafka如何保证消息不丢失 生产者&#xff1a; 1.Producer 默认是异步发送消息&#xff0c;这种情况下要确保消息发送成功&#xff0c;有两个方法 a. 把异步发送改成同步发送&#xff0c;这样 producer 就能实时知道消息发送的结果。 b. 添加异步回调函数来监听消息发送的结…

UNIapp实现局域网内在线升级

首先是UNIapp 生成apk 用Hbuilder 进行打包 可以从网站https://www.yunedit.com/reg?gotocert 使用自有证书&#xff0c;目测比直接使用云证书要快一些。 发布apk 网站 用IIS发布即可 注意事项中记录如下内容 第一、需要在 iis 的MiMe 中添加apk 的格式&#xff0c;否则无法…

STM32(15)USART编程

使用USART实现STM32与电脑之间的通信 中介&#xff1a;USB转TTL模块 闭合总开关&#xff0c;外部时钟才会传输到分频器 c8t6手册里面写了&#xff0c;usart最大支持4.5MHz&#xff0c;所以选10 重映射时记得开启AFIO的时钟

Android logcat系统

一 .logcat命令介绍 android log系统: logcat介绍 : logcat是android中的一个命令行工具&#xff0c;可以用于得到程序的log信息. 二.C/Clogcat访问接口 Android系统中的C/C日志接口是通过宏来使用的。在system/core/include/android/log.h定义了日志的级别&#xff1a; /…

opencart3 添加速卖通商品脚本

非爬虫&#xff0c;只能把速卖通商品信息拿下来解析插入到自己的项目里。 刚接触opencart3没多久&#xff0c;有一些新项目需要添加商品&#xff0c;每次手动从速卖通复制信息又很慢&#xff0c;就自己写了一个脚本。 思路&#xff1a;速卖通商品详情页有一段数据包含了几乎所…

Vue-02

开发者工具 安装插件&#xff0c;用于调试 Vue 应用。 https://chrome.zzzmh.cn/index 搜索 Vue &#xff0c;下载 Vue.js Devtools &#xff0c;此插件可以帮助更新信息&#xff0c;而不通过控制台更新&#xff0c;更方便调试。 注&#xff1a;安装插件后&#xff0c;记得在插…

Hello World!第一个labview程序

软件版本&#xff1a; labview myrio 2021英文版 因为没有找到中文版的&#xff0c;据说是myrio没有中文版本 实验内容&#xff1a; 文本显示&#xff0c;程序界面输入任意文本&#xff0c;然后运行程序 在前面板显示出输入的文本 以下为具体步骤&#xff1a; 第一步&…

no declaration can be found for element ‘rabbit:connection-factory‘

spring-mvc 配置 rabbitmq 出现问题。 我的解决方案如下&#xff1a; 1 找到配置文件 spring-rabbitmq.xml 我的配置文件叫&#xff1a;spring-rabbitmq.xml&#xff0c;你们按照自己的查找。 2 定位如下URI 接着 Ctrl鼠标左键 3 确定spring-rabbit-x.x.xsd 按照步骤2 &…

拿到年终奖马上离职,厚道吗?

拿到年终奖马上离职&#xff0c;厚道吗&#xff1f; 大家好&#xff0c;我是銘&#xff0c;全栈开发程序员。 今天在知乎上看到一个问题&#xff1a;拿到年终奖后马上辞职&#xff0c;厚道吗&#xff1f; image-20240229232132786 我的答案是&#xff1a;厚道&#xff0c;非常厚…