import numpy as np
import onnx
import onnxruntime
import onnxruntime. backend as backend
model = onnx. load( 'test.onnx' )
node = model. graph. node
graph = model. graph
new_node_0 = onnx. helper. make_node(
"Mul" ,
inputs= [ "input_image" , "1" ] ,
outputs= [ "mutiply" ] ,
)
mutiply_node = onnx. helper. make_node(
"Constant" ,
inputs= [ ] ,
outputs= [ "1" ] ,
value= onnx. helper. make_tensor( 'value' , onnx. TensorProto. FLOAT, [ ] , [ 2.0 ] )
)
new_node_1 = onnx. helper. make_node(
"Add" ,
inputs= [ "mutiply" , "2" ] ,
outputs= [ "add" ] ,
)
add_node = onnx. helper. make_node(
"Constant" ,
inputs= [ ] ,
outputs= [ "2" ] ,
value= onnx. helper. make_tensor( 'value' , onnx. TensorProto. FLOAT, [ ] , [ - 1.0 ] )
)
old_squeeze_node = model. graph. node[ 0 ]
old_squeeze_node. input [ 0 ] = "add"
model. graph. node. remove( old_squeeze_node)
graph. node. insert( 0 , mutiply_node)
graph. node. insert( 1 , new_node_0)
graph. node. insert( 2 , add_node)
graph. node. insert( 3 , new_node_1)
graph. node. insert( 4 , old_squeeze_node)
onnx. checker. check_model( model)
onnx. save( model, 'out.onnx' )
print ( onnxruntime. get_device( ) )
rt = backend. prepare( model, "CPU" )
out = rt. run( np. ones( [ 1 , 1 , 128 , 128 ] , dtype= np. float32) )
print ( out)
第二种使用可供训练的初始化参数
import numpy as np
import onnx
import onnxruntime
import onnxruntime. backend as backend
model = onnx. load( 'test.onnx' )
node = model. graph. node
graph = model. graph
mutiply_node = onnx. helper. make_tensor( name= '1' ,
data_type= onnx. TensorProto. FLOAT,
dims= [ 1 ] ,
vals = np. array( [ 2.0 ] , dtype= np. float32)
)
graph. initializer. append( mutiply_node)
new_node_0 = onnx. helper. make_node(
"Mul" ,
inputs= [ "input_image" , "1" ] ,
outputs= [ "mutiply" ] ,
)
add_node = onnx. helper. make_tensor( name= '2' ,
data_type= onnx. TensorProto. FLOAT,
dims= [ 1 ] ,
vals = np. array( [ - 1. ] , dtype= np. float32)
)
graph. initializer. append( add_node)
new_node_1 = onnx. helper. make_node(
"Add" ,
inputs= [ "mutiply" , "2" ] ,
outputs= [ "add" ] ,
)
old_squeeze_node = model. graph. node[ 0 ]
old_squeeze_node. input [ 0 ] = "add"
model. graph. node. remove( old_squeeze_node)
graph. node. insert( 0 , new_node_0)
graph. node. insert( 1 , new_node_1)
graph. node. insert( 2 , old_squeeze_node)
onnx. checker. check_model( model)
onnx. save( model, 'out.onnx' )
print ( onnxruntime. get_device( ) )
rt = backend. prepare( model, "CPU" )
out = rt. run( np. ones( [ 1 , 1 , 128 , 128 ] , dtype= np. float32) )
print ( out)