大模型平民化技术之LORA

1. 引言

在这篇博文中, 我将向大家介绍LoRA技术背后的核心原理以及相应的代码实现。

LoRA Low-Rank AdaptationLow-Rank Adaptors 的首字母缩写词,它提供了一种高效且轻量级的方法,用于微调预先训练好的的大语言模型。这包括 BERT RoBERTa 等掩码语言模型,以及 GPTLlamaMistral 等因果推断模型。

闲话少说,我们直接开始吧!

2. 优势分析

LoRA的主要优点之一是他们的效率。通过使用更少的参数,LoRA显著降低了模型训练过程中计算复杂性和显存使用量。这可以让我们在消费级的GPU上来训练大模型,并且可以便利地将我们训练好的LoRA权重(以兆为单位)分发给其他人。

此外,LoRA可以提升模型的泛化性。通过限制模型的复杂度,可以有助于防止在训练数据有限场景下的过拟合现象;由于LoRA至少保留了初始模型的能力,在处理一些新的,未见过的数据时更具有弹性。

最后,LoRA可以无缝地集成到现有的神经网络架构中。这种集成允许以最小的额外训练成本对预训练模型进行微调和调整,使其非常适合迁移学习应用。

3. 工作原理

LoRA 的基本思想是保持预训练矩阵(即原始模型的参数)冻结(即处于固定状态),并且只在原始矩阵中添加一个小的增量,其参数量比原始矩阵少很多。

例如,考虑矩阵 W,它可以是全连接层的参数,也可以是来Transformer中计算自注意力机制的矩阵之一:
在这里插入图片描述

显然,如果 Worig 的维数为 n×m,而假如我们只是初始化一个具有相同维数的新的增量矩阵进行微调,虽然我们也实现类似的功能,但是我们的参数量将会加倍。 LoRA使用的Trick就是通过训练低维矩阵 B A ,通过矩阵乘法来构造 ΔW ,来使 ΔW 的参数量低于原始矩阵。
在这里插入图片描述
这里我们不妨定义秩 r,它明显小于基本矩阵维度 r≪nr≪m。则矩阵 B n×r,矩阵 A r×m。将它们相乘会得到一个维度为 nxmW 矩阵,但构建的参数量减小了很多。

此外,我们希望我们的增量ΔW在训练开始时为零,这样微调就会从原始模型一样开始。因此,B 通常初始化为全零,而 A 初始化为随机值(通常呈正态分布)。

4. 举个栗子

我们不妨来看个直观的栗子,如下图所示:
在这里插入图片描述
想象一下,我们的基本维数是 1024,我们选择了 LoRA 的秩r为 4,则对于上述过程:
● 权重W的参数量为1024X1024≈1M
AB的参数量一致,均为rX1024≈4K,这样二者之和为8K
● 这样使用LoRA技术,在上述例子中我们仅仅需要训练0.8%的参数就可以更新我们的参数矩阵

5. LoRA指令速查

主要可查阅微软的官方文档Github: 戳我 , 由于封装的很好,目前该库页整合至 HuggingFace Parameter-Efficient Fine-Tuning (PEFT) 。

  • 如果模型要将特定层替换成 LoRA,需要调整模型的结构,但调用很简单:
# ===== Before =====
# layer = nn.Linear(in_features, out_features)
# ===== After ======
import loralib as lora
# Add a pair of low-rank adaptation matrices with rank r=16
layer = lora.Linear(in_features, out_features, r=16)
  • 在训练之前要把原本的LLM模型 Freeze 住,并且设定只有 LoRA 的参数是可训练的
import loralib as lora
model = BigModel()
# This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)
# Training loop
for batch in dataloader:
  • 保存模型时也可以只储存LoRA所训练的权重,这特性将方便大家分享自己的权重
# ===== Before =====
# torch.save(model.state_dict(), checkpoint_path)
# ===== After =====
torch.save(lora.lora_state_dict(model), checkpoint_path)
  • 推理时读取 LoRA 或是原本 LLM 的权重时,要将 strict 设定为 False
# Load the pretrained checkpoint first
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
# Then load the LoRA checkpoint
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)

6. SD-LoRA应用

