(表征学习论文阅读)A Simple Framework for Contrastive Learning of Visual Representations

Chen T, Kornblith S, Norouzi M, et al. A simple framework for contrastive learning of visual representations[C]//International conference on machine learning. PMLR, 2020: 1597-1607.

1. 前言

本文作者为了了解对比学习是如何学习到有效的表征,对本文所提出的三大组件进行了全面的研究:

  1. 各种数据增强手段的组合在表征学习中起到了重要作用;
  2. 在表征和对比损失之间引入非线性变换能够有效提高表征质量;
  3. 对比学习相较于监督学习需要更大的batch size和更多的训练步数。

在没有人类标注或者监督的情况下学习数据的有效表征是一个长期存在的难题,目前的主要工作可以分为两类:

  1. 基于生成模型的方法
    例如VQ-VAE,MAE,BERT
  2. 基于判别模型的方法
    例如MoCo,CLIP

2. 方法

本文提出了一个框架SimCLR,通过最大化同一数据的不同数据增强处理后的两个视角之间的相似度来学习有效表征。
在这里插入图片描述

  1. 如图所示,本文首先将数据 x x x进行两个不同的增强,这里作者使用了三种简单的数据增强方法:随机裁剪后再调整到原始大小、随机颜色失真、高斯模糊。
  2. f ( ∙ ) f(\bullet) f()代表编码器,这里作者使用的是同一个编码器来对两个视角数据进行编码
  3. 最后编码器输出的结果通过非线性变换 g ( ∙ ) g(\bullet) g()得到 z i z_i zi z j z_j zj,两个向量构成了一组正例,进行相似度计算,也就是简单的单位向量内积计算出余弦相似度。目标就是最大化两者的余弦相似度。同时,一个batch中其他的数据构成了负例,最小化与负例的相似度。注意最终训练完成的编码器我们是需要舍弃掉非线性变换的。
    本文使用的损失函数就是最基本的InfoNCE损失,具体可以参考我的另一篇讲解InfoNCE的博文。
    在这里插入图片描述
    在这里插入图片描述

3. 代码

这里仅提供文章提到的两个点的代码:

  1. 数据增强
    高斯模糊
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms

np.random.seed(0)


class GaussianBlur(object):
    """blur a single image on CPU"""
    def __init__(self, kernel_size):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
                                stride=1, padding=0, bias=False, groups=3)
        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
                                stride=1, padding=0, bias=False, groups=3)
        self.k = kernel_size
        self.r = radias

        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )

        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()

    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)

        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()

        img = self.tensor_to_pil(img)

        return img

组合各类增强手段

class ContrastiveLearningDataset:
    def __init__(self, root_folder=r"D:\pyproject\representation_learning\data"):
        self.root_folder = root_folder

    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """Return a set of data augmentation transformations as described in the SimCLR paper."""
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              GaussianBlur(kernel_size=int(0.1 * size)),
                                              transforms.ToTensor()])
        return data_transforms

    def get_dataset(self, name, n_views):
        valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
                                                              transform=ContrastiveLearningViewGenerator(
                                                                  self.get_simclr_pipeline_transform(32),
                                                                  n_views),
                                                              download=True),

                          'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                          transform=ContrastiveLearningViewGenerator(
                                                              self.get_simclr_pipeline_transform(96),
                                                              n_views),
                                                          download=True)}

        try:
            dataset_fn = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return dataset_fn()
  1. 非线性变换
