《深度学习实战》第2集-补充:卷积神经网络(CNN)与图像分类 实战代码解析和改进

以下是对《深度学习实战》第2集中 CIFAR-10 数据集 使用卷积神经网络进行图像分类实战 代码的详细分析,并增加数据探索环节,同时对数据探索、模型训练和评估的过程进行具体说明。所有代码都附上了运行结果配图,方便对比。


《深度学习实战》第2集 补充:数据探索与分析

在深度学习项目中,数据探索(Exploratory Data Analysis, EDA)是至关重要的一步。通过数据探索,我们可以了解数据集的基本特性、分布情况以及潜在问题,从而为后续的模型设计和优化提供指导。

1. 数据探索的目标

  • 了解 CIFAR-10 数据集的类别分布。
  • 可视化样本图像,观察其特征。
  • 分析数据预处理的效果。

2. 数据探索实现

2.1 类别分布分析

CIFAR-10 数据集包含 10 个类别,每个类别的样本数量应均匀分布。我们可以通过以下代码统计类别分布:

import matplotlib.pyplot as plt

# 统计类别分布
train_labels = [label for _, label in train_dataset]
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
label_counts = {class_names[i]: train_labels.count(i) for i in range(10)}

# 可视化类别分布
plt.bar(label_counts.keys(), label_counts.values())
plt.title("Class Distribution in CIFAR-10")
plt.xlabel("Class")
plt.ylabel("Number of Samples")
plt.xticks(rotation=45)
plt.show()

代码运行结果输出:

在这里插入图片描述

结果分析

  • 如果类别分布均匀,说明数据集没有类别不平衡问题。
  • 在 CIFAR-10 中,每个类别有 5,000 张训练图像,分布均衡。
2.2 样本可视化

为了直观了解数据集中的图像特征,我们可以随机抽取一些样本并可视化:

import numpy as np

# 可视化样本图像
def imshow(img):
    img = img / 2 + 0.5  # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

# 获取一批数据
dataiter = iter(train_loader)
images, labels = next(dataiter)

# 显示图像
imshow(torchvision.utils.make_grid(images[:16]))  # 显示前 16 张图像
print("Labels:", [class_names[label] for label in labels[:16]])

代码运行结果输出:

在这里插入图片描述

Labels: ['bird', 'ship', 'automobile', 'ship', 'cat', 'truck', 'airplane', 'bird', 'airplane', 'frog', 'ship', 'bird', 'automobile', 'bird', 'automobile', 'truck']

结果分析

  • 图像大小为 32x32,分辨率较低,但足以捕捉基本特征。
  • 不同类别的图像具有明显的视觉差异(如飞机与汽车、猫与狗等),这有助于模型学习区分不同类别。
2.3 数据预处理效果

数据预处理包括调整大小、归一化等操作。我们可以通过打印预处理后的图像张量来验证其效果:

print("Preprocessed Image Shape:", images.shape)  # 输出形状
print("Preprocessed Image Values:", images[0].min().item(), images[0].max().item())  # 输出归一化范围

代码运行输出结果:

Preprocessed Image Shape: torch.Size([64, 3, 224, 224])
Preprocessed Image Values: -0.929411768913269 1.0

结果分析

  • 预处理后图像被调整为 224x224(ResNet 输入要求),并归一化到 [-1, 1] 范围。
  • 这些操作确保了输入数据的一致性和模型的稳定性。

原代码分析与改进

1. 数据加载与预处理

代码中使用 torchvision.transforms 对数据进行了标准化和尺寸调整。以下是关键步骤的解释:

  • Resize(224):将图像从原始的 32x32 调整为 ResNet 的输入尺寸 224x224。
  • Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):将像素值归一化到 [-1, 1] 范围,以加速收敛。

改进建议

  • 添加数据增强(Data Augmentation),如随机裁剪、水平翻转等,以提高模型的泛化能力:
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    

改进后代码替换原代码,运行后输出:

Epoch 1, Loss: 0.6051
Epoch 2, Loss: 0.3872
Epoch 3, Loss: 0.3126
Epoch 4, Loss: 0.2649
Epoch 5, Loss: 0.2316
Test Accuracy: 0.9022
可以看到损失系数不同,但精确率最终结果差不多。

2. 模型训练

代码中使用了预训练的 ResNet-18 模型,并修改了最后一层以适应 CIFAR-10 的 10 个类别。以下是训练过程的关键点:

2.1 模型结构
  • ResNet-18 是一个轻量级的 CNN 架构,包含 18 层卷积网络。
  • 修改全连接层(model.fc)以输出 10 个类别的概率。
2.2 训练过程
  • 使用 Adam 优化器,学习率为 0.001。
  • 损失函数为交叉熵损失(nn.CrossEntropyLoss),适用于多分类任务。
  • 每个 epoch 后打印平均损失,便于监控训练进度。

改进建议

  • 增加学习率调度器(Learning Rate Scheduler),例如余弦退火或 StepLR,以动态调整学习率。
  • 保存最佳模型权重,避免过拟合。

3. 模型评估

代码中通过测试集计算了模型的准确率。以下是评估过程的关键点:

3.1 测试过程
  • 将模型切换为评估模式(model.eval()),关闭 Dropout 和 BatchNorm 的随机性。
  • 使用 torch.no_grad() 禁用梯度计算,减少内存消耗。
3.2 结果分析

假设测试准确率为 75%,说明模型在 CIFAR-10 上表现良好,但仍有一定的提升空间。

改进建议

  • 计算混淆矩阵(Confusion Matrix),分析模型在不同类别上的表现:
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.show()
    

代码运行输出结果:
在这里插入图片描述

  • 分析上图混淆矩阵可以发现模型在某些类别(如“猫”与“狗”)上容易混淆,Cat 和 Dog 的矩阵交汇数值相对偏高,从而指导进一步优化。

总结

通过增加数据探索环节,我们深入了解了 CIFAR-10 数据集的特性,并验证了数据预处理的有效性。在模型训练和评估过程中,我们分析了代码的实现细节,并提出了改进建议,包括数据增强、学习率调度器和混淆矩阵分析。这些改进可以帮助模型更好地适应数据集,并提升性能。

希望这些内容能为你提供更全面的理解!如果你有任何问题或想法,欢迎在评论区留言讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/975636.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

dataframe如何在末尾添加多行

如果要在pandas的dataframe中添加多行该如何实现?可通过以下常见方式在DataFrame末尾添加: ### 方法一:使用loc索引器 利用loc索引器分两次操作来添加两行数据。假设已有DataFrame对象df,要添加的两行数据分别存储在字典new_row…

使用 DeepSeek 生成流程图、甘特图与思维导图:结合 Typora 和 XMind 的高效工作流

在现代工作与学习中,可视化工具如流程图、甘特图和思维导图能够极大地提升信息整理与表达的效率。本文将详细介绍如何使用 DeepSeek 生成 Mermaid 文本,结合 Typora 快速生成流程图和甘特图,并通过 Markdown 格式生成思维导图,最终…

插入排序(详解)c++

插⼊排序(Insertion Sort)类似于玩扑克牌插牌过程,每次将⼀个待排序的元素按照其关键字⼤⼩插⼊到前⾯已排好序的序列中,按照该种⽅式将所有元素全部插⼊完成即可 算法思想: 把待排序元素插入到已排序的序列中。想象一下一张一张整理扑克牌的…

【大模型】蓝耘智算云平台快速部署DeepSeek R1/R3大模型详解

目录 一、前言 二、蓝耘智算平台介绍 2.1 蓝耘智算平台是什么 2.2 平台优势 2.3 应用场景 2.4 对DeepSeek 的支持 2.4.1 DeepSeek 简介 2.4.2 DeepSeek 优势 三、蓝耘智算平台部署DeepSeek-R1操作过程 3.1 注册账号 3.1.1 余额检查 3.2 部署DeepSeek-R1 3.2.1 获取…

ai-financial-agent - 为金融投资打造的AI代理

探索人工智能在投资研究中的应用。本项目仅用于**教育**目的,不用于真实交易或投资。 作者声明: 本项目仅用于教育和研究目的。 不用于真实交易或投资不提供任何保证或担保过去的表现并不代表未来的结果Creator 对经济损失不承担任何责任咨询财务顾问…

基于keepalived的Nginx高可用架构

