[大模型]MiniCPM-2B-chat Lora Full 微调

MiniCPM-2B-chat 介绍

MiniCPM 是面壁智能与清华大学自然语言处理实验室共同开源的系列端侧大模型,主体语言模型 MiniCPM-2B 仅有 24亿(2.4B)的非词嵌入参数量。

经过 SFT 后,MiniCPM 在公开综合性评测集上,MiniCPM 与 Mistral-7B相近(中文、数学、代码能力更优),整体性能超越 Llama2-13B、MPT-30B、Falcon-40B 等模型。
经过 DPO 后,MiniCPM 在当前最接近用户体感的评测集 MTBench上,MiniCPM-2B 也超越了 Llama2-70B-Chat、Vicuna-33B、Mistral-7B-Instruct-v0.1、Zephyr-7B-alpha 等众多代表性开源大模型。
以 MiniCPM-2B 为基础构建端侧多模态大模型 MiniCPM-V,整体性能在同规模模型中实现最佳,超越基于 Phi-2 构建的现有多模态大模型,在部分评测集上达到与 9.6B Qwen-VL-Chat 相当甚至更好的性能。
经过 Int4 量化后,MiniCPM 可在手机上进行部署推理,流式输出速度略高于人类说话速度。MiniCPM-V 也直接跑通了多模态大模型在手机上的部署。
一张1080/2080可高效参数微调,一张3090/4090可全参数微调,一台机器可持续训练 MiniCPM,二次开发成本较低。

环境准备

在autodl平台中租一个单卡3090等24G显存的显卡机器,如下图所示镜像选择PyTorch–>2.1.0–>3.10(ubuntu22.04)–>12.1
接下来打开刚刚租用服务器的JupyterLab, 图像 并且打开其中的终端开始环境配置、模型下载和运行演示。

注意:这里要选择一个 cpuintel 的机器,amdcpu 有可能会导致 deepspeed zero2 offload加载失败。

在这里插入图片描述

接下来打开刚刚租用服务器的JupyterLab,并且打开其中的终端开始环境配置、模型下载和运行demo

pip换源和安装依赖包

# 升级pip
python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

pip install modelscope transformers sentencepiece accelerate langchain

MAX_JOBS=8 pip install flash-attn --no-build-isolation

pip install peft deepspeed

注意:flash-attn 安装会比较慢,大概需要十几分钟。

模型下载

使用 modelscope 中的snapshot_download函数下载模型,第一个参数为模型名称,参数cache_dir为模型的下载路径。

/root/autodl-tmp 路径下新建 download.py 文件并在其中输入以下内容,粘贴代码后记得保存文件,如下图所示。并运行 python /root/autodl-tmp/download.py执行下载,模型大小为 10 GB,下载模型大概需要 5~10 分钟

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os
model_dir = snapshot_download('OpenBMB/MiniCPM-2B-sft-fp32', cache_dir='/root/autodl-tmp', revision='master')

数据集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
    "instrution":"回答以下用户问题,仅输出答案。",
    "input":"1+1等于几?",
    "output":"2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:

{
    "instruction": "现在你要扮演皇帝身边的女人--甄嬛",
    "input":"你是谁?",
    "output":"家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录的dataset下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 512    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<用户>{example['instruction']+example['input']}<AI>", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

MiniCPM 所使用的 Prompt_Templatetransformers默认的Prompt_Template,感兴趣的同学可以在transformers仓库查看源码构造哦~

加载模型

模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载。对于自定义的模型一定要指定trust_remote_code参数为True

tokenizer = AutoTokenizer.from_pretrained('./OpenBMB/miniCPM-bf32', use_fast=False, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained('./OpenBMB/miniCPM-bf32', trust_remote_code=True, torch_dtype=torch.half, device_map="auto")

定义LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

注意:如果要进行全量微调的话,可以在加载模型的时候选择不加载LoraConfig

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch

等等,还有很多参数可以调整,感兴趣的同学可以去看看transformers的源码哦,也欢迎各位同学来给本项目提交PR哦~

使用deepspeed训练

deepspeed 是一个分布式训练的工具,可以很方便的进行分布式训练,这里我们使用deepspeed进行训练。

  • 如果你要进行lora训练,那请在train.py中将67行代码的注释取消掉,然后运行train.sh脚本即可。

lora 训练大概需要14G显存左右,因为没有开梯度检查所以显存占用会多一点。

# 创建模型并以半精度形式加载
model = AutoModelForCausalLM.from_pretrained(finetune_args.model_path, trust_remote_code=True, torch_dtype=torch.half, device_map={"": int(os.environ.get("LOCAL_RANK") or 0)})

model = get_peft_model(model, config)
  • 全量训练直接运行train.sh脚本即可。

全量训练,大概需要22G显存左右,因为开启了deepspeed zero2 的 cpu offload 功能,优化器参数会加载到cpu上进行计算,所以降低了显存。

完整train.py源码

from datasets import Dataset
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, HfArgumentParser, Trainer
import os
import torch
from peft import LoraConfig, TaskType, get_peft_model
from dataclasses import dataclass, field
import deepspeed
deepspeed.ops.op_builder.CPUAdamBuilder().load()

@dataclass
class FinetuneArguments:
    # 微调参数
    # field:dataclass 函数,用于指定变量初始化
    model_path: str = field(default="./OpenBMB/miniCPM-bf32")

# 用于处理数据集的函数
def process_func(example):
    MAX_LENGTH = 512    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<用户>{example['instruction']+example['input']}<AI>", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

 # loraConfig
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["q_proj", "v_proj"],  # 这个不同的模型需要设置不同的参数,需要看模型中的attention层
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)


if "__main__" == __name__:
    # 解析参数
    # Parse 命令行参数
    finetune_args, training_args = HfArgumentParser(
        (FinetuneArguments, TrainingArguments)
    ).parse_args_into_dataclasses()

    # 处理数据集
    # 将JSON文件转换为CSV文件
    df = pd.read_json('./huanhuan.json')
    ds = Dataset.from_pandas(df)
    # 加载tokenizer
    tokenizer = AutoTokenizer.from_pretrained(finetune_args.model_path, use_fast=False, trust_remote_code=True)
    tokenizer.padding_side = 'right'
    tokenizer.pad_token_id = tokenizer.eos_token_id
    # 将数据集变化为token形式
    tokenized_id = ds.map(process_func, remove_columns=ds.column_names)

    # 创建模型并以半精度形式加载
    model = AutoModelForCausalLM.from_pretrained(finetune_args.model_path, trust_remote_code=True, torch_dtype=torch.half, device_map={"": int(os.environ.get("LOCAL_RANK") or 0)})
    # model = get_peft_model(model, config)
    # 使用trainer训练
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_id,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
        )
    trainer.train() # 开始训练
    trainer.save_model() # 保存模型

train.sh

num_gpus=1

deepspeed --num_gpus $num_gpus train.py \
    --deepspeed ./ds_config.json \
    --output_dir="./output/MiniCPM" \
    --per_device_train_batch_size=4 \
    --gradient_accumulation_steps=4 \
    --logging_steps=10 \
    --num_train_epochs=3 \
    --save_steps=500 \
    --learning_rate=1e-4 \
    --save_on_each_node=True \

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/703958.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙开发文件管理:【@ohos.securityLabel (数据标签)】

数据标签 该模块提供文件数据安全等级的相关功能&#xff1a;向应用程序提供查询、设置文件数据安全等级的JS接口。 说明&#xff1a; 本模块首批接口从API version 9开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 导入模块 import security…

pycharm终端pip安装模块成功但还是显示找不到 ModuleNotFoundError: No module named

