神经网络 torch.nn---损失函数与反向传播

torch.nn - PyTorch中文文档 (pytorch-cn.readthedocs.io)

torch.nn — PyTorch 2.3 documentation

Loss Function的作用

  • 每次训练神经网络的时候都会有一个目标,也会有一个输出。目标和输出之间的误差,就是用Loss Function来衡量的。所以,误差Loss是越小越好的。

  • 此外,我们可以根据误差Loss,指导输出output接近目标target。即我们可以以target为依据,不断训练神经网络,优化神经网络中各个模块,从而优化output

Loss Function的作用

  1. 计算实际输出和目标之间的差距
  2. 为我们更新输出提供一定的依据,这个提供依据的过程也叫反向传播

nn.L1Loss

创建一个衡量输入x(模型预测输出)和目标y之间差的绝对值的平均值的标准。

class torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

参数说明:

  • reduction:默认为 ‘mean’ ,可选meansum

  • reduction='mean'时,计算误差采用公式:

  • reduction='sum'时,计算误差采用公式:

需要注意的是,计算的数据必须为浮点数

程序代码:

import torch
from torch.nn import L1Loss

input=torch.tensor([1,2,3],dtype=torch.float32)
target=torch.tensor([1,2,5],dtype=torch.float32)

input=torch.reshape(input,(1,1,1,3))
target=torch.reshape(target,(1,1,1,3))

loss1=L1Loss()  #reduction='mean'
loss2=L1Loss(reduction='sum')  
result1=loss1(input,target)
result2=loss2(input,target)

print(result1,result2)

输出:

nn.MSELoss

创建一个衡量输入x(模型预测输出)和目标y之间均方误差标准。

  • x 和 y 可以是任意形状,每个包含n个元素。

  • n个元素对应的差值的绝对值求和,得出来的结果除以n

  • 如果在创建MSELoss实例的时候在构造函数中传入size_average=False,那么求出来的平方和将不会除以n

class torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

参数说明:

reduction:默认为 ‘mean’ ,可选meansum

  • reduction='mean'时,计算误差采用公式:

  • reduction='sum'时,计算误差采用公式:

程序代码:

import torch
from torch.nn import L1Loss,MSELoss

input = torch.tensor([1,2,3],dtype=torch.float32)
target = torch.tensor([1,2,5],dtype=torch.float32)

input = torch.reshape(input,(1,1,1,3))
target = torch.reshape(target,(1,1,1,3))

loss_mse1 = MSELoss()  #reduction='mean'
loss_mse2 = MSELoss(reduction='sum')
result_mse1 = loss_mse1(input, target)
result_mse2 = loss_mse2(input, target)

print(result_mse1, result_mse2)

输出:

nn.CrossEntropyLoss(交叉熵)

当训练一个分类问题的时候,假设这个分类问题有C个类别,那么有:

 当weight参数被指定的时候,loss的计算公式变为:

计算出的lossmini-batch的大小取了平均。

形状(shape):

  • Input: (N,C)    其中N代表batch_size,C 是类别的数量即数据要分成几类(或有几个标签)。

  • Target: (N)     Nmini-batch的大小,0 <= targets[i] <= C-1

举个例子:

  • 我们对包含了人、狗、猫的图片进行分类,其标签的索引分别为0、1、2。这时候将一张的图片输入神经网络,即目标(target)为1(对应标签索引)。输出结果为[0.1,0.2,0.3],该列表中的数字分别代表分类标签对应的概率。

  • 根据上述分类结果,图片为的概率更大,即0.3。对于该分类的Loss Function,我们可以通过交叉熵去计算,即:

那么如何验证这个公式的合理性呢?根据上面的例子,分类结果越准确,Loss应该越小。这条公式由两个部分组成:

  • 1、log(∑jexp(x[j])

log(∑jexp(x[j])主要作用是控制或限制预测结果的概率分布。比如说,预测出来的人、狗、猫的概率均为0.9,每个结果概率都很高,这显然是不合理的。此时 log(∑jexp(x[j]) 的值会变大,误差loss(x,class)也会随之变大。同时该指标也可以作为分类器性能评判标准。

  • 2、−x[class]:在已知图片类别的情况下,预测出来对应该类别的概率x[class]越高,其预测结果误差越小。

程序代码:

import torch
from torch import nn
from torch.nn import L1Loss

inputs = torch.tensor([1, 2, 3], dtype=torch.float)
targets = torch.tensor([1, 2, 5], dtype=torch.float)

inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))

x = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float)
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))

