[大模型]# Yi-6B-Chat Lora 微调

Yi-6B-Chat Lora 微调

概述

本节我们介绍如何基于 transformers、peft 等框架,对 Yi-6B-Chat 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出Lora。

本节所讲述的代码脚本在同级目录 04-Yi-6B-Chat Lora 微调 下,运行该脚本来执行微调过程,但注意,本文代码未使用分布式框架,微调 Yi-6B-Chat 模型至少需要 20G 及以上的显存,且需要修改脚本文件中的模型路径和数据集路径。

环境配置

在完成基本环境配置和本地模型部署的情况下(本教程中使用的模型路径是 /root/autodl-tmp/01ai/Yi-6B-Chat ),你还需要安装一些第三方库,可以使用以下命令:

pip install transformers==4.35.2
pip install peft==0.4.0
pip install datasets==2.10.1
pip install accelerate==0.20.3
pip install tiktoken
pip install transformers_stream_generator

在本节教程里,我们将微调数据集放置在根目录 /dataset。

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
    "instrution":"回答以下用户问题,仅输出答案。",
    "input":"1+1等于几?",
    "output":"2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:##

{
    "instruction": "现在你要扮演皇帝身边的女人--甄嬛",
    "input":"你是谁?",
    "output":"家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["<|im_start|>system", "现在你要扮演皇帝身边的女人--甄嬛.<|im_end|>" + "\n<|im_start|>user\n" + example["instruction"] + example["input"] + "<|im_end|>\n"]).strip(), add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer("<|im_start|>assistant\n" + example["output"] + "<|im_end|>\n", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的,所以补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  # Yi-6B的构造就是这样的
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

然后加载我们的数据集,用上面定义的函数处理数据

# 将JSON文件转换为CSV文件
import pandas as pd
from datasets import Dataset
df = pd.read_json('/root/dataset/huanhuan.json')
ds = Dataset.from_pandas(df)

tokenized_id = ds.map(process_func, remove_columns=ds.column_names)

经过格式化的数据,也就是送入模型的每一条数据,都是一个字典,包含了 input_idsattention_masklabels 三个键值对,其中 input_ids 是输入文本的编码,attention_mask 是输入文本的 attention mask,labels 是输出文本的编码。decode之后应该是这样的:

<|im_start|>system
现在你要扮演皇帝身边的女人--甄嬛.<|im_end|>
<|im_start|>user
小姐,别的秀女都在求中选,唯有咱们小姐想被撂牌子,菩萨一定记得真真儿的——<|im_end|>
<|im_start|>assistant
嘘——都说许愿说破是不灵的。<|im_end|>
<|endoftext|>

我们可以输出一条文本观察一下:

print(tokenizer.decode(tokenized_id[0]['input_ids']))
print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[0]["labels"]))))

输出结果如下图所示:

在这里插入图片描述

加载tokenizer和半精度模型

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('01ai/Yi-6B-Chat', use_fast=False, trust_remote_code=True)

# 模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载
model = AutoModelForCausalLM.from_pretrained('01ai/Yi-6B-Chat', trust_remote_code=True, torch_dtype=torch.half, device_map="auto")

