ChatGLM-6B部署和微调实例

文章目录

  • 前言
  • 一、ChatGLM-6B安装
    • 1.1 下载
    • 1.2 环境安装
  • 二、ChatGLM-6B推理
  • 三、P-tuning 微调
    • 3.1微调数据集
    • 3.2微调训练
    • 3.3微调评估
    • 3.4 调用新的模型进行推理
  • 总结


前言

ChatGLM-6B ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。ChatGLM-6B是本人尝试使用和微调的第一个大语言模型,自我感觉该模型很适合作为大语言模型的入门级选手,无论是部署配置还是推理微调都十分方便。本文主要介绍如何配置部署ChatGLM-6B,以及ChatGLM-6B推理和P-tuning v2微调基本步骤,希望可以帮助大家使用ChatGLM-6B。


一、ChatGLM-6B安装

1.1 下载

ChatGLM-6B项目仓库地址为 GitHub,模型文件下载地址为Huggingface,将下载好的模型文件chatglm-6b文件放至项目仓库中的ptuning文件目录下(如下图所示)。整个下载时间的长短根据网速和是否使用远程服务器因人而异,本人因使用的是远程服务器,下载时间共约5个小时。
在这里插入图片描述

1.2 环境安装

服务器的版本为RTX 3090,内存为24GB。Python版本为3.8.16,ubuntu的版本为20.04,Cuda的版本为11.6

库名版本
transformers4.27.1
torch1.13.1

详情可见requirements.txt,其中gradio库有的时候会安装失败,如果后续不考虑前端交互的平台的构建,此库可以先不安装,并不影响模型推理和微调。环境配置步骤如下代码所示:

conda create -n test python=3.8.16 -y
source activate test
pip install -r requirements.txt
cd ChatGLM

二、ChatGLM-6B推理

ChatGLM-6B推理部分,只要找到cli_demo.py文件运行即可。

tokenizer = AutoTokenizer.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

以下是推理部分的展示:
在这里插入图片描述
当然我们也想要批量式询问ChatGLM-6B,这里我自己写了一个批量调用的py文件:

import torch
from transformers import AutoTokenizer, AutoModel
import torch
import sys
import pandas as pd
model_path="ptuning/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained("ptuning/chatglm-6b",trust_remote_code=True).float()
#model =model.to("cpu")
model = model.eval()
data = pd.read_csv('Q1.csv')
MC = data['Question'].tolist()
j = -1
for i in MC:
    j = j+1
    input1 = f"{i}"
    print(input1)
    response,history = model.chat(tokenizer,input1,history=[],temperature=1)
    print(response)
    print("--------------------------------------------------")
    data['Answer'].loc[j] = response
    data.to_csv('Q1.csv',index = False,encoding='utf_8_sig')

最终Q1.csv的结果为:
在这里插入图片描述

三、P-tuning 微调

下图展示出ChatGLM-6B进行P-tuning v2微调的大致流程,首先需要构建好微调模型使用的数据集(包括训练集,验证集和测试集),接着是配置运行train.sh,进行数小时的训练之后将会得到模型参数权重文件Checkpoint,然后对evaluate.sh进行参数配置和运行,将会得到一系列的测试集结果,到此便是微调部分。为了检测微调后的模型在新数据上的效果,可以对cli_demo.py文件进行配置和运行。
在这里插入图片描述

3.1微调数据集

