Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。

1.自定义数据中enc_input、dec_input及dec_output的区别

博文中给出了两对德语翻译成英语的例子:

# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [
    # enc_input   dec_input    dec_output
    ['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]

初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。他们在transformer block中的位置如下:
在这里插入图片描述

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。

我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。

实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。

2.Transformer的训练流程

此处给出博文中附带的非常简洁的Transformer训练代码:

from torch import optim
from model import *

model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

# 训练开始
for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        '''
        enc_inputs: [batch_size, src_len] [2,5]
        dec_inputs: [batch_size, tgt_len] [2,6]
        dec_outputs: [batch_size, tgt_len] [2,6]
        '''
        enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]
        outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]
        loss = criterion(outputs, dec_outputs.view(-1))  # 将dec_outputs展平成一维张量

        # 更新权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')

这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:

ich mochte ein bier

在进行训练时,当我们给出起始符S,接下来应该预测出:

I

那训练时,有了SI后,则应该预测出

want

那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。

那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/379353.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

最新的 Ivanti SSRF 零日漏洞正在被大规模利用

Bleeping Computer 网站消息,安全研究员发现 Ivanti Connect Secure 和 Ivanti Policy Secure 服务器端请求伪造 (SSRF) 漏洞(CVE-2024-21893 )正在被多个威胁攻击者大规模利用。 2024 年 1 月 31 日,Ivanti 首次就网关 SAML 组件…

深入理解Netty及核心组件使用—上

目录 Netty的优势 为什么Netty使用NIO而不是AIO? Netty基本组件 Bootstrap、EventLoop(Group) 、Channel 事件和 ChannelHandler、ChannelPipeline ChannelFuture Netty入门程序 服务端代码 客户端代码 运行结果 Netty的优势 1. API 使用简单&#xff0c…

【QT】day6

#include "home.h" #include "ui_home.h"Home::Home(QWidget *parent): QWidget(parent), ui(new Ui::Home) {ui->setupUi(this);// 从配置文件读取用户名QSettings settings("kim", "ad");username settings.value("usernam…

PlateUML绘制UML图教程

UML(Unified Modeling Language)是一种通用的建模语言,广泛用于软件开发中对系统进行可视化建模。PlantUML是一款强大的工具,通过简单的文本描述,能够生成UML图,包括类图、时序图、用例图等。PlantUML是一款…

彩虹系统7.0免授权+精美WAP端模板源码

最低配置环境 PHP7.2 1、上传源码到网站根目录,导入数据库文件 2、修改数据库配置文件:/config.php 3、后台:/admin 账号: 4、前台用户:123456 密码:1234561

请手写几种js排序算法

什么是排序算法 冒泡排序选择排序插入排序快速排序归并排序(Merge Sort) 思想实现测试分析动画 快速排序 (Quick Sort) 思想实现测试分析动画 思考:快排和归并用的都是分治思想,递推公式和递归代码也非常相…

【数据分享】1929-2023年全球站点的逐月平均风速(Shp\Excel\免费获取)

气象数据是在各项研究中都经常使用的数据,气象指标包括气温、风速、降水、能见度等指标,说到气象数据,最详细的气象数据是具体到气象监测站点的数据! 有关气象指标的监测站点数据,之前我们分享过1929-2023年全球气象站…

[算法前沿]--058- LangChain 构建 LLM 应用详细教程

什么是LLMs? LLM,即大型语言模型,是指经过大量文本数据训练的最先进的语言模型。它利用深度学习技术来理解和生成类似人类的文本,使其成为各种应用程序的强大工具,例如文本完成、语言翻译、情感分析等。LLMs最著名的例子之一是 OpenAI 的 GPT-3,它因其语言生成能力而受到…

《MySQL 简易速速上手小册》第5章:高可用性和灾难恢复(2024 最新版)

文章目录 5.1 构建高可用性 MySQL 解决方案5.1.1 基础知识5.1.2 重点案例:使用 Python 构建高可用性的电子商务平台数据库5.1.3 拓展案例 5.2 数据备份策略和工具5.2.1 基础知识5.2.2 重点案例:使用 Python 实现 MySQL 定期备份5.2.3 拓展案例 5.3 灾难恢…

