transformers生成式对话机器人

简介

生成式对话机器人是一种先进的人工智能系统,它能够通过学习大量的自然语言数据来模拟人类进行开放、连贯且创造性的对话。与基于规则或检索式的聊天机器人不同,生成式对话机器人并不局限于预定义的回答集,而是可以根据对话上下文动态地生成新的。这类机器人通常依赖于深度学习框架,特别是Transformer架构(如GPT-3、BERT等)或其他循环神经网络(RNN),例如长短期记忆网络(LSTM)。

核心技术组件

神经网络架构
现代生成式对话机器人大多基于深度学习模型,尤其是Transformer架构。这种架构因其卓越的并行化能力和处理长距离依赖的能力而被广泛采用。Transformers中的多头注意力机制使得模型可以更有效地捕捉输入序列中各个部分之间的关系,从而生成更加相关和连贯的。

自回归模型
在生成回复的过程中,自回归模型按照词或子词单元的顺序预测下一个单元,直到构建出完整的句子。这种方式确保了文本序列的连续性和上下文的一致性。自回归模型的一个显著特点是它们会逐步构建输出,每一次迭代都会根据之前生成的内容调整后续的预测。

训练数据
高质量的训练数据对于生成式对话机器人的性能至关重要。这些数据可以来源于各种渠道,比如电影剧本、社交媒体对话、论坛帖子、客服记录等。丰富的多样化数据有助于训练出一个能够理解和回应多种话题及情境的对话系统。

注意力机制
特别是在Transformer架构中,注意力机制允许模型聚焦于输入序列的关键部分,这对于理解复杂的查询以及产生恰当的回答尤为重要。多头注意力机制进一步增强了这一能力,因为它可以在同一层内同时关注多个不同的信息源。

强化学习
为了优化对话机器人的行为,有时会结合强化学习策略。这种方法可以帮助模型适应不断变化的环境,并依据用户的反馈调整对话策略,以达到更好的交互效果。通过奖励机制,模型可以学习哪些类型的回答更能满足用户需求,进而改进自身的性能。

对话管理
除了基本的回复生成外,一个完整的对话机器人还需要具备对话管理功能,用以跟踪对话状态,确保对话流程的连贯性,以及适时切换话题或结束对话。这涉及到对对话历史的理解和对未来可能发展的预测。

后处理与控制
为了保证生成内容的质量和安全性,生成式对话机器人可能会包含一些后处理步骤,比如过滤不当内容或者调整语气风格,以避免生成不准确、误导性或是不合适的信息。

基于预训练模型训练生成式对话机器人

1, 训练实施方案

这次使用的模型是Langboat/bloom-389m-zh 是澜舟科技开源的。
数据集:nlpcc_2017
将数据集如何处理传给模型,训练出想要的模型实现对话机器人了。
因为模型是自回归的,所以训练任务就是要将完整的序列输入,基于上下文token预测当前token结束位置要有特殊token,eos_token。自回归上部简介中有介绍(自回归模型按照词或子词单元的顺序预测下一个单元),这样就好理解了
在这里插入图片描述
数据处理大概方向已经清楚了,那具体怎么处理了。
在对话中都是一问一答方式,nlpcc_2017也是这样。是对话那么就免不了是多轮的,那么我们喂给模型要是一轮还是多轮实现这样的结果了。能一轮肯定是一轮就要搞定了。
那么数据就要处理成这样:
在这里插入图片描述

input部分提问和答复两部分,label只有答复部分,因为计算原因input和label长度要相同,label缺少部分就要用-100补齐。
图中的黄色部分是提问,蓝色是答复最后要介绍标记eos。
这样数据集处理格式,模型可以识别出来,能计算loss。

单轮问答讲解(作为参考):
在这里插入图片描述
多轮问答讲解(参考):
在这里插入图片描述

2,代码实现

# 生成式对话机器人
## Step1 导入相关包

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

## Step2 加载数据集

ds = Dataset.load_from_disk("./alpaca_data_zh/")
print(ds)

a=ds[:3]
print(a)

