语义分割(3):损失函数解析

文章目录

    • 1. 常见语义分割损失
      • 1.1 Cross Entropy
      • 1.2 dice Loss
        • 1.2.1 为什么使用Dice loss
        • 1.2.2 公式
        • 1.2.3 Dice loss 和 F1-score代码
      • 1.3 focal loss
        • 1.3.1 公式:
        • 1.3.2 代码
    • 2. 语义分割损失应用
    • 参考

语义分割任务实际上是一种像素层面上的分类,需要识别图像中存在的内容和位置,同样也存在与分类类似问题-样本类别不平衡,对于语义分割更多的是前景区域的样本远小于背景区域。针对类别不平衡问题,在loss层面上有不同的选择。

1. 常见语义分割损失

1.1 Cross Entropy

用于图像语义分割任务的最常用损失函数是像素级别的交叉熵损失,这种损失会逐个检查每个像素,将对每个像素类别的预测结果(概率分布向量)与我们的one-hot标签向量进行比较。
p i x e l − l o s s = − ∑ c l a s s e s y t r u e l o g ( y p r e d ) pixel-loss=- \sum_{classes}y_{true}log(y_{pred}) pixelloss=classesytruelog(ypred)
整个图像的损失就是对每个像素的损失求平均值

PytorchCrossEntropyLoss()函数的主要是将log_softmax NLLLoss最小化负对数似然函数)合并到一块得到的结果

CrossEntropyLoss()=log_softmax() + NLLLoss() 

在这里插入图片描述

  • (1) 首先对预测值pred进行softmax计算:其中softmax函数又称为归一化指数函数,它可以把一个多维向量压缩在(0,1)之间,并且它们的和为1
    在这里插入图片描述
  • (2) 然后对softmax计算的结果,再取log对数。
  • (3) 最后再利用NLLLoss() 计算CrossEntropyLoss, 其中NLLLoss() 的计算过程为:将经过log_softmax计算的结果与target 相乘并求和,然后取反。

其中(1),(2)实现的是log_softmax计算,(3)实现的是NLLLoss(), 经过以上3步计算,得到最终的交叉熵损失的计算结果。

详见: 深度学习loss总结:nn.CrossEntropyLoss,nn.MSELoss,Focal_Loss,nn.KLDivLoss等

1.2 dice Loss

  • Dice Loss 最先是在VNet 这篇文章中被提出,后来被广泛的应用在了医学影像分割之中。
  • Dice Loss,也叫Soft Dice Coefficient,是一种广泛用于图像分割任务的损失函数。它基于目标分割图像模型输出结果之间的重叠区域的比例计算出分数。与交叉熵损失函数相比,它更适合于处理难分割的目标
  • Dice Loss是由Dice系数而得名的,Dice系数是一种用于评估两个样本集合相似性的度量函数,其值越大意味着这两个样本集越相似。
1.2.1 为什么使用Dice loss

Dice Loss在处理类别不平衡目标小但多的图像分割任务时有着很好的性能。交叉熵损失函数忽略了预测值和目标值之间的相似性,并且对于极端的像素值不够敏感。而Dice Loss是基于相似性的评价指标,它看重相同的像素值,可以很好地处理像素值不平衡的情况。

1.2.2 公式

D i c e = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice =\frac{2|X\cap{Y}|}{|X|+|Y|} Dice=X+Y2∣XY

其中 ∣ X ∩ Y ∣ |X\cap{Y}| XY表示X和Y之间交集元素的个数, ∣ X ∣ |X| X ∣ Y ∣ |Y| Y分别表示X、Y中元素的个数。分子乘2为了保证分母重复计算后取值范围在$[0,1]之间, D i c e L o s s Dice Loss DiceLoss表达式如下:
D i c e l o s s = 1 − D i c e = 1 − 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ Dice loss= 1-Dice =1-\frac{2|X\cap{Y}|}{|X|+|Y|} Diceloss=1Dice=1X+Y2∣XY

  • Dice Loss常用于语义分割问题中,X表示真实分割图像的像素标签,Y表示模型预测分割图像的像素类别, ∣ X ∩ Y ∣ |X\cap Y| XY 近试为预测图像的像素与真实标签图像的像素之间的点乘,并将点乘结果相加, ∣ X ∣ |X| X ∣ Y ∣ |Y| Y分别近似为他们各自对应图像像素相加。

