基于Keras的手写数字识别(附源码)

目录

引言

为什么要创建虚拟环境,好处在哪里?

源码 

我修改的部分

调用本地数据

修改第二层卷积层


引言

本文是博主为了记录一个好的开源代码而写,下面是代码出处!强烈建议收藏!【深度学习实战—1】:基于Keras的手写数字识别(非常详细、代码开源)

写的非常好,但是复现这篇博客却让我吃了很多苦头, 大家要先下载Anaconda3然后创建一个虚拟环境,在虚拟环境里面主要下载以下三个东西版本号只要对应好,肯定能运行,其他的库少什么安装什么!如果用显卡跑模型,原博客有提及配置!

版本号
Python版本3.7.3
Keras版本2.4.3
tensorflow版本2.4.0

为什么要创建虚拟环境,好处在哪里?

在进行机器学习项目时,我们经常会遇到需要为不同的模型安装不同版本的Python或相关库的情况。这是因为每个模型可能依赖于特定版本的库,这些版本之间可能存在兼容性差异。如果不使用虚拟环境,而是在主环境中直接安装这些库,可能会遇到以下问题:

首先,当你为新的模型安装特定版本的库时,可能会覆盖掉主环境中已经存在的其他模型所需的库版本,导致之前的模型无法正常运行。

其次,不同的Python版本之间也可能存在兼容性问题。如果你直接在主环境中升级或降级Python版本,可能会影响到依赖于特定Python版本的其他项目。

为了避免这些问题,使用虚拟环境变得尤为重要。虚拟环境是一个隔离的Python环境,其中可以安装特定版本的Python和库,而不会影响到主环境或其他虚拟环境。这样,你可以为每个机器学习模型创建一个独立的虚拟环境,并在其中安装所需的Python版本和库版本,从而确保每个模型都能在其特定的环境中稳定运行。

通过这种方法,你可以轻松地管理多个项目,而无需担心库版本冲突或Python版本不兼容的问题。希望这样的解释能帮助大家更好地理解虚拟环境在机器学习项目中的重要性。

源码 

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from keras.datasets import mnist
from sklearn.metrics import confusion_matrix
import seaborn as sns
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from keras.utils import np_utils
import tensorflow as tf

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)

# 设定随机数种子,使得每个网络层的权重初始化一致
# np.random.seed(10)

