深度学习模型格式转换:pytorch2onnx(包含自定义操作符)

       将PyTorch模型转换为ONNX(Open Neural Network Exchange)格式是实现模型跨平台部署和优化推理性能的一种常见方法。PyTorch 提供了多种方式来完成这一转换,以下是几种主要的方法: 

一、静态模型转换

使用 torch.onnx.export()

   torch.onnx.export() 是 PyTorch 官方推荐的最常用方法,适用于大多数情况。它允许你将一个 PyTorch 模型及其输入数据一起导出为 ONNX 格式。

基本用法
import torch
import torch.onnx

# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ...  # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224)  # 示例输入,形状取决于模型的输入要求

# 设置模型为评估模式
model.eval()

# 导出为 ONNX 文件
torch.onnx.export(
    model,                    # 要导出的模型
    dummy_input,              # 模型的输入张量
    "model.onnx",             # 输出文件名
    export_params=True,       # 是否导出模型参数
    opset_version=11,         # ONNX 操作集版本
    do_constant_folding=True, # 是否执行常量折叠优化
    input_names=['input'],    # 输入节点名称
    output_names=['output'],  # 输出节点名称
    dynamic_axes={'input': {0: 'batch_size'},  # 动态轴,支持可变批次大小
                  'output': {0: 'batch_size'}}
)
关键参数说明
  • model: 要导出的 PyTorch 模型。
  • dummy_input: 一个与模型输入形状匹配的张量,用于模拟实际输入。
  • export_params: 是否导出模型的参数(权重和偏置)。通常设置为 True
  • opset_version: 指定要使用的 ONNX 操作集版本。不同的版本可能支持不同的操作符。建议使用较新的版本(如 11 或 13)。
  • do_constant_folding: 是否执行常量折叠优化,可以减少模型的计算量。
  • input_names 和 output_names: 指定 ONNX 模型的输入和输出节点的名称,方便后续加载和调用。
  • dynamic_axes: 指定哪些维度是动态的(即可以在推理时变化),例如批次大小或序列长度。

二、复杂模型转换

       对于一些复杂的模型,特别是包含控制流(如条件语句、循环等)的模型,torch.onnx.export() 可能无法直接处理。这时可以先使用 torch.jit.trace() 将模型转换为 TorchScript 格式,然后再导出为 ONNX。

基本用法

 

import torch
import torch.onnx

# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ...  # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224)  # 示例输入

# 设置模型为评估模式
model.eval()

# 使用 torch.jit.trace 将模型转换为 TorchScript
traced_model = torch.jit.trace(model, dummy_input)

# 导出为 ONNX 文件
torch.onnx.export(
    traced_model,            # 已经转换为 TorchScript 的模型
    dummy_input,             # 模型的输入张量
    "traced_model.onnx",     # 输出文件名
    export_params=True,      # 是否导出模型参数
    opset_version=11,        # ONNX 操作集版本
    do_constant_folding=True,# 是否执行常量折叠优化
    input_names=['input'],   # 输入节点名称
    output_names=['output'], # 输出节点名称
    dynamic_axes={'input': {0: 'batch_size'},  # 动态轴
                  'output': {0: 'batch_size'}}
)

三、动态模型转换

使用 torch.onnx.dynamo_export()

   torch.onnx.dynamo_export() 是 PyTorch 2.0 引入的新功能,基于 PyTorch 的 Dynamo 编译器。它旨在提供更好的性能和更广泛的模型支持,尤其是对于那些包含动态控制流的模型。

基本用法
import torch

# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ...  # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224)  # 示例输入

# 设置模型为评估模式
model.eval()

# 使用 dynamo_export 导出为 ONNX 文件
torch.onnx.dynamo_export(
    model,                    # 要导出的模型
    dummy_input,              # 模型的输入张量
    "dynamo_model.onnx"       # 输出文件名
)

        注意torch.onnx.dynamo_export() 是 PyTorch 2.0 中引入的功能,确保你使用的是最新版本的 PyTorch。

