神经网络入门实战:(六)PyTorch 中的实用工具 SummaryWriter 和 TensorBoard 的说明

(一) SummaryWriter

这里先讲解 SummaryWriter ,TensorBoard 会在第二大点进行说明。

SummaryWriter 是 PyTorch 中的一个非常实用的工具,它主要用于将深度学习模型训练过程中的各种日志和统计数据记录下来,并可以与 TensorBoard 配合使用,实现数据的可视化。以下是对 SummaryWriter 库的详细介绍:

1)概述

SummaryWriter 是 torch.utils.tensorboard 包中的一个类,它允许用户将训练过程中的关键信息(如损失值、准确率、学习率、模型权重分布、图像等)写入到指定的事件文件中。这些信息随后可以被 TensorBoard 解析和展示,从而帮助开发者更好地理解和监控模型的训练过程。

2)代码使用步骤

  1. 安装 TensorBoard:在使用 SummaryWriter 之前,需要确保已经安装了 TensorBoard 。可以使用 pip 命令进行安装:pip install tensorboard
  2. 导入 SummaryWriter:在代码中导入 SummaryWriter 类:from torch.utils.tensorboard import SummaryWriter
  3. 实例化 SummaryWriter:创建一个 SummaryWriter 对象,并指定一个日志目录(log_dir),用于保存事件文件。
    • 例如:writer = SummaryWriter('runs/my_experiment') ,或者 writer = SummaryWriter('logs')
    • 如果没有指定日志目录,那么当开始记录数据时,会自动在当前代码的目录下创建一个名为 runs 的文件夹,同时在此文件夹下创建一个以当前时间、日期和主机名命名的子目录,用于保存时间文件,类似于 ./runs/YYYYMMDD_HHMMSS_hostname/ 这样的路径。
  4. 记录数据(第三小点会详细介绍):在训练过程中,使用 SummaryWriter 对象的各种方法记录需要的数据。例如,使用 add_scalar 记录损失值,使用 add_histogram 记录权重分布等。
  5. 关闭 SummaryWriter:在训练结束后,调用 writer.close() 方法关闭 SummaryWriter 对象,确保所有数据都被正确写入事件文件。
  6. 启动 TensorBoard(见下方第二大点):在命令行中使用 tensorboard --logdir=事件文件所在的文件夹名 命令启动 TensorBoard服务。然后,在浏览器中访问 http://localhost:6006,就可以看到 TensorBoard 的可视化界面了。

3)记录不同数据的代码

  1. 记录标量信息:使用 add_scalar 方法,可以记录如 损失值、准确率 等标量信息。这些信息通常以曲线图的形式在 TensorBoard 中展示,便于观察其变化趋势。

    writer = SummaryWriter('logs')
    writer.add_scalar('曲线图标题',scalar_value,global_step)
    # scalar_value 表示曲线图纵坐标数值
    # global_step 表示曲线图横坐标数值
    
  2. 记录张量信息:通过 add_histogram 等方法,可以记录 模型权重、梯度 等张量信息。TensorBoard 会以直方图的形式展示这些张量的分布,有助于分析模型的稳定性和收敛性。

  3. 记录图像信息(注意图像格式):使用 add_image 方法记录 图像 数据。这对于处理图像任务的模型来说尤其有用,因为可以直观地看到模型对输入图像的预测结果或中间层的特征图。

    writer = SummaryWriter('logs')
    writer.add_image("图片标题",image_tensor,global_step,dataformats='HWC')
    # image_tensor 表示图片的数据格式,只能是 torch.tensor 或者 numpy.ndarray 或者 string 格式
    # global_step 表示步长,从0开始!!
    # dataformats 表示图片的具体格式是CHW(通道数,高,宽)还是HWC(高,宽,通道数)还是HW。默认是CHW
    
    • torch.tensor 格式的图片,可以由 transforms 工具转成,具体会在后续 transforms 模块中进行讲解。

      在将图片转换成 tensor 格式时,代码会自动将像素值从(0,255)放缩到(0,1)之间

    • numpy.ndarray 格式的图片,有两种生成方式:

      • 通过 opencv 库中的工具直接读取:

        # 导入cv2库 
        img_path = "dataset/train/bees/bees(3).jpg"
        img_cv = cv2.imread(img_path)
        
      • 通过 numpy 库,间接生成:

        # 导入PIL.Image库,和numpy库
        img_path = "dataset/train/bees/bees(3).jpg"
        img_PIL = Image.open(image_path)
        img_array = np.array(img_PIL)
        

      nimpy.ndarray 格式的图片,是由 RGB 三个通道的值组成的,每个元素的值,都在(0,255)之间。

  4. 记录模型结构:在训练开始前,可以使用add_graph方法记录模型的结构。这有助于开发者理解模型的复杂性,并在TensorBoard 中直观地查看模型的层次结构和参数。

