【深度学习实验】注意力机制(四):点积注意力与缩放点积注意力之比较

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 理论介绍
      • a. 认知神经学中的注意力
      • b. 注意力机制
    • 1. 注意力权重矩阵可视化(矩阵热图)
    • 2. 掩码Softmax 操作
    • 3. 打分函数——加性注意力模型
    • 3. 打分函数——点积注意力与缩放点积注意力
      • a. 缩放点积注意力模型
      • b. 点积注意力模型
      • c. 模拟实验
      • d. 模型比较与选择
      • e. 代码整合

一、实验介绍

  注意力机制作为一种模拟人脑信息处理的关键工具,在深度学习领域中得到了广泛应用。本系列实验旨在通过理论分析和代码演示,深入了解注意力机制的原理、类型及其在模型中的实际应用。

本文将介绍将介绍带有掩码的 softmax 操作

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 理论介绍

a. 认知神经学中的注意力

  人脑每个时刻接收的外界输入信息非常多,包括来源于视
觉、听觉、触觉的各种各样的信息。单就视觉来说,眼睛每秒钟都会发送千万比特的信息给视觉神经系统。人脑通过注意力来解决信息超载问题,注意力分为两种主要类型:

  • 聚焦式注意力(Focus Attention):
    • 这是一种自上而下的有意识的注意力,通常与任务相关。
    • 在这种情况下,个体有目的地选择关注某些信息,而忽略其他信息。
    • 在深度学习中,注意力机制可以使模型有选择地聚焦于输入的特定部分,以便更有效地进行任务,例如机器翻译、文本摘要等。
  • 基于显著性的注意力(Saliency-Based Attention)
    • 这是一种自下而上的无意识的注意力,通常由外界刺激驱动而不需要主动干预。
    • 在这种情况下,注意力被自动吸引到与周围环境不同的刺激信息上。
    • 在深度学习中,这种注意力机制可以用于识别图像中的显著物体或文本中的重要关键词。

  在深度学习领域,注意力机制已被广泛应用,尤其是在自然语言处理任务中,如机器翻译、文本摘要、问答系统等。通过引入注意力机制,模型可以更灵活地处理不同位置的信息,提高对长序列的处理能力,并在处理输入时动态调整关注的重点。

b. 注意力机制

  1. 注意力机制(Attention Mechanism):

    • 作为资源分配方案,注意力机制允许有限的计算资源集中处理更重要的信息,以应对信息超载的问题。
    • 在神经网络中,它可以被看作一种机制,通过选择性地聚焦于输入中的某些部分,提高了神经网络的效率。
  2. 基于显著性的注意力机制的近似: 在神经网络模型中,最大汇聚(Max Pooling)和门控(Gating)机制可以被近似地看作是自下而上的基于显著性的注意力机制,这些机制允许网络自动关注输入中与周围环境不同的信息。

  3. 聚焦式注意力的应用: 自上而下的聚焦式注意力是一种有效的信息选择方式。在任务中,只选择与任务相关的信息,而忽略不相关的部分。例如,在阅读理解任务中,只有与问题相关的文章片段被选择用于后续的处理,减轻了神经网络的计算负担。

  4. 注意力的计算过程:注意力机制的计算分为两步。首先,在所有输入信息上计算注意力分布,然后根据这个分布计算输入信息的加权平均。这个计算依赖于一个查询向量(Query Vector),通过一个打分函数来计算每个输入向量和查询向量之间的相关性。

    • 注意力分布(Attention Distribution):注意力分布表示在给定查询向量和输入信息的情况下,选择每个输入向量的概率分布。Softmax 函数被用于将分数转化为概率分布,其中每个分数由一个打分函数计算得到。

    • 打分函数(Scoring Function):打分函数衡量查询向量与输入向量之间的相关性。文中介绍了几种常用的打分函数,包括加性模型、点积模型、缩放点积模型和双线性模型。这些模型通过可学习的参数来调整注意力的计算。

      • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

      • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

      • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

      • 双线性模型 s ( x , q ) = x T W q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{W} \mathbf{q} s(x,q)=xTWq (非对称性)

  5. 软性注意力机制

    • 定义:软性注意力机制通过一个“软性”的信息选择机制对输入信息进行汇总,允许模型以概率形式对输入的不同部分进行关注,而不是强制性地选择一个部分。

    • 加权平均:软性注意力机制中的加权平均表示在给定任务相关的查询向量时,每个输入向量受关注的程度,通过注意力分布实现。

    • Softmax 操作:注意力分布通常通过 Softmax 操作计算,确保它们成为一个概率分布。

