深度学习之pth转换为onnx时修改模型定义‌

文章目录

    • 概述
    • 实现步骤
    • python代码

概述

在将PyTorch模型(.pth文件)转换为ONNX格式时,通常的转换过程是通过torch.onnx.export函数来实现的。这个过程主要是将PyTorch模型的计算图导出为ONNX格式,以便在其他框架或环境中使用。

在转换过程中,你通常不能直接在原有的PyTorch模型前后“添加函数”,因为ONNX导出的是静态计算图,它表示的是模型在某一时刻的结构和参数,而不是动态的执行过程。不过,你可以通过‌修改模型定义‌的方式来实现类似的功能。

在导出模型之前,你可以修改模型的定义,将你想要添加的功能集成到模型本身中。例如,如果你想要在模型的前向传播过程中添加某些预处理或后处理步骤,你可以直接将这些步骤写入模型类的forward方法中。

实现步骤

  1. 定义新模型类
  2. 将原模型添加为新模型的成员
  3. 在新模型的forward中,在原有模型之前或之后添加新的层
  4. 初始化新模型
  5. 加载原有模型参数
  6. 导出onnx

python代码

from model import *
from utils import *
from data import *
import cv2


# 这是你修改后的模型定义,集成了额外功能
class ModifiedModel(nn.Module):
    def __init__(self):
        super(ModifiedModel, self).__init__()
        num_classes = 3
        self.original_model = UNet(3, num_classes)
        # 新增的层或修改后的层
        # self.new_layer = torch.argmax()

    def forward(self, x):
        # 在原始模型前添加预处理(如果需要)
        x = self.original_model(x)
        # 在原始模型后添加后处理或新增层的逻辑
        # x = self.new_layer(x)
        x = torch.argmax(x[0], dim=0).unsqueeze(0) * 255
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
weight_path = 'params/unet_CXR.pth'

pretrained_dict = torch.load(weight_path)
# 初始化修改后的模型,并加载原始模型的参数
modified_model = ModifiedModel()
modified_model.to(device)
# 假设我们只关心原始模型的参数,可以直接将其赋值给修改后的模型中的对应部分
modified_model.original_model.load_state_dict(pretrained_dict)
modified_model.eval()

img_data = torch.randn(1, 3, 256, 256)
img_data = img_data.to(device)
out_data = modified_model(img_data)

out_data = out_data.cpu().detach().numpy()
out_data = np.array(out_data, dtype='uint8')

cv2.imshow('out', out_data[0, :, :])
cv2.waitKey(0)

# 将模型导出为 ONNX 格式
is_dynamic_axes = False
if is_dynamic_axes:
    input_name = 'input'
    output_name = 'output'
    torch.onnx.export(modified_model,
                      img_data,
                      r"params/net_model_modify.onnx",
                      opset_version=11,
                      input_names=[input_name], 
                      output_names=[output_name], 
                      dynamic_axes={
                          input_name: {0: 'batch_size', 2: 'in_width', 3: 'int_height'},
                          output_name: {0: 'batch_size', 2: 'out_width', 3: 'out_height'}},
                      verbose=True)
else:
    input_name = 'input'
    output_name = 'output'
    torch.onnx.export(modified_model,
                      img_data,
                      r"params/net_model_modify.onnx",
                      opset_version=11,
                      input_names=[input_name], 
                      output_names=[output_name],  
                      verbose=True)

原有模型和修改后的模型onnx计算图如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/929224.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LinuxTCP编程详解

目录 一、创建套接字 二、绑定套接字 示例 三、监听套接字 四、等待套接字 五、服务器端示例 六、连接套接字 七、客户端示例 八、Send和Recv C/S模式:Client客户端、Server服务器 TCP编程基于socket套接字实现,因此也习惯称为Socket编程 一、…

深入解析级联操作与SQL完整性约束异常的解决方法

目录 前言1. 外键约束与级联操作概述1.1 什么是外键约束1.2 级联操作的实际应用场景 2. 错误分析:SQLIntegrityConstraintViolationException2.1 错误场景描述2.2 触发错误的根本原因 3. 解决方法及优化建议3.1 数据库级别的解决方案3.2 应用层的解决方案 4. 友好提…

「Mac畅玩鸿蒙与硬件41」UI互动应用篇18 - 多滑块联动控制器

本篇将带你实现一个多滑块联动的控制器应用。用户可以通过拖动多个滑块,动态控制不同参数(如红绿蓝三色值),并实时显示最终结果。我们将以动态颜色调节为例,展示如何结合状态管理和交互逻辑,打造一个高级的…

数字IC前端学习笔记:脉动阵列的设计方法学(以串行FIR滤波器为例)

相关阅读数字IC前端_日晨难再的博客-CSDN博客https://blog.csdn.net/weixin_45791458/category_12173698.html?spm1001.2014.3001.5482 引言 脉动结构(也称为脉动阵列)表示一种有节奏地计算并通过系统传输数据的处理单元(PEs)网络。这些处理单元有规律地…

图片预处理技术介绍4——降噪

图片预处理 大家好,我是阿赵。   这一篇将两种基础的降噪算法。   之前介绍过均值模糊和高斯模糊。如果从降噪的角度来说,模糊算法也算是降噪的一类,所以之前介绍的两种模糊可以称呼为均值降噪和高斯降噪。不过模糊算法对原来的图像特征的…

