Flow Matching For Generative Modeling

Flow Matching For Generative Modeling

一、基于流的(Flow based)生成模型

生成模型

我们先回顾一下所谓的生成任务,究竟是想要做什么事情。我们认为,世界上所有的图片,是符合某种分布 p d a t a ( x ) p_{data}(x) pdata(x) 的。当然,这个分布肯定是个极其复杂的分布。而我们有一堆图片 x 1 , x 2 , … , x m {x_1,x_2,\dots,x_m} x1,x2,,xm ,则可以认为是从这个分布中采样出来的 m m m 个样本。我们通过训练希望得到一个生成器网络 G G G ,该网络能够做到输入一个从正态分布 π ( z ) \pi(z) π(z) 中采样出来的 z z z ,输出一张看起来像真实世界的图片 x = G ( z ) ∼ p G ( x ) x=G(z)\sim p_G(x) x=G(z)pG(x) 。我们希望采样并生成出的数据分布 p G ( x ) p_G(x) pG(x) 与真实的数据分布 p d a t a ( x ) p_{data}(x) pdata(x) 越接近越好。

从概率模型的角度来看,想要做到上面说的这件事情,可以通过最大化对数似然 log ⁡ p G \log p_G logpG,来优化生成器 G G G 的参数:
G ∗ = arg ⁡ max ⁡ G ∑ i = 1 m log ⁡ p G ( x i ) G^*=\arg\max_{G}\sum_{i=1}^m\log p_G(x_i) G=argGmaxi=1mlogpG(xi)
可以证明,最大化这个对数似然,就相当于最小化生成器分布 p G ( x ) p_G(x) pG(x) 与目标分布 p d a t a ( x ) p_{data}(x) pdata(x) 的 KL散度,即让这两个分布尽量接近:
G ∗ ≈ arg ⁡ min ⁡ G K L ( p d a t a ∣ ∣ p G ) G^*\approx\arg\min_GKL(p_{data}||p_G) GargGminKL(pdata∣∣pG)

概率密度的变量变换定理

