pytorch中常用的损失函数

1 损失函数的作用

损失函数是模型训练的基础,并且在大多数机器学习项目中,如果没有损失函数,就无法驱动模型做出正确的预测。 通俗地说,损失函数是一种数学函数或表达式,用于衡量模型在某些数据集上的表现。损失函数在深度学习主要作用如下:

  • 衡量模型性能:损失函数用于评估模型的预测结果与真实结果之间的误差程度。较小的损失值表示模型的预测结果与真实结果更接近,反之则表示误差较大。因此,损失函数提供了一种度量模型性能的方式。
  • 参数优化:在训练机器学习和深度学习模型时,损失函数被用作优化算法的目标函数。通过最小化损失函数,可以调整模型的参数,使模型能够更好地逼近真实结果。
  • 反向传播:在深度学习中,通过反向传播算法计算损失函数对模型参数的梯度。这些梯度被用于参数更新,以便优化模型。损失函数在反向传播中扮演着重要的角色,指导参数的调整方向。
  • 防止过拟合:过拟合是指模型在训练数据上表现良好,但在新数据上表现较差的现象。损失函数可以帮助在训练过程中监控模型的过拟合情况。通过观察训练集和验证集上的损失,可以及早发现模型是否过拟合,从而采取相应的措施,如正则化等。

2 pytorch中常见的损失函数

损失函数名称适用场景
torch.nn.MSELoss()均方误差损失回归
torch.nn.L1Loss()平均绝对值误差损失回归
torch.nn.CrossEntropyLoss()交叉熵损失多分类
torch.nn.NLLLoss()负对数似然函数损失多分类
torch.nn.NLLLoss2d()图片负对数似然函数损失图像分割
torch.nn.KLDivLoss()KL散度损失回归
torch.nn.BCELoss()二分类交叉熵损失二分类
torch.nn.MarginRankingLoss()评价相似度的损失
torch.nn.MultiLabelMarginLoss()多标签分类的损失多标签分类
torch.nn.SmoothL1Loss()平滑的L1损失回归
torch.nn.SoftMarginLoss()多标签二分类问题的损失

多标签二分类

2.1 L1损失函数

预测值与标签值进行相差,然后取绝对值,根据实际应用场所,可以设置是否求和,求平均,公式可见下,Pytorch调用函数:nn.L1Loss

import torch
import torch.nn as nn

Loss_fn = nn.L1Loss(size_average=None, reduce=None, reduction='mean')

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = Loss_fn(input, target)
print(output)

运行结果显示如下:

tensor(1.4177, grad_fn=<MeanBackward0>)

2.2 L2损失函数

预测值与标签值进行相差,然后取平方,根据实际应用场所,可以设置是否求和,求平均,公式可见下,Pytorch调用函数:nn.MSELoss

import torch.nn as nn
import torch

loss = nn.MSELoss(size_average=None, reduce=None, reduction='mean')

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)
print(output)

运行结果显示如下:

tensor(1.7956, grad_fn=<MseLossBackward0>)

2.3 Huber Loss损失函数

简单来说就是L1和L2损失函数的综合版本,结合了两者的优点,公式可见下,Pytorch调用函数:nn.SmoothL1Loss

import matplotlib.pyplot as plt
import torch

# 定义函数和参数
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')
x = torch.linspace(-1, 1, 10000)
y = smooth_l1_loss(torch.zeros(10000), x)

# 绘制图像
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('SmoothL1Loss')
plt.title('SmoothL1Loss Function')
plt.show()

 运行结果显示如下:

2.4 二分类交叉熵损失函数

简单来说,就是度量两个概率分布间的差异性信息,在某一程度上也可以防止梯度学习过慢,公式可见下,Pytorch调用函数有两个,一个是nn.BCELoss函数,用的时候要结合Sigmoid函数,另外一个是nn.BCEWithLogitsLoss()

import torch.nn as nn
import torch

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
print(output)

运行结果显示如下:

tensor(0.6214, grad_fn=<BinaryCrossEntropyBackward0>)
import torch
import torch.nn as nn

label = torch.empty((2, 3)).random_(2)
x = torch.randn((2, 3), requires_grad=True)

bce_with_logits_loss = nn.BCEWithLogitsLoss()
output = bce_with_logits_loss(x, label)

print(output)

 运行结果显示如下:

tensor(0.7346, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

2.5 多分类交叉熵损失函数

也是度量两个概率分布间的差异性信息,Pytorch调用函数也有两个,一个是nn.NLLLoss,用的时候要结合log softmax处理,另外一个是nn.CrossEntropyLoss

import torch
import torch.nn.functional as F

input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
output = F.nll_loss(F.log_softmax(input, dim=1), target)
print(output)

运行结果显示如下:

tensor(2.9503, grad_fn=<NllLossBackward0>)
import torch
import torch.nn as nn

loss = nn.CrossEntropyLoss()
inputs = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(inputs, target)

print(output)

运行结果显示如下:

tensor(1.6307, grad_fn=<NllLossBackward0>)

2.6 自定义损失

通过对 nn 模块进行子类化,将损失函数创建为神经网络图中的节点。 这意味着我们的自定义损失函数是一个 PyTorch 层,与卷积层完全相同。

class Custom_MSE(nn.Module):
  def __init__(self):
    super(Custom_MSE, self).__init__();

  def forward(self, predictions, target):
    square_difference = torch.square(predictions - target)
    loss_value = torch.mean(square_difference)
    return loss_value
  
  # def __call__(self, predictions, target):
  #   square_difference = torch.square(y_predictions - target)
  #   loss_value = torch.mean(square_difference)
  #   return loss_value

可以在“forward”函数调用或“call”内部定义损失的实际实现。

3 总结

损失函数在人工智能领域中起着至关重要的作用,它不仅是模型训练和优化的基础,也是评估模型性能、解决过拟合问题以及指导模型选择的重要工具。不同的损失函数适用于不同的问题和算法,选择合适的损失函数对于取得良好的模型性能至关重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/124186.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络期末复习-Part2

1、网络应用程序体系结构 &#xff08;1&#xff09;客户端/服务器&#xff08;C/S&#xff09;体系结构&#xff1a; 客户端/服务器&#xff08;C/S&#xff09;应用程序&#xff1a; Web浏览器与Web服务器&#xff1a;当您使用Web浏览器&#xff08;客户端&#xff09;访问…

Scala爬虫如何实时采集天气数据?

这是一个基本的Scala爬虫程序&#xff0c;使用了Scala的http library来发送HTTP请求和获取网页内容。在爬取天气预报信息时&#xff0c;我们首先需要创建一个代理对象proxy&#xff0c;并将其用于发送HTTP请求。然后&#xff0c;我们使用http库的GET方法获取网页内容&#xff0…

linux修改rocketmq的日志文件位置

文章目录 &#x1f50a;修改rocketmq的日志文件位置&#x1f4d5;原来的文件&#x1f4cc;修改后文件&#x1f4c7;rocketmq中的Rocketmq_client.log文件在配置文件中改不了 需要在代码logback文件中进行修改&#x1f58a;️最后总结 &#x1f50a;修改rocketmq的日志文件位置 …

【Linux】 awk命令使用

AWK 是一种处理文本文件的语言&#xff0c;是一个强大的文本分析工具。 之所以叫 AWK 是因为其取了三位创始人 Alfred Aho&#xff0c;Peter Weinberger, 和 Brian Kernighan 的 Family Name 的首字符。 语法 awk [选项] [文件] awk [选项] [程序] [文件] awk命令 -Linux手…

cookie 里面都包含什么属性?

结论先行&#xff1a; Cookie 中除了名称和值外&#xff0c;还有几个比较常见的&#xff0c;例如&#xff1a; Domain 域&#xff1a;指定了 cookie 可以发送到哪些域&#xff0c;只有发送到指定域或其子域的请求才会携带该cookie&#xff1b; Path 路径&#xff1a;指定哪些…

Spring Boot 3系列之-启动类详解

Spring Boot是一个功能强大、灵活且易于使用的框架&#xff0c;它极大地简化了Spring应用程序的开发和部署流程&#xff0c;使得开发人员能够更专注于业务逻辑的实现。在我们的Spring Boot 3系列之一&#xff08;初始化项目&#xff09;文章中&#xff0c;我们使用了Spring官方…

gitlab-ce-12.3.5 挖矿病毒及解决方案

前言 最近发现在使用gitlab提交代码的时候总是失败&#xff0c;一访问gitlab还时常报503&#xff0c;于是使用 top 命令查看了内存占用情况&#xff0c;发现了一个git进程内存使用了2.3g&#xff0c;cpu还一直占用300-400%&#xff0c; 以前不知道gitlab还有病毒&#xff0c;只…

[论文阅读]PV-RCNN++

PV-RCNN PV-RCNN: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection 论文网址&#xff1a;PV-RCNN 论文代码&#xff1a;PV-RCNN 简读论文 这篇论文提出了两个用于3D物体检测的新框架PV-RCNN和PV-RCNN,主要的贡献如下: 提出P…

ubuntu 安装redis详细教程

下载redis安装包 链接如下&#xff1a; http://redis.io/download 本例版本为&#xff1a;redis-7.2.3.tar.gz 下载安装包到目录/opt下&#xff0c;路径可修改&#xff0c;本例为/opt wget https://github.com/redis/redis/archive/7.2.3.tar.gz 解压安装包&#xff0c;并…

解决 SSLError: HTTPSConnectionPool(host=‘huggingface.co‘, port=443)

看我的回答&#xff1a; https://github.com/huggingface/transformers/issues/17611#issuecomment-1794486960 能问这个问题的都是网络不太好的&#xff0c;你懂的&#xff0c;所以答案全是解决网络的。 得益于这个回答&#xff1a;#17611 (comment) 看了一下代码&#xf…

基于ssm的校园快递物流管理系统(java+jsp+ssm+javabean+mysql+tomcat)

博主24h在线&#xff0c;想要源码文档部署视频直接私聊&#xff0c;9.9拿走&#xff01; 基于javawebmysql的ssm校园快递物流管理系统(javajspssmjavabeanmysqltomcat) 运行环境&#xff1a; Java≥8、MySQL≥5.7、Tomcat≥8 开发工具&#xff1a; eclipse/idea/myeclipse/s…

万物社用户运营工具:无代码开发下的电商平台和CRM集成

简介&#xff1a;万物社与集简云的引领式连接 万物社&#xff0c;隶属于厦门头号云信息科技有限公司&#xff0c;是一家专注于互联网和相关服务的企业。在日常的业务运营中&#xff0c;万物社通过与集简云的无代码集成&#xff0c;实现了业务流程的自动化和智能化&#xff0c;…

Linux系统下数据同步服务RSYNC

一、RSYNC概述 1、什么是rsync rsync的好姐妹 sync 同步&#xff1a;刷新文件系统缓存&#xff0c;强制将修改过的数据块写入磁盘&#xff0c;并且更新超级块。 async 异步&#xff1a;将数据先放到缓冲区&#xff0c;再周期性&#xff08;一般是30s&#xff09;的去同步到磁…

【Git】安装和常用命令的使用与讲解及项目搭建和团队开发的出现的问题并且给予解决

目录 Git的简介 介绍 Git的特点及概念 Git与SVN的区别 图解 ​编辑 命令使用 安装 使用前准备 搭建项目环境 ​编辑 团队开发 Git的简介 介绍 Git 是一种分布式版本控制系统&#xff0c;是由 Linux 之父 Linus Torvalds 于2005年创建的。Git 的设计目标是为了更好地管…

图文并茂解读联合索引底层存储结构及索引查找过程

文章目录 前言版本数据准备SQL数据创建结果有无联合索引执行情况无联合索引存在联合索引 底层存储结构查询过程最左匹配原则查询过程解析 联合索引优势支持复杂查询索引覆盖查询提高排序和分组性能减少索引数量 使用建议联合索引的列顺序十分重要建议能使用联合索引尽量使用联合…

Mybatis-Plus同时使用逻辑删除和唯一索引的问题及解决办法

1 问题背景 在开发中&#xff0c;我们经常会有逻辑删除和唯一索引同时使用的情况。但当使用mybatis plus时&#xff0c;如果同时使用逻辑删除和唯一索引&#xff0c;会报数据重复Duplicate entry的问题。 举例来说&#xff0c;有表user&#xff0c;建立唯一索引&#xff08;u…

Docker本地部署Drupal并实现公网访问

文章目录 前言1. Docker安装Drupal2. 本地局域网访问3 . Linux 安装cpolar4. 配置Drupal公网访问地址5. 公网远程访问Drupal6. 固定Drupal 公网地址 前言 Dupal是一个强大的CMS&#xff0c;适用于各种不同的网站项目&#xff0c;从小型个人博客到大型企业级门户网站。它的学习…

ZYNQ_project:key_led

条件里是十进制可以不加进制说明&#xff0c;编译器默认是10进制&#xff0c;其他进制要说明。 实验目标&#xff1a; 模块框图&#xff1a; 时序图&#xff1a; 代码&#xff1a; include "para.v"module key_filter (input wire …

2020 ICPC 澳门(G,J,I)详解

链接&#xff1a;The 2020 ICPC Asia Macau Regional Contest G Game on Sequence 题意 给定长度为 n n n 数组 a i a_i ai​&#xff0c;A与G博弈&#xff0c;G先手&#xff0c;给定初始位置 k k k&#xff0c;若当前在 i i i 点转移到 j j j&#xff0c;满足 i <…

在虚拟机中新安装的Linux无法联网解决办法

1、我们在虚拟机中新安装了linux&#xff0c;默认是无法连接网络的&#xff0c;这个时候&#xff0c;需要配置自动获取ip的网设置。 2、我们在VMware Workstatio需要配置net网络&#xff0c;如下图 3、进入linux系统&#xff0c;找到 /etc/sysconfig/network-scripts/ [rootn…