三. TensorRT基础入门-ONNX注册算子的方法

目录

    • 前言
    • 0. 简述
    • 1. 执行一下我们的python程序
    • 2.转换swin-tiny时候出现的不兼容op的例子
    • 3. 当出现导出onnx不成功的时候,我们需要考虑的事情
    • 4. unsupported asinh算子
    • 5. unsupported deformable conv算子
    • 总结
    • 参考

前言

自动驾驶之心推出的 《CUDA与TensorRT部署实战课程》,链接。记录下个人学习笔记,仅供自己参考

本次课程我们来学习课程第三章—TensorRT 基础入门,一起来学习 ONNX 注册算子的方法

课程大纲可以看下面的思维导图

在这里插入图片描述

0. 简述

本小节目标:学习 pytorch 导出 onnx 不成功的时候如何解决(without plugin篇)

这节我们学习第三章节第六小节—ONNX 注册算子的方法,学习当 pytorch 导出 onnx 失败时的解决方法,比较常见的就是我们在使用开源代码将它导出 ONNX 时发现算子不兼容,还有导出成功但是转成 TensorRT 的 engine 时又出现算子不兼容,这里就简单过一下遇到这些情况我们应该怎么做

因为现阶段我们还没有讲如何利用 C++ 去写一个 plugin,所以我们先不考虑这个算子的 C++ 是如何实现的,我们现在主要是 pytorch 出现一些不兼容的算子导出 onnx 的时候该怎么做

1. 执行一下我们的python程序

源代码获取地址:https://github.com/kalfazed/tensorrt_starter

这个小节的案例主要是 3.4-export-unsupported-node,如下所示:

在这里插入图片描述

2.转换swin-tiny时候出现的不兼容op的例子

先给大家做一个简单的背景介绍,按照 swin-transformer 官方文档导出 onnx 时可能会出现如下问题:

在这里插入图片描述

roll 这个算子在 opset9 不支持,那我们再看看 opset12:

在这里插入图片描述

可以发现在 opset12 时依旧出现同样的问题

在这里插入图片描述

swin/models/swin_transformer.py

我们到对应的代码中可以找到 torch.roll 它在 onnx opset 中是不兼容的,那我们应该怎么办呢?我们在下一个小节会跟大家去讲

3. 当出现导出onnx不成功的时候,我们需要考虑的事情

当出现导出 onnx 不成功的时候,我们主要有以下几个方法,难易度从低到高:

  • 修改 opset 的版本
    • 查看不支持的算子在新的 opset 中是否被支持
    • 如果不考虑自己搭建 plugin 的话,也需要看看 onnx-trt 中这个算子是否被支持
    • 因为 onnx 是一种图结构表示,并不包含各个算子的实现。除非我们是要在 onnxruntime 上测试,否则我们更看重 onnx-trt 中这个算子的支持情况
    • 算子支持文档:https://github.com/onnx/onnx/blob/main/docs/Operators.md
  • 替换 pytorch 中的算子组合
    • 把某些计算替换成 onnx 可以识别的
  • 在 pytorch 登记 onnx 中某些算子
    • 有可能 onnx 中有支持,但没有被登记,比如 Asinh
  • 直接修改 onnx 创建 plugin
    • 使用 onnxsurgeon
    • 一般是用在加速某些算子上使用

这个小节主要是给大家介绍第三种方法,注册登记算子

4. unsupported asinh算子

我们先执行下 sample_asinh.py 案例代码,输出如下所示:

在这里插入图片描述

src/sample_asinh.py

可以看到导出 onnx 出错了,原因是 aten::asinh 在 ONNX opset9 是不支持的,代码如下:

import torch
import torch.onnx

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x = torch.asinh(x)
        return x

def infer():
    input = torch.rand(1, 5)
    
    model = Model()
    x = model(input)
    print("input is: ", input.data)
    print("result is: ", x.data)

def export_norm_onnx():
    input   = torch.rand(1, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-asinh.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 9)
    print("Finished normal onnx export")

if __name__ == "__main__":
    infer()

    # 这里导出asinh会出现错误。
    # Pytorch可以支持asinh的同时,
    #   def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...

    # 从onnx支持的算子里面我们可以知道自从opset9开始asinh就已经被支持了
    #   asinh is suppored since opset9

    # 所以我们可以知道,问题是出现在PyTorch与onnx之间没有建立asinh的映射
    # 我们需要建立这个映射。这里涉及到了注册符号函数的概念,详细看PPT
    export_norm_onnx()

