辅导男朋友转算法岗的第2天|self Attention与kv cache

文章目录

    • 公式
    • KV Cache
    • MHA、MQA、GQA
  • 面试题

公式

$ \text{Output} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \times V$ 复杂度是O( n 2 n^2 n2)

KV Cache

推理阶段最常用的缓存机制,用空间换时间。

原理:

在进行自回归解码的时候,新生成的token会加入序列,一起作为下一次解码的输入。

由于单向注意力的存在,新加入的token并不会影响前面序列的计算,因此可以把已经计算过的每层的kv值保存起来,这样就节省了和本次生成无关的计算量。

通过把kv值存储在速度远快于显存的L2缓存中,可以大大减少kv值的保存和读取,这样就极大加快了模型推理的速度。

分别做一个k cache和一个v cache,把之前计算的k和v存起来

以v cache为例:

在这里插入图片描述

存在的问题:存储碎片化

解决方法:page attention(封装在vllm里了)

MHA、MQA、GQA

Multi-Head Attention、Multi-Query Attention、Group-Query Attention

目的:优化KV Cache所需空间大小

原理是共享k和v,但是使用MQA效果会差一些,于是又出现了GQA这种折中的办法

在这里插入图片描述

面试题

为什么除以 d k \sqrt{d_k} dk

压缩softmax输入值,以免输入值过大,进入了softmax的饱和区,导致梯度值太小而难以训练。

Multihead的好处

1、每个head捕获不同的信息,多个头能够分别关注到不同的特征,增强了表达能力。多个头中,会有部分头能够学习到更高级的特征,并减少注意力权重对角线值过大的情况。

比如部分头关注语法信息,部分头关注知识内容,部分头关注近距离文本,部分头关注远距离文本,这样减少信息缺失,提升模型容量。

2、类似集成学习,多个模型做决策,降低误差

decoder-only模型在训练阶段和推理阶段的input有什么不同?

  • 训练阶段:模型一次性处理整个输入序列,输入是完整的序列,掩码矩阵是固定的上三角矩阵。
  • 推理阶段:模型逐步生成序列,输入是一个初始序列,然后逐步添加生成的 token。掩码矩阵需要动态调整,以适应不断增加的序列长度,并考虑缓存机制。

手撕必背-多头注意力

逐头计算

import torch.nn as nn
class MultiHeadAttentionScores(nn.Module):

    def __init__(self, hidden_size, num_attention_heads, attention_head_size):
        super(MultiHeadAttentionScores, self).__init__()
        self.num_attention_heads = num_attention_heads # 8,16, 32, 64
        
        # Create a query, key, and value projection layer
        # for each attention head.  W^Q, W^K, W^V
        self.query_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])
        
        self.key_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])
        
        self.value_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])

    def forward(self, hidden_states):
        # Create a list to store the outputs of each attention head
        all_attention_outputs = []

        for i in range(self.num_attention_heads): # i.e. 8
            query_vectors = self.query_layers[i](hidden_states)
            key_vectors = self.key_layers[i](hidden_states)
            value_vectors = self.value_layers[i](hidden_states)
            
            # softmax(Q&K^T)*V
            attention_scores = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
            # attention_scores combined with softmax--> normalized_attention_score
            attention_outputs = torch.matmul(attention_scores, value_vectors)
            all_attention_outputs.append(attention_outputs)

        return all_attention_outputs

矩阵运算

import torch
import torch.nn as nn

class MultiHeadAttentionScores(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, attention_head_size):
        super(MultiHeadAttentionScores, self).__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = attention_head_size
        self.hidden_size = hidden_size
        
        self.query = nn.Linear(hidden_size, num_attention_heads * attention_head_size)
        self.key = nn.Linear(hidden_size, num_attention_heads * attention_head_size)
        self.value = nn.Linear(hidden_size, num_attention_heads * attention_head_size)

    def forward(self, hidden_states):
        batch_size = hidden_states.size(0)
        
        query_layer = self.query(hidden_states)
        key_layer = self.key(hidden_states)
        value_layer = self.value(hidden_states)
        
        query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        attention_outputs = torch.matmul(attention_probs, value_layer)
        
        attention_outputs = attention_outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.num_attention_heads * self.attention_head_size)
        
        return attention_outputs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/664115.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

nodemcu32s 和 mini D1 组局域网并用 webSocket 通信

实现思路 使用 mini D1 来搭建一个 webSocket 服务,然后使用 nodemcu32 连接,然后就可以进行通信了。 服务端代码(mini D1) 在代码中主要是需要控制好 loop 函数中的延时,也就是最后一行代码 delay,如果…

B端系统:角色与权限界面设计,一文读懂。

一、什么是角色与权限系统 角色与权限系统是一种用于管理和控制用户在系统中的访问和操作权限的机制。它通过将用户分配到不同的角色,并为每个角色分配相应的权限,来实现对系统资源的权限控制和管理。 在角色与权限系统中,通常会定义多个角色…

入门flask:Python后端开发的首选框架

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言:从零开始学习弗拉斯克 二、弗拉斯克的微框架哲学 三、弗拉斯克的核心…

字符串匹配算法(一)BF算法、RK算法

文章目录 BF算法算法详解算法实现 RK算法算法详解算法实现 BF算法 算法详解 BF算法也就是Brute Force 算法,中文叫暴力匹配算法,也叫朴素匹配算法。模式串和主串:例如:我们在字符串A中查找字符串B,那么字符串A就是主…

留言板——增添功能(持久化存储数据,使用MyBatis)

