图神经网络实战(15)——SEAL链接预测算法

图神经网络实战(15)——SEAL链接预测算法

    • 0. 前言
    • 1. SEAL 框架
      • 1.1 基本原理
      • 1.2 算法流程
    • 2. 实现 SEAL 框架
      • 2.1 数据预处理
      • 2.2 模型构建与训练
    • 小结
    • 系列链接

0. 前言

我们已经学习了基于节点嵌入的链接预测算法,这种方法通过学习相关的节点嵌入来计算链接可能性。接下来,我们介绍另一类方法,通过查看目标节点周围的局部邻域执行链接预测任务,这类技术称为基于子图的算法,由 SEAL 广泛使用,可以说 SEAL 表示用于链接预测的子图、嵌入和属性 (Subgraphs, Embeddings, and Attributes for Link prediction) 的缩写,但并不完全准确。在本节中,我们将介绍 SEAL 框架,并使用 PyTorch Geometric 实现该框架。

1. SEAL 框架

1.1 基本原理

SEALZhangChen2018 年提出,是一个学习图结构特征以进行链接预测的框架。它将目标节点 ( x , y ) (x,y) (x,y) 和它们的 k 跳 (k-hop) 邻居所形成的子图定义为封闭子图 (enclosing subgraph)。每个封闭子图(而非整个图)都被用作预测链接可能性的输入。从另一个角度来看,SEAL 自动学习了一种用于链接预测的局部启发式方法。

1.2 算法流程

SEAL 框架包括三个步骤:

  1. 封闭子图提取 (Enclosing subgraph extraction),包括提取一组真实链接和一组虚假链接(负抽样)来形成训练数据
  2. 节点信息矩阵构建 (Node information matrix construction),包括节点标记、节点嵌入和节点特征三个部分
  3. 图神经网络 (Graph Neural Networks, GNN) 训练 (GNN training),将节点信息矩阵作为输入,并输出链接的可能性

这些步骤可以用下图进行总结:

SEAL 框架

封闭子图提取是一个简单的过程,列出目标节点及其 k 跳邻居,以提取它们的边和特征。k 值越大,SEAL 所能学习到的启发式算法的质量就越高,但同时也会创建更大、计算开销更大的子图。
节点信息构建的第一个部分是节点标记 (node labeling)。这一过程为每个节点分配一个特定的编号,如果没有进行标记,GNN 就无法区分目标节点和上下文节点(目标节点的邻居)。它还融合了距离,用来描述节点的相对位置和结构重要性。
在实践中,目标节点 x x x y y y 必须共享一个唯一的标签,以确定它们是目标节点。对于上下文节点 i i i j j j,如果它们与目标节点的距离相同,则必须共享相同的标签—— d ( i , x ) = d ( j , x ) d(i, x) = d(j, x) d(i,x)=d(j,x) d ( i , y ) = d ( j , y ) d(i, y) = d(j, y) d(i,y)=d(j,y)。我们称这种距离为双半径 (double radius),表示为 ( d ( i , x ) , d ( i , y ) ) (d(i, x), d(i, y)) (d(i,x),d(i,y))
SEAL 中使用双半径节点标记 (Double-Radius Node Labeling, DRNL) 算法,其工作原理如下:

  1. 首先,将标签 1 分配给节点 x x x y y y
  2. 将标签 2 分配给半径为 (1,1) 的节点
  3. 将标签 3 分配给半径为 (1,2)(2,1) 的节点
  4. 将标签 4 分配给半径为 (1,3)(3,1) 的节点,以此类推

