PyG教程:MessagePassing基类

PyG教程:MessagePassing基类

  • 一、引言
  • 二、如何自定义消息传递网络
    • 1.构造函数
    • 2.propagate函数
    • 3.message函数
    • 4.aggregate函数
    • 5.update函数
  • 三、代码实战
    • 1.图数据定义
    • 2.实现GNN的消息传递过程
    • 3.完整代码
    • 4.完整代码的精简版本
  • 四、总结
    • 1.MessagePassing各个函数的执行顺序
    • 2.参考资料

一、引言

PyG框架中提供了一个消息传递基类torch_geometric.nn.MessagePassing,它实现了消息传递的自动处理,继承该类可以简单方便的构建自己的消息传播GNN。

二、如何自定义消息传递网络

要自定义GNN模型,首先需要继承MessagePassing类,然后重写如下方法:

  • message(...):构建要传递的消息;
  • aggregate(...):将从源节点传递过来的消息聚合到目标结点;
  • update(...):更新节点的消息。

上述方法并不是一定都要自定义,若MessagePassing类默认实现满足你的需求,则可以不重写。

1.构造函数

继承MessagePassing类后,在构造函数中可以通过super().__init__方法来向基类MessagePassing传递参数,来指定消息传递过程中的一些行为。MessagePassing类的初始化函数如下:
在这里插入图片描述
参数说明:

参数名参数说明
aggr消息传递中的消息聚合方式,常用的包括summeanminmaxmul等等。default: sum
flow消息传播的方向,其中source_to_targe表示从源节点到目标节点、target_to_source表示从目标节点到源节点。default:source_to_target
node_dim传播的维度,default:-2
decomposed_layers这个参数没用过,我也还不知道,后面会更新。

2.propagate函数

在具体介绍消息传递的三个相关函数之前,首先先介绍propagate函数,该函数是消息传递的启动函数,调用该函数后依次会执行messageaggregateudpate函数来完成消息的传递聚合更新。该函数的声明如下:
在这里插入图片描述
参数说明:

参数名参数说明
edge_index边索引
size这个参数目前我理解的不是很透彻,后面透彻了补一下
**kwargs构建、聚合和更新消息所需的额外数据,都可以传入propagate函数,这些参数可以在消息传递过程中的三个函数中接收。

该函数一般会传入edge_index和特征x

3.message函数

message函数是用来构建节点的消息的。传递给propagate函数的tensor可以映射到中心(target)节点邻居(source)节点上,只需要在相应变量名后加上_ior_j即可,通常称_i为中心(target)节点,称_j为邻居(source)节点。

source节点和target节点的关系:
在这里插入图片描述
message实现源码:
在这里插入图片描述

从源码的默认实现可以看出,message传递的消息就是邻居节点自身的特征向量。

示例:

def forward(self, data):
	out = self.propagate(edge_index, x=x)
	pass

def message(self, x_i, x_j, edge_index_i, edge_index_j):
	pass

该例子中利用propagate函数传递了两个参数edge_indexx,则message函数可以根据propagate函数中的两个参数构造自己的参数,上述message函数中的构造参数为:

  • x_i:中心节点(target)的特征向量组成的矩阵,注意该矩阵与图节点的矩阵x是不同的;
  • x_j:邻居节点(source)的特征向量组成的矩阵;
  • edge_index_i:中心节点的索引;
  • edge_index_j:邻居节点的索引。

注意,若flow='source_to_target',则消息将由邻居节点传向中心节点,若flow='target_to_source'则消息将从中心节点传向邻居节点,默认为第一种情况

4.aggregate函数

消息聚合函数aggregate用来聚合来自邻居的消息,常用的包括summeanmaxmin等,可以通过super().__init__()中的参数aggr来设定。该函数的第一个参数为message函数的返回值。

5.update函数

update函数用来更新节点的消息,aggregate函数的返回值作为该函数的第一个参数。

默认实现:
在这里插入图片描述

从默认实现可以看出update函数没有进行任何的操作,只是将raggregate函数的返回值返回了而已。

实际写代码的过程中,我们也不会去重写这个方法,而是,在forward函数中调用完propagate(…)函数后编写代码,代替update函数的功能。

三、代码实战

假设我们设计一个GNN模型,其中消息传递过程用公式表示如下:
X i ( k ) = X i ( k − 1 ) + ∑ j ∈ N ( i ) X j ( k − 1 ) (1) X_i^{(k)} = X_i^{(k-1)} + \sum _{j\in {\mathcal {N(i)}}} X_j^{(k-1) }\tag {1} Xi(k)=Xi(k1)+jN(i)Xj(k1)(1)

  • message生成的消息就是中心节点的邻居节点的特征向量。
  • aggregaet聚合消息的方式是sum,即把所有邻居节点的特征向量加起来。
  • update更新中心节点的方式是:将聚合得到的消息和中心节点自身的特征向量相加。