出现导出问题我们先去寻找官方文档 https://github.com/onnx/onnx/blob/main/docs/Operators.md,如下图所示:

在这里插入图片描述

这里我们可以看到 asinh 算子在 opset9 这个版本已经开始支持了,但是导出就是不行,为什么呢?其实问题是出现在 pytorch 与 onnx 之间没有建立 asinh 的映射,我们需要自己来建立这个映射

我们可以在 torch/onnx/symbolic_opset9.py 找到 pytorch 注册到 onnx 中支持的算子,如下图所示:

在这里插入图片描述

Note:低版本 torch 下面的 symbolic_opset9.py 文件中可能并没有上图中的内容,博主 torch 版本是 2.0.1

针对不同的算子我们可以看到如下代码:

# torch/onnx/symbolic_opset9.py/L317

@_onnx_symbolic("aten::_shape_as_tensor")
@_beartype.beartype
def _shape_as_tensor(g: jit_utils.GraphContext, input):
    return g.op("Shape", input)


@_onnx_symbolic("aten::_reshape_from_tensor")
@_beartype.beartype
def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape):
    if isinstance(shape, list):
        shape = g.op("Concat", *shape, axis_i=0)
    return reshape(g, input, shape)


@_onnx_symbolic("aten::reshape")
@symbolic_helper.quantized_args(True)
@_beartype.beartype
def reshape(g: jit_utils.GraphContext, self, shape):
    return symbolic_helper._reshape_helper(g, self, shape)

...

这些代码是 PyTorch 中用于将特定的 aten (PyTorch 的内部操作库) 操作符转换为 ONNX 操作符的符号化函数。这些符号化函数定义了如何将 PyTorch 的操作符映射到等效的 ONNX 操作符,其实它们就是做了算子注册这件事

我们以 reshape 为例简单分析下 pytorch 到 onnx 的算子注册是怎么做的:(from ChatGPT)

1. @_onnx_symbolic(“aten::reshape”)

  • 这是一个装饰器,用于将 PyTorch 的 aten::reshape 操作符映射到 ONNX 的等效操作符。
  • @_onnx_symbolic 是一个装饰器函数,用于注册和标记 PyTorch 操作符与其 ONNX 对应操作符之间的关系。
  • "aten::reshape" 指定了 PyTorch 中的 reshape 操作符。

2. @symbolic_helper.quantized_args(True)

  • 指定函数是否处理量化参数。
  • @symbolic_helper.quantized_args 是一个装饰器,用于指示该函数是否支持量化参数。
  • True 表示该函数支持量化参数。如果模型中有量化的操作,该装饰器会确保这些操作被正确处理和导出。

3. @_beartype.beartype

  • 类型检查装饰器,确保函数的输入参数和返回类型符合预期。
  • @_beartype.beartypebeartype 库提供的装饰器,用于在运行时进行类型检查,确保参数类型和返回值类型正确。

4. def reshape(g: jit_utils.GraphContext, self, shape):

  • 定义 reshape 函数。
  • g: 表示 jit_utils.GraphContext 类型的对象,代表当前的图上下文,用于在图中添加节点。
  • self: 输入的张量,代表要重塑的张量。
  • shape: 表示新的形状,可以是一个表示形状的张量或一个列表。

5. return symbolic_helper._reshape_helper(g, self, shape)

  • 调用 _reshape_helper 函数执行实际的重塑操作,并返回结果。
  • symbolic_helper._reshape_helper(g, self, shape) 是一个帮助函数,用于执行重塑逻辑。
  • g: 传递当前的图上下文。
  • self: 传递要重塑的输入张量。
  • shape: 传递目标形状。

其中的 aten::xxx 是 C++ 的一个 namespace,pytorch 的很多算子的底层都是在 aten 这个命名空间下进行实现的,而 aten(a Tensor Library)是一个实现张量运算的 C++ 库。onnx_symblic 负责绑定,使得 pytorch 中的算子与 aten 命名空间下的算子一一对应

补充g.op 函数是 PyTorch 中用于在 ONNX 计算图中添加操作节点的函数。它是 jit_utils.GraphContext 类的一个方法,用于定义和插入新的 ONNX 操作节点。其参数定义如下:

  • op_type: 操作的类型,例如 "Reshape", "Concat", "Shape" 等。这些操作类型是 ONNX 操作符集中的名称。
  • inputs: 传递给操作节点的输入张量。可以是一个或多个张量。
  • attributes: 可选参数,指定操作的属性,通常以 key=value 的形式传递。

