23/76-LeNet

LeNet
早期成功的神经网络。
先使用卷积层来学习图片空间信息。
然后使用全连接层转换到类别空间。

在这里插入图片描述

#In[]
'''
LeNet,上世纪80年代的产物,最初为了手写识别设计
'''
from d2l import torch as d2l
import torch 
from torch import nn
from torch.nn.modules.loss import CrossEntropyLoss

from torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import Common_functions


'''
LeNet:
两个卷积层,两个池化层,三个线性层
假定为MNIST设计,输入为(batch_size,1,28,28)
'''

class Reshape(torch.nn.Module):
    def forward(self,x):
        return x.view(-1,1,28,28)

net = nn.Sequential(
    nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),padding=2),nn.Sigmoid(), #输出:(6,28,28)
    nn.AvgPool2d(kernel_size=(2,2)), #不指定stride默认不重叠 输出(6,14,14)
    nn.Conv2d(6,16,kernel_size=(5,5)),nn.Sigmoid(),#输出(16,10,10)
    nn.AvgPool2d(kernel_size=(2,2)),#输出(16,5,5)
    nn.Flatten(),
    nn.Linear(16*5*5,120),nn.Sigmoid(),#
    nn.Linear(120,84),nn.Sigmoid(),
    nn.Linear(84,10)
)


X=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:
    X=layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)

#In[]


batch_size = 256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size)



#对evaluate_accuracy函数进行轻微修改
#使用GPU计算模型在数据集上的精度
#计算网络在测试数据集上面的准确率
#由于完整的测试数据集位于内存中,因此在模型使用GPU预测测试数据集之前,我们需要将其复制到显存中。
def evaluate_accuracy_gpu(net,data_iter,device=None):
    if isinstance(net,nn.Module):
        net.eval() #网络用于测试数据
        if not device:
            device = next(iter(net.parameters())).device #如果没有指定device设备,device设备则使用第一层网络参数的设备
    accumulator = d2l.Accumulator(2) #累加器里面包含两个元素
    for X,y in data_iter:
        if isinstance(X,list):
            X = [x.to(device) for x in X] #X为list类型时,需要加X里面每个元素都复制到device设备上面来
        else:
            X = X.to(device)
        y = y.to(device)
        accumulator.add(d2l.accuracy(net(X),y),y.numel()) #累加器第一个元素为在每一个batch_size中预测准确的个数,第二个元素为每一个batch_size中样本总数目,然后依次循环累加,得到测试数据集上面预测准确的总数目,以及数据集总数目
    return accumulator[0]/accumulator[1] #算出模型预测准确率


