【AI实践】基于TensorFlow/Keras的CNN(卷积神经网络)简单实现:手写数字识别的工程实践

深度神经网络系列文章

  • 【AI深度学习网络】卷积神经网络(CNN)入门指南:从生物启发的原理到现代架构演进
  • 【AI实践】基于TensorFlow/Keras的CNN(卷积神经网络)简单实现:手写数字识别的工程实践

引言

在深度学习的广阔天地中,卷积神经网络(CNN)是计算机视觉领域的经典模型,卷积神经网络(CNN)凭借其强大的特征提取能力,成为了图像识别领域的中流砥柱。今天,就带大家深入剖析一个基于TensorFlow/Keras实现的简单CNN模型,看看它是如何在手写数字识别任务(MNIST数据集)中大显身手的。

本文以MNIST手写数字识别任务为例,演示如何通过TensorFlow/Keras工程化实现一个轻量级CNN,代码包含完整的数据处理、模型训练与推理流程,并特别注重实际开发中的可维护性设计。

一、环境与工具

好的工程始于好的框架。在这个项目中,我们使用Python作为编程语言,借助TensorFlow/Keras库来构建CNN模型。

环境与工具

  • Python 3.8+
  • TensorFlow 2.10+
  • Matplotlib(可视化支持)

为了使整个工程结构清晰、易于维护,我们将代码划分为多个功能模块。

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import builtins
import datetime

以上是项目的基本信息和依赖库的导入。

二、数据预处理

数据是模型的粮食,高质量的数据预处理是模型成功的关键。MNIST数据集是一个经典的手写数字数据集,包含了60000张训练图像和10000张测试图像,每张图像的大小为28x28像素。

def main():
    """
    主函数
    :return:
    """
    # 1. 加载并预处理数据
    print('加载并预处理数据')
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

    # 归一化并调整形状(添加通道维度)
    train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
    test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255

    # 转换标签为one-hot编码
    train_labels = tf.keras.utils.to_categorical(train_labels)
    test_labels = tf.keras.utils.to_categorical(test_labels)

在这里,我们首先加载了MNIST数据集,并对图像数据进行了归一化处理,将像素值从0-255的范围缩放到0-1之间。这样做的目的是为了加速模型的收敛。同时,我们还调整了图像的形状,添加了一个通道维度,以满足CNN模型的输入要求。对于标签数据,我们将其转换为one-hot编码格式,以便于模型的分类任务。

MNIST加载的数据集的train_images为60000张像素大小为28x28,内容如下:
MNIST数据集

其中to_categorical 用于将整数类别标签转换为 one-hot 编码,而one-hot编码是一种方便计算机处理的二元编码,适用于多分类任务中标签的格式化处理。

  • 输入:一维整数数组(如 [0, 2, 1, 2])。
  • 输出:二维矩阵,每一行对应一个样本的 one-hot 向量(如 [[1,0,0], [0,0,1], [0,1,0], [0,0,1]])。
  • 入参y:待转换的整数标签数组。
  • 入参num_classes(可选):总类别数。若不指定,自动根据标签最大值推断(max(y) + 1)。

三、模型构建

接下来,我们开始构建CNN模型。这个模型由几个基本的层组成:卷积层、池化层、展平层和全连接层。

    # 2. 构建CNN模型
    print('构建CNN模型')
    model = models.Sequential([
        # 卷积层:32个3x3滤波器,激活函数ReLU
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        # 最大池化层:2x2窗口
        layers.MaxPooling2D((2, 2)),
        # 展平层:将3D特征转换为1D向量
        layers.Flatten(),
        # 全连接层:128个神经元
        layers.Dense(128, activation='relu'),
        # 输出层:10个类别(数字0-9)
        layers.Dense(10, activation='softmax')
    ])

卷积层使用了32个3x3的滤波器,激活函数采用了ReLU(Rectified Linear Unit,修正线性单元,f(x)=max(0,x)),它能够有效地解决梯度消失问题,加速模型的训练。最大池化层使用2x2的窗口,对特征图进行下采样,减少参数数量,提高模型的计算效率。展平层将三维的特征图转换为一维向量,以便于全连接层的处理。全连接层包含了128个神经元,最后的输出层有10个神经元,对应着10个数字类别,激活函数使用了softmax,用于多分类任务。

四、模型编译

在模型构建完成后,我们需要对其进行编译,指定优化器、损失函数以及评估指标。

    # 3. 编译模型
    print('编译模型')
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

