自然语言处理: 第五章Attention注意力机制

自然语言处理: 第五章Attention注意力机制

理论基础

Attention(来自2017年google发表的[1706.03762] Attention Is All You Need (arxiv.org) ),顾名思义是注意力机制,字面意思就是你所关注的东西,比如我们看到一个非常非常的故事的时候,但是其实我们一般能用5W2H就能很好的归纳这个故事,所以我们在复述或者归纳一段文字的时候,我们肯定有我们所关注的点,这些关注的点就是我们的注意力,而类似How 或者when 这种不同的形式就成为了Attention里的多头的机制。 下图是引自GPT3.5对注意力的一种直观的解释,简而言之其实就是各种不同(多头)我们关注的点(注意力)构成了注意力机制,这个奠定现代人工智能基石的基础。
在这里插入图片描述



那么注意力机制的优点是什么呢? (下面的对比是相对于上一节的Seq2Seq模型)

  1. 解决了长距离依赖问题,由于Seq2Seq模型一般是以时序模型eg RNN / Lstm / GRU 作为基础, 所以就会必然导致模型更倾向新的输入 – 多头注意力机制允许模型在解码阶段关注输入序列中的不同部分
  2. 信息损失:很难将所有信息压缩到一个固定长度的向量中(encorder 输出是一个定长的向量) – 注意力机制动态地选择输入序列的关键部分
  3. 复杂度和计算成本:顺序处理序列的每个时间步 – 全部网络都是以全连接层或者点积操作,没有时序模型
  4. 对齐问题:源序列和目标序列可能存在不对齐的情况 – 注意力机制能够为模型提供更精细的词汇级别对齐信


注意力可以拆解成下面6个部分,下面会在代码实现部分逐个解释

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-npjJuAYN-1689687821801)(image/06_attention/1689603947092.png)]


  1. 缩放点积注意力

两个向量相乘可以得到其相似度, 一般常用的是直接点积也比较简单,原论文里还提出里还提出了下面两种计算相似度分数的方式, 也可以参考下图。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sAngeG15-1689687821802)(image/06_attention/1689604308442.png)]



其实相似度分数,直观理解就是两个向量点积可以得到相似度的分数,加权求和得到输出。[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-so4FVaol-1689687821802)(image/06_attention/1689604224729.png)]
在这里插入图片描述



然后就是对点积注意力的拆解首先我们要明确目标我们要求解的是X1关于X2的注意力输出,所以首先需要确定的是X1 和 X2 的特征维度以及batch_size肯定要相同,然后是seq长度可以不同。 然后我们计算原始注意力权重即X1(N , seq1 , embedding) · X2(N , seq2 , embedding) -> attention( N , seq1 , seq2) , 可以看到我们得到了X1中每个单词对X2中每个单词的注意力权重矩阵所以维度是(seq1 , seq2)。

当特征维度尺寸比较大时,注意力的值会变得非常大,从而导致后期计算softmax的时候梯度消失,所以这里会对注意力的值做一个缩小,也就是将点积的结果/scaler_factor, 这个scaler_factor一般是embedding_size 的开根号。

然后我们对X2的维度做一个softmax 得到归一化的注意力权重,至于为什么是X2,是因为我们计算的是X1关于X2的注意力,所以在下一步我会会让整个attention 权重与X2做点积也就是加权求和,这里需要把X2对应的权重做归一化所以要对X2的权重做归一化。

也就是X1 与 X2 之间相互每个单词的相似度(score) 因为我们求的是X1 关于X2的注意力,所以最后我们将归一化后的权重与X2做一个加权求和(点积)即Attention_scaled(N , seq1 , seq2) · X2(N , seq2 , embedding) -> X1_attention_X2( N , seq1 , embedding) 这个时候我们可以看到最后的输出与X1的维度相同,但是里面的信息已经是整合了X2的信息的的X1’。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2cDP4HFU-1689687821802)(image/06_attention/1689604378729.png)]



  1. 编解码注意力
    这个仅存在Seq2Seq的架构中,也就是将编码器的最后输出与解码器的隐藏状态相互结合,下图进行了解释可以看到encoder将输入的上下文进行编码后整合成一个context 向量,由于我们最后的输出是decoder 所以这里X1 是解码器的隐层状态,X2 是编码器的隐层状态。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ti3ZkUEp-1689687821803)(image/06_attention/1689684006567.png)]



4. QKV

