第T9周:使用TensorFlow实现猫狗识别2

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • 一、前期工作
      • 1.设置GPU(如果使用的是CPU可以忽略这步)
      • 2. 导入数据
    • 二、数据预处理
      • 1、加载数据
      • 2、再次检查数据
      • 3、可视化数据
      • 3、数据增强、配置数据集
      • 4、 显示数据增强后的数据
    • 三、构建CNN网络
    • 四、编译
    • 五、训练模型
    • 六、模型评估
    • 七、预测
    • 八、总结

电脑环境:
语言环境:Python 3.8.0
编译器:Jupyter Notebook
深度学习环境:tensorflow 2.15.0

一、前期工作

1.设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

# 打印显卡信息,确认GPU可用
print(gpus)

2. 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

import os,PIL,pathlib

#隐藏警告
import warnings
warnings.filterwarnings('ignore')

data_dir = "./365-7-data"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*')))

print("图片总数为:",image_count)

二、数据预处理

1、加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中。

batch_size = 64
img_height = 224
img_width = 224

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)

输出:

[‘cat’, ‘dog’]

2、再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出:

(64, 224, 224, 3)
(64,)

3、可视化数据

plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(5, 8, i + 1) 
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

3、数据增强、配置数据集

在上期的文章中,我们没有对数据进行数据增强,本次尝试数据增强改善模型性能。

AUTOTUNE = tf.data.AUTOTUNE

# 定义数据增强层
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.2),
    tf.keras.layers.RandomContrast(0.1)
])

def preprocess_image(image, label):
    image = image / 255.0
    image = data_augmentation(image)
    return image, label

def preprocess_val_image(image, label):
    image = image / 255.0
    return image, label

# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
# 验证集不需要增强
val_ds   = val_ds.map(preprocess_val_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
  • RandomFlip:随机翻转,可以水平翻转(horizontal)和垂直翻转(vertical);
  • RandomRotation:随机旋转;
  • RandomZoom:随机缩放;
  • RandomContrast:随机对比度调整,增加或减少亮暗差异;

4、 显示数据增强后的数据

在这里插入图片描述

三、构建CNN网络

直接调用官方VGG16

from keras.applications import VGG16

model = VGG16(weights='imagenet')

model.summary()

四、编译

model.compile(optimizer="adam",
              loss     ='sparse_categorical_crossentropy',
              metrics  =['accuracy'])

五、训练模型

from tqdm import tqdm
import tensorflow.keras.backend as K

epochs = 10
lr     = 1e-4

# 记录训练数据,方便后面的分析
history_train_loss     = []
history_train_accuracy = []
history_val_loss       = []
history_val_accuracy   = []

for epoch in range(epochs):
    train_total = len(train_ds)
    val_total   = len(val_ds)
    
    """
    total:预期的迭代数目
    ncols:控制进度条宽度
    mininterval:进度更新最小间隔,以秒为单位(默认值:0.1)
    """
    with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=1,ncols=100) as pbar:
        
        lr = lr*0.92
        K.set_value(model.optimizer.lr, lr)
        
        train_loss     = []
        train_accuracy = []
        for image,label in train_ds:   
            """
            训练模型,简单理解train_on_batch就是:它是比model.fit()更高级的一个用法

            想详细了解 train_on_batch 的同学,
            可以看看我的这篇文章:https://www.yuque.com/mingtian-fkmxf/hv4lcq/ztt4gy
            """
             # 这里生成的是每一个batch的acc与loss
            history = model.train_on_batch(image,label)
            
            train_loss.append(history[0])
            train_accuracy.append(history[1])
            
            pbar.set_postfix({"train_loss": "%.4f"%history[0],
                              "train_acc":"%.4f"%history[1],
                              "lr": K.get_value(model.optimizer.lr)})
            pbar.update(1)
            
        history_train_loss.append(np.mean(train_loss))
        history_train_accuracy.append(np.mean(train_accuracy))
            
    print('开始验证!')
    
    with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=0.3,ncols=100) as pbar:

        val_loss     = []
        val_accuracy = []
        for image,label in val_ds:      
            # 这里生成的是每一个batch的acc与loss
            history = model.test_on_batch(image,label)
            
            val_loss.append(history[0])
            val_accuracy.append(history[1])
            
            pbar.set_postfix({"val_loss": "%.4f"%history[0],
                              "val_acc":"%.4f"%history[1]})
            pbar.update(1)
        history_val_loss.append(np.mean(val_loss))
        history_val_accuracy.append(np.mean(val_accuracy))
            
    print('结束验证!')
    print("验证loss为:%.4f"%np.mean(val_loss))
    print("验证准确率为:%.4f"%np.mean(val_accuracy))
