【Pytorch】(十五)模型部署:ONNX和ONNX Runtime

文章目录

  • (十五)模型部署:ONNX和ONNX Runtime
    • ONNX 和 ONNX Runtime的关系
    • 将PyTorch模型导出为ONNX格式
    • 使用Netron可视化ONNX模型图
    • 检查ONNX模型
    • 验证ONNX Runtime推理结果
    • 使用ONNX Runtime运行超分模型

(十五)模型部署:ONNX和ONNX Runtime

ONNX 和 ONNX Runtime的关系

ONNX(模型表示格式):Open Neural Network Exchange(ONNX)一种用于表示深度学习模型的标准格式。这个格式允许将模型从一个深度学习框架转移到另一个框架,以及在不同平台上进行推理。

ONNX Runtime(推理引擎):ONNX Runtime(ORT) 是一个用于运行和执行 ONNX 模型的推理引擎。ONNX Runtime 提供了高性能、低延迟的深度网络模型推理,并且是跨平台的,支持各种操作系统和设备。ONNX Runtime已被证明可以显著提高多个模型的推理性能。

想用ONNX和ONNX Runtime进行Pytorch模型部署,首先需要安装以下Python包:

pip install --upgrade onnx onnxscript onnxruntime

将PyTorch模型导出为ONNX格式

Pytorch中torch.onnx模块提供API来从PyTorch的 torch.nn.Module模块捕获计算图,并将其转换为ONNX格式。从PyTorch 2.1开始,ONNX Exporter有两个版本。

torch.onnx.dynamo_export是基于PyTorch 2.0发布的TorchDynamo技术的最新Exporter(仍处于测试版)

torch.onnx.export则基于TorchScript,自PyTorch 1.2.0以来一直可用

本文只介绍torch.onnx.export,关于torch.onnx.dynamo_export,可阅读:
https://pytorch.org/tutorials/beginner/onnx/intro_onnx.html
https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html

下面将以一个图像超分模型为例,介绍如何使用基于TorchScript 的torch.onnx.export将PyTorch中定义的模型转换为ONNX格式。

import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn as nn
import torch.nn.init as init

# 1 搭建一个超分模型。
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# 创建一个模型实例
torch_model = SuperResolutionNet(upscale_factor=3)

# 2 训练模型或者直接导入预训练的模型参数,这里采用后者:
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'

map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# 3 将模型转换为推理模式。
# 这是必需的,因为像dropout或batchnorm这样的运算符在推理和训练模式中表现不同。
# set the model to inference mode
torch_model.eval()

# 4 导出ONNX模型

batch_size = 1    # just a random number
## 首先需要提供一个输入张量x。只要它是正确的类型和大小,其中的值就可以是随机的。
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

## 导出模型
torch.onnx.export(torch_model,               # 模型
                  x,                         # 模型输入
                  "super_resolution.onnx",   # onnx文件保存路径
                  export_params=True,        # 将经过训练的参数权重存储在模型文件中
                  opset_version=10,          # ONNX的版本
                  do_constant_folding=True,  # 执行常量折叠(constant folding)进行优化
                  input_names = ['input'],   # 模型输入的名字
                  output_names = ['output'], #  模型输出的名字
                  dynamic_axes={'input' : {0 : 'batch_size'}, # 将第一个维度指定为dynamic
                                'output' : {0 : 'batch_size'}})
# 计算原始Pytorch模型的输出,用于验证导出的ONNX 模型是否能计算出相同的值。                     
torch_out = torch_model(x)  # 计算原始Pytorch模型的输出

请注意,除非在dynamic_axes指定,否则ONNX模型中输入和输出的尺寸大小都是固定的。在本例中,在torch.onnx.export()中的dynamic_axies参数中将第一个维度指定为dynamic。这使得导出的模型接受大小为 [batch_size, 1, 224, 224]的输入,其中batch_size是可变的。

使用Netron可视化ONNX模型图

Netron可以对ONNX模型图进行可视化。Netron除了可以安装在macos、Linux或Windows系统的计算机上,还可以在浏览器上运行:https://netron.app/

打开Netron后,我们可以将.onnx文件拖放到浏览器中,也可以在单击“打开模型”从文件目录选择它,进行可视化:

检查ONNX模型

在使用ONNX Runtime进行推理之前,我们先使用ONNX API检查ONNX模型。

import onnx
# 加载onnx模型
onnx_model = onnx.load("super_resolution.onnx")
# 验证ONNX模型的有效性,包括通过检查模型的版本、图的结构和节点及其输入和输出
onnx.checker.check_model(onnx_model)

