MoCo v3(ICCV 2021)

paper:An Empirical Study of Training Self-Supervised Vision Transformers

official implementation:https://github.com/facebookresearch/moco-v3

出发点

本文并没有提出一种新的方法,而是对计算机视觉领域最近进展中的一个重要且基础的情况进行研究:即视觉Transformers(ViT)的自监督学习。尽管标准卷积网络的训练方法已经非常成熟和稳健,但ViT的训练方法尚未建立,尤其是在自监督学习场景中,训练变得更加具有挑战性。

解决了什么问题

论文解决了自监督学习中ViT训练的不稳定性问题。作者观察到,尽管在某些情况下训练结果看似不错,但实际上这种不稳定性会导致准确度下降,并且这种下降在没有更稳定对照组的情况下很难被察觉。论文通过实验揭示了这种不稳定性,并提出了改进稳定性的方法。

创新点

  1. 对ViT在自监督学习框架下的训练基础组件进行了深入研究,包括批量大小、学习率和优化器。
  2. 提出了一种简单的改进方法,即在ViT中冻结patch projection层,使用固定的随机patch projection来提高训练稳定性。
  3. 在多个自监督框架中对ViT进行了基准测试和消融研究,提供了不同架构设计的ViT结果,并探讨了其影响。

MoCo v3

本文提出了MoCo v3,它是对MoCo v1/2的增量改进版本,伪代码如Alg 1所示

具体来说,我们对同一张图片进行两次随机的数据增强,并取出两个crop。它们由两个编码器 \(f_k\) 和 \(f_k\) 进行编码,得到输出向量 \(q\) 和 \(k\),直觉上,\(q\) 的行为就像是一个“query”,而学习的目标是检索相应的“key”。这个过程表述为最小化一个对比损失函数,本文采用InfoNCE

 

这里 \(k^+\) 是 \(f_k\) 对 \(q\) 同一张图片的输出,作为 \(q\) 的正样本,集和 \(\{k^-\}\) 由 \(f_k\) 对其它图片的输出组成,作为 \(q\) 的负样本,\(\tau\) 是 \(\ell_2\) 归一化的 \(q,k\) 的温度超参。 

MoCo v3采用同一batch中那些自然共存的keys,作者放弃了memory queue,因为作者发现如果batch size足够大(例如4096)它的增益会变少。通过这种简化,式(1)中的对比损失可以通过几行代码实现:见Alg 1中的ctr(q, k)。本文采用了对称损失ctr(q1, k2) + ctr(q2, k1)

编码器 \(f_q\) 由一个backbone(例如ResNet、ViT)、一个projection head以及一个额外的prediction head组成。编码器 \(f_k\) 由backbone和projection head组成,但没有prediction head。\(f_k\) 根据 \(f_q\) 的移动平均值进行更新,不包括prediction head。

作者用ResNet-50来测试了MoCo v3的精度,下表对比了ImageNet上的linear probing精度

改进主要是因为额外的预测头和大批量(4096)训练。

Stability of Self-Supervised ViT Training

原则上,在contrastive/Siamese自监督框架中,可以直接用ViT主干替换ResNet主干。但在实践中,作者遇到的一个主要挑战是训练的不稳定性。

作者观察到,不稳定问题不能简单地用精度来反映。实际上,如实验结果展示的那样,训练是“相当好的”并且结果也不错,即使它可能是不稳定的。为了揭示这种不稳定性,作者在训练期间监测了KNN曲线。作者首先研究了基本要素如何影响稳定性,这些曲线表明,训练可以是“部分成功的”或者换句话说是“部分失败的”。然后又探索了一个可以提高稳定性的简单技巧,从而在各种情况下都提高了模型精度。

Empirical Observations on Basic Factors

Batch size

ViT模型本身的计算量就很大,因此大批量训练是大型ViT模型的理想解决方案。在最近的一些自监督方法中,大的batch也有利于模型的精度。图1展示了不同batch size下的训练曲线。

当batch size为1k和2k时曲线相当平滑,linear probing准确率分别为71.5%和72.6%。在这个范围内,由于有更多的负样本,更大的batch size提高了精度。4k的曲线变得明显的不稳定,它的linear probing精度为72.2%,尽管和2k的精度相比只有轻微的下降,但它的精度受到了不稳定性的影响。6k的曲线就更差了,作者假设是训练部分重启了,跳出了局部最优,然后寻找一个新的轨迹。因此训练不会发散,但精度取决于重启的效果。当这种部分的失败发生了,仍然能够得到一个还不错的结果(69.7%),但这对于研究是有害的:和那些容易发现的灾难性的失败不同,这种小的退化可能会被忽略。

Learning rate

作者研究了学习率的影响,如图2所示。学习率越小,训练越稳定,但可能会导致欠拟合。如图2所示,学习率lr=0.5e-4比lr=1.0e-4的精度低了1.8%(70.4% vs. 72.2%)。当学习率变大时,训练会变得不稳定,当lr=1.5e-4时曲线有更多的抖动,精度也降低了。

