政安晨:【Keras机器学习示例演绎】(二十七)—— 利用 NNCLR 进行自我监督对比学习

目录

简介

自我监督学习

对比学习

NNCLR

设置

超参数

加载数据集

增强

准备扩增模块

编码器结构

用于对比预训练的 NNCLR 模型

预训练 NNCLR


政安晨的个人主页:政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:计算机视觉自监督学习方法 NNCLR 的实现。

简介


自我监督学习


自我监督表示学习旨在从原始数据中获取稳健的样本表示,而无需昂贵的标签或注释。这一领域的早期方法侧重于定义预训练任务,这些任务涉及在一个有大量弱监督标签的领域中的代用任务。为解决此类任务而训练的编码器有望学习到一般特征,这些特征可能对其他需要昂贵注释的下游任务(如图像分类)有用。

对比学习


自监督学习技术的一大类别是使用对比损失的技术,这些技术已被广泛应用于图像相似性、降维(DrLIM)和人脸验证/识别等计算机视觉应用中。这些方法学习一个潜在空间,将正样本聚类在一起,同时将负样本推开。

NNCLR


在本示例中,我们实现了论文 With a Little Help from My Friends 中提出的 NNCLR:谷歌研究院和 DeepMind 共同发表的论文《视觉表征的最近邻对比学习》(With Little Help from My Friends: Nearest-Neighbor Contrastive Learning of Visual Representations)中提出的 NNCLR。

NNCLR 学习的是自我监督表征,它超越了单一实例的正向性,可以学习到更好的特征,这些特征不受不同视角、变形甚至类内变化的影响。基于聚类的方法提供了一种超越单一实例正向性的好方法,但假设整个聚类都是正向性,可能会由于早期过度泛化而影响性能。取而代之的是,NNCLR 将所学表示空间中的近邻作为正例。此外,NNCLR 还提高了 SimCLR(Keras 示例)等现有对比学习方法的性能,并减少了自监督方法对数据增强策略的依赖。

下面是论文作者提供的一个很好的可视化演示,展示了 NNCLR 如何以 SimCLR 的理念为基础:

我们可以看到,SimCLR 使用同一图像的两个视图作为正对。这两个视图是使用随机数据增强生成的,通过编码器得到正向嵌入对,我们最终使用了两个增强。而 NNCLR 则保留了一个代表完整数据分布的嵌入支持集,并使用最近邻方法形成正对。在训练过程中,支持集被用作内存,类似于 MoCo 中的队列(即先进先出)。

此示例需要使用 tensorflow_datasets,可通过此命令安装:

!pip install tensorflow-datasets

设置

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
from keras import ops
from keras import layers

超参数


如原始论文所示,更大的队列规模很可能意味着更好的性能,但会带来巨大的计算开销。作者指出,NNCLR 的最佳结果是在队列大小为 98,304 时(这是他们实验过的最大队列大小)取得的。我们在这里使用 10,000 来展示一个工作示例。

AUTOTUNE = tf.data.AUTOTUNE
shuffle_buffer = 5000
# The below two values are taken from https://www.tensorflow.org/datasets/catalog/stl10
labelled_train_images = 5000
unlabelled_images = 100000

temperature = 0.1
queue_size = 10000
contrastive_augmenter = {
    "brightness": 0.5,
    "name": "contrastive_augmenter",
    "scale": (0.2, 1.0),
}
classification_augmenter = {
    "brightness": 0.2,
    "name": "classification_augmenter",
    "scale": (0.5, 1.0),
}
input_shape = (96, 96, 3)
width = 128
num_epochs = 5  # Use 25 for better results
steps_per_epoch = 50  # Use 200 for better results

加载数据集


我们从 TensorFlow 数据集加载 STL-10 数据集,这是一个用于开发无监督特征学习、深度学习和自学算法的图像识别数据集。它受到 CIFAR-10 数据集的启发,并做了一些修改。

dataset_name = "stl10"


