LeNet训练集详细实现

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,
                                          shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                         shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()

classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):
    img = img / 2 + 0.5
    nping = img.numpy()
    plt.imshow(np.transpose(nping, (1, 2, 0)))
    plt.show()


# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

for epoch in range(5):

    running_loss = 0.0
    for step, data in enumerate(trainloader, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_fuction(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if step % 500 == 499:
            with torch.no_grad():
                outputs = net(test_image)
                predict_y = torch.max(outputs, dim=1)[1]
                accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)

                print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
                       (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0

print('Finished Training')

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/434472.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

npm 私服以及使用

在工作中,公司有很多内部的包并不希望发布到npm官网仓库,因为可能涉及到一些私有代码不能暴露。对于前端来讲,这时就可以选择在公司内网搭建npm私有仓库。当前比较主流的几种解决方案:verdaccio、nexus、cnpm。大家可以按照自己的…

限时特惠,立即购买CST电磁仿真软件,享受超值优惠与高质量服务

在电磁仿真领域,CST软件以其卓越的性能和广泛的应用成为了行业内的佼佼者。现在,我们非常高兴地宣布,为庆祝CST软件的持续创新与客户支持,我们推出限时特惠活动!(联系CST软件中国区代理亿达四方&#xff0c…

RabbitMQ如何保证消息不丢

如何保证Queue消息能不丢呢? RabbitMQ在接收到消息后,默认并不会立即进行持久化,而是先把消息暂存在内存中,这时候如果MQ挂了,那么消息就会丢失。所以需要通过持久化机制来保证消息可以被持久化下来。 队列和交换机的…

React Vite 构建工具如何查看代码占用体积

首先安装 Vite 中的 rollup-plugin-visualizer 插件 cnpm install rollup-plugin-visualizer 接着在你的 vite.config.ts 中引入并且使用到 plugins 中 import { visualizer } from "rollup-plugin-visualizer";export default defineConfig({plugins: [react(),vi…

现代信号处理学习笔记(三)现代谱估计

现代谱估计是信号处理和统计领域的一个重要主题,它涉及从信号中估计其频谱内容的方法。频谱表示一个信号在不同频率上的成分强度。谱估计在许多应用中都很重要,如通信系统、雷达、音频处理、生物医学工程等领域。 目录 前言 一、基础知识 1、功率谱估…

同步通信和异步通信(RabbitMq学习前篇)

MQ学习前篇 文章目录 MQ学习前篇1、同步和异步通讯1.1、同步通讯和异步通讯1.2、同步调用存在的问题1.3、异步调用方案1.4、异步通信的缺点 1、同步和异步通讯 学习mq之前,就要先知道同步通讯和异步通讯的区别。 1.1、同步通讯和异步通讯 同步通讯就像是打电话&am…

【网络安全】漏洞挖掘入门教程(非常详细),小白是如何挖漏洞(技巧篇)0基础入门到精通!

温馨提示: 初学者最好不要上手就去搞漏洞挖掘,因为漏洞挖掘需要很多的系统基础知识和一些理论知识做铺垫,而且难度较大…… 较合理的途径应该从漏洞利用入手,不妨分析一些公开的CVE漏洞。很多漏洞都有比较好的资料,分…

Java开发人员不得不收集的代码,java软件开发面试常见问题

前言 今年的金三银四已经过去一大半了,在这其中参与过不少面试,2021都说工作不好找,这也是对开发人员的要求变向的提高了。 之前在Github上收获15Kstar的Java核心神技(这参数,质量多高就不用我多说了吧)非…

C++与 Fluke5500A设备通过GPIB-USB-B通信的经验积累

C与 Fluke5500A设备通过GPIB-USB-B通信的经验积累 以下内容来自:C与 Fluke5500A设备通过GPIB-USB-B通信的经验积累 - JMarcus - 博客园 (cnblogs.com)START 1.需要安装NI-488.2.281,安装好了之后,GPIB-USB-B的驱动就自动安装好了 注意版本…

13. C++类使用方式

【类】 C语言使用函数将程序代码模块化,C通过类将代码进一步模块化,类用于将实现一种功能的全局数据、以及操作这些数据的函数集中存储在一起,同时可以设置类成员的访问权限,禁止外部代码使用和篡改本类成员,类成员访…

SMT贴片加工——品质检验要求

一、元器件贴装工艺品质要求 1.元器件贴装需整齐、正中,无偏移、歪斜 2.贴装的元器件型号规格应正确;元器件应无漏贴、错贴 3.贴片元器件不允许有反贴 4.有极性要求的贴片器件安装需按正确的极性标示安装 二、元器件焊锡工艺要求 1.FPC板面应无影响…

Java实现读取转码写入ES构建检索PDF等文档全栈流程

背景 之前已简单使用ES及Kibana和在线转Base64工具实现了检索文档的demo,并已实现WebHook的搭建和触发流程接口。 传送门: 基于GitBucket的Hook构建ES检索PDF等文档全栈方案 使用ES检索PDF、word等文档快速开始 实现读取本地文件入库ES 总体思路&…

真空展|2024上海国际真空技术及设备展览会

2024上海国际真空技术及设备展览会 2024 Shanghai International Exhibition of vacuum technology and equipment 时 间:2024年7月13-15日 地 点:上海新国际博览中心 承办单位:上海昶文展览服务有限公司 展会简…

JVM内部世界(内存划分,类加载,垃圾回收)

💕"Echo"💕 作者:Mylvzi 文章主要内容:JVM内部世界(内存划分,类加载,垃圾回收) 关于JVM的学习主要掌握三方面: JVM内存区的划分类加载垃圾回收 一.JVM内存区的划分 当一个Java进程开始执行时,JVM会首先向操作系统申…

CSS常见布局方式

一、静态布局(Static Layout) 既传统web设计 就是不管浏览器尺寸多少,网页布局就按当时写代码的布局来布置; 块级元素:每个块级元素会在上一个元素下面另起一行,他们会被设置好的margin分离。块级元素是垂直组织的。 …

SQL Server 安装部署

SQL Server是运行在win server上的关系型数据库,本文主要描述SQL Server的安装部署。 如上所示,SQL Server不同版本的定义以及特性 如上所示,SQL Server不同版本支持的容量规模 如上所示,SQL Server不同版本支持的高可用性 如上所…

能源大数据采集,为您提供专业数据采集服务

随着经济的不断发展,能源产业也逐渐成为国民经济的支柱产业之一。而对于能源行业来说,数据采集是一项至关重要的工作。以往,能源企业采集数据主要依靠人工收集、整理,但是这种方式不仅效率低下,而且容易出现数据不准确…

springcloud:3.7测试线程池服务隔离

服务提供者【test-provider8001】 Openfeign远程调用服务提供者搭建 文章地址http://t.csdnimg.cn/06iz8 相关接口 测试远程调用:http://localhost:8001/payment/index 服务消费者【test-consumer-resilience4j8004】 Openfeign远程调用消费者搭建 文章地址http://t…

ARM中专用指令(异常向量表、异常源、异常返回等)

状态寄存器传送指令 CPSR寄存器 状态寄存器传送指令:访问(读写)CPSR寄存器 读CPSR MRS R1, CPSR R1 CPSR 写CPSR MSR CPSR, #0x10 0x10为User模式,且开启IRQ和FRQ CPSR 0x10 在USER模式下不能随意修改CPSR,因为USER模式…

7.使用os.Args或flag解析命令行参数

文章目录 一、os.Args二、flag包基本使用 Go语言内置的flag包实现了命令行参数的解析,flag包使得开发命令行工具更为简单。 一、os.Args 如果你只是简单的想要获取命令行参数,可以像下面的代码示例一样使用os.Args来获取命令行参数。 package mainimp…