Qwen2大模型微调入门实战(完整代码)

Qwen2是通义千问团队的开源大语言模型,由阿里云通义实验室研发。以Qwen2作为基座大模型,通过指令微调的方式实现高准确率的文本分类,是学习大语言模型微调的入门任务。

在这里插入图片描述

指令微调是一种通过在由(指令,输出)对组成的数据集上进一步训练LLMs的过程。 其中,指令代表模型的人类指令,输出代表遵循指令的期望输出。 这个过程有助于弥合LLMs的下一个词预测目标与用户让LLMs遵循人类指令的目标之间的差距。

在这个任务中我们会使用Qwen2-1.5b-Instruct模型在zh_cls_fudan_news数据集上进行指令微调任务,同时使用SwanLab进行监控和可视化。

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:Qwen2-1.5B-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本教程参考了这篇文章。

1.环境安装

本案例基于Python>=3.8,请在您的计算机上安装好Python,并且有一张英伟达显卡(显存要求并不高,大概10GB左右就可以跑)。

我们需要安装以下这几个Python库,在这之前,请确保你的环境内已安装了pytorch以及CUDA:

swanlab
modelscope
transformers
datasets
peft
accelerat
pandas

一键安装命令:

pip install swanlab modelscope transformers datasets peft pandas

本案例测试于modelscope1.14.0、transformers4.41.2、datasets2.18.0、peft0.11.1、accelerate0.30.1、swanlab0.3.9

2.准备数据集

本案例使用的是zh_cls_fudan-news数据集,该数据集主要被用于训练文本分类模型。

zh_cls_fudan-news由几千条数据,每条数据包含text、category、output三列:

  • text 是训练语料,内容是书籍或新闻的文本内容
  • category 是text的多个备选类型组成的列表
  • output 则是text唯一真实的类型

在这里插入图片描述

数据集例子如下:

"""
[PROMPT]Text: 第四届全国大企业足球赛复赛结束新华社郑州5月3日电(实习生田兆运)上海大隆机器厂队昨天在洛阳进行的第四届牡丹杯全国大企业足球赛复赛中,以5:4力克成都冶金实验厂队,进入前四名。沪蓉之战,双方势均力敌,90分钟不分胜负。最后,双方互射点球,沪队才以一球优势取胜。复赛的其它3场比赛,青海山川机床铸造厂队3:0击败东道主洛阳矿山机器厂队,青岛铸造机械厂队3:1战胜石家庄第一印染厂队,武汉肉联厂队1:0险胜天津市第二冶金机械厂队。在今天进行的决定九至十二名的两场比赛中,包钢无缝钢管厂队和河南平顶山矿务局一矿队分别击败河南平顶山锦纶帘子布厂队和江苏盐城无线电总厂队。4日将进行两场半决赛,由青海山川机床铸造厂队和青岛铸造机械厂队分别与武汉肉联厂队和上海大隆机器厂队交锋。本届比赛将于6日结束。(完)
Category: Sports, Politics
Output:[OUTPUT]Sports
"""

我们的训练任务,便是希望微调后的大模型能够根据Text和Category组成的提示词,预测出正确的Output。


我们将数据集下载到本地目录下。下载方式是前往zh_cls_fudan-news - 魔搭社区 ,将train.jsonltest.jsonl下载到本地根目录下即可:

在这里插入图片描述

3. 加载模型

这里我们使用modelscope下载Qwen2-1.5B-Instruct模型(modelscope在国内,所以下载不用担心速度和稳定性问题),然后把它加载到Transformers中进行训练:

from modelscope import snapshot_download, AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq

# 在modelscope上下载Qwen模型到本地目录下
model_dir = snapshot_download("qwen/Qwen2-1.5B-Instruct", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16)

4. 配置训练可视化工具

我们使用SwanLab来监控整个训练过程,并评估最终的模型效果。

这里直接使用SwanLab和Transformers的集成来实现:

from swanlab.integration.huggingface import SwanLabCallback


swanlab_callback = SwanLabCallback(...)

