AMP 混合精度训练中的动态缩放机制: grad_scaler.py函数解析( torch._amp_update_scale_)

AMP 混合精度训练中的动态缩放机制

在深度学习中,混合精度训练(AMP, Automatic Mixed Precision)是一种常用的技术,它利用半精度浮点(FP16)计算来加速训练,同时使用单精度浮点(FP32)来保持数值稳定性。为了在混合精度训练中避免数值溢出,PyTorch 提供了一种动态缩放机制来调整 “loss scale”(损失缩放值)。本文将详细解析动态缩放机制的实现原理,并通过代码展示其内部逻辑。


动态缩放机制简介

动态缩放机制的核心思想是通过一个可动态调整的缩放因子(scale factor)放大 FP16 的梯度,从而降低舍入误差对训练的影响。当检测到数值不稳定(例如 NaN 或无穷大)时,缩放因子会被降低;当连续多步未检测到数值问题时,缩放因子会被提高。其调整策略基于以下两个参数:

  • growth_factor: 连续成功步骤后用于增加缩放因子的乘数(通常大于 1,如 2.0)。
  • backoff_factor: 检测到数值溢出时用于减少缩放因子的乘数(通常小于 1,如 0.5)。

此外,动态缩放还使用 growth_interval 参数控制连续成功步骤的计数阈值。当达到这个阈值时,缩放因子才会增加。


AMP 缩放更新核心代码解析

PyTorch 实现了一个用于更新缩放因子的 CUDA 核函数以及相关的 Python 包装函数。以下是核心代码解析:

CUDA 核函数实现

// amp_update_scale_cuda_kernel 核函数实现
__global__ void amp_update_scale_cuda_kernel(float* current_scale,
                                             int* growth_tracker,
                                             const float* found_inf,
                                             double growth_factor,
                                             double backoff_factor,
                                             int growth_interval) {
  if (*found_inf) {
    // 如果发现梯度中存在 NaN 或 Inf,缩放因子乘以 backoff_factor,并重置 growth_tracker。
    *current_scale = (*current_scale) * backoff_factor;
    *growth_tracker = 0;
  } else {
    // 未发现数值问题,增加 growth_tracker 的计数。
    auto successful = (*growth_tracker) + 1;
    if (successful == growth_interval) {
      // 当 growth_tracker 达到 growth_interval,尝试增长缩放因子。
      auto new_scale = static_cast<float>((*current_scale) * growth_factor);
      if (isfinite_ensure_cuda_math(new_scale)) {
        *current_scale = new_scale;
      }
      *growth_tracker = 0;
    } else {
      *growth_tracker = successful;
    }
  }
}
核函数逻辑
  1. 发现数值溢出(found_inf > 0):

    • 缩放因子 current_scale 乘以 backoff_factor
    • 重置成功计数器 growth_tracker 为 0。
  2. 未发现数值溢出:

    • 增加成功计数器 growth_tracker
    • 如果 growth_tracker 达到 growth_interval,则将缩放因子乘以 growth_factor
    • 保证缩放因子不会超过 FP32 的数值上限。

C++ 包装函数实现

在 PyTorch 中,这一 CUDA 核函数通过 C++ 包装函数 _amp_update_scale_cuda_ 被调用。以下是实现代码:

Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
                                Tensor& growth_tracker,
                                const Tensor& found_inf,
                                double growth_factor,
                                double backoff_factor,
                                int64_t growth_interval) {
  TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
  TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  
  // 核函数调用
  amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
    current_scale.mutable_data_ptr<float>(),
    growth_tracker.mutable_data_ptr<int>(),
    found_inf.const_data_ptr<float>(),
    growth_factor,
    backoff_factor,
    growth_interval);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return current_scale;
}

Python 调用入口

AMP 的 GradScaler 类通过 _amp_update_scale_ 函数更新缩放因子,以下是相关代码:
代码来源:anaconda3/envs/xxxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

具体调用过程可以参考笔者的另一篇博文:PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例

def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
    """更新缩放因子"""
    if not self._enabled:
        return

    _scale, _growth_tracker = self._check_scale_growth_tracker("update")

    if new_scale is not None:
        # 设置用户定义的新缩放因子。
        self._scale.fill_(new_scale)
    else:
        # 收集所有优化器中的 found_inf 数据。
        found_infs = [
            found_inf.to(device=_scale.device, non_blocking=True)
            for state in self._per_optimizer_states.values()
            for found_inf in state["found_inf_per_device"].values()
        ]

        found_inf_combined = found_infs[0]
        if len(found_infs) > 1:
            for i in range(1, len(found_infs)):
                found_inf_combined += found_infs[i]

        # 更新缩放因子。
        torch._amp_update_scale_(
            _scale,
            _growth_tracker,
            found_inf_combined,
            self._growth_factor,
            self._backoff_factor,
            self._growth_interval,
        )

