python-pytorch 下批量seq2seq+Bahdanau Attention实现问答1.0.000

python-pytorch 下批量seq2seq+Bahdanau Attention实现简单问答1.0.000

    • 前言
    • 原理看图
    • 数据准备
    • 分词、index2word、word2index、vocab_size
    • 输入模型的数据构造
    • 注意力模型
    • decoder的编写
    • 关于损失函数和优化器
    • 在预测时
    • 完整代码
    • 参考

前言

前面实现了 luong的dot 、general、concat注意力实现简单问答,这里参考官方文档,实现了python-pytorch 下批量seq2seq+Bahdanau Attention实现问答

原理看图

在这里插入图片描述
这里模型选择和官方不一样,官方选择的是GRU,我更喜欢使用LSTM,解码器和编码器都是如此。
意思大致思路是:

  1. 计算encoder的encoder_outputs、encoder_hn、encoder_cn
  2. 使用encoder_outputs、encoder_hn计算新的向量和注意力
  3. 在deconder中,以SOS单字开始,循环句子最大长度,在循环中,使用新的向量和单字SOS做cat计算得到decoder的LSTM输入数据,将该LSTM存起来,最后做cat计算得到decoder的输出

数据准备

结果类似还是采用前面的结构和数据

seq_example = [“你认识我吗”, “你住在哪里”, “你知道我的名字吗”, “你是谁”, “你会唱歌吗”, “谁是张学友”]
seq_answer = [“当然认识”, “我住在成都”, “我不知道”, “我是机器人”, “我不会”, “她旁边那个就是”]

分词、index2word、word2index、vocab_size

分词然后做基础准备,包括数据:index2word、word2index、vocab_size、最长的句子长度seq_length,和一些超参数的设置

输入模型的数据构造

  1. 长度要统一
  2. 问答的句子以EOS结尾,不足补0,如

tensor([[ 3, 4, 5, 6, 2, 0, 0],
[ 3, 7, 8, 9, 2, 0, 0],
[ 3, 10, 5, 11, 12, 6, 2],
[ 3, 13, 14, 2, 0, 0, 0],
[ 3, 15, 16, 6, 2, 0, 0],
[14, 13, 17, 2, 0, 0, 0]])

注意力模型

可以复用,用官方的即可

# Bahdanau
# query=hidden [layer_num,batch_size,hidden_size] keys=encoder_outputs  [seq_len,batch_size,hidden_size]
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) #[seq_len,batch_size,1]
        scores = scores.permute(1,0,2).squeeze(2).unsqueeze(1)#[batch_size,1,seq_len]

        weights = nn.functional.softmax(scores, dim=-1)#[batch_size,1,seq_len]
        context = torch.bmm(weights, keys.permute(1,0,2))#[batch_size,1,hidden_size]

        return context, weights

decoder的编写

思路是,获得encoder的输出和hn后,计算得到向量,然后使用向量和目标的每一个字做cat计算,输入decoder的模型中,然后得出一个字的预测,循环完了以后,就会得到最大句子长度,最后做cat和softmax计算得到输出。另外,这里要区分训练和测试,训练的时候有target,测试的没有target数据。

关于损失函数和优化器

NLLLoss+Adam的组合优于CrossEntropyLoss+SGD的组合

在预测时

获取到模型输出,size是[batch_size,seq_len,vocab_size]后,对结果做topk计算,会得到每一字在vocab_size的概率,连接起来就是一句话

完整代码

# def getAQ():
#     ask=[]
#     answer=[]
#     with open("./data/flink.txt","r",encoding="utf-8") as f:
#         lines=f.readlines()
#         for line in lines:
#             ask.append(line.split("----")[0])
#             answer.append(line.split("----")[1].replace("\n",""))
#     return answer,ask

# seq_answer,seq_example=getAQ()



import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdm
 
seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "谁是张学友"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "她旁

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/642717.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CCF20230901——坐标变换(其一)

CCF20230901——坐标变换&#xff08;其一&#xff09; #include<bits/stdc.h> using namespace std; int main() {int n,m,x[101],y[101],x1[101],y1[101];cin>>n>>m;for(int i0;i<n;i)cin>>x1[i]>>y1[i];for(int j0;j<m;j)cin>>x[…

【活动】开源与闭源大模型:探索未来趋势的双轨道路

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 开源与闭源大模型&#xff1a;探索未来趋势的双轨道路引言一、开源大模型&#…

数据库缓存 buffer pool详解

什么是buffer pool buffer pool, 又称之缓存池, 是mysql中为了提升查询性能而引入的缓存, 如果每次查询和修改都去操作磁盘的话, 性能就会很差, 从而引入 Buffer Pool包含多个缓冲页&#xff08;默认大小通常为16KB&#xff09;&#xff0c;每个缓冲页都有对应的控制信息&#…

【TB作品】stm32单片机读取DS2401程序

DS2401是由Analog Devices公司生产的一种硅序列号芯片&#xff0c;它提供了一个绝对唯一的64位ROM识别码&#xff0c;用于确保可追溯性。以下是对DS2401器件的分析&#xff1a; 特点和优势&#xff1a; 唯一性&#xff1a;每个DS2401芯片都有一个独一无二的64位注册码&#x…

Windows安装VMware(Broadcom)

1.安装前提 1.检查BIOS中是否开启了虚拟化技术。1.1 打开任务管理器&#xff0c;查看性能&#xff0c;CPU部分&#xff0c;虚拟化处于“已启用”状态。1.2 如果没有开启&#xff0c;则需要进入BIOS系统&#xff0c;将 Intel Virtualization Technology改为Enalble。2.下载VMwa…

海外CDN加速方式

