ALBEF算法解读

ALBEF论文全名Align before Fuse: Vision and Language Representation Learning with Momentum Distillation,来自于Align before Fuse,作者团队为Salesforce Research。
论文地址:https://arxiv.org/pdf/2107.07651.pdf
论文代码:GitCode - 开发者的代码家园

1、灵感来源

a)图 (a) 是VSE架构,它们的文本端就是直接抽一个文本特征,但是它们的视觉端非常大,需要的计算量非常多,因为它通常是一个目标检测器。当得到了文本特征和视觉特征之后,它最后只能做一个很简单的模态之间的交互,从而去做多模态的任务。

b)图(b)是CLIP的结构,视觉端和文本端都用同等复杂度的encoder进行特征提取,再做一个简单的模态交互,结构优点是对检索任务极其有效,因为它可以提前把特征都抽好,接下来直接算Similarity矩阵乘法就可以,极其适合大规模的图像文本的检索,非常具有商业价值。缺点是只计算Cosine Similarity无法做多模态之间深度融合,难一些的任务性能差。

c)图(c)是Oscar或者ViLBERT、Uniter采用的架构,因为对于多模态的任务,最后的模态之间的交互非常重要,只有有了模态之间更深层的交互,VQA、VR、VE这些任务效果才会非常好,所以他们就把最初的简单的点乘的模态之间的交互,变成了一个Transformer的Encoder,或者变成别的更复杂的模型结构去做模态之间的交互,所以这些方法的性能都非常的好,但是随之而来的缺点也很明显:所有的这一系列的工作都用了预训练的目标检测器,再加上这么一个更大的模态融合的部分,模型不论是训练还是部署都非常的困难。

d)图 (d) 是ViLT的架构。当Vision Transformer出来之后,ViLT这篇论文就应运而生了,因为在Vision Transformer里,基于Patch的视觉特征与基于Bounding Box的视觉特征没有太大的区别,它也能做图片分类或者目标检测的任务,因此就可以将这么大的预训练好的目标检测器换成一层Patch Embedding就能去抽取视觉的特征,所以大大的降低了运算复杂度,尤其是在做推理的时候。但是如果文本特征只是简单Tokenize,视觉特征也只是简单的Patch Embedding是远远不够的,所以对于多模态任务,后面的模态融合非常关键,所以ViLT就直接借鉴 c类里的模态融合的方法,用一个很大的Transformer Encoder去做模态融合,从而达到了还不错的效果。因为移除了预训练的目标检测器,换成了可学习的Patch Embedding Layer。

ViLT的优点:模型极其简单。它虽然是一个多模态学习的框架,但跟NLP的框架没什么区别,就是先Tokenized,然后送到一个Transformer去学习,所以非常的简单易学。
ViLT的缺点:1)它的性能不够高,ViLT在很多任务上是比不过 © 类里的这些方法的,原因是对于现有的多模态任务,需要更多的视觉能力(可能是由于数据集的bias),因此视觉模型需要比文本模型要大,最后的效果才能好,但是在ViLT里,文本端用的Tokenizer很好,但是Visual Embedding是Random Initialized,所以它的效果就很差2)ViLT虽然推理时间很快,但它的训练时间非常慢,在非常标准的一个4 million的数据集set上,ViLT需要64张32G的GPU训练三天,它训练的复杂度和训练的时间丝毫不亚于 © 类的方法,所以它只是结构上简化了多模态学习,但训练难度并没有降低。

从上面可以发现,a和b的模态融合是比较弱的,比如b的代表就是clip,他是使用余弦相似度计算文本和图像的相似性,这对于VQA,VE之类的任务效果就比较差。c的代表是ViLBERT,除了多模态融合使用较大模型外,在图像端也使用了较大的模型。d的代表是ViLT,只在多模态融合时使用较大模型。

那么我们发现,b类模型的模态融合方面较弱,适合做检索但不适合做推理问答;cd类模型在模态融合方面较强,适合做推理问答但不适合做检索。

