使用TensorFlow实现简化版 GoogLeNet 模型进行 MNIST 图像分类

        在本文中,我们将使用 TensorFlow 和 Keras 实现一个简化版的 GoogLeNet 模型来进行 MNIST 数据集的手写数字分类任务。GoogLeNet 采用了 Inception 模块,这使得它在处理图像数据时能更高效地提取特征。本教程将详细介绍如何在 MNIST 数据集上训练和测试这个模型。

项目结构

        我们的代码将分为两个部分:

  1. 训练部分 (train.py): 包含模型定义、数据加载、模型训练等。
  2. 测试部分 (test.py): 用于加载训练好的模型,并在测试集上评估其性能。

训练部分:train.py

1. 数据加载与预处理

        首先,我们需要加载 MNIST 数据集并进行预处理。预处理包括调整图像形状、归一化以及 One-Hot 编码标签。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

def load_and_preprocess_data():
    # 加载 MNIST 数据集
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

    # 数据预处理:将图像形状调整为 [28, 28, 1],并归一化到 [0, 1] 范围
    train_images = train_images.reshape((train_images.shape[0], 28, 28, 1)) / 255.0
    test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0

    # One-Hot 编码标签
    train_labels = to_categorical(train_labels, 10)
    test_labels = to_categorical(test_labels, 10)

    return train_images, train_labels, test_images, test_labels

2. 创建简化版 GoogLeNet 模型

        接下来,我们定义一个简化版的 GoogLeNet 模型。该模型包括卷积层、Inception 模块和全连接层。

from tensorflow.keras import layers, models

def googlenet(input_shape=(28, 28, 1), num_classes=10):
    inputs = layers.Input(shape=input_shape)

    # 第一卷积层 + 池化层
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)

    # 第二卷积层 + 池化层
    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)

    # 第三卷积层 + 池化层
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)

    # Inception 模块
    inception1 = layers.Conv2D(64, (1, 1), activation='relu', padding='same')(x)
    inception2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    inception3 = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(x)

    # 拼接 Inception 模块的输出
    x = layers.concatenate([inception1, inception2, inception3], axis=-1)

    # 全局平均池化层
    x = layers.GlobalAveragePooling2D()(x)

    # 全连接层
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.5)(x)  # Dropout 层减少过拟合
    outputs = layers.Dense(num_classes, activation='softmax')(x)  # 输出层,使用 softmax 激活函数进行多分类

    model = models.Model(inputs=inputs, outputs=outputs)
    return model

3. 模型训练

        定义好模型之后,我们使用 Adam 优化器和交叉熵损失函数来训练模型,并保存训练好的模型。

def train_model(model, train_images, train_labels, epochs=5, batch_size=64):
    # 训练模型
    history = model.fit(train_images, train_labels,
                        epochs=epochs,
                        batch_size=batch_size)

    return history

def save_model(model, filename='googlenet_mnist.h5'):
    model.save(filename)
    print(f"Model saved to {filename}")

4. 主程序

        最后,在主程序中,我们加载数据、创建模型并开始训练。

def main():
    train_images, train_labels, test_images, test_labels = load_and_preprocess_data()

    model = googlenet(input_shape=(28, 28, 1), num_classes=10)

    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    train_model(model, train_images, train_labels, epochs=5, batch_size=64)

    save_model(model)

if __name__ == '__main__':
    main()


测试部分:test.py

1. 加载训练好的模型

        在测试部分,我们将加载训练好的模型,并在测试集上进行评估。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

def load_and_preprocess_data():
    (_, _), (test_images, test_labels) = mnist.load_data()
    test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0
    test_labels = to_categorical(test_labels, 10)
    return test_images, test_labels

def load_model(model_path='googlenet_mnist.h5'):
    model = tf.keras.models.load_model(model_path)
    return model

2. 评估模型

        我们通过 evaluate 方法评估模型的损失和准确率。

def evaluate_model(model, test_images, test_labels):
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"Test accuracy: {test_acc * 100:.2f}%")
    return test_loss, test_acc

3. 显示预测结果

        使用 Matplotlib 可视化前几张图片的预测结果。

import matplotlib.pyplot as plt

def display_predictions(model, test_images, test_labels, num_images=6):
    predictions = model.predict(test_images[:num_images])

    fig, axes = plt.subplots(2, 3, figsize=(10, 6))
    axes = axes.flatten()

    for i in range(num_images):
        ax = axes[i]
        ax.imshow(test_images[i].reshape(28, 28), cmap='gray')
        ax.set_title(f"Pred: {tf.argmax(predictions[i]).numpy()} \n True: {tf.argmax(test_labels[i]).numpy()}")
        ax.axis('off')

    plt.tight_layout()
    plt.show()

