机器学习基础--基于常用分类算法实现手写数字识别

# 1.数据介绍

>MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.
>MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/  获取, 它包含了四个部分:

* Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
* Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
* Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
* Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

数据文件是二进制格式的,所以要按字节读取。代码如下:

导入函数包

import struct,os
import numpy as np
from array import array as pyarray
from numpy import append, array, int8, uint8, zeros
import matplotlib.pyplot as plt
def load_mnist(image_file, label_file, path="mnist"):
    digits=np.arange(10)

    fname_image = os.path.join(path, image_file)
    fname_label = os.path.join(path, label_file)

    flbl = open(fname_label, 'rb')
    magic_nr, size = struct.unpack(">II", flbl.read(8))
    lbl = pyarray("b", flbl.read())
    flbl.close()

    fimg = open(fname_image, 'rb')
    magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
    img = pyarray("B", fimg.read())
    fimg.close()

    ind = [ k for k in range(size) if lbl[k] in digits ]
    N = len(ind)

    images = zeros((N, rows*cols), dtype=uint8)
    labels = zeros((N, 1), dtype=int8)
    for i in range(len(ind)):
        images[i] = array(img[ ind[i]*rows*cols : (ind[i]+1)*rows*cols ]).reshape((1, rows*cols))
        labels[i] = lbl[ind[i]]

    return images, labels


train_image, train_label = load_mnist("train-images-idx3-ubyte", "train-labels-idx1-ubyte")
test_image, test_label = load_mnist("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte")

定义加载mnist数据集的函数,传入参数为文件名称和路径,首先将路径拼接为文件所在路径,然后以二进制读方式打开文件

magic_nr, size = struct.unpack(">II", flbl.read(8))

">II" 表示总共读取两个 4 字节无符号整数。I:表示读取一个无符号的 4 字节整数(unsigned int),>:表示大端字节序(Big Endian)。flbl.read(8)表示读取前八个字节

lbl = pyarray("b", flbl.read())

使用 pyarray 将标签数据存储为一个高效的数组。flbl.read()读取剩余所有内容,这个是接着上面的八字节继续读取

ind = [k for k in range(size) if lbl[k] in digits]

N = len(ind)

  • 仅选择标签在 digits 范围(0 到 9)内的图像索引,存储到 ind
  • N 是符合条件的图像数量。

images = zeros((N, rows*cols), dtype=uint8) labels = zeros((N, 1), dtype=int8)

  • images: 用于存储图像数据,形状为 (N, rows*cols),每个图像展平成一维向量。
  • labels: 用于存储标签,形状为 (N, 1)

for i in range(len(ind)): images[i] = array(img[ind[i]*rows*cols : (ind[i]+1)*rows*cols]).reshape((1, rows*cols)) labels[i] = lbl[ind[i]]

  • 遍历筛选后的图像索引 ind
  • 根据索引从 img 中提取相应的图像像素数据,展平为一维并存储到 images
  • 将对应的标签存储到 labels

算法介绍 

1. K-Nearest Neighbors (KNN)
  • 概念:
    K-Nearest Neighbors(KNN)是一种基于实例的学习算法。当给定一个新的数据点时,KNN 算法通过计算这个点与训练数据中所有点的距离,找到 K 个距离最近的点,然后通过这些 K 个邻居的标签来决定新的数据点的分类标签。

  • 特点:

    • 非参数方法:KNN 不对数据做任何假设。
    • 计算开销大:需要计算与所有训练样本的距离,因此在数据量大的时候计算成本高。
    • 适合小规模数据集,且对噪声敏感。
  • 工作原理:

    • 选择合适的 K 值(K 是邻居的数量)。
    • 计算测试样本与所有训练样本的距离(通常使用欧氏距离)。
    • 选出 K 个最近的邻居。
    • 通过邻居的标签投票决定分类。
from sklearn.metrics import accuracy_score,classification_report
from sklearn.neighbors import KNeighborsClassifier

knc = KNeighborsClassifier(n_neighbors=10)
#初始化,设定要分类的个数为10
knc.fit(train_image,train_label.ravel())
#使用训练数据来训练Knn模型
predict = knc.predict(test_image)
print("accuracy_score: %.4lf" % accuracy_score(predict,test_label))

 

