[大模型]GLM4-9B-chat Lora 微调

本节我们简要介绍如何基于 transformers、peft 等框架,对 LLaMA3-8B-Instruct 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出 Lora。

这个教程会在同目录下给大家提供一个 nodebook 文件,来让大家更好的学习。

环境准备

在 Autodl 平台中租赁一个 3090 等 24G 显存的显卡机器,如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1
接下来打开刚刚租用服务器的 JupyterLab,并且打开其中的终端开始环境配置、模型下载和运行演示。

在这里插入图片描述

环境配置

在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:

python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

pip install modelscope==1.9.5
pip install "transformers>=4.40.0"
pip install streamlit==1.24.0
pip install sentencepiece==0.1.99
pip install accelerate==0.29.3
pip install datasets==2.19.0
pip install peft==0.10.0
pip install tiktoken==0.7.0

MAX_JOBS=8 pip install flash-attn --no-build-isolation

注意:flash-attn 安装会比较慢,大概需要十几分钟。

考虑到部分同学配置环境可能会遇到一些问题,我们在 AutoDL 平台准备了 GLM-4 的环境镜像,该镜像适用于本教程需要 GLM-4 的部署环境。点击下方链接并直接创建 AutoDL 示例即可。(vLLM 对 torch 版本要求较高,且越高的版本对模型的支持更全,效果更好,所以新建一个全新的镜像。) https://www.codewithgpu.com/i/datawhalechina/self-llm/GLM-4

在本节教程里,我们将微调数据集放置在根目录 /dataset。

模型下载

使用 modelscope 中的 snapshot_download 函数下载模型,第一个参数为模型名称,参数 cache_dir 为模型的下载路径。

在 /root/autodl-tmp 路径下新建 model_download.py 文件并在其中输入以下内容,粘贴代码后记得保存文件,如下图所示。并运行 python /root/autodl-tmp/model_download.py 执行下载。

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os

model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat', cache_dir='/root/autodl-tmp/glm-4-9b-chat', revision='master')

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
  "instruction": "回答以下用户问题,仅输出答案。",
  "input": "1+1等于几?",
  "output": "2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:

{
  "instruction": "你是谁?",
  "input": "",
  "output": "家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer((f"[gMASK]<sop><|system|>\n假设你是皇帝身边的女人--甄嬛。<|user|>\n"
                            f"{example['instruction']+example['input']}<|assistant|>\n"
                            ), 
                            add_special_tokens=False)
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

GLM4-9B-chat 采用的Prompt Template格式如下:

[gMASK]<sop><|system|> 
假设你是皇帝身边的女人--甄嬛。<|user|> 
小姐,别的秀女都在求中选,唯有咱们小姐想被撂牌子,菩萨一定记得真真儿的——<|assistant|> 
嘘——都说许愿说破是不灵的。<|endoftext|>

加载 tokenizer 和半精度模型

模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载。对于自定义的模型一定要指定trust_remote_code参数为True

tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat', use_fast=False, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained('/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat', device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True)

定义 LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理

Lora的缩放是啥嘞?当然不是r(秩),这个缩放就是lora_alpha/r, 在这个LoraConfig中缩放就是 4 倍。

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],  # 现存问题只微调部分演示即可
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(
    output_dir="./output/GLM4",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=50,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-5,
    save_on_each_node=True,
    gradient_checkpointing=True
)

使用 Trainer 训练

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

保存 lora 权重

lora_path='./GLM4'
trainer.model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)

加载 lora 权重推理

训练好了之后可以使用如下方式加载lora权重进行推理:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

mode_path = '/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat'
lora_path = './GLM4_lora'

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()

# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path)

prompt = "你是谁?"
inputs = tokenizer.apply_chat_template([{"role": "user", "content": "假设你是皇帝身边的女人--甄嬛。"},{"role": "user", "content": prompt}],
                                       add_generation_prompt=True,
                                       tokenize=True,
                                       return_tensors="pt",
                                       return_dict=True
                                       ).to('cuda')


gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
    outputs = model.generate(**inputs, **gen_kwargs)
    outputs = outputs[:, inputs['input_ids'].shape[1]:]
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