def prepare_dataset():
    unlabeled_batch_size = unlabelled_images // steps_per_epoch
    labeled_batch_size = labelled_train_images // steps_per_epoch
    batch_size = unlabeled_batch_size + labeled_batch_size

    unlabeled_train_dataset = (
        tfds.load(
            dataset_name, split="unlabelled", as_supervised=True, shuffle_files=True
        )
        .shuffle(buffer_size=shuffle_buffer)
        .batch(unlabeled_batch_size, drop_remainder=True)
    )
    labeled_train_dataset = (
        tfds.load(dataset_name, split="train", as_supervised=True, shuffle_files=True)
        .shuffle(buffer_size=shuffle_buffer)
        .batch(labeled_batch_size, drop_remainder=True)
    )
    test_dataset = (
        tfds.load(dataset_name, split="test", as_supervised=True)
        .batch(batch_size)
        .prefetch(buffer_size=AUTOTUNE)
    )
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=AUTOTUNE)

    return batch_size, train_dataset, labeled_train_dataset, test_dataset


batch_size, train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

增强


其他自监督技术,如 SimCLR、BYOL、SwAV 等,在很大程度上依赖于精心设计的数据扩增管道,以获得最佳性能。然而,NNCLR 对复杂扩增的依赖性较低,因为近邻数据已经提供了丰富的样本变化。扩增管道通常包括以下几种常见技术:

随机调整作物大小
多种颜色变形
高斯模糊

由于 NNCLR 对复杂增强的依赖性较低,我们将只使用随机裁剪和随机亮度来增强输入图像。

准备扩增模块

def augmenter(brightness, name, scale):
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            keras_cv.layers.RandomCropAndResize(
                target_size=(input_shape[0], input_shape[1]),
                crop_area_factor=scale,
                aspect_ratio_factor=(3 / 4, 4 / 3),
            ),
            keras_cv.layers.RandomBrightness(factor=brightness, value_range=(0.0, 1.0)),
        ],
        name=name,
    )

编码器结构


使用 ResNet-50 作为编码器结构是文献中的标准结构。在原论文中,作者使用 ResNet-50 作为编码器架构,并对 ResNet-50 的输出进行空间平均。不过,请记住,功能更强大的模型不仅会增加训练时间,还会需要更多内存,并限制您可以使用的最大批次规模。在本例中,我们只使用四个卷积层。

def encoder():
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Flatten(),
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",
    )

用于对比预训练的 NNCLR 模型


我们在无标签图像上训练一个有对比损失的编码器。编码器顶部安装了一个非线性投影头,因为它能提高编码器的表征质量。

