[论文粗读]A Simple Framework for Contrastive Learning of Visual Representations

引言

今天带来一篇经典论文A Simple Framework for Contrastive Learning of Visual Representations的笔记。

本篇工作提出了SimCLR,一种用于视觉表征对比学习的简单框架。提出(1)数据增强组合在定义有效预测任务中起到至关重要的作用;(2)在表示和对比损失之间引入可学习的非线性变换,可以显著提高学习表示的质量;(3)与监督学习相比,对比学习受益于更大的批大小和更多的训练步数。

为了简单,下文中以翻译的口吻记录,比如替换"作者"为"我们"。

虽然本篇工作是基于视觉表征的但这里的思考同样可以应用于NLP,可以简单地把相似度替换为句子中的单词来理解。

1. 总体介绍

在无监督的情况下学习有效的视觉表征是一个长期存在的问题,大多数方法属于两类之一: 生成式或判别式。

生成式方法学习生成或以其他方式对输入空间中的像素进行建模,然而像素级生成计算量巨大,对于表征学习可能并非必要。

判别式方法使用类似于监督学习中的目标函数来学习表征,但训练网络时其中输入和标签都来自未标记的数据集。许多此类方法依赖于启发式方式来设计预训练任务,这可能会限制学习到的表征的泛化能力。

基于潜在空间对比学习的判别方法近期显示出巨大的潜力,取得了先进的结果。

image-20250101154503896

图 1. 在不同自监督方法学习的表示上训练的线性分类器的 ImageNet Top-1 准确率(在 ImageNet 上预训练)。灰色十字表示监督 ResNet-50。 SimCLR 以粗体显示。

为了理解是什么使对比表示有效,我们系统地研究了我们框架的主要组成部分,表明:

  1. 多种数据增强操作的组合对于定义产生有效表示的对比预测任务至关重要。此外,无监督对比学习比监督对比学习受益于更强大的数据增强。
  2. 在表示和对比损失之间引入可学习的非线性变换,极大地提高了学习表示的质量。
  3. 利用对比交叉熵损失进行表示学习,受益于归一化的嵌入和适当调整的温度参数。
  4. 与监督学习相比,对比学习受益于更大的批大小和更长的训练时间,也受益于更深更宽的网络。

2. 方法

2.1 对比学习框架

image-20250101181653099

图 2. 用于视觉表征对比学习的简单框架。从同一个增强族中采样两个独立的数据增强算子( t ∼ τ 和 t ′ ∼ τ t ∼ \tau 和 t^\prime ∼ \tau tτtτ),并将其应用于每个数据样本以获得两个相关的视图。一个基础编码器网络 f ( ⋅ ) f(\cdot) f() 和一个投影头 g ( ⋅ ) g(\cdot) g()被训练以使用对比损失最大化一致性。训练完成后,我们丢弃投影头 g ( ⋅ ) g(\cdot) g(),并使用编码器 f ( ⋅ ) f(\cdot) f() 和表征 h \pmb h h用于下游任务。

SimCLR通过在潜在空间中使用对比损失,最大化相同数据样本的不同增强视图之间的一致性来表示学习。如图2所示,该框架包含以下四个主要组成部分。

  • 一个随机数据增强模块,它随机转换给定的数据样本,从而产生同一个样本的两个相关视图,分别表示为 x ~ i \tilde{ \pmb x}_i x~i x ~ j \tilde{ \pmb x}_j x~j,我们将其设为一对正样本。在本工作中,依次应用三种简单的增强方法:随机裁剪后恢复到原始大小,随机颜色失真和随机高斯模糊。
  • 一个神经网络基础编码器 f ( ⋅ ) f(\cdot) f(),用于从增强数据样本中提取表示向量。我们的框架允许在没有任何约束的情况下选择各种网络架构,我们简单地选择常用的ResNet来获得 h i = f ( x ~ i ) = ResNet ( x ~ i ) \pmb h_i = f(\tilde {\pmb x}_i)=\text{ResNet}(\tilde {\pmb x}_i) hi=f(x~i)=ResNet(x~i),其中 h i ∈ R d \pmb h_i \in \R ^d hiRd是平均池化后的输出。
  • 一个小型神经网络投影头 g ( ⋅ ) g(\cdot) g(),它将表示映射到应用对比损失的空间。我们使用一个具有一个隐藏层的MLP来获得 z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) \pmb z_i = g(\pmb h_i) = W^{(2)} \sigma(W^{(1)} \pmb h_i) zi=g(hi)=W(2)σ(W(1)hi)其中 σ \sigma σ是一个ReLU非线性,我们发现将对比损失定义在 z i \pmb z_i zi上而不是 h i \pmb h_i hi上是有益的。
  • 针对对比预测任务定义的对比损失函数。给定一个包含正样本对 x ~ i \tilde {\pmb x}_i x~i x ~ j \tilde {\pmb x}_j x~j的集合 { x ~ k } \{\tilde {\pmb x}_k\} {x~k},对比预测任务旨在为给定的 x ~ i \tilde {\pmb x}_i x~i { x ~ k } k = i \{\tilde {\pmb x}_k\}_{k=i} {x~k}k=i 中识别 x ~ j \tilde {\pmb x}_j x~j

