图神经网络实战(10)——归纳学习

图神经网络实战(10)——归纳学习

    • 0. 前言
    • 1. 转导学习与归纳学习
    • 2. 蛋白质相互作用数据集
    • 3. 构建 GraphSAGE 模型实现归纳学习
    • 小结
    • 系列链接

0. 前言

归纳学习 (Inductive learning) 通过基于已观测训练数据,建立一个通用模型,使模型能够对未见过的节点和图进行归纳预测,而转导学习(Transductive learning, 也称直推学习)是基于所有已经观测到的训练和测试数据构建模型,这种方法是通过已经有标记的节点信息来预测无标记数据节点,因此,在图神经网络 (Graph Neural Networks, GNN)、图卷积网络 (Graph Convolutional Network, GCN)、图注意力网络 (Graph Attention Networks,GAT) 和 GraphSAGE 等节中所构建的模型均属于转导学习模型。在本节中,我们将介绍图数据中的归纳学习和多标签分类,使用 GraphSAGE 模型在蛋白质相互作用 (protein-protein interactions) 数据集执行多标签分类任务,并了解归纳学习的优势和实现方法。

1. 转导学习与归纳学习

在图神经网络 (Graph Neural Networks, GNN)中,可以将学习分为两类,转导学习(Transductive learning, 也称直推学习)和归纳学习 (Inductive learning):

  • 在归纳学习中,GNN 在训练过程中只看到训练集中的数据,而在测试过程中需要对未见过的数据进行预测,这属于机器学习中典型的监督学习 (supervised learning)。在这种情况下,标签用来调整 GNN 的参数,模型需要具备良好的泛化能力,能够从有限的样本中推断出普遍适用的规律
  • 在转导学习中,GNN 在训练过程中会看到来自训练集和测试集的数据,它通过对已有的样本进行学习来进行预测和分类。模型只从训练集中学习数据,模型会尝试将已有的样本归类到已知的类别中,并根据这些样本的特征进行预测,标签用于信息扩散。转导学习不是直接从训练集中学习出一般性的规律,而是利用图数据间的相似性或连接性进行预测

我们在之前构建的图神经网络 (Graph Neural Networks, GNN) 和图卷积网络 (Graph Convolutional Network, GCN) 属于转导学习情况。而 GraphSAGE 模型可以在训练过程中使用整个图进行预测 (self(batch.x, batch.edge_index)),然后部分屏蔽这些预测,只使用训练数据计算损失并训练模型 (criterion(out[batch.train_mask], batch.y[batch.train_mask]))。
转导学习只能为固定的图生成嵌入,不能泛化到未见过的节点或图。但由于采用了邻居采样,GraphSAGE 可以在局部水平上对经过剪枝的计算图进行预测,这种情况下属于归纳学习框架,可以应用于具有相同特征模式的任何计算图。

2. 蛋白质相互作用数据集

在 GraphSAGE 一节中,我们已经在 PubMed 数据集上构建 GraphSAGE 模型实现了转导学习。接下来,我们将 GraphSAGE 应用于由 Agrawal 等人提出的蛋白质相互作用 (protein-protein interaction, PPI) 网络数据集。该数据集是 24 个图的集合,其中节点( 21,557 个)是人类蛋白质,边( 342,353 条)是人类细胞中蛋白质之间的连接。用 Gephi 制作的 PPI 图数据集可视化结果如下所示:

PPI 数据集

该数据集的目标是使用 121 个标签进行多标签分类,这意味着每个节点可以具有 0121 个标签。这不同于多类别分类,多类别分类中每个节点只会属于一个类别。接下来,我们使用 PyTorch Geometric (PyG) 实现 GraphSAGE 模型用于对 PPI 数据集执行多标签分类任务。

(1)PPI 数据集加载为三个不同的子集,训练集、验证集和测试集:

import torch
from sklearn.metrics import f1_score

from torch_geometric.datasets import PPI
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GraphSAGE

# Load training, evaluation, and test sets
train_dataset = PPI(root=".", split='train')
val_dataset = PPI(root=".", split='val')
test_dataset = PPI(root=".", split='test')