4)注意事项

  1. 日志目录的选择:为了避免日志文件的混乱,建议为每个实验或模型训练任务指定一个唯一的日志目录。
  2. 数据的实时性:SummaryWriter 是异步更新文件内容的,这意味着在训练过程中记录的数据可能不会立即显示在TensorBoard 中。但是,这并不会影响数据的准确性和完整性。
  3. 资源的占用:长时间运行 TensorBoard 可能会占用较多的系统资源。因此,建议在需要时才启动 TensorBoard 服务,并在完成后及时关闭。

5)示例

代码中的 SummaryWriter 中的字母 SW 一定要大写。

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('logs')

# y=x
for i in range(100):
	writer.add_scalar('y=x', i, i,)

writer.close()

(二) TensorBoard

要确保安装的 TensorBoard 版本与 PyTorch 兼容,一般使用指令 pip install tensorboard==2.12.0 来安装(此时的 torch 版本为 2.4.1),原先安装好了也可以通过此指令进行覆盖。

1)使用方法

查看记录好的数据日志文件,将其可视化

运行下面这两个命令时,TensorBoard 会启动一个本地服务器,并在 默认网页浏览器中 打开一个新的标签页或窗口,显示 TensorBoard 的用户界面。在这个界面中,可以看到各种图表和可视化工具,它们展示了训练过程中记录的各种指标,如损失、准确率、模型参数分布等。

  • 默认端口号:

    tensorboard --logdir=日志目录路径
    # 例如:tensorboard --logdir=logs,logs就在当前代码文件夹中,是相对路径
    

    日志目录路径:

    • 如果 日志目录路径 是一个相对路径(即不以斜杠 / 开头),那么它会被解释为相对于你当前工作目录的路径。
    • 如果 日志目录路径 是一个绝对路径(即以斜杠 / 开头,或者在 Windows 上以盘符开头后跟冒号和斜杠,如 C:/logs),那么它会被直接解释为那个位置的路径。
  • 指定端口号:

    tensorboard --logdir=日志目录路径 --port=指定端口号
    # 例如:tensorboard --logdir=logs --port=6007,logs就在当前代码文件夹中,是相对路径
    

2)问题及解决办法

  • 已经画了一个图像,第二次画图时,如果图像内容变了,但是图像标题没变的话,两个图像就会重合紊乱,

    建议换标题,或者删除所有的日志文件,重新运行

  • 其他的问题,诸如显示不全、显示不出等,都可以通过删除所有的日志文件,换个端口重新运行来解决。

3)演示(结合 SummaryWriter )

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
import cv2

writer = SummaryWriter('logs')

image_path = "E:\\4_Data_sets\\1\\train\\ants_image\\ants(2).jpg"
img_cv = cv2.imread(image_path)
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)

print("img_cv :",type(img_cv))
print("img_PIL :",type(img_PIL))
print("img_array :",type(img_array))
print(img_array.shape)

writer.add_image("test",img_array,2,dataformats='HWC')
# y=x
for i in range(100):
	writer.add_scalar('y=x', i, i,)

writer.close()
-----------------------------------------------------------------------------------------------------------------
# 运行结果:
img_cv : <class 'numpy.ndarray'>
img_PIL : <class 'PIL.JpegImagePlugin.JpegImageFile'>
img_array : <class 'numpy.ndarray'>
(375, 500, 3)

