ONNX系列: ONNX模型修改

ONNX 模型修改

        当我们熟悉了ONNX模型各个层级的结构后,我们便可以针对各个结构来对模型进行修改,从而使其更好的适配后端运行时或者特定硬件平台的编译器。对模型的修改通常可以概括为"增删改查"的操作。"增"是增加相应结构,"删"是删除相应结构,"改"是修改相应结构,"查"是获取到指定的模型结构。修改ONNX模型通常有两种思路,一是使用ONNX官方提供的Python API;二是使用第三方ONNX模型修改工具,例如onnx-graphsurgeon工具。本文将聚焦第一种方案,介绍如何使用ONNX官方API来对ONNX模型进行"增删改查"修改。完整的ONNX官方API文档可以参考:https://onnx.ai/onnx/index.html。

1. ONNX 模型的"查"

我们想要修改ONNX模型,首先需要知道如何定位到自己感兴趣的位置,比如如何找到具体某个节点、某个 initializer,、计算图的input/output, 某个节点的 input/output以及某个 value_info。参考下面的代码,我们可以发现定位某个元素的基本思路就是遍历该元素的列表,然后根据该元素在计算图中独有的属性名称来实现定位。下面的代码实现了定位下图模型的各个元素。

# 根据算子的名字来找到目标节点

for item in model.graph.node:

    if item.name == 'Conv_1':

        print(item)

# 有的onnx模型中算子没有name属性,可以根据算子类型和输出的名字来组合找到目标节点

for item in model.graph.node:

    if item.op_type == 'Conv':

        if '1338' in item.output:

            print(item)

# 找到目标 intializer

for i in model.graph.initializer:

    if i.name == '1339':

        print(i.dims)

        print(i.dims)

        print(i.data_type)

        # 二进制形式打印,可能比较长

        print(i.raw_data)

# 找到 graph 的input和output

for i in model.graph.input:

    if i.name == 'input':

        print(i.name)

        print(i.type)

# 找到 graph 的valueinfo

for i in model.graph.value_info:

    if i.name == '9':

        print(i.name)

        print(i.type)

2. ONNX 模型的"删"

在了解了如何定位到需要修改的部分后,我们就可以对ONNX模型进行魔改了。我们首先了解如何删除ONNX模型中的指定节点或元素。下面的代码实现了删除图中标注的节点。

import onnx

# 加载模型

model = onnx.load('./super-resolution-10.onnx')

# 根据输入获取指定节点

def get_node_with_input(model, input_name):

    res = []

    for i in model.graph.node:

        if input_name in i.input:

            res.append(i)

    return res

# 根据输出获取指定节点

def get_node_with_output(model, output_name):

    res = []

    for i in model.graph.node:

        if output_name in i.output:

            res.append(i)

    return res

# 删除指定节点并将前后节点连接起来

remove_nodes = []

p = None

n = None

for i in model.graph.node:

    if '10' in i.input:

        # p = find_node_with_output(i.input[0])

        p = get_node_with_output(model, i.input[0])[0]

        remove_nodes.append(i)

    if '11' in i.input:

        # n = find_node_with_input(i.output[0])

        n = get_node_with_input(model, i.output[0])[0]

        remove_nodes.append(i)

n.input[0] = p.output[0]

for i in remove_nodes:

    model.graph.node.remove(i)

onnx.checker.check_model(model)

onnx.save(model, 'super-resolution-10-delete.onnx')

3. ONNX 模型的"增"

"增"是指在ONNX模型指定位置添加节点。在了解添加节点之前,我们首先需要了解如何创建 ONNX 节点。下面以创建一个2D卷积算子和一个ReLu算子为例,并尝试将上一步骤中删除的这两个节点重新添加回模型当中(注意我们权重没有与原模型保持一致)。

node1 = onnx.helper.make_node(

        name="Conv_0",   # 节点名字,不要和op_type搞混了

        op_type="Conv",  # 节点的算子类型, 比如'Conv'、'Relu'、'Add'这类,详细可以参考onnx给出的算子列表

        inputs=["image", "conv.weight", "conv.bias"],  # 各个输入的名字,结点的输入包含:输入和算子的权重。必有输入X和权重W,偏置B可以作为可选。

        outputs=["11"], 

        pads=[1, 1, 1, 1], # 其他字符串为节点的属性,attributes在官网被明确的给出了,标注了default的属性具备默认值。

        group=1,

        dilations=[1, 1],

        kernel_shape=[3, 3],

        strides=[1, 1]

    )

initializer_w = onnx.helper.make_tensor(

        name="conv.weight",

        data_type=onnx.helper.TensorProto.DataType.FLOAT,

        dims=[64, 64, 3, 3],

        vals=np.ones([64,64,3,3], dtype=np.float32).tobytes(),

        raw=True

    )

initializer_b = onnx.helper.make_tensor(

        name="conv.bias",

        data_type=onnx.helper.TensorProto.DataType.FLOAT,

        dims=[64],

        vals=np.ones([64], dtype=np.float32).tobytes(),

        raw=True

    )

node2 = onnx.helper.make_node(

        name="ReLU_1",

        op_type="Relu",

        inputs=["11"],

        outputs=["12"]

    )

下面代码将上述创建的两个节点插入到模型指定位置。

for i in range(len(model.graph.node)):

    if '10' in model.graph.node[i].output:

        model.graph.node[i].output[0] = 'pre_output'

        model.graph.node[i+1].input[0] = 'relu_output'

        model.graph.node.insert(i+1, node1)

        model.graph.node.insert(i+2, node2)

model.graph.initializer.append(initializer_w)

model.graph.initializer.append(initializer_b)

input = model.graph.input[0]

new_input = onnx.helper.make_tensor_value_info(input.name, onnx.TensorProto.FLOAT, [1,1,224,224])

model.graph.input[0].CopyFrom(new_input)

onnx.checker.check_model(model)

model = onnx.shape_inference.infer_shapes(model)

onnx.save(model, 'super-resolution-10-insert.onnx')

4. ONNX 模型的"改"

通常来说修改 ONNX 模型可以概括为一下两种:

  • 修改模型节点
  • 修改权重(initializer)

修改模型的节点可以通过上述的删除 + 添加节点组合操作来实现,这里不再赘述。下面将介绍如何修改节点权重。节点权重通常保存在initializer中,下面代码尝试将Conv算子中的bias缩小10倍。

import onnx

model = onnx.load("./super-resolution-10.onnx")

# 得到所有 initializer

all_initializer = model.graph.initializer

# 定位到目标 initializer

target_initializer = 'conv1.bias'

idx = ''

scale_factor = 10

for i, j in enumerate(all_initializer):

    if j.name == target_initializer:

        idx = i

        break

# 将 conv1 算子的 bias 缩小10倍

model.graph.initializer[idx].raw_data = (onnx.numpy_helper.to_array(all_initializer[idx]) / scale_factor).tobytes()

onnx.save(model,'super-resolution-10-scale.onnx')

总结

当我们在实际部署模型时,会根据具体硬件特性来在 ONNX 模型层面做相应的优化修改,使其能在特定的硬件平台上获得更好的推理性能。本文简单介绍了如何调用 ONNX 官方API来对 ONNX 模型进行增删改查,更加复杂的模型修改操作通常是上述四种操作的各种组合。

使用ONNX 官方API需要我们对 ONNX 模型的定义和Proto结构足够熟悉,并且通过本文中的示例代码可以看到,繁多复杂的API在使用过程中也不是很方便。在实际工作中,我们一般使用NV提供的onnx-graphsurgeon工具来快速对ONNX模型进行修改验证。这个工具在官方ONNX API的基础上提供了更为友好的高级API封装,大大提升了我们修改ONNX模型的效率,在之后的文章中我们将进一步详细介绍这个工具的使用。

作者:高通工程师,阮慧源(Huiyuan Ruan)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/530809.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SAP 采购订单预制发票不让重复开立增强(包含:LMR1MF6S)<转载>

原文链接:https://blog.csdn.net/LH26988/article/details/136802631 之前博主有介绍过通过配置来控制不让采购发票重复开立,然是这个方式有点缺陷(跳转) 今天介绍,通过增强来彻底搞定这个问题的办法: 问题…

数组与链表:JavaScript中的数据结构选择

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

环境监测站升级选择ARM网关驱动精准数据采集

物联网技术的深入发展和环保需求的不断攀升,API调用网关在环境监测领域的应用正成为科技创新的重要推手。其中,集成了API调用功能的ARM工控机/网关,以其出色的计算性能、节能特性及高度稳定性,成功搭建起连接物理世界与数字世界的…

vue3移动端H5 瀑布流显示列表

以上效果 是之前发送的改进版 waterList <template><view class"pro-cons" v-if"data.length"><view class"cons-left"><template v-for"(item, index) in data"><template v-if"(index 1) % 2 1…

wangEditor 测试环境对,但是生产环境无法显示

package.json 文件版本 "wangeditor": "4.3.0"开发环境 new Editor(#${this.id});出来的数据 正式环境 new Editor(#${this.id});出来的数据 原因&#xff1a; vue.config 文件 打包策略的时候 const assetsCDN {css: [https://lf6-cdn-tos.bytecd…

【分析 GClog 的吞吐量和停顿时间、heapdump 内存泄漏分析】

文章目录 &#x1f50a;博主介绍&#x1f964;本文内容GClog分析以优化吞吐量和停顿时间步骤1: 收集GClog步骤2: 分析GClog步骤3: 优化建议步骤4: 实施优化 Heapdump内存泄漏分析步骤1: 获取Heapdump步骤2: 分析Heapdump步骤3: 定位泄漏对象步骤4: 分析泄漏原因步骤5: 修复泄漏…

基于YOLOv8的摄像头下铁路工人安全作业检测系统

&#x1f4a1;&#x1f4a1;&#x1f4a1;本文摘要&#xff1a;基于YOLOv8的铁路工人安全作业检测系统&#xff0c;属于小目标检测范畴&#xff0c;并阐述了整个数据制作和训练可视化过程&#xff0c; 博主简介 AI小怪兽&#xff0c;YOLO骨灰级玩家&#xff0c;1&#xff0…

物联网实战--驱动篇之(六)4G通讯(Air780E)

目录 一、4G模块简介 二、AIR780E驱动程序 三、AIR780使用注意事项 四、结合MQTT传输测试 一、4G模块简介 4G应该是我们日常生活最常见的一种互联网通讯方式了&#xff0c;每个智能手机都配置了&#xff0c;不过手机的4G跟我们物联网领域要用的4G有点区别。首先是物联网采用…

Docker容器嵌入式开发:MySQL表的外键约束及其解决方法

本文内容涵盖了使用MySQL创建数据库和表、添加数据、处理字符集错误、解决外键约束问题以及使用SQL查询数据的过程。通过创建表、插入数据和调整字符集等操作&#xff0c;成功解决了数据库表中的字符集问题&#xff0c;并使用INSERT语句向各个表中添加了示例数据。同时&#xf…

乘苏州金龙客车,览西北无边胜境

2023年&#xff0c;甘肃省共接待游客3.88亿人次&#xff0c;实现旅游收入2745.8亿元&#xff0c;分别较上年同期增长187.8%和312.9%&#xff0c;分别恢复到2019年同期的104%和102.4%。随着旅游市场的持续火爆&#xff0c;甘肃保利旅游客运有限责任公司&#xff08;简称“甘肃保…

Day105:代码审计-PHP原生开发篇SQL注入数据库监控正则搜索文件定位静态分析

目录 代码审计-学前须知 Bluecms-CNVD-1Day-常规注入审计分析 emlog-CNVD-1Day-常规注入审计分析 emlog-CNVD-1Day-2次注入审计分析 知识点&#xff1a; 1、PHP审计-原生态开发-SQL注入&语句监控 2、PHP审计-原生态开发-SQL注入&正则搜索 3、PHP审计-原生态开发-SQ…

vulhub之Webmin篇

Webmin是功能最强大的基于Web的Unix系统管理工具。管理员通过浏览器访问Webmin的各种管理功能并完成相应的管理动作。Webmin支持绝大多数的Unix系统&#xff0c;这些系统除了各种版本的linux以外还包括&#xff1a;AIX、HPUX、Solaris、Unixware、Irix和FreeBSD等。 影响版本&…

Codigger Desktop:使用体验与获得收益双赢的革新之作(二)

昨天&#xff0c;我们介绍了Codigger Desktop的最大亮点在于&#xff0c;它不仅仅是一个普通的桌面应用程序&#xff0c;更是一个能够产生实际价值的平台。无论您是开发者还是使用者&#xff0c;Desktop都能给您带愉快体验&#xff1a; 首先&#xff0c;Codigger Desktop具备零…

51单片机入门_江协科技_25~26_OB记录的笔记_蜂鸣器教程

25. 蜂鸣器 25.1. 蜂鸣器介绍 •蜂鸣器是一种将电信号转换为声音信号的器件&#xff0c;常用来产生设备的按键音、报警音等提示信号 •蜂鸣器按驱动方式可分为有源蜂鸣器和无源蜂鸣器&#xff08;开发板上用的无源蜂鸣器&#xff09; •有源蜂鸣器&#xff1a;内部自带振荡源&a…

卡塔尔世界杯中的“先进技术”

2010年&#xff0c;卡塔尔赢得了世界杯的主办权&#xff0c;在这个年轻而雄心勃勃的国家历史上留下了金色的足迹。从那时起&#xff0c;卡塔尔开始与时间赛跑&#xff0c;以履行东道主的义务&#xff0c;并履行组织一届特殊世界杯的承诺&#xff0c;这是世界上最重要的体育赛事…

P5356 [Ynoi2017] 由乃打扑克

我手把手教她打扑克 qwq 综合分析一下2个操作&#xff0c;查找区间第k小的值&#xff0c;感觉可以用主席树&#xff0c;区间修改那没事了 考虑分块做法,块长B 分析第一个操作 只需要维护数列的单调性&#xff0c;然后二分答案上二分就ok了 分析第二个操作 维护一个加法懒…

抖音电商发布2024春茶消费数据报告,90后00后占春茶消费数量四成

清明已过&#xff0c;茶香正浓。进入4月&#xff0c;各类春茗进入销售旺季。近日&#xff0c;抖音电商发布《2024春茶消费数据报告》&#xff08;以下简称“报告”&#xff09;&#xff0c;展示平台用户在原叶茶消费领域的最新趋势。 报告显示&#xff0c;在刚过去的3月&#…

dayjs 判断是否今天、本周内、本年内、本年外显示周几、月份等

效果: 判断是否今天需从 dayjs 中引入 isToday 插件&#xff1b; 判断是否两个日期之间需从 dayjs 中引入 isBetween 插件 import dayjs from dayjs import isToday from dayjs/plugin/isToday import isBetween from dayjs/plugin/isBetween// 注册插件 dayjs.extend(isBet…

论文学习D2UNet:用于地震图像超分辨率重建的双解码器U-Net

标题&#xff1a;&#xff1a;Dual Decoder U-Net for Seismic Image Super-Resolution Reconstruction ——D2UNet&#xff1a;用于地震图像超分辨率重建的双解码器U-Net 期刊&#xff1a;IEEE Transactions on Geoscience and Remote Sensing 摘要&#xff1a;从U-Net派生…

Python实现滑块验证码识别,最简单的一种,没有任何加密

网址链接&#xff1a;衣丰 & 2010-聚衣网(juyi5.cn) - 常熟市聚衣网&#xff0c;聚衣网女装&#xff0c;江苏省女装批发&#xff0c;苏州市女装批发&#xff0c;常熟市女装批发&#xff0c;网销女装一件代发&#xff0c;全国最低价 平时采集数据&#xff0c;频率过快&…