《动手学深度学习 Pytorch版》 10.6 自注意力和位置编码

在注意力机制中,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention),也被称为内部注意力(intra-attention)。本节将使用自注意力进行序列编码,以及使用序列的顺序作为补充信息。

import math
import torch
from torch import nn
from d2l import torch as d2l

10.6.1 自注意力

在这里插入图片描述

给定一个由词元组成的输入序列 x 1 , … , x n \boldsymbol{x}_1,\dots,\boldsymbol{x}_n x1,,xn,其中任意 x i ∈ R d ( 1 ≤ i ≤ n ) \boldsymbol{x}_i\in\R^d\quad(1\le i\le n) xiRd(1in) 。该序列的自注意力输出为一个长度相同的序列 y 1 , … , y n \boldsymbol{y}_1,\dots,\boldsymbol{y}_n y1,,yn,其中:

y i = f ( x i , ( x 1 , x 1 ) , … , ( x n , x n ) ) ∈ R d \boldsymbol{y}_i=f(\boldsymbol{x}_i,(\boldsymbol{x}_1,\boldsymbol{x}_1),\dots,(\boldsymbol{x}_n,\boldsymbol{x}_n))\in\R^d yi=f(xi,(x1,x1),,(xn,xn))Rd

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,  # 基于多头注意力对一个张量完成自注意力的计算
                                   num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))  # 张量的形状为(批量大小,时间步的数目或词元序列的长度,d)。
attention(X, X, X, valid_lens).shape  # 输出与输入的张量形状相同
torch.Size([2, 4, 100])

10.6.2 比较卷积神经网络、循环神经网络和自注意力

在这里插入图片描述

  • 卷积神经网络

    • 计算复杂度为 O ( k n d 2 ) O(knd^2) O(knd2)

      • k k k 为卷积核大小

      • n n n 为序列长度是

      • d d d 为输入和输出的通道数量

    • 并行度为 O ( n ) O(n) O(n)

    • 最大路径长度为 O ( n / k ) O(n/k) O(n/k)

  • 循环神经网络

    • 计算复杂度为 O ( n d 2 ) O(nd^2) O(nd2)

      d × d d\times d d×d 权重矩阵和 d d d 维隐状态的乘法计算复杂度为 O ( d 2 ) O(d^2) O(d2),由于序列长度为 n n n,因此循环神经网络层的计算复杂度为 O ( n d 2 ) O(nd^2) O(nd2)

    • 并行度为 O ( 1 ) O(1) O(1)

      O ( n ) O(n) O(n) 个顺序操作无法并行化。

    • 最大路径长度也是 O ( n ) O(n) O(n)

  • 自注意力

    • 计算复杂度为 O ( n 2 d ) O(n^2d) O(n2d)

      查询、键和值都是 n × d n\times d n×d 矩阵

    • 并行度为 O ( n ) O(n) O(n)

      每个词元都通过自注意力直接连接到任何其他词元。因此有 O ( 1 ) O(1) O(1) 个顺序操作可以并行计算

    • 最大路径长度也是 O ( 1 ) O(1) O(1)

总而言之,卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

10.6.3 位置编码

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加 位置编码(positional encoding) 来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。

基于正弦函数和余弦函数的固定位置编码的矩阵第 i i i 行、第 2 j 2j 2j 列和 2 j + 1 2j+1 2j+1 列上的元素为:

p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) \begin{align} p_{i,2j}&=\sin{\left(\frac{i}{10000^{2j/d}}\right)}\\ p_{i,2j+1}&=\cos{\left(\frac{i}{10000^{2j/d}}\right)} \end{align} pi,2jpi,2j+1=sin(100002j/di)=cos(100002j/di)

#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

在位置嵌入矩阵 P \boldsymbol{P} P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。从下面的例子中可以看到位置嵌入矩阵的第 6 列和第 7 列的频率高于第 8 列和第 9 列。第 6 列和第 7 列之间的偏移量(第 8 列和第 9 列相同)是由于正弦函数和余弦函数的交替。

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])


在这里插入图片描述

10.6.3.1 绝对位置信息

打印出 0 , 1 , … , 7 0,1,\dots,7 0,1,,7 的二进制表示形式即可明白沿着编码维度单调降低的频率与绝对位置信息的关系。

每个数字、每两个数字和每四个数字上的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替。

for i in range(8):
    print(f'{i}的二进制是:{i:>03b}')
0的二进制是:000
1的二进制是:001
2的二进制是:010
3的二进制是:011
4的二进制是:100
5的二进制是:101
6的二进制是:110
7的二进制是:111

在二进制表示中,较高比特位的交替频率低于较低比特位,与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。

P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')


在这里插入图片描述

10.6.3.2 相对位置信息

除了捕获绝对位置信息之外,上述的位置编码还允许模型学习得到输入序列中相对位置信息。这是因为对于任何确定的位置偏移 δ \delta δ,位置 i + δ i+\delta i+δ 处的位置编码可以线性投影位置 i i i 处的位置编码来表示。

