BYOL(NeurIPS 2020)原理解读

paper:Bootstrap your own latent: A new approach to self-supervised Learning

third-party implementation:https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/selfsup/byol.py

本文的创新点

本文提出了一种新的自监督学习方法,Bootstrap Your Own Latent(BYOL),和以往需要大量负样本的对比学习方法如SimCLR不同,BYOL不依赖于负样本对。此外,和之前需要精心设计增强策略的对比方法相比,BYOL对图像增强的敏感度较低。BYOL在ImageNet上的linear evaluation取得了新的SOTA,并且在迁移学习和半监督学习的基准测试中表现优异。

方法介绍

BYOL的目标是学习一个可以用于下游任务的表示 \(y_{\theta}\)。BYOL使用online和target两个神经网络来学习。在线网络由一组权重 \(\theta\) 定义并由三个阶段组成:encoder \(f_{\theta}\)、projector \(g_{\theta}\)、predictor \(q_{\theta}\)。如图2所示。目标网络和在线网络的架构相同,但使用不同的权重 \(\xi\)。目标网络提供了用于训练在线网络的regression target,其参数 \(\xi\) 是在线网络参数 \(\theta\) 的指数移动平均值。给定一个衰减率 \(\tau \in[0,1]\),在每个训练step后,我们执行如下更新

