【神经网络】基于CNN(卷积神经网络)构建猫狗分类模型

文章目录

    • 解决问题
    • 数据集
    • 探索性数据分析
    • 数据预处理
      • 数据集分割
      • 数据预处理
    • 构建模型并训练
      • 构建模型
      • 训练模型
    • 结果分析与评估
    • 模型保存
    • 结果预测
    • 经验总结

解决问题

针对经典猫狗数据集,基于卷积神经网络,构建猫狗二元分类模型,使用数据集进行参数训练,模型评估,然后使用模型进行分类预测,最后对模型进行保存,供后续使用。

数据集

数据集来源

猫狗数据集

探索性数据分析

查看待训练识别图片

from matplotlib import pyplot as plt
import os
import random

# 获取文件名
_,_,cat_images = next(os.walk('../../dataset/kagglecatsanddogs_5340/PetImages/Cat'))

# 准备3*3 图表
fig, ax = plt.subplots(3, 3, figsize=(20, 10))
# 随机选择一幅图像并绘制
for idx, img in enumerate(random.sample(cat_images, 9)):
    img_read = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Cat/' + img)
    ax[int(idx / 3), idx % 3].imshow(img_read)
    ax[int(idx / 3), idx % 3].set_title('cat/' + img)
    ax[int(idx / 3), idx % 3].axis('off')
plt.show()

查看狗图片类似,将Cat目录换成Dog即可

image-20240614224011275

数据预处理

数据集分割

由于下载的图片猫和狗各在一个文件夹内,如下:

image-20240613223710122

需要将数据按80%:20%进行分割,分为训练集和测试集。目录结构如下:

image-20240613223517479

下面进行数据拆分,核心代码(以猫图片为例)如下:

# 训练数据集80% 测试数据集20%
train_size = 0.8
# 获取猫图像数量
_, _, cat_images = next(os.walk(src_folder+'Cat/'))
num_cat_images = len(cat_images)
num_cat_images_train = int(train_size * num_cat_images)
num_cat_images_test = num_cat_images - num_cat_images_train
# 分割猫图像
cat_train_images = random.sample(cat_images, num_cat_images_train)
for img in cat_train_images:
	shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Train/Cat/')
cat_test_images  = [img for img in cat_images if img not in cat_train_images]
for img in cat_test_images:
	shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Test/Cat/')
	

数据预处理

这一步要将分割后的数据集转成和模型结构匹配的数据类型。使用keras提供的ImageDataGenerator类和flow_from_directory()方法

ImageDataGenerator类:图像增强类,可以进行图像旋转、图像平移、水平翻转、图像缩放等操作;

flow_from_directory()方法:ImageDataGenerator类的方法,支持以图像路径为输入,按批次加载图像到内存,防止训练数据量过大,机器内存不足问题;还支持对图像进行预处理操作,例如尺寸缩放和图像增强

# 训练数据预处理
training_data_generator = ImageDataGenerator(rescale=1./255)
training_set = training_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/train/',target_size=(32, 32),batch_size=16,class_mode='binary')

# 测试数据预处理
testing_data_generator = ImageDataGenerator(rescale= 1./255)
testing_set = testing_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/test/',target_size=(32, 32),batch_size=16, class_mode='binary')

构建模型并训练

构建模型

# 定义超参数
# 特征滤波器尺寸
FILTER_SIZE = 3
# 特征滤波器数量
FILTER_NUM = 32
# 图片输入尺寸
INPUT_SIZE = 32
# 最大池化尺寸
MAXPOOL_SIZE = 2
# 批量处理图片的大小
BATCH_SIZE = 16
STEPS_PER_EPOCH = 20000 // BATCH_SIZE
# 训练轮次
EPOCHS = 10
# 定义模型
model = Sequential()
# 添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 再添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 对输出结果进行降维处理,转成一维张量
model.add(Flatten())
# 添加全链接层,根据特征进行分类预测
model.add(Dense(units=128, activation='relu'))
# 添加dropout层,随机将一部分输入设置为0,防止模型复杂,出现过拟合现象
model.add(Dropout(0.5))
# 添加输出层,一个节点
model.add(Dense(units=1, activation='sigmoid'))

