深度学习笔记(九)——tf模型导出保存、模型加载、常用模型导出tflite、权重量化、模型部署

文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
本篇博客主要是工具性介绍,可能由于软件版本问题导致的部分内容无法使用。

首先介绍tflite: TensorFlow Lite 是一组工具,可帮助开发者在移动设备、嵌入式设备和 loT 设备上运行模型,以便实现设备端机器学习。
框架具有的主要特性:

  • 延时(数据无需往返服务器)
  • 隐私(没有任何个人数据离开设备)
  • 连接性(无需连接互联网)
  • 大小(缩减了模型和二进制文件的大小)
  • 功耗(高效推断,且无需网络连接)

官方目前支持了大约130中可以量化的算子,在查阅大量资料后目前自定义的算子使用tflite导出任然存在较多问题。就解决常见的算法,使用支持的算子基本可以覆盖。tflite的压缩能力极强:使用官方算子构建的模型,导出TensorFlow Lite 二进制文件的大小约为 1 MB(针对 32 位 ARM build);如果仅使用支持常见图像分类模型(InceptionV3 和 MobileNet)所需的运算符,TensorFlow Lite 二进制文件的大小不到 300 KB。在后文的实例中我们用iris数据集的分类演示,可以导出一个仅仅只有2kb大小的模型权重相比未压缩的70kb模型缩小了30多倍。
同时tflite还实验性的在支持导出极轻量化的TFLM模型(TensorFlow Lite for Microcontrollers),这些模型可以直接在嵌入式单片机上进行推理,不过现阶段支持的算子还很少,简单的可以利用全连接和低向量卷积实现一些传感器参数的识别任务。现在主要的实例场景是MCU+IMU组合,识别IMU连续数据,来判断人体特定动作。同时开可以在MCU上离线运行语音命令识别,可以实现一个关键字的识别。

好了那我们继续看一下怎么保存模型,加载模型,保存tflite,加载tflite

保存权重或TF格式标准模型

通常情况下当完成了网络结构设计,数据处理,网络训练和评价之后需要及时的保存数据。先看到前面博客中已经介绍过的iris数据集实现网络分类任务。当时通过添加保存回调函数实现了网络权重的保存,这样保存下的是网络权重模型,需要配合网络结构的实例化使用。当然tf还提供了很多种模型的保存方式,tf2官方推荐使用tf形式保存,通过这种方式相关文件会保存到一个指定文件夹中,包含模型的权重参数模型结构信息。

ckpt格式

通过回调函数实现动态权重的保存。

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

# 定义网络结构
class IrisModel(Model):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())

    def call(self, x):
        y = self.d1(x)
        return y
# 实例化化模型
model = IrisModel()
# 定义保存和记录数据的回调器
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="./checkpoint/iris/iris.ckpt",  # 保存模型权重参数
                                                 save_weights_only=True,
                                                 save_best_only=True)
# 初始化模型
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'],
              callbacks=[cp_callback])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

在上面的代码中通过tf.keras.callbacks.ModelCheckpoint设置了一个回调器,会动态的在网络训练的过程中保存下参数表现效果最好的权重参数。这里主要保存网络中各个可变参数的值和网络的当前图数据。这个模型无法直接用于推理,应为其不包含网络完整的图信息。所以我们需要在训练结束时保存网络整体的图信息。

.pd格式

pd格式是tf保存静态模型的专用权重文件。在训练完成后直接执行:

model.save('./yor_save_path/model', save_format='tf')    # 保存模型为静态权重

这样就可以把model的全部图信息保存下来了那么怎么保存最好的呢?可以结合上一个.ckpt文件使用。
对比两个的保存方式差距,前者在动态的训练过程中存储数据,后者针对某一个节点的网络状态完整保存。所以可以在训练过程中保存下最好的参数,当训练结束后再加载回最好的动态权重,然后再保存为.pd文件。

加载权重或TF格式标准模型

动态权重和静态图模型的保存不同,加载也不同。加载.ckpt时,需要先实例化网络结构,然后再读取权重参数给实例化的模型赋值。对于静态模型文件则不需要实例化模型,也就是无需关注网络的内部,直接读取加载模型就会完成网络构建和参数赋值两个任务,在部署时明显静态模型模型的程序文件会更加简单。

model = userModer()
model.compile()
checkpoint_save_path = "./yor_file_path.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):     #
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

上面的程序展示从动态图加载权重。下面的程序则直接从静态图加载模型。

model_path = './yor_model_path'      
new_model = tf.keras.models.load_model(model_path)  # 从tf模型加载,无需重新实例化网络

从静态图文件加载模型有便捷之处,但是也需要注意模型的输入和输出结构,要保证预测时输入网络的数据维度是符合要求的,同时根据网络输出的模式接收输出数据做相应处理。

