Scikit Learn识别手写数字 -- 机器学习项目基础篇(6)

Scikit learn是机器学习社区中使用最广泛的机器学习库之一,其背后的原因是代码的易用性和机器学习开发人员构建机器学习模型所需的几乎所有功能的可用性。在本文中,我们将学习如何使用sklearn在手写数字数据集上训练MLP模型。
其优势是:

  • 它提供分类、回归和聚类算法,如SVM算法、随机森林、梯度提升和k-means。
  • 它还被设计为与Python的科学和数值库NumPy和SciPy一起操作。

导入库和数据集

让我们开始导入模型所需的库并加载数据集数字。

# importing the hand written digit dataset
from sklearn import datasets
 
# digit contain the dataset
digits = datasets.load_digits()
 
# dir function use to display the attributes of the dataset
dir(digits)

输出:

['DESCR', 'data', 'feature_names', 'frame', 'images', 'target', 'target_names']

打印一组图像

  • digits.image是三维数组。第一个维度索引图像,我们可以看到总共有1797张。
  • 以下两个维度涉及每个图像的像素的x和y坐标。
  • 每个图像的大小为8×8 = 64像素。换句话说,该阵列可以在3D中表示为8×8像素图像的堆叠。
# outputting the picture value as a series of numbers
print(digits.images[0])

输出:

[[ 0.  0.  5. 13.  9.  1.  0.  0.]
 [ 0.  0. 13. 15. 10. 15.  5.  0.]
 [ 0.  3. 15.  2.  0. 11.  8.  0.]
 [ 0.  4. 12.  0.  0.  8.  8.  0.]
 [ 0.  5.  8.  0.  0.  9.  8.  0.]
 [ 0.  4. 11.  0.  1. 12.  7.  0.]
 [ 0.  2. 14.  5. 10. 12.  0.  0.]
 [ 0.  0.  6. 13. 10.  0.  0.  0.]]

原始数字的分辨率要高得多,在为scikit-learn准备数据集时,分辨率被降低,以便训练机器学习系统更容易、更快地识别这些数字。因为在如此低的分辨率下,即使是人类也很难识别其中的一些数字。输入照片的低质量也会限制在这些设置中的表现。

# importing the matplotlib libraries pyplot function
import matplotlib.pyplot as plt
# defining the function plot_multi
 
def plot_multi(i):
    nplots = 16
    fig = plt.figure(figsize=(15, 15))
    for j in range(nplots):
        plt.subplot(4, 4, j+1)
        plt.imshow(digits.images[i+j], cmap='binary')
        plt.title(digits.target[i+j])
        plt.axis('off')
    # printing the each digits in the dataset.
    plt.show()
 
    plot_multi(0)

在这里插入图片描述

用数据集训练神经网络

神经网络是一组算法,它试图使用类似于人脑工作方式的技术来识别一批数据中的潜在关系。在这种情况下,神经网络是自然界中可能是有机或人工的神经元系统。

  • 一个由64个节点组成的输入层,每个节点对应输入图片中的每个像素,它们只是将其输入值发送到下一层的神经元。
  • 这是一个密集的神经网络,意味着每一层中的每个节点都链接到前面和后面的所有节点。

输入层需要一维数组,而图像数据集是二维的。因此,展平所有图像过程:

# converting the 2 dimensional array to one dimensional array
y = digits.target
x = digits.images.reshape((len(digits.images), -1))
 
# gives the  shape of the data
x.shape

输出:

(1797, 64)

通过将8个像素的行依次合成,将8×8图像的二维合并为一维。我们之前讨论过的第一个图像现在表示为具有8×8 = 64个插槽的一维阵列。

# printing the one-dimensional array's values
x[0]

输出:

array([ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.,  0.,  0., 13., 15., 10.,
       15.,  5.,  0.,  0.,  3., 15.,  2.,  0., 11.,  8.,  0.,  0.,  4.,
       12.,  0.,  0.,  8.,  8.,  0.,  0.,  5.,  8.,  0.,  0.,  9.,  8.,
        0.,  0.,  4., 11.,  0.,  1., 12.,  7.,  0.,  0.,  2., 14.,  5.,
       10., 12.,  0.,  0.,  0.,  0.,  6., 13., 10.,  0.,  0.,  0.])

划分训练和测试数据

当机器学习算法用于基于未用于训练模型的数据进行预测时,训练测试拆分过程用于测量其性能。
这是一种快速而简单的技术,允许您比较机器学习算法的性能,以应对预测建模挑战。

# Very first 1000 photographs and
# labels will be used in training.
x_train = x[:1000]
y_train = y[:1000]
 
# The leftover dataset will be utilised to
# test the network's performance later on.
x_test = x[1000:]
y_test = y[1000:]

多层感知分类器MLP

MLP代表多层感知器。它由密集连接的层组成,这些层将任何输入尺寸转换为所需尺寸。多层感知是具有多个层的神经网络。为了构建神经网络,我们连接神经元,使它们的输出成为其他神经元的输入。

