diffusion model(十三):DiT技术小结

info
paperhttps://arxiv.org/abs/2212.09748
githubhttps://github.com/facebookresearch/DiT/tree/main
个人博客主页http://myhz0606.com/article/dit
create date2024-03-08

阅读前需要具备以下前置知识:

DDPM(扩散模型基本原理):知乎地址 个人博客地址 paper

LDM (隐空间扩散模型基本原理,stable diffusion 底层架构) 知乎地址 个人博客地址 paper

classifier-free guided(文生图基本原理) 知乎地址 个人博客地址 paper

Motivate

虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。本文旨在探讨以Transformer取代UNet在扩散模型中的可行性和潜在方案,并对所提出的Diffusion Transformer (DIT)架构的scalable能力进行了验证和评估。

Method

采用DiT架构替换UNet主要需要探索以下几个关键问题:

  1. Token化处理。Transformer的输入为一维序列,形式为 R T × d \mathbb{R}^{T \times d} RT×d(忽略batch维度),而LDM的latent表征 z ∈ R H f × W f × C z \in \mathbb{R}^{\frac{H}{f} \times \frac{W}{f} \times C} zRfH×fW×C为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。
  2. 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代Unet需要系统研究Transformer架构的条件嵌入

DiT这篇paper的核心在于对上述两个问题的系统研究。

在这里插入图片描述

Patchify(token化)

假定原始图片 x ∈ R 256 × 256 × 3 x \in \mathbb{R} ^ {256\times256\times3} xR256×256×3,经过auto-encoder后得到latent表征 z ∈ R 32 × 32 × 4 z \in \mathbb{R} ^ {32\times32\times4} zR32×32×4。首先DiT 用ViT中patch化的方式将隐表征 z z z 转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p是一个超参数。文中分别尝试了p=2,4,8。(DiT的输出会将每一个token线性解码成pxpx2C,再reshape为nose和协方差)

在这里插入图片描述

DiT block设计

这个部分系统探究了4中在DiT中引入控制信号的方案。

(一)In-context conditioning

直接将时间步信号、文本控制信号作为addition token和输入sequence进行拼接。其角色类似于类似于ViT里面的[CLS]token。这样做有一个好处,原本的ViT架构都可以不动,并且增加的的计算量可以忽略不计。

(二)Cross-Attention block

这个方法首先将时间步信号 ( R 1 × d ) (\mathbb{R} ^{1 \times d}) (R1×d)和文本信号 ( R 1 × d ) (\mathbb{R} ^{1 \times d}) (R1×d)进行拼接,得到拼接后的控制信号 ( R 2 × d ) (\mathbb{R} ^{2 \times d}) (R2×d)。随后类似文献[1]的做法,在ViT中添加cross attention层,将控制信号作为cross-attention的key,value进行融入。

(三)Adaptive Layer Norm (adaLN) block

作者参考文献[2]提出的adaptive normalization layer(adaLN),将transformer block的layer norm替换为adaLN。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter γ \gamma γ和shift parameter β \beta β 用condition embedding来替代。下面给出了最简的示例代码便于理解。

论文原话:Rather than directly learn dimensionwise scale and shift parameters γ and β, we regress them from the sum of the embedding vectors of t and c.

import numpy as np

class LayerNorm:
    def __init__(self, feature_dim, epsilon=1e-6):
        self.epsilon = epsilon
        self.gamma = np.random.rand(feature_dim)  # scale parameters
        self.beta = np.random.rand(feature_dim)  # shift parametrs

    def __call__(self, x: np.ndarray) -> np.ndarray:
        """
    Args:
        x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
    return:
            x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
    """
        _mean = np.mean(x, axis=-1, keepdims=True)
        _std = np.var(x, axis=-1, keepdims=True)
        x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.beta
        return x_layer_norm

class DiTAdaLayerNorm:
    def __init__(self,feature_dim, epsilon=1e-6):
        self.epsilon = epsilon
        self.weight = np.random.rand(feature_dim, feature_dim * 2)

    def __call__(self, x, condition):
        """
        Args:
            x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
            condition (np.ndarray): shape: (batch_size, 1, feature_dim)
                Ps: condition = time_cond_embedding + class_cond_embedding
        return:
            x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
        """
        affine = condition @ self.weight  # shape: (batch_size, 1, feature_dim * 2)
        gamma, beta = np.split(affine, 2, axis=-1)
        _mean = np.mean(x, axis=-1, keepdims=True)
        _std = np.var(x, axis=-1, keepdims=True)
        x_layer_norm = gamma * (x - _mean / (_std + self.epsilon)) + beta
        return x_layer_norm

(四)adaLN-Zero block

这个方法是(三)的延伸。简单来说就是condition embedding除了融入到layer norm中,还作为residual的强度融入到residual连接中。下面给出了最简的示例代码

import numpy as np

