初探深度学习-手写字体识别

前言

手写数字的神经网络识别通常指的是通过训练有素的神经网络模型来识别和分类手写数字图像的任务。这种类型的任务是机器学习和计算机视觉领域的一个经典问题,经常作为入门级的图像识别问题来展示和测试各种机器学习算法的能力。
在实际应用中,手写数字识别可以用于处理和分析用户书写的数字信息,比如自动识别和输入手写数字文档中的数据,或者用于教育领域的练习评分系统,其中系统可以自动识别学生书写的数学题答案并提供反馈。
MNIST数据集是手写数字识别中最常使用的数据集之一。它包含了大量不同书写风格的手写数字图片,以及与之对应的标签。这个数据集被广泛用于训练和测试各种机器学习模型,包括深度学习模型。
在训练过程中,神经网络会学习如何识别和区分不同的手写数字,这个过程涉及到从大量的样本中提取特征,并调整网络内部参数,以便在看到一个新的手写数字图片时,能够准确地预测它所代表的数字。

训练过程

1-数据处理

本次训练的数据集选自 MNIST 数据集,该数据集保存在torchvision包中,因此我们只需要在训练开始前导入该包并进行数据处理即可:

# 导入训练集相关数据包
import torchvision
from torchvision.transforms import ToTensor
# 导入相关数据,并保存到本地,以便于下次训练
train_ds = torchvision.datasets.MNIST('data/',train=True,transform=ToTensor(),download=True)
test_ds = torchvision.datasets.MNIST('data/',train=False,transform=ToTensor(),download=True)
#对数据进行相关的处理
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=64)

在导入数据集合的时候,transform=Tensor()做的事情是将图片和标签转换为PyTorch张量,这样为我们后续进行相关模型处理提供了方便。
其中DataLoader函数对我们导入的相关数据进行了处理,其中进行了两个重要的步骤:

  1. 对训练集和测试集的数据设置了batch_size=64,这是每个批量包含64个样本的意思是,在每一次训练迭代中,模型会同时处理64个样本。这些样本会一起通过模型的前向传播(forward pass)过程,然后计算损失(如何预测不准确)并更新模型的参数。
  2. 对训练集的数据设置了打乱,这样会增加了模型的泛化能力,模型能更好地从整体上学习数据的分布,而不是记住每个单独的数据点。

2-模型构建

在导入完数据后,我们需要设计一个简单的神经网络,因为我们处理的是图像,因此输入维度是28*28,又因为我们处理的是手写字体识别,这是一个十分类的问题,因此我们的输出层的维度是10。在此基础上,我们对这个三层全连接的神经网络中间的隐藏层进行相关的设置,具体的代码如下:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 第一层输入展平后长度为28X28,创建128个神经元
        self.liner_1 = nn.Linear(28*28,128)
        # 第二层是前一层的输出,创建84个神经元
        self.liner_2 = nn.Linear(128,84)
        # 输出层接收第二层的输入84,输出分类个数为10
        self.liner_3 = nn.Linear(84,10)

在简单设置完这个三层的神经网络之后,我们需要设计这个神经网络的前向传播过程,具体代码如下:

    def forward(self,input):
        x = input.view(-1,28*28) #将输入展平为二维(1,28,28)->(28*28)
        x = torch.relu(self.liner_1(x))
        x = torch.relu(self.liner_2(x))
        x = self.liner_3(x)
        return x

这段代码实现的是将每个批量的多维张量的大小进行改变,以保证保持批次大小不变,但同时得到图像的像素总数,作为第一个线性层的输入,然后将输入通过第一个全连接层(self.liner_1),然后应用ReLU(Rectified Linear Unit)激活函数。
输入完成后再进行传入到第二个全连接层,这是另一个非线性变换,进一步帮助网络提取和组合特征。
x = self.liner_3(x) 这行代码将上一个层的输出通过输出层(self.liner_3)。这一层通常用于生成最终的预测结果。在这个例子中,self.liner_3是一个分类层,它将输出一个10维的向量,表示10个类别的得分。
最后,函数返回通过整个网络前向传播后的输出x。这个输出将用于后续的损失计算和梯度下降步骤,以训练网络的权重。
在设置完相关的神经网络模型后,我们还需要设置相关的损失函数和优化器,相关代码如下:

model = Model().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.001)

这段代码创建了一个损失函数CrossEntropyLoss的实例。这是一个用于分类问题的损失函数,它计算预测的概率分布与真实标签之间的交叉熵。
然后创建了一个优化器SGD(Stochastic Gradient Descent)的实例。这个优化器用于更新模型的参数,以最小化损失函数。
通过将模型中模型中所有可训练参数的列表进行返回,并且设置学习率为0.001进行优化器每次更新参数时所使用的步长的设定。

3-训练函数&测试函数

设计完模型后,我们需要通过对模型的使用,来设计我们的训练函数和测试函数,其中训练函数代码如下:

#编写训练循环
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)
    num_batchs = len(dataloader)
    train_loss,correct = 0,0
    for X,y in dataloader:
        X,y =X.to(device),y.to(device)
        pred = model(X)
        loss = loss_fn(pred,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            correct += (pred.argmax(1)==y).type(torch.float).sum().item()
            train_loss += loss.item()
    train_loss /= num_batchs
    correct /= size
    return train_loss,correct

这段代码首先获取数据加载器中整个数据集的大小和批量,将其赋值给sizenum_batchs,同时将train_losscorrect置空。
然后遍历数据加载器中的每个批次,将其数据和标签赋值给Xy,并将其移动到GPU中。
在此基础上,进行模型的训练,首先通过将X传入模型,得到前向传播的结果pred,然后计算预测pred和真实标签y之间的损失。
其中以下三行代码极其关键:

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

这三段代码实现的功能是:首先清除优化器中先前计算的梯度,为新的梯度更新做准备。然后计算损失关于模型参数的梯度。最后使用计算出的梯度更新模型的参数。
最后通过以下代码更新损失和准确度,并将其返回:

        with torch.no_grad():
            correct += (pred.argmax(1)==y).type(torch.float).sum().item()
            train_loss += loss.item()
    train_loss /= num_batchs
    correct /= size
    return train_loss,correct

测试函数与训练函数大体相似,不同的是没有使用相关的传播代码,因为它只需要测试,不需要训练:

def test(dataloader,model):
    size = len(dataloader.dataset)
    num_batches =len(dataloader)
    test_loss,correct = 0,0
    with torch.no_grad():
        for X,y in dataloader:
            X,y =X.to(device),y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred,y).item()
            correct +=(pred.argmax(1)==y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return test_loss,correct

4-模型训练

模型训练的过程就是通过持续调用模型的训练函数和测试函数,进行多轮训练,此时我们设置训练轮数为 50 轮,相关代码如下:

epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
    epoch_loss,epoch_acc=train(train_dl,model,loss_fn,optimizer)
    epoch_test_loss,epoch_test_acc=test(test_dl,model)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_loss)
    template=("epoch:{:2d},train_loss:{:.5f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")
    print(template.format(epoch,epoch_loss,epoch_acc*100,epoch_test_loss,epoch_test_acc*100))
print("Done!")  

5-模型使用

在训练完模型后,我们需要对模型进行保存,以便后续将模型赋值给相关函数和功能,保存模型的代码如下,该代码会将训练的模型以model_complete_pth方式进行保存:

torch.save(model,'model_complete.pth')

接下来我们调用该模型,对传入的图形数据进行预测,整体代码如下:

import torch
import torchvision
from torchvision.transforms import ToTensor
# 对训练数据进行处理
from random import shuffle
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 第一层输入展平后长度为28X28,创建128个神经元
        self.liner_1 = nn.Linear(28*28,128)
        # 第二层是前一层的输出,创建84个神经元
        self.liner_2 = nn.Linear(128,84)
        # 输出层接收第二层的输入84,输出分类个数为10
        self.liner_3 = nn.Linear(84,10)
    def forward(self,input):
        x = input.view(-1,28*28) #将输入展平为二维(1,28,28)->(28*28)
        x = torch.relu(self.liner_1(x))
        x = torch.relu(self.liner_2(x))
        x = self.liner_3(x)
        return x
