Transformer应用之构建聊天机器人(二)

四、模型训练解析

在PyTorch提供的“Chatbot Tutorial”中,关于训练提到了2个小技巧:

  • 使用”teacher forcing”模式,通过设置参数“teacher_forcing_ratio”来决定是否需要使用当前标签词汇来作为decoder的下一个输入,而不是把decoder当前预测出来的词汇当做decoder的下一个输入,这是因为存在这样的情况,如果当前预测出来的词汇跟输入词汇从语义上来讲没有多大关联时,如果继续使用预测出来的词汇来训练模型,有可能就会造成比较大的预测偏差,从而导致模型训练后的预测效果很差,如果改为直接使用输入词汇对应的目标词汇(标签)来作为decoder的下一个输入,相当于进行强制纠偏,使decoder训练时输出与输入之间不至于出现偏差很大的情况。
  • 第2个小技巧是使用梯度裁剪(Gradient Clipping),这是一种常用的防止梯度爆炸的技术。在深度学习训练过程中,因为网络层数较多,梯度可能会非常大,导致模型无法收敛。梯度裁剪的目的就是限制梯度的大小,使其不超过一个预设的阈值,从而避免梯度爆炸的问题。

训练过程如下:

  1. 输入语句正向传播通过encoder
  2. 使用SOS token作为decoder的初始输入,使用encoder的final hidden state来初始化decoder的hidden state
  3. Decoder端根据输入单步执行产生输出
  4. 如果执行”teacher forcing”模式,则把当前对应的目标词汇(标签)作为decoder的下一个输入,否则使用当前decoder的输出词汇作为decoder的下一个输入
  5. 计算并累加损失
  6. 执行反向传播
  7. 执行梯度裁剪
  8. 更新decoder和encoder的模型参数

以下是代码示例:

以下是Transformer模型训练代码示例,

  • 首先把输入sequence(对话输入),输出sequence(对话输出),以及各自的mask传入模型做正向传播
  • 计算预测结果与标签的损失,然后反向传播更新模型参数
  • 训练时可以使用验证集(dev dataset)对训练效果进行评估

五、模型预测(推理)过程解析

下面这个图描述了Transformer的预测推理过程:

  • 假设使用两个encoder和两个decoder来构成这个Transformer模型,首先把输入语句转为embedding词向量,并加入位置编码信息
  • 正向传播通过encoder1,它的输出再通过encoder2,期间会使用多头注意力机制对输入序列中的每个词向量并行地进行注意力Q,K,V的计算
  • Decoder1使用<START> token进行初始化,并使用带掩码多头注意力机制进行计算,并且需要根据前面encoder2的输出进行注意力的计算,然后输出预测得到的词汇
  • Decoder1输出的词汇作为decoder2的输入,同样decoder2在进行多头注意力计算时也需要使用encoder2的注意力计算输出结果
  • Decoder2的输出传入线性层,之后使用Softmax函数转为0到1之间的概率,然后可以使用greedy search(贪心解码)算法得到概率最高的词汇作为预测结果

下面是预测相关代码的示例:

再来看下PyTorch提供的聊天机器人样例的预测操作:

  • 用户输入正向传播通过encoder模型
  • 把encoder的final hidden layer作为decoder模型的first hidden input
  • 使用SOS_token作为decoder的第一个输入来初始化模型
  • decoder根据encoder的输出(上篇文章提到的“Luong attention”注意力机制计算),以及当前decoder的输入,hidden state来输出预测得到的词汇(迭代操作)
  • 使用Softmax计算概率并根据概率获取最有可能出现的词汇
  • 把当前预测得到的词汇作为下一个decoder的输入
  • 收集所有预测得到的词汇

以下是预测相关代码的示例:

六、聊天机器人对话效果解析

基于Transformer的聊天机器人和PyTorch提供的聊天机器人都使用同样的训练语料(“Cornell Movie-Dialogs Corpus.”)进行训练,基于Transformer的聊天机器人模型训练了20个epochs,输入语句最大长度设置为60,PyTorch提供的聊天机器人训练配置如下:

clip = 50.0

teacher_forcing_ratio = 1.0

learning_rate = 0.0001

decoder_learning_ratio = 5.0

n_iteration = 4000

print_every = 1

save_every = 500

使用同样的测试对话语料分别对两个模型进行测试,基于Transformer模型的对话测试结果如下:

PyTorch提供的聊天机器人对话测试结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/23117.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

< ElementUi组件库: el-progress 进度条Bug及样式调整 >

ElementUi组件库&#xff1a; el-progress 进度条Bug及样式调整 &#x1f449; 前言&#x1f449; 一、实现原理> 修改 el-progress 进度条样式 及 渐变进度条样式 &#x1f449; 二、案例代码&#xff08;前言效果图案例&#xff09;> HTML代码> CSS代码 &#x1f44…

C++学习day--12 循环的应用,暴力破解密码和输出动图

第 1 节 职场修炼&#xff1a;程序员到底能干多久 现状&#xff1a; 很多程序员&#xff0c;过了 30 岁&#xff0c;纷纷转行。 原因&#xff1a; 1 &#xff09;薪资过万后&#xff0c;很难进一步提升 2 &#xff09;可替代性高&#xff0c;在新人面前&#xff0c;没有…

SolVES模型在生态系统服务社会价值评估中的运用

SolVES模型&#xff08;Social Values for Ecosystem Services&#xff09;全称为生态系统服务社会价值模型&#xff0c;是由美国地质勘探局和美国科罗拉多州立大学联合开发的一款地理信息系统应用程序&#xff0c;开发该模型的目的主要是对生态系统服务功能中的社会价值进行空…

全面了解Java连接MySQL的基础知识,快速实现数据交互

全面了解Java连接MySQL的基础知识&#xff0c;快速实现数据交互 1. 数据库的重要性2. MySQL数据库简介2.1 MySQL数据库的基本概念2.2 MySQL的基本组成部分包括服务器、客户端和存储引擎。2.3 安装MySQL数据库2.3.1安装MySQL数据库2.3.2 下载MySQL安装程序2.3.3 运行MySQL安装程…

帽子设计作品——蒸汽朋克的乌托邦,机械配件的幻想世界!

蒸汽朋克是由蒸汽steam和朋克punk两个词组成&#xff0c; 蒸汽代表着以蒸汽机作为动力的大型机械&#xff0c;而朋克则代表一种反抗、叛逆的精神。 蒸汽朋克的作品通常以蒸汽时代为背景&#xff0c;通过如新能源、新机械、新材料、新交通工具等新技术&#xff0c;使画面充满想…

理解深度可分离卷积

1、常规卷积 常规卷积中&#xff0c;连接的上一层一般具有多个通道&#xff08;这里假设为n个通道&#xff09;&#xff0c;因此在做卷积时&#xff0c;一个滤波器&#xff08;filter&#xff09;必须具有n个卷积核&#xff08;kernel&#xff09;来与之对应。一个滤波器完成一…

PMP课堂模拟题目及解析(第13期)

121. 项目经理、团队成员以及若干干系人共同参与一次风险研讨会。已经根据风险管理计划生成并提供一份风险报告。若要为各个项目风险进行优先级排序&#xff0c;现在必须执行哪一项分析&#xff1f; A. 定量风险分析 B. 根本原因分析 C. 偏差分析 D. 定性风险分析 122. …

带你手撕链式二叉树—【C语言】

前言&#xff1a; 普通二叉树的增删查改没有意义&#xff1f;那我们为什么要先学习普通二叉树呢&#xff1f; 给出以下两点理由&#xff1a; 1.为后面学习更加复杂的二叉树打基础。&#xff08;搜索二叉树、ALV树、红黑树、B树系列—多叉平衡搜索树&#xff09; 2.有很多二叉树…

Linux安装MongoDB数据库并内网穿透在外远程访问

文章目录 前言1.配置Mongodb源2.安装MongoDB数据库3.局域网连接测试4.安装cpolar内网穿透5.配置公网访问地址6.公网远程连接7.固定连接公网地址8.使用固定公网地址连接 转发自CSDN cpolarlisa的文章&#xff1a;Linux服务器安装部署MongoDB数据库 - 无公网IP远程连接「内网穿透…

亚马逊开放个人卖家验证入口?亚马逊卖家验证到底怎么搞?

亚马逊卖家账户的安全对于所有卖家来说都非常重要。如果卖家想要在亚马逊上长期稳定地发展&#xff0c;赚取更多的钱并推出更多热卖产品&#xff0c;就必须确保他们的亚马逊卖家账户安全&#xff0c;特别是一直存在的亚马逊账户验证问题。 近期&#xff0c;根据亚马逊官方披露的…

