动手学深度学习-深度学习计算-5读写文件

目录

加载和保存张量

加载和保存模型参数

小结


到目前为止,我们讨论了如何处理数据, 以及如何构建、训练和测试深度学习模型。 然而,有时我们希望保存训练的模型, 以备将来在各种环境中使用(比如在部署中进行预测)。 此外,当运行一个耗时较长的训练过程时, 最佳的做法是定期保存中间结果, 以确保在服务器电源被不小心断掉时,我们不会损失几天的计算结果。 因此,现在是时候学习如何加载和存储权重向量和整个模型了。

加载和保存张量

对于单个张量,我们可以直接调用loadsave函数分别读写它们。 这两个函数都要求我们提供一个名称,save要求将要保存的变量作为输入。

import torch
from torch import nn
from torch.nn import functional as F

x = torch.arange(4)
torch.save(x, 'x-file')

我们现在可以将存储在文件中的数据读回内存。

x2 = torch.load('x-file')
x2

 

tensor([0, 1, 2, 3])

我们可以存储一个张量列表,然后把它们读回内存。

y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

 

 我们甚至可以写入或读取从字符串映射到张量的字典。 当我们要读取或写入模型中的所有权重时,这很方便。

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

 

加载和保存模型参数

保存单个权重向量(或其他张量)确实有用, 但是如果我们想保存整个模型,并在以后加载它们, 单独保存每个向量则会变得很麻烦。 毕竟,我们可能有数百个参数散布在各处。 因此,深度学习框架提供了内置函数来保存和加载整个网络。 需要注意的一个重要细节是,这将保存模型的参数而不是保存整个模型。 例如,如果我们有一个3层多层感知机,我们需要单独指定架构。 因为模型本身可以包含任意代码,所以模型本身难以序列化。 因此,为了恢复模型,我们需要用代码生成架构, 然后从磁盘加载参数。 让我们从熟悉的多层感知机开始尝试一下。

 

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

接下来,我们将模型的参数存储在一个叫做“mlp.params”的文件中。

torch.save(net.state_dict(), 'mlp.params')

为了恢复模型,我们实例化了原始多层感知机模型的一个备份。 这里我们不需要随机初始化模型参数,而是直接读取文件中存储的参数。

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

 

由于两个实例具有相同的模型参数,在输入相同的X时, 两个实例的计算结果应该相同。 让我们来验证一下。

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

 

 

小结

  • saveload函数可用于张量对象的文件读写。

  • 我们可以通过参数字典保存和加载网络的全部参数。

  • 保存架构必须在代码中完成,而不是在参数中完成。

 

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/949397.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

什么是VLAN?

VLAN(Virtual Local Area Network,虚拟局域网)是一种将物理局域网划分成多个逻辑上独立的虚拟网络的技术。VLAN不依赖于设备的物理位置,而是通过逻辑划分,将局域网内的设备虚拟地组织到同一组。这种技术允许网络管理员…

【君正T31开发记录】12.编译工具相关总结及介绍

移植交叉工具包的时候,发现这是很多工具的集合包;以及写makefile的时候,也需要了解下这些工具的作用及用法,这里总结记录一下常见的工具及相关用法。 g C编译器,用于编译C源代码文件,这个很常见&#xff0…

Appium(一)--- 环境搭建

一、Android自动化环境搭建 1、JDK 必须1.8及以上(1) 安装:默认安装(2) 环境变量配置新建JAVA_HOME:安装路径新建CLASSPath%JAVA_HOME%\lib\dt.jar;%JAVA_HOME%\lib\tools.jar在path中增加:%JAVA_HOME%\bin;%JAVA_HOME%\jre\bin;(3) 验证…

猫的眼睛有几种颜色?

在猫咪神秘而迷人的世界里,它们的眼睛犹如璀璨星辰,闪烁着各异的光芒,颜色丰富多样,令人着迷。 猫眼睛的颜色,粗略一数,常见的便有黄色、蓝色、绿色、棕色,还有那神秘的异瞳。这些色彩并非无端生…

PHP框架+gatewayworker实现在线1对1聊天--接收消息(7)

文章目录 接收消息的原理接收消息JavaScript代码 接收消息的原理 接收消息,就是接受服务器转发的客户端消息。并不需要单独创建函数,因为 ws.onmessage会自动接收消息。我们需要在这个函数里进行处理。因为初始化的时候,已经处理的init类型的…

校园周边美食探索及分享平台的设计与实现(源码+数据库+文档)

亲测完美运行带论文:文末获取源码 文章目录 项目简介(论文摘要)运行视频包含的文件列表(含论文)前台运行截图后台运行截图 项目简介(论文摘要) : 美食一直是与人们日常生活息息相关…

基于深度学习的视觉检测小项目(七) 开始组态界面

