【论文翻译】TKDE 2024 | ST-MAN:用于交通预测的时空记忆增强的多级注意力网络

image-20241023232727384

论文题目Spatio-Temporal Memory Augmented Multi-Level Attention Network for Traffic Prediction
论文链接https://ieeexplore.ieee.org/document/10285880
发表期刊/年份TKDE 2024
关键词城市计算、时空预测、交通预测、记忆网络、注意力网络

摘要
交通预测是城市计算中一个重要的时空预测任务,对交通控制、车辆调度等许多应用具有重要意义。随着城市的扩展和公共交通的发展,远距离和长期的时空相关性在交通预测中发挥着越来越重要的作用。然而,建模远距离空间依赖和长期时间依赖具有挑战性,原因有二:
1)复杂的影响因素,包括空间、时间和外部因素;
2)多种时空相关性,包括远距离和近距离的空间相关性,以及长期和短期的时间相关性。
为了解决这些问题,我们提出了一种用于细粒度交通预测的时空记忆增强多级注意力网络,命名为ST-MAN
具体来说,我们设计了一个时空记忆网络,用于编码和记忆细粒度的空间信息和典型的时间模式。然后,我们提出了一个多级注意力网络,明确地在不同的空间尺度(如网格和区域层次)和时间尺度(如日常和每周层次)下,建模局部的短期时空依赖和全局的长期时空依赖。此外,我们设计了一个外部组件,使用外部因素和空间嵌入作为输入,更高效地生成外部因素对位置的影响。最后,我们设计了一个端到端的框架,通过对比目标和监督目标的联合优化来提升模型性能。
在粗粒度和细粒度的真实数据集上的实验结果表明,ST-MAN模型相比于几种最先进的基线模型具有显著的优势。

ST-MAN

  • 1 引言
  • 2 预备知识
    • 2.1 问题定义
    • 2.2 系统框架
  • 3 方法
    • 3.1 特征图构建
      • 3.1.1 ConvLSTM组件
      • 3.1.2 外部组件
    • 3.2 时空记忆网络
      • 3.2.1 空间嵌入
      • 3.2.2 记忆编码
    • 3.3 多层次注意力网络
      • 3.3.1 短期局部注意力模块 (SLA)
      • 3.3.2 长期全局注意力模块 (LGA)
      • 3.3.3 融合机制
    • 3.4 优化
  • 4 实验
    • 4.1 实验设置
      • 4.1.1 数据集
      • 4.1.2 基线
      • 4.1.3 评估指标
      • 4.1.4 实现
    • 4.2 实验结果
      • 4.2.1 性能比较
      • 4.2.2 消融研究
      • 4.2.3 注意力机制的效果
      • 4.2.4 记忆簇数量的影响
  • 5 相关工作
    • 5.1 城市交通预测
    • 5.2 记忆网络
  • 6 结论

1 引言

近年来,时空预测在我们的日常生活中变得越来越重要,因为它提供了有用且必要的信息,如交通预测、人群流动预测、空气质量预测等。交通预测是基本的时空预测任务之一,由于大规模交通数据的日益可用性,在城市计算中发挥着越来越重要的作用。交通预测旨在基于历史观测数据预测潜在的交通量(例如,流入/流出交通量、乘客上下车需求量等),并可以为政府提供交通控制、车辆调度等方面的见解。然而,它可能会受到多种复杂的空间和时间因素的广泛影响,因此仍面临一些挑战。

交通预测需要解决的主要问题之一是建模复杂的时空依赖性,因为区域内的潜在交通量不仅与其先前的观测数据有关,还受到邻域历史数据的影响。提出了多种时空预测模型来捕捉空间和时间相关性。传统的研究通过一些基于时间序列的方法(如ARIMA)从历史观测数据中获取序列模式,但通常忽略了不同区域之间的空间依赖性,无法建模复杂的非线性关系。近年来,许多深度学习模型在时空预测任务中表现出了可喜的性能。基于卷积神经网络(CNN)的方法能够提取不同区域之间的空间依赖性,例如DeepST。基于循环神经网络(RNN)的方法通过将序列记录嵌入到隐藏状态向量中,擅长于建模时间相关性。

随着城市的扩展和公共交通的发展,长距离和长期的时空因素在城市交通预测中起着更为重要的作用。图1展示了城市中基于网格的交通流使用案例,其中城市被划分为多个小网格,每个区域由多个具有相似功能的网格单元组成,如办公区域和住宅区域。对于空间依赖性,远距离区域之间的交通流量可能彼此相关。一方面,随着人口流动的增加,越来越多在市中心工作的人(如 图1(a) 中的蓝色网格)可能居住在郊区区域(如 图1(a) 中的紫色网格),暗示了交通流转移的长距离地理相关性。另一方面,由于城市的功能划分,两个遥远的区域可能会表现出相似的交通模式(例如,图1(b) 中的网格A和网格B),这暗示了交通流模式的长距离语义相关性。

image-20241106232041367

一些研究在解决这些问题上取得了显著进展。大多数研究通过堆叠多个卷积层来建模区域之间的长距离空间依赖性,因为卷积层只能在局部尺度上捕捉短距离依赖性。例如,ST-ResNet和DeepSTN+通过残差机制实现更深的卷积神经网络。然而,堆叠过多的卷积层以捕捉长距离空间依赖性可能会导致高计算成本和优化难度,限制了在城市中划分大量区域时的预测性能。此外,一些经验知识被引入以捕捉多种时间模式。

很少有工作能够同时建模长距离空间依赖性和长期时间依赖性,因为复杂的时空相关性使得这一任务变得极具挑战性。直观上,结合CNN和RNN来同时捕捉空间和时间相关性是合理的。例如,STDN采用CNN提取空间特征,然后输入RNN进一步建模时间依赖性。然而,分别建模长距离空间依赖性和长期时间依赖性无法捕捉到时空相关性的内部连接。进一步地,研究人员尝试在标准ConvLSTM单元中设计额外的记忆单元以编码长距离空间关系。然而,由于潜在向量通常过小以表达复杂的时空相关性,它们在长期时间依赖性方面的表示能力仍然有限。因此,捕捉复杂的长距离和长期时空依赖性仍然具有挑战性。

为填补这一空白,我们旨在引入外部记忆模块,以尝试编码和记忆时空信息。最近,记忆网络在许多序列预测任务中表现出了可喜的性能,如序列推荐和问答任务等。相比其他传统的RNN/LSTM模型,记忆网络通过合适的读写操作来存储隐藏向量的外部记忆组件。然而,传统的记忆网络无法直接用于编码复杂的时空相关性,我们面临以下三个特定的挑战:

  1. 如何设计记忆网络以同时编码和记忆空间和时间信息? 之前的外部记忆网络主要关注于存储序列特征,忽略了空间信息。然而,城市中区域之间的空间依赖性对于交通预测建模至关重要。
  2. 如何提取长距离空间依赖性和长期时间依赖性? 尽管一些有效策略已经被设计用于从外部记忆中读取和更新信息,这些研究仍忽略了时间模式和空间分布之间复杂的内部连接。
  3. 如何有效建模外部因素对每个网格单元的影响? 先前的研究使用全连接层来学习外部因素(如天气条件、假期)对多个网格的影响,映射到高维交通流图中。然而,这种方法随着网格单元数量的增加,参数急剧增加。

在本文中,我们提出了一种用于交通预测的时空记忆增强多层次注意力网络(ST-MAN)。与现有方法相比,我们的模型能够学习长距离空间依赖性和长期时间模式,以实现高效的时空预测。具体而言,我们设计了一个时空记忆网络(STMN)来解决第一个挑战。STMN由键记忆矩阵和值记忆矩阵组成,通过空间嵌入编码全球空间相关性作为先验知识,并通过记忆编码捕捉长期时间模式。针对第二个挑战,我们提出了一个包含注意力机制的多层次注意力网络(MAN),显式建模短期局部和长期全球的时空相关性。MAN包括两个主要的注意力模块:短期局部注意力模块专注于在两个空间粒度(即跨网格和跨区域流动转换)上捕捉交通流动转换的地理时空依赖性;长期全球注意力模块旨在建模具有相似功能区域的交通流模式的语义时空依赖性。该模块利用时空记忆网络来存储和检索长期和长距离的时空信息。为了解决第三个挑战,我们引入了一个外部组件,该组件学习外部因素的空间感知影响。该组件基于空间嵌入为每个网格单元生成特定响应,从而减少了大量参数的需求。最后,设计了一个端到端框架用于交通预测,并同时基于监督目标和对比目标进行训练。

总结来说,我们的贡献如下:

  • 我们提出了一种用于交通预测的时空记忆增强多层次注意力网络ST-MAN,能够建模长距离空间依赖性和长期时间依赖性。
  • 我们设计了一个时空记忆网络,能够编码和记忆细粒度的空间信息和时间模式。据我们所知,这是首次引入外部记忆网络用于时空预测,以丰富时空模型的表达能力。
  • 我们引入了一个多层次注意力网络,有效地捕捉地理邻居间的短期局部时空依赖性和具有相似功能的语义邻居间的长期全球相关性。
  • 在粗粒度和细粒度的真实世界数据集上的大量实验表明,我们的模型相比最先进的模型取得了显著的改进。

