动手学深度学习10.5. 多头注意力-笔记练习(PyTorch)

本节课程地址:多头注意力代码_哔哩哔哩_bilibili

本节教材地址:10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation

本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>multihead-attention.ipynb


多头注意力

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 h 组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这 h 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 h 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) (="https://zh.d2l.ai/chapter_references/zreferences.html#id174">Vaswaniet al., 2017)。 对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 图10.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询 \mathbf{q} \in \mathbb{R}^{d_q} 、 键 \mathbf{k} \in \mathbb{R}^{d_k} 和 值 \mathbf{v} \in \mathbb{R}^{d_v} , 每个注意力头 \mathbf{h}_i( i = 1, \ldots, h )的计算方法为:

\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},

其中,可学习的参数包括 \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} 、 \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} 和 \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} , 以及代表注意力汇聚的函数 f 。 f 可以是 10.3节 中的 加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着 h 个头连结后的结果,因此其可学习参数是 \mathbf W_o\in\mathbb R^{p_o\times h p_v} :

\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

实现

在实现过程中通常[选择缩放点积注意力作为每一个注意力头]。 为了避免计算代价和参数代价的大幅增长, 我们设定 p_q = p_k = p_v = p_o / h 。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p_q h = p_k h = p_v h = p_o , 则可以并行计算 h 个头。 在下面的实现中, p_o 是通过参数num_hiddens指定的。

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

为了能够[使多个头并行计算], 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    # 输入X的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状:((batch_size,查询或者“键-值”对的个数,num_hiddens)
    return X.reshape(X.shape[0], X.shape[1], -1)

下面使用键和值相同的小例子来[测试]我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。

练习

  1. 分别可视化这个实验中的多个头的注意力权重。

解:
代码如下:

attention.attention.attention_weights.shape
# (batch_size*num_heads,查询的个数,“键-值”对的个数)

输出结果:
torch.Size([10, 4, 6])

d2l.show_heatmaps(
    attention.attention.attention_weights.reshape((2,5,4,6)), 
    xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 6)],
    figsize=(8, 3.5))

输出结果:

2. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

解:
首先定义评判注意力头重要性的指标,比如预测速度等;
然后采用单一变量法,修剪某一个头或某几个头的组合,重新训练模型,并在验证集上评估重要性指标的变化; 最后根据重要性指标的变化,判断最不重要的一个或几个注意力头,并修剪。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/928425.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大R玩家流失预测在休闲社交游戏中的应用

摘要 预测玩家何时会离开游戏为延长玩家生命周期和增加收入贡献创造了独特的机会。玩家可以被激励留下来,战略性地与公司组合中的其他游戏交叉链接,或者作为最后的手段,通过游戏内广告传递给其他公司。本文重点预测休闲社交游戏中高价值玩家…

软件质量保证——单元测试之白盒技术

笔记内容及图片整理自XJTUSE “软件质量保证” 课程ppt,仅供学习交流使用,谢谢。 程序图 程序图定义 程序图P(V,E),V是节点的集合(节点是程序中的语句或语句片段),E是有向边的集合…

Node.js:开发和生产之间的区别

Node.js 中的开发和生产没有区别,即,你无需应用任何特定设置即可使 Node.js 在生产配置中工作。但是,npm 注册表中的一些库会识别使用 NODE_ENV 变量并将其默认为 development 设置。始终在设置了 NODE_ENVproduction 的情况下运行 Node.js。…

Zookeeper的通知机制是什么?

大家好,我是锋哥。今天分享关于【Zookeeper的通知机制是什么?】面试题。希望对大家有帮助; Zookeeper的通知机制是什么? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Zookeeper的通知机制主要通过Watcher实现,它是Zookeeper客…

Request method ‘POST‘ not supported(500)

