简单搭建卷积神经网络实现手写数字10分类

搭建卷积神经网络实现手写数字10分类

1.思路流程

1.导入minest数据集

2.对数据进行预处理

3.构建卷积神经网络模型

4.训练模型,评估模型

5.用模型进行训练预测

一.导入minest数据集

 

MNIST--->raw--->test-->(0,1,2...) 10个文件夹

MNIST--->raw--->train-->(0,1,2...) 10个文件夹

共60000张图片.可自己去网上下载

二.对数据进行预处理

----读取图片,将图片先转为张量。

img=cv2.imread(path)

----将图片进行归一化,即将像素值标准化到0-1之间

img_tensor=util.transforms_train(img)

----裁剪,翻转等,实现数据增强。

数据增强:通过对原始图像进行旋转、翻转等操作,可以增加数据的多样性。这有助于模型学习到更具泛化性的特征,减少对特定方向或位置的依赖,从而提高模型的鲁棒性和准确性

transforms_train=transforms.Compose([
    # transforms.CenterCrop(10),
    # transforms.PILToTensor(),
    transforms.ToTensor(),#归一化,转tensor
    transforms.Resize((28,28)),
    transforms.RandomVerticalFlip()
])

ps:为什么要归一化

  1. 消除量纲影响:不同图像的像素值范围可能差异很大。归一化可以将像素值范围统一到一个特定的区间,例如 [0, 1] 或 [-1, 1],消除不同图像之间因像素值范围差异带来的影响,使模型更关注图像的特征和结构,而不是像素值的绝对大小。

  2. 提高训练稳定性:有助于优化算法的收敛性和稳定性。如果像素值范围较大且分布不均匀,可能导致梯度计算不稳定,从而影响模型的训练效率和效果。

  3. 缓解过拟合:一定程度上可以减少数据中的噪声和异常值对模型的影响,降低模型对某些特定像素值的过度依赖,从而提高模型的泛化能力,减少过拟合的风险。

三.构建卷积神经网络模型

常见卷积神经网络(CNN),主要由卷积,池化,全连接组成。卷积核在输入图像上滑动,通过卷积运算提取局部特征。卷积核在整个图像上重复使用,大大减少了模型的参数数量,降低了计算复杂度,同时也增强了模型对平移不变性的鲁棒性。池化层对特征进行压缩,提取主要特征,减少噪声和冗余信息。

x=torch.randn(2,3,28,28)

用x表示初始图形的信息。为了简单理解,简单表述。其中

2--->两张图片

3--->图片的通道数是3个,即 RGB

28,28--->图片的宽高是28px 28px

采用以上的神经网络conv为卷积操作,maxpool为池化。Linear为全连接。relu为激活函数。

进入全连接层时需要将展平。torch.Size([2, 16, 5, 5])--->torch.Size([2, 400])

x=torch.flatten(x,1)

因为全连接是只进行的线性的变化。所以要把每张图片的维数参数降为1。

使用print(summary(net, x))可查看网络的层次结构。其中-1就表示自己算,是多少张图片就是多少

输入的的是x=torch.randn(2,3,28,28),最终输出的是(2,10)

四.训练模型,评估模型

需要初始化之前的数据和网络,然后选择合适的优化器和损失函数,学习率和加载图片的批次去训练模型。使用loss_avg和accurary来评估模型的性能。对于pytorch来说优化器可以实现自动梯度清0,自动更新参数。我们需要主要的是就实现其中的维度的转化。loss越小越接近真实值。其中计算精度的方法使用one-hot编码。其中0表示[0,0,0,0,0,0,0,0,0,0],1表示[0,1,0,0,0,0,0,0,0,0],2表示[0,0,1,0,0,0,0,0,0,0].。。。其他依次类推。我们把用网络得出的参数,类似[0.1,0.2,0.1,0.5,0,0,0,0,0,0](数据我随便写的),然后用Python的argmax去处最大值的索引与one-hot真实值的索引相比,如果相等就是正确的结果。

----本次实验使用的是MSE损失函数

----lr(学习率)设为0.01

----使用的优化器Adam ,其实其他优化器你也可以随便试试。

Adam 算法的主要优点包括:

  1. 自适应学习率:能够为每个参数自适应地调整学习率。

  2. 偏差校正:在初始阶段对梯度估计进行校正,加速初期的学习速率。

  3. 适应性强:在很多不同的模型和数据集上都表现出良好的性能。

  4. 实现简单,计算高效,对内存需求少。

使用tensorboard进行可视化

五.用模型进行训练预测

需要读取之前训练好的模型,然后用这个模型来实现预测一个自己手写的图片

    # 加载整个模型
    loaded_model = torch.load('whole_model.pth')
​
    # 保存模型参数
    torch.save(loaded_model.state_dict(),'model_params.pth')

代码附上:

dataset.py

