rknn转换后精度差异很大,失真算子自纠

下面是添加了详细注释的优化代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inference

def get_all_node_names(model):
    """
    获取模型中所有节点的名称。

    参数:
    model (onnx.ModelProto): ONNX 模型。

    返回:
    list: 包含所有节点名称的列表。
    """
    return [node.name for node in model.graph.node]

def remove_node_and_following(model, node_name):
    """
    删除指定节点及其后续节点,并返回新的模型。

    参数:
    model (onnx.ModelProto): 原始 ONNX 模型。
    node_name (str): 要删除的节点名称。

    返回:
    onnx.ModelProto: 修改后的 ONNX 模型。
    """
    nodes_to_keep = []  # 要保留的节点
    nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点
    start_removal = False  # 是否开始删除节点
    output = []  # 输出节点列表

    for node in model.graph.node:
        if node.name == node_name:
            start_removal = True
        if start_removal:
            nodes_to_remove.add(node.name)
        else:
            nodes_to_keep.append(node)
            output.extend(node.output)

    for node in model.graph.value_info:
        if node.name in output:
            shape = [
                dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else None
                for dim in node.type.tensor_type.shape.dim
            ]
            output_tensor = helper.make_tensor_value_info(
                node.name,
                onnx.TensorProto.FLOAT,
                shape
            )
            model.graph.output.append(output_tensor)

    new_graph = helper.make_graph(
        nodes_to_keep,
        model.graph.name,
        model.graph.input,
        [output for output in model.graph.output if output.name not in nodes_to_remove],
        model.graph.initializer,
    )

    new_model = helper.make_model(new_graph, producer_name=model.producer_name)
    new_model = shape_inference.infer_shapes(new_model)

    return new_model

def preprocess_image(image_path, target_shape):
    """
    加载并预处理图像。

    参数:
    image_path (str): 图像文件路径。
    target_shape (tuple): 目标形状 (宽, 高)。

    返回:
    np.ndarray: 预处理后的图像数组。
    """
    im = cv2.imread(image_path)
    im = cv2.resize(im, target_shape)
    im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGB
    return np.ascontiguousarray(im)

def main():
    model_path = 'yolov5s.onnx'
    model = onnx.load(model_path)

    dtype_map = {
        'tensor(float)': np.float32,
        'tensor(int32)': np.int32,
        'tensor(int64)': np.int64,
    }

    all_node_names = get_all_node_names(model)
    im = preprocess_image('./sample.jpg', (640, 640))

    for node_name in all_node_names:
        if node_name:
            new_model = remove_node_and_following(model, node_name)
            new_model_path = 'test.onnx'
            onnx.save(new_model, new_model_path)

            sess = rt.InferenceSession(new_model_path)
            output_names = [output.name for output in sess.get_outputs()]
            input_data = {
                input.name: im.reshape(input.shape).astype(dtype_map[input.type])
                for input in sess.get_inputs()
            }

            outputs = sess.run(output_names, input_data)

            # TODO: 补充 rknn 转换、加载和获取输出的代码

            # TODO: 补充评估对比两个输出 onnx 的 target 和 rknn 的 predict

            # TODO: 补充当评估准确率低于 0.8 时终止循环,此时的模型最后一个算子就是 rknn 不正确的算子

if __name__ == "__main__":
    main()

在这里插入图片描述

解释

在补充 TODO 部分的代码之前,我们需要确保你有正确的 RKNN API 环境,并且你知道如何使用它来转换、加载和获取输出。以下是一个示例代码,展示了如何使用 RKNN API 完成这些任务。

1. 安装 RKNN Toolkit

首先,确保你已经安装了 RKNN Toolkit。你可以通过以下命令安装:

pip install rknn-toolkit

2. 补充代码

以下是补充了 TODO 部分的代码:

import cv2
import numpy as np
import onnx
import onnxruntime as rt
from onnx import helper, shape_inference
from rknn.api import RKNN

def get_all_node_names(model):
    """
    获取模型中所有节点的名称。

    参数:
    model (onnx.ModelProto): ONNX 模型。

    返回:
    list: 包含所有节点名称的列表。
    """
    return [node.name for node in model.graph.node]

