Transformer的PyTorch实现之若干问题探讨(二)

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])


    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的
        dec_inputs: [batch_size, tgt_len]
        enc_inpus: [batch_size, src_len]
        enc_outputs: [batch_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量
        dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵
        # 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为False
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +
                                       dec_self_attn_subsequence_mask), 0).cuda()
        # 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]

        for layer in self.layers:
            dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

        return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False]],
        [[False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],
        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        # 1) 计算注意力分数QK^T/sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]
        # 2)  进行 mask 和 softmax
        # mask为True的位置会被设为-1e9
        scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9
        attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]
        # 3) 乘V得到最终的加权和
        context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]
        '''
        得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,
        换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列

        返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,
        只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来
        '''
        return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/381505.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

统一数据格式返回,统一异常处理

目录 1.统一数据格式返回 2.统一异常处理 3.接口返回String类型问题 1.统一数据格式返回 添加ControllerAdvice注解实现ResponseBodyAdvice接口重写supports方法,beforeBodyWrite方法 /*** 统一数据格式返回的保底类 对于一些非对象的数据的再统一 即非对象的封…

【资料分享】基于单片机大气压监测报警系统电路方案设计、基于飞思卡尔的无人坚守点滴监控自动控制系统设计(程序,原理图,pcb,文档)

基于单片机大气压监测报警系统电路方案设计 功能:实现的是大气压检测报警系统,可以通过传感器实时检测当前大气压值,可以设定大气压正常范围,当超过设定范围进行报警提示。 资料:protues仿真,程序&#x…

计算机二级C语言备考学习记录

一、C语言程序的结构 1.程序的构成,main函数和其他函数。 程序是由main函数和其他函数构成main作为主函数,一个C程序里只有一个main函数其他函数可以分为系统函数和用户函数,系统函数为编译系统提供,用户函数由用户自行编写 2.…

[职场] 抖音运营SOP全攻略 #微信#职场发展

抖音运营SOP全攻略 1.养号的步骤 注册一机—卡一号,在注册的前5天只看视频不发视频,单日观看视频的时长不少于30分钟。观看过程中正常评论点赞互动,关注5-10个头部大号。关注20个二三十万至百万的竟品账号。 粉丝量低于1W的账号下不要留下…

Compose之Slider全面解析

JetPack Compose系列(14)—Slider Slider,即拖动条,默认包含了一个滑块和一个滑动轨道。允许用户在一个数值范围内进行选择。 按照惯例,先观察其构造函数: Composable fun Slider(value: Float,onValueCh…

Debezium发布历史120

原文地址: https://debezium.io/blog/2022/04/07/read-only-incremental-snapshots/ 欢迎关注留言,我是收集整理小能手,工具翻译,仅供参考,笔芯笔芯. Read-only Incremental Snapshots for MySQL April 7, 2022 by K…

【Python中Selenium元素定位的各种方法】

1、元素定位操作: 2、创建浏览器驱动操作,导入By模块: from selenium import webdriver # 用于界面与浏览器互动 from selenium.webdriver.common.by import By # 用于元素定位 driver webdriver.Chrome() # 调用Chrome类,创…

C++ 贪心 区间问题 区间选点

给定 N 个闭区间 [ai,bi] ,请你在数轴上选择尽量少的点,使得每个区间内至少包含一个选出的点。 输出选择的点的最小数量。 位于区间端点上的点也算作区间内。 输入格式 第一行包含整数 N ,表示区间数。 接下来 N 行,每行包含两…

.NET高级面试指南专题六【线程安全】5种方法解决线程安全问题

前言 多线程编程相对于单线程会出现一个特有的问题,就是线程安全的问题。所谓的线程安全,就是如果你的代码所在的进程中有多个线程在同时运行,而这些线程可能会同时运行这段代码。如果每次运行结果和单线程运行的结果是一样的,而且…

探索未来:集成存储器计算(IMC)与深度神经网络(DNN)的机遇与挑战

开篇部分:人工智能、深度神经网络与内存计算的交汇 在当今数字化时代,人工智能(AI)已经成为科技领域的一股强大力量,而深度神经网络(DNN)则是AI的核心引擎之一。DNN是一种模仿人类神经系统运作…

视觉开发板—K210自学笔记(二)

视觉开发板—K210 一、开发之前的准备 工欲善其事必先利其器。各位同学先下载下面的手册: 1.Sipeed-Maix-Bit 资料下载:https://dl.sipeed.com/shareURL/MAIX/HDK/Sipeed-Maix-Bit/Maix-Bit_V2.0_with_MEMS_microphone 2.Sipeed-Maix-Bit 规格书下载&…

解决dockor安装nginx提示missing signature key的问题

问题描述 使用dockor安装nginx拉取nginx的时候提示key丢失问题 问题定位 由于dockor版本低导致 问题解决 卸载重新安装最新版本dockor 解决步骤 1. 卸载旧版本的Docker: sudo yum remove docker docker-common docker-selinux docker-engine 2. 安装依赖包&am…

C++入门学习(二十六)for循环

for (初始化; 条件; 递增/递减) { // 代码块 } 打印1~10&#xff1a; #include <iostream> using namespace std; int main() { for (int i 1; i < 10; i) { cout <<i<<endl; } return 0; } 打印九九乘法表&#xff1a; #include <iostream…

Git版本与分支

目录 一、Git 二、配置SSH 1.什么是SSH Key 2.配置SSH Key 三、分支 1.为什么要使用分支 2.四个环境及特点 3.实践操作 1.创建分支 2.查看分支 3.切换分支 4.合并分支 5.删除分支 6.重命名分支 7.推送远程分支 8.拉取远程分支 9.克隆指定分支 四、版本 1.什…

春晚刘谦魔术——约瑟夫环

昨晚&#xff0c;刘谦在春晚上表演了一个魔术&#xff0c;通过对四张撕成两半的纸牌连续操作&#xff0c;最终实现了纸牌的配对。 这个魔术虽然原理不是很难&#xff0c;但是通过刘谦精湛的表演还是让这个魔术产生了不错的效果&#xff08;虽然我感觉小尼的效果更不错&#xff…

【北邮鲁鹏老师计算机视觉课程笔记】02 filter

1 图像的类型 二进制图像&#xff1a; 灰度图像&#xff1a; 彩色图像&#xff1a; 2 任务&#xff1a;图像去噪 噪声点让我们看得难受是因为噪声点与周边像素差别很大 3 均值 滤波核 卷积核 4 卷积操作 对应相乘再累加起来 卷积核记录了权值&#xff0c;把权值套到要卷积…

2023年总结

人们总说时间会改变一切&#xff0c;但事实上你得自己来。 今年开始给自己的时间读书、工作、生活都加上一个2.0的release版本号&#xff0c;相比过去的一年还是有很多进步的。 就跟git commit一样&#xff0c;一步一步提交优化&#xff0c;年底了发个版本。用李笑来的话说&am…

【洛谷题解】P1075 [NOIP2012 普及组] 质因数分解

题目链接&#xff1a;[NOIP2012 普及组] 质因数分解 - 洛谷 题目难度&#xff1a;入门 涉及知识点&#xff1a;枚举&#xff08;优化&#xff09; 题意&#xff1a; 输入样例&#xff1a;21 输出样例&#xff1a;7 分析&#xff1a;枚举到小因数&#xff0c;再除a&#x…

何时以及如何选择制动电阻

制动电阻的选择是优化变频器应用的关键因素 制动电阻器在变频器中是如何工作的&#xff1f; 制动电阻器在 VFD 应用中的工作原理是将电机减速到驱动器设定的精确速度。它们对于电机的快速减速特别有用。制动电阻还可以将任何多余的能量馈入 VFD&#xff0c;以提升直流母线上的…

单片机的认识

单片机的定义 先简单理解为&#xff1a; 在一片集成电路芯片上集成了微处理器&#xff08;CPU &#xff09;存储器&#xff08;ROM和RAM&#xff09;、I/O 接口电路&#xff0c;构成单芯片微型计算机&#xff0c;即为单片机。 把组成微型计算机的控制器、运算器、存储器、输…