tensflow模型转onnx实践

一、基础知识介绍

1、TensorFlow介绍

TensorFlow™是一个基于数据流编程(dataflow programming)的符号数学系统,被广泛应用于各类机器学习(machine learning)算法的编程实现,其前身是谷歌的神经网络算法库DistBelief [1]。Tensorflow拥有多层级结构,可部署于各类服务器、PC终端和网页并支持GPU和TPU高性能数值计算,被广泛应用于谷歌内部的产品开发和各领域的科学研究 [1-2]。TensorFlow由谷歌人工智能团队谷歌大脑(Google Brain)开发和维护,拥有包括TensorFlow Hub、TensorFlow Lite、TensorFlow Research Cloud在内的多个项目以及各类应用程序接口(Application Programming Interface, API)。自2015年11月9日起,TensorFlow依据阿帕奇授权协议(Apache 2.0 open source license)开放源代码 。

2、keras介绍

Keras是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化 。
Keras在代码结构上由面向对象方法编写,完全模块化并具有可扩展性,其运行机制和说明文档有将用户体验和使用难度纳入考虑,并试图简化复杂算法的实现难度 [1]。Keras支持现代人工智能领域的主流算法,包括前馈结构和递归结构的神经网络,也可以通过封装参与构建统计学习模型 。在硬件和开发环境方面,Keras支持多操作系统下的多GPU并行计算,可以根据后台设置转化为Tensorflow、Microsoft-CNTK等系统下的组件 。

3、onnx

ONNX是一种开放格式,专门用于表示机器学习模型。它定义了一套通用的运算符,这些运算符是构建机器学习和深度学习模型的基础单元,同时ONNX还定义了一种通用的文件格式。这些特性使得AI开发者能够跨多种框架、工具、运行时和编译器使用模型。

二、基础环境介绍

实际工作中,模型使用的版本和框架可能各不相同,在做模型转换或者模型迁移工作的过程中,一般先讲各个框架的模型格式转换为通用的中间格式,比如onnx,正如我们所做的,將 TensorFlow框架编写训练的模型转换为onnx格式。

1、框架版本

tensflow版本为1.14,keras版本为2.2.4

三、环境搭建

1、在 anconda当中新建一个虚拟环境,指定python版本为3.7

conda create -n py37_tf14 python=3.7

在这里插入图片描述
检查python和pip版本:

python
pip -V

在这里插入图片描述

2、安装依赖

1)安装tensflow和 keras

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow==1.14.0
Requirement already satisfied: setuptools>=41.0.0 in d:\programs\anaconda3\envs\py37_tf14\lib\site-packages (from tensorboard<1.15.0,>=1.14.0->tensorflow==1.14.0) (65.6.3)
Collecting importlib-metadata>=4.4
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ff/94/64287b38c7de4c90683630338cf28f129decbba0a44f0c6db35a873c73c4/importlib_metadata-6.7.0-py3-none-any.whl (22 kB)
Collecting MarkupSafe>=2.1.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9b/c1/9f44da5ca74f95116c644892152ca6514ecdc34c8297a3f40d886147863d/MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl (17 kB)
Collecting typing-extensions>=3.6.4
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ec/6b/63cc3df74987c36fe26157ee12e09e8f9db4de771e0f3404263117e75b95/typing_extensions-4.7.1-py3-none-any.whl (33 kB)
Collecting zipp>=0.5
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5b/fa/c9e82bbe1af6266adf08afb563905eb87cab83fde00a0a08963510621047/zipp-3.15.0-py3-none-any.whl (6.8 kB)
Installing collected packages: tensorflow-estimator, zipp, wrapt, typing-extensions, termcolor, six, protobuf, numpy, MarkupSafe, grpcio, gast, astor, absl-py, werkzeug, keras-preprocessing, importlib-metadata, h5py, google-pasta, markdown, keras-applications, tensorboard, tensorflow
Successfully installed MarkupSafe-2.1.3 absl-py-1.4.0 astor-0.8.1 gast-0.5.4 google-pasta-0.2.0 grpcio-1.57.0 h5py-3.8.0 importlib-metadata-6.7.0 keras-applications-1.0.8 keras-preprocessing-1.1.2 markdown-3.4.4 numpy-1.21.6 protobuf-4.24.2 six-1.16.0 tensorboard-1.14.0 tensorflow-1.14.0 tensorflow-estimator-1.14.0 termcolor-2.3.0 typing-extensions-4.7.1 werkzeug-2.2.3 wrapt-1.15.0 zipp-3.15.0

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple keras==2.2.4

Requirement already satisfied: keras-applications>=1.0.6 in d:\programs\anaconda3\envs\py37_tf14\lib\site-packages (from keras==2.2.4) (1.0.8)
Requirement already satisfied: h5py in d:\programs\anaconda3\envs\py37_tf14\lib\site-packages (from keras==2.2.4) (3.8.0)
Installing collected packages: scipy, pyyaml, keras
Successfully installed keras-2.2.4 pyyaml-6.0.1 scipy-1.7.3

测试tensflow是否安装完成

import tensorflow as tf
hello = tf.constant('hello Tensorflow')
sess = tf.Session()
print(sess.run(hello))

报错:

TypeError: Descriptors cannot not be created directly

需要降低 protobuf版本到3.19.0

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple protobuf==3.19.0

成功运行测试代码!

2)安装转换时需要的依赖

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple h5py
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple keras2onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tf2onnx

3)检查 安装包

(py37_tf14) PS C:\Users\lenovo> pip list
Package              Version
-------------------- ---------
absl-py              1.4.0
astor                0.8.1
certifi              2022.12.7
charset-normalizer   3.2.0
coloredlogs          15.0.1
fire                 0.5.0
flatbuffers          23.5.26
gast                 0.5.4
google-pasta         0.2.0
grpcio               1.57.0
h5py                 2.10.0
humanfriendly        10.0
idna                 3.4
importlib-metadata   6.7.0
Keras                2.2.4
Keras-Applications   1.0.8
Keras-Preprocessing  1.1.2
keras2onnx           1.7.0
Markdown             3.4.4
MarkupSafe           2.1.3
mpmath               1.3.0
numpy                1.21.6
onnx                 1.8.0
onnxconverter-common 1.13.0
onnxruntime          1.14.1
onnxtk               0.0.1
packaging            23.1
pip                  22.3.1
protobuf             3.20.3
pyreadline           2.1
PyYAML               6.0.1
requests             2.31.0
scipy                1.7.3
setuptools           65.6.3
six                  1.16.0
sympy                1.10.1
tensorboard          1.14.0
tensorflow           1.14.0
tensorflow-estimator 1.14.0
termcolor            2.3.0
tf2onnx              1.15.1
typing_extensions    4.7.1
urllib3              2.0.4
Werkzeug             2.2.3
wheel                0.38.4
wincertstore         0.2
wrapt                1.15.0
zipp                 3.15.0
(py37_tf14) PS C:

三、转换方法问题及解决方案总结

1、将keras的h5模型转化为onnx

import tensorflow as tf



from keras.models import load_model
import onnx
import os
os.environ["TF_KERAS"]='1' 

import keras2onnx
keras_model = tf.keras.models.load_model('model/ar_crnn.h5',compile=False)

onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name)

# tf2onnx.save_model(onnx_model, "ar_crnn.onnx")


'''
报错信息
File "D:\Programs\anaconda3\envs\py37_tf14\lib\site-packages\onnxconverter_common\onnx_ops.py", line 815, in apply_reshape
raise ValueError('There can only be one -1 in the targeted shape of a Reshape but got %s' % desired_shape)
ValueError: There can only be one -1 in the targeted shape of a Reshape but got [-1, -1, 1152]
'''

暂时没有找到解决方案,github官方说已经解决,但提升到最新版本还是报错!

参考官方文档

https://pypi.org/project/keras2onnx/

2、使用tf2onnx 转换为onnx

import tensorflow as tf
import tf2onnx

# Load the Keras model
keras_model = tf.keras.models.load_model('model/ar_crnn.h5')

# Convert the model to ONNX format
onnx_model = tf2onnx.convert.from_keras(keras_model)

# Save the ONNX model to a file
# tf2onnx.save_model(onnx_model, 'my_model.onnx')



'''
raise ValueError("Tensor name '{0}' is invalid.".format(node.input[0]))
ValueError: Tensor name 'batch_normalization_1/cond/ReadVariableOp/Switch:1' is invalid.
'''

没有batch_normalization算子,可能因为tf2onnx不支持keras的h5模型

3、尝试通过 tensorflow模型中介的方式转换:h5→tf→onnx

import tensorflow as tf

model_path = './model/ar_crnn.h5'                    # 模型文件
model = tf.keras.models.load_model(model_path)
model.save('tfmodel', save_format='tf')

'''
 'Saving the model as SavedModel is not supported in TensorFlow 1.X'
'''

报错,表示tensorflow1.x不支持保存tf文件,为避免升级tensorflow2.x后带来其他版本不兼容的问题,未再对这种方式做尝试

4、使用h5py文件直接转成onnx的方式,暂时没有探索成功

import h5py
import onnx
from onnx import helper, shape_inference
import json

model_file = 'model/ar_crnn.h5'
# 打开h5文件
with h5py.File(model_file, 'r') as f:
    # 获取所有子集
    model_config = json.loads(f.attrs['model_config'].decode('utf-8'))
    
    print(model_config)
    weights = []
    f.visit(lambda name: weights.append(name) if isinstance(f[name],h5py.Dataset) else None)


onnx_model = helper.make_model(model_config)
'''
TypeError: Parameter to CopyFrom() must be instance of same class: expected onnx.GraphProto got dict.
'''

报错,在网上的例子中,直接使用f.attrs[‘model_config’]的方法获取的模型结构就可以直接作为helper.make_model()的参数,初始化模型; 但是实际运行时,helper.make_model()期待的是onnx.GraphProto这种类型的入参,在做onnx.GraphProto实例化的时候,仍然缺少必要参数,仍然失败。

5、使用pb中介转换的方式: h5 → pb → onnx, 成功!

1)、下载工具类:https://link.zhihu.com/?target=https%3A//github.com/amir-abdi/keras_to_tensorflow

因为没有官方工具,所以这里下载的是在git开源的工具

2)、使用工具类进行模型转换, 将h5模型转换为pb模型

python keras_to_tensorflow.py  --input_model="model\ar_crnn.h5"  --output_model="crnn.pb"

3)、使用tf2onnx,将pb转换为onnx

python -m tf2onnx.convert --graphdef crnn.pb --output ar_crnn_1.onnx --inputs the_input:-1 --outputs output_new/truediv:0

四、转换完成后,对比精度,检查onnx

编写精度测试代码: 随机生成一张图片, 分别用keras模型和 onnx模型进行推理,计算差值, 求最大值

# 读取h5模型和 onnx模型进行推理,对比结果
import onnxruntime
from keras.models import load_model
import numpy as np
def test_h5_onnx_precision(h5_path, onnx_path, batch_size):
    # 读取h5模型
    keras_model = load_model(h5_path,compile=False)
    input_data = np.random.random(size=(batch_size, 48, 16, 3)).astype(np.float32)
    h5_res = keras_model.predict([input_data])
    # onnx推理
    ort_session = onnxruntime.InferenceSession(onnx_path)
    model_inputs = ort_session.get_inputs()
    ort_inputs = {model_inputs[0].name: input_data}
    onnx_output = ort_session.run(['output_new/truediv:0'], ort_inputs)[0]

    res = h5_res - onnx_output
    print(res)
    print(res.max())

test_h5_onnx_precision('model/ar_crnn.h5', 'ar_crnn.onnx', 1000)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/513773.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

WebGIS 之 Openlayer

1.导入第三方依赖 <link rel"stylesheet" href"https://lib.baomitu.com/ol3/4.6.5/ol.css"> <script src"https://lib.baomitu.com/ol3/4.6.5/ol.js"></script>2.初始化地图 初始化地图new ol.Map({}) 参数target:制定初始化…

阻塞队列(BlockingQueue)

何为阻塞队列 当阻塞队列是空时,从队列中获取元素的操作将被阻塞当阻塞队列是满时,往队列中添加元素将会被阻塞试图从空的阻塞队列中获取元素的线程将会被阻塞,直到其他线程往空的队列中插入新的元素试图往满的队列中,添加新的元素的线程也会被阻塞,直到其他线程从队列中移除…

基于SSM的社区疫情防控管理信息系统

目录 背景 技术简介 系统简介 界面预览 背景 随着时代的进步&#xff0c;计算机技术已经全方位地影响了社会的发展。随着居民生活质量的持续上升&#xff0c;人们对社区疫情防控管理信息系统的期望和要求也在同步增长。在社区疫情防控日益受到广泛关注的背景下&#xff0c…

OpenHarmony实战:Makefile方式组织编译的库移植

以yxml库为例&#xff0c;其移植过程如下文所示。 源码获取 从仓库获取yxml源码&#xff0c;其目录结构如下表&#xff1a; 表1 源码目录结构 名称描述yxml/bench/benchmark相关代码yxml/test/测试输入输出文件&#xff0c;及测试脚本yxml/Makefile编译组织文件yxml/.gitat…

Python基础之pandas:字符串操作与透视表

文章目录 一、字符串操作备注&#xff1a;如果想要全部行都能输出&#xff0c;可输入如下代码 1、字符检索2、字符转换3、字符类型判断4、字符调整5、字符对齐与填充6、字符检索7、字符切割8、字符整理 二、透视表1、pd.pivot_table2、多级透视表 一、字符串操作 备注&#xf…

黄锈水过滤器 卫生热水工业循环水色度水处理器厂家工作原理动画

​ 1&#xff1a;黄锈水处理器介绍 黄锈水处理器是一种专门用于处理“黄锈水”的设备&#xff0c;它采用机电一体化设计&#xff0c;安装方便&#xff0c;操作简单&#xff0c;且运行费用极低。这种处理器主要由数码射频发生器、射频换能器、活性过滤体三部分组成&#xff0c;…

2024年第九届亚太智能机器人系统国际会议即将召开!

2024年第九届亚太智能机器人系统国际会议 (ACIRS 2024) 将于2024年7月18-20日在中国大连举办&#xff0c;由大连理工大学主办&#xff0c;高性能精密制造全国重点实验室、辽宁黄海实验室和智能制造龙城实验联合承办。该会议旨在为智能机器人系统等领域的专家学者建立一个广泛有…

实现顺序表(增、删、查、改)

引言&#xff1a;顺序表是数据结构中的一种形式&#xff0c;就是存储数据的一种结构。 这里会用到动态内存开辟&#xff0c;指针和结构体的知识 1.什么是数据结构 数据结构就是组织和存储数据的结构。 数据结构的特性&#xff1a; 物理结构&#xff1a;在内存中存储的数据是否连…

k8s calico由IPIP模式切换为BGP模式

按照官网calico.yaml部署后&#xff0c;默认是IPIP模式 查看route -n &#xff0c; 看到是tunl0口进行转发 怎么切换到BGP模式呢&#xff1f; kubectl edit ippool 将ipipMode由Always修改为Never &#xff0c;修改后保存文件即可。无需做任何操作&#xff0c;自动就切换为BG…

picgo启动失败解决

文章目录 报错信息原因分析解决方案 报错信息 打开Picgo&#xff0c;显示报错 A JavaScript error occurred in the main process Uncaught Exception: Error:ENOENT:no such file or directory,open ‘C:\Users\koko\AppData\Roaming\picgo\data.json\picgo.log’ 原因分析…

