因果学习篇(2)-Causal Attention for Vision-Language Tasks(文献阅读)

Causal Attention for Vision-Language Tasks

引言

  这篇论文是南洋理工大学和澳大利亚莫纳什大学联合发表自2021年的CVPR顶会上的一篇文献,在当前流行的注意力机制中增加了因果推理算法,提出了一种新的注意力机制:因果注意力(CATT),使用因果推断中的“前门准则”解决训练数据中存在的虚假相关性,刨析了注意力机制在推理过程中的因果原理,在提高模型性能的前提下,加强了模型的可解释性,打开了神经网络的黑匣子,具有非常好的参考意义。
此外,本文所提的因果注意力遵守了传统K-Q-V的机制,所以可以替换现有的各种自注意力机制,如Transformer,是一个可插拔的模块。

> 代码已开源:https://github.com/yangxuntu/lxmertcatt
> 原文链接:Causal Attention for Vision-Language Tasks

瓶颈问题

  在传统的自注意力机制中,通常都是利用查询集Query和键集Key相乘得到权重,然后再更新值集Value。在该过程中,注意力的权重是无监督的,即注意力权重在训练过程中没有标注权重标签,这难免导致数据偏差。
  举个例子:
image captioning)
  如上图展示了一个图像描述的流程,在训练数据中,因为有好多图片被描述为“人骑着马”,自注意力机制自然而然的会将“骑”这个动作与“人”、“马”关联起来。那么,在测试阶段,如果给一个“人驾驶马车”的图片,注意力机制会仍然用“人”和“马”进行关联,推断出“骑”这个动作,而忽略了“马车”。然而,该问题不会因为数据规模的扩大而解决,因为现实中确实红色苹果比绿色苹果多,站着的人比跳舞的人多。
  该问题的本质原因就是混杂因子(因果推理中的专有名词)导致的,比如X和Y之间没有直接的因果关系,但是X和Y之间仍然相关。下面的因果结构图可以解释该理论:
在这里插入图片描述)
  图中,X为输入图片,Y为标签,C表示常识(如人可以骑马),C是混杂因子,M是图片X中的目标。
C->X

  • 表示视觉场景可以通过常识来生成;
  • X->M表示场景中包含着多个目标;
  • C->M表示目标可以由常识来决定;
  • M->Y表示语言生成;

  从因果图中可以看出,X->Y有两条路径:X->M->Y 和 X<-C->M->Y(含混杂因子)
因此,无论数据集多大,如果不知道混杂因子,仅使用P(Y|X)来训练模型,永远无法识别真正因果效应。
比如训练集中“人骑马”比“人驾驶马车”的数据多,测试集中后者比前者多,那么训练中的P(Y|X)将无法引用在测试中。

解决方法

  1. 提出了一个新的注意力机制:因果注意力(Causal Attention , CATT), 识别X->Y的因果效应,避免混杂因子造成的数据偏差。该方法使用了前门准则(无需混杂因子的假设知识);
  2. 提出了样本内注意力(ISATT)和跨样本注意力(CS-ATT),遵守Q-K-V操作;而且Q-K-V操作的参数也可以在IS-ATT和CS-ATT之间共享,以进一步提高某些架构中的效率。
  3. 在LSTM、Transformer和大规模视觉-语言预训练模型中进行了测试,验证了所提模块能够大幅改善模型性能。

方法

前门因果图中的注意力

  如上面图中所示,在预测标签Y时,通常采用观测相关性P(Y|X),计算如下:
在这里插入图片描述
  但是,该概率预测时可能学习到由后门路径Z<-X<->Y导致的错误相关性,而不是Z->Y的真正因果效应。
因此,需要使用因果干预切断X->Z这条后门路径,从而阻断Z<-X<->Y路径。方法是将输入X分成不同的情况{x},然后通过以下期望来测量Z对Y的平均因果效应:在这里插入图片描述
  其原理可以通过例子来解释:
  比如图像描述数据集中有大量的“人和滑雪板”的数据,那么模型会学习到错误的“人”与“滑雪板”的关联关系,而不会人的性别;
