深度学习中,用损失的均值或者总和反向传播的区别

如深度学习中代码:

def train_epoch_ch3(net, train_iter, loss, updater):
    """The training loop defined in Chapter 3."""
    # Set the model to training mode
    if isinstance(net, torch.nn.Module):
        net.train()
    # Sum of training loss, sum of training accuracy, no. of examples
    metric = Accumulator(3)
    for X, y in train_iter:
        # Compute gradients and update parameters
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # Using PyTorch in-built optimizer & loss criterion
            updater.zero_grad()
            l.backward()
            updater.step()
            metric.add(float(l) * len(y), accuracy(y_hat, y), y.numel())
        else:
            # Using custom built optimizer & loss criterion
            l.sum().backward()
            updater(X.shape[0])
            metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # Return training loss and training accuracy
    return metric[0] / metric[2], metric[1] / metric[2]

对于循环for X, y in train_iter:每一次迭代都是一次小批量训练 

当用torch里的优化器来更新参数时,第一个if语句执行:

1.此时由于torch默认累计梯度,所以每次循环都得梯度置零

2.然后torch进行反向传播(如果是使用了torch定义的loss,那么得到的l是标量,即平均损失,次此时直接反向传播即可)。

loss = nn.CrossEntropyLoss(reduction='mean'),reduction='mean'意味着损失会返回均值,reduction理解成降维

loss = nn.CrossEntropyLoss(reduction='None'),reduction='None'意味着损失会返回原来的形式,即矢量

为什么不用总和形式呢?均值会好一点,因为每次训练有不同的batch_size(如果我把样本分成三个批量,一次训练中,首先用批量一训练,然后批量二,批量三,然后批量都训练完后,再进行新的一轮训练,直到收敛),用均值的话会默认使用l.sum()/batch_size,batch_size由内部求出,这个不用自己写,就实现了batch_size的解耦,不容易出错,更新参数时形式就变成了w_i=w_i-\alpha *loss,因为此时的损失是均值了

3.进行参数更新

4.统计损失与精度,为什么要用float(l)*len(y)呢

 因为l=loss(y_hat,y),这里的损失函数是从外部传进来的,所以如果传进来的是torch自己定义的损失函数,应该计算的是平均损失,那么要变成整体损失就得float(l)*len(y)。

如果loss是自己定义的,可能只是计算了损失,没求平均(看上面的函数,此时的loss应该是矢量,所以才会有l.sum())

当自己实现优化器来更新参数时,else语句执行:

1.此时自己可能没实现累计梯度,所以不用梯度清零

2.反向传播,但由于此时自己实现了loss函数,可能loss是矢量,要转成标量来反向传播,一个矢量转标量的好方法就是求和,即l.sum().backward(),

3.更新参数,注意到updater传入了一个参数X.shape[0],即样本数量batch_size,外边要注意更新参数时为w_i =w_i-\alpha /batchSize*loss,因为此时的l是总和

4.统计损失与精度

sum和mean其实主要影响了“梯度的大小”,反向传播时,依据损失求梯度,如果是sum,则梯度会比mean大n倍,那么在学习率不变的情况下,步子会迈得很长,体现到图形上就是正确率提升不了。所以需要缩小学习率。

l.mean().backward()和l.sum().backward()的区别在于它们计算梯度的方式。l.mean().backward()计算的是平均损失对权重的梯度,而l.sum().backward()计算的是总损失对权重的梯度。

在实践中,这两种方式通常会得到相似的结果,因为它们都是在尝试最小化损失。然而,使用l.mean().backward()可能会使得梯度的大小更稳定,因为它不会因为批量大小的变化而变化。这可能会使得训练过程更稳定,特别是在批量大小可能变化的情况下。

另一方面,使用l.sum().backward()可能会使得梯度的大小更大,这可能会导致训练过程更快,但也可能导致训练过程更不稳定。

