深度学习笔记_7经典网络模型LSTM解决FashionMNIST分类问题

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt

# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 设置超参数
sequence_length = 28
input_size = 28  
hidden_size = 128
num_layers = 2 
num_classes = 10
batch_size = 64
learning_rate = 0.001
num_epochs = 50

# 定义数据转换操作
transform = transforms.Compose([
    transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转
    transforms.RandomHorizontalFlip(),   # 随机水平翻转
    transforms.RandomCrop(size=28, padding=4),   # 随机裁剪
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,
                        download=True, transform=transform)

# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,
                       download=True, transform=transform)

# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

3、定义LSTM模型

# 定义LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size  # LSTM隐含层神经元数
        self.num_layers = num_layers  # LSTM层数
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)  # LSTM层
        self.fc = nn.Linear(hidden_size, num_classes)  # 全连接层

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)  # 初始化状态
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, _ = self.lstm(x, (h0, c0))  # LSTM前向传播
        out = self.fc(out[:, -1, :])  # 只取序列最后一个时间步的输出
        return F.log_softmax(out, dim=1)  # 使用log_softmax作为输出

# 初始化模型、优化器和损失函数
model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []

4、 训练循环

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
        data = data.view(-1, sequence_length, input_size)

        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # 计算训练准确率
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

    # 计算平均训练损失和训练准确率
    train_loss /= len(train_loader)
    train_accuracy = 100. * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)

    # 测试模型
    model.eval()
    test_loss = 0.0
    correct = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
            data = data.view(-1, sequence_length, input_size)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上
            all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上

    # 计算平均测试损失和测试准确率
    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)

    # 计算额外的指标
    conf_matrix = confusion_matrix(all_labels, all_preds)
    conf_matrix_list.append(conf_matrix)

    accuracy = accuracy_score(all_labels, all_preds)
    accuracy_list.append(accuracy)

    error_rate = 1 - accuracy
    error_rate_list.append(error_rate)

    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    precision_list.append(precision)
    recall_list.append(recall)
    f1_score_list.append(f1)

    fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)
    roc_auc = auc(fpr, tpr)
    roc_auc_list.append(roc_auc)

    # 打印每个 epoch 的指标
    print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# 打印或绘制训练后的最终指标
print(f'Final Confusion Matrix:\n{conf_matrix_list[-1]}')
print(f'Final Accuracy: {accuracy_list[-1]:.2%}')
print(f'Final Error Rate: {error_rate_list[-1]:.2%}')
print(f'Final Precision: {precision_list[-1]:.2%}')
print(f'Final Recall: {recall_list[-1]:.2%}')
print(f'Final F1 Score: {f1_score_list[-1]:.2%}')
print(f'Final ROC AUC: {roc_auc_list[-1]:.2%}')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()


# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()


# 计算混淆矩阵
class_labels = [str(i) for i in range(10)]
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/256992.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

自定义 spring-boot组件自动注入starter

1&#xff1a;创建maven项目 2&#xff1a;pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocati…

stm32与Freertos入门(二)移植FreeRTOS到STM32中

简介 注意&#xff1a;FreeRTOS并不是实时操作系统&#xff0c;而是分时复用的&#xff0c;只不过切换频率很快&#xff0c;感觉上是同时在工作。本次使用的单片机型号为STM32F103C8T6,通过CubeMX快速移植。 一、CubeMX快速移植 1、选择芯片 打开CubeMX软件&#xff0c;进行…

轻量封装WebGPU渲染系统示例<53>- 多盏灯灯光照在地面的效果

WebGPU实时渲染实现模拟多盏灯的灯光照在地面的效果灯光效果 。 当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/material/src/voxgpu/sample/MultiLightsTest.ts 当前示例运行效果: 此示例基于此渲染系统实现&#xff0c;当前示例TypeScript源…

Ribbon负载均衡原理、策略、饥饿加载

