注意力机制(代码实现案例)

学习目标

  • 了解什么是注意力计算规则以及常见的计算规则.
  • 了解什么是注意力机制及其作用.
  • 掌握注意力机制的实现步骤.

1 注意力机制介绍

1.1 注意力概念

  • 我们观察事物时,之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物最具有辨识度的部分从而作出判断,而并非是从头到尾的观察一遍事物后,才能有判断结果. 正是基于这样的理论,就产生了注意力机制.

1.2 注意力计算规则

  • 它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示. 当输入的Q=K=V时, 称作自注意力计算规则.

1.3 常见的注意力计算规则

  • bmm运算演示:

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

2 什么是注意力机制

  • 注意力机制是注意力计算规则能够应用的深度学习网络的载体, 同时包括一些必要的全连接层以及相关张量处理, 使其与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
  • 说明: NLP领域中, 当前的注意力机制大多数应用于seq2seq架构, 即编码器和解码器模型.

3 注意力机制的作用

  • 在解码器端的注意力机制: 能够根据模型目标有效的聚焦编码器的输出结果, 当其作为解码器的输入时提升效果. 改善以往编码器输出是单一定长张量, 无法存储过多信息的情况.
  • 在编码器端的注意力机制: 主要解决表征问题, 相当于特征提取过程, 得到输入的注意力表示. 一般使用自注意力(self-attention).

注意力机制在网络中实现的图形表示:

4 注意力机制实现步骤

4.1 步骤

  • 第一步: 根据注意力计算规则, 对Q,K,V进行相应的计算.
  • 第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接, 如果是转置点积, 一般是自注意力, Q与V相同, 则不需要进行与Q的拼接.
  • 第三步: 最后为了使整个attention机制按照指定尺寸输出, 使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示.

4.2 代码实现

  • 常见注意力机制的代码分析:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Attn(nn.Module):
        def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
            """初始化函数中的参数有5个, query_size代表query的最后一维大小
               key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, 
               value = (1, value_size1, value_size2)
               value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""
            super(Attn, self).__init__()
            # 将以下参数传入类中
            self.query_size = query_size
            self.key_size = key_size
            self.value_size1 = value_size1
            self.value_size2 = value_size2
            self.output_size = output_size
    
            # 初始化注意力机制实现第一步中需要的线性层.
            self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
    
            # 初始化注意力机制实现第三步中需要的线性层.
            self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
    
    
        def forward(self, Q, K, V):
            """forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
               张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""
    
            # 第一步, 按照计算规则进行计算, 
            # 我们采用常见的第一种计算规则
            # 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果
            attn_weights = F.softmax(
                self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)
    
            # 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, 
            # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算
            attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
    
            # 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, 
            # 需要将Q与第一步的计算结果再进行拼接
            output = torch.cat((Q[0], attn_applied[0]), 1)
    
            # 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
            # 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度
            output = self.attn_combine(output).unsqueeze(0)
            return output, attn_weights
    

  • 调用:
  • query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64
    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

  • 输出效果:
    tensor([[[ 0.4477, -0.0500, -0.2277, -0.3168, -0.4096, -0.5982,  0.1548,
              -0.0771, -0.0951,  0.1833,  0.3128,  0.1260,  0.4420,  0.0495,
              -0.7774, -0.0995,  0.2629,  0.4957,  1.0922,  0.1428,  0.3024,
              -0.2646, -0.0265,  0.0632,  0.3951,  0.1583,  0.1130,  0.5500,
              -0.1887, -0.2816, -0.3800, -0.5741,  0.1342,  0.0244, -0.2217,
               0.1544,  0.1865, -0.2019,  0.4090, -0.4762,  0.3677, -0.2553,
              -0.5199,  0.2290, -0.4407,  0.0663, -0.0182, -0.2168,  0.0913,
              -0.2340,  0.1924, -0.3687,  0.1508,  0.3618, -0.0113,  0.2864,
              -0.1929, -0.6821,  0.0951,  0.1335,  0.3560, -0.3215,  0.6461,
               0.1532]]], grad_fn=<UnsqueezeBackward0>)
    
    
    tensor([[0.0395, 0.0342, 0.0200, 0.0471, 0.0177, 0.0209, 0.0244, 0.0465, 0.0346,
             0.0378, 0.0282, 0.0214, 0.0135, 0.0419, 0.0926, 0.0123, 0.0177, 0.0187,
             0.0166, 0.0225, 0.0234, 0.0284, 0.0151, 0.0239, 0.0132, 0.0439, 0.0507,
             0.0419, 0.0352, 0.0392, 0.0546, 0.0224]], grad_fn=<SoftmaxBackward>)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/430058.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

