政安晨:【Keras机器学习示例演绎】(三十二)—— 在 Vision Transformers 中学习标记化

目录

导言

导入

超参数

加载并准备 CIFAR-10 数据集

数据扩增

位置嵌入模块

变压器的 MLP 模块

令牌学习器模块

变换器组

带有 TokenLearner 模块的 ViT 模型

培训实用程序

使用 TokenLearner 培训和评估 ViT

实验结果

参数数量

最终说明


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:为 "视觉变换器 "自适应生成较少数量的令牌。

导言


视觉变换器(Dosovitskiy 等人)和许多其他基于变换器的架构(Liu 等人、Yuan 等人)在图像识别方面取得了显著成果。

下面将简要介绍用于图像分类的视觉变换器架构所涉及的组件:

—— 从输入图像中提取小块图像。
—— 线性投影这些斑块。
—— 为这些线性投影添加位置嵌入。
—— 通过一系列 Transformer(Vaswani 等人)模块运行这些投影。
—— 最后,从最后的 Transformer 模块中提取表示并添加分类头。

如果我们获取 224x224 的图像并提取 16x16 的补丁,那么每张图像总共会得到 196 个补丁(也称为标记)。

随着分辨率的提高,补丁的数量也会增加,从而导致内存占用增加。

我们能否在不影响性能的情况下减少补丁的数量呢?Ryoo 等人在 TokenLearner 中研究了这个问题:视频的自适应时空标记化》中研究了这个问题。他们引入了一个名为 TokenLearner 的新模块,该模块能以自适应的方式帮助减少视觉转换器(ViT)使用的补丁数量。将 TokenLearner 纳入标准 ViT 架构后,他们能够减少模型使用的计算量(以 FLOPS 衡量)。

在本示例中,我们实现了 TokenLearner 模块,并用迷你 ViT 和 CIFAR-10 数据集演示了其性能。

导入

import keras
from keras import layers
from keras import ops
from tensorflow import data as tf_data


from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

import math

超参数


请随时更改超参数并检查结果。对架构产生直觉的最好方法就是进行实验。

# DATA
BATCH_SIZE = 256
AUTO = tf_data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10

# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# TRAINING
EPOCHS = 1

# AUGMENTATION
IMAGE_SIZE = 48  # We will resize input images to this size.
PATCH_SIZE = 6  # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# ViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4
MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]

# TOKENLEARNER
NUM_TOKENS = 4

加载并准备 CIFAR-10 数据集

# Load the CIFAR-10 dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")

# Convert to tf.data.Dataset objects.
train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)

val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)

test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
Training samples: 40000
Validation samples: 10000
Testing samples: 10000

数据扩增

扩增管道包括:

重新缩放
调整大小
随机裁剪(固定大小或随机大小)
随机水平翻转

请注意,图像数据增强层在推理时不应用数据转换。这意味着在调用这些层时,如果训练=假,它们的行为会有所不同。

位置嵌入模块

Transformer 架构由多头自我关注层和全连接前馈网络(MLP)作为主要组成部分。这两个组件都具有排列不变性:它们不考虑特征顺序。

为了克服这一问题,我们为标记注入了位置信息。

position_embedding 函数将位置信息添加到线性投影的标记中。

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        encoded = patch + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

变换器的 MLP 模块


这是变换器的全连接前馈模块。

