机器学习深度学习——注意力提示、注意力池化(核回归)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——常见循环神经网络结构(RNN、LSTM、GRU)
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

机器学习&&深度学习——注意力提示、注意力池化(核回归)

  • 注意力提示
    • 引入
    • 查询、键和值
    • 注意力的可视化
    • 小结
  • 注意力池化:Nadaraya-Watson核回归
    • 生成数据集
    • 平均池化
    • 非参数注意力池化
    • 带参数注意力池化
      • 批量矩阵乘法
      • 定义模型
      • 训练
    • 小结

注意力提示

引入

之前讲过的CNN和RNN模型,容易发现的一个点是,他们并没有刻意的、主观的针对某个点去做训练和预测,这是因为他们并没有注意力。讲到注意力我们就要知道心理学中的非自主性提示自主性提示
什么叫非自主性提示呢?举个例子,一群杀马特中间的老大是个光头,根本不需要别人告知你,你就会觉得这人很突兀,很不一般,也就是说人的注意力没有受到认知和意识的控制。
而自主性提示就是人的注意力受到认知和意识的控制,因此注意力在基于自主性提示去辅助选择时将更为谨慎。受试者的主观意愿推动,选择的力量也就更强大。

查询、键和值

自主性的与非自主性的注意力提示解释了人类的注意力的方式,下面来看看如何通过这两种注意力提示,用神经网络来设计注意力机制的框架:
首先,考虑简单的非自主性提示。要想将选择偏向于感官输入,可以简单使用参数化的全连接层,甚至是非参数化的最大池化层或平均池化层。
所以,“是否包含自主性提示”将注意力机制与全连接层或池化层区别开来。
在注意力机制下,自主性提示被称为查询(给定任何查询,注意力机制通过注意力池化将选择引导至感官输入)。在注意力机制中,这些感官输入被称为,每个值都有一个配对,可以想象成感官输入的非自主性提示。如下所示:
在这里插入图片描述
注意:汇聚就是池化。

注意力的可视化

平均池化层可以被视为输入的加权平均值,其中各输入的权重是一样的。实际上,注意力池化得到的是加权平均的总和值,其中权重是在给定的查询和不同的键之间计算得出的。
下面来进行注意力可视化,我们通过热力图的方式来展示,定义一个show_heatmaps,其输入matrices的形状是(要显示的行数,要显示的列数,查询的数目,键的数目)。

import torch
from d2l import torch as d2l

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6)

用简单例子来演示一下:仅当查询和键相同时,注意力权重为1,否则为0

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
d2l.plt.show()

输出结果:
在这里插入图片描述

小结

1、人类的注意力是有限的、有价值和稀缺的资源。
2、使用非自主性和自主性提示有选择性地引导注意力,前者基于突出性,后者则依赖于意识。
3、注意力机制与全连接层或者池化层的区别源于增加的自主提示。
4、由于包含了自主性提示,注意力机制与全连接的层或池化层不同。
5、注意力机制通过注意力池化使选择偏向于值(感官输入),其中包含查询(自主性提示)和键(非自主性提示)。键和值是成对的。
6、可视化查询和键之间的注意力权重是可行的。

注意力池化:Nadaraya-Watson核回归

上面已经介绍了注意力机制的主要成分:查询(自主提示)和键(非自主提示)之间的交互形成了注意力池化。注意力池化有选择地聚合了值(感官输入)以生成最终的输出。后面将会以简单例子来进行讲解。

import torch
from torch import nn
from d2l import torch as d2l

生成数据集

简单起见,考虑下面这个回归问题:给定的成对的“输入-输出”数据集,如何学习f来预测任意新输入的x,其输出y_hat=f(x)。
生成一个人工数据集,加入噪声项σ:
y i = 2 s i n ( x i ) + x i 0.8 + σ y_i=2sin(x_i)+x_i^{0.8}+\sigma yi=2sin(xi)+xi0.8+σ
其中σ服从均值为0,标准差0.5的正态分布。在这里生成了50个训练样本和50个测试样本。为了更好地可视化之后的注意力模式,需要将训练样本进行排序。

import torch
from torch import nn
from d2l import torch as d2l

n_train = 50  # 训练样本数
x_train = torch.sort(torch.rand(n_train) * 5)[0]   # 排序后的训练样本

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数

下面函数将绘制所有的训练样本(圆圈表示),不带噪声项的真实数据生成函数f(标记为Truth),以及学习得到的预测函数(标记为Pred):

def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5)

平均池化