D i c e L o s s = 1 − 2 ∑ i = 1 N y i y 1 ^ ∑ i = 1 N y i + ∑ i = 1 N y i ^ DiceLoss = 1-\frac{2\sum_{i=1}^{N}y_{i}\hat{y_1}}{\sum_{i=1}^N y_i +\sum_{i=1}^N \hat{y_i}} DiceLoss=1i=1Nyi+i=1Nyi^2i=1Nyiy1^

可以说Dice Loss是直接优化F1 score而来的,是对F1 score的高度抽象,可用于多分类分割问题上。F1 score就被提出,其公式如下:

F 1 s c o r e = 2 P R P + R = 2 T P 2 R P + F P + F N F1 score = \frac{2PR}{P+R} = \frac {2TP}{2RP+FP+FN} F1score=P+R2PR=2RP+FP+FN2TP

在二分类问题中,Dice系数也可以写成 D i c e = 2 T P 2 T P + F P + F N = F 1 s c o r e Dice = \frac {2TP}{2TP+FP+FN}=F1score Dice=2TP+FP+FN2TP=F1score

D i c e L o s s = 1 − D i c e = 1 − F 1 s c o r e Dice_{Loss} =1- Dice= 1 - F1_{score} DiceLoss=1Dice=1F1score

1.2.3 Dice loss 和 F1-score代码

(1) F1-score

def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice系数
    #--------------------------------------------#
    temp_inputs = torch.gt(temp_inputs, threhold).float()
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    score = torch.mean(score)
    return score
  • 上述是F-score的代码,F1-ScoreF-score的一种特殊形式,即当 beta=1, F-score等价于F1-Score
    在这里插入图片描述
  • inputs为分割模型的推理预测输出shape为(n,h,w,c), 其中c为cls_nums,未经过softmax处理, 因此在计算F1-score时,需要进行softmax处理。target为 为真实的分割标签one-hot编码格式,shape为(n,h,w,c)
  • 可以利用如下代码,将标签mask图片png转为one-hot编码seg_labels
#-------------------------------------------------------#
#   转化成one_hot的形式
#   在这里需要+1是因为voc数据集有些标签具有白边部分
#   我们需要将白边部分进行忽略,+1的目的是方便忽略。
#-------------------------------------------------------#
seg_labels  = np.eye(self.num_classes + 1)[png.reshape([-1])]
seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
  • png.reshape([-1])得到shape大小为(h*w,1), 其中每个元素值为类别的索引, 然后根据类别索引从np.eye(self.num_classes + 1)取对应的行,这样就将mask转化为one-hot形式的seg_labels
  • 然后将seg_labels reshape为(h,w,self.num_classes + 1)
  • 在这里需要+1是因为voc数据集有些标签具有白边部分, +1的目的是方便忽略
    (2) Dice loss
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss
  • inputs分割模型的推理预测输出shape为(n,h,w,c), target为 为真实的分割标签one-hot编码格式shape为(n,h,w,c)

  • 可以看到dice_loss的实现,跟F1-score基本上是一模一样的, 将torch.mean(score)求得的F1-soce, 然后通过dice_loss = 1- F1-score 来实现。

1.3 focal loss

上面针对不同类别的像素数量不均衡提出了改进方法,但有时还需要将像素分为难学习容易学习这两种样本。

容易学习的样本模型可以很轻松地将其预测正确,模型只要将大量容易学习的样本分类正确,loss就可以减小很多,从而导致模型不怎么顾及难学习的样本,所以我们要想办法让模型更加关注难学习的样本

1.3.1 公式:

(1) 交叉熵损失:
L c e − ∑ q t l o g 2 ( p t ) = − l o g ( p t ) L_{ce} - \sum q_t log_2(p_t)= -log(p_t) Lceqtlog2(pt)=log(pt)

  • 其中:log一般以e或者2为底都是可以的。参考:KL散度、CrossEntropy详解,
  • 因为对于分类或者分割任务, q t q_t qtone-hot编码,只在真实类别处为1,其他都是0, 所以交叉熵等效为 − l o g ( p t ) -log(p_t) log(pt)