def remove_node_and_following(model, node_name):
    """
    删除指定节点及其后续节点,并返回新的模型。

    参数:
    model (onnx.ModelProto): 原始 ONNX 模型。
    node_name (str): 要删除的节点名称。

    返回:
    onnx.ModelProto: 修改后的 ONNX 模型。
    """
    nodes_to_keep = []  # 要保留的节点
    nodes_to_remove = set(i.name for i in model.graph.output)  # 要删除的节点
    start_removal = False  # 是否开始删除节点
    output = []  # 输出节点列表

    for node in model.graph.node:
        if node.name == node_name:
            start_removal = True
        if start_removal:
            nodes_to_remove.add(node.name)
        else:
            nodes_to_keep.append(node)
            output.extend(node.output)

    for node in model.graph.value_info:
        if node.name in output:
            shape = [
                dim.dim_value if (dim.dim_value > 0 and dim.HasField('dim_value')) else None
                for dim in node.type.tensor_type.shape.dim
            ]
            output_tensor = helper.make_tensor_value_info(
                node.name,
                onnx.TensorProto.FLOAT,
                shape
            )
            model.graph.output.append(output_tensor)

    new_graph = helper.make_graph(
        nodes_to_keep,
        model.graph.name,
        model.graph.input,
        [output for output in model.graph.output if output.name not in nodes_to_remove],
        model.graph.initializer,
    )

    new_model = helper.make_model(new_graph, producer_name=model.producer_name)
    new_model = shape_inference.infer_shapes(new_model)

    return new_model

def preprocess_image(image_path, target_shape):
    """
    加载并预处理图像。

    参数:
    image_path (str): 图像文件路径。
    target_shape (tuple): 目标形状 (宽, 高)。

    返回:
    np.ndarray: 预处理后的图像数组。
    """
    im = cv2.imread(image_path)
    im = cv2.resize(im, target_shape)
    im = im.transpose((2, 0, 1))[::-1]  # HWC 转 CHW, BGR 转 RGB
    return np.ascontiguousarray(im)

def convert_onnx_to_rknn(onnx_model_path, rknn_model_path):
    """
    将 ONNX 模型转换为 RKNN 模型。

    参数:
    onnx_model_path (str): ONNX 模型路径。
    rknn_model_path (str): 转换后的 RKNN 模型路径。
    """
    rknn = RKNN()

    # 加载 ONNX 模型
    print('--> Loading model')
    ret = rknn.load_onnx(model=onnx_model_path)
    if ret != 0:
        print('Load ONNX model failed!')
        return
    print('done')

    # 配置模型
    print('--> Building model')
    ret = rknn.build(do_quantization=False)
    if ret != 0:
        print('Build RKNN model failed!')
        return
    print('done')

    # 导出 RKNN 模型
    print('--> Export RKNN model')
    ret = rknn.export_rknn(rknn_model_path)
    if ret != 0:
        print('Export RKNN model failed!')
        return
    print('done')

def load_and_run_rknn_model(rknn_model_path, input_data):
    """
    加载 RKNN 模型并运行推理。

    参数:
    rknn_model_path (str): RKNN 模型路径。
    input_data (np.ndarray): 输入数据。

    返回:
    list: RKNN 模型的输出结果。
    """
    rknn = RKNN()

    # 加载 RKNN 模型
    print('--> Loading RKNN model')
    ret = rknn.load_rknn(rknn_model_path)
    if ret != 0:
        print('Load RKNN model failed!')
        return []
    print('done')

    # 初始化 RKNN 模型
    print('--> Init runtime environment')
    ret = rknn.init_runtime()
    if ret != 0:
        print('Init runtime environment failed!')
        return []
    print('done')

    # 运行推理
    print('--> Running model')
    outputs = rknn.inference(inputs=[input_data])
    print('done')

    rknn.release()

    return outputs

def compare_outputs(onnx_outputs, rknn_outputs, threshold=0.8):
    """
    比较 ONNX 和 RKNN 模型的输出结果。

    参数:
    onnx_outputs (list): ONNX 模型的输出结果。
    rknn_outputs (list): RKNN 模型的输出结果。
    threshold (float): 准确率阈值。

    返回:
    bool: 如果准确率低于阈值,则返回 False,否则返回 True。
    """
    # 计算准确率 (这里假设是简单的相对误差)
    accuracy = np.mean([np.allclose(onnx_out, rknn_out, rtol=threshold) for onnx_out, rknn_out in zip(onnx_outputs, rknn_outputs)])

    return accuracy >= threshold

