无需向量量化的自回归图像生成

摘要

https://arxiv.org/pdf/2406.11838
传统观点认为,用于图像生成的自回归模型通常伴随着向量量化的标记。我们观察到,尽管离散值空间可以方便地表示分类分布,但它对于自回归建模来说并不是必需的。在这项工作中,我们提出使用扩散过程来建模每个标记的概率分布,这使得我们可以在连续值空间中应用自回归模型。我们定义了一个扩散损失函数来建模每个标记的概率,而不是使用分类交叉熵损失。这种方法消除了对离散值标记器的需求。我们在包括标准自回归模型和广义掩码自回归(MAR)变体在内的广泛案例中评估了其有效性。通过去除向量量化,我们的图像生成器在享受序列建模的速度优势的同时,取得了出色的结果。我们希望这项工作能够激发在其他连续值领域和应用中使用自回归生成的想法。

1、引言

自回归模型目前是自然语言处理中生成模型的公认解决方案[38,39,3]。这些模型基于之前的词作为输入来预测序列中的下一个词或标记。由于语言的离散性,这些模型的输入和输出都位于一个分类的、离散值空间中。这种普遍的方法导致了人们广泛认为自回归模型与离散表示法之间存在着固有的联系。

因此,关于将自回归模型推广到连续值域——尤其是图像生成——的研究,强烈地关注于数据的离散化[6, 13, 40]。一个常用的策略是在图像上训练一个离散值标记器,这涉及通过向量量化(VQ)获得的有限词汇表[51, 41]。然后,自回归模型在离散值标记空间上操作,类似于它们在语言处理中的对应物。

在这项工作中,我们旨在回答以下问题:自回归模型是否必须与向量量化表示相结合?我们注意到,自回归性质,即“基于先前的标记预测下一个标记”,与值是离散的还是连续的无关。所需的是建模每个标记的概率分布,这可以通过损失函数来衡量,并用于从中抽取样本。离散值表示可以通过分类分布来方便地建模,但从概念上讲,这并不是必需的。如果提供了替代模型来建模每个标记的概率分布,那么可以在不使用向量量化的情况下应用自回归模型。
在这里插入图片描述

基于这一观察,我们提出通过在连续值域上操作的扩散过程来建模每个标记的概率分布。我们的方法利用了扩散模型[45,24,33,10]的原理来表示任意概率分布。具体来说,我们的方法为每个标记自回归地预测一个向量 z z z,该向量作为去噪网络(例如,一个小型多层感知器MLP)的条件。去噪扩散过程使我们能够表示输出 x x x的潜在分布 p ( x ∣ z ) p(x \mid z) p(xz)(如图1所示)。这个小的去噪网络与自回归模型一起训练,以连续值标记作为输入和目标。从概念上讲,这个应用于每个标记的小预测头就像是一个损失函数,用于衡量 z z z的质量。我们将这个损失函数称为扩散损失(Diffusion Loss)。

我们的方法消除了对离散值标记器的需求。向量量化标记器难以训练,并且对梯度近似策略敏感[51, 41, 40, 27]。与连续值对应物相比,它们的重建质量通常较差[42]。我们的方法允许自回归模型享受更高质量、非量化标记器的优势。

为了拓宽范围,我们进一步将标准的自回归(AR)模型[13]和掩码生成模型[4,29]统一到一个广义的自回归框架中(图3)。从概念上讲,掩码生成模型以随机顺序同时预测多个输出标记,同时仍保持“基于已知标记预测下一个标记”的自回归性质。这导致了掩码自回归(MAR)模型,它可以无缝地与扩散损失(Diffusion Loss)一起使用。

我们通过实验展示了扩散损失在广泛情况下(包括AR和MAR模型)的有效性。它消除了对向量量化标记器的需求,并持续提高了生成质量。我们的损失函数可以灵活地与不同类型的标记器一起应用。此外,我们的方法享有序列模型快速速度的优势。我们的带有扩散损失的MAR模型可以在每秒不到0.3秒的速度下生成图像,同时在ImageNet 256 × 256 256 \times 256 256×256 上实现了低于2.0的FID分数。我们最好的模型可以达到接近1.55的FID分数。

我们方法的有效性揭示了图像生成中一个尚未充分探索的领域:通过自回归来建模标记之间的相互依赖关系,同时利用扩散来建模每个标记的分布。这与典型的潜在扩散模型[42,37]形成了对比,其中扩散过程建模了所有标记的联合分布。鉴于我们方法的有效性、速度和灵活性,我们希望扩散损失将推动自回归图像生成的发展,并在未来的研究中被推广到其他领域。

2、相关工作

序列模型用于图像生成。在自回归图像模型方面的开创性工作[17, 50, 49, 36, 7, 6]是在像素序列上进行的。自回归可以通过RNNs[50]、CNNs[49, 7]以及最近和最流行的Transformers[36, 6]来实现。受到语言模型的启发,另一系列工作[51, 41, 13, 40]将图像建模为离散值的标记。自回归[13, 40]和掩码生成模型[4, 29]可以在离散值标记空间上操作。但是,离散标记器难以训练,这最近引起了特别的关注[27, 54, 32]。

与我们工作相关的是,最近的GIVT工作[48]也关注序列模型中的连续值标记。GIVT和我们的工作都揭示了这一方向的重要性和潜力。在GIVT中,标记分布由高斯混合模型表示。它使用预定义数量的混合,这可能限制了它可以表示的分布类型。相比之下,我们的方法利用扩散过程建模任意分布的有效性。

