01、Tensorflow实现二元手写数字识别

01、Tensorflow实现二元手写数字识别(二分类问题)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0

1、识别目标

识别手写仅仅是为了区分手写的0和1,所以实际上是一个二分类问题。

2、Tensorflow算法实现

STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。

from sklearn.metrics import accuracy_score:这是从scikit-learn库中引入accuracy_score函数。这个函数用于计算分类准确率,常用于评估分类模型的性能。


STEP2:屏蔽无用警告并允许中文

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

logging.getLogger(“tensorflow”).setLevel(logging.ERROR):这行代码用于设置 TensorFlow 的日志级别为 ERROR。这意味着只有当 TensorFlow 中发生错误时,才会在日志中输出相关信息。较低级别的日志信息(如 WARNING、INFO、DEBUG)将被忽略。

tf.autograph.set_verbosity(0):这行代码用于设置 TensorFlow 的自动图形(Autograph)日志的冗长级别为 0。这意味着在将 Python 代码转换为 TensorFlow 图形代码时,将不会输出任何日志信息。这有助于减少日志噪音,使日志更加干净。

warnings.simplefilter(action=‘ignore’,category=FutureWarning):这行代码用于忽略所有 FutureWarning 类型的警告。在 Python中,当使用某些即将过时或未来版本中可能发生变化的特性时,通常会发出 FutureWarning。通过设置action=‘ignore’,代码将不会输出这类警告,使控制台输出更加干净。

plt.rcParams[‘font.sans-serif’]=[‘SimHei’]:这行代码用于设置 matplotlib 中的默认无衬线字体为 SimHei。SimHei 是一种常用于显示中文的字体,这样设置后,matplotlib 将在绘图时使用 SimHei 字体来显示中文,从而避免中文乱码问题。

plt.rcParams[‘axes.unicode_minus’]=False:这行代码用于解决 matplotlib
中负号显示异常的问题。默认情况下,matplotlib 可能无法正确显示负号,将其设置为 False 可以使用 ASCII字符作为负号,从而正常显示。


STEP3:导入并划分数据集

划分10%作为测试:

X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

STEP4:模型构建与训练

# 构建模型,三层模型进行分类,第一层输入100个神经元...
model = Sequential(
    [
        tf.keras.Input(shape=(400,)),    #specify input size
        ### START CODE HERE ###
        Dense(100, activation='sigmoid'),
        Dense(10, activation='sigmoid'),
        Dense(1, activation='sigmoid')
        ### END CODE HERE ###
    ], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(
    X_train,y_train,
    epochs=20
)


STEP5:结果可视化与打印准确度信息
原始的输入的数据集是400 * 1000的数组,共包含1000个手写数字的数据,其中400为20*20像素的图片,因此对每个400的数组进行reshape((20, 20))可以得到原始的图片进而绘图。

# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])

    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T

    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')

    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    if prediction >= 0.5:
        yhat = 1
    else:
        yhat = 0
    # Display the label above the image
    ax.set_title(f"{y_test[random_index, 0]},{yhat}")
    ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()

# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

3、运行结果

按照最初的划分,数据集包含1000个数据,划分10%为测试集,也就是100个数据。结果可视化随机选择其中的64个数据绘图,每个图像的上方标明了其真实标签和预测的结果,这个是一个非常简单的示例,准确度还是非常高的。
在这里插入图片描述

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现二元手写数字识别(二分类问题)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

# load dataset
def load_data():
    X = np.load("Handwritten_Digit_Recognition_data/X.npy")
    y = np.load("Handwritten_Digit_Recognition_data/y.npy")
    X = X[0:1000]
    y = y[0:1000]
    return X, y


# 加载数据集,查看数据集大小,可以看到有1000个数据集,每个输入是20*20=400大小的图片
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

# # 下面画图,随便从原数据取出几个画图,可以注释
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(8, 8))
# fig.tight_layout(pad=0.1)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # 将1*400的数据转换为20*20的图像格式
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
# plt.show()

