如何训练一个大模型:LoRA篇

目录

写在前面

一、LoRA算法原理

1.设计思想

2.具体实现

二、peft库

三、完整的训练代码

四、总结


写在前面

        现在有很多开源的大模型,他们一般都是通用的,这就意味着这些开源大模型在特定任务上可能力不从心。为了适应我们的下游任务,就需要对预训练模型进行微调。

        全参数微调有两个问题:在新的数据集上训练,会破坏大模型原来的能力,使其泛化能力急剧下降;而且现在的模型参数动辄几十亿上百亿,要执行全参数微调的话,他贵啊!!

        于是LoRA出现了, LoRA(Low-Rank Adaptation)是微软提出的一种参数有效的微调方法,可以降低微调占用的显存以及更轻量化的迁移。同时解决了上述两个问题,那它凭什么这么厉害?往下看吧。

一、LoRA算法原理

1.设计思想

        论文地址:https://arxiv.org/pdf/2106.09685

        模型是过参数化的,它们有更小的内在维度,模型主要依赖于这个低的内在维度(low intrinsic dimension)去做任务适配。假设模型在适配任务时参数的改变量是低秩的,由此引出低秩自适应方法lora,通过低秩分解来模拟参数的改变量,从而以极小的参数量来实现大模型的间接训练。

       上面那段话也许有点难以理解。简单来讲,LoRA是大模型的低秩适配器,或者就简单的理解为适配器,在图像生成中可以将lora理解为某种图像风格(比如SD社区中的各种漂亮妹子的lora,可插拔式应用,甚至组合式应用实现风格的融合)的适配器,在NLP中可以将其理解为某个任务的适配器(比如基于通用大模型训练的各个领域的专家大模型)。

2.具体实现

        LoRA的实现方式是在基础模型的线性变换模块(全连接、Embedding、卷积)旁边增加一个旁路,这个旁路是由两个小矩阵做内积得来的,两个小矩阵的中间维度,就是秩!!

        通过低秩分解(先降维再升维)来模拟参数的更新量。

        下面是LoRA的公式:

h = W_0x +\Delta Wx = W_0x + ((A \bigotimes B) * \alpha / r)x

       上面公式中x是这一层的输入,h是这一层的输出,W_0是基础模型的权重参数;A和B是两个小矩阵,A的输入和B的输出形状跟W_0一样,A的输出和B的输入一样,称为秩,秩一般很小,微调的所有“新知识”都保存在A和B里面\alpha /r是一个缩放系数,这个数越大,LoRA权重的影响就越大。

        下面就是经典的LoRA运算流程图:

        我们以ChatGLM的attention模块的query_key_value(是一个linear(4096, 12288))为例,描述一下流程,其中输入4096、输出12288,LoRA的秩是8:

        初始化时,lora_A采用高斯分布初始化,lora_B初始化为全0,保证训练开始时旁路为0矩阵;        

        训练时,原模型固定,只训练降维矩阵A和升维矩阵B;

        推理时需要做参数合并,就是将AB的内积(一个与基础模型形状一样的低秩矩阵)加到原参数上,这样不引入额外的推理延迟。对于上图的例子,lora_A与lora_B做内积,得到4096x1228的参数矩阵,然后与基础模型W相加就可以了。

        我们来算算需要训练多少参数,如果是全参数需要训练4096*12288=50331648个参数,LoRA需要训练4096*8+8*12288=131072,参数可是数量级的减少啊。

二、peft库

        Pytorch中peft库实现了LoRA算法,而且使用非常方便,我们以ChatGLM代码为例,看一下LoRA对ChatGLM模型做了什么,直接上代码:

from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModel, HfArgumentParser, TrainingArguments

from finetune import CastOutputToFloat, FinetuneArguments


def count_params(model):
    for name, param in model.named_parameters():
        print(name, param.shape)



def make_peft_model():
    # 初始化原模型
    model = AutoModel.from_pretrained(
        "THUDM/chatglm-6b", load_in_8bit=False, trust_remote_code=True, device_map="auto", local_files_only=True
    ).float()
    

    # 给原模型施加LoRA
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=True,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=['query_key_value'],
    )
    model = get_peft_model(model, peft_config).float()
    count_params(model)



if __name__ == '__main__':
    make_peft_model()

        输出如下:       