近年来生成式 AI DALLE 再到 Stable-diffusion,都显示了现在的 AI 可以生成高质量以及高分辨率的图片,但是让人诟病的还是需要大量的运算资源才能够训练得了这种高分辨率的模型,因为要训练一个高分辨率的扩散模型是需要相当多内存的,即便 Stable-diffusion 将原本的 Pixel-level Diffusion 变成 Latent Diffusion Model 已经大幅降低训练的内存,但仍无法在单一张 11 GB 的 GPU 上训练,但现在不一样了,有人将 LoRA 技术整合到 Stable-diffusion,推出了 Stable Diffusion LoRA
在这里插入图片描述
整合 LoRAStable-diffusion 直接带来了以下的好处:

  • 训练快很多
  • 可在 11GB 显卡上直接进行训练
  • LoRA 权重的保存只有 3MB~200MB,易于分享

7. SD-LoRA更多资源

LoRA这项技术上的突破也使得Stable Diffusion 的社区多了许多生成模型,甚至可将模型上传至网站 CivitAI,可以看到上面有许多模型是使用 LoRA 进行训练的:

在这里插入图片描述
当然,网络上也有许多资源是使用 Colab 或是在个人 PC 上面生成/训练模型,最近 Stable diffusion 的社群已经开源相当多项目,并提供 GUI 界面,甚至不需要懂程序代码就可以训练好生成式 AI。

stable-diffusion-webui-colab

Kohya’s GUI, Support Windows

如果只是想来看看 Stable-diffusion 的人建议使用 WebUI,不仅能使用官方释出的模型,也可以直接登陆到 CivitAI,直接下载别人的训练好的生成模型:

在这里插入图片描述
在这里插入图片描述

8. 总结

共享大型的LLM模型是未来的趋势,如果要适应到某个具体任务上,只要训练LoRA模组即可,而这项技术也带来方便的替换性,未来大家只要分享LoRA的模型权重,就可以快速切换至不同的任务。

此外,LoRA通过大量降低训练参数,来大幅降低了硬体的训练门槛,并且与完全 Fine-tuning 的模型相比,推论速度的增加是相当少的。

9. 附录

本文重点参考链接如下:

[1] LoRA论文:戳我
[2] LoRA tutorial:戳我
[3] PEFT tutorial: 戳我
[4] Stable Diffusion Webui: 戳我
[5] ai-drawing: 戳我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/405890.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据之巅:揭秘企业数据分析师如何成为企业的决策智囊

引言 在数字化浪潮中,企业数据分析师已成为企业决策的重要支撑。他们如同探险家,在数据的丛林中寻找着能够指引企业前行的宝贵信息。本文将深入剖析企业数据分析师的角色、挑战与成就,带你领略这个充满智慧与激情的职业风采。 一、从数字到智…

计算机毕业设计 基于SpringBoot的宠物商城网站系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

代理模式笔记

代理模式 代理模式代理模式的应用场景先理解什么是代理,再理解动静态举例举例所用代码 动静态的区别静态代理动态代理 动态代理的优点代理模式与装饰者模式的区别 代理模式 代理模式在设计模式中是7种结构型模式中的一种,而代理模式有分动态代理&#x…

WordPres Bricks Builder 前台RCE漏洞

免责声明:文章来源互联网收集整理,请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该…

<网络安全>《48 网络攻防专业课<第十四课 - 华为防火墙的使用(1)>

1 DHCP Snooping 概述 DHCP Snooping功能用于防止 1、DHCP Server仿冒者攻击; 2、中间人攻击与IP/MAC Spoofing攻击; 3、改变CHADDR值的DoS攻击。 1.2 DHCP Server 仿冒者攻击 1.3 中间人与IP/MAC Spoofing 攻击 1.4 改变CHADDR 值的DoS 攻击 CHADDR…

2024年数学建模美赛详细总结以及经验分享

前言: 本文记录与二零二四年二月六日,正好今天是数学建模结束,打算写篇文章记录一下整个过程,以及一些感受、还有经验分享。记录这个过程的原因就是我在赛前,在博客上找了很久,也没有像我这么类似记…

Gemma模型论文详解(附源码)

原文链接:Gemma模型论文详解(附源码) 1. 背景介绍 Gemma模型是在2023.2.21号Google新发布的大语言模型, Gemma复用了Gemini相同的技术(Gemini也是Google发布的多模态模型),Gemma这次发布了了2B和7B两个版本的参数,不…

JAVA--File类与IO流

