【自然语言处理】多头注意力机制 Multi-Head Attention

多头注意力(Multi-Head Attention)机制是Transformer模型中的一个关键组件,广泛用于自然语言处理任务(如机器翻译、文本生成等)以及图像处理任务。它的核心思想是通过多个不同的注意力头来捕获输入的不同特征,从而提高模型的表现力。以下是详细的解释:

一、多头注意力机制(Multi-Head Attention)

多头注意力机制是对单个注意力机制(详见【模型】Self-Attention)的扩展,它允许模型从多个角度“看”待输入数据。

具体来说,多头注意力机制通过以下步骤进行:

  1. 线性映射:首先,对查询 Q、键 K 和值 V 通过不同的线性变换(矩阵乘法),将它们分别投影到 h 个不同的子空间(即不同的头)。假设有 h 个注意力头,每个头的维度是 dk。

对于每个头 i,分别计算:
在这里插入图片描述其中 WiQ、WiK​、WiV 是头 i 的投影矩阵;
查询 Q、键 K 和值 V从输入中通过线性变换生成,X 是输入序列,WQ、WK和WV分别是Q、K、V 的权重矩阵。
在这里插入图片描述

  1. 并行计算多个注意力:对于每个头,分别使用缩放点积注意力机制来计算注意力输出。
    在这里插入图片描述

  2. 拼接注意力头的输出:将所有 h 个注意力头的输出拼接起来,形成一个大向量。
    在这里插入图片描述

  3. 线性变换:将拼接后的结果通过另一个线性变换 WO 进行投影,得到最终的多头注意力输出。
    在这里插入图片描述
    投影矩阵 WiQ​, WiK, WiV​ 以及用于拼接头输出后的权重矩阵 WO,最初都是随机初始化的,然后通过训练逐渐学习得到。

二、多头注意力的优点

  • 捕获不同的特征:每个注意力头都可以关注输入的不同部分,从而捕获更多元的信息,提升模型的表示能力。
  • 并行计算:多个注意力头可以并行计算,因此在计算效率上比单个注意力机制更高效。
  • 扩展表示能力:通过多个头,模型能够学习到更加复杂的关系,适合处理更复杂的任务。

三、在 Transformer 中的应用

在 Transformer 模型中,多头注意力机制被广泛应用在**编码器(Encoder)解码器(Decoder)**中。具体来说:

  • 自注意力(Self-Attention):输入序列中的每个元素都可以与序列中的其他元素建立联系。多头自注意力允许每个位置“查看”其他位置的信息。
  • 编码器-解码器注意力(交叉注意力):在解码器部分,模型通过多头注意力机制来“关注”编码器的输出,从而生成目标序列。

使用pytorch实现Multi-Head Attention 在机器翻译中的简单示例:

mport torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 英语和法语的词汇表
english_vocab = {"<pad>": 0, "i": 1, "am": 2, "a": 3, "student": 4, "he": 5, "is": 6, "teacher": 7,
                 "she": 8, "loves": 9, "apples": 10, "we": 11, "are": 12, "friends": 13}
french_vocab = {"<pad>": 0, "je": 1, "suis": 2, "un": 3, "étudiant": 4, "il": 5, "est": 6,
                "professeur": 7, "elle": 8, "aime": 9, "les": 10, "pommes": 11, "nous": 12, "sommes": 13, "amis": 14}

# 翻转字典以便通过索引查找单词
english_idx2word = {i: w for w, i in english_vocab.items()}
french_idx2word = {i: w for w, i in french_vocab.items()}

# 数据对
pairs = [
    ["i am a student", "je suis un étudiant"],
    ["he is a teacher", "il est un professeur"],
    ["she loves apples", "elle aime les pommes"],
    ["we are friends", "nous sommes amis"]
]

# Tokenization: 将句子转为词汇索引
def tokenize_sentence(sentence, vocab):
    return [vocab[word] for word in sentence.split()]

# 将句子对转为索引表示
tokenized_pairs = [(tokenize_sentence(p[0], english_vocab), tokenize_sentence(p[1], french_vocab)) for p in pairs]

# Padding函数
def pad_sequence(seq, max_len):
    return seq + [0] * (max_len - len(seq))

# 获取最大序列长度
max_len_src = max([len(pair[0]) for pair in tokenized_pairs])
max_len_tgt = max([len(pair[1]) for pair in tokenized_pairs])

