Tensorflow神经网络模型-鲜花种类识别

在这里插入图片描述必应壁纸供图

Tensorflow神经网络模型-鲜花种类识别

数据集:https://download.csdn.net/download/weixin_53742691/87982215

导入相关依赖

import warnings
import re
from IPython.display import clear_output, display
from tkinter import Tk, filedialog
from ipywidgets import Button
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


warnings.filterwarnings("ignore")

数据探索

flower_category = "flowers"
categorys = 0
categorys_list = []
for category in os.listdir(flower_category):
    categorys += 1
    categorys_list.append(category)
print("种类总数为:%d" % categorys)
print(categorys_list)
种类总数为:5
['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
file_path = "flowers/sunflower/"
file_count = 0
for file in os.listdir(file_path):
    if re.match(r'\S*\.?[jpg,png,jpeg]', file):
        file_count += 1
print("文件总数是:%d" % file_count)
文件总数是:733

图片处理器

def img_deal(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
    img = cv2.resize(img, (224, 224))
    return img

图片预览

sample_list = []
num = 0
for sample in os.listdir(file_path):
    num += 1
    sample = "flowers/sunflower/"+sample
    sample_list.append(sample)
    if num == 5:
        break
print(sample_list)
['flowers/sunflower/1008566138_6927679c8a.jpg', 'flowers/sunflower/1022552002_2b93faf9e7_n.jpg', 'flowers/sunflower/1022552036_67d33d5bd8_n.jpg', 'flowers/sunflower/10386503264_e05387e1f7_m.jpg', 'flowers/sunflower/10386522775_4f8c616999_m.jpg']
plt.figure(figsize=(20, 20))
for i in range(5):
    plt.subplot(1, 5, i+1)
    img = img_deal(sample_list[i])
    plt.imshow(img)
    plt.xlabel("sunflower "+str(i+1))
    plt.xticks([])
    plt.yticks([])
plt.show()

png

数据预处理

# 输入图片大小
img_size = (224, 224)
# 图像数据生成
gen = tf.keras.preprocessing.image.ImageDataGenerator(
    img_size,
    validation_split=0.25,
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)

设置训练集

train_generator = gen.flow_from_directory(
    # 设置图片加载路径
    "flowers/",
    # 设置加载图片大小
    img_size,
    # 设置批次大小
    batch_size=32,
    class_mode="categorical",
    subset="training"
)
Found 3238 images belonging to 5 classes.

设置验证集

validation_generator = gen.flow_from_directory(
    "flowers/",
    img_size,
    batch_size=32,
    class_mode="categorical",
    subset="validation"
)
Found 1079 images belonging to 5 classes.

处理后的图片预览shuffle

plt.figure(figsize=(26, 10))
for i in range(32):
    plt.subplot(4, 8, i+1)
    sample = train_generator[0][0][i]
    # 设置图片色彩通道最小值
    sample = np.maximum(sample, 0)
    # 设置图片标签
    plt.imshow(sample)
    plt.xlabel(i)
    plt.xticks([])
    plt.yticks([])
plt.show()

png

模型搭建和训练

# 基础模型
base_model = tf.keras.applications.MobileNetV2(
    weights="imagenet",
    include_top=False,
    input_shape=(224, 224, 3)
)
# 锁定其他节点
for layers in base_model.layers:
    layers.trainable = False
# 重建模型
model = tf.keras.Sequential([
    base_model,
    # 展平
    tf.keras.layers.Flatten(),
    # 添加神经元
    tf.keras.layers.Dense(units=128, activation="relu"),
    tf.keras.layers.Dense(units=64, activation="relu"),
    tf.keras.layers.Dense(units=5, activation="softmax")
])
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 mobilenetv2_1.00_224 (Funct  (None, 7, 7, 1280)       2257984   
 ional)                                                          
                                                                 
 flatten (Flatten)           (None, 62720)             0         
                                                                 
 dense (Dense)               (None, 128)               8028288   
                                                                 
 dense_1 (Dense)             (None, 64)                8256      
                                                                 
 dense_2 (Dense)             (None, 5)                 325       
                                                                 
=================================================================
Total params: 10,294,853
Trainable params: 8,036,869
Non-trainable params: 2,257,984
_________________________________________________________________
model.compile(loss="categorical_crossentropy",
              optimizer="adam", metrics=['accuracy'])
 history = model.fit(train_generator,
                     epochs=5,
                     validation_data=validation_generator)
Epoch 1/5
102/102 [==============================] - 49s 320ms/step - loss: 1.1481 - accuracy: 0.7712 - val_loss: 0.5897 - val_accuracy: 0.8360
Epoch 2/5
102/102 [==============================] - 35s 343ms/step - loss: 0.1766 - accuracy: 0.9469 - val_loss: 0.6906 - val_accuracy: 0.8573
Epoch 3/5
102/102 [==============================] - 30s 289ms/step - loss: 0.0371 - accuracy: 0.9864 - val_loss: 0.6850 - val_accuracy: 0.8703
Epoch 4/5
102/102 [==============================] - 28s 273ms/step - loss: 0.0144 - accuracy: 0.9957 - val_loss: 0.7199 - val_accuracy: 0.8703
Epoch 5/5
102/102 [==============================] - 29s 282ms/step - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.6943 - val_accuracy: 0.8749
model.save("models/flower_model.h5")

自主测试

model = tf.keras.models.load_model("models/flower_model.h5")
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 mobilenetv2_1.00_224 (Funct  (None, 7, 7, 1280)       2257984   
 ional)                                                          
                                                                 
 flatten (Flatten)           (None, 62720)             0         
                                                                 
 dense (Dense)               (None, 128)               8028288   
                                                                 
 dense_1 (Dense)             (None, 64)                8256      
                                                                 
 dense_2 (Dense)             (None, 5)                 325       
                                                                 
=================================================================
Total params: 10,294,853
Trainable params: 8,036,869
Non-trainable params: 2,257,984
_________________________________________________________________
def select_file(b):
    clear_output()
    root = Tk()
    root.withdraw()
    root.call('wm', 'attributes', '.', '-topmost', True)
    b.files = filedialog.askopenfilename(multiple=True)
    print(b.files)


fileselect = Button(description="选择文件")
fileselect.on_click(select_file)
display(fileselect)
Button(description='选择文件', style=ButtonStyle())
len(fileselect.files)
25
plt.figure(figsize=(20, 20))
for i in range(25):
    plt.subplot(5, 5, i+1)
    img = img_deal(fileselect.files[i])
    plt.imshow(img)
    plt.xlabel(i+1)
    plt.xticks([])
    plt.yticks([])
plt.show()

png

# 图片进行打包
from tensorflow.keras.applications.densenet import preprocess_input
test_img = []
for i in range(25):
    img = img_deal(fileselect.files[i])
    test_img.append(img)
test_img = np.asarray(test_img)
test_pre_image = preprocess_input(test_img)
test_pre_image.shape
(25, 224, 224, 3)
decoder_dict = dict(zip(train_generator.class_indices.values(),
                    train_generator.class_indices.keys()))
decoder_dict
{0: 'daisy', 1: 'dandelion', 2: 'rose', 3: 'sunflower', 4: 'tulip'}
predictions = model.predict(test_pre_image)
for prediction in predictions:
    print(decoder_dict[prediction.argmax()], end=" ")
sunflower tulip tulip rose rose rose rose tulip daisy sunflower dandelion rose daisy dandelion dandelion rose tulip tulip tulip daisy daisy sunflower dandelion dandelion rose 

整体输出可视化测试

font = {
    "size": "22",
    "color": "red"
}
plt.figure(figsize=(20, 20))
for i in range(25):
    plt.subplot(5, 5, i+1)
    img = img_deal(fileselect.files[i])
    plt.imshow(img)
    img = preprocess_input(img)
    img = np.expand_dims(img, 0)
    result = model.predict(img)
    label = decoder_dict[result.argmax()]
    plt.xlabel(label, font)
    plt.xticks([])
    plt.yticks([])
plt.show()

png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/33707.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

wampServer安装Redis 扩展

第一步:查看php版本信息 使用 phpinfo() 函数查看 PHP 的版本信息(用于选择扩展包) 版本信息:PHP版本为 8.0.26,编译器版本 Visual C 2019,CPU架构 x64 。 第二步:根据第一步信息的版本选择扩…

基于树莓派4B的YOLOv5-Lite目标检测的移植与部署(含训练教程)

前言:本文为手把手教学树莓派4B项目——YOLOv5-Lite目标检测,本次项目采用树莓派4B(Cortex-A72)作为核心 CPU 进行部署。该篇博客算是深度学习理论的初步实战,选择的网络模型为 YOLOv5 模型的变种 YOLOv5-Lite 模型。Y…

【AI底层逻辑】——篇章3(上):数据、信息与知识香农信息论信息熵

目录 引入 一、数据、信息、知识 二、“用信息丈量世界” 1、香农信息三定律 2、一条信息的价值 3、信息的熵 总结 引入 AI是一种处理信息的模型,我们把信息当作一种内容的载体,计算机发明以前很少有人思考它的本质是什么。随着通信技术的发展&a…

【ISO26262】汽车功能安全第3部分:概念阶段

GB/T34590《道路车辆 功能安全》分为以下部分: 需要文档的朋友,可以和我联系! tommi_wei@163.com GB/T34590的本部分规定了车辆在概念阶段的要求: ———相关项定义; ———安全生命周期启动; ———危害分析和风险评估;及 ———功能安全概念。 危害事件分类 对于每一个…

wsl子系统Ubuntu18.04,cuDNN安装

如果觉得本篇文章对您的学习起到帮助作用,请 点赞 关注 评论 ,留下您的足迹💪💪💪 本文主要wls子系统Ubuntu18.04安装cuDNN,安装cudnn坑巨多,因此记录以备日后查看,同时&#xff0…

GaussDB WDR报告分析

标题 问题描述问题现象告警业务影响原因分析处理方法步骤 1步骤 2步骤 3步骤 4步骤 6步骤 7步骤 8步骤9步骤 10步骤 11步骤 12 问题描述 CPU使用率高。 问题现象 出现CPU使用率超过阈值,CPU使用率快速上涨或短时间持续较高水平等现象。 告警 CPU使用率告警。 …

uniapp的表单校验方式整理

uniapp的表单校验方式整理 这里我使用的模板为: 第一种: uniapp本身自带表单校验的js文件,代码写的很简洁,也是比较全面的 只要按照规则校验即可,下面是对应的校验代码: /** 数据验证(表…

PyQt中数据库的访问(一)

访问数据库的第一步是确保ODBC数据源配置成功,我接下来会写数据源配置的文章,请继续关注本栏! (一)数据库连接 self.DBQSqlDatabase.addDatabase("QODBC") self.DB.setDatabaseName("Driver{sqlServer…

ModaHub AI模型开源社区——向量数据库Milvus存储操作教程

目录 存储操作 数据插入 数据落盘 定时触发 客户端触发 缓冲区达到上限触发 数据合并 建立索引 删除 删除集合 删除分区 删除实体 数据段整理 数据读取 常见问题 存储操作 阅读本文前,请先阅读 存储相关概念。 数据插入 客户端通过调用 insert 接…

【计算机视觉】DINO

paper:Emerging Properties in Self-Supervised Vision Transformers 源码:https://github.com/facebookresearch/dino 20230627周二目前只把第一部分看完了。 论文导读:DINO -自监督视觉Transformers - deephub的文章 - 知乎 综述类型&a…

线程不安全举例

1、举例说明集合类线程不安全 &#xff08;1&#xff09;查看源码可证明 看ArrayList源码 没有sync、lock&#xff0c;线程不安全 &#xff08;2&#xff09;创建多个线程写入读取数据 List<String> list new ArrayList<>(); for (int i 1; i <30 ; i) {n…

【Android】Android虚拟机

虚拟机 Android的虚拟机主要有两种&#xff1a;Dalvik 虚拟机和 ART&#xff08;Android Runtime&#xff09;虚拟机。 Dalvik 虚拟机 Dalvik 虚拟机是 Android 早期使用的虚拟机&#xff0c;它基于寄存器架构。从Android 2.2版本开始&#xff0c;支持JIT即时编译&#xff08…

基于多站点集中汇聚需求的远程调用直播视频汇聚平台解决方案

一、行业背景 随着视频汇聚需求的不断提升&#xff0c;智慧校园、智慧园区等项目中需要将各分支机构的视频统一汇聚到总部&#xff0c;进行统一管控&#xff0c;要满足在监控内部局域网、互联网、VPN网络等TCP/IP环境下&#xff0c;为用户提供低成本、高扩展、强兼容、高性能的…

【SpringBoot】基于SSM框架的题库系统的设计与实现

文章结构 课题&#xff1a;一、项目简介主要功能技术选型 二、 模块介绍学生端教师端(一)考试管理(二)试题管理(三)学生成绩管理 管理员三、 B站项目演示地址 四、本项目其余相关博客 课题&#xff1a; 题库系统的设计与实现一、项目简介 简介&#xff1a;主要分为三个端&…

DAY38——动态规划

步骤&#xff1a; 确定dp数组&#xff08;dp table&#xff09;以及下标的含义确定递推公式dp数组如何初始化确定遍历顺序举例推导dp数组 题目一. 斐波那契数列 1. 确定dp数组以及下标的含义 dp[i]的定义为&#xff1a;第i个数的斐波那契数值是dp[i] 2. 确定递推公式 状态…

【Zookeeper】win安装随笔

目录 下载地址下载目标解压后目录结构配置文件配置文件详情伪分布式安装LinuxZooKeeper audit is disabled启动解决报错&#xff1a;SLF4J: Class path contains multiple SLF4J bindings. _ 下载地址 https://zookeeper.apache.org/releases.html 下载目标 记住选择带bin的…

一步一步学OAK之四:实现如何在低延迟下使用高分辨率视频

目录 Setup 1: 创建文件Setup 2: 安装依赖Setup 3: 导入需要的包Setup 4: 创建pipelineSetup 5: 创建节点Setup 6: 设置节点的属性和参数。Setup 7: 建立链接关系Setup 8: 连接设备并启动管道Setup 9: 创建与DepthAI设备通信的输入队列和输出队列Setup 10: 主循环获取视频帧显示…

【C++】定制删除器和特殊类设计(饿汉和懒汉~)

文章目录 定制删除器一、设计一个只能在堆上(或栈上)创建的类二、单例模式 1.饿汉模式2.懒汉模式总结 定制删除器 我们在上一篇文章中讲到了智能指针&#xff0c;相信大家都会有一个问题&#xff0c;智能指针该如何辨别我们的资源是用new int开辟的还是new int[]开辟的呢&…

html5前端学习2

一篇思维题题解&#xff1a; 第五周任务 [Cloned] - Virtual Judge (vjudge.net) http://t.csdn.cn/SIHdM 快捷键&#xff1a; CtrlAltDown 向下选取 CtrlAltUp 向上选取&#xff08;会出现多个光标&#xff0c;可以同时输入&#xff09; CtrlEnter …

【Java】Java核心 78:Git 教程(1)Git 概述

文章目录 01.GIT概述目标内容小结 02.GIT相关概念目标内容小结 01.GIT概述 Git是一个分布式版本控制系统&#xff0c;常用于协同开发和版本管理的工具。它可以跟踪文件的修改、记录历史版本&#xff0c;并支持多人协同工作。通过Git&#xff0c;你可以轻松地创建和切换分支、合…