论文阅读:Scalable Diffusion Models with Transformers

Scalable Diffusion Models with Transformers

论文链接

介绍

传统的扩散模型基于一个U-Net骨架,这篇文章提出了一种新的扩散模型结构,将U-Net替换为一个transformer,并将这种结构称为Diffusion Transformers (DiTs)。他们还发现,transformer的规模越大(通过Gflops衡量),生成的图片的质量越好(FID越低)。
如图2所示,DiT的规模越大,图片生成的质量越好(左图),和当前流行的扩散模型相比,DiT的计算效率也表现优异。
ImageNet generation with Diffusion Transformers (DiTs)

相关工作

  • Transformers:这篇文章研究了transformer作为扩散模型的骨架时,其规模的性质。
  • Denoising diffusion probabilistic models (DDPMs):传统的扩散模型都使用U-Net作为骨架,本文尝试使用纯transformer作为骨架。
  • Architecture complexity:在结构设计领域,Gflops是常见的衡量结构复杂度的指标。

方法(Diffusion Transformers)

预备知识

  • Diffusion formulation:扩散模型Diffusion Model(DM)在训练过程中,首先向图片中添加噪声,然后预测噪声来从图片中将噪声去除。这样,在推理过程中,首先初始化一个高斯噪声图片,然后去除预测的噪声,即可得到生成的图片。
  • Classifier-free guidance:条件扩散模型引入了额外信息 c c c(比如,类别)作为输入。而classifier-free guidance可以引导生成的图片 x x x是类别 c c c的概率 l o g ( c ∣ x ) log(c|x) log(cx)最大。
  • Latent diffusion models:扩散模型在像素空间上训练和推理的计算开销过大,Latent Diffusion Model(LDM)将像素空间替换为VAE编码得到的潜在空间 z = E ( x ) z=E(x) z=E(x),可以提高计算效率。本文提出的DiT沿用了LDM中的潜在空间,但是在预测潜在空间特征的模型上,将LDM中的U-Net替换为了纯Transformer骨架。

Diffusion Transformer Design Space

Diffusion Transformers (DiTs)是基于Vision Transformer (ViT)的模型,它的大体结构如图3所示,从左图可以看到,输入的噪音特征被分解为不同批,然后被若干个DiT块处理;右边的三张图展示了DiT块的详细结构,分别是三种不同的变体。
The Diffusion Transformer (DiT) architecture
下面对DiT的各层进行分析:
Patchify. 从图3中可以看到,DiT的第一个层是Patchify,其将输入转化为 T T T个token序列。在这之后,作者使用标准ViT中基于频率的位置嵌入处理前面的token序列。而token序列的数量是由一个超参数 p p p决定的, p p p减半导致 T T T翻四倍,并且导致整个transformer的GFlops至少翻四倍,如图4所示。
Input specifications for DiT
DiT block design. 在patchfiy层之后,几个transformer块处理输入token以及一些额外的条件信息,比如,类标签 c c c和时间步数 t t t。作者尝试了4种不同的ViT变体:

  • In-context conditioning:这种变体直接将时间步数 t t t和类标签 c c c作为额外的token添加到输入token序列后面,类似于ViT的cls tokens,因此也可以直接使用标准的ViT块。这种方式引入的Gflops可以忽略不计。
  • Cross-attention block:这种变体将条件信息拼接为一个长度为2的序列,独立于图片输入序列。然后,在transformer块的self-attention层后添加了一个cross-attention层,类似于LDM,在cross-attention层将条件信息加入图片特征中。cross-attention方案增加的Gflops最多,大概15%。
  • Adaptive layer norm (adaLN) block:这种变体将transformer块中标准的layer norm layers替换为adaptive layer
    norm (adaLN),这一技术在GAN相关的模型中被广泛采用。不同于直接学习维度放缩和偏移因子 γ \gamma γ β \beta β,该方案回归 t t t c c c的嵌入的和得到这两个参数。在目前的三种方案中,该变体额外增加的Gflops最少。
  • adaLN-Zero block:先前的工作说明,ResNet中的恒等映射是有益处的。Diffusion U-Net在残差之前,零初始化了每个块中最后一个卷积层。作者采用了和Diffusion U-Net相同的方案。此外,除了回归 γ \gamma γ β \beta β,该方案还对DiT块中残差连接上的放缩因此 α \alpha α进行了回归。对于所有的 α \alpha α,作者初始化MLP以输出零向量,这使得DiT块为一个恒等函数。和adaLN方案一样,ada-Zero方案引入的Gflops也可以忽略不计。

Model Size. 作者设置了四种规模的DiT:DiT-S, DiT-B, DiT-L and DiT-XL,结构复杂度依次增大。
Transformer decoder. 在经过最后的DiT块之后,使用tranformer decoder将输入tokens转化为和输入同等性状的噪音预测。

综上,作者探索了DiT设计空间中的patch_size、transformer架构(4种,in-context,cross-attention, adaptive layer
norm and adaLN-Zero blocks)和model size(4种,DiT-S, DiT-B, DiT-L and DiT-XL)。

