PyTorch Lightning:通过分布式训练扩展深度学习工作流

 

一、介绍

        欢迎来到我们关于 PyTorch Lightning 系列的第二篇文章!在上一篇文章中,我们向您介绍了 PyTorch Lightning,并探讨了它在简化深度学习模型开发方面的主要功能和优势。我们了解了 PyTorch Lightning 如何为组织和构建 PyTorch 代码提供高级抽象,使研究人员和从业者能够更多地关注模型设计和实验,而不是样板代码。

        在本文中,我们将深入研究 PyTorch Lightning,并探索它如何通过分布式训练实现深度学习工作流的扩展。分布式训练对于在海量数据集上训练大型模型至关重要,因为它允许我们利用多个 GPU 或机器的强大功能来加速训练过程。然而,分布式训练往往伴随着一系列挑战和复杂性。

二、安装 Pytorch Lightning & Torchvision

pip install torch torchvision pytorch-lightning 

三、实现

        首先,我们需要从 PyTorch 和 PyTorch Lightning 导入必要的模块:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms

import pytorch_lightning as pl

        接下来,我们使用 PyTorch 的类定义我们的神经网络架构。在这个例子中,我们使用一个简单的卷积神经网络,其中包含两个卷积层和三个全连接层:nn.Module

class Net(pl.LightningModule):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

        然后,我们为 .在该方法中,我们接收一批输入和标签,将它们通过我们的神经网络来获取 logits,计算交叉熵损失,并使用该方法记录训练损失。在该方法中,我们执行与 相同的操作,但不记录损失:LightningModuletraining_stepxyself.logvalidation_steptraining_step

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log("val_loss", loss)
        return loss

        我们还在方法中定义了优化器和学习率调度器:configure_optimizers

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
        return [optimizer], [scheduler]

        接下来,我们使用 PyTorch 和 定义数据加载和预处理步骤:DataLoadertransforms

    def prepare_data(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        CIFAR10(root='./data', train=True, download=True, transform=transform)
        CIFAR10(root='./data', train=False, download=True, transform=transform)

    def train_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = CIFAR10(root='./data', train=True, download=False, transform=transform)
        return DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)

    def val_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        val_dataset = CIFAR10(root='./data', train=False, download=False, transform=transform)
        return DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8)
  1. prepare_data(self):此函数负责在训练模型之前准备数据。它首先使用该类定义一系列转换。转换包括将数据转换为张量并对其进行规范化。定义转换后,该函数将下载用于训练和测试拆分的 CIFAR10 数据集。数据集将下载到目录,并将指定的转换应用于数据。transforms.Compose'./data'
  2. train_dataloader(self):此函数为训练数据集创建数据加载器。它首先定义与函数中相同的转换。接下来,它为训练拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用训练数据集创建一个对象。数据加载程序配置为 64 的批大小,对数据进行随机排序,并使用 8 个工作线程进行数据加载。它返回数据加载器。prepare_data'./data'DataLoader
  3. val_dataloader(self):此函数为验证数据集创建数据加载器。它遵循与函数类似的结构。它首先使用 定义转换,这些转换与前面的函数相同。然后,为验证拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用验证数据集创建一个对象。数据加载器配置为 64 的批大小,无需随机处理数据,并使用 8 个工作线程进行数据加载。它返回数据加载器。train_dataloadertransforms.Compose'./data'DataLoader

        该函数将模型作为输入,并对测试数据集执行评估。它首先对测试数据应用转换,将其转换为张量并规范化。然后,它为测试数据集创建数据加载程序。模型将移动到相应的设备(GPU,如果可用)。评估标准设置为交叉熵损失。evaluate_model

def evaluate_model(model):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()

    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    average_loss = test_loss / len(test_loader)

    print(f"Test Loss: {average_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.2f}%")

        将模型置于评估模式,并初始化测试损失、正确预测和总数据点的变量。在无梯度上下文中,该函数遍历测试数据加载器,通过模型转发成批的输入,计算损失并累积测试损失。它还计算正确预测的数量和数据点的总数。最后,它计算并打印平均测试损失和测试精度。

        最后,我们实例化我们的模型和来自 PyTorch Lightning,指定用于分布式训练的所需数量的 GPU 或机器:NetTrainer

net = Net()

trainer = pl.Trainer(
    
    num_nodes=1,  # Change to the number of machines in your distributed setup
    accelerator="auto",  # Distributed Data Parallel, Available names are: auto, cpu, cuda, hpu, ipu, mps, tpu.
    max_epochs=5, 
    devices=1 # Change to the desired number of GPUs or use `None` for CPU training
)

trainer.fit(net)

