GNN系统学习:简单图论、环境配置、PyG中图与图数据集的表示和使用

Reference

datawhale开源学习资料

开篇

1.1 为什么要在图上进行深度学习?

在过去的深度学习应用中,我们接触的数据形式主要是这四种:矩阵、张量、序列(sequence)和时间序列(time series)。然而来自现实世界应用的数据更多地是图的结构,如社交网络、交通网络、蛋白质与蛋白质相互作用网络、知识图谱和大脑网络等。
在这里插入图片描述

此外,大量的现实世界的问题可以作为图上的一组小的计算任务来解决。推断节点属性、检测异常节点(如垃圾邮件发送者)、识别与疾病相关的基因、向病人推荐药物等,都可以概括为节点分类问题。推荐、药物副作用预测、药物与目标的相互作用识别和知识图谱的完成(knowledge graph completion)等,本质上都是边预测问题

同一图的节点存在连接关系,这表明节点不是独立的。然而,传统的机器学习技术假设样本是独立且同分布的,因此传统机器学习方法不适用于图计算任务。图机器学习研究如何构建节点表征,节点表征要求同时包含节点自身的信息和节点邻接的信息,从而我们可以在节点表征上应用传统的分类技术实现节点分类。图机器学习成功的关键在于如何为节点构建表征。深度学习已经被证明在表征学习中具有强大的能力,它大大推动了计算机视觉、语音识别和自然语言处理等各个领域的发展。因此,将深度学习与图连接起来,利用神经网络来学习节点表征,将带来前所未有的机会。

然而,如何将神经网络应用于图,这一问题面临着巨大的挑战。首先,传统的深度学习是为规则且结构化的数据设计的图像、文本、语音和时间序列等都是规则且结构化的数据。但图是不规则的,节点是无序的,节点可以有不同的邻居节点。其次,规则数据的结构信息是简单的,而图的结构信息是复杂的,特别是在考虑到各种类型的复杂图,它们的节点和边可以关联丰富的信息,这些丰富的信息无法被传统的深度学习方法捕获。

以往的深度学习技术是为规则且结构化的数据设计的,无法直接用于图数据。应用于图数据的神经网络,要求:

  • 适用于不同度的节点;
  • 节点表征的计算与邻接节点的排序无关;
  • 不但能够根据节点信息、邻接节点的信息和边的信息计算节点表征,还能根据图拓扑结构计算节点表征。下面的图片展示了一个需要根据图拓扑结构计算节点表征的例子。图片中展示了两个图,它们同样有俩黄、俩蓝、俩绿,共6个节点,因此它们的节点信息相同;假设边两端节点的信息为边的信息,那么这两个图有一样的边,即它们的边信息相同。但这两个图是不一样的图,它们的拓扑结构不一样
    在这里插入图片描述

1.2 此组队学习涵盖的话题

此组队学习由五个话题组成,每一话题都包含理论部分与实践部分:

话题一:

  • 我们将首先学习简单图论知识、了解常规的图预测任务(见第2节);
  • 然后学习基于PyG包的图数据的表示与使用(见第3节)。

话题二:

  • 我们将首先学习实现图神经网络的通用范式,即消息传递范式;
  • 其次学习PyG中的消息传递(MessagePassing)基类的属性、方法和运行流程;
  • 最后学习如何自定义一个消息传递图神经网络(见第4节)。

话题三:

  • 图计算应用中最基础的任务是节点表征(Node Representation)学习。
  • 我们将以GCN和GAT(两个最为经典的图神经网络)为例,学习基于图神经网络的节点表征学习的一般过程;并且通过MLP、GCN和GAT三者在节点分类任务中的比较,学习图神经网络为什么强于普通的MLP神经网络,以及GCN和GAT的差别(见第5节)。
  • 此外,我们还将学习如何构造一个数据全部存于内存的数据集类(见第6-1节);
  • 并学习基于节点表征学习的图节点预测任务和边预测任务的实践(见第6-2节)。

