第P2周:Pytorch实现CIFAR10彩色图片识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

  1. 实现CIFAR-10的彩色图片识别
  2. 实现比P1周更复杂一点的CNN网络

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1

(二)具体步骤
1.
import torch  
import torch.nn as nn  
import matplotlib.pyplot as plt  
import torchvision  
  
# 第一步:设置GPU  
def USE_GPU():  
    if torch.cuda.is_available():  
        print('CUDA is available, will use GPU')  
        device = torch.device("cuda")  
    else:  
        print('CUDA is not available. Will use CPU')  
        device = torch.device("cpu")  
  
    return device  
  
device = USE_GPU()  

输出:CUDA is available, will use GPU

  
# 第二步:导入数据。同样的CIFAR-10也是torch内置了,可以自动下载  
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  
                                             transform=torchvision.transforms.ToTensor())  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,  
                                            transform=torchvision.transforms.ToTensor())  
  
  
batch_size = 32  
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)  
  
# 取一个批次查看数据格式  
# 数据的shape为:[batch_size, channel, height, weight]  
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。  
imgs, labels = next(iter(train_dataload))  
print(imgs.shape)  
  
# 查看一下图片  
import numpy as np  
plt.figure(figsize=(20, 5))  
for i, images in enumerate(imgs[:20]):  
    # 使用numpy的transpose将张量(C,H, W)转换成(H, W, C),便于可视化处理  
    npimg = imgs.numpy().transpose((1, 2, 0))  
    # 将整个figure分成2行10列,并绘制第i+1个子图  
    plt.subplot(2, 10, i+1)  
    plt.imshow(npimg, cmap=plt.cm.binary)  
    plt.axis('off')  
plt.show()  

输出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])
image.png

# 第三步,构建CNN网络  
import torch.nn.functional as F  
  
num_classes = 10  # 因为CIFAR-10是10种类型  
class Model(nn.Module):  
    def __init__(self):  
        super(Model, self).__init__()  
        # 提取特征网络  
        self.conv1 = nn.Conv2d(3, 64, 3)  
        self.pool1 = nn.MaxPool2d(kernel_size=2)  
        self.conv2 = nn.Conv2d(64, 64, 3)  
        self.pool2 = nn.MaxPool2d(kernel_size=2)  
        self.conv3 = nn.Conv2d(64, 128, 3)  
        self.pool3 = nn.MaxPool2d(kernel_size=2)  
  
        # 分类网络  
        self.fc1 = nn.Linear(512, 256)  
        self.fc2 = nn.Linear(256, num_classes)  
  
    # 前向传播  
    def forward(self, x):  
        x = self.pool1(F.relu(self.conv1(x)))  
        x = self.pool2(F.relu(self.conv2(x)))  
        x = self.pool3(F.relu(self.conv3(x)))  
  
        x = torch.flatten(x, 1)  
  
        x = F.relu(self.fc1(x))  
        x = self.fc2(x)  
  
        return x  
  
from torchinfo import summary  
# 将模型转移到GPU中  
model = Model().to(device)  
summary(model)  

image.png

# 训练模型  
loss_fn = nn.CrossEntropyLoss() # 创建损失函数  
learn_rate = 1e-2   # 设置学习率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)    # 设置优化器  
  
# 编写训练函数  
def train(dataloader, model, loss_fn, optimizer):  
    size = len(dataloader.dataset) # 训练集的大小 ,这里一共是60000张图片  
    num_batches = len(dataloader)   # 批次大小,这里是1875(60000/32=1875)  
  
    train_acc, train_loss = 0, 0    # 初始化训练正确率和损失率都为0  
  
    for X, y in dataloader: # 获取图片及标签,X-图片,y-标签(也是实际值)  
        X, y = X.to(device), y.to(device)  
  
        # 计算预测误差  
        pred = model(X) # 网络输出预测值  
        loss = loss_fn(pred, y) # 计算网络输出的预测值和实际值之间的差距  
  
        # 反向传播  
        optimizer.zero_grad()   # grad属性归零  
        loss.backward() # 反向传播  
        optimizer.step()    # 第一步自动更新  
  
        # 记录正确率和损失率  
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  
        train_loss += loss.item()  
  
    train_acc /= size  
    train_loss /= num_batches  
  
    return train_acc, train_loss  
  
# 测试函数  
def test(dataloader, model, loss_fn):  
    size = len(dataloader.dataset) # 测试集大小,这里一共是10000张图片  
    num_batches = len(dataloader)   # 批次大小 ,这里312,即10000/32=312.5,向上取整  
    test_acc, test_loss = 0, 0  
  
    # 因为是测试,因此不用训练,梯度也不用计算不用更新  
    with torch.no_grad():  
        for imgs, target in dataloader:  
            imgs, target = imgs.to(device), target.to(device)  
  
            # 计算loss  
            target_pred = model(imgs)  
            loss = loss_fn(target_pred, target)  
  
            test_loss += loss.item()  
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  
  
    test_acc /= size  
    test_loss /= num_batches  
  
    return test_acc, test_loss  
  
