LoRA:大模型的低阶自适用(使用BERT在IMDB数据集上运用LoRA微调)

文章目录

  • 简介
  • LoRA文章主要贡献
  • LoRA技术模型图
  • 技术细节
  • 论文实验结果
  • LoRA在bert的运用
    • LoRA核心代码
    • 实战分析

简介

论文链接https://arxiv.org/pdf/2106.09685v2.pdf

本文将先介绍论文中的LoRA技术,然后以BERT为例在IMDB数据集上代码实现运用这项微调技术。

代码+数据

LoRA文章主要贡献

文章的主要贡献是提出了一种名为LoRA(Low-Rank Adaptation)的方法,用于在不牺牲模型质量的前提下,高效地对大型预训练语言模型进行微调。LoRA的核心思想是在Transformer架构的每一层注入可训练的低秩分解矩阵,同时冻结预训练模型权重,从而大幅减少下游任务中的可训练参数数量。

具体来说,LoRA的主要贡献包括:

高效的参数更新:LoRA通过低秩矩阵更新模型权重,而不是对整个模型进行微调。这种方法大幅减少了所需的训练参数数量和GPU内存需求。例如,与GPT-3 175B模型的全参数微调相比,LoRA可以将可训练参数减少10,000倍,GPU内存需求减少3倍。

保持模型质量:尽管LoRA使用的可训练参数远少于全参数微调,但它在多个模型(如RoBERTa、DeBERTa、GPT-2和GPT-3)上的表现与全参数微调相当或更好。

提高训练效率:LoRA降低了硬件门槛,因为它不需要计算大多数参数的梯度或维护优化器状态。此外,LoRA的设计允许在部署时将训练的矩阵与冻结的权重合并,从而不会引入额外的推理延迟。

实证研究:文章提供了关于语言模型适应性中秩不足性的实证研究,这有助于解释LoRA的有效性。

总的来说,LoRA提出了一种创新的方法来解决大型语言模型在特定任务上的适应问题,同时保持了模型的性能,降低了资源消耗,并提高了操作效率。这对于需要在资源受限的环境中部署和使用大型模型的应用场景尤为重要。

LoRA技术模型图

Description
正所谓大智若愚,LoRA这项技术的模型图就是这么简洁明了,x表示数据输入,左边表示预训练大模型的参数(冻结),右边表示两个低秩矩阵(训练),当大模型微调的时候,不再是全参数微调,而是仅仅微调右边的低秩矩阵。

这样一来,就能大大减少我们微调时候的工作量和需要的资源,并且使用这种方法微调模型的性能和全参数微调差不多,从而实现四两拨千斤的效果。

技术细节

假设预训练模型要进行常规全参数微调
Description
其中Φ表示模型的参数,x表示输入,y表示输出
Description
表示进行微调任务的数据集

此时我们需要调整的参数就是全参数:
Description
如果是175B的模型,微调一个下游任务的模型,每次都要调整这么多参数,工作量巨大。

但是使用LoRA技术的话
Description
预训练模型的参数都冻结,不调整

只是额外加一组小小的参数

也能做到和下游任务适配
Description
而此时需要调整的参数远远小于预训练模型的参数
Description
也就是说此时需要调整的参数很小。

文章主要聚焦于将LoRA在transformer注意力机制上进行使用,因为这也是transformer的精髓
Description
Description
分别用于表示四个线性层的参数。

Description
用于表示预训练模型的参数

Description
是自适应过程中的累积梯度更新

r就是低秩矩阵的秩

例如我们在

W 0上加个LoRA
Description
假设 W 0为512*512
就单单只看这部分的话

全参数微调需要调整512*512 = 262144个参数

使用LoRA后,这262144个参数就冻结了

此时增加两个低秩矩阵 例如5122和2512

那么此时需要调整的参数大小就为5122+2512 = 2048个参数

2048 / 262144 = 0.0078125

此时要训练的参数就减少了许多

而且,当我们面对不同的下游任务时,因为原本的预训练模型是冻结的,所以预训练模型用一个就行,只需要保存的参数就是加入的低秩矩阵,这样的话,也能节省大量的存储空间。

可以看个伪代码:

class LowRankMatrix(nn.Module):

    def __init__(self, weight_matrix, rank, alpha=1.0):

        super(LowRankMatrix, self).__init__()

        self.weight_matrix = weight_matrix

        self.rank = rank

        self.alpha = alpha / rank  # 将缩放因子与秩相关联

        # 初始化低秩矩阵A和B

        self.A = nn.Parameter(torch.randn(weight_matrix.size(0), rank), requires_grad=True)

        self.B = nn.Parameter(torch.randn(rank, weight_matrix.size(1)), requires_grad=True)



    def forward(self, x):

        # 计算低秩矩阵的乘积并添加到原始权重上

        # 应用缩放因子

        updated_weight = self.weight_matrix + self.alpha * torch.mm(self.B.t(), self.A)

        return updated_weight