本文其余部分组织如下。第2节介绍问题的公式化和系统框架。第3节详细阐述ST-MAN的设计。第4节展示实验评估。第5节简要回顾相关工作。最后,第6节讨论结论和未来工作。

2 预备知识

在本节中,我们首先描述交通流量预测问题的定义。然后,我们简要介绍我们提出的框架。为了简洁起见,我们在表1中展示了本文所使用的符号表。

image-20241106230745196

2.1 问题定义

定义 1 (网格单元)

根据以往的工作[13],我们将一个城市区域表示为一个矩形,并沿经度和纬度将其划分为一个 H × W H \times W H×W 的网格地图,表示为 C ∈ R H × W C \in \mathbb{R}^{H \times W} CRH×W。如图1所示,总共有 H × W H \times W H×W 个网格单元,它们具有相同的大小。

定义 2 (区域)

由于城市功能分区的原因,相邻的网格单元通常具有相似的功能,如图1中的办公区域(蓝色网格)和住宅区域(紫色网格)所示。使用粗粒度的区域,与精细粒度的网格地图相比,可以有效提取全局语义邻域信息。

定义 3 (基于网格的交通量图)

T = { T 1 , T 2 , …   } \mathcal{T} = \{T_1, T_2, \dots\} T={T1,T2,} 代表一系列出行轨迹,每个出行轨迹 T i ∈ T T_i \in \mathcal{T} TiT 包含一个在 p p p 个连续时间间隔内的一系列地理坐标。我们定义在时间步 t t t 时网格 g i , j g_{i,j} gi,j 内的交通量为在该时间间隔内该网格中的出行次数。正式地,交通量图可以表示为一个三维张量 M t ∈ R K × H × W \mathcal{M}_t \in \mathbb{R}^{K \times H \times W} MtRK×H×W,其中 K K K 是交通量测量的数量。例如,当考虑起点和终点交通量时, K = 2 K = 2 K=2

交通量图 ( M t ) 0 , i , j (\mathcal{M}_t)_{0,i,j} (Mt)0,i,j ( M t ) 1 , i , j (\mathcal{M}_t)_{1,i,j} (Mt)1,i,j 分别表示在时间间隔 t t t 进入/离开网格 g i , j g_{i,j} gi,j 的出行次数。

问题陈述 (交通量预测)

给定历史观察到的交通量 { M t ∣ t = 1 , … , T − 1 } \{\mathcal{M}_t | t = 1, \dots, T - 1\} {Mtt=1,,T1},交通量预测问题旨在预测下一时间间隔 T T T 的交通量图 M T ∈ R K × H × W \mathcal{M}_T \in \mathbb{R}^{K \times H \times W} MTRK×H×W

2.2 系统框架

ST-MAN 的框架如图2所示,由四个主要组件组成。

image-20241106230816029

  1. 特征图构建
    给定历史交通量数据序列 { M t ∣ t = 1 , … , T − 1 } \{\mathcal{M}_t | t = 1, \dots, T - 1\} {Mtt=1,,T1},我们的目标是构建用于预测模型的输入特征图,该特征图结合了空间、时间和外部信息。一方面,我们考虑分布在整个城市的所有网格单元,并选择它们对应的最近、每日和每周时间步作为输入数据,表示为 X c , X p , X q X_c, X_p, X_q Xc,Xp,Xq,并进一步将其输入到ConvLSTM中以获取空间和时间信息。另一方面,某些时间步的外部因素,表示为 X e x t X_{ext} Xext,例如天气状况和假日,也提供了外部信息,以通过参数嵌入捕获外部信息。

  2. 时空记忆网络 (STMN)
    STMN 旨在编码细粒度的空间依赖性和具有代表性的时间模式,然后将它们存储在关键记忆矩阵和值记忆矩阵中,作为先验知识增强模型的表达能力。需要注意的是,关键记忆矩阵通过空间嵌入预训练以获取每个区域的嵌入向量,而值记忆矩阵则与预测模型联合学习。特别地,为了提高记忆的保真度,STMN 通过最小化对比目标来编码不同区域间具有辨别性和不变的时间特征。

  3. 多层次注意力网络 (MAN)
    MAN 采用注意力机制显式地建模短期局部和长期全局时空依赖性。具体来说,MAN 包括两个设计用于分别捕获不同类型的时空相关性的注意力模块。第一个是短期局部注意力模块,利用 ConvLSTM 从最近的观测中获取的信息,专注于捕捉地理网格单元之间交通流量转换的短期局部时空依赖性。第二个是长期全局注意力模块,提取通过时空记忆网络集成的长期空间依赖性和长期时间依赖性。此模块负责捕捉城市中语义单元间的相似交通流模式。为了获得全面的特征表示,采用了融合模块自适应集成不同的特征。这一融合过程确保最终的特征表示捕捉到短期局部和长期全局时空依赖性的综合影响。

  4. 交通量预测
    对于交通量预测任务,提取出的高级时空特征被输入到预测层(例如卷积层)中,以生成时间步 T T T 的预测交通量图 M ^ T \hat{\mathcal{M}}_T M^T

3 方法

本节详细阐述了我们提出的模型。我们首先基于ConvLSTM和外部组件构建用于预测模型的特征图,包含空间和时间信息以及外部信息。然后,我们提出了时空记忆网络(STMN),用于编码细粒度的空间信息和具有代表性的时间模式。此外,构建在时空记忆网络之上的多层次注意力机制用于显式建模短期局部依赖性和长期全局依赖性,这些依赖性通过有效的融合机制自适应地集成,以增强模型在交通预测中的表现。最后,我们提供了一种端到端的方法,通过联合损失函数,包括对比目标和监督目标来优化模型。

3.1 特征图构建

输入数据在交通预测中起着至关重要的作用,因为它包含了各种复杂因素,包括空间、时间和外部因素。特别地,网格单元的交通量数据不仅依赖于其自身的历史观测值,还表现出与其相邻网格历史记录的相关性。在交通预测中,考虑这些时空依赖性是非常重要的。此外,交通量数据受到外部因素的强烈影响,如天气状况和假日。由于信息种类的多样性,必须构建专门为交通预测量身定制的特征,考虑到原始数据的独特特性。为了解决这一问题,我们通过基于原始交通量数据创建输入特征图,作为模型学习有效和可区分表示的基础,使其能够捕捉城市交通的潜在模式和动态。

3.1.1 ConvLSTM组件

城市中固有的独特物理拓扑结构和社会关系结构使得城市交通数据表现出多样的时空依赖性,涵盖不同的尺度,如短期和长期时间相关性,以及短程和长程空间相关性。为了建模这些复杂的时空相关性,我们首先将城市中每个时间步的历史交通量转换为类似图像的矩阵(即基于网格的交通量图),并将时间划分为三个片段,包括最近时间步、每日和每周周期时间步。具体来说,给定交通量数据 { M t ∣ t = 1 , … , T − 1 } \{\mathcal{M}_t | t = 1, \dots, T - 1\} {Mtt=1,,T1},我们为未来某一特定时间步 T T T 构建三个输入序列,分别代表邻近性、周期性和趋势性。

对于邻近性,我们考虑前 l c l_c lc 个时间步的交通量数据,表示为
X c = [ M T − l c , M T − ( l c − 1 ) , … , M T − 1 ] ∈ R K l c × H × W X_c = [\mathcal{M}_{T-l_c}, \mathcal{M}_{T-(l_c-1)}, \dots, \mathcal{M}_{T-1}] \in \mathbb{R}^{K l_c \times H \times W} Xc=[MTlc,MT(lc1),,MT1]RKlc×H×W

对于周期性,我们选择前 l p l_p lp 天的每日历史记录,表示为

X p = [ M T − l p × p , M T − ( l p − 1 ) × p , … , M T − p ] ∈ R K l p × H × W X_p = [\mathcal{M}_{T-l_p \times p}, \mathcal{M}_{T-(l_p-1) \times p}, \dots, \mathcal{M}_{T-p}] \in \mathbb{R}^{K l_p \times H \times W} Xp=[MTlp×p,MT(lp1)×p,,MTp]RKlp×H×W

其中 p p p 表示时间跨度(例如每天24小时)。

对于趋势性,我们选择前 l q l_q lq 周的每周历史记录,表示为
X q = [ M T − l q × q , M T − ( l q − 1 ) × q , … , M T − q ] ∈ R K l q × H × W X_q = [\mathcal{M}_{T-l_q \times q}, \mathcal{M}_{T-(l_q-1) \times q}, \dots, \mathcal{M}_{T-q}] \in \mathbb{R}^{K l_q \times H \times W} Xq=[MTlq×q,MT(lq1)×q,,MTq]RKlq×H×W

