深入理解二分类和多分类CrossEntropy Loss和Focal Loss

深入理解二分类和多分类CrossEntropy Loss和Focal Loss

二分类交叉熵

在二分的情况下,模型最后需要预测的结果只有两种情况,对于每个类别我们的预测得到的概率为 p p p 1 − p 1-p 1p,此时表达式为( 的 log ⁡ \log log底数是 e e e):
L = 1 N ∑ i L i = 1 N ∑ i − [ y i ⋅ log ⁡ ( p i ) + ( 1 − y i ) ⋅ log ⁡ ( 1 − p i ) ] L=\frac{1}{N} \sum_{i} L_i =\frac{1}{N} \sum_{i} -[y_i \cdot \log (p_i) +(1-y_i) \cdot \log (1-p_i)] L=N1iLi=N1i[yilog(pi)+(1yi)log(1pi)]
其中:

  • y i y_i yi —— 表示样本 i i i的label,正类为1 ,负类为0
  • p i p_i pi—— 表示样本 i i i预测为正类的概率

由于二分类交叉熵很容易理解,在此就不做举例了。

多分类交叉熵

多分类交叉熵就是对二分类交叉熵的扩展,在计算公式中和二分类稍微有些许区别,但是还是比较容易理解,具体公式如下所示:
L = 1 N ∑ i L i = − 1 N ∑ i ∑ c = 1 M y i c log ⁡ ( p i c ) L=\frac{1}{N} \sum_{i} L_i=-\frac{1}{N} \sum_{i} \sum_{c=1}^M y_{ic} \log(p_{ic}) L=N1iLi=N1ic=1Myiclog(pic)
其中:

  • M M M——类别的数量
  • y i c y_{ic} yic——符号函数(0或1 ),如果样本 i i i的真实类别等于 c c c取 1,否则取 0
  • p i c p_{ic} pic——观测样本 i i i属于类别 c c c的预测概率

举例说明

预测(已经经过softmax归一化)真实
0.1 0.2 0.70 0 1
0.3 0.4 0.30 1 0
0.1 0.2 0.71 0 0

现在我们利用这个表达式计算上面例子中的损失函数值:
sample 1 loss = − ( 0 × log ⁡ 0.1 + 0 × log ⁡ 0.2 + 1 × log ⁡ 0.7 ) = 0.35 , sample 2 loss = − ( 0 × log ⁡ 0.1 + 1 × log ⁡ 0.7 + 0 × log ⁡ 0.2 ) = 0.35 , sample 3 loss = − ( 1 × log ⁡ 0.3 + 0 × log ⁡ 0.4 + 0 × log ⁡ 0.4 ) = 1.20 , L = 0.35 + 0.35 + 1.2 3 = 0.63 \text{sample 1 loss}=-(0 \times \log 0.1+0 \times \log 0.2 + 1 \times \log 0.7)=0.35 ,\\ \text{sample 2 loss}=-(0 \times \log 0.1+1 \times \log 0.7 + 0 \times \log 0.2)=0.35 ,\\ \text{sample 3 loss}=-(1 \times \log 0.3+0 \times \log 0.4 + 0 \times \log 0.4)=1.20,\\ L=\frac{0.35+0.35+1.2}{3}=0.63 sample 1 loss=(0×log0.1+0×log0.2+1×log0.7)=0.35,sample 2 loss=(0×log0.1+1×log0.7+0×log0.2)=0.35,sample 3 loss=(1×log0.3+0×log0.4+0×log0.4)=1.20,L=30.35+0.35+1.2=0.63
其实可以看到,多分类交叉熵只计算正确标签对应概率的损失值,相对错误标签其 y i c = 0 y_{ic}=0 yic=0,所以导致错误标签对应的损失值为0。

Pytorch的CrossEntropyLoss分析

参数设定

CrossEntropyLoss在Pytorch官网中,我们可以看到整个文档已经对该函数CrossEntropyLoss进行了较充分的解释。所以我们简要介绍其参数和传入的值的格式,特别是针对多分类的情况。

