经典卷积神经网络 LeNet

一、实例图片

#我们传入的是28*28,所以加了padding
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

二、总结

1、LeNet是早期成功的神经网络

2、先使用卷积层来学习图片空间信息

3、然后使用全连接层来转换到类别空间

三、代码

1、评估模型,将参数放到GPU中

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        #eval是将模型设置为评估模式,评估模式就不会改变模型参数了可以用来预测结果;eval就是关闭模型中的dropout功能,调到评价模式;与之相对的是train()
        net.eval()
        if not device:
            #如果未提供设备参数,则使用模型第一个参数的设备作为默认设备。这确保了数据和模型在            同一设备上运行
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    #确保在计算精度时不会计算梯度,从而节省显存和提高计算效率
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

2、训练模型

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    #对于net里面的所有parameter,都去run一下那个初始化权重的函数。就是说在整个net中的所有层
     上面都使用init__weights函数来初始化所有现行层和卷积层的权重
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            #xavier能够根据输入输出的大小,使得初始化随机权重能使,输入和输出的方差差别不会
             很大,保证在模型最开始的时候,结果不会指数爆炸或者消失
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    #SGD是随机梯度下降算法
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                #l是当前批次的平均损失。X.shape[0]是当前批次的样本数。l * X.shape[0]
                 计算的是当前批次的总损失。
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]

            #这里是可视化
            #(i + 1) % (num_batches // 5)每训练到一个阶段时(5次中的每一次)会更新可视化数据
            #i == num_batches - 1:当训练到最后一个批次时,条件也会满足
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/772660.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

安卓请求服务器[根据服务器的内容来更新spinner]

根据服务器的内容来更新spinner 本文内容请结合如下两篇文章一起看: 腾讯云函数node.js返回自动带反斜杠 腾讯云函数部署环境[使用函数URL] 现在有这样一个需求,APP有一个下拉选择框作为版本选择,因为改个管脚就变成一个版本,客户需求也很零散,所以后期会大量增加版本,这时候每…

三菱PLC标签使用(I/O的映射)与内容

今天,小编继续开始三菱PLC的学习,今天的内容是标签及其标签的内容说明,如果对你有帮助,欢迎评论收藏。 标签的种类,等级,定义 种类 三菱3U的PLC的种类分别为二种:全局标签与局部标签 全局标签…

【大数据】—量化交易实战案例(海龟交易策略)

声明:股市有风险,投资需谨慎!本人没有系统学过金融知识,对股票有敬畏之心没有踏入其大门,今天用另外一种方法模拟炒股,后面的模拟的实战全部用同样的数据,最后比较哪种方法赚的钱多。 海龟交易…

交互未来入选“北京市通用人工智能产业创新伙伴计划”模型伙伴

在“开启数智新时代,共享数字新未来”的宏大主题下,2024全球数字经济大会于近日盛大开幕,汇聚全球智慧,共谋数字经济新篇章。大会首日,备受瞩目的人工智能专题论坛率先拉开帷幕,以“应用即未来——大模型赋…

PLC工作原理

PLC(可编程逻辑控制器)的工作原理简述为:集中采样、集中输出、周期性循环扫描。 西门子PLC 一、集中采样 顺序读取所有输入端子的通断状态,并将所读取的信息存到输入映像寄存器中,此时输入映像寄存器被刷新&#xff…

基于Java中的SSM框架实现野生动物公益保护系统项目【项目源码+论文说明】计算机毕业设计

基于Java中的SSM框架实现野生动物公益保护系统演示 摘要 本系统按照网站系统设计的基本流程,遵循系统开发生命周期法和结构化方法,基于Java语言设计并实现了野生动物公益保护系统。该系统基于浏览器/服务器模式,采用JSP技术,后台…

Land survey boundary report (template)

Land survey boundary report (template) 土地勘测定界报告(模板).doc

昇思25天学习打卡营第7天|保存与加载

