【PyTorch】进阶学习:一文详细介绍 torch.save() 的应用场景、实战代码示例

【PyTorch】进阶学习:一文详细介绍 torch.save() 的应用场景、实战代码示例

在这里插入图片描述

🌈 个人主页:高斯小哥
🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望得到您的订阅和支持~
💡 创作高质量博文(平均质量分92+),分享更多关于深度学习、PyTorch、Python领域的优质内容!(希望得到您的关注~)


🌵文章目录🌵

  • 💾一、模型训练过程中的检查点保存
  • 🚀二、模型部署与推理加速
  • 📚三、模型迁移学习与微调
  • 🔄四、模型版本控制与共享
  • 🎨五、模型的可视化与调试
  • 📚六、模型的序列化与反序列化
  • 🌈七、总结与展望
  • 🤝 期待与你共同进步
  • 相关博客

本文旨在深入探讨PyTorch框架中torch.save()的应用场景,并通过实战代码示例展示其具体应用。如果您对torch.save()的基础知识尚存疑问,博主强烈推荐您首先阅读博客文章《【PyTorch】基础学习:一文详细介绍 torch.save() 的用法和应用》,以全面理解其基本概念和用法。通过这篇文章,您将更好地掌握torch.save()在PyTorch框架中的实际运用,为您的深度学习之旅增添更多助力。期待您的阅读,一同探索PyTorch的无限魅力!

💾一、模型训练过程中的检查点保存

  在深度学习模型的训练过程中,我们经常需要保存模型的中间状态,以便在训练中断时能够恢复训练进度,或者在模型性能达到某个要求时保存当前的最佳模型。torch.save() 在这个场景下发挥着至关重要的作用。

  • 以下是一个简单的例子,展示了如何在训练循环中使用 torch.save() 保存模型的检查点:

    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    # 假设我们有一个简单的模型
    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.fc = nn.Linear(10, 1)
            
        def forward(self, x):
            return self.fc(x)
    
    model = SimpleModel()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()
    
    # 模拟一些训练数据
    x_train = torch.randn(100, 10)
    y_train = torch.randn(100, 1)
    
    # 训练循环
    for epoch in range(100):
        optimizer.zero_grad()
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()
        
        # 每训练几个epoch保存一次模型检查点
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
            }, f'checkpoint_epoch_{epoch+1}.pth')
    

    在这个例子中,我们每10个epoch保存一次模型的检查点,包括当前的epoch数、模型的参数、优化器的状态以及当前的损失值。这样,即使训练过程中遇到中断,我们也可以从最近的检查点恢复训练。

🚀二、模型部署与推理加速

  在模型部署阶段,我们通常需要将模型加载到特定的设备(如CPU或GPU)上进行推理。torch.save() 可以帮助我们保存已经优化过的模型,以便在部署时快速加载并运行。

  • 通过保存和加载模型的参数,我们可以快速地在不同的环境中部署模型,而无需重新训练。此外,将模型加载到GPU上还可以加速推理过程,提高模型的响应速度。

    # 训练完成后,保存最终模型
    final_model_state_dict = model.state_dict()
    torch.save(final_model_state_dict, 'final_model.pth')
    
    # 在部署时加载模型
    loaded_model_state_dict = torch.load('final_model.pth')
    model.load_state_dict(loaded_model_state_dict)
    model.eval()  # 设置模型为评估模式
    
    # 将模型移动到指定设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # 进行推理...
    

📚三、模型迁移学习与微调

  迁移学习是一种利用预训练模型在新任务上进行微调的技术。torch.save() 可以帮助我们保存预训练模型,以便在其他任务中进行迁移学习。

  • 通过保存预训练模型和微调后的模型,我们可以方便地在新任务上利用已有的知识,加速模型的训练过程并提高性能。

    # 假设我们有一个预训练的模型
    pretrained_model = SomePretrainedModel()
    pretrained_model.load_state_dict(torch.load('pretrained_model.pth'))
    
    # 在新任务的数据集上进行微调
    # ...(这里省略了数据加载和训练循环的代码)
    
    # 保存微调后的模型
    finetuned_model_state_dict = pretrained_model.state_dict()
    torch.save(finetuned_model_state_dict, 'finetuned_model.pth')
    

