如何用更少的内存训练你的PyTorch模型?深度学习GPU内存优化策略总结

在训练大规模深度学习模型时,GPU 内存往往成为关键瓶颈,尤其是面对大型语言模型(LLM)和视觉 Transformer 等现代架构时。由于大多数研究者和开发者难以获得配备海量 GPU 内存的高端计算集群,掌握高效的内存优化技术至关重要。本文将系统介绍多种优化策略,这些方法在组合应用的情况下,可将训练过程中的内存占用降低近 20 倍,而不会影响模型性能和预测精度。此外,大多数技术可以相互结合,以进一步提升内存效率。

目录

一、自动混合精度训练

二、低精度训练

三、梯度检查点

四、使用梯度累积减少批次大小

五、张量分片和分布式训练

六、高效的数据加载

七、使用原地操作

八、激活和参数卸载

九、使用更精简的优化器

十、进阶优化技术

内存分析和缓存管理

使用TorchScript进行JIT编译

自定义内核融合

使用 torch.compile() 进行动态内存分配

总结


一、自动混合精度训练

混合精度训练利用16位 (FP16) 和32位 (FP32) 浮点格式来保持准确性。通过以16位计算梯度,与使用完整的32位分辨率相比,该过程变得更快,并且内存使用量减少。

screenshot_2025-03-04_16-22-58.png

该过程首先将权重转换为较低精度(FP16)以加快计算速度。然后计算梯度,将其转换回更高精度(FP32)以确保数值稳定性,最后使用这些缩放后的梯度来更新原始权重。

使用 torch.cuda.amp.autocast()可轻松实现混合精度训练:

import torch
from torch.cuda.amp import autocast, GradScaler

# Assume your model and optimizer have been defined elsewhere.
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for data, target in data_loader:
    optimizer.zero_grad()
    # Enable mixed precision
    with autocast():
        output = model(data)
        loss = loss_fn(output, target)
    
    # Scale the loss and backpropagate
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

二、低精度训练

由于16位浮点数的表示范围限制,这种方法可能导致NaN值出现。为了进一步降低精度,可采用BF16(Brain Floating Point),该格式相较FP16提供更大的动态范围,使其更适合深度学习应用。

2.png

NVIDIA Ampere及更新架构的GPU已支持BF16,用户可使用以下命令检查支持情况:

import torch
print(torch.cuda.is_bf16_supported())  # should print True

三、梯度检查点

即使使用混合精度和低精度,这些大型模型也会生成许多中间张量,这些张量会消耗大量内存。梯度检查点(Gradient Checkpointing)通过选择性地存储部分中间激活值,并在反向传播时重新计算其余激活值,以换取计算成本来减少内存占用。

通过策略性地选择要检查哪些层,您可以通过动态重新计算激活而不是存储它们来减少内存使用量。这种权衡对于具有深度架构的模型尤其有益,因为中间激活占内存消耗的很大一部分。如何使用它的简单代码片段如下:

import torch
from torch.utils.checkpoint import checkpoint

def checkpointed_segment(input_tensor):
    # This function represents a portion of your model
    # which will be recomputed during the backward pass.
    # You can create a custom forward pass for this segment.
    return model_segment(input_tensor)

# Instead of a conventional forward pass, wrap the segment with checkpoint.
output = checkpoint(checkpointed_segment, input_tensor)

四、使用梯度累积减少批次大小

简单减小批量大小虽然能显著降低内存消耗,但往往会对模型准确率产生不良影响。

梯度累积(Gradient Accumulation)通过累积多个小批量的梯度,以实现较大的“虚拟”批次大小,从而降低对GPU内存的需求。其核心原理是为较小的批量计算梯度,并在多次迭代中累积这些梯度(通常通过求和或平均),而不是在每个批次后立即更新模型权重。

然而需要注意,这种技术的主要缺点是显著增加了训练时间。


五、张量分片和分布式训练

