GLM4指令微调实战(完整代码)

GLM4是清华智谱团队最近开源的大语言模型。

以GLM4作为基座大模型,通过指令微调的方式做高精度文本分类,是学习LLM微调的入门任务。

在这里插入图片描述

使用的9B模型,显存要求相对较高,需要40GB左右。

在本文中,我们会使用 GLM4-9b-Chat 模型在 复旦中文新闻 数据集上做指令微调训练,同时使用SwanLab监控训练过程、评估模型效果。

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:GLM4-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

相关文章:Qwen2指令微调

知识点:什么是指令微调?

大模型指令微调(Instruction Tuning)是一种针对大型预训练语言模型的微调技术,其核心目的是增强模型理解和执行特定指令的能力,使模型能够根据用户提供的自然语言指令准确、恰当地生成相应的输出或执行相关任务。

指令微调特别关注于提升模型在遵循指令方面的一致性和准确性,从而拓宽模型在各种应用场景中的泛化能力和实用性。

在实际应用中,我的理解是,指令微调更多把LLM看作一个更智能、更强大的传统NLP模型(比如Bert),来实现更高精度的文本预测任务。所以这类任务的应用场景覆盖了以往NLP模型的场景,甚至很多团队拿它来标注互联网数据

下面是实战正片:

1.环境安装

本案例基于Python>=3.8,请在您的计算机上安装好Python,并且有一张英伟达显卡(显存要求并不高,大概10GB左右就可以跑)。

我们需要安装以下这几个Python库,在这之前,请确保你的环境内已安装了pytorch以及CUDA:

swanlab
modelscope
transformers
datasets
peft
accelerate
pandas
tiktoken

一键安装命令:

pip install swanlab modelscope transformers datasets peft pandas accelerate tiktoken

本案例测试于modelscope1.14.0、transformers4.41.2、datasets2.18.0、peft0.11.1、accelerate0.30.1、swanlab0.3.10、tiktokn==0.7.0,更多环境细节可以查看这里

2.准备数据集

本案例使用的是zh_cls_fudan-news数据集,该数据集主要被用于训练文本分类模型。

zh_cls_fudan-news由几千条数据,每条数据包含text、category、output三列:

  • text 是训练语料,内容是书籍或新闻的文本内容
  • category 是text的多个备选类型组成的列表
  • output 则是text唯一真实的类型

在这里插入图片描述

数据集例子如下:

"""
[PROMPT]Text: 第四届全国大企业足球赛复赛结束新华社郑州5月3日电(实习生田兆运)上海大隆机器厂队昨天在洛阳进行的第四届牡丹杯全国大企业足球赛复赛中,以5:4力克成都冶金实验厂队,进入前四名。沪蓉之战,双方势均力敌,90分钟不分胜负。最后,双方互射点球,沪队才以一球优势取胜。复赛的其它3场比赛,青海山川机床铸造厂队3:0击败东道主洛阳矿山机器厂队,青岛铸造机械厂队3:1战胜石家庄第一印染厂队,武汉肉联厂队1:0险胜天津市第二冶金机械厂队。在今天进行的决定九至十二名的两场比赛中,包钢无缝钢管厂队和河南平顶山矿务局一矿队分别击败河南平顶山锦纶帘子布厂队和江苏盐城无线电总厂队。4日将进行两场半决赛,由青海山川机床铸造厂队和青岛铸造机械厂队分别与武汉肉联厂队和上海大隆机器厂队交锋。本届比赛将于6日结束。(完)
Category: Sports, Politics
Output:[OUTPUT]Sports
"""

我们的训练任务,便是希望微调后的大模型能够根据Text和Category组成的提示词,预测出正确的Output。


我们将数据集下载到本地目录下。下载方式是前往zh_cls_fudan-news - 魔搭社区 ,将train.jsonltest.jsonl下载到本地根目录下即可:

在这里插入图片描述

3. 加载模型

这里我们使用modelscope下载GLM4-9b-Chat模型(modelscope在国内,所以下载不用担心速度和稳定性问题),然后把它加载到Transformers中进行训练:

from modelscope import snapshot_download, AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq

# 在modelscope上下载GLM模型到本地目录下
model_dir = snapshot_download("ZhipuAI/glm-4-9b-chat", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)

