记录pytorch实现自定义算子并转onnx文件输出

概览:记录了如何自定义一个算子,实现pytorch注册,通过C++编译为库文件供python端调用,并转为onnx文件输出

整体大概流程:

  • 定义算子实现为torch的C++版本文件
  • 注册算子
  • 编译算子生成库文件
  • 调用自定义算子

一、编译环境准备

1,在pytorch官网下载如下C++的libTorch package,下载完成后解压文件,是一个libtorch文件夹。

2,提前准备好python,以及pytorch

3,本示例使用了opencv库,所以需要提前安装好opencv。

二、自定义算子的实现

1,实现自定义算子函数

在解压后的libtorch文件夹统计目录,实现自定义算子,用opencv库实现的图像投射函数:warp_perspective。warp_perspective函数后面几行就是实现自定义算子的注册

warpPerspective.cpp文件:

#include "torch/script.h"
#include "opencv2/opencv.hpp"

torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
    // BEGIN image_mat
    cv::Mat image_mat(/*rows=*/image.size(0),
        /*cols=*/image.size(1),
        /*type=*/CV_32FC1,
        /*data=*/image.data_ptr<float>());
    // END image_mat

    // BEGIN warp_mat
    cv::Mat warp_mat(/*rows=*/warp.size(0),
        /*cols=*/warp.size(1),
        /*type=*/CV_32FC1,
        /*data=*/warp.data_ptr<float>());
    // END warp_mat

    // BEGIN output_mat
    cv::Mat output_mat;
    cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ image.size(0),image.size(1) });
    // END output_mat

    // BEGIN output_tensor
    torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ image.size(0),image.size(1) });
    return output.clone();
    // END output_tensor
}
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0


 torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {
    m.def("warp_perspective", warp_perspective);
}


2,同级目录创建CMakeList.txt文件

里面需要修改你自己的python下torch的路径,以及你对应安装python版pytorch是cpu还是gpu的。

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(warp_perspective)

set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")


set(TORCH_ROOT "/home/xxx/anaconda3/lib/python3.10/site-packages/torch")   
include_directories(${TORCH_ROOT}/include)
link_directories(${TORCH_ROOT}/lib/)

# Opencv
find_package(OpenCV REQUIRED)

# Define our library target
add_library(warp_perspective SHARED warpPerspective.cpp)

# Enable C++14
target_compile_features(warp_perspective PRIVATE cxx_std_17)

# libtorch库文件
target_link_libraries(warp_perspective 
    # CPU
    c10 
    torch_cpu
    # GPU
    # c10_cuda 
    # torch_cuda
    
)


# opencv库文件
target_link_libraries(warp_perspective
    ${OpenCV_LIBS}
)

add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

3,编译生成库文件

同级目录创建build文件夹,进入build文件夹利用CMakeList.txt进行编译,生成libwarp_perspective.so库文件

mkdir build
cd build
cmake ..
make

4,python版pytorch进行自定义算子的测试

注意我的以上代码都是放在了/data/xxx/mylib路径下,所以torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")就找到库文件的位置。

这里我随便找了一张图片,和直接用python版的opencv做投射变换的结果作为golden对比。如下分别是原图,golden, 自定义pytorch算子的输出。自定义算子的输出不太对,但是图像轮廓和投射效果是对的,后面有时间我再检查一下是什么原因。

测试代码: 

import torch
import cv2
import numpy as np

torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")

im=cv2.imread("/data/xxx/mylib/cat.jpg",0)

pst1 = np.float32([[56,65], [368,52], [28,387], [389,390]])
pst2 = np.float32([[100,145], [300,100], [80,290], [310,300]])
#2.2获取透视变换矩阵
T = cv2.getPerspectiveTransform(pst1, pst2)


in_data =torch.from_numpy(np.float32(im))
in2_data = torch.Tensor(T)

out1=torch.ops.my_ops.warp_perspective(in_data,in2_data)
dst0=np.uint8(out1.numpy())
cv2.imwrite("/data/xxx/mylib/cat_warp.jpg",dst0)

dst = cv2.warpPerspective(im, np.float32(T), (im.shape[1], im.shape[0]))
cv2.imwrite("/data/xxx/mylib/cat_warp_gold.jpg",dst)

三、自定义算子导出为onnx文件

