卷积神经网络——LeNet——FashionMNIST

目录

  • 一、文件结构
  • 二、model.py
  • 三、model_train.py
  • 四、model_test.py

一、文件结构

在这里插入图片描述

二、model.py

import torch
from torch import nn
from torchsummary import summary

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.c1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)
        self.sig = nn.Sigmoid()
        self.s2 = nn.AvgPool2d(kernel_size=2,stride=2)
        self.c3 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
        self.s4 = nn.AvgPool2d(kernel_size=2,stride=2)

        self.flatten = nn.Flatten()
        self.f5 = nn.Linear(in_features=5*5*16,out_features=120)
        self.f6 = nn.Linear(in_features=120,out_features=84)
        self.f7 = nn.Linear(in_features=84,out_features=10)

    def forward(self,x):
        x = self.sig(self.c1(x))
        x = self.s2(x)
        x = self.sig(self.c3(x))
        x = self.s4(x)
        x = self.flatten(x)
        x = self.f5(x)
        x = self.f6(x)
        x = self.f7(x)
        return x

# if __name__ =="__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     model = LeNet().to(device)
#
#     print(summary(model,input_size=(1,28,28)))

三、model_train.py

# 导入所需的Python库
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as Data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LeNet  # model.py中定义了LeNet模型
from tqdm import tqdm  # 导入tqdm库,用于显示进度条

# 定义数据加载和处理函数
def train_val_data_process():
    # 加载FashionMNIST数据集,Resize到28x28尺寸,并转换为Tensor
    train_data = FashionMNIST(root="./data",
                              train=True,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    # 将加载的数据集分为80%的训练数据和20%的验证数据
    train_data, val_data = Data.random_split(train_data, lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])

    # 为训练数据和验证数据创建DataLoader,设置批量大小为32,洗牌,2个进程加载数据
    train_dataloader = Data.DataLoader(dataset=train_data,
                                       batch_size=32,
                                       shuffle=True,
                                       num_workers=2)

    val_dataloader = Data.DataLoader(dataset=val_data,
                                     batch_size=32,
                                     shuffle=True,
                                     num_workers=2)

    # 返回训练和验证的DataLoader
    return train_dataloader, val_dataloader

