transfomer中Multi-Head Attention的源码实现

简介

Multi-Head Attention是一种注意力机制,是transfomer的核心机制.
在这里插入图片描述
Multi-Head Attention的原理是通过将模型分为多个头,形成多个子空间,让模型关注不同方面的信息。每个头独立进行注意力运算,得到一个注意力权重矩阵。输出的结果再通过线性变换和拼接操作组合在一起。这样可以提高模型的表示能力和泛化性能。
在Multi-Head Attention中,每个头的权重矩阵是随机初始化生成的,并在训练过程中通过梯度下降等优化算法进行更新。通过这种方式,模型可以学习到如何将输入序列的不同部分关联起来,从而捕获更多的上下文信息。
总之,Multi-Head Attention通过将模型分为多个头,形成多个子空间,让模型关注不同方面的信息,提高了模型的表示能力和泛化性能。它的源码实现基于Scaled Dot-Product Attention,通过并行运算和组合输出来实现多头注意力机制。

源码实现:

具体源码及其注释如下,配好环境可直接运行:

import torch
from torch import nn


class MultiheadAttention(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 att_dropout=0.1,
                 out_dropout=0.1,
                 average_attn_weights=True):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.att_dropout = nn.Dropout(att_dropout)
        self.out_dropout = nn.Dropout(out_dropout)
        self.average_attn_weights = average_attn_weights
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** 0.5
        assert self.embed_dim == self.num_heads * self.head_dim, \
            'embed_dim <{}> must be divisible by num_heads <{}>'.format(self.embed_dim, self.num_heads)
        self.fuse_heads = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                identity=None,
                query_pos=None,
                key_pos=None):
        assert query.dim() == 3 and key.dim() == 3 and value.dim() == 3
        assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
        tgt_len, bsz, embed_dim = query.shape  # [查询数量 batch数量 特征维度]
        src_len, _, _ = key.shape  # [被查询数量,_,_]
        # 默认和query进行shortcut(要在位置编码前,因为output为输出特征,特征和原特征shortcut,下一层再重新加位置编码,否则不就重了)
        if identity is None:
            identity = query
        # 位置编码
        if query_pos is not None:
            query = query + query_pos
        if key_pos is not None:
            key = key + key_pos
        # 特征划分为self.num_heads 份 [tgt,b,embed_dim] -> [b,n_h, tgt, d_h]
        # [n,b,n_h*d_h] -> [b,n_h,n,d_h] 主要是target和source之前的特征匹配和提取, batch和n_h维度不处理
        query = query.contiguous().view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
        key = key.contiguous().view(src_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
        value = value.contiguous().view(src_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
        # [b,n_h,tgt_len,src_len] Scaled Dot-Product Attention
        attention = query @ key.transpose(-2, -1)
        attention /= self.scale  # 参考: https://blog.csdn.net/zwhdldz/article/details/135462127
        attention = torch.softmax(attention, dim=-1)  # 行概率矩阵
        attention = self.att_dropout(input=attention)  # 正则化方法 DropKey,用于缓解 Vision Transformer 中的过拟合问题
        # [b,n_h,tgt_len,d_h] = [b,n_h,tgt_len,src_len] * [b,n_h,src_len,d_h]
        output = attention @ value
        # [b,n_h,tgt_len,d_h] -> [b,tgt_len,embed_dim]
        output = output.permute(0, 2, 1, 3).contiguous().view(tgt_len, bsz, embed_dim)
        # 头之间通过全连接融合一下
        output = self.fuse_heads(output)
        output = self.out_dropout(output)
        # shortcut
        output = output + identity
        # 多头head求平均
        if self.average_attn_weights:
            attention = attention.sum(dim=1) / self.num_heads
        # [tgt_len,b,embed_dim],[b,tgt_len,src_len]
        return output, attention


if __name__ == '__main__':
    query = torch.rand(size=(10, 2, 64))
    key = torch.rand(size=(5, 2, 64))
    value = torch.rand(size=(5, 2, 64))
    query_pos = torch.rand(size=(10, 2, 64))
    key_pos = torch.rand(size=(5, 2, 64))

    att = MultiheadAttention(64, 4)
    # 返回特征采样结果和attention矩阵
    output = att(query=query, key=key, value=value,query_pos=query_pos,key_pos=key_pos)
    pass

具体流程说明:

  1. 将input映射为qkv,如果是cross_attention,q与kv的行数可以不同,但列数(编码维度/通道数)必须相同
  2. q和v附加位置编码
  3. Scaled Dot-Product :通过计算Query和Key之间的点积除以scale得到注意力权重,经过dropout再与Value矩阵相乘得到输出。*scale和dropout的说明参考我的上一篇博客
  4. 输出的结果再通过线性变换融合多头信息。

在实现中,为了提高模型的表示能力和泛化性能,将Scaled Dot-Product Attention过程多次并行运行,形成多个头(head)。每个头分别进行注意力运算,输出的结果再通过线性变换和拼接操作组合在一起。每个头的权重矩阵是随机初始化生成的,并在训练过程中通过梯度下降等优化算法进行更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/319378.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大模型背景下计算机视觉年终思考小结(一)

1. 引言 在过去的十年里&#xff0c;出现了许多涉及计算机视觉的项目&#xff0c;举例如下&#xff1a; 使用射线图像和其他医学图像领域的医学诊断应用使用卫星图像分析建筑物和土地利用率相关应用各种环境下的目标检测和跟踪&#xff0c;如交通流统计、自然环境垃圾检测估计…

国内首款支持苹果Find My芯片-伦茨科技ST17H6x

深圳市伦茨科技有限公司&#xff08;以下简称“伦茨科技”&#xff09;发布ST17H6x Soc平台。成为继Nordic之后全球第二家取得Apple Find My「查找」认证的芯片厂家&#xff0c;该平台提供可通过Apple Find My认证的Apple查找&#xff08;Find My&#xff09;功能集成解决方案。…

PYTHON通过跳板机巡检CENTOS的简单实现

实现的细节和引用的文件和以前博客记录的基本一致 https://shaka.blog.csdn.net/article/details/106927633 差别在于,这次是通过跳板机登陆获取的主机信息,只记录差异的部份 1.需要在跳板机相应的路径放置PYTHON的脚本resc.py resc.py这个脚本中有引用的文件(pm.sh,diskpn…

代码随想录 Leetcode242. 有效的字母异位词

题目&#xff1a; 代码&#xff08;首刷看解析 2024年1月14日&#xff09;&#xff1a; class Solution { public:bool isAnagram(string s, string t) {int hash[26] {0};for(int i 0; i < s.size(); i) {hash[s[i] - a];}for(int i 0; i < t.size(); i) {hash[t[i]…

java编程解小学生一年级竞赛题

抖音教学视频 目录 1、题目三角形加起来为10 大纲 1、题目三角形加起来为10 连接&#xff1a;小学一年级数学竞赛练习题3套&#xff0c;有点难度&#xff01; 第16题 此方法不是最优解&#xff0c;穷举法&#xff0c;比较暴力解决 主要给大家演示如何用编程去解决我们的实…

智能寻迹避障清障机器人设计(电路图附件+代码)

附 录 智能小车原理图 智能小车拓展板原理图 智能小车拓展板PCB 智能小车底板PCB Arduino UNO原理图 Arduino UNO PCB 程序部分 void Robot_Traction() //机器人循迹子程序{//有信号为LOW 没有信号为HIGHSR digitalRead(SensorRight);//有信号表明在白…

vue3 - 自定义弹框组件

写了一个弹框组件 <template><transition name"modal-fade"><div v-if"showFlag" class"myModal"><div class"content"><div class"topBox"><div class"leftTitle"><spa…

Chapter 8 怎样使用类和对象(下篇)

⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️ 8.2 对象数组 1.对象数组的每一个元素都是同类的对象 2.在建立数组时&#xff0c;同样…

day18【LeetCode力扣】19.删除链表的倒数第N个结点

day18【LeetCode力扣】19.删除链表的倒数第N个结点 1.题目描述 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5]示例 2&#xff1a; 输入&a…

SpringBoot+SSM项目实战 苍穹外卖(10) Spring Task WebSocket

继续上一节的内容&#xff0c;本节学习Spring Task和WebSocket&#xff0c;并完成订单状态定时处理、来单提醒和客户催单功能。 目录 Spring Task&#xff08;cron表达式&#xff09;入门案例 订单状态定时处理WebSocket入门案例 来单提醒客户催单 Spring Task&#xff08;cron…

pyqt调用UI和开启子进程

UI制作 qrc 注意调用UI前把样式表里绑定的资源(qrc)转换成py导入进去 xxx.qrc转xxx.py 两种方法 1命令 pyrcc5 -o icons_rc.py icons.qrc 2外部工具pyrcc 实参 -o $FileNameWithoutExtension$.py $FileNameWithoutExtension$.qrcsdz.qrc→→sdaz.py 在代码里写 import…

Springboot3新特性:开发第一个 GraalVM 本机应用程序(完整教程)

在讲述之前&#xff0c;各位先自行在网上下载并安装Visual Studio 2022&#xff0c;安装的时候别忘了勾选msvc 概述&#xff1a;GraalVM 本机应用程序&#xff08;Native Image&#xff09;是使用 GraalVM 的一个特性&#xff0c;允许将 Java 应用程序编译成本机二进制文件&am…

视频剪辑软件Camtasia2024最新版本快捷键大全

Camtasia Studio是一款专门录制屏幕动作的工具&#xff0c;它能在任何颜色模式下轻松地记录 屏幕动作&#xff0c;包括影像、音效、鼠标移动轨迹、解说声音等等。 今天来给大家介绍一下Camtasia快捷键的相关内容&#xff0c;Camtasia也是一个十分好用的电脑屏幕录制与视频剪辑…

GRE隧道(初级VPN)配置步骤

一、拓朴图&#xff1a; 二、配置步骤&#xff1a; 1、配置IP 2、R1、R2 配置nat&#xff0c;代理内网地址通过G0/0/0口上外网 acl 2000rule permit source anyquit # int G0/0/0ip addr 100.1.1.1 24nat outbound 2000 # 3、R1、R2 配置默认出口路由G0/0/0&#xff0c;这一…

Windows启动MongoDB服务报错(错误 1053:服务没有及时响应启动或控制请求)

问题描述&#xff1a;修改MongoDB服务bin目录下的mongod.cfg&#xff0c;然后在任务管理器找到MongoDB服务-->右键-->点击【开始】&#xff0c;启动失败无提示&#xff1a; 右键点击任务管理器的MongoDB服务-->点击【打开服务】&#xff0c;跳转到服务页面-->找到M…

C++ QtCreator启动执行报错的各类问题解决++持续更新!!

1.QTCreator启动报错"由于找不到 python310.dll" 在QtCreator加载自动缩进的LLVM插件后, 再次打开Qt时, 会报错找不到python310.dll 解决方法&#xff1a;下载python310.dll,随后复制到目录&#xff1a;C:\Program Files\LLVM\bin 即可解决该问题。下载路径附件如…

VUE element-ui实现表格动态展示、动态删减列、动态排序、动态搜索条件配置、表单组件化。

1、实现效果 1.1、文件目录 1.2、说明 1、本组件支持列表的表头自定义配置&#xff0c;checkbox实现 2、本组件支持列表列排序&#xff0c;vuedraggable是拖拽插件&#xff0c;上图中字段管理里的拖拽效果 &#xff0c;需要的话请自行npm install 3、本组件支持查询条件动态…

使用 C++/WinRT 创作 API

本主题展示了如何直接或间接使用 winrt::implements 基结构来创作 C/WinRT API 。 在此上下文中&#xff0c;“创作”的同义词有“生成”或“实现” 。 本主题介绍以下在 C/WinRT 类型上实现 API 的情形&#xff08;按此顺序&#xff09;。 你不是在创作一个 Windows 运行时类…

Python——函数的参数

1.位置参数 位置参数可以在函数中设置一个或者多个参数&#xff0c;但是必须有对应个数的值传入该函数才能成功调用&#xff0c;例如&#xff1a; def power(x):return x*xprint(powr(5)) 如果传入的值与对应函数设置的位置参数不符合&#xff0c;则会报错&#xff1a; Traceba…

小白浅学Vue3

目录 前端环境 依赖管理NPM安装配置 创建Vue项目 模板语法 文本插值{{ }} v-html 属性绑定 条件渲染 v-if 、v-else-if 、v-else v-show 列表渲染v-for 状态管理 事件 事件 事件传参 事件修饰符 数组变化监听 计算属性 Class绑定 Style绑定 侦听器 v-mod…