# importing the MLP classifier from sklearn
from sklearn.neural_network import MLPClassifier
 
# calling the MLP classifier with specific parameters
mlp = MLPClassifier(hidden_layer_sizes=(15,),
                    activation='logistic',
                    alpha=1e-4, solver='sgd',
                    tol=1e-4, random_state=1,
                    learning_rate_init=.1,
                    verbose=True)

现在是时候在训练数据上训练我们的MLP模型了。

mlp.fit(x_train, y_train)

输出:

Iteration 185, loss = 0.01147629
Iteration 186, loss = 0.01142365
Iteration 187, loss = 0.01136608
Iteration 188, loss = 0.01128053
Iteration 189, loss = 0.01128869
Training loss did not improve more than tol=0.000100 for 10 consecutive epochs.
Stopping.
MLPClassifier(activation='logistic', hidden_layer_sizes=(15,),
              learning_rate_init=0.1, random_state=1, solver='sgd',
              verbose=True)

上所示为MLPClassifier及其相应配置的最后五个时期的损失。

fig, axes = plt.subplots(1, 1)
axes.plot(mlp.loss_curve_, 'o-')
axes.set_xlabel("number of iteration")
axes.set_ylabel("loss")
plt.show()

在这里插入图片描述

模型评估

现在,让我们使用识别数据集检查模型的性能,或者它只是记住了它。我们将通过使用剩余的测试数据来做到这一点,这样我们就可以检查模型是否已经学习了数字中的实际模式 。

predictions = mlp.predict(x_test)
predictions[:50]

输出:

array([1, 4, 0, 5, 3, 6, 9, 6, 1, 7, 5, 4, 4, 7, 2, 8, 2, 2, 5, 7, 9, 5,
       4, 4, 9, 0, 8, 9, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 3, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 5, 0])

但是真实的标签如下所示。

y_test[:50]

输出:

array([1, 4, 0, 5, 3, 6, 9, 6, 1, 7, 5, 4, 4, 7, 2, 8, 2, 2, 5, 7, 9, 5,
       4, 4, 9, 0, 8, 9, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 9, 0])

因此,通过使用预测标签和真实标签,我们可以找到模型的准确性。

# importing the accuracy_score from the sklearn
from sklearn.metrics import accuracy_score
 
# calculating the accuracy with y_test and predictions
accuracy_score(y_test, predictions)

输出:

0.9146800501882058

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/58980.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Springboot 多数据源 dynamic-datasource动态添加移除数据源

0.前言 上一篇文章我们讲了如何通过多数据源组件,在Spring boot Druid 连接池项目中配置多数据源,并且通过DS注解的方式切换数据源,《Spring Boot 配置多数据源【最简单的方式】》。但是在多租户的业务场景中,我们通常需要手动的…

【方法】Excel表格如何拆分数据?

当需要把多个数据逐个填到Excel单元格的时候,我们可以利用Excel的数据拆分功能,可以节省不少时间。 小编以下面的数据为例,看看如何进行数据拆分。 首先,要选择数字所在的单元格,然后依次点击菜单栏中的“数据”>…

FFmpeg解码详细流程

介绍 FFmpeg的 libavcodec 模块完成音视频多媒体的编解码模块。老版本的 FFmpeg 将avcodec_decode_video2()作为视频的解码函数 API,将avcodec_decode_audio4()作为音频的解码函数 API;从 3.4版本开始已经将二者标记为废弃过时 API(attribut…

Dockerfile构建lamp镜像

1、构建目录 [rootdocker ~]# mkdir compose_lamp [rootdocker ~]# cd compose_lamp/ 2、编写Docekerfile [rootdocker compose_lamp]# vim Dockerfile #基础镜像 FROM centos:7#维护该镜像的用户信息 MAINTAINER Crushlinux <crushlinux163.com>#安装httpd RUN yum -…

反转链表的两种方法

反转链表的两种方法 题目介绍 题目链接 206. 反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 效果图如下所示 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]解法一&#xff1a;迭代方法 解题思路&#xf…

Chapter 10: Dictionaries | Python for Everybody 讲义笔记_En

文章目录 Python for Everybody课程简介DictionariesDictionariesDictionary as a set of countersDictionaries and filesLooping and dictionariesAdvanced text parsingDebuggingGlossary Python for Everybody Exploring Data Using Python 3 Dr. Charles R. Severance 课程…

Spring:JDBCTemplate

JDBCTemplate 概述 概述 JDBC&#xff08;Java DataBase Connectivity&#xff0c;Java 数据库连接&#xff09;&#xff0c; 一 种用于执行 SQL 语句的 Java API&#xff08;Application Programming Interface &#xff0c; 应用程序设计接口 &#xff09;&#xff0c;可以为…

计算机网络(2) --- 网络套接字UDP