通过这么一套操作我们就可以把 ONNX 中的某个算子和 Pytorch 底层 aten 命名空间下的算子实现给绑定起来,这就是一个算子的注册

因此之前 asinh 导出问题就是因为 pytorch 中的底层算子实现和 onnx 中的算子没有绑定,我们手动绑定下即可。所以我们来看 sample_asinh_register.py 案例代码,如下所示:

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic

# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
#   g: 就是graph,计算图
#   也就是说,在计算图中添加onnx算子
#   由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
#   symblic的参数需要与Pytorch的asinh接口函数的参数对齐
#       def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的

# 那么aten是什么呢?
# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)


# 这里容易混淆的地方:
# 1. register_op中的第一个参数是PyTorch中的算子名字: aten::asinh
# 2. g.op中的第一个参数是onnx中的算子名字: Asinh

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x = torch.asinh(x)
        return x


def validate_onnx():
    input = torch.rand(1, 5)

    # PyTorch的推理
    model = Model()
    x     = model(input)
    print("result from Pytorch is :", x)

    # onnxruntime的推理
    sess  = onnxruntime.InferenceSession('../models/sample-asinh.onnx')
    x     = sess.run(None, {'input0': input.numpy()})
    print("result from onnx is:    ", x)

def export_norm_onnx():
    input   = torch.rand(1, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-asinh.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished normal onnx export")

if __name__ == "__main__":
    export_norm_onnx()

    # 自定义完onnx以后必须要进行一下验证
    validate_onnx()

这段代码展示了如何在 PyTorch 中创建和注册一个自定义的符号函数 (symbolic function),以便将 PyTorch 的 asinh 操作导出到 ONNX 格式,并在 ONNX Runtime 中进行推理验证

函数 register_custom_op_symbolic 是 PyTorch ONNX 导出工具中用来注册自定义算子的符号映射函数。此函数对于在 ONNX 中支持 PyTorch 中的自定义或特殊算子至关重要。下面我会逐一解释函数参数和作用:(from CHatGPT)

参数解释

  • symbolic_name (str):

    • 这是需要注册的自定义算子的名称。格式通常为 <domain>::<op>,其中 <domain> 表示算子所属的命名空间,而 <op> 是算子的名称。例如,aten::asinh 表示来自 aten 域的 asinh 算子。
  • symbolic_fn (Callable):

    • 这是一个函数,用于定义如何将 PyTorch 中的操作转换为 ONNX 图中的节点。这个函数通常会接收几个参数:ONNX 图对象、当前算子的输入参数等,并且基于这些输入构造并返回 ONNX 中相应的算子节点。

    • symbolic_fn 应返回一个或多个用于替换原有 PyTorch 算子的 ONNX 操作节点。这个函数的实现需要考虑输入输出的匹配、算子属性的转换等。

  • opset_version (int):

    • 这是指定算子应当注册到的 ONNX 操作集版本。ONNX 的操作集(opset)定义了算子的支持和行为,不同版本的 opset 可能支持不同的算子或者同一算子的不同行为。这个参数确保你的自定义算子兼容特定版本的 ONNX。

执行后输出如下图所示:

在这里插入图片描述

导出的 ONNX 如下图所示:

在这里插入图片描述

我们再来看另外一种写法,案例 sample_asinh_register2.py 代码如下所示:

import torch
import torch.onnx
import onnxruntime
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)

# 另外一个写法
#    这个是类似于torch/onnx/symbolic_opset*.py中的写法
#    通过torch._internal中的registration来注册这个算子,让这个算子可以与底层C++实现的aten::asinh绑定
#    一般如果这么写的话,其实可以把这个算子直接加入到torch/onnx/symbolic_opset*.py中
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x = torch.asinh(x)
        return x


def validate_onnx():
    input = torch.rand(1, 5)

    # PyTorch的推理
    model = Model()
    x     = model(input)
    print("result from Pytorch is :", x)

    # onnxruntime的推理
    sess  = onnxruntime.InferenceSession('../models/sample-asinh2.onnx')
    x     = sess.run(None, {'input0': input.numpy()})
    print("result from onnx is:    ", x)

