损失函数:Cross Entropy Loss (交叉熵损失函数)

损失函数:Cross Entropy Loss (交叉熵损失函数)

  • 前言
  • 相关介绍
  • Softmax函数
    • 代码实例
  • Cross Entropy Loss (交叉熵损失函数)
    • Cross Entropy Loss与BCE loss区别
    • 代码实例

前言

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

相关介绍

损失函数(Loss Function)在机器学习和深度学习中扮演着至关重要的角色,它是一个评估模型预测输出与真实标签之间差异程度的函数。损失函数量化了模型预测错误的程度,并在训练过程中作为优化的目标,模型通过不断地调整内部参数以最小化损失函数的值,从而实现更好的拟合数据和泛化能力。

主要特性与作用:

  1. 量化误差:损失函数将模型预测值与实际目标值之间的差异转化为数值,这样就可以通过数值大小直观地衡量模型的预测效果。

  2. 优化导向:在训练神经网络时,优化算法(如梯度下降法)会根据损失函数的梯度来更新模型参数,使损失函数朝着最小化方向移动。

  3. 种类多样:根据不同的任务和需求,有多种不同的损失函数可供选择。例如,在二分类任务中常用的有二元交叉熵损失函数(Binary Cross-Entropy Loss/BCE Loss),在多分类任务中有softmax交叉熵损失函数,在回归任务中常见的是均方误差(Mean Squared Error/MSE)和绝对误差(Mean Absolute Error/MAE)等。

常见的损失函数包括:

  • 二元交叉熵损失(Binary Cross-Entropy Loss / BCE Loss):适用于二分类问题,衡量的是sigmoid函数输出的概率与真实标签间的距离。

  • 多分类交叉熵损失(Categorical Cross-Entropy Loss):对于多分类问题,每个样本可能属于多个类别之一,使用softmax函数和交叉熵损失。

  • 均方误差(Mean Squared Error / MSE):在回归问题中常用,计算预测值与真实值之差的平方平均。

  • 均方根误差(Root Mean Squared Error / RMSE):MSE的平方根,也是回归任务中的损失函数。

  • Huber损失:一种既能兼顾均方误差又能容忍较大误差的混合损失函数,常用于回归问题中。

  • Dice系数损失(Dice Loss):在图像分割任务中广泛使用,衡量的是预测分割区域与真实分割区域的重叠程度。

  • IoU(Intersection over Union)损失:也是在图像分割领域常用的损失函数,计算的是预测区域与真实区域交集与其并集的比例。

  • Focal Loss:在目标检测中应对类别不平衡问题的损失函数,对易分类的样本给予较小的权重,强调难分类样本的训练。

每种损失函数都有其适用的情境和优缺点,选择合适的损失函数是优化模型性能的关键因素之一。

交叉熵(Cross-Entropy)之所以能够用于分类问题,是因为它能够很好地衡量模型预测的概率分布与实际标签分布之间的相似度,而且它拥有几个非常适合分类任务的重要特性:

  1. 信息论基础:交叉熵源于信息论中的概念,表示一个概率分布 (p) 与另一个概率分布 (q) 的差异。在分类问题中,我们可以把 (p) 视为真实数据的标签分布,(q)视为模型预测的概率分布。交叉熵可以衡量模型预测概率与实际类别标签之间的信息差异。

  2. 最大似然估计的自然延伸:在机器学习中,我们通常倾向于最大化模型对数据的似然性,即模型预测给定数据标签的概率。交叉熵损失函数实际上是负对数似然函数在多项式分布(对于多分类问题)或伯努利分布(对于二分类问题)下的特殊情况,通过最小化交叉熵损失,相当于最大化数据的对数似然性。

  3. 梯度稳定性:交叉熵损失函数是连续且可微的,其梯度容易计算且对于大多数情况是有意义的。这意味着在训练过程中,模型可以根据损失函数的梯度进行有效的参数更新。

  4. 稀疏性惩罚:对于多分类问题,softmax函数与交叉熵损失组合使用时,不仅鼓励模型正确预测每个样本的类别,同时也通过归一化机制惩罚了预测概率分布的不均匀性,即模型不能过于肯定任何一个错误类别。

  5. 处理多类别和二类别问题:交叉熵既可以用于处理二分类问题(通过二元交叉熵,Binary Cross-Entropy),也可以处理多分类问题(通过多类别交叉熵,Multiclass
    Cross-Entropy)。在二分类问题中,通常搭配Sigmoid函数输出概率;在多分类问题中,通常配合Softmax函数生成类别概率分布。

总的来说,交叉熵损失函数因其良好的理论基础、优化目标清晰以及在实践中的优秀表现,成为了分类问题中最常用的损失函数之一。

