【LLM训练系列02】如何找到一个大模型Lora的target_modules

方法1:观察attention中的线性层

import numpy as np
import pandas as pd
from peft import PeftModel
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
from typing import List
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
import os
os.environ['CUDA_VISIBLE_DEVICES']='1,2'
os.environ["TOKENIZERS_PARALLELISM"] = "false"


model_path ="/home/jovyan/codes/llms/Qwen2.5-14B-Instruct"
base_model = AutoModel.from_pretrained(model_path, device_map='cuda:0',trust_remote_code=True)



打印attention模型层的名字

for name, module in base_model.named_modules():
    if 'attn' in name or 'attention' in name:  # Common attention module names
        print(name)
        for sub_name, sub_module in module.named_modules():  # Check sub-modules within attention
            print(f"  - {sub_name}")

方法2:通过bitsandbytes量化查找线性层

import bitsandbytes as bnb
def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            # model-specific
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)

加载模型

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
base_model = AutoModel.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map="auto"
    )

查找Lora的目标层

find_all_linear_names(base_model)


还有个函数,一样的原理

def find_target_modules(model):
    # Initialize a Set to Store Unique Layers
    unique_layers = set()
    
    # Iterate Over All Named Modules in the Model
    for name, module in model.named_modules():
        # Check if the Module Type Contains 'Linear4bit'
        if "Linear4bit" in str(type(module)):
            # Extract the Type of the Layer
            layer_type = name.split('.')[-1]
            
            # Add the Layer Type to the Set of Unique Layers
            unique_layers.add(layer_type)

    # Return the Set of Unique Layers Converted to a List
    return list(unique_layers)

find_target_modules(base_model)

方法3:通过分析开源框架的源码swift

代码地址

from collections import OrderedDict
from dataclasses import dataclass, field
from typing import List, Union


@dataclass
class ModelKeys:

    model_type: str = None

    module_list: str = None

    embedding: str = None

    mlp: str = None

    down_proj: str = None

    attention: str = None

    o_proj: str = None

    q_proj: str = None

    k_proj: str = None

    v_proj: str = None

    qkv_proj: str = None

    qk_proj: str = None

    qa_proj: str = None

    qb_proj: str = None

    kva_proj: str = None

    kvb_proj: str = None

    output: str = None


@dataclass
class MultiModelKeys(ModelKeys):
    language_model: Union[List[str], str] = field(default_factory=list)
    connector: Union[List[str], str] = field(default_factory=list)
    vision_tower: Union[List[str], str] = field(default_factory=list)
    generator: Union[List[str], str] = field(default_factory=list)

    def __post_init__(self):
        # compat
        for key in ['language_model', 'connector', 'vision_tower', 'generator']:
            v = getattr(self, key)
            if isinstance(v, str):
                setattr(self, key, [v])
            if v is None:
                setattr(self, key, [])


LLAMA_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.o_proj',
    q_proj='model.layers.{}.self_attn.q_proj',
    k_proj='model.layers.{}.self_attn.k_proj',
    v_proj='model.layers.{}.self_attn.v_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

INTERNLM2_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.feed_forward',
    down_proj='model.layers.{}.feed_forward.w2',
    attention='model.layers.{}.attention',
    o_proj='model.layers.{}.attention.wo',
    qkv_proj='model.layers.{}.attention.wqkv',
    embedding='model.tok_embeddings',
    output='output',
)

CHATGLM_KEYS = ModelKeys(
    module_list='transformer.encoder.layers',
    mlp='transformer.encoder.layers.{}.mlp',
    down_proj='transformer.encoder.layers.{}.mlp.dense_4h_to_h',
    attention='transformer.encoder.layers.{}.self_attention',
    o_proj='transformer.encoder.layers.{}.self_attention.dense',
    qkv_proj='transformer.encoder.layers.{}.self_attention.query_key_value',
    embedding='transformer.embedding',
    output='transformer.output_layer',
)

BAICHUAN_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    qkv_proj='model.layers.{}.self_attn.W_pack',
    embedding='model.embed_tokens',
    output='lm_head',
)

YUAN_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    qk_proj='model.layers.{}.self_attn.qk_proj',
    o_proj='model.layers.{}.self_attn.o_proj',
    q_proj='model.layers.{}.self_attn.q_proj',
    k_proj='model.layers.{}.self_attn.k_proj',
    v_proj='model.layers.{}.self_attn.v_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

