MHD、MQA、GQA注意力机制详解

MHD、MQA、GQA注意力机制详解

  • 注意力机制详解及代码
    • 前言:
    • MHA
    • MQA
    • GQA

注意力机制详解及代码

前言:

自回归解码器推理是 Transformer 模型的 一个严重瓶颈,因为在每个解码步骤中加 载解码器权重以及所有注意键和值会产生 内存带宽开销

下图为三种注意力机制的结构图和实验结果

在这里插入图片描述

在这里插入图片描述

MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MutiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        
        query = self.split_head(query)
        key = self.split_head(key)
        value = self.split_head(value)
        
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        ## 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)
        
        return output
 
        
    def split_head(self, x):
        batch_size = x.size()[0]
        return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用了多查询注意力,如PaLM但许多语言模型没有,包括公开可用的语言模型,如T5和LLaM.

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=1,v=1)。相当于多个query,即多查询。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MutiQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.head_dim) ###
        self.v_linear = nn.Linear(hidden_size, self.head_dim) ###
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        
        query = self.split_head(query)
        key = self.split_head(key, 1)
        value = self.split_head(value, 1)
        
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)
        
        return output
        
        
        
        
    def split_head(self, x, head_num=None):
        
        batch_size = x.size()[0]
        
        if head_num == None:
            return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        else:
            return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)
    
    

GQA

  • 使用 5% 的原始预训练 计算将现有的多头语言模型检查点训 练到具有 MQA 的模型中
  • 引入分组查询注意力 (GQA),这是多 头语言模型的泛化。查询注意力,它使用中间,多于一个,少于查询头数量的键值头。
  • 经过训练的GQA 实现了接近多头注意力 的质量,并且速度与 MQA 相当。
  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=group_num,v=group_num)。相当于把多头分组了,比如原先有10个头,那就是10个query,分成5组,每组2个query,1个value,1个key。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 分组注意力查询
import torch
from torch import nn
class MutiGroupAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads, group_num):
        super(MutiGroupAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.group_num = group_num
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        
        query = self.split_head(query)
        key = self.split_head(key, self.group_num)
        value = self.split_head(value, self.group_num)
        
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)
        
        return output
        
        
        
        
    def split_head(self, x, group_num=None):
        
        batch_size,seq_len = x.size()[:2]
        
        if group_num == None:
            return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        else:
            x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)
            x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)
            return x
    
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/628946.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

安防视频汇聚/智能分析云平台EasyCVR调用localfile接口会返回日志的问题该如何解决?

视频汇聚/安防视频融合云平台EasyCVR视频监控系统支持多协议接入、兼容多类型设备,平台能在复杂的网络环境中(专网、局域网、广域网、VPN、公网等)将前端海量的设备进行统一集中接入与视频汇聚管理。视频监控/集中存储系统EasyCVR平台可支持国…

中青杯全国大学生数学建模竞赛纳入多所高校学科竞赛认定目录

2024年第六届中青杯全国大学生数学建模竞赛将于2024年5月23日17:00至5月26日17:00举行,中青杯全国大学生数学建模竞赛是中国高校学科竞赛中规模较大、影响较广的学科竞赛之一,并且纳入多所高校学科竞赛认定目录。 报名截止时间:2024年5月23日12:00 报名网站:http://www.c…

Hadoop 3.4.0 项目实战

1环境基于 上一篇搭建 高可用分布式集群 2 官方提供MapReduce程序 #评估圆周率 cd /data/hadoop/share/hadoop/mapreduce/ hadoop jar hadoop-mapreduce-examples-3.4.0.jar pi 2 6 3 实例项目分析1 #预分析的文件如,如单词统计 # #上传文件到hdfs hdfs …

淘系淘宝订单详情api接口(订单详情,订单列表,出售中,库存等属性)

淘系淘宝订单详情api接口(订单详情,订单列表,出售中,库存等属性)

【基础算法总结】二分查找二

二分查找二 1.山脉数组的峰顶索引2.寻找峰值3.寻找旋转排序数组中的最小值4.点名 点赞👍👍收藏🌟🌟关注💖💖 你的支持是对我最大的鼓励,我们一起努力吧!😃😃 1.山脉数组的…

【vue3】vue3中如何使用typescript

简言 现在vue3和typescript搭配使用是一个较常见的方案,下面参考vue3官网总结下在vue项目中使用ts(TypeScript)的方法。 typescript配置 新建项目 如果你准备新建vue3项目,那么使用create-vue官方脚手架,它提供了搭建基于 Vite 且 TypeSc…

vue-pure-admin项目内复制文字粘贴到word中之后存在边框问题

