深度学习:Matplotlib篇

一、简介

1.1 什么是 Matplotlib?

Matplotlib 是一个广泛使用的 2D 绘图库,它可以用来在 Python 中创建各种静态、动态和交互式的图表。无论是科学计算、数据可视化,还是深度学习模型的训练与评估,Matplotlib 都能提供强大的图形展示功能。在深度学习领域,Matplotlib 通常用于可视化训练过程中的损失函数、准确率曲线以及各种训练结果

1.2 为什么会在深度学习中使用 Matplotlib?

在深度学习中,使用 Matplotlib 可以帮助开发者和研究人员更直观地理解模型的性能。常见的应用包括:

  • 绘制训练过程中的损失和精度曲线,以监控模型是否过拟合或欠拟合
  • 展示分类模型的混淆矩阵,以可视化分类错误类型
  • 可视化图像特征或特征图,以深入理解卷积神经网络(CNN)内部的工作机制

二、基本使用方法

2.1 导入

在使用之前,需要先导入其主要模块。一般来说我们会将 Matplotlib.pyplot 模块导入并简写为 plt,为了代码看起来简洁

import matplotlib.pyplot as plt

2.2 基本绘图操作

这里以折线图为例

# 创建一些数据
x = [1, 2, 3, 4, 5]
y = [1, 4, 9, 16, 25]

# 绘制折线图
plt.plot(x, y)

# 显示图像
plt.show()

2.3 标题、标签与图例

添加标题、轴标签以及图例让图表更具信息性

plt.plot(x, y, label='y = x^2')
plt.title('Example of a Line Plot')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.legend()
plt.show()

2.4 调整图像样式

Matplotlib 支持多种样式的图表,如折线图、柱状图、散点图等。在绘图时,我们也可以调整线条颜色、样式以及图形的其他参数

plt.plot(x, y, color='green', linestyle='--', marker='o', label='y = x^2')
plt.title('Customized Line Plot')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.legend()
plt.show()

三、在深度学习中的应用

3.1 可视化训练过程

在深度学习模型训练过程中,最常见的做法是通过 Matplotlib 绘制训练和验证集上的损失曲线及准确率曲线。这些图表能够帮助我们判断模型的表现,分析是否存在过拟合或欠拟合的现象

import matplotlib.pyplot as plt

# 假设有训练和验证的损失和准确率数据 随便取的
epochs = range(1, 11)
train_loss = [0.8, 0.6, 0.4, 0.3, 0.2, 0.15, 0.1, 0.08, 0.06, 0.04]
val_loss = [0.9, 0.7, 0.5, 0.4, 0.3, 0.25, 0.22, 0.21, 0.2, 0.19]

# 绘制损失曲线
plt.plot(epochs, train_loss, 'bo-', label='Training loss')
plt.plot(epochs, val_loss, 'ro-', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

通过这种方法,我们就可以很直观地观察模型在每个 epoch 训练后的损失变化趋势,尤其是训练集和验证集的差异可以帮助判断是否发生过拟合

3.2 可视化卷积神经网络的特征图

在深度学习中,卷积神经网络(CNN)可以学习到图像的层次化特征。在某些场景下,我们希望可视化这些特征图,以更好地理解网络的工作机制。这时就可以通过 Matplotlib 将卷积层的输出特征图绘制出来

import torch
import matplotlib.pyplot as plt

# 假设我们有一个 CNN 模型和一张输入图像
model = ...
image = ...

# 获取卷积层的输出
features = model.conv1(image)

# 可视化第一层卷积后的特征图
fig, axarr = plt.subplots(1, 4)
for idx in range(4):
    axarr[idx].imshow(features[0, idx].detach().numpy(), cmap='gray')
plt.show()

给你们补全一下

import torch  
import torch.nn as nn  
import torchvision.transforms as transforms  
from PIL import Image  
import matplotlib.pyplot as plt  
  
# 定义一个简单的 CNN 模型  
class SimpleCNN(nn.Module):  
    def __init__(self):  
        super(SimpleCNN, self).__init__()  
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1)  
        # 添加其他层如果需要的话,这里仅作为示例  
  
    def forward(self, x):  
        x = self.conv1(x)  
        # x = ... 其他操作  
        return x  
  
# 实例化模型  
model = SimpleCNN()  
  
# 将模型移动到 GPU(如果可用),或者保持在 CPU 上  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
model.to(device)  
  