Ribbon负载均衡原理、策略、饥饿加载 MapperScan("cn.itcast.order.mapper") SpringBootApplication public class OrderApplication {public static void main(String[] args) {SpringApplication.run(OrderApplication.class, args);}/*** 完成创建RestTemplate&am…

虾皮 选品:如何在虾皮平台上进行有效的选品?

在虾皮&#xff08;Shopee&#xff09;这个跨境电商平台上&#xff0c;选品对于卖家来说至关重要。选品决定了店铺的销售额和竞争力。为了帮助卖家进行选品&#xff0c;虾皮平台提供了一些免费的选品工具&#xff0c;如知虾。同时&#xff0c;还有一些第三方选品工具&#xff0…

支持可视化提取变量,Apipost配置变量不要太简单

在调试接口时我们需要将响应结果中的某个字段配置为环境变量在其他接口中引用&#xff0c;之前在Apipost中需要配置脚本而在最近Apipost后执行操作中可以进行可视化的断言和变量提取&#xff0c;无需配置繁琐脚本。 这里我们在登录接口下配置一条Token环境变量&#xff0c;在后…

去面试性能测试工程师必问的问题,

性能测试的三个核心原理是什么&#xff1f; 1.基于协议。性能测试的对象是网络分布式架构的软件&#xff0c;而网络分布式架构的核心是网络协议 2.多线程。人的大脑是单线程的&#xff0c;电脑的cpu是多线程的。性能测试就是利用多线程的技术模拟多用户去负载 3.模拟真实场景。…

外卖系统海外版:技术智能引领全球美食新潮流

随着全球数字化浪潮的推动&#xff0c;外卖系统海外版不仅是食客们品味美食的便捷通道&#xff0c;更是技术智能在美食领域的引领者。本文将深入剖析其背后的技术实现&#xff0c;揭开代码带来的美食革新。 多语言支持&#xff1a;构建全球美食沟通桥梁 def multilingual_su…

机器人也能干的更好:RPA技术的优势和应用场景

RPA是什么&#xff1f; 机器人流程自动化RPA&#xff08;Robotic Process Automation&#xff09;是一种自动化技术&#xff0c;它使用软件机器人来高效完成重复且有逻辑性的工作。近年来&#xff0c;随着人工智能和自动化技术的不断发展和普及&#xff0c;RPA已经成为企业提高…

双指针——找到字符串中的所有字母异位词

https://leetcode.cn/problems/find-all-anagrams-in-a-string/description/?envTypestudy-plan-v2&envIdtop-100-liked 双指针&#xff0c;每次都统计出来p长度的滑动窗口里的数字,拿Arrays.equals进行对比,然后滑动一小格&#xff0c;减1加1继续比对即可。 class Solut…

2023年12月最新软件测试面试题(带答案)

1. 请自我介绍一下(需简单清楚的表述自已的基本情况&#xff0c;在这过程中要展现出自信&#xff0c;对工作有激情&#xff0c;上进&#xff0c;好学) 面试官您好&#xff0c;我叫###&#xff0c;今年26岁&#xff0c;来自江西九江&#xff0c;就读专业是电子商务&#xff0c;毕…

教你如何使用天眼销查企业客户

天眼销是一款能够帮助客户获取最新的企业联系方式、工商信息等关键数据的平台。 平台基于先进的技术和大数据解决方案&#xff0c;深入挖掘和分析企业信息&#xff0c;能够高效地收集、整理和存储各类企业数据&#xff0c;为用户提供360度视角和洞察&#xff1b;提供全面、准确…

从零开始掌握MAYA 2022:打造视觉创意的艺术大师之路

&#x1f482; 个人网站:【 海拥】【神级代码资源网站】【办公神器】&#x1f91f; 基于Web端打造的&#xff1a;&#x1f449;轻量化工具创作平台&#x1f485; 想寻找共同学习交流的小伙伴&#xff0c;请点击【全栈技术交流群】 Autodesk Maya是一款强大的三维计算机图形软件…

ROS-安装Rviz

安装 运行下列命令进行安装&#xff0c;xxxxxx处更改为自己的版本 sudo apt-get install ros-xxxxxx-rviz运行 先打开roscore roscore再运行rviz rviz参考&#xff1a; [1]https://blog.csdn.net/qq_66540741/article/details/134400248 [2]https://blog.csdn.net/baidu_384…

PostgreSQL入门指南:快速学会创建和管理数据库!

当谈到数据库管理系统时&#xff0c;PostgreSQL是一个功能强大且广泛使用的开源关系型数据库。在本次讲解中&#xff0c;我将为您介绍如何创建和管理数据库&#xff0c;并提供一些有关PostgreSQL的基本概念和最佳实践的指导。 创建数据库 在开始之前&#xff0c;请确保您已经成…

服务器解析漏洞有哪些?IIS\APACHE\NGINX解析漏洞利用

解析漏洞是指在Web服务器处理用户请求时&#xff0c;对输入数据&#xff08;如文件名、参数等&#xff09;进行解析时产生的漏洞。这种漏洞可能导致服务器对用户提供的数据进行错误解析&#xff0c;使攻击者能够执行未经授权的操作。解析漏洞通常涉及到对用户输入的信任不足&am…

python爬虫进阶篇:利用Scrapy爬取同花顺个股行情并发送邮件通知

一、前言 上篇笔记我记录了scrapy的环境搭建和项目创建和第一次demo测试。本篇我们来结合现实场景利用scrapy给我们带来便利。 有炒股或者其它理财产品的朋友经常会关心每日的个股走势&#xff0c;如果结合爬虫进行实时通知自己&#xff0c;并根据自己预想的行情进行邮件通知&…

Java对接腾讯多人音视频房间示例

最近在对接腾讯的多人音视频房间&#xff0c;做一个类似于腾讯会议的工具&#xff0c;至于为什么不直接用腾讯会议&#xff0c;这个我也不知道&#xff0c;当然我也不敢问 首先是腾讯官方的文档地址&#xff1a;https://cloud.tencent.com/document/product/1690 我是后端所以…

【Spring】11 EnvironmentAware 接口

文章目录 1. 简介2. 作用3. 使用3.1 创建并实现接口3.2 配置 Bean 信息3.3 创建启动类3.4 启动 4. 应用场景总结 Spring 框架为开发者提供了丰富的扩展点&#xff0c;其中之一就是 Bean 生命周期中的回调接口。本文将着重介绍一个与环境&#xff08;Environment&#xff09;相关…