政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras实战演绎机器学习

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

通过 Keras,您可以编写自定义层、模型、度量指标、损失和优化器,并在同一代码库中跨 TensorFlow、JAX 和 PyTorch 运行

老规矩,咱们还是先准备环境(参考我本专栏目录中的文章,其中有搭建环境的部分):

政安晨:【TensorFlow与Keras实战演绎机器学习】专栏 —— 目录icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136985399

准备好环境后,咱们开始。

编写组件

让我们先来看看自定义层

{keras.ops 命名空间包含}
1. NumPy API 的实现,例如 keras.ops.stack 或 keras.ops.matmul
2. 一组 NumPy 中没有的神经网络特定操作,如 keras.ops.conv 或 keras.ops.binary_crossentropy

让我们创建一个可与所有后端配合使用的自定义密集层

class MyDense(keras.layers.Layer):
    def __init__(self, units, activation=None, name=None):
        super().__init__(name=name)
        self.units = units
        self.activation = keras.activations.get(activation)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.w = self.add_weight(
            shape=(input_dim, self.units),
            initializer=keras.initializers.GlorotNormal(),
            name="kernel",
            trainable=True,
        )

        self.b = self.add_weight(
            shape=(self.units,),
            initializer=keras.initializers.Zeros(),
            name="bias",
            trainable=True,
        )

    def call(self, inputs):
        # Use Keras ops to create backend-agnostic layers/metrics/etc.
        x = keras.ops.matmul(inputs, self.w) + self.b
        return self.activation(x)

接下来,让我们制作一个依赖于keras.random命名空间的自定义Dropout层

class MyDropout(keras.layers.Layer):
    def __init__(self, rate, name=None):
        super().__init__(name=name)
        self.rate = rate
        # Use seed_generator for managing RNG state.
        # It is a state element and its seed variable is
        # tracked as part of `layer.variables`.
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        # Use `keras.random` for random ops.
        return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)

接下来,让我们编写一个自定义子类模型,使用我们的两个自定义层:

class MyModel(keras.Model):
    def __init__(self, num_classes):
        super().__init__()
        self.conv_base = keras.Sequential(
            [
                keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
                keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
                keras.layers.MaxPooling2D(pool_size=(2, 2)),
                keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
                keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
                keras.layers.GlobalAveragePooling2D(),
            ]
        )
        self.dp = MyDropout(0.5)
        self.dense = MyDense(num_classes, activation="softmax")

    def call(self, x):
        x = self.conv_base(x)
        x = self.dp(x)
        return self.dense(x)

让我们编译并适配它:

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)

model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=1,  # For speed
    validation_split=0.15,
)

现在咱们演绎如下

在本地的TensorFlow虚拟环境中,首先导入keras:

from tensorflow import keras

(可以在Jupyter Notebook中运行)

如果在演绎执行中出错,可能是Keras版本问题,使用如下命令升级keras

sudo pip install --upgrade keras

执行结果:

训练模型

在任意数据源上训练模型

所有的Keras模型都可以在各种数据来源上进行训练和评估,与您使用的后端无关。这包括:

NumPy数组 Pandas数据框 TensorFlow tf.data.Dataset对象 PyTorch DataLoader对象 Keras PyDataset对象 无论您使用TensorFlow、JAX还是PyTorch作为Keras后端,它们都可以工作。

让我们尝试使用PyTorch DataLoader:

import torch

# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_test), torch.from_numpy(y_test)
)

# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(
    train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_torch_dataset, batch_size=batch_size, shuffle=False
)

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)

现在让我们尝试使用tf.data来完成这个任务

import tensorflow as tf

train_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)
test_dataset = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)
model.fit(train_dataset, epochs=1, validation_data=test_dataset)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/488251.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【解析几何】 【多源路径】 【贪心】1520 最多的不重叠子字符串

作者推荐 视频算法专题 本身涉及知识点 解析几何 图论 多源路径 贪心 LeetCode1520. 最多的不重叠子字符串 给你一个只包含小写字母的字符串 s ,你需要找到 s 中最多数目的非空子字符串,满足如下条件: 这些字符串之间互不重叠&#xff0…

LeetCode 面试经典150题 392.判断子序列