# x_train_original和y_train_original代表训练集的图像与标签, x_test_original与y_test_original代表测试集的图像与标签
(x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()
# 假设你已经知道mnist.npz文件的路径
# file_path = 'mnist.npz'  # 替换为你的mnist.npz文件的实际路径
#
# # 加载npz文件
# with np.load(file_path, allow_pickle=True) as f:
#     x_train_original = f['x_train']
#     y_train_original = f['y_train']
#     x_test_original = f['x_test']
#     y_test_original = f['y_test']

"""
数据可视化
"""


# 单张图像可视化
def mnist_visualize_single(mode, idx):
    if mode == 0:
        plt.imshow(x_train_original[idx], cmap=plt.get_cmap('gray'))
        title = 'label=' + str(y_train_original[idx])
        plt.title(title)
        plt.xticks([])  # 不显示x轴
        plt.yticks([])  # 不显示y轴
        plt.show()
    else:
        plt.imshow(x_test_original[idx], cmap=plt.get_cmap('gray'))
        title = 'label=' + str(y_test_original[idx])
        plt.title(title)
        plt.xticks([])  # 不显示x轴
        plt.yticks([])  # 不显示y轴
        plt.show()


# 多张图像可视化
def mnist_visualize_multiple(mode, start, end, length, width):
    if mode == 0:
        for i in range(start, end):
            plt.subplot(length, width, 1 + i)
            plt.imshow(x_train_original[i], cmap=plt.get_cmap('gray'))
            title = 'label=' + str(y_train_original[i])
            plt.title(title)
            plt.xticks([])
            plt.yticks([])
        plt.show()
    else:
        for i in range(start, end):
            plt.subplot(length, width, 1 + i)
            plt.imshow(x_test_original[i], cmap=plt.get_cmap('gray'))
            title = 'label=' + str(y_test_original[i])
            plt.title(title)
            plt.xticks([])
            plt.yticks([])
        plt.show()


mnist_visualize_multiple(mode=0, start=0, end=4, length=2, width=2)
# 原始数据量可视化
print('训练集图像的尺寸:', x_train_original.shape)
print('训练集标签的尺寸:', y_train_original.shape)
print('测试集图像的尺寸:', x_test_original.shape)
print('测试集标签的尺寸:', y_test_original.shape)

"""
数据预处理
"""
#
# 从训练集中分配验证集
x_val = x_train_original[50000:]
y_val = y_train_original[50000:]
x_train = x_train_original[:50000]
y_train = y_train_original[:50000]
print('======================')
# 打印验证集数据量
print('验证集图像的尺寸:', x_val.shape)
print('验证集标签的尺寸:', y_val.shape)
print('======================')
# 将图像转换为四维矩阵(nums,rows,cols,channels), 这里把数据从unint类型转化为float32类型, 提高训练精度。
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
x_val = x_val.reshape(x_val.shape[0], 28, 28, 1).astype('float32')
x_test = x_test_original.reshape(x_test_original.shape[0], 28, 28, 1).astype('float32')
#
# 原始图像的像素灰度值为0-255,为了提高模型的训练精度,通常将数值归一化映射到0-1。
x_train = x_train / 255
x_val = x_val / 255
x_test = x_test / 255
#
print('训练集传入网络的图像尺寸:', x_train.shape)
print('验证集传入网络的图像尺寸:', x_val.shape)
print('测试集传入网络的图像尺寸:', x_test.shape)
# #
# 图像标签一共有10个类别即0-9,这里将其转化为独热编码(One-hot)向量
y_train = np_utils.to_categorical(y_train)
print(y_train[0])

y_val = np_utils.to_categorical(y_val)
y_test = np_utils.to_categorical(y_test_original)


#
# """
# 定义网络模型
# """
#
#
def CNN_model():
    model = Sequential()
    model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层
    # model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层
    model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu'))  # 卷积层
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层
    model.add(Flatten())  # 平铺层
    model.add(Dense(100, activation='relu'))  # 全连接层
    model.add(Dense(10, activation='softmax'))  # 全连接层

    print(model.summary())
    return model


#
#
# """
# 训练网络
# """
#
model = CNN_model()
# #
# 编译网络(定义损失函数、优化器、评估指标)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 开始网络训练(定义训练数据与验证数据、定义训练代数,定义训练批大小) 原来20
train_history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=20, batch_size=32, verbose=2)

# 模型保存
model.save('handwritten_numeral_recognition.h5')


#
#
# #
# #
# 定义训练过程可视化函数(训练集损失、验证集损失、训练集精度、验证集精度)
def show_train_history(train_history, train, validation):
    plt.plot(train_history.history[train])
    plt.plot(train_history.history[validation])
    plt.title('Train History')
    plt.ylabel(train)
    plt.xlabel('Epoch')
    plt.legend(['train', 'validation'], loc='best')
    plt.show()


show_train_history(train_history, 'accuracy', 'val_accuracy')
show_train_history(train_history, 'loss', 'val_loss')

