Supervised Contrastive 损失函数详解

在这里插入图片描述
有什么不对的及时指出,共同学习进步。(●’◡’●)

有监督对比学习将自监督批量对比方法扩展到完全监督设置,能够有效地利用标签信息。属于同一类的点簇在嵌入空间中被拉到一起,同时将来自不同类的样本簇推开。这种损失显示出对自然损坏很稳健,并且对优化器和数据增强等超参数设置更稳定。

有监督对比学习论文的贡献

  1. 提出了对比损失函数一种新的扩展,允许每个锚点都有多个正样本,使对比学习适应完全监督设置。
  2. 该损失为很多数据集的top-1的准确率带来了提升,对自然损坏有稳健性。
  3. 损失函数的梯度鼓励从硬正样本和硬的负样本中学习。(硬的正样本与锚点图像不相似的正样本,硬的负样本就是与锚点图像相似的负样本,都是难以学习的那种)
  4. 对比损失函数不如交叉熵损失函数对超参数敏感。

自监督对比学习损失
在这里插入图片描述
有监督对比学习损失
在这里插入图片描述
文中对交叉熵损失训练,自监督对比损失训练和有监督对比损失训练进行比较
在这里插入图片描述
推理模型中的参数个数始终保持不变,应该是推理的时候就是编码器+分类头都一样。
上图是训练的时候,交叉熵损失不必说。
自监督损失一般采用的是个体判别代理任务,正样本是自身经过数据增强后的图像(一般一个正样本),其他的都是负样本,训练编码器的时候让正样本和锚点图像经过编码器得到的特征尽可能接近,与负样本之间的特征尽可能拉远。
有监督对比学习,有标签信息,正样本除了自身数据增强后的之外还有这个类别中的其他样本(一般这个batch_size中)。
stage1就是训练编码器。
stage2是训练分类头,作者指出不需要训练线性分类器,并且先前的工作已经使用k -最近邻分类或原型分类来评估分类任务上的表示。线性分类器也可以与编码器联合训练,只要不将梯度传播回编码器即可,就是分类头和编码器之间训练要分开。
有监督对比学习损失代码
对比学习对比的是特征,所以损失函数的输入是特征,有监督对比学习损失还要输入标签信息。
损失函数就是模型的输出和标签(这里是mask)之间的差距,输出和标签差距越大,那么loss就越大。
输出这里是编码器的输出就是特征,标签就是类别标签。标签是如何起作用的呢?就是让损失函数区分这个batchsize中的正负样本,属于同一类就是正样本,其他都是负样本。
其中标签mask怎么获得,一个是通过label,另一个直接输入。label是每个数据的类别信息,label.view(1,-1)变成列向量然后再与它的转置进行torch.eq(),得到一个矩阵mask,mask(i,j)如果第i个数据和第j个数据类别相同那么这个位置是True,否则为False,float就变成0,1。后面乘了一个对角线元素为0,其他位置元素为1的矩阵,就是不让每个feature与自身对比。
我们看它self.contrast_mode="one"的时候只是比较feature中第0个特征(也就是平常的第一个特征),那么锚点特征就是所有数据的第0个特征;"all"就是所有的特征都要对比;锚点特征就是所有数据的所有特征。 torch.cat(torch.unbind(features, dim=1), dim=0)把feature按照第1维拆开,然后在第0维上cat,然后比较的feature的形式就是每一个数据的第1个特征|每个数据的第2个特征|…|每个数据的第n个特征,排列,这些特征是排在一起的在一个维度上。锚点特征要么是输入特征组的每个数据的第0个特征要么就是这些比较的特征。(不太理解为什么one的时候比较特征还是所有的)
锚点特征与比较特征的转置相乘,得到的就是batch_size*channel个相似矩阵,每两个数据在这个特征下的相似度。然后这个相似度矩阵要和我们得到的mask进行比较,就是上面的第二个式子。
下面是详细解释。