这里我们选择了Adam优化器,它是一种自适应学习率的优化算法,能够根据模型的训练情况自动调整学习率。损失函数选择了 categorical_crossentropy,它适用于多分类问题。评估指标我们选择了准确率(accuracy),用于衡量模型的分类性能。

五、模型训练

现在,我们开始对模型进行训练。在训练过程中,我们指定了训练数据、标签、训练轮数(epochs)、批量大小(batch_size)以及验证集的比例。

    # 4. 训练模型
    print('训练模型...')
    history = model.fit(
        train_images, train_labels,
        epochs=2,
        batch_size=64,
        validation_split=0.2
    )

在这里,我们设置了训练轮数为2,批量大小为64,验证集的比例为0.2,即从训练数据中划分出20%的数据作为验证集,用于在训练过程中评估模型的性能,防止过拟合。

参数选择依据

  • epochs=2:MNIST数据简单,2轮即可快速验证流程正确性
  • batch_size=64:在GPU显存允许范围内最大化批次提升训练速度

六、模型评估

训练完成后,我们需要对模型在测试集上的性能进行评估。

    # 5. 评估模型
    print('评估模型...')
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f'\n测试准确率: {test_acc:.4f}')

通过 model.evaluate 方法,我们可以得到测试集上的损失值和准确率。这能够让我们直观地了解模型在未见过的数据上的表现。

七、模型预测与可视化

最后,我们使用模型对测试集中的第一个样本进行预测,并将预测结果与真实标签进行比较,同时绘制图像。

    # 取测试集第一个样本
    test_image = test_images[0]
    true_label = test_labels[0].argmax()
    prediction = model.predict(test_image.reshape(1, 28, 28, 1)).argmax()
    plot_prediction(test_image, true_label, prediction)

def plot_prediction(image, true_label, prediction):
    plt.figure()
    plt.imshow(image.squeeze(), cmap='gray')
    # 设置字体,支持中文显示
    plt.rcParams["font.sans-serif"] = ["SimHei"]
    plt.title(f'真实: {true_label}, 预测: {prediction}')
    plt.axis('off')
    plt.show()

通过这个过程,我们可以直观地看到模型的预测结果是否正确,同时也能对模型的性能有一个更直观的感受。

八、完整代码

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @ProjectName: Ai
# @Name: 20250305-CNN.py
# @Auth: arbboter
# @Date: 2025/3/5-9:44
# @Desc: 使用Python和TensorFlow/Keras实现的简单卷积神经网络(CNN),用于手写数字识别(MNIST数据集),代码包含训练、评估和预测示例。
# @Ver : 0.0.0.1
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import builtins
import datetime

def main():
    """
    主函数
    :return:
    """
    # 1. 加载并预处理数据
    print('加载并预处理数据')
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

    # 归一化并调整形状(添加通道维度)
    train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
    test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255

    # 转换标签为one-hot编码
    train_labels = tf.keras.utils.to_categorical(train_labels)
    test_labels = tf.keras.utils.to_categorical(test_labels)

    # 2. 构建CNN模型
    print('构建CNN模型')
    model = models.Sequential([
        # 卷积层:32个3x3滤波器,激活函数ReLU
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        # 最大池化层:2x2窗口
        layers.MaxPooling2D((2, 2)),
        # 展平层:将3D特征转换为1D向量
        layers.Flatten(),
        # 全连接层:128个神经元
        layers.Dense(128, activation='relu'),
        # 输出层:10个类别(数字0-9)
        layers.Dense(10, activation='softmax')
    ])

    # 3. 编译模型
    print('编译模型')
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    # 4. 训练模型
    print('训练模型...')
    history = model.fit(
        train_images, train_labels,
        epochs=2,
        batch_size=64,
        validation_split=0.2
    )

    # 5. 评估模型
    print('评估模型...')
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f'\n测试准确率: {test_acc:.4f}')

    # 取测试集第一个样本
    test_image = test_images[0]
    true_label = test_labels[0].argmax()
    prediction = model.predict(test_image.reshape(1, 28, 28, 1)).argmax()
    plot_prediction(test_image, true_label, prediction)

def plot_prediction(image, true_label, prediction):
    plt.figure()
    plt.imshow(image.squeeze(), cmap='gray')
    # 设置字体,支持中文显示
    plt.rcParams["font.sans-serif"] = ["SimHei"]
    plt.title(f'真实: {true_label}, 预测: {prediction}')
    plt.axis('off')
    plt.show()

