机器学习周记(第三十一周:文献阅读-GGNN)2024.3.18~2024.3.24

目录

摘要

ABSTRACT

1 论文信息

1.1 论文标题

1.2 论文模型

1.2.1 数据处理

1.2.2 门控图神经网络

1.2.3 掩码操作

2 相关知识

2.1 图神经网络(GNN)

2.2 图卷积神经网络(GCN)

3 相关代码


摘要

  本周阅读了一篇利用图神经网络(GNN)与门控循环单元(GRU)进行配水网络(WDN)水质预测的论文。论文模型(GGNN)实现了扩展图邻接矩阵在有向图中加入双向信息流,从而增强了模型的双向学习能力。同时模型还利用掩码操作模拟了站点故障导致数据缺失的情况,根据正常站点数据也能对故障站点进行预测,并且还能解决模型过拟合或者欠拟合的问题。

ABSTRACT

  This week, We read a paper on water quality prediction in water distribution networks (WDNs) using Graph Neural Networks (GNN) and Gated Recurrent Units (GRU). The paper introduces a model called GGNN, which extends the graph adjacency matrix to incorporate bidirectional information flow in directed graphs, thus enhancing the model's bidirectional learning capability. Additionally, the model utilizes masking operations to simulate data missing due to station failures, enabling the prediction of faulty stations based on normal station data. Moreover, it addresses the issues of model overfitting or underfitting.

1 论文信息

1.1 论文标题

Real-time water quality prediction in water distribution networks using graph neural networks with sparse monitoring data

1.2 论文模型

  论文模型(GGNN)旨在利用门控图神经网络(GGNN)处理网络拓扑结构、流向以及水质监测站的历史氯浓度测量数据来预测配水网络(WDN)中的实时水质。该模型由两个主要部分组成:(1)对供水网络信息进行数据处理,输入到图神经网络中;(2)利用收集到的数据构建模型。

Fig.1 基于GGNN的实时水质预测方法示意图

1.2.1 数据处理

  GGNN模型需要两类数据:传感器监测站的WDN拓扑结构和历史水质监测数据。假设一个WDNn个节点和m条管道组成,配备N_{s}个传感器站监测水质。网络拓扑由图G=(V,E)表示,其中V表示由水库、储罐和连接点组成的节点集,E表示由管道、阀门和泵组成的边集。网络的流向信息和空间拓扑细节通常可以从EPANET等水力模型中获得。利用这些数据构建有向图的邻接矩阵A \in \mathbb{R}^{n\times n},其中每个元素A_{ij}表示水是否从节点i流向节点j (A_{ij}=1)或不流向 (A_{ij}=0)。论文仅在边的权重相等时考虑水流方向。更进一步还可以同时考虑流量的动态变化和加权边。

  通过在WDN中实现的监控和数据采集(SCADA)系统,可以获得各监测站的历史水质数据。该数据采集过程包括在指定的时间窗口内采集水质测量数据,记为T_{c},也表示采集历史数据的周期时间。然后将采集到的数据作为数据集中被监测节点的节点属性,对于未被监测节点,将空值替换为0,得到节点属性X\in \mathbb{R}^{n \times N_{c}}N_{c}表示数据采集周期T_{c}内获得的水质测量次数,对应于指定时间窗口内的时间步数。它是预测下一时刻水质所需数据大小的指标。

1.2.2 门控图神经网络

  为了解决WDN的非欧氏图域带来的挑战,将GGNN架构用于水质预测。GGNN是一种图神经网络,用于处理复杂的图结构数据,如WDN拓扑。它扩展了通常定义在欧氏域上的传统神经网络,使其能够直接处理非欧氏图数据。GGNN模型根据相邻节点和边之间传递的消息为每个节点v\in V计算状态向量h_{v}。状态向量h_{v}表示节点学习到的特征表示,编码了关于图的局部和全局信息。它可以被认为是节点的隐藏状态,从其邻域和整个图中捕获相关信息。最终,状态向量可用于水质预测。GGNN的整体工作流程如Fig.2所示。

