MoCo v1(CVPR 2020)原理与代码解读

paper:Momentum Contrast for Unsupervised Visual Representation Learning

official implementation:https://github.com/facebookresearch/moco

背景

最近的一些研究提出使用对比损失相关的方法进行无监督视觉表征学习并取得了不错的结果。尽管是受到不同motivation的启发,这些方法都可以看做是在构建一个动态字典。字典中的"keys"(tokens)从数据(图片或图片的patch)中采样并用一个编码器encoder网络来表示。无监督学习训练encoder来执行字典查找:一个encoded "query"应该与它匹配的key相似,而与其它的key不同。学习过程表述为最小化对比损失的过程。

存在的问题

从构建动态字典的角度来看,作者假设构建的字典应该具备两个特点:

  1. large即字典要足够大
  2. 在训练期间字典要保持一致性

从直觉上来说,一个更大的字典可以更好地对连续的、高维的视觉空间进行采样。而字典中的键应该由相同或相似的编码器表示,以便它们与query的比较是一致的。然而,一些使用对比损失的现有方法受限于这两个方面中的一个(具体将在后续的方法介绍中讨论)。

本文的创新点

本文提出了动量对比(Momentum Contrast,MoCo)作为一种构建大型和一致的字典的方法,用于对比损失的无监督学习,如图1所示。

作者维护了一个数据样本的队列作为字典,当前mini-batch的encoded representation进队,队列中最老的表示出队。队列将字典大小和batch size进行解耦从而使得字典可以非常大。此外由于字典的key来源于之前若干个mini-batch,作者提出了一个缓慢变化的key encoder,具体实现为query encoder的基于动量的移动平均值,从而保持一致性。 

无监督学习的一个主要目的是得到一个预训练表示,通过微调可以tranfer到下游任务中。作者通过实验表明,在7个与检测和分割相关的下游任务中,MoCo无监督预训练可以超过ImageNet有监督预训练。

方法介绍

Contrastive Learning as Dictionary Look-up

对比学习可以用来为字典查找任务训练一个编码器。对于一个encoded query \(q\) 和一组encoded样本 \(\{k_0,k_1,k_2,...\}\),后者是字典的keys。假设字典中有一个单独的key(表示为 \(k_+\))与 \(q\) 匹配,对比损失作为一个函数,当 \(q\) 和的positive key \(k_+\) 相似并与所有其它的key(被认为是 \(q\) 的negative keys)不相似时对比损失的值很小。用点积来表示相似性,本文采用了对比损失的一种形式,InfoNCE,如下

其中 \(\tau\) 是是温度超参,结果对一个正样本和 \(K\) 个负样本求和。从直觉上来说,这个损失是一个 \((K+1)\) 类基于softmax分类器的log损失,这个分类器试图将 \(q\) 分为 \(k_+\) 类。对比损失还有其它形式,比如margin-based loss和NCE loss的一些变种。

对比损失作为无监督的目标函数用来训练encoder network来表示queries和keys。一般来说,query representation是 \(q=f_q(x_q)\) 其中 \(f_q\) 是encoder网络,\(x_q\) 是一个query样本(同样,\(k=f_k(x_k)\))。初始化取决于具体的代理任务,输入 \(x_q\) 和 \(x_k\) 可以是图像、patches、或包含一组patches的context。网络 \(f_q\) 和 \(f_k\) 可以是相同的、部分共享的、或不同的。

Momentum Contrast

Dictionary as a queue

本文方法的核心是维护一个数据样本的队列作为字典,这使得我们可以重用前面mini-batch中的encoded keys,队列的引入将字典大小与batch大小进行了解耦,我们的字典可以比普通的batch size大得多,并且可以灵活独立的作为一个超参来设置。

字典中的样本被逐步替换掉,当前mini-batch进入队列,而队列中最老的mini-batch被删除。字典总是表示所有数据的一个采样子集,而维护字典的额外计算是可控的。此外删除最早的mini-batch也是有好处的,因为它的encoded keys是最老的,与最新的编码最不一致。

Momentum update

