机器学习第四十六周周报 FMP

文章目录

  • week46 FMP
  • 摘要
  • Abstract
    • 1. 题目
    • 2. Abstract
    • 3. FMP
      • 3.1 优化框架
      • 3.2 优化器
    • 4. 文献解读
      • 4.1 Introduction
      • 4.2 创新点
      • 4.3 实验过程
    • 5. 结论
    • 6.代码复现
      • 1. FMP
      • 2. fairGNN
      • 小结
      • 参考文献

week46 FMP

摘要

本周阅读了题为Chasing Fairness in Graphs: A GNN Architecture Perspective的论文。该文提出了一种基于图神经网络(GNN)统一优化框架指导的公平消息传递方案(FMP),旨在提高图数据处理的公平性。该方案通过两个核心步骤实现:首先聚合图数据,随后明确追求代表各统计学群体中心以减少偏见。该方法构建了一个同时考虑公平性和数据平滑性的优化问题,并运用Fenchel共轭和梯度下降技术结合softmax函数特性来高效求解,生成既公平又信息丰富的节点表示。此方案直接嵌入到GNN中,以改善节点分类任务的公平性和准确性,无需数据预处理。实验结果显示,在真实数据集上,FMP较多个基线模型表现更佳,从模型架构、效率及敏感属性利用等角度全面验证了其有效性。

Abstract

This week’s weekly newspaper decodes the paper entitled Chasing Fairness in Graphs: A GNN Architecture Perspective. This paper introduces a Fair Message Passing (FMP) scheme guided by a unified optimization framework of Graph Neural Networks (GNNs), aiming to enhance fairness in graph data processing. The FMP achieves this through two core steps: first, aggregating graph data, and then explicitly striving to represent the centers of various statistical groups to mitigate bias. This approach formulates an optimization problem that considers both fairness and data smoothness, and leverages Fenchel duality and gradient descent techniques, combined with the properties of the softmax function, to efficiently solve the problem and generate fair and informative node representations. This scheme is directly embedded into GNNs to improve the fairness and accuracy of node classification tasks, without the need for data preprocessing. Experimental results on real-world datasets show that FMP outperforms multiple baseline models, comprehensively validating its effectiveness from the perspectives of model architecture, efficiency, and the utilization of sensitive attributes.

1. 题目

标题:Chasing Fairness in Graphs: A GNN Architecture Perspective

作者:Zhimeng Jiang1, Xiaotian Han1, Chao Fan2, Zirui Liu3, Na Zou4, Ali Mostafavi1, Xia Hu3

发布:Vol. 38 No. 19: AAAI-24

链接:https://doi.org/10.1609/aaai.v38i19.30115

2. Abstract

该文旨在通过新的 GNN 框架实现更好的公平性,故提出在 GNN 的统一优化框架内设计的公平消息传递(FMP)。值得注意的是,FMP 使用交叉熵损失显式地呈现节点分类任务的前向传播中敏感属性的使用,而无需进行数据预处理。在FMP中,首先采用聚合来利用邻居的信息,然后偏差缓解步骤明确地将数据统计组节点表示中心推到一起。通过这种方式,FMP 方案可以聚合来自邻居的有用信息并减轻偏差,以实现更好的公平性和预测权衡性能。节点分类任务的实验表明,所提出的 FMP 在三个真实世界数据集上的公平性和准确性方面优于多个基线。

3. FMP

FMP 可以从模型主干的角度实现公平的预测。具体来说,将公平消息传递制定为一个优化问题,以同时追求平滑性和公平节点表示。结合有效且高效的优化算法,推导出封闭形式的公平消息传递。最后,所提出的 FMP 在三个阶段被集成到公平的 GNN 中,包括变换、聚合和去偏差步骤,如图 1 所示。这三个阶段分别采用节点特征、图拓扑和敏感属性。

image-20240707151714855

3.1 优化框架

min ⁡ F λ s 2 t r ( F T T ~ F ) + 1 2 ∣ ∣ F − X t r a n s ∣ ∣ F 2 + λ f ∣ ∣ Δ s S F ( F ) ∣ ∣ 1 (1) \min_{\mathbf F}\frac{\lambda_s}{2}tr(\mathbf F^T \tilde {\mathbf T}\mathbf F)+\frac12||\mathbf F-\mathbf X_{trans}||^2_F+\lambda_f||\Delta_sSF(\mathbf F)||_1 \tag{1} Fmin2λstr(FTT~F)+21∣∣FXtransF2+λf∣∣ΔsSF(F)1(1)

$\tilde L 表示归一化拉普拉斯矩阵, 表示归一化拉普拉斯矩阵, 表示归一化拉普拉斯矩阵,h_s(·)$ 和 h f ( ⋅ ) h_f(·) hf() 表示平滑性和公平性目标, X t r a n s ∈ R n × d o u t X_{trans} \in R^{n×d_{out}} XtransRn×dout 是变换后的 d o u t d_{out} dout维节点特征。 F ∈ R n × d o u t F \in R^{n×d_{out}} FRn×dout​ 是相同矩阵大小的聚合节点特征。

前两项保留了连接节点表示的相似性,从而增强了图的平滑性。最后一项强制执行公平的节点表示,以便不同敏感属性组之间的平均预测概率可以保持不变。

