机器学习周记(第四十二周:AT-LSTM)2024.6.3~2024.6.9

目录

  • 摘要
  • Abstract
  • 一、文献阅读
    • 1. 题目
    • 2. abstract
    • 3. 网络架构
      • 3.1 LSTM
      • 3.2 注意力机制概述
      • 3.3 AT-LSTM
      • 3.4 数据预处理
    • 4. 文献解读
      • 4.1 Introduction
      • 4.2 创新点
      • 4.3 实验过程
      • 4.3.1 训练参数
      • 4.3.2 数据集
      • 4.3.3 实验设置
      • 4.3.4 实验结果
    • 5. 基于pytorch的transformer

摘要

本周阅读了题为Water Quality Prediction Based on LSTM and Attention Mechanism: A Case Study of the Burnett River, Australia的论文。这项工作提出了一种基于长期短期记忆的神经网络和 注意力机制的混合模型——AT-LSTM。其中,LSTM缺乏对子窗口特征进行不同程度关注的能力,这可能会导致一些相关信息被忽略,无法重视时间序列的重要特征。该文应用注意力机制来有效捕获更远的关键信息,并通过在每个时间步对隐藏层元素进行加权来增强重要特征对预测模型的影响。这样就提出了一个集成模型并有效地获得了统计特征。基于真实数据的实验证明,注意力机制的加入提高了 LSTM 模型的预测性能。

Abstract

This week read the paper titled Water Quality Prediction Based on LSTM and Attention Mechanism: A Case Study of the Burnett River, Australia. This work proposes a hybrid model of neural network and attention mechanism based on long-term short-term memory, AT-LSTM. Among them, LSTM lacks the ability to pay attention to the features of sub-windows to varying degrees, which may lead to some relevant information being ignored and unable to pay attention to the important features of the time series. In this paper, the attention mechanism is applied to effectively capture the key information in the distance, and the influence of important features on the prediction model is enhanced by weighting the hidden layer elements at each time step. In this way, an ensemble model is proposed and statistical features are effectively obtained. Experiments based on real data show that the addition of attention mechanism improves the prediction performance of the LSTM model.

一、文献阅读

1. 题目

标题:Water Quality Prediction Based on LSTM and Attention Mechanism: A Case Study of the Burnett River, Australia

作者:Honglei Chen, etc.

期刊名:Sustainability

链接:https://doi.org/10.3390/su142013231

2. abstract

该研究旨在开发长短期记忆(LSTM)网络及其基于注意力的(AT-LSTM)模型,以实现澳大利亚伯内特河水质的预测。该研究开发的模型在对伯内特河断面水质数据进行特征提取后,考虑不同时刻序列对预测结果的影响,引入注意力机制,增强关键特征对预测结果的影响。该研究利用 LSTM 和 AT-LSTM 模型对伯内特河的溶解氧 (DO) 进行一步预报和多步预报,并对结果进行比较。研究结果表明,注意力机制的加入提高了 LSTM 模型的预测性能。因此,本研究开发的基于 AT-LSTM 的水质预测模型显示出比 LSTM 模型更强的能力,可以为澳大利亚昆士兰州水质改善计划提供准确预测伯内特河水质的能力。

The present study aims to develop a long short-term memory (LSTM) network and its attention-based (AT-LSTM) model to achieve the prediction of water quality in the Burnett River of Australia. The models developed in this study introduced an attention mechanism after feature extraction of water quality data in the section of Burnett River considering the effect of the sequences on the prediction results at different moments to enhance the influence of key features on the prediction results. This study provides one-step-ahead forecasting and multistep forward forecasting of dissolved oxygen (DO) of the Burnett River utilizing LSTM and AT-LSTM models and the comparison of the results. The research outcomes demonstrated that the inclusion of the attention mechanism improves the prediction performance of the LSTM model. Therefore, the AT-LSTM-based water quality forecasting model, developed in this study, demonstrated its stronger capability than the LSTM model for informing the Water Quality Improvement Plan of Queensland, Australia, to accurately predict water quality in the Burnett River.

3. 网络架构

该文提出的模型改进了编码器-解码器网络结构,使其能够更好地适应多步时间序列预测。同时,为了解决时间序列数据的降噪问题,该工作采用SG滤波器对原始数据进行去噪。G 滤波器可以有效保留时间序列的特征,并去除其噪声。同时,结合基于LSTM的编码器-解码器模型,模型显着提高了多步预测的准确性。