给定一组图片 \(\mathcal{D}\),从 \(\mathcal{D}\) 中均匀采样得到一张图片 \(x\sim\mathcal{D}\),以及两组增强分布 \(\mathcal{T}\) 和 \(\mathcal{T}'\),BYOL通过对 \(x\) 分别应用增强 \(t\sim\mathcal{T}\) 和 \(t'\sim\mathcal{T}'\) 得到两个增强视图 \(v \triangleq t(x)\) 和 \(v' \triangleq t'(x)\)。从第一个增强视图 \(v\) 来看,在线网络输出一个表示 \(y \triangleq f_{\theta}(v)\) 和一个映射 \(z_{\theta} \triangleq g_{\theta}(y)\)。目标网络从第二个增强视图 \(v'\) 输出 \(y_{\xi} \triangleq f_{\xi}(v')\) 和目标映射 \(z_{\xi}' \triangleq g_{\xi}(y')\)。然后输出一个 \(z'_{\xi}\) 的预测 \(q_{\theta}(z_{\theta})\) 并对 \(q_{\theta}(z_{\theta})\) 和 \(z'_{\xi}\) 进行 \(\ell_2\) 标准化得到 \(\overline{q_\theta}\left(z_\theta\right) \triangleq q_\theta\left(z_\theta\right) /\left\|q_\theta\left(z_\theta\right)\right\|_2\) 和 \(\bar{z}_{\xi}^{\prime} \triangleq z_{\xi}^{\prime} /\left\|z_{\xi}^{\prime}\right\|_2\)。注意这个预测网络只应用在在线分支,使得在线和目标分支是非对称的结构。最后我们定义标准化预测和目标映射之间的均方根误差

然后我们再将 \(v'\) 输入在线网络,将 \(v\) 输入目标网络得到式(2)中 \(\mathcal{L}_{\theta, \xi}\) 的对称损失 \(\widetilde{\mathcal{L}}_{\theta, \xi}\)。在每个训练step,通过梯度下降根据 \(\mathcal{L}^{BYOL}_{\theta, \xi}=\mathcal{L}_{\theta, \xi}+\widetilde{\mathcal{L}}_{\theta, \xi}\) 来优化 \(\theta\),但不更新 \(\xi\)。如图2中的stop-gradient所示,BYOL的整体优化如下

在训练完成后,我们只保留 \(f_{\theta}\)。

BYOL的伪代码如下

 

实验结果

在ImageNet上的linear evaluation如下,可以看到BYOL取得了SOTA的表现。

 

作者还比较了减小batch size以及减少增强方法时BYOL和SimCLR性能。可以看到,BYOL比SimCLR性能下降的更慢,尤其是在减少数据增强时,表明BYOL对batch size和数据增强的不敏感。

 

代码解析

class BYOL(BaseSelfSupervisor):
    """BYOL.

    Implementation of `Bootstrap Your Own Latent: A New Approach to
    Self-Supervised Learning <https://arxiv.org/abs/2006.07733>`_.

    Args:
        backbone (dict): Config dict for module of backbone.
        neck (dict): Config dict for module of deep features
            to compact feature vectors.
        head (dict): Config dict for module of head functions.
        base_momentum (float): The base momentum coefficient for the target
            network. Defaults to 0.004.
        pretrained (str, optional): The pretrained checkpoint path, support
            local path and remote path. Defaults to None.
        data_preprocessor (dict, optional): The config for preprocessing
            input data. If None or no specified type, it will use
            "SelfSupDataPreprocessor" as type.
            See :class:`SelfSupDataPreprocessor` for more details.
            Defaults to None.
        init_cfg (Union[List[dict], dict], optional): Config dict for weight
            initialization. Defaults to None.
    """

    def __init__(self,
                 backbone: dict,
                 neck: dict,
                 head: dict,
                 base_momentum: float = 0.004,
                 pretrained: Optional[str] = None,
                 data_preprocessor: Optional[dict] = None,
                 init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
        super().__init__(
            backbone=backbone,
            neck=neck,
            head=head,
            pretrained=pretrained,
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)

        # create momentum model
        self.target_net = CosineEMA(
            nn.Sequential(self.backbone, self.neck), momentum=base_momentum)

    def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
             **kwargs) -> Dict[str, torch.Tensor]:
        """The forward function in training.

        Args:
            inputs (List[torch.Tensor]): The input images.
            data_samples (List[DataSample]): All elements required
                during the forward function.

        Returns:
            Dict[str, torch.Tensor]: A dictionary of loss components.
        """
        assert isinstance(inputs, list)
        img_v1 = inputs[0]
        img_v2 = inputs[1]
        # compute online features
        proj_online_v1 = self.neck(self.backbone(img_v1))[0]
        proj_online_v2 = self.neck(self.backbone(img_v2))[0]
        # compute target features
        with torch.no_grad():
            # update the target net
            self.target_net.update_parameters(
                nn.Sequential(self.backbone, self.neck))

            proj_target_v1 = self.target_net(img_v1)[0]
            proj_target_v2 = self.target_net(img_v2)[0]

        loss_1 = self.head.loss(proj_online_v1, proj_target_v2)
        loss_2 = self.head.loss(proj_online_v2, proj_target_v1)

        losses = dict(loss=2. * (loss_1 + loss_2))
        return losses

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/559621.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux配置环境变量_推荐的方式

Linux配置环境变量_推荐以下两种方法&#xff1a; (1)用户环境变量&#xff1a;编辑用户目录下 ~/.bashrc、~/.bash_profile 或 ~/.profile文件 (2)系统环境变量&#xff1a;在/etc/profile.d/目录&#xff0c;创建独立的.sh文件 环境变量脚本文件的执行顺序 /etc/profile-&g…

【Java集合进阶】数据结构(平衡二又树旋转机制)数据结构(红黑树、红黑规则、添加节点处理方案详解)

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收藏 …

记一次 Java 应用内存泄漏的定位过程

问题现象 最近&#xff0c;笔者负责测试的某个算法模块机器出现大量报警&#xff0c;报警表现为机器CPU持续高占用。该算法模块是一个优化算法&#xff0c;本身就是CPU密集型应用&#xff0c;一开始怀疑可能是算法在正常运算&#xff0c;但很快这种猜测就被推翻&#xff1a;同…

如何使用云数据库GaussDB管理平台进行实例安装?

前言 随着数字经济的蓬勃发展&#xff0c;数据库也成为企业的关键技术生产力&#xff0c;也是各行各业数字化转型的必要根基。GaussDB作为新一代分布式数据库&#xff0c;核心代码100%自主创新&#xff0c;具备高可用、高安全、高性能、高弹性、高智能、易部署、易迁移的特性&…

Java作业6-Java类的基本概念三

编程1 import java.util.*;abstract class Rodent//抽象类 {public abstract String findFood();//抽象方法public abstract String chewFood(); } class Mouse extends Rodent {public String findFood(){ return "大米"; }public String chewFood(){ return "…

shm 共享内存

shm 共享内存 0,命令1&#xff0c;了解&#xff1a;2&#xff0c;程序: 0,命令 ipcs 查看分配的共享内存ipcrm -m shmid 删掉分配的共享内存1&#xff0c;了解&#xff1a; 1&#xff09;&#xff0c;进程通信的一种 2&#xff09;&#xff0c;地址映射出来后&#xff0c;就不…

C语言数据结构之顺序表

目录 1.线性表2.顺序表2.1顺序表相关概念及结构2.2增删查改等接口的实现 3.数组相关例题 1.线性表 线性表&#xff08;linear list&#xff09;是n个具有相同特性&#xff08;数据类型相同&#xff09;的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构&#xff…

Github 2024-04-20 开源项目日报 Top10

根据Github Trendings的统计,今日(2024-04-20统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量非开发语言项目2Python项目2Swift项目2HTML项目1CSS项目1Go项目1C项目1C++项目1Rust项目1编程面试大学:成为软件工程师的全面学习计划 创建周期…

半导体材料(三)——P-N结和金属-半导体接触

本篇为西安交通大学本科课程《电气材料基础》的笔记。 本篇为这一单元的第三篇笔记&#xff0c;上一篇传送门。 p-n结和金属-半导体接触 p-n结 无偏压开路状态 如图a所示&#xff0c;左边是n型掺杂&#xff0c;右边是p型掺杂&#xff0c;在n区和p区之间形成了一个不连续的…

WARNING: No swap limit support——查看docker状态时提示警告

环境&#xff1a;Ubuntu 20.04 1、警告详情 执行命令 service docker status如下图 2、解决办法 2.1 修改文件 执行命令 vim /etc/default/grub在GRUB_CMDLINE_LINUX中追加cgroup_enablememory swapaccount1&#xff0c;如下&#xff1a; # If you change this file…

【蓝桥杯嵌入式】蓝桥杯嵌入式第十四届省赛程序真题,真题分析与代码讲解

&#x1f38a;【蓝桥杯嵌入式】专题正在持续更新中&#xff0c;原理图解析✨&#xff0c;各模块分析✨以及历年真题讲解✨都已更新完毕&#xff0c;欢迎大家前往订阅本专题&#x1f38f; &#x1f38f;【蓝桥杯嵌入式】蓝桥杯第十届省赛真题 &#x1f38f;【蓝桥杯嵌入式】蓝桥…

攻防世界18.fileclude

18.fileclude include函数&#xff1a;包含并执行变量或者文件。 if&#xff1a;是if语句用来判断。 isset&#xff1a;判断变量是否存在&#xff0c;值是否为NULL。 $_GET&#xff1a;接收表单提交数据&#xff0c;并把数据附加到url链接当中。 逻辑运算符&&&#xff…

【提示学习论文】BlackVIP: Black-Box Visual Prompting for Robust Transfer Learning论文原理

BlackVIP: Black-Box Visual Prompting for Robust Transfer Learning BlackVIP:稳健迁移学习的黑盒视觉提示 问题 黑盒白盒&#xff1f; 黑盒和白盒的概念与对预训练模型内部参数的了解程度相关。黑盒指的是对预训练模型的参数和结构缺乏详细了解&#xff0c;通常只能通过使…

NAT基本配置

配置IP完成及缺省的路由如下&#xff1b; 此时R1pingISP是ping不通的&#xff0c;因为缺省是可以将数据传给R3&#xff0c;但是R3传不回去&#xff0c;知道目标IP地址但因其是私有内部IP&#xff0c;而自己的是公有IP&#xff0c;所以传不过去&#xff0c;此时就需要R2这个边界…

2024 发布Maven项目到中央仓库

注册sonatype账号 Maven中央仓库并不支持直接发布jar包&#xff0c;sonatype是其指定的第三方仓库之一&#xff0c;发布到此的项目会被定时同步到中央仓库 官方教程地址&#xff1a;https://central.sonatype.org/register/central-portal/ 访问网址&#xff1a;https://centra…

文件操作和IO

1.认识文件 我们先来认识狭义上的⽂件(file)。针对硬盘这种持久化存储的I/O设备&#xff0c;当我们想要进⾏数据保存时&#xff0c;往往不是保存成⼀个整体&#xff0c;⽽是独⽴成⼀个个的单位进⾏保存&#xff0c;这个独⽴的单位就被抽象成⽂件的概念&#xff0c;就类似办公桌…

# 从浅入深 学习 SpringCloud 微服务架构(三)注册中心 Eureka(2)

从浅入深 学习 SpringCloud 微服务架构&#xff08;三&#xff09;注册中心 Eureka&#xff08;2&#xff09; 段子手168 1、搭建 EurekaServer 注册中心&#xff0c;使用 Eureka 的步骤&#xff1a; 1&#xff09;搭建 EurekaServer 创建工程&#xff0c;导入依赖坐标&…

Python-VBA函数之旅-globals函数

目录 一、globals函数的常见应用场景&#xff1a; 二、globals函数与locals函数对比分析&#xff1a; 1、globals函数&#xff1a; 1-1、Python&#xff1a; 1-2、VBA&#xff1a; 2、推荐阅读&#xff1a; 个人主页&#xff1a;https://blog.csdn.net/ygb_1024?spm101…

基于springboot+vue+Mysql的广场舞团管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

牛客小白月赛91

A.Bingbong的化学世界 链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 来源&#xff1a;牛客网 时间限制&#xff1a;C/C 1秒&#xff0c;其他语言2秒 空间限制&#xff1a;C/C 262144K&#xff0c;其他语言524288K 64bit IO Format: %lld 题目描述 &#x1f319;“上…