题目: 给定字符串 s 和 t ,判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些(也可以不删除)字符而不改变剩余字符相对位置形成的新字符串。(例如,"ace"是"abcde"…

量子计算新“尺度”:用经典计算机评估复杂量子系统!

未来的量子计算机有望在计算机科学、医疗、商业、化学、物理学等多个领域解决难题,从而超越传统计算机。然而,目前的量子计算机仍存在局限,主要是由于它们固有的错误率。为此,研究者正致力于降低这些错误率。 一种研究量子计算机误…

Linux 性能优化

性能优化 性能指标 高并发和响应快对应着性能优化的两个核心指标:吞吐和延时 应用负载角度:直接影响了产品终端的用户体验 系统资源角度:资源使用率、饱和度等 性能问题的本质就是系统资源已经到达瓶颈,但请求的处理还不够快…

Gemma开源AI指南

近几个月来,谷歌推出了 Gemini 模型,在人工智能领域掀起了波澜。 现在,谷歌推出了 Gemma,再次引领创新潮流,这是向开源人工智能世界的一次变革性飞跃。 与前代产品不同,Gemma 是一款轻量级、小型模型&…

Linux 搭建jenkins docker

jekin docker gitee docker 安装 jenkins docker run -d --restartalways \ --name jenkins -uroot -p 10340:8080 \ -p 10341:50000 \ -v /home/docker/jenkins:/var/jenkins_home \ -v /var/run/docker.sock:/var/run/docker.sock \ -v /usr/bin/docker:/usr/bin/docker je…

【NLP学习记录】Embedding和EmbeddingBag

Embedding与EmbeddingBag详解 ●🍨 本文为🔗365天深度学习训练营 中的学习记录博客 ●🍖 原作者:K同学啊 | 接辅导、项目定制 ●🚀 文章来源:K同学的学习圈子1、Embedding详解 Embedding是Pytorch中最基本…

VR全景展示:传统制造业如何保持竞争优势?

在结束不久的两会上,数字化经济和创新技术再度成为了热门话题。我国制造产业链完备,但是目前依旧面临着市场需求不足、成本传导压力加大等因素影响,那么传统制造业该如何保持竞争优势呢? 在制造行业中,VR全景展示的应用…

图解Kafka架构学习笔记(二)

kafka的存储机制 https://segmentfault.com/a/1190000021824942 https://www.lin2j.tech/md/middleware/kafka/Kafka%E7%B3%BB%E5%88%97%E4%B8%83%E5%AD%98%E5%82%A8%E6%9C%BA%E5%88%B6.html https://tech.meituan.com/2015/01/13/kafka-fs-design-theory.html https://feiz…

第十二届蓝桥杯物联网试题(省赛)

思路: 这个考了一个RTC的配置,RTC我只配过一次,所以有些生疏,还是不能大意,一些偏僻的考点还是要多练,在获取RTC时间的时候也遇到一些bug,这个后续会用一篇博客将最近遇到的BUG都总结一下 主要的难点还是…

纯前端调用本机原生Office实现Web在线编辑Word/Excel/PPT,支持私有化部署

在日常协同办公过程中,一份文件可能需要多次重复修改才能确定,如果你发送给多个人修改后再汇总,这样既效率低又容易出错,这就用到网页版协同办公软件了,不仅方便文件流转还保证不会出错。 但是目前一些在线协同Office…

好莱坞新风潮:OpenAI携手Sora AI视频生成工具探索电影制作新境界

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

24.两两交换链表中的节点

给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换)。 示例 1: 输入:head [1,2,3,4] 输出:[2,1,4…

Qt|QStringList转QString

免责:百度搜索AI自动生成,如果侵权联系我删除。 AI生成有错误,已验证修改。 文章目录 1.使用join()方法:2.使用QTextStream:3.使用QString的arg()方法:4.使用std::for_each和lambda表达式:5.使…

elasticsearch 6.8.x 索引别名、动态索引扩展、滚动索引

文章目录 引言索引别名(alias)创建索引别名查询索引别名删除索引别名重命名索引别名 动态索引(index template,动态匹配生成索引)新建索引模板新建索引并插入数据索引sys-log-202402索引sys-log-202403索引sys-log-202…

android studio忽略文件

右键文件,然后忽略,就不会出现在commit里面了 然后提交忽略文件即可

【算法专题--双指针算法】leecode-15.三数之和(medium)、leecode-18. 四数之和(medium)

🍁你好,我是 RO-BERRY 📗 致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 🎄感谢你的陪伴与支持 ,故事既有了开头,就要画上一个完美的句号,让我们一起加油 目录 前言1. 三数之和2. 解法&…

龙蜥 Anolis OS 7.9 一键安装 Oracle 11GR2(231017)单机版

前言 Oracle 一键安装脚本,演示 龙蜥 Anolis OS 7.9 一键安装 Oracle 11GR2(231017)单机版过程(全程无需人工干预):(脚本包括 ORALCE PSU/OJVM 等补丁自动安装) ⭐️ 脚本下载地址…

【编译tingsboard】出现gradle-maven-plugin:1.0.11:invoke (default)

出现的错误: [ERROR] Failed to execute goal org.thingsboard:gradle-maven-plugin:1.0.11:invoke (default) on project http: Execution default of goal org.thingsboard:gradle-maven-plugin:1.0.11:invoke failed: Plugin org.thingsboard:gradle-maven-plugi…

uni-app框架(项目创建)

1.学习说明 dcloud官方除uni-app外,还有新生的uni-app x(即下一代uni-app),如果是初学者或者刚入门同学,建议还是使用uni-app进行开发。 无论是vue还是uni,作为前端开发的一个框架学习方法是一致的&#…