AIGC笔记--基于PEFT库使用LoRA

1--相关讲解

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

PEFT-LoRA

2--基本原理

        固定原始层,通过添加和训练两个低秩矩阵,达到微调模型的效果;

3--简单代码

import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, LoraModel
from peft.utils import get_peft_model_state_dict

# 创建模型
class Simple_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(64, 128)
        self.linear2 = nn.Linear(128, 256)
    def forward(self, x: torch.Tensor):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

if __name__ == "__main__":
    # 初始化原始模型
    origin_model = Simple_Model()

    # 配置lora config
    model_lora_config = LoraConfig(
        r = 32, 
        lora_alpha = 32, # scaling = lora_alpha / r 一般来说,lora_alpha的参数初始化为与r相同,即scale=1
        init_lora_weights = "gaussian", # 参数初始化方式
        target_modules = ["linear1", "linear2"], # 对应层添加lora层
        lora_dropout = 0.1
    )

    # Test data
    input_data = torch.rand(2, 64)
    origin_output = origin_model(input_data)

    # 原始模型的权重参数
    origin_state_dict = origin_model.state_dict() 

    # 两种方式生成对应的lora模型,调用后会更改原始的模型
    new_model1 = get_peft_model(origin_model, model_lora_config)
    new_model2 = LoraModel(origin_model, model_lora_config, "default")

    output1 = new_model1(input_data)
    output2 = new_model2(input_data)
    # 初始化时,lora_B矩阵会初始化为全0,因此最初 y = WX + (alpha/r) * BA * X == WX
    # origin_output == output1 == output2

    # 获取lora权重参数,两者在key_name上会有区别
    new_model1_lora_state_dict = get_peft_model_state_dict(new_model1)
    new_model2_lora_state_dict = get_peft_model_state_dict(new_model2)

    # origin_state_dict['linear1.weight'].shape -> [output_dim, input_dim]
    # new_model1_lora_state_dict['base_model.model.linear1.lora_A.weight'].shape -> [r, input_dim]
    # new_model1_lora_state_dict['base_model.model.linear1.lora_B.weight'].shape -> [output_dim, r]
    print("All Done!")

4--权重保存和合并

核心公式是:new_weights = origin_weights + alpha* (BA)

    # 借助diffuser的save_lora_weights保存模型权重
    from diffusers import StableDiffusionPipeline
    save_path = "./"
    global_step = 0
    StableDiffusionPipeline.save_lora_weights(
            save_directory = save_path,
            unet_lora_layers = new_model1_lora_state_dict,
            safe_serialization = True,
            weight_name = f"checkpoint-{global_step}.safetensors",
        )

    # 加载lora模型权重(参考Stable Diffusion),其实可以重写一个简单的版本
    from safetensors import safe_open
    alpha = 1. # 参数融合因子
    lora_path = "./" + f"checkpoint-{global_step}.safetensors"
    state_dict = {}
    with safe_open(lora_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)

    all_lora_weights = []
    for idx,key in enumerate(state_dict):
        # only process lora down key
        if "lora_B." in key: continue

        up_key    = key.replace(".lora_A.", ".lora_B.") # 通过lora_A直接获取lora_B的键名
        model_key = key.replace("unet.", "").replace("lora_A.", "").replace("lora_B.", "")
        layer_infos = model_key.split(".")[:-1]

        curr_layer = new_model1

        while len(layer_infos) > 0:
            temp_name = layer_infos.pop(0)
            curr_layer = curr_layer.__getattr__(temp_name)

        weight_down = state_dict[key].to(curr_layer.weight.data.device)
        weight_up   = state_dict[up_key].to(curr_layer.weight.data.device)
        # 将lora参数合并到原模型参数中 -> new_W = origin_W + alpha*(BA)
        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
        all_lora_weights.append([model_key, torch.mm(weight_up, weight_down).t()])
        print('Load Lora Done')

5--完整代码

PEFT_LoRA

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/653750.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

boost asio异步服务器(2)实现伪闭包延长连接生命周期

闭包 在函数内部实现一个子函数,子函数的作用域内能访问外部函数的局部变量。闭包就是能够读取其他函数内部变量。但是由于闭包会使得函数中的变量都被保存在内存中,内存消耗很大,所以不能滥用闭包,否则会造成程的性能问题&#x…

没开玩笑!高速信号不能参考电源网络这条规则,其实很难做到

高速先生成员--黄刚 看到这篇文章的题目,我相信大家心里都呈现出了这么一个场景:高速信号线在L20层,我只要把L19和L21层都铺上完整的地平面,这不就满足了高速信号线不能参考电源平面这条规则了吗?这难道很难做到吗&…

51驱动DY-SV20F语音播放模块

51驱动DY-SV20F语音播放模块 简介模块特征电气参数工作模式配置原理图代码结果图 简介 DY-SV20F 是一款一对一分段触发控制播放器,支持 MP3,WAV 解码格式; 可分段触发 9 首曲目;低电平触发;3.7-5VDC 宽电压供电,直驱 …

expect自动交互

在执行命令或脚本的时候,当控制台提示我们需要输入账号密码、参数等信息的时候,expect可以将预设的参数值自动输入到控制台,实现了自动交互。 1. 安装expect yum install expect 2. 案例: 创建 demo.exp 文件,并添…

