大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn


# rms归一化
class RMSNorm(nn.Module):
    """

    """
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        hidden_states = hidden_states.float()
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.float()


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed


# 旋转位置编码
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super(RotaryEmbedding, self).__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len).float().unsqueeze(1)
        freqs = t @ inv_freq.unsqueeze(0)
        freqs = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())

    def forward(self, q, k):
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
        return apply_rotate_pos_emb(q, k, cos, sin)


class MLA(nn.Module):
    def __init__(self,
                 dim,
                 n_heads,
                 q_lora_rank,
                 kv_lora_rank,
                 qk_nope_head_dim,
                 qk_rope_head_dim,
                 v_head_dim,
                 max_seq_len,
                 max_batch_size,
                 mode):
        super().__init__()
        self.dim = dim  # 隐藏层维度
        self.n_heads = n_heads  # 总头数
        self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度
        self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度
        self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度
        self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度
        self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度
        self.mode = mode
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size

        self.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵
        self.q_norm = RMSNorm(self.q_lora_rank)
        self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵
        # 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944

        self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵
        # nn.Linear(self.dim, self.kv_lora_rank)
        # nn.Linear(self.dim, self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵

        self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)

        self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码
        # 没有矩阵融合
        if self.mode == 'naive':
            self.register_buffer('k_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),
                                 persistent=False)
            self.register_buffer('v_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),
                                 persistent=False)
        # 有矩阵融合
        else:
            self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),
                                 persistent=False)
            self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),
                                 persistent=False)

    def forward(self, x, mask=None):

        bs, seq_len, _ = x.shape

        q = self.wq_a(x)  # [bs, seq_len, q_lora_rank]
        q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]
        q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]
        q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
                                   dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]

        kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],
                               dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]

        k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个key
        q_pe, k_pe = self.rotary_emb(q_pe, k_pe)
        if self.mode == 'naive':

            q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]

            kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]
            kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]
            kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)
            # k shape:[bs, seq_len, n_heads, qk_head_dim]
            self.k_cache[:bs, :seq_len, :, :] = k
            self.v_cache[:bs, :seq_len, :, :] = v
            # scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            scores = torch.matmul(q.transpose(1, 2),
                                  self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(
                                      self.qk_nope_head_dim + self.qk_rope_head_dim))
            scores = scores.transpose(1, 2)

        else:
            k_pe = k_pe.squeeze(2)
            wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
            wkv_b = wkv_b.view(self.n_heads, -1,
                               self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope,
                                  wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]
            # q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v
            # wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/v
            kv = self.kv_norm(kv)
            self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]
            self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]
            scores_nope = torch.einsum("bshc,btc->bsht", q_nope,
                                       self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bsht
            scores_pe = torch.einsum("bshr,btr->bsht", q_pe,
                                     self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bsht
            scores = (scores_nope + scores_pe) / math.sqrt(
                self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]

        if mask is not None:
            # mask shape:[bs, seq_len, seq_len]
            scores += mask.unsqueeze(2)

        scores = scores.softmax(dim=-1)

        if self.mode == 'naive':
            x = torch.einsum("bsht,bthd->bshd", scores,
                             self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshd
        else:

            # scores * v = scores * c * wkv_b[:, -self.v_head_dim:]
            x = torch.einsum("bsht,btc->bshc", scores,
                             self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshd

        x = x.contiguous().view(bs, seq_len, -1)
        x = self.wo(x) 

        return x


if __name__ == '__main__':
    torch.manual_seed(0)
    torch.set_printoptions(precision=3, sci_mode=False)

    x = torch.randn(1, 4, 16)

    dim = 16
    n_heads = 2
    q_lora_rank = 10
    kv_lora_rank = 6
    qk_nope_head_dim = 8
    qk_rope_head_dim = 4
    v_head_dim = 8
    max_seq_len = 10
    max_batch_size = 4
    mode = 'none'

    mla = MLA(dim=dim,
              n_heads=n_heads,
              q_lora_rank=q_lora_rank,
              kv_lora_rank=kv_lora_rank,
              qk_nope_head_dim=qk_nope_head_dim,
              qk_rope_head_dim=qk_rope_head_dim,
              v_head_dim=v_head_dim,
              max_seq_len=max_seq_len,
              max_batch_size=max_batch_size,
              mode=mode)

    print(mla(x))
    print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/966588.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue3+codemirror6实现公式(规则)编辑器

实现截图 实现/带实现功能 插入标签 插入公式 提示补全 公式验证 公式计算 需要的依赖 "codemirror/autocomplete": "^6.18.4","codemirror/lang-javascript": "^6.2.2","codemirror/state": "^6.5.2","cod…

MIT 6.5940(一)

记录了Lecture 1~8 Lecture 1 Introduction TinyML and Efficient Deep Learning Computing 摘要 AI systems need to continually adapt to new data collected locally 在设备学习:better privacy, lower cost, customization, life-long learningTraining is…

Linux TCP 编程详解与实例

一、引言 在网络编程的领域中,TCP(Transmission Control Protocol)协议因其可靠的数据传输特性而被广泛应用。在 Linux 环境下,使用 C 或 C 进行 TCP 编程可以实现各种强大的网络应用。本文将深入探讨 Linux TCP 编程的各个方面&…

一款由 .NET 官方团队开源的电子商务系统 - eShop

项目介绍 eShop是一款由.NET官方开源的,基于.NET Aspire构建的用于参考学习的服务架构电子商务系统,旨在展示如何利用.NET框架及其相关技术栈构建一个现代化的电子商务网站。该项目采用服务架构,将应用程序分解为多个独立的服务,…

crewai框架第三方API使用官方RAG工具(pdf,csv,json)

最近在研究调用官方的工具,但官方文档的说明是在是太少了,后来在一个视频里看到了如何配置,记录一下 以PDF RAG Search工具举例,官方文档对于自定义模型的说明如下: 默认情况下,该工具使用 OpenAI 进行嵌…

2011-2020年各省电话普及率数据

2011-2020年各省电话普及率数据 1、时间:2011-2020年 2、来源:国家统计局、统计年鉴 3、指标:行政区划代码、地区名称、年份、电话普及率(包括移动电话)(部/百人) 4、范围:31省 5、指标说明:电话普及率是衡量一个…

【自开发工具介绍】SQLSERVER的ImpDp和ExpDp工具演示05

SQLSERVER的ImpDp和ExpDp工具演示 1、表部分数据导出 (-query) ※「-query」和「-include_table」必须一起使用 「-query」后面字符串是sql文的where语句,但要注意要使用%,需要写%% 验证用:导出的表,导入到新的数据库 db的数…

ASP.NET Core 使用 WebClient 从 URL 下载

本文使用 ASP .NET Core 3.1,但它在.NET 5、 .NET 6和.NET 8上也同样适用。如果使用较旧的.NET Framework,请参阅本文,不过,变化不大。 如果想要从 URL 下载任何数据类型,请参阅本文:HttpClient 使用WebC…

快速上手Vim的使用

Vim Linux编辑器-vim使用命令行模式下所有选项都可以带数字底行模式可视块模式(ctrlV进入) Linux编辑器-vim使用 Vim有多种模式的编辑器。能帮助我们很快的进行代码的编辑,甚至完成很多其他事情。 默认情况下我们打开vim在命令模式下&#x…

nodejs - vue 视频切片上传,本地正常,线上环境导致磁盘爆满bug

nodejs 视频切片上传,本地正常,线上环境导致磁盘爆满bug 原因: 然后在每隔一分钟执行du -sh ls ,发现文件变得越来越大,即文件下的mp4文件越来越大 最后导致磁盘直接爆满 排查原因 1、尝试将m3u8文件夹下的所有视…

114,【6】攻防世界 web wzsc_文件上传

进入靶场 传个桌面有的 直接空白了 我们 访问一下上传的东西 /index 没显示用于解析的.htaccess和.user.ini 文件,还两个都不显示 .htaccess 和 .user.ini 文件分别用于 Apache 服务器和 PHP-FPM 环境的目录级配置 但上传的时候bp查看状态码是200,…

Open3d Qt的环境配置

Open3d Qt的环境配置 一、概述二、操作流程2.1 下载文件2.2 新建文件夹2.3 环境变量设置2.4 qt6 引用3、qt中调用4、资源下载一、概述 目前统一使用qt6配置,open3d中可视化功能目前使用vtk代替,语言为c++。 二、操作流程 2.1 下载文件 访问open3d github链接,进入releas…

零基础都可以本地部署Deepseek R1

文章目录 一、硬件配置需求二、详细部署步骤1. 安装 Ollama 工具2. 部署 DeepSeek-R1 模型3. API使用4. 配置图形化交互界面(可选)5. 使用与注意事项 一、硬件配置需求 不同版本的 DeepSeek-R1 模型参数量不同,对硬件资源的要求也不尽相同。…

Rocky Linux9安装Zabbix7.0(精简版)

Linux 系统版本 Rocky Linux release 9.3 (Blue Onyx) 注意:zabbix 7以上版本不支持CentOS 7系统,需要CentOS 8以上, 本教程支持CentOS9及Rocky Linux 9 在Rocky Linux release 9.3测试通过 Linux环境准备 关闭防火墙和selinux #关闭防…

Qt程序发布

关注后回复 qt 获取相关资料 找到Qt安装目录中的 windeployqt.exe 将其路径添加到Path环境变量中可能会涉及到多平台架构的版本,选择一个目标版本将Release版中的 ***.exe 复制到某空文件夹cmd 进入上述文件夹中执行 windeployqt.exe ***.exe此时会将该 ***.exe 文件…

从O(k*n)到O(1):如何用哈希表终结多层if判断的性能困局

【前言】   本文将以哈希表重构实战为核心,完整展示如何将传统条件匹配逻辑(上千层if-else判断)转化为O(1)的哈希表高效实现。通过指纹验证场景的代码级解剖,您将深入理解:   1.哈希函数设计如何规避冲突陷阱   2.链式寻址法的工程实现…

后端java工程师经验之谈,工作7年,mysql使用心得

mysql 工作7年,mysql使用心得 mysql1.创建变量2.创建存储过程2.1:WHILE循环2.2:repeat循环2.3:loop循环2.4:存储过程,游标2.5:存储过程,有输入参数和输出参数 3.三种注释写法4.case …

【WB 深度学习实验管理】利用 Hugging Face 实现高效的自然语言处理实验跟踪与可视化

本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到,别忘了给仓库点个小心心~~~ https://github.com/LFF8888/FF-Studio-Resources 在自然语言处理领域,使用Hugging Face的Transformers库进行模型训练已经成为主流。然而,随着模型复…

智能理解 PPT 内容,快速生成讲解视频

当我们想根据一版 PPT 制作出相对应的解锁视频时,从撰写解锁词,录制音频到剪辑视频,每一个环节都需要投入大量的时间和精力,本方案将依托于阿里云函数计算 FC 和百炼模型服务,实现从 PPT 到视频的全自动转换&#xff0…

如何使用Gemini模型,国内如何订阅购买Gemini Pro的教程,Gemini Pro 免费试用操作步骤, 谷歌 aistudio 使用入口

最近的榜首又被Gemini给霸占了,很多童鞋想要体验一翻 Gemini免费库模型更新了 Gemini2.0向所有人开放了!使用了真香 目前呢2.0flash和Gemini-2.0-Flash-Thinking-Exp、Gemini-2.0-Flash-Thinking-Exp-with-apps已经免费给所有注册用户开放了&#xff0c…