使用MNIST数据集训练手写数字识别模型

一、MNIST数据集介绍
MNIST 数据集(手写数字数据集)是一个公开的公共数据集,任何人都可以免费获取它。目前,它已经是一个作为机器学习入门的通用性特别强的数据集之一,所以对于想要学习机器学习分类的、深度神经网络分类的、图像识别与处理的小伙伴,都可以选择MNIST数据集入门。

二、MNIST数据集结构
MNIST 数据集包含70000(60000+10000)个样本,其中有60000个训练样本和10000个测试样本,每个样本的像素大小为28*28。

1.MNIST数据集下载方式


方法一
下载地址:http://yann.lecun.com/exdb/mnist/

可以直接下载这四个文件,这四个文件分别为:
①训练样本的图像(60000个)
②对应训练样本上每一张图像上数字的标签(0~9)(60000个)
③测试样本的图像(10000个)
④对应测试样本上每一张图像上数字的标签(0~9)(10000个)

方法二
在Keras中已经内置了多种公共数据集,其中就包含MNIST数据集,如图所示。

所以可以直接调用 tf.keras.datasets.mnist,直接下载数据集。


2.开始训练

可以跟着一步一步做,不会出错

(1)导包

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time

(2)打印开始时间

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print(nowtime)

 效果展示:

(3)预处理

#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

#加载数据
mnist = tf.keras.datasets.mnist

(train_x,train_y),(test_x,test_y) = mnist.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape)) 

#数据预处理
#X_train = train_x.reshape((60000,28*28))
#Y_train = train_y.reshape((60000,28*28))       #后面采用tf.keras.layers.Flatten()改变数组形状
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

效果展示:

(4)建立模型查看结构

# 建立模型
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
print('\n',model.summary())     #查看网络结构和参数信息

    效果展示:       

(5)开始训练

#配置模型训练方法
#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])   

#训练模型
#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练前时刻:'+str(nowtime))

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练后时刻:'+str(nowtime))

效果展示:

(6)评估模型

#评估模型
model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力

效果展示:

(7)结果可视化

#结果可视化
print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率

plt.figure(figsize=(10,3))

plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
#根据需要自行选择
#plt.ion()       #打开交互式操作模式
#plt.show()
#plt.pause(5)
#plt.close()

#使用模型
plt.figure()
for i in range(10):
    num = np.random.randint(1,10000)

    plt.subplot(2,5,i+1)
    plt.axis('off')
    plt.imshow(test_x[num],cmap='gray')
    demo = tf.reshape(X_test[num],(1,28,28))
    y_pred = np.argmax(model.predict(demo))
    plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
#print('X_test[0:5]: %s'%(X_test[0:5].shape))
#print('y_pred: %s'%(y_pred))

#plt.ion()       #打开交互式操作模式
plt.show()
#plt.pause(5)
#plt.close()

展示效果:

3.测试模型

1.修改测试图片的路径

2.修改保存模型的路径

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cv2

# 建立模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))     # 添加Flatten层说明输入数据的形状
model.add(tf.keras.layers.Dense(128, activation='relu'))     # 添加隐含层,为全连接层,128个节点,relu激活函数
model.add(tf.keras.layers.Dense(10, activation='softmax'))   # 添加输出层,为全连接层,10个节点,softmax激活函数

# 加载模型参数
model.load_weights('mnist_weights.h5') # 路径根据文件实际位置修改,不然会报错

# 定义一个函数来预处理图片
def preprocess_image(image_path):
    # 读取图片,转换为灰度图
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print(f"无法加载图片:{image_path}")
        return None
    # 调整图片大小为28x28像素
    img = cv2.resize(img, (28, 28))
    # 归一化图片像素值到0-1范围
    img = img / 255.0
    # 转换图片形状以匹配模型输入
    img = img.reshape(1, 28, 28)
    return img

# 使用模型进行预测
plt.figure()
# 这里替换为你的图片路径列表
image_paths = ['shouxieti_img/test/0_10.jpg']
for i, image_path in enumerate(image_paths):
    # 预处理图片
    img = preprocess_image(image_path)
    if img is not None:
        # 使用模型进行预测
        y_pred = np.argmax(model.predict(img))
        
        # 显示图片和预测结果
        plt.subplot(1, len(image_paths), i+1)
        plt.imshow(img[0], cmap='gray')
        plt.axis('off')
        plt.title('预测值:' + str(y_pred))
    