在终端执行 tensorboard --logdir=logs --port=6007 并打开相应网址之后的运行结果:
在这里插入图片描述

上一篇下一篇
神经网络入门实战(五)神经网络入门实战(七)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/930172.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C#实现一个HttpClient集成通义千问-开发前准备

集成一个在线大模型&#xff08;如通义千问&#xff09;&#xff0c;来开发一个chat对话类型的ai应用&#xff0c;我需要先了解OpenAI的API文档&#xff0c;请求和返回的参数都是以相关接口文档的标准进行的 相关文档 OpenAI API文档 https://platform.openai.com/docs/api-…

开发知识点-uniCloud

开发知识点-uniCloud 服务空间云函数 cloudfunctions云对象importObjectJSON 格式的文档型数据库Collection unicloud数据的指定表集合 DB SchemaJQL 语法参考资料 服务空间 项目关联空间 云函数 cloudfunctions 云对象importObject JSON 格式的文档型数据库 nosql 非关系…

Vue Web开发(二)

1. 项目搭建 1.1. 首页架子搭建 使用Element ui中的Container布局容器&#xff0c;选择倒数第二个样式&#xff0c;将代码复制到Home.vue。 1.1.1.下载less &#xff08;1&#xff09;下载less样式 npm i less   &#xff08;2&#xff09;下载less编辑解析器 npm i less…

GWAS分析先做后学

大家好&#xff0c;我是邓飞。 GWAS分析是生物信息和统计学的交叉学科&#xff0c;上可以学习编程&#xff0c;下可以学习统计。对于Linux系统&#xff0c;R语言&#xff0c;作图&#xff0c;统计学&#xff0c;机器学习等方向&#xff0c;都是一个极好的入门项目。生物信息如…

Go学习:变量

目录 1. 变量的命名 2. 变量的声明 3. 变量声明时注意事项 4. 变量的初始化 5. 简单例子 变量主要用来存储数据信息&#xff0c;变量的值可以通过变量名进行访问。 1. 变量的命名 在Go语言中&#xff0c;变量名的命名规则 与其他编程语言一样&#xff0c;都是由字母、数…

Netty 心跳机制示例 —— 服务端实现

Netty 心跳机制示例 —— 服务端实现 1. 背景 在分布式系统和网络通信中&#xff0c;保持客户端与服务器端的连接活跃是非常重要的。如果长时间没有数据传输&#xff0c;连接可能会超时或被中断。为了解决这个问题&#xff0c;我们可以通过 心跳机制 来保证连接持续有效。 N…

【Linux】 进程池 一主多从 管道通信

目录 1.代码介绍 2.channel 类 3.进程池类编写 4.主函数及其他 5. 源码 1.代码介绍 本文代码采用一主多从式&#xff08;一个主进程&#xff08;master&#xff09;多个子进程&#xff08;worker&#xff09;&#xff09;通过管道进行通信&#xff0c;实现主进程分发任务&…

小红薯最新x-s 算法补环境教程12-06更新(下)

在上一篇文章中已经讲了如何去定位x-s生成的位置&#xff0c;本篇文章就直接开始撸代码吧 如果没看过的话可以看&#xff1a;小红薯最新x-s算法分析12-06&#xff08;x-s 56&#xff09;&#xff08;上&#xff09;-CSDN博客 1、获取加密块代码 首先来到参数生成的位置&…

Nacos源码学习-本地环境搭建

本文主要记录如何在本地搭建Nacos调试环境来进一步学习其源码&#xff0c;如果你也刚好刷到这篇文章&#xff0c;希望对你有所帮助。 1、本地环境准备 Maven: 3.5.4 Java: 1.8 开发工具&#xff1a;idea 版本控制工具: git 2、下载源码 官方仓库地址 &#xff1a;https://git…

视频码率到底是什么?详细说明

视频码率&#xff08;Video Bitrate&#xff09;是指在单位时间内&#xff08;通常是每秒&#xff09;传输或处理的视频数据量&#xff0c;用比特&#xff08;bit&#xff09;表示。它通常用来衡量视频文件的压缩程度和质量&#xff0c;码率越高&#xff0c;视频质量越好&#…