扩散在表示学习中的应用。去噪扩散过程已被探索作为视觉自监督学习的标准。例如,DiffMAE[53]用去噪扩散解码器替换了原始MAE[21]中的L2损失;DARL[30]使用去噪扩散块解码器训练自回归模型。这些努力一直专注于表示学习,而不是图像生成。在他们的场景中,生成多样化的图像不是目标;这些方法还没有展示出从头开始生成新图像的能力。

扩散在策略学习中的应用。我们的工作在概念上与机器人学中的Diffusion Policy[8]相关。在那些场景中,执行动作的分布被制定为对机器人观测值的去噪过程,这些观测值可以是像素或潜在特征[8, 34]。在图像生成中,我们可以将生成一个标记视为要采取的“动作”。尽管存在这种概念上的联系,但在机器人学中,生成样本的多样性并不是像图像生成那样是一个核心考虑因素。

3、方法

简而言之,我们的图像生成方法是在标记化的潜在空间上操作的序列模型[6, 13, 40]。但与之前基于向量量化标记器(例如VQ-VAE的变体[51, 13])的方法不同,我们旨在使用连续值标记器(例如[42])。我们提出了扩散损失(Diffusion Loss),使序列模型与连续值标记兼容。

3.1、重新思考离散值标记

首先,我们重新审视自回归生成模型中离散值标记的角色。假设 x x x是在下一个位置要预测的真实标记。使用离散标记器, x x x可以表示为一个整数: 0 ≤ x < K 0 \leq x < K 0x<K,其中 K K K是词汇表的大小。自回归模型生成一个连续值的 D D D维向量 z ∈ R D z \in \mathbb{R}^{D} zRD,然后通过一个 K K K路分类器矩阵 W ∈ R K × D W \in \mathbb{R}^{K \times D} WRK×D进行投影。从概念上讲,这种公式化形式将分类概率分布建模为 p ( x ∣ z ) = softmax ⁡ ( W z ) p(x \mid z) = \operatorname{softmax}(W z) p(xz)=softmax(Wz)

在生成建模的上下文中,这个概率分布必须表现出两个基本属性。(i) 一个能够衡量估计分布和真实分布之间差异的损失函数。在分类分布的情况下,这可以通过交叉熵损失简单地实现。(ii) 一个在推理时可以从分布 x ∼ p ( x ∣ z ) x \sim p(x \mid z) xp(xz)中抽取样本的采样器。在分类分布的情况下,这通常通过从 p ( x ∣ z ) = softmax ⁡ ( W z / τ ) p(x \mid z) = \operatorname{softmax}(W z / \tau) p(xz)=softmax(Wz/τ)中抽取样本来实现,其中 τ \tau τ是一个控制样本多样性的温度参数。从分类分布中采样可以通过Gumbel-max方法[18]或逆变换采样来实现。

这种分析表明,离散值标记对于自回归模型来说并不是必要的。相反,建模分布的要求才是本质。离散值标记空间暗示了一个分类分布,其损失函数和采样器定义起来很简单。我们真正需要的是用于分布建模的损失函数及其对应的采样器。

3.2、扩散损失

去噪扩散模型[24]提供了一个有效的框架来建模任意分布。但与常见的用于表示所有像素或所有标记的联合分布的扩散模型用法不同,在我们的案例中,扩散模型是用于表示每个标记的分布。

考虑一个连续值向量 x ∈ R d x \in \mathbb{R}^{d} xRd,它表示要在下一个位置预测的真实标记。自回归模型在该位置生成一个向量 z ∈ R D z \in \mathbb{R}^{D} zRD。我们的目标是建模条件分布 p ( x ∣ z ) p(x \mid z) p(xz),即 x x x在给定 z z z时的概率分布。损失函数和采样器可以根据扩散模型[24, 33, 10]来定义,具体描述如下。

损失函数。根据[24, 33, 10],底层概率分布 p ( x ∣ z ) p(x \mid z) p(xz)的损失函数可以制定为一个去噪准则:

L ( z , x ) = E ε , t [ ∥ ε − ε θ ( x t ∣ t , z ) ∥ 2 ] \mathcal{L}(z, x) = \mathbb{E}_{\varepsilon, t}\left[\left\|\varepsilon - \varepsilon_{\theta}\left(x_{t} \mid t, z\right)\right\|^{2}\right] L(z,x)=Eε,t[εεθ(xtt,z)2]

在这里, ε ∈ R d \varepsilon \in \mathbb{R}^{d} εRd 是一个从 N ( 0 , I ) \mathcal{N}(\mathbf{0}, \mathbf{I}) N(0,I) 采样得到的噪声向量。噪声干扰后的向量 x t x_{t} xt x t = α ˉ t x + 1 − α ˉ t ε x_{t} = \sqrt{\bar{\alpha}_{t}} x + \sqrt{1 - \bar{\alpha}_{t}} \varepsilon xt=αˉt x+1αˉt ε,其中 α ˉ t \bar{\alpha}_{t} αˉt 定义了一个噪声时间表[24, 33]。 t t t 是噪声时间表的时间步。噪声估计器 ε θ \varepsilon_{\theta} εθ,由参数 θ \theta θ 参数化,是一个小型多层感知机(MLP)网络(参见第4节)。符号 ε θ ( x t ∣ t , z ) \varepsilon_{\theta}\left(x_{t} \mid t, z\right) εθ(xtt,z) 意味着这个网络以 x t x_{t} xt 作为输入,并且同时依赖于 t t t z z z。根据[46, 47],等式(1)在概念上类似于一种得分匹配:它与 p ( x ∣ z ) p(x \mid z) p(xz) 的得分函数相关的损失函数有关,即 ∇ log ⁡ x p ( x ∣ z ) \nabla \log_{x} p(x \mid z) logxp(xz)。扩散损失是一种参数化的损失函数,与对抗性损失[15]或感知损失[56]类似。