Optimizer

默认情况下,训练ViT采用AdamW优化器,但最近的自监督方法都采用LARS优化器进行large-batch的训练。图3作者研究了LAMB优化器,它是LARS的AdamW版本。

如图3所示,当学习率合适时(lr=5e-4),LAMB的精度比AdamW稍好(72.5%)。但当lr大于最优值时精度迅速下降,有趣的是,训练曲线仍然是平滑的,但在中间部分开始逐渐下降。由于对lr较为敏感当网络结构不同时,LAMB需要额外的lr search,因此作者在本文中还是采用AdamW。

A Trick for Improving Stability

如图4所示,作者发现训练过程中梯度的突然变化会导致训练曲线的下降。通过比较所有层的梯度,作者发现梯度的峰值发生到第一层(patch projection),并通过若干迭代后传递到网络后面的层。因此作者探索了冻结patch projection层,即随机初始化后,就通过stop-gradient不再更新权重了。

 

图5展示了patch projection的参数可学习和随机初始化并冻结两种情况的结果,可以看到random patch projection稳定了训练,训练曲线更加平滑。这种稳定性对精度也有帮助,当lr=1.5e-4时,精度提高了1.7%达到了73.4%。这一实验证明了训练的不稳定性是影响精度的一个主要问题。

此外,作者还发现其它自监督方法可能也是不稳定的。图6展示了SimCLR和BYOL使用ViT的训练曲线,随机patch projection提高了两者的稳定性,精度也分别提高了0.8%和1.3%。SwAV也存在不稳定性,但当它不稳定时损失会发散(NaN),当使用一个较大的学习率并用random patch projection可以帮助SwAV收敛,并在使用最大的稳定学习率时将精度从65.8%提升到66.4%。总之,这一技巧在所有自监督框架中都是有效的。

实验结果

表4展示了不同四种不同的自监督框架使用ViT的性能对比,为了公平比较,每个框架的wd和lr都是单独搜索出来的,可以看到MoCo v3的效果是最好的。

下面是一些消融实验的结果

位置编码sin-cos的效果最好

当不用class token精度从76.5%下降到了69.7%,ViT在最后一个block的后面额外有一个LN层,如果把这层LN也去掉,精度又涨到了76.3%。

 

ViT中没有BN,只在MLP head中有BN,去掉BN时必须将batch size设为2048否则模型不收敛,此时精度下降了2.1%,这表明BN对于对比学习 不是必要的,但可以提升精度。

 

去掉prediction head会导致性能轻微的下降,表明预测头不是必须的。

 

动量m=0.99时精度最高,m=0类似SimCLR的做法(加上prediction head并停止keys上的梯度传播),使用动量编码器精度提升了2.2%。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/710918.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL 日志(一)

本篇主要介绍MySQL日志的相关内容。 目录 一、日志简介 常用日志 一般查询日志和慢查询日志的输出形式 日志表 二、一般查询日志 三、慢查询日志 四、错误日志 一、日志简介 常用日志 在MySQL中常用的日志主要有如下几种: 这些日志通常情况下都是关闭的&a…

我用AI绘画Stable Diffusion 一个月后,竟然能做出惊艳所有人的效果!

大家好,我是设计师阿威 如今要拍摄一组写真,需要服装、道具、灯光、场地、布景、拍摄、后期等过程。整个过程需要统一才能形成好的写真效果。现在有了AI绘图技术,我们可以实现通过AI绘图,只用计算机计算就得到一组接近真实的写真照…

Python 中国象棋游戏【含Python源码 MX_011期】

简介: 中国象棋是一种古老而深受喜爱的策略棋类游戏,也被称为中国的国粹之一。它在中国有着悠久的历史,起源可以追溯到几个世纪以前。Python 中国象棋游戏是一个用Python编程语言编写的软件程序,旨在模拟和提供中国象棋的游戏体验…

Github 2024-06-10开源项目周报 Top15

根据Github Trendings的统计,本周(2024-06-10统计)共有15个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目8Jupyter Notebook项目2Go项目2C++项目1Shell项目1Lua项目1JavaScript项目1MDX项目1C项目1HTML项目1Python - 100天从新手到大师 创建…

Maven 项目的创建(导入依赖、仓库、maven的配置、配置国内源、以及可能遇到的问题)

