PyTorch系列教程:编写高效模型训练流程

当使用PyTorch开发机器学习模型时,建立一个有效的训练循环是至关重要的。这个过程包括组织和执行对数据、参数和计算资源的操作序列。让我们深入了解关键组件,并演示如何构建一个精细的训练循环流程,有效地处理数据处理,向前和向后传递以及参数更新。

模型训练流程

PyTorch训练循环流程通常包括:

  • 加载数据
  • 批量处理
  • 执行正向传播
  • 计算损失
  • 反向传播
  • 更新权重

一个典型的训练流程将这些步骤合并到一个迭代过程中,在数据集上迭代多次,或者在训练的上下文中迭代多个epoch。
在这里插入图片描述

1. 搭建环境

在编写代码之前,请确保在本地环境中设置了PyTorch。这通常需要安装PyTorch和其他依赖项:

pip install torch torchvision

下面演示为建立一个有效的训练循环奠定了基本路径的示例。

2. 数据加载

数据加载是使用DataLoader完成的,它有助于数据的批量处理:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
data_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
train_loader = DataLoader(data_train, batch_size=64, shuffle=True)

DataLoader在这里被设计为以64个为单位的批量获取数据,在数据传递中进行随机混淆。

3. 模型初始化

一个使用PyTorch的简单神经网络定义如下:

import torch.nn as nn
import torch.nn.functional as F

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

这里,784指的是输入维度(28x28个图像),并创建一个输出大小为10个类别的顺序前馈网络。

4. 建立训练循环

定义损失函数和优化器:为了改进模型的预测,必须定义损失和优化器:

import torch.optim as optim

model = SimpleNN()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

5. 实现训练循环

有效的训练循环的本质在于正确的步骤顺序:

epochs = 5
for epoch in range(epochs):
    running_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Zero the parameter gradients
        output = model(images)  # Forward pass
        loss = criterion(output, labels)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Optimize weights
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader)}")

注意,每次迭代都需要重置梯度、通过网络处理输入、计算误差以及调整权重以减少该误差。

性能优化

使用以下策略提高循环效率:

  • 使用GPU:将计算转移到GPU上,以获得更快的处理速度。如果GPU可用,使用to(‘cuda’)转换模型和输入。

  • 数据并行:利用多gpu设置与dataparlele模块来分发批处理。

  • FP16训练:使用自动混合精度(AMP)来加速训练并减少内存使用,而不会造成明显的精度损失。

在 PyTorch 中使用 FP16(半精度浮点数)训练 可以显著减少显存占用、加速计算,同时保持模型精度接近 FP32。以下是详细指南:

1. FP16 的优势

  • 显存节省:FP16 占用显存是 FP32 的一半(例如,1024MB 显存在 FP32 下可容纳约 2000 万参数,在 FP16 下可容纳约 4000 万)。
  • 计算加速:NVIDIA 的 Tensor Core 支持 FP16 矩阵运算,速度比 FP32 快数倍至数十倍。
  • 适合大规模模型:如 Transformer、Vision Transformer(ViT)等参数量大的模型。

2. 实现 FP16 训练的两种方式

(1) 自动混合精度(Automatic Mixed Precision, AMP)

PyTorch 的 torch.cuda.amp 自动管理 FP16 和 FP32,减少手动转换的复杂性。

python

import torch
from torch.cuda.amp import autocast, GradScaler

model = model.to("cuda")  # 确保模型在 GPU 上
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scaler = GradScaler()  # 梯度缩放器

for data, target in dataloader:
    data = data.to("cuda").half()  # 输入转为 FP16
    target = target.to("cuda")

    with autocast():  # 自动切换 FP16/FP32 计算
        output = model(data)
        loss = criterion(output, target)

    scaler.scale(loss).backward()  # 梯度缩放
    scaler.step(optimizer)         # 更新参数
    scaler.update()               # 重置缩放器

关键点

  • autocast() 内部自动将计算转换为 FP16(若 GPU 支持),梯度累积在 FP32。
  • GradScaler() 解决 FP16 下梯度下溢问题。
(2) 手动转换(低级用法)

直接将模型参数、输入和输出转为 FP16,但需手动管理精度和稳定性。

python