其中 q q q 表示时间跨度(例如每周24小时 × 7天)。

在数据准备阶段之后,这三种输入序列被输入到组件中,该组件旨在将准备好的交通数据转换为表示其时空属性的嵌入。幸运的是,ConvLSTM[28]通过其独特的卷积操作能够捕捉到时空特性。卷积长短时记忆网络(ConvLSTM)是全连接LSTM网络的扩展,用卷积操作代替所有线性层,以捕捉除时间依赖性之外的空间依赖性。在本文中,ConvLSTM用于捕捉不同类型输入交通量的数据中的时空信息。请注意,三个ConvLSTM组件共享相同的网络结构,但使用不共享的参数。

ConvLSTM单元的结构如图3所示。

image-20241106230840968

形式上,设 H t − 1 \mathcal{H}_{t-1} Ht1 C t − 1 \mathcal{C}_{t-1} Ct1 分别表示时间间隔 t − 1 t-1 t1 时的隐藏状态和单元状态。在当前时间 t t t,给定输入数据 X t X_t Xt,我们可以通过更新操作对时空表示 H t \mathcal{H}_t Ht 进行建模,并获得新的记忆单元 C t \mathcal{C}_t Ct
i t = σ ( W x i ∗ X t + W h i ∗ H t − 1 + b i ) i_t = \sigma (W_{xi} * X_t + W_{hi} * \mathcal{H}_{t-1} + b_i) it=σ(WxiXt+WhiHt1+bi)

f t = σ ( W x f ∗ X t + W h f ∗ H t − 1 + b f ) f_t = \sigma (W_{xf} * X_t + W_{hf} * \mathcal{H}_{t-1} + b_f) ft=σ(WxfXt+WhfHt1+bf)

o t = σ ( W x o ∗ X t + W h o ∗ H t − 1 + b o ) o_t = \sigma (W_{xo} * X_t + W_{ho} * \mathcal{H}_{t-1} + b_o) ot=σ(WxoXt+WhoHt1+bo)

θ t = tanh ⁡ ( W x c ∗ X t + W h c ∗ H t − 1 + b c ) \theta_t = \tanh (W_{xc} * X_t + W_{hc} * \mathcal{H}_{t-1} + b_c) θt=tanh(WxcXt+WhcHt1+bc)

C t = f t ∘ C t − 1 + i t ∘ θ t \mathcal{C}_t = f_t \circ \mathcal{C}_{t-1} + i_t \circ \theta_t Ct=ftCt1+itθt

H t = o t ∘ tanh ⁡ ( C t ) (1) \mathcal{H}_t = o_t \circ \tanh (\mathcal{C}_t) \tag{1} Ht=ottanh(Ct)(1)

其中 ∗ * 表示卷积操作, ∘ \circ 表示Hadamard积。 W W W b b b 是可学习参数, σ \sigma σ 是逻辑Sigmoid函数。 i t , f t , o t i_t, f_t, o_t it,ft,ot 分别是输入门、遗忘门和输出门。

3.1.2 外部组件

外部因素,如天气状况和假期,对城市不同区域的交通量有显著影响。例如,周末的交通模式可能与工作日大不相同。之前的研究主要使用全连接层将外部因素的影响映射到每个网格单元,包括使用嵌入层来组合每个因素,最终层将这些嵌入映射到与输入流量图相同形状的高维特征。然而,在细粒度网格级别(例如128 × 128 交通量图)预测交通量时,随着网格单元数量的增加,参数爆炸问题变得严峻。具体来说,模型的最后一层需要大量的参数,这些参数数量与最终层的网格单元数量成正比。

为了解决上述挑战,我们提出了一种外部组件,用于学习不同网格单元中外部因素的地点感知影响。之前的研究提供了有关外部因素与网格单元之间关系的见解,表明相关网格单元可能对外部因素表现出类似的响应。因此,我们旨在利用特定于每个网格单元的信息,生成针对动态外部因素的网格特定响应,而与网格单元的数量无关。

在3.2.1节中,我们引入了空间嵌入方法,该方法捕获城市内网格之间的空间相关性。通过利用这些网格嵌入,我们可以计算外部因素对每个网格单元的影响,从而避免了引入大量额外参数的需求。

对于外部因素,我们首先利用嵌入层,接着通过全连接层生成参数嵌入,根据最新工作,这些工作基于矩阵分解[29] 计算基于单元的响应。为了简化处理,我们使用 X e x t ∈ R l e X_{ext} \in \mathbb{R}^{l_e} XextRle 表示外部特征的嵌入向量,该向量进一步输入到两个全连接层中,生成参数嵌入 M e x t ∈ R l p M_{ext} \in \mathbb{R}^{l_p} MextRlp。为了使外部因素的地点感知影响得以实现,我们引入网格单元嵌入以学习每个网格单元的特定响应。具体而言,给定网格单元嵌入 M k ∈ R H × W × D k m M_k \in \mathbb{R}^{H \times W \times D_{km}} MkRH×W×Dkm (即在3.2.1节中介绍的时空记忆网络中的关键矩阵),我们可以将参数嵌入重新整形成 M e x t ∈ R D k m × D f M_{ext} \in \mathbb{R}^{D_{km} \times D_f} MextRDkm×Df,其中 l p = D k m × D f l_p = D_{km} \times D_f lp=Dkm×Df。这样,我们可以计算每个网格单元的特定外部特征响应 F e x t ∈ R H × W × D f F_{ext} \in \mathbb{R}^{H \times W \times D_f} FextRH×W×Df,满足 F e x t = M k M e x t F_{ext} = M_k M_{ext} Fext=MkMext。我们提出的外部组件的详细计算过程如图4所示。

image-20241106230907677

3.2 时空记忆网络

尽管传统的基于CNN和RNN的模型在同时建模长程空间和长程时间相关性方面能力有限,我们旨在通过显式和自适应的方式引入先验知识,这些知识可以基于历史数据进行编码,以增强模型的表达能力。具有读取和写入长时记忆槽能力的记忆网络已被用于提供额外的知识表示,从而增加了模型在许多序列任务中的容量。然而,大多数记忆网络不足以捕捉和存储城市中不同区域的空间信息。因此,我们提出了时空记忆网络(STMN),用于同时提供额外的空间和时间信息,以进行时空预测。

我们引入了关键矩阵和值矩阵来存储每个区域的空间相关性和时间模式,分别存储在记忆槽中,然后设计了在这两个矩阵上进行的有效训练操作。

具体来说,首先将网格级别的空间信息通过空间嵌入作为先验知识进行编码到关键矩阵中。通过在关键矩阵中加入空间嵌入向量,我们可以获得区域之间的全局关系。此外,值矩阵的设计使模型能够学习如何利用关键查询与当前信息相关的记忆。注意,我们将在下一节讨论如何有效读取记忆,而本节集中于如何设计时空记忆以有效地记住可用信息。一般来说,STMN通过编码、存储和更新时空模式的能力,能够丰富模型的表达能力。

3.2.1 空间嵌入

通常,对于交通量预测,城市中的每个网格单元并非独立的,而是彼此相关。因此,我们的目标是编码城市内网格之间的全局空间相关性,并将其存储在关键记忆矩阵中,作为网格级别的先验知识。特别是,我们从两个方面衡量网格的相关性:空间分布功能相似性

  • 对于空间分布,网格单元的潜在交通量可能会受到附近网格单元的影响。
  • 对于功能相似性,一些网格单元可能会与同一区域中的其他网格单元(例如住宅区)共享相似的时间模式。

受word2vec的启发,我们提出了一种基于交通量的空间嵌入方法,将每个网格编码为向量,以便相关网格在潜在空间中彼此靠近。直观上,嵌入矩阵应通过最小化相关网格在潜在空间中的距离来学习。为了有效地学习嵌入矩阵,我们进一步为每个网格构建训练实例,采样正例和负例,考虑网格的相关性。形式上,给定历史交通量图 M t \mathcal{M}_t Mt,对于城市中的每个网格 g k g_k gk,我们首先计算基于曼哈顿距离的距离矩阵 D k ∈ R H × W D_k \in \mathbb{R}^{H \times W} DkRH×W 来衡量空间分布,并计算基于皮尔逊相关系数的相似矩阵 S k ∈ R H × W S_k \in \mathbb{R}^{H \times W} SkRH×W 来衡量功能相似性。然后将它们结合以测量城市中网格 g k g_k gk 和其他网格之间的相关性, C k = S k − λ D k C_k = S_k - \lambda D_k Ck=SkλDk,其中 λ \lambda λ 是平衡参数。

M k ∈ R N × D k m M_k \in \mathbb{R}^{N \times D_{km}} MkRN×Dkm 表示关键矩阵,用于存储每个网格 g k g_k gk 的嵌入向量,其中 N = H × W N = H \times W N=H×W D k m D_{km} Dkm 是嵌入向量的维度。为了根据各种移动模式自适应地获取和更新空间嵌入向量,我们使用卷积网络来整合历史交通量数据 M a \mathcal{M}_a Ma 和前一阶段学习的嵌入矩阵 M b \mathcal{M}_b Mb

