Tensorflow入门实战 T05-运动鞋识别

目录

一、完整代码

二、训练过程

(1)打印2行10列的数据。

(2)查看数据集中的一张图片

(3)训练过程(训练50个epoch)

(4)训练结果的精确度

三、遇到的问题


  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

这篇博客的主要内容是,关于运动鞋的识别。

运动鞋数据集包含训练集和测试集,共578张。

一、完整代码

from tensorflow import keras
from keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0], "GPU")

# 导入数据集
data_dir = "/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p05_sport/sport_data"
data_dir = pathlib.Path(data_dir)  # 打印文件夹目录
image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:",image_count)
roses = list(data_dir.glob('train/nike/*.jpg'))
result = PIL.Image.open(str(roses[0]))
# result.show()

# 数据预处理
batch_size = 32
img_height = 224
img_width = 224

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_dir = str(data_dir) + "/train"
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
test_dir = str(data_dir) + "/test/"
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    test_dir,
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)  # 打印结果: ['adidas', 'nike']

# 可视化数据
plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])

        plt.axis("off")
plt.show()

# 检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)   # (32, 224, 224, 3)
    print(labels_batch.shape)  # (32,)
    break

# 配置数据集
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# 搭建神经网络
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""

model = models.Sequential([
    keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),

    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样
    layers.Dropout(0.3),
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),

    layers.Flatten(),  # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取
    layers.Dense(len(class_names))  # 输出层,输出预期结果
])

# model.summary()  # 打印网络结构


# 设置动态学习率
# 设置初始学习率
initial_learning_rate = 0.1

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=10,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])


from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

epochs = 50

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=20,
                             verbose=1)

# 模型训练
history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

# 模型评估图
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

# 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')

from PIL import Image
import numpy as np

# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  # 这里选择你需要预测的图片
img = Image.open(str(data_dir) + "/test/nike/1.jpg")  # 这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0) # /255.0  # 记得做归一化处理(与训练集处理方式保持一致)

predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

二、训练过程

(1)打印2行10列的数据。

(2)查看数据集中的一张图片

(3)训练过程(训练50个epoch)

模型早听了。

(4)训练结果的精确度

三、遇到的问题

在进行字符串拼接的时候,出现unsupported operand type(s) for +: 'PosixPath' and 'str'

查了下相关资料,添加str( )就可以。

原因:因为前面的data_dir 是经过pathlib.Path() 处理的。

添加str( ) 完美解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/725739.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Docker环境离线安装

Docker环境离线安装 下载下列.deb包 sudo *.deb

【PyQt5】python桌面级应用开发:PyQt5介绍,开发环境搭建快速入门

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

天津媒体邀约,及媒体名单?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 媒体宣传加速季,100万补贴享不停,一手媒体资源,全国100城线下落地执行。详情请联系胡老师。 天津作为中国北方的重要城市,拥有丰富的媒体资…

Jenkins+K8s实现持续集成(二)

部署前呢,要先把jenkins搭建好。 同时呢已经有了k8s的环境。 基于以上两步已经有了的情况,继续要实现jenkinsk8s持续集成,需要先准备四个文件: Dockerfile首先要准备好一个Dockerfile文件,用于构建Docker镜像的文本…

最新版本IntelliJ IDEA安装与“坤活”使用

最新版本IntelliJ IDEA安装与“科学”使用 IntelliJ IDEA安装与坤活下载安装坤活idea1.将下面两个压缩文件解压到安装位置,注意路径不要包含中文空格等特殊符号2.双击 install-all-users.vbs ,然后点击确定,等到出现 Done的弹窗3. 打开idea复…

函数依赖集等价、最小函数依赖集

一、函数依赖集等价 1、定义 假设F、G为一个关系模式上的两个函数依赖集,若,则称F和G是等价的,也可称F和G 互相覆盖。 2、判断 (1)引理3: 的充分必要条件是且 (2)两步走&…

密码学及其应用——为什么选择接近的质数因子对RSA加密算法不安全?

RSA加密算法是一种广泛使用的非对称加密算法,它的安全性依赖于大整数分解的难度。具体来说,RSA算法生成的公钥包含一个大整数N,这是两个大质数p和q的乘积。然而,如果这两个质数p和q太接近,则可以相对容易地对N进行因式…

