【Loss总结】适用与弱监督语义分割中的各类loss

【Loss总结】适用与弱监督语义分割中的各类loss

文章目录

  • 【Loss总结】适用与弱监督语义分割中的各类loss
    • 交叉熵损失
      • 相对熵(KL散度)
      • 交叉熵
    • L1 Loss
    • L2Loss
    • SmoothL1Loss
    • Dice loss
      • 梯度分析
      • 语义分割代码

交叉熵损失

交叉熵损失函数(CrossEntropy Loss),它是分类问题中经常使用的一种损失函数

在模型的输出层总会接一个softmax函数

交叉熵是信息论中的一个重要概念,主要用于度量两个概率分布间的差异性

构建理论:

  • 信息是用来消除随机不确定性的东西
  • 信息量的大小与信息发生的概率成反比。概率越大,信息量越小。概率越小,信息量越大


在这里插入图片描述

I(x)表示信息量,P(x)表示概率

I(x)为正值,因为log(P(x))为负值

信息熵

信息熵也被称为熵,用来表示所有信息量的期望

期望是试验中每次可能结果的概率乘以其结果的总和。

信息量的熵可表示为:


在这里插入图片描述

在这里插入图片描述

0-1分布问题

一件事情发生的概率为P(x) ,则另一件事情发生的概率为1-P(x)

在这里插入图片描述

相对熵(KL散度)

有两个单独的概率分布P ( x ) 和Q ( x ) ,则我们可以使用KL散度来衡量这两个概率分布之间的差异。

在这里插入图片描述

P ( x ) 来表示样本的真实分布,Q ( x ) 来表示模型所预测的分布

真实分布P ( X ) = [ 1 , 0 , 0 ] , 预测分布Q ( X ) = [ 0.7 , 0.2 , 0.1 ]


在这里插入图片描述

将对应的P,Q值进行代入

KL散度越小,表示P ( x ) 与Q ( x ) 的分布更加接近,可以通过反复训练Q ( x ) 来使Q ( x ) 的分布逼近P ( x ) 。

交叉熵

KL散度公式拆开:

在这里插入图片描述

前者H ( p ( x ) ) )表示信息熵,后者即为交叉熵,KL散度 = 交叉熵 - 信息熵
当KL散度完全一致(loss=1)DKL=1 · log1

交叉熵公式表示为:

在这里插入图片描述

输入数据与标签常常已经确定,那么真实概率分布P ( x ) 也就确定下来了,所以信息熵在这里就是一个常量

在线性回归问题中,常常使用MSE(Mean Squared Error)作为loss函数,而在分类问题中常常使用交叉熵作为loss函数。

在这里插入图片描述

其中一个batch的loss为


在这里插入图片描述

n表示预测的类别数(猫,狗 ,马)n=3,m表示样本数量(batch_size=16)m=16


L1 Loss

也就是L1 Loss了,它有几个别称:

  • L1 范数损失
  • 最小绝对值偏差(LAD)
  • 最小绝对值误差(LAE)

MAE也是指L1 Loss损失函数

目标值 yi与模型输出f(xi)(估计值)做绝对值得到的误差。

在这里插入图片描述

应用场景:

  1. 回归任务
  2. 简单的模型
  3. 由于神经网络通常是解决复杂问题,所以很少使用。

公式: torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

创建一个绝对值误差损失函数

在这里插入图片描述

用法:reduction='mean','None','sum'

mean 对应(x, y)的均值,None 对应(x, y)的单独值,sum对应(x, y)的总值

L2Loss

也就是L2 Loss了,它有几个别称:

  • L2 范数损失
  • 最小均方值偏差(LSD)
  • 最小均方值误差(LSE)

最常看到的MSE也是指L2 Loss损失函数,PyTorch中也将其命名为torch.nn.MSELoss

目标值 yi与模型输出f(xi)(估计值)做差然后平方得到的误差

在这里插入图片描述

应用场景:

  1. 回归任务

  2. 数值特征不大

  3. 问题维度不高

对离群点比较敏感,如果feature是unbounded的话,需要好好调整学习率,防止出现梯度爆炸的情况。l2正则会让特征的权重不过大,使得特征的权重比较平均。

SmoothL1Loss

简单来说就是平滑版的L1 Loss

SoothL1Loss的函数如下:

在这里插入图片描述

仔细观察可以看到,当预测值和ground truth差别较小的时候(绝对值差小于1),其实使用的是L2 Loss;而当差别大的时候,是L1 Loss的平移