## Step3 数据集预处理

tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-389m-zh")
print(tokenizer)

# 数据集处理
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
print(tokenized_ds)


t = tokenizer.decode(tokenized_ds[1]["input_ids"])
print(t)

p = tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"])))
print(p)

## Step4 创建模型

model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-389m-zh")

## Step5 配置训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=2
)

## Step6 创建训练器

trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
)
## Step7 模型训练
trainer.train()


## Step8 模型推理


from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
s = pipe(ipt, max_length=256, do_sample=True, )
print(s)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/934329.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

NanoLog起步笔记-4-Server端的两个线程

nonolog起步笔记-4-Server端的两个线程 Server端的两个线程两个线程的角色与各自的职责RuntimeLogger::compressionThreadMain线程 详细学习一下相关的代码第三个线程第一次出现原位置swip buffer Server端的两个线程 如前所述,nanolog的server端,相对而…

Freertos任务切换

一、操作系统进行任务切换的时机: 采用信号量实现任务的互斥: 二、FreeRTOS 任务切换场合 PendSV 中断的时候提到了上下文(任务)切换被触发的场合: ● 可以执行一个系统调用 ● 系统滴答定时器(SysTick)中断。 1、执行系统调用 执行系统…

【硬件测试】基于FPGA的4FSK调制解调通信系统开发与硬件片内测试,包含信道模块,误码统计模块,可设置SNR

目录 1.算法仿真效果 2.算法涉及理论知识概要 3.Verilog核心程序 4.开发板使用说明和如何移植不同的开发板 5.完整算法代码文件获得 1.算法仿真效果 本文是之前写的文章: 《基于FPGA的4FSK调制解调系统,包含testbench,高斯信道模块,误码率统计模块,可以设置不同SNR》 的…

【Vue2+Element-ui】el-dialog宽度适配

1、不适配问题 分辨率100%-页面 分辨率150%-页面 在项目中,我开发分辨率一直是100%,但是客户使用的分辨率不相同,所以宽度要适配 2、解决-封装mixins.js 1)、封装的mixins 我将宽度设置成动态的,因为我的项目中需求不同。 expor…

Tr0ll: 1 Vulnhub靶机渗透笔记

Tr0ll: 1 本博客提供的所有信息仅供学习和研究目的,旨在提高读者的网络安全意识和技术能力。请在合法合规的前提下使用本文中提供的任何技术、方法或工具。如果您选择使用本博客中的任何信息进行非法活动,您将独自承担全部法律责任。本博客明确表示不支…

23. C++STL 9 (priority_queue的使用和适配实现详解)

⭐本篇重点: 1 priority_queue的使用与底层原理 2 使用容器来适配 priority_queue ⭐本篇代码:c学习 橘子真甜/c-learning-of-yzc - 码云 - 开源中国 (gitee.com) ⭐标⭐是比较重要的部分 目录 一. priority_queue(优先级队列)的…

十四、Pod的升级和回滚

当集群中的某个服务需要升级时,我们需要停止目前与该服务相关的所有Pod,然后下载新版本镜像并创建新的Pod。如果集群规模比较大,则这个工作变成了一个挑战,而且先全部停止然后逐步升级的方式会导致较长时间的服务不可用。Kubernetes提供了滚动升级功能来解决上述问题。 如…

中间件--MongoDB部署及初始化js脚本(docker部署,docker-entrypoint-initdb.d,数据迁移,自动化部署)

一、概述 MongoDB是一种常见的Nosql数据库(非关系型数据库),以文档(Document)的形式存储数据。是非关系型数据库中最像关系型数据库的一种。本篇主要介绍下部署和数据迁移。 在 MongoDB 官方镜像部署介绍中&#xff…

MES系统通过eDrawings Pro API开发图纸批量转换工具,实现3D在线查看

声明:部分代码来源于网络,如有疑问,请联系本人删除。 通过C#结合eDrawings API提供接口,实现图纸转换为换.jpg、.tif、.bmp、.stl、.exe、.html、.zip、.edrw、.eprt 和 .eas格式工具,尤其是.html格式,可以…

