分类模型部署-ONNX

分类模型部署-ONNX

  • 0 引入:
  • 1 模型部署实战测试:
    • 1 安装配置环境:
    • 2 Pytorch图像分类模型转ONNX-ImageNet1000类
    • 3 推理引擎ONNX Runtime部署-预测单张图像:
  • 2 扩展阅读
  • 参考

0 引入:

在软件工程中,部署指把开发完毕的软件投入使用的过程,包括环境配置、软件安装等步骤。类似地,对于深度学习模型来说,模型部署指让训练好的模型在特定环境中运行的过程。相比于软件部署,模型部署会面临更多的难题:

  1. 运行模型所需的环境难以配置。深度学习模型通常是由一些框架编写,比如 PyTorch、TensorFlow。由于框架规模、依赖环境的限制,这些框架不适合在手机、开发板等生产环境中安装。
  2. 深度学习模型的结构通常比较庞大,需要大量的算力才能满足实时运行的需求。模型的运行效率需要优化。

因为这些难题的存在,模型部署不能靠简单的环境配置与安装完成。经过工业界和学术界数年的探索,模型部署有了一条流行的流水线:

为了让模型最终能够部署到某一环境上,开发者们可以使用任意一种深度学习框架来定义网络结构,并通过训练确定网络中的参数。之后,模型的结构和参数会被转换成一种只描述网络结构的中间表示,一些针对网络结构的优化会在中间表示上进行。最后,用面向硬件的高性能编程框架(如 CUDA,OpenCL)编写,能高效执行深度学习网络中算子的推理引擎会把中间表示转换成特定的文件格式,并在对应硬件平台上高效运行模型。
这一条流水线解决了模型部署中的两大问题:使用对接深度学习框架和推理引擎的中间表示,开发者不必担心如何在新环境中运行各个复杂的框架;通过中间表示的网络结构优化和推理引擎对运算的底层优化,模型的运算效率大幅提升。

1 模型部署实战测试:

这里以部署流水线: Pytorch-ONNX-ONNX Runtime/TensorRT 为流程,进行测试:
流程:分类模型,转成onnx,onnxRuntime部署;
这里以部署分类模型为例,参考同济子豪兄:

1 安装配置环境:

torch
torchvision
onnx
onnxruntime
numpy
pandas	
matplotlib
tqdm
opencv-python
pillow

2 Pytorch图像分类模型转ONNX-ImageNet1000类

把Pytorch预训练ImageNet图像分类模型,导出为ONNX格式,用于后续在推理引擎上部署。
导入工具包:

import torch
from torchvision import models

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

载入与训练pytorch图像分类模型:

model = models.resnet18(pretrained=True)
model = model.eval().to(device)

构建一个输入Tensor:

x = torch.randn(1, 3, 256, 256).to(device)

输入pytorch模型推理预测,获得1000个类别的预测结果:

output = model(x)
output.shape

torch.Size([1, 1000])

pytorch模型转onnx格式:

with torch.no_grad():
    torch.onnx.export(
        model,                       # 要转换的模型
        x,                           # 模型的任意一组输入
        'resnet18_imagenet.onnx',    # 导出的 ONNX 文件名
        opset_version=11,            # ONNX 算子集版本
        input_names=['input'],       # 输入 Tensor 的名称(自己起名字)
        output_names=['output']      # 输出 Tensor 的名称(自己起名字)
    ) 

验证onnx模型导出成功:

import onnx

# 读取 ONNX 模型
onnx_model = onnx.load('resnet18_imagenet.onnx')

# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)

print('无报错,onnx模型载入成功')

以可读的形式打印计算图:

print(onnx.helper.printable_graph(onnx_model.graph))
graph main_graph (
  %input[FLOAT, 1x3x256x256]
) initializers (
  %fc.weight[FLOAT, 1000x512]
  %fc.bias[FLOAT, 1000]
  %onnx::Conv_193[FLOAT, 64x3x7x7]
  %onnx::Conv_194[FLOAT, 64]
  %onnx::Conv_196[FLOAT, 64x64x3x3]
  %onnx::Conv_197[FLOAT, 64]
  %onnx::Conv_199[FLOAT, 64x64x3x3]
  %onnx::Conv_200[FLOAT, 64]
  %onnx::Conv_202[FLOAT, 64x64x3x3]
  %onnx::Conv_203[FLOAT, 64]
  %onnx::Conv_205[FLOAT, 64x64x3x3]
  %onnx::Conv_206[FLOAT, 64]
  %onnx::Conv_208[FLOAT, 128x64x3x3]
  %onnx::Conv_209[FLOAT, 128]
  %onnx::Conv_211[FLOAT, 128x128x3x3]
  %onnx::Conv_212[FLOAT, 128]
  %onnx::Conv_214[FLOAT, 128x64x1x1]
  %onnx::Conv_215[FLOAT, 128]
  %onnx::Conv_217[FLOAT, 128x128x3x3]
  %onnx::Conv_218[FLOAT, 128]
  %onnx::Conv_220[FLOAT, 128x128x3x3]
  %onnx::Conv_221[FLOAT, 128]
  %onnx::Conv_223[FLOAT, 256x128x3x3]
  %onnx::Conv_224[FLOAT, 256]
  %onnx::Conv_226[FLOAT, 256x256x3x3]
  %onnx::Conv_227[FLOAT, 256]
  %onnx::Conv_229[FLOAT, 256x128x1x1]
  %onnx::Conv_230[FLOAT, 256]
  %onnx::Conv_232[FLOAT, 256x256x3x3]
  %onnx::Conv_233[FLOAT, 256]
  %onnx::Conv_235[FLOAT, 256x256x3x3]
  %onnx::Conv_236[FLOAT, 256]
  %onnx::Conv_238[FLOAT, 512x256x3x3]
  %onnx::Conv_239[FLOAT, 512]
  %onnx::Conv_241[FLOAT, 512x512x3x3]
  %onnx::Conv_242[FLOAT, 512]
  %onnx::Conv_244[FLOAT, 512x256x1x1]
  %onnx::Conv_245[FLOAT, 512]
  %onnx::Conv_247[FLOAT, 512x512x3x3]
  %onnx::Conv_248[FLOAT, 512]
  %onnx::Conv_250[FLOAT, 512x512x3x3]
  %onnx::Conv_251[FLOAT, 512]
) 

也可以使用netron可视化模型结构:
image.png

3 推理引擎ONNX Runtime部署-预测单张图像:

使用推理引擎,ONNX Runtime,读取 ONNX 格式的模型文件,对单张图像文件进行预测。

应用场景:
以下代码在需要部署的硬件上运行(本地PC、嵌入式开发板、树莓派、Jetson Nano、服务器)
只需把`onnx`模型文件发到部署硬件上,并安装 ONNX Runtime 环境,用下面几行代码就可以运行模型了。

导入工具包:

import onnxruntime
import numpy as np
import torch
import torch.nn.functional as F

import pandas as pd

载入Onnx模型,获取 ONNX Runtime 推理器

ort_session = onnxruntime.InferenceSession('resnet18_imagenet.onnx')

构造随机输入,获取输出结果:

x = torch.randn(1, 3, 256, 256).numpy()
x.shape
# onnx runtime 输入
ort_inputs = {'input': x}

# onnx runtime 输出
ort_output = ort_session.run(['output'], ort_inputs)[0]

注意,输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应

ort_output.shape

(1, 3, 256, 256)
(1, 1000)

载入真实图片预处理;

img_path = 'banana1.jpg'

# 用 pillow 载入
from PIL import Image
img_pil = Image.open(img_path)

from torchvision import transforms

#预处理
# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

input_img = test_transform(img_pil)
input_img.shape
input_tensor = input_img.unsqueeze(0).numpy()
input_tensor.shape



推理预测:

# ONNX Runtime 输入
ort_inputs = {'input': input_tensor}

# ONNX Runtime 输出
pred_logits = ort_session.run(['output'], ort_inputs)[0]
pred_logits = torch.tensor(pred_logits)
pred_logits.shape
# 对 logit 分数做 softmax 运算,得到置信度概率
pred_softmax = F.softmax(pred_logits, dim=1) 
pred_softmax.shape

解析结果:将数字映射到物品种类上:

# 取置信度最高的前 n 个结果
n = 3
top_n = torch.topk(pred_softmax, n)
top_n

# 预测类别
pred_ids = top_n.indices.numpy()[0]
pred_ids

# 预测置信度
confs = top_n.values.numpy()[0]

载入类别ID和类别名称对应关系:

df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
for idx, row in df.iterrows():
    idx_to_labels[row['ID']] = row['class']   # 英文
#     idx_to_labels[row['ID']] = row['Chinese'] # 中文

打印预测结果:

for i in range(n):
    class_name = idx_to_labels[pred_ids[i]] # 获取类别名称
    confidence = confs[i] * 100             # 获取置信度
    text = '{:<20} {:>.3f}'.format(class_name, confidence)
    print(text)

image.png
imagenet_class_index.csv

2 扩展阅读