class LayerNorm:
    def __init__(self, epsilon=1e-6):
        self.epsilon = epsilon

    def __call__(self, x: np.ndarray, gamma: np.ndarray, beta: np.ndarray) -> np.ndarray:
        """
    Args:
        x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
            gamma (np.ndarray): shape: (batch_size, 1, feature_dim), generated by condition embedding
            beta (np.ndarray): shape: (batch_size, 1, feature_dim), generated by condition embedding
    return:
            x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
    """
        _mean = np.mean(x, axis=-1, keepdims=True)
        _std = np.var(x, axis=-1, keepdims=True)
        x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.beta
        return x_layer_norm

class DiTBlock:
    def __init__(self, feature_dim):
        self.MultiHeadSelfAttention = lambda x: x # mock multi-head self-attention
        self.layer_norm = LayerNorm()
        self.MLP = lambda x: x # mock multi-layer perceptron
        self.weight = np.random.rand(feature_dim, feature_dim * 6)

    def __call__(self, x: np.ndarray, time_embedding: np.ndarray, class_emnedding: np.ndarray) -> np.ndarray:
        """
        Args:
            x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
            time_embedding (np.ndarray): shape: (batch_size, 1, feature_dim)
            class_emnedding (np.ndarray): shape: (batch_size, 1, feature_dim)
        return:
            x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
        """
        condition_embedding = time_embedding + class_emnedding
        affine_params = condition_embedding @ self.weight  # shape: (batch_size, 1, feature_dim * 6)
        gamma_1, beta_1, alpha_1, gamma_2, beta_2, alpha_2 = np.split(affine_params, 6, axis=-1)
        x = x + alpha_1 * self.MultiHeadSelfAttention(self.layer_norm(x, gamma_1, beta_1))
        x = x + alpha_2 * self.MLP(self.layer_norm(x, gamma_2, beta_2))
        return x

Result

作者在imagenet数据上,以classifier-free的方式训练DiT(仅做class-control,即text condition embedding为类别embedding)。作者设置了4种不同model size的DiT,并开展实验。

在这里插入图片描述

DiT的scalable能力验证

作者分别尝试了的patch size,不同model size的DiT,从图中不难发现

  • patch size越小生成的效果越好(意味着初始时sequence的token数越多)。这里不太明白为什么作者不实验p=1的情形。因为latent表征本身就可以视作是CNN抽取的隐式token,只要flatten即可,很多hybrid的架构(CNN+ViT)都是这么玩的,或许是为了控制计算量?
  • model size越大生成效果越好。从实验结果中DiT-XLDiT-L的差距很小,可能是因为训练数据量还不够大体现不出大模型的优势

在这里插入图片描述

在这里插入图片描述

DiT Block有效性验证

作者在imagenet数据集上验证上面提出的四种DiT block的的生成效果。ada LN-Zero方案的生成效果最好。

在这里插入图片描述

小结

DiT 系统研究了diffusion transformer的token化和条件嵌入两个关键问题,验证了基于transformer架构的扩散模型的scalable能力。

参考文献

[1] Attention is all you need.

[2] Film: Visual reasoning with a general conditioning layer.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/448106.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

学习Java的第六天

目录 一、变量 1、变量的定义 2、变量的声明格式 3、变量的注意事项 4、变量的作用域 二、常量 三、命名规范 Java 语言支持如下运算符: 1、算术运算符 解析图: 示例: 2、赋值运算符 解析图: 示例: 3、关…

如何使用Everything+cpolar实现公网远程搜索下载内网储存文件资料

