AGCRN论文解读

一、创新点

        传统GCN只能基于静态预定义图建模全局共享模式,而AGCRN通过两种GCN的增强模块(NAPL、DAGG)实现了更精细的节点特性学习和图结构生成。

1 节点自适应参数学习模块(NAPL)

        传统GCN通过共享参数(权重和偏置)在所有节点上执行图卷积操作,这样只能考虑到邻近节点间的空间相关性。然而在交通序列中,还因时间序列的动态特性和节点的多种影响因素(例如兴趣点、天气)而表现出多样性模式。

        因此,传统的图卷积公式必须作出如下变换。将权重 Θ 分解为两个更小的权重矩阵。其中EG代表节点嵌入矩阵,它为每个节点提供了一个区分性的表示,使得模型能够识别节点间的差异。WG是一个共享的权重池,包含了所有节点共享的候选特征模板,由节点嵌入EG确定哪些模板最适合每个节点。

         这么讲非常抽象晦涩!!!作者据说受到了矩阵分解的启发,也许可以从这方面入手理解。Θ 是三维矩阵(节点N,输入特征C,输出特征F),这时一个节点对应死了一个权重值。如果把第0维和第1维中间拦腰斩断,插入一个索引维度d,那么就分解出了EG(节点N,索引d)和WG(索引d,输入特征C,输出特征F)。因此,每个节点 Ni 就能通过索引 di 在权重池子里找到适合自己的那个权重。这种生成方式允许节点共享一个全局权重池,同时保留节点特定性。代码如下:

self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)

self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))

weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool)  #N, cheb_k, dim_in, dim_out
bias = torch.matmul(node_embeddings, self.bias_pool)

        那么到底改进了什么?

  1. 在GCN中,一个权重矩阵 Θ 就能处理所有节点的特征,不管每个节点今天刮风下雨还是有车祸。而在NAPL中,每个节点都有选择自己权重矩阵的权力。这样模型针对不同的节点才会预测得更加精准。
  2. 如果每个节点都有多个不同的权重矩阵,而不是共享一个权重池。那样会导致模型过于复杂,徒增计算资源。

        因此NAPL是一个很折中的办法! 

2 数据自适应图生成(DAGG) 

        在传统的基于GCN的交通预测模型中需要预定义的邻接矩阵 A 以进行图卷积操作。而预定义的图无法包含完整的空间依赖信息,也无法直接与预测任务相关联。

        有两点原因:首先,两条物理上相邻的道路可能由于单行道或隔离带而没有直接的交通交互。某条道路因为施工或重大活动会突然变得与其他道路关联更强。其次,构建预定义图通常需要依赖领域专家的知识,而且需要大量人力来设计和验证图的质量。

        因此,作者提出了数据自适应图生成(DAGG)模块。这个模块首先基于输入数据(如交通流量或速度)为所有节点随机初始化一个可学习的节点嵌入字典 EA ,是一个二维矩阵(节点数量N,嵌入维度d)。通过 EA 和其转置矩阵相乘即可得到节点之间的内积相似度矩阵,即新的邻接矩阵。这是为什么?

        首先,内积的本质就是用来衡量两个向量的相似性。两个向量方向相同,且相似性越大,内积就会越大。当两个向量垂直时,内积为0;方向相反,内积为负。而EA 的第 i 行和其转置矩阵的第 i 列是同一向量,那么 EA 的 i 行与其转置矩阵的 i+1 列相乘得到的结果就代表,第 i 个节点和 第 i+1 个节点的相似度。因此,逐行逐列相乘后就可以得到一个 N 行 N 列的邻接矩阵。

        下面是作者的实现代码,node_embeddings 就是上图的 EA 矩阵。

# AGCRN.py 中的代码
self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)

# AGCN.py 中的代码
supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)

        不过,自适应图生成的生成并不是作者的首创,早在2019年的 Graph WaveNet 论文(Graph WaveNet for Deep Spatial-Temporal Graph Modeling)中就提出过类似的方法,叫做自适应邻接矩阵。本篇论文是 2020 年的,不管是查看论文还是代码,可以发现就是说的一个东西。。。

