Pytorch | 从零构建AlexNet对CIFAR10进行分类

Pytorch | 从零构建AlexNet对CIFAR10进行分类

  • CIFAR10数据集
  • AlexNet
    • 网络结构
    • 技术创新点
    • 性能表现
    • 影响和意义
  • AlexNet结构代码详解
    • 结构代码
    • 代码详解
      • 特征提取层 self.features
      • 分类部分self.classifier
      • 前向传播forward
  • 训练过程和测试结果
  • 代码汇总
    • alexnet.py
    • train.py
    • test.py

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

AlexNet

AlexNet是由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年提出的一种深度卷积神经网络,在ImageNet图像识别挑战赛中取得了巨大成功,推动了深度学习在计算机视觉领域的快速发展。以下是对它的详细介绍:

网络结构

  • 卷积层:包含5个卷积层,这些卷积层通过不同的卷积核大小、步长和填充方式,逐步提取图像的特征。
  • 池化层:有3个最大池化层,用于减小特征图的尺寸,同时保留关键特征,减少计算量和过拟合风险。
  • 全连接层:包括3个全连接层,用于对提取的特征进行分类,最后一层输出分类结果。
    在这里插入图片描述
    上图为AlexNet原文中的网络结构(针对ImageNet,图片尺寸为224×224),本文是针对CIFAR10,其尺寸为32×32,因此结构不太相同,比如卷积核的大小,具体可以参考下面的代码。

技术创新点

  • ReLU激活函数:使用ReLU(Rectified Linear Unit)作为激活函数,解决了传统激活函数在深度网络中梯度消失的问题,加快了训练速度。
  • Dropout正则化:在全连接层中使用了Dropout技术,随机丢弃部分神经元,防止过拟合,提高模型的泛化能力。
  • 重叠池化:采用重叠池化(Overlapping Pooling),即池化窗口之间有重叠,有助于提取更多的特征信息,提升模型的性能。
  • 多GPU训练:首次利用多GPU进行并行训练,大大提高了训练速度,使得在大规模数据集上训练深度网络成为可能。

性能表现

  • 在ImageNet数据集上,AlexNet的top-5错误率大幅降低至15.3%,相比之前的方法有了显著提升,展示了其强大的图像识别能力。
  • 能够学习到丰富的图像特征,对不同类别的物体具有很好的区分能力,在实际应用中取得了很好的效果。

影响和意义

  • 推动深度学习发展:AlexNet的成功引起了学术界和工业界对深度学习的广泛关注,激发了更多研究人员对深度神经网络的研究兴趣,推动了深度学习技术的快速发展。
  • 开启卷积神经网络新时代:为后续的卷积神经网络研究提供了重要的参考和借鉴,许多新的网络结构和技术都是在AlexNet的基础上发展而来的。
  • 拓展应用领域:由于其在图像识别任务上的出色表现,AlexNet及其改进模型被广泛应用于计算机视觉的各个领域,如目标检测、图像分割、人脸识别等。

AlexNet结构代码详解

结构代码

import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            # input size: (B, 3, 32, 32)   (Batch_size, Channel, Height, Width)
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), # (B, 64, 16, 16)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 64, 8, 8)
            nn.Conv2d(64, 192, kernel_size=3, padding=1),   # (B, 192, 8, 8)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 192, 4, 4)
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # (B, 384, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 256, 2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 *2)
        x = self.classifier(x)
        return x

代码详解

以下是对上述AlexNet代码的详细解释:

特征提取层 self.features

