TensorFlow项目练手(二)——猫狗熊猫的分类任务

项目介绍

通过猫狗熊猫图片来对图片进行识别,分类出猫狗熊猫的概率,文章会分成两部分,从基础网络模型->利用卷积网络经典模型Vgg。

基础网络模型

基础的网络模型主要是用全连接层来分类,比较经典的方法,也是祖先最先使用的方法,目前已经在这类问题上,被卷积网络模型所替代,学习这部分是为了可以了解到最简单的分类任务的写法。

一、准备数据

  • 准备猫狗熊猫的训练数据集,各自1000张图片,分别放在/train/cats/train/dogs/train/panda
  • 准备猫狗熊猫的测试数据集,各5-10张,统一放在/test目录下,后续通过随机取出来测试

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

二、开始编写

1、获取数据

数据的获取主要包含2部分

  1. 先读取图片数据
  2. 对图片数据进行预处理
import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layers

from keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
lb = LabelBinarizer()

import matplotlib.pyplot as plt
import random
import os
import numpy as np
np.set_printoptions(threshold=10000)
import cv2
import pickle

# 遍历所有文件名
def findAllFile(base):
    for root, ds, fs in os.walk(base):
        for f in fs:
            yield f

# 数据切分
def split_train(data,label,test_ratio):
    np.random.seed(43)
    shuffled_indices=np.random.permutation(len(data))
    test_set_size=int(len(data)*test_ratio)
    test_indices =shuffled_indices[:test_set_size]
    train_indices=shuffled_indices[test_set_size:]
    return data[train_indices],data[test_indices],label[train_indices],label[test_indices]

image_dir = ("./train/cats/", "./train/dogs/", "./train/panda/")
image_path = []
data = []
labels = []

# 读取图像路径
for path in image_dir:
    for i in findAllFile(path):
        image_path.append(path+i)

# 随机化数据
random.seed(43)
random.shuffle(image_path)

# 读取图像数据,读取label文件名数据
for j in image_path:
    image = cv2.imread(j)
    image = cv2.resize(image,(32,32)).flatten()
    data.append(image)
    label = j.split("/")[-2]
    labels.append(label)

# 数据预处理:规格化数据
data = np.array(data,dtype="float") / 255.0
labels = np.array(labels)
# 数据切分
(trainX,testX,trainY,testY) = split_train(data,labels,test_ratio=0.25)
# 将cat、dog、panda规格化数据
trainY = lb.fit_transform(trainY)
testY = lb.fit_transform(testY)

# 最终数据结果
print(trainX)
print(data)
print(data.shape) # (3000, 3072)32x32x3=3072,其图片3通道被拉长成一条操作
print(lb.classes_) # ['cats' 'dogs' 'panda']

将所有图片读取,并保存他们的数据集数据和训练结果,每张图片都会被规整到32x32并且进行拉长操作flatten(),最终输出的数据是一组图片的RGB数据

  • 数据集:我们将数据进行切分,25%作为验证集,75%数据作为训练集
  • 训练结果(label):我们按照文件名上进行分割,分割出对应的名字作为label

在这里插入图片描述

2、构建网络模型

  • 网络模型:采用全连接层
  • 优化器:使用梯度下降法SGD
  • 损失函数:使用分类算法
  • 权重初始化:高斯截断分布函数
# 2、创建模型层
EPOCHS = 200
model = Sequential()
model.add(Dense(512,input_shape=(3072,),activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)))
model.add(Dropout(0.5))
model.add(Dense(256,activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)))
model.add(Dropout(0.5))
model.add(Dense(len(lb.classes_),activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)))
# 损失函数和优化器,正则惩罚
model.compile(loss="categorical_crossentropy", optimizer=SGD(lr=0.001),metrics=["accuracy"])
H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=EPOCHS, batch_size=32)

在这里插入图片描述

3、模型评估

模型训练后之后,对模型进行评估,可以看到当前的分类情况

# 3、模型评估
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1), target_names=lb.classes_))

在这里插入图片描述

4、数据可视化

将数据绘制在图上,看看其训练和预测的准确率情况,并将其保存起来

# 4、数据可视化
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy (Simple NN)")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig("./plot.png")

在这里插入图片描述

5、保存模型

# 5、保存模型到本地
model.save("./model")
f = open("./label.pickle", "wb")
f.write(pickle.dumps(lb))
f.close()

6、结果输出预测

随机获取测试集中的图片,对数据进行预处理后,进行预测,将结果显示出来

# 6、测试模型
test_image_dir =  "./test/"
test_image_path = []
for i in findAllFile(test_image_dir):
    test_image_path.append(test_image_dir+i)
test_image = random.sample(test_image_path, 1)[0]

