【TensorRT部署】pytorch模型(pt/pth)转onnx,onnx转engine(tensorRT)

1. 单帧处理

1. pt2onnx

import torch
import numpy as np
from parameters import get_parameters as get_parameters
from models._model_builder import build_model
TORCH_WEIGHT_PATH = './checkpoints/model.pth'
ONNX_MODEL_PATH = './checkpoints/model.onnx'
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_tensor_type('torch.cuda.FloatTensor')
def get_numpy_data():
    batch_size = 1
    img_input = np.ones((batch_size,1,512,512), dtype=np.float32)
    return img_input

def get_torch_model():
    # args = get_args()
    args = get_parameters()
    model = build_model(args.model, args)
    model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
    model.cuda()
    #pass
    return model
#定义参数
input_name = ['data']
output_name = ['prob']
'''input为输入模型图片的大小'''
input = torch.randn(1,1,512,512).cuda()

# 创建模型并载入权重
model = get_torch_model()
model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
model.cuda()

#导出onnx
torch.onnx.export(model, input, ONNX_MODEL_PATH, input_names=input_name, output_names=output_name, verbose=False,opset_version=11)

补充:也可以对onnx进行简化

# pip install onnxsim

from onnxsim import simplify
import onnx
onnx_model = onnx.load("./checkpoints/model.onnx")  # load onnx model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, "./checkpoints/model.onnx")
print('finished exporting onnx')

2. onnx2engine

// OnnxToEngine.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
#include <iostream>
#include <chrono>
#include <vector>
#include "cuda_runtime_api.h"
#include "logging.h"
#include "common.hpp"
#include "NvOnnxParser.h"
#include"NvCaffeParser.h"
const char* INPUT_BLOB_NAME = "data";
using namespace std;
using namespace nvinfer1;
using namespace nvonnxparser;
using namespace nvcaffeparser1;

unsigned int maxBatchSize = 1;

int main()
{
    //step1:创建logger:日志记录器
    static Logger gLogger;
    //step2:创建builder
    IBuilder* builder = createInferBuilder(gLogger);

    //step3:创建network
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);//0改成1,
    //step4:创建parser
    nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger);

    //step5:使用parser解析模型填充network
    const char* onnx_filename = "..\\onnx\\model.onnx";
    parser->parseFromFile(onnx_filename, static_cast<int>(Logger::Severity::kWARNING));
    for (int i = 0; i < parser->getNbErrors(); ++i)
    {
        std::cout << parser->getError(i)->desc() << std::endl;
    }
    std::cout << "successfully load the onnx model" << std::endl;
    //step6:创建config并设置最大batchsize和最大工作空间
    // Create builder
   // unsigned int maxBatchSize = 1;
    builder->setMaxBatchSize(maxBatchSize);
    IBuilderConfig* config = builder->createBuilderConfig();
    config->setMaxWorkspaceSize( (1 << int(20)));

    //step7:创建engine
    ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
    //assert(engine);

    //step8:序列化保存engine到planfile
    IHostMemory* serializedModel = engine->serialize();
    //assert(serializedModel != nullptr);
    //std::ofstream p("D:\\TensorRT-7.2.2.322\\engine\\unet.engine");
    //p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());
    std::string engine_name = "..\\engine\\model.engine";
    std::ofstream p(engine_name, std::ios_base::out | std::ios_base::binary);
    if (!p) {
        std::cerr << "could not open plan output file" << std::endl;
        return -1;
    }
    p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());
    std::cout << "successfully build an engine model" << std::endl;
    //step9:释放资源
    serializedModel->destroy();
    engine->destroy();
    parser->destroy();
    network->destroy();
    config->destroy();
    builder->destroy();

}

2. 多帧处理(加速)

2.1 pt2onnx

import onnx
import torch
import numpy as np
from parameters import get_parameters as get_parameters
from models._model_builder import build_model
TORCH_WEIGHT_PATH = './checkpoints/model.pth'
ONNX_MODEL_PATH = './checkpoints/model.onnx'
args = get_parameters()
def get_torch_model():
    # args = get_args()
    print(args.model)
    model = build_model(args.model, args)
    model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
    model.cuda()
    #pass
    return model