文章目录 前言1.软件安装完成后,打开Everything2.登录cpolar官网 设置空白数据隧道3.将空白数据隧道与本地Everything软件结合起来总结 前言 要搭建一个在线资料库,我们需要两个软件的支持,分别是cpolar(用于搭建内网穿透数据隧道…

鸿蒙报错:Hhvigor Update the SDKs by going to Tools > SDK Manager....

鸿蒙报错:Hhvigor Update the SDKs by going to Tools > SDK Manager… 打开setting里面的sdk,将API9工程下的全部勾上,应用下载 刚打开 js 和 Native 是没勾上的

黑苹果RX590驱动解决方案

遇到的问题: 1.手头上的显卡是 华硕RX590 GAME,MacOS运行查看到显存为7m,使用起来非常卡顿。 2.免驱后,屏幕紫色。 使用的工具如下: 工具包下载地址:https://download.csdn.net/download/qq_33544860/88944761 解压密码:20240311 流程如下: 解决无法免驱问题:刷入5…

力扣:链表篇章

1、链表 链表是一种通过指针串联在一起的线性结构,每一个节点由两部分组成,一个是数据域一个是指针域(存放指向下一个节点的指针),最后一个节点的指针域指向null(空指针的意思)。 ​

IP形象设计是什么设计?如何做?

随着市场竞争的激烈,越来越多的企业开始关注品牌形象的塑造和推广。在品牌形象中,知识产权形象设计是一个非常重要的方面。在智能和互联网的趋势下,未来的知识产权形象设计可能更加关注数字和社交网络。通过数字技术和社交媒体平台&#xff0…

npm install报错,error <https://npm.community>解决方法

报错信息如下: 分析原因: 1.可能是由于node版本过低,或者过高,解决方法看我另一文章:npm install报错,npm版本过高,需要切换低版本node,过程记录 2.网络问题导致 3.切换node版本后&#xff0…

Jmeter测试关联接口

Jmeter用于接口测试时,后一个接口经常需要用到前一次接口返回的结果,本文主要介绍jmeter通过正则表达式提取器来实现接口关联的方式,可供参考。 一、实例场景: 有如下两个接口,通过正则表达式提取器,将第…

【保姆级】Protobuf详解及入门指南

目录 Protobuf概述 什么是Protobuf 为什么要使用Protobuf Protobuf实战 环境配置 创建文件 解析/封装数据 附录 AQin.proto 完整代码 Protobuf概述 什么是Protobuf Protobuf(Protocol Buffers)协议😉 Protobuf 是一种由 Google 开…

Mysql8的优化(DBA)

Mysql8的优化 1、Mysql的安装优化1.1 修改配置参数(命令行、配件文件)1.1.1 命令行修改配置参数1.1.2 参数持久化1.1.3 Mysql多实例启动,以及配置密码文件 1.2 查询表的相关参数,以及表空间管理 2、Mysql高级优化(SQL&…

编译支持国密的抓包工具 WireShark

目录 前言WireShark支持国密的 WireShark小结前言 在上一篇文章支持国密的 Web 服务器中,我们搭建了支持国密的 Web 服务器,但是,我们使用 360 安全浏览器去访问,却出现了错误: 是我们的 Web 服务器没有配置好?在这里插入图片描述还是 360 安全浏览器不支持国密?还是两…

SSL证书:构建网络安全的基石

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

Linux学习——线程的控制

目录 ​编辑 一,线程的创建 二,线程的退出 1,在子线程内return 2,使用pthread_exit(void*) 三,线程等待 四,线程获取自己的id值 五,线程取消 六,线程分离 一,线程的创建 在对…

全面的 DevSecOps 指南:有效保护 CI/CD 管道的关键注意事项

数字化转型时代带来了对更快、更高效、更安全的软件开发流程的需求。DevSecOps:一种将安全实践集成到 DevOps 流程中的理念,旨在将安全性嵌入到开发生命周期的每个阶段 - 从代码编写到生产中的应用程序部署。DevSecOps 的结合可以带来许多好处&#xff0…

抖音视频提取gif怎么做?分分钟帮你生成gif

通过将视频转换成gif动图的方式能够方便的在各种平台上分享、传播。相较于视频文件,gif动图的体积更小,传播起来更方便,能够吸引大众的注意力。下面,就来给大家分享一个gif图片制作(https://www.gif.cn/)的…

mybatisplus的条件构造器

条件构造器wrapper,主要用于构造sql语句的where条件,他更擅长这个,但也可以用于构造其他类型的条件,比如order by、group by等。 条件构造器的使用经验: 基于QueryWrapper的查询 练习1. void testQueryWrapper(){Q…

服务器集群 -- nginx配置tcp负载均衡

当面临高流量、高可用性、水平扩展、会话保持或跨地域流量分发等需求时,单台服务器受限于硬件资源、性能有限不能满足应用场景的并发需求量时,引入负载均衡器部署多个服务器共同处理客户端的并发请求,可以帮助优化系统架构,提高系…

猫咪挑食怎么治?排行榜靠前适口性好的主食冻干推荐

在如今,养猫人士几乎都将自己的小猫咪视作珍宝,宠溺有加。但宠爱过度有时也会导致猫咪养成挑食的坏习惯。猫咪挑食怎么治呢?今天,我要分享一个既能让猫咪不受苦,又能纠正挑食问题的方法。 一、为什么猫会挑食呢&#x…

Linux调试器--gdb的介绍以及使用

文章目录 1.前言 ✒️2.介绍gdb✒️3.Debug模式和Release模式的区别✒️4.如何使用gdb✒️1️⃣.在debug模式下编译2️⃣.进入调试3️⃣ .调试命令集合⭐️⭐️ 1.前言 ✒️ 🕗在我们之前的学习中已经学会了使用vim编译器编写c/c代码,但是对于一个程序员…

ThreadLocal出现内存泄露原因分析

ThreadLocal 导致内存泄漏的主要原因是它的工作方式。在 Java 中,ThreadLocal 通过维护一个以 Thread 为键,以用户设置的值为值的映射来工作。每个线程都拥有其自身的线程局部变量副本,不同线程间的这些变量互不干扰。这个映射是存储在每个 T…