🔄四、模型版本控制与共享

  在模型开发和部署过程中,我们可能需要保存和管理不同版本的模型。torch.save() 结合文件名或路径的管理,可以帮助我们实现模型的版本控制。

  • 通过保存不同版本的模型,并在文件名中明确标注版本号,我们可以轻松地管理和追踪模型的变更历史。同时,将模型文件上传到云存储或共享给团队成员,可以方便地实现模型的共享和协作:

    # 保存不同版本的模型
    torch.save(model1.state_dict(), 'model_v1.pth')
    torch.save(model2.state_dict(), 'model_v2.pth')
    
    # 加载特定版本的模型
    def load_model_version(version):
        if version == 'v1':
            return torch.load('model_v1.pth')
        elif version == 'v2':
            return torch.load('model_v2.pth')
        else:
            raise ValueError("Invalid model version")
    
    # 使用特定版本的模型进行推理
    model_state_dict = load_model_version('v2')
    loaded_model = SimpleModel()
    loaded_model.load_state_dict(model_state_dict)
    loaded_model.eval()
    
    # 模型共享
    # 可以将保存的模型文件上传到云存储或共享给团队成员
    # 其他人可以使用 torch.load() 加载模型进行推理或进一步训练
    

🎨五、模型的可视化与调试

  除了直接用于模型的保存和加载,torch.save() 还可以与一些可视化工具结合使用,帮助我们对模型进行调试和分析。例如,我们可以保存模型的中间层输出或梯度信息,然后使用可视化工具进行展示。

  • 通过保存中间层输出或梯度信息,并结合可视化工具进行分析,我们可以更好地理解模型的内部工作机制,发现潜在的问题并进行调试:

    # 在训练循环中保存中间层输出
    def forward(self, x):
        intermediate_output = self.some_layer(x)
        # 保存中间层输出到文件或内存(这里以保存到文件为例)
        torch.save(intermediate_output, 'intermediate_output.pth')
        return self.fc(intermediate_output)
    
    
    # ...(训练循环代码)
    
    # 在训练完成后,加载中间层输出进行可视化分析
    intermediate_data = torch.load('intermediate_output.pth')
    # 使用可视化工具(如TensorBoard、Matplotlib等)展示中间层输出
    

📚六、模型的序列化与反序列化

  torch.save() 和 torch.load() 的底层机制实际上是 Python 的序列化和反序列化过程。这意味着除了保存和加载模型参数外,我们还可以利用这些函数保存和加载任何可序列化的 Python 对象。

  • 通过序列化和反序列化,我们可以将模型的参数、优化器的状态、超参数以及训练过程中的其他信息保存到一个文件中,并在需要时完整地恢复这些信息。这使得我们能够轻松地重现实验结果、分享训练数据以及进行模型的迁移和复用:

    # 保存一个字典对象
    data_dict = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'hyperparameters': {'lr': 0.01, 'batch_size': 64},
        'training_loss_history': loss_history,  # 假设这是训练过程中的损失记录
    }
    torch.save(data_dict, 'training_data.pth')
    
    # 加载字典对象
    loaded_data_dict = torch.load('training_data.pth')
    model.load_state_dict(loaded_data_dict['model_state_dict'])
    optimizer.load_state_dict(loaded_data_dict['optimizer_state_dict'])
    hyperparams = loaded_data_dict['hyperparameters']
    loss_history = loaded_data_dict['training_loss_history']
    

🌈七、总结与展望

  torch.save() 作为 PyTorch 中一个重要的函数,为模型的保存和加载提供了强大的支持。从模型训练过程中的检查点保存到模型部署与推理加速,再到模型迁移学习与微调,torch.save() 在深度学习项目的各个阶段都发挥着不可或缺的作用。此外,通过结合版本控制、模型可视化与调试以及高级序列化技术,我们可以进一步拓展 torch.save() 的应用场景,提高模型开发和部署的效率。

  展望未来,随着深度学习技术的不断发展和应用领域的拓宽,对模型保存和加载的需求也将更加多样化和复杂化。相信 PyTorch 社区会不断完善和优化 torch.save() 及相关功能,为我们提供更加高效、灵活和安全的模型序列化工具,推动深度学习领域的持续进步。