使用队列可以使字典更大,但也使得通过反向传播更新key encoder变得困难(梯度应该传播到队列中的所有样本)。一个天真的解决方法是忽略key encoder \(f_k\) 的梯度直接拷贝query encoder \(f_q\),但这种解决方案在实验中得到的结果很差,作者推测这是由于快速变化的encoder减少了key representation的一致性导致的。因此提出了动量更新来解决这个问题。

我们将 \(f_k\) 的参数表示为 \(\theta_k\),\(f_q\) 的参数表示为 \(\theta_q\),然后通过下式更新 \(\theta_k\)

其中 \(m\in[0,1)\) 是动量系数,只有参数 \(\theta_q\) 通过反向传播更新,式(2)中的动量更新使得 \(\theta_k\) 比 \(\theta_q\) 的更新更平滑。因此,尽管队列中的keys是通过不同的encoder编码的(不同的mini-batch),这些encoder之间的差异非常小。后续实验表明,一个更大的动量(例如 \(m=0.999\))比更小的动量(例如 \(m=0.9\))表现得更好,表明一个缓慢更新的key encoder是使用队列的核心。

Relations to previous mechanisms

MoCo是使用对比损失的一种机制,作者将其与其它两种机制进行了对比,如图2所示,它们在字典大小和一致性上表现出不同的属性。

图2(a)是通过反向传播进行end-to-end更新的一种机制,它使用当前mini-batch中的样本作为字典,因此key的编码是一致的(通过相同的一组encoder参数)。但是字典的大小和mini-batch的大小耦合,受限于GPU的内存。同时也受到大mini-batch优化问题的挑战。

另外一种机制是采用memory bank,如图2(b)所示。memory back包含了数据集中所有样本的representation,每个mini-batch的字典是从memory bank中随机采样得到的,且没有反向传播,因此字典的size可以很大。但是,memory bank中一个样本的表示在它最后一次被看到时就更新了,因此采样的keys是过去一个epoch中不同step的encoder得到的,从而缺乏了一致性。

Pretext Task

对比学习可以使用不同的代理任务,由于本文的重点不是设计一个新的代理任务,本文遵循instance discrimination任务使用了一个简单的代理任务。如果一个query和一个key来源于同一张图像,则将它们视为positive pair,否则视为negative pair。我们对同一张图像进行两次随机数据增强得到一个postive pair,queries和keys分别由各自的encoder \(f_q\) 和 \(f_k\) 进行编码,encoder可以是任何的卷积网络。

MoCo的伪代码如下所示,对当前的mini-batch,我们对postive pair分别进行编码得到queries和对应的keys,负样本来源于队列。

Shuffling BN

编码器 \(f_q\) 和 \(f_k\) 中都使用了BN,作者在实验中发现使用BN会阻止模型学习好的表示,模型似乎“欺骗”了代理任务并很容易地找到了一种low-loss的解决方法。这可能是样本之间的batch内的通信(BN引起的)泄露了信息。

作者通过shuffle BN来解决这个问题。具体训练是在多个GPU上进行的,每个GPU独立的对样本执行BN。对于key encoder \(f_k\),在将当前mini-batch分配到不同GPU之前打乱样本顺序(并在编码之后还原顺序),query encoder \(f_q\) 不进行打乱顺序。这保证了用于计算query和对应的positve key的统计信息来自于不同的子集,有效解决了欺骗问题。

代码解析

下面是官方实现,基本上和文章中的伪代码一致,没有什么难以理解的地方。其中encoder_k的参数更新顺序和伪代码不一样,伪代码是f_q和f_k分别forward,然后f_q的loss反向传播,更新f_q的参数,最后f_k进行动量更新。而代码中是f_q先forward,然后f_k更新参数,接着f_k进行forward,最后再根据反向传播更新f_q。

另外,这里包含了MoCo v2的代码,主要的区别就是v2借鉴SimCLR的做法,在encoder的avg pooling层后多加了一层projection layer,即一个MLP。

# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]  # 2048
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
            )

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        # 将张量或缓冲区注册为 nn.Module 的一部分,但不会被视为模型的可学习参数。
        # 通常情况下,这用于存储模型中的固定参数或状态,例如均值、方差等,这些参数在训练过程中不会被更新。
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr: ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()
        # 打乱索引顺序,比如batch_size_all=8, idx_shuffle=[1,3,5,2,0,4,7,6]

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)
        # 将生成的随机索引序列从GPU 0(src=0)广播到所有其他的GPU设备上,以便在分布式训练时,每个GPU都能够获得相同的随机索引序列,以保持数据的同步性。

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)  # tensor([4, 0, 3, 1, 5, 2, 7, 6])

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder
            # 和论文中伪代码的顺序不一样,论文中encoder_k是先forward后更新参数,这里是先更新参数后forward

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

实验结果

无监督模型的常见评估方法是将训练好的encoder的权重freeze,后面接一层全连接层和softmax,然后在目标数据上只训练全连接层,最后在测试集上评估得到的模型效果。下面是MoCo和之前的无监督模型的结果对比,可以看到MoCo取得了最优的结果。

无监督模型的另一个作用是当做下游任务的预训练权重。在VOC目标检测任务上和监督预训练的对比如下,可以看到MoCo比监督预训练权重的效果更好。

 

下面是在COCO数据的目标检测任务和实例分割任务上与随机初始化权重、监督预训练权重的结果对比

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/541083.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springcloud第4季 springcloud-alibaba之nacos篇

一 nacos 1.1 nacos作用介绍 nacos是一个分布式的配置中心和注册发现中心。 nacos是 dynamic naming configuration service nacosconfigbus 实现动态刷新;nacosconsul 1.2 各个注册中心对比 注册中心CAP模型控制台管理社区活跃度EureakaAp支持低zkcp不支持中…

leetcode73 矩阵置零

