PyTorch开放神经网络交换(Open Neural Network Exchange)ONNX通用格式模型的熟悉

我们在深度学习中可以发现有很多不同格式的模型文件,比如不同的框架就有各自的文件格式:.model、.h5、.pb、.pkl、.pt、.pth等等,各自有标准就带来互通的不便,所以微软、Meta和亚马逊在内的合作伙伴社区一起搞一个ONNX(Open Neural Network Exchange)文件格式的通用标准,这样就可以使得模型在不同框架之间方便的进行互操作了。
在上节的YOLOv8的目标对象的分类,分割,跟踪和姿态估计的多任务检测实践(Netron模型可视化) 我们在最后有接触到这个格式文件,而且给出了一个 https://netron.app/站点,可以将这个文件的计算图和相关属性都能可视化的给呈现出来。YOLO的模型有自带的方法:

from ultralytics import YOLO
model = YOLO('yolov8n-cls.pt')
model.export(format="onnx")

或者命令行的方式
yolo export model=yolov8n-cls.pt format=onnx

1、PyTorch演示onnx

这里我们来熟悉下在PyTorchonnx是怎么样的。看一个简单示例:

import torch

class myModel(torch.nn.Module):
    def __init__(self):
        super(myModel, self).__init__()

    def forward(self, x):
        return x.reshape(1,3,64,64)

model = myModel()
x = torch.randn(64,64,3)
torch.onnx.export(model, x, 'newmodel.onnx', input_names=['images'], output_names=['newimages'])

修改输入进来的形状,可以看到通过torch.onnx.export就可以生成onnx格式的模型文件。我们将它上传到上面的站点,将会生成如下的流程图:

 我们打印看下export方法有哪些参数:

export(model: 'Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction]', args: 'Union[Tuple[Any, ...], torch.Tensor]', f: 'Union[str, io.BytesIO]', export_params: 'bool' = True, verbose: 'bool' = False, training: '_C_onnx.TrainingMode' = <TrainingMode.EVAL: 0>, input_names: 'Optional[Sequence[str]]' = None, output_names: 'Optional[Sequence[str]]' = None, operator_export_type: '_C_onnx.OperatorExportTypes' = <OperatorExportTypes.ONNX: 0>, opset_version: 'Optional[int]' = None, do_constant_folding: 'bool' = True, dynamic_axes: 'Optional[Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]]' = None, keep_initializers_as_inputs: 'Optional[bool]' = None, custom_opsets: 'Optional[Mapping[str, int]]' = None, export_modules_as_functions: 'Union[bool, Collection[Type[torch.nn.Module]]]' = False) -> 'None'

2、加载onnx模型

2.1、安装onnx库

为了能够正确的加载onnx格式文件,需要先安装onnx库,依然推荐加豆瓣镜像安装

pip install onnx -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

安装好了之后,我们看下是否成功安装:

import onnx
dir(onnx)
'''
['Any', 'AttributeProto', 'EXPERIMENTAL', 'FunctionProto', 'GraphProto', 'IO', 'IR_VERSION', 'IR_VERSION_2017_10_10', 'IR_VERSION_2017_10_30', 'IR_VERSION_2017_11_3', 'IR_VERSION_2019_1_22', 'IR_VERSION_2019_3_18', 'IR_VERSION_2019_9_19', 'IR_VERSION_2020_5_8', 'IR_VERSION_2021_7_30', 'MapProto', 'ModelProto', 'NodeProto', 'ONNX_ML', 'OperatorProto', 'OperatorSetIdProto', 'OperatorSetProto', 'OperatorStatus', 'Optional', 'OptionalProto', 'STABLE', 'SequenceProto', 'SparseTensorProto', 'StringStringEntryProto', 'TensorAnnotation', 'TensorProto', 'TensorShapeProto', 'TrainingInfoProto', 'TypeProto', 'TypeVar', 'Union', 'ValueInfoProto', 'Version', '_Proto', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', '_deserialize', '_get_file_path', '_load_bytes', '_save_bytes', '_serialize', 'checker', 'compose', 'convert_model_to_external_data', 'defs', 'external_data_helper', 'gen_proto', 'google', 'helper', 'hub', 'load', 'load_external_data_for_model', 'load_from_string', 'load_model', 'load_model_from_string', 'load_tensor', 'load_tensor_from_string', 'mapping', 'numpy_helper', 'onnx_cpp2py_export', 'onnx_data_pb', 'onnx_data_pb2', 'onnx_ml_pb2', 'onnx_operators_ml_pb2', 'onnx_operators_pb', 'onnx_pb', 'os', 'parser', 'printer', 'save', 'save_model', 'save_tensor', 'shape_inference', 'typing', 'utils', 'version', 'version_converter', 'write_external_data_tensors']
'''

