F-score 和 Dice Loss 原理及代码

文章目录

    • 1. F-score
    • 1. 1 原理
    • 1. 2 代码
    • 2. Dice Loss
      • 2.1 原理
      • 2.2 代码

通过看开源图像语义分割库的源码,发现它对 Dice Loss 的实现方式,是直接调用 F-score 函数,换言之,Dice LossF-score的特殊情况。于是就研究了一下这背后的原理,作文以记之。

1. F-score

1. 1 原理

首先介绍 F-score:
在这里插入图片描述
要理解F-score,就要先回顾一下 PrecisionRecall,首先给出公式:

在这里插入图片描述
两个指标衡量算法的准确性时,通常是相互排斥的。例如,输入一个数据,算法根据数据预测一个分数,现在为该分数设定阈值,大于阈值的预测为真,小于该阈值的预测为假。

  • 如果这个阈值得过低,低到测试集中所有的样本均判定为真,那么此时,FN=0(False negative, 压根就没有预测出来 negative 的样本),代入公式 (2) 得 Recall = 1。但此时,预测为真的样本中,包含大量的 FP,即 False Positive,将会导致 Precision 过低
  • 如果这个阈值设置得过高,使得所有被判定为正的样本都是真的,那么 FP=0,Precision=1,此时将不可避免有很多本应被判定为正的样本,被错误地判定为负,也就是 FN 很大,导致 Recall 过低

不同的应用场景下,对这两个指标的侧重不同。例如新冠感染者检测,就应该尽量提高 Recall,务求没有漏网之鱼。但在检测垃圾邮件时,应该尽量提升 Precision,即每个被判定为垃圾邮件的,都是板上钉钉毫无争议的,防止出现误伤,把正常邮件当成垃圾邮件处理。

F-score 则是将这两个指标综合起来:
在这里插入图片描述

  • β \beta β控制 Precision 和 Recall 的重要程度, 当 β = 1 \beta=1 β=1, 对应 F1-score,此时 Precision 和 Recall 同样重要。

  • β \beta β两个常用的取值是 0.5 2,当取 0.5 时,Precision 对 F-score 的影响更大,当取 2 时,Recall 对 F-score 的影响更大。(可以考虑得更极端一点,当 β → 0 \beta\rightarrow0 β0,公式(3)趋于 Precision;当 β → ∞ \beta\rightarrow\infty β,公式(3)上下同除以分子,易知其将趋于 Recall)

最后,把 (1) (2) 代入 (3) 得:
在这里插入图片描述

1. 2 代码

