机器学习----交叉熵(Cross Entropy)如何做损失函数

目录

一.概念引入

1.损失函数

2.均值平方差损失函数

3.交叉熵损失函数

3.1信息量

3.2信息熵

3.3相对熵

二.交叉熵损失函数的原理及推导过程

表达式

二分类

联立

取对数

补充

三.交叉熵函数的代码实现


一.概念引入

1.损失函数

损失函数是指一种将一个事件(在一个样本空间中的一个元素)映射到一个表达与其事件相关的经济成本或机会成本的实数上的一种函数。在机器学习中,损失函数通常作为学习准则与优化问题相联系,即通过最小化损失函数求解和评估模型。
 
不同的任务类型需要不同的损失函数,例如在回归问题中常用均方误差作为损失函数,分类问题中常用交叉熵作为损失函数。

2.均值平方差损失函数

定义如下: L(y,f(x;\Theta )) = \frac{1}{N}\sum_{i = 1}^{N}(yi - f(xi;\Theta ))^{2}

意义:N为样本数量。公式表示为每一个真实值与预测值相减的平方去平均值。均值平方差的值越小,表明模型越好。

对于回归问题,均方差的损失函数的导数是局部单调的,可以找到最优解。但是对于分类问题,损失函数可能是坑坑洼洼的,很难找到最优解。故均方差损失函数适用于回归问题

3.交叉熵损失函数

交叉熵是信息论中的一个重要概念,主要用于度量两个概率分布间的差异性。在机器学习中,交叉熵表示真实概率分布与预测概率分布之间的差异。其值越小,模型预测效果就越好。
 
交叉熵损失函数的公式为:
L = -(y log \hat{y} +(1-y)log(1- \hat{y}))
 
其中,y表示样本的真实标签,\hat{y}表示模型预测的标签。当y=1时,表示样本属于正类;当y=0时,表示样本属于负类。

3.1信息量

信息量是指信息多少的量度。

比如说

  • 1:太阳从东边升起,这个信息量就是0,因为这个是一句废话。没有不确定性的东西。
  • 2:今天会下雨。从直觉上来看,这个信息量就比较大了,因为今天天气具有不确定性,但是这句话消除了不确定性。

根据上述总结如下:信息量的大小与信息发生的概率成反比。概率越大,信息量就越小,概率越小,信息量就越大。设某件事发生的概率为p(xi),则信息量为:

I(xj) = -ln(p(xi))

3.2信息熵

信息熵是信息论中的一个重要概念,用于衡量一个系统或信号中信息量的不确定性或随机性。
 
信息熵的定义可以用数学公式表示。假设有一个离散的随机变量X,它可以取n个不同的可能值x_1,x_2,\ldots,x_n,每个可能值的概率为p_1,p_2,\ldots,p_n,则信息熵H(X)的计算公式为:
 
H(X)=-\sum_{i=1}^{n}p_i\log_2p_i
 
其中,\log_2表示以2为底的对数。
 
信息熵的物理意义是:它表示了在给定概率分布的情况下,系统的平均不确定性或信息量。信息熵的值越大,表示系统的不确定性越高;信息熵的值越小,表示系统的不确定性越低。

3.3相对熵

相对熵,也称为KL 散度(Kullback-Leibler Divergence),是一种用于比较两个概率分布差异的度量。它衡量了一个概率分布P与另一个参考概率分布Q之间的差异程度。
 
相对熵的定义为:
 
D_{KL}(P||Q)=\sum_{x}P(x)\log\frac{P(x)}{Q(x)}
 
其中,P(x)和Q(x)分别是概率分布P和Q在事件x上的概率。
 
相对熵的物理意义是:它表示了将概率分布P表示为参考概率分布Q的编码时所需的额外信息量。如果P和Q非常接近,相对熵的值会比较小;如果P和Q差异较大,相对熵的值会比较大。
KL散度=交叉熵-信息熵
相对熵在机器学习、信息论和统计学中有广泛的应用。它可以用于评估两个模型或概率分布的相似性,比较数据分布的差异,以及在熵最小化的框架下进行优化等。
 