loss_cross_1 = nn.CrossEntropyLoss(reduction='mean')
result_cross_1 = loss_cross_1(x, y)
loss_cross_2 = nn.CrossEntropyLoss(reduction='sum')
result_cross_2 = loss_cross_2(x, y)
print(result_cross_1, result_cross_2)

输出:

反向传播

如何根据Loss Function为更新神经网络数据提供依据?

  • 对于每个卷积核当中的参数,设置一个grad(梯度)。

  • 当我们进行反向传播的时候,对每一个节点的参数都会求出一个对应的梯度。之后我们根据梯度对每一个参数进行优化,最终达到降低Loss的一个目的。比较典型的一个方法——梯度下降法

代码举例:

 result_loss = loss(outputs, targets)
 result_loss.backward()
  • 上面就是反向传播的使用方法,它的主要作用是计算一个grad。使用debug功能并删掉上面这行代码,会发现单纯由result_loss=loss(output,targets)计算出来的结果,是没有grad这个参数的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/686031.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

美国签证办理需要带哪些材料?

在申请美国签证时&#xff0c;准备充分的材料至关重要。以下知识人网整理的关于您可能需要携带的一些常见材料&#xff1a; 1.护照&#xff1a;您的护照必须是有效的&#xff0c;并且在签证申请过程中至少有六个月的有效期。 2.签证申请表&#xff1a;您需要填写并提交签证申请…

29 - 买下所有产品的客户(高频 SQL 50 题基础版)

29 - 买下所有产品的客户 selectc.customer_id fromCustomer c group byc.customer_id havingcount(c.product_key)(select count(distinct product_key) from Product);

Windows下安装和配置Redis

目录 1、下载redis压缩包 2、解压redis文件 3、启动redis临时服务 4、打开Redis客户端进行连接 5、使用一些基础操作来测试 5.1、输入ping命令来检测redis服务器与redis客户端的连通性 5.2、使用set和get命令测试redis数据库进行数据存储和获取 5.3、在命令中通过shut…

Easy 同学:AI 时代将加速计算机专业和程序员职业的分化

一、原贴 2024 年 6 月 5 日 拥有 60多万粉丝的方糖气球&#xff08;ftqq.com&#xff09;博主 、独立开发者&#xff1a;Easy 发表了一篇 AI 对计算机专业和程序员行业影响的新浪博客&#xff0c;看后很有启发&#xff0c;故而将原文摘录于此&#xff1a; 单独开个贴说一下吧…

项目实战系列——WebSocket——websock简介

最近项目中需要用到mes和本地客户端进行实时通讯&#xff0c;本来想用webapi进行交互的&#xff0c;但是考虑到高效和实时性&#xff0c;就采用这一项技术。 以往采用的方式——长轮询 客户端主动向服务器发送一个请求&#xff0c;如果服务器没有更新的数据&#xff0c;客户端…

我的python管理

目前环境 Anaconda&#xff1a;python3.9 python2.7 IDA&#xff1a;python3.8 pycharm&#xff1a;&#xff1f;&#xff1f; 以后应该会补吧… 因为某些文件似乎用的python2决定整个python2 安装python2.7 打开anaconda命令行输入 conda create --name python27 python2…

六、Docker Swarm、Docker Stack和Portainer的使用

六、Docker swarm和Docker stack的使用 系列文章目录1.Docker swarm1.简介2.docker swarm常用命令3.docker node常用命令4.docker service常用命令5.实战案例6.参考文章 2.Docker stack1.简介3.Docker stack常用命令4.实战案例5.常见问题及调错方式1.查看报错信息并尝试解决&am…

【简报】VITA 74 (VNX)总结

参考源 VITA 74 &#xff08;VNX&#xff09;A VITA 74 &#xff08;VNX&#xff09;B VITA 74 &#xff08;VNX&#xff09;C VITA 74 &#xff08;VNX&#xff09;D VNX&#xff0c;也称为 VITA 74&#xff0c;在 VITA 标准组织管理的规范中定义。VNX目前已进入“试用”状…

1104 天长地久(测试点1,2,3)

solution 测试点3超时&#xff1a;直接暴力搜超时。m和m1的最大公约数一定是1&#xff0c;则A的个位一定是9才有可能gcd(m, m1)大于1&#xff0c;步长变为10。测试点1&#xff0c;3&#xff1a;m和n的最大公约数是大于2的素数测试点2&#xff1a;按照n从小到大排序&#xff0c…