if __name__ == "__main__":
    # 设置输入参数
    Batch_size = 1
    Channel = 1
    Height = 384
    Width = 640
    input_data = torch.rand((Batch_size, Channel, Height, Width)).cuda()

    # 实例化模型
    # 创建模型并载入权重
    model = get_torch_model()
    #model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
    #model.cuda()

    # 导出为静态输入
    input_name = 'data'
    output_name = 'prob'
    torch.onnx.export(model,
                      input_data,
                      ONNX_MODEL_PATH,
                      verbose=True,
                      input_names=[input_name],
                      output_names=[output_name])

    # 导出为动态输入
    torch.onnx.export(model,
                      input_data,
                      ONNX_MODEL_PATH2,
                      opset_version=11,
                      input_names=[input_name],
                      output_names=[output_name],
                      dynamic_axes={
                          #input_name: {0: 'batch_size'},
                          #output_name: {0: 'batch_size'}}
                          input_name: {0: 'batch_size', 1: 'channel', 2: 'input_height', 3: 'input_width'},
                          output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}}
                       )

2.2 onnx2engine

 OnnxToEngine.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
#include <iostream>
#include "NvInfer.h"
#include "NvOnnxParser.h"
#include "logging.h"
#include "opencv2/opencv.hpp"
#include <fstream>
#include <sstream>
#include "cuda_runtime_api.h"
static Logger gLogger;
using namespace nvinfer1;


bool saveEngine(const ICudaEngine& engine, const std::string& fileName)
{
	std::ofstream engineFile(fileName, std::ios::binary);
	if (!engineFile)
	{
		std::cout << "Cannot open engine file: " << fileName << std::endl;
		return false;
	}

	IHostMemory* serializedEngine = engine.serialize();
	if (serializedEngine == nullptr)
	{
		std::cout << "Engine serialization failed" << std::endl;
		return false;
	}

	engineFile.write(static_cast<char*>(serializedEngine->data()), serializedEngine->size());
	return !engineFile.fail();
}
void print_dims(const nvinfer1::Dims& dim)
{
	for (int nIdxShape = 0; nIdxShape < dim.nbDims; ++nIdxShape)
	{

		printf("dim %d=%d\n", nIdxShape, dim.d[nIdxShape]);

	}
}

int main()
{

	//	1、创建一个builder
	IBuilder* pBuilder = createInferBuilder(gLogger);
	// 2、 创建一个 network,要求网络结构里,没有隐藏的批量处理维度
	INetworkDefinition* pNetwork = pBuilder->createNetworkV2(1U << static_cast<int>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));

	// 3、 创建一个配置文件
	nvinfer1::IBuilderConfig* config = pBuilder->createBuilderConfig();
	// 4、 设置profile,这里动态batch专属
	IOptimizationProfile* profile = pBuilder->createOptimizationProfile();
	// 这里有个OptProfileSelector,这个用来设置优化的参数,比如(Tensor的形状或者动态尺寸),

	profile->setDimensions("data", OptProfileSelector::kMIN, Dims4(1, 1, 512, 512));
	profile->setDimensions("data", OptProfileSelector::kOPT, Dims4(2, 1, 512, 512));
	profile->setDimensions("data", OptProfileSelector::kMAX, Dims4(4, 1, 512, 512));

	config->addOptimizationProfile(profile);

	auto parser = nvonnxparser::createParser(*pNetwork, gLogger.getTRTLogger());

	const char* pchModelPth = "..\\onnx\\model.onnx";

	if (!parser->parseFromFile(pchModelPth, static_cast<int>(gLogger.getReportableSeverity())))
	{

		printf("解析onnx模型失败\n");
	}

	int maxBatchSize = 4;
	//IBuilderConfig::setMaxWorkspaceSize

	pBuilder->setMaxWorkspaceSize(1 << 32);  //pBuilderg->setMaxWorkspaceSize(1<<32);改为config->setMaxWorkspaceSize(1<<32);
	pBuilder->setMaxBatchSize(maxBatchSize);
	//设置推理模式
	pBuilder->setFp16Mode(true);
	ICudaEngine* engine = pBuilder->buildEngineWithConfig(*pNetwork, *config);

	std::string strTrtSavedPath = "..\\engine\\model.trt";
	// 序列化保存模型
	saveEngine(*engine, strTrtSavedPath);
	nvinfer1::Dims dim = engine->getBindingDimensions(0);
	// 打印维度
	print_dims(dim);
}

