深度学习之使用BP神经网络识别MNIST数据集

目录

补充知识点

torch.nn.LogSoftmax()

 torchvision.transforms

transforms.Compose

transforms.ToTensor

transforms.Normalize(mean, std)

torchvision.datasets

MNIST(手写数字数据集)

torch.utils.data.DataLoader

torch.nn.NLLLoss() 

torch.nn.CrossEntropyLoss()

torch.nn.NLLLoss() 

enumerate() 函数

next() 函数

pytorch中.detach()

torch.squeeze()

torch.optim.SGD 

独热编码

代码实现

bp网络的搭建

建立我们的神经网络对象

预测

完整代码

结果极其图像显示


补充知识点

torch.nn.LogSoftmax()

torch.nn.LogSoftmax()和我们的softmax差不多只不过就是最后加入了一个log,关于softmax的详情大家可以看这一篇博客深度学习之感知机,激活函数,梯度消失,BP神经网络

我们只需要知道里面的一些常用参数就好,一般就是这个dim

下面是softmax的logsoftmax和他一样。

 torchvision.transforms

这个是一个功能强大的数据处理集合库

这里主要说一下transforms.Compose还有transforms.ToTensor以及transforms.Normalize(mean, std)

transforms.Compose

这个功能函数可以看作一个功能函数容器,它里面可以放多个功能函数(多个的话应该把这些功能函数放在一个列表内),当该功能函数定义好后,其内部的其他功能函数也会随之按我们给定的要求定义好,当我们调用compose的实例的时候,就会按照我们在容器内部摆放的顺序从左至右的依次调用功能函数

transforms.ToTensor

用于对载入的图片数据进行类型转换,将之前PIL图片的数据(应该是np.array()类型的)转换成Tensor数据类型的张量,以便于我们后续的数据使用

(补充:PIL默认输出的图片格式为 RGB)

(下面图片内容来自网络,侵权必删。)

transforms.Normalize(mean, std)

数据归一化处理。

下面是我在学习过程中遇到的问题:

1.归一化就是要把图片3个通道中的数据整理到[-1, 1]或者[0,1]区间,x = (x - mean(x))/std(x)只要输入数据集x就可以直接算出来,为什么Normalize()函数的mean和std还需要我们手动输入数值呢?

我的理解是,我们一开始就算好可以极大的减少运算量,如果我们自动的让他算的话,我们这个每一个图片都要算,这样运算量就极大。

2.RGB单个通道的值是[0, 255],所以一个通道的均值应该在127附近才对。我们接下来的代码如下图所示

所填的是0.5,0.5,这是为什么?

 因为我们应用了torchvision.transforms.ToTensor,他会将数据归一化到[0,1](是将数据除以255),transforms.ToTensor( )会把HWC会变成C *H *W(拓展:格式为(h,w,c),像素顺序为RGB),所以我们就应该输入0.5,0.5

(该图片内容来自网络,侵权璧必删)

我们一般来说只需了解一些常用的库函数,我们这里只是提一下我们这次会用到的函数,其余的函数若想了解的话,推荐大家看这位作者写的博客PyTorch之torchvision.tra()nsforms详解[原理+代码实现]-CSDN博客

torchvision.datasets

这里面包含了一些我们目前常用的一些数据集

我们这里主要讲一下mnist 

MNIST(手写数字数据集)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。

在我们pytorch里面的torchvision里面是有的,我们可以直接用。

torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

相关参数:

root:就是我们从网上下载的数据集所放的目录,也就是文件路径

train:train=True意思就是下载的是训练集,如果train=Flase,那就是下载的是测试集

transform: 因为我们的数据直接下载过后一般都是要进行处理,所以这个后面跟的是一个transforms数据处理函数的一个实例化对象

download:为True就是从网络上下载数据集保存在我们的root路径中,为False就是不下载。

其余参数可自行查阅

torch.utils.data.DataLoader

(图片内容来自网络,侵权必删)

(其实我们记住常用的,dataset,batch_size,shuffle这三个常用的参数就好了)

torch.nn.NLLLoss() 

torch.nn.CrossEntropyLoss()

首先我们要了解交叉熵损失函数,torch.nn.CrossEntropyLoss()

