自注意力简介

在注意力机制中,每个查询都会关注所有的键值对并生成一个注意力输出。如果查询q,键k和值v都来自于同一组输入,那么这个注意力就被称为是自注意力(self-attention)。自注意力这部分理论,我觉得台大李宏毅老师的课程讲得最好。

自注意力就是输入一堆向量,假设称为a1,a2,a3,a4,那么这四个向量都会参与自注意力机制的运算,得到的结果仍然是四个输出,这四个输出再去做全连接运算。而每一个自注意力机制的输出都用到了a1~a4四个向量来进行运算,也就是说每个输出都是观察了所有的输入之后才得到的。

首先,输入a1需要和a2,a3,a4分别计算相关性,这个相关性可以由缩放点积方式来计算,也就称作缩放点积注意力,也可以由两个两个输入向量相加后再做非线性处理得到,称为加性注意力。

缩放点积的计算方法如下:

输入一个向量v1和一个向量v2,v1去乘上一个可训练矩阵Wq得到q,v2去乘上一个可训练矩阵Wk得到k,再把这个q和k做一个点积运算,得到的就是α,类似于相似度。

回到前面的例子中,a1这里既作为q,又作为k,又作为v。其中,Wq*a1就是q1,Wk*a1就是k1,q1和k1的点积就是α11,相当于a1自己和自己的相似度,同样的,a1和a2,a3,a4分别计算得到α12,α13,α14,然后将α11,α12,α13,α14经过softmax得到最终的四个输出,如下图所示:

然后再用一个可训练矩阵Wv去乘以a1得到v1,用计算得到的相似度α'11去乘以v1,得到一个值temp11;同样的,用可训练矩阵Wv去乘以a2得到v2,用计算得到的相似度α‘12去乘以v2,得到temp12;类似的得到temp13,temp14,然后把temp11+temp12+temp13+temp14得到b1,这个b1就是自注意力机制的第一个输出。

我们刚刚是以a1的视角做的运算,得到b1,同样可以以a2,a3,a4的视角做运算,得到b2,b3,b4。这次就得到了自注意力机制的输出。光看最后这个结构图,有点类似全连接,只是里面的运算过程比全连接要复杂。

下面,我们来看一下如何用代码实现自注意力的计算。

import torch  
import torch.nn as nn  
import torch.nn.functional as F  

class SelfAttention(nn.Module):
    # embed_size代表输入的向量维度,heads代表多头注意力机制中的头数量
    def __init__(self, embed_size, heads): 
        super(SelfAttention,self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads # 每个头的维度
        # 用assert断言机制判断
        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads" 
        # 没有偏置项,其实这个线性层本质上就是为了计算值Wv*a = V
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        # 最后的全连接操作,输出仍是输入的向量维度,也就是说大小是不变的
        self.fc_out = nn.Linear(heads*self.head_dim, self.embed_size)
        
    def forward(self, values, keys, query, mask):
        # 这个mask也很关键,它用于控制模型在处理序列数据时应该关注哪些部分,以及忽略哪些部分
        N = query.shape[0] # 获取输入的批量个数
        print("N:",N)
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 获取输入序列的长度
        # Split the embedding into self.heads different pieces  
        # 把k,q,v都切分为多个组
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        # 计算k,q,v
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        energy = torch.einsum("nqhd,nkhd->nhqk",[queries, keys]) # 格式转化
        print("queries.shape:", queries.shape)
        print("keys.shape:", keys.shape)
        print("energy.shape:", energy.shape)
        
        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20"))
            
        attention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3) # softmax内部是缩放点积
        print("attention.shape:", attention.shape)
        print("values.shape:", values.shape)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        out = self.fc_out(out)
        return out
embed_size = 512
heads = 8
attention = SelfAttention(embed_size, heads)

# batch size 1, seq length 60
values = torch.rand(1,60,embed_size)
keys = torch.rand(1,60,embed_size)
queries = torch.rand(1,60,embed_size)
mask = None # 假设没有mask

out = attention(values, keys, queries, mask)
print(out.shape)

# 输出
N: 1
queries.shape: torch.Size([1, 60, 8, 64])
keys.shape: torch.Size([1, 60, 8, 64])
energy.shape: torch.Size([1, 8, 60, 60])
attention.shape: torch.Size([1, 8, 60, 60])
values.shape: torch.Size([1, 60, 8, 64])
torch.Size([1, 60, 512])

