24/11/12 算法笔记<强化学习> 自注意力机制

自注意力机制(Self-Attention Mechanism),也称为内部注意力机制,是一种在深度学习模型中,特别是在自然语言处理(NLP)和计算机视觉领域中广泛使用的机制。它允许模型在处理序列数据时,能够动态地聚焦于序列的不同部分,从而捕捉到序列内部的长距离依赖关系。

自注意力机制的核心思想是,序列中的每个元素都与其他所有元素相关,模型需要学习如何根据上下文信息来分配不同的注意力权重。这种机制最早在Transformer模型中被提出,并在随后的研究中被广泛应用于各种任务。

自注意力机制的工作原理

  1. 输入表示:模型首先将输入序列(如句子或图像)转换为一系列向量表示,这些向量通常通过嵌入层(Embedding Layer)得到。

  2. 查询(Query)、键(Key)和值(Value):对于序列中的每个元素,模型会生成三个向量:查询(Q)、键(K)和值(V)。在原始的Transformer模型中,这些向量是通过输入向量与三个不同的权重矩阵相乘得到的。

  3. 计算注意力分数:模型计算每个查询向量与所有键向量之间的相似度或匹配程度,得到一个注意力分数矩阵。这个分数矩阵通常通过点积(Dot Product)或缩放点积(Scaled Dot-Product)得到。

  4. 应用 softmax 函数:注意力分数矩阵通过softmax函数进行归一化,使得每一行的和为1。这样,每个查询向量都会得到一个概率分布,表示对其他元素的注意力权重。

  5. 加权和:每个查询向量根据学到的权重,对所有值向量进行加权求和,得到最终的输出向量。

  6. 输出:自注意力层的输出可以是序列中的每个元素对应的加权和向量,这些向量可以被用作后续任务的输入,如分类、翻译等。

多头自注意力

为了捕捉不同子空间中的信息,Transformer模型引入了多头自注意力机制。在多头自注意力中,模型并行地执行多次自注意力操作,每个“头”使用不同的权重矩阵来生成查询、键和值。最后,所有头的输出被拼接在一起,并通过一个线性层进行处理,以产生最终的输出。

自注意力机制的优势在于其能够处理序列数据中的长距离依赖,并且不受传统循环神经网络(RNN)中序列长度的限制。此外,由于其并行化的特性,自注意力模型通常比RNN模型训练得更快。

我们来看下它的代码

1.导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.nn模块用于构建神经网络层和初始化参数,torch.nn.functional包含了各种不带有权重的函数式接口。
2.定义自注意力类

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

embed_size表示嵌入向量的维度,heads表示注意力头的数量。head_dim是每个头的维度,它通过将embed_size除以heads得到。

3.检查嵌入维度是否能被头数整除

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

这里使用断言语句来确保嵌入维度可以被头数整除,这是实现多头注意力机制的前提条件。

4.初始化线性层

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

这部分代码初始化了四个线性层(全连接层),分别用于计算值(values)、键(keys)、查询(queries)和输出。由于我们使用的是多头注意力机制,所以需要将输入的嵌入向量分割成多个头,每个头都有自己的线性层。最后一个线性层fc_out用于将多头的输出合并回原始的嵌入维度。

5.前向传播

    def forward(self, value, key, query):
        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

forward方法中,我们首先获取输入张量的形状信息。N是批大小,value_lenkey_lenquery_len分别是值、键和查询序列的长度。

6.分割嵌入向量

        values = self.values(value).view(N, value_len, self.heads, self.head_dim)
        keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)

这里我们使用线性层处理输入的值、键和查询,并将结果分割成多个头。view方法用于重新塑形张量,以适应多头注意力机制的需要。

  1. 通过.view()方法,将变换后的数据重塑为一个新的形状。这里的新形状是(N, value_len, self.heads, self.head_dim),其中:

    • N是批次大小(batch size)。
    • value_lenkey_lenquery_len分别是值、键和查询序列的长度。
    • self.heads是注意力头的数量。
    • self.head_dim是每个头的维度。

7.调整张量维度以适应多头注意力

        values = values.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        queries = queries.permute(0, 2, 1, 3)

