pytorch导出rot90算子至onnx

如何导出rot90算子至onnx

    • 1 背景描述
    • 2 等价替换
      • 2.1 rot90替换(NCHW)
      • 2.2 rot180替换(NCHW)
      • 2.3 rot270替换(NCHW)
    • 3 rot导出ONNX

1 背景描述

在部署模型时,如果某些模型中或者前后处理中含有rot90算子,但又希望一起和模型导出onnx时,可能会遇到如下错误(当前使用环境pytorch2.0.1opset_version为17):

import torch
import torch.nn as nn


class RotModel(nn.Module):
    def forward(self, x: torch.Tensor):
        x = torch.rot90(x, k=1, dims=(2, 3))
        return x


def main():
    print("pytorch version:", torch.__version__)

    model = RotModel()
    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        torch.onnx.export(model,
                          args=(x,),
                          f="rot90_counterclockwise.onnx",
                          opset_version=17)


if __name__ == '__main__':
    main()

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::rot90’ to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

简单的说就是不支持导出该算子,包括在onnx支持的算子文档中也找不到rot90算子,onnx官方github链接:
https://github.com/onnx/onnx


2 等价替换

导不出咋办,那就想想旋转矩阵的原理,以及如何使用现有支持的算子替换。

2.1 rot90替换(NCHW)

废话不多说,rot90度(以逆时针为例)可以使用翻转和转置实现。具体代码如下,使用torch自带的rot90与自己实现的对比,通过torch.equal()来对比两个Tensor是否一致,结果一致,不信自己试试。

import torch


def self_rot90_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[3]).permute([0, 1, 3, 2])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=1, dims=[2, 3])
        y1 = self_rot90_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()

2.2 rot180替换(NCHW)

rot180度(以逆时针为例)可以使用翻转实现。具体代码如下:

import torch


def self_rot180_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[2, 3])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=2, dims=[2, 3])
        y1 = self_rot180_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()

2.3 rot270替换(NCHW)

rot270度(以逆时针为例)可以使用翻转和转置实现。具体代码如下:

import torch


def self_rot270_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[2]).permute([0, 1, 3, 2])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=3, dims=[2, 3])
        y1 = self_rot270_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()


3 rot导出ONNX

这里以rot90度(以逆时针为例)结合刚刚的等价实现来导出ONNX:

import torch
import torch.nn as nn


class RotModel(nn.Module):
    def forward(self, x: torch.Tensor):
        # x = torch.rot90(x, k=1, dims=(2, 3))
        x = x.flip(dims=[3]).permute([0, 1, 3, 2])
        return x


def main():
    print("pytorch version:", torch.__version__)

    model = RotModel()
    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        torch.onnx.export(model,
                          args=(x,),
                          f="rot90_counterclockwise.onnx",
                          opset_version=17)


if __name__ == '__main__':
    main()

使用netron打开生成的rot90_counterclockwise.onnx文件,如下所示:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/191437.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于51单片机直流电机PWM控制设计

直流电机驱动 🎶基于51单片机的PWM控制直流电机设计( proteus仿真程序报告讲解视频)🎶主要功能:🎶仿真🎶程序设计: 🎶设计报告🎶资料清单:资料网盘下载链接&a…

hivesql 将json格式字符串转为数组

