政安晨:【Keras机器学习实践要点】(二十)—— 使用现代 MLP 模型进行图像分类

目录

简介

设置

准备数据

配置超参数

建立分类模型

定义实验

使用数据增强

将补丁提取作为一个图层来实施

将位置嵌入作为一个图层来实施

MLP 混频器模型

FNet 模式

gMLP 模式

实施 gMLP 模块


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:为 CIFAR-100 图像分类实施 MLP-Mixer、FNet 和 gMLP 模型。

简介


本示例实现了三种基于多层感知器(MLP)的现代无注意力图像分类模型,并在 CIFAR-100 数据集上进行了演示:

1. Ilya Tolstikhin 等人基于两种类型 MLP 的 MLP-Mixer 模型。
2. James Lee-Thorp 等人基于非参数化傅立叶变换的 FNet 模型。
3. gMLP 模型,由 Hanxiao Liu 等人提出,基于带门控的 MLP。

本示例的目的并不是要比较这些模型,因为它们在不同数据集上的表现可能不同,而且超参数都经过了很好的调整。相反,它是为了展示这些模型主要构建模块的简单实现。

设置

import numpy as np
import keras
from keras import layers

准备数据

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

演绎如下:

配置超参数

weight_decay = 0.0001
batch_size = 128
num_epochs = 1  # Recommended num_epochs = 50
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 8  # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
embedding_dim = 256  # Number of hidden units.
num_blocks = 4  # Number of blocks.

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")

演绎如下:

建立分类模型

我们采用一种方法,根据处理模块构建分类器。

def build_classifier(blocks, positional_encoding=False):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
    x = layers.Dense(units=embedding_dim)(patches)
    if positional_encoding:
        x = x + PositionEmbedding(sequence_length=num_patches)(x)
    # Process x using the module blocks.
    x = blocks(x)
    # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
    representation = layers.GlobalAveragePooling1D()(x)
    # Apply dropout.
    representation = layers.Dropout(rate=dropout_rate)(representation)
    # Compute logits outputs.
    logits = layers.Dense(num_classes)(representation)
    # Create the Keras model.
    return keras.Model(inputs=inputs, outputs=logits)

定义实验

我们实现了一个实用功能,用于编译、训练和评估给定模型。

def run_experiment(model):
    # Create Adam optimizer with weight decay.
    optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )
    # Compile the model.
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="acc"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
        ],
    )
    # Create a learning rate scheduler callback.
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5
    )
    # Create an early stopping callback.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, restore_best_weights=True
    )
    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
        verbose=0,
    )

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Return history to plot learning curves.
    return history

使用数据增强

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

将补丁提取作为一个图层来实施

class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, x):
        patches = keras.ops.image.extract_patches(x, self.patch_size)
        batch_size = keras.ops.shape(patches)[0]
        num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]
        patch_dim = keras.ops.shape(patches)[3]
        out = keras.ops.reshape(patches, (batch_size, num_patches, patch_dim))
        return out

将位置嵌入作为一个图层来实施

class PositionEmbedding(keras.layers.Layer):
    def __init__(
        self,
        sequence_length,
        initializer="glorot_uniform",
        **kwargs,
    ):
        super().__init__(**kwargs)
        if sequence_length is None:
            raise ValueError("`sequence_length` must be an Integer, received `None`.")
        self.sequence_length = int(sequence_length)
        self.initializer = keras.initializers.get(initializer)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "initializer": keras.initializers.serialize(self.initializer),
            }
        )
        return config

    def build(self, input_shape):
        feature_size = input_shape[-1]
        self.position_embeddings = self.add_weight(
            name="embeddings",
            shape=[self.sequence_length, feature_size],
            initializer=self.initializer,
            trainable=True,
        )

        super().build(input_shape)

    def call(self, inputs, start_index=0):
        shape = keras.ops.shape(inputs)
        feature_length = shape[-1]
        sequence_length = shape[-2]
        # trim to match the length of the input sequence, which might be less
        # than the sequence_length of the layer.
        position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
        position_embeddings = keras.ops.slice(
            position_embeddings,
            (start_index, 0),
            (sequence_length, feature_length),
        )
        return keras.ops.broadcast_to(position_embeddings, shape)

    def compute_output_shape(self, input_shape):
        return input_shape

MLP 混频器模型

