Sora 基础作品之 DiT:Scalable Diffusion Models with Transformer

Paper name

Scalable Diffusion Models with Transformers (DiT)

Paper Reading Note

Paper URL: https://arxiv.org/abs/2212.09748

Project URL: https://www.wpeebles.com/DiT.html

Code URL: https://github.com/facebookresearch/DiT

TL;DR

  • 2022 年 UC Berkeley 出品的论文,将 transformer 应用于 diffusion 上实现了当时最佳的生成效果。DiT 论文作者也是 OpenAI 项目领导者之一,该论文是 Sora 的基础工作之一。

Introduction

背景

  • transformer 在自回归模型中得到了广泛应用,但在其他生成模型框架中的采用较少。例如,扩散模型已经处于近期图像级生成模型进展的前沿;然而,它们都采用了卷积 U-Net 架构作为默认的主干选择
  • 本文展示了 U-Net 的归纳偏置对扩散模型的性能并非至关重要,可以替换为 transformer

本文方案

  • 提出了基于 transformer 的 diffusion 模型 Diffusion Transformers (简称 DiTs)。该架构具有良好的可扩展性,即网络复杂度(以Gflops衡量)与样本质量(以FID衡量)之间存在强相关性。通过简单地放大 DiT 并训练一个具有高容量主干的 LDM,能够在类条件 256 × 256 ImageNet 生成基准测试上达到 2.27 FID 的最新结果
    DiT 效果可视化

Methods

整体设计思路

  • 使用 Latent diffusion models(LDM) + Classifier-free guidance + transformer + VAE (Conv) 架构设计,从下图可以看出该设计的优势,左图显示有 scaling law,右图显示 LDM 相比于 pixel space diffusion 模型 ADM 有优势,不仅精度更高,训练计算量也更低
    scaling law for DiT

Diffusion Transformers

  • DiT 基于 ViT 修改得到,整体架构如下图所示:DiT block 通过区分 condition 的添加方式分为三种设计思路,分别是通过 adaLN (或 adaLN-Zero),cross-attention 或 In-Context
    DiT 架构
DiT 前向过程的各个模块
  • Patchify:DiT 的输入是图像的空间表示 z(对于 256×256×3 的图像,z 的形状为 32×32×4)。DiT的第一层是 “patchify”,它通过将输入中的每个补丁线性嵌入,将空间输入转换成一系列 T 个 token,每个 token 的维度为 d。在执行 patchify 之后,我们对所有输入 token 应用标准的 ViT 基于频率的位置嵌入(正弦-余弦版本)。由 patchify 创建的 token 数量 T 由补丁大小超参数 p 决定。如下图所示,将 p 减半会使 T 增加四倍,因此至少使整个 transformer Gflops 增加四倍。DiT 中主要实验了 p = 2, 4, 8
    patchify

  • DiT block:如整体框架图中所示,根据 condition 加入的不同方式分为以下四种设计思路

    • 上下文条件化。我们简单地将 t 和 c 的向量嵌入作为输入序列中的两个额外 tokens 附加上去,对待它们与图像 tokens 没有区别。这与 ViT 中的 cls tokens 类似,它允许我们无需修改就使用标准的 ViT 模块。在最后一个模块之后,我们从序列中移除条件化 tokens。这种方法对模型的新 Gflops 增加可以忽略不计。
    • 交叉注意力模块。我们将 t 和 c 的嵌入 concat 成一个长度为二的序列,与图像 token 序列分开。transformer 模块被修改为在多头自注意力模块后面增加一个额外的多头交叉注意力层,类似于 Attention is All you need 中的原始设计,也类似于 LDM 用于条件化类别标签的设计。交叉注意力对模型的 Gflops 增加最多,大约增加了 15% 的开销。
    • 自适应层归一化(adaLN)模块。在 GANs 和具有 UNet 骨干的扩散模型中广泛使用自适应归一化层之后,我们探索了用自适应层归一化(adaLN)替换 transformer 模块中的标准归一化层。adaLN 并不是直接学习维度规模的缩放和偏移参数 γ 和 β,而是从 t 和 c 的嵌入向量之和中回归得到它们。在我们探索的三种模块设计中,adaLN 增加的 Gflops 最少,因此是最计算高效的。它也是唯一一个限制对所有 tokens 应用相同函数的条件化机制。
    • adaLN-Zero 模块。之前的 ResNets 工作发现,将每个残差块初始化为恒等函数是有益的。例如,在监督学习环境中,将每个块中最后的批量归一化缩放因子 γ 零初始化可以加速大规模训练。扩散 U-Net 模型使用了类似的初始化策略,在任何残差连接之前零初始化每个块中的最终卷积层。我们探索了对 adaLN DiT 模块的修改,它做了同样的事情。除了回归 γ 和 β,我们还回归了在 DiT 模块内的任何残差连接之前作用的 dimension-wise 的缩放参数 α。初始化 MLP 以输出所有 α 为零向量;这将完整的 DiT 模块初始化为恒等函数。与标准的 adaLN 模块一样,adaLNZero 对模型的 Gflops 增加可以忽略不计。
  • Model size

    • 使用四种配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。它们涵盖了从 0.3 到 118.6 Gflops 不同范围的模型大小和浮点运算分配,使我们能够评估扩展性能。下表提供了配置的详细信息。
      model size
  • Transformer 解码器

    • 在最后的 DiT 模块之后,需要将图像 tokens 序列解码成输出噪声预测和输出对角协方差预测。这两个输出的形状等于原始的空间输入。我们使用标准 linear 解码器来完成这一任务;我们应用最终的层归一化(如果使用 adaLN 则为自适应的)并将每个 token 线性解码成一个 p×p×2C 的张量,其中 C 是输入到 DiT 的空间输入中的通道数。最后,我们将解码后的 tokens 重新排列成它们原始的空间布局,以得到预测的噪声和协方差。