下面介绍的就是注意力中一个经常被弄混的概念,QKV, 根据前面的只是其实query 就是X1也就是我们需要查询的目标, Key 和 Value 也就是X2,只是X2不同的表现形式,K 可以等于 V 也可以不等,上面的做法都是相等的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HiChaI3W-1689687821803)(image/06_attention/1689684377799.png)]

  1. 自注意力

    最后就是注意力机制最核心的内容,也就是自注意力机制,那么为什么多了一个自呢?其实就是X1 = X2 ,换句话说就是自己对自己做了一个升华,文本对自己的内容做了一个类似summary的机制,得到了精华。就如同下图一样,自注意力的QKV 都来自同一个输入X向量,最后得到的X’, 它是自己整合了自己全部信息向量,它读完了自己全部的内容,并不只是单独的一个字或者一段话,而是去其糟粕后的X。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VwoyKvPH-1689687821804)(image/06_attention/1689684493893.png)]

而多头自注意力也就是同样的X切分好几个QKV, 可以捕捉不同的重点(类似一个qkv捕捉when , 一个qkv捕捉how),所以多头是有助于网络的表达,然后这里需要注意的是多头是将embedding - > (n_head , embedding // n_head ) , 不是(n_head , embedding)。

所以多头中得到的注意力权重的shape也会变成(N , n_head , seq , seq ) 这里由于是自注意力 所以seq1 = seq2 = seq 。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5GkFKSvL-1689687821804)(image/06_attention/1689684868618.png)]



代码实现

这里只介绍了核心代码实现, 下面是多头注意力的实现:

import torch.nn as nn # 导入torch.nn库
# 创建一个Attention类,用于计算注意力权重
class Mult_attention(nn.Module):
    def __init__(self, n_head):
        super(Mult_attention, self).__init__()
        self.n_head = n_head

    def forward(self, decoder_context, encoder_context , dec_enc_attn_mask): # decoder_context : x1(q) , encoder_context : x2(k , v)
        # print(decoder_context.shape, encoder_context.shape) # X1(N , seq_1 , embedding) , X2(N , seq_2 , embedding)
        # 进行切分 , (N , seq_len_X , emb_dim) -> (N , num_head , seq_len_X , head_dim) / head_dim * num_head  = emb_dim
        Q = self.split_heads(decoder_context)  # X1
        K = self.split_heads(encoder_context)  # X2
        V = self.split_heads(encoder_context)  # X2
        # print(Q.shape , 0)

        # 将注意力掩码复制到多头 attn_mask: [batch_size, n_heads, len_q, len_k]
        attn_mask = dec_enc_attn_mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)

        # 计算decoder_context和encoder_context的点积,得到多头注意力分数,其实就是在原有的基础上多加一个尺度
        scores = torch.matmul(Q, K.transpose(-2, -1)) # -> (N , num_head , seq_len_1 , seq_len_2)
        scores.masked_fill_(attn_mask , -1e9)
        # print(scores.shape ,1 )

        # 自注意力原始权重进行缩放
        scale_factor = K.size(-1) ** 0.5
        scaled_weights = scores / scale_factor # -> (N , num_head , seq_len_1 , seq_len_2)
        # print(scaled_weights.shape , 2)

        # 归一化分数
        attn_weights = nn.functional.softmax(scaled_weights, dim=-1) # -> (N , num_head , seq_len_1 , seq_len_2)
        # print(attn_weights.shape , 3)

        # 将注意力权重乘以encoder_context,得到加权的上下文向量
        attn_outputs = torch.matmul(attn_weights, V) # -> (N , num_head , seq_len_1 , embedding // num_head)
        # print(attn_outputs.shape , 4) 

        # 将多头合并下(output  & attention)
        attn_outputs  = self.combine_heads(attn_outputs) # 与Q的尺度是一样的
        attn_weights = self.combine_heads(attn_weights)
        # print(attn_weights.shape , attn_outputs.shape , 5) #
        return attn_outputs, attn_weights
  
    # 将所有头的结果拼接起来,就是把n_head 这个维度去掉,
    def combine_heads(self , tensor):
        # print(tensor.size())
        batch_size, num_heads, seq_len, head_dim = tensor.size()
        feature_dim = num_heads * head_dim
        return tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)


    def split_heads(self , tensor):
        batch_size, seq_len, feature_dim = tensor.size()
        head_dim = feature_dim // self.n_head
        # print(tensor.shape, head_dim , self.n_head)
        return tensor.view(batch_size, seq_len, self.n_head, head_dim).transpose(1, 2)


多头注意力的解码器,这里添加了mask机制。

