使用 LlamaFactory 结合开源大语言模型实现文本分类:从数据集构建到 LoRA 微调与推理评估

文章目录

    • 背景介绍
      • 文本分类数据集
      • Lora 微调
      • 模型部署与推理
        • 期待模型的输出结果
    • 文本分类评估代码

背景介绍

本文将一步一步地,介绍如何使用llamafactory框架利用开源大语言模型完成文本分类的实验,以 LoRA微调 qwen/Qwen2.5-7B-Instruct 为例。

文本分类数据集

按照 alpaca 样式构建数据集,并在将其添加到 LLaMA-Factory/data/dataset_info.json 文件中。如此方便直接根据自定义数据集的名字,获取到数据集的数据。

[
  {
    "instruction": "",
    "input": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:\n\n要求}}\nreason: \nlabel:",
    "output": "reason: 该文本主要讨论的是xxx。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"
  },
  ...
]

Lora 微调

llamafactory 框架支持网页端训练,但本文选择在终端使用命令行微调模型。

模型微调训练的参数较多,将模型训练的参数都存储在 yaml 文件中。

qwen_train_cls.yaml 的文件内容如下:

### model
model_name_or_path: qwen/Qwen2.5-7B-Instruct

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
# dataset_dir: data
dataset_dir: LLaMA-Factory/data/ 填写相应路径
dataset: 数据集名 
template: qwen
cutoff_len: 2048
# max_samples: 1000 若数据集较大,可随机筛选一部分数据微调模型
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: output/qwen2.5-7B/cls_epoch2 训练的LoRA权重输出路径
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

使用下述命令启动模型训练:

nohup llamafactory-cli train qwen_train_cls.yaml > qwen_train_cls.log 2>&1 &

命令分解介绍:
nohup, 全称为 “no hangup”(不要挂起)。它的作用是让命令在退出终端后仍然运行,防止因关闭终端或会话中断导致进程被终止。
默认情况下,nohup 会将输出重定向到 nohup.out 文件,但这里已经显式指定了输出位置。
llamafactory-cli train qwen_train_cls.yaml 运行 llamafactory-cli 工具,用于执行训练任务。
train 是子命令,表示进行训练。
qwen_train_cls.yaml 是一个配置文件,包含训练所需的超参数、数据路径、模型结构等。
qwen_train_cls.log
将标准输出 (stdout) 重定向到 qwen_train_cls.log 文件中。
即运行过程中的正常日志信息会被记录到这个文件。
2>&1: 将标准错误输出 (stderr) 重定向到标准输出 (stdout)。
这样,所有错误信息也会被写入到 qwen_train_cls.log 文件中。
&: 表示将整个命令放到后台运行。终端会立即返回,您可以继续进行其他操作,而不用等待命令完成。

模型部署与推理

模型训练完成后得到 Lora 权重。相关微调模型部署与推理,请浏览下述两篇文章,相比llamafactory原本的模型推理速度更快。

  • 基于 LLamafactory 的异步API高效调用实现与速度对比.https://blog.csdn.net/sjxgghg/article/details/144176645
  • 基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现

目前llamafactory已经支持 vllm_infer 推理,这个PR是笔者提交的:

  • llamafactory vllm.https://github.com/hiyouga/LLaMA-Factory/blob/main/scripts/vllm_infer.py
期待模型的输出结果

下述是使用 llamafactory 推理出的数据格式,建议大家在做推理评估时,也做成这个样式,方便统一评估。

{
	"prompt": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:...", 
	"predict": "\nreason: 该文本主要讨论了改革创新发展、行政区划调整、行政管理体制等方面的内容,涉及到体制机制的改革与完善,旨在推动高质量发展和提升生活品质。这些内容与社会管理和经济管理密切相关,但更侧重于行政管理和社会治理的改革,因此更符合“社会管理”这一类别。\n\nlabel: 社会管理", 		 
	"label": "reason: 该文本主要讨论的是改革创新、行政区划调整、体制机制障碍的破除以及行政管理体制等与政府治理和社会管理相关的内容,强调了与高质量发展和生活品质的关系。这些内容显示出对社会管理和行政管理的关注,尤其是在推动城乡一体化和适应高质量发展要求方面。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"
}

文本分类评估代码

import os
import re
import json


from sklearn.metrics import classification_report, confusion_matrix

# 文本类别
CLASS_NAME = [
    "产业相关",
    ...
    "法律法规与行政事务",
    "其他",
]


