WandB使用笔记

最近看代码,发现代码中有wandb有关的内容,搜索了一下发现是一个模型训练工具,然后学习了一下,这里记录一下使用过程,方便以后查阅。

WandB使用笔记

  • 登录WandB 并 创建团队
  • 安装 WandB 并 登录
  • 模型训练过程跟踪
  • 模型版本管理
  • 自动调参
  • 不同的模型训练工具对比
  • 参考资料

作者自注:之前训练模型一直使用的是Visdom,感觉非常好用,然后现在学习了一下WandB,发现先各有优劣。Visdom的曲线实时跟踪效果好,但是功能简单。WandB曲线实时跟踪效果差(可能是我的网的问题),但是功能强大,可以保存每次模型调优的参数,这样就不用手动再记录了;可以实现模型的版本管理,这样就可以随便改代码,不用担心改坏了;可以进行参数分析,这样就可以有目的的进行参数调优;可以进行自动调参,这样可在完成粗调制后进行局部的参数寻优。感觉以后两个可以同时使用,提高模型调优的效率

登录WandB 并 创建团队

点击下面的网站进入WandB:https://wandb.ai/site,然后点击界面中的 LOGIN 进行登录。

在这里插入图片描述

如下需要选择登录的方式,这里我选择的是 GitHub 。

在这里插入图片描述

完成登陆后进入如下初始界面,点击图片中红框中的内容,创建一个新的 team

在这里插入图片描述

之后进入如下界面,输入团队名称,并点击 Create team ,完成团队的创建。

在这里插入图片描述

团队创建成功后出现如下界面,选择是否把自己的 runs 更新到 team ,这里选择 Update

在这里插入图片描述

如此就完成了登录和团建创建过程!

如果想要删除创建的团队,则在主界面点击创建的团队,如下图所示:

在这里插入图片描述

进入团队后,点击 Team settings ,如下图所示:

在这里插入图片描述

接着滑动到最下面,点击 Delete team

在这里插入图片描述

接着需要你输入 团队的名称 进行删除,这里的逻辑跟GitHub删除项目一样。

在这里插入图片描述

安装 WandB 并 登录

使用 pip 安装 WandB:

pip install wandb

在这里插入图片描述

验证安装是否成功:

wandb --version

在这里插入图片描述

首次使用 WandB 时,需要登录账户:

wandb login

在这里插入图片描述

登录后,WandB 会提示输入 API 密钥。可以从 WandB 的 API 密钥页面 获取密钥,点击图片中的红框部分,复制密钥,然后粘贴到上图的 3 标识的地方,并点击回车,如此就完成了登录过程。

在这里插入图片描述

如果你之前已经登陆过了,则会出现如下的内容:
在这里插入图片描述

然后在终端输入如下的命令即可重新登录:

wandb login --relogin

在这里插入图片描述

模型训练过程跟踪

将如下代码复制到PyCharm中,进行实验。

import wandb
import torch
from torch import nn
import torchvision
from torchvision import transforms
import datetime
from argparse import Namespace

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Namespace(
    project_name='wandb_demo',

    batch_size=512,

    hidden_layer_width=64,
    dropout_p=0.1,

    lr=1e-4,
    optim_type='Adam',

    epochs=150,
    ckpt_path='checkpoint.pt'
)


def create_dataloaders(config):
    transform = transforms.Compose([transforms.ToTensor()])
    ds_train = torchvision.datasets.MNIST(root="./mnist/", train=True, download=True, transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./mnist/", train=False, download=True, transform=transform)

    ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))
    dl_train = torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True, drop_last=True)
    dl_val = torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, drop_last=True)
    return dl_train, dl_val


def create_net(config):
    net = nn.Sequential()
    net.add_module("conv1", nn.Conv2d(in_channels=1, out_channels=config.hidden_layer_width, kernel_size=3))
    net.add_module("pool1", nn.MaxPool2d(kernel_size=2, stride=2))
    net.add_module("conv2", nn.Conv2d(in_channels=config.hidden_layer_width,
                                      out_channels=config.hidden_layer_width, kernel_size=5))
    net.add_module("pool2", nn.MaxPool2d(kernel_size=2, stride=2))
    net.add_module("dropout", nn.Dropout2d(p=config.dropout_p))
    net.add_module("adaptive_pool", nn.AdaptiveMaxPool2d((1, 1)))
    net.add_module("flatten", nn.Flatten())
    net.add_module("linear1", nn.Linear(config.hidden_layer_width, config.hidden_layer_width))
    net.add_module("relu", nn.ReLU())
    net.add_module("linear2", nn.Linear(config.hidden_layer_width, 10))
    net.to(device)
    return net


