softmax函数与交叉熵损失详解

文章目录

  • 一、softmax函数
    • 1.1 引入指数形式的优点
    • 1.2 引入指数形式的缺点
  • 二、交叉熵损失函数
    • 2.1 交叉熵损失函数
    • 2.2 softmax与交叉熵损失
  • 参考资料

一、softmax函数

softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!

假设我们有一个数组,Z,Zi表示Z中的第i个元素,那么这个元素的softmax值就是:
在这里插入图片描述
更形象的如下图表示:
在这里插入图片描述
softmax直白来说就是将原来输出是3,1,-3通过softmax函数一作用,就映射成为(0,1)的值,而这些值的累和为1(满足概率的性质),那么我们就可以将它理解成概率,在最后选取输出结点的时候,我们就可以选取概率最大(也就是值对应最大的)结点,作为我们的预测目标!

1.1 引入指数形式的优点

从上述 y = e x y = e^{x} y=ex函数图像指数函数中可以看出,曲线呈现递增趋势,最重要的是斜率逐渐增大,也就是说在x轴上一个很小的变化,可以导致y轴上很大的变化。这种函数曲线能够将输出的数值拉开距离。

假设拥有三个输出节点的输出值为为[2, 3, 5]。我们来尝试 不使用指数函数使用指数函数的Softmax函数 计算会产生什么区别。

import tensorflow as tf

print(tf.__version__) # 2.0.0
a = tf.constant([2, 3, 5], dtype = tf.float32)

b1 = a / tf.reduce_sum(a) # 不使用指数
print(b1) # tf.Tensor([0.2 0.3 0.5], shape=(3,), dtype=float32)

b2 = tf.nn.softmax(a) # 使用指数的Softmax
print(b2) # tf.Tensor([0.04201007 0.11419519 0.8437947 ], shape=(3,), dtype=float32)

结果还是挺明显的,经过使用指数形式的Softmax函数能够将差距大的数值距离拉的更大。

另外,在深度学习中通常使用反向传播求解梯度进而使用梯度下降进行参数更新的过程,使用指数函数在求导的时候比较方便。

1.2 引入指数形式的缺点

指数函数的曲线斜率逐渐增大虽然能够将输出值拉开距离,但是也带来了缺点,当 z i z_{i} zi 值非常大的话,计算得到的数值也会变的非常大,数值可能会溢出。

import numpy as np

scores = np.array([123, 456, 789])
softmax = np.exp(scores) / np.sum(np.exp(scores))
print(softmax) # [ 0.  0. nan]

当然针对数值溢出有其对应的优化方法,将每一个输出值减去输出值中最大的值。
在这里插入图片描述

import numpy as np

scores = np.array([123, 456, 789])
scores -= np.max(scores)
p = np.exp(scores) / np.sum(np.exp(scores))
print(p) # [5.75274406e-290 2.39848787e-145 1.00000000e+000]

二、交叉熵损失函数

2.1 交叉熵损失函数

交叉熵函数为在处理分类问题中常用的一种损失函数,用于描述模型预测值与真实值的差距大小,其具体公式为:

在这里插入图片描述

其中,p为真实值,q为预测值。

这里,关于交叉熵的具体介绍参见博客:损失函数——交叉熵损失函数(CrossEntropy Loss)

交叉熵损失函数在Python中的具体实现如下:

# tensorflow version
loss = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1]))
 
# numpy version
loss = np.mean(-np.sum(y_*np.log(y), axis=1))
 
# pytorch version
entroy=nn.CrossEntropyLoss()

2.2 softmax与交叉熵损失

当使用Softmax函数作为输出节点的激活函数的时候,一般使用交叉熵作为损失函数。

由于Softmax函数的数值计算过程中,很容易因为输出节点的输出值比较大而发生数值溢出的现象,在计算交叉熵的时候也可能会出现数值溢出的问题。为了数值计算的稳定性,TensorFlow提供了一个统一的接口,将Softmax与交叉熵损失函数同时实现,同时也处理了数值不稳定的异常,使用TensorFlow深度学习框架的时候,一般推荐使用这个统一的接口,避免分开使用Softmax函数与交叉熵损失函数。

