政安晨:【Keras机器学习实践要点】(七)—— 使用TensorFlow自定义fit()

目录

前言

导入

来一个简单例子

下沉到更低的级别

支持样本权重和类别权重

提供您自己的评估步骤

总结:一个端到端的GAN示例


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

在TensorFlow中,fit()是一个非常强大和常用的训练函数,它可以批次地训练模型并监测其性能。虽然fit()提供了很多有用的默认行为,但有时您可能想自定义fit()中发生的操作。

一种自定义fit()的方法是使用回调函数。回调函数是在训练过程中的特定时间点被调用的函数,您可以编写自己的回调函数来执行特定的操作。

TensorFlow提供了许多内置的回调函数,如EarlyStoppingCallback、ModelCheckpointCallback等,您也可以编写自己的自定义回调函数。通过将回调函数传递给fit()函数的callbacks参数,您可以自定义在每个训练批次或训练周期结束时发生的操作。

另一种自定义fit()的方法是编写自定义的训练循环。

默认情况下,fit()函数使用单个步骤来执行训练循环,但您可以重写这个步骤来实现自己的训练逻辑。您可以使用TensorFlow的GradientTape来手动计算梯度,并使用优化器来更新模型的权重。通过这种方式,您可以完全控制训练过程中的每个步骤。

最后,您还可以自定义fit()的行为通过设置fit()函数的其他参数。

例如,您可以通过设置batch_size参数来定义每个训练批次的大小,或者通过设置epochs参数来指定训练周期的数量。除了这些参数,您还可以设置其他参数,如学习率、损失函数等,以进一步自定义fit()的行为。

总而言之,TensorFlow提供了多种方法来自定义fit()函数中发生的操作。无论您是通过回调函数、自定义训练循环还是设置fit()的参数,您都可以根据自己的需求来定制训练过程。

这使得TensorFlow成为一个非常灵活和强大的深度学习框架。

今天我们讲的就是keras api 在这里的应用方法


前言

当你进行监督学习时,你可以使用fit(),一切都运行得很顺利。

当你需要控制每一个细节时,你可以完全从头开始编写自己的训练循环。

但是如果你需要一个自定义的训练算法,但仍然希望从fit()的便利功能中受益,比如回调函数,内置的分发支持或步骤融合,该怎么办呢?

Keras的一个核心原则是逐步揭示复杂性。您应该总是能够逐渐进入更低级的工作流程。如果高级功能与您的使用情况不完全匹配,您不应该感到掉入悬崖。您应该能够在保留相应数量高级便利的同时获得对细节的更多控制。

当你需要自定义fit()函数的行为时,你应该重写Model类的训练步骤函数。这是fit()函数在处理每个数据批次时调用的函数。然后,你仍然可以像往常一样调用fit()函数 - 它将会运行你自己的学习算法。

请注意这种模式并不妨碍您使用函数式API构建模型。无论您是构建顺序模型、函数式API模型还是子类化模型,都可以使用这种模式。

现在让我们看看这一切都是怎么工作的吧。

导入

import os

# This guide can only be run with the TF backend.
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
from keras import layers
import numpy as np

来一个简单例子

我们创建一个新的类,继承自keras.Model。

我们只需重写train_step(self, data)方法。

我们返回一个将指标名称(包括损失)映射到它们当前值的字典。

输入参数data是传递给fit作为训练数据的内容:

如果你通过调用fit(x, y, ...)传递NumPy数组,则data将是元组(x, y)

如果你通过调用fit(dataset, ...)传递tf.data.Dataset,则data将是每个批次由数据集产生的内容。

在train_step()方法的主体中,我们实现了一个常规的训练更新,与你已经熟悉的类似。重要的是,我们通过self.compute_loss()计算损失,它包装了传递给compile()的损失函数。

类似地,我们对self.metrics中的指标调用metric.update_state(y, y_pred)来更新在compile()中传递的指标的状态,并在最后从self.metrics查询结果以获取它们的当前值。

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compute_loss(y=y, y_pred=y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply(gradients, trainable_vars)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

现在让咱们试试:

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)

下沉到更低的级别

当然,你可以在compile()中跳过传递损失函数,而是在train_step中手动完成所有操作。同样,在指标方面也是如此。

这是一个低级示例,只使用compile()配置优化器:

我们首先创建度量实例来跟踪我们的损失和MAE得分(在__init__()中)。

我们实现了一个自定义的train_step()函数,该函数通过调用update_state()来更新这些度量的状态,然后通过调用result()来查询它们的当前平均值,以便在进度条中显示并传递给任何回调函数。

注意,我们需要在每个epoch之间调用reset_states()来重置我们的度量!

否则,调用result()将返回训练开始以来的平均值,而我们通常使用每个epoch的平均值。

