实战:基于卷积的MNIST手写体分类

前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。

1.  数据的准备

在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章节直接对数据进行“折叠”处理,这里需要显式地标注出数据的通道,代码如下:

import numpy as np

import einops.layers.torch as elt

#载入数据

x_train = np.load("../dataset/mnist/x_train.npy")

y_train_label = np.load("../dataset/mnist/y_train_label.npy")

x_train = np.expand_dims(x_train,axis=1)   #在指定维度上进行扩充

print(x_train.shape)

这里是对数据的修正,np.expand_dims的作用是在指定维度上进行扩充,这里在第二维(也就是PyTorch的通道维度)进行扩充,结果如下:

(60000, 1, 28, 28)

2.  模型的设计

下面使用PyTorch 2.0框架对模型进行设计,在本例中将使用卷积层对数据进行处理,完整的模型如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
class MnistNetword(nn.Module):
    def __init__(self):
        super(MnistNetword, self).__init__()
        #前置的特征提取模块
        self.convs_stack = nn.Sequential(
            nn.Conv2d(1,12,kernel_size=7),  	#第一个卷积层
            nn.ReLU(),
            nn.Conv2d(12,24,kernel_size=5), 	#第二个卷积层
            nn.ReLU(),
            nn.Conv2d(24,6,kernel_size=3)  	#第三个卷积层
        )
        #最终分类器层
        self.logits_layer = nn.Linear(in_features=1536,out_features=10)
    def forward(self,inputs):
        image = inputs
        x = self.convs_stack(image)        
        #elt.Rearrange的作用是对输入数据的维度进行调整,读者可以使用torch.nn.Flatten函数完成此工作
        x = elt.Rearrange("b c h w -> b (c h w)")(x)
        logits = self.logits_layer(x)
        return logits
model = MnistNetword()
torch.save(model,"model.pth")

这里首先设定了3个卷积层作为前置的特征提取层,最后一个全连接层作为分类器层,需要注意的是,对于分类器的全连接层,输入维度需要手动计算,当然读者可以一步一步尝试打印特征提取层的结果,依次将结果作为下一层的输入维度。最后对模型进行保存。

3.  基于卷积的MNIST分类模型