TensorFlow提供的统一函数式接口为:

import tensorflow as tf

print(tf.__version__) # 2.0.0
tf.keras.losses.categorical_crossentropy(y_true, y_pred, from_logits = False)

中y_true代表了One-hot编码后的真实标签,y_pred表示网络的实际预测值:

  • 当from_logits设置为True时,y_pred表示未经Softmax函数的输出值;
  • 当from_logits设置为False时,y_pred表示为经过Softmax函数后的输出值;

为了在计算Softmax函数时候数值的稳定,一般将from_logits设置为True,此时 tf.keras.losses.categorical_crossentropy 将在内部进行Softmax的计算,所以在不需要在输出节点上添加Softmax激活函数。

参考资料

  • Softmax函数详解与推导
  • 一文详解Softmax函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/622153.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

simulink-仿真以及PID参数整定/PID tuner 的使用流程

控制器搭建与参数整定 搭建一个前馈PID控制器控制系统PID tuner使用 一个懂点控制但不多的小白,因为需要利用simulink仿真,所以不得不学习一些仿真的知识,这篇文章适合和我一样的新手入门,有理解错误的地方希望大手们能够指出来共…

背完这些软件测试核心面试题,offer轻松拿捏了!

你赞同过 软件测试和开发 相关内容 01、您所熟悉的测试用例设计方法都有哪些?请分别以具体的例子来说明这些方法在测试用例设计工作中的应用。 答:有黑盒和白盒两种测试种类,黑盒有等价类划分法,边界分析法,因果图法和…

spacy微调BERT-NER模型

数据准备 加载数据集 from tqdm.notebook import tqdm import osdataset [] with open(train_file, r) as file:for line in tqdm(file.readlines()):data json.loads(line.strip())dataset.append(data)你可以按照 CLUENER 的格式准备训练数据, 例如&#xff1…

彻底搞定找不到msvcp100.dll,无法继续执行代码的问题

当您在使用电脑过程中遇到程序运行异常,提示“缺失msvcp100.dll文件”时,不必过于焦虑,此问题可通过一系列简单步骤得到有效解决。MSVCP100.dll是Microsoft Visual C库的一部分,主要用于支持某些应用程序运行所需的特定功能。如果…

ohmyzsh的安装过程中失败拒绝连接问题的解决

1.打开官网Oh My Zsh - a delightful & open source framework for Zsh 在官网能看到下面的界面 有这两种自动安装的方式 个人本次选择的是: wget https://raw.github.com/ohmyzsh/ohmyzsh/master/tools/install.sh -O - 1.打开终端输入安装的指令 sh -c "$(wget…

(done) 什么是马尔可夫链?Markov Chain

参考视频:https://www.bilibili.com/video/BV1ko4y1P7Zv/?spm_id_from333.337.search-card.all.click&vd_source7a1a0bc74158c6993c7355c5490fc600 如下图所示,马尔可夫链条实际上就是 “状态机”,只不过状态机里不同状态之间的边上是 “…

向银行家应用程序添加日期

● 首先我们将下面图片上的时间更换成现在的时间 const now new Date(); const day now.getDate(); const month now.getMonth() 1; const year now.getFullYear(); const hour now.getHours(); const min now.getMinutes();labelDate.textContent ${day}/${month}/$…

数据降维-主成分分析PCA

1.背景: 在以前计算能力还很弱的年代,我们要分析经济数据是一件很困难的事情,所以我们需要对指标特征进行降维; 2.数据降维的意义: 一般我们降维的特征数据彼此之间是存在一定的相关性的, 二维降至一维…

车载电子电器架构 —— Vector对于车载以太网的解决方案(协议栈)

车载电子电器架构 —— Vector对于车载以太网的解决方案(协议栈) 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你…

252 基于MATLAB的自适应差分阈值法检测心电信号的QRS波