值得注意的是,条件向量 z z z 是由自回归网络产生的: z = f ( ⋅ ) z = f(\cdot) z=f(),我们稍后会讨论这一点。 z = f ( ⋅ ) z = f(\cdot) z=f() 的梯度是从等式(1)中的损失函数传播的。在概念上,等式(1)定义了一个用于训练网络 f ( ⋅ ) f(\cdot) f() 的损失函数。

我们注意到等式(1)中的期望 E ε , t [ ⋅ ] \mathbb{E}_{\varepsilon, t}[\cdot] Eε,t[] 是针对 t t t 的,对于任何给定的 z z z。由于我们的去噪网络较小,我们可以对任何给定的 z z z 多次采样 t t t。这有助于改善损失函数的利用,而无需重新计算 z z z。在训练过程中,我们为每个图像在每次迭代中采样 t t t 四次。

采样器。在推理时,需要从分布 p ( x ∣ z ) p(x \mid z) p(xz) 中抽取样本。采样是通过反向扩散过程[24]完成的: x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ε θ ( x t ∣ t , z ) ) + σ t δ x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}\left(x_{t} - \frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \varepsilon_{\theta}\left(x_{t} \mid t, z\right)\right) + \sigma_{t} \delta xt1=αt 1(xt1αˉt 1αtεθ(xtt,z))+σtδ。这里 δ \delta δ 是从高斯分布 N ( 0 , I ) \mathcal{N}(\mathbf{0}, \mathbf{I}) N(0,I) 中采样的,而 σ t \sigma_{t} σt 是在时间步 t t t 的噪声水平。从 x T ∼ N ( 0 , I ) x_{T} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) xTN(0,I) 开始,这个过程生成一个样本 x 0 x_{0} x0,使得 x 0 ∼ p ( x ∣ z ) x_{0} \sim p(x \mid z) x0p(xz)[24]。

当使用分类分布(第3.1节)时,自回归模型可以享受通过温度 τ \tau τ 控制样本多样性的好处。事实上,无论是语言还是图像方面的现有文献都表明,温度在自回归生成中起着关键作用。希望扩散采样器能提供一个温度的对应物。我们采用了[10]中提出的温度采样。在概念上,给定温度 τ \tau τ,人们可能希望从(重新归一化)概率 p ( x ∣ z ) 1 τ p(x \mid z)^{\frac{1}{\tau}} p(xz)τ1 中采样,其得分函数是 1 τ ∇ log ⁡ x p ( x ∣ z ) \frac{1}{\tau} \nabla \log _{x} p(x \mid z) τ1logxp(xz)。在实践中,[10]建议要么将 ε θ \varepsilon_{\theta} εθ 除以 τ \tau τ,要么将噪声乘以 τ \tau τ。我们采用了后者:在采样器中,我们将 σ t δ \sigma_{t} \delta σtδ 乘以 τ \tau τ。直观地讲, τ \tau τ 通过调整噪声方差来控制样本多样性。

3.3、自回归模型的扩散损失

接下来,我们描述用于图像生成的自回归模型及其扩散损失。给定一个标记序列 { x 1 , x 2 , … , x n } \left\{x^{1}, x^{2}, \ldots, x^{n}\right\} {x1,x2,,xn},其中上标 1 ≤ i ≤ n 1 \leq i \leq n 1in 指定了顺序,自回归模型[17,50,49,36,7,6]将生成问题表述为“下一个标记预测”:

p ( x 1 , … , x n ) = ∏ i = 1 n p ( x i ∣ x 1 , … , x i − 1 ) p\left(x^{1}, \ldots, x^{n}\right)=\prod_{i=1}^{n} p\left(x^{i} \mid x^{1}, \ldots, x^{i-1}\right) p(x1,,xn)=i=1np(xix1,,xi1)

使用一个网络来表示条件概率 p ( x i ∣ x 1 , … , x i − 1 ) p\left(x^{i} \mid x^{1}, \ldots, x^{i-1}\right) p(xix1,,xi1)。在我们的案例中, x i x^{i} xi 可以是连续值。我们可以将这个表述重写为两个部分。首先,我们通过一个网络(例如Transformer[52])对之前的标记进行操作来生成一个条件向量 z i z^{i} zi z i = f ( x 1 , … , x i − 1 ) z^{i}=f\left(x^{1}, \ldots, x^{i-1}\right) zi=f(x1,,xi1)。然后,我们通过 p ( x i ∣ z i ) p\left(x^{i} \mid z^{i}\right) p(xizi) 来建模下一个标记的概率。等式(1)中的扩散损失可以应用于 p ( x i ∣ z i ) p\left(x^{i} \mid z^{i}\right) p(xizi)。梯度将反向传播到 z i z^{i} zi 以更新 f ( ⋅ ) f(\cdot) f() 的参数。

3.4、统一自回归和掩码生成模型

我们展示了掩码生成模型,如MaskGIT[4]和MAGE[29],可以在自回归的广泛概念下被泛化,即下一个标记的预测。

双向注意力可以执行自回归。自回归的概念与网络架构是正交的:自回归可以通过RNNs[50]、CNNs[49, 7]和Transformers[38, 36, 6]来实现。当使用Transformers时,尽管自回归模型通常通过因果注意力来实现,但我们展示了它们也可以通过双向注意力来实现。参见图2。请注意,自回归的目标是根据先前的标记预测下一个标记;它并不限制先前的标记如何与下一个标记通信。

