通过卷积神经网络(CNN)识别和预测手写数字

一:卷积神经网络(CNN)和手写数字识别MNIST数据集的介绍

卷积神经网络(Convolutional Neural Networks,简称CNN)是一种深度学习模型,它在图像和视频识别、分类和分割任务中表现出色。CNN通过模仿人类视觉系统的工作原理来处理数据,能够从图像中自动学习和提取特征。以下是CNN的一些关键特点和组成部分:

卷积层(Convolutional Layer)

卷积层是CNN的核心,它使用滤波器(或称为卷积核)在输入图像上滑动,以提取图像的局部特征。

每个滤波器负责检测图像中的特定特征,如边缘、角点或纹理等。

卷积操作会产生一个特征图(feature map),它表示输入图像在滤波器下的特征响应。

激活函数

通常在卷积层之后使用非线性激活函数,如ReLU(Rectified Linear Unit),以增加网络的非线性表达能力。

激活函数帮助网络处理复杂的模式,并使网络能够学习更复杂的特征组合。

池化层(Pooling Layer)

池化层用于降低特征图的空间尺寸,减少参数数量和计算量,同时使特征检测更加鲁棒。

最常见的池化操作是最大池化(max pooling)和平均池化(average pooling)。

全连接层(Fully Connected Layer)

在多个卷积和池化层之后,CNN通常包含一个或多个全连接层,这些层将学习到的特征映射到最终的输出类别上。

全连接层中的每个神经元都与前一层的所有激活值相连。

softmax层

在网络的最后一层,通常使用softmax层将输出转换为概率分布,用于多分类任务中。

softmax函数确保输出层的输出值在0到1之间,并且所有输出值的总和为1。

卷积神经网络的训练

CNN通过反向传播算法和梯度下降法进行训练,以最小化损失函数(如交叉熵损失)。

在训练过程中,网络的权重通过大量图像数据进行调整,以提高分类或识别的准确性。

数据增强(Data Augmentation)

为了提高CNN的泛化能力,经常使用数据增强技术,如旋转、缩放、裁剪和翻转图像,以创建更多的训练样本。

迁移学习(Transfer Learning)

迁移学习是一种技术,它允许CNN利用在一个大型数据集(如ImageNet)上预训练的网络权重,来提高在小型或特定任务上的性能。

CNN在计算机视觉领域的应用非常广泛,包括但不限于图像分类、目标检测、语义分割、物体跟踪和面部识别等任务。由于其强大的特征提取能力,CNN已成为这些任务的主流方法之一。

MNIST数据集是一个广泛使用的手写数字识别数据集,可以通过TensorFlow库Pytorch库来获取, 也可以从官方网站下载:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集它包含四个部分:训练数据集、训练数据集标签、测试数据集和测试数据集标签。这些文件是IDX格式的二进制文件,需要特定的程序来读取。这个数据集包含了60,000张训练集图像和10,000张测试集图像,每张图像都是28x28像素的手写数字,范围从0到9。这些图像被处理为灰度值,其中黑色背景用0表示,手写数字用0到1之间的灰度值表示,数值越接近1,颜色越白。

MNIST数据集的图像通常被拉直为一个一维数组,每个数组包含784个元素(28x28像素)。数据集中的每个图像都有一个对应的标签,标签以one-hot编码的形式给出,例如数字5的标签表示为[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]。

在机器学习模型中,MNIST数据集常用于训练分类器,以识别和预测手写数字。例如,在深度学习中,可以使用卷积神经网络(CNN)来处理这些图像,学习从图像像素到数字标签的映射。

二:通过Pytorch库建立CNN模型训练MNIST数据集

