【深度学习】注意力机制

https://blog.csdn.net/weixin_43334693/article/details/130189238
https://blog.csdn.net/weixin_47936614/article/details/130466448
https://blog.csdn.net/qq_51320133/article/details/138305880

注意力机制:在处理信息的时候,会将注意力放在需要关注的信息上,对于其他无关的外部信息进行过滤,这种处理方式被称为注意力机制。

注意力机制可以使模型在处理序列数据时更加准确和有效。在传统的神经网络中,每个神经元的输出只依赖于前一层的所有神经元的输出,而在注意力机制中,每个神经元的输出不仅仅取决于前一层的所有神经元的输出,还可以根据输入数据的不同部分进行加权,即对不同部分赋予不同的权重。这样可以使模型更加关注输入序列中的关键信息,从而提高模型的精度和效率。

注意力机制原理

1.计算注意力权重:注意力机制的第一步是计算每个输入位置的注意力权重。这个权重可以根据输入数据的不同部分进行加权,即对不同部分赋予不同的权重。权重的计算通常是基于输入数据和模型参数的函数,可以使用不同的方式进行计算,比如点积注意力、加性注意力、自注意力等。

2.加权求和输入表示:计算出注意力权重之后,下一步就是将每个输入位置的表示和对应的注意力权重相乘,并对所有加权结果进行求和。这样可以得到一个加权的输入表示,它可以更好地反映输入数据中重要的部分。

3.计算输出:注意力机制的最后一步是根据加权的输入表示和其他模型参数计算输出结果。这个输出结果可以作为下一层的输入,也可以作为最终的输出。

需要注意的是,注意力机制并不是一种特定的神经网络结构,而是一种通用的机制,可以应用于不同的神经网络结构中。比如,可以在卷积神经网络中使用注意力机制来关注输入图像中的重要区域,也可以在循环神经网络中使用注意力机制来关注输入序列中的重要部分。

查询(Query): 指的是查询的范围,自主提示,即主观意识的特征向量
键(Key): 指的是被比对的项,非自主提示,即物体的突出特征信息向量
值(Value) : 则是代表物体本身的特征向量,通常和Key成对出现
注意力机制是通过Query与Key的注意力汇聚(给定一个 Query,计算Query与 Key的相关性,然后根据Query与Key的相关性去找到最合适的 Value)实现对Value的注意力权重分配,生成最终的输出结果。

注意力机制计算过程:

阶段1、根据Query和Key计算两者之间的相关性或相似性(常见方法点积、余弦相似度,MLP网络),得到注意力得分

在这里插入图片描述

阶段2、对注意力得分进行缩放scale(除以维度的根号),再softmax函数,一方面可以进行归一化,将原始计算分值整理成所有元素权重之和为1的概率分布;另一方面也可以通过softmax的内在机制更加突出重要元素的权重。一般采用如下公式计算
在这里插入图片描述
阶段3、根据权重系数对Value值进行加权求和,得到Attention Value(此时的V是具有一些注意力信息的,更重要的信息更关注,不重要的信息被忽视了)
在这里插入图片描述
在这里插入图片描述

自注意力机制

神经网络接收的输入是很多大小不一的向量,并且不同向量向量之间有一定的关系,但是实际训练的时候无法充分发挥这些输入之间的关系而导致模型训练结果效果极差。比如机器翻译(序列到序列的问题,机器自己决定多少个标签),词性标注(Pos tagging一个向量对应一个标签),语义分析(多个向量对应一个标签)等文字处理问题。

自注意力机制实际上是想让机器注意到整个输入中不同部分之间的相关性

自注意力机制是注意力机制的变体,其减少了对外部信息的依赖,更擅长捕捉数据或特征的内部相关性。自注意力机制的关键点在于,Q、K、V是同一个东西,或者三者来源于同一个X,三者同源。通过X找到X里面的关键点,从而更关注X的关键信息,忽略X的不重要信息。不是输入语句和输出语句之间的注意力机制,而是输入语句内部元素之间或者输出语句内部元素之间发生的注意力机制。

自注意力机制原理
1、得到Q,K,V的值
对于每一个向量x,分别乘上三个系数 Wq,Wk,Wv,得到的Q,K和V分别表示query,key和value (这三个W就是我们需要学习的参数)
在这里插入图片描述

2、计算注意力权重
利用得到的Q和K计算每两个输入向量之间的相关性,一般采用点积计算

3、Scale+Softmax
将刚得到的相似度除以 d   k   \sqrt{d~k~} d k  (dk 表示键向量的维度),再进行Softmax。经过Softmax的归一化后,每个值是一个大于0且小于1的权重系数,且总和为1,这个结果可以被理解成一个权重矩阵。

4、使用刚得到的权重矩阵,与V相乘,计算加权求和。
在这里插入图片描述

