【机器学习】机器学习的基本分类-自监督学习-对比学习(Contrastive Learning)

对比学习是一种自监督学习方法,其目标是学习数据的表征(representation),使得在表征空间中,相似的样本距离更近,不相似的样本距离更远。通过设计对比损失函数(Contrastive Loss),模型能够有效捕捉数据的语义结构。


核心思想

对比学习的关键在于:

  1. 正样本(Positive Pair):具有相似语义或来源的样本对,例如同一图像的不同增强版本。
  2. 负样本(Negative Pair):语义不同或来源不同的样本对,例如不同图像。

通过对比正负样本对,模型能够学习区分不同数据点的特征。


方法流程

  1. 数据增强:对一个样本 x 应用两种不同的增强方法,生成 x_1, x_2​,作为正样本对。
  2. 特征提取:通过编码器(如卷积神经网络)将数据映射到潜在特征空间,得到表征 z_1, z_2
  3. 对比损失:设计损失函数,使正样本对的表征距离最小化,负样本对的表征距离最大化。

对比学习的损失函数

1. 对比损失(Contrastive Loss)

对比损失鼓励正样本对的距离更小,负样本对的距离更大。

L = \frac{1}{N} \sum_{i=1}^N \left[ y_i \cdot d(z_i, z_j)^2 + (1 - y_i) \cdot \max(0, m - d(z_i, z_j))^2 \right]

  • y_i:样本对是否为正样本(1 表示正样本,0 表示负样本)。
  • d(z_i, z_j):样本对在表征空间中的距离(通常使用欧氏距离)。
  • m:负样本对的最小距离(margin)。
2. InfoNCE 损失

用于最大化正样本对的相似性,同时将负样本对的相似性最小化。

L = - \log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{N} \exp(\text{sim}(z_i, z_k) / \tau)}

  • \text{sim}(z_i, z_j) = \frac{z_i \cdot z_j}{\|z_i\| \|z_j\|}:余弦相似度。
  • \tau:温度参数,用于控制分布的平滑程度。
  • N:批量中样本数量。

典型方法

1. SimCLR

SimCLR 是对比学习的经典方法之一:

  • 核心思想:通过数据增强生成正样本对,并利用 InfoNCE 损失函数进行优化。
  • 数据增强:随机裁剪、颜色抖动、模糊等。
2. MoCo(Momentum Contrast)

通过维护一个动态更新的“字典”,解决负样本数量不足的问题。

  • 核心思想:使用动量编码器(momentum encoder)生成更多的负样本。
3. BYOL(Bootstrap Your Own Latent)

无需显式的负样本,通过自回归(self-prediction)学习特征表征。

  • 核心思想:一个在线网络(Online Network)和一个目标网络(Target Network)协同训练。
4. SWAV(Swapping Assignments Between Views)

结合聚类和对比学习,利用图像的多视图表征。

  • 核心思想:通过在线分配伪标签,避免显式使用负样本。

示例代码:SimCLR

以下是一个实现 SimCLR 的示例代码:

import tensorflow as tf
from tensorflow.keras import layers, models


# 图像增强函数
def augment_image(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, size=(32, 32, 3))
    image = tf.image.random_brightness(image, max_delta=0.5)
    return image


# 定义编码器
def create_encoder():
    base_model = tf.keras.applications.ResNet50(include_top=False, pooling='avg', input_shape=(32, 32, 3))
    return models.Model(inputs=base_model.input, outputs=base_model.output)


# SimCLR 模型
class SimCLRModel(tf.keras.Model):
    def __init__(self, encoder, projection_dim):
        super(SimCLRModel, self).__init__()
        self.encoder = encoder
        self.projection_head = tf.keras.Sequential([
            layers.Dense(256, activation='relu'),
            layers.Dense(projection_dim)
        ])

    def call(self, x):
        features = self.encoder(x)
        projections = self.projection_head(features)
        return tf.math.l2_normalize(projections, axis=1)


# 构建模型
encoder = create_encoder()
simclr_model = SimCLRModel(encoder, projection_dim=128)


# InfoNCE 损失
def info_nce_loss(features, temperature=0.5):
    batch_size = tf.shape(features)[0]
    labels = tf.range(batch_size)
    similarity_matrix = tf.matmul(features, features, transpose_b=True)
    logits = similarity_matrix / temperature
    return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True))


# 训练
(X_train, _), _ = tf.keras.datasets.cifar10.load_data()
X_train = tf.image.resize(X_train, (32, 32)) / 255.0