话题四:

  • 我们将首先分析在超大图上进行节点表征学习面临着的挑战**;
  • 接着学习应对挑战的一种解决方案;
  • 最后学习超大图节点预测任务的实践(见第7节)。

话题五:

  • 我们将首先学习基于图神经网络的图表征学习的一般过程(见第8节);
  • 接着学习样本按需获取的数据集类的构造方法(见第9-1节);
  • 最后学习基于图表征学习的图预测任务的实践(见第9-2节)。

除了话题四和话题五都依赖于话题三之外,其余话题都依赖于该话题自身的前一话题。

图结构数据

2.1 图的表示

定义一(图)

  • 一个图被记为 G = { V , E } \mathcal{G}=\left\{\mathcal{V}, \mathcal{E}\right\} G={V,E},其中 V = { v 1 , … , v N } \mathcal{V} = \left\{v_{1}, \ldots, v_{N}\right\} V={v1,,vN}是数量为 N = ∣ V ∣ N=|\mathcal{V}| N=V 的节点的集合, E = { e 1 , … , e M } \mathcal{E}=\left\{e_{1}, \ldots, e_{M}\right\} E={e1,,eM}是数量为 M M M 的边的集合。
  • 图用节点表示实体(entities ),用边表示实体间的关系(relations)。
  • 节点和边的信息可以是类别型的(categorical),类别型数据的取值只能是哪一类别。一般称类别型的信息为标签(label)。
  • 节点和边的信息可以是数值型的(numeric),数值型数据的取值范围为实数。一般称数值型的信息为属性(attribute)。
  • 在图的计算任务中,我们认为,节点一定含有信息(至少含有节点的度的信息),边可能含有信息

定义二(图的邻接矩阵)

  • 给定一个图 G = { V , E } \mathcal{G}=\left\{\mathcal{V}, \mathcal{E}\right\} G={V,E},其对应的邻接矩阵被记为 A ∈ { 0 , 1 } N × N \mathbf{A} \in \left\{0,1\right\}^{N \times N} A{0,1}N×N A i , j = 1 \mathbf{A}_{i, j}=1 Ai,j=1表示存在从节点 v i v_i vi v j v_j vj的边,反之表示不存在从节点 v i v_i vi v j v_j vj的边。
  • 在无向图中,从节点 v i v_i vi v j v_j vj的边存在,意味着从节点 v j v_j vj v i v_i vi的边也存在。因而无向图的邻接矩阵是对称的。
  • 在无权图中,各条边的权重被认为是等价的,即认为各条边的权重为 1 1 1
  • 对于有权图,其对应的邻接矩阵通常被记为 W ∈ R N × N \mathbf{W} \in \mathbb{R}^{N \times N} WRN×N,其中 W i , j = w i j \mathbf{W}_{i, j}=w_{ij} Wi,j=wij表示从节点 v i v_i vi v j v_j vj的边的权重。若边不存在时,边的权重为 0 0 0

2.2 图的属性

定义三(节点的度,degree)

  • 对于有向有权图,节点 v i v_i vi的出度(out degree)等于从 v i v_i vi出发的边的权重之和,节点 v i v_i vi的入度(in degree)等于从连向 v i v_i vi的边的权重之和。
  • 无向图是有向图的特殊情况,节点的出度与入度相等。
  • 无权图是有权图的特殊情况,各边的权重为 1 1 1,那么节点 v i v_i vi的出度(out degree)等于从 v i v_i vi出发的边的数量,节点 v i v_i vi的入度(in degree)等于从连向 v i v_i vi的边的数量。
  • 节点 v i v_i vi的度记为 d ( v i ) d(v_i) d(vi),入度记为 d i n ( v i ) d_{in}(v_i) din(vi),出度记为 d o u t ( v i ) d_{out}(v_i) dout(vi)