Java阶段三06

第3章-第6节 一、知识点 理解MVC三层模型、理解什么是SpringMVC、理解SpringMVC的工作流程、了解springMVC和Struts2的区别、学会使用SpringMVC封装不同请求、接收参数 二、目标 理解MVC三层模型 理解什么是SpringMVC 理解SpringMVC的工作流程 学会使用SpringMVC封装请求…

【计算机网络】期末速成(2)

部分内容来源于网络,侵删~ 第五章 传输层 概述 传输层提供进程和进程之间的逻辑通信,靠**套接字Socket(主机IP地址,端口号)**找到应用进程。 传输层会对收到的报文进行差错检测。 比特流(物理层)-> 数据帧(数据链路层) -> 分组 / I…

word poi-tl 表格功能增强,实现表格功能垂直合并

目录 问题解决问题poi-tl介绍 功能实现引入依赖模版代码效果图 附加(插件实现)MergeColumnData 对象MergeGroupData 类ServerMergeTableData 数据信息ServerMergeTablePolicy 合并插件 问题 由于在开发功能需求中,word文档需要垂直合并表格&…

记一次:使用C#创建一个串口工具

前言:公司的上位机打不开串口,发送的时候设备总是关机,因为和这个同事关系比较好,编写这款软件是用C#编写的,于是乎帮着解决了一下(是真解决了),然后整理了一下自己的笔记 一、开发…

LLama系列模型简要概述

LLama-1(7B, 13B, 33B, 65B参数量;1.4T tokens训练数据量) 要做真正Open的AI Efficient:同等预算下,增大训练数据,比增大模型参数量,效果要更好 训练数据: 书、Wiki这种量少、质量高…

【OpenCV】模板匹配

理论 模板匹配是一种在较大图像中搜索和查找模板图像位置的方法。为此,OpenCV 带有一个函数 cv.matchTemplate() 。它只是在输入图像上滑动模板图像(如在 2D 卷积中),并比较模板图像下的模板和输入图像的补…

从 Zuul 迁移到 Spring Cloud Gateway:一步步实现服务网关的升级

从 Zuul 迁移到 Spring Cloud Gateway:一步步实现服务网关的升级 迁移前的准备工作迁移步骤详解第一步:查看源码第二步:启动类迁移第三步:引入 Gateway 依赖第四步 编写bootstrap.yaml第五步:替换路由配置第六步&#…

网站中的QQ在线客服接入

1. 开通QQ通讯组件 QQ通讯组件官网:https://shang.qq.com 默认未开通通讯组件,登陆上QQ之后会提示开通,点击开通即可 2. 唤起QQ临时会话(对方不是自己的QQ好友也能唤起) 复制链接地址 http://wpa.qq.com/msgrd?v3&…

赋能加速AI应用交付,F5 BIG-IP Next for Kubernetes方案解读

随着AI工作负载的爆炸式增长,服务提供商和企业需要加速计算,以安全高效地在大规模云上交付高性能的AI应用。前段时间,F5公司宣布推出一项全新的创新AI应用交付和应用安全解决方案,即BIG-IP Next for Kubernetes。那么该方案有何性…

域内DNS信息收集

目录 一、查询域内DNS记录 1. 使用 PowerView.ps1 2. 使用 adidnsdump 二、添加域内DNS记录 1. 使用 Invoke-DNSUpdate.ps1 在默认情况下,域内所有用户 都有权限读取 Active Directory 数据库中的 DNS 信息,包括所有记录。这是因为: DNS 记录被视为公共信息,用于解析域…

Odoo :一款免费且开源的食品生鲜领域ERP管理系统

文 / 贝思纳斯 Odoo金牌合作伙伴 引言 提供业财人资税的精益化管理,实现研产供销的融通、食品安全的追踪与溯源,达成渠道的扁平化以及直面消费者的 D2C 等数字化解决方案,以此提升运营效率与核心竞争力,支撑高质量的变速扩张。…