pytorch小记(七):pytorch中的保存/加载模型操作

pytorch小记(七):pytorch中的保存/加载模型操作

  • 1. 加载模型参数 (`state_dict`)
    • 1.1 保存模型参数
    • 1.2 加载模型参数
    • 1.3 常见变种
      • 1.3.1 指定加载设备
      • 1.3.2 非严格加载(跳过部分层)
      • 1.3.3 打印加载的参数
  • 2. 加载整个模型
    • 2.1 保存整个模型
    • 2.2 加载整个模型
    • 2.3 注意事项
  • 3. 总结
  • 4. 加载模型的完整代码示例
    • 4.1 保存和加载参数
    • 4.2 保存和加载整个模型
    • 4.3 加载到不同设备
    • 4.4 忽略部分参数(非严格加载)
    • 5. 检查模型是否加载成功


在 PyTorch 中,加载模型通常分为两种情况:加载模型参数(state_dict)加载整个模型。以下是加载模型的所有相关操作及其详细步骤:


1. 加载模型参数 (state_dict)

当仅保存了模型的参数时(使用 model.state_dict() 保存),加载模型的步骤如下:

1.1 保存模型参数

torch.save(model.state_dict(), 'model.pth')
  • 文件内容:只保存模型的参数(权重和偏置)。
  • 优点
    • 节省存储空间。
    • 灵活性更高,可以与不同的模型架构配合使用。
  • 缺点
    • 需要手动重新定义模型结构。

1.2 加载模型参数

  1. 重新定义模型架构:

    model = MyModel()  # 替换为你的模型类
    
  2. 加载参数:

    state_dict = torch.load('model.pth')  # 加载参数字典
    model.load_state_dict(state_dict)    # 加载参数到模型
    
  3. 选择运行设备:

    model.to('cuda')  # 如果需要运行在 GPU 上
    

1.3 常见变种

1.3.1 指定加载设备

  • 如果保存时模型在 GPU 上,而加载时在 CPU 环境中,可以使用 map_location
    state_dict = torch.load('model.pth', map_location='cpu')
    

1.3.2 非严格加载(跳过部分层)

  • 如果保存的参数与模型结构不完全匹配(例如额外的层或不同的顺序),可以使用 strict=False
    model.load_state_dict(state_dict, strict=False)
    

1.3.3 打印加载的参数

  • 可以检查参数字典的内容:
    print(state_dict.keys())
    

2. 加载整个模型

当模型是通过 torch.save(model) 保存时,文件包含了模型的结构和参数,加载更为简单。

2.1 保存整个模型

torch.save(model, 'model_full.pth')
  • 文件内容:包含模型的架构和参数。
  • 优点
    • 无需重新定义模型结构。
    • 直接加载并使用。
  • 缺点
    • 文件依赖于保存时的代码版本(如模型定义)。
    • 文件体积较大。

2.2 加载整个模型

model = torch.load('model_full.pth')
model.to('cuda')  # 如果需要在 GPU 上运行

2.3 注意事项

  • 动态定义的模型
    • 如果模型结构是动态定义的(如包含条件逻辑),保存和加载整个模型可能会依赖于代码的一致性。
    • 确保在加载时导入了与保存时相同的模型类。

3. 总结

操作使用场景优点缺点
保存参数 (state_dict)推荐大多数情况文件小、灵活性高需要手动定义模型架构
保存整个模型模型复杂且固定时不需要重新定义模型,直接加载文件大、依赖保存时的代码版本

4. 加载模型的完整代码示例

4.1 保存和加载参数

import torch
import torch.nn as nn

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 保存参数
model = MyModel()
torch.save(model.state_dict(), 'model.pth')

# 加载参数
model = MyModel()  # 重新定义模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda')  # 运行在 GPU

4.2 保存和加载整个模型

# 保存整个模型
torch.save(model, 'model_full.pth')

# 加载整个模型
model = torch.load('model_full.pth')
model.to('cuda')  # 运行在 GPU

