Pytorch训练时报nan

0. 引言

Pytorch训练时在batch=N时loss为nan。经过断点检查发现在batch=N-1时,网络参数非nan,输出非nan,但梯度为nan,导致网络参数已经全部被更新为nan,遇到这种情况应该如何排查,如何避免?由于导致nan的情况较为繁多,本文给出的不是一个个例的解决方案,而是一种通用的抽象解决方案。

1. 排查

最简单的排查的方式就是检查parameter的参数值:

# model
for name, param in model.named_parameters(recurse=True):
	if not torch.isfinite(param.mean()):
		print(name)

通过该种方法可以打印出网络参数中数值非有限值的参数所在层。

第二种方法是检查parameter的梯度值,该方法需要retain_graph=True (Pytorch默认不保存图结构以节省GPU内存)

# compute loss
loss.backward(retain_graph=True)
# model
for name, param in model.named_parameters(recurse=True):
	if not torch.isfinite(param.grad.mean()):
		print(name)

检查梯度和参数值的方式都是从后往前查(和反向传播的顺序一致),子节点出现问题会导致其根节点必定出现问题,因此优先排查子节点是否是导致nan的原因。

最后提醒一下,如果nan排查成功,别忘了把retain_graph=True给删了,因为这条命令占用额外的GPU内存。

2. 规避

在这里介绍的方法是基于Pytorch 1.13的,Pytorch 2.x的用户也不想要担心,因为本教程中设置的参数在Pytorch 2.x里面已经设为默认参数,完全兼容。

# compute loss
# optimizer, model
clip_grad = 1.0 # maximum value to clip grad_norm
try:
	nn.utils.clip_grad_norm_(model.parameters(), clip_grad, norm_type=2, error_if_nonfinite=True) # 遇到nonfinite的梯度报错
	optimizer.step()
except:
	print("nan detected in grad, skip batch")
	optimizer.zero_grad()  # 所有梯度置0,保证下一个batch的正常训练
	continue  # 跳过这个batch的训练

这个代码的思想就是利用clip_grad_norm_自带的梯度检查功能在反向传播前对model的每个参数梯度进行检查,如若出现梯度异常值,则跳过batch(且不会对网络进行梯度更新)。需要的注意的是,optimizer.zero_grad()除了在本代码中出现,应该在主循环里面也另外有一个,但是此处省略了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/911268.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[免费]SpringBoot+Vue(高校)学籍管理系统【论文+源码+SQL脚本】

大家好,我是java1234_小锋老师,看到一个不错的SpringBootVue(高校)学籍管理系统,分享下哈。 项目视频演示 【免费】SpringBootVue(高校)学籍管理系统 Java毕业设计_哔哩哔哩_bilibili 项目介绍 对在线学籍管理的流程进行科学整理、归纳和…

<项目代码>YOLOv7 草莓叶片病害识别<目标检测>

YOLOv7是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一个回归问题,能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法(如Faster R-CNN),YOLOv7具有更高的…

Pinia小菠萝(状态管理器)

Pinia 是一个专为 Vue 3 设计的状态管理库,它借鉴了 Vuex 的一些概念,但更加轻量灵活。下面将详细介绍如何使用 Pinia 状态管理库: 安装 Pinia 使用 npm:在项目目录下运行npm install pinia。使用 yarn:在项目目录下运…

【深度学习】多分类任务评估指标sklearn和torchmetrics对比