报错信息&#xff1a; ModuleNotFoundError: No module named 但是分明已经安装过此模块&#xff1a; 在cmd运行pip list 查看所有安装过的包找到了安装过&#xff1a; 如果重新安装就是这样&#xff1a;显示已经存在了 问题排查&#xff1a; 直接根据重新安装的显示已存在的…

用 KV 缓存量化解锁长文本生成

很高兴和大家分享 Hugging Face 的一项新功能: KV 缓存量化 &#xff0c;它能够把你的语言模型的速度提升到一个新水平。 太长不看版: KV 缓存量化可在最小化对生成质量的影响的条件下&#xff0c;减少 LLM 在长文本生成场景下的内存使用量&#xff0c;从而在内存效率和生成速度…

禁渔期水域监管:EasyCVR视频智能监控方案

一、背景与需求分析 根据农业部印发的《中国渔政亮剑2024系列专项执法行动方案》&#xff0c;我国将持续推进长江十年禁渔、海洋伏季休渔、黄河等内陆重点水域禁渔等专项行动。根据四川省相关规定&#xff0c;每年3月1日至6月30日为禁渔期&#xff0c;在此期间&#xff0c;四川…

C语言 | Leetcode C语言题解之第148题排序链表

题目&#xff1a; 题解&#xff1a; struct ListNode* merge(struct ListNode* head1, struct ListNode* head2) {struct ListNode* dummyHead malloc(sizeof(struct ListNode));dummyHead->val 0;struct ListNode *temp dummyHead, *temp1 head1, *temp2 head2;while…

等级考试3-2021年12月题目

十二月&#xff1a; #include <iostream> using namespace std; int sum(int); int main(){int n;cin>>n;for(int i1;i<100000;i){for(int j1;j<i;j){if(sum(i)-j*2n){cout<<j<< <<i;return 0; }}}return 0; } int sum(int n){int sum0;fo…

大数据学习——linux操作系统(Centos)安装mysql(Hive的元数据库)

一. 准备工作 1. 打开虚拟机并连接shell工具 2. 将mysql安装包上传至虚拟机 mysql安装包 提取码&#xff1a;6666 将下载好的jar包拖至install_package目录下 3. 检查环境 rpm -qa|grep mariadb 如果上述命令返回有结果 那么进行mariadb的卸载 rpm -e --nodeps mariadb-l…

【前端项目笔记】1 登录与登出功能实现

项目笔记 ☆☆代表面试常见题 前后端分离&#xff1a;后端负责写接口&#xff0c;前端负责调接口。 登录/退出功能 登录业务流程 登录页面&#xff1a;用户名密码 调用后台接口进行验证 通过验证&#xff0c;根据后台响应状态跳到项目主页 登录业务相关技术点&#xff1…

计算机网络 —— 运输层(四次挥手)

计算机网络 —— 运输层&#xff08;四次挥手&#xff09; 四次挥手客户端到服务器的关闭第一次挥手第二次挥手 服务器到客户端连接的关闭第三次挥手第四次挥手 等待时间的必要性 我们今天来看TCP协议的四次挥手&#xff1a; 四次挥手 TCP的四次挥手&#xff08;Four-Way Han…

BigDecimal-解决java中的浮点运算

《阿里巴巴 Java 开发手册》中提到&#xff1a;“为了避免精度丢失&#xff0c;可以使用 BigDecimal 来进行浮点数的运算”。浮点数的运算竟然还会有精度丢失的风险吗&#xff1f;确实会&#xff01; 示例代码&#xff1a; float a 2.0f - 1.9f; float b 1.8f - 1.7f; Syst…

【沟通管理】项目经理《葵花宝典》之跨部门沟通

为什么每次跟其它部门的沟通总是不欢而散&#xff1f; 为什么每次想好好的就事论事的时候&#xff0c;却总是像在吵架&#xff1f; 为什么沟通总是不同频&#xff1f; 这是不是你作为项目经理在跨部门沟通时经常会遇到的问题&#xff1f; 在企业项目管理中&#xff0c;跨部门沟…