2. Naive Bayes
  • 概念:
    Naive Bayes 是一种基于贝叶斯定理的分类算法,假设特征之间是条件独立的(即“朴素”假设)。尽管这种假设在实际中很少成立,但 Naive Bayes 在许多实际问题中表现得很好。

  • 特点:

    • 计算速度快。
    • 对小数据集表现良好。
    • 假设特征独立:适用于条件独立的特征,如文本分类中的词频。
  • 工作原理:

    • 通过贝叶斯定理计算每个类别的后验概率:P(Y∣X)=P(X∣Y)P(Y)P(X)P(Y∣X)=P(X)P(X∣Y)P(Y)​其中 P(Y∣X)P(Y∣X) 是给定特征 XX 后属于某类别 YY 的概率,P(X∣Y)P(X∣Y) 是特征 XX 在类别 YY 下的条件概率,P(Y)P(Y) 是类别 YY 的先验概率。
    • 由于假设特征条件独立,可以将特征的联合概率分解为单个特征的条件概率的乘积。
  • 适用场景:

    • 文本分类、垃圾邮件过滤、情感分析。
from sklearn.naive_bayes import MultinomialNB

mnb = MultinomialNB()
mnb.fit(train_image,train_label)
predict = mnb.predict(test_image)
print("accuracy_score: %.4lf" % accuracy_score(predict,test_label))
print("Classification report for classifier %s:\n%s\n" % (mnb, classification_report(test_label, predict)))

 

3. Decision Tree(决策树)
  • 概念:
    决策树是一种树形结构的分类模型,它通过将数据集划分为多个子集,最终将数据分类。每个节点代表一个特征的条件,边代表特征值的取值,叶子节点代表类别标签。

  • 特点:

    • 易于理解和解释。
    • 可以处理数值型和类别型数据。
    • 对噪声数据和不平衡数据敏感。
  • 工作原理:

    • 决策树通过递归地选择最佳的特征来分割数据,使得每个分支尽可能纯(即数据类别单一)。
    • 使用某些标准(如信息增益、基尼指数)来选择最优的特征。
    • 树的构建会一直持续到所有数据被正确分类,或者达到停止条件(如最大深度、最小样本数等)。
  • 优点:

    • 可解释性强,容易理解。
    • 不需要特征缩放。
  • 缺点:

    • 容易过拟合,尤其是树很深时。
    • 对数据中小的变化非常敏感。
from sklearn.tree import DecisionTreeClassifier

dtc = DecisionTreeClassifier()
dtc.fit(train_image,train_label)
predict = dtc.predict(test_image)
print("accuracy_score: %.4lf" % accuracy_score(predict,test_label))
print("Classification report for classifier %s:\n%s\n" % (dtc, classification_report(test_label, predict)))

 

4. Random Forest(随机森林)
  • 概念:
    随机森林是一种集成学习方法,它通过构建多棵决策树并将其结果进行投票或平均,从而得到更稳健的分类结果。每棵树都是在不同的随机子集上训练出来的。

  • 特点:

    • 通过集成多个决策树来降低过拟合的风险。
    • 可以处理大规模数据,具有较好的性能。
    • 在很多应用中效果优异。
  • 工作原理:

    • 随机选择数据的子集来训练每一棵树(Bootstrap Aggregating,简称 Bagging)。
    • 每棵树在训练时只考虑特征的一个随机子集,这进一步增加了随机性并减少了过拟合。
  • 优点:

    • 非常强大的分类器,通常比单一的决策树更准确。
    • 处理高维数据时也表现良好。
    • 对噪声不敏感。
  • 缺点:

    • 模型较复杂,较难解释。
    • 训练和预测时计算开销较大。
from sklearn.ensemble import RandomForestClassifier

rfc = RandomForestClassifier()
rfc.fit(train_image,train_label)
predict = rfc.predict(test_image)
print("accuracy_score: %.4lf" % accuracy_score(predict,test_label))
print("Classification report for classifier %s:\n%s\n" % (rfc, classification_report(test_label, predict)))

 

5. Logistic Regression(逻辑回归)
  • 概念:
    逻辑回归是一种广泛使用的分类算法,尽管它的名字中有“回归”,但它实际上是一种分类方法。它通过学习特征与类别之间的关系,输出一个类别的概率值。

  • 特点:

    • 用于二分类问题(也可以扩展到多分类问题)。
    • 模型简单,计算效率高。
    • 输出概率值,适合进行概率预测。
  • 工作原理:

    • 假设特征和类别之间存在线性关系,通过 Sigmoid 函数将线性组合映射到 0 到 1 之间。
    • 目标是最小化损失函数(如交叉熵),找到最佳的回归系数。
  • 公式:

    P(y=1∣X)=11+e−(wTX+b)P(y=1∣X)=1+e−(wTX+b)1​

    其中,XX 是特征向量,ww 是模型权重,bb 是偏置项,yy 是类别。

  • 优点:

    • 输出类别概率,适合概率预测。
    • 训练速度快,且容易理解。
  • 缺点:

    • 仅适用于线性可分问题,对于非线性问题性能较差。
