tf2使用savemodel保存之后转化为onnx适合进行om模型部署

tf2使用savemodel保存之后转化为onnx适合进行om模型部署

  • tf保存为kears框架h5文件
  • 将h5转化为savemodel格式,方便部署
  • 查看模型架构
  • 将savemodel转化为onnx格式
  • 使用netron
  • onnx模型细微处理
  • 代码转化为om以及推理代码,要么使用midstudio

tf保存为kears框架h5文件

前提环境是tf2.2及其版本以上的框架,模型训练结果保存为h5(也就是kears框架)
Pasted image 20240507233042

将h5转化为savemodel格式,方便部署

之后将h5文件转化为savemodel的格式
Pasted image 20240507233120

custom是在保存模型的时候需要的自定义函数,如果没有则不需要添加

保存结果如下
Pasted image 20240507233222

这个地方记得验证一下savemodel格式是否能成功搭载测试代码

import  os  
import pandas as pd  
import numpy as np  
from sklearn.metrics import accuracy_score  
from sklearn.model_selection import train_test_split  
from tensorflow.keras.models import Sequential  
from tensorflow.keras.layers import LSTM,Dense,Dropout  
from keras.utils import to_categorical  
import tensorflow as tf  
from tensorflow.python.keras.layers import Activation  
  
os.chdir('D:/software_project/心电信号分类/')  
  
  
# 加载 SavedModel 目录  
loaded_model = tf.saved_model.load('tfmodel_save')  
  
# 获取默认的服务签名  
infer = loaded_model.signatures['serving_default']  
print(infer.structured_input_signature)  
print(infer.structured_outputs)  
  
# 加载CSV文件  
file_path = 'data2/shuffled_merged_data.csv'  
data = pd.read_csv(file_path)  
from sklearn.preprocessing import StandardScaler  
  
# 创建StandardScaler实例  
scaler = StandardScaler()  
features = data.iloc[0:1, :-1]  
  
# 获取最后一列作为标签  
labels = data.iloc[0:1, -1]  
features1 = scaler.fit_transform(features)  
# features1 = features1.astype(np.float32)  
  
# # 转化为numpy  
# features = features.to_numpy()  
trainX3 = features1.reshape((features1.shape[0], features1.shape[1], 1))  
  
# # 将数据转换为Tensor  
input_data = tf.convert_to_tensor(trainX3, dtype=tf.float32)  
  
  
output = infer(conv1d_input=input_data)  
output4=output['dense_3']  
print(output4.numpy())  
  
# 为了确定每个样本的预测标签,我们找到概率最高的类别的索引  
predicted_indices = np.argmax(output4.numpy(), axis=1)  
accuracy = accuracy_score(labels, predicted_indices)  
print(accuracy)  
  
output2=output["dense_8"]  
print(output["dense_8"])  
predicted_indices2= np.argmax(output2.numpy(), axis=1)  
accuracy2 = accuracy_score(labels, predicted_indices2)  
print(accuracy2)  
  
  
output2_1=output["dense_8_1"]  
print(output["dense_8_1"])  
predicted_indices2= np.argmax(output2_1.numpy(), axis=1)  
accuracy3 = accuracy_score(labels, predicted_indices2)  
print(accuracy3)  
  
print('nihao')  
# 不可用  
# print(output["StatefulPartitionedCall:0"])

查看模型架构

可以使用这个代码查看模型架构,输入输出的名字

 saved_model_cli show --dir D:\software_project\心电信号分类\tfmodel_save --tag_set serve --sig
nature_def serving_default

结构如下
Pasted image 20240507233536

如果可以用咱们继续进行下一步

将savemodel转化为onnx格式

之后将保存的savemodel格式转化为onnx格式

这里直接上大佬博客
在Atlas 200 DK中部署深度学习模型

基本把每个步骤过一遍即可

注意安装tensorflowgpu的版本是很高的
Pasted image 20240507233803

转换指令

python -m tf2onnx.convert --saved-model tensorflow-model-path --output model.onnx

使用netron

把模型放入到netron中
Netron

导出的onnx模型如下
Pasted image 20240507234346

onnx模型细微处理

获得的onnx模型放入netron中进行查看,发现有些未知输出量需要修改
【tensorflow onnx】TensorFlow2导出ONNX及模型可视化教程_tf2onnx-CSDN博客

主要是这种未知量
Pasted image 20240507233928

代码转化为om以及推理代码,要么使用midstudio

之后即可使用代码进行模型的转化为om

转化成功之后,放到atlks200dk板子中进行模型的推理
代码

import numpy as np  
import acllite_utils as utils  
import constants as const  
from acllite_model import AclLiteModel  
from acllite_resource import AclLiteResource  
import time  
import csv  
import numpy as np  
  
