YOLOv5模型转ONNX,ONNX转TensorRT Engine

系列文章目录

第一章 YOLOv5模型训练集标注、训练流程
第二章 YOLOv5模型转ONNX,ONNX转TensorRT Engine
第三章 TensorRT量化

文章目录

  • 系列文章目录
  • 前言
  • 一、yolov5模型导出ONNX
    • 1.1 工作机制
    • 1.2 修改yolov5代码,输出ONNX
  • 二、TensorRT部署
    • 2.1 模型部署
    • 2.2 模型推理
  • 总结


前言

学习笔记–恩培老师


一、yolov5模型导出ONNX

1.1 工作机制

使用tensort deconde plugin 来替代yolov5代码中的deconde操作,需要修改yolov5代码导出onnx模型的部分。

在这里插入图片描述

1.2 修改yolov5代码,输出ONNX

批量修改

#将patch复制到yolov5文件夹
cp export.patch yolov5/
#进入yolov5文件夹
cd yolov5/
#应用patch
git am export.patch

安装需要依赖

pip install seaborn
pip install onnx-graphsurgeon
pip install opencv-python==4.5.5.64
pip install onnx-simplifier==0.3.10

apt update
apt install -y libgl1-mesa-glx

安装完成后,准备训练好的模型文件,默认为yolov5s.pt,然后执行下列代码,生成Onnx文件。

安装不上onnx-graphsurgeon,使用下面的命令再次安装

pip install nvidia-pyindex
pip install onnx-graphsurgeon
python export.py --weights weights/yolov5s_person.pt --include onnx --simplify

这里的yolov5s_person.pt文件就是我们刚刚训练好的best.pt复制过来的。
可视化模型工具

pip install netron
netron ./weights/yolov5s_person.onnx

二、TensorRT部署

使用TensorRT docker容器:

docker run --gpus all -it --name env_trt -v ${pwd}: /app nvcr.io/nvidia/tensorrt:22.08-py3

2.1 模型部署

推荐博客TensorRT部署流程
yolov5转到onnx后进行模型的构建并保存序列化后的模型为文件。

  • 模型导出成 ONNX 格式。
  • 把 ONNX 格式模型输入给 TensorRT,并指定优化参数。
  • 使用 TensorRT 优化得到 TensorRT Engine。
  • 使用 TensorRT Engine 进行 inference。
  1. 创建builder
    这里使用了std::unqique_ptr,只能指针包装我们的builder,实现自动管理指针生命周期。
//**************1.创建builder***************//

auto builder = std::unique_ptr<nvinferl::IBuilder>
(nvinfer1::IBuilder::createInferBuilder(sampelr::gLogger.getTRTLogger())));
if (!builder)
{
    std::cerr<<"Failed to create builder"<<std::endl;
    return -1;
}

  1. 创建网络。这里指定了explicitBatch

  2. 创建onnxparser,用于解析onnx文件

4.配置网络参数。
我们需要告诉tensorrt我们最终运行时,输入图像的范围,batch size范围。

#include <iostream>
#include "NvInfer.h"

int main() {
    // Create a logger
    nvinfer1::ILogger* logger = new nvinfer1::ILogger();

    // Create a builder
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(*logger);
    if (!builder) {
        std::cerr << "Failed to create builder" << std::endl;
        return -1;
    }

    // Set up builder configurations (optional)
    builder->setMaxBatchSize(1);
    builder->setMaxWorkspaceSize(1 << 30); // 1GB

    // Create a network definition
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U);

    // ... Add layers and define the network ...

    // Build the engine
    nvinfer1::ICudaEngine* engine = builder->buildCudaEngine(*network);

    if (!engine) {
        std::cerr << "Failed to build engine" << std::endl;
        return -1;
    }

    // Clean up
    network->destroy();
    engine->destroy();
    builder->destroy();
    logger->log(nvinfer1::ILogger::Severity::kINFO, "Finished successfully!");

    delete logger;

    return 0;
}


2.2 模型推理

推理过程

  • 读取模型文件
  • 对输入进行预处理
  • 读取模型输出
  • 后处理(NMS)