4. 主程序

        在主程序中,我们加载模型,评估其性能,并显示预测结果。

def main():
    test_images, test_labels = load_and_preprocess_data()
    model = load_model('googlenet_mnist.h5')

    evaluate_model(model, test_images, test_labels)
    display_predictions(model, test_images, test_labels)

if __name__ == '__main__':
    main()


总结

        本文介绍了如何使用 TensorFlow 实现简化版 GoogLeNet,并在 MNIST 数据集上进行训练和测试。我们将代码分为训练和测试两部分,分别处理数据预处理、模型训练与评估、结果展示等工作。

        通过使用 GoogLeNet 进行图像分类,我们不仅能够提高分类性能,还能了解 Inception 模块在图像处理中的强大能力。希望这篇博客能够帮助你更好地理解深度学习模型的训练与测试过程。

完整项目:GoogLeNet-TensorFlow: 使用TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/goog-le-net-tensor-flow

qxd-ljy/GoogLeNet-TensorFlow: 使用 TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://github.com/qxd-ljy/GoogLeNet-TensorFlow 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/918649.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2.5D视觉——Aruco码定位检测

目录 1.什么是Aruco标记2.Aruco码解码说明2.1 Original ArUco2.2 预设的二维码字典2.3 大小Aruco二维码叠加 3.函数说明3.1 cv::aruco::detectMarkers3.2 cv::solvePnP 4.代码注解4.1 Landmark图说明4.2 算法源码注解 1.什么是Aruco标记 ArUco标记最初由S.Garrido-Jurado等人在…

java 根据 pdf 模板带图片文字生成pdf文件

在现代应用开发中,自动生成包含动态内容的 PDF 文档在电子发票、合同生成、表单填充等场景中有着广泛的应用。本文将介绍如何使用 iText 库动态填充 PDF 模板字段,并在指定位置插入签名和公章图片。 项目需求 假设我们有一个 PDF 模板文件,包含表单字段,如用户姓名、地址…

计算机网络-MSTP基础实验一(单域多实例)

前面我们已经大致了解了MSTP的基本概念和工作原理,但是我自己也觉得MSTP的理论很复杂不结合实验是很难搞懂的,今天来做一个配套的小实验以及一些配置命令。 一、网络拓扑 单域多实例拓扑 基本需求:SW1为VLAN10的网关,SW2为VLAN20的…

进程相关知识

#include <sys/types.h> #include <unistd.h> pid_t fork(void); 函数的作用&#xff1a;用于创建子进程。 返回值&#xff1a; fork() 的返回值会返回两次。一次是在父进程中&#xff0c;一次是在子进程中。 在父进程中返回创建的子进程的 ID, 在子进程中…

Python中的TCP

文章目录 一. 计算机网络1. 网络的概念2. IP地址① IP地址的概念② IP地址的表现形式③ IP地址的作用④ 网络查询命令Ⅰ. ifconfig/ipconfigⅡ. ping 3. 端口和端口号的概念(计算机通信原理)① 端口的概念② 端口号的概念 4. socket套接字① socket概念② socket使用场景 二. T…

本地部署Apache Answer搭建高效的知识型社区并一键发布到公网流程

文章目录 前言1. 本地安装Docker2. 本地部署Apache Answer2.1 设置语言选择简体中文2.2 配置数据库2.3 创建配置文件2.4 填写基本信息 3. 如何使用Apache Answer3.1 后台管理3.2 提问与回答3.3 查看主页回答情况 4. 公网远程访问本地 Apache Answer4.1 内网穿透工具安装4.2 创建…

【数据结构】线性表——栈与队列

写在前面 栈和队列的关系与链表和顺序表的关系差不多&#xff0c;不存在谁替代谁&#xff0c;只有双剑合璧才能破敌万千~~&#x1f60e;&#x1f60e; 文章目录 写在前面一、栈1.1栈的概念及结构1.2、栈的实现1.2.1、栈的结构体定义1.2.2、栈的初始化栈1.2.3、入栈1.2.4、出栈…

科技改变工作方式:群晖NAS安装内网穿透实现个性化办公office文档分享(1)

文章目录 前言1. 本地环境配置2. 制作本地分享链接3. 制作公网访问链接4. 公网ip地址访问您的分享相册5. 制作固定公网访问链接 前言 本文将详细介绍如何在群晖NAS上安装Synology Office和Synology Drive Server&#xff0c;并利用Cpolar内网穿透工具为本地文档配置固定的公网…