image-20241106232537563

图5所示, M k = W r ∗ ( M a ⊕ M b ) M_k = W_r * (\mathcal{M}_a \oplus \mathcal{M}_b) Mk=Wr(MaMb),其中 ∗ * 表示卷积操作, W r W_r Wr 是可学习参数矩阵。为了有效学习嵌入矩阵,我们为每个网格基于 C k C_k Ck 和阈值 μ \mu μ 采样正例和负例构建训练实例。具体来说,对于网格 g k g_k gk,我们采样相关值 c k p c_{kp} ckp 大于 μ \mu μ 的网格 g p g_p gp 作为正例,而采样相关值 c k n c_{kn} ckn 小于 μ \mu μ 的网格 g n g_n gn 作为负例,分别表示为三元组 ( g k , g p , 1 ) (g_k, g_p, 1) (gk,gp,1) ( g k , g n , 0 ) (g_k, g_n, 0) (gk,gn,0)。然后,空间嵌入的优化目标定义为:
y ^ = ( M k ⋅ z i ) T ( M k ⋅ z j ) \hat{y} = (M_k \cdot z_i)^T (M_k \cdot z_j) y^=(Mkzi)T(Mkzj)

L r = − 1 N ∑ n = 1 N [ y n × log ⁡ ( y ^ n ) + ( 1 − y n ) × log ⁡ ( 1 − y ^ n ) ] (2) \mathcal{L}^r = -\frac{1}{N} \sum_{n=1}^N [y_n \times \log(\hat{y}_n) + (1 - y_n) \times \log(1 - \hat{y}_n)] \tag{2} Lr=N1n=1N[yn×log(y^n)+(1yn)×log(1y^n)](2)

其中 y ^ n \hat{y}_n y^n 是预测值, y n ∈ ( 0 , 1 ) y_n \in (0,1) yn(0,1) 是真实标签。 z i z_i zi z j z_j zj 是每个训练实例中表示网格索引的one-hot向量,因此 M k ⋅ z i M_k \cdot z_i Mkzi M k ⋅ z j M_k \cdot z_j Mkzj 分别表示网格 g i g_i gi g j g_j gj 的嵌入向量。

3.2.2 记忆编码

通过引入键矩阵,我们进一步引入了一个值矩阵来记忆长期的时空模式。为了在记忆组件中保持有效的信息,一些工作设计了受神经图灵机(NTM)启发的写/更新操作,每一步添加新信息之前会先清空记忆网络中的记忆矩阵。然而,这些更新操作忽略了不同网格之间的空间相关性。例如,一些商业区的网格可能具有相似的时间模式,这与居住区的网格不同。因此,我们设计了一种记忆编码策略,将具有代表性的时间模式学习并写入值矩阵中。

在STM-N的值矩阵中,我们的目标不仅是编码长期的时空模式,还要学习到可以迁移到其他具有相同功能的网格的良好表示。直观上,对于交通流量预测任务,长期的时空模式比短期模式在一段时间内更加稳定。因此,我们将值矩阵与预测模型共同学习为共享知识,而不是在每一步使用传统的写操作。这里,我们展示记忆网络的优化目标,读取操作将在下一节介绍。

最近,许多基于对比学习的模型在无监督学习任务中取得了出色的表现,因为它们具有强大的学习良好表示的能力。对比学习方法的目标是学习一个嵌入空间,其中相似样本对彼此保持靠近,而不同样本则相隔较远,因此我们可以获得对同一实例不同视角不变的表示。受对比学习的启发,我们旨在学习值矩阵以编码在不同网格中不变的代表性特征,并使用对比NT-Xent损失进行训练。

M V ∈ R N × D v m \mathbf{M}_V \in \mathbb{R}^{N \times D_{vm}} MVRN×Dvm 为值矩阵,用于存储每个网格的特征表示,其中 N = H × W N = H \times W N=H×W D v m D_{vm} Dvm 是特征表示的维度。我们首先使用K-means++方法将所有网格聚类为 S S S 类,然后随机从 S S S 个聚类中采样一个包含 S S S 个网格的小批量 B 1 = ( g 1 1 , … , g S 1 ) B^1 = (g^1_1, \dots, g^1_S) B1=(g11,,gS1),然后再采样另一个包含 S S S 个网格的小批量 B 2 = ( g 1 2 , … , g S 2 ) B^2 = (g^2_1, \dots, g^2_S) B2=(g12,,gS2)。对于簇中的任意两个网格 ( g i 1 , g i 2 ) (g^1_i, g^2_i) (gi1,gi2),我们将该对视为正例,而将小批量 B 1 B^1 B1 B 2 B^2 B2 中的其他 ( 2 S − 2 ) (2S - 2) (2S2) 对视为负例。令 v i \mathbf{v}_i vi 表示值矩阵中网格 g i g_i gi 的特征向量,那么正例对 ( g i 1 , g i 2 ) (g^1_i, g^2_i) (gi1,gi2) 的损失函数定义为:

L ( g i 1 , g i 2 ) = − log ⁡ exp ⁡ ( sim ( v i 1 , v i 2 ) / τ ) φ 1 + φ 2 \mathcal{L}(g^1_i, g^2_i) = - \log \frac{\exp(\text{sim}(\mathbf{v}^1_i, \mathbf{v}^2_i)/\tau)}{\varphi_1 + \varphi_2} L(gi1,gi2)=logφ1+φ2exp(sim(vi1,vi2)/τ)

φ 1 = ∑ s = 1 S I [ s ≠ i ] exp ⁡ ( sim ( v i 1 , v s 1 ) / τ ) (3) \varphi_1 = \sum^S_{s=1} \mathbb{I}[s \neq i] \exp(\text{sim}(\mathbf{v}^1_i, \mathbf{v}^1_s)/\tau) \tag{3} φ1=s=1SI[s=i]exp(sim(vi1,vs1)/τ)(3)

φ 2 = ∑ s = 1 S I [ s ≠ i ] exp ⁡ ( sim ( v i 1 , v s 2 ) / τ ) \varphi_2 = \sum^S_{s=1} \mathbb{I}[s \neq i] \exp(\text{sim}(\mathbf{v}^1_i, \mathbf{v}^2_s)/\tau) φ2=s=1SI[s=i]exp(sim(vi1,vs2)/τ)

其中 sim ( v 1 , v 2 ) = v 1 ⊤ v 2 ∥ v 1 ∥ ∥ v 2 ∥ \text{sim}(\mathbf{v}_1, \mathbf{v}_2) = \frac{\mathbf{v}_1^\top \mathbf{v}_2}{\|\mathbf{v}_1\| \|\mathbf{v}_2\|} sim(v1,v2)=v1∥∥v2v1v2 为余弦相似度, τ \tau τ 为控制相似度测量水平的温度参数。

最终损失在所有正例对上求和得到:

L c = 1 2 S ∑ i = 1 S L ( g i 1 , g i 2 ) + L ( g i 2 , g i 1 ) (4) \mathcal{L}_c = \frac{1}{2S} \sum^S_{i=1} \mathcal{L}(g^1_i, g^2_i) + \mathcal{L}(g^2_i, g^1_i) \tag{4} Lc=2S1i=1SL(gi1,gi2)+L(gi2,gi1)(4)

为了进一步学习具有判别性的表示以增强模型的鲁棒性,利用一个全连接层来预测每个网格的类别 c i ∈ { 1 , … , S } c_i \in \{1, \dots, S\} ci{1,,S}。输出 c ^ i \hat{c}_i c^i 是一个 S S S 类的softmax,它预测了城市中 S S S 个不同簇的概率分布,其中 N N N 是城市中网格的总数。

c ^ i = tanh ⁡ ( W c v i + b c ) \hat{c}_i = \tanh(\mathbf{W}_c \mathbf{v}_i + \mathbf{b}_c) c^i=tanh(Wcvi+bc)

L s = − 1 N ∑ i = 1 N ∑ j = 1 S c i j log ⁡ ( c ^ i j ) (5) \mathcal{L}_s = - \frac{1}{N} \sum^N_{i=1} \sum^S_{j=1} c_{ij} \log (\hat{c}_{ij}) \tag{5} Ls=N1i=1Nj=1Scijlog(c^ij)(5)

3.3 多层次注意力网络

一般来说,城市交通预测涉及两种主要的时空依赖类型:地理时空相关性和语义时空相关性。第一种依赖关系来自不同区域交通流量之间的地理交互,称为动态交通流量转移。这种依赖性在相邻区域之间普遍存在,主要受近期交通流量的影响。我们称之为短期局部时空依赖性。第二种依赖关系源于交通流模式的语义关联,特别是流量的长期变化。此外,由于城市的功能分区,远距离区域可能会表现出类似的交通模式,从而导致长期全局时空依赖性。这两种时空依赖性本质不同,但彼此紧密相连,在城市交通预测中发挥决定性作用。