3.1 LSTM

Long Short-term Memory的结构:

  • LSTM中的memory存储空间是受网络控制的
  • 输入门/写门(input gate):当收到对应的网络网络信号时,才可以向空间中写入信息
  • 记忆门/读门(output gate):收到信号后,才能从空间中读取信息
  • 遗忘门(forget gate):收到信号后,会选择性的删除存储空间中的信息

这个结构的记忆是相对短时的,故为short-term;而不仅仅像RNN中仅保留上次输入的记忆,故Long Short-term;同时需要forget gate删除一些信息来保持信息的有效性,故不为long-term

LSTM的运行逻辑如下

  • 通常信号控制为sigmoid function,这样可以保证数值分布在0-1
  • 假设输入为 g ( z ) g(z) g(z),输入门为 f ( z i ) f(z_i) f(zi),memory存储空间中为c,遗忘门为 f ( z f ) f(z_f) f(zf)
  • 在上述情况下,存储空间中的值 c ′ = g ( z ) f ( z i ) + c f ( z f ) c'=g(z)f(z_i)+cf(z_f) c=g(z)f(zi)+cf(zf)
  • 若输出门为 f ( z 0 ) f(z_0) f(z0),则 a = h ( c ′ ) f ( z 0 ) a=h(c')f(z_0) a=h(c)f(z0)

LSTM总体框架简要概括如下,下图上半部分为LSTM的结构,后半部分为各个门对应计算

image-20240428172104892

3.2 注意力机制概述

注意力机制的简要框架概括如下图

image-20240428172203279

主要用到的方法是dot-product

在网络中使用dot-product计算相关性的流程如下,假设要查询 a 1 a^1 a1与其他向量的相关性

  • 首先,计算 q u e r y query query向量 q 1 = W q a 1 q^1=W^qa^1 q1=Wqa1,之后计算 k e y key key向量 k i = W k a i k^i=W^ka^i ki=Wkai a i a^i ai为输入序列中的所有向量。
  • 其次,计算 a t t e n t i o n   s c o r e attention\ score attention score,若查询向量对应 a 1 a^1 a1、关键词向量对应 a 2 a^2 a2,则有 α 1 , 2 = q 1 ⋅ k 2 \alpha_{1,2}=q^1\cdot k^2 α1,2=q1k2
    • 以此类推计算所有向量的attention score
  • 之后,将所有的 a t t e n t i o n   s c o r e attention\ score attention score输入soft-max中,将其映射为一个分布, α 1 , 2 \alpha_{1,2} α1,2对应的输出为 α 1 , 2 ′ \alpha'_{1,2} α1,2
  • 最后,将 a i a^i ai乘上矩阵 W v W^v Wv,得到 v i v^i vi,用 α 1 , i ′ \alpha'_{1,i} α1,i乘上 v i v^i vi,将所有的按照该流程的得到的结果累加 b 1 = ∑ i α 1 , i ′ v i b^1=\sum_i{\alpha'_{1,i}v^i} b1=iα1,ivi a i a^i ai为输入序列中的所有向量。若其他向量 b i b^i bi b 1 b^1 b1越相近,则 a i a^i ai a 1 a^1 a1越相近

大致计算过程如下

多头注意力机制

以下以两个head为例,计算过程如下

3.3 AT-LSTM

不同时序输入序列首先经过LSTM网络,然后作为注意力层的输入,经过全连接层和softmax激活后输入多头注意力层,经过flattern扁平化操作后输入全连接层,最后输出结果

image-20240428172313626

该模型的主要思想是通过对神经网络隐含层元素进行自适应加权,减少无关因素对结果的影响,突出相关因素的影响,从而提高预测精度。模型框架如图7所示,主要组成部分是LSTM层和注意力层。

3.4 数据预处理

  1. 将数据清洗,将异常值设定为控制,然后通过缺失值补充填充空值。
  2. 首先,应用Pearson相关性检验选取特征,不同水质参数之间的相关性分析、执行,并且与要预测的特征相关的关键特征被用作模型的输入。然后,使用窗口大小为100的滑动窗口技术来捕获水质变量的趋势。最后通过最小-最大归一化用于减轻不同特征尺度对模型训练的影响。
  3. 最后将经过数据预处理的数据用于训练网络

