nn.KLDivLoss,nn.CrossEntropyLoss,nn.MSELoss,Focal_Loss

  1. KL loss:https://blog.csdn.net/qq_50001789/article/details/128974654

https://pytorch.org/docs/stable/nn.html

1. nn.L1Loss

1.1 公式

L1Loss: 计算预测 x和 目标y之间的平均绝对值误差MAE, 即L1损失
l o s s = 1 n ∑ i = 1 , . . . n ∣ x i − y i ∣ loss=\frac{1}{n} \sum_{i=1,...n}|x_i-y_i| loss=n1i=1,...nxiyi

1.2 语法

torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')
  • size_averagereduce已经被弃用,具体功能可由reduction替代
  • reduction:指定损失输出的形式,有三种选择:none|mean|sum。none:损失不做任何处理,直接输出一个数组mean:将得到的损失求平均值再输出,会输出一个数sum:将得到的损失求和再输出,会输出一个数

注意:输入的x y 可以是任意维数的数组,但是二者形状必须一致

1.3 应用案例

对比reduction不同时,输出损失的差异

import torch.nn as nn
import torch

x = torch.rand(10, dtype=torch.float)
y = torch.rand(10, dtype=torch.float)
L1_none = nn.L1Loss(reduction='none')
L1_mean = nn.L1Loss(reduction='mean')
L1_sum = nn.L1Loss(reduction='sum')
out_none = L1_none(x, y)
out_mean = L1_mean(x, y)
out_sum = L1_sum(x, y)
print(x)
print(y)
print(out_none)
print(out_mean)
print(out_sum)

在这里插入图片描述

2.nn.MSELoss

2.1 公式

MSELoss也叫L2 loss, 即计算预测x和目标y的平方误差损失。MSELoss的计算公式如下:

l o s s = 1 n ∑ i = 1 , . . n ( x i − y i ) 2 loss=\frac{1}{n} \sum_{i=1,..n}(x_i-y_i)^2 loss=n1i=1,..n(xiyi)2

:输入x 与y 可以是任意维数的数组,但是二者shape大小一致

2.2 语法

torch.nn.MSELoss(reduction = 'mean')

其中:

  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数

2.3 应用案例

对比reduction不同时,输出损失的差异

import torch.nn as nn
import torch

x = torch.rand(10, dtype=torch.float)
y = torch.rand(10, dtype=torch.float)
mse_none = nn.MSELoss(reduction='none')
mse_mean = nn.MSELoss(reduction='mean')
mse_sum = nn.MSELoss(reduction='sum')
out_none = mse_none(x, y)
out_mean = mse_mean(x, y)
out_sum = mse_sum(x, y)


print('x:',x)
print('y:',y)
print("out_none:",out_none)
print("out_mean:",out_mean)
print("out_sum:",out_sum)

在这里插入图片描述

3 nn.SmoothL1Loss

3.1 公式

SmoothL1Loss是结合L1 lossL2 loss改进的,其数学定义如下:

在这里插入图片描述
在这里插入图片描述
如果绝对值误差低于 β \beta β, 则使用 L 2 L2 L2损失,,否则使用绝对值损失 L 1 L1 L1, ,此损失对异常值的敏感性低于 L 2 L2 L2 ,即当 x x x y y y相差过大时,该损失数值要小于 L 2 L2 L2损失,在某些情况下该损失可以防止梯度爆炸。

3.2 语法

torch.nn.SmoothL1Loss( reduction='mean', beta=1.0)
  • reduction:指定损失输出的形式,有三种选择:none|mean|sum。none:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数
  • beta:损失在 L 1 L1 L1 L 2 L2 L2之间切换的阈值,默认beta=1.0

3.3 应用案例

import torch.nn as nn
import torch

# reduction设为none便于逐元素对比损失值
loss_none = nn.SmoothL1Loss(reduction='none')
loss_sum = nn.SmoothL1Loss(reduction='sum')
loss_mean = nn.SmoothL1Loss(reduction='mean')
x = torch.randn(10)
y = torch.randn(10)
out_none = loss_none(x, y)
out_sum = loss_sum(x, y)
out_mean = loss_mean(x, y)
print('x:',x)
print('y:',y)
print("out_none:",out_none)
print("out_mean:",out_mean)
print("out_sum:",out_sum)

