DGL中NN模块的构造函数

在这里插入图片描述
上图引用自:dgl用户文档第三章(nn模块编写)

"""

构造函数完成以下几个任务:
1、设置选项。
2、注册可学习的参数或者子模块。
3、初始化参数。

"""
import torch.nn as nn
from dgl.utils import expand_as_pair
import dgl.nn
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
"""
在构造函数中,用户首先需要设置数据的维度。
对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 
对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。
对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型
包括 mean、 sum、 max 和 min。一些模块可能会使用更加复杂的聚合函数,
比如 lstm。
"""
"""注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear、 nn.LSTM 等。"""

class SAGE(nn.Module):
    def __init__(self, in_feats, out_feats, aggregator_type,
                 bias=True, norm=None, activation=None):
        super(SAGE, self).__init__()
        # 获取源节点和目标节点的输入特征维度
        self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)
        # 输出特征维度
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
        # 聚合类型:mean、pool、lstm、gcn
        if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()
    #  构造函数的最后调用了 reset_parameters() 进行权重初始化。
    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2
    def forward(self, graph, feat):    #SAGEConv示例中的 forward() 函数
        # 输入图对象的规范检测
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)
        # 消息传递和聚合
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # 更新特征作为输出
        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst
"""
在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,
DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:
1、检测输入图对象是否符合规范。
2、消息传递和聚合。
3、聚合后,更新特征作为输出。

forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 
比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入
度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是
,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward() 函数的输
出不会全为0。在这种情况下,无需进行此类检验。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图和子图块。



聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息
传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的
消息传递代码 里所介绍的性能优化。

聚合后,更新特征作为输出
forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
"""


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/185671.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

实例分割12篇顶会论文及代码合集,含2023最新

同学们,你们觉得视觉经典四个任务中哪个最难?我个人觉得是实例分割。 因为它既具备语义分割的特点,需要做到像素层面上的分类,也具备目标检测的一部分特点,即需要定位出不同实例,即使它们是同一种类。 但…

Docker 的基本概念和优势,以及在应用程序开发中的实际应用

Docker是一种轻量级的虚拟化技术,可以将应用程序及其依赖项打包在一个容器中,并在不同的计算机上运行。以下是Docker的基本概念和优势: 基本概念: 镜像(Image):一个只读的文件,包含…

Unity中Shader的Standard材质解析(二)

文章目录 前言一、我们对 Standard 的 PBR 的 GI 进行解析1、我们先创建一个PBR的.cginc文件,用于整理用到的函数2、然后在Standard的Shader中引用该cginc文件 二、依次整理函数到该cginc文件中我们来看一下PBR中GI的镜面反射做了些什么 二、最终代码.cginc代码&…

98、Text2Room: Extracting Textured 3D Meshes from 2D Text-to-Image Models

简介 github 利用预训练的2D文本到图像模型来合成来自不同姿势的一系列图像。为了将这些输出提升为一致的3D场景表示,将单目深度估计与文本条件下的绘画模型结合起来,提出了一个连续的对齐策略,迭代地融合场景帧与现有的几何形状&#xff0…

【C/PTA —— 11.函数2(课外实践)】

C/PTA —— 11.函数2(课外实践) 一.函数题6-1 计算A[n]1/(1 A[n-1])6-2 递归实现顺序输出整数6-3 自然数的位数(递归版)6-4 分治法求解金块问题6-5 汉诺塔6-6 重复显示字符(递归版)6-7 显示平行四边形(右)(递归版) 二.编程题7-2 N阶楼梯上楼问题 一.函数…

React项目中发生空白但不报错的原因分析和解决?

文章目录 前言组件渲染问题状态管理问题异步操作问题代码错误但未抛出异常如果我们使用的是chorme浏览器的话,可以下载一个开发者工具,例如下图:代码审查使用调试工具日志和输出检查外部依赖异步操作终极大法,不到万不得已不可以使…

Twincat使用:EtherCAT通信扫描硬件设备链接PLC变量

EtherCAT通信采用主从架构,其中一个主站设备负责整个EtherCAT网络的管理和控制,而从站设备则负责在数据环网上传递数据。 主站设备可以是计算机、工控机、PLC等, 而从站设备可以是传感器、执行器、驱动器等。 EL3102:MDP5001_300_CF8D1684;…

【LM358AD运放方波振荡器可控输出幅值】2022-2-25

缘由仿真如何缩小方波振荡电路方波幅值?-有问必答-CSDN问答