因此,在模型的选择上:

        1、因为有图像的输入和文本的输入,模型应有两个分支分别抽取图像文本特征。
        2、在多模态学习里,视觉特征重要性远远要大于这个文本特征,所以应该使用更大更强的视觉模型。
        3、多模态学习模态之间的融合也非常关键,因此需要模态融合的模型尽可能大,所以好的多模态学习网络结构应该像 ©,也就是文本编码器比图像编码器小,多模态融合的部分尽可能大。

2、ALBEF预训练

ALBEF有三个损失函数:图像文本对比损失(ITC)、图像文本匹配损失(ITM)和掩码语言建模损失(MLM)。

图像和文本对在经过图像编码器和文本编码器后会算一个ITC,促使学习特征对齐的能力。对齐后的文本和图像特征再经过多模态编码器会算一个ITM,促使学习特征交互的能力。仿照BERT,文本mask后再单独过一遍文本编码器和多模态编码器来算MLM,促使学习语言和特征交互的能力。

另外,还有两个训练技巧:

        1、使用ITC相似度矩阵,挖掘困难的负样本,来提升ITM的难度,进而督促模型学习。

        2、由于数据中有噪声(有的负文本对更能表述图像),仿照MoCo,使用动量模型作为teacher模型,在学习的过程中student模型以teacher的软标签作为label进行学习,能够抵抗数据质量不高和一图多义的问题。

图像端:给定任何一张图片,按照Vision Transformer的做法,把它打成patch,然后通过patch embedding layer送给一个Vision Transformer。这里是一个非常标准的12层的Vision Transformer的base模型,如果图片是224x224,那这里的sequence length就是196,然后加上额外的一个CLS token就是197,它的特征维度是768,所以这里绿黄色的特征就是197乘以768。但论文在预训练阶段用的图片是256x256,所以这里绿色的sequence length就会相应的再长一些。它的预训练参数用的DEiT,也就是Data Efficient Vision Transformer在ImageNet 1K数据集上训练出来的初始化参数。
文本端:文本端为保持计算量与clip类似,并且增强模态融合的部分,用前六层做文本编码,剩下的六层transformer encoder作为multi-model fusion的过程。文本模型用BERT模型做初始化,它中间的特征维度也是768,它也有一个CLS token代表了整个句子的文本信息。
momentum model:ALBEF为了做momentum distillation,而且为了给ITC loss提供更多negative,增加了momentum model,这个模型参数由左边训练的模型参数通过moving average得到的(和MoCo一模一样),通过把moving average的参数设的非常高(论文里是0.995)来保证momentum model不会那么快更新,产生的特征更加稳定,不仅可以做更稳定的negative sample,而且还可以去做momentum distillation。

ITC

ITC目的是在融合前学习更好的单模态表示,并且对齐图像和文本特征。它学习一个相似度函数,使得相关联的图像-文本对具有更高的相似度得分。受MoCo的启发,维护了两个队列来存储来自动量单模态编码器的最新M个图像-文本表示。

ITC loss:对比学习就是首先定义一个正样本对,然后定义很多负样本对,对比使正样本对之间的距离越来越近,正负样本对之间的距离越来越远。在ALBEF里首先将图像I通过vision transformer之后得到图像的全局特征,图中黄色CLS token作为全局特征,也就是一个768×1的向量,文本这边先做tokenize,将文本text变成tokenization的序列,再输入BERT的前六层,得到了一系列的特征,文本端的CLS token作为文本的全局特征,也是一个768×1的向量。接下来与MoCo相同,图像的特征向量先做一下downsample和normalization,将768×1变成了256×1的向量,文本特征向量也是768变成256×1。正样本这两个特征距离尽可能的近,它的负样本全都存在一个q里,含有65536个负样本(因为它由momentum model产生的,没有gradient,所以它并不占很多内存),正负样本之间的对比学习,使得这两个特征距离尽可能的远。这个过程就是align before fuse的align,也就是说在图像特征和这个文本特征输入Multi-model Fusion的encoder之前,就已经通过对比学习的ITC loss让这个图像特征和文本特征尽可能的拉近,在同一个embedding space里,具体使用了cross entropy loss。