二、自适应图卷积递归网络(AGCRN)

        作者讲上述两个模块与 GRU 结合起来可以得到:

        在前文中,可以得到图卷积 Z,再将 Z 提取为更新门 z重置门 r,那么我们就可以计算出候选状态以及最终输出 h。代码如下:

state = state.to(x.device)  # 确保 state 与输入 x 在相同设备(如 GPU)上
input_and_state = torch.cat((x, state), dim=-1)  # 将当前输入 x 和隐藏状态 state 沿特征维度拼接
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z*state), dim=-1)
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r*state + (1-r)*hc
return h

三、实验

        作者使用了 PeMSD4PeMSD8 数据集,其中训练集、验证集和测试集的划分比例为 6:2:2。将 12 步的历史数据作为输入,接下来的 12 步数据作为输出。

3.1 消融实验

        同时,作者为了更好地评估 NAPL 和 DAGG 的性能,还进行了全面的消融实验。分别对比了如下几组实验:

  1. GCGRU:传统的 GCN ;
  2. NAPL-GCGRU:将传统的 GCN 替换为我们提出的 NAPL-GCN;
  3. DAGG-GCGRU:将预定义图替换为 DAGG 模块;
  4. AGCRN-I:不统一节点嵌入,而是在不同的 NAPL-GCN 层和 DAGG 中使用独立的节点嵌入矩阵。
  5. AGCRN:作者的模型。

        其中作者提到了节点嵌入,是指在模型中使用统一的节点嵌入向量来表示不同模块或层之间的节点特性。这意味着,在 AGCRN 中,节点的嵌入向量在整个模型中是共享的,而不是为每个模块分别学习不同的嵌入。 查看代码,也可以发现整个项目只有一个 node_embeddings

        下图为作者的测试结果。可以发现NAPL在短期预测时表现不佳。这是因为短期预测(例如 5 分钟或 15 分钟的预测)通常能从已有的历史数据中获得足够的信息,所以不太依赖于节点特定的嵌入(即 NAPL 学到的每个节点的特性)。换句话说,短期预测可以依赖于前几个时间步的交通数据,而不需要太多的额外个性化特征。然而长期预测(例如 30 分钟或 60 分钟的预测)则面临更大的不确定性,因为历史数据可能无法充分预测未来的变化。这时,NAPL 所学习的节点特定嵌入变得尤为重要,因为它能帮助模型更好地捕捉每个节点的独特模式(例如道路的特殊流量模式、影响因素等),从而提高对长期预测的准确性。

3.2 图生成过程分析

        在上述的图卷积的公式中,还有一个单位矩阵 I。它的作用是保持每个节点的自信息,即让节点的当前特征在计算中不被忽视。这就类似于一个节点在做预测时不考虑自己的历史数据,只依赖邻居节点的状态,可能导致信息丢失或不准确。因此,作者对自适应图生成(DAGG)做了两个变种,即不使用单位矩阵的 DAGG-r不使用生成的图拉普拉斯矩阵的 DAGG-2

        由下表可以看出,如果不使用单位矩阵,那么效果会大打折扣。但是不使用图生成的拉普拉斯矩阵(预定义图),效果并没有很大差距。

        最后,作者还考虑了嵌入维度对实验的影响。发现嵌入维度为10时效果最好。

        本论文的PeMSD4数据集运行结果如下: 

