2024.6.16周报

目录

摘要

ABSTRACT

一、文献阅读

一、题目

二、摘要

三、创新点

四、模型架构

五、文章解读

1、Introduction

2、实验

3、结论

二、代码复现

1、模型代码

2、实验结果 

三、总结


摘要

本周我阅读了一篇题目为《Contaminant Transport Modeling and Source Attribution With Attention‐Based Graph Neural Network》的论文,这篇论文引入了一种新的基于注意力的图神经网络(aGNN),专门用于在有限监测数据下模拟污染物迁移并量化污染源及其传播之间的因果关系。此外,aGNN的解释性分析能有效量化每个污染源的影响,证实了其在地下污染物运移研究中的高效性和减少计算成本的能力,为地下水管理提供了一个有力的工具。通过复现其代码,对模型的架构有了更深刻的理解。

ABSTRACT

This week, I rear a paper titled "Contaminant Transport Modeling and Source Attribution With Attention-Based Graph Neural Network" . In the paper, a new attention-based graph neural network (aGNN) was introduced, which is specifically designed to simulate contaminant migration under limited monitoring data and to quantify the causal relationships between pollutant sources and their propagation. Moreover, the interpretative analysis of aGNN was shown to effectively quantify the impact of each pollution source, confirming its efficiency in studies of subsurface contaminant migration and its ability to reduce computational costs, providing a powerful tool for groundwater management. By reproducing its code, a deeper understanding of the model's architecture was gained.

一、文献阅读

一、题目

题目:Contaminant Transport Modeling and Source Attribution With Attention‐Based Graph Neural Network

期刊:Water Resources Research

链接:https://doi.org/10.1029/2023WR035278

二、摘要

文章引用了一种称为基于注意力的图神经网络(aGNN)的新型机器学习模型,该模型旨在使用稀疏的监测数据对污染物传输进行建模,并分析污染物源与特定位置观测到的浓度之间的因果关系。文章在具有不同检测设置的不同含水层系统中进行了五个综合案例研究,其中aGNN表现最佳;此外,aGNN的解释性分析有效地量化了每个污染物源的影响,总结来说,这篇论文将aGNN确立为一种高效而稳健的地下污染物迁移复杂时空学习方法,它也成为地下水管理和污染源识别的一个重要工具。

The article employs a novel machine learning model known as Attention-based Graph Neural Networks (aGNN), which is designed to model the transport of pollutants using sparse monitoring data and to analyze the causal relationships between pollutant sources and the concentrations observed at specific locations. Five comprehensive case studies were conducted in various aquifer systems with different detection setups, where the aGNN demonstrated superior performance. Furthermore, the interpretability analysis of aGNN effectively quantified the impact of each pollutant source. In summary, this paper establishes aGNN as an efficient and robust method for complex spatiotemporal learning of subsurface pollutant migration, making it a significant tool for groundwater management and pollutant source identification.

三、创新点

(1)该文提出一种基于图的深度学习方法,用于模拟受监测数据约束的污染物迁移;

(2)所提出的模型量化了每个潜在污染源对任意位置观测浓度的贡献;

(3)与基于物理的污染物传输模型相比,深度学习方法大大降低了计算成本;

四、模型架构

使用深度学习和基于物理的模型(MODFLOW和MT3DMS)两种方法进行污染物传输建模的工作流程和数据概述。这些模型在三个任务中进行评估:转导学习、归纳学习和模型解释。

图1展示了使用深度学习(DL)方法和基于物理的模型来模拟地下水质量对多源污染排放的时空响应。深度学习模型,如aGNN、CNN和RNN,不需详细的水文地质信息,而物理模型如MODFLOW和MT3DMS则依赖这些数据。DL模型通过端到端学习,整合MODFLOW和MT3DMS的功能,处理水排放、污染物释放及其浓度和地下水位的数据。文章还评估了这些模型在转导学习、归纳学习和模型解释方面的效果,特别是通过Shapley值方法来分析和量化多点污染源的影响,以提供地下水管理和污染源识别的见解。

图2展示了aGNN的体系结构,这是一个基于编码器-解码器框架的系统。该体系由五个主要模块组成:

1、输入模块

(1)编码器输入和解码器输入:这两个模块负责构建节点的特征向量(包括监测点的污染物浓度、流量动态等)、空间信息以及邻接矩阵。编码器输入通常设计过去的时间步骤;解码器输入则关注未来的时间步骤。

(2)图嵌入模块:空间嵌入:通过对节点的地理位置或其他空间属性进行编码,捕捉节点间的空间关系。

                                时间嵌入:将时间信息转换为嵌入表示,使模型能够捕捉时间变化的模式和趋势。

时间嵌入可以使用时间顺序信息,给定一个时间序列S =(s0,s1,…,sT),时间嵌入层形成一个有限维表示来表示si在序列S中的位置。研究中的时间嵌入是将正弦变换串联到时间顺序,形成矩阵TE\in\mathbb{R}^{T\times d_{emb}},其中T和d_{emb}分别是时间长度和向量维数。TE设计为式2和式3,
其中2d和2d + 1分别表示偶数维和奇数维,t为时间序列中的时间位置。时间嵌入的维数为demb
3所示,时间嵌入中的每个元素都结合了时间顺序位置和特征空间的信息。

(3)编码器模块

 查询(Q),键(K)和值(V)。其思想是将Q和一组K‐V对映射到输出,使输出表示V的加权和。权重由相应的K和Q决定,然后应用Softmax函数对权重值进行归一化。

多头自注意力机制(MSA):允许模型在处理每个节点的特征时,同时考虑多种不同的解释和侧重点,从而更好地理解数据中的复杂模式。

 Q与解码器输入相关,K和V与编码器生成的隐藏特征相关。MSA特别关注自我注意机制,该机制适用于与自身交互的输入,在数学上,Q、K和V采用相同的原始输入(如公式6中的Xq=Xk=Xv)。MSA允许模型捕获输入序列中的不同方面和依赖关系,从而对特征元素之间的关系提供更全面的理解。

图卷积网络(GCN):通过在图结构中传播和更新节点信息,学习节点的特征表示。GCN通过使用节点及其邻居的信息,增强了模型对整个网络结构的理解。

在MSA阻塞后,GCN通过图结构在节点之间交换信息来提取中间表示,从而对空间依赖关系进行建模。GCNs使用图卷积过滤器,设计用于建模节点依赖关系。GCN的主要思想是构建一个消息传递网络,其中信息沿着图内相邻节点传播。

