onnx 图像分类

参考文章:

【netron】模型可视化工具netron-CSDN博客

Pytorch图像分类模型部署-ONNX Runtime本地终端推理_哔哩哔哩_bilibili

使用netron可视化模型结构

1)使用在线版

浏览器访问:Netron

点击 “Open Model” 按钮,选择要可视化的模型文件即可

2)下载本地版

终端进行安装: pip install netron
安装完成后,在脚本中 调用包 import netron
运行程序 netron.start("model.onnx"), 会自动打开浏览器进行可视化 (最后有例子)

我习惯用 pytorch,但是 netron 对 pytorch 的 .pt 和 .pth 文件不是很友好,所以,我都是先转换为 onnx 格式,再进行可视化,下面举例。

简化可视化模型:

ONNX模型转换及使用指南-CSDN博客

onnx的概念 

训练的归训练,部署的归部署。 onnx的存在极大地降低了部署的难度。

安装配置环境:

安装Pytorch

!pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113

安装onnx

!pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple

安装推理引擎 ONNX Runtime

!pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple

安装其他第三方工具包

!pip install numpy pandas matplotlib tqdm opencv-python pillow -i https://pypi.tuna.tsinghua.edu.cn/simple

验证安装配置成功

import torch
print('PyTorch 版本', torch.__version__)
import onnx
print('ONNX 版本', onnx.__version__)
import onnxruntime as ort
print('ONNX Runtime 版本', ort.__version__)

转ONNX模型文件

导入工具包

import torch
from torchvision import models

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

载入ImageNet预训练PyTorch图像分类模型

model = models.resnet18(pretrained=True)
model = model.eval().to(device)

 构造一个输入图像Tensor

x = torch.randn(1, 3, 256, 256).to(device)

 输入Pytorch模型推理预测,获得1000个类别的预测结果

output = model(x)
output.shape

 Pytorch模型转ONNX格式

with torch.no_grad():
    torch.onnx.export(
        model,                       # 要转换的模型
        x,                           # 模型的任意一组输入
        'resnet18_imagenet.onnx',    # 导出的 ONNX 文件名
        opset_version=11,            # ONNX 算子集版本
        input_names=['input'],       # 输入 Tensor 的名称(自己起名字)
        output_names=['output']      # 输出 Tensor 的名称(自己起名字)
    ) 

我只修改了自己代码的这个地方↓

        # '''pth模型文件转为onnx格式'''
        with torch.no_grad():
            torch.onnx.export(
                model,                       # 要转换的模型
                img,                           # 模型的任意一组输入
                'resnet34_classification.onnx',    # 导出的 ONNX 文件名
                opset_version=11,            # ONNX 算子集版本
                input_names=['input'],       # 输入 Tensor 的名称(自己起名字)
                output_names=['output']      # 输出 Tensor 的名称(自己起名字)
            ) 

验证onnx模型导出成功

import onnx

# 读取 ONNX 模型
onnx_model = onnx.load('resnet18_imagenet.onnx')

# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)

print('无报错,onnx模型载入成功')

推理引擎ONNX Runtime部署-预测单张图像

导入工具包

import onnxruntime
import numpy as np
import torch
import torch.nn.functional as F

import pandas as pd

载入 onnx 模型,获取 ONNX Runtime 推理器

ort_session = onnxruntime.InferenceSession('resnet18_imagenet.onnx')

构造随机输入,获取输出结果

x = torch.randn(1, 3, 256, 256).numpy()
x.shape # (1, 3, 256, 256)
# onnx runtime 输入
ort_inputs = {'input': x}

# onnx runtime 输出
ort_output = ort_session.run(['output'], ort_inputs)[0]

注意,输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应

ort_output.shape # (1, 1000)

载入一张真正的测试图像

img_path = 'banana1.jpg'
# 用 pillow 载入
from PIL import Image
img_pil = Image.open(img_path)

预处理函数

from torchvision import transforms

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

运行预处理

input_img = test_transform(img_pil)
input_img.shape # torch.Size([3, 256, 256])
input_tensor = input_img.unsqueeze(0).numpy()
input_tensor.shape # (1, 3, 256, 256)

推理预测

# ONNX Runtime 输入
ort_inputs = {'input': input_tensor}

# ONNX Runtime 输出
pred_logits = ort_session.run(['output'], ort_inputs)[0]
pred_logits = torch.tensor(pred_logits)
pred_logits.shape # torch.Size([1, 1000]) 因为ImageNet有1000类
# 对 logit 分数做 softmax 运算,得到置信度概率
pred_softmax = F.softmax(pred_logits, dim=1) 
pred_softmax.shape # torch.Size([1, 1000])