【SpringCloud学习笔记】RabbitMQ(上)

1. RabbitMQ简介 官网地址&#xff1a;https://www.rabbitmq.com/ 2. 安装方式 安装前置准备&#xff1a; 此处基于Linux平台 Docker进行安装&#xff0c;前置准备如下&#xff1a; Linux云服务器 / 虚拟机Docker环境 安装命令&#xff1a; docker run \-e RABBITMQ_DEFAU…

ControlNet作者新作Omost 一句话将LLM的编码能力转化为图像生成能力,秒变构图小作文,再也不用为不会写提示词担心了!

近日&#xff0c;ControlNet的作者推出了一个全新的项目—Omost。Omost是一个将LLM的编码能力转化为图像生成能力的项目。对现有图像模型的提示词理解有着巨大的帮助。通过很短的提示词&#xff0c;就可以生成非常详细并且空间表现很准确的图片。 完美解决新手小白不会写提示词…

数据结构01 栈及其相关问题讲解

栈是一种线性数据结构&#xff0c;栈的特征是数据的插入和删除只能通过一端来实现&#xff0c;这一端称为“栈顶”&#xff0c;相应的另一端称为“栈底”。 栈及其特点 用一个简单的例子来说&#xff0c;栈就像一个放乒乓球的圆筒&#xff0c;底部是封住的&#xff0c;如果你想…

QField测量功能

QField提供开箱即用的测量功能&#xff0c;可以灵活更改工程中测量距离和面积的单位。您可以在 "常规" 部分导航到 "工程" 菜单&#xff0c;并选择 "工程属性..." 完成此操作。 要启用测量工具&#xff0c;请打开主菜单并选择 测量工具 。 启…

JS 描述二叉树(含二叉树的前、中、后序遍历)

JS 对象描述二叉树 const binaryTreeNode {value: A,left: {value: B,left: {value: C,right: {value: G}},right: {value: D}},right: {value: E,left: {value: F,left: { value: H },right: { value: L }}} }前、中、后序遍历结果 前序遍历&#xff08;中在前&#xff09;&a…

网络的下一次迭代:AVS 将为 Web2 带去 Web3 的信任机制

撰文&#xff1a;Sumanth Neppalli&#xff0c;Polygon Ventures 编译&#xff1a;Yangz&#xff0c;Techub News 本文来源香港Web3媒体&#xff1a;Techub News AVS &#xff08;主动验证服务&#xff09;将 Web2 的规模与 Web3 的信任机制相融合&#xff0c;开启了网络的下…

vue中使用emit

1. vue中使用emit 1.1. 在子组件中触发事件 1.1.1. 子组件示例 (ChildComponent.vue) 1.2. 在父组件中监听事件 1.2.1. 父组件示例 (ParentComponent.vue) vue3中使用emit 1.3. 使用 setup 函数和 defineEmits 1.3.1. 子组件示例 (ChildComponent.vue)1.3.2. 父组件示例 (Pare…

Node.js进阶——数据库

文章目录 一、步骤1、安装操作 MySQL数据库的第三方模块(mysql)2、通过 mysql 模块连接到 MySQL 数据库3、测试 二、操作 mysql 数据库1、查询语句2、插入语句3、插入语句快捷方式4、更新数据5、更新语句快捷方式6、删除数据7、标记删除 二、前后端的身份认证1、web开发模式1&a…

AIRNet模型使用与代码分析(All-In-One Image Restoration Network)

AIRNet提出了一种较为简易的pipeline&#xff0c;以单一网络结构应对多种任务需求&#xff08;不同类型&#xff0c;不同程度&#xff09;。但在效果上看&#xff0c;ALL-In-One是不如One-By-One的&#xff0c;且本文方法的亮点是batch内选择patch进行对比学习。在与sota对比上…