基于MATLAB的自适应差分阈值法检测心电信号的QRS波,QRS波群反映左、右心室除极电位和时间的变化,第一个向下的波为Q波,向上的波为R波,接着向下的波是S波。通过GUI进行数据处理,展示心率和QRS。程序已调通,可…

COMSOL粗略估算计算时间

COMSOL粗略估算模型计算时间 针对反复修改调试的有限元模型,需要大致估算有限元模型的计算时间。经查阅,模型求解的自由度数和求解时间密切相关。 测试条件 测试模型为声-固耦合模型,电脑内存32G,CPU-i7-10700K,核显…

【GD32】03 - EXTI外部中断

EXTI EXTI,全称External Interrupt/Event Controller,即外部中断/事件控制器,是微控制器中的一个重要组成部分。它主要用于管理来自外部设备的中断和事件请求。以下是关于EXTI的详细介绍: 功能概述: EXTI管理了控制器的…

OpenAI 人工智能搜索产品即将推出,文本和图像都支持?!

OpenAI 人工智能搜索产品即将推出 OpenAI 计划于下周一(5 月 13 日)正式公布其人工智能搜索产品,不过报道中强调具体公告日期可能发生变化。OpenAI 拒绝对路透社的报道置评。外媒 The Information 在今年 2 月的报道中指出,OpenAI 一直在秘密开发其自家网络搜索服务,并将获…

C++初阶学习第七弹——探索STL奥秘(二)——string的模拟实现

标准库中的string:C初阶学习第六弹——string(1)——标准库中的string类-CSDN博客 前言: 在前面我们已经学习了如何使用标准库中的string类,但作为一个合格的程序员,我们不仅要会用,还要知道如…

Midjourney Imagine API 申请及使用

Midjourney Imagine API 申请及使用 申请流程 要使用 Midjourney Imagine API,首先可以到 Midjourney Imagine API 页面点击「Acquire」按钮,获取请求所需要的凭证: 如果你尚未登录或注册,会自动跳转到登录页面邀请您来注册和登…

[嵌入式系统-77]:RT-Thread-快速上手:嵌入式系统调测工具大全

目录 1. JTAG 下载调试器: 2. J-Link 仿真器: 3. ICE(In-Circuit Emulator): 4. ROM监视器(ROM Monitor): 5. 终端仿真工具: 6. 总线抓取工具: 7. 静态…

【教学类-55-02】20240512图层顺序挑战(四格长条纸加黑色边框、4*4、7张 、43200张去掉非7色有23040张,去掉重复样式有几种?)

作品展示 背景需求: 之前的代码吗存在几个问题,最大的问题是不能生成“”长条黑边框”” 【教学类-55-01】20240511图层顺序挑战(四格长条纸)(4*4)和“手工纸自制参考图”-CSDN博客文章浏览阅读485次&…

2024第八季完美童模 【星光】品牌赛区 【直通】赛 完美收官

2024年5月1日,春风徐徐的【星光品牌赛区】热闹非凡,备受瞩目的第八季完美童模【星光品牌赛区】赛区【直通赛】在这一天正式拉开了帷幕。比赛现场,童模们身着华服,在舞台上演绎了“亚特兰蒂斯”的时尚主题赛。 参赛选手们身着带有海…

【优先级队列】Leetcode 最后一块石头的重量

题目讲解 1046. 最后一块石头的重量 算法讲解 根据题目的意思,为了寻找到本次数组中的最大的两个值,我们需要使用一个数据结构:堆,使用大堆,每一次出两个数据,这两个数据就是当前数组中的两个最大值&…

拉链表实现过程+案例

第一种 1.从ODS层获取增量数据(上一天新增和更新的数据) 2.拿着DWD原始拉链表数据 left join 增量数据 ,修改原始拉链中历史数据的结束时间 3.拿着left join 的结果集 union all 增量数据 4.把最新的拉链数据优先保存到DWD对应的临时表中 5.使用insertselect 方式把临时表中…