class DecoderWithMutliHeadAttention(nn.Module):
    def __init__(self, hidden_size, output_size , n_head):
        super(DecoderWithMutliHeadAttention, self).__init__()
        self.hidden_size = hidden_size # 设置隐藏层大小
        self.n_head = n_head # 多头
        self.embedding = nn.Embedding(output_size, hidden_size) # 创建词嵌入层
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 创建RNN层
        self.multi_attention = Mult_attention(n_head = n_head)
        self.out = nn.Linear(2 * hidden_size, output_size)  # 修改线性输出层,考虑隐藏状态和上下文向量


    def forward(self, inputs, hidden, encoder_outputs , encoder_input):
        embedded = self.embedding(inputs)  # 将输入转换为嵌入向量
        rnn_output, hidden = self.rnn(embedded, hidden)  # 将嵌入向量输入RNN层并获取输出 
        dec_enc_attn_mask = self.get_attn_pad_mask(inputs, encoder_input) # 解码器-编码器掩码
        context, attn_weights = self.multi_attention(rnn_output, encoder_outputs , dec_enc_attn_mask)  # 计算注意力上下文向量
        output = torch.cat((rnn_output, context), dim=-1)  # 将上下文向量与解码器的输出拼接
        output = self.out(output)  # 使用线性层生成最终输出
        return output, hidden, attn_weights
  
    def get_attn_pad_mask(self , seq_q, seq_k):
        #-------------------------维度信息--------------------------------
        # seq_q 的维度是 [batch_size, len_q]
        # seq_k 的维度是 [batch_size, len_k]
        #-----------------------------------------------------------------
        # print(seq_q.size(), seq_k.size())
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        # 生成布尔类型张量[batch_size,1,len_k(=len_q)]
        pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  #<PAD> Token的编码值为0 
        # 变形为何注意力分数相同形状的张量 [batch_size,len_q,len_k]
        pad_attn_mask = pad_attn_mask.expand(batch_size, len_q, len_k)
        #-------------------------维度信息--------------------------------
        # pad_attn_mask 的维度是 [batch_size,len_q,len_k]
        #-----------------------------------------------------------------
        return pad_attn_mask # [batch_size,len_q,len_k]


结果

整体实验结果如下,可能是因为整体语料库太小了,所以翻译结果不是太好,但是多头注意力机制还是都跑通了:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B9v9o7Ep-1689687821804)(image/06_attention/1689687316028.png)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/42539.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【计算机组成原理】24王道考研笔记——第二章 数据的表示和运算

第二章 数据的表示和运算 一、数值与编码 1.1 进制转换 任意进制->十进制&#xff1a; 二进制<->八进制、十六进制&#xff1a; 各种进制的常见书写方式&#xff1a; 十进制->任意进制&#xff1a;&#xff08;用拼凑法最快&#xff09; 真值&#xff1a;符合人…

嵌入式软件和硬件的安全性:保护连接世界的数字盾牌

引言&#xff1a; 随着嵌入式系统的广泛应用和物联网的快速发展&#xff0c;嵌入式软件和硬件的安全性问题越来越引起人们的关注。安全性是确保嵌入式系统能够抵御恶意攻击和数据泄露的关键。本文将深入探讨嵌入式软件和硬件的安全性问题&#xff0c;包括技术原理、应用场景、学…

Edge 中比较独特的调试技巧

背景 大家日常开发基本都会使用 Chrome&#xff0c;毕竟确实好用。但是基于 Chromium 的新版 Microsoft Edge 已于 2020 年 1 月 15 日发布。 Edge 目前的使用基本跟 Chrome 差不多了&#xff0c;但显然&#xff0c;Edge 团队不仅仅想当 Chrome 的备用。他们也提供了一些特有…

【C++】-模板进阶(让你更好的使用模板创建无限可能)

&#x1f496;作者&#xff1a;小树苗渴望变成参天大树&#x1f388; &#x1f389;作者宣言&#xff1a;认真写好每一篇博客&#x1f4a4; &#x1f38a;作者gitee:gitee✨ &#x1f49e;作者专栏&#xff1a;C语言,数据结构初阶,Linux,C 动态规划算法&#x1f384; 如 果 你 …

MySQL 8.0 OCP (1Z0-908) 考点精析-性能优化考点6:MySQL Enterprise Monitor之Query Analyzer

文章目录 MySQL 8.0 OCP (1Z0-908) 考点精析-性能优化考点6&#xff1a;MySQL Enterprise Monitor之Query AnalyzerMySQL Enterprise Monitor之Query AnalyzerQuery Response Time index (QRTi)例题例题1: Query Analyzer答案与解析1 参考 【免责声明】文章仅供学习交流&#x…

【本地电脑搭建Web服务器并用cpolar发布至公网

本地电脑搭建Web服务器并用cpolar发布至公网访问 随着互联网的快速发展&#xff0c;网络也成为我们生活中不可缺少的必要条件&#xff0c;为了能在互联网世界中有自己的一片天地&#xff0c;建立一个属于自己的网页就成为很多人的选择。但互联网行业作为资本密集的行业&#x…

一)Stable Diffusion使用教程:安装

目前AI绘画最火的当属Midjorney和Stable Diffusion&#xff0c;但是由于Midjourney没有开源&#xff0c;因此我们主要分享下Stable Diffusion&#xff0c;后面有望补上Midjourney教程。 本节主要讲解Stable Diffusion&#xff08;以下简述SD&#xff09;的下载和安装。 1&…