正则化系数 λs 和 λf 自适应地控制图平滑性和公平性之间的权衡。
h s ( F ) = min ⁡ F λ s 2 t r ( F T T ~ F ) + 1 2 ∣ ∣ F − X t r a n s ∣ ∣ F 2 h_s(\mathbf F)=\min_{\mathbf F}\frac{\lambda_s}{2}tr(\mathbf F^T \tilde {\mathbf T}\mathbf F)+\frac12||\mathbf F-\mathbf X_{trans}||^2_F hs(F)=Fmin2λstr(FTT~F)+21∣∣FXtransF2
平滑度目标 hs(·):现有图消息传递方案中的相邻矩阵被归一化,以提高数值稳定性并实现卓越的性能。从以边为中心的角度来看,平滑目标强制连接的节点表示相似,因为
t r ( F T T ~ F ) = ∑ ( v i , v j ) ∈ E ∣ ∣ F i d i + 1 − F i d j + 1 ∣ ∣ F 2 (2) tr(\mathbf F^T \tilde {\mathbf T}\mathbf F)=\sum_{(v_i,v_j)\in {\Epsilon}}||\frac{\mathbf F_i}{\sqrt {d_i+1}}-\frac{\mathbf F_i}{\sqrt{d_j+1}}||^2_F \tag{2} tr(FTT~F)=(vi,vj)E∣∣di+1 Fidj+1 FiF2(2)
公平目标 hf(·):公平性目标衡量聚合后节点表示的偏差。敏感属性事件向量 Δs 通过符号和绝对值求和表示敏感属性组和组大小。敏感属性事件向量为
Δ s = 1 > 0 ( s ) ∣ ∣ 1 > 0 ( s ) ∣ ∣ 1 − 1 > 0 ( − s ) ∣ ∣ 1 > 0 ( − s ) ∣ ∣ 1 (3) \Delta_s=\frac{\mathbf 1_{>0}(\mathbf s)}{||\mathbf 1_{>0}(\mathbf s)||_1}-\frac{\mathbf 1_{>0}(-\mathbf s)}{||\mathbf 1_{>0}(-\mathbf s)||_1} \tag{3} Δs=∣∣1>0(s)11>0(s)∣∣1>0(s)11>0(s)(3)
SF(F) 表示节点分类任务的预测概率,其中 S F ( F ) i j = P ^ ( y i = j ∣ X ) SF(\mathbf F)_{ij} = \hat P(y_i = j|\mathbf X) SF(F)ij=P^(yi=jX)。公平性目标, Δ s S F ( F ) Δ_sSF(F) ΔsSF(F)​ 的 l1 范数表征了具有不同敏感属性的两组之间的预测概率差异。因此,提出的优化框架可以同时追求图的平滑性和公平性。
( Δ s S F ( F ) ) j = P ^ ( y i = j ∣ s i = 1 , X ) − P ^ ( y i = j ∣ s i = − 1 , X ) (Δ_sSF(F))_j=\hat P(y_i=j|s_i=1,\mathbf X)-\hat P(y_i=j|\mathbf s_i=-1,\mathbf X) (ΔsSF(F))j=P^(yi=jsi=1,X)P^(yi=jsi=1,X)

3.2 优化器

