人工智能应用-实验5-BP 神经网络分类手写数据集

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡代码🧡🧡
    • 🧡🧡分析结果🧡🧡
    • 🧡🧡实验总结🧡🧡

🧡🧡实验内容🧡🧡

编写 BP 神经网络分类, 实现对 MNIST 数据集分类的操作。


🧡🧡代码🧡🧡

需要配置torch。由于是小demo。为了提高效率,我采用的是google的colab进行实验编码,省去配环境的烦恼。

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim

#@title 加载
transform = transforms.Compose([
                transforms.ToTensor(), # 转为张量,同时如果是图片(uint8)类型,会自动进行归一化到(0,1)
                transforms.Normalize( (0.5, ) , (0.5, ) ) # 转为std=0.5、mean=0.5的分布, 灰色图像,通道只有一个  将值域(0,1)再次转为(-1,1)
                ])
train_set = datasets.MNIST('train_set', # 下载到该文件夹下
              download=not os.path.exists('train_set'), # 是否下载,如果下载过,则不重复下载
              train=True, # 是否为训练集
              transform=transform # 要对图片做的transform
              )
test_set = datasets.MNIST('test_set',
              download=not os.path.exists('test_set'),
              train=False,
              transform=transform
              )
test_set
# train_set[0][0]
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

dataiter = iter(train_loader)
images, labels = next(iter(dataiter))
print(images.shape)
print(labels.shape)


#@title Bp net
class BP_Net(nn.Module):
    def __init__(self):
        super().__init__()
        """
        定义第一个线性层,
        输入为图片(28x28),
        输出为第一个隐层的输入,大小为128。
        """
        self.linear1 = nn.Linear(28 * 28, 128)
        self.relu1 = nn.ReLU() # 在第一个隐层使用ReLU激活函数
        """
        定义第二个线性层,
        输入是第一个隐层的输出,
        输出为第二个隐层的输入,大小为64。
        """
        self.linear2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU() # 在第二个隐层使用ReLU激活函数
        """
        定义第三个线性层,
        输入是第二个隐层的输出,
        输出为输出层,大小为10
        """
        self.linear3 = nn.Linear(64, 10)
        self.softmax = nn.LogSoftmax(dim=1) # 最终的输出经过softmax进行归一化

    def forward(self, x):
        """
        定义神经网络的前向传播
        x: 输入的图片数据, shape为(64, 1, 28, 28)
        """
        x = x.view(x.shape[0], -1) # 首先将x的shape转为(64, 784)

        # 进行前向传播
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linear2(x)
        x = self.relu2(x)
        x = self.linear3(x)
        x = self.softmax(x)

        return x
model = BP_Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

#@title 评估
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
model.eval() # 将模型设置为评估模式

correct_count, all_count = 0, 0
predictions = [] # 预测结果列表
true_labels = [] # 真实标签列表

for images,labels in test_loader: # 从test_loader中一批一批加载图片
    for i in range(len(labels)):
        logps = model(images[i])  # 进行前向传播,获取预测值
        probab = list(logps.detach().numpy()[0]) # 将预测结果转为概率列表。[0]是取第一张照片的10个数字的概率列表(因为一次只预测一张照片)
        pred_label = probab.index(max(probab)) # 取最大的index作为预测结果
        true_label = labels.numpy()[i]
        if(true_label == pred_label): # 判断是否预测正确
            correct_count += 1
        all_count += 1
        predictions.append(pred_label)
        true_labels.append(true_label)

# 准确率
print("Number Of Images Tested =", all_count)
print("Model Accuracy =", (correct_count/all_count))

# 混淆矩阵
def plot_confusion_matrix(cm, classes):
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    thresh = cm.max() / 2
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

cm = confusion_matrix(true_labels, predictions)
classes = [str(i) for i in range(10)]
plot_confusion_matrix(cm, classes)

#@title 验证
model.train() # 切回训练模式

## 验证本地图片
import cv2
from PIL import Image
for num in range(0,10):
    img = cv2.imread('./myImg/{}.jpg'.format(num), 0)  # 以灰度图的方式读取要预测的图片
    img = cv2.resize(img, (28, 28))
    height, width = img.shape
    dst = np.zeros((height, width), np.uint8)
    for i in range(height):
        for j in range(width):
            dst[i, j] = 255 - img[i, j]
    dst= dst / 255.0 #归一化
    dst = (dst - 0.5) / 0.5  # 标准化到[-1, 1]
    img = dst
    # print(img)
    img = np.array(img).astype(np.float32)
    img = np.expand_dims(img, 0)  # 扩展后,为[1,28,28]
    img = np.expand_dims(img, 0)  # 扩展后,为[1,1,28,28]
    img = torch.from_numpy(img)
    # print(img.shape)
    with torch.no_grad():
        output=model(img)
    # print(output.data)
    print(output.data.max(1)[1])


🧡🧡分析结果🧡🧡

数据预处理

  • 加载数据集:
    加载torch自带的minst数据集
  • 转换数据:
    先转为tensor变量(相当于直接除255归一化到值域为(0,1))
    在这里插入图片描述
    然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)
    在这里插入图片描述

