【yolov5】onnx的INT8量化engine

GitHub上有大佬写好代码,理论上直接克隆仓库里下来使用

git clone https://github.com/Wulingtian/yolov5_tensorrt_int8_tools.git

然后在yolov5_tensorrt_int8_tools的convert_trt_quant.py 修改如下参数

BATCH_SIZE 模型量化一次输入多少张图片

BATCH 模型量化次数

height width 输入图片宽和高

CALIB_IMG_DIR 训练图片路径,用于量化

onnx_model_path onnx模型路径

engine_model_path 模型保存路径

其中这个batch_size不能超过照片的数量,然后跑这个convert_trt_quant.py

出问题了吧@_@

这是因为tensor的版本更新原因,这个代码的tensorrt版本是7系列的,而目前新的tensorrt版本已经没有了一些属性,所以我们需要对这个大佬写的代码进行一些修改

如何修改呢,其实tensorrt官方给出了一个caffe量化INT8的例子

https://github.com/NVIDIA/TensorRT/tree/master/samples/python/int8_caffe_mnist

如果足够NB是可以根据官方的这个例子修改一下直接实现onnx的INT8量化的

但是奈何我连半桶水都没有,只有一滴水,但是这个例子中的tensorrt版本是新的,于是我尝试将上面那位大佬的代码修改为使用新版的tensorrt

居然成功了??!!

成功量化后的模型大小只有4MB,相比之下的FP16的大小为6MB,FP32的大小为9MB

再看看检测速度,速度和FP16差不太多

但是效果要差上一些了

那肯定不能忘记送上修改的代码,折腾一晚上的结果如下,主要是 util_trt程序

# tensorrt-lib

import os
import tensorrt as trt
import pycuda.autoinit
import pycuda.driver as cuda
from calibrator import Calibrator
from torch.autograd import Variable
import torch
import numpy as np
import time
# add verbose
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) # ** engine可视化 **

# create tensorrt-engine
  # fixed and dynamic
def get_engine(max_batch_size=1, onnx_file_path="", engine_file_path="",\
               fp16_mode=False, int8_mode=False, calibration_stream=None, calibration_table_path="", save_engine=False):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
    def build_engine(max_batch_size, save_engine):
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, \
                builder.create_network(1) as network,\
                trt.OnnxParser(network, TRT_LOGGER) as parser:
            
            # parse onnx model file
            if not os.path.exists(onnx_file_path):
                quit('ONNX file {} not found'.format(onnx_file_path))
            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                parser.parse(model.read())
                assert network.num_layers > 0, 'Failed to parse ONNX model. \
                            Please check if the ONNX model is compatible '
            print('Completed parsing of ONNX file')
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))        
            
            # build trt engine
            builder.max_batch_size = max_batch_size
            config = builder.create_builder_config()
            config.max_workspace_size = 1 << 20
            if int8_mode:
                config.set_flag(trt.BuilderFlag.INT8)
                assert calibration_stream, 'Error: a calibration_stream should be provided for int8 mode'
                config.int8_calibrator  = Calibrator(calibration_stream, calibration_table_path)
                print('Int8 mode enabled')
            runtime=trt.Runtime(TRT_LOGGER)
            plan = builder.build_serialized_network(network, config)
            engine = runtime.deserialize_cuda_engine(plan)
            if engine is None:
                print('Failed to create the engine')
                return None   
            print("Completed creating the engine")
            if save_engine:
                with open(engine_file_path, "wb") as f:
                    f.write(engine.serialize())
            return engine
        
    if os.path.exists(engine_file_path):
        # If a serialized engine exists, load it instead of building a new one.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine(max_batch_size, save_engine)

唔,convert_trt_quant.py的代码也给一下吧

import numpy as np
import torch
import torch.nn as nn
import util_trt
import glob,os,cv2

BATCH_SIZE = 1
BATCH = 79
height = 640
width = 640
CALIB_IMG_DIR = '/content/drive/MyDrive/yolov5/ikunData/images'
onnx_model_path = "runs/train/exp4/weights/FP32.onnx"
def preprocess_v1(image_raw):
    h, w, c = image_raw.shape
    image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
    # Calculate widht and height and paddings
    r_w = width / w
    r_h = height / h
    if r_h > r_w:
        tw = width
        th = int(r_w * h)
        tx1 = tx2 = 0
        ty1 = int((height - th) / 2)
        ty2 = height - th - ty1
    else:
        tw = int(r_h * w)
        th = height
        tx1 = int((width - tw) / 2)
        tx2 = width - tw - tx1
        ty1 = ty2 = 0
    # Resize the image with long side while maintaining ratio
    image = cv2.resize(image, (tw, th))
    # Pad the short side with (128,128,128)
    image = cv2.copyMakeBorder(
        image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128)
    )
    image = image.astype(np.float32)
    # Normalize to [0,1]
    image /= 255.0
    # HWC to CHW format:
    image = np.transpose(image, [2, 0, 1])
    # CHW to NCHW format
    #image = np.expand_dims(image, axis=0)
    # Convert the image to row-major order, also known as "C order":
    #image = np.ascontiguousarray(image)
    return image


