Pytorch入门实战 P2-CIFAR10彩色图片识别

目录

一、前期准备

1、数据集CIFAR10

2、判断自己的设备,是否可以使用GPU运行。

3、下载数据集,划分好训练集和测试集

4、加载训练集、测试集

5、取一个批次查看下

6、数据可视化

二、搭建简单的CNN网络模型

三、训练模型

1、设置超参数

2、编写训练函数

3、编写测试函数

4、正式训练

四、模型训练结果可视化

五、模型训练结果:


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

这周的实战内容,主要使用的数据集是CIFAR10数据集。用来验证彩色图片的识别。

一、前期准备

1、数据集CIFAR10

我们使用的数据集的文档地址:Datasets — Torchvision 0.17 documentation

简单介绍下CIFAR10数据集:

CIFAR-10数据集由60000张32 × 32彩色图像组成,分为10个类,每个类有6000张图像。

50000张训练图像10000张测试图像

2、判断自己的设备,是否可以使用GPU运行。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

3、下载数据集,划分好训练集和测试集

import torchvision.datasets

# 下载训练集
train_ds = torchvision.datasets.CIFAR10('data',
                                        train=True,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 下载测试集
test_ds = torchvision.datasets.CIFAR10('data',
                                       train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)

4、加载训练集、测试集

# 使用dataloader加载数据集,并设置好batch_size
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,
                                       shuffle=True,
                                       batch_size=batch_size)
test_dl = torch.utils.data.DataLoader(test_ds,
                                      batch_size=batch_size)

5、取一个批次查看下

# 取一个批次,查看下数据
imgs,labels = next(iter(train_dl))
print(imgs.shape)   #  数据的shape为:[batch_size,channel,height,weight]  
'''
    对于CIFAR10,这里的shape是 [32,3,32,32],即 因为取得是train_dl的数据,batch_size为32;
    channel为3是因为,是彩色图片RGB的3通道,如果是黑白图片,则channel为1;剩下的32x32是高度和宽度;
'''

6、数据可视化

即:展示下取到的数据。

# 数据可视化
plt.figure(figsize=(20,5))
for i, imgs in enumerate(imgs[:20]):
    npimg = imgs.numpy().transpose((1,2,0))   
            #.numpy()用于将Tensor转换为一个Numpy数组。transpose是Numpy数组的一个方法,用于重新排列数组的维度。
    plt.subplot(2, 10, i+1)
    plt.imshow(npimg, cmap=plt.cm.binary)
    plt.axis('off')
plt.show()

运行结果展示: 

二、搭建简单的CNN网络模型

 CNN(卷积神经网络),需要注意其结构、层与层之间的连接关系以及各层的功能。

①卷积层:负责提取特征。(通常使用局部连接权值共享方式,这有助于减少网络的参数数量和计算复杂度。)

②池化层:负责降低数据的空间尺寸和计算复杂度。

③全连接层:负责将提取的特征映射到输出类别。

# 构建简单的CNN网络
num_classes = 10
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        # 特征提取
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2)

        # 分类网络
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, num_classes)

    # 前向传播
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))

        x = torch.flatten(x, start_dim=1)  # 线性层+激活函数  是构建复杂模型的基础
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 打印并加载模型
model = Model().to(device)
print(model)

三、训练模型

1、设置超参数

# 1、设置超参数
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2   #学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)   # 定义一个随机梯度下降优化器,即SGD优化器。
                    # model.parameters() 返回模型中所有可训练的参数(通常是权重和偏置)

2、编写训练函数

# 2、编写训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) # 数据集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目 1875 (60000/32 = 1875)

    train_loss, train_acc = 0, 0   # 初始化训练的损失和正确率
    for X,y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,y为真实值,计算二者差值,即为损失。

        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches
    return train_acc, train_loss

3、编写测试函数

# 3、编写测试函数
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器。
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 数据集的大小,共10000张
    num_batches = len(dataloader)  # 批次数目 ,313( 10000/32 = 321.5 ,向上取整)

    test_loss, test_acc = 0, 0  # 初始化测试的损失和精确

    # 不进行训练时,停止梯度下降,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

        test_acc /= size
        test_loss /= num_batches
        return test_acc, test_loss

4、正式训练

# 4、正式训练
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []

'''
     model.train()和model.eval() 是深度学习中常见的两个方法,它们用于设置模型的训练模式和评估模式。
        ①当你调用model.train()时,你正在告诉模型你即将进入训练阶段。通常意味着模型中的某些层(如Dropout层和BatchNormalization层)会改变它们的行为以适应训练过程。
            Dropout层:在训练模式下,Dropout层会随机将一部分神经元的输出设置为0,有助于防止过拟合。
            BatchNormalization层:在训练模式下,BatchNoralization层会使用当前批次的数据来更新其运行均值和方差,并应用这些统计量来标准化输入。
        ②当你调用model.eval()时,你正在告诉模型你即将进入评估或推断阶段。在这种模式下,模型的某些层会改变它们的行为,以确保在评估时模型给出一致的结果。
'''
for epoch in range(epochs):
    model.train()  # 进入训练阶段
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}'
    print(template.format(epoch+1, epoch_train_acc*100,epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Finish')

四、模型训练结果可视化

# 四、结果可视化
warnings.filterwarnings('ignore')   # 忽略警告信息
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100    # 分辨率

epochs_range = range(epochs)  # 生成从0到epoches-1的整数序列

plt.figure(figsize=(12,3))  # figsize=(12,3)  包含两个元素的元组,分别代表图形的宽度和高度,单位是英寸。

plt.subplot(1,2,1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validataion Loss')

# 在远程服务器上面跑代码,想要保存下,plt.show()的结果,打下下面的注释
# plt.savfig('想要保存的服务器的地址+图片的名称.png/jpg自行定义即可')  
# eg:plt.savefig('/data/jupyter/deepinglearning/resultImg.jpg')

plt.show()
print("画图结束。。。")

五、模型训练结果:

这周和上周的代码类似,但是,比起刚开始的时候,好多代码都清晰了很多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/456292.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【深入理解设计模式】命令设计模式

命令设计模式: 命令模式(Command Pattern)是一种行为型设计模式,它将请求封装为一个对象,从而使你可以用不同的请求对客户端进行参数化,对请求排队或记录请求日志,以及支持可撤销的操作。 概述…

onecloud刷CasaOS系统后如何安装内网穿透实现公网访问本地文件

文章目录 1. CasaOS系统介绍2. 内网穿透安装3. 创建远程连接公网地址4. 创建固定公网地址远程访问 2月底,玩客云APP正式停止运营,不再提供上传、云添加功能。3月初,有用户进行了测试,局域网内的各种服务还能继续使用,但…

国产化兼容问题与解决办法: java.lang.ClassNotFoundException: javafx.util.Pair

先说解决办法:找一个大版本相同的jdk将/jre/lib/ext中的所有jar包放到服务器jdk相同路径下,跳过相同名称. 下面是详细的问题分析,感觉啰嗦或者没有用,可以直接关闭 运行环境: 服务器:麒麟v10.x86_64 jdk:BiSheng (build 1.8.0_402-b11) 问题描述: 将程序部署在国产化服务器…

STC89C52单片机 启动!!!(一)

跑马灯实现 直接上代码 #include<regx52.h> sbit D1P2^0; sbit D2P2^1; sbit D3P2^2; sbit D4P2^3; sbit D5P2^4; sbit D6P2^5; sbit D7P2^6; sbit D8P2^7; void delay(int num){while(num--){} } void led_running(){//从第1盏灯到第8盏灯依次点亮D10;delay(40000);D2…

unity2D生成9*9格子

1.创建一个空对象和格子 2将格子做成预制体&#xff08;直接将格子拖到这里即可&#xff0c;拖了过后删掉原来的格子&#xff09; 3.创建脚本并将脚本拖到空对象上 using System.Collections; using System.Collections.Generic; using UnityEngine;public class CreateMap : M…

2024年雪糕线上市场未来发展趋势分析(2024京东淘宝天猫雪糕数据分析报告)

据相关媒体报道&#xff0c;北京多位雪糕批发商称钟薛高停产了&#xff0c;从年前开始就已经不供货了。还有记者实探钟薛高的北京总部&#xff0c;发现有不少人离职&#xff0c;办公区内仅剩零星几人。 从60元到2.5元&#xff0c;钟薛高在这两年经历了不少风波&#xff0c;终究…

鸿蒙不再适合JS语言开发

ArkTS是鸿蒙生态的应用开发语言。它在保持TypeScript&#xff08;简称TS&#xff09;基本语法风格的基础上&#xff0c;对TS的动态类型特性施加更严格的约束&#xff0c;引入静态类型。同时&#xff0c;提供了声明式UI、状态管理等相应的能力&#xff0c;让开发者可以以更简洁、…

mysql事务(MVCC机制:undo日志)(mysql执行过程:redo日志,Buffer Pool缓存池)

事务 目的&#xff1a;保证数据的最终一致性## 事务的目的 事务的4大特性&#xff08;ACID&#xff09; 1.原子性(Atomicity):由undo log日志来保证 2.一致性(Consistency):使用事务的最终目的&#xff0c;由业务代码正确逻辑保证,比如错误的try-catch 3.隔离性(Isolation):…

在任意一个文件下,进入cmd

直接在界面上输入cmd&#xff0c;回车就出来了

安卓六大布局

LinearLayout&#xff08;线性布局&#xff09; 1.简介 线性布局在开发中使用最多&#xff0c;具有垂直方向与水平方向的布局方式。LinearLayout 默认是垂直排列的&#xff0c;但是可以通过设置 android:orientation 属性来改变为水平排列。 2.常用属性 orientation&#xf…

Windows系统下载安装Emby结合内网穿透实现公网访问本地影音网站

文章目录 1.前言2. Emby网站搭建2.1. Emby下载和安装2.2 Emby网页测试 3. 本地网页发布3.1 注册并安装cpolar内网穿透3.2 Cpolar云端设置3.3 Cpolar内网穿透本地设置 4.公网访问测试5.结语 1.前言 在现代五花八门的网络应用场景中&#xff0c;观看视频绝对是主力应用场景之一&…

3.2 RK3399项目开发实录-初次使用的环境搭建(物联技术666)

通过百度网盘分享的文件&#xff1a;嵌入式物联网单片… 链接:https://pan.baidu.com/s/1Zi9hj41p_dSskPOhIUnu9Q?pwd8qo1 提取码:8qo1 复制这段内容打开「百度网盘APP 即可获取」 1. 用户和密码 1.1. Ubuntu Desktop 系统 Ubuntu Desktop 系统开机启动后&#xff0c;自动登录…

权限管理和操作指令

文章目录 前言一、文件的权限分类二、操作时无相应权限解决办法1.使用sudo指令2.修改文档权限 总结 前言 &#x1f4a6; Linux操作系统中&#xff0c;主要都是对文件进行操作&#xff0c;完成读写或者执行功能。Ubuntu 下我们会常跟用户权限打交道&#xff0c;权限就是用户对于…

python操作dataframe--打乱df的顺序

在Python中&#xff0c;可以使用Pandas库来操作DataFrame。要打乱DataFrame的顺序&#xff0c;可以使用sample方法来实现。以下是一个示例代码&#xff1a; import pandas as pd# 创建一个示例DataFrame data {A: [1, 2, 3, 4, 5],B: [10, 20, 30, 40, 50]} df pd.DataFrame…

为什么ERP与MES集成那么难搞?怎么有效解决这一难题

在现代企业信息化进程中&#xff0c;ERP&#xff08;企业资源规划&#xff09;和MES&#xff08;制造执行系统&#xff09;作为企业管理的核心信息系统&#xff0c;它们之间的深度集成是提升生产效率、实现精益管理和智能决策的关键环节。然而&#xff0c;ERP与MES集成并非易事…

【Python】成功解决NameError: name ‘sns‘ is not defined

【Python】成功解决NameError: name ‘sns’ is not defined &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#x1f448; 希望得到您…

1个二维码能包含多个视频吗?制作视频二维码的方法

二维码在生活中现在随处可见&#xff0c;除了用于支付之外&#xff0c;展示内容也可以通过二维码来展现&#xff0c;比如常见的视频、图片、文件、音频等内容都可以通过二维码来展现。那么当我们需要将多个视频存入一个二维码中展示时&#xff0c;该如何利用二维码生成器的工具…

开发知识点-python-Tornado框架

介绍 Tornado是一个基于Python语言的高性能Web框架和异步网络库&#xff0c;它专注于提供快速、可扩展和易于使用的网络服务。由于其出色的性能和灵活的设计&#xff0c;Tornado被广泛用于构建高性能的Web应用程序、实时Web服务、长连接的实时通信以及网络爬虫等领域。 Torna…

jmeter接口自动化测试通过csv文件读取用例并执行测试

最近在公司测试中经常使用jmeter这个工具进行接口自动化&#xff0c;简单记录下~ 一、在csv文件中编写好用例 首先在csv文件首行填写相关参数&#xff08;可根据具体情况而定&#xff09;并编写测试用例。脚本可通过优先级参数控制执行哪些接口&#xff0c;通过端口参数同时执…

leetcode110.平衡二叉树

之前没有通过的样例 return语句只写了一个 return abs(l-r)<1缺少了 isBalanced(root->left)&&isBalanced(root->right);补上就好了 class Solution { public:bool isBalanced(TreeNode* root) {if(!root){return true;}int lgetHeight(root->left);i…