Experiments

训练配置
  • ImageNet, 256x256 或 512x512 训练
  • AdamW, no weight decay
  • constant lr: 1 x 10−4
  • batch size: 256
  • EMA: decay 0.9999
VAE/Diffusion
  • Stable Diffusion 中的 VAE,下采样倍数为 8: 256 × 256 × 3 -> 32 × 32 × 4.
  • tmax=1000
  • linear variance schedule: 1e-4 -> 2e-2
评测指标
  • FID, FID-50k,250 DDPM sampling step

DiT block 消融

  • condition 的加入方式很影响精度:adaLN-Zero 精度最佳,说明权重初始化方式也很重要(让 DiT 的 block 初始化为 identity 函数)。
  • 计算量:in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops)
    DiT block

scaling 分析

  • 提升模型计算量稳定涨点
    scaling 基础实验

  • 只要模型计算量接近,FID 就接近。
    scaling

训练效率
  • 更大的 DiT 模型计算效率更高。
    • 训练计算量的评估方式:Gflops · batch size · training steps · 3。因子 3 大致近似于反向传播的计算量是前向传播的两倍。发现,即使训练时间更长,小型 DiT 模型最终相对于训练步数更少的大型 DiT 模型而言,在计算上变得效率低下。同样,我们发现,除了 patch 大小不同之外,其他都相同的模型即使在控制了训练 Gflops 的情况下,也有不同的性能表现。例如,在大约 1 0 10 10^{10} 1010 Gflops 之后,XL/4 的性能被 XL/2 超越。
      训练效率
可视化效果
  • 提升模型计算量可视化效果明显提升
    visualize scaling

class condition 定量分析

  • 达到 sota 效果,比之前的 sota StyleGAN-XL 精度更高
    class2img
    class2img 512px
增加模型参数量 or 增加 sampling 步数
  • 扩散模型的独特之处在于可以通过增加生成图像时的采样步骤来在训练后使用额外的计算资源。
  • 研究了在使用更多采样计算的情况下,较小模型计算的 DiT 是否能够超越较大的模型。结论是增加采样计算的规模无法弥补模型计算能力的不足
    scaling model vs scaling sampling

Thoughts

  • 符合 scaling law 的简洁架构才是王道。scaling law 在 DiT 这的实验效果极佳,和 OpenAI 价值观相符,这应该是作为 Sora 基础工作之一的原因
  • condition 的加入方式可能还需要更多的 class condition 之外的消融实验,比如 image condition、text condition 等

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/510873.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 59 螺旋矩阵(模拟)

给你一个正整数 n ,生成一个包含 1 到 n2 所有元素,且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。 示例 1: 输入:n 3 输出:[[1,2,3],[8,9,4],[7,6,5]]示例 2: 输入:n 1 输出&…

模型训练----parser.add_argument添加配置参数

现在需要配置参数来达到修改训练的方式,我现在需要新建一个参数来开关wandb的使用。 首先就是在def parse_option():函数里添加上你要使用的变量名 parser.add_argument("--open_wandb",type bool,defaultFalse,helpopen wandb) 到config文件里增加你的…

2024-04-02 作业

作业要求&#xff1a; 整理思维导图使用模板类&#xff0c;实现顺序栈写一个char类型的字符数组&#xff0c;对该数组访问越界时抛出异常&#xff0c;并做处理。 作业1&#xff1a; 作业2&#xff1a; 运行代码: #include <iostream>using namespace std; #define LEN …

OpenGL_Learn16(模板测试)

模板缓冲首先会被清除为0&#xff0c;之后在模板缓冲中使用1填充了一个空心矩形。场景中的片段将会只在片段的模板值为1的时候会被渲染&#xff08;其它的都被丢弃了&#xff09;。 模板缓冲操作允许我们在渲染片段时将模板缓冲设定为一个特定的值。通过在渲染时修改模板缓冲的…

LeetCode_394(字符串解码)