解析预测结果

# 取置信度最高的前 n 个结果
n = 3

top_n = torch.topk(pred_softmax, n)

top_n
'''torch.return_types.topk(
values=tensor([[9.9669e-01, 2.6005e-03, 3.0254e-04]]),
indices=tensor([[954, 939, 941]]))'''

# 预测类别
pred_ids = top_n.indices.numpy()[0]

pred_ids
'''array([954, 939, 941])'''

用自己转换好的ONNX模型文件类型预测

        # 加载ONNX模型
        onnx_model_path = 'resnet34_classification.onnx'
        onnx_model = onnx.load(onnx_model_path)

        # 创建ONNX Runtime会话
        ort_session = onnxruntime.InferenceSession(onnx_model_path)

        # 将PyTorch Tensor转换为NumPy数组
        input_data = img.numpy()

        # 运行预测
        ort_inputs = {ort_session.get_inputs()[0].name: input_data}
        ort_outputs = ort_session.run(['output'], ort_inputs)

        # 获取预测结果
        output = ort_outputs[0]
        predict = torch.from_numpy(output)
        predict_cla = torch.argmax(predict).item()

        grading_result = int(class_indict[str(predict_cla)])
        end_time = time.time()
        # 打印预测时间
        prediction_time = end_time - start_time
        print("模型预测时间:{:.4f}秒".format(prediction_time))

        return grading_result

注意:相比于直接用pt文件预测图片,输入的图像转成numpy的数据格式

       # 将PyTorch Tensor转换为NumPy数组
        input_data = img.numpy()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/254556.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据库灾难应对:MySQL误删除数据的救赎之道,技巧get起来!之binlog

《数据库灾难应对:MySQL误删除数据的救赎之道,技巧get起来!之binlog》 数据意外删除是数据库管理中常见的问题之一。MySQL作为广泛使用的数据库管理系统,当数据意外删除时,有几种方法可以尝试恢复数据。以下是binlog方…

基于SpringBoot的房屋租赁系统 附源码

基于SpringBoot的房屋租赁系统 附源码 文章目录 基于SpringBoot的房屋租赁系统 附源码 一.引言二.系统设计三.技术架构四.功能实现五.界面展示六.源码获取 一.引言 本文介绍了一个基于SpringBoot的房屋租赁系统。该系统利用SpringBoot框架的优势,实现了用户注册、登…

采购oled屏幕,应注意什么

在采购OLED屏幕时,应注意以下几点: 规格和参数:了解OLED屏幕的规格和参数,包括尺寸、分辨率、亮度、对比度、响应时间等。确保所采购的屏幕符合项目的需求和预期效果。 品质和可靠性:选择具有可靠品质和稳定性的OLED屏…

分布式事务seata使用示例及注意事项

分布式事务seata使用示例及注意事项 示例说明代码调用方(微服务A)服务方(微服务B) 测试测试一 ,seata发挥作用,成功回滚!测试二:处理feignclient接口的返回类型从Integer变成String,…

Ngnix之反向代理、负载均衡、动静分离

目录 1. Ngnix 1.1 Linux系统Ngnix下载安装 1.2 反向代理 正向代理(Forward Proxy): 反向代理(Reverse Proxy): 1.3 负载均衡 1.4 动静分离 1. Ngnix Nginx是一个高性能的开源Web服务器&#xff0…

python识别增强静脉清晰度 opencv-python图像处理案例

一.任务说明 用python实现静脉清晰度提升。 二.代码实现 import cv2 import numpy as npdef enhance_blood_vessels(image):# 调整图像对比度和亮度enhanced_image cv2.convertScaleAbs(image, alpha0.5, beta40)# 应用CLAHE(对比度受限的自适应直方图均衡化&…

Future CompleteFuture

前言 Java8 中的 completeFuture 是对 Future 的扩展实现,主要是为了弥补 Future 没有相应的回调机制的缺陷。 Callable、Runnable、Future、CompletableFuture 之间的关系: Callable,有结果的同步行为,比如做蛋糕,…

python程序打包成exe全流程纪实(windows)