base_model.model.transformer.word_embeddings.weight torch.Size([130528, 4096])
base_model.model.transformer.layers.0.input_layernorm.weight torch.Size([4096])
base_model.model.transformer.layers.0.input_layernorm.bias torch.Size([4096])
base_model.model.transformer.layers.0.attention.query_key_value.base_layer.weight torch.Size([12288, 4096])
base_model.model.transformer.layers.0.attention.query_key_value.base_layer.bias torch.Size([12288])
base_model.model.transformer.layers.0.attention.query_key_value.lora_A.default.weight torch.Size([8, 4096])
base_model.model.transformer.layers.0.attention.query_key_value.lora_B.default.weight torch.Size([12288, 8])

base_model.model.transformer.layers.0.attention.dense.weight torch.Size([4096, 4096])
base_model.model.transformer.layers.0.attention.dense.bias torch.Size([4096])
base_model.model.transformer.layers.0.post_attention_layernorm.weight torch.Size([4096])
base_model.model.transformer.layers.0.post_attention_layernorm.bias torch.Size([4096])
base_model.model.transformer.layers.0.mlp.dense_h_to_4h.weight torch.Size([16384, 4096])
base_model.model.transformer.layers.0.mlp.dense_h_to_4h.bias torch.Size([16384])
base_model.model.transformer.layers.0.mlp.dense_4h_to_h.weight torch.Size([4096, 16384])
base_model.model.transformer.layers.0.mlp.dense_4h_to_h.bias torch.Size([4096])
base_model.model.transformer.layers.1.input_layernorm.weight torch.Size([4096])
base_model.model.transformer.layers.1.input_layernorm.bias torch.Size([4096])

......

        可以看到模型中被添加了LoRA模块(红色部分),是根据全连接“query_key_value”生成的。因为query_key_value层输入是4096,输出是12288,而配置中LoRA的秩是8,所以两个LoRA块是(8,4096)和(12288, 8)

        代码也很好理解,get_peft_model方法将原模型参数冻结并且根据配置向模型中添加LoRA模块。

        解释一下配置LoraConfig,下面是这个对象的主要参数:

 1.task_type:

        SEQ_CLS:序列分类(Sequence Classification)任务。这种任务涉及对输入序列整体进行分类,例如情感分析、文本分类等。

        SEQ_2_SEQ_LM:序列到序列语言建模(Sequence-to-Sequence Language Modeling)任务。这种任务能够将一个输入序列映射到另一个输出序列,例如机器翻译、文本摘要等。

        CAUSAL_LM:因果语言建模(Causal Language Modeling)任务。这种任务涉及训练一个模型,使其能够预测给定先前上下文的下一个标记,例如自动补全、语言生成等。

        TOKEN_CLS:标记分类(Token Classification)任务。这种任务涉及对输入序列中的每个标记进行分类,例如命名实体识别、词性标注等。

        QUESTION_ANS:问答(Question Answering)任务。这种任务涉及根据给定的问题和相关的上下文文本来预测答案。输入是Prompt+问题。

        FEATURE_EXTRACTION:特征提取(Feature Extraction)任务。这种任务涉及从文本或序列中提取有用的特征,以供其他任务或模型使用。

2.r:LoRA秩的维度,这数越大,微调带来的“影响”越强,但是需要训练的参数量会增加。

3.lora_alpha:LoRA在前向传播的过程中引入一个额外的扩展系数(scaling coefficient),用于将LoRA权重应用于预训练权重。这个数越大,LoRA权重的影响就越大。

4.target_modules:要施加LoRA的模块名称,需要注意的是,参数是字符串数组,模块类型必须是`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`中的一个。比如这个例子中还可以填写"word_embeddings"和"dense"。

三、完整的训练代码

        现在给出一个完整的基于LoRA的ChatGLM训练代码,peft库在原模型基础上添加LoRA非常方便,对代码的侵入也很小。下面的代码我添加了注释,流程还是很清楚的:

from transformers.integrations import TensorBoardCallback
from torch.utils.tensorboard import SummaryWriter
from transformers import TrainingArguments
from transformers import Trainer, HfArgumentParser
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from peft import get_peft_model, LoraConfig, TaskType
from dataclasses import dataclass, field
import datasets
import os


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)