3. c++调用tensorRT模型

整个工程:链接
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/184138.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GoogleNet详解

一、亮点 AlexNet、VGG都只有一个输出层。googlenet有三个&#xff08;其中两个是辅助分类层&#xff09; 二、先看看Inception结构 1、Inception 之前的网络&#xff1a; AlexNet、VGG都是串行结构 Inception&#xff1a; 并行结构 上一层的输出同时传入四个并行结构&…

中西部各省市翻译协会、公关协会会长金秋圆桌会议圆满结束

中西部翻译协会共同体、中西部公共关系协会共同体共同体创建8年来&#xff0c;已成功举办了八届翻译大赛。时值第九届中西部翻译大赛将拉开序幕&#xff0c;中西部翻译协会共同体、中西部公共关系协会共同体举办的2023年度中西部各省市翻译协会、公关协会会长金秋圆桌会议&…

如何用python画一个圣诞树

前言 距离圣诞节还有一个月啦。今天&#xff0c;我们给大家画一个圣诞树&#xff0c;我们一起来看看效果吧。 效果展示 我们先来看看最终的效果看看我们画的圣诞树怎么样吧。如果&#xff0c;感觉不错&#xff0c;我们一起来实现吧。 功能实现 功能模块 我们先看看&#x…

Bytebase 2.11.1 - 数据脱敏支持语义类型和脱敏算法

&#x1f680; 新功能 数据脱敏支持自定义脱敏算法和语义类型。 &#x1f514; 重大变更 用户页面的 URL 由 /u/{uid} 变更为 /users/{email}。工作空间的所有者和开发者分别更名为&#xff1a;管理员和成员。 &#x1f384; 改进 SQL 编辑器支持显示表的 DDL 语句&#…

HR8833 双通道H桥电机驱动芯片

HR8833为玩具、打印机和其它电机一T化应用提供一种双通道电机驱动方案。HR8833提供两种封装&#xff0c;一种是带有L露焊盘的TSSOP-16封装&#xff0c;能改进散热性能&#xff0c;且是无铅产品&#xff0c;引脚框采用100&#xff05;无锡电镀。另一种封装为SOP16&#xff0c;不…

【ARM CoreLink 系列 3.2 -- CCI-400,CCI-500, CCI-550 差异】

文章目录 CCI-400 和 CCI-500 差异ARM CCI-400ARM CCI-500ARM CCI-550CCI-400 和 CCI-500 差异 ARM的 CCI(Cache Coherent Interconnect)系列产品是用于多核处理器之间的高性能缓存一致性互连。CCI-400 和 CCI-500 是该系列中的两种设计,它们旨在允许多个处理器核心和其他资…

【腾讯云云上实验室-向量数据库】Tencent Cloud VectorDB在实战项目中替换Milvus测试

为什么尝试使用Tencent Cloud VectorDB替换Milvus向量库&#xff1f; 亮点&#xff1a;Tencent Cloud VectorDB支持Embedding&#xff0c;免去自己搭建模型的负担&#xff08;搭建一个生产环境的模型实在耗费精力和体力&#xff09;。 腾讯云向量数据库是什么&#xff1f; 腾…

函数计算的新征程:使用 Laf 构建 AI 知识库

Laf 已成功上架 Sealos 模板市场&#xff0c;可通过 Laf 应用模板来一键部署&#xff01; 这意味着 Laf 在私有化部署上的扩展性得到了极大的提升。 Sealos 作为一个功能强大的云操作系统&#xff0c;能够秒级创建多种高可用数据库&#xff0c;如 MySQL、PostgreSQL、MongoDB …

神命令tree的魅力你get到了吗?

背景 日常工作中&#xff0c;有时候为了明确表达自己的意思&#xff0c;往往需要输出对应的目录层级结构&#xff0c;手动一个个输入往往显得不那么高级&#xff0c;效率相对较低&#xff0c;这时候拥有可以一键输出目录结构并且可以快速转化为文本的工具就比较方便&#xff0…

创新指南|消费品牌2024重塑增长最值得关注的10个DTC零售策略