我们随机抽取一个包含 N N N个样本的小批量数据,并在从该小批量数据中派生的增强样本对上定义对比预测任务,从而产生 2 N 2N 2N个数据点。没有明确地采样负样本,相反,给定一个正样本对,将小批量数据中的其他 2 ( N − 1 ) 2(N-1) 2(N1)个增强样本视为负样本。令 sim ( u , v ) \text{sim}(\pmb u,\pmb v) sim(u,v)表示它们之间的余弦相似度,一对正样本 ( i , j ) (i,j) (i,j)的损失函数定义为:
l i , j = − exp ⁡ ( sim ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k = i ] exp ⁡ ( sim ( z i , z k ) / τ ) (1) \mathscr l_{i,j} = - \frac{\exp(\text{sim}(\pmb z_i, \pmb z_j)/\tau)}{\sum_{k=1}^{2N} \Bbb 1_{[k=i]} \exp(\text{sim}(\pmb z_i, \pmb z_k)/\tau)} \tag 1 li,j=k=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)(1)
其中 1 [ k = i ] ∈ { 0 , 1 } \Bbb 1_{[k=i]} \in \{0,1\} 1[k=i]{0,1}是一个指示函数,当且仅当 k = i k=i k=i时取值为 1 1 1 τ \tau τ表示一个温度参数。

最终的损失是在小批量中所有正样本对 ( i , j ) (i,j) (i,j) ( j , i ) (j,i) (j,i)上计算的。下面的算法总结了所提出的方法:

image-20250101210730920

2.2 大批量训练

为了简化,没有使用记忆库训练模型。而是改变了训练批大小 N N N,从256到8192。批大小为8192为我们提供了来自两个增强视图的每个正样本对16382个负样本。

3. 对比表征学习的数据增强

数据增强定义了预测任务 虽然数据增强已经广泛应用于监督和无监督表示学习,但它尚未被视为定义对比预测任务的系统方法。

3.1 数据增强操作的组合对于学习良好的表示至关重要

为了系统地研究数据增强的影响,这里考虑了几种常见的增强方法。一种设计数据的空间几何变换,比如裁剪和调整大小、旋转、剪切。另一种设计外观变换,例如颜色失真(颜色下降、高度、对比度、饱和度、色调)、高斯模糊和索贝尔滤波。图4展示了本工作的增强方法。

image-20250101211617733

图 4. 研究的数据增强算子的示意图。每个增强都可以通过一些内部参数(例如旋转角度、噪声水平)随机地变换数据。需要注意的是,我们只在消融实验中测试了这些算子,用于训练我们模型的增强策略只包括随机裁剪(带翻转和缩放)、颜色失真和高斯模糊。

3.2 对比学习需要比监督学习更强的数据增强

image-20250101211730328

表 1. 在不同的颜色失真强度和其他数据变换下,使用线性评估的无监督 ResNet-50 和有监督 ResNet-50 的 Top-1 准确率。Stength 1 (+ Blur) 是我们的默认数据增强策略。

为了进一步证明颜色增强的重要性,我们调整了颜色增强的强度,如表1所示,更强的颜色增强显著提高了学习到的无监督模型的线性评估。

4. 编码器和Head架构

4.1 无监督对比学习从更大的模型中获益

image-20250101212043905

图 7. 不同深度和宽度模型的线性评估。蓝色圆点表示我们训练 100 个 epochs 的模型,红色星号表示我们训练 1000 个 epochs 的模型,绿色十字表示监督 ResNets 训练 90 个 epochs。

图7显示,增加深度和宽度都能提高性能。随着模型尺寸的增加,监督模型和在无监督模型上训练的线性分类器之间的差距正在缩小,表明无监督学习比其监督对照物从更大模型中获益更多。

4.2 非线性投影头提高了之前层的表示质量

image-20250101212511115

图 8. 使用不同投影头 g ( ⋅ ) g(\cdot) g()和不同维度 z = g ( h ) \pmb z=g(\pmb h) z=g(h)对表示进行线性评估。这里表示 h \pmb h h(投影前)是2048维的。

