DGL在异构图上的GraphConv模块

回顾同构图GraphConv模块

首先回顾一下同构图中实现GraphConv的主要思路(以GraphSAGE为例):
在初始化模块首先是获取源节点和目标节点的输入维度,同时获取输出的特征维度。根据SAGE论文提出的三种聚合操作,需要获取所使用的聚合类型,方便后面使用Pytorch中的nn模块实现。最后是特征归一化操作。
其具体的代码段为:

获取相关输入特征

        # 获取源节点和目标节点的输入特征维度
        self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)
        # 输出特征维度
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

根据聚合类型选择Pytorch对应的nn模块中的函数

        # 聚合类型:mean、pool、lstm、gcn
        if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)

权重初始化

构造函数的最后调用了 reset_parameters() 进行权重初始化。

def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2

forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:

  1. 检测输入图对象是否符合规范。
  2. 消息传递和聚合
  3. 聚合后,更新特征作为输出。

检测输入图对象的规范性

# 输入图对象的规范检测
with graph.local_scope():
    # 指定图类型,然后根据图类型扩展输入特征
    feat_src, feat_dst = expand_as_pair(feat, graph)

对于expand_as_pair()函数,其实现的操作是如果输入的特征不是一对的话(源节点和目标节点),就根据图Graph将特征变成一对,但要求图必须是一个block,其对应的源码为:

def expand_as_pair(input_, g=None):
    """Return a pair of same element if the input is not a pair.
    如果输入不是一对,则返回相同元素的一对。


    If the graph is a block, obtain the feature of destination nodes from the source nodes.
    如果图是块,则从源节点中获取目的节点的特征。
    Parameters
    ----------
    input_ : Tensor, dict[str, Tensor], or their pairs
        The input features
    g : DGLGraph or None
        The graph.

        If None, skip checking if the graph is a block.

    Returns
    -------
    tuple[Tensor, Tensor] or tuple[dict[str, Tensor], dict[str, Tensor]]
        The features for input and output nodes
        输入和输出节点的特性
    """
    if isinstance(input_, tuple):
        return input_
    elif g is not None and g.is_block:
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()
            }
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        return input_, input_

消息传递和聚合

聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() APIDGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。

        # 消息传递和聚合
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

如果是gcn聚合方式的话还需要用到它自身的特征,但是SAGE不需要,它只需要聚合邻居的特征,这里通过一条判断语句加以区分:

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

更新特征

聚合后,更新特征作为输出——forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

        # 更新特征作为输出
        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

异构图GraphConv模块

DGL提供了 HeteroGraphConv,用于定义异构图上GNN模块。 实现逻辑与消息传递级别的API multi_update_all() 相同,它包括:

  • 每个关系上的DGL NN模块。
  • 聚合来自不同关系上的结果。
    其对应的数学公式为:(r表示关系)

在这里插入图片描述

__ init __函数

异构图的卷积操作接受一个字典类型参数 mods。这个字典的键为关系名,值为作用在该关系上NN模块对象。参数 aggregate 则指定了如何聚合来自不同关系的结果。

class HeteroGraphConv(nn.Module):
    def __init__(self, mods, aggregate='sum'):
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            # 获取聚合函数的内部函数
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

nn.ModuleDict() 用于保存字典中的子模块。Pytorch官方也给出了对应的示例:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

forward函数

对于前向传播函数,除了需要输入图和输入张量以外,它还需要2个额外的字典参数mod_argsmod_kwargs。这2个字典与 self.mods 具有相同的键,值则为对应NN模块自定义参数
forward() 函数的输出结果也是一个字典类型的对象。其键为 nty,其值为每个目标节点类型 nty 的输出张量的列表, 表示来自不同关系的计算结果HeteroGraphConv 会对这个列表进一步聚合,并将结果返回给用户。聚合操作主要是:

if g.is_block:
    src_inputs = inputs
    dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
    src_inputs = dst_inputs = inputs

for stype, etype, dtype in g.canonical_etypes:
    rel_graph = g[stype, etype, dtype]
    if rel_graph.num_edges() == 0:
        continue
    if stype not in src_inputs or dtype not in dst_inputs:
        continue
    dstdata = self.mods[etype](
        rel_graph,
        (src_inputs[stype], dst_inputs[dtype]),
        *mod_args.get(etype, ()),
        **mod_kwargs.get(etype, {}))
    outputs[dtype].append(dstdata)

