[大模型]Gemma2b-Instruct Lora 微调

本节我们简要介绍如何基于 transformers、peft 等框架,对 Gemma2b 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出 Lora。

这个教程会在同目录下给大家提供一个 nodebook 文件,来让大家更好的学习。

环境准备

在 Autodl 平台中租赁一个 3090 等 24G 显存的显卡机器,如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1
接下来打开刚刚租用服务器的 JupyterLab,并且打开其中的终端开始环境配置、模型下载和运行演示。

在这里插入图片描述

环境配置

在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:

python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

pip install modelscope==1.9.5
pip install "transformers>=4.40.0"
pip install streamlit==1.24.0
pip install sentencepiece==0.1.99
pip install accelerate==0.29.3
pip install datasets==2.19.0
pip install peft==0.10.0

MAX_JOBS=8 pip install flash-attn --no-build-isolation

注意:flash-attn 安装会比较慢,大概需要十几分钟。

在本节教程里,我们将微调数据集放置在根目录 /dataset。

模型下载

使用 modelscope 中的 snapshot_download 函数下载模型,第一个参数为模型名称,参数 cache_dir 为模型的下载路径。

在 /root/autodl-tmp 路径下新建 model_download.py 文件并在其中输入以下内容,粘贴代码后记得保存文件,如下图所示。并运行 python /root/autodl-tmp/model_download.py 执行下载,模型大小为 15 GB,下载模型大概需要 2 分钟。

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os

model_dir = snapshot_download('Lucachen/gemma2b', cache_dir='/root/autodl-tmp', revision='master')

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
  "instruction": "回答以下用户问题,仅输出答案。",
  "input": "1+1等于几?",
  "output": "2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:

{
  "instruction": "你是谁?",
  "input": "",
  "output": "家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384    # 分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<start_of_turn>system\n现在你要扮演皇帝身边的女人--甄嬛<end_of_turn>\n<start_of_turn>user\n{example['instruction'] + example['input']}<end_of_turn>\n<start_of_turn>model\n", add_special_tokens=False)
    #response = tokenizer(f"{example['output']}", add_special_tokens=False)
    response = tokenizer(f"{example['output']}<end_of_turn>model", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

gemma 采用的Prompt Template格式如下:

<start_of_turn>system
你是一个小助手<end_of_turn>
<start_of_turn>user
你好<end_of_turn>
<start_of_turn>model
你好!我也很高兴见到你!有什么问题或话题想聊天吗?你好!很高兴你来了。请问您有什么问题或需要我帮助的吗?<end_of_turn>model

加载 tokenizer 和半精度模型

模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载。对于自定义的模型一定要指定trust_remote_code参数为True

tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/Lucachen/gemma2b', use_fast=False, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained('/root/autodl-tmp/Lucachen/gemma2b', device_map="auto",torch_dtype=torch.bfloat16)

定义 LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理

Lora的缩放是啥嘞?当然不是r(秩),这个缩放就是lora_alpha/r, 在这个LoraConfig中缩放就是 4 倍。

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(
    output_dir="./output/gemma2b",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=3,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True
)

使用 Trainer 训练

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

保存 lora 权重

lora_path='./gemma2b_lora'
trainer.model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)

加载 lora 权重推理

训练好了之后可以使用如下方式加载lora权重进行推理:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

mode_path = '/root/autodl-tmp/Lucachen/gemma2b'
lora_path = './gemma2b_lora' # lora权重路径

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16)

# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)