自注意力机制问题:
1、自注意力机制的原理是筛选重要信息,过滤不重要信息。这就导致自注意力机制无法完全利用图像本身具有的尺度,平移不变性,以及图像的特征局部性。这就导致自注意力机制只有在大数据的基础上才能有效地建立准确的全局关系
2、自注意力机制虽然考虑了所有的输入向量,但没有考虑到向量的位置信息。在实际的文字处理问题中,可能在不同位置词语具有不同的性质(可通过位置编码解决:对每一个输入向量加上一个位置向量e,位置向量的生成方式有多种,通过e来表示位置信息带入self-attention层进行计算)

多头注意力机制:Multi-Head Self-Attention

多头注意力机制在自注意力的基础上,通过增加多个注意力头来并行地对输入信息进行不同维度的注意力分配,从而捕获更丰富的特征和上下文信息。

第1步:定义多组W,生成多组Q、K、V
在这里插入图片描述
线性变换:首先,对输入序列中的每个位置的向量分别进行三次线性变换(即加权和偏置),生成查询矩阵Q, 键矩阵K, 和值矩阵V。在多头注意力中,这一步骤实际上会进行h次(其中h为头数),每个头拥有独立的权重矩阵,从而将输入向量分割到h个不同的子空间。

第2步:
并行注意力计算:对每个子空间,应用自注意力机制计算注意力权重,并据此加权求和值矩阵V,得到每个头的输出。公式上表现为:
在这里插入图片描述

第3步:
合并与最终变换:将所有头的输出拼接起来,再经过一个最终的线性变换和层归一化,得到多头注意力的输出。这一步骤整合了不同子空间学到的信息,增强模型的表达能力。

import torch
from torch.nn import Module, Linear, Dropout, LayerNorm
 
