动手学图神经网络(2):跆拳道俱乐部案例实战

动手学图神经网络(2):跆拳道俱乐部案例实战

在深度学习领域,图神经网络(GNNs)能将传统深度学习概念推广到不规则的图结构数据,使神经网络能够处理对象及其关系。将基于 PyTorch Geometric 库,一步步探索图神经网络的奥秘。

安装必要的包

首先, 安装所需的 Python 包。在开始之前, 需要获取当前使用的 PyTorch 版本,

import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

接下来, 可以使用以下命令安装必要的库:

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

同时, 还需要一些用于可视化的辅助函数:

%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()

def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()

图神经网络基础介绍

图神经网络(GNNs)旨在将经典深度学习概念推广到不规则结构数据(与图像或文本不同),使神经网络能够推理对象及其关系。遵循简单的神经消息传递方案,在图 G = ( V , E ) \mathcal{G} = (\mathcal{V}, \mathcal{E}) G=(V,E) 中,所有节点 v ∈ V v \in \mathcal{V} vV 的节点特征 x v ( ℓ ) \mathbf{x}_v^{(\ell)} xv() 通过聚合其邻居 N ( v ) \mathcal{N}(v) N(v) 的局部信息来迭代更新:
x v ( ℓ + 1 ) = f θ ( ℓ + 1 ) ( x v ( ℓ ) , { x w ( ℓ ) : w ∈ N ( v ) } ) \mathbf{x}_v^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_v^{(\ell)}, \left\{ \mathbf{x}_w^{(\ell)} : w \in \mathcal{N}(v) \right\} \right) xv(+1)=fθ(+1)(xv(),{xw():wN(v)})

本教程将基于 PyTorch Geometric (PyG) 库 介绍图神经网络的一些基本概念。PyTorch Geometric 是流行深度学习框架 PyTorch 的扩展库,包含各种方法和实用工具,便于实现图神经网络。

将以著名的 Zachary’s karate club network 为例,深入了解图神经网络。这个图描述了一个空手道俱乐部 34 名成员的社交网络,并记录了俱乐部外成员之间的联系。我们的目标是检测由成员互动产生的社区。

加载数据集

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

初始化 KarateClub 数据集后,可以检查其一些属性。可以看到,这个数据集只包含 一个图,每个节点都被分配了一个 34 维的特征向量,用于唯一描述空手道俱乐部的成员。此外,图中正好有 4 个类,代表每个节点所属的社区。

在这里插入图片描述

查看图的详细信息

data = dataset[0]  # Get the first graph object.