四、自定义操作符模型转换

       自定义操作符(Custom Operator)是指那些不在标准 PyTorch 或 ONNX 操作集中的操作符。当你需要实现某些特定的功能或优化时,可能需要编写自定义的操作符,并将其注册到 ONNX 中以便在导出和推理时使用。

例子:实现一个自定义的 ReLU6 操作符

假设我们想要实现一个自定义的 ReLU6 操作符。ReLU6 是一种常用的激活函数,它与标准的 ReLU 类似,但有一个上限值 6。其数学表达式为:

1. 实现自定义操作符

       首先,我们需要在 C++ 中实现这个自定义操作符,并编译成一个共享库。PyTorch 提供了 torch::jit::custom_ops 接口来注册自定义操作符,而 ONNX 则提供了 onnxruntime 来注册自定义操作符。

1.1 在 PyTorch 中实现自定义操作符

       我们可以在 C++ 中实现 ReLU6 操作符,并通过 PyTorch 的 torch::jit::custom_ops 接口将其注册到 PyTorch 中:

// custom_relu6.cpp
#include <torch/script.h>
#include <torch/custom_class.h>

// 定义自定义的 ReLU6 操作符
torch::Tensor custom_relu6(const torch::Tensor& input) {
    return torch::clamp(input, 0, 6);
}

// 注册自定义操作符
static auto registry = torch::RegisterOperators("custom_ops::relu6", &custom_relu6);
1.2 编译自定义操作符

       接下来,我们需要将这个 C++ 文件编译成一个共享库(例如 .so 文件),以便在 Python 中加载:

# 使用 PyTorch 提供的工具进行编译
python -m pip install torch torchvision torchaudio
python -m torch.utils.cpp_extension.build_ext --inplace custom_relu6.cpp

这会生成一个名为 custom_relu6.so 的共享库文件;

2. 在 PyTorch 中使用自定义操作符

       现在我们可以在 Python 中加载并使用这个自定义操作符;

import torch
import torch.nn as nn
import custom_relu6  # 加载编译后的共享库

# 定义一个使用自定义 ReLU6 操作符的模型
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv(x)
        # 调用自定义的 ReLU6 操作符
        x = torch.ops.custom_ops.relu6(x)
        return x

# 创建模型实例
model = CustomModel()
model.eval()

# 准备示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 运行模型
output = model(dummy_input)
print(output.shape)  # 输出形状应为 (1, 16, 224, 224)
3. 将自定义操作符导出为 ONNX

       为了将包含自定义操作符的模型导出为 ONNX 格式,我们需要告诉 ONNX 如何处理这个自定义操作符。我们可以使用 torch.onnx.register_custom_op_symbolic 来定义 ONNX 符号函数,从而在导出时正确处理自定义操作符。

3.1 定义 ONNX 符号函数

       我们需要定义一个符号函数,告诉 ONNX 如何表示 custom_ops::relu6 操作符。这个符号函数会生成相应的 ONNX 操作符节点。

import torch.onnx
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args

# 定义 ONNX 符号函数
@parse_args('v')
def symbolic_custom_relu6(g, input):
    # 使用 ONNX 的 Clip 操作符来实现 ReLU6
    return g.op("Clip", input, min_f=0.0, max_f=6.0)

# 注册自定义操作符的符号函数
register_custom_op_symbolic('custom_ops::relu6', symbolic_custom_relu6, 9)  # 9 表示 ONNX 操作集版本
3.2 导出为 ONNX

       现在我们可以将模型导出为 ONNX 格式,并确保自定义操作符被正确处理。

# 导出为 ONNX 文件
torch.onnx.export(
    model,                    # 要导出的模型
    dummy_input,              # 模型的输入张量
    "custom_model.onnx",      # 输出文件名
    export_params=True,       # 是否导出模型参数
    opset_version=9,          # ONNX 操作集版本
    do_constant_folding=True, # 是否执行常量折叠优化
    input_names=['input'],    # 输入节点名称
    output_names=['output'],  # 输出节点名称
    dynamic_axes={'input': {0: 'batch_size'},  # 动态轴
                  'output': {0: 'batch_size'}}
)
4. 在 ONNX Runtime 中使用自定义操作符

       为了在 ONNX Runtime 中使用自定义操作符,我们需要将自定义操作符的实现编译成一个 ONNX Runtime 扩展库,并在推理时加载该扩展库。

