交叉熵、Focal Loss以及其Pytorch实现

交叉熵、Focal Loss以及其Pytorch实现

本文参考链接:https://towardsdatascience.com/focal-loss-a-better-alternative-for-cross-entropy-1d073d92d075

文章目录

  • 交叉熵、Focal Loss以及其Pytorch实现
    • 一、交叉熵
    • 二、Focal loss
    • 三、Pytorch
      • 1.[交叉熵](https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html?highlight=nn+crossentropyloss#torch.nn.CrossEntropyLoss)
      • 2.[Focal loss](https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py)

一、交叉熵

损失是通过梯度回传用来更新网络参数是之产生的预测结果和真实值之间相似。不同损失函数有着不同的约束作用,不同的数据对损失函数有着不同的影响。

交叉熵是常见的损失函数,常见于语义分割、对比学习等。其函数表达式如下,其中 Y i 和 p i Y_i和p_i Yipi分别表示真实值和预测结果:
C r o s s E n t r o p y = − ∑ i = 1 i = n Y i l o g ( p i ) CrossEntropy=-\sum_{i=1}^{i=n}Y_ilog(p_i) CrossEntropy=i=1i=nYilog(pi)
因为 p i p_i pi值在0~1之间,故交叉熵大于等于0。这个函数什么时候最小呢?数学证明结果表明当 Y i = p i Y_i=p_i Yi=pi时交叉熵最小。下面,我们取二分类情况来进行简单证明:
B C E L o s s = − y l o g x − ( 1 − y ) l o g ( 1 − x ) BCELoss=-ylogx-(1-y)log(1-x) BCELoss=ylogx(1y)log(1x)
对BCELoss求导可得:
− y x + 1 − y 1 − x = y − x x − x 2 -\frac{y}{x}+\frac{1-y}{1-x}=\frac{y-x}{x-x^2} xy+1x1y=xx2yx
所以当 y = x y=x y=x时,二分类交叉熵取最小值。

那么,交叉熵有啥子问题?

  • 从表达式可以看出,交叉熵只针对单个像素进行比较,像素和像素之间并没有联系,这就需要我们在模型中使用空间注意力机制等使得特征在空间上进行交互。针对这个问题,不少论文提出了改进方案,如Context Prior for Scene Segmentation这篇论文就使用了和precision、recall等类似的损失函数(就像Dice loss和F1 score指标一样)。

  • 类别不平衡:这个问题比较常见,语义分割中类别在图片上总像素占比是不平衡。如果类别不平衡比较严重,交叉熵损失就会偏向于占比较高的类别,导致对占比较少的类别预测结果较差。解决这一方法为给交叉熵损失添加权重(平衡交叉熵)等,如下式:
    B a l a n c e d C r o s s E n t r o p y = − ∑ i = 1 i = n α i Y i l o g ( p i ) BalancedCrossEntropy=-\sum_{i=1}^{i=n}\alpha_iY_ilog(p_i) BalancedCrossEntropy=i=1i=nαiYilog(pi)

  • 困难样本:首先,我们要知道困难样本是那些模型反复出现巨大损失的例子,而简单样本是那些容易分类的例子。交叉熵对于所有样本同等对待,导致无法辨别困难样本和简单样本。解决这一问题就是接下来的损失函数Focal loss

二、Focal loss

Focal loss关注的是模型出错的例子,而不是它可以自信地预测的例子,确保对困难的例子的预测随着时间的推移而改善,而不是对容易的例子变得过于自信。

这到底是怎么做到的呢?Focal loss是通过一个叫做Down Weighting的东西来实现的。下调权重是一种技术,它可以减少容易的例子对损失函数的影响,从而使人们更加关注困难的例子。这种技术可以通过在交叉熵损失中加入一个调节因子来实现。其表达式如下:
F o c a l L o s s = − ∑ i = 1 i = n ( 1 − p i ) γ l o g p i FocalLoss=-\sum_{i=1}^{i=n}(1-p_i)^{\gamma}logp_i FocalLoss=i=1i=n(1pi)γlogpi
不同的 γ \gamma γ对损失有什么影响呢?如下图所示

img

不同的 γ \gamma γ ( 1 − p i ) γ (1-p_i)^{\gamma} (1pi)γ有什么影响呢,如下:

img

  • 在误分类样本的情况下, p i pi pi很小,使得调制因子大约或非常接近于1,这使损失函数不受影响。此时,Focal Loss和交叉熵损失相似。
  • 随着模型置信度的提高,即 p i → 1 pi→1 pi1,调制因子将趋于0,从而降低了分类良好的例子的损失值。聚焦参数, γ \gamma γ≥1,将重新调整调制因子,使容易的例子比困难的例子降权更多,减少它们对损失函数的影响。例如,考虑预测概率为0.9和0.6。考虑到 γ \gamma γ=2,对0.9计算出的损失值是4.5e-4,降权系数( 1 / ( 1 − q i ) 2 1/(1-q_i)^2 1/(1qi)2)为100,对0.6则是3.5e-2,降权系数为6.25。从实验来看, γ \gamma γ=2来说效果最好。
  • γ \gamma γ=0时,Focal Loss等同于Cross Entropy。

此外,加入平衡因子 α \alpha α,用来平衡正负样本本身的比例不均:文中 α \alpha α取0.25,即正样本要比负样本占比小,这是因为负例易分。其表达式如下:
F o c a l L o s s = − ∑ i = 1 i = n α i ( 1 − p i ) γ l o g p i FocalLoss=-\sum_{i=1}^{i=n}\alpha_i(1-p_i)^{\gamma}logp_i FocalLoss=i=1i=nαi(1pi)γlogpi
Focal Loss自然地解决了阶级不平衡的问题,(1因为来自多数类别的例子通常容易预测,而来自少数类别的例子由于缺乏数据或来自多数类别的例子在损失和梯度过程中占主导地位而难以预测。由于这种相似性,Focal Loss可能能够解决这两个问题。

三、Pytorch

1.交叉熵

Pytorch可以直接调用交叉熵损失函数nn.CrossEntropyLoss(),其功能还是比较全的。其中weight可以用了进行权重平衡,ignore_index可以用来忽略特定类别。输入的标签不需要进行one hot编码,其内部已经实现。nn.CrossEntropyLoss()=nn.NLLoss() + nn.LogSoftmax。
在这里插入图片描述

2.Focal loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/32246.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python 动态生成系统数据库设计到word文档

背景 经常需要交付一些系统文档而且基本都是word的,其中又有系统数据库介绍模块, 看着数据库里的几百张表于是我开始怀疑人生, 所以咱手写一个 涉及知识 pymysql 操作数据库 -tkinter GUI图形库threading 线程queue 阻塞队列pandas python数据计算…

layui(5)——内置模块分页模块

模块加载名称&#xff1a;laypage laypage 的使用非常简单&#xff0c;指向一个用于存放分页的容器&#xff0c;通过服务端得到一些初始值&#xff0c;即可完成分页渲染&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset&quo…

RocketMQ --- 实战篇

一、案例介绍 1.1、业务分析 模拟电商网站购物场景中的【下单】和【支付】业务 1.1.1、下单 流程 用户请求订单系统下单 订单系统通过RPC调用订单服务下单 订单服务调用优惠券服务&#xff0c;扣减优惠券 订单服务调用调用库存服务&#xff0c;校验并扣减库存 订单服务调…

长尾关键词有什么作用?要怎么用?

长尾关键词很多的网站都会忽略其存在&#xff0c;其实你不要小看长尾关键词&#xff0c;他将带给网站的流量也是极其可观的&#xff0c;所说比不上那些重点关键词的流量&#xff0c;但是对提升网站的权重还是有着重要的作用。 长尾关键词有什么用&#xff1f;长尾关键词的3…

Gitlab群组及项目仓库搭建

1、新建群组 2、新建项目 3、克隆到Visualstudio 复制克隆地址&#xff0c;克隆到本地 这里会让你登录账号 可以添加成员并邀请ta进项目组 从已注册用户列表中选择 4、Git工作流 回顾一下Git工作流&#xff0c;工程人员只需要从Develop分支新建自己的分支即可。分支命名以姓名…

CadLib 4.0.2023.31601 net for Windows Crack

CadLib 4.0 for Windows&#xff1a;在 C# VB .NET 中读取、写入和显示 AutoCAD DWG 和 DXF 文件 CadLib 4.0 for Windows仅在Windows上运行&#xff0c;并且基于.NET 4.x。 CadLib 4.0读取、写入和显示 C#、VB.NET 或任何其他 .NET 语言的 AutoCAD™ DWG 和 DXF 文件。下载试…

2-css-3

一 选择器 1 结构伪类选择器 作用&#xff1a;根据元素的结构关系查找元素。 选择器说明E:first-child查找第一个E元素E:last-child查找最后一个E元素E:nth-child(N)查找第N个E元素&#xff08;第一个元素N值为1&#xff09; li:first-child {background-color: green; }2 :…

5.6.3 套接字

5.6.3 套接字 我们先以示例引入套接字的基本内容&#xff0c;我们知道在邮政通信的时候我们需要在信封上写明我们的收件地址&#xff0c;比如北京市海淀区双清路30号清华大学8444号某某某收&#xff0c;这其中我们需要一个物理地址“北京市海淀区双清路30号”&#xff0c;一个…

6.22 驱动开发作业

字符设备驱动内部实现原理 1.字面理解解析&#xff1a; 字符设备驱动的内部实现有两种情况&#xff1a; 情况1.应用层调用open函数的内部实现&#xff1a; open函数的第一个参数是要打开的文件的路径&#xff0c;根据这个路径 虚拟文件系统层VFS 可以找到这个文件在文件系统…

openeuler22.03系统salt-minion启动报“Invalid version: ‘cpython‘“错的问题处理

某日&#xff0c;检查发现一台openeuler22.03 SP1系统的服务器上之前正常运行的saltstack客户端minion未运行&#xff0c;查看服务状态&#xff0c;报"Invalid version: cpython"错&#xff0c;无法正常运行&#xff0c;本文记录问题处理过程。 一、检查salt-minion…

uniapp中小程序的生命周期

一、uni-app应用生命周期 函数名说明onLuaunch当uni-app 初始化完成时触发&#xff08;全局只触发一次&#xff09;onShow当 uni-app 启动&#xff0c;或从后台进入前台显示onHide当 uni-app 从前台进入后台onError当 uni-app 报错时触发onUniNViewMessage对 nvue 页面发送的数…

android jetpack Room的基本使用(java)

数据库的基本使用 添加依赖 //roomdef room_version "2.5.0"implementation "androidx.room:room-runtime:$room_version"annotationProcessor "androidx.room:room-compiler:$room_version"创建表 Entity表示根据实体类创建数据表&#xff0c…

发送图文并茂的html格式的邮件

本文介绍如何生成和发送包含图表和表格的邮件&#xff0c;涉及echarts图表转换为图片、图片内嵌到html邮件内容中、html邮件内容生成、邮件发送方法等 一、图表处理 因为html格式的邮件不支持echarts,也不支持js执行&#xff0c;所以图表需要转换为图片内嵌在邮件内容中 因为平…

【Java】Java核心 73:XML (中)

文章目录 5 XML的组成&#xff1a;字符区(了解)**6** **DTD约束(能够看懂即可)****1** **什么是DTD****2** **DTD约束的实现和语法规则&#xff08;看懂dtd约束&#xff0c;书写符合规范的xml文件&#xff09;** 5 XML的组成&#xff1a;字符区(了解) 当大量的转义字符出现在x…

ansible实训-Day1(Liunx基础问题总结及ansible安装环境前置部署)

一、前言 该篇是对本学期Ansible实训第一天内容的原理性总结&#xff0c;主要包括Liunx相关问题等基础性的问题总结以及ansible安装环境的前置部署。 二、Liunx是什么 Linux是一种自由和开放源代码的Unix操作系统&#xff0c;最初由芬兰人Linus Torvalds于1991年创建。与其他许…

浅谈Spring Cloud Gateway

网关:用户和微服务的桥梁 网关的核心是一组过滤器&#xff0c;按照先后顺序执行过滤操作。 Spring Cloud Gateway是基于webFlux框架实现&#xff0c;而webFlux框架底层则使用了高性能的Reactor模式通信框架的Netty Spring Cloud Gateway是Spring Cloud生态系统中的一个API网…

图解transformer中的自注意力机制

本文将将介绍注意力的概念从何而来&#xff0c;它是如何工作的以及它的简单的实现。 注意力机制 在整个注意力过程中&#xff0c;模型会学习了三个权重:查询、键和值。查询、键和值的思想来源于信息检索系统。所以我们先理解数据库查询的思想。 假设有一个数据库&#xff0c…

使用Servlet完成单表的增删改查功能以及使用模板方法设计模式解决类爆炸问题(重写service模板方法)

使用Servlet做一个单表的CRUD操作 开发前的准备 导入sql脚本创建一张部门表 drop table if exists dept; create table dept(deptno int primary key,dname varchar(255),loc varchar(255) ); insert into dept(deptno, dname, loc) values(10, XiaoShouBu, BeiJing); inser…

Python小游戏集合(开源、开源、免费下载)

Python小游戏集合 0. 前言1. 为什么用Python做游戏2. 小游戏集合及源代码&#xff08;整理不易&#xff0c;一键三连&#xff09;2.1 外星人小游戏2.2 塔防小游戏2.3 三国小游戏2.4 打飞机游戏2.5 飞机大战小游戏2.6 玛丽快跑小游戏2.7 涂鸦跳跃小游戏2.8 猜数字小游戏2.9 坦克…

吃透JAVA的Stream流操作,多年实践总结

在JAVA中&#xff0c;涉及到对数组、Collection等集合类中的元素进行操作的时候&#xff0c;通常会通过循环的方式进行逐个处理&#xff0c;或者使用Stream的方式进行处理。 例如&#xff0c;现在有这么一个需求&#xff1a; 从给定句子中返回单词长度大于5的单词列表&#xf…