神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。

1. 神经网络

2. 案例说明

具体来说,设计一个三层的神经网络。以数字图像作为输入,经过神经网络的计算,识别出图像中的数字是几,从而实现数字图像的分类。

3. 视频讲解内容的提纲

4. 神经网络的设计和实现

我们要处理的数据是28*28像素的灰色通道图像。

这样的灰色图像包括了28*28=784个数据点。需要先将他展平为1*784大小的向量。然后将这个向量输入到神经网络中。

用一个三层神经网络处理图片对应的向量X。输入成需要接收784维的图片向量X。X里面每个维度的数据都有一个神经元来接收。因此输入层要包含784个神经元。

隐藏成用于特征提取特征向量,将输入的特征向量处理成更高级的特征向量。

因为手写数字图像识别并不复杂,所以将隐藏层的神经元个数设置为256。这样,输入层和隐藏层之间就会有个784*256的线性层。它可以将一个784维的输入向量转换为256维的输出向量。

该输出向量会继续向前传播到达输出层。

由于最终要将数字图像识别为0~9,十种可能的数字。因此,输出层需要定义10个神经元,对应这十种数字。

256维的向量在经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。这个10维的向量就代表了10个数字的预测得分。

为了继续得到输出层的预测概率,还要将输出层的输出输入到softmax层。softmax层会将10维的向量转换为10个概率值p0~p9。p0~p9相加的总和等于1.

5. 神经网络的Pytorch实现

import torch
from torch import nn

# 定义神经网络Network
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        # 线性层1,输入层和隐藏层之间的线性层
        self.layer1 = nn.Linear(784, 258)
        # 线性层2,隐藏层和输出层之间的线性层
        self.layer2 = nn.Linear(256, 10)
    # 在前向传播,forward函数中,输入为图像x
    def forward(self, x):
        x = x.view(-1, 28 * 28) # 使用view函数,将x展平
        x = self.layer1(x) # 将x输入到layer1
        x = torch.relu(x) # 使用relu激活
        return self.layer2(x) # 输入至layer2计算结果

    # 这里没有直接定义softmax层,因为后面会使用CrossEntropyLoss损失函数
    # 在这个损失函数中,会实现softmax的计算

6. 训练数据的准备和处理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 初学只要知道大致的数据处理流程即可
if __name__ == '__main__'
    # 实现图像的预处理pipeline
    transform = trnasforms.Compose([
        # 转换成单通道灰度图
        transforms.Grayscale(num_output_channels=1),
        # 转换为张量
        transforms.ToTensor()
    ])

    # 使用ImageFolder函数,读取数据文件夹,构建数据集dataset
    # 这个函数会将保持数据的文件夹的名字,作为数据的标签,组织数据
    train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)
    test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)

    # 打印他们的长度
    print("train_dataset length: ", len(train_dataset))
    print("test_dataset length: ", len(test_dataset))

    # 使用train_loader, 实现小批量的数据读取
    # 这里设置小批量的大小,batch_size=64. 也就是每个批次,包括64个数据
    train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)
    # 打印train_loader的长度
    print("train_loader length: ", len(train_loader))
    # 6000个训练数据,如果每个小批量,读入64个样本,那么60000个数据会被分成938组
    # 938*64=60032,说明最后一组不够64个数据

    # 循环遍历train_loader
    # 每一次循环,都会取出64个图像数据,作为一个小批量batch
    for batch_idx, (data, label) in enumerate(train_loader)
        if batch_idx == 3:
            break
        print("batch_idx: ", batch_idx)
        print("data.shape: ", data.shape) # 数据的尺寸
        print("label: ", label.shape) # 图像中的数字
        print(label)

7. 模型的训练和测试

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