Fig.2 GGNN总体架构示意图

  首先,通过扩展邻接矩阵A \in \mathbb{R}^{n \times n},在有向图中加入双向信息流来作为输入。主要通过将邻接矩阵A与其转置连接起来,形成一个扩展的邻接矩阵\widehat{A}=\left [ A,A^{T} \right ]来实现的,这样可以同时考虑输入边和输出边。\widehat{A} \in \mathbb{R}^{n \times 2n}捕获了节点之间的复杂关系和消息传播方向,从而增强了GGNN的双向学习能力。

  然后,通过标准线性组合修正线性单元(rectified linear unit, ReLU)激活函数将节点v的节点属性x_{v}从原始空间\mathbb{R}^{N_{c}}映射到新空间\mathbb{R}^{M}的原始隐藏状态h_{v}^{(0)}。这种映射过程有效地扩大了节点属性的大小,使GGNN能够捕获节点属性之间潜在的重要非线性关系。隐藏状态的大小用M表示,是一个决定模型容量的超参数。然而,至关重要的是要与M取得平衡,以防止过拟合并控制训练期间的计算复杂性。

  GGNN以扩展的邻接矩阵\widehat{A}=\left [ A,A^{T} \right ]和映射的节点属性h^{(0)}为输入,在固定的k步上递归计算节点状态以产生最终的状态矩阵h^{(K)}\in \mathbb{R}^{n \times M}。在聚合阶段,利用扩展邻接矩阵\widehat{A}计算聚合向量a_{v}a_{v}表示节点v和相邻节点状态的聚合,聚合向量的计算公式如下:

a_{v}^{(k)}=\widehat{A}^{T}_{v:}\left [ h_{1}^{(k-1)^{T}},...,h_{n}^{(k-1)^{T}} \right ]^{T}+b                                                                              (1)

其中,上标k表示时间步长,\widehat{A}_{v:}\in \mathbb{R}^{n \times 2}是块\widehat{A}中对应节点v的两列,b是偏移向量。在聚合阶段之后,传播阶段采用门控循环单元(gated recurrent units, GRU)机制更新节点状态。GRU传播方程描述如下:

r_{v}^{(k)}=\sigma (W_{r} \cdot a_{v}^{(k)}+U_{r}\cdot h_{v}^{(k-1)})                                                                                       (2)

z_{v}^{(k)}=\sigma (W_{z} \cdot a_{v}^{(k)}+U_{z}\cdot h_{v}^{(k-1)})                                                                                       (3)

\widetilde{h}_{v}^{(k)}=\tanh (W \cdot a_{v}^{(k)}+U\cdot (r_{v}^{(k)}\bigodot h_{v}^{(k-1)}))                                                                   (4)

h_{v}^{(k)}=(1-z_{v}^{(k)})\bigodot h_{v}^{(k-1)}+z_{v}^{(k)}\bigodot \widetilde{h}_{v}^{(k)}                                                                         (5)

其中rz是重置门和更新门;W_{r},W_{z},WU_{r},U_{z},U是每层的权重和偏差;\sigma (\cdot)sigmoid激活函数;\bigodot是元素点积运算。

  GGNN中的聚合和传播步骤允许模型迭代更新和细化节点状态,合并来自节点先前的特征及其邻近节点的特征信息。这个迭代过程捕获了图结构内的动态和交互规则,使GGNN能够学习和表示节点之间的复杂关系和依赖关系。传播步长K(也即GNN层数)决定了GGNN中信息传播的深度。当K=1时,每个节点只能从其近邻节点学习。随着K的增加,GGNN可以从距离K步的节点捕获信息,包括它们的间接连接。K的选择影响模型的学习能力和效率。较高的K值会导致训练较慢以及增加内存需求,而较低的K值会限制每个节点可以学习的依赖关系的数量。因此,K的选择应该在模型性能和计算效率之间取得平衡。

  在使用GRU模块更新节点状态后,使用线性层将更新后的状态h^{(K)}转换为表示每个节点预测状态的\widehat{Y}\in \mathbb{R}^{n}。在本研究中,节点属性为历史水质浓度数据,其预测状态表示模型对每个节点下一时间步水质浓度的预测。这种转换允许模型根据其更新的表示和从邻近节点传播的信息在每个节点生成对水质的预测。