通过这个程序,我们可以看到,自注意力机制是不改变输入和输出的形状的,输入的Q,K,V格式是[1,60,512],输出的结果的仍然是[1,60,512]。

下面是几点说明:

1. 这里的embed_size代表的是输入到自注意力层中的每个元素的向量维度。在Transformer模型中,输入数据首先会被转换成一个固定长度的向量,这个向量的长度就称为embed_size。

2. mask表示的是模型在处理序列数据时,应该忽略掉哪部分,我这里设置为None,也就是全部参与计算。

3. einsum,称为爱因斯坦求和,起源是爱因斯坦在研究广义相对论时,需要处理大量求和运算,为了简化这种繁复的运算,提出了求和约定,推动了张量分析的发展。einsum 可以计算向量、矩阵、张量运算,如果利用得当,sinsum可完全代替其他的矩阵计算方法。

例如,C = einsum('ij,jk->ik', A, B),就相当于两个矩阵求内积:cik = Σj AijBjk。

通过输出可以看到,在计算前queries的形状是[1,60,8,64],keys的形状是[1,60,8,64],在表达式"nqhd,nkhd->nhqk"中,n=1,q=60,h=8,d=64,k=60,两个矩阵进行内积,因此得到的结果是nhqk,也就是[1,8,60,60]。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/796841.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FFmpeg开发环境搭建

FFmpeg是音视频开发必备的库,也是唯一的库。本文主要讲解在ubuntu22和macOS14环境下的编译安装。 为什么要自己编译呢?其中一个很重要的原因就是ffmpeg在编译时可以加入很多插件,这种特定的库网络上可能找不到编译好的版本,另外如…

在linux中查找 / 目录下的以.jar结尾的文件(find / -name *.jar)

文章目录 1、查找 / 目录下的以.jar结尾的文件 1、查找 / 目录下的以.jar结尾的文件 [rootiZuf6332h890vozldoxcprZ ~]# find / -name *.jar /etc/java/java-1.8.0-openjdk/java-1.8.0-openjdk-1.8.0.342.b07-1.el9_0.x86_64/lib/security/policy/limited/US_export_policy.ja…

Spring MVC 中 HttpMessageConverter 转换器