随着全球化经济的进一步推进和互联网时代的到来&#xff0c;给对外贸易行业带来了巨大的商机&#xff0c;众多传统的贸易公司都纷纷建立起自已的外贸网站或服务站点等各种信息化平台&#xff0c; 相当多的贸易公司也从他们所构建的平台中得到了很高的利益&#xff0c;然而由于当…

推荐个免费天气接口

http://www.tianqiapi.com/index/doc?versionmonthhttp://www.tianqiapi.com/index/doc?versionmonth 个人博客使用足够了&#xff01;

浅析FAT32文件系统

本文通过实验测试了FAT文件系统的存储规律&#xff0c;并且探究了部分可能的文件隐藏方法。 实验背景 现有一块硬盘&#xff08;U盘&#xff09;&#xff0c;其中存在两个分区&#xff0c;分别为FAT32和NTFS文件系统分区。 在FAT分区中存在如下文件&#xff1a; 现需要阅读底…

Android Compose 九:常用组件列表 简单使用

遇事不决 先看官方文档 列表和网格 如果不需要任何滚动&#xff0c;通过Column 或 Row可以使用verticalScroll() 使Column滚动 Column(modifier Modifier.verticalScroll(rememberScrollState())) {for (i in 0..50){Text(text "条目>>${i}")}}显示大量列表…

移动硬盘难题:不显示容量与无法访问的解决策略

在使用移动硬盘的过程中&#xff0c;有时会遇到一些棘手的问题&#xff0c;比如移动硬盘不显示容量且无法访问。这种情况让人十分头疼&#xff0c;因为它不仅影响了数据的正常使用&#xff0c;还可能导致重要数据的丢失。接下来&#xff0c;我们就来详细探讨一下这个问题及其解…

弹性盒子布局,flex布局

弹性盒子布局&#xff08;Flexbox&#xff09;是CSS3引入的一种新的布局模式&#xff0c;它提供了一种更加有效的方式来设计、布局和对齐容器中的项目&#xff0c;即使容器的大小动态改变或者项目的数量未知。 弹性盒子布局的主要特点是能够轻松地在不同的屏幕大小和设备上实现…

虚拟局域网(VLAN)

关键词&#xff1a;veth、vlan、bridge、iptables、nat、tcpdump、icmp、cidr、arp、路由表、计算机网络协议栈 前言 在过去的几十年里&#xff0c;互联网发展得非常快。许多新兴技术迅速崛起&#xff0c;也有不少曾经的主流技术被淘汰。然而&#xff0c;有些技术因为其基础性…

iPhone实况照片从Windows资源管理器复制的JPG+MOV无法正常还原到iPhone

背景&#xff1a; 之前使用的iPhone 15 Pro&#xff0c;使用的Windows资源管理器当中复制导出的实况照片&#xff0c;复制出来的格式例如IMG_0001.JPG, IMG_0001.MOV。之后手机就卖掉了。现在使用的iPhone 14 Pro Max&#xff0c;想要导回之前备份的实况照片。尝试使用爱思助手…

解决vue版本不一致导致不能正常编译

解决vue版本不一致导致不能正常编译 异常现象分析原因解决方案 异常现象 项目原本运行无异常&#xff0c;但安装了一个el-table-infinite-scroll的插件后&#xff0c;编译报错&#xff0c;截图如下 分析原因 vue版本与compile版本不一致&#xff0c;应该统一起来&#xff0…

网创教程:WordPress插件网创自动采集并发布

网创教程&#xff1a;WordPress插件网创自动采集并发布 使用插件注意事项&#xff1a; 如果遇到404错误&#xff0c;请先检查并调整网站的伪静态设置&#xff0c;这是最常见的问题。需要定制化服务&#xff0c;请随时联系我。 本次更新内容 我们进行了多项更新和优化&#x…

1.存储部分

1.Flash Memory--闪速存储器&#xff08;注&#xff1a;U盘&#xff0c;SD卡就是闪存&#xff09;在EEPROM基础上发展而来的&#xff0c;断电后也能保存信息&#xff0c;且可进行多次 快速擦除重写。注意&#xff1a;由于闪存需要先擦除再写入&#xff0c;因此闪存写的速度要比…

ssm141餐厅点菜管理系统+vue

餐厅点菜管理系统的设计与实现 摘 要 网络技术和计算机技术发展至今&#xff0c;已经拥有了深厚的理论基础&#xff0c;并在现实中进行了充分运用&#xff0c;尤其是基于计算机运行的软件更是受到各界的关注。加上现在人们已经步入信息时代&#xff0c;所以对于信息的宣传和管…

Unity-Sprite Atlas+UGUI系统的运行原理

每日一句&#xff1a;别听世俗耳语&#xff0c;看自己的风景就好 目录 SA的原理&#xff1a; SA的优点&#xff1a; SA的缺点&#xff1a; DrawCall是什么&#xff1f; 批处理是什么&#xff1f; 我们先了解一下UGUI系统的运行原理吧&#xff01; 提到图集优化&#xff0…

性能测试--线程的监控

1.线程的状态 1.1.线程的5种状态 java的线程总共有5种状态&#xff0c;如下&#xff1a; * 新建&#xff1a;new 【新建之后不启用都是new】* 运行&#xff1a;runnable* 等待&#xff1a;waitting(无限期等待),timed waitting(限期等待)* 阻塞&#xff1a;blocked* 结束&am…

C++ 求 f(x) = sin x / x 的函数极限

Python&#xff1a; import sympy from sympy import oo #注意无究符号表示形式为两个小定字母o import numpy as np x sympy.Symbol(x) #注意Symbol首字母大写f sympy.sin(x) / x # 定义极限表达式lim sympy.limit(f,x,oo) print(lim)输出&#xff1a; 0 C #include…