【Pytorch】理解自动混合精度训练

【Pytorch】理解自动混合精度训练

  • 混合精度概述
  • 实验对比

  更大的深度学习模型需要更多的计算能力和内存资源。一些新技术的提出,可以更快地训练深度神经网络。我们可以使用 FP16(半精度浮点数格式)来代替 FP32(全精度浮点数格式),研究人员发现串联使用它们是更好的选择。有的 GPU(例如 Paperspace 提供的 Ampere GPU)甚至可以利用较低级别的精度,例如 INT8。

  混合精度允许半精度训练,同时仍保留大部分单精度网络精度。术语“混合精度技术”是指该方法同时使用单精度和半精度表示。

  在使用 PyTorch 进行自动混合精度 (Amp) 训练的概述中,我们演示了该技术的工作原理,逐步介绍使用 Amp 的过程,并通过代码讨论 Amp 技术的应用。

混合精度概述

  在深度学习的世界里,使用 FP16 进行计算不仅能显著提升性能,还能节省内存。然而,这种方法也带来了两个主要问题:精度溢出和舍入误差。这两个问题是深度学习中 FP16 计算的关键挑战。

   精度溢出(Precision Overflow)

  在 FP16 格式下,由于位宽较小,可表示的数值范围远小于 FP32 或 FP64。这容易导致数值过大或过小而无法在 FP16 的表示范围内精确表示。在深度学习中,这可能引起梯度消失或梯度爆炸,因为一些小的梯度值可能变成零(下溢),而一些大的梯度值可能变得无限大(上溢)。这种溢出问题会严重影响模型训练的稳定性和最终性能。

  舍入误差(Rounding Error)

  FP16 由于其16位的表示限制,相比于 FP32 或 FP64,舍入误差更加明显。在深度学习中,每次计算的舍入误差会累积,尤其是在多层和复杂运算中。这可能导致模型输出与使用更高精度计算时存在显著差异。对于那些对精确度要求极高的应用(比如金融或医疗领域),这种误差可能造成不可接受的后果。

  为了缓解这些问题,混合精度训练方法在关键部分(如权重更新)使用 FP32 来保持精度,而在其他操作(如前向传播)中使用 FP16 来提高效率。混合精度训练中,我特别注意到了权重备份(Weight Backup)、损失放大(Loss Scaling)、精度累加(Precision Accumulated)这三种技术的重要性。

  权重备份(Weight Backup):

  在混合精度训练中,为了确保数值稳定性,模型的权重通常会在 FP16 和 FP32 两种格式下同时维护。权重备份是指保留 FP32 格式的权重副本,这样即使在大部分使用 FP16 格式的计算过程中出现数值不稳定现象,我们仍然能依靠 FP32 权重副本保持稳定和精确。这对于更新模型参数时的准确性至关重要。

  损失放大(Loss Scaling):

  在混合精度训练中,由于 FP16 的表示范围限制,梯度值可能太小而无法在 FP16 中准确表示,导致有效梯度变为零。损失放大是通过在计算梯度前将损失函数的值乘以一个较大的常数(放大因子),从而放大梯度值,使其在 FP16 范围内可表示且非零。在反向传播后,再将放大的梯度除以相同的放大因子,恢复原始比例,这样可以有效减少梯度下溢问题。

  精度累加(Precision Accumulation):

  精度累加是指在权重更新过程中,即使梯度计算是在 FP16 下完成的,但权重更新则在 FP32 精度下进行。这有助于减少舍入误差和累积误差,尤其在训练过程中涉及大量累积操作时。由于 FP32 提供更高的数值精度和更大的表示范围,可以更准确地累积小梯度值,避免更新权重时的数值不稳定性。

  综上所述,通过将这些技术相结合,混合精度训练能够有效地利用 FP16 带来的性能优势,同时最大限度地减少精度损失和计算不稳定性。

实验对比

  为了进一步验证这些技术的有效性,我设计了两个实验对比使用混合精度和传统的 FP32 两种方式进行训练。以下是我在这两个实验中使用的代码片段:

FP16与FP32混合训练代码:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import time

from torch.cuda.amp import GradScaler, autocast


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20000),

            nn.Linear(20000, 20000),
            # nn.Dropout(0.1),

            nn.Linear(20000, 200),
            # nn.LayerNorm(20),

            nn.Linear(200, 20),
            # nn.LayerNorm(20),

            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

lr = 0.0001
iteration = 1000


x1 = torch.arange(-1000, 1000).float().to('cuda')
x2 = torch.arange(0, 2000).float().to('cuda')
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = (2*x1 - x2 + 1).to('cuda')

model = Model().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()

scaler = GradScaler()

start_time = time.time()
writer = SummaryWriter(comment='_FP16')

for iter in range(iteration):
    with autocast():
        y_pred = model(x)
        loss = loss_function(y, y_pred.squeeze())
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    writer.add_scalar('loss', loss, iter)
    optimizer.zero_grad()

    if iter % 100 == 0:
        # 获取 GPU 的内存使用情况
        print("GPU Memory Allocated:", torch.cuda.memory_allocated())
        print("GPU Memory Cached:   ", torch.cuda.memory_reserved())

