政安晨:【Keras机器学习实践要点】(二十九)—— 半监督图像分类使用具有SimCLR对比性预训练的方法

目录

介绍

半监督学习

对比学习

设置

超参数设置

数据集

图像增强

编码器结构

有监督基线模型

用于对比预训练的自我监督模型

对预训练编码器进行有监督微调

与基准线的比较

进一步改进

架构

超参数

相关工作


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用SimCLR的对比预训练方法进行STL-10数据集的半监督图像分类。

介绍

半监督学习

半监督学习是一种处理部分标记数据集的机器学习范式。在实际应用深度学习时,通常需要收集大量数据集以使其良好运行。然而,标记成本与数据集大小成线性关系(标记每个示例的时间是恒定的),而模型性能只与数据集大小成亚线性关系。这意味着标记越来越多的样本成本效益越来越低,而收集未标记的数据通常便宜,因为通常有大量可获得的未标记数据。

半监督学习通过只要求部分标记数据集并通过利用未标记示例进行学习来解决这个问题。

在这个例子中,我们将在STL-10半监督数据集上使用对比学习在没有任何标签的情况下预训练一个编码器,然后仅使用其标记的子集进行微调。

对比学习

从最高层面上讲,对比学习的主要思想是以自监督的方式学习对图像增强不变的表示。这个目标的一个问题是它有一个微不足道的退化解:表示是常数,不依赖输入图像。

对比学习通过以下方式修改目标来避免这个陷阱它将同一图像的增强版本/视图的表示拉近(收缩正样本),同时将不同的图像在表示空间中推开(对比负样本)。

一种这样的对比方法是SimCLR,它本质上识别出了优化这个目标所需的核心组件,并且通过扩展这种简单方法可以实现高性能。

另一种方法是SimSiam(Keras示例),它与SimCLR的主要区别在于前者在损失函数中不使用任何负样本。

因此,它并没有明确地避免微不足道的解,而是通过架构设计隐含地避免它(在最后几层使用具有预测网络和批归一化(BatchNorm)的非对称编码路径)。

设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"


# Make sure we are able to handle large datasets
import resource

low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))

import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

import keras
from keras import ops
from keras import layers

超参数设置

# Dataset hyperparameters
unlabeled_dataset_size = 100000
labeled_dataset_size = 5000
image_channels = 3

# Algorithm hyperparameters
num_epochs = 20
batch_size = 525  # Corresponds to 200 steps per epoch
width = 128
temperature = 0.1
# Stronger augmentations for contrastive, weaker ones for supervised training
contrastive_augmentation = {"min_area": 0.25, "brightness": 0.6, "jitter": 0.2}
classification_augmentation = {
    "min_area": 0.75,
    "brightness": 0.3,
    "jitter": 0.1,
}

数据集

在训练过程中,我们将同时加载一个大批未标记的图像和一个较小的标记图像的批次。

def prepare_dataset():
    # Labeled and unlabeled samples are loaded synchronously
    # with batch sizes selected accordingly
    steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
    unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
    labeled_batch_size = labeled_dataset_size // steps_per_epoch
    print(
        f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
    )

    # Turning off shuffle to lower resource usage
    unlabeled_train_dataset = (
        tfds.load("stl10", split="unlabelled", as_supervised=True, shuffle_files=False)
        .shuffle(buffer_size=10 * unlabeled_batch_size)
        .batch(unlabeled_batch_size)
    )
    labeled_train_dataset = (
        tfds.load("stl10", split="train", as_supervised=True, shuffle_files=False)
        .shuffle(buffer_size=10 * labeled_batch_size)
        .batch(labeled_batch_size)
    )
    test_dataset = (
        tfds.load("stl10", split="test", as_supervised=True)
        .batch(batch_size)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

    # Labeled and unlabeled datasets are zipped together
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=tf.data.AUTOTUNE)

    return train_dataset, labeled_train_dataset, test_dataset


# Load STL10 dataset
train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

演绎展示:

batch size is 500 (unlabeled) + 25 (labeled)

图像增强

对于对比学习,最重要的两种图像增强方式如下:

