【Transformer从零开始代码实现 pytoch版】(三)Decoder编码器组件:多头自注意力+多头注意力+全连接层+规范化层

解码器组件

在这里插入图片描述

解码器部分:

  • 由N个解码器层堆叠而成
  • 每个解码器层由三个子层连接结构组成
  • 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
  • 第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
  • 第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接

解码器层的作用:
作为解码器的组成单元,每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程。

解码器层代码

解码器曾主要由三个子层组成,这里面三个子层还用之前构建Encoder时的代码,详情请看:【Transformer从零开始代码实现 pytoch版】(二)Encoder编码器组件:mask + attention + feed forward + add&norm

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        """

        :param size: 词嵌入维度
        :param self_attn: 多头自注意力层 Q=K=V
        :param src_attn: 多头注意力层 Q!=K=V
        :param feed_forward: 前馈全连接层
        :param dropout: 置0比率
        """
        super(DecoderLayer, self).__init__()

        # 传参到类中
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.dropout = dropout

        # 按照解码器层的结构图,使用clones函数克隆3个子层连接对象
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, source_mask, target_mask):
        """构建出三个子层:多头自注意力子层、普通的多头注意力子层、前馈全连接层

        :param x: 上一层输入的张量
        :param memory: 编码器的语义存储张量(K=V)
        :param source_mask: 源数据的掩码张量
        :param target_mask: 目标数据的掩码张量
        :return:一层解码器的解码输出
        """
        m = memory

        # 第一步,让x进入第一个子层(多头自注意力机制的子层)
        # 采用target_mask,将解码时未来的信息进行遮掩。
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, target_mask))

        # 第二步,让x进入第二个子层(常规多头注意力机制的子层,Q!=K=V)
        # 采用source_mask,遮掩掉已经判定出来的对结果信息无用的数据(减少对无用信息的关注),提升计算效率
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, source_mask))

        # 第三步,让x进入第三个子层(前馈全连接层)
        return self.sublayer[2](x, self.feed_forward)

示例

# 定义参数
size = d_model = 512
head = 8
d_ff = 64
dropout = 0.2

self_attn = src_attn = MultiHeadedAttention(head, d_model, dropout)     # 定义多头注意力层
ff = PositionwiseFeedForward(d_model, d_ff, dropout)                    # 定义前馈全连接层
x = pe_res
memory = enc_res    # 将之前编码器实例中的enc_res结果赋值给memory作为K和V
mask = torch.zeros(2, 4, 4)
source_mask = target_mask = mask    # 简单示范,都先给同样的mask

dl = DecoderLayer(size, self_attn, src_attn, ff, dropout)
dl_res = dl(x, memory, source_mask, target_mask)
print(f"dl_res: {dl_res}\n shape:{dl_res.shape}")


dl_res: tensor([[[-2.7233e+01,  3.7782e+01,  1.7257e+01,  ...,  1.2275e+01,
          -4.7017e+01,  1.7687e+01],
         [-2.6276e+01,  1.4660e-01,  5.5642e-02,  ..., -2.5157e+01,
          -2.8655e+01, -3.8758e+01],
         [ 1.0419e+00, -2.7726e+01, -2.3628e+01,  ..., -7.7137e+00,
          -5.7320e+01,  4.6977e+01],
         [-3.3436e+01,  3.2082e+01, -1.6754e+01,  ..., -2.5161e-01,
          -4.0380e+01,  4.7144e+01]],
        [[-5.3706e+00, -2.4270e+01,  2.1009e+01,  ...,  6.5833e+00,
          -4.3054e+01,  2.5535e+01],
         [ 3.1999e+01, -8.3981e+00, -5.6480e+00,  ...,  3.1037e+00,
           2.1093e+01,  3.0293e+00],
         [ 5.5799e+00,  1.0306e+01, -2.0165e+00,  ...,  3.8163e+00,
           4.0567e+01, -1.2256e+00],
         [-3.6323e+01, -1.4260e+01,  3.3353e-02,  ..., -9.4611e+00,
          -1.6435e-01, -3.5157e+01]]], grad_fn=<AddBackward0>)
 shape:torch.Size([2, 4, 512])