通过permute方法,我们调整张量的维度

8.计算注意力分数

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

使用torch.einsum计算查询和键之间的点积,得到注意力分数。然后,我们对这些分数应用softmax函数,以获得每个头的注意力权重。除以embed_size的平方根是缩放点积的一种常见做法,有助于稳定训练过程

9.应用注意力权重

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).permute(0, 2, 1, 3)

这里我们再次使用torch.einsum将注意力权重应用到值上,得到加权的值。然后,我们调整张量的维度,以便于后续的合并操作。

10.合并多头注意力

        out = out.reshape(N, query_len, self.heads * self.head_dim)
        out = self.fc_out(out)

11.使用自注意力机制

embed_size = 256  # 嵌入向量的维度
heads = 8  # 注意力头的数量

value = torch.rand(32, 10, embed_size)
key = torch.rand(32, 10, embed_size)
query = torch.rand(32, 10, embed_size)

attention = SelfAttention(embed_size, heads)
output = attention(value, key, query)

print(output.shape)  # 应该输出:torch.Size([32, 10, 256])

这部分代码展示了如何使用上面定义的SelfAttention类。我们创建了一个SelfAttention实例,并传入随机生成的值、键和查询张量。然后,我们打印输出张量的形状,以验证其正确性。

训练量小的化cnn好,大的化self-attention好

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/916529.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前后端交互之动态列

一. 情景 在做项目时,有时候后会遇到后端使用了聚合函数,导致生成的对象的属性数量或数量不固定,因此无法建立一个与之对应的对象来向前端传递数据,这时可以采用NameDataListVO向前端传递数据。 Data Builder AllArgsConstructo…

k8s服务内容滚动升级以及常用命令介绍

查看K8S集群所有的节点信息 kubectl get nodes 删除K8S集群中某个特定节点 kubectl delete nodes/10.0.0.123 获取K8S集群命名空间 kubectl get namespace 获取K8S所有命名空间的那些部署 kubectl get deployment --all-namespaces 创建命名空间 web界面上看到的效果,但是…

【视觉SLAM】1-概述

读书笔记 文章目录 1. 经典视觉SLAM框架2. 数学表述2.1 运动方程2.2 观测方程2.3 问题抽象 1. 经典视觉SLAM框架 传感器信息读取:相机图像、IMU等多源数据;前端视觉里程计(Visual Odometry,VO):估计相机的相…

低成本出租屋5G CPE解决方案:ZX7981PG/ZX7981PM WIFI6千兆高速网络

刚搬进新租的房子,没有网络,开个热点?续航不太行。随身WIFI?大多是百兆级网络。找人拉宽带?太麻烦,退租的时候也不能带着走。5G CPE倒是个不错的选择,插入SIM卡就能直接连接5G网络,千…

如何在Typora中绘制流程图

如何在Typora中绘制流程图 在撰写文档时,清晰的流程图能极大地提升信息传递的效率。Typora是一款优秀的Markdown编辑器,支持通过Mermaid语法快速绘制流程图。本文将介绍如何在Typora中创建和自定义流程图,帮助你用更直观的方式呈现逻辑结构和…

莱特币转型MEME币:背后隐含的加密市场现象

随着加密市场的风云变幻,莱特币(LTC)这款曾经的“老牌矿币”近日以自嘲式推文宣布“自己是一个MEME币”,迅速引发了市场的广泛关注和一波围绕MEME币的炒作浪潮。这一举动看似玩笑,却反映出当前加密市场的一种微妙转变&…

【代码大模型】Is Your Code Generated by ChatGPT Really Correct?论文阅读

Is Your Code Generated by ChatGPT Really Correct? Rigorous Evaluation of Large Language Models for Code Generation key word: evaluation framework, LLM-synthesized code, benchmark 论文:https://arxiv.org/pdf/2305.01210.pdf 代码:https:…

LC12:双指针

文章目录 125. 验证回文串 本专栏记录以后刷题碰到的有关双指针的题目。 125. 验证回文串 题目链接:125. 验证回文串 这是一个简单题目,但条件判断自己写的时候写的过于繁杂。后面参考别人写的代码,首先先将字符串s利用s.toLowerCase()将其…

