机器学习11-前馈神经网络识别手写数字1.0

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成:

1. 输入层
输入层接收输入数据,这里是一个 28x28 的灰度图像,每个像素值表示图像中的亮度值。

2. Flatten 层
Flatten 层用于将输入数据展平为一维向量,以便传递给后续的全连接层。在这里,我们将 28x28 的图像展平为一个长度为 784 的向量。

3. 全连接层(Dense 层)
全连接层是神经网络中最常见的层之一,每个神经元与上一层的每个神经元都连接。在这里,我们有一个包含 128 个神经元的隐藏层,以及一个包含 10 个神经元的输出层。隐藏层使用 ReLU(Rectified Linear Unit)激活函数,输出层使用 softmax 激活函数。

4. 输出层
输出层产生神经网络的输出,这里是一个包含 10 个元素的向量,每个元素表示对应类别的概率。softmax 函数用于将网络的原始输出转换为概率分布。

5. 编译模型
在编译模型时,我们指定了优化器(optimizer)和损失函数(loss function)。在这里,我们使用 Adam 优化器和稀疏分类交叉熵损失函数。

6. 训练模型
使用训练数据集对模型进行训练,以学习如何将输入映射到正确的输出。在训练过程中,模型通过优化损失函数来调整权重和偏置,使其尽可能准确地预测输出。

总的来说,这个神经网络是一个经典的多层感知器(MLP),它在输入层和输出层之间包含一个或多个隐藏层,通过学习逐步提取和组合特征来进行分类或回归任务。

代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 数据预处理
train_images = train_images / 255.0
test_images = test_images / 255.0

# 构建神经网络模型
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(train_images, train_labels, epochs=5)

# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)

# 保存模型
model.save('mnist_model.h5')

# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')

# 使用加载的模型进行预测
predictions = loaded_model.predict(test_images)

结果:

Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2586 - accuracy: 0.9265
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1136 - accuracy: 0.9656
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0773 - accuracy: 0.9768
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0587 - accuracy: 0.9823
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0462 - accuracy: 0.9855
313/313 [==============================] - 0s 1ms/step - loss: 0.0750 - accuracy: 0.9775

Test accuracy: 0.9775000214576721

识别准确率挺高,然后我们也得到了训练好的模型

应用测试:

import tensorflow as tf
import numpy as np
from PIL import Image

# 加载保存的模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')

# 打开手写图片文件
image_path = 'pic/handwritten_digit_thick_5.png'  # 修改为你的手写图片文件路径
image = Image.open(image_path).convert('L')  # 转换为灰度图像

# 调整图片大小为 28x28 像素
image = image.resize((28, 28))

# 将图片转换为 NumPy 数组并进行归一化处理
image_array = np.array(image) / 255.0

# 将图片转换为模型输入的格式(添加批次维度)
input_image = np.expand_dims(image_array, axis=0)

# 使用模型进行预测
predictions = loaded_model.predict(input_image)

# 获取预测结果(最大概率的类别)
predicted_class = np.argmax(predictions)

print('Predicted digit:', predicted_class)

准备了4张图片,3张自己手写,1张摘自minst:

前两张画笔比较细,第三张是minst的5,第四张是用了粗笔自己写的5,最终结果是就minst预测对了。

Predicted digit: 2

Predicted digit: 8

Predicted digit: 5

Predicted digit: 3


结论:

可见这个模型的扩展适应性能还是不够,只能预测正确训练过的minst数字。

改进:

想办法提升训练的质量,让预测能力达标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/380049.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

空气质量预测 | Matlab实现基于SVR支持向量机回归的空气质量预测模型