什么是熵?

熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。交叉熵越小,表示数据越接近真实样本。

(预测的概率就是我们的预测值的准确值)

torch.nn.NLLLoss() 

 torch.nn.NLLLoss输入是一个对数概率向量和一个目标标签,它与torch.nn.CrossEntropyLoss的关系可以描述为:

假设有张量x,先softmax(x)得到y,然后再log(y)得到z,然后我们已知标签b,则:

NLLLoss(z,b)=CrossEntropyLoss(x,b)

代码:

nllloss = nn.NLLLoss()
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
predict = torch.log(torch.softmax(predict, dim=-1))
label = torch.tensor([1, 2])
nllloss(predict, label)
运行结果:tensor(0.2684)

而我们用 torch.nn.CrossEntropyLoss

cross_loss = nn.CrossEntropyLoss()

predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
cross_loss(predict, label)
运行结果: tensor(0.2684)

enumerate() 函数

这是python的一个内置函数。

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

Python 2.3. 以上版本可用,2.6 添加 start 参数。

enumerate(sequence, [start=0])
sequence -- 一个序列、迭代器或其他支持迭代对象。
start -- 下标起始位置的值。
例如:
>>> seasons = ['Spring', 'Summer', 'Fall', 'Winter']
>>> list(enumerate(seasons))
[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
>>> list(enumerate(seasons, start=1))       # 下标从 1 开始
[(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]

next() 函数

python的内置函数

next() 返回迭代器的下一个项目。

next() 函数要和生成迭代器的 iter() 函数一起使用。

返回值:返回下一个项目。

next(iterable[, default])
iterable -- 可迭代对象
default -- 可选,用于设置在没有下一个元素时返回该默认值,如果不设置,又没有下一个元素则会触发 StopIteration 异常。
例如:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
 
# 首先获得Iterator对象:
it = iter([1, 2, 3, 4, 5])
# 循环:
while True:
    try:
        # 获得下一个值:
        x = next(it)
        print(x)
    except StopIteration:
        # 遇到StopIteration就退出循环
        break
结果:
1
2
3
4
5

pytorch中.detach()

在 PyTorch 中,detach() 方法用于返回一个新的 Tensor,这个 Tensor 和原来的 Tensor 共享相同的内存空间,但是不会被计算图所追踪,也就是说它不会参与反向传播,不会影响到原有的计算图,这使得它成为处理中间结果的一种有效方式,通常在以下两种情况下使用:

第一:在计算图中间,需要截断反向传播的梯度计算时。例如,当计算某个 Tensor 的梯度时,我们希望在此处截断反向传播,而不是将梯度一直传递到计算图的顶部,从而减少计算量和内存占用。此时可以使用 detach() 方法将 Tensor 分离出来。

第二:在将 Tensor 从 GPU 上拷贝到 CPU 上时,由于 Tensor 默认是在 GPU 上存储的,所以直接进行拷贝可能会导致内存不一致的问题。此时可以使用 detach() 方法先将 Tensor 分离出来,然后再将分离出来的 Tensor 拷贝到 CPU 上。

torch.squeeze()

详情请看这篇博客深度学习之张量的处理(代码笔记)

torch.optim.SGD 

torch.optim.SGD 是 PyTorch 中用于实现随机梯度下降(Stochastic Gradient Descent,SGD)优化算法的类。SGD 是一种常用的优化算法。

原理部份可以看我的这篇博客:机器学习优化算法(深度学习)-CSDN博客

我们主要介绍一下常用的参数:

torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

(图片内容来自网络,侵权必删) 

独热编码

就是如果是0-9这十个数,那我们就用[1,0,0,0,0,0,0,0,0,0]表示0,[0,1,0,0,0,0,0,0,0,0]表示1,等等

其余同理

代码实现

关于bp神经网络的原理,我们可以看这篇博客:深度学习之感知机,激活函数,梯度消失,BP神经网络-CSDN博客

我们接下来就只是讲解代码实现。

bp网络的搭建

#搭建bp神经网络
class BPNetwork(torch.nn.Module):
    def __init__(self):
        super(BPNetwork,self).__init__()
        #我们的每张图片都是28*28也就是784个像素点
        #第一个隐藏层
        self.linear1=torch.nn.Linear(784,128)
        #激活函数,这里选择Relu
        self.relu1=torch.nn.ReLU()
        #第二个隐藏层
        self.linear2=torch.nn.Linear(128,64)
        #激活函数
        self.relu2=torch.nn.ReLU()
        #第三个隐藏层:
        self.linear3=torch.nn.Linear(64,32)
        # 激活函数
        self.relu3 = torch.nn.ReLU()
        #输出层
        self.linear4=torch.nn.Linear(32,10)
        # 激活函数
        self.softmax=torch.nn.LogSoftmax()
    #前向传播
    def forward(self,x):
        #修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
        x=x.reshape(x.shape[0],-1)
        #前向传播
        x=self.linear1(x)#784*128
        x=self.relu1(x)
        x=self.linear2(x)#128*64
        x=self.relu2(x)
        x=self.linear3(x)#64*32
        x=self.relu3(x)
        x=self.linear4(x)#输出层32*10
        x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
        #上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
        return x

一些关键点都在注释上

搭建这次的BP神经网络我的隐藏层有三层,分别是128,64,32个神经元,因为我们的图片是28*28=784得,我们需要把其展开成一维,所以第一层网络是784*128得,这样输入层中每一行代表一个样本(或者说一张图片得所有像素点),因为我们的每个神经元都有一个参数,第一层网络中每一列都是一个神经元对应的参数,所以最后就是n*784和784*128两个矩阵相乘,最后得到n*128得矩阵,以此类推,最后因为我们的输出要用到独热编码的思想,所以我们的输出层调整为10个神经元,或者说最后得线性网络是32*10。

建立我们的神经网络对象

 

#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15#循环得轮数
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
    # 损失值参数
    sumloss = 0
    for imges,labels in trainload:
        a+=1
        #前向传播
        output=model(imges)
        #反向传播
        loss=critimizer(output,labels)
        loss.backward()
        #参数更新
        optimizer.step()
        #梯度清零
        optimizer.zero_grad()
        #损失累加
        sumloss+=loss.item()
    loss_.append(sumloss)
    a_.append(a)
    print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()

注意看注释。

SGD还有损失函数等等上面的补充内容里面都有说明,这里不在多加阐述。

预测

#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
#     bath_index=i
#     (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
    pre=model(imagess[i])#预测
    #第一张图片对应的pre得格式:
    # print(pre)
    # tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
    #          -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
    #        grad_fn= < LogSoftmaxBackward0 >)
    #接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
    pro = list(pre.detach().numpy()[0])
    pre_label=pro.index(max(pro))
    #print(pre_label)

注意看注释

我们实际上也可以不用next()我们可以直接用for循环比如这样。

import d2l.torch as d2l
import math
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
a=0
images=0
labelss=0
for i,j in example:
    a+=1
    index = i
    (imagess, labelss) = j
    print(imagess[0])
    print('数据集中抽取的64份数据的纯数据部份的尺寸:',imagess.shape)
    print(imagess[0].shape)
    print(labelss[0])
    print(labelss[0].shape)

    if a==1:
        break

这样得到的结果我们可以看到:

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\prictice.py 
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3412,
           0.4510,  0.2471,  0.1843, -0.5294, -0.7176, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.7412,
           0.9922,  0.9922,  0.9922,  0.9922,  0.8902,  0.5529,  0.5529,
           0.5529,  0.5529,  0.5529,  0.5529,  0.5529,  0.5529,  0.3333,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4745,
          -0.1059, -0.4353, -0.1059,  0.2784,  0.7804,  0.9922,  0.7647,
           0.9922,  0.9922,  0.9922,  0.9608,  0.7961,  0.9922,  0.9922,
           0.0980, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.8667, -0.4824, -0.8902,
          -0.4745, -0.4745, -0.4745, -0.5373, -0.8353,  0.8510,  0.9922,
          -0.1686, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.3490,  0.9843,  0.6392,
          -0.8588, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.8275,  0.8275,  1.0000, -0.3490,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.0118,  0.9922,  0.8667, -0.6549,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.5373,  0.9529,  0.9922, -0.5137, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000,  0.0431,  0.9922,  0.4667, -0.9608, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.9294,  0.6078,  0.9451, -0.5451, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.0118,  0.9922,  0.4275, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.4118,  0.9686,  0.8824, -0.5529, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8510,
           0.7333,  0.9922,  0.3020, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9765,  0.5922,
           0.9922,  0.7176, -0.7255, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7020,  0.9922,
           0.9922, -0.3961, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.7569,  0.7569,  0.9922,
          -0.0980, -0.9922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000,  0.0431,  0.9922,  0.9922,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.5216,  0.8980,  0.9922,  0.9922,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.0510,  0.9922,  0.9922,  0.7176,
          -0.6863, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.0510,  0.9922,  0.6235, -0.8588,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
数据集中抽取的64份数据的纯数据部份的尺寸: torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
tensor(7)
torch.Size([])

进程已结束,退出代码0

 我们可以看到一张图片得数据格式(这里面是已经归一化处理过的),因为我们的手写字体识别是单通道的灰度图,所以size是[1,28,28],这里很正确,对于彩色图三通道得来说,会有些不一样。

完整代码

import matplotlib.pyplot as plt
from matplotlib import font_manager

print('BP识别MNIST任务说明---------------------')
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
#导入数据集并且进行数据处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
#搭建bp神经网络
class BPNetwork(torch.nn.Module):
    def __init__(self):
        super(BPNetwork,self).__init__()
        #我们的每张图片都是28*28也就是784个像素点
        #第一个隐藏层
        self.linear1=torch.nn.Linear(784,128)
        #激活函数,这里选择Relu
        self.relu1=torch.nn.ReLU()
        #第二个隐藏层
        self.linear2=torch.nn.Linear(128,64)
        #激活函数
        self.relu2=torch.nn.ReLU()
        #第三个隐藏层:
        self.linear3=torch.nn.Linear(64,32)
        # 激活函数
        self.relu3 = torch.nn.ReLU()
        #输出层
        self.linear4=torch.nn.Linear(32,10)
        # 激活函数
        self.softmax=torch.nn.LogSoftmax()
    #前向传播
    def forward(self,x):
        #修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
        x=x.reshape(x.shape[0],-1)
        #前向传播
        x=self.linear1(x)#784*128
        x=self.relu1(x)
        x=self.linear2(x)#128*64
        x=self.relu2(x)
        x=self.linear3(x)#64*32
        x=self.relu3(x)
        x=self.linear4(x)#输出层32*10
        x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
        #上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
        return x
#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
    # 损失值参数
    sumloss = 0
    for imges,labels in trainload:
        a+=1
        #前向传播
        output=model(imges)
        #反向传播
        loss=critimizer(output,labels)
        loss.backward()
        #参数更新
        optimizer.step()
        #梯度清零
        optimizer.zero_grad()
        #损失累加
        sumloss+=loss.item()
    loss_.append(sumloss)
    a_.append(a)
    print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()
#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
#     bath_index=i
#     (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
    pre=model(imagess[i])#预测
    #第一张图片对应的pre得格式:
    # print(pre)
    # tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
    #          -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
    #        grad_fn= < LogSoftmaxBackward0 >)
    #接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
    pro = list(pre.detach().numpy()[0])
    pre_label=pro.index(max(pro))
    #print(pre_label)
    #图像显示
    img=torch.squeeze(imagess[i]).numpy()

    plt.subplot(8,8,i+1)
    plt.tight_layout()
    plt.imshow(img,cmap='gray',interpolation='none')
    plt.title(f"预测值:{pre_label}",fontproperties=font, fontsize=9)
    plt.xticks([])
    plt.yticks([])
plt.show()

















结果极其图像显示

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py 
BP识别MNIST任务说明---------------------
D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py:49: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
  x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
第1轮的损失:815.2986399680376,抽取次数和:938
第2轮的损失:281.06414164602757,抽取次数和:1876
第3轮的损失:205.76270231604576,抽取次数和:2814
第4轮的损失:159.9431014917791,抽取次数和:3752
第5轮的损失:131.47989665158093,抽取次数和:4690
第6轮的损失:109.93954652175307,抽取次数和:5628
第7轮的损失:95.26143277343363,抽取次数和:6566
第8轮的损失:85.17149402853101,抽取次数和:7504
第9轮的损失:75.1239058477804,抽取次数和:8442
第10轮的损失:68.23363681556657,抽取次数和:9380
第11轮的损失:60.05981844640337,抽取次数和:10318
第12轮的损失:54.82598690944724,抽取次数和:11256
第13轮的损失:51.70861432119273,抽取次数和:12194
第14轮的损失:46.613128249999136,抽取次数和:13132
第15轮的损失:43.05269447225146,抽取次数和:14070

进程已结束,退出代码0

 

 可以看到,经过15轮得训练后,我们基本上已经完全能识别出来了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/524756.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue 有哪些主要的指令修饰符

目录 1. 什么是指令修饰符 2. 指令修饰符有哪些 2.1. 按键修饰符 2.2. v-model修饰符 2.3. 事件修饰符 1. 什么是指令修饰符 通过 "." 指明一些指令 后缀&#xff0c;不同 后缀 封装了不同的处理操作 目的&#xff1a;简化代码 2. 指令修饰符有哪些 2.1. 按键…

SpringMVC数据响应和请求

文章目录 1.SpringMVC简介2. SpringMVC快速入门3. SpringMVC执行的流程4.SpringMVC注解解释5. 视图解析器6.SpringMVC的数据响应6.1返回ModelView对象6.2直接返回字符串6.3返回json字符串 7.SpringMVC获得请求数据7.1 获得基本类型参数7.2获得POJO类型参数7.3获取数组类型参数7…

Python | Leetcode Python题解之第15题三数之和

题目&#xff1a; 题解&#xff1a; class Solution:def threeSum(self, nums: List[int]) -> List[List[int]]:n len(nums)nums.sort()ans list()# 枚举 afor first in range(n):# 需要和上一次枚举的数不相同if first > 0 and nums[first] nums[first - 1]:continu…

【问题处理】银河麒麟操作系统实例分享,银河麒麟高级服务器操作系统mellanox 网卡驱动编译

1.Mellanox 网卡源码驱动下载链接&#xff1a; https://www.mellanox.com/downloads/ofed/MLNX_EN-5.7-1.0.2.0/MLNX_EN_SRC-5.7-1.0.2.0.tgz 2.系统及内核版本如下截图&#xff1a; 3.未升级前 mellanox 网卡驱动版本如下&#xff1a; 4.解压 “MLNX_EN_SRC-5.7-1.0.2.0.tg…

uniapp使用npm命令引入font-awesome图标库最新版本

uniapp使用npm命令引入font-awesome图标库最新版本 图标库网址&#xff1a;https://fontawesome.com/search?qtools&or 命令行&#xff1a; 引入 npm i fortawesome/fontawesome-free 查看版本 npm list fortawesome在main.js文件中&#xff1a; import fortawesome/fo…

使用Springboot配置生产者、消费者RabbitMQ?

生产者服务 1、引入依赖以及配置rabbitmq 此时我们通过使用springboot来快速搭建一个生产者服务 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId> </dependency> applica…

网络安全流量平台_优缺点分析

FlowShadow&#xff08;流影&#xff09;&#xff0c;Ntm&#xff08;派网&#xff09;&#xff0c;Elastiflow。 Arkimesuricata&#xff0c;QNSMsuricata&#xff0c;Malcolm套件。 Malcolm套件优点&#xff1a;支持文件还原反病毒引擎&#xff08;clamav/yara&#xff09;…

webpack环境配置分类结合vue使用

文件目录结构 按照目录结构创建好文件 控制台执行: npm install /config/webpack.common.jsconst path require(path) const {merge} require(webpack-merge) const {CleanWebpackPlugin} require(clean-webpack-plugin) const { VueLoaderPlugin } require(vue-loader); c…

SQLite 4.9的虚拟表机制(十四)

返回&#xff1a;SQLite—系列文章目录 上一篇:SQLite 4.9的 OS 接口或“VFS”&#xff08;十三&#xff09; 下一篇&#xff1a;SQLite—系列文章目录 1. 引言 虚拟表是向打开的 SQLite 数据库连接注册的对象。从SQL语句的角度来看&#xff0c; 虚拟表对象与任何其他…

目标跟踪——行人检测数据集

一、重要性及意义 目标跟踪和行人检测是计算机视觉领域的两个重要任务&#xff0c;它们在许多实际应用中发挥着关键作用。为了推动这两个领域的进步&#xff0c;行人检测数据集扮演着至关重要的角色。以下是行人检测数据集的重要性及意义的详细分析&#xff1a; 行人检测数据…

MySQL操作DML

目录 1.概述 2.插入 3.更新 4.删除 5.查询 6.小结 1.概述 数据库DML是数据库操作语言&#xff08;Data Manipulation Language&#xff09;的简称&#xff0c;主要用于对数据库中的数据进行增加、修改、删除等操作。它是SQL语言的一部分&#xff0c;用于实现对数据库中数…

python统计分析——多组比较

参考资料&#xff1a;python统计分析【托马斯】 一、方差分析 1、原理 方差分析的思想是将方差分为组间方差和组内方差&#xff0c;看看这些分布是否符合零假设&#xff0c;即所有组都来自同一分布。区分不同群体的变量通常被称为因素或处理。 作为对比&#xff0c;t检验观察…

(React Hooks)前端八股文修炼Day9

一 对 React Hook 的理解&#xff0c;它的实现原理是什么 React Hooks是React 16.8版本中引入的一个特性&#xff0c;它允许你在不编写类组件的情况下&#xff0c;使用state以及其他的React特性。Hooks的出现主要是为了解决类组件的一些问题&#xff0c;如复杂组件难以理解、难…

科技云报道:卷完参数卷应用,大模型落地有眉目了?

科技云报道原创。 国内大模型战场的比拼正在进入新的阶段。 随着产业界对模型落地的态度逐渐回归理性&#xff0c;企业客户的认知从原来的“觉得大模型什么都能做”的阶段&#xff0c;已经收敛到“大模型能够给自身业务带来什么价值上了”。 2023 年下半年&#xff0c;不少企…

Zotero参考文献引用(适用国内)

配置环境 【1】Windows 10 【2】Office 2021 【3】Zotero 6 目录 配置环境 前言 一、软件及插件安装 二、网页信息抓取 三、文献引用 注意事项 前言 本教程记录使用zotero作为自动插入参看文献引用工具的全流程&#xff0c;流程较为简洁&#xff0c;适用于平时论文写作不多&…

C语言 | Leetcode C语言题解之第15题三数之和

题目&#xff1a; 题解&#xff1a; int cmp(const void *x, const void *y) {return *(int*)x - *(int*)y; } //判断重复的三元组 bool TheSame(int a, int b, int c, int **ans, int returnSize) {bool ret true;for(int i 0;i < returnSize;i){if(a ans[i][0] &&…

书生·浦语训练营二期第三次笔记-茴香豆:搭建你的 RAG 智能助理

RAG学习文档1&#xff1a; https://paragshah.medium.com/unlock-the-power-of-your-knowledge-base-with-openai-gpt-apis-db9a1138cac4 RAG学习文档2: https://blog.demir.io/hands-on-with-rag-step-by-step-guide-to-integrating-retrieval-augmented-generation-in-llms-a…

2024/4/1—力扣—主要元素

代码实现&#xff1a; 思路&#xff1a;摩尔投票算法 int majorityElement(int *nums, int numsSize) {int candidate -1;int count 0;for (int i 0; i < numsSize; i) {if (count 0) {candidate nums[i];}if (nums[i] candidate) {count;} else {count--;}}count 0;…

操作系统复习

虚拟内存 内存(memory)资源永远都是稀缺的&#xff0c;当越来越多的进程需要越来越来内存时&#xff0c;某些进程会因为得不到内存而无法运行&#xff1b; 虚拟内存是计算机系统内存管理的一种技术。它使得应用程序认为它拥有连续的可用的内存&#xff0c;而实际上&#xff0…

网络与并发编程(一)

并发编程介绍_串行_并行_并发的区别 串行、并行与并发的区别 串行(serial)&#xff1a;一个CPU上&#xff0c;按顺序完成多个任务并行(parallelism)&#xff1a;指的是任务数小于等于cpu核数&#xff0c;即任务真的是一起执行的并发(concurrency)&#xff1a;一个CPU采用时间…