双栈法 public String decodeString(String s) {String res "";Stack<Integer> countStack new Stack<>();Stack<String> resStack new Stack<>();int idx 0;while (idx < s.length()){char cur s.charAt(idx);//处理数字if(Charact…

css基础(一文读懂css)

1.css简介 css是一种用于描述网页样式和布局的样式表语言。它与HTML结合使用&#xff0c;用于控制网页中各个元素的外观和排版。 2.css样式引入方式 2.1 行内样式 行内优先级最高&#xff0c;针对当前标签 2.2 行外头部引入 行外头部&#xff1a;style&#xff0c;针对当前…

ISELED-演示项目代码

目录 一、main函数二、点灯函数一、main函数 int main(void) {/* Write your local variable definition here */iseledInitType.crcEnable = 1;iseledInitType.firstLedAdr = 1;iseledInitType.tempCmpEnable = 0;iseledInitType.voltSwing = 0;/*** End of Processor Expert…

【二叉树】Leetcode 105. 从前序与中序遍历序列构造二叉树【中等】

从前序与中序遍历序列构造二叉树 给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 示例1&#xff1a; 输入: preorder [3,9,20,15,7], inorder …

基于Springboot的一站式家装服务管理系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的一站式家装服务管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体…

微信小程序开发学习笔记——4.7 api中navigate路由接口与组件的关系

>>跟着b站up主“咸虾米_”学习微信小程序开发中&#xff0c;把学习记录存到这方便后续查找。 一、跳转 1、方法一&#xff1a;组件 组件-导航-navigator <navigator url"/pages/demo/demo?id123" open-type"reLaunch">go demo page <…

DETREC数据集标注 VOC格式

经过将DETRAC数据集转换成VOC格式&#xff0c;并使用labelimg软件进行查看&#xff0c;发现该数据集存在很多漏标情况&#xff0c;截图如下所示。

acwing1388. 游戏 + LC1406.石子游戏 零和博弈

零和博弈 有点类似那个Min-Max 游戏 考虑DP【l,r】 为当前考虑到[l,r]当前的先手能得到的最大的分 #include<bits/stdc.h> using namespace std; using ll long long; using pii pair<int,int>; const int N 1e510; const int inf 0x3f3f3f3f; const int mod …

NIKKI DENSO伺服驱动器维修NCR-CAB1A2D-801B

NEXSRT伺服驱动器维修NPSA-MU日机电装伺服维修ACTUS POWER&#xff0c;NCS-ZE12MDA/ZE1MDA-601A&#xff0c;NEXSRT日机电装伺服维修NCS-ZE12MDB-401A/NCS-ZAMDA-401AG。 NIKKI常见故障原因及处理方法&#xff1a; 1、电机在一个方向上比另一个方向跑得快&#xff1b; (1) 故…

4个惊艳的AI项目,开源了!

大家好&#xff0c;今天继续聊聊科技圈发生的那些事。 一、Champ 三维参数导引下可控一致的人体图像动画生成项目。只需要一张照片&#xff0c;就能让照片里的人物动起来。 给出一个动作视频&#xff0c;Champ 可以让不同的人像复刻出相同的动作。 我们先来看看真实人物照片…

【PowerDesigner】PGSQL反向工程过程已中断

问题 反向工程过程已中断,原因是某些字符无法通过ANSI–&#xff1e;UTF-16转换进行映射。pg导入sql时报错&#xff0c;一查询是power designer 反向工程过程已中断&#xff0c;某些字符无法通过ANSI–>UTF-16转换进行映射&#xff08;会导致数据丢失&#xff09; 处理 注…

生活篇——关于分期贷或信用贷的等额本息、先息后本、月利率、年利率、年利率单利的个人理解

首先我先就年利率的理解问一下各位读者2个问题。 问题1&#xff1a;假设你要借100000元&#xff0c;借一年&#xff0c;月利息0.2%&#xff0c;等额本息&#xff0c;那么你觉得你总共需要还多少利息&#xff1f;它的实际年利率约为多少&#xff1f; A.2400&#xff0c;2.4% …

C语言一维数组及二维数组详解

引言&#xff1a; 小伙伴们&#xff0c;我发现我正文更新的有些慢&#xff0c;但相信我&#xff0c;每一篇文章真的都很用心在写的&#xff0c;哈哈&#xff0c;在本篇博客当中我们将详细讲解一下C语言中的数组知识&#xff0c;方便大家后续的使用&#xff0c;有不会的也可以当…

Java设计之道:色即是空,空即是色

0.引子 我们的这个世界上&#xff0c;存在这么一种东西&#xff1a; 第一&#xff1a;它不占据任何3D之体积&#xff0c;即它没有Volume第二&#xff1a;它也不占据任何2D之面积&#xff0c;即它没有Area第三&#xff1a;它也不占据任何1D之长度&#xff0c;即它没有Length 总…

《QT实用小工具·三》偏3D风格的异型窗体

1、概述 源码放在文章末尾 可以在窗体中点击鼠标左键进行图片切换&#xff0c;项目提供了一些图片素材&#xff0c;整体风格偏向于3D类型&#xff0c;也可以根据需求自己放置不同的图片。 下面是demo演示&#xff1a; 项目部分代码如下所示&#xff1a; 头文件部分&#xff…

NULL与nullptr的区别

NULL是宏定义&#xff0c;如下&#xff1a; 如果用NULL&#xff0c;在函数重载时&#xff0c;NULL的类型被推断为int。这是不好的&#xff0c;所以引入nullptr。nullptr是c11引入的关键字&#xff0c;它就代表空指针。