目录 1. java.io.File类的使用 1.1 概述 1.2 构造器 1.3 常用方法 1、获取文件和目录基本信息 2、列出目录的下一级 3、File类的重命名功能 4、判断功能的方法 5、创建、删除功能 2. IO流原理及流的分类 2.1 Java IO原理 2.2 流的分类 2.3 流的API 3. 节点流之一…

微服务学习

一、服务注册发现 服务注册就是维护一个登记簿,它管理系统内所有的服务地址。当新的服务启动后,它会向登记簿交待自己的地址信息。服务的依赖方直接向登记簿要Service Provider地址就行了。当下用于服务注册的工具非常多ZooKeeper,Consul&am…

Jetson Xavier NX 与笔记本网线连接 ,网络共享,ssh连接到vscode

Jetson Xavier NX 与笔记本网线连接 ,网络共享,ssh连接到vscode Jetson Xavier NX桌面版需要连接显示屏、鼠标和键盘,操作起来并不方便,因此常常需要ssh远程连接到本地笔记本电脑,这里介绍一种连接方式,通过…

Linux实验记录:使用PXE+Kickstart无人值守安装服务

前言: 本文是一篇关于Linux系统初学者的实验记录。 参考书籍:《Linux就该这么学》 实验环境: VmwareWorkStation 17——虚拟机软件 RedHatEnterpriseLinux[RHEL]8——红帽操作系统 备注: 实际生产中安装操作系统的工作&…

论文笔记:利用词对比注意增强预训练汉字表征

整理了 ACL2020短文 Enhancing Pre-trained Chinese Character Representation with Word-aligned Att)论文的阅读笔记 背景模型实验 论文地址:论文 背景 近年来,以 BERT 为代表的预训练模型在 NLP 领域取得取得了非常显著的效果。但是&…

谈谈对BFC的理解

文章目录 一、是什么二、触发条件三、应用场景防止margin重叠(塌陷)清除内部浮动自适应多栏布局小结 参考文献 一、是什么 我们在页面布局的时候,经常出现以下情况: 这个元素高度怎么没了?这两栏布局怎么没法自适应&…

28-k8s集群中-StatefulSets控制器(进阶知识)

一、statefullsets控制器概述 1,举例 假如,我们有一个deployment资源,创建了3个nginx的副本,对于nginx来讲,它是不区分启动或者关闭的先后顺序的,也就是“没有特殊状态”的一个服务,也成“无状…

一次有趣的nginx Tcp4层代理转发的试验

nginx主配置文件添加配置: stream {log_format proxy $remote_addr [$time_local] $protocol status:$status bytes_sent:$bytes_sent bytes_received:$bytes_received $session_time upstream_addr:"$upstream_addr" "$upstream_bytes_sent" …

React18源码: React调度中的3种优先级类型和Lane的位运算

优先级类型 React内部对于优先级的管理,贯穿运作流程的4个阶段(从输入到输出),根据其功能的不同,可以分为3种类型: 1 )fiber优先级(LanePriority) 位于 react-reconciler包,也就是L…

基于Java SSM框架实现网络作业提交与批改系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架实现网络作业提交与批改系统演示 摘要 随着互联网时代的到来,同时计算机网络技术高速发展,网络管理运用也变得越来越广泛。因此,建立一个B/S结构的网络作业提交与批改系统,会使网络作业提交与批改系统工作系统化…

JavaScript字符串的常用方法(非常详细)

文章目录 一、操作方法增concat 删改trim()、trimLeft()、trimRight()repeat()padEnd() toLowerCase()、 toUpperCase()查charAt()indexOf()startWith()、includes() 二、转换方法split 三、模板匹配方法match()search()replace() 一、操作方法 我们也可将字符串常用的操作方法…

c编译器学习07:minilisp编译器改造(debug模式支持调试)

问题 原版的minilisp编译器不支持argv输入测试,不方便单步调试。 代码改造目标是既不改变原有程序的各种功能, 又能支持个人习惯的vs单步debug模式。 CMakeLists.txt变更 定义DEBUG宏 解决单步调试源码定位偏差问题 cmake_minimum_required(VERSION …

Puppeteer 使用实战:如何将自己的 CSDN 专栏文章导出并用于 Hexo 博客(二)

文章目录 上一篇效果演示Puppeteer 修改浏览器的默认下载位置控制并发数错误重试并发控制 错误重试源码 上一篇 Puppeteer 使用实战:如何将自己的 CSDN 专栏文章导出并用于 Hexo 博客(一) 效果演示 上一篇实现了一些基本功能,…