keras深度学习框架通过简单神经网络实现手写数字识别

   背景

     keras深度学习框架,并不是一个独立的深度学习框架,它后台依赖tensorflow或者theano。大部分开发者应该使用的是tensorflow。keras可以很方便的像搭积木一样根据模型搭出我们需要的神经网络,然后进行编译,训练,测试,预测。

    今天介绍的手写数字识别实验,主要是熟悉keras搭建神经网络的流程,以及大体的思路。现如今,手写数字识别实验的代码各种各样,对于初学者而言,我们需要的是类似helloworld那样简单的示例。通过示例,我们可以了解神经网络的搭建过程。

    这里使用的手写数字识别,通过搭建网络,构建模型,最后保存模型,然后我们加载模型,通过真实的图片来预测,也检验一下神经网络的能力。

     这里手写数字识别数据来源于官方自带mnist数据集,这个数据集包含60000个训练集和10000个测试集。每个数据是由28 * 28 = 784个矩阵元素组成。所以我们自己用来测试的图片最后应该也要按照这个28*28的尺寸来制作,并且最后进行预测predict的时候,也要像训练集或者测试集一样,把图片转为一个784元素的数组。

    准备代码

import keras
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Activation
from tensorflow.keras import datasets, utils
import matplotlib.pyplot as plt


(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape((-1, 28*28))
x_train = x_train.astype('float32')/255
x_test = x_test.reshape((-1, 28*28))
x_test = x_test.astype('float32')/255

y_train = utils.to_categorical(y_train, num_classes=10)
y_test = utils.to_categorical(y_test, num_classes=10)

print('x_train.shape', x_train.shape)
print('x_test.shape', x_test.shape)
print('y_train.shape', y_train.shape)
print('y_test.shape', y_test.shape)
"""
layer = [Dense(32, input_shape=(784,)),
         Activation('relu'),
         Dense(10),
         Activation('softmax')]

model = Sequential(layer)
"""
model = Sequential()
# model.add(Dense(units=784, activation="relu", input_dim=784))
model.add(Dense(512, activation="relu", input_shape=(28*28, )))
model.add(Dense(10, activation="softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()

history = model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test))

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label="Training accuracy")
plt.plot(epochs, val_acc, 'b', label="Validation accuracy")
plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
model.save("mnist.h5")
prediction = model.predict(x_test[:1], batch_size=32)
print(x_test[:1])
print(y_test[:1])
print(prediction)
print(np.argmax(prediction, axis=1))

    这个代码在引入了相关库之后,进行的第一件事就是数据处理:

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape((-1, 28*28))
x_train = x_train.astype('float32')/255
x_test = x_test.reshape((-1, 28*28))
x_test = x_test.astype('float32')/255
y_train = utils.to_categorical(y_train, num_classes=10)
y_test = utils.to_categorical(y_test, num_classes=10)

print('x_train.shape', x_train.shape)
print('x_test.shape', x_test.shape)
print('y_train.shape', y_train.shape)
print('y_test.shape', y_test.shape)

     我们的数据集x_train,x_test就是我们的图片数据,这个数据是784个元素组成的数组,我们先进行转矩阵,然后对像素点取模,得到0-1之间的值。我们代码最后打印了x_test[:1],可以看看它的样子:

    这里我们还使用了utils.to_categorical(y_test,num_classes=10) 对我们的目标进行了one-hot转码。通过这个图我们也看到了,数字 7 转了one-hot编码之后,变为了[0,0,0,0,0,0,0,1,0,0]。

    这个代码构建了一个简单的神经网络,也就两层,

     第一层输入层 Dense(512,activation="relu",input_shape=(28*28, ))  #512个节点,relu激活函数,输入形状或者维度 28*28=784。代码中也给出了另一种通过input_dim来指定维度的方法,意思是一样的,但是那种写法model.add(Dense(units=784, activation="relu", input_dim=784))指定的网络节点units=784。这个数字可以随便定义。手写数字识别里面,设置512,784都可以。

    第二层输出层 Dense(10, activation="softmax") #这里指定对应十个分类,也就是数字0,1,2,3,4,5,6,7,8,9的个数。手写数字识别是一个多分类问题。

     没有隐藏层,也没有其他的Dropout。就是简单神经网络。

     另外,代码中还给出了一种构建神经网络的办法:

layer = [Dense(32, input_shape=(784,)),
         Activation('relu'),
         Dense(10),
         Activation('softmax')]

model = Sequential(layer)

    意思是一样的,只不过,这里units=32,也就是输入层由32个神经网络节点组成。 

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()

    这是编译神经网络和打印神经网络概要。

    编译神经网络传入loss="categorical_cressentropy" 表示损失函数求的是交叉熵。optimizer="adam",表示优化器是adam,表示自适应算法,另外,也有可能会看到sgd,随机梯度下降算法,或者rmsprop也是一种自适应算法。metrics=["accuracy"]统计指标,这里指定成功率。 

   通过model.summary()我们可以看到神经网络节点信息: 

   

history = model.fit(x_train, y_train, epochs=5, batch_size=128, validation_data=(x_test, y_test))

    这里是把训练和测试神经网络放在一起了,我们传入的validation_data指定了测试数据集。如果不指定validation_data,那么后面,我们通过model.evaluate(x_test,y_test) 也可以得到loss,acc等数据。

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label="Training accuracy")
plt.plot(epochs, val_acc, 'b', label="Validation accuracy")
plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

    我们通过matplot来展示acc,val_acc等信息,结果如下图所示:

       我们还通过model.save("mnist.h5")保存模型,后面我们会加载这个模型来进行预测。

prediction = model.predict(x_test[:1], batch_size=32)
print(x_test[:1])
print(y_test[:1])
print(prediction)
print(np.argmax(prediction, axis=1))

    我们简单通过测试集的第一个数字7来进行了一个验证,这个验证,主要是要知道我们将来传入图片需要什么类型的数据,以及得到预测结果之后,怎么取值。这里prediction是一个按照概率来进行组装的数组,哪个概率大,最终的结果就是谁。我们通过np.argmax(prediction, axis=1)指定获取一个数组中按行(axis=1)来统计最大的那个数。

***************************************************************

    预测

     很多代码示例里面,基本上到了model.evaluate()对算法进行评估之后,就没有了,对于刚入门的人来说,神经网络创建了,测试了,好不好用也不知道。因为这个训练集和测试机都是官网给出的例子,对于程序员来说,通过实践来验证一个猜测,那才是最重要的,至于这是什么不重要。

     上面的代码最后,我们通过测试集x_test[:1]也就是第一个测试数字简单做了一个预测,大概知道了要预测,需要的数据是一个[28*28=784]的数组。而我们准备的测试图片应该也要和官方给出的测试数据对应上,也即是前面提到的图片是28*28像素的数字图片,如下所示:

    同样的给出代码:

import keras
import numpy as np
import cv2
from keras.models import load_model

model = load_model("mnist.h5")


def predict(img_path):
    img = cv2.imread(img_path, 0)
    img = img.reshape(28, 28).astype("float32") / 255  # 0 1
    img = img.reshape(-1, 784)  # 28 * 28 -> 784
    label = model.predict(img)
    label = np.argmax(label, axis=1)
    print('{} -> {}'.format(img_path, label[0]))


if __name__ == '__main__':
    for _ in range(10):
        predict("number_images/b_{}.png".format(_))

    这些图片我们放在number_images目录下,命名规则是b_0.png,b_1.png这样子。

    最后,我们加载模型,并通过opencv库加载图片,并转换图片矩阵为784个元素的数组。然后交给模型预测,预测结果是一个概率数组,取概率最大的那个数组元素。

     预测结果如下:

    结果很感人,并没有达到很高的概率,准确率60%,而且这个概率对于手写图片识别来说,还有点偏高,因为实际上很多数字图片识别错误。 

    这篇文章,主要就keras构建简单神经网络,并进行训练,测试,最后还通过我们自己手写的数字图片来进行预测验证,也过了一把深度学习的瘾。

    本文keras和tensorflow版本是2.8.0,可能有几个api与其他地方有区别,比如datasets,这里使用的是tensorflow.keras.datasets。另外在计算成功率acc的时候,使用的是history['accuracy'],有的地方可能直接是history['acc'],应该是版本的问题,根据自己的版本找到合适的方法就行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/97352.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于RabbitMQ的模拟消息队列之四——内存管理