幸运的是,框架可以为我们做到这一点:只需将要重置的任何度量列在模型的metrics属性中。

模型将在每次fit()的epoch开始或调用evaluate()的开始时调用此处列出的任何对象的reset_states()函数。

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute our own loss
            loss = self.loss_fn(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply(gradients, trainable_vars)

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(y, y_pred)
        return {
            "loss": self.loss_tracker.result(),
            "mae": self.mae_metric.result(),
        }

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [self.loss_tracker, self.mae_metric]


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't pass a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)

支持样本权重和类别权重

您可能已经注意到,我们的第一个基本示例没有提及样本加权。

如果您想支持fit()函数的sample_weight和class_weight参数,只需按照以下步骤操作:

从data参数中解包sample_weight 将其传递给compute_loss和update_state函数(当然,如果您不依赖compile()函数来计算损失和指标,您也可以手动应用它) 就是这样。

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value.
            # The loss function is configured in `compile()`.
            loss = self.compute_loss(
                y=y,
                y_pred=y_pred,
                sample_weight=sample_weight,
            )

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply(gradients, trainable_vars)

        # Update the metrics.
        # Metrics are configured in `compile()`.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)

提供您自己的评估步骤

如果你想对model.evaluate()的调用做同样的操作,那么你可以通过同样的方式覆盖test_step。下面是实现的样例:

class CustomModel(keras.Model):
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_pred = self(x, training=False)
        # Updates the metrics tracking the loss
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # Update the metrics.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)

总结:一个端到端的GAN示例

让我们通过一个端到端的示例来演示你刚刚学到的所有内容。

让我们考虑以下内容:

一个生成器网络,用于生成28x28x1的图像。

一个判别器网络,用于将28x28x1的图像分为两类("假"和"真")。

每个网络都有一个优化器。 一个损失函数,用于训练判别器。

# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

这是一个完整的GAN类,它重写了compile()方法以使用自己的签名,并在train_step中使用17行代码实现了整个GAN算法。

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.seed_generator = keras.random.SeedGenerator(1337)

    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * keras.random.uniform(
            tf.shape(labels), seed=self.seed_generator
        )

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply(grads, self.discriminator.trainable_weights)

        # Sample random points in the latent space
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply(grads, self.generator.trainable_weights)

        # Update metrics and return their value.
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }

让我们试试吧:

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)

深度学习背后的思想很简单吧。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/506458.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】TCP网络套接字编程+守护进程

文章目录 日志类(完成TCP/UDP套接字常见连接过程中的日志打印)单进程版本的服务器客户端通信多进程版本和多线程版本守护进程化的多线程服务器 日志类(完成TCP/UDP套接字常见连接过程中的日志打印) 为了让我们的代码更规范化&…

3万字80道Java基础经典面试题总结(2024修订版)

大家好,我是哪吒。 本系列是《10万字208道Java经典面试题总结(附答案)》的2024修订版。 目录 1、说说跨平台性2、Java是如何实现跨平台性的?3、JDK 和 JRE 有什么区别?4、为何要配置Java环境变量?5、Java都有哪些特性&#xff1f…

(八)Gateway服务网关

Gateway服务网关 Spring Cloud Gateway 是 Spring Cloud 的一个全新项目,该项目是基于 Spring 5.0,Spring Boot 2.0 和 Project Reactor 等响应式编程和事件流技术开发的网关,它旨在为微服务架构提供一种简单有效的统一的 API 路由管理方式。…

Linux/Headless

