Pytorch个人学习记录总结 09

目录

损失函数与反向传播

L1Loss

MSELOSS

CrossEntropyLoss


损失函数与反向传播

所需的Loss计算函数都在torch.nn的LossFunctions中,官方网址是:torch.nn — PyTorch 2.0 documentation。举例了L1Loss、MSELoss、CrossEntropyLoss。


在这些Loss函数的使用中,有以下注意的点:
(1) 参数reduction='mean',默认是'mean'表示对差值的和求均值,还可以是'sum'则不会求均值。
(2) 一定要注意Input和target的shape。 

L1Loss

创建一个标准,用于测量中每个元素之间的Input: x xx 和 target: y yy。

创建一个标准,用来测量Input: x xx 和 target: y yy 中的每个元素之间的平均绝对误差(MAE)(L 1 L_1L 1范数)。

 

 

Shape:

Input: (∗ *∗), where ∗ *∗ means any number of dimensions. 会对所有维度的loss求均值
Target: (∗ *∗), same shape as the input. 与Input的shape相同
Output: scalar.返回值是标量。

假设 a aa 是标量,则有:

type(a) = torch.Tensor
a.shape = torch.Size([])
a.dim = 0
 

MSELOSS

创建一个标准,用来测量Input: x xx 和 target: y yy 中的每个元素之间的均方误差(平方L2范数)。


 

Shape:

Input: (∗ *∗), where ∗ *∗ means any number of dimensions. 会对所有维度求loss
Target: (∗ *∗), same shape as the input. 与Input的shape相同
Output: scalar.返回值是标量。


CrossEntropyLoss

该标准计算 input 和 target 之间的交叉熵损失。

非常适用于当训练 C CC 类的分类问题(即多分类问题,若是二分类问题,可采用BCELoss)。如果要提供可选参数 w e i g h t weightweight ,那 w e i g h t weightweight 应设置为1维tensor去为每个类分配权重。这在训练集不平衡时特别有用。

期望的 input应包含每个类的原始的、未标准化的分数。input必须是大小为C CC(input未分批)、(m i n i b a t c h , C minibatch,Cminibatch,C) or (m i n i b a t c h , C , d 1 , d 2 , . . . d kminibatch,C,d_1,d_2,...d_kminibatch,C,d 1,d 2,...d k )的Tensor。最后一种方法适用于高维输入,例如计算2D图像的每像素交叉熵损失。

期望的 target应包含以下内容之一:

(1) (target包含了)在[ 0 , C ) [0,C)[0,C)区间的类别索引,C CC是类别总数量。如果指定了 ignore_index,则此损失也接受此类索引(此索引不一定在类别范围内)。reduction='none'情况下的loss为

注意:l o g loglog默认是以10为底的。
 

x是input,y yy是target,w ww是权重weight,C CC是类别数量,N NN涵盖minibatch维度且d 1 , d 2 . . . , d k d_1,d_2...,d_kd 1,d 2...,d k分别表示第k个维度。(N太难翻译了,总感觉没翻译对)如果reduction='mean'或'sum',

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):  # 模型前向传播
        return self.model(x)


model = Model()  # 定义模型
loss_cross = nn.CrossEntropyLoss()  # 定义损失函数

for data in dataloader:
    imgs, targets = data
    outputs = model(imgs)
    # print(outputs)    # 先打印查看一下结果。outputs.shape=(2, 10) 即(N,C)
    # print(targets)    # target.shape=(2) 即(N)
    # 观察outputs和target的shape,然后选择使用哪个损失函数
    res_loss = loss_cross(outputs, targets)
    res_loss.backward()  # 损失反向传播
    print(res_loss)

#
# inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
# targets = torch.tensor([1, 2, 5], dtype=torch.float32)
#
# inputs = torch.reshape(inputs, (1, 1, 1, 3))
# targets = torch.reshape(targets, (1, 1, 1, 3))
#
# # -------------L1Loss--------------- #
# loss = nn.L1Loss()
# res = loss(inputs, targets)  # 返回的是一个标量,ndim=0
# print(res)  # tensor(1.6667)
#
# # -------------MSELoss--------------- #
# loss_mse = nn.MSELoss()
# res_mse = loss_mse(inputs, targets)
# print(res_mse)
#
# # -------------CrossEntropyLoss--------------- #
# x = torch.tensor([0.1, 0.2, 0.3])  # (N,C)
# x = torch.reshape(x, (1, 3))
# y = torch.tensor([1])  # (N)
# loss_cross = nn.CrossEntropyLoss()
# res_cross = loss_cross(x, y)
# print(res_cross)


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/55982.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