ITM

ITM预测一个图文对是正对(匹配)还是负对(不匹配)。我们使用多模态编码器的[CLS] token的输出嵌入作为图像-文本对的联合表示,并附加一个全连接(FC)层,然后是softmax来预测两类概率。

本文提出一种策略,为ITM任务采样较难的负样本,而且计算开销为零。如果图像-文本具有相似的语义,但在细节上存在差异,这就是一个比较难的图像-文本对。我们使用公式1中的对比相似性来找到批量内的困难的负对。对于小批量中的每个图像,我们按照对比相似性分布从同一批次中采样一个负文本,其中与图像更相似的文本有更高的机会被采样。同样,我们还为每个文本采样一个难的负图像。

ITM loss:Image Text Matching,属于一个二分类任务,就是给定一个图片,给定一个文本,图像文本通过ALBEF的模型之后输出一个特征,在这个特征之后加一个分类头,也就是一个FC层,然后去判断I和T是不是一个对,这个loss虽然很合理,但是实际操作的时候发现这个loss太简单,所以这个分类任务,很快它的准确度就提升得很高无法进一步优化。因此ALBEF通过某种方式选择最难的负样本(最接近于正样本的那个负样本),具体来说ALBEF的batch size是512,所以ITM loss正样本对就是512个,对于mini batch里的每一张图像,把这张图片和同一个batch里所有的文本都算一遍cos similarity,然后它在这里选择除了它自己之外相似度最高的文本当做negative,这样ITM loss就非常有难度。

MLM

MLM利用图像和对应的文本来预测被遮盖的token。我们以15%的概率随机遮盖输入token,并用特殊标记[mask]替换它们。

MLM Loss:Mask Language Modeling,它把原来完整的句子text T变成一个T’,也就是有些单词被mask掉了,然后它把缺失的句子和图片一起通过ALBEF的模型,最后把之前的完整的句子给预测出来,它这里也借助了图像的信息去更好的恢复被mask掉的单词。

momentum distillation动量蒸馏

用于预训练的图像-文本对大部分是从web上收集的,它们往往有噪声。正相关对通常是弱相关的:文本可能包含与图像无关的单词,或者图像可能包含文本中没有描述的实体。对于ITC学习,图像的负面文本也可能与图像内容相匹配。对于MLM,可能存在与标注不同的其他单词,这些单词同样好地描述了图像(或更好)。然而,ITC和MLM的one-hot标签会惩罚所有负面预测,无论其正确性如何

为解决这个问题,本文建议从动量模型生成的伪目标中学习。动量模型是一个不断进化的老师,由单模态和多模态编码器的指数移动平均版本组成。在训练过程中,我们训练基础模型,使其预测与动量模型的预测相匹配。

具体的,作者认为one hot label(就是图片和文本就是一对,其他跟它都不是一对)对于ITC和MLM这两个loss来说不好,因为有的负样本也包含了很多的信息。所以作者采取了自训练方式,先构建一个momentum model然后用这个动量模型去生成pseudo targets伪目标(其实就是一个softmax score),这样它就不再是一个one hot label。
动量模型在已有模型之上做exponential moving average EMA,目的是在原始模型训练的时候,不仅希望模型预测与ground truth的one hot label去尽可能的接近,还希望模型预测与动量模型出来的pseudo targets尽可能的匹配,这样就能达到一个比较好的折中点。因为当one hot label正确时,可以学习到很多信息,但当one hot label是错误的,或者是noisy的时候,作者希望稳定的momentum model能够提供一些改进。
以ITC loss为例,它是基于这个one hot label的,所以这里再算一个pseudo target loss去弥补它的一些缺陷和不足。这个loss跟前面equation1里的ITC loss的不同,这里将ground truth换成这个q,就是pseudo targets,q不再是one hot label,而是softmax score,所以这里计算KL divergence而不是cross entropy。