2.2、onnx模型信息

正确安装之后,就可以加载模型了。

import onnx

myModel = onnx.load("newmodel.onnx")
#打印整个模型信息
#print(myModel)
output = myModel.graph.output
#打印输出层的信息
print(output)

[name: "newimages"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 64
      }
      dim {
        dim_value: 64
      }
    }
  }
}
]

2.3、加载外部数据

需要注意的是,外部数据如果跟模型文件不是在同一个目录里面的话,需要使用load_external_data_for_model来加载外部数据

import onnx
from onnx.external_data_helper import load_external_data_for_model

mymodel = onnx.load('xx/model.onnx', load_external_data=False)
load_external_data_for_model(mymodel, 'datasets/')

这样就可以通过mymodel这个模型来加载指定目录的数据了。

3、保存新模型

我们可以试着在模型的基础上做一些修改,然后保存为一个新的模型文件。

import onnx
from onnx import helper

model = onnx.load('newmodel.onnx')
prob_info = helper.make_tensor_value_info('xx',onnx.TensorProto.FLOAT, [1,3,128,128])
model.graph.output.insert(0, prob_info)
onnx.save(model, 'newmodel2.onnx')

将一个新的节点插入到输出层,我们来看下生成的图:

其中helper里面有很多用法,比如增加节点:

node = helper.make_node("Range",inputs=["start", "limit", "delta"],outputs=["output"])
start = np.float32(1)
limit = np.float32(15)
delta = np.float32(2)
output = np.arange(start, limit, delta, dtype=np.float32)
print(output)#[ 1.  3.  5.  7.  9. 11. 13.]

这里的Range是其中一个操作类型(算子),用法:Operators 更多的操作类型,有兴趣的可以进去查阅

4、常见对象

在onnix常出现的几个对象,AttributeProto,TensorProto,GraphProto,NodeProto一起来了解下:

4.1、AttributeProto 

属性

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto

arg = helper.make_attribute("this_is_an_int", 1701)
print(arg)
'''
name: "this_is_an_int"
type: INT
i: 1701
'''

当然类型还可以是浮点数、字符串、数组等

4.2、NodeProto

节点

node_proto = helper.make_node("Relu", ["X"], ["Y"])
print(node_proto)
'''
input: "X"
output: "Y"
op_type: "Relu"
'''

其中op_type的算子挺多的,比如上面那个Range的用法。

4.3、AttributeProto和NodeProto结合

看下两者结合使用的一个例子,卷积核为3,步幅为1,填充为1的一个卷积。 

node_proto = helper.make_node(
    "Conv", ["X", "W", "B"], ["Y"],
    kernel=3, stride=1, pad=1)
# 属性按顺序打印
node_proto.attribute.sort(key=lambda attr: attr.name)
print(node_proto)
'''
input: "X"
input: "W"
input: "B"
output: "Y"
op_type: "Conv"
attribute {
  name: "kernel"
  type: INT
  i: 3
}
attribute {
  name: "pad"
  type: INT
  i: 1
}
attribute {
  name: "stride"
  type: INT
  i: 1
}
'''

另一种更有可读性的打印:

print(helper.printable_node(node_proto))
%Y = Conv[kernel = 3, pad = 1, stride = 1](%X, %W, %B)

4.4、TensorProto和GraphProto

