政安晨:【掌握AI的深度学习工具Keras API】(二)—— 【使用内置的训练循环和评估循环】

渐进式呈现复杂性,是指采用一系列从简单到灵活的工作流程,并逐步提高复杂性。这个原则也适用于模型训练。Keras提供了训练模型的多种工作流程。这些工作流程可以很简单,比如在数据上调用fit(),也可以很高级,比如从头开始编写新的训练算法。

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: 政安晨的机器学习笔记

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!


开始

你已经熟悉compile()、fit()、evaluate()和predict()的工作流程。

咱们看下面的代码,走一下这个流程:

标准工作流程compile()  ->  fit()  ->   evaluate()  ->  predict()

from tensorflow.keras.datasets import mnist

# 创建模型(我们将其包装为一个单独的函数,以便后续复用)
def get_mnist_model(): 
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

# 加载数据,保留一部分数据用于验证
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

model = get_mnist_model()

# (本行及以下3行)编译模型,指定模型的优化器、需要最小化的损失函数和需要监控的指标
model.compile(optimizer="rmsprop", 
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

# (本行及以下3行)使用fit()训练模型,可以选择提供验证数据来监控模型在前所未见的数据上的性能
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))

# 使用evaluate()计算模型在新数据上的损失和指标
test_metrics = model.evaluate(test_images, test_labels) 

# 使用predict()计算模型在新数据上的分类概率
predictions = model.predict(test_images)

要想自定义这个简单的工作流程,可以采用以下方法:

编写自定义指标;

向fit()方法传入调函数,以便在训练过程中的特定时间点采取行动。

咱们接下来进一步讨论。

编写自定义指标

指标是衡量模型性能的关键,尤其是衡量模型在训练数据上的性能与在测试数据上的性能之间的差异。常用的分类指标和回归指标内置于keras.metrics模块中。大多数情况下,你会使用这些指标。但如果想做一些不寻常的工作,你需要能够编写自定义指标

小伙伴们不要怕,这并不难,很简单!

Keras指标是keras.metrics.Metric类的子类。与层相同的是,指标具有一个存储在TensorFlow变量中的内部状态。与层不同的是,这些变量无法通过反向传播进行更新,所以你必须自己编写状态更新逻辑。这一逻辑由update_state()方法实现。

举个例子,如下代码实现了一个简单的自定义指标,用于衡量均方根误差(RMSE)

通过将Metric类子类化来实现自定义指标

import tensorflow as tf

# 将Metric类子类化
class RootMeanSquaredError(keras.metrics.Metric):

    # (本行及以下4行)在构造函数中定义状态变量。与层一样,你可以访问add_weight()方法
    def __init__(self, name="rmse", **kwargs):

        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
        self.total_samples = self.add_weight(
            name="total_samples", initializer="zeros", dtype="int32")

    # 为了匹配MNIST模型,我们需要分类预测值与整数标签
    def update_state(self, y_true, y_pred, sample_weight=None):

    # 在update_state()中实现状态更新逻辑。y_true参数是一个数据批量对应的目标(或标签),y_pred则表示相应的模型预测值。你可以忽略sample_weight参数,这里不会用到
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

我们可以使用result()方法返回指标的当前值。

    def result(self):
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

此外,你还需要提供一种方法来重置指标状态,而无须将其重新实例化。

如此一来,相同的指标对象可以在不同的训练轮次中使用,或者在训练和评估中使用。

这可以用reset_state()方法来实现。

    def reset_state(self):
        self.mse_sum.assign(0.)
        self.total_samples.assign(0)

自定义指标的用法与内置指标相同。下面来测试一下我们的自定义指标

model = get_mnist_model()

model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy", RootMeanSquaredError()])

model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)

你可以看到fit()的进度条,上面显示模型的RMSE。

使用调函数

使用model.fit()在大型数据集上启动数十轮训练,这样做有点类似于投掷纸飞机:最初给它一点推力,之后你就再也无法控制它的轨迹或着陆点。如果想避免得到不好的结果(从而避免浪费纸飞机),更聪明的做法是,不用纸飞机,而用一架无人机。它可以感知环境,向操作者发送数据,并且能够根据当前状态自主航行。

Keras的回调函数(callback)API可以让model.fit()的调用从纸飞机变为自主飞行的无人机,使其能够观察自身状态并不断采取行动。

回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用。

回调函数可以访问关于模型状态与模型性能的所有可用数据,还可以采取以下行动:中断训练、保存模型、加载一组不同的权重或者改变模型状态。

回调函数的一些用法示例如下:

模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。

提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。