(2) Focal Loss
F L ( p t ) = − a t ( 1 − p t ) r l o g ( p t ) FL(p_t)=-a_t(1-p_t)^rlog(p_t) FL(pt)=at(1pt)rlog(pt)

  • a t a_t at 是用来平衡正负样本数量的:基于样本非平衡造成的损失函数倾斜,一个直观的做法就是在损失函数中添加权重因子,提高少数类别在损失函数中的权重,平衡损失函数的分布。 p t p_t pt 表示预测的概率值或者置信度 p t p_t pt越大说明预测越接近focal loss 设置了一个modulating factor: ( 1 − p t ) r (1-p_t)^r (1pt)r用来区分样本预测的难易程度, 当预测的概率 p t p_t pt越接近于1,说明样本容易预测,此时 ( 1 − p t ) r (1-p_t)^r (1pt)r趋近于0;当预测概率 p t p_t pt越接近于0时,说明样本比较难预测,此时 ( 1 − p t ) r (1-p_t)^r (1pt)r趋近于1。 整体而言,通过modulating factor因子的调节,相当于增加了难分样本在损失函数中的权重

可以看出Focal Loss是在交叉熵Cross Entropy基础上演进而来的,相比于交叉熵损失,多了一个类别平衡系数 a t a_t at,以及区分样本难分程度的因子modulating factor: ( 1 − p t ) r (1-p_t)^r (1pt)r

1.3.2 代码
def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2):
    n, c, h, w = inputs.size()
    nt, ht, wt = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)

    temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    temp_target = target.view(-1)

    logpt  = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
    pt = torch.exp(logpt)
    if alpha is not None:
        logpt *= alpha
    loss = -((1 - pt) ** gamma) * logpt
    loss = loss.mean()
    return loss
  • inputs为分割模型的推理预测输出shape为(n,h,w,c)
  • 由于 nn.CrossEntropyLoss内部已经包含了Softmax处理,因此不需要对模型的预测输出inputs进行Softmax计算。
  • 对于分类、分割任务而言,focal loss相比于交叉熵损失CE − l o g ( p t ) -log(p_t) log(pt),多了一个类别平衡系数 a t a_t at,以及区分样本难分程度的因子modulating factor: ( 1 − p t ) r (1-p_t)^r (1pt)r,因此可以先求CrossEntropyLoss,然后乘以类别平衡系数以及区分样本难分程度的因子modulating factor: ( 1 − p t ) r (1-p_t)^r (1pt)r, 就可以求出focal loss
logpt  = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
 pt = torch.exp(logpt)
 if alpha is not None:
     logpt *= alpha
 loss = -((1 - pt) ** gamma) * logpt
 loss = loss.mean()
  • cls_weights: 衡量每个类别的重要程度, cls_weights数组元素个数与类别数一样。代码中默认设置各个类别重要度一样:
cls_weights     = np.ones([num_classes], np.float32)
  • 在计算CrossEntropyLoss, target可以是one-hot格式,也可以直接输出类别不需要进行one-hot处理(此时pytorch 内部会自动帮忙进行one-hot编码)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)

logpt  = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
  • CrossEntropyLoss根据预测输出和真实的target,就可以计算出交叉熵损失, 注意预测的输出需要reshape到(-1,c), target需要reshape[-1](一维),每个元素为像素的类别索引

2. 语义分割损失应用

  • 在语义分割中,focal_lossCross entropy loss以及dice_loss,通常需要结合一起使用。代码参考:https://github.com/bubbliiiing/deeplabv3-plus-pytorch/blob/main/train.py
  • 可以看到作者定义语义分割的损失中,结合了这三类损失,默认只使用的是Cross entropy loss, 如果存在素类别不平衡,以及难分的像素类别,可以将focal_lossdice_loss两个损失一起叠加使用。
dice_loss       = False
focal_loss      = Fals
cls_weights     = np.ones([num_classes], np.float32)