开始设计和组态画面。 • 关于背景和配色 在组态画面之前,先要确定好画面的风格和色系。如果有前端经验和美术功底,可以建立自己的配色体系。像我这种工科男,就只能从网络上下载一些别人做好的优秀界面,然后在photo shop中抠取色…

wps版excel中如何快速生成倒序序号?

使用wps办公软件打开的excel文件: 效果如下: 方法: 如:想生成此列序号从101~13序号,倒序排列。 在第1个格子中输入开头的最小数字:13 点击一下【13】这个单元格,然后鼠标放在右下角&#xff…

jupyter出现“.ipynb appears to have died. It will restart automatically.”解决方法

原因 解决方法:更新jupyter的版本 1.打开anaconda prompt 2、更新jupyter版本 在anaconda prompt输入以下指令 conda update jupyter如图:

【Flink CDC】Flink CDC的Schema Evolution表结构演变的源码分析和流程图

Flink CDC版本:3.2.1 说明:本文从SchemaOperator接收到,表结构变更事件开始,表结构变更事件应由source端产生,本文不讨论。 可以先看流程图,研究源码。 参考文章: Flink cdc3.0动态变更表结构—…

【编译原理与技术(李文生第二版)】期末复习

第五章 语法制导定义第五章 设计翻译方案√第六章 语义分析-类型表达式(仅记录,没说考)第七章 参数传递 √第七章 运行栈、display表 √例题1:来源:课件例题2:来源:教材7.4例题3:来源…

SpringBoot环境和Maven配置

SpringBoot环境和Maven配置 1. 环境准备2. Maven2.1 什么是Maven2.2 为什么要学 Maven2.3 创建一个 Maven项目2.4 Maven核心功能2.4.1 项目构建2.4.2 依赖管理2.4.3 Maven Help插件 2.5 Maven 仓库2.5.1本地仓库2.5.2 中央仓库2.5.3 私有服务器, 也称为私服 2.6 Maven设置国内源…

五个不同类型的数据库安装

一、 官方首页下载 打开 MySQL 官方首页,链接为: MySQL 进去社区后选择合适的版本进行安装 安装细节 依图一路next 点击finish结束安装 二、 在线YUM仓库 将该安装包的下载链接在 Linux 操作系统中按照以下命令直接进行下载 三、 二进制本地 通过该链接…

决定系数(R²分数)——评估回归模型性能的一个指标

目录 1.定义 2.计算举例 3. 结果分析 1.定义 R(R平方)分数,也称为决定系数,是用来评估回归模型性能的一个指标。它表示自变量解释因变量变异性的比例。R分数的取值范围通常在0到1之间,其值越接近1,说明…

基于单片机的直流稳压电源的设计(论文+源码)

1.系统方案设计 在本次直流稳压电源的设计中,其关键指标如下: 系统输入电压220V交流系统输出直流0到12V可调,步进可以达到0.1V电流最大输出可以到2A具有短路保护功能可以通过液晶或者数码管等显示设备显示当前输出电压 2. 电路图

排序算法——堆排序

什么是堆 堆就是一种特殊的二叉树,他有以下特点: 堆中某个节点的值总是不大于或不小于其父节点的值; 堆总是一棵完全二叉树。 堆又可以分为大根堆和小根堆 大根堆:根节点最大,每个节点都小于或等于父节点 小跟堆&am…

数据挖掘——聚类

数据挖掘——聚类 聚类K-meansKNN VS K-meansK-Nearest Neighbors (KNN)K-means K中心算法PAM算法 K-modes算法——解决数据敏感的问题KMeans算法 ——解决初始点选择问题K-中心点层次方法AGNES算法——最小距离单链接全链接平均链接 聚类评估K均值和K中心点的优缺点层次化聚类…

在线机考|2024华为实习秋招春招编程题(最新)——第3题_个性化歌单推荐系统_300分(十一)

题目内容 假设你是音乐服务的开发者,为了提高用户体验需要解决推荐歌单的同质化问题,保证推荐给用户的所有歌单不包含相同歌曲的。给定一个包含N个歌单和M条歌单重复记录,每个歌单用一个从1到N的整数编号,歌单重复记录包含两个歌单的ID,表示两个歌单有相同的歌曲。 你的任…

每日OJ_牛客_宵暗的妖怪_DP_C++_Java

目录 牛客_宵暗的妖怪_DP 题目解析 C代码 Java代码 牛客_宵暗的妖怪_DP 宵暗的妖怪 描述: 露米娅作为宵暗的妖怪,非常喜欢吞噬黑暗。这天,她来到了一条路上,准备吞噬这条路上的黑暗。这条道路一共被分为n 部分&…

开源架构的自动化测试策略优化版

最近四篇文章推荐: 开源架构的容器化部署优化版(New) 开源架构的微服务架构实践优化版(New) 开源架构中的数据库选择优化版(New) 开源架构学习指南:文档与资源的智慧锦囊&#xff08…