模型部署onnx入门

一、定义

1.定义
2. 环境安装
3. 案例
4. 可视化界面
5. 参考网址
6. 推理引擎 onnx Runtime 进行单张图片推理,本地部署
7. 推理引擎onnx Runtime 进行单张图片推理,调用摄像头获取画面
8. 推理引擎onnx Runtime 进行图片推理,调用摄像头获取画面
9. ModelInferBench 推理速度测试

二、实现

  1. 定义
    模型pytorch\ keras\ tensorflow\ pp飞桨 等训练框架
    onnx 标准化
    nvidia\ intel\ 高通\ 升腾 等硬件。 在这里插入图片描述
    深度学习框架------->中间件--------->推理引擎(不同设备,选择的推理引擎不同),加速推理
    推理引擎:在这里插入图片描述
  2. 环境安装
    onnx 算子 网址https://onnx.ai/onnx/operators/
pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
import onnx
print('ONNX 版本', onnx.__version__)
import onnxruntime as ort
print('ONNX Runtime 版本', ort.__version__)
  1. 案例
import torch
from torchvision import models

devices=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=models.resnet18(pretrained=True)
model=model.eval().to(devices)

x=torch.randn((1,3,256,256)).to(devices)
output=model(x)
print(output.shape)
#pytorch onnx 格式转换
with torch.no_grad():
    torch.onnx.export(model,
                      x,
                      "resnet18_imagenet.onnx",   #导出的onnx 名称
                      opset_version=11,            #算子版本          https://onnx.ai/onnx/operators/
                      input_names=["input"],       #输入tensor 名称,自己起的
                      output_names=["output"]      #输出tensor 名称, 自己起的
                      )

#验证是否成功
import onnx
onnx_model=onnx.load("resnet18_imagenet.onnx")
onnx.checker.check_model(onnx_model)
print("无报错,转换成功")
  1. 可视化界面
    https://netron.app/ #浏览器打开
    导入对应的onnx 文件在这里插入图片描述

  2. 参考网址
    https://github.com/TommyZihao/Train_Custom_Dataset/blob/main/%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB/7-ONNX%20Runtime%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E9%83%A8%E7%BD%B2/1-Pytorch%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E8%BD%ACONNX/%E3%80%90Z%E3%80%91%E6%89%A9%E5%B1%95%E9%98%85%E8%AF%BB.ipynb
    https://zhuanlan.zhihu.com/p/498425043

  3. 推理引擎 onnx Runtime 进行单张图片推理,本地部署

import torch
import onnxruntime
import numpy as np
import torch.nn.functional as F
import pandas as pd

#onnxruntime 推理
ort_session=onnxruntime.InferenceSession("resnet18_imagenet.onnx")    #加载模型
x=torch.randn((1,3,256,256)).numpy()
ort_input={"input":x}
output=ort_session.run(["output"],ort_input)[0]   #推理
print(output.shape)

###########图像推理#########

from PIL import Image   #加载图片
img_pil=Image.open("banana1.jpg")
#img_pil.show()
#数据处理
from torchvision import transforms

test_transform=transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                                    ])

input_img=test_transform(img_pil)
input_tensor=input_img.unsqueeze(0).numpy()

ort_input={"input":input_tensor}
pred_logits=ort_session.run(["output"],ort_input)[0]
pred_logits=torch.tensor(pred_logits)

pred_softmax=F.softmax(pred_logits,dim=1)

#解析,取前n个
n=3
topn=torch.topk(pred_softmax,n)
pred_ids=topn.indices.numpy()[0]
conf=topn.values.numpy()[0]
print(pred_ids,conf)

#将id与字典对应
df=pd.read_csv("imagenet_class_index.csv")
id_to_labels={}
for idx,row in df.iterrows():
    id_to_labels[row["ID"]]=row["class"]  #英文

for i in range(n):
    class_name=id_to_labels[pred_ids[i]]
    confidence=conf[i]*100
    print(class_name,confidence)

  1. 推理引擎onnx Runtime 进行单张图片推理,调用摄像头获取画面
import torch
import onnxruntime
import numpy as np
import torch.nn.functional as F
import pandas as pd
from PIL import Image,ImageFont,ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt
#font=ImageFont.truetype("SimHei",32)