定义四(邻接节点,neighbors)

  • 节点 v i v_i vi的邻接节点为与节点 v i v_i vi直接相连的节点,其被记为 N ( v i ) \mathcal{N(v_i)} N(vi)
  • 节点 v i v_i vi k k k跳远的邻接节点(neighbors with k-hop)指的是到节点 v i v_i vi要走 k k k步的节点(一个节点的 2 2 2跳远的邻接节点包含了自身)。

定义五(行走,walk)

  • w a l k ( v 1 , v 2 ) = ( v 1 , e 6 , e 5 , e 4 , e 1 , v 2 ) walk(v_1,v_2)=(v_1,e_6,e_5,e_4,e_1,v_2) walk(v1,v2)=(v1,e6,e5,e4,e1,v2),这是一次“行走”,它是一次从节点 v 1 v_1 v1出发,依次经过边 e 6 , e 5 , e 4 , e 1 e_6,e_5,e_4,e_1 e6,e5,e4,e1,最终到达节点 v 2 v_2 v2的“行走”。
  • 下图所示为 w a l k ( v 1 , v 2 ) = ( v 1 , e 6 , e 5 , e 4 , e 1 , v 2 ) walk(v_1, v_2) = (v_1, e_6,e_5,e_4,e_1,v_2) walk(v1,v2)=(v1,e6,e5,e4,e1,v2),其中红色数字标识了边的访问序号。
  • 在“行走”中,节点是允许重复的。

定理六

  • 有一图,其邻接矩阵为 A A A, A n A^n An邻接矩阵的 n n n次方,那么 A n [ i , j ] \mathbf{A}^{n}[i,j] An[i,j]等于从节点 v i v_i vi到节点 v j v_j vj的长度为 n n n的行走的个数。(也就是,以节点 v i v_i vi为起点,节点 v j v_j vj为终点,长度为 n n n的节点访问方案的数量,节点访问中可以兜圈子重复访问一些节点)

定义七(路径,path)

  • “路径”是节点不可重复的“行走”。

定义八(子图,subgraph)

  • 有一图 G = { V , E } \mathcal{G}=\left\{\mathcal{V}, \mathcal{E}\right\} G={V,E},另有一图 G ′ = { V ′ , E ′ } \mathcal{G}^{\prime}=\left\{\mathcal{V}^{\prime}, \mathcal{E}^{\prime}\right\} G={V,E},其中 V ′ ∈ V \mathcal{V}^{\prime} \in \mathcal{V} VV E ′ ∈ E \mathcal{E}^{\prime} \in \mathcal{E} EE并且 V ′ \mathcal{V}^{\prime} V不包含 E ′ \mathcal{E}^{\prime} E中未出现过的节点,那么 G ′ \mathcal{G}^{\prime} G G \mathcal{G} G的子图。

定义九(连通分量,connected component)

  • 给定图 G ′ = { V , E } \mathcal{G}^{\prime}=\left\{\mathcal{V}, \mathcal{E}\right\} G={V,E}是图 G = { V , E } \mathcal{G}=\left\{\mathcal{V}, \mathcal{E}\right\} G={V,E}的子图。记属于图 G \mathcal{G} G但不属于 G ′ \mathcal{G}^{\prime} G图的节点集合记为 V / V ′ \mathcal{V}/\mathcal{V}^{\prime} V/V 。如果属于 V ′ \mathcal{V}^{\prime} V的任意节点对之间存在至少一条路径,但不存在一条边连接属于 V ′ \mathcal{V}^{\prime} V的节点与属于 V / V ′ \mathcal{V}/\mathcal{V}^{\prime} V/V的节点,那么图 G ′ \mathcal{G}^{\prime} G是图 G \mathcal{G} G的连通分量。

在这里插入图片描述

左右两边子图都是整图的连通分量。

定义十(连通图,connected graph)

  • 当一个图只包含一个连通分量,即其自身,那么该图是一个连通图。