2023年对消费零售行业来说同样是挑战的一年&#xff0c;经济逆风和消费低迷迫在眉睫&#xff0c;而品牌零售商如何从库存积压中跳出来&#xff0c;努力应对增加的支出&#xff0c;实现可盈利的增长会是让每位CEO战略执行的第一优先级。2023年用什么策略于DTC&#xff1f;与全球…

B033-Servlet交互 JSP

目录 ServletServlet的三大职责跳转&#xff1a;请求转发和重定向请求转发重定向汇总请求转发与重定向的区别用请求转发和重定向完善登录 JSP第一个JSP概述注释设置创建JSP文件默认字符编码集 JSP的java代码书写JSP的原理三大指令九大内置对象改造动态web工程进行示例内置对象名…

易点易动固定资产管理系统:实现全面的固定资产采购管理

在现代企业中&#xff0c;固定资产采购管理是一项关键的任务。为了确保企业的正常运营和发展&#xff0c;有效管理和控制固定资产采购过程至关重要。易点易动固定资产管理系统为企业提供了一种全面的解决方案&#xff0c;整合了从采购需求、采购计划、询比价、采购合同到采购执…

【C语言】整形在内存中的存储

1、整形在内存中的存储 1.1 原码、反码、补码 计算机中整数有三种二进制表示方法&#xff0c;分别是原码、反码、补码 三种表示方法由符号位和数值位构成&#xff0c;符号位用0表示正数&#xff0c;1表示负数。 整形数据在内存中存放的是补码 正数的原码、反码、补码相同 …

JSP:MVC

Web应用 一个好的Web应用&#xff1a; 功能完善 易于实现和维护 易于扩展等 的体系结构 一个Web应用通常分为两个部分&#xff1a; m 1. 由界面设计人员完成的 表示层 &#xff08;主要做网页界面设计&#xff09; m 2. 由程序设计人员实现的 行为层 &#xff08;主要完成本…

利用企业被执行人信息查询API保障商业交易安全

前言 在当今竞争激烈的商业环境中&#xff0c;企业为了保障商业交易的安全性不断寻求新的手段。随着技术的发展&#xff0c;利用企业被执行人信息查询API已经成为了一种强有力的工具&#xff0c;能够帮助企业在商业交易中降低风险&#xff0c;提高合作的信任度。 企业被执行人…

ArcMap针对正射影像图生成切片操作

1.导入图层jpg文件 2.添加地图坐标系 右键点击地图 --》数据框属性 坐标系选项设置地图的坐标系 地图应该有对应的坐标文件 3.地理配准选项 --》去除自动校正&#xff0c; 4.选择参考坐标 在图中选三个定位坐标保存 5.地理配准选项 --》更新地理位置配准 6.管理工具下 --》…

2023仿聚合搜索程序源码/轻量级搜狗泛站群程序源码/PHP整站源码+完美SEO优化+符合搜狗算法

源码简介&#xff1a; 2023仿聚合搜索/轻量级搜狗泛站群程序整站源码&#xff0c;作为PHP源码&#xff0c;可以完美SEO优化&#xff0c;符合搜狗搜索引擎算法。 轻量级的PHP搜狗泛站群程序源码&#xff0c;完美SEO优化符合搜狗搜索引擎算法&#xff0c;无需任何采集&#xff…

如何有效减少 AI 模型的数据中心能源消耗?

在让人工智能变得更好的竞赛中&#xff0c;麻省理工学院&#xff08;MIT&#xff09;林肯实验室正在开发降低功耗、高效训练和透明能源使用的方法。 在 Google 上搜索航班时&#xff0c;您可能已经注意到&#xff0c;现在每个航班的碳排放量估算值都显示在其成本旁边。这是一种…

AI:87-基于深度学习的街景图像地理位置识别

🚀 本文选自专栏:人工智能领域200例教程专栏 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 ✨✨✨ 每一个案例都附带有在本地跑过的代码,详细讲解供大家学习,希望可以帮到大家。欢迎订阅支持,正在不断更新中,…

GWAS结果批量整理:升级版算法TidyGWAS

TidyGWAS GWAS分析关键结果之一是显著性SNP位点的P值&#xff0c;通常多年份多地点多模型的GWAS分析将会产生很多结果文件&#xff0c;如何对这些数据进行整理&#xff1f; 汇总这些结果&#xff0c;并将显著性的位点或区域找出来&#xff0c;更加清晰的展示关键信息。 今天介…