Study--Oracle-04-SQL练习

一、SQL语句思维导图 二、SQL练习 -- 以employee_id 为排序,列出前5个人 -- FETCH select employee_id,first_name from employees order by employee_id FETCH FIRST 5 rows only; -- 以employee_id 为排序,从第6个人开始 到第10个人 -- offset …

计算机组成原理---Cache的基本工作原理习题

对应知识点: Cache的基本原理 1.某存储系统中,主存容量是Cache容量的4096倍,Cache 被分为 64 个块,当主存地址和Cache地址采用直接映射方式时,地址映射表的大小应为()(假设不考虑一致维护和替…

【Redis】如何保证缓存和数据库的一致性

目录 背景问题思路 三个经典的缓存模式Cache-Aside读缓存写缓存为什么是删除旧缓存而不是更新旧缓存?为什么不先删除旧的缓存,然后再更新数据库? 延迟双删如何确保原子性 Read-Through/Write-ThroughRead-ThroughWrite-Through Write Behind …

Ubuntu22.04 下安装Curl库

1. apt 安装: sudo apt-get install curl 2. 官网压缩包: 下载地址:curl downloads wget https://curl.haxx.se/download/curl-7.78.0.tar.gz tar -xzvf curl-7.78.0.tar.gz cd curl-7.78.0 ./configure --with-openssl make sudo make i…

ubuntu系统上快速直接获取ip地址

文章目录 前言总结 前言 ubuntu系统上查看ip地址 $ hostname -I示例输出: 192.168.1.10总结 作者:加辣椒了吗? 简介:憨批大学生一枚,喜欢在博客上记录自己的学习心得,也希望能够帮助到你们!

基于若依的ruoyi-nbcio流程管理系统增加所有任务功能(一)

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 http://218.75.87.38:9666/ 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码: h…

CMake个人理解和使用

100编程书屋_孔夫子旧书网 前言 CMake是一个构建工具,通过它可以很容易创建跨平台的项目。通常使用它构建项目要分两步,通过源代码生成工程文件,通过工程文件构建目标产物(可能是动态库,静态库,也可能是可执行程序)。使用CMake的一个主要优势是在多平台或者多人协作的…

Python学习打卡:day10

day10 笔记来源于:黑马程序员python教程,8天python从入门到精通,学python看这套就够了 目录 day1073、文件的读取操作文件的操作步骤open()打开函数mode常用的三种基础访问模式读操作相关方法read()方法readlines()方法readline()方法for循…

Golang | Leetcode Golang题解之第168题Excel表列名称

题目&#xff1a; 题解&#xff1a; func convertToTitle(columnNumber int) string {ans : []byte{}for columnNumber > 0 {columnNumber--ans append(ans, Abyte(columnNumber%26))columnNumber / 26}for i, n : 0, len(ans); i < n/2; i {ans[i], ans[n-1-i] ans[n…

VSCode 安装NeoVim扩展(详细)

目录 1、安装NeoVim扩展 2、windows安装Neovim软件 3、优化操作相关的配置&#xff1a; 5、Neovim最好的兼容性配置 6、技巧和特点 6.1 故障排除 6.2、Neovim 插件组合键设置 6.3、跳转列表 1、安装NeoVim扩展 在扩展商店搜索NeoVim&#xff0c;安装扩展 2、windows安装…

吉时利Keithley2602B数字源表

吉时利Keithley2602B数字源表 2601B、2602B、2604B 系统 Sourcemeter SMU 仪器 2601B、2602B 和 2604B 系统 Sourcemeter SMU 仪器为 40W DC / 200W 脉冲 SMU&#xff0c;支持 10A 脉冲&#xff0c;3A 至 100fA 和 40V 至 100nV DC。它们将精密电源、实际电流源、6 位数字万用…

[Linux] 文件系统

UNIX操作系统将文件组织成一个有层次的树形结构&#xff1a; 标准目录&#xff1a; 根目录&#xff1a; /tmp目录 主目录&#xff1a; 这就是主目录 一般与系统有关的信息都存放在etc目录下 注意&#xff1a; /etc/passwd存放的是用户账户信息&#xff0c;不是密码信息&#xf…