def load_jsonl(file_path):
    """
    加载指定路径的 JSON 文件并返回解析后的数据。

    :param file_path: JSON 文件的路径
    :return: 解析后的数据(通常是字典或列表)
    :raises FileNotFoundError: 如果文件未找到
    :raises json.JSONDecodeError: 如果 JSON 格式不正确
    """
    data = []
    try:
        with open(file_path, "r", encoding="utf-8") as file:
            for line in file:
                tmp = json.loads(line)
                data.append(tmp)
    except FileNotFoundError as e:
        print(f"文件未找到:{file_path}")
        raise e
    except json.JSONDecodeError as e:
        print(f"JSON 格式错误:{e}")
        raise e
    return data


def parser_label(text: str):
    pattern = r"label[::\s\.\d\*]*([^\s^\*]+)"
    matches = re.findall(pattern, text, re.DOTALL)
    if len(matches) == 1:
        return matches[0]
    return None


def trans2num(item):
    predict = parser_label(item["predict"])
    label = parser_label(item["label"])

    predict_idx = -1
    label_idx = -1
    for idx, cls_name in enumerate(CLASS_NAME):
        if predict == cls_name:
            predict_idx = idx

        if label == cls_name:
            label_idx = idx

    return predict_idx, label_idx

def cls_eval(input_file):
    data = load_jsonl(file_path=input_file)
    predicts = []
    labels = []

    for item in data:
        predict, label = trans2num(item)
        if label == -1:
            continue

        predicts.append(predict)
        labels.append(label)

    return classification_report(predicts, labels, output_dict=False)

本文使用了大模型生成式预测文本类别,我没有使用结构化输出的方式,大家可以使用结构化的json格式输出,这样在提取大模型预测结果的时候会方便很多。

大家按照自己模型的输出结果,修改parser_label 函数,这个函数用于从大模型的输出结果提取label。

cls_eval("xxx/generated_predictions.jsonl")

就会得到下述的输出结果:

-1 代表模型预测的类别不在给定的类别中。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/930536.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

链式设计模式

链式设计模式——装饰器模式和职责链模式 装饰模式 定义: 指在不改变现有对象结构的情况下,动态地给该对象增加一些职责(即增加其额外功能)的模式。 结构 装饰(Decorator)模式中的角色: 抽…

小红薯x-s算法最新补环境教程12-06更新(下)

在上一篇文章中已经讲了如何去定位x-s生成的位置,本篇文章就直接开始撸代码吧 如果没看过的话可以看:小红薯最新x-s算法分析12-06(x-s 56)(上)-CSDN博客 1、获取加密块代码 首先来到参数生成的位置&…

同三维TL200H2S2 2机位互动录播主机

2路HDMI(1路4K30)输入2路SDI输入4路网络摄像机输入1路USB摄像头3路互动远程网络信号解码,2路HDMI(1路4K60)输出, 音频输入:2路3.5mm立体声线路,2路凤凰头带48V幻相电源麦克风,音频输…

dell电脑开不了机怎么回事?戴尔电脑无法开机解决方法

dell戴尔电脑开不了机,这是很多使用dell电脑用户常遇到的问题。这种故障情况是由多种原因引起,包括硬件故障、软件问题或电源问题等等。dell电脑开不了机怎么办呢?下面便为大家介绍一下相关解决修复方法,帮助用户解决戴尔电脑无法…

【AI系统】Auto-Tuning 原理

Auto-Tuning 原理 在硬件平台驱动算子运行需要使用各种优化方式来提高性能,然而传统的手工编写算子库面临各种窘境,衍生出了自动生成高性能算子的的方式,称为自动调优。在本文我们首先分析传统算子库面临的挑战,之后介绍基于 TVM…

Windows电脑伪关机(快速启动模式),怎么真关机

Windows电脑在关机的时候,进入到一个伪关机的状态,也就是并没有真正的关机,但是在一些系统更新、变更了一些设置,进行重启等操作也会进入到真关机状态 这种一般是开启快速启动模式,开启了快速启动模式功能会在关机的时…

C# WPF抽奖程序

C# WPF抽奖程序 using Microsoft.Win32; using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; using System.Windows; using System.…

【智能控制】实验,基于MATLAB的模糊推理系统设计,模糊控制系统设计