4.3 加载到不同设备

# 保存参数
torch.save(model.state_dict(), 'model.pth')

# 加载到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)

# 加载到 GPU
model.to('cuda')

4.4 忽略部分参数(非严格加载)

# 保存参数
torch.save(model.state_dict(), 'model.pth')

# 加载参数(非严格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)

5. 检查模型是否加载成功

  1. 验证权重是否加载

    for name, param in model.named_parameters():
        print(f"{name}: {param.data}")
    
  2. 进行推理验证

    x = torch.randn(1, 10).to('cuda')  # 假设输入维度为 10
    output = model(x)
    print(output)
    

通过以上操作,你可以灵活加载 PyTorch 模型,无论是仅加载参数还是加载整个模型结构和权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/953157.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Windows 10 ARM工控主板连接I2S音频芯片

在Windows工控主板应用中,音频功能是一项基本的需求,USB声卡在x86/x64 Windows系统上就可直接免驱使用,但这些USB声卡通常不提供ARM上的Windows系统驱动。本文将介绍如何利用安装在ARM上的Windows工控主板——ESM8400的I2S接口、连接WM8960音…

【Rust】错误处理机制

目录 思维导图 引言 一、错误处理的重要性 1.1 软件中的错误普遍存在 1.2 编译时错误处理要求 二、错误的分类 2.1 可恢复错误(Recoverable Errors) 2.2 不可恢复错误(Unrecoverable Errors) 三、Rust 的错误处理机制 3…

提升租赁效率的租赁小程序全解析

内容概要 在如今快节奏的生活中,租赁小程序俨然成为了提升租赁效率的一把利器。无论是个人还是企业,都会因其便捷的功能而受益。简单来说,租赁小程序能让繁琐的租赁流程变得轻松、高效。在这里,我们将带您畅游租赁小程序的海洋&a…

SSM商城设计与实现

摘 要 本文的主要工作是对基于B/S模式及JSP技术的基于智能推荐的b2c销售网站进行了研究与设计。本文首先介绍了基于智能推荐的b2c销售网站的背景,分析比较了国内外相关基于智能推荐的b2c销售网站的运行模式、系统特点与开发技术。然后分析了目前热点的各种Web应用开…

drawDB docker部属

docker pull xinsodev/drawdb docker run --name some-drawdb -p 3000:80 -d xinsodev/drawdb浏览器访问:http://192.168.31.135:3000/

CentOS7下Hadoop集群分布式安装详细图文教程

1、集群规划 主机 角色 DSS20 NameNode DataNode ResourceManager NodeManager DSS21 SecondaryNameNode NameNode NodeManager DSS22 DataNode NodeManager 1.1、环境准备 1.1.1 关闭防火墙 #查看防火墙状态 firewall-cmd --state #停止…

计算机网络——网络层-IPV4相关技术

一、网络地址转换NAT • 网络地址转换 NAT 方法于1994年提出。 • 需要在专用网连接到因特网的路由器上安装 NAT 软件。装有 NAT 软件的路由器叫做 NAT路由器,它至少有一个有效的外部全球地址 IPG。 • 所有使用本地地址的主机在和外界通信时都要在 NAT 路由器上将…

postgresql|数据库|利用sqlparse和psycopg2库批量按顺序执行SQL语句(psyconpg2新优化版本)

一、 旧版批量执行SQL脚本的python文件缺点,优点,以及更新内容 书接上回,postgresql|数据库开发|python的psycopg2库按指定顺序批量执行SQL文件(可离线化部署)_python sql psycopg2-CSDN博客 这个python脚本写了很久了,最近开始…

Node.js——http 模块(二)

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

基于element UI el-dropdown打造表格操作列的“更多⌵”上下文关联菜单

<template><div :class"$options.name"><el-table :data"tableData"><el-table-column type"index" label"序号" width"60" /><!-- 主要列 BEGIN---------------------------------------- --&g…

javascrip基础语法