trainer = Trainer(
    ...
    callbacks=[swanlab_callback],
)

如果你是第一次使用SwanLab,那么还需要去https://swanlab.cn上注册一个账号,在用户设置页面复制你的API Key,然后在训练开始时粘贴进去即可:

在这里插入图片描述

5. 完整代码

开始训练时的目录结构:

|--- train.py
|--- train.jsonl
|--- test.jsonl

train.py:

import json
import pandas as pd
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.huggingface import SwanLabCallback
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import os
import swanlab


def dataset_jsonl_transfer(origin_path, new_path):
    """
    将原始数据集转换为大模型微调所需数据格式的新数据集
    """
    messages = []

    # 读取旧的JSONL文件
    with open(origin_path, "r") as file:
        for line in file:
            # 解析每一行的json数据
            data = json.loads(line)
            context = data["text"]
            catagory = data["category"]
            label = data["output"]
            message = {
                "instruction": "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型",
                "input": f"文本:{context},类型选型:{catagory}",
                "output": label,
            }
            messages.append(message)

    # 保存重构后的JSONL文件
    with open(new_path, "w", encoding="utf-8") as file:
        for message in messages:
            file.write(json.dumps(message, ensure_ascii=False) + "\n")
            
            
def process_func(example):
    """
    将数据集进行预处理
    """
    MAX_LENGTH = 384 
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(
        f"<|im_start|>system\n你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型<|im_end|>\n<|im_start|>user\n{example['input']}<|im_end|>\n<|im_start|>assistant\n",
        add_special_tokens=False,
    )
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
    )
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}   


def predict(messages, model, tokenizer):
    device = "cuda"
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    print(response)
     
    return response
    