4.1 实现 ONNX Runtime 自定义操作符

       我们需要在 C++ 中实现 ReLU6 操作符,并将其注册到 ONNX Runtime 中。

// custom_relu6_onnxruntime.cpp
#include "onnxruntime/core/providers/cpu/cpu_provider_factory.h"
#include "onnxruntime/core/framework/op_kernel.h"

namespace onnxruntime {

class CustomRelu6 : public OpKernel {
public:
  explicit CustomRelu6(const OpKernelInfo& info) : OpKernel(info) {}

  Status Compute(OpKernelContext* context) const override {
    // 获取输入张量
    const Tensor* input_tensor = context->Input<Tensor>(0);
    if (!input_tensor) return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is null");

    // 获取输出张量
    Tensor* output_tensor = context->Output(0, input_tensor->Shape());
    if (!output_tensor) return Status(common::ONNXRUNTIME, common::FAIL, "Output tensor is null");

    // 获取输入和输出的数据指针
    float* input_data = input_tensor->template Data<float>();
    float* output_data = output_tensor->template Data<float>();

    // 计算 ReLU6
    size_t size = input_tensor->Shape().Size();
    for (size_t i = 0; i < size; ++i) {
      output_data[i] = std::min(std::max(input_data[i], 0.0f), 6.0f);
    }

    return Status::OK();
  }
};

ONNX_OPERATOR_KERNEL(
    Relu6,  // 操作符名称
    kOnnxDomain,  // 命名空间
    9,  // 操作集版本
    KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),  // 数据类型约束
    CustomRelu6);  // 自定义操作符类
}
4.2 编译 ONNX Runtime 自定义操作符

       我们将上述代码编译成一个动态链接库(例如 .so 文件),以便在 ONNX Runtime 中加载。

# 使用 ONNX Runtime 提供的工具进行编译
g++ -shared -fPIC -o custom_relu6_onnxruntime.so custom_relu6_onnxruntime.cpp -lonnxruntime
4.3 在 ONNX Runtime 中加载自定义操作符

       最后,我们在 Python 中使用 onnxruntime 加载自定义操作符,并运行推理。

import onnxruntime as ort
import numpy as np

# 加载 ONNX 模型
ort_session = ort.InferenceSession("custom_model.onnx", providers=['CPUExecutionProvider'])

# 加载自定义操作符的扩展库
ort_session.load_custom_ops_library("custom_relu6_onnxruntime.so")

# 准备输入数据
ort_inputs = {'input': dummy_input.numpy()}  # 将 PyTorch 张量转换为 NumPy 数组

# 运行推理
ort_outs = ort_session.run(None, ort_inputs)

# 获取 PyTorch 模型的输出
with torch.no_grad():
    torch_out = model(dummy_input)

# 比较 ONNX 和 PyTorch 的输出
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)

print("ONNX 模型验证通过!")

总结

  • 自定义操作符:当你的模型中包含不在标准 PyTorch 或 ONNX 操作集中的操作符时,你可以通过编写自定义操作符来实现这些功能。
  • PyTorch 中的自定义操作符:可以使用 torch::jit::custom_ops 接口在 C++ 中实现自定义操作符,并通过共享库加载到 PyTorch 中。
  • ONNX 中的自定义操作符:可以通过 torch.onnx.register_custom_op_symbolic 定义符号函数,告诉 ONNX 如何处理自定义操作符。然后,在 ONNX Runtime 中,可以通过编译自定义操作符的实现并加载扩展库来支持推理。
  • 复杂性:实现自定义操作符通常比较复杂,因为它涉及到跨语言编程(C++ 和 Python)、编译和链接等多个步骤。然而,这对于实现特定功能或优化模型是非常有用的。

       通过这个例子,你可以看到如何从头实现一个自定义操作符,并将其集成到 PyTorch 和 ONNX 中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/947207.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GPU 进阶笔记(一):高性能 GPU 服务器硬件拓扑与集群组网

