政安晨:【Keras机器学习示例演绎】(十六)—— 用于图像分类的混合增强

目录

简介

设置

准备数据集

定义超参数

将数据转换为 TensorFlow 数据集对象

定义混合技术函数

可视化新的增强数据集

模型制作

1.使用混合数据集训练模型

2.在没有混合数据集的情况下训练模型

说明


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:利用混合技术对图像分类进行数据扩增。

简介


mixup 是由 Zhang 等人在 mixup.Beyond Empirical Risk Minimization 一书中提出的一种与领域无关的数据增强技术。

这项技术的名称相当系统。从字面上看,我们是在混合特征及其相应的标签。实施起来很简单。神经网络很容易记住错误的标签。mixup 通过将不同的特征相互组合(标签也是如此)来放松这一点,这样网络就不会对特征及其标签之间的关系过于自信。

当我们不确定是否要为给定数据集(例如医学影像数据集)选择一组增强变换时,mixup 特别有用。

设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import keras
import matplotlib.pyplot as plt

from keras import layers

# TF imports related to tf.data preprocessing
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow.random import gamma as tf_random_gamma

准备数据集

在本例中,我们将使用 FashionMNIST 数据集。但同样的方法也可用于其他分类数据集。

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
y_train = keras.ops.one_hot(y_train, 10)

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
y_test = keras.ops.one_hot(y_test, 10)

定义超参数

AUTO = tf_data.AUTOTUNE
BATCH_SIZE = 64
EPOCHS = 10

将数据转换为 TensorFlow 数据集对象

# Put aside a few samples to create our validation set
val_samples = 2000
x_val, y_val = x_train[:val_samples], y_train[:val_samples]
new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:]

train_ds_one = (
    tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
)
train_ds_two = (
    tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
)
# Because we will be mixing up the images and their corresponding labels, we will be
# combining two shuffled datasets from the same training data.
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))

val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)

test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

定义混合技术函数


为了执行混合例程,我们使用同一数据集的训练数据创建新的虚拟数据集,并应用从 Beta 分布采样的 [0, 1] 范围内的 lambda 值,例如,new_x = lambda * x1 + (1 - lambda) * x2(其中 x1 和 x2 为图像),同样的等式也应用于标签。

def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
    gamma_1_sample = tf_random_gamma(shape=[size], alpha=concentration_1)
    gamma_2_sample = tf_random_gamma(shape=[size], alpha=concentration_0)
    return gamma_1_sample / (gamma_1_sample + gamma_2_sample)


def mix_up(ds_one, ds_two, alpha=0.2):
    # Unpack two datasets
    images_one, labels_one = ds_one
    images_two, labels_two = ds_two
    batch_size = keras.ops.shape(images_one)[0]

    # Sample lambda and reshape it to do the mixup
    l = sample_beta_distribution(batch_size, alpha, alpha)
    x_l = keras.ops.reshape(l, (batch_size, 1, 1, 1))
    y_l = keras.ops.reshape(l, (batch_size, 1))

    # Perform mixup on both images and labels by combining a pair of images/labels
    # (one from each dataset) into one image/label
    images = images_one * x_l + images_two * (1 - x_l)
    labels = labels_one * y_l + labels_two * (1 - y_l)
    return (images, labels)

请注意,这里我们是将两幅图像合并为一幅理论上,我们可以组合任意多的图像,但这样做会增加计算成本。在某些情况下,这可能也无助于提高性能。

可视化新的增强数据集

# First create the new dataset using our `mix_up` utility
train_ds_mu = train_ds.map(
    lambda ds_one, ds_two: mix_up(ds_one, ds_two, alpha=0.2),
    num_parallel_calls=AUTO,
)

# Let's preview 9 samples from the dataset
sample_images, sample_labels = next(iter(train_ds_mu))
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().squeeze())
    print(label.numpy().tolist())
    plt.axis("off")

演绎展示:

[0.0, 0.9964277148246765, 0.0, 0.0, 0.003572270041331649, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.9794676899909973, 0.02053229510784149, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.9536369442939758, 0.0, 0.0, 0.0, 0.04636305570602417, 0.0]
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7631776928901672, 0.0, 0.0, 0.23682232201099396]
[0.0, 0.0, 0.045958757400512695, 0.0, 0.0, 0.0, 0.9540412425994873, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 2.8015051611873787e-08, 0.0, 0.0, 1.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0003173351287841797, 0.0, 0.9996826648712158, 0.0, 0.0, 0.0, 0.0]

模型制作

def get_training_model():
    model = keras.Sequential(
        [
            layers.Input(shape=(28, 28, 1)),
            layers.Conv2D(16, (5, 5), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(32, (5, 5), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Dropout(0.2),
            layers.GlobalAveragePooling2D(),
            layers.Dense(128, activation="relu"),
            layers.Dense(10, activation="softmax"),
        ]
    )
    return model

为了提高可重复性,我们将浅层网络的初始随机权重序列化。

initial_model = get_training_model()
initial_model.save_weights("initial_weights.weights.h5")

1.使用混合数据集训练模型

model = get_training_model()
model.load_weights("initial_weights.weights.h5")
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS)
_, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
Epoch 1/10
  62/907 ━[37m━━━━━━━━━━━━━━━━━━━  2s 3ms/step - accuracy: 0.2518 - loss: 2.2072

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699655923.381468   16749 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

 907/907 ━━━━━━━━━━━━━━━━━━━━ 13s 9ms/step - accuracy: 0.5335 - loss: 1.4414 - val_accuracy: 0.7635 - val_loss: 0.6678
Epoch 2/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 12s 4ms/step - accuracy: 0.7168 - loss: 0.9688 - val_accuracy: 0.7925 - val_loss: 0.5849
Epoch 3/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 4ms/step - accuracy: 0.7525 - loss: 0.8940 - val_accuracy: 0.8290 - val_loss: 0.5138
Epoch 4/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.7742 - loss: 0.8431 - val_accuracy: 0.8360 - val_loss: 0.4726
Epoch 5/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.7876 - loss: 0.8095 - val_accuracy: 0.8550 - val_loss: 0.4450
Epoch 6/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8029 - loss: 0.7794 - val_accuracy: 0.8560 - val_loss: 0.4178
Epoch 7/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - accuracy: 0.8039 - loss: 0.7632 - val_accuracy: 0.8600 - val_loss: 0.4056
Epoch 8/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8115 - loss: 0.7465 - val_accuracy: 0.8510 - val_loss: 0.4114
Epoch 9/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8115 - loss: 0.7364 - val_accuracy: 0.8645 - val_loss: 0.3983
Epoch 10/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8182 - loss: 0.7237 - val_accuracy: 0.8630 - val_loss: 0.3735
 157/157 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8610 - loss: 0.4030
Test accuracy: 85.82%

2.在没有混合数据集的情况下训练模型

model = get_training_model()
model.load_weights("initial_weights.weights.h5")
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
# Notice that we are NOT using the mixed up dataset here
model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS)
_, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
Epoch 1/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.5690 - loss: 1.1928 - val_accuracy: 0.7585 - val_loss: 0.6519
Epoch 2/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 2ms/step - accuracy: 0.7525 - loss: 0.6484 - val_accuracy: 0.7860 - val_loss: 0.5799
Epoch 3/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - accuracy: 0.7895 - loss: 0.5661 - val_accuracy: 0.8205 - val_loss: 0.5122
Epoch 4/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - accuracy: 0.8148 - loss: 0.5126 - val_accuracy: 0.8415 - val_loss: 0.4375
Epoch 5/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - accuracy: 0.8306 - loss: 0.4636 - val_accuracy: 0.8610 - val_loss: 0.3913
Epoch 6/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - accuracy: 0.8433 - loss: 0.4312 - val_accuracy: 0.8680 - val_loss: 0.3734
Epoch 7/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - accuracy: 0.8544 - loss: 0.4072 - val_accuracy: 0.8750 - val_loss: 0.3606
Epoch 8/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - accuracy: 0.8577 - loss: 0.3913 - val_accuracy: 0.8735 - val_loss: 0.3520
Epoch 9/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - accuracy: 0.8645 - loss: 0.3803 - val_accuracy: 0.8725 - val_loss: 0.3536
Epoch 10/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8686 - loss: 0.3597 - val_accuracy: 0.8745 - val_loss: 0.3395
 157/157 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.8705 - loss: 0.3672
Test accuracy: 86.92%

我们鼓励读者在不同领域的不同数据集上试用 mixup,并尝试使用 lambda 参数。我们强烈建议您同时查看原始论文--作者介绍了几项关于 mixup 的消融研究,展示了 mixup 如何提高泛化效果,并展示了他们将两幅以上图像合并为一幅图像的结果。

说明


× 通过混合,您可以创建合成示例,尤其是在缺乏大型数据集的情况下,而不会产生高昂的计算成本。
× 标签平滑和 mixup 通常不能很好地结合使用,因为标签平滑已经对硬标签进行了一定程度的修改。
× 在使用有监督对比学习(SCL)时,mixup 不能很好地发挥作用,因为 SCL 在预训练阶段就需要真实标签。
× mixup 的其他一些优势还包括(如论文所述)对对抗性示例的鲁棒性和稳定的生成对抗网络(GAN)训练。
× 有许多数据增强技术可以扩展 mixup,如 CutMix 和 AugMix。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/574713.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

BGP的基本配置

l 按照以下步骤配置BGP协议: 第1步:设备基本参数配置,AS内配置IGP确保内部网络连通性; l 配置IGP(OSPF协议等)路由解决peer对等体的源和目标IP之间连通性,确保peer之间TCP(179&a…

SpringBoot学习之Redis下载安装启动【Mac版本】(三十七)

一、下载Redis 1、下载地址:Downloads - Redis 往下滑,找到Downloads区域,这里有若干版本,这里我们选择了7.0的稳定版本 2、我们下载的是redis-7.0.15.tar.gz,这是一个压缩包,我们双击解压这个压缩包,可以得到如下文件 二、安装Redis 1、我们进入redis根目录安装mak…

IDEA使用技巧(常用设置、快捷键等)

IDEA使用技巧 一、IDEA常用基本设置设置代码背景颜色/主题/字体Ctrl鼠标滚轮缩放字体大小设置字符编码左右两侧的Project,Structure,Maven等按钮消失新增类似sout,psvm的模版切换某个模块编译的JDK版本 二、常用快捷键CtrlAltT包裹代码Alt回车联想补全Ct…

Qt : 禁用控件默认的鼠标滚轮事件

最近在写一个模拟器,在item中添加了很多的控件,这些控件默认是支持鼠标滚动事件的。在数据量特别大的时候,及容易不小心就把数据给修改了而不自知。所有,我们这里需要禁用掉这些控件的鼠标滚轮事件。 实现的思想很简单&#xff0c…

数据结构篇其二---单链表(C语言+超万字解析)

目录 前言: 一、顺序表的缺点和链表的引入。 二、链表概述 实现一个简单的链表 空链表的概念 三、链表的功能实现 链表的打印 链表节点的创建 链表的头插(自上而下看完分析,相信你会有所收获) 头插的前置分析 传值调用和…

一、路由基础

1.路由协议的优先级 路由器分别定义了外部优先级和内部优先级(越小越优) 路由选择顺序:外部优先级>>内部优先级(相同时) ①外部优先级:用户可以手工为各路由协议配置的优先级 ②内部优先级&#xf…

汽车信息安全--如何理解TrustZone(2)

目录 1.概述 2 如何切换安全状态 3 TrustZone里实现了什么功能? 4. 与HSM的比较 1.概述 汽车信息安全--如何理解TrustZone(1)-CSDN博客讲解了什么是Trustzone,下面我们继续讲解与HSM的区别。 2 如何切换安全状态 在引入安全扩展后,Arm…

太速科技-多路PCIe的阵列计算全国产化服务器

多路PCIe的阵列计算全国产化服务器 多路PCIe的阵列计算全国产化服务器以国产化处理器(海光、飞腾ARM、算能RSIC V)为主板,扩展6-8路PCIe3.0X4计算卡; 计算卡为全国产化的AI处理卡(瑞星微ARM,算能AI&#x…

MMSeg搭建模型的坑

Input type(torch.suda.FloatTensor) and weight type (torch.FloatTensor) should be same 自己搭建模型的时候,经常会遇到二者不匹配,以这种情况为例,是因为部分模型没有加载到CUDA上面造成的。 注意搭建模型的时候,所有层都应…

Oracle之RMAN联机和脱机备份(二)

rman脱机备份,首先使用rman登入数据库服务器,然后关闭数据库后,启动数据库到mount状态,在执行backup database指定备份整个数据库。 1、启动mount归档模式 sys@ORCL>archive log list; Database log mode Archive Mode Automatic archival Enabl…

STM32H750外设ADC之开始和结束数据转换功能

目录 概述 1 开始转换 1.1 使能ADSTART 1.2 使能JADSTART 1.3 ADSTART 通过硬件清零 2 转换时序 3 停止正在进行的转换( ADSTP、 JADSTP) 3.1 停止转换功能实现 3.2 停止转换流程图 概述 本文主要讲述了STM32H750外设ADC之开始和结束数据转换…

SPRD Android 14 通过属性控制系统设置显示双栏或者单栏

SPRD Android 14 通过属性控制系统设置显示双栏或者单栏 第一步 确认有添加静态库第二步 验证第三步 修改源码在合适的地方配置 ro.product.is_support_SettingsSplitEnabled 即可。第一步 确认有添加静态库 --- a/packages/apps/Settings/Android.bp +++ b/packages/apps/Set…

[集群聊天项目] muduo网络库

目录 网络服务器编程常用模型什么是muduo网络库什么是epoll muduo网络库服务器编程 网络服务器编程常用模型 【方案1】 : accept read/write 不是并发服务器 【方案2】 : accept fork - process-pre-connection 适合并发连接数不大,计算任…

Centos的一些基础命令

CentOS是一个基于开源代码构建的免费Linux发行版,它由Red Hat Enterprise Linux (RHEL) 的源代码重新编译而成。由于 CentOS是基于RHEL构建的,因此它与RHEL具有非常类似的特性和功能,包括稳定性、安全性和可靠性。并且大部分的 Linux 命令在C…

Apache Doris 2.x 版本【保姆级】安装+使用教程

Doris简介 Apache Doris 是一个基于 MPP 架构的高性能、实时的分析型数据库,以极速易用的特点被人们所熟知,仅需亚秒级响应时间即可返回海量数据下的查询结果,不仅可以支持高并发的点查询场景,也能支持高吞吐的复杂分析场景。基于…

COOIS 生产订单显示系统增强

需求说明:订单系统显示页面新增批量打印功能 增强点:CL_COIS_DISP_LIST_NAVIGATION -->TOOLBAR方法中新增隐式增强添加自定义打印按钮 增强点:BADI-->WORKORDER_INFOSYSTEM新增增强实施 实现位置:IF_EX_WORKORDER_INFOSYS…

频裂变加群推广强制分享引流源码

视频裂变加群推广强制分享引流源码,用户达到观看次数后需要分享给好友或者群,好友必须点击推广链接后才会增加观看次数。 引导用户转发QV分享,达到快速裂变引流的效果! 视频裂变推广程序,强制分享链接,引导用户转发,…

数据库MySQL的初级基础操作

文章目录 1. 介绍2. 数据库相关概念3. 启动4. 数据模型5. SQL6. DDL数据库DDL-表操作DDL-表操作-数据类型DDL-表操作-修改DDL-表操作-删除 7. 图形化界面工具DataGrip8. DML(数据操作语言)DML-添加数据DML-修改数据 9. DQL(数据查询语言)基本查询条件查询…

MemFire案例-政务应急物联网实时监测预警项目

客户背景 党的十八大以来,中央多次就应急管理工作做出重要指示:要求坚持以防为主、防抗救相结合,全面提升综合防灾能力;坚持生命至上、安全第一,完善安全生产责任制,坚决遏制重特大安全事故。 面对新形势…

小白学习SpringCloud之Eureka

前言 需要搭建springcloud项目,eureka是其中的一个模块,依赖主要继承父依赖 学习视频:b站狂神说 便于理解,我修改了本地域名》这里!!! 127.0.0.1 eureka7001.com 127.0.0.1 eureka7002.com 127.0.0.1 eureka7003.comEureka入门案例 eureka…