@dataclass
class FinetuneArguments:
    dataset_path: str = field(default="data/alpaca")
    model_path: str = field(default="output")
    lora_rank: int = field(default=8)


class CastOutputToFloat(nn.Sequential):
    def forward(self, x):
        return super().forward(x).to(torch.float32)


def data_collator(features: list) -> dict:
    len_ids = [len(feature["input_ids"]) for feature in features]
    longest = max(len_ids)
    input_ids = []
    labels_list = []
    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
        ids = feature["input_ids"]
        seq_len = feature["seq_len"]
        labels = (
            [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
        )
        ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
        _ids = torch.LongTensor(ids)
        labels_list.append(torch.LongTensor(labels))
        input_ids.append(_ids)
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels_list)
    return {
        "input_ids": input_ids,
        "labels": labels,
    }



class ModifiedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

    def save_model(self, output_dir=None, _internal_call=False):
        self.model.save_pretrained(output_dir)


def main():
    writer = SummaryWriter()
    # 组织训练参数
    finetune_args, training_args = HfArgumentParser(
        (FinetuneArguments, TrainingArguments)
    ).parse_args_into_dataclasses()

    # init model
    model = AutoModel.from_pretrained(
        "THUDM/chatglm-6b", load_in_8bit=False, trust_remote_code=True, device_map="auto", local_files_only=True
    ).float()
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()
    # 模型是可以并行化的。
    model.is_parallelizable = True
    # 启用模型的并行化。
    model.model_parallel = True
    # 将模型的 lm_head(语言模型头)的输出转换为浮点数类型。
    model.lm_head = CastOutputToFloat(model.lm_head)
    # 禁用模型配置中的缓存,用于禁止缓存中间结果,可以减少显存占用,但是训练时间会变长
    model.config.use_cache = (
        False  # silence the warnings. Please re-enable for inference!
    )

    # LoRA配置
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=finetune_args.lora_rank,
        lora_alpha=32,
        lora_dropout=0.1,
    )
    # 对模型使用LoRA
    model = get_peft_model(model, peft_config).float()

    # 使用alpaca数据集
    dataset = datasets.load_from_disk(finetune_args.dataset_path)
    print(f"\n{len(dataset)=}\n")

    # for d in dataset.iter(batch_size=1):
    #     print("d:", d)

    # start train
    trainer = ModifiedTrainer(
        model=model,
        train_dataset=dataset,
        args=training_args,
        callbacks=[TensorBoardCallback(writer)],
        data_collator=data_collator,
    )
    trainer.train()
    writer.close()
    # 存训练后的参数
    model.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
    main()

        训练之后模型文件会保存在output_dir目录中。到这里我们发现一个问题,毕竟LoRA在原模型的基础上加了分支,这会带来推理效率的降低,其实我们调用merge_and_unload方法就能将LoRA的分支模块合并到基础模型,推理代码如下:

from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModel, AutoModelForSeq2SeqLM
import torch
from transformers import AutoTokenizer

# 加载基础模型
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)

# 配置LoRA
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=True,
    target_modules=['query_key_value'],
    r=8, lora_alpha=32, lora_dropout=0.1
)
# 对模型使用LoRA
model = get_peft_model(model, peft_config).half()
# 加载LoRA参数
model.load_state_dict(torch.load("output/checkpoint-1000/adapter_model.bin", map_location=torch.device("cuda")), strict=False)
# 将LoRA的分支模块合并到基础模型
model.merge_and_unload()

while True:
    prompt = input("Prompt: ")
    inputs = tokenizer(prompt, return_tensors="pt")
    model.params_dtype = torch.float32
    response = model.generate(input_ids=inputs["input_ids"],
                              max_length=inputs["input_ids"].shape[-1] + 128)
    response = response[0, inputs["input_ids"].shape[-1]:]
    print("responseL", response)
    for r in response:
        print(r, ":", tokenizer.decode([r], skip_special_tokens=False))
    print("Response:", tokenizer.decode(response, skip_special_tokens=True))

四、总结

1.LoRA的实现方式是在原模型的线性变换模块(全连接、Embedding、卷积)旁边增加一个旁路,通过低秩分解(先降维再升维)来模拟参数的更新量。

2.LoRA模块由两个小矩阵组成,这两个矩阵内积的输入输出形状与原模型一致,大模型需要的“新知识”就存在这个模块中;