【深度学习】多分类任务评估指标sklearn和torchmetrics对比 说明sklearn代码torchmetrics代码两个MultiClassReport类的对比分析1. 代码结构与实现方式2. 数据处理与内存使用3. 性能与效率 二分类任务评估指标1. 准确率(Accuracy)2. 精确率(P…

[CUDA] 设置sync模式cudaSetDeviceFlags

文章目录 1. 设置cuda synchronize的等待模式2 设置函数3. streamQuery方式实现stream sync等待逻辑Reference 1. 设置cuda synchronize的等待模式 参考资料:https://docs.nvidia.com/cuda/pdf/CUDA_Runtime_API.pdf cuda的 synchronize等待模式分为: Y…

jdk安装升级到jdk17

百度安全验证 有些项目编译不过 找不到类 ,实际有,需要升级jdk到17 https://blog.csdn.net/qq_44866828/article/details/130557027 sudo apt-get update sudo apt-get install openjdk-17-jdk 然后修改一下配置路径 也就是环境变量 11 改成17 重新…

cuda、pytorch-gpu安装踩坑!!!

前提:已经安装了acanoda cuda11.6下载 直接搜索cuda11.6 acanoda操作 python版本3.9 conda create -n pytorch python3.9conda activate pytorch安装Pytorch-gpu版本等包 要使用pip安装,cu116cuda11.6版本 pip install torch1.13.1cu116 torchvi…

H.265流媒体播放器EasyPlayer.js网页web无插件播放器:如何优化加载速度

在当今的网络环境中,用户对于视频播放体验的要求越来越高,尤其是对于视频加载速度的期待。EasyPlayer.js网页web无插件播放器作为一款专为现代Web环境设计的流媒体播放器,它在优化加载速度方面采取了多种措施,以确保用户能够享受到…

C语言 | Leetcode C语言题解之第542题01矩阵

题目: 题解: /*** Return an array of arrays of size *returnSize.* The sizes of the arrays are returned as *returnColumnSizes array.* Note: Both returned array and *columnSizes array must be malloced, assume caller calls free().*/ type…

Transformer究竟是什么?预训练又指什么?BERT

目录 Transformer究竟是什么? 预训练又指什么? BERT的影响力 Transformer究竟是什么? Transformer是一种基于自注意力机制(Self-Attention Mechanism)的神经网络架构,它最初是为解决机器翻译等序列到序列(Seq2Seq)任务而设计的。与传统的循环神经网络(RNN)或卷…

【春秋云镜】CVE-2023-23752

目录 CVE-2023-23752漏洞细节漏洞利用示例修复建议 春秋云镜:解法一:解法二: CVE-2023-23752 是一个影响 Joomla CMS 的未授权路径遍历漏洞。该漏洞出现在 Joomla 4.0.0 至 4.2.7 版本中,允许未经认证的远程攻击者通过特定 API 端…

51单片机教程(七)- 蜂鸣器

1 项目分析 利用P2.3引脚输出电平变化,控制蜂鸣器的鸣叫。 2 技术准备 1 蜂鸣器介绍 有绿色电路板的一种是无源蜂鸣器,没有电路板而用黑胶封闭的一种是有源蜂鸣器。 有源蜂鸣器和无源蜂鸣器 这里的“源”不是指电源。而是指震荡源。也就是说有源蜂鸣…

十六 MyBatis使用PageHelper

十六、MyBatis使用PageHelper 16.1 limit分页 mysql的limit后面两个数字: 第一个数字:startIndex(起始下标。下标从0开始。)第二个数字:pageSize(每页显示的记录条数) 假设已知页码pageNum&…

汽车和飞机研制过程中“骡车”和“铁鸟”

在汽车和飞机的研制过程中,“骡车”和“铁鸟”都扮演着至关重要的角色。 “骡车”在汽车研制中,是一种处于原型车和量产车之间的过渡阶段产物。它通常由不同的零部件组合而成,就像骡子是马和驴的杂交后代一样,取各家之长。“骡车…

MySQL存储目录与配置文件(ubunto下)

mysql的配置文件: 在这个目录下,直接cd /etc/mysql/mysql.conf.d mysql的储存目录: /var/lib/mysql Ubuntu版本号:

RibbitMQ-安装

本文主要介绍RibbitMQ的安装 RabbitMQ依赖于Erlang,因此首先需要安装Erlang环境。分别下载erlang-26.2.5-1.el7.x86_64.rpm、rabbitmq-server-4.0.3-1.el8.noarch.rpm 官网地址:https://www.rabbitmq.com/ 官网文档:https://www.rabbitmq.c…

【Linux】解锁操作系统潜能,高效线程管理的实战技巧

目录 1. 线程的概念2. 线程的理解3. 地址空间和页表4. 线程的控制4.1. POSIX线程库4.2 线程创建 — pthread_create4.3. 获取线程ID — pthread_self4.4. 线程终止4.5. 线程等待 — pthread_join4.6. 线程分离 — pthread_detach 5. 线程的特点5.1. 优点5.2. 缺点5.3. 线程异常…

WPF+MVVM案例实战(二十二)- 制作一个侧边弹窗栏(CD类)

文章目录 1、案例效果1、侧边栏分类2、CD类侧边弹窗实现1、样式代码实现2、功能代码实现3 运行效果4、源代码获取1、案例效果 1、侧边栏分类 A类 :左侧弹出侧边栏B类 :右侧弹出侧边栏C类 :顶部弹出侧边栏D类 :底部弹出侧边栏2、CD类侧边弹窗实现 1、样式代码实现 在原有的…

如何对LabVIEW软件进行性能评估?

对LabVIEW软件进行性能评估,可以从以下几个方面着手,通过定量与定性分析,全面了解软件在实际应用中的表现。这些评估方法适用于确保LabVIEW程序的运行效率、稳定性和可维护性。 一、响应时间和执行效率 时间戳测量:使用LabVIEW的时…

stm32使用串口DMA实现数据的收发

前言 DMA的作用就是帮助CPU来传输数据,从而使CPU去完成更重要的任务,不浪费CPU的时间。 一、配置stm32cubeMX 这两个全添加上。参数配置一般默认即可 代码部分 只需要把上期文章里的HAL_UART_Transmit_IT(&huart2,DATE,2); 全都改为HAL_UART_Tra…