多头注意(MTA):MAT将信息从编码器传输到解码器。MAT作为编码器和解码器之间的链接。编码器的堆叠输出作为V和K传递给MAT,并将注意力分数分配给解码器输入的表示(即Q)。解码器中的MSA和GCN进行类似于机器翻译任务的学习过程,其中,解码器输入表示需要翻译成另一种“语言”的“一种语言中的句子”。

(4)解码器模块

与编码器结构相似,解码器同样包括多头自注意力机制和GCN层。不同的是,解码器更侧重于使用编码器的输出(隐藏状态)来生成对未来状态的预测。

(5)输出模块

最终生成的是目标序列预测,如污染物在未来某一时间点在地下水中的预期移动。

五、文章解读

1、Introduction

文章提出了aGNN,一种新的基于注意力的图神经建模框架,它结合了(a)图卷积网络(GCN)、(b)注意力机制和(c)嵌入层来模拟地下水系统中的污染物输送过程。GCN通过消息通过节点和边缘提取图信息,有效学习空间模式。注意机制是变压器网络中擅长序列分析的关键组成部分。嵌入层是潜在空间学习机制,代表了时空过程中的高维性。对交通和行人轨迹的研究表明,基于注意力的图神经网络在单过程时空预测任务中表现出竞争性的表现。在本研究中,作者将其应用扩展到学习地下水流动和溶质输送问题中的多个过程。此外,在尚未研究的未监测污染位置,采用了新的坐标嵌入方法进行归纳学习。本研究的目标有三个方面。首先研究了aGNN在涉及污染物迁移建模的多过程中的性能。基于GNN、CNN和LSTM的方法适用于多步空间预测的端到端学习任务,以深入了解每个模型的执行情况。其次,根据数据的可用性和含水层的非均质性,评估了aGNN将从监测数据中学习到的知识通过归纳学习转移到未监测站点的能力。第三,采用了一种可解释的人工智能技术,即沙普利值,它起源于合作博弈论的概念。

2、实验

1、研究区域

本研究设计了两个采用非承压含水层的合成研究场地,用于方法的开发和验证。第一个研究场地面积为497,500平方米,通过MODFLOW划分为30列和15行的网格,每个网格50米x50米。研究场地设置了两侧无通量边界和两侧恒定水头边界(分别为100米和95米)。为了研究水力传导率异质性对污染物传输模型的影响,考虑了两种水力传导率场景:一种是五个不同区域的水力传导率从15到35米/天变化;另一种是水力传导率从0到50米/天变化。污染物传输在MT3DMS中以30米的均匀纵向分散性进行模拟,并设置了三个间歇性排放污染水的注水井。第二个研究场地(场景C)覆盖面积180平方公里,是第一个场地的约360倍,划分为120列和150行的网格,每个网格100米x100米,并设有四个区域的水力传导率从30到350米/天变化。两个场地都设置了监测系统,包括水位下降和污染物浓度的日常数据记录。本研究还考察了三个水力传导率场的三个监测网络,观察它们对污染物移动反应的学习过程如何受到数据大小的影响。

2、实验准备

使用MODFLOW和MT3DMS模拟生成污染物运输数据集,并用于训练和评估不同的深度学习模型。数据集中80%用于训练,20%用于性能评估。所有DL模型均通过批量优化进行训练,批次大小为16,迭代400个周期,模型输出观测位置的地下水位降低(GD)和污染物浓度(CC)的预测,预测时域为50时间步。

上表为三种检测网络中不同算法的输入维度和参数数量

3、实验结果

四种DL模型:DCRNN、aGNN、aGNN-noE(无嵌入模块的aGNN变体)和ConvLSTM。这些模型都使用编解码器框架,但在输入设计上有所不同。输入特征包括静态特征(S)、历史行为(H)和计划特征(F)。静态特征代表坐标信息,历史行为详细记录了地下水排放和污染物释放的两个计划及监测的GD和CC,计划特征包含预测期内的地下水排放和污染物释放计划。

上图(a)、(b)、(c)显示的是传导学习中,三个检测网络中,污染源及其邻居具有较大的节点强度。

上图(d)、(e)、(f)显示的是归纳学习M1、M2、M3的预测区域。 

表2总结了整个数据集的统计特征,并按80/20的比例划分为训练和测试集。结果显示aGNN在五种不同情况下的测试性能。CC的变异范围是GD的五倍,表现出更高的分散性。本研究将多目标任务(涉及GD和CC)转化为单目标,使用加权和方法,CC权重为5,GD权重为1。此外,含水层非均质性对GD的影响较小,与CC相比,水头对电导率的非均质性敏感度较低。场景C中,由于场地更大且监测井更少,所有模型的精度均有所下降。在所有算法中,aGNN在几乎所有五种情况下均获得最低RMSE和最高R^{2}(表2),表明其在模拟非均匀分布监测系统中污染物迁移方面的性能优于其他算法。

 图7展示了四种模型的预测误差。ConvLSTM在空间上的RMSE较高,通常超过1 mg/L,而DCRNN的RMSE普遍低于0.3 mg/L,尤其在A-M1、B-M1、A-M2和B-M2区域。aGNN和aGNN-noE的性能优于DCRNN,显示更小的RMSE波动,这证明了基于注意力的图卷积网络的优势。aGNN在所有模型中展示了最小的RMSE变化,突显出其在捕捉空间变化,尤其是在污染源下游区域的高效性。此外,研究还使用了相对绝对误差(RAE)来测量预测值与真实值之间的差异,发现使用aGNN时RAE降低。

3、结论

本研究开发了一种新型数据驱动模型aGNN,用于模拟非均质地下水含水层中的污染物传输,特别强调数据有限且分布不均的情况。aGNN模型结合了注意力机制、时空嵌入层和图卷积网络层(GCN),优化了污染物传输的时空学习精度,通过动态权重分配、特征转换和信息传递提高模型效率。实验结果显示,aGNN在预测精度上达到了99%的R^{2}值,证明了其高效的预测能力。此外,aGNN能够利用图学习从监测地点提供的数据推断未监测地点的观测,即使在监测井有限的大型场地或高度非均质的含水层中也能有效捕捉污染物的时空变化。aGNN还通过SHAP方法分析污染源归因,展示了其作为数值模拟模型MODFLOW和MT3DMS的有效替代品。此方法大幅减轻了基于物理模型的计算负担,特别是在需要处理大量注入井和长期管理的情景中,显著提高了计算效率。