设置基本参数:
在这里插入图片描述

构建BP神经网络:
如下,输入为一张2828图片,拆解成2828=784个特征,最终经过三个线性层(784,128)、(128、64)、(64,10),输出为10个特征(对应10个类),归一化这10个特征,它们的大小即认为它属于哪张图片的概率值,取出概率最大的特征对应的类别作为最终预测类别。
在这里插入图片描述

模型训练:
在这里插入图片描述
在这里插入图片描述

模型评估:
准确率:达到97.69%
在这里插入图片描述
混淆矩阵
在这里插入图片描述

接下来,分析网络层数对分类准确率的影响。
被对照试验:隐藏层数目改为2,神经元数目分别为128、64
准确率为:97.69%
对照实验1:隐藏层数目改为3,神经元数目分别为256、128、64
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵如下:97.55%
在这里插入图片描述
对照实验2:隐藏层数目改为5,神经元数目分别为512、256、128、64、32
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵:97.85%
在这里插入图片描述
总结结果如下表:
在这里插入图片描述
分析可知:

  • 运行时间:从实验结果来看,在增加隐藏层数的情况下,运行时间明显增加。
  • 准确率:实验结果显示,在增加隐藏层数的情况下,准确率大体上有所提升,但是总体变化幅度并不大,可能是因为epochs或者随机梯度下降等参数已经设为较优值,使得准确率已经接近最优效果,从而导致增加网络层数的提优空间并不明显。
    综合来看,增加隐藏层数对于提高分类准确率有一定的帮助,但是也会明显增加运行时间。其次,需要注意的是,若增加隐藏层数并非一定能够带来准确率的提升,过多的隐藏层可能会导致过拟合等问题。

🧡🧡实验总结🧡🧡

在完成基础实验上,我自己画了几张数字图,以对模型进行验证
在这里插入图片描述
结果如下,可以看到,对数字1和数字5分类错误(分布预测成了5和8),其余均分类正确,大体上效果良好。考虑原因,可能是因为minst的数据集是“黑底白字”,而我手画的图片则为“黑字白底”,导致了一些误差。
在这里插入图片描述
理论理解:
通过本次实验,大体上掌握了BP神经网络的定义和结构,总的来说,BP神经网络可以理解为一个黑盒子,通过不断根据loss进行反向传播,最终目的就是得到线性参数w和b,从而根据Y=wx+b 对输入的新x进行预测分类。
代码实践:
一开始想用纯numpy进行BP网络的编写,但是在编写后向传播时,可能是线代和高数知识有些遗忘,求导数时琢磨了很久。后面还是选择直接使用pytorch进行编写,也容易调参,方便进行实验。对我而言,代码中比较纠结的是shape的转换和传入,因此最好多查看中间过程的shape,以便更好理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/662559.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

图形学初识--空间变换

文章目录 前言正文矩阵和向量相乘二维变换1、缩放2、旋转3、平移4、齐次坐标下总结 三维变换1、缩放2、平移3、旋转绕X轴旋转:绕Z轴旋转:绕Y轴旋转: 结尾:喜欢的小伙伴可以点点关注赞哦 前言 前面章节补充了一下基本的线性代数中…

统计计算五|MCMC( Markov Chain Monte Carlo)

系列文章目录 统计计算一|非线性方程的求解 统计计算二|EM算法(Expectation-Maximization Algorithm,期望最大化算法) 统计计算三|Cases for EM 统计计算四|蒙特卡罗方法(Monte Carlo Method) 文章目录 系列文章目录一…

彻底理解浏览器的进程与线程

彻底理解浏览器的进程与线程 什么是进程和线程,两者的区别及联系浏览器的进程和线程总结浏览器核心进程有哪些浏览器进程与线程相关问题 什么是进程和线程,两者的区别及联系 进程和线程是操作系统中用于管理程序执行的两个基本概念进程的定义及理解 定义…

今日分享站

同志们,字符函数和字符串函数已经全部学习完啦,笔记也已经上传完毕,大家可以去看啦。字符函数和字符串函数and模拟函数 加油!!!!!

应用上架后的关键!苹果商店(AppStore)运营策略与技巧指南

1、运营期:怎么能活得好? ▍封号和下架问题 14天 在收到苹果封号通知(我们将会在14天后封你的账号)如果觉得冤枉可以在14天内进行申诉。14天并不是一个严格准确的时间,有可能会在第15天或者在第20天,甚至…

开源基于Node编写的批量HTML转PDF

LTPP批量HTML转PDF工具 Github 地址 LTPP-GIT 地址 官方文档 功能 LTPP 批量 HTML 转 PDF 工具支持将当前目录下所有 HTML 文件转成 PDF 文件,并且在新目录中保存文件结构与原目录结构一致 说明 一共两个独立版本,html-pdf 目录下是基于 html-pdf 模…

【考研数学】数学一和数学二哪个更难?如何复习才能上90分?

