MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)

查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

1.MNIST的数据集下载与预处理

import tensorflow as tf
from keras.datasets import mnist
from keras.utils import to_categorical

(train_x,train_y),(test_x,test_y) = mnist.load_data()
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) # 归一化
y_train,y_test = to_categorical(train_y),to_categorical(test_y) # onehot
print(X_train[:5])
print(y_train[:5])

2.搭建MLP模型

from keras import Sequential
from keras.layers import Flatten,Dense
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

3.模型训练

3.1 调用model.compile()函数对训练模型进行设置

model.compile(optimizer='adam',
			  loss='categorical_crossentropy',
              metrics=['accuracy'])
  • loss=‘categorical_crossentropy’: 损失函数设置为交叉熵损失函数,在深度学习中用交叉熵模式训练效果会比较好。
  • optimizer=‘adam’: 优化器设置为adam, 在深度学习中可以让训练更快收敛,并提高准确率。
  • metrics=[‘accuracy’]:评估模式设置为准确度评估模式。

loss参数常用的损失函数

  • binary_crossentropy: 亦称作对数损失,logloss
  • categorical_crossentropy: 交叉熵损失函数,亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为onehot形式
  • sparse_categorical_crossentropy:稀疏交叉熵损失函数。
  • kullback_leibler_divergence: 从预测值概率分布Q到真值概率分布P的信息增益,用以度量两个分布的差异
  • poisson: 即(pred-target*log(pred))的均值
  • cosine_proximity:预测值与真实标签的余弦距离平均值的相反数

优化器

  • SGD
  • RMSprop
  • Adagrad
  • Adadelta
  • Adam
  • Adamax
  • Nadam
  • TFOptimizer

评估模式

  • binary_accuracy: 对二分类问题,计算在所有预测值上的平均正确率
  • categorical_accuracy: 对多分类问题,计算在所有预测值上的平均正确率
  • sparse_categorical_accuracy:与categorical_accuracy相同,在对稀疏的目标值预测时有用
  • top_k_categorical_accuracy: 计算top-k正确率,当预测值的前K个值中存在目标类别即认为预测正确
  • sparse_top_k_categorical_accuracy: 与top_k_categorical_accuracy作用相同,但适用于稀疏情况

3.2 调用model.fit()配置训练参数,开始训练,并保存训练结果。

H = model.fit(x=X_train,
			  y=y_train,
			  validation_split=0.2,
			  epochs=20,
		      batch_size=128,
			  verbose=1)

在这里插入图片描述

4.显示模型准确率和误差

import matplotlib.pyplot as plt

def show_train(history,train,validation):
    plt.plot(history.epoch, history.history[train],label=train)
    plt.plot(history.epoch, history.history[validation],label=validation)
    plt.title(train)
    plt.legend()
    plt.show()
    
show_train(H,'loss','val_loss')
show_train(H,'accuracy','val_accuracy')

在这里插入图片描述

5.使用测试数据进行识别

import numpy as np
import matplotlib.pyplot as plt

def pred_plot_images_lables(images,labels,start_idx,num=5):
    # 预测
    res = model.predict(images[start_idx:start_idx+num])
    res = np.argmax(res,axis=1)

    # 画图
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

pred_plot_images_lables(X_test,test_y,0,5)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/590969.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MLP实现fashion_mnist数据集分类(2)-函数式API构建模型(tensorflow)

使用函数式API构建模型,使得模型可以处理多输入多输出。 1、查看tensorflow版本 import tensorflow as tfprint(Tensorflow Version:{}.format(tf.__version__)) print(tf.config.list_physical_devices())2、fashion_mnist数据集分类模型 2.1 使用Sequential构建…

pygame鼠标绘制

pygame鼠标绘制 Pygame鼠标绘制效果代码 Pygame Pygame是一个开源的Python库,专为电子游戏开发而设计。它建立在SDL(Simple DirectMedia Layer)的基础上,允许开发者使用Python这种高级语言来实时开发电子游戏,而无需被…

自定义数据上的YOLOv9分割训练

原文地址:yolov9-segmentation-training-on-custom-data 2024 年 4 月 16 日 在飞速发展的计算机视觉领域,物体分割在从图像中提取有意义的信息方面起着举足轻重的作用。在众多分割算法中,YOLOv9 是一种稳健且适应性强的解决方案&#xff0…

环形链表 [两道题目](详解)

环形链表(详解) 第一题: 题目: 题目链接 给你一个链表的头节点 head ,判断链表中是否有环。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表…

使用protoc-jar-maven-plugin生成grpc项目

在《使用protobuf-maven-plugin生成grpc项目》中我们使用protobuf-maven-plugin完成了grpc代码的翻译。本文我们将只是替换pom.xml中的部分内容,使用protoc-jar-maven-plugin来完成相同的功能。总体来说protoc-jar-maven-plugin方案更加简便。 环境 见《使用proto…

【数据结构(邓俊辉)学习笔记】列表01——从向量到列表