这部分构建了AlexNet的特征提取层,是一个由多个层组成的顺序结构(通过nn.Sequential来定义)。
- 第一个卷积层
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)表示输入图像的通道数为3(通常对应RGB图像的红、绿、蓝三个通道),输出的通道数为64(即卷积核的数量为64,意味着会生成64个不同的特征图),卷积核大小是3×3,步长为2(在空间维度上每次移动2个像素),填充为1(在图像边缘进行1个像素的填充,这样可以保证输入输出图像尺寸在卷积操作下能按预期变化),经过这个卷积层后,输入尺寸为(B, 3, 32, 32)的图像数据会变成(B, 64, 16, 16)
- 激活函数层
nn.ReLU(inplace=True)是使用修正线性单元(Rectified Linear Unit)作为激活函数,inplace=True表示直接在输入的张量上进行修改(节省内存空间),对经过卷积后的特征图进行非线性变换,增强网络的表达能力。
- 池化层
nn.MaxPool2d(kernel_size=2)是最大池化层,池化核大小为2×2,它会在每个2×2的窗口内选取最大值作为输出,起到下采样的作用,减少数据量同时保留重要特征,比如经过第一次池化后特征图尺寸从(B, 64, 16, 16)变为(B, 64, 8, 8)

后续依次重复卷积、激活、池化等操作,不断提取图像的特征,逐步降低特征图的尺寸同时增加特征图的深度(通道数),最终经过这一系列操作后得到尺寸为(B, 256, 2, 2)的特征图。

分类部分self.classifier

self.classifier = nn.Sequential(
    nn.Dropout(),
    nn.Linear(256 * 2 * 2, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(inplace=True),
    nn.Linear(4096, num_classes)
)

这部分构建了AlexNet的分类器,同样是顺序结构。
- Dropout层
nn.Dropout()是一种正则化技术,在训练过程中以一定概率(默认0.5)随机将神经元的输出设置为0,防止过拟合,提高模型的泛化能力。这里使用了两次Dropout,分别在不同的全连接层之前。
- 全连接层
第一个nn.Linear(256 * 2 * 2, 4096)表示将经过特征提取后展平的特征向量(尺寸为256 * 2 * 2,因为前面特征提取部分最后得到的特征图尺寸是(B, 256, 2, 2),展平后维度就是256 * 2 * 2)映射到一个4096维的向量空间,后面接着激活函数nn.ReLU(inplace=True)进行非线性变换。然后又是一个Dropout层和一个同样输出维度为4096的全连接层以及相应的激活函数,最后通过nn.Linear(4096, num_classes)将4096维的向量映射到指定的类别数(num_classes)维度,得到最终的分类预测结果。

前向传播forward

def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), 256 * 2 *2)
    x = self.classifier(x)
    return x

forward方法定义了数据在网络中的前向传播过程。

  • 特征提取
    首先x = self.features(x),将输入数据x送入到之前定义的特征提取部分(features),按照特征提取层中定义的卷积、激活、池化等操作依次对输入数据进行处理,得到提取后的特征图。
  • 特征图展平
    x = x.view(x.size(0), 256 * 2 *2)这行代码将特征图进行展平操作,使其变成一个二维张量,其中第一维对应批次大小(x.size(0)表示批次中的样本数量),第二维就是展平后的特征向量长度(由前面特征提取最后得到的特征图尺寸计算得出),这样才能输入到后面的全连接层中进行分类处理。
  • 分类预测
    最后x = self.classifier(x)将展平后的特征向量送入分类器部分(classifier),经过全连接层、激活函数、Dropout等操作逐步得到最终的分类预测结果,然后通过return x返回这个预测结果。

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述
在这里插入图片描述
训练过程准确率变化曲线:
在这里插入图片描述
测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models
	|--__init__.py
	|--alexnet.py
|--results
|--weights
|--train.py
|--test.py

alexnet.py

import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            # input size: (B, 3, 32, 32)   (Batch_size, Channel, Height, Width)
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), # (B, 64, 16, 16)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 64, 8, 8)
            nn.Conv2d(64, 192, kernel_size=3, padding=1),   # (B, 192, 8, 8)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 192, 4, 4)
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # (B, 384, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 256, 2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 *2)
        x = self.classifier(x)
        return x

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import AlexNet
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练轮次
epochs = 15

