yolov8-cls的onnx与tensorrt推理

本文不生产技术,只做技术的搬运工!

前言

  最近需要使用yolov8-cls进行模型分类任务,但是使用ultralytics框架去部署非常不方便,因此打算进行onnx或者tensorrt去部署,查看了很多网上的帖子,并没有发现有完整复现yolov8-cls前处理(不需要后处理)的"轮子",通过自己debug找到并复现了前处理代码,在这里做一下代码记录

环境配置

训练环境

ultralytics-8.2.0:https://github.com/ultralytics/ultralytics/tree/v8.2.0

推理环境

tensorrt:Tensorrt安装及分类模型tensorrt推理,并生成混淆矩阵_linux命令 tensorrt转化模型 trtexec-CSDN博客文章浏览阅读368次,点赞7次,收藏4次。分类模型使用tensorrt推理,包括tensorrt安装及推理_linux命令 tensorrt转化模型 trtexechttps://blog.csdn.net/qq_44908396/article/details/143628108onnxruntime-gpu:1.18.1

python:3.9

torch:1.13.1+cu117

推理代码框架

onnx推理:分类模型onnx推理,并生成混淆矩阵-CSDN博客文章浏览阅读148次。onnx推理分类模型https://blog.csdn.net/qq_44908396/article/details/143507869tensorrt推理:

Tensorrt安装及分类模型tensorrt推理,并生成混淆矩阵_linux命令 tensorrt转化模型 trtexec-CSDN博客文章浏览阅读368次,点赞7次,收藏4次。分类模型使用tensorrt推理,包括tensorrt安装及推理_linux命令 tensorrt转化模型 trtexechttps://blog.csdn.net/qq_44908396/article/details/143628108其实作者的这两篇博客已经搭建了好推理框架,这里我们只需要修改一下前处理代码即可

前处理分析

俗话说授人以鱼不如授人以渔,这里作者讲述一下怎样复现分类的前处理代码,纯搞工程的朋友可以跳过这一段,直接去下一段拿代码

推理demo编写

这里我们需要编写一个推理demo方便接下来的debug

from ultralytics import YOLO

model = YOLO("./ultralytics-8.2.0/runs/classify/train/weights/last.pt")  # load a pretrained model (recommended for training)

results = model("/home/workspace/temp/1111/14.jpg")  # predict on an image

debug

这里我们先找到ultralytics-8.2.0/ultralytics/models/yolo/classify/predict.py文件,同理,如果大家需要检测或者分割的前处理,只要把classify换成detect或者segment即可,前处理代码如下:

    def preprocess(self, img):
        """Converts input image to model-compatible data type."""
        if not isinstance(img, torch.Tensor):
            is_legacy_transform = any(
                self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
            )
            if is_legacy_transform:  # to handle legacy transforms
                img = torch.stack([self.transforms(im) for im in img], dim=0)
            else:
                img = torch.stack(
                    [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
                )
        img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
        return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32

我们将断点设置在这个函数体上,然后逐行执行,发现它走的其实是else,也就是说is_legacy_transform变量是false,那么我们只需要复现else内的语句即可,这里边的内容很简单,我们只需要知道self.transforms是个什么东西就可以了,这里我们可以通过debug监视器查看,也可以简单粗暴加打印,作者更喜欢打印

执行我们上面的demo查看打印

这里详细打印了self.transforms的内容及类型,走到这一步我们基本就知道了该如何复现yolov8-cls的前处理了,前处理代码如下:

def read_image(image_path):
    src = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    img = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)

    # 使用 InterpolationMode.BILINEAR 指定双线性插值
    transform = transforms.Compose([
        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=True),
        transforms.CenterCrop(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    ])
    # 将图像转换为 PIL 图像并应用变换
    pil_image = Image.fromarray(img)
    normalized_image = transform(pil_image)

    return np.expand_dims(normalized_image.numpy(), axis=0), src

这样我们就完美复现了yolov8-cls的前处理

整体代码

onnx模型转换

yolo classify export model=runs/classify/train/weights/last.pt format="onnx"

onnx推理

import onnxruntime
import numpy as np
import os
import cv2
import argparse
import time
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import torch
from torchvision import transforms
from PIL import Image