# 输出网络在测试集上的损失与精度
score = model.evaluate(x_test, y_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# 测试集结果预测
predictions = model.predict(x_test)
predictions = np.argmax(predictions, axis=1)
print('前9张图片预测结果:', predictions[:9])


# 预测结果图像可视化
def mnist_visualize_multiple_predict(start, end, length, width):
    for i in range(start, end):
        plt.subplot(length, width, 1 + i)
        plt.imshow(x_test_original[i], cmap=plt.get_cmap('gray'))
        title_true = 'true=' + str(y_test_original[i])
        title_prediction = ',' + 'prediction' + str(model.predict_classes(np.expand_dims(x_test[i], axis=0)))
        title = title_true + title_prediction
        plt.title(title)
        plt.xticks([])
        plt.yticks([])
    plt.show()


mnist_visualize_multiple_predict(start=0, end=9, length=3, width=3)

# 混淆矩阵
cm = confusion_matrix(y_test_original, predictions)
cm = pd.DataFrame(cm)
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']


def plot_confusion_matrix(cm):
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, cmap='Oranges', linecolor='black', linewidth=1, annot=True, fmt='', xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()


plot_confusion_matrix(cm)

我修改的部分

调用本地数据

# x_train_original和y_train_original代表训练集的图像与标签, x_test_original与y_test_original代表测试集的图像与标签
# (x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()
# 假设你已经知道mnist.npz文件的路径
file_path = 'mnist.npz'  # 替换为你的mnist.npz文件的实际路径

# 加载npz文件
with np.load(file_path, allow_pickle=True) as f:
    x_train_original = f['x_train']
    y_train_original = f['y_train']
    x_test_original = f['x_test']
    y_test_original = f['y_test']

因为原来的代码是每次运行都请求下载网上的在线数据,这是没必要的,当你运行了一次,可以把数据存在本地,然后以后本地调用

修改第二层卷积层

def CNN_model():
    model = Sequential()
    model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层
    # model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层
    model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu'))  # 卷积层
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层
    model.add(Flatten())  # 平铺层
    model.add(Dense(100, activation='relu'))  # 全连接层
    model.add(Dense(10, activation='softmax'))  # 全连接层

    print(model.summary())
    return model

 原文中的第二层卷积层的输入是规定为(28,28,1),但是这是有问题的,应该是不设置参数,这样子的话,会自动将第一个池化层的输出当作输入

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/642887.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SK6812-RGBW是一个集控制电路与发光电路于一体的智能外控LED光源

产品概述: SK6812-RGBW是一个集控制电路与发光电路于一体的智能外控LED光源。其外型与一个5050LED灯珠相同,每个元件即为一个像素点。像素点内部包含了智能数字接口数据锁存信号整形放大驱动电路,电源稳压电路,内置恒流电路&#xff0…

基于51单片机的盆栽自动浇花系统

一.硬件方案 工作原理是湿度传感器将采集到的数据直接传送到ADC0832的IN端作为输入的模拟信号。选用湿度传感器和AD转换,电路内部包含有湿度采集、AD转换、单片机译码显示等功能。单片机需要采集数据时,发出指令启动A/D转换器工作,ADC0832根…

WordPress Country State City Dropdown CF7插件 SQL注入漏洞复现(CVE-2024-3495)

0x01 产品简介 Country State City Dropdown CF7插件是一个功能强大、易于使用的WordPress插件,它为用户在联系表单中提供国家、州/省和城市的三级下拉菜单功能,帮助用户更准确地填写地区信息。同时,插件的团队和支持也非常出色,为用户提供高质量的服务。 0x02 漏洞概述 …

零基础的粉丝有福了:逐键提示盲打更轻松

盲打就是不看键盘去打字,对于零基础的粉丝而言,盲打入门通常都是很难的,今天就给大家放个福利:从今天开始就能盲打,3天之后盲打就入门了。 真的有这么简单吗?是的,跟着我做就可以了。 首先&am…

结构体(位段)内存分配

结构体由多个数据类型的成员组成。那编译器分配的内存是不是所有成员的字节数总和呢? 首先,stu的内存大小并不为29个字节,即证明结构体内存不是所有成员的字节数和。   其次,stu成员中sex的内存位置不在21,即可推测…

Linux驱动开发笔记(二) 基于字符设备驱动的I/O操作

文章目录 前言一、设备驱动的作用与本质1. 驱动的作用2. 有无操作系统的区别 二、内存管理单元MMU三、相关函数1. ioremap( )2. iounmap( )3. class_create( )4. class_destroy( ) 四、GPIO的基本知识1. GPIO的寄存器进行读写操作流程2. 引脚复用2. 定义GPIO寄存器物理地址 五、…

视频汇聚平台LntonCVS视频监控系统前端错误日志记录及Debug模式详细讲解

LntonCVS作为一种支持GB28181标准的流媒体服务平台,旨在提供一个能够整合不同厂商设备、便于管理和扩展的解决方案,以适应日益复杂的视频监控环境。通过实现设备的统一管理和流媒体的高效传输,LntonCVS帮助构建更加灵活和强大的视频监控系统。…

Andoird使用Room实现持久化及使用Room进行增删查改

文章目录 Room概述Room的使用一、在gradle.build中添加依赖库kotlinJava 创建实体类创建抽象Dao层接口创建DataBase层使用创建的查看数据库 总结: 这篇文章会告诉你如何在Android中通过kotlin或者Java来实现数据持久化 Room概述 处理大量结构化数据的应用可极大地受…

强烈推荐 20.7k Star!企业级商城开源项目强烈推荐!基于DDD领域驱动设计模型,助您快速掌握技术奥秘,实现业务快速增长

更多资源请关注纽扣编程微信公众号 1 项目简介 商城是个从零到一的C端商城项目,包含商城核心业务和基础架构两大模块,推出用户、消息、商品、订单、优惠券、支付、网关、购物车等业务模块,通过商城系统中复杂场景,给出对应解决方案。使用 …

Java——简易图书管理系统

本文使用 Java 实现一个简易图书管理系统 一、思路 简易图书管理系统说白了其实就是 用户 与 图书 这两个对象之间的交互 书的属性有 书名 作者 类型 价格 借阅状态 而用户可以分为 普通用户 管理员 使用数组将书统一管理起来 用户对这个数组进行操作 普通用户可以进…

Axure RP 10汉化版修改文字

效果 安装目录 lang/default Axure 10 RP 汉化包(概览改图层)

Express 的 req 和 res 对象

新建 learn-express文件夹,执行命令行 npm init -y npm install express 新建 index.js const express require(express); const app express();app.get(/, (req, res, next) > {res.json(return get) })app.post(/, (req, res, next) > {res.json(retur…

单机一天轻松300+ 最新微信小程序拼多多+京东全自动掘金项目、

现代互联网经济的发展带来了新型的盈利方式,这种方法通过微信小程序的拼多多和京东进行商品自动巡视,以此给商家带来增加的流量,同时为使用者带来利润。实践这一手段无需复杂操作,用户仅需启动相应程序,商品信息便会被…

【东山派Vision K510开发板试用笔记】WiFi配网问题

目录 概述 WiFi配网的修改 悬而未决的问题 概述 最近试用了百问网提供的东山派Vision开发板,DongshanPI-Vision开发板是百问网针对AI应用开发设计出来的一个RSIC-V架构的AI开发板,主要用于学习使用嘉楠的K510芯片进行Linux项目开发和嵌入式AI应用开发…

闲话 .NET(5):.NET Core 有什么优势?

前言 .NET Core 并不是 .NET FrameWork 的升级版,它是一个为满足新一代的软件设计要求而从头重新开发的开发框架和平台,所以它没有 .NET FrameWork 的历史包袱,相对于 .NET FrameWork,它具备很多优势。 .NET Core 有哪些优势&am…

什么是DDoS流量清洗?

随着互联网的飞速发展,网络安全问题日益凸显,其中分布式拒绝服务(DDoS)攻击尤为引人关注。为了有效应对这一威胁,流量清洗服务应运而生,成为网络安全领域的一项重要技术。 流量清洗服务是一种专门针对DDoS…

最小生成树【做题记录】c++(Prim,Kruskal)

目录 Prim算法求最小生成树 【算法思想】 【算法实现】 【数据结构设计】 【算法步骤】 【输入输出】 【代码示例】 Kruskal算法求最小生成树 【算法思想】 判断是否会产生回路的方法 【算法描述】 【图的存储结构】 【输入输出】 【代码示例】 Prim算法求最小生…

Reactor设计模式

Reactor设计模式 Reactor模式称为反应器模式或应答者模式,是基于事件驱动的设计模式,拥有一个或多个并发输入源,有一个服务处理器和多个请求处理器,服务处理器会同步的将输入的请求事件以多路复用的方式分发给相应的请求处理器。…

Android 自定义图片进度条

用系统的Progressbar,设置图片drawable作为进度条会出现图片长度不好控制,容易被截断,或者变形的问题。而我有个需求,使用图片背景,和图片进度,而且在进度条头部有个闪光点效果。 如下图: 找了…

Nginx部署静态网页

1、首先拿到前端给的dist包,上传到服务器指定位置:/ajd/dist 2、找到nginx.conf配置文件,修改 server {listen 9300;server_name xxx.xx.xx.xx;location / {root /ajd/dist;try_files $uri $uri/ /index.html;index index.html …