ONNX主页: https://onnx.ai
ONNX Github主页: https://github.com/onnx/onnx
ONNX支持的深度学习框架:https://onnx.ai/supported-tools.html
ONNX支持的计算平台:https://onnx.ai/supported-tools.html
ONNX算子:https://onnx.ai/onnx/operators
模型在线可视化网站 Netron:https://netron.app
Netron视频教程:https://www.bilibili.com/video/BV1TV4y1P7AP

周奕帆博客:
模型部署入门教程(一):模型部署简介
https://zhuanlan.zhihu.com/p/477743341
模型部署入门教程(二):解决模型部署中的难题
https://zhuanlan.zhihu.com/p/479290520
模型部署入门教程(三):PyTorch 转 ONNX 详解
https://zhuanlan.zhihu.com/p/498425043

参考

https://mmdeploy.readthedocs.io/zh-cn/latest/tutorial/01_introduction_to_model_deployment.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/715181.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

消息队列-RabbitMQ-延时队列实现

死信队列 DLX,全称为Dead-Letter-Exchange,死信交换机&#xff0c;死信邮箱。当消息在一个队列中变成死信之后&#xff0c;它能重新发送到另外一个交换器中&#xff0c;这个交换器就是DLX&#xff0c;绑定DLX的队列就称为死信队列。 导致死信的几种原因&#xff1a; ● 消息…

Swift开发——简单函数实例

函数是模块化编程的基本单位,将一组完成特定功能的代码“独立”地组成一个执行单位,称为函数。函数的基本结构如下所示: 其中,func为定义函数的关键字;“函数名”是调用函数的入口;每个函数可以有多个参数,即可以有多个“参数标签 参数名称:参数类型”,一般地,各个参…

边缘微型AI的宿主?—— RISC-V芯片

一、RISC-V技术 RISC-V&#xff08;发音为 "risk-five"&#xff09;是一种基于精简指令集计算&#xff08;RISC&#xff09;原则的开放源代码指令集架构&#xff08;ISA&#xff09;。它由加州大学伯克利分校在2010年首次发布&#xff0c;并迅速获得了全球学术界和工…

跟着刘二大人学pytorch(第---13---节课之RNN高级篇)

文章目录 0 前言0.1 课程视频链接&#xff1a;0.2 课件下载地址&#xff1a; 1 本节课任务描述模型的处理过程训练循环初始化分类器是否使用GPU构造损失函数和优化器每个epoch所要花费的时间遍历每个epoch时进行训练和测试记录每次测试的准确率加入到列表中 具体实现&#xff0…

中国最著名的起名大师颜廷利:父亲节与之相关的真实含义

今天是2024年6月16日&#xff0c;这一天被广泛庆祝为“父亲节”。在汉语中&#xff0c;“父亲”这一角色常以“爸爸”、“大大”&#xff08;da-da&#xff09;或“爹爹”等词汇表达。有趣的是&#xff0c;“爸爸”在汉语拼音中表示为“ba-ba”&#xff0c;而当我们稍微改变“b…

DeepDriving | 经典的目标检测算法:CenterNet

本文来源公众号“DeepDriving”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;经典的目标检测算法&#xff1a;CenterNet 1 前言 CenterNet是2019年发表的一篇文章《Objects as Points》中提出的一个经典的目标检测算法&#xf…

MySQL-创建表~数据类型

070-创建表 create table t_user(no int,name varchar(20),gender char(1) default 男);071-插入数据 语法格式&#xff1a; insert into 表名(字段名1, 字段名2, 字段名3,......) values (值1,值2,值3,......);insert into t_user(no, name, gender) values(1, Cupid, 男);字…

监控异地组网的方法?

监控异地组网是一项关键的技术&#xff0c;能够实现远程连接和访问。在复杂的网络环境中&#xff0c;使用传统的方法可能会遭遇网络限制和访问速度较慢的问题。而采用新兴的监控异地组网方法&#xff0c;如【天联】组网技术&#xff0c;可以克服这些问题并提供更好的用户体验。…

计算机缺失d3dcompiler_43.dll怎么办,介绍5种靠谱的解决方法

在电脑使用过程中&#xff0c;我们经常会遇到一些错误提示&#xff0c;其中之一就是“找不到d3dcompiler43.dll”的错误。那么&#xff0c;d3dcompiler43.dll到底是什么&#xff1f;为什么会出现丢失的情况&#xff1f;它对计算机有什么具体影响&#xff1f;如何解决这个问题&a…

黄仁勋最新建议:找到一门技艺,用一生去完善、磨炼!