文章目录 一、设计数据结构二、管理集合1.交换机2.队列3.绑定4.消息5.队列上的消息6.待确认消息7.恢复数据 一、设计数据结构 针对交换机、队列、绑定、消息、待确认消息设计数据结构。 交换机集合 exchangeMap 数据结构:ConcurrentHashMap key:交换机name value:交…

rpm打包

文章目录 rpm打包 1. rpm打包步骤0)安装打包工具rpm-build和rpmdevtools1)创建初始化目录2)准备打包内容3)编写打包脚本 spec文件 rpm打包 1. rpm打包步骤 0)安装打包工具rpm-build和rpmdevtools yum install rpm-bu…

C++信息学奥赛1176:谁考了第k名

#include <iostream> #include <string> using namespace std; int main() {int n, a;cin >> n >> a; // 输入整数 n 和 aint arr[n]; // 创建大小为 n 的整型数组 arrdouble btt[n]; // 创建大小为 n 的双精度浮点型数组 bttfor (int i 0; i < n;…

03-基础例程3

基础例程3 01、外部中断 ESP32的外部中断有上升沿、下降沿、低电平、高电平触发模式。 实验目的 使用外部中断功能实现按键控制LED的亮灭 按键按下为0。【即下降沿】 * 接线说明&#xff1a;按键模块-->ESP32 IO* (K1-K4)-->(14,27,26,25)* * …

2023年天府杯——C 题:码头停靠问题

问题背景&#xff1a; 某个港口有多个不同类型的码头&#xff0c;可以停靠不同种类的船只。每 艘船只需要一定的时间来完成装卸货物等任务&#xff0c;并且每个码头有容量 限制和停靠时间限制。港口需要在保证收益的情况下&#xff0c;尽可能地提高 运营效率和降低成本。同…

【中危】Spring Kafka 反序列化漏洞 (CVE-2023-34040)

zhi.oscs1024.com​​​​​ 漏洞类型反序列化发现时间2023-08-24漏洞等级中危MPS编号MPS-fed8-ocuvCVE编号CVE-2023-34040漏洞影响广度小 漏洞危害 OSCS 描述Spring Kafka 是 Spring Framework 生态系统中的一个模块&#xff0c;用于简化在 Spring 应用程序中集成 Apache Kaf…

新型安卓恶意软件使用Protobuf协议窃取用户数据

近日有研究人员发现&#xff0c;MMRat新型安卓银行恶意软件利用protobuf 数据序列化这种罕见的通信方法入侵设备窃取数据。 趋势科技最早是在2023年6月底首次发现了MMRat&#xff0c;它主要针对东南亚用户&#xff0c;在VirusTotal等反病毒扫描服务中一直未被发现。 虽然研究…

零知识证明(zk-SNARK)(一)

全称为 Zero-Knowledge Succinct Non-Interactive Argument of Knowledge&#xff0c;简洁非交互式零知识证明&#xff0c;简洁性使得运行该协议时&#xff0c;即便statement非常大&#xff0c;它的proof大小也仅有几百个bytes&#xff0c;并且验证一个proof的时间可以达到毫秒…

ExpressLRS开源之RC链路性能测试

ExpressLRS开源之RC链路性能测试 1. 源由2. 分析3. 测试方案4. 测试设计4.1 校准测试4.2 实验室测试4.3 拉距测试4.4 遮挡测试 5. 总结6. 参考资料 1. 源由 基于ExpressLRS开源基本调试验证方法&#xff0c;对RC链路性能进行简单的性能测试。 修改设计总能够满足合理的需求&a…

达梦数据库管理用户和创建用户介绍