4. 配置训练可视化工具

我们使用SwanLab来监控整个训练过程,并评估最终的模型效果。

这里直接使用SwanLab和Transformers的集成来实现:

from swanlab.integration.huggingface import SwanLabCallback

swanlab_callback = SwanLabCallback(...)

trainer = Trainer(
    ...
    callbacks=[swanlab_callback],
)

如果你是第一次使用SwanLab,那么还需要去https://swanlab.cn上注册一个账号,在用户设置页面复制你的API Key,然后在训练开始时粘贴进去即可:

在这里插入图片描述

5. 完整代码

开始训练时的目录结构:

|--- train.py
|--- train.jsonl
|--- test.jsonl

train.py:

import json
import pandas as pd
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.huggingface import SwanLabCallback
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import os
import swanlab


def dataset_jsonl_transfer(origin_path, new_path):
    """
    将原始数据集转换为大模型微调所需数据格式的新数据集
    """
    messages = []

    # 读取旧的JSONL文件
    with open(origin_path, "r") as file:
        for line in file:
            # 解析每一行的json数据
            data = json.loads(line)
            context = data["text"]
            catagory = data["category"]
            label = data["output"]
            message = {
                "instruction": "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型",
                "input": f"文本:{context},类型选型:{catagory}",
                "output": label,
            }
            messages.append(message)

    # 保存重构后的JSONL文件
    with open(new_path, "w", encoding="utf-8") as file:
        for message in messages:
            file.write(json.dumps(message, ensure_ascii=False) + "\n")
            
            
def process_func(example):
    """
    将数据集进行预处理
    """
    MAX_LENGTH = 384 
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(
        f"<|system|>\n你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n",
        add_special_tokens=False,
    )
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
    )
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}   


def predict(messages, model, tokenizer):
    device = "cuda"
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    print(response)
     
    return response
    
