ONNX 模型直接合并,输入和输出不一致也可以,各自输入输出各自的
示例代码
import onnxruntime
# version : 1.16.0
import onnx
def log_model(model):
model_1_outs = {o.name for o in model.graph.output}
model_1_ins = {i.name for i in model.graph.input}
print(model_1_outs)
print(model_1_ins)
# model 1
model_1 = onnx.load("model1.onnx")
model_1_new = onnx.compose.add_prefix(model_1, prefix="G1_")
log_model(model_1)
log_model(model_1_new)
# model 2
model_2 = onnx.load("model2.onnx")
model_2_new = onnx.compose.add_prefix(model_2, prefix="G2_")
log_model(model_2)
log_model(model_2_new)
# concat
# input 合并
model_1_new.graph.input.extend(model_2_new.graph.input)
# node 合并
model_1_new.graph.node.extend(model_2_new.graph.node)
# initializer 合并
model_1_new.graph.initializer.extend(model_2_new.graph.initializer)
# output 合并
model_1_new.graph.output.extend(model_2_new.graph.output)
# 保存新模型
onnx.save(model_1_new, "combined_1.onnx")
# 测试
session = onnxruntime.InferenceSession(
"combined.onnx", providers=["CPUExecutionProvider"]
)
input_names = [i.name for i in session.get_inputs()]
output_names = [i.name for i in session.get_outputs()]
print("\n--- new model ---")
print(input_names)
print(output_names)