概述 本文主要对达梦数据库管理用户和创建用户进行介绍和总结。 1.管理用户介绍 1.1 达梦安全机制 任何数据库设计和使用都需要考虑安全机制&#xff0c;达梦数据库采用“三权分立”或“四权分立”的安全机制&#xff0c;将系统中所有的权限按照类型进行划分&#xff0c;为每…

浅谈 Pytest+HttpRunner 如何展开接口测试!

软件测试有多种多样的方法和技术&#xff0c;可以从不同角度对它们进行分类。其中&#xff0c;根据软件生命周期&#xff0c;针对不同的测试对象与目标&#xff0c;可将测试过程分为 4 个阶段&#xff1a;单元测试、集成测试、系统测试和验收测试。本文着重介绍了如何借用 pyte…

基于机器学习的fNIRS信号质量控制方法

摘要 尽管功能性近红外光谱(fNIRS)在神经系统研究中的应用越来越广泛&#xff0c;但fNIRS信号处理仍未标准化&#xff0c;并且受到经验和手动操作的高度影响。在任何信号处理过程的开始阶段&#xff0c;信号质量控制(SQC)对于防止错误和不可靠结果至关重要。在fNIRS分析中&…

Endnote中查看一个文献的分组的具体方法——以Endnote X8为例

Endnote中查看一个文献的分组的具体方法——以Endnote X8为例 一、问题 当Endnote中使用分类方法对文献进行分组管理后&#xff0c;有时需要重新调整该文献的分组&#xff0c;则需要找到这个文献在哪个分组中。本文阐述怎样寻找一个文献的分组的位置信息。 二、解决方法 1.选…

基于单片机的智能数字电子秤proteus仿真设计

一、系统方案 1、当电子称开机时&#xff0c;单片机会进入一系列初始化&#xff0c;进入1602显示模式设定&#xff0c;如开关显示、光标有无设置、光标闪烁设置&#xff0c;定时器初始化&#xff0c;进入定时器模式&#xff0c;如初始值赋值。之后液晶会显示Welcome To Use Ele…

Rust之自动化测试(一):如何编写测试

开发环境 Windows 10Rust 1.71.1 VS Code 1.81.1 项目工程 这里继续沿用上次工程rust-demo 编写自动化测试 Edsger W. Dijkstra在他1972年的文章《谦逊的程序员》中说&#xff0c;“程序测试可以是一种非常有效的方法来显示错误的存在&#xff0c;但它对于显示它们的不存在…

【业务功能篇90】微服务-springcloud-检索服务-ElasticSearch实战运用-DSL语句

商城检索服务 1.检索页面的搭建 商品检索页面我们放在search服务中处理&#xff0c;首页我们需要在mall-search服务中支持Thymeleaf。添加对应的依赖 <!-- 添加Thymeleaf的依赖 --><dependency><groupId>org.springframework.boot</groupId><artifa…

Solidity 小白教程:4. 函数输出 return

Solidity 小白教程&#xff1a;4. 函数输出 return 这一讲&#xff0c;我们将介绍Solidity函数输出&#xff0c;包括&#xff1a;返回多种变量&#xff0c;命名式返回&#xff0c;以及利用解构式赋值读取全部和部分返回值。 返回值 return 和 returns Solidity有两个关键字与…

java八股文面试[JVM]——类初始化过程

回顾类加载过程&#xff1a; 知识来源&#xff1a; 【2023年面试】Class初始化过程是什么_哔哩哔哩_bilibili

物联网智慧种植农业大棚系统

一、项目背景 智慧农业是是将物联网技术和农业生产箱管理的新型农业&#xff0c;依托部署在农业生产现场的各种传感节点&#xff0c;以物联网网关为通道形成数据传输网络&#xff0c;可以实现控制柜、环境监测传感器、气象监测机器等设备的远程监控&#xff0c;达到及时高校的…

LRU代码实现

LRU&#xff1a;最久未使用置换原则 均以3个内存块举例。 FIFO&#xff1a;先进先出 LRU算法代码实现&#xff1a; /* 双链表&#xff1a;最底端是最久未使用的 哈希表&#xff1a;通过缓存数据的键映射到其在双向链表中位置对hash表做put和get&#xff1a;给LRU的cache用map…