pytorch笔记:自动混合精度(AMP)

1 理论部分

1.1 FP16 VS FP32

  • FP32具有八个指数位和23个小数位,而FP16具有五个指数位和十个小数位
  • Tensor内核支持混合精度数学,即输入为半精度(FP16),输出为全精度(FP32)

1.1.1 使用FP16的优缺点

  • 优点
    • FP16需要较少的内存,因此更易于训练和部署大型神经网络,同时还减少了数据移动(同时可以使用更大的batch)
    • 数学运算的运行速度大大降低了
      • NVIDIA提供的Volta GPU的确切数量是:FP16中为125 TFlops,而FP32中为15.7 TFlops(加速8倍)
  • 缺点:
    • 从FP32转到FP16时,必然会降低精度
      • 但有的时候,这个精度的降低可以忽略不计
      • FP16实际上可以很好地表示大多数权重和渐变。
      • ——>拥有存储和使用FP32所需的所有这些额外位只是浪费。
    • 溢出错误
      • 由于FP16的动态范围比FP32位的狭窄很多,因此,在计算过程中很容易出现上溢出和下溢出
      • 溢出之后就会出现"NaN"的问题

1.2 解决上述FP16的问题

1.2.1 混合精度训练

  • 用FP16做储存和乘法,而用FP32做累加避免舍入误差
  • ——>混合精度训练的策略有效地缓解了舍入误差的问题

1.2.2 损失放大(Loss scaling)

  • 即使使用了混合精度训练,还是存在无法收敛的情况
    • 原因是激活梯度的值太小,造成了溢出。
  • ——>通过使用torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的下溢出
    • 只在BP时传递梯度信息使用,真正更新权重时还是要把放大的梯度再unscale回去
      • 反向传播前,将损失变化手动增大2^k倍

        • 因此反向传播时得到的中间变量(激活函数梯度)不会溢出;

      • 反向传播后,将权重梯度缩小2^k倍,恢复正常值。

2 torch.cuda.amp

  • AMP(自动混合精度)的关键词有两个:
    • 自动
      • Tensor的dtype类型会自动变化,框架按需自动调整tensor的dtype,当然有些地方还需手动干预
    • 混合精度
      • 采用不止一种精度的Tensor,torch.FloatTensor和torch.HalfTensor

2.1 Pytorch中不同类型的tensor

类型名称位数
torch.DoubleTensor64bit
torch.LongTensor64bit
torch.FloatTensor(默认)32bit
torch.IntTensor32bit
torch.HalfTensor16bit
torch.BFloat16Tensor16bit
torch.ShortTensor16bit
torch.ByteTensor(无符号)8bit
torch.CharTensor8bit
torch.BoolTensorBoolean

2.2 在AMP上下文中,被自动转化为半精度浮点型的参数:

__matmul__
addbmm
addmm
addmv
addr
baddbmm
bmm
chain_matmul
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
linear
matmul
mm
mv
prelu

2.3 autocast

from torch.cuda.amp import autocast as autocast


model = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensor

optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)
    '''
    自动混合精度环境

    包含了前向过程(模型的输出)和loss的计算

    把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算

    进入autocast的上下文时,tensor可以是任何类型
        不需要在model或者input上手工调用.half() ,框架会自动做
    '''

    
    loss.backward()
    optimizer.step()
    # 反向传播在autocast上下文之外

 2.4 GradScaler

在2.3的基础上增加,反向传播时增加梯度,以防止下溢出

from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler


model = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensor

optimizer = optim.SGD(model.parameters(), ...)