🤝 期待与你共同进步

  🌱 亲爱的读者,非常感谢你每一次的停留和阅读!你的支持是我们前行的最大动力!🙏

  🌐 在这茫茫网海中,有你的关注,我们深感荣幸。你的每一次点赞👍、收藏🌟、评论💬和关注💖,都像是明灯一样照亮我们前行的道路,给予我们无比的鼓舞和力量。🌟

  📚 我们会继续努力,为你呈现更多精彩和有深度的内容。同时,我们非常欢迎你在评论区留下你的宝贵意见和建议,让我们共同进步,共同成长!💬

  💪 无论你在编程的道路上遇到什么困难,都希望你能坚持下去,因为每一次的挫折都是通往成功的必经之路。我们期待与你一起书写编程的精彩篇章! 🎉

  🌈 最后,再次感谢你的厚爱与支持!愿你在编程的道路上越走越远,收获满满的成就和喜悦!祝你编程愉快!🎉

相关博客

博客文章标链接地址
【PyTorch】基础学习:一文详细介绍 torch.save() 的用法和应用https://blog.csdn.net/qq_41813454/article/details/136777957?spm=1001.2014.3001.5501
【PyTorch】进阶学习:一文详细介绍 torch.save() 的应用场景、实战代码示例https://blog.csdn.net/qq_41813454/article/details/136778437?spm=1001.2014.3001.5501
【PyTorch】基础学习:一文详细介绍 torch.load() 的用法和应用https://blog.csdn.net/qq_41813454/article/details/136776883?spm=1001.2014.3001.5501
【PyTorch】进阶学习:一文详细介绍 torch.load() 的应用场景、实战代码示例https://blog.csdn.net/qq_41813454/article/details/136779327?spm=1001.2014.3001.5501
【PyTorch】基础学习:一文详细介绍 load_state_dict() 的用法和应用https://blog.csdn.net/qq_41813454/article/details/136778868?spm=1001.2014.3001.5501
【PyTorch】进阶学习:一文详细介绍 load_state_dict() 的应用场景、实战代码示例https://blog.csdn.net/qq_41813454/article/details/136779495?spm=1001.2014.3001.5501

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/465317.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue组件封装方案对比——v-if方式与内置component方式

近期在准备搭建一个通用组件库,而公司现有的各个系统也已有自己的组件库只是没抽离出来,但是目前有两套不同的组件封装方案,所以对于方案的选择比较困惑,于是对两种方式进行了对比,结合网上找到的一些开源组件库进行分…

wireshark解析https数据包

Debian11环境: 在linux环境下抓取访问某个https的网址时抓取的数据包都是加密的,导致无法跟踪到数据包流,现在尝试将抓取的https包进行解密。 1、解密https数据包需要设置SSLKEYLOGFILE变量,推荐写入配置文件中。 echo "exp…

Mysql的行级锁

MySQL 中锁定粒度最小的一种锁,是 针对索引字段加的锁 ,只针对当前操作的行记录进行加锁。 行级锁能大大减少数据库操作的冲突。其加锁粒度最小,并发度高,但加锁的开销也最大,加锁慢,会出现死锁。行级锁和存…

Ps:文字工具

工具箱里的文字工具组中包含了四种工具: 横排文字工具 Horizontal Type Tool 直排文字工具 Vertical Type Tool 横排文字蒙版工具 Horizontal Type Mask Tool 直排文字蒙版工具 Vertical Type Mask Tool 快捷键:T 横排文字蒙版工具和直排文字蒙版工具…

C++第六弹---类与对象(三)

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】 目录 1、类的6个默认成员函数 2、构造函数 2.1、概念 2.2、特性 3、析构函数 3.1、概念 3.2、特性 3.3、调用顺序 总结 1、类的6个默认成员函数…

力扣hot100:33. 搜索旋转排序数组(二分的理解)

33.搜索旋转排序数组 ​ 这是一个非常有趣的问题,如果不要求使用O(logn)应该没人会想到吧。。 方法一: 极致的分类讨论。旋转排序数组,无非就是右边的增区间的数小于左边的增区间的数,然后依次排序。因此我们只需要分三类讨论即可…

【测试开发学习历程】MySQL数据类型 + MySQL表创建与操作

前言: 半夜梦到自己没有写今天的博客,结果惊醒起来看一看。 得,真的没写。QWQ 可谓垂死病中惊坐起了。 看看发博的时间6:16,而不是什么整点的,就知道我4点就起来了,不是定时发布&#xff01…