对比下面编码器的编码结果:

enc_res: tensor([[[-0.9458,  1.4723,  0.6997,  ...,  0.6569, -1.9873,  0.7674],
         [-0.9278,  0.0055, -0.0309,  ..., -1.2925, -1.2145, -1.6950],
         [ 0.1456, -1.1068, -0.8927,  ..., -0.2079, -2.2481,  1.8858],
         [-1.2406,  1.3828, -0.8069,  ...,  0.1041, -1.5828,  1.9792]],
        [[-0.1922, -1.1158,  0.7787,  ...,  0.2102, -1.7763,  1.1359],
         [ 1.4014, -0.3193, -0.3572,  ..., -0.0428,  0.7563,  0.1116],
         [ 0.3749,  0.4738, -0.0470,  ...,  0.1295,  1.8679,  0.0937],
         [-1.5545, -0.5667, -0.0432,  ..., -0.6391, -0.0121, -1.4567]]],
       grad_fn=<AddBackward0>)

原数据的掩码张量存在意义:
掩码原数据中,关联性弱的数据,不让注意力计算分散,提升计算效率。

解码器代码

N个解码器层构成一个解码器

class Decoder(nn.Module):
    def __init__(self, layer, N):
        """ 确定解码器层和层数

        :param layer: 解码器层
        :param N: 解码器层的个数
        """
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)          # 使用clones函数克隆N个类
        self.norm = LayerNorm(layer.size)        # 实例化规范化层

    def forward(self, x, memory, source_mask, target_mask):
        """ 循环构建解码器,经过规范化层后输出

        :param x:目标数据的嵌入表示
        :param memory:解码器层的输出QV
        :param source_mask:源数据掩码张量
        :param target_mask:目标数据掩码张量
        :return:经过规范化后的解码器
        """
        for layer in self.layers:
            x = layer(x, memory, source_mask, target_mask)

        return self.norm(x)

示例

size = d_model = 512
head = 8
d_ff =64
dropout = 0.2
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)     # 第一个attn作为自注意力机制,第二个attn作为注意力机制

N = 8
x = pe_res
memory = enc_res
mask = torch.zeros(2, 4, 4)
source_mask = target_mask = mask

de = Decoder(layer, N)      # 实例化解码器
de_res = de(x, memory, source_mask, target_mask)
print(f"de_res: {de_res}\n shape: {de_res.shape}")