输入 g 可以是异构图或来自异构图的子图区块。和普通的NN模块一样,forward() 函数需要分别处理不同的输入图类型

上述代码中的for循环为处理异构图计算的主要逻辑

  • 首先我们遍历图中所有的关系(通过调用 canonical_etypes)。
  • 通过关系名,我们可以使用g[ stype, etype, dtype ]的语法将只包含该关系的子图( rel_graph )抽取出来。
  • 对于二分图,输入特征将被组织为元组 (src_inputs[stype], dst_inputs[dtype])
  • 接着调用用户预先注册在该关系上的NN模块,并将结果保存在outputs字典中。

最后,HeteroGraphConv 会调用用户注册的 self.agg_fn 函数聚合来自多个关系的结果。

rsts = {}
for nty, alist in outputs.items():
    if len(alist) != 0:
        rsts[nty] = self.agg_fn(alist, nty)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/186288.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Day40力扣打卡

打卡记录 包子凑数(裴蜀定理 DP) 根据裴蜀定理,存在 c gcd(a, b) 使不定方程ax by c满足条件,如果gcd(a, b) 1即a与b互素的情况下,就会 ax by 1,由于为1可以构造后面的无穷数字,故得到结…

项目实战详细讲解带有条件响应的 SQL 盲注、MFA绕过技术、MFA绕过技术、2FA绕过和技巧、CSRF绕过、如何寻找NFT市场中的XSS漏洞

项目实战详细讲解带有条件响应的 SQL 盲注、MFA绕过技术、MFA绕过技术、2FA绕过和技巧、CSRF绕过、如何寻找NFT市场中的XSS漏洞。 带有条件响应的 SQL 盲注 这篇文章的核心要点如下: 漏洞发现:作者在Portswigger提供的实验室中发现了一个盲SQL注入漏洞。这个漏洞存在于一个应…

【libGDX】Mesh纹理贴图

1 前言 纹理贴图的本质是将图片的纹理坐标与模型的顶点坐标建立一一映射关系。纹理坐标的 x、y 轴正方向分别朝右和朝下,如下。 2 纹理贴图 本节将使用 Mesh、ShaderProgram、Shader 实现纹理贴图,OpenGL ES 的实现见博客 → 纹理贴图。 DesktopLauncher…

Prometheus环境搭建和认识

