Pytorch-以数字识别更好地入门深度学习

目录

一、数据介绍

二、下载数据 

三、可视化数据

四、模型构建

五、模型训练

六、模型预测


一、数据介绍

MNIST数据集是深度学习入门的经典案例,因为它具有以下优点:

1. 数据量小,计算速度快。MNIST数据集包含60000个训练样本和10000个测试样本,每张图像的大小为28x28像素,这样的数据量非常适合在GPU上进行并行计算。

2. 标签简单,易于理解。MNIST数据集的标签只有0-9这10个数字,相比其他图像分类数据集如CIFAR-10等更加简单易懂。

3. 数据集已标准化。MNIST数据集中的图像已经被归一化到0-1之间,这使得模型可以更快地收敛并提高准确率。

4. 适合初学者练习。MNIST数据集是深度学习入门的最佳选择之一,因为它既不需要复杂的数据预处理,也不需要大量的计算资源,可以帮助初学者快速上手深度学习。

综上所述,MNIST数据集是深度学习入门的经典案例,它具有数据量小、计算速度快、标签简单、数据集已标准化、适合初学者练习等优点,因此被广泛应用于深度学习的教学和实践中。

手写数字识别技术的应用非常广泛,例如在金融、保险、医疗、教育等领域中,都有很多应用场景。手写数字识别技术可以帮助人们更方便地进行数字化处理,提高工作效率和准确性。此外,手写数字识别技术还可以用于机器人控制、智能家居等方面  。

使用torch.datasets.MNIST下载到指定目录下:./data,当download=True时,如果已经下载了不会再重复下载,同train选择下载训练数据还是测试数据

官方提供的类:

class MNIST(
    root: str,
    train: bool = True,
    transform: ((...) -> Any) | None = None,
    target_transform: ((...) -> Any) | None = None,
    download: bool = False
)
Args:
    root (string): Root directory of dataset where MNIST/raw/train-images-idx3-ubyte
        and MNIST/raw/t10k-images-idx3-ubyte exist.
    train (bool, optional): If True, creates dataset from train-images-idx3-ubyte,
        otherwise from t10k-images-idx3-ubyte.
    download (bool, optional): If True, downloads the dataset from the internet and
        puts it in root directory. If dataset is already downloaded, it is not downloaded again.
    transform (callable, optional): A function/transform that takes in an PIL image
        and returns a transformed version. E.g, transforms.RandomCrop
    target_transform (callable, optional): A function/transform that takes in the
        target and transforms it.

二、下载数据 

# 导入数据集
# 训练集
import torch
from torchvision import datasets,transforms
from torch.utils.data import Dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root="./data",
                   train=True,
                   download=True,
                   transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),
                    batch_size=64,
                    shuffle=True)

# 测试集
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST("./data",train=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),
    batch_size=64,shuffle=True
)

pytorch也提供了自定义数据的方法,根据自己数据进行处理

使用PyTorch提供的Dataset和DataLoader类来定制自己的数据集。如果想个性化自己的数据集或者数据传递方式,也可以自己重写子类。

以下是一个简单的例子,展示如何创建一个自定义的数据集并将其传递给模型进行训练:

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y

data = torch.randn(100, 3, 32, 32)
labels = torch.randint(0, 10, (100,))

my_dataset = MyDataset(data, labels)
my_dataloader = DataLoader(my_dataset, batch_size=4, shuffle=True)

详细完整流程可参考: Pytorch快速搭建并训练CNN模型?

三、可视化数据

mport matplotlib.pyplot as plt
import numpy as np
import torchvision
def imshow(img):
    img = img / 2 + 0.5 # 逆归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.title("Label")
    plt.show()

# 得到batch中的数据
dataiter = iter(train_loader)
images,labels = next(dataiter)
# 展示图片
imshow(torchvision.utils.make_grid(images))

四、模型构建

定义模型类并继承nn.Module基类

# 构建模型
import torch.nn as nn
import torch
import torch.nn.functional as F
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        # 输入图像为单通道,输出为六通道,卷积核大小为5×5
        self.conv1 = nn.Conv2d(1,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        # 将16×4×4的Tensor转换为一个120维的Tensor,因为后面需要通过全连接层
        self.fc1 = nn.Linear(16*4*4,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    
    def forward(self,x):
        # 在(2,2)的窗口上进行池化
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        # 将维度转换成以batch为第一维,剩余维数相乘为第二维
        x = x.view(-1,self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self,x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    
net = MyNet()
print(net)

输出: 

MyNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

简单的前向传播

# 前向传播
print(len(images))
image = images[:2]
label = labels[:2]
print(image.shape)
print(image.size())
print(label)
out = net(image)
print(out)

输出: 

16
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 28, 28])
tensor([6, 0])
tensor([[ 1.5441e+00, -1.2524e+00,  5.7165e-01, -3.6299e+00,  3.4144e+00,
          2.7756e+00,  1.1974e+01, -6.6951e+00, -1.2850e+00, -3.5383e+00],
        [ 6.7947e+00, -7.1824e+00,  8.8787e-01, -5.2218e-01, -4.1045e+00,
          4.6080e-01, -1.9258e+00,  1.8958e-01, -7.7214e-01, -6.3265e-03]],
       grad_fn=<AddmmBackward0>)

计算损失:

# 计算损失
# 因为是多分类,所有采用CrossEntropyLoss函数,二分类用BCELoss
image = images[:2]
label = labels[:2]
out = net(image)
criterion = nn.CrossEntropyLoss()
loss = criterion(out,label)
print(loss)

输出:

tensor(2.2938, grad_fn=<NllLossBackward0>)

五、模型训练

# 开始训练
model = MyNet()
# device = torch.device("cuda:0")
# model = model.to(device)
import torch.optim as optim
optimizer = optim.SGD(model.parameters(),lr=0.01) # lr表示学习率
criterion = nn.CrossEntropyLoss()
def train(epoch):
    # 设置为训练模式:某些层的行为会发生变化(dropout和batchnorm:会根据当前批次的数据计算均值和方差,加速模型的泛化能力)
    model.train()
    running_loss = 0.0
    for i,data in enumerate(train_loader):
        # 得到输入和标签
        inputs,labels = data
        # 消除梯度
        optimizer.zero_grad()
        # 前向传播、计算损失、反向传播、更新参数
        outputs = model(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        # 打印日志
        running_loss += loss.item()
        if i % 100 == 0:
            print("[%d,%5d] loss: %.3f"%(epoch+1,i+1,running_loss/100))
            running_loss = 0

train(10)

输出:

[11,    1] loss: 0.023
[11,  101] loss: 2.302
[11,  201] loss: 2.294
[11,  301] loss: 2.278
[11,  401] loss: 2.231
[11,  501] loss: 1.931
[11,  601] loss: 0.947
[11,  701] loss: 0.601
[11,  801] loss: 0.466
[11,  901] loss: 0.399

六、模型预测

# 模型预测结果
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images,labels = data
        outputs = model(images)
        # 最大的数值及最大值对应的索引
        value,predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        # 对bool型的张量进行求和操作,得到所有预测正确的样本数,采用item将整数类型的张量转换为python中的整型对象
        correct += (predicted == labels).sum().item()
    print("predicted:{}".format(predicted[:10].tolist()))
    print("label:{}".format(labels[:10].tolist()))
    print("Accuracy of the network on the 10 test images: %d %%"% (100*correct/total))

imshow(torchvision.utils.make_grid(images[:10],nrow=len(images[:10])))

输出:

predicted:[1, 0, 7, 6, 5, 2, 4, 3, 2, 6]
label:[1, 0, 7, 6, 5, 2, 4, 8, 2, 6]
Accuracy of the network on the 10 test images: 91 %

对应类别的准确率:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
classes = [i for i in range(10)]

with torch.no_grad():# model.eval()
    for data in test_loader:
        images,labels = data
        outputs = model(images)
        value,predicted = torch.max(outputs,1)
        c = (predicted == labels).squeeze()
        # 对所有labels逐个进行判断
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
    print("class_correct:{}".format(class_correct))
    print("class_total:{}".format(class_total))

# 每个类别的指标
for i in range(10):
    print('Accuracy of -> class %d : %2d %%'%(classes[i],100*class_correct[i]/class_total[i]))

输出:

class_correct:[958.0, 1119.0, 948.0, 938.0, 901.0, 682.0, 913.0, 918.0, 748.0, 902.0]
class_total:[980.0, 1135.0, 1032.0, 1010.0, 982.0, 892.0, 958.0, 1028.0, 974.0, 1009.0]
Accuracy of -> class 0 : 97 %
Accuracy of -> class 1 : 98 %
Accuracy of -> class 2 : 91 %
Accuracy of -> class 3 : 92 %
Accuracy of -> class 4 : 91 %
Accuracy of -> class 5 : 76 %
Accuracy of -> class 6 : 95 %
Accuracy of -> class 7 : 89 %
Accuracy of -> class 8 : 76 %
Accuracy of -> class 9 : 89 %

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/95938.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

直播预告|博睿学院第四季即将开讲:博睿数据资深运维团队现身说法!

博睿学院第四季开讲啦&#xff01;本季博睿学院的课程将于本周四&#xff08;8月31日&#xff09;16点正式启动。本季我们邀请到了博睿数据平台支撑中心的四位资深运维专家现身说法&#xff0c;来为我们分享一体化智能可观测平台Bonree ONE的实践干货。 他们&#xff0c;见多识…

docker 学习-- 04 实践2 (lnpmr环境)

docker 学习 系列文章目录 docker 学习-- 01 基础知识 docker 学习-- 02 常用命令 docker 学习-- 03 环境安装 docker 学习-- 04 实践 1&#xff08;宝塔&#xff09; docker 学习-- 04 实践 2 &#xff08;lnpmr环境&#xff09; 文章目录 docker 学习 系列文章目录1. 配…

模型的保存加载、模型微调、GPU使用及Pytorch常见报错

序列化与反序列化 序列化就是说内存中的某一个对象保存到硬盘当中&#xff0c;以二进制序列的形式存储下来&#xff0c;这就是一个序列化的过程。 而反序列化&#xff0c;就是将硬盘中存储的二进制的数&#xff0c;反序列化到内存当中&#xff0c;得到一个相应的对象&#xff…

全球选手逐鹿清华!首届AI药物研发算法大赛完美收官

8月26日&#xff0c;首届全球AI药物研发算法大赛决赛答辩暨颁奖典礼&#xff0c;在清华大学生物医学馆举行。来自微软研究院、中国科学院上海药物研究所、上海交通大学等单位的十五支团队&#xff0c;从全球878支团队中脱颖而出&#xff0c;进入了决赛答辩环节。 产教融合&…

UDP 多播(组播)

前言&#xff08;了解分类的IP地址&#xff09; 1.组播&#xff08;多播&#xff09; 单播地址标识单个IP接口&#xff0c;广播地址标识某个子网的所有IP接口&#xff0c;多播地址标识一组IP接口。单播和广播是寻址方案的两个极端&#xff08;要么单个要么全部&#xff09;&am…

性能测试常见的测试指标

一、什么是性能测试 先看下百度百科对它的定义 性能测试是通过自动化的测试工具模拟多种正常、峰值以及异常负载条件来对系统的各项性能指标进行测试。我们可以认为性能测试是&#xff1a;通过在测试环境下对系统或构件的性能进行探测&#xff0c;用以验证在生产环境下系统性能…

2023最新Python重点知识万字汇总

这是一份来自于 SegmentFault 上的开发者 二十一 总结的 Python 重点。由于总结了太多的东西&#xff0c;所以篇幅有点长&#xff0c;这也是作者"缝缝补补"总结了好久的东西。 **Py2 VS Py3** * print成为了函数&#xff0c;python2是关键字* 不再有unicode对象…

百度垂类离线计算系统发展历程

作者 | 弘远君 导读 本文以百度垂类离线计算系统的演进方向为主线&#xff0c;详细描述搜索垂类离线计算系统发展过程中遇到的问题&#xff0c;以及对应的解决方案。架构演进过程中一直奉行“没有最好的架构&#xff0c;只有最合适的架构”的宗旨&#xff0c;面对不同阶段遇到的…

微信小程序使用stomp.js实现STOMP传输协议的实时聊天

简介&#xff1a; uniapp开发的小程序中使用 本来使用websocket&#xff0c;后端同事使用了stomp协议&#xff0c;导致前端也需要对应修改。 如何使用 在static/js中新建stomp.js和websocket.js&#xff0c;然后在需要使用的页面引入监听代码发送代码即可 代码如下&#x…

request+python操作文件导入

业务场景&#xff1a; 通常我们需要上传文件或者导入文件如何操作呢&#xff1f; 首先通过f12或者通过抓包查到请求接口的参数&#xff0c;例如&#xff1a; 图中标注的就是我们需要的参数&#xff0c;其中 name是参数名&#xff0c;filename是文件名&#xff0c;Content-Type是…

keil5 报错no target connected

场景&#xff1a;用ST_Link V2 在 keil5 中下载stm32程序 原因&#xff1a;线路连接错误 正确连接 注意&#xff1a;江科大stm32和stlink的接线&#xff0c;一定要对齐&#xff0c;我买的一个不是按照顺序接线的&#xff0c;需要仔细查看

OpenHarmony设备截屏的5种方式

本文转载自《OpenHarmony设备截屏的5种方式 》&#xff0c;作者westinyang 目录 方式1&#xff1a;系统控制中心方式2&#xff1a;OHScrcpy投屏工具方式3&#xff1a;DevEcoStudio截屏功能方式4&#xff1a;hdc shell snapshot_display方式5&#xff1a;hdc shell wukong持续关…

C++新经典 | C语言

目录 一、基础之查漏补缺 1.float精度问题 2.字符型数据 3.变量初值问题 4.赋值&初始化 5.头文件之<> VS " " 6.逻辑运算 7.数组 7.1 二维数组初始化 7.2 字符数组 8.字符串处理函数 8.1 strcat 8.2 strcpy 8.3 strcmp 8.4 strlen 9.函数 …

(笔记四)利用opencv识别标记视频中的目标

预操作&#xff1a; 通过cv2将视频的某一帧图片转为HSV模式&#xff0c;并通过鼠标获取对应区域目标的HSV值&#xff0c;用于后续的目标识别阈值区间的选取 img cv.imread(r"D:\data\123.png") img cv.cvtColor(img, cv.COLOR_BGR2HSV) plt.figure(1), plt.imshow…

JixiPix Artista Impresso Pro for mac(油画滤镜效果软件)

JixiPix Artista Impresso pro Mac是一款专业的图像编辑软件&#xff0c;专为Mac用户设计。它提供了各种高质量的图像编辑工具&#xff0c;可以帮助您创建令人惊叹的图像。该软件具有直观的用户界面&#xff0c;使您可以轻松地浏览和使用各种工具。 它还支持多种文件格式&…

CSS中如何实现弹性盒子布局(Flexbox)的换行和排序功能?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 换行&#xff08;Flexbox Wrapping&#xff09;⭐ 示例&#xff1a;实现换行⭐ 排序&#xff08;Flexbox Ordering&#xff09;⭐ 示例&#xff1a;实现排序⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得…

大模型开发05:PDF 翻译工具开发实战

大模型开发实战05:PDF 翻译工具开发实战 PDF-Translator 机器翻译是最广泛和基础的 NLP 任务 PDF-Translator PDF 翻译器是一个使用 AI 大模型技术将英文 PDF 书籍翻译成中文的工具。这个工具使用了大型语言模型 (LLMs),如 ChatGLM 和 OpenAI 的 GPT-3 以及 GPT-3.5 Turbo 来…

长胜证券:股票配资什么意思

股票配资是指假贷的方法来进行股票出资&#xff0c;是指出资者经过向配资公司或个人假贷&#xff0c;以增加其自有资金的杠杆份额&#xff0c;然后到达更高的收益。股票配资可以用于股票、期货、外汇等多种金融市场&#xff0c;一起也是一种危险较大的出资方法。本文将从多个视…

ios开发 swift5 苹果系统自带的图标 SF Symbols

文章目录 1.官网app的下载和使用2.使用代码 1.官网app的下载和使用 苹果官网网址&#xff1a;SF Symbols 通过上面的网址可以下载dmg, 安装到自己的mac上 貌似下面这样不能展示出动画&#xff0c;还是要使用动画的代码 .bounce.up.byLayer2.使用代码 UIKit UIImage(system…

RESTful API 面试必问

RESTful API是一种基于 HTTP 协议的 API 设计风格&#xff0c;它提供了一组规范和约束&#xff0c;使得客户端&#xff08;如 Web 应用程序、移动应用等&#xff09;和服务端之间的通信更加清晰、简洁和易于理解。 RESTful API 的设计原则 使用 HTTP 协议&#xff1a;RESTful …