其表达式为:
f ( x ) = 1 n ∑ i = 1 n y i f(x)=\frac{1}{n}\sum_{i=1}^ny_i f(x)=n1i=1nyi
在绘制预测结果之前,先介绍一下repeat_interleave函数,可以看下面的这一篇博客,看懂用法:
【PyTorch】repeat_interleave()方法详解
现在我们可以绘制预测结果:

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
d2l.plt.show()

在这里插入图片描述
可以发现真实函数和预测函数的差距很大,并不是很好的方式。

非参数注意力池化

很显然,平均池化只注意y而没有注意x,于是提出了更好的想法,根据输入的位置对输出y进行加权:
f ( x ) = ∑ i = 1 n K ( x − x i ) ∑ j = 1 n K ( x − x j ) y i f(x)=\sum_{i=1}^n\frac{K(x-x_i)}{\sum_{j=1}^nK(x-x_j)}y_i f(x)=i=1nj=1nK(xxj)K(xxi)yi
其中K是核,上面的估计器被称为Nadaraya-Watson核回归。我们可以从注意力机制的框架角度重写上式,成为一个更通用的注意力池化公式:
f ( x ) = ∑ i = 1 n α ( x , x i ) y i 其中, x 是查询, ( x i , y i ) 是键值对 f(x)=\sum_{i=1}^nα(x,x_i)y_i\\ 其中,x是查询,(x_i,y_i)是键值对 f(x)=i=1nα(x,xi)yi其中,x是查询,(xi,yi)是键值对
显然,注意力池化就是yi的加权平均,将查询x和键xi之间的关系建模为注意力权重α。容易知道,对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布:它们是非负的,并且总和为1。
比如,我们可以考虑一个高斯核,定义为:
K ( u ) = 1 2 π e x p ( − u 2 2 ) K(u)=\frac{1}{\sqrt{2π}}exp(-\frac{u^2}{2}) K(u)=2π 1exp(2u2)
将高斯核带入上式,得:
f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n e x p ( − 1 2 ( x − x i ) 2 ) ∑ j = 1 n e x p ( − 1 2 ( x − x j ) 2 ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( x − x i ) 2 ) y i f(x)=\sum_{i=1}^nα(x,x_i)y_i\\ =\sum_{i=1}^n\frac{exp(-\frac{1}{2}(x-x_i)^2)}{\sum_{j=1}^nexp(-\frac{1}{2}(x-x_j)^2)}y_i\\ =\sum_{i=1}^nsoftmax(-\frac{1}{2}(x-x_i)^2)y_i f(x)=i=1nα(x,xi)yi=i=1nj=1nexp(21(xxj)2)exp(21(xxi)2)yi=i=1nsoftmax(21(xxi)2)yi
可以看出,如果一个键xi越是接近给定的查询x,那么分配给这个键对应值的yi的注意力权重就会越大,也就“获得了更多的注意力”。
需要注意,Nadaraya-Watson核回归是一个非参数模型,也因此,上式是非参数的注意力池化模型。基于该模型绘制预测结果,可以发现预测线是平滑的,且比平均池化的预测更接近真实:

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)
d2l.plt.show()

运行结果:
在这里插入图片描述
现在我们来观察注意力的权重。这里测试数据的输入相当于查询,而训练数据的输入相当于键。因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近,注意力汇聚的注意力权重就越高:

# unsqueeze(0)表示增加第一维度,增加两次维度
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')
d2l.plt.show()

运行结果:
在这里插入图片描述

带参数注意力池化

非参数的Nadaraya-Watson核回归具有一致性的优点:若有足够数据,该模型会收敛到最优结果。尽管如此,我们还是可以轻松地将可学习的参数集成到注意力池化中。
与之前不同,在下面的查询x和键xi之前的距离乘以可学习的参数w:
f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n e x p ( − 1 2 ( ( x − x i ) w ) 2 ) ∑ j = 1 n e x p ( − 1 2 ( ( x − x j ) w ) 2 ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( ( x − x i ) w ) 2 ) ) y i f(x)=\sum_{i=1}^nα(x,x_i)y_i\\ =\sum_{i=1}^n\frac{exp(-\frac{1}{2}((x-x_i)w)^2)}{\sum_{j=1}^nexp(-\frac{1}{2}((x-x_j)w)^2)}y_i\\ =\sum_{i=1}^nsoftmax(-\frac{1}{2}((x-x_i)w)^2))y_i f(x)=i=1nα(x,xi)yi=i=1nj=1nexp(21((xxj)w)2)exp(21((xxi)w)2)yi=i=1nsoftmax(21((xxi)w)2))yi
下面将通过训练这个模型来学习注意力汇聚的参数。

