第33周:运动鞋识别(Tensorflow实战第五周)

目录

前言

一、前期工作

1.1 设置GPU

1.2 导入数据

1.3 查看数据

二、数据预处理

2.1 加载数据

2.2 可视化数据

2.3 再次检查数据

2.4 配置数据集

2.4.1 基本概念介绍

2.4.2 代码完成

三、构建CNN网络

四、训练模型

4.1 设置动态学习率

4.2 早停与保存最佳模型参数

4.3 模型训练

五、模型评估

5.1 Loss和Accuracy图

5.2 指定图片进行预测

六、改进

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
  • 🍖 原作者:[K同学啊]

说在前面

1)本周任务:基于CNN模型完成对两种品牌运动鞋图片的识别

2)运行环境:Python3.6、Pycharm2020、tensorflow2.4.0


一、前期工作

1.1 设置GPU

代码如下:

# 一、前期准备
# 1.1 设置GPU
from tensorflow import keras
from tensorflow.keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0], "GPU")
print(gpus)

1.2 导入数据

代码如下:

# 1.2 导入数据
data_dir = "./data/"
data_dir = pathlib.Path(data_dir)

1.3 查看数据

代码如下:

# 1.3 查看数据
image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:", image_count)
shoses = list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(shoses[0]))

输出:

图片总数为: 578

二、数据预处理

2.1 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset,tf.keras.preprocessing.image_dataset_from_directory():是 TensorFlow 的 Keras 模块中的一个函数,用于从目录中创建一个图像数据集(dataset)。这个函数可以以更方便的方式加载图像数据,用于训练和评估神经网络模型

测试集与验证集的关系:

  • 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  • 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集
  • 因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

# 二、数据预处理
# 2.1 加载数据集
batch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "./data/train/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "./data/test/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

输出如下:

Found 502 files belonging to 2 classes.

Found 76 files belonging to 2 classes.

['adidas', 'nike']

2.2 可视化数据

代码如下:

# 2.2 可视化数据
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

输出:

2.3 再次检查数据

代码如下:

# 2.3再次检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出:

(32, 224, 224, 3)
(32,)

情况说明:

  • Image_batch是形状的张量(32,2224,224,3)。这是一批形状224x224x3的32张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(32,)的张量,这些标签对应32张图片

2.4 配置数据集

2.4.1 基本概念介绍

prefetch():CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态

​​

然后使用prefetch()可显著减少空闲时间:

​​

2.4.2 代码完成

# cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

代码如下:

# cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络

         卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是 (180, 180, 3)。我们需要在声明第一层时将形状赋值给参数input_shape

网络结构图如下:

​​

代码如下:

# 三、构建CNN网络
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样
    layers.Dropout(0.3),
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),
    layers.Flatten(),  # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取
    layers.Dense(len(class_names))  # 输出层,输出预期结果
])

model.summary()  # 打印网络结构

模型结构打印如下:

四、训练模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率

4.1 设置动态学习率

ExponentialDecay函数tf.keras.optimizers.schedules.ExponentialDecay是 TensorFlow 中的一个学习率衰减策略,用于在训练神经网络时动态地降低学习率。学习率衰减是一种常用的技巧,可以帮助优化算法更有效地收敛到全局最小值,从而提高模型的性能

🔍主要参数:

  • decay_steps(衰减步数):学习率衰减的步数。在经过 decay_steps 步后,学习率将按照指数函数衰减。例如,如果 decay_steps 设置为 10,则每10步衰减一次。
  • decay_rate(衰减率):学习率的衰减率。它决定了学习率如何衰减。通常,取值在 0 到 1 之间。
  • staircase(阶梯式衰减):一个布尔值,控制学习率的衰减方式。如果设置为 True,则学习率在每个 decay_steps 步之后直接减小,形成阶梯状下降。如果设置为 False,则学习率将连续衰减

代码如下:

initial_learning_rate = 0.0001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=10,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

⚠️⚠️⚠️注:这里设置的动态学习率为:指数衰减型(ExponentialDecay)。在每一个epoch开始前,学习率(learning_rate)都将会重置为初始学习率(initial_learning_rate),然后再重新开始衰减。计算公式如下:learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

Tips学习率大与学习率小的优缺点分析

  • 学习率大:优点--加快学习速率,有助于跳出局部最优值;缺点--导致模型训练不收敛,单单使用大学习率容易导致模型不精确;
  • 学习率小:优点--有助于模型收敛、模型细化,提高模型精度;缺点--很难跳出局部最优值,收敛缓慢