尽管ConvLSTM在某些时空预测任务中表现出色,但在建模长距离和长期时空依赖性方面存在局限性,这在交通流量预测中尤为重要。受自注意力机制启发,它能够聚合所有空间位置的特征,我们利用自注意力在全球范围内捕捉空间相关性。此外,上节介绍的STM-N被用于编码和存储长距离网格级空间依赖性和长期代表性时间模式。因此,我们提出了多层次注意力网络(MAN),它利用外部记忆网络和注意力网络的力量,同时捕捉长期全局相关性和短期局部相关性。具体而言,我们将ConvLSTM计算的特征用作自注意力模块的输入,而不是原始的交通流数据,因为它编码了输入数据的序列信息。MAN的框架如图6所示,包含两个模块:短期局部注意力和长期全局注意力。

image-20241106231024501

3.3.1 短期局部注意力模块 (SLA)

SLA的目标是在局部尺度上建模短距离空间依赖性和短期时间依赖性,因为大多数交通流量的转移涉及相邻区域之间的交互。正如图7所示,大多数交通流量转移发生在10公里或更小的范围内,很少有覆盖整个城市的交通流量。因此,考虑到人类移动轨迹的地理属性,建模短期局部空间依赖性是合理的。正如3.1.1节中所示,ConvLSTM的输出 F c F_c Fc捕捉了来自输入序列的短期时间信息,因此我们将其用作SLA的输入特征。此外,考虑到外部因素,SLA的输入特征图可以与外部特征 F e x t F_{ext} Fext拼接。接下来,我们将输入映射 F s F_s Fs输入到SLA中,以提取短期局部时空特征,其中 F s = F c + F e x t F_s = F_c + F_{ext} Fs=Fc+Fext F s ∈ R H × W × D F_s \in \mathbb{R}^{H \times W \times D} FsRH×W×D

image-20241106231042406

对于传统的空间注意力模块中的简单自注意力机制,城市中的每个网格都会与所有其他网格进行交互。然而,从短距离空间视角来看,仅需要相邻网格之间的交互。因此,我们为SLA引入了地理掩码策略,以建模短期局部时空依赖性。具体来说,当两个网格之间的距离较短时,SLA中会考虑它们之间的交互。为了便于解释,我们将从跨网格注意力和跨区域注意力两个方面介绍局部注意力模块。

为简化表示,我们将城市在两个空间尺度上表示为基于网格的映射 C ∈ R H × W C \in \mathbb{R}^{H \times W} CRH×W(例如, 128 × 128 128 \times 128 128×128)和基于区域的映射 C ′ ∈ R H ′ × W ′ C' \in \mathbb{R}^{H' \times W'} CRH×W(例如, 32 × 32 32 \times 32 32×32),其中一个区域包含一些网格。我们使用二值掩码矩阵 M m a s k M_{mask} Mmask捕捉相邻网格之间的交互。当两个网格之间的距离在一定范围内时, M m a s k M_{mask} Mmask的权重设为1,否则设为0。直观上,处于同一区域的所有网格在地理上具有相关性,因为它们彼此相对接近,促使我们为它们的掩码权重设为1。相比之下,对于区域,掩码权重仅在网格所在的两个区域之间的距离小于阈值时才设为1。如图6中的SLA模块所示,对于红色区域,区域内存在跨网格交互,且只有蓝色区域内的网格参与跨区域交互。

一般而言,跨网格和跨区域交互本质上涉及网格之间的相互作用。上一节中介绍的网格划分仅用于直观地说明如何构建掩码矩阵。在获得掩码矩阵 M m a s k M_{mask} Mmask后,我们定义网格级的空间注意力。具体而言,短期局部注意力模块通过 1 × 1 1 \times 1 1×1卷积操作将 F s F_s Fs映射到三个矩阵:查询 Q c ∈ R N × C ~ Q_c \in \mathbb{R}^{N \times \tilde{C}} QcRN×C~、键 K c ∈ R N × C ~ K_c \in \mathbb{R}^{N \times \tilde{C}} KcRN×C~和值 V c ∈ R N × C V_c \in \mathbb{R}^{N \times C} VcRN×C,其中 N = H × W N = H \times W N=H×W C ~ \tilde{C} C~ C C C为通道数。令 A c ∈ R N × N A_c \in \mathbb{R}^{N \times N} AcRN×N为注意力图, A c A_c Ac中的每个元素明确指示了输入特征图中两个网格之间的相关性。最终,时空表示 F c F_c Fc通过对 V c V_c Vc进行加权求和得到,权重由 A c A_c Ac决定。为简便起见,我们将上述短期局部注意力模块定义如下,其中 W c Q W_c^Q WcQ W c K W_c^K WcK W c V W_c^V WcV是需要学习的投影矩阵。

Q c = W c Q F s , K c = W c K F s , V c = W c V F s Q_c = W_c^Q F_s, \quad K_c = W_c^K F_s, \quad V_c = W_c^V F_s Qc=WcQFs,Kc=WcKFs,Vc=WcVFs

A c = softmax ( ( Q c ( K c ) T ) ∘ M m a s k ) A_c = \text{softmax} \left( \left( Q_c (K_c)^T \right) \circ M_{mask} \right) Ac=softmax((Qc(Kc)T)Mmask)

F c = SLA ( Q c , K c , V c , M m a s k ) = A c V c (6) F_c = \text{SLA}(Q_c, K_c, V_c, M_{mask}) = A_c V_c \tag{6} Fc=SLA(Qc,Kc,Vc,Mmask)=AcVc(6)

3.3.2 长期全局注意力模块 (LGA)

LGA模块旨在通过考虑具有相似功能的远距离区域可能表现出类似的交通流模式,捕获长距离的全局时空依赖性。不同于SLA模块的输入特征,来自ConvLSTM组件的三种时间特征(即 F c F_c Fc F p F_p Fp F q F_q Fq)以及外部特征 F e x t F_{ext} Fext将被输入到LGA模块中进行全局特征提取,表示为 F l ∈ R H × W × D F_l \in \mathbb{R}^{H \times W \times D} FlRH×W×D,其中 F l = F c + F p + F q + F e x t F_l = F_c + F_p + F_q + F_{ext} Fl=Fc+Fp+Fq+Fext。此外,细粒度的空间关系和时间交通模式被编码并存储在STM-N中。因此,LGA模块结合STM-N的信息来显式建模长距离空间和长期时间依赖性。

给定STM-N中的键记忆矩阵 M k ∈ R H × W × D k m M_k \in \mathbb{R}^{H \times W \times D_{km}} MkRH×W×Dkm和值记忆矩阵 M v ∈ R H × W × D v m M_v \in \mathbb{R}^{H \times W \times D_{vm}} MvRH×W×Dvm,以及得到的特征 F l ∈ R H × W × D F_l \in \mathbb{R}^{H \times W \times D} FlRH×W×D,LGA旨在从内部记忆和外部记忆中提取长距离空间特征和长期时间特征。然而,尽管城市中网格数量庞大(例如 128 × 128 128 \times 128 128×128),在网格级别捕获时空依赖性可能导致计算效率低下。因此,LGA的目标是通过从外部记忆网络读取关键信息来有效学习长距离空间依赖性和长期时间特征。

我们首先执行从网格空间到区域空间的下采样操作,以获取语义信息,更有利于捕捉全局依赖性。具体来说,我们采用卷积层将LGA的网格级输入映射到区域级特征:

F l ′ = C o n v ( F l ) ∈ R H ′ × W ′ × D F'_l = Conv(F_l) \in \mathbb{R}^{H' \times W' \times D} Fl=Conv(Fl)RH×W×D

M k ′ = C o n v ( M k ) ∈ R H ′ × W ′ × D k m M'_k = Conv(M_k) \in \mathbb{R}^{H' \times W' \times D_{km}} Mk=Conv(Mk)RH×W×Dkm

M v ′ = C o n v ( M v ) ∈ R H ′ × W ′ × D v m (7) M'_v = Conv(M_v) \in \mathbb{R}^{H' \times W' \times D_{vm}} \tag{7} Mv=Conv(Mv)RH×W×Dvm(7)

为了对具有相似功能的区域的交通流模式进行语义时空依赖建模,注意力机制用于从记忆网络中读取长距离空间信息和长期时间信息。我们将区域级特征映射到新的特征空间,以提高模型的灵活性,具体为查询矩阵 Q m ∈ R N × C ~ Q_m \in \mathbb{R}^{N \times \tilde{C}} QmRN×C~、键矩阵 K m ∈ R N × C ~ K_m \in \mathbb{R}^{N \times \tilde{C}} KmRN×C~和值矩阵 V m ∈ R N × C V_m \in \mathbb{R}^{N \times C} VmRN×C,其中 M = H ′ × W ′ M = H' \times W' M=H×W。然后,计算所有区域对之间的注意力权重 A m ∈ R M × M A_m \in \mathbb{R}^{M \times M} AmRM×M,可以表示全局空间关系。因此,通过计算全城所有区域记忆值的加权和,可以获得长期全局时空特征 F m ′ F'_m Fm。为了获得对应每个网格的细粒度特征,我们进一步采用 N 2 N^2 N2-Normalization方法将 H ′ × W ′ H' \times W' H×W区域的区域级特征 F m ′ F'_m Fm映射到 H × W H \times W H×W网格的网格级特征 F m F_m Fm

