目录
- 前言
- 0. 简述
- 1. 执行一下我们的python程序
- 2.转换swin-tiny时候出现的不兼容op的例子
- 3. 当出现导出onnx不成功的时候,我们需要考虑的事情
- 4. unsupported asinh算子
- 5. unsupported deformable conv算子
- 总结
- 参考
前言
自动驾驶之心推出的 《CUDA与TensorRT部署实战课程》,链接。记录下个人学习笔记,仅供自己参考
本次课程我们来学习课程第三章—TensorRT 基础入门,一起来学习 ONNX 注册算子的方法
课程大纲可以看下面的思维导图
0. 简述
本小节目标:学习 pytorch 导出 onnx 不成功的时候如何解决(without plugin篇)
这节我们学习第三章节第六小节—ONNX 注册算子的方法,学习当 pytorch 导出 onnx 失败时的解决方法,比较常见的就是我们在使用开源代码将它导出 ONNX 时发现算子不兼容,还有导出成功但是转成 TensorRT 的 engine 时又出现算子不兼容,这里就简单过一下遇到这些情况我们应该怎么做
因为现阶段我们还没有讲如何利用 C++ 去写一个 plugin,所以我们先不考虑这个算子的 C++ 是如何实现的,我们现在主要是 pytorch 出现一些不兼容的算子导出 onnx 的时候该怎么做
1. 执行一下我们的python程序
源代码获取地址:https://github.com/kalfazed/tensorrt_starter
这个小节的案例主要是 3.4-export-unsupported-node,如下所示:
2.转换swin-tiny时候出现的不兼容op的例子
先给大家做一个简单的背景介绍,按照 swin-transformer 官方文档导出 onnx 时可能会出现如下问题:
roll 这个算子在 opset9 不支持,那我们再看看 opset12:
可以发现在 opset12 时依旧出现同样的问题
我们到对应的代码中可以找到 torch.roll
它在 onnx opset 中是不兼容的,那我们应该怎么办呢?我们在下一个小节会跟大家去讲
3. 当出现导出onnx不成功的时候,我们需要考虑的事情
当出现导出 onnx 不成功的时候,我们主要有以下几个方法,难易度从低到高:
- 修改 opset 的版本
- 查看不支持的算子在新的 opset 中是否被支持
- 如果不考虑自己搭建 plugin 的话,也需要看看 onnx-trt 中这个算子是否被支持
- 因为 onnx 是一种图结构表示,并不包含各个算子的实现。除非我们是要在 onnxruntime 上测试,否则我们更看重 onnx-trt 中这个算子的支持情况
- 算子支持文档:https://github.com/onnx/onnx/blob/main/docs/Operators.md
- 替换 pytorch 中的算子组合
- 把某些计算替换成 onnx 可以识别的
- 在 pytorch 登记 onnx 中某些算子
- 有可能 onnx 中有支持,但没有被登记,比如 Asinh
- 直接修改 onnx 创建 plugin
- 使用 onnxsurgeon
- 一般是用在加速某些算子上使用
这个小节主要是给大家介绍第三种方法,注册登记算子
4. unsupported asinh算子
我们先执行下 sample_asinh.py 案例代码,输出如下所示:
可以看到导出 onnx 出错了,原因是 aten::asinh
在 ONNX opset9 是不支持的,代码如下:
import torch
import torch.onnx
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.asinh(x)
return x
def infer():
input = torch.rand(1, 5)
model = Model()
x = model(input)
print("input is: ", input.data)
print("result is: ", x.data)
def export_norm_onnx():
input = torch.rand(1, 5)
model = Model()
model.eval()
file = "../models/sample-asinh.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 9)
print("Finished normal onnx export")
if __name__ == "__main__":
infer()
# 这里导出asinh会出现错误。
# Pytorch可以支持asinh的同时,
# def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
# 从onnx支持的算子里面我们可以知道自从opset9开始asinh就已经被支持了
# asinh is suppored since opset9
# 所以我们可以知道,问题是出现在PyTorch与onnx之间没有建立asinh的映射
# 我们需要建立这个映射。这里涉及到了注册符号函数的概念,详细看PPT
export_norm_onnx()
出现导出问题我们先去寻找官方文档 https://github.com/onnx/onnx/blob/main/docs/Operators.md,如下图所示:
这里我们可以看到 asinh 算子在 opset9 这个版本已经开始支持了,但是导出就是不行,为什么呢?其实问题是出现在 pytorch 与 onnx 之间没有建立 asinh 的映射,我们需要自己来建立这个映射
我们可以在 torch/onnx/symbolic_opset9.py 找到 pytorch 注册到 onnx 中支持的算子,如下图所示:
Note:低版本 torch 下面的 symbolic_opset9.py 文件中可能并没有上图中的内容,博主 torch 版本是 2.0.1
针对不同的算子我们可以看到如下代码:
# torch/onnx/symbolic_opset9.py/L317
@_onnx_symbolic("aten::_shape_as_tensor")
@_beartype.beartype
def _shape_as_tensor(g: jit_utils.GraphContext, input):
return g.op("Shape", input)
@_onnx_symbolic("aten::_reshape_from_tensor")
@_beartype.beartype
def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape):
if isinstance(shape, list):
shape = g.op("Concat", *shape, axis_i=0)
return reshape(g, input, shape)
@_onnx_symbolic("aten::reshape")
@symbolic_helper.quantized_args(True)
@_beartype.beartype
def reshape(g: jit_utils.GraphContext, self, shape):
return symbolic_helper._reshape_helper(g, self, shape)
...
这些代码是 PyTorch 中用于将特定的 aten
(PyTorch 的内部操作库) 操作符转换为 ONNX 操作符的符号化函数。这些符号化函数定义了如何将 PyTorch 的操作符映射到等效的 ONNX 操作符,其实它们就是做了算子注册这件事
我们以 reshape 为例简单分析下 pytorch 到 onnx 的算子注册是怎么做的:(from ChatGPT)
1. @_onnx_symbolic(“aten::reshape”)
- 这是一个装饰器,用于将 PyTorch 的
aten::reshape
操作符映射到 ONNX 的等效操作符。 @_onnx_symbolic
是一个装饰器函数,用于注册和标记 PyTorch 操作符与其 ONNX 对应操作符之间的关系。"aten::reshape"
指定了 PyTorch 中的reshape
操作符。
2. @symbolic_helper.quantized_args(True)
- 指定函数是否处理量化参数。
@symbolic_helper.quantized_args
是一个装饰器,用于指示该函数是否支持量化参数。True
表示该函数支持量化参数。如果模型中有量化的操作,该装饰器会确保这些操作被正确处理和导出。
3. @_beartype.beartype
- 类型检查装饰器,确保函数的输入参数和返回类型符合预期。
@_beartype.beartype
是beartype
库提供的装饰器,用于在运行时进行类型检查,确保参数类型和返回值类型正确。
4. def reshape(g: jit_utils.GraphContext, self, shape):
- 定义 reshape 函数。
- g: 表示
jit_utils.GraphContext
类型的对象,代表当前的图上下文,用于在图中添加节点。 - self: 输入的张量,代表要重塑的张量。
- shape: 表示新的形状,可以是一个表示形状的张量或一个列表。
5. return symbolic_helper._reshape_helper(g, self, shape)
- 调用
_reshape_helper
函数执行实际的重塑操作,并返回结果。
symbolic_helper._reshape_helper(g, self, shape)
是一个帮助函数,用于执行重塑逻辑。- g: 传递当前的图上下文。
- self: 传递要重塑的输入张量。
- shape: 传递目标形状。
其中的 aten::xxx 是 C++ 的一个 namespace,pytorch 的很多算子的底层都是在 aten 这个命名空间下进行实现的,而 aten(a Tensor Library)是一个实现张量运算的 C++ 库。onnx_symblic 负责绑定,使得 pytorch 中的算子与 aten 命名空间下的算子一一对应
补充:g.op
函数是 PyTorch 中用于在 ONNX 计算图中添加操作节点的函数。它是 jit_utils.GraphContext
类的一个方法,用于定义和插入新的 ONNX 操作节点。其参数定义如下:
- op_type: 操作的类型,例如
"Reshape"
,"Concat"
,"Shape"
等。这些操作类型是 ONNX 操作符集中的名称。 - inputs: 传递给操作节点的输入张量。可以是一个或多个张量。
- attributes: 可选参数,指定操作的属性,通常以
key=value
的形式传递。
通过这么一套操作我们就可以把 ONNX 中的某个算子和 Pytorch 底层 aten 命名空间下的算子实现给绑定起来,这就是一个算子的注册
因此之前 asinh 导出问题就是因为 pytorch 中的底层算子实现和 onnx 中的算子没有绑定,我们手动绑定下即可。所以我们来看 sample_asinh_register.py 案例代码,如下所示:
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic
# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
# g: 就是graph,计算图
# 也就是说,在计算图中添加onnx算子
# 由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
# symblic的参数需要与Pytorch的asinh接口函数的参数对齐
# def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):
return g.op("Asinh", input)
# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的
# 那么aten是什么呢?
# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)
# 这里容易混淆的地方:
# 1. register_op中的第一个参数是PyTorch中的算子名字: aten::asinh
# 2. g.op中的第一个参数是onnx中的算子名字: Asinh
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.asinh(x)
return x
def validate_onnx():
input = torch.rand(1, 5)
# PyTorch的推理
model = Model()
x = model(input)
print("result from Pytorch is :", x)
# onnxruntime的推理
sess = onnxruntime.InferenceSession('../models/sample-asinh.onnx')
x = sess.run(None, {'input0': input.numpy()})
print("result from onnx is: ", x)
def export_norm_onnx():
input = torch.rand(1, 5)
model = Model()
model.eval()
file = "../models/sample-asinh.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished normal onnx export")
if __name__ == "__main__":
export_norm_onnx()
# 自定义完onnx以后必须要进行一下验证
validate_onnx()
这段代码展示了如何在 PyTorch 中创建和注册一个自定义的符号函数 (symbolic function
),以便将 PyTorch 的 asinh
操作导出到 ONNX 格式,并在 ONNX Runtime 中进行推理验证
函数 register_custom_op_symbolic
是 PyTorch ONNX 导出工具中用来注册自定义算子的符号映射函数。此函数对于在 ONNX 中支持 PyTorch 中的自定义或特殊算子至关重要。下面我会逐一解释函数参数和作用:(from CHatGPT)
参数解释
-
symbolic_name (str):
- 这是需要注册的自定义算子的名称。格式通常为
<domain>::<op>
,其中<domain>
表示算子所属的命名空间,而<op>
是算子的名称。例如,aten::asinh
表示来自aten
域的asinh
算子。
- 这是需要注册的自定义算子的名称。格式通常为
-
symbolic_fn (Callable):
-
这是一个函数,用于定义如何将 PyTorch 中的操作转换为 ONNX 图中的节点。这个函数通常会接收几个参数:ONNX 图对象、当前算子的输入参数等,并且基于这些输入构造并返回 ONNX 中相应的算子节点。
-
symbolic_fn
应返回一个或多个用于替换原有 PyTorch 算子的 ONNX 操作节点。这个函数的实现需要考虑输入输出的匹配、算子属性的转换等。
-
-
opset_version (int):
- 这是指定算子应当注册到的 ONNX 操作集版本。ONNX 的操作集(opset)定义了算子的支持和行为,不同版本的 opset 可能支持不同的算子或者同一算子的不同行为。这个参数确保你的自定义算子兼容特定版本的 ONNX。
执行后输出如下图所示:
导出的 ONNX 如下图所示:
我们再来看另外一种写法,案例 sample_asinh_register2.py 代码如下所示:
import torch
import torch.onnx
import onnxruntime
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
# 另外一个写法
# 这个是类似于torch/onnx/symbolic_opset*.py中的写法
# 通过torch._internal中的registration来注册这个算子,让这个算子可以与底层C++实现的aten::asinh绑定
# 一般如果这么写的话,其实可以把这个算子直接加入到torch/onnx/symbolic_opset*.py中
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):
return g.op("Asinh", input)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.asinh(x)
return x
def validate_onnx():
input = torch.rand(1, 5)
# PyTorch的推理
model = Model()
x = model(input)
print("result from Pytorch is :", x)
# onnxruntime的推理
sess = onnxruntime.InferenceSession('../models/sample-asinh2.onnx')
x = sess.run(None, {'input0': input.numpy()})
print("result from onnx is: ", x)
def export_norm_onnx():
input = torch.rand(1, 5)
model = Model()
model.eval()
file = "../models/sample-asinh2.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished normal onnx export")
if __name__ == "__main__":
export_norm_onnx()
# 自定义完onnx以后必须要进行一下验证
validate_onnx()
这段代码展示了另一种注册自定义符号函数的方法,使用了 torch.onnx._internal
中的 registration
模块,通过装饰器的方式注册 asinh
操作符
与之前方式不同的是它通过两个步骤来注册自定义符号函数:(from ChatGPT)
1. 创建 _onnx_symbolic
的局部函数
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
-
使用
functools.partial
创建一个偏函数_onnx_symbolic
,指定 opset 为 9。 -
该偏函数用于注册符号函数,简化了装饰器的使用。
2. 定义 asinh_symbolic
符号函数并注册
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):
return g.op("Asinh", input)
- 使用装饰器
@_onnx_symbolic('aten::asinh')
注册asinh_symbolic
符号函数。 asinh_symbolic
函数用于在计算图 (graph) 中添加 ONNX 的 Asinh 操作。- 参数 g 是计算图对象,input 是输入张量。
g.op("Asinh", input)
表示在计算图中添加一个 Asinh 操作。
执行后输出如下图所示:
我们再来看下一个案例 sample_custom_op_autograd.py,代码如下所示:
import torch
import torch.onnx
import onnxruntime
OperatorExportTypes = torch._C._onnx.OperatorExportTypes
# 我们按照正常的方式,创建一个图。不考虑自己设计算子的话,我们其实是直接导出这个onnx的
# 只不过onnx会为这个实现根据自己内部注册的各个算子,追踪每一个节点生成图
# 我们可以观察一下onnx,会比较复杂,我们管这个叫做算子的inline autograd function
class CustomOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(x)
x = x.clamp(min=0)
return x / (1 + torch.exp(-x))
customOp = CustomOp.apply
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = customOp(x)
return x
def validate_onnx():
input = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)
# PyTorch的推理
model = Model()
x = model(input)
print("result from Pytorch is :\n", x)
# onnxruntime的推理
sess = onnxruntime.InferenceSession('../models/sample-customOp.onnx')
x = sess.run(None, {'input0': input.numpy()})
print("result from onnx is: \n", x)
def export_norm_onnx():
input = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)
model = Model()
model.eval()
# 我们可以在导出onnx的时候添operator_export_type的限制,防止导出的onnx进行inline
file = "../models/sample-customOp.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
# operator_export_type = OperatorExportTypes.ONNX_FALLTHROUGH)
print("Finished normal onnx export")
if __name__ == "__main__":
export_norm_onnx()
# 自定义完onnx以后必须要进行一下验证
validate_onnx()
该代码定义了一个自定义的 PyTorch 算子 CustomOp
,并在一个简单模型中使用它。然后将模型导出为 ONNX 格式,并通过 onnxruntime 验证导出的 ONNX 模型是否与原始 PyTorch 模型输出一致。
执行后输出如下图所示:
导出的 ONNX 如下图所示:
可以看到如果不加以特殊处理,自定义算子可能会被内联到标准 ONNX 算子中,这个过程可能会导致模型结构的复杂性增加,会导致产生一些多余的算子,而这些算子都将被 trace,这是我们不希望看到的,我们希望它导出更加简洁一点
那我们该怎么做,我们来看 sample_custom_op_autograd_register.py 案例,代码如下所示:
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic
OperatorExportTypes = torch._C._onnx.OperatorExportTypes
class CustomOp(torch.autograd.Function):
@staticmethod
def symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:
return g.op("custom_domain::customOp2", x)
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(x)
x = x.clamp(min=0)
return x / (1 + torch.exp(-x))
customOp = CustomOp.apply
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = customOp(x)
return x
def validate_onnx():
input = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)
# PyTorch的推理
model = Model()
x = model(input)
print("result from Pytorch is :\n", x)
# onnxruntime的推理
sess = onnxruntime.InferenceSession('../models/sample-customOp2.onnx')
x = sess.run(None, {'input0': input.numpy()})
print("result from onnx is: \n", x)
def export_norm_onnx():
input = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)
model = Model()
model.eval()
file = "../models/sample-customOp2.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished normal onnx export")
if __name__ == "__main__":
export_norm_onnx()
# 自定义完onnx以后必须要进行一下验证
validate_onnx()
它与之前的代码相比增加了对自定义算子的符号化注册,使得导出的 ONNX 模型更加简洁,主要区别有:
1. 增加 symbolic
方法:
- 在
CustomOp
类中添加了symbolic
方法,用于定义自定义算子在 ONNX 中的符号化表示。
2. 注册自定义算子符号:
- 通过
register_custom_op_symbolic
注册自定义算子,使其在导出 ONNX 时使用定义的符号化操作。
输出如下图所示:
导出的 ONNX 如下图所示:
可以看到导出的 ONNX 就一个我们的自定义算子非常简洁,但是我们在利用 onnxruntime 进行推理时发生了错误,这其实是因为 onnxruntime 在执行推理时无法识别我们的自定义算子 custom_domain::customOp2
,我们需要在 onnxruntime 中进行相关自定义算子的实现才能正确推理
5. unsupported deformable conv算子
我们再来看 deformable conv 这个案例,先看 sample_deformable_conv.py 代码如下所示:
import torch
import torch.nn as nn
import torchvision
import torch.onnx
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 18, 3)
self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)
def forward(self, x):
x = self.conv2(x, self.conv1(x))
return x
def infer():
input = torch.rand(1, 3, 5, 5)
model = Model()
x = model(input)
print("input is: ", input.data)
print("result is: ", x.data)
def export_norm_onnx():
input = torch.rand(1, 3, 5, 5)
model = Model()
model.eval()
file = "../models/sample-deformable-conv.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished normal onnx export")
if __name__ == "__main__":
infer()
# 这里导出deformable-conv会出现错误。
# torchvision支持deformable_conv的
# 但是我们在onnx中是没有找到有关deformable conv的支持
# 所以这个时候,我们需要做两件事情
export_norm_onnx()
执行完输出如下图所示:
可以看到导出失败了,deformable conv 其实在 torchvision 中有实现,但是我们在 onnx 中没有找到有关deformable conv 的支持
这个时候我们其实注册下算子就行了,来看案例 sample_deformable_conv_register.py,代码如下:
import torch
import torch.nn as nn
import torchvision
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
# 注意
# 这里需要把args的各个参数的类型都指定
# 这里还没有实现底层对deform_conv2d的实现
# 具体dcn的底层实现是在c++完成的,这里会在后面的TensorRT plugin中回到这里继续讲这个案例
# 这里先知道对于不支持的算子,onnx如何导出即可
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i","i", "i", "i", "none")
def dcn_symbolic(
g,
input,
weight,
offset,
mask,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps,
use_mask):
return g.op("custom::deform_conv2d", input, offset)
register_custom_op_symbolic("torchvision::deform_conv2d", dcn_symbolic, 12)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 18, 3)
self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)
def forward(self, x):
x = self.conv2(x, self.conv1(x))
return x
def validate_onnx():
input = torch.rand(1, 3, 5, 5)
# PyTorch的推理
model = Model()
x = model(input)
print("result from Pytorch is :", x)
# onnxruntime的推理
sess = onnxruntime.InferenceSession('../models/sample-deformable-conv.onnx')
x = sess.run(None, {'input0': input.numpy()})
print("result from onnx is: ", x)
def infer():
input = torch.rand(1, 3, 5, 5)
model = Model()
x = model(input)
print("input is: ", input.data)
print("result is: ", x.data)
def export_norm_onnx():
input = torch.rand(1, 3, 5, 5)
model = Model()
model.eval()
file = "../models/sample-deformable-conv.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished normal onnx export")
if __name__ == "__main__":
# infer()
export_norm_onnx()
validate_onnx()
这段代码通过定义和注册自定义算子符号,将包含 deformable convolution 算子的 PyTorch 模型导出为 ONNX 格式,值得注意的是 dcn_symbolic
中的 g.op("custom::deform_conv2d", input, offset)
仅定义了一个自定义操作,但并没有实际实现 deform_conv2d
的计算逻辑,这需要在后续的 onnxruntime/TensorRT 插件中实现。
另外 @parse_args()
是一个装饰器,用于指明 dcn_symbolic
函数各个参数的类型。具体解释如下:
- “v” (variable): 表示一个张量(Tensor)参数。
- “i” (integer): 表示一个整型参数。
- “none”: 表示一个可以为
None
的参数。
执行后输出如下图所示:
导出的 ONNX 如下图所示:
总结
本次课程我们主要学习了 ONNX 注册算子的方法,主要是利用 register_custom_op_symbolic 函数来注册自定义算子的符号映射函数,从而将自定义的算子导出到 ONNX,值得注意的是这个小节我们并没有具体的实现自定义算子的计算逻辑,这需要在后续的 TensorRT 插件中实现,这个我们在之后的 plugin 小节再来详细讲解
OK,以上就是第 6 小节有关 ONNX 注册算子方法的全部内容了,下节我们来学习 onnx-graph-surgeon,敬请期待😄
参考
- https://github.com/kalfazed/tensorrt_starter
- https://github.com/onnx/onnx/blob/main/docs/Operators.md