CNN-day5-经典神经网络LeNets5

经典神经网络-LeNets5

1998年Yann LeCun等提出的第一个用于手写数字识别问题并产生实际商业(邮政行业)价值的卷积神经网络

参考:论文笔记:Gradient-Based Learning Applied to Document Recognition-CSDN博客

1 网络模型结构

整体结构解读:

输入图像:32×32×1

三个卷积层:

C1:输入图片32×32,6个5×5卷积核 ,输出特征图大小28×28(32-5+1)=28,一个bias参数;

可训练参数一共有:(5×5+1)×6=156

C3 :输入图片14×14,16个5×5卷积核,有6×3+6×4+3×4+1×6=60个通道,输出特征图大小10×10((14-5)/1+1),一个bias参数;

可训练参数一共有:6(3×5×5+1)+6×(4×5×5+1)+3×(4×5×5+1)+1×(6×5×5+1)=1516

C3的非密集的特征图连接:

C3的前6个特征图与S2层相连的3个特征图相连接,后面6个特征图与S2层相连的4个特征图相连 接,后面3个特征图与S2层部分不相连的4个特征图相连接,最后一个与S2层的所有特征图相连。 采用非密集连接的方式,打破对称性,同时减少计算量,共60组卷积核。主要是为了节省算力。

C5:输入图片5×5,16个5×5卷积核,包括120×16个5×5卷积核 ,输出特征图大小1×1(5-5+1),一个bias参数;

可训练参数一共有:120×(16×5×5+1)=48120

两个池化层S2和S4:

都是2×2的平均池化,并添加了非线性映射

S2(下采样层):输入28×28,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:14×14(28/2)

S4(下采样层):输入10×10,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:5×5(10/2)

两个全连接层:

第一个全连接层:输入120维向量,输出84个神经元,计算输入向量和权重向量之间的点积,再加上一个偏置,结果通过sigmoid函数输出。84的原因是:字符编码是ASCII编码,用7×12大小的位图表示,-1白色1黑色,84可以用于对每一个像素点的值进行估计。

第二个全连接层(Output层-输出层):输出 10个神经元 ,共有10个节点,代表数字0-9。

所有激活函数采用Sigmoid

2 网络模型实现

2.1模型定义

import torch
import torch.nn as nn
​
​
class LeNet5s(nn.Module):
    def __init__(self):
        super(LeNet5s, self).__init__()  # 继承父类
        # 第一个卷积层
        self.C1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,  # 输入通道
                out_channels=6,  # 输出通道
                kernel_size=5,  # 卷积核大小
            ),
            nn.ReLU(),
        )
        # 池化:平均池化
        self.S2 = nn.AvgPool2d(kernel_size=2)
​
        # C3:3通道特征融合单元
        self.C3_unit_6x3 = nn.Conv2d(
            in_channels=3,
            out_channels=1,
            kernel_size=5,
        )
        # C3:4通道特征融合单元
        self.C3_unit_6x4 = nn.Conv2d(
            in_channels=4,
            out_channels=1,
            kernel_size=5,
        )
​
        # C3:4通道特征融合单元,剔除中间的1通道
        self.C3_unit_3x4_pop1 = nn.Conv2d(
            in_channels=4,
            out_channels=1,
            kernel_size=5,
        )
​
        # C3:6通道特征融合单元
        self.C3_unit_1x6 = nn.Conv2d(
            in_channels=6,
            out_channels=1,
            kernel_size=5,
        )
​
        # S4:池化
        self.S4 = nn.AvgPool2d(kernel_size=2)
        # 全连接层
        self.fc1 = nn.Sequential(
            nn.Linear(in_features=16 * 5 * 5, out_features=120), nn.ReLU()
        )
        self.fc2 = nn.Sequential(nn.Linear(in_features=120, out_features=84), nn.ReLU())
        self.fc3 = nn.Linear(in_features=84, out_features=10)
​
    def forward(self, x):
        # 训练数据批次大小batch_size
        num = x.shape[0]
​
        x = self.C1(x)
        x = self.S2(x)
        # 生成一个empty张量
        outchannel = torch.empty((num, 0, 10, 10))
        # 6个3通道的单元
        for i in range(6):
            # 定义一个元组:存储要提取的通道特征的下标
            channel_idx = tuple([j % 6 for j in range(i, i + 3)])
            x1 = self.C3_unit_6x3(x[:, channel_idx, :, :])
            outchannel = torch.cat([outchannel, x1], dim=1)
​
        # 6个4通道的单元
        for i in range(6):
            # 定义一个元组:存储要提取的通道特征的下标
            channel_idx = tuple([j % 6 for j in range(i, i + 4)])
            x1 = self.C3_unit_6x4(x[:, channel_idx, :, :])
            outchannel = torch.cat([outchannel, x1], dim=1)