for iteration, batch in enumerate(train_dataloader):
        imgs, pngs, labels = batch

        with torch.no_grad():
            weights = torch.from_numpy(cls_weights)
            if cuda:
                imgs    = imgs.cuda(local_rank)
                pngs    = pngs.cuda(local_rank)
                labels  = labels.cuda(local_rank)
                weights = weights.cuda(local_rank)
        #----------------------#
        #   清零梯度
        #----------------------#
        optimizer.zero_grad()
		outputs = model_train(imgs)
		 #----------------------#
		 #   计算损失
		 #----------------------#
		 if focal_loss:
		     loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
		 else:
		     loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
		
		 if dice_loss:
		     main_dice = Dice_loss(outputs, labels)
		     loss      = loss + main_dice
  • 其中imgs为图片数据,shape大小为(n,c,h,w) ,
  • pngs 为mask标签图片,shape大小为(n,h,w), 为8位单通道图像,每个像素值对应类别信息
    利用imgspngs可以计算focal_loss 和 CE_Loss
  • labels 为标签pngs的one-hot形式,shape大小为(n,h,w,c+1), 在计算Dice_loss时,需要将标签转换为one_hot格式。
  • weights是用来,给各个类别附加权重的,weight数组元素需要和类别数一样

参考

https://github.dev/bubbliiiing/deeplabv3-plus-pytorch
https://zhuanlan.zhihu.com/p/101773544?utm_id=0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/354047.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

回归预测 | Matlab基于SSA-SVR麻雀算法优化支持向量机的数据多输入单输出回归预测

回归预测 | Matlab基于SSA-SVR麻雀算法优化支持向量机的数据多输入单输出回归预测 目录 回归预测 | Matlab基于SSA-SVR麻雀算法优化支持向量机的数据多输入单输出回归预测预测效果基本描述程序设计参考资料 预测效果 基本描述 1.Matlab基于SSA-SVR麻雀算法优化支持向量机的数据…

Qlik Sense 使用Join合并表格

Join | Windows 版 Qlik Sense帮助 什么是Qlik Sense的Join join 前缀可连接加载的表格和现有已命名的表格或最近创建的数据表。本质上跟SQL的Join很类似。 联接数据的效果是通过一组额外的字段或属性扩展目标表,即目标表中不存在的字段或特性。源数据集和目标表之间…

牛客——只能吃土豆的牛牛(进制转化)

