[大模型]LLaMA3-8B-Instruct Lora 微调

本节我们简要介绍如何基于 transformers、peft 等框架,对 LLaMA3-8B-Instruct 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出 Lora。

这个教程会在同目录下给大家提供一个 nodebook 文件,来让大家更好的学习。

环境准备

在 Autodl 平台中租赁一个 3090 等 24G 显存的显卡机器,如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1
接下来打开刚刚租用服务器的 JupyterLab,并且打开其中的终端开始环境配置、模型下载和运行演示。

在这里插入图片描述

环境配置

在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:

python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

pip install modelscope==1.9.5
pip install "transformers>=4.40.0"
pip install streamlit==1.24.0
pip install sentencepiece==0.1.99
pip install accelerate==0.29.3
pip install datasets==2.19.0
pip install peft==0.10.0

MAX_JOBS=8 pip install flash-attn --no-build-isolation

注意:flash-attn 安装会比较慢,大概需要十几分钟。

考虑到部分同学配置环境可能会遇到一些问题,我们在 AutoDL 平台准备了 LLaMA3 的环境镜像,该镜像适用于该仓库的所有部署环境。点击下方链接并直接创建 Autodl 示例即可。
https://www.codewithgpu.com/i/datawhalechina/self-llm/self-llm-LLaMA3

在本节教程里,我们将微调数据集放置在根目录 /dataset。

模型下载

使用 modelscope 中的 snapshot_download 函数下载模型,第一个参数为模型名称,参数 cache_dir 为模型的下载路径。

在 /root/autodl-tmp 路径下新建 model_download.py 文件并在其中输入以下内容,粘贴代码后记得保存文件,如下图所示。并运行 python /root/autodl-tmp/model_download.py 执行下载,模型大小为 15 GB,下载模型大概需要 2 分钟。

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os

model_dir = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct', cache_dir='/root/autodl-tmp', revision='master')

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
  "instruction": "回答以下用户问题,仅输出答案。",
  "input": "1+1等于几?",
  "output": "2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:

{
  "instruction": "你是谁?",
  "input": "",
  "output": "家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

Llama-3-8B-Instruct 采用的Prompt Template格式如下:

<|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|>'
<|start_header_id|>user<|end_header_id|>
你是谁?<|eot_id|>'
<|start_header_id|>assistant<|end_header_id|>
我是一个有用的助手。<|eot_id|>"

加载 tokenizer 和半精度模型

模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载。对于自定义的模型一定要指定trust_remote_code参数为True

tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct', use_fast=False, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained('/root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct', device_map="auto",torch_dtype=torch.bfloat16)

定义 LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理

Lora的缩放是啥嘞?当然不是r(秩),这个缩放就是lora_alpha/r, 在这个LoraConfig中缩放就是 4 倍。

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(
    output_dir="./output/llama3",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=3,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True
)

使用 Trainer 训练

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

保存 lora 权重

lora_path='./llama3_lora'
trainer.model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)

加载 lora 权重推理

训练好了之后可以使用如下方式加载lora权重进行推理:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

mode_path = '/root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct'
lora_path = './llama3_lora' # lora权重路径

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16)

# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)

