用TensorBoard可视化PyTorch

一、TensorBoard与PyTorch配合使用的基本步骤

PyTorch可以直接与TensorBoard进行集成,因为TensorBoard是一个独立于TensorFlow之外的可视化工具。TensorBoard被设计为支持机器学习实验的可视化,如训练的进度和结果等。PyTorch中的`torch.utils.tensorboard`模块允许PyTorch用户使用这个强大的可视化工具。
以下是将TensorBoard与PyTorch配合使用的基本步骤:
1. 在PyTorch中安装TensorBoard:   

pip install tensorboard

2. 在Python代码中导入TensorBoard的`SummaryWriter`:   

from torch.utils.tensorboard import SummaryWriter

3. 创建一个`SummaryWriter`实例,它将日志写入指定的目录:

writer = SummaryWriter('runs/your_experiment_name')

4. 将数据写入日志:

   # For example, log scalars
   writer.add_scalar('Loss/train', loss_value, epoch)

   # Log values and models
   writer.add_histogram('weights', model.weight, epoch)
   writer.add_graph(model, input_to_model)

   # Log images
   writer.add_image('input_image', img, epoch)

   # And many more...

   5. 当所有日志都写入后,在命令行启动TensorBoard,在浏览器中查看结果:

   tensorboard --logdir=runs

之后,就可以在TensorBoard的Web界面中看到各种图形和数据的可视化展现,这对于理解模型的学习过程、调试以及展示结果是非常有用的。
此外,社区也开发了一些其他可视化工具,比如`visdom`,但TensorBoard因其功能强大和易用性,在PyTorch社区中得到了广泛的应用。 

二、PyTorch与TensorBoard进行集成的完整示例

要将PyTorch与TensorBoard结合起来,可以使用`tensorboardX`库,这是一个提供了与TensorBoard兼容的API的库,使得可以从PyTorch中记录数据并在TensorBoard中查看。不过,从PyTorch 1.1.0起,官方直接内置了对TensorBoard的支持,称为`torch.utils.tensorboard`。以下是一个简单的例子,说明如何使用PyTorch训练一个模型并使用TensorBoard记录日志:
首先,确保已经安装PyTorch和TensorBoard:

pip install torch torchvision tensorboard

接下来,是一个简单的训练脚本示例,将会记录损失和精度:


import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 创建一些数据进行演示
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 转换为torch张量
tensor_x = torch.Tensor(X_train)
tensor_y = torch.Tensor(y_train).long()
tensor_x_test = torch.Tensor(X_test)
tensor_y_test = torch.Tensor(y_test).long()

# 创建数据加载器
train_dataset = TensorDataset(tensor_x, tensor_y)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = TensorDataset(tensor_x_test, tensor_y_test)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 创建一个简单的模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(20, 64)
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 设置TensorBoard
writer = SummaryWriter()

# 训练模型
for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # 将结果写入TensorBoard
    writer.add_scalar('training loss', running_loss / len(train_loader), epoch)
    
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    writer.add_scalar('accuracy', correct / total, epoch)

writer.close()
print('Finished Training')

# Now you can view the results by running in your terminal:
# tensorboard --logdir=runs

在这个脚本中,我们创建了一个简单的完全连接网络用于分类,并将训练过程中的损失和精度写入TensorBoard。要查看TensorBoard结果,保存并运行上面的脚本。然后,在终端中运行以下命令:

tensorboard --logdir=runs

打开显示的URL,将能够看见TensorBoard的仪表盘,反映出模型训练过程中记录的数据。

这段代码是一个使用PyTorch进行数据分类的示例,同时演示了如何将训练过程的信息记录到TensorBoard中。
1. 导入必要的库:
   这段代码首先导入了PyTorch相关的库,包括模型(layers)、优化(optimizer)、数据处理等组件。同时还导入了`SummaryWriter`用于向TensorBoard写入数据。
2. 生成和预处理数据:
   代码使用scikit-learn库中的`make_classification`函数生成了一个具有1000个样本、20个特征的合成分类数据集。然后,它使用`train_test_split`将数据集分割为训练集和测试集。标准化这些数据使得每个特征的分布均值为0,方差为1。
3. 准备PyTorch数据加载器:
   代码将处理好的数据转换为PyTorch张量,然后创建了`TensorDataset`数据集对象,最后通过`DataLoader`为训练和测试数据集创建迭代器,用于在训练过程中加载数据。
4. 定义简单的神经网络模型:
   定义了一个名为`SimpleNet`的神经网络类,它包含两个全连接层,第一个全连接层将20个特征映射到64个隐藏单元,接着是ReLU激活函数,最后一个全连接层将64个隐藏单元映射到2个输出(因为是二分类问题)。