plt.show()

 效果展示:

话不多说 源码奉上!

 4.全部代码

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print(nowtime)

#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

#加载数据
mnist = tf.keras.datasets.mnist

(train_x,train_y),(test_x,test_y) = mnist.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape)) 

#数据预处理
#X_train = train_x.reshape((60000,28*28))
#Y_train = train_y.reshape((60000,28*28))       #后面采用tf.keras.layers.Flatten()改变数组形状
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

# 建立模型
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
print('\n',model.summary())     #查看网络结构和参数信息

#配置模型训练方法
#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])   

#训练模型
#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练前时刻:'+str(nowtime))

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练后时刻:'+str(nowtime))

#评估模型
model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力


#结果可视化
print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率

plt.figure(figsize=(10,3))

plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
#根据需要自行选择
#plt.ion()       #打开交互式操作模式
#plt.show()
#plt.pause(5)
#plt.close()

#使用模型
plt.figure()
for i in range(10):
    num = np.random.randint(1,10000)

    plt.subplot(2,5,i+1)
    plt.axis('off')
    plt.imshow(test_x[num],cmap='gray')
    demo = tf.reshape(X_test[num],(1,28,28))
    y_pred = np.argmax(model.predict(demo))
    plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
#print('X_test[0:5]: %s'%(X_test[0:5].shape))
#print('y_pred: %s'%(y_pred))

#plt.ion()       #打开交互式操作模式
plt.show()
#plt.pause(5)
#plt.close()

 转载于:【神经网络与深度学习】使用MNIST数据集训练手写数字识别模型——[附完整训练代码]_使用mnist数据集进行模型训练时-CSDN博客

 谢谢支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/706029.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

抓包工具 Wireshark 的下载、安装、使用、快捷键

目录 一、什么是Wireshark?二、Wireshark下载三、Wireshark安装四、Wireshark使用4.1 基本使用4.2 过滤设置1)捕获过滤器2)显示过滤器 4.3 过滤规则1)捕获过滤器-规则语法2)显示过滤器-规则语法 4.4 常用的显示过滤器规…

js实现一个数据结构——栈

