深度学习笔记_6经典预训练网络LeNet-18解决FashionMNIST数据集

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
from torchvision import models

# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 50

# 定义数据转换操作
transform = transforms.Compose([
    transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转
    transforms.RandomHorizontalFlip(),   # 随机水平翻转
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,
                        download=True, transform=transform)

# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,
                       download=True, transform=transform)

# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=test_batch_size, shuffle=False)

3、使用预训练的ResNet-18模型

# 使用预训练的ResNet-18模型
model = models.resnet18(pretrained=True)
# 修改最后一层,使其适应FashionMNIST的输出类别数
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

# 冻结预训练模型的参数
for param in model.parameters():
    param.requires_grad = False

# 只训练模型的最后一层
for param in model.fc.parameters():
    param.requires_grad = True
# 初始化优化器和损失函数
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

4、 训练循环

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []

# 训练循环
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # 计算训练准确率
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

    # 计算平均训练损失和训练准确率
    train_loss /= len(train_loader)
    train_accuracy = 100. * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)

    # 测试模型
    model.eval()
    test_loss = 0.0
    correct = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上
            all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上

    # 计算平均测试损失和测试准确率
    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)

    # 计算额外的指标
    conf_matrix = confusion_matrix(all_labels, all_preds)
    conf_matrix_list.append(conf_matrix)

    accuracy = accuracy_score(all_labels, all_preds)
    accuracy_list.append(accuracy)

    error_rate = 1 - accuracy
    error_rate_list.append(error_rate)

    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    precision_list.append(precision)
    recall_list.append(recall)
    f1_score_list.append(f1)

    fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)
    roc_auc = auc(fpr, tpr)
    roc_auc_list.append(roc_auc)

    # 打印每个 epoch 的指标
    print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.show()

# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.show()

# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
class_labels = [str(i) for i in range(10)]
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/254362.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

算法模板之双链表图文详解

🌈个人主页:聆风吟 🔥系列专栏:算法模板、数据结构 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 📋前言一. ⛳️使用数组模拟双链表讲解1.1 🔔为什么我们要使用数组去模拟双链表…

全国巡展“2024人工智能展·世亚智博会”3月上海·4月杭州·6月北京

近年来,我国积极布局人工智能产业,竞跑“未来赛道”。随着各行业、各领域对人工智能需求的日益增长,与实体经济深度融合的新模式不断涌现,形成了具有中国特色的研发体系和应用生态,引领着经济社会各领域从数字化、网络…

YOLOv3-YOLOv8的一些总结

0 写在前面 这个文档主要总结YOLO系列的创新点,以YOLOv3为baseline。参考(抄)了不少博客,就自己看看吧。有些模型的trick不感兴趣就没写进来,核心的都写了。 YOLO系列的网络都由四个部分组成:Input、Backbone、Neck、Prediction…

高新技术企业工时管理的挑战与应对策略

随着科技的飞速发展,高新技术企业已成为推动社会进步的重要力量。而在这类企业中,工时管理作为企业管理的重要组成部分,其意义也日益凸显。有效的工时管理不仅关乎企业的项目进度、人力掌控和资源合理配置,还直接影响到企业的研发…

centos7服务器上的文件上传到谷歌云盘(google drive)

1,下载gdrive客户端,Releases glotlabs/gdrive GitHub 2,下载完解压,并移动到cp gdrive /usr/local/bin/ 3,查看是否安装成功 4,添加账户,gdrive account add 根据链接,创建Client id和 Client secret 5,填写Client…

spring boot 配置多数据源 踩坑 BindingException: Invalid bound statement (not found)

在上一篇:《【已解决】Spring Boot多数据源的时候,mybatis报错提示:Invalid bound statement (not found)》 凯哥(凯哥Java) 已经接受了,在Spring Boot配置多数据源时候,因为自己马虎,导致的一个坑。下面&a…

SEO专业人士成功所需的8大技能

你有能力在SEO领域建立职业生涯吗?您需要某些技能才能成功。在这里了解这些技能是什么。 尽管SEO已经存在了几十年,但许多大学仍然没有教授SEO,也没有在大多数营销课程中提及。 SEO专业人士来自不同的背景。有些是程序员,有些是…

IDA PRO 0A - 交叉引用