1.创建运行时
2.反序列化模型得到推理Engine
3.创建执行上下文
4.创建输入输出缓冲区管理器
5.读取视频文件,并逐帧读取图像送入模型,进行推理

#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <chrono>
#include <opencv2/opencv.hpp>
#include "NvInfer.h"

int main() {
    // Create a logger
    nvinfer1::ILogger* logger = new nvinfer1::ILogger();
    
    // Create a runtime
    nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(*logger);
    if (!runtime) {
        std::cerr << "Failed to create runtime" << std::endl;
        return -1;
    }

    // Deserialize the engine
    const std::string engineFilePath = "path/to/your/engine.plan";
    std::ifstream engineFile(engineFilePath, std::ios::binary);
    if (!engineFile) {
        std::cerr << "Failed to open engine file" << std::endl;
        return -1;
    }
    engineFile.seekg(0, engineFile.end);
    const int engineSize = engineFile.tellg();
    engineFile.seekg(0, engineFile.beg);
    char* engineData = new char[engineSize];
    engineFile.read(engineData, engineSize);
    engineFile.close();

    nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engineData, engineSize, nullptr);
    if (!engine) {
        std::cerr << "Failed to deserialize engine" << std::endl;
        return -1;
    }

    delete[] engineData;

    // Create an execution context
    nvinfer1::IExecutionContext* context = engine->createExecutionContext();
    if (!context) {
        std::cerr << "Failed to create execution context" << std::endl;
        return -1;
    }

    // Create input and output buffer managers
    const int maxBatchSize = engine->getMaxBatchSize();
    nvinfer1::Dims inputDims = engine->getBindingDimensions(0);
    const int inputSize = inputDims.d[1] * inputDims.d[2] * inputDims.d[3];
    nvinfer1::Dims outputDims = engine->getBindingDimensions(1);
    const int outputSize = outputDims.d[1];

    nvinfer1::IHostMemory* inputMemory = engine->createHostMemory(engine->getBindingDataType(0), maxBatchSize * inputSize);
    void* inputBuffer = inputMemory->data();

    nvinfer1::IHostMemory* outputMemory = engine->createHostMemory(engine->getBindingDataType(1), maxBatchSize * outputSize);
    void* outputBuffer = outputMemory->data();

    // Open the video file
    const std::string videoFilePath = "path/to/your/video.mp4";
    cv::VideoCapture cap(videoFilePath);
    if (!cap.isOpened()) {
        std::cerr << "Failed to open video file" << std::endl;
        return -1;
    }

    // Loop through all frames
    cv::Mat frame;
    int frameCount = 0;
    auto startTime = std::chrono::high_resolution_clock::now();

    while (true) {
        // Read the next frame
        cap >> frame;
        if (frame.empty()) {
            break;
        }

        // Prepare the input data
        cv::Mat resizedFrame;
        cv::resize(frame, resizedFrame, cv::Size(inputDims.d[3], inputDims.d[2]));
        float* inputData = static_cast<float*>(inputBuffer) + frameCount * inputSize;

        const int channelSize = inputDims.d[2] * inputDims.d[3];
        for (int c = 0; c < inputDims.d[1]; ++c) {
            for (int h = 0; h < inputDims.d[2]; ++h) {
                for (int w = 0; w < inputDims.d[3]; ++w) {
                    const float pixel = resizedFrame.at<cv::Vec3b>(h, w)[c] / 255.0f;
                    inputData[c * channelSize + h * inputDims.d[3] + w] = pixel;
                }
            }
        }

        // Run inference
        context->executeV2(&inputBuffer, &outputBuffer);

        // Process the output data
        float* outputData = static_cast<float*>(outputBuffer) + frameCount * outputSize;

        // ... Process the output data ...

        ++frameCount;
    }

    // Measure and print the inference time
    auto endTime = std::chrono::high_resolution_clock::now();
    auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
    std::cout << "Inference time: " << elapsedTime.count() << "ms" << std::endl;

    // Clean up
    inputMemory->destroy();
    outputMemory->destroy();
    context->destroy();
    engine->destroy();
    runtime->destroy();
    logger->log(nvinfer1::ILogger::Severity::kINFO, "Finished successfully!");

    delete logger;

    return 0;
}

使用cmake进行构建,cmake相关知识可看cmake学习笔记