链接:登录—专业IT笔试面试备考平台_牛客网 来源:牛客网 旅行完了的牛牛又胖了,于是他终于下决心要戒掉零食,所以他带着他最爱的土豆回到了牛星,开始了在牛星种土豆和只吃土豆减肥的日子。(吃土豆能减肥…

Future模式先给您提货单

Future模式是一种设计模式,用于在处理耗时操作时提高程序的响应性。 角色介绍: Main类: 负责向Host发出请求并获取数据的类。 Host类: 负责向请求返回FutureData的实例的类,起到调度的作用。 Data接口: 表示访问数据的方法的接口,由FutureD…

S275智慧煤矿4G物联网网关:矿山开采的未来已来

随着经济发展煤矿需求不断激增,矿山矿井普遍处于偏远山区,生产管理、人员安全、生产效率是每个矿山矿井都需要考虑的问题,利用网关对现场终端设备连接组网,实现智慧煤矿远程管理。 各矿山矿井分布范围比较广泛,户外环…

python内置函数有哪些?整理到了7大分类48个函数,都是工作中常用的函数

python内置函数 一、入门函数 1.input() 功能: 接受标准输入,返回字符串类型 语法格式: input([提示信息])实例: # input 函数介绍text input("请输入信息:") print("收到的数据是:%s" % (text))#输出…

Qt Design Studio+Pyside项目

Qt Design Studio设计出的项目结构有多个层级的目录,我们直接用类似Qt Creator工具的方式加载main.qml文件时会报错提示module "content" is not installed,将content加入importPath后还是报同样的错误。 Qt Design Studio生成的文件包含了.qm…

lv14 内核内存管理、动态分频及IO访问 12

一、内核内存管理框架 内核将物理内存等分成N块4KB,称之为一页,每页都用一个struct page来表示,采用伙伴关系算法维护 补充: Linux内存管理采用了虚拟内存机制,这个机制可以在内存有限的情况下提供更多可用的内存空…

路由、组件目录存放

文章目录 单页应用程序:SPA- Single Page Application路由的介绍VuePouter的介绍VueRouted 的使用 组件目录存放问题(组件分类) 单页应用程序:SPA- Single Page Application 单页应用(SPA):所有功能在一个…

Springmvc-@RequestBody

SpringBoot-2.7.12 请求的body参数无法转换,服务端没有报错信息打印,而是响应的状态码是400 PostMapping("/static/user") public User userInfo(RequestBody(required false) User user){user.setAge(19);return user; }PostMapping("…

算法设计与分析实验一:二分查找

目录 一、有序数组中的单一元素 1.1思路 1.2 代码实现 1.3 运行结果 二、长度最小的子数组 2.1思路 2.2 代码 2.3 运行结果 三、 山脉数组中查找目标值 3.1 思路 3.2 代码 3.3 运行结果 四、寻找旋转排序数组中的最小值 4.1思路 4.2代码 4.3 运行结果 一、有…

超越 Node.js:Bun 的创新与突破

1. Bun Bun 是一个全新的 JavaScript 运行时,类似于 Node.js 和 Deno,它专注于提供出色的性能和开发者体验。Bun 的一些特点包括: 快速的性能:Bun 旨在提供高性能,无论是启动时间、执行速度还是安装依赖包的速度。 兼…

ORM-02-Hibernate 对象关系映射(ORM)框架

拓展阅读 The jdbc pool for java.(java 手写 jdbc 数据库连接池实现) The simple mybatis.(手写简易版 mybatis) Hibernate Hibernate ORM 允许开发者更轻松地编写那些数据在应用程序进程结束后仍然存在的应用程序。 作为一个对象关系映射&#xff08…

蓝桥杯省赛无忧 编程14 肖恩的投球游戏加强版

#include <stdio.h> #define MAX_N 1003 int a[MAX_N][MAX_N], d[MAX_N][MAX_N]; // 差分数组的初始化 void init_diff(int n, int m) {for (int i 1; i < n; i) {for (int j 1; j < m; j) {d[i][j] a[i][j] - a[i-1][j] - a[i][j-1] a[i-1][j-1];}} } // 对差…

【王道数据结构】【chapter2线性表】【P44t17~t20】【统考真题】

目录 2009年统考 2012年统考 2015年统考 2019年统考 2009年统考 #include <iostream>typedef struct node{int data;node* next; }node,*list;list Init() {list head(list) malloc(sizeof (node));head->next nullptr;head->data-1;return head; }list Buyne…

QA-GNN: 使用语言模型和知识图谱的推理问答

Abstract 使用预训练语言模型&#xff08;LMs&#xff09;和知识图谱&#xff08;KGs&#xff09;的知识回答问题的问题涉及两个挑战&#xff1a;在给定的问答上下文&#xff08;问题和答案选择&#xff09;中&#xff0c;方法需要&#xff08;i&#xff09;从大型知识图谱中识…

C++:auto 关键字 范围for

目录 auto 关键字&#xff1a; 起源&#xff1a; auto的使用细则&#xff1a; auto不能推导的场景&#xff1a; 范围for&#xff1a; 范围for的使用条件&#xff1a; C的空指针&#xff1a; 注意&#xff1a; auto 关键字&#xff1a; 起源&#xff1a; 随着程序越…

蜡烛图采用PictureBox控件绘制是实现量化的第一步

股票软件中的蜡烛图是非常重要的一个东西&#xff0c;这里用VB6.0自带的Picture1控件的Line方法就可以实现绘制。 关于PictureBox 中的line 用法 msdn 上的说明为如下所示 object.Line [Step] (x1, y1) [Step] - (x2, y2), [color], [B][F] 然…

【Axure教程0基础入门】02高保真基础

02高保真基础 1.高保真原型的要素 &#xff08;1&#xff09;静态高保真原型图 尺寸&#xff1a;严格按照截图比例&#xff0c;参考线 色彩&#xff1a;使用吸取颜色&#xff0c;注意渐变色 贴图&#xff1a;矢量图/位图&#xff0c;截取&#xff0c;覆盖等 &#xff08;…

【Java Kubernates】Java调用kubernates提交Yaml到SparkOperator

背景 目前查询框架使用的是trino&#xff0c;但是trino也有其局限性&#xff0c;需要准备一个备用的查询框架。考虑使用spark&#xff0c;spark operator也已经部署到k8s&#xff0c;现在需要定向提交spark sql到k8s的sparkoperator上&#xff0c;使用k8s资源执行sql。 对比 …