【深度学习实验】注意力机制(二):掩码Softmax 操作

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 理论介绍
      • a. 认知神经学中的注意力
      • b. 注意力机制:
    • 1. 注意力权重矩阵可视化(矩阵热图)
    • 2. 掩码Softmax 操作
      • a. 导入必要的库
      • b. masked_softmax
      • c. 实验结果

一、实验介绍

  注意力机制作为一种模拟人脑信息处理的关键工具,在深度学习领域中得到了广泛应用。本系列实验旨在通过理论分析和代码演示,深入了解注意力机制的原理、类型及其在模型中的实际应用。

本文将介绍将介绍带有掩码的 softmax 操作

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 理论介绍

a. 认知神经学中的注意力

  人脑每个时刻接收的外界输入信息非常多,包括来源于视
觉、听觉、触觉的各种各样的信息。单就视觉来说,眼睛每秒钟都会发送千万比特的信息给视觉神经系统。人脑通过注意力来解决信息超载问题,注意力分为两种主要类型:

  • 聚焦式注意力(Focus Attention):
    • 这是一种自上而下的有意识的注意力,通常与任务相关。
    • 在这种情况下,个体有目的地选择关注某些信息,而忽略其他信息。
    • 在深度学习中,注意力机制可以使模型有选择地聚焦于输入的特定部分,以便更有效地进行任务,例如机器翻译、文本摘要等。
  • 基于显著性的注意力(Saliency-Based Attention)
    • 这是一种自下而上的无意识的注意力,通常由外界刺激驱动而不需要主动干预。
    • 在这种情况下,注意力被自动吸引到与周围环境不同的刺激信息上。
    • 在深度学习中,这种注意力机制可以用于识别图像中的显著物体或文本中的重要关键词。

  在深度学习领域,注意力机制已被广泛应用,尤其是在自然语言处理任务中,如机器翻译、文本摘要、问答系统等。通过引入注意力机制,模型可以更灵活地处理不同位置的信息,提高对长序列的处理能力,并在处理输入时动态调整关注的重点。

b. 注意力机制:

  1. 注意力机制(Attention Mechanism):

    • 作为资源分配方案,注意力机制允许有限的计算资源集中处理更重要的信息,以应对信息超载的问题。
    • 在神经网络中,它可以被看作一种机制,通过选择性地聚焦于输入中的某些部分,提高了神经网络的效率。
  2. 基于显著性的注意力机制的近似: 在神经网络模型中,最大汇聚(Max Pooling)和门控(Gating)机制可以被近似地看作是自下而上的基于显著性的注意力机制,这些机制允许网络自动关注输入中与周围环境不同的信息。

  3. 聚焦式注意力的应用: 自上而下的聚焦式注意力是一种有效的信息选择方式。在任务中,只选择与任务相关的信息,而忽略不相关的部分。例如,在阅读理解任务中,只有与问题相关的文章片段被选择用于后续的处理,减轻了神经网络的计算负担。

  4. 注意力的计算过程:注意力机制的计算分为两步。首先,在所有输入信息上计算注意力分布,然后根据这个分布计算输入信息的加权平均。这个计算依赖于一个查询向量(Query Vector),通过一个打分函数来计算每个输入向量和查询向量之间的相关性。

    • 注意力分布(Attention Distribution):注意力分布表示在给定查询向量和输入信息的情况下,选择每个输入向量的概率分布。Softmax 函数被用于将分数转化为概率分布,其中每个分数由一个打分函数计算得到。

    • 打分函数(Scoring Function):打分函数衡量查询向量与输入向量之间的相关性。文中介绍了几种常用的打分函数,包括加性模型、点积模型、缩放点积模型和双线性模型。这些模型通过可学习的参数来调整注意力的计算。

      • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

      • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

      • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

      • 双线性模型 s ( x , q ) = x T W q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{W} \mathbf{q} s(x,q)=xTWq (非对称性)

  5. 软性注意力机制

    • 定义:软性注意力机制通过一个“软性”的信息选择机制对输入信息进行汇总,允许模型以概率形式对输入的不同部分进行关注,而不是强制性地选择一个部分。

    • 加权平均:软性注意力机制中的加权平均表示在给定任务相关的查询向量时,每个输入向量受关注的程度,通过注意力分布实现。

    • Softmax 操作:注意力分布通常通过 Softmax 操作计算,确保它们成为一个概率分布。

1. 注意力权重矩阵可视化(矩阵热图)

【深度学习实验】注意力机制(一):注意力权重矩阵可视化(矩阵热图heatmap)

