PyTorch中的交叉熵函数 CrossEntropyLoss的计算过程

CrossEntropyLoss() 函数联合调用了 nn.LogSoftmax() 和 nn.NLLLoss()。

关于交叉熵函数的公式详见:
交叉熵损失函数原理详解
CrossEntropyLoss() 函数的计算过程可以拆解为如下四个步骤:
1、对输出的结果进行softmax操作,因为softmax操作可以将所有输入值都归为[0,1]之间,且所有值之和为1,符合概率分布的特性。

2、对softmax结果进行log运算,求出都是小于0的值
3、对真实概率值进行one-hot编码
4、利用下面的公式求出最终的loss值
C r o s s E n t r o p y L o s s ( x ) = − ∑ i = 1 n O n e H o t ( t a r g e t i ) ∗ l o g s o f t m a x ( i n p u t ) i CrossEntropyLoss(x) = - \sum_{i=1}^{n} OneHot(target_i) * log^{softmax(input)_i} CrossEntropyLoss(x)=i=1nOneHot(targeti)logsoftmax(input)i
不难看出NLLloss+log+softmax就是CrossEntropyLoss(softmax版的交叉熵损失函数),而其中的NLLloss就是在做交叉熵损失函数的最后一步:预测结果的取负求和。

一段代码带你上高速:

原生手写CrossEntropyLoss()函数与PyTorch里的CrossEntropyLoss():

import torch
import torch.nn as nn
import torch.nn.functional as F


def softmax(x):
    r"""
    这里的公式是:

    .. math::
        softmax(x_i) = \frac{\exp^{x_i}}{\sum_{i=0}^N (\exp^{x_i})}
    :param x:
    :return:
    """

    # return torch.exp(x) / torch.unsqueeze(torch.sum(torch.exp(x), dim=1), dim=1)
    t = torch.zeros_like(x)
    sum_array = []
    for idx in range(len(x)):
        t[idx] = torch.exp(x[idx]) / torch.sum(torch.exp(x[idx]))
        sum_array.append(torch.sum(torch.exp(x[idx])))
    print(sum_array)
    print(torch.unsqueeze(torch.sum(torch.exp(x), dim=1), dim=1))
    return t

def manual_cross_entropy_loss(inputs, target):
    one_hot_target = F.one_hot(target, 10)
    print(one_hot_target)
    # print(softmax_)
    # print(F.softmax(inputs, dim=-1))
    # cross_entropy_loss = -torch.sum(one_hot_target * torch.log(F.softmax(inputs, dim=-1))) / len(inputs)
    cross_entropy_loss = -torch.sum(one_hot_target * torch.log(softmax(inputs))) / len(inputs)
    print(cross_entropy_loss)


def call_cross_entropy_loss(inputs, target):
    """
    调用真实的CrossEntropyLoss()
    :return:
    """
    loss = nn.CrossEntropyLoss()
    loss = loss(inputs, target)
    print(loss)


if __name__ == '__main__':
    torch.random.manual_seed(0)
    inputs = torch.randn(3, 10)
    # 获取inputs的最大值的索引
    target = torch.argmax(inputs, dim=-1)
    print(target)
    manual_cross_entropy_loss(inputs=inputs, target=target)
    call_cross_entropy_loss(inputs=inputs, target=target)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/15379.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于虚拟同步发电机的光伏混合储能并网系统MATLAB仿真

资源地址: 主要模块: 光伏电池模型(按照数学模型搭建)、蓄电池储能模块、超级电容储能模块、双向DC/DC模块、LC滤波器、逆变器VSG控制模块电压电流双环控制模块、光伏MPPT控制模块、储能系统充放电控制模块。 使用MATLAB2021b及…

09-Node.js—express框架

目录 1、express 介绍2、express 使用2.1 express 下载2.2 express 初体验 3、express 路由3.1 什么是路由3.2 路由的使用3.2.1使用Ajax发送一次post请求 3.3 获取请求参数 3.4 获取路由参数3.5 路由参数练习 4、express 响应设置5、express 中间件5.1 什么是中间件5.2 中间件的…

Linux 利用网络同步时间

yum -y install ntp ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime ntpdate ntp1.aliyun.com 创建加入crontab echo "*/20 * * * * /usr/sbin/ntpdate -u ntp.api.bz >/dev/null &" >> /var/spool/cron/rootntp常用服务器 中国国家授…

活动目录(Active Directory)安全审计

延迟响应变化的影响可能会使原本应该微不足道的颠簸滚雪球变成无法弥补的损害。这在 Windows Active Directory 环境中更为重要,因为这种延迟造成的损害可能会使组织损失数百万美元!在这种情况下,需要一个警惕的警报系统,该系统可…

(下)苹果有开源,但又怎样呢?

一开始,因为 MacOS X ,苹果与 FreeBSD 过往从密,不仅挖来 FreeBSD 创始人 Jordan Hubbard,更是在此基础上开源了 Darwin。但是,苹果并没有给予 Darwin 太多关注,作为苹果的首个开源项目,它算不上…

【多线程】线程安全问题

1. 一段线程不安全的代码 我们先来看一段代码&#xff1a; public class ThreadDemo {public static int count 0;public static void main(String[] args) {for (int i 0; i < 10_0000; i) {count;}System.out.println("count " count);} } // 打印结果&…

MySql中执行计划如何来的——Optimizer Trace | 京东云技术团队