prompt = "你是谁?"
messages = [
    # {"role": "system", "content": "现在你要扮演皇帝身边的女人--甄嬛"},
    {"role": "user", "content": prompt}
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

generated_ids = model.generate(
    model_inputs.input_ids,
    max_new_tokens=512,
    do_sample=True,
    top_p=0.9, 
    temperature=0.5, 
    repetition_penalty=1.1,
    eos_token_id=tokenizer.encode('<|eot_id|>')[0],
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

完整微调源码:
https://github.com/datawhalechina/self-llm/blob/master/LLaMA3/LLaMA3-8B-Instruct%20Lora.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/700823.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

观成科技:基于深度学习技术的APT加密流量检测与分类检测方案

一、前言 近年来APT攻击的案例屡见不鲜&#xff0c;给国家、企业以及个人的利益造成极大威胁。随着流量加密技术的不断成熟&#xff0c;许多APT组织倾向于将流量加密后进行传输&#xff0c;从而保护传输内容。由于加密流量的实际载荷已被加密&#xff0c;故采用原始的流量检测…

[大模型]Llama-3-8B-Instruct FastApi 部署调用

环境准备 在 Autodl 平台中租赁一个 3090 等 24G 显存的显卡机器&#xff0c;如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1。 接下来打开刚刚租用服务器的 JupyterLab&#xff0c;并且打开其中的终端开始环境配置、模型下载和运行演示。 pip 换源…

RabbitMQ从入门到入土

同步与异步 同步调用 优势&#xff1a; 时效性强&#xff0c;等到结果后就返回 问题&#xff1a; 扩展性差 性能下降 级联失败问题 异步调用 优势&#xff1a; 耦合度低&#xff0c;扩展性强 无需等待&#xff0c;性能好 故障隔离&#xff0c;下游服务故障不影响上游 缓…

【C语言】12.C语言内存函数

文章目录 1.memcpy使用和模拟实现2.memmove使用和模拟实现3.memset函数的使用4.memcmp函数的使用 memcpy&#xff1a;内存拷贝 memmove&#xff1a;内存移动 memset&#xff1a;内存设置 memcmp&#xff1a;内存比较 1.memcpy使用和模拟实现 memcpy&#xff1a;内存拷贝 void…

Mysql查询分析工具Explain的使用

一、前言 作为一名合格的开发人员&#xff0c;与数据库打交道是必不可少的&#xff0c;尤其是在业务规模和数据体量大规模增长的条件下&#xff0c;应用系统大部分请求读写比例在10:1左右&#xff0c;而且插入操作和一般的更新操作很少出现性能问题&#xff0c;遇到最多的&…

产品人生(12):从“产品生命周期管理”看如何做“职业规划”

产品生命周期管理是产品人常接触的一个概念&#xff0c;它是一种全面管理产品从概念构想、设计开发、生产制造、市场推广、销售使用&#xff0c;直至最终退役的全生命周期过程的方法论和一系列业务流程。下面我们来简单介绍下产品生命周期管理&#xff1a; 概念阶段&#xff1a…

MybatisPlus代码生成器使用案例

针对数据库中的实体类表&#xff0c;自动生成相关的pojo类&#xff0c;mapper&#xff0c;service等 1. Get-Started 基于mybatisplus&#xff0c;idea下载mybatisplus插件 sql文件 /*!40101 SET OLD_CHARACTER_SET_CLIENTCHARACTER_SET_CLIENT */; /*!40101 SET NAMES utf8 …

面试官:MySQL也可以实现分布式锁吗?

首先说结论&#xff0c;可以做&#xff0c;但不推荐做。 我们并不推荐使用数据库实现分布式锁。 如果非要这么做&#xff0c;实现大概有两种。 1、锁住Java的方法&#xff0c;借助insert实现 如何用数据库实现分布式锁呢&#xff0c;简单来说就是创建一张锁表&#xff0c;比…

PB案例学习笔记-19制作一个图片按钮

写在前面 这是PB案例学习笔记系列文章的第19篇&#xff0c;该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习&#xff0c;提高编程技巧&#xff0c;以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码&#xff0c;小凡都上传到了gite…

MAC认证

简介 MAC认证是一种基于接口和MAC地址对用户的网络访问权限进行控制的认证方法&#xff0c;它不需要用户安装任何客户端软件。设备在启动了MAC认证的接口上首次检测到用户的MAC地址以后&#xff0c;即启动对该用户的认证操作。认证过程中&#xff0c;不需要用户手动输入用户名…

Leetcode3174. 清除数字

Every day a Leetcode 题目来源&#xff1a;3174. 清除数字 解法1&#xff1a;栈 用栈模拟&#xff0c;遇到数字就弹出栈顶&#xff0c;遇到字母就插入栈。 最后留在栈里的就是答案。 代码&#xff1a; /** lc appleetcode.cn id3174 langcpp** [3174] 清除数字*/// lc c…

如何做好期货投资?

期货&#xff0c;这个词对于很多人来说可能还是个陌生的词汇&#xff0c;但是&#xff0c;随着经济的发展和人们对金融投资的需求增加&#xff0c;期货投资也变得越来越受到关注。那么&#xff0c;如何才能做好期货投资呢&#xff1f; 首先&#xff0c;了解期货的基本知识是非…

现货黄金交易多少克一手?国内外情况大不同

如果大家想参与国际市场上的现货黄金交易&#xff0c;就应该从它交易细则的入手&#xff0c;先彻底认识这个品种&#xff0c;因为它是来自欧美市场的投资方式&#xff0c;所以无论是从合约的计的单位&#xff0c;计价的货币&#xff0c;交易的具体时间&#xff0c;以及买卖过程…

word空白页删除不了怎么办?

上方菜单栏点击“视图”&#xff0c;下方点击“大纲视图”。找到文档分页符的位置。将光标放在要删除的分节符前&#xff0c;按下键盘上的“Delet”键删除分页符。

Filament 【表单操作】修改密码

场景描述&#xff1a; 新增管理员信息时需要填写密码&#xff0c;修改管理员信息时密码可以为空&#xff08;不修改密码&#xff09;&#xff0c;此时表单中密码输入有冲突&#xff0c;需要对表单中密码字段进项条件性的判断&#xff0c;使字段在 create 操作时为必需填写&…

服务器部署spring项目jar包使用bat文件,省略每次输入java -jar了

echo off set pathC:\Program Files\Java\jre1.8.0_191\bin START "YiXiangZhengHe-8516" "%path%/java" -Xdebug -jar -Dspring.profiles.activeprod -Dserver.port8516 YiXiangZhengHe-0.0.1-SNAPSHOT.jar 将set path后面改成jre的bin文件夹 START 后…

在微信小程序中安装和使用vant框架

目录 1、初始化项目2、安装vant相关依赖3、修改 app.json4、修改 project.config.json5、构建npm6、使用示例 本文将详细介绍如何在微信小程序中安装并使用vant框架&#xff5e; 开发工具&#xff1a;微信开发者工具 1、初始化项目 从终端进入小程序项目目录&#xff0c;执行…

一个数据查询导出工具

数据查询导出工具 安装说明 安装完成后在桌面会创建“数据查询导出工具”的查询工具。 程序初始化 配置数据库连接 首次运行&#xff0c;请先配置数据库连接 点击“数据库连接”后&#xff0c;会出现下面的窗体&#xff0c;要求输入维护工程师密码。&#xff08;维护工程师密码…

VScode如何调试

调试 1.打断点 1.点击调试按钮 3.点击下拉选择环境node&#xff0c;点击绿三角选择输入调试的命令&#xff08;具体命令查看package.json中scripts中的哪一个命令和运行的文件&#xff09;&#xff0c;点击右边的设置&#xff08;可以直接跳下面第八步&#xff01;&#xff…

实体类status属性使用枚举类型的步骤

1. 问题引出 当实体类的状态属性为Integer类型时&#xff0c;容易写错 2. 初步修改 把状态属性强制为某个类型&#xff0c;并且自定义一些可供选择的常量。 public class LessonStatus {public static final LessonStatus NOT_LEARNED new LessonStatus(0,"未学习"…