DRNL 函数的数学表达式如下:
f ( i ) = 1 + m i n ( d ( i , x ) , d ( i , y ) ) + ( d / 2 ) [ ( d / 2 ) + ( d % 2 ) − 1 ] f(i)=1+min(d(i,x),d(i,y))+(d/2)[(d/2)+(d\%2)-1] f(i)=1+min(d(i,x),d(i,y))+(d/2)[(d/2)+(d%2)1]
其中, d = d ( i , x ) + d ( i , y ) d= d(i, x) + d(i, y) d=d(i,x)+d(i,y) ( d / 2 ) (d/2) (d/2) ( d % 2 ) (d\%2) (d%2) 分别是 d d d 除以 2 2 2 的整数商和余数。最后,对这些节点标签进行独热编码 (one-hot encode)。
节点信息矩阵构建过程中的其它两个部分比较容易获得。节点嵌入是可选的,可以使用其他算法(如 Node2Vec )计算。然后,将它们与节点特征和独热编码标签连接起来,构建最终的节点信息矩阵 (node information matrix)。
最后,训练 GNN,利用封闭子图的信息和邻接矩阵来预测链接。为此,SEAL 使用了深度图卷积神经网络 (Deep Graph Convolutional Neural Network, DGCNN),该架构执行以下三个步骤:

  1. 使用数个图卷积网络 (Graph Convolutional Network, GCN) 层计算节点嵌入,然后将其串联起来,类似于图同构网络 (Graph Isomorphism Network, GIN)
  2. 使用全局排序池化层按照一致的顺序排列这些嵌入,然后再将它们深入到卷积层,而卷积层不具备置换不变性
  3. 使用传统的卷积层和全连接层应用于排序后图表示,并输出链接概率

DGCNN 模型使用二进制交叉熵损失进行训练,输出的概率介于 01 之间

2. 实现 SEAL 框架

SEAL 框架需要进行大量的预处理,以提取并标注封闭子图。接下来,我们使用 PyTorch Geometric 来实现 SEAL 框架。

2.1 数据预处理

(1) 首先,导入所有必要的库:

import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.sparse.csgraph import shortest_path

import torch.nn.functional as F
from torch.nn import Conv1d, MaxPool1d, Linear, Dropout, BCEWithLogitsLoss

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, aggr
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

(2) 加载 Cora 数据集,并应用链接级随机拆分:

transform = RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True)
dataset = Planetoid('.', name='Cora', transform=transform)
train_data, val_data, test_data = dataset[0]

(3) 链接级随机拆分会在数据对象中创建新字段,用于存储每条正样本边(真实的边)和负样本边(虚假的边)的标签和索引:

print(train_data)

# Data(x=[2708, 1433], edge_index=[2, 8976], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], pos_edge_label=[4488], pos_edge_label_index=[2, 4488], neg_edge_label=[4488], neg_edge_label_index=[2, 4488])

(4) 创建函数 seal_processing() 处理拆分后的数据集,并获得带有独热编码节点标签和节点特征的封闭子图,使用列表 data_list 存储这些子图:

def seal_processing(dataset, edge_label_index, y):
    data_list = []

对于数据集中的每一对节点(源和目的节点),提取 k 跳邻居(本节中 k = 2):

    for src, dst in edge_label_index.t().tolist():
        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph([src, dst], 2, dataset.edge_index, relabel_nodes=True)
        src, dst = mapping.tolist()

使用双半径节点标记 (Double-Radius Node Labeling, DRNL) 函数计算距离。首先,从子图中删除目标节点:

        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
        mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
        sub_edge_index = sub_edge_index[:, mask1 & mask2]

根据上一个子图计算源节点和目标节点的邻接矩阵:

        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

计算每个节点与源节点/目标节点之间的距离:

        # Calculate the distance between every node and the source target node
        d_src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
        d_src = np.insert(d_src, dst, 0, axis=0)
        d_src = torch.from_numpy(d_src)

        # Calculate the distance between every node and the destination target node
        d_dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst-1)
        d_dst = np.insert(d_dst, src, 0, axis=0)
        d_dst = torch.from_numpy(d_dst)

计算子图中每个节点的节点标签 z z z

        dist = d_src + d_dst
        z = 1 + torch.min(d_src, d_dst) + dist // 2 * (dist // 2 + dist % 2 - 1)
        z[src], z[dst], z[torch.isnan(z)] = 1., 1., 0.
        z = z.to(torch.long)