scaler = GradScaler()
# 在训练最开始之前实例化一个GradScaler对象

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()


        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        '''
        自动混合精度环境

        包含了前向过程(模型的输出)和loss的计算

        把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算

        进入autocast的上下文时,tensor可以是任何类型
            不需要在model或者input上手工调用.half() ,框架会自动做
        '''

        
        scaler.scale(loss).backward()
        # Scales loss. 为了梯度放大,防止下溢出
        # 代替原来的loss.backward()
        
        scaler.step(optimizer)
        '''
        scaler.step() 首先把梯度的值unscale回来.
        
        如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,
            否则,忽略step调用,从而保证权重不更新(不被破坏)
        '''

        
        scaler.update()
        '''
        准备着,看是否要增大scaler

        '''
  •  scaler的大小在每次迭代中动态的估计
    • 为了尽可能的减少梯度underflow,scaler应该更大
    • 但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。
  • ——>动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值

3 一些tips

  • 为了保证计算不溢出,首先保证人工设定的常数不溢出。如epsilon,INF等
  • Dimension最好是8的倍数:维度是8的倍数,性能最好
  • 涉及sum的操作要小心,容易溢出
    • 比如softmax操作,建议用官方API,并定义成layer写在模型初始化里
  • 如果遇到以下的报错:
    • RuntimeError: expected scalar type float but found c10::Half
    • 需要手动在tensor上调用.float()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/674695.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大渡口数字经济产业商会暨尼伽OLED透明屏产品发布会

2024年5月31日,大渡口数字经济产业商会成功举办了一场盛大的“商会数字经济发展项目签约大会”,活动上不仅深入探讨了构建“义渡新质生产力”及如何更好地“建功重庆西部大开发”的战略议题,还正式与尼伽OLED宣布达成战略合作伙伴关系&#x…

Java版工程项目管理系统源码:技术框架与功能实现全解析

在工程行业,项目管理的高效协同和信息共享是提升管理效率和精度的关键。本文将详细介绍一款采用先进技术框架的Java版工程项目管理系统,该系统支持前后端分离,功能全面,可满足不同角色的需求。从项目进度图表到施工地图&#xff0…

10个从基础到高级的GPT提示词优化指南

为一名大模型的深度用户和微软Copilot的首批开放测试者,很多人会问我如何写出高效的提示词。同时,也有不少读者反映,像ChatGPT和Claude这样的模型并没有想象中那么神奇,无法满足他们的实际需求。 首先我想说,确实像Ch…

双指针_复写零

复写零 题目描述: 题目链接:复写零 内容: 这道题目要求我们每遇到一次0就复写一遍,并且只能在原数组上进行修改,不能越界访问。 算法原理: 思路1: 如果我们用两个指针cur,dest同时从指向第一个…

c# 输出二进制字符串

参考链接 C#二进制输出数据_c# 输出二进制 123.5的方法-CSDN博客https://blog.csdn.net/a497785609/article/details/4572112标准数字格式字符串 - .NET | Microsoft Learnhttps://learn.microsoft.com/zh-cn/dotnet/standard/base-types/standard-numeric-format-strings#BFo…

工业HMI设计,稳定压倒一切,那高颜值就不稳定了吗?

提及工业HMI设计,很多小伙伴就跳出来说,工业 HMI稳定是最重要的,颜值没比必要,花里呼哨的.我承认稳定的重要性,但是稳定与颜值并不是一对矛盾体。本文就分享为什么工业HMI稳定性重要?为什么高颜值也重要&am…

QT:QML中使用Loader加载界面

目录 一.介绍 二.实现 三.效果展示 四.代码 一.介绍 在QML中使用Loader加载界面,可以带来诸多好处,如提高应用程序的启动速度、动态地改变界面内容、根据条件加载不同的组件、更有效地使用内存以及帮助分割应用逻辑等。 1.延迟加载:QML…

动态规划2:面试题 08.01. 三步问题

动态规划解题步骤: 1.确定状态表示:dp[i]是什么 2.确定状态转移方程:dp[i]等于什么 3.初始化:确保状态转移方程不越界 4.确定填表顺序:根据状态转移方程即可确定填表顺序 5.确定返回值 题目链接:面试…

【WEB前端2024】开源智体世界:乔布斯3D纪念馆-第34课-进门播放欢迎光临的音效