evaluate_model(net)
  • num_nodes:它指定分布式设置中的计算机数量。在这种情况下,它设置为 ,表示单台计算机设置。1
  • accelerator:它确定训练的加速器类型。该值允许 PyTorch Lightning 根据硬件和软件环境自动选择适当的加速器。其他可能的值包括 、 和 ,它们对应于特定的硬件加速器。"auto""cpu""cuda""hpu""ipu""mps""tpu"
  • max_epochs:它设置用于训练模型的最大周期数(通过训练数据集的完整遍历)。在本例中,它设置为 。5
  • devices:它指定用于训练的 GPU 数量。将其设置为 表示使用单个 GPU 进行训练。如果要在 CPU 上进行训练,可以将其设置为 。1None

        这些选项允许您控制训练过程的各个方面,例如分布式训练、加速器选择以及用于训练的周期数和设备数。

        设置好所有内容后,我们只需调用对象的方法,传入我们的模型、训练数据加载器和验证数据加载器。fitTrainerNet

四、输出

 

五、结论

        PyTorch Lightning 通过分布式训练简化了扩展深度学习工作流的过程。通过抽象化分布式训练的复杂性,PyTorch Lightning 使我们能够专注于设计和实现我们的深度学习模型,而不必担心低级细节。在本文中,我们演练了一个使用 PyTorch Lightning 进行分布式训练的示例代码实现。通过利用多个GPU或机器的强大功能,我们可以显著减少大型深度学习模型的训练时间。

六、引用

  • PyTorch Lightning: Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.1.0.rc0 documentation
  • PyTorch: PyTorch
  • torchvision.datasets.CIFAR10: Datasets — Torchvision 0.15 documentation
  • torch.utils.data.DataLoader: torch.utils.data — PyTorch 2.0 documentation
  • 火炬亚当:Adam — PyTorch 2.0 documentation
  • torch.optim.lr_scheduler。步长:StepLR — PyTorch 2.0 documentation
  • Torch.nn.CrossEntropyLoss: CrossEntropyLoss — PyTorch 2.0 documentation
  • torch.cuda.is_available:torch.cuda — PyTorch 2.0 documentation

阿奈·东格雷

皮托奇

分布式系统

深度学习
皮托奇闪电
计算机视觉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/81715.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

详解junit

目录 1.概述 2.断言 3.常用注解 3.1.Test 3.2.Before 3.3.After 3.4.BeforeClass 3.5.AfterClass 4.异常测试 5.超时测试 6.参数化测试 1.概述 什么是单元测试: 单元测试,是针对最小的功能单元编写测试代码,在JAVA中最小的功能单…

openpose姿态估计【学习笔记】

文章目录 1、人体需要检测的关键点2、Top-down方法3、Openpose3.1 姿态估计的步骤3.2 PAF(Part Affinity Fields)部分亲和场3.3 制作PAF标签3.4 PAF权值计算3.5 匹配方法 4、CPM(Convolutional Pose Machines)模型5、Openpose5.1 …

博客系统之功能测试

博客系统共有:用户登录功能、发布博客功能、查看文章详情功能、查看文章列表功能、删除文章功能、退出功能 1.登录功能: 1.1测试对象:用户登录 1.2测试用例 方法:判定表 用例 编号 操作步骤预期结果实际结果截图1 1.用户名正确…

【C++从0到王者】第二十一站:继承

文章目录 前言一、继承的概念及定义1. 继承的概念2.继承的格式3.继承关系与访问限定符 二、基类和派生类的赋值转换三、继承中的作用域四、派生类的默认成员函数五、继承与友元六、继承与静态成员 前言 继承是面向对象的三大特性之一。我们常常会遇到这样的情况。很多角色的信…

一、docker及mysql基本语法

文章目录 一、docker相关命令二、mysql相关命令 一、docker相关命令 &#xff08;1&#xff09;拉取镜像&#xff1a;docker pull <镜像ID/image> &#xff08;2&#xff09;查看当前docker中的镜像&#xff1a;docker images &#xff08;3&#xff09;删除镜像&#x…

Python web实战之细说 Django 的单元测试

关键词&#xff1a; Python Web 开发、Django、单元测试、测试驱动开发、TDD、测试框架、持续集成、自动化测试 大家好&#xff0c;今天&#xff0c;我将带领大家进入 Python Web 开发的新世界&#xff0c;深入探讨 Django 的单元测试。通过本文的实战案例和详细讲解&#xff…

DNS域名解析服务器

一、DNS简介 1、因特网的域名结构 2、域名服务器的类型划分 二、DNS域名解析的过程 三、DNS服务器配置 两个都定义&#xff0c;ttl的优先&#xff1a; 能解析&#xff0c;不能拼通&#xff08;没有13这个主机&#xff09; 别名&#xff1a; 测试&#xff1a; 主&#xff08;192…

AI在日常生活中的应用:从语音助手到自动驾驶

文章目录 AI的定义和发展AI在日常生活中的应用1. **智能语音助手**2. **智能家居**3. **智能医疗**4. **自动驾驶** 代码示例&#xff1a;使用Python实现基于机器学习的图片分类AI的未来前景结论 &#x1f389;欢迎来到AIGC人工智能专栏~探索AI在日常生活中的应用 ☆* o(≧▽≦…