def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):#手动初始化模型参数
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight) #使用xavier_uniform分布初始化参数
    net.apply(init_weights)
    net.to(device)#将模型复制到gpu上面
    print('training on',device)
    loss = nn.CrossEntropyLoss() #定义loss
    optim = torch.optim.SGD(net.parameters(),lr=lr) #定义优化器
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train_loss','train_acc','test_acc'])
    timer = d2l.Timer()
    num_batches = len(train_iter)
    for epoch in range(num_epochs):
        net.train()#模型开始训练,需要放在第一层循环里面,因为后面evaluate_accuracy_gpu()函数里面有net.eval(),将模型改变为测试状态,因此需要在每一个循环epoch后面手动再加上模型开始处于训练状态
        accumulator = d2l.Accumulator(3) #累加器
        for i,(X,y) in enumerate(train_iter):
           timer.start()
           optim.zero_grad()
           X = X.to(device)#将X复制到gpu上面
           y = y.to(device) #将y复制到gpu上面
           y_hat = net(X) #得到模型训练后的输出标签y_hat
           l = loss(y_hat,y)#计算每一个batch_size的loss
           l.backward() #计算梯度
           optim.step() #使用优化器更新模型参数
           with torch.no_grad():#不需要模型梯度
               accumulator.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])
           timer.stop()
           train_loss = accumulator[0]/accumulator[2] #从累加器里面获得所有训练集的loss之和
           train_acc = accumulator[1]/accumulator[2] #从累加器里面获得所有训练集的准确数之和
           if (i+1) % (num_batches // 5) == 0 or i == num_batches-1:
               animator.add(epoch+(i+1)/num_batches,(train_loss,train_acc,None))
        test_accuracy = evaluate_accuracy_gpu(net,test_iter) #每次训练完一个epoch后的模型用于测试数据集上面计算测试精确度
        animator.add(epoch+1,(None,None,test_accuracy))
    print(f'模型训练完最后一轮时 train_loss:{train_loss},train_acc:{train_acc},test_acc:{test_accuracy}')
    print(f'{num_epochs*accumulator[2]/timer.sum()}examples/second on {str(device)}')#打印出模型每秒能处理多少个样本数

lr,num_epochs= 0.9,10
train_ch6(net,train_iter=train_iter,test_iter=test_iter,lr=lr,num_epochs=num_epochs,device=d2l.try_gpu())
'''
输出结果:
模型训练完最后一轮时 train_loss:0.4322478462855021,train_acc:0.8396666666666667,test_acc:0.8163
55954.65804440994examples/second on cuda:0
'''








#训练
if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"
device = torch.device(device)

Common_functions.train_device(net,train_iter,test_iter,lr=0.9,device=device)
# %%

plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/328007.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Transformer详解(附代码实现及翻译任务实现)

一:了解背景和动机 阅读Transformer论文: 阅读原始的Transformer论文:“Attention is All You Need”,由Vaswani等人于2017年提出,是Transformer模型的开创性工作。 二:理解基本构建块 注意力机制&#…

C++面试宝典第21题:字符串解码

题目 给定一个经过编码的字符串,返回其解码后的字符串。具体的编码规则为:k[encoded_string],表示方括号内部的encoded_string正好重复k次。注意:k保证为正整数;encoded_string只包含大小写字母,不包含空格和数字;方括号确定是匹配的,且可以嵌套。 示例: 编码字符串为…

tcpdump常用参数以及wireshark密文解密

tcpdump常用参数以及wireshark密文解密 文章目录 一、tcpdump命令和常用参数二、在wireshark中协议解析 tcpdump常用参数 一、tcpdump命令和常用参数 tcpdump常用命令:tcpdump -i eth0 src host 11.6.224.1 and udp port 161 -s 0 -w 161.pcap (161为sn…

开发知识点-JAVA-springboot

springboot springbootConfiguration注解的底层核心原理Bean注解的底层核心原理 springboot Configuration注解的底层核心原理 https://www.bilibili.com/video/BV1rq4y1E7gK/?spm_id_from333.999.0.0&vd_sourcef21773b7086456ae21a58a6cc59023be spring.io 全家桶 24…

Mysql 数据库DDL 数据定义语言——数据库,数据表的创建

DDL:数据定义语言,用来定义数据库对象(数据库,表,字段)—Database Definition Language 1、登录数据库,输入用户名和密码 mysql -ufdd -p990107Wjl2、查看数据库 show databases;3、创建一个…

主流人工智能AI工具测评

主流人工智能AI工具测评 主流的人工智能AI工具ChatGPT ——OpenAI研发CHAT_BISON——Google研发Qwen通义千问 ——阿里云研发文心一言——百度研发 根据10个问题分析人工智能的回答女朋友生气了怎么哄千元机性价比推荐小米13 和 vivo iQOO 11s哪个好计算机专业毕业论文护士年终…

node.js(expree.js )模拟手机验证码功能及登录功能

dbconfig.js const mysql require(mysql) module.exports {// 数据库配置config: {host: localhost, // 连接地址port: 3306, //端口号user: root, //用户名password: wei630229, //密码database: exapp2, //数据库名}, // 连接数据库,使用mysql的连接池连接方式…

transfomer的位置编码

什么是位置编码 在transformer的encoder和decoder的输入层中,使用了Positional Encoding,使得最终的输入满足: input_embeddingpositional_encoding 这里,input_embedding的shape为[n,b,embed_dim],positional_encoding和input_…

xilinxi mulitboot 启动

xilinix在线更新有两种方式,一种是使用ICAP原语,另一中是在xdc中约束,根据应用的场景不同,选用不同的启动方式,第二种更为简单。 可参考官方提供的手册和实例 XAPP1247 链接: XAPP1247 golden和updata.b…

VitePress-01-从零开始的项目创建(npm版)

说明 本文介绍一下 VitePress的项目创建的步骤。 主要用到的命令工具是 npm。 本文的操作步骤是从无到有的创建一个完整的基本的【VitePress】项目。 环境准备 根据官方文档的介绍,截止本文发稿时,需要使用node.js 18 的版本。 可以使用node -v 的命令查…

muduo网络库剖析——通道Channel类

muduo网络库剖析——通道Channel类 前情从muduo到my_muduo 概要事件种类channel 框架与细节成员函数细节实现使用方法 源码结尾 前情 从muduo到my_muduo 作为一个宏大的、功能健全的muduo库,考虑的肯定是众多情况是否可以高效满足;而作为学习者&#x…

Asp .Net Core 系列:集成 Ocelot+Consul实现网关、服务注册、服务发现

什么是Ocelot? Ocelot是一个开源的ASP.NET Core微服务网关,它提供了API网关所需的所有功能,如路由、认证、限流、监控等。 Ocelot是一个简单、灵活且功能强大的API网关,它可以与现有的服务集成,并帮助您保护、监控和扩展您的微…

RTSP/Onvif安防视频监控平台EasyNVR漏洞扫描及解决方法

视频安防监控平台EasyNVR可支持设备通过RTSP/Onvif协议接入,并能对接入的视频流进行处理与多端分发,包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等多种格式。安防视频监控平台可提供视频实时监控直播、云端录像、云存储、录像检索与回看、告警等视频…

【前端架构】前端通用架构

一个强大的前端通用架构应该具备多种能力,以支持现代化的应用程序开发和提高开发效率。以下是一些前端通用架构应该具备的关键能力: 模块化和组件化:支持模块化开发和组件化架构,能够将应用拆分为独立的模块和组件,以便…

外观模式(结构型)

目录 一、前言 二、外观模式 三、总结 一、前言 外观模式(Facade Pattern)是一种结构型设计模式,它为系统中的一组复杂子系统提供一个简单的接口,从而隐藏了这些子系统的复杂性,并且使得代码更加易于使用和理解。 外…

XUbuntu22.04之免费开源DesktopNaotu脑图(二百零七)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

5个99%的人可能不知道的实用程序库!

前言 作为一名前端开发者,这些 JavaScript 库极大地提高了我的工作效率,如格式化日期、处理 URL 参数和调试移动网页。朋友们,我想和你们分享这些库。 1. 使用 “Day.js” 来格式化日期和时间 链接 作为开发者,我已经厌倦了在 JavaScript 中操作日期和时间,因为它太麻烦了。…

MySQL的多表数据记录查询笔记

关系数据操作 合并查询数据记录 在MySQL中通过关键字UNION来实现并操作,即可以通过其将多个SELECT语句的查询结果合并在一起组成新的关系。 两张表,表1 和表2 带有关键字UNION的合并操作 关键字UNION会把查询结果集直接合并在一起,同时将…

vite和webpack的区别

1 构建原理 Webpack 是一个静态模块打包器,通过对项目中的 JavaScript、CSS、图片等文件进行分析,生成对应的静态资源,并且可以通过一些插件和加载器来实现各种功能。Webpack 的主要特点是支持各种复杂的构建场景,例如代码分割、…

vs2022配置OpenCV测试

1,下载Opencv安装包 OpenCV官网下载地址:Releases - OpenCV 大家可以按需选择版本进行下载,官网下载速度还是比较慢的,推荐大家使用迅雷进行下载 下载安装包到自定义文件夹下 双击安装 按以下图示进行安装 2、 添加环境变量 打…