ort_session = onnxruntime.InferenceSession('resnet18_imagenet.onnx')

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                                    ])


#调用摄像头   获取摄像头,传入0表示获取系统默认摄像头
import cv2
import time
cap=cv2.VideoCapture(1)
cap.open(0)    #打开摄像头
time.sleep(1)
success,img_bgr=cap.read()
cap.release()                   #关闭摄像头
cv2.destroyAllWindows()       #关闭图像接口
print(img_bgr.shape)

img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # BGR转RGB
img_pil = Image.fromarray(img_rgb)
img_pil.show()

input_img=test_transform(img_pil)
input_tensor=input_img.unsqueeze(0).numpy()

#预测
ort_input={"input":input_tensor}
pred_logits=ort_session.run(["output"],ort_input)[0]
pred_logits=torch.tensor(pred_logits)
pred_softmax=F.softmax(pred_logits,dim=-1)

#topn 类别
n=5
topn=torch.topk(pred_softmax,k=n)
pred_ids=topn[1].cpu().detach().numpy().squeeze()
pred_conf=topn[0].cpu().detach().numpy().squeeze()

df=pd.read_csv("imagenet_class_index.csv")
id_to_labels={}
for idx,row in df.iterrows():
    id_to_labels[row["ID"]]=row["class"]  #英文

for i in range(n):
    class_name=id_to_labels[pred_ids[i]]
    confidence=pred_conf[i]*100
    print(class_name,confidence)

draw = ImageDraw.Draw(img_pil)
# 在图像上写字
for i in range(len(pred_conf)):
    pred_class = id_to_labels[pred_ids[i]]
    text = '{:<15} {:>.3f}'.format(pred_class, pred_conf[i])
    # 文字坐标,中文字符串,字体,rgba颜色
    draw.text((50, 100 + 50 * i), text, fill=(255, 0, 0, 1))
img = np.array(img_pil)  # PIL 转 array
plt.imshow(img)
plt.show()
  1. 推理引擎onnx Runtime 进行图片推理,调用摄像头获取画面
import torch
import onnxruntime
import numpy as np
import torch.nn.functional as F
import pandas as pd
from PIL import Image,ImageFont,ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt
import time
import cv2

#加载引擎
ort_session = onnxruntime.InferenceSession('resnet18_imagenet.onnx')

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                                    ])
#加载标签
df=pd.read_csv("imagenet_class_index.csv")
id_to_labels={}
for idx,row in df.iterrows():
    id_to_labels[row["ID"]]=row["class"]  #英文



def process_frame(img):
    '''
    输入摄像头拍摄画面bgr-array,输出图像分类预测结果bgr-array
    '''

    # 记录该帧开始处理的时间
    start_time = time.time()

    ## 画面转成 RGB 的 Pillow 格式
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # BGR转RGB
    img_pil = Image.fromarray(img_rgb)  # array 转 PIL

    ## 预处理
    input_img = test_transform(img_pil)  # 预处理
    input_tensor = input_img.unsqueeze(0).numpy()

    ## onnx runtime 预测
    ort_inputs = {'input': input_tensor}  # onnx runtime 输入
    pred_logits = ort_session.run(['output'], ort_inputs)[0]  # onnx runtime 输出
    pred_logits = torch.tensor(pred_logits)
    pred_softmax = F.softmax(pred_logits, dim=1)  # 对 logit 分数做 softmax 运算

    ## 解析top-n预测结果的类别和置信度
    n = 5
    top_n = torch.topk(pred_softmax, n)  # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze()  # 解析出类别
    confs = top_n[0].cpu().detach().numpy().squeeze()  # 解析出置信度

    ## 在图像上写中文
    draw = ImageDraw.Draw(img_pil)
    for i in range(len(confs)):
        pred_class = id_to_labels[pred_ids[i]]
        text = '{:<15} {:>.3f}'.format(pred_class, confs[i])
        # 文字坐标,中文字符串,字体,rgba颜色
        draw.text((50, 100 + 50 * i), text, fill=(255, 0, 0, 1))
    img = np.array(img_pil)  # PIL 转 array
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # RGB转BGR

    # 记录该帧处理完毕的时间
    end_time = time.time()
    # 计算每秒处理图像帧数FPS
    FPS = 1 / (end_time - start_time)
    # 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
    img = cv2.putText(img, 'FPS  ' + str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4,
                      cv2.LINE_AA)
    return img