(2) 训练集包含 20 个图,而验证集和测试集只有两个图。对训练集应用邻居采样,为了方便起见,使用 Batch.from_data_list() 将所有训练图统一到一个集合中,然后应用邻居采样:

train_data = Batch.from_data_list(train_dataset)
train_loader = NeighborLoader(train_data, batch_size=2048, shuffle=True, num_neighbors=[20, 10], num_workers=2, persistent_workers=True)

(3) 训练集创建完毕后,使用 DataLoader 类创建批数据,将 batch_size 值定义为 2,与每批图的数量相对应:

val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

(4) 定义设备使批处理能够在 GPU 上进行处理。如果计算机中有 GPU,使用 GPU,否则就使用 CPU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3. 构建 GraphSAGE 模型实现归纳学习

使用 PyTorch Geometrictorch_geometric.nn 模块构建 GraphSAGE 模型。

(1) 使用 GraphSAGE() 类初始化一个两层的 GraphSAGE 模型,其中隐藏维度为 512,此外,还需要使用 to(device) 将模型放置在与数据相同的设备上:

model = GraphSAGE(
    in_channels=train_dataset.num_features,
    hidden_channels=512,
    num_layers=2,
    out_channels=train_dataset.num_classes,
).to(device)

(2) fit() 函数与 GraphSAGE 一节中使用的函数类似,不同之处在于,我们希望尽可能将数据移动到 GPU 上,并且由于每批数据有两个图,因此将损失乘以 2 (data.num_graphs):

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

def fit(loader):
    model.train()

    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        total_loss += loss.item() * data.num_graphs
        loss.backward()
        optimizer.step()
    return total_loss / len(loader.data)

由于 val_loadertest_loader 包含两个图且 batch_size 值为 2,因此在 test() 函数中,两个图位于同一个批数据中,而无需像训练时那样在加载器中循环。

(3) 使用度量指标 F1 分数代替准确率,F1 分数相当于精确度和召回率的调和平均值。但,模型的预测结果是 121 维的实数向量,我们需要将其转换成二进制向量,使用 out > 0 将它们与 data.y 进行比较:

@torch.no_grad()
def test(loader):
    model.eval()

    data = next(iter(loader))
    out = model(data.x.to(device), data.edge_index.to(device))
    preds = (out > 0).float().cpu()

    y, pred = data.y.numpy(), preds.numpy()
    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0

(4) 对模型进行 300epoch 的训练,并打印训练过程中模型在验证数据集上的 F1 分数:

for epoch in range(301):
    loss = fit(train_loader)
    val_f1 = test(val_loader)
    if epoch % 50 == 0:
        print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Val F1-score: {val_f1:.4f}')
'''
Epoch   0 | Train Loss: 12.686 | Val F1-score: 0.4866
Epoch  50 | Train Loss: 8.734 | Val F1-score: 0.7963
Epoch 100 | Train Loss: 8.600 | Val F1-score: 0.8098
Epoch 150 | Train Loss: 8.531 | Val F1-score: 0.8202
Epoch 200 | Train Loss: 8.495 | Val F1-score: 0.8230
Epoch 250 | Train Loss: 8.497 | Val F1-score: 0.8255
Epoch 300 | Train Loss: 8.432 | Val F1-score: 0.8290
'''

(5) 最后,计算测试集上的 F1 分数:

print(f'Test F1-score: {test(test_loader):.4f}')

# Test F1-score: 0.8527

可以看到,在归纳学习中,模型在 PPI 数据集上训练后的 F1 分数为 0.9360。当增加或减少隐藏维度的大小时,模型的性能会有有大幅改变,我们可以使用不同的值,如 1281,024,并观察训练的后的模型 F1 分数变化。
需要注意的是,在以上代码中,我们并未显式的使用掩码。这是由于实际上,归纳学习是由 PPI 数据集强制实现的;训练数据、验证数据和测试数据位于不同的图和数据加载器中。我们也可以使用 Batch.from_data_list() 将它们合并,然后再使用归纳学习的设定。

小结