"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode#设置对比的模式有one和all两种,代表对比一个channel还是所有,个人理解
        self.base_temperature = base_temperature #设置的温度

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')#设置设备
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:# batch_size, channel,H,W,平铺变成batch_size, channel, (H,W)
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:#只能存在一个
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:#如果两个都没有就是无监督对比损失,mask就是一个单位阵
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:#有标签,就把他变成mask
            labels = labels.contiguous().view(-1, 1)#contiguous深拷贝,与原来的labels没有关系,展开成一列,这样的话能够计算mask,否则labels一维的话labels.T是他本身捕获发生转置
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask =  torch.eq(labels, labels.T).float().to(device)#label和label的转置比较,感觉应该是广播机制,让label和label.T都扩充了然后进行比较,相同的是1,不同是0.
            #这里就是由label形成mask,mask(i,j)代表第i个数据和第j个数据的关系,如果两个类别相同就是1, 不同就是0
        else:
            mask = mask.float().to(device)#有mask就直接用mask,mask也是代表两个数据之间的关系

        contrast_count = features.shape[1]#对比数是channel的个数
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)#把feature按照第1维拆开,然后在第0维上cat,(batch_size*channel,h*w..)#后面就是展开的feature的维度
        #这个操作就和后面mask.repeat对上了,这个操作是第一个数据的第一维特征+第二个数据的第一维特征+第三个数据的第一维特征这样排列的与mask对应
        if self.contrast_mode == 'one':#如果mode=one,比较feature中第1维中的0号元素(batch, h*w)
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':#all就(batch*channel, h*w)
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),#两个相乘获得相似度矩阵,乘积值越大代表越相关
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)#计算其中最大值
        logits = anchor_dot_contrast - logits_max.detach()#减去最大值,都是负的了,指数就小于等于1

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)#repeat它就是把mask复制很多份
        # mask-out self-contrast cases
        logits_mask = torch.scatter(#生成一个mask形状的矩阵除了对角线上的元素是0,其他位置都是1, 不会对自身进行比较
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask#定义其中的相似度
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))#softmax

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan] 
        mask_pos_pairs = mask.sum(1)#mask的和
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)#满足返回1,不满足返回mask_pos_pairs.保证数值稳定
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos#类似蒸馏temperature温度越高,分布曲线越平滑不易陷入局部最优解,温度低,分布陡峭
        loss = loss.view(anchor_count, batch_size).mean()#计算平均

        return loss

使用的化就是下面这段:

loss = criterion(features, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/349736.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

专业远程控制软件有哪些品牌

远程办公、远程控制类的软件很多&#xff0c;主打方向和面向的客户人群也不一样。个人用户可能更在意便捷、免费等因素&#xff1b;专业用户会更注重安全性、管理功能等。今天我们介绍几个在全球知名的专业商业远程软件。 1、TeamViewer 简介&#xff1a;TeamViewer &#xf…

EXCEL VBA调用adobe的api识别电子PDF发票里内容并登记台账

EXCEL VBA调用adobe的api识别电子PDF发票里内容并登记台账 代码如下 使用须知&#xff1a; 1、工具--引用里勾选[Adobe Acrobat 10.0 Type Library] 2、安装Adobe Acrobat pro软件Dim sht As Worksheet Function BrowseFolders() As String 浏览目录Dim objshell As ObjectDim…

暗藏危险,警惕钓鱼邮件!

叮 您有一份福利待查收 您的信息资产需要排查 您的账户异常需要验证 这些看似“重要”的邮件 都藏着攻击者的恶意嘴脸 随着网络安全防护和建设的重要性日益凸显&#xff0c;国家安全、企业安全、合规需求及业务驱动等各个方面都亟需将网络安全作为基石。在企业业务转型发展…

【C++中STL】stack和queue容器

stack和queue stack基本概念常用接口 quque基本概念常用接口 stack 基本概念 stack是一种先进后出的数据结构&#xff0c;它只有一个出口 栈中只有顶端的元素可以被外界使用&#xff0c;因此栈不允许由遍历行为 可以判断是否为空empty(),和统计个数size(); 常用接口 1、st…

服务器是什么?(四种服务器类型)

服务器 服务器定义广义: 专门给其他机器提供服务的计算机。狭义:一台高性能的计算机&#xff0c;通过网络提供外部计算机一些业务服务 个人PC内存大概8G&#xff0c;服务器内存128G起步 服务器是什么 服务器指的是 网络中能对其他机器提供某些服务的计算机系统 &#xff0c;相对…

用Yara对红队工具“打标”

前言: YARA 通常是帮助恶意软件研究人员识别和分类恶意软件样本的工具&#xff0c;它基于文本或二进制模式创建恶意样本的描述规则&#xff0c;每个规则由一组字符串和一个布尔表达式组成&#xff0c;这些表达式决定了它的逻辑。 但是这次我们尝试使用 YARA 作为一种扫描工具…

【好书推荐-第五期】《互联网大厂推荐算法实战》(异步图书出品)

&#x1f60e; 作者介绍&#xff1a;我是程序员洲洲&#xff0c;一个热爱写作的非著名程序员。CSDN全栈优质领域创作者、华为云博客社区云享专家、阿里云博客社区专家博主、前后端开发、人工智能研究生。公粽号&#xff1a;程序员洲洲。 &#x1f388; 本文专栏&#xff1a;本文…

机器学习 | 深入探索Numpy的高性能计算能力

目录 初识numpy numpy基本操作 数组的基本操作 ndarray运算 数组间运算 矩阵 初识numpy Numpy&#xff08;Numerical Python&#xff09;是一个开源的Python科学计算库&#xff0c;用于快速处理任意维度的数组。Numpy支持常见的数组和矩阵操作。对于同样的数值计算任务&…

k8s 进阶实战笔记 | Pod 创建过程详解

Pod 创建过程详解 ​ 初始状态0 controller-manager、scheduler、kubelet组件通过 list-watch 机制与 api-server 通信并检查资源变化 第一步 用户通过 CLI 或者 WEB 端等方式向 api-server 发送创建资源的请求&#xff08;比如&#xff1a;我要创建一个replicaset资源&…

Hadoop3.x源码解析

文章目录 一、RPC通信原理解析1、概要2、代码demo 二、NameNode启动源码解析1、概述2、启动9870端口服务3、加载镜像文件和编辑日志4、初始化NN的RPC服务端5、NN启动资源检查6、NN对心跳超时判断7、安全模式 三、DataNode启动源码解析1、概述2、初始化DataXceiverServer3、初始…

聚道云软件连接器:打通金蝶云星空与招商银行CBS,提升企业财务和银行业务效率

【客户介绍】 某企业是一家从事电子商务的企业&#xff0c;随着业务的不断扩大&#xff0c;对于财务管理和银行业务的需求也越来越高。该企业希望能够实现财务和银行业务的自动化处理&#xff0c;提高工作效率。由于业务的不断发展&#xff0c;企业面临着越来越多的资金管理挑…

零基础学习数学建模——(四)备战美赛

本篇博客将讲解如何备战美赛。 什么是美赛 美赛&#xff0c;全称是美国大学生数学建模竞赛&#xff08;MCM/ICM&#xff09;&#xff0c;由美国数学及其应用联合会主办&#xff0c;是最高的国际性数学建模竞赛&#xff0c;也是世界范围内最具影响力的数学建模竞赛。 赛题内容…

Oracle触发器简单应用示例(销售与库存)

目录 一、应用描述 1、应用场景&#xff1a; 2、具体场景&#xff1a; 二、表结构介绍 1、表名介绍&#xff1a; 2、表结构&#xff1a; 三、设置触发器 四、运行示例 1、初始库存描述 2、有库存情况 2.1 1001号产品售出1件 2.2 1001号产品库存已减1 3、无库存情况…

外汇天眼:QoinTech误信假老师话术投资外汇,惨遭黑平台滑点爆仓拒出金

去年11月与12月&#xff0c;外汇天眼先后发布了「钓鱼广告诱加投资群组&#xff0c;限制出金逼迫缴分成费」与「假投顾诱导投资黄金获利&#xff0c;黑平台操作爆仓狠诈700万」这2篇文章&#xff0c;曝光黑平台QoinTech的诈骗手法&#xff0c;呼吁投资人不要上当&#xff0c;没…

[SwiftUI]修改状态栏文字颜色

问题&#xff1a; 如图&#xff0c;在项目 Info.plist 中&#xff0c;将 UIViewControllerBasedStatusBarAppearance 设置为 NO&#xff0c;将UIStatusBarStyle设置为Light Content后&#xff0c;APP的状态栏字体颜色仍然是黑色没变成白色。 修复&#xff1a; https://stacko…

uniapp vuecli项目融合[小记]:将多个项目融合,打包成一个小程序/App,拆分多个H5应用

前言&#xff1a; 目前两个uniapp vuecli开发的项目【A、B】&#xff0c;新规划的项目C&#xff1a;需要融合项目B 80%的功能模块&#xff0c;同时也需要涵盖项目A的所有功能模块。 应用需求&#xff1a; 1、新项目C【小程序】可支持切换到应用A/C界面【内部通过初始化、路由跳…

便捷接口调测:API 开发工具大比拼 | 开源专题 No.62

hoppscotch/hoppscotch Stars: 56.1k License: MIT Hoppscotch 是一个开源的 API 开发生态系统&#xff0c;主要功能包括发送请求和获取实时响应。该项目具有以下核心优势&#xff1a; 轻量级&#xff1a;采用简约的 UI 设计。快速&#xff1a;实时发送请求并获得响应。支持多…

直播项目开发

uni-aapp&#xff0c;egg.js&#xff0c;直播服务器自己搭建&#xff0c;Node.js&#xff0c;socket.io实时送礼物&#xff0c;充值&#xff0c;兼容Android&#xff0c;iOS,小程序&#xff0c;充值时用到微信支付&#xff0c;直播分为主播端和用户端&#xff0c;主播端有摄像头…

湿法蚀刻酸洗槽—— 应用半导体新能源光伏光电行业

PFA清洗槽又被称为防腐蚀槽、酸洗槽、溢流槽、纯水槽、浸泡槽、水箱、滴流槽&#xff0c;是四氟清洗桶后的升级款&#xff0c;是为半导体光伏光电等行业设计&#xff0c;一体成型&#xff0c;无需担心漏液。主要用于浸泡、清洗带芯片硅片电池片的花篮。 由于PFA的特点它能耐受…

分钟级实时数据分析的背后——实时湖仓产品解决方案

随着信息技术的深入应用&#xff0c;企业对市场的响应速度也在不断提升&#xff0c;而且这种响应速度正在变得越来越快&#xff0c;没有最快只有更快。对数据实时性要求的提高&#xff0c;是眼下很多企业遇到的一个新的挑战。 从生产侧的视角来看&#xff0c;系统实时监控与实…