Focal Loss-解决样本标签分布不平衡问题

文章目录

    • 背景
      • 交叉熵损失函数
      • 平衡交叉熵函数
    • Focal Loss损失函数
      • Focal Loss vs Balanced Cross Entropy
      • Why does Focal Loss work?
    • 针对VidHOI数据集
    • Reference

背景

Focal Loss由何凯明提出,最初用于图像领域解决数据不平衡造成的模型性能问题。

交叉熵损失函数

L o s s = L ( y , p ^ ) = − y l o g ( p ^ ) − ( 1 − y ) l o g ( 1 − p ^ ) Loss=L(y,\hat{p})=-ylog(\hat{p})-(1-y)log(1-\hat{p}) Loss=L(y,p^)=ylog(p^)(1y)log(1p^)

其中, p ^ \hat{p} p^为预测概率大小。y为label,二分类中对应0和1。
L c e ( y , p ^ ) = { − l o g ( p ^ ) , if  y = 1 − l o g ( 1 − p ^ ) , if  y = 0 L_{ce}(y,\hat{p})= \left\{ \begin{array}{ll} -log(\hat{p}), & \text{if } y = 1 \\ -log(1-\hat{p}), & \text{if }y=0 \end{array} \right. Lce(y,p^)={log(p^),log(1p^),if y=1if y=0
对于所有样本,需要求平均作为最终的结果:
L = 1 N ∑ i = 1 N l ( y i , p ^ i ) L=\frac{1}{N}\sum_{i=1}^{N}l(y_i,\hat{p}_i) L=N1i=1Nl(yi,p^i)
对于二分类问题,可以改写成:
L = 1 N ( ∑ y i = 1 m − l o g ( p ^ ) + ∑ y i = 0 n − l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i=1}^{m}-log(\hat{p})+\sum_{y_i=0}^{n}-log(1-\hat{p})) L=N1(yi=1mlog(p^)+yi=0nlog(1p^))
其中,N为样本总数,m和n为正、负样本数, m + n = N m+n=N m+n=N

当样本分布不平衡时,损失函数L的分布也会发生倾斜,若m>>n时,正样本就会在损失函数中占据主导地位,由于损失函数的倾斜,训练的模型会倾向于样本较多的类别,导致对较少样本类别的性能较差。

平衡交叉熵函数

对于样本不平衡造成的损失函数倾斜,最直接的方法就是添加权重因子,提高少数类别在损失函数中的权重,从而平衡损失函数的分布。还是以之前的二分类问题为例,我们添加权重参数 α ∈ [ 0 , 1 ] \alpha∈[0,1] α[0,1]
L = 1 N ( ∑ y i = 1 m − α l o g ( p ^ ) + ∑ y i = 0 n − ( 1 − α ) l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i=1}^{m}-\alpha log(\hat{p})+\sum_{y_i=0}^{n}-(1-\alpha)log(1-\hat{p})) L=N1(yi=1mαlog(p^)+yi=0n(1α)log(1p^))
其中, α 1 − α = n m \frac{\alpha}{1-\alpha}=\frac{n}{m} 1αα=mn,权重大小由正负样本数量比来设置。

Focal Loss损失函数

Focal Loss从loss角度提供了一种样本不均衡的解决方案:
L f o c a l ( y , p ^ ) = { − ( 1 − p ^ ) γ l o g ( p ^ ) , if  y = 1 − p ^ γ l o g ( 1 − p ^ ) , if  y = 0 L_{focal}(y,\hat{p})= \left\{ \begin{array}{ll} -(1-\hat{p})^\gamma log(\hat{p}), & \text{if } y = 1 \\ -\hat{p}^\gamma log(1-\hat{p}), & \text{if }y=0 \end{array} \right. Lfocal(y,p^)={(1p^)γlog(p^),p^γlog(1p^),if y=1if y=0
p t = { p ^ , if  y = 1 1 − p ^ , otherwise.  p_t= \left\{ \begin{array}{ll} \hat{p}, & \text{if } y = 1 \\ 1-\hat{p}, & \text{otherwise. } \end{array} \right. pt={p^,1p^,if y=1otherwise. 

则表达式统一为:
L f o c a l = − ( 1 − p t ) γ l o g ( p t ) L_{focal}=-(1-p_t)^\gamma log(p_t) Lfocal=(1pt)γlog(pt)
与交叉熵表达式对照: L c e = − l o g ( p t ) L_{ce}=-log(p_t) Lce=log(pt),仅仅多了一个可变系数 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ.

其中, p t p_t pt反应了与ground truth的接近程度,越大表示分类越准。 γ > 0 \gamma>0 γ>0为调节因子。

对于分类不准确的样本, p t → 0 p_t→0 pt0 ( 1 − p t ) γ → 1 (1-p_t)^\gamma→1 (1pt)γ1 L f o c a l → L c e L_{focal}→L_{ce} LfocalLce;对于分类准确的样本, p t → 1 p_t→1 pt1 ( 1 − p t ) γ → 0 (1-p_t)^\gamma→0 (1pt)γ0 L f o c a l → 0 L_{focal}→0 Lfocal0;因此,Focal Loss对于分类不准确的样本,损失没有改变;对于分类准确的样本,损失会变小。整体来看,Focal Loss增加了分类不准确样本在损失函数中的权重。

如下是不同调节因子 γ \gamma γ对应的Loss-proba分布图,可以看出Cross Entropy(CE)和Focal Loss(FL)之间的区别,Focal Loss使损失函数更倾向于难分的样本。

在这里插入图片描述

Focal Loss vs Balanced Cross Entropy

  • Focal Loss是从样本分类难易程度出发,让Loss聚焦于难分类的样本;
  • Balanced Cross Entropy是从样本分布角度对Loss添加权重因子。
    • 缺点:仅仅考虑样本分布,有些难以区分的类别的样本数可能也比较多,此时被BCE赋予了较低的权重,会导致模型很难识别该类别!

Why does Focal Loss work?

Focal Loss从样本难易分类的角度出发,解决了样本不平衡导致模型性能较低的问题。

WHY?

样本不平衡造成的问题就是,样本数少的类别分类难度大,因此Focal Loss聚焦于难分样本,解决了样本少的类别分类精度不高的问题,对于难分样本中样本多的类别,也会被Focal Loss聚焦。因此,它不仅解决了样本不平衡问题,还提升了模型整体性能。

但是,要使模型训练过程中聚焦于难分类样本,仅仅将Loss倾向于难分类样本是不够的,因为模型参数更新取决于Loss的梯度:
w = w − α ∂ L ∂ w w=w-\alpha\frac{\partial L}{\partial w} w=wαwL
若Loss中难分类样本的权重较高,但是难分类样本的Loss梯度为0,难分类样本就不会影响到模型的参数更新。对于梯度问题,Focal Loss中的梯度与 x t x_t xt的关系如下所示,其中 x t = y x x_t=yx xt=yx y ∈ { − 1 , 1 } y∈\{-1,1\} y{1,1}为类别, p t = σ ( x t ) p_t=\sigma(x_t) pt=σ(xt),对于易分样本, x t > 0 x_t>0 xt>0,即 p t > 0.5 p_t>0.5 pt>0.5,由下图可知,此时的导数趋于0。对于难分样本,导数数值较大,因此,学习过程中更聚焦于难分样本。

在这里插入图片描述

难易分类样本是动态的, p t p_t pt在训练的过程中,可能会在难易之间相互转换。

在Loss梯度中,难训练样本起主导作用,参数朝着优化难训练样本的方向改变,变化之后可能会导致原本易训练的样本 p t p_t pt变化,即变成难训练样本。若发生了这种情况会导致模型收敛速度较慢。

为了防止这种难易样本的频繁变化,应该选择较小的学习率。

针对VidHOI数据集

因为VidHOI数据集中的一个人-物对会被多个交互标签同时标注,如< human,next to & watch & hold, cup >,所以会面临multi-class multi-label的分类问题。以往常常使用Binary cross-entropy,能够计算每个交互类别独立于其他类别的损失。但是,VidHOI数据集分布不均且具有长尾分布,为了解决这个不均衡问题同时避免过分强调最频繁类别的重要性,我们采用class-balanced Focal loss:
C B f o c a l ( p i , y i ) = − 1 − β 1 − β n i ( 1 − p y i ) γ l o g ( p y i ) w i t h   p y i = { p i , if  y i = 1 1 − p i , otherwise. CB_{focal}(p_i,y_i)=-\frac{1-\beta}{1-\beta^{n_i}}(1-p_{y_i})^{\gamma}log(p_{y_i}) \\ with \ p_{y_i} = \left\{ \begin{array}{ll} p_i, & \text{if } y_i = 1 \\ 1-p_i, & \text{otherwise.} \end{array} \right. CBfocal(pi,yi)=1βni1β(1pyi)γlog(pyi)with pyi={pi,1pi,if yi=1otherwise.

其中的 − ( 1 − p y i ) γ l o g ( p y i ) -(1-p_{y_i})^{\gamma}log(p_{y_i}) (1pyi)γlog(pyi)是Lin提出的Focal loss, p i p_i pi表示预估为第i个类别的可能性, y i ∈ { 0 , 1 } y_i∈\{0,1\} yi{0,1}表示Ground Truth的label。变量 n i n_i ni表示第i个类别在Ground Truth下的样本量, β ∈ [ 0 , 1 ) \beta∈[0,1) β[0,1)是可调节参数。所有类别的平均损失作为一个预测的损失。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class FocalBCEWithLogitLoss(nn.modules.loss._Loss):
    """Focal Loss with binary cross-entropy
    Implement the focal loss with class-balanced loss, using binary cross-entropy as criterion
    Following paper "Class-Balanced Loss Based on Effective Number of Samples" (CVPR2019)

    Args:
        gamma (int, optional): modulation factor gamma in focal loss. Defaults to 2.
        alpha (int, optional): modulation factor alpha in focal loss. If a integer, apply to all;
            if a list or array or tensor, regard as alpha for each class; if none, no alpha. Defaults to None.
        weight (Optional[torch.Tensor], optional): weight to each class, !not the same as alpha. Defaults to None.
        size_average (_type_, optional): _description_. Defaults to None.
        reduce (_type_, optional): _description_. Defaults to None.
        reduction (str, optional): _description_. Defaults to "mean".
    """

    def __init__(
        self,
        gamma=2,
        alpha=None,
        weight: Optional[torch.Tensor] = None,
        size_average=None,
        reduce=None,
        reduction: str = "mean",
        pos_weight: Optional[torch.Tensor] = None,
    ):
        super(FocalBCEWithLogitLoss, self).__init__(size_average, reduce, reduction)
        self.gamma = gamma
        # a number for all, or a Tensor with the same num_classes as input
        if isinstance(alpha, (list, np.ndarray)):
            self.alpha = torch.Tensor(alpha)
        else:
            self.alpha = alpha
        self.register_buffer("weight", weight)
        self.register_buffer("pos_weight", pos_weight)
        self.weight: Optional[torch.Tensor]
        self.pos_weight: Optional[torch.Tensor]

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        if self.alpha is not None:
            if isinstance(self.alpha, torch.Tensor):
                alpha_t = self.alpha.repeat(input.shape[0], 1)
            else:
                alpha_t = torch.ones_like(input) * self.alpha
        else:
            alpha_t = None
		# 二元交叉熵
        ce = F.binary_cross_entropy_with_logits(input, target, reduction="none")
        # pt = torch.exp(-ce)
        # modulator = ((1 - pt) ** self.gamma)
        # following author's repo https://github.com/richardaecn/class-balanced-loss/blob/master/src/cifar_main.py#L226-L266
        # explaination https://github.com/richardaecn/class-balanced-loss/issues/1
        # A numerically stable implementation of modulator.
        if self.gamma == 0.0:
            modulator = 1.0
        else:
            # e^(-gamma*target*input - gamma*log(1+e^(-input)))
            modulator = torch.exp(-self.gamma * target * input - self.gamma * torch.log1p(torch.exp(-input)))
        # focal loss
        fl_loss = modulator * ce
        # alpha
        if alpha_t is not None:
            alpha_t = alpha_t * target + (1 - alpha_t) * (1 - target)
            fl_loss = alpha_t * fl_loss
        # pos weight
        if self.pos_weight is not None:
            fl_loss = self.pos_weight * fl_loss
        # reduction
        if self.reduction == "mean":
            return fl_loss.mean()
        elif self.reduction == "sum":
            return fl_loss.sum()
        else:
            return fl_loss

C B f o c a l ( p i , y i ) = − 1 − β 1 − β n i ( 1 − p y i ) γ l o g ( p y i ) w i t h   p y i = { p i , if  y i = 1 1 − p i , otherwise. CB_{focal}(p_i,y_i)=-\frac{1-\beta}{1-\beta^{n_i}}(1-p_{y_i})^{\gamma}log(p_{y_i}) \\ with \ p_{y_i} = \left\{ \begin{array}{ll} p_i, & \text{if } y_i = 1 \\ 1-p_i, & \text{otherwise.} \end{array} \right. CBfocal(pi,yi)=1βni1β(1pyi)γlog(pyi)with pyi={pi,1pi,if yi=1otherwise.

原始版本的代码:

def focal_loss(labels, logits, alpha, gamma):
  """Compute the focal loss between `logits` and the ground truth `labels`.

  Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
  where pt is the probability of being classified to the true class.
  pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).

  Args:
    labels: A float32 tensor of size [batch, num_classes].
    logits: A float32 tensor of size [batch, num_classes].
    alpha: A float32 tensor of size [batch_size]
      specifying per-example weight for balanced cross entropy.
    gamma: A float32 scalar modulating loss from hard and easy examples.
  Returns:
    focal_loss: A float32 scalar representing normalized total loss.
  """
  with tf.name_scope('focal_loss'):
    logits = tf.cast(logits, dtype=tf.float32)
    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=labels, logits=logits)

    # positive_label_mask = tf.equal(labels, 1.0)
    # probs = tf.sigmoid(logits)
    # probs_gt = tf.where(positive_label_mask, probs, 1.0 - probs)
    # # With gamma < 1, the implementation could produce NaN during back prop.
    # modulator = tf.pow(1.0 - probs_gt, gamma)

    # A numerically stable implementation of modulator.
    if gamma == 0.0:
      modulator = 1.0
    else:
      modulator = tf.exp(-gamma * labels * logits - gamma * tf.log1p(
          tf.exp(-1.0 * logits)))

    loss = modulator * cross_entropy

    weighted_loss = alpha * loss
    focal_loss = tf.reduce_sum(weighted_loss)
    # Normalize by the total number of positive samples.
    focal_loss /= tf.reduce_sum(labels)
  return focal_loss

Reference

  1. https://zhuanlan.zhihu.com/p/266023273
  2. https://github.com/nizhf/hoi-prediction-gaze-transformer
  3. https://github.com/richardaecn/class-balanced-loss/blob/master/src/cifar_main.py#L226-L266

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/93418.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用AI重构的钉钉,“钱”路在何方?

点击关注 文&#xff5c;郝 鑫&#xff0c;编&#xff5c;刘雨琦 钉钉2023年生态大会&#xff0c;离开了两年的无招&#xff0c;遇到了单飞9天的钉钉。 “做小钉钉、做好钉钉、做酷钉钉”&#xff0c;无招重申了钉钉的方向。 无招提到的三点&#xff0c;再加上“高质量增长”…

Doris异常处理

1、decimal 字段异常 修改为 2、连接超时 Caused by: com.mysql.cj.exceptions.CJCommunicationsException: Communications link failure The last packet successfully received from the server was 1,068 milliseconds ago. The last packet sent successfully to the ser…

Redis限流实践:实现用户消息推送每天最多通知2次的功能

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;CSDN领军人物&#xff0c;全栈领域优质创作者✌&#xff0c;CSDN博客专家&#xff0c;阿里云社区专家博主&#xff0c;2023年6月CSDN上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师…

Docker数据管理(数据卷与数据卷容器)

目录 一、数据卷&#xff08;Data Volumes&#xff09; 1、概述 2、原理 3、作用 4、示例&#xff1a;宿主机目录 /var/test 挂载同步到容器中的 /data1 二、数据卷容器&#xff08;DataVolumes Containers&#xff09; 1、概述 2、作用 3、示例&#xff1a;创建并使用…

AIGC ChatGPT 实现动态多维度分析雷达图制作

雷达图在多维度分析中是一种非常实用的可视化工具,主要有以下优势: 易于理解:雷达图使用多边形或者圆形的形式展示多维度的数据,直观易于理解。多维度对比:雷达图可以在同一张图上比较多个项目或者实体在多个维度上的表现。数据关系明显:通过雷达图,可以直观的看出各个数…

C++贪吃蛇(控制台版)

C自学精简实践教程 目录(必读) 目录 主要考察 需求 输入文件 运行效果 实现思路 枚举类型 enum class 启动代码 输入文件data.txt 的内容 参考答案 学生实现的效果 主要考察 模块划分 文本文件读取 UI与业务分离 控制台交互 数据抽象 需求 用户输入字母表示方…

朋友圈也可以定时定量发送?

场景1&#xff1a;明天要搞活动&#xff0c;早中晚都得发朋友圈&#xff0c;一天要发3次朋友圈&#xff0c;要在手机上定好3个闹钟&#xff0c;这是一件非常麻烦的事。 场景2&#xff1a;有朋友是房产信息的&#xff0c;每天要发布很多二手房源&#xff0c;手动发圈太耗时间&a…

记录:yolov8训练自己的数据集

一、LabelImg标注自己的原图数据集 .xml标注格式 二、带标签的数据增强 先将原始数据&#xff08;图片&#xff0c;标注&#xff09;转移到项目根目录&#xff0c;然后再数据增强&#xff0c;避免标注内容路径错误。 亮度变换加旋转 # 一、亮度 img_dir multi/images # 原始…

CSS基础选择器及常见属性

文章目录 一、CSS1、CSS简介2、CSS语法规范 二、CSS基础选择器1、选择器的作用2、选择器分类3、基础选择器标签选择器类选择器id选择器通配符选择器 三、CSS常见属性1、字体属性字体系列字体大小字体粗细文字样式 2、文本属性文本颜色对齐文本装饰文本文本缩进行间距 四、CSS引…

python编写四画面同时播放swap视频

当代技术让我们能够创建各种有趣和实用的应用程序。在本篇博客中&#xff0c;我们将探索一个基于wxPython和OpenCV的四路视频播放器应用程序。这个应用程序可以同时播放四个视频文件&#xff0c;并将它们显示在一个GUI界面中。 C:\pythoncode\new\smetimeplaymp4.py 准备工作…

2023最新任务悬赏平台源码uniapp+Thinkphp新款悬赏任务地推拉新充场游戏试玩源码众人帮威客兼职任务帮任务发布分销机

新款悬赏任务地推拉新充场游戏试玩源码众人帮威客兼职任务帮任务发布分销机制 后端是&#xff1a;thinkphpFastAdmin 前端是&#xff1a;uniapp 1.优化首页推荐店铺模块如有则会显示此模块没有则隐藏。 2修复首页公告&#xff0c;更改首页公告逻辑。&#xff08;后台添加有公…

redis 6个节点(3主3从),始终一个节点不能启动

redis节点&#xff0c;始终有一个节点不能启动起来 1.修改了配置文件 protected-mode no&#xff0c;重启 修改了配置文件 protected-mode no&#xff0c;重启redis问题依然存在 2、查看/var/log/message的redis日志 Aug 21 07:40:33 redisMaster kernel: Out of memory: K…

Jumpserver堡垒机管理(安装和相关操作)-------从小白到大神之路之学习运维第89天

第四阶段 时 间&#xff1a;2023年8月28日 参加人&#xff1a;全班人员 内 容&#xff1a; Jumpserver堡垒机管理 目录 一、堡垒机简介 &#xff08;一&#xff09;运维常见背黑锅场景 &#xff08;二&#xff09;背黑锅的主要原因 &#xff08;三&#xff09;解决背黑…

SSM框架的学习与应用(Spring + Spring MVC + MyBatis)-Java EE企业级应用开发学习记录(第三天)动态SQL

动态SQL—SSM框架的学习与应用(Spring Spring MVC MyBatis)-Java EE企业级应用开发学习记录&#xff08;第三天&#xff09;Mybatis的动态SQL操作 昨天我们深入学习了Mybatis的核心对象SqlSessionFactoryBuilder&#xff0c;掌握MyBatis核心配置文件以及元素的使用,也掌握My…

4-1-netty

非阻塞io 服务端就一个线程&#xff0c;可以处理无数个连接 收到所有的连接都放到集合channelList里面 selector是有事件集合的 对server来说优先关注连接事件 遍历连接事件

小研究 - Java虚拟机性能及关键技术分析

利用specJVM98和Java Grande Forum Benchmark suite Benchmark集合对SJVM、IntelORP,Kaffe3种Java虚拟机进行系统测试。在对测试结果进行系统分析的基础上&#xff0c;比较了不同JVM实现对性能的影响和JVM中关键模块对JVM性能的影响&#xff0c;并提出了提高JVM性能的一些展望。…

Leetcode 2651.计算列车到站时间

给你一个正整数 arrivalTime 表示列车正点到站的时间&#xff08;单位&#xff1a;小时&#xff09;&#xff0c;另给你一个正整数 delayedTime 表示列车延误的小时数。 返回列车实际到站的时间。 注意&#xff0c;该问题中的时间采用 24 小时制。 示例 1&#xff1a; 输入&…

什么样的人适合开抖店?最后一个条件必须满足!抖店开通门槛如下

我是王路飞。 作为现在热门的电商项目&#xff0c;抖店显然已经取代直播带货&#xff0c;成为了普通人在抖音卖货的新渠道&#xff0c;毕竟做账号和开直播对普通人来说&#xff0c;门槛太高了。 那么&#xff0c;在抖音开店&#xff0c;是谁都可以开吗&#xff1f;开店有什么…

K8S最新版本集群部署(v1.28) + 容器引擎Docker部署(上)

温故知新 &#x1f4da;第一章 前言&#x1f4d7;背景&#x1f4d7;目的&#x1f4d7;总体方向 &#x1f4da;第二章 基本环境信息&#x1f4d7;机器信息&#x1f4d7;软件信息&#x1f4d7;部署用户kubernetes &#x1f4da;第三章 Kubernetes各组件部署&#x1f4d7;安装kube…

基于MATLAB/Simulink的三相并网逆变器dq阻抗建模及扫频仿真

目录 整体系统介绍理论模型MATLAB实现 基于Simulink的阻抗扫频仿真整体思路注意事项流程框图 其他 本文主要介绍三相并网逆变器dq阻抗建模的相关知识&#xff0c;和大家分享一下怎么使用MATLAB/Simulink来进行理论模型的搭建以及如何通过扫频获取阻抗模型&#xff0c;一方面是给…