给定一个随机变量 z z z 及其概率密度函数 z ∼ π ( z ) z\sim\pi(z) zπ(z) ,通过一个一对一的映射函数 f f f 构造一个新的随机变量 x = f ( z ) x=f(z) x=f(z)。如果存在逆函数 f − 1 f^{-1} f1 满足 z = f − 1 ( x ) z=f^{-1}(x) z=f1(x),那么新变量 x x x 的概率密度函数 p ( x ) p(x) p(x) 计算如下:
p ( x ) = π ( z ) ∣ d z d x ∣ = π ( f − 1 ( x ) ) ∣ d f − 1 d x ∣ = π ( f − 1 ( x ) ) ∣ ( f ′ − 1 ( x ) ∣     若 z 为随机变量 p ( x ) = π ( z ) ∣ det ⁡ ( d z d x ) ∣ = π ( f − 1 ( x ) ) ∣ det ⁡ ( J f − 1 ) ∣      若 z 为随机向量 p(x)=\pi(z)|\frac{dz}{dx}|=\pi (f^{-1}(x))|\frac{df^{-1}}{dx}|=\pi(f^{-1}(x))|(f'^{-1}(x)| \ \ \ \ 若z为随机变量 \\ p(\mathbf{x})=\pi(\mathbf{z})|\det(\frac{d\mathbf{z}}{d\mathbf{x}})|=\pi(f^{-1}(\mathbf{x}))|\det(\mathbf{J}_{f^{-1}})| \ \ \ \ \ 若\mathbf{z}为随机向量 p(x)=π(z)dxdz=π(f1(x))dxdf1=π(f1(x))(f1(x)    z为随机变量p(x)=π(z)det(dxdz)=π(f1(x))det(Jf1)     z为随机向量
其中 det ⁡ ( ⋅ ) \det(\cdot) det() 表示行列式, J \mathbf{J} J 表示雅可比矩阵,是向量函数中因变量各维度关于自变量各维度的偏导数组成的矩阵,可类比为单变量函数的导数。

流模型推导

现在流行的生成模型五花八门,各显神通。VAE 优化变分下界 ELBO、GAN 通过对抗训练来隐式地逼近数据分布。流模型则可以直接优化对数似然。

在这里插入图片描述

流模型通过最大化对数似然,来优化生成器 G G G
G ∗ = arg ⁡ max ⁡ G ∑ i = 1 m log ⁡ p G ( x i ) G^*=\arg\max_{G}\sum_{i=1}^m\log p_G(x_i) G=argGmaxi=1mlogpG(xi)
而根据变量变换定理,有:
p G ( x i ) = π ( z i ) ∣ det ⁡ ( J G − 1 ) ∣ ,      z i = G − 1 ( x i ) p_G(x_i)=\pi(z_i)|\det(J_{G^{-1}})|,\ \ \ \ z_i=G^{-1}(x_i) pG(xi)=π(zi)det(JG1),    zi=G1(xi)
则对数似然:
log ⁡ p G ( x i ) = log ⁡ π ( G − 1 ( x i ) ) + log ⁡ ∣ det ⁡ ( J G − 1 ) ∣ \log p_G(x_i)=\log \pi(G^{-1}(x_i))+\log |\det(J_{G^{-1}})| logpG(xi)=logπ(G1(xi))+logdet(JG1)
要训练一个好的生成器 G G G 我们只需要训练一个(或一系列)网络完成从噪声分布 π ( z ) \pi(z) π(z) 到数据分布 p data ( x ) p_\text{data}(x) pdata(x) 的变换就可以了。在采样生成时,求出生成器的逆向网络 G − 1 G^{-1} G1,再将随机采样的噪声输入,即可生成新的符合数据分布的样本。只要最大化上面这个式子,就可以了。

现在的问题就是怎么把这个式子算出来,具体来说,这个式子计算的关键在以下两点:

  • 如何计算行列式 det ⁡ ( J G ) \det(J_G) det(JG)
  • 如何求逆矩阵 G − 1 G^{-1} G1

我们设计的生成器网络 G G G 需要满足上面这两个条件,这就是流模型生成器数学上的限制。在之前的流模型研究中,研究者们提出了许多设计精巧的网络(如 decoupling layer),可以巧妙地使得网络满足上述两点便利计算的要求。

另外要提一点,流模型的输入输出的尺寸必须是一致的。这是因为如果想要 G G G 可逆,它的输入输出维度一致是一个必要条件(非方阵不可能可逆)。比如要生成 100 × 100 × 3 100\times 100\times 3 100×100×3 的图像,那输入的随机噪声也是 100 × 100 × 3 100\times 100\times 3 100×100×3​​ 的。这与 VAE、GAN 等生成模型很不一样,这些生成模型的输入维度通常远小于输出维度。

堆叠多个网络

在实际中,由于可逆神经网络存在数学上的诸多限制,其单个网络的表达能力有限,我们一般需要堆叠多层网络来得到一个生成器,这也是 “流模型” 这个名称的由来。不过虽然堆叠了很多层,在公式上也没有什么复杂的。无非就是把一堆 G i G_i Gi 连乘起来,通过 log ⁡ \log log​ 之后,又变成连加。

比如我们有 K K K 个网络 { f i } i = 1 K \{f_i\}_{{i=1}}^K {fi}i=1K,对噪声分布 π ( z 0 ) \pi(\mathbf{z}_0) π(z0) 进行 K K K 步变换,得到数据 x \mathbf{x} x,即有:
x = z k = f K ( f K − 1 . . . f 1 ( z 0 ) ) \mathbf{x}=\mathbf{z}_k=f_K(f_{K-1}...f_1(\mathbf{z}_0)) x=zk=fK(fK1...f1(z0))
对于其中第 i i i 步有:
z i ∼ p i ( z i ) z i = f i ( z i − 1 ) ,    z i − 1 = f − 1 ( z i ) \mathbf{z}_i\sim p_i(\mathbf{z}_i)\\ \mathbf{z}_i=f_i(\mathbf{z}_{i-1}),\ \ \mathbf{z}_{i-1}=f^{-1}(\mathbf{z}_i) zipi(zi)zi=fi(zi1),  zi1=f1(zi)
根据变量变换定理,相邻两步之间的隐变量分布的关系为:
p i ( z i ) = p i − 1 ( f i − 1 ( z i ) ) ∣ det ⁡ J f i − 1 ∣ = p i − 1 ( z i − 1 ) ∣ det ⁡ J f i ∣ − 1 \begin{align} p_i(\mathbf{z}_{i})&=p_{i-1}(f_i^{-1}(\mathbf{z}_i))|\det\mathbf{J}_{f_i^{-1}}|\\ &=p_{i-1}(\mathbf{z}_{i-1})|\det\mathbf{J}_{f_i}|^{-1} \end{align} pi(zi)=pi1(fi1(zi))detJfi1=pi1(zi1)detJfi1
每一步的对数似然为:
log ⁡ p i ( z i ) = log ⁡ p i − 1 ( z i − 1 ) − log ⁡ ( det ⁡ ( J f i ) ) \log p_i(\mathbf{z}_i)=\log p_{i-1}(\mathbf{z}_{i-1})-\log(\det(\mathbf{J}_{f_i})) logpi(zi)=logpi1(zi1)log(det(Jfi))
对于整个 K K K 步的过程,对数似然为:
log ⁡ p ( x ) = log ⁡ p K ( z K ) = log ⁡ p K − 1 ( z K − 1 ) − log ⁡ ( det ⁡ ( J f K ) ) = log ⁡ p K − 2 ( z K − 2 ) − log ⁡ ( det ⁡ ( J f K − 1 ) ) − log ⁡ ( det ⁡ ( J f K ) ) =   . . . = log ⁡ π ( z 0 ) − ∑ i = 1 K log ⁡ ( det ⁡ ( J f i ) ) \begin{align} \log p(\mathbf{x})&=\log p_K(\mathbf{z}_K)\\ &=\log p_{K-1}(\mathbf{z}_{K-1})-\log(\det(\mathbf{J}_{f_K})) \\ &=\log p_{K-2}(\mathbf{z}_{K-2})-\log(\det(\mathbf{J}_{f_{K-1}}))-\log(\det(\mathbf{J}_{f_K})) \\ &=\ ... \\ &=\log\pi(\mathbf{z}_0)-\sum_{i=1}^K\log(\det(\mathbf{J}_{f_i})) \end{align} logp(x)=logpK(zK)=logpK1(zK1)log(det(JfK))=logpK2(zK2)log(det(JfK1))log(det(JfK))= ...=logπ(z0)i=1Klog(det(Jfi))

在这里插入图片描述

可以看到,流模型的核心思路就是通过多个可逆神经网络,一步步地将噪声分布转换为数据分布。在采样生成时,直接使用逆向网络,将随机采样的噪声样本转换为新的数据样本。

二、连续归一化流

常规流模型是在设定了离散的有限个(比如 K K K 个)可逆神经网络来逐步完成分布变换。而连续归一化流(Continuous Normalizing Flow, CNF),则是将其扩展为连续的情形。

设有 d d d 维空间中的数据 x = ( x 1 , x 2 , … , x d ) ∈ R d x=(x^1,x^2,\dots,x^d)\in\mathbb{R}^d x=(x1,x2,,xd)Rd 。CNF 有两个核心的研究对象:

  • 概率密度路径 (Probability Density Path) p p p [ 0 , 1 ] × R d → R > 0 [0,1]\times \mathbb{R}^d\rightarrow\mathbb{R}_{>0} [0,1]×RdR>0 ,这是一个关于时间的概率密度函数,即有 ∫ p t ( x ) d x = 1 \int p_t(x)dx=1 pt(x)dx=1
  • 关于时间的向量场 (time-dependent vector field) v v v [ 0 , 1 ] × R d → R d [0,1]\times \mathbb{R}^d\rightarrow\mathbb{R}^d [0,1]×RdRd ,它定义了每一个数据点在状态空间中随时间的变化方向和大小(所以叫向量场),可以理解为描述概率分布随时间变化的速率。

向量场 v t v_t vt 可以用来构建关于时间的微分同胚的映射,称为流 (flow) ϕ \phi ϕ [ 0 , 1 ] × R d → R d [0,1]\times \mathbb{R}^d\rightarrow\mathbb{R}^d [0,1]×RdRd 。通过常微分方程来定义:
d d t ϕ t ( x ) = v t ( ϕ t ( x ) ) \frac{d}{dt}\phi_t(x)=v_t(\phi_t(x))\\ dtdϕt(x)=vt(ϕt(x))

ϕ 0 ( x ) = x \phi_0(x)=x ϕ0(x)=x

这里的 ϕ t ( x ) \phi_t(x) ϕt(x) 可以理解为 flow ϕ \phi ϕ 在时间 t t t 时的状态,对应于扩散模型中时间步 t t t 的噪声图。 p t ( x ) p_t(x) pt(x) 是概率密度路径 p p p 时的状态,也就是 flow ϕ \phi ϕ 在时间 t t t 的概率分布。

之前,Neural ODE 提出使用一个参数为 θ ∈ R p \theta\in\mathbb{R}^p θRp 的神经网络 v t ( x ; θ ) v_t(x;\theta) vt(x;θ) 来建模向量场 v t v_t vt ,从而就能够计算出 flow ϕ t \phi_t ϕt,来实现 CNF。

CNF 可以通过 push forward 公式,将一个简单的先验分布 p 0 p_0 p0 (即纯噪声)转化为复杂的分布 p 1 p_1 p1 (即数据分布):
p t = [ ϕ t ] ∗ p 0 p_t=[\phi_t]_*p_0 pt=[ϕt]p0
其中 push forward 操作符 ∗ * 定义为:
[ ϕ t ] ∗ p 0 ( x ) = p 0 ( ϕ t − 1 ( x ) ) det ⁡ [ ∂ ϕ t − 1 ∂ x ( x ) ] [\phi_t]_*p_0(x)=p_0(\phi_t^{-1}(x))\det[\frac{\partial\phi_t^{-1}}{\partial x}(x)] [ϕt]p0(x)=p0(ϕt1(x))det[xϕt1(x)]
如果满足了上述公式,可以看作是一个向量场 v t v_t vt 生成了一个概率密度路径 p t p_t pt

本文通过连续性方程(Continuity Equation)来测试一个向量场是否能生成一个概率密度路径,这是一个偏微分方程(PDE),给出了概率场生成概率密度路径的充要条件:
d d t p t ( x ) + div ( p t ( x ) v t ( x ) ) = 0 \frac{d}{dt}p_t(x)+\text{div}(p_t(x)v_t(x))=0 dtdpt(x)+div(pt(x)vt(x))=0
其中散度运算符 div \text{div} div 是关于空间变量 x = ( x 1 , … , x d ) x=(x^1,\dots,x^d) x=(x1,,xd) 的偏导数: div = ∑ i = 1 d ∂ ∂ x i \text{div}=\sum_{i=1}^d\frac{\partial}{\partial x^i} div=i=1dxi。本文附录 C 还介绍了更多关于 CNF 的前置知识,尤其是如何在空间中任意点 x ∈ R d x\in\mathbb{R}^d xRd 处,计算概率 p 1 ( x ) p_1(x) p1(x)

为什么说向量场 v t v_t vt “生成” 了概率密度路径 p t p_t pt?为什么要用常微分方程 ODE 来表达?

v t v_t vt ϕ t \phi_t ϕt 的导数(微分)。导数或者说微分,就是一个量随着另一个量极小变化时的变化,其实写成离散形式也好理解了,微分就是变化量: ϕ t ′ = ϕ t + Δ t − ϕ t \phi'_t=\phi_{t+\Delta t}-\phi_t ϕt=ϕt+Δtϕt 。就是从上一个时间点,怎么到下一个时间点,再知道初值 ϕ 0 = x \phi_0=x ϕ0=x 之后,就能从第一个点 “流” 到最后一个点,得到一个路径 p t p_t pt,所以说 “向量场( ϕ t \phi_t ϕt ODE 的解 ϕ t ′ = v t \phi'_t=v_t ϕt=vt生成了一条概率路径”。而 ODE d ϕ t / d t = v ( z t , t ) d\phi_t/dt=v(z_t,t) dϕt/dt=v(zt,t) 定义了一个向量场 v v v

三、Flow Matching

在构建生成模型时,我们假设有一个未知的数据分布 q ( x 1 ) q(x_1) q(x1) (注意本文中的符号与扩散模型论文中常用的符号相反,本文中 x 1 x_1 x1 表示真实数据, x 0 x_0 x0 表示随机噪声),我们能从其中采样出大量数据样本,但是不知道该分布的具体函数。

p t p_t pt 为概率路径,而 p 0 = p p_0=p p0=p 是一个简单的已知分布(如标准高斯分布 p ( x ) ∼ N ( 0 , I ) p(x)\sim\mathcal{N}(0,\mathbf{I}) p(x)N(0,I)),并令 p 1 p_1 p1 在分布上大致与 q q q 相等。Flow Matching 的目标就是去匹配这样一条目标概率路径,从而我们能够从 p 0 p_0 p0 ”流动“ 到 p 1 p_1 p1,实现生成。如何构造这样一条目标路径,稍后会介绍。

给定一个目标概率密度路径 p t ( x ) p_t(x) pt(x) 以及对应的生成这条路径的向量场 u t ( x ) u_t(x) ut(x),Flow Matching 的目标函数定义为:
L FM ( θ ) = E t , p t ( x ) ∣ ∣ v t ( x ) − u t ( x ) ∣ ∣ 2 \mathcal{L}_\text{FM}(\theta)=\mathbb{E}_{t,p_t(x)}||v_t(x)-u_t(x)||^2 LFM(θ)=Et,pt(x)∣∣vt(x)ut(x)2
其中 θ \theta θ 是 CNF 向量场 v t v_t vt 的参数, t ∼ U ( 0 , 1 ) ,   x ∼ p t ( x ) t\sim\mathcal{U}(0,1),\ x\sim p_t(x) tU(0,1), xpt(x) 。简单来说,FM 损失就是通过一个神经网络 v t v_t vt 对向量场 u t u_t ut 进行回归。当损失达到零时,训练好的的 CNF 模型就能够生成各时间 t t t p t ( x ) p_t(x) pt(x),当然就能生成符合数据分布 q ( x 1 ) = p 1 ( x ) q(x_1)=p_1(x) q(x1)=p1(x) 的样本。

Flow Matching 目标函数非常简洁,不过实际中它本身是无法计算的,因为我们并不知道 p t p_t pt u t u_t ut。有许多条概率路径能够实现 p 1 ( x ) ≈ q ( x ) p_1(x)\approx q(x) p1(x)q(x),更重要的是,我们无法计算生成目标 p t p_t pt u t u_t ut​ 的闭式解。

由条件概率路径和条件向量场构建 p t p_t pt u t u_t ut

接下来我们介绍构建目标概率路径 p t p_t pt 和向量场 u t u_t ut 的方法,本方法的思路是通过单个样本构建条件概率路径和条件向量场,再通过积分将条件概率路径/向量场与边缘概率路径/向量场联系起来,从而有一个容易计算的流匹配目标函数。

构建目标概率路径的一个简单方法是通过混合一个更简单的概率路径:给定一个特定的数据样本 x 1 x_1 x1,我们用 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) 表示一个条件概率路径,它需要满足:

  • 时间 t = 0 t=0 t=0 p 0 ( x ∣ x 1 ) = p ( x ) p_0(x|x_1)=p(x) p0(xx1)=p(x),也就是说 p 0 ( x ) p_0(x) p0(x) 和样本数据 x 1 x_1 x1 无关,是一个标准噪声分布;
  • t = 1 t=1 t=1 时的 p 1 ( x ∣ x 0 ) p_1(x|x_0) p1(xx0) 是一个在 x = x 1 x=x_1 x=x1 附近的分布(如一个均值为 x 1 x_1 x1,标准差 σ > 0 \sigma>0 σ>0 足够小的正态分布 p 1 ( x ∣ x 1 ) = N ( x ∣ x 1 , σ 2 I ) p_1(x|x_1)=\mathcal{N}(x|x_1,\sigma^2\mathbf{I}) p1(xx1)=N(xx1,σ2I))。也就是说 t = 1 t=1 t=1 时要大致符合数据分布,即 p 1 ( x ) ≈ q ( x ) p_1(x)\approx q(x) p1(x)q(x)

将条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) 对所有的 q ( x 1 ) q(x_1) q(x1) 进行积分(相当于遍历数据集中所有的真实数据),就得到了我们想要的边缘概率路径 p t ( x ) p_t(x) pt(x)
p t ( x ) = ∫ p t ( x ∣ x 1 ) q ( x 1 ) d x 1 p_t(x)=\int p_t(x|x_1)q(x_1)d{x_1} pt(x)=pt(xx1)q(x1)dx1