def hook_print():
    def my_print(*args, **kwargs):
        old_print('[', datetime.datetime.now(), end="] ")
        old_print(*args, **kwargs)

    old_print = builtins.print
    builtins.print = my_print

if __name__ == '__main__':
    hook_print()
    main()

九、运作结果

结果输出
注意:首次运行程序会自动下载训练和测试数据集,比较费时间。

十、工程实践中的注意事项

在实际的工程实践中,我们需要注意以下几个方面:

  1. 数据预处理:数据的质量直接影响模型的性能。除了归一化和one-hot编码外,还可以考虑对数据进行增强,如旋转、平移、缩放等操作,以增加模型的泛化能力。

  2. 模型结构:根据实际任务的需求,合理设计模型的结构。可以尝试增加卷积层的数量、调整滤波器的大小和数量,以及改变全连接层的神经元数量,以提高模型的性能。

  3. 模型训练:选择合适的优化器、学习率和训练轮数。可以使用早停(early stopping)、学习率衰减等技巧,防止过拟合,提高模型的泛化能力。

  4. 模型评估:除了准确率外,还可以考虑使用其他评估指标,如精确率(precision)、召回率(recall)、F1值等,从多个角度评估模型的性能。

  5. 模型部署:在模型训练完成后,需要将其部署到实际的应用场景中。可以使用TensorFlow Serving等工具,将模型封装为API,供其他应用程序调用。

总之,基于TensorFlow/Keras实现的简单CNN模型,为我们提供了一种高效、便捷的手写数字识别解决方案。在实际的工程实践中,我们需要根据具体的需求和数据特点,灵活调整模型的结构和训练策略,以实现最佳的性能。

结语

本实现仅用35行核心代码完成端到端的CNN训练与验证,准确率达98%+。通过模块化设计、日志增强和可视化组件,展现了工业级代码的雏形。读者可在此基础上扩展更复杂的网络结构或部署功能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/981835.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

学习threejs,使用LineBasicMaterial基础线材质

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.LineBasicMaterial1.…

【0010】Python流程控制结构-分支结构详解

如果你觉得我的文章写的不错,请关注我哟,请点赞、评论,收藏此文章,谢谢! 本文内容体系结构如下: 分支结构是编程中的基本控制结构之一,它允许程序根据条件判断执行不同的代码路径。通过本文&…

python网络爬虫开发实战之基本库使用

目录 第二章 基本库的使用 2.1 urllib的使用 1 发送请求 2 处理异常 3 解析链接 4 分析Robots协议 2.2 requests的使用 1 准备工作 2 实例引入 3 GET请求 4 POST请求 5 响应 6 高级用法 2.3 正则表达式 1 实例引入 2 match 3 search 4 findall 5 sub 6 com…

pytest框架 核心知识的系统复习

