8-pytorch-损失函数与反向传播

b站小土堆pytorch教程学习笔记

根据loss更新模型参数
1.计算实际输出与目标之间的差距
2.为我们更新输出提供一定的依据(反向传播)

在这里插入图片描述

1 MSEloss

import torch
from torch.nn import L1Loss
from torch import nn

inputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)

inputs=torch.reshape(inputs,(-1,1,1,3))
targets=torch.reshape(targets,(-1,1,1,3))

loss=L1Loss()
result=loss(inputs,targets)

loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)

print(result)
print(result_mse)

tensor(0.6667)
tensor(1.3333)

2 Cross EntropyLoss

在这里插入图片描述

x=torch.tensor([0.1,0.2,0.3])#需要reshape为要求的(batch_size,class)
y=torch.tensor([1])#target已经为要求的batch_size无需reshape
x=torch.reshape(x,(-1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(result_cross)

tensor(1.1019)

3 在具体的神经网络中使用loss

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset=torchvision.datasets.CIFAR10('dataset',train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)
dataloader=DataLoader(dataset,batch_size=1)

class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x

loss=nn.CrossEntropyLoss()
han=Han()
for data in dataloader:
    imgs,target=data
    output=han(imgs)
    # print(target)
    # print(output)
    result_loss=loss(output,target)
    print(result_loss)

*tensor([7])
tensor([[ 0.0057, -0.0201, -0.0796, 0.0556, -0.0625, 0.0125, -0.0413, -0.0056,
0.0624, -0.1072]], grad_fn=)…

tensor(2.2664, grad_fn=)…

4 反向传播 优化器

  1. 定义优化器
  2. 将待更新的每个参数梯度清零
  3. 调用损失函数的反向传播函数求出每个节点的梯度
  4. 使用step函数对模型的每个参数调优
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset=torchvision.datasets.CIFAR10('dataset',train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)
dataloader=DataLoader(dataset,batch_size=64)

class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x

loss=nn.CrossEntropyLoss()
han=Han()
optim=torch.optim.SGD(han.parameters(),lr=0.01)

for epoch in range(5):
    running_loss=0.0#一个epoch结束的loss和
    for data in dataloader:
        imgs,target=data
        output=han(imgs)

        result_loss=loss(output,target)#每次迭代的loss
        optim.zero_grad()#将网络中每个可调节参数对应的梯度调为0
        result_loss.backward()#优化器需要每个参数的梯度,使用反向传播获得
        optim.step()#对每个参数调优
        running_loss=running_loss+result_loss
    print(running_loss)

Files already downloaded and verified
tensor(361.0316, grad_fn=)
tensor(357.6938, grad_fn=)
tensor(343.0560, grad_fn=)
tensor(321.8132, grad_fn=)
tensor(313.3173, grad_fn=)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/408522.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL——基础内容

目录 第01章_数据库概述 关系型数据库(RDBMS)——表、关系模型 非关系型数据库(非RDBMS) 表、记录、字段 表的关联关系 一对一关联 一对多关系 多对多 自我引用 第02章_MySQL环境搭建 登录命令 常用命令 show databases; create database use 数据库名 show tables 第03章…

五种多目标优化算法(MOCS、MOFA、NSWOA、MOAHA、MOPSO)性能对比(提供MATLAB代码)

一、5种多目标优化算法简介 多目标优化算法是用于解决具有多个目标函数的优化问题的一类算法。其求解流程通常包括以下几个步骤: 1. 定义问题:首先需要明确问题的目标函数和约束条件。多目标优化问题通常涉及多个目标函数,这些目标函数可能…

云原生之容器编排实践-kubectl get pod -A没有coredns

背景 前面搭建的3节点 Kubernetes 集群,其实少了一个组件: CoreDNS ,这也是我后面拿 ruoyi-cloud 项目练手时,部署了 MySQL 和 Nacos 服务后才意识到的:发现Nacos无法通过服务名连接MySQL,这里 Nacos 选择…

Mac OS 搭建C++开发环境【已解决】

Mac OS 搭建C开发环境 文章目录 Mac OS 搭建C开发环境一、安装命令行工具:二、安装vscode三、安装gcc3.1 安装Homebrew3.2 安装gcc3.3 修改配置 四、更改VSCode默认编译器五、安装gdb六、安装Cmake && git七、编译运行 本地环境: Mac OS Sonoma …

Python中回调函数的理解与应用

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站零基础入门的AI学习网站~。 目录 前言 回调函数的概念 回调函数的基本用法 回调函数的实现方式 1 使用函数 2 使用类方法 3 使用类实…

QWidget: Must construct a QApplication before a QWidget 13:25:48: 程序异常结束。

QWidget: Must construct a QApplication before a QWidget 13:25:48: 程序异常结束。 你的插件是release,而你用了debug模式、

一元函数微分学——刷题(19

目录 1.题目:2.解题思路和步骤:3.总结:小结: 1.题目: 2.解题思路和步骤: 题目给的是个复合函数,因此需要根据复合函数求得原本f’(x)的表达式: 3.总结: 题目给的是个…

Java EE改名Jakarta EE,jakarta对程序开发的影响

一、前言 很多Java程序员在使用新版本的Spring6或者springboot3版本的时候,发现了一些叫jakarta的包。我在阅读开源工作流引擎camunda源代码的时候,也发展了大量jakarta的工程包。 比如:camunda的webapps编译工程就提供了2种方式javax和jaka…

linux系统---httpd

目录 Internet的起源 一、http协议——超文本传输协议 1.http相关概念 二、HTTP请求访问的完整过程 1、 建立连接 2、 接收请求 3、 处理请求 常用请求Method: GET、POST、HEAD、PUT、DELETE、TRACE、OPTIONS 3.1 常见的HTTP方法 3.2 GET和POST比较 4、访问资源 …

【YOLO系列算法人员摔倒检测】

YOLO系列算法人员摔倒检测 模型和数据集下载YOLO系列算法的人员摔倒检测数据集可视化数据集图像示例: 模型和数据集下载 yolo行人跌倒检测一: 1、训练好的行人跌倒检测权重以及PR曲线,loss曲线等等,map达90%多,在行人跌…

黑马程序员——接口测试——day02

目录: Postman基础使用 简介和安装案例一:案例二:案例三:接口用例设计 接口测试的测试点 功能测试性能测试安全测试接口用例设计方法 单接口测试业务场景测试单接口测试用例 模版分析测试点 登录添加员工查询员工业务场景测试用例…

操作系统-复试笔记

http://t.csdnimg.cn/PJLWh 操作系统基础 什么是操作系统? 操作系统(Operating System,简称 OS)是管理计算机硬件与软件资源的程序,是计算机的基石。操作系统本质上是一个运行在计算机上的软件程序 ,用于…

【问答】接地原则

以前谈到电源去耦,我警告过糟糕的去耦会增加放大器的失真。一位读者问了一个有趣的问题,去耦电容的接地脚应该在哪里接地才能消除这个问题呢? 这个问题升级到关于正确接地的技术。题目太大了,不过我也许能够提供一些启发性的例子…

doxygen 注释->API接口文档

1.doxygen官方下载 2.工程根目录下(如qt工程.pro同级目录) #doxygen file_name 用以注释生成接口文档的输入(生成file_name文件) doxygen安装目录\bin\doxygen -g doxygen_pro #生成API接口文档,该步骤会生成html目录与latex目录 doxygen安装目录\bin\doxygen axis…

森歌深化体育营销战略,揭晓2024奥运新代言人,携手共创影响力奇迹

2024年,奥运龙年的春节将将过去,各大高端品牌便纷纷开始激烈博弈。森歌有备而来!布局早,积累深,以其深入骨髓的体育情怀和独具匠心的品牌策略,成为厨电行业的佼佼者。2月27日-2月28日,森歌将在杭…

MyBatis进阶

目录 一、实现多表查询 二、#{}和${} 1、#{}和${}的使用 2、#{}和${}的区别 3、${}的使用场景 三、数据库连接池 1、数据库连接池概念 2、常见数据库连接池 3、修改连接池为Hikari 四、动态sql语句--xml 1、if标签 2、tirm标签 3、where标签 4、set标签 5、fore…

Docker复习笔记

Centos7安装Docker Docker官网:www.docker.com Docker官网仓库:hub.docker.com Docker文档是比较详细的 安装相关依赖 yum -y install gcc gcc-c yum install -y yum-utils 设置docker镜像仓库 yum-config-manager --add-repo https://download.docker.com/linux/centos/do…

云尚办公-0.0.1

1. 核心技术 基础框架:SpringBoot数据缓存:Redis数据库:MySQL权限控制:SpringSecurity工作流引擎:Activiti前端技术:vue-admin-template Node.js Npm Vue ElementUI Axios微信公众号:公众…

【FreeRTOS】任务创建

参考博客: ESP-IDF FreeRTOS 任务创建分析 - [Genius] - 博客园 (cnblogs.com) 1.什么是任务 1)独立的无法返回的函数称为任务 2)任务是无线循环 3)无返回数据 2.任务的实现过程 1.定义任务栈 裸机程序:统一分配到一…

centos7部署单机项目和自启动

centos7部署单机项目和服务器自启动 1.安装jdk和tomact1.1上传jdk、tomcat安装包1.2解压两个工具包1.3.配置并且测试jdk安装1.4.启动tomcat1.5.防火墙设置1.6配置tomcat自启动 2.安装mysql2.1卸载mariadb,否则安装MySql会出现冲突(先查看后删除再查看)2.2在线下载My…