TensorFlow2实战-系列教程3:猫狗识别1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、项目介绍

基本流程:

  • 数据预处理:图像数据处理,准备训练和验证数据集
  • 卷积网络模型:构建网络架构
  • 过拟合问题:观察训练和验证效果,针对过拟合问题提出解决方法
  • 数据增强:图像数据增强方法与效果
  • 迁移学习:深度学习必备训练策略

在我们的数据中,有训练和验证,训练集中分别有猫狗两个类别,都有1000张图像,验证集则有500张

2、数据读取

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
  1. 导包
  2. 指定数据路径
  3. 训练数据路径
  4. 验证数据路径
  5. 训练数据猫类别路径
  6. 训练数据狗类别路径
  7. 验证数据猫类别路径
  8. 训练数据狗类别路径

3、构建卷积神经网络

model = tf.keras.models.Sequential([
    #如果训练慢,可以把数据设置的更小一些
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    
    #为全连接层准备
    tf.keras.layers.Flatten(),
    
    tf.keras.layers.Dense(512, activation='relu'),
    # 二分类sigmoid就够了
    tf.keras.layers.Dense(1, activation='sigmoid')
])

3个3x3卷积,穿插3个2x2池化,拉平操作,两个全连接层

model.summary()

打印一下模型架构:

配置训练器:

model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['acc'])

4、数据预处理

  • 读进来的数据会被自动转换成tensor(float32)格式,分别准备训练和验证
  • 图像数据归一化(0-1)区间
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        train_dir,  # 文件夹路径
        target_size=(64, 64),  # 指定resize成的大小
        batch_size=20,
        # 如果one-hot就是categorical,二分类用binary就可以
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(64, 64),
        batch_size=20,
        class_mode='binary')

打印结果:
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

5、模型训练

  • 直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
  • steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step
history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,  # 2000 images = batch_size * steps
      epochs=20,
      validation_data=validation_generator,
      validation_steps=50,  # 1000 images = batch_size * steps
      verbose=2)

部分打印结果:
Epoch 1/20 100/100 - 9s - loss: 0.6909 - acc: 0.5240 - val_loss: 0.6952 - val_acc: 0.5000
Epoch 2/20 100/100 - 9s - loss: 0.6645 - acc: 0.5960 - val_loss: 0.6906 - val_acc: 0.5360

Epoch 19/20 100/100 - 9s - loss: 0.1750 - acc: 0.9460 - val_loss: 0.6277 - val_acc: 0.7390
Epoch 20/20 100/100 - 9s - loss: 0.1593 - acc: 0.9505 - val_loss: 0.5901 - val_acc: 0.7490

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

在这里插入图片描述
在这里插入图片描述
将训练损失、准确率和对应的epoch分别画图展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/351021.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring 的执行流程以及 Bean 的作用域和生命周期

文章目录 Bean 的作用域更改作用域的方式singletonprototype Spring 执行流程Bean 的生命周期 Bean 的作用域 Spring 容器在初始化⼀个 Bean 的实例时,同时会指定该实例的作用域。Bean 有6种作用域 singleton:单例作用域prototype:原型作用域…

Hadoop-MapReduce-MRAppMaster启动篇

一、源码下载 下面是hadoop官方源码下载地址&#xff0c;我下载的是hadoop-3.2.4&#xff0c;那就一起来看下吧 Index of /dist/hadoop/core 二、上下文 在上一篇<Hadoop-MapReduce-源码跟读-客户端篇>中已经将到&#xff1a;作业提交到ResourceManager&#xff0c;那…

首发:2024全球DAO组织发展研究

作者&#xff0c;张群&#xff08;专注DAO及区块链应用研究&#xff0c;赛联区块链教育首席讲师&#xff0c;工信部赛迪特邀资深专家&#xff0c;CSDN认证业界专家&#xff0c;微软认证专家&#xff0c;多家企业区块链产品顾问&#xff09; DAO&#xff08;去中心化自治组织&am…

adb测试冷启动和热启动 Permission Denial解决

先清理日志 adb shell logcat -c 打开手机模拟器中的去哪儿网&#xff0c;然后日志找到包名和MainActivity adb shell logcat |grep Main com.Qunar/com.mqunar.atom.alexhome.ui.activity.MainActivity 把手机模拟器的去哪儿的进程给杀掉 执行 命令 adb shell am start -W…

TensorFlow2实战-系列教程1:回归问题预测

&#x1f9e1;&#x1f49b;&#x1f49a;TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 1、环境测试 import tensorflow as tf import numpy as np tf.__version__打印结果 ‘…

深入理解Redis:如何设置缓存数据的过期时间及其背后的机制

目录 Redis 给缓存数据设置过期时间 Redis是如何判断数据是否过期的呢&#xff1f; 过期的数据的删除策略 Redis 内存淘汰机制 Redis 给缓存数据设置过期时间 一般情况下&#xff0c;我们设置保存的缓存数据的时候都会设置一个过期时间。为什么呢&#xff1f; 因为内存是有…

4小时精通MyBatisPlus框架

目录 1.介绍 2.快速入门 2.1.环境准备 2.2.快速开始 2.2.1引入依赖 2.2.2.定义Mapper ​编辑 2.2.3.测试 2.3.常见注解 ​编辑 2.3.1.TableName 2.3.2.TableId 2.3.3.TableField 2.4.常见配置 3.核心功能 3.1.条件构造器 3.1.1.QueryWrapper 3.1.2.UpdateWra…

