WWW24因果论文(3/8) |通过因果干预实现图分布外泛化

【摘要】由于图神经网络 (GNN) 通常会随着分布变化而出现性能下降,因此分布外 (OOD) 泛化在图学习中引起了越来越多的关注。挑战在于,图上的分布变化涉及节点之间错综复杂的互连,并且数据中通常不存在环境标签。在本文中,我们采用自下而上的数据生成视角,并通过因果分析揭示了一个关键观察结果:GNN 在 OOD 泛化中失败的关键在于来自环境的潜在混杂偏差。后者误导模型利用自我图特征与目标节点标签之间的环境敏感相关性,导致在新的未见节点上出现不良的泛化。基于这一分析,我们引入了一种概念上简单但原则性的方法,用于在节点级分布变化下训练稳健的 GNN,而无需事先了解环境标签。我们的方法采用了一种源自因果推理的新学习目标,该目标协调了环境估计器和专家混合 GNN 预测器。新方法可以抵消训练数据中的混杂偏差,并促进学习可推广的预测关系。大量实验表明,我们的模型可以有效地增强各种分布偏移下的泛化能力,并且在图 OOD 泛化基准上比最先进的方法提高高达 27.4% 的准确率。

原文:Graph Out-of-Distribution Generalization via Causal Intervention
地址:https://arxiv.org/abs/2402.11494
代码:https://github.com/fannie1208/CaNet
出版:www 24
机构: 上海交通大学

写的这么辛苦,麻烦关注微信公众号“码农的科研笔记”!

1 研究问题

本文研究的核心问题是: 如何设计一个图神经网络模型,使其能够在结点属性分布发生变化时,仍然保持良好的泛化性能。

假设一个社交网络中,用户的爱好与其朋友的年龄分布密切相关。在大学生群体中,朋友都比较年轻的用户往往更喜欢篮球运动。但是这种相关性可能只在大学生群体中成立,对于职场社交网络LinkedIn,用户的年龄与其爱好的相关性可能就很弱。如果我们基于大学生的社交网络训练了一个用于预测用户爱好的图神经网络,那么将其直接应用于LinkedIn,可能会遇到泛化失败的问题。

本文研究问题的特点和现有方法面临的挑战主要体现在以下几个方面:

  • 图数据中的分布变化往往涉及到结点之间的复杂交互与关联,需要模型能够充分考虑不同结点的结构化特征。

  • 在图学习问题中,每个结点所处的环境信息通常是隐含的,难以直接获取。这为模型从观测数据中推断有用的环境信息,以指导学习过程,带来了障碍。

针对这些挑战,本文提出了一种基于因果干预的"因果网络(CaNet)"方法:

CaNet巧妙地借鉴了因果推理中的 do-calculus 思想,通过显式地对环境变量建模,消除了隐含的混淆偏差。具体来说,它引入了一个环境估计器,负责基于输入的局部子图推断可能的环境信息。同时,图神经网络的每一层都配备了一组 mixture-of-expert 的传播单元,可以动态地根据推断的环境选择不同的传播方式。通过环境估计器和图神经网络的协同优化,CaNet 可以自动地发现观测数据中的稳定关系,同时避免捕获那些容易受环境变化影响的虚假相关性。这一设计理念犹如为图神经网络装上了一副"透视镜",让它对不可见的因果机制具备了感知和适应的能力。

2 研究方法

2.1 因果分析

论文首先从因果的角度来分析GNN面临分布外泛化问题的根本原因。如图2(a)所示,论文使用有向无环图建模了节点的ego-graph特征、节点标签和环境因素之间的因果依赖关系。可以看到,作为未观测的混淆因素,会影响和的生成过程。当使用最大似然估计来训练GNN时,由于忽略了的影响,模型会错误地学习到由某些特定的引起的和之间的强相关性(如在大学生群体中,"朋友年轻"和"喜欢打篮球"往往同时成立)。然而,这类相关性是不稳定的,一旦测试环境发生变化(如职场人士的社交网络),先前学到的相关性便不再成立。这导致了GNN在分布外数据上的泛化性能显著下降。

2.2 因果干预

为了消除环境因素的混淆偏差,进而提升GNN的分布外泛化性能,论文提出了一种基于因果干预的方法。借助后门调整公式,论文指出,优化干预分布而非观测分布,可以有效避免环境因素的混淆。然而,的求解需要穷举所有可能的环境,这在实际中是不可行的。为此,论文进一步引入变分推断,得到的一个变分下界,如式(5)所示:

其中,是根据节点的ego-graph特征来推断环境因素的估计器,是给定ego-graph特征和推断的环境标签来预测节点标签的GNN。通过最大化该变分下界,可以得到协同优化的算法:环境估计器尽可能准确地推断环境因素,同时要求推断结果与ego-graph特征保持独立;而GNN预测器则根据ego-graph特征和推断的环境因素来预测节点标签。通过协同学习,模型可以学习到与具体环境无关的稳定预测模式,进而提升分布外泛化性能。

2.3 模型实例化

在CaNet中,环境估计器将环境因素表示为一系列伪环境标签向量,其中表示GNN的层数。如式(6-7)所示(太难打了,公式见原文),对于每个节点的第层表示,环境估计器首先计算该节点属于每个伪环境的概率,然后通过Gumbel-Softmax技巧对重参数化,得到。

GNN预测器的核心是一个层的混合专家传播网络,每一层包含个专家分支。如式(8-9)所示,每个分支采用独立的参数,并由推断的伪环境标签进行选择。不同分支学习ego-graph特征的不同组合模式,赋予模型更强的表征能力。GNN预测器的最后一层输出节点表示,并通过全连接层映射为节点标签的预测值。

算法1总结了CaNet的前向计算和训练优化流程。其中,环境估计器和GNN预测器通过梯度下降法交替优化,协同学习对分布外泛化有利的预测模式。

5 实验

5.1 实验场景介绍

该论文提出了一个处理图神经网络节点级别分布差异的因果干预方法CaNet。实验主要在节点属性预测任务中,验证CaNet相比其他模型在训练集和测试集节点分布不一致情况下的泛化优势。同时通过消融实验、超参数分析等进一步探究模型内部机制。

5.2 实验设置

  • Datasets:使用Cora、Citeseer、Pubmed、Twitch、Arxiv、Elliptic等6个不同规模和属性的节点预测数据集,通过时间属性、子图、动态快照等不同方式构建训练集和测试集的分布差异

  • Baseline:ERM, IRM, DeepCoral, DANN, GroupDRO, Mixup等通用OOD方法;SR-GNN, EERM等图数据OOD方法;均使用GCN和GAT作为编码器骨干

  • Implementation details:基于PyTorch 1.13和PyG 2.1, Adam优化器,训练500轮,网格搜索超参数

  • metric:Accuracy, ROC-AUC, macro F1

5.3 实验结果

5.3.1 实验一、不同数据集上的性能对比

目的:在多个数据集上验证CaNet相比其他模型处理分布差异节点的优势

涉及图表:表1,表3,图4

实验细节概述:在Cora、Citeseer、Pubmed上测试合成特征和结构导致的分布差异;在Arxiv上测试不同时间的论文节点;在Twitch上测试不同子图的节点;在Elliptic上测试不同时间快照的节点

结果:

  • CaNet在所有OOD测试集上显著优于对应的基线,在ID测试集上也有竞争力的表现

  • 在Cora和Citeseer上,CaNet在OOD数据的绝对性能接近ID数据

  • 在Arxiv的跨时间差异最大的测试集上,CaNet超出次优baseline 14.1%和27.4%

  • 在Elliptic的动态图快照测试集上,CaNet平均超出次优baseline 12.16%

5.3.2 实验二、消融实验

目的:验证正则化损失、层级环境推断等关键组件的有效性

涉及图表:图5

实验细节概述:去除正则化损失、使用复杂先验分布、采用全局环境表示、使用非参数环境估计器等简化变体

结果:

  • 正则化损失和层级环境推断能有效提升OOD性能

  • 简单先验分布优于复杂先验,更利于泛化

5.3.3 实验三、超参数分析

目的:探究伪环境数K和温度τ对模型性能的影响

涉及图表:图6

实验细节概述:在Arxiv和Twitch上分别评估不同K和τ下模型在各OOD测试集的表现

结果:

  • 性能对K不太敏感,过大或过小的K在Arxiv的OOD 2/3上可能降低性能

  • 适中的τ(如1)效果最佳,过大的τ会导致性能下降

5.3.4 实验四、可视化分析

目的:直观展现不同分支学习到的权重模式差异

涉及图表:图7,8,9,10

实验细节概述:可视化K=3时模型在Arxiv和Twitch上第一层和最后一层不同分支的权重矩阵

结果:不同分支权重有明显差异,说明mixture-of-expert结构能学习到区分不同伪环境的表达模式,利于泛化