特别地,当时间 t = 1 t=1 t=1 时,边缘概率 p 1 p_1 p1 是一个混合分布,能够对数据分布 q q q 进行很好的近似:
p 1 ( x ) = ∫ p 1 ( x ∣ x 1 ) q ( x 1 ) d x 1 ≈ q ( x ) p_1(x)=\int p_1(x|x_1)q(x_1)dx_1\approx q(x) p1(x)=p1(xx1)q(x1)dx1q(x)
我们也可以通过对条件向量场进行 ”边缘化“,来定义一个边缘向量场 (marginal vector field) (假设对所有的 t , x t,x t,x p t ( x ) > 0 p_t(x)>0 pt(x)>0):
u t ( x ) = ∫ u t ( x ∣ x 1 ) p t ( x ∣ x 1 ) q ( x 1 ) p t ( x ) d x 1 u_t(x)=\int u_t(x|x_1)\frac{p_t(x|x_1)q(x_1)}{p_t(x)}dx_1 ut(x)=ut(xx1)pt(x)pt(xx1)q(x1)dx1
其中 u t ( ⋅ ∣ x 1 ) :   R d → R d u_t(\cdot|x_1):\ \mathbb{R}^d\rightarrow\mathbb{R}^d ut(x1): RdRd 是生成 p t ( ⋅ ∣ x 1 ) p_t(\cdot|x_1) pt(x1) 的条件向量场。