def preprocess_data(image):
    return augment_image(image), augment_image(image)


train_data = tf.data.Dataset.from_tensor_slices(X_train)
train_data = train_data.map(preprocess_data).batch(32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

for epoch in range(10):
    for x1, x2 in train_data:
        with tf.GradientTape() as tape:
            z1 = simclr_model(x1)
            z2 = simclr_model(x2)
            loss = info_nce_loss(tf.concat([z1, z2], axis=0))
        gradients = tape.gradient(loss, simclr_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, simclr_model.trainable_variables))
    print(f"Epoch {epoch + 1}, Loss: {loss.numpy()}")

输出结果

Epoch 1, Loss: 3.465735912322998
Epoch 2, Loss: 3.465735912322998
Epoch 3, Loss: 3.465735912322998
Epoch 4, Loss: 3.465735912322998
Epoch 5, Loss: 3.465735912322998

对比学习的优势与挑战

优势
  1. 无需标签数据:适用于大规模无标签数据集。
  2. 高质量特征:学习的表征具有很强的迁移能力。
  3. 通用性强:适用于图像、文本、语音等多种模态。
挑战
  1. 负样本选择:负样本数量和质量对性能影响大。
  2. 计算成本:对比学习需要大量计算资源,尤其是在大规模数据上训练。
  3. 超参数调整:温度参数等对模型表现至关重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/948575.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Flutter-插件 scroll-to-index 实现 listView 滚动到指定索引位置

scroll-to-index 简介 scroll_to_index 是一个 Flutter 插件,用于通过索引滚动到 ListView 中的某个特定项。这个库对复杂滚动需求(如动态高度的列表项)非常实用,因为它会自动计算需要滚动的目标位置。 使用 安装插件 flutte…

国内Ubuntu环境Docker部署CosyVoice

国内Ubuntu环境Docker部署CosyVoice 本文旨在记录在 国内 CosyVoice项目在 Ubuntu 环境下如何使用 dockermin-conda进行一键部署。 源项目地址: https://github.com/FunAudioLLM/CosyVoice 如果想要使用 dockerpython 进行部署,可以参考我另一篇博客中的…

计算机网络•自顶向下方法:网络层介绍、路由器的组成

网络层介绍 网络层服务:网络层为传输层提供主机到主机的通信服务 每一台主机和路由器都运行网络层协议 发送终端:将传输层报文段封装到网络层分组中,发送给边缘路由器路由器:将分组从输入链路转发到输出链路接收终端&#xff1…

信创云之天翼云:引领信创云时代的先锋力量

数据显示,2024年中国云服务市场规模已达到4242.5亿元,显示出各行业对信息技术软硬件的依赖程度不断加深。在国家政策的持续支持下,数字化转型为云服务行业带来了前所未有的发展机遇。预计到2025年,中国云服务市场规模将突破4795.4…

Elasticsearch:基础概念

一、什么是Elasticsearch Elasticsearch是基于 Apache Lucene 构建的分布式搜索和分析引擎、可扩展数据存储和矢量数据库。它针对生产规模工作负载的速度和相关性进行了优化。使用 Elasticsearch 可以近乎实时地搜索、索引、存储和分析各种形状和大小的数据。Elasticsearch 是…

智慧工地解决方案 1

建设背景与挑战 工地施工现场环境复杂,人员管理难度大,多工种交叉作业导致管理混乱,事故频发。传统管理方式难以实现科学、有效、集中式的管理,特别是在环境复杂、地点分散的情况下,监管困难,取证复杂。施…

若依中Feign调用的具体使用(若依微服务版自身已集成openfeign依赖,并在此基础上定义了自己的注解)