class NNCLR(keras.Model):
    def __init__(
        self, temperature, queue_size,
    ):
        super().__init__()
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.correlation_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_augmenter = augmenter(**contrastive_augmenter)
        self.classification_augmenter = augmenter(**classification_augmenter)
        self.encoder = encoder()
        self.projection_head = keras.Sequential(
            [
                layers.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(10)], name="linear_probe"
        )
        self.temperature = temperature

        feature_dimensions = self.encoder.output_shape[1]
        self.feature_queue = keras.Variable(
            keras.utils.normalize(
                keras.random.normal(shape=(queue_size, feature_dimensions)),
                axis=1,
                order=2,
            ),
            trainable=False,
        )

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)
        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

    def nearest_neighbour(self, projections):
        support_similarities = ops.matmul(projections, ops.transpose(self.feature_queue))
        nn_projections = ops.take(
            self.feature_queue, ops.argmax(support_similarities, axis=1), axis=0
        )
        return projections + ops.stop_gradient(nn_projections - projections)

    def update_contrastive_accuracy(self, features_1, features_2):
        features_1 = keras.utils.normalize(features_1, axis=1, order=2)
        features_2 = keras.utils.normalize(features_2, axis=1, order=2)
        similarities = ops.matmul(features_1, ops.transpose(features_2))
        batch_size = ops.shape(features_1)[0]
        contrastive_labels = ops.arange(batch_size)
        self.contrastive_accuracy.update_state(
            ops.concatenate([contrastive_labels, contrastive_labels], axis=0),
            ops.concatenate([similarities, ops.transpose(similarities)], axis=0),
        )

    def update_correlation_accuracy(self, features_1, features_2):
        features_1 = (features_1 - ops.mean(features_1, axis=0)) / ops.std(
            features_1, axis=0
        )
        features_2 = (features_2 - ops.mean(features_2, axis=0)) / ops.std(
            features_2, axis=0
        )

        batch_size = ops.shape(features_1)[0]
        cross_correlation = (
            ops.matmul(ops.transpose(features_1), features_2) / batch_size
        )

        feature_dim = ops.shape(features_1)[1]
        correlation_labels = ops.arange(feature_dim)
        self.correlation_accuracy.update_state(
            ops.concatenate([correlation_labels, correlation_labels], axis=0),
            ops.concatenate(
                [cross_correlation, ops.transpose(cross_correlation)], axis=0
            ),
        )

    def contrastive_loss(self, projections_1, projections_2):
        projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
        projections_2 = keras.utils.normalize(projections_2, axis=1, order=2)

        similarities_1_2_1 = (
            ops.matmul(
                self.nearest_neighbour(projections_1), ops.transpose(projections_2)
            )
            / self.temperature
        )
        similarities_1_2_2 = (
             ops.matmul(
                projections_2, ops.transpose(self.nearest_neighbour(projections_1))
            )
            / self.temperature
        )

        similarities_2_1_1 = (
            ops.matmul(
                self.nearest_neighbour(projections_2), ops.transpose(projections_1)
            )
            / self.temperature
        )
        similarities_2_1_2 = (
            ops.matmul(
                projections_1, ops.transpose(self.nearest_neighbour(projections_2))
            )
            / self.temperature
        )

        batch_size = ops.shape(projections_1)[0]
        contrastive_labels = ops.arange(batch_size)
        loss = keras.losses.sparse_categorical_crossentropy(
            ops.concatenate(
                [
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                ],
                axis=0,
            ),
            ops.concatenate(
                [
                    similarities_1_2_1,
                    similarities_1_2_2,
                    similarities_2_1_1,
                    similarities_2_1_2,
                ],
                axis=0,
            ),
            from_logits=True,
        )

        self.feature_queue.assign(
            ops.concatenate([projections_1, self.feature_queue[:-batch_size]], axis=0)
        )
        return loss

    def train_step(self, data):
        (unlabeled_images, _), (labeled_images, labels) = data
        images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
        augmented_images_1 = self.contrastive_augmenter(images)
        augmented_images_2 = self.contrastive_augmenter(images)

        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1)
            features_2 = self.encoder(augmented_images_2)
            projections_1 = self.projection_head(features_1)
            projections_2 = self.projection_head(features_2)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.update_contrastive_accuracy(features_1, features_2)
        self.update_correlation_accuracy(features_1, features_2)
        preprocessed_images = self.classification_augmenter(labeled_images)

        with tf.GradientTape() as tape:
            features = self.encoder(preprocessed_images)
            class_logits = self.linear_probe(features)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_accuracy.update_state(labels, class_logits)

        return {
            "c_loss": contrastive_loss,
            "c_acc": self.contrastive_accuracy.result(),
            "r_acc": self.correlation_accuracy.result(),
            "p_loss": probe_loss,
            "p_acc": self.probe_accuracy.result(),
        }

    def test_step(self, data):
        labeled_images, labels = data

        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.encoder(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)

        self.probe_accuracy.update_state(labels, class_logits)
        return {"p_loss": probe_loss, "p_acc": self.probe_accuracy.result()}

预训练 NNCLR


我们按照论文中的建议,使用 0.1 的温度和前面解释过的 10,000 的队列大小来训练网络。我们使用 Adam 作为对比和探测优化器。在本例中,我们只对模型进行了 30 个历元的训练,但为了获得更好的性能,我们应该对模型进行更多历元的训练。

以下两个指标可用于监控预训练性能,我们也会记录这些指标(摘自 Keras 示例):