CS-Sampling则将人和其他样本中的对象结合,比如自行车、镜子。
最终,预测概率如下:
在这里插入图片描述

  以上过程称为前门准则。

IS-Sampling和CS-Sampling
IS-Sampling

  对于上述的概率公式,可以将P(Y|Z,X)用一个softmax层的网络g()来表示,因为许多视觉-语言模型都是做分类的,其公式如下:
在这里插入图片描述
  因此,需要采样X和Z,并送到网络中完成P(Y|do(X)),do表示对X进行干预(因果学习的内容)。
为降低复杂度,使用归一化加权几何平均数(NormalizedWeighted Geometric Mean,NWGM)近似地将外采样吸收为特征级别,实现只需要一次前向传播(我没懂,有懂得记得评论告诉我一哈)。具体公式如下:
在这里插入图片描述
  其中h和f都是特征编码函数,这里,我的理解可能就是将概率公式中的x,同样使用神经网络层来进行特征嵌入表示。xz加粗表示向量。
  IS-Sampling在实际计算中是按一个传统的注意力网络计算的,简单的表示为Q-K-V操作,结构下图蓝色部分。
在这里插入图片描述
  对应的,In-Sampling attention(IS-ATT)算法如下:
在这里插入图片描述
  下标I表示IS-ATT。其中,所有的 K I K_I KI V I V_I VI都来源于当前输入样本特征, Q I Q_I QI来源于h(X)。交叉模态i注意力中,查询向量表示的是句子上下文,而自注意力机制中查询向量表示的仍然是输入样本特征。对于 A I A_I AI而言,每个注意力向量 a I a_I aI都是IS-Sampling的P(Z=z|h(X))概率估计,输出 Z ^ \hat{Z} Z^是IS-Sampling评估向量。
与IS-ATT类似,交叉样本注意力(Cross-Sample attention,CS-ATT)结构如上图红色部分,算法如下:
在这里插入图片描述
  其中, K C K_C KC V C V_C VC都来源于训练集中的其他样本, Q C Q_C QC来源于 f ( X ) f(X) f(X) a C a_C aC近似 P ( X = x ∣ f ( X ) ) P(X=x|f(X)) P(X=xf(X)),且 X ^ \hat{X} X^是CS-Sampling评估向量。
最后,单一的因果注意力分别由IS-ATT和CS-ATT得到,然后,拼接这两个值作为最终 P ( Y ∣ d o ( X ) ) P(Y|do(X)) P(Ydo(X))的值。

因果注意力(CATT)在堆叠注意力网络中的应用
  • Transformer+CATT
    在这里插入图片描述
      从图中可以看出,视觉-语言Transformer模型包含视觉编码器和文本解码器。在实现中,共包含六个蓝色和紫色的模块,编码器的输入包括当前图片和一个全局的图像嵌入字典(数据集所有文本的嵌入向量),编码器中IS-ATT和CS-ATT的输出输入到解码器中,用于学习视觉语言的表示。在解码器中,首先输入IS-ATT和CS-ATT的是当前的句子嵌入和全局嵌入字典,输出包括IS-Sampling和CS-Sampling两部分内容,将其进行拼接将作为最终的预测。

实验分析

  本文在不同技术实现的视觉问答、图像描述两个任务上进行了实验。数据集有MS COCO、VQA2.0和Pre-training and Fine-tuning Datasets for VLP,实验结果如下。

图像描述
  • 相似度度量
    在这里插入图片描述
  • 偏差度量
    在这里插入图片描述
视觉问答
  • 基于LSTM和Transformer的VQA模型性能对比
    在这里插入图片描述
  • 不同题型对测试标准分割的准确性
    在这里插入图片描述

