PyG教程:MessagePassing基类
- 一、引言
- 二、如何自定义消息传递网络
- 1.构造函数
- 2.propagate函数
- 3.message函数
- 4.aggregate函数
- 5.update函数
- 三、代码实战
- 1.图数据定义
- 2.实现GNN的消息传递过程
- 3.完整代码
- 4.完整代码的精简版本
- 四、总结
- 1.MessagePassing各个函数的执行顺序
- 2.参考资料
一、引言
在PyG
框架中提供了一个消息传递基类torch_geometric.nn.MessagePassing
,它实现了消息传递的自动处理,继承该类可以简单方便的构建自己的消息传播GNN。
二、如何自定义消息传递网络
要自定义GNN模型,首先需要继承MessagePassing
类,然后重写如下方法:
message(...)
:构建要传递的消息;aggregate(...)
:将从源节点传递过来的消息聚合到目标结点;update(...)
:更新节点的消息。
上述方法并不是一定都要自定义,若
MessagePassing
类默认实现满足你的需求,则可以不重写。
1.构造函数
继承MessagePassing
类后,在构造函数中可以通过super().__init__
方法来向基类MessagePassing
传递参数,来指定消息传递过程中的一些行为。MessagePassing
类的初始化函数如下:
参数说明:
参数名 | 参数说明 |
---|---|
aggr | 消息传递中的消息聚合方式,常用的包括sum 、mean 、min 、max 、mul 等等。default: sum |
flow | 消息传播的方向,其中source_to_targe 表示从源节点到目标节点、target_to_source 表示从目标节点到源节点。default:source_to_target |
node_dim | 传播的维度,default:-2 |
decomposed_layers | 这个参数没用过,我也还不知道,后面会更新。 |
2.propagate函数
在具体介绍消息传递的三个相关函数之前,首先先介绍propagate
函数,该函数是消息传递的启动函数,调用该函数后依次会执行message
、aggregate
、udpate
函数来完成消息的传递、聚合和更新。该函数的声明如下:
参数说明:
参数名 | 参数说明 |
---|---|
edge_index | 边索引 |
size | 这个参数目前我理解的不是很透彻,后面透彻了补一下 |
**kwargs | 构建、聚合和更新消息所需的额外数据,都可以传入propagate 函数,这些参数可以在消息传递过程中的三个函数中接收。 |
该函数一般会传入
edge_index
和特征x
。
3.message函数
message
函数是用来构建节点的消息的。传递给propagate
函数的tensor
可以映射到中心(target)节点
和邻居(source)节点上
,只需要在相应变量名后加上_i
or_j
即可,通常称_i
为中心(target)节点,称_j
为邻居(source)节点。
source节点和target节点的关系:
message实现源码:
从源码的默认实现可以看出,message
传递的消息就是邻居节点自身的特征向量。
示例:
def forward(self, data):
out = self.propagate(edge_index, x=x)
pass
def message(self, x_i, x_j, edge_index_i, edge_index_j):
pass
该例子中利用propagate
函数传递了两个参数edge_index
和x
,则message
函数可以根据propagate
函数中的两个参数构造自己的参数,上述message
函数中的构造参数为:
x_i
:中心节点(target)的特征向量组成的矩阵,注意该矩阵与图节点的矩阵x
是不同的;x_j
:邻居节点(source)的特征向量组成的矩阵;edge_index_i
:中心节点的索引;edge_index_j
:邻居节点的索引。
注意,若
flow='source_to_target'
,则消息将由邻居节点传向中心节点,若flow='target_to_source'
则消息将从中心节点传向邻居节点,默认为第一种情况
4.aggregate函数
消息聚合函数aggregate
用来聚合来自邻居的消息,常用的包括sum
、mean
、max
和min
等,可以通过super().__init__()
中的参数aggr
来设定。该函数的第一个参数为message
函数的返回值。
5.update函数
update
函数用来更新节点的消息,aggregate
函数的返回值作为该函数的第一个参数。
默认实现:
从默认实现可以看出update
函数没有进行任何的操作,只是将raggregate
函数的返回值返回了而已。
实际写代码的过程中,我们也不会去重写这个方法,而是,在forward函数中调用完propagate(…)函数后编写代码,代替update函数的功能。
三、代码实战
假设我们设计一个GNN模型,其中消息传递过程用公式表示如下:
X
i
(
k
)
=
X
i
(
k
−
1
)
+
∑
j
∈
N
(
i
)
X
j
(
k
−
1
)
(1)
X_i^{(k)} = X_i^{(k-1)} + \sum _{j\in {\mathcal {N(i)}}} X_j^{(k-1) }\tag {1}
Xi(k)=Xi(k−1)+j∈N(i)∑Xj(k−1)(1)
message
生成的消息就是中心节点的邻居节点的特征向量。aggregaet
聚合消息的方式是sum
,即把所有邻居节点的特征向量加起来。update
更新中心节点的方式是:将聚合得到的消息和中心节点自身的特征向量相加。
1.图数据定义
我们有如下数据:
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1],
[1, 0]], dtype=torch.long)
x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.contiguous())
2.实现GNN的消息传递过程
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr='sum')
def forward(self, data):
out = self.propagate(data.edge_index, x=data.x)
# out = out + x
return out
def message(self, x_i, x_j, edge_index_i, edge_index_j):
# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行
return x_j
def aggregate(self, message, edge_index_i):
# 这里只是写的样例,实际上一般不会重写这个方法,直接使用默认的就好了,只需要自己选择一下聚合的方式即可
return super().aggregate(message, edge_index_i, dim_size=len(x))
def update(self, aggregate, x):
# 一般也不会重写这个方法的,update阶段可以在forward函数中调用完propagate(...)函数后编写代码。
return x + aggregate
3.完整代码
import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr='sum')
def forward(self, data):
out = self.propagate(data.edge_index, x=data.x)
out = out + data.x
return out
def message(self, x_i, x_j, edge_index_i, edge_index_j):
# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行
return x_j
# def aggregate(self, message, edge_index_i):
# return super().aggregate(message, edge_index_i, dim_size=len(x))
# def update(self, aggregate, x):
# return x + aggregate
if __name__ == '__main__':
edge_index = torch.tensor([[0, 1],
[1, 0]], dtype=torch.long)
x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.contiguous())
myConv = MyConv()
print(myConv(data))
4.完整代码的精简版本
import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr='sum')
def forward(self, data):
edge_index, _ = add_self_loops(data.edge_index, num_nodes=len(data.x))
out = self.propagate(edge_index, x=data.x)
return out
if __name__ == '__main__':
edge_index = torch.tensor([[0, 1],
[1, 0]], dtype=torch.long)
x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.contiguous())
myConv = MyConv()
print(myConv(data))
思考:大家可以根据上面讲解的细节,理解一下这个精简版本的代码的实现逻辑和过程。
四、总结
1.MessagePassing各个函数的执行顺序
2.参考资料
- PyG: MessagePassing
- PyG: Creating Message Passing Networks