α和r用于缩放矩阵,帮助更好的训练

A矩阵使用随机高斯初始化

B矩阵初始化为0

论文实验结果

Description
LoRA相较于Adapter不会增加推理的时间。

Description
Description
Description
LoRA效果好

Description
LoRA一起用到Wq和Wv效果比较好

Description
低秩已足够

LoRA在bert的运用

这里主要以bert-base-uncased为例来实现LoRA微调技术的运用。
bert-base-uncased的参数量为110M也就是1.1亿个参数

LoRA核心代码

主要使用文章提出的开源loralib来对bert的注意力机制线性层进行LoRA层的增加

def get_lora_bert_model(model, r=8, lora_layer=["q", 'k', 'v', 'o']):
    encoder_layers = list(model.encoder.layer)
    for layer_index, encoder_layer in enumerate(encoder_layers):
        # 访问多头自注意力层
        attention = encoder_layer.attention
        # 获取Q、K、V线性层
        q_linear = attention.self.query
        k_linear = attention.self.key
        v_linear = attention.self.value
        # 获取O线性层(实际上,O是V经过加权求和后的结果,通常不单独存储)
        o_linear = attention.output.dense

        for l in lora_layer:
            if l == 'q':
                new_q_proj = lora.Linear(q_linear.in_features, q_linear.out_features, r=r)
                model.encoder.layer[layer_index].attention.self.query = new_q_proj
            elif l == 'k':
                new_k_proj = lora.Linear(k_linear.in_features, k_linear.out_features, r=r)
                model.encoder.layer[layer_index].attention.self.key = new_k_proj
            elif l == 'v':
                new_v_proj = lora.Linear(v_linear.in_features, v_linear.out_features, r=r)
                model.encoder.layer[layer_index].attention.self.value = new_v_proj
            elif l == 'o':
                new_o_proj = lora.Linear(o_linear.in_features, o_linear.out_features, r=r)
                model.encoder.layer[layer_index].attention.output.dense = new_o_proj
    return model

可以看到对每层注意注意力机制层的q k v o的线性层都添加了LoRA层

def mark_only_LLM_lora_as_trainable(model: nn.Module, bias: str = 'none', LLM_name: str = 'default_value') -> None:
    if LLM_name == 'default_value':
        for n, p in model.named_parameters():
            if 'lora_' not in n:
                p.requires_grad = False
        if bias == 'none':
            return
        elif bias == 'all':
            for n, p in model.named_parameters():
                if 'bias' in n:
                    p.requires_grad = True
        elif bias == 'lora_only':
            for m in model.modules():
                if isinstance(m, LoRALayer) and \
                    hasattr(m, 'bias') and \
                    m.bias is not None:
                        m.bias.requires_grad = True
        else:
            raise NotImplementedError
    else:
        for n, p in model.named_parameters():
            if 'lora_' not in n and LLM_name in n:
              # and "bert.pooler" not in 
                p.requires_grad = False
        if bias == 'none':
            return
        elif bias == 'all':
            for n, p in model.named_parameters():
                if 'bias' in n:
                    p.requires_grad = True
        elif bias == 'lora_only':
            for m in model.modules():
                if isinstance(m, LoRALayer) and \
                    hasattr(m, 'bias') and \
                    m.bias is not None:
                        m.bias.requires_grad = True
        else:
            raise NotImplementedError

添加LoRA层后,每次训练模型的时候,就只需要训练bert加入的LoRA层,此时我们就需要用到mark_only_LLM_lora_as_trainable()来帮助我们实现,考虑到可能我们基于bert的分类模型可能还会涉及到我们自己加入的某些结构,这些部分是需要进行训练的,所以对于这种情况就这么来使用:

mark_only_LLM_lora_as_trainable(model, LLM_name='bert')

实战分析

本文采用IMDB影评情感分析数据集测试训练集各25000条来进行实验。

因为bert才1.1B,可能在bert上使用这个东西有点小题大做了,但是一屋不扫何以扫天下,现在的大模型架构基本都是基于transformer架构的(bert可以说是第一个),其实本质上都是差不多的,只不过我感觉可能更大一些的模型LoRA的效果会更加显著,模型越大,这个方法的优越性就会越强。