import cv2
import time

# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)

# 打开cap
cap.open(0)

# 无限循环,直到break被触发
while cap.isOpened():       #捕获每一帧图片
    # 获取画面
    success, frame = cap.read()
    if not success:
        print('Error')
        break

    ## !!!处理帧函数
    frame = process_frame(frame)

    # 展示处理后的三通道图像
    cv2.imshow('my_window', frame)

    if cv2.waitKey(1) in [ord('q'), 27]:  # 按键盘上的q或esc退出(在英文输入法下)
        break

# 关闭摄像头
cap.release()

# 关闭图像窗口
cv2.destroyAllWindows()

  1. ModelInferBench 推理速度测试
    网址:https://github.com/zhangchaosd/ModelInferBench

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/721917.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring Cloud 专题-配置篇(2)

关注微信公众号&#xff1a;IT技术馆 Spring Cloud 是一个基于 Spring Boot 构建的云应用开发工具集&#xff0c;它为开发人员提供了在分布式系统中快速实现常见模式&#xff08;如配置管理、服务发现、断路器等&#xff09;的能力。下面是一个简化的示例&#xff0c;展示如何…

结构思考力:让你的思维更有条理

在这个信息爆炸的时代&#xff0c;如何让自己的思维更有条理&#xff0c;更高效地沟通显得尤为重要。最近读了《结构思考力》一书。今天&#xff0c;我想和大家分享一下读后感&#xff0c;从以下几个方面展开&#xff1a;1. 什么是结构思考力及其重要性&#xff1b;2. 为什么要…

JavaScript日期对象、DOM节点操作(查找、增加、克隆、删除)

目录 1. 日期对象2. DOM节点操作2.1 查找节点2.2 增加节点2.3 克隆节点2.4 删除节点 1. 日期对象 实例化日期对象&#xff1a; 获取当前时间: new Date()获取指定时间: new Date(2023-12-1 17:12:08) 日期对象方法: 方法作用说明getFullYear()获得年份获取四位年份getMonth…

设计可持续数据中心的基本要素

在当今的商业环境中&#xff0c;数据处理中心&#xff08;DPC&#xff09;的重要性日益凸显&#xff0c;它们为海量数据提供稳定的存储、处理和传输服务。随着数据处理中心重要性的不断提升&#xff0c;构建一个具有弹性和可靠性的基础设施变得至关重要。本文将深入探讨构建可持…

Qt做群控系统

群控系统顾名思义&#xff0c;一台设备控制多台机器。首先我们来创造下界面。我们通过QT UI设计界面。设计界面如下&#xff1a; 登录界面&#xff1a; 登录界面分为两种角色&#xff0c;一种是管理员&#xff0c;另一种是超级管理员。两种用户的主界面是不同的。通过选中记住…

获取泛型,泛型擦除,TypeReference 原理分析

说明 author blog.jellyfishmix.com / JellyfishMIX - githubLICENSE GPL-2.0 获取泛型&#xff0c;泛型擦除 下图中示例代码是一个工具类用于生成 csv 文件&#xff0c;需要拿到数据的类型&#xff0c;使用反射感知数据类型的字段&#xff0c;来填充表字段名。可以看到泛型…

LabVIEW开发为什么沟通需求非常重要

在LabVIEW开发项目中&#xff0c;需求沟通是项目成功的基石。以下是需求沟通的重要性及其原因&#xff1a; 明确项目目标&#xff1a; 定义清晰的目标&#xff1a;通过与用户的沟通&#xff0c;可以明确项目的目标和范围&#xff0c;确保开发团队理解用户的实际需求&#xff0c…

JavaFX DatePicker

JavaFX DatePicker允许从给定日历中选择一天。DatePicker控件包含一个带有日期字段和日期选择器的组合框。JavaFX DatePicker控件使用JDK8日期时间API。 import javafx.application.Application; import javafx.scene.Scene; import javafx.scene.control.DatePicker; import j…

SpringBoot的事务注解

SpringBoot的事务注解 在Spring Boot应用中&#xff0c;事务管理是一个关键的部分&#xff0c;尤其是当涉及到数据库操作时。Spring Boot提供了强大的事务管理支持&#xff0c;使得开发人员可以通过简单的注解来控制事务的边界和行为。本文将介绍如何在Spring Boot中使用事务注…