我们可以采用在Masked Autoencoder(MAE)[21]中使用的双向注意力实现。参见图2(b)。具体来说,我们首先对已知标记(带有位置嵌入[52])应用MAE风格的编码器 1 {}^{1} 1。然后,我们将编码后的序列与掩码标记(再次添加位置嵌入)进行连接,并使用MAE风格的解码器对该序列进行映射。掩码标记上的位置嵌入可以让解码器知道要预测哪些位置。与因果注意力不同,这里的损失仅计算在未知标记上[21]。

通过使用MAE风格的技巧,我们允许所有已知标记相互可见,同时也允许所有未知标记看到所有已知标记。这种全注意力引入了在标记之间比因果注意力更好的通信。在推理时,我们可以使用这种双向表述来生成标记(每步一个或多个),这是一种自回归的形式。作为折衷,我们不能使用因果注意力的键值(kv)缓存[44]来加速推理。但是,由于我们可以一起生成多个标记,我们可以减少生成步骤以加快推理速度。跨标记的全注意力可以显著提高质量,并提供更好的速度/准确度的权衡。

自回归模型在随机顺序中。为了与掩码生成模型[4, 29]相联系,我们考虑了随机顺序的自回归变体。模型会得到一个随机排列的序列。这个随机排列对于每个样本都是不同的。参见图3(b)。在这种情况下,需要预测的下一个标记的位置必须能够被模型访问。我们采用与MAE[21]类似的策略:我们在解码器层中添加位置嵌入(对应于未打乱的位置),这可以告诉模型要预测哪些位置。这种策略既适用于因果版本也适用于双向版本。

如图3(b)©所示,随机顺序的自回归表现得像一种特殊的掩码生成形式,其中一次生成一个标记。我们详细解释如下。

掩码自回归模型。在掩码生成建模[4, 29]中,模型基于已知/预测的标记预测一个随机的标记子集。这可以表述为通过随机顺序排列标记序列,然后基于先前的标记预测多个标记。参见图3©。从概念上讲,这是一个自回归过程,可以写为估计条件分布: p ( { x i , x i + 1 … , x j } ∣ x 1 , … , x i − 1 ) p\left(\left\{x^{i}, x^{i+1} \ldots, x^{j}\right\} \mid x^{1}, \ldots, x^{i-1}\right) p({xi,xi+1,xj}x1,,xi1),其中需要预测多个标记 { x i , x i + 1 … , x j } \left\{x^{i}, x^{i+1} \ldots, x^{j}\right\} {xi,xi+1,xj} i ≤ j i \leq j ij)。我们可以将这个自回归模型写为:

p ( x 1 , … , x n ) = p ( X 1 , … , X K ) = ∏ k K p ( X k ∣ X 1 , … , X k − 1 ) p\left(x^{1}, \ldots, x^{n}\right)=p\left(X^{1}, \ldots, X^{K}\right)=\prod_{k}^{K} p\left(X^{k} \mid X^{1}, \ldots, X^{k-1}\right) p(x1,,xn)=p(X1,,XK)=kKp(XkX1,,Xk1)

在这里, X k = { x i , x i + 1 … , x j } X^{k}=\left\{x^{i}, x^{i+1} \ldots, x^{j}\right\} Xk={xi,xi+1,xj}是在第 k k k步要预测的一组标记,其中 ∪ k X k = { x 1 , … , x n } \cup_{k} X^{k}=\left\{x^{1}, \ldots, x^{n}\right\} kXk={x1,,xn}。从这个意义上讲,这本质上是“下一组标记预测”,因此也是自回归的一种一般形式。我们将这种变体称为掩码自回归(MAR)模型。MAR是一种随机顺序的自回归模型,可以同时预测多个标记。

MAR在概念上与MAGE[29]相关。然而,MAR通过应用于每个标记的概率分布的温度 τ \tau τ来采样标记(这是像GPT这样的生成式语言模型中的标准做法)。相比之下,MAGE(遵循MaskGIT[4])应用一个温度来采样要预测的标记的位置:这不是一个完全随机的顺序,这会在训练时间和推理时间行为之间造成差异。

4、实现

本节描述了我们的实现。我们注意到,本文中介绍的概念是通用的,并不限于特定的实现。更详细的特定信息在附录B中。

4.1、扩散损失

扩散过程。我们的扩散过程遵循[33]。我们的噪声计划具有余弦形状,训练时有1000步;在推理时,它会重新采样为较少的步骤(默认情况下为100步)[33]。我们的去噪网络预测噪声向量 ε \varepsilon ε[24]。损失可以选择性地包括变分下界项 L v l b \mathcal{L}_{\mathrm{vlb}} Lvlb[33]。扩散损失自然支持无分类器指导(CFG)[23](详细见附录B)。

去噪MLP。我们使用一个小型的MLP,由几个残差块[20]组成,用于去噪。每个块依次应用LayerNorm(LN)[1]、一个线性层、SiLU激活函数[12]、和另一个线性层,并通过残差连接合并。默认情况下,我们使用3个块和1024个通道的宽度。去噪MLP的条件是AR/MAR模型产生的向量 z z z(见图1)。向量 z z z被添加到噪声计划时间步 t t t的时间嵌入中,该时间嵌入通过AdaLN[37]在LN层中作为MLP的条件。

4.2、自回归和掩码自回归图像生成

分词器。我们使用LDM[42]提供的公开可用的分词器。我们的实验将涉及他们的VQ-16和KL-16版本[42]。VQ-16是VQ-GAN[13],即带有GAN损失[15]和感知损失[56]的VQ-VAE[51];KL-16是通过Kullback-Leibler (KL)散度正则化的对应版本,没有向量量化。16表示分词器的步长。

Transformer。我们的架构遵循ViT[11]中的Transformer[52]实现。给定来自分词器的标记序列,我们添加位置嵌入[52]并附加类别标记[cls];然后我们用Transformer处理这个序列。默认情况下,我们的Transformer有32个块和1024的宽度,我们称之为Large大小或-\mathrm{L}(~4亿参数)。

