LeNet-5(fashion-mnist)

文章目录

  • 前言
  • LeNet
  • 模型训练

前言

LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。

LeNet

LeNet-5由以下两个部分组成

  • 卷积编码器(2)
  • 全连接层(3)
    卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
    第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
    3个全连接层分别有120,84,10个输出。
    此处对原始模型做出部分修改,去除最后一层的高斯激活。
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Flatten(),
                  nn.Linear(16*5*5,120),nn.Sigmoid(),
                  nn.Linear(120,84),nn.Sigmoid(),
                  nn.Linear(84,10))

模型训练

为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。
此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。

batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)

将数据送入GPU进行计算测试集准确率

def evaluate_accuracy_gpu(net,data_iter,device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net,torch.nn.Module):
        net.eval()
        if not device:
            device=next(iter(net.parameters())).device
    # 正确预测的数量,预测的总数
    eva = 0.0
    y_num = 0.0
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(X,list):
                X=[x.to(device) for x in X]
            else:
                X=X.to(device)
            y=y.to(device)
            eva += accuracy(net(X), y)
            y_num += y.numel()
    return eva/y_num

训练过程同样将数据送入GPU计算

def train_epoch_gpu(net, train_iter, loss, updater,device):

    # 训练损失之和,训练准确数之和,样本数
    train_loss_sum = 0.0
    train_acc_sum = 0.0
    num_samples = 0.0
    # timer = d2l.torch.Timer()
    for i, (X, y) in enumerate(train_iter):
        # timer.start()
        updater.zero_grad()
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        l = loss(y_hat, y)
        l.backward()
        updater.step()
        with torch.no_grad():
            train_loss_sum += l * X.shape[0]
            train_acc_sum += evaluation.accuracy(y_hat, y)
            num_samples += X.shape[0]
        # timer.stop()
    return train_loss_sum/num_samples,train_acc_sum/num_samples


def train_gpu(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):
        if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:
            torch.nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)
    net.to(device)
    print('training on',device)
    optimizer=torch.optim.SGD(net.parameters(),lr=lr)
    loss=torch.nn.CrossEntropyLoss()
    # num_batches=len(train_iter)
    tr_l=[]
    tr_a=[]
    te_a=[]
    for epoch in range(num_epochs):
        net.train()
        train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)
        test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)
        train_loss, train_acc = train_metric
        train_loss = train_loss.cpu().detach().numpy()
        tr_l.append(train_loss)
        tr_a.append(train_acc)
        te_a.append(test_accuracy)
        print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')
    x = torch.arange(num_epochs)
    plt.plot((x + 1), tr_l, '-', label='train_loss')
    plt.plot(x + 1, tr_a, '--', label='train_acc')
    plt.plot(x + 1, te_a, '-.', label='test_acc')
    plt.legend()
    plt.show()
    print(f'on {str(device)}')
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/317632.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GitHub项目推荐-incubator

项目地址 Github地址:GitHub - apache/incubator-anser 官网:Apache Answer | Free Open-source Q&A Platform 项目简述 这是Apache的一个开源在线论坛,也可以部署成为一个自有的QA知识库。项目主要使用了Go和Typescript来开发&#…

微服务治理:微服务断路器(微服务故障隔离模式)详解

微服务断路器是一种设计模式,可以保护系统免于级联故障,通过限制对故障服务的调用来实现。它的工作原理类似于电气断路器:当服务遇到问题时,它会切断请求流,使其有机会恢复,并防止其他服务被压垮。 工作原…

短视频抖音文案策划创作运营手册资料大全

【干货资料持续更新,以防走丢】 短视频抖音文案策划创作运营手册资料大全 部分资料预览 资料部分是网络整理,仅供学习参考。 抖音运营资料合集(完整资料包含以下内容) 目录 制作短视频的四部曲 主题 主题是短视频脚本的基调…

(学习日记)2024.01.13:一份关于自行车定位的调研 2

写在前面: 由于时间的不足与学习的碎片化,写博客变得有些奢侈。 但是对于记录学习(忘了以后能快速复习)的渴望一天天变得强烈。 既然如此 不如以天为单位,以时间为顺序,仅仅将博客当做一个知识学习的目录&a…

Springboot读取配置文件

多种配置文件格式 springboot项目中不同配置文件的优先加载顺序 为:properties> yml >yaml>自定义核心类配置 自定义配置文件的加载 一般系统会加载默认的application.properties或者application.yml,但如果使用自定义配置文件,可使用下面方…

计算机毕业设计-----SpringBoot招聘网站项目

项目介绍 SpringBoot招聘网站项目。主要功能说明: 管理员登录,简历管理,问答管理,职位管理,用户管理,职位申请进度更新,查看简历等功能。 用户角色包含以下功能:用户首页,登录注册,职位查看,职位详情,投递简历,查看我的申请,管理个人简历,附件简历管理…

【Linux】应用与驱动交互及应用间数据交换