Redis(八)哨兵机制(sentinel)

文章目录 哨兵机制案例认识异常 哨兵运行流程及选举原理主观下线(Subjectively Down)ODown客观下线(Objectively Down)选举出领导者哨兵选出新master过程 哨兵使用建议 哨兵机制 吹哨人巡查监控后台master主机是否故障&#xff0c;如果故障了根据投票数自动将某一个从库转换为新…

Java 基础知识-File类

大家好我是苏麟 , 今天聊聊File . 资料来自黑马程序员 File类 java.io.File 类是文件和目录路径名的抽象表示&#xff0c;主要用于文件和目录的创建、查找和删除等操作。 构造方法 public File(String pathname) &#xff1a;通过将给定的路径名字符串转换为抽象路径名来创建…

盘古信息IMS OS 数垒制造操作系统+ 产品及生态部正式营运

启新址吉祥如意&#xff0c;登高楼再谱新篇。2024年1月22日&#xff0c;广东盘古信息科技股份有限公司新办公楼层正式投入使用并举行了揭牌仪式&#xff0c;以崭新的面貌、奋进的姿态开启全新篇章。 盘古信息总部位于东莞市南信产业园&#xff0c;现根据公司战略发展需求、赋能…

【双目】基于findChessboardCorners的双目精度评估,可以直接使用

1. 基于findChessboardCorners的双目精度评估 原理&#xff1a; 代码&#xff1a; #include <iostream> #include <opencv2/opencv.hpp>using namespace std; using namespace cv;int main() {// 加载图像auto srcimage imread("/home/oem/data/steroe_p…

Hadoop增加新节点环境配置(自用)

完成Hadoop集群增添一个新的节点配置&#xff08;文中命名为&#xff09;Hadoop106&#xff0c;没有进行继续为该节点分配身份职能的步骤 1.在VMware中安装CentOS 7 新建虚拟机 1.⾸先我们创建⼀个新的虚拟机&#xff0c;也可以点⽂件-新建虚拟机。 2.选择⾃定义&#xff0c…

网页元素圈选

从前面我们已验证配置自动化是可行的&#xff0c;接下来就实现元素选择&#xff0c;当然有了配置化&#xff0c;我们也是可以通过浏览器F12的调试工具去把元素xpath复制出来&#xff08;ps:反正又不是不能用&#xff09;&#xff0c;但是这不是我们最终目的。 其实圈选效果如下…

CSS3如何实现从右往左布局的按钮组(固定间距)

可以通过下方CSS实现&#xff0c;下面的CSS表示按钮从右往左布局&#xff0c;且间距为10px: .right-btn {position: relative;float: right;margin-right: 10px; }类似这种&#xff1a; 这种&#xff1a; 注意&#xff1a; 不能使用right:10px代替margin-right:10px&#x…

聚醚醚酮(Polyether Ether Ketone)PEEK主要应用于哪些行业领域?

聚醚醚酮&#xff08;Polyether Ether Ketone&#xff0c;PEEK&#xff09;广泛应用于以下行业&#xff1a; 1.航空航天业&#xff1a; PEEK常被用于制造航空航天组件&#xff0c;如飞机零部件、航天器构件&#xff0c;因其轻量化、高强度和耐高温性能。 2.汽车工业&#xff1…

滑木块H5小游戏

欢迎来到程序小院 滑木块 玩法&#xff1a;点击木块横着的只能左右移动&#xff0c;竖着的只能上下移动&#xff0c; 移动到箭头的位置即过关&#xff0c;不同关卡不同的木块摆放&#xff0c;快去滑木块吧^^。开始游戏https://www.ormcc.com/play/gameStart/260 html <can…

Django项目搭建

一、创建项目 在命令行中执行代码 $ django-admin startproject mysitedjango-admin 为内部命令startproject 为参数mysite 项目名 备注 避免使用 Python 或 Django 的内部保留字来命名项目。比如说&#xff0c;避免使用像 django (会和 Django 自己产生冲突)或 test (会和 P…

深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包 import mathimport torch import torch.nn as nn import torch.nn.functional as F silu激活函数 class SiLU(nn.Module): # SiLU激活函数staticmethoddef forward(x):return x * torch.sigmoid(x) 归一化设置 def get_norm(norm, num_channels, num_groups)…

Java集合-Map接口(key-value)

Map接口的特点&#xff1a;①KV键值对方式存储②Key键唯一&#xff0c;Value允许重复③无序。 Map有四个实现类&#xff1a;1.HashMap类2.LinkedHashMap类3.TreeMap类4.Hashtable类 1.HashMap类&#xff1a; 存储结构&#xff1a;哈希表 数组Node[ ] 链表&#xff08;红黑…

暴力破解

暴力破解工具使用汇总 1.查看密码加密方式 在线网站&#xff1a;https://cmd5.com/ http://www.158566.com/ https://encode.chahuo.com/kali&#xff1a;hash-identifier2.hydra 用于各种服务的账号密码爆破&#xff1a;FTP/Mysql/SSH/RDP...常用参数 -l name 指定破解登录…