5. 创建损失函数和优化器:
   使用交叉熵损失函数作为分类问题的损失函数,以及使用Adam优化器对模型的参数进行优化。
6. 设置TensorBoard日志记录器:
   初始化了`SummaryWriter`,这个对象将用于将训练过程中的信息写入日志文件,这些文件可以被TensorBoard读取并可视化。
7. 训练模型:
   在一个循环中,代码遍历了数据集多次(这里定义了10个epoch),在每个epoch中,设置模型为训练模式,并用数据加载器获取训练数据。对于每个批次的数据,执行前向传播,计算损失,执行反向传播和优化步骤。同时,汇总损失并在每个epoch结束时将平均损失记录到TensorBoard中。
8. 评估模型和记录准确率:
   在每个训练epoch之后,代码进入评估模式,并停止梯度计算,使用测试数据集计算模型预测的准确性,并将这个结果记录到TensorBoard中。
9. 关闭TensorBoard日志记录器并完成训练:
   训练结束后,会关闭`SummaryWriter`,此时训练生成的日志文件已经写入磁盘中`runs`目录下。打印完成训练的信息。
10. 查看TensorBoard中的结果:
    最后,通过命令行中运行`tensorboard --logdir=runs`来启动TensorBoard服务,并可以在浏览器中打开显示的URL来查看训练过程中记录的损失和准确率曲线。
另,with torch.no_grad():是用来停止PyTorch跟踪梯度信息。在测试模式下通常需要这么做,以减少内存消耗并加速计算。

三、可能出现的问题

由于`collections.Mapping`在Python 3.10及以后的版本已经被移除了,而应该使用`collections.abc.Mapping`。由于`tensorboard`的某些依赖库在较新版的Python中可能仍在使用已经废弃的模块路径,因此抛出了`ImportError`。
如果TensorBoard是独立于PyTorch环境外安装的,可能需要在一个PyTorch支持的Python环境中安装和运行TensorBoard。PyTorch目前支持的Python版本是3.6-3.9,如果环境中的Python版本是3.12,这有可能导致兼容性问题。
要解决这个问题,可以试图降低Python的版本,创建一个新的虚拟环境,安装一个TensorBoard版本,该版本与Python版本兼容,或等待或协助贡献TensorBoard对新Python版本的支持。下面是创建新虚拟环境并尝试安装TensorBoard的方法:
1. 创建新的Python环境(推荐使用Python 3.9)并激活它:

conda create -n new_env python=3.9
conda activate new_env

2. 在新环境中安装TensorBoard和PyTorch:

pip install tensorboard torch torchvision

3. 重新尝试启动TensorBoard:

tensorboard --logdir=runs

这种方式安装可能会避开遇到的兼容性问题。如果问题依旧存在,请考虑在相应的TensorBoard或相关依赖包的GitHub问题跟踪页面提交问题报告,以获取官方或社区的解决方案。在等待修复的同时,可以使用其他的Python版本,在那里TensorBoard是兼容的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/527956.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SSM党员管理系统

一、系统介绍 党员管理系统: 可以方便管理人员对党员管理系统的管理,提高信息管理工作效率及查询效率,有利于更好的为用户提供服务。 主要的模块包括: 1、后台功能: 管理员角色:首页、个人中心,党员管理…

“成像光谱遥感技术中的AI革命:ChatGPT在遥感领域中的应用“

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境,是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型,在理解和生成人类语言方面表现出了非凡的能力。本文重点介绍ChatGPT在遥感中的应用,人工智能…

Vue-B站学习笔记

1. 路由配置 B站视频之Vue route文件下的index.js app.vue

云原生:5分钟了解一下Kubernetes是什么

在当今的云计算时代,容器化技术变得越来越重要。它能够帮助开发者更高效地部署和管理应用程序。而Kubernetes,作为容器编排领域的领军者,正逐渐成为企业构建和管理云原生应用的核心工具。 近期将持续为大家分享Kubernetes相关知识&#xff…

78、WAF攻防——菜刀冰蝎哥斯拉流量通讯特征绕过检测反制感知

文章目录 菜刀冰蝎哥斯拉 本节主要讲上传了后门,要是用webshell连接工具进行连接,而webshell连接工具特征可能被检测系统识别相关内容。 web防护软件: WEB应用防火墙——WAF入侵检测系统——IDS、HIDS威胁感知——天眼 感知威胁技术&#x…

基于单片机数码管20V电压表仿真设计

**单片机设计介绍,基于单片机数码管20V电压表仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机数码管20V电压表仿真设计的主要目的是通过单片机和数码管显示电路实现一个能够测量0到20V直流电压的电…

如何设计系统容量?

单位每年都会举行运动会,有一个2000m长跑的项目,大约每年报名人员为男选手40人,女选手20人,只有一条橡胶跑道。一次比赛10人齐跑,所以至少需要6场比赛。 2000米的完成时间要求是20分钟,超过20分钟不计数&a…