计算机网络复习——概念强化作业

物理层负责网络通信的二进制传输 用于将MAC地址解析为IP地址的协议为RARP。 一个交换机接收到一帧,其目的地址在它的MAC地址表中查不到,交换机应该向除了来的端口外的所有其它端口转发。 关于ICMP协议,下面的论述中正确的是ICMP可传送IP通信过程中出现的错误信息。 在B类网络…

【AI系统】感知量化训练 QAT

感知量化训练 QAT 本文将会介绍感知量化训练&#xff08;QAT&#xff09;流程&#xff0c;这是一种在训练期间模拟量化操作的方法&#xff0c;用于减少将神经网络模型从 FP32 精度量化到 INT8 时的精度损失。QAT 通过在模型中插入伪量化节点&#xff08;FakeQuant&#xff09;…

【AI系统】模型压缩基本介绍

基本介绍 随着神经网络模型的复杂性和规模不断增加&#xff0c;模型对存储空间和计算资源的需求越来越多&#xff0c;使得部署和运行成本显著上升。模型压缩的目标是通过减少模型的存储空间、减少计算量或提高模型的计算效率&#xff0c;从而在保持模型性能的同时&#xff0c;…

使用GO--Swagger生成文档

概述 在前后端分离的项目中&#xff0c;后端配置swagger可以很好的帮助前端人员了解后端接口参数和数据传输。go-swagger 是一个功能全面且高性能的Go语言实现工具包&#xff0c;用于处理Swagger 2.0&#xff08;即OpenAPI 2.0&#xff09;规范。它提供了丰富的工具集&#x…

排查bug的通用思路

⭐️前言⭐️ APP点击某个按钮没有反应/PC端执行某个操作后&#xff0c;响应较慢&#xff0c;通用的问题排查方法: 从多个角度来排查问题 &#x1f349;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f349;博主将持续更新学习记录收获&#xff0c;友友们有任何问题可以在评…

2024年认证杯SPSSPRO杯数学建模C题(第一阶段)云中的海盐解题全过程文档及程序

2024年认证杯SPSSPRO杯数学建模 C题 云中的海盐 原题再现&#xff1a; 巴黎气候协定提出的目标是&#xff1a;在2100年前&#xff0c;把全球平均气温相对于工业革命以前的气温升幅控制在不超过2摄氏度的水平&#xff0c;并为1.5摄氏度而努力。但事实上&#xff0c;许多之前的…

oracle之用户的相关操作

&#xff08;1&#xff09;创建用户(sys用户下操作) 简单创建用户如下&#xff1a; CREATE USER username IDENTIFIED BY password; 如果需要自定义更多的信息&#xff0c;如用户使用的表空间等&#xff0c;可以使用如下&#xff1a; CREATE USER mall IDENTIFIED BY 12345…

ArcMap 处理河道坡度、计算污染区、三维爆炸功能

ArcMap 处理河道坡度、计算污染区、三维爆炸功能今天分析 一、计算河道方向坡度 1、折线转栅格 确定 2、提取河道高程值 确定后展示河流的高程值 3、计算坡向数据 确定后展示 4、计算坡度数据 确定后展示 二、计算上游集水区污染值 1、填挖处理 确定 2、计算流向 确定 3、计算…

一睹:微软最新发布的LazyGraphRAG

微软近期推出了一项革新性的技术——LazyGraphRAG&#xff0c;这是一种启用图谱的检索增强生成&#xff08;Retrieval Augmented Generation&#xff0c;RAG&#xff09;技术&#xff0c;它以其卓越的效率和成本效益&#xff0c;彻底颠覆了传统观念中对“懒惰”的刻板印象。 位…

linux_kernel_编程

内核报错信息查看 include/uapi/asm-generic/errno-base.h 设备树的读取操作 struct device_node *ncof_property_read_bool(nc, "spi-cpha")if (!of_node_name_eq(nc, "slave"))rc of_property_read_u32(nc, "reg", &…