例如,在机器学习中,相对熵常用于比较真实数据的分布和模型预测的分布之间的差异,以评估模型的性能。较小的相对熵值表示模型预测的分布与真实分布更接近。

二.分类问题中的交叉熵

1.二分类问题中的交叉熵

把二分类的交叉熵公式 4 分解开两种情况:

  • 当 y=1 时,即标签值是 1 ,是个正例,加号后面的项为: loss=-log(a)
  • 当 y=0 时,即标签值是 0 ,是个反例,加号前面的项为 0 : loss=-log(1-a)

横坐标是预测输出,纵坐标是损失函数值。 y=1 意味着当前样本标签值是1,当预测输出越接近1时,损失函数值越小,训练结果越准确。当预测输出越接近0时,损失函数值越大,训练结果越糟糕。此时,损失函数值如下图所示。

 2.多分类问题中的交叉熵

假设希望根据图片动物的轮廓、颜色等特征,来预测动物的类别,有三种可预测类别:猫、狗、猪。假设我们训练了两个分类模型,其预测结果如下:

模型1:

预测值标签值是否正确
0.3 0.3 0.40 0 1(猪)正确
0.3 0.4 0.40 1 0(狗)正确
0.1 0.2 0.71 0 0(猫)错误

每行表示不同样本的预测情况,公共 3 个样本。可以看出,模型 1 对于样本 1 和样本 2 以非常微弱的优势判断正确,对于样本 3 的判断则彻底错误。

模型2:

预测值标签值是否正确
0.1 0.2 0.70 0 1(猪)正确
0.1 0.7 0.20 1 0(狗)正确
0.3 0.4 0.41 0 0(猫)错误

可以看出,模型 2 对于样本 1 和样本 2 判断非常准确(预测概率值更趋近于 1),对于样本 3 虽然判断错误,但是相对来说没有错得太离谱(预测概率值远小于 1)。

结合多分类的交叉熵损失函数公式可得,模型 1 的交叉熵为:

sample 1 loss = -(0 * log(0.3) + 0 * log(0.3) + 1 * log(0.4)) = 0.91

sample 1 loss = -(0 * log(0.3) + 1 * log(0.4) + 0 * log(0.4)) = 0.91

sample 1 loss = -(1 * log(0.1) + 0 * log(0.2) + 0 * log(0.7)) = 2.30

对所有样本的 loss 求平均:

L=\frac{0.91+0.91+2.3}{3}=1.37

模型 2 的交叉熵为:

sample 1 loss = -(0 * log(0.1) + 0 * log(0.2) + 1 * log(0.7)) = 0.35

sample 1 loss = -(0 * log(0.1) + 1 * log(0.7) + 0 * log(0.2)) = 0.35

sample 1 loss = -(1 * log(0.3) + 0 * log(0.4) + 0 * log(0.4)) = 1.20

对所有样本的 loss 求平均:

L=\frac{0.35+0.35+1.2}{3}=0.63

可以看到,0.63 比 1.37 的损失值小很多,这说明预测值越接近真实标签值,即交叉熵损失函数可以较好的捕捉到模型 1 和模型 2 预测效果的差异。交叉熵损失函数值越小,反向传播的力度越小

参考文章-损失函数|交叉熵损失函数。

三.交叉熵损失函数的原理及推导过程

表达式

输出标签表示为10,1}时,损失函数表达式为:L = -(y log \hat{y} +(1-y)log(1- \hat{y}))

二分类

二分类问题,假设y\epsilon (0,1)
正例:P(y = 1 |x) = \hat{y}                                                                 公式1

反例:P(y = 0|x) = 1-\hat{y}                                                         公式2

联立

将上述两式连乘。
P(y | x) = \hat{y}^{y}*(1-\hat{y})^{1-y};       其中y\epsilon (0,1)                            公式3

