深度学习理论基础(六)多头注意力机制的自定义及Pytoch库的使用详细代码

目录

    • 1. Scaled Dot-Product Attention
    • 2. 多头注意力机制框图
      • (1)计算公式
      • (2)具体计算过程
      • (3)具体代码

  
  深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息,提高模型的性能和泛化能力。
  下图 展示了人类在看到一幅图像时如何高效分配有限注意力资源的,其中红色区域表明视觉系统更加关注的目标,从图中可以看出:人们会把注意力更多的投入到人的脸部。文本的标题以及文章的首句等位置。
在这里插入图片描述

1. Scaled Dot-Product Attention

  在实际应用中,经常会用到 Attention 机制,其中最常用的是Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。
Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.
在这里插入图片描述

2. 多头注意力机制框图

  多头注意力机制是在 Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。
在这里插入图片描述

(1)计算公式

在这里插入图片描述

(2)具体计算过程

①计算注意力得分:根据Query和Key计算两者的相似性或相关性。常见方法:求两者的向量点积(内积)。
②对注意力得分进行softmax归一化处理。
③输出:根据权重系数对value进行加权求和。

(3)具体代码

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
 
class MultiHeadAttention(nn.Module):
	#embedding_dim:输入向量的维度,num_heads:注意力机制头数
    def __init__(self, embedding_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads       #总头数
        self.embedding_dim = embedding_dim   #输入向量的维度
        self.d_k= self.embedding_dim// self.num_heads  #每个头 分配的输入向量的维度数
        self.softmax=nn.Softmax(dim=-1)
 
        self.W_query = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)
        self.W_key = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)
        self.W_value = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)
        self.fc_out = nn.Linear(embedding_dim, embedding_dim)
        
   #输入张量 x 中的特征维度分成 self.num_heads 个头,并且每个头的维度为 self.d_k。
	def split_head(self, x, batch_size):
		x = x.reshape(batch_size, -1, self.num_heads, self.d_k)
		return x.permute(0,2,1,3)   #x  (N_size, self.num_heads, -1, self.d_k)
	     
 
    def forward(self, x):
     	batch_size=x.size(0)  #获取输入张量 x 的批量(batch size)大小
        q= self.W_query(x)  
        k= self.W_key(x)  
        v= self.W_value(x)
        
       #使用 split_head 函数对 query、key、value 进行头部切分,将其分割为多个注意力头。
		q= self.split_head(q, batch_size)
		k= self.split_head(k, batch_size)
		v= self.split_head(v, batch_size)
		
		##attention_scorce = q*k的转置/根号d_k
 		attention_scorce=torch.matmul(q, k.transpose(-2,-1))/torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        attention_weight= self.softmax(attention_scorce)
 
        ## output = attention_weight * V
        output = torch.matmul(attention_weight, v)  # [h, N, T_q, num_units/h]
        output  = out.permute(0,2,1,3).contiguous() # [N, T_q, num_units]
 		output  = out.reshape(batch_size,-1, self.embedding_dim)
		output  = self.fc_out(output)
		
        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/520195.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

轻量应用服务器4核8G12M配置优惠价格646元一年零3个月,12M公网带宽

腾讯云轻量4核8G12M服务器优惠价格646元15个月,买一年送3个月,配置为轻量4核8G12M、180GB SSD盘、2000GB月流量、12M带宽,腾讯云优惠活动页面 yunfuwuqiba.com/go/txy 活动链接打开如下图: 腾讯云4核8G服务器租用价格 腾讯云&…

SaaS模式Java版云HIS系统源码 覆盖医院所有业务的HIS信息管理系统源码

SaaS模式Java版云HIS系统源码 覆盖医院所有业务的HIS信息管理系统源码 HIS(Hospital Information System)是覆盖医院所有业务和业务全过程的信息管理系统。 HIS系统以财务信息、病人信息和物资信息为主线,通过对信息的收集、存储、传递、统…

Android 窗口那些事儿

目录 1. 📂 前言 你,是否有过这些疑问? 2. 🔱 Window 2.1 认识 Window 的几个阶段 1)阶段一:Window 约等于 Activity 2)阶段二:Window 约等于 View 3)阶段三&…

list的使用

前言 我们前面已经对string和vector进行了学习使用,以及对他们的底层进行了模拟实现!本期我们继续学习STL的另外一个容器---list。 本期内容介绍 什么是list? list的常用接口 什么是list? 还是来看看官方的文档说明! 这里通过…

[蓝桥杯 2017 国 C] 合根植物

[蓝桥杯 2017 国 C] 合根植物 题目描述 w 星球的一个种植园,被分成 m n m \times n mn 个小格子(东西方向 m m m 行,南北方向 n n n 列)。每个格子里种了一株合根植物。 这种植物有个特点,它的根可能会沿着南北…

【MySQL】增删改查操作(基础)

