Transformer英语-法语机器翻译实例

依照Transformer结构来实例化编码器-解码器模型。在这里,指定Transformer编码器和解码器都是2层,都使用4头注意力。为了进行序列到序列的学习,我们在英语-法语机器翻译数据集上训练Transformer模型,如图11.2所示。

data_path = "weibo_senti_100k.csv"
data_list = open(data_path,"r",encoding='UTF-8').readlines()[1:]
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

encoder = TransformerEncoder(
    len(src_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
decoder = TransformerDecoder(
    len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

loss 0.030, 5244.8 tokens/sec on cuda:0

图11.2  在英语-法语机器翻译数据集上训练Transformer模型

训练结束后,使用Transformer模型将一些英语句子翻译成法语,并且计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

结果如下:

go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est calme ., bleu 1.000
i'm home . => je suis chez moi ., bleu 1.000

当进行最后一个英语到法语的句子翻译工作时,需要可视化Transformer的注意力权重。编码器自注意力权重的形状为[编码器层数,注意力头数,num_steps或查询的数目,num_steps或“键-值”对的数目]。

enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads,
    -1, num_steps))
enc_attention_weights.shape

结果如下:

torch.Size([2, 4, 10, 10])

在编码器的自注意力中,查询和键都来自相同的输入序列。由于填充词元是不携带信息的,因此通过指定输入序列的有效长度,可以避免查询与使用填充词元的位置计算注意力。接下来,将逐行呈现两层多头注意力的权重。每个注意力头都根据查询、键和值不同的表示子空间来表示不同的注意力,如图11.3所示。

d2l.show_heatmaps(
    enc_attention_weights.cpu(), xlabel='Key positions',
    ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
    figsize=(7, 3.5))

图11.3  4头注意力模型

为了可视化解码器的自注意力权重和编码器-解码器的注意力权重,我们需要完成更多的数据操作工作。

例如,用零填充被遮蔽住的注意力权重。值得注意的是,解码器的自注意力权重和编码器-解码器的注意力权重都有相同的查询,即以序列开始词元(Beginning-Of-Sequence,BOS)打头,再与后续输出的词元共同组成序列。

dec_attention_weights_2d = [head[0].tolist()
for step in dec_attention_weight_seq
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = torch.tensor(
    pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))
dec_self_attention_weights, dec_inter_attention_weights = \
    dec_attention_weights.permute(1, 2, 3, 0, 4)
dec_self_attention_weights.shape, dec_inter_attention_weights.shape

结果如下:

(torch.Size([2, 4, 6, 10]), torch.Size([2, 4, 6, 10]))

由于解码器自注意力的自回归属性,查询不会对当前位置之后的键-值对进行注意力计算。结果如图11.4所示。

# Plusonetoincludethebeginning-of-sequencetoken
d2l.show_heatmaps(
    dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
    xlabel='Key positions', ylabel='Query positions',
    titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))

图11.4  查询不会对当前位置之后的键-值对进行注意力计算

与编码器的自注意力的情况类似,通过指定输入序列的有效长度,输出序列的查询不会与输入序列中填充位置的词元进行注意力计算。结果如图11.5所示。

d2l.show_heatmaps(
    dec_inter_attention_weights, xlabel='Key positions',
    ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
    figsize=(7, 3.5))

图11.5  指定输入序列的有效长度的4头注意模型

尽管Transformer结构是为了序列到序列的学习而提出的,Transformer编码器或Transformer解码器通常被单独用于不同的深度学习任务中。

节选自《Python深度学习原理、算法与案例》。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/106673.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

网络协议--DNS:域名系统