print(data)
print('==============================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

在这里插入图片描述

在 PyTorch Geometric 中,每个图都由一个 Data 对象表示,它包含了描述图所需的所有信息。通过 print(data) 可以查看数据对象的属性和形状的简要摘要。

Data(edge_index=[2, 156], x=[34, 34], y=[34], train_mask=[34])

可以看到,这个 data 对象包含 4 个属性:

  1. edge_index 属性保存了 图连接性 的信息,即每个边的源节点和目标节点的索引对。
  2. 节点特征x 表示(34 个节点中的每个节点都被分配了一个 34 维的特征向量)。
  3. 节点标签y 表示(每个节点都被分配到一个类)。
  4. 还有一个额外的属性 train_mask,它描述了哪些节点的社区分配是已知的。

查看图的连接信息

from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

edge_index = data.edge_index
print(edge_index.t())

通过打印 edge_index,我们可以了解 PyG 如何在内部表示图的连接性。可以看到,对于每条边,edge_index 包含一个两个节点索引的元组,其中第一个值描述源节点的索引,第二个值描述目标节点的索引。

tensor([[ 0,  1],
        [ 0,  2],
        [ 0,  3],
        [ 0,  4],
        [ 0,  5],
        [ 0,  6],
        [ 0,  7],
        [ 0,  8],
        [ 0, 10],
        [ 0, 11],
        [ 0, 12],
        [ 0, 13],
        [ 0, 17],
        [ 0, 19],
        [ 0, 21],
        [ 0, 31],
        [ 1,  0],
        [ 1,  2],
        [ 1,  3],
        [ 1,  7],
        [ 1, 13],
        [ 1, 17],
        [ 1, 19],
        [ 1, 21],
        [ 1, 30],
        [ 2,  0],
        [ 2,  1],
        [ 2,  3],
        [ 2,  7],
        [ 2,  8],
        [ 2,  9],
        [ 2, 13],
        [ 2, 27],
        [ 2, 28],
        [ 2, 32],
        [ 3,  0],
        [ 3,  1],
        [ 3,  2],
        [ 3,  7],
        [ 3, 12],
        [ 3, 13],
        [ 4,  0],
        [ 4,  6],
        [ 4, 10],
        [ 5,  0],
        [ 5,  6],
        [ 5, 10],
        [ 5, 16],
        [ 6,  0],
        [ 6,  4],
        [ 6,  5],
        [ 6, 16],
        [ 7,  0],
        [ 7,  1],
        [ 7,  2],
        [ 7,  3],
        [ 8,  0],
        [ 8,  2],
        [ 8, 30],
        [ 8, 32],
        [ 8, 33],
        [ 9,  2],
        [ 9, 33],
        [10,  0],
        [10,  4],
        [10,  5],
        [11,  0],
        [12,  0],
        [12,  3],
        [13,  0],
        [13,  1],
        [13,  2],
        [13,  3],
        [13, 33],
        [14, 32],
        [14, 33],
        [15, 32],
        [15, 33],
        [16,  5],
        [16,  6],
        [17,  0],
        [17,  1],
        [18, 32],
        [18, 33],
        [19,  0],
        [19,  1],
        [19, 33],
        [20, 32],
        [20, 33],
        [21,  0],
        [21,  1],
        [22, 32],
        [22, 33],
        [23, 25],
        [23, 27],
        [23, 29],
        [23, 32],
        [23, 33],
        [24, 25],
        [24, 27],
        [24, 31],
        [25, 23],
        [25, 24],
        [25, 31],
        [26, 29],
        [26, 33],
        [27,  2],
        [27, 23],
        [27, 24],
        [27, 33],
        [28,  2],
        [28, 31],
        [28, 33],
        [29, 23],
        [29, 26],
        [29, 32],
        [29, 33],
        [30,  1],
        [30,  8],
        [30, 32],
        [30, 33],
        [31,  0],
        [31, 24],
        [31, 25],
        [31, 28],
        [31, 32],
        [31, 33],
        [32,  2],
        [32,  8],
        [32, 14],
        [32, 15],
        [32, 18],
        [32, 20],
        [32, 22],
        [32, 23],
        [32, 29],
        [32, 30],
        [32, 31],
        [32, 33],
        [33,  8],
        [33,  9],
        [33, 13],
        [33, 14],
        [33, 15],
        [33, 18],
        [33, 19],
        [33, 20],
        [33, 22],
        [33, 23],
        [33, 26],
        [33, 27],
        [33, 28],
        [33, 29],
        [33, 30],
        [33, 31],
        [33, 32]])

这种表示方式称为 COO 格式(坐标格式),通常用于表示稀疏矩阵。PyG 以稀疏方式表示图,只保存邻接矩阵 A \mathbf{A} A 中非零元素的坐标和值。

可视化图

from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

可以将图转换为 networkx 库的格式,利用其强大的可视化工具来可视化图。
在这里插入图片描述

实现图神经网络

在了解了 PyG 的数据处理之后, 实现 第一个图神经网络 ! 将使用最简单的 GNN 算子之一,即 GCN 层 (Kipf et al. (2017)),其定义为:
x v ( ℓ + 1 ) = W ( ℓ + 1 ) ∑ w ∈ N ( v )   ∪   { v } 1 c w , v ⋅ x w ( ℓ ) \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} xv(+1)=W(+1)wN(v){v}cw,v1xw()
其中 W ( ℓ + 1 ) \mathbf{W}^{(\ell + 1)} W(+1) 表示形状为 [num_output_features, num_input_features] 的可训练权重矩阵, c w , v c_{w,v} cw,v 指的是每条边的固定归一化系数。

定义图神经网络模型

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.

        # Apply a final (linear) classifier.
        out = self.classifier(h)

        return out, h

model = GCN()
print(model)