​
        # 3个4通道的单元,先拿五个,干掉中那一个
        for i in range(3):
            # 定义一个元组:存储要提取的通道特征的下标
            channel_idx = tuple([j % 6 for j in range(i, i + 5)])
            # 删除第三个元素
            channel_idx = channel_idx[:2] + channel_idx[3:]
            print(channel_idx)
            x1 = self.C3_unit_3x4_pop1(x[:, channel_idx, :, :])
            outchannel = torch.cat([outchannel, x1], dim=1)
​
        x1 = self.C3_unit_1x6(x)
        # 平均池化
        outchannel = torch.cat([outchannel, x1], dim=1)
        outchannel = nn.ReLU()(outchannel)
​
        x = self.S4(outchannel)
        # 对数据进行变形
        x = x.view(x.size(0), -1)
        # 全连接层
        x = self.fc1(x)
        x = self.fc2(x)
        # TODO:SOFTMAX
        output = self.fc3(x)
​
        return output
​
​
def test001():
    net = LeNet5s()
    # 随机一个测试数据
    input = torch.randn(128, 1, 32, 32)
    output = net(input)
    print(output.shape)
    pass
​
​
if __name__ == "__main__":
    test001()

2.2全局变量

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
​
dir = os.path.dirname(__file__)
modelpath = os.path.join(dir, "weight/model.pth")
datapath = os.path.join(dir, "data")
​
# 数据预处理和加载
transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),  # 调整输入图像大小为32x32
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
​

2.3模型训练

def train():
​
    trainset = torchvision.datasets.MNIST(
        root=datapath, train=True, download=True, transform=transform
    )
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
​
    # 实例化模型
    net = LeNet5()
​
    # 使用MSELoss作为损失函数
    criterion = nn.MSELoss()
​
    # 使用SGD优化器
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
​
    # 训练模型
    num_epochs = 10
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
​
            # 将labels转换为one-hot编码
            labels_one_hot = torch.zeros(labels.size(0), 10).scatter_(
                1, labels.view(-1, 1), 1.0
            )
            labels_one_hot = labels_one_hot.to(torch.float32)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels_one_hot)
            loss.backward()
            optimizer.step()
​
            running_loss += loss.item()
            if i % 100 == 99:
                print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")
                running_loss = 0.0
    # 保存模型参数
    torch.save(net.state_dict(), modelpath)
    print("Finished Training")

2.4验证

def vaild():
​
    testset = torchvision.datasets.MNIST(
        root=datapath, train=False, download=True, transform=transform
    )
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
    # 实例化模型
    net = LeNet5()
    net.load_state_dict(torch.load(modelpath))
    # 在测试集上测试模型
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
​
    print(f"验证集: {100 * correct / total:.2f}%")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/967190.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

导航守卫router.beforeEach

router.beforeEach 是一个全局前置守卫,在每次路由跳转之前都会触发。 //index.jsrouter.beforeEach((to, from, next) > {// 打印即将要进入的目标路由信息console.log(即将要进入的目标路由信息:, to)// 打印当前正要离开的路由信息console.log(当前正要离开的…

[ESP32:Vscode+PlatformIO]添加第三方库 开源库 与Arduino导入第三方库的区别

前言 PlatformIO与Arduino在添加第三方库方面的原理存在显著差异 在PlatformIO中,第三方库的使用是基于项目(工程)的。具体而言,只有当你为一个特定的项目添加了某个第三方库后,该项目才能使用该库。这些第三方库的文…

了解AI绘图,Stable Diffusion的使用

AI绘图对GPU算力要求较高。 个人电脑配置可参考: CPU:14600kf 盒装 显卡:RTX 4080金属大师 OC,16G显存 主板:z790吹雪d4 内存:芝奇皇家戟4000c18,162G 硬盘:宏基gm7000 1T 散热:追风…

linux环境自动化golang项目启动脚本解析

一.场景介绍 当在本地创建了golang项目,修改了代码功能,怎么在远程测试服务器上更新该功能呢,可以使用下面的步骤来解决该问题(这只是其中一种方法): (1).推送最新代码到远程仓库 (2).在测试服务器上创建该项目并拉取最新代码 (3).创建deploy.sh脚本 (4).运行deploy.sh脚本 二.…

归一化与伪彩:LabVIEW图像处理的区别

在LabVIEW的图像处理领域,归一化(Normalization)和伪彩(Pseudo-coloring)是两个不同的概念,虽然它们都涉及图像像素值的调整,但目的和实现方式截然不同。归一化用于调整像素值的范围&#xff0c…

基于DeepSeek API和VSCode的自动化网页生成流程

1.创建API key 访问官网DeepSeek ,点击API开放平台。 在开放平台界面左侧点击API keys,进入API keys管理界面,点击创建API key按钮创建API key,名称自定义。 2.下载并安装配置编辑器VSCode 官网Visual Studio Code - Code Editing…

