机器翻译之多头注意力(MultiAttentionn)在Seq2Seq的应用

目录

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

2.2验证上述封装的代码 

2.3 创建 添加了Bahdanau的decoder 

 2.4训练

 2.5预测

3.知识点个人理解 


 

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

class MultiHeadAttention(nn.Module):
    #初始化属性和方法
    def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        """
        query_size_size: query_size的特征数features
        key_size: key_size的特征数features
        value_size: value_size的特征数features
        num_hiddens:隐藏层的神经元的数量
        num_heads:多头注意力的header的数量
        dropout: 释放模型需要计算的参数的比例
        bias=False:没有偏差
        **kwargs : 不定长度的关键字参数
        """
        super().__init__(**kwargs)
        #接收参数
        self.num_heads = num_heads
        #初始化注意力,    #使用DotProductAttention时, keys与 values具有相同的长度, 经过decoder,他们长度相同
        self.attention = dltools.DotProductAttention(dropout)
        #初始化四个w模型参数
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        def transpose_qkv(X, num_heads):
            """实现queries, keys, values的数据维度转化"""
            #输入的X的shape=(batch_size, 查询数/键值对数量, num_hiddens)
            #这里,不能直接用reshape,需要索引维度,防止数据不能一一对应
            X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)   #将原维度的num_hiddens拆分成num_heads, -1,  -1相当于num_hiddens/num_heads的数值
            X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, num_size, 查询数/键值对数量, num_hiddens/num_heads)
            return X.reshape(-1, X.shape[2], X.shape[3])  #X的shape=(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)

        def transpose_outputs(X, num_heads):
            """逆转transpose_qkv的操作"""
            #此时数据的X的shape =(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)
            X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])  #X的shape=(batch_size, num_heads, 查询数/键值对数量, num_hiddens/num_heads)
            X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, 查询数/键值对数量, num_heads,  num_hiddens/num_heads)
            return X.reshape(X.shape[0], X.shape[1], -1)  #X的shape还原了=(batch_size, 查询数/键值对数, num_hiddens)

        #queries, keys, values,传入的shape=(batch_size, 查询数/键值对数, num_hiddens)
        #获取转换维度之后的queries, keys, values,
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        #若valid_len不为空,存在
        if valid_lens is not None:
            #将valid_lens重复数据self.num_heads次,在0维度上
            valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)
        #若为空,什么都不做,跳出if判断,继续执行其他代码

        #通过注意力函数获取输出outputs
        #outputs的shape = (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)
        outputs = self.attention(queries, keys, values, valid_lens)

        #逆转outputs的维度
        outputs_concat = transpose_outputs(outputs, self.num_heads)

        return self.W_o(outputs_concat)

2.2验证上述封装的代码 

#假设变量
num_hiddens, num_heads, dropout = 100, 5, 0.2
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
attention.eval()  #需要预测,加上
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
#假设变量
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])

X = torch.ones((batch_size, num_queries, num_hiddens))  #shape(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  #shape(2,6,100) 

attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

2.3 创建 添加了Bahdanau的decoder 

# 添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs : (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # X : (batch_size, num_steps, vocab_size)
        X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens

            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)

            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)

            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)

            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
            

        outputs = self.dense(torch.cat(outputs, dim=0))

        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

 2.4训练

# 训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)

encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)

decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)

net = dltools.EncoderDecoder(encoder, decoder)

dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 2.5预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('trouvez tom .', []), bleu 0.000
i'm home . => ('je suis chez moi .', []), bleu 1.000

3.知识点个人理解 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/883952.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

protobuf编码方式

protobuf编码方式 一个简单的例子 message Test1 {optional int32 a 1; }上述的proto文件,设置a 150,那么将其序列化后,得到的数据就是08 96 01,然后你使用protoscope工具去解析这些数据,就得到1 : 150&#xff0c…

基于深度学习的花卉智能分类识别系统

温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 传统的花卉分类方法通常依赖于专家的知识和经验,这种方法不仅耗时耗力,而且容易受到主观因素的影响。本系统利用 TensorFlow、Keras 等深度学习框架构建卷积神经网络&#…

【Linux:共享内存】

共享内存的概念: 操作系统通过页表将共享内存的起始虚拟地址映射到当前进程的地址空间中共享内存是由需要通信的双方进程之一来创建但该资源并不属于创建它的进程,而属于操作系统 共享内存可以在系统中存在多份,供不同个数,不同进…

推荐5款压箱底的宝贝,某度搜索就有

​ 今天要给大家推荐5款压箱底的宝贝软件了,都是在某度搜索一下就能找到的好东西。 1.桌面壁纸——WinDynamicDesktop ​ WinDynamicDesktop是一款创新的桌面壁纸管理工具,能根据用户的地理位置和时间自动更换壁纸。软件内置多个美丽的动态壁纸主题&am…

苹果电脑系统重磅更新——macOS Sequoia 15 系统 新功能一 览

有了 macoS Sequoia,你的工作效率将再次提升:快速调整桌面布局,一目了然地浏览网页重点,还可以通过无线镜像功能操控你的iPhone。 下面就来看看几项出色新功能,还有能够全面发挥这些功能的 App 和游戏。 macOS Sequo…

智能新突破:AIOT 边缘计算网关让老旧水电表图像识别