1.图数据定义

我们有如下数据:

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
						   [1, 0]], dtype=torch.long)
x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.contiguous())

在这里插入图片描述

2.实现GNN的消息传递过程

class MyConv(MessagePassing):
	def __init__(self):
		super().__init__(aggr='sum')

	def forward(self, data):
		out = self.propagate(data.edge_index, x=data.x)
		# out = out + x 
		return out

	def message(self, x_i, x_j, edge_index_i, edge_index_j):
		# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行
		return x_j

	def aggregate(self, message, edge_index_i):
		# 这里只是写的样例,实际上一般不会重写这个方法,直接使用默认的就好了,只需要自己选择一下聚合的方式即可
		return super().aggregate(message, edge_index_i, dim_size=len(x))

	def update(self, aggregate, x):
		# 一般也不会重写这个方法的,update阶段可以在forward函数中调用完propagate(...)函数后编写代码。
		return x + aggregate

3.完整代码

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing


class MyConv(MessagePassing):
	def __init__(self):
		super().__init__(aggr='sum')

	def forward(self, data):
		out = self.propagate(data.edge_index, x=data.x)
		out = out + data.x
		return out

	def message(self, x_i, x_j, edge_index_i, edge_index_j):
		# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行
		return x_j

	# def aggregate(self, message, edge_index_i):
	# 	return super().aggregate(message, edge_index_i, dim_size=len(x))

	# def update(self, aggregate, x):
	# 	return x + aggregate


if __name__ == '__main__':
	edge_index = torch.tensor([[0, 1],
							   [1, 0]], dtype=torch.long)
	x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
	data = Data(x=x, edge_index=edge_index.contiguous())

	myConv = MyConv()
	print(myConv(data))

4.完整代码的精简版本

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops


class MyConv(MessagePassing):
	def __init__(self):
		super().__init__(aggr='sum')

	def forward(self, data):
		edge_index, _ = add_self_loops(data.edge_index, num_nodes=len(data.x))
		out = self.propagate(edge_index, x=data.x)
		return out

if __name__ == '__main__':
	edge_index = torch.tensor([[0, 1],
							   [1, 0]], dtype=torch.long)
	x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
	data = Data(x=x, edge_index=edge_index.contiguous())

	myConv = MyConv()
	print(myConv(data))

思考:大家可以根据上面讲解的细节,理解一下这个精简版本的代码的实现逻辑和过程。

四、总结

1.MessagePassing各个函数的执行顺序

在这里插入图片描述

2.参考资料

  • PyG: MessagePassing
  • PyG: Creating Message Passing Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/925525.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Win10 系统下使用研华XNavi安装板卡驱动失败

配置:主板 AIMB-705G2,CPU i5-6500,系统 Windows10_64bit_Pro_22H2, 测试: 1、多次安装驱动。FAIL 2、尝试在其他电脑上移植板卡驱动并且使用数字签名安装。FAIL 3、系统更新到WIN10最新版本。FAIL 4、杀毒软件卸…

用三维模型的顶点法向量计算法线贴图

法线贴图的核心概念是在不增加额外多边形数目的情况下,通过模拟细节来改善光照效果。具体流程包括: 法线的计算与存储:通过法线映射将三维法线向量转化为法线贴图的 RGB 值。渲染中的使用:在片段着色器中使用法线贴图来替代原有的…

idea编译与maven编译的问题

先说下idea编译按钮的位置 编译运行时,会在idea底部出现Build面板 比较: idea编译器编译整个项目 maven编译器根据pom.xml的配置,可实现灵活编译 两套编译会遇到的问题: maven 编译成功 ,但idea编译失败&#xff…

deepin 安装 chrome 浏览器

deepin 安装 chrome 浏览器 最近好多小伙伴儿和我说 deepin 无法安装最新的谷歌浏览器 其实是因为最新的 谷歌浏览器 其中的一个依赖需要提前安装 提前安装依赖然后再安装谷歌浏览器就可以了 安装 fonts-liberationsudo apt -y install fonts-liberation安装 chrome 浏览器sudo…

《String类》

