ONNX 模型修改
当我们熟悉了ONNX模型各个层级的结构后,我们便可以针对各个结构来对模型进行修改,从而使其更好的适配后端运行时或者特定硬件平台的编译器。对模型的修改通常可以概括为"增删改查"的操作。"增"是增加相应结构,"删"是删除相应结构,"改"是修改相应结构,"查"是获取到指定的模型结构。修改ONNX模型通常有两种思路,一是使用ONNX官方提供的Python API;二是使用第三方ONNX模型修改工具,例如onnx-graphsurgeon工具。本文将聚焦第一种方案,介绍如何使用ONNX官方API来对ONNX模型进行"增删改查"修改。完整的ONNX官方API文档可以参考:https://onnx.ai/onnx/index.html。
1. ONNX 模型的"查"
我们想要修改ONNX模型,首先需要知道如何定位到自己感兴趣的位置,比如如何找到具体某个节点、某个 initializer,、计算图的input/output, 某个节点的 input/output以及某个 value_info。参考下面的代码,我们可以发现定位某个元素的基本思路就是遍历该元素的列表,然后根据该元素在计算图中独有的属性名称来实现定位。下面的代码实现了定位下图模型的各个元素。
# 根据算子的名字来找到目标节点
for item in model.graph.node:
if item.name == 'Conv_1':
print(item)
# 有的onnx模型中算子没有name属性,可以根据算子类型和输出的名字来组合找到目标节点
for item in model.graph.node:
if item.op_type == 'Conv':
if '1338' in item.output:
print(item)
# 找到目标 intializer
for i in model.graph.initializer:
if i.name == '1339':
print(i.dims)
print(i.dims)
print(i.data_type)
# 二进制形式打印,可能比较长
print(i.raw_data)
# 找到 graph 的input和output
for i in model.graph.input:
if i.name == 'input':
print(i.name)
print(i.type)
# 找到 graph 的valueinfo
for i in model.graph.value_info:
if i.name == '9':
print(i.name)
print(i.type)
2. ONNX 模型的"删"
在了解了如何定位到需要修改的部分后,我们就可以对ONNX模型进行魔改了。我们首先了解如何删除ONNX模型中的指定节点或元素。下面的代码实现了删除图中标注的节点。
import onnx
# 加载模型
model = onnx.load('./super-resolution-10.onnx')
# 根据输入获取指定节点
def get_node_with_input(model, input_name):
res = []
for i in model.graph.node:
if input_name in i.input:
res.append(i)
return res
# 根据输出获取指定节点
def get_node_with_output(model, output_name):
res = []
for i in model.graph.node:
if output_name in i.output:
res.append(i)
return res
# 删除指定节点并将前后节点连接起来
remove_nodes = []
p = None
n = None
for i in model.graph.node:
if '10' in i.input:
# p = find_node_with_output(i.input[0])
p = get_node_with_output(model, i.input[0])[0]
remove_nodes.append(i)
if '11' in i.input:
# n = find_node_with_input(i.output[0])
n = get_node_with_input(model, i.output[0])[0]
remove_nodes.append(i)
n.input[0] = p.output[0]
for i in remove_nodes:
model.graph.node.remove(i)
onnx.checker.check_model(model)
onnx.save(model, 'super-resolution-10-delete.onnx')
3. ONNX 模型的"增"
"增"是指在ONNX模型指定位置添加节点。在了解添加节点之前,我们首先需要了解如何创建 ONNX 节点。下面以创建一个2D卷积算子和一个ReLu算子为例,并尝试将上一步骤中删除的这两个节点重新添加回模型当中(注意我们权重没有与原模型保持一致)。
node1 = onnx.helper.make_node(
name="Conv_0", # 节点名字,不要和op_type搞混了
op_type="Conv", # 节点的算子类型, 比如'Conv'、'Relu'、'Add'这类,详细可以参考onnx给出的算子列表
inputs=["image", "conv.weight", "conv.bias"], # 各个输入的名字,结点的输入包含:输入和算子的权重。必有输入X和权重W,偏置B可以作为可选。
outputs=["11"],
pads=[1, 1, 1, 1], # 其他字符串为节点的属性,attributes在官网被明确的给出了,标注了default的属性具备默认值。
group=1,
dilations=[1, 1],
kernel_shape=[3, 3],
strides=[1, 1]
)
initializer_w = onnx.helper.make_tensor(
name="conv.weight",
data_type=onnx.helper.TensorProto.DataType.FLOAT,
dims=[64, 64, 3, 3],
vals=np.ones([64,64,3,3], dtype=np.float32).tobytes(),
raw=True
)
initializer_b = onnx.helper.make_tensor(
name="conv.bias",
data_type=onnx.helper.TensorProto.DataType.FLOAT,
dims=[64],
vals=np.ones([64], dtype=np.float32).tobytes(),
raw=True
)
node2 = onnx.helper.make_node(
name="ReLU_1",
op_type="Relu",
inputs=["11"],
outputs=["12"]
)
下面代码将上述创建的两个节点插入到模型指定位置。
for i in range(len(model.graph.node)):
if '10' in model.graph.node[i].output:
model.graph.node[i].output[0] = 'pre_output'
model.graph.node[i+1].input[0] = 'relu_output'
model.graph.node.insert(i+1, node1)
model.graph.node.insert(i+2, node2)
model.graph.initializer.append(initializer_w)
model.graph.initializer.append(initializer_b)
input = model.graph.input[0]
new_input = onnx.helper.make_tensor_value_info(input.name, onnx.TensorProto.FLOAT, [1,1,224,224])
model.graph.input[0].CopyFrom(new_input)
onnx.checker.check_model(model)
model = onnx.shape_inference.infer_shapes(model)
onnx.save(model, 'super-resolution-10-insert.onnx')
4. ONNX 模型的"改"
通常来说修改 ONNX 模型可以概括为一下两种:
- 修改模型节点
- 修改权重(initializer)
修改模型的节点可以通过上述的删除 + 添加节点组合操作来实现,这里不再赘述。下面将介绍如何修改节点权重。节点权重通常保存在initializer中,下面代码尝试将Conv算子中的bias缩小10倍。
import onnx
model = onnx.load("./super-resolution-10.onnx")
# 得到所有 initializer
all_initializer = model.graph.initializer
# 定位到目标 initializer
target_initializer = 'conv1.bias'
idx = ''
scale_factor = 10
for i, j in enumerate(all_initializer):
if j.name == target_initializer:
idx = i
break
# 将 conv1 算子的 bias 缩小10倍
model.graph.initializer[idx].raw_data = (onnx.numpy_helper.to_array(all_initializer[idx]) / scale_factor).tobytes()
onnx.save(model,'super-resolution-10-scale.onnx')
总结
当我们在实际部署模型时,会根据具体硬件特性来在 ONNX 模型层面做相应的优化修改,使其能在特定的硬件平台上获得更好的推理性能。本文简单介绍了如何调用 ONNX 官方API来对 ONNX 模型进行增删改查,更加复杂的模型修改操作通常是上述四种操作的各种组合。
使用ONNX 官方API需要我们对 ONNX 模型的定义和Proto结构足够熟悉,并且通过本文中的示例代码可以看到,繁多复杂的API在使用过程中也不是很方便。在实际工作中,我们一般使用NV提供的onnx-graphsurgeon工具来快速对ONNX模型进行修改验证。这个工具在官方ONNX API的基础上提供了更为友好的高级API封装,大大提升了我们修改ONNX模型的效率,在之后的文章中我们将进一步详细介绍这个工具的使用。
作者:高通工程师,阮慧源(Huiyuan Ruan)