docker快照备份回滚

1. 安装系统 1.1 vm安装Ubuntu 参考:https://blog.csdn.net/u010308917/article/details/125157774 1.2 其他操作 添加自定义物理卷 –待补充– 1.2.1 查询可用物理卷 fdisk -l 输出如下 Disk /dev/loop0: 73.9 MiB, 77492224 bytes, 151352 sectors Units: sectors of …

运维随录实战(4)

添加账号并为账号赋予root权限 1,使用root账号添加一个普通账号 adduser test passwd test # 赋予密码 2,赋予root权限 修改/etc/sudoers文件,如果使用vi 命令打开提示仅只读,则使用 visudo命令打开 在root下面添加一行“test ALL=(ALL) ALL”,如下所示 3,将test账…

【MySQL使用】show processlist 命令详解

show processlist 命令详解 一、命令含义二、命令返回参数三、Command值解释四、State值解释五、参考资料 一、命令含义 对于一个MySQL连接&#xff0c;或者说一个线程&#xff0c;任何时刻都有一个状态&#xff0c;该状态表示了MySQL当前正在做什么。SHOW PROCESSLIST 命令的…

行人实时动作识别

详细资料和代码请加微信&#xff1a;17324069443 一、项目介绍 基于PyTorchVideo的实时动作识别框架&#xff1a; 我们选择了yolov5作为目标检测器&#xff0c;而不是Faster R-CNN&#xff0c;它速度更快、更方便。 我们使用一个跟踪器&#xff08;deepsort&#xff09;来为不…

STM32CubeIDE基础学习-安装芯片固件支持包

STM32CubeIDE基础学习-添加芯片固件支持包 前言 前面的文章在安装STM32CubeIDE软件时没有安装这个芯片PACK包&#xff0c;如果工程没有这个固件支持包的话是无法正常使用的&#xff0c;随便安装一个和芯片对应系列的支持包就可以了。 这篇文章来记录一下新增PACK包的常用操作…

【Docker】Windows11操作系统下安装、使用Docker保姆级教程

【Docker】Windows11操作系统下安装、使用Docker保姆级教程 大家好 我是寸铁&#x1f44a; 总结了一篇【Docker】Windows11操作系统下安装、使用Docker保姆级教程的文章✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 前言 什么是 Docker&#xff1f; Docker 是一个开源平台&…

【优选算法】前缀和

前缀和思想其实就是一种简单的dp思想&#xff0c;也就是动态规划 什么时候用到前缀和&#xff1f;当要快速求出数组中某一个区间的和 前缀和模板 暴力解法 定义一个指针从左向右遍历&#xff0c;并且累加值即可&#xff0c;这里就不过多赘述&#xff0c;主要还是来看前缀和…

缓存雪崩、击穿、穿透

目录 前言 一、缓存雪崩 1.大量数据同时过期 1.均匀设置过期时间 2.互斥锁 3.后台更新缓存 2.Redis故障宕机 1.服务熔断或请求限流机制 2.构建Redis缓存高可靠集群 二、缓存击穿 1.设置互斥锁&#xff1b; 2.不给热点数据设置过期时间&#xff0c;由后台更新缓存。 …

可行性研究报告-范例直接套用

1业务需求可行性分析 2技术可行性分析 2.1规范化原则 2.2高度的兼容性和可移植性 2.3人性化、适用性 2.4标准化统一设计原则 2.5先进安全可扩展性原则 3开发周期可行性分析 4人力资源可行性分析 5成本分析 6收益分析 7结论 软件开发多套实际项目案例、方案、源码获取&#xff1…

使用API有效率地管理Dynadot域名,进行DNS域名解析