# 数据预处理
image = cv2.imread(test_image)
output = image.copy()
image = image.astype("float") / 255.0
image = cv2.resize(image,(32,32)).flatten()
image = image.reshape((1, image.shape[0]))

# 加载模型
model = load_model("./model")
lb = pickle.loads(open("./label.pickle", "rb").read())
# 开始预测
preds = model.predict(image)

# 查看预测结果
text1 = "{}: {:.2f}% ".format(lb.classes_[0], preds[0][0] * 100)
text2 = "{}: {:.2f}% ".format(lb.classes_[1], preds[0][1] * 100)
text3 = "{}: {:.2f}% ".format(lb.classes_[2], preds[0][2] * 100)
cv2.putText(output, text1, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text2, (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text3, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.imshow("Image", output)
cv2.waitKey(0)

通过我们的测试集合看出来,其准确率还是有点不尽如意,主要是数据集较小,且训练次数不足的原因导致

在这里插入图片描述

vgg模型

在后续的技术迭代中,卷积神经网络基本上已经覆盖了图像识别技术,使用卷积神经网络结合vgg的架构,可以更准确地提高准确率

一、准备数据

跟上面基础模型一样,所有数据都是一样的

二、开始编写

1、获取数据

获取数据的代码跟基础网络模型完全一致,唯一区别在于# image = cv2.resize(image,(32,32)).flatten() # 将图片resize到64,且去掉拉长操作

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layers

from keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
lb = LabelBinarizer()

import matplotlib.pyplot as plt
import random
import os
import numpy as np
np.set_printoptions(threshold=10000)
import cv2
import pickle

# 遍历所有文件名
def findAllFile(base):
    for root, ds, fs in os.walk(base):
        for f in fs:
            yield f

# 数据切分
def split_train(data,label,test_ratio):
    np.random.seed(43)
    shuffled_indices=np.random.permutation(len(data))
    test_set_size=int(len(data)*test_ratio)
    test_indices =shuffled_indices[:test_set_size]
    train_indices=shuffled_indices[test_set_size:]
    return data[train_indices],data[test_indices],label[train_indices],label[test_indices]

image_dir = ("./train/cats/", "./train/dogs/", "./train/panda/")
image_path = []
data = []
labels = []

# 1、数据预处理
# 读取图像路径
for path in image_dir:
    for i in findAllFile(path):
        image_path.append(path+i)

# 随机化数据
random.seed(43)
random.shuffle(image_path)

# 读取图像数据,读取label文件名数据
for j in image_path:
    image = cv2.imread(j)
    image = cv2.resize(image,(64,64))
    # image = cv2.resize(image,(32,32)).flatten() # 将图片resize到64,且去掉拉长操作
    data.append(image)
    label = j.split("/")[-2]
    labels.append(label)

# 规格化数据
data = np.array(data,dtype="float") / 255.0
labels = np.array(labels)
# 数据切分
(trainX,testX,trainY,testY) = split_train(data,labels,test_ratio=0.25)
# 将cat、dog、panda规格化数据
trainY = lb.fit_transform(trainY)
testY = lb.fit_transform(testY)

# 最终数据结果
print(trainX)
print(data)
print(data.shape) # (3000, 3072)32x32x3=3072,其图片3通道被拉长成一条操作
print(lb.classes_) # ['cats' 'dogs' 'panda']

2、构建网络模型

采用vgg的框架,搭建最简单的vgg层数的网络模型

from keras.models import Sequential
from keras.layers.normalization.batch_normalization_v1 import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.initializers import TruncatedNormal
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense


model = Sequential()
chanDim = 1
inputShape = (64, 64, 3)

model.add(Conv2D(32, (3, 3), padding="same",input_shape=inputShape))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))

# (CONV => RELU) * 2 => POOL 
model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))

# (CONV => RELU) * 3 => POOL 
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))

# FC层
model.add(Flatten())
model.add(Dense(512))
model.add(Activation("relu"))
model.add(BatchNormalization())
#model.add(Dropout(0.6))

# softmax 分类,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)
model.add(Dense(len(lb.classes_)))
model.add(Activation("softmax"))

# 损失函数和优化器,正则惩罚
EPOCHS = 200
model.compile(loss="categorical_crossentropy", optimizer=SGD(lr=0.001),metrics=["accuracy"])
H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=EPOCHS, batch_size=32)

在这里插入图片描述

3、数据可视化

同样的操作,将数据绘制在图上,看看其训练和预测的准确率情况,并将其保存起来

# 4、数据可视化
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy (Simple NN)")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig("./plot.png")

在这里插入图片描述

4、保存模型

# 5、保存模型到本地
model.save("./model")
f = open("./label.pickle", "wb")
f.write(pickle.dumps(lb))
f.close()

5、结果输出预测