# Padding后的输入和目标序列
src_sentences = torch.tensor([pad_sequence(pair[0], max_len_src) for pair in tokenized_pairs])
tgt_sentences = torch.tensor([pad_sequence(pair[1], max_len_tgt) for pair in tokenized_pairs])

# 模型超参数
d_model = 16   # 词嵌入维度
num_heads = 2  # 注意力头的数量
d_k = d_model // num_heads  # 每个头的维度

# 词嵌入层
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x)

# Multi-Head Attention 实现
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)

        self.out_proj = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        # Q * K^T / sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output

    def forward(self, Q, K, V):
        batch_size = Q.size(0)

        # 线性变换后,拆分成多个 heads
        Q = self.q_linear(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.k_linear(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.v_linear(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # 每个头分别计算注意力
        attention_output = self.scaled_dot_product_attention(Q, K, V)

        # 拼接所有头的输出
        concat_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        # 最终的线性变换
        output = self.out_proj(concat_output)
        return output

class Encoder(nn.Module):
    def __init__(self, src_vocab_size, d_model, num_heads):
        super(Encoder, self).__init__()
        self.embedding = EmbeddingLayer(src_vocab_size, d_model)
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, src):
        src_embedded = self.embedding(src)
        attention_output = self.attention(src_embedded, src_embedded, src_embedded)
        output = self.fc(attention_output)
        return output

class Decoder(nn.Module):
    def __init__(self, tgt_vocab_size, d_model, num_heads):
        super(Decoder, self).__init__()
        self.embedding = EmbeddingLayer(tgt_vocab_size, d_model)
        self.self_attention = MultiHeadAttention(d_model, num_heads)  # 自注意力
        self.cross_attention = MultiHeadAttention(d_model, num_heads)  # 编码器-解码器的交叉注意力
        self.fc = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, tgt, encoder_output):
        tgt_embedded = self.embedding(tgt)

        # 解码器的自注意力
        tgt_self_attention_output = self.self_attention(tgt_embedded, tgt_embedded, tgt_embedded)

        # 编码器-解码器的注意力 (将编码器的输出作为 Key 和 Value)
        attention_output = self.cross_attention(tgt_self_attention_output, encoder_output, encoder_output)

        # 最后的线性层
        output = self.fc(attention_output)
        return output

# 完整的 Encoder-Decoder 结构
class Translator(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads):
        super(Translator, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, num_heads)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_heads)

    def forward(self, src, tgt):
        encoder_output = self.encoder(src)
        output = self.decoder(tgt, encoder_output)
        return output

# 创建基于 encoder-decoder 的模型
model = Translator(len(english_vocab), len(french_vocab), d_model, num_heads)

# 损失和优化器
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略 padding
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(100):
    optimizer.zero_grad()
    output = model(src_sentences, tgt_sentences)

    # reshape 输出以适配损失函数
    output = output.view(-1, output.size(-1))
    tgt_sentences_flat = tgt_sentences.view(-1)

    loss = criterion(output, tgt_sentences_flat)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/100], Loss: {loss.item():.4f}")

# 预测函数
def predict(model, src_sentence, max_len=10):
    model.eval()  # 进入评估模式
    
    # 将源句子转换为索引并进行 padding
    src_tensor = torch.tensor([pad_sequence(tokenize_sentence(src_sentence, english_vocab), max_len_src)])
    
    # 编码器输出:将源句子输入到编码器
    encoder_output = model.encoder(src_tensor)
    
    # 初始化目标句子的开始符号 (<pad> 实际中应替换为 <bos>)
    tgt_sentence = [french_vocab["<pad>"]]
    
    # 开始逐步生成翻译
    for _ in range(max_len):
        tgt_tensor = torch.tensor([pad_sequence(tgt_sentence, max_len_tgt)])  # 对目标句子进行 padding
        
        # 将当前目标句子输入到解码器,得到预测输出
        output = model.decoder(tgt_tensor, encoder_output)
        
        # 获取最后一个时间步的输出(即预测的下一个词)
        next_word_idx = torch.argmax(output, dim=-1).squeeze().tolist()[-1]
        
        # 如果预测的词是结束符号或 <pad>,则停止生成
        if next_word_idx == french_vocab["<pad>"]:
            break
        
        # 将预测的词添加到目标句子中
        tgt_sentence.append(next_word_idx)
    
    # 将词索引转换回句子
    translated_sentence = [french_idx2word[idx] for idx in tgt_sentence if idx != 0]
    return " ".join(translated_sentence)

