李沐66_使用注意力机制的seq2seq——自学笔记

加入注意力

1.编码器对每次词的输出作为key和value

2.解码器RNN对上一个词的输出是query

3.注意力的输出和下一个词的词嵌入合并进入RNN

一个带有Bahdanau注意力的循环神经网络编码器-解码器模型

总结

1.seq2seq通过隐状态在编码器和解码器中传递信息

2.注意力机制可以根据解码器RNN的输出来匹配到合适的编码器RNN的输出来更有效的传递信息。

pip install d2l==0.17.6  ### 很重要,不要下载错了,对于colab
import torch
from torch import nn
from d2l import torch as d2l

注意力解码器

AttentionDecoder类定义了带有注意力机制解码器的基本接口

class AttentionDecoder(d2l.Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    @property
    def attention_weights(self):
        raise NotImplementedError

Seq2SeqAttentionDecoder类中 实现带有Bahdanau注意力的循环神经网络解码器。

1.编码器在所有时间步的最终层隐状态,将作为注意力的键和值;

2.上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;

3.编码器有效长度(排除在注意力池中填充词元)。

class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention = d2l.AdditiveAttention(num_hiddens,num_hiddens,num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)

    def forward(self, X, state):
        # enc_outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,
        # num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # 输出X的形状为(num_steps,batch_size,embed_size)
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            # query的形状为(batch_size,1,num_hiddens)
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            # context的形状为(batch_size,1,num_hiddens)
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens)
            # 在特征维度上连结
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # 将x变形为(1,batch_size,embed_size+num_hiddens)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        # 全连接层变换后,outputs的形状为
        # (num_steps,batch_size,vocab_size)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,
                                          enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights

使用包含7个时间步的4个序列输入的小批量测试Bahdanau注意力解码器。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
                             num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2
                                  )
decoder.eval()
X = d2l.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

实例化一个带有Bahdanau注意力的编码器和解码器, 并对这个模型进行机器翻译训练。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
loss 0.020, 7390.3 tokens/sec on cuda:0

在这里插入图片描述

模型训练后,我们用它将几个英语句子翻译成法语并计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
go . => va !,  bleu 1.000
i lost . => j'ai perdu .,  bleu 1.000
he's calm . => il est mouillé .,  bleu 0.658
i'm home . => je suis chez moi .,  bleu 1.000
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((
    1, 1, -1, num_steps))

训练结束后,下面通过可视化注意力权重 会发现,每个查询都会在键值对上分配不同的权重,这说明 在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/574525.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

.NET 个人博客-添加RSS订阅功能

个人博客-添加RSS订阅功能 前言 个人博客系列已经完成了 留言板文章归档推荐文章优化推荐文章排序 博客地址 然后博客开源的原作者也是百忙之中添加了一个名为RSS订阅的功能,那么我就来简述一下这个功能是干嘛的,然后照葫芦画瓢实现一下。 RSS简述…

php代码比对工具优化版

下载地址:php代码比对工具优化版.zip 一款强大且专业的文件对比工具(php代码比对),用户可以直接在线进行两个或多个文件的差异对比,支持用户进行多种格式的问价对比,用户可以在这里轻松查找出相同会不同之处,支持用户…

elementui el-date-picker禁止选择今年、今天、之前、时间范围限制18个月

1、禁止选择今年之前的所有年份 <el-date-pickerv-if"tabsActive 0":clearable"false"v-model"yearValue"change"yearTimeChange"type"year"placeholder"选择年"value-format"yyyy":picker-options…

文化旅游3D数字孪生可视化管理平台推动文旅产业迈向更加美好的未来

随着数字化、智能化管理成为文旅产业发展的必然趋势&#xff0c;数字孪生公司深圳华锐视点创新性地推出了景区三维可视化数字孪生平台&#xff0c;将线下的实体景区与线上的虚拟世界完美融合&#xff0c;引领智慧文旅新潮流。 我们运用先进的数字孪生、web3D开发和三维可视化等…

怎么设置 idea terminal 窗口的编码格式

1 修改Terminal 窗口为 Git bash 窗口 打开 settings 设置界面&#xff0c;选择 Tools 中的 Terminal (File -> settings -> Tools -> Terminal) 修改 Shell path 为你的 Git bash 安装路径&#xff0c;我的在 C:\my_software\java\Git\bin\bash.exe 2 解决中文显示…

高端制造企业生产设备文件管理,怎样保证好用不丢失文件?

高端制造业在市场经济中占据重要角色&#xff0c;在高端制造业企业内部&#xff0c;生产设备又是最关键的一环环&#xff0c;它们不仅负责完成生产任务&#xff0c;同时也会产生大量的文件。这些数据反映了设备的运行状态、生产效率、能源消耗以及产品质量等多个方面&#xff0…

网站内容下载软件有哪些 网站内容下载软件推荐 网站内容下载软件安全吗 idm是啥软件 idm网络下载免费

一招搞定网页内容下载&#xff0c;并且各大网站通用&#xff01;绕过资源审查&#xff0c;所有网站内容随意下载。解锁速度限制&#xff0c;下载即高速无视网站限速。跳过会员充值&#xff0c;所有VIP资源免费下载。有关网站内容下载软件有哪些&#xff0c;网站内容下载软件推荐…