呼吸灯——FPGA

文章目录 前言一、呼吸灯是什么&#xff1f;1、介绍2、占空比调节示意图 二、系统设计1、系统框图2、RTL视图 三、源码四、效果五、总结六、参考资料 前言 环境&#xff1a; 1、Quartus18.0 2、vscode 3、板子型号&#xff1a;EP4CE6F17C8 要求&#xff1a; 将四个LED灯实现循环…

Redis源码篇 - Ziplist数据结构

Ziplist是一种内存优化的list存储结构&#xff0c;通过使用连续的内存空间存储&#xff0c;来减少内存碎片化&#xff0c;同时和链表的不同还有&#xff0c;它不存储前后指针&#xff0c;而是通过变长的字节存储前节点元素长度&#xff0c;通过计算长度来实现节点的查找。它是一…

Google 登录支付,Firebase 相关设置

登录sdk: https://developers.google.com/identity/sign-in/android/start?hlzh-cn 支付sdk: https://developers.google.com/pay/api/android/overview?hlzh-cn Firebase sdk: https://firebase.google.com/docs/android/setup?hlzh-cn 登录设置&#xff1a; 创建凭据&…

机器学习-线性代数-5-空间中的向量投影与最小二乘法

空间中的向量投影与最小二乘法 文章目录 空间中的向量投影与最小二乘法一、引入二、投影和投影的描述1、投影描述最近2、利用矩阵描述投影(1)向一维直线投影(2)向二维平面投影(3)向n维子空间投影的一般情况 三、最小二乘法1、重要的子空间(1)互补的子空间(2)正交的子空间(3)相互…

12.面板问题

面板问题 html部分 <h1>Lorem ipsum dolor sit, amet consectetur adipisicing.</h1><div class"container"><div class"faq"><div class"title-box"><h3 class"title">Lorem, ipsum dolor.<…

(转载)神经网络遗传算法函数极值寻优(matlab实现)

本博客的完整代码获取&#xff1a; https://www.mathworks.com/academia/books/book106283.html 1案例背景 对于未知的非线性函数,仅通过函数的输入输出数据难以准确寻找函数极值。这类问题可以通过神经网络结合遗传算法求解,利用神经网络的非线性拟合能力和遗传算法的非线性…

make/makefile的使用

make/makefile 文章目录 make/makefile初步认识makefile的工作流程依赖关系和依赖方法make的使用 总结 make是一个命令&#xff0c;是一个解释makefile中指令的命令工具&#xff0c;makefile是一个文件&#xff0c;当前目录下的文件&#xff0c;两者搭配使用&#xff0c;完成项…

数据预处理matlab

matlab数据的获取、预处理、统计、可视化、降维 数据的预处理 - MATLAB & Simulink - MathWorks 中国https://ww2.mathworks.cn/help/matlab/preprocessing-data.html 一、数据的获取 1.1 从Excel中获取 使用readtable() 例1&#xff1a; 使用spreadsheetImportOption…

【AutoSAR 架构介绍】

AutoSAR简介 AUTOSAR是Automotive Open System Architecture&#xff08;汽车开放系统架构&#xff09;的首字母缩写&#xff0c;是一家致力于制定汽车电子软件标准的联盟。 AUTOSAR是由全球汽车制造商、部件供应商及其他电子、半导体和软件系统公司联合建立&#xff0c;各成…

npm link 实现全局运行package.json中的指令

packages.json "name":"testcli","bin": {"itRun": "index.js"},执行命令 npm link如果要解绑定 npm unlink testcli 现在你可以输入 itRun试一下

MySQL高阶语句之二

目录 一、子查询 1.1语法 1.2select 1.3insert 1.3update 1.4delete 1.5 exists 1.6别名as 二、MySQL视图 2.1功能 2.2区别 2.3联系 2.4 创建视图(单表) 2.5 创建视图(多表) 2.6修改原表数据 2.7修改视图数据 三、NULL值 四、连接查询 4.1内连接 4.1.1语法 4.1.…

LangChain+LLM大模型问答能力搭建与思考

1. 背景 最近&#xff0c;大模型&#xff08;LLMs&#xff0c;Large Language Models&#xff09;可谓是NLP领域&#xff0c;甚至整个科技领域最火热的技术了。凑巧的是&#xff0c;我本人恰好就是NLP算法工程师&#xff0c;面临着被LLMs浪潮淘汰的窘境&#xff0c;决定在焦虑…

配置jenkins 服务器与目标服务器自动化部署

在配置完远程构建后可以通过添加post-build step 执行shell脚本的方式将包传到远程服务器等一系列操作。 通过scp传输打包好的项目到目标服务器 按照链接 方式配置免密操作&#xff0c;需要注意的是要在jenkins 用户目录下配置生成私钥密钥&#xff0c;配置jenkins 的免密&…