下面进入本章的最后示例部分,也就是MNIST手写体的分类。完整的训练代码如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
#载入数据
x_train = np.load("../dataset/mnist/x_train.npy")
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x_train = np.expand_dims(x_train,axis=1)
print(x_train.shape)
class MnistNetword(nn.Module):
    def __init__(self):
        super(MnistNetword, self).__init__()
        self.convs_stack = nn.Sequential(
            nn.Conv2d(1,12,kernel_size=7),
            nn.ReLU(),
            nn.Conv2d(12,24,kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(24,6,kernel_size=3)
        )
        self.logits_layer = nn.Linear(in_features=1536,out_features=10)
    def forward(self,inputs):
        image = inputs
        x = self.convs_stack(image)
        x = elt.Rearrange("b c h w -> b (c h w)")(x)
        logits = self.logits_layer(x)
        return logits
device = "cuda" if torch.cuda.is_available() else "cpu"
#注意记得将model发送到GPU计算
model = MnistNetword().to(device)
model = torch.compile(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
batch_size = 128
for epoch in range(42):
    train_num = len(x_train)//128
    train_loss = 0.
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size
        x_batch = torch.tensor(x_train[start:end]).to(device)
        y_batch = torch.tensor(y_train_label[start:end]).to(device)
        pred = model(x_batch)
        loss = loss_fn(pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()  # 记录每个批次的损失值
    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == y_batch).type(torch.float32).sum().item() / batch_size
    print("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

在这里,我们使用了本章新定义的卷积神经网络模块作为局部特征抽取,而对于其他的损失函数以及优化函数,只使用了与前期一样的模式进行模型训练。最终结果如下所示,请读者自行验证。

(60000, 1, 28, 28)
epoch: 0 train_loss: 2.3 accuracy: 0.11
epoch: 1 train_loss: 2.3 accuracy: 0.13
epoch: 2 train_loss: 2.3 accuracy: 0.2
epoch: 3 train_loss: 2.3 accuracy: 0.18
…
epoch: 58 train_loss: 0.5 accuracy: 0.98
epoch: 59 train_loss: 0.49 accuracy: 0.98
epoch: 60 train_loss: 0.49 accuracy: 0.98
epoch: 61 train_loss: 0.48 accuracy: 0.98
epoch: 62 train_loss: 0.48 accuracy: 0.98

Process finished with exit code 0

本文节选自《PyTorch 2.0深度学习从零开始学》,本书实战案例丰富,可带领读者快速掌握深度学习算法及其常见案例。

   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/100876.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【USRP】集成化仪器系列2 :示波器,基于labview实现

USRP 示波器 1、设备IP地址:默认为192.168.10.2,请勿 修改,运行阶段无法修改。 2、中心频率:当需要生成不同频率单载波的 时候请直接修改中心频率,在运行的时候您 也可以直接修改中心频率。 3、接收增益:…

Java与其他编程语言比较分析,编程语言选择与优点、缺点和适用场景详解

原文地址:Java与其他编程语言比较分析,编程语言选择与优点、缺点和适用场景详解 Java 擅长可移植性和可靠性,Python 擅长通用性和简单性,JavaScript 擅长 Web 开发,C 擅长性能,Go 擅长效率。网址:yii666.c…

大数据Flink简介与架构剖析并搭建基础运行环境

文章目录 前言Flink 简介Flink 集群剖析Flink应用场景Flink基础运行环境搭建Docker安装docker-compose文件编写创建并运行容器访问Flink web界面 前言 前面我们分别介绍了大数据计算框架Hadoop与Spark,虽然他们有的有着良好的分布式文件系统和分布式计算引擎,有的有…

内部类总结

内部类 1、内部类介绍: 外 2、成员内部类: 3、静态内部类 4、局部内部类: 5、匿名内部类:

以antd为例 React+Typescript 引入第三方UI库

本文 我们来说说 第三方UI库 其实应用市场上的 第三方UI库都是非常优秀的 那么 react 我们比较熟的肯定还是 antd 我们还是来用它作为演示 这边 我们先访问他的官网 https://3x.ant.design/index-cn 点击开始使用 在左侧 有一个 在 TypeScript 中使用 通过图标我们也可以看出…

前端面试必备 | uni-app 篇(P1-15)

文章目录 1. 请简述一下uni-app的定义和特点。2. uni-app兼容哪些前端框架?请列举几个。3. 请简述一下uni-app的跨平台工作原理。4. 什么是条件编译?在uni-app中如何实现条件编译?5. uni-app中的页面生命周期有哪些?请简要介绍。6…

word 调整列表缩进

word 调整列表缩进的一种方法,在试了其他方法无效后,按下图所示顺序处理,编号和文字之间的空白就没那么大了。 即右键word上方样式->点击修改格式->定义新编号格式->字体->取消勾选 “……对齐到网格”->确定

Redis-监听过期key-JAVA实现方案

一、创建监听配置类 RedisListenerConfig。 import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.d…

目录扫描+JS文件中提取URL和子域+403状态绕过+指纹识别(dirsearch_bypass403)

dirsearch_bypass403 在安全测试时,安全测试人员信息收集中时可使用它进行目录枚举,目录进行指纹识别,枚举出来的403状态目录可尝试进行绕过,绕过403有可能获取管理员权限。不影响dirsearch原本功能使用 运行流程 dirsearch进行…

docker常见面试问题详解

在面试的时候,面试官常常会问一些问题: docker是什么,能做什么?docker和虚拟机的区别是什么呢?docker是用什么做隔离的?docke的网络类型?docker数据之间是如何通信的?docker的数据保…

pom.xml配置文件失效,显示已忽略的pom.xml --- 解决方案

现象: 在 Maven 创建模块Moudle时,由于开始没有正确创建好,所以把它删掉了,然后接着又创建了与一个与之前被删除的Moudle同名的Moudle时,出现了 Ignore pom.xml,并且新创建的 Module 的 pom.xml配置文件失效&#xf…

简述SpringMVC

一、典型的Servlet JSP JavaBean UserServlet看作业务逻辑处理(Controller)User看作模型(Model)user.jsp看作渲染(View) 二、高级MVC 由DispatcherServlet对请求统一处理 三、SpringMVC MVC与Spr…

字符串匹配的Rabin–Karp算法

leetcode-28 实现strStr() 更熟悉的字符串匹配算法可能是KMP算法, 但在Golang中,使用的是Rabin–Karp算法 一般中文译作 拉宾-卡普算法,由迈克尔拉宾与理查德卡普于1987年提出 “ 要在一段文本中找出单个模式串的一个匹配,此算法具有线性时间的平均复杂度&#xff0…

设计模式行为模式-命令模式

文章目录 前言定义结构工作原理优点适用场景消息队列模式Demo实现分写业务总结 前言 定义 命令模式(Command Pattern)是一种行为型设计模式,用于将请求封装为对象,从而使你可以使用不同的请求、队列或者日志请求来参数化其他对象…

PY32F003F18点灯

延时函数学习完之后,可以学习PY32F003F18的GPIO输出功能。 1、Debug引脚默认被置于复用功能上拉或下拉模式:PA14默认为SWCLK: 置于下拉模式PA13默认为SWDIO: 置于上拉模式PF4默认为Boot:Boot引脚默认置于输入下拉模式 2、GPIO输出状态&#…

亚马逊云科技生成式AI技术辅助教学领域,近实时智能应答2D数字人搭建

早在大语言模型如GPT-3.5等的兴起和被日渐广泛的采用之前,教育行业已经在AI辅助教学领域有过各种各样的尝试。在教育行业,人工智能技术的采用帮助教育行业更好地实现教学目标,提高教学质量、学习效率、学习体验、学习成果。例如,人…

sql各种注入案例

目录 1.报错注入七大常用函数 1)ST_LatFromGeoHash (mysql>5.7.x) 2)ST_LongFromGeoHash &#xff08;mysql>5.7.x&#xff09; 3)GTID (MySQL > 5.6.X - 显错<200) 3.1 GTID 3.2 函数详解 3.3 注入过程( payload ) 4)ST_Pointfromgeohash (mysql>5.…

蓝桥杯 2240. 买钢笔和铅笔的方案数c++解法

最近才回学校。在家学习的计划不翼而飞。但是回到学校了&#xff0c;还是没有找回状态。 现在是大三了&#xff0c;之前和同学聊天&#xff0c;说才大三无论是干什么&#xff0c;考研&#xff0c;找工作&#xff0c;考公&#xff0c;考证书 还都是来的及的。 但是心里面…

css换行

强制显示一行&#xff0c;超出... .box{white-space: nowrap; /* 强制显示一行 */overflow: hidden;text-overflow: ellipsis; /* 超出... */ } 自动换行 一般默认制动换行 .box1{word-wrap:break-word; } 显示2行&#xff0c;超出... .box2 {overflow: hidden;display: -…

LabVIEW计算测量路径输出端随机变量的概率分布密度

LabVIEW计算测量路径输出端随机变量的概率分布密度 今天&#xff0c;开发算法和软件来解决计量综合的问题&#xff0c;即为特定问题寻找最佳测量算法。提出了算法支持&#xff0c;以便从计量上综合测量路径并确定所开发测量仪器的测量误差。测量路径由串联的几个块组成&#x…