低成本微调LLM

低成本微调LLM

最近在微调不同量级上的大模型,包括Llama-2-7b,Llama-2-13b,Llama-2-70b,Yi-34b,Qwen-14b,Qwen-72b等大模型。在有限的资源上微调大模型,节约显存,可以考虑使用
LoRA这个算法,来自论文《LoRA: Low-Rank Adaptation of Large Language Models》,目前可以用的包有两个,分别是loralib 和peft这两个包,其中peft 和huggingface 中的transformers结合一起使用非常方便大家的使用。

LoRA 原理

LLM模型中有一部分层是线性层,LoRA是对线性层的参数增加两个低维度的矩阵进行线性运算,从而降低了模型微调的参数量和显存消耗。通常低维度的参数rank 可以自行设置,相比于LLM的hidden_size的值要小很多。下面介绍LoRA的数学表达式:
y = ( w T + α l o r a _ A T l o r a _ B T ) x = w T x + α l o r a _ A T l o r a _ B T x y = (w^{T} + \alpha lora\_{A}^{T}lora\_{B}^{T}) x\\ =w^{T} x+ \alpha lora\_{A}^{T}lora\_{B}^{T} x y=(wT+αlora_ATlora_BT)x=wTx+αlora_ATlora_BTx
其中 w ∈ R m × h w \in R^{m\times h} wRm×h l o r a _ A ∈ R r × h lora\_{A} \in R^{r\times h} lora_ARr×h l o r a _ B ∈ R m × r lora\_{B} \in R^{m\times r} lora_BRm×r r r r通常可以自行设置的值比较小,而且远远小于m, l o r a _ A ∈ R r × h , l o r a _ B ∈ R m × r lora\_{A} \in R^{r\times h},lora\_{B} \in R^{m\times r} lora_ARr×h,lora_BRm×r的参数量远小于 w ∈ R m × h w \in R^{m\times h} wRm×h。在采用LoRA的方式微调LLM,梯度更新的参数为每一层的线性层相应的lora参数,例如表达式中的 l o r a _ A lora\_{A} lora_A l o r a _ B lora\_{B} lora_B,原模型的参数不进行梯度更新,这样做的目的是训练参数量减少,节约显存,加快训练速度。在显卡资源不足的情况下,可以选择LoRA的方式进行微调LLM。而且 l o r a _ A T l o r a _ B T ∈ R h × m lora\_{A}^{T}lora\_{B}^{T}\in R^{h\times m} lora_ATlora_BTRh×m w T w^{T} wT的大小是一致的,方便将 l o r a _ A T l o r a _ B T lora\_{A}^{T}lora\_{B}^{T} lora_ATlora_BT参数合并到原始模型参数 w T w^{T} wT上,并未对原始模型增加新的参数,从而采用原始模型的推理方式进行推理。有时候微调的时候需要训练embedding层和lm head 这一层,在采用LoRA训练的时候也可以训练这两层,在peft中可以实现这两层的训练。

LoRA微调LLM

下面将给出Qwen采用LoRA的方式进行微调,是基于peft和transformers实现的。下面将给出示例代码。注意这里使用的是Qwen1不是Qwen2。

from peft import PeftModel
from peft import LoraConfig
from peft import get_peft_model

from modeling_qwen import QWenLMHeadModel

def load_qwen_model_lora(pretrain_model_path,  use_gradient_checkpointing, lora_r, lora_alpha,
                          lora_dropout, bf16=False, fp16=False, checkpoint_dir=None
                          ):
    model = QWenLMHeadModel.from_pretrained(pretrain_model_path, 
                                             torch_dtype=torch.bfloat16 if bf16 else torch.float16 if fp16 else torch.float32)
    print("loadding model")
    model.config.torch_dtype = (torch.float16 if fp16 else (torch.bfloat16 if bf16 else torch.float32))
    if use_gradient_checkpointing:
        model.gradient_checkpointing_enable()
        print("using gradient_checkpointing_enable")

    model.config.use_cache = False

    target_modules = find_all_linear_names(model)
    print("模型中linear层的名称集合:", target_modules)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM", 
    )

    if checkpoint_dir is not None:
        print("加载模型继续训练.")
        model = PeftModel.from_pretrained(model, checkpoint_dir)
         for name, param in model.named_parameters():
             if 'lora_' in name:
                 param.requires_grad = True
    else:
        print('添加 LoRA 网络...')
        model = get_peft_model(model, config)

    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if bf16:
                module = module.to(torch.bfloat16)
        if 'ln' in name:
            module = module.to(torch.float32)
        if 'lm_head' in name or 'wte' in name:
            if hasattr(module, 'weight'):
                if bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    return model


