【深度学习】pytorch pth模型转为onnx模型后出现冗余节点“identity”,onnx模型的冗余节点“identity”

情況描述

onnx模型的冗余节点“identity”如下图。
在这里插入图片描述

解决方式

首先,确保您已经安装了onnx-simplifier库:

pip install onnx-simplifier

然后,您可以按照以下方式使用onnx-simplifier库:

import onnx
from onnxsim import simplify

# 加载导出的 ONNX 模型
onnx_model = onnx.load("your_model.onnx")

# 简化模型
simplified_model, check = simplify(onnx_model)

# 保存简化后的模型
onnx.save_model(simplified_model, "simplified_model.onnx")

通过这个过程,onnx-simplifier库将会检测和移除不必要的"identity"节点,从而减少模型中的冗余。

请注意,使用onnx-simplifier库可能会改变模型的计算图,因此在使用简化后的模型之前,务必进行测试和验证以确保其功能没有受到影响。

对比两者结果是否一样的代码:

    import torch.onnx
    dummy_input = torch.randn(1, 3, 64, 64)
    net = MobileNetV3_Small_050().eval()
    # 比较onnx模型和pytorch模型的输出
    import onnxruntime
    import numpy as np
    sess = onnxruntime.InferenceSession("simplified_mobilenetv3_small_050.onnx")
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name
    onnx_result = sess.run([output_name], {input_name: dummy_input.numpy()})[0]
    pytorch_result = net(dummy_input).detach().numpy()
    print(np.allclose(onnx_result, pytorch_result, rtol=1e-03, atol=1e-05))

问题原因

在将 PyTorch 模型转换为 ONNX 格式时,有时会出现冗余的"identity"节点的问题。这是因为 PyTorch 和 ONNX 在计算图构建和表示方式上存在一些差异。

在 PyTorch 中,计算图是动态构建的,其中包含了很多临时变量和操作。但在 ONNX 中,计算图是静态定义的,每个操作都显式地表示为一个节点。这种差异可能导致在将 PyTorch 模型转换为 ONNX 格式时引入一些不必要的中间"identity"节点。

一个常见的原因是,PyTorch 中的某些操作或模型结构在 ONNX 中没有直接的等价表示。为了保持模型结构的一致性,转换过程中可能会引入额外的"identity"节点,用于保留原始模型中的特定计算图结构或操作。

另外,有时候这些"identity"节点并不会对模型的性能或功能产生任何影响,它们只是在图形表示上引入了一些冗余。这些冗余节点在模型尺寸较小的情况下可能并不明显,但对于大型模型来说可能会显著增加模型文件的大小。

通过使用onnx-simplifier库,您可以对导出的 ONNX 模型进行后处理,去除这些不必要的"identity"节点,从而减少模型的冗余。

需要注意的是,由于 PyTorch 和 ONNX 之间的差异,无法完全避免所有的冗余节点。但大部分情况下这些冗余节点并不会对模型的性能或功能产生实质性的影响。

我的模型代码

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import init


class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out


class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out