记录一些平时接触到的 GPU 知识。由于是笔记而非教程&#xff0c;因此内容不求连贯&#xff0c;有基础的同学可作查漏补缺之用 1 术语与基础 1.1 PCIe 交换芯片1.2 NVLink 定义演进&#xff1a;1/2/3/4 代监控1.3 NVSwitch1.4 NVLink Switch1.5 HBM (High Bandwidth Memory) 由…

在Unity中用Ab包加载资源(简单好抄)

第一步创建一个Editor文件夹 第二步编写BuildAb&#xff08;这个脚本一点要放在Editor中因为这是一个编辑器脚本&#xff0c;放在其他地方可能会报错&#xff09; using System.IO; using UnityEditor; using UnityEngine;public class BuildAb : MonoBehaviour {// 在Unity编…

【贪心算法】贪心算法七

贪心算法七 1.整数替换2.俄罗斯套娃信封问题3.可被三整除的最大和4.距离相等的条形码5.重构字符串 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&#x1f…

(五)人工智能进阶:基础概念解释

前面我们介绍了人工智能是如何成为一个强大函数。接下来&#xff0c;搞清损失函数、优化方法和正则化等核心概念&#xff0c;才能真正驾驭它&#xff01; 1. 什么是网络模型&#xff1f; 网络模型就像是一个精密的流水线工厂&#xff0c;由多个车间&#xff08;层&#xff0…

SpringMVC(二)原理

目录 一、配置Maven&#xff08;为了提升速度&#xff09; 二、流程&&原理 SpringMVC中心控制器 完整流程&#xff1a; 一、配置Maven&#xff08;为了提升速度&#xff09; 在SpringMVC&#xff08;一&#xff09;配置-CSDN博客的配置中&#xff0c;导入Maven会非…

2、redis的持久化

redis的持久化 在redist当中&#xff0c;高可用的技术包括持久化&#xff0c;主从复制&#xff0c;哨兵模式&#xff0c;集群。 持久化是最简单的高可用的方法&#xff0c;作用就是备份数据。即将数据保存到硬盘&#xff0c;防止进程退出导致数据丢失。 redis持久化方式&…

【算法】模拟退火算法学习记录

写这篇博客的原因是博主本人在看某篇文章的时候&#xff0c;发现自己只是知道SGD这个东西&#xff0c;但是到底是个啥不清楚&#xff0c;所以百度了一下&#xff0c;然后在通过博客学习的时候看到了退火两个字&#xff0c;想到了本科做数模比赛的时候涉猎过&#xff0c;就上bil…

Visual Point Cloud Forecasting enables Scalable Autonomous Driving——点云论文阅读(12)

此内容是论文总结,重点看思路!! 文章概述 这篇文章介绍了一个名为 ViDAR 的视觉点云预测框架,它通过预测历史视觉输入生成未来点云,作为自动驾驶的预训练任务。ViDAR 集成了语义、三维几何和时间动态信息,有效提升了感知、预测和规划等自动驾驶核心任务的性能。实验表明…

AI 将在今年获得“永久记忆”,2028美国会耗尽能源储备

AI的“永久记忆”时代即将来临 谷歌前CEO施密特揭示了AI技术的前景&#xff0c;他相信即将在2025年迎来一场伟大的变化。AI将实现“永久记忆”&#xff0c;改变我们与科技的互动过程。施密特将现有的AI上下文窗口比作人类的短期记忆&#xff0c;难以持久保存信息。他的设想是…

工控主板ESM7000/6800E支持远程桌面控制

英创公司ESM7000 是面向工业领域的双核 Cortex-A7 高性能嵌入式主板&#xff0c;ESM6800E则为单核Cortex-A7 高性价比嵌入式主板&#xff0c;ESM7000、ESM6800E都是公司的成熟产品&#xff0c;已广泛应用于工业很多领域。ESM7000/6800E板卡中Linux系统配置为linux-4.9.11内核、…

越权漏洞简介及靶场演示

越权漏洞简介及靶场演示 文章目录 一、什么是越权&#xff1f; &#xff08;一&#xff09;越权漏洞的概念&#xff08;二&#xff09;越权漏洞的分类&#xff08;三&#xff09;常见越权方法&#xff08;四&#xff09;未授权访问 二、越权漏洞测试过程 &#xff08;一&…

VIT:视觉transformer|学习微调记录

一、了解VIT结构 vit提出了对于图片完全采用transformer结构而不是CNN的方法&#xff0c;通过将图片分为patch&#xff0c;再将patch展开输入编码器&#xff08;grid_size网格大小&#xff09;&#xff0c;最后用MLP将输出转化为对应类预测。 详细信息可以看下面这个分享&…

coredns报错plugin/forward: no nameservers found

coredns报错plugin/forward: no nameservers found并且pod无法启动 出现该报错原因 是coredns获取不到宿主机配置的dns地址 查看宿主机是否有dns地址 resolvectl status 我这里是配置正确后&#xff0c;如果没配置过以下是不会显示出dns地址的 给宿主机增加静态dns地址之后将…

使用Diffusion Models进行图像超分辩重建

Diffusion Models专栏文章汇总:入门与实战 前言:图像超分辨率重建是一个经典CV任务,其实LR(低分辨率)和 HR(高分辨率)图像仅在高频细节上存在差异。通过添加适当的噪声,LR 图像将变得与其 HR 对应图像无法区分。这篇博客介绍一种方式巧妙利用这个规律使用Diffusion Mod…

NineData 荣获年度“创新解决方案奖”

近日&#xff0c;国内知名 IT 垂直媒体 & 技术社区 IT168 再次启动“技术卓越奖”评选&#xff0c;由行业 CIO/CTO 大咖、技术专家及 IT 媒体多方联合评审&#xff0c;NineData 凭借技术性能和产品创新等方面表现出色&#xff0c;在数据库工具领域荣获“2024 年度创新解决方…

liunx下载gitlab

1.地址&#xff1a; https://mirrors.tuna.tsinghua.edu.cn/gitlab-ce/yum/el7/ 安装 postfix 并启动 yum install postfix systemctl start postfix systemctl enable postfix ssh服务启动 systemctl enable sshd systemctl start sshd开放 ssh 以及 http 服务&#xff0c…

SQL—替换字符串—replace函数用法详解

SQL—替换字符串—replace函数用法详解 REPLACE() 函数——查找一个字符串中的指定子串&#xff0c;并将其替换为另一个子串。 REPLACE(str, old_substring, new_substring)str&#xff1a;要进行替换操作的原始字符串。old_substring&#xff1a;要被替换的子串。new_substri…

Android笔试面试题AI答之Android基础(11)

Android入门请看《Android应用开发项目式教程》&#xff0c;视频、源码、答疑&#xff0c;手把手教 文章目录 1.Android的权限有哪些&#xff1f;**1. 普通权限****常见普通权限** **2. 危险权限****权限分组****常见危险权限组及权限** **3. 特殊权限****常见特殊权限** **4. …

机器学习之正则化惩罚和K折交叉验证调整逻辑回归模型

机器学习之正则化惩罚和K折交叉验证调整逻辑回归模型 目录 机器学习之正则化惩罚和K折交叉验证调整逻辑回归模型1 过拟合和欠拟合1.1 过拟合1.2 欠拟合 2 正则化惩罚2.1 概念2.2 函数2.3 正则化种类 3 K折交叉验证3.1 概念3.2 图片理解3.3 函数导入3.4 参数理解 4 训练模型K折交…

[AHK]用大模型写ahk脚本

问题背景 遇到程序在运行&#xff0c;但是在屏幕上看不到的窘境&#xff0c;于是想用AHK来实现一键在主屏幕上居中显示。 解决思路 手撸是不可能手撸的&#xff0c;我有豆包我有cursor&#xff0c;于是想看看她俩到底能力咋样。 提示词 用AHK v2实现&#xff1a;热键WinC …