from sklearn.linear_model import LogisticRegression

lr = LogisticRegression()
lr.fit(train_image,train_label)
predict = lr.predict(test_image)
print("accuracy_score: %.4lf" % accuracy_score(predict,test_label))
print("Classification report for classifier %s:\n%s\n" % (lr, classification_report(test_label, predict)))

 

6. Support Vector Machine (SVM)
  • 概念:
    支持向量机(SVM)是一种基于最大间隔分类的算法。它通过在特征空间中找到一个超平面来分隔不同类别的数据点,并尽可能最大化分类边界(即最大化类别之间的间隔)。

  • 特点:

    • 强大的分类能力,尤其在高维空间中表现良好。
    • 适用于线性可分和非线性可分的情况(通过核技巧)。
    • 对于小样本数据集有较好的表现。
  • 工作原理:

    • SVM 试图找到一个超平面,使得数据点到超平面的距离尽可能大(即最大化间隔)。这种超平面被称为最优超平面
    • 核技巧(Kernel Trick)可以将数据映射到更高维的空间,使得非线性可分问题转化为线性可分问题。
  • 优点:

    • 在高维空间中依然有效。
    • 能够处理线性和非线性问题。
    • 强大的分类性能,尤其是在数据量较少时。
  • 缺点:

    • 训练速度较慢,尤其是在大规模数据集上。
    • 对于大数据集的处理较为困难。

 

from sklearn.svm import SVC

svc = SVC()
svc.fit(train_image,train_label)
predict = svc.predict(test_image)
print("accuracy_score: %.4lf" % accuracy_score(predict,test_label))
print("Classification report for classifier %s:\n%s\n" % (svc, classification_report(test_label, predict)))

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/924090.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

解决jupyter notebook 新建或打开.ipynb 报500 : Internal Server Error(涉及jinja2兼容性问题)

报错: [E 10:09:52.362 NotebookApp] 500 GET /notebooks/Untitled16.ipynb?kernel_namepyt hon3 (::1) 93.000000ms refererhttp://localhost:8888/tree ...... 重点是: from .exporters import * File "C:\ProgramData\Anaconda3\lib\site-p…

基于Springboot企业级工位管理系统【附源码】

基于Springboot企业级工位管理系统 效果如下: 系统登录页面 员工主页面 部门信息页面 员工管理页面 部门信息管理页面 工位信息管理页面 工位分配管理页面 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所。…

GoogleTest做单元测试

目录 环境准备GoogleTest 环境准备 git clone https://github.com/google/googletest.git说cmkae版本过低了,解决方法 进到googletest中 cmake CMakeLists.txt make sudo make installls /usr/local/lib存在以下文件说明安装成功 中间出了个问题就是,…

Android 11 三方应用监听关机广播ACTION_SHUTDOWN

前言 最近有项目过程中,有做app的同事反馈,三方应用无法监听关机广播。特地研究了下关机广播为啥监听不到。 1.原因:发送关机广播的类是ShutdownThread.java,添加了flag:Intent.FLAG_RECEIVER_FOREGROUND | Intent.FLAG_RECEIVER…

一篇文章了解Linux

目录 一:命令 1 ls命令作用 2 目录切换命令(cd/pwd) (1)cd切换工作目录命令 3 相对路径、绝对路径和特殊路径 (1)相对路径和绝对路径的概念和写法 (2)几种特殊路径的表示符 (3)练习题: 4 创建目录命令&#x…

css—动画

一、背景 本文章是用于解释上一篇文章中的问题,如果会动画的小伙伴就不用再次来看了,本文主要讲解一下动画的设定规则,以及如何在元素中添加动画,本文会大篇幅的讲解一下,动画属性。注意,这是css3的内容&am…

MATLAB下的RSSI定位程序,二维平面上的定位,基站数量可自适应

文章目录 引言程序概述程序代码运行结果待定位点、锚点、计算结果显示待定位点和计算结果坐标 引言 随着无线通信技术的发展,基于 R S S I RSSI RSSI(接收信号强度指示)的方法在定位系统中变得越来越流行。 R S S I RSSI RSSI定位技术特别适…

排序算法之选择排序堆排序

算法时间复杂度辅助空间复杂度稳定性选择排序O(N^2)O(1)不稳定堆排序O(NlogN)O(1)不稳定 1.选择排序 这应该算是最简单的排序算法了,每次在右边无序区里选最小值,没有无序区时,就宣告排序完毕 比如有一个数组:[2,3,2,6,5,1,4]排…