SooothL1Loss其实是L2Loss和L1Loss的结合,它同时拥有L2 Loss和L1 Loss的部分优点。

优点:

  1. 当预测值和ground truth差别较小的时候(绝对值差小于1),梯度不至于太大。(损失函数相较L1 Loss比较圆滑)
  2. 当差别大的时候,梯度值足够小(较稳定,不容易梯度爆炸)。

应用:

  1. 回归
  2. 当特征中有较大的数值
  3. 适合大多数问题

Dice loss

dice loss 来自 dice coefficient,是一种用于评估两个样本的相似性的度量函数,取值范围在0到1之间,取值越大表示越相似

在这里插入图片描述

dice loss可以写为:

在这里插入图片描述

梯度分析

从dice loss的定义可以看出,dice loss 是一种区域相关的loss。意味着某像素点的loss以及梯度值不仅和该点的label以及预测值相关,和其他点的label以及预测值也相关,这点和ce (交叉熵cross entropy) loss 不同

从loss曲线和求导曲线对单点输出方式分析。然后对于多点输出的情况,利用模拟预测输出来分析其梯度

Dice Loss可以缓解样本中前景背景(面积)不平衡带来的消极影响,前景背景不平衡也就是说图像中大部分区域是不包含目标的,只有一小部分区域包含目标

语义分割代码

class DiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes  # 物体的输入数量

    # 没有问题,但是需要的是进行一个one_hot_的解码,来满足6个特征图
    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob.unsqueeze(1))
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)  # 12, 6, 256, 256
        target = self._one_hot_encoder(target)  # [12, 6, 256, 256]
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(),
                                                                                                  target.size())
        class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/447765.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大模型优化——重排序模型

检索增强生成(RAG)技术作为自大模型兴起后爆火的方向之一,已经广受研发者们追捧,大型语言模型(LLMs)如GPT系列和LLama系列在自然语言处理领域取得了显著的成功,但它们面临着幻觉、过时知识和不透明、不可追溯的推理过程等挑战。检索增强生成(RAG)通过整合外部数据库的…

VideoDubber时长可控的视频配音方法

本次分享由中国人民大学、微软亚洲研究院联合投稿于AAAI 2023的一篇专门为视频配音任务定制的机器翻译的工作《VideoDubber: Machine Translation with Speech-Aware Length Control for Video Dubbing》。这个工作将电影或电视节目中的原始语音翻译成目标语言。 论文地址&…

【Python】【Matplotlib】解决使用 plt.savefig() 保存的图片出现一片空白的问题

【Python】【Matplotlib】解决使用 plt.savefig() 保存的图片出现一片空白的问题 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#…

Leetcode : 1137. 高度检查器

学校打算为全体学生拍一张年度纪念照。根据要求,学生需要按照 非递减 的高度顺序排成一行。 排序后的高度情况用整数数组 expected 表示,其中 expected[i] 是预计排在这一行中第 i 位的学生的高度(下标从 0 开始)。 给你一个整数…

CV论文--2024.3.7

1、FAR: Flexible, Accurate and Robust 6DoF Relative Camera Pose Estimation 中文标题:FAR:灵活、准确和稳健的6DoF相机相对姿态估计 简介:在计算机视觉领域,估计图像之间的相对相机姿态一直是一个关键问题。通常,…

php导出excel文件

