第五章 模型篇: 模型保存与加载

参考教程
https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

文章目录

  • pytorch中的保存与加载
    • torch.save()
    • torch.load()
    • 代码示例
  • 模型的保存与加载
    • 保存 state_dict()
    • nn.Module().load_state_dict()
    • 加载模型参数
    • 保存模型本身
    • 加载模型本身
  • checkpoint
    • 保存与读取
    • 多个模型的保存与读取

训练好的模型,可以保存下来,用于后续的预测或者训练过程的重启。
为了便于理解模型保存和加载的过程,我们定义一个简单的小模型作为例子,进行后续的讲解。

这个模型里面包含一个名为self.p1的Parameter和一个名为conv1的卷积层。我们没有给模型定义forward()函数,是因为暂时不需要用到该方法。假如你想使用这个模型对数据进行前向传播,会返回 “NotImplementedError: Module [Model] is missing the required “forward” function”

import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.t1 = torch.randn((3,2))
        self.p1 = nn.Parameter(self.t1)
        self.conv1 = nn.Conv2d(1, 1, 5)
net = Model()

pytorch中的保存与加载

首先我们来看一下pytorch中的保存和加载的方法是怎么实现的。

torch.save()

参考文档:https://pytorch.org/docs/stable/generated/torch.save.html
首先来看一下torch.save()函数。

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

torch.save()函数传入的第一个参数,就是我们要保存的对象,它的类别要求是object,而没有限定在nn.Module()或者nn.Parameters()等等之间。说明它可以保存的类型是多种多样的,很灵活。
传入的第二个参数是f,f是一个file-like object或者文件路径,也就是我们想要保存的位置。
后面的几个参数可以不用管它,一般也不会用到。从参数名称可以看到,我们想要保存的object是以pickle的形式保存的。因为pickle支持多种数据类型。
在源码中给了两个使用torch.save的例子。

  >>> # xdoctest: +SKIP("makes cwd dirty")
        >>> # Save to file
        >>> x = torch.tensor([0, 1, 2, 3, 4])
        >>> torch.save(x, 'tensor.pt')
        >>> # Save to io.BytesIO buffer
        >>> buffer = io.BytesIO()
        >>> torch.save(x, buffer)

第一个例子把一个tensor保存在了‘tensor.pt’中,第二个则是将tensor保存在一个buffer中。这都是允许的。

torch.load()

参考文档:https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
再来看一下torch.load()函数。

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)