2024-12-07 06:19: Total training time: 25.8792min, best loss: 19.460032
2024-12-07 06:19: Horizon 01, MAE: 18.68, RMSE: 30.26, MAPE: 12.5175%
2024-12-07 06:19: Horizon 02, MAE: 18.78, RMSE: 30.61, MAPE: 12.5259%
2024-12-07 06:19: Horizon 03, MAE: 18.98, RMSE: 31.03, MAPE: 12.6548%
2024-12-07 06:19: Horizon 04, MAE: 19.22, RMSE: 31.49, MAPE: 12.8187%
2024-12-07 06:19: Horizon 05, MAE: 19.50, RMSE: 31.96, MAPE: 13.0233%
2024-12-07 06:19: Horizon 06, MAE: 19.77, RMSE: 32.41, MAPE: 13.2145%
2024-12-07 06:19: Horizon 07, MAE: 20.01, RMSE: 32.82, MAPE: 13.3573%
2024-12-07 06:19: Horizon 08, MAE: 20.20, RMSE: 33.19, MAPE: 13.4612%
2024-12-07 06:19: Horizon 09, MAE: 20.38, RMSE: 33.54, MAPE: 13.5726%
2024-12-07 06:19: Horizon 10, MAE: 20.59, RMSE: 33.93, MAPE: 13.7235%
2024-12-07 06:19: Horizon 11, MAE: 20.93, RMSE: 34.43, MAPE: 13.9569%
2024-12-07 06:19: Horizon 12, MAE: 21.45, RMSE: 35.11, MAPE: 14.3262%
2024-12-07 06:19: Average Horizon, MAE: 19.87, RMSE: 32.60, MAPE: 13.2627%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/935478.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用观测云排查数据库死锁故障

故障发现 核心应用 pod 发生重启,同时接收到对应使用者反馈业务问题,开始排查。 观测云排查现场 1、根据重启应用信息,查询 APM 执行数据库 update 操作大量报错,执行时间在 5min 以上。 分析 APM 链路异常,发现是触…

UNIX数据恢复—UNIX系统常见故障问题和数据恢复方案

UNIX系统常见故障表现: 1、存储结构出错; 2、数据删除; 3、文件系统格式化; 4、其他原因数据丢失。 UNIX系统常见故障解决方案: 1、检测UNIX系统故障涉及的设备是否存在硬件故障,如果存在硬件故障&#xf…

npm或yarn包配置地址源

三种方法 1.配置.npmrc 文件 在更目录新增.npmrc文件 然后写入需要访问的包的地址 2.直接yarn.lock文件里面修改地址 简单粗暴 3.yarn install 的时候添加参数 设置包的仓库地址 yarn config set registry https://registry.yarnpkg.com 安装:yarn install 注意…

opencv——图片矫正

图像矫正 图像矫正的原理是透视变换,下面来介绍一下透视变换的概念。 听名字有点熟,我们在图像旋转里接触过仿射变换,知道仿射变换是把一个二维坐标系转换到另一个二维坐标系的过程,转换过程坐标点的相对位置和属性不发生变换&a…

【学习】企业通过CMMI认证,还需要申请CSMM资质吗

​ 企业通过CMMI认证之后,是否还有必要申请CSMM资质?这是一个值得软件企业深思的问题。虽然CMMI和CSMM都在组织的软件过程改进和认证方面发挥着重要作用,但它们各自拥有自己的特点在。企业需要根据自身发展需求来选择适合的认证方式。首先我…

OpenHarmony-3.HDF input子系统(5)

HDF input 子系统OpenHarmony-4.0-Release 1.Input 概述 输入设备是用户与计算机系统进行人机交互的主要装置之一,是用户与计算机或者其他设备通信的桥梁。常见的输入设备有键盘、鼠标、游戏杆、触摸屏等。本文档将介绍基于 HDF_Input 模型的触摸屏器件 IC 为 GT91…

BurpSuite之移动端流量抓包

学习视频来自B站UP主泷羽sec,如涉及侵权马上删除文章。 笔记只是方便学习,以下内容只涉及学习内容,切莫逾越法律红线。 安全见闻,包含了各种网络安全,网络技术,旨在明白自己的渺小,知识的广博&a…

Any2Policy: Learning Visuomotor Policy with Any-Modality(类似AnyGPT)

发表时间:NeurIPS 2024 论文链接:https://readpaper.com/pdf-annotate/note?pdfId2598959255168534016&noteId2598960522854466816 作者单位:Midea Group Motivation:Current robotic learning methodologies often focus…

QTreeView 与 QTreeWidget 例子