当y=1时,公式3和公式1一样。
当y=0时,公式3和公式2一样。

取对数

取对数,方便运算,也不会改变函数的单调性。

logp(y|x) = ylog\hat{y}+(1-y)log(1-\hat{y})                                公式4
我们希望P(y|x)越大越好,即让负值-logP(y|x)越小越好,

得到损失函数为L = -(y log \hat{y} +(1-y)log(1- \hat{y}))              公式5

补充

上面说的都是一个样本的时候,多个样本的表达式是:多个样本的概率即联合概率,等于每个的乘积。

p(y |x) = \prod_{i}^{m}p(y^{(i)}|x^{(i)})
logp(y|x)= \sum_{i}^{m}logp(y^{(i)}x^{(i)})
由公式4和公式5得到
logp(y^{(i)} |x^{(i)})=-L(y^{(i)}|x^{(i)})

logp(y^{(i)}|x^{(i)})=-\sum_{i}^{m}L(y^{(i)}|x^{(i)})
加上\frac{1}{m}对式子进行缩放。便于计算。
Cost(min):J(w,b) = \frac{1}{m}\sum_{i}^{m}L(y^{(i)}|x^{(i)})
或者写作

J=-\frac{1}{m}\sum_{i=1}^{m}[y^{(i)}log\hat{y}^{(i)}+(1-y^{(i)})log(1-\hat{y}^{(i)})]

四.交叉熵函数的代码实现

在Python中,可以使用NumPy库或深度学习框架(如TensorFlow、PyTorch)来计算交叉熵损失函数。以下是使用NumPy计算二分类和多分类交叉熵损失函数的示例代码:

import numpy as np

# 二分类交叉熵损失函数
def binary_cross_entropy_loss(y_true, y_pred):
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

# 多分类交叉熵损失函数
def categorical_cross_entropy_loss(y_true, y_pred):
    num_classes = y_true.shape[1]
    return -np.mean(np.sum(y_true * np.log(y_pred + 1e-9), axis=1))

# 示例用法
# 二分类
y_true_binary = np.array([[0], [1], [1], [0]])
y_pred_binary = np.array([[0.1], [0.9], [0.8], [0.4]])
loss_binary = binary_cross_entropy_loss(y_true_binary, y_pred_binary)
print("Binary Cross-Entropy Loss:", loss_binary)

# 多分类
y_true_categorical = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
y_pred_categorical = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]])
loss_categorical = categorical_cross_entropy_loss(y_true_categorical, y_pred_categorical)
print("Categorical Cross-Entropy Loss:", loss_categorical)

请注意,上述代码示例仅用于演示目的,实际使用中可能会使用深度学习框架提供的交叉熵损失函数,因为它们通常更加优化和稳定。例如,在TensorFlow中,可以使用tf.keras.losses.BinaryCrossentropy和tf.keras.losses.CategoricalCrossentropy类来计算二分类和多分类交叉熵损失函数。在PyTorch中,可以使用torch.nn.BCELoss和torch.nn.CrossEntropyLoss类来计算相应的损失函数。

代码来自于https://blog.csdn.net/qlkaicx/article/details/136100406

五.交叉熵函数优缺点

1.优点

在用梯度下降法做参数更新的时候,模型学习的速度取决于两个值:

1、学习率

2、偏导值;

其中,学习率是我们需要设置的超参数,所以我们重点关注偏导值。从上面的式子中,我们发现,偏导值的大小取决于 和 ,我们重点关注后者,后者的大小值反映了我们模型的错误程度,该值越大,说明模型效果越差,但是该值越大同时也会使得偏导值越大,从而模型学习速度更快。所以,使用逻辑函数得到概率,并结合交叉熵当损失函数时,在模型效果差的时候学习速度比较快,在模型效果好的时候学习速度变慢。

2.缺点