为什么学习 JavaScript? JavaScript 是 web 开发人员必须学习的 3 门语言中的一门&#xff1a; HTML 定义了网页的内容CSS 描述了网页的布局JavaScript 控制了网页的行为 1. JavaScript 输出 1.1 console.log()&#xff1a;用于将信息输出到浏览器控制台&#xff0c;例如con…

大语言模型预训练、微调、RLHF

转发&#xff0c;如有侵权&#xff0c;请联系删除&#xff1a; 1.【LLM】3&#xff1a;从零开始训练大语言模型&#xff08;预训练、微调、RLHF&#xff09; 2.老婆饼里没有老婆&#xff0c;RLHF里也没有真正的RL 3.【大模型微调】一文掌握7种大模型微调的方法 4.基于 Qwen2.…

django基于Python的校园个人闲置物品换购平台

Django 基于 Python 的校园个人闲置物品换购平台 一、平台概述 Django 基于 Python 的校园个人闲置物品换购平台是专为校园师生打造的一个便捷、环保且充满活力的线上交易场所。它借助 Django 这一强大的 Python Web 开发框架&#xff0c;整合了校园内丰富的闲置物品资源&…

abap安装cl_json类

文章来自 SAP根据源码导入/ui2/cl_json类 - pikeduo - 博客园 新建一个se38程序&#xff0c;把源码放到里&#xff0c;源码如下 *----------------------------------------------------------------------* * CLASS zcl_json DEFINITION *----------------------------…

[OPEN SQL] ORDER BY排序数据

本次操作使用的数据库表为SFLIGHT&#xff0c;其字段内容如下所示 航班(SFLIGHT) 该数据库表中的部分值如下所示 OPEN SQL中的ORDER BY语句用于对数据库表中的数据进行排序 在查询数据的时候使用ORDER BY语句&#xff0c;则查询出来的结果会按照ORDER BY指定的字段进行排序 排序…

STM32F103ZET6战舰版单片机开发板PCB文件 电路原理图

资料下载地址&#xff1a;STM32战舰版单片机开发板PCB文件 电路原理图 1、原理图 2、PCB 3、板子介绍 一、核心芯片与性能 核心芯片&#xff1a;STM32F103ZET6&#xff0c;这是一款基于ARM Cortex-M3内核的高性能单片机。处理器频率&#xff1a;高达72MHz&#xff0c;确保了…

An FPGA-based SoC System——RISC-V On PYNQ项目复现

本文参考&#xff1a; &#x1f449; 1️⃣ 原始工程 &#x1f449; 2️⃣ 原始工程复现教程 &#x1f449; 3️⃣ RISCV工具链安装教程 1.准备工作 &#x1f447;下面以LOCATION代表本地源存储库的安装目录&#xff0c;以home/xilinx代表在PYNQ-Z2开发板上的目录 ❗ 下载Vivad…

GAN的应用

5、GAN的应用 ​ GANs是一个强大的生成模型&#xff0c;它可以使用随机向量生成逼真的样本。我们既不需要知道明确的真实数据分布&#xff0c;也不需要任何数学假设。这些优点使得GANs被广泛应用于图像处理、计算机视觉、序列数据等领域。上图是基于GANs的实际应用场景对不同G…

centos9设置静态ip

CentOS 9 默认使用 NetworkManager 管理网络&#xff0c;而nmcli是 NetworkManager 命令行接口的缩写&#xff0c;是一个用来进行网络配置、管理网络连接的命令工具&#xff0c;可以简化网络设置&#xff0c;尤其是在无头&#xff08;没有图形界面&#xff09;环境下。 1、 cd…

Idea日志乱码

问题描述 前提&#xff1a;本人使用windows Idea运行sh文件&#xff0c;指定了utf-8编码&#xff0c;但是运行过程中还是存在中文乱码 Idea的相关配置都已经调整 字体调整为雅黑 文件编码均调整为UTF-8 调整Idea配置文件 但是还是存在乱码&#xff0c;既然Idea相关配置已经…