【WB 深度学习实验管理】使用 PyTorch Lightning 实现高效的图像分类实验跟踪

本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到,别忘了给仓库点个小心心~~~
https://github.com/LFF8888/FF-Studio-Resources
在这里插入图片描述

在机器学习项目中,实验跟踪和结果可视化是至关重要的环节。无论是调整超参数、优化模型架构,还是监控训练过程中的性能变化,清晰的记录和直观的可视化都能显著提升开发效率。然而,许多开发者在实际操作中往往忽视了这一点,导致实验结果难以复现,或者在项目协作中出现混乱。今天,笔者将介绍如何利用 PyTorch Lightning 和 Weights & Biases 这一强大的工具组合,轻松构建和训练一个图像分类模型。通过本文,你将学会如何高效地组织数据管道、定义模型架构,并利用 W&B 实现实验跟踪和结果可视化,让每一次实验都清晰可溯,每一次优化都有据可依。

使用 PyTorch Lightning ⚡️ 进行图像分类

我们将使用 PyTorch Lightning 构建一个图像分类管道。我们将遵循这个 风格指南 来提高代码的可读性和可重复性。这里有一个很酷的解释:使用 PyTorch Lightning 进行图像分类。

设置 PyTorch Lightning 和 W&B

对于本教程,我们需要 PyTorch Lightning(这不是很明显吗!)和 Weights and Biases。

!pip install lightning torchvision -q
# 安装 weights and biases
!pip install wandb -qU

你需要这些导入。

import lightning.pytorch as pl
# 你最喜欢的机器学习跟踪工具
from lightning.pytorch.loggers import WandbLogger

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import CIFAR10

import wandb

现在你需要登录到你的 wandb 账户。

wandb.login()

🔧 DataModule - 我们应得的数据管道

DataModules 是一种将数据相关的钩子与 LightningModule 解耦的方式,以便你可以开发与数据集无关的模型。
它将数据管道组织成一个可共享和可重用的类。一个 datamodule 封装了 PyTorch 中数据处理的五个步骤:

  • 下载 / 分词 / 处理。
  • 清理并(可能)保存到磁盘。
  • 加载到 Dataset 中。
  • 应用转换(旋转、分词等)。
  • 包装到 DataLoader 中。

了解更多关于 datamodules 的信息 这里。让我们为 Cifar-10 数据集构建一个 datamodule。

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.num_classes = 10

    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # 为 dataloaders 分配训练/验证数据集
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # 为 dataloader(s) 分配测试数据集
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

📱 Callbacks

回调是一个独立的程序,可以在项目之间重用。PyTorch Lightning 提供了一些 内置回调,这些回调经常被使用。
了解更多关于 PyTorch Lightning 中的回调 这里。

内置回调

在本教程中,我们将使用 Early Stopping 和 Model Checkpoint 内置回调。它们可以传递给 Trainer

自定义回调

如果你熟悉自定义 Keras 回调,那么在 PyTorch 管道中实现相同功能的能力只是锦上添花。
由于我们正在进行图像分类,能够可视化模型对一些样本图像的预测可能很有帮助。这种形式的回调可以帮助在早期阶段调试模型。

class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples

    def on_validation_epoch_end(self, trainer, pl_module):
        # 将张量带到 CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # 获取模型预测
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # 将图像记录为 wandb Image
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                           for x, pred, y in zip(val_imgs[:self.num_samples],
                                                 preds[:self.num_samples],
                                                 val_labels[:self.num_samples])]
            })

🎺 LightningModule - 定义系统

LightningModule 定义了一个系统,而不是一个模型。在这里,系统将所有研究代码分组到一个类中,使其自包含。LightningModule 将你的 PyTorch 代码组织成 5 个部分:

  • 计算 (__init__)。
  • 训练循环 (training_step)
  • 验证循环 (validation_step)
  • 测试循环 (test_step)
  • 优化器 (configure_optimizers)

因此,可以构建一个与数据集无关的模型,并且可以轻松共享。让我们为 Cifar-10 分类构建一个系统。

class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()

        # 记录超参数
        self.save_hyperparameters()
        self.learning_rate = learning_rate

        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)

        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)

    # 返回从卷积块进入线性层的输出张量的大小。
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input)
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size

    # 返回卷积块的特征张量
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x

    # 将在推理期间使用
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)

       return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # 训练指标
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # 验证指标
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # 验证指标
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

🚋 训练和评估

现在我们已经使用 DataModule 组织了数据管道,并使用 LightningModule 组织了模型架构和训练循环,PyTorch Lightning Trainer 为我们自动化了其他所有内容。

Trainer 自动化了以下内容:

  • Epoch 和 batch 迭代
  • 调用 optimizer.step()backwardzero_grad()
  • 调用 .eval(),启用/禁用梯度
  • 保存和加载权重
  • Weights and Biases 日志记录
  • 多 GPU 训练支持
  • TPU 支持
  • 16 位训练支持
dm = CIFAR10DataModule(batch_size=32)
# 要访问 x_dataloader,我们需要调用 prepare_data 和 setup。
dm.prepare_data()
dm.setup()

