CNN——LeNet

1.LeNet概述       

         LeNet是Yann LeCun于1988年提出的用于手写体数字识别的网络结构,它是最早发布的卷积神经网络之一,可以说LeNet是深度CNN网络的基石。

        当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet当时被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。

        下面是整个网络的结构图

        LeNet共有8层,其中包括输入层,3个卷积层,2个子采样层(也就是现在的池化层),1个全连接层和1个高斯连接层。

        上图中用C代表卷积层,用S代表采样层,用F代表全连接层。输入size固定在1*32*32,LeNet图片的输入是二值图像。网络的输出为0~9十个数字的RBF度量,可以理解为输入图像属于0~9数字的可能性大小。

2.详解LeNet

下面对图中每一层做详细的介绍:

  • LeNet使用的卷积核大小都为5*5,步长为1,无填充,只是卷积深度不一样(卷积核个数导致生成的特征图的通道数)
  • 激活函数为Sigmoid
  • 下采样层都是使用最大池化实现,池化的核都为2*2,步长为2,无填充

        input输入层,尺寸为1*32*32的二值图

        C1层是一个卷积层。该层使用6个卷积核,生成特征图尺寸为32-5+1=28,输出为6个大小为28*28的特征图。再经过一个Sigmoid激活函数非线性变换。

        S2层是一个下采样层。生成特征图尺寸为28/2=14,得到6个14*14的特征图。

        C3层是一个卷积层,该层使用16个卷积核,生成特征图尺寸为14-5+1=10,输出为16个10*10的特征图。再经过一个Sigmoid激活函数非线性变换。

        S4层是一个下采样层,生成特征图尺寸为10/2=5,得到16个5*5的特征图

        C5层是一个卷积层,卷积核数量增加至120。生成特征图尺寸为5-5+1=1。得到120个1*1的特征图。这里实际上相当于S4全连接了,但仍将其标为卷积层,原因是如果LeNet-5的输入图片尺寸变大,其他保持不变,那该层特征图的维数也会大于1*1,那就不是全连接了。再经过一个Sigmoid激活函数非线性变换。

        F6层是一个全连接层,该层与C5层全连接,输出84张特征图。再经过一个Sigmoid激活函数非线性变换。

        输出层:输出层由欧式径向基函数(高斯)单元组成,每个类别(0~9数字)对应一个径向基函数单元,每个单元有84个输入。也就是说,每个输出RBF单元计算输入向量和该类别标记向量之间的欧式距离,距离越远,PRF输出越大,同时我们也会将与标记向量欧式距离最近的类别作为数字识别的输出结果。当然现在通常使用的Softmax实现

3.使用LeNet实现Mnist数据集分类

1.导入所需库

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm # 显示训练进度条

2.使用GPU

device = 'cuda' if torch.cuda.is_available() else 'cpu'

3.读取Mnist数据集

# 定义数据转换以进行数据标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为 PyTorch 张量
])

# 下载并加载 MNIST 训练和测试数据集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)

# 创建数据加载器以批量加载数据
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