裁剪:通过RandomTranslation和RandomZoom图层,迫使模型对同一图像的不同部分进行相似编码。

颜色抖动:通过扭曲颜色直方图,防止基于颜色直方图的简单解决方案。在颜色空间中使用仿射变换是一种实现的原则性方法。

在这个例子中,我们还使用了随机水平翻转。对于对比学习,我们会应用更强的增强方式,并且对于监督分类,我们会使用更弱的增强方式,以避免对少量标记示例的过拟合。

我们将随机颜色抖动作为自定义的预处理层进行实施。使用数据增强的预处理层具有以下两个优点:

数据增强将在GPU上以批处理方式运行,因此在具有受限CPU资源的环境中(如Colab Notebook或个人机器)训练不会受到数据管道的瓶颈限制。

部署更加简单,因为数据预处理流程被封装在模型中,不需要在部署时重新实现。

# Distorts the color distibutions of images
class RandomColorAffine(layers.Layer):
    def __init__(self, brightness=0, jitter=0, **kwargs):
        super().__init__(**kwargs)

        self.seed_generator = keras.random.SeedGenerator(1337)
        self.brightness = brightness
        self.jitter = jitter

    def get_config(self):
        config = super().get_config()
        config.update({"brightness": self.brightness, "jitter": self.jitter})
        return config

    def call(self, images, training=True):
        if training:
            batch_size = ops.shape(images)[0]

            # Same for all colors
            brightness_scales = 1 + keras.random.uniform(
                (batch_size, 1, 1, 1),
                minval=-self.brightness,
                maxval=self.brightness,
                seed=self.seed_generator,
            )
            # Different for all colors
            jitter_matrices = keras.random.uniform(
                (batch_size, 1, 3, 3), 
                minval=-self.jitter, 
                maxval=self.jitter,
                seed=self.seed_generator,
            )

            color_transforms = (
                ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
                * brightness_scales
                + jitter_matrices
            )
            images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
        return images


# Image augmentation module
def get_augmenter(min_area, brightness, jitter):
    zoom_factor = 1.0 - math.sqrt(min_area)
    return keras.Sequential(
        [
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
            layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
            RandomColorAffine(brightness, jitter),
        ]
    )


def visualize_augmentations(num_images):
    # Sample a batch from a dataset
    images = next(iter(train_dataset))[0][0][:num_images]

    # Apply augmentations
    augmented_images = zip(
        images,
        get_augmenter(**classification_augmentation)(images),
        get_augmenter(**contrastive_augmentation)(images),
        get_augmenter(**contrastive_augmentation)(images),
    )
    row_titles = [
        "Original:",
        "Weakly augmented:",
        "Strongly augmented:",
        "Strongly augmented:",
    ]
    plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
    for column, image_row in enumerate(augmented_images):
        for row, image in enumerate(image_row):
            plt.subplot(4, num_images, row * num_images + column + 1)
            plt.imshow(image)
            if column == 0:
                plt.title(row_titles[row], loc="left")
            plt.axis("off")
    plt.tight_layout()


visualize_augmentations(num_images=8)

编码器结构

# Define the encoder architecture
def get_encoder():
    return keras.Sequential(
        [
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Flatten(),
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",
    )

有监督基线模型

使用随机初始化训练基线监督模型。

# Baseline supervised training with random initialization
baseline_model = keras.Sequential(
    [
        get_augmenter(**classification_augmentation),
        get_encoder(),
        layers.Dense(10),
    ],
    name="baseline_model",
)
baseline_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

baseline_history = baseline_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)

演绎展示:

Epoch 1/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 9s 25ms/step - acc: 0.2031 - loss: 2.1576 - val_acc: 0.3234 - val_loss: 1.7719
Epoch 2/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.3476 - loss: 1.7792 - val_acc: 0.4042 - val_loss: 1.5626
Epoch 3/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4060 - loss: 1.6054 - val_acc: 0.4319 - val_loss: 1.4832
Epoch 4/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.4347 - loss: 1.5052 - val_acc: 0.4570 - val_loss: 1.4428
Epoch 5/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.4600 - loss: 1.4546 - val_acc: 0.4765 - val_loss: 1.3977
Epoch 6/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4754 - loss: 1.4015 - val_acc: 0.4740 - val_loss: 1.4082
Epoch 7/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4901 - loss: 1.3589 - val_acc: 0.4761 - val_loss: 1.4061
Epoch 8/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5110 - loss: 1.2793 - val_acc: 0.5247 - val_loss: 1.3026
Epoch 9/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5298 - loss: 1.2765 - val_acc: 0.5138 - val_loss: 1.3286
Epoch 10/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5514 - loss: 1.2078 - val_acc: 0.5543 - val_loss: 1.2227
Epoch 11/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5520 - loss: 1.1851 - val_acc: 0.5446 - val_loss: 1.2709
Epoch 12/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5851 - loss: 1.1368 - val_acc: 0.5725 - val_loss: 1.1944
Epoch 13/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.5738 - loss: 1.1411 - val_acc: 0.5685 - val_loss: 1.1974
Epoch 14/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 21ms/step - acc: 0.6078 - loss: 1.0308 - val_acc: 0.5899 - val_loss: 1.1769
Epoch 15/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.6284 - loss: 1.0386 - val_acc: 0.5863 - val_loss: 1.1742
Epoch 16/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.6450 - loss: 0.9773 - val_acc: 0.5849 - val_loss: 1.1993
Epoch 17/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6547 - loss: 0.9555 - val_acc: 0.5683 - val_loss: 1.2424
Epoch 18/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6593 - loss: 0.9084 - val_acc: 0.5990 - val_loss: 1.1458
Epoch 19/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6672 - loss: 0.9267 - val_acc: 0.5685 - val_loss: 1.2758
Epoch 20/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6824 - loss: 0.8863 - val_acc: 0.5969 - val_loss: 1.2035
Maximal validation accuracy: 59.90%

用于对比预训练的自我监督模型

我们使用对比损失对无标签图像进行编码器预训练。编码器顶部安装了一个非线性投影头,因为它能提高编码器的表示质量。

我们使用的是 InfoNCE/NT-Xent/N-pairs loss(对比度损失),可以用以下方式解释:

1. 我们将批次中的每幅图像都视为有自己的类别。
2. 然后,我们为每个 "类 "提供两个示例(一对增强视图)。
3. 每个视图的表示都要与每对可能的视图(两个增强版本)进行比较。
4. 我们使用比较表征的温标余弦相似度作为对数。
5. 最后,我们使用分类交叉熵作为 "分类 "损失

以下两个指标用于监测预训练性能:

对比准确率(SimCLR表格5):自监督指标,表示在当前批次中,图像表示与其不同增强版本的表示相比,比与任何其他图像的表示更相似的情况的比例。即使没有标记的样本,自监督指标也可以用于超参数调整。

线性探测准确率:线性探测是评估自监督分类器的常用指标。它是在编码器特征之上训练的逻辑回归分类器的准确率。在我们的情况下,这是通过在冻结的编码器之上训练一个单独的密集层来实现的。

需要注意的是,与传统方法不同,其中分类器是在预训练阶段之后训练的,在这个例子中,我们在预训练期间训练它。这可能会略微降低它的准确性,但这样我们可以在训练过程中监测它的值,这有助于实验和调试。

另一个广泛使用的有监督指标是KNN准确率,它是在编码器特征之上训练的KNN分类器的准确率,该指标在此示例中未实施。