一、创建Maven项目 使用的编译软件:idea 软件版本: 社区版 2021.1 - 2022.4(为什么选择这个版本,因为只有这个版本里有一些插件是可以安装的) 专业版不限制(专业版功能是最全的,但是收费&am…

【会议征稿,ACM出版】2024年云计算与大数据国际学术会议(ICCBD 2024,7月26-28)

2024年云计算与大数据国际学术会议(ICCBD 2024)将于2024年7月26-28日在中国大理召开。ICCBD 2024将围绕“云计算与大数据”的最新研究领域, 旨在为从事研究的专家、学者、工程师和技术人员提供一个国际平台,分享科研成果和尖端技术,了解学术发展趋势&…

能耗分析与远程抄表是什么?

一、引言 在21世纪的数字化时代,能耗分析和远程抄表已成为现代能源管理的重要组成部分。这两项技术不仅提高了能源效率,还为企业和个人提供了更精细的能源使用数据,从而实现更科学的节能减排。 二、能耗分析的深度洞察 能耗分析是通过收集…

Python保姆级教程 数据类型—新手小白入门必看

python学习资料,下方已打包好 一、基本数据类型与变量(上) 2.1 注释 优点: 代码说明 没注释的代码 有注释的代码 不让解释器执行注释的那句话 2.2 单行注释 单行注释快捷键:ctrl ? 2.3多行注释 …

【Java】内部类、枚举、泛型

目录 1.内部类1.1概述1.2分类1.3匿名内部类(重点) 2.枚举2.1一般枚举2.2抽象枚举2.3应用1:用枚举写单例2.4应用2:标识常量 3.泛型3.1泛型认识3.2泛型原理3.3泛型的定义泛型类泛型接口泛型方法 3.4泛型的注意事项 1.内部类 1.1概述 内部类:指…

大模型:原理、进展及其影响

大模型:原理、进展及其影响---中国人民大学人工智能学院 一、大模型的背景和原理 二、大模型的飞速发展及趋势 三、大模型的深刻影响

1_常见指令【Linux中常见30个指令的学习和使用】【万字长文】

常见指令以及权限理解 开始学习linux前的注意事项 在学习linux之前,我们要知道linux是一个操作系统。 那操作系统是什么呢?(这里只做大概了解) 操作系统就是一个管理软硬件的软件。 它对上提供良好(稳定、高效、安…

服务器数据恢复—热备盘未完全启用导致raid5阵列崩溃的数据恢复案例

服务器存储故障: 一台EMC某型号存储由于存储中raid5阵列出现故障导致服务器崩溃,由于数据涉密,需要工程师到现场恢复数据。 服务器数据恢复工程师到现场后对数据进行检测,经过检测发现服务器崩溃是由于raid中某些硬盘掉线所导致。…

力扣hot100:75. 颜色分类(双指针)

75.颜色分类 本题是经典的「荷兰国旗问题」,由计算机科学家 Edsger W. Dijkstra 首先提出。 75. 颜色分类 1、遍历两遍 遍历两遍,第一遍放置0的位置,第二遍放置1的位置,我们只需要维护一个当前放置位置即可。 class Solution…

图片转Excel表格:提升数据处理效率的利器

在日常工作和生活中,我们经常遇到各种数据和信息以图片的形式存在。有时,这些数据图片中包含了重要的表格信息,例如财务报告、统计数据或调研结果。为了对这些数据进行进一步的分析和处理,我们需要将其转换为可编辑的电子表格格式…

深入理解计算机系统 家庭作业6.34

第一步先求(S,E,B,m) 题目说共C32个字节,块大小B为16个字节,那就是分为两组:0,1.然后每组存4个int 每个4字节 CB*E*S .B16 ,直接映射的E就是1,所以S2 m为啥等于7? 通过写出两个数组所有的地址可以得出m7. 得出高速缓存的参数:(S,E,B,m)(2,1,16,7),注意图6-26每个参数的定义…

如何实现 Python 源码压缩加密常用解决方案详细教程(更新中)

Python是一种高级的、解释型的、面向对象的编程语言,Python 码简洁易读,并且Python语言跨平台,拥有丰富的标准库和第三方库,深受开发人员的喜爱。 Python 程序扩展名 .py:这是 Python 程序的标准文件扩展名。当你创建…

FPGA - Verilog题目: 非整数倍数据位宽转换24to128

题目描述: 实现数据位宽转换电路,实现24bit数据输入转换为128bit数据输出。其中,先到的数据应置于输出的高bit位。 电路的接口如下图所示。valid_in用来指示数据输入data_in的有效性,valid_out用来指示数据输出data_out的有效性…

【2024算力大会分会 | SPIE独立出版 | 往届均已完成EI检索】2024云计算、性能计算与深度学习国际学术会议(CCPCDL 2024)

【2024算力大会分会 | SPIE出版】 2024云计算、性能计算与深度学习国际学术会议(CCPCDL 2024) 2024 International conference on Cloud Computing, Performance Computing and Deep Learning *CCPCDL往届均已完成EI检索,最快会后4个半月完成! 一、…

Spring Aop及事务管理

5 Spring AOP AOP概述 AOP:全称是 Aspect Oriented Programming 即:面向切面编程。简单的说它就是把我们程序重复的代码抽取出来,在需要执行的时候,使用动态代理的技术,在不修改源码的基础上,对我们的已有…

Linux--MQTT(一)简介

一、简介 MQTT ( Message Queuing Telemetry Transport,消息队列遥测传输), 是一种基于客户端服务端架构的发布/订阅模式的消息传输协议。 与 HTTP 协议一样, MQTT 协议也是应用层协议,工作在 TCP/IP 四…