14、保存与加载PyTorch训练的模型和超参数

文章目录

  • 1. state_dict
  • 2. 模型保存
  • 3. check_point
  • 4. 详细保存
  • 5. Docker
  • 6. 机器学习常用库

1. state_dict

nn.Module 类是所有神经网络构建的基类,即自己构建一个深度神经网络也是需要继承自nn.Module类才行,并且nn.Module中的state_dict包含神经网络中的权重weight ,偏置bias,过程量buffer,举例说明:

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :NN_Embedding.py
# @Time      :2024/11/26 22:50
# @Author    :Jason Zhang
import torch
from torch import nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = nn.Linear(3, 4)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(4, 5)
        self.batch_norm = nn.BatchNorm2d(4)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        y = self.linear2(x)
        return y


if __name__ == "__main__":
    my_test = MyModel()
    my_keys = my_test.state_dict().keys()
    print(f"my_keys={my_keys}")
  • 结果:
    从结果中看出,跟说明的一样,不仅存的是weight,bias ,还有buffer
y_keys=odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'batch_norm.weight', 'batch_norm.bias', 'batch_norm.running_mean', 'batch_norm.running_var', 'batch_norm.num_batches_tracked'])

2. 模型保存

保存和加载

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :torch_save.py
# @Time      :2024/11/27 21:33
# @Author    :Jason Zhang
import torch
import torchvision.models as models

if __name__ == "__main__":
    run_code = 0
    model = models.vgg16(weights='IMAGENET1K_V1')
    torch.save(model.state_dict(), 'model_weights.pth')
    model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
    model.eval()
    torch.save(model, 'model.pth')

3. check_point

# Define model
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F


class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

4. 详细保存

在训练过程中,我们希望详细保存,以至于我们可以在中断训练中恢复训练。
保存模型

5. Docker

关于Docker方式搭建深度神经网络环境和配置
在这里插入图片描述

6. 机器学习常用库

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/926757.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在鸿蒙应用中 Debug 对开发者的帮助

文章目录 摘要引言Debug 的意义与挑战案例:页面渲染性能优化中的 Bug 排查Debug 过程详解问题定位问题解决优化布局与渲染逻辑 代码详细讲解示例代码详细讲解1. 导入必要模块2. 数据生成3. 使用虚拟列表组件items 属性itemHeight 属性renderItem 属性 4. 返回完整组…

基于多VSG独立微网的多目标二次控制MATLAB仿真模型

“电气仔推送”获得资料(专享优惠) 模型简介 本文将一致性算法引入微电网的二次频率和电压控制,自适应调节功率参考值和补偿电压,同时实现频率电压恢复、有功 无功功率的比例均分以及功率振荡抑制,提高系统的暂态和稳…

洛谷 P2415 集合求和 C语言

题目&#xff1a; https://www.luogu.com.cn/problem/P2415 思路从大佬学来的思路。 如图&#xff1a; 我们可以发现&#xff0c;集合最后出现过的数字是2的&#xff08;n-1&#xff09;次方&#xff0c;所以就很好计算了。 代码如下&#xff1a; #include <iostream&g…

leaflet 的基础使用

目录 一、创建dom节点 二、创建地图 三、添加底图&#xff08;天地图&#xff09;&#xff0c;在地图创建完成后添加底图 本章主要讲述leaflet在vue中的使用&#xff1a; leaflet 详情总目录&#xff1a;传送 一、创建dom节点 <div class"map" id"map_…

Springboot 2.x升级到3.x

运维在扫描项目的时候发现了官方发布的漏洞&#xff0c;https://spring.io/security/cve-2024-38816 我们使用的是spring框架的2.x系列&#xff0c;WebMvc依赖于5.3系列&#xff0c;描述说需要更新到5.3.40&#xff0c;但是官方迟迟不再更新。同时发现官方说5.3系列也就更新到…

【HarmonyOS】@Observed和@ObjectLink嵌套对象属性更改UI不刷新问题

【HarmonyOS】Observed和ObjectLink嵌套对象属性更改UI不刷新问题 一、问题背景 使用了Observed和ObjectLink&#xff0c;修改嵌套对象的属性&#xff0c;UI还是不刷新&#xff0c;常见的问题有以下三种形式&#xff1a; 1.多级嵌套&#xff0c;嵌套对象的类并没有添加Observ…

【rustdesk】客户端和服务端的安装和部署(自建服务器,docker,远程控制开源软件rustdesk)

【rustdesk】客户端和服务端的安装和部署&#xff08;自建服务器&#xff0c;docker&#xff09; 一、官方部署教程 https://rustdesk.com/docs/zh-cn/client/mac/ 官方服务端下载地址 https://github.com/rustdesk/rustdesk-server/releases 我用的docker感觉非常方便&am…

otter 高可用策略