def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice系数
    #--------------------------------------------#
    temp_inputs = torch.gt(temp_inputs, threhold).float()
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    score = torch.mean(score)
    return score
  • inputs为分割模型的预测输出,未经过softmax, target为gt
  • temp_target中将channels维度设为num_classes+1,为了方便处理白边,因此在实际计算时需要去掉最后一个channel: temp_target[...,:-1]
    在这里插入图片描述
  • 预测分割图temp_inputs与 GT 分割图的点乘,然后再(n,hw)方向上求和作为tp
    参考自: Dice系数(Dice coefficient)与mIoU与Dice Loss
    在这里插入图片描述
  • 因为预测temp_inputs (pred) = fp+tp, 因此已知temp_inputstp, 就可以求出fp
  • 同理temp_target (gt) = fn+tp, 因此已知temp_targettp, 就可以求出`fn
  • 然后根据F-score的计算公式,在已知tp,fp,fn以及beta系数,就可以计算出F-score值了
    在这里插入图片描述
    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    score = torch.mean(score)

2. Dice Loss

2.1 原理

Dice Loss 是语义分割中常用的一种损失,它的计算方法如下:
在这里插入图片描述
即,分子为预测值与真实值的交集元素数目的两倍,分母为两个集合元素数目之和(注意并不是并集,而是和)。而
在这里插入图片描述
因此,(6) 相当于:
1 − 2 T P 2 T P + F P + F N 1-\frac{2TP}{2TP+FP+FN} 12TP+FP+FN2TP

而上式的结果,正是公式 (5) 中 β = 1 \beta =1 β=1的情况,也就是F1 score。因此,

Dice Loss = 1 - F1 score 

2.2 代码

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss
  • 可以看到dice_loss的实现,跟F-score基本上是一模一样的, 将torch.mean(score)求得的F-soce, 然后通过dice_loss = 1- F-score 来实现。
  • 代码中默认 β = 1 \beta=1 β=1, 所以更精确的说: dice_loss = 1- F1-score
  • DIce _loss的在训练损失中的使用如下:
    在这里插入图片描述

参考:

  • F-score 和 Dice Loss
  • https://github.com/bubbliiiing/deeplabv3-plus-pytorch/blob/main/utils/utils_metrics.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/319007.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python实现网络爬虫代码_python如何实现网络爬虫

python实现网络爬虫的方法:1、使用request库中的get方法,请求url的网页内容;2、【find()】和【find_all()】方法可以遍历这个html文件,提取指定信息。 python实现网络爬虫的方法: 第一步:爬取 使用reque…

定时任务-理论基础

什么是小顶堆 小顶堆(Min Heap)是一种特殊的二叉堆,它满足以下条件: 它是一个完全二叉树,即除了最后一层外,其他层的节点数都是满的,并且最后一层的节点从左到右依次排列。树中的每个节点的…

若依基于jsencrypt实现前后端登录密码加密

若依虽然有加密解密功能,然后只有前端有,在用户点击保存密码的时候,会将密码保存到本地,但是为了防止密码泄露,所以在保存的时候,进行加密,在回显密码的时候进行解密显示,用户在登录…

SpringCloud:Ribbon

文章目录 Ribbon快速入门Ribbon负载均衡算法常见的负载均衡算法更改算法规则修改配置 饥饿加载 Ribbon ribbon是一个客户端负载均衡器,会从注册中心拉取可用服务,当客户端需要获取服务请求时,ribbon能够解析服务地址并实现负载均衡 快速入门 …

Quick taxi route assignment via real-time intersection state prediction

Quick taxi route assignment via real-time intersection state prediction with a spatial-temporal graph neural network(通过时空图神经网络实时交叉口状态预测快速分配出租车路线) PAPER LINK 简单说一下: 本文采用了一种新的方法,通过使用空间-时间图神经网络(ST…

LMDeploy 的量化和部署

LMDeploy 的量化和部署 文档:https://github.com/InternLM/tutorial/blob/vansin-patch-4/lmdeploy/lmdeploy.md 视频:https://www.bilibili.com/video/BV1iW4y1A77P 一、模型量化 大模型参数量很大,运行起来非常消耗显存和内存,…

如何在电脑上免费更改 PDF 格式文档的字体大小?

对于需要编辑或修改的 PDF 文件来说,更改其字体大小是一个非常常见且必要的工作。虽然 Adobe Acrobat Pro DC 等专业的 PDF 编辑软件可以帮助您完成此任务,但他们通常都需要昂贵的恢复。幸运的是,有许多免费的 PDF 编辑工具可供选择。在本文中…

大括号内两行公式中,如何左对齐公式的条件

1. 先建立一个大括号,中间设置一个二维矩阵如下: 2. 选中整个矩阵,不要选外面的括号,进行如下操作 3. 选择左侧对齐 即可。

Docker安装Redis详细步骤

1、创建安装目录 mkdir -p /usr/local/docker/redis-docker 2、确定安装的版本 确定对应的版本,在步骤3中会用到: https://github.com/redis/redis/branches 3、配置docker-compose.yml 内容如下: version: 3 services:redis:image: r…

信息检索速通知识点

仅仅是我自己能想到的对这个分类的一个记忆。欢迎指正 首先,最重要的一点,什么是信息检索? 信息检索是从大规模无规则的数据中(主要是文档)中查询用户所需要的信息的过程。 然后,信息检索有哪几种索引呢&am…

Vue.observable详解(细到原码)

文章目录 一、Observable 是什么二、使用场景三、原理分析参考文献 一、Observable 是什么 Observable 翻译过来我们可以理解成可观察的 我们先来看一下其在Vue中的定义 Vue.observable,让一个对象变成响应式数据。Vue 内部会用它来处理 data 函数返回的对象 返回…

“一键转换PNG至BMP:轻松批量处理,高效优化图片管理“

在数字世界中,图片格式的转换是日常工作中不可或缺的一部分。你是否经常遇到需要将PNG格式的图片转换为BMP格式的需求?是否在处理大量图片时,希望能够实现一键批量转换,提高工作效率? 首先,我们进入首助编…

迎接数智时代:数字经济引领可视化转型

在数字经济的持续崛起下,企业正在进行数字化转型,其中可视化和数智化成为关键驱动力。NFC技术的应用更是为这一转型提供了新的可能性。 数字经济塑造未来: 数字经济的兴起标志着企业正进入一个全新的时代。通过数字技术,企业可…

如何使用创建时间给文件重命名,简单的批量操作教程

在处理大量文件时,有时要按照规则对文件重命名,根据文件的创建时间来重命名。那如何批量操作呢?现在一起来看云炫文件管理器如何用文件的创建时间来批量重命名。 按创建时间重命名文件的前后对比图。 用创建时间批量给文件重命名的步骤&…

数据仓库(3)-模型建设

本文从以下9个内容,介绍数据参考模型建设相关内容。 1、OLTP VS OLAP OLTP:全称OnLine Transaction Processing,中文名联机事务处理系统,主要是执行基本日常的事务处理,比如数据库记录的增删查改,例如mysql、oracle…

OpenJDK 和 OracleJDK 哪个jdk更好更稳定,正式项目用哪个呢?关注者

OpenJDK 和 OracleJDK:哪个JDK更好更稳定,正式项目应该使用哪个呢?我会从,从开源性质、更新和支持、功能差异等方面进行比较,如何选择,哪个jdk更好更稳定,正式项目用哪个呢,进行比较…

小米数据恢复软件:如何从小米手机恢复已删除的数据

“买一部小米手机,送一个移动硬盘”。人们惊叹于小米手机以非常合理的价格提供的大容量。我们甚至可以把小米手机当做一个移动硬盘来使用,存储大量的照片、视频、文档等文件。但是,在我们使用手机的过程中,误删的情况时有发生&…

AI编程可视化Java项目拆解第一弹,解析本地Java项目

之前分享过一篇使用 AI 可视化 Java 项目的文章,同步在 AI 破局星球、知乎、掘金等地方都分享了。 原文在这里AI 编程:可视化 Java 项目 有很多人感兴趣,我打算写一个系列文章拆解这个项目,大家多多点赞支持~ 今天分享的是第一…

学习使用Rainyun搭建网站

我们选择了白嫖雨云的二级域名 浏览器输入https://www.rainyun.com/z22_ 创建账号然后选择一个你喜欢的子域名我建议后缀选择ates.top的 选择自定义地址,类型选择cname 现在要选择记录值了,有a,aa,txt等 根据实际情况填写。就可以…

【CAN】CANoe添加模拟节点报错解决方法

文章目录 1. 问题现象2. 问题解决方法 >>返回总目录<< 1. 问题现象 通过CANoe添加模拟节点时&#xff0c;提示无法加载动态链接库CANOEILNLSPA.DLL。 2. 问题解决方法 右键模拟节点&#xff0c;选择Configuration选项&#xff0c;弹出Node Configuration界面&am…