CODEFUSE_KEYS = ModelKeys(
    module_list='gpt_neox.layers',
    mlp='gpt_neox.layers.{}.mlp',
    down_proj='gpt_neox.layers.{}.mlp.dense_4h_to_h',
    attention='gpt_neox.layers.{}.attention',
    o_proj='gpt_neox.layers.{}.attention.dense',
    qkv_proj='gpt_neox.layers.{}.attention.query_key_value',
    embedding='gpt_neox.embed_in',
    output='gpt_neox.embed_out',
)

PHI2_KEYS = ModelKeys(
    module_list='transformer.h',
    mlp='transformer.h.{}.mlp',
    down_proj='transformer.h.{}.mlp.c_proj',
    attention='transformer.h.{}.mixer',
    o_proj='transformer.h.{}.mixer.out_proj',
    qkv_proj='transformer.h.{}.mixer.Wqkv',
    embedding='transformer.embd',
    output='lm_head',
)

QWEN_KEYS = ModelKeys(
    module_list='transformer.h',
    mlp='transformer.h.{}.mlp',
    down_proj='transformer.h.{}.mlp.c_proj',
    attention='transformer.h.{}.attn',
    o_proj='transformer.h.{}.attn.c_proj',
    qkv_proj='transformer.h.{}.attn.c_attn',
    embedding='transformer.wte',
    output='lm_head',
)

PHI3_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.o_proj',
    qkv_proj='model.layers.{}.self_attn.qkv_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

PHI3_SMALL_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.dense',
    qkv_proj='model.layers.{}.self_attn.query_key_value',
    embedding='model.embed_tokens',
    output='lm_head',
)

DEEPSEEK_V2_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.o_proj',
    qa_proj='model.layers.{}.self_attn.q_a_proj',
    qb_proj='model.layers.{}.self_attn.q_b_proj',
    kva_proj='model.layers.{}.self_attn.kv_a_proj_with_mqa',
    kvb_proj='model.layers.{}.self_attn.kv_b_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

我的博客即将同步至腾讯云开发者社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=3hiaca88ulogc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/922737.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Selenium的八种定位方式

