从0开始深度学习(27)——卷积神经网络(LeNet)

1 LeNet神经网络

LeNet是最早的卷积神经网络之一,由Yann LeCun等人在1990年代提出,并以其名字命名。最初,LeNet被设计用于手写数字识别,最著名的应用是在美国的邮政系统中识别手写邮政编码。LeNet架构的成功证明了卷积神经网络在解决实际问题中的有效性,为后续更复杂、更强大的CNN模型的发展奠定了基础。

结构如下:
在这里插入图片描述
先用pytorch代码实现该结构:

import torch
from torch import nn

net=nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5,padding=2),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

我们知道手写数字识别数据集的数据,都是 28 × 28 28\times28 28×28的灰度图,下面我们将输入一个 28 × 28 28\times28 28×28的矩阵,看看经过这个模型过后,会输出什么。

x=torch.rand(size=(1,1,28,28))
for layer in net:
    x=layer(x)
    print(layer.__class__.__name__,"output shape:",x.shape)
    

运行结果:
在这里插入图片描述
可以发现最后输出为 1 × 10 1\times10 1×10的张量,该维度与我们需要的结果分类数(0~9)匹配。

2 模型训练

检测一下LeNet-5在Fashion-MNIST数据集上的表现。

import torch
from torch import nn,optim
import torchvision
from torch.utils import data
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
from d2l import torch as d2l
import matplotlib.pyplot as plt

net=nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5,padding=2),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

batch_size=128

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,))  # 标准化到[-1, 1]区间,加快计算
])