汽车信息安全--如何理解TrustZone(1)

目录 1.车规MCU少见TrustZone 2. 什么是TrustZone 2.1 TrustZone隔离了什么&#xff1f; 2.2 处理器寄存器和异常处理 3.小结 1.车规MCU少见TrustZone 在车规MCU里&#xff0c;谈到信息安全大家想到的大多可能都是御三家的HSM方案&#xff1a;英飞凌的HSM\SHE、瑞萨的ICU…

【【gitlab解决git Clone 出现 Permission denied, please try again.】】

【gitlab解决git Clone 出现 Permission denied, please try again.】 问题解决随便找一个地方 点击右键输入ssh -keygen -C "邮件"显示结果输入 登录gitlab然后再次git Clone就可以了。 问题 git clone的时候出现 Permission denied, please try again 解决 随便…

跨语言指令调优深度探索

目录 I. 介绍II. 方法与数据III. 结果与讨论1. 跨语言迁移能力2. 问题的识别3. 提高跨语言表现的可能方向 IV. 结论V. 参考文献 I. 介绍 在大型语言模型的领域&#xff0c;英文数据由于其广泛的可用性和普遍性&#xff0c;经常被用作训练模型的主要语料。尽管这些模型可能在英…

jar依赖批量上传Nexus服务器(二)

jar依赖批量上传Nexus服务器&#xff08;二&#xff09; 批量上传脚本 #!/bin/bash # copy and run this script to the root of the repository directory containing files # this script attempts to exclude uploading itself explicitly so the script name is important…

华为数字化转型与数据管理实践介绍(附PPT下载)

华为作为全球领先的信息与通信技术&#xff08;ICT&#xff09;解决方案提供商&#xff0c;在数字化转型和数据管理领域拥有丰富的实践经验和技术积累。其数字化转型解决方案旨在帮助企业通过采用最新的ICT技术&#xff0c;实现业务流程、组织结构和文化的全面数字化&#xff0…

电子工艺卡在汽车制造流程中的应用

在当今高度发达的汽车工业中&#xff0c;电子工艺卡作为一种重要的工具&#xff0c;在汽车制造流程中发挥着至关重要的作用。它不仅是汽车生产的指导手册&#xff0c;更是确保汽车质量和性能的关键因素。 汽车制造是一个复杂而精密的过程&#xff0c;涉及众多的零部件和系统。电…

云LIS系统概述JavaScript+前端框架JQuery+EasyUI+Bootstrap医院云HIS系统源码 开箱即用

云LIS系统概述JavaScript前端框架JQueryEasyUIBootstrap医院云HIS系统源码 开箱即用 云LIS&#xff08;云实验室信息管理系统&#xff09;是一种结合了计算机网络化信息系统的技术&#xff0c;它无缝嵌入到云HIS&#xff08;医院信息系统&#xff09;中&#xff0c;用于连…

粘合/胶合/粘接/聚酰亚胺PI材料使用UV胶,具有高强度粘接的优势,这一点具体要如何操作?(三十五)

前面文章说明使用UV胶粘合聚酰亚胺PI材料时&#xff0c;有一点优势是&#xff1a;具有高强度粘接&#xff0c;UV胶粘剂对聚酰亚胺PI材料具有良好的附着性&#xff0c;能够提供高强度的粘接。这对于需要承受重负载或高应力的应用来说尤为重要。 这一点提到UV胶在粘合聚酰亚胺&am…

03-JAVA设计模式-状态模式

状态模式 什么是状态模式 Java中的状态模式&#xff08;State Pattern&#xff09;是一种行为型设计模式&#xff0c;主要用于解决系统中复杂对象的状态转换以及不同状态下行为的封装问题。状态模式允许一个对象在其内部状态改变时改变它的行为&#xff0c;使得对象看起来似乎…

Stable Diffusion WebUI 使用 VAE 增加滤镜效果

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里&#xff0c;订阅后可阅读专栏内所有文章。 大家好&#xff0c;我是水滴~~ 本文主要介绍 VAE 模型&#xff0c;主要内容有&#xff1a;VAE 模型的概念、如果下载 VAE 模型、如何安装 VAE 模型、如…

STL——List常用接口模拟实现及其使用

认识list list的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。 list的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&#xff0c;在节点中通过指针指向其前一个元素和后一个元素…

FlyFlow:全新开源版问世,支持SpringBoot3+Flowable7

经过精心打磨和严格测试&#xff0c;我们隆重推出全新FlyFlow开源版&#xff0c;这款源自商业版的强大工具&#xff0c;如今已完美融入SpringBoot3和Flowable7两大核心框架&#xff0c;为开发者带来前所未有的便捷与高效。 SpringBoot3的加持&#xff0c;让FlyFlow在简化开发流…

【计算机网络】成功解决 ARP项添加失败:请求的操作需要提升

最近在用Wireshark做实验时候&#xff0c;需要清空本机ARP表和DNS缓存&#xff0c;所以在cmd窗口输入以下命令&#xff0c; 结果发生了错误&#xff1a;ARP项添加失败&#xff1a;请求的操作需要提升 一开始我还以为是操作的命令升级了&#xff0c;但是后面发现其实只是给的权…