# 定义模型训练和验证过程的函数
def train_model_process(model, train_dataloader, val_dataloader, num_epochs):
    # 设置使用CUDA如果可用
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 打印使用的设备
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    print(f'当前模型训练设备为: {dev}')

    # 初始化Adam优化器和交叉熵损失函数
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # 将模型移动到选定的设备上
    model = model.to(device)

    # 复制模型权重用于后续更新最佳模型
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0  # 初始化最佳准确度

    # 初始化用于记录训练和验证过程中损失和准确度的列表
    train_loss_all = []
    val_loss_all = []
    train_acc_all = []
    val_acc_all = []

    # 记录训练开始时间
    start_time = time.time()

    # 迭代指定的训练轮数
    for epoch in range(1, num_epochs + 1):
        # 记录每个epoch开始的时间
        since = time.time()

        # 打印分隔符和当前epoch信息
        print("-" * 10)
        print(f"Epoch: {epoch}/{num_epochs}")

        # 初始化训练和验证过程中的损失和正确预测数量
        train_loss = 0.0
        train_corrects = 0
        val_loss = 0.0
        val_corrects = 0

        # 初始化批次计数器
        train_num = 0
        val_num = 0

        # 创建训练进度条
        progress_train_bar = tqdm(total=len(train_dataloader), desc=f'Training {epoch}', unit='batch')

        # 训练数据集的遍历
        for step, (b_x, b_y) in enumerate(train_dataloader):
            # 将数据移动到相应的设备上
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            # 训练模型
            model.train()

            # 前向传播
            output = model(b_x)

            # 计算预测标签
            pre_label = torch.argmax(output, dim=1)

            # 计算损失
            loss = criterion(output, b_y)

            # 清空梯度
            optimizer.zero_grad()

            # 反向传播
            loss.backward()

            # 更新权重
            optimizer.step()

            # 累加损失和正确预测数量
            train_loss += loss.item() * b_x.size(0)
            train_corrects += torch.sum(pre_label == b_y.data)

            # 更新批次计数器
            train_num += b_x.size(0)

            # 更新训练进度条
            progress_train_bar.update(1)

        # 关闭训练进度条
        progress_train_bar.close()

        # 创建验证进度条
        progress_val_bar = tqdm(total=len(val_dataloader), desc=f'Validation {epoch}', unit='batch')

        # 验证数据集的遍历
        for step, (b_x, b_y) in enumerate(val_dataloader):
            # 将数据移动到相应的设备上
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            # 评估模型
            model.eval()

            # 前向传播
            output = model(b_x)

            # 计算预测标签
            pre_label = torch.argmax(output, dim=1)

            # 计算损失
            loss = criterion(output, b_y)

            # 累加损失和正确预测数量
            val_loss += loss.item() * b_x.size(0)
            val_corrects += torch.sum(pre_label == b_y.data)

            # 更新批次计数器
            val_num += b_x.size(0)

            # 更新验证进度条
            progress_val_bar.update(1)

        # 关闭验证进度条
        progress_val_bar.close()

        # 计算并记录epoch的平均损失和准确度
        train_loss_all.append(train_loss / train_num)
        train_acc_all.append(train_corrects.double().item() / train_num)

        val_loss_all.append(val_loss / val_num)
        val_acc_all.append(val_corrects.double().item() / val_num)

        # 打印训练和验证的损失与准确度
        print(f'{epoch} Train Loss: {train_loss_all[-1]:.4f} Train Acc: {train_acc_all[-1]:.4f}')
        print(f'{epoch} Val Loss: {val_loss_all[-1]:.4f} Val Acc: {val_acc_all[-1]:.4f}')

        # 计算并打印epoch训练耗费的时间
        time_use = time.time() - since
        print(f'第 {epoch} 个 epoch 训练耗费时间: {time_use // 60:.0f}m {time_use % 60:.0f}s')

        # 若当前epoch的验证准确度为最佳,则更新最佳模型权重
        if val_acc_all[-1] > best_acc:
            best_acc = val_acc_all[-1]
            best_model_wts = copy.deepcopy(model.state_dict())

    # 训练结束,保存最佳模型权重
    torch.save(best_model_wts, 'D:/Pycharm/deepl/LeNet/weight/best_model.pth')

    # 如果当前epoch为总epoch数,则保存最终模型权重
    if epoch == num_epochs:
        torch.save(model.state_dict(), f'D:/Pycharm/deepl/LeNet/weight/{num_epochs}_model.pth')

    # 将训练过程中的统计数据整理成DataFrame
    train_process = pd.DataFrame(data={
        "epoch": range(1, num_epochs + 1),
        "train_loss_all": train_loss_all,
        "val_loss_all": val_loss_all,
        "train_acc_all": train_acc_all,
        "val_acc_all": val_acc_all
    })

    # 打印总训练时间
    consume_time = time.time() - start_time
    print(f'总耗时:{consume_time // 60:.0f}m {consume_time % 60:.0f}s')

    # 返回包含训练过程统计数据的DataFrame
    return train_process

# 定义绘制训练和验证过程中损失与准确度的函数
def matplot_acc_loss(train_process):
    # 创建图形和子图
    plt.figure(figsize=(12, 4))

    # 绘制训练和验证损失
    plt.subplot(1, 2, 1)
    plt.plot(train_process["epoch"], train_process["train_loss_all"], 'ro-', label="train_loss")
    plt.plot(train_process["epoch"], train_process["val_loss_all"], 'bs-', label="val_loss")
    plt.legend()
    plt.xlabel("epoch")
    plt.ylabel("loss")
    # 保存损失图像
    plt.savefig('./result_picture/training_loss_accuracy.png', bbox_inches='tight')

    # 绘制训练和验证准确度
    plt.subplot(1, 2, 2)
    plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_acc")
    plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bs-', label="val_acc")
    plt.legend()
    plt.xlabel("epoch")
    plt.ylabel("accuracy")
    # 保存准确率曲线图
    plt.savefig('./result_picture/training_accuracy.png', bbox_inches='tight')
    plt.show()

if __name__ == "__main__":
    model = LeNet()

    train_dataloader, val_dataloader = train_val_data_process()
    train_process = train_model_process(model, train_dataloader, val_dataloader, num_epochs=20)

    matplot_acc_loss(train_process)

四、model_test.py

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# t代表test


def t_data_process():
    test_data = FashionMNIST(root="./data",
                             train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)

    return test_dataloader


