pytorch 利用Tensorboar记录训练过程loss变化

文章目录

    • 1. LossHistory日志类定义
    • 2. LossHistory类的使用
      • 2.1 实例化LossHistory
      • 2.2 记录每个epoch的loss
      • 2.3 训练结束close掉SummaryWriter
    • 3. 利用Tensorboard 可视化
      • 3.1 显示可视化效果
    • 参考

利用Tensorboard记录训练过程中每个epoch的训练loss以及验证loss,便于及时了解网络的训练进展。

代码参考自 B导github仓库: https://github.com/bubbliiiing/deeplabv3-plus-pytorch

1. LossHistory日志类定义

import os
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signal

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
class LossHistory():
    def __init__(self, log_dir, model, input_shape):
        self.log_dir    = log_dir
        self.losses     = []
        self.val_loss   = []
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass

    def append_loss(self, epoch, loss, val_loss):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self.losses.append(loss)
        self.val_loss.append(val_loss)

        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
            f.write(str(loss))
            f.write("\n")
        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
            f.write(str(val_loss))
            f.write("\n")

        self.writer.add_scalar('loss', loss, epoch)
        self.writer.add_scalar('val_loss', val_loss, epoch)
        self.loss_plot()

    def loss_plot(self):
        iters = range(len(self.losses))

        plt.figure()
        plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
        plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
        try:
            if len(self.losses) < 25:
                num = 5
            else:
                num = 15
            
            plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
            plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
        except:
            pass

        plt.grid(True)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc="upper right")

        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))

        plt.cla()
        plt.close("all")
  • (1) 首先利用LossHistory类的构造函数__init__, 实例化TensorboardSummaryWriter对象self.writer,并将网络结构图添加到self.writer中。其中__init__方法接收的参数包括,保存log的路径log_dir以及模型model和输入的shape
def __init__(self, log_dir, model, input_shape):
        self.log_dir    = log_dir
        self.losses     = []
        self.val_loss   = []
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass
  • (2) 记录每个epoch的训练损失loss以及验证val_loss,并保存到tensorboar中显示
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

同时将训练的loss以及验证val_loss逐行保存到.txt文件中

 with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
      f.write(str(loss))
      f.write("\n")

并且在每个epoch时,调用loss_plot绘制历史的loss曲线,并保存为epoch_loss.png, 由于每个epoch保存的图片都是重名的,因此在训练结束时,会保存最新的所有epoch绘制的loss曲线

2. LossHistory类的使用

2.1 实例化LossHistory

在训练开始前,实例化LossHistory类,调用__init__实例化时,会创建SummaryWriter对象,用于记录训练的过程中的数据,比如loss, graph以及图片信息等

local_rank  = int(os.environ["LOCAL_RANK"]) 
model   = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained)
input_shape     = [512, 512]

if local_rank == 0:
      time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
      log_dir         = os.path.join(save_dir, "loss_" + str(time_str))
      loss_history    = LossHistory(log_dir, model, input_shape=input_shape)
  else:
      loss_history    = None
  • 对于多GPU训练时,只在主进程(local_rank == 0)记录训练的日志信息
  • log 保存的路径log_dir,利用loss_ + 当前时间的形式记录
log_dir         = os.path.join(save_dir, "loss_" + str(time_str))

2.2 记录每个epoch的loss

在每个epoch中,利用loss_history的append_loss方法,利用SummaryWriter对象保存loss:

for epoch in range(start_epoch, total_epoch):
	...
	loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  • 记录了每个epoch的训练loss以及验证val_loss
  • 同时将最新的loss曲线,保存到本地epoch_loss.png
  • 并将历史的训练loss和val_loss保存为txt文件,方便查看

2.3 训练结束close掉SummaryWriter

loss_history.writer.close()

3. 利用Tensorboard 可视化

  • Tensorboard最早是在Tensorflow中开发和应用的,pytorch 中也同样支持Tensorboard的使用,pytorch中的Tensorboard工具叫TensorboardX, 它需要依赖于tensorflow库中的一些组件支持。因此在安装Tensorboardx之前,需要先安装TensorFlow, 否则直接安装Tensorboardx运行会报错。
pip install tensorflow
pip install tensorboardX

3.1 显示可视化效果

训练结束后,cd到SummaryWriter中定义好日志保存目录log_dir下,执行如下指令