Epoch 1/20: 100%|███| 43/43 [00:56<00:00,  1.31s/it, train_loss=0.7041, train_acc=0.4531, lr=9.2e-5]
开始验证!
Epoch 1/20: 100%|██████████████████| 11/11 [00:02<00:00,  3.79it/s, val_loss=0.7103, val_acc=0.5000]
结束验证!
验证loss为:0.7073
验证准确率为:0.5085
Epoch 2/20: 100%|██| 43/43 [00:10<00:00,  4.30it/s, train_loss=0.6984, train_acc=0.5312, lr=8.46e-5]
开始验证!
Epoch 2/20: 100%|██████████████████| 11/11 [00:01<00:00,  8.82it/s, val_loss=0.6984, val_acc=0.5000]
结束验证!
验证loss为:0.6955
验证准确率为:0.5085
Epoch 3/20: 100%|██| 43/43 [00:09<00:00,  4.48it/s, train_loss=0.6942, train_acc=0.4688, lr=7.79e-5]
开始验证!
Epoch 3/20: 100%|██████████████████| 11/11 [00:01<00:00,  8.85it/s, val_loss=0.6934, val_acc=0.5000]
结束验证!
验证loss为:0.6942
验证准确率为:0.4915
Epoch 4/20: 100%|██| 43/43 [00:09<00:00,  4.33it/s, train_loss=0.6976, train_acc=0.4531, lr=7.16e-5]
开始验证!
Epoch 4/20: 100%|██████████████████| 11/11 [00:01<00:00,  9.86it/s, val_loss=0.6944, val_acc=0.5000]
结束验证!
验证loss为:0.6959
验证准确率为:0.5043
....................................................................................................
Epoch 20/20: 100%|| 43/43 [00:09<00:00,  4.49it/s, train_loss=0.3254, train_acc=0.7656, lr=1.89e-5]
开始验证!
Epoch 20/20: 100%|█████████████████| 11/11 [00:01<00:00,  9.75it/s, val_loss=0.3012, val_acc=0.9500]
结束验证!
验证loss为:0.1548
验证准确率为:0.9670