定义十一(最短路径,shortest path)

  • v s , v t ∈ V v_s, v_t \in \mathcal{V} vs,vtV 是图 G = { V , E } \mathcal{G}=\left\{\mathcal{V}, \mathcal{E}\right\} G={V,E}上的一对节点,节点对 v s , v t ∈ V v_{s}, v_{t} \in \mathcal{V} vs,vtV之间所有路径的集合记为 P s t \mathcal{P}_{\mathrm{st}} Pst。节点对 v s , v t v_{s}, v_{t} vs,vt之间的最短路径 p s t s p p_{\mathrm{s} t}^{\mathrm{sp}} pstsp P s t \mathcal{P}_{\mathrm{st}} Pst中长度最短的一条路径,其形式化定义为 p s t s p = arg ⁡ min ⁡ p ∈ P s t ∣ p ∣ p_{\mathrm{s} t}^{\mathrm{sp}}=\arg \min \limits_{p \in \mathcal{P}_{\mathrm{st}}}|p| pstsp=argpPstminp 其中, p p p表示 P s t \mathcal{P}_{\mathrm{st}} Pst中的一条路径, ∣ p ∣ |p| p是路径 p p p的长度。

定义十二(直径,diameter):

  • 给定一个连通图 G = V , E \mathcal{G}={\mathcal{V}, \mathcal{E}} G=V,E,其直径为其所有节点对之间的最短路径的最大值,形式化定义为
    diameter ⁡ ( G ) = max ⁡ v s , v t ∈ V min ⁡ p ∈ P s t ∣ p ∣ \operatorname{diameter}(\mathcal{G})=\max {v{s}, v_{t} \in \mathcal{V}} \min {p \in \mathcal{P}{s t}}|p| diameter(G)=maxvs,vtVminpPstp

定义十三(拉普拉斯矩阵,Laplacian Matrix):

  • 给定一个图 G = { V , E } \mathcal{G}=\left\{\mathcal{V}, \mathcal{E}\right\} G={V,E},其邻接矩阵为 A A A,其拉普拉斯矩阵定义为 L = D − A \mathbf{L=D-A} L=DA,其中 D = d i a g ( d ( v 1 ) , ⋯   , d ( v N ) ) \mathbf{D=diag(d(v_1), \cdots, d(v_N))} D=diag(d(v1),,d(vN))

定义十四(对称归一化的拉普拉斯矩阵,Symmetric normalized Laplacian):

  • 给定一个图 G = V , E \mathcal{G}={\mathcal{V}, \mathcal{E}} G=V,E,其邻接矩阵为 A A A,其规范化的拉普拉斯矩阵定义为
    L = D − 1 2 ( D − A ) D − 1 2 = I − D − 1 2 A D − 1 2 \mathbf{L}=\mathbf{D}^{-\frac{1}{2}}(\mathbf{D}-\mathbf{A})\mathbf{D}^{-\frac{1}{2}}=\mathbf{I}-\mathbf{D}^{-\frac{1}{2}}\mathbf{A}\mathbf{D}^{-\frac{1}{2}} L=D21(DA)D21=ID21AD21

2.3 图的种类

  • 同质图(Homogeneous Graph):只有一种类型的节点和一种类型的边的图。
  • 异质图(Heterogeneous Graph):存在多种类型的节点和多种类型的边的图。

在这里插入图片描述

  • 二部图(Bipartite Graphs):节点分为两类,只有不同类的节点之间存在边。
    在这里插入图片描述

2.4 图结构数据上的机器学习

  1. 节点预测:预测节点的类别或某类属性的取值
    例子:对是否是潜在客户分类、对游戏玩家的消费能力做预测
  2. 边预测:预测两个节点间是否存在链接
    例子:Knowledge graph completion、好友推荐、商品推荐
  3. 图的预测:对不同的图进行分类或预测图的属性
    例子:分子属性预测
  4. 节点聚类:检测节点是否形成一个社区
    例子:社交圈检测
  5. 其他任务
    图生成:例如药物发现
    图演变:例如物理模拟
    ……

环境配置与PyG中图与图数据集的表示和使用

3.1 PyG引言