【数据中心建设资料】数据中心安全建设解决方案,数据中心整理解决方案,数据中心如何做到安全保障,数据中台全方案(Word全原件)

第一章 解决方案 1.1 建设需求 1.2 建设思路 1.3 总体方案 信息安全系统整体部署架构图 1.3.1 IP准入控制系统 1.3.2 防泄密技术的选择 1.3.3 主机账号生命周期管理系统 1.3.4 数据库账号生命周期管理系统 1.3.5 双因素认证系统 1.3.6 数据库审计系统 1.3.7 数据脱敏系统 1.3.8…

十,[极客大挑战 2019]Secret File1

点击进入靶场 查看源代码 有个显眼的紫色文件夹,点击 点击secret看看 既然这样,那就回去查看源代码吧 好像没什么用 抓个包 得到一个文件名 404 如果包含"../"、"tp"、"input"或"data",则输出"…

UE5 C++ 不规则按钮识别,复选框不规则识别 UPIrregularWidgets

插件名称:UPIrregularWidgets 插件包含以下功能 你可以点击任何图片,而不仅限于矩形图片。 UPButton、UPCheckbox 基于原始的 Button、Checkbox 扩展。 复选框增加了不规则图像识别功能,复选框增加了悬停事件。 欢迎来到我的博客 记录学习过…

【数据结构】手搓链表

一、定义 typedef struct node_s {int _data;struct node_s *_next; } node_t;typedef struct list_s {node_t *_head;node_t *_tail; } list_t;节点结构体(node_s): int _data;存储节点中的数据struct node_s *_next;:指向 node…

【Win11的Bug】无法在文件夹中创建txt文件

问题 右键只能新建文件夹 , 无法新建txt文本文档 解决办法 将注册表中的一个参数从1改为0即可. 具体内容: WinR输入regeditHKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System 将1改为0(下面这张图我已改过) 4.然后重新启动电脑即可 小技…

基于Matlab三点雨流计数法的载荷时间历程分析与循环疲劳评估

随着工程领域中机械设备和结构系统的复杂性不断增加,疲劳分析成为评估其可靠性与使用寿命的关键环节。载荷时间历程数据在疲劳分析中扮演着重要角色,而雨流计数法作为经典的循环计数方法,能够有效地从载荷时间历程中提取疲劳载荷循环信息。本…

二、部署docker

二、安装与部署 2.1 安装环境概述 Docker划分为CE和EE,CE为社区版(免费,支持周期三个月),EE为企业版(强调安全,付费使用)。 Docker CE每月发布一个Edge版本(17.03&…

python + PPT

ppt转化为word 对于PPT中的文本内容,如何转化为word内容呢?下面的代码可以实现如下功能 代码如下: from pptx import Presentation from docx import Documentdef clean_text(text):"""清理文本,移除控制字符&quo…

linux运维命令

防火墙相关命令 防火墙规则查看 firewall-cmd --list-all 禁ping firewall-cmd --permanent --add-rich-rulerule protocol valueicmp drop firewall-cmd --reload 执行完以上命令后,通过firewall-cmd --list-all查看规则生效情况 firewall-cmd --list-all 其…

论文笔记:Asymptotic Midpoint Mixup for Margin Balancing and Moderate Broadening

1. Motivation 在特征空间中,特征之间的collapse会导致representation learning 中的关键问题,这是因为特征之间不可区分。基于线性插值的增强方法(例如 mixup)已经显示出它们在缓解类间塌陷(称为inter-class collaps…

Elasticsearch之索引的增删改查(6.x版本)-yellowcong

1. 节点信息查看 #查看集群健康情况 curl -X GET localhost:9200/_cat/health?v&pretty#查看节点信息 curl -X GET localhost:9200/_cat/nodes?v&pretty 2. 索引管理 在es中,索引就相当于是mysql中的库了。 #查看索引列表 curl -X GET localhost:9200/…

Linux红帽认证有哪些等级?RHCE含金量如何?

工 仲 好:IT运维大本营哈喽,大家好! 红帽认证,作为一个备受瞩目的认证体系,其完善程度在行业内有口皆碑。 它清晰地划分为三个等级,分别是初级、中级和高级,每个等级都具有独特的要求和价值。…

ArcGIS求取多个点距离线要素的最近距离以及距离倒数

本文介绍在ArcMap软件中,对于点要素中的每一个点,求取其距离最近的道路的距离、距离倒数的方法。 首先,看一下本文的需求。现在已知一个点要素,其中含有多个点,假设每一个点表示城市中的一家商店;同时&…

大数据实验E5HBase:安装配置,shell 命令和Java API使用

实验目的 熟悉HBase操作常用的shell 命令和Java API使用; 实验要求 掌握HBase的基本操作命令和函数接口的使用; 实验平台 操作系统:Linux(建议Ubuntu16.04或者CentOS 7 以上);Hadoop版本:3…

跑一下pyapp

文档:How-to - PyApp 首先没有rust要安装 安装 Rust - Rust 程序设计语言 查看是否安装成功 然后clone下pyapp https://github.com/ofek/pyapp/releases/latest/download/source.zip -OutFile pyapp-source.zip 进入目录中,cmd,设置环境…