graph_proto = helper.make_graph(
    [
        helper.make_node("FC", ["X", "W1", "B1"], ["H1"]),
        helper.make_node("Relu", ["H1"], ["R1"]),
        helper.make_node("FC", ["R1", "W2", "B2"], ["Y"]),
    ],
    "MLP",
    [
        helper.make_tensor_value_info("X" , TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info("W1", TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info("B1", TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info("W2", TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info("B2", TensorProto.FLOAT, [1]),
    ],
    [
        helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1]),
    ]
)
print(helper.printable_graph(graph_proto))
'''
graph MLP (
  %X[FLOAT, 1]
  %W1[FLOAT, 1]
  %B1[FLOAT, 1]
  %W2[FLOAT, 1]
  %B2[FLOAT, 1]
) {
  %H1 = FC(%X, %W1, %B1)
  %R1 = Relu(%H1)
  %Y = FC(%R1, %W2, %B2)
  return %Y
'''

5、解析器

5.1、parse_graph

onnx.parser.parse_graph将文本表示,创建成ONNX图形

input = """
   agraph (float[N, 128] X, float[128, 10] W, float[10] B) => (float[N, 10] C)
   {
        T = MatMul(X, W)
        S = Add(T, B)
        C = Softmax(S)
   }
"""
graph = onnx.parser.parse_graph(input)
print(helper.printable_graph(graph))
'''
graph agraph (
  %X[FLOAT, Nx128]
  %W[FLOAT, 128x10]
  %B[FLOAT, 10]
) {
  %T = MatMul(%X, %W)
  %S = Add(%T, %B)
  %C = Softmax(%S)
  return %C
}
'''

 5.2、parse_model

onnx.parser.parse_model将文本表示,创建成ONNX模型

input = """
   <
     ir_version: 7,
     opset_import: ["" : 10]
   >
   agraph (float[N, 128] X, float[128, 10] W, float[10] B) => (float[N, 10] C)
   {
      T = MatMul(X, W)
      S = Add(T, B)
      C = Softmax(S)
   }
"""
model = onnx.parser.parse_model(input)
print(model)

'''
ir_version: 7
opset_import {
  domain: ""
  version: 10
}
graph {
  node {
    input: "X"
    input: "W"
    output: "T"
    op_type: "MatMul"
    domain: ""
  }
  node {
    input: "T"
    input: "B"
    output: "S"
    op_type: "Add"
    domain: ""
  }
  node {
    input: "S"
    output: "C"
    op_type: "Softmax"
    domain: ""
  }
  name: "agraph"
  input {
    name: "X"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_param: "N"
          }
          dim {
            dim_value: 128
          }
        }
      }
    }
  }
  input {
    name: "W"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 128
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
  input {
    name: "B"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
  output {
    name: "C"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_param: "N"
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
}
'''

引用来源
github:https://github.com/onnx/onnx

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/33767.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Spring Cloud系列】-负载均衡(Load Balancer,LB)

【Spring Cloud系列】-负载均衡&#xff08;Load Balancer&#xff0c;LB&#xff09; 文章目录 【Spring Cloud系列】-负载均衡&#xff08;Load Balancer&#xff0c;LB&#xff09;一、什么是负载均衡&#xff08;Load Balancer&#xff0c;LB&#xff09;二、负载均衡的主要…

vue2、vue3分别配置echarts多图表的同步缩放

文章目录 ⭐前言⭐使用dataZoom api实现echart的同步缩放&#x1f496; vue2实现echarts多图表同步缩放&#x1f496; vue3实现echarts多图表同步缩放 ⭐结束 ⭐前言 大家好&#xff01;我是yma16&#xff0c;本文分享在vue2和vue3中配置echarts的多图表同步缩放 背景&#xf…

教你如何使用Nodejs搭建HTTP web服务器并发布上线公网

文章目录 前言1.安装Node.js环境2.创建node.js服务3. 访问node.js 服务4.内网穿透4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5.固定公网地址 转载自内网穿透工具的文章&#xff1a;使用Nodejs搭建HTTP服务&#xff0c;并实现公网远程访问「内网穿透」 前言 Node.js…

自动驾驶开源数据集(附下载链接)

自动驾驶是带动新兴产业的一个突破点&#xff0c;也是中国结合新能源汽车&#xff0c;实现汽车产业弯道超车的不二手段&#xff0c;是打破国外燃油车技术壁垒的关键一步&#xff01;它不会停止&#xff0c;只是在蓄势待发&#xff01; 数据集介绍&#xff1a;点击 自动驾驶场…

使用MATLAB画SCI论文图

从gcf和gca说起 不论是 Python 绘图还是Matlab绘图&#xff0c;想要获得更好看的图&#xff0c;都会用到这两个单词。 gcf&#xff1a;get current figure&#xff0c;是目标图像的图形句柄对象 gca&#xff1a;get current axes&#xff0c;是目标图像的坐标轴句柄对象 Mat…

java 全局、局部异常处理详解及result结果封装

1、引入spring-boot-starter-web依赖和new-swagger依赖 <dependency><groupId>com.jjw</groupId><artifactId>new-swagger</artifactId><version>1.0-SNAPSHOT</version> </dependency> <dependency><groupId>or…

【Vue2】Vant2上传文件使用formData方式,base64图片转Blob再转File上传

文章目录 前言一、base64转换为 Blob 对象的方法二、使用步骤1.引入工具类js2.编写formData上传方法3.api方法中的request代码 三、实际操作1.html代码2.js代码 总结 前言 vant2上传组件传送门 使用vant2组件中的uploader组件 <van-uploader v-model"fileList" …

耳挂式骨传导耳机哪个牌子好,分享几个品牌的骨传导耳机

骨传导耳机就是利用震动来传递声音的耳机&#xff0c;在运动时佩戴骨传导耳机&#xff0c;可以听歌也能听周围的声音&#xff0c;提高了运动时的安全性。目前市面上的骨传导耳机也是琳琅满目。今天就来给大家分享下目前市面上比较常见的几款骨传导耳机。希望对正在选购骨传导耳…

什么是kafka,如何学习kafka,整合SpringBoot

目录 一、什么是Kafka&#xff0c;如何学习 二、如何整合SpringBoot 三、Kafka的优势 一、什么是Kafka&#xff0c;如何学习 Kafka是一种分布式的消息队列系统&#xff0c;它可以用于处理大量实时数据流。学习Kafka需要掌握如何安装、配置和运行Kafka集群&#xff0c;以及如…

代码随想录算法训练营第五十三天|1143.最长公共子序列、1035.不相交的线、53. 最大子序和 动态规划

最长公共子序列 确定dp数组&#xff08;dp table&#xff09;以及下标的含义 和上一题一样&#xff0c;dp[i][j]代表&#xff1a; 长度为[0, i - 1]的字符串text1与长度为[0, j - 1]的字符串text2的最长公共子序列为dp[i][j]确定递推公式 主要就是两大情况&#xff1a; text1[…

一步一步学OAK之三:实现RGB相机场景切换

目录 Setup 1: 创建文件Setup 2: 安装依赖Setup 3: 导入需要的包Setup 4: 遍历所有场景模式和特效模式Setup 5: 创建pipelineSetup 6: 创建节点Setup 7: 连接设备并启动管道Setup 8: 创建与DepthAI设备通信的输入队列和输出队列Setup 9: 定义putText函数Setup 10: 主循环获取视…

uni-app

uni-app 一、准备工作1.新建项目2.配置浏览器3.兼容4.新建页面 二、上手1.pages.json文件的页面配置与全局配置2.rpx尺寸单位3.内置组件4.vue写法 一、准备工作 uni-app文档 HBuilderX&#xff0c;H是HTML的首字母&#xff0c;Builder是构造者&#xff0c;X是HBuilder的下一代版…

实例005 可以拉伸的菜单界面

实例说明 如果管理程序功能菜单非常多&#xff0c;而用户只使用一些常用菜单&#xff0c;这时&#xff0c;可以将主菜单项下的不常用菜单隐藏起来。此种显示方式类似于对菜单进行拉伸。使用时&#xff0c;只需单击展开菜单&#xff0c;即可显示相应菜单功能。运行本例&#xf…

使用Python批量进行数据分析

案例01 批量升序排序一个工作簿中的所有工作表——产品销售统计表.xlsx import xlwings as xw import pandas as pd app xw.App(visible False, add_book False) workbook app.books.open(产品销售统计表.xlsx) worksheet workbook.sheets # 列出工作簿中的所有工作表 fo…

VVIC搜款网API接口:获取商品详情数据API

VVIC电商平台汇集了数千家优质品牌和供应商&#xff0c;包括服装、家居用品、电子产品、美妆产品、食品和饮料等各种商品。消费者可以在VVIC上找到各类品牌和产品&#xff0c;满足他们的购物需求。VVIC还提供了多种付款方式和物流配送服务&#xff0c;确保消费者的购物过程顺利…

第27章 uView 内置路由使用注意事项

1 uView 内置路由不支持通过“localhost”域名直接获取数据。 在前后分离开发中“axios” 路由支持使用“localhost”域名或IP地址获取后端的数据&#xff0c;所以不管是IIS部署还是后端调试通过“axios” 路由都能获取数据&#xff0c;对于.NetCore的前后端分离开发来说“axio…

NLP学习笔记(二)

文章目录 &#xff08;一&#xff09;负采样&#xff08;二&#xff09;GloVe1.带全局语料库的跳元模型2.GloVe模型3.问题4.跳元模型与GloVe模型的比较 &#xff08;三&#xff09;问题1.参数初始化2.梯度下降3.下游任务4.句法信息5.似然估计6.词向量表示 &#xff08;一&#…

2023 中兴捧月算法挑战赛-自智网络-参赛总结

“中兴捧月”是由中兴通讯面向在校大学生举办的全球性系列赛事活动&#xff0c;致力于培养学生建模编程、创新、方案策划和团队合作能力。今年是在学校的宣传下了解到比赛&#xff0c;最初抱着学习的态度报名了比赛&#xff0c;最终进入了决赛&#xff0c;完成了封闭的开发与赛…

Jenkins+Gitlab+Springboot项目部署Jar和image两种方式

Springboot环境准备 利用spring官网快速创建springboot项目。 添加一个controller package com.example.demo;import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController;RestController public class…

【结构型设计模式】桥接模式

一、写在前面 桥接模式&#xff08;Bridge&#xff09;&#xff1a;桥接模式是一种结构型设计模式&#xff0c;其目的是将抽象部分和实现部分分离&#xff0c;允许它们可以独立地变化。该模式通过创建一个桥接类&#xff0c;连接抽象和实现&#xff0c;使得它们可以独立地进行…