损失函数——交叉熵损失(Cross-entropy loss)

交叉熵损失(Cross-entropy loss)是深度学习中常用的一种损失函数,通常用于分类问题。它衡量了模型预测结果与实际结果之间的差距,是优化模型参数的关键指标之一。以下是交叉熵损失的详细介绍。

假设我们有一个分类问题,需要将输入数据x分为C个不同的类别。对于每个输入数据x,我们定义一个C维的向量y^​,其中y^​i​表示x属于第i个类别的概率。我们的目标是使得y^​尽可能接近真实的标签y的概率分布。

假设真实标签y是一个C维的向量,其中只有一个元素为1,其余元素为0,表示x属于第k个类别。那么,我们可以使用交叉熵损失来衡量模型预测结果和真实标签之间的差距。交叉熵损失的公式如下: 

                                                         L\left ( x,y \right ) = -\sum _{i=1}^{C}x_{_{i}} log y_{i}

其中,xi​表示真实标签的第i个元素,y​i​表示模型预测x属于第i个类别的概率。

交叉熵损失的本质是衡量两个概率分布之间的距离。其中一个概率分布是真实标签y的分布,另一个是模型预测的概率分布y^​。对于每个类别i,yi​表示真实标签x属于第i个类别的概率,y^​i​表示模型预测x属于第i个类别的概率。当两个概率分布越接近时,交叉熵损失越小,表示模型预测结果越准确。

交叉熵损失是一种凸函数,通常使用梯度下降等优化算法来最小化它。在深度学习中,交叉熵损失是常见的分类损失函数之一,广泛应用于图像分类、语音识别等任务中。

在PyTorch中,交叉熵损失可以使用torch.nn.CrossEntropyLoss实现。该函数将输入数据视为模型输出的概率分布,将目标标签视为类别索引,并计算这些概率与实际标签之间的交叉熵损失。

以下是一个示例代码片段,说明如何使用torch.nn.CrossEntropyLoss计算交叉熵损失:

import torch

# 创建模型输出和目标标签
output = torch.randn(10, 5)  # 10个样本,5个类别
target = torch.tensor([1, 0, 4, 2, 3, 1, 0, 4, 2, 3])  # 目标类别索引

# 创建交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss()

# 计算损失
loss = criterion(output, target)

print(loss)

在训练中,你可以使用torch.nn.CrossEntropyLoss作为损失函数来优化模型。假设你已经有一个PyTorch模型和训练数据集,以下是一个简单的训练循环示例,它使用交叉熵损失函数来训练模型:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)
    
    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

model = MyModel()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 训练循环
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

在这个例子中,MyModel是一个简单的两层全连接神经网络。训练循环通过从数据集中加载数据批次,使用optimizer.zero_grad()清空梯度,计算模型输出和损失,使用loss.backward()计算梯度并使用optimizer.step()更新模型参数。每个epoch结束时,模型将在测试集上进行评估,以检查其在新数据上的泛化能力。

在这个训练循环中,我们使用nn.CrossEntropyLoss()作为损失函数,并传递模型输出和目标标签作为参数。loss.backward()计算梯度并将梯度传播回模型中的参数,从而使优化器能够更新这些参数以最小化损失。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/24231.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

L-shape 方法

