PyTorch搭建LeNet训练集详细实现

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,
                                          shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                         shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()

classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):
    img = img / 2 + 0.5
    nping = img.numpy()
    plt.imshow(np.transpose(nping, (1, 2, 0)))
    plt.show()


# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

# 将创建的模型实例化
net = LeNet()  # 实例化
loss_fuction = nn.CrossEntropyLoss()  # 定义损失函数

# 通过优化器将所有可训练的参数都进行训练,lr是learningrate学习率
optimizer = optim.Adam(net.parameters(), lr=0.001)

#通过for循环实现训练过程,循环几次就是将训练集迭代多少次
for epoch in range(5):

    running_loss = 0.0  # 用来累加在学习过程中的损失
    for step, data in enumerate(trainloader, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()   # 历时损失梯度清零。

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_fuction(outputs, labels)  # 计算神经网络的预测值和真实标签之间的损失
        loss.backward()
        optimizer.step()  # step()函数实现参数更新

        # print statistics  打印数据的过程
        running_loss += loss.item()
        if step % 500 == 499:  # 每隔500步打印一次数据的信息
            with torch.no_grad():  # 上下文管理器
                outputs = net(test_image)
                predict_y = torch.max(outputs, dim=1)[1]
                accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)

                print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
                       (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0

print('Finished Training')

# 将模型保存到文件夹中
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

 详细解释:

比较重点的单独解释了,其他的在注释中。

 optimizer.zero_grad()   # 历时损失梯度清零。

 ? 为什么每计算一个batch,就要调用一次 optimizer.zero_grad()函数

=> 通过清楚历史梯度,就会对计算的历史梯度进行累加。通过这个特性,能变相的实现一个很大的batch数值的训练(因为batch数值越大,训练效果越好)

 with torch.no_grad():  # 上下文管理器

 上下文管理器: 在接下来的计算过程中,不再去计算每个节点的误差损失梯度。

如果不调用这个函数,将会在测试过程中占用更多的算力,消耗更多的资源和占用更多的内存资源,导致内存容易崩。

print函数中打印参数解释:

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
                       (epoch + 1, step + 1, running_loss / 500, accuracy))

epoch + 1:迭代到第几轮了

step + 1:某一轮的第几步

running_loss / 500:训练过程中500步平均训练误差

accuracy:准确率

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/445151.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【亲测有效】解决三月八号ChatGPT 发消息无响应!

背景 今天忽然发现 ChatGPT 无法发送消息,能查看历史对话,但是无法发送消息。 可能的原因 出现这个问题的各位,应该都是点击登录后顶部弹窗邀请 [加入多语言 alapha 测试] 了,并且语言选择了中文,抓包看到 ab.chatg…

Flutter 开发环境搭建-VS Code篇

1.准备环境 Java SDK 下载及安装Flutter SDK 安装及配置环境变量 下载地址将flutter sdk解压目录下的bin目录放到系统环境变量中 检查环境,在系统终端中输入: # 打印flutter sdk版本号 flutter --version# 检查flutter运行环境 flutter doctor第一次运…

qt 格式化打印 日志 QMessagePattern 格式词法语法及设置

一、qt源码格式化日志 关键内部类 QMessagePattern qt为 格式化打印日志 提供了一个简易的 pattern(模式/格式) 词法解析的简易的内部类QMessagePattern,作用是获取和解析自定义的日志格式信息。 该类在qt的专门精心日志操作的源码文件Src\qtbase\src\corelib\global\qloggi…

专题1 - 双指针 - leetcode 11. 盛最多水的容器

leetcode 11. 盛最多水的容器 1. leetcode 11. 盛最多水的容器1. 题目详情1. 原题链接2. 基础框架 2. 解题思路1. 题目分析2. 算法原理3. 时间复杂度 3. 代码实现4. 知识与收获 1. leetcode 11. 盛最多水的容器 1. 题目详情 给定一个长度为 n 的整数数组 height 。有 n 条垂线…

软考攻略/软考详解/软考等级/软考科目

目录 前言 一、软考是什么 二、证书样式 三、软考介绍 3.1 什么是软考? 3.2 通过了软考,就算有职称了么? 3.3 哪些人可以参加软考? 3.4 软考设置了哪些资格? 3.5 哪些资格含金量比较高呢?报考建议? 四、中级资格推荐以下几个: 计算机软件类 --软件…

【AI视野·今日NLP 自然语言处理论文速览 第八十二期】Tue, 5 Mar 2024

AI视野今日CS.NLP 自然语言处理论文速览 Tue, 5 Mar 2024 (showing first 100 of 175 entries) Totally 100 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Key-Point-Driven Data Synthesis with its Enhancement on Mathematica…

农场管理小程序|基于微信小程序的农场管理系统设计与实现(源码+数据库+文档)

农场管理小程序目录 目录 基于微信小程序的农场管理系统设计与实现 一、前言 二、系统设计 三、系统功能设计 1、用户信息管理 2、农场信息管理 3、公告信息管理 4、论坛信息管理 四、数据库设计 五、核心代码 七、最新计算机毕设选题推荐 八、源码获取&#x…