英语四级翻译练习笔记③——大学英语四级考试2023年12月真题(第三套)

目录 引言(必看) 四级翻译评分标准分析及真题解析 四级翻译评分标准 四级翻译真题 学生作答 1. 评分 2. 修正翻译中的错误 错误标记: 3. 改正句子 4. 标出错误单词 5. 标准答案 6. 常考万能句子 7.重点单词的中文意思 引言&…

移动应用程序设计详解:基本概念和原理

移动应用程序设计是什么? 一般来说,应用程序设计师的核心职责是让用户有体验应用的欲望,而开发者负责让它正常工作。移动应用程序设计包括用户界面 (UI) 和用户体验 (UX)。设计者负责应用程序的整体风格,包括配色方案、字体选择、…

关于如何通过APlayer+MetingJS为自己的wordpress博客网页添加网易音乐播放器(无需插件)

本文转自博主的个人博客:https://blog.zhumengmeng.work,欢迎大家前往查看。 原文链接:点我访问 序言:最近在网上冲浪,发现大家的博客大部分都有一个音乐播放器能够播放音乐,随机我也开始寻找解决方法。可是找来找去我…

达梦数据库查看字符集、页大小

1.查看字符集select UNICODE (); 0 表示 GB18030,1 表示 UTF-8,2 表示 EUC-KR 2.查看页大小select SF_GET_PAGE_SIZE(); 也可以通过管理工具去查看

【组合数学 放球问题 虚拟点 小于等于转小于】1621. 大小为 K 的不重叠线段的数目

本文涉及知识点 放球问题 组合数学汇总 本题难道分:2198 LeetCode1621. 大小为 K 的不重叠线段的数目 给你一维空间的 n 个点,其中第 i 个点(编号从 0 到 n-1)位于 x i 处,请你找到 恰好 k 个不重叠 线段且每个线段…

菊花链通信技术整理

目录 一、菊花链简介 二、菊花链与CAN通信的区别 三、常见的菊花链AFE芯片 四、菊花链数据结构 五、菊花链方案介绍 一、菊花链简介 首先简单的说一下菊花链以及菊花链的应用,在目前国内的BMS开发中,我们应用最广泛的目前还还是分布式,…

代码随想录算法训练营第七天| 454.四数相加II 、383. 赎金信、 15. 三数之和、18. 四数之和

454.四数相加II 题目链接: 454.四数相加II 文档讲解:代码随想录 状态:没做出来,没想到考虑重复的情况! 题解: public int fourSumCount(int[] nums1, int[] nums2, int[] nums3, int[] nums4) {// 结果计数…

java的变量关系~使用和扩展

一、变量的概述 1、什么是变量 白话:变量就是一个装东西的盒子。 通俗:变量是用于存放数据的容器。我们通过变量名 获取数据,甚至数据可以修改。 2、变量在内存中的存储 本质:变量是程序在内存中申请的一块用来存放数据的空间,类似我们酒店的房间&a…

基于多源数据的微服务系统失败测试用例诊断

简介 本文介绍由南开大学、华为云及清华大学共同合作的论文:基于多源数据的微服务系统失败测试用例诊断。该论文已被FSE 2024(The ACM International Conference on the Foundations of Software Engineering) 会议录用,论文标题为: Fault D…

JS中的数组很重要,怎样定义(声明)

为什么呢?在java中有集合,数组的作用就弱了,其高光时刻基本都被集合代替了。在JS中没有集合,数组就有点忙不过来了。你说它重要不重要?! 在JS中,怎样定义一个数组呢? 数组的声明方…

动手学操作系统(二、编写MBR主引导记录)

动手学操作系统(二、编写MBR主引导记录) 文章目录 动手学操作系统(二、编写MBR主引导记录)1. 实模式和保护模式2. BIOS与MBR3. MBR程序Reference 在之前的学习内容中,我们已经实现了基本的仿真环境bochs的搭建&#xf…

【Linux】数据链路层协议+ICMP协议+NAT技术

欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:Linux 目录 👉🏻数据链路层👉🏻以太网以太网帧格式网卡Mac地址对比ip地址 👉🏻MTUMTU…

员工管理和激励怎么做?试试场景化激励解决方案!

截止到2020年底,中国企业主体数量达3858.3万,同比增速达11.1%。如何留住人才、激励人才以强化人才与企业“黏性”,最大化提升员工的忠诚度与敬业度,成为企业未来人才发展战略的主要方向之一。 一、传统激励方式存在哪些不足 传统的…

【NumPy】权威指南:使用NumPy的percentile函数进行百分位数计算

🧑 博主简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向…

计算机找不到msvcr110.dll如何解决,总结5种简单靠谱的方法

在日常使用电脑的过程中,我们可能会遇到一些错误提示,其中之一就是“msvcr110.dll丢失”。这个错误通常会导致某些程序无法正常运行,为了解决这个问题,下面我将介绍5种有效的解决方法。 一,了解msvcr110.dll是什么 ms…

网络之再谈体系结构

大家都知道的是网络的体系结构,现代软件常用的体系结构无非是TCP/IP协议栈,OSI因为实现复杂并且效率没有TCP/IP协议栈好,所以不用OSI,但是,最近在复习网络知识的时候,发现了一些奇怪的地方,那就…