在训练过程中动态调节某些参数值:比如调节优化器的学习率。

在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。

keras.callbacks模块包含许多内置的回调函数,下面列出了其中一些,还有很多没有列出来:

keras.callbacks.ModelCheckpoint
keras.callbacks.EarlyStopping
keras.callbacks.LearningRateScheduler
keras.callbacks.ReduceLROnPlateau
keras.callbacks.CSVLogger

下面介绍两个回调函数EarlyStopping  ModelCheckpoint,让你大致了解回调函数的用法。

调函数EarlyStopping和ModelCheckpoint

训练模型时,很多事情一开始无法预测,尤其是你无法预测需要多少轮才能达到最佳验证损失。

前面所有例子都采用这样一种策略:训练足够多的轮次,这时模型已经开始过拟合,利用第一次运行确定最佳训练轮数,然后用这个最佳轮数从头开始重新训练一次。当然,这种方法很浪费资源。一种更好的处理方法是,发现验证损失不再改善时,停止训练。这可以通过EarlyStopping回调函数来实现。

如果监控的目标指标在设定的轮数内不再改善,那么可以用EarlyStopping回调函数中断训练。比如,这个回调函数可以在刚开始过拟合时就立即中断训练,从而避免用更少的轮数重新训练模型。这个回调函数通常与ModelCheckpoint结合使用,后者可以在训练过程中不断保存模型(你也可以选择只保存当前最佳模型,即每轮结束后具有最佳性能的模型)。

如下代码展示了如何在fit()方法中使用callbacks参数。

在fit()方法中使用callbacks参数

# 通过fit()的callbacks参数将回调函数传入模型中,该参数接收一个回调函数列表,可以传入任意数量的回调函数
callbacks_list = [
    # 如果不再改善,则中断训练
    keras.callbacks.EarlyStopping(
        # 监控模型的验证精度
        monitor="val_accuracy",
        # 如果精度在两轮内都不再改善,则中断训练
        patience=2,
    ),
    # 在每轮过后保存当前权重
    keras.callbacks.ModelCheckpoint(
        # 模型文件的保存路径
        filepath="checkpoint_path.keras",
        # (本行及以下1行)这两个参数的含义是,只有当val_loss改善时,才会覆盖模型文件,这样就可以一直保存训练过程中的最佳模型
        monitor="val_loss",
        save_best_only=True,
    )
]
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              # 监控精度,它应该是模型指标的一部分
              metrics=["accuracy"])
# (本行及以下3行)因为回调函数要监控验证损失和验证指标,所以在调用fit()时需要传入validation_data(验证数据)
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=callbacks_list,
          validation_data=(val_images, val_labels))

注意,你也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。

要重新加载已保存的模型,只需使用下面这行代码:

model = keras.models.load_model("checkpoint_path.keras")

编写自定义调函数

如果想在训练过程中采取特定行动,而这些行动又没有包含在内置回调函数中,那么你可以编写自定义回调函数。

回调函数的实现方式是将keras.callbacks.Callback类子类化。

然后,你可以实现下列方法(从名称中即可看出这些方法的作用),它们在训练过程中的不同时间点被调用。

# 在每轮开始时被调用
on_epoch_begin(epoch, logs)

# 在每轮结束时被调用
on_epoch_end(epoch, logs)

# 在处理每个批量之前被调用
on_batch_begin(batch, logs)

# 在处理每个批量之后被调用
on_batch_end(batch, logs)

# 在训练开始时被调用
on_train_begin(logs)

# 在训练结束时被调用
on_train_end(logs)

调用这些方法时,都会用到参数logs。

这个参数是一个字典,它包含前一个批量、前一个轮次或前一次训练的信息,比如训练指标和验证指标等。on_epoch_*方法和on_batch_*方法还将轮次索引或批量索引作为第一个参数(整数)。

如下代码给出了一个简单示例,它在训练过程中保存每个批量损失值组成的列表,还在每轮结束时保存这些损失值组成的图。

如下代码通过对Callback类子类化来创建自定义回调函数

from matplotlib import pyplot as plt

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}")
        self.per_batch_losses = []

咱们来测试一下

model = get_mnist_model()

model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

自定义回调函数LossHistory的输出图像

利用TensorBoard进行监控和可视化

要想做好研究或开发出好的模型,你在实验过程中需要获得丰富且频繁的反馈,从而了解模型内部发生了什么。

这正是运行实验的目的:获取关于模型性能好坏的信息,并且越多越好。

取得进展是一个反复迭代的过程,或者说是一个循环:

首先,你有一个想法,并将其表述为一个实验,用于验证你的想法是否正确;

然后,你运行这个实验并处理生成的信息;

这又激发了你的下一个想法。在这个循环中,重复实验的次数越多,你的想法就会变得越来越精确、越来越强大。

Keras可以帮你尽快将想法转化成实验,高速GPU则可以帮你尽快得到实验结果。但如何处理实验结果呢?这就需要TensorBoard发挥作用了,如下图所示:

TensorBoard是一个基于浏览器的应用程序,可以在本地运行。它是在训练过程中监控模型的最佳方式。利用TensorBoard,你可以做以下工作:

在训练过程中以可视化方式监控指标

将模型架构可视化

将激活函数和梯度的直方图可视化

以三维形式研究嵌入

如果监控除模型最终损失之外的更多信息,则可以更清楚地了解模型做了什么、没做什么,并且能够更快地取得进展。

要将TensorBoard与Keras模型和fit()方法一起使用,最简单的方式就是使用keras.callbacks.TensorBoard回调函数。

在最简单的情况下,只需指定让回调函数写入日志的位置即可。

model = get_mnist_model()

model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

tensorboard = keras.callbacks.TensorBoard(
    log_dir="/full_path_to_your_log_dir",
)

model.fit(train_images, train_labels,
          epochs=10,
          validation_data=(val_images, val_labels),
          callbacks=[tensorboard])

一旦开始运行,模型就将在目标位置写入日志。