cmake -S .-B build
cmake --build build
./build/build
./build/build ./weights/yolo5s_person.onnx
#执行推理
./build/runtime

视频文件

./weights/yolov5.engine ./media/c3.mp4

总结

接下来是了解TensorRT插件,Int8量化流程。

推荐视频链接:https://www.bilibili.com/video/BV1jj411Z7wG/?spm_id_from=333.337.search-card.all.click&vd_source=ce674108fa2e19e5322d710724193487

推荐链接:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/tree/master/cookbook

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/319620.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习每日小知识】Computer Vision 计算机视觉

计算机视觉是人工智能的一个领域&#xff0c;涉及算法和系统的开发&#xff0c;使计算机能够解释、理解和分析来自周围世界的视觉数据。这包括从静态图像到视频流甚至 3D 环境的一切。 使用对象检测和特征提取等方法&#xff0c;计算机视觉本质上需要从视觉输入中提取有用信息…

TensorRT(C++)基础代码解析

TensorRT(C)基础代码解析 文章目录 TensorRT(C)基础代码解析前言一、TensorRT工作流程二、C API2.1 构建阶段2.1.1 创建builder2.1.2 创建网络定义2.1.3 定义网络结构2.1.4 定义网络输入输出2.1.5 配置参数2.1.6 生成Engine2.1.7 保存为模型文件2.1.8 释放资源 2.2 运行期2.2.1…

STM32的USB设备库

适用范围&#xff1a;“on the STM32F10xxx,STM32F37xxx, STM32F30xxx and STM32L15xxx devices.” STM32_USB-FS-Device_Lib_V4.0.0.rar&#xff08;访问密码&#xff1a;1666&#xff09;https://url48.ctfile.com/f/33868548-1000799917-a5409d?p1666 适用范围&#xff1…

服务器配置SSL证书到nginx基于Fdfs存储服务器或者直接阿里云绑定SSL