无插件H5播放器EasyPlayer.js网页web无插件播放器选择全屏时,视频区域并没有全屏问题的解决方案

EasyPlayer.js H5播放器&#xff0c;是一款能够同时支持HTTP、HTTP-FLV、HLS&#xff08;m3u8&#xff09;、WS、WEBRTC、FMP4视频直播与视频点播等多种协议&#xff0c;支持H.264、H.265、AAC、G711A、MP3等多种音视频编码格式&#xff0c;支持MSE、WASM、WebCodec等多种解码方…

rocketmq5源码系列--(一)--搭建调试环境

说在前头&#xff1a;阿里的rocketmq的文档是真他妈的烂的1b&#xff0c;很多东西都不说&#xff0c;全靠自己看源码&#xff0c;摸索&#xff0c;草&#xff0c;真的要吐血了 rocketmq的版本5而不是版本4&#xff0c;版本5比版本4多了个proxy rocketmq5 三个组件&#xff1a;…

【网页设计】CSS3 进阶(动画篇)

1. CSS3 2D 转换 转换&#xff08;transform&#xff09;是CSS3中具有颠覆性的特征之一&#xff0c;可以实现元素的位移、旋转、缩放等效果 转换&#xff08;transform&#xff09;你可以简单理解为变形 移动&#xff1a;translate旋转&#xff1a;rotate缩放&#xf…

django安装与项目创建

一、安装 在终端输入 pip install django //或者(&#xff09;指定安装版本 pip install django2.2 二、创建项目 2.1创建项目 django-admin startproject 项目名 2.2Django 项目中的关键文件 _init_.py:将目录标识为python包setting.py:核心配置文件&#xff0c;定义项目…

【redis】—— 初识redis(redis基本特征、应用场景、以及重大版本说明)

序言 本文将引导读者探索Redis的世界&#xff0c;深入了解其发展历程、丰富特性、常见应用场景、使用技巧等&#xff0c;最后会对Redis演进过程中具有里程碑意义的版本进行详细解读。 目录 &#xff08;一&#xff09;初始redis &#xff08;二&#xff09;redis特性 &#…

SpringBoot学习记录(三)之多表查询

SpringBoot学习记录&#xff08;三&#xff09;之多表查询 一、多表查询概述1、数据准备2、介绍3、分类 二、内连接三、外连接四、子查询1、标量子查询2、列子查询3、行子查询4、表子查询 三、案例1、准备环境2、需求实现3、&#xff08;附&#xff09;数据准备 一、多表查询概…

泰矽微重磅发布超高集成度车规触控芯片TCAE10

市场背景 智能按键和智能表面作为汽车智能化的重要部分&#xff0c;目前正处于快速发展阶段&#xff0c;电容式触摸按键凭借其操作便利性与小体积的优势&#xff0c;在汽车内饰表面的应用越来越广泛。对于空调控制面板、档位控制器、座椅扶手、门饰板、车顶控制器等多路开关的…

HarmonyOs学习笔记-布局单位

鸿蒙开发中布局存在很多单位 鸿蒙的默认单位是vp 下方先展示一下在RrkTsUI中我们应该怎么书写&#xff0c;然后讲一下各大单位具体的含义。 Text("这是一个文本, 用默认单位进行展示&#xff0c;也就是vp") .width(100) .height(100);//此段代码与上方代码是一样的…

操作系统实验 C++实现生产者-消费者问题

实验目的 1、进一步加深理解进程同步的概念 2、加深对进程通信的理解 3、了解Linux下共享内存的使用方法 实验内容 1、按照下面要求&#xff0c;写两个c程序&#xff0c;分别是生产者producer.c以及customer.c 2、一组生产者和一组消费者进程共享一块环形缓冲区 使用共…

Easyexcel(1-注解使用)

文章链接&#xff1a; Easyexcel&#xff08;1-注解使用&#xff09; 版本依赖 <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>3.3.3</version> </dependency>ExcelProperty 指定…

最新版xAI LLM 模型Grok-2 上线

xAI&#xff01;Grok-2 最新版开启公测&#xff01;”。这是我注册成功的截图&#xff0c;使用国内的邮箱就可以注册使用了&#xff01; Grok API公测与免费体验: Grok API开启公测&#xff0c;提供免费体验128k上下文支持&#xff0c;。Grok-Beta与马斯克: 马斯克庆祝特朗普当…

css数据不固定情况下,循环加不同背景颜色

<template><div><p v-for"(item, index) in items" :key"index" :class"getBackgroundClass(index)">{{ item }}</p></div> </template><script> export default {data() {return {items: [学不会1, …