那么,这种对条件向量场积分,来构造的边缘向量场 u t ( x ) u_t(x) ut(x),能否生成对应的边缘概率路径 p t ( x ) p_t(x) pt(x) 呢?作者证明,是可以的。原文中附录 A 给出了完整的证明过程,其实要证明就是上述构造边缘概率路径/向量场的形式,能够满足连续性方程。

这样就将条件向量场(可以生成条件概率路径)和边缘向量场(可以生成边缘概率路径)联系了起来。从而我们就可以将未知且难以计算的边缘概率场转换为更简单的条件概率场。条件概率场定义起来要简单得多,因为它仅依赖于单个数据样本。正式地表述为:

定理1 给定条件概率路径 p ( x ∣ x 1 ) p(x|x_1) p(xx1) 以生成该路径的条件向量场 u ( x ∣ x 1 ) u(x|x_1) u(xx1),对于任意数据分布 q ( x 1 ) q(x_1) q(x1),边缘向量场 u t u_t ut p t p_t pt 满足连续性方程,即 u t u_t ut 能够生成 p t p_t pt

条件流匹配 Conditional Flow Matching

遗憾的是,由于边缘向量场和边缘概率路径中的积分无法计算,我们还是无法得到 u t u_t ut ,从而也就无法直接计算原始 Flow Matching 目标函数。这里,作者提出了一个更简单的目标函数,它能导出与原目标函数相同的最优解。具体来说,作者提出了 条件流匹配 (Conditional Flow Matching) 目标:
L CFM ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∣ ∣ v t ( x ) − u t ( x ∣ x 1 ) ∣ ∣ 2 \mathcal{L}_\text{CFM}(\theta)=\mathbb{E}_{t,q(x_1),p_t(x|x_1)}||v_t(x)-u_t(x|x_1)||^2 LCFM(θ)=Et,q(x1),pt(xx1)∣∣vt(x)ut(xx1)2