MLP 混频器是一种完全基于多层感知器(MLP)的架构,包含两种类型的 MLP 层

1. 一种是独立应用于图像斑块,混合每个位置的特征。
2. 另一种应用于跨斑块(沿通道),混合空间信息。

这类似于基于深度可分离卷积的模型,如 Xception 模型,但有两个链式密集变换,没有最大池化,以及层归一化而不是批归一化。

实施 MLP 混频器模块

class MLPMixerLayer(layers.Layer):
    def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.mlp1 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=num_patches),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.mlp2 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=hidden_units),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.normalize = layers.LayerNormalization(epsilon=1e-6)

    def build(self, input_shape):
        return super().build(input_shape)

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize(inputs)
        # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
        x_channels = keras.ops.transpose(x, axes=(0, 2, 1))
        # Apply mlp1 on each channel independently.
        mlp1_outputs = self.mlp1(x_channels)
        # Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
        mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))
        # Add skip connection.
        x = mlp1_outputs + inputs
        # Apply layer normalization.
        x_patches = self.normalize(x)
        # Apply mlp2 on each patch independtenly.
        mlp2_outputs = self.mlp2(x_patches)
        # Add skip connection.
        x = x + mlp2_outputs
        return x

构建、训练和评估 MLP-Mixer 模型

请注意,在 V100 GPU 上以当前设置训练模型,每个轮次大约需要 8 秒钟。