# 正式训练  
epochs = 10  
train_acc, train_loss, test_acc, test_loss = [], [], [], []  
  
for epoch in range(epochs):  
    model.train()  
    epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt)  
  
    model.eval()  
    epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn)  
  
    train_acc.append(epoch_train_acc)  
    train_loss.append(epoch_train_loss)  
    test_acc.append(epoch_test_acc)  
    test_loss.append(epoch_test_loss)  
  
    template = 'Epoch:{:2d}, 训练正确率:{:.1f}%, 训练损失率:{:.3f}, 测试正确率:{:.1f}%, 测试损失率:{:.3f}'  
    print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  
  
print('Done')  
  
# 结果可视化  
# 隐藏警告  
import warnings  
warnings.filterwarnings('ignore')   # 忽略警告信息  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 正常显示中文标签  
plt.rcParams['axes.unicode_minus'] = False  # 正常显示+/-号  
plt.rcParams['figure.dpi'] = 100    # 分辨率  
  
epochs_range = range(epochs)  
  
plt.figure(figsize=(12, 3))  
  
plt.subplot(1, 2, 1)    # 第一张子图  
plt.plot(epochs_range, train_acc, label='训练正确率')  
plt.plot(epochs_range, test_acc, label='测试正确率')  
plt.legend(loc='lower right')  
plt.title('训练和测试正确率比较')  
  
plt.subplot(1, 2, 2)    # 第二张子图  
plt.plot(epochs_range, train_loss, label='训练损失率')  
plt.plot(epochs_range, test_loss, label='测试损失率')  
plt.legend(loc='upper right')  
plt.title('训练和测试损失率比较')  
  
plt.show()

# 保存模型  
torch.save(model, './models/cnn-cifar10.pth')

image.png
再次设置epochs为50训练结果:
image.png
epochs增加到100,训练结果:
image.png
可以看到训练集和测试集的差距有点大,不太理想。做一下数据增加试试:

data_transforms= {  
    'train': transforms.Compose([  
        transforms.RandomHorizontalFlip(),  
        transforms.ToTensor(),  
    ]),  
    'test': transforms.Compose([  
       transforms.ToTensor(),  
    ])  
}

在dataset中:

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=data_transforms['train'])  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])

运行结果:
image.png
image.png
比较漂亮了,再调整batch_size=16和epochs=20,提高了近6个百分点。
image.png
batch_size=16,epochs=50:有第20轮左右的时候,验证集的确认性基本就没有再提高了。和上面基本一样。
image.png

(三)总结
  1. epochs并不是越多越好。batch_size同样的道理
  2. 数据增强确实可以提高模型训练的准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/935867.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数字花园】数字花园(个人网站、博客)搭建经历汇总教程

目录 写在最最前面第一章:netlify免费搭建数字花园相关教程使用的平台步骤信息管理 第二章:本地部署数字花园数字花园网站本地手动部署方案1. 获取网站源码2.2 安装 Node.js 3. 项目部署3.1 安装项目依赖3.2 构建项目3.3 启动http服务器 4. 本地预览5. 在…

Hadoop一课一得

Hadoop作为大数据时代的奠基技术之一,自问世以来就深刻改变了海量数据存储与处理的方式。本文将带您深入了解Hadoop,从其起源、核心架构、关键组件,到典型应用场景,并结合代码示例和图示,帮助您更好地掌握Hadoop的实战…

使用 GD32F470ZGT6,手写 I2C 的实现

我的代码:https://gitee.com/a1422749310/gd32_-official_-code I2C 具体代码位置:https://gitee.com/a1422749310/gd32_-official_-code/blob/master/Hardware/i2c/i2c.c 黑马 - I2C原理 官方 - IIC 协议介绍 个人学习过程中的理解,有错误&…

WPF Prism ViewInjection

ViewInjection介绍 ViewInjection是Prism框架提供的一种机制,用于将视图动态地注入到指定的容器(Region)中。这种注入方式允许你在运行时动态地添加、移除或替换视图,从而实现更灵活的用户界面设计。 ViewInjection示例 GitHub…

软考高级架构 - 11.1- 信息物理系统CPS

信息物理系统CPS 信息物理系统(CPS)是控制系统、嵌入式系统的扩展与延伸。通过集成先进的感知、计算、通信、控制等信息技术和自动控制技,构建了物理空间与信息空间中人、机、物、环境、信息等要素相互映射、适时交互、高效协同的夏杂系统。 CPS的本质是基于…

后端开发工程师需要掌握哪些设计模式?

大家好,我是袁庭新。 作为后端开发者,学习和掌握设计模式是非常有必要的。不仅可以帮助后端开发者更好地设计和实现软件架构,还可以提高代码的质量和可维护性。此外,设计模式也是后端开发面试中常见的考点之一,掌握它…

【Android Studio】学习——数据存储管理