1.2.3 掩码操作

  虽然之前的研究主要采用掩码操作(Maskng Operation来模拟传感器故障,特别是在不利条件下测试模型的鲁棒性,但本文方法在训练阶段利用掩码操作来增强模型对未监测节点的预测能力。在训练过程中,结合掩码操作对解决两个重大挑战至关重要。首先,现有研究通常假设传感器节点的输入,并根据模拟的网络中所有节点的值来计算损失,这在现实世界中是不切实际的,因为获取非传感器节点的测量数据很困难。论文使用模拟模型的合成数据,这样数据虽然完整,但作者并没有使用所有网络节点的所有数据进行训练。相反,只使用了一小部分节点数据。其次,如果模型仅基于传感器节点的输入进行训练,并基于这些节点计算损失,可能会导致过拟合,阻碍模型预测未监测节点的水质的能力。为了克服这些挑战,在训练过程中引入了掩码操作。随机选择指定比例(例如20%)的传感器节点,并通过在每个训练批次中将其输入替换为零进行掩盖。这个屏蔽操作有两个目的。首先,在训练过程中模拟非传感器节点数据的不可用性,使模型能够在观测到的传感器数据之外进行泛化,并学习预测无监测节点的值;其次,它作为正则化技术,防止模型仅依赖有限的传感器输入。通过鼓励模型捕捉传感器节点和非监测节点之间的关系,提高模型的泛化能力,降低过拟合的可能性。需要研究掩码节点的比例,因为它可以平衡模型性能和过拟合。更高的比率会减少可用的信息,增加欠拟合的风险。较低的速率可以提供更多的信息,但可能会导致过拟合。因此,掩码率也是一个十分重要的超参数。

2 相关知识

2.1 图神经网络(GNN)

2.2 图卷积神经网络(GCN)

  需要注意的是,常规任务情境下不会需要节点的信息传播太远。经过6~7个hops,基本上就可以使节点的信息传播到整个网络,这也使得聚合不那么有意义。实验结果也表明,2~3层的网络应该是比较好的,当GCN达到7层时,效果已经变得较差,但是通过在隐藏层间加上残差连接(Residual Connections)可以使效果变好。

3 相关代码

GCN模型定义与图结构数据定义:

import torch
import torch.nn as nn
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx


class GCN(nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = nn.Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)  # 输入特征与邻接矩阵
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()

        out = self.classifier(h)

        return out, h


def visualize_graph(G, color):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, node_color=color, cmap="Set2")
    plt.show()


def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()


dataset = KarateClub()
print(f'Dataset: {dataset}')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]
# x:[34, 34](M*F,M:样本数,F:特征维度)
# edge_index:[2, 156](两个数组,第一个为source,第二个为target,156条边)
# y:[34](标签)
# train_mask:[34](指定节点是否有标签,通过此数组可以选择哪些节点计算损失,元素类型为bool)
print(data)
print(dataset.edge_index)

G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

数据集KarateClub的图结构:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/485314.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

银行监管报送系统介绍(六):客户风险数据报送系统

【概念定义】 银监会决定从2013年起实行新版客户风险统计制度,对各政策性银行、国有商业银行、股份制商业银行进行客户信息汇总统计。 客户风险统计信息,是指新版客户风险统计报送信 息。客户风险统计报送信息包括但不限于对公及同业客户授信和 表内外业…

ClickHouse--11--物化视图

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1.物化视图什么是物化视图? 1.1 普通视图1.2 物化视图1.3 优缺点1.4 基本语法1.5 在生产环境中创建物化视图1.6 AggregatingMergeTree 表引擎3.1 概念3.2 Aggregat…

