MicroDiffusion——采用新的掩码方法和改进的 Transformer 架构,实现了低预算的扩散模型

介绍

论文地址:https://arxiv.org/abs/2407.15811
现代图像生成模型擅长创建自然、高质量的内容,每年生成的图像超过十亿幅。然而,从头开始训练这些模型极其昂贵和耗时。文本到图像(T2I)扩散模型降低了部分计算成本,但仍需要大量资源。

目前最先进的技术需要大约 18 000 个 A100 GPU 小时,而使用 8 个 H100 GPU 进行训练则需要一个多月的时间。此外,该技术通常依赖于大型或专有数据集,因此难以普及。

在这篇评论性论文中,我们开发了一种低成本、端到端文本到图像扩散建模管道,目的是在没有大型数据集的情况下显著降低成本。它侧重于基于视觉变换器的潜在扩散模型,利用其简单的设计和广泛的适用性。为了降低计算成本,可通过随机屏蔽输入标记来减少每幅图像需要处理的斑块数量。本文克服了现有遮蔽方法在高遮蔽率下性能下降的难题。

为了克服文本到图像扩散模型性能不佳的问题,本文提出了一种 "延迟掩蔽 "策略。通过在轻量级补丁混合器中处理补丁,然后将其输入扩散变换器,可以低成本实现可靠的训练,同时即使在高掩蔽率下也能保留语义信息。它还结合了变压器架构的最新发展,以提高大规模训练的性能。

该实验训练了一个 1.16 亿参数的稀疏扩散变换器,预算仅为 1,890 美元、3,700 万张图像和 75% 的屏蔽率。结果,在 COCO 数据集上零镜头生成的 FID 达到了 12.7。在一台 8×H100 GPU 机器上的训练时间仅为 2.6 天,与目前最先进的方法(37.6 天,GPU 成本 28,400 美元)相比缩短了 14 倍。

建议方法

延迟掩蔽

由于变换器的计算复杂度与序列的长度成正比,降低训练成本的一种方法是通过使用大尺寸补丁来减少序列,如图 1-b 所示。使用大尺寸补丁可使每幅图像的补丁数量呈二次方减少,但由于图像的大区域会被主动压缩成单个补丁,因此会显著降低性能。

一种方法是使用遮罩去除变换器输入层中的一些斑块,如图 1-c 所示,同时保持斑块大小。这种方法类似于卷积网络中的随机裁剪训练,但遮蔽补丁可以在图像的非连续区域进行训练。这种方法被广泛应用于视觉和语言领域。

图 1-d 中的MaskDiT还增加了补充自编码损失,以鼓励从遮蔽的斑块中学习表示法,从而促进遮蔽斑块的重建。这种方法屏蔽了 75% 的输入图像,大大降低了计算成本。

图 1.压缩补丁序列以降低计算成本。

然而,高遮罩率会大大降低转换器的整体性能:即使使用 MaskDiT,也只能看到与简单遮罩相比的微弱改进。这是因为即使采用这种方法,大部分图像斑块也会在输入层被去除。

本文引入了一个名为 "补丁混合器 "的预处理模块,用于在屏蔽之前处理补丁嵌入。这可确保未屏蔽的补丁保留整个图像的信息,从而提高学习效率。这种方法有可能提高性能,同时在计算上与现有的 MaskDiT 策略相当。

补丁混合器和学习障碍

补片混合器指的是任何能够融合单个补片嵌入的神经架构。在变换器模型中,这一目标自然可以通过注意机制和前馈层的结合来实现。因此,本文使用轻量级变压器(只有几层)作为补片混合器。输入序列标记经补丁混合器处理后,将被屏蔽(图 2e)。假定掩码为二进制 m,则使用以下损失函数对模型进行训练。

引入专家混合(MoE)和分层缩放的变换器架构

论文采用了先进变压器结构的创新技术,在计算受限的情况下提高了模型性能。

  • 专家混合物(MoE,Zhou 等人,2022 年):使用 MoE 层扩展模型的参数和表现力,同时避免训练成本的显著增加 简化的 MoE 层与专家选择路由允许额外的辅助损失函数无需调整负载。
  • 分层缩放 (Mehta 等人,2024 年 ):这种方法已被证明能提高大型语言模型的性能,其中变换器块的宽度(隐藏层的维度)随深度线性增加。更多的参数被分配给更深的层,以学习更复杂的特征。

整体架构如图 2 所示。

图 2:拟议方法的总体概览。

试验

验证延迟遮蔽和补丁混合器的有效性

当许多补丁被遮蔽时,遮蔽性能会下降;Zheng 等人(2024 年)指出,当遮蔽率超过 50%,MaskDiT 的性能会显著下降。本文评估了遮蔽率高达 87.5% 时的性能,并将其与不使用补丁混合器的传统天真遮蔽方法进行了比较。本文中的 "延迟掩蔽 "使用了一个四层变压器块贴片混合器,其参数小于主干变压器参数的 10%。两者都使用了设置完全相同的 AdamW 优化器。