总的来说,哪种方式更好取决于你的具体情况。如果你的批量大小是固定的,那么你可能会发现l.sum().backward()和l.mean().backward()在实践中没有太大的区别。如果你的批量大小可能变化,那么你可能会发现l.mean().backward()在实践中更稳定

补:总训练函数:

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
    """Train a model (defined in Chapter 3)."""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/941686.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

K8S Ingress 服务配置步骤说明

部署Pod服务 分别使用kubectl run和kubectl apply 部署nginx和tomcat服务 # 快速启动一个nginx服务 kubectl run my-nginx --imagenginx --port80# 使用yaml创建tomcat服务 kubectl apply -f my-tomcat.yamlmy-tomcat.yaml apiVersion: apps/v1 kind: Deployment metadata:n…

基于DockerCompose搭建Redis主从哨兵模式

linux目录结构 内网配置 哨兵配置文件如下&#xff0c;创建3个哨兵配置文件 # sentinel26379.conf sentinel26380.conf sentinel26381.conf 内容如下 protected-mode no sentinel monitor mymaster redis-master 6379 2 sentinel down-after-milliseconds mymaster 60000 s…

将java项目部署到linux

命令解析 Dockerfile: Dockerfile 是一个文本文件&#xff0c;包含了所有必要的指令来组装&#xff08;build&#xff09;一个 Docker 镜像。 docker build: 根据 Dockerfile 或标准指令来构建一个新的镜像。 docker save: 将本地镜像保存为一个 tar 文件。 docker load: 从…

SQL server学习08-使用索引和视图优化查询

目录 一&#xff0c;创建和管理索引 1&#xff0c;索引的概念 2&#xff0c;索引的分类 3&#xff0c;创建索引的原则 4&#xff0c;创建索引 1&#xff09;使用SSMS的图形化界面 2&#xff09; 使用T-SQL 二&#xff0c;创建和使用视图 1&#xff0c;视图概念 2&…

ffmpeg翻页转场动效的安装及使用

