政安晨:【Keras机器学习实践要点】(十)—— 自定义保存和序列化

目录

导言

涵盖的API

Setup

状态保存自定义

构建和编译保存自定义

结论


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

相较于上一篇文章而言,这是一篇更高级的图层和模型自定义保存指南。

导言

本文涵盖可在 Keras 保存中自定义的高级方法。对于大多数用户来说,上文的初级序列化、保存和导出文章中概述的方法已经足够

涵盖的API

我们将介绍以下应用程序接口:

  • save_assets() 和 load_assets()
  • save_own_variables() 和 load_own_variables()
  • get_build_config() 和 build_from_config()
  • get_compile_config() 和 compile_from_config()

还原模型时,这些操作将按以下顺序执行:

  • build_from_config()
  • compile_from_config()
  • load_own_variables()
  • load_assets()

Setup

import os
import numpy as np
import keras

状态保存自定义

这些方法决定了调用 model.save() 时如何保存模型图层的状态。您可以重载这些方法来完全控制状态保存过程。

save_own_variables() & load_own_variables()

这些方法分别在调用 model.save() 和 keras.models.load_model() 时保存和加载层的状态变量。默认情况下,保存和加载的状态变量是层的权重(可训练和不可训练)。

以下是 save_own_variables() 的默认实现:

def save_own_variables(self, store):
    all_vars = self._trainable_weights + self._non_trainable_weights
    for i, v in enumerate(all_vars):
        store[f"{i}"] = v.numpy()

这些方法使用的存储空间是一个字典,可以填充层变量。

下面我们来看一个自定义的示例:

示例:

@keras.utils.register_keras_serializable(package="my_custom_package")
class LayerWithCustomVariable(keras.layers.Dense):
    def __init__(self, units, **kwargs):
        super().__init__(units, **kwargs)
        self.my_variable = keras.Variable(
            np.random.random((units,)), name="my_variable", dtype="float32"
        )

    def save_own_variables(self, store):
        super().save_own_variables(store)
        # Stores the value of the variable upon saving
        store["variables"] = self.my_variable.numpy()

    def load_own_variables(self, store):
        # Assigns the value of the variable upon loading
        self.my_variable.assign(store["variables"])
        # Load the remaining weights
        for i, v in enumerate(self.weights):
            v.assign(store[f"{i}"])
        # Note: You must specify how all variables (including layer weights)
        # are loaded in `load_own_variables.`

    def call(self, inputs):
        dense_out = super().call(inputs)
        return dense_out + self.my_variable


model = keras.Sequential([LayerWithCustomVariable(1)])

ref_input = np.random.random((8, 10))
ref_output = np.random.random((8, 10))
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(ref_input, ref_output)

model.save("custom_vars_model.keras")
restored_model = keras.models.load_model("custom_vars_model.keras")

np.testing.assert_allclose(
    model.layers[0].my_variable.numpy(),
    restored_model.layers[0].my_variable.numpy(),
)

执行结果:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 101ms/step - loss: 0.2908

save_assets() 和 load_assets()


这些方法可以添加到模型类定义中,以存储和加载模型所需的任何附加信息。

例如,文本矢量化层(TextVectorization layer)和索引查找层(IndexLookup layer)等 NLP 领域层可能需要在保存时将其相关词汇(或查找表)存储到文本文件中。

让我们用一个简单的 assets.txt 文件来了解一下这个工作流程的基本情况。

示例:

@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomAssets(keras.layers.Dense):
    def __init__(self, vocab=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.vocab = vocab

    def save_assets(self, inner_path):
        # Writes the vocab (sentence) to text file at save time.
        with open(os.path.join(inner_path, "vocabulary.txt"), "w") as f:
            f.write(self.vocab)

    def load_assets(self, inner_path):
        # Reads the vocab (sentence) from text file at load time.
        with open(os.path.join(inner_path, "vocabulary.txt"), "r") as f:
            text = f.read()
        self.vocab = text.replace("<unk>", "little")


model = keras.Sequential(
    [LayerWithCustomAssets(vocab="Mary had a <unk> lamb.", units=5)]
)

x = np.random.random((10, 10))
y = model(x)

model.save("custom_assets_model.keras")
restored_model = keras.models.load_model("custom_assets_model.keras")

np.testing.assert_string_equal(
    restored_model.layers[0].vocab, "Mary had a little lamb."
)

构建和编译保存自定义

get_build_config() 和 build_from_config()

这些方法可共同保存图层的构建状态,并在加载时恢复这些状态。

默认情况下,这只包括一个包含图层输入形状的构建配置字典,但重载这些方法可用于包含更多变量和查找表,这些变量和查找表对于恢复构建的模型非常有用。

示例:

@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomBuild(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def call(self, inputs):
        return keras.ops.matmul(inputs, self.w) + self.b

    def get_config(self):
        return dict(units=self.units, **super().get_config())

    def build(self, input_shape, layer_init):
        # Note the overriding of `build()` to add an extra argument.
        # Therefore, we will need to manually call build with `layer_init` argument
        # before the first execution of `call()`.
        super().build(input_shape)
        self._input_shape = input_shape
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer=layer_init,
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer=layer_init,
            trainable=True,
        )
        self.layer_init = layer_init

    def get_build_config(self):
        build_config = {
            "layer_init": self.layer_init,
            "input_shape": self._input_shape,
        }  # Stores our initializer for `build()`
        return build_config

    def build_from_config(self, config):
        # Calls `build()` with the parameters at loading time
        self.build(config["input_shape"], config["layer_init"])


custom_layer = LayerWithCustomBuild(units=16)
custom_layer.build(input_shape=(8,), layer_init="random_normal")

model = keras.Sequential(
    [
        custom_layer,
        keras.layers.Dense(1, activation="sigmoid"),
    ]
)

x = np.random.random((16, 8))
y = model(x)

model.save("custom_build_model.keras")
restored_model = keras.models.load_model("custom_build_model.keras")

np.testing.assert_equal(restored_model.layers[0].layer_init, "random_normal")
np.testing.assert_equal(restored_model.built, True)

get_compile_config() 和 compile_from_config()


这些方法可共同保存编译模型时使用的信息(优化器、损耗等),并使用这些信息恢复和重新编译模型。

重载这些方法对于用自定义优化器、自定义损耗等编译恢复后的模型非常有用,因为在调用 compile_from_config() 中的 model.compile 之前需要对这些信息进行反序列化。

下面我们来看一个例子。

@keras.saving.register_keras_serializable(package="my_custom_package")
def small_square_sum_loss(y_true, y_pred):
    loss = keras.ops.square(y_pred - y_true)
    loss = loss / 10.0
    loss = keras.ops.sum(loss, axis=1)
    return loss


@keras.saving.register_keras_serializable(package="my_custom_package")
def mean_pred(y_true, y_pred):
    return keras.ops.mean(y_pred)


@keras.saving.register_keras_serializable(package="my_custom_package")
class ModelWithCustomCompile(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense1 = keras.layers.Dense(8, activation="relu")
        self.dense2 = keras.layers.Dense(4, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

    def compile(self, optimizer, loss_fn, metrics):
        super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
        self.model_optimizer = optimizer
        self.loss_fn = loss_fn
        self.loss_metrics = metrics

    def get_compile_config(self):
        # These parameters will be serialized at saving time.
        return {
            "model_optimizer": self.model_optimizer,
            "loss_fn": self.loss_fn,
            "metric": self.loss_metrics,
        }

    def compile_from_config(self, config):
        # Deserializes the compile parameters (important, since many are custom)
        optimizer = keras.utils.deserialize_keras_object(config["model_optimizer"])
        loss_fn = keras.utils.deserialize_keras_object(config["loss_fn"])
        metrics = keras.utils.deserialize_keras_object(config["metric"])

        # Calls compile with the deserialized parameters
        self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics)


model = ModelWithCustomCompile()
model.compile(
    optimizer="SGD", loss_fn=small_square_sum_loss, metrics=["accuracy", mean_pred]
)

x = np.random.random((4, 8))
y = np.random.random((4,))

model.fit(x, y)

model.save("custom_compile_model.keras")
restored_model = keras.models.load_model("custom_compile_model.keras")

np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer)
np.testing.assert_equal(model.loss_fn, restored_model.loss_fn)
np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics)

执行如下:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - accuracy: 0.0000e+00 - loss: 0.0627 - mean_metric_wrapper: 0.2500

结论

使用本文中学到的方法可以实现多种使用情况,保存和加载具有特殊资产和状态元素 的复杂模型。

总结一下:

save_own_variables 和 load_own_variables 决定了保存和加载状态的方式。
save_assets 和 load_assets 可用于存储和加载模型所需的任何附加信息。
get_build_config 和 build_from_config 用于保存和恢复模型的构建状态。
get_compile_config 和 compile_from_config 保存和恢复模型的编译状态。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/503844.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

线程的安全问题

目录 导言&#xff1a; 正文&#xff1a; 1.共享资源&#xff1a; 2.非原子操作&#xff1a; 3.执行顺序不确定&#xff1a; 4.可见性&#xff1a; 5.死锁和饥饿&#xff1a; 6.指令重排序&#xff1a; 总结&#xff1a; 导言&#xff1a; 线程安全是并发编程中的一个…

文献阅读:使用 CellChat 推理和分析细胞-细胞通信

文献介绍 「文献题目」 Inference and analysis of cell-cell communication using CellChat 「研究团队」 聂青&#xff08;加利福尼亚大学欧文分校&#xff09; 「发表时间」 2021-02-17 「发表期刊」 Nature Communications 「影响因子」 16.6 「DOI」 10.1038/s41467-0…

Vue3 使用 v-bind 动态绑定 CSS 样式

在 Vue3 中&#xff0c;可以通过 v-bind 动态绑定 CSS 样式。 语法格式&#xff1a; color: v-bind(数据); 基础使用&#xff1a; <template><h3 class"title">我是父组件</h3><button click"state !state">按钮</button>…

解析CUDA FATBIN格式

参考文档&#xff1a; https://pdfs.semanticscholar.org/5096/25785304410039297b741ad2007e7ce0636b.pdf CUDA Pro Tip: Understand Fat Binaries and JIT Caching | NVIDIA Technical Blog cuda二进制文件中到底有什么 - 知乎 NVIDIA CUDA Compiler Driver NVIDIA CUDA…

HSP_04章_扩展: 进制、位运算

文章目录 10. 扩展: 进制11. 位运算11.1 二进制在运算中的说明11.2 原码 反码 补码11.3位运算符11.3.1 ~按位取反11.3.2 &按位与11.3.3 ^按位异或11.3.4 |按位或11.3.5 << 左移11.3.6 >> 右移 10. 扩展: 进制 进制介绍 进制的转换 2.1 其他进制转十进制 二进…

(八)目标跟踪中参数估计(似然、贝叶斯估计)理论知识

目录 前言 一、统计学基础知识 &#xff08;一&#xff09;随机变量 &#xff08;二&#xff09;全概率公式 &#xff08;三&#xff09;高斯分布及其性质 二、似然是什么&#xff1f; &#xff08;一&#xff09;概率和似然 &#xff08;二&#xff09;极大似然估计 …

国内顶级大牛整理:分布式消息中间件实践笔记+分布式核心原理解析

XMPP JMS RabbitMQ 简介 工程实例 Java 访问RabbitMQ实例 Spring 整合RabbitMQ 基于RabbitMQ的异步处理 基于RabbitMQ的消息推送 RabbitMQ实践建议 虚拟主机 消息保存 消息确认模式 消费者应答 流控机制 通道 总结 ActiveMQ 简介 工程实例 Java 访问ActiveMQ实例…

机器人寻路算法双向A*(Bidirectional A*)算法的实现C++、Python、Matlab语言

机器人寻路算法双向A*&#xff08;Bidirectional A*&#xff09;算法的实现C、Python、Matlab语言 最近好久没更新&#xff0c;在搞华为的软件挑战赛&#xff08;软挑&#xff09;&#xff0c;好卷只能说。去年还能混进32强&#xff0c;今年就比较迷糊了&#xff0c;这东西对我…

EasyRecovery2024汉化精简版,无需注册

EasyRecovery2024是世界著名数据恢复公司 Ontrack 的技术杰作&#xff0c;它是一个威力非常强大的硬盘数据恢复软件。能够帮你恢复丢失的数据以及重建文件系统。 EasyRecovery不会向你的原始驱动器写入任何东东&#xff0c;它主要是在内存中重建文件分区表使数据能够安全地传输…

FL Studio21.2.3中文版软件新功能介绍及下载安装步骤教程

FL Studio21.2中文版的适用人群非常广泛&#xff0c;主要包括以下几类&#xff1a; FL Studio 21 Win-安装包下载如下: https://wm.makeding.com/iclk/?zoneid55981 FL Studio 21 Mac-安装包下载如下: https://wm.makeding.com/iclk/?zoneid55982 音乐制作人&#xff1a…