一、应用程序与 Linux 驱动交互主要通过以下几种方式: 1. 系统调用接口(System Calls): 应用程序可以通过系统调用,如 open(), read(), write(), ioctl(), 等来与设备驱动进行交互。这些调用最终会通过内核转发到相应的驱动函数…

AI剪辑助手:轻松剪辑专注创意,视频批量AI智剪的方法

随着科技的飞速发展,人工智能(AI)在许多领域都展现出了强大的能力。在视频剪辑领域,AI剪辑助手的出现,给内容创作者带来了前所未有的便利。它不仅能快速、高效地完成视频剪辑工作,还能释放创造力。今天一起…

记录用python封装的第一个小程序

前言 我要封装的是前段时间复现的一个视频融合拼接的程序,现在我打算将他封装成exe程序,我在这里只记录一下我封装的过程,使用的是pyinstaller,具体的封装知识我就不多说了,可以参考我另一篇博客:将Python…

Linux操作系统基础

目录 计算机存储结构 冯.诺依曼结构 操作系统 在前几期我们学写了linux中常见的一些指令,本期我们将正式进行linux操作系统的学习。 计算机存储结构 要学习linux操作系统,我们就得先进行计算机存储结构的学习,要进行计算机存储结构的学…

电脑桌面便签在哪里设置?Win10电脑桌面便签设置指南

很多上班族在使用Win10电脑办公时,需要随时记录工作事项。例如,需要参加一场紧急会议但时间不确定,或者在电话沟通时需要记下重要的信息,甚至可能是需要快速记录工作计划、想法或者临时安排的场景。在这些情况下,如果有…

高精度恒流/恒压(CC/CV)原边反馈功率转换器

一、产品概述 PR6214是一款应用于小功率AC/DC充电器和电源适配器的高性能离线式功率开关转换器。PR6214采用PFM工作模式,使用原边反馈架构,无需次级反馈电路,因此省去了光耦和431,应用电路简单,降低了系统的成本和体积…

FEB(acwing)

文章目录 FEB题目描述输入格式输出格式数据范围输入样例1:输出样例1:输入样例2:输出样例2:输入样例3:输出样例3:代码题解情况1:xxxxxx:0,1,2,…&a…

css宽度适应内容

废话不多说,看如下demo,我需要将下面这个盒子的宽度变成内容自适应 方法有很多,如下 父元素设置display:flex 实现子元素宽度适应内容 如下给父元素设置flex能实现宽度自适应内容 <!DOCTYPE html><html lang"en"><head><meta charset"U…

在CentOS中,对静态HTTP服务的性能监控

在CentOS中&#xff0c;对静态HTTP服务的性能监控和日志管理是确保系统稳定运行和及时发现潜在问题的关键。以下是对这一主题的详细探讨。 性能监控 使用工具监控&#xff1a;top、htop、vmstat、iostat等工具可以用来监控CPU、内存、磁盘I/O等关键性能指标。这些工具可以实时…

pandas增强—数据表的非等式连接和条件连接。

Pandas 支持 equi-join&#xff0c;其中 join 中涉及的键被认为是相等的。这是通过 merge 和 join 函数实现的。但是&#xff0c;在某些情况下&#xff0c;所涉及的Key可能不相等;联接中还涉及一些其他逻辑条件、这称为非等式连接或不等式连接或者条件连接。 这种情况下使用pa…

C#MQTT编程02--报文格式

1、报文结构 在MQTT协议中&#xff0c;一个MQTT数据包由&#xff1a;固定头&#xff08;Fixed header&#xff09;、可变头&#xff08;Variable header&#xff09;、消息体&#xff08;Payload&#xff09;三部分构成。 注意2点&#xff1a; 1&#xff09;所有的数据包结构…

大模型实战营Day4 XTuner 大模型单卡低成本微调实战

本次讲师是一位从事算法工作的优秀贡献者。 一起来看看吧&#xff01; 本次课程内容主要有&#xff1a; 我将在此整理前三节的内容&#xff0c;第四节放在作业章节进行讲解&#xff1a; 同第三节的建立数据库中所提及到的&#xff0c;如果通用大模型在专用领域表现能力不强&…

GSTAE

缺失数据的流量预测:一种多任务学习方法 摘要:基于真实交通数据的交通速度预测是智能交通系统(ITS)中的一个经典问题。大多数现有的交通速度预测模型都是基于交通数据完整或具有罕见缺失值的假设而提出的。然而,由于各种人为和自然因素,在现实场景中收集的此类数据往往是…

matlab中any()函数用法

一、帮助文档中的介绍 B any(A) 沿着大小不等于 1 的数组 A 的第一维测试所有元素为非零数字还是逻辑值 1 (true)。实际上&#xff0c;any 是逻辑 OR 运算符的原生扩展。 二、解读 分两步走&#xff1a; ①确定维度&#xff1b;②确定运算规则 以下面二维数组为例 >>…