# Define the contrastive model with model-subclassing
class ContrastiveModel(keras.Model):
    def __init__(self):
        super().__init__()

        self.temperature = temperature
        self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
        self.classification_augmenter = get_augmenter(**classification_augmentation)
        self.encoder = get_encoder()
        # Non-linear MLP as projection head
        self.projection_head = keras.Sequential(
            [
                keras.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        # Single dense layer for linear probing
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(10)],
            name="linear_probe",
        )

        self.encoder.summary()
        self.projection_head.summary()
        self.linear_probe.summary()

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)

        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

        # self.contrastive_loss will be defined as a method
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
            name="c_acc"
        )
        self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")

    @property
    def metrics(self):
        return [
            self.contrastive_loss_tracker,
            self.contrastive_accuracy,
            self.probe_loss_tracker,
            self.probe_accuracy,
        ]

    def contrastive_loss(self, projections_1, projections_2):
        # InfoNCE loss (information noise-contrastive estimation)
        # NT-Xent loss (normalized temperature-scaled cross entropy)

        # Cosine similarity: the dot product of the l2-normalized feature vectors
        projections_1 = ops.normalize(projections_1, axis=1)
        projections_2 = ops.normalize(projections_2, axis=1)
        similarities = (
            ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
        )

        # The similarity between the representations of two augmented views of the
        # same image should be higher than their similarity with other views
        batch_size = ops.shape(projections_1)[0]
        contrastive_labels = ops.arange(batch_size)
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(
            contrastive_labels, ops.transpose(similarities)
        )

        # The temperature-scaled similarities are used as logits for cross-entropy
        # a symmetrized version of the loss is used here
        loss_1_2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, similarities, from_logits=True
        )
        loss_2_1 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, ops.transpose(similarities), from_logits=True
        )
        return (loss_1_2 + loss_2_1) / 2

    def train_step(self, data):
        (unlabeled_images, _), (labeled_images, labels) = data

        # Both labeled and unlabeled images are used, without labels
        images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
        # Each image is augmented twice, differently
        augmented_images_1 = self.contrastive_augmenter(images, training=True)
        augmented_images_2 = self.contrastive_augmenter(images, training=True)
        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1, training=True)
            features_2 = self.encoder(augmented_images_2, training=True)
            # The representations are passed through a projection mlp
            projections_1 = self.projection_head(features_1, training=True)
            projections_2 = self.projection_head(features_2, training=True)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.contrastive_loss_tracker.update_state(contrastive_loss)

        # Labels are only used in evalutation for an on-the-fly logistic regression
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=True
        )
        with tf.GradientTape() as tape:
            # the encoder is used in inference mode here to avoid regularization
            # and updating the batch normalization paramers if they are used
            features = self.encoder(preprocessed_images, training=False)
            class_logits = self.linear_probe(features, training=True)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        labeled_images, labels = data

        # For testing the components are used with a training=False flag
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.encoder(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        # Only the probe metrics are logged at test time
        return {m.name: m.result() for m in self.metrics[2:]}


# Contrastive pretraining
pretraining_model = ContrastiveModel()
pretraining_model.compile(
    contrastive_optimizer=keras.optimizers.Adam(),
    probe_optimizer=keras.optimizers.Adam(),
)

pretraining_history = pretraining_model.fit(
    train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(pretraining_history.history["val_p_acc"]) * 100
    )
)

演绎展示:

演绎展示:

Epoch 1/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 34s 134ms/step - c_acc: 0.0880 - c_loss: 5.2606 - p_acc: 0.1326 - p_loss: 2.2726 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.2579 - val_p_loss: 2.0671
Epoch 2/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.2808 - c_loss: 3.6233 - p_acc: 0.2956 - p_loss: 2.0228 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.3440 - val_p_loss: 1.9242
Epoch 3/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 136ms/step - c_acc: 0.4097 - c_loss: 2.9369 - p_acc: 0.3671 - p_loss: 1.8674 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.3876 - val_p_loss: 1.7757
Epoch 4/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 30s 142ms/step - c_acc: 0.4893 - c_loss: 2.5707 - p_acc: 0.3957 - p_loss: 1.7490 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.3960 - val_p_loss: 1.7002
Epoch 5/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 136ms/step - c_acc: 0.5458 - c_loss: 2.3342 - p_acc: 0.4274 - p_loss: 1.6608 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4374 - val_p_loss: 1.6145
Epoch 6/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.5949 - c_loss: 2.1179 - p_acc: 0.4410 - p_loss: 1.5812 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4444 - val_p_loss: 1.5439
Epoch 7/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.6273 - c_loss: 1.9861 - p_acc: 0.4633 - p_loss: 1.5076 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4695 - val_p_loss: 1.5056
Epoch 8/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.6566 - c_loss: 1.8668 - p_acc: 0.4817 - p_loss: 1.4601 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4790 - val_p_loss: 1.4566
Epoch 9/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.6726 - c_loss: 1.7938 - p_acc: 0.4885 - p_loss: 1.4136 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4933 - val_p_loss: 1.4163
Epoch 10/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.6931 - c_loss: 1.7210 - p_acc: 0.4954 - p_loss: 1.3663 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5140 - val_p_loss: 1.3677
Epoch 11/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 137ms/step - c_acc: 0.7055 - c_loss: 1.6619 - p_acc: 0.5210 - p_loss: 1.3376 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5155 - val_p_loss: 1.3573
Epoch 12/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 30s 145ms/step - c_acc: 0.7215 - c_loss: 1.6112 - p_acc: 0.5264 - p_loss: 1.2920 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5232 - val_p_loss: 1.3337
Epoch 13/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 31s 146ms/step - c_acc: 0.7279 - c_loss: 1.5749 - p_acc: 0.5388 - p_loss: 1.2570 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5217 - val_p_loss: 1.3155
Epoch 14/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.7435 - c_loss: 1.5196 - p_acc: 0.5505 - p_loss: 1.2507 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5460 - val_p_loss: 1.2640
Epoch 15/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 40s 135ms/step - c_acc: 0.7477 - c_loss: 1.4979 - p_acc: 0.5653 - p_loss: 1.2188 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5594 - val_p_loss: 1.2351
Epoch 16/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.7598 - c_loss: 1.4463 - p_acc: 0.5590 - p_loss: 1.1917 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5551 - val_p_loss: 1.2411
Epoch 17/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.7633 - c_loss: 1.4271 - p_acc: 0.5775 - p_loss: 1.1731 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5502 - val_p_loss: 1.2428
Epoch 18/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.7666 - c_loss: 1.4246 - p_acc: 0.5752 - p_loss: 1.1805 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5633 - val_p_loss: 1.2167
Epoch 19/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.7708 - c_loss: 1.3928 - p_acc: 0.5814 - p_loss: 1.1677 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5665 - val_p_loss: 1.2191
Epoch 20/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.7806 - c_loss: 1.3733 - p_acc: 0.5836 - p_loss: 1.1442 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5640 - val_p_loss: 1.2172
Maximal validation accuracy: 56.65%

对预训练编码器进行有监督微调

我们通过在其顶部附加一个单独随机初始化的全连接分类层,对编码器进行标记示例的微调。

# Supervised finetuning of the pretrained encoder
finetuning_model = keras.Sequential(
    [
        get_augmenter(**classification_augmentation),
        pretraining_model.encoder,
        layers.Dense(10),
    ],
    name="finetuning_model",
)
finetuning_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

finetuning_history = finetuning_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetuning_history.history["val_acc"]) * 100
    )
)
Epoch 1/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 5s 18ms/step - acc: 0.2104 - loss: 2.0930 - val_acc: 0.4017 - val_loss: 1.5433
Epoch 2/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4037 - loss: 1.5791 - val_acc: 0.4544 - val_loss: 1.4250
Epoch 3/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4639 - loss: 1.4161 - val_acc: 0.5266 - val_loss: 1.2958
Epoch 4/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5438 - loss: 1.2686 - val_acc: 0.5655 - val_loss: 1.1711
Epoch 5/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5678 - loss: 1.1746 - val_acc: 0.5775 - val_loss: 1.1670
Epoch 6/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6096 - loss: 1.1071 - val_acc: 0.6034 - val_loss: 1.1400
Epoch 7/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6242 - loss: 1.0413 - val_acc: 0.6235 - val_loss: 1.0756
Epoch 8/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6284 - loss: 1.0264 - val_acc: 0.6030 - val_loss: 1.1048
Epoch 9/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6491 - loss: 0.9706 - val_acc: 0.5770 - val_loss: 1.2818
Epoch 10/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6754 - loss: 0.9104 - val_acc: 0.6119 - val_loss: 1.1087
Epoch 11/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - acc: 0.6620 - loss: 0.8855 - val_acc: 0.6323 - val_loss: 1.0526
Epoch 12/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - acc: 0.7060 - loss: 0.8179 - val_acc: 0.6406 - val_loss: 1.0565
Epoch 13/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - acc: 0.7252 - loss: 0.7796 - val_acc: 0.6135 - val_loss: 1.1273
Epoch 14/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7176 - loss: 0.7935 - val_acc: 0.6292 - val_loss: 1.1028
Epoch 15/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7322 - loss: 0.7471 - val_acc: 0.6266 - val_loss: 1.1313
Epoch 16/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7400 - loss: 0.7218 - val_acc: 0.6332 - val_loss: 1.1064
Epoch 17/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7490 - loss: 0.6968 - val_acc: 0.6532 - val_loss: 1.0112
Epoch 18/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7491 - loss: 0.6879 - val_acc: 0.6403 - val_loss: 1.1083
Epoch 19/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7802 - loss: 0.6504 - val_acc: 0.6479 - val_loss: 1.0548
Epoch 20/20
 200/200 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - acc: 0.7800 - loss: 0.6234 - val_acc: 0.6409 - val_loss: 1.0998
