【论文解读】2017 STGCN: Spatio-Temporal Graph Convolutional Networks

一、简介

使用历史速度数据预测未来时间的速度。同时用于序列学习的RNN(GRU、LSTM等)网络需要迭代训练,它引入了逐步累积的误差,并且RNN模型较难训练。为了解决以上问题,我们提出了新颖的深度学习框架STGCN,用于交通预测。

二、STGCN模型架构

2.1 整体架构图示

在这里插入图片描述

2.2 ST-Conv blocks

符号含义
M历史时间序列长度
n节点数
C i C_i Ci输入的channel 数
C o C_o Co输出的channel 数

2.2.1 TemporalConv: Gated CNNs 用于提取时间特征

Note: nn.Conv2d的输入 channel在第一维度

[ P Q ] = C o n v ( x ) ; o u t = P ⊙ σ ( Q ) [P Q] = Conv(x); \\ out = P \odot \sigma (Q) [PQ]=Conv(x);out=Pσ(Q)

  • x ∈ R C i × M × n x \in \mathbb{R}^{C_i \times M \times n } xRCi×M×n
  • [ P Q ] ∈ R 2 C o ∗ ( M − K t + 1 ) × n [\text{P Q}] \in \mathbb{R}^{2C_o * (M - K_t + 1) \times n } [P Q]R2Co(MKt+1)×n

示例代码:

class TCN(nn.Module):
    def __init__(self, c_in: int, c_out: int, dia: int=1):
        """TemporalConvLayer
        input_dim:  (batch_size, 1, his_time_seires_len, node_num)
        sample:     [b, 1, 144, 207]
        Args:
            c_in (int): channel in
            c_out (int): channel out
            dia (int, optional): 空洞卷积大小. Defaults to 1.
        """
        super(TCN, self).__init__()
        self.c_out = c_out * 2
        self.c_in = c_in
        self.conv = nn.Conv2d(
            c_in, self.c_out, (2, 1), 1, padding=(0, 0), dilation=dia
        )

    def forward(self, x):
        # [batch, channel, his_n, node_num] 
        #  仅在时间维度上进行卷积 
        c = self.c_out//2
        out = self.conv(x)
        if len(x.shape) == 3: # channel, his_n, node_num
            P = out[:c, :, :]
            Q = out[c:, :, :]
        else:
            P = out[:, :c, :, :]
            Q = out[:, c:, :, :]
        return P * torch.sigmoid(Q)

2.2.2 SpatialConv: Graph CNNs 提取空间信息

迭代定义的切比雪夫多项式

o u t = Θ ∗ G x = ∑ k = 0 K − 1 θ k T k ( L ~ ) x = ∑ k = 0 K − 1 W K , l z k , l out= \Theta_{* \mathcal{G}} x = \sum_{k=0}^{K-1}\theta_k T_k(\tilde{L})x=\sum_{k=0}^{K-1}W^{K, l}z^{k, l} out=ΘGx=k=0K1θkTk(L~)x=k=0K1WK,lzk,l

  • Z 0 , l = H l Z^{0, l} = H^{l} Z0,l=Hl
  • Z 1 , l = L ~ ⋅ H l Z^{1, l} = \tilde{L} \cdot H^{l} Z1,l=L~Hl
  • Z k , l = 2 ⋅ L ~ ⋅ Z k − 1 , l − Z k − 2 , l Z^{k, l} = 2 \cdot \tilde{L} \cdot Z^{k-1, l} - Z^{k-2, l} Zk,l=2L~Zk1,lZk2,l
  • L ~ = 2 ( I − D ~ − 1 / 2 A ~ D ~ − 1 / 2 ) / λ m a x − I \tilde{L} = 2\left(I - \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\right)/\lambda_{max} - I L~=2(ID~1/2A~D~1/2)/λmaxI

论文: Recursive formulation for fast filtering

示例代码:

class STCN_Cheb(nn.Module):
    def __init__(self, c, A, K=2):
        """spation cov layer
        Args:
            c (int): hidden dimension
            A (adj matrix): adj matrix
        """
        super(STCN_Cheb, self).__init__()
        self.K = K
        self.lambda_max = 2
        self.tilde_L = self.get_tilde_L(A)
        self.weight = nn.Parameter(torch.empty((K * c, c)))
        self.bias = nn.Parameter(torch.empty(c))
        stdv = 1.0 / np.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def get_tilde_L(self, A):
        I = torch.diag(torch.Tensor([1] * A.size(0))).float().to(A.device)
        tilde_A = A + I 
        tilde_D = torch.diag(torch.pow(tilde_A.sum(axis=1), -0.5))
        return 2 / self.lambda_max * (I - tilde_D @ tilde_A @ tilde_D) - I

    def forward(self, x):
        # [batch, channel, his_n, node_num] -> [batch, node_num, his_n, channel] -> [batch, his_n, node_num, channel] 
        x = x.transpose(1, 3)
        x = x.transpose(1, 2)
        output = self.m_unnlpp(x)
        output = output @ self.weight + self.bias
        output = output.transpose(1, 2)
        output = output.transpose(1, 3)
        return torch.relu(output) 

    def m_unnlpp(self, feat):
        K = self.K
        X_0 = feat
        Xt = [X_0]
        # X_1(f)
        if K > 1:
            X_1 = self.tilde_L @ X_0
            # Append X_1 to Xt
            Xt.append(X_1)
        # Xi(x), i = 2...k
        for _ in range(2, K):
            X_i =  2 * self.tilde_L @ X_1 - X_0
            # Add X_1 to Xt
            Xt.append(X_i)
            X_1, X_0 = X_i, X_1
        # 合并数据
        Xt = torch.cat(Xt, dim=-1)
        return Xt