验证ONNX Runtime推理结果

现在,让我们使用ONNX Runtime的Python API来进行推理。

这一部分通常是在另一个进程中或在另一台机器上完成。为了验证ONNX Runtime和PyTorch原始网络模型计算的值是否近似,我们在一个进程进行。

import onnxruntime
# 创建一个推理会话
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 使用ONNX Runtime进行推理
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# ONNX Runtime和PyTorch原始网络模型输出的近似程度没有达到指定精度(rtol=1e-03和atol=1e-05),将抛出异常。
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

使用ONNX Runtime运行超分模型


import numpy as np
import onnxruntime
from PIL import Image
import torchvision.transforms as transforms

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


# 创建一个推理会话
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])

# 加载图像与预处理
img = Image.open("cat.jpg")

resize = transforms.Resize([224, 224])
img = resize(img)

img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()

to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)

# 在ONNX Runtime中运行超分辨率模型
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]

# 从输出张量构造最终输出图像,并保存
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')

final_img = Image.merge(
    "YCbCr", [
        img_out_y,
        img_cb.resize(img_out_y.size, Image.BICUBIC),
        img_cr.resize(img_out_y.size, Image.BICUBIC),
    ]).convert("RGB")  # Cr, Cb通道通过插值发大

final_img.save("cat_superres_with_ort.jpg")

参考:
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/573957.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

外贸干货|真正的销售高手,都很会提问

你的产品性价比很高,为什么客户没有买单呢? 最重要的原因是你没有了解到他真正的需求。 真正的销售高手,应该是一个提问高手,至少要连续问对方6个问题,问出客户的真实需求。 假如他回答你的问题,你有一种&a…

git 冲突与解决冲突

目录 1.使用 git 解决冲突 GIT 常用命令 制造冲突 解决冲突 2.使用 IDEA 解决冲突 产生冲突 解决冲突 1.使用 git 解决冲突 GIT 常用命令 命令作用git clone克隆git init初始化git add 文件名添加到暂存区git commit -m " 日志信息" 文件名提交到本地库git st…

【运维自动化-配置平台】如何通过模板创建集群和模块

通过【每天掌握一个功能点】配置平台如何创建业务机拓扑(集群-模块)我们知道了直接创建集群和模块的操作方法,直接创建的方式适合各集群模块都相对独立的场景,那大量的、标准规范的集群模块如何快速创建呢,这里就引入了…

条件生成对抗网络(cGAN)在AI去衣技术中的应用探索

随着深度学习技术的飞速发展,生成对抗网络(GAN)作为其中的一个重要分支,在图像生成、图像修复等领域展现出了强大的能力。其中,条件生成对抗网络(cGAN)通过引入条件变量来控制生成模型的输出&am…

面试十五 容器