之前对bert全参数微调的准确率是93%,而使用LoRA微调技术得出的结果大约是86%左右,确实有一定的差距,我个人感觉可能是因为模型不够大,只有1.1B,因为低秩势必导致信息的损失,只有当你的模型够大的时候,这些损失才能够忽略不计。

但是使用LoRA技术,对于训练速度、显存占用有了巨大的提升。

首先来看显存占用量(同样是batch_size=64):
Description
这是全参数微调的显存占用。

Description
这是使用LoRA后的显存占用(q k v o都使用,r=8)

可以看到,使用了LoRA后,显存占用少了16G左右,节约了约31.5%的显存使用。

再看看训练速度有什么区别:
Description
这是全参数微调的结果,可以看到准确率确实挺高的,但是训练一个epoch需要4分钟

Description
这是使用LoRA之后的,可以看到除了第一个epoch可能涉及数据加载、GPU预热等情况稍微慢点,其余epoch都是2.5分钟不到就完成了,节约了大概43%的训练时间。

不过准确率也下降了,从93%掉到了86%,准确率大约下降了7.5%。

如果对更大的模型使用LoRA技术,训练时间和显存占用的节省会更多,而性能的下降则会更少,确实是一项很不错的技术。

由此可见,LoRA这项技术确实十分有意义,能够大大降低模型微调的成本,同时不会增加推理的时间延迟,我们可以看到模型评估的时间都是一模一样的。

所以,这项技术其实一定程度上让大模型的门槛降低了一些,让大模型的使用成本大大降低,虽然性能上可能有些损失,但是,至少落地的可能性变大了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/550042.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Day 14 网络协议

常见网络设备:交换机 路由器 中继器 多协议网关(路由器的前身) 交换机:用于连接统一网络的设备,实现内网设备通信。 从广义上分为:局域网交换机,广域网交换机 从网络构成分为:接…

Prompt提示工程上手指南:基础原理及实践-思维树 (ToT)策略下的Prompt

前言 此篇文章已经是本系列的第五篇文章,之前我们已经将检索增强生成(RAG)策略,逐渐我们掌握的知识和技术都在不断提高,对于Prompt的技巧策略也不能只局限于局部运用而要适应LLM大模型的整体框架去进行改进休整。较为主流的LLM模型框架设计基…

通过adb 命令打印安装在第三方模拟器上的log

1,环境:Windows 11 ,第三方模拟器 网易的MuMu 步骤: 1,打开cmd,输入 adb connect 172.0.0.1:7555 2,在cmd,再次输入adb logcat 回车

【Web】陇原战“疫“2021网络安全大赛 题解