常见的传入参数如下所示:

  • weight:传入的是一个list或者tensor,其检索对应位置的值为该类的权重。注意,如果是GPU的环境下,则传入的值必须是tensor,并且其应该在GPU中。

  • reduction:传入的是一个字符串,有三种形式可以选择,分别是mean/sum/none,默认是meanmeansum如字面意思所示,代表损失值取平均,损失值求和的形式。none是计算每个位置对应的损失值,返回和label对应的形状。

更多参数解释如下图所示:

使用方法

CrossEntropyLoss传入的值为两个,分别是inputtarget。输出只有一个Output

  • input的形状为 ( N , C ) / ( N , C , d 1 , d 1 , … ) (N,C)/(N,C,d_1,d_1,\ldots) (N,C)/(N,C,d1,d1,),前者对应二维情况,后者对应高维情况,值得注意的是 C C C是在dim=1的位置上,可能在高维的情况下很多人都以为默认应该是最后一个维度dim=-1

  • target的形状为 ( N ) / ( N , d 1 , d 1 , … ) (N)/(N,d_1,d_1,\ldots) (N)/(N,d1,d1,),前者对应二维情况,后者对应高维情况。注意的是target的值对应的是类别对应的索引,不是one-hot的形式

  • Output的形状和target的形状一致。

更多参数解释如下图所示:

二维情况下对应的5分类交叉熵损失计算(官网示例):

>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()

高维情况下对应的交叉熵计算:

input = torch.randn(2,3,5,5,4)#最后一个维度对应的是类别
target = torch.empty(2,3,5,5, dtype=torch.long).random_(4) #四分类
loss_fn=CrossEntropyLoss(reduction='sum')
_input=torch.permute(input,dims=(0,-1,1,2,3))
loss=loss_fn(_input,target)#输入的类别一定是在dim=1的位置上
print(loss)
# 当然也可以将输入先转为2维的形式在计算,结果是一样的
_input=input.view(-1,4)
_target=target.view(-1)
loss=loss_fn(_input,_target)
print(loss)

内在原理

Pytorch中的CrossEntropyLoss()是将logSoftmax()NLLLoss()函数进行合并的,也就是说其内在实现就是基于logSoftmax()NLLLoss()这两个函数。

input=torch.rand(3,5)
target=torch.empty(3,dtype=torch.long).random_(5)
loss_fn=CrossEntropyLoss(reduction='sum')
loss=loss_fn(input,target)
print(loss)
_input=torch.nn.LogSoftmax(dim=1)(input)
loss=torch.nn.NLLLoss(reduction='sum')(_input,target)
print(loss)

其实也就是和官网上所说的一样,CrossEntropyLoss()是对输出计算softmax(),在对结果取log()对数,最后使用NLLLoss()得到对应位置的索引值。

Focal Loss原理和实现

Focal Loss来自于论文Focal Loss for Dense Object Detection,用于解决类别样本不平衡以及困难样本挖掘的问题,其公式非常简洁:
F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t)=- \alpha_t (1-p_t) ^{\gamma} \log (p_t) FL(pt)=αt(1pt)γlog(pt)
p t p_t pt是模型预测的结果的类别概率值。 − log ⁡ ( p t ) - \log (p_t) log(pt)和交叉熵损失函数一致,因此当前样本类别对应的那个 p t p_t pt如果越小,说明预测越不准确, 那么 ( 1 − p t ) γ (1-p_t)^{\gamma} (1pt)γ 这一项就会增大,这一项也作为困难样本的系数,预测越不准,Focal Loss越倾向于把这个样本当作困难样本,这个系数也就越大,目的是让困难样本对损失和梯度的贡献更大。

前面的 α t \alpha_t αt是类别权重系数。如果你有一个类别不平衡的数据集,那么你肯定想对数量少的那一类在loss贡献上赋予一个高权重,这个 α t \alpha_t αt就起到这样的作用。因此, α t \alpha_t αt应该是一个向量,向量的长度等于类别的个数,用于存放各个类别的权重。一般来说 α t \alpha_t αt中的值为每一个类别样本数量的倒数,相当于平衡样本的数量差距。