Day22:安全开发-PHP应用留言板功能超全局变量数据库操作第三方插件引用

目录 开发环境 数据导入-mysql架构&库表列 数据库操作-mysqli函数&增删改查 数据接收输出-html混编&超全局变量 第三方插件引用-js传参&函数对象调用 完整源码 思维导图 PHP知识点: 功能:新闻列表,会员中心&#xff0…

Stable Diffusion 模型下载:ZavyChromaXL(现实、魔幻)

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八 下载地址 模型介绍 作者述:该模型系列应该是用于 SDXL 的 ZavyMix SD1.5 模型的延续。主要重点是获…

【工具】Git的24种常用命令

相关链接 传送门&#xff1a;>>>【工具】Git的介绍与安装<< 1.Git配置邮箱和用户 第一次使用Git软件&#xff0c;需要告诉Git软件你的名称和邮箱&#xff0c;否则无法将文件纳入到版本库中进行版本管理。 原因&#xff1a;多人协作时&#xff0c;不同的用户可…

M1电脑 Xcode15升级遇到的问题

遇到四个问题 一、模拟器下载经常报错。 二、Xcode15报错: SDK does not contain libarclite 三、报错coreAudioTypes not found 四、xcode模拟器运行一次下次必定死机 一、模拟器下载经常报错。 可以https://developer.apple.com/download/all/?qios 下载最新的模拟器&…

Skywalking

1、简介 Skywalking是由国内开源爱好者吴晟开源并提交到Apache孵化器的开源项目&#xff0c; 2017年12月SkyWalking成为Apache国内首个个人孵化项目&#xff0c; 2019年4月17日SkyWalking从Apache基金会的孵化器毕业成为顶级项目&#xff0c; 目前SkyWalking支持Java、 .Net、 …

lvs集群中NAT模式

群集的含义 由多台主机构成&#xff0c;但对外表现为一个整体&#xff0c;只提供一个访问入口&#xff0c;相当于一台大型的计算机。 横向发展:放更多的服务器&#xff0c;有调度分配的问题。 垂直发展&#xff1a;升级单机的硬件设备&#xff0c;提高单个服务器自身功能。 …

复盘-PPT

调整PPT编号起始页码在设计→幻灯片大小 设置所有以及文本项目符号 ## 打开母版&#xff0c;找到对应级别设置重置 当自动生成的smartart图形不符合预期时

Python Web应用程序构建的最佳实践:代码实例与深度解析【第122篇—装饰器详解】

Python Web应用程序构建的最佳实践&#xff1a;代码实例与深度解析 在当今数字时代&#xff0c;构建高效、可扩展的Web应用程序是开发者们的一项重要任务。Python&#xff0c;作为一种简洁、强大的编程语言&#xff0c;为Web开发提供了丰富的工具和框架。在本篇文章中&#xff…

递增三元组 刷题笔记

题意为 若存在 a中的数小于b中的数&#xff0c;b中的数小于c中的数 则该数算一种方案 思路 暴力模拟优化 两层循环遍历即可 从b到c的过程我们发现 第三层并不需要循环 直接加上 大于b的数量即可 那么第一层和第三层是对称的 我们有没有可能再去掉一层循环 只做一次遍历 …

使用51单片机控制lcd1602字体显示

部分效果图&#xff1a; 准备工作&#xff1a; 51单片机&#xff08;BST&#xff09;1602显示屏 基础知识&#xff1a; 注&#xff1a;X表示可以是0&#xff0c;也可以是1&#xff1b; DL 1&#xff0c; N 1&#xff0c; F 0&#xff0c; 代码一&#xff1a; 要求显示字母…

NIN网络中的网络

是什么 intro LeNet→AlexNet→VGG→NiN→GoogLeNet→ResNetLeNet→AlexNet→VGG 卷积层模块充分抽取空间特征全连接层输出分类结果AlexNet & VGG 改进在于把两个模块加宽 、加深&#xff08;加宽指增加通道数&#xff0c;那加深呢&#xff1f;&#xff08;层数增加叭 Ni…

【STM32】HAL库 CubeMX 教程 --- 通用定时器 TIM2 定时

实验目标&#xff1a; 通过CUbeMXHAL&#xff0c;配置TIM2&#xff0c;1s中断一次&#xff0c;闪烁LED。 一、常用型号的TIM时钟频率 1. STM32F103系列&#xff1a; 所有 TIM 的时钟频率都是72MHz&#xff1b;F103C8不带基本定时器&#xff0c;F103RC及以上才带基本定时器。…

ClickHouse Grafana插件4.0版 - 升级SQL可观测性

本文字数&#xff1a;3396&#xff1b;估计阅读时间&#xff1a;9 分钟 作者&#xff1a;Clickhouse team 审校&#xff1a;庄晓东&#xff08;魏庄&#xff09; 本文在公众号【ClickHouseInc】首发 自2022年5月Grafana的ClickHouse插件首次发布以来&#xff0c;两种技术都有了…