torch.load()传入的第一个参数f对应着torch.save()中的f,它可以是一个路径,也可以是一个file-like object。
因为我们的模型训练支持cpu也支持gpu等设备,所以我们保存的object也可能处于多种设备环境中,在torch.load()时,这个object会现在CPU上进行反序列化,然后移动到其保存时所处的设备上。假如当前的系统不支持这个设备,就会出现问题,这个时候就需要使用map_location参数,这个参数可以指定你想要放置object的设备,假如没有特别指定,在设备不能实现时就会报错。
weights_only参数可以限定你先要unpickle的object的种类,在使用weights_only参数的同时,你必须明确定义pickle_moduel这个参数(默认为pickle,这也是对的),否则就会报错RuntimeError(“Can not safely load weights when explicit pickle_module is specified”。一般情况下我们也不需要管这个参数。

代码示例

给出一个简单的例子,我们将一个tensor保存在’tensor.pt’中,又使用torch.load()加载进来。

因为保存支持的输入是object,所以我们即使只保存一个字符串也是可以的。(可以,但没必要)
在这里插入图片描述

模型的保存与加载

保存 state_dict()

在之前的章节中有说过,调用model.state_dict()方法时,得到的返回结果是一个orderdict,这个字典的key是模型中参数的名字,value是模型的参数值。
我们通常说的保存模型,保存的就是模型的state_dict(),也就是只保存了模型的参数名和参数值,因此我们是不知道模型的正确结构和forward()中的运算顺序的,你也没有办法直接使用这个state_dict()进行预测。
现在我们保存最开始定义的笨蛋小模型的state_dict()
在这里插入图片描述
我们只保存了模型的参数名和参数值,这个’test.pth’的大小只有1.39 KB (1,428 字节)。

nn.Module().load_state_dict()

def load_state_dict(self, state_dict: Mapping[str, Any],
                        strict: bool = True):

load_state_dict()传入的参数是一个key和value的mapping。这里的keys对应的当前模型自己的state_dict的key,或者说参数名。
在使用load_state_dict()时,该方法会对传入的mapping中的key和模型本身的key进行对比。如果key可以匹配上,就会进行一些操作后,更改模型的key对应的参数值。假如没有匹配上,这个key就会被放进missing_keys或者unexpected_keys中去。
strict这个参数默认是True,所以当有不匹配的key时,就会返回报错。

加载模型参数

我们只保存的模型的参数,所以想要使用这个参数,就需要把它放置在一个现有的模型中去。比如说我们现在有一个新模型model2,它和model1有着一样的结构,但是因为初始化的随机性,它们的参数值可能是不一样的。
在这里插入图片描述
可以看到我们的model2中的参数名和model1一样,但是对应的值不一样。
我们可以使用load_state_dict()方法将model1的参数值根据参数名放到model2中去。
在这里插入图片描述
现在model1和model2中的参数值也都变得一样了。
假如我们手动修改一下我们使用torch.load()加载的state_dict,给它增加一个新的值。加载时就会报错,出现了unexpected_keys。相应地,假如给它删除一个值,就会出现Missing key(s) 的错误,在这里不举例子。

在这里插入图片描述

保存模型本身

torch.save()支持保存的对象是object,而我们的模型本身,作为nn.Module(),自然也是符合object的要求的。因此你也可以直接保存整个模型。
在这里插入图片描述
我们保存的是整个模型,包括了模型的结构和模型的参数名+参数值。这个’test2.pth’的大小是2.39 KB (2,457 字节)。

加载模型本身

我们在上面将整个模型都保存在了’test2.pth’中,因此我们使用torch.load('test2.pth)时,获得的结果就是模型本身,它的类型是nn.Module()。
在这里插入图片描述

checkpoint

保存与读取

假如我们现在有一个保存好的模型’model.pth’,我们想要继续当前模型的状态继续训练。这个时候我们就会发现,'model.pth’中拥有我们模型的参数名和参数值,但是随着我们之前的训练的进行,我们使用的optimizer或者lr_scheluder的状态我们是无法获取的,它们中也有一些参数可能在训练时发生了变化。
因此为了帮助我们重启训练状态,我们需要保存更多的信息,而不是只保存一个模型的state_dict。这些被保存的信息,统称为checkpoint。
在保存checkpoint时,我们同样使用torch.save()方法,在加载时,也是用torch.load()方法。因为torch.save支持保存各种格式,我们可以将想要保存的信息按照key和value组成一个dict,并将这个dict保存下来。
在下面这个例子中,被保存下来的信息包括当前的epoch数,模型的state_dict, 优化器的state_dict还有louss。

# Additional information
torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

在加载时,我们只要按照key取其中的value就可以。

# Additional information
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

多个模型的保存与读取

我们已经知道可以将key和value对应的dict保存成checkpoint的形式,帮助我们重启训练状态。当我们有多个模型时,只不过是增加了要保存到信息而已,方法是一样的。

# Specify a path to save to
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

在这个checkpoint中,我们分别保存了modelA和modelB的state_dict,和它们对应的优化器optimizerA和optimizerB的state_dict。
因此在使用时,只要分别放置到对应的object中就可以。

modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/29848.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

K8s 中 port, targetPort, NodePort的区别

看1个例子: 我们用下面命令去创建1个pod2, 里面运行的是1个nginx kubectl create deployment pod2 --imagenginx当这个POD被创建后, 其实并不能被外部访问, 因为端口映射并没有完成. 我们用下面这个命令去创建1个svc &#xff…

chatgpt赋能python:Python怎样让画笔变粗

Python怎样让画笔变粗 Python是一门强大的编程语言,不仅适用于数据分析和机器学习等领域,也可以用来进行图像处理。在Python中,我们可以使用Pillow库来进行图像操作。在本篇文章中,我们将介绍如何使用Python和Pillow来让画笔变粗…

vue2_markdown的内容目录生成

文章目录 ⭐前言⭐引入vue-markdown💖 全局配置💖 渲染选项💖 取出markdown的标题层级 ⭐结束 ⭐前言 大家好!我是yma16,本文分享在vue2的markdown文本内容渲染和目录生成 背景: 优化个人博客功能&#xf…

delphi的ARM架构支持与System.Win.WinRT库

delphi的ARM架构支持与System.Win.WinRT库 目录 delphi的ARM架构支持与System.Win.WinRT库 一、WinRT 二、delphi的System.Win.WinRT库 2.1、支持ARM芯片指令 2.2、基于WinRT技术的特点 2.3、所以使用默认库而未经转化的服务端应用并不支持ARM架构服务器 2.4、对默认库…

【Linux】初步认识Linux系统

Linux 操作系统 主要作用是管理好硬件设备,并为用户和应用程序提供一个简单的接口,以便于使用。 作为中间人,连接硬件和软件 常见操作系统 桌面操作系统 WindowsmacOsLinux 服务器操作系统 LinuxWindows Server 嵌入式操作系统 Linux …

深度学习图像分类、目标检测、图像分割源码小项目

​demo仓库和视频演示: 银色子弹zg的个人空间-银色子弹zg个人主页-哔哩哔哩视频 卷积网路CNN分类的模型一般使用包括alexnet、DenseNet、DLA、GoogleNet、Mobilenet、ResNet、ResNeXt、ShuffleNet、VGG、EfficientNet和Swin transformer等10多种模型 目标检测包括…

Java关键词synchronized

目录 一、通过卖票系统观察多线程的安全隐患 二、synchronized的基本知识 1.使用synchronized的原因 2.synchronized的作用 3.synchronized的基本格式 a.synchronized加在方法名前 b.synchronized用在方法中 4. Java锁机制 5.synchronized注意事项 三、使用synchronize…

Java Logback日志框架概述及logback.xml详解

日志技术具备的优势 可以将系统执行的信息选择性的记录到指定的位置(控制台、文件中、数据库中)。 可以随时以开关的形式控制是否记录日志,无需修改源代码。 日志体系结构 Logback日志框架 Logback是由log4j创始人设计的另一个开源日志组件&#xff0…

MATLAB读取OpenFOAM的二进制文件

OpenFOAM的文件格式 上面是OpenFOAM二进制文件的格式,我们可以看出,前面21行都是无关的说明文件,22开始时除了一个括号之外,其它的都是数据。 读取数据 读取数据的思路非常简单,忽略不需要的,读取需要的。…

Autoware 跑 Demo(踩坑指南)

Autoware 跑 Demo(踩坑指南) 网上的博客和官方的教程,几乎都是一样的,但实际上跑不起来 Autoware 1.12学习整理–01–运行rosbag示例 Autoware入门学习(三)——Autoware软件功能使用介绍(1/3&a…

【Unity3D】激光雷达特效

1 由深度纹理重构世界坐标 屏幕深度和法线纹理简介中对深度和法线纹理的来源、使用及推导过程进行了讲解,本文将介绍使用深度纹理重构世界坐标的方法,并使用重构后的世界坐标模拟激光雷达特效。 本文完整资源见→Unity3D激光雷达特效。 1)重构…

基于51单片机的智能火灾报警系统温度烟雾光

wx供重浩:创享日记 对话框发送:火灾报警 获取完整源码源文件电路图仿真文件论文报告等 功能简介 51单片机MQ-2烟雾传感ADC0832模数转换芯片DS18B20温度传感器数码管显示按键模块声光报警模块 具体功能: 1、实时监测及显示温度值和烟雾浓度…

管理类联考——英语二——技巧篇——写作——B节——议论文——必备替换句型

议论文必备替换句型 (一)表示很明显/众所周知的句型 It is obvious thatIt is clear thatIt is apparent thatIt is evident thatlt is self-evident thatIt is manifest thatIt is well-knownIt is known to all thatIt is widely-accepted thatIt is crystal-cl…

蓝牙客户端QBluetoothSocket的使用——Qt For Android

了解蓝牙 经典蓝牙和低功耗蓝牙差异 经典蓝牙(Bluetooth Classic):分为基本速率/增强数据速率(BR/EDR), 79个信道,在2.4GHz的(ISM)频段。支持点对点设备通信,主要用于实现无线音频流传输,已成…

Ceph:关于Ceph 集群管理的一些笔记

写在前面 准备考试,整理ceph 相关笔记博文内容涉及,Ceph 管理工具 cephadm,ceph 编排器,Ceph CLI 和 Dashboard GUI 介绍理解不足小伙伴帮忙指正 对每个人而言,真正的职责只有一个:找到自我。然后在心中坚守…

大数据分析平台释疑专用帖第二弹

不管是想要快速了解BI大数据分析平台,还是想要了解BI和自己的需求匹配度,都可关注我们的释疑专用贴。 1、可以分析直播数据吗? 严格来说,只要能够提供数据,就可以做数据可视化分析,直播数据也同理。 如果…

solr快速上手:整合SolrJ实现客户端操作(九)

0. 引言 我们前面学习了solr的服务端基础操作,实际项目中我们还需要在客户端调用solr,就像调用数据库一样,我们可以基于solrJ来实现对solr的客户端操作 1. SolrJ简介 SolrJ 是 Solr官方提供的 Java 客户端库,主要用于与 Solr 服…

Python 请求分页

文章目录 什么是 Python 中的分页带有下一个按钮的 Python 分页没有下一个按钮的 Python 分页无限滚动的 Python 分页带有加载更多按钮的分页 在本文中,我们将了解分页以及如何克服 Python 中与分页相关的问题。 读完本文后,我们将能够了解 Python 分页以…

经典目标检测YOLO系列(1)YOLO-V1算法及其在VOC2007数据集上的应用

经典目标检测YOLO系列(1)YOLO-V1算法及其在VOC2007数据集上的应用 1 YOLO-V1的简述 1.1 目标检测概述 ​ 目标检测有非常广泛的应用, 例如:在安防监控、手机支付中的人脸检测;在智慧交通,自动驾驶中的车辆检测;在智…

Parallel Desktop中按照的centos在切换root用户时,密码正确,但一直切换不成功,显示su: Authentication failure

目录 一、出现问题二、分析问题三、解决问题四、参考资料 一、出现问题 我的密码明明是输入正确的,但又一直给我报下面的错误 二、分析问题 我怀疑是我密码记错了,所以我点击Log Out,重新去输入了一下密码,发现是正确的我确认…