TensorFlow系列:第四讲:MobileNetV2实战

一. 加载数据集

编写工具类,实现数据集的加载

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/808da38d6ad74628b869c28e937b02d9.png


import keras

"""
加载数据集工具类
"""


class DatasetLoader:
    def __init__(self, path_url, image_size=(224, 224), batch_size=32, class_mode='categorical'):
        self.path_url = path_url
        self.image_size = image_size
        self.batch_size = batch_size
        self.class_mode = class_mode

    # 不使用图像增强
    def load_data(self):
        # 加载训练数据集
        train_data = keras.preprocessing.image_dataset_from_directory(
            self.path_url + '/train',  # 训练数据集的目录路径
            image_size=self.image_size,  # 调整图像大小
            batch_size=self.batch_size,  # 每批次的样本数量
            label_mode=self.class_mode,  # 类别模式:返回one-hot编码的标签
        )

        # 加载验证数据集
        val_data = keras.preprocessing.image_dataset_from_directory(
            self.path_url + '/validation',  # 验证数据集的目录路径
            image_size=self.image_size,  # 调整图像大小
            batch_size=self.batch_size,  # 每批次的样本数量
            label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签
        )
        # 加载测试数据集
        test_data = keras.preprocessing.image_dataset_from_directory(
            self.path_url + '/test',  # 验证数据集的目录路径
            image_size=self.image_size,  # 调整图像大小
            batch_size=self.batch_size,  # 每批次的样本数量
            label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签
        )
        class_names = train_data.class_names
        return train_data, val_data, test_data, class_names

二. 训练模型完整代码

import keras
from keras import layers

from utils.dataset_loader import DatasetLoader

"""
使用MobileNetV2,实现图像多分类
"""

# 模型训练地址
PATH_URL = '../data/fruits'
# 训练曲线图
RESULT_URL = '../results/fruits'
# 模型保存地址
SAVED_MODEL_DIR = '../saved_model/fruits'

#  图片大小
IMG_SIZE = (224, 224)
# 定义图像的输入形状
IMG_SHAPE = IMG_SIZE + (3,)
# 数据加载批次,训练轮数
BATCH_SIZE, EPOCH = 32, 16


# 训练模型
def train():
    # 实例化数据集加载工具类
    dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)
    train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()

    # 构建 MobileNet 模型
    base_model = keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)
    # 将模型的主干参数进行冻结
    base_model.trainable = False
    model = keras.Sequential([
        layers.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),
        # 设置主干模型
        base_model,
        # 对主干模型的输出进行全局平均池化
        layers.GlobalAveragePooling2D(),
        # 通过全连接层映射到最后的分类数目上
        layers.Dense(len(class_total), activation='softmax')
    ])
    # 编译模型
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    # 模型结构
    model.summary()
    # 指明训练的轮数epoch,开始训练
    model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)
    # 测试
    loss, accuracy = model.evaluate(test_ds)
    # 输出结果
    print('Mobilenet test accuracy :', accuracy, ',loss :', loss)
    # 保存模型 savedModel格式
    model.export(filepath=SAVED_MODEL_DIR)


if __name__ == '__main__':
    train()

训练模型输出如下:

模型结构:

在这里插入图片描述
训练进度:主要看最下边一行输出,一轮训练完成会显示训练集和验证集的正确率。
在这里插入图片描述
验证正确率:

在这里插入图片描述
保存的模型:

在这里插入图片描述

三. 函数式调用方式

以后的所有讲解,都基于函数式方式进行,因为函数式调用比较灵活。

# 函数式调用方式
def train1():
    # 实例化数据集加载工具类
    dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)
    train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()

    inputs = keras.Input(shape=IMG_SHAPE)
    # 加载预训练的 MobileNetV2 模型,不包括顶层分类器,并在 Rescaling 层之后连接
    base_model = keras.applications.MobileNetV3Large(weights='imagenet', include_top=False, input_tensor=inputs)

    # 冻结 MobileNetV2 的所有层,以防止在初始阶段进行权重更新
    for layer in base_model.layers:
        layer.trainable = False
    # 在 MobileNetV2 之后添加自定义的顶层分类器
    x = layers.GlobalAveragePooling2D()(base_model.output)
    predictions = layers.Dense(len(class_total), activation='softmax')(x)
    # 构建最终模型
    model = keras.Model(inputs=base_model.input, outputs=predictions)
    # 编译模型
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    # 查看模型结构
    model.summary()
    model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)
    # 测试
    loss, accuracy = model.evaluate(test_ds)
    # 输出结果
    print('Mobilenet test accuracy :', accuracy, ',loss :', loss)
    # 保存模型 savedModel格式
    model.export(filepath=SAVED_MODEL_DIR)