# 在modelscope上下载GLM模型到本地目录下
model_dir = snapshot_download("ZhipuAI/glm-4-9b-chat", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法

# 加载、处理数据集和测试集
train_dataset_path = "train.jsonl"
test_dataset_path = "test.jsonl"

train_jsonl_new_path = "new_train.jsonl"
test_jsonl_new_path = "new_test.jsonl"

if not os.path.exists(train_jsonl_new_path):
    dataset_jsonl_transfer(train_dataset_path, train_jsonl_new_path)
if not os.path.exists(test_jsonl_new_path):
    dataset_jsonl_transfer(test_dataset_path, test_jsonl_new_path)

# 得到训练集
train_df = pd.read_json(train_jsonl_new_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["query_key_value", "dense", "dense_h_to_4h", "activation_func", "dense_4h_to_h"],
    inference_mode=False,  # 训练模式
    r=8,  # Lora 秩
    lora_alpha=32,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1,  # Dropout 比例
)

model = get_peft_model(model, config)

args = TrainingArguments(
    output_dir="./output/GLM4-9b",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)

swanlab_callback = SwanLabCallback(
    project="GLM4-fintune",
    experiment_name="GLM4-9B-Chat",
    description="使用智谱GLM4-9B-Chat模型在zh_cls_fudan-news数据集上微调。",
    config={
        "model": "ZhipuAI/glm-4-9b-chat",
        "dataset": "huangjintao/zh_cls_fudan-news",
    },
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

trainer.train()

# 用测试集的前10条,测试模型
test_df = pd.read_json(test_jsonl_new_path, lines=True)[:10]

test_text_list = []
for index, row in test_df.iterrows():
    instruction = row['instruction']
    input_value = row['input']
    
    messages = [
        {"role": "system", "content": f"{instruction}"},
        {"role": "user", "content": f"{input_value}"}
    ]

    response = predict(messages, model, tokenizer)
    messages.append({"role": "assistant", "content": f"{response}"})
    result_text = f"{messages[0]}\n\n{messages[1]}\n\n{messages[2]}"
    test_text_list.append(swanlab.Text(result_text, caption=response))
    
swanlab.log({"Prediction": test_text_list})
swanlab.finish()

看到下面的进度条即代表训练开始,这些loss、grad_norm等信息会到一定的step时打印出来:

在这里插入图片描述

6.训练结果演示

在SwanLab上查看最终的训练结果:

可以看到在2个epoch之后,微调后的qwen2的loss降低到了不错的水平——当然对于大模型来说,真正的效果评估还得看主观效果。

在这里插入图片描述

可以看到在一些测试样例上,微调后的qwen2能够给出准确的文本类型:

在这里插入图片描述

至此,你已经完成了glm2指令微调的训练!

相关链接

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:GLM4-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/696703.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

1.VMware软件的安装与虚拟机的创建

1. VMware软件的安装 1.1 为什么需要虚拟机 嵌入式Linux开发需要在Linux系统下运行&#xff0c;我们选择Ubuntu。   1、双系统安装     有问题&#xff0c;一次只能使用一个系统。Ubuntu基本只做编译用。双系统安装不能同时运行Windows和Linux。   2、虚拟机软件   …

解决!word转pdf时,怎样保持图片不失真

#今天用word写了期末设计报告&#xff0c;里面有很多过程的截图&#xff0c;要打印出来&#xff0c;想到pdf图片不会错位&#xff0c;就转成了pdf&#xff0c;发现图片都成高糊了&#xff0c;找了好多方法&#xff0c;再不下载其他软件和插件的情况下&#xff0c;导出拥有清晰的…

单片机嵌入式计算器(带程序EXE)

单片机嵌入式计算器 主要功能&#xff1a;完成PWM占空比计算&#xff0c;T溢出时间&#xff08;延时&#xff09;&#xff1b; [!NOTE] 两个程序EXE&#xff1b; [!CAUTION] 百度网盘链接&#xff1a;链接&#xff1a;https://pan.baidu.com/s/1VJ0G7W5AEQw8_MiagM7g8A?pwdg8…

DDMA信号处理以及数据处理的流程---原始数据生成

Hello&#xff0c;大家好&#xff0c;我是Xiaojie&#xff0c;好久不见&#xff0c;欢迎大家能够和Xiaojie一起学习毫米波雷达知识&#xff0c;Xiaojie准备连载一个系列的文章—DDMA信号处理以及数据处理的流程&#xff0c;本系列文章将从目标生成、信号仿真、测距、测速、cfar…

G盘文件系统损坏:全面解析与应对策略

在数字时代&#xff0c;数据的重要性不言而喻。然而&#xff0c;G盘文件系统损坏却时常给我们的数据安全带来威胁。当G盘文件系统受损时&#xff0c;可能导致文件丢失、数据无法访问等严重后果。本文将深入探讨G盘文件系统损坏的现象、原因、恢复方案以及预防措施&#xff0c;帮…

鸿蒙开发必备:《DevEco Studio 系列一:实用功能解析与常用快捷键大全》

系列文章目录 文章目录 系列文章目录前言一、下载与安装小黑板 二、IDE被忽略的实用功能-帮助&#xff08;Help&#xff09;1.Quick Start2. API Reference3.FAQ 三、常用快捷键一、编辑二、查找或替换三、编译与运行四、调试五、其他 前言 DevEco Studio&#xff09;是基于In…

算法笔记1-高精度模板(加减乘除)个人模板

目录 加法 减法 乘法 ​编辑 除法 加法 #include <iostream> #include <cstring> #include <algorithm> #include <cmath> #include <queue>using namespace std;typedef pair<int,int> PII;const int N 1e5 10;int n; int a[N],…

Spring运维之业务层测试数据回滚以及设置测试的随机用例

业务层测试数据回滚 我们之前在写dao层 测试的时候 如果执行到这边的代码 会在数据库 里面留下数据 运行一次留一次数据 开发有开发数据库&#xff0c;运行有运行数据库 我们先连数据库 在pom文件里引入mysql的驱动和mybatis-plus的依赖 在数据层写接口 用mybatis-plus进…

Qt 布局管理

布局基础 1)Qt 布局管理系统使用的类的继承关系如下图: QLayout 和 QLayoutItem 这两个类是抽象类,当设计自定义的布局管理器时才会使用到,通常使用的是由 Qt 实现的 QLayout 的几个子类。 2)Qt 使用布局管理器的步骤如下: 首先创建一个布局管理器类的对象。然后使用该…

