【自然语言处理】多头注意力Multi-Head Attention机制

多头注意力(Multi-Head Attention)机制是Transformer模型中的一个关键组件,广泛用于自然语言处理任务(如机器翻译、文本生成等)以及图像处理任务。它的核心思想是通过多个不同的注意力头来捕获输入的不同特征,从而提高模型的表现力。以下是详细的解释:

一、多头注意力机制(Multi-Head Attention)

多头注意力机制是对单个注意力机制(详见【模型】Self-Attention)的扩展,它允许模型从多个角度“看”待输入数据。

具体来说,多头注意力机制通过以下步骤进行:

  1. 线性映射:首先,对查询 Q、键 K 和值 V 通过不同的线性变换(矩阵乘法),将它们分别投影到 h 个不同的子空间(即不同的头)。假设有 h 个注意力头,每个头的维度是 dk。

对于每个头 i,分别计算:
在这里插入图片描述其中 WiQ、WiK​、WiV 是头 i 的投影矩阵;
查询 Q、键 K 和值 V从输入中通过线性变换生成,X 是输入序列,WQ、WK和WV分别是Q、K、V 的权重矩阵。
在这里插入图片描述

  1. 并行计算多个注意力:对于每个头,分别使用缩放点积注意力机制来计算注意力输出。
    在这里插入图片描述

  2. 拼接注意力头的输出:将所有 h 个注意力头的输出拼接起来,形成一个大向量。
    在这里插入图片描述

  3. 线性变换:将拼接后的结果通过另一个线性变换 WO 进行投影,得到最终的多头注意力输出。
    在这里插入图片描述
    投影矩阵 WiQ​, WiK, WiV​ 以及用于拼接头输出后的权重矩阵 WO,最初都是随机初始化的,然后通过训练逐渐学习得到。

二、多头注意力的优点

  • 捕获不同的特征:每个注意力头都可以关注输入的不同部分,从而捕获更多元的信息,提升模型的表示能力。
  • 并行计算:多个注意力头可以并行计算,因此在计算效率上比单个注意力机制更高效。
  • 扩展表示能力:通过多个头,模型能够学习到更加复杂的关系,适合处理更复杂的任务。

三、在 Transformer 中的应用

在 Transformer 模型中,多头注意力机制被广泛应用在**编码器(Encoder)解码器(Decoder)**中。具体来说:

  • 自注意力(Self-Attention):输入序列中的每个元素都可以与序列中的其他元素建立联系。多头自注意力允许每个位置“查看”其他位置的信息。
  • 编码器-解码器注意力(交叉注意力):在解码器部分,模型通过多头注意力机制来“关注”编码器的输出,从而生成目标序列。

使用pytorch实现Multi-Head Attention 在机器翻译中的简单示例:

mport torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 英语和法语的词汇表
english_vocab = {"<pad>": 0, "i": 1, "am": 2, "a": 3, "student": 4, "he": 5, "is": 6, "teacher": 7,
                 "she": 8, "loves": 9, "apples": 10, "we": 11, "are": 12, "friends": 13}
french_vocab = {"<pad>": 0, "je": 1, "suis": 2, "un": 3, "étudiant": 4, "il": 5, "est": 6,
                "professeur": 7, "elle": 8, "aime": 9, "les": 10, "pommes": 11, "nous": 12, "sommes": 13, "amis": 14}

# 翻转字典以便通过索引查找单词
english_idx2word = {i: w for w, i in english_vocab.items()}
french_idx2word = {i: w for w, i in french_vocab.items()}

# 数据对
pairs = [
    ["i am a student", "je suis un étudiant"],
    ["he is a teacher", "il est un professeur"],
    ["she loves apples", "elle aime les pommes"],
    ["we are friends", "nous sommes amis"]
]

# Tokenization: 将句子转为词汇索引
def tokenize_sentence(sentence, vocab):
    return [vocab[word] for word in sentence.split()]

# 将句子对转为索引表示
tokenized_pairs = [(tokenize_sentence(p[0], english_vocab), tokenize_sentence(p[1], french_vocab)) for p in pairs]

# Padding函数
def pad_sequence(seq, max_len):
    return seq + [0] * (max_len - len(seq))

# 获取最大序列长度
max_len_src = max([len(pair[0]) for pair in tokenized_pairs])
max_len_tgt = max([len(pair[1]) for pair in tokenized_pairs])