自回归基线。因果注意力遵循GPT[38]的常见做法(图2(a))。输入序列通过一个标记(这里是[cls])进行移位。在注意力矩阵上应用三角形掩码[52]。在推理时,应用温度( τ \tau τ)采样。我们使用kv-cache[44]进行高效的推理。

掩码自回归模型。使用双向注意力(图2(b)),我们可以基于任意数量的已知标记预测任意数量的未知标记。在训练时,我们在[0.7,1.0]范围内随机采样一个掩码比率[21,4,29]:例如,0.7意味着70%的标记是未知的。由于采样序列可能非常短,我们总是在编码器序列的开头填充64个[cls]标记,这提高了我们编码的稳定性和容量。如图2所示,在解码器中引入了掩码标记[\mathrm{m}],并添加了位置嵌入。为了简单起见,与[21]不同,我们让编码器和解码器具有相同的大小:每个都有所有块的一半(例如,在MAR-L中为16个块)。

在推理时,MAR执行“下一组标记预测”。它使用余弦计划[4,29]逐步将掩码比率从1.0减少到0。默认情况下,我们在该计划中使用64步。应用温度( τ \tau τ)采样。与[4, 29]不同,MAR总是使用完全随机的顺序。

5、实验

我们在分辨率为 256 × 256 256 \times 256 256×256的ImageNet[9]上进行实验。我们评估FID[22]和IS[43],并提供精度和召回率作为参考,以遵循常见的实践[10]。我们遵循[10]提供的评估套件。

5.1、扩散损失的特性

扩散损失与交叉熵损失的比较。我们首先比较使用扩散损失的连续值标记与标准离散值标记的交叉熵损失(表1)。为了公平比较,我们从LDM代码库[42]下载了这两种分词器(“VQ-16”和“KL-16”)。这些分词器被广泛使用(例如,[13,42,37])。

如表1所示,在所有情况下,扩散损失都一致优于对应的交叉熵损失。具体来说,在MAR(例如默认设置)中,使用扩散损失可以将FID相对降低约50%~60%。这是因为连续值的KL-16比VQ-16具有更小的压缩损失(接下来在表2中讨论),而且扩散过程比分类过程更有效地建模分布。

在以下消融实验中,除非另有说明,我们遵循表1中的“默认”MAR设置。

扩散损失的灵活性。扩散损失的一个显著优势是它与各种分词器的灵活性。我们在表2中比较了几种公开可用的分词器。

即使给定VQ分词器,扩散损失也可以轻松使用。我们简单地将VQ层之前的连续值潜在作为标记。这种变体给出了7.82的FID(不使用CFG),与使用相同VQ分词器的交叉熵损失的8.79 FID(表1)相比表现良好。这表明扩散在建模分布方面的能力更强。

这种变体还使我们能够使用相同的损失来比较VQ-16和KL-16分词器。如表2所示,VQ-16的重建FID(rFID)远差于KL-16,这导致生成的FID也差得多(例如,表2中的7.82与3.50)。

有趣的是,扩散损失还允许我们使用步长不匹配的分词器。在表2中,我们研究了步长为8、输出序列长度为 32 × 32 32 \times 32 32×32的KL-8分词器。在不增加生成器序列长度的情况下,我们将 2 × 2 2 \times 2 2×2个标记组合成一个新标记。尽管存在不匹配,但我们仍然能够获得相当不错的结果,例如,KL-8给出了2.05的FID,与KL-16的1.98 FID相比。此外,这一特性允许我们研究其他分词器,例如Consistency Decoder[35],这是一个针对不同目标设计、具有不同架构/步长的非VQ分词器。

为了全面起见,我们还使用[42]的代码在ImageNet上训练了一个KL-16分词器,注意到[42]中原始的KL-16是在OpenImages[28]上训练的。比较结果列在表2的最后一行。在以下的探索中,我们使用这个分词器。

扩散损失中的去噪MLP。我们研究了表3中的去噪MLP。即使是一个非常小的MLP(例如,2M)也能带来有竞争力的结果。正如预期的那样,增加MLP的宽度有助于提高生成质量;我们已经探索了增加深度并得出了类似的观察结果。请注意,我们默认的MLP大小(1024宽度,21M)仅为MAR-L模型增加了约5%的额外参数。在推理时,扩散采样器有一个相当不错的成本,约占总体运行时间的10%。在我们的实现中,增加MLP的宽度具有可忽略的额外成本(表3),部分原因是主要的开销不在于计算而在于内存通信。

扩散损失的采样步数。我们的扩散过程遵循DDPM的常用做法[24,10]:我们使用1000步的噪声计划进行训练,但在推理时使用更少的步数。图4显示了在推理时使用100步扩散步骤足以达到较强的生成质量。

扩散损失的温度。在交叉熵损失的情况下,温度是至关重要的。扩散损失也提供了一个温度对应物来控制多样性和保真度。图5展示了在推理时扩散采样器中的温度 τ \tau τ(见第3.2节)的影响。在我们的模型中,温度 τ \tau τ起着重要作用,类似于基于交叉熵的对应物的观察结果(注意表1中的交叉熵结果是在其最佳温度下获得的)。

5.2、广义自回归模型的特性

从AR到MAR。表1还对比了AR/MAR的变种,接下来我们将讨论这些变种。首先,将AR中的栅格顺序替换为随机顺序带来了显著的增益,例如,将FID从19.23降低到13.07(无CFG)。接下来,将因果注意力替换为双向的对应物带来了另一个巨大的增益,例如,将FID从13.07降低到3.43(无CFG)。