关注作者了解更多 我的其他CSDN专栏 过程控制系统 工程测试技术 虚拟仪器技术 可编程控制器 工业现场总线 数字图像处理 智能控制 传感器技术 嵌入式系统 复变函数与积分变换 单片机原理 线性代数 大学物理 热工与工程流体力学 数字信号处理 光电融合集成电路…

泷羽sec:shell编程(9)不同脚本的互相调用和重定向操作

声明: 学习视频来自B站up主 泷羽sec 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&#…

记事本建java及java命名规范

1.桌面开发:c# 2. 记事本建java: 以class的名称(类名)为名,名称.java 编译jdk:javac 名称.java 调动运行jre : java 名称 查看名称.java里面的内容:cat 名称.java java 的命名规范 大驼峰(每个单词首…

数据结构---队列(Queue)

1. 简介 队列(Queue)是一种常用的数据结构,它遵循先进先出(FIFO,First In First Out)的原则。这意味着第一个进入队列的元素将是第一个被移除的元素。队列在计算机科学中有着广泛的应用,比如任…

3D 生成重建016-SA3D从nerf中分割一切

3D 生成重建016-SA3D从nerf中分割一切 文章目录 0 论文工作1 方法介绍2 实验结果 0 论文工作 1 SAM的背景和目标: SAM 是一种强大的二维视觉基础模型,能够在 2D 图像中进行任意物体的分割。传统上,SAM 在二维空间表现出色,但其无…

Ubuntu环境安装RabbitMQ

1.安装Erlang RabbitMq需要Erlang语⾔的⽀持,在安装rabbitMq之前需要安装erlang # 更新软件包 sudo apt-get update # 安装 erlang sudo apt-get install erlang 查看erlang版本 : erl 退出命令:halt(). 2. 安装RabbitMQ # 更新软件包 sudo apt-get update # 安装 …

FSWIND脉动风-风载时程生成器软件原理

大量风的实测资料表明,在风的时程曲线中,瞬时风速包含两个部分:一部分是自振周期一般在 10 分钟以上的平均风,另一部分是周期一般只有几秒左右的脉动风。平均风由于其周期一般比结构的自振周期大,因而考虑其作用性质相…

【JavaWeb后端学习笔记】MySQL的数据查询语言(Data Query Language,DQL)

MySQL DQL 1、DQL语法与数据准备1.1 DQL语法1.2 数据准备 2、基础查询2.1 查询指定字段2.2 查询返回所有字段2.3 给查询结果起别名2.4 去除重复记录 3、条件查询3.1 条件查询语法3.2 条件查询案例分析 4、分组查询4.1 分组查询语法4.2 分组查询案例分析 5、排序查询5.1 排序查询…

优化工业应用的振动筛

了解如何使用 Ansys Rocky 优化振动筛性能。 振动筛 振动筛在许多散装物料处理过程中发挥着重要作用,可实现有效的颗粒分离和分类。 它们在采矿、采石、建筑、回收和食品加工行业尤为常见。它们的设计和操作需要仔细考虑材料特性、工艺要求和环境因素,…

git修改某次commit(白痴版)

第一步 在bash窗口运行 git rebase --interactive commitId^ 比如要改的commitId是 abcedf git rebase --interactive abcedf^键盘 按 i 或者 ins 进入编辑状态 进入insert 编辑状态 在bash窗口手动把对应commit前面的pick改为e或edit 按 esc 进入退出程序 输入 :wq 保存退出…

基于 RNN(GRU, LSTM)+CNN 的红点位置检测(pytorch)

文章目录 1 项目背景2 数据集3 思路4 实验结果5 代码 1 项目背景 需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。 其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大…

word如何快速创建目录?

文章目录 1,先自己写出目录的各级标题。2、选中目标标题,然后给它们编号3、给标题按照个人需求开始分级4、插入域构建目录。4.1、利用快捷键插入域构建目录4.2、手动插入域构建目录 听懂掌声!学会了吗? 前提声明:我在此…

【金猿CIO展】复旦大学附属中山医院计算机网络中心副主任张俊钦:推进数据安全风险评估,防范化解数据安全风险,筑牢医疗数据安全防线...

‍ 张俊钦 本文由复旦大学附属中山医院计算机网络中心副主任张俊钦撰写并投递参与“数据猿年度金猿策划活动——2024大数据产业年度优秀CIO榜单及奖项”评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 数据要素时代,医疗数据已成为医院运营与决策的重要基石…