该模型结构分为,卷积池化层,卷积池化层,Flatten层,全链接层1,全链接层2(输出层)如下:

image-20240614221747896

其中,第一列是神经网络的层,第二列是每层的输出形状,第三层是每层训练的参数

可以看到,该模型图像输入尺寸是(32,32),经过一层卷积(32个特征过滤器)输出为(30,30,32),经过一层最大池化层,输出为(15,15,32);其中特征滤波器尺寸为3*3,所以滤波后的尺寸会是32-(3-1)=30,经过最大池化(2x2)尺寸减半,为15。

训练模型

# 模型训练
model.fit(training_set, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS, verbose=1)
image-20240614123217943

结果分析与评估

model.evaluate(testing_set,steps=len(testing_set),verbose=1)
image-20240614222619817

准确度达到了0.7856

模型保存

from joblib import dump, load

# 模型持久化 到磁盘
dump(model, './猫狗分类.onnx')

结果预测

引入保存模型,随机选取一张图片进行预测分类

from matplotlib import pyplot as plt
fig, ax = plt.subplots()
img = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg')
ax.imshow(img)
plt.show()
image-20240614223543720
from joblib import dump, load
model = load('./猫狗分类.onnx')

from tensorflow.keras.preprocessing.image import img_to_array,load_img

img = load_img('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg',target_size=(32,32))
img = img_to_array(img)
img /= 255
import numpy as np
img_array = np.expand_dims(img, axis=0)
print(img_array.shape)
model.predict(img_array)

在这里插入图片描述

由于是二元分类,0和1分别表示猫狗,输出概率接近表示是狗,接近0表示是猫狗。但具体为啥0表示猫1表示狗而不是反过来表示,还待研究。

经验总结

1 在使用next()加载图像时,要确保路径正确,否则会报StopIteration错误,原因是路径错误,找不到可迭代的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/720544.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RoboDK试用期间提示无效或过期的许可证

问题描述 RoboDK下载下来在试用期间提示如下信息,不知道什么原因 临时解决方法 将C:\Users\${username}\AppData\Roaming\RoboDK该目录下的文件全部删除掉,便可以正常使用RoboDK应用了,但是等软件关闭后还是会出现上面的问题,…

IPython 使用技巧整理

IPython 是一个强大的交互式 Python shell,广泛用于数据分析、科学计算和开发工作。本文将整理一些 IPython 的实用技巧,帮助你更高效地使用 IPython。 目录 快速启动和退出魔法命令高效的代码编写变量和对象信息历史命令IPython 扩展错误调试与 Jupy…

js中!emailPattern.test(email) 的test是什么意思

test 是 JavaScript 正则表达式(RegExp)对象的方法之一,用于测试一个字符串是否与正则表达式匹配。正则表达式是一种用于匹配字符串的模式,通常用于验证输入数据、查找和替换文本等。 使用 test 方法 test 方法语法如下&#xf…

YOLOv8旋转目标检测Yolov8n-obb详细实例+rolabelimg

一、Yolov8环境搭建 首先创建虚拟环境下载安装(其实就是yolov8的环境)再大概写一下步骤,没有想详细的看本人另外一篇:YOLOv8环境搭建_yolov8环境配置-CSDN博客 1、下载安装anaconda 2、创建虚拟环境 conda create -n my_yolov8…

油猴 脚本如何添加包含哪个网址 执行脚本

油猴 脚本如何添加包含哪个网址 执行脚本 在这里面加上就可以 // include *://blog.csdn.net/*/article/details/* // include *.blog.csdn.net/article/details/*

虚拟货币投资指南|XEX交易所

什么是虚拟货币? 虚拟货币是一种基于区块链技术的数字资产,具有去中心化、透明性和安全性等特点。比特币(BTC)、以太坊(ETH)和莱特币(LTC)等是目前较为知名的虚拟货币。 虚拟货币投…

安卓软件自动运行插件的开发源代码介绍!

随着移动互联网的快速发展,安卓操作系统凭借其开放性和灵活性,成为了众多开发者们的首选平台,在安卓应用的开发中,为了实现各种复杂的功能,插件化技术逐渐受到青睐。 其中,自动运行插件作为一种能够实现应…