1. Spring MVC 中 HttpMessageConverter 转换器 文章目录 1. Spring MVC 中 HttpMessageConverter 转换器2. 补充:什么是 HTTP 消息3. 转换器3.1 转换器转换的是什么 4. Spring MVC中的 AJAX 请求5. ResponseBody 将服务器端的 return 返回值转化为“字符串(JSON格式…

RIP路由协议概述

RIP【Routing Information Protocol】 RIP是为TCP/IP 环境中开发的第一个路由选择协议标准 RIP是一个【距离——矢量】路由选择协议 当路由器的更新周期为30s到来时,向邻居发送路由表 RIP以跳数作为唯一度量值 RIP工作原理 RIP路由协议向邻居发送整个路由表信息RI…

Linux 复现Docker NAT网络

Linux 复现Docker NAT网络 docker 网络的构成分为宿主机docker0网桥和为容器创建的veth 对构成。这个默认网络命名空间就是我们登陆后日常使用的命名空间 使用ifconfig命令查看到的就是默认网络命名空间,docker0就是网桥,容器会把docker0当成路由&…

JavaDS —— 单链表 与 LinkedList

顺序表和链表区别 ArrayList : 底层使用连续的空间,可以随机访问某下标的元素,时间复杂度为O(1) 但是在插入和删除操作的时候,需要将该位置的后序元素整体往前或者向后移动,时间复杂度为O&…

代码随想录二刷7.22|977.有序数组的平方

暴力解法: ——如果想暴力解决这个问题的话,可以像题目那样,先将每一个元素平方,然后再排序 双指针: ——从题目中找到的信息:这是一个非递减顺序的整数数组,从例子中,可以容易看…

excel 百分位函数 学习

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、函数说明PERCENTILE 函数PERCENTILE.inc 函数PERCENTILE.exc 函数QUARTILE.EXC 函数 二、使用步骤总结 前言 excel 百分位函数 Excel提供了几个函数用于…

数据库MySQL下载安装

MySQL下载安装地址如下: MySQL :: Download MySQL Community Server 1、下载界面 2、点击下载 3、解压记住目录 4、配置my.ini文件 未完..

vector 介绍

1.简述vector 首先我们要大致弄明白vector是一个什么东西,其实vector就是之前我们学过的顺序表,这里直接使用就行了. 定义vector-------->vector<typename> arr 此时的这种定义vector可以理解成为一个数组,而typename可以是各种数据类型,比如string,int,double....…

QT实现自定义带有提示信息的透明环形进度条

1. 概述 做界面开发的童鞋可能都会遇到这样的需求&#xff0c;就是有一些界面点击了之后比较耗时的操作&#xff0c;需要界面给出一个环形进度条的进度反馈信息. 如何来实现这样的需求呢&#xff0c;话不多说&#xff0c;上效果 透明进度条 2. 代码实现 waitfeedbackprogressba…

3-2 多层感知机的从零开始实现

import torch from torch import nn from d2l import torch as d2lbatch_size 256 # 批量大小为256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size) # load进来训练集和测试集初始化模型参数 回想一下&#xff0c;Fashion-MNIST中的每个图像由 28 28 784…

一文了解5G新通话技术演进与业务模型

5G新通话简介 5G新通话&#xff0c;也被称为VoNR&#xff0c;是基于R16及后续协议产生的一种增强型语音通话业务。 它在IMS网络里新增数据通道&#xff08;Data Channel&#xff09;&#xff0c;承载通话时的文本、图片、涂鸦、菜单等信息。它能在传统话音业务基础上提供更多服…

红日靶场----(三)1.漏洞利用

上期已经信息收集阶段已经完成&#xff0c;接下来是漏洞利用。 靶场思路 通过信息收集得到两个吧靶场的思路 1、http://192.168.195.33/phpmyadmin/&#xff08;数据库的管理界面&#xff09; root/root 2、http://192.168.195.33/yxcms/index.php?radmin/index/login&am…

练习 6.7:⼈们 在为练习 6.1 编写的程序中,再创建两个表⽰⼈的字典,然后将这三个字典都存储在⼀个名为 people 的列表中。

练习 6.7&#xff1a;⼈们 在为练习 6.1 编写的程序中&#xff0c;再创建两个表⽰⼈的字典&#xff0c;然后将这三个字典都存储在⼀个名为 people 的列表中。 要求 遍历这个列表&#xff0c;将其中每个⼈的所有信息都打印出来。 代码 human {shuicc: {first_name: shui,la…

linux下安装cutecom串口助手;centos安装cutecom串口助手;rpm安装包安装cutecom串口助手

在支持apt-get的系统下安装 在终端命令行中输入&#xff1a; sudo apt-get install cutecom 安装好后输入 sudo cutecom 就可以了 关于如何使用&#xff0c;可以看这个https://www.cnblogs.com/xingboy/p/14388610.html 如果你的电脑不支持apt-get。 那我们就通过安装包…

druid(德鲁伊)数据线程池连接MySQL数据库

文章目录 1、druid连接MySQL2、编写JDBCUtils 工具类 1、druid连接MySQL 初学JDBC时&#xff0c;连接数据库是先建立连接&#xff0c;用完直接关闭。这就需要不断的创建和销毁连接&#xff0c;会消耗系统的资源。 借鉴线程池的思想&#xff0c;数据连接池就这么被设计出来了。…

【Unity2D 2022:NPC】制作任务系统

一、接受任务 1. 编辑NPC对话脚本&#xff1a; &#xff08;1&#xff09;创建静态布尔变量用来判断ruby是否接受到任务 public class NPCDialog : MonoBehaviour {// 创建全局变量用来判断ruby是否接到任务public static bool receiveTask false; } &#xff08;2&#xff…

git自动pull同步远程若干分支与本地若干分支

git自动pull同步远程若干分支与本地若干分支 假设远程代码仓库有100个分支&#xff0c;而本地只有10个本地分支与远程分支一一对应&#xff0c;现在要保持本地的这个10个分支与远程一致&#xff0c;最笨的方法是checkout到每个分支&#xff0c;然后一个一个的 git pull origin…

展望未来:在【PyCharm】中结合【机器学习】实现高效的图形化处理

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ 文章目录 引言一、PyCharm简介与配置1.1 PyCharm基础1.2 图形化库配置 二、机器学习项目中的数据可视化2.1 数据加载与预处理2.2 数据探索性可视化2.3 模型训练与结果可视化 三、PyCharm中的图形化调试四、高级…