对于超大规模模型,可以使用完全分片数据并行(FSDP)技术,将模型参数、梯度和优化器状态拆分至多个GPU,以降低单 GPU 的内存压力。

3.png

FSDP不会在每个GPU上维护模型的完整副本,而是将模型的参数划分到可用设备中。执行前向或后向传递时,只有相关分片才会加载到内存中。这种分片机制大大降低了每个设备的内存需求,与上述任何一种技术相结合,在某些情况下甚至可以实现高达10倍的减少。

4.png

使用以下方式启用它:

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Initialize your model and ensure it is on the correct device.
model = MyLargeModel().cuda()

# Wrap the model in FSDP for sharded training across GPUs.
fsdp_model = FSDP(model)

六、高效的数据加载

内存优化中常被忽视的一个方面是数据加载效率。虽然大部分优化关注点集中在模型内部结构和计算过程,但低效的数据处理同样可能造成不必要的瓶颈,影响内存利用和计算速度。作为经验法则,当处理数据加载器时,应始终启用Pinned Memory和配置适当的Multiple Workers,如下所示:

from torch.utils.data import DataLoader

# Create your dataset instance and then the DataLoader with pinned memory enabled.
train_loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,      # Adjust based on your CPU capabilities
    pin_memory=True     # Enables faster host-to-device transfers
)

七、使用原地操作

避免不必要的张量复制,可以通过原地操作减少临时内存分配。例如:

import torch

x = torch.randn(100, 100, device='cuda')
y = torch.randn(100, 100, device='cuda')

# Using in-place addition
x.add_(y)  # Here x is modified directly instead of creating a new tensor

八、激活和参数卸载

对于非常大的模型,即使采用了上述所有技术,由于中间激活次数过多,您仍可能会达到GPU内存的极限。

此外,可以策略性地将一些激活和/或参数卸载到主机内存(CPU), GPU 内存保留下来仅用于关键计算。将部分激活转移到CPU以节省GPU内存,如使用DeepSpeed进行自动管理:

def offload_activation(tensor):
    # Move tensor to CPU to save GPU memory
    return tensor.cpu()

def process_batch(data):
    # Offload some activations explicitly
    intermediate = model.layer1(data)
    intermediate = offload_activation(intermediate)
    intermediate = intermediate.cuda()  # Move back when needed
    output = model.layer2(intermediate)
    return output

九、使用更精简的优化器

各种优化器在内存消耗方面存在显著差异。例如,广泛使用的Adam优化器为每个模型参数维护两个额外状态参数(动量和方差),这意味着更多的内存消耗。将Adam替换为无状态优化器(如SGD)可将参数数量减少近2/3,这在处理LLM等大型模型时尤为重要。

标准SGD的缺点是收敛特性较差。为弥补这一点,可引入余弦退火学习率调度器以实现更好的收敛效果。实现示例:

# instead of this
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# use this
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_steps = NUM_EPOCHS * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=num_steps)

十、进阶优化技术

除上述基础技术外,以下高级策略可进一步优化GPU内存使用,充分发挥硬件潜能:

  • 内存分析和缓存管理

精确测量是有效优化的前提。PyTorch提供了多种实用工具用于监控GPU内存使用情况:

import torch

# print a detailed report of current GPU memory usage and fragmentation
print(torch.cuda.memory_summary(device=None, abbreviated=False))

# free up cached memory that’s no longer needed by PyTorch
torch.cuda.empty_cache()
  • 使用TorchScript进行JIT编译

PyTorch的即时编译器(JIT)可让使用TorchScript将Python 模型转换为优化的可序列化程序。通过优化内核启动并减少开销,此转换可同时提高内存和性能:

import torch

# Suppose `model` is an instance of your PyTorch network.
scripted_model = torch.jit.script(model)

# Now, you can run the scripted model just like before.
output = scripted_model(input_tensor)
  • 自定义内核融合

编译的另一个主要好处是将多个操作融合(如上文所述)到单个内核中。这有助于减少内存读写并提高整体吞吐量。融合操作如下所示:

5.png

  • 使用 torch.compile() 进行动态内存分配

进一步利用编译技术,JIT编译器可通过编译时优化改进动态内存分配效率。结合跟踪和计算图优化技术,这种方法可在大型模型和Transformer架构中实现更显著的内存和性能优化。


总结

通过合理组合以上优化策略,可以大幅降低GPU内存占用,提高训练效率,使得大规模深度学习模型能在有限资源下运行。随着硬件技术和深度学习框架的不断发展,进一步探索新方法将有助于更高效地训练AI模型。如果您有更好的技术方式,欢迎在评论区讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/984379.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Dify+DeepSeek | Excel数据一键可视化(创建步骤案例)(echarts助手.yml)(文档表格转图表、根据表格绘制图表、Excel绘制图表)

Dify部署参考:Dify Rag部署并集成在线Deepseek教程(Windows、部署Rag、安装Ragan安装、安装Dify安装、安装ollama安装) DifyDeepSeek - Excel数据一键可视化(创建步骤案例)-DSL工程文件(可直接导入&#x…

linux下ollama离线安装

一、离线安装包下载地址 直接下载地址: https://github.com/ollama/ollama/releases/tag/v0.5.12 网络爬取地址: MacOS https://ollama.com/download/Ollama-darwin.zip Linux curl -fsSL https://ollama.com/install.sh | sh Windows https://olla…

MAC 搭建Dify+DeepSeek-R1整合部署

在开始安装之前,我们需要确保系统满足以下基本要求: CPU至少2核心内存至少4GB(建议8GB以上)硬盘空间至少20GB(为了后续扩展)操作系统支持:Windows、macOS或LinuxDocker环境 1. dify的安装步骤…

OpenManus介绍及本地部署体验

1.OpenManus介绍 OpenManus,由 MetaGPT 团队精心打造的开源项目,于2025年3月发布。它致力于模仿并改进 Manus 这一封闭式商业 AI Agent 的核心功能,为用户提供无需邀请码、可本地化部署的智能体解决方案。换句话说,OpenManus 就像…

springboot011基于springboot的课程作业管理系统(源码+包运行+LW+技术指导)

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得难了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等,你想解决的问题,今天…

swift -(5) 汇编分析结构体、类的内存布局

一、结构体 在 Swift 标准库中,绝大多数的公开类型都是结构体,而枚举和类只占很小一部分 比如Bool、 Int、 Double、 String、 Array、 Dictionary等常见类型都是结构体 ① struct Date { ② var year: Int ③ var month: Int ④ …

全域网络安全防御 健全网络安全防护体系

网络安全基本概念 网络安全(Cyber Security)是指网络系统的硬件、软件及其系统中的数据受到保护,不因偶然的或者恶意的原因而遭受到破坏、更改、泄露,系统连续可靠正常地运行,网络服务不中断,使网络处于稳…

记录小白使用 Cursor 开发第一个微信小程序(二):创建项目、编译、预览、发布(250308)

文章目录 记录小白使用 Cursor 开发第一个微信小程序(二):创建项目、编译、预览、发布(250308)一、创建项目1.1 生成提示词1.2 生成代码 二、编译预览2.1 导入项目2.2 编译预览 三、发布3.1 在微信开发者工具进行上传3…

uploadlabs通关思路

目录 靶场准备 复现 pass-01 代码审计 执行逻辑 文件上传 方法一:直接修改或删除js脚本 方法二:修改文件后缀 pass-02 代码审计 文件上传 1. 思路 2. 实操 pass-03 代码审计 过程: 文件上传 pass-04 代码审计 文件上传 p…

CTFHub-FastCGI协议/Redis协议

将木马进行base64编码 <?php eval($_GET[cmd]);?> 打开kali虚拟机&#xff0c;使用虚拟机中Gopherus-master工具 Gopherus-master工具安装 git clone https://github.com/tarunkant/Gopherus.git 进入工具目录 cd Gopherus 使用工具 python2 "位置" --expl…

前端 | 向后端传数据,判断问题所在的调试过程

目录 ​编辑 1. 在 vue 文件中&#xff0c;在调用函数之前 先打印传入的数据 2. 在 js 文件中&#xff0c;打印接收到的数据 3. 在浏览器 Network 面板查看请求数据 4. 在 server.js 中查看请求数据 5. 确保 JSON 格式正确 知识点&#xff1a;JSON.stringify(req.body, …

STL之list的使用(超详解)

目录 一、list的介绍及使用 1.1 list的介绍 1.2 list的使用 1.2.1 list的构造 1.2.2 iterator的使用 1.2.3capacity&#xff08;容量相关&#xff09; 1.2.4 element access&#xff08;元素访问&#xff09; 1.2.5 modifiers&#xff08;链表修改&#xff09;…

在【k8s】中部署Jenkins的实践指南

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《Kubernetes航线图&#xff1a;从船长到K8s掌舵者》 &#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、引言 1、Jenkins简介 2、k8s简介 3、什么在…

Ae 效果详解:VR 发光

Ae菜单&#xff1a;效果/沉浸式视频/VR 发光 Immersive Video/VR Glow VR 发光 VR Glow效果用于在 VR 视频中创建光晕效果&#xff0c;并针对等距柱状投影&#xff08;Equirectangular&#xff09;进行优化&#xff0c;以确保全景画面中的光晕均匀分布&#xff0c;不受画面边缘…

猫耳大型活动提效——组件低代码化

1. 引言 猫耳前端在开发活动的过程中&#xff0c;经历过传统的 pro code 阶段&#xff0c;即活动页面完全由前端开发编码实现&#xff0c;直到 2020 年接入公司内部的低代码活动平台&#xff0c;满足了大部分日常活动的需求&#xff0c;运营可自主配置活动并上线&#xff0c;释…

ESP8266UDP透传

1. 配置 WiFi 模式 ATCWMODE3 // softAPstation mode 响应 : OK 2. PC 连⼊入 ESP8266 softAP 就是连接wifi 3.查询ESP8266设备的IP地址 ATCIFSR 响应: CIFSR: APIP, "192.168.4.1" CIFSR: APMAC, "1a: fe: 34: a5:8d: c6" CIFSR: STAIP, "192.…

【仿muduo库one thread one loop式并发服务器实现】

文章目录 一、项目介绍1-1、项目总体简介1-2、项目开发环境1-3、项目核心技术1-4、项目开发流程1-5、项目如何使用 二、框架设计2-1、功能模块划分2-1-1、SERVER模块2-1-2、协议模块 2-2、项目蓝图2-2-1、整体图2-2-2、模块关系图2-2-2-1、Connection 模块关系图2-2-2-2、Accep…

私有云基础架构与运维(二)

二.私有云基础架构 【项目概述】 经过云计算基础知识及核心技术的学习后&#xff0c;希望进一步了解 IT 基础架构的演变过 程&#xff0c;通过学习传统架构、集群架构以及私有云基础架构的相关知识&#xff0c;认识企业从传统 IT 基 础架构到私有云基础架构转型的必要性。…

DeepSeek R1-32B医疗大模型的完整微调实战分析(全码版)

DeepSeek R1-32B微调实战指南 ├── 1. 环境准备 │ ├── 1.1 硬件配置 │ │ ├─ 全参数微调:4*A100 80GB │ │ └─ LoRA微调:单卡24GB │ ├── 1.2 软件依赖 │ │ ├─ PyTorch 2.1.2+CUDA │ │ └─ Unsloth/ColossalAI │ └── 1.3 模…

vue3 vite项目安装eslint

npm install eslint -D 安装eslint库 npx eslint --init 初始化配置&#xff0c;按项目实际情况选 自动生成eslint.config.js&#xff0c;可以添加自定义rules 安装ESLint插件 此时打开vue文件就会标红有问题的位置 安装prettier npm install prettier eslint-config-pr…