1. 先举个例子 1班有3个学生:张三、李四、王五 4个学生属性:语文 数学 英语 性别。 语文 数学 英语使用QDoubleSpinBox* 编辑,范围为0到100,1位小数 性别使用QComboBox* 编辑,选项为:男、女 实现效果: 2…

计算机视觉与医学的结合:推动医学领域研究的新机遇

目录 引言医学领域面临的发文难题计算机视觉与医学的结合:发展趋势计算机视觉结合医学的研究方向高区位参考文章结语 引言 计算机视觉(Computer Vision, CV)技术作为人工智能的重要分支,已经在多个领域取得了显著的应用成果&…

谷粒商城—分布式基础

1. 整体介绍 1)安装vagrant 2)安装Centos7 $ vagrant init centos/7 A `Vagrantfile` has been placed in this directory. You are now ready to `vagrant up` your first virtual environment! Please read the comments in the Vagrantfile as well as documentation on…

麒麟系统+达梦数据库+MybatisPlus+Redis+SpringBoot

环境准备 麒麟系统 在麒麟系统官网进行下载镜像 这里选择的是麒麟V10桌面版,使用虚拟机启动 修改root密码 # 启动到单用户模式 init 1 # 修改 root 密码 passwd root # 重启 reboot达梦数据库准备 进入达梦官网 我这里选择的是达梦数据库管理系统DM8开发版 下…

DFC:控制 ~~到达率~~ 最小化等待时间

DFC:控制 到达率 最小化等待时间 计算节点的等待成本:公式(2) ( λ i λ ( W q i C i μ c i ‾ ) ) (\frac{\lambda_i}{\lambda}(W_q^i C_i\overline{\mu_c^i})) (λλi​​(Wqi​Ci​μci​​)) 在这个到达率下的等待时间&am…

单词翻转

单词翻转 C语言实现C实现Java实现Python实现 💐The Begin💐点点关注,收藏不迷路💐 输入一个句子(一行),将句子中的每一个单词翻转后输出。 输入 只有一行,为一个字符串&#xff0c…

数据分析实战—房价特征关系

1.实战内容 (1) 读取房价特征关系表(house_price.npz)绘制离地铁站的距离与单位面积的房价的散点图,并对其进行分析; import pandas as pd import numpy as np import warnings warnings.filterwarnings(&…

网页502 Bad Gateway nginx1.20.1报错与解决方法

目录 网页报错的原理 查到的502 Bad Gateway报错的原因 出现的问题和尝试解决 问题 解决 网页报错的原理 网页显示502 Bad Gateway 报错原理是用户访问服务器时,nginx代理服务器接收用户信息,但无法反馈给服务器,而出现的报错。 查到…

Linux入门攻坚——41、Linux集群系统入门-lvs(2)

lvs-dr:GATEWAY Director只负责请求报文,响应报文不经过Director,直接由RS返回给Client。 lvs-dr的报文路线如上图,基本思路就是报文不会回送Director,第①种情况是VIP、DIP、RIP位于同一个网段,这样&…

【Python网络爬虫笔记】10- os库存储爬取数据

os库的作用 操作系统交互:os库提供了一种使用Python与操作系统进行交互的方式。使用os库来创建用于存储爬取数据的文件夹,或者获取当前工作目录的路径,以便将爬取的数据存储在合适的位置。环境变量操作:可以读取和设置环境变量。在…

在CentOS中安装和卸载mysql

在CentOS7中安装和卸载mysql 卸载mysql1、查看是否安装过mysql2、查看mysql服务状态3、关闭mysql服务4、卸载mysql相关的rpm程序5、删除mysql相关的文件6、删除mysql的配置文件my.cnf 安装mysql1、下载mysql相关的rpm程序2、检查/tmp临时目录权限3、安装mysql前的依赖检查3、安…

HDOJ 1735:字数统计 ← 贪心

【题目来源】https://acm.hdu.edu.cn/showproblem.php?pid1735【题目描述】 一天,淘气的 Tom 不小心将水泼到了他哥哥 Jerry 刚完成的作文上。原本崭新的作文纸顿时变得皱巴巴的,更糟糕的是由于水的关系,许多字都看不清了。可怜的 Tom 知道他…