生死战的主场,国足与泰国队1比1战平,遗憾失去了出线主动权

6月6日晚&#xff0c;于中国足球曾经的福地沈阳&#xff0c;中国男子国家足球队迎来了一场至关重要的对决&#xff0c;与泰国队展开了一场生死较量。在此前&#xff0c;中国队只需取胜即可提前一轮晋级&#xff0c;然而遗憾的是&#xff0c;国足未能把握住机会&#xff0c;再次…

Android Studio安装详细教程

&#x1f4d6;Android Studio安装教程 ✅1. 下载✅2. 安装✅3. 第一次启动导向 ✅1. 下载 官方地址&#xff1a;https://developer.android.google.cn/studio?hlzh-cn 下载最新稳定版即可 ✅2. 安装 1.点击运行android-studio-2023.3.1.18-windows.exe安装包&#xff0c;直…

四川汇聚荣聚荣科技有限公司综合实力如何?

在探讨一个公司的综合实力时&#xff0c;我们不仅关注其经济表现&#xff0c;还应深入分析其技术实力、市场地位、创新能力、团队素质以及社会责任感等多个维度。四川汇聚荣聚荣科技有限公司作为一家立足于科技领域的企业&#xff0c;其实力究竟如何呢?接下来的内容将围绕这一…

神了,Suno创作的《不期待》三个版本,你最喜欢哪一个?

Suno AI 最近尝试了下 Suno AI 的v3.5模型&#xff0c;有点惊艳 做了三首版本的AI歌曲&#xff0c;词是我写的&#xff0c;其他全都交给Suno了&#xff0c;欢迎大家来听听 B站链接 Youtube链接

电能质量在线监测装置

安科瑞电气股份有限公司 祁洁 15000363176 一、装置概述 APView500电能质量在线监测装置采用了高性能多核平台和嵌入式操作系统&#xff0c;遵照IEC61000-4-30《测试和测量技术-电能质量测量方法》中规定的各电能质量指标的测量方法进行测量&#xff0c;集谐波分析、波形采…

【Stable Diffusion】(基础篇一)—— Stable Diffusion的安装

本系列笔记主要参考B站nenly同学的视频教程&#xff0c;传送门&#xff1a;B站第一套系统的AI绘画课&#xff01;零基础学会Stable Diffusion&#xff0c;这绝对是你看过的最容易上手的AI绘画教程 | SD WebUI 保姆级攻略_哔哩哔哩_bilibili **Stable Diffusion&#xff08;简称…

体验SmartEDA:颠覆传统,设计流程更流畅,超越Multisim与Proteus!

在电子设计自动化&#xff08;EDA&#xff09;领域&#xff0c;传统软件如Multisim和Proteus一直是工程师们的得力助手。然而&#xff0c;随着科技的飞速发展和用户需求的不断升级&#xff0c;一个全新的EDA平台——SmartEDA正崭露头角&#xff0c;凭借其更为流畅的设计流程&am…

C++ 并发编程指南(5)线程状态及切换

文章目录 一、多线程状态及切换1、线程状态2、状态切换 前言&#xff1a; C中的线程状态及切换是操作系统和C线程库&#xff08;如POSIX线程或C11及之后的<thread>库&#xff09;共同管理的。线程的状态和切换是多线程编程中的重要概念&#xff0c;下面将简要介绍C线程的…

太阳能航空障碍灯在航空安全发挥什么作用_鼎跃安全

随着我国经济的快速发展&#xff0c;空域已经成为经济发展的重要领域。航空运输、空中旅游、无人机物流、飞行汽车等经济活动为空域经济发展提供了巨大潜力。然而&#xff0c;空域安全作为空域经济发展的关键因素&#xff0c;受到了广泛关注。 随着空域经济活动的多样化和密集…

Java实现经纬度坐标转换

一、坐标系统简介 坐标系统&#xff0c;是描述物质存在的空间位置&#xff08;坐标&#xff09;的参照系&#xff0c;通过定义特定基准及其参数形式来实现。 坐标是描述位置的一组数值&#xff0c;按坐标的维度一般分为一维坐标&#xff08;公路里程碑&#xff09;和二维坐标…

SpringBoot社区配送服务系统小程序-计算机毕业设计源码88705

摘要 随着科学技术的飞速发展&#xff0c;社会的方方面面、各行各业都在努力与现代的先进技术接轨&#xff0c;通过科技手段来提高自身的优势&#xff0c;社区当然也不例外。社区配送服务系统小程序是以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c;采…