4 总结后记

本论文针对图神经网络在面对分布偏移时泛化能力较差的问题,从因果分析的角度揭示了其根源在于未观测到的环境混淆因素。基于此分析,提出了一种通过因果干预改进图神经网络泛化性的方法CaNet。该方法引入了环境估计器和混合专家传播网络,可以在没有先验环境标签的情况下,通过优化一个新的学习目标来捕获对环境不敏感的预测关系,从而提高模型的分布外泛化能力。实验结果表明,该模型在多个具有不同类型分布偏移的数据集上,相比现有方法可以显著提升泛化性能,泛化准确率提升高达27.4%。

疑惑和想法:

  1. 除了节点层面的分布偏移,该方法是否可以推广到处理图层面或子图层面的分布偏移?

  2. 环境估计器推断出的伪环境标签对应着什么物理含义?能否赋予它们可解释性?

  3. 除了混合专家传播网络,是否可以设计其他形式的条件传播机制来建模环境因素对节点表示的影响?

可借鉴的方法点:

  1. 从因果分析角度揭示模型泛化不足的根源,为诊断和改进其他类型的图学习任务提供了新的思路。

  2. 通过优化包含对抗正则化项的目标函数来消除混淆偏差,可以推广到其他需要增强鲁棒性的机器学习场景。

  3. 环境估计器和条件传播网络的思想可以借鉴到图预训练等其他图表示学习任务中,以建模不同图之间的差异。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/654008.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C#电子名片(vCard)

目录 1.介绍 2.基本定义 3.字段信息 4,字段详解。 4.1,预定义类型的用法 4.2,基本类型 4.3,通讯地址类型 4.4,电信通信类型 4.5,地理类型 4.6,解释类型 5,应用。 6&…

Rviz 复选框插件

Rviz 复选框插件 0.引言1.实现效果 0.引言 参考1参考2参考3参考4 我想做的插件是类似于 pangolin 侧面的复选框,动态传递 bool 值给程序内部使用。查了一下只能是通过插件的方式进行实现。但是Display 的参数在编译阶段就写死了,我想要在运行期给定参数…

DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion

UW&UCB&Google&NVIDIA ICCV23https://github.com/johannakarras/DreamPose?tabreadme-ov-file 问题引入 输入参考图片 x 0 x_0 x0​和pose序列 { p 1 , ⋯ , p N } \{p_1,\cdots,p_N\} {p1​,⋯,pN​},输出对应视频 { x 1 ′ , ⋯ , x N ′ } \{x_1,…

失落的方舟台服注册接收不到验证码 注册怎么验证手机号教程

《失落的方舟》(Lost Ark)是一款引人入胜的大型多人在线角色扮演游戏(MMORPG),由韩国知名游戏开发商Smilegate精心打造。这款游戏凭借其绚丽的视觉效果、错综复杂的故事情节、媲美动作游戏的战斗机制以及无边无际的探索…

【OCPP】ocpp1.6协议第3.13章节SmartCharging介绍及翻译

目录 3.13 SmartCharging智能充电-概述 智能充电的目标 关键功能 消息类型 负载管理 动态电量配置 总结 3.13 SmartCharging智能充电-译文 3.13.1 Charging Profile Purpose充电配置的目的 3.13.2 Stacking charging profile堆叠充电配置 3.13.3 Combining charging profile pu…

为什么要学习c++?

你可能在想,“C?那不是上个时代的产物吗?” 哎呀,可别小看了这位“老将”,它在21世纪的科技舞台上依旧光芒万丈,是许多尖端技术不可或缺的基石! 1. 无可替代 c源于c语言,它贴近于硬…

C++面向对象程序设计 - 标准输出流

在C中,标准输出流通常指的是与标准输出设备(通常是终端或控制台)相关联的流对象。这个流对象在C标准库中被定义为std::cout、std::err、std::clog,它们是std::ostream类的一个实例。 一、cout,cerr和clog流 ostream类…

去除uni微信小程序button的边框

想要去除button的边框,如下未去除边框时,非常影响观感。 解决方法 使用伪元素::after,简单但是易忘,正常情况下,我直接是给button上加上一个类名直接设置border:none,但是这样是无效的,应该如下…

【软考】下篇 第13章 层次式架构设计理论与实践