图8显示了使用三种不同头架构的线性评估结果:(1)恒等映射; (2)线性投影; (3)具有一个额外隐藏层(和ReLU激活)的默认非线性投影。

我们观测到非线性投影比线性投影效果更好,并且比不使用投影效果好得多。

5. 损失函数和批量大小

5.1 带可调节温度的归一化交叉熵损失优于其他方法

image-20250101212933556

表 2. 负损失函数及其梯度。所有输入向量,即 u , v + , v − \pmb u, \pmb v^+, \pmb v^- u,v+,v,都经过 ℓ2 归一化。NT-Xent 是归一化温度缩放交叉熵(Normalized Temperature-scaled Cross Entropy)的缩写。不同的损失函数对正负样本施加不同的权重。

我们将本文提出的NT-Xent损失与其他常用的对比损失函数进行比较。表2显示了目标函数以及损失函数输入的梯度。观察梯度,我们发现 1) 余弦相似度以及温度有效地对不同的样本进行加权,适当的温度可以帮助模型从难样本中学习; 2) 与交叉熵不同,其他目标函数不会根据其相对难度对负样本进行加权。因此必须对这些损失函数应用半硬负样本挖掘:与其对所有损失项计算梯度,不如使用半硬负样本项计算梯度。

5.2 对比学习从更大的批大小和更长的训练中获益

image-20250101213420685

图 9. 使用不同批次大小和 epoch 训练的线性评估模型 (ResNet-50)。每个条形代表从头开始的一次运行。

图9展示了在不同训练轮数下,批大小对模型的影响。当训练轮数较少(100轮)时,较大的批大小比较小的批大小具有显著优势。随着训练步数/轮次的增加,不同批大小之间的差距会减小或消失,前提是批次是随机重采用的。

对比学习中更大的批次提供了更多的负样本,促进了收敛。更长的训练时间也提供了更多的负样本,从而改善了结果。

6. 结论

我们提出了一种简单的对比视觉表征学习框架及其实例化,与标准监督学习不同之处在于数据增强方式的选择,网络末端非线性头的使用以及损失函数。效果显著优于以前的方法。

总结

⭐ 作者提出了一种对比学习框架,虽然论文是基于视觉探讨的,但后续也影响了很多NLP方面的工作。在更深更宽的模型和更大的批次、更长的训练时间基础上,首先通过多种数据增强产生有效的正样本用于对比学习。其次在表示和对比损失之间引入可学习的非线性头。最后利用归一化温度缩放交叉熵损失进行对比学习。在训练结束后,该引入的非线性头会被丢弃。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/946363.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

(leetcode算法题)188. 买卖股票的最佳时机 IV

题目中要求最多可以完成k次交易,很多时候不要把问题搞复杂了, 按照题目要求,研究对象是最后一天结束后最多进行了 k 次交易获得的最大利润 那么就可以把问题拆分成 第 1 天结束后完成 0 次交易获得的最大利润,第 1 天结束后完成…

使用 Docker 搭建 Hadoop 集群

1.1. 启用 WSL 与虚拟机平台 1.1.1. 启用功能 启用 WSL并使用 Moba 连接-CSDN博客 1.2 安装 Docker Desktop 最新版本链接:Docker Desktop: The #1 Containerization Tool for Developers | Docker 指定版本链接:Docker Desktop release notes | Do…

win32汇编环境,对话框程序模版,含文本框与菜单简单功能

;运行效果 ;win32汇编环境,对话框程序模版,含文本框与菜单简单功能 ;直接抄进RadAsm可编译运行。 ;下面为asm文件 ;>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>&g…

【赵渝强老师】MongoDB文档级别的并发控制

MongoDB在执行写操作时,WiredTiger存储引擎会在文档级别进行并发控制。换句话说在同一时间点上,多个写操作能够修改同一个集合中的不同文档;而当多个写操作修改同一个文档时,必须以序列化方式执行。这意味着如果当前文档正在被修改…

Java开发 PDF文件生成方案

业务需求背景 业务端需要能够将考试答卷内容按指定格式呈现并导出为pdf格式进行存档,作为紧急需求插入。导出内容存在样式复杂性,包括特定的字体(中文)、字号、颜色,页面得有页眉、页码,数据需要进行表格聚…

C++文件流 例题

问题: 设计一个留言类,实现以下的功能: 1) 程序第一次运行时,建立一个 message.txt 文本文件,并把用 户输入的信息存入该文件; 2) 以后每次运行时,都先读取该文件的内容并显示给用户&#xf…

Xilinx DCI技术