def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for units in hidden_units:
        x = layers.Dense(units, activation=ops.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

令牌学习器模块

下图是该模块的图示概览(来源)。

TokenLearner 模块将图像形状的张量作为输入。然后,它将其通过多个单通道卷积层,提取不同的空间注意力图,并将注意力集中在输入的不同部分。然后,将这些注意力图按元素顺序与输入相乘,并对结果进行汇集。汇集后的输出可以看作是输入的汇总,其补丁数量(例如 8 个)远远少于原始输出(例如 196 个)。

使用多个卷积层有助于提高表现力。施加一种空间注意力有助于保留输入的相关信息。这两个部分对 TokenLearner 的运行都至关重要,尤其是当我们要大幅减少贴片数量时。

def token_learner(inputs, number_of_tokens=NUM_TOKENS):
    # Layer normalize the inputs.
    x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs)  # (B, H, W, C)

    # Applying Conv2D => Reshape => Permute
    # The reshape and permute is done to help with the next steps of
    # multiplication and Global Average Pooling.
    attention_maps = keras.Sequential(
        [
            # 3 layers of conv with gelu activation as suggested
            # in the paper.
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=ops.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=ops.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=ops.gelu,
                padding="same",
                use_bias=False,
            ),
            # This conv layer will generate the attention maps
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation="sigmoid",  # Note sigmoid for [0, 1] output
                padding="same",
                use_bias=False,
            ),
            # Reshape and Permute
            layers.Reshape((-1, number_of_tokens)),  # (B, H*W, num_of_tokens)
            layers.Permute((2, 1)),
        ]
    )(
        x
    )  # (B, num_of_tokens, H*W)

    # Reshape the input to align it with the output of the conv block.
    num_filters = inputs.shape[-1]
    inputs = layers.Reshape((1, -1, num_filters))(inputs)  # inputs == (B, 1, H*W, C)

    # Element-Wise multiplication of the attention maps and the inputs
    attended_inputs = (
        ops.expand_dims(attention_maps, axis=-1) * inputs
    )  # (B, num_tokens, H*W, C)

    # Global average pooling the element wise multiplication result.
    outputs = ops.mean(attended_inputs, axis=2)  # (B, num_tokens, C)
    return outputs

变换器组

def transformer(encoded_patches):
    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
    )(x1, x1)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)

    # Skip connection 2.
    encoded_patches = layers.Add()([x4, x2])
    return encoded_patches

带有 TokenLearner 模块的 ViT 模型

def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS):
    inputs = layers.Input(shape=INPUT_SHAPE)  # (B, H, W, C)

    # Augment data.
    augmented = data_augmentation(inputs)

    # Create patches and project the pathces.
    projected_patches = layers.Conv2D(
        filters=PROJECTION_DIM,
        kernel_size=(PATCH_SIZE, PATCH_SIZE),
        strides=(PATCH_SIZE, PATCH_SIZE),
        padding="VALID",
    )(augmented)
    _, h, w, c = projected_patches.shape
    projected_patches = layers.Reshape((h * w, c))(
        projected_patches
    )  # (B, number_patches, projection_dim)

    # Add positional embeddings to the projected patches.
    encoded_patches = PatchEncoder(
        num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
    )(
        projected_patches
    )  # (B, number_patches, projection_dim)
    encoded_patches = layers.Dropout(0.1)(encoded_patches)

    # Iterate over the number of layers and stack up blocks of
    # Transformer.
    for i in range(NUM_LAYERS):
        # Add a Transformer block.
        encoded_patches = transformer(encoded_patches)

        # Add TokenLearner layer in the middle of the
        # architecture. The paper suggests that anywhere
        # between 1/2 or 3/4 will work well.
        if use_token_learner and i == NUM_LAYERS // 2:
            _, hh, c = encoded_patches.shape
            h = int(math.sqrt(hh))
            encoded_patches = layers.Reshape((h, h, c))(
                encoded_patches
            )  # (B, h, h, projection_dim)
            encoded_patches = token_learner(
                encoded_patches, token_learner_units
            )  # (B, num_tokens, c)

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

如令牌学习器论文所示,将令牌学习器模块置于网络中间几乎总是有利的。

培训实用程序

