【深度学习】pytorch,MNIST手写数字分类

efficientnet_b0的迁移学习


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt

# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10

# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0
    transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])

# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 用于绘图的数据
train_losses = []
test_accuracies = []

# 训练模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # 计算平均损失
    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)

    # 测试准确率
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move test data to the correct device
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')

# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/485038.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言——sizeof与strlen的对比

一.sizeof 我们在学习操作符的时候&#xff0c;就了解到了sizeof操作符&#xff0c;它的作用是求参数所占内存空间的大小&#xff0c;单位是字节。如果参数是一个类型&#xff0c;那就返回参数所占的字节数。 #include <stdio.h>int main() {int a 10;size_t b sizeo…

【机器学习300问】48、如何绘制ROC曲线?

ROC曲线&#xff08;受试者工作特征曲线&#xff09;是一种用于可视化评估二分类模型性能的指标。特别是在不同阈值情况下模型对正类和负类的区分能力。那么“阈值”到底是个什么呢&#xff1f;ROC曲线中的每一个点到底是什么意思&#xff1f; 一、ROC曲线的绘制【理论】 二分…

LeetCode Python - 72. 编辑距离

目录 题目描述解法运行结果 题目描述 给你两个单词 word1 和 word2&#xff0c; 请返回将 word1 转换成 word2 所使用的最少操作数 。 你可以对一个单词进行如下三种操作&#xff1a; 插入一个字符删除一个字符替换一个字符 示例 1&#xff1a; 输入&#xff1a;word1 “h…

Linux的介绍以及其发展历史

文章目录 前言一、技术是推动社会发展的基本动力1.人为什么能成为万物之长呢&#xff1f;2.人为什么要发明工具&#xff0c;进行进化呢&#xff1f;3.人是如何发明工具的&#xff1f;4.为什么要有不同的岗位和行业&#xff1f; 二、计算机(操作系统)发展的基本脉络1.第一台计算…

Google ScreenAI代表了一款先进的视觉语言模型,专为用户界面(UI)和视觉情境下的语言理解而设计

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

二次开发Flink-coGroup算子支持迟到数据通过测输出流提取

1.背景 coGroup算子开窗到时间关闭之后&#xff0c;迟到数据无法通过测输出流提取&#xff0c;intervalJoin算子提供了api&#xff0c;因为join算子底层就是coGroup算子&#xff0c;所以Join算子也不行。 flink版本 v1.17.1 2.coGroup算子源码分析 2.1完成的coGroup算子调用流…

QT(C++)-error LNK2038: 检测到“_ITERATOR_DEBUG_LEVEL”的不匹配项: 值“2”不匹配值“0”

1、项目场景&#xff1a; 在VS中采用QT&#xff08;C&#xff09;调试时&#xff0c;出现error LNK2038: 检测到“_ITERATOR_DEBUG_LEVEL”的不匹配项: 值“2”不匹配值“0”错误 2、解决方案&#xff1a; 在“解决方案资源管理器”中选中出现此类BUG的项目&#xff0c;右键-…

jenkins介绍,帮助你从安装到使用jenkins

Jenkins 概述 官网地址&#xff1a;https://www.jenkins.io/zh/ 什么是 Jenkins Jenkins是一款开源 CI&CD 软件&#xff0c;用于自动化各种任务&#xff0c;包括构建、测试和部署软件。它提供了一个易于使用的图形化界面&#xff0c;可以通过配置简单的任务来实现自动化构…

javaSSM游泳馆日常管理系统IDEA开发mysql数据库web结构计算机java编程maven项目

一、源码特点 IDEA开发SSM游泳馆日常管理系统是一套完善的完整企业内部系统&#xff0c;结合SSM框架和bootstrap完成本系统&#xff0c;对理解JSP java编程开发语言有帮助系统采用SSM框架&#xff08;MVC模式开发&#xff09;MAVEN方式加载&#xff0c;系统具有完整的源代码和…

Vue 3 里的 onMounted 怎么用?

疑问 最近&#xff0c;一直在学习 Vue 3&#xff0c;此前我不懂前端&#xff0c;也没写过 Vue 2&#xff0c;所以是从 0 开始学习 Vue 3 的。很多对普通人不是疑问的&#xff0c;在我这里也会不太清楚。 我在写项目的时候&#xff0c;常见的一种场景是这样的&#xff1a;页面…