2. 掩码Softmax 操作

  掩码Softmax操作的用处在于在处理序列数据时,对于某些位置的输入可能需要进行忽略或者特殊处理。通过使用掩码张量,可以将这些无效或特殊位置的权重设为负无穷大,从而在进行Softmax操作时,使得这些位置的输出为0。
  这种操作通常在序列模型中使用,例如自然语言处理中的文本分类任务。在文本分类任务中,输入是一个句子或一个段落,长度可能不一致。为了保持输入的统一性,需要进行填充操作,使得所有输入的长度相同。然而,在经过填充操作后,一些位置可能对应于填充字符,这些位置的权重应该被忽略。通过使用掩码Softmax操作,可以确保填充位置的输出为0,从而在计算损失函数时不会对填充位置产生影响。

a. 导入必要的库

import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l

b. masked_softmax

  带有掩码的 softmax 操作主要用于处理变长序列,其中填充的元素不应该对 softmax 操作的结果产生影响。

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
        

参数解释

  • X: 一个三维张量,表示输入的 logits。

  • valid_lens: 一个一维或二维张量,表示每个序列的有效长度。如果是一维张量,它会被重复到匹配 X 的第二维。

函数流程

  1. 如果 valid_lensNone,则直接应用标准的 softmax 操作,返回 nn.functional.softmax(X, dim=-1)

  2. 如果 valid_lens 不是 None,则进行以下步骤:

    • 获取 X 的形状 shape

    • 如果 valid_lens 是一维张量,将其重复到匹配 X 的第二维,以便与 X 进行逐元素运算。

    • X 重塑为一个二维张量,形状为 (-1, shape[-1]),这样可以在最后一个轴上进行逐元素操作。

    • 使用 d2l.sequence_mask 函数,将有效长度外的元素替换为一个很大的负数(-1e6)。这样,这些元素在经过 softmax 后的输出会趋近于零。

    • 将处理后的张量重新塑形为原始形状,然后应用 softmax 操作。最终输出是带有掩码的 softmax 操作结果。

c. 实验结果

masked_softmax(torch.rand(3, 8, 5), torch.tensor([2, 2, 2]))
  • 随机生成了一个形状为 (3, 8, 5) 的 3D 张量,其中有效长度全为 2。

在这里插入图片描述

masked_softmax(torch.rand(3, 8, 5), torch.tensor([1, 2, 3]))

在这里插入图片描述

  • 使用二维张量,为矩阵样本中的每一行指定有效长度
masked_softmax(torch.rand(2, 2, 5), torch.tensor([[1, 3], [2, 4]]))

  • 对于形状为 (2, 2, 5) 的 3D 张量
    • 第一个二维矩阵的第一个序列的有效长度为 1,第二个序列的有效长度为 3。
    • 第二个二维矩阵的第一个序列的有效长度为 2,第二个序列的有效长度为 4。
      在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/164896.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营第25天|216.组合总和III 17.电话号码的字母组合

JAVA代码编写 216. 组合总和III 找出所有相加之和为 n 的 k 个数的组合,且满足下列条件: 只使用数字1到9每个数字 最多使用一次 返回 所有可能的有效组合的列表 。该列表不能包含相同的组合两次,组合可以以任何顺序返回。 示例 1: 输入: k …

Spring IOC/DI和MVC及若依对应介绍