def run_experiment(model):
    # Initialize the AdamW optimizer.
    optimizer = keras.optimizers.AdamW(
        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    # Compile the model with the optimizer, loss function
    # and the metrics.
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    # Define callbacks
    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    # Train the model.
    _ = model.fit(
        train_ds,
        epochs=EPOCHS,
        validation_data=val_ds,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(test_ds)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

使用 TokenLearner 培训和评估 ViT

vit_token_learner = create_vit_classifier()
run_experiment(vit_token_learner)
 157/157 ━━━━━━━━━━━━━━━━━━━━ 303s 2s/step - accuracy: 0.1158 - loss: 2.4798 - top-5-accuracy: 0.5352 - val_accuracy: 0.2206 - val_loss: 2.0292 - val_top-5-accuracy: 0.7688
 40/40 ━━━━━━━━━━━━━━━━━━━━ 5s 133ms/step - accuracy: 0.2298 - loss: 2.0179 - top-5-accuracy: 0.7723
Test accuracy: 22.9%
Test top 5 accuracy: 77.22%

实验结果


我们进行了实验,在我们实现的迷你 ViT 中使用和不使用 TokenLearner(超参数与本例中介绍的相同)。以下是我们的结果:

TokenLearner# tokens in
TokenLearner
Top-1 Acc
(Averaged across 5 runs)
GFLOPsTensorBoard
N-56.112%0.0184Link
Y856.55%0.0153Link
N-56.37%0.0184Link
Y456.4980%0.0147Link
N- (# Transformer layers: 8)55.36%0.0359Link

在没有模块的情况下,TokenLearner 的性能始终优于我们的迷你 ViT。同样有趣的是,它也能超越我们的迷你 ViT 的更深版本(有 8 层)。咱们以前也看到类似的观察结果,并将其归功于 TokenLearner 的适应性。

我们还应该注意到,随着令牌学习器模块的加入,FLOPs 数量大大减少。随着 FLOPs 数的减少,TokenLearner 模块能够提供更好的结果。这与作者的研究结果非常吻合。

此外,咱们还为较小的训练数据机制推出了更新版本的 TokenLearner。

该版本不再使用 4 个具有小通道的卷积层来实现空间注意力,而是使用 2 个具有更多通道的分组卷积层。它还使用了 softmax 而不是 sigmoid。我们证实,在训练数据有限的情况下,比如使用 ImageNet1K 从头开始训练时,该版本的效果更好。

我们对该模块进行了实验,并在下表中对实验结果进行了总结:

# Groups# TokensTop-1 AccGFLOPsTensorBoard
4454.638%0.0149Link
8854.898%0.0146Link
4855.196%0.0149Link

请注意,我们在本例中使用了相同的超参数。我们的实现可以在本笔记本中找到。我们承认,使用这个新的 TokenLearner 模块得出的结果比预期的略有偏差,这可能会随着超参数的调整而有所缓解。

(注:为了计算模型的 FLOPs,我们使用了这个软件库中的实用程序。)

参数数量


你可能已经注意到,添加 TokenLearner 模块会增加基础网络的参数数量。但这并不意味着效率会降低,正如 Dehghani 等人的研究所示。贝洛等人也报告了类似的发现。令牌学习器模块有助于减少整个网络的 FLOPS,从而有助于减少内存占用。

最终说明


TokenFuser论文作者还提出了另一个名为 TokenFuser 的模块。

该模块有助于将 TokenLearner 输出的表示重映射回其原始空间分辨率。

要在 ViT 架构中重复使用 TokenLearner,TokenFuser 是必不可少的。

我们首先从令牌学习器中学习令牌,从变换器层建立令牌的表示,然后将表示重映射到原始空间分辨率,这样令牌学习器就能再次使用它。

请注意,在整个 ViT 模型中,如果不与 TokenFuser 配对,只能使用一次 TokenLearner 模块。
在视频中使用这些模块:咱们还建议 TokenFuser 与 Vision Transformers for Videos(阿纳布等人)搭配使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/591247.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32单片机实战开发笔记-EXIT外部中断检测

嵌入式单片机开发实战例程合集: 链接:https://pan.baidu.com/s/11av8rV45dtHO0EHf8e_Q0Q?pwd28ab 提取码:28ab EXIT模块测试 功能描述 外部中断/事件控制器由19个产生事件/中断要求的边沿检测器组成。每个输入线可以独立地配置输入类型&a…

政安晨:【Keras机器学习示例演绎】(三十三)—— 知识提炼

目录 设置 构建 Distiller() 类 创建学生和教师模型 准备数据集 培训教师 将教师模型蒸馏给学生模型 从头开始训练学生进行比较 政安晨的个人主页:政安晨 欢迎 👍点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras机器学习实战 希望政安晨的博客能够…

基于Spring Boot的医疗服务系统设计与实现

基于Spring Boot的医疗服务系统设计与实现 开发语言:Java框架:springbootJDK版本:JDK1.8数据库工具:Navicat11开发软件:eclipse/myeclipse/idea 系统部分展示 医疗服务系统首页界面图,公告信息、医疗地图…

快速幂笔记

快速幂即为快速求出一个数的幂&#xff0c;这样可以避免TLE&#xff08;超时&#xff09;的错误。 传送门&#xff1a;快速幂模板 前置知识&#xff1a; 1) 又 2) 代码&#xff1a; #include <bits/stdc.h> using namespace std; int quickPower(int a, int b) {int…

容斥原理以及Nim基础(异或,SG函数)

容斥原理&#xff1a; 容斥的复杂度为O&#xff08;2^m)&#xff0c;所以可以通过&#xff0c;对于实现&#xff0c;一共2^n-1种&#xff0c;我们可以用二进制来实现 下面是AC代码&#xff1a; #include<bits/stdc.h> using namespace std; typedef long long LL; cons…

uniapp 文字转语音(文字播报、语音合成)、震动提示插件 Ba-TTS

简介&#xff08;下载地址&#xff09; Ba-TTS 是一款uniapp语音合成&#xff08;tts&#xff09;插件&#xff0c;支持文本转语音&#xff08;无服务费&#xff09;&#xff0c;支持震动提示。 支持语音合成&#xff0c;文本转语音支持震动&#xff08;可自定义任意震动效果…

【云原生】Docker 实践(四):使用 Dockerfile 文件的综合案例

【Docker 实践】系列共包含以下几篇文章&#xff1a; Docker 实践&#xff08;一&#xff09;&#xff1a;在 Docker 中部署第一个应用Docker 实践&#xff08;二&#xff09;&#xff1a;什么是 Docker 的镜像Docker 实践&#xff08;三&#xff09;&#xff1a;使用 Dockerf…

【LinuxC语言】信号的基本概念与基本使用

文章目录 前言一、信号的概念二、信号的使用2.1 基本的信号类型2.2 signal函数 总结 前言 在Linux环境下&#xff0c;信号是一种用于通知进程发生了某种事件的机制。这些事件可能是由操作系统、其他进程或进程本身触发的。对于C语言编程者来说&#xff0c;理解信号的基本概念和…

36.Docker-Dockerfile自定义镜像

镜像结构 镜像是将应用程序及其需要的系统函数库、环境、配置、依赖打包而成。 镜像是分层机构&#xff0c;每一层都是一个layer BaseImage层&#xff1a;包含基本的系统函数库、环境变量、文件系统 EntryPoint:入口&#xff0c;是镜像中应用启动的命令 其他&#xff1a;在…

从0开始学习制作一个微信小程序 学习部分(6)组件与事件绑定

系列文章目录 学习篇第一篇我们讲了编译器下载&#xff0c;项目、环境建立、文件说明与简单操作&#xff1a;第一篇链接 第二、三篇分析了几个重要的配置json文件&#xff0c;是用于对小程序进行的切换页面、改变图标、控制是否能被搜索到等的操作第二篇链接、第三篇链接 第四…

QT的TcpServer

Server服务器端 QT版本5.6.1 界面设计 工程文件&#xff1a; 添加 network 模块 头文件引入TcpServer类和TcpSocket&#xff1a;QTcpServer和QTcpSocket #include <QTcpServer> #include <QTcpSocket>创建server对象并实例化&#xff1a; /*h文件中*/QTcpServer…

基于SSM SpringBoot vue宾馆网上预订综合业务服务系统

基于SSM SpringBoot vue宾馆网上预订综合业务服务系统 系统功能 首页 图片轮播 宾馆信息 饮食美食 休闲娱乐 新闻资讯 论坛 留言板 登录注册 个人中心 后台管理 登录注册 个人中心 用户管理 客房登记管理 客房调整管理 休闲娱乐管理 类型信息管理 论坛管理 系统管理 新闻资讯…

Docker-Compose编排LNMP并部署WordPress

前言 随着云计算和容器化技术的快速发展&#xff0c;使用 Docker Compose 编排 LNMP 环境已经成为快速部署 Web 应用程序的一种流行方式。LNMP 环境由 Linux、Nginx、MySQL 和 PHP 组成&#xff0c;为运行 Web 应用提供了稳定的基础。本文将介绍如何通过 Docker Compose 编排 …

BUUCTF---misc---被偷走的文件

1、题目描述 2、下载附件&#xff0c;是一个流量包&#xff0c;拿去wireshark分析&#xff0c;依次点开流量&#xff0c;发现有个流量的内容显示flag.rar 3、接着在kali中分离出压缩包&#xff0c;使用下面命令&#xff0c;将压缩包&#xff0c;分离出放在out3文件夹中 4、在文…

Java -- (part21)

一.File类 1.概述 表示文件或者文件夹的路径抽象表示形式 2.静态成员 static String pathSeparator:路径分隔符:; static String separator:名称分隔符:\ 3.构造方法 File(String parent,String child) File(File parent,String child) Flie(String path) 4.方法 获…

在M1芯片安装鸿蒙闪退解决方法

在M1芯片安装鸿蒙闪退解决方法 前言下载鸿蒙系统安装完成后&#xff0c;在M1 Macos14上打开闪退解决办法接下来就是按照提示一步一步安装。 前言 重新安装macos系统后&#xff0c;再次下载鸿蒙开发软件&#xff0c;竟然发现打不开。 下载鸿蒙系统 下载地址&#xff1a;http…

Android使用kts上传aar到JitPack仓库

Android使用kts上传aar到JitPack 之前做过sdk开发&#xff0c;需要将仓库上传到maven、JitPack或JCenter,但是JCenter已停止维护&#xff0c;本文是讲解上传到JitPack的方式,使用KTS语法&#xff0c;记录使用过程中遇到的一些坑. 1.创建项目(library方式) 由于之前用鸿神的w…

每日OJ题_贪心算法二⑥_力扣409. 最长回文串

目录 力扣409. 最长回文串 解析代码 力扣409. 最长回文串 409. 最长回文串 难度 简单 给定一个包含大写字母和小写字母的字符串 s &#xff0c;返回 通过这些字母构造成的 最长的回文串 。 在构造过程中&#xff0c;请注意 区分大小写 。比如 "Aa" 不能当做一个…

【Java从入门到精通】Java继承

继承的概念 继承是java面向对象编程技术的一块基石&#xff0c;因为它允许创建分等级层次的类。 继承就是子类继承父类的特征和行为&#xff0c;使得子类对象&#xff08;实例&#xff09;具有父类的实例域和方法&#xff0c;或子类从父类继承方法&#xff0c;使得子类具有父…

Linux 麒麟系统安装

国产麒麟系统官网地址&#xff1a; https://www.openkylin.top/downloads/ 下载该镜像后&#xff0c;使用VMware部署一个虚拟机&#xff1a; 完成虚拟机创建。点击&#xff1a;“开启此虚拟机” 选择“试用试用开放麒麟而不安装&#xff08;T&#xff09;”&#xff0c;进入op…