关于Dynadot Dynadot是通过ICANN认证的域名注册商&#xff0c;自2002年成立以来&#xff0c;服务于全球108个国家和地区的客户&#xff0c;为数以万计的客户提供简洁&#xff0c;优惠&#xff0c;安全的域名注册以及管理服务。 Dynadot平台操作教程索引&#xff08;包括域名邮…

nodejs,JSDOM 补 window环境

window[atob] 是一个在浏览器中使用的 JavaScript 函数&#xff0c;用于将 base64 编码的字符串解码为原始数据。具体来说&#xff0c;atob 函数会将 base64 字符串解码为一个 DOMString&#xff0c;其中包含解码后的二进制数据。这在处理从服务器获取的 base64 编码的数据或在…

RedisDesktopManager连接Ubuntu的Redis失败解决办法

配置redis 1.设置redis在后台服务&#xff0c;修改配置文件 默认情况下是 no ,修改为yes&#xff0c;可以后台服务 2、设置redis端口&#xff0c;默认端口是6379&#xff0c;可以根据自己的需要&#xff0c;找到/et/redis/redis.conf文件, 修改port 3、设置密码 配置文件中…

基于pytorch的手写体识别

一、环境搭建 链接: python与深度学习——基础环境搭建 二、数据集准备 本次实验用的是MINIST数据集&#xff0c;利用MINIST数据集进行卷积神经网络的学习&#xff0c;就类似于学习单片机的点灯实验&#xff0c;学习一门机器语言输出hello world。MINIST数据集&#xff0c;可以…

【书籍推广】这本书太好了!150页就能让你上手大模型应用开发

文章目录 蛇尾书特色蛇尾书思维导图作译者简介业内专家书评 如果问个问题&#xff1a;有哪些产品曾经创造了伟大的奇迹&#xff1f;ChatGPT 应该会当之无愧入选。仅仅发布 5 天&#xff0c;ChatGPT 就吸引了 100 万用户——当然&#xff0c;数据不是关键&#xff0c;关键是其背…

【Unity】VR开发的正确测试节奏

【背景】 VR开发由于其测试时对设备的依赖较大&#xff0c;因此有时在没有测试条件时想当然地写了大量代码&#xff0c;一旦到正式测试时需要debug&#xff0c;往往无法判断到底是哪个环节的问题&#xff08;代码&#xff0c;环境&#xff0c;等等&#xff09;。相对于PC平台的…

AI预测福彩3D第2弹【2024年3月5日预测】

书接上回&#xff0c;首先声明下&#xff0c;写这一系列文章的目的不为别的&#xff0c;就是想看下到底使用一些强大的AI算法能不能挖掘出彩票的规律&#xff0c;毕竟彩票的规律太乱&#xff0c;不是说没有规律&#xff0c;而是规律太多。经过上一篇文章的图片&#xff0c;大家…

git使用教程14-Pycharm版本控制与分支管理

一、版本控制 1、版本控制介绍 &#xff08;1&#xff09;Version Control System 版本控制系统&#xff0c;简称VCS。 &#xff08;2&#xff09;版本控制系统分类&#xff1a; 集中式版本控制工具&#xff1a;SVN 分布式版本控制工具&#xff1a;Git 2、Pycharm 支持的版本…

C++:Vector的使用

一、vector的介绍 vector的文档介绍 1. vector是表示可变大小数组的序列容器。 2. 就像数组一样&#xff0c;vector也采用的连续存储空间来存储元素。也就是意味着可以采用下标对vector的元素进行访问&#xff0c;和数组一样高效。但是又不像数组&#xff0c;它的大小是可以…

全链路监控

1. 全链路监控的兴起与发展 当代的互联网的服务&#xff0c;通常都是用复杂的、大规模分布式集群来实现的。互联网应用构建在不同的软件模块集上&#xff0c;这些软件模块&#xff0c;有可能是由不同的团队开发、可能使用不同的编程语言来实现、有可能布在了几千台服务器&…

CentOS 7操作系统安装教程

CentOS 7操作系统安装教程 CentOS 7是一款功能强大、稳定可靠的操作系统&#xff0c;适用于服务器、桌面等多种场景。下面将介绍CentOS 7的安装教程。 准备工作 下载CentOS 7镜像文件&#xff1a;https://mirrors.tuna.tsinghua.edu.cn/centos/7/isos/x86_64/准备安装介质&am…