作者&#xff1a;京东物流 籍磊 1.前言 当谈到MySQL的执行计划时&#xff0c;会有很多同学想&#xff1a;“我就觉得使用其他的执行方案比EXPLAIN语句输出的方案强&#xff0c;凭什么优化器做的决定与我得不一样&#xff1f;”。这个问题在MySQL 5.6之前或许自己很难解决&…

滑动奇异频谱分析:数据驱动的非平稳信号分解工具(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

每天一道算法练习题--Day15 第一章 --算法专题 --- -----------二叉树的遍历

概述 二叉树作为一个基础的数据结构&#xff0c;遍历算法作为一个基础的算法&#xff0c;两者结合当然是经典的组合了。很多题目都会有 ta 的身影&#xff0c;有直接问二叉树的遍历的&#xff0c;有间接问的。比如要你找到树中满足条件的节点&#xff0c;就是间接考察树的遍历…

【Java校招面试】基础知识(一)——Java常用类库

目录 前言一、编程时常用的Java类库1. 异常捕获模块(try-catch-finally, Error, Exception)2. boolean / short / int / long / float / double / char / byte及其对应的引用类型 二、面试时常考的Java类库1. 一切类型的父类Object及其equals / hashCode / toString方法2. 常用…

scratch拆礼物游戏 中国电子学会图形化编程 少儿编程 scratch编程等级考试三级真题和答案解析2023年3月

目录 scratch拆礼物游戏 一、题目要求 1、准备工作 2、功能实现 二、案例分析 <

perf工具报错,升级ubuntu子系统linux内核

文章目录 1&#xff0c;运行perf工具报错1.1&#xff0c;可能的原因有&#xff1a; 2&#xff0c;我选择升级linux内核&#xff0c;和当前perf工具版本保持一致2.1&#xff0c;下载6.2.12内核源码2.2&#xff0c;安装6.2.12内核 1&#xff0c;运行perf工具报错 1.1&#xff0c;…

网络安全基础入门学习路线

在大多数的思维里总觉得学习网络安全得先收集资料、学习编程、学习计算机基础&#xff0c;这样不是不可以&#xff0c;但是这样学效率太低了&#xff01; 你要知道网络安全是一门技术&#xff0c;任何技术的学习一定是以实践为主的。也就是说很多的理论知识其实是可以在实践中…

十、ElasticSearch 实战 - 源码运行

一、概述 想深入理解 Elasticsearch&#xff0c;了解其报错机制&#xff0c;并有针对性的调整参数&#xff0c;阅读其源码是很有必要的。此外&#xff0c;了解优秀开源项目的代码架构&#xff0c;能够提高个人的代码架构能力 阅读 Elasticsearch 源码的第一步是搭建调试环境&…

Qt5下Qxlsx模块安装及使用

Qt5下Qxlsx模块安装及使用 一、Qt5下Qxlsx模块安装及使用1. 未安装Qxlsx的程序效果2. 安装Perl&#xff08;编译Qxlsx源码用&#xff09;2.1 下载 ActivePerl 5.282.2 安装 ActivePerl 5.28 3. 下载并编译Qxlsx源码3.1 下载Qxlsx源码3.2 编译Qxlsx源码 4. 将编译好的文件复制到…

26- OCR 基于PP-OCRv3的液晶屏读数识别

要点&#xff1a; 液晶屏识别示例github 地址 1. 简介 本项目基于PaddleOCR开源套件&#xff0c;以PP-OCRv3检测和识别模型为基础&#xff0c;针对液晶屏读数识别场景进行优化。主要是针对各种仪表进行识别&#xff1a; 2 安装环境 安装Git&#xff1a;Git 详细安装教程 # 首…

Git基础

文章目录 1. Git基础1.1 版本管理1.1.1 什么是版本管理1.1.2 人为维护文档版本的问题 1.2 Git 是什么1.3 Git 安装1.4 Git基本工作流程1.5 Git 的使用1.5.1 Git 使用前配置1.5.2 提交步骤1.5.3 撤销 2. Git进阶2.1 分支2.1.1 分支细分2.1.2 分支命令 2.2 暂时保存更改 1. Git基…

鸿蒙Hi3861学习三-第一个实例程序Hello_world

一、简介 前两章介绍了环境搭建、烧录和编译。这一节&#xff0c;来介绍实现第一个经典代码“hello world”。 先介绍小熊派的目录结构&#xff0c;该目录结构延续了OpenHarmony官方目录结构。 二、实操 1.搭建代码架构 1).新建项目文件夹hello_world cd bearpi-hm_nano/appli…

【VM服务管家】VM4.0平台SDK_2.3 控件嵌入类

目录 2.3.1 渲染结果&#xff1a;通过绑定流程或模块获取渲染结果的方法2.3.2 渲染控件&#xff1a;渲染控件加载本地图像的方法2.3.3 渲染控件&#xff1a;渲染控件上自定义图形的方法2.3.4 参数控件&#xff1a;参数配置控件绑定模块的方法2.3.5 控件颜色&#xff1a;控件颜色…

Java新提案,最终还是靠近C#了

Java是一门非常优秀的编程语言&#xff0c;特别是生态繁荣&#xff0c;成熟的轮子很多&#xff0c;各种解决方案都有&#xff0c;要开发一个项目&#xff0c;只需把轮子组装&#xff0c;并根据自己的项目&#xff0c;进行自定义修改&#xff0c;可以极大地提升开发效率。 曾经…