二、代码复现

1、模型代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import numpy as np
from utils_dpl3_contam import norm_Adj


class RBF(nn.Module):
    """
    Transforms incoming data using a given radial basis function:
    u_{i} = rbf(||x - c_{i}|| / s_{i})
    Arguments:
        in_features: size of each input sample
        out_features: size of each output sample
    Shape:
        - Input: (N, in_features) where N is an arbitrary batch size
        - Output: (N, out_features) where N is an arbitrary batch size
    Attributes:
        centres: the learnable centres of shape (out_features, in_features).
            The values are initialised from a standard normal distribution.
            Normalising inputs to have mean 0 and standard deviation 1 is
            recommended.

        log_sigmas: logarithm of the learnable scaling factors of shape (out_features).

        basis_func: the radial basis function used to transform the scaled
            distances.
    """

    def __init__(self, in_features, out_features, num_vertice,basis_func):
        super(RBF, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.centres1 = nn.Parameter(torch.Tensor(num_vertice, self.in_features))  # (out_features, in_features)
        self.alpha = nn.Parameter(torch.Tensor(num_vertice,out_features))
        self.log_sigmas = nn.Parameter(torch.Tensor(out_features))
        self.basis_func = basis_func
        self.reset_parameters()


        # self.alpha1 = nn.Parameter(torch.Tensor(num_vertice, self.out_features))
    def reset_parameters(self):
        nn.init.normal_(self.centres1, 0, 1)
        nn.init.constant_(self.log_sigmas, 0)

    def forward(self, input):

        size1= (input.size(0), input.size(0), self.in_features)
        x1 = input.unsqueeze(1).expand(size1)
        c1 = self.centres1.unsqueeze(0).expand(size1)
        distances1 = torch.matmul((x1 - c1).pow(2).sum(-1).pow(0.5),self.alpha) / torch.exp(self.log_sigmas).unsqueeze(0)
        return self.basis_func(distances1) #distances1


# RBFs

def gaussian(alpha):
    phi = torch.exp(-1 * alpha.pow(2))
    return phi


def linear(alpha):
    phi = alpha
    return phi


def quadratic(alpha):
    phi = alpha.pow(2)
    return phi


def inverse_quadratic(alpha):
    phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2))
    return phi


def multiquadric(alpha):
    phi = (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
    return phi


def inverse_multiquadric(alpha):
    phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
    return phi


def spline(alpha):
    phi = (alpha.pow(2) * torch.log(alpha + torch.ones_like(alpha)))
    return phi


def poisson_one(alpha):
    phi = (alpha - torch.ones_like(alpha)) * torch.exp(-alpha)
    return phi


def poisson_two(alpha):
    phi = ((alpha - 2 * torch.ones_like(alpha)) / 2 * torch.ones_like(alpha)) \
          * alpha * torch.exp(-alpha)
    return phi


def matern32(alpha):
    phi = (torch.ones_like(alpha) + 3 ** 0.5 * alpha) * torch.exp(-3 ** 0.5 * alpha)
    return phi


def matern52(alpha):
    phi = (torch.ones_like(alpha) + 5 ** 0.5 * alpha + (5 / 3) \
           * alpha.pow(2)) * torch.exp(-5 ** 0.5 * alpha)
    return phi


def basis_func_dict():
    """
    A helper function that returns a dictionary containing each RBF
    """

    bases = {'gaussian': gaussian,
             'linear': linear,
             'quadratic': quadratic,
             'inverse quadratic': inverse_quadratic,
             'multiquadric': multiquadric,
             'inverse multiquadric': inverse_multiquadric,
             'spline': spline,
             'poisson one': poisson_one,
             'poisson two': poisson_two,
             'matern32': matern32,
             'matern52': matern52}
    return bases
###############################################################################################################

def clones(module, N):
    '''
    Produce N identical layers.
    :param module: nn.Module
    :param N: int
    :return: torch.nn.ModuleList
    '''
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def subsequent_mask(size):
    '''
    mask out subsequent positions.
    :param size: int
    :return: (1, size, size)
    '''
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0   # 1 means reachable; 0 means unreachable


class spatialGCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels):
        super(spatialGCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))  # (b*t,n,f_in)

        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix, x)).reshape((batch_size, num_of_timesteps, num_of_vertices, self.out_channels)).transpose(1, 2))


class GCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels):
        super(GCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, F_in)
        :return: (batch_size, N, F_out)
        '''
        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix, x)))  # (N,N)(b,N,in)->(b,N,in)->(b,N,out)


class Spatial_Attention_layer(nn.Module):
    '''
    compute spatial attention scores
    '''
    def __init__(self, dropout=.0):
        super(Spatial_Attention_layer, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        '''
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, T, N, N)
        '''
        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))  # (b*t,n,f_in)

        score = torch.matmul(x, x.transpose(1, 2)) / math.sqrt(in_channels)  # (b*t, N, F_in)(b*t, F_in, N)=(b*t, N, N)

        score = self.dropout(F.softmax(score, dim=-1))  # the sum of each row is 1; (b*t, N, N)

        return score.reshape((batch_size, num_of_timesteps, num_of_vertices, num_of_vertices))


class spatialAttentionGCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels, dropout=.0):
        super(spatialAttentionGCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)
        self.SAt = Spatial_Attention_layer(dropout=dropout)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''

        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        spatial_attention = self.SAt(x)  # (batch, T, N, N)

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))  # (b*t,n,f_in)

        spatial_attention = spatial_attention.reshape((-1, num_of_vertices, num_of_vertices))  # (b*T, n, n)

        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix.mul(spatial_attention), x)).reshape((batch_size, num_of_timesteps, num_of_vertices, self.out_channels)).transpose(1, 2))
        # (b*t, n, f_in)->(b*t, n, f_out)->(b,t,n,f_out)->(b,n,t,f_out)


class spatialAttentionScaledGCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels, dropout=.0):
        super(spatialAttentionScaledGCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)
        self.SAt = Spatial_Attention_layer(dropout=dropout)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        spatial_attention = self.SAt(x) / math.sqrt(in_channels)  # scaled self attention: (batch, T, N, N)

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))
        # (b, n, t, f)-permute->(b, t, n, f)->(b*t,n,f_in)

        spatial_attention = spatial_attention.reshape((-1, num_of_vertices, num_of_vertices))  # (b*T, n, n)

        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix.mul(spatial_attention), x)).reshape((batch_size, num_of_timesteps, num_of_vertices, self.out_channels)).transpose(1, 2))
        # (b*t, n, f_in)->(b*t, n, f_out)->(b,t,n,f_out)->(b,n,t,f_out)



class SpatialPositionalEncoding_RBF(nn.Module):
    def __init__(self, d_model, logitudelatitudes,num_of_vertices, dropout, gcn=None, smooth_layer_num=0):
        super(SpatialPositionalEncoding_RBF, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # self.embedding = torch.nn.Embedding(num_of_vertices, d_model)
        self.embedding = RBF(2, d_model, num_of_vertices,quadratic) # gaussin nn.Linear(4, d_model-4)
        self.logitudelatitudes = logitudelatitudes
        self.gcn_smooth_layers = None
        if (gcn is not None) and (smooth_layer_num > 0):
            self.gcn_smooth_layers = nn.ModuleList([gcn for _ in range(smooth_layer_num)])

    def forward(self, x,log1,lat1):
        '''
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        # x,log,lat,t= x[0],x[1],x[2],x[3]
        batch, num_of_vertices, timestamps, _ = x.shape
        x_indexs = torch.concat((torch.unsqueeze(log1.mean(0).mean(-1),-1),torch.unsqueeze(lat1.mean(0).mean(-1),-1)),-1)# (N,)

        x_ind = torch.concat((
                              x_indexs[:, 0:1] ,
                              x_indexs[:, 1:] )
                             , axis=1)

        embed = self.embedding(x_ind.float()).unsqueeze(0)
        if self.gcn_smooth_layers is not None:
            for _, l in enumerate(self.gcn_smooth_layers):
                embed = l(embed)  # (1,N,d_model) -> (1,N,d_model)
        x = x + embed.unsqueeze(2)  # (B, N, T, d_model)+(1, N, 1, d_model)

        return self.dropout(x)


class TemporalPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len, lookup_index=None):
        super(TemporalPositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)
        self.lookup_index = lookup_index
        self.max_len = max_len
        # computing the positional encodings once in log space
        pe = torch.zeros(max_len, d_model)
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

        pe = pe.unsqueeze(0).unsqueeze(0)  # (1, 1, T_max, d_model)
        self.register_buffer('pe', pe)
        # register_buffer:
        # Adds a persistent buffer to the module.
        # This is typically used to register a buffer that should not to be considered a model parameter.

    def forward(self, x,t):
        '''
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        if self.lookup_index is not None:
            x = x + self.pe[:, :, self.lookup_index, :]  # (batch_size, N, T, F_in) + (1,1,T,d_model)
        else:
            x = x + self.pe[:, :, :x.size(2), :]

        return self.dropout(x.detach())


class SublayerConnection(nn.Module):
    '''
    A residual connection followed by a layer norm
    '''
    def __init__(self, size, dropout, residual_connection, use_LayerNorm):
        super(SublayerConnection, self).__init__()
        self.residual_connection = residual_connection
        self.use_LayerNorm = use_LayerNorm
        self.dropout = nn.Dropout(dropout)
        if self.use_LayerNorm:
            self.norm = nn.LayerNorm(size)

    def forward(self, x, sublayer):
        '''
        :param x: (batch, N, T, d_model)
        :param sublayer: nn.Module
        :return: (batch, N, T, d_model)
        '''
        if self.residual_connection and self.use_LayerNorm:
            return x + self.dropout(sublayer(self.norm(x)))
        if self.residual_connection and (not self.use_LayerNorm):
            return x + self.dropout(sublayer(x))
        if (not self.residual_connection) and self.use_LayerNorm:
            return self.dropout(sublayer(self.norm(x)))


class PositionWiseGCNFeedForward(nn.Module):
    def __init__(self, gcn, dropout=.0):
        super(PositionWiseGCNFeedForward, self).__init__()
        self.gcn = gcn
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        '''
        :param x:  (B, N_nodes, T, F_in)
        :return: (B, N, T, F_out)
        '''
        return self.dropout(F.relu(self.gcn(x)))


def attention(query, key, value, mask=None, dropout=None):
    '''
    :param query:  (batch, N, h, T1, d_k)
    :param key: (batch, N, h, T2, d_k)
    :param value: (batch, N, h, T2, d_k)
    :param mask: (batch, 1, 1, T2, T2)
    :param dropout:
    :return: (batch, N, h, T1, d_k), (batch, N, h, T1, T2)
    '''
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  # scores: (batch, N, h, T1, T2)

    if mask is not None:
        scores = scores.masked_fill_(mask == 0, -1e9)  # -1e9 means attention scores=0
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    # p_attn: (batch, N, h, T1, T2)

    return torch.matmul(p_attn, value), p_attn  # (batch, N, h, T1, d_k), (batch, N, h, T1, T2)


class MultiHeadAttention(nn.Module):
    def __init__(self, nb_head, d_model, dropout=.0):
        super(MultiHeadAttention, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask: (batch, T, T)
        :return: x: (batch, N, T, d_model)
        '''
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        query, key, value = [l(x).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3) for l, x in
                             zip(self.linears, (query, key, value))]

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class MultiHeadAttentionAwareTemporalContex_qc_kc(nn.Module):  # key causal; query causal;
    def __init__(self, nb_head, d_model, num_of_lags, points_per_lag, kernel_size=3, dropout=.0):
        '''
        :param nb_head:
        :param d_model:
        :param num_of_weeks:
        :param num_of_days:
        :param num_of_hours:
        :param points_per_hour:
        :param kernel_size:
        :param dropout:
        '''
        super(MultiHeadAttentionAwareTemporalContex_qc_kc, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 2)  # 2 linear layers: 1  for W^V, 1 for W^O
        self.padding = kernel_size - 1
        self.conv1Ds_aware_temporal_context = clones(nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.padding)), 2)  # # 2 causal conv: 1  for query, 1 for key
        self.dropout = nn.Dropout(p=dropout)
        self.n_length = num_of_lags * points_per_lag


    def forward(self, query, key, value, mask=None, query_multi_segment=False, key_multi_segment=False):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask:  (batch, T, T)
        :param query_multi_segment: whether query has mutiple time segments
        :param key_multi_segment: whether key has mutiple time segments
        if query/key has multiple time segments, causal convolution should be applied separately for each time segment.
        :return: (batch, N, T, d_model)
        '''

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # deal with key and query: temporal conv
        # (batch, N, T, d_model)->permute(0, 3, 1, 2)->(batch, d_model, N, T) -conv->(batch, d_model, N, T)-view->(batch, h, d_k, N, T)-permute(0,3,1,4,2)->(batch, N, h, T, d_k)

        if query_multi_segment and key_multi_segment:
            query_list = []
            key_list = []
            if self.n_length > 0:
                query_h, key_h = [l(x.permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query[:, :, self.w_length + self.d_length:self.w_length + self.d_length + self.h_length, :], key[:, :, self.w_length + self.d_length:self.w_length + self.d_length + self.h_length, :]))]
                query_list.append(query_h)
                key_list.append(key_h)

            query = torch.cat(query_list, dim=3)
            key = torch.cat(key_list, dim=3)

        elif (not query_multi_segment) and (not key_multi_segment):

            query, key = [l(x.permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query, key))]

        elif (not query_multi_segment) and (key_multi_segment):

            query = self.conv1Ds_aware_temporal_context[0](query.permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

            key_list = []

            if self.n_length > 0:
                key_h = self.conv1Ds_aware_temporal_context[1](key[:, :,0:self.n_length, :].permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
                key_list.append(key_h)

            key = torch.cat(key_list, dim=3)

        else:
            import sys
            print('error')
            sys.out

        # deal with value:
        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        value = self.linears[0](value).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3)

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class MultiHeadAttentionAwareTemporalContex_q1d_k1d(nn.Module):  # 1d conv on query, 1d conv on key
    def __init__(self, nb_head, d_model, num_of_lags, points_per_lag,  kernel_size=3, dropout=.0): #num_of_weeks, num_of_days, num_of_hours

        super(MultiHeadAttentionAwareTemporalContex_q1d_k1d, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 2)  # 2 linear layers: 1  for W^V, 1 for W^O
        self.padding = (kernel_size - 1)//2

        self.conv1Ds_aware_temporal_context = clones(
            nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.padding)),
            2)  # # 2 causal conv: 1  for query, 1 for key

        self.dropout = nn.Dropout(p=dropout)
        self.n_length = num_of_lags * points_per_lag  #num_of_hours * points_per_hour


    def forward(self, query, key, value, mask=None, query_multi_segment=False, key_multi_segment=False):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask:  (batch, T, T)
        :param query_multi_segment: whether query has mutiple time segments
        :param key_multi_segment: whether key has mutiple time segments
        if query/key has multiple time segments, causal convolution should be applied separately for each time segment.
        :return: (batch, N, T, d_model)
        '''

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # deal with key and query: temporal conv
        # (batch, N, T, d_model)->permute(0, 3, 1, 2)->(batch, d_model, N, T) -conv->(batch, d_model, N, T)-view->(batch, h, d_k, N, T)-permute(0,3,1,4,2)->(batch, N, h, T, d_k)

        if query_multi_segment and key_multi_segment:
            query_list = []
            key_list = []
            if self.n_length > 0:
                query_h, key_h = [l(x.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query[:, :,0: self.n_length, :], key[:, :, 0: self.n_length, :]))]
                query_list.append(query_h)
                key_list.append(key_h)

            query = torch.cat(query_list, dim=3)
            key = torch.cat(key_list, dim=3)

        elif (not query_multi_segment) and (not key_multi_segment):

            query, key = [l(x.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query, key))]

        elif (not query_multi_segment) and (key_multi_segment):

            query = self.conv1Ds_aware_temporal_context[0](query.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

            key_list = []

            if self.n_length > 0:
                key_h = self.conv1Ds_aware_temporal_context[1](key[:, :, 0:self.n_length, :].permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
                key_list.append(key_h)

            key = torch.cat(key_list, dim=3)

        else:
            import sys
            print('error')
            sys.out

        # deal with value:
        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        value = self.linears[0](value).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3)

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class MultiHeadAttentionAwareTemporalContex_qc_k1d(nn.Module):  # query: causal conv; key 1d conv
    def __init__(self, nb_head, d_model, num_of_lags, points_per_lag,  kernel_size=3, dropout=.0):
        super(MultiHeadAttentionAwareTemporalContex_qc_k1d, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 2)  # 2 linear layers: 1  for W^V, 1 for W^O
        self.causal_padding = kernel_size - 1
        self.padding_1D = (kernel_size - 1)//2
        self.query_conv1Ds_aware_temporal_context = nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.causal_padding))
        self.key_conv1Ds_aware_temporal_context = nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.padding_1D))
        self.dropout = nn.Dropout(p=dropout)
        self.n_length = num_of_lags * points_per_lag


    def forward(self, query, key, value, mask=None, query_multi_segment=False, key_multi_segment=False):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask:  (batch, T, T)
        :param query_multi_segment: whether query has mutiple time segments
        :param key_multi_segment: whether key has mutiple time segments
        if query/key has multiple time segments, causal convolution should be applied separately for each time segment.
        :return: (batch, N, T, d_model)
        '''

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # deal with key and query: temporal conv
        # (batch, N, T, d_model)->permute(0, 3, 1, 2)->(batch, d_model, N, T) -conv->(batch, d_model, N, T)-view->(batch, h, d_k, N, T)-permute(0,3,1,4,2)->(batch, N, h, T, d_k)

        if query_multi_segment and key_multi_segment:
            query_list = []
            key_list = []
            if self.n_length > 0:
                query_h = self.query_conv1Ds_aware_temporal_context(query[:, :, 0: self.n_length, :].permute(0, 3, 1, 2))[:, :, :, :-self.causal_padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1,
                                                                                                                4, 2)
                key_h = self.key_conv1Ds_aware_temporal_context(key[:, :,0: self.n_length, :].permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

                query_list.append(query_h)
                key_list.append(key_h)

            query = torch.cat(query_list, dim=3)
            key = torch.cat(key_list, dim=3)

        elif (not query_multi_segment) and (not key_multi_segment):

            query = self.query_conv1Ds_aware_temporal_context(query.permute(0, 3, 1, 2))[:, :, :, :-self.causal_padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
            key = self.key_conv1Ds_aware_temporal_context(query.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

        elif (not query_multi_segment) and (key_multi_segment):

            query = self.query_conv1Ds_aware_temporal_context(query.permute(0, 3, 1, 2))[:, :, :, :-self.causal_padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

            key_list = []

            if self.n_length > 0:
                key_h = self.key_conv1Ds_aware_temporal_context(key[:, :, 0: self.n_length, :].permute(0, 3, 1, 2)).contiguous().view(
                    nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
                key_list.append(key_h)

            key = torch.cat(key_list, dim=3)

        else:
            import sys
            print('error')
            sys.out

        # deal with value:
        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        value = self.linears[0](value).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3)

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, trg_dim,decoder1, src_dense, encode_temporal_position,decode_temporal_position, generator1, DEVICE,spatial_position): #generator2,
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder1 = decoder1
        # self.decoder2 = decoder2
        self.src_embed = src_dense
        # self.trg_embed = trg_dense
        self.encode_temporal_position = encode_temporal_position
        self.decode_temporal_position = decode_temporal_position
        self.prediction_generator1 = generator1
        # self.prediction_generator2 = generator2
        self.spatial_position = spatial_position
        self.trg_dim = trg_dim
        self.to(DEVICE)

    def forward(self, src, trg,x,y,te,td):
        '''
        src:  (batch_size, N, T_in, F_in)
        trg: (batch, N, T_out, F_out)
        '''
        encoder_output = self.encode(src,x,y,te)  # (batch_size, N, T_in, d_model)

        trg_shape = self.trg_dim#int(trg.shape[-1]/2)
        return self.decode1(trg[:, :, :, -trg_shape:], encoder_output, trg[:, :, :, :trg_shape],x,y,td)#trg[:, :, :, :trg_shape],x,y,td)  # src[:,:,-1:,:2])#

    def encode(self, src,x,y,t):
        '''
        src: (batch_size, N, T_in, F_in)
        '''
        src_emb = self.src_embed(src)
        if self.encode_temporal_position ==False:
            src_tmpo_emb = src_emb
        else:
            src_tmpo_emb = self.encode_temporal_position(src_emb,t)
        if self.spatial_position == False:
            h = src_tmpo_emb
        else:
            h = self.spatial_position(src_tmpo_emb, x,y)

        return self.encoder(h)


    def decode1(self, trg, encoder_output,encoder_input,x,y,t):
        trg_embed = self.src_embed
        trg_emb_shape = self.trg_dim
        trg_emb = torch.matmul(trg, list(trg_embed.parameters())[0][:, trg_emb_shape:].T)
        if self.encode_temporal_position ==False:
            trg_tempo_emb = trg_emb
        else:
            trg_tempo_emb = self.decode_temporal_position(trg_emb, t)

        if self.spatial_position ==False:
            a =  self.prediction_generator1(self.decoder1(trg_tempo_emb, encoder_output))+encoder_input#[:,:,:,0:2]
            return a
        else:
            a =  self.prediction_generator1(self.decoder1(self.spatial_position(trg_tempo_emb,x,y), encoder_output))+encoder_input#[:,:,:,0:2]
            return a




class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, gcn, dropout, residual_connection=True, use_LayerNorm=True):
        super(EncoderLayer, self).__init__()
        self.residual_connection = residual_connection
        self.use_LayerNorm = use_LayerNorm
        self.self_attn = self_attn
        self.feed_forward_gcn = gcn
        if residual_connection or use_LayerNorm:
            self.sublayer = clones(SublayerConnection(size, dropout, residual_connection, use_LayerNorm), 2)
        self.size = size

    def forward(self, x):
        '''
        :param x: src: (batch_size, N, T_in, F_in)
        :return: (batch_size, N, T_in, F_in)
        '''
        if self.residual_connection or self.use_LayerNorm:
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, query_multi_segment=True, key_multi_segment=True))
            return self.sublayer[1](x, self.feed_forward_gcn)
        else:
            x = self.self_attn(x, x, x, query_multi_segment=True, key_multi_segment=True)
            return self.feed_forward_gcn(x)


class Encoder(nn.Module):
    def __init__(self, layer, N):
        '''
        :param layer:  EncoderLayer
        :param N:  int, number of EncoderLayers
        '''
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x):
        '''
        :param x: src: (batch_size, N, T_in, F_in)
        :return: (batch_size, N, T_in, F_in)
        '''
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)


class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, gcn, dropout, residual_connection=True, use_LayerNorm=True):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward_gcn = gcn
        self.residual_connection = residual_connection
        self.use_LayerNorm = use_LayerNorm
        if residual_connection or use_LayerNorm:
            self.sublayer = clones(SublayerConnection(size, dropout, residual_connection, use_LayerNorm), 3)

    def forward(self, x, memory):
        '''
        :param x: (batch_size, N, T', F_in)
        :param memory: (batch_size, N, T, F_in)
        :return: (batch_size, N, T', F_in)
        '''
        m = memory
        tgt_mask = subsequent_mask(x.size(-2)).to(m.device)  # (1, T', T')
        if self.residual_connection or self.use_LayerNorm:
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, query_multi_segment=False, key_multi_segment=False))  # output: (batch, N, T', d_model)
            x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, query_multi_segment=False, key_multi_segment=True))  # output: (batch, N, T', d_model)
            return self.sublayer[2](x, self.feed_forward_gcn)  # output:  (batch, N, T', d_model)
        else:
            x = self.self_attn(x, x, x, tgt_mask, query_multi_segment=False, key_multi_segment=False)  # output: (batch, N, T', d_model)
            x = self.src_attn(x, m, m, query_multi_segment=False, key_multi_segment=True)  # output: (batch, N, T', d_model)
            return self.feed_forward_gcn(x)  # output:  (batch, N, T', d_model)


class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x, memory):
        '''
        :param x: (batch, N, T', d_model)
        :param memory: (batch, N, T, d_model)
        :return:(batch, N, T', d_model)
        '''
        for layer in self.layers:
            x = layer(x, memory)
        return self.norm(x)

class EmbedLinear(nn.Module):
    def __init__(self, encoder_input_size, d_model,bias=False):
        '''
        :param layer:  EncoderLayer
        :param N:  int, number of EncoderLayers
        '''
        super(EmbedLinear, self).__init__()
        self.layers = nn.Linear(encoder_input_size, d_model, bias=bias)

    def forward(self, x):
        '''
        :param x: src: (batch_size, N, T_in, F_in)
        :return: (batch_size, N, T_in, F_in)
        '''
        #for layer in self.layers:
        y = self.layers(x)
        return y

def search_index(max_len, num_of_depend, num_for_predict,points_per_hour, units):
    '''
    Parameters
    ----------
    max_len: int, length of all encoder input
    num_of_depend: int,
    num_for_predict: int, the number of points will be predicted for each sample
    units: int, week: 7 * 24, day: 24, recent(hour): 1
    points_per_hour: int, number of points per hour, depends on data
    Returns
    ----------
    list[(start_idx, end_idx)]
    '''
    x_idx = []
    for i in range(1, num_of_depend + 1):
        start_idx = max_len - points_per_hour * units * i
        for j in range(num_for_predict):
            end_idx = start_idx + j
            x_idx.append(end_idx)
    return x_idx



def make_model(DEVICE,logitudelatitudes, num_layers, encoder_input_size,decoder_input_size, decoder_output_size, d_model, adj_mx, nb_head, num_of_lags,points_per_lag,
                 num_for_predict, dropout=.0, aware_temporal_context=True,
               ScaledSAt=True, SE=True, TE=True, kernel_size=3, smooth_layer_num=0, residual_connection=True, use_LayerNorm=True):

    # LR rate means: graph Laplacian Regularization

    c = copy.deepcopy

    norm_Adj_matrix = torch.from_numpy(norm_Adj(adj_mx)).type(torch.FloatTensor).to(DEVICE)  # 通过邻接矩阵,构造归一化的拉普拉斯矩阵

    num_of_vertices = norm_Adj_matrix.shape[0]

    src_dense = EmbedLinear(encoder_input_size, d_model, bias=False)#nn.Linear(encoder_input_size, d_model, bias=False)

    if ScaledSAt:  # employ spatial self attention
        position_wise_gcn = PositionWiseGCNFeedForward(spatialAttentionScaledGCN(norm_Adj_matrix, d_model, d_model), dropout=dropout)
    else:  #
        position_wise_gcn = PositionWiseGCNFeedForward(spatialGCN(norm_Adj_matrix, d_model, d_model), dropout=dropout)

    # encoder temporal position embedding
    max_len = num_of_lags

    if aware_temporal_context:  # employ temporal trend-aware attention
        attn_ss = MultiHeadAttentionAwareTemporalContex_q1d_k1d(nb_head, d_model, num_of_lags, points_per_lag,  kernel_size, dropout=dropout)
        attn_st = MultiHeadAttentionAwareTemporalContex_qc_k1d(nb_head, d_model,num_of_lags, points_per_lag,  kernel_size, dropout=dropout)
        att_tt = MultiHeadAttentionAwareTemporalContex_qc_kc(nb_head, d_model, num_of_lags, points_per_lag,  kernel_size, dropout=dropout)
    else:  # employ traditional self attention
        attn_ss = MultiHeadAttention(nb_head,d_model, dropout=dropout) #d_model, dropout=dropout)
        attn_st = MultiHeadAttention(nb_head,d_model, dropout=dropout)# d_model, dropout=dropout)
        att_tt = MultiHeadAttention(nb_head,d_model, dropout=dropout) #d_model, dropout=dropout)

    encode_temporal_position = TemporalPositionalEncoding(d_model, dropout, max_len)  #   en_lookup_index   decoder temporal position embedding
    decode_temporal_position = TemporalPositionalEncoding(d_model, dropout, num_for_predict)
    spatial_position = SpatialPositionalEncoding_RBF(d_model, logitudelatitudes,num_of_vertices, dropout, GCN(norm_Adj_matrix, d_model, d_model), smooth_layer_num=smooth_layer_num) #logitudelatitudes,


    encoderLayer = EncoderLayer(d_model, attn_ss, c(position_wise_gcn), dropout, residual_connection=residual_connection, use_LayerNorm=use_LayerNorm)

    encoder = Encoder(encoderLayer, num_layers)

    decoderLayer1 = DecoderLayer(d_model, att_tt, attn_st, c(position_wise_gcn), dropout, residual_connection=residual_connection, use_LayerNorm=use_LayerNorm)

    decoder1 = Decoder(decoderLayer1, num_layers)

    generator1 = nn.Linear(d_model, decoder_output_size)#



    model = EncoderDecoder(encoder,decoder_output_size,
                       decoder1,
                           src_dense,
                       encode_temporal_position,
                       decode_temporal_position,
                       generator1,
                       DEVICE,
                       spatial_position) #,generator2

    # param init
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

2、实验结果 

模型经过399个epoch训练后,在验证阶段,损失为0.1143,其他性能指标包括c-r为0.0185和L-hr为0.0215,验证阶段耗时约3.655秒,模型在第308个epoch达到最佳性能。第二张图的训练和验证损失曲线显示,训练损失从高到低逐渐稳定,验证损失经过初始波动后也趋于平稳,这表明模型随着训练逐渐适应数据,达到了较好的泛化能力。

三、总结

本周阅读的这篇论文,受益颇多,回顾了很多知识,比如说GCN、多头自注意力等,文中提到的方法aGNN大幅减轻了基于物理模型的计算负担,特别是在需要处理大量注入井和长期管理的情景中,显著提高了计算效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/714297.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

工厂方法模式实战之某商场一次促销活动

目录 1.5.1、前言1.5.2、实战场景简介1.5.3、开发环境1.5.4、用传统的if-else语句实现1.5.4.1、工程结构1.5.4.2、if-else需求实现1.5.4.3、测试验证 1.5.5、工厂模式优化代码1.5.5.1、工程结构1.5.5.2、代码实现1.5.5.2.1、定义各种商品发放接口及接口实现1.5.5.2.2、定义工厂…

项目经理,请勇敢Say No~

为什么要say no? 培养say no的勇气 优雅的say no! say no 三部曲,项目经理,你准备好了吗? 为什么要say no? 保护项目完整性的屏障 项目管理的核心在于平衡时间、成本与质量三大要素,任何一项的…

STL——set、map、multiset、multimap的介绍及使用

文章目录 关联式容器键值对树形结构与哈希结构setset的介绍set的使用set的模板参数列表set的构造set的使用set的迭代器使用演示 multisetmultiset演示 mapmap的定义方式map的插入map的查找map的[ ]运算符重载map的迭代器遍历multimapmultimap的介绍multimap的使用 在OJ中的使用…

全球“抱团”美股,美股“抱团”AI

内容提要 过去一个月内,全球约有300亿美元新资金流入股票基金,其中高达94%投向了美国资产;一季度,海外投资者购入了1870亿美元美国公司债券,同比增长61%。 文章正文 尽管美国面临债务问题和大选带来的政治分歧&#…

索引-定义、创建(CREATE INDEX)、删除(DROP INDEX)

一、概述 1、索引是SQL语言定义的一种数据对象,是大多数DBMS为数据库中基本表创建的一种辅助存取结构,用于响应特定查询条件进行查询时的查询速度,DBMS根据查询条件从数据库文件中,选择出一条或者多条数据记录以供检索&#xff0…

【JS重点17】原型继承

目录 一:什么是原型继承 二:通过赋值方式实现原型继承 三:通过构造函数实现原型继承 四:如何赚钱 一:什么是原型继承 通过往构造函数上的原型对象添加属性和方法,再new一个实例对象,从而实例…

18. 第十八章 继承

18. 继承 和面向对象编程最常相关的语言特性就是继承(inheritance). 继承值得是根据一个现有的类型, 定义一个修改版本的新类的能力. 本章中我会使用几个类来表达扑克牌, 牌组以及扑克牌性, 用于展示继承特性.如果你不玩扑克, 可以在http://wikipedia.org/wiki/Poker里阅读相关…

CSS期末复习速览(二)

1.元素显示模式分为三种&#xff1a;块元素&#xff0c;行内元素&#xff0c;行内块元素 2.块元素&#xff1a;常见的块元素&#xff1a;<h1>~<h6> <p> <div> <ul> <ol> <li>&#xff0c;特点&#xff1a;自己独占一行&a…

需求:如何给文件添加水印

今天给大家介绍一个简单易用的水印添加框架&#xff0c;框架抽象了各个文件类型的对于水印添加的方法。仅使用几行代码即可为不同类型的文件添加相同样式的水印。 如果你有给PDF、图片添加水印的需求&#xff0c;EasyWatermark是一个很好的选择&#xff0c;主要功能就是传入一…

嵌入式实训day5

1、 from machine import Pin import time # 定义按键引脚控制对象 key1 Pin(27,Pin.IN, Pin.PULL UP) key2 Pin(26,Pin.IN, Pin.PULL UP)led1 Pin(15,Pin.ouT, value0) led2 Pin(2,Pin.ouT, value0) led3 Pin(0,Pin.ouT, value0) # 定义key1按键中断处理函数 def key1 ir…

2.线上论坛项目

一、项目介绍 线上论坛 相关技术&#xff1a;SpringBootSpringMvcMybatisMysqlSwagger项目简介&#xff1a;本项目是一个功能丰富的线上论坛&#xff0c;用户可编辑、发布、删除帖子&#xff0c;并评论、点赞。帖子按版块分类&#xff0c;方便查找。同时&#xff0c;用户可以…

【CT】LeetCode手撕—121. 买卖股票的最佳时机

目录 题目1- 思路2- 实现⭐121. 买卖股票的最佳时机——题解思路 2- ACM实现 题目 原题连接&#xff1a;121. 买卖股票的最佳时机 1- 思路 模式识别 模式1&#xff1a;只能某一天买入 ——> 买卖一次 ——> dp 一次的最大利润 动规五部曲 1.定义dp数组&#xff0c;确…

跻身中国市场前三,联想服务器的“智变”与“质变”

IDC发布的《2024年第一季度中国x86服务器市场报告》显示&#xff0c;联想服务销售额同比增长200.2%&#xff0c;在前十厂商中同比增速第一&#xff0c;并跻身中国市场前三&#xff0c;迈入算力基础设施“第一阵营”。 十年砺剑联想梦&#xff0c;三甲登榜领风骚。探究联想服务器…

IDEA模版快速生成Java方法体

新建模版组myLive 在模版组下新建模版finit 在模版text内输入以下脚本 LOGGER.info("$className$.$methodName$>$parmas1$", $parmas2$); try {} catch (Exception e) {LOGGER.error("$className$.$methodName$>error:", e); }LOGGER.info("$c…

redis未授权到getshell

0 前言 现在是redis数据库未授权访问到getshell的部分了,不好意思&#xff0c;因为个人原因&#xff0c;和上篇mysql的getshell文章间隔较久. 1 漏洞产生原因 redis安装完之后&#xff0c;默认情况下绑定在 0.0.0.0:6379&#xff0c;且没有对登录IP做限制&#xff0c;并且没…

T113 Tina5.0 添加板级支持包

文章目录 环境介绍添加板级支持包修改板级文件验证总结 环境介绍 硬件&#xff1a;韦东山T113工业板 软件&#xff1a;全志Tina 5.0 添加板级支持包 进入源码目录<SDK>/device/config/chips/t113/configs&#xff0c;可以看到有如下文件夹&#xff1a; 复制一份evb1_…

python15 数据类型 集合类型

集合类型 无序的不重复元素序列 集合中只能存储不可变的数据类型 声明集合 使用 {} 定义 与列表&#xff0c;字典一样&#xff0c;都是可变数据类型 代码 集合类型 无序的不重复元素序列 集合中只能存储不可变的数据类型 声明集合 使用 大括号{} 定义 与列表&#xff0c;字典一…

linux驱动学习(十)之内存管理

一、linux内核启动过程中&#xff0c;关于内存信息 1、内核的内存的分区 [ 0.000000] Memory: 1024MB 1024MB total ---> 1G [ 0.000000] Memory: 810820k/810820k available, 237756k reserved, 272384K highmem [ 0.000000] Virtual kernel memory layout: 内…

UnityAPI学习之碰撞检测与触发检测

碰撞检测 发生碰撞检测的前提&#xff1a; 1. 碰撞的物体需要有Rigidbody组件和boxcllidder组件 2. 被碰撞的物体需要有boxcollider组件 示例1&#xff1a;被碰撞的物体拥有Rigidbody组件 两个物体会因为都具有刚体的组件而发生力的作用&#xff0c;如下图所示&#xff0c…

人工智能模型组合学习的理论和实验实践

组合学习&#xff0c;即掌握将基本概念结合起来构建更复杂概念的能力&#xff0c;对人类认知至关重要&#xff0c;特别是在人类语言理解和视觉感知方面。这一概念与在未观察到的情况下推广的能力紧密相关。尽管它在智能中扮演着核心角色&#xff0c;但缺乏系统化的理论及实验研…