hivesql 将json格式字符串转为数组 完整过程SQL在文末 json 格式字符串 本案例 json 字符串参考格式,请勿使用本数据 {"data": [{"province": 11,"id_card": "110182198903224674","name": "闾丘饱乾"…

gitee仓库使用教程

下载安装git;在本地项目文件夹右击鼠标点击Git Bash Here;输入git init,这个目录变成git可以管理的仓库,会出现一个.git文件夹,如果没出现的话需要选择“显示隐藏文件”(不会的同学自行百度一下) 4.绑定本地…

MySQL数据库如何实现跨服务器访问数据

点击上方蓝字关注我 在使用MySQL数据库时,很多同学经常会问,我能跨服务器访问另一库的数据么?得到的答案很多时候是让人失望的。那么如果真的需要访问,又不想使用拷贝表及数据的方式,可以实现么,又该如何实…

3.数据结构

3.1 数据结构分类 常见的数据结构包括数组、链表、栈、队列、哈希表、树、堆、图,它们可以从“逻辑结构”和“物理结构”两个维度进行分类。 3.1.1逻辑结构:线性与非线性 逻辑结构揭示了数据元素之间的逻辑关系。在数组和链表中,数据按照…

056-第三代软件开发-软件打包

第三代软件开发-软件打包 文章目录 第三代软件开发-软件打包项目介绍软件打包1 下载 linuxdepoyqt 工具2 安装 linuxdepoyqt3 qmake配置4 打包程序 总结 关键字: Qt、 Qml、 linuxdeployqt、 Ubuntu、 AppImage 项目介绍 欢迎来到我们的 QML & C 项目&…

常见树种(贵州省):021冬青、连香树、白辛树、香合欢、云贵鹅耳枥、肥牛树、杜英、格木、黄连木、圆果化香树、南天竹

摘要:本专栏树种介绍图片来源于PPBC中国植物图像库(下附网址),本文整理仅做交流学习使用,同时便于查找,如有侵权请联系删除。 图片网址:PPBC中国植物图像库——最大的植物分类图片库 一、冬青 …

图解系列--Http

1.URI和URL 1.1.URL URL是统一资源定位符。URL正是使用 Web 浏览器等访问 Web 页面时需要输入的网页地址。比如,http://hackr.jp/就是 URL。 1.2.URI 统一资源标识符。 URI 用字符串标识某一互联网资源,而URL表示资源的地点(互联网上所处的位…

Java 基础学习(一)Java环境搭建和基本数据类型

1 Java 开发环境搭建 1.1 Java 编程语言 1.1.1 什么是Java编程语言 语言是人类进行沟通交流的各种表达符号,方便人与人之间进行沟通与信息交换;而计算机编程语言则是人与计算机之间进行信息交流沟通的一种特殊语言,也有语法规则、字符、符…

【c++随笔14】虚函数表

【c随笔14】虚函数表 一、虚函数表(Virtual Function Table)1、定义2、查看虚函数表2.1、 问题:三种类型,包含一个int类型的class、一个int类型的变量、int类型的指针:这三个大小分别是多少呢?2.2、怎么发现…

Linux7安装mysql数据库以及navicat远程连接mysql

1.下载地址:MySQL :: Download MySQL Community Server 2.创建mysql目录将压缩包上传到该目录 mkdir /opt/mysql cd /opt/mysql3.解压压缩包 gzip mysql-8.1.0-1.el7.x86_64.rpm-bundle.tar tar -zxvf mysql-8.1.0-1.el7.x86_64.rpm-bundle.tar.gz 4.前置检查 ch…

STM32 F1 串口空闲中断 + DMA实现数据发送

DMA实现数据发送 文章目录 DMA实现数据发送前言一、DMA二、代码编写1.DMA2.USART3.main 前言 当你遇到通信数据量大的时候,可以使用 空闲中断 DMA 的方案来减轻 CPU 的压力。 或者 在进行stm32开发时,有时会遇到这种情况:需要在设备间进行数…

接口测试快速入门 以飞致云平台为例

飞致云电商API地址系统来自飞致云项目。接口API地址:https://gz.fit2cloud.com/swagger-ui.html 飞致云电商系统接口文档 V1.0:见 有道云笔记 该网站可以做接口测试练习。快速了解如何测试接口,如何做关联 系统基地址:https://g…

通用电气调查网络攻击和数据盗窃指控

通用电气正在调查有关威胁行为者在网络攻击中破坏了公司开发环境并泄露据称被盗数据的指控。 通用电气 (GE) 是一家美国跨国公司,业务涉及电力、可再生能源和航空航天行业。 本月早些时候,一个名为 IntelBroker 的威胁行为者试图在黑客论坛上以 500 美…

Linuxfork,写时拷贝

1.prinf隐藏的缓冲区 1.思考:为什么会有缓冲区的存在? 2.演示及思考? 1).演示缓存区没有存在感 那为什么我们感觉不到缓冲区的存在呢?我们要打印东西直接就打印了呢? 我们用代码演示一下: 比如打开一个main.c,输入内容如下: #include <stdio.h> int main() { …

二十七、微服务案例

目录 一、实现输入搜索功能 1、下载代码&#xff0c;在idea上打开 2、新建RequestParams类&#xff0c;用于接收解析请求 3、在启动类中加入客户端地址Bean&#xff0c;以便实现服务 4、编写搜索方法 5、新建返回分页结果类 6、实现搜索方法 7、编写控制类&#xff0c;…

docker打包前端镜像

文章目录 一、构建镜像二、查看本地镜像三、启动容器四、查看启动的容器五、保存镜像六、读取镜像七、创建镜像八、最后 docker官网 一、构建镜像 -t是给镜像命名&#xff0c;.(点)是基于当前目录的Dockerfile来构建镜像 docker build -t image_web .二、查看本地镜像 docke…

【数据结构】什么是队列?

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 目录 &#x1f4cc;队列的定义 &#x1f4cc;队列的抽象数据类型 &#x1f4cc;队列的顺序存储结构 &#x1f4cc;队列的链式存储结构 结语 人生,是一个又一个小小的队列…

极客时间:使用本地小型语言模型运行网页浏览器应用程序。

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

【Apache Doris】Manager极致丝滑地运维管理

【Apache Doris】Manager极致丝滑地运维管理 1.标准VS可视化运维管理2. 环境信息2.1.硬件信息2.2.软件信息 3.前置准备3.1.安装包准备3.2.文档手册准备 4.集群初始化4.1.系统参数预设4.2.Manager部署4.3.新集群部署4.4 监控告警4.4.1 监控4.4.2 告警 5. 集群升级5.1 新包准备5.…