我的课题是研究法律判决预测任务,因此我的微调数据集的输入为案情陈述,输出为罪行判决。ChatGLM-6B的微调数据集有很多的格式可以选择,这里是经典的content+summary格式。以下是一个例子🌰:
{“content”: “经审理查明,2017年10月10日18时左右,被告人张某酒后驾驶牌号为川Q???二轮摩托车从宜宾市翠屏区牟坪镇牟坪村5组35号家中出发,前往宜宾市翠屏区牟坪镇派出所办事,被办案民警发现被告人张某饮酒驾驶机动车,即对张某进行了呼气式酒精检测,检出酒精含量268mg/100mL。”, “summary”: “根据中华人民共和国刑法第133条,判处张某危险驾驶罪。其中检出张某酒精含量268mg/100mL,根据中国的交通法规,血液中酒精含量超过80mg/100ml,即被认定为醉驾,因此张某符合危险驾驶罪中的醉酒驾驶机动车。”}
我们将标注好的数据分成训练集、验证集和测试集,一起存入Legal_data文件夹中,并放在ptuning目录下,如下图所示:
在这里插入图片描述
其中train.json中有44条数据, dev.json中有10条数据, test.json中有10条数据。数据量不大,只是为了方便走一遍微调流程,大家可以在创建自己的微调数据集的时候多标注些,这样会大大提高模型的性能。

3.2微调训练

查看train.sh,模型主要的训练参数有PRE_SEQ_LEN,max_target_length,max_source_length,learning_rate,per_device_train_batch_size,max_steps,per_device_train_batch_size,gradient_accumulation_stepsquantization_bit。下面将详细介绍这些训练参数的含义与作用。
PRE_SEQ_LEN是指自然语言指令的长度,而max_source_length是指整个输入序列的最大长度,max_target_length指整个输出序列的最大长度。 一般来说,PRE_SEQ_LEN应该小于或者等于max_source_length,因为输入序列除了包含指令之外,还可能包含其他内容,例如上下文信息或对话历史。根据微调标注数据的输入输出的文本长度,我们设置PRE_SEQ_LEN为128,max_target_lengthmax_source_length为300。
learning_rate是一个关键参数,它决定了每次更新模型权重时,根据梯度下降的方向应该迈出多大的步伐。我们使用ChatGLM的默认值1e-3。
per_device_train_batch_size设置为1,那么每个设备上将会有1个样本作为输入进行模型训练,并且基于这1个样本的损失值来进行一次模型参数的更新。
gradient_accumulation_steps指定了在执行一次模型权重更新(即一次反向传播步骤)之前,要累积多少个批次的梯度。gradient_accumulation_steps的值为16,这样模型会在每16个批次后进行一次权重更新,等价于使用大小为 16的批次进行训练。
max_steps 参数用于指定在结束训练前,模型应进行多少步的更新。每一“步”通常包括一个前向传播和一个反向传播,并且可能涉及到多个批次,在实验中max_steps设置为300。
quantization_bit 参数通常关联到模型权重的量化过程。量化是一种将模型权重从浮点数转换为低精度(如整数)表示形式的技术。这个过程可以显著减少模型的存储需求和计算复杂性,从而提高推理速度并减少内存使用。

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file Legal_data/train.json \
    --validation_file Legal_data/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 300 \
    --max_target_length 300 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 300 \
    --logging_steps 10 \
    --save_steps 100 \
    --learning_rate $LR \
    # --report_to tensorboard \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

在运行sh train.sh之前,我们需要额外安装datasetsjiebarouge_chinesenltk库:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple datasets
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple jieba
pip install rouge_chinese
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple nltk

微调训练成功的截图如下:
在这里插入图片描述
由于logging_steps 设置为10,因此每运行十个step,便会保留一次loss,learning_rate和epoch值。最后也会输出一个train metrics
在这里插入图片描述
在这里插入图片描述

sh train.sh完成之后会发现ptuning路径下有新的文件夹output,文件夹中保存了微调训练后的checkpoint和一些评估指标,下图中有三个checkpoint是因为max_steps 为300,save_steps为100。在这里插入图片描述

3.3微调评估

执行sh evaluate.sh,注意evaluate.sh中的一些参数要和train.sh中的参数一致。如max_source_length,max_target_lengthPRE_SEQ_LEN

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=300

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_predict \
    --validation_file Legal_data/dev.json \
    --test_file Legal_data/test.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path chatglm-6b \
    --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 300 \
    --max_target_length 300 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