def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    print("lora name: " , lora_module_names)
    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

采用transformers 的Trainer 进行模型训练,这里也可以采用deepspeed ddp 进行模型训练。

下面将介绍在采用LoRA进行微调的时候训练embedding 和lm header 这两层。如果模型字典进行延扩,embedding 和lm header 这两个层需要训练。只需要在LoraConfig指定就可以。

def load_qwen_model_lora(pretrain_model_path, use_gradient_checkpointing, lora_r, lora_alpha,
                          lora_dropout, bf16=False, fp16=False, checkpoint_dir=None,
                          finetuning_embedding_and_lm_head=True
                          ):
    model = QWenLMHeadModel.from_pretrained(pretrain_model_path, 
                                             torch_dtype=torch.bfloat16 if bf16 else torch.float16 if fp16 else torch.float32)
    print("loadding model")
    model.config.torch_dtype = (torch.float16 if fp16 else (torch.bfloat16 if bf16 else torch.float32))

    if use_gradient_checkpointing:
        model.gradient_checkpointing_enable()
        print("using gradient_checkpointing_enable")

    model.config.use_cache = False

    target_modules = find_all_linear_names(model)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM", 
        modules_to_save=['base_model.model.lm_head', 'base_model.model.transformer.wte']
    )

    if checkpoint_dir is not None:
        print("从 checkpoint 加载 adapters.")
        model = PeftModel.from_pretrained(model, checkpoint_dir)
        
        for name, param in model.named_parameters():
            if 'lora_' in name:
                param.requires_grad = True
    else:
        print('添加 LoRA 网络...')
        model = get_peft_model(model, config)

    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if bf16:
                module = module.to(torch.bfloat16)
        if 'ln' in name:
            module = module.to(torch.float32)
        if 'lm_head' in name or 'wte' in name:
            if hasattr(module, 'weight'):
                if bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    # 设置lm head 和 embedding是否训练
    if finetuning_embedding_and_lm_head:
        for name, param in model.named_parameters():
            if 'lm_head' in name or 'wte' in name:
                param.requires_grad = True

    return model

QLoRA 微调LLM

QLoRA是模型量化(Quantilization) 和LoRA结合起来使得降低显存并加快训练效率,可以看作是对LoRA的优化。下面将给出千问的示例。在实际中使用的是4bit进行微调的。

import os

import torch
from peft import PeftModel
from peft import LoraConfig
from peft import get_peft_model
from peft import prepare_model_for_kbit_training

from modeling_qwen import QWenLMHeadModel

def load_qwen_model_qlora(pretrain_model_path, use_gradient_checkpointing, 
                          bits,
                          double_quant,
                          quant_type, 
                          lora_r, lora_alpha, lora_dropout, 
                          checkpoint_dir=None,
                          bf16=False, fp16=False,
                          finetuning_embedding_and_lm_head=False):
    
    ## lora 微调embedding和lm head 的话,有两种方式,一种是下面两行代码取消注释,一种是LoraConfig中modules_to_save添加要训练的层名
    if finetuning_embedding_and_lm_head:
        replace_peft_save_pretrined()

    model_dict = load_model_state_dict(pretrain_model_path)

    if checkpoint_dir is not None:
        fintuning_checkpoint_file = os.path.join(checkpoint_dir, 'embedding_and_lm_head.pt')
        if os.path.exists(fintuning_checkpoint_file):
            # 加载保存点的变化的权重
            finetuning_state = torch.load(fintuning_checkpoint_file, map_location=lambda storage, loc: storage)
            # 
            for n, p in finetuning_state.items():
                n = n.replace('base_model.model.', '')
                if n in model_dict:
                    model_dict[n] = p
                else:
                    raise NameError
    
    weight_dtype = torch.bfloat16 if bf16 else torch.float16 if fp16 else torch.float32
    model = QWenLMHeadModel.from_pretrained(pretrain_model_path, 
                                             torch_dtype=weight_dtype,
                                             state_dict=model_dict,
                                             load_in_4bit=bits == 4,
        load_in_8bit=bits == 8,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=bits == 4,
            load_in_8bit=bits == 8,
            bnb_4bit_compute_dtype=weight_dtype,
            bnb_4bit_use_double_quant=double_quant,
            bnb_4bit_quant_type=quant_type
        ),)
    print("loadding model")

    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=use_gradient_checkpointing)
    target_modules = find_all_linear_names(bits, model)
    print("模型中linear层的名称集合:", target_modules)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        # modules_to_save=['base_model.model.lm_head', 'base_model.model.transformer.wte'] # , 
    )

    if checkpoint_dir is not None:
        print("从 checkpoint 加载 adapters.")
        model = PeftModel.from_pretrained(model, checkpoint_dir)

    else:
        print('添加 LoRA 网络...')
        model = get_peft_model(model, config)

    for name, param in model.named_parameters():
        if finetuning_embedding_and_lm_head:
            if 'lm_head' in name or 'wte' in name:
                param.requires_grad = True
                continue

    return model

