【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)


文章目录

  • 【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)
  • 1. 交叉注意力的起源与提出
  • 2. 交叉注意力的原理
  • 3. 交叉注意力的数学表示
  • 4. 交叉注意力的应用场景与发展
  • 5. 代码实现
  • 6. 代码解释
  • 7. 总结


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 交叉注意力的起源与提出

交叉注意力(Cross-Attention)是在深度学习中提出的一种重要注意力机制,用于在多个输入之间建立关联,主要用于多模态任务中(如图像和文本、视频和音频的联合处理)。

与常规的自注意力机制不同,交叉注意力专注于从两个不同的输入特征空间中提取和结合关键信息。这种机制最初在自然语言处理和计算机视觉的融合任务中得到应用,例如在多模态Transformer、机器翻译和图像-文本任务(如CLIP、DALL·E、VQA等)中。

  • 提出背景:交叉注意力通常用于处理两种不同类型的数据,通过这种机制,一个输入可以对另一个输入进行查询,捕捉和增强跨模态之间的关联。相比自注意力(仅在同一个输入中找到相关性),交叉注意力能够有效地捕捉多模态数据的交互信息。

2. 交叉注意力的原理

交叉注意力的核心思想是将一个输入(例如图像)作为查询(Query),另一个输入(例如文本)作为键(Key)和值(Value),通过注意力机制让查询能够从键和值中选择和关注相关信息。

交叉注意力的步骤:

  • 查询、键、值的生成: 假设有两个不同的输入数据 X1 和 X2,分别生成对应的 Query、Key 和 Value 矩阵。对于 X1,我们可以生成 Query 矩阵,而对于 X2,则可以生成 Key 和 Value 矩阵。
  • 注意力计算: 与自注意力类似,交叉注意力通过计算 Query 和 Key 的相似性来获得注意力权重:
    在这里插入图片描述
    其中 Q 来自 X1,而 K 和 V 来自 X2 。通过这种计算,Query 可以从X2 中提取与其最相关的信息,这种机制实现了两个输入数据之间的特征融合和信息传递。
  • 权重与输出: 计算出的注意力权重应用到 X2的 Value 矩阵上,得到 X1在
    X2上的相关信息。这种机制实现了两个输入数据之间的特征融合和信息传递。

3. 交叉注意力的数学表示

假设有两个输入特征 X 1 ∈ R T 1 × d X_1∈R^{T_1×d} X1RT1×d X 2 ∈ R T 2 × d X_2∈R^{T_2×d} X2RT2×d,其中 T 1 T_1 T1 T 2 T_2 T2分别表示两个输入的长度(如序列长度或特征维度), d d d 表示特征维度。

Query、Key 和 Value 的生成:

  • 对于 X 1 X_1 X1:生成查询矩阵 Q = W q X 1 Q=W_qX_1 Q=WqX1
  • 对于 X 2 X_2 X2:生成键矩阵 K = W k X 2 K=W_kX_2 K=WkX2和值矩阵 V = W v X 2 V=W_vX_2 V=WvX2

注意力计算:
在这里插入图片描述
其中, W q W_q Wq W k W_k Wk W v W_v Wv ∈ R d × d ∈R^{d×d} Rd×d是线性变换矩阵, d d d 是键的维度。

结果输出: 注意力权重应用于 V V V 后的结果,即:
在这里插入图片描述

4. 交叉注意力的应用场景与发展

交叉注意力在以下场景中得到广泛应用:

  • 多模态学习:交叉注意力在视觉和语言任务中的多模态联合建模中尤为常见,如图像与文本的对齐(CLIP)、视觉问答(VQA)和跨模态生成任务(如DALL·E)。
  • 机器翻译:交叉注意力在Transformer中的"解码器"部分用于让生成的序列(目标语言)参考源语言的表示,这大大提高了翻译质量。
  • Transformer架构的扩展:在诸如BERT、GPT等基于Transformer的模型中,交叉注意力也被用于各种任务,例如文本生成、序列到序列任务等。

发展过程中,交叉注意力机制已经被改进和扩展。例如,层次化交叉注意力(Hierarchical Cross-Attention)通过在不同层次上融合多模态信息,进一步提升了模型在多模态任务上的性能。

5. 代码实现

下面是一个基于PyTorch的交叉注意力机制的简单实现,用于展示如何在两个不同的输入(例如图像和文本)之间计算交叉注意力。

