摘要
在本研究中,我们提出了一种名为对抗扩散蒸馏(ADD)的创新训练技术,它能够在1至4步的采样过程中,高效地对大规模基础图像扩散模型进行处理,同时保持图像的高质量。该方法巧妙地结合了分数蒸馏技术,利用现有的大型图像扩散模型作为指导信号,同时引入对抗损失,确保在较少的采样步骤下依然能够产生高保真度的图像。通过这种结合,我们的模型在单步操作中的表现显著超越了现有的少步骤方法,包括生成对抗网络(GAN)和潜在一致性模型。更为引人注目的是,ADD在仅经过四个采样步骤的情况下,就能达到与当前最先进的扩散模型SDXL相媲美的性能水平。这一成就标志着ADD成为首个能够利用基础模型实现单步实时图像合成的技术,为图像生成领域带来了新的突破。
论文地址: https://stability.ai/research/adversarial-diffusion-distillation
1.简介
扩散模型(DMs)在生成建模领域扮演着核心角色,尤其在高质量图像和视频合成方面取得了显著的进展。DMs的一个关键优势在于它们的可伸缩性和迭代性,使得它们能够处理复杂的任务,如根据自由格式文本提示合成图像。然而,DMs的迭代推理过程需要大量的采样步骤,这限制了它们在实时应用中的使用。相比之下,生成式对抗网络(GANs)以其单步生成和快速特性而闻名。尽管GANs尝试扩展到大型数据集,但在样本质量方面往往无法与DMs相媲美。本研究的目标是将DMs的高质量样本生成能力与GANs的速度优势结合起来。
我们提出了一种概念上简单的方法——对抗扩散蒸馏(ADD),这是一种将预训练扩散模型的推理步骤减少到1至4个采样步骤的通用方法,同时保持高采样保真度,并可能提升模型的整体性能。我们结合了两个训练目标:对抗损失和与分数蒸馏抽样(SDS)相对应的蒸馏损失。对抗损失确保模型在每次前向传递时直接生成高质量的图像样本,避免了其他蒸馏方法中常见的模糊和伪影。蒸馏损失则利用另一个预训练的DM作为教师模型,有效地利用其丰富的知识,并保持大型DMs的强组合性。在推理过程中,我们的方法避免了无分类器引导,从而进一步降低了内存需求,同时保留了通过迭代细化改进结果的能力,这比以往的基于GAN的一步方法更具优势。
主要贡献包括:
- 引入了ADD,这是一种能够将预训练的扩散模型转化为高保真的实时图像生成器的方法,仅需1至4个采样步骤。
- 我们的方法结合了对抗性训练和分数蒸馏,经过精心设计,显著优于现有的强基线模型,如LCM、LCM-XL和单步GAN,并能够在单个推理步骤中处理复杂的图像组成,同时保持高图像真实感。
- 使用四个采样步骤的ADD-XL在512像素分辨率上的表现甚至超过了其教师模型SDXL-Base。
2. 研究背景
扩散模型在合成和编辑高分辨率图像及视频方面取得了显著成就,但其迭代性质限制了在实时应用中的使用。潜在扩散模型尝试通过在潜在空间中表示图像来降低计算成本,但仍然依赖于大型模型的迭代处理。为了解决这一问题,研究者们探索了更快的采样器和模型蒸馏技术,如渐进蒸馏和引导蒸馏,尽管这些方法减少了采样步骤,但可能会牺牲原始性能,并需要迭代训练过程。
一致性模型通过在ODE轨迹上施加一致性正则化来提高性能,并在少量样本设置中展示了基于像素的模型的强大能力。其他方法,如InstaFlow,推荐使用修正流来优化蒸馏过程。然而,这些方法通常在较少的采样步骤下会产生模糊和伪影问题。
与此同时,生成对抗网络(GANs)作为单步模型,虽然在文本到图像合成方面表现出快速的采样速度,但其性能通常不如基于扩散的模型。这是因为GANs的训练需要精细地平衡特定架构以实现稳定的对抗目标,而且在不破坏这种平衡的情况下扩展模型并集成神经网络架构的进步是具有挑战性的。目前最先进的文本到图像GANs缺乏像无分类器制导这样的技术,这对于大规模的DMs来说至关重要。
分数蒸馏抽样是一种新提出的方法,用于将基础T2I模型的知识转移到3D综合模型中。一些基于分数蒸馏抽样的作品专注于3D对象的逐场景优化,并已应用于文本到3D视频合成和图像编辑。最近的研究展示了基于分数的模型和GANs之间的紧密联系,并提出了使用来自DM的基于分数的扩散流进行训练的Score gan。类似地,DiffInstruct是一种推广分数蒸馏抽样的方法,它可以将预训练的扩散模型转化为不带鉴别器的生成器。
我们的方法结合了对抗性训练和分数蒸馏,旨在解决当前几步生成模型中的最佳表现问题。通过这种混合目标的方法,我们旨在提升模型的性能,同时减少采样步骤,以实现更快、更高质量的图像合成。
3. 研究方法
我们的目标是在尽可能少的采样步骤中生成高保真度的样本,同时匹配最先进模型的质量[7,50,53,55]。对抗性目标[14,60]自然适合快速生成,因为它训练的模型可以在单个向前步骤中输出图像。然而,将gan扩展到大型数据集的尝试[58,59]发现,不仅要依赖鉴别器,还要使用预训练的分类器或CLIP网络来改善文本对齐,这一点至关重要。如文献[59]所述,过度使用判别网络会引入伪影,影响图像质量。相反,我们通过分数蒸馏目标利用预训练扩散模型的梯度来改善文本对齐和样本质量。此外,我们使用预训练的扩散模型权值初始化模型,而不是从头开始训练;已知预训练生成器网络可以显著改善具有对抗性损失的训练[15]。最后,我们采用了标准的扩散模型框架,而不是使用仅用于GAN训练的解码器架构[26,27]。这种设置自然支持迭代细化。
3.1. 训练过程
我们的训练过程如图2所示,涉及三个网络:ADD-student从一个权重为θ的预训练UNet-DM初始化,一个权重为φ的可训练鉴别器初始化,一个权重为ψ的DM教师初始化。在训练过程中,ADD-student从有噪声的数据xs中生成样本δ xθ(xs, s)。通过前向扩散过程xs = αsx0 + σ sε,从真实图像数据集x0产生带噪数据点。在我们的实验中,我们使用相同的系数αs和σs作为学生DM和样本s,均匀地来自集合Tstudent = {τ1,…, τn} (N个选择的学生时间步长)在实践中,我们选择N = 4。
重要的是,我们设置τn = 1000,并在训练过程中强制执行zero-terminalSNR[33],因为模型在推理过程中需要从纯噪声开始。
对于对抗目标,生成的样本x0和真实图像x0被传递给鉴别器,目的是区分它们。鉴别器和对抗性损耗的设计将在3.2节中详细描述。为了从DM老师那里提取知识,我们用老师的前向过程将学生样本δ xθ扩散到δ xθ,t,并使用老师的去噪预测δ xψ(δ xθ,t, t)作为蒸馏损失Ldistill的重建目标,参见3.3节。因此,总体目标是
虽然我们在像素空间中制定了我们的方法,但很容易将其适应于在潜在空间中工作的ldm。当使用师生共享潜在空间的ldm时,蒸馏损失可以在像素或潜在空间中计算。我们在像素空间中计算蒸馏损失,因为这在蒸馏潜在扩散模型时产生更稳定的梯度[72]。
3.2. 对抗损失
对于鉴别器,我们遵循[59]中提出的设计和训练程序,并对其进行简要总结;有关细节,我们建议读者参阅原文。我们使用一个固定的预训练特征网络F和一组可训练的轻量级鉴别器头d φ,k。对于特征网络F, Sauer等[59]发现视觉变压器(ViTs)[9]工作良好,我们在第4节中对ViTs目标和模型尺寸进行了不同的选择。将可训练判别器头应用于特征网络不同层的特征Fk上。
为了提高性能,鉴别器可以通过投影以附加信息为条件[46]。通常,在文本到图像设置中使用文本嵌入ctext。但是,与标准GAN训练相比,我们的训练配置还允许对给定图像进行条件设置。当τ < 1000时,ADD-student接收到来自输入图像x0的信号。因此,对于给定的生成样本xθ(xs, s),我们可以根据来自x0的信息来约束鉴别器。这鼓励add学生有效地利用输入。在实践中,我们使用附加的特征网络来提取嵌入cimg的图像。
接下来[57,59],我们使用hinge loss[32]作为对抗目标函数。因此,add学生的对抗性目标Ladv(xθ(xs, s), ϕ)等于
而鉴别器被训练成最小化
其中R1为R1梯度惩罚[44]。我们不是计算相对于像素值的梯度惩罚,而是在每个鉴别器头dpn,k的输入上计算它。我们发现,当输出分辨率大于128像素时,R1惩罚特别有益。
3.3. 分数蒸馏损失
式(1)中的蒸馏损失表示为
式中sg为停止梯度操作。直观地说,损失使用距离度量d来测量由ad -student和DM-teacher的输出生成的样本xθ之间的不匹配:xψ(φ xθ,t, t) = (φ xθ,t - σt φ ϵψ(φ xθ,t, t))/αt随时间步长t和噪声ε′的平均值。值得注意的是,教师并没有直接应用于ADD-student的生成的δ xθ上,而是应用于扩散的输出δ xθ上,t = αt δ xθ + σ δ δ ',因为对于教师模型来说,非扩散的输入将不在分布范围内[68]。
下面,我们定义距离函数d(x, y):= ||x−y||2。对于加权函数c(t),我们考虑两种选择:指数加权,其中c(t) = αt(噪声水平越高贡献越小),以及分数蒸馏抽样(SDS)加权[51]。在补充材料中,我们证明了当d(x, y) = ||x−y||2以及c(t)的特定选择时,我们的蒸馏损失等于SDS目标LSDS,如[51]中提出的那样。我们的公式的优点是它能够实现重建目标的直接可视化,并且它自然地促进了几个连续去噪步骤的执行。最后,我们还评估了无噪声分数蒸馏(NFSD)目标,这是最近提出的SDS的变体[28]。
4. 实验
在我们的实验中,我们训练了两个不同容量的模型,ADD-M (860M参数)和ADD-XL (3.1B参数)。为了消融ADD-M,我们使用了稳定扩散(SD) 2.1主干[54],为了与其他基线进行公平比较,我们使用了SD1.5。ADD-XL使用SDXL[50]骨干网。所有实验均在512x512像素的标准化分辨率下进行;产生更高分辨率的模型的输出被下采样到这个大小。
我们在所有实验中采用λ = 2.5的蒸馏加权因子。此外,R1惩罚强度γ设为10−5。对于判别器条件,我们使用预训练的clip - vit -g-14文本编码器[52]来计算文本嵌入上下文,并使用DINOv2 vit- l编码器[47]的CLS嵌入来进行图像嵌入。对于基线,我们使用了最好的公开可用模型:潜在扩散模型[50,54](SD1.5, SDXL)级联像素扩散模型[55](IF-XL),蒸馏扩散模型[39,41](LCM-1.5, LCM-1.5- xl),以及OpenMUSE [48],这是MUSE[6]的重新实现,是专门为快速推理而开发的转换器模型。注意,我们比较的是没有附加细化模型的sdxl - base -1.0模型;这是为了保证公平的比较。由于没有公开的最先进的GAN模型,我们使用改进的鉴别器重新训练StyleGAN-T[59]。这个基线(stylegan - t++)在FID和CS中明显优于之前最好的gan,见补充。我们通过FID[18]量化样本质量,通过CLIP评分[17]量化文本对齐。对于CLIP评分,我们使用LAION-2B训练的ViT-g-14模型[61]。这两个指标在来自COCO2017的5k个样本上进行了评估[34]。
4.1. 消融实验
我们的训练设置在对抗性损失、蒸馏损失、初始化和损失相互作用方面开辟了许多设计空间。我们对表1中的几个选择进行了消融研究;每个表格下面突出显示了关键的见解。我们将在下文中讨论每个实验。
鉴别器特征网络。(表1)。Stein等人[67]最近的见解表明,使用CLIP[52]或DINO[5,47]目标训练的vit特别适合评估生成模型的性能。同样地,这些模型作为鉴别器特征网络似乎也很有效,其中DINOv2成为最佳选择。
鉴别器。(表1 b)。与先前的研究类似,我们观察到鉴别器的文本条件反射增强了结果。值得注意的是,图像条件反射优于文本条件反射,并且文本和图像条件反射的组合产生了最好的结果。
学生pretraining。(表1 c)。我们的实验证明了对add学生进行预训练的重要性。与纯GAN方法相比,能够使用预训练的生成器是一个显著的优势。gan的一个问题是缺乏可扩展性;Sauer等[59]和Kang等[25]都观察到在达到一定网络容量后,性能会达到饱和。这一观察结果与dm一般平滑的标度规律形成了对比[49]。然而,ADD可以有效地利用更大的预训练DM(见表1c),并从稳定的DM预训练中获益。
损失项。(表1 d)。我们发现这两种损失都是必不可少的。蒸馏损失本身是无效的,但当与对抗损失相结合时,结果会有明显的改善。不同的权重计划导致不同的行为,指数计划倾向于产生更多样化的样本,正如较低的FID, SDS和NFSD计划所表明的那样,可以提高质量和文本对齐。当我们使用指数计划作为所有其他消融的默认设置时,我们选择NFSD加权来训练我们的最终模型。选择最优权重函数提供了改进的机会。或者,如3D生成建模文献[23]中所探讨的,可以考虑在训练上调度蒸馏权值。
老师类型。(表1 e)。有趣的是,一个更大的学生和老师并不一定会带来更好的FID和CS。相反,学生采用教师的特点。SDXL通常获得更高的FID,可能是因为其输出的多样性较少,但它具有更高的图像质量和文本对齐[50]。
老师的步数。(表1)。虽然我们的蒸馏损失公式允许通过构造与老师一起采取连续的几个步骤,但我们发现几个步骤并不能最终导致更好的性能。
4.2. 与最新技术的定量比较
对于我们与其他方法的主要比较,我们避免使用自动化指标,因为用户偏好研究更可靠[50]。在这项研究中,我们的目标是评估及时的依从性和整体形象。作为一种性能度量,我们在比较几种方法时计算两两比较的获胜百分比和ELO分数。对于报告的ELO分数,我们计算提示跟随和图像质量之间的平均分数。关于ELO分数计算和研究参数的详细信息在补充材料中列出。
图5、图6为研究结果。最重要的结果是:首先,ADD-XL仅用一步就优于LCM-XL(4步)。其次,ADD-XL在大多数比较中可以击败SDXL(50步)。这使得ADD-XL在单步和多步设置中都是最先进的。图7显示了相对于推理速度的ELO分数。最后,表2比较了使用相同基本模型的不同小步取样和蒸馏方法。ADD优于所有其他方法,包括标准的八步DPM求解器。
4.3. 定性结果
为了补充我们上面的定量研究,我们在本节中提出了定性结果。为了描绘更完整的画面,我们在补充材料中提供了额外的样本和定性比较。图3比较了ADD-XL(1步)和当前的最佳基线。ADD-XL的迭代采样过程如图4所示。这些结果显示了我们的模型在初始样本基础上的改进能力。这种迭代改进代表了与纯GAN方法(如stylegan - t++)相比的另一个显著优势。最后,图8将ADD-XL与其教师模型SDXL-Base进行了直接比较。从4.2节的用户研究中可以看出,ADD-XL在质量和快速对齐方面都优于它的老师。