# Padding后的输入和目标序列
src_sentences = torch.tensor([pad_sequence(pair[0], max_len_src) for pair in tokenized_pairs])
tgt_sentences = torch.tensor([pad_sequence(pair[1], max_len_tgt) for pair in tokenized_pairs])

# 模型超参数
d_model = 16   # 词嵌入维度
num_heads = 2  # 注意力头的数量
d_k = d_model // num_heads  # 每个头的维度

# 词嵌入层
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x)

# Multi-Head Attention 实现
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)

        self.out_proj = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        # Q * K^T / sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output

    def forward(self, Q, K, V):
        batch_size = Q.size(0)

        # 线性变换后,拆分成多个 heads
        Q = self.q_linear(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.k_linear(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.v_linear(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # 每个头分别计算注意力
        attention_output = self.scaled_dot_product_attention(Q, K, V)

        # 拼接所有头的输出
        concat_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        # 最终的线性变换
        output = self.out_proj(concat_output)
        return output

class Encoder(nn.Module):
    def __init__(self, src_vocab_size, d_model, num_heads):
        super(Encoder, self).__init__()
        self.embedding = EmbeddingLayer(src_vocab_size, d_model)
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, src):
        src_embedded = self.embedding(src)
        attention_output = self.attention(src_embedded, src_embedded, src_embedded)
        output = self.fc(attention_output)
        return output

class Decoder(nn.Module):
    def __init__(self, tgt_vocab_size, d_model, num_heads):
        super(Decoder, self).__init__()
        self.embedding = EmbeddingLayer(tgt_vocab_size, d_model)
        self.self_attention = MultiHeadAttention(d_model, num_heads)  # 自注意力
        self.cross_attention = MultiHeadAttention(d_model, num_heads)  # 编码器-解码器的交叉注意力
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, tgt, encoder_output):
        tgt_embedded = self.embedding(tgt)

        # 解码器的自注意力
        tgt_self_attention_output = self.self_attention(tgt_embedded, tgt_embedded, tgt_embedded)

        # 编码器-解码器的注意力 (将编码器的输出作为 Key 和 Value)
        attention_output = self.cross_attention(tgt_self_attention_output, encoder_output, encoder_output)

        # 最后的线性层
        output = self.fc(attention_output)
        return output

# 完整的 Encoder-Decoder 结构
class Translator(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads):
        super(Translator, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, num_heads)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_heads)

    def forward(self, src, tgt):
        encoder_output = self.encoder(src)
        output = self.decoder(tgt, encoder_output)
        return output

# 创建基于 encoder-decoder 的模型
model = Translator(len(english_vocab), len(french_vocab), d_model, num_heads)

# 损失和优化器
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略 padding
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(100):
    optimizer.zero_grad()
    output = model(src_sentences, tgt_sentences)

    # reshape 输出以适配损失函数
    output = output.view(-1, output.size(-1))
    tgt_sentences_flat = tgt_sentences.view(-1)

    loss = criterion(output, tgt_sentences_flat)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/100], Loss: {loss.item():.4f}")

# 预测函数
def predict(model, src_sentence, max_len=10):
    model.eval()  # 进入评估模式
    
    # 将源句子转换为索引并进行 padding
    src_tensor = torch.tensor([pad_sequence(tokenize_sentence(src_sentence, english_vocab), max_len_src)])
    
    # 编码器输出:将源句子输入到编码器
    encoder_output = model.encoder(src_tensor)
    
    # 初始化目标句子的开始符号 (<pad> 实际中应替换为 <bos>)
    tgt_sentence = [french_vocab["<pad>"]]
    
    # 开始逐步生成翻译
    for _ in range(max_len):
        tgt_tensor = torch.tensor([pad_sequence(tgt_sentence, max_len_tgt)])  # 对目标句子进行 padding
        
        # 将当前目标句子输入到解码器,得到预测输出
        output = model.decoder(tgt_tensor, encoder_output)
        
        # 获取最后一个时间步的输出(即预测的下一个词)
        next_word_idx = torch.argmax(output, dim=-1).squeeze().tolist()[-1]
        
        # 如果预测的词是结束符号或 <pad>,则停止生成
        if next_word_idx == french_vocab["<pad>"]:
            break
        
        # 将预测的词添加到目标句子中
        tgt_sentence.append(next_word_idx)
    
    # 将词索引转换回句子
    translated_sentence = [french_idx2word[idx] for idx in tgt_sentence if idx != 0]
    return " ".join(translated_sentence)