def export_norm_onnx():
    input   = torch.rand(1, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-asinh2.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished normal onnx export")

if __name__ == "__main__":
    export_norm_onnx()

    # 自定义完onnx以后必须要进行一下验证
    validate_onnx()

这段代码展示了另一种注册自定义符号函数的方法,使用了 torch.onnx._internal 中的 registration 模块,通过装饰器的方式注册 asinh 操作符

与之前方式不同的是它通过两个步骤来注册自定义符号函数:(from ChatGPT)

1. 创建 _onnx_symbolic 的局部函数

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
  • 使用 functools.partial 创建一个偏函数 _onnx_symbolic,指定 opset 为 9。

  • 该偏函数用于注册符号函数,简化了装饰器的使用。

2. 定义 asinh_symbolic 符号函数并注册

@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)
  • 使用装饰器 @_onnx_symbolic('aten::asinh') 注册 asinh_symbolic 符号函数。
  • asinh_symbolic 函数用于在计算图 (graph) 中添加 ONNX 的 Asinh 操作。
  • 参数 g 是计算图对象,input 是输入张量。
  • g.op("Asinh", input) 表示在计算图中添加一个 Asinh 操作。

执行后输出如下图所示:

在这里插入图片描述

我们再来看下一个案例 sample_custom_op_autograd.py,代码如下所示:

import torch
import torch.onnx
import onnxruntime

OperatorExportTypes = torch._C._onnx.OperatorExportTypes
# 我们按照正常的方式,创建一个图。不考虑自己设计算子的话,我们其实是直接导出这个onnx的
# 只不过onnx会为这个实现根据自己内部注册的各个算子,追踪每一个节点生成图
# 我们可以观察一下onnx,会比较复杂,我们管这个叫做算子的inline autograd function

class CustomOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(x)
        x = x.clamp(min=0)
        return x / (1 + torch.exp(-x))

customOp = CustomOp.apply

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x = customOp(x)
        return x


def validate_onnx():
    input   = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)

    # PyTorch的推理
    model = Model()
    x     = model(input)
    print("result from Pytorch is :\n", x)

    # onnxruntime的推理
    sess  = onnxruntime.InferenceSession('../models/sample-customOp.onnx')
    x     = sess.run(None, {'input0': input.numpy()})
    print("result from onnx is:    \n", x)

def export_norm_onnx():
    input   = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)
    model   = Model()
    model.eval()

    # 我们可以在导出onnx的时候添operator_export_type的限制,防止导出的onnx进行inline
    file    = "../models/sample-customOp.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
        # operator_export_type = OperatorExportTypes.ONNX_FALLTHROUGH)
    print("Finished normal onnx export")

if __name__ == "__main__":
    export_norm_onnx()

    # 自定义完onnx以后必须要进行一下验证
    validate_onnx()

该代码定义了一个自定义的 PyTorch 算子 CustomOp,并在一个简单模型中使用它。然后将模型导出为 ONNX 格式,并通过 onnxruntime 验证导出的 ONNX 模型是否与原始 PyTorch 模型输出一致。

执行后输出如下图所示:

在这里插入图片描述

导出的 ONNX 如下图所示:

在这里插入图片描述

可以看到如果不加以特殊处理,自定义算子可能会被内联到标准 ONNX 算子中,这个过程可能会导致模型结构的复杂性增加,会导致产生一些多余的算子,而这些算子都将被 trace,这是我们不希望看到的,我们希望它导出更加简洁一点

那我们该怎么做,我们来看 sample_custom_op_autograd_register.py 案例,代码如下所示:

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic

OperatorExportTypes = torch._C._onnx.OperatorExportTypes

class CustomOp(torch.autograd.Function):
    @staticmethod 
    def symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:
        return g.op("custom_domain::customOp2", x)

    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(x)
        x = x.clamp(min=0)
        return x / (1 + torch.exp(-x))

customOp = CustomOp.apply

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x = customOp(x)
        return x


def validate_onnx():
    input   = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)

    # PyTorch的推理
    model = Model()
    x     = model(input)
    print("result from Pytorch is :\n", x)

    # onnxruntime的推理
    sess  = onnxruntime.InferenceSession('../models/sample-customOp2.onnx')
    x     = sess.run(None, {'input0': input.numpy()})
    print("result from onnx is:    \n", x)

def export_norm_onnx():
    input   = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-customOp2.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished normal onnx export")

if __name__ == "__main__":
    export_norm_onnx()

    # 自定义完onnx以后必须要进行一下验证
    validate_onnx()

