【AAAI 2024】解锁深度表格学习(Deep Tabular Learning)的关键:算术特征交互

近日,阿里云人工智能平台PAI与浙江大学吴健、应豪超老师团队合作论文《Arithmetic Feature Interaction is Necessary for Deep Tabular Learning》正式在国际人工智能顶会AAAI-2024上发表。本项工作聚焦于深度表格学习中的一个核心问题:在处理结构化表格数据(tabular data)时,深度模型是否拥有有效的归纳偏差(inductive bias)。我们提出算术特征交互(arithmetic feature interaction)对深度表格学习是至关重要的假设,并通过创建合成数据集以及设计实现一种支持上述交互的AMFormer架构(一种修改的Transformer架构)来验证这一假设。实验结果表明,AMFormer在合成数据集表现出显著更优的细粒度表格数据建模、训练样本效率和泛化能力,并在真实数据的对比上超过一众基准方法,成为深度表格学习新的SOTA(state-of-the-art)模型。

背景

图1:结构化表格数据示例,引用自[Borisov et al.]
图1:结构化表格数据示例,引用自[Borisov et al.]

结构化表格数据——这些数据往往以表(Table)的形式存储于数据库或数仓中——作为一种在金融、市场营销、医学科学和推荐系统等多个领域广泛使用的重要数据格式,其分析一直是机器学习研究的热点。表格数据(图1)通常同时包含数值型(numerical)特征和类目型(categorical)特征,并往往伴随有特征缺失、噪声、类别不平衡(class imblanance)等数据质量问题,且缺少时序性、局部性等有效的先验归纳偏差,极大地带来了分析上的挑战。传统的树集成模型(如,XGBoost、LightGBM、CatBoost)因在处理数据质量问题上的鲁棒性,依然是工业界实际建模的主流选择,但其效果很大程度依赖于特征工程产出的原始特征质量。

随着深度学习的流行,研究者试图引入深度学习端到端建模,从而减少在处理表格数据时对特征工程的依赖。相关的研究工作至少可以可以分成四大类:(1)在传统建模方法中叠加深度学习模块(通常是多层感知机MLP),如Wide&Deep、DeepFMs;(2)形状函数(shape function)采用深度学习建模的广义加性模型(generalized additive model),如 NAM、NBM、SIAN;(3)树结构启发的深度模型,如NODE、Net-DNF;(4)基于Transformer架构的模型,如AutoInt、DCAP、FT-Transformer。尽管如此,深度学习在表格数据上相比树模型的提升并不显著且持续,其有效性仍然存在疑问,表格数据因此被视为深度学习尚未征服的最后堡垒。

算术特征交互在深度表格学习的“必要性”

我们认为现有的深度表格学习方法效果不尽如人意的关键症结在于没有找到有效的建模归纳偏差,并进一步提出算术特征交互对深度表格学习是至关重要的假设。本节介绍我们通过创建一个合成数据集,并对比引入算数特征交互前后的模型效果,来验证该假设。

合成数据集的构造方法如下:我们设计了一个包含八个特征(

图片

)的合成数据集。

图片

图2:合成数据集上的结果对比。图中+x%表示AMFormer相比Transformer的相对提升。
标图2:合成数据集上的结果对比。图中+x%表示AMFormer相比Transformer的相对提升。

在上述数据中,我们将引入了算数特征交互的AMFormer架构与经典的XGBoost和Transformer架构对比。实验结果显示:

以上结果共同证实了算术特征交互在深度表格学习中的显著意义。

算法架构

图片
图3:AMFormer架构,其中L表示模型层数。

本节介绍AMFormer架构(图3),并重点介绍算数特征交互的引入。AMFormer架构借鉴了经典的Transformer框架,并引入了Arithmetic Block来增强模型的算术特征交互能力。在AMFormer中,我们首先将原始特征转换为具有代表性的嵌入向量,对于数值特征,我们使用一个1输入d输出的线性层;对于类别特征,则使用一个d维的嵌入查询表。之后,这些初始嵌入通过L个顺序层进行处理,这些层增强了嵌入向量中的上下文和交互元素。每一层中的算术模块采用了并行的加法和乘法注意力机制,以刻意促进算术特征之间的交互。为了促进梯度流动和增强特征表示,我们保留了残差连接和前馈网络。最终,依据这些丰富的嵌入向量,AMFormer使用分类或回归头部生成最终输出。

算术模块的关键组件包括并行注意力机制和提示标记。为了补偿需要算术特征交互的特征,我们在AMFormer中配置了并行注意力机制,这些机制负责提取有意义的加法和乘法交互候选者。这些交互候选随着会沿着候选维度被串联(concatenate)起来,并通过一个下采样的线性层进行融合,使得AMFormer的每一层都能有效捕捉算术特征交互,即特征上的四则算法运算。为了防止由特征冗余引起的过拟合并提升模型在超大规模特征数据集上的伸缩,我们放弃了原始Transformer架构中平方复杂度的自注意力机制,而是使用两组提示向量(prompt token vectors)作为加法和乘法查询。这种方法为AMFormer提供了有限的特征交互自由度,并且作为一个附带效果,优化了内存占用和训练效率。