文章目录 前言一、背景二、选型分析2.1 ffmpeg自带的xfade滤镜2.2 ffmpeg使用GL Transition库2.3 xfade-easing项目三、安装3.1、安装依赖([参考](https://trac.ffmpeg.org/wiki/CompilationGuide/macOS#InstallingdependencieswithHomebrew))3.2、获取ffmpeg源码3.3、融合xf…

EasyPoi 使用$fe:模板语法生成Word动态行

1 Maven 依赖 <dependency><groupId>cn.afterturn</groupId><artifactId>easypoi-spring-boot-starter</artifactId><version>4.0.0</version> </dependency> 2 application.yml spring:main:allow-bean-definition-over…

C++----类与对象(下篇)

再谈构造函数 回顾函数体内赋值 在创建对象时&#xff0c;编译器通过调用构造函数&#xff0c;给对象中各个成员变量一个合适的初始值。 class Date{ public: Date(int year, int month, int day) { _year year; _month month; _day day; } private: int _year; int _mo…

delve调试环境搭建—golang

原文地址&#xff1a;delve调试环境搭建—golang – 无敌牛 欢迎参观我的个人博客&#xff1a;无敌牛 – 技术/著作/典籍/分享等 由于平时不用 IDE 开发环境&#xff0c;习惯在 linux终端vim 环境下开发&#xff0c;所以找了golang的调试工具&#xff0c;delve类似gdb的调试界…

Oracle安装报错:将配置数据上载到资料档案库时出错

环境&#xff1a;联想服务器 windows2022安装Oracle11g 结论&#xff1a;禁用多余网卡先试试&#xff0c;谢谢。 以下是问题描述和处理过程&#xff1a; 网上处理方式: hosts文件添加如下&#xff1a; 关闭防火墙 暂时无法测试通过。 发现ping不是本地状态&#xff0c;而是…

数据结构:栈(顺序栈)

目录 1.栈的定义 2.栈的结构 3.栈的接口 3.1初始化 3.2栈的销毁 3.3压栈 3.4判断栈是否为空 3.5出栈 3.6得到栈顶元素 3.7栈的大小 1.栈的定义 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端…

LightGBM分类算法在医疗数据挖掘中的深度探索与应用创新(上)

一、引言 1.1 医疗数据挖掘的重要性与挑战 在当今数字化医疗时代,医疗数据呈爆炸式增长,这些数据蕴含着丰富的信息,对医疗决策具有极为重要的意义。通过对医疗数据的深入挖掘,可以发现潜在的疾病模式、治疗效果关联以及患者的健康风险因素,从而为精准医疗、个性化治疗方…

【WPS安装】WPS编译错误总结:WPS编译失败+仅编译成功ungrib等

WPS编译错误总结&#xff1a;WPS编译失败仅编译成功ungrib等 WPS编译过程问题1&#xff1a;WPS编译失败错误1&#xff1a;gfortran: error: unrecognized command-line option ‘-convert’; did you mean ‘-fconvert’?解决方案 问题2&#xff1a;WPS编译三个exe文件只出现u…

深入理解Redis

1.数据结构类型 数据结构-SDS-简单动态字符串 Redis构建了一种新字符串结构,称为简单动态字符串(Simple Dynamic String),简称SDS。 Redis未直接使用C语言的字符串,如:char* s = "hello",本质是字符数组: {h, e, l, l, o, \0}。因为C语言字符串存在很多问题…

前端开发 之 12个鼠标交互特效上【附完整源码】

前端开发 之 12个鼠标交互特效上【附完整源码】 文章目录 前端开发 之 12个鼠标交互特效上【附完整源码】一&#xff1a;彩色空心爱心滑动特效1.效果展示2.HTML完整代码 二&#xff1a;彩色实心爱心滑动特效1.效果展示2.HTML完整代码 三&#xff1a;粒子连结特效1.效果展示2.HT…

解析mysqlbinlog

一、前置设置 ps -ef | grep mysql 查看mysql进程对应的安装目录 需设置mysql binlog日志模式为 ROW 二、执行命令 [rootlocalhost bin]# mysqlbinlog --verbose --base64-outputdecode-rows /usr/local/mysql/data/binlog.000069 > 1.sql 查看文件具体内容

理解神经网络

神经网络是一种模拟人类大脑工作方式的计算模型&#xff0c;是深度学习和机器学习领域的基础。 基本原理 神经网络的基本原理是模拟人脑神经系统的功能&#xff0c;通过多个节点&#xff08;也叫神经元&#xff09;的连接和计算&#xff0c;实现非线性模型的组合和输出。每个…

基于Vue.js和SpringBoot的笔记记录分享网站的设计与实现(文末附源码)

博主介绍&#xff1a;✌全网粉丝50W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流✌ 技术范围&#xff1a;SpringBoot、Vue、SSM、HLM…

信息安全管理与评估赛题第9套

全国职业院校技能大赛 高等职业教育组 信息安全管理与评估 赛题九 模块一 网络平台搭建与设备安全防护 1 赛项时间 共计180分钟。 2 赛项信息 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 第一阶段 网络平台搭建与设备安全防护 任务1 网络平台搭建 XX:XX- XX:XX 50 任务2…

怎么在idea中创建springboot项目

最近想系统学习下springboot&#xff0c;尝试一下全栈路线 从零开始&#xff0c;下面将叙述下如何创建项目 环境 首先确保自己环境没问题 jdkMavenidea 创建springboot项目 1.打开idea&#xff0c;选择file->New->Project 2.选择Spring Initializr->设置JDK->…

【计算机视觉基础CV-图像分类】05 - 深入解析ResNet与GoogLeNet:从基础理论到实际应用

引言 在上一篇文章中&#xff0c;我们详细介绍了ResNet与GoogLeNet的网络结构、设计理念及其在图像分类中的应用。本文将继续深入探讨如何在实际项目中应用这些模型&#xff0c;特别是如何保存训练好的模型、加载模型以及使用模型进行新图像的预测。通过这些步骤&#xff0c;读…