Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

        下面首先复现这个bug。

import torch
import torch.nn as nn

# 定义一个简单的线性模型,参数类型为整数
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量

# 创建一个简单模型实例
model = SimpleModel()

# 创建一个浮点数作为参数
float_parameter = torch.tensor(0.6)

# 将注册名指向另一个浮点型张量
model.test = float_parameter

# 保存模型
torch.save(model.state_dict(), 'model.pth')

# 直接使用原模型加载
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint)

# 打印加载后的参数
print(model.test)

# 直接使用新模型加载
model_1 = SimpleModel()
model_1.load_state_dict(checkpoint)

# 打印加载后的参数
print(model_1.test)
输出:
tensor(0.6000)
tensor(0)

        可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

        但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

import torch

# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5.1, 6.1], [7.1, 8.1]])

# 查看张量对象的id
print(id(a))
print(id(b))

# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())

# 将张量 b 中的值复制到张量 a 中
a.copy_(b)

# 打印复制后的结果
print(a)

# 查看张量对象的id
print(id(a))
print(id(b))

# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())
输出:
2604425272672
2604426953808  
2604511348096  
2602930352832  
tensor([[5, 6],
        [7, 8]])
2604425272672
2604426953808
2604511348096
2602930352832

        在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

        因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/579321.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Kafka 3.x.x 入门到精通(07)——Java应用场景——SpringBoot集成

Kafka 3.x.x 入门到精通(07)——Java应用场景——SpringBoot集成 4. Java应用场景——SpringBoot集成4.1 创建SpringBoot项目4.1.1 创建SpringBoot项目4.1.2 修改pom.xml文件4.1.3 在resources中增加application.yml文件 4.2 编写功能代码4.2.1 创建配置…

debian配置BIND DNS服务器

前言 局域网内有很多台主机,IP难以记忆。 而修改hosts文件又难以做到配置共享和统一,需要一台内网的DNS服务器。 效果展示 这里添加了一个域名hello.dog,将其指向为192.168.1.100。 同时,外网的域名不会受到影响,…

基于粒子滤波器的电池剩余使用寿命计算matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 粒子滤波器基础 4.2 电池剩余使用寿命建模与预测 4.3 粒子滤波器在电池寿命预测中的应用 5.完整工程文件 1.课题概述 基于粒子滤波器的电池剩余使用寿命计算。根据已知的数据,预测未来…

前端框架编译器之模板编译

编译原理概述 编译原理:是计算机科学的一个分支,研究如何将 高级程序语言 转换为 计算机可执行的目标代码 的技术和理论。 高级程序语言:Python、Java、JavaScript、TypeScript、C、C、Go 等。计算机可执行的目标代码:机器码、汇…

高级IO|从封装epoll服务器到实现Reactor服务器|Part1

从封装epoll_server到实现reactor服务器(part1) 项目复习:从封装epoll_server到实现reactor服务器(part1)EPOLL模式服务器初步 select, poll, epoll的优缺点epoll的几个细节封装epoll_server基本框架先写好创建监听套接字和创建epoll模型可以Accept了吗&#xff1f…

鸿蒙OpenHarmony【轻量系统 运行】 (基于Hi3861开发板)

运行 联网配置 由于Hi3861为WLAN模组,您可以在版本编译及烧录后,通过如下操作,使开发板实现联网功能。 保持Windows工作台和Hi3861 WLAN模组的连接状态,确认串口终端显示正常。 复位Hi3861 WLAN模组,终端界面显示“…

网络攻击日益猖獗,安全防护刻不容缓

“正在排队登录”、“账号登录异常”、“断线重连”......伴随着社交软件用户的一声声抱怨,某知名社交软件的服务器在更新上线2小时后,遭遇DDoS攻击,导致用户无法正常登录。在紧急维护几小时后,这款软件才恢复正常登录的情况。 这…

视频通话实时换脸:支持训练面部模型 | 开源日报 No.235