【WEB前端2024】开源智体世界:乔布斯3D纪念馆-第34课-进门播放欢迎光临的音效 使用dtns.network德塔世界(开源的智体世界引擎),策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由JavaScript编写的智…

R语言绘图 --- 柱状图(Biorplot 开发日志 --- 3)

「写在前面」 在科研数据分析中我们会重复地绘制一些图形,如果代码管理不当经常就会忘记之前绘图的代码。于是我计划开发一个 R 包(Biorplot),用来管理自己 R 语言绘图的代码。本系列文章用于记录 Biorplot 包开发日志。 相关链接…

GITLAB常见问题总结

Troubleshooting GitLab Pages administration (FREE SELF) 原文地址 stage: Plan group: Knowledge info: To determine the technical writer assigned to the Stage/Group associated with this page, see https://about.gitlab.com/handbook/product/ux/technical-writing/…

MWORKS车辆动力性经济性与热管理联合应用篇

一、引言 随着科技的飞速发展、环保意识的日益增强以及国家政策的大力支持,新能源汽车已经不再是遥不可及的未来科技,而是逐步走进千家万户,成为我们日常生活中不可或缺的一部分。然而,每到冬季的来临,纯电动汽车面临…

Linux shell编程学习笔记56:date命令——显示或设置系统时间与日期

0 前言 2024年的网络安全检查又开始了,对于使用基于Linux的国产电脑,我们可以编写一个脚本来收集系统的有关信息。在收集的信息中,应该有一条是搜索信息的时间。 1. date命令 的功能、格式和选项说明 我们可以使用命令 date --help 来查看 d…

巧用Jmeter Debug sampler获取变量信息

Jmeter Debug sampler介绍 Jmeter Debug sampler 可以帮助我们解决如下问题: debug参数化的变量取值是否正确 debug正则表达式提取器(或json提取器)提取的值是否正确 查看 JMeter 属性 具体使用方法 前提条件:添加查看结果树…

【Python】【matLab】模拟退火算法求二元高次函数最小值

一、目标函数 求二元高次函数的最小值。目标函数选择: 用于测试算法的简单的目标函数: 二、Python代码实现 import numpy as np# 目标函数(2变量) def objective_function(x):return x[0] ** 2 2 * x[0] - 15 4 * 4 * 2 * x[…

【开发心得】三步本地化部署llama3大模型

目录 第一步:启动ollama 第二步:启动dify 第三步:配置模型(截图) 最近llama3很火,本文追击热点,做一个本地化部署的尝试,结果还成功了! 当然也是站在别人的肩膀上&…

DevOps中如何高效开展手工和自动化测试

在快速发展的软件开发行业中,DevOps实践已经成为提高软件交付速度和质量的关键。DevOps是一种文化和实践的集合,旨在促进开发(Dev)和运维(Ops)团队之间的协作和通信。测试作为DevOps生命周期中的重要组成部…

安装打开 ubuntu-22.04.3-LTS 报错 解决方案

安装打开 ubuntu-22.04.3-LTS 报错 解决方案 WslRegisterDistribution failed with error: 0x800701bc Error: 0x800701bc WSL 2 ??? https://aka.ms/wsl2kernel 1、确保【windows 功能】打开了【虚拟机】。 键盘上按 WIN R 打开【运行】,输入 【 control 】&…

树莓派4B 学习笔记2:GPIO介绍_第一个Python程序_点灯

今日开始学习树莓派4B 4G:(Raspberry Pi,简称RPi或RasPi) GPIO介绍_第一个Python程序_Python点灯 文章提供测试代码讲解、完整代码贴出、测试效果图 目录 树莓派4B 引脚与外设图: 树莓派常用命令: 第一个…

今日好料推荐(ARM嵌入式)

今日好料推荐(ARM嵌入式) 参考资料在文末获取,关注我,获取优质资源。 给我留言,会帮大家寻找需要的资料。 ARM 嵌入式系统 嵌入式系统在现代电子设备中扮演着至关重要的角色,从智能手机到工业自动化&am…