info | |
---|---|
paper | https://arxiv.org/abs/2212.09748 |
github | https://github.com/facebookresearch/DiT/tree/main |
个人博客主页 | http://myhz0606.com/article/dit |
create date | 2024-03-08 |
阅读前需要具备以下前置知识:
DDPM
(扩散模型基本原理):知乎地址 个人博客地址 paper
LDM
(隐空间扩散模型基本原理,stable diffusion 底层架构) 知乎地址 个人博客地址 paper
classifier-free guided(文生图基本原理) 知乎地址 个人博客地址 paper
Motivate
虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。本文旨在探讨以Transformer取代UNet
在扩散模型中的可行性和潜在方案,并对所提出的Diffusion Transformer (DIT
)架构的scalable能力进行了验证和评估。
Method
采用DiT
架构替换UNet
主要需要探索以下几个关键问题:
- Token化处理。Transformer的输入为一维序列,形式为
R
T
×
d
\mathbb{R}^{T \times d}
RT×d(忽略batch维度),而
LDM
的latent表征 z ∈ R H f × W f × C z \in \mathbb{R}^{\frac{H}{f} \times \frac{W}{f} \times C} z∈RfH×fW×C为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。 - 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代
Unet
需要系统研究Transformer架构的条件嵌入
DiT
这篇paper的核心在于对上述两个问题的系统研究。
Patchify(token化)
假定原始图片
x
∈
R
256
×
256
×
3
x \in \mathbb{R} ^ {256\times256\times3}
x∈R256×256×3,经过auto-encoder
后得到latent表征
z
∈
R
32
×
32
×
4
z \in \mathbb{R} ^ {32\times32\times4}
z∈R32×32×4。首先DiT
用ViT中patch化的方式将隐表征
z
z
z 转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p
是一个超参数。文中分别尝试了p=2,4,8。(DiT
的输出会将每一个token线性解码成pxpx2C,再reshape为nose和协方差)
DiT block设计
这个部分系统探究了4中在DiT中引入控制信号的方案。
(一)In-context conditioning
直接将时间步信号、文本控制信号作为addition token和输入sequence进行拼接。其角色类似于类似于ViT里面的[CLS]
token。这样做有一个好处,原本的ViT架构都可以不动,并且增加的的计算量可以忽略不计。
(二)Cross-Attention block
这个方法首先将时间步信号
(
R
1
×
d
)
(\mathbb{R} ^{1 \times d})
(R1×d)和文本信号
(
R
1
×
d
)
(\mathbb{R} ^{1 \times d})
(R1×d)进行拼接,得到拼接后的控制信号
(
R
2
×
d
)
(\mathbb{R} ^{2 \times d})
(R2×d)。随后类似文献[1]的做法,在ViT
中添加cross attention层,将控制信号作为cross-attention的key,value进行融入。
(三)Adaptive Layer Norm (adaLN) block
作者参考文献[2]提出的adaptive normalization layer(adaLN
),将transformer block的layer norm替换为adaLN
。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter
γ
\gamma
γ和shift parameter
β
\beta
β 用condition embedding来替代。下面给出了最简的示例代码便于理解。
论文原话:Rather than directly learn dimensionwise scale and shift parameters γ and β, we regress them from the sum of the embedding vectors of t and c.
import numpy as np
class LayerNorm:
def __init__(self, feature_dim, epsilon=1e-6):
self.epsilon = epsilon
self.gamma = np.random.rand(feature_dim) # scale parameters
self.beta = np.random.rand(feature_dim) # shift parametrs
def __call__(self, x: np.ndarray) -> np.ndarray:
"""
Args:
x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
return:
x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
"""
_mean = np.mean(x, axis=-1, keepdims=True)
_std = np.var(x, axis=-1, keepdims=True)
x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.beta
return x_layer_norm
class DiTAdaLayerNorm:
def __init__(self,feature_dim, epsilon=1e-6):
self.epsilon = epsilon
self.weight = np.random.rand(feature_dim, feature_dim * 2)
def __call__(self, x, condition):
"""
Args:
x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
condition (np.ndarray): shape: (batch_size, 1, feature_dim)
Ps: condition = time_cond_embedding + class_cond_embedding
return:
x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
"""
affine = condition @ self.weight # shape: (batch_size, 1, feature_dim * 2)
gamma, beta = np.split(affine, 2, axis=-1)
_mean = np.mean(x, axis=-1, keepdims=True)
_std = np.var(x, axis=-1, keepdims=True)
x_layer_norm = gamma * (x - _mean / (_std + self.epsilon)) + beta
return x_layer_norm
(四)adaLN-Zero block
这个方法是(三)的延伸。简单来说就是condition embedding除了融入到layer norm中,还作为residual的强度融入到residual连接中。下面给出了最简的示例代码
import numpy as np
class LayerNorm:
def __init__(self, epsilon=1e-6):
self.epsilon = epsilon
def __call__(self, x: np.ndarray, gamma: np.ndarray, beta: np.ndarray) -> np.ndarray:
"""
Args:
x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
gamma (np.ndarray): shape: (batch_size, 1, feature_dim), generated by condition embedding
beta (np.ndarray): shape: (batch_size, 1, feature_dim), generated by condition embedding
return:
x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
"""
_mean = np.mean(x, axis=-1, keepdims=True)
_std = np.var(x, axis=-1, keepdims=True)
x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.beta
return x_layer_norm
class DiTBlock:
def __init__(self, feature_dim):
self.MultiHeadSelfAttention = lambda x: x # mock multi-head self-attention
self.layer_norm = LayerNorm()
self.MLP = lambda x: x # mock multi-layer perceptron
self.weight = np.random.rand(feature_dim, feature_dim * 6)
def __call__(self, x: np.ndarray, time_embedding: np.ndarray, class_emnedding: np.ndarray) -> np.ndarray:
"""
Args:
x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
time_embedding (np.ndarray): shape: (batch_size, 1, feature_dim)
class_emnedding (np.ndarray): shape: (batch_size, 1, feature_dim)
return:
x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
"""
condition_embedding = time_embedding + class_emnedding
affine_params = condition_embedding @ self.weight # shape: (batch_size, 1, feature_dim * 6)
gamma_1, beta_1, alpha_1, gamma_2, beta_2, alpha_2 = np.split(affine_params, 6, axis=-1)
x = x + alpha_1 * self.MultiHeadSelfAttention(self.layer_norm(x, gamma_1, beta_1))
x = x + alpha_2 * self.MLP(self.layer_norm(x, gamma_2, beta_2))
return x
Result
作者在imagenet
数据上,以classifier-free的方式训练DiT
(仅做class-control,即text condition embedding为类别embedding)。作者设置了4种不同model size的DiT
,并开展实验。
DiT的scalable能力验证
作者分别尝试了的patch size,不同model size的DiT
,从图中不难发现
- patch size越小生成的效果越好(意味着初始时sequence的token数越多)。这里不太明白为什么作者不实验p=1的情形。因为latent表征本身就可以视作是CNN抽取的隐式token,只要flatten即可,很多hybrid的架构(CNN+ViT)都是这么玩的,或许是为了控制计算量?
- model size越大生成效果越好。从实验结果中
DiT-XL
和DiT-L
的差距很小,可能是因为训练数据量还不够大体现不出大模型的优势
DiT Block有效性验证
作者在imagenet数据集上验证上面提出的四种DiT
block的的生成效果。ada LN-Zero
方案的生成效果最好。
小结
DiT
系统研究了diffusion transformer的token化和条件嵌入两个关键问题,验证了基于transformer架构的扩散模型的scalable能力。
参考文献
[1] Attention is all you need.
[2] Film: Visual reasoning with a general conditioning layer.