import glob
import os.path
​
import cv2
import torch
import util
​
class DataAndLabel:
    def __init__(self,path='D:\\0MNIST\\raw',is_train=True):
        super().__init__()
        # 拼接路径
        #data里面是path,label
        clas='train' if is_train==True else 'test'
        path=os.path.join(path,clas)
        paths=glob.glob(os.path.join(path,'*','*'))
        # print(paths)
        # print(path)
        self.data=[]
        for path in paths:
            label=int(path.split('\\')[-2])
            self.data.append((path,label))
    def __getitem__(self, idx):
        #返回一个tensor,one-hot
        path,label =self.data[idx]
        img=cv2.imread(path)
        # cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
        img_tensor=util.transforms_train(img)
        one_hot=torch.zeros(10)
        one_hot[label]=1
        return img_tensor,one_hot
    def __len__(self):
        return len(self.data)
# if __name__ == '__main__':
#     data=DataAndLabel()
#     print(data[0])
#     print()

lenet5.py

import torch
import torch.nn as nn
from torchkeras import summary
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(3,6,5,1)
        self.maxpool1=nn.MaxPool2d(2)
        self.conv2=nn.Conv2d(6,16,3,1)
        self.maxpool2=nn.MaxPool2d(2)
        self.layer1=nn.Linear(16*5*5,10)
        self.layer2=nn.Linear(10,10)
        self.relu=nn.Softmax()
    def forward(self,x):
        x=self.conv1(x)
        x=self.relu(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.relu(x)
        x=self.maxpool2(x)
        # print(x.shape)
        x=torch.flatten(x,1)
        # print(x.shape)
​
        x=self.layer1(x)
        x=self.layer2(x)
        return x
if __name__ == '__main__':
    x=torch.randn(2,3,28,28)
    net=Net()
    out=net(x)
    # print(out.shape)
    # print(summary(net, x))

train_and_test

import torch
import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from lenet5 import Net
import torch.nn as nn
from dataset import DataAndLabel
class TrainAndTest(Dataset):
    def __init__(self):
        super().__init__()
        # self.writer=SummaryWriter("logs")
        net=Net()
        self.net=net
        self.loss=nn.MSELoss()
        self.opt = torch.optim.Adam(net.parameters(), lr=0.1)
        self.train_data=DataAndLabel(is_train=True)
        self.test_data=DataAndLabel(is_train=False)
        self.train_loader=DataLoader(self.train_data,batch_size=100,shuffle=False)
        self.test_loader=DataLoader(self.test_data,batch_size=100,shuffle=False)
    # 拿到数据,网络
    def train(self,epoch):
        loss_sum = 0
        accurary_sum = 0
        for img_tensor, label in tqdm.tqdm(self.train_loader, desc='train...', total=len(self.train_loader)):
            out = self.net(img_tensor)
            loss = self.loss(out, label)
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            loss_sum += loss.item()
            accurary_sum += torch.mean(
                torch.eq(torch.argmax(label, dim=1), torch.argmax(out, dim=1)).to(torch.float32)).item()
        loss_avg = loss_sum / len(self.train_loader)
        accurary_avg = accurary_sum / len(self.train_loader)
        print(f'train---->loss_avg={round(loss_avg, 3)},accurary_avg={round(accurary_avg, 3)}')
        # self.writer.add_scalars('loss',{'loss_avg':loss_avg},epoch)
    def train1(self):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm.tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy =torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.train_loader)
        avg_acc = sum_acc / len(self.train_loader)
        print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
​
​
    def run(self):
        for epoch in range(10):
            self.train1()
            # self.test(epoch)
if __name__ == '__main__':
    tt=TrainAndTest()
    tt.run()

util.py

from torchvision import transforms
​
transforms_train=transforms.Compose([
    # transforms.CenterCrop(10),
    # transforms.PILToTensor(),
    transforms.ToTensor(),#归一化,转tensor
    transforms.Resize((28,28)),
    transforms.RandomVerticalFlip()
])
transforms_test=transforms.Compose([
    transforms.ToTensor(),  # 归一化,转tensor
    transforms.Resize((28, 28)),
])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/798720.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】基于环形队列RingQueue的生产消费者模型

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 环形队列的概念及定义 POSIX信号量 RingQueue的实现方式 RingQueue.hpp的构建 Thread.hpp Main.cc主函数的编写 Task.hpp function包装器的使用 总结 前言…

《Python数据科学之一:初见数据科学与环境》

《Python数据科学之一:初见数据科学与环境》 欢迎来到“Python数据科学”系列的第一篇文章。在这个系列中,我们将通过Python的镜头,深入探索数据科学的丰富世界。首先,让我们设置和理解数据科学的基本概念以及在开始任何数据科学项…

《C专家编程》 C++

抽象 就是观察一群数据,忽略不重要的区别,只记录关注的事务特征的关键数据项。比如有一群学生,关键数据项就是学号,身份证号,姓名等。 class student {int stu_num;int id_num;char name[10]; } 访问控制 this关键字…

安全防御:防火墙概述

目录 一、信息安全 1.1 恶意程序一般会具备一下多个或全部特点 1.2 信息安全五要素: 二、了解防火墙 2.1 防火墙的核心任务 2.2 防火墙的分类 2.3 防火墙的发展历程 2.3.1 包过滤防火墙 2.3.2 应用代理防火墙 2.3.3 状态检测防火墙 补充防御设备 三、防…