jenkins通过sshPut传输文件的时候,报错Permission denied的SftpException

一、背景 使用jenkins的ssh插件传输文件至远程机器的指定目录,php程序打包后,经过zip压缩为oms.zip zip -rq oms.zip ./ -x .git/* -x .env然后我们求md5值 md5sum oms.zip最后执行传输。 09:03:02 Executing command on ssh[116.61.10.149]: mkdir…

C#生成dll给c++调用 方法二COM方式 vs2022 NO Make Assembly COM-Visible选错了 不需要clr

有些C项目中也用了C语言.c,用方法一就无法使用【不能使用 /clr 选项编译 C 文件】。就用方法2。 方法二:COM方式 参考: https://www.5axxw.com/questions/content/2ozion 1.C# 生成dll using System; using System.Collections.Generic; using System.Linq; usin…

Jvm的一些技巧

反编译字节码文件 找到对应的class文件所在的目录,使用javap -v -p 命令 查询运行中某个Java进程的Jvm参数 【案例】查询 MethodAreaDemo 这个类运行过程中,初始的元空间大小 MetaspaceSize jps 查询 Java 进程的进程ID 使用jinfo查看具体的参数&…

阿里云 MSE + ZadigX ,无门槛实现云原生全链路灰度发布

作者:ZadigX 企业发布现状痛点 目前企业在选择和实施发布策略时面临以下困境: 1. 缺乏云原生能力: 由于从传统部署转变为云原生模式后,技术架构改造需要具备相关能力的人才。这使得企业在发布策略方面难以入手。 2. 缺乏自动化…

手写自定义的spring-boot-start

需求&#xff1a;手写一个加密的spring-boot-start&#xff0c;按着用户定义的加密算法&#xff08;可选&#xff1a;MD5、SHA&#xff09;去加密内容 新建一个maven项目 新建好的项目结构和pom.xml如图 添加pom.xml 完整的pom.xml文件 <?xml version"1.0" …

VS中使用QT的插件:QT VS Tools

1、插件下载 &#xff08;1&#xff09;可以在VS中的管理扩展中直接搜索安装&#xff0c;但是我下载太慢&#xff0c;甚至是根本就无法安装。 &#xff08;2&#xff09;qt插件下载地址&#xff1a;Index of /official_releases/vsaddin 这个地址下载就很快&#xff0c;下载…

web集群学习

目录 简述静态网页和动态网页的区别 Webl.0 和 Web2.0 的区别 安装tomcat8&#xff0c;配置服务启动脚本&#xff0c;部署jpress应用 1、安装jdk文件&#xff1a; 2、下载tomcat的二进制包&#xff1a; 3、创建用户组和用户 创建Tomcat的登录服务脚本 此处创建的登录服务…

MyBatis查询数据库(3)

前言&#x1f36d; ❤️❤️❤️SSM专栏更新中&#xff0c;各位大佬觉得写得不错&#xff0c;支持一下&#xff0c;感谢了&#xff01;❤️❤️❤️ Spring Spring MVC MyBatis_冷兮雪的博客-CSDN博客 前面我们讲解了MyBatis增删改查基本操作&#xff0c;下面我们来深入了解M…

Mac上命令

1. block端口&#xff1a; sudo cp /etc/pf.conf /etc/pf443.conf 编辑pf443.conf&#xff0c;vim /etc/pf443.conf&#xff0c;如 block on en0 proto udp from any to any port 9000 # block UDP port 9000 block on en0 proto tcp from any to any port 5004 # bloc…

【Spring AOP + 自定义注解 + 动态数据源 实现主从库切换读写分离】—— 案例实战

&#x1f4a7; S p r i n g A O P 主从数据源切换 读写分离 自定义注解案例实战&#xff01; \color{#FF1493}{Spring AOP 主从数据源切换 读写分离 自定义注解 案例实战&#xff01;} SpringAOP主从数据源切换读写分离自定义注解案例实战&#xff01;&#x1f4a7; …

Docker Sybase修改中文编码

镜像&#xff1a;datagrip/sybase 镜像默认用户名sa&#xff0c;密码myPassword&#xff0c;服务名MYSYBASE 1.进入容器 docker exec -it <container_name> /bin/bash2.加载Sybase环境变量 source /opt/sybase/SYBASE.sh3.查看是否安装了中文字符集 isql -Usa -PmyP…

【Unity学习笔记】对象池

文章目录 设计思路总体设计从生命周期考虑 一些代码 对象池这个东西老生常谈了&#xff0c;使用它的好处在于&#xff1a;当我们需要重复创建或者销毁一些物体&#xff0c;例如限制子弹数量上限为10发&#xff0c;当射出第11发就需要使第10发消失&#xff0c;第11出现。销毁10号…

NullPointerException导致手机重启案例分析

和你一起终身学习&#xff0c;这里是程序员Android 经典好文推荐&#xff0c;通过阅读本文&#xff0c;您将收获以下知识点: 一、 Framework 层对象 空指针导致手机重启。二、解决方案&#xff0c;规避空指针三、Telecom APK 控制导致的重启举例 一、 Framework 层对象 空指针导…

XGBoost的基础思想与实现

目录 1. XGBoost VS 梯度提升树 1.1 XGBoost实现精确性与复杂度之间的平衡 1.2 XGBoost极大程度地降低模型复杂度、提升模型运行效率 1.3 保留了部分与梯度提升树类似的属性 2. XGBoost的sklearnAPI实现 2.1 sklearn API 实现回归 2.2 sklearn API 实现分类 3. XGBoost回…

MySQL 极速安装使用与卸载

目录 mysql-5.6.51 极速安装使用与卸载 sqlyog工具 mysql简化 mysql-8.1.0下载配置 再完善 mysql-5.6.51 极速安装使用与卸载 mysql-8.1.0下载安装在后 mysql中国官网 MySQLhttps://www.mysql.com/cn/ 点击MySQL社区服务器 点击历史档案 下载完 解压 用管理员运行cmd&a…

ODIN_1靶机详解

ODIN_1靶机复盘 下载地址&#xff1a;https: //download.vulnhub.com/odin/odin.ova 靶场很简单&#xff0c;一会儿就打完了。 靶场说明里提醒说加一个dns解析。 我们在/etc/hosts加一条解析 就能正常打开网站了&#xff0c;要么网站打开css是乱的。 这里看到结尾就猜测肯定…

uniapp自定义消息语音

需求是后端推送的消息APP要响自定义语音&#xff0c;利用官方插件&#xff0c;总结下整体流程 uniapp后台配置 因为2.0只支持uniapp自己的后台发送消息&#xff0c;所以要自己的后台发送消息只能用1.0 插件地址和代码 插件地址: link let isIos (plus.os.name "iOS&qu…

<C++> 三、内存管理

1.C/C内存分布 我们先来看下面的一段代码和相关问题 int globalVar 1; static int staticGlobalVar 1; void Test() {static int staticVar 1;int localVar 1;int num1[10] {1, 2, 3, 4};char char2[] "abcd";const char *pChar3 "abcd";int *ptr1…

小乌龟(TortoiseGit)连接GitLab

目录 &#x1f35f;写在前面 &#x1f35f;实验目标 &#x1f35f;安装gitlab &#x1f37f;1、安装依赖 &#x1f37f;2、下载清华gitlab包 &#x1f37f;3、安装gitlab &#x1f37f;4、修改配置文件 &#x1f37f;5、管理命令 &#x1f35f;访问gitlab &#x1f35f;界面设置…

《吐血整理》进阶系列教程-拿捏Fiddler抓包教程(16)-Fiddler如何充当第三者再识AutoResponder标签-上

1.简介 Fiddler充当第三者&#xff0c;主要是通过AutoResponder标签在客户端和服务端之间&#xff0c;Fiddler抓包&#xff0c;然后改包&#xff0c;最后发送。AutoResponder这个功能可以算的上是Fiddler最实用的功能&#xff0c;可以让我们修改服务器端返回的数据&#xff0c…