Q m = W m Q F l ′ , K m = W m K M k ′ , V m = W m V M v ′ Q_m = W^Q_m F'_l, K_m = W^K_m M'_k, V_m = W^V_m M'_v Qm=WmQFl,Km=WmKMk,Vm=WmVMv

A m = s o f t m a x ( Q m ( K m ) T ) A_m = softmax \left( Q_m (K_m)^T \right) Am=softmax(Qm(Km)T)

F m ′ = L G A ( Q m , K m , V m ) = A m V m F'_m = LGA(Q_m, K_m, V_m) = A_m V_m Fm=LGA(Qm,Km,Vm)=AmVm

F m = N 2 − N o r m a l i z a t i o n ( F m ′ ) (8) F_m = N^2-Normalization(F'_m) \tag{8} Fm=N2Normalization(Fm)(8)

3.3.3 融合机制

一些现有研究表明,融合多种特征可以提升模型性能。因此,我们提出了一种融合机制,将两种时空特征整合以丰富模型的表示能力。具体而言,我们利用神经门控技术动态聚合从两个注意力模块中提取的特征,这可以通过控制两种时空特征的重要性来适应不同区域和时间段。给定短期局部特征 F c F_c Fc和长期全局特征 F m F_m Fm,通过一个由融合门 g f g_f gf控制的加权和得到集成的特征表示 F f i n a l ∈ R H × W × K F_{final} \in \mathbb{R}^{H \times W \times K} FfinalRH×W×K,该过程可以总结为:

F f i n a l = g f ∘ F c + ( 1 − g f ) ∘ F m F_{final} = g_f \circ F_c + (1 - g_f) \circ F_m Ffinal=gfFc+(1gf)Fm

g f = σ ( W c g ∗ F c + W m g ∗ F m + b g ) (9) g_f = \sigma (W_{cg} * F_c + W_{mg} * F_m + b_g)\tag{9} gf=σ(WcgFc+WmgFm+bg)(9)

其中 W c g W_{cg} Wcg W m g W_{mg} Wmg b g b_g bg是待学习的投影矩阵。

最终,基于提取的时空特征 F f i n a l F_{final} Ffinal计算第 T T T时刻的预测交通流量 Y ^ T ∈ R H × W × K ′ \hat{Y}_T \in \mathbb{R}^{H \times W \times K'} Y^TRH×W×K

Y ^ T = tanh ⁡ ( C o n v ( F f i n a l ) ) (10) \hat{Y}_T = \tanh (Conv(F_{final})) \tag{10} Y^T=tanh(Conv(Ffinal))(10)

其中 C o n v Conv Conv是卷积操作, tanh ⁡ \tanh tanh是双曲正切函数,确保输出值在-1和1之间。

3.4 优化

我们将交通流量预测公式化为一个回归任务,并提出一种端到端的训练方法,通过最小化以下损失函数来训练模型:

L = L r p + α L s m + β L c m + γ ∥ Θ ∥ F 2 (11) \mathcal{L} = \mathcal{L}^p_r + \alpha \mathcal{L}^m_s + \beta \mathcal{L}^m_c + \gamma \|\Theta\|^2_F \tag{11} L=Lrp+αLsm+βLcm+γ∥ΘF2(11)

其中 Θ \Theta Θ是模型参数集, α \alpha α β \beta β γ \gamma γ是权衡参数。 L r p \mathcal{L}^p_r Lrp为均方误差(MSE),用于评估预测值 Y ^ T \hat{Y}_T Y^T与真实值 M T M_T MT之间的预测性能。 L s m \mathcal{L}^m_s Lsm L c m \mathcal{L}^m_c Lcm为记忆网络的损失,如前一节所述。最终项对所有模型参数进行正则化以避免过拟合。

4 实验

在本节中,我们首先概述实验设置,包括数据集、基线、评估指标和实现。随后,我们在五个公共数据集上对所提出模型的性能进行全面评估。

4.1 实验设置

4.1.1 数据集

我们在五个交通数据集上进行实验,以评估ST-MAN的性能。表2详细列出了五个真实世界的数据集。前两个数据集,TaxiNYC和BikeNYC,包含纽约市大约600万条出租车行程记录和800万条自行车行程记录。接下来的两个数据集,TaxiDC和BikeDC,来自华盛顿特区约1600万辆出租车和300万辆自行车的轨迹。最后一个数据集,TaxiBJ+,包含北京超过3万条出租车轨迹。

image-20241106231113095

4.1.2 基线

我们将ST-MAN与四组基线进行比较。第一类方法是时间序列回归模型,包括历史平均(HA)和自回归积分滑动平均(ARIMA);第二类方法是传统回归模型,包括线性回归(LR)和树模型;第三类方法是经典的基于深度神经网络(DNN)的模型,包括多层感知机(MLP)、卷积神经网络(CNN)、长短期记忆网络(LSTM)和卷积LSTM(ConvLSTM);最后一类方法是最新的时空模型,包括DMVST-Net、DeepST、ST-ResNet、STDN、DeepSTN+、SA-ConvLSTM和PDFormer。

  • HA:历史平均仅使用历史数据的平均值来预测未来值。
  • ARIMA:自回归积分滑动平均通过历史值和残差运算的线性组合来预测未来值。ARIMA模型在实验中针对每个目标区域分别训练。
  • LR:线性回归建模当前观察值与历史数据之间的线性关系。我们基于所有区域的历史数据构建一个全局回归模型以计算预测值。
  • 树模型:使用树结构拟合数据之间的复杂相关性。
  • MLP:我们使用一个全连接网络进行交通预测任务。隐藏层的神经元数量分别为256、128和64。
  • CNN:卷积操作用于在神经网络中建模空间依赖性。实验中的CNN模型包含3个卷积层,滤波器数量为64,卷积核大小为3×3。
  • LSTM:长短期记忆网络具有128个隐藏单元,我们选择前12帧预测下一帧。
  • ConvLSTM:卷积LSTM用卷积层替换LSTM单元中的全连接层,以同时捕获时间和空间依赖性。
  • DMVST-Net:深度多视角时空网络结合了CNN和LSTM,同时建模空间和时间关系。空间特征首先通过两个卷积层提取,然后输入全连接的LSTM层以提取时间特征。
  • DeepST:第一个基于深度学习的时空预测网络,通过卷积神经网络框架捕获基于三种序列数据的空间和时间依赖性。
  • ST-ResNet:深度时空残差网络是一种先进的时空模型,应用残差机制以进一步改进DeepST中的一般卷积框架。
  • STDN:时空动态网络通过注意力机制在长时间跨度上建模时间关系。
  • DeepSTN+:改进了ST-ResNet中的融合机制,并通过ConvPlus模块建模大范围空间依赖性。
  • SA-ConvLSTM:自注意力卷积LSTM结合了注意力机制和ConvLSTM,以建模大范围空间依赖性。
  • PDFormer:一种先进的城市流量预测方法,通过自注意力模块捕获动态的时间和空间依赖性。

4.1.3 评估指标

我们基于交通流量预测的两个常用指标来评估模型性能:均方根误差(RMSE)和平均绝对误差(MAE)。RMSE和MAE的值越小,模型性能越好。

4.1.4 实现

我们使用PyTorch 1.8和Python 3.8实现了ST-MAN和所有基线。在我们的实验中,我们将纽约市、华盛顿特区和北京划分为大小为10×20、16×16和128×128的网格单元。我们还为相应城市定义了大小为5×10、8×8和16×16的区域单元。对于每个数据集,我们选择最后20%的数据作为测试数据,之前的所有数据用于训练和验证。对于时间接近度、周期性和趋势性,我们分别设置三个片段的长度为 l c ∈ { 3 , 4 , 5 } l_c \in \{3,4,5\} lc{3,4,5} l p ∈ { 3 , 4 , 5 } l_p \in \{3,4,5\} lp{3,4,5} l q ∈ { 2 , 3 , 4 } l_q \in \{2,3,4\} lq{2,3,4}。此外,键和值矩阵的维度(即 D k m D_{km} Dkm D v m D_{vm} Dvm)设置为32和64。所有卷积层使用64个滤波器,卷积核的大小设置为3×3。ConvLSTM层的深度在 { 2 , 3 , 4 } \{2,3,4\} {2,3,4}中搜索。注意力模块的隐藏维度 D D D { 16 , 32 , 64 , 128 } \{16,32,64,128\} {16,32,64,128}中搜索,注意力层的深度在 { 1 , 2 , 3 , 4 } \{1,2,3,4\} {1,2,3,4}中搜索。温度参数 τ \tau τ固定为1.0,损失函数中的权衡参数 α \alpha α β \beta β γ \gamma γ分别设为0.05、0.1和0.001。在训练过程中,模型通过Adam优化器进行优化,学习率为 1 0 − 4 10^{-4} 104,批次大小设置为32。

4.2 实验结果

4.2.1 性能比较

为了评估我们提出的模型,我们在五个数据集上比较了不同模型的性能,其中TaxiBJ+是一个具有128 × 128网格单元的细粒度数据集。结果如表3所示。

image-20241106231138152

从表3可以看出,ST-MAN在所有基准数据集上的性能均有改善,证明了我们模型的优越性。我们有以下主要观察:

  • 可以看到,基于DNN的预测模型在复杂的非线性时空关系捕获方面显著优于时间序列回归模型和传统回归方法,因为它们在捕捉复杂的非线性时空关系方面能力有限。
  • 通过将先进的时空模型与传统的DNN模型(包括MLP、CNN、FC-LSTM和ConvLSTM)进行比较,可以发现时空模型的性能更好。主要原因是大多数时空模型在设计相应模块时考虑了时空数据的特征。例如,尽管ConvLSTM扩展了FC-LSTM以更好地捕捉时空相关性,但仍未能捕捉到长距离空间依赖性和长期时间依赖性,这表明需要设计特定模块以建模复杂的时空相关性,从而提高预测模型的表示能力。
  • ST-MAN在五个数据集上的性能优于一些最先进的时空预测模型。原因在于基于CNN的模型(如DeepST、ST-ResNet、DeepSTN+)主要通过卷积层聚焦于长距离空间依赖性,但忽略了长期的时间动态。DMVST-Net和STDN结合了局部CNN和RNN,分别学习空间和时间依赖性。然而,这些模型忽略了空间和时间因素之间的动态相关性。此外,这些模型在捕获长距离空间依赖性方面不足,因为CNN模型需要堆叠许多卷积层,这在大规模空间网络中难以部署。我们提出的ST-MAN能够基于记忆网络和多层次注意力网络同时捕获短期局部时空依赖性和长期全局时空依赖性。
  • 我们可以观察到ST-MAN在所有五个预测任务中均优于SAConvLSTM和STM。尽管SAConvLSTM引入了注意力机制到ConvLSTM以提取长距离空间依赖性,但仍不足以捕捉长期时间依赖性。此外,尽管PDFormer利用Transformer结构捕捉动态的时间和空间特征,但仍不足以捕捉长期时间依赖性。在处理城市中大量网格时,它还存在显著的冗余计算。传统的自注意力机制在所有网格之间计算注意力分数,导致了显著的计算开销。我们从网格空间到区域空间执行降采样操作以获取语义信息,这在更低计算成本的情况下更有利于捕捉全局依赖性。因此,ST-MAN可以通过读取外部记忆网络中的关键信息有效地学习长距离空间依赖性和长期时间特征。此外,我们还引入了外部组件,基于网格级空间嵌入生成位置感知的外部因素影响。总体来说,ST-MAN的改进不仅来源于注意力机制,还来源于外部记忆网络的独特结构。

4.2.2 消融研究

为了研究ST-MAN中不同组件对性能提升的贡献,我们评估了模型的三个变体来进行消融研究:

  • ST-MAN w/o ConvLSTM 去除了特征图构建阶段的ConvLSTM组件。
  • ST-MAN w/o STMN 去除了ST-MAN中的外部时空记忆网络。
  • ST-MAN w/o MAN 是基础预测模型,仅利用ConvLSTM提取特征。

结果汇总在表4中。

image-20241106231156240

可以观察到,通过加入STMN和MAN,性能有所提升,这表明记忆网络能够增强模型的表达能力并提高性能。首先,可以观察到ST-MAN w/o ConvLSTM的性能低于ST-MAN,表明通过ConvLSTM获得的特征图在建模复杂的时空依赖性方面的优势。其次,ST-MAN在没有注意力模块的情况下性能也低于ST-MAN,并且在不同数据集上表现出不同的变化。这一差异的主要原因在于不同城市的交通流分布差异。此外,我们发现移除STMN中的键记忆单元和注意力单元在一定程度上会降低性能。其主要原因是键记忆单元通过空间嵌入提供了全球空间相关性信息,而忽视空间依赖性会削弱特征表示能力。简而言之,ST-MAN的良好表现证明了我们设计的有效性,引入外部记忆网络以捕捉长距离和长期的时空特征。

4.2.3 注意力机制的效果

本实验研究了我们提出的注意力模块在建模时空特征方面的有效性。我们实现了三个简化版本的ST-MAN,以研究注意力模块是否能有效地捕获时空特征。

  • ST-MAN w/o SGLA:去除了多层次注意力网络中的短期跨网格局部注意力模块;
  • ST-MAN w/o SRLA:去除了多层次注意力网络中的短期跨区域局部注意力模块;
  • ST-MAN w/o LGA:去除了来自时空记忆网络的长期跨区域全局注意力,仅使用ConvLSTM提取特征。

表5展示了实验结果。

image-20241106231213200

可以看到,ST-MAN比其他变体表现更佳,这表明我们提出的注意力机制能够有效提取时空特征并提升模型的预测性能。此外,可以观察到ST-MAN w/o LGA的表现最差,表明长距离和短距离的时空特征对于交通流量预测是有用的。此外,可以观察到ST-MAN w/o SRLA的性能比ST-MAN w/o SGLA更差,表明跨区域交通流转移的重要性。一个可能的原因是,随着交通的发展,城市内人们的流动性增加,使得跨区域交互变得更加重要。

4.2.4 记忆簇数量的影响

为了从时空记忆网络中获得关于城市不同区域的先验知识,我们将区域划分为不同的类型,其中通过聚类获得的潜在城市功能区域数量是ST-MAN模型的一个重要参数。因此,我们研究了值记忆矩阵中不同功能区域类别数量对ST-MAN模型的影响。

在本实验中,聚类功能区域数量 K K K的范围设置为 { 3 , 4 , 5 , 6 , 7 } \{3,4,5,6,7\} {3,4,5,6,7}图8显示了随着 K K K值增加,ST-MAN模型在四个数据集上的预测性能。

image-20241106231300238

从图8可以观察到,当 K K K值设置为5时,ST-MAN的性能最佳。一个可能的原因是,大多数城市已经发展成熟,城市的功能区域趋于稳定,因此在短期内某些类型的功能区域不太可能出现或消失。例如,商业区、居住区、行政区、大学区和风景区通常是大多数城市的五个主要城市功能区域。

5 相关工作

在本节中,我们简要回顾了一些关于城市交通预测和记忆网络的相关研究。

5.1 城市交通预测

随着大规模交通数据的不断增加,城市交通预测已成为城市计算中的一个重要领域。许多研究基于历史交通数据预测潜在的交通量。经典研究将交通预测视为时间序列预测问题,并应用时间序列方法建模时间模式。例如,历史平均模型(HA)仅使用历史数据的平均值来预测未来值。自回归积分滑动平均(ARIMA)模型作为典型的时间序列模型之一,通过历史值和残差运算的线性组合来预测未来值。然而,这些方法在捕捉复杂的非线性时间关系方面能力有限,并且忽略了空间信息来建模空间依赖性。

近年来,深度学习在许多任务中取得了可喜的性能,特别是在计算机视觉领域中,CNN已被成功用于提取空间特征,而RNN通过嵌入历史序列记录到隐藏状态向量中为序列学习任务编码时间信息。许多研究者开始利用深度学习方法来解决交通预测问题,并表现出优于传统方法的性能。DeepST是第一个基于CNN的时空预测网络,它利用卷积神经网络框架来捕捉空间依赖性。ST-ResNet作为具有代表性的时空模型之一,通过引入残差机制改善了DeepST的框架,使深度结构能够建模大规模城市范围的依赖性。DeepSTN+提出了一种情境感知时空神经网络,应用ConvPlus结构来捕捉不同区域间的长距离空间相关性。此外,许多研究利用图卷积神经网络建模交通预测中的轨迹和道路网络。此外,许多工作结合CNN和RNN来同时建模空间和时间相关性。提出了深度多视角时空网络(DMVST-Net)来为出租车需求预测建模空间和时间关系。STDN提出了时空动态网络以解决真实场景中的动态时间偏移问题。

这些现有工作在捕捉长距离空间依赖性和长期时间依赖性方面具有局限性。一方面,在基于CNN的网络中堆叠许多卷积层以捕捉长距离空间依赖性可能导致高计算成本和优化难度。另一方面,基于RNN的网络将历史数据压缩到隐藏状态中以编码时间依赖性,限制了表达长期时间模式的能力。相比之下,我们的目标是在城市交通预测中同时建模长距离空间依赖性和长期时间依赖性。

5.2 记忆网络

记忆网络是一种循环注意力模型,旨在解决RNN在记忆操作中面临的困难。它利用外部记忆组件在记忆槽中读写长期记忆,从而提供了额外的知识表示以增加模型容量。

近年来,记忆网络已成功应用于许多领域,如问答和序列推荐。提出了一种基于协同过滤的记忆增强神经网络(MANN),用于推荐系统中存储和更新用户兴趣。具体来说,使用了两个外部用户记忆矩阵来编码项目级和特征级信息。提出了记忆到序列(Mem2Seq)模型,用于端到端任务导向的对话系统,通过多跳机制和外部记忆来处理长序列。此外,除了原始记忆网络之外,键值记忆网络(KV-MN)将记忆组件结构化为(键、值)对,并分别基于键记忆和值记忆进行寻址和读取操作。提出了一种基于动态记忆的注意力网络(DMAN),用于长期序列推荐,包括一组用于存储用户长期兴趣的记忆块。提出了主题增强记忆网络(TEMN),用于个性化兴趣点推荐,特别是利用记忆网络组件编码用户的历史签到记录,并捕捉不同区域之间的局部关系。

以上所有工作主要关注通过隐藏向量存储序列信息,而在时空预测方面很少有工作,因为记忆网络中忽略了空间信息。我们的工作不同于上述所有工作,我们为时空预测引入了新的视角,其中记忆网络用于编码和记忆空间信息和时间信息,以增强时空预测模型的表达能力。据我们所知,我们是第一个将外部记忆网络用于时空预测任务,并研究复杂交通状态下的长距离和长期时空相关性。

6 结论

本文提出了一种新的交通预测方法,称为时空记忆增强多层次注意力网络(ST-MAN)。据我们所知,这是首次在时空预测任务中引入外部记忆网络,以通过编码和记忆细粒度空间信息和时间模式来增强预测模型的表达能力。此外,我们结合时空记忆网络和多层次注意力网络的优点,显式建模长距离空间依赖性和长期时间依赖性。我们提出的ST-MAN在五个真实世界的数据集上进行了评估,实验结果表明ST-MAN相比最先进的基线模型更为有效。在未来的工作中,我们计划利用记忆网络学习领域不变的时空模式,并从数据丰富的城市向数据稀疏的城市转移有价值的知识,以提高性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/910105.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【每日一题】2009考研数据结构 - 求倒数第k个数的值

已知一个带有表头结点的单链表,结点结构为 data 和 link。假设该链表只给出了头指针 list。在不改变链表的前提下,请设计一个尽可能高效的算法,查找链表中倒数第 k 个位置上的结点(k 为正整数)。 要求: 若…

ELF加载,进程地址空间与可执行程序的关系

1,可执行程序的格式 粗略概况 操作系统要如何认识可执行程序?我们的可执行程序是有格式的: 用指令size 加可执行程序名: 其中test就是代码块,data就是数据块,不仅可执行程序有格式,动态库&am…

超实惠的租借服务器训练深度学习方法

1. 必备软件 1.1 Xftp和Xshell 通过百度网盘分享的文件:Niha 链接:https://pan.baidu.com/s/1uHLme7H9SL2C-ZhFr107gA?pwdnadb 提取码:nadb xftp用于连接服务器, 传输本地文件到服务器上面去。 xshell用于连接服务器进行命令操作 2 恒源…

蓝桥杯-网络安全比赛题目-遗漏的压缩包

小蓝同学给你发来了他自己开发的网站链接, 他说他故意留下了一个压缩包文件,里面有网站的源代码, 他想考验一下你的网络安全技能。 (点击“下发赛题”后,你将得到一个http链接。如果该链接自动跳转到https,…

MongoDB笔记03-MongoDB索引

文章目录 一、前言1.1 概述1.2 MongoDB索引使用B-Tree还是BTree?1.3 B 树和 B 树的对比1.4 总结 二、索引的类型2.1 单字段索引2.2 复合索引2.3 其他索引 三、索引的管理操作3.1 索引的查看3.2 索引的创建3.2.1 单字段索引3.2.2 复合索引 3.3 索引的移除3.3.1 指定索…

【Android】时区规则库tzdata更新

1 背景: 最近我遇到墨西哥城时区,会出现夏令时,而墨西哥城在2022年底都已经取消夏令时了。 看起来是要更新RK3588上的时区库,我的还是2021a,而现在都已经2024年了 这样能看版本号: cat /system/usr/sha…

网络初始:TCP/IP 五层协议模型 网络通信基本流程

目录 1. 名词解释 1.1 局域网 1.2 广域网 1.3 交换机 1.4 IP 地址 1.5 端口号 2. 协议 2.1 认识协议 2.2 五元组 3. 协议分层 3.1 分层的作用 3.2 OSI 七层网络模型 & TCP/IP 五层(四层)协议模型 4. TCP/IP 五层(四层)网络模型 4.1 物理层 4.2 数据链路层 4…

小新学习k8s第六天之pod详解

一、资源限制 Pod是k8s中的最小的资源管理组件,pod也是最小化运行容器化应用的资源对象。一个Pod代表着集群中运行的一个进程。k8s中其他大多数组件都是围绕着Pod来进行支撑和扩展Pod功能的,例如,用于管理Pod运行的StatefulSet和Deployment等…

安利一款超6K+ star的可拖放响应式灵活的网格布局Gridstack.js

Gridstack.js是一个现代JavaScript(或Typescript)库,旨在帮助开发人员快速构建交互式和响应式的布局。以下是对Gridstack.js的详细介绍: 一、主要特点 灵活的网格布局:Gridstack.js允许开发者轻松地创建和管理网格布局…

接口测试基础 --- 什么是接口测试及其测试流程?

接口测试是指针对软件系统的接口进行测试的过程,主要是验证系统之间的数据传输和通信是否正常、功能是否正确。接口测试主要关注接口的输入、输出以及相应的逻辑关系,而不关注底层实现细节。接口测试可以帮助开发团队发现和解决与接口相关的问题&#xf…

1分钟解决Excel打开CSV文件出现乱码问题

一、编码问题 1、不同编码格式 CSV 文件有多种编码格式,如 UTF - 8、UTF - 16、ANSI 等。如果 CSV 文件是 UTF - 8 编码,而 Excel 默认使用的是 ANSI 编码打开,就可能出现乱码。例如,许多从网络应用程序或非 Windows 系统生成的 …

python基础(1)

声明:学习视频来自b站up主 泷羽sec,如涉及侵权马上删除文章 感谢泷羽sec 团队的教学 视频地址:初识python,环境配置,编程基础以及数据类型_哔哩哔哩_bilibili 一、什么是python Python 是一种高级、解释型、通用编程语…

USB 设备数据安全管理解决方案

在当今数字化的办公环境中,USB 设备的广泛使用为企业和组织带来了便捷,但同时也隐藏着巨大的数据泄露风险。许多企业和机构都曾因 USB 设备使用不当而遭受严重损失。 一方面,员工可能会无意或有意地使用未经授权的 USB 设备接入公司网络。这…

【UE5】一种老派的假反射做法,可以用于移动端,或对反射的速度、清晰度有需求的地方

没想到大家这篇文章呼声还挺高 这篇文章是对它的详细实现,建议在阅读本篇之前,先浏览一下前面的文章,以便更好地理解和掌握内容。 这种老派的假反射技术,适合用于移动端或对反射效果的速度和清晰度有较高要求的场合。该技术通过一…

Flink滑动窗口(Sliding)中window和windowAll的区别

滑动窗口的使用,主要是计算,在reduce之前添加滑动窗口,设置好间隔和所统计的时间,然后再进行reduce计算数据即可。 窗口设置好时间间隔,和处理时间窗口的时间,比如将滑动窗口的时间间隔都设置为5s,处理时间…

基于YOLO11/v10/v8/v5深度学习的煤矿传送带异物检测系统设计与实现【python源码+Pyqt5界面+数据集+训练代码】

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

Golang--文件操作

1、文件 文件:文件用于保存数据,是数据源的一种 os包下的File结构体封装了对文件的操作(记得包os包) 2、File结构体--打开文件和关闭文件 2.1 打开文件 打开文件,用于读取(函数): 传…

BSAchongsds、

一、 ## 统计基因组整体信息 srun -A 2022099 -p Debug -n 2 -N 1 seqkit stats ~/yiyaoran/workspace/06.BSRseq/guo_BSR_pipline/ref/genome.fasta > genome.allstatcat genome.allstat 文件名 格式 类型 序列数量 总长度 最小长度 平均长…

聊一聊Elasticsearch的基本原理与形成机制

1、搜索引擎的基本原理 通常搜索引擎包括:数据采集、文本分析、索引存储、搜索等模块,它们之间的协作流程如下图: 数据采集模块负责采集需要搜索的数据源。 文本分析模块是将结构化数据中的长文本切分成有实际意义的词,这样用户…

**AI的三大支柱:神经网络、大数据与GPU计算的崛起之路**

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…