文章目录 0.概述1. 从向量到列表1.1 从静态到动态1.2 从向量到列表1.3 从秩到位置1.4 列表 2. 接口2.1 列表节点2.1.1 ADT接口2.1.2 ListNode模板类 2.2 列表2.2.1 ADT接口2.2.2 List模板类 0.概述 学习了向量,再介绍下列表。先介绍下列表里的概念和语义&#xff0…

global IoT SIM解决方案

有任何关于GSMA\IOT\eSIM\RSP\业务应用场景相关的问题,欢迎W: xiangcunge59 一起讨论, 共同进步 (加的时候请注明: 来自CSDN-iot). Onomondo提供的全球IoT SIM卡解决方案具有以下特点和优势: 1. **单一全球配置文件**:Onomondo的SIM卡拥…

hive-row_number() 和 rank() 和 dense_rank()

row_number() 是无脑排序 rank() 是相同的值排名相同,相同值之后的排名会继续加,是我们正常认知的排名,比如学生成绩。 dense_rank()也是相同的值排名相同,接下来的排名不会加。不会占据排名的坑位。

C#知识|将选中的账号信息展示到控制台(小示例)

哈喽,你好啊,我是雷工! 上篇学习了控件事件的统一关联, 本篇通过实例练习继续学习事件统一处理中Tag数据获取、对象的封装及泛型集合List的综合运用。 01 实现功能 在上篇的基础上实现,点击选中喜欢的账号&#xff0…

WebSocket 多屏同显和异显

介绍 多屏同显:通过在一个应用上进行操作之后,另一个应用也能跟着一起发生改变,例如app1播放了晴天这首音乐,那么app2也要同步播放这首音乐,确保所有屏幕显示的内容完全相同。多屏异显:每个屏幕可以显示不同的内容,或者在内容更新时存在一定的延迟,而不需要严格保持同步…

MyBatis 使用 XML 文件映射

在MyBatis中 我们可以使用各种注解来配置我们Mapper 类中的方法 我们为什么要使用XML文件呢? 如果我们是一条非常长的SQL 语句 使用 注解配置的话, 会非常不利于阅读 如下 所以,就需要使用到一个XML文件来对SQL语句进行映射,那么 …

力扣763. 划分字母区间

Problem: 763. 划分字母区间 文章目录 题目描述思路复杂度Code 题目描述 思路 1.创建一个名为 last 的数组,用于存储每个字母在字符串 s 中最后出现的位置。然后,获取字符串 s 的长度 len。 2.计算每个字母的最后位置:遍历字符串 s&#xff0…

(五)SQL系列练习题(上)创建、导入与查询 #CDA学习打卡

目录 一. 创建表 1)创建课程表 2)创建学生表 3)创建教师表 4)创建成绩表 二. 导入数据 1)导入课程科目数据 2)导入课程成绩数据 3)导入学生信息数据 4)导入教师信息数据 …

RocketMQ SpringBoot 3.0不兼容解决方案

很多小伙伴在项目升级到springBoot3.0版本以上之后,整合很多中间件会有很多问题,下面带小伙伴解决springBoot3.0版本以上对于RocketMQ 不兼容问题 报错信息 *************************** APPLICATION FAILED TO START *************************** Des…

电脑崩溃了,之前备份的GHO文件怎么恢复到新硬盘?

前言 之前咱们说到用WinPE系统给电脑做一个GHO镜像备份,这个备份可以用于硬盘完全崩溃换盘的情况下使用。 那么这个GHO镜像文件怎么用呢? 咱们今天详细来讲讲! 如果你的电脑系统硬盘崩溃了或者是坏掉了,那么就需要使用之前备份…

ubuntu 小工具

随着在专业领域待得越久,发现总会遇到各种奇怪的需求。因此,这里介绍一些ubuntu上的一些小工具。 meld 对比文件 sudo apt-get install meld sudo apt --fix-broken install # maybeHow to compare two files

pytorch笔记:ModuleDict

1 介绍 在 PyTorch 中,nn.ModuleDict 是一个方便的容器,用于存储一组子模块(即 nn.Module 对象)的字典这个容器主要用于动态地管理多个模块,并通过键来访问它们,类似于 Python 的字典 2 特点 组织性 nn…

《金融研究》:普惠金融改革试验区DID工具变量数据(2012-2023年)

数据简介:本数据集包括普惠金融改革试验区和普惠金融服务乡村振兴改革试验区两类。 其中,河南兰考、浙江宁波、福建龙岩和宁德、江西赣州和吉安、陕西铜川五省七地为普惠金融改革试验区。山东临沂、浙江丽水、四川成都三地设立的是普惠金融服务乡村振兴…

可视化大屏应用(20):智慧水务、水利、防汛

hello,我是大千UI工场,本篇分享智智慧水务的大屏设计,关注我们,学习N多UI干货,有设计需求,我们也可以接单。 在智慧水务领域,可视化大屏具有以下几个方面的价值: 数据展示和监控 可…

华为手机 鸿蒙系统-android studio识别调试设备,开启adb调试权限

1.进入设置-关于手机-版本号,连续点击7次 认证:有锁屏密码需要输入密码, 开启开发者配置功能ok 进入开发者配置界面 打开调试功能 重新在androd studio查看可运行running devices显示了, 不行的话,重启一下android …