def main():
    model_path = 'yolov5s.onnx'
    model = onnx.load(model_path)

    dtype_map = {
        'tensor(float)': np.float32,
        'tensor(int32)': np.int32,
        'tensor(int64)': np.int64,
    }

    all_node_names = get_all_node_names(model)
    im = preprocess_image('./sample.jpg', (640, 640))

    for node_name in all_node_names:
        if node_name:
            new_model = remove_node_and_following(model, node_name)
            new_model_path = 'test.onnx'
            onnx.save(new_model, new_model_path)

            sess = rt.InferenceSession(new_model_path)
            output_names = [output.name for output in sess.get_outputs()]
            input_data = {
                input.name: im.reshape(input.shape).astype(dtype_map[input.type])
                for input in sess.get_inputs()
            }

            onnx_outputs = sess.run(output_names, input_data)

            # 转换 ONNX 模型为 RKNN 模型
            rknn_model_path = 'test.rknn'
            convert_onnx_to_rknn(new_model_path, rknn_model_path)

            # 加载并运行 RKNN 模型
            rknn_outputs = load_and_run_rknn_model(rknn_model_path, im)

            # 比较 ONNX 和 RKNN 模型的输出结果
            if not compare_outputs(onnx_outputs, rknn_outputs):
                print(f'Node {node_name} is the incorrect operator in RKNN model.')
                break

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/730923.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何正确理解和评估品牌价值?

在当今这个品牌林立的商业世界里,我们常常听到企业家们满怀憧憬地谈论品牌梦想。 但究竟是什么驱使这些企业去打造一个品牌,到底是市场的激烈竞争,还是内心的情感寄托?亦或是社会发展的必然趋势,引领我们追求超越产品…

【shell脚本速成】函数

文章目录 一、函数1.1、函数介绍1.2、函数定义1.3、函数调用 🌈你好呀!我是 山顶风景独好 🎈欢迎踏入我的博客世界,能与您在此邂逅,真是缘分使然!😊 🌸愿您在此停留的每一刻&#xf…

Java用文件流mask文本文件某些特定字段

思路 在Java中,如果你想要掩码(mask)文本文件中的某些特定字段,你可以按照以下步骤进行: 读取文本文件内容。找到并识别需要掩码的字段。用特定的掩码字符(如星号*)替换这些字段。将修改后的内…

Kubernates容器化JVM调优笔记(内存篇)

Kubernates容器化JVM调优笔记(内存篇) 先说结论背景思路方案 先说结论 1、首先如果是JDK8,需要使用JDK8_191版本以上,才支持容器化环境和以下参数,否则就更新到JDK10以上,选择对应的镜像构建就行了 2、在容…

cd 命令特殊路径符 mkdir命令

cd 特殊路径符 cd . 表示当前目录,比如 cd ./Desktop表示切换到当前目录下的Desktop目录内,和 cd Desktop效果一致。cd … 表示上一级目录,比如 cd … 即可切换到上一级目录,cd…/…切换到上二级目录。cd ~ 表示 HOME 目录&#…

【自动驾驶】运动底盘状态数据:里程计、IMU、运动学分析、串口通信协议

文章目录 控制器与运动底盘状态数据:里程计、IMU运动学分析与轮子运动学分析公式串口通信控制与反馈通讯协议串口通信反馈上行数据帧解析串口通信控制下行数据帧解析代码实现IMU、里程计数据的获取、解析、计算控制器与运动底盘状态数据:里程计、IMU 控制器需要负责外发底盘…

智慧园区解决方案PPT(53页)

## 1.1 智慧园区背景及需求分析 - 智慧园区的发展历程包括园区规划、经济、产业、企业、管理、理念的转变,强调管理模式创新,关注业务综合化、管理智慧化等发展。 ## 1.2 国家对智慧园区发展的政策 - 涉及多个国家部门,如工信部、住建部、…

【机器学习300问】129、RNN如何在情感分析任务中起作用的?