labels = ["0", "1", "2", "3", "4", "5", "6", "7"]

def sigmoid(x):
    """Sigmoid function for a scalar or NumPy array."""
    return 1 / (1 + np.exp(-x))

def getFileList(dir, Filelist, ext=None):
    """
    获取文件夹及其子文件夹中文件列表
    输入 dir:文件夹根目录
    输入 ext: 扩展名
    返回: 文件路径列表
    """
    newDir = dir
    if os.path.isfile(dir):
        if ext is None:
            Filelist.append(dir)
        else:
            if ext in dir:
                Filelist.append(dir)

    elif os.path.isdir(dir):
        for s in os.listdir(dir):
            newDir = os.path.join(dir, s)
            getFileList(newDir, Filelist, ext)

    return Filelist

def read_image(image_path):
    src = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    img = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)

    # 使用 InterpolationMode.BILINEAR 指定双线性插值
    transform = transforms.Compose([
        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=True),
        transforms.CenterCrop(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    ])
    # 将图像转换为 PIL 图像并应用变换
    pil_image = Image.fromarray(img)
    normalized_image = transform(pil_image)

    return np.expand_dims(normalized_image.numpy(), axis=0), src


def load_onnx_model(model_path):
    providers = ['CUDAExecutionProvider']  # 使用 GPU
    # providers = ['CPUExecutionProvider']
    #providers = ['TensorrtExecutionProvider']
    session = onnxruntime.InferenceSession(model_path, providers=providers)
    print("ONNX模型已成功加载。")
    return session

def main(image_path, session):
    image,_ = read_image(image_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    pred = session.run([output_name], {input_name: image})[0]
    pred = np.squeeze(pred)
    pred = pred.tolist()
    return pred.index(max(pred)), max(pred), labels[pred.index(max(pred))]

def plot_confusion_matrix(y_true, y_pred, labels):
    """
    绘制混淆矩阵
    输入 y_true: 真实标签
    输入 y_pred: 预测标签
    输入 labels: 标签名称
    """
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--images_path', type=str, default="/home/workspace/temp/test", help='images_path')
    parser.add_argument('--model_path', type=str, default="/home/workspace/temp/last.onnx", help='model_path')
    args = parser.parse_args()
    img_list = []
    img_list = getFileList(args.images_path, img_list)
    count = 0
    session = load_onnx_model(args.model_path)
    start = time.time()
    y_true = []
    y_pred = []
    count_time = 0
    for img in img_list:
        #true_label = int(img.split('/')[-2].split('-')[0])
        true_label = img.split('/')[-2]
        start_1 = time.time()
        predicted_index, score, label = main(img, session)
        print(img,label, score)
        count_time += time.time() - start_1
        y_true.append(true_label)
        #y_pred.append(predicted_index)
        y_pred.append(label)
        if label == true_label:
            count += 1
        # else:
        #     dst_path = img.replace('test', 'test_out')
        #     dst_dir = os.path.dirname(dst_path)
        #     if not os.path.exists(dst_dir):
        #         os.makedirs(dst_dir)
        #     shutil.copy(img, dst_path.replace('.jpg', "-" + label + '.jpg'))

    accuracy = count / len(img_list) * 100
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Correct predictions: {count}, Total images: {len(img_list)}")
    print(f"Time taken: {time.time() - start:.6f} seconds")
    print("推理", len(img_list), "张图像用时", count_time)
    # 绘制混淆矩阵
    plot_confusion_matrix(y_true, y_pred, labels)

tensorrt模型转换

/home/tensorrt8.6/TensorRT-8.6.1.6/bin/trtexec --onnx=last.onnx --saveEngine=last-fp16.engine --workspace=3000 --verbose --fp16

tensorrt代码

import os
import cv2
import tensorrt as trt
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import time
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torchvision import transforms
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

def plot_confusion_matrix(y_true, y_pred, labels):
    """
    绘制混淆矩阵
    输入 y_true: 真实标签
    输入 y_pred: 预测标签
    输入 labels: 标签名称
    """
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()

def sigmoid(x):
    """Sigmoid function for a scalar or NumPy array."""
    return 1 / (1 + np.exp(-x))

def getFileList(dir, Filelist, ext=None):
    """
    获取文件夹及其子文件夹中文件列表
    输入 dir:文件夹根目录
    输入 ext: 扩展名
    返回: 文件路径列表
    """
    newDir = dir
    if os.path.isfile(dir):
        if ext is None:
            Filelist.append(dir)
        else:
            if ext in dir:
                Filelist.append(dir)

    elif os.path.isdir(dir):
        for s in os.listdir(dir):
            newDir = os.path.join(dir, s)
            getFileList(newDir, Filelist, ext)

    return Filelist

def read_image(image_path):
    src = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    img = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)

    # 使用 InterpolationMode.BILINEAR 指定双线性插值
    transform = transforms.Compose([
        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=True),
        transforms.CenterCrop(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    ])
    # 将图像转换为 PIL 图像并应用变换
    pil_image = Image.fromarray(img)
    normalized_image = transform(pil_image)

    return np.expand_dims(normalized_image.numpy(), axis=0), src