实验

实验设置

  • 训练:在256 × 256和512 × 512 图片分辨率的ImageNet数据集上训练。超参数设置几乎和ADM一致。
  • Diffusion:和Stable DIffusion一样使用VAE编码图片和解码特征。
  • 评估指标:主要使用Fr´echet Inception Distance (FID),还使用了Inception Score [51], sFID [34] and Precision/Recall [32]
  • 计算平台:在JAX [1]这个深度学习框架上实现了DiT,在TPU上训练模型。

实验结果

DiT block design. 四个不同的DiT块:in-context (119.4 Gflops), cross-attention (137.6 Gflops),
adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops)中, adaLN-zero (118.6 Gflops) 取得最低的FID。其中,adaLN-zero相较于adaptive layer norm的提升,说明了恒等映射的好处。(后续的实验除非特别说明都是在adaLN-zero上做的)

Comparing different conditioning strategies
Scaling model size and patch size. 模型size增大和patch zise减小,均会提高Gflops,降低FID。我们注意到,DiT-L 和DiT-XL的FID很接近,因为它们的Gflops也相对更接近。
Scaling the DiT model improves FID at all stages of training
DiT Gflops are critical to improving performance. 上面的图6再次说明了模型参数量的增大并不等同于DiT模型的图片质量提高,真正的关键是提高Gflops。比如,DiT S/2的表现和DiT B/4接近,因为小的batch size会增大Gflops,二者的Gflops接近,所以FID也接近。
Larger DiT models are more compute-efficient
小的DiT模型即便训练时间更长,相对于训练时间更短的大的DiT模型,其计算效率也是更差的。
这里,作者估计训练计算量的方式为model Gflops · batch size · training steps · 3。
Larger DiT models use large compute more effi-
ciently

State-of-the-Art Diffusion Models

和主流的扩散模型相比,DiT-XL/2 (即参数量最大,patch size最小的DiT)的表现最优。

Scaling Model vs. Sampling Compute

扩散模型有一个比较特殊的点,在生成图片时,它可以通过增加调整采样步数,引入额外的增加的计算量,但是,这并不能弥补训练时模型计算量的差距,即大GFlops的DiT在采样步数少的情况下,仍然能比小GFlops的DiT在采样步数多的情况下,取得更低的FID。

结论

Diffusion Transformers (DiTs)作为一种新的扩散模型,比基于U-Net的扩散模型表现更加优异。并且,其在模型复杂度提高的时候,能够有明显的性能提高,因此,使用更大规模的DiT有助于提高模型性能。此外,DiT也可以用于文生图生成任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/441040.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【网络】:HTTP服务器

HTTP服务器 一.预备知识二.HTTP的请求和响应三.写一个简单的HTTP服务器四.返回响应五.HTTP方法和状态码 一.预备知识 1.域名 https://www.baidu.com,这是一个域名。在技术角度上,访问一个服务器其实只需要知道它的ip和域名就行了,而域名主要…

电力物联网系统设计

电力物联网系统设计 简介 在新能源行业从业多年,参与和负责过大大小小的的项目,发电侧、电网侧、用户侧系统都有过实际的项目经验,这些项目或多或少都有物联网采集方面的需求,本篇文章将会对电力行业物联网经验做一个总结分享。 …

LeetCode 刷题 [C++] 第3题.无重复字符的最长子串

题目描述 给定一个字符串 s ,请你找出其中不含有重复字符的 最长子串 的长度。 题目分析 可以使用滑动窗口加哈希表来实现: 使用start和end两个变脸来表示滑动窗口的头部位置和尾部位置,两者开始均为0;借助哈希表来记录已经遍…

六、长短时记忆网络语言模型(LSTM)

为了解决深度神经网络中的梯度消失问题,提出了一种特殊的RNN模型——长短期记忆网络(Long Short-Term Memory networks, LSTM),能够有效的传递和表达长时间序列中的信息并且不会导致长时间前的有用信息被忽略。 长短时记忆网络原理…

vue iis 配置

下载安装两个IIS模块 1). 传送门:URL Rewrite 2). 传送门:Application Request Routing 注 : 只有在 服务器的主页 有Application Request Routing 部署VUE网站 生成网站 在VUE项目打包生成出发布文件,即文件夹 dist,此处忽略 复制到你需要存放网站的…

10 事务控制

文章目录 事务控制事务概述事务操作事务四大特性事务隔离级别 事务控制 事务概述 MySQL 事务主要用于处理操作量大,复杂度高的数据。比如说,在人员管理系统中,你删除一个人员,既需要删除人员的基本资料,也要删除和该…

kafka报文模拟工具的使用

日常项目中经常会碰到消费kafka某个topic的数据,如果知道报文格式,即可使用工具去模拟发送报文,以此测试代码中是否能正常消费到这个数据。 工具资源已上传,可直接访问连接下载:https://download.csdn.net/download/w…

Learn OpenGL 02 你好,三角形