综上,算法的整体结构如下

image-20240428173141089

4. 文献解读

4.1 Introduction

研究以水质评价的关键参数DO作为模型构建和预测评价的目标。以往的方法仍然没有充分学习时间序列中隐藏的相关特征,这显着影响了预测精度。LSTM缺乏对子窗口特征进行不同程度关注的能力,这可能会导致一些相关信息被忽略,无法重视时间序列的重要特征。该文应用注意力机制来有效捕获更远的关键信息,并通过在每个时间步对隐藏层元素进行加权来增强重要特征对预测模型的影响。在此基础上,引入了注意力机制,并在LSTM模型的基础上开发了AT-LSTM模型,重点是更好地捕捉水质变量。利用水质监测原始数据预测了澳大利亚伯内特河河段的溶解氧浓度。最后将预测结果与LSTM模型进行比较。我们的目标是实现多元时间数据的长期依赖性和隐藏相关性特征的自适应学习,以使河流水质预测更加准确。伯内特河被认为是一个案例研究,以说明所提出的 AT-LSTM 模型的适用性。

4.2 创新点

这项工作设计了一种混合模型,使用基于 LSTM和注意力机制的神经网络(称为 AT-LSTM)来预测未来的水质。主要贡献总结如下:

  1. 提出了一种改进的网络结构,可以更好地预测多步水质时间序列数据。因此,所提出的AT-LSTM可以更好地处理时间序列数据中的长序列。
  2. 创新性地将注意力机制同LSTM结合和集成,显着提高了多步预测精度。

4.3 实验过程

4.3.1 训练参数

模型结构和主要参数如表3所示。该文通过试错法将时间窗口设置为100,利用贝叶斯优化[56]进行模型超参数优化,识别出相对较好的超参数和激活函数。

image-20240428172614050

4.3.2 数据集

研究使用的数据为伯内特河自动监测点的水质数据,其位置及流域边界如图1所示。

image-20240428170905747

为保证模型的可靠性和适用性,使用了伯内特河2015年1月至2020年1月采集的水质监测数据。每半小时收集一次数据,包括五个特征:水温 (Temp)、pH、溶解氧 (DO)、电导率 (EC)、叶绿素-a (Chl-a) 和浊度 (NTU)。该文采用每小时39752个特征的水质数据和溶解氧作为输出变量。表1显示了数据的描述性统计。

image-20240428170953054

研究期间DO的变化如图2所示。

image-20240428171059266

水质数据按照8:1:1的比例分为三个数据集:训练集、验证集和测试集。在本研究中,训练集包含31,802个每小时条目(从2015年1月1日到2019年2月4日),验证集包含3975个每小时条目(从2019年2月4日到2019年7月20日),测试集包含3975个每小时条目(从2019年2月4日到2019年7月20日)。 2019年7月20日至2020年1月1日)。

4.3.3 实验设置

评估指标:采用平均绝对误差(MAE)、均方根误差(RMSE)和决定系数( R 2 R^2 R2)来定量评价模型预测效果,计算方法如下

image-20240428173303669

使用MSE作为模型的损失函数,并使用以下标准方程进行计算:

image-20240428173408459

两个模型均使用 Adam 优化器在训练集上进行训练,批量大小为 64。为了加速误差的收敛,使用了反向传播学习方法。验证集用作early stop方法,以确保模型不会过度训练。

使用LSTM作为基准模型

4.3.4 实验结果

从下图中,可以看到AT-LSTM模型在Burnett River测试集的水质预测方面优于LSTM模型:

image-20240428173619520

在图(b)中,蓝色曲线表示实际值,橙色曲线表示建模的预测值。虽然LSTM可以预测水质变化,但AT-LSTM的预测与实际值的差异较小,表明AT-LSTM的泛化能力比LSTM更强。

image-20240428173649318

下表总结了LSTM和AT-LSTM模型在监测段中预测溶解氧DO任务的性能

image-20240428173721840

以下是两种模型多步预测结果的比较:

image-20240428173753738

5. 基于pytorch的transformer