# 加载并预处理一张图像  
# 假设我们使用一张 RGB 图像,尺寸为 224x224(根据模型需求调整)  
transform = transforms.Compose([  
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),  # 将图像转换为 PyTorch 张量,并归一化到 [0, 1]  
])  
  
# 这里我们加载一个示例图像,你需要提供实际的图像路径  
image_path = 'path_to_your_image.jpg'  # 替换为你的图像路径  
image = Image.open(image_path).convert('RGB')  
image = transform(image).unsqueeze(0)  # 增加一个 batch 维度  
image = image.to(device)  # 确保图像在正确的设备上  
  
# 获取卷积层的输出  
with torch.no_grad():  # 我们不需要计算梯度  
    features = model.conv1(image)  
  
# 可视化第一层卷积后的特征图  
features = features.squeeze(0).cpu()  # 移除 batch 维度,并移动到 CPU  
fig, axarr = plt.subplots(1, 4, figsize=(12, 3))  
for idx in range(4):  
    axarr[idx].imshow(features[idx, :, :].numpy(), cmap='gray')  # 注意这里的索引可能需要根据实际的输出形状调整  
    axarr[idx].axis('off')  # 关闭坐标轴  
plt.show()

 看看效果:

原图:

3.3 混淆矩阵的可视化

混淆矩阵用于衡量分类模型性能,它、可以直观展示模型在不同类别上的分类正确率和错误率

import seaborn as sns
from sklearn.metrics import confusion_matrix

# 假设我们有预测值和真实值
y_true = [0, 1, 2, 2, 0, 1, 1, 2]
y_pred = [0, 0, 2, 2, 0, 2, 1, 2]

# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)

# 可视化混淆矩阵
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

这种可视化方法可以帮助我们深入分析模型在哪些类别上分类错误较多,进而做出相应的改进措施

四、进阶技巧

4.1 子图与多图显示

在深度学习中,往往需要同时观察多个图表(如损失和准确率),这时可以通过 subplot 函数来实现多个子图的排列显示

import matplotlib.pyplot as plt  
import numpy as np  
  
# 假设我们已经有了一些训练数据  
# 这些数据通常是在训练循环中收集的  
  
# 示例数据(您应该使用您的实际数据替换这些)  
epochs = np.arange(1, 21)  # 假设我们训练了20个epoch  
train_loss = np.linspace(0.5, 0.1, 20)  # 假设训练损失从0.5线性降低到0.1  
val_loss = np.linspace(0.55, 0.15, 20)  # 假设验证损失从0.55线性降低到0.15  
train_acc = np.linspace(0.5, 0.9, 20)   # 假设训练准确率从0.5线性增加到0.9  
val_acc = np.linspace(0.45, 0.85, 20)   # 假设验证准确率从0.45线性增加到0.85  
  
# 创建子图  
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  
  
# 绘制第一个图(损失曲线)  
ax1.plot(epochs, train_loss, 'b-', label='Training loss')  
ax1.plot(epochs, val_loss, 'r-', label='Validation loss')  
ax1.set_title('Loss Over Epochs')  
ax1.set_xlabel('Epochs')  
ax1.set_ylabel('Loss')  
ax1.legend()  
ax1.grid(True)  # 可选:添加网格线  
  
# 绘制第二个图(准确率曲线)  
ax2.plot(epochs, train_acc, 'b-', label='Training accuracy')  
ax2.plot(epochs, val_acc, 'r-', label='Validation accuracy')  
ax2.set_title('Accuracy Over Epochs')  
ax2.set_xlabel('Epochs')  
ax2.set_ylabel('Accuracy')  
ax2.legend()  
ax2.grid(True)  # 可选:添加网格线  
  
# 显示图形  
plt.tight_layout()  # 可选:调整子图之间的间距  
plt.show()

4.2 动态绘图

在深度学习训练过程中,有时我们希望能实时查看训练的进展。这时可以利用 Matplotlib 的动态绘图功能,通过 plt.ion() 实现图表的实时更新

import matplotlib.pyplot as plt  
import random  # 用于生成模拟数据  
  
# 初始化列表  
epochs = []  
train_loss = []  
val_loss = []  
  
# 开启交互模式  
plt.ion()  
  