PyTorch Geometric (PyG) 是面向几何深度学习的PyTorch的扩展库,几何深度学习指的是应用于图和其他不规则、非结构化数据的深度学习。

基于PyG库,我们可以轻松地根据数据生成一个图对象,然后很方便的使用它;我们也可以容易地为一个图数据集构造一个数据集类,然后很方便的将它用于神经网络。

通过此节的实践内容,我们将

  1. 首先学习程序运行环境的配置。
  2. 接着学习PyG中图数据的表示及其使用,即学习PyG中Data类。
  3. 最后学习PyG中图数据集的表示及其使用,即学习PyG中Dataset类。

3.2 环境配置

  1. 安装正确版本的pytorch和cudatoolkit(Cudatoolkit的目的是为开发人员提供方便操作GPU的工具和库,避免他们需要重复编写底层功能),此处安装1.8.1版本的pytorch和11.1版本的cudatoolkit
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia

确认是否正确安装,正确的安装应出现下方的结果

$ python -c "import torch; print(torch.__version__)"
# 1.8.1
$ python -c "import torch; print(torch.version.cuda)"
# 11.1

在一个torch版本为1.9.0的上:

1.9.0+cpu
None

在一个torch版本为2.4.1+cu121的上:

2.4.1+cu121
12.1
  1. 安装正确版本的PyG
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-geometric
# 通用形式
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-geometric

# 电脑安装
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
pip install torch-geometric

# 服务器安装
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
pip install torch-geometric

其他版本的安装方法以及安装过程中出现的大部分问题的解决方案可以在 Installation of of PyTorch Geometric 页面找到。

3.3 Data类——PyG中图的表示及其使用

3.3.1 Data对象的创建

A. 通过构造函数

Data类的构造函数:

class Data(object):

    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, **kwargs):
    r"""
    Args:
        x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes, num_node_features]`
        edge_index (LongTensor, optional): 边索引矩阵,大小为`[2, num_edges]`,第0行可称为头(head)节点、源(source)节点、邻接节点,第1行可称为尾(tail)节点、目标(target)节点、中心节点
        edge_attr (Tensor, optional): 边属性矩阵,大小为`[num_edges, num_edge_features]`
        y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是边的标签)
	
    """
    self.x = x
    self.edge_index = edge_index
    self.edge_attr = edge_attr
    self.y = y

    for key, item in kwargs.items():
        if key == 'num_nodes':
            self.__num_nodes__ = item
        else:
            self[key] = item

edge_index的每一列定义一条边,其中第一行为边起始节点的索引,第二行为边结束节点的索引。这种表示方法被称为COO格式(coordinate format),通常用于表示稀疏矩阵
PyG 不是用稠密矩阵 A ∈ { 0 , 1 } ∣ V ∣ × ∣ V ∣ \mathbf{A} \in \left\{ 0, 1 \right\}^{|\mathcal{V}| \times |\mathcal{V}|} A{0,1}V×V来持有邻接矩阵的信息,而是用仅存储邻接矩阵 A \mathbf{A} A 0 0 0元素的稀疏矩阵来表示图。

通常,一个图至少包含x, edge_index, edge_attr, y, num_nodes5个属性当图包含其他属性时,我们可以通过指定额外的参数使Data对象包含其他的属性:

graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, num_nodes=num_nodes, other_attr=other_attr)
B. 转dict对象为Data对象

我们也可以将一个dict对象转换为一个Data对象

graph_dict = {
    'x': x,
    'edge_index': edge_index,
    'edge_attr': edge_attr,
    'y': y,
    'num_nodes': num_nodes,
    'other_attr': other_attr
}
graph_data = Data.from_dict(graph_dict)

from_dict是一个类方法:

@classmethod
def from_dict(cls, dictionary):
    r"""Creates a data object from a python dictionary."""
    data = cls()
    for key, item in dictionary.items():
        data[key] = item

    return data

注意:graph_dict中属性值的类型与大小的要求与Data类的构造函数的要求相同。