以上是对LoRA和QLoRA的使用介绍,如有理解错误,欢迎指证。

欢迎使用Markdown编辑器

你好! 这是你第一次使用 Markdown编辑器 所展示的欢迎页。如果你想学习如何使用Markdown编辑器, 可以仔细阅读这篇文章,了解一下Markdown的基本语法知识。

新的改变

我们对Markdown编辑器进行了一些功能拓展与语法支持,除了标准的Markdown编辑器功能,我们增加了如下几点新功能,帮助你用它写博客:

  1. 全新的界面设计 ,将会带来全新的写作体验;
  2. 在创作中心设置你喜爱的代码高亮样式,Markdown 将代码片显示选择的高亮样式 进行展示;
  3. 增加了 图片拖拽 功能,你可以将本地的图片直接拖拽到编辑区域直接展示;
  4. 全新的 KaTeX数学公式 语法;
  5. 增加了支持甘特图的mermaid语法1 功能;
  6. 增加了 多屏幕编辑 Markdown文章功能;
  7. 增加了 焦点写作模式、预览模式、简洁写作模式、左右区域同步滚轮设置 等功能,功能按钮位于编辑区域与预览区域中间;
  8. 增加了 检查列表 功能。

功能快捷键

撤销:Ctrl/Command + Z
重做:Ctrl/Command + Y
加粗:Ctrl/Command + B
斜体:Ctrl/Command + I
标题:Ctrl/Command + Shift + H
无序列表:Ctrl/Command + Shift + U
有序列表:Ctrl/Command + Shift + O
检查列表:Ctrl/Command + Shift + C
插入代码:Ctrl/Command + Shift + K
插入链接:Ctrl/Command + Shift + L
插入图片:Ctrl/Command + Shift + G
查找:Ctrl/Command + F
替换:Ctrl/Command + G

合理的创建标题,有助于目录的生成

直接输入1次#,并按下space后,将生成1级标题。
输入2次#,并按下space后,将生成2级标题。
以此类推,我们支持6级标题。有助于使用TOC语法后生成一个完美的目录。

如何改变文本的样式

强调文本 强调文本

加粗文本 加粗文本

标记文本

删除文本

引用文本

H2O is是液体。

210 运算结果是 1024.

插入链接与图片

链接: link.

图片: Alt

带尺寸的图片: Alt

居中的图片: Alt

居中并且带尺寸的图片: Alt

当然,我们为了让用户更加便捷,我们增加了图片拖拽功能。

如何插入一段漂亮的代码片

去博客设置页面,选择一款你喜欢的代码片高亮样式,下面展示同样高亮的 代码片.

// An highlighted block
var foo = 'bar';

生成一个适合你的列表

  • 项目
    • 项目
      • 项目
  1. 项目1
  2. 项目2
  3. 项目3
  • 计划任务
  • 完成任务

创建一个表格

一个简单的表格是这么创建的:

项目Value
电脑$1600
手机$12
导管$1

设定内容居中、居左、居右

使用:---------:居中
使用:----------居左
使用----------:居右

第一列第二列第三列
第一列文本居中第二列文本居右第三列文本居左

SmartyPants

SmartyPants将ASCII标点字符转换为“智能”印刷标点HTML实体。例如:

TYPEASCIIHTML
Single backticks'Isn't this fun?'‘Isn’t this fun?’
Quotes"Isn't this fun?"“Isn’t this fun?”
Dashes-- is en-dash, --- is em-dash– is en-dash, — is em-dash

创建一个自定义列表

Markdown
Text-to- HTML conversion tool
Authors
John
Luke

如何创建一个注脚

一个具有注脚的文本。2

注释也是必不可少的

Markdown将文本转换为 HTML

KaTeX数学公式

您可以使用渲染LaTeX数学表达式 KaTeX:

Gamma公式展示 Γ ( n ) = ( n − 1 ) ! ∀ n ∈ N \Gamma(n) = (n-1)!\quad\forall n\in\mathbb N Γ(n)=(n1)!nN 是通过欧拉积分

Γ ( z ) = ∫ 0 ∞ t z − 1 e − t d t   . \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}dt\,. Γ(z)=0tz1etdt.