文章目录 一、Spring IOC、DI注解1.介绍2.使用 二、Spring MVC注解1.介绍2.使用 一、Spring IOC、DI注解 1.介绍 什么是Spring IOC/DI? IOC(Inversion of Control:控制反转)是面向对象编程中的一种设计原则。其中最常见的方式叫做依赖注入(…

windows排除故障工具pathping、MTR、sysinternals

pathping 基本上可以认为它是ping和tracert的功能合体。 pathping首先对目标执行tracert,然后使用ICMP对每一跳进行100次ping操作。 如图,是一个对8.8.8.8进行pathing操作。 MTR MTR是另一个多工具合体工具。 winmtr是mtr的windows版本。 这个工具…

c盘清除文件

打开设置 搜索存储

设计模式(二)-创建者模式(2)-工厂模式

一、为何需要工厂模式(Factory Pattern)? 由于简单工厂模式存在一个缺点,如果工厂类创建的对象过多,使得代码变得越来越臃肿。这样导致工厂类难以扩展新实例,以及难以维护代码逻辑。于是在简单工厂模式的基础上&…

QFile文件读写操作QFileInFo文件信息读取

点击按钮选择路径,路径显示在lineEdit中 将路径下的文件的内容放在textEdit中 最后显示出来 !file.atend()//没有读到文件尾就一直读 file.readline表示按行进行读 追加的方式进行写 要是重新写的话用file.open(QIODevice::write) 用QFileInFo来读取…

AcWing 3. 完全背包问题 学习笔记

有 N� 种物品和一个容量是 V� 的背包,每种物品都有无限件可用。 第 i� 种物品的体积是 vi��,价值是 wi��。 求解将哪些物品装入背包,可使这些物品的总体积不…

pytest测试框架介绍(1)

又来每天进步一点点啦~~~ 一、Pytest介绍: pytest 是一个非常成熟的全功能的Python测试框架; pytest 简单、灵活、易上手; 支持参数化 能够支持简单的单元测试和复杂的功能测试,可以做接口自动化测试(pytestrequests&…

Linux驱动开发——块设备驱动

目录 一、 学习目标 二、 磁盘结构 三、块设备内核组件 四、块设备驱动核心数据结构和函数 五、块设备驱动实例 六、 习题 一、 学习目标 块设备驱动是 Linux 的第二大类驱动,和前面的字符设备驱动有较大的差异。要想充分理解块设备驱动,需要对系统…

国学---佛系算吉凶~

佛系算吉凶咯~,正经走访深山庙宇,前辈老人,经过调研后,搭建的轻衍计算模型,团队对国学的初次信息化尝试。 共享给有需要的朋友,准不准没关系,开心最重要。 后续还有财富,事业&…

STM32F4系列单片机GPIO概述和寄存器分析

第2章 STM32-GPIO口 2.1 GPIO口概述 通用输入/输出口 2.1.1 GPIO口作用 GPIO是单片机与外界进行数据交流的窗口。 2.1.2 STM32的GPIO口 在51单片机中,IO口,以数字进行分组(P0~P3),每一组里面又有8个IO口。 在ST…

redis+python 建立免费http-ip代理池;验证+留接口

前言: 效果图: 对于网络上的一些免费代理ip,http的有效性还是不错的;但是,https的可谓是凤毛菱角; 正巧,有一个web可以用http访问,于是我就想到不如直接拿着免费的HTTP代理去做这个! 思路: 1.单页获取ipporttime (获取time主要是为了后面使用的时候,依照时效可以做文章) 2.整…

庖丁解牛:NIO核心概念与机制详解 01

文章目录 Pre输入/输出Why NIO流与块的比较通道和缓冲区概述什么是缓冲区?缓冲区类型什么是通道?通道类型 NIO 中的读和写概述Demo : 从文件中读取1. 从FileInputStream中获取Channel2. 创建ByteBuffer缓冲区3. 将数据从Channle读取到Buffer中 Demo : 写…

照片+制作照片书神器,效果太棒了!

随着数码相机的普及,越来越多的人喜欢用照片记录生活点滴。而制作一本精美的照片书,不仅可以保存珍贵的回忆,还能让照片更加美观。今天,就为大家推荐一款制作照片书神器,让您的回忆更加完美! 一、产品介绍 …

光敏传感器模块(YH-LDR)

目录 1. YH-LDR模块说明 1.1 简介 1.2 YH-LDR 模块的引脚说明 1.3 LDR 传感器工作原理与输出特性 2. 使用单片机系统控制 YH-LDR 模块 2.1 通用控制说明 1. YH-LDR模块说明 1.1 简介 YH-LDR 是野火设计的光强传感器,使用一个光敏电阻作为采集源&#x…

用cmd看星球大战大电影,c++版本全集星球大战,超长多细节

用cmd看星球大战 最近发现了一个有趣的指令。 是不是感觉很insteresting呢 教程 进入控制面板,点击系统与安全 然后,进入以后,点击启用或关闭 Windows 功能 启用Telnet Client并点击确定 用快捷键winr打开我们的cmd 输入指令 telnet towe…

【diffuser系列】ControlNet

ControlNet: TL;DRControl TypeStableDiffusionControlNetPipeline1. Canny ControlNet1.1 模型与数据加载1.2 模型推理1.3 DreamBooth微调 2. Pose ControlNet2.1 数据和模型加载2.2 模型推理 ControlNet: TL;DR ControlNet 是在 Lvmin Zhang 和 Maneesh Agrawala 的 Adding …

03 前后端数据交互【小白入门SpringBoot + Vue3】

项目笔记,教学视频来源于B站青戈 https://www.bilibili.com/video/BV1H14y1S7YV 前两个笔记。是把前端页面大致做出来,接下来,把后端项目搞一下。 后端项目,使用IDEA软件、jdk1.8、springboot2.x 。基本上用的是稳定版。 还有My…

【算法萌新闯力扣】:旋转字符串

力扣热题:796.旋转字符串 开篇 今天下午刷了6道力扣算法题,选了一道有多种解法的题目与大家分享。 题目链接:796.旋转字符串 题目描述 代码思路 完全按照题目的要求,利用StringBuffer中的方法对字符串进行旋转,寻找相同的一项 …

月子会所信息展示服务预约小程序的作用是什么

传统线下门店经营只依赖自然流量咨询或简单的线上付费推广是比较低效的,属于靠“天”吃饭,如今的年轻人学历水平相对较高,接触的事物或接受的思想也更多更广,加之生活水平提升及互联网带来的长期知识赋能,因此在寻找/咨…