vue-pure-admin项目内复制文字粘贴到word中之后存在黑色边框是由于reset.scss文件内设置了通配符的border样式 修改前 代码 *, ::before, ::after {box-sizing: border-box;// 添加这个样式会导致复制的文字粘贴到word中带有边框问题border-color: currentColor;border-styl…

CCF PTA 2022年11月C++学生会提名

【问题描述】 学生会选举要开始了。根据选举规则,首先由全体同学进行提名,每位同学可以从全体同学中提 名一名同学参选。选举时,会从全体同学的提名中选出一名学生会主席,再从三个年级分别的提名中 各选出一名副主席。现在&#…

sa-token权限认证框架,最简洁,最实用讲解

查看源码,可知,sa sa-token框架 测试代码源码配置自动装配SaTokenConfigSaTokenConfigFactory SaManager工具类SaFoxUtilStpUtilSaResult StpLogic持久层定时任务 会话登录生成token创建account-session事件驱动模型写入tokenSaSessionSaCookieSaTokenDa…

elementui,iview等 表格单元格合并之固定列

要的效果如下 需要合并 show weak 及 Siginin这三列 上代码 <template><Table:columns"columns":span-method"handleSpan":data"data"bordersize"small"ref"table"></Table> </template> <sc…

Linux备份---异地

参考文档&#xff1a;Linux环境实现mysql所在服务器定时同步数据文件到备份服务器&#xff08;异地容灾备份场景&#xff09;_mysql异地备份-CSDN博客 通过SSH进行连接&#xff1a; 应用服务器&#xff1a; 通过ssh-keygen -t rsay建立ssh通信的密钥 密钥建立后&#xff0c;…

JavaScript-输入输出语句

输出语句 document.write( 输出的内容 ) 语法&#xff1a;document.write( 输出的内容) 作用&#xff1a;内容会显示在网页上 如果输出的内容是标签&#xff0c;也会被解析为网页元素 代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head>&…

cubemx配置stm32f407VET6实现can通信

背景&#xff1a; 项目上需要把原先的TMC5160电机驱动器替换为购买的电机控制模块&#xff08;该模块采用canopen通信&#xff09; 移植canopen的前提是can通信正常&#xff0c;现在添加一下can通信&#xff08;先用标准帧&#xff0c;250K bit/S的波特率测试&#xff09; 原理…

【回溯】1255. 得分最高的单词集合

本文涉及知识点 回溯 力扣难道&#xff1a;1881 LeetCode1255. 得分最高的单词集合 你将会得到一份单词表 words&#xff0c;一个字母表 letters &#xff08;可能会有重复字母&#xff09;&#xff0c;以及每个字母对应的得分情况表 score。 请你帮忙计算玩家在单词拼写游戏…

系统管理(System Keeping):Codigger资源与配置管理(上)

系统管理&#xff08;System Keeping&#xff09;&#xff0c;作为Codigger不可或缺的一部分&#xff0c;为开发者提供全面而高效的资源与配置管理体验。下面&#xff0c;让我们从它的其中三方面来一探究竟其强大的功能如何助力开发者提升工作效率。 一、环境配置&#xff1a;全…

Linux交叉编译

一. 交叉编译 1.使用环境要求 新版本的orangepi-build是在Ubuntu22.04的x64电脑或虚拟机上运行的 lsb_release -a //查看自己的虚拟机版本 因为编译出的SDK大概有16G大小&#xff0c;因此&#xff0c;至少给虚拟机分配50G的大小。 2.获取Linux SDK 方法一&#xff1a;从…

React框架-Next 学习-1

创建一个 Next.js 应用,node版本要高&#xff0c;16.5以上 npm淘宝镜像切为https://registry.npmmirror.com npm config set registry https://registry.npmmirror.com npx create-next-applatest//安装后 使用npm run dev 启动 Next.js 是围绕着 页面&#xff08;pages&am…

智慧园区EasyCVR视频智能管理方案:构建高效安全园区新视界

一、背景分析 园区作为城市的基本单元&#xff0c;是最重要的人口和产业聚集区。根据行业市场调研&#xff0c;90%以上城市居民工作与生活在园区进行&#xff0c;80%以上的GDP和90%以上的创新在园区内产生&#xff0c;可以说“城市&#xff0c;除了马路都是园区”。 园区形态…

高通QCS6490开发(二)AI板卡接口

QCS6490是高通公司针对高端物联网终端而优化的SoC&#xff0c;在性能和功耗上有最优的平衡。《高通QCS6490 AIoT应用开发》是一系列AIoT应用开发文章&#xff0c;介绍如何基于QCS6490平台做AIIoT的应用开发。 本文主要介绍FV01开发板的内部和外部接口。 内部的板载接口如下 接口…

怎么做私域?先来了解私域运营模式!

现在&#xff0c;很多企业都在做私域&#xff0c;但仍旧有很多人会问&#xff1a;我的私域到底要怎么做&#xff1f; 关于这个问题&#xff0c;不同产品无论在消费频次与客单价上&#xff0c;还是在决策链路的长度和复杂度上&#xff0c;都有巨大的差异&#xff0c;消费者需要…