Softmax函数

Softmax函数是深度学习和机器学习中广泛使用的激活函数,特别是在多分类问题中。它的目的是将一个线性变换的输出(通常称为logits)映射为一个概率分布,使得所有类别的概率总和为1,每个类别的概率都在0到1之间。

Softmax函数的形式:

对于一个向量 ( z ) ,其中包含每个类别的原始得分(logits),Softmax函数的计算公式如下:

s o f t m a x ( z ) i = e z i ∑ j = 1 K e z j softmax(z)_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}} softmax(z)i=j=1Kezjezi

其中:

  • ( K ) 表示类别总数。
  • ( z_i ) 表示第 ( i ) 个类别的得分。
  • ( softmax(z)_i ) 表示第 ( i ) 个类别的归一化概率。

整个Softmax函数的结果是一个概率分布向量,其中每个元素都是原得分经过指数函数变换后再除以所有得分指数函数值之和,因此所有元素的和为1。

Softmax函数的特性:

  1. 概率性质:Softmax函数确保输出的每个元素都是非负数,并且所有元素的和为1,满足概率分布的要求。
  2. 竞争性:Softmax函数会使得分最高的类别获得最大的概率值,其余类别的概率按比例递减,形成了一种“赢家通吃”的效应。
  3. 平滑连续:由于使用了指数函数和平滑的除法运算,Softmax函数输出是平滑且连续的,便于在训练过程中梯度的计算和传播。

应用场景

在深度学习的多分类问题中,例如图像分类、文本分类等任务,Softmax函数通常与交叉熵损失函数一起使用。模型最后一层通常会产生一个logits向量,接着通过Softmax函数得到每个类别的概率,最后计算与实际标签之间的交叉熵损失,以此指导模型参数的更新。

代码实例

在PyTorch中,你可以直接使用torch.softmax()函数来实现Softmax操作。下面是一个简单的实例:

import torch

# 假设我们有一个代表logits的张量
logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

# 使用torch.softmax函数计算Softmax值
probs = torch.softmax(logits, dim=1)

print(probs)
'''
tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])
'''

上述代码中,logits是一个2x3的张量,代表两个样本的三个类别的原始得分。dim=1表示我们在每个样本的类别间计算Softmax,也就是对每一行进行操作。执行torch.softmax()后,probs张量将包含每个样本各类别的归一化概率。

注意,如果你正在训练一个多分类模型并且使用了nn.CrossEntropyLoss()损失函数,通常不需要单独调用torch.softmax(),因为该损失函数内部已经包含了对logits计算Softmax的过程。在多数情况下,你只需将模型的原始输出(logits)传递给损失函数,并配合真实类别标签即可。

Cross Entropy Loss (交叉熵损失函数)

在这里插入图片描述

nn.CrossEntropyLoss是PyTorch中用于多分类问题的一种损失函数,特别适用于输出层是softmax激活函数后的分类任务。它结合了softmax函数和交叉熵损失(Cross-Entropy Loss)的操作,简化了模型训练过程中的计算步骤和代码实现。

基本概念:

  • 交叉熵损失(Cross-Entropy Loss)源于信息论中的熵概念,用于衡量两个概率分布之间的差异。在机器学习和深度学习中,它用来量化模型预测的概率分布与真实标签分布之间的差距。

  • softmax函数:在多分类问题中,softmax函数将模型的线性输出(logits)转换为一个概率分布,确保所有类别的概率和为1。softmax函数的输出可以用作模型预测的概率分布。

nn.CrossEntropyLoss的工作方式:

  • PyTorch中的nn.CrossEntropyLoss接收两个输入:

    • input:模型的原始输出(logits),通常是未经过softmax激活的张量。
    • target:真实的一维标签张量,包含了每个样本所属类别的索引,通常采用LongTensor类型。
  • 内部处理流程

    • 对于每个样本,首先计算其对应的softmax概率分布。
    • 然后,根据真实标签计算交叉熵损失。损失是对每个样本的损失值进行平均得到的,如果没有特殊指定,损失默认会在批次(batch)层面求平均。
  • 损失函数计算公式

    • 对于单个样本,交叉熵损失是 -∑(yi * log(pi)),其中 yi 是实际标签的one-hot编码(在实际情况中,由于标签是索引形式,nn.CrossEntropyLoss内部会处理one-hot编码),pi 是模型预测的该类别概率。
    • 对于整个批次,损失则是各样本损失的平均。

Cross Entropy Loss与BCE loss区别

  • 关于BCE Loss(二元交叉熵损失函数)的相关知识,可查阅损失函数:BCE Loss(二元交叉熵损失函数)、Dice Loss(Dice相似系数损失函数)