iperov/DeepFaceLive Stars: 19.7k License: GPL-3.0 DeepFaceLive 是一个用于 PC 实时流媒体或视频通话的人脸换装工具。 可以使用训练好的人脸模型从网络摄像头或视频中交换面部。提供多个公共面部模型,包括 Keanu Reeves、Mr. Bean 等。支持自己训练面部模型以…

基于数据挖掘的斗鱼直播数据可视化分析系统

温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 随着网络直播平台的兴起,斗鱼直播作为其中的佼佼者,吸引了大量用户和观众。为了更好地理解和分析斗鱼直播中的数据,本项目介绍了一个基于数据挖掘的斗鱼直播数据…

【图解计算机网络】简单易懂的https原理解析

简单易懂的https原理解析 https与http的区别混合加密对称加密非对称加密混合加密解析混合加密问题 摘要算法数字证书数字证书原理为什么通过CA证书可以解决中间人攻击的问题呢? https握手流程 https与http的区别 http是明文传输的,非常不安全&#xff0…

呆马科技——智慧应急执法监管平台

在当今社会,安全生产的重要性日益凸显。对于各级政府和企事业单位,当务之急是如何高效地对突发事件进行执法管理。平台应运而生,旨在通过信息化、智能化技术,提升安全管理的效率与准确性。 一、平台特点 整合各类平台的信息资源&…

添加github SSH Key

添加github SSH Key 使用 SSH 协议,您可以连接远程服务器和服务并对其进行身份验证。使用 SSH 密钥,您可以连接到 GitHub,而无需在每次访问时提供您的用户名和个人访问令牌。您还可以使用 SSH 密钥来签署提交。 #3224333333qq.com替换为你自己…

6.NVIC中断配置(ST的精简ARM中断体系)

void NVIC_SetPriorityGrouping(uint32_t PriorityGroup)//设置优先级分组,整个项目共用一个分组 uint32_t NVIC_EncodePriority (uint32_t PriorityGroup, uint32_t PreemptPriority, uint32_t SubPriority) //计算优先级编码值,(组号&…

Python爬虫--Ajax异步抓取腾讯视频评论

在某些网站 ,当我们滑下去的时候才会显示出后面的内容 就像淘宝一样,滑下去才逐渐显示其他商品 这个就是采用 Ajax 做的 然后我们现在就是要编写这样的爬虫。 规律分析: 这个时候就要用到我们的 Fiddler 了 我们需要分析加载评论的规律 …

GateWay具体的使用之局部过滤器接口耗时

1.找规律 局部过滤器命名规则 XXXGatewayFilterFactory, 必须以GatewayFilterFactory结尾。 /* 注意名称约定 * AddRequestHeaderGatewayFilterFactory 配置的时候写的是 AddRequestHeader * AddRequestParameterGatewayFilterFactory 配置的时候写的是 A…

【语音识别】搭建本地的语音转文字系统:FunASR(离线不联网即可使用)

参考自: 参考配置:FunASR/runtime/docs/SDK_advanced_guide_offline_zh.md at main alibaba-damo-academy/FunASR (github.com)参考配置:FunASR/runtime/quick_start_zh.md at 861147c7308b91068ffa02724fdf74ee623a909e alibaba-damo-aca…

上位机图像处理和嵌入式模块部署(树莓派4b下使用sqlite3)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 嵌入式设备下面,有的时候也要对数据进行处理和保存。如果处理的数据不是很多,一般用json就可以。但是数据如果量比较大&…

CISSP通关学习笔记:共计 9 个章节(已完结)

1. 笔记说明 第 0 章节为开篇介绍,不包括知识点。第 1 - 8 章节为知识点梳理汇总,8 个章节的知识框架关系如下图所示: 2. 笔记目录 「 CISSP学习笔记 」0.开篇「 CISSP学习笔记 」1.安全与风险管理「 CISSP学习笔记 」2.资产安全「 CISSP…