摘要
本文利用从互联网上收集的2.5亿个图像/文本对数据,训练了一个120亿参数的自回归transformer,进而得到一个可以通过自然语言/图像控制生成的高保真图像生成模型。在大多数数据集上的表现超越以往的方法。
框架
本文的目标为通过训练一个自回归transformer,通过将文本和图像tokens自回归建模为单个数据流,进而结合图像解码器进行图像生成,整体分为两个阶段:
- 第一阶段:训练一个离散变分自编码器(dVAE),其编码器会将输入图像从 256 × 256 256 × 256 256×256压缩为 32 × 32 32 × 32 32×32的图像tokens,其中每个token都会映射到 K = 8192 K = 8192 K=8192的codebook向量中。相比于直接使用像素作为图像token,这可以使后续步骤的自回归transfromer的上下文大小减少192倍,同时不会大幅降低视觉质量(如上图)。
- 第二阶段:将256个由BPE编码的文本tokens 与 32 × 32 = 1024 32 × 32 = 1024 32×32=1024个图像tokens拼接起来,基于此训练一个自回归transformer,实现对文本和图像tokens的联合分布的建模。
整个过程可以被看作最大化模型在图像
x
x
x、文本
y
y
y和tokens
z
z
z上的联合似然的ELBO,通过因子分解可以将该分布建模为
p
θ
,
ψ
(
x
,
y
,
z
)
=
p
θ
(
x
∣
y
,
z
)
p
ψ
(
y
,
z
)
p_{θ,ψ}(x, y, z) = p_θ(x | y, z)p_ψ(y, z)
pθ,ψ(x,y,z)=pθ(x∣y,z)pψ(y,z),对应的ELBO为:
ln
p
θ
,
ψ
(
x
,
y
)
≥
E
z
∼
q
ϕ
(
z
∣
x
)
(
ln
p
θ
(
x
∣
y
,
z
)
−
β
D
K
L
(
q
ϕ
(
y
,
z
∣
x
)
,
p
ψ
(
y
,
z
)
)
)
\ln p_{θ,ψ}(x, y) \ge \mathbb{E}_{z \sim q_ϕ(z|x)}(\ln p_θ(x|y,z)-\beta D_{KL}(q_ϕ(y,z|x),p_ψ(y,z)))
lnpθ,ψ(x,y)≥Ez∼qϕ(z∣x)(lnpθ(x∣y,z)−βDKL(qϕ(y,z∣x),pψ(y,z)))
其中, q ϕ q_ϕ qϕ表示给定图像 x x x经过dVAE编码器生成的 32 × 32 32 × 32 32×32的tokens的分布; p θ p_θ pθ表示由tokens经过dVAE解码器生成的图像的分布; p ψ p_ψ pψ表示由自回归transformer建模的文本和图像tokens的联合分布。该ELBO只适用与 β = 1 \beta = 1 β=1的情况。
Learning the Visual Codebook
第一阶段的训练目标为最大化 ϕ ϕ ϕ和 θ θ θ的ELBO,即通过给定图像训练dVAE。先验 p ψ p_ψ pψ初始化为基于codebook( K = 8192 K = 8192 K=8192)向量的均匀分类分布(uniform categorical distribution); q ϕ q_ϕ qϕ初始化为在编码器输出的 32 × 32 × 8192 32 × 32×8192 32×32×8192的logits参数化的分类分布(uniform categorical)。
由于 p ψ p_ψ pψ是一个离散分布,无法使用梯度进行优化,故此处采用gumbel-softmax松弛,用 q ϕ τ q ^τ_ ϕ qϕτ取代 q ϕ q_ϕ qϕ,当 τ → 0 τ → 0 τ→0时,松弛程度会逐渐缩小,逼近原始分布。 p θ p_θ pθ的似然使用log-laplace分布评估,以避免离群值导致的生成模糊问题。
松弛后的ELBO使用Adam和EMA优化,以下配置对训练稳定性很重要:
- 松弛temperature和步长的具体退火方法。实验发现 τ τ τ退火到1/16时,松弛ELBO的 q ϕ τ q ^τ_ ϕ qϕτ和真实ELBO的 q ϕ q_ϕ qϕ之间的gap就会消失。
- 在编码器的末尾和解码器的开头使用1 × 1卷积。 实验发现,通过减少松弛方法周围的卷积层的感受野大小,可以使其泛化到真实ELBO的情况。
- 将编码器和解码器的输出激活值乘一个小的常数,可以使初始化时的训练更加训练。
另外,KL权重增加到 β = 6.6 β = 6.6 β=6.6时,可以得到更好codebook,故而使训练结束时的重构误差更小。
Learning the Prior
第二阶段在固定 ϕ ϕ ϕ和 θ θ θ的情况下,最大化关于 ψ ψ ψ的ELBO,学习文本和图像token的联合先验分布。其中, p ψ p_ψ pψ是一个120亿参数的稀疏transformer。
具体,给定一个文本/图像对,首先通过对小写文本进行BPE编码(词汇表大小为16384)得到最多256个文本token ,并对dVAE编码器输出的logits进行argmax采样codebook得到1024个图像token,此处没有添加gumbel噪声。最后,拼接这些文本和图像token作为单个数据流进行自回归建模。
本文限制文本标题的最大长度为256,每个文本位置都会学习一个特殊的“padding” token,当对应位置没有文本token时使用此token。得到文本和图像token的交叉熵损失后,将文本交叉熵损失乘以1/8,图像交叉熵损失乘以7/8,以对loss归一化,本阶段也使用EMA和Adam进行优化。
Data Collection
本文从互联网上收集了2.5亿个文本/图像对,创建了一个与JFT-300M相似规模的数据集。该数据集包括一部分Conceptual Captions和YFCC100M的经过滤的子集。
Mixed-Precision Training
为了节省GPU内存并增加吞吐量,模型的大多数参数、Adam矩阵和模型激活值都以FP16存储,并使用了activation checkpointing技术。
在训练过程中发现,随着模型变得更深更广,resblocks的激活梯度会单调减少,较深层的resblocks的激活梯度可能小于FP16的最小值,其会被四舍五入为0,这种现象称为下溢(underflow)。实验发现消除下溢可以使训练更加稳定。
故对于模型中每个的resblock,通过执行“gradient scale”可以解决下溢问题,如上图。
Distributed Optimization
DELL-E有120亿参数,以FP16精度存储时会消耗约24GB内存,这超过单张NVIDIA V100 GPU的16GB内存,故使用参数分片。如上图。
本文的实现中,每台机器上的每个GPU都独立地计算其参数分片梯度的低秩因子,而不依赖于其相邻的GPU。一旦计算出低秩因子,每台机器都会将其error buffer设置为其八个GPU上未压缩的参数梯度的平均值(通过reduce-scatter获得)与通过解压缩低秩因子得到的梯度之间的残差(两者偏差)。
对于一个模型训练集群,其机器之间的带宽远低于同一机器上不同GPU之间的带宽,故机器之间梯度平均操作(all-reduce)成为训练期间的主要速度瓶颈,通过引入PowerSGD压缩梯度,可以大大降低这种成本。 PowerSGD会将未压缩的参数梯度的通信操替换为基于其低秩因子的两个更小的通信操作。给定压缩rank r r r和transformer激活尺寸 d m o d e l d_{model} dmodel,其压缩率为 1 − 5 r / ( 8 d m o d e l ) 1 − 5r/(8d_{model}) 1−5r/(8dmodel)。如上表显示,无论模型大或小,该方法可以实现约85%的压缩率。
Sample Generation
对于从transformer中生成的一系列图像,本文采用预训练CLIP对生成图像与文本标题的匹配程度来分配分数并排序。如上图显示了给定生成的N张图像,并从中选择的top-k图像。除非另有说明,用于定性和定量结果的所有样本都是在不降低temperature的情况下获得的,并使用N = 512重新排序。
实验
Quantitative Results
上图定性比较了DALL-E和AttnGAN、DM-GAN和DF-GAN的生成。
上图为人类验证实验。给定一个文本标题,相比DF-GAN,DALL-E的生成在93%的情况下与文本标题更好地匹配,而获得更多的人类投票。在90%的情况下,也因为更真实而获得了大多数人类投票。
上图(a上)为在MS-COCO数据集上验证的定量结果,DALL-E与之前最佳方法只差2个点的FID分数。由于DALL-E的训练数据中包含一个YFCC100M的过滤子集,其中包含MS-COCO验证集中大约21%的图像,故为了隔离这种影响,另外分别计算了验证集有这些图像(实线)和没有这些图像(虚线)的FID信息,结果没有明显变化。
用dVAE编码器的token训练transformer,可以使模型学习更多的图像低频信息,使图像在视觉上更真实。,但这也不利于模型学习产生高频细节。为了验证模型的高频建模能力,本实验对验证图像和模型生成的样本应用了不同半径的高斯滤波器,并计算对应IS值。结果如上图(a下),随着模糊半径的增加,DALL-E和其他方法之间的差距越拉越大,当模糊半径大于等于2时,DALL-E取得了最佳结果。
DALL-E在CUB数据集上的表现比较差,如上图(b),和之前的主要方法有近40点FID的差距。经过检测发现,训练数据集中包含12%的CUB数据,但去除这些数据后模型表现仍旧不佳。故推测zero-shot DALL-E不太可能在CUB等专业分布的数据集上获得优势。
上图(c)显示了当用于CLIP重排序的样本增加时,DALL-E的FID有了明显改进。
上图显示了DALL-E在CUB数据集中不同文本标题下的生成示例。
Qualitative Findings
通过验证发现DALL-E有不以最初预期的方式进行泛化的能力。当给出文本标题“a tapir made of accordion… ”,该模型似乎画了一个以手风琴为身 体的貘(上图a)。这表明,其发展出一种基本的能力,可以在较高的抽象层次上组合概念。
DALL-E似乎也能进行组合泛化,例如在渲染如“an illustration of a baby hedgehog in a christmas sweater walking a dog”这样的句子(上图b、c)。
在有限的可靠性程度上,还发现该模型能够由自然语言控制图像到图像的翻译。当模型被赋予标题“the exact same cat on the top as a sketch at the bottom”时,其能够在底部画一个类似的猫的草图(上图d)。 这也适用于其他几种类型的转换,包括图像操作(例如改变图像的颜色、将其转换为灰度或翻转图像)和样式转换(例如在贺卡、邮票或手机壳上画猫)。一些只涉及改变动物颜色的转换,表明DALL-E能够执行基本的对象分割。
Appendix
Details for Discrete VAE
Architecture
dVAE编码器和解码器都为具有bottleneck-style resblocks的ResNets。编码器的第一层卷积核尺寸为 7 × 7 7 × 7 7×7,编码器的最后一层卷积核尺寸为 1 × 1 1×1 1×1(输出尺寸为 32 × 32 × 8192 32 × 32 × 8192 32×32×8192,用作图像token的分类分布的logits)。解码器的第一层卷积和最后一层卷积核尺寸都为 1 × 1 1×1 1×1。编码器使用最大池化下采样,解码器使用最近邻上采样。
Training
dVAE在与transformer相同的数据集上进行训练,使用上图中给出的数据增强代码。以下量在训练过程中使用余弦退火进行衰减:
- KL权重 β β β在前5000次迭代中从0增加到6.6
- 松弛 τ τ τ在前150000次迭代中从1退火到1/16
- 在1200000迭代中,step size从 1 ⋅ 1 0 − 4 1\cdot10^{−4} 1⋅10−4退火到 1.25 ⋅ 1 0 − 6 1.25 \cdot 10^{−6} 1.25⋅10−6
使用 β 1 = 0.9 , β 2 = 0.999 , ϵ = 1 0 − 8 β_1 = 0.9, β_2 = 0.999, ϵ = 10^{−8} β1=0.9,β2=0.999,ϵ=10−8的AdamW和 1 0 − 4 10^{−4} 10−4的weight decay和 0.999 0.999 0.999的EMA优化模型。该模型在64个16 GB NVIDIA V100 gpu上使用混合精度训练,每个gpu的batch size为8,总batch size为512,总共3000000次。
Details for Transformer
Architecture
本文第二阶段模型是一个仅解码器的稀疏transformer,其输入tokens embedding格式如上图。其包括64个注意力层,每个层使用62个注意力头,每个头的维度大小为64。
该模型使用三种稀疏注意力mask,如上图。给定自注意力层的索引
i
i
i(
i
∈
[
1
,
63
]
i ∈ [1, 63]
i∈[1,63]),如果
i
−
2
m
o
d
4
=
0
i − 2\ mod \ 4 = 0
i−2 mod 4=0,则使用列注意力mask(c),否则使用行注意力mask,例如,前四个自注意力层分别使用row、column、row、row。卷积注意力mask(d)仅用于最后的自注意力层。
Training
对于训练transformer的训练,在使用dVAE编码器编码图像之前,首先对图像进行如上图代码所示的数据增强。在用BPE编码文本标题时,还应用了10%的BPE dropout。该模型使用逐resblock缩放和梯度压缩进行训练,总压缩rank为896(每个GPU的参数分片使用112的压缩rank)。
使用 β 1 = 0.9 , β 2 = 0.96 , ϵ = 1 0 − 8 β_1 = 0.9, β_2 = 0.96, ϵ = 10^{−8} β1=0.9,β2=0.96,ϵ=10−8的AdamW与和 4.5 ⋅ 1 0 − 2 4.5 \cdot 10^{−2} 4.5⋅10−2的weight decay和0.99的EMA优化参数。在应用Adam更新前,会使用阈值为4的norm对解压后的梯度进行裁剪,梯度裁剪仅在训练开始的预热阶段运行。为了节省内存,大部分Adam矩阵以FP16格式存储,其中运行平均值为1-6-9格式(即1位用于符号,6位用于指数,9位用于尾数),运行方差为0-6-10格式,在更新参数或动量之前,会将运行方差裁剪为5。其次,还会异步地将模型参数从GPU复制到CPU(每25次更新复制一次),以获得更稳定的更新。
该模型在1024个16 GB NVIDIA V100 gpu和总batch size为1024的设置下训练模型,总共进行了430000更新。step size在前5000次迭代中,线性退火到 4.5 ⋅ 1 0 − 4 4.5 · 10^{−4} 4.5⋅10−4,并在每次训练损失趋于稳定时将step size减半,训练周期内,总共减半了5次,比初始步长小32倍的最终步长结束训练。
Details for Human Evaluation Experiments
对于人类验证实验,本文使每个模型对每个文本标题生成一个示例图像,并给定文本和示例图像让人类给出比较结果,实验提交给了亚马逊的Mechanical Turk,每组生成都由五名不同的人类回答。工作人员被要求比较两张图像并选择答案:(1)哪张图像最真实,(2)哪张图像最匹配文本标题。提供给人类的实验设置如上图。
Zero-Shot Image-to-Image Translation
上图显示了DALL-E的zero-shot图像到图像转换的示例。
reference
Ramesh, A. , Pavlov, M. , Goh, G. , Gray, S. , Voss, C. , & Radford, A. , et al. (2021). Zero-shot text-to-image generation.