Python语言 AI框架:Mindspore 1.模型构建 class Network(nn.Cell):def __init__(self):super().__init__()self.flatten nn.Flatten()self.dense_relu_sequential nn.SequentialCell(nn.Dense(28*28, 512, weight_init"normal", bias_init"zeros…

C#Winform窗体中嵌入exe文件

1,效果以嵌入Modbus Slave为例: 2,代码: public partial class Form1 : Form{//设置嵌入exe的常量private const int nIndex -16;private const int dwNewLong 0x10000000;Process m_AppProcess;public Form1(){InitializeCompo…

符号速率、调制方式、码率、传输速率

符号速率、调制方式、码率(比特率)和传输速率在通信系统中是相互关联且重要的概念,它们之间的关系可以归纳如下: 一、定义 符号速率(Symbol Rate): 也称波特率(Baud Rate&#xff…

mmdetection3增加12种注意力机制

在mmdetection/mmdet/models/layers/目录下增加attention_layers.py import torch.nn as nn from mmdet.registry import MODELS #自定义注意力机制算法 from .attention.CBAM import CBAMBlock as _CBAMBlock from .attention.BAM import BAMBlock as _BAMBlock from .attent…

欢乐钓鱼大师游戏攻略:在什么地方掉称号鱼?云手机游戏辅助!

《欢乐钓鱼大师》是一款融合了休闲娱乐和策略挑战的钓鱼游戏。游戏中的各种鱼类不仅各具特色,而且钓鱼过程充满了挑战和乐趣。下面将为大家详细介绍如何在游戏中钓鱼,以及一些有效的钓鱼技巧,帮助你成为一个出色的钓鱼大师。 实用工具推荐 为…

Day64 代码随想录打卡|回溯算法篇---组合总和

题目(leecode T39): 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 ,并以列表形式返回。你可以按 任意顺序 返回这些组合。 …

什么开放式耳机好用?五大王牌开放式耳机种草!

随着科技的持续进步,开放式蓝牙耳机悄然兴起,逐步取代了经典的入耳式耳机。入耳式耳机以其卓越的隔音性能著称,然而,长时间的使用却容易引发耳道受压,伴随而来的不仅是疼痛与不适,更潜藏着耳膜受损的风险。…

90%的铲屎官必遇到家里猫毛满天飞问题,热门宠物空气净化器分享

作为一名资深猫奴,一到换毛季节家中就会忍受猫毛飞舞、异味四溢的双重困扰?花粉加上宠物的毛发和体味,过敏和不适似乎成了家常便饭。尝试过很多半方法,用过空气净化器去除毛和异味,虽然普通空气净化器可能提供一定程度…

swiftui中几个常用的手势控制单击点击,双击和长按事件

简单做了一个示例代码,包含三个圆形形状,配置了不同的事件,示例代码: // // RouterView.swift // SwiftBook // // Created by song on 2024/7/4. //import SwiftUIstruct RouterView: View {State var isClick falsevar bod…

Movable antenna 早期研究

原英文论文名字Historical Review of Fluid Antenna and Movable Antenna 最近,无线通信研究界对“流体天线”和“可移动天线”两种新兴天线技术的发展引起了极大的关注,这两种技术因其前所未有的灵活性和可重构性而极大地提高了无线应用中的系统性能。…

用Vue3和Plotly.js绘制交互式3D烛形图

本文由ScriptEcho平台提供技术支持 项目地址:传送门 Plotly.js实现交互式K线图 应用场景 K线图广泛应用于金融领域,用于展示股票、外汇等金融产品的价格走势。它直观地呈现了开盘价、收盘价、最高价和最低价等信息,帮助投资者分析市场趋势…

“第六感”真的存在吗?

现在已有证据表明,人类除视觉、听觉、嗅觉、味觉和触觉五种感觉以外,确实存在“第六感” “第六感”的学术名称为“超感自知觉”(简称ESP),它能透过正感官之外的渠道接收信息, 预知将要发生的事,而且与当事人之前的经…

探索Figma:下载流程及使用前准备

Figma 是基于浏览器的 UI 设计合作工具。无需下载,打开浏览器使用。虽然更建议直接在浏览器中使用 Figma,但是如果确实需要下载 Figma 客户端,可以直接在 Figma 官网的 Products > Downloads 页面下载。如果你不能访问 Figma 官网&#xf…