在本节中,学习了图神经网络中转导学习(Transductive learning, 也称直推学习)和归纳学习 (Inductive learning) 的区别。其中,图神经网络中的归纳学习通常指的是从给定的训练图数据中学习出一个泛化能力强的模型,以便对未知图数据中的节点或边进行预测或分类,而转导学习通常指的是利用训练图数据和测试图数据之间的关联性进行推断,从而对给定的测试图数据进行预测或分类。并且构建了 GraphSAGE 模型在 PPI 数据集上测试了归纳学习,以执行多标签分类任务。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/619793.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Maven:Maven基础

Maven apache旗下的一个开源项目,一款用于管理和构建java项目的工具 什么是Maven 一个项目管理和构建工具,基于项目对象模型(POM)的概念,通过一小段描述信息来管理项目的构建,报告和文档. Maven的作用 依赖管理 方便快捷的管理项目依赖的资源jar包,避免版本冲突问题 统一…

C++的数据结构(四):队列

在数据结构中,队列(Queue)是一种特殊的线性表,只允许在表的前端(front)进行删除操作,而在表的后端(rear)进行插入操作。队列中没有元素时,称为空队列。队列的…

小程序的小组件

进度的组件 文字换行过滤 以及 排序 简单易懂 只为了记录工作 <template><div><ProgressBar :progress"progress" /><button click"increaseProgress">增加进度</button><view class"goods-name">12…

电脑锁屏快捷键是哪个?1分钟弄懂锁屏设置!

“当我暂时不需要使用电脑时&#xff0c;想给电脑设置锁屏&#xff0c;有朋友知道电脑锁屏快捷键是哪个吗&#xff1f;” 随着信息技术的飞速发展&#xff0c;我们在日常生活中经常需要使用电脑。然而&#xff0c;当我们暂时离开电脑时&#xff0c;如何确保电脑信息安全&#x…

【解决】Android APK文件安装时 已包含数字签名相同APP问题

引言 在开发Android程序过程中&#xff0c;编译好的APK文件&#xff0c;安装至Android手机时&#xff0c;有时会报 包含数字签名相同的APP 然后无法安装的问题&#xff0c;这可能是之前安装过同签名的APP&#xff0c;但是如果不知道哪个是&#xff0c;无法有效卸载&#xff0c;…

图文详解:synchronized关键字 及其底层原理

目录 一.线程安全问题 二.synchronized关键字 ▐ synchronized图解 ▐ 可重入锁及图解 ▐ synchronized用于方法上 三.Java标准库中synchronized的使用 四.synchronized的底层实现原理 一.线程安全问题 线程安全是指在多线程环境下&#xff0c;对共享资源的访问不会导致…

详解循环队列——链表与数组双版本

前言&#xff1a;本节内容主要是讲解循环队列。 在本篇中会讲到两个版本——数组版本、链表版本。本篇内容适合正在学习数据结构队列章节或者已经学过队列但对循环队列感觉模糊的友友们 。 首先先来看一下什么是循环队列 什么是循环队列 因为是刚开始讲解&#xff0c; 所以我们…

【基础绘图】 10.饼图

效果图&#xff1a; 主要步骤&#xff1a; 1. 数据准备&#xff1a;自己赋值的随机数 2. 图像绘制&#xff1a;绘制饼图 详细代码&#xff1a;着急的直接拖到最后有完整代码 步骤一&#xff1a;导入库包及图片存储路径并设置中文字体为宋体&#xff0c;西文为新罗马&#…

totoriseSVN 常见问题

1. SVN 无法 clean up 上传时没有关闭 Excel&#xff0c;导致传入了一些临时文件&#xff08;文件名以$开头&#xff09;&#xff0c;关闭文件后临时文件自动删除&#xff0c;导致 SVN 版本错乱&#xff0c;使用 CleanUp 功能无效 更新时提示【Previous operation has not fin…

win7 phpstudy 多站点无法保存hosts的原因

1、先找到hosts文件位置 C:\Windows\System32\drivers\etc hosts文件不是txt的后缀&#xff0c;它是一个系统文件 2、如果不显示需要查找隐藏文件 组织-》文件夹和搜索选项-》查看-》取消隐藏文件夹的的√ 3、文件无法编辑 属性不要勾选只读