完整脚本参考:
https://github.com/datawhalechina/self-llm/blob/master/GLM-4/05-GLM-4-9B-chat%20Lora%20%E5%BE%AE%E8%B0%83.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/697709.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【docker】 pull access denied for alpine-java, repository does not exist

问题&#xff1a; com.spotify.docker.client.exceptions.DockerException: pull access denied for alpine-java, repository does not exist or may require docker login: denied: requested access to the resource is denied org.apache.maven.plugin.MojoExecutionExce…

[Algorithm][动态规划][完全背包问题][零钱兑换][零钱兑换Ⅱ][完全平方数]详细讲解

目录 1.零钱兑换1.题目链接2.算法原理详解3.代码实现 2.零钱兑换 II1.题目链接2.算法原理详解3.代码实现 3.完全平方数1.题目链接2.算法原理详解3.代码实现 1.零钱兑换 1.题目链接 零钱兑换 2.算法原理详解 思路&#xff1a; 确定状态表示 -> dp[i][j]的含义 dp[i][j]&am…

Gitlab安装配置

gitlab git是一个分布式的代码版本管理软件。用于敏捷高效地处理任何或小或大的项目。Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。 1.版本控制 是指对软件开发过程中各种程序代码&#xff0c;配置文件及说明文档等文件变更的管…

如何用R语言ggplot2画高水平期刊散点图

文章目录 前言一、数据集二、ggplot2画图1、全部代码2、细节拆分1&#xff09;导包2&#xff09;创建图形对象3&#xff09;主题设置4&#xff09;轴设置5&#xff09;图例设置6&#xff09;散点颜色7&#xff09;保存图片 前言 一、数据集 数据下载链接见文章顶部 处理前的数据…

LabVIEW图像采集处理项目中相机选择与应用

在LabVIEW图像采集处理项目中&#xff0c;选择合适的相机是确保项目成功的关键。本文将详细探讨相机选择时需要关注的参数、黑白相机与彩色相机的区别及其适用场合&#xff0c;帮助工程师和开发者做出明智的选择。 相机选择时需要关注的参数 1. 分辨率 定义&#xff1a;分辨率…

上心师傅的思路分享(三)--Nacos渗透

目录 1. 前言 2. Nacos 2.1 Nacos介绍 2.2 鹰图语法 2.3 fofa语法 2.3 漏洞列表 未授权API接口漏洞 3 环境搭建 3.1 方式一: 3.2 方式二: 3.3 访问方式 4. 工具监测 5. 漏洞复现 5.1 弱口令 5.2 未授权接口 5.3.1 用户信息 API 5.3.2 集群信息 API 5.3.3 配置…

Functional ALV系列 (10) - 将填充FieldCatalog封装成函数

在前面的博文中&#xff0c;已经讲了封装的思路和实现&#xff0c;主要是利用 cl_salv_data_descr>read_structdescr () 方法来实现。在这里&#xff0c;贴出代码方便大家参考。 编写获取内表组件的通用方法 form frm_get_fields using pt_data type any tablechanging…

微信小程序双层/多层 wx:for 循环嵌套,关于内外层的 index 和 item ;data-index 传递两个参数

微信小程序用 wx:for 循环可以快速将后端 js 的数组快速显示到前端&#xff1b; 那假如数组中嵌套数组&#xff1b;就存在内外层两层及以上的多层嵌套循环了。 那么如果两层的嵌套式循环 index 究竟是属于哪一层呢&#xff1f;item 又属于哪一个呢&#xff1f; <view><…

java之面向对象2笔记

1 接口(interface) 1.1 概述 接口&#xff08;Interface&#xff09;在计算机科学中&#xff0c;特别是在面向对象编程&#xff08;OOP&#xff09;中&#xff0c;是一个重要的概念。它定义了一组方法的规范&#xff0c;但没有实现这些方法的具体代码。接口的主要目的是确保类…

介绍Linux

目录 1.什么是操作系统 2.现实生活中的操作系统 3.操作系统的发展史 4.操作系统的发展 Linux的不同版本以及应用领域 1.Linux内核及发行版介绍 <1>Linux内核版本 <2>Linux发行版本 2.应用领域 个⼈桌⾯领域的应⽤ 服务器领域 嵌⼊式领域 3.文件和目录 …