关于otter高可用在设计之初&#xff0c;提供了这样几个基本的需求&#xff1a; 1.网络不可靠&#xff0c;异地机房尤为明显. 2.manager/node的jvm不可靠&#xff0c;需要考虑异常crash情况 3.node的jvm不可靠&#xff0c;需要考虑异常crash的情况 4.数据库不可靠&#xff0c;需…

数据库日志

MySQL中有哪些日志 1&#xff0c;redo log重做日志 redo log是物理机日志&#xff0c;因为它记录的是对数据页的物理修改&#xff0c;而不是SQL语句。 作用是确保事务的持久性&#xff0c;redo log日志记录事务执行后的状态&#xff0c;用来恢复未写入 data file的已提交事务…

STL算法之set相关算法

STL一共提供了四种与set(集合)相关的算法&#xff0c;分别是并集(union)、交集(intersection)、差集(difference)、对称差集(symmetric difference)。 目录 set_union set_itersection set_difference set_symmetric_difference 所谓set&#xff0c;可细分为数学上定义的和…

【连接池】.NET开源 ORM 框架 SqlSugar 系列

.NET开源 ORM 框架 SqlSugar 系列 【开篇】.NET开源 ORM 框架 SqlSugar 系列【入门必看】.NET开源 ORM 框架 SqlSugar 系列【实体配置】.NET开源 ORM 框架 SqlSugar 系列【Db First】.NET开源 ORM 框架 SqlSugar 系列【Code First】.NET开源 ORM 框架 SqlSugar 系列【数据事务…

Redis(4):主从复制

一、主从复制概述 主从复制&#xff0c;是指将一台Redis服务器的数据&#xff0c;复制到其他的Redis服务器。前者称为主节点(master)&#xff0c;后者称为从节点(slave)&#xff1b;数据的复制是单向的&#xff0c;只能由主节点到从节点。   默认情况下&#xff0c;每台Redis…

游戏引擎学习第27天

仓库:https://gitee.com/mrxiao_com/2d_game 欢迎 项目的开始是从零开始构建一款完整的游戏&#xff0c;完全不依赖任何库或引擎。这样做有两个主要原因&#xff1a;首先&#xff0c;因为这非常有趣&#xff1b;其次&#xff0c;因为它非常具有教育意义。了解游戏开发的低层次…

WebSocket协议解析 : 双向实时通信的利器

1. WebSocket是什么 WebSocket是一种在单个TCP连接上进行全双工通信的协议。 WebSocket允许客户端和服务器之间进行实时的双向通信。这意味着服务器可以主动推送数据到客户端&#xff0c;而不需要客户端每次都发送请求来获取数据。这种通信方式通过长连接实现&#xff0c;即连…

分页查询日期格式不对

方式一:在属性上加入注解&#xff0c;对日期进行格式化 方式二:在 WebMvcConfiguration 中扩展Spring MVC的消息转换器&#xff0c;统一对日期类型进行格式化处理 /*** 统一转换处理扩展spring mvc* 后端返回前端的进行统一转化处理* param converters*/Overrideprotected voi…

Mybatis:CRUD数据操作之多条件查询及动态SQL

Mybatis基础环境准备请看&#xff1a;Mybatis基础环境准备 本篇讲解Mybati数据CRUD数据操作之多条件查询 1&#xff0c;编写接口方法 在 com.itheima.mapper 包写创建名为 BrandMapper 的接口。在 BrandMapper 接口中定义多条件查询的方法。 而该功能有三个参数&#xff0c;…

接口测试工具:reqable

背景 在众多接口测试工具中挑选出一个比较好用的接口测试工具。使用过很多工具&#xff0c;如Postman、Apifox、ApiPost等&#xff0c;基本上是同类产品&#xff0c;一般主要使用到的功能就是API接口和cURL&#xff0c;其他的功能目前还暂未使用到。 对比 性能方面&#xff…

Chapter 17 v-model进阶

欢迎大家订阅【Vue2Vue3】入门到实践 专栏&#xff0c;开启你的 Vue 学习之旅&#xff01; 文章目录 1 v-model原理2 表单类组件封装3 v-model简化代码 1 v-model原理 1. 基本原理 v-model 本质上是一个语法糖&#xff0c;它将 value 属性 和 input 事件 的绑定合并为一个指令…

区块链技术如何改变我们的日常生活?

区块链技术&#xff0c;这个听起来很高科技的词&#xff0c;其实已经在悄悄地改变我们的日常生活了。不信&#xff1f;让我给你举几个例子&#xff1a; 防作弊&#xff1a;想象一下&#xff0c;如果有一个账本&#xff0c;每个人都能随时查看&#xff0c;那想在里面作弊就难了…

nacos安装部署

nacos安装部署 1.安装nacos 1.安装nacos nacos的安装很简单下载后解压启动即可&#xff0c;但是在启动前请确保jdk环境正常&#xff1b; 1.首先我们要下载nacos安装包&#xff1a;可以到官网下载&#xff0c;注意我这里使用的是2.1.0版本&#xff1b; 2.下载完成后&#xff0…