第32周:猴痘病识别(Tensorflow实战第四周)

目录

前言

一、前期工作

1.1 设置GPU

1.2 导入数据

1.3 查看数据

二、数据预处理

2.1 加载数据

2.2 可视化数据

2.3 再次检查数据

2.4 配置数据集

2.4.1 基本概念介绍

2.4.2.代码完成

三、构建CNN网络

四、编译

五、训练模型

六、模型评估

6.1 Loss和Accuracy图

6.2 指定图片进行预测

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
  • 🍖 原作者:[K同学啊]

说在前面

1)本周任务:基于CNN模型完成对猴痘病图片的识别

2)运行环境:Python3.6、Pycharm2020、tensorflow2.4.0


一、前期工作

1.1 设置GPU

代码如下:

# 一、前期准备
# 1.1 设置GPU
from tensorflow import keras
from tensorflow.keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0], "GPU")
print(gpus)

1.2 导入数据

代码如下:

# 1.2 导入数据
data_dir = "./4-data/"
data_dir = pathlib.Path(data_dir)

1.3 查看数据

代码如下:

# 1.3 查看数据
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:", image_count)
Monkeypox = list(data_dir.glob('Monkeypox/*.jpg'))
PIL.Image.open(str(Monkeypox[0]))

输出:

图片总数为: 2142

二、数据预处理

2.1 加载数据

使用image_dataset_from_directory方法(详细可参考文章tf.keras.preprocessing.image_dataset_from_directory() 简介_tf.python.keras preprocessing在哪里-CSDN博客)将磁盘中的数据加载到tf.data.Dataset中。

测试集与验证集的关系:

  • 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  • 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

# 二、数据预处理
# 2.1 加载数据
batch_size = 32
img_height = 180
img_width = 180
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

输出如下:

Found 2142 files belonging to 2 classes.
Using 1714 files for training.

​Found 2142 files belonging to 2 classes.
Using 428 files for validation.

['Monkeypox', 'Others']

2.2 可视化数据

代码如下:

# 2.2 可视化数据
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

输出:

2.3 再次检查数据

代码如下:

# 2.3再次检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出:

(32, 224, 224, 3)
(32,)

情况说明:

  • Image_batch是形状的张量(32,180,180,3)。这是一批形状180x180x3的32张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(32,)的张量,这些标签对应32张图片

2.4 配置数据集

2.4.1 基本概念介绍

prefetch():CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态

然后使用prefetch()可显著减少空闲时间:

2.4.2.代码完成

代码如下:

# cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络

         卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是 (180, 180, 3)。我们需要在声明第一层时将形状赋值给参数input_shape

网络结构图如下:

代码如下:

# 三、构建CNN网络
num_classes = 4
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),

    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),  # 让神经元以一定的概率停止工作,防止过拟合,提高模型的泛化能力。

    layers.Flatten(),  # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取
    layers.Dense(num_classes)  # 输出层,输出预期结果
])
model.summary()  # 打印网络结构

模型结构打印如下:

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率

代码如下:

# 四、编译
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

代码如下:

# 五、训练模型
epochs = 10
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

训练过程打印如下:

六、模型评估

6.1 Loss和Accuracy图

代码如下:

# 六、模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

训练结果可视化如下:

6.2 指定图片进行预测

代码如下:

# 6.2 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')
from PIL import Image
import numpy as np
# img = Image.open("./4-data/Monkeypox/M06_01_04.jpg")  #这里选择你需要预测的图片
img = Image.open("./4-data/Others/NM15_02_11.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])
img_array = tf.expand_dims(image, 0)
predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

输出:

预测结果为: Others


总结

本周在上周的基础上增加了对指定图片进行预测的任务,并实现了正确预测,也让我更加熟练CNN模型搭建的流程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/922838.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【创建型设计模式】工厂模式

【创建型设计模式】工厂模式 创建型设计模式第二期!本期介绍简单工厂模式和工厂方法模式。 简单工厂模式 简单工厂模式(又叫作静态工厂方法模式),其属于创建型设计模式,简单工厂模式不属于设计模式中的 23 种经典模…

【Linux】安装cuda

一、安装nvidia驱动 # 添加nvidia驱动ppa库 sudo add-apt-repository ppa:graphics-drivers/ppa sudo apt update# 查找推荐版本 sudo ubuntu-drivers devices# 安装推荐版本 sudo apt install nvidia-driver-560# 检验nvidia驱动是否安装 nvidia-smi 二、安装cudatoolkit&…

上天入地 灵途科技光电技术赋能空间感知

近来,人工智能技术频频亮相各大马拉松赛事,成为引人注目的科技亮点。 11月3日,杭州马拉松首次启用了机器狗作为配速员,以稳定的节奏为选手提供科学的跑步节奏。 11月11日,亦庄半程马拉松的终点处,人形机器…

Java三大特性:封装、继承、多态【详解】

封装 定义 隐藏对象的属性和实现细节,仅对外公开接口,控制在程序中属性的读取和修改的访问级别便是封装。 在开发中造一个类就是封装,有时也会说封装一个类。封装可以隐藏一些细节或者包含数据不能被随意修改。 比如这是一个敏感的数据&a…

40分钟学 Go 语言高并发:【实战】并发安全的配置管理器(功能扩展)

【实战】并发安全的配置管理器(功能扩展) 一、扩展思考 分布式配置中心 实现配置的集中管理支持多节点配置同步实现配置的版本一致性 配置加密 敏感配置的加密存储配置的安全传输访问权限控制 配置格式支持 支持YAML、TOML等多种格式配置格式自动…

【ChatGPT大模型开发调用】如何获得 OpenAl API Key?

如何获取 OpenAI API Key 获取 OpenAI API Key 主要有以下三种途径: OpenAI 官方平台 (推荐): 开发者用户可以直接在 OpenAI 官方网站 (platform.openai.com) 注册并申请 API Key。 通常,您可以在账户设置或开发者平台的相关页面找到申请入口。 Azure…

苹果系统中利用活动监视器来终止进程

前言 苹果系统使用的时候总是感觉不太顺手。特别是转圈的彩虹球出现的时候,就非常令人恼火。如何找到一个像Windows那样任务管理器来终止掉进程呢? 解决办法 Commandspace 弹出搜索框吗,如下图: 输入“活动”进行搜索&#xff…

实战项目负载均衡式在线 OJ

> 作者:დ旧言~ > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:能自己实现负载均衡式在线 OJ。 > 毒鸡汤:有些事情,总是不明白,所以我不会坚持。早安! > 专栏选自&#xff1…

python Flask指定IP和端口

from flask import Flask, request import uuidimport json import osapp Flask(__name__)app.route(/) def hello_world():return Hello, World!if __name__ __main__:app.run(host0.0.0.0, port5000)

burp suite-1

声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&a…

【Spring boot】微服务项目的搭建整合swagger的fastdfs和demo的编写

文章目录 1. 微服务项目搭建2. 整合 Swagger 信息3. 部署 fastdfsFastDFS安装环境安装开始图片测试FastDFS和nginx整合在Storage上安装nginxnginx安装不成功排查:4. springboot 整合 fastdfs 的demodemo编写1. 微服务项目搭建 版本总结: spring boot: 2.6.13springfox-boot…

【区块链】深入理解椭圆曲线密码学(ECC)

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 深入理解椭圆曲线密码学(ECC)1. 概述2. 椭圆曲线的数学基础2.1 基本定义2.2 有限…

【Qt流式布局改造支持任意位置插入和删除】

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、源代码二、删除代码三、扩展总结 前言 最近在做一个需求需要流式布局,虽然官方example里有一个流式布局范例,但是不能满足我的需求…

JQuery -- 第九课

文章目录 前言一、JQuery是什么?二、JQuery的使用步骤1.引入2.书写位置3. 表示方法 三、JQuery选择器1.层级选择器2. 筛选选择器3. 排他思想4. 精品展示 四、jQuery样式操作1. 修改样式2.类操作1. 添加2. 移除3. 切换 五、jQuery动画1. 显示和隐藏2. 滑动1. slide2.…

Python 版本的 2024详细代码

2048游戏的Python实现 概述: 2048是一款流行的单人益智游戏,玩家通过滑动数字瓷砖来合并相同的数字,目标是合成2048这个数字。本文将介绍如何使用Python和Pygame库实现2048游戏的基本功能,包括游戏逻辑、界面绘制和用户交互。 主…

在Elasticsearch中,是怎么根据一个词找到对应的倒排索引的?

大家好,我是锋哥。今天分享关于【在Elasticsearch中,是怎么根据一个词找到对应的倒排索引的?】面试题。希望对大家有帮助; 在Elasticsearch中,是怎么根据一个词找到对应的倒排索引的? 在 Elasticsearch 中…

C# 数据结构之【图】C#图

1. 图的概念 图是一种重要的数据结构,用于表示节点(顶点)之间的关系。图由一组顶点和连接这些顶点的边组成。图可以是有向的(边有方向)或无向的(边没有方向),可以是加权的&#xff…

Mac 系统上控制台常用性能查看命令

一、top命令显示 在macOS的控制台中,top命令提供了系统当前运行的进程的详细信息以及整体系统资源的利用情况。下面是对输出中各个字段的解释: Processes: 483 total: 系统上总共有483个进程。 2 running: 当前有2个进程正在运行。 481 sleeping: 当前有…

Docker--通过Docker容器创建一个Web服务器

Web服务器 Web服务器,一般指网站服务器,是驻留于因特网上某种类型计算机的程序。 Web服务器可以向浏览器等Web客户端提供文档,也可以放置网站文件以供全世界浏览,或放置数据文件以供全世界下载。 Web服务器的主要功能是提供网上…

Linux网络——NAT/代理服务器

一.NAT技术 1.NAT IP转换 之前我们讨论了, IPv4 协议中, IP 地址数量不充足的问题,NAT 技术就是当前解决 IP 地址不够用的主要手段, 是路由器的一个重要功能。 NAT 能够将私有 IP 对外通信时转为全局 IP. 也就是一种将私有 IP 和全局IP 相互转化的技术方法: 很…