Torch-Pruning 库入门级使用介绍

项目地址:https://github.com/VainF/Torch-Pruning Torch-Pruning 是一个专用于torch的模型剪枝库,其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法(如 Magnitude Pruning 或 Taylor Pruning)相结合…

uniapp实现水印相机

uniapp实现水印相机-livePusher 水印相机 背景 前两天拿到了一个需求,要求在内部的oaApp中增加一个卫生检查模块,这个模块中的核心诉求就是要求拍照的照片添加水印。对于这个需求,我首先想到的是直接去插件市场,下一个水印相机…

《Python数据科学之五:模型评估与调优深入解析》

《Python数据科学之五:模型评估与调优深入解析》 在数据科学项目中,精确的模型评估和细致的调优过程是确保模型质量、提高预测准确性的关键步骤。本文将详细探讨如何利用 Python 及其强大的库进行模型评估和调优,确保您的模型能够达到最佳性能…

docker中1个nginx容器搭配多个django项目中设置uwsgi.ini的django项目路径

docker中,1个nginx容器搭配多个django项目容器,设置各个uwsgi.ini的django项目路径 被这个卡了一下,真是,哎 各个uwsgi配置应该怎样设置项目路径 django项目1中创建的django项目名为 web 那么uwsgi.ini中要设置为 chdir …

【Vue3 ts】echars图表展示统计的月份数据

图片展示 此处内容为展示24年各个月份产品的创建数量。在后端统计24年各个月份产品数量后,以数组的格式发送给前端,前端负责展示。 后端 entity层: Data Schema(description "月份统计")public class MonthCount {private Stri…

得物六宫格验证码分析

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 前言(lianxi a…

算法的时间复杂度和空间复杂度-例题

一、消失的数字 . - 力扣&#xff08;LeetCode&#xff09; 本题要求的时间复杂度是O(n) &#xff0c;所以我们不能用循环嵌套&#xff1b; 解法一&#xff1a; int missingNumber(int* nums, int numsSize){int sum10;for(int i0;i<numsSize;i){sum1i;}int sum20;for(i…

C到C嘎嘎的衔接篇

本篇文章&#xff0c;是帮助大家从C向C嘎嘎的过渡&#xff0c;那么我们直接开始吧 不知道大家是否有这样一个问题&#xff0c;学完C的时候感觉还能听懂&#xff0c;但是听C嘎嘎感觉就有点难度或者说很难听懂&#xff0c;那么本篇文章就是帮助大家从C过渡到C嘎嘎。 C嘎嘎与C的区…

MPC轨迹跟踪控制器推导及Simulink验证

文章目录 MPC轨迹跟踪控制器推导及Simulink验证MPC的特点MPC轨迹跟踪控制器推导一 系统离散化二 预测区间状态和变量推导三 代价函数推导四 优化求解 <center> 基于MPC的倒立摆控制系统相关资料Reference&#xff1a; MPC轨迹跟踪控制器推导及Simulink验证 MPC的特点 多…

SAP 消息输出 - Adobe Form

目录 1 安装链接 2 前台配置 - Fiori app 2.1 维护表单模板 (maintain form templates) 2.2 管理微标 (manage logos) 2.3 管理文本 (manage texts) 3 后台配置 3.1 定义表单输出规则 3.2 分配表单模板 SAP 消息输出&#xff0c;不仅是企业内部用来记录关键业务操作也是…

Win11任务栏当中对 STM32CubeMX 的堆叠问题

当打开多个 CubeMX 程序的时候&#xff0c;Win11 自动将其进行了堆叠&#xff0c;这时候就无法进行预览与打开。 问题分析&#xff1a;大部分ST的工具都是基于 JDK 来进行开发的&#xff0c;Win11 将其识别成了同一个 Binary 但是实际上他们并不是同一个&#xff0c;通过配置…

基于conda包的环境创建、激活、管理与删除

Anaconda是一个免费、易于安装的包管理器、环境管理器和 Python 发行版&#xff0c;支持平台包括Windows、macOS 和 Linux。下载安装地址&#xff1a;Download Anaconda Distribution | Anaconda 很多不同的项目可能需要使用不同的环境。例如某个项目需要使用pytorch1.6&#x…

C语言详解(结构体)

Hi~&#xff01;这里是奋斗的小羊&#xff0c;很荣幸各位能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎~~ &#x1f4a5;个人主页&#xff1a;小羊在奋斗 &#x1f4a5;所属专栏&#xff1a;C语言 本系列文章为个人学习笔记&#xff0c;在这里撰写成文一…

《后端程序猿 · EasyPOI 导入导出》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

Android OkHttp3中HttpLoggingInterceptor使用

目录 一 概述1.1 日志级别 二 使用2.1 引入依赖2.2 创建对象2.3 添加拦截器 三 结果展示3.1 日志级别为BODY3.2 日志级别为BASIC3.3 日志级别为HEADERS 参考 一 概述 HttpLoggingInterceptor是OkHttp3提供的拦截器&#xff0c;用来记录HTTP请求和响应的详细信息。 1.1 日志级…