__init__ 方法中,初始化了所有的构建块,并在 forward 方法中定义了网络的计算流程。定义并堆叠了 三个图卷积层,这相当于聚合每个节点周围的 3 跳邻域信息。每个 GCNConv 层后都应用了一个 tanh 非线性激活函数。

应用一个线性变换 (torch.nn.Linear) 作为分类器,将节点映射到 4 个类/社区之一。
在这里插入图片描述

可视化节点嵌入

model = GCN()

_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

visualize_embedding(h, color=data.y)

即使在训练模型之前,模型产生的节点嵌入已经很好地反映了图的社区结构。相同颜色(社区)的节点在嵌入空间中已经紧密聚集在一起,这表明 GNNs 引入了强大的归纳偏置,使得在输入图中彼此接近的节点具有相似的嵌入。
在这里插入图片描述

训练模型

import time
from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 430})'''))

model = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)

训练过程与其他 PyTorch 模型类似。定义了损失函数 (CrossEntropyLoss) 和随机梯度优化器 (Adam)。在每一轮优化中,进行前向传播和反向传播,计算模型参数相对于损失的梯度,并更新参数。

通过观察节点嵌入的变化,可以看到 3 层 GCN 模型能够很好地线性分类,并正确分类大多数节点。

在这里插入图片描述

本教程是对图神经网络和 PyTorch Geometric 的初步介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/959928.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Elastic Agent 对 Kafka 的新输出:数据收集和流式传输的无限可能性

作者:来 Elastic Valerio Arvizzigno, Geetha Anne 及 Jeremy Hogan 介绍 Elastic Agent 的新功能:原生输出到 Kafka。借助这一最新功能,Elastic 用户现在可以轻松地将数据路由到 Kafka 集群,从而实现数据流和处理中无与伦比的可扩…

1.25学习

web bugku-源代码 打开环境后看到了一个提交的界面,我们根据题目查看源代码,看到了js代码,其中有几处是url编码,我们对其进行解码,后面的unescape()函数就是将p1解码以及%35%34%61%61%32p2解码…

Hive详细讲解-基础语法快速入门

文章目录 1.DDL数据库相关操作1.1创建数据库1.2指定路径下创建数据库1.3添加额外信息创建with dbproperties1.4查看数据库 结合like模糊查询 2.查看某一个数据库的相关信息2.1.如何查看数据库信息,extended可选2.2修改数据库 3.Hive基本数据类型4.复杂数据类型5.类型…

深度解析:基于Vue 3与Element Plus的学校管理系统技术实现

一、项目架构分析 1.1 技术栈全景 核心框架:Vue 3 TypeScript UI组件库:Element Plus(含图标动态注册) 状态管理:Pinia(用户状态持久化) 路由方案:Vue Router(动态路…

基于Django的个人博客系统的设计与实现

【Django】基于Django的个人博客系统的设计与实现(完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 系统采用Python作为主要开发语言,结合Django框架构建后端逻辑,并运用J…

【架构面试】一、架构设计认知

涉及分布式锁、中间件、数据库、分布式缓存、系统高可用等多个技术领域,旨在考查候选人的技术深度、架构设计能力与解决实际问题的能力。 1. 以 Redis 是否可以作为分布式锁为例: 用 Redis 实现分布式锁会存在哪些问题? 死锁:如果…

DrawDB:超好用的,免费数据库设计工具

DrawDB:超好用的,免费数据库设计工具 引言 在软件开发过程中,数据库设计是一个至关重要的环节。 无论是关系型数据库还是非关系型数据库,良好的数据库设计都能显著提升系统的性能和可维护性。 然而,数据库设计往往…

如何将xps文件转换为txt文件?xps转为pdf,pdf转为txt,提取pdf表格并转为txt

文章目录 xps转txt方法一方法二 pdf转txt整页转txt提取pdf表格,并转为txt 总结另外参考XPS文件转换为TXT文件XPS文件转换为PDF文件PDF文件转换为TXT文件提取PDF表格并转为TXT示例代码(部分) 本文测试代码已上传,路径如下&#xff…

【Linux】线程、线程控制、地址空间布局

⭐️个人主页:小羊 ⭐️所属专栏:Linux 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 1、Linux线程1.1 线程的优缺点1.2 线程异常和用途1.3 线程等待1.3 线程终止1.4 线程分离1.5 线程ID和地址空间布局1.6 线程栈 1、…

c语言操作符(详细讲解)

目录 前言 一、算术操作符 一元操作符: 二元操作符: 二、赋值操作符 代码例子: 三、比较操作符 相等与不相等比较操作符: 大于和小于比较操作符: 大于等于和小于等于比较操作符: 四、逻辑操作符 逻辑与&…

宏_wps_宏修改word中所有excel表格的格式_设置字体对齐格式_删除空行等

需求: 将word中所有excel表格的格式进行统一化,修改其中的数字类型为“宋体, 五号,右对齐, 不加粗,不倾斜”,其中的中文为“宋体, 五号, 不加粗,不倾斜” 数…

第一届“启航杯”网络安全挑战赛WP

misc PvzHE 去这个文件夹 有一张图片 QHCTF{300cef31-68d9-4b72-b49d-a7802da481a5} QHCTF For Year 2025 攻防世界有一样的 080714212829302316092230 对应Q 以此类推 QHCTF{FUN} 请找出拍摄地所在位置 柳城 顺丰 forensics win01 这个软件 云沙盒分析一下 md5 ad4…

GESP2024年3月认证C++六级( 第三部分编程题(2)好斗的牛)

参考程序&#xff08;暴力枚举&#xff09; #include <iostream> #include <vector> #include <algorithm> using namespace std; int N; vector<int> a, b; int ans 1e9; int main() {cin >> N;a.resize(N);b.resize(N);for (int i 0; i &l…

QFramework实现原理 一 :日志篇

作为一款轻量级开源的Unity程序框架&#xff0c;QFramework结合了作者凉鞋多年的开发经验&#xff0c;是比较值得想要学习框架的初学者窥探一二的对象&#xff0c;我就尝试结合凉鞋大大给出的文档和ai&#xff0c;解析一下其背后的代码逻辑&#xff0c;以作提升自己的一次试炼 …

图论汇总1

1.图论理论基础 图的基本概念 二维坐标中&#xff0c;两点可以连成线&#xff0c;多个点连成的线就构成了图。 当然图也可以就一个节点&#xff0c;甚至没有节点&#xff08;空图&#xff09; 图的种类 整体上一般分为 有向图 和 无向图。 有向图是指 图中边是有方向的&a…

_CLASSDEF在C++中的用法详解及示例

_CLASSDEF在C++中的用法详解及示例 _CLASSDEF的定义与使用示例说明代码解析总结在C++编程中,宏(Macro)是一种预处理指令,它允许程序员在编译之前对代码进行文本替换。_CLASSDEF是一个自定义的宏,它提供了一种便捷的方式来定义类及其相关类型。本文将详细介绍_CLASSDEF在C+…

华为数据之道-读书笔记

内容简介 关键字 数字化生产 已经成为普遍的商业模式&#xff0c;其本质是以数据为处理对象&#xff0c;以ICT平台为生产工具&#xff0c;以软件为载体&#xff0c;以服务为目的的生产过程。 信息与通信技术平台&#xff08;Information and Communication Technology Platf…

从CRUD到高级功能:EF Core在.NET Core中全面应用(四)

初识表达式树 表达式树&#xff1a;是一种可以描述代码结构的数据结构&#xff0c;它由一个节点组成&#xff0c;节点表示代码中的操作、方法调用或条件表达式等&#xff0c;它将代码中的表达式转换成一个树形结构&#xff0c;每个节点代表了代码中的操作例如&#xff0c;如果…

系统思考—问题分析

很多中小企业都在面对转型的难题&#xff1a;市场变化快&#xff0c;资源有限&#xff0c;团队协作不畅……这些问题似乎总是困扰着我们。就像最近和一位企业主交流时&#xff0c;他提到&#xff1a;“我们团队每天都很忙&#xff0c;但效率始终没见提升&#xff0c;感觉像是在…

MySQL 的索引类型【图文并茂】

基本分类 文本生成MindMap:https://app.pollyoyo.com/planttext <style> mindmapDiagram {node {BackgroundColor yellow}:depth(0) {BackGroundColor SkyBlue}:depth(1) {BackGroundColor lightGreen} } </style> * MySQL 索引** 数据结构角度 *** B树索引*** 哈…