最终ALBEF的训练loss有五个:两个ITC、两个MLM、一个ITM,其中:
1)ITC有两个loss:一个是原始的ITC,一个是基于pseudo target的ITC,所以分别加权(1-α)和α的loss weight,最终得到momentum版本的ITC loss。
2)MLM loss有两个loss:一个是原始的MLM,一个是基于pseudo target的MLM,用新生成的pseudo target去代替了原来的ground truth,分别加权(1-α)和α的loss weight,最终得到momentum版本的MLM loss。
3)ITM有一个loss:ITM没有动量版本,因为本身它就是基于ground truth,它就是一个二分类任务,而且在ITM里又做了hard negative,这跟momentum model有冲突。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/397670.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SICTF round#3 web

1.100&#xff05;_upload url可以进行文件包含&#xff0c;但是flag被过滤 看一下源码 <?phpif(isset($_FILES[upfile])){$uploaddir uploads/;$uploadfile $uploaddir . basename($_FILES[upfile][name]);$ext pathinfo($_FILES[upfile][name],PATHINFO_EXTENSION);$t…

大模型量化技术原理-LLM.int8()、GPTQ

近年来&#xff0c;随着Transformer、MOE架构的提出&#xff0c;使得深度学习模型轻松突破上万亿规模参数&#xff0c;从而导致模型变得越来越大&#xff0c;因此&#xff0c;我们需要一些大模型压缩技术来降低模型部署的成本&#xff0c;并提升模型的推理性能。 模型压缩主要分…

react开发者必备vscode插件【2024最新】

React开发者必备VSCode插件及使用教程 Visual Studio Code&#xff08;VSCode&#xff09;是当今最流行的代码编辑器之一&#xff0c;特别是在前端开发者中。对于使用React的开发者来说&#xff0c;VSCode不仅因其轻量和高度可定制而受到欢迎&#xff0c;还因为其强大的插件生…

Java项目,营销抽奖系统设计实现

作者&#xff1a;小傅哥 博客&#xff1a;https://bugstack.cn 项目&#xff1a;https://gaga.plus 沉淀、分享、成长&#xff0c;让自己和他人都能有所收获&#xff01;&#x1f604; 大家好&#xff0c;我是技术UP主&#xff0c;小傅哥。 经过这个假期的嘎嘎卷&#x1f9e8;…

8 大内部排序算法图文讲解

排序算法可以分为内部排序和外部排序&#xff0c;内部排序是数据记录在内存中进行排序&#xff0c;而外部排序是因排序的数据很大&#xff0c;一次不能容纳全部的排序记录&#xff0c;在排序过程中需要访问外存。常见的内部排序算法有&#xff1a;插入排序、希尔排序、选择排序…

软件测试面试题常见一百道【含答案】

1、问&#xff1a;你在测试中发现了一个bug&#xff0c;但是开发经理认为这不是一个bug&#xff0c;你应该怎样解决? 首先&#xff0c;将问题提交到缺陷管理库里面进行备案。 然后&#xff0c;要获取判断的依据和标准&#xff1a; 根据需求说明书、产品说明、设计文档等&am…

75.SpringMVC的拦截器和过滤器有什么区别?执行顺序?

75.SpringMVC的拦截器和过滤器有什么区别&#xff1f;执行顺序&#xff1f; 区别 拦截器不依赖与servlet容器&#xff0c;过滤器依赖与servlet容器。拦截器只能对action请求(DispatcherServlet 映射的请求)起作用&#xff0c;而过滤器则可以对几乎所有的请求起作用。拦截器可…

Redis基础和高级使用

文章目录 Redis概述Redis简介Redis特点Redis适合于做Redis不适合于做Redis安装 Redis命令Redis命令Redis的键 Redis数据类型Redis支持的数据类型字符串及相关命令字符串应用场景&#xff1a;列表及相关命令列表应用场景&#xff1a;集合及相关命令集合应用场景&#xff1a;有序…

环信IM Android端实现华为推送详细步骤

首先我们要参照华为的官网去完成 &#xff0c;以下两个配置都是华为文档为我们提供的 1.https://developer.huawei.com/consumer/cn/doc/HMSCore-Guides/android-config-agc-0000001050170137#section19884105518498 2.https://developer.huawei.com/consumer/cn/doc/HMSCore…