转化模型到tflite

转化模型主要有三种方式:

  • 使用现有的 TensorFlow Lite 模型
  • 创建 TensorFlow Lite 模型
  • 将 TensorFlow 模型转换为TensorFlow Lite 模型

模型的保存就分别对应三个主要函数:
在这里插入图片描述

后续主要介绍使用tf构建网络后从tf模型保存到tflite,并以keras model为主。
首先我们需要上面iris数据集分类的例子,当网络训练结束后,可以使用如下的程序导出:

tflite_save_path = './your_file_path'
os.makedirs(tflite_save_path, exist_ok=True)
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model.
with open(tflite_save_path+'/model.tflite', 'wb') as f:
  f.write(tflite_model)

导出后可以看到model.tflite的模型文件。可以比较上面直接导出的完整模型,这个模型的体积小了很多,更加适合在低算力和存储能力的设备上运行。

从tflite加载模型并执行推理

从tflite上加载模型并推理主要有两个手段:使用完整tf框架加载tf.lite读取;或使用tflite_runtime,这是 TensorFlow Lite 解释器,无需安装所有 TensorFlow 软件包,但是对python版本和系统,硬件有一定的要求。目前tf-runtime支持的平台有:
在这里插入图片描述
在此以外的模型需要拉取完整源码在本地设备上编译执行。
安装了相关的软件环境后,可以使用如下的代码来加载模型并推理:

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=tflite_save_path+'/model.tflite')  # 加载模型
interpreter.allocate_tensors()  # 为模型分配张量参数

# Get input and output tensors.
input_details = interpreter.get_input_details()     # 设置网络输入
output_details = interpreter.get_output_details()   # 设置网络输出

# Test the model on set input data.
input_shape = input_details[0]['shape'] # 获取输入层(第一层)的数据维度
print(input_shape)	# 输出维度结构,便于调试

input_data = np.array([6.0,3.4,4.5,1.3], dtype=np.float32)	# 手动给一组鸢尾花数据
input_data = input_data.reshape([1,4])  # 确保维度相同
print(input_data.shape)
interpreter.set_tensor(input_details[0]['index'], input_data)   # 将数据输入到网络中
interpreter.invoke()  # 运行推理
output_data = interpreter.get_tensor(output_details[0]['index'])    # 获得网络输出

print(output_data)

pred = tf.argmax(output_data, axis=1)   # 网络输出层是softmax,需要找到最大值
print(int(pred))    # 输出最大位置的index

通过上面几行简单的代码就可以在终端设备实现预测推理。使用tf-runtim时只需要做简单修改,将包名替换即可。例如:

import tensorflow as tf  改为:import tflite_runtime.interpreter as tflite
interpreter = tf.lite.Interpreter(model_path=args.model_file) 改为: interpreter = tflite.Interpreter(model_path=args.model_file)

模型量化

对模型执行量化可以进一步解决嵌入式终端设备的痛点。量化模型可以实现:

  • 较小的存储大小:小模型在用户设备上占用的存储空间更少
  • 较小的下载大小:小模型下载到用户设备所需的时间和带宽较少
  • 更少的内存用量:小模型在运行时使用的内存更少,从而释放内存供应用的其他部分使用,并可以转化为更好的性能和稳定性
    tflite支持的量化形式有:
    在这里插入图片描述

训练后量化

训练后量化是一种转换技术,它可以在改善 CPU 和硬件加速器延迟的同时缩减模型大小,且几乎不会降低模型准确率。使用 TensorFlow Lite 转换器将已训练的浮点 TensorFlow 模型转换为 TensorFlow Lite 格式后,可以对该模型进行量化。

动态范围量化

动态范围量化能使模型大小缩减至原来的四分之一,在量化时激活函数始终以浮点格式保存,其它支持的算子会根据损失动态保存为8位整形,以此减小模型体积。在导出时量化模型,设置 optimizations 标记以优化大小:

# 上续训练后的模型
tflite_save_path = './your_file_path'
os.makedirs(tflite_save_path, exist_ok=True)
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"opt_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

加载模型时使用相同的方式加载即可

全整数量化

全整型量化相对更加复杂一些。
在上面导出tflite的过程中,实际是将tf默认的协议缓冲区模型压缩为FlatBuffers的格式,这种格式具有多种优势,例如可缩减大小(代码占用的空间较小)以及提高推断速度(可直接访问数据,无需执行额外的解析/解压缩步骤)。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

上面的两行代码实质是做了模型格式的转换和压缩,并没有调整权重参数和计算格式。在上面的基础上,可以进一步使用动态范围量化:

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_quant = converter.convert()