def load_engine(engine_file_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    with open(engine_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

def create_context(engine):
    return engine.create_execution_context()

def allocate_buffers(engine):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()

    for binding in engine:
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        bindings.append(int(device_mem))
        if engine.binding_is_input(binding):
            inputs.append({'host': host_mem, 'device': device_mem})
        else:
            outputs.append({'host': host_mem, 'device': device_mem})
    return inputs, outputs, bindings, stream

def infer(context, inputs, outputs, bindings, stream, input_data):
    # Transfer input data to the GPU
    [np.copyto(i['host'], input_data.ravel().astype(np.float32)) for i in inputs]
    [cuda.memcpy_htod_async(i['device'], i['host'], stream) for i in inputs]

    # Execute the model
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)

    # Transfer predictions back from the GPU
    [cuda.memcpy_dtoh_async(o['host'], o['device'], stream) for o in outputs]

    # Synchronize the stream
    stream.synchronize()

    # Return the host output
    return [o['host'] for o in outputs]

def main(image_path, context,inputs, outputs, bindings, stream):
    input_data, src = read_image(image_path)
    pred = infer(context, inputs, outputs, bindings, stream, input_data)
    pred = np.squeeze(pred)
    pred = pred.tolist()
    return pred.index(max(pred)), max(pred), labels[pred.index(max(pred))]

if __name__ == '__main__':
    image_dir = r"/home/workspace/temp/test"
    engine_file_path = '/home/workspace/temp/last-fp16.engine'
    labels = ["0", "1", "2", "3", "4", "5", "6", "7"]
    engine = load_engine(engine_file_path)
    context = create_context(engine)
    inputs, outputs, bindings, stream = allocate_buffers(engine)
    img_list = []
    img_list = getFileList(image_dir, img_list)
    count = 0
    start = time.time()
    y_true = []
    y_pred = []
    count_time = 0
    for img in img_list:
        # true_label = int(img.split('/')[-2].split('-')[0])
        true_label = img.split('/')[-2]
        start_1 = time.time()
        predicted_index, score, label = main(img, context,inputs, outputs, bindings, stream)
        count_time += time.time() - start_1
        y_true.append(true_label)
        # y_pred.append(predicted_index)
        y_pred.append(label)
        if label == true_label:
            count += 1
        # else:
        #     dst_path = img.replace('test', 'test_out')
        #     dst_dir = os.path.dirname(dst_path)
        #     if not os.path.exists(dst_dir):
        #         os.makedirs(dst_dir)
        #     shutil.copy(img, dst_path.replace('.jpg', "-" + label + '.jpg'))

    accuracy = count / len(img_list) * 100
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Correct predictions: {count}, Total images: {len(img_list)}")
    print(f"Time taken: {time.time() - start:.6f} seconds")
    print("推理", len(img_list), "张图像用时", count_time)
    # 绘制混淆矩阵
    plot_confusion_matrix(y_true, y_pred, labels)

注意事项

代码中虽然编写了sigmoid函数,但并为使用,主要是因为通过netron查看onnx模型时发现,其输出已经包含了softmax层,因此不需要再进行额外的分类输出函数

 如果大家使用pycharm执行tensorrt推理有可能遇到找不到链接库的问题,可以参考作者的这一篇博客

pycharm解决ImportError: libnvinfer.so.8: cannot open shared object file: No such file or directory-CSDN博客文章浏览阅读138次。解决pycharm无法识别tensorrt系统环境变量的问题_libnvinfer.so.8: cannot open shared object file: no such file or directoryhttps://blog.csdn.net/qq_44908396/article/details/143628859

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/915617.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Acrobat Pro DC 2023(pdf免费转化word)

所在位置 通过网盘分享的文件:Acrobat Pro DC 2023(64bit).tar 链接: https://pan.baidu.com/s/1_m8TT1rHTtp5YnU8F0QGXQ 提取码: 1234 --来自百度网盘超级会员v4的分享 安装流程 打开安装所在位置 进入安装程序 找到安装程序 进入后点击自定义安装,这里…

VMware和CentOS 7.6 Linux操作系统的安装使用

1. 安装VMware 安装VMware之前,有些电脑是需要去BIOS里修改设置开启cpu虚拟化设备支持才能安装。如果运气不好在安装过程中安装不了的话就自行百度吧。 打开 VMware 的官网: https://www.vmware.com/ 点击 product,往下滑找到 see desktop hypeerviso…

手把手教你:如何从零开始实施一套OA办公系统!

很多朋友都吐槽说公司的各种各样的信息啊文件啊越积越多,导致管理起来越来越麻烦。早就跟大家说过,尤其是在提高工作效率、优化资源配置和促进信息共享方面,OA(办公自动化)系统发挥着不可替代的作用,早安排…

网页web无插件播放器EasyPlayer.js播放器返回错误 Incorrect response MIME type 的解决方式

在使用EasyPlayer.js播放器进行视频流播放时,尤其是在SpringBoot环境中部署静态资源时,可能会遇到“Incorrect response MIME type”的错误,这通常与WebAssembly(WASM)文件的MIME类型配置有关。 WASM是一种新的代码格式…

element-plus <el-date-picker>日期选择器踩坑!!!!

我怎么一上午踩两个坑&#xff01;&#xff01;&#xff01;&#xff01;&#xff01;&#xff08;大声bb&#xff09; 原来的vue2老项目是这样写的 <el-form-item label"时间" prop"time"><el-date-pickerv-model"addForm.time"typ…

# 如何查看 Ubuntu 版本?

如何查看 Ubuntu 版本&#xff1f; 要查看‌Ubuntu版本&#xff0c;你可以通过以下几种方法&#xff1a; 1. 使用‌lsb_release 命令‌查看 使用 lsb_release -a 命令可以查看Ubuntu的详细版本信息&#xff0c;包括发行版ID、版本号以及版本代号。‌ ‌### 2、查看 /etc/is…

常用的生物医药专利查询数据库及网站(很全!)

生物医药专利信息检索是药物研发前期不可或缺的一步&#xff0c;通过对国内外生物医药专利网站信息查询&#xff0c;可详细了解其专利技术&#xff0c;进而有效降低药物研发过程中的风险。 目前主要使用的生物医药专利查询网站分为两大类&#xff0c;一个是免费生物医药专利查询…

第四节-OSI-网络层

数据链路层&#xff1a;二层--MAC地址精确定位 Ethernet 2&#xff1a; 报头长度&#xff1a;18B 携带的参数&#xff1a;D MAC /S MAC/TYPE(标识上层协议)/FCS 802.3 报头长度&#xff1a;26B 携带的参数&#xff1a;D MAC/S MAC/LLC(标识上层协议)/SNAP&#xff08;标识…

Python数据分析NumPy和pandas(二十七、数据可视化 matplotlib API 入门)

数据可视化或者数据绘图是数据分析中最重要的任务之一&#xff0c;是数据探索过程的一部分&#xff0c;数据可视化可以帮助我们识别异常值、识别出需要的数据转换以及为模型生成提供思考依据。对于Web开发人员&#xff0c;构建基于Web的数据可视化显示也是一种重要的方式。Pyth…

【前端】深入浅出 - TypeScript 的详细讲解

TypeScript 是一种静态类型编程语言&#xff0c;它是 JavaScript 的超集&#xff0c;添加了类型系统和编译时检查。TypeScript 的主要目标是提高大型项目的开发效率和可维护性。本文将详细介绍 TypeScript 的核心概念、语法、类型系统、高级特性以及最佳实践。 1. TypeScript…

查询DBA_FREE_SPACE缓慢问题

这个是一个常见的问题&#xff0c;理论上应该也算是一个bug&#xff0c;在oracle10g&#xff0c;到19c&#xff0c;我都曾经遇到过&#xff1b;今天在给两套新建的19C RAC添加监控脚本时&#xff0c;又发现了这个问题&#xff0c;在这里记录一下。 Symptoms 环境&#xff1a;…

The Internals of PostgreSQL 翻译版 持续更新...

为了方便自己快速学习&#xff0c;整理了翻译版本&#xff0c;目前翻译的还不完善&#xff0c;后续会边学习边完善。 文档用于自己快速参考&#xff0c;会持续修正&#xff0c;能力有限,无法确保正确!!! 《The Internals of PostgreSQL 》 不是 《 PostgreSQL14 Internals 》…

机器学习 ---模型评估、选择与验证(1)

目录 前言 一、为什么要有训练集与测试集 1、为什么要有训练集与测试集 2、如何划分训练集与测试集 二、欠拟合与过拟合 1、什么是欠拟合与欠拟合的原因 2、什么是过拟合与过拟合的原因 一些解决模型过拟合和欠拟合问题的常见方法&#xff1a; 解决过拟合问题&#…

一文简单了解Android中的input流程

在 Android 中&#xff0c;输入事件&#xff08;例如触摸、按键&#xff09;从硬件传递到应用程序并最终由应用层消费。整个过程涉及多个系统层次&#xff0c;包括硬件层、Linux 内核、Native 层、Framework 层和应用层。我们将深入解析这一流程&#xff0c;并结合代码逐步了解…

【JavaEE初阶 — 多线程】单例模式 & 指令重排序问题

目录 1. 单例模式 (1) 饿汉模式 (2) 懒汉模式 1. 单线程版本 2. 多线程版本 2. 解决懒汉模式产生的线程安全问题 (1) 产生线程安全的原因 (2) 解决线程安全问题 1. 通过加锁让读写操作紧密执行 方法一 方法二 2. 处理加锁引入的新问题 问题描述 …

二叉树搜索树(下)

二叉树搜索树&#xff08;下&#xff09; 二叉搜索树key和key/value使用场景 key搜索场景 只有key作为关键码&#xff0c;结构中只需要存储key即可&#xff0c;关键码即为需要搜索到的值&#xff0c;搜索场景只需要判断 key在不在。key的搜索场景实现的二叉树搜索树支持增删查…

Web项目版本更新及时通知

背景 单页应用&#xff0c;项目更新时&#xff0c;部分用户会出更新不及时&#xff0c;导致异常的问题。 技术方案 给出版本号&#xff0c;项目每次更新时通知用户&#xff0c;版本已经更新需要刷新页面。 版本号更新方案版本号变更后通知用户哪些用户需要通知&#xff1f;…

D64【python 接口自动化学习】- python基础之数据库

day64 SQL-DQL-基础查询 学习日期&#xff1a;20241110 学习目标&#xff1a;MySQL数据库-- 133 SQL-DQL-基础查询 学习笔记&#xff1a; 基础数据查询 基础数据查询-过滤 总结 基础查询的语法&#xff1a;select 字段列表|* from 表过滤查询的语法&#xff1a;select 字段…

Unity插件-Smart Inspector 免费的,接近虚幻引擎的蓝图Tab管理

习惯了虚幻的一张蓝图&#xff0c;关联所有Tab &#xff08;才发现Unity&#xff0c;的Component一直被人吐槽&#xff0c;但实际上是&#xff1a;本身结构Unity 的GameObject-Comp结构&#xff0c;是好的不能再好了&#xff0c;只是配上 smart Inspector就更清晰了&#xff0…

2024 年Postman 如何安装汉化中文版?

2024 年 Postman 的汉化中文版安装教程