MHA(Multi-Head Attention) 与GQA(Grouped Query Attention)的区别

引言

       Grouped Query Attention(GQA,分组查询注意力)和多头注意力机制(Multi-Head Attention,MHA)都是Transformer模型中用于捕获输入序列中不同位置之间关系的注意力机制。然而,它们在实现方式和计算复杂度上有所不同。下面我将详细介绍它们的原理以及它们之间的区别。

1. 多头注意力机制(MHA)

1.1 概念

       多头注意力机制(Multi-Head Attention,MHA)是Transformer模型中的核心组件。它通过并行的多组注意力机制,让模型能够在不同的子空间中关注序列的不同方面,从而加强模型的表达能力。

1.2 工作原理

       输入表示:给定输入序列 X ∈ R n × d model \mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}} XRn×dmodel n n n是序列长度, d model d_{\text{model}} dmodel是隐藏维度。

       线性投影:通过线性变换将输入 X \mathbf{X} X映射为查询( Q \mathbf{Q} Q)、键( K \mathbf{K} K)和值( V \mathbf{V} V):

Q = X W Q , K = X W K , V = X W V \mathbf{Q} = \mathbf{X}\mathbf{W}_Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}_K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}_V Q=XWQ,K=XWK,V=XWV

       拆分多头:将 Q \mathbf{Q} Q K \mathbf{K} K V \mathbf{V} V沿着隐藏维度拆分成 h h h个头,每个头的维度为 d k = d model / h d_k = d_{\text{model}} / h dk=dmodel/h

计算每个头的注意力:

head i = Attention ( Q i , K i , V i ) \text{head}_i = \text{Attention}(\mathbf{Q}_i, \mathbf{K}_i, \mathbf{V}_i) headi=Attention(Qi,Ki,Vi)

其中,

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left( \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d_k}} \right) \mathbf{V} Attention(Q,K,V)=softmax(dk QKT)V

拼接和线性变换:将所有头的输出拼接起来,再通过一个线性层:

Z = Concat ( head 1 , … , head h ) W O \mathbf{Z} = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \mathbf{W}_O Z=Concat(head1,,headh)WO

其中 W O ∈ R d model × d model \mathbf{W}O \in \mathbb{R}^{d{\text{model}} \times d_{\text{model}}} WORdmodel×dmodel

1.3 特点

       每个头有独立的Q、K和V的投影矩阵,能够在不同的子空间中捕获输入序列的不同特征。提高模型的多样性和表达能力。

2. 分组查询注意力(Grouped Query Attention,GQA)

2.1 概念

       分组查询注意力(Grouped Query Attention,GQA)是一种改进的注意力机制,旨在降低模型的参数量和计算复杂度,特别适用于资源受限的环境,如移动设备上的应用。

2.2 工作原理

       分组思想:GQA将多头注意力机制中的多个注意力头分为 g g g个组,每组共享一个查询投影矩阵,但仍然拥有独立的键和值投影矩阵。

输入表示:与MHA相同,输入 X ∈ R n × d model \mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}} XRn×dmodel

线性投影:
       查询:每组内的头共享查询投影矩阵,共有 g g g个查询投影矩阵:

Q ( j ) = X W Q ( j ) , j = 1 , 2 , … , g \mathbf{Q}^{(j)} = \mathbf{X} \mathbf{W}_Q^{(j)}, \quad j = 1, 2, \dots, g Q(j)=XWQ(j),j=1,2,,g
       键和值:每个头仍然有独立的键和值投影矩阵,总共有 h h h个键和值投影矩阵。

计算每个头的注意力:

       对于第 j j j组,第 i i i个头:

head ( j , i ) = Attention ( Q ( j ) , K ( j , i ) , V ( j , i ) ) \text{head}{(j, i)} = \text{Attention}(\mathbf{Q}^{(j)}, \mathbf{K}{(j, i)}, \mathbf{V}_{(j, i)}) head(j,i)=Attention(Q(j),K(j,i),V(j,i))

拼接和线性变换:与MHA类似,将所有头的输出拼接起来,通过线性层输出。

2.3 特点

减少参数量:由于查询投影矩阵在组内共享,参数量较MHA有所减少。

降低计算复杂度:共享查询减少了计算量,特别是在查询投影的部分。

折衷方案:在保持一定表达能力的同时,降低了模型的资源消耗。

3. MHA 和 GQA 的区别

3.1 查询投影矩阵(Q)的共享

MHA:每个头都有独立的查询、键和值投影矩阵。

GQA:查询投影矩阵在每个组内共享,键和值投影矩阵仍然是独立的。