定义LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理
from peft import LoraConfig, TaskType, get_peft_model
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["q_attn", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

训练模型

首先,使用get_peft_model函数将基础模型和peft_config包装起来,以创建PeftModel。要了解模型中可训练参数的数量,可以使用print_trainable_parameters方法。

model = get_peft_model(model, config)
model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法
model.print_trainable_parameters()

接下来,我们自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
from transformers import DataCollatorForSeq2Seq, TrainingArguments, Trainer
args = TrainingArguments(
    output_dir="./output/Yi-6B",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    logging_steps=10,
    num_train_epochs=3,
    gradient_checkpointing=True,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True
)

最后,使用Traniner训练模型

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

模型训练完成后,会输出如下图所示的信息:
在这里插入图片描述

模型推理

下载好的模型被保存在了 ./output/Yi-6B 目录下,如果想要从头加载微调好的模型,需要执行下面的代码

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

peft_model_id = "output/Yi-6B/checkpoint-600"  # 这里我训练出效果最好的一版是 checkpoint-600,所以调用了这个,大家可以根据自己情况选择
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

然后使用以下代码进行模型推理:

model.eval()
input = tokenizer("<|im_start|>system\n现在你要扮演皇帝身边的女人--甄嬛.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n".format("你是谁?", "").strip() + "\nassistant\n ", return_tensors="pt").to(model.device)

max_length = 512

outputs = model.generate(
    **input,
    max_length=max_length,
    eos_token_id=7,
    do_sample=True,
    repetition_penalty=1.3,
    no_repeat_ngram_size=5,
    temperature=0.1,
    top_k=40,
    top_p=0.8,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/542018.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ThignsBoard通过服务端订阅共享属性

MQTT基础 客户端 MQTT连接 通过服务端订阅属性 案例 1、首先需要创建整个设备的信息&#xff0c;并复制访问令牌 ​​2、通过工具MQTTX连接上对应的Topic 3、测试链接是否成功 4、在MQTT上订阅对应的Topic 5、在客户端添加共享属性信息 6、查看整个设备的遥测数据 M…

数据库(2)

目录 6.buffer pool,redo log buffer和undo logo&#xff0c;redo logo,bin log概念以及关系&#xff1f; 7.从准备更新一条数据到事务的提交的流程描述&#xff1f; 8.能说下myisam和innodb的区别吗&#xff1f; 9.说下MySQL的索引有哪些吧&#xff1f; 10.什么是B树&…

基于Pytorch实现图像分类——基于jupyter

分类任务 网络基本构建与训练方法&#xff0c;常用函数解torch.nn.functional模块nn.Module模块 MNIST数据集下载 from pathlib import Path import requestsDATA_PATH Path("data") PATH DATA_PATH / "mnist"PATH.mkdir(parentsTrue, exist_okTrue)U…

vue3中使用webstocket

1.在项目中创建webstocket.ts文件 export default class SocketService {// 单例static instance null;static get Instance() {if (!this.instance) {this.instance new SocketService();}return this.instance;}// 和服务端连接的socket对象ws null;// 存储回调函数callB…

202206青少年软件编程(scratch图形化) 等级考试试卷(四级)

第1题&#xff1a;【 单选题】 执行下列程序&#xff0c; 说的内容是&#xff1f; &#xff08; &#xff09; A:使 B:命 C:初 D:心 【正确答案】: D 【试题解析】 : 注意标点符号也是一个字符&#xff0c; 连接后字符串是“牢记使命&#xff01; 不忘初心&#xff0c; …

宝藏免费音乐软件LX music

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 宝藏免费音乐软件LX music 前言LX Music的特色功能&#xff1a;音乐播放的新境界安装与配置&#xff1a;在不同平台上使用LX Music下载页面 主题定制 本文将深入研究LX Music&#xff0c;一款备受欢迎…

pytorch车牌识别

目录 使用pytorch库中CNN模型进行图像识别收集数据集定义CNN模型卷积层池化层全连接层 CNN模型代码使用模型 使用pytorch库中CNN模型进行图像识别 收集数据集 可以去找开源的数据集或者自己手做一个 最终整合成 类别分类的图片文件 定义CNN模型 卷积层 功能&#xff1a;提…

opencv基础图行展示

"""试用opencv创建画布并显示矩形框&#xff08;适用于目标检测图像可视化&#xff09; """ # 创建一个黑色的画布&#xff0c;图像格式(BGR) img np.zeros((512, 512, 3), np.uint8)# 画一个矩形&#xff1a;给定左上角和右下角坐标&#xff0…

Redis入门到通关之Hash命令

文章目录 ⛄介绍⛄命令⛄RedisTemplate API❄️❄️添加缓存❄️❄️设置过期时间(单独设置)❄️❄️添加一个Map集合❄️❄️提取所有的小key❄️❄️提取所有的value值❄️❄️根据key提取value值❄️❄️获取所有的键值对集合❄️❄️删除❄️❄️判断Hash中是否含有该值 ⛄…

文献阅读:猕猴的单个基底外侧杏仁核神经元表现出与额叶皮层不同的连接模式

文献介绍 「文献题目」 Single basolateral amygdala neurons in macaques exhibit distinct connectional motifs with frontal cortex 「研究团队」 Peter H. Rudebeck&#xff08;美国西奈山伊坎医学院&#xff09; 「发表时间」 2023-10-18 「发表期刊」 Neuron 「影响因…

Springboot+Vue项目-基于Java+MySQL的母婴商城系统(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;Java毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计 &…

Ubuntu去除烦人的顶部【活动】按钮

文章目录 一、需求说明二、打开 extensions 网站三、安装 GNOME Shell 插件四、安装本地连接器五、安装 Hide Activities Button 插件六、最终效果七、卸载本地连接器命令参考 本文所使用的 Ubuntu 系统版本是 Ubuntu 22.04 ! 一、需求说明 使用 Ubuntu 的过程中&#xff0c;屏…

【大语言模型】应用:10分钟实现搜索引擎

本文利用20Newsgroup这个数据集作为Corpus(语料库)&#xff0c;用户可以通过搜索关键字来进行查询关联度最高的News&#xff0c;实现对文本的搜索引擎&#xff1a; 1. 导入数据集 from sklearn.datasets import fetch_20newsgroupsnewsgroups fetch_20newsgroups()print(fNu…

在Linux驱动中,如何确保中断上下文的正确保存和恢复?

大家好&#xff0c;今天给大家介绍在Linux驱动中&#xff0c;如何确保中断上下文的正确保存和恢复&#xff1f;&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 在Linux驱动中&am…

AI图书推荐:如何在课堂上使用ChatGPT 进行教育

ChatGPT是一款强大的新型人工智能&#xff0c;已向公众免费开放。现在&#xff0c;各级别的教师、教授和指导员都能利用这款革命性新技术的力量来提升教育体验。 本书提供了一个易于理解的ChatGPT解释&#xff0c;并且更重要的是&#xff0c;详述了如何在课堂上以多种不同方式…

STM32利用软件I2C通讯读MPU6050的ID号

今天的读ID号是建立在上篇文章中有了底层的I2C通讯的6个基本时序来编写的。首先需要完成的就是MPU6050的初始化函数 然后就是编写 指定地址写函数&#xff1a; 一&#xff1a;开始 二&#xff1a;发送 从机地址读写位&#xff08;1&#xff1a;读 0&#xff1…

MySQL之索引失效、覆盖、前缀索引及单列、联合索引详细总结

索引失效 最左前缀法则 如果索引了多列(联合索引)&#xff0c;要遵守最左前缀法则&#xff0c;最左前缀法则指的是查询从索引的最左列开始&#xff0c;并且不跳过索引中的列。如果跳跃某一列&#xff0c;索引将部分失效&#xff08;后面的字段索引失效&#xff09;。 联合索…

第1章 计算机网络体系结构

王道学习 【考纲内容】 &#xff08;一&#xff09;计算机网络概述 计算机网络的概念、组成与功能&#xff1b;计算机网络的分类&#xff1b; 计算机网络的性能指标 &#xff08;二&#xff09;计算机网络体系结构与参考模型 计算机网络分层结…

权威Scrum敏捷开发企业级实训/敏捷开发培训课程

课程简介 Scrum是目前运用最为广泛的敏捷开发方法&#xff0c;是一个轻量级的项目管理和产品研发管理框架。 这是一个两天的实训课程&#xff0c;面向研发管理者、项目经理、产品经理、研发团队等&#xff0c;旨在帮助学员全面系统地学习Scrum和敏捷开发, 帮助企业快速启动敏…

Ansys Workbench拓扑优化教程

很基础。 前言 进行拓扑优化的好处在于可以简化结构&#xff0c;满足力学性能的同时简化结构。 如赵州桥的一大一小的拱&#xff0c;就可以用拓扑优化优化出来&#xff0c;可见一千四百多年以前古人的智慧是多么丰富。 步骤 大体的步骤是需要 1.先导入模型&#xff08;需…