在本节中,并未使用节点嵌入,但仍将特征和独热编码标签串联起来,以构建节点信息矩阵:

        node_labels = F.one_hot(z, num_classes=200).to(torch.float)
        node_emb = dataset.x[sub_nodes]
        node_x = torch.cat([node_emb, node_labels], dim=1)

创建一个 Data 对象并将其附加到列表 data_list 中,作为函数的最终输出:

        data = Data(x=node_x, z=z, edge_index=sub_edge_index, y=y)
        data_list.append(data)

    return data_list

(5) 调用 deal_processing 提取每个数据集的封闭子图。将正样本和负样本分开,以获得正确的预测标签:

train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)
train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)

val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)
val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)

test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)
test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)

(6) 合并正负数据列表,重建训练、验证和测试数据集:

train_dataset = train_pos_data_list + train_neg_data_list
val_dataset = val_pos_data_list + val_neg_data_list
test_dataset = test_pos_data_list + test_neg_data_list

(7) 创建数据加载器,使用批数据训练 GNN

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

2.2 模型构建与训练

(1) 定义 DGCNN 类,其中参数 k 表示每个子图的节点数:

class DGCNN(torch.nn.Module):
    def __init__(self, dim_in, k=30):
        super().__init__()

创建四个 GCN 层,设定隐藏维度为 32

        self.gcn1 = GCNConv(dim_in, 32)
        self.gcn2 = GCNConv(32, 32)
        self.gcn3 = GCNConv(32, 32)
        self.gcn4 = GCNConv(32, 1)

实例化全局排序池化层 (深度图卷积神经网络 (Deep Graph Convolutional Neural Network, DGCNN) 架构的核心):

        self.global_pool = aggr.SortAggregation(k=k)

全局排序池化层提供的节点排序使我们能够使用传统的卷积层:

        self.conv1 = Conv1d(1, 16, 97, 97)
        self.conv2 = Conv1d(16, 32, 5, 1)
        self.maxpool = MaxPool1d(2, 2)

最后,实例化多层感知机 (Multilayer Perceptron, MLP) 用于获取预测:

        self.linear1 = Linear(352, 128)
        self.dropout = Dropout(0.5)
        self.linear2 = Linear(128, 1)

forward() 方法中,计算每个 GCN 的节点嵌入,并将结果串联起来:

    def forward(self, x, edge_index, batch):
        # 1. Graph Convolutional Layers
        h1 = self.gcn1(x, edge_index).tanh()
        h2 = self.gcn2(h1, edge_index).tanh()
        h3 = self.gcn3(h2, edge_index).tanh()
        h4 = self.gcn4(h3, edge_index).tanh()
        h = torch.cat([h1, h2, h3, h4], dim=-1)

对串联结果依次应用全局排序池化、卷积层和全连接层:

        # 2. Global sort pooling
        h = self.global_pool(h, batch)

        # 3. Traditional convolutional and dense layers
        h = h.view(h.size(0), 1, h.size(-1))
        h = self.conv1(h).relu()
        h = self.maxpool(h)
        h = self.conv2(h).relu()
        h = h.view(h.size(0), -1)
        h = self.linear1(h).relu()
        h = self.dropout(h)
        h = self.linear2(h).sigmoid()

        return h

(2) 将模型实例化,并使用 Adam 优化器和二进制交叉熵损失对其进行训练:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DGCNN(train_dataset[0].num_features).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
criterion = BCEWithLogitsLoss()

(3) 创建 train() 函数用于批训练:

def train():
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_dataset)

(4)test() 函数中,计算 ROC AUC 分数和平均精度,以比较 SEAL 和变分图自编码器 (Variational Graph Autoencoder, VGAE) 的性能:

@torch.no_grad()
def test(loader):
    model.eval()
    y_pred, y_true = [], []

    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        y_pred.append(out.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    auc = roc_auc_score(torch.cat(y_true), torch.cat(y_pred))
    ap = average_precision_score(torch.cat(y_true), torch.cat(y_pred))

    return auc, ap

(5)DGCNN 进行 31epoch 的训练:

for epoch in range(31):
    loss = train()
    val_auc, val_ap = test(val_loader)
    print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}')