计算机网络&#xff08;1&#xff09; --- 网络介绍_哈里沃克的博客-CSDN博客https://blog.csdn.net/m0_63488627/article/details/131967378?spm1001.2014.3001.5501 目录 1.端口号 2.TCP与UDP协议 1.TCP协议介绍 1.TCP协议 2.UDP协议 3.理解 2.网络字节序 发送逻辑…

小白到运维工程师自学之路 第六十六集 (docker 网络模型)

一、概述 Docker网络模型是指Docker容器在网络中的通信方式和组织结构。Docker容器通过网络连接&#xff0c;使得容器之间可以相互通信&#xff0c;并与主机和外部网络进行交互。 在Docker中&#xff0c;有几种不同的网络模型可供选择&#xff1a; 1、主机模式&#xff08;H…

网络安全(黑客)自学建议一一附学习路线

温馨提示&#xff1a;为了避免误入歧途&#xff0c;自学请优先看《网络安全法》。 下面是一些学习建议&#xff1a; 1、多请教有经验的人 切忌钻牛角尖&#xff0c;特别是刚入门的什么都不了解的情况下&#xff0c;可能你花好几天研究的一个东西&#xff0c;人10分钟就能搞定…

DBeaver连MySQL库报错public key retrieval is not allowed

连接报错: public key retrieval is not allowed解决办法&#xff1a; 右击你连接的库进行编辑连接&#xff08;或者直接按F4打开编辑&#xff09; 然后点击驱动属性里面进行设置 找到allowPublicKeyRetrieval属性&#xff0c;把值由false改为true 注&#xff1a;连接成功后如…

部署SpringBoot项目在服务器上,并通过swagger登录

1.项目编译打包 2.上传jar包到服务器并启动 xftp将打包好后的jar包传到虚拟机指定路径 java -jar执行该jar包 3.通过swagger登录 输入后点击下面的执行按钮 会得到一个tocken 4.将tocken放到postman的Headers中 5.修改url 例如我本地测试是http://localhost:8080/接口名&am…

【C++】继承的基本特性(定义,赋值转换,友元,静态成员,虚拟继承,默认成员函数,作用域)

文章目录 一、继承的定义1.定义格式2.继承基类成员访问方式的变化 二、基类和派生类对象赋值转换三、继承的作用域1. 在继承体系中基类和派生类都有独立的作用域。2.子类和父类中有同名成员3.成员函数的隐藏4.注意在实际中在继承体系里面最好不要定义同名的成员。 四、派生类的…

【WebRTC---序篇】(七)RTC多人连麦方案

服务端可以选择mediasoup&#xff0c;作为SFU服务器&#xff0c;只负责转发数据 下图举例三个Client (browser或者客户端)同时加入一个房间&#xff0c;每个app同时发布一路视频和一路音频&#xff0c;并且接受来自其他app的音视频流&#xff0c;mediasoup内部的结构如下&…

【诺依管理系统-前端】对话框

1.指定 v-model"openAllScene" 2.按钮代码 3.调用方法handleAddAllScene里面&#xff0c;只是openAllScene的值为true&#xff0c;调用显示对话框

unity tolua热更新框架教程(1)

git GitHub - topameng/tolua: The fastest unity lua binding solution 拉取到本地 使用unity打开&#xff0c;此处使用环境 打开前几个弹窗(管线和api升级)都点确定 修改项目设置 切换到安卓平台尝试打包编译 设置包名 查看报错 打开 屏蔽接口导出 重新生成 编译通过 …

Git常见问题

git clone 提示OpenSSL SSL_read git clone 时提示Connection was reset, errno 10054类错误 fatal: unable to acce ss https://github.com/fex-team/ueditor.git/: OpenSSL SSL_read: Connection was reset, errno 10054 备注&#xff1a;以下方法只是归纳整理&#xff0c;…

JVM GC ROOT分析

GC root原理:通过对枚举GCroot对象做引用可达性分析,即从GC root对象开始,向下搜索,形成的路径称之为 引用链。如果一个对象到GC roots对象没有任何引用,没有形成引用链,那么该对象等待GC回收,换而言之,如果减少内存泄漏,也就是切断引用链,常见的GCRoot对象如下: 1、…

【深度学习】在 MNIST实现自动编码器实践教程

一、说明 自动编码器是一种无监督学习的神经网络模型&#xff0c;主要用于降维或特征提取。常见的自动编码器包括基本的单层自动编码器、深度自动编码器、卷积自动编码器和变分自动编码器等。 其中&#xff0c;基本的单层自动编码器由一个编码器和一个解码器组成&#xff0c;编…

简单易懂的生鲜蔬果小程序开发指南

随着人们对健康意识的提高&#xff0c;越来越多的人开始注重饮食健康&#xff0c;选择新鲜的果蔬产品。为了满足市场需求&#xff0c;制作一个果蔬配送小程序成为了一个不错的选择。本文将详细介绍如何快速制作一个果蔬配送小程序。 第一步&#xff1a;登录乔拓云网后台&#x…