3.3.2 Data对象转换成其他类型数据

我们可以将Data对象转换为dict对象:

def to_dict(self):
    return {key: item for key, item in self}

或转换为namedtuple:

def to_namedtuple(self):
    keys = self.keys
    DataTuple = collections.namedtuple('DataTuple', keys)
    return DataTuple(*[self[key] for key in keys])

获取Data对象属性

x = graph_data[‘x’]

设置Data对象属性

graph_data[‘x’] = x

获取Data对象包含的属性的关键字

graph_data.keys()

对边排序并移除重复的边

graph_data.coalesce()

3.3.3 Data对象的其他性质

我们通过观察PyG中内置的一个图来查看Data对象的性质:

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
data = dataset[0]  # Get the first graph object.
print(data)
print('==============================================================')

# 获取图的一些信息
print(f'Number of nodes: {data.num_nodes}') # 节点数量
print(f'Number of edges: {data.num_edges}') # 边数量
print(f'Number of node features: {data.num_node_features}') # 节点属性的维度
print(f'Number of node features: {data.num_features}') # 同样是节点属性的维度
print(f'Number of edge features: {data.num_edge_features}') # 边属性的维度
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') # 平均节点度
print(f'if edge indices are ordered and do not contain duplicate entries.: {data.is_coalesced()}') # 是否边是有序的同时不含有重复的边
print(f'Number of training nodes: {data.train_mask.sum()}') # 用作训练集的节点
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}') # 用作训练集的节点数占比
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}') # 此图是否包含孤立的节点
print(f'Contains self-loops: {data.contains_self_loops()}')  # 此图是否包含自环的边
print(f'Is undirected: {data.is_undirected()}')  # 此图是否是无向图

3.4 Dataset类——PyG中图数据集的表示及其使用

PyG内置了大量常用的基准数据集,接下来我们以PyG内置的Planetoid数据集为例,来学习PyG中图数据集的表示及使用。

Planetoid数据集类的官方文档为torch_geometric.datasets.Planetoid。

3.4.1 生成数据集对象并分析数据集

如下方代码所示,在PyG中生成一个数据集是简单直接的。在第一次生成PyG内置的数据集时,程序首先下载原始文件,然后将原始文件处理成包含Data对象的Dataset对象并保存到文件

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/dataset/Cora', name='Cora')
# Cora()

len(dataset)
# 1 该数据集只有一个图

dataset.num_classes
# 7

dataset.num_node_features
# 1433

3.4.2 分析数据集中样本

可以看到该数据集只有一个图,包含7个分类任务,节点的属性为1433维度。

data = dataset[0]
# Data(edge_index=[2, 10556], test_mask=[2708],
#         train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

data.is_undirected()
# True

data.train_mask.sum().item()
# 140

data.val_mask.sum().item()
# 500

data.test_mask.sum().item()
# 1000

现在我们看到该数据集包含的唯一的图,有2708个节点,节点特征为1433维,有10556条边,有140个用作训练集的节点,有500个用作验证集的节点,有1000个用作测试集的节点

A: 为什么节点数为 2708,但训练、验证、测试节点的总和为 1650?
Q: 在 PyG 中,数据集中并不是所有节点都被分配到训练、验证或测试集。剩余的 1058 个节点(2708 - 1650 = 1058)通常没有标签,也就不参与监督训练、验证或测试。在实际应用中,尤其是图神经网络中,这是一个常见情况,因为图中的部分节点可能并不具有标签,而是作为上下文信息参与训练

A: train_mask、val_mask 和 test_mask 的维度解释
Q: 在 PyG 中,train_mask、val_mask 和 test_mask 是长度为 2708 的布尔张量,而不是单纯地包含分配的节点数量。每个布尔张量表示该节点是否属于相应的集合
train_mask[i] 为 True 表示节点 i 属于训练集
val_mask[i] 为 True 表示节点 i 属于验证集。
test_mask[i] 为 True 表示节点 i 属于测试集。
这样,每个节点的位置都固定地映射到训练、验证、测试集中某一类别或不属于任何类别。这种方式便于处理节点信息,并方便后续将特定集合的节点从图数据中分割出来。

