Pytorch学习 day10(L1Loss、MSELoss、交叉熵Loss、反向传播)

Loss

  • loss的作用如下:
    • 计算实际输出和真实值之间的差距
    • 为我们更新模型提供一定的依据(反向传播)

L1Loss

  • 绝对值损失函数:在每一个batch_size内,求每个输入x和标签y的差的绝对值,最后返回他们平均值
    在这里插入图片描述

MSELoss

  • 均方损失函数:在每一个batch_size内,求每个输入x和标签y的差的平方,最后返回他们的平均值
    在这里插入图片描述

交叉熵Loss

  • 当我们在处理分类问题时,经常使用交叉熵损失函数。
    • 交叉熵能够衡量同一个随机变量中的两个不同概率分布的差异程度,在机器学习中就表示为真实概率分布与预测概率分布之间的差异。交叉熵的值越小,模型预测效果就越好。
    • 交叉熵在分类问题中常常与softmax是标配,softmax将输出的结果进行处理,使其多个分类的预测值和为1,再通过交叉熵来计算损失。
  • 由于以下内容需要理解Softmax函数和交叉熵损失函数,所以先回顾一遍:
  • Softmax函数:
    • 首先,分类任务的目标是通过比较每个类别的概率大小来判断预测的结果。但是,我们不能选择未规范化的线性输出作为我们的预测。原因有两点。
1. 线性输出的总和不一定为1
2. 线性输出可能有负值
  • 因此我们采用Softmax规范手段来保证输出的非负、和为1,公式和举例如下:
    • 左侧为Softmax函数公式,右侧的o为线性输出,y为Softmax规范后的输出
      在这里插入图片描述
  • 交叉熵损失函数:
    • 下图为交叉熵损失函数公式,P(x)为真实概率分布,q(x)为预测概率分布:
      在这里插入图片描述
  • 我们将Softmax规范后的输出代入交叉熵损失函数中,可得:
    • 在训练中,我们已知该样本的类别,那么在该样本的真实概率分布中,只有该类别为1,其他都为0。
    • 在计算机中的log,默认都是ln。
      请添加图片描述
  • 这就是Pytorch官网中的交叉熵损失函数公式:
    在这里插入图片描述
  • 注意:给此公式的交叉熵损失函数传入的input,不需要进行规范化,即不需要进行Softmax变换
  • 我们仍然使用该类的对象函数来调用forward方法,而forward方法需要满足以下条件:
    • input:第一位为batch_size,第二位为输入的class数量
    • target:只有一位,为batch_size
      在这里插入图片描述
  • 代码如下:
import torch

x = torch.tensor([0.1, 0.2, 0.3])
print(x.shape)  # torch.Size([3])
print(x)    # tensor([0.1000, 0.2000, 0.3000])
y = torch.tensor([1])   
x = torch.reshape(x, (1,3)) # 由于交叉熵损失函数的forward方法要求输入是二维,且第一位是batch_size,第二位是class的数量
print(x.shape)  # torch.Size([1, 3])
print(x)    # tensor([[0.1000, 0.2000, 0.3000]])
loss_cross = torch.nn.CrossEntropyLoss()    # 交叉熵损失函数
result_loss = loss_cross(x, y)  # 计算交叉熵损失
print(result_loss)  

# 输出结果:
# torch.Size([3])
# tensor([0.1000, 0.2000, 0.3000])
# torch.Size([1, 3])
# tensor([[0.1000, 0.2000, 0.3000]])
# tensor(1.1019)
  • 计算器的输出结果如下:
    • 代码中的log默认为ln
      在这里插入图片描述