电视网络机顶盒恢复出厂超级密码大全汇总

部分电视机顶盒在按遥控器设置键打开设置时,会弹出设置密码弹窗,需输入密码才能操作其中内容。 如下图所示: 部分电视机顶盒在选择恢复出厂设置时,会出现设置密码弹窗,只有输入操作密码后才能进行恢复出厂设置的操作。…

继续完善wsl相关内容:基础指令

文章目录 前言一、我们需要安装wsl,这也是安装docker desktop的前提,因此我们在这篇文章里做了介绍:二、虽然我们在以安装docker desktop为目的时,不需要安装wsl的分发(distribution),但是装一个分发也是有诸多好处的:三、在使用wsl时,不建议把东西直接放到系统里,因…

基于STM32的智能风扇控制系统

基于STM32的智能风扇控制系统 持续更新,欢迎关注!!! ** 基于STM32的智能风扇控制系统 ** 近几年,我国电风扇市场发展迅速,产品产出持续扩张,国家产业政策鼓励电风扇产业向高技术产品方向发展,国内企业新增投资项目投…

Zero to JupyterHub with Kubernetes中篇 - Kubernetes 常规使用记录

前言:纯个人记录使用。 搭建 Zero to JupyterHub with Kubernetes 上篇 - Kubernetes 离线二进制部署。搭建 Zero to JupyterHub with Kubernetes 中篇 - Kubernetes 常规使用记录。搭建 Zero to JupyterHub with Kubernetes 下篇 - Jupyterhub on k8s。 参考&…

docker-compose搭建xxl-job、mysql

docker-compose搭建xxl-job、mysql 1、搭建docker以及docker-compose2、下载xxl-job需要数据库脚本3、创建文件夹以及docker-compose文件4、坑来了5、正确配置6、验证-运行成功 1、搭建docker以及docker-compose 略 2、下载xxl-job需要数据库脚本 下载地址:https…

HTTP有哪些风险?是怎么解决的?

一、风险 HTTP是通过明文传输的,存在窃听风险、篡改风险以及冒充风险。 二、如何解决 HTTPS在HTTP的下层加了一个SSL/TLS层,保证了安全,通过混合加密解决窃听风险、数字签名解决篡改风险、数字证书解决冒充风险。 (1&#xff0…

《Django 5 By Example》阅读笔记:p339-p358

《Django 5 By Example》学习第13天,p359-p382总结,总计24页。 一、技术总结 1.session (1)session 存储方式 Database sessions File-based sessions Cached sessions Cached database sessions Cookie-based sessions (2)设置 CART_SESSION_I…

Python数据分析(OpenCV)

第一步通过pip安装依赖包,执行一下命令 pip install opencv-python 如果是Anaconda请在工具中自行下载 下载好咋们就可以在环境中使用了。 人脸识别的特征数据可以到 github上面下载,直接搜索OpenCV 然后我们在源码中通过cv2的级联分类器引入人脸的特征…

(免费送源码)计算机毕业设计原创定制:Java+ssm+JSP+Ajax SSM棕榈校园论坛的开发

摘要 随着计算机科学技术的高速发展,计算机成了人们日常生活的必需品,从而也带动了一系列与此相关产业,是人们的生活发生了翻天覆地的变化,而网络化的出现也在改变着人们传统的生活方式,包括工作,学习,社交…

工业AI质检 AI质检智能系统 尤劲恩(上海)信息科技有限公司

来的现代化工厂,将逐步被无人化车间取代,无人工厂除了产线自动化,其无人质检将是绕不开的话题。尤劲恩致力于帮助工业制造领域上下游工厂减员增效、提高品质效率,真正实现无人质检IQC/IPQC/OQC的在线质检系统。分析生产环节真实品…

C 语言数组与函数:核心要点深度剖析与高效编程秘籍

我的个人主页 我的专栏:C语言,希望能帮助到大家!!!点赞❤ 收藏❤ 目录 引言数组基础 2.1 数组的定义与初始化 2.2 一维数组的基本操作 2.3 二维数组及其应用 2.4 数组与指针的关系函数基础 3.1 函数的定义与调用 3.2…

XML JSON

XML 与 JSON 结构 XML(eXtensible Markup Language) 1. 定义 XML 是一种标记语言,用于描述数据的结构和内容。主要用于数据存储与交换。 2. 特点 可扩展性:用户可以自定义标签。层次化结构:数据以树形结构组织&…