if __name__ == '__main__'
    # 图像的预处理
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ])

    # 读入并构造数据集
    train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)
    print("train_dataset length: ", len(train_dataset))

    # 小批量的数据读入
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    print("train_loader length: ", len(train_loader))

    # 在使用Pytorch训练模型时,需要创建三个对象:
    model = Network() # 1.模型本身,就是我们设计的神经网络
    optimizer = optim.Adam(model.parameters()) #2.优化器,优化模型中的参数
    criterion = nn.CrossEntropyLoss() #3.损失函数,分类问题,使用交叉熵损失误差

    # 进入模型的循环迭代
    # 外层循环,代表了整个训练数据集的遍历次数
    for epoch in range(10):
        # 内层循环使用train_loader, 进行小批量的数据读取
        for batch_idx, (data, label) in enumerate(train_loader):
            # 内层每循环一次,就会进行一次梯度下降算法
            # 包括了5个步骤
            # 这5个步骤是使用pytorch框架训练模型的定式,初学时先记住即可
            # 1. 计算神经网络的前向传播结果
            output = model(data)
            # 2. 计算output和标签label之间的损失loss
            loss = criterion(output, label)
            # 3. 使用backward计算梯度
            loss.backward()
            # 4. 使用optimizer.step更新参数
            optimizer.step()
            # 5.将梯度清零
            optimizer.zero_grad()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch + 1}/10"
                      f"| Batch {batch_idx}/{len(train_loader)}"
                      f"| Loss: {loss.item():.4f}"
                      )
    torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torch