【Linux】Linux工具学习之git

🔥博客主页: 小羊失眠啦. 🎥系列专栏:《C语言》 《数据结构》 《C》 《Linux》 《Cpolar》 ❤️感谢大家点赞👍收藏⭐评论✍️ 文章目录 前言一、账号注册1.1 GitHub与Gitee 二、构建仓库三、安装git 四、配置git五、克…

树状数组原理和代码

树状数组 求下标的对应 求i管着的下标的范围 方法:拆掉最右侧的1然后1 到你自己 query sum 1-i的和 拆掉最右侧的1 再把下一个数值吸收到sum 重复这个过程直到全变0为止 add 方法:加上最右侧的1 到上限为止 lowbit方法 单点增加范围查询模板 #inc…

Redis持久化【RDB,bgsave的写时复制机制】【AOF,aof重写机制】【Redis混合持久化,以及对应改变aof重写规则】【Redis数据备份策略】

Redis持久化 RDB快照(snapshot)bgsave的写时复制(COW)机制 AOF(append-only file)AOF重写 Redis 4.0 混合持久化开启持久化后,AOF重写规则发生了变化 Redis数据备份策略: 转自 图灵课堂 RDB快照&#xff0…

第390场 LeetCode 周赛题解

A 每个字符最多出现两次的最长子字符串 滑动窗口&#xff1a;枚举窗口的左边界&#xff0c;尽可能右移窗口的右边界。 (当然也可以暴力枚举) class Solution { public:int maximumLengthSubstring(string s) {vector<int> cnt(26);int res 0;for (int l 0, r -1, n s…

python第三方库的安装,卸载和更新,以及在cmd下pip install安装的包在pycharm不可用问题的解决

目录 第三方库pip安装&#xff0c;卸载更新 1.安装&#xff1a; 2.卸载 3.更新 一、第三方库pip安装&#xff0c;卸载更新 1.安装 pip install 模块名 加镜像下载&#xff1a;pip install -i 镜像网址模块名 常用的是加清华镜像&#xff0c;如 pip install -i https://pyp…

jmeter链路压测

比如登录后返回token&#xff0c;业务打印上传的操作需要用到token 线程组中添加登录请求&#xff0c;并执行 1、添加登录并执行&#xff0c;查看结果 2、结果树中下拉选择正则表达式&#xff0c;将token参数和值复制粘贴到下方&#xff0c;将token值改为(.*?)&#xff0…

Pinctrl子系统_05_Pincontroller构造过程情景分析

上一节我们了解了Pinctrl子系统主要的数据结构&#xff0c;要想更好的掌握Pinctrl子系统&#xff0c;还需要知道他的构造过程。 本节我们就来分析一下Pinctrl子系统的构造过程。 以内核面向对象的思想&#xff0c;设备树可以分为两部分&#xff0c;左边是Pinctrl子系统节点&a…

毕业论文降重(gpt+完美降重指令),sci论文降重gpt指令——超级好用,重复率低于4%

1. 降重方法&#xff1a;gpt降重指令 2. gpt网站 https://yiyan.baidu.com/ https://chat.openai.com/ 3. 降重指令——非常好用&#xff01;&#xff01;sci论文&#xff0c;本硕大论文都可使用&#xff01; 请帮我把下面句子重新组织&#xff0c;通过调整句子逻辑&#xff0…

牛客NC218 检测循环依赖【中等 图 Java,Go,PHP】

题目 题目链接&#xff1a; https://www.nowcoder.com/practice/8dc02ad98553432a90affc3a0484910b 思路 图的基本知识要理解&#xff0c;一般用Map来表示 图解决拓扑排序&#xff0c;依赖之类的问题 感觉课程数在这道题里面可以不用&#xff0c;因为没有规定所有课程都得有先…

解决方案Please use Oracle(R) Java(TM) 11, OpenJDK(TM) 11 to run Neo4j.