对比准确度:自监督指标,即图像表示与其不同增强版本的图像表示更相似的情况比与当前批次中任何其他图像的表示更相似的情况的比率。即使在没有标注示例的情况下,自监督指标也可用于超参数调整。


线性探测准确率:线性探测是评估自监督分类器的常用指标。它的计算方法是在编码器特征的基础上训练逻辑回归分类器的准确率。在我们的例子中,是通过在冻结编码器上训练单个密集层来实现的。需要注意的是,与在预训练阶段后训练分类器的传统方法不同,在本例中,我们在预训练阶段就对分类器进行了训练。这可能会略微降低分类器的准确性,但这样我们就可以在训练过程中监控其值,从而有助于实验和调试。

model = NNCLR(temperature=temperature, queue_size=queue_size)
model.compile(
    contrastive_optimizer=keras.optimizers.Adam(),
    probe_optimizer=keras.optimizers.Adam(),
    jit_compile=False,
)
pretrain_history = model.fit(
    train_dataset, epochs=num_epochs, validation_data=test_dataset
)

正如 SEER、SimCLR、SwAV 等以前的方法所显示的那样,当你只能获得非常有限的标注训练数据,但却能设法建立大量未标注数据的语料库时,自我监督学习就特别有用。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/588963.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DRF限流组件源码分析

DRF限流组件源码分析 开发过程中,如果某个接口不想让用户访问过于频繁,可以使用限流的机制 限流,限制用户访问频率,例如:用户1分钟最多访问100次 或者 短信验证码一天每天可以发送50次, 防止盗刷。 对于…

Spring - 7 ( 13000 字 Spring 入门级教程 )

一:Spring Boot 日志 1.1 日志概述 日志对我们来说并不陌生,我们可以通过打印日志来发现和定位问题, 或者根据日志来分析程序的运行过程,但随着项目的复杂度提升, 我们对日志的打印也有了更高的需求, 而不仅仅是定位排查问题 比如有时需要…

【LDAP】LDAP 和 AD 介绍及使用 LDAP 操作 AD 域