情感分析是自然语言处理(NLP)领域的一个重要分支,它的目标是自动检测和提取出非结构化文本数据中的主观信息(比如:情绪、意见、评价等) 一、情感分析任务案例 分析电商产品评论的情感倾向(三分类…

OS复习笔记ch11-4

磁盘调度 磁盘的物理结构 经典的温彻斯特盘 其中的几个概念: 盘面:可以看成是一个操场的平面,不同的盘面通过中间的轴串在一起磁道:磁道可以看成是操场的跑道,我们知道操场上有外道和内道,最内道中间往…

homework 2024.06.17 math, UI

A的宽度225 B的宽度150 这样画出来就比较标准, 225 * 2 150 * 3 2A 3B

ASP.NET Core 6.0 多种部署方式

IIS 环境准备和部署 安装并配置 IIS 安装 IIS,在搜索输入并打开 启用或关闭 Windows 功能。 配置IIS 需要配置 ASPNETCore 部署IS 程序包安装 (ASP.NET Core Module v2) Download .NET 6.0 (Linux, macOS, and Windows).NET 6.0 downloads…

搭建一个简单的xxljob

数据库表结构: YyJobInfo: public class YyJobInfo {//定时任务idprivate int id;//该定时任务所属的执行器的idprivate int jobGroup;//定时任务描述private String jobDesc;//定时任务添加的时间private Date addTime;//定时任务的更新时间private D…

TIM: A Time Interval Machine for Audio-Visual Action Recognition

标题:TIM:一种用于视听动作识别的时间间隔机器 源文链接:openaccess.thecvf.com/content/CVPR2024/papers/Chalk_TIM_A_Time_Interval_Machine_for_Audio-Visual_Action_Recognition_CVPR_2024_paper.pdfhttps://openaccess.thecvf.com/cont…

Redis 持久化策略

Redis 提供了多种持久化机制,用于将数据保存到磁盘中,以防止因服务器重启或故障而导致的数据丢失。主要的持久化策略有两种:RDB (Redis Database) 和 AOF (Append Only File),即当 Redis 服务器重新启动时,会读取相应的…

SEGGER Embedded Studio IDE移植embOS

SEGGER Embedded Studio IDE移植embOS 一、背景介绍二、任务目标三、技术实现3.1 获得embOS3.2 创建SES工程3.2.1 创建初始Solution和Project3.2.2 制作项目文件结构3.2.3 移植embOS库和有关头文件3.2.3.1 头文件3.2.3.2 库文件3.2.3.3 创建RTOSInit.c源文件3.2.3.4 OS_Error.c…

鸿蒙HarmonyOS NEXT角落里的知识:ArkTS高性能编程实践

概述 本文主要提供应用性能敏感场景下的高性能编程的相关建议,助力开发者开发出高性能的应用。高性能编程实践,是在开发过程中逐步总结出来的一些高性能的写法和建议,在业务功能实现过程中,我们要同步思考并理解高性能写法的原理…

信息学奥赛初赛天天练-31-CSP-J2022基础题-指针、数组、链表、进制转换、深度优先搜索、广度优先搜索、双栈实现队列应用

PDF文档公众号回复关键字:20240621 2022 CSP-J 选择题 单项选择题(共15题,每题2分,共计30分:每题有且仅有一个正确选项) 3.运行以下代码片段的行为是 ( ) int x 101; int y 201; int * p &x; int * q &y;…

【Java】已解决java.net.ProtocolException异常

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决java.net.ProtocolException异常 在Java的网络编程中,java.net.ProtocolException异常通常表示在网络通信过程中,客户端或服务器违反了某种协议规则。…

ASP.NET Core 6.0 启动方式

启动方式 Visualstudio 2022启动 IIS Express IIS Express 是一个专为开发人员优化的轻型独立版本的 IIS。 借助 IIS Express,可以轻松地使用最新版本的 IIS 开发和测试网站。 控制台版面 直接在浏览器输入监听的地址,监听的是 http://localhost:5137 脚本启动 dotnet run…

Java中将文件转换为Base64编码的字节码

在Java中,将文件转换为Base64编码的字节码通常涉及以下步骤: 读取文件内容到字节数组。使用java.util.Base64类对字节数组进行编码。 下面是一个简单的Java示例代码,演示如何实现这个过程: import java.io.File; import java.io…