六、模型评估

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, history_train_accuracy, label='Training Accuracy')
plt.plot(epochs_range, history_val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, history_train_loss, label='Training Loss')
plt.plot(epochs_range, history_val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、预测

import numpy as np

# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1,8, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy())
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")

输出:

1/1 [==============================] - 0s 303ms/step
1/1 [==============================] - 0s 26ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step

在这里插入图片描述

八、总结

本次使用了自定义数据增强方式,对dataset进行操作,也可以使用数据增强生成器ImageDataGenerator进行数据增强。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/872097.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++题解】1053 - 求100+97+……+4+1的值。

欢迎关注本专栏《C从零基础到信奥赛入门级&#xff08;CSP-J&#xff09;》 问题&#xff1a;1053 - 求10097……41的值。 类型&#xff1a;简单循环 题目描述&#xff1a; 求 10097⋯41 的值。 输入&#xff1a; 无。 输出&#xff1a; 输出一行&#xff0c;即求到的和…

Linux--网络层 IP协议

目录 0.往期文章 1.IP基本概念 2. IP协议报头格式 3.网段划分 两种网段划分的方式 为什么要进行网段划分 4.特殊的IP 地址 5.IP 地址的数量限制 6.私有 IP 地址和公网 IP 地址*** NAT技术 认识公网 运营商扮演的角色 7.路由 8.16位标识&#xff0c;3为标志和13位…

加速自动驾驶模型迭代,数据存算一体是关键

自动驾驶的每一个业务阶段都会涉及到 AI 深度学习算法和算力的参与&#xff0c;机器视觉&#xff0c;深度学习&#xff0c;传感器技术等均在自动驾驶领域发挥着重要的作用。自动驾驶系统不断迭代的前提是算法的持续优化&#xff0c;目前&#xff0c;自动驾驶发展的瓶颈主要在于…

解决ubuntu22.04无法识别CH340/CH341和vscode espidf插件无法选择串口设备节点问题

文章目录 解决ubuntu22.04无法识别CH340/CH341和vscode espidf插件无法选择串口设备节点问题不识别CH340/CH341报错解决办法升级驱动编译安装 卸载brltty程序 vscode espidf插件无法选择串口设备节点问题解决办法编译安装 解决ubuntu22.04无法识别CH340/CH341和vscode espidf插…

坐标大连!提交EI、Scopus、知网检索!第五届经济管理与大数据应用国际学术会议(ICEMBDA 2024)

合作ACM出版-EI稳检索 高录用&#xff0c;快见刊&#xff01; 管理、经济、金融、计算机相关主题均可投稿 目前仍有口头汇报名额&#xff0c;如有需要请尽快报名 重要信息 会议官网&#xff1a;www.icembda.org 会议时间&#xff1a;2024年10月25日-27日 会议地点&#x…

【Python系列】方法返回2个值

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

spring security 入门基础,表单认证web页面跳转

一、导入所需依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.6.2</version></parent><!-- web 支持 --><dependency><groupId>…

FPGA 综合笔记

仿真时阻塞赋值和非阻塞赋值 Use of Non-Blocking Assignment in Testbench : Verilog Use of Non-Blocking Assignment in Testbench : Verilog - Stack Overflow non-blocking assignment does not work as expected in Verilog non-blocking assignment does not work a…

Linux云计算 |【第二阶段】SECURITY-DAY3

主要内容&#xff1a; Prometheus监控服务器、Prometheus被监控端、Grafana监控可视化 补充&#xff1a;Zabbix监控软件不自带LNMP和DB数据库&#xff0c;需要自行手动安装配置&#xff1b;Prometheus监控软件自带WEB页面和DB数据库&#xff1b;Prometheus数据库为时序数据库&…

adaptive AUTOSAR UCM模块中SoftwareCluster与Software Package是什么样的关系,他们分别包含哪些元素?

在自适应AUTOSAR(Adaptive AUTOSAR)的更新和配置管理(UCM)模块中,SoftwareCluster和Software Package是软件更新过程中的两个关键概念,它们之间有着密切的关系: SoftwareCluster:通常指的是一组功能相关的软件组件,它们共同实现了车辆中的一个或多个特定功能。在UCM中…

柔性织物处理 | 山大宋锐老师 | 最新演讲

笔者是清华在读研究生&#xff0c;主要关注人形机器人、具身智能。将持续分享行业前沿动态、学者观点整理、论文阅读笔记、知识学习路线等。欢迎交流 最近听了宋老师的演讲&#xff0c;以下是学习整理。部分图截自直播&#xff0c;若模糊望见谅 演讲信息&#xff1a; 【会议】…

Python | Leetcode Python题解之第365题水壶问题

题目&#xff1a; 题解&#xff1a; class Solution:def canMeasureWater(self, x: int, y: int, z: int) -> bool:if x y < z:return Falseif x 0 or y 0:return z 0 or x y zreturn z % math.gcd(x, y) 0

实现BeanPostProcessor

文章目录 1.实现初始化方法1.目录2.InitializingBean.java3.MonsterService.java 实现初始化接口4.SunSpringApplicationContext.java 调用初始化方法5.测试 2.实现后置处理器1.目录2.BeanPostProcessor.java 后置处理器接口3.SunBeanProcessor.java 自定义后置处理器4.SunSpri…

ZMQ请求应答模型

案例一 这个案例的出处是ZMQ的官网。请求段发送Hello&#xff0c;应答端回复World。 ZMQ Request(client) #include <string> #include <iostream> #include <zmq.hpp>using namespace std; using namespace zmq; // 使用 zmq 命名空间int main() {// ini…

PHP轻创推客集淘客地推任务平台于一体的综合营销平台系统源码

&#x1f680;轻创推客&#xff0c;营销新纪元 —— 集淘客与地推任务于一体的全能平台&#x1f310; &#x1f308;【开篇&#xff1a;营销新潮流&#xff0c;轻创推客引领未来】 在瞬息万变的营销世界里&#xff0c;你还在为寻找高效、全面的营销渠道而烦恼吗&#xff1f;&…

【STM32】C语言基础补充

学习过程中发现自己好些需要用到的C语言语法、特征都不太熟练了&#xff0c;特意记录一下&#xff0c;免得忘记了&#xff0c;以后遇到了新的也会继续更新 目录 1 全局变量 2 结构体 3 静态变量 4 memset()函数 5 使用8位的存储器存16位的数 1 全局变量…

汽车冷却液温度传感器

1、冷却液温度传感器的功能 发动机冷却液温度传感器&#xff0c;也称为ECT&#xff0c;是帮助保护发动机&#xff0c;提高发动机工作效率以及帮助发动机稳定运行的非常重要的传感器之一。 发动机冷却液温度 &#xff08;ECT&#xff09; 传感器用于测量发动机的冷却液温度&…

Vue项目创建和使用

快速上手 | Vue.js (vuejs.org) nodejs.org/ vue项目实质上是index.html页面和多个js文件的集合&#xff0c;最终解析后的html和js代码可以由浏览器解析运行&#xff1a; vue项目的创建&#xff0c;需要脚手架工具来搭建&#xff1b; 在编译的源码阶段&#xff0c;文件格式为.…

集团数字化转型方案(十二)

集团数字化转型方案致力于通过构建一个集成化的数字平台&#xff0c;全面应用大数据分析、人工智能、云计算和物联网等前沿技术&#xff0c;推动业务流程、管理模式和决策机制的全面升级。该方案将从业务流程的数字化改造开始&#xff0c;优化资源配置&#xff0c;提升运营效率…

MySQL的源码安装及基本部署(基于RHEL7.9)

这里源码安装mysql的5.7.44版本 一、源码安装 1.下载并解压mysql , 进入目录: wget https://downloads.mysql.com/archives/get/p/23/file/mysql-boost-5.7.44.tar.gz tar xf mysql-boost-5.7.44.tar.gz cd mysql-5.7.44/ 2.准备好mysql编译安装依赖: yum install cmake g…