def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    loss_history, acc_history = [], []
    for epoch in range(epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        loss_history.append(train_loss)
        acc_history.append(train_acc)
        # 保存模型权重,每5轮次保存到weights文件夹下
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'weights/alexnet_epoch_{epoch + 1}.pth')
    # 绘制损失曲线
    plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.savefig('results\\train_loss_curve.png')
    plt.close()

    # 绘制准确率曲线
    plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy Curve')
    plt.legend()
    plt.savefig('results\\train_acc_curve.png')
    plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import AlexNet

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = "weights/alexnet_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))

def test(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(testloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    test_loss, test_acc = test(model, testloader, criterion, device)
    print("================AlexNet Test================")
    print(f"Load Model Weights From: {weights_path}")
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/939844.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

初学stm32 --- 系统时钟配置

众所周知,时钟系统是 CPU 的脉搏,就像人的心跳一样。所以时钟系统的重要性就不言而喻了。 STM32 的时钟系统比较复杂,不像简单的 51 单片机一个系统时钟就可以解决一切。于是有人要问,采用一个系统时钟不是很简单吗?为…

进程间通信方式---System V IPC信号量

进程间通信方式—System V IPC信号量 文章目录 进程间通信方式---System V IPC信号量信号量1.信号量原语2.semget 系统调用参数返回值 3.semop 系统调用参数返回值 4.semctl 系统调用5.特殊键值 IPC_PRIVATE6.信号量实现进程间通信1. 数据结构定义2. 信号量操作相关部分3. 生产…

深入理解Kafka:核心设计与实践原理读书笔记

目录 初识Kafka基本概念安装与配置ZooKeeper安装与配置Kafka的安装与配置 生产与消费服务端参数配置 生产者客户端开发消息对象 ProducerRecord必要的参数配置发送消息序列化分区器生产者拦截器 原理分析整体架构元数据的更新 重要的生产者参数acksmax.request.sizeretries和re…

electron 顶部的元素点不中,点击事件不生效

electron 顶部的元素点不中,点击事件不生效

Excel设置生日自动智能提醒,公式可直接套用!

大家好,我是小鱼。 今天跟大家分享一个WPS表格中根据出生日期,设置生日提醒,并且根据距离生日天数自动标记数据颜色。简单又实用,一个公式轻松搞定! 接下来我们先学习一下需要使用到的函数,然后再根据实例让…

全域数据集成平台ETL

全域数据集成平台ETL Restcloud 工作原理 RestCloud数据集成平台采用SpringCloud微服务架构技术开发,底层基于纯Java语言采用前后端分离架构,前端采用React技术进行开发。 RestCloud数据集成平台是基于数据流工作流引擎的架构进行研发的,底…

Spring(一)---IOC(控制权反转)

目录 引入 1.什么叫IOC(Inversion of Control)控制权反转? 2.什么叫AOP(Aspect-Oriented Programming)面向切面编程(涉及Java代理)? 3.简单谈一下Java怎么实现ICO? Spring框架的介绍 1. Spring框架的概述 2. Spring框架的优点 Spring IOC容器介绍…

【GESP】C++二级考试大纲知识点梳理, (4)流程图

GESP C二级官方考试大纲中,共有9条考点,本文针对C(4)号知识点进行总结梳理。 (4)了解流程图的概念及基本表示符号,掌握绘制流程图的方法,能正确使用流程图描述程序设计的三种基本结构…

scala中正则表达式的使用

正则表达式: 基本概念 在 Scala 中,正则表达式是用于处理文本模式匹配的强大工具。它通过java.util.regex.Pattern和java.util.regex.Matcher这两个 Java 类来实现(因为 Scala 运行在 Java 虚拟机上,可以无缝使用 Java 类库&…

使用VSCode Debugger 调试 React项目

一般我们调试代码时,用的最多的应该就是console.log方式了,还有的是使用Chrome DevTools 通过在对应的 sourcemap代码位置打断点进行调试,除了上面两种方式外还有一种更好用的调试方式: VSCode Debugger。 VSCode Debugger可以直…

微信小程序实现上传图片自定义水印功能、放大缩小旋转删除、自定义字号颜色位置、图片导出下载、图像预览裁剪、Canvas绘制 开箱即用

目录 功能实现画布绘制上传图片并渲染图片操作事件添加文字水印canvas解析微信小程序中 canvas 的应用场景canvas 与 2D 上下文、webgl 上下文的关系图像的加载与绘制总结说明功能实现 画布绘制 在wxml添加canvas标签并在在当前页面的 data 对象中,创建一个 Canvas 上下文(c…

用.Net Core框架创建一个Web API接口服务器

我们选择一个Web Api类型的项目创建一个解决方案为解决方案取一个名称我们这里选择的是。Net 8.0框架 注意,需要勾选的项。 我们找到appsetting.json配置文件 appsettings.json配置文件内容如下 {"Logging": {"LogLevel": {"Default&quo…

[创业之路-199]:《华为战略管理法-DSTE实战体系》- 3 - 价值转移理论与利润区理论

目录 一、价值转移理论 1.1. 什么是价值? 1.2. 什么价值创造 (1)、定义 (2)、影响价值创造的因素 (3)、价值创造的三个过程 (4)、价值创造的实践 (5&…

【阅读记录-章节6】Build a Large Language Model (From Scratch)

文章目录 6. Fine-tuning for classification6.1 Different categories of fine-tuning6.2 Preparing the dataset第一步:下载并解压数据集第二步:检查类别标签分布第三步:创建平衡数据集第四步:数据集拆分 6.3 Creating data loa…

[搜广推]王树森推荐系统——矩阵补充最近邻查找

矩阵补充(工业界不常用) 模型结构 embedding可以把 用户ID 或者 物品ID 映射成向量输入用户ID 和 物品ID,输出向量的内积(一个实数),内积越大说明用户对这个物品越感兴趣模型中的两个embedding层不共享参…

【优选算法篇】揭秘快速排序:分治算法如何突破性能瓶颈

文章目录 须知 💬 欢迎讨论:如果你在学习过程中有任何问题或想法,欢迎在评论区留言,我们一起交流学习。你的支持是我继续创作的动力! 👍 点赞、收藏与分享:觉得这篇文章对你有帮助吗&#xff1…

建投数据与腾讯云数据库TDSQL完成产品兼容性互认证

近日,经与腾讯云联合测试,建投数据自主研发的人力资源信息管理系统V3.0、招聘管理系统V3.0、绩效管理系统V2.0、培训管理系统V3.0通过腾讯云数据库TDSQL的技术认证,符合腾讯企业标准的要求,产品兼容性良好,性能卓越。 …

Java-30 深入浅出 Spring - IoC 基础 启动IoC 纯XML启动 Bean、DI注入

点一下关注吧!!!非常感谢!!持续更新!!! 大数据篇正在更新!https://blog.csdn.net/w776341482/category_12713819.html 目前已经更新到了: MyBatis&#xff…

如何利用Python爬虫获得1688按关键字搜索商品

在当今的数字化时代,数据已成为企业竞争的核心资源。对于电商行业来说,了解市场动态、分析竞争对手、获取商品信息是至关重要的。Python作为一种强大的编程语言,其丰富的库和框架使得数据爬取变得简单易行。本文将介绍如何使用Python爬虫技术…

WatchAlert - 开源多数据源告警引擎

概述 在现代 IT 环境中,监控和告警是确保系统稳定性和可靠性的关键环节。然而,随着业务规模的扩大和数据源的多样化,传统的单一数据源告警系统已经无法满足复杂的需求。为了解决这一问题,我开发了一个开源的多数据源告警引擎——…