这里提供一个二维/高维的Focal Loss的实现:

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=torch.tensor([0.2, 0.3, 0.5,1])):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, input, target):
        logpt = nn.functional.log_softmax(input, dim=1) #计算softmax后在计算log
        pt = torch.exp(logpt) #对log_softmax去exp,把log取消就是概率
        alpha=self.alpha[target].unsqueeze(dim=1) # 去取真实索引类别对应的alpha
        logpt = alpha*(1 - pt) ** self.gamma * logpt #focal loss计算公式
        loss = nn.functional.nll_loss(logpt, target,reduction='sum') # 最后选择对应位置的元素
        return loss

参考资料

CrossEntropy官网详细说明。

Pytorch中的CrossEntropyLoss()函数案例解读和结合one-hot编码计算Loss

详解PyTorch实现多分类Focal Loss——带有alpha简洁实现

最近工作

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/18839.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Osek网络管理及ETAS实现

OSEK/VDX(Offene Systeme und deren Schnittstellen fr die Elektronik in Kraftfahrzeugen / Vehicle Distributed eXecutive)是一种用于嵌入式系统(尤其是汽车电子控制单元)的开放标准。它旨在提供一种统一、可互操作的软件架构…

Origin如何绘制三维图形?

文章目录 0.引言1.使用矩阵簿窗口2.三维数据转换3.三维绘图4.三维曲面图5.三维XYY图6.三维符号、条状、矢量图7.等高线图 0.引言 因科研等多场景需要,绘制专业的图表,笔者对Origin进行了学习,本文通过《Origin 2022科学绘图与数据》及其配套素…

三分钟教你如何定义自己的ChatGPT

三分钟教你如何定义自己的ChatGPT 成品预览材料准备MyChatGPT自定义AI 成品预览 材料准备 总共有两种方式: 一种是使用自己的OpenAI账号,这种方式是可控性比较强,同时也会有很多问题,比如你需要准备国外的手机号和Visa卡&#x…

Java 动态原理详解

Java 动态代理是一种非常重要的编程技术,它在很多场景下都有着广泛的应用。本文将介绍 Java 动态代理的实现原理,并附上相应的源码,以帮助读者更好地理解和应用这一技术。 一、什么是 Java 动态代理? Java 动态代理是一种在运行时…

在docker上安装运行Python文件

目录 一、在docker中安装python 1.1 输入镜像拉取命令 1.2 查看镜像 1.3 运行 1.4 查看是否成功 1.5 查看python版本 二、运行py文件 2.1准备运行所需文件 2.2 准备文件夹 2.3 大概是这幅模样 2.4 打包上传到服务器上 2.5 构建镜像示例 2.6 查看镜像 2.7 优化镜像的…

Spring MVC自定义拦截器--Spring MVC异常处理

目录 自定义拦截器 什么是拦截器 ● 说明 自定义拦截器执行流程分析图 ● 自定义拦截器执行流程说明 自定义拦截器应用实例 ● 应用实例需求 创建MyInterceptor01 创建FurnHandler类 在 springDispatcherServlet-servlet.xml 配置拦截器 第一种配置方式 第二种配置方…

【Linux】网络---->套接字编程(TCP)

套接字编程TCP TCP的编程流程TCP的接口TCP的代码(单线程、多进程、多线程代码)单线程多进程多线程 TCP的编程流程 TCP的编程流程:大致可以分为五个过程,分别是准备过程、连接建立过程、获取新连接过程、消息收发过程和断开过程。 …

《花雕学AI》ChatGPT 的 Prompt 用法,不是随便写就行的,这 13 种才是最有效的

ChatGPT 是一款基于 GPT-3 模型的人工智能写作工具,它可以根据用户的输入和要求,生成各种类型和风格的文本内容,比如文章、故事、诗歌、对话、摘要等。ChatGPT 的强大之处在于它可以灵活地适应不同的写作场景和目的,只要用户给出合…

MySQL多表查询之连接查询