文章目录 一、现象二、解决方案 一、现象 当安装好JDK跟neo4j&#xff0c;用neo4j.bat console来启动neo4却报错&#xff1a; 部分报错信息&#xff1a; Starting Neo4j. WARNING! You are using an unsupported Java runtime. Please use Oracle Java™ 11, OpenJDK™ 11 t…

Jenkins中使用Generic Webhook Trigger插件实现持续集成

项目环境 宝塔Linux面板DockerJenkinsgitee 目的 实现每次push推送dev分支到gitee上&#xff0c;Jenkins自动构建项目&#xff1b;push其它分支时&#xff0c;不运行。 实现方法 1.在Jenkins上安装Generic Webhook Trigger插件 在“系统设置–插件管理–可选插件”界面搜…

jmeter中参数加密

加密接口常用的方式有&#xff1a; MD5&#xff0c;SHA&#xff0c;HmacSHA RSA AES&#xff0c;DES&#xff0c;Base64 压测中有些参数需要进行加密&#xff0c;加密方式已接口文档为主。 MD5加密 比如MD5加密的接口文档&#xff1a; 请求URL&#xff1a;http://101.34.221…

STM32 | Systick定时器(第四天源码解析)

STM32 | Systick定时器(第四天)STM32 | STM32F407ZE中断、按键、灯(续第三天)1、参考delay_us代码,完成delay_ms的程序 定时器频率换算单位:1GHZ=1000MHZ=1000 000KHZ = 1000 000 000HZ 定时器定时时间:计数个数/f(频率) 或者 (1/f(频率))*计数的个数 500/1MHZ = 500/1…

Spring相关框架八股

单例bean是线程安全的吗&#xff1f; AOP 事务失效 Bean生命周期 Bean循环依赖解决 MVC执行流程 自动装配原理 Spring常见注解 SpringMVC注解 SpringBoot注解 MyBatis执行流程 MyBatis延迟加载 MyBatis缓存 SpringCloud五大组件 注册中心Nacos、Eureka 负载均衡Ribbon 服务雪崩…

Godot.NET C# 工程化开发(1):通用Nuget 导入+ 模板文件导出,包含随机数生成,日志管理,数据库连接等功能

文章目录 前言Github项目地址&#xff0c;包含模板文件后期思考补充项目设置编写失误环境visual studio 配置详细的配置看我这篇文章 Nuget 推荐NewtonSoft 成功Bogus 成功Github文档地址随机生成构造器生成构造器接口(推荐) 文件夹设置Nlog 成功&#xff01;Nlog.configNlogHe…

2025汤家凤考研数学视频,基础网课百度网盘课程+PDF讲义资料

2025汤家凤大神及数学全程 docs.qq.com/doc/DTmtOa0Fzc0V3WElI 复制粘贴到浏览器&#xff0c;可以见所有的Ke 第一轮 夯实基础 1.阅读大纲考查要求&#xff0c;明确每章的学习目标&#xff1b; 2.按节学习数学理论基础知识&#xff0c;吃透书中例题&#xff1b; 3.学习每章…

红外遥控器的使用和详细解释

infrared.c #include "infrared.h"/* 红外 --- PA8*/void Infrared_Init(void) {GPIO_InitTypeDef GPIO_InitStruct; EXTI_InitTypeDef EXTI_InitStruct;NVIC_InitTypeDef NVIC_InitStruct;//使能SYSCFG时钟RCC_APB2PeriphClockCmd(RCC_APB2Periph_SYSCFG, E…

【数据结构】五分钟自测主干知识(十)

上一节&#xff0c;我们讲述了二叉树的概念&#xff0c;二叉树又有什么基本操作呢&#xff1f;今天我们来讲述二叉树的应用~ 话不多说&#xff0c;书继上回 5.3二叉树的遍历及应用 二叉树由三个基本部分组成&#xff1a;根结点&#xff08;D&#xff09;&#xff0c;左子树&a…