它与之前的代码相比增加了对自定义算子的符号化注册,使得导出的 ONNX 模型更加简洁,主要区别有:

1. 增加 symbolic 方法:

  • CustomOp 类中添加了 symbolic 方法,用于定义自定义算子在 ONNX 中的符号化表示。

2. 注册自定义算子符号:

  • 通过 register_custom_op_symbolic 注册自定义算子,使其在导出 ONNX 时使用定义的符号化操作。

输出如下图所示:

在这里插入图片描述

导出的 ONNX 如下图所示:

在这里插入图片描述

可以看到导出的 ONNX 就一个我们的自定义算子非常简洁,但是我们在利用 onnxruntime 进行推理时发生了错误,这其实是因为 onnxruntime 在执行推理时无法识别我们的自定义算子 custom_domain::customOp2,我们需要在 onnxruntime 中进行相关自定义算子的实现才能正确推理

5. unsupported deformable conv算子

我们再来看 deformable conv 这个案例,先看 sample_deformable_conv.py 代码如下所示:

import torch
import torch.nn as nn
import torchvision
import torch.onnx

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 18, 3)
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)
    
    def forward(self, x):
        x = self.conv2(x, self.conv1(x))
        return x

def infer():
    input = torch.rand(1, 3, 5, 5)
    
    model = Model()
    x = model(input)
    print("input is: ", input.data)
    print("result is: ", x.data)

def export_norm_onnx():
    input   = torch.rand(1, 3, 5, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-deformable-conv.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished normal onnx export")

if __name__ == "__main__":
    infer()

    # 这里导出deformable-conv会出现错误。
    # torchvision支持deformable_conv的
    # 但是我们在onnx中是没有找到有关deformable conv的支持
    # 所以这个时候,我们需要做两件事情
    export_norm_onnx()

执行完输出如下图所示:

在这里插入图片描述

可以看到导出失败了,deformable conv 其实在 torchvision 中有实现,但是我们在 onnx 中没有找到有关deformable conv 的支持

这个时候我们其实注册下算子就行了,来看案例 sample_deformable_conv_register.py,代码如下:

import torch
import torch.nn as nn
import torchvision
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args

# 注意
#   这里需要把args的各个参数的类型都指定
#   这里还没有实现底层对deform_conv2d的实现
#   具体dcn的底层实现是在c++完成的,这里会在后面的TensorRT plugin中回到这里继续讲这个案例
#   这里先知道对于不支持的算子,onnx如何导出即可
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i","i", "i", "i", "none")
def dcn_symbolic(
        g,
        input,
        weight,
        offset,
        mask,
        bias,
        stride_h, stride_w,
        pad_h, pad_w,
        dil_h, dil_w,
        n_weight_grps,
        n_offset_grps,
        use_mask):
    return g.op("custom::deform_conv2d", input, offset)

register_custom_op_symbolic("torchvision::deform_conv2d", dcn_symbolic, 12)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 18, 3)
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)
    
    def forward(self, x):
        x = self.conv2(x, self.conv1(x))
        return x

def validate_onnx():
    input = torch.rand(1, 3, 5, 5)

    # PyTorch的推理
    model = Model()
    x     = model(input)
    print("result from Pytorch is :", x)

    # onnxruntime的推理
    sess  = onnxruntime.InferenceSession('../models/sample-deformable-conv.onnx')
    x     = sess.run(None, {'input0': input.numpy()})
    print("result from onnx is:    ", x)

def infer():
    input = torch.rand(1, 3, 5, 5)
    
    model = Model()
    x = model(input)
    print("input is: ", input.data)
    print("result is: ", x.data)

def export_norm_onnx():
    input   = torch.rand(1, 3, 5, 5)
    model   = Model()
    model.eval()

    file    = "../models/sample-deformable-conv.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished normal onnx export")

if __name__ == "__main__":
    # infer()
    export_norm_onnx()
    validate_onnx()

这段代码通过定义和注册自定义算子符号,将包含 deformable convolution 算子的 PyTorch 模型导出为 ONNX 格式,值得注意的是 dcn_symbolic 中的 g.op("custom::deform_conv2d", input, offset) 仅定义了一个自定义操作,但并没有实际实现 deform_conv2d 的计算逻辑,这需要在后续的 onnxruntime/TensorRT 插件中实现。

