手撕Transformer -- Day7 -- Decoder

手撕Transformer – Day7 – Decoder

Transformer 网络结构图

目录

  • 手撕Transformer -- Day7 -- Decoder
    • Transformer 网络结构图
    • Decoder 代码
      • Part1 库函数
      • Part2 实现一个解码器Decoder,作为一个类
      • Part3 测试
    • 参考

在这里插入图片描述

Transformer 网络结构

Decoder 代码

Part1 库函数

# 该板块主要是对解码器进行串接,实现得到解码器部分
# 输入为x,还没嵌入的,但是PAD好的输入,输出需要对注意力值进行线性转化和softmax,最后得到一个单维向量,长度为词库大小。
'''
# Part1 导入库函数
'''
import torch
from torch import nn
from dataset import train_dataset, de_vocab, en_vocab, de_preprocess, en_preprocess,PAD_IDX
from encoder import Encoder
from decoder_block import DecoderBlock
from emb import EmbeddingWithPosition

Part2 实现一个解码器Decoder,作为一个类

'''
# Part2 设计解码器的类
'''


class Decoder(nn.Module):
    def __init__(self, en_vocab_size, emd_size, nums_decoder_block, head, q_k_size, v_size, f_size):
        super().__init__()
        self.nums_decoder_block=nums_decoder_block
        # 首先对x进行编码
        self.emd = EmbeddingWithPosition(vocab_size=en_vocab_size, emd_size=emd_size)
        # 然后输入n个编码器
        self.decoder_list = nn.ModuleList()
        for _ in range(nums_decoder_block):
            self.decoder_list.append(
                DecoderBlock(head=head, emd_size=emd_size, q_k_size=q_k_size, v_size=v_size, f_size=f_size))
        # 然后需要线性化和softmax,目前是(batch_size,q_sqen_len,emd)
        # 得到(batch_size,vocab_size)
        self.linear1=nn.Linear(emd_size,en_vocab_size)
        self.softmax=nn.Softmax(-1)

    def forward(self, x, encoder_z,encoder_x): # encoder_x是编码器的输入(batch_size,q_seq_len)
        # x(batch_size,q_sqen_len)
        # 首先对解码器输入的padding位置进行掩码设置。
        mask1=(x==PAD_IDX).unsqueeze(1) # (batch_size,1,q_seq_len)
        mask1.expand(-1,x.size()[1],-1)  # (batch_size,q_seq_len,q_seq_len)
        # 然后要对解码器的输入的上半部分也取True然后和mask1或一下(也就是符号|),注意True表示需要隐藏的位置。
        # 注意:torch.tril 和 torch.triu 的区别就是决定矩阵的上半部分(不包含对角线)还是下半部分(不包含对角线)置为0,diagonal=1,表示置0的区域向上移动一行
        mask1=mask1 | torch.triu(torch.ones(mask1.size()[-1],mask1.size()[-1]),diagonal=1).bool().unsqueeze(0).expand(mask1.size()[0],-1,-1)


        # 然后对编码器的mask2进行掩码设置。在交叉注意力中,Padding 掩码的区域由K 和 V 的来源决定,
        # 而不是由Q 的来源决定。这确保了来自Q 的查询只关注K 中有效的信息位置。

        mask2 = (encoder_x == PAD_IDX).unsqueeze(1) # (batch_size,1,q_seq_len)
        mask2.expand(-1, encoder_x.size()[1], -1) # (batch_size,1,q_seq_len)

        x=self.emd(x)  # (batch_size,q_sqen_len,emd)

        # 进入解码器
        output=x
        for i in range(self.nums_decoder_block):
            output = self.decoder_list[i](output,encoder_z,mask1,mask2)

        # 输出进行线性层和softmax
        output=self.linear1(output)
        output=self.softmax(output)
        return output

Part3 测试