结论和展望

  在本文中,利用因果推理来分析注意机制易被数据集偏差误导的原因,发现了注意机制是一个不正确的近似前门准则,无法捕捉输入和输出之间真正的因果关系。在此基础上提出了一种新的注意力机制-因果注意力(CATT),该机制通过消除混杂因子来提高注意机制的质量。具体来说,CATT包含样本内和样本间注意力,用于估计前门平差中的样本内和样本间注意力,两个注意力网络都遵循Q-K-V操作。
  将CATT应用到各种流行的基于注意力的视觉语言模型,实验结果表明,它可以改善这些模型的相当大的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/435273.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【kubernetes】关于k8s集群的pod控制器

目录 一、deployment控制器 二、statefulset控制器 1、验证数据可以持久化 2、验证删除后名称不会改变&#xff0c;数据还会一直存在 3、验证扩容的创建过程是升序串行执行&#xff0c;并且自动创建pv 4、验证滚动更新的时候也是升序执行&#xff0c;数据持久化还在 5、验…

9、Linux驱动开发:驱动-控制接口的实现(ioctl)

目录 &#x1f345;点击这里查看所有博文 随着自己工作的进行&#xff0c;接触到的技术栈也越来越多。给我一个很直观的感受就是&#xff0c;某一项技术/经验在刚开始接触的时候都记得很清楚。往往过了几个月都会忘记的差不多了&#xff0c;只有经常会用到的东西才有可能真正记…

DailyNotes个人笔记管理工具

DailyNotes 是记录笔记和跟踪任务的应用程序&#xff0c;使用markdown进行编辑 部署 下载镜像 docker pull m0ngr31/dailynotes创建目录并授权 mkdir -p /data/dailynotes/config_dir chmod -R 777 /data/dailynotes启动容器 docker run -d --restart always --name mynot…

ruoyi-vue框架密码加密传输

先看一下改造后的样子&#xff0c;输入的密码不会再以明文展示。 下面我主要把前后端改造的代码贴出来。 1.后端代码 RsaUtils类 在com.ruoyi.common.utils包下新建RsaUtils类&#xff0c;RsaUtils添加了Component注解 generateKeyPair()构建密钥对添加了Bean注解 在项目启动…

【MATLAB】语音信号识别与处理:卷积滑动平均滤波算法去噪及谱相减算法呈现频谱

1 基本定义 卷积滑动平均滤波算法是一种基于卷积操作的滤波方法&#xff0c;它通过对信号进行卷积运算来计算移动平均值&#xff0c;以消除噪声。该算法的主要思想是将滤波窗口的加权系数定义为一个卷积核&#xff0c;对信号进行卷积运算来得到平滑后的信号。这样可以有效地去…

【论文精读】【Yolov1】You Only Look Once Unified, Real-Time Object Detection

0.论文摘要 我们提出了YOLO&#xff0c;一种新的目标检测方法。先前关于目标检测的工作重新利用分类器来执行检测。相反&#xff0c;我们将目标检测框架确定为空间分离的边界框和相关类别概率的回归问题。单个神经网络在一次评估中直接从完整图像预测边界框和类别概率。由于整…