device = 'cuda' if torch.cuda.is_available() else "cpu"
model = Model().to(device)
train_ds = torchvision.datasets.MNIST('data/',train=True,transform=ToTensor(),download=True)
test_ds = torchvision.datasets.MNIST('data/',train=False,transform=ToTensor(),download=True)
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=64)
def process_image(image):
    image = image.unsqueeze(0)  # 在第一维增加一个维度,因为模型期望的输入形状是 (batch_size, channels, height, width)
    image = image.to(device)
    return image
def predict(image):
    image = process_image(image)
    with torch.no_grad():
        prediction = model(image)
        _,predicted_digit = torch.max(prediction, 1)
    return predicted_digit.item()
device = 'cuda' if torch.cuda.is_available() else "cpu"
model = torch.load('model_complete.pth')
for i in range(5):
    test_image, _ = test_ds[i]
    plt.imshow(test_image.squeeze(), cmap='gray')
    plt.show()
    predicted_digit = predict(test_image)
    print(f"The predicted digit is: {predicted_digit}")

程序运行后,会根据输入的图片进行相关的分类,准确度较高:image.png

总结

本篇文章我们记录了对手写字体识别的训练任务,包括了数据处理、模型构建、训练函数构建、测试函数构建、模型训练及使用五个部分,作为初级的深度学习训练任务,这是一个十分类的任务,通过这个任务,我们可以具体的去感知神经网络任务的设计、使用、损失函数的使用、优化器的使用,这背后涉及到很多的数学知识需要我们去学习和调优。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/445388.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

寒假作业Day 09

寒假作业Day 09 一、选择题 因为一开始的for循环&#xff0c;k<2NN&#xff0c;所以复杂度为2N方&#xff0c;而后面的M10的while循环&#xff0c;则是10&#xff0c;复杂度为常数级&#xff0c;所以2N方10&#xff0c;近似于N方&#xff0c;即O(N^2) 这是一个计算阶乘的递…

excel批量数据导入时用poi将数据转化成指定实体工具类

1.实现目标 excel进行批量数据导入时&#xff0c;将批量数据转化成指定的实体集合用于数据操作&#xff0c;实现思路&#xff1a;使用注解将属性与表格中的标题进行同名绑定来赋值。 2.代码实现 2.1 目录截图如下 2.2 代码实现 package poi.constants;/*** description: 用…

一键部署Tesseract-OCR环境C++版本(Windows)

环境&#xff1a;Windows 10 工具&#xff1a;git vcpkg vscode cmake 库&#xff1a;Tesseract 一键部署Tesseract-OCR环境C版本&#xff08;Windows&#xff09; 分享这篇文章的原因很简单&#xff0c;就是为了让后续的朋友少走弯路。自己在搜索相关C版本的tesseract部署时…

【python量化】基于okex API开发的海龟策略

介绍 基于okex api开发的海龟策略&#xff0c;okex海龟策略python实现方式。该程序目前只支持单品种&#xff0c;比如设置ETH后&#xff0c;只对ETH进行做多做空。该程序运行需要两样东西&#xff1a;apikey 和 标的 运行该程序之前&#xff0c;用户需要到okex网站去申请apiK…

虚函数与纯虚函数有什么区别?

总的来说有两点区别&#xff1a; 1.虚函数的作用主要是矫正指针&#xff08;口语化的说法&#xff09; 2.虚函数不一定要重新定义&#xff0c;纯虚函数一定要定义&#xff08;口语化的说法&#xff09; 1&#xff09;. 虚函数的作用主要是矫正指针&#xff0c;使得基类的指针…

【Python数据结构与判断1/7】复杂的多向选择

目录 导入 举个栗子 代码优化 elif 栗子 执行顺序 情况一 情况二 情况三 if-elif-else特性 三种判断语句小结 if if-else if-elif-else 嵌套语句 if嵌套 栗子 执行顺序 相互嵌套 Tips Debug 总结 导入 在前面&#xff0c;我们学习了单向选择的if语句和多项…

Decontam去污染:一个尝试

为了程序运行的便利性&#xff0c;不想将Decontam放到windows的Rstudio里面运行&#xff0c;需要直接在Ubuntu中运行&#xff0c;并且为了在Decontam时进行其他操作&#xff0c;使用python去运行R 首先你需要有一个conda环境&#xff0c;安装了R&#xff0c;Decontam&#xff0…

day 49 动态规划 part 10● 121. 买卖股票的最佳时机 ● 122.买卖股票的最佳时机II