总结

PyTorch 的动态缩放机制通过 CUDA 核函数和 Python 包装函数协作完成。其核心逻辑是:

  1. 检测数值不稳定(如 NaN 或 Inf),通过缩小缩放因子提高数值稳定性。
  2. 当连续多次未出现数值不稳定时,逐步增大缩放因子以充分利用 FP16 的动态范围。
  3. 所有更新操作都在 GPU 上异步完成,最大限度地减少同步开销。

通过动态调整缩放因子,AMP 有效地加速了深度学习模型的训练,同时避免了梯度溢出等数值问题。


推荐阅读

  • PyTorch 官方文档
  • 混合精度训练介绍

后记

2025年1月2日15点38分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/947474.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微博_14.12.2-内置猪手 会员版

微博猪手是一款作用于微博的 XposedLsposed 模块&#xff0c;可以支持未root用户和已root用户使用。进入【我的】页面&#xff0c;点击【右上角的设置】&#xff0c;点击【微博猪手】即可进一步设置其他功能。通过微博猪手模块可以实现去除各种广告&#xff08;开屏、信息流等&…

计算机网络 (21)网络层的几个重要概念

前言 计算机网络中的网络层是OSI&#xff08;开放系统互连&#xff09;模型中的第三层&#xff0c;也是TCP/IP模型中的第二层&#xff0c;它位于数据链路层和传输层之间&#xff0c;负责数据包从源主机到目的主机的路径选择和数据转发。 一、网络层的主要功能 路由选择&#xf…

openwrt nginx UCI配置过程

openwrt 中nginx有2种配置方法&#xff0c;uci nginx uci /etc/config/nginx 如下&#xff1a; option uci_enable true‘ 如果是true就是使用UCI配置&#xff0c;如果 是false&#xff0c;就要使用/etc/nginx/nginx.conf&#xff0c;一般不要修改。 如果用UCI&#xff0c;其…

【深度学习进阶】基于CNN的猫狗图片分类项目

介绍 基于卷积神经网络&#xff08;CNN&#xff09;的猫狗图片分类项目是机器学习领域中的一种常见任务&#xff0c;它涉及图像处理和深度学习技术。以下是该项目的技术点和流程介绍&#xff1a; 技术点 卷积神经网络 (CNN): CNN 是一种专门用于处理具有类似网格结构的数据的…

uni-app 页面生命周期及组件生命周期汇总(Vue2、Vue3)

文章目录 一、前言&#x1f343;二、页面生命周期三、Vue2 页面及组件生命周期流程图四、Vue3 页面及组件生命周期流程图4.1 页面加载时序介绍4.2 页面加载常见问题4.3 onShow 和 onHide4.4 onInit4.5 onLoad4.6 onReachBottom4.7 onPageScroll4.8 onBackPress4.9 onTabItemTap…

缓存淘汰算法:次数除以时间差

记录缓存中的每一项的访问次数、最后访问时间&#xff0c;获取当前时间&#xff0c;可算出时间差&#xff0c;然后&#xff0c;用次数除以时间差&#xff0c;取最小的淘汰。 这一算法比较慢&#xff0c;需配合多级缓存。一级缓存不很大&#xff0c;使用此算法。二级缓存可以大…

uniapp 微信小程序开发使用高德地图、腾讯地图

一、高德地图 1.注册高德地图开放平台账号 &#xff08;1&#xff09;创建应用 这个key 第3步骤&#xff0c;配置到项目中locationGps.js 2.下载高德地图微信小程序插件 &#xff08;1&#xff09;下载地址 高德地图API | 微信小程序插件 &#xff08;2&#xff09;引入项目…

Mac iTerm2集成DeepSeek AI

1. 去deepseek官网申请api key&#xff0c;DeepSeek 2. 安装iTerm2 AI Plugin插件&#xff0c;https://iterm2.com/ai-plugin.html&#xff0c;插件解压后直接放到和iTerms相同的位置&#xff0c;默认就在/Applications 下 3. 配置iTerm2 4. 重启iTerm2,使用快捷键呼出AI对话…

树莓派 Pico RP2040 教程点灯 双核编程案例

双核点亮不同的 LED 示例&#xff0c;引脚分别是GP0跟GP1。 #include "pico/stdlib.h" #include "pico/multicore.h"#define LED1 0 // 核心 0 控制的 LED 引脚 #define LED2 1 // 核心 1 控制的 LED 引脚// the setup function runs once when you press …