CrossEntropyLossBCELoss 都是 PyTorch 中用于监督学习分类任务的损失函数,它们分别适用于不同的分类场景:

BCELoss (Binary Cross Entropy Loss)

  • BCELoss 是二元交叉熵损失函数,专门用于二分类问题,即输出只有两类(0或1,正面或负面,真或假等)。
  • 使用 BCELoss 时,模型的输出一般是通过 Sigmoid 函数得到的概率值,介于0和1之间。
  • 计算公式为 -y * log(p) - (1-y) * log(1-p),其中 y 是真实的标签(0或1),p 是模型预测的概率。
  • 输入要求是经过Sigmoid激活函数之后的输出张量和相应的真实标签张量,二者形状必须相同。

CrossEntropyLoss (Multinomial Cross Entropy Loss 或者 Softmax Cross Entropy Loss)

  • CrossEntropyLoss 适用于多分类问题,它可以处理任何数量的类别,不仅仅是二分类。
  • 对于多分类问题,模型的输出通常是一个 logits(未归一化的预测值),然后CrossEntropyLoss内部会先通过Softmax函数将其转换为概率分布,然后再计算交叉熵。
  • 使用 CrossEntropyLoss 时,不需要手动在输出层之前添加Sigmoid或Softmax函数,因为它已经包含了Softmax运算步骤。
  • 它结合了Softmax函数和交叉熵损失的功能,简化了多分类任务的训练流程,其计算公式基于交叉熵和类别间的互斥性(即对于每个样本,所有类别的概率之和为1)。
  • 输入要求是未经Softmax激活函数处理的logits张量和one-hot编码形式的真实标签张量。

总结来说,两者的主要区别在于:

  • BCELoss用于二分类任务,而CrossEntropyLoss适用于多分类任务。
  • BCELoss前接Sigmoid,CrossEntropyLoss前接Softmax(但这一步在使用CrossEntropyLoss时由损失函数内部自动完成)。
  • BCELoss处理的是二元概率分布,而CrossEntropyLoss处理的是多类别概率分布。

代码实例

import torch
import torch.nn as nn

# 假设模型输出和真实标签
output_logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])  # 假设输出是两样本的三个类别的logits
targets = torch.tensor([1, 2])  # 假设第一样本是第二类,第二样本是第三类

# 创建交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(output_logits, targets)

print(loss.item())  # 输出损失值 # 0.9076058864593506

在上述代码中,nn.CrossEntropyLoss()函数内部处理了softmax激活和交叉熵损失计算,直接返回了模型预测与真实标签之间的交叉熵损失。

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/554946.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

密码学 | 椭圆曲线密码学 ECC 入门(二)

目录 4 椭圆曲线:更好的陷门函数 5 奇异的对称性 6 让我们变得奇特 ⚠️ 原文地址:A (Relatively Easy To Understand) Primer on Elliptic Curve Cryptography ⚠️ 写在前面:本文属搬运博客,自己留着学习。如果你和我一样…

【Linux】应用层协议:HTTP

URL 在之前的文章中我们实现了一个网络版本的计算器,在那个计算器中揉合了协议定制以及序列化反序列化的内容,我们当时也自己定制了一套协议标准,比如请求和响应的格式应该是什么?如何读到一个完整的报文?支持的运算符…

【XR806开发板试用】在 xr806 上移植 LVGL

