6-pytorch - 网络的保存和提取

前言

我们训练好的网络,怎么保存和提取呢?
总不可以一直不关闭电脑吧,训练到一半,想结束到明天再来训练,这就需要进行网络的保存和提取了。
本文以前面博客3-pytorch搭建一个简单的前馈全连接层网络(回归问题)的网络进行网络的保存和提取,建议先看完上面博客再来看本博客。

一、生成训练数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# 生成数据(fake data)
x = torch.linspace(-1,1,100).reshape(-1,1)
# 加上点噪声
y = x.pow(2) + 0.2*torch.rand(x.shape)

# 可视化一下数据
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

输出:
在这里插入图片描述

二、网络保存

def save():
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(), 
        torch.nn.Linear(10,1)
    )
    optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)
    loss_func = torch.nn.MSELoss()
    
    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 下面介绍两种不同的保存方法,方法二可能运行速度要快点
    # 保存整个网络的所有
    torch.save(net1, 'net.pkl')     
    # 保存好网络的参数
    torch.save(net1.state_dict(),'net_params.pkl')
    
    # plot result
    plt.figure(1,figsize=(10,3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

【注】:保存整个网络还是保存网络参数,个人建议仅保存参数,这个速度更快。

三、网络提取

def restore_net():
    net2 = torch.load('net.pkl')
    prediction = net2(x)
    
    # plot result
    plt.subplot(132)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
    
def restore_params():
    # 如果只是保留参数的情况,提取时需要再次定义相同网络才行
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
    
    # plot result
    plt.subplot(133)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

四、对保存网络提取进行结果展示

save()
restore_net()
restore_params()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/550748.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

树和二叉树(一)

一、树 非线性数据结构,在实际场景中,存在一对多,多对多的情况。 树( tree)是n (n>0)个节点的有限集。当n0时,称为空树。 在任意一个非空树中,有如下特点。 1.有且仅有一个特定的称为根的节点…

数字IPO:企业增长的新引擎

数字IPO作为一种新型融资方式,可以被视为企业增长的重要加速器。以下是数字IPO如何促进企业增长的几个关键方面: 1.低成本融资与知名度提升:相较于传统的借贷融资方式,数字IPO为企业提供了低成本的资金来源。同时,上市…

神经网络--反向传播算法推导

神经网络–反向传播算法推导 文章目录 神经网络--反向传播算法推导概述神经网络模型反向传导算法 概述 以监督学习为例,假设我们有训练样本集 ( x ( i ) , y ( i ) ) (x^{(i)},y^{(i)}) (x(i),y(i)),那么神经网络算法能提供一种复杂且非线性的假设模型 …

MySQL与Redis缓存一致性的实现与挑战

缓存是提高应用性能的重要手段之一,而 MySQL 和 Redis 是两种常用的数据存储和缓存技术。在许多应用中,常常将 Redis 用作缓存层,以加速对数据的访问。然而,在使用 MySQL 和 Redis 组合时,保持缓存与数据库之间的一致性…

【MATLAB源码-第54期】基于白鲸优化算法(WOA)和遗传算法(GA)的栅格地图路径规划最短路径和适应度曲线对比。

操作环境: MATLAB 2022a 1、算法描述 1.白鲸优化算法(WOA): 白鲸优化算法是一种受白鲸捕食行为启发的优化算法。该算法模拟了白鲸群体捕食的策略和行为,用以寻找问题的最优解。其基本思想主要包括以下几点&#xff…

FMEA赋能可穿戴设备:打造安全可靠的未来科技新宠!

在科技日新月异的今天,可穿戴设备已成为我们生活中不可或缺的一部分。它们以其便携性、智能化和个性化的特点,深受消费者喜爱。然而,随着可穿戴设备市场的快速扩张,其安全性和可靠性问题也日益凸显。为了确保产品质量,…

QT常量中有换行符解决方法--使用中文显示乱码或者编译报错

QT6.3常量中有换行符 int ret2QMessageBox::information(this,QString::fromLocal8Bit("提示"),QString::fromLocal8Bit(("确认启动设备吗?")),QMessageBox::Yes,QMessageBox::No); 确保显示正常,建议每次使用时,中文的前后加一个空…

从零开始写 Docker(十一)---实现 mydocker exec 进入容器内部

本文为从零开始写 Docker 系列第十一篇,实现类似 docker exec 的功能,使得我们能够进入到指定容器内部。 完整代码见:https://github.com/lixd/mydocker 欢迎 Star 推荐阅读以下文章对 docker 基本实现有一个大致认识: 核心原理&…

在线音乐网站的设计与实现

在线音乐网站的设计与实现 摘 要 在社会和互联网的快速发展中,音乐在人们生活中也产生着很大的作用。音乐可以使我们紧张的神经得到放松,有助于开启我们的智慧,可以辅助治疗,达到药物无法达到的效果,所以利用现代科学…

优秀Burp插件 提取JS、HTML中URL插件

Burp Js Url Finder 攻防演练过程中,我们通常会用浏览器访问一些资产,但很多接口/敏感信息隐匿在html、JS文件中,通过该Burp插件我们可以: 1、发现通过某接口可以进行未授权/越权获取到所有的账号密码 2、发现通过某接口可以枚举用…

STM32的GPIO端口的八种模式解析

目录 STM32的GPIO端口的八种模式解析 一、上拉输入模式 二、下拉输入模式 三、浮空输入模式 四、模拟输入模式 五、推挽输出模式 六、开漏输出模式 七、复用推挽输出模式 八、复用开漏输出模式 STM32的GPIO端口的八种模式解析 在学习STM32的过程中,GPIO端口…

【YUV】YUV图像全面详解(一)——格式详解

文章目录 一、前言二、YUV 介绍三、YUV 优点四、YUV 采样格式五、YUV 存储格式六、具体分类详解 一、前言 视频采集芯片输出的码流一般都是 YUV 格式数据流,后续视频处理也是对 YUV 数据流进行编码和解析。所以,了解 YUV 数据流对做视频领域的人而言&am…

【web开发网页制作】html+css家乡长沙旅游网页制作(4页面附源码)

家乡长沙网页制作 涉及知识写在前面一、网页主题二、网页效果Page1、主页Page2、历史长沙Page3、著名人物Page4、留言区 三、网页架构与技术3.1 脑海构思3.2 整体布局3.3 技术说明书 四、网页源码HtmlCSS 五、源码获取5.1 获取方式 作者寄语 涉及知识 家乡长沙网页制作&#x…

学习Rust的第5天:控制流

Control flow, as the name suggests controls the flow of the program, based on a condition. 控制流,顾名思义,根据条件控制程序的流。 If expression If表达式 An if expression is used when you want to execute a block of code if a condition …

如何试用 Ollama 运行本地模型 Mac M2

首先下载 Ollama https://github.com/ollama/ollama/tree/main安装完成之后,启动 ollma 对应的模型,这里用的是 qwen:7b ollama run qwen:7b命令与模型直接交互 我的机器配置是M2 Pro/ 32G,运行 7b 模型毫无压力,而且推理时是用…

C语言案例——输出以下图案(两个对称的星型三角形)

目录 图片代码 图片 代码 #include<stdio.h> int main() {int i,j,k;//先输出上半部图案for(i0;i<3;i){for(j0;j<2-i;j)printf(" ");for(k0;k<2*i;k)printf("*");printf("\n");}//再输出下半部分图案for(i0;i<2;i){for(j0;j&…

《R语言与农业数据统计分析及建模》学习——R基础包的函数

1、R的基础包 基础包是R语言的核心组成部分&#xff0c;构建了R语言的基本功能框架。是R语言默认的安装包&#xff0c;不需要额外安装&#xff0c;使用时无需加载。 2、常用函数及其功能 &#xff08;1&#xff09;数据处理函数&#xff1a;unique()、sort()、subset() # 获…

LRTimelapse for Mac:专业延时摄影视频制作利器

LRTimelapse for Mac是一款专为Mac用户设计的延时摄影视频制作软件&#xff0c;它以其出色的性能和丰富的功能&#xff0c;成为摄影爱好者和专业摄影师的得力助手。 LRTimelapse for Mac v6.5.4中文激活版下载 这款软件提供了直观易用的界面&#xff0c;用户可以轻松上手&#…

OpenHarmony、HarmonyOS和Harmony NEXT 《我们不一样》

1. OpenHarmony 定义与地位&#xff1a;OpenHarmony是鸿蒙系统的底层内核系统&#xff0c;集成了Linux内核和LiteOS&#xff0c;为各种设备提供统一的操作系统解决方案。 开源与商用&#xff1a;OpenHarmony是一个开源项目&#xff0c;允许开发者自由访问和使用其源代码&#…

# Nacos 服务发现-Spring Cloud Alibaba 综合架构实战(五) -实现 gateway 网关。

Nacos 服务发现-Spring Cloud Alibaba 综合架构实战&#xff08;五&#xff09; -实现 gateway 网关。 1、什么是网关&#xff1f; 原来的单体架构&#xff0c;所有的服务都是本地的&#xff0c;UI 可以直接调用&#xff0c;现在按功能拆分成独立的服务&#xff0c;跑在独立的…