Maximal validation accuracy: 65.32%

与基准线的比较

# The classification accuracies of the baseline and the pretraining + finetuning process:
def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
    for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
        plt.figure(figsize=(8, 5), dpi=100)
        plt.plot(
            baseline_history.history[f"val_{metric_key}"],
            label="supervised baseline",
        )
        plt.plot(
            pretraining_history.history[f"val_p_{metric_key}"],
            label="self-supervised pretraining",
        )
        plt.plot(
            finetuning_history.history[f"val_{metric_key}"],
            label="supervised finetuning",
        )
        plt.legend()
        plt.title(f"Classification {metric_name} during training")
        plt.xlabel("epochs")
        plt.ylabel(f"validation {metric_name}")


plot_training_curves(pretraining_history, finetuning_history, baseline_history)

通过比较训练曲线,我们可以看到在使用对比预训练时,可以达到更高的验证准确率,并伴随着更低的验证损失,这意味着预训练网络在只看到少量标记样本时能够更好地泛化。

进一步改进

架构

原始论文中的实验证明,增加模型的宽度和深度比监督学习的改进效果更高。此外,在文献中使用ResNet-50编码器是相当常见的。然而,请记住,更强大的模型不仅会增加训练时间,还需要更多的内存,并且会限制您可以使用的最大批量大小。

据报道,使用BatchNorm层有时可能会降低性能,因为它在样本之间引入了一个批内依赖关系,这就是为什么我在这个示例中没有使用。然而,在我的实验中,尤其是在投影头部使用BatchNorm可以提高性能。

超参数

在这个示例中使用的超参数是针对这个任务和架构手动调整过的。因此,如果不改变它们,进一步的超参数调整只能带来较小的收益。

然而,对于不同的任务或模型架构,这些参数需要进行调整,因此以下是我对其中最重要的几个参数的注释:

批处理大小:由于目标可以被解释为对一批图像进行分类(粗略地说),批处理大小实际上比通常更重要的超参数。批处理大小越大越好。

温度:温度定义了在交叉熵损失中使用的softmax分布的“软度”,是一个重要的超参数。较小的值通常会导致更高的对比精度。最近的一个技巧(在ALIGN中)是学习温度的值(可以通过将其定义为tf.Variable,并对其应用梯度来实现)。尽管这提供了一个良好的基准值,但在我的实验中,学习得到的温度稍低于最优值,因为它是针对对比损失进行优化的,而对比损失并不是表示质量的完美代理。

图像增强强度:在预训练期间,强化增强会增加任务的难度,但是在一定程度上,强化增强太强会降低性能。在微调期间,强化增强可以减少过拟合,但是根据我的经验,强化增强太强会降低预训练带来的性能提升。整个数据增强流程可以看作是算法的一个重要超参数,在这个存储库中可以找到Keras中其他自定义图像增强层的实现。

学习率调度:这里使用的是恒定的调度,但在文献中使用余弦衰减调度是很常见的,它可以进一步提高性能。