# 模拟10个epoch的训练过程  
for epoch in range(10):  
    # 更新epochs列表(虽然在这种情况下,我们可以直接使用range(11)来绘制,但为了与您的代码一致,我们还是更新这个列表)  
    epochs.append(epoch + 1)  
      
    # 模拟新的训练损失和验证损失(在实际应用中,这些值将来自您的训练循环)  
    new_train_loss = random.uniform(0.1, 1.0)  # 生成一个0.1到1.0之间的随机浮点数  
    new_val_loss = random.uniform(0.1, 1.0)    # 生成另一个0.1到1.0之间的随机浮点数  
      
    # 将新的损失值添加到列表中  
    train_loss.append(new_train_loss)  
    val_loss.append(new_val_loss)  
      
    # 清除上一帧图像  
    plt.clf()  
      
    # 绘制新的曲线  
    plt.plot(epochs, train_loss, 'b-', label='Training loss')  
    plt.plot(epochs, val_loss, 'r-', label='Validation loss')  
    plt.xlabel('Epoch')  
    plt.ylabel('Loss')  
    plt.title('Loss Over Epochs')  
    plt.legend()  
    plt.grid(True)  # 可选:添加网格线  
      
    # 暂停一段时间以更新图形(0.1秒)  
    plt.pause(0.1)  
  
# 关闭交互模式(在显示最终图形之前通常不需要这样做,因为plt.show()会处理它)  
# 但为了与您的代码一致,我们还是包含了这个调用  
plt.ioff()  
  
# 显示最终图形(在交互模式下,这通常不是必需的,因为图形已经在循环中更新了)  
# 但由于我们包含了plt.ioff(),所以我们需要调用plt.show()来确保图形显示出来  
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/902415.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

虚拟现实新纪元:VR/AR技术将如何改变娱乐与教育

内容概要 在当今科技飞速发展的时代,虚拟现实(VR)和增强现实(AR)技术不仅让我们的娱乐体验如虎添翼,更为教育变革注入了新活力。这些技术的飞跃进展,将原本平淡无奇的场景转变为令人沉醉的沉浸…

深入浅出 C++ STL:解锁高效编程的秘密武器

引言 C 标准模板库(STL)是现代 C 的核心部分之一,为开发者提供了丰富的预定义数据结构和算法,极大地提升了编程效率和代码的可读性。理解和掌握 STL 对于 C 开发者来说至关重要。以下是对 STL 的详细介绍,涵盖其基础知…

面向对象编程中类与类之间的关系(一)

目录 1.引言 2."有一个"关系 3."是一个"关系(继承) 4.“有一个”与“是一个”的区别 5.not-a关系 6.层次结构 7.多重继承 8.混入类 1.引言 作为程序员,必然会遇到这样的情况:不同的类具有共同的特征,至少看起来彼…

JavaWeb——Web入门(1/9)-Spring Boot Web介绍(Spring家族,Spring Boot)

目录 Spring家族 Spring Boot 在我们了解完了 Maven 这款项目构建工具的基本使用之后,接下来我们正式的进入到 Web 后端开发的学习。 第一篇章要了解的是 Spring Boot Web 的入门。 在正式开始之前,我们先需要介绍一下什么是 Spring 以及什么是 Spri…

H3C Hybrid 实验

实验拓扑 图 1-1 注:如无特别说明,描述中的 R1 或 SW1 对应拓扑中设备名称末尾数字为 1 的设备,R2 或 SW2 对应拓扑中设备名称末尾数字为 2 的设备,以此类推;另外,同一网段中,IP 地址的主机位为…

【NOI】C++函数入门二(自定义函数)

文章目录 前言一、概念1.导入1.1 首先什么是函数呢? 2.函数分类3.为什么要定义函数呢?4.函数结构5.函数使用注意事项 二、例题讲解问题:1137 - 纯粹素数问题:1258 - 求一个三位数问题:1140 - 亲密数对问题:…

Flutter仿京东商城APP实战 用户中心基础布局