MySQL5.7.37安装配置

1.下载MySQL软件包并解压 2.配置环境变量 3.新建my.ini文件并输入信息 [mysqld] #端口号 port 3306 #mysql-5.7.27-winx64的路径 basedirC:\mysql-5.7.37\mysql-5.7.37-winx64 #mysql-5.7.27-winx64的路径\data datadirC:\mysql-5.7.37\mysql-5.7.37-winx64\data #最大连接数…

python习题4

1 判断车牌归属地 输入一串车牌号,按e结束,判断车牌归属于那里 例如: 输入: jingA12345 huB34567 zheA99999 e 输出: jing hu zhe chepai input(请输入车牌号:\n) lst [] while chepai ! e:lst…

【原创】java+ssm+mysql社区疫情防控管理系统设计与实现

个人主页:程序猿小小杨 个人简介:从事开发多年,Java、Php、Python、前端开发均有涉猎 博客内容:Java项目实战、项目演示、技术分享 文末有作者名片,希望和大家一起共同进步,你只管努力,剩下的交…

《深度学习》VGG网络

文章目录 1.VGG的网络架构2.案例:手写数字识别 学习目标: 知道VGG网络结构的特点能够利用VGG网络完成图像分类 2014年,⽜津⼤学计算机视觉组(Visual Geometry Group)和GoogleDeepMind公司的研究员⼀起研发出了新的深度…

探索 Python HTTP 的瑞士军刀:Requests 库

文章目录 探索 Python HTTP 的瑞士军刀:Requests 库第一部分:背景介绍第二部分:Requests 库是什么?第三部分:如何安装 Requests 库?第四部分:Requests 库的基本函数使用方法第五部分&#xff1a…

无桥Boost-PFC 双闭环控制MATLAB仿真

一、无桥Boost-PFC原理概述 无桥 Boost-PFC(Power Factor Correction,功率因数校正)的工作原理是通过特定的电路结构和控制策略,对输入电流进行校正,使其与输入电压同相位,从而提高电路的功率因数&#xf…

数据结构Python版

2.3.3 双链表 双链表和链表一样,只不过每个节点有两个链接——一个指向后一个节点,一个指向前一个节点。此外,除了第一个节点,双链表还需要记录最后一个节点。 每个结点为DLinkNode类对象,包括存储元素的列表data、…

力扣-Hot100-二叉树其一【算法学习day.32】

前言 ###我做这类文档一个重要的目的还是给正在学习的大家提供方向(例如想要掌握基础用法,该刷哪些题?)我的解析也不会做的非常详细,只会提供思路和一些关键点,力扣上的大佬们的题解质量是非常非常高滴&am…

京东商品详情,Python爬虫的“闪电战”

在这个数字化的时代,我们每天都在和数据打交道,尤其是电商数据。想象一下,你是一名侦探,需要快速获取京东上某个商品的详细信息,但是没有超能力,怎么办?别担心,Python爬虫来帮忙&…

深度学习推荐系统的工程实现

参考自《深度学习推荐系统》——王喆,用于学习和记录。 介绍 之前章节主要从理论和算法层面介绍了推荐系统的关键思想。但算法和模型终究只是“好酒”,还需要用合适的“容器”盛载才能呈现出最好的味道,这里的“容器”指的就是实现推荐系统…

「QT」高阶篇 之 d-指针 的用法

✨博客主页何曾参静谧的博客📌文章专栏「QT」QT5程序设计📚全部专栏「Win」Windows程序设计「IDE」集成开发环境「UG/NX」BlockUI集合「C/C」C/C程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「UG/NX」NX定制…

ISUP协议视频平台EasyCVR视频设备轨迹回放平台智慧农业视频远程监控管理方案

在当今快速发展的农业领域,智慧农业已成为推动农业现代化、助力乡村全面振兴的新手段和新动能。随着信息技术的持续进步和城市化进程的加快,智慧农业对于监控安全和智能管理的需求日益增长。 视频设备轨迹回放平台EasyCVR作为智慧农业视频远程监控管理方…