其中 t ∼ U ( 0 , 1 ) ,   x 1 ∼ q ( x 1 ) t\sim\mathcal{U}(0,1),\ x_1\sim q(x_1) tU(0,1), x1q(x1),而此时 x ∼ p t ( x ∣ x 1 ) x\sim p_t(x|x_1) xpt(xx1)。也就是说,我们不回归向量场 u t ( x ) u_t(x) ut(x) 了,而是改为回归条件向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1)。不同于 FM 目标函数,在 CFM 目标函数中,只要我们能从 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) 中采样,并计算 u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1),就可以计算出无偏估计。而由于我们是在单个样本上进行的定义,这两点要求都很容易满足。

作者证明了:

定理2 假设对所有的 x ∈ R d x\in\mathbb{R}^d xRd t ∈ [ 0 , 1 ] t\in[0,1] t[0,1],都有 p t ( x ) > 0 p_t(x)>0 pt(x)>0,那么 L CFM \mathcal{L}_\text{CFM} LCFM L FM \mathcal{L}_\text{FM} LFM 是相等的(至多差一个与 θ \theta θ 无关的常数),即 ∇ θ L CFM ( θ ) = ∇ θ L FM ( θ ) \nabla_\theta\mathcal{L}_\text{CFM}(\theta)=\nabla_\theta\mathcal{L}_\text{FM}(\theta) θLCFM(θ)=θLFM(θ)

也就是说,优化 CFM 目标(在期望上)等同于优化 FM 目标。因此,我们可以用 CFM 目标训练一个 CNF 来生成边际概率路径 p t p_t pt,在 t = 1 t=1 t=1 时近似未知数据分布 q q q,而无需已知边缘概率路径或边缘向量场。我们只需要设计合适的条件概率路径和条件向量场。

四、高斯条件概率路径和条件向量场

CFM 目标适用于所有的条件概率路径和条件向量场。本节中,我们重点讨论高斯条件概率路径族的 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1)。即,我们考虑如下形式的高斯条件概率路径:
p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t 2 ( x 1 ) I ) p_t(x|x_1)=\mathcal{N}(x|\mu_t(x_1),\sigma_t^2(x_1)\mathbf{I}) pt(xx1)=N(xμt(x1),σt2(x1)I)
其中 μ : [ 0 , 1 ] × R d → R d \mu:[0,1]\times \mathbb{R}^d\rightarrow\mathbb{R}^d μ:[0,1]×RdRd σ : [ 0 , 1 ] × R → R > 0 \sigma:[0,1]\times\mathbb{R}\rightarrow\mathbb{R}_{>0} σ:[0,1]×RR>0 分别是关于时间 t t t 的高斯分布的均值和标准差。需要满足:

  1. 在时间 t = 0 t=0 t=0 时,满足 μ 0 ( x 1 ) = 0 , σ 0 ( x 1 ) = 1 \mu_0(x_1)=0,\sigma_0(x_1)=1 μ0(x1)=0,σ0(x1)=1,从而所有的条件概率路径都收敛到标准高斯分布 p ( x ) = N ( x ∣ 0 , I ) p(x)=\mathcal{N}(x|0,\mathbf{I}) p(x)=N(x∣0,I)
  2. 在时间 t = 1 t=1 t=1 时,满足 KaTeX parse error: Got function '\min' with no arguments as subscript at position 37: …_1(x_1)=\sigma_\̲m̲i̲n̲,其中 KaTeX parse error: Got function '\min' with no arguments as subscript at position 8: \sigma_\̲m̲i̲n̲ 需足够小,使得 p 1 ( x ∣ x 1 ) p_1(x|x_1) p1(xx1) 是足够聚集于中心 x 1 x_1 x1 的高斯分布。