四. 保存训练过程曲线图

在训练模型时,我们不可能时时盯着训练数据结果,如果把训练过程曲线保存成图片,这样就比较方便查看。

在项目中编写一个工具类如下:
在这里插入图片描述
上边代码简单改造:

    # 训练模型
    history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)
    # 保存曲线图
    Utils.trainResult(history, RESULT_URL)

曲线图如下:训练集和验证集准确率上升,损失率下降,这是完美的表现。

在这里插入图片描述

五. 模型可视化批量测试

在这里插入图片描述
编写可视化批量测试工具类:

import keras
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatch

from utils.dataset_loader import DatasetLoader

"""
模型工具类
"""


class ModelUtil:
    def __init__(self, saved_model_dir, path_url):
        self.save_model_dir = saved_model_dir  # savedModel 模型保存地址
        self.path_url = path_url  # 模型训练数据地址

    # 批量识别 进行可视化显示
    def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):
        dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)
        train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()
        # 加载savedModel模型
        tfs_layer = keras.layers.TFSMLayer(self.save_model_dir)
        # 创建一个新的 Keras 模型,包含 TFSMLayer
        model = keras.Sequential([
            keras.Input(shape=image_size + (3,)),  # 根据你的模型的输入形状
            tfs_layer
        ])

        plt.figure(figsize=(10, 10))
        for images, labels in test_ds.take(1):
            # 使用模型进行预测
            outputs = model.predict(images)
            for i in range(num_images):
                plt.subplot(5, 5, i + 1)
                image = np.array(images[i]).astype("uint8")
                plt.imshow(image)
                index = int(np.argmax(outputs[i]))
                prediction = outputs[i][index]
                percentage_str = "{:.2f}%".format(prediction * 100)
                plt.title(f"{class_names[index]}: {percentage_str}")
                plt.axis("off")
        plt.subplots_adjust(hspace=0.5, wspace=0.5)
        plt.show()

使用工具类:

if __name__ == '__main__':
    # train()
    model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)
    model_util.batch_evaluation()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/797742.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

H5的Canvas如何画N叉树数据结构

大家好。我是猿码叔叔,一位有着 5 年Java工作经验的北漂,业余时间喜欢瞎捣鼓,学习一些新东西来丰富自己。看过上一篇 Java 方法调用关系的老铁们,也许遗留了不少疑问,这Java方法调用关系可视化页面就这?这方…

护网HW面试——redis利用方式即复现

参考:https://xz.aliyun.com/t/13071 面试中经常会问到ssrf的打法,讲到ssrf那么就会讲到配合打内网的redis,本篇就介绍redis的打法。 未授权 原理: Redis默认情况下,会绑定在0.0.0.0:6379,如果没有采用相关…

基于SpringBoot的校园志愿者管理系统

你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:Java 数据库:MySQL 技术:SpringBoot框架 工具:MyEclipse、Tomcat 系统展示 首页 个人中心 志愿者管理 活动信息…

黑马头条微服务学习day01-环境搭建、SpringCloud微服务(注册发现、网关)

文章目录 项目介绍环境搭建项目背景业务功能技术栈说明 nacos服务器环境准备nacos安装 初始工程搭建环境准备主体结构 app登录需求分析表结构分析手动加密微服务搭建接口定义功能实现登录功能实现 Swagger使用app端网关nginx配置 项目介绍 环境搭建 项目背景 业务功能 技术栈说…

数据结构(Java):树二叉树

目录 1、树型结构 1.1 树的概念 1.2 如何判断树与非树 1.3 树的相关概念 1.4 树的表示形式 1.4.1 孩子兄弟表示法 2、二叉树 2.1 二叉树的概念 2.2 特殊的二叉树 2.3 二叉树的性质 2.4 二叉树的存储 2.5 二叉树的遍历 1、树型结构 1.1 树的概念 树型结构是一种非线…

MySQL复合查询(重点)

前面我们讲解的mysql表的查询都是对一张表进行查询,在实际开发中这远远不够。 基本查询回顾 查询工资高于500或岗位为MANAGER的雇员,同时还要满足他们的姓名首字母为大写的J mysql> select * from emp where (sal>500 or jobMANAGER) and ename l…

数据湖仓一体(一) 编译hudi

目录 一、大数据组件版本信息 二、数据湖仓架构 三、数据湖仓组件部署规划 四、编译hudi 一、大数据组件版本信息 hudi-0.14.1zookeeper-3.5.7seatunnel-2.3.4kafka_2.12-3.5.2hadoop-3.3.5mysql-5.7.28apache-hive-3.1.3spark-3.3.1flink-1.17.2apache-dolphinscheduler-3.1.9…