分类预测 | Matlab实现MTF-CNN-Mutilhead-Attention马尔可夫转移场卷积网络多头注意力机制多特征分类预测/故障识别

分类预测 | Matlab实现MTF-CNN-Mutilhead-Attention马尔可夫转移场卷积网络多头注意力机制多特征分类预测/故障识别 目录 分类预测 | Matlab实现MTF-CNN-Mutilhead-Attention马尔可夫转移场卷积网络多头注意力机制多特征分类预测/故障识别分类效果基本介绍模型描述程序设计参考…

基于SSM非遗视域下喀什旅游网站

ssm非遗视域下喀什旅游网站的设计与实现 摘要 我们的生活水平正在不断的提高&#xff0c;然而提高的一个重要的侧面表现就是更加注重我们的娱乐生活。旅行是我们都喜欢的一种娱乐方式&#xff0c;各式各样的旅行经历给我们带来的喜悦也是大不相同的。带来快乐的同时也因为其复…

IntelliJ IDE 插件开发 | (七)PSI 入门及实战(实现 MyBatis 插件的跳转功能)

系列文章 IntelliJ IDE 插件开发 |&#xff08;一&#xff09;快速入门IntelliJ IDE 插件开发 |&#xff08;二&#xff09;UI 界面与数据持久化IntelliJ IDE 插件开发 |&#xff08;三&#xff09;消息通知与事件监听IntelliJ IDE 插件开发 |&#xff08;四&#xff09;来查收…

MongoDB高可用架构涉及常用功能整理

MongoDB高可用架构涉及常用功能整理 1. mongo架构和相关组件1.1. Master-Slave主从模式1.2. Replica Set 副本集模式1.3. Sharding 分片模式 2. Sharding 分片模式2.1. Hashed Sharding方式2.2. Range Sharding方式 3. 事务性4. 疑问和思考4.1. 怎么保证数据的高可靠&#xff1…

常用中间件redis,kafka及其测试方法

常用消息中间件及其测试方法 一、中间件的使用场景引入中间件的目的一般有两个&#xff1a;1、提升性能常用的中间件&#xff1a;1) 高速缓存&#xff1a;redis2) 全文检索&#xff1a;ES3) 存日志&#xff1a;ELK架构4) 流量削峰&#xff1a;kafka 2、提升可用性产品架构中高可…

Web前端—浏览器渲染原理

浏览器渲染原理 浏览器渲染原理渲染时间点渲染流水线1. 解析HTML—Parse HTML2. 样式计算—Recalculate Style3. 布局—Layout4. 分层—Layer5. 绘制—Paint6. 分块—Tiling7. 光栅化—Raster8. 画—Draw完整过程 面试题1. 浏览器是如何渲染页面的&#xff1f;2. 什么是 reflow…

linux apt 速度慢 换源

Ubuntu 20.04.1 LTS已推出,一样的为期5年的服务,感觉不错,安装了一个,但是苦于使用默认源在国内下载太慢,就想着把apt源改为国内源,目前国内比较好的源,有阿里源,清华源,豆瓣源等,下面我以阿里源为例,说下如何修改。 也可以在中科大https://mirrors.ustc.edu.cn/查…

使用amd架构的计算机部署其他架构的虚拟机(如:arm)

1 下载quem模拟器 https://qemu.weilnetz.de/w64/2 QEMU UEFI固件文件下载(引导文件) 推荐使用&#xff1a;https://releases.linaro.org/components/kernel/uefi-linaro/latest/release/qemu64/QEMU_EFI.fd3 QEMU 安装 安装完成之后&#xff0c;需要将安装目录添加到环境变…

福昕阅读器 PDF 文档基本操作

福昕阅读器 PDF 文档基本操作 References 转至 PDF 顶部 快捷键&#xff1a;Home. 转至 PDF 顶部 快捷键&#xff1a;End. 打开超链接 文本选择工具 -> 手形工具 (Hand Tool) -> 点击超链接 福昕阅读器 同时在多个窗口中打开多个文件 文件 -> 偏好设置 -> 文…

数据库导入文件或者运行文件的时候报错误 #1046 - No database selected

如果我们在使用数据库导入文件的时候报错误 #1046 - No database selected该怎么解决 那么小编带我们可以从三个角度去观察 1、这种情况一般是因为你在数据库中没有这个数据库&#xff0c;你新建一个你要导入的数据库名字的数据库&#xff0c;然后选中该数据库&#xff0c;再进…