python3+TensorFlow 2.x(三)手写数字识别

目录

代码实现

模型解析:

1、加载 MNIST 数据集:

2、数据预处理:

3、构建神经网络模型:

4、编译模型:

5、训练模型:

6、评估模型:

7、预测和可视化结果:

输出结果:

总结:


代码实现

TensorFlow 2.x 实现手写数字识别(MNIST 数据集)。MNIST 数据集包含了 28x28 像素的手写数字图像,任务是将这些图像分类为 10 个类别(0-9) 

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt

# 1. 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# 2. 数据预处理:归一化和改变形状
train_images = train_images / 255.0  # 将图像像素值归一化到 [0, 1]
test_images = test_images / 255.0

# 调整形状,使得每张图片的维度是 [28, 28, 1],因为模型需要3D输入
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1))
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))

# 3. 构建神经网络模型
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')  # 10类分类问题
])

# 4. 编译模型:选择优化器、损失函数和评价指标
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',  # 因为标签是整数,所以使用 sparse_categorical_crossentropy
              metrics=['accuracy'])

# 5. 训练模型
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))

# 6. 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")

# 7. 可视化训练过程中的损失和准确率变化
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 8. 使用模型进行预测
predictions = model.predict(test_images)

# 显示一些预测结果
for i in range(5):
    plt.imshow(test_images[i].reshape(28, 28), cmap='gray')
    plt.title(f"Predicted Label: {predictions[i].argmax()}, Actual Label: {test_labels[i]}")
    plt.show()

模型解析:

1、加载 MNIST 数据集:

使用 tf.keras.datasets.mnist.load_data() 函数来加载 MNIST 数据集。返回的数据包括训练集和测试集。训练集有 60,000 张图像,测试集有 10,000 张图像。

2、数据预处理:

将图像的像素值从 [0, 255] 归一化到 [0, 1],使每个像素的值在 0 到 1 之间,提升模型的训练效果。将每张图像的形状调整为 (28, 28, 1),即每个图像是 28x28 的灰度图像。

3、构建神经网络模型:

使用卷积神经网络(CNN)构建模型:Conv2D 层用于提取图像的特征,使用了 ReLU 激活函数。MaxPooling2D 层用于下采样,减少计算量。Flatten 层将卷积层的输出展平,进入全连接层。Dense 层用于输出分类结果,其中最后一层使用了 softmax 激活函数,将模型的输出转换为 10 类的概率分布。

4、编译模型:

使用 adam 优化器,sparse_categorical_crossentropy 作为损失函数(适用于类别标签是整数的情况),并使用 accuracy 作为评价指标。

5、训练模型:

使用 model.fit 训练模型,设置了 5 个 epoch,使用训练集进行训练,并验证模型在测试集上的表现。

6、评估模型:

使用 model.evaluate 在测试集上评估模型的准确性。并可视化训练过程中的损失和准确率变化:使用 matplotlib 绘制训练过程中的损失和准确率变化曲线,查看模型的学习进度。

7、预测和可视化结果

使用训练好的模型对测试集进行预测,展示一些预测结果,并与真实标签进行对比。

输出结果

训练和验证准确率:随着训练的进行,准确率应该逐渐提高。
测试准确率:训练完成后,模型在测试集上的准确率会显示出来,通常可以达到 98% 以上。
预测图像:展示一些手写数字图像,标注预测的标签和实际标签。

预测可视化展示

总结:

该模型使用了卷积层、池化层以及全连接层,在 MNIST 数据集上训练,最终达到了很好的分类效果。你可以调整模型的超参数(例如卷积层的数量、神经元的数量等)以提高性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/959614.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LKT4304新一代算法移植加密芯片,守护 物联网设备和云服务安全

凌科芯安作为一家在加密芯片领域深耕18年的企业,主推的LKT4304系列加密芯片集成了身份认证、算法下载、数据保护和完整性校验等多方面安全防护功能,可以为客户的产品提供一站式解决方案,并且在调试和使用过程提供全程技术支持,针对…