model = model.half()  # 模型参数转为 FP16
for data, target in dataloader:
    data = data.to("cuda").half()  # 输入转为 FP16
    target = target.to("cuda")

    output = model(data)
    loss = criterion(output, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

缺点

  • 可能因数值不稳定导致训练失败(如梯度消失)。
  • 不支持动态精度切换(如部分层用 FP32)。

3. FP16 训练的注意事项

(1) 设备支持
  • NVIDIA GPU:需支持 Tensor Core(如 Volta 架构以上的 GPU,包括 Tesla V100、A100、RTX 3090 等)。
  • AMD GPU:部分型号支持 FP16 计算,但 AMP 功能受限(需使用 torch.backends.cudnn.enabled = False)。
(2) 学习率调整
  • FP16 的初始学习率通常设为 FP32 的 2~4 倍(因梯度放大),需配合学习率调度器(如 CosineAnnealingLR)。
(3) 损失缩放(Loss Scaling)
  • FP16 的梯度可能过小,导致update() 时下溢。解决方案:

    • 自动缩放:使用 GradScaler()(推荐)。
    • 手动缩放:将损失乘以一个固定因子(如 1e4),反向传播后再除以该因子。
(4) 模型初始化
  • FP16 参数初始化值不宜过大,否则可能导致 nan。建议初始化时用 FP32,再转为 FP16。
(5) 检查数值稳定性
  • 训练过程中监控损失是否为 nan 或无穷大。
  • 可通过 torch.set_printoptions(precision=10) 打印中间结果。

4. FP16 vs FP32 精度对比

模型FP32 精度损失FP16 精度损失
ResNet-18微小可忽略
BERT-base微小~1-2%
GPT-2微小~3-5%

结论:多数任务中 FP16 的精度损失可接受,但需通过实验验证。

5. 常见错误及解决

错误现象解决方案
RuntimeError: CUDA error: out of memory减少 batch size 或清理缓存 (torch.cuda.empty_cache())
naninf调整学习率、检查数据预处理、启用梯度缩放
InvalidArgumentError确保输入数据已正确转换为 FP16
  • 推荐使用 autocast + GradScaler:平衡易用性和性能。
  • 优先在 NVIDIA GPU 上使用:AMD GPU 的 FP16 支持较弱。
  • 从小批量开始测试:避免显存不足或数值不稳定。

通过合理配置,FP16 可以在几乎不损失精度的情况下显著提升训练速度和显存利用率。

最后总结

高效的训练循环为优化PyTorch模型奠定了坚实的基础。通过遵循适当的数据加载过程,模型初始化过程和系统的训练步骤,你的训练设置将有效地利用GPU资源,并通过数据集快速迭代,以构建健壮的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/983761.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PX4中的DroneCAN的实现库Libuavcan及基础功能示例

简介 Libuavcan是一个用C编写的可移植的跨平台库,对C标准库的依赖小。它可以由几乎任何符合标准的C编译器编译,并且可以在几乎任何体系结构/OS上使用。 在 DroneCAN 中,Libuavcan 有一个 DSDL 编译器,将 DSDL 文件转换为 hpp 头…

计算机网络(1) 网络通信基础,协议介绍,通信框架

网络结构模式 C/S-----客户端和服务器 B/S -----浏览器服务器 MAC地址 每一个网卡都拥有独一无二的48位串行号,也即MAC地址,也叫做物理地址、硬件地址或者是局域网地址 MAC地址表示为12个16进制数 如00-16-EA-AE-3C-40 (每一个数可以用四个…

PCA(主成分分析)核心原理

一、PCA(主成分分析)核心原理 即主成分分析技术,又称主分量分析技术,旨在利用降维的思想,把多指标转化为少数几个综合指标。在统计学中,主成分分析PCA是一种简化数据集的技术。它是一个线性变换。这个变换…

SpringBoot-模拟SSE对话交互

SpringBoot-模拟SSE对话交互 后端使用SSE进行会话,前端使用Html模拟大模型的问答交互->【前端】【后端】 1-学习目的 本项目代码仓库:https://gitee.com/enzoism/springboot_sse 1-核心知识点 1)什么是SSE协议->客户端发起一次请求&am…

2025DNS二级域名分发PHP网站源码

安装教程 1.程序必须使用PHP8.1 2.将扩展ixed.8.1.lin放入/www/server/php/81/lib/php/extensions/no-debug-non-zts-20210902 3.打开宝塔→软件商店→PHP8.1→配置文件 4.放入:extensionixed.8.1.lin 5.重启PHP8.1 6.新建站点(mysql5.6-5.7andPHP8.1&a…

Matlab实现车牌识别

车牌识别技术作为现代智能交通系统、安防监控以及诸多车辆管理应用场景中的关键环节,正发挥着日益重要的作用,它能够自动、快速且精准地从车辆图像或视频流中提取车牌信息,实现车辆身份的智能化识别。 技术原理 车牌识别主要依托于图像处理、…

C语言——链表

大神文献:https://blog.csdn.net/weixin_73588765/article/details/128356985 目录 一、链表概念 1. 什么是链表? 1.1 链表的构成 2. 链表和数组的区别 数组的特点: 链表的特点: 二者对比: 二…

国产化板卡设计原理图:2330-基于FMC接口的JFM7K325T PCIeX4 3U PXIe接口卡

基于FMC接口的JFM7K325T PCIeX4 3U PXIe接口卡 一、板卡概述 本板卡基于 FPGAJFM7K325T 芯片,pin_to_pin兼容FPGAXC7K410T-2FFG900 ,支持PCIeX8、64bit DDR3容量2GByte,HPC的FMC连接器,板卡支持PXIE标准协议,其中XJ3…

【网络】HTTP协议、HTTPS协议

HTTP与HTTPS HTTP协议概述 HTTP(超文本传输协议):工作在OSI顶层应用层,用于客户端(浏览器)与服务器之间的通信,B/S模式 无状态:每次请求独立,服务器不保存客户端状态(通过Cookie/Session扩展状态管理)。基于TCP:默认端口80(HTTP)、443(HTTPS),保证可靠传输。请…

设计AI芯片架构的入门 研究生入行数字芯片设计、验证的项目 opentitan

前言 这几年芯片设计行业在国内像坐过山车。时而高亢,时而低潮。最近又因为AI的热潮开始high起来。到底芯片行业的规律是如何? 我谈谈自己观点:芯片设计是“劳动密集型”行业。 “EDA和工具高度标准化和代工厂的工艺标准化之后,芯…

K8S学习之基础十七:k8s的蓝绿部署

蓝绿部署概述 ​ 蓝绿部署中,一共有两套系统,一套是正在提供服务的系统,一套是准备发布的系统。两套系统都是功能完善、正在运行的系统,只是版本和对外服务情况不同。 ​ 开发新版本,要用新版本替换线上的旧版本&…

STM32之I2C硬件外设

注意:硬件I2C的引脚是固定的 SDA和SCL都是复用到外部引脚。 SDA发送时数据寄存器的数据在数据移位寄存器空闲的状态下进入数据移位寄存器,此时会置状态寄存器的TXE为1,表示发送寄存器为空,然后往数据控制寄存器中一位一位的移送数…

Linux基础--用户管理

目录 查看用户 使用命令: id 创建用户 使用命令: useradd ​编辑 为用户设置密码 使用命令: passwd ​编辑 删除用户 使用命令: userdel 创建用户组 使用命令: groupadd 删除用户组 使用命令: groupdel 用户设置 使用命令: usermod 将用户从组中去除 使用…

VSTO(C#)Excel开发3:Range对象 处理列宽和行高

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的,可以在任何平台上使用。 源码指引:github源…

本地YARN集群部署

请先完成HDFS的前置部署,部署方式可查看:本地部署HDFS集群https://blog.csdn.net/m0_73641796/article/details/145998092?spm1001.2014.3001.5502 部署说明 组件配置文件启动进程备注Hadoop HDFS需修改 需启动: NameNode作为主节点 DataNode作为从节点 Secondary…

每日一练之移除链表元素

题目: 画图解析: 方法:双指针 解答代码(注:解答代码带解析): //题目给的结构体 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* }…

黑金风格人像静物户外旅拍Lr调色教程,手机滤镜PS+Lightroom预设下载!

调色教程 针对人像、静物以及户外旅拍照片,运用 Lightroom 软件进行风格化调色工作。旨在通过软件中的多种工具,如基本参数调整、HSL(色相、饱和度、明亮度)调整、曲线工具等改变照片原本的色彩、明度、对比度等属性,将…

【GPT入门】第8课 大语言模型的自洽性

【GPT入门】第8课 大语言模型的自洽性 1.自洽性概念2.代码(观察执行结果)3.自洽性核心思想 1.自洽性概念 大模型的自洽性(self - consistency)是指在推理阶段,大模型通过生成多个答案并选择出现频率最高的那个&#x…

深度学习(斋藤康毅)学习笔记(六)反向传播3

上一篇文章介绍了反向传播的自动化,但也存在一些问题,本章用于说明这些问题,并修改原有框架,使其支持复杂计算图的运行: 问题一:重复使用一个变量,梯度不会累计 也就是说,反向传播时…

3.6c语言

#define _CRT_SECURE_NO_WARNINGS #include <math.h> #include <stdio.h> int main() {int sum 0,i,j;for (j 1; j < 1000; j){sum 0;for (i 1; i < j; i){if (j % i 0){sum i;} }if (sum j){printf("%d是完数\n", j);}}return 0; }#de…