mlpmixer_blocks = keras.Sequential(
    [MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
history = run_experiment(mlpmixer_classifier)

演绎结果如下:

与卷积模型和基于变换器的模型相比,MLP-Mixer 模型的参数数量要少得多,这就减少了训练和计算成本。

正如 MLP-Mixer 论文中提到的,当在大型数据集上进行预训练或使用现代正则化方案时,MLP-Mixer 可获得与最先进模型相当的分数。您可以通过增加嵌入维度、增加混合块数量和延长模型训练时间来获得更好的结果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。

FNet 模式

FNet 使用与 Transformer 模块类似的模块。不过,FNet 用一个无参数的二维傅立叶变换层取代了 Transformer 模块中的自注意层:

1. 一个一维傅里叶变换沿斑块应用。
2. 沿通道进行一次一维傅里叶变换。

实施 FNet 模块

class FNetLayer(layers.Layer):
    def __init__(self, embedding_dim, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ffn = keras.Sequential(
            [
                layers.Dense(units=embedding_dim, activation="gelu"),
                layers.Dropout(rate=dropout_rate),
                layers.Dense(units=embedding_dim),
            ]
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        # Apply fourier transformations.
        real_part = inputs
        im_part = keras.ops.zeros_like(inputs)
        x = keras.ops.fft2((real_part, im_part))[0]
        # Add skip connection.
        x = x + inputs
        # Apply layer normalization.
        x = self.normalize1(x)
        # Apply Feedfowrad network.
        x_ffn = self.ffn(x)
        # Add skip connection.
        x = x + x_ffn
        # Apply layer normalization.
        return self.normalize2(x)

构建、训练和评估 FNet 模型

请注意,在 V100 GPU 上以当前设置训练模型,每个轮次大约需要 8 秒钟。

演绎如下;

如 FNet 论文所示,通过增加嵌入维度、增加 FNet 块数和延长模型训练时间,可以获得更好的结果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。FNet 可以非常高效地扩展到较长的输入,运行速度比基于注意力的 Transformer 模型快得多,并能产生具有竞争力的准确性结果。

gMLP 模式

gMLP 是一种以空间门控单元(SGU)为特色的 MLP 架构。空间门控单元(SGU)可通过以下方式实现跨空间(通道)维度的跨门控互动:

1. 通过跨补丁(沿通道)线性投影,对输入进行空间转换。
2. 对输入及其空间变换进行元素乘法运算。

实施 gMLP 模块

class gMLPLayer(layers.Layer):
    def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.channel_projection1 = keras.Sequential(
            [
                layers.Dense(units=embedding_dim * 2, activation="gelu"),
                layers.Dropout(rate=dropout_rate),
            ]
        )

        self.channel_projection2 = layers.Dense(units=embedding_dim)

        self.spatial_projection = layers.Dense(
            units=num_patches, bias_initializer="Ones"
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def spatial_gating_unit(self, x):
        # Split x along the channel dimensions.
        # Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].
        u, v = keras.ops.split(x, indices_or_sections=2, axis=2)
        # Apply layer normalization.
        v = self.normalize2(v)
        # Apply spatial projection.
        v_channels = keras.ops.transpose(v, axes=(0, 2, 1))
        v_projected = self.spatial_projection(v_channels)
        v_projected = keras.ops.transpose(v_projected, axes=(0, 2, 1))
        # Apply element-wise multiplication.
        return u * v_projected

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize1(inputs)
        # Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
        x_projected = self.channel_projection1(x)
        # Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
        x_spatial = self.spatial_gating_unit(x_projected)
        # Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
        x_projected = self.channel_projection2(x_spatial)
        # Add skip connection.
        return x + x_projected

建立、训练和评估 gMLP 模型

请注意,在 V100 GPU 上以当前设置训练模型,每个轮次大约需要 9 秒钟。

gmlp_blocks = keras.Sequential(
    [gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.003
gmlp_classifier = build_classifier(gmlp_blocks)
history = run_experiment(gmlp_classifier)

演绎如下:

如 gMLP 论文所示,通过增加嵌入维度、增加 gMLP 块数和延长模型训练时间,可以获得更好的效果。您还可以尝试增加输入图像的大小,并使用不同的补丁尺寸。请注意,该论文使用了高级正则化策略,如 MixUp 和 CutMix,以及 AutoAugment。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/524884.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

移动平台相关(安卓)

目录 安卓开发 Unity打包安卓 ​编辑​编辑 BuildSettings PlayerSettings OtherSettings 身份证明 配置 脚本编译 优化 PublishingSettings 调试 ReMote Android Logcat AndroidStudio的调试 Java语法 ​编辑​编辑​编辑 变量 运算符 ​编辑​编辑​编辑​…

基于JSP SSM的社区生活超市管理系统

目录 背景 技术简介 系统简介 界面预览 背景 随着时代步伐的加速,计算机技术已广泛而深刻地渗透到社会的各个层面。随着居民生活水平的持续提升,人们对社区生活超市的期望和管理要求也越来越高。随着社区生活超市数量的稳步增长,开发一个…

162 Linux C++ 通讯架构实战16,UDP/TCP协议的优缺点,使用环境对比。UDP 服务器开发

UDP/TCP协议的优缺点 TCP :面向连接的,可靠数据包传输。对于不稳定的网络层,采取完全弥补的通信方式。丢包重传 优点:稳定,数据流量稳定,速度稳定,顺序稳定 缺点:传输速度慢&…

【C语言】_文件类型,结束判定与文件缓冲区

目录 1. 文本文件和二进制文件 2. 文件读取结束的判定 3. 文件缓冲区 1. 文本文件和二进制文件 根据数据的组织形式,数据文件被称为文本文件或二进制文件; 数据在内存中以二进制的形式存储,如果不加转换地输出到外存,就是二进…

2024年最新版本的开源TwoNav网址导航系统源码 免授权

TwoNav 是一款新鲜发布的开源解密版书签(导航)管理程序。该程序采用PHP SQLite 3进行开发,具有界面简洁、安装简单、使用方便等特点,基础功能免费提供。TwoNav可以帮助用户集中管理浏览器书签,解决跨设备、跨平台和跨…

Text-Driven Object Detection 关于结合文本的目标检测

1、简单介绍 首先说明,本文目的主要是水一篇CSDN博客,顺便说一下和标题相关的认识。 近几年,在目标检测领域关于多模态的目标检测工作已成了主流,趋势仍在延续,未来仍有很大挖掘空间。这里说的多模态不是简单的多源数…

03-JAVA设计模式-建造者模式

建造者模式 什么是建造者模式 建造者模式(Builder Pattern)是一种对象构建的设计模式,它允许你通过一步一步地构建一个复杂对象,来隐藏复杂对象的创建细节。 这种模式将一个复杂对象的构建过程与其表示过程分离,使得…

Linux 线程:使用管理线程、多线程、分离线程

目录 一、使用线程 1、pthread_create创建线程 2、pthread_join等待线程 主线程获取新线程退出结果 获取新线程退出返回的数组 3、线程异常导致进程终止 4、pthread_exit 5、pthread_cancel 6、主线程可以取消新线程,新线程可以取消主线程吗 二、如何管理线…

vivado中移位寄存器的优化(二)

移位寄存器优化用于改善移位寄存器单元(SRLs)与其他逻辑单元之间的负裕量路径的时序。如果存在对移位寄存器单元(SRL16E或SRLC32E)的时序违规,优化会从SRL寄存器链的开始或结束位置提取一个寄存器,并将其放…

linux学习:gcc编译

编译.c gcc hello.c -o hello 用gcc 这个工具编译 hello.c,并且使之生成一个二进制文件 hello。 其中 –o 的意义是 output,指明要生成的文件的名称,如果不写 –o hello 的话会生成默 认的一个 a.out 文件 获得 C 源程序经过预处理之后的文…

【深度学习】StableDiffusion的组件解析,运行一些基础组件效果

文章目录 前言vaeclipUNetunet训练帮助、问询 前言 看了篇文: https://zhuanlan.zhihu.com/p/617134893 运行一些组件试试效果。 vae 代码: import torch from diffusers import AutoencoderKL import numpy as np from PIL import Image# 加载模型…

【Redis 知识储备】读写分离/主从分离架构 -- 分布系统的演进(4)

读写分离/主从分离架构 简介出现原因架构工作原理技术案例架构优缺点 简介 将数据库读写操作分散到不同的节点上, 数据库服务器搭建主从集群, 一主一从, 一主多从都可以, 数据库主机负责写操作, 从机只负责读操作 出现原因 数据库成为瓶颈, 而互联网应用一般读多写少, 数据库…

zdpdjango_argonadmin Django后台管理系统中的常见功能开发

效果预览 首先&#xff0c;看一下这个项目最开始的样子&#xff1a; 左侧优化 将左侧优化为下面的样子&#xff1a; 代码位置&#xff1a; 代码如下&#xff1a; {% load i18n static admin_argon %}<aside class"sidenav bg-white navbar navbar-vertical na…

SpringCloud Alibaba Sentinel 创建流控规则

一、前言 接下来是开展一系列的 SpringCloud 的学习之旅&#xff0c;从传统的模块之间调用&#xff0c;一步步的升级为 SpringCloud 模块之间的调用&#xff0c;此篇文章为第十四篇&#xff0c;即介绍 SpringCloud Alibaba Sentinel 创建流控规则。 二、基本介绍 我们在 senti…

Golang | Leetcode Golang题解之第16题最接近的三数之和

题目&#xff1a; 题解&#xff1a; func threeSumClosest(nums []int, target int) int {sort.Ints(nums)var (n len(nums)best math.MaxInt32)// 根据差值的绝对值来更新答案update : func(cur int) {if abs(cur - target) < abs(best - target) {best cur}}// 枚举 a…

2024/4/1—力扣—最小高度树

代码实现&#xff1a; /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ struct TreeNode* buildTree(int *nums, int l, int r) {if (l > r) {return NULL; // 递归出口}struct…

加州大学欧文分校英语基础语法专项课程01:Word Forms and Simple Present Tense 学习笔记

Word Forms and Simple Present Tense Course Certificate 本文是学习Coursera上 Word Forms and Simple Present Tense 这门课程的学习笔记。 文章目录 Word Forms and Simple Present TenseWeek 01: Introduction & BE VerbLearning Objectives Word FormsWord Forms (P…

云原生安全当前的挑战与解决办法

云原生安全作为一种新兴的安全理念&#xff0c;不仅解决云计算普及带来的安全问题&#xff0c;更强调以原生的思维构建云上安全建设、部署与应用&#xff0c;推动安全与云计算深度融合。所以现在云原生安全在云安全领域越来受到重视&#xff0c;云安全厂商在这块的投入也是越来…

工业网络自动化控制赛项分析

时间过去很久了,我突然想起来这篇文章还没写… 设备 它实际上是一个药盒装盖然后再进行一个归类码垛 左侧是供料,主要将盒子推出然后传送带送至中间工作站 中间工作站进行对料盒进行钢珠装填 再通过图像处理,判断大小,然后将数据传送到云服务器,最后通过伺服电机进行分类 …

飞书文档如何在不同账号间迁移

今天由于个人需要新建了一个飞书账号&#xff0c;遇到个需求就是需要把老帐号里面的文档迁移到新的账号里面。在网上搜了一通&#xff0c;发现关于此的内容似乎不多&#xff0c;只好自己动手解决&#xff0c;记录一下过程以便分享&#xff0c;主要有以下几个步骤。 1. 添加新账…