文章目录 效果一览文章概述源码设计参考资料效果一览 文章概述 政府机构使用空气质量指数 (AQI) 向公众传达当前空气污染程度或预测空气污染程度。 随着 AQI 的上升,公共卫生风险也会增加。 不同国家有自己的空气质量指数,对应不同国家的空气质量标准。 基于支持向量机(Su…

真正免费的文件恢复软件easyrecovery2024中文版

easyrecovery数据恢复软件是一款广受好评的数据恢复工具,它能够有效地帮助用户恢复各种类型的文件。无论是照片、视频、音乐还是文档,都能轻松地找回这些重要文件。操作安全、用户可自主操作的数据恢复方案,它支持从各种各样的存储介质恢复删…

备战蓝桥杯---动态规划之背包问题引入

先看一个背包问题的简单版: 如果我们暴力枚举可能会超时。 但我们想一想,我们其实不关心怎么放,我们关心的是放后剩下的体积。 用可行性描述即可。 于是我们令f[i][j]表示前i个物品能否放满体积为j的背包。 f[i][j]f[i-1][j]||f[i-1][j-v…

Mysql为什么使用B+Tree作为索引结构

B树和B树 一般来说,数据库的存储引擎都是采用B树或者B树来实现索引的存储。首先来看B树,如图所示: B树是一种多路平衡树,用这种存储结构来存储大量数据,它的整个高度会相比二叉树来说,会矮很多。 而对于数…

【Spring源码解读!底层原理高级进阶】【上】探寻Spring内部:BeanFactory和ApplicationContext实现原理揭秘✨

🎉🎉欢迎光临🎉🎉 🏅我是苏泽,一位对技术充满热情的探索者和分享者。🚀🚀 🌟特别推荐给大家我的最新专栏《Spring 狂野之旅:底层原理高级进阶》 &#x1f680…

C#上位机与三菱PLC的通信05--MC协议之QnA-3E报文解析

1、MC协议回顾 MC是公开协议 ,所有报文格式都是有标准 ,MC协议可以在串口通信,也可以在以太网通信 串口:1C、2C、3C、4C 网口:4E、3E、1E A-1E是三菱PLC通信协议中最早的一种,它是一种基于二进制通信协…

windows10安装配置nvm以达到切换nodejs的目的

前言 各种各样的项目,各种node环境,还有node_modules这个庞然大物。。想想都觉得恐怖。 所以现在有了:nvm-切换node环境,pnpm–解决重复下载同样类库的问题。 下面将就如何在win10下配置进行说明 nvm下载配置 nvm的github下载地…

Flink从入门到实践(三):数据实时采集 - Flink MySQL CDC

文章目录 系列文章索引一、概述1、版本匹配2、导包 二、编码实现1、基本使用2、更多配置3、自定义序列化器4、Flink SQL方式 三、踩坑1、The MySQL server has a timezone offset (0 seconds ahead of UTC) which does not match the configured timezone Asia/Shanghai. 参考资…

kmeans聚类选择最优K值python实现

Kmeans算法中K值的确定是很重要的。 下面利用python中sklearn模块进行数据聚类的K值选择 数据集自制数据集,格式如下: 维度为3。 ①手肘法 手肘法的核心指标是SSE(sum of the squared errors,误差平方和), 其中,Ci是第…

微调LLM或使用RAG,开发RAG管道的12个痛点

论文地址:archive.is/bNbZo Pain Point 1: Missing Content 内容缺失 Pain Point 2: Missed the Top Ranked Documents 错过排名靠前的文档 Pain Point 3: Not in Context — Consolidation Strategy Limitations 不在上下文中 — 整合战略的局限性 Pain Point …

免费生成ios证书的方法(无需mac电脑)

使用hbuilderx的uniapp框架开发移动端程序很方便,可以很方便地开发出移动端的小程序和app。但是打包ios版本的app的时候却很麻烦,官方提供的教程需要使用mac电脑来生成证书,但是mac电脑却不便宜,一般的型号都差不多上万。 因此&a…

【Java面试】数据类型常见面试题

什么是包装类型 将基本类型包装进了对象中得到的类型 基本类型和包装类型有什么区别 用途不同:基本类型一般用于局部变量,包装类型用于其他地方存储方式不同:用于局部变量的基本类型存在虚拟机栈中的局部变量表中,用于成员变量…

火星符号运算 - 华为OD统一考试

OD统一考试(C卷) 分值: 100分 题解: Java / Python / C 题目描述 已知火星人使用的运算符号为 #和$ 其与地球人的等价公式如下 x#y2*x3*y4 x$y3*xy2x y是无符号整数。地球人公式按照c语言规则进行计算。火星人公式中&#xff0…

基于鲲鹏服务器的LNMP配置

基于鲲鹏服务器的LNMP配置 系统 Centos8 # cat /etc/redhat-release CentOS Linux release 8.0.1905 (Core) 卸载已经存在的旧版本的安装包 # rpm -qa | grep php #查看已经安装的PHP旧版本# rpm -qa | grep php | xargs rpm -e #卸载已经安装的旧版,如果提示有…

113.路径总和 II

给你二叉树的根节点 root 和一个整数目标和 targetSum ,找出所有 从根节点到叶子节点 路径总和等于给定目标和的路径。 叶子节点 是指没有子节点的节点。 示例 1: 输入:root [5,4,8,11,null,13,4,7,2,null,null,5,1], targetSum 22 输出&a…

RocketMQ(二):领域模型(生产者、消费者)

1 生产者(Producer) 本节介绍Apache RocketMQ 中生产者的定义、模型关系、内部属性、版本兼容和使用建议。 1.1 定义 生产者是Apache RocketMQ 系统中用来构建并传输消息到服务端的运行实体。 生产者通常被集成在业务系统中,将业务消息按照要…

513. 找树左下角的值 - 力扣(LeetCode)

题目描述 给定一个二叉树的 根节点 root,请找出该二叉树的 最底层 最左边 节点的值。 假设二叉树中至少有一个节点。 题目示例 输入: root [2,1,3] 输出: 1 解题思路 深度优先搜索 使用 depth 记录遍历到的节点的深度,result 记录深度在 depth 的最…

C++ 动态规划 记忆化搜索 滑雪

给定一个 R 行 C 列的矩阵,表示一个矩形网格滑雪场。 矩阵中第 i 行第 j 列的点表示滑雪场的第 i 行第 j 列区域的高度。 一个人从滑雪场中的某个区域内出发,每次可以向上下左右任意一个方向滑动一个单位距离。 当然,一个人能够滑动到某相…

C++:二叉搜索树模拟实现(KV模型)

C:二叉搜索树模拟实现(KV模型) 前言模拟实现KV模型1. 节点封装2、前置工作(默认构造、拷贝构造、赋值重载、析构函数等)2. 数据插入(递归和非递归版本)3、数据删除(递归和非递归版本…

【芯片设计- RTL 数字逻辑设计入门 15 -- 函数实现数据大小端转换】

文章目录 函数实现数据大小端转换函数语法函数使用的规则Verilog and Testbench综合图VCS 仿真波形 函数实现数据大小端转换 在数字芯片设计中,经常把实现特定功能的模块编写成函数,在需要的时候再在主模块中调用,以提高代码的复用性和提高设…