2.2.3 ST-Block

组合TCNSTCN_Cheb
v l + 1 = Γ 1 ∗ T l ReLU ( Θ ∗ G l ( Γ 0 ∗ T l v l ) ) v^{l+1} = \Gamma ^{l} _{1*\mathcal{T}} \text{ReLU}( \Theta ^l_{*\mathcal{G}} (\Gamma ^{l} _{0*\mathcal{T}} v^l) ) vl+1=Γ1TlReLU(ΘGl(Γ0Tlvl))

  • Γ 0 ∗ T l v l \Gamma ^{l} _{0*\mathcal{T}} v^l Γ0Tlvl: 第一个TCN
  • Θ ∗ G l \Theta ^l_{*\mathcal{G}} ΘGl : STCN_Cheb
  • Γ 1 ∗ T l v l \Gamma ^{l} _{1*\mathcal{T}} v^l Γ1Tlvl: 第二个TCN
class STBlock(nn.Module):
    def __init__(
        self,
        A,
        K=2,
        TST_channel: List=[64, 16, 64]
        T_dia: List=[2, 4]
    ):
        # St-Conv Block1[  TCN(64, 16*2)->SCN(16, 16)->TCN(16, 64*2) ] 
        super(STBlock, self).__init__()
        self.T1 = TCN(TST_channel[0], TST_channel[1], dia=T_dia[0])
        # STCN_Cheb out have relu
        self.S = STCN_Cheb(TST_channel[1], Lk=A, K=K)
        self.T2 = TCN(TST_channel[1], TST_channel[2], dia=T_dia[1])

    def forward(self, x):
        return self.T2(self.S(self.T1(x)))

三、简单复现

复现可以看笔者的github: train.ipynb
用的数据是metr-la.h5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/39844.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Unity 多相机 同屏显示

一 首先了解: 相机和Canvas 的渲染先后关系 什么是相机的渲染顺序? 答:简单理解就是 用毛刷 刷墙面,先刷的,会被后刷的 挡住 。 列如:相机01: 先渲染的大海,相机02:后…

如何使用DiskPart命令行格式化分区?

想要格式化磁盘分区,您可以使用磁盘管理工具,或在Windows文件资源管理器中右键单击驱动器并选择“格式化”。如果您更想使用命令行来格式化磁盘,那么Windows自带的DiskPart将是首选。 DiskPart有很多优点,例如,如果您想…

PyTorch 1.13简介

# 1.  PyTorch 1.13 据官方介绍,PyTorch 1.13 中包括了 BetterTransformer 的稳定版,且不再支持 CUDA 10.2 及 11.3,并完成了向 CUDA 11.6 及 11.7 的迁移。此外 Beta 版还增加了对 Apple M1 芯片及 functorch 的支持。 1.1 主要更新 Be…

Java虚拟机——字节码指令简介

Java虚拟机的指令由一个字节长度的、代表着某种特定操作含义的数字(称为操作码) 以及 跟随其后的零至多个代表此操作所需的参数(称为操作数)构成。大多数指令都不包括操作数,只有一个操作码,指令参数都存放…

【云原生】k8s安全机制

前言 Kubernetes 作为一个分布式集群的管理工具,保证集群的安全性是其一个重要的任务。API Server 是集群内部各个组件通信的中介, 也是外部控制的入口。所以 Kubernetes 的安全机制基本就是围绕保护 API Server 来设计的。 比如 kubectl 如果想向 API…

记一次真实MySQL百万数据优化

证实下确实是150万+数据哈 原SQL 原SQL执行计划 原SQL执行时间 5秒左右 原SQL分析 思路来源 整体看下SQL好像没啥可优化的。那咱们就大错特错了。 可能有人会说B表为啥在A表后面不正常呀,因为这是内连接查询不是左右连接查询。A,B表的顺序是可以交换的(实测无影响) 首先我们…

JVM之内存与垃圾回收篇2

文章目录 3 运行时区域3.1 本地方法栈3.2 程序计数器3.3 方法区3.3.1 Hotspot中方法区的演进3.3.2 设置方法区内存大小3.3.3 运行时常量池3.3.4 方法区使用举例3.3.5 方法区的演进3.3.5 方法区的垃圾回收 3.4 栈3.4.1 几个面试题 3.5 堆3.5.1 Minor GC、Major GC和Full GC3.5.2…