def t_model_process(model, test_dataloader):
    if model is not None:
        print('Successfully loaded the model.')

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = model.to(device)

    # 初始化参数
    test_corrects = 0.0
    test_num = 0
    all_preds = []  # 存储所有预测标签
    all_labels = []  # 存储所有实际标签

    # 只进行前向传播,不计算梯度
    with torch.no_grad():
        for test_x, test_y in test_dataloader:
            test_x = test_x.to(device)
            test_y = test_y.to(device)

            # 设置模型为验证模式
            model.eval()
            # 前向传播得到一个batch的结果
            output = model(test_x)
            # 查找最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)

            # 收集预测和实际标签
            all_preds.extend(pre_lab.tolist())
            all_labels.extend(test_y.tolist())

            # 计算准确率
            test_corrects += torch.sum(pre_lab == test_y.data)

            # 将所有的测试样本进行累加
            test_num += test_x.size(0)

    # 计算准确率
    test_acc = test_corrects.double().item() / test_num
    print(f'测试的准确率:{test_acc}')

    # 绘制混淆矩阵
    conf_matrix = confusion_matrix(all_labels, all_preds)
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.title('Confusion Matrix')
    plt.show()
    plt.savefig('./result_picture/Confusion_Matrix.png', bbox_inches='tight')