另外 @parse_args() 是一个装饰器,用于指明 dcn_symbolic 函数各个参数的类型。具体解释如下:

  • “v” (variable): 表示一个张量(Tensor)参数。
  • “i” (integer): 表示一个整型参数。
  • “none”: 表示一个可以为 None 的参数。

执行后输出如下图所示:

在这里插入图片描述

导出的 ONNX 如下图所示:

在这里插入图片描述

总结

本次课程我们主要学习了 ONNX 注册算子的方法,主要是利用 register_custom_op_symbolic 函数来注册自定义算子的符号映射函数,从而将自定义的算子导出到 ONNX,值得注意的是这个小节我们并没有具体的实现自定义算子的计算逻辑,这需要在后续的 TensorRT 插件中实现,这个我们在之后的 plugin 小节再来详细讲解

OK,以上就是第 6 小节有关 ONNX 注册算子方法的全部内容了,下节我们来学习 onnx-graph-surgeon,敬请期待😄

参考

  • https://github.com/kalfazed/tensorrt_starter
  • https://github.com/onnx/onnx/blob/main/docs/Operators.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/618445.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

新能源行业网间数据交换,更好用更专业的工具是什么?

新能源行业涵盖了多个方面&#xff0c;包括但不限于新能源汽车、可再生能源技术等。新能源行业发展具有重要的意义&#xff0c;新能源企业的研发数据极其重要&#xff0c;为了保障网络安全和数据安全&#xff0c;许多新能源企业采用逻辑隔离的方式进行网络隔离&#xff0c;此时…

C#【进阶】泛型

1、泛型 文章目录 1、泛型1、泛型是什么2、泛型分类3、泛型类和接口4、泛型方法5、泛型的作用思考 泛型方法判断类型 2、泛型约束1、什么是泛型2、各泛型约束3、约束的组合使用4、多个泛型有约束思考1 泛型实现单例模式思考2 ArrayList泛型实现增删查改 1、泛型是什么 泛型实现…

08 - hive的集合函数、高级聚合函数、炸裂函数以及窗口函数

目录 1、集合函数 1.1、size&#xff1a;集合中元素的个数 1.2、map&#xff1a;创建map集合 1.3、map_keys&#xff1a; 返回map中的key 1.4、map_values: 返回map中的value 1.5、array 声明array集合 1.6、array_contains: 判断array中是否包含某个元素 1.7、sort_a…

SBM模型、超效率SBM模型代码及案例数据(补充操作视频)

01、数据简介 SBM&#xff08;Slack-Based Measure&#xff09;模型是一种数据包络分析&#xff08;Data Envelopment Analysis, DEA&#xff09;的方法&#xff0c;用于评估决策单元&#xff08;Decision Making Units, DMUs&#xff09;的效率。而超效率SBM模型是对SBM模型的…

告别数据泥潭:PySpark性能调优的黄金法则

阿佑今天给大家带来个一张藏宝图——使用PySpark进行性能调优的黄金法则&#xff0c;从内存管理到执行计划&#xff0c;再到并行度设置&#xff0c;每一步都是提升数据处理速度的关键&#xff01; 文章目录 Python Spark 详解1. 引言2. 背景介绍2.1 大数据处理技术演变2.2 Apac…

Flutter-加载中动画

效果 考察内容 AnimationControllerTweenAnimatedBuilderTransformMatrix4 实现 ///源码&#xff1a;https://github.com/yixiaolunhui/flutter_xy class LoadingView extends StatefulWidget {const LoadingView({Key? key}) : super(key: key);overrideState<LoadingV…

AI算法-高数5-线性代数1-基本概念、向量

线性代数&#xff1a;主要研究1、张量>CV计算机视觉 2、研究张量的线性关系。 深度学习的表现之所以能够超过传统的机器学习算法离不开神经网络&#xff0c;然而神经网络最基本的数据结构就是向量和矩阵&#xff0c;神经网络的输入是向量&#xff0c;然后通过每个矩阵对向量…

Vue3项目打包部署到云服务器的Nginx中

文章目录 一、打包vue3项目二、dist文件夹上传到服务器三、改nginx配置文件Docker安装nginx 一、打包vue3项目 npm run build 二、dist文件夹上传到服务器 将dist文件夹放到docker安装的nginx中的html目录下 三、改nginx配置文件 然后重启nginx【改了配置文件重启nginx才能…

Cloudflare国内IP地址使用教程