print("Time: ", time.time() - start_time)
torch.save(model.state_dict(), 'model_state_dict_fp16.pth')


FP32训练代码:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import time

from torch.cuda.amp import GradScaler, autocast


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20000),

            nn.Linear(20000, 20000),
            # nn.Dropout(0.1),

            nn.Linear(20000, 200),
            # nn.LayerNorm(20),

            nn.Linear(200, 20),
            # nn.LayerNorm(20),

            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

lr = 0.0001
iteration = 1000


x1 = torch.arange(-1000, 1000).float().to('cuda')
x2 = torch.arange(0, 2000).float().to('cuda')
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = (2*x1 - x2 + 1).to('cuda')

model = Model().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()

scaler = GradScaler()

start_time = time.time()
writer = SummaryWriter(comment='_FP32')

for iter in range(iteration):

    y_pred = model(x)
    loss = loss_function(y, y_pred.squeeze())
    loss.backward()
    optimizer.step()


    writer.add_scalar('loss', loss, iter)
    optimizer.zero_grad()

    if iter % 100 == 0:
        # 获取 GPU 的内存使用情况
        print("GPU Memory Allocated:", torch.cuda.memory_allocated())
        print("GPU Memory Cached:   ", torch.cuda.memory_reserved())

print("Time: ", time.time() - start_time)
torch.save(model.state_dict(), 'model_state_dict_fp32.pth')


  最终两者的效果如下:

实验名占用GPU消耗时间
FP16与FP32混合训练4867271680 bytes78.73 s
FP32训练4867274752 bytes140.18 s

  实验分析如下:

  1、在内存占用方面,两种训练方法几乎相同。这可能是因为模型结构和数据集大小相同,所以内存占用没有显著差异。然而,通常情况下,FP16训练应该占用更少的内存,因为它使用的是半精度浮点数。

  2、在训练时间方面,混合精度训练明显快于纯FP32训练。这是因为FP16训练可以加快计算速度并降低内存需求,从而允许模型更快地运行。混合精度训练结合了FP16的高效率和FP32的数值稳定性,提供了一个平衡的解决方案。

在这里插入图片描述

  在提供的损失图中,我们可以看到蓝色曲线代表使用全精度(FP32)训练的模型损失,而橙色曲线代表使用混合精度(FP16与FP32)训练的模型损失。以下是对两种训练方法损失曲线的进一步分析:

  · 在下降到一定程度后,两条曲线都达到平稳状态,这表明模型已基本收敛。在这个平稳阶段,损失变化不大,说明模型在训练集上的表现已经稳定。

  · 没有明显的过拟合迹象,因为损失曲线没有再次上升的趋势。

  · 混合精度训练似乎在时间效率略优于全精度训练,尽管最终损失值的差异不大,但在追求快速迭代和高效训练的情况下,选择混合精度训练会更有优势;而在收敛速度上,混合精度训练在140epoch时收敛完毕,而全精度训练早在100epoch就收敛完,说明全精度训练收敛较于混合精度训练,在step层面更快,而它在时间层面更慢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/223920.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java 输入输出流02

六. Java.IO 流类库 1:io 流的四个基本类 java.io 包中包含了流式 I/O 所需要的所有类。在 java.io 包中有四个基本类:InputStream、OutputStream 及 Reader、Writer 类,它们分别处理字节流和字符流: 基本数据流的 I/O 输入 / …

多向通信----多人聊天

package 多人聊天; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.PrintStream; import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayList; publ…

信道的极限容量

目录 信道的最高码元传输速率 限制码元在信道上的传输速率的因素: (1)信道能够通过的频率范围 (2) 信噪比 任何实际的信道都不是理想的,在传输信号时会产生各种失真以及带来多种干扰。 码元传输的速率越…

Avalonia中如何将View事件映射到ViewModel层

前言 前面的文章里面我们有介绍在Wpf中如何在View层将事件映射到ViewModel层的文章,传送门,既然WPF和Avalonia是两套不同的前端框架,那么WPF里面实现模式肯定在这边就用不了,本篇我们将分享一下如何在Avalonia前端框架下面将事件…

如何配置WinDbg和VMware实现内核的调试

设置 VMware 的虚拟串口 运行 VMware,首先将 Guest OS 系统电源关闭,这样才能修改该系统的虚拟机设置。 单击界面上的“编辑虚拟机设置”选项对虚拟机的属性进行设置。 单击“添加”按钮,打开 VMware 的 添加硬件向导 对话框 选择“串行端口…

构建第一个事件驱动型 Serverless 应用

我相信,我们从不缺精彩的应用创意,我们缺少的把这些想法变成现实的时间和付出。 我认为,无服务器技术真的有助于最大限度节省应用开发和部署的时间,并且无服务器技术用可控的成本,实现了我的那些有趣的想法。 在我 2…