# 注意力机制
class SeModule(nn.Module):
    def __init__(self, in_channel, reduction=4):
        super(SeModule, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channel, in_channel // reduction, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(in_channel // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_channel // reduction, in_channel, kernel_size=1, stride=1, padding=0, bias=False)
        self.hs = hsigmoid()

    def forward(self, x):
        out = self.avgpool(x)
        out = self.fc1(out)
        out = self.bn(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.hs(out)
        return x * out


# 线性瓶颈和反向残差结构
class Block(nn.Module):
    def __init__(self, kernel_size, in_channel, expand_size, out_channel, nolinear, semodule, stride):
        super(Block, self).__init__()
        self.stride = stride
        self.se = semodule
        # 1*1展开卷积
        self.conv1 = nn.Conv2d(in_channel, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(expand_size)
        self.nolinear1 = nolinear
        # 3*3(或5*5)深度可分离卷积
        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride,
                               padding=kernel_size // 2, groups=expand_size, bias=False)
        self.bn2 = nn.BatchNorm2d(expand_size)
        self.nolinear2 = nolinear
        # 1*1投影卷积
        self.conv3 = nn.Conv2d(expand_size, out_channel, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channel),
            )

    def forward(self, x):
        out = self.nolinear1(self.bn1(self.conv1(x)))
        out = self.nolinear2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        # 注意力模块
        if self.se != None:
            out = self.se(out)
        # 残差链接
        out = out + self.shortcut(x) if self.stride == 1 else out
        return out


class MobileNetV3_Small_050(nn.Module):
    def __init__(self):
        super(MobileNetV3_Small_050, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = nn.ReLU(inplace=True)
        self.bneck = nn.Sequential(
            Block(3, 16, 8, 16, nn.ReLU(inplace=True), SeModule(16), 2),
            Block(3, 16, 40, 16, nn.ReLU(inplace=True), None, 2),
            Block(3, 16, 56, 16, nn.ReLU(inplace=True), None, 1),
            Block(5, 16, 64, 24, hswish(), SeModule(24), 2),
            Block(5, 24, 144, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 144, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 72, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 72, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 144, 48, hswish(), SeModule(48), 2),
            Block(5, 48, 288, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 288, 48, hswish(), SeModule(48), 1),
        )
        self.conv2 = nn.Conv2d(48, 288, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(288)
        self.hs2 = hswish()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(288, 6)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = out.view(-1, 288)
        out = self.fc(out)
        return out


class MobileNetV3_Small(nn.Module):
    def __init__(self):
        super(MobileNetV3_Small, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()
        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
            Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
            Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
            Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
        )

        self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(576)
        self.hs2 = hswish()

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(576, 6)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = out.view(-1, 576)
        out = self.fc(out)
        return out


if __name__ == '__main__':
    # from torchsummary import summary
    # net = MobileNetV3_Small_050().train()
    # summary(net, (3, 64, 64))
    #
    # from torchstat import stat
    # net = MobileNetV3_Small_050().train()
    # stat(net, input_size=(3, 64, 64))  # 输出模型的FLOPs和参数数量

    # 转为onnx
    import torch.onnx

    dummy_input = torch.randn(1, 3, 64, 64)
    net = MobileNetV3_Small_050().eval()
    torch.onnx.export(net, dummy_input, "mobilenetv3_small_050.onnx", input_names=["input"], output_names=["output"],
                      opset_version=11, )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/28063.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32F407软件模拟I2C实现MPU6050通讯(CUBEIDE)

STM32F407软件模拟I2C实现MPU6050通讯(CUBEIDE) 文章目录 STM32F407软件模拟I2C实现MPU6050通讯(CUBEIDE)模拟I2C读写的实现mpu6050_iic.cmpu6050_iic.h代码分析 复位,读取温度,角度等函数封装mpu6050.cmpu…

QT学习07:五种按钮控件

文章首发于我的个人博客:欢迎大佬们来逛逛 文章目录 抽象类:QAbstractButtonQPushButtonQToolButtonQCommandLinkButtonQRadioButtonQCheckBoxQButtonGroup 抽象类:QAbstractButton 是所有按钮类的祖先。 QAbstractButton的信号&#xff1a…

深入理解CSS字符转义行为

深入理解CSS字符转义行为 深入理解CSS字符转义行为 前言为什么要转义&#xff1f;CSS 转义什么是合法css的表达式 左半部分右半部分 练习参考链接 前言 在日常的开发中&#xff0c;我们经常写css。比如常见的按钮: <button class"btn"></button>&am…

【MySQL】 IS NOT NULL 和 != NULL 的区别?

背景 最近在开发小伙伴的需求&#xff0c;遇到了一个数据库统计的问题&#xff0c; is not null 结果正确 &#xff01;null 结果就不对&#xff0c;然后就激发了获取真理的想法&#xff0c;那必须的查查 咋回事嘞&#xff1f; 开整 在用MySQL的过程中&#xff0c;你是否存…

大学物理(上)-期末知识点结合习题复习(4)——质点运动学-动能定理 力做功 保守力与非保守力 势能 机械能守恒定律 完全弹性碰撞

目录 1.力做功 恒力作用下的功 变力的功 2.动能定理 3.保守力与非保守力 4.势能 引力的功与弹力的功 引力势能与弹性势能 5.保守力做功与势能的关系 6.机械能守恒定律 7.完全弹性碰撞 题1 题目描述 题解 题2 题目描述 题解 1.力做功 物体在力作用下移动做功…

AWS CodeWhisperer 简单介绍

一、何为AWS CodeWhisperer Amazon CodeWhisperer能够理解以自然语言&#xff08;英语&#xff09;编写的注释&#xff0c;并能实时生成多条代码建议&#xff0c; 以此提高开发人员生产力。 二、主要功能 Amazon CodeWhisperer 的主要功能&#xff0c;包括代码生成、引用追踪…

36.SpringBoot实用篇—运维

目录 一、实用篇—运维。 &#xff08;1&#xff09;程序打包与运行&#xff08;Windows版&#xff09;。 &#xff08;2&#xff09;spring-boot-maven-plugin插件作用。 &#xff08;3&#xff09;程序打包与运行&#xff08;Linux版&#xff09;。 &#xff08;4&#…

chatgpt赋能python:Python中如何处理多个输入

Python中如何处理多个输入 在编写Python程序时&#xff0c;我们经常需要从用户那里获取多个输入来执行某些操作。本文将介绍Python中的各种方法来处理多个输入。 从终端获取多个输入 Python中最简单的方式是从终端获取多个输入。下面是一个基本的例子&#xff1a; input_st…

SpringSecurity实现前后端分离登录token认证详解

目录 1. SpringSecurity概述 1.1 权限框架 1.1.1 Apache Shiro 1.1.2 SpringSecurity 1.1.3 权限框架的选择 1.2 授权和认证 1.3 SpringSecurity的功能 2.SpringSecurity 实战 2.1 引入SpringSecurity 2.2 认证 2.2.1 登录校验流程 2.2.2 SpringSecurity完整流程 2.2.…

Splashtop 与 Pax8 合作为 MSP 提供简化的远程支持解决方案

2023年4月27日 科罗拉多州丹佛 Pax8 是一个行业领先的云商务市场&#xff0c;该公司今天宣布将通过 Pax8 市场在全球推出其全新运营供应商 Splashtop。Splashtop 的远程访问、支持以及端点监控和管理解决方案极具成本效益&#xff0c;而且功能强大&#xff0c;可以助力托管服务…

002、体系结构之TiDB Server

TiDB Server 1、TiDB总览1.1、TiDB Server架构1.2、TiDB Server 主要功能&#xff1a; 2、SQL语句处理语句的解析和编译SQL层协议层上下文解析层逻辑优化器物理优化器本地执行器分布式执行器 3、如何将表的数据转成kv形式4、在线DDL相关模块5、GC机制与相关模块6、TiDB Server …

你真的会写软件测试简历吗?为什么面试约不到,测试老鸟的建议...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 写好一份简历的三…

Frida技术—逆向开发的屠龙刀

简介 Frida是一种基于JavaScript的动态分析工具&#xff0c;可以用于逆向开发、应用程序的安全测试、反欺诈技术等领域。Frida主要用于在已安装的应用程序上运行自己的JavaScript代码&#xff0c;从而进行动态分析、调试、修改等操作&#xff0c;能够绕过应用程序的安全措施&a…

mac下部署和访问 Kubernetes 仪表板(Dashboard)

简介 Dashboard 是基于网页的 Kubernetes 用户界面。 你可以使用 Dashboard 将容器应用部署到 Kubernetes 集群中&#xff0c;也可以对容器应用排错&#xff0c;还能管理集群资源。 你可以使用 Dashboard 获取运行在集群中的应用的概览信息&#xff0c;也可以创建或者修改 Kub…

QT--配置Opencv

提示&#xff1a;本文为学习记录&#xff0c;若有疑问&#xff0c;请及时联系作者。 文章目录 前言一、下载已编译的opencv1..解压2..path路径 二、使用步骤1..pro文件2..h文件 总结 前言 只做第一个我&#xff0c;不做第二个谁。 一、下载已编译的opencv 适用于mingw编译器…

NoSQL数据库

NoSQL数据库 NoSQL简介NoSQL兴起的原因NoSQL与关系数据库的对比NoSQL的四大类型键值数据库列族数据库文档数据库图形数据库不同类型数据库比较分析RedisMongoDBCassandraNeo4j NoSQL三大基石CAPBASE最终一致性 NoSQL简介 “Not Only SQL”泛指非关系型的数据库&#xff0c;区别…

07_scrapy的应用——获取电影数据(通过excel保存静态页面scrapy爬虫数据的模板/通过数据库保存)

0、前言: 一般我们自己创建的一些python项目,我们都需要创建虚拟环境,其中会下载很多包,也叫做依赖。但是我们在给他人分享我们的项目时,不能把虚拟环境打包发送给别人,因为每个人电脑系统不同,我们可以把依赖导出为依赖清单,然后别人有了我们的依赖清单,就可以用一条…

项目使用tensorflow2会出错,下载并使用tensorflow1

背景&#xff1a;使用pycharm安装总显示安装失败&#xff0c;使用pip安装也不行&#xff0c;只能使用conda配置虚拟环境手动安装 1、下载安装anaconda 官网下载&#xff0c;双击安装。用anaconda就是想使用虚拟环境&#xff0c;万一没弄好直接删了重新搞就行。 2、创建虚拟环境…

Https加密超文本传输协议的运用

1.https的相关知识 1.1 https的简介 HTTPS &#xff08;全称&#xff1a;Hypertext Transfer Protocol Secure &#xff09;&#xff0c;是以安全为目标的 HTTP 通道&#xff0c;在HTTP的基础上通过传输加密和身份认证保证了传输过程的安全性 。HTTPS 在HTTP 的基础下加…

使用POI实现JAVA操作Excel

Apache POI POI提供API给JAVA程序对Microsoft Office格式档案读和写的功能 POI工具介绍 POI 是用Java编写的免费开源的跨平台的 Java API&#xff0c;Apache POI提供API给Java程式对Microsoft Office格式档案读和写的功能。主要是运用其中读取和输出excel的功能。 POI官网地…