存在无限多个向量场可以生成任何特定的概率路径,但这些中的绝大多数是由于存在使底层分布不变的分量(比如像连续性方程中添加一个无散度的分量),导致的不必要的额外计算。作者使用最简单的,对应于高斯分布的标准变换的向量场。具体来说,考虑条件于 x 1 x_1 x1 的流:

ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(x)=\sigma_t(x_1)x+\mu_t(x_1) ψt(x)=σt(x1)x+μt(x1)
x x x 是标准的高斯分布时, ψ t ( x ) \psi_t(x) ψt(x) 是一个仿射变换,映射到均值为 μ t ( x 1 ) \mu_t(x_1) μt(x1)、标准差为 σ t ( x 1 ) \sigma_t(x_1) σt(x1) 的正态分布随机变量。也就是说,根据上式, ψ t \psi_t ψt 的前向过程从噪声分布 p 0 ( x ∣ x 1 ) p_0(x|x_1) p0(xx1) 流向 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) ,即:
[ ψ t ] ∗ p ( x ) = p t ( x ∣ x 1 ) [\psi_t]_*p(x)=p_t(x|x_1) [ψt]p(x)=pt(xx1)
生成这个条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) 的条件向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1) 为:
d d t ψ t ( x ) = u t ( ψ t ( x ) ∣ x 1 ) \frac{d}{dt}\psi_t(x)=u_t(\psi_t(x)|x_1) dtdψt(x)=ut(ψt(x)x1)
ψ t \psi_t ψt 重写为仅关于 x 0 x_0 x0,并将上式代入到 CFM 损失中,有:
L CFM ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∣ ∣ v t ( ψ t ( x 0 ) ) − d d t ψ t ( x 0 ) ∣ ∣ 2 \mathcal{L}_\text{CFM}(\theta)=\mathbb{E}_{t,q(x_1),p(x_0)}||v_t(\psi_t(x_0))-\frac{d}{dt}\psi_t(x_0)||^2 LCFM(θ)=Et,q(x1),p(x0)∣∣vt(ψt(x0))dtdψt(x0)2
由于 ψ t \psi_t ψt 是可逆的仿射映射,我们可以闭式计算出 u t u_t ut

f ′ f' f 表示关于时间的函数 f f f 对时间的微分,即 f ′ = d d t f f'=\frac{d}{dt}f f=dtdf

定理3 设 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) 是一个高斯概率路径, ψ t \psi_t ψt 是其对应的 flow map,那么有唯一的向量场 ψ t \psi_t ψt,其形式为:
u t ( x ∣ x 1 ) = σ t ′ ( x 1 ) σ t ( x 1 ) ( x − μ t ( x 1 ) ) + μ t ′ ( x 1 ) u_t(x|x_1)=\frac{\sigma'_t(x_1)}{\sigma_t(x_1)}(x-\mu_t(x_1))+\mu'_t(x_1) ut(xx1)=σt(x1)σt(x1)(xμt(x1))+μt(x1)
该向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1) 可以生成高斯路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1)​。

高斯条件概率路径的特殊情形

我们的形式化对于任意函数 μ t ( x 1 ) \mu_t(x_1) μt(x1) σ t ( x 1 ) \sigma_t(x_1) σt(x1) 都是完全通用的,我们可以将它们设置为任何满足所需边界条件的可微函数。本节讨论两个实例,首先讨论已有的经典扩散模型(如 VP/VE)在本文形式化下的推导。然后,由于我们直接使用概率路径工作,可以完全不依赖于关于扩散过程的推理。因此,我们可以直接基于 Wasserstein-2 最优传输解来制定一个概率路径,这是第二个实例。

例子1:Diffusion Conditional VFs

扩散模型对一个真实数据样本逐渐添加噪声,直到其成为纯噪声。扩散模型可以表示为随机过程,其具有一定的要求,从而对任意时间 t t t 有闭式表示。选择不同的均值 μ t ( x 1 ) \mu_t(x_1) μt(x1) 和标准差 σ t ( x 1 ) \sigma_t(x_1) σt(x1),就得到特定高斯条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1)

首先来看 Variance Exploding,其反向(噪声->数据)路径为:
p t ( x ) = N ( x ∣ x 1 , σ 1 − t 2 I ) p_t(x)=\mathcal{N}(x|x_1,\sigma^2_{1-t}\mathbf{I}) pt(x)=N(xx1,σ1t2I)
其中 σ t \sigma_t σt 是一个单增函数, σ 0 = 0 , σ 1 > > 1 \sigma_0=0,\sigma_1>>1 σ0=0,σ1>>1。上式这种 VE 扩散模型,是选择了均值和标准差分别为 μ t ( x 1 ) = x 1 , σ t ( x 1 ) = σ 1 − t \mu_t(x_1)=x_1,\sigma_t(x_1)=\sigma_{1-t} μt(x1)=x1,σt(x1)=σ1t 。带入到定理 3 的公式中:
u t ( x ∣ x 1 ) = − σ 1 − t ′ σ 1 − t ( x − x 1 ) u_t(x|x_1)=-\frac{\sigma'_{1-t}}{\sigma_{1-t}}(x-x_1) ut(xx1)=σ1tσ1t(xx1)
另一种经典的扩散模型 Variance Preserving 扩散路径的形式为:
p t ( x ∣ x 1 ) = N ( x ∣ α 1 − t x 1 , ( 1 − α 1 − t 2 ) I ) α t = e − 1 2 T ( t ) T ( t ) = ∫ 0 t β ( s ) d s p_t(x|x_1)=\mathcal{N}(x|\alpha_{1-t}x_1,(1-\alpha^2_{1-t})\mathbf{I})\\ \alpha_t=e^{-\frac{1}{2}T(t)}\\ T(t)=\int_0^t\beta(s)ds pt(xx1)=N(xα1tx1,(1α1t2)I)αt=e21T(t)T(t)=0tβ(s)ds
其中 β \beta β 是关于 t t t 的 noise scale 函数。上式是选择了均值和标准差分别为 μ t ( x 1 ) = α 1 − t x 1 , σ t ( x 1 ) = 1 − α 1 − t 2 \mu_t(x_1)=\alpha_{1-t}x_1,\sigma_t(x_1)=\sqrt{1-\alpha_{1-t}^2} μt(x1)=α1tx1,σt(x1)=1α1t2 。带入到定理 3 的公式中:
u t ( x ∣ x 1 ) = α 1 − t ′ 1 − α 1 − t 2 ( α 1 − t x − x 1 ) = − T ′ ( 1 − t ) 2 [ e − T ( 1 − t ) x − e − 1 2 T ( 1 − t ) x 1 1 − e − T ( 1 − t ) ] u_t(x|x_1)=\frac{\alpha'_{1-t}}{1-\alpha^2_{1-t}}(\alpha_{1-t}x-x_1)=-\frac{T'(1-t)}{2}[\frac{e^{-T(1-t)}x-e^{-\frac{1}{2}T(1-t)}x_1}{1-e^{-T(1-t)}}] ut(xx1)=1α1t2α1t(α1txx1)=2T(1t)[1eT(1t)eT(1t)xe21T(1t)x1]
实际上,本文在指定特定的条件扩散过程时构建出的条件向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1) ,与宋飏等人(Diff SDE 论文,公式 12)中给出的确定性概率流模型是相符的。并且,将扩散条件向量场与 FM 训练目标结合起来,能得到另一种训练 score matching 的方法,作者发现该方法训练起来更加稳定。