知识积累(五):Transformer 家族的学习笔记

文章目录 1. RNN1.1 缺点 2. Transformer2.1 组成2.2 Encoder2.2.1 Input Embedding(嵌入层)2.2.2 位置编码2.2.3 多头注意力2.2.4 Add & Norm 2.3 Decoder2.3.1 概览2.3.2 Masked multi-head attention 2.4 Transformer 模型的训练和推理2.4.1 训练…

C语言学习过程总结(16)——指针(4)

一、数组名的理解 我们直接使用%p打印出地址来看看&arr【0】 和 arr的不同: int main() {int arr[10] { 1,2,3,4,5,6,7,8,9,10 };printf("&arr[0] %p\n", &arr[0]);printf("arr %p\n", arr);} 、 很容易看出来两者的输出…

ES模块化

Node.js默认并不支持ES模块化,如果需要使用可以采用两种方式。方式一,直接将所有的js文件修改为mjs扩展名。方式二,修改package.json中type属性为module。 导出 默认导出 // 向外部导出内容 export let a 10 export const b "孙悟空…

数据分析 | NumPy

NumPy,全称是 Numerical Python,它是目前 Python 数值计算中最重要的基础模块。NumPy 是针对多维数组的一个科学计算模块,这个模块封装了很多数组类型的常用操作。 使用numpy来创建数组 import numpy as npdata np.array([1, 2, 3]) print…

Unity中UGUI中的PSD导入工具的原理和作用

先说一下PSD导入工具的作用,比如在和美术同事合作开发一个背包UI业务系统时,美术做好效果图后,程序在UGUI中制作好界面,美术说这个图差了2像素,那个图位置不对差了1像素,另外一个图大小不对等等一系列零碎的…

文件包含漏洞(input、filter、zip)

一、PHP://INPUT php://input可以访问请求的原始数据的只读流,将post请求的数据当作php代码执行。当传入的参数作为文件名打开时,可以将参数设为php://input,同时post想设置的文件内容,php执行时会将post内容当作文件内容。从而导致任意代码…

ngnix安装配置

通过yum -y install nginx的方式,有时候会出现No package nginx available的报错。迟迟无法解决。此时要通过下载安装包的方式安装。 1、下载安装包:官方网址 2、解压缩: tar -xzvf nginx-1.23.4.tar.gz cd nginx-1.23.4.tar.gz 3、源码包…

pycharm里test connection连接成功,但是无法同步服务器文件,deployment变灰

如果服务器test connection连接成功,但是无法同步文件。 可以尝试以下方式: 点击tools-deployment-browse remonte host,选择要连接的服务器的文件夹 如果能正常显示服务器文件夹,再点击tools-deployment,注意要把要…

echarts设置柱形图柱间距离

不同系柱形图柱间距离(barGap) {type: bar,itemStyle: {normal: {color: #ddd}},silent: true,barWidth: 40,barGap: 10%, //设置负值 不同系的柱形图会实现重叠效果data: [60, 60, 60, 60] },同系柱形图柱间距离(barCategoryGap&#xff…

谈谈对数据库索引的认识

索引的概念 索引是一种特殊的文件,包含着对数据表里所有记录的引用指针。 可以对表中的一列或多列创建索引,并指定索引的类型,各类索引有各自的数据结构实现。 索引的作用 默认情况下,进行条件查询操作,就是遍历表&a…

27. 移除元素 (Swift版本)

题目描述 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并 原地 修改输入数组。 元素的顺序可以改变。你不需要考虑数组中超出…

【蓝桥杯每日一题】填充颜色超详细解释!!!

为了让蓝桥杯不变成蓝桥悲,我决定在舒适的周日再来一道题。 例: 输入: 6 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 0 0 1 1 1 0 0 0 1 1 0 0 0 0 1 1 1 1 1 1 1 输出: 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 2 2 1 1 1 2 2 2 1 1 2 2 2 2 1 1…

渐开线花键环规的几种加工方法

小伙伴们大家好,今天咱们聊一聊渐开线花键环规的几种加工方法。 渐开线花键环规是在汽车、摩托车以及机械制造工业应用非常广泛的一种检测量具。它属于是一种内花键齿轮,其精度和表面粗糙度要求都比较高。采用的加工方法也比较多,下面详细看…