1. 通过 ID 定位 ID 是最直接和高效的方式来定位元素,因为每个页面中的 ID 应该是唯一的。 from selenium import webdriverdriver webdriver.Chrome(executable_pathpath/to/chromedriver) driver.get(https://example.com)# 通过 ID 定位 element driver.find…

MySQL底层概述—1.InnoDB内存结构

大纲 1.InnoDB引擎架构 2.Buffer Pool 3.Page管理机制之Page页分类 4.Page管理机制之Page页管理 5.Change Buffer 6.Log Buffer 1.InnoDB引擎架构 (1)InnoDB引擎架构图 (2)InnoDB内存结构 (1)InnoDB引擎架构图 下面是InnoDB引擎架构图,主要分为内存结构和磁…

丹摩|丹摩智算平台深度评测

1. 丹摩智算平台介绍 随着人工智能和大数据技术的快速发展,越来越多的智能计算平台涌现,为科研工作者和开发者提供高性能计算资源。丹摩智算平台作为其中的一员,定位于智能计算服务的提供者,支持从数据处理到模型训练的全流程操作…

基于企业微信客户端设计一个文件下载与预览系统

在企业内部沟通与协作中,文件分享和管理是不可或缺的一部分。企业微信(WeCom)作为一款广泛应用于企业的沟通工具,提供了丰富的API接口和功能,帮助企业进行高效的团队协作。然而,随着文件交换和协作的日益增…

LLM的原理理解6-10:6、前馈步骤7、使用向量运算进行前馈网络的推理8、注意力层和前馈层有不同的功能9、语言模型的训练方式10、GPT-3的惊人性能

目录 LLM的原理理解6-10: 6、前馈步骤 7、使用向量运算进行前馈网络的推理 8、注意力层和前馈层有不同的功能 注意力:特征提取 前馈层:数据库 9、语言模型的训练方式 10、GPT-3的惊人性能 一个原因是规模 大模型GPT-1。它使用了768维的词向量,共有12层,总共有1.…

大模型系列11-ray

大模型系列11-ray PlasmaPlasmaStore启动监听处理请求 ProcessMessagePlasmaCreateRequest请求PlasmaCreateRetryRequest请求PlasmaGetRequest请求PlasmaReleaseRequestPlasmaDeleteRequestPlasmaSealRequest ObjectLifecycleManagerGetObjectSealObject ObjectStoreRunnerPlas…

开源动态表单form-create-designer 扩展个性化配置的最佳实践教程

在开源低代码表单设计器 form-create-designer 的右侧配置面板里,field 映射规则为开发者提供了强大的工具去自定义和增强组件及表单配置的显示方式。通过这些规则,你可以简单而高效地调整配置项的展示,提升用户体验。 源码地址: Github | G…

美创科技入选2024数字政府解决方案提供商TOP100!

11月19日,国内专业咨询机构DBC德本咨询发布“2024数字政府解决方案提供商TOP100”榜单。美创科技凭借在政府数据安全领域多年的项目经验、技术优势与创新能力,入选收录。 作为专业数据安全产品与服务提供商,美创科技一直致力于为政府、金融、…

地平线 bev_cft_efficientnetb3 参考算法-v1.2.1

01 概述 在自动驾驶感知算法中 BEV 感知成为热点话题,BEV 感知可以弥补 2D 感知的缺陷构建 3D “世界”,更有利于下游任务和特征融合。 地平线集成了基于 bev 的纯视觉算法,目前已支持 ipm-based 、lss-based、 transformer-based&#xff…

C#里怎么样检测文件的属性?

C#里怎么样检测文件的属性? 对于文件来说,在C#里有一种快速的方法来检查文件的属性。 比如文件是否已经压缩, 文件是否加密, 文件是否是目录等等。 属性有下面这么多: 例子演示如下: /** C# Program to View the Information of the File*/ using System; using Syste…

最新‌VSCode保姆级安装教程(附安装包)

文章目录 一、VSCode介绍 二、VSCode下载 下载链接:https://pan.quark.cn/s/19a303ff81fc 三、VSCode安装 1.解压安装文件:双击打开并安装VSCode 2.勾选我同意协议:然后点击下一步 3.选择目标位置:点击浏览 4.选择D盘安装…

传输控制协议(TCP)和用户数据报协议(UDP)

一、传输控制协议(TCP) 传输控制协议(Transmission Control Protocol,TCP)是一种面向连接的、可靠的、基于字节流的传输层通信协议,由 IETF 的 RFC 793 定义。 它通过三次握手建立连接,确保数…

linux从0到1——shell编程9

声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&a…

nature communications论文 解读

题目《Transfer learning with graph neural networks for improved molecular property prediction in the multi-fidelity setting》 这篇文章主要讨论了如何在多保真数据环境(multi-fidelity setting)下,利用图神经网络(GNNs&…

基于Qt/C++/Opencv实现的一个视频中二维码解析软件

本文详细讲解了如何利用 Qt 和 OpenCV 实现一个可从视频和图片中检测二维码的软件。代码实现了视频解码、多线程处理和界面更新等功能,是一个典型的跨线程图像处理项目。以下分模块对代码进行解析。 一、项目的整体结构 项目分为以下几部分: 主窗口 (M…

【Elasticsearch入门到落地】2、正向索引和倒排索引

接上篇《1、初识Elasticsearch》 上一篇我们学习了什么是Elasticsearch,以及Elastic stack(ELK)技术栈介绍。本篇我们来什么是正向索引和倒排索引,这是了解Elasticsearch底层架构的核心。 上一篇我们学习到,Elasticsearch的底层是由Lucene实…

鸿蒙主流路由详解

鸿蒙主流路由详解 Navigation Navigation更适合于一次开发,多端部署,也是官方主流推荐的一种路由控制方式,但是,使用起来入侵耦合度高,所以,一般会使用HMRouter,这也是官方主流推荐的路由 Navigation官网地址 个人源码地址 路由跳转 第一步-定义路由栈 Provide(PageInfo) pag…

java使用itext生成pdf

一、利用Adobe Acrobat DC软件创建pdf模板 备好Adobe Acrobat DC软件 1.excel/jpg/png文件转pdf文件 右击打开我们要转换的文件 2.然后点击 添加 域 3.可以看到域的名字 4.调整字体大小/对齐方式等 5.保存 二&#xff0c;代码部分 首先 上依赖 <dependency><group…

生成对抗网络模拟缺失数据,辅助PAMAP2数据集仿真实验

PAMAP2数据集是一个包含丰富身体活动信息的数据集&#xff0c;它为我们提供了一个理想的平台来开发和测试HAR模型。本文将从数据集的基本介绍开始&#xff0c;逐步引导大家通过数据分割、预处理、模型训练&#xff0c;到最终的性能评估&#xff0c;在接下来的章节中&#xff0c…

全面解析:HTML页面的加载全过程(一)--输入URL地址,与服务器建立连接

用户输入URL地址&#xff0c;与服务器建立连接 用户在浏览器地址栏输入一个URL 浏览器开始执行以下三步操作操作&#xff1a;url解析、DNS查询、TCP连接 第一步&#xff1a;URL解析 什么是URL&#xff1f; URL(Uniform Resource Locator&#xff0c;统一资源定位符)是互联网…