Pulsar 社区周报 | No.2024-06-07 | Apache Pulsar 新分支 3.3 版本发布

“ 各位热爱 Pulsar 的小伙伴们&#xff0c;Pulsar 社区周报更新啦&#xff01;这里将记录 Pulsar 社区每周的重要更新&#xff0c;每周发布。 ” 本期主题&#xff1a;Apache Pulsar 新分支 3.3 版本发布 Apache Pulsar 新分支 3.3 版本发布&#xff1a;Apache Pulsar 3.3.0[1…

2024屈原故里端午文化节开幕

6月7日&#xff0c;“中国端午诗意宜昌”2024屈原故里端午文化节在宜昌市秭归县屈原广场盛大开幕。相关嘉宾、屈氏后裔、华商侨商及外资企业代表、主流媒体代表和当地民众聚首屈原祠前&#xff0c;缅怀诗祖屈原&#xff0c;共襄端午盛典。 本届屈原故里端午文化节由湖北省人民政…

【SQLAlChemy】filter过滤条件如何使用?

filter 过滤条件 生成 mock 数据 # 创建 session 对象 session sessionmaker(bindengine)()# 本地生成mock数据 for i in range(6):# 生成随机名字, 长度为4到7个字符name .join(random.choice(string.ascii_letters) for _ in range(random.randint(4, 7)))# 生成随机年龄…

Lua搭建网站后台教程

本文讲解如何使用二进制发布包和FastWeb网站管理工具搭建站点 FastWeb网站管理工具 使用该工具可快速在Windows平台部署。支持官方或三方模块的自动安装、日志调试、版本更新等。 1、下载最新版本压缩包 2、解压到任意目录(建议英文) 3、运行 ①点击 [设置]->[安装] 部…

IO进程线程(八)线程(pthread_t)

文章目录 一、线程(LWP)概念二、线程相关函数&#xff08;一&#xff09;创建 pthread_create1. 定义2. 使用&#xff08;不传参&#xff09;3. 使用&#xff08;单个参数&#xff09;4. 使用&#xff08;多个参数&#xff09;5. 多线程执行的顺序6. 多线程内存空间 &#xff0…

BUUCTF---web---[GYCTF2020]Blacklist

1、来到题目连接页面 2、测试单引号和双引号&#xff0c;单引号报错&#xff0c;双引号没报错 1 1" 3、使用万能句式 4、使用堆叠注入测试&#xff0c;查看数据库名 1;show databases;# 5、查看表名 1;show tables;# 6、查看FlagHere中字段名 1;show columns from FlagH…

Python | Leetcode Python题解之第143题重排链表

题目&#xff1a; 题解&#xff1a; class Solution:def reorderList(self, head: ListNode) -> None:if not head:returnmid self.middleNode(head)l1 headl2 mid.nextmid.next Nonel2 self.reverseList(l2)self.mergeList(l1, l2)def middleNode(self, head: ListNo…

Webpack 从入门到精通-基础篇

一、webpack 简介 1.1 webpack 是什么 webpack 是一种前端资源构建工具&#xff0c;一个静态模块打包器(module bundler)。 在 webpack 看来, 前端的所有资源文件(js/json/css/img/less/...)都会作为模块处理。 它将根据模块的依赖关系进行静态分析&#xff0c;打包生成对应的…

CA证书及PKI

文章目录 概述非对称加密User Case: 数据加密User Case: 签名验证潜在问题 CACA证书的组成CA签发证书流程CA验证签名流程CA吊销证书流程 PKI信任链证书链 概述 首先我们需要简单对证书有一个基本的概念&#xff0c;以几个问题进入了解 ❓ Question1: 什么是证书&#xff1f; 证…

数据可视化——pyecharts库绘图

目录 官方文档 使用说明&#xff1a; 点击基本图表 可以点击你想要的图表 安装&#xff1a; 一些例图&#xff1a; 柱状图&#xff1a; 效果&#xff1a; 折线图&#xff1a; 效果&#xff1a; 环形图&#xff1a; 效果&#xff1a; 南丁格尔图&#xff08;玫瑰图&am…