一、vector容器 template<typename T> class Allocator{ public:T* allocator(size_t size){// 负责内存开辟return (T*)malloc(sizeof(T) * size);}void deallocate(void * p){free(p);}void construct(T*p,const T&val){// 定位newnew (p) T(val);}void destroy(…

Golang对接Ldap(保姆级教程:概念搭建实战)

Golang对接Ldap&#xff08;保姆级教程&#xff1a;概念&搭建&实战&#xff09; 最近项目需要对接客户的LDAP服务&#xff0c;于是趁机好好了解了一下。LDAP实际是一个协议&#xff0c;对应的实现&#xff0c;大家可以理解为一个轻量级数据库。用户查询。比如&#xff…

DiT论文精读Scalable Diffusion Models with Transformers CVPR2023

Scalable Diffusion Models with Transformers CVPR2023 Abstract idea 将UNet架构用Transformer代替。并且分析其可扩展性。 并且实验证明通过增加transformer的宽度和深度&#xff0c;有效降低FID 我们最大的DiT-XL/2模型在classconditional ImageNet 512、512和256、256基…

switch语句深讲

一。功能 1.选择&#xff0c;由case N:完成 2.switch语句本身没有分支功能&#xff0c;分支功能由break完成 二。注意 1.switch语句如果不加break&#xff0c;在一次判断成功后会执行下面全部语句并跳过判断 2.switch的参数必须是整形或者是计算结果为整形的表达式,浮点数会…

centos 7 yum install -y nagios

centos 7 systemctl disable firewalld --now vi /etc/selinux/config SELINUXdisabled yum install -y epel-release httpd nagios yum install -y httpd nagios systemctl enable httpd --now systemctl enable nagios --now 浏览器 IP/nagios 用户名&#xff1a;…

stack,queue的模拟实现以及优先级队列

这篇博客用来记录stack&#xff0c;queue的学习。 stack的模拟实现 stack的模拟实现比较简单&#xff0c;先上代码 #pragma once #include<vector> #include<list> #include<deque> #include<iostream> using std::deque; using namespace std;name…

【STM32HAL库】外部中断

目录 一、中断简介 二、NVIC 1.寄存器 2.工作原理 3.优先级 4.使用NVIC 三、EXTI 1.简介 2.AFIO&#xff1a;复用功能IO&#xff0c;主要用于重映射和外部中断映射配置​编辑 3. 中断使用 4.HAL库配置使用 一、中断简介 中断的意义&#xff1a;高效处理紧急程序&#xff0c;不会…

树莓派学习笔记--串口通信(配置硬件串口进行通信)

树莓派串口知识点 树莓派4b的外设一共包含两个串口&#xff1a;硬件串口&#xff08;/dev/ttyAMA0&#xff09;,mini串口&#xff08;/dev/ttyS0&#xff09; 硬件串口由硬件实现&#xff0c;有单独的波特率时钟源&#xff0c;性能高&#xff0c;可靠&#xff1b;而mini串口性能…

Java-AQS的原理

文章目录 基本概述1. 设计思想2. 基本实现 一些关键词语以及常用术语&#xff0c;主要如下&#xff1a; 信号量(Semaphore): 是在多线程环境下使用的一种设施&#xff0c;是可以用来保证两个或多个关键代码段不被并发调用&#xff0c;也是作系统用来解决并发中的互斥和同步问题…

数据挖掘 | Count数据去除批次效应后不是整数甚至还出现负值导致无法进行差异分析怎么办?

之前咱们介绍过数据挖掘 | 批次效应的鉴定与处理 | 附完整代码 注释 | 看完不会来揍我&#xff0c;但是很多小伙伴遇到了Count数据批次处理后不是整数甚至还出现负值的问题&#xff0c;这就导致无法使用某些包包进行差异分析&#xff08;对差异分析感兴趣的小伙伴可以查看&…

MySQL中如何随机获取一条记录

点击上方蓝字关注我 随机获取一条记录是在数据库查询中常见的需求&#xff0c;特别在需要展示随机内容或者随机推荐的场景下。在 MySQL 中&#xff0c;有多种方法可以实现随机获取一条记录&#xff0c;每种方法都有其适用的情况和性能特点。在本文中&#xff0c;我们将探讨几种…

word添加行号

打开页面设置&#xff0c;找到行号

2018-2023年上市公司富时罗素ESG评分数据

2018-2023年上市公司富时罗素ESG评分数据 1、时间&#xff1a;2018-2023年 2、来源&#xff1a;整理自WIND 3、指标&#xff1a;证券代码、简称、ESG评分 4、范围&#xff1a;上市公司 5、指标解释&#xff1a; 富时罗素将公司绿色收入的界定和计算作为公司ESG 评级打分结…

「白嫖」开源的后果就是供应链攻击么?| 编码人声

「编码人声」是由「RTE开发者社区」策划的一档播客节目&#xff0c;关注行业发展变革、开发者职涯发展、技术突破以及创业创新&#xff0c;由开发者来分享开发者眼中的工作与生活。 面对网络安全威胁日益严重的今天&#xff0c;软件供应链安全已经成为开发者领域无法避免的焦点…

OpenWRT设置自动获取IP,作为二级路由器

前言 上一期咱们讲了在OpenWRT设置PPPoE拨号的教程&#xff0c;在光猫桥接的模式下&#xff0c;OpenWRT如果不设置PPPoE拨号&#xff0c;就无法正常上网。 OpenWRT设置PPPoE拨号教程 但现在很多新装的宽带&#xff0c;宽带师傅为了方便都会把光猫设置为路由模式。如果你再外…

【A-024】基于SSH的房屋租赁管理系统(含论文)

【A-024】基于SSH的房屋租赁管理系统&#xff08;含论文&#xff09; 开发环境&#xff1a; Jdk7(8)Tomcat7(8)MySQLIntelliJ IDEA(Eclipse) 数据库&#xff1a; MySQL 技术&#xff1a; SpringStruts2HiberanteBootstrapJquery 适用于&#xff1a; 课程设计&#xff0c;毕…