详细解析Barlow Twins:自监督学习中的创新方法

首先先简单了解一下机器学习中,主要有三种学习范式:监督学习、无监督学习和自监督学习:

  • 监督学习:依赖带标签的数据,通过输入输出映射关系进行训练。
  • 无监督学习:不依赖标签,关注数据的内在结构和模式。
  • 自监督学习:利用数据本身生成标签,通过预训练任务学习有效的特征表示。

Barlow Twins

Barlow Twins是一种基于信息论的自监督学习方法,其目标是减少神经元之间的冗余。该方法要求神经元对数据增强具有不变性,但彼此独立。

在实际训练中,通过反向传播(backpropagation)调整神经网络的参数,使得交叉相关矩阵的对角线元素尽可能大,而非对角线元素尽可能小——接近单位矩阵,从而达到上述目标。

1 例子

假设我们有一张图片 X X X ,经过两个不同的数据增强得到图像 Y A Y^A YA Y B Y^B YB ,其再通过相同的神经网络得到特征表示 Z A Z^A ZA Z B Z^B ZB (假设有RGB三维)。由于是同一张图片, Z A Z^A ZA 的蓝色与 Z B Z^B ZB 的蓝色应该相似(红绿同理),同时为了最大限度减少冗余,我们希望特征彼此本身不同(即 Z A Z^A ZA 中的蓝绿红彼此不同) —— 对数据增强保持不变,但独立于其他

image-20240530211739303

数学上描述即为:计算特征表示 Z A Z^A ZA Z B Z^B ZB 的交叉相关矩阵,目标为使该矩阵接近单位矩阵。

这张图展示了Barlow Twins方法的主要流程。具体步骤如下:

  1. 数据增强
    • 从输入图像 X X X 出发,使用不同的数据增强变换 T T T 生成两组扭曲图像 Y A Y^A YA Y B Y^B YB。这些变换包括随机裁剪、翻转、颜色抖动等。
  2. 特征提取
    • 将扭曲图像 Y A Y^A YA Y B Y^B YB 输入相同的神经网络 f θ f_\theta fθ,生成对应的特征表示 Z A Z^A ZA Z B Z^B ZB
  3. 计算交叉相关矩阵
    • 计算特征表示 Z A Z^A ZA Z B Z^B ZB交叉相关矩阵。目标是使该矩阵接近单位矩阵,从而:
      • 对角线元素:希望在不同数据增强下,相同神经元的特征表示具有高度相关性(接近1)。
      • 非对角线元素:希望不同神经元之间没有冗余(接近0)。

2 Loss计算

交叉相关矩阵 C i j C_{ij} Cij​ 的计算

衡量了不同增强视图下神经元之间的相关性
C i j = ∑ b z b , i A z b , j B ∑ b ( z b , i A ) 2 ∑ b ( z b , j B ) 2 C_{ij} = \frac{\sum_b z^A_{b,i} z^B_{b,j}}{\sqrt{\sum_b (z^A_{b,i})^2} \sqrt{\sum_b (z^B_{b,j})^2}} Cij=b(zb,iA)2 b(zb,jB)2 bzb,iAzb,jB

  • z b , i A z^A_{b,i} zb,iA z b , j B z^B_{b,j} zb,jB 分别表示第 b b b 个样本在增强视图 A A A B B B 中第 i i i 和第 j j j 个神经元的特征表示。
损失函数 L B T \mathcal{L}_{BT} LBT

L B T = ∑ i ( 1 − C i i ) 2 + λ ∑ i ∑ j ≠ i C i j 2 \mathcal{L}_{BT} = \sum_i (1 - C_{ii})^2 + \lambda \sum_i \sum_{j \neq i} C_{ij}^2 LBT=i(1Cii)2+λij=iCij2

  • 不变性项:
    ∑ i ( 1 − C i i ) 2 \sum_i (1 - C_{ii})^2 i(1Cii)2 这个部分希望对角线上的元素 C i i C_{ii} Cii 尽可能接近1,表示在不同增强视图下,相同神经元的特征表示高度相关。

  • 冗余减少项:
    λ ∑ i ∑ j ≠ i C i j 2 \lambda \sum_i \sum_{j \neq i} C_{ij}^2 λij=iCij2 这个部分希望非对角线上的元素 C i j C_{ij} Cij 尽可能接近0,表示不同神经元之间没有冗余。系数 λ \lambda λ 是一个超参数,用来平衡这两个项的权重。

整个Barlow Twins的关键即损失函数:

返回方阵非对角线元素的扁平(一维)视图函数:

  1. x.flatten()[:-1]:首先,将方阵x扁平化(即将其转换为一维数组),然后删除最后一个元素。扁平化后的数组中,最后一个元素是方阵的最后一个对角线元素。

  2. .view(n - 1, m + 1):然后,将扁平化后的数组重新塑形为一个(n - 1, m + 1)的矩阵。这个矩阵的每一行都包含了原方阵的一行元素。

  3. [:, 1:]:接着,删除矩阵的第一列。这一列包含了原方阵的剩余所有对角线元素。

  4. .flatten():最后,再次将矩阵扁平化。这样,得到的就是一个包含了原方阵所有非对角线元素的一维数组。