绝不忽视!List.add方法揭秘:你绝对需要了解的覆盖现象

文章目录 引言一、背景介绍1.1 事件背景1.2 List.add()方法简介示例影响 二、覆盖现象解决方案1. 每次循环创建新对象2. 使用工厂方法或建造者模式3. 深拷贝4. 不可变对象 三、解决方案1. 使用深拷贝2. 创建新对象3. 避免直接修改原对象 四、 结论 引言 在 Java 编程中&#x…

MyBatis的基本应用

源码地址 01.MyBatis环境搭建 添加MyBatis的坐标 <!--mybatis坐标--><dependency><groupId>org.mybatis</groupId><artifactId>mybatis</artifactId><version>3.5.9</version></dependency><!--mysql驱动坐…

VSCode调试C++

1、环境准备 1.1、g的安装与使用 1.1.1、安装 方式一&#xff1a;Xcode安装 苹果的开发集成工具是Xcode.app&#xff0c;其中包含一堆命令行工具。 在 App store 可以看到其大小有好几个G&#xff0c;有点大。 方式二&#xff1a;Command Line Tools 安装 Command Line Too…

OpenHarmony实战:小型系统器件驱动移植

本章节讲解如何移植各类器件驱动。 LCD驱动移植 移植LCD驱动的主要工作是编写一个驱动&#xff0c;在驱动中生成模型的实例&#xff0c;并完成注册。 这些LCD的驱动被放置在源码目录//drivers/hdf_core/framework/model/display/driver/panel中。 创建Panel驱动 创建HDF驱动…

高级数据结构与算法习题(6)

一、单选题 1、In the Tic-tac-toe game, a "goodness" function of a position is defined as f(P)=Wcomputer​−Whuman​ where W is the number of potential wins at position P. In the following figure, O represents the computer and X the human. What i…

农业保险利用卫星遥感监测、理赔、农作物风险评估

​农业保险一直是农民和农业生产者面临的重要课题&#xff0c;而卫星遥感技术的不断发展正为农业保险带来全新的解决方案。通过高分辨率的卫星遥感监测&#xff0c;农业保险得以更精准、及时地评估农田状况&#xff0c;为农业经营者提供可靠的风险管理手段。 **1. 灾害监测与风…

2024年第三期丨全国高校大数据与人工智能师资研修班邀请函

2024年第三期 杭州线下班 数据采集与机器学习实战&#xff08;Python&#xff09; 线上班 八大专题 大模型技术与应用实战 数据采集与处理实战&#xff08;Python&八爪鱼&#xff09; 大数据分析与机器学习实战&#xff08;Python&#xff09; 商务数据分析实战&…

29.使线程以固定顺序输出结果

借助wait和notify方法控制线程以固定的顺序执行&#xff1a; /*** 控制输出字符的顺序&#xff0c;必须是固定顺序2,1* 采用wait-notify实现* param args*/public static void main(String[] args) {new Thread(() -> {synchronized (lock) {while (!isPrint2) {try {lock.…

【c++】STl-list使用list模拟实现

主页&#xff1a;醋溜马桶圈-CSDN博客 专栏&#xff1a;c_醋溜马桶圈的博客-CSDN博客 gitee&#xff1a;mnxcc (mnxcc) - Gitee.com 目录 1. list的介绍及使用 1.1 list的介绍 1.2 list的使用 1.2.1 list的构造 1.2.2 list iterator的使用 1.2.3 list capacity 1.2.4 …

【Java】CAS详解

一.什么是CAS CAS(compare and swap) 比较并且交换. CAS是一个cpu指令,是原子的不可再分.因此基于CAS就可以给我们编写多线程的代码提供了新的思路---->使用CAS就不用使用加锁,就不会牵扯到阻塞,也称为无锁化编程 下面是一个CAS的伪代码: address是一个内存地址,expectVal…