本文将讨论IDA中的交叉引用的相关知识。 更多c逆向知识可以看B站的课程《C 反汇编基础教程(IDA Pro Visual Studio)》 交叉引用 IDA 中的交叉引用通常简称为xref 。从名字可以看出,使用快捷键就可以找出某个函数或者数据被引用的地方。 在IDA 中有两类基本的交叉引…

NSSCTF第16页(3)

[SWPUCTF 2023 秋季新生赛]ez_talk 上传一句话木马得到 抓包改文件类型 上传成功,只是倒序而已 得到flag [第五空间 2021]PNG图片转换器 这道题采用的是ruby语言,第一次听说 2021-第五空间智能安全大赛-PNG图片转换器 | 管道符与反引号的配合、open…

使用python实现链表

手写代码 class Node(object):def __init__(self, item):self.item itemself.next Noneclass LinkListFunction(object):"""此对象为Node对象的方法类"""def __init__(self):self.linklistlength 0 # 当前链表的长度def create_linklist_he…

C语言学习第二十六天(算法的时间复杂度和空间复杂度)

1、算法效率 衡量一个算法的好坏,是从时间和空间两个方面来衡量的,换句话说就是从时间复杂度和空间复杂度来衡量的 这里需要补充一点:时间复杂度是衡量一个算法的运行快慢,空间复杂度是主要衡量一个算法运行所需要的额外空间。 …

「Verilog学习笔记」交通灯

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 timescale 1ns/1nsmodule triffic_light(input rst_n, //异位复位信号,低电平有效input clk, //时钟信号input pass_request,output wire[7:0]clock,output reg…

Shared Worker的快速理解与简单应用

SharedWorker 是 HTML5 中引入的一种 WebWorkers 类型,用于在浏览器中创建可在多个浏览器窗口或标签页之间共享的后台线程。Web Workers 是在主线程之外运行的脚本,允许执行一些耗时的任务而不会阻塞用户界面。 对 SharedWorker 的概念、理解和应用的简要…

2023第十七届中国品牌节,酷开科技荣获金谱奖!

11月18日,以“复苏与腾飞”为主题的2023第十七届中国品牌节,在杭州市云栖小镇国际会展中心盛大开幕。来自政界、商界、文化界等领域的6000余名嘉宾出席本次盛会,共同见证了民族品牌的崛起,全力奉献一场史无前例的“品牌人的亚运会…

Pikachu漏洞练习平台之暴力破解(基于burpsuite)

从来没有哪个时代的黑客像今天一样热衷于猜解密码 ---奥斯特洛夫斯基 Burte Force(暴力破解)概述 “暴力破解”是一攻击具手段,在web攻击中,一般会使用这种手段对应用系统的认证信息进行获取。 其过程就是使用大量的认证信息在认…

【操作系统】实验名称: 实验五 文件系统

实验目的: 1. 掌握文件系统的基本概念和工作机制 2. 掌握文件系统的主要数据结构的实现 3、掌握软件系统实现算法 实验内容: 设计并实现一个虚拟的一级(单用户)文件系统程序 提供以下操作 1、文件创建/删除接口命令 2、目录创建/删…

合并 K 个排序链表——Java解答

题目:合并 K 个排序链表 题目描述: 给你一个链表数组,每个链表都已经按升序排列。请你将所有链表合并到一个升序链表中,返回合并后的链表。 示例: 假设有以下三个链表: 1->4->5, 1->3->4,…

QUIC在零信任解决方案的落地实践

一 前言 ZTNA为以“网络为中心”的传统企业体系架构向以“身份为中心”的新型企业安全体系架构转变,提供解决方案。随着传统网络边界不断弱化,企业SaaS规模化日益增多,给终端安全访问接入创造了多元化的空间。其中BYOD办公方式尤为突出&#…

SpringBoot使用@DS配置 多数据源 【mybatisplus druid datasource mysql】

项目最近需要使用多数据源,不同的mapper分别读取不同的链接,本项目使用了mybatisplus druid 来配置多数据源,基于mysql数据库。 目录 1.引入依赖 ​2.配置文件 application.yaml 3.Mapper中使用DS切换数据源 4.使用DS的注意事项 1.引入依…

苹果忽略iPhone?2024可穿戴产品或成重心!

一代版本一代神,即便是强如iPhone也有着被忽视的一天,当然,这么说有些夸张。虽然iPhone永远都是苹果最重要的产品,但在明年,苹果的重心将偏向其他产品。 彭博社记者马克古曼(Mark Gurman)在新一…