目录 一、定义与概述 二、创建字符串对象 2.1 直接赋值 2.2 使用构造函数 三、字符串的不可变性 四、常用方法 4.1 String对象的比较 4.1.1 比较是否引用同一个对象 4.1.2 boolean equals(Object anObject)方法:按照字典序比较 4.1.3 int compareTo(Strin…

OpenSSH-9.9p1 OpenSSL-3.4.0 升级步骤详细

前言 收到漏洞扫描通知 OpenSSH 安全漏洞(CVE-2023-38408) OpenSSH 安全漏洞(CVE-2023-51385) OpenSSH 安全漏洞(CVE-2023-51384) OpenSSH 安全漏洞(CVE-2023-51767) OpenSSH 安全漏洞(CVE-2023-48795) OpenSSH(OpenBSD SecureShell)是加拿大OpenBSD计划…

【Stable Diffusion】安装教程

目录 一、python 安装教程 二、windows cuda安装教程 三、Stable Diffusion下载 四、Stable Diffusion部署(重点) 一、python 安装教程 (1)第一步下载 打开python下载页面,找到python3.10.9,点击右边…

Scala身份证上的秘密以及Map的遍历

object test {def main(args: Array[String]): Unit {val id "42032220080903332x"//1.生日是?//字符串截取val birthday id.substring(10,14) //不包括终点下标println(birthday)val year id.substring(6,10) //println(year)//性别:倒数第…

springboot 异步 @Async 的日常使用及失效场景

文章目录 springboot 异步 Async 的日常使用引言一、Async 使用位置二、Async 使用三、注解 Async 失效的情况(1)调用同一个类中的异步方法(内部调用)(2)未使用 EnableAsync 注解(3)…

Laravel8.5+微信小程序实现京东商城秒杀方案

一、商品秒杀涉及的知识点 鉴权策略封装掊口访问频次限制小程序设计页面防抖接口调用订单创建事务使用超卖防御 二、订单库存系统方案(3种) 下单减库存 优点是库存和订单的强一致性,商品不会卖超,但是可能导致恶意下单&#xff…

三角网格体的光滑性问题

三角网格体的光滑性问题 在计算机图形学和计算机辅助设计中,C0连续性(也称为位置连续性)是指两个曲线或曲面在它们的公共边界上具有相同的位置。这意味着它们在边界处没有缝隙或重叠,但它们的切线方向可以不同。C0连续性是最低级…

独家|京东调整职级序列体系

原有的M、P、T、S主序列将正式合并为新的专业主序列P。 作者|文昌龙 编辑|杨舟 据「市象」独家获悉,京东已在近日在内部宣布对职级序列体系进行调整,将原有的M、P、T、S主序列正式合并为新的专业主序列P,合并后的职级体系将沿用原有专业序…

Echarts 绘制地图

一、Apache Echarts 官网地址:https://echarts.apache.org/ npm install echarts --save 二、获取地图的GeoJSON 地址:DataV.GeoAtlas地理小工具系列 左侧是地图,右侧是JSON数据路径,点击你想要生成的地图省市、地级&#xff0…

想入手养宠宠物空气净化器,养宠宠物空气净化器哪个好?

家里有了宠物后,确实多了很多欢乐,但掉落的毛发也多了不少,特别是换毛期,掉毛问题真的很让人头疼!作为养了多年宠物的铲屎官,我真心推荐大家买一台宠物空气净化器,它能大大提升家里的空气质量&a…

ASUS/华硕ROG掌机 2023款 RC71 NR2301原厂win11系统 工厂文件 带ASUS Recovery恢复

华硕工厂文件恢复系统 ,安装结束后带隐藏分区,一键恢复,以及机器所有驱动软件。 系统版本:windows11 原厂系统下载网址:http://www.bioxt.cn 需准备一个20G以上u盘进行恢复 请注意:仅支持以上型号专用…

nginx 升级http 到 http2

同步发布于我的网站 🚀 背景介绍准备工作配置过程遇到的问题及解决方法验证升级总结参考资料 背景介绍 HTTP/2 是 HTTP 协议的最新版本,相比 HTTP/1.1,它带来了多项重要的改进,包括多路复用、头部压缩和服务端推送。这些特性可…

Spark 内存管理机制

Spark 内存管理 堆内内存和堆外内存 作为一个 JVM 进程,Executor 的内存管理建立在 JVM(最小为六十四分之一,最大为四分之一)的内存管理之上,此外spark还引入了堆外内存(不在JVM中的内存),在spark中是指不…

透视投影(Perspective projection)与等距圆柱投影(Equirectangular projection)

一、透视投影 1.方法概述 Perspective projection(透视投影)是一种模拟人眼观察三维空间物体时的视觉效果的投影方法。它通过模拟观察者从一个特定视点观察三维场景的方式来创建二维图像。在透视投影中,远处的物体看起来比近处的物体小&…

uniapp开发微信小程序笔记8-uniapp使用vant框架

前言:其实用uni-app开发微信小程序的首选不应该是vant,因为vant没有专门给uni-app设置专栏,可以看到目前Vant 官方提供了 Vue 2 版本、Vue 3 版本和微信小程序版本,并由社区团队维护 React 版本和支付宝小程序版本。 但是vant的优…

Spring Web MVC其他扩展(详解下)

文章目录 Spring MVC其他扩展(下)异常处理异常处理机制声明式异常好处基于注解异常声明异常处理 拦截器拦截器概念拦截器使用拦截器作用位置图解拦截器案例拦截器工作原理源码 参数校验校验概述操作演示SpringMVC自定义参数验证ValueObject(VO) 文件上传…