4.搭建LeNet

        需要注意的是torch.nn.CrossEntropyLoss自带了softmax函数,所以最后一层使用全连接即可,在训练时使用torch.nn.CrossEntropyLoss

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) # Mnist尺寸为28*28,这里设置填充变成32*32
        self.sigmoid = nn.Sigmoid()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(self.sigmoid(self.conv1(x)))
        x = self.pool(self.sigmoid(self.conv2(x)))
        x = self.flatten(x)
        x = self.sigmoid(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x
# 实例化模型
model = LeNet().to(device)
summary(model, (1, 28, 28))

5.训练函数

def train(model, lr, epochs):
    # 将模型放入GPU
    model = model.to(device)
    # 使用交叉熵损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # SGD
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    # 记录训练与验证数据
    train_losses = []
    train_accuracies = []
    # 开始迭代   
    for epoch in range(epochs):   
        # 切换训练模式
        model.train()  
        # 记录变量
        train_loss = 0.0
        correct_train = 0
        total_train = 0
        # 读取训练数据并使用 tqdm 显示进度条
        for i, (inputs, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit='batch'):
            # 训练数据移入GPU
            inputs = inputs.to(device)
            targets = targets.to(device)
            # 模型预测
            outputs = model(inputs)
            # 计算损失
            loss = loss_fn(outputs, targets)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 使用优化器优化参数
            optimizer.step()
            # 记录损失
            train_loss += loss.item()
            # 计算训练正确个数
            _, predicted = torch.max(outputs, 1)
            total_train += targets.size(0)
            correct_train += (predicted == targets).sum().item()
        # 计算训练正确率并记录
        train_loss /= len(train_dataloader)
        train_accuracy = correct_train / total_train
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        # 输出训练信息
        print(f"Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
    # 绘制损失和正确率曲线
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(epochs), train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(range(epochs), train_accuracies, label='Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.tight_layout()
    plt.show()

6.模型训练

model = LeNet()
lr = 0.9 # sigmoid两端容易饱和,gradient比较小,学得比较慢,所以学习率要大一些
epochs = 20
train(model,lr,epochs)

 

7.模型测试 

def test(model, test_dataloader, device, model_path):
    # 将模型设置为评估模式
    model.eval()
    # 将模型移动到指定设备上
    model.to(device)

    # 从给定路径加载模型的状态字典
    model.load_state_dict(torch.load(model_path))

    correct_test = 0
    total_test = 0
    # 不计算梯度
    with torch.no_grad():
        # 遍历测试数据加载器
        for inputs, targets in test_dataloader:  
            # 将输入数据和标签移动到指定设备上
            inputs = inputs.to(device)
            targets = targets.to(device)
            # 模型进行推理
            outputs = model(inputs)
            # 获取预测结果中的最大值
            _, predicted = torch.max(outputs, 1)
            total_test += targets.size(0)
            # 统计预测正确的数量
            correct_test += (predicted == targets).sum().item()
    
    # 计算并打印测试数据的准确率
    test_accuracy = correct_test / total_test
    print(f"Accuracy on Test: {test_accuracy:.4f}")
    return test_accuracy
model_path = save_path
test(model, test_dataloader, device, save_path)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/287882.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【前沿技术】超级稳定的视频卡通画方案

Git clone项目到本地 git clone gitgithub.com:Artiprocher/DiffSynth-Studio.git 基本原理 使用了stable diffusion稳定扩散模型和controlnet来控制图像生成的轮廓,animatediff控制视频帧与帧之间的连续性,最后使用RIFE技术平滑整个生成后的视频。 …

40道java集合面试题含答案(很全)

1. 什么是集合 集合就是一个放数据的容器,准确的说是放数据对象引用的容器集合类存放的都是对象的引用,而不是对象的本身集合类型主要有3种:set(集)、list(列表)和map(映射)。 2. 集合的特点 集合的特点主要有如下两…

使用Python做个可视化的“剪刀石头布”小游戏

目录 一、引言 二、环境准备与基础知识 三、游戏界面制作 四、游戏逻辑实现 五、代码示例 六、游戏测试与优化 七、扩展与改进 八、总结 一、引言 “剪刀石头布”是一种古老的手势游戏,它简单易懂,趣味性强,适合各个年龄段的人参与。…

虎克:开发小程序要多少钱一个,非专业开发如何做自己的小程序

小程序开发费用主要取决于小程序的功能复杂度和开发周期。一般来说,小程序开发费用可以分为两类:模板开发和定制开发。 模板开发:模板开发是指使用现成的模板进行开发,价格相对较低,一般在几千元左右。优点是价格便宜&…

你不知道的 CSS 之 包含块 ! 最细讲解,一听就懂!

你不知道的 CSS 之包含块 一说到 CSS 盒模型,这是很多小伙伴耳熟能详的知识,甚至有的小伙伴还能说出 border-box 和 content-box 这两种盒模型的区别。 但是一说到 CSS 包含块,有的小伙伴就懵圈了,什么是包含块?好像…

(切图笔记)layui表格单元格添加超链接 以及传参方法 亲测可用 附代码

layui在切图网日常的工作中常常用到,特别是它的layer弹窗,基本可以满足网站切图时候遇到的绝大多数弹窗的情况,参数比较丰富 灵活,是不可多得的网页插件之一,我见很多人说layui过时了,这是相比于vue正流行的…

具有不规则结果的常规 PyTorch 张量函数

一、说明 深度学习从业者应注意的常用 PyTorch 张量函数的例外情况。你是不是也和上面的人一样呢?如果是,那么本文可能会帮助您在使用 PyTorch 构建深度学习模型时发现一些常见错误。 我在下面提到了 5 个最常用的 PyTorch 函数及其小示例以及它们无法按…

阿里云服务器8080端口怎么打开?在安全组中设置

阿里云服务器8080端口开放在安全组中放行,Tomcat默认使用8080端口,8080端口也用于www代理服务,阿腾云atengyun.com以8080端口为例来详细说下阿里云服务器8080端口开启教程教程: 阿里云服务器8080端口开启教程 阿里云服务器8080端…

Codeforces Good Bye 2023 A~E

A.2023(思维) 题意: 有一个序列 A a 1 , a 2 , . . . , a n k A a_1, a_2, ..., a_{n k} Aa1​,a2​,...,ank​,且这个序列满足 ∏ i 1 n k a i 2023 \prod\limits_{i 1}^{n k}a_i 2023 i1∏nk​ai​2023,而这个序列中的 k k k个…

[Flutter]WindowsOS上运行遇到的问题总结

[Flutter]WindowsOS上运行遇到的问题总结 写在开头 Flutter项目已能在移动端完美使用后,想看看在桌面端等使用情况 基于Flutter3.0后已支持Windows/MacOS等桌面端,不过具体的系统,还需要看下官方文档解释。 这里抛出文档地址,可…

solidity显示以太坊美元价格

看过以太坊白皮书的都知道,以太坊比较比特币而言所提升的地方中,我认为最重要的一点就是能够访问外部的数据,这一点在赌博、金融领域应用会很广泛,但是区块链是一个确定的系统,包括里面的所有数值包括交易ID等都是确定…

教师行业的行业现状

teacher行业现状,近年来呈现出许多新的变化。作为一名从事教育行业多年的教师,深感这个行业的日新月异。今天,就让我来为大家揭秘一下,这个行业究竟有着怎样的现状吧! 需求持续增长随着不断发展,家长们对孩…

【计算机毕业设计】SSM实现的在线农产品商城

项目介绍 本项目分为前后台,且有普通用户与管理员两种角色。 用户角色包含以下功能: 用户登录,查看首页,按分类查看商品,查看新闻资讯,查看关于我们,查看商品详情,加入购物车,查看我的订单,提交订单,添加收获地址,支付订单等功能。 管理员角色包含以…

AntDB设计之CheckPoint——引言与功能简述

1.引言 数据库服务能力提升是一项系统性的工程,在不同的应用场景下,用户对于数据库各项能力的关注点也不同,如:读写延迟、吞吐量、扩展性、可靠性、可用性等等。国内不少数据库系统通过系统架构优化、硬件设备升级等方式&#xf…

工程项目管理软件哪个好用?这款顶级软件别错过!

“随着市场竞争加剧、产品利润走薄、用户响应要求提高、产品更新迭代加快等各项因素的变化,项目管理开始成为越来越多企业的管理方式。项目管理的核心目标是在规定时间和预算内,完成事先确定的范围内的工作,同时达到质量要求。” 你所在公司…

架构师使用的8种重要生命周期图

什么是生命周期? 百度给出的定义是:生命周期就是指一个对象的生老病死。 生命周期的概念应用很广泛,特别是在政治、经济、环境、技术、社会等诸多领域经常出现,其基本涵义可以通俗地理解为"从摇篮到坟墓"的整个过程。对于某个产品而言,它的生命周期其实是指产…

【hyperledger-fabric】部署和安装

简介 对hyperledger-fabric进行安装,话不多说,直接开干。但是需要申明一点,也就是本文章全程是开着加速器进行的资源操作,所以对于没有开加速器的情况可能会由于网络原因导致下载资源失败。 资料提供 1.官方部署文档在此&#…

mfc100u.dll文件丢失,有五种不同解决方法

在计算机使用过程中,我们经常会遇到一些错误提示,其中之一就是“找不到mfc100u.dll文件”。那么,mfc100u.dll文件到底是什么?为什么会出现丢失的情况?本文将详细介绍mfc100u.dll文件的作用以及丢失的原因,并…

软件测试入门(知识汇总)

1、黑盒测试、白盒测试、灰盒测试 1.1 黑盒测试 黑盒测试又叫功能测试、数据驱动测试 或 基于需求规格说明书的功能测试。该类测试注重于测试软件的功能性需求。 采用这种测试方法,测试工程师把测试对象看作一个黑盒子,完全不考虑程序内部的逻辑结构和…

【多线程与高并发 四】CAS、Unsafe 及 JUC 原子类详解

👏作者简介:大家好,我是若明天不见,BAT的Java高级开发工程师,CSDN博客专家,后端领域优质创作者 📕系列专栏:多线程及高并发系列 📕其他专栏:微服务框架系列、…