LDAP 和 AD 介绍及使用 LDAP 操作 AD 域 1.LDAP入门1.1 定义1.2 目录结构1.3 命名格式 2.AD 入门2.1 AD 定义2.2 作用2.3 AD 域结构常用对象2.3.1 域(Domain)2.3.2 组织单位(Organization Unit)2.3.3 群组(Group&#…

服务器数据恢复—多块磁盘离线导致阵列瘫痪,上层lun不可用的数据恢复案例

服务器存储数据恢复环境: 某品牌MSA2000存储,该存储中有一组由8块SAS硬盘(其中有一块热备盘)组建的RAID5阵列,raid5阵列上层划分了6个lun,均分配给HP-Unix小型机使用,主要数据为oracle数据库和O…

Mac 上安装多版本的 JDK 且实现 自由切换

背景 当前电脑上已经安装了 jdk8; 现在再安装 jdk17。 期望 完成 jdk17 的安装,并且完成 环境变量 的配置,实现自由切换。 前置补充知识 jdk 的安装路径 可以通过查看以下目录中的内容,确认当前已经安装的 jdk 版本。 cd /Library/Java/Java…

解决WordPress无法强制转换https问题

原因:我在用cs的时候,突然老鸟校园网突然断了,客户端cs连不上了,进程也杀不死,cpu占用100%,只能重启,但是重启后我的blog网站打不开了 开始以为是Nginx的问题,重启它说配置出了问题…

基于Springboot的在线博客网站

基于SpringbootVue的在线博客网站的设计与实现 开发语言:Java数据库:MySQL技术:SpringbootMybatis工具:IDEA、Maven、Navicat 系统展示 用户登录 首页 博客标签 博客分类 博客列表 图库相册 后台登录 后台首页 用户管理 博客标…

Word文件导出为PDF

Word文件导出为PDF 方法一、使用Word自带另存为PDF功能 打开需要转换为PDF格式的Word文件,依次点击【文件】➡【另存为】➡选择文件保存类型为.PDF 使用这种方法导出的PDF可能存在Word中书签丢失的情况,在导出界面点击,选项进入详细设置 勾…

算法系列--BFS解决拓扑排序

💕"请努力活下去"💕 作者:Lvzi 文章主要内容:算法系列–算法系列–BFS解决拓扑排序 大家好,今天为大家带来的是算法系列--BFS解决拓扑排序 前言:什么是拓扑排序 拓扑排序–解决有顺序的排序问题(要做事情的先后顺序) …

Vulntarget-a 打靶练习

关于环境配置,这里就不在附上图片和说明了,网上一大堆,这里只针对自己练习,做一个记录。 外网信息收集 利用arpscan工具,扫描了当前局域网中都存在哪些主机: 正常来说我们不应该使用arpscan,而是…

各个硬件的工作原理

目录 前言 主存储器的基本组成 运算器的基本组成 控制器的基本组成 计算机的工作过程 前言 上个小节我们学习了现代计算机的基本构成都是基于冯诺依曼的思想来设计的,那么本章节要来看看主机内部三个组件的细节以及它们之间相互协调工作的. 主存储器的基本组成 这张图非常…

WPF基础应用

WPF参考原文 MVVM介绍 1.常用布局控件 1.1 布局控件 WPF(Windows Presentation Foundation)提供了多种布局容器来帮助开发者设计用户界面,以下是一些常用的布局: Grid: Grid是最常用的布局容器之一,它允许你通过定…

暗区突围端游海外版|暗区突围怎么玩 新手游玩攻略分享

游戏中健康系统与其它射击游戏有很大区别,根据受伤部位、伤势的不同,会有不同的表现。除了头部之外,其它部位如果损坏后继续受到伤害,那么伤害将会分摊到身体其它部位。在暗区内或者暗区外都可以对角色进行治疗,角色不…

Mybatis进阶(映射关系一对一 )

文章目录 1.基本介绍1.基本说明2.映射方式 2.配置xml方式(多表联查)1.数据库表设计2.新建子模块1.创建子模块2.创建基本结构 3.MyBatisUtils.java和jdbc.properties和mybatis-config.xml与原来的一致4.IdenCard.java5.Person.java6.IdenCardMapper.java7…

使用 uni-app 开发 iOS 应用的操作步骤

哈喽呀,大家好呀,淼淼又来和大家见面啦,上一期和大家一起探讨了使用uniapp开发iOS应用的优势及劣势之后有许多小伙伴想要尝试使用uniapp开发iOS应用,但是却不懂如何使用uniapp开发iOS应用,所以这一期淼淼就来给你们分享…

TCP三次握手,四次挥手

TCP三次握手 TCP协议 : 1。源端口 :当前的进程端口,2字节 2。目的端口:对方的端口 ,2字节 3。序号:客户端或者服务器端生成的随机数 4.确认序号:确认上一次发送给数据对方有没有收到 5.标志…

三数之和细节

这道题看着简单&#xff0c;但是有细节要注意&#xff0c;不能有重复的三元组&#xff0c;我们也不能一开始的时候把重复的元素去除&#xff0c;如果全都是0的话&#xff0c;那么就删除的只剩下一个0了&#xff0c;显然答案是[0,0,0] class Solution { public:vector<vecto…

Jetpack Compose简介

文章目录 Jetpack Compose简介概述声明式UI和命令式UIJetpack Compose和Android View对比Compose API设计原则一切皆为函数组合优于继承单一数据源 Jetpack Compose和Android View关系使用ComposesetContent()源码ComposablePreview Jetpack Compose简介 概述 Jetpack Compos…

用龙梦迷你电脑福珑2.0做web服务器

用龙梦迷你电脑福珑2.0上做web服务器是可行的。已将一个网站源码放到该电脑&#xff0c;在局域网里可以访问网站网页。另外通过在同一局域网内的一台windows10电脑上安装花生壳软件&#xff0c;也可以在外网访问该内网服务器网站网页。该电脑的操作系统属于LAMP。在该电脑上安装…

【python】python标准化考试系统[单项选择题 简易版](源码)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…