以上是AMFormer在架构层引入的主要创新,关于模型更详细的实现细节可以参考原文以及我们的开源实现。

进一步实验结果

图片
表1:真实数据集统计以及评估指标。标题

为了进一步展示AMFormer的效果,我们挑选了四个真实数据集进行实验。被挑选数据集覆盖了二分类、多分类以及回归任务,数据集统计如表1所示。

图片
标表2:AMFormer以及基准方法的性能对比,其中括号内的数字表示该方法在当前数据集上表现的排名,最优以及次优的结果分别以加粗以及下划线突出。题

我们一共测试了包含传统树模型(XGBoost)、树架构深度学习方法(NODE)、高阶特征交互(DCN-V2、DCAP)以及Transformer派生架构(AutoInt、FT-Trans)在内的六个基准算法以及两个AMFormer实现(分别选择AutoInt、FT-Trans做基础架构,即AMF-A和AMF-F),结果汇总在表2中。

在一系列对比实验中,AMFormer表现更突出。结果显示,基于MLP的深度学习方法如DCN-V2在表格数据上的性能不尽如人意,而基于Transformer架构的模型显示出更大的潜力,但未能始终超过树模型XGBoost。我们的AMFormer在四个不同的数据集上,与所有六个基准模型相比,表现一致更优:在分类任务中,它将AutoInt和FT-transformer的准确率或AUC提升至少0.5%,最高达到1.23%(EP)和4.96%(CO);在回归任务中,它也显著减少了平均平方误差。相比其它深度表格学习方法,AMFormer具有更好的鲁棒和稳定性,这使得在性能排序中AMFormer断层式优于其它基准算法,这些实验结果充分证明了AMFormer在深度表格学习中的必要性和优越性。

结论

本工作研究了深度模型在表格数据上的有效归纳偏置。我们提出,算术特征交互对于表格深度学习是必要的,并将这一理念融入Transformer架构中,创建了AMFormer。我们在合成数据和真实世界数据上验证了AMFormer的有效性。合成数据的结果展示了其在精细表格数据建模、训练数据效率以及泛化方面的优越能力。此外,对真实世界数据的广泛实验进一步确认了其一致的有效性。因此,我们相信AMFormer为深度表格学习设定了强有力的归纳偏置。

进一步阅读:

● 论文标题:

Arithmetic Feature Interaction is Necessary for Deep Tabular Learning

● 论文作者:

程奕、胡仁君、应豪超、施兴、吴健、林伟

● 论文PDF链接:

https://arxiv.org/abs/2402.02334

● 代码链接:

https://github.com/aigc-apps/AMFormer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/460079.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

html5的css使用display: flex进行div居中的坑!

