将PyTorch模型转换为ONNX(Open Neural Network Exchange)格式是实现模型跨平台部署和优化推理性能的一种常见方法。PyTorch 提供了多种方式来完成这一转换,以下是几种主要的方法:
一、静态模型转换
使用 torch.onnx.export()
torch.onnx.export()
是 PyTorch 官方推荐的最常用方法,适用于大多数情况。它允许你将一个 PyTorch 模型及其输入数据一起导出为 ONNX 格式。
基本用法
import torch
import torch.onnx
# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ... # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224) # 示例输入,形状取决于模型的输入要求
# 设置模型为评估模式
model.eval()
# 导出为 ONNX 文件
torch.onnx.export(
model, # 要导出的模型
dummy_input, # 模型的输入张量
"model.onnx", # 输出文件名
export_params=True, # 是否导出模型参数
opset_version=11, # ONNX 操作集版本
do_constant_folding=True, # 是否执行常量折叠优化
input_names=['input'], # 输入节点名称
output_names=['output'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, # 动态轴,支持可变批次大小
'output': {0: 'batch_size'}}
)
关键参数说明
model
: 要导出的 PyTorch 模型。dummy_input
: 一个与模型输入形状匹配的张量,用于模拟实际输入。export_params
: 是否导出模型的参数(权重和偏置)。通常设置为True
。opset_version
: 指定要使用的 ONNX 操作集版本。不同的版本可能支持不同的操作符。建议使用较新的版本(如 11 或 13)。do_constant_folding
: 是否执行常量折叠优化,可以减少模型的计算量。input_names
和output_names
: 指定 ONNX 模型的输入和输出节点的名称,方便后续加载和调用。dynamic_axes
: 指定哪些维度是动态的(即可以在推理时变化),例如批次大小或序列长度。
二、复杂模型转换
对于一些复杂的模型,特别是包含控制流(如条件语句、循环等)的模型,torch.onnx.export()
可能无法直接处理。这时可以先使用 torch.jit.trace()
将模型转换为 TorchScript 格式,然后再导出为 ONNX。
基本用法
import torch
import torch.onnx
# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ... # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224) # 示例输入
# 设置模型为评估模式
model.eval()
# 使用 torch.jit.trace 将模型转换为 TorchScript
traced_model = torch.jit.trace(model, dummy_input)
# 导出为 ONNX 文件
torch.onnx.export(
traced_model, # 已经转换为 TorchScript 的模型
dummy_input, # 模型的输入张量
"traced_model.onnx", # 输出文件名
export_params=True, # 是否导出模型参数
opset_version=11, # ONNX 操作集版本
do_constant_folding=True,# 是否执行常量折叠优化
input_names=['input'], # 输入节点名称
output_names=['output'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, # 动态轴
'output': {0: 'batch_size'}}
)
三、动态模型转换
使用 torch.onnx.dynamo_export()
torch.onnx.dynamo_export()
是 PyTorch 2.0 引入的新功能,基于 PyTorch 的 Dynamo 编译器。它旨在提供更好的性能和更广泛的模型支持,尤其是对于那些包含动态控制流的模型。
基本用法
import torch
# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ... # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224) # 示例输入
# 设置模型为评估模式
model.eval()
# 使用 dynamo_export 导出为 ONNX 文件
torch.onnx.dynamo_export(
model, # 要导出的模型
dummy_input, # 模型的输入张量
"dynamo_model.onnx" # 输出文件名
)
注意:torch.onnx.dynamo_export()
是 PyTorch 2.0 中引入的功能,确保你使用的是最新版本的 PyTorch。
四、自定义操作符模型转换
自定义操作符(Custom Operator)是指那些不在标准 PyTorch 或 ONNX 操作集中的操作符。当你需要实现某些特定的功能或优化时,可能需要编写自定义的操作符,并将其注册到 ONNX 中以便在导出和推理时使用。
例子:实现一个自定义的 ReLU6 操作符
假设我们想要实现一个自定义的 ReLU6
操作符。ReLU6
是一种常用的激活函数,它与标准的 ReLU
类似,但有一个上限值 6。其数学表达式为:
1. 实现自定义操作符
首先,我们需要在 C++ 中实现这个自定义操作符,并编译成一个共享库。PyTorch 提供了 torch::jit::custom_ops
接口来注册自定义操作符,而 ONNX 则提供了 onnxruntime
来注册自定义操作符。
1.1 在 PyTorch 中实现自定义操作符
我们可以在 C++ 中实现 ReLU6
操作符,并通过 PyTorch 的 torch::jit::custom_ops
接口将其注册到 PyTorch 中:
// custom_relu6.cpp
#include <torch/script.h>
#include <torch/custom_class.h>
// 定义自定义的 ReLU6 操作符
torch::Tensor custom_relu6(const torch::Tensor& input) {
return torch::clamp(input, 0, 6);
}
// 注册自定义操作符
static auto registry = torch::RegisterOperators("custom_ops::relu6", &custom_relu6);
1.2 编译自定义操作符
接下来,我们需要将这个 C++ 文件编译成一个共享库(例如 .so
文件),以便在 Python 中加载:
# 使用 PyTorch 提供的工具进行编译
python -m pip install torch torchvision torchaudio
python -m torch.utils.cpp_extension.build_ext --inplace custom_relu6.cpp
这会生成一个名为 custom_relu6.so
的共享库文件;
2. 在 PyTorch 中使用自定义操作符
现在我们可以在 Python 中加载并使用这个自定义操作符;
import torch
import torch.nn as nn
import custom_relu6 # 加载编译后的共享库
# 定义一个使用自定义 ReLU6 操作符的模型
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv(x)
# 调用自定义的 ReLU6 操作符
x = torch.ops.custom_ops.relu6(x)
return x
# 创建模型实例
model = CustomModel()
model.eval()
# 准备示例输入
dummy_input = torch.randn(1, 3, 224, 224)
# 运行模型
output = model(dummy_input)
print(output.shape) # 输出形状应为 (1, 16, 224, 224)
3. 将自定义操作符导出为 ONNX
为了将包含自定义操作符的模型导出为 ONNX 格式,我们需要告诉 ONNX 如何处理这个自定义操作符。我们可以使用 torch.onnx.register_custom_op_symbolic
来定义 ONNX 符号函数,从而在导出时正确处理自定义操作符。
3.1 定义 ONNX 符号函数
我们需要定义一个符号函数,告诉 ONNX 如何表示 custom_ops::relu6
操作符。这个符号函数会生成相应的 ONNX 操作符节点。
import torch.onnx
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
# 定义 ONNX 符号函数
@parse_args('v')
def symbolic_custom_relu6(g, input):
# 使用 ONNX 的 Clip 操作符来实现 ReLU6
return g.op("Clip", input, min_f=0.0, max_f=6.0)
# 注册自定义操作符的符号函数
register_custom_op_symbolic('custom_ops::relu6', symbolic_custom_relu6, 9) # 9 表示 ONNX 操作集版本
3.2 导出为 ONNX
现在我们可以将模型导出为 ONNX 格式,并确保自定义操作符被正确处理。
# 导出为 ONNX 文件
torch.onnx.export(
model, # 要导出的模型
dummy_input, # 模型的输入张量
"custom_model.onnx", # 输出文件名
export_params=True, # 是否导出模型参数
opset_version=9, # ONNX 操作集版本
do_constant_folding=True, # 是否执行常量折叠优化
input_names=['input'], # 输入节点名称
output_names=['output'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, # 动态轴
'output': {0: 'batch_size'}}
)
4. 在 ONNX Runtime 中使用自定义操作符
为了在 ONNX Runtime 中使用自定义操作符,我们需要将自定义操作符的实现编译成一个 ONNX Runtime 扩展库,并在推理时加载该扩展库。
4.1 实现 ONNX Runtime 自定义操作符
我们需要在 C++ 中实现 ReLU6
操作符,并将其注册到 ONNX Runtime 中。
// custom_relu6_onnxruntime.cpp
#include "onnxruntime/core/providers/cpu/cpu_provider_factory.h"
#include "onnxruntime/core/framework/op_kernel.h"
namespace onnxruntime {
class CustomRelu6 : public OpKernel {
public:
explicit CustomRelu6(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override {
// 获取输入张量
const Tensor* input_tensor = context->Input<Tensor>(0);
if (!input_tensor) return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is null");
// 获取输出张量
Tensor* output_tensor = context->Output(0, input_tensor->Shape());
if (!output_tensor) return Status(common::ONNXRUNTIME, common::FAIL, "Output tensor is null");
// 获取输入和输出的数据指针
float* input_data = input_tensor->template Data<float>();
float* output_data = output_tensor->template Data<float>();
// 计算 ReLU6
size_t size = input_tensor->Shape().Size();
for (size_t i = 0; i < size; ++i) {
output_data[i] = std::min(std::max(input_data[i], 0.0f), 6.0f);
}
return Status::OK();
}
};
ONNX_OPERATOR_KERNEL(
Relu6, // 操作符名称
kOnnxDomain, // 命名空间
9, // 操作集版本
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), // 数据类型约束
CustomRelu6); // 自定义操作符类
}
4.2 编译 ONNX Runtime 自定义操作符
我们将上述代码编译成一个动态链接库(例如 .so
文件),以便在 ONNX Runtime 中加载。
# 使用 ONNX Runtime 提供的工具进行编译
g++ -shared -fPIC -o custom_relu6_onnxruntime.so custom_relu6_onnxruntime.cpp -lonnxruntime
4.3 在 ONNX Runtime 中加载自定义操作符
最后,我们在 Python 中使用 onnxruntime
加载自定义操作符,并运行推理。
import onnxruntime as ort
import numpy as np
# 加载 ONNX 模型
ort_session = ort.InferenceSession("custom_model.onnx", providers=['CPUExecutionProvider'])
# 加载自定义操作符的扩展库
ort_session.load_custom_ops_library("custom_relu6_onnxruntime.so")
# 准备输入数据
ort_inputs = {'input': dummy_input.numpy()} # 将 PyTorch 张量转换为 NumPy 数组
# 运行推理
ort_outs = ort_session.run(None, ort_inputs)
# 获取 PyTorch 模型的输出
with torch.no_grad():
torch_out = model(dummy_input)
# 比较 ONNX 和 PyTorch 的输出
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print("ONNX 模型验证通过!")
总结
- 自定义操作符:当你的模型中包含不在标准 PyTorch 或 ONNX 操作集中的操作符时,你可以通过编写自定义操作符来实现这些功能。
- PyTorch 中的自定义操作符:可以使用
torch::jit::custom_ops
接口在 C++ 中实现自定义操作符,并通过共享库加载到 PyTorch 中。 - ONNX 中的自定义操作符:可以通过
torch.onnx.register_custom_op_symbolic
定义符号函数,告诉 ONNX 如何处理自定义操作符。然后,在 ONNX Runtime 中,可以通过编译自定义操作符的实现并加载扩展库来支持推理。 - 复杂性:实现自定义操作符通常比较复杂,因为它涉及到跨语言编程(C++ 和 Python)、编译和链接等多个步骤。然而,这对于实现特定功能或优化模型是非常有用的。
通过这个例子,你可以看到如何从头实现一个自定义操作符,并将其集成到 PyTorch 和 ONNX 中。