L-shape 方法是求解两阶段随机规划的一种常用方法,基本思想是利用切平面将第二阶段的反馈函数线性化,在构造切平面条件时有点类似 bender’s 方法。 注:这个图形中黑实线 Q ( x ) \mathcal{Q}(x) Q(x) 就是下面模型中的 L ( x ) \mathscr{L…

为什么要用线程池?

线程池是一种管理和复用线程资源的机制,它由一个线程池管理器和一组工作线程组成。线程池管理器负责创建和销毁线程池,以及管理线程池中的工作线程。工作线程则负责执行具体的任务。 线程池的主要作用是管理和复用线程资源,避免了线程的频繁…

IOS开发指南之自定义TableViewCell使用

演示效果: 1.自定义TableViewCell创建 File->new->File... 在iOS模板中选择Empty来创建一个空的XIB文件,然后点击下一步 输入XIB文件名Cell,然后点击Create创建 创建XIB文件成功后如下: 同时按钮Shift+command+L弹出库,然后输入 table筛选,选择Table View Cell 拖到下…

LSP:里氏替换原则

系列文章目录 C高性能优化编程系列 深入理解设计原则系列 深入理解设计模式系列 高级C并发线程编程 LSP:里氏替换原则 系列文章目录1、里氏替换原则的定义和解读2、里氏替换原则可以用于哪些设计模式中?3、如何使用里氏替换原则来降低代码耦合度&#…

ChatGPT原理简介

承接上文GPT前2代版本简介 GPT3的基本思想 GPT2没有引起多大轰动,真正改变NLP格局的是第三代版本。 GPT3训练的数据包罗万象,上通天文下知地理,所以它会胡说八道,会说的贼离谱,比如让你穿越到唐代跟李白对诗,不在一…

windows里怎么杀死一个进程?

我们可以使用 taskkill 命令,可以使用该工具按照进程 ID (PID) 或映像名称终止任务。 显示帮助消息: taskkill /?参数列表: /S:system:指定要连接的远程系统。/U:[domain\]user:指定应该在哪…

第五篇:强化学习基础之马尔科夫决策过程

你好,我是zhenguo(郭震) 今天总结强化学习第五篇:马尔科夫决策过程 基础 马尔科夫决策过程(MDP)是强化学习的基础之一。下面统一称为:MDP MDP提供了描述序贯决策问题的数学框架。 它将决策问题建模为: 状态…

7种PCB走线方式

01电源布局布线相关 数字电路很多时候需要的电流是不连续的,所以对一些高速器件就会产生浪涌电流。 如果电源走线很长,则由于浪涌电流的存在进而会导致高频噪声,而此高频噪声会引入到其他信号中去。 而在高速电路中必然会存在寄生电感和寄…

STM32G4 比较器COMPx(寄存器开发)

目录 1. 特性1.1 框图1.2 比较器输入信号SEL1.3 比较器滞回选择HYST1.4 比较器的输出1.5 LOCK机制 2. 编程2.1 初始化步骤2.2 举例 STM内部的比较器是模拟量的比较器,其与APB2时钟同步,在RCC时钟控制器中没有COMx时钟使能标志位,其时钟的使能…

RTOS专栏(一) —— rt-thread简单介绍和qemu使用

本期主题: 简单介绍rt-thread介绍qemu和rt-thread怎么配合使用qemu的简单例子 rt-thread & qemu 1.rt-thread介绍2.qemu介绍3.搭建rt-thread和qemu开发环境4.简单例子 1.rt-thread介绍 RT-Thread 是一款完全由国内团队开发维护的嵌入式实时操作系统&#xff0…

JAVA POI excel 添加下拉字典的方式与案例 以及图文详解及个人理解

场景 原有的Excel 某一个 sheet 页中某些列需要添加指定的字典下拉,而这些字典的值又是确认的。 有两种思路: 一、如果给定的下拉字典值是确定的而且关联原有列的位置也不会变,那么这些数据可以固定写死在代码中,也是最简单的一…

北邮22信通:利用BF算法解决实际问题:题目实战(超详解)设计函数 char *locatesubstr(char *str1,char *str2)

北邮22信通一枚~ 跟随课程进度每周更新数据结构与算法的代码和文章 持续关注作者 解锁更多邮苑信通专属代码~ 获取更多文章 请访问专栏~ 北邮22信通_青山如墨雨如画的博客-CSDN博客 目录 题干描述 解析 1.string库函数 2.使用KMP算法思想 注解1 注解2 注解3 题…

学懂缓存雪崩,缓存击穿,缓存穿透仅需一篇,基于Redis讲解

在了解缓存雪崩、击穿、穿透这三个问题前,我们需要知道为什么我们需要缓存。在了解这三个问题后,我们也必须知道使用Redis时,如何解决这些问题。 所以我将按照"为什么我们需要缓存"、"什么是缓存雪崩、击穿、穿透"、&qu…

​字创未来 方正字库第十二届“方正奖”设计大赛正式来袭

传承汉字文化精髓,方正字库在字体行业不断探索深耕。方正字库一直致力于弘扬中华汉字文化,不断促进行业字体设计创新发展。于2001年在行业最艰难的时候,怀揣着对字体设计未来的美好向往,首届“北大方正奖”印刷字体设计大赛&#…

家政服务预约APP的系统设计与实现

摘 要:针对家政行业蓬勃发展,老套的家政服务方式已经跟不上互联网时代的步伐这个问题。基于Android移动平台的分析和设计过程、C/S模式、Eclipse平台,采用Java语言进行开发设计,设计了基于MVC架构的实现方案。安卓客户端与服务器…

Flume系列:Flume通道拓扑结构

目录 Apache Hadoop生态-目录汇总-持续更新 1: 基础架构 2:简单串联 3:复制(Replicating)和多路复用(Multiplexing) 4:负载均衡和故障转移 5:聚合 Apache Hadoop生态-目录汇总-持续更新 系统环境:centos7 Java环境…

IDEA 创建 Springmvc 项目

一、概述 在18年的时候就开始接触 SpringBoot ,然后就一直在使用它。众所周知 SpringBoot 内嵌 Tomcat,后续再也没有单独新建过Web 项目。作为IDEA 的用户,总想要用它来建一个Web 项目自己跑一跑,但建项目不是我最终目的~~ &…

好用的自动化框架-Allure

概述 报告主要包含总览、类别、测试套件、图表、时间刻度、功能、包等7大部分,支持自定义诸多信息,包括附件添加、缺陷链接、案例链接、测试步骤、Epic、Feature、Story、Title、案例级别等,相当强大。 allure与pytest的结合使用可以呈现完…

ProtoBuf 语法(一)

系列文章 ProtoBuf 语法(二) ProtoBuf 语法(三) 文章目录 前言一、字段规则二、消息类型的定义与使用2.1 定义2.2 使用 三、enum 类型3.1 定义规则3.2 注意事项 四、any 类型4.1 类型说明4.2 类型使用 五、oneof 类型六、map 类型…

【Vue】二:Vue核心处理---计算属性 监视属性

文章目录 1.计算属性示例2. 监听属性3.补充 1.计算属性示例 实际上计算属性与methods中定义方法基本上没有什么区别,只是计算属性基于响应式依赖缓存,只要数据没有发生改变,计算属性从缓存中取值,只有当数据发送改变,才…