在这里插入图片描述

4. nn.CrossEntropyLoss

nn.CrossEntropyLoss 在pytorch中主要用于多分类问题的损失计算。

4.1 交叉熵定义

交叉熵主要是用来判定实际的输出与期望的输出的接近程度,也就是交叉熵的值越小,两个概率分布就越接近。
假设概率分布p为期望输出(target),概率分布q为实际输出(pred), H ( p , q ) H(p,q) H(p,q)为交叉熵, 则:
在这里插入图片描述
在这里插入图片描述

Pytorch中的CrossEntropyLoss()函数

Pytorch中计算的交叉熵并不是采用交叉熵定义的公式得到的,其中q为预测值,p为target值:
在这里插入图片描述
而是交叉熵的另外一种方式计算得到的:
在这里插入图片描述
PytorchCrossEntropyLoss()函数的主要是将log_softmax NLLLoss最小化负对数似然函数)合并到一块得到的结果

CrossEntropyLoss()=log_softmax() + NLLLoss() 

在这里插入图片描述

  • (1) 首先对预测值pred进行softmax计算:其中softmax函数又称为归一化指数函数,它可以把一个多维向量压缩在(0,1)之间,并且它们的和为1
    在这里插入图片描述
  • (2) 然后对softmax计算的结果,再取log对数。
  • (3) 最后再利用NLLLoss() 计算CrossEntropyLoss, 其中NLLLoss() 的计算过程为:将经过log_softmax计算的结果与target 相乘并求和,然后取反。

其中(1),(2)实现的是log_softmax计算,(3)实现的是NLLLoss(), 经过以上3步计算,得到最终的交叉熵损失的计算结果。

4.2 语法

torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction='mean',label_smoothing=0.0)
  • 最常用的参数为 reduction(str, optional) ,可设置其值为 mean, sum, none ,默认为 mean。该参数主要影响多个样本输入时,损失的综合方法。mean表示损失为多个样本的平均值,sum表示损失的和,none表示不综合。
  • weight: 可手动设置每个类别的权重,weight的数组大小和类别数需保持一致

4.3 应用案例

import torch.nn as nn
import torch
loss_func = nn.CrossEntropyLoss()
pre = torch.tensor([[0.8, 0.5, 0.2, 0.5]], dtype=torch.float)
tgt = torch.tensor([[1, 0, 0, 0]], dtype=torch.float)
print('----------------手动计算---------------------')
print("1.softmax")
print(torch.softmax(pre, dim=-1))
print("2.取对数")
print(torch.log(torch.softmax(pre, dim=-1)))
print("3.与真实值相乘")
print(-torch.sum(torch.mul(torch.log(torch.softmax(pre, dim=-1)), tgt), dim=-1))
print('-------------调用损失函数-----------------')
print(loss_func(pre, tgt))
print('----------------------------------------')

在这里插入图片描述
由此可见:

  • ①交叉熵损失函数会自动对输入模型的预测值进行softmax。因此在多分类问题中,如果使用nn.CrossEntropyLoss()则预测模型的输出层无需添加softmax层

  • nn.CrossEntropyLoss()=nn.LogSoftmax()+nn.NLLLoss().

nn.CrossEntropyLoss() 的target可以是one-hot格式,也可以直接输出类别,不需要进行one-hot处理,如下示例:

import torch
import torch.nn as nn
x_input=torch.randn(3,3)#随机生成输入 
print('x_input:\n',x_input) 
y_target=torch.tensor([1,2,0])#设置输出具体值 print('y_target\n',y_target)

#计算输入softmax,此时可以看到每一行加到一起结果都是1
softmax_func=nn.Softmax(dim=1)
soft_output=softmax_func(x_input)
print('soft_output:\n',soft_output)

#在softmax的基础上取log
log_output=torch.log(soft_output)
print('log_output:\n',log_output)

#对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果,发现两者是一致的。
logsoftmax_func=nn.LogSoftmax(dim=1)
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n',logsoftmax_output)

#pytorch中关于NLLLoss的默认参数配置为:reducetion=True、size_average=True
nllloss_func=nn.NLLLoss()
nlloss_output=nllloss_func(logsoftmax_output,y_target)
print('nlloss_output:\n',nlloss_output)