linux之Ubuntu系列 find 、 ln 、 tar、apt 指令 软链接和硬链接 snap

查找文件 find 命令 功能非常强大,通常用来在 特定的目录下 搜索 符合条件的文件 find [path] -name “.txt” 记得要加 “ ” 支持通配符 ,正则表达式 包括子目录 ls 不包括 子目录 如果省略路径,表示 在当前路径下,搜索 软链接…

Python爬虫——urllib_微博cookie登陆

cookie登陆适用场景: 适用场景:数据采集的时候,需要绕过登陆,然后进入到某个页面 # 适用场景:数据采集的时候,需要绕过登陆,然后进入到某个页面 import urllib.requesturl https://weibo.cn/7…

Linux 学习记录52(ARM篇)

Linux 学习记录52(ARM篇) 本文目录 Linux 学习记录52(ARM篇)一、汇编语言相关语法1. 汇编语言的组成部分2. 汇编指令的类型3. 汇编指令的使用格式 二、基本数据处理指令1. 数据搬移指令(1. 格式(2. 指令码类型(3. 使用示例 2. 立即数(1. 一条指令的组成 3. 移位操作指令(1. 格式…

Revit中如何创建水的效果及基坑?

一、Revit中如何创建水的效果? 我们在创建建筑的时候会遇上小池塘啊小池子之类的装饰景观,Revit又不像专业的3D软件那样可以有非常真实的水的效果,那么我们该如何简单创建水呢?下面来看步骤: 1、 在水池位置创建一块楼板,并将该…

【DevOps】Atlassian插件开发指南

本文以Bamboo插件开发为例,记录一下插件开发过程。 一、简介 Atlassian Bamboo 6.9.1 是一款持续集成和持续交付(CI/CD)工具,支持使用插件扩展其功能。如果需要开发自己的 Bamboo 插件并添加到 Bamboo 中,则可以参考…

sqli-labs 堆叠注入 解析

打开网页首先判断闭合类型 说明为双引号闭合 我们可以使用单引号将其报错 先尝试判断回显位 可以看见输出回显位为2,3 尝试暴库爆表 这时候进行尝试堆叠注入,创造一张新表 ?id-1 union select 1,database(),group_concat(table_name) from informatio…

mac端好用的多功能音频软件 AVTouchBar for mac 3.0.7

AVTouchBar是来自触摸栏的视听播放器,将跳动笔记的内容带到触摸栏,触摸栏可显示有趣的音频内容,拥有更多乐趣,以一种有趣的方式播放音乐,该软件支持多种音频播放软件,可在Mac上自动更改音乐~ 音频选择-与内…

javascript实现久久乘法口诀表、document、write、console、log

文章目录 正序乘法口诀表倒序乘法口诀表logconsoledocumentwrite 正序乘法口诀表 function multiplicationTable() {for (let i 1; i < 9; i) {let val ;for (let j 1; j < i; j) {document.write(j * i (i * j) &nbsp );val ${j}*${i}${i * j} ;}consol…

【Linux】进程间通信——管道/共享内存

文章目录 1. 进程间通信2. 管道匿名管道命名管道管道的特性管道的应用&#xff1a;简易的进程池 3. System V共享内存共享内存的概念共享内存的结构共享内存的使用代码实现 1. 进程间通信 进程间通信&#xff08;Inter-Process Communication&#xff0c;简称IPC&#xff09;是…

跨网络的通信过程、路由的作用以及默认网关

如下网络拓扑图&#xff0c;交换机0所在的网段为192.168.1.0/24&#xff0c;交换机1所在网段为192.168.2.0/24&#xff0c;且各自有2台主机&#xff1a; 假设PC0&#xff08;192.168.1.10/32&#xff09;要跟PC4&#xff08;192.168.2.11/32&#xff09;通信&#xff0c;如何实…

基于 chinese-roberta-wwm-ext 微调训练 6 分类情感分析模型

一、模型和数据集介绍 1.1 预训练模型 chinese-roberta-wwm-ext 是基于 RoBERTa 架构下开发&#xff0c;其中 wwm 代表 Whole Word Masking&#xff0c;即对整个词进行掩码处理&#xff0c;通过这种方式&#xff0c;模型能够更好地理解上下文和语义关联&#xff0c;提高中文文…

DuiLib中的list控件以及ListContainerElement控件

文章目录 前言1、创建list控件2、创建 ListContainerElement 元素&#xff0c;并添加到 List 控件中,这里的ListContainerElement用xml来表示3、在 ListContainerElement 元素中添加子控件 1、List控件2、ListContainerElement控件 前言 在 Duilib 中&#xff0c;List 控件用于…

Python 集合 add()函数使用详解,集合添加元素

「作者主页」&#xff1a;士别三日wyx 「作者简介」&#xff1a;CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」&#xff1a;小白零基础《Python入门到精通》 add函数使用详解 1、元素的顺序2、可以添加的元素类型3、添加重复的元素4、一次只…