js/ts数值计算精度丢失问题及解决方案

文章目录 概念及问题问题分析解决方案方案一方案二方案其它——用成熟的库 概念及问题 js中处理浮点数运算时会出现精度丢失。js中整数和浮点数都属于Number数据类型,所有的数字都是以64位浮点数形式存储,整数也是如此。所以打印x.00这样的浮点数的结果…

SSM框架探秘:Spring 整合 SpringMVC 框架

搭建和测试 SpringMVC 的开发环境&#xff1a; web.xml 元素顺序&#xff1a; 在 web.xml 中配置 DisPatcherServlet 前端控制器&#xff1a; <!-- 配置前端控制器 --> <servlet><servlet-name>dispatcherServlet</servlet-name><servlet-class>…

算法11(力扣496)-下一个更大元素I

1、问题 nums1 中数字 x 的 下一个更大元素 是指 x 在 nums2 中对应位置 右侧 的 第一个 比 x 大的元素。 给你两个 没有重复元素 的数组 nums1 和 nums2 &#xff0c;下标从 0 开始计数&#xff0c;其中nums1 是 nums2 的子集。 对于每个 0 < i < nums1.length &#xf…

2024年博客之星主题创作|2024年蓝桥杯与数学建模年度总结与心得

引言 2024年&#xff0c;我在蓝桥杯编程竞赛和数学建模竞赛中投入了大量时间和精力&#xff0c;这两项活动不仅加深了我对算法、数据结构、数学建模方法的理解&#xff0c;还提升了我的解决实际问题的能力。从蓝桥杯的算法挑战到数学建模的复杂应用&#xff0c;我在这些竞赛中…

Spring FatJar写文件到RCE分析

背景 现在生产环境部署 spring boot 项目一般都是将其打包成一个 FatJar&#xff0c;即把所有依赖的第三方 jar 也打包进自身的 app.jar 中&#xff0c;最后以 java -jar app.jar 形式来运行整个项目。 运行时项目的 classpath 包括 app.jar 中的 BOOT-INF/classes 目录和 BO…

初阶数据结构:链表(二)

目录 一、前言 二、带头双向循环链表 1.带头双向循环链表的结构 &#xff08;1)什么是带头&#xff1f; (2)什么是双向呢&#xff1f; &#xff08;3&#xff09;那什么是循环呢&#xff1f; 2.带头双向循环链表的实现 &#xff08;1&#xff09;节点结构 &#xff08;2…

Java Web-Request与Response

在 Java Web 开发中&#xff0c;Request 和 Response 是两个非常重要的对象&#xff0c;用于在客户端和服务器之间进行请求和响应的处理&#xff0c;以下是详细介绍&#xff1a; Request&#xff08;请求对象&#xff09; Request继承体系 在 Java Web 开发中&#xff0c;通…

mysql 学习2 MYSQL数据模型,mysql内部可以创建多个数据库,一个数据库中有多个表;表是真正放数据的地方,关系型数据库 。

在第一章中安装 &#xff0c;启动mysql80 服务后&#xff0c;连接上了mysql&#xff0c;那么就要 使用 SQL语句来 操作mysql数据库了。那么在学习 SQL语言操作 mysql 数据库 之前&#xff0c;要对于 mysql数据模型有一个了解。 MYSQL数据模型 在下图中 客户端 将 SQL语言&…

微信小程序date picker的一些说明

微信小程序的picker是一个功能强大的组件&#xff0c;它可以是一个普通选择器&#xff0c;也可以是多项选择器&#xff0c;也可以是时间、日期、省市区选择器。 官方文档在这里 这里讲一下date picker的用法。 <view class"section"><view class"se…

【学习笔记】计算机网络(二)