if __name__=="__main__":
    # 加载模型
    model = LeNet()

    print('loading model')
    # 加载权重
    model.load_state_dict(torch.load('D:/Pycharm/deepl/LeNet/weight/best_model.pth'))

    # 加载测试数据
    test_dataloader = t_data_process()

    # 加载模型测试的函数
    t_model_process(model,test_dataloader)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = model.to(device)

    classes = ['T-shirt/top','Trouser','Pullover','Dress','coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
    with torch.no_grad():
        for b_x,b_y in test_dataloader:
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            model.eval()

            output = model(b_x)
            pre_lab = torch.argmax(output,dim=1)
            result = pre_lab.item()
            label = b_y.item()

            print(f'预测值:{classes[result]}',"-----------",f'真实值:{classes[label]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/796757.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于SSM的校园一卡通管理系统的设计与实现

摘 要 本报告全方位、深层次地阐述了校园一卡通管理系统从构思到落地的整个设计与实现历程。此系统凭借前沿的 SSM(Spring、Spring MVC、MyBatis)框架精心打造而成,旨在为学校构建一个兼具高效性、便利性与智能化的一卡通管理服务平台。 该系…

liunx硬盘分区挂载笔记

NAME: 设备名称。 MAJ : 主设备号和次设备号。 RM: 只读标志(0 表示可读写,1 表示只读)。 SIZE: 设备的总大小。 RO: 只读状态(0 表示可读写,1 表示只读)。 TYPE: 设备类型(disk 表示物理磁盘设…

C 语言结构体

由于近期项目需求,需使用到大量的指针与结构体,为更好的完成项目,故对结构体与指针的内容进行回顾,同时撰写本博客,方便后续查阅。 本博客涉及的结构体知识有: 1.0:结构体的创建和使用 2.0: typedef 关…

怎样在 C 语言中进行类型转换?

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会! 📙C 语言百万年薪修炼课程 通俗易懂,深入浅出,匠心打磨,死磕细节,6年迭代,看过的人都说好。 文章目…

记一次 .NET某上位视觉程序 离奇崩溃分析

一:背景 1. 讲故事 前段时间有位朋友找到我,说他们有一个崩溃的dump让我帮忙看下怎么回事,确实有太多的人在网上找各种故障分析最后联系到了我,还好我一直都是免费分析,不收取任何费用,造福社区。 话不多…

快速读出linux 内核中全局变量

查问题时发现全局变量能读出来会提高效率,于是考虑从怎么读出内核态的全局变量,脚本如下 f open("/proc/kcore", rb) f.seek(4) # skip magic assert f.read(1) b\x02 # 64 位def read_number(bytes):return int.from_bytes(bytes, little,…

每日一练:奇怪的TTL字段(python实现图片操作实战)

打开图片,只有四种数字:127,191,63,255 最大数字为255,想到进制转换 将其均转换为二进制: 发现只有前2位不一样 想着把每个数的前俩位提取出来,组成新的二进制,然后每…

c++ 多边形 xyz 数据 获取 中心点方法,线的中心点取中心值搞定 已解决

有需求需要对。多边形 获取中心点方法&#xff0c;绝大多数都是 puthon和java版本。立体几何学中的知识。 封装函数 point ##########::getCenterOfGravity(std::vector<point> polygon) {if (polygon.size() < 2)return point();auto Area [](point p0, point p1, p…

AI绘画Midijourney操作技巧及变现渠道喂饭式教程!

前言 盘点Midijourney&#xff08;AIGF&#xff09;热门赚米方法&#xff0c;总有一种适合你之AI绘画操作技巧及变现渠道剖析 【表情包制作】 首先我们对表情包制作进行详细的讲解&#xff1a; 当使用 Midjourney&#xff08;AIGF&#xff09; 绘画来制作表情包时&#xff…

ensp防火墙综合实验作业+实验报告

实验目的要求及拓扑图&#xff1a; 我的拓扑&#xff1a; 更改防火墙和交换机&#xff1a; [USG6000V1-GigabitEthernet0/0/0]ip address 192.168.110.5 24 [USG6000V1-GigabitEthernet0/0/0]service-manage all permit [Huawei]vlan batch 10 20 [Huawei]int g0/0/2 [Huawei-…

218.贪心算法:分发糖果(力扣)

核心思想 初始化每个学生的糖果数为1&#xff1a; 确保每个学生至少有一颗糖果。从左到右遍历&#xff1a; 如果当前学生的评分高于前一个学生&#xff0c;则当前学生的糖果数应比前一个学生多一颗。从右到左遍历&#xff1a; 如果当前学生的评分高于后一个学生&#xff0c;则…

排序【选择排序和快速排序】

1.选择排序 1.1基本思想 每次选出最小&#xff08;或最大&#xff09;的一个元素&#xff0c;存放在数组的起始位置&#xff0c;直到所有元素都排完。 1.2直接插入排序&#xff1a; 在数组arr[i]到arr[n-1]中选出最大&#xff08;小&#xff09;的元素。若该元素不是数组的…

前端的页面代码

根据老师教的前端页面的知识&#xff0c;加上我也是借鉴了老师上课所说的代码&#xff0c;马马虎虎的写出了页面。如下代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</ti…

小型 FPGA 瞄准 4K 视频接口,MiSTer FPGA 现已支持 Sinden 光枪-FPGA新闻速览

无需矩阵乘法&#xff0c;在 FPGA 上实现低功耗、高性能的 LLM UC Santa Cruz, Soochow University, UC Davis 和 LuxiTech 发表了一篇题为“可扩展的无 MatMul 语言建模”的新技术论文。 “矩阵乘法 (MatMul) 通常占据大型语言模型 (LLM) 总体计算量的主导地位。随着 LLM 扩展…

PLC物联网关在工业自动化领域的应用的意义-天拓四方

随着信息技术的飞速发展&#xff0c;物联网技术正逐步渗透到各个行业领域&#xff0c;其中&#xff0c;工业自动化领域的PLC与物联网的结合&#xff0c;为工业自动化的发展开辟了新的道路。PLC物联网关作为连接PLC与物联网的重要桥梁&#xff0c;其重要性日益凸显。 PLC物联网…

单例模式Singleton

设计模式 23种设计模式 Singleton 所谓类的单例设计模式&#xff0c;就是采取一定的方法保证在整个的软件系统中&#xff0c;对某个类只能存在一个对象实例&#xff0c;并且该类只提供一个取得其对象实例的方法。 饿汉式 public class BankTest {public static void main(…

四个“一体化”——构建数智融合时代下的一站式大数据平台

随着智能化技术的飞速发展&#xff0c;尤其是以生成式AI为代表的技术快速应用&#xff0c;推动了数据与智能的深化融合&#xff0c;给数据基础设施带来了新的变革和挑战。如何简化日益复杂的系统架构&#xff0c;提高数据处理效率&#xff0c;降低开发运维成本&#xff0c;促进…

Selenium使用注意事项:

find_element 和 find_elements 的区别 WebDriver和WebElement的区别 问题&#xff1a; 会遇到报错&#xff1a; selenium.common.exceptions.NoSuchElementException: Message: no such element: Unable to locate element: {"method":"css selector",&…

STM32智能空气质量监测系统教程

目录 引言环境准备智能空气质量监测系统基础代码实现&#xff1a;实现智能空气质量监测系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;空气质量监测与优化问题解决方案与优化收尾与总结 1. 引言 智能空…

UCSD和MIT的华人学者最新成果展示:沉浸式远程遥操作机器人

你是否曾想过&#xff0c;自己身处某地&#xff0c;可以控制几千公里以外的「机器人」本体&#xff1f;这个想法&#xff0c;最近被来自UCSD和MIT的华人学者们实现了。UCSD位于加利福尼亚州&#xff0c;MIT位于马萨诸塞州&#xff0c;这两地之差&#xff0c;约3000英里&#xf…