模型训练过程监测

(6) 最后,在测试数据集上对其进行测试:

test_auc, test_ap = test(test_loader)
print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')

# Test AUC: 0.7899 | Test AP 0.8174

可以看到,使用 SEAL 框架得到的结果与使用 VGAE 得到的结果(AUC0.8727AP0.8620) 相似。从理论上讲,基于子图的方法(如 SEAL )比基于节点的方法(如 VGAE )更具表达能力,基于子图的方法通过明确考虑目标节点周围的整个邻域来捕捉更多信息。通过 k 参数增加所考虑的邻域数量,可以进一步提高 SEAL 的准确性。

小结

链接预测是指利用图数据中已知的节点和边的信息,来推断图中未知的连接关系或者未来可能出现的连接关系,在机器学习和数据挖掘等领域具有广泛的应用。本节中介绍了用于链接预测的 SEAL 框架,其侧重于子图表示,每个链接周围的邻域作为预测链接概率的输入。并使用边级随机分割和负采样在 Cora 数据集上实现了 SEAL 模型执行链接预测任务。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/744902.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Open AI 前 Superalignment部门研究员Leopold Aschenbrenner的关于Superintelligence担忧的真挚长文

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

JavaWeb——MySQL:DDL

目录 3.DDL:查询 ​编辑3.4 分组查询(group by) 3.4.1 什么是分组查询 3.4.2 聚合函数 3.4.3 分组查询 3.4.5 总结 3.DDL:查询 查询是使用最多、最频繁的操作,因为前面的修改以及删除,一般会交给数据库…

spring原理篇

第三方bean默认为方法名 自动配置 自动配置的原理 springboot的自动配置原理 首先是从 SpringBootApplication这个注解出发 有一个ComponentScan()默认扫描同级包及其子包 第二个注解是springbootconfiguration 声明当前类是一个配置类 第三个是核心 enableAutoConfigurati…

【机器学习】在【R语言】中的应用:结合【PostgreSQL数据库】的【金融行业信用评分模型】构建

目录 1.数据库和数据集的选择 1.准备工作 2.PostgreSQL安装与配置 3.R和RStudio安装与配置 2.数据导入和预处理 1.连接数据库并导入数据 1.连接数据库 2.数据检查和清洗 1.数据标准化 2.拆分训练集和测试集 3.特征工程 1.生成新特征 2.特征选择 4.模型训练和评估…

嵌入式EMC之TVS管

整理一些网上摘抄的笔记: TVS管认识: TVS的Vc要比,DCDC的最大承受电压要小

Flink 反压

反压 Flink反压是一个在实时计算应用中常见的问题,特别是在流式计算场景中。以下是对Flink反压的详细解释: 一、反压释义 反压(backpressure)意味着数据管道中某个节点成为瓶颈,其处理速率跟不上上游发送数据的速率…

cJSON源码解析之add_item_to_object函数

文章目录 前言add_item_to_object函数是干什么的add_item_to_object代码解析函数实现函数原理解析开头的代码constant_key参数的作用最后的if判断 add_item_to_array函数 总结 前言 在我们的日常编程中,JSON已经成为了一种非常常见的数据交换格式。在C语言中&#…

[深度学习] 卷积神经网络CNN

卷积神经网络(Convolutional Neural Network, CNN)是一种专门用于处理数据具有类似网格结构的神经网络,最常用于图像数据处理。 一、CNN的详细过程: 1. 输入层 输入层接收原始数据,例如一张图像,它可以被…

Qt 实战(6)事件 | 6.1、事件机制

文章目录 一、事件1、基本概念2、事件描述3、事件循环4、事件分发4.1、QApplication::notify()4.2、QObject::event() 5、事件传递6、事件处理器 前言: Qt 框架中的事件机制(Event Mechanism)是一种核心功能,它允许应用程序以事件…