优化器:在这个例子中使用的是Adam,因为它在默认参数下提供了良好的性能。带有动量的SGD需要更多的调整,但可能会略微提高性能。

相关工作

其他实例级(图像级)对比学习方法:

MoCo(v2,v3):同样使用动量编码器,其权重是目标编码器的指数移动平均值

SwAV:使用聚类而不是成对比较

BarlowTwins:使用基于交叉相关的目标而不是成对比较 MoCo和BarlowTwins的Keras实现可以在此存储库中找到,其中包括一个Colab笔记本。

还有一种新的方法,优化类似的目标,但不使用任何负样本

BYOL:动量编码器+无负样本

SimSiam(Keras示例):无动量编码器+无负样本

根据我的经验,这些方法更加脆弱(它们可能会崩溃为常数表示,我无法使用此编码器架构使它们正常工作)。尽管它们通常更依赖于模型架构,但它们可以在较小的批量大小下提高性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/541803.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Guava里一些比较常用的工具

随着java版本的更新提供了越来越多的语法和工具来简化日常开发,但是我们一般用的比较早的版本所以体验不到。这时就用到了guava这个包。guava提供了很多方便的工具方法,solar框架就依赖了guava的16.0.1版本,这里稍微介绍下。 一、集合工具类…

20240413,类和对象

对象:一切都可为对象,类:相同特性的对象;面向对象特性:封装,继承,多态 一,封装 CLASS 类名 { 访问权限 :属性/行为 } 实例化:通过一个类,创建一…

0101tomcat部署war访问mysql失败-容器间通信-docker项目部署

文章目录 一、简介二、部署1、mysql数据迁移2、docker部署redis3、docker部署tomcat并运行war包 三、报错四、解决1 分析2 解决 结语 一、简介 最近参与开发一个项目,其中一部分系统需要迁移。从阿里云迁移到实体服务器,使用docker部署。系统使用Java语…

11 Php学习:函数

PHP 内建函数Array 函数 PHP Array 函数是 PHP 核心的组成部分。无需安装即可使用这些函数。 创建 PHP 函数 当您需要在 PHP 中封装一段可重复使用的代码块时,可以使用函数。下面详细解释如何创建 PHP 函数并举例说明。 创建 PHP 函数的语法 PHP 函数的基…

TCP-IP详解卷一:协议——阅读总结

该内容适合程序员查看 第1章 概述 1.1 引言 WAN全称是 Wide Area Network,中文名为广域网。 LAN全称是 Local Area Network,中文名为局域网。 1.2分层 ICP/IP协议族通常被认为是一个四层协议系统 分层协议应用层Telnet、FTP和e-mail运输层TCP和UDP网…

shell 调用钉钉通知

使用场景:机器能访问互联网,运行时间任务后通知使用 钉钉建立单人群 手机操作,只能通过手机方式建立单人群 电脑端 2. 配置脚本 #!/bin/bash set -e## 上图中 access_token字段 TOKEN KEYWORDhello # 前文中设置的关键字 function call_…

机器学习——自动驾驶

本章我们主要学习以下内容: 阅读自动驾驶论文采集数据根据论文搭建自动驾驶神经网络训练模型在仿真环境中进行自动驾驶 论文介绍 本文参考自2016年英伟达发表的论文《End to End Learning for Self-Driving Cars》 📎end2end.pdf

Redis从入门到精通(十七)多级缓存(二)Lua语言入门、OpenResty集群的安装与使用

文章目录 前言6.4 Lua语法入门6.4.1 初识Lua6.4.2 Hello World6.4.3 变量6.4.3.1 Lua的数据类型6.4.3.2 声明变量 6.4.4 循环6.4.5 函数6.4.6 条件控制 6.5 实现多级缓存6.5.1 安装和启动OpenResty6.5.2 实现ajax请求反向代理至OpenResty集群6.5.2.1 反向代理配置6.5.2.2 OpenR…

chromium 协议栈 cronet ios 踩坑案例

1、请求未携带 Accept-Language http header 出现图片加载失败 现象: 访问 https://www.huawei.com/cn/?ic_mediumdirect&ic_sourcesurlent 时出现图片加载失败的问题 预期结果: 原因: 网络库删除了添加 Accept-Language header 的逻…