1.如果用FDFS存储服务器内置nginx设置SSL证书 1.验证当前nginx是否存在 http_ssl_modulehttp_ssl_module模块 如果存在直接配置就行 server {listen 80 default backlog2048;listen 443 ssl; server_name 域名; ssl_certificate /usr/local/nginx_fdfs/ssl/xxxx.top.crt; ssl…

【C++】内联函数

前言 在C语言中&#xff0c;我们学习过宏的用法。宏通常被用于进行简单的文本替换来执行一系列的操作&#xff0c;比如一些简单的运算。使用宏可以避免函数调用时建立栈帧的开销&#xff0c;提高程序的性能。我们首先来写一个实现加法功能的宏&#xff1a; #define ADD(x, y)…

5、C语言:结构

结构 结构的基本知识结构与函数传递结构 结构数组、指向结构的指针自引用结构&#xff08;二叉树&#xff09;表查找类型定义&#xff08;typedef&#xff09;联合位字段 结构也是一种数据类型。类似于int、char、double、float等。 结构是一个或多个变量的集合&#xff0c;这些…

Linux系统——远程访问及控制

目录 一、OpenSSH服务器 1.SSH&#xff08;Secure Shell&#xff09;协议 2.OpenSSH 2.SSH原理 2.1公钥传输原理 2.2加密原理 &#xff08;1&#xff09;对称加密 &#xff08;2&#xff09;非对称加密 2.3远程登录 2.3.1延伸 2.3.2登录用户 3.SSH格式及选项 3.1延…

node(express.js创建项目)+连接mysql数据库

1.node npm的安装 2.express的安装 全局安装:npm install express -gnpm install -g express-generator// ps: 4.0版本把generator分离出来了&#xff0c;需要单独安装3.创建express项目 express 项目名称 cd 项目名称 npm install npm start4.项目中安装数据库 npm install…

Python读取log文件报错“UnicodeDecodeError”

转载说明&#xff1a;如果您喜欢这篇文章并打算转载它&#xff0c;请私信作者取得授权。感谢您喜爱本文&#xff0c;请文明转载&#xff0c;谢谢。 问题描述&#xff1a; 写了一个读取log文件的Python脚本&#xff1a; # -*- coding:utf-8 -*- import os import numpy as np …

第01章_Java语言概述拓展练习(为什么要设置path?)

文章目录 第01章_Java语言概述拓展练习1、System.out.println()和System.out.print()有什么区别&#xff1f;2、一个".java"源文件中是否可以包括多个类&#xff1f;有什么限制&#xff1f;3、Something类的文件名叫OtherThing.java是否可以&#xff1f;4、为什么要设…

【Maven】009-Maven 简单父子工程搭建

【Maven】009-Maven 简单父子工程搭建 文章目录 【Maven】009-Maven 简单父子工程搭建一、需求说明1、结构2、第三方库 二、工程搭建1、父工程第一步&#xff1a;创建父工程第二步&#xff1a;引入公共依赖 lombok 和管理 hutool 依赖版本 2、公共子模块第一步&#xff1a;创建…

服务器出现500、502、503错误的原因以及解决方法

服务器我们经常会遇到访问不了的情况有的时候是因为我们服务器被入侵了所以访问不了&#xff0c;有的时候是因为出现了服务器配置问题&#xff0c;或者软硬件出现问题导致的无法访问的问题&#xff0c;这时候会出现500、502、503等错误代码。基于以上问题我们第一步可以先重启服…

MySQL核心SQL

一.结构化查询语言 SQL是结构化查询语言&#xff08;Structure Query Language&#xff09;&#xff0c;它是关系型数据库的通用语言。 SQL 主要可以划分为以下 3 个类别&#xff1a; DDL&#xff08;Data Definition Languages&#xff09;语句 数据定义语言&#xff0c;这…

vue2配置教程

5.12.3 Vue Cli 文档地址: https://cli.vuejs.org/zh/ IDEA 打开项目&#xff0c;运行项目

【年终总结】回首2023的精彩,迈向2024的未来

文章目录 一、历历在目&#xff0c;回首成长之路&#x1f3c3;‍1、坚持输出&#xff0c;分享所学2、积土成山&#xff0c;突破万粉3、不断精进&#xff0c;向外涉足 二、雅俗共赏&#xff0c;阅历百般美好&#x1f3bb;1、音乐之声&#xff0c;声声入耳2、书海遨游&#xff0c…

select...for update锁详解

select…for update锁详解 select…for update的作用就是&#xff1a;如果A事务中执行了select…for update&#xff0c;那么在其提交或回滚事务之前&#xff0c;B&#xff0c;C&#xff0c;D…事务是无法操作&#xff08;写&#xff09;A事务select…for update所命中的数据的…

php时间函数date()、getdate()、time()

目录 1. 时区修改 2. date() 3. getdate() 4. time() 1. 时区修改 位于东八区&#xff0c;修改php.ini 。date.timezone Asia/Shanghai 2. date() 获取时间函数 <?php header("Content-Type: text/html; charsetutf-8");$d date(H:i:s);//小时H&#xf…

linux驱动(五):framebuffer

本文主要探讨210的framebuffer驱动知识。 frameBuffer 用户态进程直接调用显卡写屏,framebuffer接口是给用户态进程用于写屏 framebuffer设备文件为fbx 清屏:dd if/dev/zero of/dev/fbx 清屏&#xff1a;$ dd if/dev/zero of/dev/fb0 bs1024 …

2023 IoTDB Summit:天谋科技高级开发工程师苏宇荣《汇其流:如何用 IoTDB 流处理框架玩转端边云融合》...

12 月 3 日&#xff0c;2023 IoTDB 用户大会在北京成功举行&#xff0c;收获强烈反响。本次峰会汇集了超 20 位大咖嘉宾带来工业互联网行业、技术、应用方向的精彩议题&#xff0c;多位学术泰斗、企业代表、开发者&#xff0c;深度分享了工业物联网时序数据库 IoTDB 的技术创新…

一键调整播放倍速,调整播放倍速的软件

你是否曾因为长时间的视频而感到厌烦&#xff1f;或者因为视频播放得太快而错过了一些重要内容&#xff1f;现在&#xff0c;有了我们的【媒体梦工厂】&#xff0c;这些问题都将得到完美解决。不论你是想快速浏览长视频&#xff0c;还是想让视频慢下来以便更好地学习或欣赏&…