“你可能会找到你的英伟达。我希望你们将挫折视为新的机遇。” 黄仁勋职业生涯中最大的教诲并非来自导师或科技公司 CEO&#xff0c;而是来自他在国际旅行时遇到的一位园丁。 近日在加州理工学院毕业典礼上发表演讲时&#xff0c;黄仁勋向毕业生分享了自己在日本京都的小故事。…

跟着刘二大人学pytorch(第---12---节课之RNN基础篇)

文章目录 0 前言0.1 课程视频链接&#xff1a;0.2 课件下载地址&#xff1a; 1 Basic RNN1.1 复习DNN和CNN1.2 直观认识RNN1.3 RNN Cell的内部计算方式 2 具体什么是一个RNN&#xff1f;3 使用pytorch构造一个RNN3.1 手动构造一个RNN Cell来实现RNN3.2 直接使用torch中现有的RN…

赶紧转行大模型,预计风口就今年一年,明年市场就饱和了!不是开玩笑

恕我直言&#xff0c;就这几天&#xff0c;各大厂都在裁员&#xff0c;什么开发测试运维都裁&#xff0c;只有大模型是急招人。 你说你不知道大模型是什么&#xff1f;那可太对了&#xff0c;你不知道说明别人也不知道&#xff0c;就是要趁只有业内部分人知道的时候入局&#…

深度学习1 -- 开头

感觉用这玩意越来越多&#xff0c;所以想学学。不过没想好怎么学&#xff0c;也没有提纲&#xff0c;买了两本书&#xff0c;一本是深度学习入门&#xff0c;小日子写的。还有一本就是花书。还有就是回Gatech参加线上课程&#xff0c;提纲大概是这样的。 https://omscs.gatech…

生产中的 RAG:使你的生成式 AI 项目投入运营

作者&#xff1a;来自 Elastic Tim Brophy 检索增强生成 (RAG) 为组织提供了一个采用大型语言模型 (LLM) 的机会&#xff0c;即通过将生成式人工智能 (GenAI) 功能应用于其自己的专有数据。使用 RAG 可以降低固有风险&#xff0c;因为我们依赖受控数据集作为模型答案的基础&…

比利时海外媒体宣发,发稿促进媒体通稿发布新形势-大舍传媒

引言 随着全球化的推进&#xff0c;海外媒体的影响力也日益增强。在这一背景下&#xff0c;比利时海外媒体的宣发工作成为了媒体通稿发布的新形势。大舍传媒作为一家专注于宣传推广的公司&#xff0c;一直致力于与比利时博伊克邮报&#xff08;boicpost&#xff09;合作&#…

ModuleNotFoundError: No module named ‘distutils‘的解决办法

最近想试试odoo17&#xff0c;在windows环境下&#xff0c;想安装试验一下&#xff0c;结果老出现oduleNotFoundError: No module named ‘distutils‘错误。查了一下&#xff0c;以为是python版本导致的&#xff0c;结果试了很多版本如下&#xff1a; 试了几个&#xff0c;每个…

4-异常-log4j配置日志滚动覆盖出现日志丢失问题

4-异常-log4j配置日志打印滚动覆盖出现日志丢失问题(附源码分析) 更多内容欢迎关注我&#xff08;持续更新中&#xff0c;欢迎Star✨&#xff09; Github&#xff1a;CodeZeng1998/Java-Developer-Work-Note 技术公众号&#xff1a;CodeZeng1998&#xff08;纯纯技术文&…

浪潮信息内存故障预警技术再升级 服务器稳定性再获提升

浪潮信息近日对其内存故障智能预警修复技术进行了全面升级&#xff0c;再次取得技术突破。此次升级后&#xff0c;公司服务器的宕机率实现了80%锐降&#xff0c;再次彰显了浪潮信息在服务器技术领域的卓越能力。 浪潮信息全新升级服务器内存故障智能预警修复技术MUPR (Memory …

大数据开发流程解析

大数据开发是一个复杂且系统的过程&#xff0c;涉及需求分析、数据探查、指标管理、模型设计、ETL开发、数据验证、任务调度以及上线管理等多个阶段。本文将详细介绍每个阶段的内容&#xff0c;并提供相关示例和代码示例&#xff0c;帮助理解和实施大数据开发流程。 本文中的示…

学习记录:VS2019+OpenCV3.4.1实现SURF库函数的调用

最近在学习opencv的使用&#xff0c;在参照书籍《OpenCV3编程入门》实现SURF时遇到不少问题&#xff0c;下面做归纳总结。 错误 LNK2019 无法解析的外部符号 “public: static struct cv::Ptr __cdecl cv::xfeatures2d::SURF::create(double,int,int,bool,bool)” (?createSUR…