目录 一、概述1.2 通用分层1.2 分层架构注意点 二、表现层(展示层)2.1 表现层设计模式(MVC、MVP、MVVM)2.2 使用XML设计表现层2.3 UIP2.4 表现层动态生成设计思想 三、中间层(业务层、业务逻辑层)3.1 组件设计3.2 工作流设计3.3 实体设计3.4 业务逻辑层框…

CC1310 Debug interface is locked

XDS110连接CC1310板子,打开Smart RF 弹出窗口如下: 解决办法: 1 打开SmartRF Flash Programmer 2 选择连接的设备CC1310, 弹出如下窗口,点击OK。 3 点击Tools图标,选择CC26XX/CC13XX Forced Mass Erase。 4 在弹出的…

分布式锁的三种实现方案

目录 为什么需要分布式锁实现一个分布式锁需要考虑的点如何实现一个分布式锁基于数据库的实现基于 zookeeper 实现基于 redis 实现 实际项目中好用的轮子 为什么需要分布式锁 传统的 jdk 锁,比如 synchronized, reentrantlock 只能在单机环境中使用分布式场景下&am…

5.26 基于UDP的网络聊天室

需求&#xff1a; 如果有人发送消息&#xff0c;其他用户可以收到这个人的群聊信息 如果有人下线&#xff0c;其他用户可以收到这个人的下线信息 服务器可以发送系统信息实现模型 模型&#xff1a; 代码&#xff1a; //chatser.c -- 服务器端实现 #include <stdio.h>…

FastAPI - 组织模块2

FastAPI没有强制指定某种格式来组织项目结构&#xff0c;开发者可以根据自己喜好和项目需要来定制自己的项目结构。 https://fastapi.tiangolo.com/zh/tutorial/bigger-applications/ 在项目根目录创建python包routers&#xff0c;然后创建member.py文件 member.py文件内容 …

互联网应用主流框架整合之AOP

SpringAOP基本概念和约定流程 在实际的开发中有些内容并不是面向对象能够解决的&#xff0c;比如数据库事务&#xff0c;对于企业级应用非常重要&#xff0c;比如在电商系统中订单模块和交易模块&#xff0c;一边是交易记录一边是账户数据&#xff0c;就需要对这两者有统一的事…

权限维持--linux

隐藏文件/夹&-开头文件 如何创建: 在文件名之前加.即可 touch .1.s 如何清除、查找&#xff1a; ls -al rm -fr -文件 已-开头的文件直接读取是不行的需要带目录 隐藏时间戳 ①用其他文件的时间 touch -r zww.php testq.txt 如何清除、查看&#xff1a; stat test…

大模型基础知识

文章目录 1. 位置编码1.1 绝对位置编码1.2 相对位置编码1.3 旋转位置编码2. 注意力机制2.1 MHA(muti head attention)2.2 MQA(muti query attention)2.3 GQA(grouped query attention)3. 大模型分类4. 微调方法4.1 Prompt Tuning4.2 Prefix Tuning4.3 Lora4.4 QLora5. La…

探索机器人智能设备:开启智慧生活新篇章

机器人智能设备作为科技创新的代表&#xff0c;正以其独特的魅力吸引着越来越多的关注。它们不仅具备高度的智能化和自主化能力&#xff0c;还能在各种场景下发挥出强大的功能。 机器人智能设备的张总说&#xff1a;在智能家居领域&#xff0c;机器人智能设备可以帮助我们实现家…

改进rust代码的35种具体方法-类型(十九)-避免使用反射

上一篇文章 从其他语言来到Rust的程序员通常习惯于将反思作为工具箱中的工具。他们可能会浪费很多时间试图在Rust中实现基于反射的设计&#xff0c;却发现他们所尝试的事情只能做得不好&#xff0c;如果有的话。这个项目希望通过描述Rust在反思方面做什么和不做什么&#xff0c…

2024年【西式面点师(中级)】新版试题及西式面点师(中级)考试试卷

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年【西式面点师&#xff08;中级&#xff09;】新版试题及西式面点师&#xff08;中级&#xff09;考试试卷&#xff0c;包含西式面点师&#xff08;中级&#xff09;新版试题答案和解析及西式面点师&#xff08;…

手动安装maven依赖到本地仓库

使用mvn install命令安装jar包到指定的仓库。 命令如下&#xff1a; mvn install:install-file -Dmaven.repo.localC:\Users\liyong.m2\repository -DgroupIdcom.aspose -DartifactIdwords -Dversion18.4 -Dpackagingjar -DfileC:\Users\liyong\Desktop\jar\words-18.4.jar 解释…