def train_epoch(model, dl_train, optimizer):
    model.train()
    for step, batch in enumerate(dl_train):
        features, labels = batch
        features, labels = features.to(device), labels.to(device)

        preds = model(features)
        loss = nn.CrossEntropyLoss()(preds, labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    return model


def eval_epoch(model, dl_val):
    model.eval()
    accurate = 0
    num_elems = 0
    for batch in dl_val:
        features, labels = batch
        features, labels = features.to(device), labels.to(device)
        with torch.no_grad():
            preds = model(features)
        predictions = preds.argmax(dim=-1)
        accurate_preds = (predictions == labels)
        num_elems += accurate_preds.shape[0]
        accurate += accurate_preds.long().sum()

    val_acc = accurate.item() / num_elems
    return val_acc
def train(config=config):
    dl_train, dl_val = create_dataloaders(config)
    model = create_net(config);
    optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)
    # ======================================================================
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)
    model.run_id = wandb.run.id
    # ======================================================================
    model.best_metric = -1.0
    for epoch in range(1, config.epochs + 1):
        model = train_epoch(model, dl_train, optimizer)
        val_acc = eval_epoch(model, dl_val)
        if val_acc > model.best_metric:
            model.best_metric = val_acc
            torch.save(model.state_dict(), config.ckpt_path)
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")
        # ======================================================================
        wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})
        # ======================================================================
    # ======================================================================
    wandb.finish()
    # ======================================================================
    return model

上述代码最关键的就是如下三个部分:

  1. 初始化部分:
wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)
  1. 模型训练参数上传
wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})
  1. 模型训练完成关闭wandb:
wandb.finish()

最后在PyCharm中输入如下代码,即可运行上述代码:

model = train(config)

代码运行成功,即可出现如下的界面,点击下图中红框中的部分,即可跳转到曲线监视界面。

在这里插入图片描述

模型训练过程监视界面如下图所示:

在这里插入图片描述

点击下图中的红框部分,更改曲线的横坐标值。

在这里插入图片描述

如下图所示,将横坐标值更改为 epoch。

在这里插入图片描述

然后我们还可以增加一个 section

在这里插入图片描述

在新的 section 中添加新的显示模块,如下图所示:

在这里插入图片描述

此处我们添加了验证集的准确率,实现实时的监控。

在这里插入图片描述

模型训练结束,我们可以点击 runs 查看历史记录。

在这里插入图片描述

如下图可以看到,我们刚才监视的曲线,如图中的长方形红框所示。然后点击小红框中的 runs ,查看每一次训练过程的模型参数。

在这里插入图片描述

每一次模型训练的参数如下图所示,可以选择图中红框中的内容,选择需要的参数进行显示。

在这里插入图片描述

可选择的指标如下图所示:

在这里插入图片描述

对于某些我们比较关注的指标,我们可以将其固定显示:

在这里插入图片描述

固定后,我们回到 Workspace 界面,即可看到固定的参数。

在这里插入图片描述

模型版本管理

除了可以记录实验日志传递到 wandb 网站的云端服务器 并进行可视化分析。wandb还能够将实验关联的数据集,代码和模型 保存到 wandb 服务器。我们可以通过 wandb.log_artifact的方法来保存任务的关联的重要成果。例如 dataset, code,和 model,并进行版本管理。

当我们跑出一个相对不错的结果时,我们希望把这个结果给保存下来,此时我们就可以使用该功能。

我们先使用run_id 恢复 run任务,以便继续记录。

import wandb
# resume the run
run = wandb.init(project='wandb_demo', id='6h5xkv16', resume='allow')

上述代码中的 id 是用来关联我们训练的 runs 的,参数的值来自下图红框中的内容,想搞关联某一次的训练过程,就把某一次训练的 ID 写入上述代码。

在这里插入图片描述
保存数据集的代码:

# save dataset
arti_dataset = wandb.Artifact(name='mnist', type='dataset')
arti_dataset.add_dir('mnist/')
wandb.log_artifact(arti_dataset)

保存模型文件的代码:

# save code
arti_code = wandb.Artifact(name='py', type='code')
arti_code.add_file('./wandb_test.py')
wandb.log_artifact(arti_code)

保存模型权重的代码:

# save model
arti_model = wandb.Artifact(name='cnn', type='model')
arti_model.add_file(config.ckpt_path)
wandb.log_artifact(arti_model)

最后结束时要使用一下代码:

# finish时会提交保存
wandb.finish()

上传后的效果如图所示:

在这里插入图片描述

自动调参

sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。

在这里插入图片描述

使用Sweep的3步骤:

  1. 配置 sweep_config