Open WebUI项目源码学习记录(从0开始基于纯CPU环境部署一个网页Chat服务)

感谢您点开这篇文章:D,鼠鼠我是一个代码小白,下文是学习开源项目Open WebUI过程中的一点笔记记录,希望能帮助到你~ 本人菜鸟,持续成长,能力不足有疏漏的地方欢迎一起探讨指正,比心心&#xff5e…

SSM仓库物品管理系统 附带详细运行指导视频

文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.用户登录代码:2.保存物品信息代码:3.删除仓库信息代码: 一、项目演示 项目演示地址: 视频地址 二、项目介绍 项目描述:这是一个基于SSM框架开发的仓库…

Python微博动态爬虫

本文是刘金路的《语言数据获取与分析基础》第十章的扩展,详细解释了如何利用Python进行微博爬虫,爬虫内容包括微博指定帖子的一级评论、评论时间、用户名、id、地区、点赞数。 整个过程十分明了,就是用户利用代码模拟Ajax请求,发…

时序数据库:Influxdb详解

文章目录 一、简介1、简介2、官网 二、部署1、安装2、配置(1)用户初始化 三、入门(Web UI)1、加载数据(1)上传数据文件(2)代码接入模板 2、管理存储桶(1)创建…

unity学习32:角色相关1,基础移动控制

目录 0 应用商店 1 角色上新增CharacterController 组件 1.1 角色上新增CharacterController 组件 1.2 如果没有这个则会报错 2 速度 2.1 默认速度,按帧率计算 2.2 修改速度为按时间计算 2.3 movespeed,基础是1米/秒,这个就是每 move…

Centos Ollama + Deepseek-r1+Chatbox运行环境搭建

Centos Ollama Deepseek-r1Chatbox运行环境搭建 内容介绍下载ollama在Ollama运行DeepSeek-r1模型使用chatbox连接ollama api 内容介绍 你好! 这篇文章简单讲述一下如何在linux环境搭建 Ollama Deepseek-r1。并在本地安装的Chatbox中进行远程调用 下载ollama 登…

mysql8.0使用pxc实现高可用

环境准备 准备三台虚拟机,其对应的主机名和IP地址为 pxc-1192.168.190.129pxc-2192.168.190.133pxc-3192.168.190.134 解析,都要做解析 测试 下载pxc的安装包, 官网:https://www.percona.com/downloads 选择8.0的版本并下载,…

LabVIEW污水生化处理在线监测

污水处理是环保领域的重要工作,传统污水处理方法在监测方面存在实时性差、操作不便等问题。为解决这些问题,本项目设计并实现了一套基于LabVIEW的污水生化处理在线监测平台,能够实时监测污水处理过程中的关键参数,如温度、pH值、溶…

【AI学习】关于 DeepSeek-R1的几个流程图

遇见关于DeepSeek-R1的几个流程图,清晰易懂形象直观,记录于此。 流程图一 来自文章《Understanding Reasoning LLMs》, 文章链接:https://magazine.sebastianraschka.com/p/understanding-reasoning-llms?continueFlagaf07b1a0…

vs封装dll 给C#使用

一,vs创建控制台应用 创建控制台应用得好处时,我们可以自己测试接口,如果接口没有问题,改成dll重新编译一遍就可以。 二, 创建一个c 类,将所需提供得功能 封装到类中。 这样可以将 所有功能,进…

ubuntu20使用tigervnc远程桌面配置记录

一、安装tigervnc sudo apt install tigervnc-common sudo apt install tigervnc-standalone-server二、增加配置文件 安装完后新增配置文件:vim ~/.vnc/xstartup #!/bin/sh #Uncomment the following two lines for normal desktop: #unset SESSION_MANAGER #ex…

DeepSeek使用技巧大全(含本地部署教程)

在人工智能技术日新月异的今天,DeepSeek 作为一款极具创新性和实用性的 AI,在众多同类产品中崭露头角,凭借其卓越的性能和丰富的功能,吸引了大量用户的关注。 DeepSeek 是一款由国内顶尖团队研发的人工智能,它基于先进…

网络原理之HTTPS(如果想知道网络原理中有关HTTPS的知识,那么只看这一篇就足够了!)

前言:随着互联网安全问题日益严重,HTTPS已成为保障数据传输安全的标准协议,通过加密技术和身份验证,HTTPS有效防止数据窃取、篡改和中间人攻击,确保通信双方的安全和信任。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要…

MySQL 8.0.41 终端修改root密码

1.在 MySQL 命令行中,运行以下命令修改密码 ALTER USER rootlocalhost IDENTIFIED BY new_password; 其中,new_password替换为你想要设置的新密码 2.退出 MySQL终端,重新打开,使用新密码进入,修改成功