环境 php7.4hyperf3composer require phpoffice/phpspreadsheet代码 class IndexController extends AbstractController { /*** Inject* var Picture*/private $picture;public function index(){$res_data[]["robot" > 哈哈机器人,"order" > TES…

记录一下C++的学习之旅吧--C++基础

文章目录 前言using namespace std; 使用标准命名空间一、helloworld-输出表示1.1代码1.2 运行结果 二、变量2.1.1 普通变量代码2.1.2 运行结果2.2.1 常量和变量代码2.2.2 运行结果 三、sizeof---统计数据类型所占的内存大小3.1 代码3.2 运行结果 四、小数表示4.2 运行结果 总结…

02- 使用Docker安装RabbitMQ

使用Docker安装RabbitMQ 下载安装镜像 方式一: 启动docker服务,然后在线拉取 # 在线拉取镜像 docker pull rabbitmq:3-management# 使用docker images查看是否已经成功拉取方式二: 从本地加载 ,将RabbitMQ上传到虚拟机中后使用命令加载镜像即可 docker load -i mq.tar启动M…

mabatis 中

手动实现MaBatis底层机制 实现任务阶段一🍍完成读取配置文件, 得到数据库连接🥦分析 代码实现🥦完成测试 实现任务阶段二🍍编写执行器, 输入SQL语句, 完成操作🥦分析 代码实现🥦完成测试 实现任务阶段三&…

Redis缓存预热-缓存穿透-缓存雪崩-缓存击穿

什么叫缓存穿透? 模拟一个场景: 前端用户发送请求获取数据,后端首先会在缓存Redis中查询,如果能查到数据,则直接返回.如果缓存中查不到数据,则要去数据库查询,如果数据库有,将数据保存到Redis缓存中并且返回用户数据.如果数据库没有则返回null; 这个缓存穿透的问题就是这个…

使用python将数据输出为图表图片

数据示例(数组或其他): hourly_data {00:00: 10,01:00: 15,02:00: 20,03:00: 25,04:00: 30,# 添加更多数据... }示例输出(图片): python代码: 下面代码中使用了matplotlib库,如果…

Mac系统:mysql+jdk+neo4j

mysql 指令 //启动MySQL服务 sudo /usr/local/mysql/support-files/mysql.server start//停止MySQL服务 sudo /usr/local/mysql/support-files/mysql.server stop //连接MySQL数据库,在进行这一步前要先关掉服务 mysql -u root -p //检查MySQL服务状态 sudo /us…

JDK17镜像制作

背景 获取JDK17 wget https://download.oracle.com/java/17/latest/jdk-17_linux-x64_bin.tar.gz 解压JDK tar -zxvf jdk-17_linux-x64_bin.tar.gz 制作JRE 由于jdk的体积比较大,可以使用jre来作为运行环境,jdk1.8及以前版本,自带jre&#…

力扣--动态规划/回溯算法131.分割回文串

思路分析: 动态规划 (DP): 使用动态规划数组 dp,其中 dp[i][j] 表示从字符串 s[i] 到 s[j] 是否为回文子串。预处理动态规划数组: 从字符串末尾开始,遍历每个字符组合,判断是否为回文子串,填充…

后悔没有早点看到这份产品说明书模板

产品说明书是连接产品与消费者的桥梁,它对产品具有多重好处。一份设计精良、内容准确的产品说明书有助于消费者全面了解产品,确保用户正确使用产品;减少消费者因误操作导致的故障,降低企业的售后服务成本;增强消费者对…

GaLore的全称是“Gradient Low-Rank Projection“,翻译过来就是“梯度低秩投影“

鉴于大家对GaLore比较感兴趣,我今天试着结合论文做一个更深入的解读: GaLore的全称是"Gradient Low-Rank Projection",翻译过来就是"梯度低秩投影"。它的核心思想是通过降低优化器状态的秩,来大幅减少内存占用。 在训练大模型时,我们需要存储三类数据:模型…

操作系统基础

进程与线程 进程之间如何通讯 用户态与核心态 进程空间 操作系统内存管理 TBL TBL 多级页表虽然解决了空间上的问题,但是我们发现这种方式需要走多道转换才能找到映射的物理内存地址,经过的多道转换造成了时间上的开销。 程序是局部性的,即…

新质生产力简介

新质生产力简介 新质生产力概述: 新质生产力是以科技创新为核心,实现关键性颠覆性技术突破,推动社会经济发展的高效能、高质量生产力。 新质生产力的本质 新质生产力的本质是“科技创新” 新质生产力的核心是科技创新 新质生产力简介 新质…

全面对比Amazon DocumentDB 与 MongoDB

在云中部署 MongoDB 似乎有多种选择。例如,Amazon DocumentDB自称是完全支持 MongoDB API 的 AWS 原生数据库。虽然它支持一些 MongoDB 功能,但需要注意的是 DocumentDB 并不完全兼容 MongoDB。要在 AWS 上访问功能齐全的“MongoDB 即服务”,…

微服务技术栈SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式(五):分布式搜索 ES-上

文章目录 一、ElasticSearch1.1 概述1.2 倒排索引1.3 ES与MySQL的概念对比 二、 安装2.1 部署单点ES2.2 部署kibana 三、安装IK分词器3.1 在线安装ik插件(较慢)3.2 离线安装ik插件(推荐)3.3 扩展词词典3.4 停用词词典 四、索引库操…