通过动态范围量化之后模型已经缩小了,但是任然有一部分模型参数是浮点格式,这对存储有效,计算能力有限的设备还是存在限制。
在进行整形量化时需要量化模型内部层和输入输出层。tflite给出了两种量化的方式,第一种量化兼容性相对广泛,但是需要输入一组足够大的代表数据集用来推理量化。这样得到的模型任然有小部分参数会是浮点,这无法支持纯整形计算的硬件。
这里转述官方给出的第二种整形量化方式:
为了量化输入和输出张量,并让转换器在遇到无法量化的运算时引发错误,使用一些附加参数再次转换模型:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_data).batch(yordatabatch).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()

上面的第一个函数是量化的一个必要步骤,要量化可变数据(例如模型输入/输出和层之间的中间体),需要提供 RepresentativeDataset。这是一个生成器函数,它提供一组足够大的输入数据来代表典型值。转换器可以通过该函数估算所有可变数据的动态范围。(相比训练或评估数据集,此数据集不必唯一。)为了支持多个输入,每个代表性数据点都是一个列表,并且列表中的元素会根据其索引被馈送到模型。
通过转化给定数据推理量化,现在模型的输入层和输出层数据已经是整形格式:

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

此时模型已经完全支持全整形设备的计算。
那继续的,将模型文件保存下来:

import pathlib
tflite_models_dir = pathlib.Path("/tmp/user_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
# Save the unquantized/float model:
tflite_model_file = tflite_models_dir/"user_model.tflite"
tflite_model_file.write_bytes(tflite_model)
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"user_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)

执行推理时使用的程序结构和上文介绍的从tflite加载模型并执行推理的内容相同。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/341450.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java21 + SpringBoot3集成easy-captcha实现验证码显示和登录校验

文章目录 前言相关技术简介easy-captcha 实现步骤引入maven依赖定义实体类定义登录服务类定义登录控制器前端登录页面实现测试和验证 总结附录使用Session缓存验证码前端登录页面实现代码 前言 近日心血来潮想做一个开源项目,目标是做一款可以适配多端、功能完备的…

Android.mk和Android.bp的区别和转换详解

Android.mk和Android.bp的区别和转换详解 文章目录 Android.mk和Android.bp的区别和转换详解一、前言二、Android.mk和Android.bp的联系三、Android.mk和Android.bp的区别1、语法:2、灵活性:3、版本兼容性:4、向后兼容性:5、编译区…

鸿蒙开发笔记(二十三):图形展示 Image,Shape,Canvas

1. Image 在应用中显示图片需要使用Image组件实现,Image支持多种图片格式,包括png、jpg、bmp、svg和gif,具体用法请参考Image组件。 Image通过调用接口来创建,接口调用形式如下: Image(src: string | Resource | me…

力扣第92题——反转链表 II(C语言题解)

题目描述 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点&#xff0c;返回 反转后的链表 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], left 2, right 4 输出&#xff1…

精品基于Uniapp+springboot智慧农业环境监测App

《[含文档PPT源码等]精品基于Uniappspringboot智慧农业环境监测App》该项目含有源码、文档、PPT、配套开发软件、软件安装教程、项目发布教程、包运行成功&#xff01; 软件开发环境及开发工具&#xff1a; 开发语言&#xff1a;Java 后台框架&#xff1a;springboot、ssm …

Android双指缩放ScaleGestureDetector检测放大因子大图移动到双指中心点ImageView区域中心,Kotlin

Android双指缩放ScaleGestureDetector检测放大因子大图移动到双指中心点ImageView区域中心&#xff0c;Kotlin 在 Android双击图片放大移动图中双击点到ImageView区域中心&#xff0c;Kotlin-CSDN博客 基础上&#xff0c;这次使用ScaleGestureDetector检测两根手指的缩放动作&a…

一起玩儿物联网人工智能小车(ESP32)——44. 利用红外测距模块GP2Y0E03实现避障小车

摘要&#xff1a;本文介绍使用红外测距模块GP2Y0E03实现避障小车 在前边已经介绍了两种非接触测距的办法&#xff0c;分别是超声波测距和激光测距&#xff0c;在这里&#xff0c;再介绍另一种常用的测距传感器——红外测距传感器。红外测距的工作原理是&#xff0c;利用红外信号…

HuoCMS|免费开源可商用CMS建站系统HuoCMS 2.0下载(thinkphp内核)

HuoCMS是一套基于ThinkPhp6.0Vue 开发的一套HuoCMS建站系统。 HuoCMS是一套内容管理系统同时也是一套企业官网建设系统&#xff0c;能够帮过用户快速搭建自己的网站。可以满足企业站&#xff0c;外贸站&#xff0c;个人博客等一系列的建站需求。HuoCMS的优势: 可以使用统一后台…

数学建模--Radar图绘制

1.Radar图简介 最近在数学建模中碰见需要绘制Radar图(雷达图)的情况来具体分析样本的各个特征之间的得分与优劣关系&#xff0c;这样的情况比较符合雷达图的使用场景&#xff0c;一般来说&#xff0c;雷达图适用于展示多个维度的数据&#xff0c;并在一个平面上直观地呈现出不同…

前端每日一练 “文字穿透效果”

前言 我都不知道用什么样的词来描述这个效果&#xff0c;反正你看吧&#xff01;这个效果看上去很简单&#xff0c;但是一旦实现起来你会发现也不复杂&#xff0c;废话不多说直接上源码&#xff0c;喜欢的点个关注、留个免费的 html源码 <!DOCTYPE html> <html>&…

13.8.1异步、异步、异步 Page720~721

#include <iostream> #include <thread> #include <future>using namespace std;///定时炸弹第一波 void sync_sleep(int s) {cout << "sync_sleep----不使用异步" << endl;///启动定时this_thread::sleep_for(chrono::seconds(s)); /…

《WebKit 技术内幕》学习之十(2): 插件与JavaScript扩展

2 Chromium PPAPI插件 2.1 原理 插件其实是一种统称&#xff0c;表示一些动态库&#xff0c;这些动态库根据定义的一些标准接口可以跟浏览器进行交互&#xff0c;至于这个标准接口是什么都可以&#xff0c;重要的是大家都遵循它们&#xff0c;NPAPI接口标准只是其中的一种&a…

FastDFS分布式文件存储

为什么会有分布式文件系统&#xff1f; 分布式文件系统是面对互联网的需求而产生。因为互联网时代要对海量数据进行存储。很显然靠简单的增加硬盘个数已经满足不了我们的要求。因为硬盘传输速度有限但是数据在急剧增长&#xff0c;另外我们还要要做好数据备份、数据安全等。采用…

初识k8s(概述、原理、安装)

文章目录 概述由来主要功能 K8S架构架构图组件说明ClusterMasterNodekubectl 组件处理流程 K8S概念组成PodPod控制器ReplicationController&#xff08;副本控制器&#xff09;ReplicaSet &#xff08;副本集&#xff09;DeploymentStatefulSet &#xff08;有状态副本集&#…

6 时间序列(不同位置的装置如何建模): GRU+Embedding

很多算法比赛经常会遇到不同的物体产生同含义的时间序列信息&#xff0c;比如不同位置的时间序列信息&#xff0c;风力发电、充电桩用电。经常会遇到该如此场景&#xff0c;对所有数据做统一处理喂给模型&#xff0c;模型很难学到区分信息&#xff0c;因此设计如果对不同位置的…

【Linux】常见指令(一)

前言: Linux有许多的指令&#xff0c;通过学习这些指令&#xff0c;可以对目录及文件进行操作。 文章目录 一、基础指令1. ls—列出目录内容2. pwd—显示当前目录3. cd—切换目录重新认识指令4. touch—创建文件等5. mkdir—创建目录6. rmdir指令 && rm 指令7. man—显…

linux源码编译安装llvm

目录 1 建立文件夹llvm 2 下载源码到llvm文件夹 3 解压上述文件 4 将解压后的3个文件夹改名&#xff0c;并移动到llvm-9.0.0.src中&#xff1a; 5 在llvm文件夹内建立build文件夹&#xff0c;并进入该文件夹&#xff1a; 6 执行cmake命令 7 make 8 安装 9 安装成功后…

[晓理紫]每日论文分享(有中文摘要,源码或项目地址)--机器人、强化学习

专属领域论文订阅 VX 扫吗关注{晓理紫|小李子}&#xff0c;每日更新论文&#xff0c;如感兴趣&#xff0c;请转发给有需要的同学&#xff0c;谢谢支持 如果你感觉对你有帮助可以扫吗关注&#xff0c;每日准时为你推送最新论文 分类: 大语言模型LLM视觉模型VLM扩散模型视觉导航…

Git Docker 学习笔记

注意&#xff1a;该文章摘抄之百度&#xff0c;仅当做学习笔记供小白使用&#xff0c;若侵权请联系删除&#xff01; 目录 列举工作中常用的几个git命令&#xff1f; 提交时发生冲突&#xff0c;你能解释冲突是如何产生的吗&#xff1f;你是如何解决的&#xff1f; git的4个…

安全通信网络

1.网络架构 1&#xff09;应保证网络设备的业务处理能力满足业务高峰期需要。 设备CPU和内存使用率的峰值不大于设备处理能力的70%。 在有监控环境的条件下&#xff0c;应通过监控平台查看主要设备在业务高峰期的资源&#xff08;CPU、内存等&#xff09;使用 情况&#xff…