随机顺序、双向的AR本质上是一种形式的MAR,它一次预测一个标记。在每一步中预测多个标记(’ >1 ')可以有效地减少自回归步骤的数量。在表1中,我们展示了具有64个步骤的MAR变种在生成质量上略有妥协。接下来将讨论更全面的权衡比较。

速度/准确度的权衡。跟随MaskGIT [4],我们的MAR享有一次性预测多个标记的灵活性。这通过在推理时控制自回归步骤的数量来实现。图6展示了速度/准确度的权衡。MAR与其AR对应物相比具有更好的权衡,注意到AR使用了高效的kv-cache。

使用扩散损失时,MAR与最近流行的Diffusion Transformer(DiT)[37]相比也显示出有利的权衡。作为一个潜在扩散模型,DiT通过扩散过程来建模所有标记之间的相互依赖关系。DiT的速度/准确度权衡主要由其扩散步骤控制。与我们在小型MLP上的扩散过程不同,DiT的扩散过程涉及整个Transformer架构。我们的方法更准确且更快。值得注意的是,我们的方法能够以每幅图像小于0.3秒的速度生成图像,同时保持FID小于2.0的强劲性能。

5.3、与先前系统的基准测试

我们在表4中与领先的系统进行了比较。我们探索了不同的模型大小(见附录B)并训练了800个周期。类似于自回归语言模型[3],我们观察到了令人鼓舞的扩展行为。进一步探索扩展可能会很有前景。关于指标,我们报告了不使用CFG时的FID为2.35,大幅超过了其他基于标记的方法。我们最佳的表现FID为1.55,与其他领先系统相比具有竞争力。图7展示了定性结果。

6、讨论与结论

Diffusion Loss在各种自回归模型上的有效性为新的机会指明了方向:通过自回归建模标记之间的相互依赖关系,同时结合扩散过程建模每个标记的分布。这与常见的扩散用法不同,后者建模所有标记的联合分布。我们在图像生成上的出色结果表明,自回归模型或其扩展是超越语言建模的有力工具。这些模型不需要被向量量化表示所限制。我们希望我们的工作能激励研究社区探索在其他领域中使用连续值表示的序列模型。

A、局限性与更广泛的影响

局限性。除了展示我们方法在图像生成方面的潜力外,本文也承认其局限性。

首先,我们的图像生成系统可能会产生带有明显瑕疵的图像(如图8所示)。这种局限性在现有方法中很常见,尤其是在使用受控的学术数据(如ImageNet)进行训练时。与在大量数据上训练的商业模型相比,基于ImageNet训练的研究驱动模型在视觉质量上仍有明显的差距。

其次,我们的图像生成系统依赖于现有的预训练标记器。我们系统的质量可能会受到这些标记器质量的限制。预训练更好的标记器超出了本文的范围。然而,我们希望我们的工作将为未来开发连续值标记器变得更加容易。

最后,我们注意到,由于计算资源有限,我们主要在ImageNet基准测试集上测试了我们的方法。为了评估我们的方法在更多样化和现实场景中的可扩展性和鲁棒性,还需要进一步的验证。

更广泛的影响。我们的主要目标是推动生成模型的基础研究,并相信这将对该领域产生积极影响。我们方法的直接应用之一是将其扩展到大型视觉生成模型,如文本到图像或文本到视频生成。我们的方法有可能显著降低这些大型模型的训练和推理成本。同时,我们的方法可能表明在许多应用中用Diffusion Loss替换传统损失函数的机会。从负面来看,我们的方法从训练数据集中学习统计信息,因此可能会反映数据中的偏差;图像生成系统可能会被滥用以生成虚假信息,这值得进一步考虑。

B、附加实现细节

分类器自由引导(CFG)。为了支持CFG [23],在训练时,对于10%的样本,类别条件被替换为一个虚拟类别标记[23]。在推理时,模型使用给定的类别标记和虚拟标记运行,产生两个输出 z c z_{c} zc z u z_{u} zu。然后,预测的噪声 ε \varepsilon ε被修改为[23]: ε = ε θ ( x t ∣ t , z u ) + ω ⋅ ( ε θ ( x t ∣ t , z c ) − ε θ ( x t ∣ t , z u ) ) \varepsilon=\varepsilon_{\theta}\left(x_{t} \mid t, z_{u}\right)+\omega \cdot\left(\varepsilon_{\theta}\left(x_{t} \mid t, z_{c}\right)-\varepsilon_{\theta}\left(x_{t} \mid t, z_{u}\right)\right) ε=εθ(xtt,zu)+ω(εθ(xtt,zc)εθ(xtt,zu)),其中 ω \omega ω是引导尺度。在推理时,我们遵循[5]的CFG计划。我们针对每个模型扫描最佳引导尺度和温度组合。

训练。默认情况下,模型使用AdamW优化器[31]训练400个周期。AdamW的权重衰减和动量分别为0.02和(0.9,0.95)。我们使用批处理大小为2048,学习率(lr)为 8 e − 4 8 \mathrm{e}-4 8e4。我们的带有Diffusion Loss的模型采用100个周期的线性lr预热[16],之后是恒定[37]的lr计划。交叉熵对应的模型则使用余弦lr计划,这对它们更有效。遵循[37, 25],我们使用0.9999的动量维持模型参数的指数移动平均值(EMA)。

表4中的实现细节。为了探索我们方法的扩展行为,我们研究了以下描述的三种模型大小。除了MAR-L,我们还探索了一个较小的模型(MAR-B)和一个较大的模型(MAR-H)。MAR-B、-L和-H分别具有24、32、40个Transformer块,以及768、1024和1280的宽度。在表4中特别地,去噪MLP分别具有6、8、12个块,以及1024、1280和1536的宽度。训练周期增加到800个周期。在推理时,我们运行256个自回归步骤以达到最佳结果。

Diffusion Loss的伪代码。见算法1。

Diffusion Loss 概念的伪代码。在这里,条件向量 z z z 是从 AR/MAR 模型中输出的。梯度将反向传播到 z z z。为了简化,这里我们省略了推理重调度、温度和变分下界损失项的代码[10],这些可以很容易地合并进来。

计算资源。我们的训练主要在 16 台服务器上完成,每台服务器配备了 8 个 V100 GPU。在这些 GPU 上训练一个 400 轮次的 MAR-L 模型大约需要 2.6 天。相比之下,在同一集群上训练相同轮次的 DiT-XL/2 和 LDM-4 模型分别需要 4.6 天和 9.5 天。

MAR 和 MAGE 的比较。MAR(无论使用何种损失)在概念上与 MAGE [29] 相关。除了实现上的差异(例如,架构特定性和超参数)之外,MAR 和 MAGE 在推理时的扫描顺序上存在主要的概念差异。在 MAGE 中,遵循 MaskGIT [4] 的方法,下一个要预测的令牌的位置是根据每个位置的样本置信度动态确定的,即每个步骤中更自信的位置更有可能被选中[4, 29]。相比之下,MAR 采用完全随机的顺序,并且其温度采样应用于每个令牌。表 5 在受控设置下比较了这种差异。第一行是我们的 MAR 实现,但使用了 MAGE 的即时排序策略,其结果与简单的随机排序相似。完全随机的排序可以使训练和推理过程在令牌顺序的分布上保持一致;它还允许我们以类似于自回归语言模型(例如,GPT [38,39,3])的方式对每个令牌进行温度采样。

D、额外比较

D.1、ImageNet 512 × 512 512 \times 512 512×512

与先前的工作一样,我们也报告了在 512 × 512 512 \times 512 512×512 分辨率下 ImageNet 的结果,并与领先的系统进行了比较(见表6)。为了简化,我们使用 KL-16 分词器,它在 512 × 512 512 \times 512 512×512 图像上给出 32 × 32 32 \times 32 32×32 的序列长度。其他设置遵循表4中描述的 MAR-L 配置。我们的方法在没有使用 CFG 的情况下达到了 FID 分数 2.74,使用 CFG 时达到了 1.73。我们的结果与先前系统的结果相竞争。由于资源有限,我们尚未在 ImageNet 512 × 512 512 \times 512 512×512 上训练更大的 MAR-H,预计会有更好的结果。

D.2、L2 损失与 Diffusion 损失

对于连续值令牌的一个简单基线是直接计算预测和目标令牌之间的均方误差(MSE,即 L2)损失。在栅格顺序的自回归模型的情况下,使用 L2 损失不引入随机性,因此无法生成多样化的样本。在使用 L2 损失的 MAR 模型中,唯一的随机性是序列顺序;对于任何给定的顺序,某个位置的预测是确定的。在我们的实验中,我们训练了一个使用 L2 损失的 MAR 模型,这导致 FID 分数灾难性地大于 100(>100)。

致谢

我们感谢 Congyue Deng 和 Xinlei Chen 的有益讨论。我们感谢 Google TPU 研究云(TRC)为我们提供 TPU 访问权限,以及 Google Cloud Platform 对 GPU 资源的支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/756826.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Kafka入门-分区及压缩

一、生产者消息分区 Kafka的消息组织方式实际上是三级结构&#xff1a;主题-分区-消息。主题下的每条消息只会保存在某一个分区中&#xff0c;而不会在多个分区中被保存多份。 分区的作用就是提供负载均衡的能力&#xff0c;或者说对数据进行分区的主要原因&#xff0c;就是为…

spring和springboot的关系是什么?

大家好&#xff0c;我是网创有方的站长&#xff0c;今天给大家分享下spring和springboot的关系是什么&#xff1f; Spring和Spring Boot之间的关系可以归纳为以下几个方面&#xff1a; 技术基础和核心特性&#xff1a; Spring&#xff1a;是一个广泛应用的开源Java框架&#…

我们后端程序员不是操作MyBatis的CRUD Boy

大家好&#xff0c;我是南哥。 一个对Java程序员进阶成长颇有研究的人&#xff0c;今天我们接着新的一篇Java进阶指南。 为啥都戏称后端是CRUD Boy&#xff1f;难道就因为天天怼着数据库CRUD吗&#xff1f;要我说&#xff0c;是这个岗位的位置要的就是你CRUD&#xff0c;你不…

JeecgBoot中如何对敏感信息进行脱敏处理?

数据脱敏即将一些敏感信息通过加密、格式化等方式处理&#xff0c;展示给用户一个新的或是格式化后的信息&#xff0c;避免了敏感信息的暴露。 一、接口脱敏注解 针对接口数据实现脱敏加密&#xff0c;只加密&#xff0c;一般此方案用于数据加密展示。 1.1 注解介绍 注解作用域…

Qt信号槽的坑

1、重载的信号&#xff08;以QSpinBox为例&#xff09; 像是点击按钮之类的信号槽很好连接&#xff0c;这是因为它的信号没有重载&#xff0c;如果像SpinBox那样有重载信号的话&#xff08;Qt5.12的见下图&#xff0c;不过Qt5.15LTS开始就不再重载而是换信号名了&#xff09;&…

Linux常用命令大全(超详细!!!)

文章目录 1.Linux是什么1.1 关于Linux我们主要学习什么1.1 学习Linux常见命令的前置知识 2. Linux常见命令2.1 ls命令2.2 cd命令2.3 pwd命令2.4 touch命令2.5 cat命令2.6 echo命令2.7 vim命令2.8 mkdir 命令2.9 rm命令2.10 cp命令2.11 mv命令2.12 grep命令2.13 ps命令2.14 nets…

<Python><ffmpeg>基于python使用PyQt5构建GUI实例:音频格式转换程序(MP3/aac/wma/flac)(优化版2)

前言 本文是基于python语言使用pyqt5来构建的GUI,功能是使用ffmpeg来对音频文件进行格式转换,如mp3、aac、wma、flac等音乐格式。 UI示例: 环境配置 系统:windows 平台:visual studio code 语言:python 库:pyqt5、ffmpeg 概述 本文是建立在之前的博文的基础上的优化版…

(笔记)Error: qemu-virgl: Failed to download resource “qemu-virgl--test-image“解决方法

错误&#xff1a; > Downloading https://www.ibiblio.org/pub/micro/pc-stuff/freedos/files/distributions/1.2/FD12FLOPPY.zip curl: (22) The requested URL returned error: 404Error: qemu-virgl: Failed to download resource "qemu-virgl--test-image" D…

Grafana-11.0.0 在线部署教程

Grafana-11.0.0 在线部署教程 环境&#xff1a; 操作系统&#xff1a; ubuntugrafana版本&#xff1a; 11.0.0 &#xff08;建议不要按照最新版&#xff09;grafana要求的系统配置不高&#xff0c;建议直接部署在监控服务器上&#xff0c;比如zabbix服务器、prometheus服务器…

文华财经通达信同花顺期货通盘立方博易大师主图指标公式源码

买线:EMA(C,2); 卖线:EMA(SLOPE(C,21)*20C,42); BU:CROSS(买线,卖线); SEL:CROSS(卖线,买线); STICKLINE1(买线>卖线,LOW,MIN(O,C),0.1,1),COLORRED; STICKLINE1(买线>卖线,MAX(O,C),HIGH,0.1,1),COLORRED; STICKLINE(买线>卖线,CLOSE,OPEN,8,1),COLORRED; STI…

页面开发感想

页面开发 1、 前端预览 2、一些思路 2.1、首页自定义element-plus的走马灯 :deep(.el-carousel__arrow){border-radius: 0%;height: 10vh; }需要使用:deep(标签)才能修改样式 或者 ::v-deep 标签 2.2、整体设计思路 <template><div class"card" style&…

跟《经济学人》学英文:2024年6月22日这期 India’s electronics industry is surging

India’s electronics industry is surging Foreign and domestic firms are investing in local manufacturing surge:激增&#xff1b;急剧上升&#xff1b; 原文&#xff1a; To witness India’s growing role as a manufacturing hub, dodge Bangalore’s notorious t…

maven安装jar和pom到本地仓库

举例子我们要将 elastic-job-spring-boot-starter安装到本地的maven仓库&#xff0c;如下&#xff1a; <dependency><groupId>com.github.yinjihuan</groupId><artifactId>elastic-job-spring-boot-starter</artifactId><version>1.0.5&l…

使用腾讯云服务器从0搭建个人网站,超简单图文教程

使用腾讯云服务器搭建网站全流程&#xff0c;包括轻量应用服务器和云服务器CVM建站教程&#xff0c;轻量可以使用应用镜像一键建站&#xff0c;云服务器CVM可以通过安装宝塔面板的方式来搭建网站&#xff0c;腾讯云服务器网txyfwq.com整理使用腾讯云服务器建站教程&#xff0c;…

Java的IO体系

目录 1、Java的IO体系2、IO的常用方法3、Java中为什么要分为字节流和字符流4、File和RandomAccessFile5、Java对象的序列化和反序列化6、缓冲流7、Java 的IO流中涉及哪些设计模式 1、Java的IO体系 IO 即为 input 输入 和 output输出 Java的IO体系主要分为字节流和字符流两大类…

【51单片机】串口通信(发送与接收)

文章目录 前言串口通信简介串口通信的原理串口通信的作用串口编程的一些概念仿真图如何使用串口初始化串口串口模式波特率配置 发送与接收发送接收 示例代码 总结 前言 在嵌入式系统的开发中&#xff0c;串口通信是一种常见且重要的通信方式。它以其简单、稳定的特性在各种应用…

多阶段分层构建容器化Spring Boot应用程序

上一节中&#xff0c;容器化spring boot应用程序-CSDN博客我们介绍了基于简单的Dockerfile对spring boot进行容器化的过程&#xff0c;本讲将介绍如何基于Dockerfile进行多阶段的分层构建过程&#xff0c;希望对大家有所帮助。 Spring Boot从版本2.3.0开始支持分层构建容器化的…

4.x86游戏实战-人物状态标志位

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 本次游戏没法给 内容参考于&#xff1a;微尘网络安全 上一个内容&#xff1a;3.x86游戏实战-寄存器 人物状态标志位&#xff1a; 什么叫人物状态标志位&…

PAE:从潮流报告中提炼有效产品属性

本文将介绍PAE&#xff0c;一种用于包含 PDF格式的文本和图像的产品属性提取算法。目前大部分的方法侧重于从标题或产品描述中提取属性&#xff0c;或利用现有产品图像中的视觉信息。与之前的工作相比&#xff0c;PAE从潮流趋势报告的PDF文件中提取属性&#xff0c;提取的属性包…

Django 自定义标签

1&#xff0c;简单标签 1.1 添加自定义标签函数 Test/app5/templatetags/mytags.py from django import template register template.Library() register.simple_tag() def show_title(value, n):if len(value) > n:return f{value[:n]}...else:return value 1.2 添加视…