世强硬创获德佑威授权代理,拓展UV胶粘剂/PUR热熔胶等产品布局

随着下游应用市场产品不断更新迭代,以及企业的环保意识提高,企业对电子胶粘材料的性能要求越来越高,从而推动上游原厂的技术创新与升级,为国内提供更多高性能国产胶粘材料。 基于优良的口碑,世强先进(深圳…

关于JVM-三色标记算法剖析

相关系列 深入理解JVM垃圾收集器-CSDN博客 深入理解JVM垃圾收集算法-CSDN博客 深入理解jvm执行引擎-CSDN博客 jvm优化原则-CSDN博客 jvm流程图-CSDN博客 三色标记产生的原因? 在并发标记的过程中,因为标记期间应用线程还在继续跑,对象间的引…

面试题:volatile

一旦一个共享变量(类的成员变量、类的静态成员变量)被volatile修饰之后,那么就具备了两层语义: 1. 保证线程间的可见性 保证了不同线程对这个变量进行操作时的可见性,即一个线程修改了某个变量的值,这新值…

Javascript进阶内容

1. 作用域 1.1 局部作用域 局部作用域分为函数作用域 和 块级作用域 块级作用域就是用 {} 包起来的,let、const声明的变量就是产生块作用域,var不会;不同代码块之间的变量无法互相访问,里面的变量外部无法访问 1.2 全局作用域…

安卓开机启动流程

目录 一、整体框架二、流程代码分析2.1 Boot ROM2.2 Boot Loader2.3 Kernel层Kernel代码部分 2.4 Init进程Init进程代码部分 2.5 zygote进程zygote代码部分 2.6 SystemServer进程SystemServer代码部分 2.7 启动Launcher与SystemUI 三、SystemServices3.1 引导服务3.2 核心服务3…

Linux|从 STDIN 读取 Awk 输入

简介 在之前关于 Awk 工具的系列文章中,主要探讨了如何从文件中读取数据。但如果你希望从标准输入(STDIN)中读取数据,又该如何操作呢? 在本文中,将介绍几个示例,展示如何使用 Awk 来过滤其他命令…

开创加密资产新纪元:深度解析ERC-314协议

随着加密资产市场的不断发展和区块链技术的日益成熟,新的协议和标准不断涌现,其中包括了ERC-314协议。本文将深入分析ERC-314协议的特点、功能以及对加密资产市场可能产生的影响。 1. ERC-314协议简介 ERC-314协议是一项建立在以太坊区块链上的新提案&a…

软件测试中的43个功能测试点总结

功能测试就是对产品的各功能进行验证,根据功能测试用例,逐项测试,检查产品是否达到用户要求的功能。针对web系统的常用测试方法如下: 1、页面链接检查: 每一个链接是否都有对应的页面,并且页面之间切换正…

设计模式之状态模式讲解

概念:又称为状态对象模式,该模式允许一个对象在其内部状态改变时改变其行为。状态模式的核心是封装,状态的变更引起行为的变动,从外部看来就好像该对象对应的类发生改变一样。 抽象状态:用以封装环境对象的一个特定状态…

thinkphp6使用阿里云SDK发送短信

使用composer安装sdk "alibabacloud/dysmsapi-20170525": "2.0.24"封装发送短信类 发送到的短信参数写在env文件里面的 #发送短信配置 [AliyunSms] AccessKeyId "" AccessKeySecret "" signName"" templateCode"&…

尚硅谷html5+css3(3)布局

1.文档流normal flow -网页是一个多层结构 -通过CSS可以分别为每一层设置样式 -用户只能看到最顶层 -最底层&#xff1a;文档流&#xff08;我们所创建的元素默认都是从文档流中进行排列&#xff09; <head><style>.box1 {background-color: blue;}/*它的父元…

精益管理培训在哪些行业比较适用?

在当今瞬息万变的市场环境中&#xff0c;企业竞争日趋激烈&#xff0c;如何提升内部管理水平、降低成本、提高效率&#xff0c;成为企业持续发展的关键。精益管理作为一种先进的管理理念和方法&#xff0c;正逐渐被越来越多的行业所采纳和应用。本文&#xff08;深圳天行健精益…

MSO7104A安捷伦MSO7104A示波器

181/2461/8938产品概述&#xff1a; 带宽:1 GHz通道:4个模拟通道和16个数字通道采样速率:4 GSa/s记录长度:标准8 Mpts MegaZoom III深内存垂直分辨率:8位自动量程和峰值检测有洞察力的应用软件分段存储器使用FFT的波形数学模拟高清电视/EDTV触发器 总线模式显示和简单的软件升…