class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        # 修改resnet最后一层的全连接层即可
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def forward(self, x):
        return self.backbone(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/527860.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Disk Drill Enterprise for Mac v5.5.1515数据恢复软件中文版

Disk Drill 是 Mac 操作系统固有的Mac数据恢复软件:使用 Recovery Vault 轻松保护文件免遭意外删除,并从 Mac 磁盘恢复丢失的数据。支持大多数存储设备,文件类型和文件系统。 软件下载:Disk Drill Enterprise for Mac v5.5.1515激…

【YOLOV8】项目目录重点部分介绍和性能评估指标

目录 一 项目目录重点部分介绍 二 性能评估指标 一 项目目录重点部分介绍 1 ultralytics

3D医疗图像配准 | 基于Vision-Transformer+Pytorch实现的3D医疗图像配准算法

项目应用场景 面向医疗图像配准场景,项目采用 Pytorch ViT 来实现,形态为 3D 医疗图像的配准。 项目效果 项目细节 > 具体参见项目 README.md (1) 模型架构 (2) Vision Transformer 架构 (3) 量化结果分析 项目获取 https://download.csdn.net/down…

.NET 设计模式—装饰器模式(Decorator Pattern)

简介 装饰者模式(Decorator Pattern)是一种结构型设计模式,它允许你在不改变对象接口的前提下,动态地将新行为附加到对象上。这种模式是通过创建一个包装(或装饰)对象,将要被装饰的对象包裹起来…

基于Springboot考研资讯平台的设计与实现(论文+源码)_kaic

摘 要 随着现在网络的快速发展,网络的应用在各行各业当中它很快融入到了许多学校的眼球之中,他们利用网络来做这个电商的服务,随之就产生了“考研资讯平台”,这样就让学生考研资讯平台更加方便简单。 对于本考研资讯平台的设计来…

npm包安装与管理:深入解析命令行工具的全方位操作指南,涵盖脚本执行与包发布流程

npm,全称为Node Package Manager,是专为JavaScript生态系统设计的软件包管理系统,尤其与Node.js平台紧密关联。作为Node.js的默认包管理工具,npm为开发者提供了便捷的方式来安装、共享、分发和管理代码模块。 npm作为JavaScript世…

C++ | Leetcode C++题解之第18题四数之和

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<vector<int>> fourSum(vector<int>& nums, int target) {vector<vector<int>> quadruplets;if (nums.size() < 4) {return quadruplets;}sort(nums.begin(), nums.en…

谷歌seo自然搜索排名怎么提升快?

要想在谷歌上排名快速上升&#xff0c;关键在于运用GPC爬虫池跟高低搭配的外链组合 首先你要做的&#xff0c;就是让谷歌的蜘蛛频繁来你的网站&#xff0c;网站需要被谷歌蜘蛛频繁抓取和索引&#xff0c;那这时候GPC爬虫池就能派上用场了&#xff0c;GPC爬虫池能够帮你大幅度提…

GD32零基础教程第一节(开发环境搭建及工程模板介绍)

文章目录 前言一、MDK keil5安装二、设备支持包安装三、CH340串口驱动安装四、STLINIK驱动安装五、工程风格介绍总结 前言 本篇文章正式带大家开始学习GD32F407VET6国产单片机的学习&#xff0c;国产单片机性能强&#xff0c;而且价格也便宜&#xff0c;下面就开始带大家来介绍…

致远互联-OA 前台fileUpload.do 绕过文件上传漏洞复现

0x01 产品简介 致远互联-OA 是数字化构建企业数字化协同运营中台,面向企业各种业务场景提供一站式大数据分析解决方案的协同办公软件。 0x02 漏洞概述 致远互联-OA 接口 fileUpload.do 接口处存在文件上传漏洞,未经身份验证的远程攻击者可通过目录遍历的方式绕过上传接口限…

阿里通义千问开源 320 亿参数模型;文字和音频自动翻译成手语Hand Talk拉近人与人的距离

✨ 1: Qwen1.5-32B Qwen1.5-32B是Qwen1.5系列中性能与效率兼顾的最新语言模型&#xff0c;内存占用低&#xff0c;运行速度快。 Qwen1.5-32B是Qwen1.5语言模型系列的最新成员&#xff0c;这个模型是基于先进的技术研发的&#xff0c;旨在提供一种既高效又经济的AI语言理解和生…

Unity面经(自整)——Unity基础知识

Unity基础知识 1. Image和RawImage的区别 Image比RawImage更耗性能。Image只能使用sprite属性的图片。而RawImage什么都可以使用 2. Unity3D中的碰撞器Collider和触发器Trigger的区别 碰撞器是触发器的载体&#xff0c;而触发器是碰撞器上的一个属性。 如果IsTrigger为fal…

Maven POM元素解析

这是对Maven中使用的Maven项目描述符的引用。 <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/…

列车调度

描述 火车站的列车调度铁轨的结构如下图所示。 两端分别是一条入口&#xff08;Entrance&#xff09;轨道和一条出口&#xff08;Exit&#xff09;轨道&#xff0c;它们之间有N条平行的轨道。每趟列车从入口可以选择任意一条轨道进入&#xff0c;最后从出口离开。在图中有9趟列…

C#操作MySQL从入门到精通(7)——对查询数据进行简单过滤

前言 我们在查询数据库中数据的时候,有时候需要剔除一些我们不想要的数据,这时候就需要对数据进行过滤,比如学生信息中,我只需要年龄等于18的,类似这种操作,本文就是详细介绍如何对查询的数据进行初步的过滤。 1、等于操作符 本次查询student_age 等于20的数据,使用我…

Open CASCADE学习|平面上的PCurve

曲面上的曲线PCurve&#xff0c;字面上理解即为参数曲线(Parametric Curve)。在几何建模中&#xff0c;PCurve通常被描述为附加在参数曲面之间公共边上的数据结构。从更具体的定义来看&#xff0c;当给定一个曲面方程&#xff0c;并且其参数u和v是另一个参数t的函数时&#xff…

PyCharm双击无法打开 安装新旧版本pycharm同时启动失败的解决办法

由于2019版本无法直接升级到2023版本 所以下载了两个版本的PyCharm 且两个都是专业版的 一个是2019的&#xff0c; 一个是2024新版的其中2019版本是破解版&#xff01; 然后现在想要打开2024的新版&#xff0c;发现双击无法启动&#xff0c;到文件所在位置打开也无法启动&a…

二维数组及其内存图解

二维数组 在一维数组的介绍当中曾说&#xff0c;数组中可以储存任何同类型的元素&#xff0c;那么这个元素是不是可以也是数组呢&#xff1f;答案是可以&#xff0c;即在数组之中储存数组元素。这种情况就是多维数组&#xff0c;当一个数组中的元素是数组时叫做二维数组&#x…

如何使用校园网——Win10笔记本,台式机互开热点

当我们使用校园网的时候&#xff0c;往往只能连接一个电脑端&#xff0c;但是又想两个机子同时连接WIFI怎么办呢&#xff1f; 当然&#xff0c;前提条件是你先得其中一台电脑有网络哈 1、打开想开共享热点的电脑的设置 A、点击WIN&#xff0c;再点击设置 2、点击网络和Inte…

论如何在小程序展示超链接在线网页

在工作中遇到一个需求&#xff0c;就是在小程序中展示超链接网页&#xff0c;起初我是直接使用web-view标签 <web-view src"https://www.baidu.com/"/>但是web-view只能在开发阶段手机上展示&#xff0c;一旦小程序发布线上&#xff0c;就会出现下面这种情况“…