C#/BS手麻系统源码 手术麻醉管理系统源码 商业项目源码

C#/BS手麻系统源码 手术麻醉管理系统源码 商业项目源码 手麻系统从麻醉医生实际工作环境和流程需求方面设计&#xff0c;与HIS&#xff0c;LIS&#xff0c;PACS&#xff0c;EMR无缝连接&#xff0c;方便查看患者的信息;实现术前、术中、术后手术麻醉信息全记录;减少麻醉医师在…

Spring Boot配置⽂件的格式

1、Spring Boot 配置⽂件有以下三种&#xff1a; &#xff08;1&#xff09;application.properties &#xff08;2&#xff09;application.yml &#xff08;3&#xff09;application.yaml 2、yml 为yaml的简写, 实际开发中出现频率最⾼. yaml 和yml 的使⽤⽅式⼀样 3、 4…

Vue + .NetCore前后端分离,不一样的快速发开框架

摘要&#xff1a; 随着前端技术的快速发展&#xff0c;Vue.NetCore框架已成为前后端分离开发中的热门选择。本文将深入探讨Vue.NetCore前后端分离的快速开发框架&#xff0c;以及它如何助力开发人员提高效率、降低开发复杂度。文章将从基础功能、核心优势、适用范围、依赖环境等…

软考之零碎片段记录(一)

2023上半年选择题 一、流水线周期 注&#xff1a;&#xff08;n-1) * 流水线周期 &#xff08;取址时间分析时间执行时间&#xff09; 注&#xff1a;流水线周期&#xff1a;指令中最耗时的部分&#xff08;取址或者分析或者执行&#xff09; # 耗时最高的部分 * &#xff0…

单例设计模式(3)

单例模式&#xff08;3&#xff09; 实现集群环境下的分布式单例类 如何理解单例模式中的唯一性&#xff1f; 单例模式创建的对象是进程唯一的。以springboot应用程序为例&#xff0c;他是一个进程&#xff0c;可能包含多个线程&#xff0c;单例代表在这个进程的某个类是唯一…

Unity 基于Rigidbody2D模块的角色移动

制作好站立和移动的动画后 控制器设计 站立 移动 角色移动代码如下&#xff1a; using System.Collections; using System.Collections.Generic; using Unity.VisualScripting; using UnityEngine;public class p1_c : MonoBehaviour {// 获取动画组件private Animator …

LeetCode Python - 84. 柱状图中最大的矩形

目录 题目描述解法方法一方法二 运行结果方法一方法二 题目描述 给定 n 个非负整数&#xff0c;用来表示柱状图中各个柱子的高度。每个柱子彼此相邻&#xff0c;且宽度为 1 。 求在该柱状图中&#xff0c;能够勾勒出来的矩形的最大面积。 示例 1: 输入&#xff1a;heights …

《最小阻力之路》利用最小阻力路径,采用创造性思维模式,更有效地实现个人愿景和目标 - 三余书屋 3ysw.net

最小阻力之路 大家好&#xff0c;今天我们分享《最小阻力之路》。我们时常听到有人感叹&#xff0c;明明懂得那么多道理&#xff0c;为何生活过得不如意呢&#xff1f;这本书从某种角度回应了这个疑问&#xff0c;作者分析了我们在人生旅途中屡次失败的原因&#xff0c;提出了…

图像分割论文阅读:Automatic Polyp Segmentation via Multi-scale Subtraction Network

这篇论文的主要内容是介绍了一种名为多尺度差值网络&#xff08;MSNet&#xff09;的自动息肉分割方法。 1&#xff0c;模型整体结构 整体结构包括编码器&#xff0c;解码器&#xff0c;编码器和解码器之间是多尺度差值模块模块&#xff08;MSM&#xff09;&#xff0c;以及一…

使用Python实现ID3决策树中特征选择的先后顺序,字节跳动面试真题

def empty1(pri_data): hair [] #[‘长’, ‘短’, ‘短’, ‘长’, ‘短’, ‘短’, ‘长’, ‘长’] voice [] #[‘粗’, ‘粗’, ‘粗’, ‘细’, ‘细’, ‘粗’, ‘粗’, ‘粗’] sex [] #[‘男’, ‘男’, ‘男’, ‘女’, ‘女’, ‘女’, ‘女’, ‘女’] for o…