使用Python的Pytorch库来完成一个卷积神经网络(CNN)来训练MNIST数据集,需要遵循以下步骤:

  1. 导入必要的库:我们需要导入Pytorch以及其它可能需要的库,如torchvision用于数据加载和变换。
  2. 加载MNIST数据集:使用torchvision库中的datasets和DataLoader来加载和预处理MNIST数据集。
  3. 定义卷积神经网络结构:设计一个简单的CNN结构,包括卷积层、池化层和全连接层。
  4. 定义损失函数和优化器:选择一个合适的损失函数,如交叉熵损失,以及一个优化器,如Adam或SGD。
  5. 训练模型:在训练集上训练模型,并保存训练过程中的损失和准确率。
  6. 测试模型:在测试集上评估模型的性能。

接下来,我们将按照这些步骤使用Python代码来完成这个任务。

Step1:导入必要的库

# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
  • import torch: 导入了PyTorch的主库,这是进行深度学习任务的基础。
  • import torch.nn as nn: 导入了PyTorch的神经网络模块,它包含了构建神经网络所需的许多类和函数。
  • import torch.nn.functional as F: 导入了PyTorch的功能性API,它提供了不需要维护状态的神经网络操作,例如激活函数、池化等。
  • import torchvision: 导入了PyTorch的视觉库,它提供了许多视觉任务所需的工具和数据集。
  • import torchvision.transforms as transforms: 导入了对数据进行预处理的工具。
  • from torch.utils.data import DataLoader: 导入了PyTorch的数据加载器,它可以方便地迭代数据集。

Step2:加载MNIST数据集

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
  • transform = transforms.Compose(...): 创建了一个转换管道,用于对数据进行预处理。Compose是一个函数,它将多个转换步骤组合成一个转换。
  • transforms.ToTensor(): 将图像数据从PIL Image或NumPy ndarray格式转换为浮点张量,并且将像素值缩放到[0,1]范围内。
  • transforms.Normalize((0.5,), (0.5,)): 对图像进行归一化处理。给定均值(mean)和标准差(std),这个转换将张量的每个通道都减去均值并除以标准差。在这里,它将每个像素值从[0,1]范围转换为[-1,1]范围。
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  • 这两行代码分别加载了MNIST数据集的训练集和测试集。
  • root='./data': 指定数据集下载和存储的根目录。
  • train=True: 对于trainset,表示加载数据集的训练部分。
  • train=False: 对于testset,表示加载数据集的测试部分。
  • download=True: 表示如果数据集不在指定的root目录下,则从互联网上下载。
  • transform=transform: 应用之前定义的转换。
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
  • 这两行代码创建了两个DataLoader对象,用于在训练和测试时迭代数据集。
  • batch_size=64: 指定每个批次的样本数量。
  • shuffle=True: 对于trainloader,在每次迭代时打乱数据,这对于训练是有益的,因为它可以减少模型学习数据的顺序性。
  • shuffle=False: 对于testloader,不打乱数据,因为测试时不需要随机性。

得到了一个名为data的文件夹:

847242f10504407ca060290107d1bc8d.png

Step3:定义卷积神经网络结构

# 定义卷积神经网络结构
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
  • 这段代码定义了一个名为CNN的卷积神经网络类,它继承自nn.Module
  • __init__方法初始化了网络的结构:
    • self.conv1是一个2D卷积层,输入通道为1(MNIST图像为单通道),输出通道为32,卷积核大小为3x3,并带有1像素的填充。
    • self.pool是一个2x2的最大池化层,用于减小数据的维度。
    • self.conv2是第二个2D卷积层,输入通道为32,输出通道为64,卷积核大小为3x3,并带有1像素的填充。
    • self.fc1是一个全连接层,它将64个通道的7x7图像映射到1024个特征。
    • self.fc2是另一个全连接层,它将1024个特征映射到10个输出,对应于MNIST数据集的10个类别。
  • forward方法定义了数据通过网络的前向传播路径:
    • x首先通过conv1卷积层,然后应用ReLU激活函数,并使用pool进行池化。
    • 接着,x通过conv2卷积层,再次应用ReLU激活函数和池化。
    • x.view(-1, 64 * 7 * 7)将数据扁平化,为全连接层准备。
    • x通过fc1全连接层,并应用ReLU激活函数。
    • 最后,x通过fc2全连接层,输出结果。
