checkpoint_blocks
函数实现了一种分块梯度检查点机制 (checkpoint_blocks
),目的是通过分块(chunking)执行神经网络模块,减少内存使用。在深度学习训练中,梯度检查点(activation checkpointing)是一种显存优化技术。该代码可以:
- 对神经网络的块(blocks)按需分块,并对每块应用梯度检查点。
- 动态调整计算开销与显存占用的权衡。
1. 源代码:
from typing import Any, Tuple, List, Callable, Optional
import torch
import torch.utils.checkpoint
import functools
try:
import deepspeed
deepspeed_is_installed = True
except ImportError:
deepspeed_is_installed = False
BLOCK_ARG = Any
BLOCK_ARGS = Tuple[BLOCK_ARG, ...] # List[BLOCK_ARGS]
def get_checkpoint_fn():
return torch.utils.checkpoint.checkpoint # deepspeed.checkpointing.checkpoint
def checkpoint_blocks(
blocks: List[Callable],
args: BLOCK_ARGS,
blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS:
"""
Chunk a list of b