de_res: tensor([[[-0.7714,  0.1066,  1.8197,  ..., -0.1137,  0.2005,  0.5856],
         [-0.9215, -0.9844, -0.4962,  ..., -0.1074,  0.4848,  0.3493],
         [-2.2495,  0.0859, -0.7644,  ..., -0.0679, -0.7270, -1.3438],
         [-0.4822,  0.2821,  1.0786,  ..., -1.9442,  0.8834, -1.1757]],
        [[-0.2491, -0.6117,  0.7908,  ..., -2.1624,  0.6212,  0.6190],
         [-0.3938, -0.5203,  0.6412,  ..., -0.8679,  0.8462,  0.3037],
         [-1.0217, -1.0685, -0.5138,  ...,  1.2010,  2.0795, -0.0143],
         [-0.2919, -0.5916,  1.5231,  ..., -0.1215,  0.7127, -0.0586]]],
       grad_fn=<AddBackward0>)
 shape: torch.Size([2, 4, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/138437.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

猪酒店房价采集

<?php // 设置代理 $proxy_host jshk.com.cn;// 创建一个cURL资源 $ch curl_init();// 设置代理 curl_setopt($ch, CURLOPT_PROXY, $proxy_host.:.$proxy_port);// 连接URL curl_setopt($ch, CURLOPT_URL, "http://www.zujia.com/");// 发送请求并获取HTML文档…

快跑RUSH

欢迎来到程序小院 快跑RUSH 玩法&#xff1a;跑动的小人&#xff0c;点击鼠标左键跳过障碍物&#xff0c;跳过不同的阶梯&#xff0c;经过金币吃掉获取1分&#xff0c;赶紧去快跑PUSH看看你能够获得多少金币哦^^。开始游戏https://www.ormcc.com/play/gameStart/202 html <…

2024dh网站导航最新,你以为它很花俏?确是牛逼的人人资源站

2024dh网站app.2024网站导航最新。2024免费中文导航。2024dh手机网站导航。2024年还好用的导航app 2024资讯导航是一个专注于新闻和资讯的视频导航网站。电影导航网站&#xff0c;图片导航网站&#xff0c;爱奇艺导航&#xff0c;优酷电影导航&#xff0c;土豆导航&#xff0c…

苹果手机怎么将图片转为excel/word?

第一种方案就是用苹果手机自带的OCR功能来实现需求&#xff0c;但低版本的IOS系统不支持此功能&#xff0c;目前已知IOS15以上版本可以支持&#xff0c;只需要在“设置”--“相册”那打开“实况文本”即可&#xff0c;如下图。 IOS15系统打开“实况文本” 开启后&#xff0c;打…

Pass基础-DevOps

&#xff0c;DevOps是Dev&#xff08;开发&#xff09;和Ops&#xff08;运维/运营&#xff09;的结合&#xff0c;它将人、流程、工具、工程实践等等结合起来应用到IT价值流的实现过程中&#xff0c;是一系列原则、方法、流程、实践、工具的综合体。DevOps面向应用的全生命周期…

递归和master公式 系统栈 + 计算时间复杂度

前置知识&#xff1a;无 1&#xff09;从思想上理解递归&#xff1a;对于新手来说&#xff0c;递归去画调用图是非常重要的&#xff0c;有利于分析递归 2&#xff09;从实际上理解递归&#xff1a;递归不是玄学&#xff0c;底层是利用系统栈来实现的 3&#xff09;任何递归函…

Rust语言入门:理解基础语法

大家好&#xff0c;我是[lincyang]。 今天&#xff0c;我们将深入了解Rust编程语言的基础语法&#xff0c;为你打下坚实的Rust编程基础。 Rust简介 Rust是一种系统编程语言&#xff0c;它注重内存安全、并发性和实用性。Rust的设计哲学是提供安全性&#xff0c;而不损失性能。它…

自动生成Form表单提交在苹果浏览器中的注意事项

以下是本人在公司旧系统中看到的该段代码 function Post(URL, PARAMTERS) {//创建form表单var temp_form document.createElement("form");temp_form.action URL;//如需打开新窗口&#xff0c;form的target属性要设置为_blanktemp_form.target "_blank"…

跟我一起从零开始学python(二)网络编程

前言 昨天讲解了关于从零入门python的第一遍&#xff0c;编程语法必修内容&#xff0c;比如python3基础入门&#xff0c;列表与元组&#xff0c;字符串&#xff0c;字典&#xff0c;条件丶循环和其他语句丶函数丶面向对象丶异常和文件处理 。 今天讲第二篇&#xff1a;python…

修改/etc/fstab文件导致Linux无法正常启动解决方法

如果把 /etc/fstab 文件修改错了&#xff0c;也重启了&#xff0c;系统崩溃启动不了了&#xff0c;那该怎么办&#xff1f;比如&#xff1a; [rootlocalhost ~]# vi /etc/fstab UUIDc2ca6f57-b15c-43ea-bca0-f239083d8bd2 ext4 defaults 1 1 UUID0b23d315-33a7-48a4-bd37-9248…

ceph-deploy bclinux aarch64 ceph 14.2.10【2】vdbench rbd 块设备rbd 测试失败

上篇 ceph-deploy bclinux aarch64 ceph 14.2.10-CSDN博客 安装vdbench 下载vdbench 下载页面 Vdbench Downloads (oracle.com) 包下载 需要账号登录&#xff0c;在弹出层点击同意才能继续下载 用户手册 https://download.oracle.com/otn/utilities_drivers/vdbench/vdb…

搜集的升压芯片资料

DC-DC升压芯片,输入电压0.65v/1.5v/1.8v/2v/2.5v/2.7v/3v/3.3v/3.6v/5v/12v/24v航誉微 HUB628是一款超小封装高效率、直流升压稳压电路。输入电压范围可由低2V伏特到24伏特&#xff0c;升压可达28V可调&#xff0c;且内部集成极低RDS内阻100豪欧金属氧化物半导体场效应晶体管的…

人物百科怎么创建?教你如何创建人物百度百科注意以下方式技巧!

百科就像互联网上的名片&#xff0c;不仅代表身份&#xff0c;而且拥有极高的可信度。因此&#xff0c;许多名人都希望利用百科提高自己的知名度。任何人都可以编辑人物百科词条&#xff0c;但为了成功上传&#xff0c;需要一些技巧。以下是小媒同学给大家带来的人物百科快速创…

成都瀚网科技有限公司抖音带货正规

随着互联网的蓬勃发展&#xff0c;越来越多的公司开始利用网络平台进行产品销售。其中&#xff0c;抖音作为一款广受欢迎的短视频平台&#xff0c;已经成为众多商家眼中的“香饽饽”。在这场电商狂欢中&#xff0c;成都瀚网科技有限公司&#xff08;以下简称“瀚网科技”&#…

AMEYA360:江苏润石再次重磅发布11颗通过AEC-Q100认证的车规级芯片

为进一步满足众多新能源汽车客户对车规级芯片的需求&#xff0c;江苏润石持续研发更多的车规级产品&#xff0c;再次重磅发布11颗通过AEC-Q100 Grade1 & MSL 1湿敏等级认证的车规级芯片;进一步展示了江苏润石在车规级芯片领域孜孜不倦的追求&#xff0c;以及深耕汽车电子市…

研究生做实验找不到数据集咋办?

做实验找不到数据集咋办?这是很多研究者和开发者都会遇到的问题。数据集是实验的基础,没有合适的数据集,就无法验证模型的性能和效果。那么,有没有什么方法可以快速地找到我们需要的数据集呢?本文将介绍4个常用的数据集搜索平台,希望能够帮助大家解决这个难题。下面以室内…

单极性非归零码(NRZ)、双极性非归零码(NRZ)、单极性归零码、双极性非归零码(NRZ)、差分码的编码规则与其功率谱

数字信号的基带传输的基本概念与传输码型 主要涉及一些数字基带传输的基本概念和数字基带传输的简单码型。码型包括&#xff1a;单极性非归零码&#xff08;NRZ&#xff09;、双极性非归零码&#xff08;NRZ&#xff09;、单极性归零码、双极性非归零码&#xff08;NRZ&#xf…

【第2章 Node.js基础】2.4 Node.js 全局对象(一)

什么是Node.js 全局对象 对于浏览器引擎来说&#xff0c;JavaScript 脚本中的 window 是全局对象&#xff0c;而Node.js程序中的全局对象是 global&#xff0c;所有全局变量(除global本身外)都是global 对象的属性。全局变量和全局对象是所有模块都可以调用的。Node.is 的全局…

零代码Prompt应用大赛正式开始!飞桨星河社区五周年活动第一站

五周年盛典将至&#xff01;抢发第一站&#xff01; 在大模型时代&#xff0c;飞桨星河社区致力于让人人都成为大模型开发者&#xff01;飞桨星河社区零代码应用开发工具链&#xff0c;帮助大家轻松实现灵感落地、场景化需求落地&#xff0c;助力每个人实现工作与生活的效能提…

Node-RED系列教程-29nodered与三菱PLC基于MC协议通信测试

安装mc通信节点: node-red-contrib-mcprotocol 包含2个节点,一个节点负责读,一个节点负责写。 本教程目前只演示读功能。由于没有硬件,首先利用hsl demo软件模拟出一个用于测试mc通信的服务端。 mc读过程如下: 输入节点开启定时即可。 MC读节点配置: