PyTorch: 基于【VGG16】处理MNIST数据集的图像分类任务【准确率98.9%+】

目录

  • 引言
  • 在Conda虚拟环境下安装pytorch
  • 步骤一:利用代码自动下载mnist数据集
  • 步骤二:搭建基于VGG16的图像分类模型
  • 步骤三:训练模型
  • 步骤四:测试模型
  • 运行结果
  • 后续模型的优化和改进建议
  • 完整代码
  • 结束语

引言

在本博客中,小编将向大家介绍如何使用VGG16处理MNIST数据集的图像分类任务。MNIST数据集是一个常用的手写数字分类数据集,包含60,000个训练样本和10,000个测试样本。我们将使用Python编程语言和PyTorch深度学习框架来实现这个任务。

在Conda虚拟环境下安装pytorch

# CUDA 11.6
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
# CUDA 11.3
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# CUDA 10.2
pip install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu102
# CPU only
pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cpu

步骤一:利用代码自动下载mnist数据集

import torchvision.datasets as datasets  
import torchvision.transforms as transforms  
  
# 定义数据预处理操作  
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])
  
# 下载并加载MNIST数据集  
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)  
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

步骤二:搭建基于VGG16的图像分类模型

class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构VGG16网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,这样可以有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

步骤三:训练模型

import torch.optim as optim  
from torch.utils.data import DataLoader  
  
# 定义超参数和训练参数  
batch_size = 64  # 批处理大小  
num_epochs = 5  # 训练轮数
learning_rate = 0.01  # 学习率
num_classes = 10  # 类别数(MNIST数据集有10个类别)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用GPU进行训练,否则使用CPU。

# 定义训练集和测试集的数据加载器  
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)  
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)  
  
# 初始化模型和优化器  
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)  
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数  
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)  
  
# 训练模型  
for epoch in range(num_epochs):  
    for i, (images, labels) in enumerate(train_loader):  
        images = images.to(device)  # 将图像数据移动到指定设备  
        labels = labels.to(device)  # 将标签数据移动到指定设备  
          
        # 前向传播  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
          
        # 反向传播和优化  
        optimizer.zero_grad()  # 清空梯度缓存  
        loss.backward()  # 计算梯度  
        optimizer.step()  # 更新权重参数  
          
        if (i+1) % 100 == 0:  # 每100个batch打印一次训练信息  
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))  
              
# 保存模型参数  
torch.save(model.state_dict(), './model.pth')

步骤四:测试模型

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

运行结果

在这里插入图片描述

后续模型的优化和改进建议

  1. 数据增强:通过旋转、缩放、平移等方式来增加训练数据,从而让模型拥有更好的泛化能力。
  2. 调整模型参数:可以尝试调整模型的参数,比如学习率、批次大小、迭代次数等,来提高模型的性能。
  3. 更换网络结构:可以尝试使用更深的网络结构,如ResNet、DenseNet等,来提高模型的性能。
  4. 调整优化器:本次代码采用SGD优化器,但仍可以尝试使用不同的优化器,如Adam、RMSprop等,来找到最适合我们模型的优化器。
  5. 添加正则化操作:为了防止过拟合,可以添加一些正则化项,如L1正则化、L2正则化等。
  6. 代码目前只有等训练完全结束后才能进入测试阶段,后续可以在每个epoch结束,甚至是指定的迭代次数完成后便进入测试阶段。因为训练完全结束的模型很可能已经过拟合,在测试集上不能表现较强的泛化能力。

完整代码

import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为(224, 224)
    transforms.ToTensor(),  # 将图像转换为PyTorch张量
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])

# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)


class VGGClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGGClassifier, self).__init__()
        self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器
        # 重构网络的第一层卷积层,适配mnist数据的灰度图像格式
        self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096
            nn.ReLU(True),
            nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。
            nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10)
        )
        self._initialize_weights()  # 初始化权重参数

    def forward(self, x):
        x = self.features(x)  # 通过特征提取器提取特征
        x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量
        x = self.classifier(x)  # 通过分类器进行分类预测
        return x

    def _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



# 定义超参数和训练参数
batch_size = 64  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.01  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()  # 清空梯度缓存
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新权重参数

        if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),
                                                                     loss.item()))

# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作

# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量

# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间
    for images, labels in test_loader:
        images = images.to(device)  # 将图像数据移动到指定设备
        labels = labels.to(device)  # 将标签数据移动到指定设备
        outputs = model(images)  # 模型前向传播,得到预测结果
        _, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别
        total += labels.size(0)  # 更新总样本数量
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量

# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

结束语

如果本博文对你有所帮助/启发,可以点个赞/收藏支持一下,如果能够持续关注,小编感激不尽~
如果有相关需求/问题需要小编帮助,欢迎私信~
小编会坚持创作,持续优化博文质量,给读者带来更好de阅读体验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/247421.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

antd+vue:tree组件:父级节点禁止选择并不展示选择框——基础积累

antdvue:tree组件:父级节点禁止选择并不展示选择框——基础积累 1.判断哪些是父节点,给父节点添加disabled属性——this.permissionList是数据源2.通过css样式来处理disabled的父节点3.完整代码如下: 最近在写后台管理系统的时候,…

C语言—每日选择题—Day47

第一题 1. 以下逗号表达式的值为() (x 4 * 5, x * 5), x 25 A:25 B:20 C:100 D:45 答案及解析 D 本题考查的就是逗号表达式,逗号表达式是依次计算每个表达式,但是只输出最后一个表…

DRF从入门到精通一(DRF介绍、API接口、接口测试工具)

