MNIST手写数字识别

        本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在Pytorch

       Pytorch(2)---MNIST手写数字识别》

MNIST手写数字识别

目录

一、 实验目的

二、 实验内容

2.1 MNIST数据集介绍

2.2 代码解析

三、运行结果


一、 实验目的

        掌握利用卷积神经网络CNN实现对MNIST手写数字的识别。一个简单的神经网络实验


二、 实验内容

2.1 MNIST数据集介绍

        MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。官方下载网站:http://yann.lecun.com/exdb/mnist/,下载得到的数据集一共包含4个文件,训练集、训练集标签、测试集、测试集标签。

type Markdown and LaTeX: α2α2

        直接下载得到的数据集是无法通过普通的解压或者应用程序打开的,因为这些文件不是任何标准的图像格式而是以字节的形式进行存储的,所以必须编写相关程序将其打开。

        而torchvision.datasets包中已经包含MNIST数据集,可以通过在编译器中输入代码进行数据集的获取,步骤如下:

        Step1:归一化,softmax归一化指数函数(https://blog.csdn.net/lz_peter/article/details/84574716),其中0.1307是mean均值和0.3081是std标准差 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        Step2:下载/获取数据集,其中root为数据集存放路径,train=True即训练集否则为测试集。 train_dataset = datasets.MNIST(root='./data/', train=True, download=False, transform=transform) # train=True训练集,=False测试集 test_dataset = datasets.MNIST(root='./data/', train=False, download=Falese, transform=transform) # download=True表示可自动下载,此处=False表示直接从本地导入

        Step3:实例化一个dataset后,然后用Dataloader 包起来,即载入数据集。这里的batch_size为超参数;shuffle=True即打乱数据集,这里我们打乱训练集进行训练,而对测试集进行顺序测试。 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        可以按照上述尝试下载数据集,但本次实验过程为节约时间,可直接导入本地下载好的数据集,本地保存相对路径为:root='./data/'

2.2 代码解析

首先导入运行过程中需要的文件包:

import torch
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
"""
卷积运算 使用mnist数据集,和10-4,11类似的,只是这里:1.输出训练轮的acc 2.模型上使用torch.nn.Sequential
添加载入测试集数据
"""
# Super parameter ------------------------------------------------------------------------------------
batch_size = 64
learning_rate = 0.01
momentum = 0.5
EPOCH = 10
# Prepare dataset ------------------------------------------------------------------------------------
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# softmax归一化指数函数(https://blog.csdn.net/lz_peter/article/details/84574716),其中0.1307是mean均值和0.3081是std标准差

train_dataset = datasets.MNIST(root='./data/', train=True, download=False, transform=transform) 
# download=False加载本地数据集,否则自动下载;train=True训练集,=False测试集
#从本地导入,transform的方式与训练集数据一致
test_dataset =  datasets.MNIST(root='./data/', train=False, download=True, transform=transform) 

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) #shuffle=True 表示数据集打乱
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

展示MNIST数据集

        这里举例展示12幅图,包含图片内容和标签。

fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i+1)
    plt.tight_layout()
    plt.imshow(train_dataset.train_data[i], cmap='gray', interpolation='none')
    plt.title("Labels: {}".format(train_dataset.train_labels[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

 构建CNN网络模型

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
        # 第二层卷积核维度是[10, 20, 5 ,5], 输入通道是10, 输出通道是20, 卷积核大小是 5*5, (默认padding=0, stride = 1)
        # 第二层的激活函数是Relu
        # 第二层池化层: 采用的是MaxPooling, 大小是2*2.
        
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(10, 20, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
            
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(320, 50),
            torch.nn.Linear(50, 10),
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)  # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)
        x = self.conv2(x)  # 再来一次
        x = x.view(batch_size, -1)  # flatten 变成全连接网络需要的输入 (batch, 20,4,4) ==> (batch,320), -1 此处自动算出的是320
        x = self.fc(x)
        return x  # 最后输出的是维度为10的,也就是(对应数学符号的0~9)
# 实例化模型:
model = Net()
# 损失函数和优化器损失函数  使用交叉熵损失参数优化使用随机梯度下降
# Construct loss and optimizer ------------------------------------------------------------------------------
criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)  # lr学习率,momentum冲量

模型训练

Step1:前馈(forward propagation)

Step2:反馈(backward propagation)

Step3:更新(update)

def train(epoch):
    running_loss = 0.0  # 这整个epoch的loss清零
    running_total = 0
    running_correct = 0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()

        # forward + backward + update
        outputs = model(inputs)
        loss = criterion(outputs, target)

        loss.backward()
        optimizer.step()

        # 把运行中的loss累加起来,为了下面300次一除
        running_loss += loss.item()
        # 把运行中的准确率acc算出来
        _, predicted = torch.max(outputs.data, dim=1)
        running_total += inputs.shape[0]
        running_correct += (predicted == target).sum().item()

        if batch_idx % 300 == 299:  # 不想要每一次都出loss,浪费时间,选择每300次出一个平均损失,和准确率
            print('[%d, %5d]: loss: %.3f , acc: %.2f %%'
                  % (epoch + 1, batch_idx + 1, running_loss / 300, 100 * running_correct / running_total))
            running_loss = 0.0  # 这小批300的loss清零
            running_total = 0
            running_correct = 0  # 这小批300的acc清零
            torch.save(model.state_dict(), './model.pth') #保存模型参数
            torch.save(optimizer.state_dict(), './optimizer.pth') #保存优化器参数

 测试轮

        测试集不用算梯度(无需反馈),首先从test_loader中读取每一次的图片和标签,进行前馈运算后,预测每一轮的准确率

        测试轮代码

def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():  # 测试集不用算梯度
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度,沿着行(第1个维度)去找1.最大值和2.最大值的下标
            total += labels.size(0)  # 张量之间的比较运算,表示预测标签的所有次数
            correct += (predicted == labels).sum().item() #标签预测准确的次数
    # acc表示标签预测的准确率,即预测准确次数除以预测标签的所有次数,acc的计算代码公式
    
    acc = correct/total
    print('[%d / %d]: Accuracy on test set: %.1f %% ' % (epoch+1, EPOCH, 100 * acc))  # 求测试的准确率,正确数/总数
    return acc

开始训练

        超参数:用到的超参数主要有小批量数据的batch size,梯度下降算法中用到的学习率(learning rate)和冲量(momentum),同时定义进行10轮次的训练。 主函数:共进行10轮次的训练:每训练一轮,就进行一次测试。

if __name__ == '__main__':
    acc_list_test = [] #创建保存测试数据的列表
    for epoch in range(EPOCH):
        train(epoch)
        # if epoch % 10 == 9:  #每训练10轮 测试1次
        acc_test = test()
        #接下来需要把每次的测试结果即acc_test,添加到列表acc_list_test中
        acc_list_test.append(acc_test)
        
    plt.plot(acc_list_test)
    plt.xlabel('TEST-EPOCH')
    plt.ylabel('Accuracy On TestSet')
    plt.show()

 三、运行结果 

        文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/201407.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

OpenMMlab导出FCN模型并用onnxruntime推理

导出onnx文件 直接使用脚本 import torch from mmseg.apis init_modelconfig_file configs/fcn/fcn_r18-d8_4xb2-80k_cityscapes-512x1024.py checkpoint_file fcn_r18-d8_512x1024_80k_cityscapes_20201225_021327-6c50f8b4.pth model init_model(config_file, checkpoin…

YOLOv8 代码部署

一、获取代码 YOLOv8官方GitHub网址 https://github.com/ultralytics/ultralytics 获取YOLOv8代码压缩包 二、虚拟环境配置 这个就不写了,装个Anaconda,网上教程很多 三、PyCharm安装与配置(可选) 这个也不写了,…

磁环电感参数计算

磁环电感参数计算 1.什么是电感磁饱和2.电感饱和的原因3.电感饱和带来的影响3.1 感应电动势变化3.2 电感值变化3.3 功率损耗增加3.4 系统稳定性受到影响4.饱和电流计算最近在做DC/DC电源,电感是用磁环绕制的,所以关注一下磁环绕制电感参数的计算,学习学习。 某款磁环参数。 …

SpringBoot——Swagger2 接口规范

优质博文:IT-BLOG-CN 如今,REST和微服务已经有了很大的发展势头。但是,REST规范中并没有提供一种规范来编写我们的对外REST接口API文档。每个人都在用自己的方式记录api文档,因此没有一种标准规范能够让我们很容易的理解和使用该…

3dMax导出glft和glb格式模型插件Max2Babylon教程

为了满足Autodesk提供自己的导出管道之前的迫切需要,Babylon.js导出器可用于3dMax。导出器可以将3dMax场景导出为.glTF文件、.glb文件或.babylon文件。 【适用版本】 3dMax2015 - 2024 【安装方法】 1.选择和自己电脑中3dMax所对应的插件版本,解压缩。…

基于Qt MP3音频播放器示例(可制作音频播放器)

​本次MP3文件也给出来,方便大家调试。话不多说直接上源码。 整个项目下载地址:CSDN:GetCode 昵称-》Qt魔术师:https://gitcode.com/m0_45463480/QtMP3/tree/main## .pro # 指定项目类型为应用程序。TEMPLATE = app# 指定项目的名称为musicplayerTARGET = musicplayer# 添…

Matlab下载许可证文件 教程(在账号有许可证的前提下)

文章目录 Part.I IntroductionPart.II 许可证文件过期解决方案Chap.I 使用 Internet 自动激活Chap.II 在不使用 Internet 的情况下手动激活 Part.I Introduction 本文主要介绍,在 Mathwork 账号有许可证的前提下,下载许可证的操作流程。 好久没有用 Mat…

使用Redis实现接口防抖

说明:实际开发中,我们在前端页面上点击了一个按钮,访问了一个接口,这时因为网络波动或者其他原因,页面上没有反应,用户可能会在短时间内再次点击一次或者用户以为没有点到,很快的又点了一次。导…

基于springboot的电影院管理系统的设计与实现 (含论文和源码视频导入教程)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于springboot的电影院管理系统7拥有两种角色 管理员:用户管理、购票统计、电影管理、电影类型管理、放映厅管理、订单管理等 用户:登录注册、查看各种信息、购票…

uniapp打包ios有时间 uniapp打包次数

我们经常用的解决方案有,分包,将图片上传到服务器上,减少插件引入。但是还有一个方案好多刚入门uniapp的人都给忽略了,就是在源码视图中配置,开启分包优化。 1.分包 目前微信小程序可以分8个包,每个包的最大存储是2M,也就是说你文件总体的大小不能超过16M,每个包的大…

【模板】KMP算法笔记

练习链接:【模板】KMP - 洛谷 题目: 输入 ABABABC ABA 输出 1 3 0 0 1 思路: 根据题意,用到的是KMP算法,KMP算法思想是通过一个一个匹配首字母的原理进行整个匹配效果,当某个首字母不匹配的时候&#x…

箭头函数与普通函数:谁更胜一筹?

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

ASUS(华硕) B760M-AYW WIFI D4_解决wifi不能使用

1、最近新购买了一套 diy电脑主机,选用的是 ASUS B760M-AYW WIFI D4电脑主板 win10 系统,到货后 发现右下角电脑图标处及网络适配器中 没有wifi选项 首先 在官网和旗舰店客服处,确认了 该主板 有集成wifi模块,鲨鱼鳍天线未安装…

Java数据结构之《直接插入排序》问题

一、前言: 这是怀化学院的:Java数据结构中的一道难度中等的一道编程题(此方法为博主自己研究,问题基本解决,若有bug欢迎下方评论提出意见,我会第一时间改进代码,谢谢!) 后面其他编程题只要我写完…

GoLang切片

一、切片基础 1、切片的定义 切片(Slice)是一个拥有相同类型元素的可变长度的序列它是基于数组类型做的一层封装它非常灵活,支持自动扩容切片是一个引用类型,它的内部结构包含地址、长度和容量声明切片类型的基本语法如下&#…

qt-C++笔记之主线程中使用异步逻辑来处理ROS事件循环和Qt事件循环解决相互阻塞的问题

qt-C笔记之主线程中使用异步逻辑来处理ROS事件循环和异步循环解决相互阻塞的问题 code review! 文章目录 qt-C笔记之主线程中使用异步逻辑来处理ROS事件循环和异步循环解决相互阻塞的问题1.Qt的app.exec()详解2.ros::spin()详解3.ros::AsyncSpinner详解4.主线程中结合使用的示…

hyper-V操作虚拟机ubuntu 22.03

安装hyper-V 点击卸载程序 都勾选上即可 新建虚拟机,选择镜像文件 选择第一代即可 设置内存 配置网络 双击 启动安装虚拟机 输入用户名 zenglg 密码:LuoShuwen123456 按照enter键选中openssh安装 安装中 安装完成 选择重启 输入用户名、密码

Java进阶(第三期): JDK版本接口的新特性 内部类(成员类、静态类、局部类、匿名类) Lambda表达式、简写规则

Java进阶(第三期) ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ 文章目录 Java基础(第三期)一、接口新特性1.1 JDK8版本1.2 JDK9版本 代码块二、内部类1、成员内部类1.2 内部类成员访问 2、 静态内部类3、 局部…

Python三百行代码实现一简约个人博客网站(全网最小巧)

这是全互联网最小巧的博客,没有比这更小的了。虽然小巧,但功能一点儿也不弱,支持文章的分页展示,文章表格,图片和代码语法高亮。文章无限制分类,访问量统计,按时间和按点击量排序,展…

CPU虚拟化的过程

VMCS 是Virtual Machine Control Structure。是 Intel 实现 CPU 虚拟化,记录 vCPU 状态的一个关键数据结构。VMCS 数据结构主要包含以下信息。 Guest-state area,即 vCPU 的状态信息,包括 vCPU 的基本运行环境,例如寄存器等。Hos…