if __name__ == '__main__'
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ])
    # 读取测试数据集
    test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)
    print("test_dataset length: ", len(test_dataset))

    model = Network() # 定义神经网络模型
    model.load_state_dict(torch.load('mnist.pth')) # 加载刚刚训练好的模型文件

    rigth = 0 # 保存正确识别的数量
    for i, (x, y) in enumerate(test_dataset):
        output = model(x) # 将其中的数据x输入到模型
        predict = output.argmax(1).item() # 选择概率最大标签的作为预测结果
        # 对比预测值predict和真实标签y
        if predict == y:
            right += 1
        else:
            # 将识别错误的样例打印出来
            img_path = test_dataset.samples[i][0]
            print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")
    # 计算出测试效果
    sample_num = len(test_dataset)
    acc = right * 1.0 / sample_num
    print("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/799036.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

采用自动微分进行模型的训练

自动微分训练模型 简单代码实现: import torch import torch.nn as nn import torch.optim as optim# 定义一个简单的线性回归模型 class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear nn.Linear(1, 1) …

链接追踪系列-07.logstash安装json_lines插件

进入docker中的logstash 容器内: jelexbogon ~ % docker exec -it 7ee8960c99a31e607f346b2802419b8b819cc860863bc283cb7483bc03ba1420 /bin/sh $ pwd /usr/share/logstash $ ls bin CONTRIBUTORS Gemfile jdk logstash-core modules tools x-pack …

如何预防最新的baxia变种勒索病毒感染您的计算机?

引言 在当今数字化时代,网络安全威胁层出不穷,其中勒索病毒已成为企业和个人面临的重大挑战之一。近期,.baxia勒索病毒以其高隐蔽性和破坏性引起了广泛关注。本文将详细介绍.baxia勒索病毒的特点、传播方式,并给出相应的应对策略…

2024-07-15 Unity插件 Odin Inspector3 —— Button Attributes

文章目录 1 说明2 Button 特性2.1 Button2.2 ButtonGroup2.3 EnumPaging2.4 EnumToggleButtons2.5 InlineButton2.6 ResponsiveButtonGroup 1 说明 ​ 本章介绍 Odin Inspector 插件中有关 Button 特性的使用方法。 2 Button 特性 2.1 Button 依据方法,在 Inspec…

YOLOv8训练自己的数据集(超详细)

一、准备深度学习环境 本人的笔记本电脑系统是:Windows10 YOLO系列最新版本的YOLOv8已经发布了,详细介绍可以参考我前面写的博客,目前ultralytics已经发布了部分代码以及说明,可以在github上下载YOLOv8代码,代码文件夹…

力扣经典题目之->移除值为val元素的讲解,的实现与讲解

一:题目 博主本文将用指向来形象的表示下标位的移动。 二:思路 1:两个整形,一个start,一个end,在一开始都 0,即这里都指向第一个元素。 2:在查到val之前,查一个&…

昇思25天学习打卡营第12天|MindSpore 助力下的 GPT2:数据集加载处理及模型全攻略

环境配置 %%capture captured_output 此乃 Jupyter Notebook 中的一个魔法命令,其作用在于捕获后续单元格中的输出,并将之存储于变量 captured_output 当中,而非直接于输出区域予以显示。如此一来,便可隐匿某些可能存在的输出信息…

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(九)-无人机服务区分离

引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…

小程序里面使用vant ui中的vant-field组件,如何使得输入框自动获取焦点

//.wxml <van-fieldmodel:value"{{ userName }}"placeholder"请输入学号"focus"{{focusUserName}}"/>// .js this.setData({focusUserName: true});vant-field

钡铼ARMxy控制器在智能网关中的应用

随着IoT物联网技术的飞速发展&#xff0c;智能网关作为连接感知层与网络层的枢纽&#xff0c;可以实现感知网络和通信网络以及不同类型感知网络之间的协议转换。钡铼技术的ARMxy系列控制器凭借其高性能、低功耗和高度灵活性的特点&#xff0c;在智能网关中发挥了关键作用&#…

RPC与服务的注册发现

文章目录 1. 什么是远程过程调用(RPC)?2. RPC的流程3. RPC实践4. RPC与REST的区别4.1 RPC与REST的相似之处4.2 RPC与REST的架构原则4.3 RPC与REST的主要区别 5. RPC与服务发现5.1 以zookeeper为服务注册中心5.2 以etcd为服务注册中心 6. 小结参考 1. 什么是远程过程调用(RPC)?…

大语言模型诞生过程剖析

过程图如下 &#x1f4da; 第一步&#xff1a;海量文本的无监督学习 得到基座大模型&#x1f389; &#x1f50d; 原料&#xff1a;首先&#xff0c;我们需要海量的文本数据&#xff0c;这些数据可以来自互联网上的各种语料库&#xff0c;包括书籍、新闻、科学论文、社交媒体帖…

<数据集>光伏板缺陷检测数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;2400张 标注数量(xml文件个数)&#xff1a;2400 标注数量(txt文件个数)&#xff1a;2400 标注类别数&#xff1a;4 标注类别名称&#xff1a;[Crack,Grid,Spot] 序号类别名称图片数框数1Crack8688922Grid8248843S…

全栈智能家居系统设计方案:STM32+Linux+多协议(MQTT、Zigbee、Z-Wave)通信+云平台集成

1. 项目概述 随着物联网技术的快速发展,智能家居系统正在成为现代生活中不可或缺的一部分。本文介绍了一个基于STM32微控制器和Linux系统的智能家居解决方案,涵盖了硬件设计、软件架构、通信协议以及云平台集成等方面。 该系统具有以下特点: 采用STM32作为终端设备的控制核心…

springboot3——项目部署

springboot的项目开发完了&#xff0c;怎么样把他放到服务器上或者生产环境上让他运行起来跑起来。就要牵扯到项目部署&#xff0c;打包的方式了。 springboot支持jar和war: 打jar包&#xff1a;默认方式&#xff0c;项目开发完打个jar包&#xff0c;通过命令把jar包起起来就…

汇川ST 实现分拣

//初始化 IF init FALSE THEN// 初始化init : 1 ;//45 Y数组 BOOL[8] [OFF发料Y OFF分拣Y OFF送料Y OFF取料Y OFF摆取Y OFF摆放Y OFF升降Y OFF夹料Y] [OFF发料Y OFF分拣Y OFF送料Y OFF取料Y OFF摆取Y OFF摆放Y OFF升降Y OFF夹料Y] 不保持 私有 Y0(*Y数组[0] BOOL OFF 发料…

MySQL 中的几种锁

MySQL 中的锁 #按锁粒度如何划分? 按锁粒度划分的话&#xff0c;MySQL 的锁有&#xff1a; 表锁&#xff1a;开销小&#xff0c;加锁快&#xff1b;锁定力度大&#xff0c;发生锁冲突概率高&#xff0c;并发度最低;不会出现死锁。行锁&#xff1a;开销大&#xff0c;加锁慢…

unity宏编译版本

在写c程序的时候我们通常可以用不同的宏定义来控制不同版本的编译内容&#xff0c;最近有个需求就是根据需要编译一个完全体验版本&#xff0c;就想到了用vs的那套方法。经过研究发现unity也有类似的控制方法。 注意这里设置完后要点击右下的应用&#xff0c;我起先就没有设置…

7/13 - 7/15

vo.setId(rs.getLong("id"))什么意思&#xff1f; vo.setId(rs.getLong("id")); 这行代码是在Java中使用ResultSet对象&#xff08;通常用于从数据库中检索数据&#xff09;获取一个名为"id"的列&#xff0c;并将其作为long类型设置为一个对象…

深度学习基础:Numpy 数组包

数组基础 在使用导入 Numpy 时&#xff0c;通常给其一个别名 “np”&#xff0c;即 import numpy as np 。 数据类型 整数类型数组与浮点类型数组 为了克服列表的缺点&#xff0c;一个 Numpy 数组只容纳一种数据类型&#xff0c;以节约内存。为方便起见&#xff0c;可将 Nu…