你可以找到更多关于的信息 LaTeX 数学表达式here.

新的甘特图功能,丰富你的文章

2014-01-07 2014-01-09 2014-01-11 2014-01-13 2014-01-15 2014-01-17 2014-01-19 2014-01-21 已完成 进行中 计划一 计划二 现有任务 Adding GANTT diagram functionality to mermaid
  • 关于 甘特图 语法,参考 这儿,

UML 图表

可以使用UML图表进行渲染。 Mermaid. 例如下面产生的一个序列图:

张三 李四 王五 你好!李四, 最近怎么样? 你最近怎么样,王五? 我很好,谢谢! 我很好,谢谢! 李四想了很长时间, 文字太长了 不适合放在一行. 打量着王五... 很好... 王五, 你怎么样? 张三 李四 王五

这将产生一个流程图。:

链接
长方形
圆角长方形
菱形
  • 关于 Mermaid 语法,参考 这儿,

FLowchart流程图

我们依旧会支持flowchart的流程图:

Created with Raphaël 2.3.0 开始 我的操作 确认? 结束 yes no
  • 关于 Flowchart流程图 语法,参考 这儿.

导出与导入

导出

如果你想尝试使用此编辑器, 你可以在此篇文章任意编辑。当你完成了一篇文章的写作, 在上方工具栏找到 文章导出 ,生成一个.md文件或者.html文件进行本地保存。

导入

如果你想加载一篇你写过的.md文件,在上方工具栏可以选择导入功能进行对应扩展名的文件导入,
继续你的创作。


  1. mermaid语法说明 ↩︎

  2. 注脚的解释 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/540668.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

EPSON L4160 Series打印机驱动安装

EPSON L4160 Series 官方网站下载 win64驱动 accept后自动下载。 安装 添加 网络打印机可以自动搜索并识别。 win11 设置里 -这里改名字 -比如我是192.168.50.115

基于springboot+vue的汽车租赁管理系统

背景介绍: 网络发展的越来越迅速,它深刻的影响着每一个人生活的各个方面。每一种新型事务的兴起都是为了使人们的生活更加方便。汽车租赁管理系统是一种低成本、更加高效的电子商务方式,它已慢慢的成为一种全新的管理模式。人们不再满足于在互联网上浏览…

AutoCAD 2024 安装注册教程

前言 大家好,我是梁国庆。 本篇分享的安装包是 AutoCAD 的全新版本——AutoCAD 2024,下文安装教程中简称 AutoCAD。 时间:2024年4月8日。 获取 AutoCAD 安装包 我已将本篇所使用的安装包打包上传至百度云,扫描下方二维码关注…

Docker+Uwsgi部署Django项目

在之前的文章中,已经给大家分享了在docker中使用django自带的命令部署项目,这篇文章主要讲解如何使用uwsgi部署。 1. 在Django项目的根目录下新建Dockerfile文件 #Dockerfile文件 # 使用 Python 3.9 作为基础镜像 FROM python:3.9# 设置工作目录 WORKDI…

计算机视觉——图像特征提取D2D先描述后检测特征提取算法原理

概述 局部特征提取是计算机视觉中的一个重要任务,它旨在从图像中提取出能够代表图像局部结构和外观信息的特征。这些特征通常用于图像匹配、物体识别、三维重建、跟踪和许多其他应用。传统方法,如尺度不变特征变换(SIFT)&#xf…

Java——面向对象的初步认识

目录 一.什么是面向对象 二.面向对象与面向过程 1. 传统洗衣服过程(面向过程) 2. 现代洗衣服过程(面向对象) 一.什么是面向对象 Java是一门纯面向对象的语言(Object Oriented Program,简称OOP),在面向…

Navicat连接SQL server出现:[IM002] [Microsoft][ODBC 驱动程序管理器] 未发现数据源名称并且未指定默认驱动程序(0)

问题 解决方法 一 找到Navicat的安装路径,然后找到sqlncli_x64.msi文件并安装,安装成功后重启Navicat重新进行连接,看是否成功。 解决方法 二 如果方法一没有找到找到sqlncli_x64.msi 还是Navicat的安装路径,然后找到msodbcsql_64…

【Proteus仿真】按键控制LED流水灯定时器时钟

0~65535 每隔1us计数加1 总共定时时间65535us 64535离计数器溢出差值1000&#xff0c;所以计时时间为1ms #include <REGX51.H> void inittimer0() {TMOD0x01;//0000 0001TF00;//SCON可位寻址&#xff0c;TF1产生中断TR01;//定时器启动TL064535%256;//定时1msTH064536/256…