def off_diagonal(x):
    '''
    返回方阵非对角线元素的扁平(一维)视图
    '''
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, m + 1)[:, 1:].flatten()

barlow_loss计算函数:

def barlow_loss(z1, z2, bn, lambd):
    '''
    返回一对特征的Barlow Twins的loss

    :param z1:第一个输入特征
    :param z2:第二个输入特征
    :param bn:应用于 z1 和 z2 的 nn.BatchNorm1d 层
    :param lambd:权衡超参数 lambda
    '''
	# 批量归一化
    z1_norm = bn(z1)
    z2_norm = bn(z2)

    batch_size = z1.size(0)

    # 计算 z1 和 z2 的协方差矩阵
    c = torch.mm(z1_norm, z2_norm.t()) / batch_size

    # loss
    c_diff = (c - torch.eye(c.size(0), device=c.device)).pow(2)
    c_diff = off_diagonal(c_diff).mul_(lambd)
    loss = c_diff.sum()

    return loss

3 整体流程

整体流程的伪代码如下:

# 训练循环
for x in loader:  # 加载一个批次包含N个样本
    # 对每个样本生成两个随机增强版本
    y_a, y_b = augment(x)  # augment函数生成数据增强版本
    
    # 计算表征
    z_a = f(y_a)  # NxD
    z_b = f(y_b)  # NxD
    
    # 沿批次维度标准化表征
    z_a_norm = (z_a - z_a.mean(dim=0)) / z_a.std(dim=0)  # NxD
    z_b_norm = (z_b - z_b.mean(dim=0)) / z_b.std(dim=0)  # NxD
    
    # 计算交叉相关矩阵
    c = torch.mm(z_a_norm.T, z_b_norm) / N  # DxD
    
    # 计算损失
    c_diff = (c - torch.eye(D, device=c.device)).pow(2)  # DxD
    # 将非对角线元素乘以lambda
    off_diagonal(c_diff).mul_(lambda_off_diag)
    loss = c_diff.sum()
    
    # 优化步骤
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/665070.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

整合Spring Boot 框架集成Knife4j

本次示例使用Spring Boot作为脚手架来快速集成Knife4j,Spring Boot版本2.3.5.RELEASE ,Knife4j版本2.0.7 POM.XML完整文件代码如下&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0…

美创CTO周杰受邀参加2024省级现代服务业高研班,分享“人工智能数据安全与防护技术”

近日&#xff0c;为期三天的省级现代服务业“模型生态应用与安全治理”高级研修班在杭州成功举办。 本次高研班由浙江省人社厅、浙江省委网信办指导&#xff0c;浙江省网络空间安全协会主办&#xff0c;旨在抢抓新一轮人工智能带来的科技革命与产业变革新机遇&#xff0c;助推浙…

C++入门——类和对象【3】(6)

前言 本节是C类和对象中的最后一节&#xff0c;学完本节内容并且能够掌握之前所学的所有内容的话&#xff0c;C就可以说是入门了&#xff0c;那我们废话不多说&#xff0c;正式进入今天的学习 1. 再谈构造函数 1.1 引入 我们在栈的背景下来看 栈的代码&#xff1a; ​type…

Docker Hub 国内镜像源配置

Docker Hub 国内镜像源配置 Docker Hub 国内镜像源是指在国内境内提供 Docker 镜像服务的镜像源。由于国际网络带宽等问题&#xff0c;国内用户下载 Docker 镜像通常速度较慢。因此&#xff0c;为了解决这个问题&#xff0c;一些国内的公司和组织提供了 Docker 镜像的国内镜像…

什么是机器人离线编程? 衡祖仿真

一、什么是机器人离线编程&#xff1f; 机器人离线编程是自动化生产的重要一环。离线编程指&#xff0c;在建立了机器人的三维模拟场景后&#xff0c;经由软件仿真计算&#xff0c;生成控制机器人运动轨迹&#xff0c;进而生成机器人的控制指令。工程师可以由此来控制物理环境…

基于SSM框架的手机商城项目

后端: 订单管理 客户管理&#xff1a; 商品管理 类目管理 前端&#xff1a; 首页&#xff1a;

算法(十一)贪婪算法

文章目录 算法简介算法概念算法举例 经典问题 -背包问题 算法简介 算法概念 贪婪算法&#xff08;Greedy&#xff09;是一种在每一步都采取当前状态下最好的或者最优的选择&#xff0c;从而希望导致结果也是全局最好或者最优的算法。贪婪算法是当下局部的最优判断&#xff0c…

java向上转型