反向传播

  • 当输入不变时,我们要想让总loss最小,就是要找到一组最小的w、b序列,这时我们可以采用一种系统的方法:梯度下降方法
    • 那么找w、b序列,就转换为求学习率和loss对w、b的偏导数,形象化的表示如下:
      在这里插入图片描述
    • 梯度下降的公式如下:
      在这里插入图片描述
  • 这其中:学习率是我们手动设定的,偏导数则是模型自动计算的。
  • 由于每一个节点都需要计算偏导数,如果我们采用正向传播计算,那么针对每一个节点,我们都需要正向计算到结尾一次,而反向传播,只需要我们从头正向计算到结尾一次,之后根据节点位置,进行反向偏导数相乘即可,流程图如下:
    在这里插入图片描述
  • 在模型代码中,偏导数用grad(梯度)表示,在模型的训练过程中,通过反向传播来计算每个网络层节点的对应梯度,并通过某种算法(优化器)不断更新节点的参数,最终达到loss最小的一个结果,代码如下:
import torch
import torchvision
from torch import nn

test_dataset = torchvision.datasets.CIFAR10(root='Dataset', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
        self.module1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    def forward(self, input):
        output = self.module1(input)
        return output

loss = nn.CrossEntropyLoss()    # 交叉熵损失函数
tudui = Tudui()

for data in test_loader:
    inputs, targets = data
    outputs = tudui(inputs)
    result_loss = loss(outputs, targets)    # 计算loss
    result_loss.backward()  # 反向传播,注意需要使用计算后的loss
    a=1 # 用于调试,设置断点
    break
  • 结果如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/458715.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python创建虚拟环境-Anaconda安装配置和使用

Anaconda提供了一个名为conda的包管理工具,可以方便地创建、管理和分享Python环境。用户可以根据自己的需要创建不同的环境,每个环境都可以拥有自己的Python版本、库和依赖项,这样就可以避免因为不同项目之间的依赖关系而导致的冲突问题。 一…

Vscode中关于Java的一些问题

前言 在使用Vscode的时候,总是会有这么一种感觉:有时得这样,有时得那样,这让我甚是困惑,于是写下来这篇解答文章 为什么java文件有时候会有class文件,有时候没有 在编写Java代码时,我会有一种…

【Java基础】IO流(二)字符集知识

目录 字符集知识 1、GBK字符集 2、Unicode字符集(万国码) 3、乱码 4、Java中编码和解码的方法 字符集知识 字符(Character):在计算机和电信技术中,一个字符是一个单位的字形、类字形单位或符号的基本信…

智能合约开发基础知识:最小信任机制、智能合约、EVM

苏泽 大家好 这里是苏泽 一个钟爱区块链技术的后端开发者 本篇专栏 ←持续记录本人自学两年走过无数弯路的智能合约学习笔记和经验总结 如果喜欢拜托三连支持~ 专栏的前面几篇详细了介绍了区块链的核心基础知识 有兴趣学习的小伙伴可以看看http://t.csdnimg.cn/fCD5E关于区块…

光伏便携式EL检测仪是什么?—科技助农

光伏便携式EL监测仪是一种专门用于检测光伏电池组件性能的高效、实用的设备。它利用电致发光(Electroluminescence,EL)原理,通过检测光伏板在受到光照后产生的电流所激发出的光线,来评估光伏板的性能。这种设备通常具有…

Linux搭建我的世界(MC)整合包服务器,All the Mods 9(ATM9)整合包开服教程

Linux使用MCSM面板搭建我的世界(Minecraft)整合包服务器,MC开服教程,All the Mods 9(ATM9)整合包搭建服务器的教程。 本教程使用Docker来运行mc服,可以方便切换不同Java版本,方便安装多个mc服版本。 视频教程:https:…

算法的渐进时间复杂度

T(n) = O(F(n)) T(n):Time 渐进时间复杂度 O:正比例关系 F(n):代码执行次数 只要代码执行的次数越来越多 所耗费的时间也就越来越高 常见的5种: O(n^2) O(n logn) O(n) O(logn) O(1):不管重复多少次1次也是这个时间,10次也是这个时间。 时间复杂度排序:由小到…

C语言指针与数组(不适合初学者版):一篇文章带你深入了解指针与数组!

🎈个人主页:JAMES别扣了 💕在校大学生一枚。对IT有着极其浓厚的兴趣 ✨系列专栏目前为C语言初阶、后续会更新c语言的学习方法以及c题目分享. 😍希望我的文章对大家有着不一样的帮助,欢迎大家关注我,我也会回…

我记不住的那些命令-xxdod

背景: 这里记录一下具体的xxd和od命令参数和使用方法,我想我肯定是记不住这些参数的。 零、文件 我们这里有一个示例文件README.txt,内容如下图所示: 一、xxd(查看、编辑二进制文件) 1. 十六进制显示(默认) 我们发现上面的显…

使用代理ip后访问网站仍然被拒该怎么办

目录 前言 一、代理IP被网站封锁 二、代理IP的质量不佳 三、代理设置不正确 总结 前言 在使用代理IP之后仍然被网站拒绝访问可能是由于多种原因引起的。这些原因包括代理IP被网站封锁、代理IP的质量不佳、代理设置不正确等。下面将详细介绍如何解决这些问题。 一、代理I…

代码学习记录16

随想录日记part16 t i m e : time: time: 2024.03.11 主要内容:今天的主要内容是二叉树的第五部分,主要涉及最大二叉树;合并二叉树;二叉搜索树的搜索;验证二叉搜索树。 654.最大二叉…

使用docker-compose部署Redis集群

一、部署三主三从的Redis集群 分别为6个节点建立挂载目录,每个目录下建立数据、配置、日志文件夹。 docker-compose内容如下: version: 3 services:redis1:image: redis:6.2.3restart: alwaysports:- "6379:6379"- "16379:16379"v…

Spring揭秘:ClassPathScanningProvider接口应用场景及实现原理!

技术应用场景 ClassPathScanningCandidateComponentProvider是Spring框架中一个非常核心的类,它主要用于在类路径下扫描并发现带有特定注解的组件,支持诸如ComponentScan、Component、Service、Repository和Controller等注解的自动扫描和注册。 ClassP…

.NET开源快速、强大、免费的电子表格组件

今天大姚给大家分享一个.NET开源(MIT License)、快速、强大、免费的电子表格组件,支持数据格式、冻结、大纲、公式计算、图表、脚本执行等。兼容 Excel 2007 (.xlsx) 格式,支持WinForm、WPF和Android平台:ReoGrid。 项…

普发Pfeiffer TPG256A MaxiGauge 真空计控制器接口通讯针脚等详情见图目录

普发Pfeiffer TPG256A MaxiGauge 真空计控制器接口通讯针脚等详情见图目录

强化学习中SARSA(State-Action-Reward-State-Action)和Q-learning的区别

SARSA(State-Action-Reward-State-Action)和Q-learning是两种经典的强化学习算法,它们都用于学习最优策略以使智能体在一个环境中获得最大的累积奖励。它们之间的主要区别在于它们更新动作值函数(Q值函数)的方式以及其…

SwiftUI组件-DatePicker

SwiftUI组件-DatePicker 本文记录一下SwiftUI组件-DatePicker import SwiftUIstruct DatePickerBootCamp: View {State var selectedDate: Date Date()var dateFormatter: DateFormatter {let formatter DateFormatter()formatter.dateStyle .shortformatter.timeStyle .…

使用kill()函数向进程发送信号

本片文章的学习记录总结来源于:https://www.bilibili.com/cheese/play/ep182660?csourcecommon_hp_history_null&t11&spm_id_from333.1007.top_right_bar_window_history.content.click 通常在Linux系统中,可以使用 kill or killall 命令向指定…

OpenCASCADE开发指南<十二>:OCC创建三维瓶子模型

在OpenCASCADE有一个例程,在 官方帮助网站中可以找到。程将教你如何使OpenCASCADE的API来进行三维建模。教程的目的不是描述所有的类,而是帮助你思考如何将OpenCASCADE作为一种工具。 1 概述 利用OpenCASCADE的API创建一个三维瓶子,形状如下…

如何在Linux部署DataEase数据分析服务并实现无公网IP远程分析内网数据信息

文章目录 前言1. 安装DataEase2. 本地访问测试3. 安装 cpolar内网穿透软件4. 配置DataEase公网访问地址5. 公网远程访问Data Ease6. 固定Data Ease公网地址 前言 DataEase 是开源的数据可视化分析工具,帮助用户快速分析数据并洞察业务趋势,从而实现业务…