MNIST手写字符分类

MNIST手写字符分类

文章目录

  • MNIST手写字符分类
    • 1 数据集
    • 2 模型构建
    • 3 训练
    • 4 模型保存
    • 5 推理
    • 6 模型导出
    • 7 导出模型测试

1 数据集

  MNIST手写字符集包括60000张用于训练的训练集图片和10000张用于测试的测试集图片,所有图片均归一化为28*28的灰度图像。其中字符区域为白色,非字符区域为黑色。
  该数据集可以直接通过pytorch dataset进行下载。

2 模型构建

  本次试验采用图像分类的方案进行手写字符分类。分类网络采用线性层->非线性函数堆叠的方式实现。本次试验设计了两个网络,分别是3层网络模型以及4层网络模型。

  模型定义如下:

import torch
from torch import nn
from torch.utils.data import DataLoader

class ZKNNNet_3Layer(nn.Module):
    def __init__(self):
        super(ZKNNNet_3Layer, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

class ZKNNNet_5Layer(nn.Module):
    def __init__(self):
        super(ZKNNNet_5Layer, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

3 训练

  模型训练过程主要包括确定数据集、构建dataloader、确定训练设备(device)、生成模型对象、定义优化器、定义目标函数、设定模型为训练模式、获取输入与标签、进行模型推理、计算损失函数、优化器梯度清零、损失反向传播、优化器步进等。

  训练代码如下:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_3Layer, ZKNNNet_5Layer
import os

# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = ZKNNNet_3Layer()
if os.path.exists("model/model_3layer.pth"):
    model.load_state_dict(torch.load("model/model_3layer.pth"))
model = model.to(device)
print(model)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Train
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# Test
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

epochs = 200
maxAcc = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    currentAcc = test(test_dataloader, model)
    if maxAcc < currentAcc:
        maxAcc = currentAcc
        torch.save(model.state_dict(), "model/model_3layer.pth")
print("Done!")

  以上脚本训练的是3层网络,经过200个epoch的训练,模型精度可以达到96%左右。

  以下为训练5层网络模型的训练代码。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_3Layer, ZKNNNet_5Layer
import os
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = ZKNNNet_5Layer()
if os.path.exists("model/model_5layer.pth"):
    model.load_state_dict(torch.load("model/model_5layer.pth"))
model = model.to(device)
print(model)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Train
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# Test
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

epochs = 200
maxAcc = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    currentAcc = test(test_dataloader, model)
    if maxAcc < currentAcc:
        maxAcc = currentAcc
        torch.save(model.state_dict(), "model/model_5layer.pth")
print("Done!")

4 模型保存

  在pytorch中,模型保存通过torch.save即可完成模型的保存。通常使用的方式是torch.save(model.state_dice(),'save_model_path.pth')
  这里需要注意的时候,通常我们只需要保存在测试集上准确率最高的模型,此时,可以结合训练过程中的测试过程,在计算了在测试集上的准确率之后,根据准确率进行保存。

# Test
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

epochs = 200
maxAcc = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    currentAcc = test(test_dataloader, model)
    if maxAcc < currentAcc:
        maxAcc = currentAcc
        torch.save(model.state_dict(), "model/model_5layer.pth")

这部分代码显示了在test()测试过程中计算模型在测试集上的准确率并返回,在保存模型时,判断当前测试集准确率是否优于历史最优准确率,高于的话就保存模型。

5 推理

  在进行完模型训练并保存了最优模型之后,我们需要对模型进行推理测试。通常用于在模型训练完成后,使用少量数据进行模型推理结果可视化,方便排查模型性能。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from ZKNNNet import ZKNNNet_3Layer

import matplotlib.pyplot as plt

# Get cpu or gpu device for inference.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device for inference".format(device))

# Load the trained model
model = ZKNNNet_3Layer()
model.load_state_dict(torch.load("model/model_3layer.pth"))
model.to(device)
model.eval()

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=64)

# Perform inference
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Visualize the image and its predicted result
        for i in range(len(images)):
            image = images[i].cpu()
            label = labels[i].cpu()
            prediction = predicted[i].cpu()

            plt.imshow(image.squeeze(), cmap='gray')
            plt.title(f"Label: {label}, Predicted: {prediction}")
            plt.show()

    accuracy = 100 * correct / total
    print("Accuracy on test set: {:.2f}%".format(accuracy))

在这里插入图片描述

6 模型导出

  在模型训练完成,经过推理可视化,证明模型精度可用的时候,可以对模型进行导出。因为pytorch模型结构为pth结构,该模型文件不利于部署。通常可以将pth模型文件导出为onnx模型文件。onnx模型文件通常是各种推理框架支持的中间模型文件格式。

import torch
import torch.utils
import os
from ZKNNNet import ZKNNNet_3Layer,ZKNNNet_5Layer

device = "cpu"
print("Using {} device".format(device))
model_3Layer = ZKNNNet_3Layer()
if os.path.exists('./model/model_3layer.pth'):
    model_3Layer.load_state_dict(torch.load('./model/model_3layer.pth'))
model_3Layer = model_3Layer.to(device)

model_3Layer.eval()

# export pytorch model to onnx
torch.onnx.export(model_3Layer, torch.randn(1, 1, 28, 28), './model/model_3layer.onnx', verbose=True)

model_5Layer = ZKNNNet_5Layer()
if os.path.exists('./model/model_5layer.pth'):
    model_5Layer.load_state_dict(torch.load('./model/model_5layer.pth'))
model_5Layer = model_5Layer.to(device)
model_5Layer.eval()
torch.onnx.export(model_5Layer,torch.randn(1,1,28,28),'./model/model_5layer.onnx',verbose=True)

通过上述脚本,可以将之前训练好的3层模型和5层模型文件进行导出,生成对应的onnx模型文件。

7 导出模型测试

  在onnx模型文件生成之后,需要对onnx模型文件进行测试,判断整个模型文件导出过程、推理过程是否正确。

import onnxruntime as rt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets

import matplotlib.pyplot as plt

from PIL import Image

sess = rt.InferenceSession("model/model_3layer.onnx")
input_name = sess.get_inputs()[0].name
print(input_name)

image = Image.open('./data/test/2.png')
image_data = np.array(image)
image_data = image_data.astype(np.float32)/255.0
image_data = image_data[None,None,:,:]
print(image_data.shape)

outputs = sess.run(None,{input_name:image_data})
outputs = np.array(outputs).flatten()

prediction = np.argmax(outputs)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted: {prediction}")
plt.show()

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=1)

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.numpy()
        labels = labels.numpy()
        outputs = sess.run(None,{input_name:images})[0]
        outputs = np.array(outputs).flatten()
        prediction = np.argmax(outputs)

        # Visualize the image and its predicted result
        for i in range(len(images)):
            image = images[i]
            label = labels[i]

            plt.imshow(image.squeeze(), cmap='gray')
            plt.title(f"Label: {label}, Predicted: {prediction}")
            plt.show()