1. pytest 介绍 是什么:Python 最流行的单元测试框架之一,支持复杂的功能测试和插件扩展。 优点: 语法简洁(用 assert 替代 self.assertEqual)。 自动发现测试用例。 丰富的插件生态(如失败重试、并发执…

Java 大视界 -- Java 大数据在智慧交通信号灯智能控制中的应用(116)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…

Electron桌面应用开发:自定义菜单

完成初始应用的创建Electron桌面应用开发:创建应用,随后我们就可以自定义软件的菜单了。菜单可以帮助用户快速找到和执行命令,而不需要记住复杂的快捷键,通过将相关功能组织在一起,用户可以更容易地发现和使用应用程序…

探索低空,旅游景区无人机应用技术详解

在低空领域,无人机技术在旅游景区中的应用已经日益广泛,为旅游业带来了前所未有的变革。以下是对旅游景区无人机应用技术的详细解析: 一、无人机景区巡检系统 1. 高清拍摄与实时监控:无人机搭载高清摄像头,能够对景区…

Python-07PDF转Word

2025-03-04-PDF转Word DeepSeek等大模型从来都不是简单的写一个静态博客这么肤浅(太多博主都只讲这个内容了)借助全网大神的奇思妙想,拓展我狭隘的思维边界。 文章目录 2025-03-04-PDF转Word [toc]1-参考网址2-学习要点3-核心逻辑4-核心代码 …

【c语言函数精选题】

c语言函数精选题 一、易错概念题1.1💡建立函数的目的1.2💡函数的定义1.3💡return语句1.4💡函数的参数1.5💡复合语句声明变量 二、代码填空题2.1💡四舍五入2.2💡二分法求方程根2.3💡输…

储油自动化革命,网关PROFINET与MODBUS网桥的无缝融合,锦上添花

储油行业作为能源供应链的关键环节,其自动化和监控系统的可靠性和效率至关重要。随着工业4.0的推进,储油设施越来越多地采用先进的自动化技术以提高安全性、降低成本并优化运营。本案例探讨了如何通过使用稳联技术PROFINET转MODBUS模块网关网桥&#xff…

不同类型光谱相机的技术差异比较

一、波段数量与连续性 ‌多光谱相机‌ 波段数:通常4-9个离散波段,光谱范围集中于400-1000nm‌。 数据特征:光谱呈阶梯状,无法连续覆盖,适用于中等精度需求场景(如植被分类)‌。 ‌高光谱相机…

Linux纯命令行界面下SVN的简单使用教程

诸神缄默不语-个人技术博文与视频目录 我用的VSCode插件是这个: 可以在文件中用色块显示代码修改了什么地方,点击色块还可以显示修改内容。 文章目录 1. SVN安装2. checkout3. update1. 将文件加入版本控制 4. commit5. 查看SVN信息:info6.…

STM32单片机芯片与内部114 DSP-变换运算 实数 复数 FFT IFFT 不限制点数

目录 一、ST 官方汇编 FFT 库(64点, 256 点和 1024 点) 1、cr4_fft_xxx_stm32 2、计算幅频响应 3、计算相频响应 二、复数浮点 FFT、IFFT(支持单精度和双精度) 1、基础支持 2、单精度函数 arm_cfft_f32 3、双精…

在IDEA中进行git回滚操作:Reset current branch to here‌或Reset HEAD

问题描述 1)在本地修改好的代码,commit到本地仓库,突然发觉有问题不想push推到远程仓库了,但它一直在push的列表中存在,那该怎么去掉push列表中的内容呢? 2)合并别的分支到当前分支&#xff0…

【五.LangChain技术与应用】【14.LangChain与MoonShot、通义千问:多模型融合的实战】

兄弟们,今天咱们来唠点硬核的——当国产大模型双雄(MoonShot和通义千问)碰上LangChain这个万能胶水,会擦出什么火花?这可不是简单的API调用教程,而是实打实的多模型组合拳打法,保准看完你也能搞出个企业级AI系统!(全程大白话,放心食用) 一、为什么非得搞多模型? 先…

33.C++二叉树进阶1(二叉搜索树两种模型及其应用)

⭐上篇文章:32.C二叉树进阶1(二叉搜索树)-CSDN博客 ⭐本篇代码:c学习/18.二叉树进阶-二叉搜索树 橘子真甜/c-learning-of-yzc - 码云 - 开源中国 (gitee.com) ⭐标⭐是比较重要的部分 在上篇文章中,实现了一个简单的二…

CSS—属性继承与预处理器:2分钟掌握预处理器

个人博客:haichenyi.com。感谢关注 1. 目录 1–目录2–属性继承3–预处理器 2. 属性继承 像Android里面继承extends,类继承,子类可以使用父类的public和protected的属性和方法。子类可以直接用。   在CSS里面也是类似的。CSS里面是布局里面…

Ansys Zemax | 使用衍射光学器件模拟增强现实 (AR) 系统的出瞳扩展器 (EPE):第 4 部分

附件下载 联系工作人员获取附件 在 OpticStudio 中使用 RCWA 工具为增强现实(AR)系统设置出瞳扩展器(EPE)的示例中,首先解释了k空间中光栅的规划,并详细讨论了设置每个光栅的步骤。 介绍 本文是该四篇文…

【数据结构】堆和priority_queue

堆的定义 堆是什么?实际上堆是一种特殊的(受限制的)完全二叉树,它在完全二叉树的基础上要求每一个节点都要大于等于或者小于等于它的子树的所有节点。这个大于小于体现在节点的值或者权重。 如图所示: 根节点大于等于…

大语言模型学习--本地部署DeepSeek

本地部署一个DeepSeek大语言模型 研究学习一下。 本地快速部署大模型的一个工具 先根据操作系统版本下载Ollama客户端 1.Ollama安装 ollama是一个开源的大型语言模型(LLM)本地化部署与管理工具,旨在简化在本地计算机上运行和管理大语言模型…