最近做项目的时候,有个需求,一个高度宽度不确定的Div在另一个Div内上下左右居中。 然后以前上下居中用的都是很繁琐的,就打算去百度搜索一个更优秀的方法。 百度AI自己给我一个例子: /* div在容器里居中显示,设置外容…

单片机学到什么程度才可以去工作?

单片机学到什么程度才可以去工作? 如果没有名校或学位的加持,你还得再努力一把,才能从激烈的竞争中胜出。以下这些技能可以给你加分,你看情况学,不同行业对这些组件会有取舍: . Cortex-M内核:理解MCU内核各部件的工作机制&#…

windows的maven 低版本如何切换到高版本

要升级到 Maven 3.9.x 版本,可以按照以下步骤操作: 下载 Maven 3.9.x: 访问 Maven 的官方网站(https://maven.apache.org/download.cgi)并下载 Maven 3.9.x 版本的压缩包。选择与你的操作系统兼容的版本。 2. 解压缩 Maven 3.9.x…

一、MySQL基础学习

目录 1、MySQL启动2、MySQL客户端连接3、SQL3.1、SQL语句分类3.2、DDL(数据库定义语言)3.2.1、操作数据库3.2.2、操作数据表 3.3、DML(数据库操作语言)3.3.1、增加 insert into3.3.2、删除 delete3.3.3、修改 update 3.4、DQL&…

移动云COCA架构实现算力跃升,探索人工智能新未来

近期,随着OpenAI正式发布首款文生视频模型Sora,标志着人工智能大模型在视频生成领域有了重大飞跃。Sora模型不仅能够生成逼真的视频内容,还能够模拟物理世界中的物体运动与交互,其核心在于其能够处理和生成具有复杂动态与空间关系…

逆序对的数量 刷题笔记

思路 使用归并排序 在每次返回时 更新增加答案数 因为归并排序的两个特点 第一 使用双指针算法 第二 层层返回 从局部有序合并到整体有序 例如 {4 ,1 ,2 ,3} 划分到底层是四个数组 {4},{1},{3}, {…

【算法杂货铺】二分算法

目录 🌈前言🌈 📁 朴素二分查找 📂 朴素二分模板 📁 查找区间端点处 细节(重要) 📂 区间左端点处模板 📂 区间右端点处模板 📁 习题 1. 35. 搜索插入位…

『 Linux 』进程替换( Process replacement ) 及 简单Shell的实现(万字)

文章目录 🦄 进程替换🦩 execl()函数🦩 execlp()函数🦩 execle()函数🦩 execv()函数🦩 execvp()函数🦩 execvpe()函数🦩 execve()函数 🦄 简单Shell命令行解释器的实现&a…

Centos8安装Docker,使用阿里云源

一、前期准备 1.关闭防火墙,SELINUX systemctl stop firewalld.service systemctl disable firewalld.service setenforce 0 sed -i "s/SELINUXenforcing/SELINUXdisabled/g" /etc/selinux/config查看状态 systemctl status firewalld systemctl status…

汇编语言(Assemble Language)学习笔记(更新中)

零.学习介绍和使用工具 【1】我们使用的教材是机械工业出版社的《32位汇编语言程序设计第二版》。 指导老师是福州大学的倪一涛老师。 这门课程教授的是Intel 80*86系列处理器的32位汇编。我们现在的处理器都兼容这个处理器。 这篇博客只是大二下汇编语言学习的总结&#xff…

【C++设计模式】策略模式

文章目录 前言一、策略模式是什么?二、策略模式的实现原理三、UML图四、代码实现总结 前言 策略模式是一种行为设计模式,它允许在运行时选择算法的行为。通过将每个算法封装到具有共同接口的独立类中,客户端可以在不改变自身代码的情况下选择…

css3 实现html样式蛇形布局

文章目录 1. 实现效果2. 实现代码 1. 实现效果 2. 实现代码 <template><div class"body"><div class"title">CSS3实现蛇形布局</div><div class"list"><div class"item" v-for"(item, index) …

【C#】WPF 将string数据导出txt

示例 代码 string allInfo "123"; SaveFileDialog saveFileDialog new SaveFileDialog(); saveFileDialog.Filter "*.txt|*.txt|所有文件(*.*)|*.*"; if (!(bool)saveFileDialog.ShowDialog()) {return; } string fileName saveFileDialog.FileName; …

arcgis 计算某点到其他城市的距离,含要素转点(以北京市到各个地级市的距离为例)

导入地级市的地图导入要计算距离的地带你计算距离&#xff1a;点距离或者近邻分析 以北京市到各个地级市的距离为例 到入地级市&#xff0c;并复制一个同款地级市&#xff0c;导入一个有北京市的Excel 将一个地级市进行要素转点&#xff1a;data management tools——要素——…

【C语言】—— 指针二 : 初识指针(下)

【C语言】——函数栈帧 一、 c o n s t const const 修饰指针1.1、 c o n s t const const 修饰变量1.2、 c o n s t const const 修饰指针 二、野指针2.1野指针的成因&#xff08;1&#xff09;指针未初始化&#xff08;2&#xff09;指针越界访问&#xff08;3&#xff09;指…

4.9.CVAT——用长方体进行注释

文章目录 1.创建长方体1.1.按4点绘制长方体1.2.从长方形画出长方体 2.编辑长方体 它用于注释 3 维物体&#xff0c;例如汽车、盒子等。目前该功能支持单点透视&#xff0c;并具有垂直边缘与侧面完全平行的约束。 1.创建长方体 1.1.按4点绘制长方体 在开始之前&#xff0c;您必…

基于Java+SpringMVC+vue+element宠物管理系统设计实现

基于JavaSpringMVCvueelement宠物管理系统设计实现 博主介绍&#xff1a;5年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获取源…

华为配置中心AP内漫游实验

华为配置中心AP内漫游示例 组网图形 图1 配置中心AP内漫游组网图 配置流程组网需求配置思路数据规划配置注意事项操作步骤配置文件 配置流程 WLAN不同的特性和功能需要在不同类型的模板下进行配置和维护&#xff0c;这些模板统称为WLAN模板&#xff0c;如域管理模板、射频模…

❤ css布局篇

❤ css布局篇 一、基础布局 &#xff08;1&#xff09;居中布局 ① 文字居中 <div class"div1">测试文字居中</div> body {margin: 0;padding: 0;padding: 10%; } .div1 {width: 100px;height: 100px;background: cadetblue;text-align: center; }te…

微信小程序--分享如何与ibeacon蓝牙信标建立联系

ibeacon蓝牙设备 iBeacon是苹果公司2013年9月发布的移动设备用OS&#xff08;iOS7&#xff09;上配备的新功能。其工作方式是&#xff0c;配备有 低功耗蓝牙&#xff08;BLE&#xff09;通信功能的设备使用BLE技术向周围发送自己特有的ID&#xff0c;接收到该ID的应用软件会根…