tensorRT C++使用pt转engine模型进行推理

目录

  • 1. 前言
  • 2. 模型转换
  • 3. 修改Binding
  • 4. 修改后处理

1. 前言

本文不讲tensorRT的推理流程,因为这种文章很多,这里着重讲从标准yolov5的tensort推理代码(模型转pt->wts->engine)改造成TPH-yolov5(pt->onnx->engine)的过程。

2. 模型转换

请查看上一篇文章https://blog.csdn.net/wyw0000/article/details/139737473?spm=1001.2014.3001.5502

3. 修改Binding

如果不修改Binding,会报下图中的错误。
在这里插入图片描述
该问题是由于Binding有多个,而代码中只申请了input和output,那么如何查看engine模型有几个Bingding呢?代码如下:

int get_model_info(const string& model_path) {
    // 创建 logger
    Logger gLogger;

    // 从文件中读取 engine
    std::ifstream engineFile(model_path, std::ios::binary);
    if (!engineFile) {
        std::cerr << "Failed to open engine file." << std::endl;
        return -1;
    }

    engineFile.seekg(0, engineFile.end);
    long int fsize = engineFile.tellg();
    engineFile.seekg(0, engineFile.beg);

    std::vector<char> engineData(fsize);
    engineFile.read(engineData.data(), fsize);
    if (!engineFile) {
        std::cerr << "Failed to read engine file." << std::endl;
        return -1;
    }

    // 反序列化 engine
    auto runtime = nvinfer1::createInferRuntime(gLogger);
    auto engine = runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);

    // 获取并打印输入和输出绑定信息
    for (int i = 0; i < engine->getNbBindings(); ++i) {
        nvinfer1::Dims dims = engine->getBindingDimensions(i);
        nvinfer1::DataType type = engine->getBindingDataType(i);

        std::cout << "Binding " << i << " (" << engine->getBindingName(i) << "):" << std::endl;
        std::cout << "  Type: " << (int)type << std::endl;
        std::cout << "  Dimensions: ";
        for (int j = 0; j < dims.nbDims; ++j) {
            std::cout << (j ? "x" : "") << dims.d[j];
        }
        std::cout << std::endl;
        std::cout << "  Is Input: " << (engine->bindingIsInput(i) ? "Yes" : "No") << std::endl;
    }

    // 清理资源
    engine->destroy();
    runtime->destroy();

    return 0;
}

下图是我的tph-yolov5的Binding,可以看到有5个Binding,因此在doInference推理之前,要给5个Binding都申请空间,同时要注意获取BindingIndex时,名称和dimension与查询出来的对应。
在这里插入图片描述

//for tph-yolov5
    int Sigmoid_921_index = trt->engine->getBindingIndex("onnx::Sigmoid_921");
    int Sigmoid_1183_index = trt->engine->getBindingIndex("onnx::Sigmoid_1183");
    int Sigmoid_1367_index = trt->engine->getBindingIndex("onnx::Sigmoid_1367");
    CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_921_index], BATCH_SIZE * 3 * 192 * 192 * 7 * sizeof(float)));
    CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1183_index], BATCH_SIZE * 3 * 96 * 96 * 7 * sizeof(float)));
    CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1367_index], BATCH_SIZE * 3 * 48 * 48 * 7 * sizeof(float)));

    trt->data = new float[BATCH_SIZE * 3 * INPUT_H * INPUT_W];
    trt->prob = new float[BATCH_SIZE * OUTPUT_SIZE];
    trt->inputIndex = trt->engine->getBindingIndex(INPUT_BLOB_NAME);
    trt->outputIndex = trt->engine->getBindingIndex(OUTPUT_BLOB_NAME);

还有推理的部分也要做修改,原来只有input和output两个Binding时,那么输出是buffers[1],而目前是有5个Binding那么输出就变成了buffers[4]