def preprocess(img):
    img = cv2.resize(img, (640, 640))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.transpose((2, 0, 1)).astype(np.float32)
    img /= 255.0
    return img

class DataLoader:
    def __init__(self):
        self.index = 0
        self.length = BATCH
        self.batch_size = BATCH_SIZE
        # self.img_list = [i.strip() for i in open('calib.txt').readlines()]
        self.img_list = glob.glob(os.path.join(CALIB_IMG_DIR, "*.jpg"))
        assert len(self.img_list) > self.batch_size * self.length, '{} must contains more than '.format(CALIB_IMG_DIR) + str(self.batch_size * self.length) + ' images to calib'
        print('found all {} images to calib.'.format(len(self.img_list)))
        self.calibration_data = np.zeros((self.batch_size,3,height,width), dtype=np.float32)

    def reset(self):
        self.index = 0

    def next_batch(self):
        if self.index < self.length:
            for i in range(self.batch_size):
                assert os.path.exists(self.img_list[i + self.index * self.batch_size]), 'not found!!'
                img = cv2.imread(self.img_list[i + self.index * self.batch_size])
                img = preprocess_v1(img)
                self.calibration_data[i] = img

            self.index += 1

            # example only
            return np.ascontiguousarray(self.calibration_data, dtype=np.float32)
        else:
            return np.array([])

    def __len__(self):
        return self.length

def main():
    # onnx2trt
    fp16_mode = False
    int8_mode = True 
    print('*** onnx to tensorrt begin ***')
    # calibration
    calibration_stream = DataLoader()
    engine_model_path = "runs/train/exp4/weights/int8.engine"
    calibration_table = 'yolov5_tensorrt_int8_tools/models_save/calibration.cache'
    # fixed_engine,校准产生校准表
    engine_fixed = util_trt.get_engine(BATCH_SIZE, onnx_model_path, engine_model_path, fp16_mode=fp16_mode, 
        int8_mode=int8_mode, calibration_stream=calibration_stream, calibration_table_path=calibration_table, save_engine=True)
    assert engine_fixed, 'Broken engine_fixed'
    print('*** onnx to tensorrt completed ***\n')
    
if __name__ == '__main__':
    main()
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/136387.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

操作系统(二)内存管理的基础知识

文章目录 前言内存管理地址空间与地址生成连续内存分配内存碎片连续分配算法碎片整理 非连续内存分配虚拟内存管理虚拟内存地址内存分段内存分页段页式内存管理虚拟内存的覆盖技术虚拟内存的交换技术 缺页异常内存页面置换算法局部页面置换算法Belady现象全局页面置换算法抖动和…

Mybatis-Plus入门

Mybatis-Plus入门 MyBatis-Plus 官网&#xff1a;https://mp.baomidou.com/ 1、简介 MyBatis-Plus (简称 MP) 是一个 MyBatis 的增强工具&#xff0c;在 MyBatis 的基础上只做增强不做改变&#xff0c;为简化开发、 提高效率而生。 https://github.com/baomidou/mybatis-p…

【MySQL系列】第二章 · SQL(上)

写在前面 Hello大家好&#xff0c; 我是【麟-小白】&#xff0c;一位软件工程专业的学生&#xff0c;喜好计算机知识。希望大家能够一起学习进步呀&#xff01;本人是一名在读大学生&#xff0c;专业水平有限&#xff0c;如发现错误或不足之处&#xff0c;请多多指正&#xff0…

深入理解强化学习——多臂赌博机:知识总结

分类目录&#xff1a;《深入理解强化学习》总目录 我们在《深入理解强化学习——多臂赌博机》系列文章中介绍了几种平衡试探和开发的简单方法。 ϵ − \epsilon- ϵ−贪心方法在一小段时间内进行随机的动作选择&#xff0c;而UCB方法虽然采用确定的动作选择&#xff0c;却可以通…

Leetcode100128. 高访问员工

Every day a Leetcode 题目来源&#xff1a;100128. 高访问员工 解法1&#xff1a;模拟 把名字相同的员工对应的访问时间&#xff08;转成分钟数&#xff09;分到同一组中。 对于每一组的访问时间 accessTime&#xff0c;排序后&#xff0c;判断是否有 accessTime[i] - ac…

吃透 Spring 系列—Web部分

目录 ◆ Spring整合web环境 - Javaweb三大组件及环境特点 - Spring整合web环境的思路及实现 - Spring的web开发组件spring-web ◆ web层MVC框架思想与设计思路 ◆ Spring整合web环境 - Javaweb三大组件及环境特点 在Java语言范畴内&#xff0c;web层框架都是基于J…

win环境Jenkins部署前端项目

今天分享win环境Jenkins部署前端vue项目&#xff0c;使用的版本jenkins版本Jenkins 2.406版本。 前提是jenkins安装好了&#xff0c;通用配置已经配置好了&#xff0c;可以参考上两篇博客。 1、前端项目依赖nodejs&#xff0c;需要安装相关插件 点击进入 安装成功标准 jenki…

【Vue3】scoped 和样式穿透