3.2 参数量和计算复杂度

参数量

MHA:总参数量与头数 h h h成正比,因为每个头都有独立的投影矩阵。

GQA:参数量减少,因为查询投影矩阵共享,参数量与组数 g g g和头数 h h h有关。

计算复杂度

MHA:计算复杂度较高,需要计算所有独立投影。

GQA:由于查询投影减少,计算量有所降低,效率更高。

3.3 表达能力

MHA:每个头完全独立,具有最大的表达灵活性,能够在不同的子空间中捕获多样化的特征。

GQA:在组内头的查询受限于共享的投影矩阵,可能会略微降低表达能力,但通过保留独立的键和值投影矩阵,仍然能够捕获丰富的特征。

3.4 应用场景

       MHA:适用于对模型性能要求高、资源相对充足的场景,如服务器端的模型训练和推理。

       GQA:适用于资源受限的场景,如移动设备、嵌入式系统等,追求在降低资源消耗的同时保持较好的模型性能。

4. 直观理解

       MHA:想象每个注意力头都有自己独立的“视角”,从查询、键和值三个方面独立观察输入序列。

       GQA:在GQA中,每组内的注意力头共享“视角”(查询),但仍然可以通过自己的键和值关注不同的信息。这有点像一组人看着同一张地图(查询),但关注不同的地标(键和值)。

5. 总结

多头注意力机制(MHA)

特点:每个头都有独立的查询、键和值投影矩阵,最大化模型的表达能力。
优点:能够捕获输入序列中丰富的特征,适用于对性能要求高的场景。
缺点:参数量大,计算复杂度高,对资源要求较高。

分组查询注意力(GQA)

特点:在组内共享查询投影矩阵,减少参数量和计算量。
优点:在降低资源消耗的同时,尽可能保持模型的性能,适用于资源受限的场景。
缺点:由于共享查询,可能会影响模型的表达能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/905554.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Spring MVC】DispatcherServlet 请求处理流程

一、 请求处理 Spring MVC 是 Spring 框架的一部分,用于构建 Web 应用程序。它遵循 MVC(Model-View-Controller)设计模式,将应用程序分为模型(Model)、**视图(View)和控制器&#x…

现代数字信号处理I--最佳线性无偏估计 BLUE 学习笔记

目录 1. 最佳线性无偏估计的由来 2. 简单线性模型下一维参数的BLUE 3. 一般线性模型下一维参数的BLUE 4. 一般线性模型下多维参数的BLUE 4.1 以一维情况说明Rao论文中的结论 4.2 矢量参数是MVUE的本质是矢量参数中的每个一维参数都是MVUE 4.3 一般线性模型多维参数BLUE的…

QT(绘图)

目录 QPainter QPainter 的一些关键步骤和使用方法: QPainter 的一些常用接口: 1. 基础绘制接口 2. 颜色和画刷设置 3. 图像绘制 4. 文本绘制 5. 变换操作 6. 渲染设置 7. 状态保存与恢复 8. 其它绘制方法 示例代码1: 示例代码…

【js逆向学习】某多多anti_content逆向(补环境)

文章目录 声明逆向目标逆向分析逆向过程总结 声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,不提供完整代码,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的…

【安全解决方案】深入解析:如何通过CDN获取用户真实IP地址

一、业务场景 某大型互联网以及电商公司为了防止客户端获取到真实的ip地址,以及达到保护后端业务服务器不被网站攻击,同时又可以让公安要求留存网站日志和排查违法行为,以及打击犯罪的时候,获取不到真实的ip地址,发现…

Java | Leetcode Java题解之第524题通过删除字母匹配到字典里最长单词

题目&#xff1a; 题解&#xff1a; class Solution {public String findLongestWord(String s, List<String> dictionary) {int m s.length();int[][] f new int[m 1][26];Arrays.fill(f[m], m);for (int i m - 1; i > 0; --i) {for (int j 0; j < 26; j) {…

python爬虫抓取豆瓣数据教程

环境准备 在开始之前&#xff0c;你需要确保你的Python环境已经安装了以下库&#xff1a; requests&#xff1a;用于发送HTTP请求。BeautifulSoup&#xff1a;用于解析HTML文档。 如果你还没有安装这些库&#xff0c;可以通过以下命令安装&#xff1a; pip install requests…