若依中Feign调用具体使用 注意:以下所有步骤实现的前提是需要在启动类上加入注解 EnableRyFeignClients 主要是为开启feign接口扫描 1.创建服务提供者(provider) 导入依赖(我在分析依赖时发现若依本身已经引入openfeign依赖,并在此基础上自定义了自己的EnableRyF…

基于Springboot +Vue 实验课程预约管理系统

基于Springboot Vue 实验课程预约管理系统 前言 在现代教育领域,实验课程预约管理系统扮演着至关重要的角色。随着教学资源的日益紧张和学生需求的多样化,传统的人工管理方式已难以满足高效、透明的课程安排需求。基于SpringBootVue的实验课程预约管理…

CSS2笔记

一、CSS基础 1.CSS简介 2.CSS的编写位置 2.1 行内样式 2.2 内部样式 2.3 外部样式 3.样式表的优先级 4.CSS语法规范 5.CSS代码风格 二、CSS选择器 1.CSS基本选择器 通配选择器元素选择器类选择器id选择器 1.1 通配选择器 1.2 元素选择器 1.3 类选择器 1.4 ID选择器 1.5 基…

【偏好对齐】通过ORM直接推导出PRM

论文地址:https://arxiv.org/pdf/2412.01981 相关博客 【自然语言处理】【大模型】 ΨPO:一个理解人类偏好学习的统一理论框架 【强化学习】PPO:近端策略优化算法 【偏好对齐】PRM应该奖励单个步骤的正确性吗? 【偏好对齐】通过OR…

springmvc--请求参数的绑定

目录 一、创建项目,pom文件 二、web.xml 三、spring-mvc.xml 四、index.jsp 五、实体类 Address类 User类 六、UserController类 七、请求参数解决中文乱码 八、配置tomcat,然后启动tomcat 1. 2. 3. 4. 九、接收Map类型 1.直接接收Map类型 &#x…

第五届电网系统与绿色能源国际学术会议(PGSGE 2025)

2025年第五届电网系统与绿色能源国际学术会议(PGSGE 2025) 定于2025年01月10-12日在吉隆坡召开。 第五届电网系统与绿色能源国际学术会议(PGSGE 2025) 基本信息 会议官网:www.pgsge.org【点击投稿/了解会议详情】 会议时间:202…

CSS——4. 行内样式和内部样式(即CSS引入方式)

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title>方法1&#xff1a;行内样式</title></head><body><!--css引入方式&#xff1a;--><!--css的引入的第一种方法叫&#xff1a;行内样式将css代码写…

彩色图像分割—香蕉提取

实验任务 彩色图像分割—香蕉提取 利用香蕉和其它水果及其背景颜色在R,G,B分量上的差异进行识别,根据香 蕉和其它水果在R,G,B分量的二值化处理&#xff0c;获得特征提取的有效区域&#xff0c;然后提取 特征&#xff0c;达到提取香蕉的目的。附&#xff1a;统计各种水果及个数…

【算法】克里金(Kriging)插值原理及Python应用

文章目录 [toc] 前言一、克里金插值原理1.1 概述1.2 基本公式1.2 权重 w i w_i wi​的确定1.3 拟合函数的确定 二、Python建模与可视化2.1 Demo2.1.1 随机生成已知格网点2.1.2 拟合2.1.3 评估内符合精度2.1.3 内插未知格网点2.1.4 画图 2.2 结果图 参考文献 前言 最近学习了一下…

QML自定义滑动条Slider的样式

代码展示 import QtQuick 2.9 import QtQuick.Window 2.2 import QtQuick.Controls 2.1Window {visible: truewidth: 640height: 480title: qsTr("Hello World")Slider {id: controlvalue: 0.5background: Rectangle {x: control.leftPaddingy: control.topPadding …

Android Studio学习笔记

01-课程前面的话 02-Android 发展历程 03-Android 开发机器配置要求 04-Android Studio与SDK下载安装 05-创建工程与创建模拟器 在 Android Studio 中显示 “Device Manager” 有以下几种方法&#xff1a; 通过菜单选项 打开 Android Studio&#xff0c;确保已经打开了一个…

Qt天气预报系统设计界面布局第四部分右边

Qt天气预报系统 1、第四部分右边的第一部分1.1添加控件 2、第四部分右边的第二部分2.1添加控件 3、第四部分右边的第三部分3.1添加控件3.2修改控件名字 1、第四部分右边的第一部分 1.1添加控件 拖入一个widget&#xff0c;改名为widget04r作为第四部分的右边 往widget04r再拖…

Spring boot + Hibernate + MySQL实现用户管理示例

安装MySQL Windows 11 Mysql 安装及常用命令_windows11 mysql-CSDN博客 整体目录 pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLS…

Spring Boot 整合 Keycloak

1、概览 本文将带你了解如何设置 Keycloak 服务器&#xff0c;以及如何使用 Spring Security OAuth2.0 将Spring Boot应用连接到 Keycloak 服务器。 2、Keycloak 是什么&#xff1f; Keycloak是针对现代应用和服务的开源身份和访问管理解决方案。 Keycloak 提供了诸如单点登…