了解diffusion model:什么是diffusion model? 它为什么好用? - 知乎
摘要
去噪扩散概率模型目前正成为许多重要数据模式生成建模的主要范式。扩散模型在计算机视觉社区中最为流行,最近也在其他领域引起了一些关注,包括语音、NLP 和图形数据。在这项工作中,我们研究了扩散模型的框架是否可用于解决表格问题,其中数据点通常由异构特征的向量表示。表格数据固有的异构性使得准确建模变得非常具有挑战性,因为各个特征可能具有完全不同的性质,即其中一些特征可能是连续的,而另一些特征可能是离散的。为了解决此类数据类型,我们引入了 TabDDPM——一种可以普遍应用于任何表格数据集并处理任何类型特征的扩散模型。我们在一系列基准上对 TabDDPM 进行了广泛的评估,并证明了它优于现有 GAN/VAE 替代方案,这与扩散模型在其他领域的优势一致。此外,我们表明 TabDDPM 适用于隐私导向的设置,其中原始数据点不能公开共享。 TabDDPM 的源代码和我们的实验可在 https://github.com/rotot0/tab-ddpm 上找到。
1.引言
去噪扩散概率模型 (DDPM) (Sohl-Dickstein 等人,2015 年;Ho 等人,2020 年) 最近成为生成模型界研究的热门话题,因为它们在单个样本的真实感和多样性方面往往优于其他方法 (Dhariwal & Nichol,2021 年)。 DDPM 最令人印象深刻的成功是在自然图像领域 (Dhariwal & Nichol, 2021; Saharia et al, 2022; Rombach et al, 2022) 证明的,其中扩散模型的优势在诸如着色 (Song et al, 2021)、修复 (Song et al, 2021)、分割 Baranchuk et al (2021)、超分辨率 (Saharia et al, 2021; Li et al, 2021)、语义编辑 (Meng et al, 2021) 等应用中得到成功利用。除了计算机视觉之外,DDPM 框架还在其他领域得到了研究,例如 NLP(Austin 等,2021;Li 等,2022)、波形信号处理(Kong 等,2020;Chen 等,2020)、分子图(Jing 等,2022;Hoogeboom 等,2022)、时间序列(Tashiro 等,2021),证明了扩散模型在广泛问题中的普适性。
本文的工作目的是了解 DDPM 是否可以扩展到表格问题。由于各个特征的异质性(有的特征是连续数据有的是离散数据)以及典型表格数据集的规模相对较小,训练高质量的表格数据模型与计算机视觉或 NLP 相比更具挑战性。在本文中,作者表明,尽管存在这两个复杂性,但扩散模型可以成功地近似表格数据的典型分布,从而在大多数基准测试中实现最优的性能。
更详细地说,我们工作的主要贡献如下:
- 引入了 TabDDPM — 针对表格问题的最简单 DDPM 设计,可应用于任何表格任务,并可处理混合数据,包括数值和分类特征。
- 证明 TabDDPM 优于为表格数据设计的替代方法,包括文献中基于 GAN 和基于 VAE 的模型,并说明了多个数据集的这一优势来源。
- 当使用合成数据替代无法共享的真实用户数据时,TabDDPM 生成的数据似乎是隐私问题场景的“最佳点”。
2.相关工作
扩散模型 (Sohl-Dickstein 等人,2015 年;Ho 等人,2020 年) 是一种生成建模范式,旨在通过马尔可夫链的端点近似目标分布,该链从给定的参数分布(通常是标准高斯分布)开始。每个马尔可夫步骤都由深度神经网络执行,该网络可以有效地学习使用已知高斯核来反转扩散过程。 Ho 等人证明了扩散模型和分数匹配的等价性 (Song & Ermon,2019 年;2020 年),表明它们是通过迭代去噪过程将简单的已知分布逐步转换为目标分布的两个不同视角。最近的几项工作 (Nichol,2021 年;Dhariwal & Nichol,2021 年) 开发了更强大的模型架构以及不同的高级学习协议,这导致 DDPM 在计算机视觉领域的生成质量和多样性方面“胜过”GAN。在我们的工作中,我们证明了扩散模型也可以成功地用于表格问题。
表格问题的生成模型目前是机器学习社区的一个活跃的研究方向(Xu 等人,2019 年;Engelmann & Lessmann,2021 年;Jordon 等人,2018 年;Fan 等人,2020 年;Torfi 等人,2022 年;Zhao 等人,2021 年;Kim 等人,2021 年;Zhang 等人,2021 年;Nock & GuillameBert,2022 年;Wen 等人,2022 年),因为许多表格任务对高质量的合成数据有着很大的需求。首先,表格数据集的大小通常有限,这与视觉或 NLP 问题不同,因为互联网上有大量“额外”数据可用。其次,适当的合成数据集不包含实际的用户数据,因此它们不受 GDPR 类法规的约束,并且可以公开共享而不会违反匿名性。最近的研究开发了大量模型,包括表格 VAE(Xu et al,2019)和基于 GAN 的方法(Xu et al,2019;Engelmann & Lessmann,2021;Jordon et al,2018;Fan et al,2020;Torfi et al,2022;Zhao et al,2021;Kim et al,2021;Zhang et al,2021;Nock & Guillame-Bert,2022;Wen et al,2022)。通过对大量公共基准进行广泛的评估,我们表明我们的 TabDDPM 模型超越了现有的替代方案,而且往往领先幅度很大。
“浅层”合成生成。与非结构化图像或自然文本不同,表格数据通常是结构化的,即各个特征通常是可解释的,并且尚不清楚它们的建模是否需要多层“深层”架构。因此,简单的插值技术,如 SMOTE(Chawla 等人,2002 年)(最初提出用于解决类别不平衡问题)可以作为简单而强大的解决方案,如(Camino 等人,2020 年)中所示,其中 SMOTE 在小类过采样方面的表现优于表格 GAN。在实验中,我们从隐私保护的角度展示了 TabDDPM 生成的合成图像相对于插值技术生成的合成图像的优势。
3.背景
扩散模型(Sohl-Dickstein 等人,2015 年;Ho 等人,2020 年)是基于似然的生成模型,通过正向和反向马尔可夫过程处理数据。正向过程 逐渐将噪声添加到来自数据分布 的初始样本 ,从预定义分布 中采样噪声,方差为 。逆扩散过程 逐渐对潜在变量 进行去噪,并允许从 生成新的数据样本。分布 通常未知,并由具有参数 θ 的神经网络近似。这些参数是通过优化变分下限从数据中学习的:公式(1)
高斯扩散模型在连续空间 中运行,其中正向和反向过程以高斯分布为特征:
Ho 等人 (2020) 建议使用具有常数 的对角线 ,并计算 作为 和 的函数:
其中 且 预测噪声数据样本 的“真实”噪声分量。实际上,目标 (1) 可以简化为 和 在所有时间步 t 上的均方误差之和: 公式(2)
多项式扩散模型 (Hoogeboom 等人,2021) 旨在生成分类数据,其中 是具有 K 个值的独热编码分类变量。多项式正向扩散过程将 定义为分类分布,该分布通过 K 个类的均匀噪声破坏数据:
从上面的方程中,可以推导出后验概率 :
其中,。
逆分布 被参数化为 ,其中 由神经网络预测。然后,训练模型以最大化变分下界 (1)。
4.TABDDPM
在本节中,我们描述了 TabDDPM 的设计及其影响模型有效性的主要超参数。
TabDDPM 使用多项式扩散来对分类特征和二元特征进行建模,使用高斯扩散来对数值特征进行建模。更详细地讲,对于一个表格数据样本 ,其中包含 个数值特征 和 C 个分类特征 (每个特征有 个类别),我们的模型将分类特征的独热编码版本作为输入(即),并使用标准化的数值特征。因此,输入 的维数为 ()。对于预处理,我们使用 scikit-learn 库(Pedregosa 等人,2011)中的高斯分位数变换。每个分类特征都由单独的前向扩散过程处理,即所有特征的噪声成分都是独立采样的。 TabDDPM 中的反向扩散步骤由多层神经网络建模,该网络具有与 相同的维数输出,其中前 个数值是高斯扩散的 预测,其余数值是多项式扩散的 i 预测。
TabDDPM 模型用于分类问题的示意图如图1所示。该模型通过最小化高斯扩散项的均方误差 (公式 (2))和每个多项式扩散项的 KL 散度(公式 (1))之和来进行训练。多项式扩散的总损失还会额外除以分类特征的数量。公式(3)
对于分类数据集,我们使用一个类条件模型,即学习 。对于回归数据集,我们将目标值视为一个额外的数值特征,并学习联合分布。
为了建模逆过程,我们使用一个简单的 MLP 架构,改编自 (Gorishniy 等人,2021):公式(4)
与 (Nichol, 2021; Dhariwal & Nichol, 2021) 一样,表格输入 、时间步长 t 和类标签 y 的处理如下。公式(5)
其中 SinTimeEmb 指的是正弦时间嵌入,如 (Nichol, 2021; Dhariwal & Nichol, 2021) 中所示,维度为 128。公式 5 中的所有线性层都具有固定的投影维度 128。
TabDDPM 中的超参数至关重要,因为在实验中我们观察到它们对模型有效性有很大影响。表 1 列出了我们建议使用的主要超参数以及每个超参数的搜索空间。实验部分详细描述了调整过程。
简述上面的过程:
(1) 数据预处理
- 数值特征:对数值特征进行高斯分位数变换,使其接近标准正态分布。
- 离散特征:将离散特征转换为独热编码形式(one-hot encoding)。
- (2)正向扩散(Forward Diffusion)
- 数值特征:通过高斯扩散过程处理,即向数值特征添加高斯噪声,逐步扰动数据,使其接近高斯分布。
- 离散特征:每个离散特征独立处理,通过多项扩散过程添加噪声。每个类别特征的噪声分量是独立采样的。
(3) 构建输入数据:将处理的数值特征和one-hot编码的离散特征合并成输入向量,其维度为 。
(4)模型构建和训练:使用一个多层感知器(MLP)来建模反向扩散过程。该MLP的输出维度与 相同。
- 数值特征预测:MLP输出向量的前 个坐标用于预测高斯扩散过程中的噪声,从而恢复数值特征。
- 离散特征预测:MLP输出向量的剩余坐标用于预测多项扩散过程中的噪声,从而恢复离散特征。
(5)损失函数
- 数值特征:通过最小化均方误差(mean-squared error, )来训练模型。
- 类别特征:通过最小化每个多项扩散过程的KL散度( )来训练模型。所有类别特征的总损失除以类别特征的数量 𝐶 进行归一化。
总损失函数如下:
(6)处理时间步长和类别标签
- 使用一个固定维度为128的线性层来处理时间步长 𝑡 和类别标签 𝑦。
- 时间步长 t 经过正弦时间嵌入(SinTimeEmb)处理。
- 类别标签 𝑦 经过嵌入(Embedding)处理。
- 最终输入数据 𝑥 结合了初始输入 、时间步长嵌入 和类别标签嵌入 。
(7) 数据生成过程
分类任务:使用类条件模型,即学习 (; y)。
回归任务:将目标值作为一个额外的数值特征,并学习联合分布。
5.实验
在本节中,我们将广泛评估 TabDDPM 与现有替代方案。
数据集。为了系统地研究表格生成模型的性能,我们考虑了一组不同的 15 个现实世界的公共数据集。这些数据集具有不同的大小、性质、特征数量及其分布。大多数数据集以前用于表格模型评估(Zhao et al, 2021; Gorishniy et al, 2021)。表 2 中列出了数据集及其属性的完整列表。
基线。由于针对表格数据提出的生成模型数量巨大,我们仅针对生成建模每个范式中的领先方法来评估 TabDDPM。此外,我们仅考虑具有已发布源代码的基线。
- TVAE(Xu et al, 2019)——用于表格数据生成的最先进的变分自动编码器。据我们所知,没有其他类似 VAE 的模型能够胜过 TVAE 并拥有公开源代码。
- CTABGAN(Zhao 等人,2021 年)——一种基于 GAN 的最新模型,在一系列不同的基准测试中,其表现优于现有的表格 GAN。这种方法无法处理回归任务。
- CTABGAN+(Zhao 等人,2022 年)——CTABGAN 模型的扩展,该模型在最近的预印本中发布。我们不知道在 CTABGAN+ 之后提出的、具有公开源代码的基于 GAN 的表格数据模型。
- SMOTE(Chawla 等人,2002 年)——一种基于“浅”插值的方法,它将合成点“生成”为真实数据点与其来自数据集的第 k 个最近邻的凸组合。该方法最初是为小类过采样而提出的。在这里,我们对其进行概括并将其作为简单的健全性检查应用于合成数据生成。
评估指标。我们的主要评估指标是机器学习 (ML) 效率(或效用)(Xu et al, 2019)。更详细地说,ML 效率量化了在合成数据上训练并在真实测试集上评估的分类或回归模型的性能。直观地说,在高质量合成数据上训练的模型应该与在真实数据上训练的模型具有竞争力(甚至更胜一筹)。在我们的实验中,我们使用两种评估协议来计算 ML 效率。在文献中更常见的第一种协议中(Xu et al, 2019;Zhao et al, 2021;Kim et al, 2022),我们针对一组不同的 ML 模型(逻辑回归、决策树等)计算平均效率。在第二种协议中,我们仅针对 CatBoost 模型(Prokhorenkova et al, 2018)评估 ML 效率,该模型是领先的 GBDT 实现,在表格任务上提供最先进的性能。在第 5.2 小节的实验中,我们表明使用第二种协议至关重要,而第一种协议通常会产生误导。
调整过程。为了调整 TabDDPM 和基线的超参数,我们使用了 Optuna 库(Akiba 等人,2019 年)。调整过程由保留验证数据集上生成的合成数据的 ML 效率值(相对于 Catboost)指导(分数在五个不同的采样种子上取平均值)。 TabDDPM 所有超参数的搜索空间在表 1 中报告(对于基线 - 在附录 B 中)。此外,我们证明使用 CatBoost 指导调整超参数不会引入任何类型的“Catboost 偏差”,并且经过 Catboost 调整的 TabDDPM 产生的合成物对于其他模型(如 MLP)也更胜一筹。这些结果报告在附录 A 中。
5.1 定性比较
在这里,定性地研究了与 TVAE 和 CTABGAN+ 相比,TabDDPM 建模个体和联合特征分布的能力。具体来说,对于每个数据集,我们从 TabDDPM、TVAE 和 CTABGAN+ 中抽取一个合成数据集,其大小与特定数据集中的真实训练集相同。对于分类数据集,每个类都根据其在真实数据集中的比例进行抽样。然后,我们在图 2 中可视化真实数据和合成数据的典型个体特征分布。为了完整起见,我们展示了不同类型和分布的特征。在大多数情况下,与 TVAE 和 CTABGAN+ 相比,TabDDPM 产生的特征分布更真实。优势在
- 对于均匀分布的数值特征,TabDDPM生成的分布更为真实。这说明TabDDPM在处理和模拟连续数值特征时具有更好的表现。
- 对于高基数(类别数量多)的类别特征,TabDDPM表现出更强的能力,这意味着TabDDPM在处理复杂的类别数据时比其他模型更有效。
- 同时包含连续和离散分布的混合类型特征,TabDDPM能够生成更逼真的数据分布。这显示了TabDDPM在处理多样化数据类型时的灵活性和准确性。
然后,我们还可视化了针对不同数据集基于真实数据和合成数据计算的相关矩阵之间的差异,见图 3。为了计算相关矩阵,我们使用 Pearson 相关系数来表示数值-数值的相关性,使用相关比来表示分类-数值情况,并使用 Theil 的 U 统计量来表示分类特征。与 CTABGAN+ 和 TVAE 相比,TabDDPM 生成的合成数据集具有更真实的成对相关性。这些图示表明,与其他方案相比,我们的 TabDDPM 模型更加灵活,并能生成更出色的合成数据。
- Pearson 相关系数:Pearson相关系数用于衡量两个连续变量之间的线性相关程度。它的取值范围在-1到1之间,其中1表示完全正相关,-1表示完全负相关,0表示无相关性。
- 相关比 (Correlation Ratio):相关比用于衡量一个分类变量与一个连续变量之间的相关性。它衡量了不同类别之间的变量分布的差异性。相关比的取值范围是0到1之间,其中0表示没有相关性,1表示完全相关。
- Theil's U统计量:Theil's U统计量用于测量两个分类变量之间的关联程度。它可以看作是对熵的一种扩展,用于评估两个分类变量之间的非对称关联。Theil's U统计量的取值范围是0到1之间,其中0表示没有关联,1表示完全关联。
5.2 机器学习效率
在本节中,我们将 TabDDPM 与其他生成模型在机器学习效率方面进行比较。从每个生成模型中,我们按表 1 中的比例抽取一个大小与真实训练集相同的合成数据集。然后使用这些合成数据训练分类/回归模型,然后使用真实测试集对其进行评估。在我们的实验中,分类性能通过 F1 分数来评估,回归性能通过 R2 分数来评估。我们使用两种协议:
- 首先,我们计算一组多样化 ML 模型的平均 ML 效率,如之前的研究(Xu 等人,2019 年;Zhao 等人,2021 年;Kim 等人,2022 年)中所述。该集合包括来自 scikit-learn 库 (Pedregosa et al, 2011) 的决策树、随机森林、逻辑回归(或岭回归)和 MLP 模型,这些模型具有默认超参数,但决策树和随机森林的“最大深度”等于 28,逻辑回归和岭回归的“最大迭代次数”等于 500,MLP 的“最大迭代次数”等于 100。
- 其次,我们根据当前最先进的表格数据模型计算 ML 效率。具体来说,我们考虑使用 CatBoost(Prokhorenkova et al, 2018)和来自 (Gorishniy et al, 2021) 的 MLP 架构进行评估。使用来自 (Gorishniy et al, 2021) 的搜索空间在每个数据集上彻底调整 CatBoost 和 MLP 超参数。我们认为,该评估协议更可靠地证明了合成数据的实际价值,因为在大多数实际场景中,从业者对使用弱和次优分类器/回归器不感兴趣。
主要结果。表 3 和表 4 显示了两种协议计算出的 ML 效率值。附录 A 报告了调整后的 MLP 的 ML 效率。为了计算每个值,我们对合成生成的五个随机种子的结果取平均值,对于每个生成的数据集,我们对训练分类器/回归器的十个随机种子取平均值。关键观察结果如下:
- 在两种评估协议中,TabDDPM 在大多数数据集上的表现都明显优于 TVAE 和 CTABGAN+,这凸显了表格数据扩散模型的优势,并在先前的工作中证明了其他领域的优势。
- 基于插值的 SMOTE 方法表现出与 TabDDPM 相媲美的性能,并且通常明显优于 GAN/VAE 方法。有趣的是,大多数关于表格数据生成模型的先前研究都没有与 SMOTE 进行比较,而它似乎是一个简单的基线,很难被击败。
- 虽然许多先前的研究使用第一个评估协议来计算 ML 效率,但我们认为第二个协议(使用最先进的模型,如 CatBoost)更合适。表 3 和表 4 显示,第一个协议的分类/回归性能的绝对值要低得多,即在考虑的基准上,弱分类器/回归器明显不如 CatBoost。因此,很难使用这些次优模型代替 CatBoost,而且它们的性能值对从业者来说没有参考价值。此外,在第一个协议中,与对真实数据进行训练相比,对合成数据进行训练通常更有优势。这给人一种印象,即生成模型生成的数据比真实数据更有价值。然而,当使用经过调整的 ML 模型时,情况并非如此,就像在大多数实际场景中一样。附录 A 证实了对经过适当调整的 MLP 模型的这一观察结果。
总体而言,TabDDPM 提供了最先进的生成性能,可用作高质量合成数据的来源。有趣的是,就 ML 效率而言,简单的“浅层” SMOTE 方法与 TabDDPM 相媲美,这引发了一个问题,是否需要复杂的深度生成模型。在下面的部分中,我们对这个问题给出了肯定的答案。
5.3 隐私
在这里,我们证明 TabDDPM 在有隐私问题的设置中比 SMOTE 更可取,例如在不泄露个人或敏感信息的情况下共享数据。在这些设置中,人们对不会泄露原始真实数据集中的数据点的高质量合成感兴趣。为了量化合成的隐私,我们使用合成和真实数据点之间的中位数距离最近记录 (DCR) (Zhao et al, 2021)。具体而言,对于每个合成样本,我们找到到真实数据点的最小距离并取这些距离的中位数。低 DCR 值表示所有合成样本本质上都是一些真实数据点的副本,这违反了隐私要求。相反,较大的 DCR 值表示生成模型可以生成一些“新”的东西,而不仅仅是真实数据的副本。表 5 比较了 SMOTE 和 TabDDPM 的 DCR 值,并展示了 TabDDPM 对所有数据集的一致优势。我们还在图 4 上可视化了最小合成到真实距离的直方图。对于 SMOTE,大多数距离值集中在零附近,而 TabDDPM 样本与真实数据点的分离效果更好。该实验证实,TabDDPM 合成数据在提供高 ML 效率的同时,也更适合隐私保护场景。
6.总结
在本文中,我们研究了扩散建模框架在表格数据领域的应用前景。特别是,描述了 DDPM 的设计,它可以处理由数值、序数和分类特征组成的混合数据。还展示了模型超参数的重要性,并解释了它们的调整协议。对于最受关注的基准,模型生成的合成结果与基于 GAN/VAE 的竞争对手和插值技术生成的合成结果相比始终具有更高的质量,尤其是对于必须确保数据隐私的设置。