搜维尔科技:SenseGlove虚拟训练、VR/AR 模拟和研究中的触觉反馈

训练 传统培训成本高昂且风险大,需要重复资产或停产。在培训中使用虚拟现实可以轻松解决这些问题。借助 SenseGlove,终于可以研究和评估与传统培训效果相同的虚拟培训技术。体验低成本的定制 VR 培训,同时保留现实世界的肌肉记忆和记忆力。 …

人工智能在气候变化中的应用

Hi~!这里是奋斗的小羊,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 💥💥个人主页:奋斗的小羊 💥💥所属专栏:C语言 🚀本系列文章为个人学习…

新手一次学会SpringBoot项目部署 + Docker中运行Samba服务设置共享目录

SpringBoot项目部署 1.IDEA打包,在IDEA终端,输入mvn clean package 2.将项目target中的jar包放入linux目录 3.运行jar包 前台运行(直接显示输出): java -jar data-transport-server-0.0.1-SNAPSHOT.jar后台运行&…

VBA技术资料MF160:提取文件夹中文件的详细信息

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。“VBA语言専攻”提供的教程一共九套,分为初级、中级、高级三大部分,教程是对VBA的系统讲解&#…

shell命令(进程管理和用户管理)

一、进程处理相关命令 1、进程的概念 进程的概念主要有两点: 进程是一个实体。每一个进程都有它自己的地址空间,一般情况下,包括文本区域( text region )、数据区域( data region )和堆栈&am…

【Android 11】AOSP Settings添加屏幕旋转按钮

前言 这里是客户要求添加按钮以实现屏幕旋转。屏幕旋转使用adb的命令很容易实现: #屏幕翻转 adb shell settings put system user_rotation 1 #屏幕正常模式 adb shell settings put system user_rotation 0这里的值可以是0,1,2&#xff0c…

EasyCVR/EasyDSS无人机直播技术助力野生动物监测:开启野生动物保护新篇章

近日有新闻报道,一名挖掘机师傅在清理河道时,意外挖出一只稀有的扬子鳄,挖机师傅小心翼翼地将其放在一边,扬子鳄也顺势游回一旁的河道中。 随着人类对自然环境的不断探索和开发,野生动物及其栖息地的保护显得愈发重要。…

STM32学习和实践笔记(35):内部温度传感器实验

1.STM32F1内部温度传感器介绍 1.1 STM32F1内部温度传感器简介 STM32F1内部含有一个温度传感器,可用来测量 (STM32芯片的)CPU 及周围的温度(TA)。(实际并不用来测周围的温度,仅用来测试CPU的温度) 此温度传…

2023数模A题——定日镜场的优化问题

A题——定日镜场的优化问题 思路:该题主要考察的几何知识和天文学知识,需要不同角度下的镜面和遮挡情况。 资料获取 问题1: 若将吸收塔建于该圆形定日镜场中心,定日镜尺寸均为 6 m6 m,安装高度均为 4 m,且…

Python使用策略模式生成TCP数据包

使用策略模式(Strategy Pattern)来灵活地生成不同类型的TCP数据包。 包括三次握手、数据传输和四次挥手。 from scapy.all import * from scapy.all import Ether, IP, TCP, UDP, wrpcap from abc import ABC, abstractmethodclass TcpPacketStrategy(A…

表面声波滤波器——设计方案(4)

设计步骤 设计声表面波滤波器,首先需要分析器件的指标要求,如中心频率、使用带宽、插入损耗等,结合产线工艺水平,选择合适的衬底材料和换能器材料。确认可以满足器件性能需求的换能器设计方案,然后通过软件仿真。对叉…

Tomcat Websocket应用实例研究

概述 本文介绍了如何根据Tomcat给出的websocket实例,通过对实例的学习,定制自己基于websocket的应用。 环境及版本: Ubuntu 22.04.4 LTSApache Tomcat/10.1.20openjdk 11.0.23 2024-04-16浏览器:Chrome 相关资源及链接 Class…

vue+springboot导入Excel表格

1.创建一个excel表格,与数据库需要的表头对应 2.(前端)导入excel的按钮 <template class"importExcel"><el-button type"primary" click"chooseFile">导入<i class"el-icon-upload el-icon--right"></i><…