政安晨:【Keras机器学习实践要点】(十八)—— 利用视觉转换器进行图像分类

目录

简介

设置

准备数据

配置超参数

使用数据增强

实施多层感知器(MLP)

将创建修补程序作为一个层

实施补丁编码层

建立 ViT 模型

编译、培训和评估模式


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实现图像分类的视觉变换器(ViT)模型。

简介

本示例实现了 Alexey Dosovitskiy 等人提出的用于图像分类的视觉变换器(ViT)模型,并在 CIFAR-100 数据集上进行了演示。

ViT 模型不使用卷积层,而是对图像片段序列应用了具有自我关注功能的 Transformer 架构。

设置

import os

os.environ["KERAS_BACKEND"] = "jax"  # @param ["tensorflow", "jax", "torch"]

import keras
from keras import layers
from keras import ops

import numpy as np
import matplotlib.pyplot as plt

准备数据

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

演绎结果:

x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)

配置超参数

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 10  # For real training, use num_epochs=100. 10 is a test value
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [
    2048,
    1024,
]  # Size of the dense layers of the final classifier

使用数据增强

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

实施多层感知器(MLP)

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

将创建修补程序作为一个层

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = ops.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = keras.ops.image.extract_patches(images, size=self.patch_size)
        patches = ops.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config

让我们显示一幅图像样本的修补程序:

plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = ops.image.resize(
    ops.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = ops.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
    plt.axis("off")

演绎结果如下:

Image size: 72 X 72
Patch size: 6 X 6
Patches per image: 144
Elements per patch: 108

实施补丁编码层

PatchEncoder 层会通过投射到 projection_dim 大小的矢量中对补丁进行线性变换。此外,它还会为投影向量添加一个可学习的位置嵌入。

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

建立 ViT 模型

ViT 模型由多个变换器模块组成,这些模块使用 layers.MultiHeadAttention 层作为应用于补丁序列的自注意机制。Transformer 块会产生一个 [batch_size, num_patches, projection_dim] 张量,该张量通过一个带有 softmax 的分类器头进行处理,以产生最终的类别概率输出。

与论文中描述的将可学习嵌入预置到编码补丁序列中作为图像表示的技术不同,最终变换器块的所有输出都经过 layers.Flatten() 重塑,并用作分类器头部的图像表示输入。

请注意,layer.GlobalAveragePooling1D 层也可以用来聚合变换器块的输出,尤其是当贴片数量和投影尺寸较大时。

def create_vit_classifier():
    inputs = keras.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

编译、培训和评估模式

def run_experiment(model):
    optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)