作者还指出,上述提到的这些概率路径通过扩散过程推导得出,所以他们在最终的时间步并没有达到真正的噪声分布(Zero Terminal SNR 也提出了一样的问题)。实际中, p 0 ( x ) p_0(x) p0(x) 只是通过一个合适的高斯分布近似,来进行采样和似然计算。而本文提出的构造方式,则可以对概率路径有完全的控制,可以直接设置 μ t \mu_t μt σ t \sigma_t σt。接下来,我们就试试这样做。

例子2:Optimal Transport Conditional VFs

一个更自然的选择是将均值和标准差定义为简单的线性变换,即:
KaTeX parse error: Got function '\min' with no arguments as subscript at position 42: …x)=1-(1-\sigma_\̲m̲i̲n̲)t
根据定理 3,产生上述路径的向量场为:
KaTeX parse error: Got function '\min' with no arguments as subscript at position 33: …{x_1-(1-\sigma_\̲m̲i̲n̲)x}{1-(1-\sigma…

其中 t ∈ [ 0 , 1 ] t\in[0,1] t[0,1]。其对应的 flow 为:

KaTeX parse error: Got function '\min' with no arguments as subscript at position 25: …)=(1-(1-\sigma_\̲m̲i̲n̲)t)x+tx_1
此时,CFM 损失为:
KaTeX parse error: Got function '\min' with no arguments as subscript at position 95: …(x_1-(1-\sigma_\̲m̲i̲n̲)x_0)||^2
本文这种线性的均值标准差构造方法,不仅能得到简单直观的路径,实际上在以下意义上也是最优的。条件流 ψ t ( x ) \psi_t(x) ψt(x) 实际上是两个高斯分布 p 0 ( x ∣ x 1 ) p_0(x|x_1) p0(xx1) p 1 ( x ∣ x 1 ) p_1(x|x_1) p1(xx1) 之间的最优传输映射(Optimal Transport (OT) Displacement Map)。最优传输插值(OT Interpolant),即是一个概率路径,被定义为:
p t = [ ( 1 − t ) id + t ψ ] ∗ p 0 p_t=[(1-t)\text{id}+t\psi]_*p_0 pt=[(1t)id+tψ]p0
其中 ψ : R d → R d \psi:\mathbb{R}^d\rightarrow\mathbb{R}^d ψ:RdRd 是从 p 0 p_0 p0 p 1 p_1 p1 的最优传输映射, id \text{id} id 表示恒等映射,即 id ( x ) = x \text{id}(x)=x id(x)=x ( 1 − t ) id + t ψ (1-t)\text{id}+t\psi (1t)id+tψ 即 OT displacement map。之前的研究表明,在这种情况汇总,两个高斯分布(其中第一个是标准高斯)的 OT displacement map 形如式 23。

直观地说,在最优传输位移图下,粒子总是沿着直线轨迹并以恒定速度移动。下图展示了扩散和最优传输条件向量场的采样路径。作者还发现,从扩散路径中采样的轨迹可能会“超出”最终样本,导致不必要的回溯,而最优传输路径则保证保持直线。

在这里插入图片描述

下图比较了扩散条件得分函数(典型扩散方法中的回归目标),即 ∇ log ⁡ p t ( x ∣ x 1 ) \nabla \log p_t(x|x_1) logpt(xx1),与 OT 条件向量场。两个示例中的起始 $p_0 $ 和结束 p 1 p_1 p1高斯分布是相同的。一个有趣的观察是,最优传输向量场在时间上具有恒定的方向,这无疑会导致一个更简单的回归任务。这个属性也可以从 OT 的形式中验看出,因为向量场可以写成 u t ( x ∣ x 1 ) = g ( t ) h ( x ∣ x 1 ) u_t(x|x_1) = g(t)h(x|x_1) ut(xx1)=g(t)h(xx1) 的形式。最后,我们注意到,尽管条件流是最优的,但这并不意味着边际向量场是最优传输解。尽管如此,我们期望边际向量场保持相对简单。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/728166.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MES管理系统如何设计生产质量管理功能

在现代制造业中,MES管理系统作为连接企业计划层与车间操作层的关键桥梁,其生产者计量管理功能的设计显得尤为重要。一个完善的MES管理系统生产质量管理模块,不仅要求能够实时、准确地采集和分析生产过程中的质量数据,还需要能够与…

Unity3d 游戏暂停(timeScale=0)引起的deltaTime关联的系列问题解决

问题描述 游戏暂停的功能是通过设置timeScale0实现的,不过在暂停游戏的时候,需要对角色进行预览和设置,为了实现这个功能,是通过鼠标控制相机的操作,为了使相机的操作丝滑,获取鼠标操作系数乘以Time.delta…

如何在React中使用CSS模块,并解释为什么使用它们比传统CSS更有益?

在React中使用CSS模块是一种将CSS类名局部化到单个组件的方法,从而避免了全局作用域中的类名冲突。CSS模块允许你为组件编写样式,并确保这些样式只应用于该组件,而不会影响到其他组件。 以下是在React中使用CSS模块的步骤: 安装C…

Excel 识别数据层次后转换成表格

某列数据可分为 3 层,第 1 层是字符串,第 2 层是日期,第 3 层是时间: A1NAME122024-06-03304:06:12404:09:23508:09:23612:09:23717:02:2382024-06-02904:06:121004:09:231108:09:2312NAME2132024-06-031404:06:121504:09:231620…

FreeBSD在zfs挂接第二块ssd 硬盘

为FreeBSD机器新增加了一块ssd硬盘:骑尘 256G 先格式化分区硬盘 进入bsdconfig 选Disk Management 选择ada1 ,也就是新增加的硬盘 选择auto 然后选择Entire Disk 提示信息 The existing partition scheme on this disk (MBR) │ …

如何解决windows自动更新,释放C盘更新内存

第一步:首先关闭windows自动更新组件 没有更新windows需求,为了防止windows自动更新,挤占C盘空间,所以我们要采取停止Windows Update服务。按下WinR打开运行对话框,输入services.msc, 然后按Enter。在服务…

传输大文件之镭速自动清理过期文件

电子文档的普及无疑极大地便利了我们的工作与生活,但随之而来的是如何有效管理这些日益增多的文件。企业面临着存储空间紧张、文件传输复杂、敏感信息泄露等挑战。自动化文件清理的需求日益凸显,这不仅关乎个人对高效工作环境的追求,更是企业…

绘唐3工具—让创作触手可及

绘唐AIGC(Artificial Intelligence Generated Creativity)是一种新兴的技术,通过人工智能生成创意,让创作更加触手可及。 绘唐3下载地址https://qvfbz6lhqnd.feishu.cn/wiki/KnRawcWQxiFrj5kC8lVcCEypnZc 绘唐AIGC结合了人工智能…

element-plus的form表单组件之checkbox组件

单个checkbox 绑定的响应式的值类型为bool类型,同一个组的checkbox多选其值对应值的数组,类型根据checkbox的value值而来。 label只用来显示具体的值,根据value属性来设置。 element-plus的checkbox提供多种特性。 如单选,多选…

B站广告开户投流是什么政策?要哪些资质?

B站(哔哩哔哩)作为年轻人喜爱的视频分享社区,其广告价值也日益凸显。为了更好地服务广告主,B站近日对广告开户投流政策进行了更新,云衔科技作为专业的数字营销服务商,也积极响应,为广告主提供一…

【java分布式计算】控制反转和依赖注入(DI IOC AOP)

考试要求:了解控制反转的基本模式,用依赖注入编写程序 目录 控制反转(Inversion of Control, IOC): 依赖注入(Dependency Injection, DI): 依赖注入的三种实现方式 具体的例子 …

C#——装箱与拆箱详情

装箱与拆箱 装箱: 将值类型转换成引用类型的过程; 拆箱: 把引用类型转为值类型的过程,就是拆箱 装箱 拆箱

韩顺平0基础学Java——第27天

p548-568 明天开始坦克大战 Entry 昨天没搞明白的Map、Entry、EntrySet://GPT教的 Map 和 Entry 的关系 1.Map 接口:它定义了一些方法来操作键值对集合。常用的实现类有 HashMap、TreeMap 等。 2. Entry接口:Entry 是 Map 接口的一个嵌…

LDO的原理及测试方法

一、基本结构 这是LM317芯片的核心,这个电路单元称为Bandgap Reference带隙基准源。属于模拟集成电路中的经典电路结构。 LDO拓扑结构图 常见的基本结构 利用VBE的负温度系数,而VT是正温度系数,正负温度系数抵消就的得到稳定的基准参考电压了(三极管的方程VBE=VT*In(lC/IS…

鄂州职业大学2024年成人高等继续教育招生简章

鄂州职业大学,作为一所享有盛誉的高等学府,一直以来都致力于为社会培养具备专业技能和良好素养的优秀人才。在成人高等继续教育领域,该校同样表现出色,为广大渴望继续深造、提升自身能力的成年人提供了宝贵的学习机会。 随着社会…

通过Socket通信实现局域网下Amov无人机连接与数据传输

1.局域网下的通信 1.1 局域网 厂家提供的方式是通过Homer图数传工具(硬件)构建的amov局域网实现通信连接. 好处是通信距离足够长,支持150m;坏处是"局部",无法访问互联网. [IMAGE:…

生信算法8 - HGVS转换与氨基酸字母表

HGVS 概念 HGVS 人类基因组变异协会(Human Genome Variation Society)提出的转录本编号,cDNA 参考序列(以前缀“c.”表示)、氨基酸参考序列(以前缀“p.”表示)。cDNA 中一种碱基被另一种碱基取代,以“>”进行表示,如:c.2186A&…

14-Kafka-Day03

第 5 章 Kafka 消费者 5.1 Kafka 消费方式 5.2 Kafka 消费者工作流程 5.2.1 消费者总体工作流程 一个消费者组中的多个消费者,可以看作一个整体,一个组内的多个消费者是不可能去消费同一个分区的数据的,要不然就消费重复了。 5.2.2 消费者…

达梦8 兼容MySQL语法支持非分组项作为查询列

MySQL 数据库迁移到达梦后,部分GROUP BY语句执行失败,报错如下: 问题原因: 对于Oracle数据库,使用GROUP BY时,SELECT中的非聚合列必须出现在GROUP BY后面,否则就会报上面的错误,达梦…

C++项目——负载均衡在线OJ

前言 学习了这么久的C/C与Linux,终于到了做项目的时候,想想还是有点小激动,哈哈哈哈哈。我们的目标是做一个跟leetcode、牛客类似的在线OJ系统,功能阉割了一些,比如说登录、论坛、求职等等。主要实现了提交题目与判定…