数字化高速发展的时代,AIOT(人工智能物联网)技术正以惊人的速度改变着我们的生活和工作方式。而其中,AIOT 边缘计算网关凭借其强大的功能,成为了推动物联网发展的关键力量。 这款边缘计算网关拥有令人瞩目的 1T POS 算…

使用build_chain.sh离线搭建匹配的区块链,并通过命令配置各群组节点的MySQL数据库

【任务】 登陆Linux服务器,以MySQL分布式存储方式安装并部署如图所示的三群组、四机构、 七节点的星形组网拓扑区块链系统。其中,三群组名称分别为group1、group2和group3, 四个机构名称为agencyA、agencyB、agencyC、agencyD。p2p_port、cha…

powerbi计算销售额累计同比增长率——dax

目录 效果展示: 一、建立日期表 二、建立度量值 1.销售收入 2.本年累计销售额 3.去年累计销售额 4.累计同比增长率 三、矩阵表制作 效果展示: 数据包含2017-2019年的销售收入数据 一、建立日期表 日期表建立原因及步骤见上一篇文章https://blog…

数据处理与统计分析篇-day11-RFM模型案例

会员价值度模型介绍 会员价值度用来评估用户的价值情况,是区分会员价值的重要模型和参考依据,也是衡量不同营销效果的关键指标之一。 价值度模型一般基于交易行为产生,衡量的是有实体转化价值的行为。常用的价值度模型是RFM RFM模型是根据…

UNI-SOP应用场景(1)- 纯前端预开发

在平时新项目开发中,前端小伙伴是否有这样的经历,hi,后端小伙伴们,系统啥时候能登录,啥时候能联调了,这是时候往往得到的回答就是,再等等,我们正在搭建系统呢,似曾相识的…

20个数字经济创新发展试验区建设案例【2024年发布】

数据简介:国家数字经济创新发展试验区的建设是一项重要的国家战略,旨在推动数字经济与实体经济的深度融合,促进经济高质量发展。自2019年10月启动以来,包括河北省(雄安新区)、浙江省、福建省、广东省、重庆…

通过OpenScada在ARMxy边缘计算网关上实现数字化转型

随着工业4.0概念的普及,数字化转型已成为制造业升级的关键路径之一。在此背景下,边缘计算技术因其能够有效处理大量数据、减少延迟并提高系统响应速度而受到广泛关注。ARMxy边缘计算网关,特别是BL340系列,凭借其强大的性能和灵活的…

Linux网络之UDP与TCP协议详解

文章目录 UDP协议UDP协议数据报报头 TCP协议确认应答缓冲区 超时重传三次握手其他问题 四次挥手滑动窗口流量控制拥塞控制 UDP协议 前面我们只是说了UDP协议的用法,但是并没有涉及到UDP协议的原理 毕竟知道冰箱的用法和知道冰箱的原理是两个层级的事情 我们首先知道计算机网…

使用API有效率地管理Dynadot域名,设置域名服务器(NS)

前言 Dynadot是通过ICANN认证的域名注册商,自2002年成立以来,服务于全球108个国家和地区的客户,为数以万计的客户提供简洁,优惠,安全的域名注册以及管理服务。 Dynadot平台操作教程索引(包括域名邮箱&…

在虚幻引擎中实现Camera Shake 相机抖动/震屏效果

在虚幻引擎游戏中创建相机抖动有时能让画面更加高级 , 比如 遇到大型的Boss , 出现一些炫酷的特效 加一些短而快的 Camera Shake 能达到很好的效果 , 为玩家提供沉浸感 创建Camera Shake 调整Shake参数 到第三人称或第一人称蓝图 调用Camera Shake Radius值越大 晃动越强

拍卖的价格怎么定?聊聊转转拍卖场的起拍定价算法演变

价格策略、定价调价算法是诸多中大规模电商不可或缺的一项能力,涉及到精准定价、智能调价、智能发券、成本控制等一系列智能运营场景,尤其对于二手行业来说,定价能力更是面临诸多挑战,却又不可或缺。本文将旨在介绍转转 TOB 拍卖场…

kibana开启访问登录认证

编辑es配置文件,添加以下内容开启es认证 vim /etc/elasticsearch/elasticsearch.yml http.cors.enabled: true http.cors.allow-origin: "*" http.cors.allow-headers: Authorization xpack.security.enabled: true xpack.security.transport.ssl.enable…

解释器模式原理剖析和Spring中的应用

解释器模式原理剖析和Spring中的应用 解释器模式 是一种行为型设计模式,它定义了一种语言的文法表示,并提供了一个解释器来处理该文法的表达式。解释器模式可以用于构建语法解释器,例如计算器、简单编程语言的解释器等。 核心思想&#xff1a…

Java框架学习(mybatis)(01)

简介:以本片记录在尚硅谷学习ssm-mybatis时遇到的小知识 详情移步:想参考的朋友建议全部打开相互配合学习! 官方文档: MyBatis中文网https://mybatis.net.cn/index.html 学习视频: 067-mybatis-介绍和对比_哔哩哔…

人工智能时代,程序员如何保持核心竞争力?

引言 随着AIGC(如ChatGPT、Midjourney、Claude等)大语言模型接二连三的涌现,AI辅助编程工具日益普及,程序员的工作方式正在发生深刻变革。有人担心AI可能取代部分编程工作,也有人认为AI是提高效率的得力助手。面对这一…