题目描述 给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用原地算法。 输入:matrix [[1,1,1],[1,0,1],[1,1,1]] 输出:[[1,0,1],[0,0,0],[1,0,1]] 输入:matrix [[0,1,2,0],[3,4…

使用 Python 标记具有相同名称的条目

如果大家想在 Python 中标记具有相同名称的条目,可以使用字典(Dictionary)或集合(Set)来实现。这取决于你们希望如何存储和使用这些条目。下面我将提供两种常见的方法来实现这个目标。 1、问题背景 在处理数据时&…

PE文件的分析和构造超详细过程

本文详细讲述如何从0构造一个PE文件,运行该文件会弹出一个HelloPE的窗口 目录 预备知识 1. 构造DOS头IMAGE_DOS_HEADER 1.1 构造DOS_MZ头 1.2 构造DOS_STUB 2、构造PE头IMAGE_NT_HEADERS 248字节 2.1 signature 2.2 IMAGE_FILE_HEADER 2.3 IMAGE_OPTI…

Python爬虫:蝉妈妈返回参数data解密

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ 🐴作者:秋无之地 🐴简介:CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作,主要擅长领域有:爬虫、后端、大数据开发、数据分析等。 🐴欢迎小伙伴们点赞👍🏻、收藏⭐️、…

Spring Boot | Spring Boot 整合 “Servlet三大组件“ ( Servlet / Filter / Listene )

目录: Spring Boot 整合 "Servlet三大组件" :1. 使用 "组件注册" 的方式 "整合Servlet三大组件" ( 实际操作为 : 创建自定义的"三大组件"对象 结合刚创建"的自定义组件对象"来 将 XxxRegistrationBean对象 通过…

sparkSql join 关联机制

💐💐扫码关注公众号,回复 spark 关键字下载geekbang 原价 90 元 零基础入门 Spark 学习资料💐💐 join 实现机制 Join 有 3 种实现机制,分别是 NLJ(Nested Loop Join)、SMJ&#xf…

【VUE】使用Vue和CSS动画创建滚动列表

使用Vue和CSS动画创建滚动列表 在这篇文章中,我们将探讨如何使用Vue.js和CSS动画创建一个动态且视觉上吸引人的滚动列表。这个列表将自动滚动显示项目,类似于轮播图的方式,非常适合用于仪表盘、排行榜或任何需要在有限空间内展示项目列表的应…

【Altium Designer 20 笔记】隐藏PCB上的信号线(连接线)

使用网络类隐藏特定类型的信号线 如果你想要隐藏特定类型的信号线(例如电源类),你可以首先创建一个网络类。使用快捷键DC调出对象类浏览器,在Net Classes中右击添加类,并重命名(例如为“Power”&#xff0…

【Qt 学习笔记】QWidget的geometry属性及window frame的影响

博客主页:Duck Bro 博客主页系列专栏:Qt 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ QWidget的geometry属性 文章编号:Qt 学习笔记 / 16 文章目…

spring boot学习第十七篇:OAuth2概述及使用GitHub登录第三方网站

0. 导言 我们在浏览器上可以访问成百上千个网站,使用每个网站的服务一般都要先注册账号,那么我们为了更好地记忆,一般都会在多个网站使用相同的账号和密码进行注册。那么问题就来了,如果在你注册的网站中有某些个网站的系统设计不…

C++进阶03 模板与群体数据

听课笔记简单整理,供小伙伴们参考~🥝🥝 第1版:听课的记录代码~🧩🧩 编辑:梅头脑🌸 审核:文心一言 目录 🐳课程来源 🐋模板 🐋8.…

小区烟火AI检测/楼道杂物堆积消防隐患AI智能识别方案

一、背景需求 据新闻报道,今年4月7日,安徽省合肥市肥东县一民房发生火灾,致1死11伤,起火点是“一楼楼道杂物间”。 因为小区居民楼楼道堆积大量杂物而导致的消防火灾事故也不在少数。楼道堆积杂物是一个长期存在的问题&#xff…

安装ODBC方法

1、运行 搜索 ODBC数据源管理程序 32位或者 64位 2、在用户DSN或者系统DSN选择添加(建议前者),此处以添加access数据库的odbc驱动为例 3、安装成功

2024妈妈杯数学建模A 题思路分析-移动通信网络中 PCI 规划问题

# 1 赛题 A 题 移动通信网络中 PCI 规划问题 物理小区识别码(PCI)规划是移动通信网络中下行链路层上,对各覆盖 小区编号进行合理配置,以避免 PCI 冲突、 PCI 混淆以及 PCI 模 3 干扰等 现象。 PCI 规划对于减少物理层的小区间互相干扰(ICI),增…

jenkins通过pipeline部署springboot项目

部署方案: 1、springboot项目不保存部署的pipeline或dockerfile构建脚本等与部署相关的问文件,业务项目只需关心业务,能够正常构建为jar包即可 2、新建一个代码仓库,用于保存项目需要构建的Jenkinsfile 3、jenkins配置pipeline地址…

Element ui 动态展示表格列,动态格式化表格列的值

需求 后台配置前端展示的表格列,遇到比如 文件大小这样的值,如果后台存的是纯数字,需要进行格式化展示,并且能控制显示的小数位数,再比如,部分列值需要加单位等信息,此外还有状态类&#xff0…

【心路历程】初次参加蓝桥杯实况

送给大家一句话: 寂静的光辉平铺的一刻,地上的每一个坎坷都被映照得灿烂。 – 史铁生 《我与地坛》 初次参加蓝桥杯有感 一点小小的震撼难评的做题过程A题 艺术与篮球问题描述解题 B 题 五子棋问题描述解题 C题 训练士兵问题描述解题 D题 团建解题 E题 …

基于SpringBoot+Vue的毕业设计管理系统(源码+文档+部署+讲解)

一.系统概述 二十一世纪我们的社会进入了信息时代,信息管理系统的建立,大大提高了人们信息化水平。传统的管理方式对时间、地点的限制太多,而在线管理系统刚好能满足这些需求,在线管理系统突破了传统管理方式的局限性。于是本文针…

【前端】layui table表格勾选事件,以及常见模块

欢迎来到《小5讲堂》,大家好,我是全栈小5。 这是《前端》系列文章,每篇文章将以博主理解的角度展开讲解, 温馨提示:博主能力有限,理解水平有限,若有不对之处望指正! 目录 表格勾选事…