C#,《小白学程序》第六课:队列(Queue)其二,队列的应用,编写《实时叫号系统》

医院里面常见的《叫号系统》怎么实现的&#xff1f; 1 文本格式 /// <summary> /// 下面定义一个新的队列&#xff0c;用于演示《实时叫号系统》 /// </summary> Queue<Classmate> q2 new Queue<Classmate>(); /// <summary> /// 《小白学程序…

笔记:内网渗透流程之信息收集

信息收集 首先&#xff0c;收集目标内网的信息&#xff0c;包括子网结构、域名信息、IP地址范围、开放的端口和服务等。这包括通过主动扫描和渗透测试工具收集信息&#xff0c;以及利用公开的信息源进行信息搜集。 本机信息收集 查看系统配置信息 查看系统详细信息&#xf…

单链表的反转?太细了哥们!细到离谱!

单链表的反转&#xff08;面试常出&#xff09;&#xff1a; ​ 单链表的反转&#xff0c;可以通过很多种方法实现。包括迭代法&#xff0c;递归法&#xff0c; 迭代法&#xff1a; 定义三个指针&#xff1a;prev、current和next&#xff0c;它们分别表示前一个节点、当前节点…

【黑马甄选离线数仓day04_维度域开发】

1. 维度主题表数据导出 1.1 PostgreSQL介绍 PostgreSQL 是一个功能强大的开源对象关系数据库系统&#xff0c;它使用和扩展了 SQL 语言&#xff0c;并结合了许多安全存储和扩展最复杂数据工作负载的功能。 官方网址&#xff1a;PostgreSQL: The worlds most advanced open s…

中国企业500强的排名也在不断变化。面对不确定性的挑战,企业如何应对?

随着全球经济的不断发展和变化&#xff0c;中国企业500强的排名也在不断变化。面对不确定性的挑战&#xff0c;企业如何应对&#xff1f;在本文中&#xff0c;挖数据平台将提供数据源探讨中国企业500强在应对不确定性方面的突围与变革。 一、数据挖掘与分析 从2006年到2023年&…

【科普知识】什么是步进电机?

德国百格拉公司于1973年发明了五相混合式步进电机及其驱动器&#xff0c;1993年又推出了性能更加优越的三相混合式步进电机。我国在80年代以前&#xff0c;一直是反应式步进电机占统治地位&#xff0c;混合式步进电机是80年代后期才开始发展。 步进电机是一种用电脉冲信号进行…

CAN实验

CAN 寄存器 HAL库函数 代码 #include "./BSP/CAN/can.h"CAN_HandleTypeDef g_can1_handle; CAN_TxHeaderTypeDef g_can1_txheader; CAN_RxHeaderTypeDef g_can1_rxheader;/* STM32F103 TS1 8 TS2 7 BRP 3 波特率&#xff1a;36000 / [(9 8 1) * 4] 500Kbps …

数据结构-树-二叉树-堆的实现

1.树概念及结构 树是一种 非线性 的数据结构&#xff0c;它是由 n &#xff08; n>0 &#xff09;个有限结点组成一个具有层次关系的集合。 把它叫做树是因 为它看起来像一棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的 。 有一个特殊的结点&#xff…

[chroot+seccomp逃逸] THUCTF2019 之 固若金汤

题目分析 附件为一个源码, 其中注释我都写好了, 主要就讲关键的知识点. #define _GNU_SOURCE#include <stdio.h> #include <stdlib.h> #include <fcntl.h> #include <string.h> #include <errno.h> #include <sched.h> #include <uni…

【蓝桥杯省赛真题47】Scratch小猫踩球 蓝桥杯scratch图形化编程 中小学生蓝桥杯省赛真题讲解

目录 scratch小猫踩球 一、题目要求 编程实现 二、案例分析 1、角色分析

flutter之graphic图表自定义tooltip

renderer graphic中tooltip的TooltipGuide类提供了renderer方法&#xff0c;接收三个参数Size类型&#xff0c;Offset类型&#xff0c;Map<int, Tuple>类型。可查到的文档是真的少&#xff0c;所以只能在源码中扒拉例子&#xff0c;做符合需求的修改。 官方github示例 …

【TypeScript】常见数据结构与算法(二):链表

文章目录 链表结构&#xff08;LinkedList&#xff09;链表以及数组的缺点数组链表的优势 什么是链表?封装链表相关方法源码链表常见面试题237-删除链表中的节点206 - 反转链表 数组和链表的复杂度对比 链表结构&#xff08;LinkedList&#xff09; 链表以及数组的缺点 链表…