栈的概念就不再赘述,无可厚非的先进后出,而JS又是高级语言,数组中的方法十分丰富,已经自带了push pop方法进行入栈出栈的操作。 1.基本实现 class Stack {constructor() {this.items [];}// 入栈push(item) {this.items.push(i…

【C++入门(1)】命名空间

一、C出世 我们先简单认识下C的来历,C是在C语言的基础上发展来的。 当年C的设计者Bjarne Stroustrup,本贾尼斯特劳斯特卢普先生设计C语言之初,是为了对C语言做出一些更改,弥补C语言在一些方面的不足,或者做出其他的设…

二阶段提交(2pc)协议

二阶段提交(2pc)协议 1、 简介 二阶段提交算法是一个分布式一致性算法,强一致、中心化的原子提交协议,主要用来解决分布式事务问题。在单体spring应用中我们往往通过一个Transactional注解就可以保证方法的事务性,但…

破解发展难题 台山这家合作社以农业社会化服务助推乡村振兴

风吹稻田千层浪,眼下,台山四九镇的早稻长势喜人,沉甸甸的稻穗迎风而动,已进入破口抽穗的关键期,即将在6月底陆续迎来丰收。在台山市明华汇种养专业合作社管理的稻田里,合作社负责人梁明喜正仔细观察着稻苗的…

算法第六天:力扣第977题有序数组的平方

一、977.有序数组的平方的链接与题目描述 977. 有序数组的平方的链接如下所示:https://leetcode.cn/problems/squares-of-a-sorted-array/description/https://leetcode.cn/problems/squares-of-a-sorted-array/description/ 给你一个按 非递减顺序 排序的整数数组…

#慧眼识模每日PK[话题]##用五种语言说爸爸我爱你[话题]#

#慧眼识模每日PK #用五种语言说爸爸我爱你 你觉得哪个模型回答得更好?欢迎留言 A.蓝 B.紫 更多问题,扫码体验吧~ by 国家(杭州)新型交换中心

Whisper语音识别 -- 自回归解码分析

前言 Whisper 是由 OpenAI 开发的一种先进语音识别系统。它采用深度学习技术,能够高效、准确地将语音转换为文本。Whisper 支持多种语言和口音,并且在处理背景噪音和语音变异方面表现出色。其广泛应用于语音助手、翻译服务、字幕生成等领域,为…

鸿蒙轻内核A核源码分析系列七 进程管理 (3)

本文记录下进程相关的初始化函数,如OsSystemProcessCreate、OsProcessInit、OsProcessCreateInit、OsUserInitProcess、OsDeInitPCB、OsUserInitProcessStart等。 1、LiteOS-A内核进程创建初始化通用函数 先看看一些内部函数,不管是初始化用户态进程还…

收银系统小程序商城商品详情页再升级!

本期导读 1.新增:商品详情页新增商品参数模块; 2.新增:商品详情页新增保障服务模块; 3.新增:线上商城商品新增划线价; 4.新增:线上商城分销商品新增“赚”字标签及预收收益; 5.…

Linux-笔记 全志平台OTG虚拟 串口、网口、U盘笔记

前言: 此文章方法适用于全志通用平台,并且三种虚拟功能同一时间只能使用一个,原因是此3种功能都是内核USB Gadget precomposed configurations的其中一个选项,只能单选,不能多选,而且不能通过修改配置文件去…

我的考研经历

当我写下这篇文章时,我已经从考研 的失败中走出来了,考研的整个过程都写在博客日志里面了,在整理并阅读考研的日志时,想写下一篇总结,也算是为了更好的吸取教训。 前期日志模板:时间安排的还算紧凑&#x…

安鸾学院靶场——安全基础

文章目录 1、Burp抓包2、指纹识别3、压缩包解密4、Nginx整数溢出漏洞5、PHP代码基础6、linux基础命令7、Mysql数据库基础8、目录扫描9、端口扫描10、docker容器基础11、文件类型 1、Burp抓包 抓取http://47.100.220.113:8007/的返回包,可以拿到包含flag的txt文件。…

DDei在线设计器-配置主题风格

DDeiCore-主题 DDei-Core插件提供了默认主题和黑色主题。 如需了解详细的API教程以及参数说明,请参考DDei文档 默认主题 黑色主题 使用指南 引入 import { DDeiCoreThemeBlack } from "ddei-editor";使用并修改设置 extensions: [......//通过配置&am…

【FreeRTOS】内存管理

目录 1 为什么要自己实现内存管理2 FreeRTOS的5中内存管理方法2.1 Heap_12.2 Heap_22.3 Heap_32.4 Heap_4 2.5 Heap_53 Heap相关的函数3.1 pvPortMalloc/vPortFree3.2 xPortGetFreeHeapSize 3.3 xPortGetMinimumEverFreeHeapSize3.4 malloc失败的钩子函数 参考《FreeRTOS入门与…

Python私教张大鹏 Vue3整合AntDesignVue之DatePicker 日期选择框

案例&#xff1a;选择日期 <script setup> import {ref} from "vue";const date ref(null) </script> <template><div class"p-8 bg-indigo-50 text-center"><a-date-picker v-model:value"date"/><a-divide…

原子阿波罗STM32F429程序的控制器改为STM32F407驱动LCD屏

原子大神的阿波罗开发板使用STM32F429IGT6控制器&#xff0c;编程风格也与探索者F407系列有了很大的不同&#xff0c;使用BSP功能模块编程了&#xff0c;也有点类似于安富莱的编程风格了。这种模块式程序风格的优点是更加方便移植&#xff0c;更方便泡系统。 但无奈手里只有F40…

模拟笔试 - 卡码网周赛第二十一期(23年美团笔试真题)

第一题&#xff1a;小美的排列询问 解题思路: 简单题&#xff0c;一次遍历数组&#xff0c;判断 是否有和x、y相等并且相连 即可。 可优化逻辑&#xff1a;因为x和y是后输入的&#xff0c;必须存储整个数组&#xff0c;但是上面说了 **排列是指一个长度为n的数组&#xff0…

搭建一个好玩的 RSS 订阅网站记录

全文相关链接 Github仓库创建链接Railway官网Supabase官网f-droid上的co.appreactor.news应用下载链接Railway账户使用量估算链接 全文相关代码 原文地址: https://blog.taoshuge.eu.org/p/270/ Dockerfile FROM docker.io/miniflux/miniflux:2.1.3环境变量 DATABASE_URL…