import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        # 线性变换,用于生成 Q, K, V 矩阵
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        # 输出的线性变换
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x1, x2):
        # x1 是 Query,x2 是 Key 和 Value
        B, T1, C = x1.shape  # x1 的形状: [batch_size, seq_len1, dim]
        _, T2, _ = x2.shape  # x2 的形状: [batch_size, seq_len2, dim]

        # 生成 Q, K, V 矩阵
        Q = self.q_proj(x1).view(B, T1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力得分
        attn_scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = self.softmax(attn_scores)  # 注意力权重
        attn_weights = self.dropout(attn_weights)  # dropout 防止过拟合

        # 使用注意力权重加权值矩阵
        attn_output = attn_weights @ V
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T1, C)

        # 输出线性变换
        output = self.out_proj(attn_output)
        return output

# 测试交叉注意力机制
if __name__ == "__main__":
    B, T1, T2, C = 2, 10, 20, 64  # batch_size, seq_len1, seq_len2, channels
    x1 = torch.randn(B, T1, C)  # Query 输入
    x2 = torch.randn(B, T2, C)  # Key 和 Value 输入

    cross_attn = CrossAttention(dim=C, num_heads=4)
    output = cross_attn(x1, x2)
    
    print("输出形状:", output.shape)  # 输出应该为 [batch_size, seq_len1, channels]

6. 代码解释

CrossAttention 类:该类实现了交叉注意力机制,允许将两个不同的输入(x1x2)进行交叉信息融合。

  • q_proj, k_proj, v_proj:三个线性层,用于将输入分别映射到 Query、Key 和 Value 空间。
  • num_headshead_dim:定义了多头注意力机制的头数和每个头的维度。

forward 函数:实现前向传播过程。

  • Q, K, V:分别从 x1x2 中生成 Query、Key 和 Value 矩阵,形状为 [batch_size, num_heads, seq_len, head_dim]
  • attn_scores:计算 Query 和 Key 的点积,得到注意力得分。
  • attn_weights:通过 softmax 对得分进行归一化,得到注意力权重。
  • attn_output:利用注意力权重对 Value 矩阵进行加权求和,得到最终的注意力输出。

测试部分:随机生成两个输入张量 x1x2,并测试交叉注意力的输出形状,确保与预期一致。

7. 总结

交叉注意力在多模态学习中起到了至关重要的作用,能够有效融合不同类型的数据,使得模型可以同时处理图像、文本等多种信息。通过捕捉模态之间的相关性,交叉注意力为多模态任务中的特征融合提供了强大的工具。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/905399.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springboot响应文件流文件给浏览器+前端下载

springboot响应文件流文件给浏览器前端下载 1.controller: Api(tags {"【样本提取系统】-api"}) RestController("YbtqYstbtqController") RequiredArgsConstructor RequestMapping("/ybtq-ystbtq") Slf4j public class YbtqYstbtqController …

DAY67WEB 攻防-Java 安全JNDIRMILDAP五大不安全组件RCE 执行不出网

知识点: 1、Java安全-RCE执行-5大类函数调用 2、Java安全-JNDI注入-RMI&LDAP&高版本 3、Java安全-不安全组件-Shiro&FastJson&JackJson&XStream&Log4j Java安全-RCE执行-5大类函数调用 Java中代码执行的类: Groovy Runti…

vue下载安装

目录 vue工具前置要求:安装node.js并配置好国内镜像源下载安装 vue 工具 系统:Windows 11 前置要求:安装node.js并配置好国内镜像源 参考:本人写的《node.js下载、安装、设置国内镜像源(永久)&#xff…

书生实战营第四期-第四关 玩转HF/魔搭/魔乐社区

一、任务1:模型下载 使用魔搭社区平台下载文档中提到的模型 1.创建开发机 2.环境配置 # 激活环境 conda activate /root/share/pre_envs/pytorch2.1.2cu12.1# 安装 modelscope pip install modelscope -t /root/env/maas pip install numpy1.26.0 -t /root/env/m…

【Blender】 学习笔记(一)

文章目录 参考概念原点 Origin游标 轴心点坐标操作默认快捷键两个比较好用的功能渲染器元素不可选(防止误选)关联材质 参考 参考b站视频:【Kurt】Blender零基础入门教程 | Blender中文区新手必刷教程(已完结) 概念 模型、灯光、摄像机 原点…

Java中的反射(Reflection)

先上两张图来系统的看一下反射的作用和具体的实现方法 接下来详细说一下反射的步骤以及之中使用的方法: 获取Class对象: 要使用反射,首先需要获得一个Class对象,该对象是反射的入口点。可以通过以下几种方式获取Class对象&#x…

号码认证是什么意思?有什么用?

随着通信环境越来越复杂,各种骚扰、推销电话层出不穷。许多企业为了取信于客户,提高电话的接听率,纷纷选择了申请号码认证,试图通过这种方法来与客户建立更加高效的沟通。 不可否认,这种方法是极其有效的。号码认证可…