openssl 如何从pfx格式证书 获取证书序列号信息

已知:一个个人证书文件 test.pfx 求:如何通过openssl查看其对应证书的序列号信息? 踩坑之:unable to load certificate! openssl x509 -in xxx.cert -noout -serial 命令可查看证书序列号,但是这个-in 的输入必须是私…

Android开发:Camera2+MediaRecorder录制视频后上传到阿里云VOD

文章目录 版权声明前言1.Camera1和Camera2的区别2.为什么选择Camera2? 一、应用Camera2MediaPlayer实现拍摄功能引入所需权限构建UI界面的XMLActivity中的代码部分 二、在上述界面录制结束后点击跳转新的界面进行视频播放构建播放界面部分的XMLActivity的代码上述代…

基于Python的深度学习的中文情感分析系统(V2.0),附源码

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

2024年MathorCup数学建模B题甲骨文智能识别中原始拓片单字自动分割与识别研究解题文档与程序

2024年第十四届MathorCup高校数学建模挑战赛 B题 甲骨文智能识别中原始拓片单字自动分割与识别研究 原题再现: 甲骨文是我国目前已知的最早成熟的文字系统,它是一种刻在龟甲或兽骨上的古老文字。甲骨文具有极其重要的研究价值,不仅对中国文…

【学习心得】神经网络知识中的符号解释②

我在上篇文章中初步介绍了一些神经网络中的符号,只有统一符号及其对应的含义才能使我自己在后续的深度学习中有着一脉相承的体系。如果对我之前的文章感兴趣可以点击链接看看哦: 【学习心得】神经网络知识中的符号解释①http://t.csdnimg.cn/f6PeJ 一、…

010、Python+fastapi,第一个后台管理项目走向第10步:ubutun 20.04下安装ngnix+mysql8+redis5环境

一、说明 先吐槽一下,ubuntu 界面还是不习惯,而且用的是云电脑,有些快捷键不好用,只能将就,谁叫我们穷呢? 正在思考怎么往后进行,突然发现没安装mysql 和redis,准备安装&#xff0…

组合预测 | Matlab实现ICEEMDAN-SMA-SVM基于改进完备集合经验模态分解-黏菌优化算法-支持向量机的时间序列预测

组合预测 | Matlab实现ICEEMDAN-SMA-SVM基于改进完备集合经验模态分解-黏菌优化算法-支持向量机的时间序列预测 目录 组合预测 | Matlab实现ICEEMDAN-SMA-SVM基于改进完备集合经验模态分解-黏菌优化算法-支持向量机的时间序列预测预测效果基本介绍程序设计参考资料预测效果 基本…

kali工具----域名IP及路由跟踪

域名IP及路由跟踪 测试网络范围内的IP地址或域名也是渗透测试的一个重要部分。通过测试网络范围内的IP地址或域名,确定是否有人入侵自己的网络中并损害系统。不少单位选择仅对局部IP基础架构进行渗透测试,但从现在的安全形势来看,只有对整个I…

未来计算机的发展趋势是什么?

未来计算机的发展趋势是多方面的,涵盖了硬件、软件、体系结构以及计算范式等多个层面。以下是一些预期的趋势: 1. 量子计算: 随着量子理论的不断成熟和技术的进步,量子计算机将可能解决传统计算机难以处理的问题,比如药物发现、材料科学、复杂系统模拟等领域。量子计算的…

UT单元测试

Tips:在使用时一定要注意版本适配性问题 一、Mockito 1.1 Mock的使用 Mock 的中文译为仿制的,模拟的,虚假的。对于测试框架来说,即构造出一个模拟/虚假的对象,使我们的测试能顺利进行下去。 Mock 测试就是在测试过程…

SQL语言自用(持续更新)+实验记录

课本:《数据库原理及其应用教程》(第四版) (主编)黄德才&(副主编)陆亿红 实验:学校实验课材料 其他: [ ]表示可以被删除,也表示可以被替换,请自行判断。如果有一些截图或照片,是暂时懒得整…