功能测试_验证某城市电话号码的正确性

案例&#xff1a;验证某城市电话号码的正确性 功能测试_等价类设计用例&#xff1a; 步骤&#xff1a; 1:明确需求&#xff1a;电话号码是否正确 2:划分等价类&#xff1a;有效等价类、有效取值、无效等价类、无效取值 3&#xff1a;提取数据编写用例&#xff1a;用例编号…

基于matlab动态化绘制一个彩色边框的爱心

一、版本1 % 定义爱心曲线的参数方程 t linspace(0, 2*pi, 100); x 16*sin(t).^3; y 13*cos(t) - 5*cos(2*t) - 2*cos(3*t) - cos(4*t);% 创建图形 figure; axis equal; axis off; title(爱心);% 循环遍历每个点&#xff0c;绘制不同颜色的线段 for i 1:length(t)-1% 清除…

【深度学习】图像风格混合——StyleGAN2原理解析

1、前言 上一篇文章&#xff0c;我们详细讲解了StyleGAN的原理。这篇文章&#xff0c;我们就来讲解一下StyleGAN2&#xff0c;也就是StyleGAN的改进版。 原论文&#xff1a;Analyzing and Improving the Image Quality of StyleGAN 参考代码&#xff1a;①Pytorch版本&#…

openjudge_2.5基本算法之搜索_1700:八皇后问题

题目 1700:八皇后问题 总时间限制: 10000ms 内存限制: 65536kB 描述 在国际象棋棋盘上放置八个皇后&#xff0c;要求每两个皇后之间不能直接吃掉对方。 输入 无输入。 输出 按给定顺序和格式输出所有八皇后问题的解&#xff08;见Sample Output&#xff09;。 样例输入 样例输…

汇舟问卷:海外问卷调查怎么样?

越来越多的企业决定采用线上调查的方式来了解消费者的意愿。这种转变不仅反映了科技发展的必然趋势&#xff0c;也凸显了企业对市场动态和消费者需求的高度重视。 线上调查能够覆盖更广泛的受众群体&#xff0c;通过互联网的普及&#xff0c;企业可以轻松地触及全国各地的消费…

HackTheBox-Machines--MonitorsTwo

文章目录 0x01 信息收集0x02 CVE-2022-46169 漏洞利用0x03 权限提升0x04 提升到root权限 MonitorsTwo 测试过程 0x01 信息收集 a.端口扫描: 发现22、80端口    b.信息收集: 1.2.22 Cacti信息收集 nmap -sC -sV 10.129.186.1321.访问 10.129.186.132&#xff0c;为 1.2.22 Ca…

Vue pdfjs

最终效果图 官网 https://mozilla.github.io/pdf.js 下载 放入项目 vue页面嵌入本地下载好的html sessionStorage.setItem(sdfDldj8KJ45SDF, encodeURIComponent(file_url)) <template><div style"height:100%"><iframe:id"1":key"…

vue快速入门(二十一)计算属性

注释很详细&#xff0c;直接上代码 上一篇 新增内容 计算属性的基本应用 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.…

OceanBase V4.2 MySQL模式下,如何通过DBLINK实现跨数据源访问

概述 跨数据源访问可通过 DBLINK&#xff08;以下简称DBLINK&#xff09;实现&#xff0c;从而使得业务代码能够像访问本地数据库一样轻松访问远端数据库。原先&#xff0c;DBLINK主要服务于Oracle模式&#xff0c;但由于OceanBase 的MySQL模式租户同样存在访问远端数据库的需…

C语言 | Leetcode C语言题解之第26题删除有序数组中的重复项

题目&#xff1a; 题解&#xff1a; int removeDuplicates(int* nums, int numsSize) {if (numsSize 0) {return 0;}int fast 1, slow 1;while (fast < numsSize) {if (nums[fast] ! nums[fast - 1]) {nums[slow] nums[fast];slow;}fast;}return slow; }

【机器学习】深入剖析贝叶斯算法原理及其广泛应用

一、引言 在机器学习的广阔领域中&#xff0c;贝叶斯算法以其独特的概率推理方式占据了重要的地位。它不仅为分类问题提供了有效的解决方案&#xff0c;还在自然语言处理、信息检索、垃圾邮件过滤等诸多领域发挥着不可替代的作用。 贝叶斯算法的基本思想源于贝叶斯定理&#xf…

leetcode热题100.爬楼梯(从二进制到快速幂)

Problem: 70. 爬楼梯 文章目录 题目思路Code复杂度 题目 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;2 解释&#xff1a;有两种方…