图形渲染管线 图形渲染管线的每个阶段的抽象展示。要注意蓝色部分代表的是我们可以注入自定义的着色器的部分 首先,我们以数组的形式传递3个3D坐标作为图形渲染管线的输入,用来表示一个三角形,这个数组叫做顶点数据(Vertex Data)。 顶点着色…

编译内核错误 multiple definition of `yylloc‘

编译内核错误 # make ARCHarm CROSS_COMPILEarm-mix410-linux- uImageHOSTLD scripts/dtc/dtc /usr/bin/ld: scripts/dtc/dtc-parser.tab.o:(.bss0x10): multiple definition of yylloc; scripts/dtc/dtc-lexer.lex.o:(.bss0x0): first defined here collect2: error: ld ret…

昏暗场景增强-低照度增强-弱光增强(附代码)

引言 随着现代科技的发展,图像采集设备已经渗透到生活的方方面面,然而在昏暗场景、低照度或弱光条件下,图像的质量往往受到严重影响,表现为亮度不足、对比度低下、色彩失真以及细节丢失等问题。这类图像对于人眼识别和计算机视觉…

FPGA IBUFG

IBUFG和IBUFGDS的输入端仅仅与芯片的专用全局时钟输入管脚有物理连接,与普通IO和其它内部CLB等没有物理连接。 所以,IBUFG输入的不能直接接另外信号。 GTH transceiver primitives are called GTHE3_COMMON and GTHE3_CHANNEL in UltraScale FPGAs, an…

部署LVS+Keepalived高可用群集(抢占模式,非抢占模式,延迟模式)

目录 一、LVSKeepalived高可用群集 1、实验环境 2、 主和备keepalived的配置 2.1 yum安装ipvsadm和keepalived工具 2.2 添加ip_vs模块并开启ipvsadm 2.3 修改keepalived的配置文件 2.4 调整proc响应参数,关闭linux内核的重定向参数响应 2.5 将主服务器的kee…

SpringBoot整合Redis实现分布式锁

SpringBoot整合Redis实现分布式锁 分布式系统为什么要使用分布式锁? 首先,分布式系统是由多个独立节点组成的,这些节点可能运行在不同的物理或虚拟机器上,它们通过网络进行通信和协作。在这样的环境中,多个节点可能同…

群智能优化算法:巨型犰狳优化算法(GAO)求解23个基准函数(提供MATLAB代码)

一、巨型犰狳优化算法 巨型犰狳优化算法(Giant Armadillo Optimization,GAO)由Omar Alsayyed等人于2023年提出,该算法模仿了巨型犰狳在野外的自然行为。GAO设计的基本灵感来自巨型犰狳向猎物位置移动和挖掘白蚁丘的狩猎策略。GAO…

LLM 构建Data Muti-Agents 赋能数据分析平台的实践之①:数据采集

一、 概述 在推进产业数字化的过程中,数据作为最重要的资源是优化产业管控过程和提升产业数字化水平的基础一环,如何实现数据采集工作的便利化、高效化、智能化是降低数据分析体系运转成本以及推动数据价值挖掘体系的基础手段。随着数字化在产业端的推进…

电脑小问题:Windows更新后黑屏

Windows 更新后黑屏解决方法 在 Windows 更新后,伴随了一个小问题,电脑启动后出现了桌面黑屏。原因可能是火绒把 explorer.exe 当病毒处理了。 下面讲解 Windows 更新后黑屏的解决方法,步骤如下: 1. 按 ctrl alt delete 组合键…

JAVA实战开源项目:生活废品回收系统(Vue+SpringBoot)

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容三、界面展示3.1 登录注册3.2 资源类型&资源品类模块3.3 回收机构模块3.4 资源求购/出售/交易单模块3.5 客服咨询模块 四、免责说明 一、摘要 1.1 项目介绍 生活废品回收系统是可持续发展的解决方案,旨在鼓…

【操作系统概念】 第9章:虚拟内存管理

文章目录 0.前言9.1 背景9.2 按需调页9.2.1 基本概念9.2.2 按需调页的性能 9.3 写时复制9.4 页面置换9.4.1 基本页置换9.4.2 FIFO页置换9.4.3 最优(Optimal)置换9.4.4 LRU(Least Recently Used)页置换9.4.5 近似LRU页置换9.4.6 页缓冲算法 9.5 帧分配9.5…

遥遥领先!基于transformer变体的时间序列预测新SOTA!

目前,以CNN、RNN和 Transformer 模型为代表的深度学习算法已经超越了传统机器学习算法,成为了时间序列预测领域一个新的研究趋向。这其中,基于Transformer架构的模型在时间序列预测中取得了丰硕的成果。 Transformer模型因其强大的序列建模能…

JVM-垃圾收集器G1

G1垃圾回收器 概述: 是一款面向服务器的垃圾收集器,主要针对配备多个处理器及大容量内存的机器. 以极高效率满足GC停顿时间要求的同时,还具备高吞吐量性能特征.G1保留了年轻代和老年代的概念,但不再是物理隔阂了,它们都是(可以不连…