class Reasoning(object):  
    """  
    class for reasoning    """    def __init__(self, model_path):  
        self._model_path = model_path  
        self.device_id = 0  
        self._model = None  
    def init(self):  
        """  
        Initialize        """  
        # Load model  
        self._model = AclLiteModel(self._model_path)  
  
        return const.SUCCESS  
    def inference(self, one_dim_data):  
        """  
        model inference        """        return self._model.execute(one_dim_data)  
  
def main():  
    model_path = 'model_dim_replace.om'  
    # 打开 CSV 文件  
    with open('shuffled_merged_data.csv', newline='') as csvfile:  
        # 创建 CSV 读取器对象  
        csvreader = csv.reader(csvfile, delimiter=',')  
        # 跳过第一行(标题行)  
        next(csvreader)  
        # 读取第二行数据  
        second_row = next(csvreader)  
        # 移除最后一个数据  
        second_row_without_last = second_row[:-1]  
        # 将数据转换为 NumPy 数组  
        np_array = np.array(second_row_without_last, dtype=np.float32)  
        print(np_array.dtype)  
        # 输出转换后的 NumPy 数组  
    acl_resource = AclLiteResource()  
    acl_resource.init()  
    reasoning = Reasoning(model_path)  
    # init  
    ret = reasoning.init()  
    utils.check_ret("Reasoning.init ", ret)  
    start_time = time.time()  
  
    # 假设你有一个名为 input_data 的 NumPy 数组,它包含模型的输入数据  
    input_data = np.array([np_array])  # 替换为你的输入数据  
    result_class = reasoning.inference(input_data)  
  
  
    end_time = time.time()  
    execution_time = end_time - start_time  
  
    print(result_class)  
if __name__ == '__main__':  
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/601235.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

设计严谨,思路绝妙!这篇高级孟德尔随机化研究:药靶、共定位,发文一区(IF=8.9)!...

现在越来越多的学者在用孟德尔随机化高级方法发文,今天我们看的这篇这篇药靶孟德尔随机化,还用了共定位分析方法,亮点在于它的设计严谨,思路绝妙,一起看下去吧! 2024年4月21日,四川大学华西医院…

机器人码垛机的主体结构及技术特点

在现代物流和生产线上,机器人码垛机以其高效、准确的特点,成为了不可或缺的重要设备。那么,这个神奇的机器人究竟由哪些部分组成?它的内部结构又有哪些奥秘呢?接下来,就让我们一起揭开它的神秘面纱! 一、机器人码垛机的主体结构…

每日OJ题_贪心算法三②_力扣553. 最优除法

目录 力扣553. 最优除法 解析代码 力扣553. 最优除法 553. 最优除法 难度 中等 给定一正整数数组 nums,nums 中的相邻整数将进行浮点除法。例如, [2,3,4] -> 2 / 3 / 4 。 例如,nums [2,3,4],我们将求表达式的值 "…

【Leetcode每日一题】 穷举vs暴搜vs深搜vs回溯vs剪枝_全排列 - 子集(解法2)(难度⭐⭐)(72)

1. 题目解析 题目链接:78. 子集 这个问题的理解其实相当简单,只需看一下示例,基本就能明白其含义了。 2.算法原理 为了生成一个给定数组 nums 的所有子集,我们可以利用一种称为回溯(backtracking)的算法…

美国纽扣电池UL4200A及16CFR1262标准亚马逊要求

2023年9月21日,美国消费品安全委员会CPSC(Consumer Product Safety Commission) 决定采用UL 4200A-2023(包含纽扣电池或硬币电池的产品安全标准)作为包含纽扣电池或硬币电池的消费品的强制性消费品安全规则,相关要求同时被编入到1…

C++中的异常处理方式

目录 一、异常 二、C语言中对错误的处理 三、C中的异常处理 四、异常的抛出和捕获 五、异常的重新抛出 六、C标准库中的异常体系 七、异常的规范 一、异常 在C中,异常是程序运行期间发生的意外或错误情况。这些情况可能会导致程序无法继续正常执行,…

STM32接入CH340芯片的初始化进入升级模式(死机)问题处理

目录 1. 问题描述2. 问题分析2.1 CH340G/K 的初始化波形2.2 第1种USB升级电路2.3 第2种USB升级电路2.4 第3种USB升级电路2.5 第4种USB升级电路 3. 总结 1. 问题描述 我所用的CH340G(CH340K也用过)接在MCU的电路中,在插入CH340G/K 的接插件&a…

基于点灯Blinker的ESP8266远程网络遥控LED

