Group Query Attention (GQA) 机制详解以及手动实现计算

Group Query Attention (GQA) 机制详解

1. GQA的定义

Grouped-Query Attention (GQA) 是对 Multi-Head Attention (MHA) 和 Multi-Query Attention (MQA) 的扩展。通过提供计算效率和模型表达能力之间的灵活权衡,实现了查询头的分组。GQA将查询头分成了G个组,每个组共享一个公共的键(K)和值(V)投影。

2. GQA的变体

GQA有三种变体:

  • GQA-1:一个单独的组,等同于 Multi-Query Attention (MQA)。
  • GQA-H:组数等于头数,基本上与 Multi-Head Attention (MHA) 相同。
  • GQA-G:一个中间配置,具有G个组,平衡了效率和表达能力。
3. GQA的优势

使用G个组可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。GQA提供了对模型质量和效率的细致控制。

4. GQA的实现

GQA的最简形式可以通过实现 GroupedQueryAttention 类来实现。GroupedQueryAttention 类继承自 Attention 类,重写了 forward 方法,其中使用了 MultiQueryAttention 类的实例来处理每个组的查询。通过将每个组的结果拼接起来,然后与投影矩阵进行矩阵乘法运算,最终得到 GQA 的输出。[1]

pytorch 示例实现:

假设我们有以下初始化的query, key, value:

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)
1. 确定分组数量

首先,我们需要确定将查询头分为多少组。在这个例子中,我们有8个查询头,而键和值的头数为2,所以我们可以将查询头分为4组,每组有2个查询头。

2. 对查询进行分组

然后,我们将查询头分组。我们可以使用 torch.chunk 函数将查询张量沿着头维度分割成4个组,每个组有2个头。

query_groups = torch.chunk(query, 4, dim=2)  # shape of each group: (1, 256, 2, 64)

3. 计算注意力分数

对于每一个查询组,我们计算它与键的注意力分数。我们首先计算查询组和键的点积,然后通过 torch.softmax 函数得到注意力分数。

attention_scores = []
for query_group in query_groups:
    score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)
    score = torch.softmax(score, dim=-1)
    attention_scores.append(score)
4. 计算注意力输出

接下来,我们使用注意力分数对值进行加权求和,得到每一个查询组的注意力输出。

attention_scores = []
for query_group in query_groups:
    score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)
    score = torch.softmax(score, dim=-1)
    attention_scores.append(score)
5. 拼接输出

最后,我们将所有查询组的注意力输出拼接起来,得到最终的 Grouped Query Attention 的输出。

attention_outputs = []
for score in attention_scores:
    output = torch.matmul(score, value)  # shape: (1, 256, 2, 64)
    attention_outputs.append(output)

这就是 Grouped Query Attention 的实现过程。在这个过程中,我们将查询头分组,然后对每一个查询组分别计算注意力分数和输出,最后将所有查询组的输出拼接起来。这样可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。


  1. Grouped-Query Attention (GQA) - The Large Language Model Playbook

  2. 安全验证 - 知乎
  3. 安全验证 - 知乎
  4. 安全验证 - 知乎
  5. Grouped-Query Attention (GQA) - The Large Language Model Playbook

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/557931.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PE文件注入恶意代码改变代码执行顺序教程

原理:因为PE文件3个段都需要对齐200h,这样就会导致会有很多填充0,在填充0的地方就有机会进行插入恶意代码 在原有的PE文件的基础上修改,文件是我之前发布的一篇博客PE文件构造 主要思路:原来的PE文件中.text的有Mess…

2024面试软件测试,常见的面试题(上)

一、综合素质 1、自我介绍 面试官您好,我叫XXX,一直从事车载软件测试,负责最多的是中控方面。 以下是我的一些优势: 车载的测试流程我是熟练掌握的,且能够独立编写测试用例。 平时BUG提交会使用到Jira,类似…

CSS中的盒子模型

目录 盒子模型介绍 盒子模型组成 盒子边框 边框的基本使用 边框影响盒子大小 盒子内边距 内边距的基本使用 内边距影响盒子大小 内边距不影响盒子大小的情况 盒子外边距 外边距的基本使用 外边距的常见使用 外边距合并问题 相邻块元素垂直外边距的合并 嵌套块元…

Qt - 窗口

目录 1. 前言 2. 菜单栏(QMenuBar) 2.1. 创建菜单栏 2.1.1. 方式一 2.1.2. 方式二 2.2. 在菜单栏中添加菜单和创建菜单项 2.3. 在菜单项之间添加分割线 2.4. 综合示例 3. 工具栏(QToolBar) 3.1. 创建工具栏 3.2. 设置停靠位置 3.2.1. 方式一 3.2.2. 方式二 3.3. 设…

【前端】input输入框输入文字加文字轮廓效果

【前端】input输入框输入文字加文字轮廓效果 两种方案 方案一 输入框文字轮廓DEMO1通过文字阴影实现 <!DOCTYPE html> <html lang"en"> <head><title>输入框文字轮廓DEMO1通过文字阴影实现</title> <meta charset"UTF-8&quo…

【Linux进阶之路】高级IO

一、 铺垫 I&#xff0c;即input为输入&#xff1b;O&#xff0c;即output为输出&#xff0c;IO&#xff0c;即input output为输入输出。IO一般是基于网卡&#xff0c;磁盘&#xff0c;光盘&#xff0c;U盘&#xff0c;磁盘&#xff0c;磁带等毫秒级别的外存&#xff0c;相较…