cd log_dir # log_dir为定义的日志保存目录
tensorboard  --logdir=./     --port 6006 

然后会显示出访问的链接地址,点击链接就可以查看Tensorboard可视化效果

  • Scalar模块展示训练过程中,每个epoch的train_loss、Accuracy、Learn_Rating的数值变化
    在这里插入图片描述
  • GRAPH模块展示的是模型的网络结构
    在这里插入图片描述
  • HISTOGRAMS模块展示添加到tensorboard中各层的权重分布情况
    在这里插入图片描述

参考

  • 1 https://github.com/bubbliiiing/deeplabv3-plus-pytorch
  • 2 pytorch中使用tensorboard实现训练过程可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/367209.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

布隆过滤器的概述和使用

1 布隆过滤器概述 1.1 概述 布隆过滤器&#xff08;Bloom Filter&#xff09;是1970年由布隆提出的。它实际上是由一个很长的二进制向量&#xff08;数组&#xff09;和一系列随机映射函数&#xff08;hash函数&#xff09;组成&#xff0c;它不存放数据的明细内容&#xff0…

FANUC机器人开机时无法进入系统,示教器黑屏故障处理总结

FANUC机器人开机时无法进入系统&#xff0c;示教器黑屏故障处理总结 故障描述&#xff1a; FANUC机器人开机时&#xff0c;示教器在初始化时显示&#xff1a;EMAC initial call failed&#xff08;示教器上电时会进入boot画面&#xff0c;左上角会出现一些白色的英文提示&#…

YOLOv5白皮书-第Y3周:yolov5s.yaml文件解读

YOLOv5白皮书-第Y3周:yolov5s.yaml文件解读 YOLOv5白皮书-第Y3周:yolov5s.yaml文件解读一、前言二、我的环境三、yolov5s.yaml源文件内容四、Parameters五、anchors配置六、backbone七、head八、总结 OLOv5-第Y2周&#xff1a;训练自己的数据集) YOLOv5白皮书-第Y3周:yolov5s.…

学习日志以及个人总结 (16)