sh evaluate.sh执行成功的截图如下:
在这里插入图片描述
sh evaluate.sh完成之后会发现output文件夹中保存了评估后的generated_predictions.txt``predict_results.json文件。

3.4 调用新的模型进行推理

完成微调后的模型在测试集上的评估之后,我们如何使用微调好的模型进行推理呢?这里我们以cli_demo.py为例,之前的是这样的调用模型:

tokenizer = AutoTokenizer.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

我们仅需要将上面的代码变为:

tokenizer = AutoTokenizer.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True)

config = AutoConfig.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("ptuning/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join('ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-300', "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

主要注意⚠️的是以下一行代码,需要导入微调好的checkpoint:

prefix_state_dict = torch.load(os.path.join('ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-300', "pytorch_model.bin"))

这样我们便可以调用自己微调好的ChatGLM-6B模型:
在这里插入图片描述


总结

ChatGLM-6B对于中文的问题回答能力优秀,希望大家可以通过我的分享来测试它❤️❤️❤️

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/332557.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于Prism框架的WPF前端框架开发《知产代理数字化解决方案》

最近新开发了一套WPF前端界面框架,叫《知产代理数字化解决方案》,采用了时下流行的Prism框架作为整个系统的基础架构,演示了Prism中的IRegionManager区域管理器、IDialogAware对话框、IDialogService对话框服务、IContainerExtension容器等用…

Python实现自动化办公(使用第三方库操作Excel)

1 使用 xlrd 读取Excel数据 1.1 获取具体单元格的数据 import xlrd# 1. 打开工作簿 workbook xlrd.open_workbook("D:/Python_study_projects/Python自动化办公/Excel/test1.xlsx") # 2. 打开工作表 sheet1 workbook.sheets()[0] # 选择所有工作表中的第一个 # …

阿里云地域和可用区分布表,2024更新

2024年阿里云服务器地域分布表,地域指数据中心所在的地理区域,通常按照数据中心所在的城市划分,例如华北2(北京)地域表示数据中心所在的城市是北京。阿里云地域分为四部分即中国、亚太其他国家、欧洲与美洲和中东&…

springcloud Ribbon负载均衡服务调用

文章目录 代码下载地址简介测试 Ribbon负载均衡算法手写RoundRobinRule源码8001/8002微服务改造80订单微服务改造测试 代码下载地址 地址:https://github.com/13thm/study_springcloud/tree/main/days6_Ribbon 简介 Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端…

ora-12154无法解析指定的连接标识符

用户反映查询的时候报错ora-12154 这个系统只做历史数据查询使用,使用并不平凡,该数据库曾做过一次服务器间的迁移。 用户描述,所有oracle客户端查询该视图都报tns错误,一般ora-12154会发生在连接数据库时,因为tns配…

flutter开发windows桌面软件,使用Inno Setup打包成安装程序,支持中文

最近使用flutter开发windows桌面软件的时候,想要将软件打包成安装程序,使用了flutter官方推荐的msix打包,但是打包出来的软件生成的桌面快捷方式有蓝色背景: 这个蓝色背景应该是没有设置为动态导致的,windows系统的屏幕…

C#,字符串匹配(模式搜索)RK(Rabin Karp)算法的源代码

M.O.Rabin Rabin-Karp算法,是由M.O.Rabin和R.A.Karp设计实现的一种基于移动散列值的字符串匹配算法。 通常基于散列值的字符串匹配方法:(1)首先计算模式字符串的散列函数;(2)然后利用相同的散…

mysql数据迁移报错Specified key was too long; max key length is 767 bytes

目录 场景: 说明: 疑问: 解决: 验证: 场景: 线上项目支持的过程中遇到mysql库表结构和数据由A库迁移到B库上提示Specified key was too long; max key length is 767 bytes报错,第一次遇到特此…

每日一题——LeetCode1266.访问所有点的最小时间

方法一 个人方法 找规律: 当前的点为current,下一个点为next,x为两点横坐标之间距离,y为两点竖坐标之间距离 1、当两点横坐标相同时,两点距离为y 2、当两点竖坐标相同时,两点距离为x 3、当两点x与y相同…

30分钟带你深入优化安卓Bitmap大图

30分钟带你源码深入了解Bitmap以及优化安卓大图 一、前言二、Bitmap入门1. 如何创建Bitmap?2. Bitmap的堆内存分布在哪里3. 图片文件越大,Bitmap堆内存会越大吗?4. 如何管理Bitmap的内存?5. 实战修改Bitmap的堆内存,改变图片的图…

MySQL中锁的概述

按照锁的粒度来分可分为:全局锁(锁住当前数据库的所有数据表),表级锁(锁住对应的数据表),行级锁(每次锁住对应的行数据) 加全局锁:flush tables with read lo…

4 python快速上手

计算机常识知识 1.Python代码运行方式2.进制2.1 进制转换 3. 计算机中的单位4.编码4.1 ascii编码4.2 gb-2312编码4.3 unicode4.4 utf-8编码4.5 Python相关的编码 总结 各位小伙伴想要博客相关资料的话关注公众号:chuanyeTry即可领取相关资料! 1.Python代…

FlinkAPI开发之状态管理

案例用到的测试数据请参考文章: Flink自定义Source模拟数据流 原文链接:https://blog.csdn.net/m0_52606060/article/details/135436048 Flink中的状态 概述 有状态的算子 状态的分类 托管状态(Managed State)和原始状态&…

RDMA Scatter Gather List详解

1. 前言 在使用RDMA操作之前,我们需要了解一些RDMA API中的一些需要的值。其中在ibv_send_wr我们需要一个sg_list的数组,sg_list是用来存放ibv_sge元素,那么什么是SGL以及什么是sge呢?对于一个使用RDMA进行开发的程序员来说&#…

微信小程序(六)tabBar的使用

注释很详细,直接上代码 上一篇 新增内容: 1. 标签栏文字的内容以及默认与选中颜色 2. 标签栏图标的默认样式与选中样式 3. 标签选项路径页面 4.标签栏背景颜色 🐼(文末补充)设置标签栏后为什么navigator标签无法跳转页…

数据集成时表模型同步方法解析

01 背景介绍 数据治理的第一步,也是数据中台的一个基础功能 — 即将来自各类业务数据源的数据,同步集成至中台 ODS 层。业务数据源多种多样,单单可能涉及到的主流关系型数据库就有近十种。功能更加全面的数据中台通常还具有对接非关系型数据…

elasticsearch[一]-索引库操作(轻松创建)、文档增删改查、批量写入(效率倍增)

elasticsearch[一]-索引库操作(轻松创建)、文档增删改查、批量写入(效率倍增) 1、初始化 RestClient 在 elasticsearch 提供的 API 中,与 elasticsearch 一切交互都封装在一个名为 RestHighLevelClient 的类中,必须先完成这个对象的初始化,…

Docker中创建并配置MySQL、nginx、redis等容器

Docker中安装并配置MySQL、nginx、redis等 文章目录 Docker中安装并配置MySQL、nginx、redis等一、创建nginx容器①:拉取镜像②:运行nginx镜像③:从nginx容器中映射nginx配置文件到本地④:重启nginx并重新配置nginx的挂载 二、创建…

苹果Find My可查找添加32件物品,伦茨科技ST17H6x芯片加速产品赋能

苹果最近更新的支持文档证实,从 iOS 16 开始,"Find My"可查找添加物品从16件增加到32件,AirTag 和“查找”网络中的物品利用“查找”网络的强大功能来发挥作用,这个网络由数亿台加密的匿名 Apple 设备构成。“查找”网络…

TCP高并发服务器简介(select、poll、epoll实现与区别)

select、poll、epoll三者的实现: select实现TCP高并发服务器的流程: 一、创建套接字(socket函数):二、填充服务器的网络信息结构体:三、套接字和服务器的网络信息结构体进行绑定(bind函数&…