图 3 对结果进行了总结。延迟掩蔽在所有指标上都明显优于天真掩蔽和MaskDiT,表明随着掩蔽率的增加,性能差异也在扩大。例如,在屏蔽率为 75% 时,原始屏蔽的 FID 得分为 80,MaskDiT 的FID 得分为16.5,而拟议方法的 FID 得分为 5.03,优于未屏蔽时的 3.79。

图 3:验证延迟屏蔽和贴片混频器的有效性。

验证 "专家混合 "和 "分层缩放 "的有效性

分层缩放: 使用 DiT-Tiny 架构进行的实验比较了分层缩放和恒宽变换器与天真屏蔽。两个模型都在相同的计算负荷下进行了相同时间的训练。在所有性能指标上,逐层缩放方法始终优于恒定宽度模型,而且在屏蔽训练中更为有效。

专家混合物(MoE): 测试了在交替区块中具有 MoE 层的 DiT-Tiny/2 变压器。总体性能与无 MoE 层的基线模型相似,Clip-score 略有提高(从 28.11 到 28.66),FID 分数有所下降(从 6.92 到 6.98)。改进幅度有限的原因是 60K 步的训练和每位专家看到的样本量较小。

与以往研究的比较

COCO 数据集(表 1)上的零镜头图像生成: 根据标题生成 30 000 幅图像,并使用 FID-30K 比较其与真实图像的分布情况。拟议方法的FID-30K 得分为 12.66,与之前的低成本训练方法相比,计算成本降低了 14 倍,而且不依赖于专有数据集。该方法的计算成本也比 Würstchen 低 19 倍(Pernias et al.)

表 1:在 COCO 数据集上生成零镜头图像

详细图像生成比较**(表 2)****:**GenEval(Ghosh 等人,2024 年)用于评估生成物体位置、共现、数量和颜色的能力。与 Stable-DiffusionXL-turbo 和 PixArt-α 模型相比,拟议方法在单个物体生成方面的准确性接近完美,与 Stable-Diffusion 变体相当,优于 Stable-Diffusion-1.5。在以下方面也表现出卓越的性能

表 2.详细图像生成对比

总结

本评论文章重点讨论了旨在降低扩散变换器训练计算成本的补丁掩蔽策略。本文提出了一种 "延迟掩蔽 "策略,以缓解现有掩蔽方法的不足,并显示了所有掩蔽比率下的显著性能改进。

特别是使用了 75% 的延迟掩蔽率,并在真实和合成图像数据集上进行了大规模训练。尽管与最先进的技术相比成本大大降低,但还是取得了具有竞争力的零镜头图像生成性能结果。希望这种低成本的训练机制能鼓励更多研究人员参与大规模扩散模型的训练和开发。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/944991.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用 Three.js 创建一个 3D 人形机器人仿真系统

引言 在这篇文章中,我们将探讨如何使用 Three.js 创建一个简单但有趣的 3D 人形机器人仿真系统。这个机器人可以通过键盘控制进行行走和转向,并具有基本的动画效果。 技术栈 HTML5Three.jsJavaScript 实现步骤 1. 基础设置 首先,我们需要…

【c++高阶DS】最小生成树

🔥个人主页:Quitecoder 🔥专栏:c笔记仓 目录 01.最小生成树Kruskal算法Prim算法 01.最小生成树 连通图中的每一棵生成树,都是原图的一个极大无环子图,即:从其中删去任何一条边,生成…

自学记录鸿蒙API 13:实现人脸比对Core Vision Face Comparator

完成了文本识别和人脸检测的项目后,我发现人脸比对是一个更有趣的一个小技术玩意儿。我决定整一整,也就是对HarmonyOS Next最新版本API 13中的Core Vision Face Comparator API的学习,这项技术能够对人脸进行高精度比对,并给出相似…

2024/12/29 黄冈师范学院计算机学院网络工程《路由期末复习作业一》

一、选择题 1.某公司为其一些远程小站点预留了网段 172.29.100.0/26,每一个站点有10个IP设备接到网络,下面那个VLSM掩码能够为该需求提供最小数量的主机数目 ( ) A./27 B./28 C./29 D./30 -首先审题我们需要搞清楚站点与网…

redis cluster集群