开发敏捷高效 | 云原生应用开发与运维新范式

5 月 18 日&#xff0c;腾讯云举办了 Techo Day 腾讯技术开放日&#xff0c;以「开箱吧&#xff01;腾讯云」为栏目&#xff0c;对外发布和升级了腾讯自研的一系列云原生产品和工具。其中&#xff0c;腾讯云开发者产品中心总经理刘毅围绕“开发敏捷高效”这一话题&#xff0c;分…

单体项目偶遇并发漏洞!短短一夜时间竟让老板蒸发197.83元

事先声明&#xff1a;以下故事基于真实事件而改编&#xff0c;如有雷同&#xff0c;纯属巧合~ 眼下这位正襟危坐的男子&#xff0c;名为小竹&#xff0c;他正是本次事件的主人公&#xff0c;也即将成为熊猫集团的被告&#xff0c;嗯&#xff1f;这究竟怎么一回事&#xff1f;欲…

手写简单的RPC框架(一)

一、RPC简介 1、什么是RPC RPC&#xff08;Remote Procedure Call&#xff09;远程过程调用协议&#xff0c;一种通过网络从远程计算机上请求服务&#xff0c;而不需要了解底层网络技术的协议。RPC它假定某些协议的存在&#xff0c;例如TPC/UDP等&#xff0c;为通信程序之间携…

PMP考试应该要如何备考?如何短期通过PMP?

我从新考纲考完下来&#xff0c;3A通过了考试&#xff0c;最开始也被折磨过一段时间&#xff0c;但是后面还是找到了方法&#xff0c;也算有点经验&#xff0c;给大家分享一下吧。 程序猿应该是考PMP里面人最多的&#xff0c;毕竟有一个30大坎&#xff0c;大部分人还是考虑转型…

什么是网络编程

目录 一、什么是网络编程&#xff1f; 二、协议 1.用户数据报协议(User Datagram Protocol) 2.TCP协议 TCP三次握手过程 三、实例 1.UDP通信程序 实现步骤 TCP接收数据 四、TCP协议和UDP协议的区别和联系 一、什么是网络编程&#xff1f; 1.在网络通信协议下&#xf…

一图看懂!RK3568与RK3399怎么选?

▎简介 RK3568和RK3399都是Rockchip公司的处理器&#xff0c;具有不同的特点和适用场景。以下是它们的主要区别和应用场景。 ▎RK3568 RK3568是新一代的高性能处理器&#xff0c;采用了22nm工艺&#xff0c;具有更高的性能和更低的功耗。它支持4K视频解码和编码&#xff0c;支持…

电脑如何查找重复文件?轻松揪出它!

电脑如何查找重复文件&#xff1f;小编每天要接触各种文档、图片等资料&#xff0c;很多时候下载了一些图片后&#xff0c;我根本记不住&#xff0c;下次看到不错的图片&#xff0c;我又会下载下来&#xff0c;结果就是和之前下载的图片是一样的内容。下载的重复文件多了&#…

人员定位及轨迹管理技术原理及应用领域

人员定位及轨迹管理的实现涉及多种技术和设备。例如&#xff0c;在GPS定位方面&#xff0c;使用卫星系统可以提供全球范围内的准确定位信息。然而&#xff0c;GPS在室内环境下的信号覆盖可能存在限制&#xff0c;因此在室内定位应用中&#xff0c;常常采用无线传感器网络&#…

第一行代码 第十一章 基于位置的服务

第11章 基于位置的服务 在本章中&#xff0c;我们将要学习一些全新的Android技术&#xff0c;这些技术有别于传统的PC或Web领域的应用技术&#xff0c;是只有在移动设备上才能实现的。 基于位置的服务&#xff08;Location Based Service&#xff09;。由于移动设备相比于电脑…

案例分享 | 纽扣电池石墨片厚度及缺陷检测

石墨片是一种导热散热材料&#xff0c;质轻柔软&#xff0c;能够轻松贴合在各种热源点&#xff0c;在新能源、航天、3C电子等领域应用广泛。 汽车钥匙中的纽扣电池也需要使用石墨片&#xff0c;石墨片会有统一的厚度标准&#xff0c;装配过程中表面不可避免地会出现裂纹、划痕…