# 实例化网络
net = CNN()
  • 创建了一个CNN类的实例,名为net

Step4:定义损失函数和优化器

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
  • criterion是交叉熵损失函数,常用于多分类问题。
  • optimizer是Adam优化器,用于更新网络的权重。

Step5:训练模型

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/(i+1)}")

下面是这段代码的逐行解释:

  1. epochs是一个变量,表示训练过程中模型将遍历整个训练数据集的次数。这里设置为5,意味着整个训练数据集将被遍历5次。
  2. 外层for循环,它将执行epochs次。在每次迭代中,epoch变量将代表当前的迭代次数,从0开始到epochs-1结束。
  3. 在每次epoch开始时,running_loss被重置为0.0。这个变量用于累加每个epoch中的所有批次损失,以便计算平均损失。
  4. 这是一个嵌套的for循环,它遍历trainloader返回的批次数据。enumerate函数用于遍历可迭代对象,同时跟踪当前的索引(这里是i)。
  5. trainloader是之前定义的数据加载器,它负责分批加载数据,以便于训练。
  6. 参数0指定了索引的起始值。
  7. 然后解包了data元组,其中包含输入(图像)和标签(目标值)。inputs是模型的输入数据,labels是这些输入数据的正确类别标签。
  8. 在每次迭代开始时,调用optimizer.zero_grad()来清除之前梯度计算的结果。这是必要的,因为PyTorch的梯度是累加的。
  9. 输入inputs传递给神经网络net,并得到输出outputs。这是模型的前向传播步骤。
  10. 计算了模型输出的损失。criterion是之前定义的交叉熵损失函数,它比较outputs(模型的预测)和labels(实际类别标签)来计算损失。
  11. 执行了反向传播。它计算了损失相对于模型参数的梯度。
  12. 更新了模型的权重。optimizer使用计算出的梯度来调整网络参数,以减少下一次迭代的损失。
  13. 将当前的批次损失累加到running_loss变量中,用于后续计算平均损失。
  14. 在每个epoch结束时,打印出当前epoch的编号和平均损失。epoch+1是为了从1开始计数epoch,而不是从0开始。running_loss/(i+1)计算了当前epoch的平均损失,其中i+1是当前epoch中批次的数量。

最终得到每个epoch的平均损失如下:

49592faa38b84b699f4458f2cf76a433.png

Step6:测试模型

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct / total}%")
  1. correcttotal是两个变量,分别用于跟踪模型在测试数据集上正确预测的样本数量和总的样本数量。
  2. with torch.no_grad()是一个上下文管理器,用于在测试阶段禁用梯度计算。因为测试阶段不需要计算梯度,这样可以节省内存并加快计算速度。
  3. for循环,遍历testloader返回的测试数据集的批次数据。
  4. 这行代码解包了data元组,其中包含测试图像images和它们对应的真实标签labels
  5. 这行代码将测试图像images输入到训练好的神经网络net中,并得到输出outputs
  6. torch.max(outputs.data, 1)返回两个值:第一个是每个批次中最大值的元素,第二个是这些最大值的索引。在这里,最大值代表模型对每个图像的预测类别,而索引则代表预测的类别标签。
  7. predicted是模型预测的类别标签的向量。
  8. 这行代码累加测试集中总的样本数量。labels.size(0)给出了当前批次中样本的数量。
  9. (predicted == labels)是一个布尔表达式,它比较模型的预测predicted和真实标签labels,并返回一个布尔张量,其中正确预测的位置为True,否则为False。
  10. .sum()计算布尔张量中True的数量,即正确预测的样本数量。
  11. .item()将计算得到的张量(只有一个元素)转换为Python的标量值。
  12. 这行代码计算并打印出模型在测试数据集上的准确率。准确率是通过将正确预测的样本数量correct除以总样本数量total,然后乘以100来得到的百分比。这里假设测试数据集包含10000个样本。