将注册的pytorch的自定义算子导出为onnx文件查看,效果图如下:

导出代码文件如下

import torch
import numpy as np

torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")
class MyNet(torch.nn.Module):
    def __init__(self, name):
        super(MyNet, self).__init__()
        self.model_name = name

    def forward(self, in_data, warp_data):
        return torch.ops.my_ops.warp_perspective(in_data, warp_data)


def my_custom(g, in_data, warp_data):
    return g.op("cus_ops::warp_perspective", in_data, warp_data)
torch.onnx.register_custom_op_symbolic("my_ops::warp_perspective", my_custom, 9)


if __name__ == "__main__":
    net = MyNet("my_ops")
    in_data = torch.randn((32, 32))
    warp_data = torch.rand((3, 3))

    out = net(in_data, warp_data)
    print("out: ", out)

    # export onnx
    torch.onnx.export(net,
            (in_data, warp_data),
            "./my_ops_export_model2.onnx",
            input_names=["img_data", "warp_mat"],
            output_names=["out_img"],
            custom_opsets={"cus_ops": 11},
            )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/142833.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【GlobalMapper精品教程】064:点云提取(按范围裁剪)

本文讲解Globalmapper中进行点云数据提取(按范围裁剪)的方法。 文章目录 一、加载点云及范围数据二、点云裁剪三、注意事项一、加载点云及范围数据 加载配套实验数据包中的实验数据data064.rar中的point.las点云与bound.shp面状范围数据,如下图所示: 二、点云裁剪 接下来…

【C/PTA——8.数组2(课内实践)】

C/PTA——8.数组2&#xff08;课内实践&#xff09; 7-1 求矩阵的局部极大值7-2 求矩阵各行元素之和7-3 判断上三角矩阵7-4 点赞 7-1 求矩阵的局部极大值 #include<stdio.h> int main() {int m, n, i, j;int arr[100][100];scanf("%d %d", &m, &n);for…

PHP在自己框架中引入composer

目录 1、使用composer之前先安装环境 2、 在项目最开始目录添加composer.json文本文件 3、写入配置文件 composer.json 4、使用composer安装whoops扩展 5、引入composer类并且使用安装异常显示类 1、使用composer之前先安装环境 先安装windows安装composer并更换国内镜像…

Linux内存管理 | 五、物理内存空间布局及管理

我的圈子&#xff1a; 高级工程师聚集地 我是董哥&#xff0c;高级嵌入式软件开发工程师&#xff0c;从事嵌入式Linux驱动开发和系统开发&#xff0c;曾就职于世界500强企业&#xff01; 创作理念&#xff1a;专注分享高质量嵌入式文章&#xff0c;让大家读有所得&#xff01; …

Linux 性能调优之硬件资源监控

写在前面 考试整理相关笔记博文内容涉及 Linux 硬件资源监控常见的命名介绍&#xff0c;涉及硬件基本信息查看查看硬件错误信息查看虚拟环境和云环境资源理解不足小伙伴帮忙指正 对每个人而言&#xff0c;真正的职责只有一个&#xff1a;找到自我。然后在心中坚守其一生&#x…

如何通过把setTimeout异步转为同步