介绍 代码 父类 package b;public class father_ {//father classString name"动物";int age10;public void sleep() {System.out.println("睡");}public void run() {System.out.println("跑");}public void eat() {System.out.println("…

案例|开发一个美业小程序,都有什么功能

随着移动互联网的迅猛发展&#xff0c;美业连锁机构纷纷寻求数字化转型&#xff0c;以小程序为载体&#xff0c;提升服务效率&#xff0c;增强客户体验。 线下店现在面临的困境&#xff1a; 客户到店排队时间过长&#xff0c;体验感受差 新客引流难&#xff0c;老用户回头客…

实验---DC-AC逆变器(1)---EG8010+NSI6602驱动IGBT实验

一、设计电路 1.LCC 主回路模块原理图 1.1 电源部分 这个电源部分电路图是一个简单而有效的DC-DC转换器设计&#xff0c;包含输入保护和滤波、电源模块、以及输出滤波和稳定。 a. 输入电源部分 输入电源 (E12V): 电路从E12V端子接收12V的直流电源。这是整个电路的输入电源。…

【知识拓展】机器学习基础(二):什么是模型、自定义模型、模型训练、模型调优

前言 接上文&#xff0c;前文对模型没有过多介绍&#xff0c;随着看的资料增多&#xff0c;对模型有了更多的自我认识&#xff0c;记录一下。要了解模型&#xff0c;我们先从零开始创建一个模型开始&#xff1a; 最简单的方法是使用Python和scikit-learn库。关于scikit-learn库…

Maven 中的 classifier 属性用过没?

最近训练营有小伙伴问到松哥一个关于 Maven 依赖的问题&#xff0c;涉及到 classifier 属性&#xff0c;随机问了几个小伙伴&#xff0c;都说工作中没用到过&#xff0c;因此简单整篇文章和小伙伴们分享下。 Maven 大家日常开发应该都有使用&#xff0c;Maven 中有一个比较好玩…

深入理解 Go 语言中的字符串不可变性与底层实现

文章目录 前言1 字符串类型的数据结构组成2 为什么要这么设计数据结构&#xff1f;3 为什么说字符串类型不可修改&#xff1f;4 如何实现字符串的修改&#xff1f;5 为什么字符串修改的字面量用单引号&#xff1f;6 如何判断字符串的修改新建了一个字符串&#xff1f;7 字符串的…

DevExpress开发WPF应用实现对话框总结

说明&#xff1a; 完整代码Github​&#xff08;https://github.com/VinciYan/DXMessageBoxDemos.git&#xff09;DevExpree v23.2.4&#xff08;链接&#xff1a;https://pan.baidu.com/s/1eGWwCKAr8lJ_PBWZ_R6SkQ?pwd9jwc 提取码&#xff1a;9jwc&#xff09;使用Visual St…

Rust之函数式语言特性:迭代器和闭包(一):概述

开发环境 Windows 11Rust 1.78.0 VS Code 1.89.1 项目工程 这次创建了新的工程minigrep. 函数式语言特性:迭代器和闭包 Rust的设计从许多现有语言和技术中获得了灵感&#xff0c;其中一个重要影响是函数式编程。函数式编程通常包括通过在参数中传递函数、从其他函数返回函数、…

CameraProvider启动流程

从Android 8.0之后&#xff0c;Android 引入Treble机制&#xff0c;主要是为了解决目前Android 版本之间升级麻烦的问题&#xff0c;将OEM适配的部分vendor与google 对android 大框架升级的部分system部分做了分离&#xff0c;一旦适配了一个版本的vendor信息之后&#xff0c;之…

告别低效提问:掌握BARD技巧,让AI成为你的智能助手!

今天只聊一个主题&#xff1a;提示词 Prompt。 说到提示词&#xff0c;大家可能都看过GPT的高级示例&#xff0c;那些几百字的提示词&#xff0c;写起来确实不容易。 那么&#xff0c;如何写出同样效果的提示词呢&#xff1f; 有没有什么公式或者系统学习的方法&#xff1f;…

在CentOS7下构建TeamSpeak服务器并增加网易云点歌插件

文章目录 部署TeamSpeak创建一个新用户下载并解压服务端下载解压 启动服务端同意许可协议启动与配置开放端口设置开机自启 客户端连接 部署TS3AudioBot并添加网易云插件安装ffmpeg下载TS3AudioBot本体与插件并解压配置TS3AudioBot启动设置开机自启 部署网易云API安装git安装Nod…

5.23R语言-参数假设检验

理论 方差分析&#xff08;ANOVA, Analysis of Variance&#xff09;是统计学中用来比较多个样本均值之间差异的一种方法。它通过将总变异分解为不同来源的变异来检测因子对响应变量的影响。方差分析广泛应用于实验设计、质量控制、医学研究等领域。 方差分析的基本模型 方差…

ReDos攻击浅析

DOS为拒绝服务攻击&#xff0c;re则是由于正则表达式使用不当&#xff0c;陷入正则引擎的回溯陷阱导致服务崩溃&#xff0c;大量消耗后台性能 正则 ​ 探讨redos攻击之前&#xff0c;首先了解下正则的一些知识 执行过程 大体的执行过程分为: 编译 -> 执行编译过程中&…