Prometheus 环境搭建 1.prometheus 简介 Prometheus是基于go语言开发的一套开源的监控、报警和时间序列数据库的组合,是由SoundCloud公司开发的开源监控系统,Prometheus于2016年加入CNCF(Cloud Native Computing Foundation,云原生计算基金…

虾皮插件:优化Shopee商家店铺运营的利器

在如今竞争激烈的电商市场中,如何提升Shopee商家店铺的运营效率和销售业绩成为了摆在每个商家面前的一道难题。然而,幸运的是,虾皮插件-知虾的出现为商家们带来了一种全新的解决方案。本文将介绍虾皮插件的用途和优势,并详细介绍其…

【第一部也是唯一一部】3DMAX脚本语言MAXScript 中文帮助

3DMAX我们很多3D设计师和艺术家都在使用这款功能强大的三维软件,但是再强大的工具也不可能包罗万象,无所不能,所以,通常官方努力在功能和性能平衡之间的同时,也提供第三方扩展软件功能的可能—插件开发。 3DMAX插件开发…

Linux的基本指令(3)

16.cal指令 cal命令可以用来显示公历(阳历)日历。公历是现在国际通用的历法,又称格列历,通称阳历。“阳历”又名“太阳历”,系以地球绕行太阳一周为一年,为西方各国所通用,故又名“西历”。 命…

安防系统智能视频监控中出现画面异常该如何自检?

大家都知道,在当今社会,摄像头无处不在,除了常见的生活与工作场景中,在一些无法人员无法长期驻点场景,如野生动物监测、高空作业监控、高压电缆监控等场景,在这些地方安装摄像头就是为方便日常监控。但是由…

java 反射和注解1-反射详解

反射和注解本就是一家人,注解离不开反射,这里先将反射的写法,本文涉到的注解暂时可以不不用理解 1,创建一个类 public class ReflexUser {public String name;private String namePrivate;protected String nameProtected;Strin…

【自主探索】基于 rrt_exploration 的单个机器人自主探索建图

文章目录 一、rrt_exploration 介绍1、原理2、主要思想3、拟解决的问题4、优缺点 二、安装环境三、安装与运行1、安装2、运行 四、配置自己的机器人1、Robots Network2、Robots frame names in tf3、Robots node and topic names4、Setting up the navigation stack on the rob…

CSS特效018:科技动画,hover后点亮阁楼,拉伸出楼梯

CSS常用示例100专栏目录 本专栏记录的是经常使用的CSS示例与技巧,主要包含CSS布局,CSS特效,CSS花边信息三部分内容。其中CSS布局主要是列出一些常用的CSS布局信息点,CSS特效主要是一些动画示例,CSS花边是描述了一些CSS…

Linux时间命令—— 显示时间,日历等

目录 1.date显示时间 1.1 常用的标记列表: 1.2 设定时间: 2.cal显示日历 3.时间戳 1.date显示时间 date 用法:date [OPTION] ... [FORMAT] 1.1 常用的标记列表: %H : 小时 (00..23) %M : 分钟 (00..59) %S : 秒 (00..61…

nginx基础篇学习

一、nginx编译安装 1、前往nginx官网获取安装包 下载安装包 2、解压 3、安装 进入安装包 安装准备:nginx的rewrite module重写模块依赖于pcre、pcre-devel、zlib和zlib-devel库,要先安装这些库 安装: 编译: 启动&#xff…

ErphpdownV16.21插件 安装教程和插件下载

ErphpdownV16.21插件下载_新版本 上传插件并解压 登入后台插件管理启动ErphpdownV16.21插件即可 启动后设置即可使用此版本为学习版插件 功能介绍: Erphpdown会员推广下载专业版 经过完美测试运行于wordpress 3.x-6.x版本。后续会增加更多实用的功能。已针对此插件…

医学图像分割:U_Net 论文阅读

“U-Net: Convolutional Networks for Biomedical Image Segmentation” 是一篇由Olaf Ronneberger, Philipp Fischer, 和 Thomas Brox发表的论文,于2015年在MICCAI的医学图像计算和计算机辅助干预会议上提出。这篇论文介绍了一种新型的卷积神经网络架构——U-Net&a…

3、Qt使用windeploy工具打包可执行文件

新建一个文件夹,把要打包的可执行文件exe拷贝过来 点击输入框,复制一下文件夹路径 点击电脑左下角,找到Qt文件夹, 点击打开 “Qt 5.12.0 for Desktop” (我安装的是Qt 5.12.0版本) 输入“cd bin”&#xff…

【云原生 Prometheus篇】Prometheus的动态服务发现机制

自动发现 一、Prometheus服务发现 理论部分1.1 Prometheus数据采集配置1.2 基于文件的服务发现1.3 基于consul的服务发现1.4 基于 Kubernetes API 的服务发现1.4.1 概念1.4.2 部分配置参数1.4.3 配置模板 二、实例一:部署基于文件的服务发现2.1 创建用于服务发现的文…

保姆级连接FusionInsight MRS kerberos Hive

数新网络,让每个人享受数据的价值https://xie.infoq.cn/link?targethttps%3A%2F%2Fwww.datacyber.com%2F 概述 本文将介绍在华为云 FusionInsight MRS(Managed Relational Service)的Kerberos环境中,如何使用Java和DBeaver实现远…

命名空间、字符串、布尔类型、nullptr、类型推导

面向过程语言:C ——> 重视求解过程 面向对象语言:C ——> 重视求解的方法 面向对象的三大特征:封装、继承和多态 C 和 C 在语法上的区别 1、命名空间(用于解决命名冲突问题) 2、函数重载和运算符重载&#xf…

vatee万腾的科技征途:Vatee独特探索的数字化力量

在数字化时代的浪潮中,Vatee万腾以其独特的科技征途成为引领者。公司在数字化领域的探索之路不仅是技术的创新,更是一种对未知的勇敢涉足,是对新时代的深刻洞察和积极实践。 Vatee万腾通过独特的探索,展示了在数字化征途上的创新力…