【SAP-FICO】SAP-FICO生产订单-结算规则配置路径(OKO7)

需求&#xff1a; 作为一个ABAPer&#xff0c;有接到一个狗屁倒灶的配置需求&#xff0c;要求如下&#xff0c;给生产订单的结算规则显示出来 图1&#xff1a;找一个生产订单&#xff0c;显示其结算规则 CO03→菜单栏-表头→结算规则 图2&#xff1a;查看该生产订单&#xff0c…

SMB/RPC协议分析之-命名/匿名管道pipe

在前面的文章中&#xff0c;介绍了SMB协议共享相关的内容&#xff0c;详见我的专栏《网络攻防协议实战分析》&#xff0c;连接这里。在SMB协议中往往需要连接到对应的远程管道&#xff0c;如果你经常接触到SMB协议&#xff0c;相信你对于lsass&#xff0c;svcctl等多种命名管道…

数据结构-二叉树-AVL树(平衡二叉树)

红黑树是平衡二叉树的一个变种。 一、 产生平衡二叉树的原因。 二叉搜索树的问题在于极端场景下退化为类似链表的结构&#xff0c;所以搜索的时间复杂度就变成了O(N)。为了保证二叉树不退化为链表&#xff0c;我们必须保证二叉树的的平衡性。 二叉平衡搜索树就是解决上面的问…

职场新人小王的沟通挑战与成长

近日&#xff0c;职场新人小王遇到了一个沟通上的小难题。作为刚刚踏入社会的新鲜人&#xff0c;小王在工作会议上因为一次直接的反馈而无意间触动了同事的敏感神经&#xff0c;导致双方关系稍显紧张。 在一次团队会议上&#xff0c;小王被要求分享对项目进度的看法以及建议。他…

【图解计算机网络】TCP 重传、滑动窗口、流量控制、拥塞控制

TCP 重传、滑动窗口、流量控制、拥塞控制 TCP 重传超时重传快速重传 滑动窗口流量控制拥塞控制慢启动拥塞避免拥塞发生快速恢复 TCP 重传 TCP重传是当发送的报文发生丢失的时候&#xff0c;重新发送丢失报文的一种机制&#xff0c;它是保证TCP协议可靠性的一种机制。 TCP重传…

9. SVG中的text元素

SVG (Scalable Vector Graphics) 提供了强大的文本渲染能力&#xff0c;其中<text>元素是常用 的文本操作的元素。本文将详细介绍<text>标签的基本使用方法&#xff0c;并展示如何通过<tspan>和<textPath>增强文本的表现力。 <text>标签基础 &…

【PHP【实战项目】系统性教学】——使用最精简的代码完成用户的登录与退出

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;开发者-曼亿点 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 曼亿点 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a…

MyBatis——MyBatis 参数处理

一、单个简单类型参数 简单类型包括&#xff1a; byte short int long float double char Byte Short Integer Long Float Double Character String java.util.Date java.sql.Date parameterType 属性&#xff1a;告诉 MyBatis 参数的类型 MyBatis 自带类型自动推断机制…

【Linux】centos7安装软件(rpm、yum、编译安装),补充:查找命令的相关文件路径,yum安装mysql

【Linux】技术上&#xff0c;Linux是内核。而术语上&#xff0c;我们通常说的Linux是完整的操作系统&#xff0c;其实称为"Linux发行版"&#xff0c;是将Linux内核和应用系统打包&#xff0c;由不同的发行家族发行了不同版本。Linux发行版众多&#xff0c;主要有RedH…

Debian常用命令:高效管理与运维的必备指南

在Linux世界中&#xff0c;Debian以其稳定性、安全性和开源精神赢得了广大用户的青睐。作为一个基于Linux的操作系统&#xff0c;Debian拥有丰富且强大的命令行工具&#xff0c;这些命令对于系统管理员和开发者来说至关重要。本文将为您介绍一系列Debian系统中的常用命令&#…