if __name__ == '__main__':
    # 取2个de句子转词ID序列,输入给encoder
    de_tokens1, de_ids1 = de_preprocess(train_dataset[0][0])
    de_tokens2, de_ids2 = de_preprocess(train_dataset[1][0])
    # 对应2个en句子转词ID序列,再做embedding,输入给decoder
    en_tokens1, en_ids1 = en_preprocess(train_dataset[0][1])
    en_tokens2, en_ids2 = en_preprocess(train_dataset[1][1])

    # de句子组成batch并padding对齐
    if len(de_ids1) < len(de_ids2):
        de_ids1.extend([PAD_IDX] * (len(de_ids2) - len(de_ids1)))
    elif len(de_ids1) > len(de_ids2):
        de_ids2.extend([PAD_IDX] * (len(de_ids1) - len(de_ids2)))

    enc_x_batch = torch.tensor([de_ids1, de_ids2], dtype=torch.long)
    print('enc_x_batch batch:', enc_x_batch.size())

    # en句子组成batch并padding对齐
    if len(en_ids1) < len(en_ids2):
        en_ids1.extend([PAD_IDX] * (len(en_ids2) - len(en_ids1)))
    elif len(en_ids1) > len(en_ids2):
        en_ids2.extend([PAD_IDX] * (len(en_ids1) - len(en_ids2)))

    dec_x_batch = torch.tensor([en_ids1, en_ids2], dtype=torch.long)
    print('dec_x_batch batch:', dec_x_batch.size())

    # Encoder编码,输出每个词的编码向量
    enc = Encoder(vocab_size=len(de_vocab), emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8, nums_encoderblock=3)
    enc_outputs = enc(enc_x_batch)
    print('encoder outputs:', enc_outputs.size())

    # Decoder编码,输出每个词对应下一个词的概率
    dec = Decoder(en_vocab_size=len(en_vocab), emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8, nums_decoder_block=3)
    enc_outputs = dec(dec_x_batch, enc_outputs, enc_x_batch)
    print(enc_outputs)
    print('decoder outputs:', enc_outputs.size())

参考

视频讲解:transformer-带位置信息的词嵌入向量_哔哩哔哩_bilibili

github代码库:github.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/954398.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

UI自动化测试:异常截图和page_source

自动化测试过程中&#xff0c;是否遇到过脚本执行中途出错却不知道原因的情况&#xff1f;测试人员面临的不仅是问题的复现&#xff0c;还有对错误的快速定位和分析。而异常截图与页面源码&#xff08;Page Source&#xff09;的结合&#xff0c;正是解决这一难题的利器。 在实…

Unity-Mirror网络框架-从入门到精通之RigidbodyBenchmark示例

文章目录 前言示例代码逻辑测试结论性能影响因素最后前言 在现代游戏开发中,网络功能日益成为提升游戏体验的关键组成部分。本系列文章将为读者提供对Mirror网络框架的深入了解,涵盖从基础到高级的多个主题。Mirror是一个用于Unity的开源网络框架,专为多人游戏开发设计,它…

【Unity3D】【已解决】TextMeshPro无法显示中文的解决方法

TextMeshPro无法显示中文的解决方法 现象解决方法Assets 目录中新建一个字体文件夹在C:\Windows\Fonts 中随便找一个中文字体的字体文件把字体文件拖到第一步创建的文件夹中右键导入的字体&#xff0c;Create---TextMeshPro---Font Asset&#xff0c;创建字体文件资源把 SDF文件…

走出实验室的人形机器人,将复刻ChatGPT之路?

1月7日&#xff0c;在2025年CES电子展现场&#xff0c;黄仁勋不仅展示了他全新的皮衣和采用Blackwell架构的RTX 50系列显卡&#xff0c;更进一步展现了他对于机器人技术领域&#xff0c;特别是人形机器人和通用机器人技术的笃信。黄仁勋认为机器人即将迎来ChatGPT般的突破&…

Docker PG流复制搭建实操

目录标题 制作镜像1. 删除旧的容器2. 创建并配置容器3. 初始化数据库并启动 主库配置参数4. 配置主库5. 修改 postgresql.conf 配置 备库配置参数6. 创建并配置备库容器7. 初始化备库 流复制8. 配置&检查主库复制状态9. 检查备库配置 优化建议问题1&#xff1a;FATAL: usin…

【Flink】Flink内存管理

Flink内存整体结构图&#xff1a; JobManager内存管理 JVM 进程总内存(Total Process Memory)Flink总内存(Total Flink Memory)&#xff1a;JVM进程总内存减去JVM Metaspace(元空间)和JVM Overhead(运行时开销)上图解释&#xff1a; JVM进程总内存为2G;JVM运行时开销(JVM Overh…

Flink系统知识讲解之:Flink内存管理详解

Flink系统知识讲解之&#xff1a;Flink内存管理详解 在现阶段&#xff0c;大部分开源的大数据计算引擎都是用Java或者是基于JVM的编程语言实现的&#xff0c;如Apache Hadoop、Apache Spark、Apache Drill、Apache Flink等。Java语言的好处是不用考虑底层&#xff0c;降低了程…

VM(虚拟机)和Linux的安装

文章目录 1.虚拟机1.1 VM的安装和删除1.1.1 安装前提1.1.2 安装步骤 1.2 虚拟机快照1.3 虚拟机的克隆 2.Linux的安装2.1 CentOS2.2 Ubuntu 1.虚拟机 &#xff08;1&#xff09;Linux系统的安装方式 ①物理机安装&#xff1a;直接将操作系统安装到服务器硬件上 ②虚拟机安装&am…

