PyTorch训练RNN, GRU, LSTM:手写数字识别

文章目录

    • pytorch 神经网络训练demo
    • Result
    • 参考来源

pytorch 神经网络训练demo

数据集:MNIST

该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片

在这里插入图片描述
图片来源:https://tensornews.cn/mnist_intro/

神经网络:RNN, GRU, LSTM

# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 2

# Create a RNN
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connected
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward Prop
        out, _ = self.rnn(x, h0)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)

        return out
    
# Create a GRU
class RNN_GRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connected
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward Prop
        out, _ = self.gru(x, h0)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)

        return out


# Create a LSTM
class RNN_LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connected
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward Prop
        out, _ = self.lstm(x, (h0, c0))
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)
        return out
    

# Load data
train_dataset = datasets.MNIST(root='dataset/', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', 
                              train=False, 
                               transform=transforms.ToTensor(),
                               download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

# Initialize network 选择一个即可
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_GRU(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_LSTM(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train network
for epoch in range(num_epochs):
    # data: images, targets: labels
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Get data to cuda if possible
        data = data.to(device).squeeze(1) # 删除一个张量中所有维数为1的维度 (N, 1, 28, 28) -> (N, 28, 28)
        targets = targets.to(device)

        # forward
        scores = model(data) # 64*10
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()


# Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):
    if loader.dataset.train:
        print("Checking accuracy on training data")
    else:
        print("Checking accuracy on test data")
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad(): # 不计算梯度
        for x, y in loader:
            x = x.to(device).squeeze(1)
            y = y.to(device)
            # x = x.reshape(x.shape[0], -1) # 64*784

            scores = model(x)# 64*10
            _, predictions = scores.max(dim=1) #dim=1,表示对每行取最大值,每行代表一个样本。
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0) # 64

        print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')

    model.train()

check_accuracy(train_loader, model)
check_accuracy(test_loader, model)


Result

RNN Result
Checking accuracy on training data
Got 57926 / 60000 with accuracy 96.54%
Checking accuracy on test data
Got 9640 / 10000 with accuracy 96.40%


GRU Result
Checking accuracy on training data
Got 59058 / 60000 with accuracy 98.43%
Checking accuracy on test data
Got 9841 / 10000 with accuracy 98.41%

LSTM Result
Checking accuracy on training data
Got 59248 / 60000 with accuracy 98.75%
Checking accuracy on test data
Got 9849 / 10000 with accuracy 98.49%

参考来源

【1】https://www.youtube.com/watch?v=Gl2WXLIMvKA&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/38947.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

每日一题2023.7.19|ACM模式

文章目录 C的输入方式介绍cin>>cin.get(字符变量名)cin.get(数组名,接收字符数目)cin.get()cin.getline() getline()gets()getchar() AB问题|AB问题||AB问题|||ABⅣAB问题ⅤAB问题Ⅵ C的输入方式介绍 参考博客 cin>> 最基本,最常用的字符或者数字的输…

TMS FlexCel for VCL FMX Crack

TMS FlexCel for VCL & FMX Crack 强大、广泛和灵活的组件套件,用于VCL和FireMonkey的本地Excel报告、文件生成和操作。 FlexCel for VCL/FireMonkey是一套允许操作Excel文件的Delphi组件。它包括一个广泛的API,允许本地读/写Excel文件。如果您需要在…

c#调用cpp库,debug时不进入cpp函数

选中c#的项目,右击属性,进入属性页,点击调试,点击打开调试启动配置文件UI,打开启用本机代码调试。

uniapp 集成七牛云,上传图片

1 创建项目 我是可视化创建项目的 ,cli创建的项目可以直接使用npm安装七牛云。 2 拷贝qiniuUploader.js到项目,下面的回复 放了qiniuUploader.js百度云链接。 3 在需要使用qiniuUploader的vue文件 引入。 4 相册选择照片,或者拍照后&#xff…

工欲善其事,必先利其器之—react-native-debugger调试react native应用

调试react应用通常利用chrome的inspector的功能和两个最常用的扩展 1、React Developer Tools (主要用于debug组件结构) 2、Redux DevTools (主要用于debug redux store的数据) 对于react native应用,我们一般就使用react-nativ…

vue 项目优化

去除冗余的css 消除框架中未使用的CSS,初步达到按需引入的效果 使用背景:vue2.x, webpack3.x 使用插件:purifycss-webpack 安装: npm i purifycss-webpack purify-css glob-all -D安装后各个插件的版本: “glob-all”: “^3.3.…

轻松实现数据一体化:轻易云数据集成平台全解析

在当今快速发展的商业环境中,企业面临着大量来自多样数据源的数据。如何将这些数据进行高效集成和利用,成为企业数字化转型的关键挑战。轻易云数据集成平台提供了一个一站式的解决方案,帮助企业实现数据的无缝集成和高效利用。下面我们将通过…

[java安全]URLDNS

文章目录 [java安全]URLDNS前言HashMapURLURLStreamHandler调用过程调用链流程图POC [java安全]URLDNS 前言 URLDNS利用链是一条很简单的链子,可以用来查看java反序列化是否存在反序列化漏洞,如果存在,就会触发dns查询请求 它有如下优点&a…

docker-compose搭建prometheus+grafana+钉钉告警

前言: 本文将介绍使用docker-compose部署搭建promtheus监控容器、主机、服务等相关状态; 配合granfana面板构建监控大屏; 由于grafana的报警不是很友好,使用dingtalk,配合altermanager,实现钉钉报警。 …

pico添加devmem2读写内存模块

devmem2读写内存 自定义msh命令devmem2验证msh命令devmem2读CPUID读写全局变量 devmem2模块可实现对设备寄存器的读写操作。在RT-Thread的命令行组件Fish中添加devmem2模块,用户可在终端输入devmem2相关命令,FinSH根据输入对指定寄存器进行读写&#xff…

提高LLaMA-7B的数学推理能力

概述 这篇文章探讨了利用多视角微调方法提高数学推理的泛化能力。数学推理在相对较小的语言模型中仍然是一个挑战,许多现有方法倾向于依赖庞大但效率低下的大语言模型进行知识蒸馏。研究人员提出了一种避免过度依赖大语言模型的新方法,该方法通过有效利…

JVM中类加载的过程

文章目录 一、类加载是什么二、类加载过程1.加载2.验证3.准备4.解析5.初始化 三、什么时候进行类加载四、双亲委派模型1.三大类加载器2.加载过程 总 一、类加载是什么 把.class文件加载到内存中,得到类对象的过程。 二、类加载过程 1.加载 找到.class文件&#xff…

QT Quick初学笔记---第一篇

链接: QML Book中文版(QML Book In Chinese) 1、对Qt Quick的初步认识 Qt Quick是Qt5界面开发技术的统称,是以下几种技术的集合: QML:界面标记语言JavaScript:动态脚本语言QT C:跨平台C封装库 QML是与HTML类似的一…

守护数智未来,开源网安受邀参加2023OWASP北京论坛

2023年7月14日,OWASP中国与网安加社区联合举办的“2023OWASP中国北京安全技术论坛”在北京圆满召开,开源网安受邀参加本次论坛并分享“软件供应链安全治理实践”。 本次“2023OWASP中国北京安全技术论坛”是OWASP中国北京地区年度重要活动之一&#xff…

数据库信息速递 MONGODB 6.0 的新特性,更多的查询函数,加密查询,与时序数据集合 (译)...

开头还是介绍一下群,如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题,有需求都可以加群群内有各大数据库行业大咖,CTO,可以解决你的问题。加群请联系 liuaustin3 ,在新加的朋友会分到3群(共…

ffmpeg学习之音频解码数据

音频数据经过解码后会被保存为,pcm数据格式。而对应的处理流程如下所示。 avcodec_find_encoder() /*** 查找具有匹配编解码器ID的已注册编码器.** param id AVCodecID of the requested encoder* return An encoder if one was found, NULL otherwise.*/ const A…

docker k8s

Docker docker到底与一般的虚拟机有什么不同呢? 我们知道一般的linux系统即GNU/Linux系统包括两个部分,linux系统内核GNU提供的大量自由软件,而centos就是众多GNU/Linux系统中的一个。 虚拟机会在宿主机上虚拟出一个完整的操作系统与宿主机完…

数智领航 信创强基 | GBASE南大通用携手金仕达共助金融用户合规风控

GBASE南大通用董事长丁明峰先生应邀出席大会并在主论坛发表题为《去全球化背景下的中国数据库发展策略》的主题分享。 技术的迭代发展是经济增长、产业升级的核心动力。纵观近现代社会史,信息技术和通信技术的迅猛发展,帮助人类实现了PC互联网到移动互联…

MyBatis全篇

文章目录 MyBatis特性下载持久化层技术对比 搭建MyBatis创建maven工程创建MyBatis的核心配置文件创建mapper接口创建MyBatis的映射文件测试功能加入log4j日志功能加入log4j的配置文件 核心配置文件的完善与详解MyBatis的增删改查测试功能 MyBatis获取参数值在IDEA中设置中配置文…

Java遍历集合方法分析(实现原理、算法性能、适用场合)

Java遍历集合方法分析(实现原理、算法性能、适用场合) 概述 java语言中,提供了一套数据集合框架,其中定义了一些诸如List、Set等抽象数据类型,每个抽象数据类型的各个具体实现,底层又采用了不同的实现方式…