前端路径检查 查看前端的请求路径地址、请求类型、方法名是否正确,结果没问题 后端服务检查 查看后端的传参uri、传参类型、方法名,结果没问题 nacos服务名检查 检查注册的服务是否对应(我这里是后端的服务名是‘ydlh-gatway’,服务列表走…

ESP32-S3模组上跑通ES8388(13)

接前一篇文章:ESP32-S3模组上跑通ES8388(12) 二、利用ESP-ADF操作ES8388 2. 详细解析 上一回解析了es8388_init函数中的第6段代码,本回继续往下解析。为了便于理解和回顾,再次贴出es8388_init函数源码,在…

Ubuntu 包管理

APT&dpkg 查看已安装包 查看所有已经安装的包 dpkg -l 查找包 apt search <package_name>搜索软件包列表&#xff0c;找到与搜索关键字匹配的包 dpkg与grep结合查找特定的包 dpkg -s <package>&#xff1a;查看某个安装包的详细信息 安装包 apt安装命令 更新…

泛化调用 :在没有接口的情况下进行RPC调用

什么是泛化调用&#xff1f; 在RPC调用的过程中&#xff0c;调用端向服务端发起请求&#xff0c;首先要通过动态代理&#xff0c;动态代理可以屏蔽RPC处理流程&#xff0c;使得发起远程调用就像调用本地一样。 RPC调用本质&#xff1a;调用端向服务端发送一条请求消息&#x…

开源 - Ideal库 - Excel帮助类,ExcelHelper实现(四)

书接上回&#xff0c;前面章节已经实现Excel帮助类的第一步TableHeper的对象集合与DataTable相互转换功能&#xff0c;今天实现进入其第二步的核心功能ExcelHelper实现。 01、接口设计 下面我们根据第一章中讲解的核心设计思路&#xff0c;先进行接口设计&#xff0c;确定Exce…

Android Studio 右侧工具栏 Gradle 不显示 Task 列表

问题&#xff1a; android studio 4.2.1版本更新以后AS右侧工具栏Gradle Task列表不显示&#xff0c;这里需要手动去设置 解决办法&#xff1a; android studio 2024.2.1 Patch 2版本以前的版本设置&#xff1a;依次打开 File -> Settings -> Experimental 选项&#x…

力扣hot100道【贪心算法后续解题方法心得】(三)

力扣hot100道【贪心算法后续解题方法心得】 十四、贪心算法关键解题思路1、买卖股票的最佳时机2、跳跃游戏3、跳跃游戏 | |4、划分字母区间 十五、动态规划什么是动态规划&#xff1f;关键解题思路和步骤1、打家劫舍2、01背包问题3、完全平方式4、零钱兑换5、单词拆分6、最长递…

计算机网络常见面试题总结(上)

计算机网络基础 网络分层模型 OSI 七层模型是什么&#xff1f;每一层的作用是什么&#xff1f; OSI 七层模型 是国际标准化组织提出的一个网络分层模型&#xff0c;其大体结构以及每一层提供的功能如下图所示&#xff1a; 每一层都专注做一件事情&#xff0c;并且每一层都需…

编译MT7620 OpenWrt的所有机型的固件

前置条件&#xff1a;准备VMware16Ubuntu16.04的编译环境 安装编译需要的插件 sudo apt-get install gcc g binutils patch bzip2 flex bison make autoconf gettext texinfo unzip sharutils libncurses5-dev ncurses-term zlib1g-dev gawk asciidoc libz-dev git-core uuid…

vue2+svg+elementui实现花瓣图自定义el-select回显色卡图片

项目需要实现花瓣图&#xff0c;但是改图表在echarts&#xff0c;highCharts等案例中均未出现&#xff0c;有类似的韦恩图&#xff0c;但是和需求有所差距&#xff1b; 为实现该效果&#xff0c;静态图表上采取svg来手动绘制花瓣&#xff1a; 确定中心点&#xff0c;以该点为中…

释放超凡性能,打造鸿蒙原生游戏卓越体验

11月26日在华为Mate品牌盛典上&#xff0c;全新Mate70系列及多款全场景新品正式亮相。在游戏领域&#xff0c;HarmonyOS NEXT加持下游戏的性能得到充分释放。HarmonyOS SDK为开发者提供了软硬协同的系统级图形加速解决方案——Graphics Accelerate Kit&#xff08;图形加速服务…

Docker应用

一、需要知道什么&#xff08;先去了解一下&#xff09; Docker 概述 Docker 安装&#xff08;配置阿里云镜像加速器&#xff0c;要不 pull 镜像下载的慢&#xff09; portainer 可视化界面 关于镜像、容器、仓库的命令等 数据卷 Docker 网络 DockerFile DockerCompose …

Qt桌面应用开发 第十天(综合项目二 翻金币)

目录 1.主场景搭建 1.1重载绘制事件&#xff0c;绘制背景图和标题图片 1.2设置窗口标题&#xff0c;大小&#xff0c;图片 1.3退出按钮对应关闭窗口&#xff0c;连接信号 2.开始按钮创建 2.1封装MyPushButton类 2.2加载按钮上的图片 3.开始按钮跳跃效果 3.1按钮向上跳…

VScode离线下载扩展安装

在使用VScode下在扩展插件时&#xff0c;返现VScode搜索不到插件&#xff0c;网上搜了好多方法&#xff0c;都不是常规操作&#xff0c;解决起来十分麻烦&#xff0c;可以利用离线下载安装的方式安装插件&#xff01;亲测有效&#xff01;&#xff01;&#xff01; 1.找到VScod…

浅谈网络 | 应用层之HTTPS协议

目录 对称加密非对称加密数字证书HTTPS 的工作模式重放与篡改 使用 HTTP 协议浏览新闻虽然问题不大&#xff0c;但在更敏感的场景中&#xff0c;例如支付或其他涉及隐私的数据传输&#xff0c;就会面临巨大的安全风险。如果仍然使用普通的 HTTP 协议&#xff0c;数据在网络传输…

MySQL有哪些日志?

MySQL主要有三种日志&#xff1a;undo log、redo log、binlog。前两种是InnoDB特有的&#xff0c;binlog是MySQL的Server层中的。 Buffer Pool buffer pool是MySQL的缓冲池&#xff0c;里面存储了数据页、索引页、undo页等&#xff08;与数据库不一致的即为脏页&#xff09;。…