批量矩阵乘法

对于矩阵乘法,我们之前已经知道对于两个形状分别为a×b和b×c的矩阵,进行矩阵乘法以后矩阵的形状就变成a×c的,现在做个推广,假设两个张量的形状分别为n×a×b和n×b×c,则进行批量矩阵乘法后,输出n×a×c。
在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
print(torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1)))

输出结果:
在这里插入图片描述

定义模型

使用小批量矩阵乘法,定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

训练

将训练数据集变换为键和值用于训练注意力模型,在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出。

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))
d2l.plt.show()

运行结果:
在这里插入图片描述
接着绘制预测结果:

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
d2l.plt.show()

运行结果:
在这里插入图片描述
容易发现,在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑。
可以尝试查看热力图,带参数的模型加入可学习的参数后,曲线在注意力权重较大的区域变得更不平滑:

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')
d2l.plt.show()

运行结果:
在这里插入图片描述

小结

1、Nadaraya-Watson核回归是具有注意力机制的机器学习范例。
2、Nadaraya-Watson核回归的注意力池化是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于将值所对应的键和查询作为输入的函数。
3、注意力池化可以分为非参数型和带参数型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/69341.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SqlServer基础之(触发器)

概念: 触发器(trigger)是SQL server 提供给程序员和数据分析员来保证数据完整性的一种方法,它是与表事件相关的特殊的存储过程,它的执行不是由程序调用,也不是手工启动,而是由事件来触发&#x…

JVM G1垃圾回收机制介绍

G1(Garbage First)收集器 (标记-整理算法): Java堆并行收集器,G1收集器是JDK1.7提供的一个新收集器,G1收集器基于“标记-整理”算法实现,也就是说不会产生内存碎片。此外,G1收集器不同于之前的收集器的一个重要特点是&…

钓鱼攻击:相似域名识别及如何有效预防攻击

网络犯罪分子很乐意劫持目标公司或其供应商或业务合作伙伴的官方域名,但在攻击的早期阶段,他们通常没有这种选择。相反,在有针对性的攻击之前,他们会注册一个与受害组织的域名相似的域名 - 他们希望您不会发现其中的差异。此类技术…

SpringBoot 的自动装配特性

1. Spring Boot 的自动装配特性 Spring Boot 的自动装配(Auto-Configuration)是一种特性,它允许您在应用程序中使用默认配置来自动配置 Spring Framework 的各种功能和组件,从而减少了繁琐的配置工作。通过自动装配,您…

TepeScript 问题记录

问题 对object的所有属性赋值或清空&#xff0c;提示类型错误不能赋值 type VoiceParams {_id?: string | undefined;name: string;sex: string;vc_id: string;model_url: string;preview_url: string;isPrivate: boolean;visible: boolean; }const formData reactive<V…

【Minecraft】Fabric Mod开发完整流程2 - 创造模式物品栏与第一个方块

创造模式物品栏 添加到当前已有物品栏 再添加自定义的创造模式物品栏之前&#xff0c;请确保你的确有这个需求&#xff01;否则建议直接添加到当前已有的物品栏内部 创建新文件&#xff1a;com/example/item/ModItemGroup.java package com.example.item;import net.fabricmc.…

出于网络安全考虑,印度启用本土操作系统”玛雅“取代Windows

据《印度教徒报》报道&#xff0c;印度将放弃微软系统&#xff0c;选择新的操作系统和端点检测与保护系统。 备受期待的 "玛雅操作系统 "将很快用于印度国防部的数字领域&#xff0c;而新的端点检测和保护系统 "Chakravyuh "也将一起面世。 不过&#xf…

2024考研408-计算机网络 第五章-传输层学习笔记

文章目录 前言一、传输层提供的服务1.1、传输层的功能1.2、传输层的两个协议&#xff08;TCP、UDP&#xff09;1.3、传输层的寻址与端口&#xff08;常见端口介绍&#xff09; 二、UDP协议2.1、认识UDP功能和特点2.2、UDP首部格式2.3、UDP伪首部字段分析2.4、伪首部校验UDP用户…

【24择校指南】南京大学计算机考研考情分析

南京大学(A) 考研难度&#xff08;☆☆☆☆☆&#xff09; 内容&#xff1a;23考情概况&#xff08;拟录取和复试分数人数统计&#xff09;、院校概况、23初试科目、23复试详情、参考书目、各科目考情分析、各专业考情分析。 正文2178字&#xff0c;预计阅读&#xff1a;6分…