3.秩可以很小,有实验表明,就算秩=1,效果也不是很差;

4.尽量多的对模型中的线性变换模块使用秩很小LoRA;而不是对一个模块使用秩很大的LoRA;

5.推理时需要做参数合并,就是将AB的内积加到原参数上,从而不引入额外的推理延迟;

5.LoRA智能一定程度提升模型在某个领域的能力,并不能使模型发生根本性的能力提升。

LoRA就介绍到这里,关注不迷路(#^.^#)

关注订阅号了解更多精品文章

交流探讨、商务合作请加微信

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/619576.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用python写算法——队列笔记

1.队列定义 队列是一种特殊的线性表,它只允许在表的前端进行删除操作,在表的后端进行插入操作,和栈一样,队列是一种操作受限制的线性表。进行插入操作的端称为队尾,进行删除操作的端称为队头。队列中没有元素时&#…

C控制语句:分支和跳转

1.1if语句 //colddays.c --找出0摄氏度以下的天数占总天数的百分比 #include <stdio.h>int main(void) {const int FREEZING 0;float temperature;int cold_days 0;int all_days 0;printf("Enter the list of daily low temperature.\n");printf("Use…

Kexp 动态展示 k8s 资源对象依赖关系

kexp[1] 旨在以可视化的方式帮助用户理解和探索 Kubernetes 的能力。 适用场景&#xff1a; 学习和探索 Kubernetes 的功能。 应用开发&#xff0c;提供每个应用的对象图预设。 控制器和操作器的开发&#xff0c;支持动态对象图。 即将推出类似 Postman 的 Kubernetes API …

程序猿成长之路之数据挖掘篇——距离公式介绍

上一篇介绍了朴素贝叶斯&#xff0c;那么这次讲讲距离公式 什么是距离公式 用自己的话来说距离公式就是判断两个属性&#xff08;参数&#xff09;相似度的度量公式,比如以两点间距离为例&#xff0c;A地经纬度为(110.9802,120.9932)&#xff0c;B地经纬度为(110.9980,120.828…

Java | Leetcode Java题解之第86题分隔链表

题目&#xff1a; 题解&#xff1a; class Solution {public ListNode partition(ListNode head, int x) {ListNode small new ListNode(0);ListNode smallHead small;ListNode large new ListNode(0);ListNode largeHead large;while (head ! null) {if (head.val < x…

PyQt5中的LineEdit单行文本框

文章目录 1. 简介1.1 常用方法&#xff1a;1.2 常用信号&#xff1a; 2. LineEdit常用方法使用案例3. LineEdit常用信号使用案例 1. 简介 在PyQt5中&#xff0c;LineEdit&#xff08;单行文本框&#xff09;是一个常用的组件&#xff0c;它允许用户输入文本。以下是一些LineEd…

【游戏引擎】unity

目录 Unity入门教程&#xff1a;从零到英雄的旅程前言第一步&#xff1a;下载和安装Unity第二步&#xff1a;创建你的第一个Unity项目第三步&#xff1a;熟悉Unity界面第四步&#xff1a;创建一个简单的游戏对象第五步&#xff1a;编写脚本赋予游戏对象生命第六步&#xff1a;运…

华为OD机试【统一限载货物数最小值】(java)(200分)

1、题目描述 火车站附近的货物中转站负责将到站货物运往仓库&#xff0c;小明在中转站负责调度 2K 辆中转车(K辆干货中转车&#xff0c;K 辆湿货中转车)货物由不同供货商从各地发来&#xff0c;各地的货物是依次进站&#xff0c;然后小明按照卸货顺序依次装货到中转车&#xf…

如何解决pycharm在HTML文件中注释快捷键出错的问题(HTML注释规则出错)

文章目录 💢 问题 💢🏡 演示环境 🏡💯 解决方案 💯⚓️ 相关链接 ⚓️💢 问题 💢 你是否在编程时遇到过这样的烦恼?当你正专注地编写HTML代码,想要快速注释掉某部分内容时,却发现PyCharm的注释快捷键失灵了(没有使用正确的注释格式)。这不仅打断了你的工作…

【论文笔记】利用扩散模型DDPM做变化检测change detection

去噪扩散模型DDPM去年开始在各种视觉任务取得惊人的效果&#xff0c;变化检测领域也不例外&#xff0c;本文介绍两篇关于如何使用扩散模型实现变化检测的论文。第一篇做法较为自然&#xff0c;先利用遥感数据预训练DDPM&#xff0c;然后将预训练好的网络当作变化检测任务的特征…

设计模式-结构型-适配器模式-Adapter

地址类 public class Address {public void street() {System.out.println("普通的街道");}public void zip() {System.out.println("普通的邮政编码");}public void city() {System.out.println("普通的城市");} } 荷兰地址类 public class …

用lobehub打造一个永久免费的AI个人助理

Lobe Chat是一个开源的高性能聊天机器人框架&#xff0c;它被设计来帮助用户轻松创建和部署自己的聊天机器人。这个框架支持多种智能功能&#xff0c;比如语音合成&#xff08;就是让机器人能说话&#xff09;&#xff0c;还能理解和处理多种类型的信息&#xff0c;不仅限于文字…

关于USB 3.1电气参数的探讨

目录 0 引言 1 抖动预算 2 时钟恢复-CDR 3 测试码型-PRBS16 4 传输码型-128b/132b 5 眼图模板-Eye Mask 6 发射均衡 7 接收均衡 7.1 CTLE均衡 7.2 DFE均衡

Postman历史版本安装与runner测试

前言 实际上就是笔者本地做demo&#xff0c;postman使用了最新版本&#xff0c;本身也没问题&#xff0c;不过postman不支持不登录做runner测试了&#xff0c;很多功能必须登录账号才能使用&#xff0c;否则只能使用http工具发送的能力&#xff0c;而postman本身就是一个简单工…

栈和队列经典练习题

目录 前言&#xff1a; 一、括号匹配问题 1.题目描述 2.解题思路 3.题目链接 二、用队列实现栈 1.题目描述 2.解题思路 3.题目链接 三、用栈实现队列 1.题目描述 2.题目分析 3.题目链接 四、设计循环队列 1.题目描述 2. 题目分析 3.题目链接 最后 前言&#xff1a; 前…

JCR一区 | Matlab实现1D-2D-GASF-CNN-BiLSTM-MATT的多通道输入数据分类预测

JCR一区 | Matlab实现1D-2D-GASF-CNN-BiLSTM-MATT的多通道输入数据分类预测 目录 JCR一区 | Matlab实现1D-2D-GASF-CNN-BiLSTM-MATT的多通道输入数据分类预测分类效果基本介绍程序设计参考资料 分类效果 基本介绍 Matlab实现1D-2D-GASF-CNN-BiLSTM-MATT的多通道输入数据分类预…

未授权访问:VNC未授权访问

目录 1、漏洞原理 2、环境搭建 3、未授权访问 防御手段 今天继续学习各种未授权访问的知识和相关的实操实验&#xff0c;一共有好多篇&#xff0c;内容主要是参考先知社区的一位大佬的关于未授权访问的好文章&#xff0c;还有其他大佬总结好的文章&#xff1a; 这里附上大…

修改MTU值解决Linux下运行top命令卡死问题

上周明月的Linux服务器上运行top命令总是莫名的出现卡死现象&#xff0c;甚至是CtrlC都无法终止进程&#xff0c;今天终于抽空找到了解决办法&#xff0c;原来是需要修改Linux的MTU值&#xff0c;将服务器操作系统数据包调小&#xff0c;加上VxLAN数据包小于1500即可。 top命令…

Python-VBA函数之旅-sum函数

目录 一、sum函数的常见应用场景 二、sum函数使用注意事项 三、如何用好sum函数&#xff1f; 1、sum函数&#xff1a; 1-1、Python&#xff1a; 1-2、VBA&#xff1a; 2、推荐阅读&#xff1a; 个人主页&#xff1a; https://myelsa1024.blog.csdn.net/ 一、sum函数的常…

摩苏尔大坝形变监测

摩苏尔大坝&#xff0c;是伊拉克最大的大坝。它位于底格里斯河35公里&#xff0c;北距摩苏尔市&#xff0c;这是一座粘土质地的水坝&#xff0c;高113米&#xff0c;长3.2公里&#xff0c;于1986落成。 大坝建成后不久&#xff0c;大坝就遇到了由软石膏地基造成的一些结构性问题…