第2章 物理层 文章目录 第2章 物理层2.1物理层的基本概念2.2 数据通信的基础知识2.2.1 数据通信系统的模型2.2.2 有关信道的几个基本概念2.2.3 信道的极限容量 2.3物理层下面的传输媒体2.3.1 导引型传输媒体2.3.2 非导引型传输媒体 2.4 信道复用技术2.4.1 频分复用、时分复用和…

总结8..

#include <stdio.h> // 定义结构体表示二叉树节点&#xff0c;包含左右子节点编号 struct node { int l; int r; } tree[100000]; // 全局变量记录二叉树最大深度&#xff0c;初始为0 int ans 0; // 深度优先搜索函数 // pos: 当前节点在数组中的位置&#xff0c…

多智能体中的理论与传统智能体理论有何异同?

多智能体系统与传统单智能体理论在多个方面存在异同&#xff0c;多智能体系统在理论上扩展了单智能体系统的研究范畴&#xff0c;强调智能体之间的交互和协作。随着人工智能、人机智能、人机环境系统智能的发展&#xff0c;多智能体系统在机器人群体、分布式计算、资源管理等领…

RKNN_C++版本-YOLOV5

1.背景 为了实现低延时&#xff0c;所以开始看看C版本的rknn的使用&#xff0c;确实有不足的地方&#xff0c;请指正&#xff08;代码借鉴了rk官方的仓库文件&#xff09;。 2.基本的操作流程 1.读取模型初始化 // 设置基本信息 // 在postprocess.h文件中定义&#xff0c;详见…

FlinkSql使用中rank/dense_rank函数报错空指针

问题描述 在flink1.16(甚至以前的版本)中&#xff0c;使用rank()或者dense_rank()进行排序时&#xff0c;某些场景会导致报错空指针NPE(NullPointerError) 报错内容如下 该报错没有行号/错误位置&#xff0c;无法排查 现状 目前已经确认为bug&#xff0c;根据github上的PR日…

csapp2.4节——浮点数

目录 二进制小数 十进制小数转二进制小数 IEEE浮点表示 规格化表示 非规格化表示 特殊值 舍入 浮点运算 二进制小数 类比十进制中的小数&#xff0c;可定义出二进制小数 例如1010.0101 小数点后的权重从-1开始递减。 十进制小数转二进制小数 整数部分使用辗转相除…

【2024年华为OD机试】 (A卷,200分)- 开放日活动、取出尽量少的球(JavaScriptJava PythonC/C++)

一、问题描述 题目描述 某部门开展Family Day开放日活动,其中有个从桶里取球的游戏,游戏规则如下: 有N个容量一样的小桶等距排开。每个小桶默认装了数量不等的小球,记录在数组 bucketBallNums 中。游戏开始时,要求所有桶的小球总数不能超过 SUM。如果小球总数超过 SUM,…

《安富莱嵌入式周报》第349期:VSCode正式支持Matlab调试,DIY录音室级麦克风,开源流体吊坠,物联网在军工领域的应用,Unicode字符压缩解压

周报汇总地址&#xff1a;嵌入式周报 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬汉嵌入式论坛 - Powered by Discuz! 视频版&#xff1a; 《安富莱嵌入式周报》第349期&#xff1a;VSCode正式支持Matlab调试&#xff0c;DIY录音室级麦克风…

nacos(基于docker最详细安装)

1、什么是Spring Cloud Spring Cloud是一系列框架的集合。它利用Spring Boot的开发便利性巧妙地简化了分布式系统基础设施的开发&#xff0c;如服务发现注册、配置中心、消息总线、负载均衡、断路器、数据监控等&#xff0c;都可以用Spring Boot的开发风格做到一键启动和部署。…

李沐vscode配置+github管理+FFmpeg视频搬运+百度API添加翻译字幕

终端输入nvidia-smi查看cuda版本 我的是12.5&#xff0c;在网上没有找到12.5的torch&#xff0c;就安装12.1的。torch&#xff0c;torchvision&#xff0c;torchaudio版本以及python版本要对应 参考&#xff1a;https://blog.csdn.net/FengHanI/article/details/135116114 创…