# 测试翻译
for pair in pairs:
    english_sentence = pair[0]
    prediction = predict(model, english_sentence)
    print(f"English: {english_sentence} -> French (Predicted): {prediction}")

四、总结

多头注意力机制通过多个不同的头来并行计算多个注意力分布,使得模型能够从多个角度捕获输入序列中的不同特征。其关键优势在于增强了模型的表示能力,同时可以有效并行化计算,在自然语言处理、计算机视觉等领域取得了广泛的成功。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/891330.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MedSAM2调试安装与使用记录

目录 前言一、环境准备多版本cuda切换切换cuda版本二 安装CUDNN2.1 检查cudnn 二、使用步骤1.安装虚拟环境2.测试Gradio3.推理 总结 前言 我们在解读完MedSAM之后&#xff0c;迫不及待想尝尝这个技术带来的福音&#xff0c;因此验证下是否真的那么6。这不&#xff0c;新鲜的使…

使用 KVM 在 Xubuntu 上创建 Windows 10 虚拟机

目录 前言说明注意准备 iso官网思博主(嘻嘻)拖动到虚拟机里面启动 virt-manager创建虚拟机选择本地安装介质选择 iso配置 内存 和 CPU选择 创建的虚拟机 保存的位置启动虚拟机看到熟悉的 Win10界面点击现在安装点击我没有产品密钥选择 Win10 专业工作站版勾选接受许可条款选择自…

《智慧博物馆:科技与文化的完美融合》

《智慧博物馆&#xff1a;科技与文化的完美融合》 一、智慧博物馆的兴起与发展 随着科技的飞速发展&#xff0c;智慧博物馆应运而生。进入新时代&#xff0c;大数据、人工智能、信息化的进步以及智能产品的应用&#xff0c;改变了人们接收信息和学习的习惯。为顺应社会变革&am…

【超详细】UDP协议

UDP传输层协议的一种&#xff0c;UDP(User Datagram Protocol 用户数据报协议)&#xff1a; 传输层协议无连接不可靠传输面向数据报 UDP协议端格式 定长报头&#xff0c;8字节源端口号和目的端口号来定位16位UDP长度, 表示整个数据报(UDP首部UDP数据)的最大长度如果校验和出错…

【C++】线程库常用接口

1.创建线程&#xff0c;等待线程&#xff0c;获取线程id 2.全局变量&#xff0c;局部变量&#xff0c;互斥锁 要让不同的线程访问同一个变量和同一把锁&#xff0c;有两种方法&#xff1a; 2.1方法一 定义全局的变量和全局的锁&#xff0c;这样自然就能访问到。 但全局变量在…

电能表预付费系统-标准传输规范(STS)(5)

5.5MeterFunctionObjects / companion specifications配套规格 With reference to Figure 1 it can be seen that the TokenCarrierToMeterInterface, which also includes the TokenCarrier, is dealt with in the IEC 62055-4x and IEC 62055-5x series. The remaining Mete…

论文 | OpenICL: An Open-Source Framework for In-context Learning

主要内容&#xff1a; 2. 提供多种 ICL 方法&#xff1a; 3. 完整的教程&#xff1a; 4. 评估和验证&#xff1a; 背景&#xff1a; 随着大型语言模型 (LLM) 的发展&#xff0c;上下文学习 (ICL) 作为一种新的评估范式越来越受到关注。问题&#xff1a; ICL 的实现复杂&#xf…

[ABC367C] Enumerate Sequences

1.注意输入的是哪个数组&#xff0c;输出的是哪个 2.dfs函数可以带两个参数&#xff0c;方便记录&#xff0c;一个记录第几个位置&#xff0c;一个记录题目的要求&#xff0c;例如求和。 3.注意递归出口输出后一定要return. #include<bits/stdc.h> using namespace std; …

Unity XR PICO 手势交互 Demo APK

效果展示 用手抓取物体&#xff0c;调整物体位置和大小等 亲测pico4 企业版可用&#xff0c; 其他设备待测试 下载链接&#xff1a; 我标记的不收费 https://download.csdn.net/download/qq_35030499/89879333

AI 编译器学习笔记之七三 -- 应用配置测试