用户中心界面 pages/tabs/user/user.dart import package:flutter/material.dart; import package:jdshop/utils/zdp_screen.dart; import package:provider/provider.dart;import ../../../store/counter_store.dart;class UserPage extends StatefulWidget {const UserPage…

如何在Node.js中执行解压缩文件操作

一、解压文件 1.安装依赖: 安装adm-zip依赖包:npm install adm-zip --save 安装iconv-lite依赖包:npm install iconv-lite --save 解压前的file文件夹结构: update-1.0.2.zip压缩包内容: 2.在depresssFile.js文件&…

Vue学习笔记(六)

模板引用(获取DOM 操作) 虽然Vue的声明性渲染模型为你抽象了大部分对DOM的直接操作&#xff0c;但在某些情况下&#xff0c;我们仍然需要直接访问底层DOM元素。要实现这一点&#xff0c;我们可以使用特殊的refattribute。 挂载结束后引用都会被暴露在this.$refs之上。 <s…

华为原生鸿蒙操作系统:我国移动操作系统的新篇章

华为原生鸿蒙操作系统&#xff1a;我国移动操作系统的新篇章 引言 在移动操作系统领域&#xff0c;苹果iOS和安卓系统一直占据主导地位。然而&#xff0c;随着华为原生鸿蒙操作系统的正式发布&#xff0c;这一格局正在发生深刻变化。作为继苹果iOS和安卓系统后的全球第三大移动…

android studio编译错误提示无法下载仓库

一、调整方法之一 buildscript {repositories {google()jcenter()//maven { url https://maven.aliyun.com/repository/google }//maven { url https://maven.aliyun.com/repository/central }}dependencies {// classpath "com.android.tools.build:gradle:4.1.1"c…

Prompt Engineering (Prompt工程)

2 prompt工程2大原则 2.1 给出清晰&#xff0c;详细的指令 策略1&#xff1a;使用分割符清晰的指示输出的不同部分&#xff0c;比如"",<>,<\tag>等分隔符 策略2&#xff1a;指定一个结构化的输出&#xff0c;比如json,html等格式 策略3&#xff1a;要…

C++STL之stack

1.stack的使用 函数说明 接口说明 stack() 构造空的栈 empty() 检测 stack 是否为空 size() 返回 stack 中元素的个数 top() 返回栈顶元素的引用 push() 将元素 val 压入 stack 中 pop() 将 stack 中尾部的元素弹出 2.stack的模拟实现 #include<vector> namespace abc { …

监控-08-skywalking监控告警

文章目录 前言一、准备二、配置skywalking2.1 修改alarm-settings.yml2.2 重启skywalking 三、收到告警邮件总结 前言 skywalking根据监控规则&#xff0c;通过webhook调后端微服务接口&#xff0c;从而发送告警邮件。 一、准备 根据上几章内容&#xff0c;保证skywalking能监…

Spring IoC DI

博主主页: 码农派大星. 数据结构专栏:Java数据结构 数据库专栏:MySQL数据库 JavaEE专栏:JavaEE 关注博主带你了解更多数据结构知识 目录 1. 应用分层 1.1 如何分层: 1.2 MVC与三层架构区别联系 2. Spring 3.IoC & DI⼊⻔ 3.1 什么是IoC&#xff1f; 3.2 DI 介绍 …

CANFD SSP第二采样点引发的“风波”分析

案例背景&#xff1a; 近几年来&#xff0c;主机厂逐渐大范围使用CANFD通信。在CAN网络中&#xff0c;因SSP第二采样点引发的错误帧偶有发生&#xff0c;所以在主机厂的SPEC需求中&#xff0c;明确要求启用CAN控制器的TDC收发器延迟补偿&#xff0c; 目录 1 CANFD的SSP第二采…

【设计模式】《Java 设计模式魔法:解锁高效编程的秘密武器》

标题&#xff1a;《Java 设计模式奇幻之旅&#xff1a;解锁高效编程的魔法钥匙》 摘要&#xff1a; 本文将深入探讨 Java 中的十种设计模式&#xff0c;包括单例模式、工厂方法模式、抽象工厂模式…迭代器模式、组合模式、模板方法模式等。通过详细的解释、生动有趣的例子以及…

【skywalking 】选择Elasticsearch存储

介绍 skywalking支持 Elasticsearch 和 OpenSearch 作为存储。 OpenSearch 是 ElasticSearch 7.11 的一个分支&#xff0c;但在 Apache 2.0 中获得许可。 OpenSearch 存储与 ElasticSearch 共享相同的配置。为了激活 OpenSearch 作为存储&#xff0c;请将存储提供程序设置为e…

Python4

4. 更多控制流工具 除了刚介绍的 while 语句&#xff0c;Python 还用了一些别的。我们将在本章中遇到它们。 4.1. if 语句 if elif else if x<0: x 0 print(Negative changed to zero) elif x0: print( zero) else: print(More) 4.2. for 语句 Pyth…

Mybatis的关联关系-多对多

在进行数据库原理的时候&#xff0c;我们将E-R图的实体转化为我们的表时&#xff0c;有时要考虑到多对多的关系。比如下图&#xff1a; 我们可以转化为下面的表&#xff1a; 因为User和Orders是1:n的关系&#xff0c;所以Orders有一个外键。 t_orders表 idnumberuser_id(外键…