【网工】华为设备命令学习(服务器发布)

本次实验主要是内网静态nat配置没,对外地址可以理解为一台内网的服务器,外网设备可以ping通内网的服务器设备,但是ping不通内网的IP。 除了AR1设备配置有区别,其他设备都是基础IP的配置。 [Huawei]int g0/0/0 [Huawei-GigabitEt…

排序算法---快速排序

原创不易,转载请注明出处。欢迎点赞收藏~ 快速排序是一种常用的排序算法,采用分治的策略来进行排序。它的基本思想是选取一个元素作为基准(通常是数组中的第一个元素),然后将数组分割成两部分,其中一部分的…

在Visual Studio中引用和链接OpenSceneGraph (OSG) 库

在Visual Studio中引用和链接OpenSceneGraph (OSG) 库,按照以下步骤操作: 构建或安装OSG库 下载OpenSceneGraph源代码(如3.0版本)并解压。使用CMake配置项目,为Visual Studio生成解决方案文件。通常您需要设置CMake中的…

PHPExcel导出excel

PHPExcel下载地址 https://gitee.com/mirrors/phpexcelhttps://github.com/PHPOffice/PHPExcel 下载后目录结构 需要的文件如下图所示 将上面的PHPExcel文件夹和PHPExcel.php复制到你需要的地方 这是一个简单的示例代码 <?php$dir dirname(__FILE__); //require_once …

电脑通电自启动设置

首先要进入BIOS&#xff0c;以华硕为例&#xff0c;按下电源键&#xff0c;在开机之前按下delete键&#xff0c;其他电脑可能是esc或者某个f键&#xff0c;请自行查找。 进入BIOS后要找到电源管理&#xff0c;可以在高级选项中找一找&#xff0c;如上图右下角选择高级模式。 …

【DDD】学习笔记-理解领域模型

Eric Evans 的领域驱动设计是对软件设计领域的一次重新审视&#xff0c;是在面向对象语言大行其道时对数据建模的“拨乱反正”。Eric 强调了模型的重要性&#xff0c;例如他在书中总结了模型在领域驱动设计中的作用包括&#xff1a; 模型和设计的核心互相影响模型是团队所有成…

基于微信小程序的校园二手交易平台

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

Linux下代码的运行

在Windows环境下&#xff0c;我们代码都是在集成开发环境下运行&#xff0c;也就是说代码的编辑、编译、调试、运行都在一个软件上&#xff0c;而在Linux环境下这些都是分开执行的。 Linux编辑器-vim vim是一款多模式编辑器&#xff0c;vim有很多模式&#xff0c;最常用的三个…

idea: 无法创建Java Class文件(SpringBoot)已解决

第一&#xff1a;点击file-->project Sructure... 第二步&#xff1a;点击Moudules 选择自己需要创建java的文件夹&#xff08;我这里选择的是main&#xff09;右键点击Sources&#xff0c;然后点击OK即可 然后就可以创建java类了

【漏洞复现】EasyCVR智能边缘网关用户信息泄漏漏洞

Nx01 产品简介 EasyCVR智能边缘网关是一种基于边缘计算和人工智能技术的设备&#xff0c;旨在提供高效的视频监控和智能分析解决方案。它结合了视频监控摄像头、计算能力和网络连接&#xff0c;能够在现场进行视频数据处理和分析&#xff0c;减轻对中心服务器的依赖。 Nx02 漏…

【深度学习】pytorch 与 PyG 安装(pip安装)

【深度学习】pytorch 与 PyG 安装&#xff08;pip安装&#xff09; 一、PyTorch安装和配置&#xff08;一&#xff09;、安装 CUDA&#xff08;二&#xff09;、安装torch、torchvision、torchaudio三个组件&#xff08;1&#xff09;下载镜像文件&#xff08;2&#xff09;创建…