[OpenAI]继ChatGPT后发布的Sora模型解析与体验通道

前言 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家&#xff1a;https://www.captainbed.cn/z ChatGPT体验地址 文章目录 前言OpenAI体验通道Spacetime Latent Patches 潜变量时空碎片, 建构视觉语言系统…

HTTPS(超文本传输安全协议)被恶意请求该如何处理。

HTTPS&#xff08;超文本传输安全协议&#xff09;端口攻击通常是指SSL握手中的一些攻击方式&#xff0c;比如SSL握手协商过程中的暴力破解、中间人攻击和SSL剥离攻击等。 攻击原理 攻击者控制受害者发送大量请求&#xff0c;利用压缩算法的机制猜测请求中的关键信息&#xf…

【压缩感知基础】Nyquist采样定理

Nyquist定理&#xff0c;也被称作Nyquist采样定理&#xff0c;是由哈里奈奎斯特在1928年提出的&#xff0c;它是信号处理领域的一个重要基础定理。它描述了连续信号被离散化为数字信号时&#xff0c;采样的要求以避免失真。 数学表示 Nyquist定理的核心内容可以描述如下&…

java+vue_springboot企业设备安全信息系统14jbc

企业防爆安全信息系统采用B/S架构&#xff0c;数据库是MySQL。网站的搭建与开发采用了先进的java进行编写&#xff0c;使用了vue框架。该系统从三个对象&#xff1a;由管理员、人员和企业来对系统进行设计构建。主要功能包括&#xff1a;个人信息修改&#xff0c;对人员管理&am…

目录IO 2月19日学习笔记

1. lseek off_t lseek(int fd, off_t offset, int whence); 功能: 重新设定文件描述符的偏移量 参数: fd:文件描述符 offset:偏移量 whence: SEEK_SET 文件开头 SEE…

Expected class selector “.menuChildMall“ to be kebab-case报错原因

![在这里插入图片描述](https://img-blog.csdnimg.cn/dire ct/6b72bda760a2497a90558d48bd0a4de3.png) 使用stylelint格式化css文件时候报上述错误&#xff1a; 原因&#xff1a; css类名未使用-分隔符 将类名修改为&#xff1a; .menu-child-mall形式即可

C++11---(2)

目录 一、新增容器 1.1、array 1.2、forward_list 1.3、unordered系列 二、右值引用和移动语义 2.1、什么是左值&#xff0c;什么是左值引用 2.2、什么是右值&#xff0c;什么是右值引用 2.3、左值引用和右值引用比较 2.4、右值引用使用场景和意义 2.5、右值引用引用…

【教程】详解相机模型与坐标转换

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.cn] 由于复制过来&#xff0c;如果有格式问题&#xff0c;推荐大家直接去我原网站上查看&#xff1a; 相机模型与坐标转换 - 生活大爆炸 目录 经纬度坐标系 转 地球直角坐标系大地直角坐标系 转 经纬度坐标系地理坐标…

MLP-Mixer: AN all MLP Architecture for Vision

发表于NeurIPS 2021, 由Google Research, Brain Team发表。 Mixer Architecture Introduction 当前的深度视觉结构包含融合特征(mix features)的层:(i)在一个给定的空间位置融合。(ii)在不同的空间位置&#xff0c;或者一次融合所有。在CNN中&#xff0c;(ii) 是由N x N(N &g…

服务端实时推送技术之SSE(Server-Send Events)

文章目录 前言一、解决方案&#xff1a;1、传统实时处理方案&#xff1a;2、HTML5 标准引入的实时处理方案&#xff1a;3、第三方推送&#xff1a; 二、SSE1.引入库1、客户端&#xff1a; 2.服务端&#xff1a;三、业务实践&#xff1a;能否做到精准投递&#xff1f; 总结 前言…

解决Ubuntu中vscode右键没有create catkin package

右键发现没有这个create catkin package 解决方案&#xff1a; 查了一会发现安装个拓展就可以了 效果&#xff1a;