# 测试翻译
for pair in pairs:
    english_sentence = pair[0]
    prediction = predict(model, english_sentence)
    print(f"English: {english_sentence} -> French (Predicted): {prediction}")

四、总结

多头注意力机制通过多个不同的头来并行计算多个注意力分布,使得模型能够从多个角度捕获输入序列中的不同特征。其关键优势在于增强了模型的表示能力,同时可以有效并行化计算,在自然语言处理、计算机视觉等领域取得了广泛的成功。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/896932.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

虚拟现实与Facebook的结合:未来社交的全新体验

随着科技的不断发展&#xff0c;虚拟现实&#xff08;VR&#xff09;技术正在逐步改变人们的社交方式。Facebook&#xff0c;作为全球最大的社交媒体平台之一&#xff0c;积极探索如何将虚拟现实融入其社交生态系统&#xff0c;创造全新的用户体验。这一结合不仅影响了用户之间…

深度解析机器学习的四大核心功能:分类、回归、聚类与降维

深度解析机器学习的四大核心功能&#xff1a;分类、回归、聚类与降维 前言分类&#xff08;Classification&#xff09;&#xff1a;预测离散标签的艺术关键算法与代码示例逻辑回归支持向量机&#xff08;SVM&#xff09; 回归&#xff08;Regression&#xff09;&#xff1a;预…

探索秘境:如何使用智能体插件打造专属的小众旅游助手『小众旅游探险家』

文章目录 摘要引言智能体介绍和亮点展示介绍亮点展示 已发布智能体运行效果智能体创意想法创意想法创意实现路径拆解 如何制作智能体可能会遇到的几个问题快速调优指南总结未来展望 摘要 本文将详细介绍如何使用智能体平台开发一款名为“小众旅游探险家”的旅游智能体。通过这…

获取非加密邮件协议中的用户名和密码——安全风险演示

引言 在当今的数字时代,网络安全变得越来越重要。本文将演示如何通过抓包工具获取非加密邮件协议中的用户名和密码,以此说明使用非加密协议的潜在安全风险。通过这个演示,我们希望能提高读者的安全意识,促使大家采取更安全的通信方式。 注意: 本文仅用于教育目的,旨在提高安全…

【MyBatis】初识MyBatis 构建简单框架

目录 MyBatis前言搭建一个简单的MyBatis创建Maven项目引入必要依赖创建数据表结构创建User实体类创建Mapper接口Mapper层Dao层 创建MyBatis的Mapper映射文件编写测试类传统测试类JUnit测试 MyBatis 介绍&#xff1a;MyBatis是一款半自动的ORM持久层框架&#xff0c;具有较高的…

Linux下ClamAV源代码安装与使用说明

Linux下ClamAV源代码安装与使用说明 ClamAV(Clam AntiVirus)是一款开源的防病毒工具,广泛应用于Linux平台上的网络安全领域。它以其高效的性能和灵活的配置选项,成为网络安全从业人员的重要工具。ClamAV支持多线程扫描,可以自动升级病毒库,并且支持多个操作系统,包括Li…

NGINX 保护 Web 应用安全之基于 IP 地址的访问

根据客户端的 IP 地址控制访问 使用 HTTP 或 stream 访问模块控制对受保护资源的访问&#xff1a; location /admin/ { deny 10.0.0.1; allow 10.0.0.0/20; allow 2001:0db8::/32; deny all; } } 给定的 location 代码块允许来自 10.0.0.0/20 中的任何 IPv4 地址访问&#xf…

可视化大屏中运用3D模型,能够带来什么好处。

现在你看到的可视化大屏&#xff0c;大都会在中间放置一些3D模型&#xff0c;比如厂房、园区、设备等等&#xff0c;那么这些3D模型的放置的确给可视化大屏带来了不一样的视觉冲击&#xff0c;本文将从以下四个方面来分析这个现象。 一、可视化大屏中越来越多使用3D模型说明了…

Linux工具的使用-【git的理解和使用】【调试器gdb的使用】