4.2 早停与保存最佳模型参数

🔍EarlyStopping()参数说明:

  • monitor:被监测的数据
  • min_delta:在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
  • patience:没有进步的训练轮数,在这之后训练就会被停止。
  • verbose:详细信息模式。
  • mode: {auto, min, max} 其中之一。 在 min 模式中, 当被监测的数据停止下降,训练就会停止;在 max 模式中,当被监测的数据停止上升,训练就会停止;在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
  • baseline:要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
  • estoe_best_weights:是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重

代码如下:

# 4.2 早停与保存最佳模型参数
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
epochs = 50
# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)
# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=20,
                             verbose=1)

4.3 模型训练

代码如下:

# 4.3 模型训练
history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

打印训练过程:

五、模型评估

5.1 Loss和Accuracy图

代码如下:

# 五、模型评估
# 5、1 Loss与Accuracy图
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(loss))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

训练结果可视化如下:

5.2 指定图片进行预测

代码如下:

# 5.2 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')
from PIL import Image
import numpy as np

# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  #这里选择你需要预测的图片
img = Image.open("./data/test/nike/1.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])
img_array = tf.expand_dims(image, 0)/255.0  # 记得做归一化处理(与训练集处理方式保持一致)
predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

输出:

预测结果为: addidas

六、改进

从训练过程和指定预测的结果可以发现效果并不好,这里考虑修改initial_learning_rate,原始设置的是0.1,这里考虑将其调至0.001或0.0001


总结

本周在上周的基础上增加了动态学习率的设置,优化了训练过程,并实现了正确预测,也让我更加熟练基于Tensorflow框架下CNN模型搭建的流程,并对动态学习率初始值设置对结果影响进行了分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/926858.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

云轴科技ZStack助力 “上科大智慧校园信创云平台”入选上海市2024年优秀信创解决方案

近日,为激发创新活⼒,促进信创⾏业⾼质量发展,由上海市经济信息化委会同上海市委网信办、上海市密码管理局、上海市国资委等主办的“2024年上海市优秀信创解决方案”征集遴选活动圆满落幕。云轴科技ZStack支持的“上科大智慧校园信创云平台”…

Linux 服务器使用指南:诞生与演进以及版本(一)

一、引言 在当今信息技术的浪潮中,Linux 操作系统无疑是一个关键的支柱😎。无论是在服务器管理、软件开发还是大数据处理领域,Linux 都以其卓越的适应性和优势脱颖而出👍。然而,对于许多新手而言,Linux 系统…

基于树莓派3B+的简易智能家居小项目(WiringPi库 + C语言开发)

github主页:https://github.com/snqx-lqh 本项目github地址:https://github.com/snqx-lqh/RaspberryPiSmartHome 硬件开源地址:https://oshwhub.com/from_zero/shu-mei-pai-kuo-zhan-ban 欢迎交流 树莓派智能家居项目,学习树莓派的…

YOLOv8-Pose NCNN安卓部署

YOLOv8-Pose NCNN安卓部署 前言 YOLOv8-Pose NCNN安卓部署 目前的帧率可以稳定在30帧左右,下面是这个项目的github地址:https://github.com/gaoxumustwin/ncnn-android-yolov8-pose 介绍 在做YOLOv8-Pose NCNN安卓部署的时候,在github上…

【热门主题】000077 物联网智能项目:开启智能未来的钥匙

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 【热…

【STM32学习】TB6612FNG驱动芯片的学习,驱动电路的学习

目录 1、TB6612电机驱动芯片 1.1如下是芯片的引脚图: 1.2如下图是电机的控制逻辑: 1.3MOS管运转逻辑 1.3典型应用电路 2、H桥驱动电路 2.1、单极模式 2.2、双极模式 2.3、高低端MOS管导通条件 2.4、H桥电路设计 2.5、自举电路 3、电气特性 3…

day01(Linux底层)基础知识

目录 导学 基础知识 1、Bootloader是什么 2、Bootloader的基本作用 3、入式中常见的Bootloader有哪些 4、Linux系统移植为什么要使用bootloader 5、uboot和Bootloader之间的关系 6.Uboot的获取 7、uboot版本命名 8、uboot版本选择 9、uboot的特点 10.Uboot使用 导学…

OpenFeign 服务调用

1.简介 微服务架构中使用 OpenFeign 进行服务调用, OpenFeign 提供了一种简洁的方式来定义和处理服 务间的调用。 OpenFeign 作为一个声明式的、模块化的 HTTP 客户端,通过 「接口」 的定义和 「注解」 的使用,简化了微服务之间的通信调用。…

superset load_examples加载失败解决方法

如果在执行load_examples命令后,出现上方图片情况,或是相似报错(url error\connection error),大概率原因是python程序请求github数据,无法访问. 因此我们可以将数据下载在本地来解决. 1.下载zip压缩文件,存放到本地 官方示例地址:GitHub - apache-superset/examples-data …

k8s集成skywalking

如果能科学上网的话,安装应该不难,如果有问题可以给我留言 本篇文章我将给大家介绍“分布式链路追踪”的内容,对于目前大部分采用微服务架构的公司来说,分布式链路追踪都是必备的,无论它是传统微服务体系亦或是新一代…

深度学习Python基础(2)

二 数据处理 一般来说PyTorch中深度学习训练的流程是这样的: 1. 创建Dateset 2. Dataset传递给DataLoader 3. DataLoader迭代产生训练数据提供给模型 对应的一般都会有这三部分代码 # 创建Dateset(可以自定义) dataset face_dataset # Dataset部分自定义过的…

【Git教程 之 安装】

Git教程 之 安装 Git教程 之 安装Windows系统安装gitLinux安装gitmacOS安装Git Git教程 之 安装 Windows系统安装git 首先从官网上直接下载安装 选择对应的操作系统,我电脑是Windows,所以我演示以Windows为主 点进去 点击Click here to download 下…

ElasticSearch学习记录

服务器操作系统版本:Ubuntu 24.04 Java版本:21 Spring Boot版本:3.3.5 如果打算用GUI,虚拟机安装Ubuntu 24.04,见 虚拟机安装Ubuntu 24.04及其常用软件(2024.7)_ubuntu24.04-CSDN博客文章浏览阅读6.6k次&#xff0…

【AI】数据,算力,算法和应用(3)

三、算法 算法这个词,我们都不陌生。 从接触计算机,就知道有“算法”这样一个神秘的名词存在。象征着专业、权威、神秘、高难等等。 算法是一组有序的解决问题的规则和指令,用于解决特定问题的一系列步骤。算法可以被看作是解决问题的方法…

Java代码操作Zookeeper(使用 Apache Curator 库)

1. Zookeeper原生客户端库存在的缺点 复杂性高:原生客户端库提供了底层的 API,需要开发者手动处理很多细节,如连接管理、会话管理、异常处理等。这增加了开发的复杂性,容易出错。连接管理繁琐:使用原生客户端库时&…

SAP SD学习笔记17 - 投诉处理3 - Credit/Debit Memo依赖,Credit/Debit Memo

上一章讲了 请求书(发票)的取消。 SAP SD学习笔记16 - 请求书的取消 - VF11-CSDN博客 再往上几章,讲了下图里面的返品传票: SAP SD学习笔记14 - 投诉处理1 - 返品处理(退货处理)的流程以及系统实操&#…

Flink--API 之Transformation-转换算子的使用解析

目录 一、常用转换算子详解 (一)map 算子 (二)flatMap 算子 (三)filter 算子 (四)keyBy 算子 元组类型 POJO (五)reduce 算子 二、合并与连接操作 …

qt QToolBox详解

1、概述 QToolBox是Qt框架中的一个控件,它提供了一个带标签页的容器,用户可以通过点击标签页标题来切换不同的页面。QToolBox类似于一个带有多页选项卡的控件,但每个“选项卡”都是一个完整的页面,而不仅仅是标签。这使得QToolBo…

MySQL 复合查询

实际开发中往往数据来自不同的表,所以需要多表查询。本节我们用一个简单的公司管理系统,有三张表EMP,DEPT,SALGRADE 来演示如何进行多表查询。表结构的代码以及插入的数据如下: DROP database IF EXISTS scott; CREATE database IF NOT EXIST…

分布式系统中的Dapper与Twitter Zipkin:链路追踪技术的实现与应用

目录 一、什么是链路追踪? 二、核心思想Dapper (一)Dapper链路追踪基本概念概要 (二)Trace、Span、Annotations Trace Span Annotation 案例说明 (三)带内数据与带外数据 带外数据 带…