文章目录 DRF入门规范一、Web应用模式1.前后端混合开发模式2.前后端分离开发模式 二、API接口概念Json格式与XML格式的区别前端、前台、后端以及后台的区别 三、接口测试工具:Postman/Apifox DRF入门规范 DRF(django rest framework)是基于django的一个框架&#xf…

文心一言 VS 讯飞星火 VS chatgpt (157)-- 算法导论12.3 4题

四、用go语言,删除操作可交换吗?可交换的含义是,先删除 x 再删除 y 留下的结果树与先除 y 再删除 x 留下的结果树完全一样。如果是,说明为什么? 否则,给出一个反例。 文心一言: 在Go语言中,删除操作是不…

企业信息建设现状

信息化建设是传统计算机与互联网技术高速发展并融合的产物,现阶段已经成为引领产业创新的决定性技术手段。 随着信息化的不断发展与进步,各行各业都开始了信息化的建设与应用。信息化是未来发展的大趋势,企业运用信息技术可以大幅度提高员工效…

知识|基于混合模式的多余度飞控全数字仿真系统研究

*余度(Redundancy):一种确保安全的设计手段,使得出现两个及以上故障时,才会引起既定不希望发生的工作状态。 飞行控制软件主要完成飞行传感器数据处理、飞行姿态控制和余度管理任务,对保证飞机安全性和可靠…

leetCode算法—2.两数相加

1.给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外,这两个数都不会以 0…

干货:企业如何讲好品牌故事

品牌故事讲得好,不仅能够体现品牌特色还能向消费者传递品牌精神的重要工具,优秀的品牌故事能够促进产品销量,为品牌带来曝光率,今天媒介盒子就来和大家聊聊:如何讲好品牌故事。 一、 品类历史和故事 品牌虽然是新品牌…

基于单片机智能家具无线遥控控制系统设计

**单片机设计介绍,基于单片机智能家具无线遥控控制系统设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的智能家具无线遥控控制系统设计可以实现对家具(如灯具、窗帘、空调等)的…

【Python】解读a+=b 和 a=a+b是否一样?看完恍然大悟!

文章目录 前言一、可变对象和不可变对象总结 前言 在Python中,对于可变和不可变对象的行为差异是一个重要概念,特别是在涉及到和操作时。理解这一点对于编写高效且无误的代码至关重要。 一、可变对象和不可变对象 首先,让我们谈谈可变和不可…

JVM的内存分区以及垃圾收集

1.JVM的内存分区 1.1方法区 方法区(永久代)主要用来存储已在虚拟机加载的类的信息、常量、静态变量以及即时编译器编译后的代码信息。该区域是被线程共享的。 1.2虚拟机栈 虚拟机栈也就是我们平时说的栈内存,它是为java方法服务的。每个方法在执行的…

分析若依的文件上传处理逻辑

分析若依的文件上传处理逻辑 注:已经从若依框架完成拆分,此处单独分析一下人家精彩的封装,也来理解一下怎么做一个通用的上传接口!如有分析的,理解的不透彻的地方,大家多多包含,欢迎批评指正&am…

【C语言必学知识点五】指针

指针 导言一、指针与指针变量二、指针变量的创建和指针类型三、指针类型的意义3.1 指针 /- 整数3.2 指针解引用 四、野指针4.1 定义4.2 野指针的成因4.3 指针未初始化4.4 指针越界访问4.5 指针指向的空间被释放4.6 如何规避野指针 五、指针运算5.1指针-整数5.2 指针-指针5.2.1 …

B037-Mybatis基础

目录 为什么需要Mybatis?mybatis简介入门案例其余见代码查询流程增删改流程 - 变动数据要加事务去持久化抽取公共类 mapper接口开发规则概述代码 mapper.xml引入本地约束文件别名日志管理作用log4j的使用规范 井大括号与dollar大括号的区别 框架:半成品&…

Linux篇:信号

一、信号的概念: ①进程必须识别能够处理信号,信号没有产生,也要具备处理信号的能力---信号的处理能力属于进程内置功能的一部分 ②进程即便是没有收到信号,也能知道哪些信号该怎么处理。 ③当进程真的受到了一个具体的信号的时候…

Word公式居中+序号右对齐

Word公式居中序号右对齐 # 号制表位法表格法Mathtype法 # 号 制表位法 表格法 Mathtype法 参考1 参考2

力扣每日一题:2132. 用邮票贴满网格图(2023-12-14)

力扣每日一题 题目:2132. 用邮票贴满网格图 日期:2023-12-14 用时:38 m 32 s 思路:使用前缀和+差分,只是往常是一维,现在变二维了,原理差不多 时间:22ms 内存&#xff1…

certum ev ssl证书1180元一年,360浏览器显示公司名

Certum旗下的EV SSL证书是审核最严的数字证书,不仅对网站传输数据进行加密,还可以对网站身份进行验证,除此之外,它独有的绿色地址栏提升了网站的真实性,增强了客户对网站的信任感。今天就随SSL盾小编了解Certum旗下的E…

【Spring Boot】视图渲染技术之Freemarker

一、引言 1、什么是Freemarker FreeMarker是一款模板引擎,基于模板和要改变的数据,并用来生成输出文本(HTML网页、电子邮件、配置文件、源代码等)的通用工具。它不是面向最终用户的,而是一个Java类库,是一款…

LAMP平台部署及应用

1、安装PHP软件包 1.1、准备工作 检查软件是否安装,避免冲突 [rootyang ~]# rpm -e php php-cli php-ldap php-common php-mysql --nodeps 错误:未安装软件包 php 错误:未安装软件包 php-cli 错误:未安装软件包 php-ldap 错误…