OpenCV读取tensorflow神经网络模型:SavedModel格式转为frozen graph的方法

  本文介绍基于Pythontensorflow库,将tensorflowkeras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++ 等其他语言中将其打开的方法。

  如果我们需要训练使用一个神经网络模型,一般情况下都是首先借助Python语言中完善的神经网络模型API对其加以训练,训练完毕后在C++Java等语言环境下高效、快速地使用它。最近,就需要在C++ 中打开、使用几个前期已经在Pythontensorflow库中训练好的神经网络模型。但是,由于训练模型时使用的是2.X版本的tensorflow库(且用的是keras的框架),所以训练模型后保存的是SavedModel格式的神经网络模型文件——就是包含3.pb格式文件,以及assetsvariables2个文件夹那种形式的模型;如下图所示。

  而在C++ 中读取神经网络模型,首先是可以借助tensorflow库的C++ API来实现,但是这种方法非常复杂——完整的TensorFlow C++ API部署起来非常困难——需要系统盘至少40 G50 G的剩余空间、动辄0.5 h1 h的编译时长,经常需要花费一周的时间才可以配置成功;所以如果仅仅是需要在C++ 中读取已经训练好的神经网络模型的话,没必要花费这么大功夫去配置TensorFlow C++ API。而同时,基于OpenCV库,我们则可以在简单、快速地配置完其环境后,就基于1个函数对训练好的tensorflow库神经网络模型加以读取、使用。这里如果大家需要配置C++ 环境的OpenCV库,可以参考文章C++计算机视觉库OpenCV在Visual Studio 2022的配置方法(https://blog.csdn.net/zhebushibiaoshifu/article/details/128260507)。

  但是,还有一个问题——OpenCV库自身目前仅支持读取tensorflowfrozen graph格式的神经网络模型,不支持读取SavedModel格式的模型。因此,如果希望基于OpenCV库读取tensorflowSavedModel格式的模型,就需要首先将其转换为frozen graph格式;那么,本文就介绍一下这个操作的具体方法,并给出2种实现这一转换功能的Python代码。

  首先,本文神经网络模型格式转换的代码是基于Python环境中tensorflow库实现的,因此需要配置好这一个库(大家都已经需要转换神经网络模型的格式了,那Python环境中tensorflow库肯定早已经配置好了);如果没有配置,可以参考文章Anaconda配置Python新版本tensorflow库(CPU、GPU通用)的方法(https://blog.csdn.net/zhebushibiaoshifu/article/details/129285815)。

  第1种代码如下。

# -*- coding: utf-8 -*-
"""
Created on Sat Mar  9 14:31:18 2024

@author: fkxxgis
"""

import tensorflow as tf
from tensorflow.keras import models
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

# model_save_model = tf.saved_model.load("F:/Data_Reflectance_Rec/model/model_blue/")
model_save_model = models.load_model("F:/Data_Reflectance_Rec/model/model_blue/")

signatures = model_save_model.signatures["serving_default"]
graph = tf.function(lambda x: model_save_model(x))
graph = graph.get_concrete_function(tf.TensorSpec(signatures.inputs[0].shape.as_list(), signatures.inputs[0].dtype.name))
frozen_variable = convert_variables_to_constants_v2(graph)
frozen_variable.graph.as_graph_def();

tf.io.write_graph(graph_or_graph_def = frozen_variable.graph, 
                  logdir = "F:/Data_Reflectance_Rec/model/model_blue_new", 
                  name = "frozen_graph.pb", 
                  as_text = False)
# tf.io.write_graph(graph_or_graph_def = frozen_variable.graph, 
#                   logdir = "F:/Data_Reflectance_Rec/model/model_blue_new", 
#                   name = "frozen_graph.pbtxt", 
#                   as_text = True)

  其中,我们首先需要导入对应的Python模块和convert_variables_to_constants_v2()函数。

  随后,加载我们待转换的、SavedModel格式的tensorflow神经网络模型。这里需要注意,我写了2句不同的代码来加载初始的模型——其中,如果用第1句代码加载模型,倒也可以不报错地运行完成上述代码,但是等到用C++ 环境的OpenCV库读取这个转换后的模型时,会出现Microsoft C++ 异常: cv::Exception字样的报错,如下图所示;而如果用第2句代码加载模型,就没有问题。之所以会这样,应该是因为我当初训练这个神经网络模型时,用的是tensorflowkeras模块的Model,所以导致加载模型时,就不能用传统的加载SavedModel格式模型的方法了(可能是这样)。

  接下来,我们从初始模型中获取其签名tensorflow库中的签名(Signature),是用于定义模型输入、输出的一种机制——其定义了模型接受的输入参数和返回的输出结果的名称、数据类型和形状等信息;这个默认签名为serving_default,我们这里获取这个默认的签名即可。

  接下来,这个graph = tf.function(lambda x: model_save_model(x))表示将模型封装在tensorflow的图函数中;随后,get_concrete_function()获取具体函数并指定输入张量的形状和数据类型。说实话,这里的2行代码我也搞不太清楚具体详细含义是什么——但大体上,这些内容应该是tensorflow1.X版本中的一些操作与名词(因为frozen graph格式的模型本来就是tensorflow1.X版本中用的,而SavedModel格式则是2.X版本中常用的)。

  再次,通过convert_variables_to_constants_v2()函数,将图中的变量转换为常量,并基于as_graph_def()定义1个冻结图。

  最后,就可以通过tf.io.write_graph()函数,将冻结图写入指定的目录中,输出文件名为frozen_graph.pbas_text = False表示以二进制格式保存这个模型(如果不加这个参数,就相当于成了.pbtxt文件了,导致后续用C++环境的OpenCV库还是读取不了这个模型)。代码末尾,还有一段注释的部分——如果取消注释,将以文本格式保存冻结图,也就是.pbtxt文件。因为我们只要.pb文件就够了,所以就不需要这段代码了。

  执行上述代码,在结果文件夹中,我们将看到1.pb格式的神经网络模型结果文件,如下图所示。

  接下来,在C++Python等语言的OpenCV库中,我们都可以基于cv::dnn::readNetFromTensorflow()这个函数,来读取我们的神经网络模型了。

  除此之外,再给出另一个版本的转换代码;这个代码其实和前述代码的含义差不多,如果前述代码不能执行,大家可以再尝试尝试下面这个。

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

loaded = tf.saved_model.load('F:/Data_Reflectance_Rec/model/model_nir/')
infer = loaded.signatures['serving_default']

f = tf.function(infer).get_concrete_function(tf.TensorSpec(infer.inputs[0].shape.as_list(), dtype=tf.float32))
f2 = convert_variables_to_constants_v2(f)
graph_def = f2.graph.as_graph_def()

with tf.io.gfile.GFile('frozen_graph.pb', 'wb') as f:
    f.write(graph_def.SerializeToString())

  至此,大功告成。

欢迎关注:疯狂学习GIS

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/446772.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[mmucache]-ARMV8-aarch64的虚拟内存(mmutlbcache)介绍-概念扫盲

🔥博客主页: 小羊失眠啦. 🎥系列专栏:《C语言》 《数据结构》 《C》 《Linux》 《Cpolar》 ❤️感谢大家点赞👍收藏⭐评论✍️ 思考: 1、cache的entry里都是有什么? 2、TLB的entry里都是有什么? 3、MMU操作…

QT给QLabel设置背景颜色

1.选中label 2.右键点击"改变样式表" 3.填写样式,点击apply,ok 注意 #{QLabel名称},例如名称是label就是QLabel#label

opencv人脸识别实战3:多线程和GUI界面设计(PyCharm实现)

一、多线程设计 1、在一个新线程中调用了 scan_face() 函数来进行人脸识别操作。根据识别结果,更新界面显示结果,最后释放资源。 def f_scan_face_thread():var.set(刷脸)ans scan_face()if ans 0:print("最终结果:无法识别")va…

【个人开发】llama2部署实践(三)——python部署llama服务(基于GPU加速)

1.python环境准备 注:llama-cpp-python安装一定要带上前面的参数安装,如果仅用pip install装,启动服务时并没将模型加载到GPU里面。 # CMAKE_ARGS"-DLLAMA_METALon" FORCE_CMAKE1 pip install llama-cpp-python CMAKE_ARGS"…

UE4开个头-简易小汽车

跟着谌嘉诚学的小Demo,记录一下 主要涉及到小白人上下车和镜头切换操作 1、动态演示效果 2、静态展示图片 3、蓝图-上下车

如何轻松打造属于自己的水印相机小程序?

水印相机小程序源码 描述:微信小程序。本文将为您详细介绍小程序水印相机源码的搭建过程,教您如何轻松打造属于自己的水印相机小程序。无论您是初学者还是有一定基础的开发者,都能轻松掌握这个教程。 一:水印相机搭建教程 1 隐…

Ubuntu23.10安装FFmpeg及编译FFmpeg源码

安装FFmpeg: 打开终端: 输入 sudo apt install ffmpeg 安装成功: 验证FFmpeg 默认安装位置与库与头文件位置 使用FFmpeg源码编译: 1.安装YASM sudo apt-get install yasm

鸿蒙开发学习:【ets_frontend组件】

简介 ets_frontend组件是方舟运行时子系统的前端工具,结合ace-ets2bundle组件,支持将ets文件转换为方舟字节码文件。 ets_frontend组件架构图 目录 /arkcompiler/ets_frontend/ ├── test262 # test262测试配置和运行脚本 ├── testTs…

Mysql 死锁案例2-间隙锁与意向插入锁冲突

死锁复现 CREATE TABLE t (id int(11) NOT NULL,c int(11) DEFAULT NULL,d int(11) DEFAULT NULL,PRIMARY KEY (id),KEY c (c) ) ENGINEInnoDB DEFAULT CHARSETutf8;/*Data for the table t */insert into t(id,c,d) values (0,0,0),(5,5,5),(10,10,10) 事务1事务2T1START …

React-路由小知识

1.默认路由 说明:当访问的是一级路由时,默认的二级路由组件可以得到渲染,只需要在二级路由的位置去掉path,设置index.属性为true。 2.404路由 说明:当浏览器输入ul的路径在整个路由配置中都找不到对应的pth,为了用户体验&#x…

Django简易用户登入系统示例

Django简易用户登入系统示例 1)添加url和函数的对应关系(urls.py) urlpatterns [ path(login/, views.login), #login:url路径,views.login:对应的函数 ]2)添加视图函数(views.py) def login(req):if…

React useMemo钩子指南:优化计算性能

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

绳牵引并联机器人动态避障方法

绳牵引并联机器人在受限空间中如何躲避动态障碍物,是个有挑战的课题。 来自哈尔滨工业大学(深圳)的熊昊老师团队,开展了一项有趣的研究,论文《Dynamic Obstacle Avoidance for Cable-Driven Parallel Robots With Mob…

GitOps实践之Argo CD (2)

argocd 【-1】argocd可以解决什么问题? helm 部署是手动的?依赖流水线。而有时候仅仅更新一个小东西,流水线跑好久,CD真的不应该和CI耦合。不同环境的helm配置不同,手动修改问题多,可以用git管理起来,例如分不同环境用目录区分。argocd创建应用可以不通环境部署到不同集…

C++ STL--Vector 详细剖析

目录 1.vector的介绍及使用 1.1 vector的介绍 1.2 vector的使用 1.2.1 vector的定义 1.2.2 vector iterator 的使用 1.2.3 vector 空间增长问题 1.2.3 vector 增删查改 1.2.4 vector 迭代器失效问题 2.vector深度剖析及模拟实现 2.1 std::vector的核心框架接口的模拟实…

探索云原生数据库技术:构建高效可靠的云原生应用

数据库是应用开发中非常重要的组成部分,可以进行数据的存储和管理。随着企业业务向数字化、在线化和智能化的演进过程中,面对指数级递增的海量存储需求和挑战以及业务带来的更多的热点事件、突发流量的挑战,传统的数据库已经很难满足和响应快…

利用GPT开发应用007:警惕人工智能幻觉,局限与注意事项

文章目录 一、人工智能幻觉二、计算案例三、斑马案例四、总结 正如您所见,一个大型语言模型通过基于给定的输入提示逐个预测下一个单词(或标记)来生成答案。在大多数情况下,模型的输出对您的任务来说是相关的,并且完全…

Windows电脑安装Linux(Ubuntu 22.04)系统(图文并茂)

Windows电脑安装Ubuntu 22.04系统,其它版本的Ubuntu安装方法相同 Ubuntu 16.04、Ubuntu 18.04安装方法相同,制作U盘启动项的镜像文件下载你需要的版本即可! Ubuntu的中文官网网址:https://cn.ubuntu.com/,聪明的你一定…

03-安装配置jenkins

一、安装部署jenkins 1,上传软件包 为了方便学习,本次给大家准备了百度云盘的安装包 链接:https://pan.baidu.com/s/1_MKFVBdbdFaCsOTpU27f7g?pwdq3lx 提取码:q3lx [rootjenkins ~]# rz -E [rootjenkins ~]# yum -y localinst…

SpringMVC08、Json

8、Json 8.1、什么是JSON? JSON(JavaScript Object Notation, JS 对象标记) 是一种轻量级的数据交换格式,目前使用特别广泛。采用完全独立于编程语言的文本格式来存储和表示数据。简洁和清晰的层次结构使得 JSON 成为理想的数据交换语言。易于人阅读和…