什么是注意力机制?注意力机制的计算规则

我们观察事物时,之所以能够快速判断一种事物(当然允许判断是错误的),是因为我们大脑能够很快把注意力放在事物最具有辨识度的部分从而作出判断,而并非是从头到尾的观察一遍事物后,才能有判断结果,正是基于这样的理论,就产生了注意力机制。

什么是注意力计算规则:

它需要三个指定的输入Q(query),K(key),V(value),然后通过计算公式得到注意力的结果,这个结果代表query在key和value作用下的注意力表示.当输入的Q=K=V时,称作自注意力计算规则。

常见的注意力计算规则:

|| ·将Q,K进行纵轴拼接,做一次线性变化,再使用softmax处理获得结果最后与V做张量乘法。

在这里插入图片描述

|| ·将Q,K进行纵轴拼接,做一次线性变化后再使用tanh函数激活,然后再进行内部求和,最后使用softmax处理获得结果再与V做张量乘法.

在这里插入图片描述

|| ·将Q与K的转置做点积运算,然后除以一个缩放系数再使用softmax处理获得结果最后与V做张量乘法。

在这里插入图片描述

说明:当注意力权重矩阵和V都是三维张量且第一维代表为batch条数时, 则做bmm运算.bmm是一种特殊的张量乘法运算。

bmm运算演示:

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

注意力机制的作用

在解码器端的注意力机制:能够根据模型目标有效的聚焦编码器的输出结果,当其作为解码器的输入时提升效果,改善以往编码器输出是单一定长张量,无法存储过多信息的情况。

在编码器端的注意力机制:主要解决表征问题,相当于特征提取过程,得到输入的注意力表示。般使用自注意力(self-attention)。

注意力机制实现步骤

第一步:根据注意力计算规则,对Q,K,V进行相应的计算

第二步:根据第一步采用的计算方法,如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接,如果是转置点积一般是自注意力,Q与V相同,则不需要进行与Q的拼接

第三步:最后为了使整个attention机制按照指定尺寸输出,使用线性层作用在第二步的结果上做个线性变换,得到最终对Q的注意力表示

常见注意力机制的代码分析:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        """初始化函数中的参数有5个, query_size代表query的最后一维大小
           key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, 
           value = (1, value_size1, value_size2)
           value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""
        super(Attn, self).__init__()
        # 将以下参数传入类中
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        # 初始化注意力机制实现第一步中需要的线性层.
        self.attn = nn.Linear(self.query_size + self.key_size, value_size1)

        # 初始化注意力机制实现第三步中需要的线性层.
        self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)


    def forward(self, Q, K, V):
        """forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
           张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""

        # 第一步, 按照计算规则进行计算, 
        # 我们采用常见的第一种计算规则
        # 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果
        attn_weights = F.softmax(
            self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)

        # 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, 
        # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)

        # 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, 
        # 需要将Q与第一步的计算结果再进行拼接
        output = torch.cat((Q[0], attn_applied[0]), 1)

        # 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
        # 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度
        output = self.attn_combine(output).unsqueeze(0)
        return output, attn_weights

调用:

query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
out = attn(Q, K ,V)
print(out[0])
print(out[1])

输出效果:

tensor([[[ 0.4477, -0.0500, -0.2277, -0.3168, -8.4096, -0.5982, 0.1548,
-8.8771, -8.0951. 8.1833. 8.3128. 8.1260, 8.4420. 8.8495.
-0.7774, -0.0995, 0.2629, 0.4957, 1.0922, 0.1428, 0.3024.
-0.2646, -0.0265, 0.0632, 0.3951, 0.1583, 0.1130, 0.5500,
-0.1887, -0.2816, -0.3800, -0.5741, 0.1342, 0.0244, -0.2217,
0.1544, 0.1865, -0.2019, 0.4090, -0.4762, 0.3677, -0.2553,
-0.5199, 0.2290, -0.4407, 0.0663, -8.0182, -8.2168, 0.0913,
-0.2340, 0.1924, -0.3687, 0.1508, 0.3618, -0.0113, 0.2864.
-0.1929, -0.6821, 0.0951, 0.1335, 0.3560, -0.3215
,0.6461,
0.1532]]],grad_fn=<UnsqueezeBackward0>)
tensor([[0.0395, 0.0342, 0.0200, 0.0471, 0.0177, 0.0209, 0.0244, 0.0465, 0.0346,
0.0378, 0.0282, 0.0214, 0.0135, 0.0419, 0.0926, 0.0123, 0.0177, 0.0187,
0.0166, 0.8225, 0.0234, 0.0284, 0.0151, 0.0239, 0.0132, 0.0439, 0.0507,
0.0419, 8.0352, 8.0392, 8.0546, 0.0224]], grad_fn=<SoftmaxBackward>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/59046.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】初阶结构体

&#x1f388;个人主页&#xff1a;库库的里昂 &#x1f390;CSDN新晋作者 &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏 ✨收录专栏&#xff1a;C语言初阶 ✨其他专栏&#xff1a;代码小游戏 &#x1f91d;希望作者的文章能对你有所帮助&#xff0c;有不足的地方请在评论…

P14 电路定理——巧妙-灵性-智慧

1、替代定理 图示表示&#xff1a; 叠加定理和齐性定理只能用于线性电路&#xff0c;但是替代定理无论线不线性都可以用。 常见的&#xff1a;线性电路将某复杂支路等效成电压源或电流源之后&#xff0c;就可以使用叠加原理了。 引入两个相互抵消的电压源&#xff0c;拿其中一…

基于STM32设计的智能空调

一、项目背景 随着人们生活水平的不断提高&#xff0c;对居住环境的舒适度要求也越来越高。空调作为一种重要的家电设备&#xff0c;已经成为了现代家庭中必不可少的一部分。本文介绍了一种基于STM32的智能空调设计方案&#xff0c;可以自动地根据环境温度进行温度调节。 二、…

Opencv-C++笔记 (15) : 像素重映射 与 图像扭曲

文章目录 一、重映射简介二、图像扭曲 一、重映射简介 重映射&#xff0c;就是把一幅图像中某位置的像素放置到另一图像指定位置的过程。即&#xff1a; 在重映射过程中&#xff0c;图像的大小也可以同时发生改变。此时像素与像素之间的关系就不是一一对应关系&#xff0c;因…

PHP从入门到精通—PHP开发入门-PHP概述、PHP开发环境搭建、PHP开发环境搭建、第一个PHP程序、PHP开发流程

每开始学习一门语言&#xff0c;都要了解这门语言和进行开发环境的搭建。同样&#xff0c;学生开始PHP学习之前&#xff0c;首先要了解这门语言的历史、语言优势等内容以及了解开发环境的搭建。 PHP概述 认识PHP PHP最初是由Rasmus Lerdorf于1994年为了维护个人网页而编写的一…

租赁类小程序定制开发|租赁管理系统源码|免押租赁系统开发

随着互联网的发展&#xff0c;小程序成为了一种重要的移动应用开发方式。租赁小程序作为其中的一种类型&#xff0c;可以为很多行业提供便利和创新。下面我们将介绍一些适合开发租赁小程序的行业。   房屋租赁行业&#xff1a;租房小程序可以帮助房东和租户快速找到合适的租赁…

对当下AI的一些观感思考

目前来看&#xff0c;AI技术地震的震中还是在美帝那旮瘩。尤其是M7&#xff0c;这几家市值加总快15万亿美元了&#xff0c;个个都是行业翘楚&#xff0c;个个都有拿得出手的东西。AI是个技术密集、人才密集、计算密集的产业。美帝拥有全球一流的顶尖人才&#xff0c;以及财力、…

IntelliJ IDEA 2023.2社区版插件汇总

参考插件帝&#xff1a;https://gitee.com/zhengqingya/java-developer-document 突发小技巧&#xff1a;使用插件时要注意插件的版本兼容性&#xff0c;并根据自己的实际需求选择合适的插件。同时&#xff0c;不要过度依赖插件&#xff0c;保持简洁和高效的开发环境才是最重要…

2023下半年软考初级程序员报名入口-报名流程-备考方法

软考初级程序员2023下半年考试时间&#xff1a; 2023年下半年软考初级程序员的考试时间为11月4日、5日。考试时间在全国各地一致&#xff0c;建议考生提前备考。共分两科&#xff0c;第一科基础知识考试具体时间为9:00到11:30&#xff1b;第二科应用技术考试具体时间为2:00到4…

STM32 低功耗-睡眠模式

STM32 睡眠模式 文章目录 STM32 睡眠模式第1章 低功耗模式简介第2章 睡眠模式简介2.1 进入睡眠模式2.1 退出睡眠模式 第3章 睡眠模式代码示例总结 第1章 低功耗模式简介 在 STM32 的正常工作中&#xff0c;具有四种工作模式&#xff1a;运行、睡眠、停止和待机模式。 在系统或…

小白到运维工程师自学之路 第六十四集 (dockerfile构建tomcat、mysql、lnmp、redis镜像)

一、tomcat&#xff08;更换jdk&#xff09; mkdir tomcat cd tomcat/ tar xf jdk-8u191-linux-x64.tar.gz tar xf apache-tomcat-8.5.40.tar.gzvim Dockerfile FROM centos:7 MAINTAINER Crushlinux <syh163.com> ADD jdk1.8.0_191 /usr/local/java ENV JAVA_HOME /us…

AcWing257. 关押罪犯(二分图+染色法)

输入样例&#xff1a; 4 6 1 4 2534 2 3 3512 1 2 28351 1 3 6618 2 4 1805 3 4 12884输出样例&#xff1a; 3512 解析&#xff1a; 二分&#xff0c;每次查看是否是二分图 #include<bits/stdc.h> using namespace std; typedef long long ll; const int N2e45,M2e55…

《Spring Boot源码解读与原理分析》书籍推荐

Spring Boot 1.0.0 早在2014年就已经发布&#xff0c;只不过到了提倡“降本增效”的今天&#xff0c;Spring Boot才引起了越来越多企业的关注。Spring Boot是目前Java EE开发中颇受欢迎的框架之一。依托于底层Spring Framework的基础支撑&#xff0c;以及完善强大的特性设计&am…

使用 Go 语言实现二叉搜索树

原文链接&#xff1a; 使用 Go 语言实现二叉搜索树 二叉树是一种常见并且非常重要的数据结构&#xff0c;在很多项目中都能看到二叉树的身影。 它有很多变种&#xff0c;比如红黑树&#xff0c;常被用作 std::map 和 std::set 的底层实现&#xff1b;B 树和 B 树&#xff0c;…

前端实现打印1 - 使用 iframe 实现 并 分页打印

目录 打印代码对话框预览打印预览 打印代码 <!-- 打印 --> <template><el-dialogtitle"打印":visible.sync"dialogVisible"width"50%"top"7vh"append-to-bodyclose"handleClose"><div ref"print…

uni-app uView自定义底部导航栏

因项目需要自定义底部导航栏&#xff0c;我把它写在了组件里&#xff0c;基于uView2框架写的&#xff08;vue2&#xff09;&#xff1b; 一、代码 在components下创建tabbar.vue文件&#xff0c;代码如下&#xff1a; <template><view><u-tabbar :value"c…

推荐一款非常简单实用的数据库连接工具Navicat Premium

Navicat Premium是一款非常实用的数据库连接工具&#xff0c;别再用HeidiSQL和idea自带的数据库连接了&#xff0c;看完这篇文章&#xff0c;赶紧把Navicat Premium用起来吧。 首先&#xff0c;需要获取Navicat Premium的安装包&#xff0c;可以通过以下网盘链接下载&#xff0…

刷题笔记 day5

力扣 202 快乐数 首先来分析什么样的数是快乐数&#xff0c; 解题思路&#xff1a; 1&#xff09;定义快慢指针 &#xff1b; 2&#xff09;快指针走两步&#xff0c;慢指针走一步 &#xff1b;3&#xff09;两个指针相遇时判断相遇的数值是否为1。 怎样处理 取各分位数的平…

力扣初级算法(二分查找)

力扣初级算法(二分法)&#xff1a; 每日一算法&#xff1a;二分法查找 学习内容&#xff1a; 给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;并返回其索引。如果目标值不存在于数组中&#xff0c;返回它将会被按顺序插入的位置。 2.二分查找流程&…

SystemVerilog scheduler

文章目录 简介调度器simulation regionPreponed regionActive regionInactive regionNBA(Non-blocking Assignment Events region)Observed regionReactive regionRe-Inactive Events regionRe-NBA RegionPostponed Region PLI region:Pre-active regionPre-NBA regionPost-NBA…