很明显考研数学一更难! 不管是复习量还是题目难度 对比项考研数学一考研数学二适用专业理工科类及部分经济学类理工科类考试科目高等数学、线性代数、概率论与数理统计高等数学、线性代数试卷满分150分150分考试时间180分钟180分钟试卷内容结构高等数学约60%&…

在 iCloud.com 上导入、导出或打印联系人

想将iPhone上的电话本备份一份到本地电脑上,发现iTunes好像只是音乐播放了,不再支持像电话本等功能,也不想通过其他第三方软件,好在可以通过iCloud进行导入导出。下面只是对操作过程进行一个图片记录而已,文字说明可以…

CSS中的Flex布局

目录 一.什么是Flex布局 二.Flex布局使用 2.1Flex使用语法 2.2基本概念 三.容器的属性 3.1所有属性概述 3.2flex-direction 3.3flex-wrap 3.4flex-flow 3.5justify-content 3.6align-items 3.7align-content 四.项目(子元素)的属性 4.1所有属性概述 4.2order 4…

失落的方舟 命运方舟台服账号怎么注册 游戏账号最全图文注册教程

探索奇幻大陆阿克拉西亚的奥秘,加入《失落的方舟》(Lost Ark)这场史诗般的冒险。这是一款由Smilegate精心雕琢的MMORPG巨作,它融合了激烈动作战斗与深邃故事叙述,引领玩家步入一个因恶魔侵袭而四分五裂的世界。作为勇敢…

非量表题如何进行信效度分析

效度是指设计的题确实在测量某个东西,一般问卷中使用到。如果是量表类的数据,其一般是用因子分析这种方法去验证效度水平,其可通过因子分析探究各测量量表的内部结构情况,分析因子分析得到的内部结构与自己预期的内部结构进行对比…

Websocket助手

功能介绍 WS助手是WebSocket调试的开发工具,该客户端工具可以帮助开发人员快速连接到测试/生产环境,它可以帮助您监视和分析 Websocket 消息,并在开发过程中解决问题;可以模拟客户端实现与服务器的数据交互,并完成批量…

QT基础初学

目录 1.什么是QT 2.环境搭建 QT SDK的下载 QT的使用 QT构建项目 快捷指令 QT的简单编写 对象树 编码问题 组件 初识信号槽 窗口的释放 窗口坐标体系 1.什么是QT QT 是一个跨平台的 C 图形用户界面库,支持多个系统,用于开发具有图形界面的应…

乡村振兴与农业科技创新:加大农业科技研发投入,推动农业科技创新,促进农业现代化和美丽乡村建设

一、引言 在当代中国,乡村振兴已成为国家发展的重要战略之一。作为国民经济的基础,农业的发展直接关系到国家的稳定和人民的福祉。随着科技的不断进步,农业科技创新在推动农业现代化和美丽乡村建设中发挥着越来越重要的作用。本文旨在探讨如…

线下实体店相亲机构不靠谱!靠谱的相亲交友婚恋软件有哪些?单身找对象必看!

当下大龄剩男剩女矛盾越来越大,单身市场越来越火热,相亲市场需求也在逐渐变大,线下相亲实体店也越来越多。但是从个人经历来说,实体店相亲不靠谱,收费很高,拖又多,根本脱不了单。现在呢&#xf…

echarts-dataset,graphic,dataZoom, toolbox

dataset数据集配置数据 dataset数据集,也可以完成数据的映射,一般用于一段数据画多个图表 例子: options {tooltip: {},dataset: {source: [["product", "2015", "2016", "2017"],["test&q…

视频汇聚EasyCVR视频监控平台GA/T 1400协议特点及应用领域解析

GA/T 1400协议,也被称为视图库标准,全称为《公安视频图像信息应用系统》。这一标准在公安系统中具有举足轻重的地位,它详细规定了公安视频图像信息应用系统的设计原则、系统结构、视频图像信息对象、统一标识编码、系统功能、系统性能、接口协…

亚马逊VC账号产品热销,在美国哪些智能小家电产品最好卖?

亚马逊VC账号产品在美国市场的热销,反映了消费者对于特定智能小家电产品的强烈需求。智能小家电产品因其实用性、便捷性和科技感,近年来在美国市场备受追捧。 以下是一些在亚马逊VC账号上热销的智能小家电产品: 智能扫地机器人 这类产品在美…

重庆耶非凡科技选品师项目大揭秘:成功背后的故事与经验

在电商行业迅猛发展的今天,选品师这一职业愈发受到市场的关注。重庆耶非凡科技有限公司凭借其专业的选品团队和科学的选品方法,成为众多商家关注的焦点。那么,该公司的选品师项目是否真的有成功的案例呢?接下来,我们将从多个角度…

计算机算法中的数字表示法——原码、反码、补码

目录 1.前言2.研究数字表示法的意义3.数字表示法3.1 无符号整数3.2 有符号数值3.3 二进制补码(Twos Complement, 2C)3.4 二进制反码(也称作 1 的补码, Ones Complement, 1C)3.5 减 1 表示法(Diminished one System, D1)3.6 原码、反码、补码总结 1.前言 昨天有粉丝让我讲解下定…