LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145368666


GQA
GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。

从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:

  • GQA:从头实现 LLaMA3 网络与推理流程
  • KVCache:GPT(Decoder Only) 类模型的 KV Cache 公式与原理

Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QK)V

1. MultiHeadAttention

MultiHeadAttention (多头注意力机制),合计 43 行:

  1. __init__ 初始化 (10行):
    • 输入:heads(头数)、d_model(维度)、dropout (用于 scores)
    • 计算 d_k 每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk
    • 线性层是 QKVO,Dropout 层
  2. attention 注意力 (10行):
    • q q q 的维度 [bs,h,s,d],与 k ⊤ k^{\top} k[bs,h,d,s],mm 之后 scores 是 [bs,h,s,s]
    • mask 的维度是 [bs,s,s],使用 unsqueeze(1),转换成 [bs,1,s,s]
    • QKV 的计算,额外支持 Dropout
  3. forward 推理 (12行):
    • QKV Linear 转换成 [bs,s,h,dk],再转换 [bs,h,s,dk]
    • 计算 attn 的 [bs,h,s,dk]
    • 转换 [bs,s,h,dk],再 contiguous(),再 合并 h × d k = d h \times d_{k} = d h×dk=d
    • 再过 O
  4. 测试 (11行):
    • torch.randn 构建数据
    • Mask 的 torch.tril(torch.ones(bs, s, s))

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制 MultiHeadAttention
    """
    def __init__(self, heads, d_model, dropout=0.1):  # 10行
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):  # 10行
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None):  # 12行
        bs = q.size(0)
        # 进行线性操作划分为成 h 个头
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        # 矩阵转置
        k = k.transpose(1, 2)  # [bs,h,s,d] = [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
        print(f"[Info] attn: {attn.shape}")
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output
def main():
    # 设置超参数
    bs, s, h, d = 2, 10, 8, 512
    dropout_rate = 0.1
    # 创建 MultiHeadAttention 实例
    attention = MultiHeadAttention(h, d, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 测试无掩码的情况
    output_no_mask = attention(q, k, v)
    print("Output shape without mask:", output_no_mask.shape)
    # 测试有掩码的情况
    output_with_mask = attention(q, k, v, mask)
    print("Output shape with mask:", output_with_mask.shape)
    # 检查输出是否符合预期
    assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
    assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
    print("Test passed!")
if __name__ == '__main__':
    main()

2. GroupQueryAttention

GroupQueryAttention (分组查询注意力机制),相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention

  1. __init__ :增加参数 kv_heads,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k)也需要修改。
  2. forward:使用 repeat_interleave 扩充 KV 维度,其他相同,增加 3 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
    """
    分组查询注意力机制(Group Query Attention)
    """
    def __init__(self, heads, d_model, kv_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.kv_heads = kv_heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):
        # [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        # 进行线性操作
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 10, 8, 64]
        k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 10, 4, 64]
        v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)
        # 复制键值头以匹配查询头的数量
        group = self.h // self.kv_heads
        k = k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]
        v = v.repeat_interleave(group, dim=2)
        # 矩阵转置, 将 head 在前
        k = k.transpose(1, 2)  # [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output
def main():
    # 设置超参数, GQA 8//4=2组
    bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
    dropout_rate = 0.1
    # 创建 MultiHeadAttention 实例
    attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 测试无掩码的情况
    output_no_mask = attention(q, k, v)
    print("Output shape without mask:", output_no_mask.shape)
    # 测试有掩码的情况
    output_with_mask = attention(q, k, v, mask)
    print("Output shape with mask:", output_with_mask.shape)
    # 检查输出是否符合预期
    assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
    assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
    print("Test passed!")
if __name__ == '__main__':
    main()

3. GQA + KVCache

GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:

  1. forward :增加参数 kv_cache,合并 [cached_k, new_k],同时返回 new_kv_cache,用于迭代,增加 5 行。
  2. 设置 cur_qkvcur_mask,迭代序列s维度,合计 8 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
    """
    分组查询注意力机制(Group Query Attention)
    """
    def __init__(self, heads, d_model, kv_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.kv_heads = kv_heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):
        # [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None, kv_cache=None):
        bs = q.size(0)
        # 进行线性操作
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 1, 8, 64]
        new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]
        new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]
        # 处理 KV Cache
        if kv_cache is not None:
            cached_k, cached_v = kv_cache
            new_k = torch.cat([cached_k, new_k], dim=1)
            new_v = torch.cat([cached_v, new_v], dim=1)
        # 复制键值头以匹配查询头的数量
        group = self.h // self.kv_heads
        k = new_k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]
        v = new_v.repeat_interleave(group, dim=2)
        # 矩阵转置, 将 head 在前
        # KV Cache 最后1轮: q—>[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]
        k = k.transpose(1, 2)  # [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)  # [2, 8, 1, 64]
        print(f"[Info] attn: {attn.shape}")
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        # 更新 KV Cache
        new_kv_cache = (new_k, new_v)  # 当前的 KV 缓存
        return output, new_kv_cache
def main():
    # 设置超参数
    bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
    dropout_rate = 0.1
    # 创建 GroupQueryAttention 实例
    attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 模拟逐步生成序列,测试 KV Cache
    print("Testing KV Cache...")
    kv_cache, output = None, None
    for i in range(s):
        cur_q = q[:, i:i+1, :]
        cur_k = k[:, i:i+1, :]
        cur_v = v[:, i:i+1, :]
        cur_mask = mask[:, i:i+1, :i+1]   # q是 i:i+1,k是 :i+1
        output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)
        print(f"Output shape at step {i}:", output.shape)
    # 检查输出是否符合预期
    assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"
    print("Test passed!")
if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/959988.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

视觉语言模型 (VLMs):跨模态智能的探索

文章目录 一. VLMs 的重要性与挑战:连接视觉与语言的桥梁 🌉二. VLMs 的核心训练范式:四种主流策略 🗺️1. 对比训练 (Contrastive Training):拉近正例,推远负例 ⚖️2. 掩码方法 (Masking):重构…

java8-日期时间Api

目录 LocalDate更新时间LocalTimeLocalDateTimeInstantPeriod Duration格式化、解析日期-时间对象时区 java.util.Date java.util.Calendar 不支持时区 线程不安全 月份从0起线程不安全,只有包裹在ThreadLocal中才安全 java.text.DateFormat java.text.SimpleDateFo…

深度学习 Pytorch 动态计算图与梯度下降入门

在上节末尾我们发现autograd.grad函数可以灵活进行函数某一点的导数和偏导数的运算,但微分运算只是AutoGrad模块中的一小部分功能,本节将继续讲解这个模块的其他常用功能,并在此基础上介绍另一个常用优化算法:梯度下降算法。 imp…

FreeRtos的使用教程

定义: RTOS实时操作系统, (Real Time Operating System), 指的是当外界事件发生时, 能够有够快的响应速度,调度一切可利用的资源, 控制实时任务协调一致的运行。 特点: 支持多任务管理, 处理多个事件, 实现更复杂的逻辑。 与计算…

大话特征工程:1.维数灾难与特征轮回

一、维度深渊 公元 2147 年,人类文明进入了数据驱动的超级智能时代。从金融到医疗,从教育到娱乐,所有决策都仰赖“全维计算网络”(高维特征空间)。这套系统将全球所有信息抽象成数以亿计的多维特征&#xff08…

从ai产品推荐到利用cursor快速掌握一个开源项目再到langchain手搓一个Text2Sql agent

目录 0. 经验分享:产品推荐 1. 经验分享:提示词优化 2. 经验分享:使用cursor 阅读一篇文章 3. 经验分享:使用cursor 阅读一个完全陌生的开源项目 4. 经验分享:手搓一个text2sql agent (使用langchain l…

《STL基础之hashtable》

【hashtable导读】STL为大家提供了丰富的容器,hashtable也是值得大家学习和掌握的基础容器,而且面试官经常会把它和hashmap混在一起,让同学们做下区分。因此关于hashtable的一些特性,比如:底层的数据结构、插入、查找元…

本地大模型编程实战(02)语义检索(2)

文章目录 准备按批次嵌入加载csv文件,分割文档并嵌入测试嵌入效果总结代码 上一篇文章: 本地大模型编程实战(02)语义检索(1) 详细介绍了如何使用 langchain 实现语义检索,为了演示方便,使用的是 langchain 提供的内存数据库。 在实…

猿人学第一题 js混淆源码乱码

首先检查刷新网络可知,m参数被加密,这是一个ajax请求 那么我们直接去定位该路径 定位成功 观察堆栈之后可以分析出来这应该是一个混淆,我们放到解码平台去还原一下 window["url"] "/api/match/1";request function…

Dev-C++分辨率低-解决办法

目录 【工具】Dev-C分辨率低-解决办法问题背景完整操作指南第一步:打开属性设置 【工具】Dev-C分辨率低-解决办法 问题背景 Dev-C因版本老旧(长期未更新),在高分辨率显示器上存在界面模糊问题。通过修改Windows兼容性设置可优化…

Linux 小火车

1.添加epel软件源 2.安装sl 3. 安装完成后输入: sl

iic、spi以及uart

何为总线? 连接多个部件的信息传输线,是部件共享的传输介质 总线的作用? 实现数据传输,即模块之间的通信 总线如何分类? 根据总线连接的外设属于内部外设还是外部外设将总线可以分为片内总线和片外总线 可分为数…

Linux_线程控制

线程控制的相关接口 进程创建相关 之前我们已经认识到了pthread_create函数用来创建线程&#xff0c;这里不再赘述。 pthread_self函数 void* routine(void* args) {std::cout << "我是新线程..." << pthread_self() << std::endl;return null…

利用双指针一次遍历实现”找到“并”删除“单链表倒数第K个节点(力扣题目为例)

Problem: 19. 删除链表的倒数第 N 个结点 文章目录 题目描述思路复杂度Code 题目描述 思路 1.欲找到倒数第k个节点&#xff0c;即是找到正数的第n-k1、其中n为单链表中节点的个数个节点。 2.为实现只遍历一次单链表&#xff0c;我们先可以使一个指针p1指向链表头部再让其先走k步…

Ubuntu-手动安装 SBT

文章目录 前言Ubuntu-手动安装 SBT1. SBT是什么?1.1. SBT 的特点1.2. SBT 的基本功能1.3. SBT 的常用命令 2. 安装2.1. 下载2.2. 解压 sbt 二进制包2.3. 确认 sbt 可执行文件的位置2.4. 设置执行权限2.5. 创建符号链接2.6. 更新 PATH 环境变量2.7. 验证 sbt 安装 前言 如果您觉…

【ProtoBuf 安装】ProtoBuf在window/Linux下的安装 创建/删除swap分区

文章目录 1.ProtoBuf在window下的安装2.ProtoBuf在Linux下的安装创建swap分区命令解析关闭swap分区删除swap分区的影响 1.ProtoBuf在window下的安装 1、下载ProtoBuf编译器 下载地址&#xff1a;https://github.com/protocolbuffers/protobuf/releases 如果要在 C 下使用 Pro…

BAHD酰基转移酶对紫草素的手性催化-文献精读105

Two BAHD Acyltransferases Catalyze the Last Step in the Shikonin/Alkannin Biosynthetic Pathway 两个BAHD酰基转移酶催化了紫草素/左旋紫草素生物合成途径中的最后一步 一个BAHD酰基转移酶专门催化紫草素的酰基化&#xff0c;而另一个BAHD酰基转移酶则仅催化紫草素的对映…

C语言初阶力扣刷题——349. 两个数组的交集【难度:简单】

1. 题目描述 力扣在线OJ题目 给定两个数组&#xff0c;编写一个函数来计算它们的交集。 示例&#xff1a; 输入&#xff1a;nums1 [1,2,2,1], nums2 [2,2] 输出&#xff1a;[2] 输入&#xff1a;nums1 [4,9,5], nums2 [9,4,9,8,4] 输出&#xff1a;[9,4] 2. 思路 直接暴力…

在Docker 容器中安装 Oracle 19c

在 Docker 容器中安装 Oracle 19c 是可行的&#xff0c;但它相较于其他数据库&#xff08;如 MySQL、PostgreSQL 等&#xff09;会复杂一些&#xff0c;因为 Oracle 数据库有一些特定的要求&#xff0c;如操作系统和库的依赖&#xff0c;以及许可证问题。 不过&#xff0c;Ora…

WGCLOUD使用介绍 - 如何监控ActiveMQ和RabbitMQ

根据WGCLOUD官网的信息&#xff0c;目前没有针对ActiveMQ和RabbitMQ这两个组件专门做适配 不过可以使用WGCLOUD已经具备的通用监测模块&#xff1a;进程监测、端口监测或者日志监测、接口监测 来对这两个组件进行监控