目录 CheckIN eaaasyphp EasyJaba CheckIN 拿到附件,贴出关键代码 func getController(c *gin.Context) {cmd : exec.Command("/bin/wget", c.QueryArray("argv")[1:]...)err : cmd.Run()if err ! nil {fmt.Println("error: ", …

【HCIP】OSPF的高级特性

OSPF的高级特性1 --- 不规则区域 一、OSPF不规则区域类型 产生原因:区域划分不合理,导致的问题 1、非骨干区域无法和骨干区域保持连通 2、骨干区域被分割 造成后果:非骨干区域没和骨干区域相连,导致ABR将不会帮忙转发区域间的路由…

element-ui设置弹窗等级最高

通过参数:appendToBody"true"设置弹窗等级最高 主要是 :appendToBody“true”&#xff0c;其他参数可根据自己需求配置 <el-dialog :title"title" :visible.sync"isShow" top"5vh" :appendToBody"true"><el-image…

Web前端开发——Ajax,Axios概述及在Vue框架中的使用

前言&#xff1a; 整理下学习笔记&#xff0c;打好基础&#xff0c;daydayup!!! Ajax Ajax是什么&#xff1f; Ajax全称Asynchromous JavaScript And Xml&#xff0c;是异步的JavaScript和Xml。 Ajax的作用&#xff1f; 1&#xff0c;数据交换&#xff1a;通过Ajax可以给服务器…

uniapp 当前系统没有安装苹果根证书,是否打开证书目录(打开后依次安装证书

当你遇到这类问题时&#xff0c;说明你也极其的困惑&#xff01;这就是为啥大抵国内这些货色搞的东西总是不尽人意&#xff01;连开发者生态都搞不好&#xff0c;就急着吹嘘。 这是官方给的技术说明方案&#xff1a; 恭喜你&#xff0c;当你按照这个搞之后&#xff0c;你的问题…

Map与Set的模拟实现封装

目录 一. 底层原理 二. 红黑树节点的定义 三. 仿函数封装 四. 基本函数的封装 五. 迭代器的封装 5.1 迭代器的基本定义 5.2 *与->操作 5.3 迭代器的操作 5.3.1 右子树不为空 5.3.2 右子树为空 5.4 迭代器的--操作 5.4.1 当前节点的父节点…

CSS基础:最详细 padding的 4 种用法解析

你好&#xff0c;我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃&#xff0c;大专生&#xff0c;一枚程序媛&#xff0c;感谢关注。回复 “前端基础题”&#xff0c;可免费获得前端基础 100 题汇总&#xff0c;回复 “前端工具”&#xff0c;可获取 We…

Adobe Premiere Pro将加入AI生成式功能,以提高视频编辑的效率;OpenAI宣布在东京设立亚洲首个办事处

&#x1f989; AI新闻 &#x1f680; Adobe Premiere Pro将加入AI生成式功能&#xff0c;以提高视频编辑的效率 摘要&#xff1a;Adobe宣布&#xff0c;将为Premiere Pro引入由生成式AI驱动的新功能&#xff0c;以提高视频编辑的效率。这些功能包括“生成扩展”&#xff0c;能…

免费开源多客圈子婚恋社交校园跑腿线上线下陪玩 源码交付 可打包小程序 支持二开!

聊天软件作为一种现代化的通讯工具&#xff0c;其好处可以总结如下&#xff1a; 1.方便快捷&#xff1a;聊天软件只要有网络连接&#xff0c;就可以随时随地与他人进行交流&#xff0c;不受时间和地点的限制&#xff0c;可以随时随地进行沟通&#xff0c;大大方便了人们的日常…

【结构型模式】装饰器模式

​一、装饰器模式概述 装饰器模式&#xff08;装饰者模式&#xff09;定义&#xff1a;装饰器模式动态地将责任附加到对象上。若要拓展功能&#xff0c;装饰者提供了比继承更有弹性地替代方案。&#xff08;对象结构型模型&#xff09;通俗点来说&#xff1a;动态的给一个对象增…

适用于 Windows 的 10 个顶级 PDF 编辑器 [免费和付费]

曾经打开PDF文件&#xff0c;感觉自己被困在数字迷宫中吗&#xff1f;无法编辑的文本、无法调整大小的图像以及签署感觉像是一件苦差事的文档&#xff1f;好吧&#xff0c;不用再担心了&#xff01;本指南解开了在 Windows 上掌握 PDF 的秘密&#xff0c;其中包含 10 款适用于 …

vscode vue template模板中 tab键无法快速补全

之前记得一直可以的突然不知道咋的就不行了… 解决办法: 菜单栏 - 文件 - 首选项 - 设置- emmet:tab ✔就好了

Flink CDC 的 debezium-json 格式和 debezium 原生格式是一回事吗?

博主历时三年精心创作的《大数据平台架构与原型实现&#xff1a;数据中台建设实战》一书现已由知名IT图书品牌电子工业出版社博文视点出版发行&#xff0c;点击《重磅推荐&#xff1a;建大数据平台太难了&#xff01;给我发个工程原型吧&#xff01;》了解图书详情&#xff0c;…

【介绍下负载均衡原理及算法】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

IP协议如何进行地址管理?

如今&#xff0c;IP协议有两个版本&#xff0c;分别是IPv4和IPv6&#xff0c;IPv4是目前主要应用的版本。IPv4的IP地址是以4个字节的数字来表示的&#xff0c;比如 127.0.0.1。因此&#xff0c;IPv4所能表示IP地址的个数是2^32次方&#xff0c;也就是42亿多个&#xff0c;看起来…

48.HarmonyOS鸿蒙系统 App(ArkUI)常用组件的使用

48.HarmonyOS鸿蒙系统 App(ArkUI)常用组件的使用 按钮触发事件 toast信息提示 单选按钮 复选框 切换按钮&#xff0c;开关按钮 进度条 textbox,textinput,TextArea文本输入框 气泡提示 import prompt from ohos.prompt; import promptAction from ohos.promptAction; …

Qt对象池,单例模式,对象池可以存储其他类的对象指针

代码描述&#xff1a; 写了一个类&#xff0c;命名为对象池&#xff08;ObjectPool &#xff09;&#xff0c;里面放个map容器。 3个功能&#xff1a;添加对象&#xff0c;删除对象&#xff0c;查找对象 该类只构建一次&#xff0c;故采用单例模式功能描述&#xff1a;对象池可…