[Vulnhub] Sedna BuilderEngine-CMS+Kernel权限提升

信息收集 IP AddressOpening Ports192.168.8.104TCP:22, 53, 80, 110, 111, 139, 143, 445, 993, 995, 8080, 55679 $ nmap -p- 192.168.8.104 --min-rate 1000 -sC -sV PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 6.6.1p1 Ubuntu 2ubuntu2 …

C++20中的consteval说明符

在C20中,立即函数(immediate function)是指每次调用该函数都会直接或间接产生编译时常量表达式(constant expression)的函数。这些函数在其返回类型前使用consteval关键字进行声明。 立即函数是constexpr函数,具体情况取决于其要求。与constexpr相同&…

半小时获得一张ESG入门证书【详细中英文笔记一】

前些日子,有朋友转发了一则小红书的笔记给我, 标题是《半小时获CFI官方高颜值免费证书 ESG认证》。这对考证狂魔的我来说,必然不能错过啊,有免费的羊毛不薅白不薅。 ESG课程的 CFI 官方网址戳这里:CFI 于是信心满满的…

清华大学孙富春教授团队开发了多模态数字孪生环境,辅助机器人获得复杂的 3C 装配技能

中国是全球3C产品(电脑、通信和消费电子)的主要生产国,全球70%的3C产品产能集中在中国。3C智能制造装备的突破与产业化,对于提升我国制造产业的全球竞争力意义重大。 机器人在计算机、通信和消费电子 (3C) …

常用的设计模式和使用案例汇总

常用的设计模式和使用案例汇总 【一】常用的设计模式介绍【1】设计模式分类【2】软件设计七大原则(OOP原则) 【二】单例模式【1】介绍【2】饿汉式单例【3】懒汉式单例【4】静态内部类单例【5】枚举(懒汉式) 【三】工厂方法模式【1】简单工厂模式&#xf…

springboot 程序运行一段时间后收不到redis订阅的消息

springboot 程序运行一段时间后收不到redis订阅的消息 问题描述 程序启动后redis.user.two主题正常是可以收到消息的,发一条收一条,但是隔一段时间后;就收不到消息了; 此时如果你手动调用发送另外一个消息订阅redis.user.two2&…

vmware workstation 虚拟机安装

vmware workstation 虚拟机安装 VMware Workstation Pro是VMware(威睿公司)发布的一代虚拟机软件,中文名称一般称 为"VMware 工作站".它的主要功能是可以给用户在单一的桌面上同时运行不同的操作系统,它也是可进 行开发…

c# 容器变换

List<Tuple<int, double, bool>> 变为List<Tuple<int, bool>>集合 如果您有一个List<Tuple<int, double, bool>>并且您想要将其转换为一个List<Tuple<int, bool>>集合&#xff0c;忽略double值&#xff0c;您可以使用LINQ的S…

3U 与 SV630A 伺服实现 CANLINK 通讯

1、打开 AUTOSHOP&#xff0c;点击工具>系统选项&#xff0c;勾选自动生成 canlink 轴 控通讯配置和 canlink 轴控指令增强功能。 2、检查 plc 的拨码是否已经拨上去。 1 代表 485 通讯&#xff0c;2 代表 can 通讯&#xff0c;将 2 打到 ON 状态。还有9&#xff0c;10拨…

Matlab 计算一个平面与一条直线的交点

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 这里使用一种很有趣的坐标:Plucker线坐标,它的定义如下所示: 这个坐标有个很有趣的性质,将直线 L L L与由其齐次坐标 V = (

IDEA社区版使用Maven archetype 创建Spring boot 项目

1.新建new project 2.选择Maven Archetype 3.命名name 4.选择存储地址 5.选择jdk版本 6.Archetype使用webapp 7.create创建项目 创建好长这样。 检查一下自己的Maven是否是自己的。 没问题的话就开始增添java包。 [有的人连resources包也没有&#xff0c;那就需要自己添…

AI人工智能开源大模型生态体系分析

人工智能开源大模型生态体系研究 "人工智能开源大模型生态体系研究报告v1.0"揭示&#xff0c;AI(A)的飞速发展依赖于三大核心&#xff1a;数据、算法和算力。这一理念已得到业界广泛认同&#xff0c;三者兼备才能推动AI的壮大发展。随着AI大模型的扩大与普及&#xf…

el-table 动态添加删除 -- 鼠标移入移出显隐删除图标

<el-table class"list-box" :data"replaceDataList" border><el-table-column label"原始值" prop"original" align"center" ><template slot-scope"scope"><div mouseenter"showClick…