Unity中实现倒计时结束后干一些事情

问题描述&#xff1a;如果我们想实现在一个倒计时结束后可以执行某个方法&#xff0c;比如挑战成功或者挑战失败&#xff0c;或者其他什么的比如生成boss之类的功能&#xff0c;而且你又不想每次都把代码复制一遍&#xff0c;那么就可以用下面这种方法。 结构 实现步骤 创建一…

citrix netscaler13.1 重写负载均衡响应头(基础版)

在 Citrix NetScaler 13.1 中&#xff0c;Rewrite Actions 用于对负载均衡响应进行修改&#xff0c;包括替换、删除和插入 HTTP 响应头。这些操作可以通过自定义策略来完成&#xff0c;帮助你根据需求调整请求内容。以下是三种常见的操作&#xff1a; 1. Replace (替换响应头)…

新垂直电商的社交传播策略与AI智能名片2+1链动模式S2B2C商城小程序的应用探索

摘要&#xff1a;随着互联网技术的不断进步和电商行业的快速发展&#xff0c;传统电商模式已难以满足消费者日益增长的个性化和多元化需求。新垂直电商在此背景下应运而生&#xff0c;通过精准定位、用户细分以及深度社交传播策略&#xff0c;实现了用户群体的快速裂变与高效营…

TP4056锂电池充放电芯片教程文章详解·内置驱动电路资源!!!

目录 TP4056工作原理 TP4056引脚详解 TP4056驱动电路图 锂电池充放电板子绘制 编写不易&#xff0c;仅供学习&#xff0c;感谢理解。 TP4056工作原理 TP4056是专门为单节锂电池或锂聚合物电池设计的线性充电器&#xff0c;充电电流可以用外部电阻设定&#xff0c;最大充电…

burpsiute的基础使用(2)

爆破模块&#xff08;intruder&#xff09;&#xff1a; csrf请求伪造访问&#xff08;模拟攻击&#xff09;: 方法一&#xff1a; 通过burp将修改&#xff0c;删除等行为的数据包压缩成一个可访问链接&#xff0c;通过本地浏览器访问&#xff08;该浏览器用户处于登陆状态&a…

基于32QAM的载波同步和定时同步性能仿真,包括Costas环的gardner环

目录 1.算法仿真效果 2.算法涉及理论知识概要 3.MATLAB核心程序 4.完整算法代码文件获得 1.算法仿真效果 matlab2022a仿真结果如下&#xff08;完整代码运行后无水印&#xff09;&#xff1a; 仿真操作步骤可参考程序配套的操作视频。 2.算法涉及理论知识概要 载波同步是…

使用PWM生成模式驱动BLDC三相无刷直流电机

引言 在 TI 的无刷直流 (BLDC) DRV8x 产品系列使用的栅极驱动器应用中&#xff0c;通常使用一些控制模式来切换MOSFET 开关的输出栅极。这些控制模式包括&#xff1a;1x、3x、6x 和独立脉宽调制 (PWM) 模式。   不过&#xff0c;DRV8x 产品系列&#xff08;例如 DRV8311&…

校园跑腿小程序---轮播图,导航栏开发

hello hello~ &#xff0c;这里是 code袁~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的在校大学生…

SpringBoot入门实现简单增删改查

本例子的依赖 要实现的内容 通过get、post、put和delete接口,对数据库中的trade.categories表进行增删改查操作。 目录结构 com.test/ │ ├── controller/ │ ├── CateController.java │ ├── pojo/ │ ├── dto/ │ │ └── CategoryDto.java │ ├─…

doc、pdf转markdown

国外的一个网站可以&#xff1a; Convert A File Word, PDF, JPG Online 这个网站免费的&#xff0c;算是非常厚道了&#xff0c;但是大文件上传多了之后会扛不住 国内的一个网站也不错&#xff1a; TextIn-AI智能文档处理-图像处理技术-大模型加速器-在线免费体验 https://…

信号与系统初识---信号的分类

文章目录 0.引言1.介绍2.信号的分类3.关于周期大小的求解4.实信号和复信号5.奇信号和偶信号6.能量信号和功率信号 0.引言 学习这个自动控制原理一段时间了&#xff0c;但是只写了一篇博客&#xff0c;其实主要是因为最近在打这个华数杯&#xff0c;其次是因为在补这个数学知识…

Facebook 在新兴市场的发展策略:机遇与挑战并存

随着全球互联网的普及&#xff0c;新兴市场成为了全球科技公司关注的重点。这些地区的互联网用户数量快速增长&#xff0c;对于 Facebook&#xff08;现 Meta&#xff09;来说&#xff0c;这些市场不仅蕴藏着巨大的用户潜力&#xff0c;也带来了不少挑战。Facebook 在新兴市场的…