CNN实现手写数字识别(Pytorch)

CNN结构

CNN(卷积神经网络)主要包括卷积层、池化层和全连接层。输入数据经过多个卷积层和池化层提取图片信息后,最后经过若干个全连接层获得最终的输出。
在这里插入图片描述CNN的实现主要包括以下步骤:

  1. 数据加载与预处理
  2. 模型搭建
  3. 定义损失函数、优化器
  4. 模型训练
  5. 模型测试

以下基于Pytorch框架搭建一个CNN神经网络实现手写数字识别。

CNN实现

此处使用MNIST数据集,包含60000个训练样本和10000个测试样本。分为图片和标签,每张图片是一个 28 × 28 28 \times 28 28×28 的像素矩阵,标签是0~9一共10种数字。每个样本的格式为[data, label]。

1. 导入相关库

import numpy as np
import torch 
from torch import nn
from torchvision import datasets, transforms,utils
from PIL import Image
import matplotlib.pyplot as plt

2. 数据加载与预处理

# 定义超参数
batch_size = 128 # 每个批次(batch)的样本数

# 对输入的数据进行标准化处理
# transforms.ToTensor() 将图像数据转换为 PyTorch 中的张量(tensor)格式,并将像素值缩放到 0-1 的范围内。
# 这是因为神经网络需要的输入数据必须是张量格式,并且需要进行归一化处理,以提高模型的训练效果。
# transforms.Normalize(mean=[0.5],std=[0.5]) 将图像像素值进行标准化处理,使其均值为 0,标准差为 1。
# 输入数据进行标准化处理可以提高模型的鲁棒性和稳定性,减少模型训练过程中的梯度爆炸和消失问题。
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5],std=[0.5])])

# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transform, 
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transform, 
                                          download=True)
                                          
# 创建数据加载器(用于将数据分次放进模型进行训练)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, # 装载过程中随机乱序
                                           num_workers=2) # 表示2个子进程加载数据
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False,
                                          num_workers=2) 

加载完数据后,可以得到60000个训练样本和10000个测试样本

print(len(train_dataset))
print(len(test_dataset))

在这里插入图片描述

以及469个训练批次和79测试批次

# batch=128
# train_loader=60000/128 = 469 个batch
# test_loader=10000/128=79 个batch
print(len(train_loader))
print(len(test_loader))

在这里插入图片描述

打印前5个手写数字样本看看

for i in range(0,5):
    oneimg,label = train_dataset[i]
    grid = utils.make_grid(oneimg)
    grid = grid.numpy().transpose(1,2,0) 
    std = [0.5]
    mean = [0.5]
    grid = grid * std + mean
    # 可视化图像
    plt.subplot(1, 5, i+1)
    plt.imshow(grid)
    plt.axis('off')

plt.show()

在这里插入图片描述
这里用了 make_grid() 函数将多张图像拼接成一张网格图像,并调整了网格图像的形状,使得它可以直接作为 imshow() 函数的输入。这种方式可以在一张图中同时显示多张图像,比单独显示每张图像更加方便,常用于可视化深度学习中的卷积神经网络(CNN)中的特征图、卷积核等信息。
在 PyTorch 中,默认的图像张量格式是 (channel, height, width),即通道维度在第一个维度。 torchvision.transforms.ToTensor() 函数会将 PIL 图像对象转换为 PyTorch 张量,并将通道维度放在第一个维度。因此,当我们使用 ToTensor() 函数加载图像数据时,得到的 PyTorch 张量的格式就是 (channel, height, width)。代码中的 oneimg.numpy().transpose(1,2,0) 就是将 PyTorch 张量 oneimg 转换为 NumPy 数组,然后通过 transpose 函数将图像数组中的通道维度从第一个维度(channel-first)调整为最后一个维度(channel-last),即将 (channel, height, width) 调整为 (height, width, channel),以便于 Matplotlib 库正确处理通道信息。

2. 模型搭建

我们将使用Pytorch构建一个如下图所示的CNN,包含两个卷积层,和全连接层,并使用Relu作为激活函数。
在这里插入图片描述
接下来看以下不同层的参数。

卷积层: Connv2d

  • in_channels ——输入数据的通道数目
  • out_channels ——卷积产生的通道数目
  • kernel_size ——卷积核的尺寸
  • stride——步长
  • padding——输入数据的边缘填充0的层数

池化层: MaxPool2d

  • kernel_siez ——池化核大小
  • stride——步长
  • padding——输入数据的边缘填充0的层数

全连接层: Linear

  • in_features:输入特征数
  • out_features:输出特征数

代码实现如下:

class CNN(nn.Module):
    # 定义网络结构
    def __init__(self):
        super(CNN, self).__init__()
        # 图片是灰度图片,只有一个通道
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, 
                               kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, 
                               kernel_size=5, stride=1, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=7*7*32, out_features=256)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=256, out_features=10)
	
    # 定义前向传播过程的计算函数
    def forward(self, x):
        # 第一层卷积、激活函数和池化
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        # 第二层卷积、激活函数和池化
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        # 将数据平展成一维
        x = x.view(-1, 7*7*32)
        # 第一层全连接层
        x = self.fc1(x)
        x = self.relu3(x)
        # 第二层全连接层
        x = self.fc2(x)
        return x

定义损失函数和优化函数

import torch.optim as optim

learning_rate = 0.001 # 学习率

# 定义损失函数,计算模型的输出与目标标签之间的交叉熵损失
criterion = nn.CrossEntropyLoss()
# 训练过程通常采用反向传播来更新模型参数,这里使用的是SDG(随机梯度下降)优化器
# momentum 表示动量因子,可以加速优化过程并提高模型的泛化性能。
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
#也可以选择Adam优化方法
# optimizer = torch.optim.Adam(net.parameters(),lr=1e-2)

3. 模型训练

model = CNN() # 实例化CNN模型
num_epochs = 10 # 定义迭代次数

# 如果可用的话使用 GPU 进行训练,否则使用 CPU 进行训练。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 将神经网络模型 net 移动到指定的设备上。
model = model.to(device)
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images,labels) in enumerate(train_loader):
        images=images.to(device)
        labels=labels.to(device)
        optimizer.zero_grad() # 清空上一个batch的梯度信息
        # 将输入数据 inputs 喂入神经网络模型 net 中进行前向计算,得到模型的输出结果 outputs。
        outputs=model(images) 
        # 使用交叉熵损失函数 criterion 计算模型输出 outputs 与标签数据 labels 之间的损失值 loss。
        loss=criterion(outputs,labels)
        # 使用反向传播算法计算模型参数的梯度信息,并使用优化器 optimizer 对模型参数进行更新。
        loss.backward()
         # 更新梯度
        optimizer.step()
        # 输出训练结果
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

print('Finished Training')

在这里插入图片描述

保存模型

# 模型保存
PATH = './mnist_net.pth'
torch.save(model.state_dict(), PATH)

4. 模型测试

# 测试CNN模型
with torch.no_grad(): # 进行评测的时候网络不更新梯度
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

在这里插入图片描述
这里训练的模型准确率达到了98%,非常高,如果还想继续提高模型准确率,可以调整迭代次数、学习率等参数或者修改CNN网络结构实现。

可视化检验一个批次测试数据的准确性

# 将 test_loader 转换为一个可迭代对象 dataiter
dataiter = iter(test_loader)
# 使用 next(dataiter) 获取 test_loader 中的下一个 batch 的图像数据和标签数据
images, labels = next(dataiter)

# print images
test_img = utils.make_grid(images)
test_img = test_img.numpy().transpose(1,2,0)
std = [0.5]
mean =  [0.5]
test_img = test_img*std+0.5
plt.imshow(test_img)
plt.show()
plt.savefig('./mnist_net.png')
print('GroundTruth: ', ' '.join('%d' % labels[j] for j in range(128)))

在这里插入图片描述

参考来源:
使用Pytorch框架的CNN网络实现手写数字(MNIST)识别
PyTorch初探MNIST数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/21750.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

应用现代化中的弹性伸缩

作者:马伟,青云科技容器顾问,云原生爱好者,目前专注于云原生技术,云原生领域技术栈涉及 Kubernetes、KubeSphere、KubeKey 等。 2019 年,我在给很多企业部署虚拟化,介绍虚拟网络和虚拟存储。 2…

微服务架构初探

大家好,我是易安!我们今天来谈一谈微服务架构的前世今生。 我们先来看看维基百科是如何定义微服务的。微服务的概念最早是在2014年由Martin Fowler和James Lewis共同提出,他们定义了微服务是由单一应用程序构成的小服务,拥有自己的…

建立在Safe生态的—GameFi SocialFi双赛道项目No.1头号玩家

最近大家关注的重点在BRC-20和MEME项目,人们似乎更在意短期的投机回报。而在这之外,一个web3的游戏——No.1头号玩家却得到了大量的玩家支持。 据了解,No.1是一个GameFi & SocialFi的双赛道web3游戏,中文名称为头号玩家。它是…

光纤衰减器作用及使用说明

在光纤通信中,光信号的强度过大或过小都会对信号的传输和接收产生不良的影响,因此光衰减器在光通信系统中起到了重要的作用。那什么是光衰减器呢?它又有什么作用呢?下面跟着小易一起来了解一下吧! 一、什么是光纤衰减…

HUSTOJ中添加初赛练习系统

文章目录 0. 基于hustoj二开的初赛练习系统,QQ4705852261. 主界面2. 练习界面3. 模拟考试界面4. 查看试卷回放5. 后台操作界面6. 后台试题分类-列表7.后台试题-列表8. 后台试题-添加9. 后台试卷结构-设置 0. 基于hustoj二开的初赛练习系统,QQ470585226 …

[笔记]渗透测试工具Burpsuit《一》Burpsuit介绍

文章目录 前言一、安装配置1.1 环境1.2 安装过程1.3 科技过程 二、常用功能2.1 Manual penetration testing features2.2 Advanced/custom automated attacks2.3 Automated scanning for vulnerabilities2.4 Productivity tools2.5 Extensions 三、拓展功能 前言 Burp Suite(b…

一、预约挂号微服务模块搭建

文章目录 一、预约挂号微服务模块搭建1、项目模块构建2、sql资源3、构建父工程(yygh-parent)3.1、添加配置pom.xml 4、搭建common父模块4.1、搭建common4.2、修改配置pom.xml 5、搭建common-util模块5.1、搭建common-util5.2、修改配置pom.xml5.3、添加公…

ELK的安装部署与使用

ELK的安装与使用 安装部署 部署环境:Elasticsearch-7.17.3 Logstash-7.17.3 Kibana-7.17.3 一、安装部署Elasticsearch 解压目录,进入conf目录下编辑elasticsearch.yml文件,输入以下内容并保存 network.host: 127.0.0.1 http.port: 9200…

计算机网络实验(ensp)-实验10:三层交换机实现VLAN间路由

目录 实验报告: 实验操作 1.建立网络拓扑图并开启设备 2.配置主机 1.打开PC机 配置IP地址和子网掩码 2.配置完成后点击“应用”退出 3.重复步骤1和2配置每台PC 3.配置交换机VLAN 1.点开交换机 2.输入命名:sys 从用户视图切换到系统视图…

Jenkins版本升级

Jenkins版本过低的时候,一些插件无法升级,会引发一系列错误,这个时候我们就要升级版本了 一、下载更新包 第一种方式:Jenkins页面下载最新包 第二种官网上下载war包(Jenkins官网) 二、打开服务器搜索jenkins.war路径 1、如果Jenk…

SQL Backup Master 6.3.6 Crack

SQL Backup Master 能够为用户将 SQL Server 数据库备份到一些简单的云存储服务中,例如 Dropbox、OneDrive、Amazon S3、Microsoft Azure、box,最后是 Google Drive。它能够将数据库备份到用户和开发者的FTP服务器上,甚至本地机器甚至网络服务…

这几款实用的电脑软件推荐给你

软件一:TeamViewer TeamViewer是一款跨平台的远程控制软件,它可以帮助用户远程访问和控制其他计算机、服务器、移动设备等,并且支持文件传输、会议功能等。 TeamViewer的主要功能包括: 远程控制:支持远程访问和控制…

HANTS时间序列滤波算法的MATLAB实现

本文介绍在MATLAB中,实现基于HANTS算法(时间序列谐波分析法)的长时间序列数据去噪、重建、填补的详细方法。 HANTS(Harmonic Analysis of Time Series)是一种用于时间序列分析和插值的算法。它基于谐波分析原理&#x…

自学黑客(网络安全)有哪些技巧——初学者篇

很多人说,要想学好黑客技术,首先你得真正热爱它。 热爱,听着多么让人激情澎湃,甚至热泪盈眶。 但很可惜,“热爱”这个词对还没入门的小白完全不管用。 如果一个人还没了解过你就说爱你,不是骗财就是骗色…

asp.net高校运动会管理系统的设计与实现

本高校运动会管理系统是针对我院当前运动会工作需要而开发的B/S模式的网络系统,涉及到运动会赛前的报名录入准备与分组编排、赛中的成绩处理、赛后的成绩汇总与团体总分的统计。它将是一个完整统一、技术先进、高效稳定、安全可靠的基于Internet/Intranet的高校运动…

Word怎么生成目录?4个方法快速生成目录!

案例:Word怎么生成目录 【想问下大家在使用Word时是怎么生成目录的呀?正在写毕业论文的我真的很急!感谢!】 Word作为我们常用的办公软件,为我们的提供了很多便利。生成目录是在Word文档中创建一个方便导航的索引。 …

Word怎么转换成PDF免费?分享适合你的Word转PDF方法

随着数字化时代的到来,将文件转换为PDF格式已经成为一个常见的需求。PDF文件格式的广泛应用使其在各个领域都非常重要,而Word文档则是最常见的文件类型之一。因此,将Word转换为PDF的方法备受关注。在下面,我将分享一种适合任何人使…

Android Service 使用

在Android应用开发中,Service是一种非常重要的组件。Service可以在后台执行长时间运行的任务,例如播放音乐、下载文件等。在本文中,我将会介绍如何使用Service组件,并通过代码实现来说明它的作用。 Android Service概述 在Androi…

如何为博客选择目标受众(+例子)

要创建免费网站和博客?从易服客建站平台开始 500M免费空间,可升级为20GB电子商务网站 创建免费网站 您是否正在寻找为您的博客选择目标受众的最佳实践? 选择目标受众可以让您创建更好的内容,引起用户的共鸣。这有助于您获得更…

【PDF软件篇】PDF轻量化电子笔记编辑利刃-Xodo软件优化

【PDF软件篇】PDF轻量化电子笔记编辑利刃-Xodo软件优化 默认配置已经够强,但是我还是推荐自定义,适合自己的就是最好的—【蘇小沐】 文章目录 【PDF软件篇】PDF轻量化电子笔记编辑利刃-Xodo软件优化1.实验环境 (一)日常办公导出无…