# 构建模型,三层模型进行分类,第一层输入25个神经元...
model = Sequential(
    [
        tf.keras.Input(shape=(400,)),    #specify input size
        ### START CODE HERE ###
        Dense(100, activation='sigmoid'),
        Dense(10, activation='sigmoid'),
        Dense(1, activation='sigmoid')
        ### END CODE HERE ###
    ], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(
    X_train,y_train,
    epochs=20
)

# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])

    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T

    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')

    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    if prediction >= 0.5:
        yhat = 1
    else:
        yhat = 0
    # Display the label above the image
    ax.set_title(f"{y_test[random_index, 0]},{yhat}")
    ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()

# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/186296.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C#,《小白学程序》第一课:初识程序,变量,数据与显示

曰:扫地僧练就绝世武功的目的是为了扫地更干净。 1 引言 编程只是一项技术,如包包子,不是什么高深的科学。 学习程序最不好的方法是先学习枯燥的语法。 学习程序主要是用代码解决问题。因此,我们抛开所有的语法与诸多废物&…

【Tiny_CD】Tiny_CD变化检测网络详解(含python代码)

题目:TinyCD: A (Not So) Deep Learning Model For Change Detection 论文:paper 代码:code 目录 🍟 🍟1.摘要 🍗🍗 2.贡献 🍖🍖 3.网络结构

classifier-free-guidance 扩散模型引导生成

浅谈扩散模型的有分类器引导和无分类器引导 - 知乎这篇文章主要比较一下扩散模型的引导生成的三种做法的区别。它们分别是用显式分类器引导生成的做法,用隐式无分类器引导的做法和用CLIP计算跨模态间的损失来引导生成的做法。 Classifier-Guidance: Diffusion Mode……

React + BraftEditor 实现富文本编辑

Braft Editor 是一个基于 React 和 Draft-js 开发的富文本编辑器,提供了丰富的基础功能,如基本文本格式化、列表、链接、图片上传、视频插入等,并且还支持扩展。 首先,确保你已经在项目中安装了 Braft Editor 和它的依赖项&#x…

腾讯云发布新一代基于AMD处理器的星星海云服务器实例SA5

基础设施的硬实力,愈发成为云厂商的核心竞争力。 11月24日,腾讯云发布了全新一代星星海服务器。基于自研服务器的高密设计与硬件升级,对应云服务器SA5是全球首家搭载第四代AMD EPYC处理器(Bergamo)的公有云实例&#…

【机器学习】平滑滤波

平滑滤波技术 平滑滤波,顾名思义就是对信号进行处理使之整体显得更加平滑,降低噪声影响,提高信号质量,它常见于数字信号处理和图像处理,一般意义上的数字信号多体现于一维数据,图像信号多体现于二维数据。…

大众博客系统测试报告【改】

一、项目背景 大众博客系统采用前后端分离的方法来实现,同时使用了数据库来存储相关的数据,同时将其部署到云服务器上。前端主要有四个页面构成:登录页、列表页、详情页以及编辑页,以上模拟实现了最简单的大众博客系统。其结合后端…

DGL在异构图上的GraphConv模块

回顾同构图GraphConv模块 首先回顾一下同构图中实现GraphConv的主要思路(以GraphSAGE为例): 在初始化模块首先是获取源节点和目标节点的输入维度,同时获取输出的特征维度。根据SAGE论文提出的三种聚合操作,需要获取所…

Day40力扣打卡

打卡记录 包子凑数(裴蜀定理 DP) 根据裴蜀定理,存在 c gcd(a, b) 使不定方程ax by c满足条件,如果gcd(a, b) 1即a与b互素的情况下,就会 ax by 1,由于为1可以构造后面的无穷数字,故得到结…

项目实战详细讲解带有条件响应的 SQL 盲注、MFA绕过技术、MFA绕过技术、2FA绕过和技巧、CSRF绕过、如何寻找NFT市场中的XSS漏洞

项目实战详细讲解带有条件响应的 SQL 盲注、MFA绕过技术、MFA绕过技术、2FA绕过和技巧、CSRF绕过、如何寻找NFT市场中的XSS漏洞。 带有条件响应的 SQL 盲注 这篇文章的核心要点如下: 漏洞发现:作者在Portswigger提供的实验室中发现了一个盲SQL注入漏洞。这个漏洞存在于一个应…

【libGDX】Mesh纹理贴图

1 前言 纹理贴图的本质是将图片的纹理坐标与模型的顶点坐标建立一一映射关系。纹理坐标的 x、y 轴正方向分别朝右和朝下,如下。 2 纹理贴图 本节将使用 Mesh、ShaderProgram、Shader 实现纹理贴图,OpenGL ES 的实现见博客 → 纹理贴图。 DesktopLauncher…

Prometheus环境搭建和认识

Prometheus 环境搭建 1.prometheus 简介 Prometheus是基于go语言开发的一套开源的监控、报警和时间序列数据库的组合,是由SoundCloud公司开发的开源监控系统,Prometheus于2016年加入CNCF(Cloud Native Computing Foundation,云原生计算基金…

虾皮插件:优化Shopee商家店铺运营的利器

在如今竞争激烈的电商市场中,如何提升Shopee商家店铺的运营效率和销售业绩成为了摆在每个商家面前的一道难题。然而,幸运的是,虾皮插件-知虾的出现为商家们带来了一种全新的解决方案。本文将介绍虾皮插件的用途和优势,并详细介绍其…

【第一部也是唯一一部】3DMAX脚本语言MAXScript 中文帮助

3DMAX我们很多3D设计师和艺术家都在使用这款功能强大的三维软件,但是再强大的工具也不可能包罗万象,无所不能,所以,通常官方努力在功能和性能平衡之间的同时,也提供第三方扩展软件功能的可能—插件开发。 3DMAX插件开发…

Linux的基本指令(3)

16.cal指令 cal命令可以用来显示公历(阳历)日历。公历是现在国际通用的历法,又称格列历,通称阳历。“阳历”又名“太阳历”,系以地球绕行太阳一周为一年,为西方各国所通用,故又名“西历”。 命…

安防系统智能视频监控中出现画面异常该如何自检?

大家都知道,在当今社会,摄像头无处不在,除了常见的生活与工作场景中,在一些无法人员无法长期驻点场景,如野生动物监测、高空作业监控、高压电缆监控等场景,在这些地方安装摄像头就是为方便日常监控。但是由…

java 反射和注解1-反射详解

反射和注解本就是一家人,注解离不开反射,这里先将反射的写法,本文涉到的注解暂时可以不不用理解 1,创建一个类 public class ReflexUser {public String name;private String namePrivate;protected String nameProtected;Strin…

【自主探索】基于 rrt_exploration 的单个机器人自主探索建图

文章目录 一、rrt_exploration 介绍1、原理2、主要思想3、拟解决的问题4、优缺点 二、安装环境三、安装与运行1、安装2、运行 四、配置自己的机器人1、Robots Network2、Robots frame names in tf3、Robots node and topic names4、Setting up the navigation stack on the rob…

CSS特效018:科技动画,hover后点亮阁楼,拉伸出楼梯

CSS常用示例100专栏目录 本专栏记录的是经常使用的CSS示例与技巧,主要包含CSS布局,CSS特效,CSS花边信息三部分内容。其中CSS布局主要是列出一些常用的CSS布局信息点,CSS特效主要是一些动画示例,CSS花边是描述了一些CSS…

Linux时间命令—— 显示时间,日历等

目录 1.date显示时间 1.1 常用的标记列表: 1.2 设定时间: 2.cal显示日历 3.时间戳 1.date显示时间 date 用法:date [OPTION] ... [FORMAT] 1.1 常用的标记列表: %H : 小时 (00..23) %M : 分钟 (00..59) %S : 秒 (00..61…