本文参与极术社区的《基于安谋科技STAR-MC1的XR806开发板试用》活动。 不多废话,直接开搞,先上效果图 准备 开发环境啥的,已经有很多文章了,这里就不再提搭建开发环境的相关内容了。 一个屏幕(1.8’ 128x160) LVGL源码(v8.0.2…

京东微服务microApp使用总结

前言 基于现有业务门户进行微服务基础平台搭建 主应用框架:vue3vite 子应用框架:vue2webpack / vue3vite在这里插入代码片 本地调试即可:主应用子应用进行打通(注意:两者都是vue3vite) 问题总结 1.嵌入…

压缩感知的概述梳理(3)

参考文献 Adaptive embedding: A novel meaningful image encryption scheme based on parallel compressive sensing and slant transform 文献内容 梳理 列表形式 并行压缩感知核心元素与流程 信号 x 长度:N表示:(x \sum_{i1}^{N} a_i\psi_i \su…

密码学 | 椭圆曲线密码学 ECC 入门(三)

目录 7 这一切意味着什么? 8 椭圆曲线密码学的应用 9 椭圆曲线密码学的缺点 10 展望未来 ⚠️ 原文地址:A (Relatively Easy To Understand) Primer on Elliptic Curve Cryptography ⚠️ 写在前面:本文属搬运博客,自己留…

Argus DBM 一款开源的数据库监控工具,无需部署Agent

开箱即用 无需部署Agent,开箱即用。我们使用JDBC直连您的数据库,输入IP端口账户密码即可。 全平台支持 Argus目前支持对Mysql, PostgreSQL, Oracle等数据库类型的监控,我们也会尽快适配其它数据库,致力于监控所有数据库。我们提…

ctfshow web入门 SQl注入web171--web179

从这里开始SQl建议大家去看这篇文章学习一下先 MySQl web171 法一联合查询 题目 $sql "select username,password from user where username !flag and id ".$_GET[id]." limit 1;";爆数据库名 -1 union select 1,database(),3 -- 爆表名 -1 union s…

el-input在type=“textarea“中文字换行,导出成word文档中也显示换行

el-input在type"textarea"文字中换行,是\n转义的,文字导出成word文档没有换行,只需要把\n替换成\r,这样el-input和word文档都能正常显示换行 val.replace.replace("\n", "\r")

UWB技术应用:UWB模块在室内定位系统中的关键角色

随着室内定位技术的不断发展,UWB模块在室内定位系统中扮演着关键角色,其高精度、抗干扰和多路径传输等特点,为室内定位提供了可靠的技术支持。本文将探讨UWB模块在室内定位系统中的关键角色,包括其原理、应用场景和技术优势&#…

Java项目 苍穹外卖 黑马程序员

目录 day1一、项目效果展示二、项目开发整体介绍三、项目介绍3.1 定位3.2 功能架构3.3 产品原型3.4 技术选型 四、开发环境搭建4.1 前端环境4.2 后端环境 五、导入接口文档六、Swagger6.1 介绍6.2 使用方式6.3 常用注解 day2一、新增员工二、员工分页查询三、启用禁用员工账号四…

【无标题】PHP-parse_str变量覆盖

[题目信息]: 题目名称题目难度PHP-parse_str变量覆盖1 [题目考点]: 变量覆盖指的是用我们自定义的参数值替换程序原有的变量值,一般变量覆盖漏洞需要结合程序的其它功能来实现完整的攻击。 经常导致变量覆盖漏洞场景有:$$&…

paddle gpu install

官方教程 https://www.paddlepaddle.org.cn/documentation/docs/zh/install/pip/linux-pip.html#sanyanzhenganzhuang 创建一个新的环境安装 python3 -m pip install paddlepaddle-gpu2.6.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/cudnnin/stable.htmlpython …

网络基础-TCP和UDP协议区别

什么是 TCP 协议? TCP 全称是 Transmission Control Protocol(传输控制协议),它由 IETF 的 RFC 793 定义,是一种面向连接的点对点的传输层通信协议。 TCP 通过使用序列号和确认消息,从发送节点提供有关传输…

请陪伴Kimi和GPT成长

经验的闪光汤圆 但是我想要写实的 你有吗? 岁数大了,希望如何学习新知识呢?又觉得自己哪些能力亟需补强呢? 看论文自然得用Kimi,主要是肝不动了,眼睛也顶不住了。 正好昨天跟专业人士学会了用工作流的办法跟…

【动态规划 状态机dp】3082. 求出所有子序列的能量和

算法可以发掘本质,如: 一,若干师傅和徒弟互有好感,有好感的师徒可以结对学习。师傅和徒弟都只能参加一个对子。如何让对子最多。 二,有无限多1X2和2X1的骨牌,某个棋盘若干格子坏了,如何在没有坏…

idea连接Docker数据库

我们在docker下创建了数据库,想要更方便的查看和操作该数据库,idea和DataGrip或者其他人家都可以。在数据库连接时需要填写数据库名字,主机,端口,数据库用户名和密码。 输入之后先不要点击OK和按Enter键,我…

vue+springboot实验个人信息,修改密码,忘记密码功能实现

前端部分 新增Person(个人页面),Password(修改密码页面),还需要对Manager,login页面进行修改 router文件夹下的index.js: import Vue from vue import VueRouter from vue-router i…

【运输层】TCP 的流量控制和拥塞控制

目录 1、流量控制 2、TCP 的拥塞控制 (1)拥塞控制的原理 (2)拥塞控制的具体方法 1、流量控制 一般说来,我们总是希望数据传输得更快一些。但如果发送方把数据发送得过快,接收方就可能来不及接收&#x…

Ubuntu20.04无法连接蓝牙

连接蓝牙的时候一直转圈显示搜索中,但是搜索不出来任何设备。 在终端输入 dmesg | grep -i blue 可看到有以下信息 可以看到无法识别蓝牙,处于一个unknow的状态 Bluetooth: hci0: RTL: unknown IC info, lmp subver 8852, hci rev 000b, hci ver 000b…