Deng在2019年提出了ArcFace Loss,并在论文里说了Softmax Loss的两个缺点:

  • 1、随着分类数目的增大,分类层的线性变化矩阵参数也随着增大;
  • 2、对于封闭集分类问题,学习到的特征是可分离的,但对于开放集人脸识别问题,所学特征却没有足够的区分性。对于人脸识别问题,首先人脸数目(对应分类数目)是很多的,而且会不断有新的人脸进来,不是一个封闭集分类问题。

另外,sigmoid(softmax)+cross-entropy loss 擅长于学习类间的信息,因为它采用了类间竞争机制,它只关心对于正确标签预测概率的准确性,忽略了其他非正确标签的差异,导致学习到的特征比较散。基于这个问题的优化有很多,比如对softmax进行改进,如L-Softmax、SM-Softmax、AM-Softmax等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/478231.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SW工具下没有URDF

解决方案,下载

阿里云安装宝塔后面板打不开

前言 按理来说装个宝塔面板应该很轻松的,我却装了2天,真挺恼火的,网上搜的教程基本上解决不掉我的问题点,问了阿里云和宝塔客服,弄了将近2天,才找出问题出在哪里,在此记录一下问题的处理。 服…

十分钟快速入门 Python

本文以 Eric Matthes 的《Python编程:从入门到实践》为基础,以有一定其他语言经验的程序员视角,对书中内容提炼总结,化繁为简,将这本书的精髓融合成一篇10分钟能读完的文章。 读完本篇文章后,可对 Python …

【学习】软件测试行业未来的发展趋势预测

近年来,随着中国数字经济的蓬勃发展,软件测试行业也迎来了新的春天。从早期的手工测试到自动化测试,再到持续集成和持续交付,中国的软件测试行业经历了快速的发展和变革。各行各业均对软件测试提出了更高的要求,尤其在…

按摩师C语言

题干出现“接或不接”,“最优”&#xff0c;仔细一想&#xff0c;该用动态规划了。 #include<stdio.h> int max(int a,int b) {if(a>b)return a;elsereturn b; } int massage(int* nums,int numSize) {if(numSize 0)return 0;else if(numSize 1)return nums[0];els…

js 输出负数的个数,和所有正整数的平均值。

首先输入要输入的整数个数n&#xff0c;然后输入n个整数。输出为n个整数中负数的个数&#xff0c;和所有正整数的平均值&#xff0c;结果保留一位小数。 0即不是正整数&#xff0c;也不是负数&#xff0c;不计入计算。如果没有正数&#xff0c;则平均值为0。输入 11 1 2 3 4 5 …

zabbix6.4监控mysql数据库

目录 一、前提二、配置mysql数据库模板三、配置监控的mysql主机 一、前提 已经搭建好zabbix-server 在需要监控的mysql服务器上安装zabbix-agent2 上述安装步骤参考我的上篇文章&#xff1a;通过docker容器安装zabbix6.4.12图文详解&#xff08;监控服务器docker容器&#xf…

2016年认证杯SPSSPRO杯数学建模A题(第二阶段)洗衣机全过程文档及程序

2016年认证杯SPSSPRO杯数学建模 A题 洗衣机 原题再现&#xff1a; 洗衣机是普及率极高的家用电器&#xff0c;它给人们的生活带来了很大的方便。家用洗衣机从工作方式来看&#xff0c;有波轮式、滚筒式、搅拌式等若干种类。在此基础上&#xff0c;各厂商也推出了多种具体方案…

二、Web3 学习(区块链)

区块链基础知识 一、基础知识1. 区块链可以做什么&#xff1f;2. 区块链的三个特点 二、区块链的类型概括1. PoW2. PoS3. 私有链和联盟链 三、智能合约1. 什么是智能合约2. 如何使用智能合约 四、困境1. 三难选择的基本要素2. 这真的是一个三难选择吗? 五、比特币1. 什么是比特…

2024年服装erp系统排名,都有哪些品牌?