共用体 共用体 union 共用体名 { 成员列表&#xff1b; }&#xff1b;//表示定义一个共用体类型 注意&#xff1a; 1.共用体 初始化 --- 只能给一个值&#xff0c;默认是给到第一个成员变量 2.共用体成员变量辅助 3.可以判断大小端 ----※&#xff01;&#xff01; 实际用途…

猫用空气净化器好吗?好用的养猫宠物空气净化器品牌推荐

作为一个养猫五年的资深铲屎官&#xff0c;我对如何轻松快乐地养猫有一些心得。猫咪每天在家里奔跑&#xff0c;导致家里经常会出现“猫毛雪”&#xff0c;沙发、地板和衣服都成了重灾区。在除猫毛的问题上&#xff0c;我真的尝试了各种方法&#xff0c;几乎用上了所有的技能。…

2024美赛预测算法 | 回归预测 | Matlab基于RIME-LSSVM霜冰算法优化最小二乘支持向量机的数据多输入单输出回归预测

2024美赛预测算法 | 回归预测 | Matlab基于RIME-LSSVM霜冰算法优化最小二乘支持向量机的数据多输入单输出回归预测 目录 2024美赛预测算法 | 回归预测 | Matlab基于RIME-LSSVM霜冰算法优化最小二乘支持向量机的数据多输入单输出回归预测预测效果基本介绍程序设计参考资料 预测效…

2024美赛E题数学建模思路代码数据分享

2024 ICM Problem E: Sustainability of Property Insurance 本题要求选取不同大陆上经历极端天气的两个地区来为保险公司开发模型&#xff0c;本题的重点是找到尽可能多而全的数据&#xff0c;包括天气数据&#xff0c;经济数据&#xff0c;人口数据等。 模型选择&#xff1a…

《最新出炉》系列入门篇-Python+Playwright自动化测试-10-标签页操作(tab)

1.简介 标签操作其实也是基于浏览器上下文&#xff08;BrowserContext&#xff09;进行操作的&#xff0c;而且宏哥在之前的BrowserContext也有提到过&#xff0c;但是有的童鞋或者小伙伴还是不清楚怎么操作&#xff0c;或者思路有点模糊&#xff0c;因此今天单独来对其进行讲…

Windows内存管理 - 物理内存概念(Physical Memory Address)

作为windows驱动程序的程序员&#xff0c;需要比普通程序员更多的了解Windows内部的内存管理机制&#xff0c;并在驱动程序中有效地使用内存。在驱动程序编写中&#xff0c;分配和管理内存不能使用熟知的Win32 API函数&#xff0c;取而代之的是DDK提供的高效的内核函数。程序员…

PKG系统安装包及IPSW固件:MacOS 11-14 Sonoma 正式版

MacOS 14 Sonoma&#xff0c;为提高生产力和创造力带来了全新的功能&#xff0c;有了更多使用小部件和令人惊叹的新屏幕保护程序进行个性化设置的方法&#xff0c;对Safari浏览器和视频会议进行了重大更新&#xff0c;以及优化的游戏体验——Mac体验比以往任何时候都更好。 mac…

MySQL篇----第三篇

系列文章目录 文章目录 系列文章目录前言一、InnoDB与MyISAM的区别二、索引三、常见索引原则有前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 一、InnoDB与MyISAM…

【Android Studio 启动出错】

Android Studio版本&#xff1a;2022.3.1 出错前操作&#xff1a; 昨晚开着三四个项目&#xff0c;然后太晚了直接关机睡觉&#xff0c;第二天起来开机&#xff0c;启动Android Studio&#xff0c;就出现了这个问题&#xff1a; Internal error. Please refer to https://co…

opencv+mediapipe 手势识别控制电脑音量(详细注释解析)

前段时间社团布置了一个手势识别控制电脑音量的小任务&#xff0c;今天记录一下学习过程&#xff0c;将大佬作品在我的贫瘠的基础上解释一下~ 项目主要由以下4个步骤组成&#xff1a; 1、使用OpenCV读取摄像头视频流 2、识别手掌关键点像素坐标 3、根据拇指和食指指尖的坐标…

今日早报 每日精选15条新闻简报 每天一分钟 知晓天下事 2月3日,星期六

每天一分钟&#xff0c;知晓天下事&#xff01; 2024年2月3日 星期六 农历腊月廿四 南小年 1、 气象局&#xff1a;将雨雪冰冻三级应急响应提升为二级&#xff0c;针对性做好春运气象保障服务。 2、 教育部&#xff1a;鼓励银龄教师投身西部地区、民族地区民办学校。 3、 四部…

解决java.lang.ClassCastException

目录 问题 原因 解决方案 问题 前后端分离开发中&#xff0c;往往需要统一封装返回数据用到一个Result<T>类包装多个接口&#xff1a; 重复劳动并不优雅&#xff0c;于是想用RestControllerAdvice做控制器拦截增强&#xff0c;进行封装。 代码如下&#xff1a; Res…

[Python] 什么是PCA降维技术以及scikit-learn中PCA类使用案例(图文教程,含详细代码)

什么是维度&#xff1f; 对于Numpy中数组来说&#xff0c;维度就是功能shape返回的结果&#xff0c;shape中返回了几个数字&#xff0c;就是几维。索引以外的数据&#xff0c;不分行列的叫一维&#xff08;此时shape返回唯一的维度上的数据个数&#xff09;&#xff0c;有行列…

ImportError: urllib3 v2.0 only supports OpenSSL 1.1.1+

错误记录&#xff1a; 安装使用moviepy&#xff0c;测试出现问题。 解决方案&#xff1a; 采用降低urllib3的版本的方式&#xff0c;实测可行。 pip install urllib31.*

【Kafka专栏】windows搭建Kafka环境 详细教程(01)

文章目录 01 引言1.1 官网地址1.2 概述简介1.3 kafka与zookeeper 02 部署zookeeper2.1 下载组件包2.2 解压压缩包&#xff08;1&#xff09;解压到任意路径&#xff08;2&#xff09;解压后的目录创建数据目录data 2.3 修改zoo配置2.4 设置系统变量2.5 启动zookeepe服务&#x…

数据结构+算法(第13篇):精通二叉树的“独门忍术”——线索二叉树(上)

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 学习必须往深处挖&…

【数据结构】(分治策略)中位数的查询和最接近点对问题

中位数查询&#xff1a; 寻找一组字符串中第k小的数&#xff0c;返回其值和下标。 不可以有重复值&#xff08;在缩小规模的时候&#xff0c;会导致程序死循环&#xff09; 相对位置的转换体现了分治策略的思想。> 划分函数 int partition(int *nums,int left, int rig…