得到准确率如下:

9eaaa375532f47f496aa265cb2d0d615.png

使用这个建立好的卷积神经网络(CNN)模型,主要用于训练分类器。具体来说,这个模型能够识别手写数字图像,并将它们分类为0到9中的一个类别。它适用于MNIST数据集。这个示例能够帮助更好的了解卷积神经网络(CNN)的原理。

 

想要探索更多元化的数据分析视角,可以关注之前发布的相关内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/873260.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

8. GIS数据分析师岗位职责、技术要求和常见面试题

本系列文章目录: 1. GIS开发工程师岗位职责、技术要求和常见面试题 2. GIS数据工程师岗位职责、技术要求和常见面试题 3. GIS后端工程师岗位职责、技术要求和常见面试题 4. GIS前端工程师岗位职责、技术要求和常见面试题 5. GIS工程师岗位职责、技术要求和常见面试…

【高等代数笔记】线性空间(一到四)

3. 线性空间 令 K n : { ( a 1 , a 2 , . . . , a n ) ∣ a i ∈ K , i 1 , 2 , . . . , n } \textbf{K}^{n}:\{(a_{1},a_{2},...,a_{n})|a_{i}\in\textbf{K},i1,2,...,n\} Kn:{(a1​,a2​,...,an​)∣ai​∈K,i1,2,...,n},称为 n n n维向量 规定(规定…

【技术前沿】智能反向寻车解决方案:提升停车场用户体验与运营效率

亲爱的技术员及停车场管理者们,您是否曾遇到过车主在庞大的停车场中迷失方向,耗费大量时间寻找爱车的困境?这不仅影响了车主的停车体验,也无形中增加了停车场的管理难度和运营成本。本文专为解决这一痛点而生,介绍最新…

基于人工智能的手写数字识别系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 手写数字识别是一种经典的计算机视觉任务,目标是让机器能够识别手写数字。通过人工智能技术,特别是卷积神经网…

京东物流查询|开发者调用API接口实现

快递聚合查询的优势 1、高效整合多种快递信息。2、实时动态更新。3、自动化管理流程。 聚合国内外1500家快递公司的物流信息查询服务,使用API接口查询京东物流的便捷步骤,首先选择专业的数据平台的快递API接口:物流快递查询API接口-单号查询…

【论文分享】MyTEE: Own the Trusted Execution Environment on Embedded Devices 23‘NDSS

目录 AbstractINTRODUCTIONBACKGROUNDARMv8 ArchitectureSecurity statesTrustZone extensionsVirtualization Communication with Peripherals MOTIVATIONATTACK MODEL AND ASSUMPTIONSYSTEM DESIGNOverviewExecution Environments IsolationDMA FilterExternal DMA controlle…

flutter的入口和原生交互

从今天起,笔者要开始从flutter列表页面向原生页面跳转了 首先遇到了N个No such module "Flutter" 因为笔者的公司其实是从前往后改造Flutter的,所以也不需要引擎组,但是笔者搞不懂,只能照着葫芦画瓢,以后等…

【重学 MySQL】十六、算术运算符的使用

【重学 MySQL】十六、算术运算符的使用 加法 ()减法 (-)乘法 (*)除法 (/ 或 div )取模(求余数) (% 或 mod )注意事项 在 MySQL 中,算术运算符用于执行数学运算,如加法、减法、乘法、除法和取模(求余数)等。…

2024数学建模国赛选题建议+团队助攻资料(已更新完毕)

目录 一、题目特点和选题建议 二、模型选择 1、评价模型 2、预测模型 3、分类模型 4、优化模型 5、统计分析模型 三、white学长团队助攻资料 1、助攻代码 2、成品论文PDF版 3、成品论文word版 9月5日晚18:00就要公布题目了,根据历年竞赛题目…

自动化运维之WGCLOUD入门到掌握 - ubuntu服务器进行入侵检测分析

WGCLOUD监测平台,有个日志监控模块,我们本文就用它来进行ubuntu的入侵检测分析 准备:ubuntu 20,WGCLOUD v3.5.4 ubuntu的登录日志文件,用于分析用户登录行为:/var/log/auth.log,我们今天就用u…

AWS SES服务 Golang接入教程(排坑版)

因为刚来看的时候 也迷迷糊糊的 所以 先讲概念 再上代码 一 基础设置 这里需要完成两个最基础的设置任务 1 是验证至少一个收件电子邮箱 2 【很关键】是验证发送域。即身份里的域类型的身份。(可以理解为配置你的域名邮箱服务器(SMPT)为亚马…

微软出品的一款管理多个远程桌面连接的工具

RDCMan(Remote Desktop Connection Manager)是微软官方出品的一款用于管理多个远程桌面连接的工具。它可以帮助用户集中管理和分类远程桌面,特别适用于需要同时管理大量服务器或在不同计算机间切换操作的场景。 RDCMan的主要功能包括&#x…

[深度学习][LLM]:浮点数怎么表示,什么是混合精度训练?

混合精度训练 混合精度训练1. 浮点表示法:[IEEE](https://zh.wikipedia.org/wiki/电气电子工程师协会)二进制浮点数算术标准(IEEE 754)1.1 浮点数剖析1.2 举例说明例子 1:例子 2: 1.3 浮点数比较1.4 浮点数的舍入 2. 混合精度训练2.1 为什么需…

继收购西门子物流自动化后,丰田又投资一家AGV公司,智能物流版图已极其夸张...

导语 大家好,我是社长,老K。专注分享智能制造和智能仓储物流等内容。 继成功将西门子物流自动化(机场物流业务)纳入麾下后,丰田并未停下其征伐的步伐,而是再度出手,与新兴科技巨头Gideon携手,共同绘制了一幅…

副本集 Election succeeded

目录 1. 分析mongo副本集 Election succeeded 的全过程:2. 从日志里面看到数据库一致性的对比吗?3. 模拟主备不同步,副本集切换步骤注意事项: not master and slaveOkfalse解释: 其他方案方法一:使用 rs.st…

时间同步服务

多主机协作工作时,各个主机的时间同步很重要,时间不一致会造成很多重要应用的故障,如:加密协 议,日志,集群等。 利用NTP(Network Time Protocol) 协议使网络中的各个计算机时间达到…

全英文地图/天地图和谷歌瓦片地图杂交/设备分布和轨迹回放/无需翻墙离线使用

一、前言说明 随着风云局势的剧烈变化,对我们搞软件开发的人员来说,影响也是越发明显,比如之前对美对欧的软件居多,现在慢慢的变成了对大鹅和中东以及非洲的居多,这两年明显问有没有俄语或者阿拉伯语的输入法的增多&a…

vmware用ghost镜像ios、esd格式装系统

1、需要下载一个pe.iso镜像,可以用大白菜,老毛桃什么的,vmware选择从光盘启动 然后在PE里面把磁盘分为两个区,C,D盘 然后修改ISO镜像,变成要恢复的ghost包 把iso里面文件拷贝到D盘,用桌面PE工具开始重…

鸿蒙开发中实现自定义弹窗 (CustomDialog)

效果图 #思路 创建带有 CustomDialog 修饰的组件 ,并且在组件内部定义controller: CustomDialogController 实例化CustomDialogController,加载组件,open()-> 打开对话框 , close() -> 关闭对话框 #定义弹窗 (CustomDial…

视频汇聚平台LntonAIServer视频质量诊断功能--偏色检测与噪声检测

随着视频监控技术的不断进步,视频质量成为了决定监控系统性能的关键因素之一。LntonAIServer新增的视频质量诊断功能,特别是偏色检测和噪声检测,进一步强化了视频监控系统的可靠性和实用性。下面我们将详细介绍这两项功能的技术细节、应用场景…