随着数字经济时代的到来&#xff0c;传统的服装纺织行业正在经历深刻的变革。如何实现产业数字化升级&#xff0c;是众多服装纺织企业所面临的共同课题。当前&#xff0c;服装erp管理系统已经成为服装企业实现智能化转型的关键&#xff0c;通过将所有流程环节的数据进行统一的信…

C语言字符函数与字符串函数:编织文字的舞会之梦(下)

欢迎来到白刘的领域 Miracle_86.-CSDN博客 系列专栏 C语言知识 先赞后看&#xff0c;已成习惯 创作不易&#xff0c;多多支持&#xff01; 目录 七、strncpy的使用以及模拟实现 八、strncat的使用以及模拟实现 九、strncmp的使用以及模拟实现 十、strstr的使用以及模拟…

预防胜于治疗:六西格玛设计DFSS在新产品开发中的实践

六西格玛设计(Design for Six Sigma, DFSS)是一套旨在从一开始就设计出几乎完美的产品和服务的方法论。它不同于六西格玛改进过程&#xff0c;后者专注于改善现有的产品和流程。DFSS 的核心理念是通过彻底理解顾客需求和采用创新的解决方案来预防缺陷的发生&#xff0c;从而在产…

开源项目ChatGPT-Next-Web的容器化部署(二)-- jenkins CI构建并推送镜像

一、背景 接着上文已制作好了Dockerfile&#xff0c;接下来就是docker build/tag/push等一系列操作了。 不过在这之前&#xff0c;你还必须在jenkins等CI工具中&#xff0c;拉取源码&#xff0c;然后build构建应用。 因为本文的重点不是讲述jenkins ci工具&#xff0c;所以只…

HTTP系列之HTTP缓存 —— 强缓存和协商缓存

文章目录 HTTP缓存强缓存协商缓存状态码区别缓存优先级如何设置强缓存和协商缓存使用场景 HTTP缓存 HTTP缓存时利用HTTP响应头将所请求的资源在浏览器进行缓存&#xff0c;缓存方式分两种&#xff1a;强缓存和协商缓存。 浏览器缓存是指将之前请求过的资源在浏览器进行缓存&am…

山东省大数据局副局长禹金涛一行莅临聚合数据走访调研

3月19日&#xff0c;山东省大数据局党组成员、副局长禹金涛莅临聚合数据展开考察调研。山东省大数据局数据应用管理与安全处处长杨峰&#xff0c;副处长都海明参加调研&#xff0c;苏州市大数据局副局长汤晶陪同。聚合数据董事长左磊等人接待来访。 调研组一行参观了聚合数据展…

leetcode刷题日记-搜索插入位置

题目描述 解题思路 这种题目一眼看到二分查找的变体&#xff0c;稍微在二分查找的基础上变了一点点&#xff0c;那就是找插入的位置&#xff0c;根据二分查找的原理我们可以知道&#xff0c;当查找失败的时候&#xff0c;会退出while循环&#xff0c;跳出循环的时候左边指针所…

多数据源mybatisplus对sqlserver分页查询兼容

新增配置文件 package com.ruoyi.framework.config;import com.baomidou.mybatisplus.annotation.DbType; import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor…

基于ssm的网络游戏公司官方平台设计与实现论文

摘 要 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。针对网络游戏信息管理混乱&#xff0c;出错率高&#xff0c;信息安全性差…

代码随想录 Day-24

动态规划理论基础 1、什么是动态规划 动态规划&#xff0c;英文&#xff1a;Dynamic Programming&#xff0c;简称DP&#xff0c;如果某一问题有很多重叠子问题&#xff0c;使用动态规划是最有效的。 所以动态规划中每一个状态一定是由上一个状态推导出来的&#xff0c;这一点…

C#中解决字符串在编译后无法修改的情况

文章目录 一、配置文件二、使用方式对于.NET Framework应用程序&#xff08;使用app.config&#xff09;对于.NET Core和.NET 5/6应用程序&#xff08;使用appsettings.json&#xff09; 三、应用实例 一、配置文件 在C#等编程语言中&#xff0c;硬编码&#xff08;直接在代码…