一.封装定时器函数 function delayed(time){return new Promise((resolve,reject)>{setTimeout( () > {resolve(time)}, time);}) }二调用的时候通过async await 修饰 async function demo() {console.log(new Date().getMinutes(): new Date().getSeconds())await del…

Transformers 中原生支持的量化方案概述

本文旨在对 transformers 支持的各种量化方案及其优缺点作一个清晰的概述&#xff0c;以助于读者进行方案选择。 目前&#xff0c;量化模型有两个主要的用途: 在较小的设备上进行大模型推理对量化模型进行适配器微调 到目前为止&#xff0c;transformers 已经集成并 原生 支持了…

VScode不打开浏览器实时预览html

下载Microsoft官方的Live Preview就行了 点击预览按钮即可预览

深圳联强优创手持PDA身份证阅读器 身份证核验手持机

身份证手持机外观比较小巧&#xff0c;方便携带&#xff0c;支持条码识别、人脸识别、NFC卡刷卡、内置二代证加密模块&#xff0c;可离线采集和识别二代身份证&#xff0c;进行身份识别。信息读取更便捷、安全高效。采用IP65高防护等级&#xff0c;1.5M防摔&#xff0c;可以适应…

RFID汽车制造工业系统解决方案

随着物联网技术的不断发展&#xff0c;汽车行业的信息化水平也在不断提高&#xff0c;随着近几年国产汽车的带动&#xff0c;汽车配件配套市场也已形成了一定的规模&#xff0c;初步形成比较完整成熟的零部件配套体系。 RFID汽车制造工业系统解决方案 与其他行业对比&#xff0…

人工智能与发电玻璃:未来能源技术的融合

人工智能与发电玻璃&#xff1a;未来能源技术的融合 摘要&#xff1a;本文探讨人工智能与发电玻璃这两项技术的结合&#xff0c;共同推动能源领域的创新。本文将介绍发电玻璃工作原理及应用、人工智能在发电玻璃的应用领域以及共同为可持续能源发展做出贡献。 一、引言 随着科…

Android自定义控件:一款多特效的智能loadingView

先上效果图&#xff08;如果感兴趣请看后面讲解&#xff09;&#xff1a; 1、登录效果展示 2、关注效果展示 1、【画圆角矩形】 画图首先是onDraw方法&#xff08;我会把圆代码写上&#xff0c;一步一步剖析&#xff09;&#xff1a; 首先在view中定义个属性&#xff1a;priv…

虹科示波器 | 汽车免拆检修 | 2010款奥迪A5车怠速时发动机偶尔自动熄火

一、故障现象 一辆2010款奥迪A5车&#xff0c;搭载CDN发动机&#xff0c;累计行驶里程约为16.3万km。车主进厂反映&#xff0c;发动机怠速偶尔出现抖动&#xff0c;紧接着自动熄火&#xff1b;重新起动&#xff0c;发动机又能正常工作&#xff1b;故障频率较低&#xff0c;有时…

Elastcsearch入门案例之 —— 搜索聚合

前言 在前面的Mall项目脚手架整合中涉及到的Elasticsearch的内容仅仅只是在表面给出了一个在SpringBoot中的使用示例&#xff0c;但其实对于Elasticsearch的一些基础概念和底层的原理并没有过多的涉及&#xff0c;这种学习方式是浮躁的&#xff0c;所以这篇文章荔枝会对其中欠缺…

NSSCTF题库——web

[SWPUCTF 2021 新生赛]gift_F12 f12后ctrlf找到flag [SWPUCTF 2021 新生赛]jicao——json_decode() 加密后的格式 $json {"a":"php","b":"mysql","c":3}; json必须双引号传输 构造&#xff1a;GET里json{"x"…

蓝桥杯 冒泡排序

冒泡排序的思想 冒泡排序的思想是每次将最大的一下一下移动到最右边&#xff0c;然后将最右边这个确定下来。 再来确定第二大的&#xff0c;再确定第三大的… 对于数组a[n]&#xff0c;具体来说&#xff0c;每次确定操作就是从左往右扫描&#xff0c;如果a[i]>a[i1],我们将…

Linux-AWK(应用最广泛的文本处理程序)

目录 一、awk基础 二、awk工作原理 三、OFS输出分隔符 四、awk的格式化输出 五、awk模式pattern 一、awk基础 使用案例&#xff1a; 1.准备工作 请在Linux中执行以下指令 cat -n /etc/passwd > ./passwd 练习&#xff1a; 1.从文件 passwd 中提取并打印出第五行的内…

基于混合蛙跳算法优化概率神经网络PNN的分类预测 - 附代码

基于混合蛙跳算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于混合蛙跳算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于混合蛙跳优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神…

《深入浅出进阶篇》洛谷P4147 玉蟾宫——悬线法dp

上链接&#xff1a;P4147 玉蟾宫 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)https://www.luogu.com.cn/problem/P4147 上题干&#xff1a; 有一个NxM的矩阵&#xff0c;每个格子里写着R或者F。R代表障碍格子&#xff0c;F代表无障碍格子请找出其中的一个子矩阵&#xff0c…

金蝶云星空设置单据体行高

文章目录 金蝶云星空设置单据体行高表单插件Python脚本 金蝶云星空设置单据体行高 表单插件 新建类继承AbstractBillPlugIn&#xff0c;重写OnInitialize方法进行设置 public override void OnInitialize(InitializeEventArgs e){base.OnInitialize(e);this.View.GetControl&…