14.1 引言 域名系统(DNS)是一种用于TCP/IP应用程序的分布式数据库,它提供主机名字和IP地址之间的转换及有关电子邮件的选路信息。这里提到的分布式是指在Internet上的单个站点不能拥有所有的信息。每个站点(如大学中的系、校园、…

工业4.0的安全挑战与解决方案

在当今数字化时代,工业4.0已经成为制造业的核心趋势。工业4.0的兴起为生产企业带来了前所未有的效率和灵活性,但与之伴随而来的是一系列的安全挑战。本文将深入探讨工业4.0的安全挑战,并提供一些解决方案,以确保制造业的数字化转型…

Typora(morkdown编辑器)的安装包和安装教程

Typora(morkdown编辑器)的安装包和安装教程 下载安装1、覆盖文件2、输入序列号①打开 typora ,点击“输入序列号”:②邮箱一栏中任意填写(但须保证邮箱地址格式正确),输入序列号,点击…

JS中面向对象的程序设计

面向对象(Object-Oriented,OO)的语言有一个标志,那就是它们都有类的概念,而通过类可以创建任意多个具有相同属性和方法的对象。但在ECMAScript 中没有类的概念,因此它的对象也与基于类的语言中的对象有所不…

order by数据过多引起的cpu飙升

测试环境 1.目前数据库类型为pg数据库2.目前数据库业务为共享数据库,为减少其他业务对本次测试的影响,故选在业务空闲时间执行3.服务器性能为8C 32GB 500GB硬盘 原程序测试结果 优化后程序结果 出现原因 当数据量大时,order by排序操作会消耗大量的CPU资源&#…

数据结构之队列

什么是队列? 队列这个概念非常好理解。你可以把它想象成排队买票,先来的先买,后来的人只能站末尾,不允许插队。先进者先出,这就是典型的“队列”。 我们知道,栈只支持两个基本操作:入栈 push()…

Spring Session框架

Spring Session框架 前言 Spring Session是一个用于在分布式环境中管理会话的框架。它提供了一种无状态的方式来管理用户会话,使得应用程序可以在不同的服务器之间共享会话数据。Spring Session的设计目标是为了解决传统基于Servlet容器的会话管理的局限性&#xf…

【ARM 嵌入式 C 入门及渐进 10 -- 冒泡排序 选择排序 插入排序 快速排序 归并排序 堆排序 比较介绍】

文章目录 排序算法小结排序算法C实现 排序算法小结 C语言中常用的排序算法包括冒泡排序、选择排序、插入排序、快速排序、归并排序、堆排序。下面我们来一一介绍: 冒泡排序(Bubble Sort):冒泡排序是通过比较相邻元素的大小进行排…

(el-Table)操作(不使用 ts):Element-plus 中 Table 多选框的样式等的调整

Ⅰ、Element-plus 提供的 Table 表格组件与想要目标情况的对比: 1、Element-plus 提供 Table 组件情况: 其一、Element-ui 自提供的 Table 代码情况为(示例的代码): // Element-plus 自提供的代码: // 此时是使用了 ts 语言环境…

大语言模型(LLM)综述(四):如何适应预训练后的大语言模型

A Survey of Large Language Models 前言5. ADAPTATION OF LLMS5.1 指导调优5.1.1 格式化实例构建5.1.2 指导调优策略5.1.3 指导调优的效果5.1.4 指导调优的实证分析 5.2 对齐调优5.2.1 Alignment的背景和标准5.2.2 收集人类反馈5.2.3 根据人类反馈进行强化学习5.2.4 无需 RLHF…

Leetcode. 2866.美丽塔II

要求O(N)复杂度内解决,考虑单调栈,这个题很像经典的美丽度的那个单调栈的模板题 对有每一个位置,考虑右边能扩展到哪来?不如直接从末尾来倒着看,发现从末尾需要维护一个单调增的单调栈&#xff…

前端实现打印功能Print.js

前端实现打印的方式有很多种,本人恰好经历了几个项目都涉及到了前端打印,目前较为推荐Print.js来实现前端打印 话不多说,直接上教程 官方链接: Print.js官网 在项目中如何下载Print.js 使用npm下载:npm install print-js --sav…

机器视觉3D项目评估的基本要素及测量案例分析

目录 一. 检测需求确认 1、产品名称:【了解是什么产品上的零件,功能是什么】 2、*产品尺寸:【最大兼容尺寸】 3、*测量项目:【确认清楚测量点位】 4、*精度要求:【若客户提出的精度值过大或者过小,可以和客…

【深度学习】使用Pytorch实现的用于时间序列预测的各种深度学习模型类

深度学习模型类 简介按滑动时间窗口切割数据集模型类CNNGRULSTMMLPRNNTCNTransformerSeq2Seq 简介 本文所定义模型类的输入数据的形状shape统一为 [batch_size, time_step,n_features],batch_size为批次大小,time_step为时间步长&#xff0c…

Project Costs

/*** 初始化象棋的棋子,正常情况加载双方所有棋子,残局演示加载剩余棋子,按坐标位置摆放* * 【费用】* 因甲方要求产生工作量计算费用;新增、修改、删除需求* 因乙方生产缺陷工作量不计费用;缺陷、延误* * 来说个一个栗…

电商时代,VR全景如何解决实体店难做没流量?

近日,电商和实体经济的对立成为了热门话题,尽管电商的兴起确实对线下实体店造成了一定的冲击,但实体店也不是没有办法挽救。VR全景助力线下实体店打造线上店铺,打通流量全域布局,还能实现打开产品、查看产品内部细节等…

机器学习实验一:KNN算法,手写数字数据集(使用汉明距离)

KNN-手写数字数据集: 使用sklearn中的KNN算法工具包( KNeighborsClassifier)替换实现分类器的构建,注意使用的是汉明距离; 分段解释代码: import os import pandas as pd from Levenshtein import hamming导入所需的库,包括os用于文件操作,pandas用于数据处理,以及hamm…

P7473 重力球

P7473 重力球 Solution 考虑 Brute Force:对于每一次询问,通过 BFS 处理出最近的交汇点,输出答案。 很显然,会 TLE \colorbox{navy}{\color{white}{TLE}} TLE​。 故,考虑 优化: 观察发现障碍物数量非…

服务器数据恢复-服务器系统损坏启动蓝屏的数据恢复案例

服务器故障&分析: 某公司一台华为机架式服务器,运行过程中突然蓝屏。管理员将服务器进行了重启,但是服务器操作系统仍然进入蓝屏状态。 导致服务器蓝屏的原因非常多,比较常见的有:显卡/内存/cpu或者其他板卡接触不…

高品质工地建筑模板,防水耐用,易脱模

欢迎选购我们的产品:高品质工地建筑模板。作为一家专业厂家,我们提供适用于高层建筑的建筑模板,具有出色的防水耐用性能,且不易开胶。1. 高品质工地建筑模板:我们的建筑模板经过精心设计和制作,以确保其高品…