RocketMQ教程(三):RocketMQ的核心组件

四个核心组件 RocketMQ 的架构采用了典型的分布式系统设计理念,以确保高性能、高可用和可扩展性。RocketMQ 主要由四个核心组件构成:NameServer、Broker、Producer 和 Consumer。下面是对这些组件以及它们在 RocketMQ 中的角色和功能的概述: 1. NameServer 角色和功能:Name…

28.找零

上海市计算机学会竞赛平台 | YACSYACS 是由上海市计算机学会于2019年发起的活动,旨在激发青少年对学习人工智能与算法设计的热情与兴趣,提升青少年科学素养,引导青少年投身创新发现和科研实践活动。https://www.iai.sh.cn/problem/744 题目描述 有一台自动售票机,每张票卖 …

C++11:列表初始化 初始化列表initializer_list decltype关键字

目录 前言 列表初始化 初始化列表initializer_list decltype关键字 左值和右值 move 前言 2003年C标准委员会曾经提交了一份技术勘误表&#xff08;简称TC1&#xff09;&#xff0c;使得C03这个名字取代了C98成为了C11前最新的C标准名称。不过由于C03主要是对C98标准中的…

C++必修:探索C++的内存管理

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;C学习 贝蒂的主页&#xff1a;Betty’s blog 1. C/C的内存分布 我们首先来看一段代码及其相关问题 int globalVar 1; static…

Java从入门到放弃

线程池的主要作用 线程池的设计主要是为了管理线程&#xff0c;为了让用户不需要再关系线程的创建和销毁&#xff0c;只需要使用线程池中的线程即可。 同时线程池的出现也为性能的提升做出了很多贡献&#xff1a; 降低了资源的消耗&#xff1a;不会频繁的创建、销毁线程&…

Helm离线部署Rancher2.7.10

环境依赖&#xff1a; K8s集群、helm 工具 Rancher组件架构 Rancher Server 包括用于管理整个 Rancher 部署的所有软件组件。 下图展示了 Rancher 2.x 的上层架构。下图中&#xff0c;Rancher Server 管理两个下游 Kubernetes 集群 准备Rancher镜像推送到私有仓库 cat >…

CorelDRAW2024破解激活码序列号一步到位

亲们&#xff0c;今天给大家种草一个神奇的软件——CorelDRAW破解2024最新版&#xff01;&#x1f3a8;这是一款专业级的矢量图形设计软件&#xff0c;无论你是平面设计师、插画师还是设计师&#xff0c;都能在这个软件中找到你需要的工具和功能。✨ 让我来给大家介绍一下这款软…

OpenGauss数据库-7.用户及角色

第1关&#xff1a;创建用户 gsql -d postgres -U gaussdb -W passwd123123 CREATE USER jackson WITH PASSWORD jackson123; 第2关&#xff1a;修改用户 gsql -d postgres -U gaussdb -W passwd123123 ALTER USER jackson WITH PASSWORD Abcd123; 第3关&#xff1a;创建角色 …

MySQL 常见客户端程序

本篇主要介绍MySQL常见的客户端程序 目录 一、mysqlcheck 二、mysqldump 三、mysqladmin 四、mysqldumpslow 五、mysqlbinlog 六、mysqlshow 显示列的具体信息​编辑 七、mysqlslap 一、mysqlcheck mysqlcheck是MySQL的表维护程序&#xff0c;其功能主要包含以下四个方…

传感器展会|2024厦门传感器与应用技术展览会

传感器展会|2024厦门传感器与应用技术展览会 时间&#xff1a;2024年11月1-3日 地点&#xff1a;厦门国际会展中心 XISE EXPO展会介绍&#xff1a; 2024中国&#xff08;厦门&#xff09;国际传感器与应用技术展览会将于2024年11月1-3日在厦门国际会展中心举行&#xf…

Docker快速部署springboot项目

本文概述 本文主要介绍了怎么将springboot项目打包为docker镜像&#xff0c;并如何在后端服务器上使用docker快速部署springboot应用和nginx应用。 一、打包springboot项目 1、复制原来的application.yml文件然后重命名为application-pro.yml文件&#xff0c;将application-pro…