使用 AlexNet 实现图片分类 | PyTorch 深度学习实战

前一篇文章,CNN 卷积神经网络处理图片任务 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

本篇文章内容来自于 强化学习必修课:引领人工智能新时代【梗直哥瞿炜】

使用 AlexNet 实现图片分类

  • 经典卷积网络
  • AlexNet 特点
  • 实验代码
  • 实验结果
  • Why AlexNet Works
  • Links

经典卷积网络

以下是卷积神经网络发展的里程碑:

在这里插入图片描述

  • AlexNet 在各项比赛中,比其它算法好,证明了深度神经网络算法的优越性和前景
  • VGGNet 则使用比 AlexNet 更深更宽的网络,取得了比 AlexNet 还好的成绩
  • GoogLeNet 效果则比 VGGNet 更好
  • ResNet 引入残差模块,解决了深度网络训练中的退化问题,超越之前的模型
  • DenseNet 模型采用密集连接的结构,使模型具有更好的鲁棒性

AlexNet 特点

在这里插入图片描述

AlexNet 结构

![[../assets/media/screenshot_20250208184431.png]]

更多详细介绍,阅读作者 Paper 论文, ImageNet Classification with Deep Convolutional Neural Networks

视频资源:9年后重读深度学习奠基作之一:AlexNet【上】【论文精读】

实验代码

from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from torchvision import datasets
import torch.nn as nn
import torch
import numpy as np

# configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_rootdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, "data")


################################
# 定义 dataset loader
################################
def get_train_valid_loader(data_dir,
                           batch_size,
                           augment,
                           random_seed,
                           valid_size=0.1,
                           shuffle=True):

    # 正则化图片的参数,mean 和 std 的值来自于 imagenet 的数据统计
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # define transforms
    valid_transform = transforms.Compose([
        transforms.Resize((227, 227)),
        transforms.ToTensor(),
        normalize,
    ])
    if augment:
        # 数据增强
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((227, 227)),
            transforms.ToTensor(),
            normalize,
        ])

    # load the dataset
    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=train_transform,
    )

    valid_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=valid_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler)

    return (train_loader, valid_loader)


def get_test_loader(data_dir,
                    batch_size,
                    shuffle=True):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    # define transform
    transform = transforms.Compose([
        transforms.Resize((227, 227)),
        transforms.ToTensor(),
        normalize,
    ])

    dataset = datasets.CIFAR10(
        root=data_dir, train=False,
        download=True, transform=transform,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle
    )

    return data_loader


# CIFAR10 dataset
CIFAR10_data = os.path.join(data_rootdir, "CIFAR10")
train_loader, valid_loader = get_train_valid_loader(data_dir= CIFAR10_data,                                      batch_size = 64,
                                                    augment = False,                                          random_seed = 1)

test_loader = get_test_loader(data_dir= CIFAR10_data,
                              batch_size = 64)


################################
# 定义 Model
################################
class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(9216, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())
        self.fc2= nn.Sequential(
            nn.Linear(4096, num_classes))

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

######################################
# Setting Hyperparameters
######################################
num_classes = 10
num_epochs = 20
batch_size = 64
learning_rate = 0.005

model = AlexNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  

# Train the model
total_step = len(train_loader)


######################################
# Training
######################################
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                    .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    # Validation
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs

        print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))


# Now, we see how our model performs on unseen data
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))

在这里插入图片描述

实验结果

以上代码在 NVIDIA GeForce RTX 2050 WDDM 显存上训练和测试,大约花了半个小时时间。
最终,测试集上的准确率达到了 82.24% 。
在这里插入图片描述

Why AlexNet Works

Links

  • Writing AlexNet from Scratch in PyTorch
  • ImageNet Classification with Deep Convolutional
    Neural Networks
  • Conv2d API in PyTorch
  • AlexNet 论文精读(上),9年后重读深度学习奠基作之一
  • AlexNet 论文精读(下)
  • AlexNet Explained: A Step-by-Step Guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/968281.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于进化式大语言模型的下一代漏洞挖掘范式:智能对抗与自适应攻防体系

摘要 本文提出了一种基于进化式大语言模型(Evolutionary LLM)的智能漏洞挖掘框架,突破了传统静态分析的局限,构建了具备对抗性思维的动态攻防体系。通过引入深度强化学习与多模态感知机制,实现了漏洞挖掘过程的自适应进化,在RCE、SQLi、XXE等关键漏洞类型的检测中达到97…

python自动化测试之Pytest框架之YAML详解以及Parametrize数据驱动!

一、YAML详解 YAML是一种数据类型,它能够和JSON数据相互转化,它本身也是有很多数据类型可以满足我们接口 的参数类型,扩展名可以是.yml或.yaml 作用: 1.全局配置文件 基础路径,数据库信息,账号信息&…

SQLMesh系列教程-2:SQLMesh入门项目实战(上篇)

假设你已经了解SQLMesh是什么,以及其他应用场景。如果没有,我建议你先阅读《SQLMesh系列教程-1:数据工程师的高效利器-SQLMesh》。 在本文中,我们将完成一个小项目或教程,以帮助你开始使用SQLMesh。你可以选择一步一步…

人工智能与低代码如何重新定义企业数字化转型?

引言:数字化转型的挑战与机遇 在全球化和信息化的浪潮中,数字化转型已经成为企业保持竞争力和创新能力的必经之路。然而,尽管“数字化”听上去是一个充满未来感的词汇,落地的过程却往往充满困难。 首先,传统开发方式…

使用云效解决docker官方镜像拉取不到的问题