双级优化问题公式化:Fenchel 共轭(又名凸共轭)(Rockafellar 2015)可以使用原始算法将原始问题转换为等效鞍点问题(Liu et al. 2021)。这样可以降低计算复杂度并兼容反向传播训练。类似地,为了以更有效和高效的方式解决优化问题1,引入Fenchel共轭(Rockafellar 2015)将原始问题转化为双层优化问题。对于一般凸函数h(·),其共轭函数定义为 h ∗ ( U ) ≜ s u p X ⟨ U , X ⟩ − h ( X ) h^*(U) \triangleq sup_X ⟨U,X⟩ − h(X) h(U)supXU,Xh(X)。基于 Fenchel 共轭,公平性目标可以转化为变分表示 h f ( p ) = s u p u ⟨ p , u ⟩ − h f ∗ ( u ) h_f(p) = sup_u ⟨p, u⟩ − h^∗_f(u) hf(p)=supup,uhf(u),其中 p = Δ s S F ( F ) ∈ R 1 × d o u t p = Δ_sSF(F) ∈ R^{1×d_{out}} p=ΔsSF(F)R1×dout​​ 是预测概率用于分类的向量。此外,原始优化问题等价于
min ⁡ F max ⁡ u h s ( F ) + ⟨ p , u ⟩ + h f ∗ ( u ) (4) \min_{\mathbf F}\max_{\mathbf u}h_s(\mathbf F)+⟨p,u⟩+h^*_f(u) \tag{4} Fminumaxhs(F)+p,u+hf(u)(4)
解决方法:最小-最大优化问题 (4) 可以通过以下定点方程求解,每次迭代计算复杂度较低和收敛保证
{ F = F − ∇ h s ( F ) − ∂ ⟨ p , u ⟩ ∂ F u = prox h f ∗ ( u + Δ s S F ( F ) ) (5) \begin{cases} \mathbf F=\mathbf F-\nabla h_s(\mathbf F)-\frac{\partial ⟨p,u⟩}{\partial \mathbf F}\\ \mathbf u=\text{prox}_{h^*_f}(\mathbf u+\Delta_s SF(\mathbf F)) \end{cases} \tag{5} {F=Fhs(F)Fp,uu=proxhf(u+ΔsSF(F))(5)
采用迭代算法来寻找最小-最大优化问题的鞍点。具体来说,从 (Fk, uk) 开始,对原始变量 F 采用梯度下降步骤到达 ( F ‾ k + 1 , u k ) ( \overline {\mathbf F}^{k+1}, \mathbf u^k) (Fk+1,uk),然后在对偶变量 u 中进行近端上升步骤。最后,对点 ( F ‾ k + 1 , u k ) (\overline {\mathbf F}^{k+1}, \mathbf u^k) (Fk+1,uk) 中的原始变量进行梯度下降以到达 ( F k + 1 , u k ) (\mathbf F^{k+1}, \mathbf u^k) (Fk+1,uk)​。简而言之,迭代可以概括为:

image-20240707164033812

其中 γ 和 β 是原始变量和对偶变量的步长。注意, ∂ ⟨ p , u ⟩ ∂ F ∈ R n × d o u t \frac{\partial ⟨p,u⟩}{ \partial F} ∈ R^{n×d_{out}} Fp,uRn×dout prox β h f ∗ ( ⋅ ) \text{prox}_{βh^∗_f} (·) proxβhf()​ 的闭式还不清楚,将提供解决方案。

FMP:令 γ = 1 1 + λ s , β = 1 2 γ \gamma=\frac1{1+\lambda_s},\beta=\frac1{2\gamma} γ=1+λs1,β=2γ1

image-20240707164631002

故可将提出的 FMP 总结为两个阶段,包括跳跃连接传播 step1 和偏差缓解 step2-5。step2更新聚合节点特征以实现公平性目标;step3&4旨在分别学习和“重塑”概率空间中的扰动向量。step5基于原始变量的梯度下降显式减轻节点特征的偏差。数学公式如下:

image-20240707164809372

梯度计算加速:采用softmax属性来加速梯度计算。

根据下面定理中的softmax函数的性质简化了梯度计算。

T h e o r e m   0.2 ( Gradient Computation ) \mathbf {Theorem\ 0.2}(\text{Gradient\ Computation}) Theorem 0.2(Gradient Computation)
T h e   g r a d i e n t   o v e r p r i m a l   v a r i a b l e   ∂ ⟨ p , u ⟩ ∂ F   s a t i s f i e s \\The\ gradient\ over primal\ variable\ \frac{\partial ⟨p,u⟩}{\partial \mathbf {F}}\ satisfies\\ The gradient overprimal variable Fp,u satisfies
image-20240707165220629

这样的梯度计算可以解释为三个步骤:Softmax变换、概率空间中的扰动以及表示空间中的去偏。具体来说,首先通过 softmax 变换将节点表示映射到概率空间。随后,计算概率空间中公平目标的梯度。可以看出,扰动 Δ s ⊤ u Δ^⊤_s u Δsu 实际上在概率空间中造成了低秩去偏,其中具有不同敏感属性的节点包含相反的扰动。换句话说,对偶变量u代表概率空间中的扰动方向。最后,概率空间中的扰动将通过雅可比变换 ∂ S F ( F ) ∂ F \frac{∂SF(F)}{∂F} FSF(F)​变换到表示空间。

FMP可以为敏感属性提供白盒形式的调用,因为可以直接识别出敏感属性的使用是在前向传播期间强制统计学上的组节点表示中心聚集在一起。

4. 文献解读

4.1 Introduction

该文通过设计一个由 GNN 的统一优化框架指导的公平消息传递方案来提供确定性。

实现公平消息传递的关键思想是

  • 首先聚合,
  • 然后通过明确追求一致的人口群体代表中心来减轻偏见。

具体来说,

  • 首先制定一个集成图数据的公平性和平滑性目标的优化问题。
  • 然后,通过 Fenchel 共轭和梯度下降来解决公式化的问题,以生成公平且信息丰富的表示,其中采用 softmax 函数的特性来加速原始变量的梯度计算。
    • 此外,将优化问题求解器解释为两个主要步骤。
  • 最后,将FMP集成到图神经网络中,以实现节点分类任务的公平和准确的预测。

4.2 创新点

该文的主要贡献有四个方面:

  1. 证明了精心设计的 GNN 架构可以提高图数据的公平性的概念验证。与专注于数据预处理和公平训练策略设计的传统方法相比,工作提供了全新的前景。
  2. 提出FMP,在统一优化框架的指导下,通过在消息传递中明确纳入敏感属性信息来实现公平性。此外,引入了一种基于softmax属性的加速方法来降低梯度计算复杂度。
  3. FMP 的有效性和效率在三个真实数据集上进行了实验评估。结果表明,与最先进的技术相比,FMP 在预测性能和公平性之间表现出相当或更好的权衡,而计算开销可以忽略不计。

4.3 实验过程

数据集:Pokecz 和 Pokec-n 是根据省份信息从斯洛伐克一个更大的类 Facebook 社交网络 Pokec(Takac 和 Zabovsky 2012)中采样的,其中区域信息被视为敏感属性,预测标签是该网络的工作领域。用户。 NBA 数据集扩展自 Kaggle 数据集 7,其中包含约 400 名 NBA 篮球运动员。球员信息包括年龄、国籍、2016-2017赛季工资等。玩家的链接关系来自Twitter,官方抓取API。采用二元国籍(美国和海外球员)作为敏感属性,预测标签是工资是否高于中位数。

评估标准 Δ D P = ∣ P ( y ^ = 1 ∣ s = − 1 ) − P ( y ^ = 1 ∣ s = 1 ) Δ_{DP} = |P(\hat y = 1|s = −1)−P(\hat y = 1|s = 1) ΔDP=P(y^=1∣s=1)P(y^=1∣s=1) Δ E O = ∣ P ( y ^ = 1 ∣ s = − 1 , y = 1 ) − P ( y ^ = 1 ∣ s = 1 , y = 1 ) ∣ Δ_{EO} = |P(\hat y = 1|s = −1, y = 1) − P(\hat y = 1|s = 1, y = 1)| ΔEO=P(y^=1∣s=1,y=1)P(y^=1∣s=1,y=1)

基线:将 FMP 与代表性的 GNN 进行比较,GCN、GAT、SGC、APPNP、JKNet和MLP。对于所有模型,训练 2 层具有 64 个隐藏单元的神经网络 300 个时期。此外,还比较了对抗性去偏差和添加统计学正则化方法,以显示所提出方法的有效性。

实施细节:运行实验 5 次并报告每种方法的平均性能。对所有模型采用 Adam 优化器,学习率为 0.001,权重衰减为 10−5。对于对抗性去偏差,采用训练分类器和对手分别具有 70 和 30 epoch。对手损失的超参数在 {0.0, 1.0, 2.0, 5.0, 8.0, 10.0, 20.0, 30.0} 中调整。为了添加正则化,采用超参数集 {0.0, 1.0, 2.0, 5.0, 8.0, 10.0, 20.0, 50.0, 80.0, 100.0}。

与现有 GNN 的比较:表 1 显示了针对 Pokec-z、Pokec-n、NBA 数据集提出的 FMP 与 MLP、GAT、GCN、SGC 和 APPNP 的准确性、人口统计平等性和平等机会指标。这三个数据集的详细统计信息如表3所示。

image-20240707170745596

image-20240707170845327

从这些结果中,可以得到以下观察结果:

  • 就公平性指标而言,许多现有的 GNN 在所有三个数据集上的 MLP 模型都表现不佳。例如,在 Pokec-z 数据集上,MLP 的人口统计奇偶性比 GAT、GCN、SGC 和 APPNP 低 32.64%、50.46%、66.53% 和 58.72%。较高的预测偏差来自于相同敏感属性节点内的聚合和图数据中的拓扑偏差。
  • FMP 在所有数据集的人口平等和机会均等方面始终实现最低的预测偏差。具体而言,与 Pokecz、Pokec-n 和 NBA 数据集中所有基线中的最低偏差相比,FMP 将人口统计均等性降低了 49.69%、56.86% 和 5.97%。同时,FMP 在 NBA 数据集中实现了最佳精度,在 Pokec-z 和 Pokec-n 数据集中达到了相当的精度。简而言之,所提出的 FMP 可以有效减轻预测偏差,同时保持预测性能。

与对抗性去偏和正则化的比较:随机划分 50%/25%/25% 用于训练、验证和测试数据集。图 2 显示了所有方法的帕累托最优曲线,其中右下角点代表理想性能(最高准确度和最低预测偏差)。

image-20240707171047802

根据结果,列出以下观察结果:

  • 与对抗性去偏并为许多 GNN 和 MLP 添加正则化相比,我们提出的 FMP 可以实现更好的 DP-Acc 权衡。这样的观察验证了 FMP 中关键思想的有效性:先聚合,然后去偏。此外,由于透明且高效的去偏差,FMP 可以以可忽略不计的性能成本来降低人口统计平等。
  • GNN 中的消息传递确实很重要。为了添加正则化或对抗性去偏差,不同的 GNN 存在巨大差异,这意味着适当的消息传递方式可能会带来更好的权衡性能。此外,许多 GNN 在低标签同质系数数据集(例如 NBA)中表现不如 MLP。基本原理是,当邻居具有相同标签的概率较低时,聚合可能并不总是能带来准确性方面的好处。

5. 结论

在这项工作中,从模型架构的角度提高了图的公平性。设计了一种公平的消息传递方案,以使用普通训练损失来实现节点分类的公平预测,而无需进行数据预处理。具体来说,受 GNN 统一优化框架的推动,FMP 被设计为首先聚合,然后减轻偏差,以明确追求平滑性和公平性目标。还从模型架构解释、效率以及敏感属性的白盒使用方面对FMP进行了全面的讨论。真实数据集上的实验结果证明了 FMP 与节点分类任务中的多个基线相比的有效性。

6.代码复现

tip:该部分代码需要使用numpy1.x

1. FMP

from typing import Optional, Tuple
from torch_geometric.typing import Adj, OptTensor

import torch
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from dgl.nn.pytorch import GraphConv
import torch_sparse
from torch_sparse import SparseTensor, matmul

def get_sen(sens, idx_sens_train):
    sens_zeros = torch.zeros_like(sens)
    # print(f'sens={sens}')
    sens_1 = sens 
    sens_0 = (1 - sens) 

    # print(f'idx_sens_train={idx_sens_train.shape}')
    # print(f'idx_sens_train={idx_sens_train.shape}')

    # print(f'sens_1={sens_1.shape}')

    sens_1[idx_sens_train] = sens_1[idx_sens_train] / len(sens_1[idx_sens_train])
    sens_0[idx_sens_train] = sens_0[idx_sens_train] / len(sens_0[idx_sens_train])

    # print(f'sens_1={sens_1.shape}')

    sens_zeros[idx_sens_train] = sens_1[idx_sens_train] - sens_0[idx_sens_train]

    sen_mat = torch.unsqueeze(sens_zeros, dim=0)
    # print(f'sen_mat={sen_mat[0, 0:10]}')
    # print(f'sen_mat={sen_mat[0, 10:20]}')

    return sen_mat

# def sen_norm(sen, edge_index):
#     ## edge_index: unnormalized adjacent matrix
#     ## normalize the sensitive matrix
#     edge_index = torch_sparse.fill_diag(edge_index, 1.0) ## add self loop to avoid 0 degree node
#     deg = torch_sparse.sum(edge_index, dim=1)
#     deg_inv_sqrt = deg.pow(-0.5)
#     sen = torch_sparse.mul(sen, deg_inv_sqrt.view(1, -1)) ## col-wise
#     return sen

def check_sen(edge_index, sen):
    nnz = edge_index.nnz()
    deg = torch.eye(edge_index.sizes()[0]).cuda()
    adj = edge_index.to_dense()
    lap = (sen.t() @ sen).to_dense()
    lap2 = deg - adj
    diff = torch.sum(torch.abs(lap2-lap)) / nnz
    assert diff < 0.000001, f'error: {diff} need to make sure L=B^TB'


class FMP(GraphConv):
    _cached_sen = Optional[SparseTensor]

    def __init__(self, 
                 in_feats: int,
                 out_feats: int,
                 K: int, 
                 lambda1: float = None,
                 lambda2: float = None,
                 L2: bool = True,
                 dropout: float = 0.,
                 cached: bool = False, 
                 **kwargs):

        super(FMP, self).__init__(in_feats, out_feats)
        self.K = K
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.L2 = L2
        self.dropout = dropout
        self.cached = cached
        self._cached_sen = None

        self.propa = GraphConv(in_feats, in_feats, weight=False, bias=False, activation=None)

    def reset_parameters(self):
        self._cached_sen = None
       def forward(self, x: Tensor,   
                g,   
                idx_sens_train,  
                edge_weight: OptTensor = None,   
                sens=None) -> Tensor:  
        """  
        前向传播函数  
  
        Args:  
            x (Tensor): 节点特征矩阵  
            g: 图结构信息  
            idx_sens_train (Tensor): 训练集中敏感属性的索引  
            edge_weight (OptTensor, optional): 边权重,默认为None  
            sens (Tensor, optional): 敏感属性矩阵,默认为None  
  
        Returns:  
            Tensor: 处理后的节点特征矩阵  
        """  
        if self.K <= 0:  
            return x  # 如果传播次数为0或负数,则直接返回输入  
  
        cache = self._cached_sen  
        if cache is None:  
            # 如果未缓存敏感矩阵,则计算并可能缓存  
            sen_mat = get_sen(sens=sens, idx_sens_train=idx_sens_train)  
            if self.cached:  
                self._cached_sen = sen_mat  
                # 初始化z,可能用于后续计算  
                self.init_z = torch.zeros((sen_mat.size(0), x.size(-1))).cuda()  
        else:  
            sen_mat = self._cached_sen  
  
        # 开始传播和敏感性处理  
        hh = x  # 初始化hh为x  
        x = self.emp_forward(g, x=x, hh=hh, K=self.K, sen=sen_mat)  
        return x  
    def emp_forward(self, g, x, hh, K, sen):  
        # 获取类的属性lambda1和lambda2,这些可能是正则化项的系数  
        lambda1 = self.lambda1  
        lambda2 = self.lambda2  
        # 计算gamma和beta,这些参数用于后续的加权和正则化  
        gamma = 1 / (1 + lambda2)  
        beta = 1 / (2 * gamma)  

        # 循环K次,可能是执行K步迭代或更新  
        for _ in range(K):  
            # 根据lambda2的值决定是否使用自定义传播函数(propa)还是直接使用x  
            if lambda2 > 0:  
                y = gamma * hh + (1-gamma) * self.propa(g, feat=x)  
            else:  
                y = gamma * hh + (1-gamma) * x  

            # 如果lambda1 > 0,执行更复杂的更新逻辑  
            if lambda1 > 0:  
                # 使用softmax和sen矩阵计算z  
                z = sen @ F.softmax(y, dim=1) / (gamma * sen @ sen.t())  
                # 通过z反推出x_bar0,再对x_bar0进行softmax得到x_bar1  
                x_bar0 = sen.t() @ z  
                x_bar1 = F.softmax(x_bar0, dim=1)  
                # 计算修正项correct,用于调整x_bar  
                correct = x_bar0 * x_bar1  
                coeff = torch.sum(x_bar0 * x_bar1, 1, keepdim=True)  
                correct = correct - coeff * x_bar1  
                x_bar = y - gamma * correct  
                # 更新z并考虑正则化  
                z_bar = z + beta * (sen @ F.softmax(x_bar, dim=1))  
                if self.L2:  
                    # 如果使用L2正则化,则调用L2_projection进行投影  
                    z = self.L2_projection(z_bar, lambda_=lambda1, beta=beta)  
                else:  
                    # 否则,使用L1正则化  
                    z = self.L1_projection(z_bar, lambda_=lambda1)  
                # 重新计算x_bar0, x_bar1, correct以考虑正则化后的z  
                x_bar0 = sen.t() @ z  
                x_bar1 = F.softmax(x_bar0, dim=1)  
                correct = x_bar0 * x_bar1  
                coeff = torch.sum(x_bar0 * x_bar1, 1, keepdim=True)  
                correct = correct - coeff * x_bar1  
                # 更新x  
                x = y - gamma * correct  
            else:  
                # 如果lambda1 <= 0,则不执行复杂的更新逻辑,直接令x=y  
                x = y  
            # 对x应用dropout,以减少过拟合  
            x = F.dropout(x, p=self.dropout, training=self.training)  
        # 返回更新后的x  
        return x
        def L1_projection(self, x: torch.Tensor, lambda_):  
        """  
        对输入张量x应用L1投影。  
        通过clamp函数将x的值限制在[-lambda_, lambda_]之间。  
          
        参数:  
            x (torch.Tensor): 输入张量。  
            lambda_ (float): L1投影的限制值。  
          
        返回:  
            torch.Tensor: 应用L1投影后的张量。  
        """  
        return torch.clamp(x, min=-lambda_, max=lambda_)  
  
    def L2_projection(self, x: torch.Tensor, lambda_, beta):  
        """  
        对输入张量x应用L2投影,通过缩放x来实现。  
          
        参数:  
            x (torch.Tensor): 输入张量。  
            lambda_ (float): L2正则化的系数。  
            beta (float): 另一个系数,可能与正则化强度有关。  
          
        返回:  
            torch.Tensor: 应用L2投影后的张量。  
        """  
        coeff = (2*lambda_) / (2*lambda_ + beta)  
        return coeff * x  
  
    def message(self, x_j: torch.Tensor, edge_weight: torch.Tensor) -> torch.Tensor:  
        """  
        根据边权重计算消息。  
        将节点特征x_j与边权重相乘,用于图神经网络中的消息传递步骤。  
          
        参数:  
            x_j (torch.Tensor): 目标节点的特征。  
            edge_weight (torch.Tensor): 边权重。  
          
        返回:  
            torch.Tensor: 乘以边权重后的节点特征,作为消息。  
        """  
        return edge_weight.view(-1, 1) * x_j  
  
    def message_and_aggregate(self, adj_t: SparseTensor, x: torch.Tensor) -> torch.Tensor:  
        """  
        在图上进行消息传递和聚合。  
        使用稀疏矩阵乘法将邻接矩阵adj_t与节点特征x相乘,并根据聚合方式聚合消息。  
          
        参数:  
            adj_t (SparseTensor): 图的邻接矩阵的稀疏表示。  
            x (torch.Tensor): 节点特征张量。  
          
        返回:  
            torch.Tensor: 聚合后的节点特征张量。  
        """  
        return matmul(adj_t, x, reduce=self.aggr)  
  
    def __repr__(self):  
        """  
        返回类的字符串表示,包括类的名称和关键属性。  
          
        返回:  
            str: 类的字符串表示。  
        """  
        return '{}(K={}, lambda1={}, lambda2={}, L2={})'.format(  
            self.__class__.__name__, self.K, self.lambda1, self.lambda2, self.L2)

2. fairGNN

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear

from fmp import FMP

class FairGNN(nn.Module):  
    """  
    FairGNN类,一个结合了公平性约束的图神经网络模型。  
  
    Args:  
        input_size (int): 输入特征的维度。  
        size (int): 隐藏层的维度。  
        num_classes (int): 输出的类别数。  
        num_layer (int): 图神经网络中隐藏层的数量(不包括输入和输出层)。  
        prop (nn.Module): 自定义的传播层,用于实现公平性相关的图传播逻辑。  
        **kwargs: 其他可能传递给父类的关键字参数。  
    """  
    def __init__(self, input_size, size, num_classes, num_layer, prop, **kwargs):  
        super(FairGNN, self).__init__()  
  
        # 初始化隐藏层列表  
        self.hidden = nn.ModuleList()  
        for _ in range(num_layer-2):  # 减去2是因为已经手动添加了输入层和输出层  
            self.hidden.append(nn.Linear(size, size))  
  
        # 输入层  
        self.first = nn.Linear(input_size, size)  
        # 输出层  
        self.last = nn.Linear(size, num_classes)  
  
        # 自定义的传播层  
        self.prop = prop  
  
    def reset_parameters(self):  
        """  
        重置模型参数。注意:原始代码中的lin1和lin2在FairGNN类中未定义,这里假设重置所有层。  
        """  
        self.first.reset_parameters()  
        for layer in self.hidden:  
            layer.reset_parameters()  
        self.last.reset_parameters()  
        # 如果prop层有reset_parameters方法,则也应该调用  
        if hasattr(self.prop, 'reset_parameters'):  
            self.prop.reset_parameters()  
  
    def forward(self, features, g, sens, idx_sens_train):  
        """  
        前向传播函数。  
  
        Args:  
            features (Tensor): 节点特征矩阵。  
            g (Graph): 图结构,可能包含边索引和其他图信息。  
            sens (Tensor): 敏感属性矩阵。  
            idx_sens_train (Tensor): 训练集敏感属性的索引。  
  
        Returns:  
            Tensor: 模型输出的类别预测。  
        """  
        x = features  
  
        # 通过第一层  
        out = F.relu(self.first(x))  
  
        # 通过隐藏层  
        for layer in self.hidden:  
            out = F.relu(layer(out))  
  
        # 通过输出层  
        x = self.last(out)  
  
        # 应用自定义的传播层,考虑公平性  
        x = self.prop(x, sens=sens, g=g, idx_sens_train=idx_sens_train)  
  
        # 返回原始输出,如果需要,可以在外部调用F.log_softmax  
        return x  
  
  
def get_model(args, data):  
    """  
    根据给定的参数和数据构建FairGNN模型。  
  
    Args:  
        args (Namespace): 包含模型配置参数的命名空间。  
        data (Data): 包含图数据和特征等的数据对象。  
  
    Returns:  
        nn.Module: 构建的FairGNN模型。  
    """  
    Model = FairGNN  
  
    # 创建自定义传播层FMP  
    prop = FMP(in_feats=data.num_features,  
               out_feats=data.num_features,  # 注意:这里out_feats通常为隐藏层大小,这里可能是一个示例  
               K=args.num_layers,  # 注意:这里的K可能与FairGNN中的num_layer不一致,取决于FMP的实现  
               lambda1=args.lambda1,  
               lambda2=args.lambda2,  
               L2=args.L2,  
               cached=True)  
  
    # 构建FairGNN模型  
    model = Model(input_size=data.num_features,   
                  size=args.num_hidden,   
                  num_classes=data.num_classes,   
                  num_layer=args.num_gnn_layer,   
                  prop=prop).cuda()  
  
    return model

小结

该文提出了一种基于图神经网络(GNN)统一优化框架指导的公平消息传递方案(FMP),旨在提高图数据处理的公平性。该方案通过两个核心步骤实现:首先聚合图数据,随后明确追求代表各统计学群体中心以减少偏见。该方法构建了一个同时考虑公平性和数据平滑性的优化问题,并运用Fenchel共轭和梯度下降技术结合softmax函数特性来高效求解,生成既公平又信息丰富的节点表示。此方案直接嵌入到GNN中,以改善节点分类任务的公平性和准确性,无需数据预处理。实验结果显示,在真实数据集上,FMP较多个基线模型表现更佳,从模型架构、效率及敏感属性利用等角度全面验证了其有效性。

未来的阅读计划:

  1. Hypergraph-enhanced Dual Semi-supervised Graph Classification:ICML2024,HEAL:用于半监督图分类的超图增强对偶框架
  2. PGODE: Towards High-quality System Dynamics Modeling:ICML2024,PGODE:迈向高质量的系统动力学建模

参考文献

[1] Jiang, Z., Han, X., Fan, C., Liu, Z., Zou, N., Mostafavi, A., & Hu, X. (2024). Chasing Fairness in Graphs: A GNN Architecture Perspective. Proceedings of the AAAI Conference on Artificial Intelligence, 38(19), 21214-21222. https://doi.org/10.1609/aaai.v38i19.30115

[2] Liu, X.; Jin, W.; Ma, Y.; Li, Y.; Liu, H.; Wang, Y.; Yan, M.; and Tang, J. 2021. Elastic graph neural networks. In International Conference on Machine Learning, 6837–6849. PMLR.

[3] Rockafellar, R. T. 2015. Convex analysis. Princeton university press.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/780832.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

初识java—jdk17的一些新增特性

文章目录 前言一 &#xff1a; yield关键字二 &#xff1a;var关键字三 &#xff1a;密封类四 &#xff1a;空指针异常&#xff1a;五&#xff1a;接口中的私有方法&#xff1a;六&#xff1a;instanceof关键字 前言 这里介绍jdk17相对于jdk1.8的部分新增特性。 一 &#xff…

Spring Boot的无缝衔接:深入解析与实践

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ &#x1f680;The begin&#x1f697;点点关注&#xff0c;收藏不迷路&#x1f6a9; 引言 在快速迭代的软件开发环境中&#xff0c;无缝衔接是提升开发效率、降低维护成本、增强系统稳定性的关键。Spring Boo…

STM32芯片系列与产品后缀解读

一. 产品系列 STM32单片机是一系列基于ARM Cortex-M内核的32位微控制器&#xff0c;广泛应用于嵌入式系统中。 STM32系列由STMicroelectronics&#xff08;意法半导体&#xff09;开发和生产&#xff0c;并凭借其灵活的设计、丰富的外设和强大的生态系统&#xff0c;成为嵌入式…

LLM - 卷积神经网络(CNN)

1. 卷积神经网络结构&#xff1a;分为输入层&#xff0c;卷积层&#xff0c;池化层&#xff0c;全连接层&#xff1b; &#xff08;1&#xff09;首先进入输入层&#xff0c;对数据数据进行处理&#xff0c;将输入数据向量化处理&#xff0c;最终形成输入矩阵。 &#xff08;…

C++ 什么是虚函数?什么是纯虚函数,以及区别?(通俗易懂)

&#x1f4da; 当谈到虚函数时&#xff0c;通常是指在面向对象编程中的一种机制&#xff0c;它允许在派生类中重写基类的函数&#xff0c;并且能够通过基类指针或引用调用派生类中的函数。 目录 前言 &#x1f525; 虚函数 &#x1f525; 纯虚函数 &#x1f525; 两者区别…

用 Echarts 画折线图

https://andi.cn/page/621503.html

leetcode每日一题-3033. 修改矩阵

题目描述&#xff1a; 解题思路&#xff1a;简单题目&#xff0c;思路非常直接。对列进行遍历&#xff0c;记录下最大值&#xff0c;然后再遍历一遍&#xff0c;把-1替换为最大值。需要注意的是进行列遍历和行遍历是不同的。 官方题解&#xff1a; class Solution { public:v…

VRay渲染有什么技巧?渲染100邀请码1a12

渲染是视觉行业非常重要的一环&#xff0c;没有渲染就没有效果图&#xff0c;常用的渲染器有Vray&#xff0c;而Vray渲染有很多技巧&#xff0c;可以让渲染更快更省&#xff0c;下面我们总结下。 1、删除无用对象 检查场景&#xff0c;看是否有一些不需要渲染的物体和灯光&am…

将大型语言模型模块化打造协作智能体

B UILDING C OOPERATIVE E MBODIED A GENTS MODULARLY WITH L ARGE L ANGUAGE M ODELS 论文链接&#xff1a; https://arxiv.org/abs/2307.02485https://arxiv.org/abs/2307.02485 1.概述 在去中心化控制及多任务环境中&#xff0c;多智能体合作问题因原始感官观察、高昂…

绝区肆--2024 年AI安全状况

前言 随着人工智能系统变得越来越强大和普及&#xff0c;与之相关的安全问题也越来越多。让我们来看看 2024 年人工智能安全的现状——评估威胁、分析漏洞、审查有前景的防御策略&#xff0c;并推测这一关键领域的未来可能如何。 主要的人工智能安全威胁 人工智能系统和应用程…

el-date-picker 设置默认值为当前日期

this.listQuery.Date new Date().toISOString().substr(0, 10); <el-date-picker v-model"listQuery.Date" format"yyyy-MM-dd" value-format"yyyy-MM-dd" type"date" placeholder"选择日期" change"getList()&qu…

Java语言程序设计篇一

Java语言概述 Java语言起源编程语言最新排名名字起源Java语言发展历程Java语言的特点Java虚拟机垃圾回收Java语言规范Java技术简介Java程序的结构Java程序注意事项&#xff1a;注释编程风格练习 Java语言起源 1990年Sun公司提出一项绿色计划。1992年语言开发成功最初取名为Oak…

Blender新手入门笔记收容所(一)

基础篇 基础操作 视角的控制 控制观察视角&#xff1a;鼠标中键平移视图&#xff1a;Shift鼠标中键缩放视图&#xff1a;滚动鼠标中键滚轮 选中物体后&#xff1a;移动物体快捷键G&#xff0c;移动后单击鼠标就会定下来。 进入移动状态后&#xff1a;按Y会沿着Y轴移动进入移动…

成人高考本科何时报名-深职训学校帮您规划学习之路

你有想过继续深造自己的学历吗&#xff1f;也许你已经工作多年&#xff0c;但总觉得学历是一块心病&#xff0c;想要通过成人高考本科来提升自己。不用着急&#xff0c;今天我们来聊一聊成人高考本科的报名时间&#xff0c;以及深职训学校如何帮助你顺利完成报名。 深圳成人高…

2024上半年网络工程师考试《应用技术》试题一

阅读以下说明&#xff0c;回答问题。 【说明】 MPLS基于(1)进行转发&#xff0c;进行MPLS标签交换和报文转发的网络设备称为(2)&#xff0c;构成MPLS域(MPSDomain)。位于MPLS域边缘、连接其他网络的LSR称为(3),区域内部的LSR称为核心LSR(CoreLSR)IP报文进入MPLS网络时&#xf…

文件管理下:文件函数的学习

前言 Hello,小伙伴们你们的作者君又来了&#xff0c;上次我们简单介绍了文件的坐拥并简单提到了数据的读取&#xff0c;和C语言的默认流的作用&#xff0c;今天我将继续带领大家探索文件的奥秘&#xff0c;大家准别好了吗&#xff1f; 在内容开始之前还是按照惯例&#xff0c…

Alt与Tab切换窗口时将Edge多个标签页作为一个整体参与切换的方法

本文介绍在Windows电脑中&#xff0c;使用Alt与Tab切换窗口时&#xff0c;将Edge浏览器作为一个整体参与切换&#xff0c;而不是其中若干个页面参与切换的方法。 最近&#xff0c;需要将主要使用的浏览器由原本的Chrome换为Edge&#xff1b;但是&#xff0c;在更换后发现&#…

Python爬虫系列-让爬虫自己写爬虫(半自动化,代替人工写爬虫)

现在的PC、手机客户端等终端设备大量使用了网页前后端技术&#xff0c;另外主流的网站也会经常会更新&#xff0c;导致以前一个月更新一次爬虫代码&#xff0c;变成了天天需要更新代码&#xff0c;所以自动化爬虫技术在当前就显得特别重要&#xff0c;最近我也是在多次更新某个…

Java | Leetcode Java题解之第220题存在重复元素III

题目&#xff1a; 题解&#xff1a; class Solution {public boolean containsNearbyAlmostDuplicate(int[] nums, int k, int t) {int n nums.length;Map<Long, Long> map new HashMap<Long, Long>();long w (long) t 1;for (int i 0; i < n; i) {long i…

【Java探索之旅】初识多态_概念_实现条件

文章目录 &#x1f4d1;前言一、多态1.1 概念1.2 多态的实现条件 &#x1f324;️全篇总结 &#x1f4d1;前言 多态作为面向对象编程中的重要概念&#xff0c;为我们提供了一种灵活而强大的编程方式。通过多态&#xff0c;同一种操作可以应用于不同的对象&#xff0c;并根据对象…