Android 圆形进度条CircleProgressView 基础版

一个最基础的自定义View 圆形进度条,可设置背景色、进度条颜色(渐变色)下载进度控制;可二次定制度高; 核心代码: Overrideprotected void onDraw(NonNull Canvas canvas) {super.onDraw(canvas);int mW g…

Java基础0-Java概览

Java概览 一、Java的主要特性 Java 语言是简单的: Java 丢弃了 C 中很少使用的、很难理解的、令人迷惑的那些特性,如操作符重载、多继承、自动的强制类型转换。特别地,Java 语言不使用指针,而是引用。并提供了自动分配和回收内存…

信号(四)【信号处理与捕捉】

目录 1. 信号的处理1.1 内核态 && 用户态1.2 进程地址空间第三弹1.1 内核态 && 用户态 (续) 2. 信号捕捉 1. 信号的处理 我们一直在说,进程收到信号了,可能会因为各种原因无法即使处理信号,而后选择一个合适的时机去处理。所…

Kafka 与传统 MQ 消息系统之间有三个关键区别?

大家好,我是锋哥。今天分享关于【Kafka 与传统 MQ 消息系统之间有三个关键区别?】面试题?希望对大家有帮助; Kafka 与传统 MQ 消息系统之间有三个关键区别? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 …

基于局部近似的模型解释方法

在机器学习领域中,模型解释性是一个越来越重要的议题,尤其是在复杂的深度学习模型和非线性模型广泛应用的今天。解释性不仅帮助我们理解模型的决策逻辑,还能提高模型在敏感领域(如医疗诊断、金融分析)中的可信度。基于…

img 标签的 object-fit 属性

设置图片固定尺寸后,可以通过 object-fit 属性调整图片展示的形式 object-fit: contain; 图片的长宽比不变,相应调整大小。 object-fit: cover; 当图片的长宽比与容器的长宽比不一致时,会被裁切。 object-fit: fill; 图片不再锁定长宽…

推荐一款功能强大的文字处理工具:Atlantis Word Processor

Atlantis word proCEssor是一款功能强大的文字处理工具。该软件可以让用户放心的去设计文档,并且软件的界面能够按用户的意愿去自定义,比如工具栏、字体选择、排版、打印栏等等,当然还有更多的功能,比如你还可以吧软件界面中的任何…

【算法】【优选算法】双指针(上)

目录 一、双指针简介1.1 对撞指针(左右指针)1.2 快慢指针 二、283.移动零三、1089.复写零3.1 双指针解题3.2 暴力解法 四、202.快乐数4.1 快慢指针4.2 暴力解法 五、11.盛最多⽔的容器5.1 左右指针5.2 暴力解法 一、双指针简介 常⻅的双指针有两种形式&…

Pandas DataFrame学习补充

1. 从字典创建:字典的键成为列名,值成为列数据。 import pandas as pd# 通过字典创建 DataFrame df pd.DataFrame({Column1: [1, 2, 3], Column2: [4, 5, 6]}) 2. 从列表的列表创建:外层列表代表行,内层列表代表列。 df pd.Da…

二、Go快速入门之数据类型

📅 2024年4月27日 📦 使用版本为1.21.5 Go的数据类型 📖官方文档:https://go.dev/ref/spec#Types 1️⃣ 布尔类型 ⭐️ 布尔类型只有真和假,true和false ⭐️ 在Go中整数0不会代表假,非零整数也不能代替真&#…

Springboot整合原生ES依赖

前言 Springboot整合依赖大概有三种方式: es原生依赖:elasticsearch-rest-high-level-clientSpring Data ElasticsearchEasy-es 三者的区别 1. Elasticsearch Rest High Level Client 简介: 这是官方提供的 Elasticsearch 客户端,支持…

Spark,Anconda在虚拟机实现本地模式部署

如果想要了解模式的概念部分,以及作用请看: Spark学习-CSDN博客 一.在虚拟机安装spark cd /opt/modules 把Anconda和Spark安装包拖拽进去: 解压: tar -zxf spark-3.1.2-bin-hadoop3.2.tgz -C /opt/installs 重命名&#x…

Javaee:阻塞队列和生产者消费者模型

文章目录 什么是阻塞队列java中的主要阻塞队列生产者消费者模型阻塞队列发挥的作用解耦合削峰填谷 模拟实现阻塞队列put方法take方法生产者消费者模型 什么是阻塞队列 阻塞队列是一种支持阻塞操作的队列,在多线程中实现通线程之间的通信协调的特殊队列 java中的主…