1. 注意力权重矩阵可视化(矩阵热图)

【深度学习实验】注意力机制(一):注意力权重矩阵可视化(矩阵热图heatmap)

在这里插入图片描述

2. 掩码Softmax 操作

【深度学习实验】注意力机制(二):掩码Softmax 操作
在这里插入图片描述

3. 打分函数——加性注意力模型

  • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

【深度学习实验】注意力机制(三):打分函数——加性注意力模型

3. 打分函数——点积注意力与缩放点积注意力

  • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

  • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

a. 缩放点积注意力模型

class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

  1. 初始化方法 (__init__):

    • 参数:
      • dropout: Dropout 正则化的概率。
    • 说明: 初始化方法定义了模型的组件,仅包含了一个 Dropout 正则化层。
  2. 前向传播方法 (forward):

    • 参数:
      • queries: 查询张量,形状为 (batch_size, num_queries, d)
      • keys: 键张量,形状为 (batch_size, num_kv_pairs, d)
      • values: 值张量,形状为 (batch_size, num_kv_pairs, value_size)
      • valid_lens: 有效长度张量,形状为 (batch_size,)(batch_size, num_queries)
    • 返回值: 加权平均后的值张量,形状为 (batch_size, num_queries, value_size)
  3. 实现细节:

    • 计算缩放点积得分:通过张量乘法计算 querieskeys 的点积,然后除以 d \sqrt{d} d 进行缩放,其中 d d d 是查询或键的维度。
    • 使用 masked_softmax 函数计算注意力权重,根据有效长度对注意力进行掩码。
    • 将注意力权重应用到值上,得到最终的加权平均结果。
    • 使用 Dropout 对注意力权重进行正则化。

b. 点积注意力模型

class DotProductAttention2(nn.Module):
    """点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention2, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # P195:(8.3),(8.4)
        # 在计算得分时不进行缩放操作(即不再除以sqrt(d))
        self.scores = torch.bmm(queries, keys.transpose(1, 2))
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

  • 区别:
    • 计算点积得分:通过张量乘法计算 querieskeys 的点积。

c. 模拟实验

  • 模型数据
queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])
  • 模型应用
# 创建缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
# 使用模型进行前向传播
attention(queries, keys, values, valid_lens)

# 创建点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
# 使用模型进行前向传播
attention2(queries, keys, values, valid_lens)

在这里插入图片描述

  • 权重可视化
      为了直观地展示模型在不同输入下的注意力分布,使用前文 show_heatmaps 函数,通过热图的形式展示注意力权重。
# 可视化缩放点积注意力权重
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
              xlabel='Keys', ylabel='Queries')

# 可视化点积注意力权重
show_heatmaps(attention2.attention_weights.reshape((1, 1, 2, 10)),
              xlabel='Keys', ylabel='Queries')

在这里插入图片描述

d. 模型比较与选择

  • 缩放点积注意力模型

    • 适用于处理高维度的查询和键。
    • 通过缩放操作有助于防止点积得分的方差过大。
  • 点积注意力模型

    • 适用于处理相对较低维度的查询和键。
    • 更方便地利用矩阵乘积提高计算效率。
      在这里插入图片描述

e. 代码整合

# 导入必要的库
import math
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l
from torch.utils import data


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)


def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6)

class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        print(queries)
        d = queries.shape[-1]
        print(d)
        self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        print(self.scores)
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


class DotProductAttention2(nn.Module):
    """点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention2, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)

    def forward(self, queries, keys, values, valid_lens=None):
        # P195:(8.3),(8.4)
        # 在计算得分时不进行缩放操作(即不再除以sqrt(d))
        self.scores = torch.bmm(queries, keys.transpose(1, 2))
        print(self.scores)
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])