# 配置 Sweep config
sweep_config = {
    'method': 'random',  # 选择调优算法,超参数搜索方法:随机搜索
    'metric': {          # 定义调优目标
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {     # 定义超参空间
        'project_name': {'value': 'wandb_demo'},    # 固定不变的超参
        'epochs': {'value': 10},
        'ckpt_path': {'value': 'checkpoint.pt'},

        'optim_type': {                             # 离散型分布超参
            'values': ['Adam', 'SGD', 'AdamW']
        },
        'hidden_layer_width': {
            'values': [16, 32, 48, 64, 80, 96, 112, 128]
        },

        'lr': {                                     # 连续型分布超参
            'distribution': 'log_uniform_values',
            'min': 1e-6,
            'max': 0.1
        },
        'batch_size': {
            'distribution': 'q_uniform',
            'q': 8,
            'min': 32,
            'max': 256,
        },
        'dropout_p': {
            'distribution': 'uniform',
            'min': 0,
            'max': 0.6,
        }
    },
    # 'early_terminate': {    # 定义剪枝策略 (可选)
    #     'type': 'hyperband',    # 使用 HyperBand 作为早停策略
    #     'min_iter': 3,          # 最小评估迭代次数(第 3 次迭代后开始考虑剪枝)
    #     'eta': 2,               # 成倍增长的资源分配比例(每次迭代中仅保留约 1/eta 的实验)
    #     's': 3                  # HyperBand 的最大阶数,影响资源分配的层级
    # }
}
from pprint import pprint
pprint(sweep_config)

Sweep支持如下3种调优算法:

(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。

(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。

(3)贝叶斯搜索:bayes.
创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。

  1. 初始化 sweep controller
# 初始化 sweep controller
sweep_id = wandb.sweep(sweep_config, project=config.project_name)
  1. 启动 sweep agents
# 启动 Sweep agent
# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)

等代码跑完我们就有了一个 sweep,如下图所示:

在这里插入图片描述

进入 sweep 之后就可以添加 Parallel coordinatesParameter importance 进行参数分析。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

不同的模型训练工具对比

工具实验管理数据版本控制模型部署团队协作离线支持特点
TensorBoard轻量级工具,适合快速原型开发
WandB功能全面,支持超参数调优和实时协作
Comet简单易用,支持离线模式
MLflow实验管理与模型部署一体化
Neptune强大的可视化功能
Sacred极简实验管理工具
Polyaxon分布式训练与大规模实验管理支持
DVC专注于数据和模型版本控制
ClearML全面的 MLOps 功能

参考资料

30分钟吃掉wandb模型训练可视化

wandb我最爱的炼丹伴侣操作指南

30分钟吃掉wandb可视化自动调参

wandb可视化调参完全指南

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/948445.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一文理解ssh,ssl协议以及应用

在使用基于密钥的认证方式的时候,私钥的位置一定要符合远程服务器规定的位置,否则找不到私钥的位置会导致建立ssh连接失败 SSH 全称是 “Secure Shell”,即安全外壳协议。 它是一种网络协议,用于在不安全的网络中安全地进行远程登…

Elasticsearch 创建索引 Mapping映射属性 索引库操作 增删改查

Mapping Type映射属性 mapping是对索引库中文档的约束,有以下类型。 text:用于分析和全文搜索,通常适用于长文本字段。keyword:用于精确匹配,不会进行分析,适用于标签、ID 等精确匹配场景。integer、long…

【Ubuntu】 Ubuntu22.04搭建NFS服务

安装NFS服务端 sudo apt install nfs-kernel-server 安装NFS客户端 sudo apt install nfs-common 配置/etc/exports sudo vim /etc/exports 第一个字段:/home/lm/code/nfswork共享的目录 第二个字段:指定哪些用户可以访问 ​ * 表示所有用户都可以访…

【谷歌开发者月刊】十二月精彩资讯回顾,探索科技新可能

我们在今年的尾声中回顾本月精彩,开发者们借助创新技术为用户打造温暖的应用体验,展现技术与实用的结合。欢迎您查阅本期月刊,掌握最新动态。 本月看点 精彩看点多多,请上下滑动阅览 01DevFest 北京站和上海站圆满举办&#xff0c…

浙江中医药大学携手云轴科技ZStack荣获“鼎信杯”金鼎实践奖

近日,2024“鼎信杯”信息技术发展论坛(以下简称“论坛”)在北京隆重召开。本次论坛汇聚多位领导和专家,以及业内骨干企业、研究机构、用户单位、行业组织代表等500余人,共同探讨信息技术应用创新产业趋势,分…

嵌入式linux系统中CMake的基本用法

第一:CMake的基本使用 在上篇文章中,我们聊了聊 Makefile。虽然它是 C/C++ 项目编译的“老司机”,但写起来真的是让人头大。尤其是当项目文件一多,手写依赖就像在搬砖,费时又费力。 那么问题来了,难道我们就没有更优雅的工具了吗?答案是:有! 这时候,CMake 就像一个…

vulnhub Earth靶机

搭建靶机直接拖进来就行 1.扫描靶机IP arp-scan -l 2.信息收集 nmap -sS -A -T4 192.168.47.132 得到两个DNS; 在443端口处会让我们加https dirb https://earth.local/ dirb https://terratest.earth.local/ #页面下有三行数值 37090b59030f11060b0a1b4e0000000000004312170a…

【AI日记】25.01.04 kaggle 比赛 3-3 | 王慧玲与基层女性

【AI论文解读】【AI知识点】【AI小项目】【AI战略思考】【AI日记】 工作 参加:kaggle 比赛 Forecasting Sticker Sales时间:6 小时 读书 书名:基层女性时间:3 小时原因:虽然我之前就知道这个作者,因为我…

《learn_the_architecture_-_aarch64_exception_model》学习笔记

1.当发生异常时,异常级别可以增加或保持不变,永远无法通过异常来转移到较低的权限级别。从异常返回时,异常级别可能会降低或保持不变,永远无法通过从异常返回来移动到更高的权限级别。EL0级不进行异常处理,异常必须在比…

linux上安装MySQL教程

1.准备好MySQL压缩包,并进行解压 tar -xvf mysql-5.7.28-1.el7.x86_64.rpm-bundle.tar -C /usr/local 2.检查是否有mariadb数据库 rpm -aq|grep mariadb 关于mariadb:是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可 MariaDB的目…

量子力学复习

黑体辐射 热辐射 绝对黑体: (辐射能力很强,完全的吸收体,理想的发射体) 辐射实验规律: 温度越高,能量越大,亮度越亮 温度越高,波长越短 光电效应 实验装置&#xf…

OSI模型的网络层中产生拥塞的主要原因?

( 1 )缓冲区容量有限;( 1.5 分) ( 2 )传输线路的带宽有限;( 1.5 分) ( 3 )网络结点的处理能力有限;( 1 分…

Spring Boot 的自动配置,以rabbitmq为例,请详细说明

Spring Boot 的自动配置特性能够大大简化集成外部服务和组件的配置过程。以 RabbitMQ 为例,Spring Boot 通过 spring-boot-starter-amqp 提供了自动配置支持,开发者只需在应用中添加相关依赖并配置必要的属性,Spring Boot 会自动配置所需的连…

visual studio 安全模式

一、安全模式: 在 Visual Studio 中,安全模式是一种启动方式,允许你在禁用所有扩展和自定义设置的情况下启动 Visual Studio。这个模式可以帮助排除插件或扩展引起的问题,特别是在 Visual Studio 无法正常启动时。 二、安全模式下…

使用SSH建立内网穿透,能够访问内网的web服务器

搞了一个晚上,终于建立了一个内网穿透。和AI配合,还是得自己思考,AI配合才能搞定,不思考只依赖AI也不行。内网服务器只是简单地使用了python -m http.server 8899,但是对于Gradio建立的服务器好像不行,会出…

服务器信息整理:用途、操作系统安装日期、设备序列化、IP、MAC地址、BIOS时间、系统

文章目录 引言I BIOS时间Windows查看BIOS版本安装日期linux查看BIOS时间II 操作系统安装日期LinuxWindowsIII MAC 地址IV 设备序列号Linux 查看主板信息知识扩展Linux常用命令引言 信息内容:重点信息:用途、操作系统安装日期、设备序列化、IP、MAC地址、BIOS时间、系统 Linux…

java项目之读书笔记共享平台(源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的闲一品交易平台。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 读书笔记共享平台的主要使…

【信息系统项目管理师】【综合知识】【备考知识点】【思维导图】第十一章 项目成本管理

word版☞【信息系统项目管理师】【综合知识】【备考知识点】第十一章 项目成本管理 移动端【思维导图】☞【信息系统项目管理师】【思维导图】第十一章 项目成本管理

1、单片机寄存器-io输入实验笔记

1、硬件 时钟总线如下: PB端口挂载在AHB1总线上,因此要对该位进行使能。 引脚 LED0和LED1挂载在PB0和PB1上:推挽输出、100M、 上拉默认高电平,低电平点亮。 2、软件 位带操作 #ifndef _IO_BIT_H_ #define _IO_BIT_H_#define …

【嵌入式硬件】嵌入式显示屏接口

数字显示串行接口(Digital Display Serial Interface) SPI 不过多赘述。 I2C-bus interface 不过多赘述 MIPI DSI MIPI (Mobile Industry Processor Interface) Alliance, DSI (Display Serial Interface) 一般用于移动设备,下面是接口…