Cloudflare国内IP地址使用教程 加速网站&#xff1a; 首先我们添加一个 A 记录解析&#xff0c;解析 IP 就是我们服务器真实 IP&#xff1a; 然后侧边栏 SSL/TLS - 自定义主机名&#xff1a; 回退源这里填写你刚刚解析的域名&#xff0c;保存后回退源状态为有效再来接下的操作…

C++ 指针 参数 静态 常 友元与组合概念

一 类类型作为函数参数 1 类类型作参数类型的三种方式 1&#xff09; 对象本身作为参数 由于C采用传值的方式传递参数&#xff0c;因此使用对象本身参数时&#xff0c;形参是实参的一个拷贝。在这种情况下&#xff0c;最好显式地为类定义一个拷贝构造函数&#xff0c;以免出…

二维费用背包分组背包

二维费用背包&分组背包 一定要做的

[Spring Cloud] (7)gateway防重放拦截器

文章目录 简述本文涉及代码已开源Fir Cloud 完整项目防重放防重放必要性&#xff1a;防重放机制作用&#xff1a; 整体效果后端进行处理 后端增加防重放开关配置签名密钥 工具类防重放拦截器 前端被防重放拦截增加防重放开关配置请求头增加防重放签名处理防重放验证处理函数bas…

HC-06 蓝牙串口从机 AT 命令详解

HC-06 蓝牙串口从机 AT 命令详解 要使用 AT 命令&#xff0c;首先要知道 HC-06 的波特率&#xff0c;然后要进入 AT 命令模式。 使用串口一定要知道三要素&#xff0c;一是波特率&#xff0c;二是串口号&#xff0c;三是数据格式, HC-06只支持一种数据格式: 数据位8 位&#…

MYSQL数据库-SQL语句

数据库相关概念 名称全称简称数据库存储数据的仓库&#xff0c;数据是有组织的进行存储DataBase(DB)数据库管理系统操纵和管理数据库的大型软件DataBase Management System(DBMS)SQL操作关系型数据库的编程语言&#xff0c;定义了一套操作关系型数据库统一标准Structured Quer…

第十四篇:数据库设计精粹:规范化与性能优化的艺术

数据库设计精粹&#xff1a;规范化与性能优化的艺术 1. 引言 1.1 数据库设计在现代应用中的核心地位 在数字化的浪潮中&#xff0c;数据库设计如同建筑师手中的蓝图&#xff0c;是构建信息大厦的基石。它不仅关乎数据的存储与检索&#xff0c;更是现代应用流畅运行的生命线。…

打印图形(C语言)

一、N-S流程图&#xff1b; 二、运行结果&#xff1b; 三、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h>int main() {//初始化变量值&#xff1b;int i, j;//循环打印&#xff1b;for (i 0; i < 5; i){//列&#xff1b;for (j 0; j &…

Python深度学习基于Tensorflow(9)注意力机制

文章目录 注意力机制是怎么工作的注意力机制的类型 构建Transformer模型Embedding层注意力机制的实现Encoder实现Decoder实现Transformer实现 注意力机制的主要思想是将注意力集中在信息的重要部分&#xff0c;对重要部分投入更多的资源&#xff0c;以获取更多所关注目标的细节…

关于Speech processing Universal PERformance Benchmark (SUPERB)基准测试及衍生版本

Speech processing Universal PERformance Benchmark &#xff08;SUPERB&#xff09;是由台湾大学、麻省理工大学&#xff0c;卡耐基梅隆大学和 Meta 公司联合提出的评测数据集&#xff0c;其中包含了13项语音理解任务&#xff0c;旨在全面评估模型在语音处理领域的表现。这些…

贝叶斯分类器详解

1 概率论知识 1.1 先验概率 先验概率是基于背景常识或者历史数据的统计得出的预判概率&#xff0c;一般只包含一个变量&#xff0c;例如P(A)&#xff0c;P(B)。 1.2 联合概率 联合概率指的是事件同时发生的概率&#xff0c;例如现在A,B两个事件同时发生的概率&#xff0c;记…

Hotcoin Research | 市场洞察:2024年5月6日-5月12日

加密货幣市场表现 加密货幣总市值为1.24万亿&#xff0c;BTC占比53.35%。 本周行情呈现先涨后跌的一种態势&#xff0c;5月6日-9日大盘持续下跌&#xff0c;周末为震荡行情。本周的比特幣现货ETF凈流入&#xff1a;1.1262亿美元&#xff0c;其中&#xff1a;美国ETF流入&…