# 自定义 ImagePredictionLogger 回调所需的样本,用于记录图像预测。
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)

# 初始化 wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# 初始化 Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# 初始化一个 trainer
trainer = pl.Trainer(max_epochs=2,
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# 训练模型 ⚡🚅⚡
trainer.fit(model, dm)

# 在保留的测试集上评估模型 ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

# 关闭 wandb run
wandb.finish()

最终想法

我来自 TensorFlow/Keras 生态系统,发现 PyTorch 虽然是一个优雅的框架,但有点让人不知所措。这只是我的个人经验。在探索 PyTorch Lightning 时,我意识到几乎所有让我远离 PyTorch 的原因都得到了解决。以下是我兴奋的快速总结:

  • 过去:传统的 PyTorch 模型定义通常分散在各个地方。模型在某个 model.py 脚本中,训练循环在 train.py 文件中。需要来回查看才能理解管道。
  • 现在:LightningModule 作为一个系统,模型定义与 training_stepvalidation_step 等一起定义。现在它是模块化的且可共享的。
  • 过去:TensorFlow/Keras 最棒的部分是输入数据管道。他们的数据集目录丰富且不断增长。PyTorch 的数据管道曾经是最大的痛点。在普通的 PyTorch 代码中,数据下载/清理/准备通常分散在许多文件中。
  • 现在:DataModule 将数据管道组织成一个可共享和可重用的类。它只是 train_dataloaderval_dataloader(s)、test_dataloader(s) 以及匹配的转换和数据处理/下载步骤的集合。
  • 过去:使用 Keras,可以调用 model.fit 来训练模型,调用 model.predict 来运行推理。model.evaluate 提供了一个简单而有效的测试数据评估。这在 PyTorch 中不是这样。通常会找到单独的 train.pytest.py 文件。
  • 现在:有了 LightningModuleTrainer 自动化了一切。只需调用 trainer.fittrainer.test 来训练和评估模型。
  • 过去:TensorFlow 喜欢 TPU,PyTorch…嗯!
  • 现在:使用 PyTorch Lightning,可以轻松地在多个 GPU 上训练相同的模型,甚至在 TPU 上。哇!
  • 过去:我是回调的忠实粉丝,更喜欢编写自定义回调。像 Early Stopping 这样简单的事情曾经是传统 PyTorch 的讨论点。
  • 现在:使用 PyTorch Lightning,使用 Early Stopping 和 Model Checkpointing 是小菜一碟。我甚至可以编写自定义回调。

🎨 结论和资源

我希望你觉得这份报告有帮助。我鼓励你玩一下代码,并使用你选择的数据集训练一个图像分类器。

以下是一些学习更多关于 PyTorch Lightning 的资源:

  • 逐步演练 - 这是官方教程之一。他们的文档写得非常好,我强烈推荐它作为学习资源。
  • 使用 PyTorch Lightning 与 Weights & Biases - 这是一个快速 colab,你可以通过它学习如何使用 W&B 与 PyTorch Lightning。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/968466.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

异位妊娠唯一相关的是年龄(U型曲线)

异位妊娠唯一相关的是年龄(U型曲线) 简介 异位妊娠,俗称宫外孕,是指受精卵在子宫体腔以外着床发育的异常妊娠过程 。正常情况下,受精卵会在子宫内着床并发育成胎儿,但在异位妊娠中,受精卵却在…

ESM3(1)-介绍:用语言模型模拟5亿年的进化历程

超过30亿年的进化在天然蛋白质空间中编码形成了一幅生物学图景。在此,作者证明在进化数据上进行大规模训练的语言模型,能够生成与已知蛋白质差异巨大的功能性蛋白质,并推出了ESM3,这是一款前沿的多模态生成式语言模型,…

CondaValueError: Malformed version string ‘~‘: invalid character(s)

CondaValueError: Malformed version string ‘~‘: invalid character(s) 送一张 GPT plus 、 deepseek-R1 满血 体验卡~ https://bbs.csdn.net/topics/619568415 ​ 报错原因 使用conda安装一些库时出现以下报错: CondaValueError: Malformed versio…

01、单片机上电后没有正常运行怎么办

单片机上电后没有运转, 首先要检查什么? 1、单片机供电是否正常? &电路焊接检查 如果连最基本的供电都没有,其它都是空谈啊!检查电路断路了没有?短路了没有?电源合适吗?有没有虚焊? 拿起万用表之前,预想一下测量哪里?供电电压应该是多少?对PCB上电压测量点要…

基于Java的分布式系统架构设计与实现

Java在大数据处理中的应用:基于Java的分布式系统架构设计与实现 随着大数据时代的到来,数据处理的规模和复杂性不断增加。为了高效处理海量数据,分布式系统成为了必不可少的架构之一。而Java,凭借其平台独立性、丰富的生态系统以…

【含文档+PPT+源码】基于Python的全国景区数据分析以及可视化实现

项目介绍 本课程演示的是一款基于Python的全国景区数据分析以及可视化实现,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 包含:项目源码、项目文档、数据库脚本、软件工具等所有资料 带你从零开始部署运行本套系统 该…

Apache Kafka 中的认证、鉴权原理与应用

编辑导读:本篇内容将进一步介绍 Kafka 中的认证、鉴权等概念。AutoMQ 是与 Apache Kafka 100% 完全兼容的新一代 Kafka,可以帮助用户降低 90%以上的 Kafka 成本并且进行极速地自动弹性。作为 Kafka 生态的忠实拥护者,我们也会持续致力于传播 …

初阶数据结构:树---二叉树的链式结构

目录 一、二叉树的链式结构 (一)、概念 二、二叉树链式结构的实现 (一)、二叉树链式结构的遍历 1、前序遍历 2、中序遍历 3、后序遍历 4、层序遍历 (二)、二叉树的构建 (三&#xff0…

SurfGen爬虫:解析HTML与提取关键数据

一、SurfGen爬虫框架简介 SurfGen是一个基于Swift语言开发的爬虫框架,它提供了丰富的功能,包括网络请求、HTML解析、数据提取等。SurfGen的核心优势在于其简洁易用的API和高效的性能,使得开发者能够快速构建爬虫程序。以下是SurfGen的主要特…

pyrender 渲染报错解决

pyrender渲染后,出来的图样子不对: 正确的图: 解决方法: pip install numpy1.26 下面的不是必须的: pip install pyrender0.1.45 os.environ["PYOPENGL_PLATFORM"] "egl" os.environ[EGL_DEVI…

C++,STL容器,unordered_map/unordered_multimap:无序映射/无序多重映射深入解析

文章目录 一、容器概览与核心特性核心特性对比二、底层实现原理:哈希表架构1. 哈希表核心结构2. 动态扩容机制三、核心操作详解1. 容器初始化与配置2. 元素插入与更新3. 元素访问与查找4. 元素删除策略四、实战应用场景1. 缓存系统实现2. 分布式系统路由表五、性能优化策略1. …

Qt 控件整理 —— 按钮类

一、PushButton 1. 介绍 在Qt中最常见的就是按钮,它的继承关系如下: 2. 常用属性 3. 例子 我们之前写过一个例子,根据上下左右的按钮去操控一个按钮,当时只是做了一些比较粗糙的去演示信号和槽是这么连接的,这次我们…

python-leetcode 27.合并两个有序链表

题目: 将两个升序链表合并为一个新的升序链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4] 方法一:递归 函数在运行时调用自己,这个函数叫递归函数…

Unity中实现动态图集算法

在 Unity 中,动态图集(Dynamic Atlas)是一种在运行时将多个纹理合并成一个大纹理图集的技术,这样可以减少渲染时的纹理切换次数,提高渲染效率。 实现原理: 动态图集的核心思想是在运行时动态地将多个小纹理…

公然上线传销项目,Web3 的底线已经被无限突破

作者:Techub 热点速递 撰文:Yangz,Techub News 今天早些时候,OKX 将上线 PI 的消息在圈内引起轩然大波,对于上线被板上钉钉为传销盘子的「项目」 ,Techub News 联系了 OKX 公关,但对方拒绝置评…

元宵节快乐

早上吃的一碗小颗粒汤圆。 晚上做了三个小菜,一碗米饭和一杯饮料。 整理了Chrome浏览器收藏夹书签,删除了太多不需要的书签,重新分类,更加细化。 看到某博主推荐的5本书,下载这学期看看。点击此处下载 看来这段关系…

SAP系统常见的接口方式及特点介绍

【SAP系统研究】 在SAP系统中,接口主要用于系统间或系统与外部应用的数据交换和集成。以下是常见的接口方式及其特点: 一、IDoc方式 IDoc,Intermediate document,是SAP历史很悠久的接口技术,是一种系统间通用的数据交换媒介文件。IDoc基于XML的标准格式,常用于EDI、系…

【嵌入式Linux应用开发基础】open函数与close函数

目录 一、open函数 1.1. 函数原型 1.2 参数说明 1.3 返回值 1.4. 示例代码 二、close函数 2.1. 函数原型 2.2. 示例代码 三、关键注意事项 3.1. 资源管理与泄漏防范 3.2. 错误处理的严谨性 3.3. 标志(flags)与权限(mode&#xff…

LabVIEW国内外开发的区别

LabVIEW作为全球领先的图形化编程平台,在国内外工业测控领域均占据重要地位。本文从开发理念、技术生态、应用深度及自主可控性四个维度,对比分析国内外LabVIEW开发的差异,并结合国内实际应用场景,探讨其未来发展趋势。 ​ 一、开…

【大模型】阿里云百炼平台对接DeepSeek-R1大模型使用详解

目录 一、前言 二、DeepSeek简介 2.1 DeepSeek 是什么 2.2 DeepSeek R1特点 2.2.1 DeepSeek-R1创新点 2.3 DeepSeek R1应用场景 2.4 与其他大模型对比 三、阿里云百炼大平台介绍 3.1 阿里云百炼大平台是什么 3.2 阿里云百炼平台主要功能 3.2.1 应用场景 3.3 为什么选…