# 缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

# 点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
attention2(queries, keys, values, valid_lens)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/173673.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【GUI】-- 12 贪吃蛇小游戏之让小蛇动起来

GUI编程 04 贪吃蛇小游戏 4.3 第三步:让小蛇动起来(键盘控制) 首先,在构造器中要获取焦点事件、键盘监听事件并加入定时器(定时器定义需要实现ActionListener接口并重写actionPerformed方法): //构造器public GamePanel() {init();this.s…

郎酒“掉队”,经销商们能等来春天吗?

文 | 螳螂观察(TanglangFin) 作者 | 渡过 有“六朵金花”之称的川酒品牌中,五粮液、泸州老窖、舍得、水井坊都已成功上市,只剩下郎酒和剑南春未上市。 与IPO的“掉队”相对应的,是郎酒在冲刺高端、内部管理、渠道管…

【LeetCode刷题-链表】--23.合并K个升序链表

23.合并K个升序链表 方法:顺序合并 在前面已经知道合并两个升序链表的前提下,用一个变量ans来维护以及合并的链表,第i次循环把第i个链表和ans合并,答案保存到ans中 /*** Definition for singly-linked list.* public class List…

SQL基础理论篇(八):视图

文章目录 简介创建视图修改视图删除视图总结参考文献 简介 视图,即VIEW,是SQL中的一个重要概念,它其实是一种虚拟表(非实体数据表,本身不存储数据)。 视图类似于编程中的函数,也可以理解成是一个访问数据的接口。 从…

VPX 插座(VITA46)介绍及应用 (简单介绍)

1. VPX 插座的介绍 VPX是VITA(VME International Trade Association, VME国际贸易协会)组织于2007年在其VME总线基础上提出的新一代高速串行总线标准。VPX总线的基本规范、机械结构和总线信号等具体内容均在ANSI/VITA46系列技术规范中定义。VPX就是基于高速串行总线的新一代总线…

基于Surfer与Voxler数据处理及可视化实践技术应用

Surfer和Voxler分别是美国Golden Software 公司开发的用于二维和三维数据可视化软件,具有强大的数据处理和插值功能,软件主要应用于气象、环境和地质(以及生物、医学等)等领域。其中Surfer主要用于绘制二维等值线图、三维表面图以…

仪表盘:pyecharts绘制

一、仪表盘 在数据分析中,仪表盘图(dashboard)的作用是以一种简洁、图表化的方式呈现数据的关键指标和核心信息,以帮助用户快速理解数据的情况,并从中提取关键见解。 仪表盘图通常由多个图表、指标和指示器组成&…

Java —— String类

目录 1. String类的重要性 2. 常用方法 2.1 字符串构造 2.2 String对象的比较 2.3 字符串查找 2.4 转化 1. 数值和字符串转化 2. 大小写转换 3. 字符串转数组 4. 格式化 2.5 字符串替换 2.6 字符串拆分 2.7 字符串截取 2.8 其他操作方法 2.9 字符串常量池 2.9.1 创建对象的思考…

基于安卓android微信小程序的刷题系统

项目介绍 面试刷题系统的开发过程中,采用B / S架构,主要使用jsp技术进行开发,中间件服务器是Tomcat服务器,使用Mysql数据库和Eclipse开发环境。该面试刷题系统包括会员、答题录入员和管理员。其主要功能包括管理员:个…

芯片IO口不加电阻会怎样?

芯片IO口不加电阻会怎样? 可能会导致以下几个后果: 1.高电流问题,IO口没有电阻限流,当与外部设备直接连接时,就可能会导致过大的电流流过IO口,这就可能损坏IO口,引起短路或烧坏其它电路组件。像…

Spring Boot要如何学习?【云驻共创】

Spring Boot 是由 Pivotal 团队提供的全新框架,其设计目的是用来简化新 Spring 应用的初始搭建以及开发过程。该框架使用了特定的方式来进行配置,从而使开发人员不再需要定义样板化的配置。我这里会分享一些学习Spring Boot的方法和干货,包括…

【C++】泛型编程 ⑨ ( 类模板的运算符重载 - 函数声明 和 函数实现 写在同一个类中 | 类模板 的 外部友元函数问题 )

文章目录 一、类模板 - 函数声明与函数实现分离1、函数声明与函数实现分离2、代码示例 - 函数声明与函数实现分离3、函数声明与函数实现分离 友元函数引入 二、普通类的运算符重载 - 函数声明 和 函数实现 写在同一个类中三、类模板的运算符重载 - 函数声明 和 函数实现 写在同…

Java实现windows系统截图

Java提供了一种方便的方式来截取Windows系统的截图。这个过程通常需要使用Java的Robot类来模拟用户的鼠标和键盘输入操作。下面将介绍如何使用Java实现Windows系统截图。 步骤1:导入Robot和AWT包 Java提供了一个Robot类,它可以模拟用户的键盘和鼠标操作…

医疗器械维修售后技术培训与支持的重要性

医疗器械维修售后技术培训与支持的重要性 随着我国医疗器械产业的的高速发展、医疗器械企业的崛起,大量创新医疗器械产品进入医疗机构。但医疗设备在使用和维护过程中,暴露出许多问题和不足,如部分设备故障率较高、临床工程培训不足、售后服务模式整体比较落后等,这影响了医疗…

当前系统并无桌面环境,或无显示器,无法显示远程桌面,您需要自行安装X11桌面环境,或者使用终端文件功能

ToDesk远程遇到的问题如上图,换向日葵直接黑屏; 问题原因 截止发文时间,Todesk只支持X11协议,没有适配最新的Wayland协议,所以我们需要把窗口系统调整为X11才可以。 解决方法 修改配置文件,关闭wayland su…

c# 文件读取和写入

文件写入 using System.Collections.Generic; namespace demo1;/// <summary> /// System.IO下的所有的Stream类是所有数据流的基类 /// 流是用于传输数据的对象&#xff0c;流就是用来传输数据的 /// 数据传输的两种方式&#xff1a;1、数据从外部源传输到程序中&#…

技术互联 创新交流 | 广汽研究院举办技术交流会圆满落幕

技术互联 创新交流 2023年11月1日&#xff0c;同星智能走进广汽研究院技术交流会圆满举行并落下帷幕。本次技术交流会得到广汽研究院相关部门的大力支持&#xff0c;并邀请到多名人员参与&#xff0c;涉及其技术、研发等部门。 本次活动的举办意义重大&#xff0c;目前广汽研究…

如何把A3 pdf 文章打印成A4

1. 用Adobe Acrobat 打开pdf 2 打印 选择海报 进行调整即可如下图,见下面红色的部分。

互联网医院源码搭建部署功能

互联网将医院与患者、医院内部&#xff08;医生、护士、领导层&#xff09;、医院与生态链上的各类组织机构连接起来。以患者为中心&#xff0c;优化医院业务流程&#xff0c;提升医疗服务质量与医院资源能效&#xff0c;让患者通过移动互联网随时随地的享受医院的服务&#xf…

练习八-利用有限状态机进行时序逻辑的设计

利用有限状态机进行时序逻辑的设计 1&#xff0c;任务目的&#xff1a;2&#xff0c;RTL代码&#xff0c;及原理框图3&#xff0c;测试代码&#xff0c;输出波形 1&#xff0c;任务目的&#xff1a; &#xff08;1&#xff09;掌握利用有限状态机实现一般时序逻辑分析的方法&am…