void doInference(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* output, int batchSize) {
    // infer on the batch asynchronously, and DMA output back to host
    context.enqueueV2(buffers, stream, nullptr);
    //CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
    CUDA_CHECK(cudaMemcpyAsync(output, buffers[4], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
    cudaStreamSynchronize(stream);
}

4. 修改后处理

之前的yolov5推理代码是将pt模型转为wts再转为engine的,输出维度只有一维,而TPH输出维度为145152*7,因此要对原来的后处理代码进行修改。

struct BoundingBox {
    //bbox[0],bbox[1],bbox[2],bbox[3],conf, class_id
    float x1, y1, x2, y2, score, index;
};

float iou(const BoundingBox&  box1, const BoundingBox& box2) {
	float max_x = max(box1.x1, box2.x1);  // 找出左上角坐标哪个大
	float min_x = min(box1.x2, box2.x2);  // 找出右上角坐标哪个小
	float max_y = max(box1.y1, box2.y1);
	float min_y = min(box1.y2, box2.y2);
	if (min_x <= max_x || min_y <= max_y) // 如果没有重叠
		return 0;
	float over_area = (min_x - max_x) * (min_y - max_y);  // 计算重叠面积
	float area_a = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);
	float area_b = (box2.x2 - box2.x1) * (box2.y2 - box2.y1);
	float iou = over_area / (area_a + area_b - over_area);
	return iou;
}

std::vector<BoundingBox> nonMaximumSuppression(std::vector<std::vector<float>>& boxes, float overlapThreshold) {
	std::vector<BoundingBox> convertedBoxes;

	// 将数据转换为BoundingBox结构体
	for (const auto&  box: boxes) {
		if (box.size() == 6) { // Assuming [x1, y1, x2, y2, score]
			BoundingBox bbox;
			bbox.x1 = box[0];
			bbox.y1 = box[1];
			bbox.x2 = box[2];
			bbox.y2 = box[3];
			bbox.score = box[4];
			bbox.index = box[5];
			convertedBoxes.push_back(bbox);
		}
		else {
			std::cerr << "Invalid box format!" << std::endl;
		}
	}

	// 对框按照分数降序排序
	std::sort(convertedBoxes.begin(), convertedBoxes.end(), [](const BoundingBox& a, const BoundingBox&  b) {
		return a.score > b.score;
		});

	// 非最大抑制
	std::vector<BoundingBox> result;
	std::vector<bool> isSuppressed(convertedBoxes.size(), false);

	for (size_t i = 0; i < convertedBoxes.size(); ++i) {
		if (!isSuppressed[i]) {
			result.push_back(convertedBoxes[i]);

			for (size_t j = i + 1; j < convertedBoxes.size(); ++j) {
				if (!isSuppressed[j]) {
					float overlap = iou(convertedBoxes[i], convertedBoxes[j]);

					if (overlap > overlapThreshold) {
						isSuppressed[j] = true;
					}
				}
			}
		}
	}
#if 0
	// 输出结果
	std::cout << "NMS Result:" << std::endl;
	for (const auto& box: result) {
		std::cout << "x1: " << box.x1 << ", y1: " << box.y1
			<< ", x2: " << box.x2 << ", y2: " << box.y2
			<< ", score: " << box.score << ",index:" << box.index << std::endl;
	}
#endif 
	return result;
}

void post_process(float *prob_model, float conf_thres, float overlapThreshold, std::vector<Yolo::Detection> & detResult)
{
	int cols = 7, rows = 145152;
	//  ========== 8. 获取推理结果 =========
	std::vector<std::vector<float>> prediction(rows, std::vector<float>(cols));

	int index = 0;
	for (int i = 0; i < rows; ++i) {
		for (int j = 0; j < cols; ++j) {
			prediction[i][j] = prob_model[index++];
		}
	}

	//  ========== 9. 大于conf_thres加入xc =========
	std::vector<std::vector<float>> xc;
	for (const auto& row : prediction) {
		if (row[4] > conf_thres) {
			xc.push_back(row);
		}
	}
	//  ========== 10. 置信度 = obj_conf * cls_conf =========
	//std::cout << xc[0].size() << std::endl;
	for (auto& row: xc) {
		for (int i = 5; i < xc[0].size(); i++) {
			row[i] *= row[4];
		}
	}

	// ========== 11. 切片取出xywh 转为xyxy=========
	std::vector<std::vector<float>> xywh;
	for (const auto& row: xc) {
		std::vector<float> sliced_row(row.begin(), row.begin() + 4);
		xywh.push_back(sliced_row);
	}
	std::vector<std::vector<float>> box(xywh.size(), std::vector<float>(4, 0.0));

	xywhtoxxyy(xywh, box);
	
	// ========== 12. 获取置信度最高的类别和索引=========
	std::size_t mi = xc[0].size();
	std::vector<float> conf(xc.size(), 0.0);
	std::vector<float> j(xc.size(), 0.0);

	for (std::size_t i = 0; i < xc.size(); ++i) {
		// 模拟切片操作 x[:, 5:mi]
		auto sliced_x = std::vector<float>(xc[i].begin() + 5, xc[i].begin() + mi);

		// 计算 max
		auto max_it = std::max_element(sliced_x.begin(), sliced_x.end());

		// 获取 max 的索引
		std::size_t max_index = std::distance(sliced_x.begin(), max_it);

		// 将 max 的值和索引存储到相应的向量中
		conf[i] = *max_it;
		j[i] = max_index;  // 加上切片的起始索引
	}

	// ========== 13. concat x1, y1, x2, y2, score, index;======== =
	for (int i = 0; i < xc.size(); i++) {
		box[i].push_back(conf[i]);
		box[i].push_back(j[i]);
	}

	std::vector<std::vector<float>> output;
	for (int i = 0; i < xc.size(); i++) {
		output.push_back(box[i]); // 创建一个空的 float 向量并
	}

	// ==========14 应用非最大抑制 ==========
	std::vector<BoundingBox>  result = nonMaximumSuppression(output, overlapThreshold);
	for (const auto& r : result)
	{
		Yolo::Detection det;
		det.bbox[0] = r.x1;
		det.bbox[1] = r.y1;
		det.bbox[2] = r.x2;
		det.bbox[3] = r.y2;
		det.conf = r.score;
		det.class_id = r.index;
		detResult.push_back(det);
	}

}

代码参考:
https://blog.csdn.net/rooftopstars/article/details/136771496
https://blog.csdn.net/qq_73794703/article/details/132147879

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/737362.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux - 探秘 Linux 的 /proc/sys/vm 常见核心配置

文章目录 PreLinux 的 /proc/sys/vm 简述什么是 /proc/sys/vm&#xff1f;主要的配置文件及其用途参数调整对系统的影响dirty_background_ratio 和 dirty_ratioswappinessovercommit_memory 和 overcommit_ratiomin_free_kbytes 实例与使用建议调整 swappiness设置 min_free_kb…

6.XSS-钓鱼攻击展示(存储型xss)

xss钓鱼演示 钓鱼攻击利用页面 D:\phpStudy\WWW\pikachu\pkxss\xfish 修改配置文件里面对应自己的入侵者的IP地址或者域名&#xff0c;对应的路径下的fish.php脚本如下&#xff1a; <?php error_reporting(0); // var_dump($_SERVER); if ((!isset($_SERVER[PHP_AUTH_USE…

分类预测 | Matlab实现GWO-CNN-SVM灰狼冰算法优化卷积支持向量机分类预测

分类预测 | Matlab实现GWO-CNN-SVM灰狼冰算法优化卷积支持向量机分类预测 目录 分类预测 | Matlab实现GWO-CNN-SVM灰狼冰算法优化卷积支持向量机分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.Matlab实现GWO-CNN-SVM灰狼冰算法优化卷积支持向量机分类预测&…

嵌入式学习——数据结构(队列)——day50

1. 查找二叉树、搜索二叉树、平衡二叉树 2. 哈希表——人的身份证——哈希函数 3. 哈希冲突、哈希矛盾 4. 哈希代码 4.1 创建哈希表 4.2 5. 算法设计 5.1 正确性 5.2 可读性&#xff08;高内聚、低耦合&#xff09; 5.3 健壮性 5.4 高效率&#xff08;时间复杂度&am…

线程封装,互斥

文章目录 线程封装线程互斥加锁、解锁认识接口解决问题理解锁 线程封装 C/C代码混编引起的问题 此处pthread_create函数要求传入参数为void * func(void * )类型,按理来说ThreadRoutine满足,但是 这是在内类完成封装,所以ThreadRoutine函数实际是两个参数,第一个参数Thread* …

Python | Leetcode Python题解之第174题地下城游戏

题目&#xff1a; 题解&#xff1a; class Solution:def calculateMinimumHP(self, dungeon: List[List[int]]) -> int:n, m len(dungeon), len(dungeon[0])BIG 10**9dp [[BIG] * (m 1) for _ in range(n 1)]dp[n][m - 1] dp[n - 1][m] 1for i in range(n - 1, -1, …

Redis的实战常用一、验证码登录(解决session共享问题)(思路、意识)

一、基于session实现登录功能 第一步&#xff1a;发送验证码&#xff1a; 用户在提交手机号后&#xff0c;会校验手机号是否合法&#xff1a; 如果不合法&#xff0c;则要求用户重新输入手机号如果手机号合法&#xff0c;后台此时生成对应的验证码&#xff0c;同时将验证码进行…

C语言 | Leetcode C语言题解之第187题重复的DNA序列

题目&#xff1a; 题解&#xff1a; #define MAXSIZE 769/* 选取一个质数即可 */ typedef struct Node {char string[101];int index;struct Node *next; //保存链表表头 } List;typedef struct {List *hashHead[MAXSIZE];//定义哈希数组的大小 } MyHashMap;List * …

【百问大模型02】一文讲透RAG实战全解析

1.实时性无法更新&#xff0c;知识容易自相矛盾 2.大模型的缺点有哪些&#xff1f; 3.一个人的能力可以分为两种&#xff1a; 1&#xff09;大模型&#xff1a;推理能力&#xff0c;聪明&#xff0c;知识&#xff1b;很聪明但是缺少知识 2&#xff09;知识库&#xff1a;辅…

第一个Flask程序

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 一切准备就绪&#xff0c;现在我们开始编写第一个Flask程序&#xff0c;由于是第一个Flask程序&#xff0c;当然要从最简单的“Hello World&#xff…

打印机状态显示错误是什么原因?这5个有效方法要记好!

打印机是现代办公中不可或缺的设备之一&#xff0c;但在使用过程中&#xff0c;打印机状态显示错误是一个常见的问题。本文将详细探讨打印机状态显示错误的原因及其解决方法。 摘要 打印机状态显示错误的原因及解决方法如下&#xff1a; 1、网络连接问题&#xff1a;原因&…

【python】python基于微博互动数据的用户类型预测(随机森林与支持向量机的比较分析)(源码+数据集+课程论文)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

皇河将相董事长程灯虎出席第二十四届世纪大采风并获奖

仲夏时节,西子湖畔。第二十四届世纪大采风品牌人物年度盛典于6月16日至17日在杭州东方文化园隆重举行。本届盛典由亿央网、《华夏英才》电视栏目联合多家媒体共同主办,中世采文化发展集团承办,意尔康股份有限公司、宸咏集团协办,汇聚了来自全国政、商、产、学、研、媒等各界代表…

图像编辑技术的新篇章:基于扩散模型的综述

在人工智能的浪潮中&#xff0c;图像编辑技术正经历着前所未有的变革。随着数字媒体、广告、娱乐和科学研究等领域对高质量图像编辑需求的不断增长&#xff0c;传统的图像编辑方法已逐渐无法满足日益复杂的视觉内容创作需求。尤其是在AI生成内容&#xff08;AIGC&#xff09;的…

YIA主题侧边栏如何添加3D旋转标签云?

WordPress站点侧边栏默认的标签云排版很一般&#xff0c;而3D旋转标签云就比较酷炫了。下面boke112百科就以YIA主题为例&#xff0c;跟大家说一说如何将默认的标签云修改成3D旋转标签云&#xff0c;具体步骤如下&#xff1a; 1、点此下载3d标签云文件&#xff08;密码&#xf…

开源项目推荐-vue2+element+axios 个人财务管理系统

文章目录 financialmanagement项目简介项目特色项目预览卫星的实现方式&#xff1a;首次进入卫星效果的实现方式&#xff1a;卫星跟随鼠标滑动的随机效果实现方式&#xff1a;环境准备项目启动项目部署项目地址 financialmanagement 项目简介 vue2elementaxios 个人财务管理系…

java学习--集合(大写二.2)

看尚硅谷视频做的笔记 2.collection接口及方法 jdk8里有一些默认的方法&#xff0c;更多的是体现的是一种规范&#xff0c;规范更多关注的是一些抽象方法。 看接口里面的抽象方法&#xff0c;选一个具体的实现类。 测试collection的方法&#xff0c;存储一个一个数据都有哪些…

记录:[android] SSLHandshakeException: Handshake failed 问题;已解决!

1、问题描述&#xff1a;在使用Retrofit2 时在安卓老设备上&#xff08;安卓6.0&#xff09;网络无法请求、安卓 10 、 11 未出现此问题&#xff1f;what? 原因&#xff1a;服务端 TLS 版本过高 2、废话不多说、解决方案A 、添加依赖&#xff1a;implementation org.conscrypt…

安徽理工大学2计算机考研情况,招收计算机专业的学院和联培都不少!

安徽理工大学&#xff08;Anhui University of Science and Technology&#xff09;&#xff0c;位于淮南市&#xff0c;是安徽省和应急管理部共建高校&#xff0c;安徽省高等教育振兴计划“地方特色高水平大学”建设高校&#xff0c;安徽省高峰学科建设计划特别支持高校&#…

Java面试八股之myBatis与myBatis plus的对比

myBatis与myBatis plus的对比 基础与增强&#xff1a; MyBatis 是一个成熟的Java持久层框架&#xff0c;它允许开发者通过XML文件或注解来配置SQL语句和数据库映射&#xff0c;提供了一个灵活的方式来操作数据库&#xff0c;但需要手动编写所有的SQL语句和结果集映射。 MyBa…