0. 数据源 /*Navicat Premium Data TransferSource Server : localhost_3306Source Server Type : MySQLSource Server Version : 80016Source Host : localhost:3306Source Schema : tempdbTarget Server Type : MySQLTarget Server Version…

两小时让你全方位的认识文件(一)

想必友友们在生活中经常会使用到各种各样的文件,那么我们是否了解它其中的奥秘呢,今天阿博就带领友友们深入地走入文件🛩️🛩️🛩️ 文章目录 一.为什么使用文件二.什么是文件三.文件的打开和关闭四.文件的顺序读写 一…

时间复杂度

学习《代码随想录》 时间复杂度为什么要引入时间复杂度和空间复杂度?什么是时间复杂度?这个O是什么意思?时间复杂度越低越好? 内存管理什么是内存空间?(C为例)为什么总说C/C更偏向底层&#xff…

T-SQL游标的使用

一.建表 INSERT INTO cloud VALUES( 你 ) INSERT INTO cloud VALUES( 一会看我 ) INSERT INTO cloud VALUES( 一会看云 ) INSERT INTO cloud VALUES( 我觉得 ) INSERT INTO cloud VALUES( 你看我时很远 ) INSERT INTO cloud VALUES( 你看云时很近 ) 二.建立游标 1.游标的一般格…

判断大小端的错误做法

这里不详细讲解大小端的区别,只讲解判断大小端的方法。 1.大端,小端的区别 0x123456 在内存中的存储方式 大端是高字节存放到内存的低地址 小端是高字节存放到内存的高地址 2.大小端的判断 1.错误的做法 int main() {int a0x1234;char c(char)a;if(…

2022年宜昌市网络搭建与应用竞赛样题(三)

网络搭建与应用竞赛样题(三) 技能要求 (总分1000分) 竞赛说明 一、竞赛内容分布 “网络搭建与应用”竞赛共分三个部分,其中: 第一部分:网络搭建及安全部署项目(500分&#xff0…

基于SpringBoot3从零配置SpringDoc

为了方便调试,更好的服务于前后端分离式的工作模式,我们给项目引入Swagger。 系列文章指路👉 系列文章-基于SpringBoot3创建项目并配置常用的工具和一些常用的类 文章目录 1. SpringFox2. SpringDoc2.1 引入依赖2.2 配置文件2.3 语法2.4 使…

PCL学习八:Keypoints-关键点

参考引用 Point Cloud Library黑马机器人 | PCL-3D点云 PCL点云库学习笔记(文章链接汇总) 1. 引言 关键点也称为兴趣点,它是 2D 图像或 3D 点云或曲面模型上,可以通过检测标准来获取的具有稳定性、区别性的点集。从技术上来说,关键…

Microsoft Edge新功能测评体验

Microsoft Edge使用体验 Microsoft Edge是一款现代化的浏览器,它拥有众多功能和强大的性能,为用户带来更加流畅的浏览体验。 Edge最近推出了分屏功能,支持一个窗口同时显示两个选项卡,这可以大大提高生产力和多任务处理能力。 一…

API接口对程序员的帮助有哪些,参考值简要说明

API接口对程序员的帮助有哪些 提高开发效率:通过API接口,程序员能够在不用重复编写代码的情况下,直接获取其他应用程序提供的服务或数据,极大地提高了开发效率。 减少错误率:使用API接口可以避免手动输入数据容易出现…

洛谷P5047 [Ynoi2019 模拟赛] Yuno loves sqrt technology II(离线区间逆序对+莫队二次离线)

题目 给你一个长为n(1<n<1e5)的序列a(0<ai<1e9)&#xff0c; m(1<m<1e5)次询问&#xff0c;每次查询一个区间[l,r]的逆序对数&#xff0c;可离线。 思路来源 登录 - 洛谷 三道经典分块题的更优复杂度解法&[Ynoi2019模拟赛]题解 - 博客 - OldDriverT…

Flutter性能分析工具使用

使用前提 flutter常用的性能分析工具&#xff0c;这些工具都集成在android studio中&#xff0c;基本能满足我们的需求了。在下面介绍的几个工具中&#xff0c;Flutter Performance和Flutter Inspector都能够直接在debug模式下使用&#xff0c;但是DevTools只能在profile模式下…