def plot_history(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_history("loss")
plot_history("top-5-accuracy")

演绎结果:

Epoch 1/10
...
Epoch 10/10
 176/176 ━━━━━━━━━━━━━━━━━━━━ 449s 3s/step - accuracy: 0.0790 - loss: 3.9468 - top-5-accuracy: 0.2711 - val_accuracy: 0.0986 - val_loss: 3.8537 - val_top-5-accuracy: 0.3052

 313/313 ━━━━━━━━━━━━━━━━━━━━ 66s 198ms/step - accuracy: 0.1001 - loss: 3.8428 - top-5-accuracy: 0.3107
Test accuracy: 10.61%
Test top 5 accuracy: 31.51%

经过 100 次历时后,ViT 模型在测试数据上达到了约 55% 的准确率和 82% 的前五名准确率。

在 CIFAR-100 数据集上,这些结果并不具有竞争力,因为在相同数据上从头开始训练的 ResNet50V2 可以达到 67% 的准确率。

请注意,本文所报告的最新结果是通过使用 JFT-300M 数据集对 ViT 模型进行预训练,然后在目标数据集上对其进行微调而实现的。

要在不进行预训练的情况下提高模型质量,可以尝试对模型进行更多的历时训练、使用更多的变换层、调整输入图像的大小、改变补丁大小或增加投影维度。

此外,正如论文中提到的,模型的质量不仅会受到架构选择的影响,还会受到学习率安排、优化器、权重衰减等参数的影响。

在实践中,建议对使用大型高分辨率数据集预先训练好的 ViT 模型进行微调。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/521770.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android源码笔记-输入事件(二)

这一节主要了解输入事件的获取,InputReaderThread继承自C的Thread类,Thread类封装了pthread线程工具,提供了与Java层Thread类相似的API。C的Thread类提供了一个名为threadLoop()的纯虚函数,当线程开始运行后,将会在内建…

【Linux实践室】Linux高级用户管理实战指南:创建与删除用户组操作详解

🌈个人主页:聆风吟_ 🔥系列专栏:Linux实践室、网络奇遇记 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 一. ⛳️任务描述二. ⛳️相关知识2.1 🔔Linux创建用户组命令2.1.1 知识点讲解2.1.2…

Arduino的OTA在线升级

一、OTA 介绍 OTA是Over-the-Air的缩写,中文意思是空中下载技术。通过移动通信(GSM或CDMA)的空中接口对SIM卡数据及应用进行远程管理的技术。空中接口可以采用WAP、GPRS、CDMA1X及短消息技术。OTA技术的应用,使得移动通信不仅可以…

读所罗门的密码笔记12_群雄逐鹿(上)

1. 国际电信规则 1.1. 美国坚持互联网自由和极少的内容限制,这一立场肯定会遭到许多国家的反对 1.2. 除去两个各方针锋相对、无法妥协的议题,比如内容限制规定,实际上所有国家都已在打击垃圾邮件和常见网络安全威胁方…

Windows Server 2012 R2安装Web服务器IIS

文章目录 一、打开【服务器管理器】二、点击【添加角色和功能】三、点击【下一步】四、点击【下一步】五、点击【下一步】六、勾选【Web服务器(IIS)】→点击【添加功能】→点击【下一步】七、勾选【.NET Framework 3.5 功能】→点击【下一步】八、点击【下一步】九、点击【下一…

基于H2O AutoML与集成学习策略的房屋售价预测模型研究与实现

项目简述: 本项目采用H2O AutoML工具,针对加州房屋销售价格预测问题进行了深入研究与建模。项目以Kaggle提供的加州房屋 交易数据集为基础,通过数据清洗、特征工程、模型训练与评估等步骤,构建了一种基于集成学习策略的房价预测模…

LeetCode刷题记(二):31~60题

31. 下一个排列 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如,arr [1,2,3] ,以下这些都可以视作 arr 的排列:[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大的排列。…

如何确认RID池是否耗尽,以及手动增加RID池大小

确认RID池是否耗尽: 事件查看器: 在RID主控域控制器上打开事件查看器,导航至“Windows日志 > 应用程序和服务日志 > Microsoft > Windows > Directory Service > Operations”。搜索事件ID 16656和16657。事件ID 16656表明RID…

C++超高精度计时器

#include <chrono> #include <QDebug> #pragma execution_character_set("utf-8") /*1 秒&#xff08;s&#xff09; 1,000 毫秒&#xff08;ms&#xff09; 1,000,000 微秒&#xff08;μs&#xff09; 1,000,000,000 纳秒&#xff08;ns&#xff09…

Leetcode 第 389 场周赛题解

Leetcode 第 389 场周赛题解 Leetcode 第 389 场周赛题解题目1&#xff1a;3083. 字符串及其反转中是否存在同一子字符串思路代码复杂度分析 题目2&#xff1a;3084. 统计以给定字符开头和结尾的子字符串总数思路代码复杂度分析 题目3&#xff1a;3085. 成为 K 特殊字符串需要删…

Fire Smoke - Dynamic Nature

烟雾、火灾和爆炸预制件、着色器的集合。粒子支持HD、URP和标准渲染,自然制造风,因此它们对风速、方向和颤抖做出反应。 包装支持: Unity 2021.2及更高版本 Unity 2021.2 HD RP Unity 2021.2 URP Unity 2021.3及更高版本 Unity 2021.3 LTS HD RP Unity 2021.3 LTS URP Unity…

202112青少年软件编程(Scratch图形化)等级考试试卷(四级)

第1题&#xff1a;【 单选题】 小猫和小狗是非常好的朋友&#xff0c; 他们发明了一种加密方法&#xff1a; 用两位数字代表字母。比如 65 代表 A&#xff0c; 66 代表 B……&#xff0c; 75 代表 K&#xff0c; ……&#xff0c; 78 代表 N&#xff0c; 79 代表 O、 80 代表 …

zdpdjango_argonadmin使用Django开发一个美观的后台管理系统

初始代码 安装依赖 pip install -r requirements.txt生成管理员账户 迁移模型&#xff1a; python manage.py makemigrations python manage.py migrate创建超级用户&#xff1a; python manage.py createsuperuser启动服务 python manage.py runserver浏览器访问&#xf…

ChatGPT官宣新增Dynamic(动态)模式!

大家好&#xff0c;我是木易&#xff0c;一个持续关注AI领域的互联网技术产品经理&#xff0c;国内Top2本科&#xff0c;美国Top10 CS研究生&#xff0c;MBA。我坚信AI是普通人变强的“外挂”&#xff0c;所以创建了“AI信息Gap”这个公众号&#xff0c;专注于分享AI全维度知识…

Java中IO流详解

文章目录 字符流问题导入编码表**出现乱码的原因**ASCII表Unicode表汉字存储和展示过程解析问题导入解答 介绍分类字符输出流字符输入流字符缓冲流 特殊操作流转化流对象操作流打印流 工具包Commons-io介绍分类IOUtils类FileUtils类 字符流 问题导入 既然字节流能操作所有文件…

探索Linux的挂载操作

在Linux这个强大的操作系统中&#xff0c;挂载操作是一个基本而重要的概念。它涉及到文件系统、设备和数据访问&#xff0c;对于理解Linux的工作方式至关重要。那么&#xff0c;挂载操作究竟是什么&#xff0c;为什么我们需要它&#xff0c;如果没有它&#xff0c;我们将面临什…

递归学习第一个课

一、递归定义 基本定义 函数自己调用自己&#xff08;通俗第一印象&#xff09;大问题可以拆分小问题&#xff08;拆分&#xff0c;边界&#xff09;大问题与小问题的关系&#xff08;递归关系&#xff09; 为什么拆分小问题&#xff1f; 小问题更容易求解大问题与小问题内部…

基于Spark中随机森林模型的天气预测系统

基于Spark中随机森林模型的天气预测系统 在这篇文章中&#xff0c;我们将探讨如何使用Apache Spark和随机森林算法来构建一个天气预测系统。该系统将利用历史天气数据&#xff0c;通过机器学习模型预测未来的天气情况&#xff0c;特别是针对是否下雨的二元分类问题。 简介 Ap…

SpringBoot3整合RabbitMQ之四_发布订阅模型中的fanout模型

SpringBoot3整合RabbitMQ之四_发布订阅模型中的fanout模型 文章目录 SpringBoot3整合RabbitMQ之四_发布订阅模型中的fanout模型3. 发布/订阅模型之fanout模型1. 说明1. 消息发布者1. 创建工作队列的配置类2. 发布消费Controller 2. 消息消费者One3. 消息消费者Two4. 消息消费者…

实际项目中如何使用Git做分支管理

前言 Git是一种强大的分布式版本控制系统&#xff0c;在实际项目开发中使用Git进行分支管理是非常常见的做法&#xff0c;因为它可以帮助团队高效的协作和管理项目的不同版本&#xff0c;今天我们来讲讲在实际项目中最常用的Git分支管理策略Git Flow。 常见的Git分支管理策略…