目录 一、数据准备 二、引入MyBatis 和 MySQL驱动依赖 三、配置MySQL账号密码 四、编写后端代码 五、调整前端代码 六、测试 之前的代码:综合性练习(后端代码练习3)——留言板_在线留言板前后端交互-CSDN博客 一、数据准备 创建数据库…

vivo鄢楠:基于OceanBase 的降本增效实践

在3 月 20 日的2024 OceanBase 数据库城市行中,vivo的 体系与流程 IT 部 DBA 组总监鄢楠就“vivo 基于 OceanBase 的降本增效实践”进行了主题演讲。本文为该演讲的精彩回顾。 vivo 在1995年于中国东莞成立,作为一家全球领先的移动互联网智能终端公司&am…

CentOS 7基础操作01_安装CentOS 7操作系统

1、实验环境 因为 Windows图形界面占用系统资源较高,所以公司准备将面向互联网的网站,数据库等重要应用基于Linux平台部署,并计划于近期将服务器安装开源免费的 CentOS 系统。进行前期准备工作时,需要公司的系统管理员尽快掌握 CentOS 系统的安装过程 2、需要描述 …

CSS 空间转换 动画

目录 1. 空间转换1.1 视距 - perspective1.2 空间转换 - 旋转1.3 立体呈现 - transform-style1.4 空间转换 - 缩放 2. 动画 - animation2.1 动画的基本用法2.1 animation 复合属性2.2 animation 拆分属性2.3 多组动画 正文开始 1. 空间转换 空间:是从坐标轴角度定义…

在AutoDL上部署Yi-34B大模型

在AutoDL上部署Yi-34B大模型 Yi介绍 Yi 系列模型是 01.AI 从零训练的下一代开源大语言模型。Yi 系列模型是一个双语语言模型,在 3T 多语言语料库上训练而成,是全球最强大的大语言模型之一。Yi 系列模型在语言认知、常识推理、阅读理解等方面表现优异。 …

【JavaEE】多线程(1)

🎆🎆🎆个人主页🎆🎆🎆 🎆🎆🎆JavaEE专栏🎆🎆🎆 🎆🎆🎆计算机是怎么工作的🎆&#x1f3…

第七在线惊艳亮相第11届奥莱峰会,AI驱动零售供应链升级

2024年5月22-24日,第11届奥莱领秀峰会暨2024奥莱产业经济论坛在南京盛大举行。论坛上,智能商品计划管理系统服务商第七在线凭借富有前瞻性的AI技术,引领零售供应链迈入全新升级阶段,赢得了与会嘉宾的广泛关注与赞誉。 峰会由中国奥…

【一百】【算法分析与设计】N皇后问题常规解法+位运算解法

N皇后问题 链接:登录—专业IT笔试面试备考平台_牛客网 来源:牛客网 题目描述 给出一个nnn\times nnn的国际象棋棋盘,你需要在棋盘中摆放nnn个皇后,使得任意两个皇后之间不能互相攻击。具体来说,不能存在两个皇后位于同…

BUUCTF Crypto RSA详解《1~32》刷题记录

文章目录 一、Crypto1、 一眼就解密2、MD53、Url编码4、看我回旋踢5、摩丝6、password7、变异凯撒8、Quoted-printable9、篱笆墙的影子10、Rabbit11、RSA12、丢失的MD513、Alice与Bob14、大帝的密码武器15、rsarsa16、Windows系统密码17、信息化时代的步伐18、凯撒?…

springboot基本使用十一(自定义全局异常处理器)

例如:我们都知道在java中被除数不能为0,为0就会报by zero错误 RestController public class TestController {GetMapping("/ex")public Integer ex(){int a 10 / 0;return a;}} 打印结果: 如何将这个异常进行处理? 创…

java——网络原理初识

T04BF 👋专栏: 算法|JAVA|MySQL|C语言 🫵 小比特 大梦想 目录 1.网络通信概念初识1.1 IP地址1.2端口号1.3协议1.3.1协议分层协议分层带来的好处主要有两个方面 1.3.2 TCP/IP五层 (或四层模型)1.3.3 协议的层和层之间是怎么配合工作的 1.网络通信概念初识…

RLC防孤岛保护装置如何工作的?

什么是RLC防孤岛保护装置? 孤岛保护装置是电力系统中一道强大的守护利器,它以敏锐的感知和迅速的反应,守护着电网的平稳运行。当电网遭遇故障或意外脱离主网时,孤岛保护装置如同一位机警的守门人,立刻做出决断&#xf…

算法(七)插入排序

文章目录 插入排序简介代码实现 插入排序简介 插入排序(insertion sort)是从第一个元素开始,该元素就认为已经被排序过了。然后取出下一个元素,从该元素的前一个索引下标开始往前扫描,比该值大的元素往后移动。直到遇到比它小的元…

MySQL 索引的使用

本篇主要介绍MySQL中索引使用的相关内容。 目录 一、最左前缀法则 二、索引失效的场景 索引列运算 字符串无引号 模糊查询 or连接条件 数据分布 一、最左前缀法则 当我们在使用多个字段构成的索引时(联合索引),需要考虑最左前缀法则…

【VTKExamples::Utilities】第十七期 ZBuffer

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例ZBuffer,并解析接口vtkWindowToImageFilter,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ…

【面试八股总结】MySQL事务:事务特性、事务并行、事务的隔离级别

参考资料:小林coding 一、事务的特性ACID 原子性(Atomicity) 一个事务是一个不可分割的工作单位,事务中的所有操作,要么全部完成,要么全部不完成,不会结束在中间某个环节。原子性是通过 undo …