prompt = "你是谁?"
messages = [
    # {"role": "system", "content": "现在你要扮演皇帝身边的女人--甄嬛"},
    {"role": "user", "content": prompt}
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

generated_ids = model.generate(
    model_inputs.input_ids,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

完整微调脚本参考:https://github.com/datawhalechina/self-llm/blob/master/Gemma/04-Gemma-2B-Lora%E5%BE%AE%E8%B0%83.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/698384.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI智能体的分级

技术的分级 人们往往通过对一个复杂的技术进行分级&#xff0c;明确性能、适用范围和价值&#xff0c;方便比较、选择和管理&#xff0c;提高使用效率&#xff0c;促进资源合理分配和技术改进和标准化。 比如&#xff0c;国际汽车工程师学会&#xff08;SAE&#xff09;定义了自…

诊所管理系统有没有免费的?

随着消费市场的不断变化&#xff0c;消费型医疗领域逐渐受到重视&#xff0c;医疗机构面临激烈的行业竞争。为了提升诊所的经营管理效率和营销能力&#xff0c;选择一款合适的诊所管理系统变得尤为重要。然而&#xff0c;面对市场上众多的系统选择&#xff0c;如何找到适合自己…

Linux 安装ab测试工具

yum -y install httpd-tools ab -help #10个并发连接&#xff0c;100个请求 ab -n 200 -c 100 http://www.baidu.com/

如何制作MapBox个性化地图

我们在《如何在QGIS中加载MapBox图源》一文中&#xff0c;为你分享了在QGIS中加载MapBox的方法。 现在为你分享如何制作MapBox个性化地图的方法&#xff0c;如果你需要最新版本的QGIS及高清图源&#xff0c;请在文末查看获取软件安装包的方法。 新建地图样式 进入Mapbox Stu…

智慧监狱大数据整体解决方案(51页PPT)

方案介绍&#xff1a; 智慧监狱大数据整体解决方案通过集成先进的信息技术和大数据技术&#xff0c;实现对监狱管理的全面升级和智能化改造。该方案不仅提高了监狱管理的安全性和效率&#xff0c;还提升了监狱的智能化水平&#xff0c;为监狱的可持续发展提供了有力支持。 部…

计算机组成原理之计算机的性能指标

目录 计算机的性能指标 复习提示 1.计算机的主要性能指标 1.1机器字长 1.1.1与机器字长位数相同的部件 1.2数据通路带宽 1.3主存容量 1.4运算速度 1.4.1提高系统性能的综合措施 1.4.2时钟脉冲信号和时钟周期的相关概念 1.4.3主频和时钟周期的转换计算 1.4.4IPS的相关…

任务3.8.1 利用RDD实现词频统计

实战&#xff1a;利用RDD实现词频统计 目标 使用Apache Spark的RDD&#xff08;弹性分布式数据集&#xff09;模块实现一个词频统计程序。 环境准备 选择实现方式 确定使用Spark RDD进行词频统计。 Spark版本与Scala版本匹配 选择Spark 3.1.3与Scala 2.12.15以匹配现有Spar…

使用seq2seq架构实现英译法

seq2seq介绍 模型架构&#xff1a; Seq2Seq&#xff08;Sequence-to-Sequence&#xff09;模型是一种在自然语言处理&#xff08;NLP&#xff09;中广泛应用的架构&#xff0c;其核心思想是将一个序列作为输入&#xff0c;并输出另一个序列。这种模型特别适用于机器翻译、聊天…

定个小目标之刷LeetCode热题(11)

这是道简单题&#xff0c;只想到了暴力解法&#xff0c;就是用集合存储起来&#xff0c;然后找出其中的众数&#xff0c;看了一下题解&#xff0c;发现有多种解法&#xff0c;我觉得Boyer-Moore 投票算法是最优解&#xff0c;看了官方对这个算法的解释&#xff0c;我是这样理解…

每天一个数据分析题(三百五十九)- 多维分析模型

图中是某公司记录销售情况相关的表建立好的多维分析模型&#xff0c;请根据模型回答以下问题&#xff1a; 2&#xff09;产品表左连接品牌表的对应关系属于&#xff1f; A. 一对多 B. 一对一 C. 多对一 D. 多对多 数据分析认证考试介绍&#xff1a;点击进入 题目来源于CD…

高光谱成像光源 实现对细微色差的分类--51camera

光源在机器视觉中的重要性不容小觑&#xff0c;它直接影响到图像的质量&#xff0c;进而影响整个系统的性能。然而自然光LED光源不能完全满足实际需求&#xff0c;比如对细微的色差进行分类&#xff0c;我们就需要考虑红外高光谱光源。 所谓高光谱成像&#xff0c;是指使用具有…

实现思路:Vue 子组件高度不固定下实现瀑布流布局

实现思路&#xff1a;Vue 子组件高度不固定下实现瀑布流布局 一、瀑布流布局基础实现原理 在深入解说不定高度子组件的瀑布流如何实现之前&#xff0c;先大体说一下子组件高度固定已知的这种实现原理&#xff1a; 有一个已知组件高度的数组。定义好这个瀑布流的列数&#xff…

【回调函数】

1.回调函数是什么&#xff1f; 回调函数就是⼀个通过函数指针调用的函数。 如果你把函数的指针&#xff08;地址&#xff09;作为参数传递给另⼀个函数&#xff0c;当这个指针被用来调用其所指向的函数 时&#xff0c;被调用的函数就是回调函数。回调函数不是由该函数的实现方…

linux的持续性学习

安装php 第一步&#xff1a;配置yum源 第二步&#xff1a;下载php。 yum install php php-gd php-fpm php-mysql -y 第三步&#xff1a;启动php。 systemctl start php-fpm 第四步&#xff1a;检查php是否启动 lsof -i :9000 计划任务 作用&am…

问题:下列关于信息的说法,错误的是( )。 #媒体#知识分享#学习方法

问题&#xff1a;下列关于信息的说法&#xff0c;错误的是&#xff08; &#xff09;。 A. 信息无时不在&#xff0c;无时不有 B. 信息随着时间的推移&#xff0c;效用会越来越小 C. 信息具有可处理性 D. 信息是同一时刻只能被单方使用 参考答案如图所示

pnpm : 无法加载文件 C:\Users\xxxxx\AppData\Roaming\npm\pnpm.ps1,因为在此系统上禁止运行脚本。

vscode中执行pnpm install的时候&#xff0c;直接报了上面的错误。 解决&#xff1a; 然后输入&#xff1a;set-ExecutionPolicy RemoteSigned&#xff0c;按回车&#xff0c;然后根据提示&#xff0c;我们选A。 然后回车。 这样我们再次回到vscode中的我们就会发现可以了。 …

C++:SLT容器-->stack

C:SLT容器--&#xff1e;stack 1. stack容器2. stack 常用接口 1. stack容器 先进后出&#xff0c;后进先出不允许有遍历行为可以判断容器是否为空可以返回元素的个数 2. stack 常用接口 构造函数 stack<T> stk; // stack采用模板类实现&#xff0c;stack对象的默认构造形…

认识Spring 中的BeanPostProcessor

关于BeanPostProcessor和BeanFactoryPostProcessors&#xff0c;将分2篇文章来写&#xff0c;这篇文章是对Spring 中BeanPostProcessor进行了总结 先看下大模型对这个类的介绍&#xff0c;随后再看下这两个类的示例&#xff0c;最后看下这两个类的实现。 这两个类从名字看都很类…

三丰云免费服务器

云网址&#xff1a; https://www.sanfengyun.com 可申请免费云服务器&#xff0c;1核/1G内存/5M宽带/有公网IP/10G SSD硬盘/免备案。 收费云服务器&#xff0c;买2年送1年&#xff0c;有很多优惠

MAVEN架构项目管理工具(下)

1、maven工程约定目录结构 每一个maven在磁盘中都是一个文件夹&#xff08;即项目&#xff0c;以hello项目为例&#xff09; Hello/---/src------/main #放置主程序Java代码和配置文件-----------/java #程序的包和包中的文件-----------/resource #java程…