如果在本地计算机上运行Python脚本,那么可以使用下列命令来启动TensorBoard本地服务器。(注意,如果你是通过pip安装TensorFlow的,那么tensorboard可执行文件应该已经可用;如果不可用,你可以通过pip install tensorboard手动安装TensorBoard。

tensorboard --logdir /full_path_to_your_log_dir

然后可以访问该命令返回的URL,以显示TensorBoard界面。

如果在Colab笔记本中运行脚本,则可以使用以下命令,将TensorBoard嵌入式实例作为笔记本的一部分运行。

%load_ext tensorboard
%tensorboard --logdir /full_path_to_your_log_dir

在TensorBoard界面中,你可以实时监控训练指标和评估指标的图像,如下图所示:

TensorBoard可用于监控训练指标和评估指标


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/418404.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ShardingSphere inline表达式线程安全问题定位

ShardingSphere inline表达式线程安全问题定位 问题背景 春节期间发现 ShardingSphere 事务 E2E 偶发执行失败问题,并且每次执行失败需要执行很久,直到超时。最终定位发现 inline 表达式存在线程安全问题。本文记录定位并解决 inline 表达式线程安全问…

实验笔记之——Ubuntu20.04配置nvidia以及cuda并测试3DGS与SIBR_viewers

之前博文测试3DGS的时候一直用服务器进行开发,没有用过笔记本,本博文记录下用笔记本ubuntu20.04配置过程~ 学习笔记之——3D Gaussian Splatting源码解读_3dgs运行代码-CSDN博客文章浏览阅读3.2k次,点赞34次,收藏62次…

编写科技项目验收测试报告需要注意什么?第三方验收测试多少钱?

科技项目验收测试是一个非常重要的环节,它对于确保科技项目的质量和可用性起着至关重要的作用。在项目完成后,进行科技项目验收测试可以评估项目的功能、性能和可靠性等方面,并生成科技项目验收测试报告,以提供给项目的相关方参考…

keil uv5 map文件解析

map参考博客:https://www.csdn.net/tags/MtjaYgwsMTY2NzUtYmxvZwO0O0OO0O0O.html 配置外部flash存储代码:https://strongerhuang.blog.csdn.net/article/details/51485903?spm1001.2101.3001.6650.4&utm_mediumdistribute.pc_relevant.none-task-bl…

使用 Helm 安装 极狐GitLab

本篇作者 徐晓伟 使用 Helm 简便快捷的部署与管理 极狐GitLab 前提条件 k8s 完成 helm 的配置 k8s 完成 ingress 的配置 内存至少 10G 演示环境是 龙蜥 Anolis 8.4(即:CentOS 8.4)最小化安装k8s 版本 1.28.2calico 版本 3.26.1nginx ingre…

Dockerfile(5) - CMD 指令详解

CMD 指定容器默认执行的命令 # exec 形式,推荐 CMD ["executable","param1","param2"] CMD ["可执行命令", "参数1", "参数2"...]# 作为ENTRYPOINT的默认参数 CMD ["param1","param…

高瓴张磊入籍新加坡,这代表了什么?

文|新熔财经 作者|显洋 这两天,海外媒体报道了中国投资大佬与企业家拿到新加坡永居的事儿。本来乏善可陈的文章,却因为一个人名的出现变得有趣起来——高瓴创始人张磊,一位曾经在国内如日中天,但今天鲜少…

论文阅读:2020GhostNet华为轻量化网络

创新:(1)对卷积进行改进(2)加残差连接 1、Ghost Module 1、利用1x1卷积获得输入特征的必要特征浓缩。利用1x1卷积对我们输入进来的特征图进行跨通道的特征提取,进行通道的压缩,获得一个特征浓…

解放设计师的创造力:免版的图片素材

title: 解放设计师的创造力:免版的图片素材 date: 2024/2/29 15:10:19 updated: 2024/2/29 15:10:19 tags: 版权无忧创意自由设计效率视觉提升广告设计UI/UX素材移动应用 在设计领域,设计师常常需要使用图片素材来增加作品的视觉效果。然而,…

Docker技术概论(1):Docker与虚拟化技术比较

Docker技术概论(1) Docker与虚拟化技术比较 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https:…

从 Flask 切到 FastAPI 后,起飞了!

我这几天上手体验 FastAPI,感受到这个框架易用和方便。之前也使用过 Python 中的 Django 和 Flask 作为项目的框架。Django 说实话上手也方便,但是学习起来有点重量级框架的感觉,FastAPI 带给我的直观体验还是很轻便的,本文就会着…

LeetCode34.在排序数组中查找元素的第一个和最后一个位置

题目 给你一个按照非递减顺序排列的整数数组 nums,和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target,返回 [-1, -1]。 你必须设计并实现时间复杂度为 O(log n) 的算法解决此问题。 示例 输入…

尚硅谷Java数据结构--希尔排序

插入排序的问题🎈: arr{2,3,4,5,6,0,9,7,8}; 当0作为插入元素的时候,其待插入下标与原下标相差很远,需要进行多次比较和移动。 希尔排序则是先将下标相差一定距离gap的元素分为一组,进行插入排序;再逐渐将距…

Flutter(四):SingleChildScrollView、GridView

SingleChildScrollView、GridView 遇到的问题 以下代码会报错: class GridViewPage extends StatefulWidget {const GridViewPage({super.key});overrideState<GridViewPage> createState() > _GridViewPage(); }class _GridViewPage extends State<GridViewPage&g…

Maven下载、安装、配置教程

maven是一个项目管理的工具&#xff0c;maven自身是纯java开发的&#xff0c;可以使用maven对java项目进行构建、依赖管理。 通常我们靠手动下载jar包引入项目中是非常浪费时间的&#xff0c;我们可以通过maven工具帮我们导入jar包提高开发效率。 第一步&#xff1a;下载Mave…

Docker技术概论(3):Docker 中的基本概念

Docker技术概论&#xff08;3&#xff09; Docker 中的基本概念 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://…

vivo 在离线混部探索与实践

作者&#xff1a;来自 vivo 互联网服务器团队 本文根据甘青、黄荣杰老师在“2023 vivo开发者大会"现场演讲内容整理而成。 伴随 vivo 互联网业务的高速发展&#xff0c;数据中心的规模不断扩大&#xff0c;成本问题日益突出。在离线混部技术可以在保证服务质量的同时&…

【探索AI】十二 深度学习之第2周:深度神经网络(一)深度神经网络的结构与设计

第2周&#xff1a;深度神经网络 将从以下几个部分开始学习&#xff0c;第1周的概述有需要详细讲解的的同学自行百度&#xff1b; 深度神经网络的结构与设计 深度学习的参数初始化策略 过拟合与正则化技术 批标准化与Dropout 实践&#xff1a;使用深度学习框架构建简单的深度神…

红队基础设施建设

文章目录 一、ATT&CK二、T1583 获取基础架构2.1 匿名网络2.2 专用设备2.3 渗透测试虚拟机 三、T1588.002 C23.1 开源/商用 C23.1.1 C2 调研SliverSliver 对比 CS 3.1.2 CS Beacon流量分析流量规避免杀上线 3.1.3 C2 魔改3.1.4 C2 隐匿3.1.5 C2 准入应用场景安装配置说明工具…

安卓cpu内存监控,大厂首发

开头 很多人工作了十年&#xff0c;但只是用一年的工作经验做了十年而已。 高级工程师一直是市场所需要的&#xff0c;然而很多初级工程师在进阶高级工程师的过程中一直是一个瓶颈。 移动研发在最近两年可以说越来越趋于稳定&#xff0c;因为越来越多人开始学习Android开发&…