class MultiHeadAttention(Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_head = d_model // num_heads
        self.num_heads = num_heads
        
        self.linear_q = Linear(d_model, d_model)
        self.linear_k = Linear(d_model, d_model)
        self.linear_v = Linear(d_model, d_model)
        self.linear_out = Linear(d_model, d_model)
        
        self.dropout = Dropout(dropout)
        self.layer_norm = LayerNorm(d_model)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 线性变换
        q = self.linear_q(q).view(batch_size, -1, self.num_heads, self.d_head)
        k = self.linear_k(k).view(batch_size, -1, self.num_heads, self.d_head)
        v = self.linear_v(v).view(batch_size, -1, self.num_heads, self.d_head)
        
        # 转置以便于计算注意力
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        
        # 计算注意力权重
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 加权求和得到输出
        outputs = torch.matmul(attn_weights, v)
        
        # 转换回原始形状并进行最终线性变换
        outputs = outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        outputs = self.linear_out(outputs)
        outputs = self.layer_norm(outputs + q)
        
        return outputs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/765876.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS开发实战:UDP通讯示例规范

1. UDP简介 UDP协议是传输层协议的一种,它不需要建立连接,是不可靠、无序的,相对于TCP协议报文更简单,在特定场景下有更高的数据传输效率,在现代的网络通讯中有广泛的应用,以最新的HTTP/3为例,…

2024年6月29日 (周六) 叶子游戏新闻

老板键工具来唤去: 它可以为常用程序自定义快捷键,实现一键唤起、一键隐藏的 Windows 工具,并且支持窗口动态绑定快捷键(无需设置自动实现)。 喜马拉雅下载工具: 字面意思 《星刃》性感女主私密部位细节逼真 让玩家感到惊讶《星刃…

探索NVIDIA A100 显卡 如何手搓A100显卡

NVIDIA A100 显卡(GPU)是基于NVIDIA的Ampere架构设计的高性能计算和人工智能任务的处理器。 A100显卡主要由以下几种关键芯片和组件组成: 1. GPU芯片 NVIDIA GA100 GPU: 核心组件,是整个显卡的核心处理单元。GA100芯…

Ubuntu24.04 Isaacgym的安装

教程1 教程2 教程3 1.下载压缩包 link 2. 解压 tar -xvf IsaacGym_Preview_4_Package.tar.gz3. 从源码安装 Ubuntu24.04还需首先进入虚拟环境 python -m venv myenv # 创建虚拟环境,已有可跳过 source myenv/bin/activate # 激活虚拟环境python编译 cd isaa…

Python容器 之 字符串--字符串的常用操作方法

1.字符串查找方法 find() 说明:被查找字符是否存在于当前字符串中。 格式:字符串.find(被查找字符) 结果:如果存在则返回第一次出现 被查找字符位置的下标 如果不存在则返回 -1 需求: 1. 现有字符串数据: 我是中国人 2. 请设计程序…

Python 作业题1 (猜数字)

题目 你要根据线索猜出一个三位数。游戏会根据你的猜测给出以下提示之一:如果你猜对一位数字但数字位置不对,则会提示“Pico”;如果你同时猜对了一位数字及其位置,则会提示“Fermi”;如果你猜测的数字及其位置都不对&…

网络爬虫基础知识

文章目录 网络爬虫基础知识爬虫的定义爬虫的工作流程常用技术和工具爬虫的应用1. 抓取天气信息2. 抓取新闻标题3. 抓取股票价格4. 抓取商品价格5. 抓取博客文章标题 网络爬虫基础知识 爬虫的定义 网络爬虫(Web Crawler 或 Spider)是一种自动化程序&…

算法训练营day24--93.复原IP地址 +78.子集 +90.子集II

一、93.复原IP地址 题目链接:https://leetcode.cn/problems/restore-ip-addresses/ 文章讲解:https://programmercarl.com/0093.%E5%A4%8D%E5%8E%9FIP%E5%9C%B0%E5%9D%80.html 视频讲解:https://www.bilibili.com/video/BV1fA4y1o715 1.1 初…

MyBatis入门案例

实施前的准备工作: 1.准备数据库表2.创建一个新的springboot工程,选择引入对应的起步依赖(mybatis、mysql驱动、lombok)3.在application.properties文件中引入数据库连接信息4.创建对应的实体类Emp(实体类属性采用驼峰…

终身免费的Navicat数据库,不需要破解,官方支持

终身免费的Navicat数据库,不需要破解,官方支持 卸载了Navicat,很不爽上干货,Navicat免费版下载地址 卸载了Navicat,很不爽 公司不让用那些破解的数据库软件,之前一直使用Navicat。换了几款其他的数据库试了…

WebStorm 2024 for Mac JavaScript前端开发工具

Mac分享吧 文章目录 效果一、下载软件二、开始安装1、双击运行软件(适合自己的M芯片版或Intel芯片版),将其从左侧拖入右侧文件夹中,等待安装完毕2、应用程序显示软件图标,表示安装成功3、打开访达,点击【文…

web权限到系统权限 内网学习第一天 权限提升 使用手工还是cs???msf可以不??

现在开始学习内网的相关的知识了,我们在拿下web权限过后,我们要看自己拿下的是什么权限,可能是普通的用户权限,这个连添加用户都不可以,这个时候我们就要进行权限提升操作了。 权限提升这点与我们后门进行内网渗透是乘…

代码查重软件-自力更生

为了减轻工作量,自研了简单实用的代码查重工具,可以对若干文件之间进行查重。通过调试,相似度大于80%的没有一个是冤枉的。好用。去掉雷同的,其他的代码再慢慢看。

pads layout 脚本导出不能运行excle解决办法

在一台新的电脑上安装好PADS,打开PCB文件导出坐标文件时: 出现“ActiveX Automation: server could not be found.”的问题,导致无法成功导出文件,错误提示截图如下: 导致上述问题的原因是在我们配置导出带坐标的脚本时,默认使用的是微软…

服务器连接不上

记录今天2024/07/02的问题: 我今天真的是非常无语,今天在连服务器的时候,突然发现连不上了。 后来才意识到,原来是我笔记本先是开了全局代理,然后再用easy connected连接。当时还跳出了一个窗口如下,我当时…

2024 MWC上海:创新力量驱动未来先行,移远智慧点亮数字蓝海

6月26日,2024年世界移动通信大会(MWC上海)如期举行,今年的展会以“未来先行”为主题,涵盖“超越 5G、数智制造和人工智能经济”三大技术主题。移远通信作为全球物联网行业的引领者之一,今年不仅在展示内容上…

性能调优 性能监控

1.影响性能考虑点包括: 数据库、应用程序、中间件(tomcat、nginx)、网络和操作系统等方面。 首先考虑自己的应用属于 CPU密集型 还是 IO密集型 cpu密集型 计算,排序,分组查询,各种算法 IO密集型 网络传输,磁盘读…

将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap

将数据切分成N份,采用NCCL异步通信,让all_gathermatmul尽量Overlap 一.测试数据二.测试环境三.普通实现四.分块实现 本文演示了如何将数据切分成N份,采用NCCL异步通信,让all_gathermatmul尽量Overlap 一.测试数据 1.测试规模:8192*8192 world_size22.单算子:all_gather:0.035…

JDBC链接kerberos认证的impala数据库报错问题解决

先上代码 public static Connection connectToImpala() {try {log.info("ketTabPath:" ketTabPath);log.info("krb5Path:" krb5Path);System.setProperty("java.security.krb5.conf", krb5Path);System.setProperty("sun.security.krb5.…

冒泡排序、选择排序、菱形

冒泡排序、选择排序、菱形 文章目录 一、冒泡排序二、选择排序三、菱形 一、冒泡排序 思路: 外层(第一层)循环控制循环次数,和业务无关 内层(第二层)循环用于比较相邻的2个值的大小,根据小到大…