同样的操作,唯一的区别在于

  • # image = cv2.resize(image,(32,32)).flatten() #不拉平,且改为64x64
  • # image = image.reshape((1, image.shape[0])) #数据改为数组
# 6、测试模型
test_image_dir =  "./test/"
test_image_path = []
for i in findAllFile(test_image_dir):
    test_image_path.append(test_image_dir+i)
test_image = random.sample(test_image_path, 1)[0]

# 数据预处理
image = cv2.imread(test_image)
output = image.copy()
image = image.astype("float") / 255.0
# image = cv2.resize(image,(32,32)).flatten() #不拉平,且改为64x64
# image = image.reshape((1, image.shape[0])) #数据改为数组
image = cv2.resize(image,(64,64))
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

# 加载模型
model = load_model("./model")
lb = pickle.loads(open("./label.pickle", "rb").read())
# 开始预测
preds = model.predict(image)

# 查看预测结果
text1 = "{}: {:.2f}% ".format(lb.classes_[0], preds[0][0] * 100)
text2 = "{}: {:.2f}% ".format(lb.classes_[1], preds[0][1] * 100)
text3 = "{}: {:.2f}% ".format(lb.classes_[2], preds[0][2] * 100)
cv2.putText(output, text1, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text2, (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text3, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.imshow("Image", output)
cv2.waitKey(0)

通过我们的结果展示,可以发现分类的精准度可以达到80-90%以上,说明这个模型是比基础网络模型50%左右的准确度好很多,但是,也会精准的分类错误,原因在于我们只有1000的数据集,比较容易分类错误,当然你的数据量越大,就可以解决当前的问题。
在这里插入图片描述

源代码

  • 源码查看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/35217.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MinGW编译OpenCV 过程记录

1.下载源码opencv-3.4.10.zip ,可以在OpenCV官网下载Releases - OpenCV 解压缩如下: 2.下载Mingw64工具,需要支持posix 并设置系统环境目录,下载的文件名x86_64-8.1.0-release-posix-sjlj-rt_v6-rev0.7z (可以在网上找) 3.使用Cmake工具构建…

微信小程序个人中心展示样式(2)

这是之前的详细的看这里 因为这是好多年前写的了,好多人私信我代码有问题。正好今天有时间简单的还原下代码 话不多说先看图(图片样式自己搞奥~~~~我也好久没弄了这就是个参考demo) 以下是一个使用微信小程序开发的个人中心展示详情的示例: 在微信开发…

基于PyQt5的桌面图像调试仿真平台开发(10)色彩矩阵

系列文章目录 基于PyQt5的桌面图像调试仿真平台开发(1)环境搭建 基于PyQt5的桌面图像调试仿真平台开发(2)UI设计和控件绑定 基于PyQt5的桌面图像调试仿真平台开发(3)黑电平处理 基于PyQt5的桌面图像调试仿真平台开发(4)白平衡处理 基于PyQt5的桌面图像调试仿真平台开发(5)…

解决问题:通配符的匹配很全面, 但无法找到元素 ‘context:component-scan‘ 的声明~

异常描述如下&#xff1a; 产生异常原因&#xff1a; 因为在配置文件中没有找到<context:component-scan />元素的声明&#xff0c;解决办法&#xff1a;将XML配置文件中的声明改为下述代码&#xff1a; <beans xmlns"http://www.springframework.org/schema/b…

01 | 一条 SQL 查询语句是如何执行的?

以下内容出自 《MySQL 实战 45 讲》 一条 SQL 查询语句是如何执行的&#xff1f; 下面是 MySQL 的基本架构示意图&#xff0c;从中可以清楚地看到 SQL 语句在 MySQL 的各个功能模块中的执行过程。 大体来说&#xff0c;MySQL 可以分为 Server 层和存储引擎层两部分。 Server …

Python如何批量将图片以超链接的形式插入Excel

【研发背景】 在日常办公中&#xff0c;我们经常需要将图片插入进Excel中&#xff0c;但是如果插入的图片太多的话&#xff0c;就会导致Excel的文件内存越来越大&#xff0c;但是如果我直插入图片的路径&#xff0c;或者只是更改某一列的数据设置为超链接&#xff0c;这样的话&…

Spring底层核心架构

Spring底层核心架构 相关的配置类 1. user类 package com.zhouyu.service;import org.springframework.stereotype.Component;public class User { }2. AppConfig类 package com.zhouyu;import org.springframework.context.annotation.*; import org.springframework.sched…

open*w*r*t +dnspod ddns动态解析ipv6 远程控制移动内网路由器

1.修改openw*r*t web https管理端口为8443 修改ipv6 https 监听端口list listen_https [::]:8443 cd /etc/config/vi uhttpdvi /etc/config/uhttpdconfig uhttpd mainlist listen_http 0.0.0.0:80list listen_http [::]:80list listen_https 0.0.0.0:443list listen_https [:…

前端Vue一款基于canvas的精美商品海报生成组件 根据个性化数据生成商品海报图 长按保存图片

前端Vue一款基于canvas的精美商品海报生成组件 根据个性化数据生成商品海报图 长按保存图片&#xff0c;下载完整代码请访问uni-app插件市场地址&#xff1a;https://ext.dcloud.net.cn/plugin?id13326 效果图如下: # cc-beautyPoster #### 使用方法 使用方法 <!-- pos…

Java虚拟机(JVM)、垃圾回收器

一、Java简介 1、Java开发及运行版本 JRE(Java Runtime Environment&#xff0c;运行环境) 所有的程序都要在JRE下才能够运行。包括JVM和Java核心类库和支持文件。JDK(Java Development Kit&#xff0c;开发工具包) 用来编译、调试Java程序的开发工具包。包括Java工具(javac/…

【Redis】3、Redis 作为缓存(Redis中的穿透、雪崩、击穿、工具类)

目录 一、什么是缓存二、给业务添加缓存&#xff08;减少数据库访问次数&#xff09;三、给店铺类型查询业务添加缓存(1) 使用 String 类型(2) 使用 List 类型 四、缓存的更新策略(1) 主动更新(2) 最佳实现方案(3) 给查询商铺的缓存添加超时剔除和主动更新的策略① 存缓存&…

泰迪智能科技基于产业技能生态链学生学徒制的双创工作室--促进学生高质量就业

据悉&#xff0c;6月28日&#xff0c;广东省人力资源和社会保障厅在广东岭南现代技师学院举行广东省“产教评”技能生态链建设对接活动。该活动以“新培养、新就业、新动能”为主题&#xff0c;总结推广“产教评”技能人才培养新模式&#xff0c;推行“岗位培养”学徒就业新形式…

Tomcat多实例部署

1、关闭防火墙&#xff0c;将安装 Tomcat 所需软件包传到/opt目录下2、安装好 JDK3、设置JDK环境变量4、安装 tomcat5、配置 tomcat 环境变量6、修改 tomcat2 中的 server.xml 文件&#xff0c;要求各 tomcat 实例配置不能有重复的端口号7、修改各 tomcat 实例中的 startup.sh …

更便捷的人体三维模型制作方法

人体三维模型是一种以计算机辅助设计技术为基础的创新工具&#xff0c;它在医学、生物学、运动学等领域具有广泛的应用价值。这种模型通过将人体的形态、结构与功能等要素进行数字化处理和计算&#xff0c;能够以立体图像的形式展现出来。它可以精确地模拟人体的各种部位&#…

【C】数据在内存中的存储

前言 > 在内存中&#xff0c;整型和浮点型存储的方式是不同的&#xff0c;从内存中读取的方式也是有所差异的&#xff0c;这篇文章主要介绍整型和浮点型在内存中存储的方式。 整型在内存中的存储 计算机中有符号数有3种表示方式&#xff1a; 原码&#xff1a;直接将二进制按…

查看虚拟机主机IP

虚拟机主机ip 文章目录 ifconfigip addr图形化界面 ifconfig 失败了 ip addr 图形化界面

Windows远程连接linux中mysql数据库

我没有mysql并且没有把mysql配置到环境变量中&#xff0c;所以现在我要下载mysql 一.下载mysql Mysql官网下载地址&#xff1a;https://downloads.mysql.com/archives/installer 二.安装mysql 1. 选择设置类型 双击运行mysql-installer-community-8.0.26.msi&#xff0c;这…

MySQL简单查询操作

系列文章目录 前言SELECT子句SELECT后面之间跟列名DISTINCT,ALL列表达式列更名 WHERE子句WHERE子句中可以使用的查询条件比较运算特殊比较运算符BETWEEN...AND...集合查询&#xff1a;IN模糊查询&#xff1a;LIKE空值比较&#xff1a;IS NULL 多重条件查询 ORDER BY子句排序复杂…

【基于容器的部署、扩展和管理】3.10 云原生容器运行时环境和配置管理

往期回顾&#xff1a; 第一章&#xff1a;【云原生概念和技术】 第二章&#xff1a;【容器化应用程序设计和开发】 第三章&#xff1a;【3.1 容器编排系统和Kubernetes集群的构建】 第三章&#xff1a;【3.2 基于容器的应用程序部署和升级】 第三章&#xff1a;【3.3 自动…

Anaconda详细安装及配置教程(Windows)

Anaconda详细安装及配置教程&#xff08;Windows&#xff09; 一、下载方式1、官网下载2、网盘下载 二、安装三、配置四、创建虚拟环境 一、下载方式 1、官网下载 点击下载 点击window下载即可。 2、网盘下载 点击下载 二、安装 双击运行 点next 点I agree next 如…