《QT实用小工具·三十一》基于QT开发的访客管理平台demo2

1、概述 源码放在文章末尾 该项目为访客管理平台demo&#xff0c;包含主界面、系统设置、警情查询、调试帮助、用户退出功能。 项目部分代码如下&#xff1a; #pragma execution_character_set("utf-8")#include "frmmain.h" #include "ui_frmmain…

SpringBoot + Redis实现用户信息登录的缓存

&#x1f34e;前言 &#x1f350;项目的背景 背景&#xff1a;&#x1f349;当我们在完成用户信息登录时&#xff0c;我们往往每次都会在数据库中查询用户的记录&#xff0c;生成token并返回给前端&#xff0c;不过这样会有一定的问题。 &#x1f350;造成的问题 问题&#xf…

Linux 用户和组

理解Linux 用户和组的概念 掌握passwd 文件的组成以及作用 掌握shadow 文件的组成以及作用 了解group 文件的内容 1.用户分类&#xff1a; 超级管理员&#xff08;root&#xff09; 普通用户 程序用户 1.用户信息文件 /etc/passwd 文件中存储了所有用户信息。 1.passwd 格…

python中中英文打印对齐解决方案

在python中&#xff0c;有时候会出现中英文混合输出的情形&#xff0c;但是由于中文默认是全角格式&#xff08;一个中文字符占用两个字符宽度&#xff09;&#xff0c;这会对python原生的print函数带来一些障碍。尤其是用户用print对齐输出的时候&#xff0c;这种差异会导致文…

月球地形数据介绍(LOLA)

月球地形数据介绍 LOLA介绍LOLA数据的处理与发布数据类型和格式投影坐标系SIMPLE CYLINDRICALPOLAR STEREOGRAPHIC 数据下载与浏览 LOLA介绍 目前最新的月球地形高程数据来源于美国2009年发射的LRO探测器。 “月球勘测轨道器”(Lunar Reconnaissance Orbiter&#xff0c;LRO)…

文心一言 VS 讯飞星火 VS chatgpt (240)-- 算法导论17.3 6题

六、证明&#xff1a;如何用两个普通的栈实现一个队列(练习10.1-6)&#xff0c;使得每个 ENQUEUE 和 DEQUEUE 操作的摊还代价为 O(1) 。练习10.1-6的内容是&#xff1a;说明如何用两个栈实现一个队列&#xff0c;并分析相关队列操作的运行时间。如果要写代码&#xff0c;请用go…

【MIT6.824】lab2C-persistence, lab2D-log compaction 实现笔记

引言 lab2C的实验要求如下 Complete the functions persist() and readPersist() in raft.go by adding code to save and restore persistent state. You will need to encode (or “serialize”) the state as an array of bytes in order to pass it to the Persister. Us…

el-table使用show-summary合计,但只需要合并某一列

el-table使用show-summary合计&#xff0c;但只需要合并某一列 这里有两种方法&#xff0c;一种是网上的&#xff0c;我会引用他的链接给你们看。 一种是我自己看源码发现的 方法一 这个就是方法一的链接 点击我跳转方法一 方法二 不需要计算的列 去掉prop 然后用插槽显示即可…

比特币叙事大转向

作者&#xff1a;David Lawant 编译&#xff1a;秦晋 要理比特币解减半动态&#xff0c;最关键的图表是下面这张&#xff0c;而不是价格图表。它显示了自 2012 年以来&#xff0c;矿业总收入与比特币现货交易量的比例&#xff0c;并标注了三个减半日期。 虽然矿工仍然是比特币生…

Python八股文:基础知识Part2

1. Python中变量的保存和访问 Python中的变量实际上是一个指向对象的引用&#xff0c;每个对象都有一个唯一的标识符&#xff08;即内存地址&#xff09;。对于一些不可变对象&#xff0c;如字符串和整数&#xff0c;因为它们的值不可更改&#xff0c;所以当多个变量引用相同的…

OSPF大型实验

OSPF大型实验 实验拓扑图 实验思路 1、R4为ISP&#xff0c;其上只配置IP地址&#xff1b;R4与其他所直连设备间均使用公有IP&#xff1b; 2、R3-R5、R6、R7为MGRE环境&#xff0c;R3为中心站点&#xff1b; 3、整个OSPF环境IP基于172.16.0.0/16划分&#xff1b;除了R12有两…

Java精品项目--第8期基于SpringBoot的宠物用品商城的设计分析与实现

项目使用技术栈 SpringBootThymeleafMyBatisMySQLAopJavajdk1.8 项目介绍 项目截图

JAVA 集合框架(一) Collection集合详解和常用方法

Java集合框架&#xff08;Java Collections Framework&#xff09;是一个强大的、高度灵活的数据结构库&#xff0c;它为Java应用程序提供了组织、存储和操作对象集合的标准方法。 集合类体系结构 接口类&#xff1a; Collection: 是单例集合类的顶层接口&#xff0c;包括Lis…

移动端适配之viewport

目录 盒模型&#xff1a;widthcontent&#xff08;padding border&#xff09; class"content-box"内容盒模型&#xff08;W3C盒&#xff09; class"border-box"边框盒模型&#xff08;IE 盒&#xff09; scroll滚动 window浏览器视窗&#xff1a;包括…