一、概述 Keepalived 是一个基于 VRRP(Virtual Router Redundancy Protocol)协议 的高可用性解决方案,为了解决静态路由器出现的单点故障问题,它能偶保证网络的不间断、稳定的运行。 二、核心功能 IP 漂移(VIP&…

学术论文项目网站搭建教程【Github】

本教程使用的是linux系统,ubuntu20.04版本进行学术项目网站搭建 一:创建github的个人组织 我个人习惯使用自己的github组织【Your organizations】来进行学术项目网站的创建: New一个organization,点击Free中的Create a free o…

postman调用ollama的api

按照如下设置,不需要设置key 保持长会话的方法 # 首次请求 curl http://localhost:11434/api/generate -d {"model": "deepseek-r1:32b","prompt": "请永久记住:110,1-12,之后所有数学计算必…

【Linux】多线程 -> 线程同步与基于BlockingQueue的生产者消费者模型

线程同步 条件变量 当一个线程互斥地访问某个变量时,它可能发现在其它线程改变状态之前,它什么也做不了。 例如:一个线程访问队列时,发现队列为空,它只能等待,直到其它线程将一个节点添加到队列中。这…

ChatGPT各模型版本对比分析

文章目录 1. GPT-3.5(2022年11月)2. GPT-4(2023年3月)3. GPT-4o(2024年5月)4. GPT-4o mini(2024年7月)5. o1系列(2024年9月至12月)6. o3-mini(202…

萌新学 Python 之自定义函数

函数主要用来封装功能,具有独立功能的代码块,可以提高代码重复利用率,便于模块管理 函数的定义: def 函数名(形参): 函数体,独立功能的代码 return ‘函数的返回值’ 函数注意事项: 1.函数的命名通常使…

【工作流】Spring Boot 项目与 Camunda 的整合

【工作流】Spring Boot 项目与 Camunda 的整合 【一】Camunda 和主流流程引擎的对比【二】概念介绍【1】Camunda 概念:【2】BPMN 概念 【三】环境准备【1】安装流程设计器CamundaModeler【画图工具】(1)下载安装 【2】CamundaModeler如何设计…

【Linux】基于UDP/TCP套接字编程与守护进程

目录 一、网路套接字编程 (一)基础概念 1、源IP地址与目的IP地址 2、端口号 3、TCP与UDP 4、网络字节序 (二)套接字编程接口 1、socket 常见API 2、sockaddr结构 (三)UDP套接字 1、UDP服务器创建…

【图像处理】:两幅图中相同区域的相似度比较

两幅图中相同区域的相似度比较 1.OpenCV和Python实现的两幅图相似度衡量方法1. 均方误差(MSE)2. 结构相似性指数(SSIM)图像协方差能显示结构特征的原因 3. 直方图相似度4. 特征点匹配5. 相关系数(Pearson Correlation&…

[python脚本]论文1.(一)CPU/内存数据分析和分组

CPU 收集到的CPU数据,格式如下: 由于这里6个数据为一组来收集latency的数据以及各个分位值的数据,而本质上每一行都是一次完整的测试,因此这里将这个csv文件分为两个文件,第一个是和latency相关的,将6条数…

綫性與非綫性泛函分析與應用_1.例題(下)-半母本

第1章 實分析與函數論:快速回顧(下) 五、基數;有限集和無限集相關例題 例題1:集合基數的判斷 判斷集合和集合B=\{a,b,c,d,e\}的基數關係。 解析: 可以構造一個雙射,例如,,,,。 所以,兩個集合具有相同的基數。 例題2:可數集的證明 證明整數集是可數集。 解析: …

MQTT实现智能家居------3、源码分析(超详细)

一、连接服务器 1、初始化: mqtt_log_init();是一个空函数,自己定义宏 client mqtt_lease();//创建一个client结构体,从此以后client代表客户端 platform_memory_alloc();//是一个分配内存的总函数,可以适用于Linux、FreeRTos…

Qt常用控件之日历QCalendarWidget

日历QCalendarWidget QCalendarWidget 是一个日历控件。 QCalendarWidget属性 属性说明selectDate当前选中日期。minimumDate最小日期。maximumDate最大日期。firstDayOfWeek设置每周的第一天是周几(影响日历的第一列是周几)。gridVisible是否显示日历…

智慧废品回收小程序php+uniapp

废品回收小程序:数字化赋能环保,开启资源循环新时代 城市垃圾治理难题,废品回收小程序成破局关键 随着城市化进程加速与消费水平提升,我国生活垃圾总量逐年攀升,年均增速达5%-8%,其中超30%为可回收物。然…

SkyWalking集成Kafka实现日志异步采集经验总结

SkyWalking日志异步采集架构 【重点知识】 1、【Agent】kafka-reporter-plugin-x.x.x.jar包放plugins目录后必走kafka(kafka没有正确配置就会报错) 2、【Agent】异步如不开启数据压缩,日志数据较大,pod多、业务大时容易造成网络…