# 加载Fashion-MNIST数据集
train_dataset = datasets.FashionMNIST(root='D:/DL_Data/', train=True, download=False, transform=transform)
test_dataset = datasets.FashionMNIST(root='D:/DL_Data/', train=False, download=False, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 自定义 try_gpu 函数
def try_gpu(i=0):
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def evaluate_acc_gpu(net, data_iter, device=None):
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
        metric = d2l.Accumulator(2)
        with torch.no_grad():
            for X, y in data_iter:
                if isinstance(X, list):
                    X = [x.to(device) for x in X]
                else:
                    X = X.to(device)
                y = y.to(device)
                temp = net(X)
                acc = accuracy(temp, y)
                metric.add(acc, y.numel())
    return metric[0] / metric[1]

def train(net, train_iter, test_iter, num_epochs, lr, device, train_acc_list,test_acc_list):
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print("training on", device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    timer = d2l.Timer()
    train_acc_list = train_acc_list
    test_acc_list = test_acc_list
    print("init train_list nad test_list is ok")

    for epoch in range(num_epochs):
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
        train_l = metric[0] / metric[2]
        train_acc = metric[1] / metric[2]
        train_acc_list.append(train_acc)
        print(f"epoch: {epoch+1}, train_l: {train_l:.3f}, train_acc: {train_acc:.3f}")
        test_acc = evaluate_acc_gpu(net, test_iter)
        test_acc_list.append(test_acc)
        print(f"test acc: {test_acc:.3f}")
    return train_acc_list,test_acc_list

    


# 实现 accuracy 函数
def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())
    
lr, num_epochs = 0.9, 10
train_acc_list=[]
test_acc_list=[]

train_acc_list,test_acc_list=train(net, train_loader, test_loader, num_epochs, lr, try_gpu(),train_acc_list,test_acc_list)

print(f"num_epochs: {num_epochs}")
print(f"train_acc_list: {train_acc_list}")
print(f"test_acc_list: {test_acc_list}")

try:
    # 绘制训练和测试准确率的折线图
    epochs = range(1, num_epochs + 1)
    plt.plot(epochs, train_acc_list, 'b', label='Training Accuracy')
    plt.plot(epochs, test_acc_list, 'r', label='Testing Accuracy')
    plt.title('Training and Testing Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()
except Exception as e:
    print(f"An error occurred: {e}")

运行结果
在这里插入图片描述
在这里插入图片描述
分析图像可以看出,准确率还没有稳定,说明还有提升空间,可以添加epoch继续训练以获得更准的分类效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/912467.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何用C#和Aspose.PDF实现PDF转Word工具

在本篇博文中,我将详细讲解如何用C#实现一个PDF转Word工具。这款工具基于Aspose.PDF库,实现PDF文件转为Word(DOC/DOCX)格式的功能,并通过用户友好的界面和状态提示提升用户体验。希望通过这篇文章帮助大家理解软件的实…

运维技术之文件系统(File System for 0peration and Maintenance Technology)

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 本人主要分享计算机核心技…

IoTDB 与 HBase 对比详解:架构、功能与性能

五大方向,洞悉 IoTDB 与 HBase 的详尽对比! 在物联网(IoT)领域,数据的采集、存储和分析是确保系统高效运行和决策准确的重要环节。随着物联网设备数量的增加和数据量的爆炸式增长,开发者和决策者们需要选择…

【Vue】Vue2和Vue3响应式原理

前言 Vue 3 的核心部分可以分为三个主要模块:Compiler、Reactivity 和 Runtime。响应式的处理逻辑在 Reactivity 部分。 Compiler(编译器):Template > 渲染函数 将 Vue 的模板(Template)转换成 JavaS…

哪些人群适合考取 PostgreSQL 数据库 PGCM 证书?

#postgresql#,作为开源数据库领域的佼佼者,凭借其强大的功能和广泛的应用场景,吸引了大量数据库从业者的关注。它代表着持有者在PostgreSQL数据库管理、优化、安全和高可用性设计等方面的专家级技能。 PGCM证书适合那些具备扎实理论基础和一…

C++高级编程(9)

九、STL模板库 1.C函数模板 函数模板是一个独立于类型的函数,可产生函数特定类型的版本。通过对参数类型进行参数化,获取有相同形式的函数体。 它是一个通用函数,它可适应一定范围内的不同类型对象的操作。 函数模板将代表着不同类型的一组…

深圳世界之窗:文化与娱乐交织的旅游胜地

深圳世界之窗位于广东省深圳市南山区华侨城,是中国著名的缩微景区。它以弘扬世界文化为宗旨,将世界奇观、历史遗迹、民间歌舞表演、高科技游乐项目等融为一体,为游客打造出一个不出国门就能领略世界风情的旅游胜地。 从文化角度来看&#xff…

贪心day04(买卖股票的最佳时机)

1.买卖股票的最佳时机 题目链接:. - 力扣(LeetCode) 思路:我们其实只需遍历一篇就可以解决这个问题。首先我们定义一个min为无穷大值,再遍历只要有数字比min跟小我们就更改min的值就好,此时我们只需要找出…

ClickHouse创建账号和连接测试

在之前搭建ClickHouse的时候,把账户相关的去掉了,所以登录和连接的时候是不需要账号密码的,但是实际项目中,肯定是需要根据需要创建账号。 一,创建账号 1,进入到 /etc/clickhouse-server, 编辑…

网页版五子棋——匹配模块(客户端开发)

前一篇文章:网页版五子棋——用户模块(客户端开发)-CSDN博客 目录 前言 一、前后端交互接口设计 二、游戏大厅页面 1.页面代码编写 2.前后端交互代码编写 3.测试获取用户信息功能 结尾 前言 前面文章介绍完了五子棋项目用户模块的代码…

10. java基础知识(下)

文章目录 一、一带而过二、字符串类型String1. 简单了解2. 关于结束符\03. 自动类型转换与强制类型转换 三、API文档与import导包1. API文档2. import导包 四、java中的数组1. 创建2. 遍历3. 补充4. Arrays类① 简单介绍② 练习 五、方法的重载六、规范约束七、内容出处 一、一…

Sam Altman:年底将有重磅更新,但不是GPT-5!

大家好,我是木易,一个持续关注AI领域的互联网技术产品经理,国内Top2本科,美国Top10 CS研究生,MBA。我坚信AI是普通人变强的“外挂”,专注于分享AI全维度知识,包括但不限于AI科普,AI工…

《Python编程实训快速上手》第四天--字符串操作

一、处理字符串 1、单引号和双引号 Python中单双引号均可以表示字符串,区别在于: 1、双引号中可以使用到单引号 2、单引号字符串中如果要使用单引号,要使用到转义字符 \ \ \t \n \\ 原始字符串 在开始的引号前加r&#xf…

原生html+js输入框下拉多选带关闭模块完整案例

<!DOCTYPE html> <html> <head> <title>多选下拉框</title> <style> * { box-sizing: border-box; margin: 0; padding: 0; } .multi-select-container { position: relative; width: 300px; margin: 20px; font-family: Arial, sans-seri…

在 Ubuntu 操作系统上:改变 App 任务栏菜单的背景色

Ubuntu 官方默认的终端&#xff0c;与操作系统的主题 theme 无关。

【优选算法 — 滑动窗口】滑动窗口小专题(一)

长度最小的子数组 长度最小的子数组 题目解析&#xff1a; 对于示例一 对于剩下两种示例&#xff1a; 解法一&#xff1a;暴力枚举 把所有的子数组全部枚举出来&#xff0c;并且枚举出的每一个子数组求和判断&#xff0c;返回长度最小的子数组&#xff1b; 时间复杂度 &…

半波整流器原理

一、二极管不控整流 1.阻性负载 1.1.电路拓扑结构 电路只由交流源、二极管和电阻组成。最基本的带阻性负载的半波整流器如图所示。输入源为交流源&#xff0c;目标是使输出电压含有非零直流分量&#xff0c;负载为R。功率二极管只允许电流往一个方向流动。 1.2.工作模态分析…

yolov8涨点系列之引入CBAM注意力机制

文章目录 YOLOv8 中添加注意力机制 CBAM 具有多方面的好处特征增强与选择通道注意力方面空间注意力方面 提高模型性能计算效率优化&#xff1a; yolov8增加CBAM具体步骤CBAM代码(1)在__init.pyconv.py文件的__all__内添加‘CBAM’(2)conv.py文件复制粘贴CBAM代码(3)修改task.py…

Rust-AOP编程实战

文章本天成&#xff0c;妙手偶得之。粹然无疵瑕&#xff0c;岂复须人为&#xff1f;君看古彝器&#xff0c;巧拙两无施。汉最近先秦&#xff0c;固已殊淳漓。胡部何为者&#xff0c;豪竹杂哀丝。后夔不复作&#xff0c;千载谁与期&#xff1f; ——《文章》宋陆游 【哲理】文章…

基于 SSM(Spring + Spring MVC + MyBatis)框架构建电器网上订购系统

基于 SSM&#xff08;Spring Spring MVC MyBatis&#xff09;框架构建电器网上订购系统可以为用户提供一个方便快捷的购物平台。以下将详细介绍该系统的开发流程&#xff0c;包括需求分析、技术选型、数据库设计、项目结构搭建、主要功能实现以及前端页面设计。 需求分析 …