Go - 4.数组和切片

目录 一.引言 二.定义 1.基础定义 2.引申理解 三.实战 1.估算切片的长度与容量 2.切片的切片长度与容量 四.拓展 1.估算切片容量的增长 2.切片底层数组的替换 五.总结 一.引言 本文主要讨论 Go 语言的数组 array 类型和切片 slice 类型。主要从二者的使用方法&…

C++ 65 之 模版的局限性

#include <iostream> #include <cstring> using namespace std;class Students05{ public:string m_name;int m_age;Students05(string name, int age){this->m_name name;this->m_name age;} };// 两个值进行对比的函数 template<typename T> bool …

PCIe学习——重点提纲

PCIe学习-重点提纲 基础知识 计算机架构基础总线系统概述PCI vs PCI-X vs PCIe PCIe 概述 PCIe 的发展历史PCIe 与其他总线的对比PCIe 的优势和应用场景 PCIe 体系结构 PCIe 分层模型 物理层&#xff08;Physical Layer&#xff09;数据链路层&#xff08;Data Link Layer&…

在 KubeSphere 上快速安装和使用 KDP 云原生数据平台

作者简介&#xff1a;金津&#xff0c;智领云高级研发经理&#xff0c;华中科技大学计算机系硕士。加入智领云 8 余年&#xff0c;长期从事云原生、容器化编排领域研发工作&#xff0c;主导了智领云自研的 BDOS 应用云平台、云原生大数据平台 KDP 等产品的开发&#xff0c;并在…

[Java基本语法] 常量变量与运算符

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏:&#x1f355; Collection与数据结构 (92平均质量分)https://blog.csdn.net/2301_80050796/category_12621348.html?spm1001.2014.3001.5482 &#x1f9c0;线程与…

解决Java项目运行时错误:“Command line is too long”

在开发Java应用的过程中&#xff0c;你可能偶尔会遇到“Error running ‘Application’: Command line is too long”的问题。这是因为Java虚拟机&#xff08;JVM&#xff09;在启动时&#xff0c;如果传递给它的类路径&#xff08;classpath&#xff09;过长&#xff0c;超过了…

1080 MOOC期终成绩

solution 输出合格的学生信息&#xff0c;其中所谓合格需要同时满足 在线编程成绩>200总评成绩>60 总评计算方法为 当期中成绩>期末成绩时&#xff0c;总评期中成绩0.4期末成绩0.6否则&#xff0c;总评期末成绩 通过map建立学号和学生信息间的关联&#xff0c;否则直…

YOLOv10改进 | 注意力篇 | YOLOv10引入Polarized Self-Attention注意力机制

1. Polarized Self-Attention介绍 1.1 摘要:像素级回归可能是细粒度计算机视觉任务中最常见的问题,例如估计关键点热图和分割掩模。 这些回归问题非常具有挑战性,特别是因为它们需要在低计算开销的情况下对高分辨率输入/输出的长期依赖性进行建模,以估计高度非线性的像素语…

AI引领数字安全新纪元,下一代身份基础设施如何帮助企业破局?

近日&#xff0c;Open AI正式发布面向未来人机交互范式的全新大模型GPT-4o&#xff0c;具有文本、语音、图像三种模态的理解力&#xff0c;无疑代表着人工智能技术的又一重大跃进。 人工智能技术领域不断创新和发展&#xff0c;为各行各业带来巨大的生产变革和经济增长的同时&…

RockChip Android12 Settings二级菜单

一:概述 本文将针对Android12 Settings的二级菜单System进行说明。 二:System 1、Activity packages/apps/Settings/AndroidManifest.xml <activityandroid:name=".Settings$SystemDashboardActivity"android:label="@string/header_category_system&quo…

关于Hutool的模块使用说明方法

Hutool包含多个模块&#xff0c;每个模块针对特定的功能需求提供支持&#xff1a;• hutool-aop&#xff1a;JDK动态代理封装&#xff0c;为非IOC环境提供切面支持。• hutool-bloomFilter&#xff1a;布隆过滤器&#xff0c;提供基于Hash算法的实现。• hutool-cache&#xff…