我们使用很多 vue 的组件库&#xff08;element-plus、vant&#xff09;&#xff0c;在修改样式的时候需要进行其他操作才能成功更改样式&#xff0c;此时就用到了样式穿透。 而不能正常更改样式的原因就是 scoped 标记。 scoped 的渲染规则&#xff1a; <template>&l…

如何在ModelScope社区魔搭下载所需的模型

本篇文章介绍如何在ModelScope社区下载所需的模型。 若您需要在ModelScope平台上有感兴趣的模型并希望能下载至本地&#xff0c;则ModelScope提供了多种下载模型的方式。 使用Library下载模型 若该模型已集成至ModelScope的Library中&#xff0c;则您只需要几行代码即可加载…

STM32之DMA

一、DMA概述 DMA:直接寄存器访问 Direction:直接 Memory:存储器 Access:访问 就是一个外设用于搬运数据&#xff0c;就是一个搬运工。 在串口发送数据的时候&#xff1a;这种效率并不高 如何想要发送大量的数据的时候可以利用DMA 1、DMA工作流程 没有DMA参与…

【友提】2023年“思维100”编程比赛开始报名,名额有限报名抓紧

根据官方昨天发布的通知&#xff0c;2023年上海市“科学小公民”实践展示活动之“思维100”STEM应用能力编程活动&#xff08;秋季&#xff09;开始报名了&#xff0c;为便于大家了解&#xff0c;六分成长为大家整理关键信息如下。为便于叙述&#xff0c;该活动简称为思维100编…

RGB颜色空间与BMP格式图片

RGB颜色空间 RGB可以分为两大类&#xff1a;一种是索引形式&#xff0c;一种是像素形式&#xff1a; 索引形式&#xff1a;存储每个像素在调色板中的索引 RGB1&#xff1a;每个像素用1bit表示&#xff0c;调色板中只包含两种颜色&#xff08;黑白&#xff09;RGB4&#xff1a…

卸载本地开发环境,拥抱容器化开发

以前在公司的时候&#xff0c;使用同事准备的容器化环境&#xff0c;直接在 Docker 内进行开发&#xff0c;爽歪歪呀。也是在那时了解了容器化开发的知识&#xff0c;可惜了&#xff0c;现在用不到那种环境了。所以打算自己在本地也整一个个人的开发环境&#xff0c;不过因为我…

吴恩达《机器学习》8-3->8-4:模型表示I、模型表示II

8.3、模型表示I 一、大脑神经网络的基本原理 为了构建神经网络模型&#xff0c;首先需要理解大脑中的神经网络是如何运作的。每个神经元都可以被看作是一个处理单元或神经核&#xff0c;它包含多个输入&#xff08;树突&#xff09;和一个输出&#xff08;轴突&#xff09;。…

在vue3中使用Element-plus的图标

首先安装Element-Plus-icon # 选择一个你喜欢的包管理器# NPM $ npm install element-plus/icons-vue # Yarn $ yarn add element-plus/icons-vue # pnpm $ pnpm install element-plus/icons-vue 如何使用 Element-Plus-icon官方文档链接Icon 图标 | Element Plus (element-…

【操作系统】1.1 操作系统的基础概念、功能以及特性

&#x1f4e2;&#xff1a;如果你也对机器人、人工智能感兴趣&#xff0c;看来我们志同道合✨ &#x1f4e2;&#xff1a;不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 &#x1f4e2;&#xff1a;文章若有幸对你有帮助&#xff0c;可点赞 &#x1f44d;…

Unity之NetCode多人网络游戏联机对战教程(8)--玩家位置同步

文章目录 前言添加相机玩家添加对应组件服务端权威&#xff08;server authoritative&#xff09;客户端权威&#xff08;client authoritative&#xff09;服务端同步位置阅读与理解PlayerTransformSync.csNetworkVariableUploadTransformSyncTransform 后话 前言 承接上篇&a…

LOW-POWER AUDIO KEYWORD SPOTTING USING TSETLIN MACHINES

基于TM的低功耗语音关键字识别 摘要1介绍2TM的介绍3KWS的音频预处理技术4实验结果MFC4.1C设置分位数数量4.3增加关键词数量4.4 声音相似的关键词4.5 每个类别的子句数量对KWS-TM的比较学习收敛和复杂性分析 摘要 在本文中&#xff0c;我们探讨了一种基于TM的关键词识别&#x…

《算法通关村——透彻理解二叉树中序遍历的应用》

《算法通关村——透彻理解二叉树中序遍历的应用》 直接上题 108. 将有序数组转换为二叉搜索树 给你一个整数数组 nums &#xff0c;其中元素已经按 升序 排列&#xff0c;请你将其转换为一棵 高度平衡 二叉搜索树。 高度平衡 二叉树是一棵满足「每个节点的左右两个子树的高…

同一个Unity项目打开两个Unity Editor实例

特殊情况下&#xff0c;同一个项目需要同时打开两个编辑器做测试&#xff0c;如多人在线游戏&#xff0c;或者有通信功能的时候就有这样的需求。同时也为了方便调试和观察日志。并且修改的是同一份代码。 命令介绍&#xff1a; 实现思路&#xff1a; 使用 mklink 命令 分别创建…