Pytorch-Lighting使用教程(MNIST为例)

一、pytorch-lighting简介

1.1 pytorch-lighting是什么

pytorch-lighting(简称pl),基于 PyTorch 的框架。它的核心思想是,将学术代码模型定义、前向 / 反向、优化器、验证等)与工程代码for-loop,保存、tensorboard 日志、训练策略等)解耦开来,使得代码更为简洁清晰。

工程代码经常会出现在深度学习代码中,PyTorch Lightning 对这部分逻辑进行了封装,只需要在 Trainer 类中简单设置即可调用,无需重复造轮子。

1.2 pytorch-lighting优势

  • 通过抽象出样板工程代码,可以更容易地识别和理解ML代码;
  • Lightning的统一结构使得在现有项目的基础上进行构建和理解变得非常容易;
  • Lightning 自动化的代码是用经过全面测试、定期维护并遵循ML最佳实践的高质量代码构建的;

pytorch-lighting最大的好处:

(1)是摆脱了硬件依赖,不需要在程序中显式设置.cuda() 等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train() 等代码,这也会自动切换

(2)支持分布式训练,自动分配资源,能够很好的进行大规模的DL训练

(3)代码量较少,只需要关心关键的逻辑代码,而框架性的东西,pytorch-lighting已经帮你解决(如自动训练,自动debug)


二、基于Pytorch-Lighting框架训练MNIST模型

1、仅仅训练

下载的所有的数据集都用于训练(没有评估和测试过程,不清楚模型的好与坏)。

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl

# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)

# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):

    # 3.1 加载基础模型
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    # 3.2 训练过程设置
    def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    # 3.3 优化器设置
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 4. 定义训练数据
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# 6. 开始训练
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

class LitAutoEncoder(pl.LightningModule):

  • 将模型定义代码写在__init__
  • 定义前向传播逻辑
  • 将优化器代码写在 configure_optimizers 钩子中
  • 训练代码写在 training_step 钩子中,可使用 self.log 随时记录变量的值,会保存在 tensorboard 中
  • 验证代码写在 validation_step 钩子中
  • 移除硬件调用.cuda() 等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train() 等代码,这也会自动切换
  • 根据需要,重写其他钩子函数,例如 validation_epoch_end,对 validation_step 的结果进行汇总;train_dataloader,定义训练数据的加载逻辑
  • 实例化 Lightning Module 和 Trainer 对象,传入数据集
  • 定义训练参数和回调函数,例如训练设备、数量、保存策略,Early Stop、半精度等

运行结果

2、添加验证和测试模块

在训练之后,加入了测试和评估功能,能更好的指导模型的性能。

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)

# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):

    # 3.1 加载基础模型
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    # 3.2 训练过程设置
    def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    # 3.3 测试过程设置
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    # 3.4 验证过程设置
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    # 3.5 优化器设置
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 4. 定义训练数据
'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
'''

# 4.1 分别下载并加载训练集和测试集
transform = transforms.ToTensor()
train_set = datasets.MNIST(os.getcwd(), download=False, train=True, transform=transform)
test_set = datasets.MNIST(os.getcwd(), download=False, train=False, transform=transform)

# 4.2 将训练集中的20%用于验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 4.3 设置种子
seed = torch.Generator().manual_seed(42)

# 4.4 从训练集中随机拿到80%的测试集和20%的验证集
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

# 4.5 分别加载训练集和测试集
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# 6. 实例化Trainer
trainer = pl.Trainer(max_epochs=10)

# 7. 开始训练和评估
trainer.fit(autoencoder, train_loader, valid_loader)

# 8.开始测试
trainer.test(model=autoencoder, dataloaders=DataLoader(test_set))

3、权重 & 超参的保存和加载

当模型正在训练时,性能会随着它继续看到更多数据而发生变化。

1)训练完成后,使用在训练过程中发现的最佳性能相对应的权重;

2)权重可以让训练在训练过程中断的情况下从原来的位置恢复。

保存权重:Lightning 会自动为你在当前工作目录下保存一个权重,其中包含上一次训练的状态。这能确保在训练中断的情况下恢复训练。

3.1 自动在当前目录下保存checkpoint

# simply by using the Trainer you get automatic checkpointing
trainer = Trainer()

3.2 指定checkpoint保存的目录

# saves checkpoints to 'some/path/' at every epoch end
trainer = Trainer(default_root_dir="some/path/")

3.3 加载checkpoint

# trainer.fit(autoencoder, train_loader, valid_loader, ckpt_path="/home/gvlib_ljh/class/Lightning_mnist/lightning_logs/version_25/checkpoints/epoch=9-step=160000.ckpt")

4、可视化

在模型开发中,我们跟踪感兴趣的值,例如validation_loss,以可视化模型的学习过程。模型开发就像驾驶一辆没有窗户的汽车,图表和日志提供了了解汽车行驶方向的窗口。借助 Lightning,可以可视化任何您能想到的东西:数字、文本、图像、音频。

要跟踪指标,只需使用 LightningModule 内可用的 self.log 方法。

class LitModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        value = ...
        self.log("some_value", value)

同时记录多个指标:

values = {"loss": loss, "acc": acc, "metric_n": metric_n}  # add more items if needed
self.log_dict(values)

4.1 命令行查看

要在命令行进度栏中查看指标,请将 prog_bar 参数设置为 True。

self.log(..., prog_bar=True)

4.2 浏览器查看

默认情况下,Lightning 使用 Tensorboard(如果可用)和一个简单的 CSV 记录器

在命令行中输入(注意:一定是lightning_logs所在的目录):

tensorboard --logdir=lightning_logs/

Tensorboard界面:

Tensorboard输出分析:

完整的代码:

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

from pytorch_lightning.loggers import TensorBoardLogger


# 设置浮点矩阵乘法精度为 'medium'
torch.set_float32_matmul_precision('medium')

# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)

# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):

    # 3.1 加载基础模型
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    # 3.2 训练过程设置
    def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)

        batch_idx_value = batch_idx + 1

        print(" ")
        values = {"loss": loss, "batch_idx_value": batch_idx_value}  # add more items if needed
        self.log_dict(values)

        # 在命令行界面显示log
        '''
        sync_dist=True:分布式计算,数据同步标志
        prog_bar=True:在控制台上显示
        '''
        self.log("train_loss", loss, sync_dist=True, prog_bar=True)
        return loss

    # 3.3 测试过程设置
    def test_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss, sync_dist=True, prog_bar=True)

    # 3.4 验证过程设置
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss, sync_dist=True, prog_bar=True)

    # 3.5 优化器设置
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 4. 定义训练数据
'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
'''

# 4.1 分别下载并加载训练集和测试集
transform = transforms.ToTensor()
train_set = datasets.MNIST(os.getcwd(), download=False, train=True, transform=transform)
test_set = datasets.MNIST(os.getcwd(), download=False, train=False, transform=transform)

# 4.2 将训练集中的20%用于验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 4.3 设置种子
seed = torch.Generator().manual_seed(42)

# 4.4 从训练集中随机拿到80%的测试集和20%的验证集
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

# 4.5 分别加载训练集和测试集
train_loader = DataLoader(train_set, batch_size=256, num_workers=5)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=5)

# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# 6. 实例化Trainer
trainer = pl.Trainer(max_epochs=1000)

# 7. 开始训练和评估
trainer.fit(autoencoder, train_loader, valid_loader)
# 7. 从checkpoint恢复状态
# trainer.fit(autoencoder, train_loader, valid_loader, ckpt_path="/home/gvlib_ljh/class/Lightning_mnist/lightning_logs/version_25/checkpoints/epoch=9-step=160000.ckpt")

# 8.开始测试
trainer.test(model=autoencoder, dataloaders=DataLoader(test_set))

参考:

https://zhuanlan.zhihu.com/p/659631467

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/667947.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java大厂面试题第2季

一、本课程前提要求和说明 面试题1: 面试题2: 面试题3: 面试题4: 面试题5: 高频最多的常见笔试面试题目 ArrayList HashMap 底层是什么东东 JVM/GC 多线程与高并发 java集合类

移动系统编程-Ionic 页面(Ionic Pages)

Ionic 页面 Ionic 应用程序和大多数移动应用程序使用页面来利用小显示区域。Ionic 源代码结构中,每个页面都在一个单独的目录中,以便将页面的所有信息集中在一起。例如,tabs 启动应用程序在 app 目录中的目录结构如下。请看是否能看到与 Angu…

米博无布洗地机:颠覆清洁体验,引领家庭清洁风尚

在科技日新月异的今天,家庭清洁方式也正在经历着一场翻天覆地的变化。 传统洗地机的诸多痛点 传统的洗地机虽然在一定程度上解放了人们的双手,然而其存在的诸多问题与痛点也随着时间流逝而逐渐凸显,成为了越来越多消费者心中的困扰。 传统洗地…

css网格背景样式

空白内容效果图 在百度页面测试效果 ER图效果 注意&#xff1a;要给div一个宽高 <template><div class"grid-bg"></div> </template><style scoped> .grid-bg {width: 100%;height: 100%;background: url(data:image/svgxml;base…

深入浅出Java多线程

系列文章目录 文章目录 系列文章目录前言一、多线程基础概念介绍线程的状态转换图线程的调度一些常见问题 二、Java 中线程的常用方法介绍Java语言对线程的支持Thread常用的方法三、线程初体验&#xff08;编码示例&#xff09; 前言 前些天发现了一个巨牛的人工智能学习网站&…

LeeCode热题100(爬楼梯)

爬楼梯这个题我断断续续看了不下5遍&#xff0c;哪次看都是懵逼的&#xff0c;就会说是满足动态规划&#xff0c;满足斐波那契数列&#xff0c;也不说为什么。 本文一定让你明白怎么分析这个题的规律&#xff08;利用数学的递推思想来分析&#xff09;&#xff0c;看不懂来打我…

SleepFM:利用对比学习预训练的多模态“睡眠”基础模型

大模型技术论文不断&#xff0c;每个月总会新增上千篇。本专栏精选论文重点解读&#xff0c;主题还是围绕着行业实践和工程量产。若在阅读过程中有些知识点存在盲区&#xff0c;可以回到如何优雅的谈论大模型重新阅读。另外斯坦福2024人工智能报告解读为通识性读物。若对于如果…

Go微服务: 封装nacos-sdk-go的v2版本与应用

概述 基于前文&#xff1a;https://active.blog.csdn.net/article/details/139213323我们基于此SDK提供的API封装一个公共方法来用于生产环境 封装 nacos-sdk-go 我们封装一个 nacos.go 文件, 这个是通用的工具库 package commonimport ("fmt""github.com/nac…

如何查看谁连接到了你的Wi-Fi网络?这里提供几种方法或工具

序言 你知道谁连接到你路由器的Wi-Fi网络吗?查看从路由器或计算机连接到Wi-Fi网络的设备列表,找出答案。 请记住,现在很多设备都可以连接到了你的Wi-Fi,该名单包括笔记本电脑、智能手机、平板电脑、智能电视、机顶盒、游戏机、Wi-Fi打印机等。 使用GlassWire Pro查看连接…

蒙自源六一童趣献礼:纯真餐单,唤醒你的童年味蕾

当岁月的车轮滚滚向前&#xff0c;我们总会怀念那些逝去的时光&#xff0c;尤其是那段纯真无瑕的童年。当六一儿童节来临&#xff0c;心底的那份童趣与回忆总会被轻轻触动。从5月25日起&#xff0c;蒙自源旗下各大门店为所有小朋友&#xff0c;以及那些心怀童真的大人们&#x…