目录 Linux工具的使用-031.git1.1git是什么1.2git在linux下的操作1.2.1创建git仓库1.2.2 .gitignore1.2.3 .git&#xff08;本地仓库&#xff09;1.2.4 add (添加)1.2.5 commit(提交)1.2.6push(推送)对两个特殊情况的处理配置免密码push 1.2.7 log(获取提交记录)1.2.8 status(获…

Java项目-基于springboot框架的逍遥大药房管理系统项目实战(附源码+文档)

作者&#xff1a;计算机学长阿伟 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、ElementUI等&#xff0c;“文末源码”。 开发运行环境 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBoot、Vue、Mybaits Plus、ELementUI工具&#xff1a;IDEA/…

Linux运维篇-误操作已经做了pv的磁盘导致pv异常

目录 故障场景排错过程小结 故障场景 在对/dev/vdb1创建了pv并扩容至vg(klas)之后&#xff0c;不小心对/dev/vdb进行了parted操作&#xff0c;删除了/dev/vdb1导致pvs查看显示异常。具体过程如下所示&#xff1a; 正常创建pv 将创建好的pv添加到系统现有的卷组中 不小心又对…

Golang | Leetcode Golang题解之第491题非递减子序列

题目&#xff1a; 题解&#xff1a; var (temp []intans [][]int )func findSubsequences(nums []int) [][]int {ans [][]int{}dfs(0, math.MinInt32, nums)return ans }func dfs(cur, last int, nums []int) {if cur len(nums) {if len(temp) > 2 {t : make([]int, len(…

【计网】理解TCP全连接队列与tcpdump抓包

希望是火&#xff0c;失望是烟&#xff0c; 生活就是一边点火&#xff0c;一边冒烟。 理解TCP全连接队列与tcpdump抓包 1 TCP 全连接队列1.1 重谈listen函数1.2 初步理解全连接队列1.3 深入理解全连接队列 2 tcpdump抓包 1 TCP 全连接队列 1.1 重谈listen函数 这里我们使用…

颜色交替的最短路径

题目链接 颜色交替的最短路径 题目描述 注意 返回长度为n的数组answer&#xff0c;其中answer[x]是从节点0到节点x的红色边和蓝色边交替出现的最短路径的长度图中每条边为红色或者蓝色&#xff0c;且可能存在自环或平行边 解答思路 可以使用广度优先遍历从0开始找到其相邻…

Java.6--多态-设计模式-抽象父类-抽象方法

一、多态 1.定义--什么是多态&#xff1f; a.同一个父类的不同子类对象&#xff0c;在做同一行为的时候&#xff0c;有不同的表现形式&#xff0c;这就是多态。&#xff08;总结为&#xff1a;一个父类下的不同子类&#xff0c;同一行为&#xff0c;不同表现形式。&#xff0…

leetcode day1 910+16

910 最小差值 给你一个整数数组 nums&#xff0c;和一个整数 k 。 在一个操作中&#xff0c;您可以选择 0 < i < nums.length 的任何索引 i 。将 nums[i] 改为 nums[i] x &#xff0c;其中 x 是一个范围为 [-k, k] 的任意整数。对于每个索引 i &#xff0c;最多 只能 …

Excel中如何进行傅里叶变换(FT),几步完成

在 Excel 中&#xff0c;虽然没有像 MATLAB 那样专门的函数库来直接进行傅里叶变换&#xff0c;但可以使用 Excel 内置的分析工具库提供的傅里叶变换&#xff08;FT &#xff0c;Fourier Transform&#xff09;功能。这个工具可以对数据进行频域分析。以下是如何在 Excel 中进行…

开源表单生成器OpnForm

什么是 OpnForm &#xff1f; OpnForm 是一个开源的表单构建工具&#xff0c;旨在简化创建自定义表单的过程&#xff0c;特别适合无编码知识的用户。它通过人工智能优化表单创建流程&#xff0c;支持多种用途&#xff0c;如联系人表单、调查表等。OpnForm 提供了一个直观的拖放…

semi-Naive Bayesian(半朴素贝叶斯)

semi-Naive Bayesian&#xff08;半朴素贝叶斯&#xff09; 引言 朴素贝叶斯算法是基于特征是相互独立这个假设开展的&#xff08;为了降低贝叶斯公式: P ( c ∣ x ) P ( c ) P ( x ∣ c ) P ( x ) P(c|x) \frac {P(c)P(x|c)}{P(x)} P(c∣x)P(x)P(c)P(x∣c)​中后验概率 P …