掩码部分代码如下图

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        """
        类的初始化函数
        d_model:指词嵌入的维度
        vocab:指词表的大小
        """
        super(Embeddings, self).__init__()
        #之后就是调用nn中的预定义层Embedding,获得一个词嵌入对象self.lut
        self.lut = nn.Embedding(vocab, d_model)
        #最后就是将d_model传入类中
        self.d_model =d_model
    def forward(self, x):
        """
        Embedding层的前向传播逻辑
        参数x:这里代表输入给模型的单词文本通过词表映射后的one-hot向量
        将x传给self.lut并与根号下self.d_model相乘作为结果返回
        """
        embedds = self.lut(x)
        return embedds * math.sqrt(self.d_model)

位置编码部分代码如下

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        """
        位置编码器类的初始化函数
        
        共有三个参数,分别是
        d_model:词嵌入维度
        dropout: dropout触发比率
        max_len:每个句子的最大长度
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings
        # 注意下面代码的计算方式与公式中给出的是不同的,但是是等价的,你可以尝试简单推导证明一下。
        # 这样计算是为了避免中间的数值计算结果超出float的范围,
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

编码器代码如下

# 定义一个clones函数,来更方便的将某个结构复制若干份
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class Encoder(nn.Module):
    """
    Encoder
    The encoder is composed of a stack of N=6 identical layers.
    """
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        # 调用时会将编码器层传进来,我们简单克隆N分,叠加在一起,组成完整的Encoder
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

编码器层代码如下

class EncoderLayer(nn.Module):
    "EncoderLayer is made up of two sublayer: self-attn and feed forward"                                                                                                         
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size   # embedding's dimention of model, 默认512

    def forward(self, x, mask):
        # attention sub layer
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        # feed forward sub layer
        z = self.sublayer[1](x, self.feed_forward)
        return z

注意力机制层代码如下

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"

    #首先取query的最后一维的大小,对应词嵌入维度
    d_k = query.size(-1)
    #按照注意力公式,将query与key的转置相乘,这里面key是将最后两个维度进行转置,再除以缩放系数得到注意力得分张量scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    #接着判断是否使用掩码张量
    if mask is not None:
        #使用tensor的masked_fill方法,将掩码张量和scores张量每个位置一一比较,如果掩码张量则对应的scores张量用-1e9这个置来替换
        scores = scores.masked_fill(mask == 0, -1e9)
        
    #对scores的最后一维进行softmax操作,使用F.softmax方法,这样获得最终的注意力张量
    p_attn = F.softmax(scores, dim = -1)
    
    #之后判断是否使用dropout进行随机置0
    if dropout is not None:
        p_attn = dropout(p_attn)
    
    #最后,根据公式将p_attn与value张量相乘获得最终的query注意力表示,同时返回注意力张量
    return torch.matmul(p_attn, value), p_attn

多头注意力机制代码如下

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        #在类的初始化时,会传入三个参数,h代表头数,d_model代表词嵌入的维度,dropout代表进行dropout操作时置0比率,默认是0.1
        super(MultiHeadedAttention, self).__init__()
        #在函数中,首先使用了一个测试中常用的assert语句,判断h是否能被d_model整除,这是因为我们之后要给每个头分配等量的词特征,也就是embedding_dim/head个
        assert d_model % h == 0
        #得到每个头获得的分割词向量维度d_k
        self.d_k = d_model // h
        #传入头数h
        self.h = h
        
        #创建linear层,通过nn的Linear实例化,它的内部变换矩阵是embedding_dim x embedding_dim,然后使用,为什么是四个呢,这是因为在多头注意力中,Q,K,V各需要一个,最后拼接的矩阵还需要一个,因此一共是四个
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        #self.attn为None,它代表最后得到的注意力张量,现在还没有结果所以为None
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        #前向逻辑函数,它输入参数有四个,前三个就是注意力机制需要的Q,K,V,最后一个是注意力机制中可能需要的mask掩码张量,默认是None
        if mask is not None:
            # Same mask applied to all h heads.
            #使用unsqueeze扩展维度,代表多头中的第n头
            mask = mask.unsqueeze(1)
        #接着,我们获得一个batch_size的变量,他是query尺寸的第1个数字,代表有多少条样本
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        # 首先利用zip将输入QKV与三个线性层组到一起,然后利用for循环,将输入QKV分别传到线性层中,做完线性变换后,开始为每个头分割输入,这里使用view方法对线性变换的结构进行维度重塑,多加了一个维度h代表头,这样就意味着每个头可以获得一部分词特征组成的句子,其中的-1代表自适应维度,计算机会根据这种变换自动计算这里的值,然后对第二维和第三维进行转置操作,为了让代表句子长度维度和词向量维度能够相邻,这样注意力机制才能找到词义与句子位置的关系,从attention函数中可以看到,利用的是原始输入的倒数第一和第二维,这样我们就得到了每个头的输入
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch. 
        # 得到每个头的输入后,接下来就是将他们传入到attention中,这里直接调用我们之前实现的attention函数,同时也将mask和dropout传入其中
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear. 
        # 通过多头注意力计算后,我们就得到了每个头计算结果组成的4维张量,我们需要将其转换为输入的形状以方便后续的计算,因此这里开始进行第一步处理环节的逆操作,先对第二和第三维进行转置,然后使用contiguous方法。这个方法的作用就是能够让转置后的张量应用view方法,否则将无法直接使用,所以,下一步就是使用view重塑形状,变成和输入形状相同。  
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        #最后使用线性层列表中的最后一个线性变换得到最终的多头注意力结构的输出
        return self.linears[-1](x)

解码器整体结构代码结构如下

#使用类Decoder来实现解码器
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        #初始化函数的参数有两个,第一个就是解码器层layer,第二个是解码器层的个数N
        super(Decoder, self).__init__()
        #首先使用clones方法克隆了N个layer,然后实例化一个规范化层,因为数据走过了所有的解码器层后最后要做规范化处理。
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        #forward函数中的参数有4个,x代表目标数据的嵌入表示,memory是编码器层的输出,source_mask,target_mask代表源数据和目标数据的掩码张量,然后就是对每个层进行循环,当然这个循环就是变量x通过每一个层的处理,得出最后的结果,再进行一次规范化返回即可。
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

解码器层代码如下

#使用DecoderLayer的类实现解码器层
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        #初始化函数的参数有5个,分别是size,代表词嵌入的维度大小,同时也代表解码器的尺寸,第二个是self_attn,多头自注意力对象,也就是说这个注意力机制需要Q=K=V,第三个是src_attn,多头注意力对象,这里Q!=K=V,第四个是前馈全连接层对象,最后就是dropout置0比率
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        #按照结构图使用clones函数克隆三个子层连接对象
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
 
    def forward(self, x, memory, src_mask, tgt_mask):
        #forward函数中的参数有4个,分别是来自上一层的输入x,来自编码器层的语义存储变量memory,以及源数据掩码张量和目标数据掩码张量,将memory表示成m之后方便使用。
        "Follow Figure 1 (right) for connections."
        m = memory
        #将x传入第一个子层结构,第一个子层结构的输入分别是x和self-attn函数,因为是自注意力机制,所以Q,K,V都是x,最后一个参数时目标数据掩码张量,这时要对目标数据进行遮掩,因为此时模型可能还没有生成任何目标数据。
        #比如在解码器准备生成第一个字符或词汇时,我们其实已经传入了第一个字符以便计算损失,但是我们不希望在生成第一个字符时模型能利用这个信息,因此我们会将其遮掩,同样生成第二个字符或词汇时,模型只能使用第一个字符或词汇信息,第二个字符以及之后的信息都不允许被模型使用。
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        #接着进入第二个子层,这个子层中常规的注意力机制,q是输入x;k,v是编码层输出memory,同样也传入source_mask,但是进行源数据遮掩的原因并非是抑制信息泄露,而是遮蔽掉对结果没有意义的padding。
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        
        #最后一个子层就是前馈全连接子层,经过它的处理后就可以返回结果,这就是我们的解码器结构
        return self.sublayer[2](x, self.feed_forward)

整体网络框架如下

# Model Architecture
#使用EncoderDecoder类来实现编码器-解码器结构
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. 
    Base for this and many other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        #初始化函数中有5个参数,分别是编码器对象,解码器对象,源数据嵌入函数,目标数据嵌入函数,以及输出部分的类别生成器对象.
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed    # input embedding module(input embedding + positional encode)
        self.tgt_embed = tgt_embed    # ouput embedding module
        self.generator = generator    # output generation module
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        #在forward函数中,有四个参数,source代表源数据,target代表目标数据,source_mask和target_mask代表对应的掩码张量,在函数中,将source source_mask传入编码函数,得到结果后与source_mask target 和target_mask一同传给解码函数
        memory = self.encode(src, src_mask)
        res = self.decode(memory, src_mask, tgt, tgt_mask)
        return res
    
    def encode(self, src, src_mask):
        #编码函数,以source和source_mask为参数,使用src_embed对source做处理,然后和source_mask一起传给self.encoder
        src_embedds = self.src_embed(src)
        return self.encoder(src_embedds, src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        #解码函数,以memory即编码器的输出,source_mask target target_mask为参数,使用tgt_embed对target做处理,然后和source_mask,target_mask,memory一起传给self.decoder
        target_embedds = self.tgt_embed(tgt)
        return self.decoder(target_embedds, memory, src_mask, tgt_mask)


# Full Model
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """
    构建模型
    params:
        src_vocab:
        tgt_vocab:
        N: 编码器和解码器堆叠基础模块的个数
        d_model: 模型中embedding的size,默认512
        d_ff: FeedForward Layer层中embedding的size,默认2048
        h: MultiHeadAttention中多头的个数,必须被d_model整除
        dropout:
    """
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab))
    
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/695112.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【ArcGISPro SDK】构建多面体要素

结果展示 每个面构建顺序 代码 using ArcGIS.Core.CIM; using ArcGIS.Core.Data; using ArcGIS.Core.Geometry; using ArcGIS.Desktop.Catalog; using ArcGIS.Desktop.Core; using ArcGIS.Desktop.Editing; using ArcGIS.Desktop.Extensions; using ArcGIS.Desktop.Framework;…

基于AT89C51单片机的红外防盗报警器设计

第一章 绪论1.1 选题背景 随着社会科学的不断进步和发展,人们生活水平得到很大的提高,对个人私有财产的保护越来越重视,因而对于防盗的措施提出了更高的要求。本设计就是为了满足现代生活防盗的需要而设计的应用于家庭、车库、仓库和保险柜等处进行防盗监控的无线防盗报警装…

安装搭建java版的悟空crm遇到 网络错误请稍候再试 终极解决办法(hrm人力资源模块)

java版 项目目录 ├── build – webpack 配置文件 ├── config – 项目配置文件 ├── src – 源码目录 │ ├── api – axios请求接口 │ ├── assets – 静态图片资源文件 │ ├── components – 通用组件 │ ├── directives – 通用指令 │ ├── filters –…

Objective-C之通过协议提供匿名对象

概述 通过协议提供匿名对象的设计模式,遵循了面向对象设计的多项重要原则: 接口隔离原则:通过定义细粒度的协议来避免实现庞大的接口。依赖倒置原则:高层模块依赖于抽象协议,而不是具体实现。里氏替换原则&#xff1…

计算机网络9——无线网络和移动网络2无线个人区域网 WPAN

文章目录 一、蓝牙系统二、低速 WPAN三、高速 WPAN 无线个人区域网WPAN(Wireless Personal Area Network)就是在个人工作的地方把属于个人使用的电子设备(如便携式电脑、平板电脑、便携式打印机以及蜂窝电话等)用无线技术连接起来自组网络,不需要使用接入点AP&#…

nlp学习笔记

目录 很多入门例子 bert chinese 很多入门例子 https://github.com/lansinuote/Huggingface_Toturials bert chinese import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel, BertModel, TFBertModel, BertTokenizer# youpath = D:/bert-…

彩虹易支付最新版源码

源码简介 彩虹易支付最新版源码,更新时间为5.1号 2024/05/01: 1.更换全新的手机版支付页面风格 2.聚合收款码支持填写备注 3.后台支付统计新增利润、代付统计 4.删除结算记录支持直接退回商户金额 安装环境 1.PHP版本>7.4 2.Mysql数据库 安装教…

springCloudAlibaba之分布式事务组件---seata

Seata Sea学习分布式事务Seata二阶段提交协议AT模式 Sea学习 事务:事务是访问数据库并更新数据库中各项数据的一个程序执行单元。在关系数据库中,一个事务由一组或多组SQL语句组成。事务应该具有4个属性:原子性、一致性、隔离性、持久性。例如…

vuInhub靶场实战系列--Kioptrix Level #2

免责声明 本文档仅供学习和研究使用,请勿使用文中的技术源码用于非法用途,任何人造成的任何负面影响,与本人无关。 目录 免责声明前言一、环境配置1.1 靶场信息1.2 靶场配置 二、信息收集2.1 主机发现2.1.1 netdiscover2.1.2 nmap主机扫描2.1.3 arp-scan主机扫描 2.2 端口扫描…

LangChain基础知识入门

LangChain的介绍和入门 1 什么是LangChain LangChain由 Harrison Chase 创建于2022年10月,它是围绕LLMs(大语言模型)建立的一个框架,LLMs使用机器学习算法和海量数据来分析和理解自然语言,GPT3.5、GPT4是LLMs最先进的代…

打字侠是一款PWA网站,如何下载到电脑桌面?

嘿,亲爱的键盘侠们! 你是否还在为寻找一款好用的打字练习工具而烦恼?别担心,今天我要给大家介绍一位超级英雄——打字侠!它不仅是一个超级酷的打字练习网站,还是一款PWA(渐进式网页应用&#x…

汇编:结构体

在32位汇编中,结构体(structures)用于组织和管理复杂的数据类型,结构体可以包含多个不同类型的数据项(成员);在MASM(Microsoft Macro Assembler)中,使用结构体…

stm32编写Modbus步骤

1. modbus协议简介: modbus协议基于rs485总线,采取一主多从的形式,主设备轮询各从设备信息,从设备不主动上报。 日常使用都是RTU模式,协议帧格式如下所示: 地址 功能码 寄存器地址 读取寄存器…

电子设计入门教程硬件篇之集成电路IC(二)

前言:本文为手把手教学的电子设计入门教程硬件类的博客,该博客侧重针对电子设计中的硬件电路进行介绍。本篇博客将根据电子设计实战中的情况去详细讲解集成电路IC,这些集成电路IC包括:逻辑门芯片、运算放大器与电子零件。电子设计…

汇编语言LDS指令

在8086架构的实模式下,LDS指令(Load Pointer Using DS)用于从内存中加载一个32位的指针到指定寄存器和DS寄存器。我们来详细解释一下这条指令为什么会修改DS段寄存器。 LDS指令的功能 LDS指令格式如下: LDS destination, sourc…

Python中报错提示:TypeError: Student() takes no arguments

Python中报错提示:TypeError: Student() takes no arguments 在Python编程中,类是创建对象的蓝图。每个类都可能包含一个特殊的方法__init__,我们称之为构造函数,它在创建新实例时被调用。如果你在尝试创建一个类的实例时遇到了Ty…

找寻窗口句柄

FindWindow FindWindow这个函数检索顶级窗口的类名和窗口名称匹配指定的字符串。这个函数不搜索子窗口。 该函数是个宏,定义如下 #ifdef UNICODE #define FindWindow FindWindowW #else #define FindWindow FindWindowA #endif // !UNICODE ​​​​​​FindW…

SpringBoot快速整合MyBatisPlus

文章目录 创建项目配置pom.xml配置数据源创建实体类创建Mapper接口配置MyBatis Plus MyBatis Plus 是 MyBatis 的增强工具,在 MyBatis 的基础上进行扩展和增强,主要目标是简化开发、提高效率。它提供了一系列功能,包括 CRUD 封装、条件构造器…

#01 Stable Diffusion基础入门:了解AI图像生成

文章目录 前言什么是Stable Diffusion?Stable Diffusion的工作原理如何使用Stable Diffusion?Stable Diffusion的应用场景结论 前言 在当今迅速发展的人工智能领域,AI图像生成技术以其独特的魅力吸引了广泛的关注。Stable Diffusion作为其中的一项前沿技术&#…

k8s概述

文章目录 一、什么是Kubernetes1、官网链接2、概述3、特点4、功能 二、Kubernetes架构1、架构图2、核心组件2.1、控制平面组件(Control Plane Components)2.1.1、kube-apiserver2.1.2、etcd2.1.3、kube-scheduler2.1.4、kube-controller-manager 2.2、No…