Headless Enumeration nmap 用 nmap 扫描了常见的端口,发现对外开放了 22 和 5000,而且 nmap 显示 5000 端口的服务是 upnp? ┌──(kali㉿kali)-[~/vegetable/HTB/headless] └─$ nmap 10.10.11.8 Starting Nmap 7.93 ( https://nmap.or…

打造安全医疗网络:三网整体规划与云数据中心构建策略

医院网络安全问题涉及到医院日常管理多个方面,一旦医院信息管理系统在正常运行过程中受到外部恶意攻击,或者出现意外中断等情况,都会造成海量医疗数据信息的丢失。由于医院信息管理系统中存储了大量患者个人信息和治疗方案信息等,…

神奇的css radial-gradient

使用css radial-gradient属性,创造一个中间凹陷进去的形状。如下图 background: radial-gradient(circle at 50% -0.06rem, transparent 0.1rem, white 0) top left 100% no-repeat;

vue系列——v-on

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>v-on指令</title> </head> <body>…

docker部署DOS游戏

下载镜像 docker pull registry.cn-beijing.aliyuncs.com/wuxingge123/dosgame-web-docker:latestdocker-compose部署 vim docker-compose.yml version: 3 services:dosgame:container_name: dosgameimage: registry.cn-beijing.aliyuncs.com/wuxingge123/dosgame-web-docke…

贪吃蛇:从零开始搭建一个完整的小游戏

目录 导语&#xff1a; 一、游戏框架 二、蛇的实现 三、绘制游戏界面 四、食物 五、移动蛇 六.得分系统&#xff0c;是否吃到食物 七、检查碰撞 八、处理按键事件 九、得分系统 十、游戏状态管理 导语&#xff1a; 贪吃蛇这个经典的小游戏&#xff0c;我上学的时候就…

设计模式-概述篇

1. 掌握设计模式的层次 第1层&#xff1a;刚开始学编程不久&#xff0c;听说过什么是设计模式第2层&#xff1a;有很长时间的编程经验&#xff0c;自己写了很多代码&#xff0c;其中用到了设计模式&#xff0c;但是自己却不知道第3层&#xff1a;学习过了设计模式&#xff0c;…

【机器学习】无监督学习与聚类技术:解锁数据的隐藏结构

无监督学习介绍 无监督学习&#xff0c;作为机器学习的一大分支&#xff0c;专注于探索未经标记的数据集中的潜在结构。不同于有监督学习&#xff0c;无监督学习不依赖于外部提供的标签或输出结果&#xff0c;而是通过数据本身的特征来寻找模式、聚类或降维。这种学习方法在多…

03-MySQl数据库的-用户管理

一、创建新用户 mysql> create user xjzw10.0.0.% identified by 1; Query OK, 0 rows affected (0.01 sec) 二、查看当前数据库正在登录的用户 mysql> select user(); ---------------- | user() | ---------------- | rootlocalhost | ---------------- 1 row …

PI案例分享--2000A核心电源网络的设计、仿真与验证

目录 摘要 0 引言 1 为什么需要 2000A 的数字电子产品? 2 2000A 的供电电源设计 2.1 "MPM3698 2*MPM3699"的 MPS扩展电源架构 2.2 使用恒定导通时间(COT)模式输出核心电压的原因 2.3 模块化 VRM 的优势 2.4 用步进负载验证2000A的设计难点 2.4.1 电源网络 …

初始Java篇(JavaSE基础语法)(5)(类和对象(上))

个人主页&#xff08;找往期文章包括但不限于本期文章中不懂的知识点&#xff09;&#xff1a;我要学编程(ಥ_ಥ)-CSDN博客 目录 面向对象的初步认知 面向对象与面向过程的区别 类的定义和使用 类的定义格式 类的实例化 this引用 什么是this引用&#xff1f; this引用…

Python爬虫-懂车帝城市销量榜单

前言 本文是该专栏的第23篇,后面会持续分享python爬虫干货知识,记得关注。 最近粉丝留言咨询某汽车平台的汽车销量榜单数据,本文笔者以懂车帝平台为例,采集对应的城市汽车销量榜单数据。 具体的详细思路以及代码实现逻辑,跟着笔者直接往下看正文详细内容。(附带完整代码…

Gradle 使用详解

目录 一. 前言 二. 下载与安装 2.1. 下载 2.2. 配置环境变量 2.3. 配置镜像 2.3.1. 全局设置 2.3.2. 项目级设置 三. Gradle 配置文件 3.1. build.gradle 3.2. settings.gradle 3.3. gradle.properties 3.4. init.d 目录 3.5. buildSrc 目录 四. Java Library 插…

计算机网络——28自治系统内部的路由选择

自治系统内部的路由选择 RIP 在1982年发布的BSD-UNIX中实现Distance vector算法 距离矢量&#xff1a;每条链路cost 1&#xff0c;# of hops(max 15 hops)跳数DV每隔30秒和邻居交换DV&#xff0c;通告每个通告包括&#xff1a;最多25个目标子网 RIP通告 DV&#xff1a;在…

Qt笔记-解决Qt程序连不上数据库MySQL数据库(重编libqsqlmysql.so)

使用QSqlDatabase连接MySQL数据库时。在自己程序配置没有错误的情况下报这类错误&#xff1a; QSqlDatabase: QMYSQL driver not loaded QSqlDatabase::exec: database not open 造成这样的问题大多数是libqsqlmysql.so有问题。 Qt的QSqlDatabase使用的是libqsqlmysql.so&a…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《偏远地区能源自洽系统源储容量协同配置方法》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

HTTP 常见面试题(计算机网络)

HTTP 基本概念 一、HTTP 是什么&#xff1f; HTTP(HyperText Transfer Protocol) &#xff1a;超文本传输协议。 HTTP 是一个在计算机世界里专门在「两点」之间「传输」文字、图片、音频、视频等「超文本」数据的「约定和规范」。 「HTTP 是用于从互联网服务器传输超文本到本…