这种投影的数学解释是,令 ω j = 1 / 1000 0 2 j / d \omega_j=1/10000^{2j/d} ωj=1/100002j/d,对于任何确定的位置偏移 δ \delta δ,上个式子中的任何一对 ( p i , 2 j , p i , 2 j + 1 ) (p_{i,2j},p_{i,2j+1}) (pi,2j,pi,2j+1) 都可以线性投影到 ( p i + δ , 2 j , p i + δ , 2 j + 1 ) (p_{i+\delta,2j},p_{i+\delta,2j+1}) (pi+δ,2j,pi+δ,2j+1)

[ cos ⁡ ( δ ω j ) sin ⁡ ( δ ω j ) − sin ⁡ ( δ ω j ) cos ⁡ ( δ ω j ) ] [ p i , 2 j p i , 2 j + 1 ] = [ cos ⁡ ( δ ω j ) sin ⁡ ( i ω j ) + sin ⁡ ( δ ω j ) cos ⁡ ( i ω j ) − sin ⁡ ( δ ω j ) sin ⁡ ( i ω j ) + cos ⁡ ( δ ω j ) cos ⁡ ( i ω j ) ] = [ sin ⁡ ( ( i + δ ) ω j ) cos ⁡ ( ( i + δ ) ω j ) ] = [ p i , 2 j p i , 2 j + 1 ] \begin{align} &\begin{bmatrix} \cos{(\delta\omega_j)} & \sin{(\delta\omega_j)}\\ -\sin{(\delta\omega_j)} & \cos{(\delta\omega_j)} \end{bmatrix} \begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix}\\ =&\begin{bmatrix} \cos{(\delta\omega_j)}\sin{(i\omega_j)}+\sin{(\delta\omega_j)}\cos{(i\omega_j)}\\ -\sin{(\delta\omega_j)}\sin{(i\omega_j)}+\cos{(\delta\omega_j)}\cos{(i\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} \sin{((i+\delta)\omega_j)}\\ \cos{((i+\delta)\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix} \end{align} ===[cos(δωj)sin(δωj)sin(δωj)cos(δωj)][pi,2jpi,2j+1][cos(δωj)sin(iωj)+sin(δωj)cos(iωj)sin(δωj)sin(iωj)+cos(δωj)cos(iωj)][sin((i+δ)ωj)cos((i+δ)ωj)][pi,2jpi,2j+1]

2 × 2 2\times 2 2×2 投影矩阵不依赖于任何位置的索引 i i i

练习

(1)假设设计一个深度架构,通过堆叠基于位置编码的自注意力层来表示序列。可能会存在什么问题?


(2)请设计一种可学习的位置编码方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/107160.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

图纸管理制度 《五》

1、存档文件应由专人管理,其他人未征得管理人员同意,不得随意翻阅查看。 2、档案管理人员要认真贯彻执行公司相关制度,严禁泄露档案材料中的秘密。 彩虹图纸管理软件_图纸管理系统_图纸文档管理软件系统_彩虹EDM【官网】彩虹EDM图纸管理软件…

docker 部署 若依 Ruoyi springboot+vue分离版 dockerCompose

本篇从已有虚拟机/服务器 安装好dokcer为基础开始讲解 1.部署mysql 创建conf data init三个文件夹 conf目录存放在mysql配置文件 init目录存放着若依数据库sql文件(从navicat导出的并非若依框架自带sql) 创建一个属于本次若依部署的网段(只…

【产品经理】APP备案(阿里云)

工信部《关于开展移动互联网应用程序备案工作的通知》 工业和信息化部印发了《关于开展移动互联网应用程序备案工作的通知》,“在中华人民共和国境内从事互联网信息服务的App主办者,应当依照相关法律法规等规定履行备案手续,未履行备案手续的…

【C程序设计】用心浇灌<C程序>

目录 数据类型 整数类型 实例 浮点类型 void 类型 类型转换 数据类型 在 C 语言中,数据类型指的是用于声明不同类型的变量或函数的一个广泛的系统。变量的类型决定了变量存储占用的空间,以及如何解释存储的位模式。 C 中的类型可分为以下几种&…

深度学习使用Keras进行迁移学习提升网络性能

上一篇文章我们用自己定义的模型来解决了二分类问题,在20个回合的训练之后得到了大约74%的准确率,一方面是我们的epoch太小的原因,另外一方面也是由于模型太简单,结构简单,故而不能做太复杂的事情,那么怎么提升预测的准确率了?一个有效的方法就是迁移学习。 迁移学习其…

C++反转链表递归

文章目录 题目描述解题思路代码复杂度分析 题目描述 LCR 024. 反转链表 - 力扣(LeetCode) 给定单链表的头节点 head ,请反转链表,并返回反转后的链表的头节点。 解题思路 这里我们采用递归的思路来解决首先我们分为两个视角来查看…

[debug/main.o] Error 1 QtCreator编译报错

我是用Qt5.6.0MinGW32位版本编译程序,在Pro文件中添加了预编译头文件后编译报错:mingw32-make[1]: *** [debug/main.o] Error 1; #添加预编译头文件 CONFIG precompiled_header PRECOMPILED_HEADER header.h 解决方法: 1.删除…

大数据-Storm流式框架(六)---Kafka介绍

Kafka简介 Kafka是一个分布式的消息队列系统(Message Queue)。 官网:Apache Kafka 消息和批次 kafka的数据单元称为消息。消息可以看成是数据库表的一行或一条记录。 消息由字节数组组成,kafka中消息没有特别的格式或含义。 消息有可选的键&#x…

【跟小嘉学 Rust 编程】三十三、Rust的Web开发框架之一: Actix-Web的基础

系列文章目录 【跟小嘉学 Rust 编程】一、Rust 编程基础 【跟小嘉学 Rust 编程】二、Rust 包管理工具使用 【跟小嘉学 Rust 编程】三、Rust 的基本程序概念 【跟小嘉学 Rust 编程】四、理解 Rust 的所有权概念 【跟小嘉学 Rust 编程】五、使用结构体关联结构化数据 【跟小嘉学…

【Unity精华一记】特殊文件夹

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:uni…

云原生安全:如何保护云上应用不受攻击

文章目录 云原生安全的概念1. 多层次的安全性2. 自动化安全3. 容器安全4. 持续监控5. 合规性 云原生安全的关键挑战1. 无边界的环境2. 动态性3. 多云环境4. 容器化应用程序5. API和微服务 如何保护云上应用不受攻击1. 身份验证和访问控制示例代码: 2. 数据加密示例代…

探秘Kafka背后的幕后机关,揭示消息不丢失或重复的原理与实践经验

背景 相信大家在工作中都用过消息队列,特别是 Kafka 使用得更是普遍,业务工程师在使用 Kafka 的时候除了担忧 kafka 服务端宕机外,其实最怕如下这样两件事。 消息丢失。下游系统没收到上游系统发送的消息,造成系统间数据不一致。…

PyTorch中grid_sample的使用方法

官方文档首先Pytorch中grid_sample函数的接口声明如下: torch.nn.functional.grid_sample(input, grid, modebilinear, padding_modezeros, align_cornersNone)input : 输入tensor, shape为 [N, C, H_in, W_in]grid: 一个field flow, shape为…

JAVA实现校园失物招领管理系统 开源

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 招领管理模块2.2 寻物管理模块2.3 系统公告模块2.4 感谢留言模块 三、界面展示3.1 登录注册3.2 招领模块3.3 寻物模块3.4 公告模块3.5 感谢留言模块3.6 系统基础模块 四、免责说明 一、摘要 1.1 项目介绍 基于VueSpri…

深入剖析SQL与NoSQL的优劣势,帮你决定最佳数据存储方案

你是否在为系统的数据库来一波大流量就几乎打满 CPU,日常 CPU 居高不下烦恼?你是否在各种 NoSQL 间纠结不定,到底该选用哪种最好?今天的你就是昨天的我,这也是我写这篇文章的初衷。 作为互联网从业人员,我们要知道关系型数据库…

蓝桥杯 第 2 场算法双周赛 第4题 通关【算法赛】c++ 优先队列 + 小根堆 详解注释版

题目 通关【算法赛】https://www.lanqiao.cn/problems/5889/learning/?contest_id145 问题描述 小蓝最近迷上了一款电玩游戏“蓝桥争霸”。这款游戏由很多关卡和副本组成,每一关可以抽象为一个节点,整个游戏的关卡可以抽象为一棵树形图,每…

群晖上搭建teamspeak3语音服务器

什么是 TeamSpeak ? TeamSpeak (简称 TS)是一款团队语音通讯工具,但比一般的通讯工具具有更多的功能而且使用方便。它由服务器端程序和客户端程序两部分组成,如果不是想自己架设 TS 服务器,只需下载客户端程…

SQL Server Management Studio (SSMS)的安装教程

文章目录 SQL Server Management Studio (SSMS)的安装教程从Microsoft官网下载SQL Server Management Studio安装程序。选中安装程序右键并选择“以管理员的身份运行”选项选择安装目录,单击“安装”按钮开始安装过程安装成功界面安装完成后,您可以启动S…

LaTeX:在标题section中添加脚注footnote

命令讲解 先导包: \usepackage{footmisc} 设原标题为: \section{标题内容} 更改为: \section[标题内容]{标题内容\protect\footnote{脚注内容}} 语法讲解: \section[]{} []内为短标题,作为目录和页眉中的标题。…

Java面向对象(进阶)-- this关键字的使用

文章目录 一、引子(1) this是什么?(2)什么时候使用this1.实例方法或构造器中使用当前对象的成员2. 同一个类中构造器互相调用 二、探讨(1)问题(2)解决 三、this关键字&am…