Python实现深度学习模型预测控制(tensorflow)DL-MPC(Deep Learning Model Predictive Control

链接&#xff1a;深度学习模型预测控制 &#xff08;如果认为有用&#xff0c;动动小手为我点亮github小星星哦&#xff09;&#xff0c;持续更新中…… 链接&#xff1a;WangXiaoMingo/TensorDL-MPC&#xff1a;DL-MPC&#xff08;深度学习模型预测控制&#xff09;是基于 P…

简单的ELK部署学习

简单的ELK部署学习 1. 需求 我们公司现在使用的是ELK日志跟踪&#xff0c;在出现问题的时候&#xff0c;我们可以快速定为到问题&#xff0c;并且可以对日志进行分类检索&#xff0c;比如对服务名称&#xff0c;ip , 级别等信息进行分类检索。此文章为本人学习了解我们公司的…

神经网络进行波士顿房价预测

前言 前一阵学校有五一数模节校赛&#xff0c;和朋友一起参加做B题&#xff0c;波士顿房价预测&#xff0c;算是第一次自己动手实现一个简单的小网络吧&#xff0c;虽然很简单&#xff0c;但还是想记录一下。 题目介绍 波士顿住房数据由哈里森和鲁宾菲尔德于1978年Harrison …

Spark的集群环境部署

一、Standalone集群 1.1、架构 架构&#xff1a;普通分布式主从架构 主&#xff1a;Master&#xff1a;管理节点&#xff1a;管理从节点、接客、资源管理和任务 调度&#xff0c;等同于YARN中的ResourceManager 从&#xff1a;Worker&#xff1a;计算节点&#xff1a;负责利…

[java][基础]JSP

目标&#xff1a; 理解 JSP 及 JSP 原理 能在 JSP中使用 EL表达式 和 JSTL标签 理解 MVC模式 和 三层架构 能完成品牌数据的增删改查功能 1&#xff0c;JSP 概述 JSP&#xff08;全称&#xff1a;Java Server Pages&#xff09;&#xff1a;Java 服务端页面。是一种动态的…

常见问题 | 数字签名如何保障电子商务交易安全?

如何解决电商交易中数据泄露、交易欺诈等问题&#xff1f; 数字签名是一种类似于电子“指纹”的安全技术&#xff0c;它在电子商务中扮演着至关重要的角色。随着电子商务的迅猛发展&#xff0c;网上交易的数量不断增加&#xff0c;确保交易的安全性和完整性成为了亟待解决的问题…

【Python基础】

一、编程语言介绍 1、分类 机器语言 (直接用 0 1代码编写&#xff09;汇编语言 &#xff08;英文单词替代二进制指令&#xff09;高级语言 2、总结 1、执行效率&#xff1a;机器语言&#xff1e;汇编语言>高级语言&#xff08;编译型>解释型&#xff09; 2、开发效率&…

Java项目实战II基于Java+Spring Boot+MySQL的编程训练系统(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在当今数字…

双指针习题篇(上)

双指针习题篇(上) 文章目录 双指针习题篇(上)1.移动零题目描述&#xff1a;算法原理&#xff1a;算法流程&#xff1a;代码实现&#xff1a; 2.复写零题目描述&#xff1a;算法原理&#xff1a;算法流程&#xff1a;代码实现&#xff1a; 3.快乐数题目描述&#xff1a;算法原理…

更安全高效的文件传输工具,Ftrans国产FTP替代方案可以了解

文件传输协议&#xff08;FTP&#xff09;&#xff0c;诞生于1971年&#xff0c;自20世纪70年代发明以来&#xff0c;FTP已成为传输大文件的不二之选。内置有操作系统的 FTP 可提供一个相对简便、看似免费的文件交换方法&#xff0c;因此得到广泛使用。 随着企业发展过程中新增…

Leetcode21:合并两个有效链表

原题地址&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 题目描述 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1&#xff1a; 输入&#xff1a;l1 [1,2,4], l2 [1,3,4] 输出&#xff1a;[1,1,2,3,4,4]示…

《Mini-internVL》论文阅读:OpenGVLab+清华/南大等开源Mini-InternVL | 1~4B参数,仅用5%参数实现90%性能

论文地址Mini-InternVL: A Flexible-Transfer Pocket Multimodal Model with 5% Parameters and 90% PerformanceGitHub仓库地址模型使用教程和权重下载地址 该论文发表于2024年10月份&#xff0c;截止2024年11月&#xff0c;引用数<10 文章目录 论文摘要1. 引用介绍2. 本文…

双目视觉标定——1原理与实践

0 前言 双目视觉定位是目前机器&#xff08;机器人&#xff09;等领域中使用得非常广泛的视觉定位技术&#xff0c;双目视觉是模拟人的视觉系统利用两个不同位置的摄像头的视差来确定物体的位置。由于有需要采集两个摄像头的图像共同参与计算&#xff0c;所以双目相机装配要求…