Pytorch转onnx

pytorch 转 onnx 模型需要函数 torch.onnx.export。

def export(
    model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
    args: Union[Tuple[Any, ...], torch.Tensor],
    f: Union[str, io.BytesIO],
    export_params: bool = True,
    verbose: bool = False,
    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
    input_names: Optional[Sequence[str]] = None,
    output_names: Optional[Sequence[str]] = None,
    operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
    opset_version: Optional[int] = None,
    do_constant_folding: bool = True,
    dynamic_axes: Optional[
        Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]
    ] = None,
    keep_initializers_as_inputs: Optional[bool] = None,
    custom_opsets: Optional[Mapping[str, int]] = None,
    export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
) -> None:

常用参数说明

model——需要导出的pytorch模型
args——模型的输入参数,满足输入层的shape正确即可。
f——输出的onnx模型的位置。例如‘yolov5.onnx’。
export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
verbose——是否打印模型转换信息。default=False。
input_names——输入节点名称。default=None。
output_names——输出节点名称。default=None。
opset_version——算子指令集合
do_constant_folding——是否使用常量折叠,默认即可。default=True。
dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道

参数说明
ONNX算子文档
ONNX 算子的定义情况,都可以在官方的算子文档中查看
这份文档中最重要的开头的这个算子变更表格。表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export中提到的opset_version表示的算子集版本号。通过查看算子第一次发生变动的版本号,我们可以知道某个算子是从哪个版本开始支持的;通过查看某算子小于等于opset_version的第一个改动记录,我们可以知道当前算子集版本中该算子的定义规则。
在这里插入图片描述
练习

import torch
import torch.nn as nn
import torch.onnx

class Model(torch.nn.Module):
    def __init__(self, in_features, out_features, weights, bias=False):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias)
        with torch.no_grad():
            self.linear.weight.copy_(weights)
    
    def forward(self, x):
        x = self.linear(x)
        return x

def infer():
    in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
    weights = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    
    model = Model(4, 3, weights)
    x = model(in_features)
    print("result is: ", x)

def export_onnx():
    input   = torch.zeros(1, 1, 1, 4)
    weights = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    model   = Model(4, 3, weights)
    model.eval() #添加eval防止权重继续更新

    # pytorch导出onnx的方式,参数有很多,也可以支持动态size
    # 我们先做一些最基本的导出,从netron学习一下导出的onnx都有那些东西
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = "../models/example.onnx",
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished onnx export")


if __name__ == "__main__":
    infer()
    export_onnx()

然后使用netron打开onnx文件,如果没有安装netron,在终端使用pip install netron。
在这里插入图片描述

参考链接
模型部署入门教程(三):PyTorch 转 ONNX 详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/520602.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

《QT实用小工具·十六》IP地址输入框控件

1、概述 源码放在文章末尾 该项目为IP地址输入框控件,主要包含如下功能: 可设置IP地址,自动填入框。 可清空IP地址。 支持按下小圆点自动切换。 支持退格键自动切换。 支持IP地址过滤。 可设置背景色、边框颜色、边框圆角角度。 下面…

网址打包微信小程序源码 wap转微信小程序 网站转小程序源码 网址转小程序开发

内容目录 一、详细介绍二、效果展示2.效果图展示 三、学习资料下载 一、详细介绍 我们都知道微信小程序是无法直接打开网址的。 这个小程序源码提供了一种将网址直接打包成微信小程序的方法, 使得用户可以在微信小程序中直接访问这些网址内容。 这个源码没有进行加…

Python3 Ubuntu

一、安装中文输入法 1.sudo apt install ibus-sunpinyin 2.点击右上角输入法,然后点击加号,输入yin添加进来,最后选中输入法即可 二、安装截屏软件 1.sudo apt install gnome-screenshot 三、安装opencv-python 1.pip3 install --upgrade…

第九讲 Join 算法

1. 为什么我们需要 Join 我们对关系数据库中的表【tables】进行规范化【normalize】,这样我们就减少了信息的冗余和浪费的空间,但是现在我们为了可以响应传入的查询【Query】,我们必须把这些分离的东西重新组合在一起,以重建原始…

瑞吉外卖实战学习--15、批量启售和批量禁售

批量启售和批量禁售 前言代码实现 前言 代码实现 通过url我们可以获取到传过来的ids和状态值&#xff0c;现根据状态值查询出来相关数据然后直接附加状态值最后通过updateBatchById来进行修改 PostMapping("/status/{status}")public R<String> updateStatus(…

嵌入式学习48-单片机1

51单片机—————8位单片机 裸机驱动 无系统 linux驱动 有系统 驱动-----反映硬件变化 MCU 微控器 MPU CPU GPU 图像处理 IDE 集成开发环境 peripheral 外设 SOC&#xff1a; system on chip P0&#xff1a;8bit——8个引脚 位运算 & …

彩虹聚合DNS管理系统v1.0全新发布