fiddler抓包问题记录,支持https、解决 tunnel to 443

fiddler下载安装步骤及基本配置 fiddler抓包教程&#xff0c;如何抓取HTTPS请求&#xff0c;详细教程 可能遇到的问题及解决方案 1. 不能正常访问页面&#xff08;所有https都无法访问&#xff09; 解决方案&#xff1a;查看下面配置是否正确 Rules-customization 找到 OnB…

Django进阶:DRF(Django REST framework)

什么是DRF&#xff1f; DRF即Django REST framework的缩写&#xff0c;官网上说&#xff1a;Django REST framework是一个强大而灵活的工具包&#xff0c;用于构建Web API。 简单来说&#xff1a;通过DRF创建API后&#xff0c;就可以通过HTTP请求来获取、创建、更新或删除数据(…

常见排序集锦-C语言实现数据结构

目录 排序的概念 常见排序集锦 1.直接插入排序 2.希尔排序 3.选择排序 4.堆排序 5.冒泡排序 6.快速排序 hoare 挖坑法 前后指针法 非递归 7.归并排序 非递归 排序实现接口 算法复杂度与稳定性分析 排序的概念 排序 &#xff1a;所谓排序&#xff0c;就是使一串记录&#…

Win11中zookeeper的下载与安装

下载步骤 打开浏览器&#xff0c;前往 Apache ZooKeeper 的官方网站&#xff1a;zookeeper官方。在主页上点击"Project"选项&#xff0c;并点击"Release" 点击Download按钮&#xff0c;跳转到下载目录 在下载页面中&#xff0c;选择版本号&#xff0c;并点…

什么是Pytorch?

当谈及深度学习框架时&#xff0c;PyTorch 是当今备受欢迎的选择之一。作为一个开源的机器学习库&#xff0c;PyTorch 为研究人员和开发者们提供了一个强大的工具来构建、训练以及部署各种深度学习模型。你可能会问&#xff0c;PyTorch 是什么&#xff0c;它有什么特点&#xf…

springboot、java实现调用企业微信接口向指定用户发送消息

因为项目的业务逻辑需要向指定用户发送企业微信消息&#xff0c;所以在这里记录一下 目录 引入相关依赖创建配置工具类创建发送消息类测试类最终效果 引入相关依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-…

【Python机器学习】实验14 手写体卷积神经网络(PyTorch实现)

文章目录 LeNet-5网络结构&#xff08;1&#xff09;卷积层C1&#xff08;2&#xff09;池化层S1&#xff08;3&#xff09;卷积层C2&#xff08;4&#xff09;池化层S2&#xff08;5&#xff09;卷积层C3&#xff08;6&#xff09;线性层F1&#xff08;7&#xff09;线性层F2 …

数据可视化-canvas-svg-Echarts

数据可视化 技术栈 canvas <canvas width"300" height"300"></canvas>当没有设置宽度和高度的时候&#xff0c;canvas 会初始化宽度为 300 像素和高度为 150 像素。切记不能通过样式去设置画布的宽度与高度宽高必须通过属性设置&#xff0c;…

Gateway网关路由以及predicates用法(项目中使用场景)

1.Gatewaynacos整合微服务 服务注册在nacos上&#xff0c;通过Gateway路由网关配置统一路由访问 这里主要通过yml方式说明&#xff1a; route: config: #type:database nacos yml data-type: yml group: DEFAULT_GROUP data-id: jeecg-gateway-router 配置路由&#xff1a;…

Liunx系统编程:进程信号的概念及产生方式

目录 一. 进程信号概述 1.1 生活中的信号 1.2 进程信号 1.3 信号的查看 二. 信号发送的本质 三. 信号产生的四种方式 3.1 按键产生信号 3.2 通过系统接口发送信号 3.2.1 kill -- 向指定进程发送信号 3.2.2 raise -- 当自身发送信号 3.2.3 abort -- 向自身发送进程终止…

使用 Elasticsearch 轻松进行中文文本分类

本文记录下使用 Elasticsearch 进行文本分类&#xff0c;当我第一次偶然发现 Elasticsearch 时&#xff0c;就被它的易用性、速度和配置选项所吸引。每次使用 Elasticsearch&#xff0c;我都能找到一种更为简单的方法来解决我一贯通过传统的自然语言处理 (NLP) 工具和技术来解决…

基于Python的HTTP代理爬虫开发初探

前言 随着互联网的发展&#xff0c;爬虫技术已经成为了信息采集、数据分析的重要手段。然而在进行爬虫开发的过程中&#xff0c;由于个人或机构的目的不同&#xff0c;也会面临一些访问限制或者防护措施。这时候&#xff0c;使用HTTP代理爬虫可以有效地解决这些问题&#xff0…