以上推理onnx模型的过程中,分别选择了直接图片输入和来自dataloader输入两种方式。这里需要注意的是,在使用图片进行输入时,需要注意数据范围需要与训练时的dataloader输入方式时一致。在使用dataloader时,图像数据是归一化到0-1的范围内的,而使用PIL读取图片之后是uint8的范围0-255的数据,因此在推理之前需要先进行数据范围转化,将输入数据转换成范围0-1的float32类型的数据。否则可能导致推理结果错误。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/706532.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue31-自定义指令:总结

一、自定义函数的陷阱 1-1、自定义函数名 自定义函数名&#xff0c;不能用驼峰式&#xff01;&#xff01;&#xff01; 示例1&#xff1a; 示例2&#xff1a; 1-2、指令回调函数的this 【回顾】&#xff1a; 所有由vue管理的函数&#xff0c;里面的this直接就是vm实例对象。…

小主机折腾记26

双独立显卡调用问题 前两天将tesla p4从x99大板上拆了下来&#xff0c;将880G5twr上的rx480 4g安装到了x99大板上&#xff0c;预计是dg1输出&#xff0c;rx480做3d运算。安装完驱动后&#xff0c;还想着按照之前tesla p4的设置方法去设置rx480&#xff0c;结果果然&#xff0c…

边坡监测规范:确保边坡工程安全稳定的专业准则

边坡工程是土木工程中不可或缺的一部分&#xff0c;其安全性直接关系到工程整体的质量与稳定性。因此&#xff0c;在边坡工程中实施有效的监测措施&#xff0c;遵循一系列专业的监测规范&#xff0c;对于预防边坡失稳、滑坡等灾害的发生&#xff0c;保障人民群众的生命财产安全…

使用 Vue 和 Ant Design 实现抽屉效果的模块折叠功能

功能描述&#xff1a; 有两个模块&#xff0c;点击上面模块的收起按钮时&#xff0c;上面的模块可以折叠&#xff0c;下面的模块随之扩展 代码实现&#xff1a; 我们在 Vue 组件中定义两个模块的布局和状态管理&#xff1a; const scrollTableY ref(560); // 表格初始高度…

ssm161基于web的资源共享平台的共享与开发+jsp

资源共享平台设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本资源共享平台就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处…

tkinter滚动条Scrollbar

tkinter滚动条Scrollbar 滚动条Scrollbar滚动条的基本用法效果代码 滚动条Scrollbar 在Tkinter中&#xff0c;**滚动条&#xff08;Scrollbar&#xff09;**是一个允许用户在较大的内容区域内平移的组件。**滚动条通常与文本框&#xff08;Text&#xff09;、列表框&#xff0…

cspccf备考

13年12月CCF计算机软件能力认证 3192. 出现次数最多的数 给定 n n n个正整数&#xff0c;找出它们中出现次数最多的数。 如果这样的数有多个&#xff0c;请输出其中最小的一个。 输入格式 输入的第一行只有一个正整数 n n n,表示数字的个数。 输入的第二行有 n n n个整数 …

什么是基于风险的漏洞管理RBVM,及其优势