1、通过jit_compile来进行算子调用控制 (不同的模型对推理的时间影响巨大) 昇腾pytorch代码地址&#xff1a;https://gitee.com/ascend/pytorch jit_compile true&#xff1a;走的是GEIR&#xff0c;进行了在线编译&#xff0c;可以用到的算子包含了 ascendC 、tbe、tik、aicpu…

Spring事务管理:应用实战案例和规则

背景 想象一下&#xff0c;如果没有Spring框架对事务的支持&#xff0c;我们得自行对事物进行管理&#xff1a; 获得JDBC连接、 关闭JDBC连接、 执行JDBC事务提交、 执行JDBC事务回滚操作 有了Spring事务框架&#xff0c;我们再也不需要在与事务相关的方法中处理大量的try.…

Faker:自动化测试数据生成利器

Faker&#xff1a;自动化测试数据生成利器 前言1. 安装2. 多语言支持3. 常用方法3.1 生成姓名和地址3.2 生成电子邮件和电话号码3.3 生成日期和时间3.4 生成公司名称和职位3.5 生成文本和段落3.6 生成图片和颜色3.7 生成用户代理和浏览器信息3.8 生成文件和目录3.9 生成UUID和哈…

4 机器学习之归纳偏好

通过学习得到的模型对应了假设空间中的一个假设。于是&#xff0c;图1.2的西瓜版本空间给我们带来一个麻烦&#xff1a;现在有三个与训练集一致的假设&#xff0c;但与它们对应的模型在面临新样本的时候&#xff0c;却会产生不同的输出。例如&#xff0c;对&#xff08;色泽青绿…

Excel日期导入数据库变为数字怎么办

在Excel导入到数据库的过程中&#xff0c;经常会碰到Excel里面的日期数据&#xff0c;导进去过后变成了数字。 如下图&#xff1a; 使用navicate等数据库编辑器导入数据库后&#xff1a; 原因分析&#xff1a;这是因为日期和时间在excel中都是以数字形式存储的&#xff0c;这个…

PolarCTF靶场[web]file、ezphp WP

[WEB]file 知识点&#xff1a;文件上传漏洞 工具&#xff1a;Burp Suite、dirsearch 方法一&#xff1a; 根据页面提示&#xff0c;先用dirsearch工具扫一扫 访问/upload.php&#xff0c;发现一个上传区 在访问/uploaded/,再点击Parent Directory&#xff0c;发现链接到首页…

带隙基准Bandgap电路学习(三)

一、导入器件到版图中 从原理图中导入器件&#xff1a; Connectivity——>Generate——>All From Source I/O Pins暂不添加&#xff0c;后面自己画 PR&#xff08;Primary Region&#xff09;Boundary: 通常是用来定义芯片设计中某些关键区域的轮廓&#xff0c;比…

用Eclipse运行第一个Java程序

1.左键双击在桌面“软件 (文件夹)”&#xff0c;打开该文件夹 2.左键双击“eclipse (文件夹)”&#xff0c;打开该文件夹 3.左键双击“eclipse (文件夹)”&#xff0c;打开该文件夹 4.左键双击“eclipse.exe”&#xff0c;运行这个可执行程序 5.左键单击“Ok&#xff08;按下按…

【软件部署安装】OpenOffice转换PDF字体乱码

现象与原因分析 执行fc-list查看系统字体 经分析发现&#xff0c;linux默认不带中文字体&#xff0c;因此打开我们本地的windows系统的TTF、TTC字体安装到centos机器上。 安装字体 将Windows的路径&#xff1a; C:\Windows\Fonts 的中文字体&#xff0c;如扩展名为 TTC 与TT…

电影《荒野机器人》观后感

上上周看了电影《荒野机器人》&#xff0c;电影整体是比较偏向温馨的&#xff0c;通过动物与机器人视角&#xff0c;展现人类为情感。 &#xff08;1&#xff09;承载-托举-学习-感情 在电影中&#xff0c;有个场景让自己感觉特别温馨&#xff0c;就是机器人为了让大雁宝宝学习…

Linux系统之dig命令的基本使用

Linux系统之dig命令的基本使用 一、dig命令介绍二、本次实践环境三、dig命令的使用帮助3.1 dig的语法解释3.2 dig的帮助信息 四、dig命令的基本使用4.1 查询对应域名的ip4.2 查询域名的MX记录4.3 查询域名的NS记录4.4 查询域名的A记录4.5 查询详细信息4.6 对目标ip进行反向解析…