AndroidStudio实验——数据存储管理 文章目录 AndroidStudio实验——数据存储管理[toc]一:实验目标和实验内容:二:数据库的CRUD操作【一】创建(Create)【2】读取(Read)【3】更新(Upd…

科研绘图系列:R语言绘制热图和散点图以及箱线图(pheatmap, scatterplot boxplot)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载图1图2图3系统信息参考介绍 R语言绘制热图和散点图以及箱线图(pheatmap, scatterplot & boxplot) 加载R包 library(magrittr) library(dplyr) library(ve…

【Qt】信号、槽

目录 一、信号和槽的基本概念 二、connect函数:关联信号和槽 三、自定义信号和槽 1.自定义槽函数 2.自定义信号函数 例子: 四、带参的信号和槽 例子: 五、Q_OBJECT宏 六、断开信号和槽的连接 例子: 一、信号和槽的基本…

一种构建网络安全知识图谱的实用方法

文章主要工作 论述了构建网络安全知识库的三个步骤,并提出了一个构建网络安全知识库的框架;讨论网络安全知识的推演 1.框架设计 总体知识图谱框架如图1所示,其包括数据源(结构化数据和非结构化数据)、信息抽取及本体构建、网络…

JAVA后端实现全国区县下拉选择--树形结构

设计图如图&#xff1a; 直接上代码 数据库中的格式&#xff1a; JAVA实体类&#xff1a; Data public class SysAreaZoningDO {private Long districtId;private Long parentId;private String districtName;private List<SysAreaZoningDO> children; } MapperSQL语句…

青少年夏令营管理系统的设计与开发(社团管理)(springboot+vue)+文档

&#x1f497;博主介绍&#x1f497;&#xff1a;✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示&#xff1a;文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

安卓低功耗蓝牙BLE官方开发例程(JAVA)翻译注释版

官方原文链接 https://developer.android.com/develop/connectivity/bluetooth/ble/ble-overview?hlzh-cn 目录 低功耗蓝牙 基础知识 关键术语和概念 角色和职责 查找 BLE 设备 连接到 GATT 服务器 设置绑定服务 设置 BluetoothAdapter 连接到设备 声明 GATT 回…

Windows 系统中的组策略编辑器如何打开?

组策略是 Windows 操作系统中用于设置计算机和用户配置的重要工具。它允许管理员控制各种系统功能&#xff0c;从桌面背景到安全设置等。对于 Windows 专业版、企业版和教育版用户来说&#xff0c;可以通过组策略编辑器&#xff08;Group Policy Editor&#xff09;来管理这些设…

MySQL删除外键报错check that column/key exists

在我们删除外键的时候&#xff0c;报了check that column/key exists这个错误&#xff0c;这是因为你的外键名字没写对&#xff0c;我们以为我们写的字段名就是我们的外键其实并不是&#xff0c;我们可以通过show create table[ ]来查看外键的名字 所以删除外键的时候应该这样…

python学opencv|读取图像(十)用numpy创建彩色图像

【1】引言 前序文章中&#xff0c;我们已经学会了用numpy规划数据控制像素大小&#xff0c;然后用像素规划矩阵&#xff0c;对矩阵赋值后输出灰度图&#xff0c;相关链接为&#xff1a; python学opencv|读取图像&#xff08;八&#xff09;用numpy创建纯黑灰度图-CSDN博客 p…

线程池(ThreadPoolExecutor)

目录 一、线程池 标准提供的线程池 ThreadPoolExecutor 自定义线程池 一、线程池 为什么要引入线程池? 这个原因我们需要追溯到线程&#xff0c;我们线程存在的意义在于&#xff0c;使用进程进行并发编程太重了&#xff0c;所以引入了线程&#xff0c;因为线程又称为 “轻…

hbase读写操作后hdfs内存占用太大的问题

hbase读写操作后hdfs内存占用太大的问题 查看内存信息hbase读写操作 查看内存信息 查看本地磁盘的内存信息 df -h查看hdfs上根目录下各个文件的内存大小 hdfs dfs -du -h /查看hdfs上/hbase目录下各个文件的内存大小 hdfs dfs -du -h /hbase查看hdfs上/hbase/oldWALs目录下…

【IntelliJ IDEA 集成工具】TalkX - AI编程助手

前言 在数字化时代&#xff0c;技术的迅猛发展给软件开发者带来了更多的挑战和机遇。为了提高技术开发群体在繁多项目中的编码效率和质量&#xff0c;他们需要一个强大而专业的工具来辅助开发过程&#xff0c;而正是为了满足这一需求&#xff0c;TalkX 应运而生。 一、概述 1…

vue2+element-ui实现多行行内表格编辑

效果图展示 当在表格中点击编辑按钮时:点击的行变成文本框且数据回显可以点击确定按钮修改数据或者取消修改回退数据: 具体实现步骤 1. 行数据定义编辑标记 行数据定义编辑标记 当在组件中获取到用于表格展示数据的方法中,针对每一行数据添加一个编辑标记 this.list.f…