提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
#AI夏令营 #Datawhale #夏令营
- 一、数据集制作
- 1.1 环境配置
- 1.2 数据处理prompt
- 1.3 训练数据集制作
- 1.4 测试集数据制作
- 二、模型微调
- 2.1 平台微调
- 2.2 平台微调
- 三、微调推理
- 提交文件
一、数据集制作
1.1 环境配置
- baseline2导入
代码部署托管在百度飞桨aistudio平台
fork到自己的项目中 - 参考task1把对应的星火大模型spark api配置的相应信息填入对应的位置
1.2 数据处理prompt
这个prompt相较于baseline01区别比较明显,对需要抽取的任务做了一次总结。总结了四个方面:
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日 客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细 客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段 跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
通过总结后的数据一方面节约了微调的运算资源,一方面也让数据被清洗后更容易被模型理解,达到更好的抽取效果。
content = ''
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
1.3 训练数据集制作
jsonl样例:
jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}
print(jsonl_data)
print(jsonl_data["instruction"])
print(jsonl_data["input"])
print(jsonl_data["output"])
import json
# 打开并读取JSON文件
with open('train.json', 'r', encoding='utf-8') as file:
data = json.load(file)
这里我们通过星火3.5api清洗原来的数据,总结后按照刚才看到得单行jsonl存储格式将数据存入traindata.jsonl中。大家可以经过处理后自行查阅traindata.jsonl文件,看看都有啥。
这里的训练时长大概40min左右,请耐心等待。这段等待的时间可以看看后面的内容。
# 训练集制作
# 打开一个文件用于写入,如果文件已存在则会被覆盖
with open('traindata.jsonl', 'w', encoding='utf-8') as file:
# 训练集行数(130)不符合要求,范围:1500~90000000
# 遍历数据列表,并将每一行写入文件
# 这里为了满足微调需求我们重复12次数据集 130*12=1560
for line_data in tqdm(data):
line_input = line_data["chat_text"]
line_output = line_data["infos"]
content = line_input
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
res = chatbot(prompt=prompt)
# print(res)
line_write = {
"instruction":jsonl_data["instruction"],
"input":json.dumps(res, ensure_ascii=False),
"output":json.dumps(line_output, ensure_ascii=False)
}
# 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
for time in range(12):
file.write(json.dumps(line_write, ensure_ascii=False) + '\n') # '\n' 用于在每行末尾添加换行符
1.4 测试集数据制作
测试数据和训练数据相似,都是通过api清洗后存储。
# 验证集制作(提交版本)
# input,target
import json
# 打开并读取JSON文件
with open('test_data.json', 'r', encoding='utf-8') as file:
data_test = json.load(file)
这里的验证数据我们以csv文件存储,有input和target两列,由于我们没有这些数据的真实标签,我这里将target列设置为’-'。
测试集text.csv文件大概需要20min能得到,也请大家耐心等待~
import csv
# 打开一个文件用于写入CSV数据
with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
# 创建一个csv writer对象
csvwriter = csv.writer(csvfile)
csvwriter.writerow(["input","target"])
# 遍历数据列表,并将每一行写入CSV文件
for line_data in tqdm(data_test):
content = line_data["chat_text"]
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
res = chatbot(prompt=prompt)
# print(line_data["chat_text"])
## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
line_list = [res, "-"]
csvwriter.writerow(line_list)
# break
训练完成后会输出两个利用大模型进行数据清洗后的纯净数据,如下图所示右键下载这两个文件即可
二、模型微调
- 登录微调平台
- 微调思路
2.1 平台微调
- 训练数据上传
登录微调平台后,上传第一步制作的数据集
上传测试集与上述步骤相同,完成后我们有两个数据集了
这里实际微调时测试集好像用不上
2.2 平台微调
等模型训练完成后即可
三、微调推理
# 定义写入函数
def write_json(json_file_path, data):
#"""写入json文件"""
with open(json_file_path, 'w') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
把控制台训练好的模型的api信息写入下面代码对应的位置
import SparkApi
import json
#以下密钥信息从控制台获取
appid = "" #填写控制台中获取的 APPID 信息
api_secret = "" #填写控制台中获取的 APISecret 信息
api_key ="" #填写控制台中获取的 APIKey 信息
#调用微调大模型时,设置为“patch”
domain = "patchv3"
#云端环境的服务地址
# Spark_url = "wss://spark-api-n.xf-yun.com/v1.1/chat" # 微调v1.5环境的地址
Spark_url = "wss://spark-api-n.xf-yun.com/v3.1/chat" # 微调v3.0环境的地址
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
def core_run(text,prompt):
# print('prompt',prompt)
text.clear
Input = prompt
question = checklen(getText("user",Input))
SparkApi.answer =""
# print("星火:",end = "")
SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
getText("assistant",SparkApi.answer)
# print(text)
return text[-1]['content']
text = []
res = core_run(text,'你好吗?')
在SparkApi.py文件的108行,引号中填入你的resourceId
提交文件
https://challenge.xfyun.cn/h5/detail?type=role-element-extraction&ch=dw24_y0SCtd
参考task1中,等待官方打分即可