本文介绍基于ESP8266模块实现的远程点灯操作,手机侧APP选用的是点灯-Blinker,完整资料及软件见文末链接 一、ESP8266模块简介 ESP8266是智能家居等物联网场景下常用的数传模块,具有强大的功能,通过串口转WIFI的方式可实现远距离…

文献速递:深度学习医学影像心脏疾病检测与诊断--CT中的深度学习用于自动钙评分:使用多个心脏CT和胸部CT协议的验证

Title 题目 Deep Learning for Automatic Calcium Scoring in CT: Validation Using Multiple Cardiac CT and Chest CT Protocols CT中的深度学习用于自动钙评分:使用多个心脏CT和胸部CT协议的验证 Background 背景 Although several deep learning (DL) calc…

微软开发新模型;YouTube 推出新AI功能;可折叠iPhone 或发布?

微软或开发新模型与 Google、OpenAI 竞争 The Information 报道,微软正在训练一种新的 AI 大模型「MAI-1」,规模上足以与 Google、Anthropic 乃至 OpenAI 的先进模型抗衡。 据报道,这个 MAI-1 模型由微软聘请的 Inflection 前 CEO Mustafa S…

unity基础(二)

debug方法 Debug.Log(" 一般日志 ");Debug.LogWarning(" 警告日志 ");Debug.LogError(" 错误日志 ");// Player Informationstring strPlayerName "Peter";int iPlayerHpValue 32500;short shPlayerLevel 10;long lAdvantureExp 1…

爱普生MCU系列语音芯片S1C31D41

随着科技的发展和产品的集成化,语音芯片已经逐渐替代了多种语音设备应用在各场合。语音芯片主要特性是功耗低,抗干扰能力强,外围器件少,控制简单,语音保存时间久(某些语音芯片可以保存内容100年),掉电不丢失…

yolo-world:”目标检测届大模型“

AI应用开发相关目录 本专栏包括AI应用开发相关内容分享,包括不限于AI算法部署实施细节、AI应用后端分析服务相关概念及开发技巧、AI应用后端应用服务相关概念及开发技巧、AI应用前端实现路径及开发技巧 适用于具备一定算法及Python使用基础的人群 AI应用开发流程概…

【Git】Git学习-16:git merge,且解决合并冲突

学习视频链接: 【GeekHour】一小时Git教程_哔哩哔哩_bilibili​编辑https://www.bilibili.com/video/BV1HM411377j/?vd_source95dda35ac10d1ae6785cc7006f365780 1 创建分支dev,并用merge合并master分支,使dev分支合并上master分支中内容为…

[Algorithm][多源BFS][矩阵][飞地的数量][地图中的最高点][地图分析] + 多源BFS原理讲解 详细讲解

目录 0.原理讲解1.矩阵1.题目链接2.算法原理详解3.代码实现 2.飞地的数量1.题目链接2.算法原理详解3.代码实现 3.地图中的最高点1.题目链接2.算法原理详解3.代码实现 4.地图分析1.题目链接2.算法原理详解3.代码实现 0.原理讲解 注意:只要是用**BFS解决的最短路径问题…

韩顺平0基础学Java——第5天

p72——p86 今天同学跟我说别学java,真的吗?唉,先把这视频干完吧。 逻辑运算符练习 x6,y6 x6,y5 x11,y6 x11,y5 z48 错了&a…

【web网页制作】html+css旅游家乡河南开封主题网页制作(4页面)【附源码】

HTMLCSS家乡河南主题网页目录 🍔涉及知识🥤写在前面🍧一、网页主题🌳二、页面效果Page1 首页Page2 开封游玩Page 3 开封美食Page4 留言 🌈 三、网页架构与技术3.1 脑海构思3.2 整体布局3.3 技术说明书 🐋四…

【Git】Git学习-14:VSCode中使用git

学习视频链接:【GeekHour】一小时Git教程_哔哩哔哩_bilibili​编辑https://www.bilibili.com/video/BV1HM411377j/?vd_source95dda35ac10d1ae6785cc7006f365780 在vscode中打开文件 code . 自行修改内容,在源代码管理器中测试下

flutter报错

组件相关 type ‘List’ is not a subtype of type ‘List’ children: CardList.map((item) > Container( 加上 *** < Widget>*** 正常 type ‘(dynamic, dynamic) > Container’ is not a subtype of type ‘(CardType) > Widget’ of ‘f’ children: CardL…

Spring Data JPA自定义Id生成策略、复合主键配置、Auditing使用

前言 在Spring Data JPA系列的第一篇文章 SpringBoot集成JPA及基本使用-CSDN博客 中讲解了实体类的Id生成策略可以通过GeneratedValue注解进行配置&#xff0c;该注解的strategy为GenerationType类型&#xff0c;GenerationType为枚举类&#xff0c;支持四种Id的生成策略&…