Mac vm虚拟机激活版:VMware Fusion Pro for Mac支持Monterey 1

相信之前使用过Win版系统的朋友们对这款VMware Fusion Pro for Mac应该都不会陌生&#xff0c;这款软件以其强大的功能和适配能力广受用户的好评&#xff0c;在Mac端也同样是一款最受用户欢迎之一的虚拟机软件&#xff0c;VM虚拟机mac版可以让您能够轻松的在Apple的macOS和Mac的…

华为交换机、路由器配置查询、用户界面常见配置及安全加固

华为交换机、路由器配置查询、用户界面常见配置及安全加固。 一、查询命令 1.常用的查询命令 查看当前生效的配置信息&#xff1a; display current-configuration //正在生效的配置&#xff0c;默认参数不显示。查看当前视图下生效的配置信息&#xff1a; display this //常…

数字IC基础:主要的FPGA厂商

相关阅读 数字IC基础https://blog.csdn.net/weixin_45791458/category_12365795.html?spm1001.2014.3001.5482 Xilinx&#xff08;现已被AMD收购&#xff09; Xilinx, 成立于1984年&#xff0c;是FPGA&#xff08;现场可编程门阵列&#xff09;技术的创始者和市场领导者。该公…

Mac下载Homebrew

通过command空格搜索终端打开 直接输入 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" 然后输入电脑密码 然后直接回车等待安装完成 注意⚠️&#xff1a;如果出现报错/opt/homebrew/bin is not in your PATH…

python操作mongodb底层封装并验证pymongo是否应该关闭连接

一、pymongo简介 github地址&#xff1a;https://github.com/mongodb/mongo-python-driver mymongo安装命令&#xff1a;pip install pymongo4.7.2 mymongo接口文档&#xff1a;PyMongo 4.7.2 Documentation PyMongo发行版包含Python与MongoDB数据库交互的工具。bson包是用…

【设计模式深度剖析】【1】【行为型】【模板方法模式】| 以烹饪过程为例加深理解

&#x1f448;️上一篇:结构型设计模式对比 文章目录 模板方法模式定义英文原话直译如何理解呢&#xff1f; 2个角色类图代码示例 应用优点缺点使用场景 示例解析&#xff1a;以烹饪过程为例类图代码示例 模板方法模式 模板方法模式&#xff08;Template Method Pattern&…

生信分析进阶4 - 比对结果的FLAG和CIGAR信息含义与BAM文件指定区域提取

BAM文件时存储比对数据的常用格式&#xff0c;可用于短reads和长reads数据。BAM是二进制压缩格式&#xff0c;SAM文件为其纯文本格式&#xff0c;CRAM为BAM的高压缩格式&#xff0c;IO效率相比于BAM略差&#xff0c;但是占用存储空间更小。 1. BAM文件的比对信息 BAM的核心信…

软件测试期末复习

第四章 边界黑盒测试续 4.3边界值设计方法 1.边界值设计方法&#xff1a;故障往往出现在定义域或边界值上。通常边界值分析法是作为对等价类划分法的补充。其测试用例来自等价类的边界。是对输入或输出的边界值进行测试的一种黑盒测试方法。 2.边界值分析法和等价类划分法的…

在Visual Studio2022中同一个项目里写作业,有多个cpp文件会报错

为了省事&#xff0c;在同一个项目里写很多个题目&#xff0c;结果只有一个cpp文件时没出错&#xff0c;写了2个cpp文件再想运行时就出错了&#xff1b; 将不相关的cpp文件移出去 在源文件中对其点击右键&#xff0c;找到“从项目中排除”&#xff1b; 结果如图&#xff0c;剩…

网络原理-三

一、连接管理 建立连接,断开连接 建立连接,TCP有连接的. 客户端执行 socket new Socket(SeverIP,severPort); -> 这个操作就是在建立连接. 上述只是调用socket api,真正建立连接的过程,实在操作系统内核完成的. 内核是怎样完成上述的 " 建立连接 "过程的…