前言
构建onnx方式通常有两种:
1、通过代码转换成onnx结构,比如pytorch —> onnx
2、通过onnx 自定义结点,图,生成onnx结构
本文主要是简单学习和使用两种不同onnx结构,
下面以 Equal
结点进行分析
方式
方法一:pytorch --> onnx
暂缓,主要研究方式二
方法二: onnx
# import torch
# import torch.nn as nn
# class JustEqual(nn.Module):
# def __init__(self):
# super(JustEqual, self).__init__()
# def forward(self,x):
# return x
import onnx
from onnx import TensorProto, helper, numpy_helper
import numpy as np
def run():
print("run start....\n")
equal = helper.make_node(
"Equal",
name="Equal_0",
inputs=["input", "equal"],
outputs=["output1"],
)
# initializer = [
# helper.make_tensor("equal", TensorProto.FLOAT, [1,1,1], np.zeros((1,1,1), dtype=np.float32))
# ]
initializer = [
helper.make_tensor("equal", TensorProto.FLOAT, [1], np.zeros((1), dtype=np.float32))
]
# initializer = [
# helper.make_tensor("equal", TensorProto.FLOAT, [], [0])
# ]
cast_nodel = helper.make_node(
op_type="Cast",
inputs=["output1"],
outputs=["output2"],
name="test_cast",
to=TensorProto.FLOAT,
)
value_info = helper.make_tensor_value_info(
"output1", TensorProto.BOOL, [16,1,397])
graph = helper.make_graph(
nodes=[equal, cast_nodel],
name="test_graph",
inputs=[helper.make_tensor_value_info(
"input", TensorProto.FLOAT, [16,1,397]
)],
outputs=[helper.make_tensor_value_info(
"output2",TensorProto.FLOAT, [16,1,397]
)],
initializer=initializer,
value_info=[value_info],
)
op = onnx.OperatorSetIdProto()
op.version = 11
model = helper.make_model(graph, opset_imports=[op])
model.ir_version = 8
print("run done....\n")
return model
if __name__ == "__main__":
model = run()
onnx.save(model, "./test_equal.onnx")
# onnx.save(model, "./test_equal_ori.onnx")
run
import onnx
import onnxruntime
import numpy as np
# 检查onnx计算图
def check_onnx(mdoel):
onnx.checker.check_model(model)
# print(onnx.helper.printable_graph(model.graph))
def run(model):
print(f'run start....\n')
session = onnxruntime.InferenceSession(model,providers=['CPUExecutionProvider'])
input_name1 = session.get_inputs()[0].name
input_data1= np.random.randn(16,1,397).astype(np.float32)
print(f'input_data1 shape:{input_data1.shape}\n')
output_name1 = session.get_outputs()[0].name
pred_onx = session.run(
[output_name1], {input_name1: input_data1})[0]
print(f'pred_onx shape:{pred_onx.shape} \n')
print(f'run end....\n')
if __name__ == '__main__':
path = "./test_equal.onnx"
model = onnx.load("./test_equal.onnx")
check_onnx(model)
run(path)