目录 前言原文地址测试jenkins构建结果:后续使用说明 前言 最近经常出现docker镜像进行拉取不了,流水线挂掉的问题,看到一个解决方案: 《借助阿里个人版镜像仓库云效实现全免费同步docker官方镜像到国内》 原文地址 https://developer.aliyun.com/artic…

R语言LCMM多维度潜在类别模型流行病学研究:LCA、MM方法分析纵向数据

全文代码数据:https://tecdat.cn/?p39710 在数据分析领域,当我们面对一组数据时,通常会有已知的分组情况,比如不同的治疗组、性别组或种族组等(点击文末“阅读原文”获取完整代码数据)。 然而,…

java项目之基于用户兴趣的影视推荐系统设计与实现源码(ssm+mysql)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的基于用户兴趣的影视推荐系统设计与实现。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 基于用户…

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 0基础…

vue2 多页面pdf预览

使用pdfjs-dist预览pdf&#xff0c;实现预加载&#xff0c;滚动条翻页。pdfjs的版本很重要&#xff0c;换了好多版本&#xff0c;终于有一个能用的 node 20.18.1 "pdfjs-dist": "^2.2.228", vue页面代码如下 <template><div v-loading"loa…

堆排序

目录 堆排序&#xff08;不稳定&#xff09;&#xff1a; 代码实现&#xff1a; 思路分析&#xff1a; 总结&#xff1a; 堆排序&#xff08;不稳定&#xff09;&#xff1a; 如果想要一段数据从小到大进行排序&#xff0c;则要先建立大根堆&#xff0c;因为这样每次堆顶上都能…

【C++】多态原理剖析

目录 1.虚表指针与虚表 2.多态原理剖析 1.虚表指针与虚表 &#x1f36a;类的大小计算规则 一个类的大小&#xff0c;实际就是该类中成员变量之和&#xff0c;需要注意内存对齐空类&#xff1a;编译器给空类一个字节来唯一标识这个类的对象 对于下面的Base类&#xff0c;它的…

【Git】完美解决git push报错403

remote: Permission to xx.git denied to xx. fatal: unable to access https://github.com/xx/xx.git/: The requested URL returned error: 403出现这个就是因为你的&#xff08;personal access tokens &#xff09;PAT过期了 删掉旧的token 生成一个新的 mac系统 在mac的…

初窥强大,AI识别技术实现图像转文字(OCR技术)

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ &#x1f434;作者&#xff1a;秋无之地 &#x1f434;简介&#xff1a;CSDN爬虫、后端、大数据、人工智能领域创作者。目前从事python全栈、爬虫和人工智能等相关工作&#xff0c;主要擅长领域有&#xff1a;python…

黑马Redis详细笔记(实战篇---短信登录)

目录 一.短信登录 1.1 导入项目 1.2 Session 实现短信登录 1.3 集群的 Session 共享问题 1.4 基于 Redis 实现共享 Session 登录 一.短信登录 1.1 导入项目 数据库准备 -- 创建用户表 CREATE TABLE user (id BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 用户ID,phone …

逻辑回归不能解决非线性问题,而svm可以解决

逻辑回归和支持向量机&#xff08;SVM&#xff09;是两种常用的分类算法&#xff0c;它们在处理数据时有一些不同的特点&#xff0c;特别是在面对非线性问题时。 1. 逻辑回归 逻辑回归本质上是一个线性分类模型。它的目的是寻找一个最适合数据的直线&#xff08;或超平面&…

41.兼职网站管理系统(基于springbootvue的Java项目)

目录 1.系统的受众说明 2.相关技术 2.1 B/S架构 2.2 Java技术介绍 2.3 mysql数据库介绍 2.4 Spring Boot框架 3.系统分析 3.1 需求分析 3.2 系统可行性分析 3.2.1技术可行性&#xff1a;技术背景 3.2.2经济可行性 3.2.3操作可行性&#xff1a; 3.3 项目设计目…

MS08067练武场--WP

免责声明&#xff1a;本文仅用于学习和研究目的&#xff0c;不鼓励或支持任何非法活动。所有技术内容仅供个人技术提升使用&#xff0c;未经授权不得用于攻击、侵犯或破坏他人系统。我们不对因使用本文内容而引起的任何法律责任或损失承担责任。 注&#xff1a;此文章为快速通关…

Elasticsearch:如何使用 Elastic 检测恶意浏览器扩展

作者&#xff1a;来着 Elastic Aaron Jewitt 当你的 CISO 询问你的任何工作站上是否安装过特定的浏览器扩展时&#xff0c;你多快能得到正确答案&#xff1f;恶意浏览器扩展是一个重大威胁&#xff0c;许多组织无法管理或检测。这篇博文探讨了 Elastic Infosec 团队如何使用 os…

检测网络安全漏洞 工具 网络安全 漏洞扫描 实验

实验一的名称为信息收集和漏洞扫描 实验环境&#xff1a;VMware下的kali linux2021和Windows7 32&#xff0c;网络设置均为NAT&#xff0c;这样子两台机器就在一个网络下。攻击的机器为kali,被攻击的机器为Windows 7。 理论知识记录&#xff1a; 1.信息收集的步骤 2.ping命令…

esxi添加内存条因为资源不足虚拟机无法开机——避坑

exsi8.0我加了6根内存条&#xff0c;然后将里面的ubuntu的内存增加 haTask-2-vim.VirtualMachine.powerOn-919 描述 打开该虚拟机电源 虚拟机 ub22 状况 失败 - 模块“MonitorLoop”打开电源失败。 错误 模块“MonitorLoop”打开电源失败。无法将交换文件 /vmfs/volumes…