看了题解&#xff0c;第一种暴力&#xff0c;两个for循环。 class Solution { public:int maxProfit(vector<int>& prices) {int result 0;for (int i 0; i < prices.size(); i) {for (int j i 1; j < prices.size(); j){result max(result, prices[j] -…

使用scrapy爬取蜻蜓FM

创建框架和项目 ### 1. 创建虚拟环境 conda create -n spiderScrapy python3.9 ### 2. 安装scrapy pip install scrapy2.8.0 -i https://pypi.tuna.tsinghua.edu.cn/simple### 3. 生成一个框架并进入框架 scrapy startproject my_spider cd my_spider### 4. 生成项目 scrapy …

LeetCode:143.重排链表

143. 重排链表 解题过程 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}* ListNode(int val) { this.val val; }* ListNode(int val, ListNode next) { this.val val; this.next next; …

数据结构——堆的应用 Topk问题

&#x1f49e;&#x1f49e; 前言 hello hello~ &#xff0c;这里是大耳朵土土垚~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f4a5;个人主页&#x…

实验一:华为VRP系统的基本操作

1.1实验介绍 1.1.1关于本实验 本实验通过配置华为设备&#xff0c;了解并熟悉华为VRP系统的基本操作 1.1.2实验目的 理解命令行视图的含义以及进入离开命令行视图的方法 掌握一些常见的命令 掌握命令行在线帮助的方法 掌握如何撤销命令 掌握如何使用命令快捷键 1.1.3实验组网 …

挑战杯 基于设深度学习的人脸性别年龄识别系统

文章目录 0 前言1 课题描述2 实现效果3 算法实现原理3.1 数据集3.2 深度学习识别算法3.3 特征提取主干网络3.4 总体实现流程 4 具体实现4.1 预训练数据格式4.2 部分实现代码 5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于深度学习机器视觉的…

React组件(函数式组件,类式组件)

函数式组件 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>React Demo</title> <!-- 引…

阿里云服务器9元1个月优惠价格表

阿里云服务器9元1个月优惠价格表&#xff0c;用不上9元&#xff0c;又降价了&#xff0c;只要5元。阿里云服务器一个月多少钱&#xff1f;最便宜5元1个月。阿里云轻量应用服务器2核2G3M配置61元一年&#xff0c;折合5元一个月&#xff0c;2核4G服务器30元3个月&#xff0c;2核2…

深入探讨AI团队的角色分工

目录 前言1. 软件工程师&#xff1a;构建系统基石的关键执行者2. 机器学习工程师&#xff1a;数据与模型的塑造专家3. 机器学习研究员&#xff1a;引领算法创新的智囊4. 机器学习应用科学家&#xff1a;理论与实践的巧妙连接5. 数据分析师&#xff1a;洞察数据&#xff0c;智慧…

C语言初学10:typedef

一、作用 为用户定义的数据类型取一个新名字 二、对结构体使用typedef定义新的数据类型名字 #include <stdio.h> #include <string.h>typedef struct Books //使用 typedef 来定义一个新的数据类型名字 {char title[50];} book;int main( ) {//book是typedef定…

图片在div完全显示

效果图&#xff1a; html代码&#xff1a; <div class"container" style" display: flex;width: 550px;height: 180px;"><div class"box" style" color: red; background-color:blue; width: 50%;"></div><div …

基于冠豪猪优化算法(Crested Porcupine Optimizer,CPO)的无人机三维路径规划(MATLAB)

一、无人机路径规划模型介绍 无人机三维路径规划是指在三维空间中为无人机规划一条合理的飞行路径&#xff0c;使其能够安全、高效地完成任务。路径规划是无人机自主飞行的关键技术之一&#xff0c;它可以通过算法和模型来确定无人机的航迹&#xff0c;以避开障碍物、优化飞行…

企业如何安全参与开源项目?

【开源三句半】 企业参与开源潮&#xff0c; 安全创新都重要&#xff0c; 持续投入不可少&#xff0c; 眼光独到。 开源已经成为构建现代软件的常见方式&#xff0c;这不仅局限于IT技术本身&#xff0c;更推动了多个行业的数字化发展。企业决定引入开源项目打造商业软件时&…