# 在modelscope上下载Qwen模型到本地目录下
model_dir = snapshot_download("qwen/Qwen2-1.5B-Instruct", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法

# 加载、处理数据集和测试集
train_dataset_path = "train.jsonl"
test_dataset_path = "test.jsonl"

train_jsonl_new_path = "new_train.jsonl"
test_jsonl_new_path = "new_test.jsonl"

if not os.path.exists(train_jsonl_new_path):
    dataset_jsonl_transfer(train_dataset_path, train_jsonl_new_path)
if not os.path.exists(test_jsonl_new_path):
    dataset_jsonl_transfer(test_dataset_path, test_jsonl_new_path)

# 得到训练集
train_df = pd.read_json(train_jsonl_new_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False,  # 训练模式
    r=8,  # Lora 秩
    lora_alpha=32,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1,  # Dropout 比例
)

model = get_peft_model(model, config)

args = TrainingArguments(
    output_dir="./output/Qwen1.5",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)

swanlab_callback = SwanLabCallback(
    project="Qwen2-fintune",
    experiment_name="Qwen2-1.5B-Instruct",
    description="使用通义千问Qwen2-1.5B-Instruct模型在zh_cls_fudan-news数据集上微调。",
    config={
        "model": "qwen/Qwen2-1.5B-Instruct",
        "dataset": "huangjintao/zh_cls_fudan-news",
    }
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

trainer.train()

# 用测试集的前10条,测试模型
test_df = pd.read_json(test_jsonl_new_path, lines=True)[:10]

test_text_list = []
for index, row in test_df.iterrows():
    instruction = row['instruction']
    input_value = row['input']
    
    messages = [
        {"role": "system", "content": f"{instruction}"},
        {"role": "user", "content": f"{input_value}"}
    ]

    response = predict(messages, model, tokenizer)
    messages.append({"role": "assistant", "content": f"{response}"})
    result_text = f"{messages[0]}\n\n{messages[1]}\n\n{messages[2]}"
    test_text_list.append(swanlab.Text(result_text, caption=response))
    
swanlab.log({"Prediction": test_text_list})
swanlab.finish()

看到下面的进度条即代表训练开始:

在这里插入图片描述

6.训练结果演示

在SwanLab上查看最终的训练结果:

可以看到在2个epoch之后,微调后的qwen2的loss降低到了不错的水平——当然对于大模型来说,真正的效果评估还得看主观效果。

在这里插入图片描述

可以看到在一些测试样例上,微调后的qwen2能够给出准确的文本类型:

在这里插入图片描述

至此,你已经完成了qwen2指令微调的训练!

相关链接

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:Qwen2-1.5B-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/694521.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用vite从0开始搭建vue项目

使用Vite从0开始创建vue项目 第一步&#xff1a;创建项目目录 mkdir vue-demo -创建目录 cd vue-demo --进入项目 npm init -y --生成package.json文件 第二步&#xff1a;安装vite、typescript--ts、vue、vitejs/plugin-vue--对单文件组件、热重载、生产优化的支持 pnpm…

【Python】探索 One-Class SVM:异常检测的利器

我已经从你的 全世界路过 像一颗流星 划过命运 的天空 很多话忍住了 不能说出口 珍藏在 我的心中 只留下一些回忆 &#x1f3b5; 牛奶咖啡《从你的全世界路过》 在数据科学和机器学习领域&#xff0c;异常检测&#xff08;Anomaly Detection&#xff09;是…

工业通讯现场中关于EtherCAT转TCPIP网关的现场应用

在当今工业自动化的浪潮中&#xff0c;EtherCAT技术以其高效、实时的特性成为了众多制造业的首选。然而&#xff0c;随着工业互联网的发展&#xff0c;对于数据的远程访问和云平台集成的需求日益增长&#xff0c;这就需要将EtherCAT协议转化为更为通用的TCP/IP协议。于是开疆智…

【MMU】——MMU 权限控制

文章目录 权限控制内存访问权限内存类型XN execute neverDomain 权限控制 内存访问权限 内存类型 TEX C B bit 控制信息 XN execute never 不可执行区域&#xff0c;例如设备地址空间通常标记为不可执行区域&#xff0c;如果有指令预取访问了该空间&#xff0c;就会触发指令…

CVE-2022-4230

CVE-2022-4230 漏洞介绍 WP Statistics WordPress 插件13.2.9之前的版本不会转义参数&#xff0c;这可能允许经过身份验证的用户执行 SQL 注入攻击。默认情况下&#xff0c;具有管理选项功能 (admin) 的用户可以使用受影响的功能&#xff0c;但是该插件有一个设置允许低权限用…

快速入门Linux及使用VSCode远程连接Linux服务器

在当前的技术环境中&#xff0c;Linux操作系统因其强大的功能和灵活性而广受欢迎。无论你是开发人员、系统管理员还是技术爱好者&#xff0c;学习Linux都是提升技术技能的重要一步。本文将介绍如何快速入门Linux&#xff0c;并使用Visual Studio Code&#xff08;VSCode&#x…

【RAG入门教程03】Langchian框架-文档加载

Langchain 使用文档加载器从各种来源获取信息并准备处理。这些加载器充当数据连接器&#xff0c;获取信息并将其转换为 Langchain 可以理解的格式。 LangChain 中有几十个文档加载器&#xff0c;可以在这查看https://python.langchain.com/v0.2/docs/integrations/document_lo…

王学岗鸿蒙开发(北向)——————(七、八)ArkUi的各种装饰器

arts包含如下&#xff1a;1&#xff0c;装饰器 &#xff1b;2&#xff0c;组件的描述(build函数)&#xff1b;3&#xff0c;自定义组件(Component修饰的),是可复用的单元&#xff1b;4&#xff0c;系统的组件(鸿蒙官方提供)&#xff1b;等 装饰器的作用:装饰类、变量、方法、结…

著名AI人工智能社会学家唐兴通谈数字社会学网络社会学主要矛盾与数字空间社会网络社会的基本议题与全球海外最新热点与关注社会结构社会分工数字财富数字游民数字经济

如果人工智能解决了一切&#xff0c;人类会做什么&#xff1f; 这个问题的背后是人工智能时代的社会主要矛盾会是什么&#xff1f;那么整个社会的大的分工体系就会围绕主要矛盾开展。 《人工智能社会主要矛盾》 在农业社会&#xff0c;主要矛盾是人口增长和土地资源之间的关…

这家叉车AGV巨头2024年一季度销售4075万~

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》人俱乐部 一、BALYO公司概览 BALYO&#xff0c;这家来自法国的仓储机器人公司&#xff0c;自2006年成立以来&#xff0c;一直致力于为全球客户提供…

LabVIEW 与组态软件在自动化系统中的应用比较与选择

LabVIEW 确实在非标单机设备、测试和测量系统中有着广泛的应用&#xff0c;特别是在科研、教育、实验室和小型自动化设备中表现突出。然而&#xff0c;LabVIEW 也具备一定的扩展能力&#xff0c;可以用于更复杂和大型的自动化系统。以下是对 LabVIEW 与组态软件在不同应用场景中…

嵌入式单片机产品微波炉拆解分享

在厨房电器中,微波炉可以说是最具技术含量的电器,它的工作原理不像其他电器那样一眼就能看个明白,于是拆解了一个微波炉,分析内部电路。 微波炉的结构 微波炉由箱体、磁控管、变压器、高压电容器、高压二极管、散热风扇、转盘装置及一系列控制保护开关组成,大多数微波炉还…

【详细的Kylin使用心得,什么是Kylin?】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

【JAVASE】面向对象编程综合案例--------模仿电影信息系统

需求&#xff1a; &#xff08;1&#xff09;展示系统中的全部电影&#xff08;每部电影展示&#xff1a;名称、价格&#xff09; &#xff08;2&#xff09;允许用户根据电影编号&#xff08;ID&#xff09;查询出某个电影的详细信息。 目标&#xff1a;使用所学的面向对象…

报表或者BI的价值在哪?这是十几年的问题啦!

对&#xff0c;问题已经十几年了&#xff0c;答案也应该普世都懂了吧&#xff0c;但非常遗憾&#xff0c;答案没有问题普及的广。看似简单&#xff0c;但也难说清楚&#xff0c;不同的人&#xff0c;总会有不同的看法。 为什么要解释这个并不新鲜的问题&#xff1f; 因为有人问…

计网总结☞网络层

.................................................. 思维导图 ........................................................... 【Wan口和Lan口】 WAN口&#xff08;Wide Area Network port&#xff09;&#xff1a; 1)用于连接外部网络&#xff0c;如互联…

Java面向对象-[封装、继承、多态、权限修饰符]

Java面向对象-封装、继承、权限修饰符 一、封装1、案例12、案例2 二、继承1、案例12、总结 三、多态1、案例 四、权限修饰符1、private2、default3、protected4、public 一、封装 1、案例1 package com.msp_oop;public class Girl {private int age;public int getAge() {ret…

【AI 高效问答系统】机器阅读理解实战内容

⭐️我叫忆_恒心&#xff0c;一名喜欢书写博客的研究生&#x1f468;‍&#x1f393;。 如果觉得本文能帮到您&#xff0c;麻烦点个赞&#x1f44d;呗&#xff01; 近期会不断在专栏里进行更新讲解博客~~~ 有什么问题的小伙伴 欢迎留言提问欧&#xff0c;喜欢的小伙伴给个三连支…

Git【版本控制命令】

02 【本地库操作】 1.git的结构 2.Git 远程库——代码托管中心 2.1 git工作流程 代码托管中心用于维护 Git 的远程库。包括在局域网环境下搭建的 GitLab 服务器&#xff0c;以及在外网环境下的 GitHub 和 Gitee (码云)。 一般工作流程如下&#xff1a; 1&#xff0e;从远程…

大模型创新企业集结!百度智能云千帆AI加速器Demo Day启动

新一轮技术革命风暴席卷而来&#xff0c;为创业带来源源不断的创新动力。过去一年&#xff0c;在金融、制造、交通、政务等领域&#xff0c;大模型正从理论到落地应用&#xff0c;逐步改变着行业的运作模式&#xff0c;成为推动行业创新和转型的关键力量。 针对生态伙伴、创业…