c语言笔记之文件操作

16 文件操作 嵌入式开发中基本用不上,这章不重要 a 字符集:泛泛意义上的文本文件中的数据与磁盘中保存的二进制之间的映射关系。 常见的字符集:ASCLL,Latin,GB2312,GBK,UTF-8 解码过程:从看不懂到看得懂的过程。 ​ 如果操作时…

02Docker容器卷

Docker容器卷 1.数据卷是什么 简而言之: 就是Docker用来存储数据的,在镜像被删除的时候,卷中数据不会被删除,就是相当于一个数据库备份数据,相当于Windows中的目录或文件 2.目的 解决数据持久化 独立容器的生存周期,帮助容器间继承和共享数据 3.数据卷的使用 1.直接添加 doc…

Linux 多线程(C语言) 备查

基础 1)线程在运行态和就绪态不停的切换。 2)每个线程都有自己的栈区和寄存器 1)进程是资源分配的最小单位,线程是操作系统调度执行的最小单位 2)线程的上下文切换的速度比进程快得多 3)从应用程序A中启用应…

【UE】制作地月全息投影

效果 步骤 1. 在必应国际版上搜索“purlin noise”,下载如下所示图片 再搜索“Earth Map”,下载如下所示图片 再搜索“Moon 360”,下载如下所示图片 这三张图片的资源链接如下: 链接:https://pan.baidu.com/s/1b_50q…

Leetcode每日一题学习训练——Python3版(最小化旅行的价格总和)

版本说明 当前版本号[20231206]。 版本修改说明20231206初版 目录 文章目录 版本说明目录最小化旅行的价格总和理解题目代码思路参考代码 原题可以点击此 2646. 最小化旅行的价格总和 前去练习。 最小化旅行的价格总和 现有一棵无向、无根的树,树中有 n 个节点…

【Spark学习笔记】- 5.1 IO基本实现原理

IO基本实现原理 Input& Output 字节流 InputStream in new FileInputStream("path") int i -1while ( (i in.read()) ! -1 ) {println(i); }上述为字节流 需要一个字节一个字节读取数据,读一个打印一个。功能可以实现,效率不高。 缓…

9_企业架构队列缓存中间件分布式Redis

企业架构队列缓存中间件分布式Redis 学习目标和内容 1、能够描述Redis作用及其业务适用场景 2、能够安装配置启动Redis 3、能够使用命令行客户端简单操作Redis 4、能够实现操作基本数据类型 5、能够理解描述Redis数据持久化机制 6、能够操作安装php的Redis扩展 7、能够操作实现…

AI跨界学习,不再是梦!

大家好!今天给大家推荐的 GPTs 是【行业知识脉络】,帮助大家快速了解某个领域的脉络,并提供足够的学习资料和建议。 在AI时代,从小白到专家的1万小时定律即将失效,用少于1千小时掌握行业知识树和其核心概念是如何学习的…

内核无锁队列kfifo

文章目录 1、抛砖引玉2、内核无锁队列kfifo2.1 kfifo结构2.2 kfifo分配内存2.3 kfifo初始化2.4 kfifo释放2.5 kfifo入队列2.6 kfifo出队列2.7 kfifo的判空和判满2.8 关于内存屏障 1、抛砖引玉 昨天遇到这样一个问题,有多个生产者,多个消费者&#xff0c…

使用Java网络编程,窗口,线程,IO,内部类等实现多人在线聊天1.0

1.整体思路 思路图 整体思路如上: 涉及知识点:线程网络编程集合IO等 TCP 协议 2.代码实现过程 服务端 import javax.swing.*; import java.awt.*; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.KeyAdapter; import jav…

SQL手工注入漏洞测试(Sql Server数据库)-墨者

———靶场专栏——— 声明:文章由作者weoptions学习或练习过程中的步骤及思路,非正式答案,仅供学习和参考。 靶场背景: 来源: 墨者学院 简介: 安全工程师"墨者"最近在练习SQL手工注入漏洞&#…

大模型应用设计的10个思考

技术不是万能的,但没有技术却可能是万万不能的,对于大模型可能也是如此。基于大模型的应用设计需要聚焦于所解决的问题,在自然语言处理领域,大模型本身在一定程度上只是将各种NLP任务统一成了sequence 到 sequence 的模型。利用大…

使用 Webshell 访问 SQL Server 主机并利用 SSRS

本文将指导您使用RDS SQL Server实例的主机账号登录和管理SQL Server Reporting Services(SSRS)数据库。 背景信息 RDS SQL Server提供Webshell功能,用户可以通过Web界面登录RDS SQL Server实例的操作系统。通过Webshell,用户可…

一次重新加载所有 maven 项目产生的 OOM

1、解决什么问题? 忘了截图了,用文字描述就是由于Reload All Maven Projects导致的 OOM 异常。 2、尝试与解决 2.1、尝试 2.1.1、尝试清理idea缓存(无效) 2.1.2、重启idea(无效) 2.1.3、重启电脑&am…