华子目录 什么是redis集群redis cluster的体系架构什么是数据sharding?什么是hash tag集群中删除或新增节点,数据如何迁移?redis集群如何使用gossip通信?定义meet信息ping消息pong消息fail消息(不是用gossip协议实现的&#xff0…

PrimeVue菜单模块(Menu),看api的重要性

以下是对PrimeVue菜单模块(Menu)的API属性的中文详解: 一、整体概述 PrimeVue的菜单(Menu)是一个支持动态和静态定位的导航/命令组件,其API通过定义一些辅助的属性(props)、事件等&…

STM32中断详解

STM32中断详解 NVIC 中断系统中断向量表相关寄存器中断优先级中断配置 外部中断实验EXTI框图外部中断/事件线映射中断步骤初始化代码实现 定时器中断通用定时器相关功能标号1:时钟源标号 2:控制器标号 3:时基单元 代码实现 NVIC 中断系统 STM…

从零开始开发纯血鸿蒙应用之逻辑封装

从零开始开发纯血鸿蒙应用 一、前言二、逻辑封装的原则三、实现 FileUtil1、统一的存放位置2、文件的增删改查2.1、文件创建与文件保存2.2、文件读取2.2.1、读取内部文件2.2.2、读取外部文件 3、文件删除 四、总结 一、前言 应用的动态,借助 UI 响应完成&#xff0…

《机器学习》——线性回归模型

文章目录 线性回归模型简介一元线性回归模型多元线性回归模型误差项分析一元线性模型实例完整代码 多元线性模型实例完整代码 线性回归模型简介 线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。 相关关系&…

【深度学习环境】NVIDIA Driver、Cuda和Pytorch(centos9机器,要用到显示器)

文章目录 一 、Anaconda install二、 NIVIDIA driver install三、 Cuda install四、Pytorch install 一 、Anaconda install Step 1 Go to the official website: https://www.anaconda.com/download Input your email and submit. Step 2 Select your version, and click i…

在HTML中使用Vue如何使用嵌套循环把集合中的对象集合中的对象元素取出来(我的意思是集合中还有一个集合那种)

在 Vue.js 中处理嵌套集合(即集合中的对象包含另一个集合)时,使用多重 v-for 指令来遍历这些层次结构。每个 v-for 指令可以用于迭代一个特定级别的数据集,并且可以在模板中嵌套多个 v-for 来访问更深层次的数据。 例如&#xff…

ip归属地是什么意思?ip归属地是实时定位吗

在数字化时代,IP地址作为网络设备的唯一标识符,不仅关乎设备间的通信,还涉及到用户的网络身份与位置信息。其中,IP归属地作为IP地址的地理位置信息,备受用户关注。本文将详细解析IP归属地的含义,并探讨其是…

基于BP训练深度学习模型(用于回归)以及验证误差值

用原生Python训练了一个BP网络,适合没有pytorch等环境的电脑,并用训练的模型对原始数据进行了预测,拿来估测比较误差值了,可以直接拿去用(需根据个人数据来调训练次数、学习效率),代码在文章末。…

C#冒泡排序

一、冒泡排序基本原理 冒泡排序是一种简单的排序算法。它重复地走访要排序的数列,一次比较两个元素,如果它们的顺序错误就把它们交换过来。走访数列的工作是重复地进行直到没有再需要交换,也就是说该数列已经排序完成。 以一个简单的整数数…

折腾日记:如何让吃灰笔记本发挥余热——搭建一个相册服务

背景 之前写过,我在家里用了一台旧的工作站笔记本做了服务器,连上一个绿联的5位硬盘盒实现简单的网盘功能,然而,还是觉的不太理想,比如使用filebrowser虽然可以备份文件和图片,当使用手机使用网页&#xf…

从0入门自主空中机器人-2-1【无人机硬件框架】

关于本课程: 本次课程是一套面向对自主空中机器人感兴趣的学生、爱好者、相关从业人员的免费课程,包含了从硬件组装、机载电脑环境设置、代码部署、实机实验等全套详细流程,带你从0开始,组装属于自己的自主无人机,并让…

剑指Offer|LCR 013. 二维区域和检索 - 矩阵不可变

LCR 013. 二维区域和检索 - 矩阵不可变 给定一个二维矩阵 matrix,以下类型的多个请求: 计算其子矩形范围内元素的总和,该子矩阵的左上角为 (row1, col1) ,右下角为 (row2, col2) 。 实现 NumMatrix 类: NumMatrix(…

接口Mock技术介绍

相信学习过程序设计的读者朋友们,一定对“桩(Stub)”这个概念并不陌生。它是指用来替换一部分功能的程序代码段。桩程序代码段可以用来模拟已有程序的某些功或者是将实现的系统代码的一种临时替代方法。插桩方法被广泛应用于开发和测试工作中…

深入解析C#异步编程:await 关键字背后的实现原理

在C#中,async 和 await 关键字用于编写异步代码。本文将详细介绍 await 的实现原理,包括状态机的生成、回调函数的注册和触发等关键步骤。 1. 异步方法的基本概念 在C#中,async 关键字标记一个方法为异步方法,而 await 关键字用于…

【机器学习】SVM支持向量机(一)

介绍 支持向量机(Support Vector Machine, SVM)是一种监督学习模型,广泛应用于分类和回归分析。SVM 的核心思想是通过找到一个最优的超平面来划分不同类别的数据点,并且尽可能地最大化离该超平面最近的数据点(支持向量…