简单使用linux

1.1 Linux的组成 Linux 内核&#xff1a;内核是系统的核心&#xff0c;是运行程序和管理 像磁盘和打印机等硬件设备的核心程序。 文件系统 : 文件存放在磁盘等存储设备上的组织方法。 Linux 能支持多种目前浒的文件系统&#xff0c;如 ext4 、 FAT 、 VFAT 、 ISO9660 、 NF…

ACM算法模板

ACM算法模板 起手式基础算法前缀和与差分二分查找三分查找求极值分治法&#xff1a;归并排序 动态规划基本线性 d p dp dp最长上升子序列I O ( n 2 ) O(n ^ 2) O(n2)最长上升子序列II O ( n l o g n ) O(nlogn) O(nlogn) 贪心二分最长公共子序列 背包背包求组合种类背包求排列…

《Vue3实战教程》19:Vue3组件 v-model

如果您有疑问&#xff0c;请观看视频教程《Vue3实战教程》 组件 v-model​ 基本用法​ v-model 可以在组件上使用以实现双向绑定。 从 Vue 3.4 开始&#xff0c;推荐的实现方式是使用 defineModel() 宏&#xff1a; vue <!-- Child.vue --> <script setup> co…

Docker 环境中搭建 Redis 哨兵模式集群的步骤与问题解决

在 Docker 环境中搭建 Redis 哨兵模式集群的步骤与问题解决 在 Redis 高可用架构中&#xff0c;哨兵模式&#xff08;Sentinel&#xff09;是确保 Redis 集群在出现故障时自动切换主节点的一种机制。通过使用 Redis 哨兵&#xff0c;我们可以实现 Redis 集群的监控、故障检测和…

数据结构:时间复杂度和空间复杂度

我们知道代码和代码之间算法的不同&#xff0c;一定影响了代码的执行效率&#xff0c;那么我们该如何评判算法的好坏呢&#xff1f;这就涉及到了我们算法效率的分析了。 &#x1f4d6;一、算法效率 所谓算法效率的分析分为两种&#xff1a;第一种时间效率&#xff0c;又称时间…

《Vue3实战教程》39:Vue3无障碍访问

如果您有疑问&#xff0c;请观看视频教程《Vue3实战教程》 无障碍访问​ Web 无障碍访问 (也称为 a11y) 是指创建可供任何人使用的网站的做法——无论是身患某种障碍、通过慢速的网络连接访问、使用老旧或损坏的硬件&#xff0c;还是仅处于某种不方便的环境。例如&#xff0c;…

GESP2024年6月认证C++五级( 第三部分编程题(1)黑白格)

参考程序&#xff08;二维前缀和&#xff09; #include <iostream> #include <vector> #include <algorithm> using namespace std;int main() {int n, m, k;cin >> n >> m >> k;// 输入网格图vector<vector<int>> grid(n, v…

二、SQL语言,《数据库系统概念》,原书第7版

文章目录 一、概览SQL语言1.1 SQL 语言概述1.1.1 SQL语言的提出和发展1.1.2 SQL 语言的功能概述 1.2 利用SQL语言建立数据库1.2.1 示例1.2.2 SQL-DDL1.2.2.1 CREATE DATABASE1.2.2.2 CREATE TABLE 1.2.3 SQL-DML1.2.3.1 INSERT INTO 1.3 用SQL 语言进行简单查询1.3.1 单表查询 …

js按日期按数量进行倒序排序,然后再新增一个字段,给这个字段赋值 10 到1

效果如下图&#xff1a; 实现思路&#xff1a; 汇总数据&#xff1a;使用 reduce 方法遍历原始数据数组&#xff0c;将相同日期的数据进行合并&#xff0c;并计算每个日期的总和。创建日期映射&#xff1a;创建一个映射 dateMap&#xff0c;存储每个日期的对象列表。排序并添加…

用uniapp写一个播放视频首页页面代码

效果如下图所示 首页有导航栏&#xff0c;搜索框&#xff0c;和视频列表&#xff0c; 导航栏如下图 搜索框如下图 视频列表如下图 文件目录 视频首页页面代码如下 <template> <view class"video-home"> <!-- 搜索栏 --> <view class…

【three.js】光源

光源 光源特点 当使用MeshLambertMaterial材质时&#xff0c;会受到光线的影响&#xff0c; 我们代码里面如果没有设置光线&#xff0c;则使用MeshLambertMaterial材质修饰的模型不可见&#xff0c;这个时候&#xff0c;我们添加光线后&#xff0c;便可以看见。 环境光 定义&a…