3.4.3 数据集的使用

假设我们定义好了一个图神经网络模型,其名为Net。在下方的代码中,我们展示了节点分类图数据集在训练过程中的使用。

model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/913236.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式面试八股文(六)·ROM和RAM的区别、GPIO的八种工作模式、串行通讯和并行通讯的区别、同步串行和异步串行的区别

目录 1. ROM和RAM的区别 2. GPIO的八种工作模式 3. 串行通讯和并行通讯的区别 3.1 串行通讯 3.2 并行通讯 3.3 对比 4. 同步串行和异步串行的区别 4.1 时钟信号 4.2 数据传输效率 4.3 应用场景 4.4 硬件复杂性 1. ROM和RAM的区别 ROM(Read-O…

批量缓存模版

批量缓存模版 缓存通常有两种使用方式,一种是Cache-Aside,一种是cache-through。也就是旁路缓存和缓存即数据源。 一般一种用于读,另一种用于读写。参考后台服务架构高性能设计之道。 最典型的Cache-Aside的样例: //读操作 da…

Vue3学习笔记(上)

Vue3学习笔记(上) Vue3的优势: 更容易维护: 组合式API更好的TypeScript支持 更快的速度: 重写diff算法模板编译优化更高效的组件初始化 更小的体积: 良好的TreeShaking按需引入 更优的数据响应式&#xf…

SPIRE: Semantic Prompt-Driven Image Restoration 论文阅读笔记

这是一篇港科大学生在google research 实习期间发在ECCV2024的语义引导生成式修复的文章,港科大陈启峰也挂了名字。从首页图看效果确实很惊艳,尤其是第三行能用文本调控修复结果牌上的字。不过看起来更倾向于生成,对原图内容并不是很复原&…

Knowledge Graph-Enhanced Large Language Models via Path Selection

研究背景 研究问题:这篇文章要解决的问题是大型语言模型(LLMs)在生成输出时存在的事实不准确性,即所谓的幻觉问题。尽管LLMs在各种实际应用中表现出色,但当遇到超出训练语料库范围的新知识时,它们通常会生…

常见计算机网络知识整理(未完,整理中。。。)

TCP和UDP区别 TCP是面向连接的协议,发送数据前要先建立连接;UDP是无连接的协议,发送数据前不需要建立连接,是没有可靠性; TCP只支持点对点通信,UDP支持一对一、一对多、多对一、多对多; TCP是…

HashMap(深入源码追踪)

一篇让你搞懂HashMap的几个最重要的知识点,往源码跟踪可以让我们很轻松应对所谓的一些八股面试题. 一. 属性解释 先来解释HashMap中重要的常量属性值 DEFAULT_INITIAL_CAPACITY : 默认初始化容量,也就是如果不指定初始化的Map存储容量大小,默认生成一个存储16个空间的Map集合…

MySQL中的事务与锁

目录 事务 InnoDB 和 ACID 模型 原⼦性的实现 持久性的实现 ​隔离性的实现 锁 隔离级别 ​多版本控制(MVCC) 事务 1.什么是事务? 事务是把⼀组SQL语句打包成为⼀个整体,在这组SQL的执⾏过程中,要么全部成功,要么全部失败&#…

C#开发基础:WPF和WinForms关于句柄使用的区别

1、前言 在 Windows 应用程序开发中,WPF(Windows Presentation Foundation)和 WinForms(Windows Forms)是两种常见的用户界面(UI)框架。它们各自有不同的架构和处理方式,其中一个显…

基于.NET开源、功能强大且灵活的工作流引擎框架

前言 工作流引擎框架在需要自动化处理复杂业务流程、提高工作效率和确保流程顺畅执行的场景中得到了广泛应用。今天大姚给大家推荐一款基于.NET开源、功能强大且灵活的工作流引擎框架:elsa-core。 框架介绍 elsa-core是一个.NET开源、免费(MIT License…

.NET6中WPF项目添加System.Windows.Forms引用

.NET6中WPF项目添加System.Windows.Forms引用 .NET6的WPF自定义控件默认是不支持System.Windows.Forms引用的,需要添加这个引用方法如下: 1. 在项目浏览器中找到项目右击,选择编辑项目文件(Edit Project File)。 …

16.UE5拉怪机制,怪物攻击玩家,伤害源,修复原视频中的BUG

2-18 拉怪机制,怪物攻击玩家、伤害源、黑板_哔哩哔哩_bilibili 目录 1.实行行为树实现拉怪机制 1.1行为树黑板 1.2获取施加伤害对象(伤害源) 2.修复原视频中,第二次攻击怪物后,怪物卡在原地不动的BUG 3.怪物攻击玩…

<项目代码>YOLOv8 草莓成熟识别<目标检测>

YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一个回归问题,能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法(如Faster R-CNN),YOLOv8具有更高的…

Vue全栈开发旅游网项目(9)-用户登录/注册及主页页面开发

1.用户登录页面开发 1.查询vant组件 2.实现组件模板部分 3.模型层准备 4.数据上传 1.1 创建版权声明组件Copyright 新建文件&#xff1a;src\components\common\Copyright.vue <template><!-- 版权声明 --><div class"copyright">copyright xx…

后台管理系统窗体程序:文章管理 > 文章列表

目录 文章列表的的功能介绍&#xff1a; 1、进入页面 2、页面内的各种功能设计 &#xff08;1&#xff09;文章表格 &#xff08;2&#xff09;删除按钮 &#xff08;3&#xff09;编辑按钮 &#xff08;4&#xff09;发表文章按钮 &#xff08;5&#xff09;所有分类下拉框 &a…

【万字详解】如何在微信小程序的 Taro 框架中设置静态图片 assets/image 的 Base64 转换上限值

设置方法 mini 中提供了 imageUrlLoaderOption 和 postcss.url 。 其中&#xff1a; config.limit 和 imageUrlLoaderOption.limit 服务于 Taro 的 MiniWebpackModule.js &#xff0c; 值的写法要 &#xff08;&#xff09;KB * 1024。 config.maxSize 服务于 postcss-url 的…

基于STM32通过TM1637驱动4位数码管详细解析(可直接移植使用)

目录 1. 单位数码管概述 2. 对应编码 2.1 共阳数码管 2.2 共阴数码管 3. TM1637驱动数码管 3.1 工作原理 3.1.1 读键扫数据 3.1.2 显示器寄存器地址和显示模式 3.2 时序 3.2.1 指令数据传输过程&#xff08;读案件数据时序&#xff09; 3.2.2 写SRAM数据…

数字信号处理Python示例(11)生成非平稳正弦信号

文章目录 前言一、生成非平稳正弦信号的实验设计二、生成非平稳正弦信号的Python代码三、仿真结果及分析写在后面的话 前言 本文继续给出非平稳信号的Python示例&#xff0c;所给出的示例是非平稳正弦信号&#xff0c;在介绍了实验设计之后给出Python代码&#xff0c;最后给出…

Linux 系统结构

Linux系统一般有4个主要部分&#xff1a;内核、shell、文件系统和应用程序。内核、shell和文件系统一起形成了基本的操作系统结构&#xff0c;它们使得用户可以运行程序、管理文件并使用系统。 1. linux内核 内核是操作系统的核心&#xff0c;具有很多最基本功能&#xff0c;它…

网络安全之SQL初步注入

一.字符型 平台使用pikachu $name$_GET[name]; ​ $query"select id,email from member where username$name"; 用户输入的数据会被替换到SQL语句中的$name位置 查询1的时候&#xff0c;会展示username1的用户数据&#xff0c;可以测试是否有注入点&#xff08;闭…