目录 前言准备工作安装python(必须)安装vs平台或conda(非必须) 详细步骤Step1.创建python虚拟环境方法一、裸装(windows下)方法二、借助工具(windows下) Step2.安装打包必须的python包Step3.准备好程序logo(非必须&…

51单片机定时器

51单片机有两个16位定时器,今天复习了一下使用方法,发现当初刚开始学习51单片机时并没有记录,特此今天补上这篇博客。 下面是定时器的总览示意图,看到这个图就能想到定时器怎么设置,怎么开始工作。 第一步&#xff1a…

刷完这个笔记,18K不能再少了....

大家好,最近有不少小伙伴在后台留言,得准备年后面试了,又不知道从何下手!为了帮大家节约时间,特意准备了一份面试相关的资料,内容非常的全面,真的可以好好补一补,希望大家在都能拿到…

EmbedAI:一个可以上传文件训练自己ChatGPT的AI工具,妈妈再也不用担心我的GPT不会回答问题

功能介绍: 个性化定制:提供灵活的训练选项,用户能够通过文件、网站、Notion文档甚至YouTube等多种数据源对ChatGPT进行训练,以满足不同领域和需求的个性化定制。广泛应用场景:ChatGPT支持多种用例,包括智能…

Jmeter吞吐量控制器使用小结

吞吐量控制器(Throughput Controller)场景: 在同一个线程组里, 有10个并发, 7个做A业务, 3个做B业务,要模拟这种场景,可以通过吞吐量模拟器来实现.。 添加吞吐量控制器 用法1: Percent Executions 在一个线程组内分别建立两个吞吐量控制器, 分别放业务A和业务B 吞吐量控制器采…

【算法系列篇】递归、搜索和回溯(三)

文章目录 前言什么是决策树1. 全排列1.1 题目要求1.2 做题思路1.3 代码实现 2. 子集2.1 题目要求2.2 做题思路2.3 代码实现 3. 找出所有子集的异或总和再求和3.1 题目要求3.2 做题思路3.3 代码实现 4. 全排列II4.1 题目要求4.2 做题思路4.3 代码实现 前言 前面我们通过几个题目…

蚂蚁集团5大开源项目获开放原子 “2023快速成长开源项目”

12月16日,在开放原子开源基金会主办的“2023开放原子开发者大会”上,蚂蚁集团主导开源的图数据库TuGraph、时序数据库CeresDB、隐私计算框架隐语SecretFlow、前端框架OpenSumi、数据域大模型开源框架DB-GPT入选“2023快速成长开源项目”。 (图…

Kafka中Ack应答级别和数据去重

在Kafka中,保证数据安全可靠的条件是: 数据完全可靠条件 ACK级别设置为-1 分区副本大于等于2 ISR里应答的最小副本数量大于等于2; Ack应答级别 可靠性总结: acks0,生产者发送过来数据就不管了,可靠性差…

2023年国赛高教杯数学建模D题圈养湖羊的空间利用率解题全过程文档及程序

2023年国赛高教杯数学建模 D题 圈养湖羊的空间利用率 原题再现 规模化的圈养养殖场通常根据牲畜的性别和生长阶段分群饲养,适应不同种类、不同阶段的牲畜对空间的不同要求,以保障牲畜安全和健康;与此同时,也要尽量减少空间闲置所…

人工智能深度学习:探索智能的深邃奥秘

导言 人工智能深度学习作为当今科技领域的明星,正引领着智能时代的浪潮。深度学习和机器学习作为人工智能领域的两大支柱,它们之间的关系既有协同合作,又存在着显著的区别。本文将深入研究深度学习在人工智能领域的角色,以及其在各…

Android Termux安装MySQL数据库并通过内网穿透实现公网远程访问

文章目录 前言1.安装MariaDB2.安装cpolar内网穿透工具3. 创建安全隧道映射mysql4. 公网远程连接5. 固定远程连接地址 前言 Android作为移动设备,尽管最初并非设计为服务器,但是随着技术的进步我们可以将Android配置为生产力工具,变成一个随身…

鸿蒙端H5容器化建设——JSB通信机制建设

1. 背景 2023年鸿蒙开发者大会上,华为宣布为了应对国外技术封锁的潜在风险,2024年的HarmonyOS NEXT版本中将不再兼容Android,并推出鸿蒙系统以及其自研的开发框架,形成开发生态闭环。同时,在更高维度上华为希望将鸿蒙…

GPT-4V被超越?SEED-Bench多模态大模型测评基准更新

📖 技术报告 SEED-Bench-1:https://arxiv.org/abs/2307.16125 SEED-Bench-2:https://arxiv.org/abs/2311.17092 🤗 测评数据 SEED-Bench-1:https://huggingface.co/datasets/AILab-CVC/SEED-Bench SEED-Bench-2&…