【Launcher3】解决谷歌桌面的小部件重启后消失问题

1-问题摘要 这次主要解决困扰了我很久的时钟消失问题,大概是去年10月刚开始做EDLA项目的时候,需要定制谷歌桌面,桌面布局大概要改成这样: 时间显示在谷歌搜索框的上方,而安卓原生桌面大概是这样子的 我们开发一开始是使用小部件…

微服务+云原生:打造高效、灵活的分布式系统

🐇明明跟你说过:个人主页 🏅个人专栏:《未来已来:云原生之旅》🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、云原生概述 2、微服务概述 二、微服务架构基础 1、…

二、反应式集成-spring

一、Spring WebFlux what 1、简介 - Spring WebFlux 包含一个用于执行 HTTP 请求的客户端。 有一个 基于 Reactor 的功能性、流畅的 API,请参阅 Reactive Libraries, 它支持异步逻辑的声明性组合,而无需处理 线程或并发。它是完全无阻塞的…

Talking Web

1. curl 1.1 http curl http://127.0.0.1:80 向目标主机端口发送http请求 1.2 httphead curl -H “Host: 18ed3df584cd48328b5839443aa7b42b” http://127.0.0.1:80 1.3 httppath curl http://127.0.0.1:80/853c64cd218f80d0a59665666fb2ab80 1.4 URL编码路径 &#xff0…

Python学习笔记21:进阶篇(十)常见标准库使用之math模块,random模块和statistics模块

前言 本文是根据python官方教程中标准库模块的介绍,自己查询资料并整理,编写代码示例做出的学习笔记。 根据模块知识,一次讲解单个或者多个模块的内容。 教程链接:https://docs.python.org/zh-cn/3/tutorial/index.html 数学 P…

微软结束将数据中心置于海底的实验

2016 年,微软 宣布了一项名为"纳蒂克项目"(Project Natick)的实验。基本而言,该项目旨在了解数据中心能否在海洋水下安装和运行。经过多次较小规模的测试运行后,该公司于 2018 年春季在苏格兰海岸外 117 英尺…

从0开始C++(八):多态的实现

相关文章: 从0开始C(一):从C到C 从0开始C(二):类、对象、封装 从0开始C(三):构造函数与析构函数详解 从0开始C(四):作…

React+TS前台项目实战(十九)-- 全局Input组件封装:加载状态和清除功能的实现

文章目录 前言Input组件1. 功能分析2. 代码详细注释3. 使用方式4. 效果展示 总结 前言 今天我们来封装一个input输入框组件,并提供一些常用的功能,你可以选择不同的 尺寸、添加前缀、显示加载状态、触发回调函数、自定义样式 等等。这些功能在这个项目中…

vite+vue3+ts项目搭建流程 (pnpm, eslint, prettier, stylint, husky,commitlint )

vitevue3ts项目搭建 项目搭建项目目录结构 项目配置自动打开项目eslint①vue3环境代码校验插件②修改.eslintrc.cjs配置文件③.eslintignore忽略文件④运行脚本 prettier①安装依赖包②.prettierrc添加规则③.prettierignore忽略文件④运行脚本 stylint①.stylelintrc.cjs配置文…

【云原生】Kubernetes网络知识

Kubernetes网络管理 文章目录 Kubernetes网络管理一、案例概述二、案例前置知识点2.1、Kubernetes网络模型2.2、Docker网络基础2.3、Kubernetes网络通信2.3.1、Pod内容器与内容之间的通信2.3.2、Pod与Pod之间的通信 2.4、Flannel网络插件2.5、Calico网络插件2.5.1、Calico网络模…

免费下载电子书的网站

在如今的数字化时代,电子书已成为许多人书籍阅读的首选。下面小编就和大家分享一些提供免费查找下载电子书服务的网站,这些网站不仅资源丰富,而且操作简便。 免费下载电子书的网站:https://www.bgrdh.com/favorites/1355.html 1…