网络原理(JavaEE初阶系列11)

目录 前言&#xff1a; 1.网络原理的理解 2.应用层 2.1自定义协议的约定 2.1.1确定要传输的信息 2.1.2确定数据的格式 3.传输层 3.1UDP 3.1.1UDP报文格式 3.2TCP 3.2.1确认应答 3.2.2超时重传 3.2.3连接管理 3.2.3.1三次握手 3.2.3.2四次挥手 3.2.4滑动窗口 3.…

【JavaEE】Spring Boot - 配置文件

【JavaEE】Spring Boot 开发要点总结&#xff08;2&#xff09; 文章目录 【JavaEE】Spring Boot 开发要点总结&#xff08;2&#xff09;1. 配置文件的两种格式2. .properties 文件2.1 基本语法2.2 注释2.3 配置项2.4 主动读取配置文件的键值2.5 数据库的连接时的需要的信息配…

ChatGPT访问流量下降的原因分析

​自从OpenAI的ChatGPT于11月问世以来&#xff0c;这款聪明的人工智能聊天机器人就席卷了全世界&#xff0c;人们在试用该工具的同时也好奇该技术到底将如何改变我们的工作和生活。 但近期Similarweb表示&#xff0c;自去ChatGPT上线以来&#xff0c;该网站的访问量首次出现下…

面试热题(路径总和II)

给你二叉树的根节点 root 和一个整数目标和 targetSum &#xff0c;找出所有 从根节点到叶子节点 路径总和等于给定目标和的路径。 叶子节点 是指没有子节点的节点。 在这里给大家提供两种方法进行思考&#xff0c;第一种方法是递归&#xff0c;第二种方式使用回溯的方式进行爆…

Linux文件属性与权限管理(可读、可写、可执行)

Linux把所有文件和设备都当作文件来管理&#xff0c;这些文件都在根目录下&#xff0c;同时Linux中的文件名区分大小写。 一、文件属性 使用ls -l命令查看文件详情&#xff1a; 1、每行代表一个文件&#xff0c;每行的第一个字符代表文件类型&#xff0c;linux文件类型包括&am…

Javascript 正则

基本语法 定义 JavaScript种正则表达式有两种定义方式 构造函数 var regnew RegExp(<%[^%>]%>,g);字面量 var reg/<%[^%>]%>/g;g&#xff1a; global&#xff0c;全文搜索&#xff0c;默认搜索到第一个结果接停止i&#xff1a;ingore case&#xff0c;忽略…

小程序如何设置电子票

电子票是一种方便快捷的票务管理方式&#xff0c;可以帮助商家实现电子化的票务管理&#xff0c;提升用户体验。下面介绍&#xff1a;如何在小程序内&#xff0c;设置电子票以及用电子票购买商品。 1. 设置电子票套餐。可以新建一个商品&#xff0c;商品标题写&#xff1a;XX电…

UDP通信实验、广播与组播、本地套接字

文章目录 流程函数应用广播应用 组播&#xff08;多播&#xff09;本地套接字应用 流程 函数 返回值&#xff1a; 成功&#xff0c;返回成功发送的数据长度 失败&#xff0c;-1 返回值&#xff1a; 成功&#xff0c;返回成功接收数据长度 失败&#xff0c;-1 应用 广播 应用 …

android APP内存优化

Android为每个应用分配多少内存 Android出厂后&#xff0c;java虚拟机对单个应用的最大内存分配就确定下来了&#xff0c;超出这个值就会OOM。这个属性值是定义在/system/build.prop文件中. 例如&#xff0c;如下参数 dalvik.vm.heapstartsize8m #起始分配内存 dalvik.vm.…

搭建servlet服务

目录 servlet的生命周期 配置tomcat环境 创建web后端项目 配置web.xml http请求 get和post 其他请求 http响应 Servlet是Server Applet的简称&#xff0c;意思为用Java编写的服务器端的程序&#xff0c;它运行在web服务器中&#xff0c;web服务器负责Servlet和客户的通…

5.利用matlab完成 符号矩阵的转置和 符号方阵的幂运算(matlab程序)

1.简述 Matlab符号运算中的矩阵转置 转置向量或矩阵 B A. B transpose(A) 说明 B A. 返回 A 的非共轭转置&#xff0c;即每个元素的行和列索引都会互换。如果 A 包含复数元素&#xff0c;则 A. 不会影响虚部符号。例如&#xff0c;如果 A(3,2) 是 12i 且 B A.&#xff0…