文章目录 1、新增操作(Create)1.1单行数据全列插入1.2多行数据指定列插入 2、查询操作(Retrieve)2.1全列查询2.2指定列查询2.3指定列查询2.4别名(as)2.5去重(distinct)2.6排序&#…

数据结构—图

图的基本概念 图就是由顶点的有穷非空集合和顶点之间的边组成的集合。通常表示为:G(V,E),其中,G 表示一个图,V 表示顶点的集合,E 表示边的集合。 顶点 图中的数据元素,我们称之为顶点,图至少有…

常见现代卷积神经网络(Pytorch 09)

本章将介绍现代的 卷积神经网络架构,许多现代卷积神经网络的研究都是建立在这一章的基础上的。在本章中的每一个模型都曾一度占据主导地位,其中许多模型都是 ImageNet竞赛 的优胜者。ImageNet竞赛自2010年以来,一直是计算机视觉中监督学习进展…

面试题——JVM老年代空间担保机制(我的想法)

这里借用一下人家的图,来说一下我的想法,嘻嘻。。。。 原文链接:一道面试题:JVM老年代空间担保机制-CSDN博客? 嗯,我觉得老年代担保机制的主要作用就是避免频繁触发FULL GC,这其实也是因为年轻代Minor GC…

Java项目:基于Springboot+vue社区医院管理系统设计与实现(源码+数据库+毕业论文)

一、项目简介 本项目是一套基于Springbootvue社区医院管理系统 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse或者idea 确保可以运行! 该系统功能完善、界面美观、操作简单、功能…

数据结构之顺序表的相关知识点及应用

个人主页(找往期文章包括但不限于本期文章中不懂的知识点):我要学编程(ಥ_ಥ)-CSDN博客 目录 顺序表的概念及结构 顺序表的分类 顺序表的实现 在顺序表中增加数据 在顺序表中删除数据 在顺序表中查找数据 顺序表源码 顺序表的概念…

浮动辊位移测量功能块(CODESYS ST代码)

1、张力测量+标定(ST代码) 张力测量+标定(ST代码)_动态舞轮控制张力-CSDN博客文章浏览阅读804次。跳舞轮对应张力调节范围,我们可以通过改变气缸的气压方式间接改变,张力跳舞轮在收放卷闭环控制上的详细应用,可以参看下面的文章链接,这里我们主要讨论精密可调气阀的模拟量…

Java | Leetcode Java题解之第6题Z字形变换

题目: 题解: class Solution {public String convert(String s, int numRows) {int n s.length(), r numRows;if (r 1 || r > n) {return s;}int t r * 2 - 2;int c (n t - 1) / t * (r - 1);char[][] mat new char[r][c];for (int i 0, x …

[Spring Cloud] gateway全局异常捕捉统一返回值

文章目录 处理转发失败的情况全局参数同一返回格式操作消息对象AjaxResult返回值状态描述对象AjaxStatus返回值枚举接口层StatusCode 全局异常处理器自定义通用异常定一个自定义异常覆盖默认的异常处理自定义异常处理工具 在上一篇章时我们有了一个简单的gateway网关 [Spring C…

比selenium体验更好的ui自动化测试工具: cypress介绍

话说 Cypress is a next generation front end testing tool built for the modern web. And Cypress can test anything that runs in a browser.Cypress consists of a free, open source, locally installed Test Runner and a Dashboard Service for recording your tests.…

leetcode077——排序链表

题目: 给定链表的头结点 head ,请将其按 升序 排列并返回 排序后的链表 。 示例 1: 输入:head [4,2,1,3] 输出:[1,2,3,4] 思路: 1.找链表中点【使用快慢指针 慢指针每次移动一步,快指针每…

基于单片机12864的出租车计价器设计

**单片机设计介绍,基于单片机12864的出租车计价器设计 文章目录 一 概要二、功能设计三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机和12864液晶显示屏的出租车计价器设计,主要是利用单片机的强大控制能力和液晶显示屏的直观显示特性&…

牛客网BC-125 序列中整数去重复(难题讲解)

题目如下 --------------------------------------------------------------------------------------------------------------------------------- 题目讲解(思路) -------------------------------------------------------------------------------…

单一职责原则

1.1 阅读干吗不直接用手机? 电子阅读器比较专注,而手机功能比较多,影响专注。 1.2 手机不纯粹 手机确实很方便。但是现在的手机就是一台小型智能电脑。它不仅能打电话,还能听音乐、看电影电视、与个人交流、与一群人群聊&#…

基于java+SpringBoot+Vue的大学生入学审核系统设计与实现

基于javaSpringBootVue的大学生入学审核系统设计与实现 开发语言: Java 数据库: MySQL技术: SpringBoot VUE工具: IDEA/Eclipse、Navicat、Maven 系统展示 前台展示 入学办理模块:学生可以提交入学申请并跟踪入学办理进度。 后台展示 学生管理模块&#xff1…