聚合DNS管理系统&#xff08;https://github.com/netcccyun/dnsmgr&#xff09;可以实现在一个网站内管理多个平台的域名解析&#xff0c;目前已支持的域名平台有&#xff1a;阿里云、腾讯云、华为云、西部数码、CloudFlare。本系统支持多用户&#xff0c;每个用户可分配不同的…

python 01操作符与流程控制

一、算术运算符 , , *, /, %, **, // 二、赋值运算符 , , -, *, /, %, **, // 三、比较运算符 四、逻辑操作符 五、变量与赋值 赋值运算符是 &#xff0c;与比较运算符 进行区分 需要注意的是&#xff0c;python的变量是不可变对象&#xff0c;如果变量的值发生改变&…

[AIGC] Spring Interceptor 拦截器详解

文章目录 什么是Spring Interceptor如何使用Spring InterceptorSpring Interceptor的影响 什么是Spring Interceptor Interceptor&#xff08;拦截器&#xff09;是Spring MVC框架中的一种特性&#xff0c;类似于Servlet开发中的Filter&#xff08;过滤器&#xff09;&#xf…

【PyQt5篇】使用QtDesigner添加控件和槽

文章目录 &#x1f354;使用QtDesigner进行设计&#x1f6f8;在代码中添加信号和槽 &#x1f354;使用QtDesigner进行设计 我们首先使用QtDesigner设计界面 得到代码login.ui <?xml version"1.0" encoding"UTF-8"?> <ui version"4.0&q…

Java_18 字符串中的单词反转

字符串中的单词反转 你在与一位习惯从右往左阅读的朋友发消息&#xff0c;他发出的文字顺序都与正常相反但单词内容正确&#xff0c;为了和他顺利交流你决定写一个转换程序&#xff0c;把他所发的消息 message 转换为正常语序。 注意&#xff1a;输入字符串 message 中可能会…

移动端适配方案总结之vw

1、vw/vh是什么&#xff1f; vw是&#xff1a;viewport width 视口宽度单位 vh是&#xff1a; viewport height 视口高度单位 实际开发中我们基本用vw&#xff1b; 2.相对视口的尺寸计算结果 1vw 1/100视口宽度 1vh 1/100视口高度 例如&#xff1a; 当前屏幕视口是 375像素…

数据结构---顺序表实现

目录 1.顺序表 2.动态顺序表的实现 &#xff08;4&#xff09;顺序表初始化 &#xff08;5&#xff09;顺序表销毁 &#xff08;6&#xff09;顺序表的插入 a.尾插 b.头插 &#xff08;7&#xff09;顺序表的删除 a.尾删 b.头删 &#xff08;8&#xff09;指定位置之…

【大数据存储】spark-编程

实验8-spark编程 实验&#xff1a;编写Spark应用程序&#xff08;掌握Spark应用程序的编写、编译打包和运行方法&#xff09; 1、对于两个输入文件A和B&#xff0c;编写Spark独立应用程序&#xff0c;对两个文件进行合并&#xff0c;并剔除其中重复的内容&#xff0c;得到一个…

VSCode如何调试C#代码?

1、启动VSCode&#xff1b; 一、创建项目 1、创建一个文件夹(workspace)&#xff1a; 2、进入这个文件夹 cd tt1 3、创建解决方案 dotnet new sln -o MyApp 4、进入解决方案 cd .\MyApp\ 5、创建项目&#xff08;在此假定为一个命令行的项目&#xff09; dotnet new …

例47:键盘事件演示

建立一个EXE工程&#xff0c;在默认窗体上放一个Image框和一一个text框。在text的按键事件中输入代码&#xff1a; Function Form1_Text1_WM_KeyDown(hWndForm As hWnd, hWndControl As hWnd,nVirtKey As Long, lKeyData As Long) As LongIf nVirtKey VK_SPACE ThenImage1.Pi…

Django的html在for遍历后显示“一、二、三...”和“1,2,3...”分级标题

例如当天的html为&#xff1a; {% load static %} {% csrf_token %} <!DOCTYPE html> <html> <head><title>生活规划师</title><link rel"stylesheet" href"{% static css/LifePlanningGuide.css %}"><script src…

FreeRtos入门-7 中断管理

中断管理 中断管理相比非中断的优势 1&#xff0c;简洁和效率。 2&#xff0c;同步和安全。提供了中断安全的操作&#xff0c;确保在中断上下文中执行时不会引发竞态条件或破坏系统状态。 3&#xff0c;通过配置中断的优先级&#xff0c;可以确保高优先级的中断能够立即响应…

AI跟踪报道第36期-新加坡内哥谈技术-这周的AI新闻铺天盖地

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

STM32重要参考资料

stm32f103c8t6 一、引脚定义图 二、时钟树 三、系统结构图 四、启动配置 &#xff08;有时候不小心短接VCC和GND&#xff0c;芯片会锁住&#xff0c;可以BOOT0拉高试试&#xff08;用跳线帽接&#xff09;&#xff09; 五、最小系统原理图 可用于PCB设计 六、常见折腾人bug…