mybatis单表curd笔记(尚硅谷

Mybatis 11111ibatis和mybatis不同 查询文档mybatis的日志输出id赋值输入&#xff08;向sql语句传入数据单个简单类型单个实体对象多个简单类型map类型 输出数据的指定单个简单类型单个实体类型输出map类型输出list输出类型主键回显&#xff08;自增长类型主键回显&#xff08;…

R语言中定义函数、调用函数及常用编程技巧

R语言中定义函数、调用函数及常用编程技巧 介绍定义函数调用函数常用编程循环结构apply 函数族apply()案例&#xff1a; lapply()案例&#xff1a; sapply()案例&#xff1a; vapply()案例&#xff1a; mapply()案例&#xff1a; 介绍 R语言是一种功能强大的统计分析编程语言&a…

构建阶段的软件供应链威胁

随着软件供应链生命周期从源代码发展到可执行组件&#xff0c;构建阶段是一个关键时刻。然而&#xff0c;这一变革阶段也容易受到一系列威胁的影响&#xff0c;这些威胁可能会危及软件的完整性和构建安全性。 这些威胁可以通过各种方法渗透构建过程&#xff0c;包括规避已建立…

LVS----DR模式

一、LVS-DR工作原理 1、LVS-DR数据包流向分析 客户端发送请求到Director Server (负载均衡器)&#xff0c;请求的数据报文&#xff08;源IP是CIP&#xff0c;目标IP是VIP&#xff09;到达内核空间。Director Server 和Real Server 在同一个网络中&#xff0c;数据通过二层数据…

比较 2 名无人机驾驶员:借助分析飞得更高

近年来&#xff0c;越来越多的政府和执法机构使用无人机从空中鸟瞰。为了高效执行任务&#xff0c;无人机必须能够快速机动到预定目标。快速机动使它们能够在复杂的环境中航行&#xff0c;并高效地完成任务。成为认证的无人机驾驶员的要求因国家/地区而异&#xff0c;但都要求您…

node_vue个人博客系统开发

Day01 一、导入express 1、创建node_serve服务文件夹 2、初始化项目 npm init -y3、导入express框架 npm i express4、创建一个app.js文件,为服务端的入口文件 // 导入express模块 const express = require(express); // 创建express服务 const app = express(); // 调用…

UVa11595 Crossing Streets EXTREME

题目链接 UVa11595 - Crossing Streets EXTREME 题意 平面上有 n&#xff08;n≤35&#xff09;条直线&#xff0c;各代表一条街道。街道相互交叉&#xff0c;形成一些路段&#xff08;对应于几何上的线段&#xff09;。你的任务是设计一条从A到B的路线&#xff0c;使得穿过路…

土地利用数据分类过程教学/土地利用分类/遥感解译/土地利用获取来源介绍/地理数据获取

本篇主要介绍如何对影像数据进行分类解译&#xff0c;及过程教学&#xff0c;示例数据下载链接&#xff1a;数据下载链接 一、背景介绍 土地是人类赖以生存与发展的重要资源和物质保障&#xff0c;在“人口&#xff0d;资源&#xff0d;环境&#xff0d;发展&#x…

excel中去除公式,仅保留值

1.单个单元格去除公式 双击单元格&#xff0c;按F9. 2.批量去除公式 选中列然后复制&#xff0c;选择性粘贴&#xff0c;选值粘贴

C++之类型转换

C语言中的类型转换 在C语言中, 如果赋值运算符左右两侧类型不同, 或者形参与实参类型不匹配, 或者返回值类型与 接收返回值类型不一致时, 就需要发生类型转化, C语言中总共有两种形式的类型转换: 隐式类型转换和显式类型转换 1. 隐式类型转化是关联度很强, 意义相近的类型之间…

事务 失效的八种情况

在某些业务场景下&#xff0c;如果一个请求中&#xff0c;需要同时写入多张表的数据。为了保证操作的原子性&#xff08;要么同时成功&#xff0c;要么同时失败&#xff09;&#xff0c;避免数据不一致的情况&#xff0c;我们一般都会用到 spring 事务。 确实&#xff0c;sprin…

css使用伪元素绘制带三角箭头的提示框

效果图 代码实现 使用伪元素进行绘制&#xff1a; <div class"my-tip"></div> .my-tip{width: 128px;height: 100px;background: #FFFFFF;box-shadow: 0px 1px 10px 0px rgba(0,0,0,0.05), 0px 4px 5px 0px rgba(0,0,0,0.08), 0px 2px 4px -1px rgba(0…

【开源】SpringBoot框架开发网上药店系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 药品类型模块2.3 药品档案模块2.4 药品订单模块2.5 药品收藏模块2.6 药品资讯模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 角色表3.2.2 药品表3.2.3 药品订单表3.2.4 药品收藏表3.2.5 药品留言表…

Python 快速获取PDF文件的页数

有时在处理或打印一个PDF文档之前&#xff0c;你可能需要先知道该文档包含多少页。虽然我们可以使用Adobe Acrobat这样的工具来查看页数&#xff0c;但对于程序员来说&#xff0c;编写脚本来完成这项工作会更加高效。本文就介绍一个使用Python快速获取PDF文件页数的办法。 安装…