#直接使用pytorch中的loss_func=nn.CrossEntropyLoss()看与经过NLLLoss的计算是不是一样
crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(x_input,y_target)
print('crossentropyloss_output:\n',crossentropyloss_output)
  • 其中pred为x_input=torch.randn(3,3,对应的target为y_target=torch.tensor([1,2,0]), target并没有处理为one-hot格式,也可以正常计算结果的。

5. nn.BCELoss和nn.BCEWithLogitsLoss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/166358.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

leetcode算法之分治-快排

目录 1.颜色分类2.排序数组3.数组中的第k个最大元素4.最小的k个数 1.颜色分类 颜色分类 class Solution { public:void sortColors(vector<int>& nums) {int n nums.size();int left -1,rightn,i0;while(i<right){if(nums[i] 0) swap(nums[left],nums[i]);e…

Spring的后处理器

目录 引言 BeanFactoryPostProcessor 注意 BeanPostProcessor 引言 Spring的后处理器是spring对外开发的重要扩展点&#xff0c;允许我们介入到Bean的整个实例化流程来&#xff0c;以达到动态注册BeanDefintion&#xff0c;动态修改BeanDefintion&#xff0c;以及动态修改Be…

【Android】使用Retrofit2发送异步网络请求的简单案例

添加网络权限到AndroidManifest.xml清单文件 为了让你的Android应用程序能够使用互联网进行通信&#xff0c;你需要在AndroidManifest.xml文件中添加网络权限声明。<uses-permission android:name"android.permission.INTERNET"/> 这个权限应该添加到 Android…

laravel-admin导出excel全部,表中无id列导出失败

laravel-admin导出excel时&#xff0c;导出全部数据&#xff0c;但是表中没有id字段&#xff0c;然后就无法导出excel&#xff1b; 就直接显示 一开始我也很着急&#xff0c;弄了半天还是不行&#xff0c;然后重写还是有问题 最后发现底层代码排序是按照id排序的orderBy(id, a…

音视频技术在手机上的应用与挑战

// 编者按&#xff1a;随着手机相机功能日益强大&#xff0c;4k&#xff0c;8k&#xff0c;各类特色短视频的拍摄&#xff0c;编辑、播放需求日益增长&#xff0c;短视频应用的火爆也对当前的手机音视频技术提出了更高的要求&#xff0c;如何更好地提高用户体验成为了行业共同…

Android 13 - Media框架(14)- OpenMax(二)

这一节我们将来解析 media.codec 这个 HIDL service 究竟提供了什么服务&#xff0c;服务是如何启动的。 1、main 函数 我们先来看 frameworks/av/services/mediacodec/main_codecservice.cpp&#xff1a; int main(int argc __unused, char** argv) {strcpy(argv[0], "…

微服务 Spring Cloud 7,Nacos配置中心的Pull原理,附源码

目录 一、本地配置二、配置中心1、以Nacos为例&#xff1a;2、Pull模式3、也可以通过Nacos实现注册中心 三、配置中心提供了哪些功能四、如何操作配置中心1、配置注册2、配置反注册3、配置查看4、配置变更订阅 五、主流的微服务注册中心有哪些&#xff0c;如何选择&#xff1f;…

76基于matlab的免疫算法求解配送中心选址问题,根据配送地址确定最佳配送中心地址位置。

基于matlab的免疫算法求解配送中心选址问题&#xff0c;根据配送地址确定最佳配送中心地址位置。数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。 76matlab免疫算法配送中心选址 (xiaohongshu.com)

常见Web安全

一.Web安全概述 以下是百度百科对于web安全的解释&#xff1a; Web安全&#xff0c;计算机术语&#xff0c;随着Web2.0、社交网络、微博等等一系列新型的互联网产品的诞生&#xff0c;基于Web环境的互联网应用越来越广泛&#xff0c;企业信息化的过程中各种应用都架设在Web平台…

SpringCloud总结

注&#xff1a;本文并不涉及具体功能是怎么实现的&#xff0c;而只是微服务技术栈的整体总结和理解。 目录 一.基础概念--认识微服务 1.单体架构 2.分布式架构 3.微服务 4.SpringCloud 二.服务的拆分原则 三.RestTemplate--实现不同服务之间的通信与远程调用 四.Eurek…

5 redis的GEO操作

一、GEO Redis 3.2版本提供了GEO(地理信息定位)功能&#xff0c;支持存储地理位置信息用来实现诸如附近位置、摇一摇这类依赖于地理位置信息的功能。 有效纬度从-85.05112878度到85.05112878度 注意&#xff1a;当坐标位置超出上述指定范围时&#xff0c;将会返回一个错误。 …

RTMP协议和源码解析

一、背景 实时消息传输协议&#xff08;Real-Time Messaging Protocol&#xff09;是目前直播的主要协议&#xff0c;是Adobe公司为Flash播放器和服务器之间提供音视频数据传输服务而设计的应用层私有协议。RTMP协议是目前各大云厂商直线直播业务所公用的基本直播推拉流协议&a…

高防CDN为什么可以防DDOS攻击

CDN的全称是ContentDeliveryNetwork&#xff0c;即内容分发网络&#xff0c;顾名思义&#xff0c;它是一个分布式节点网络(也称为边缘服务器)&#xff0c;CDN节点具有缓存内容的功能&#xff0c;使用户可以在不获取源服务器数据的情况下就近获取所需内容&#xff0c;提高客户访…

【LeetCode每日一题合集】2023.9.25-2023.10.1(⭐LFU缓存Java数据流花期内花的数量)

文章目录 460. LFU 缓存⭐&#xff08;数据结构题&#xff09;解法1——平衡树 哈希表&#xff08;TreeSet HashMap&#xff09; O ( l o g n ) O(logn) O(logn)解法2——双哈希表 双向链表 O ( 1 ) O(1) O(1) &#xff08;LRU缓存的升级版&#xff09; 2582. 递枕头解法—…

SpringCloud微服务注册中心:Nacos介绍,微服务注册,Ribbon通信,Ribbon负载均衡,Nacos配置管理详细介绍

微服务注册中心 注册中心可以说是微服务架构中的”通讯录“&#xff0c;它记录了服务和服务地址的映射关系。在分布式架构中&#xff0c;服务会注册到这里&#xff0c;当服务需要调用其它服务时&#xff0c;就这里找到服务的地址&#xff0c;进行调用。 微服务注册中心 服务注…

LV.12 D18 中断处理 学习笔记

一、ARM的异常处理机制及工程代码结构 1.1异常概念 处理器在正常执行程序的过程中可能会遇到一些不正常的事件发生 这时处理器就要将当前的程序暂停下来转而去处理这个异常的事件 异常事件处理完成之后再返回到被异常打断的点继续执行程序。 1.2异常处理机制 不同的处…

2023/11/19总结

项目进度&#xff1a; 地址管理&#xff1a; 显示菜品 购物车相关功能 然后最近在看 支付宝沙盒支付的相关功能&#xff0c;打算把支付给做了 。界面做的不是很好看 &#xff0c;但是后续会改成 手机端的。

丹麦能源袭击预示着更关键的基础设施成为目标

5 月&#xff0c;22 个丹麦能源部门组织在与俄罗斯 Sandworm APT 部分相关的攻击中受到损害。 丹麦关键基础设施安全非营利组织 SektorCERT 的一份新报告描述了不同的攻击者群体利用合勤防火墙设备中的多个关键漏洞&#xff08;包括两个零日漏洞&#xff09;侵入工业机械&…

【Android Jetpack】DataStore的介绍

DataStore Jetpack DataStore是一种数据存储解决方案&#xff0c;允许您使用协议缓冲区存储键值对或类型化对象。DataStore使用Kotlin协程和Flow以异步、一致的事务方式存储数据。 注意&#xff1a;如果您需要支持大型或复杂数据集、部分更新或参照完整性&#xff0c;请考虑使…

配置iTerm2打开自动执行命令

打开iTerm2&#xff0c;commado&#xff0c;打开profies->edit profies&#xff0c;点击号&#xff0c;创建一个新的profile 在新的profile中填写 name&#xff1a;随意 command&#xff1a;Login Shell Send text at start&#xff1a;执行脚本的命令&#xff0c;不想写路…