文章目录 一、什么是漏洞管理二、什么是基于风险的漏洞管理RBVM三、RBVM的基本流程四、RBVM的特点和优势 一、什么是漏洞管理 安全漏洞是网络或网络资产的结构、功能或实现中的任何缺陷或弱点&#xff0c;黑客可以利用这些缺陷或弱点发起网络攻击&#xff0c;获得对系统或数据…

FFMpeg解复用流程

文章目录 解复用流程图复用器与解复用器小结 解复用流程图 流程图&#xff0c;如上图所示。 复用器与解复用器 复用器&#xff0c;就是视频流&#xff0c;音频流&#xff0c;字幕流&#xff0c;其他成分&#xff0c;按照一定规则组合成视频文件&#xff0c;视频文件可以是mp4…

c语言利用openssl实现简单客户端和服务端(观察记录层最大长度)

文章目录 前言一、客户端实现二、服务端实现总结 前言 本文是使用openssl111w实现的简单客户端和服务端&#xff0c;主要用于观察openssl一个记录层数据包的大小。 一、客户端实现 #include <stdio.h> #include <stdlib.h> #include <string.h> #inc…

搜维尔科技:Movella旗下的Xsens在人形机器人开发中得到广泛应用

人形机器人的发展正在全球范围内受到广泛关注。作为机器人领域的重要分支&#xff0c;人形机器人因其具备高度仿真的外观和动作&#xff0c;以及更贴近人类的行为模式&#xff0c;有望逐渐成为人们日常生活和工业生产中的得力助手。在中国&#xff0c;这一领域的发展尤为引人注…

湘江早报专访惟客数据李柯辰:湖南伢子返湘玩转“AI+金融”

来源 |《湘江早报》 记者 | 黄荣佳 ​ 随着数字化浪潮的到来&#xff0c;AI的风吹遍了各行各业&#xff0c;金融作为对新兴技术最敏感的行业&#xff0c;前沿技术的赋能&#xff0c;让金融科技成为行业发展的“新赛点”。作为一家以大数据和AI人工智能技术驱动的新一代数字化…

MYSQL六、存储引擎的认识

一、存储引擎 1、MySQL体系结构 连接层&#xff1a;最上层是一些客户端和链接服务&#xff0c;包含本地sock 通信和大多数基于客户端/服务端工具实现的类似于TCP/IP的通信。主要完成一些类似于连接处理、授权认证、及相关的安全方案。在该层上引入了线程池的概念&#xff0c;为…

使用大模型进行时间序列预测

今天想聊聊这周一篇关于使用语言模型进行时间序列预测的工作&#xff0c;这个工作的主要亮点有四个: 首先提出的Chronos框架将时间序列通过缩放和量化转换为token序列&#xff0c;从而可以直接使用语言模型架构(如T5, GPT-2等)来建模时间序列&#xff0c;不需要对模型架构做任…

Word菜谱制作教程

原始文本&#xff1a; 打开标尺 选中文字右键-段落&#xff0c; 制表位&#xff0c;选好字符和引导符 在文字和价格之间按下Tab 效果 参考资料好看视频-轻松有收获 Phrase&#xff1a;我觉等还是有点麻烦&#xff0c;可以插入表格&#xff0c;再把表格调整为无表框即可

60行代码加速20倍: NEON实现深度学习OD任务后处理绘框

【前言】 本文版权属于GiantPandaCV&#xff0c;未经允许&#xff0c;请勿转载&#xff01; 最近在学neon汇编加速&#xff0c;由于此前OD任务发现在检测后处理部分使用OpenCV较为占用资源且耗时&#xff0c;遂尝试使用NEON做后处理绘框&#xff0c;以达到加速并降低CPU资源消耗…

PHP简约轻型聊天室留言源码

无名轻聊是一款phptxt的轻型聊天室。 无名轻聊特点&#xff1a; 自适应电脑/手机 数据使用txt存放&#xff0c;默认显示近50条聊天记录 采用jqueryajax轮询方式&#xff0c;适合小型聊天环境。 访问地址加?zhi进入管理模式&#xff0c;发送 clear 清空聊天记录。 修改在…

品质卓越为你打造App UI 风格

品质卓越为你打造App UI 风格

【ElasticSearch】ElasticSearch基本概念

ES 是一个开源的高扩展的分布式全文检索引擎&#xff0c;它是对开源库 Luence 的封装&#xff0c;提供 REST API 接口 MySQL 更适合数据的存储和关系管理&#xff0c;即 CRUD&#xff1b;而 ES 更适合做海量数据的检索和分析&#xff0c;它可以秒级地从数据库中检索出我们感兴…

【论文复现|智能算法改进】基于改进鲸鱼优化算法的移动机器人多目标点路径规划

目录 1.算法原理2.数学模型3.改进点4.结果展示5.参考文献6.代码获取 1.算法原理 SCI二区|鲸鱼优化算法&#xff08;WOA&#xff09;原理及实现【附完整Matlab代码】 2.数学模型 使用 A* 算法生成所有目标点之间的距离矩阵U: U [ d 1 − 1 d 1 − 2 d 1 − 3 ⋯ d 1 − i d…