LeNet原理及代码实现

目录

1.原理及介绍 

2.代码实现

2.1model.py

2.2model_train.py

2.3model.test.py


1.原理及介绍 

2.代码实现

2.1model.py

import torch
from torch import nn
from torchsummary import summary


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 定义第一层卷积(in_channels;输入通道;out_channels:输出通道;kernel_size:卷积核大小;stride:步长,默认为1;padding:填充,默认为0)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.sig = nn.Sigmoid()  # 激活函数
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)  # 第一次池化
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)  # 定义第二层卷积
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)  # 第二次池化

        self.flatten = nn.Flatten()  # 平展操作
        self.fc1 = nn.Linear(5 * 5 * 16, 120)  # 第一个全连接层
        self.fc2 = nn.Linear(120, 84)  # 第二个全连接层
        self.fc3 = nn.Linear(84, 10)  # 第三个全连接层

    def forward(self, x):
        x = self.conv1(x)
        x = self.sig(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.sig(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LeNet().to(device)
    print(summary(model, (1, 28, 28)))  # 打印模型信息

2.2model_train.py

import copy
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as data
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from model import LeNet


def train_val_data_process():
    train_dataset = FashionMNIST(
        root='./data',  # 指定下载文件夹
        train=True,  # 只下载训练集
        # 归一化处理数据集
        transform=transforms.Compose([transforms.Resize(size=28),
                                      transforms.ToTensor()]),  # 数据转换成张量形式,方便模型应用
        download=True  # 下载数据
    )
    # 划分训练集和验证集
    train_data, val_data = data.random_split(train_dataset,
                                             [round(0.8 * len(train_dataset)), round(0.2 * len(train_dataset))])
    # 训练集加载
    train_dataloader = data.DataLoader(dataset=train_data,
                                       batch_size=32,  # 一个批次数据的数量
                                       shuffle=True,  # 数据打乱
                                       num_workers=2)  # 分配的进程数目
    # 验证集加载
    val_dataloader = data.DataLoader(dataset=val_data,
                                     batch_size=32,
                                     shuffle=True,
                                     num_workers=2)
    return train_dataloader, val_dataloader


def train_model_process(model, train_dataloader, val_dataloader, num_epochs):
    # 定义训练使用的设备,有GPU则用,没有则用CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 使用Adam优化器进行模型参数更新,学习率为0.001
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    # 损失函数为交叉熵损失函数
    criterion = nn.CrossEntropyLoss()
    # 将模型放入训练设备内
    model = model.to(device)
    # 复制当前模型参数(w,b等),以便将最好的模型参数权重保存下来
    best_model_wts = copy.deepcopy(model.state_dict())

    # 初始化参数
    # 最高准确度
    best_acc = 0.0
    # 训练集损失值列表
    train_loss_all = []
    # 验证集损失值列表
    val_loss_all = []
    # 训练集准确度列表
    train_acc_all = []
    # 验证集准确度列表
    val_acc_all = []
    # 当前时间
    since = time.time()

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("-" * 10)

        # 初始化参数
        # 训练集损失值
        train_loss = 0.0
        # 训练集精确度
        train_corrects = 0
        # 验证集损失值
        val_loss = 0.0
        # 验证集精确度
        val_corrects = 0
        # 训练集样本数量
        train_num = 0
        # 验证集样本数量
        val_num = 0

        # 对每一个mini-batch训练和计算
        for step, (b_x, b_y) in enumerate(train_dataloader):
            # 将特征放入到训练设备中
            b_x = b_x.to(device)  # batch_size*28*28*1的tensor数据
            # 将标签放入到训练设备中
            b_y = b_y.to(device)  # batch_size大小的向量tensor数据
            # 设置模型为训练模式
            model.train()

            # 前向传播过程,输入为一个batch,输出为一个batch中对应的预测
            output = model(b_x)  # 输出为:batch_size大小的行和10列组成的矩阵
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)  # batch_size大小的向量表示属于物品的标签
            # 计算每一个batch的损失函数,向量形式的交叉熵损失函数
            loss = criterion(output, b_y)

            # 将梯度初始化为0,防止梯度累积
            optimizer.zero_grad()
            # 反向传播计算
            loss.backward()
            # 根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用
            optimizer.step()
            # 对损失函数进行累加,该批次的loss值乘于该批次数量得到批次总体loss值,在将其累加得到轮次总体loss值
            train_loss += loss.item() * b_x.size(0)
            # 如果预测正确,则准确度train_corrects加1
            train_corrects += torch.sum(pre_lab == b_y.data)
            # 当前用于训练的样本数量
            train_num += b_x.size(0)

        for step, (b_x, b_y) in enumerate(val_dataloader):
            # 将特征放入到验证设备中
            b_x = b_x.to(device)
            # 将标签放入到验证设备中
            b_y = b_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为一个batch,输出为一个batch中对应的预测
            output = model(b_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 计算每一个batch的损失函数
            loss = criterion(output, b_y)

            # 对损失函数进行累加
            val_loss += loss.item() * b_x.size(0)
            # 如果预测正确,则准确度val_corrects加1
            val_corrects += torch.sum(pre_lab == b_y.data)
            # 当前用于验证的样本数量
            val_num += b_x.size(0)

        # 计算并保存每一轮次迭代的loss值和准确率
        # 计算并保存训练集的loss值
        train_loss_all.append(train_loss / train_num)
        # 计算并保存训练集的准确率
        train_acc_all.append(train_corrects.double().item() / train_num)
        # 计算并保存验证集的loss值
        val_loss_all.append(val_loss / val_num)
        # 计算并保存验证集的准确率
        val_acc_all.append(val_corrects.double().item() / val_num)

        # 打印每一轮次的loss值和准确度
        print("{} train loss:{:.4f} train acc: {:.4f}".format(epoch, train_loss_all[-1], train_acc_all[-1]))
        print("{} val loss:{:.4f} val acc: {:.4f}".format(epoch, val_loss_all[-1], val_acc_all[-1]))

        if val_acc_all[-1] > best_acc:
            # 保存当前最高准确度
            best_acc = val_acc_all[-1]
            # 保存当前最高准确度的模型参数
            best_model_wts = copy.deepcopy(model.state_dict())

        # 计算训练和验证的耗时
        time_use = time.time() - since
        print("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use // 60, time_use % 60))

    # 选择最优参数,保存最优参数的模型
    torch.save(best_model_wts, "./model_save/LeNet_best_model.pth")

    # 将产生的数据保存成表格,方便查看
    train_process = pd.DataFrame(data={"epoch": range(num_epochs),
                                       "train_loss_all": train_loss_all,
                                       "val_loss_all": val_loss_all,
                                       "train_acc_all": train_acc_all,
                                       "val_acc_all": val_acc_all})

    return train_process


def matplot_acc_loss(train_process):
    # 显示每一次迭代后的训练集和验证集的损失函数和准确率
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)  # 表示一行两列的第一张图
    plt.plot(train_process['epoch'], train_process.train_loss_all, "ro-", label="Train loss")
    plt.plot(train_process['epoch'], train_process.val_loss_all, "bs-", label="Val loss")
    plt.legend()
    plt.xlabel("epoch")
    plt.ylabel("Loss")

    plt.subplot(1, 2, 2)  # 表示一行两列的第二张图
    plt.plot(train_process['epoch'], train_process.train_acc_all, "ro-", label="Train acc")
    plt.plot(train_process['epoch'], train_process.val_acc_all, "bs-", label="Val acc")
    plt.xlabel("epoch")
    plt.ylabel("acc")
    plt.legend()
    plt.show()


if __name__ == '__main__':
    # 加载需要的模型
    LeNet = LeNet()
    # 加载数据集
    train_data, val_data = train_val_data_process()
    # 利用现有的模型进行模型的训练
    train_process = train_model_process(LeNet, train_data, val_data, num_epochs=20)
    matplot_acc_loss(train_process)

2.3model.test.py

import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet


def test_data_process():
    test_data = FashionMNIST(root='./data',
                             train=False,  # 用测试集进行测试
                             transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                             download=True)

    test_dataloader = data.DataLoader(dataset=test_data,
                                      batch_size=1,  # 该批次设为1
                                      shuffle=True,
                                      num_workers=0)
    return test_dataloader


def test_model_process(model, test_dataloader):
    # 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
    device = "cuda" if torch.cuda.is_available() else 'cpu'

    # 讲模型放入到训练设备中
    model = model.to(device)

    # 初始化参数
    test_corrects = 0.0
    test_num = 0

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output = model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 如果预测正确,则准确度test_corrects加1
            test_corrects += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

    # 计算测试准确率
    test_acc = test_corrects.double().item() / test_num
    print("测试的准确率为:", test_acc)


if __name__ == "__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('./model_save/LeNet_best_model.pth'))  # 调用训练好的参数权重
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)

可以点个免费的赞吗!!!   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/787512.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uniapp 去掉小数末尾多余的0

文章目录 在uniapp或者一般的JavaScript环境中,要去掉小数末尾的0,可以使用以下几种方法: 使用parseFloat()函数 let num 123.4500; let result parseFloat(num); console.log(result); // 输出: 123.45字符串处理 将数字转换为字符串&am…

02day-C++学习(const 指针与引用的关系 inline nullptr)

02day-C学习 1. 使用const注意事项 注意事项 • 可以引⽤⼀个const对象,但是必须⽤const引⽤。const引⽤也可以引⽤普通对象,因为对象的访 问权限在引⽤过程中可以缩⼩,但是不能放⼤。 • 不需要注意的是类似 int& rb a3; double d 1…

代码随想录——合并区间(Leecode LCR74)

题目链接 贪心 排序 class Solution {public int[][] merge(int[][] intervals) {ArrayList<int[]> res new ArrayList<>();// 先将数组按照左区间排序Arrays.sort(intervals, new Comparator<int[]>() {public int compare(int[] intervals1, int[] in…

模板语法之插值语法{{}}——01

<主要研究&#xff1a;{{ 这里可以写什么}} 1.在data中声明的变量函数都可以 2.常量 3.只要是合法的JavaScript的表达式&#xff0c;都可以 4. 模板表达式都被放在沙盒中&#xff0c;只能访问全局变量的一个白名单&#xff0c;如 Math 和 Date <body> <div i…

STM32智能仓库管理系统教程

目录 引言环境准备智能仓库管理系统基础代码实现&#xff1a;实现智能仓库管理系统 4.1 数据采集模块 4.2 数据处理与控制算法 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;仓库管理与优化问题解决方案与优化收尾与总结 1. 引言 智能仓库管理系统通…

深入理解循环神经网络(RNN)

深入理解循环神经网络&#xff08;RNN&#xff09; 循环神经网络&#xff08;Recurrent Neural Network, RNN&#xff09;是一类专门处理序列数据的神经网络&#xff0c;广泛应用于自然语言处理、时间序列预测、语音识别等领域。本文将详细解释RNN的基本结构、工作原理以及其优…

国际网课平台Udemy上的亚马逊云科技AWS免费高分课程和创建、维护EC2动手实践

亚马逊云科技(AWS)是全球云行业最&#x1f525;火的云平台&#xff0c;在全球经济形势不好的大背景下&#xff0c;通过网课学习亚马逊云科技AWS基础备考亚马逊云科技AWS证书&#xff0c;对于找工作或者无背景转行做AWS帮助巨大。欢迎大家关注小李哥&#xff0c;及时了解世界最前…

香橙派AIpro初体验:搭建无线随身NAS

文章目录 1.引言2. 香橙派 AIPro概述3. 开发准备3.0 烧录镜像3.1 需要准备的硬件3.2 需要准备的软件3.3 启动并连接香橙派 AIPro3.3.1 初始化启动香橙派 AIPro3.3.2 无线连接香橙派 AIPro3.3.3.3 VNC连接香橙派 AIPro 3.4 设置固定ip3.4.1 设置开机自动连接WIFI3.4.1 设置香橙派…

遍历请求后端数据引出的数组forEach异步操作的坑

有一个列表数据&#xff0c;每项数据里有一个额外的字段需要去调另外一个接口才能拿到&#xff0c;后端有现有的这2个接口&#xff0c;现在临时需要前端显示出来&#xff0c;所以这里需要前端先去调列表数据的接口拿到列表数据&#xff0c;然后再遍历请求另外一个接口去拿到对应…

springboot封装请求参数json的源码解析

源码位置&#xff1a; org.springframework.web.servlet.mvc.method.annotation.AbstractMessageConverterMethodArgumentResolver#readWithMessageConverters(org.springframework.http.HttpInputMessage, org.springframework.core.MethodParameter, java.lang.reflect.Type…

Java PKI Programmer‘s Guide

一、PKI程序员指南概述 PKI Programmer’s Guide Overview Java认证路径API由一系列类和接口组成&#xff0c;用于创建、构建和验证认证路径。这些路径也被称作认证链。实现可以通过基于提供者的接口插入。 这个API基于密码服务提供者架构&#xff0c;这在《Java密码架构参考指…

c++入门基础篇(上)

目录 前言&#xff1a; 1.c&#xff0b;&#xff0b;的第一个程序 2.命名空间 2.1 namespace的定义 2.2 命名空间使用 3.c&#xff0b;&#xff0b;输入&输出 4.缺省参数 5.函数重载 前言&#xff1a; 我们在之前学完了c语言的大部分语法知识&#xff0c;是不是意…

springboot驾校管理系统-计算机毕业设计源码49777

驾校管理系统 摘 要 驾校管理系统是一个基于Spring Boot框架开发的系统&#xff0c;旨在帮助驾校提高管理效率和服务水平。该系统主要实现了用户管理、年月类型管理、区域信息管理、驾校信息管理、车辆信息管理、报名信息管理、缴费信息管理、财务信息管理、教练分配管理、更换…

微搭低代码从入门到实战01创建数据源

目录 1 创建数据源2 创建字段总结 很多零基础的想学习低代码开发&#xff0c;苦于没有编程的经验感觉入门困难。本次教程就按照我们日常开发的思路&#xff0c;从浅入深逐步拆解一下低代码该如何学习。 开发软件&#xff0c;不管是管理后台还是小程序&#xff0c;先需要规划好数…

忘记Apple ID密码怎么退出苹果ID账号?

忘记Apple ID密码怎么退出账号&#xff1f;Apple ID对每个苹果用户来说都是必不可少的&#xff0c;没有它&#xff0c;用户就不能享受iCloud、App Store、iTunes等服务。苹果手机软件下载、丢失解锁、恢复出厂设置等都需要使用Apple ID。如果忘记Apple ID 密码&#xff0c;这会…

C语言 结构体和共用体——结构体和数组的嵌套

目录 结构体和数组的相互嵌套​编辑 嵌套的结构体 嵌套结构体变量的初始化 结构体数组的定义和初始化 结构体和数组的相互嵌套 嵌套的结构体 在一个结构体内包含了另一个结构体作为其成员 嵌套结构体变量的初始化 STUDENT stu1 {100310121, " 王刚 ", M, {1991…

【Java 的四大引用详解】

首先分别介绍一下这几种引用 强引用&#xff1a; 只要能通过GC ROOT根对象引用链找到就不会被垃圾回收器回收&#xff0c;当所有的GC Root都不通过强引用引用该对象时&#xff0c;才能被垃圾回收器回收。 软引用&#xff08;SoftReference&#xff09;&#xff1a; 当只有软引…

打开ps提示dll文件丢失如何解决?教你几种靠谱的方法

在日常使用电脑过程中&#xff0c;由于不当操作&#xff0c;dll文件丢失是一种常见现象。当dll文件丢失时&#xff0c;程序将无法正常运行&#xff0c;比如ps&#xff0c;pr等待软件。此时&#xff0c;我们需要对其进行修复以恢复其功能&#xff0c;下面我们一起来了解一下出现…

后端登录校验——Filter过滤器和Interceptor拦截器

一、Filter过滤器 前面我们学会了最先进的会话跟踪技术jwt令牌&#xff0c;那么我们要让用户使用某些功能时就要根据jwt令牌来验证用户身份&#xff0c;来决定他是否登陆了、让不让用户访问这个页面&#xff08;或功能&#xff09; 但是这样一来&#xff0c;没发一个请求&…

数学建模中常用的数据处理方法

常用的数据处理方法 本文参考 B站西电数模协会的讲解视频 &#xff0c;只作笔记提纲&#xff0c;想要详细学习具体内容请观看 up 的学习视频。一般来说国赛的 C 题一般数据量比较大。 这里介绍以下两种方法&#xff1a; 数据预处理方法 数据分析方法 数据预处理方法 1. 数据…