Xilinx DCI技术 DCI技术概述Xilinx DCI技术实际使用某些Bank特殊DCI要求 DCI级联技术DCI端接方式阻抗控制驱动器(源端接)半阻抗控制阻抗驱动器(源端接)分体式DCI(戴维宁等效端接到VCCO/2)DCI和三态DCI&…

「Mac畅玩鸿蒙与硬件51」UI互动应用篇28 - 模拟记账应用

本篇教程将介绍如何创建一个模拟记账应用,通过账单输入、动态列表展示和实时统计功能,学习接口定义和组件间的数据交互。 关键词 UI互动应用接口定义动态列表实时统计数据交互 一、功能说明 模拟记账应用包含以下功能: 账单输入&#xff1…

阴阳师の新手如何速刷5个SP/SSR?!(急速育成)

目标:攒5个SP/SSR式神,参与急速育成,省四个黑蛋(想要快速升级技能而且经常上场的式神在攒够5个式神前先不升级)【理论上组成:10蓝40蓝预约召唤福利20修行or抽卡】 关键点:蓝票,新手…

Linux应用软件编程-多任务处理(进程,线程)-通信(管道,信号,内存共享)

多任务处理:让系统具备同时处理多个事件的能力。让系统具备并发性能。方法:进程和线程。这里先讲进程。 进程(process):正在执行的程序,执行过程中需要消耗内存和CPU。 进程的创建:操作系统在…

使用 TensorFlow 打造企业智能数据分析平台

文章目录 摘要引言平台架构设计核心架构技术栈选型 数据采集与预处理代码详解 数据分析与预测代码详解 数据可视化ECharts 配置 总结未来展望参考资料 摘要 在大数据时代,企业决策正越来越依赖数据分析。然而,面对海量数据,传统分析工具常因…

初始JavaEE篇 —— Maven相关配置

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程程(ಥ_ಥ)-CSDN博客 所属专栏:JavaEE 目录 介绍 创建第一个Maven项目 Maven的核心功能 项目构建 依赖管理 添加依赖 依赖排除 依赖调解 Maven仓库 配置本地仓…

Linux套接字通信学习

Linux套接字通信 代码源码:https://github.com/say-Hai/TcpSocketLearn/tree/CThreadSocket 在网络通信的时候, 程序猿需要负责的应用层数据的处理(最上层),而底层的数据封装与解封装(如TCP/IP协议栈的功能)通常由操作系统、网络协…

职场常用Excel基础01-数据验证

大家好,excel在职场中使用非常频繁,今天和大家一起分享一下excel中数据验证相关的内容~ 在Excel中,数据验证(Data Validation)是一项非常有用的功能,它可以帮助用户限制输入到单元格中的数据类型和范围&am…

建造者设计模式学习

1.介绍 建造者模式是一种创建型设计模式,它将一个复杂对象的构建过程与它的表示分离,使得相同的构建过程可以创建不同的表示。通过分步骤地构建对象,建造者模式提供了更细粒度的控制和灵活性,特别适合需要灵活创建复杂对象的场景…

ROS2+OpenCV综合应用--10. AprilTag标签码追踪

1. 简介 apriltag标签码追踪是在apriltag标签码识别的基础上,增加了小车摄像头云台运动的功能,摄像头会保持标签码在视觉中间而运动,根据这一特性,从而实现标签码追踪功能。 2. 启动 2.1 程序启动前的准备 本次apriltag标签码使…

mysql乱码、mysql数据中文问号

网上排出此错误方法的很多,但是 都不简洁,找不到根本原因 主要排查两点: 1.代码中jdbc链接的编码规则 urljdbc:mysql://localhost:3306/title?useUnicodetrue&characterEncodingutf8 将characterEncoding设置为utf8 2.设置mysq…

Presto-简单了解-230403

presto是什么了解一下: 秒级查询引擎(不做存储),GB-PB级不依赖于yarn,有自己的资源管理和执行计划支持多种数据源:hive、redis、kafka presto架构 presto优缺点 presto优点 内存到内存的传输&#xff0…

openGauss连接是报org.opengauss.util.PSQLException: 尝试连线已失败

安装好高斯数据库后然后用java连接时报如下错误: 解决方法: 在openGauss数据库的安装路径下/opt/opengauss/data/single_node(这个路径根据自己实际情况变化)有个pg_hba.conf文件,修改里面host内容如下,我这里设置的是所有ip都能…

mybatis-plus自动填充时间的配置类实现

mybatis-plus自动填充时间的配置类实现 在实际操作过程中,我们并不希望创建时间、修改时间这些来手动进行,而是希望通过自动化来完成,而mybatis-plus则也提供了自动填充功能来实现这一操作,接下来,就来了解一下mybatis…