使用 SwanLab 进行可视化 MNIST 手写体识别训练

使用 SwanLab 进行可视化 MNIST 手写体识别训练

在线演示demo

本案例主要:

  • 使用pytorch进行CNN(卷积神经网络)的构建、模型训练与评估
  • 使用swanlab跟踪超参数、记录指标和可视化监控整个训练周期

一、相关简介

SwanLab

SwanLab是一款开源、轻量级的AI实验跟踪工具,提供了一个跟踪、比较、和协作实验的平台,旨在加速AI研发团队100倍的研发效率。其提供了友好的API和漂亮的界面,结合了超参数跟踪、指标记录、在线协作、实验链接分享、实时消息通知等功能,让您可以快速跟踪ML实验、可视化过程、分享给同伴。

SwanLab提供了一套云端AI实验跟踪方案,面向训练过程,提供了训练可视化、实验跟踪、超参数记录、日志记录、多人协同等功能,研究者能轻松通过直观的可视化图表找到迭代灵感,并且通过在线链接的分享与基于组织的多人协同训练,打破团队沟通的壁垒。

可视化界面截图:

在这里插入图片描述

MNIST

MNIST手写体识别是深度学习最经典的入门任务之一,由 LeCun 等人提出。
该任务基于MNIST数据集,研究者通过构建机器学习模型,来识别10个手写数字(0~9)。

二、环境配置

本案例基于Python>=3.8,请在您的计算机上安装好Python。
环境依赖:

torch
torchvision
swanlab

快速安装命令:

pip install torch torchvision swanlab

MNIST 数据集已经被 torch 自动集成了,所以不需要额外下载,很方便。

三、训练代码

复制以下代码,创建 app.py 并粘贴代码,保存后直接使用 python 或 IDE 运行:python app.py

import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.models import ResNet18_Weights
import swanlab

# CNN网络构建
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1,28x28
        self.conv1 = nn.Conv2d(1, 10, 5)  # 10, 24x24
        self.conv2 = nn.Conv2d(10, 20, 3)  # 128, 10x10
        self.fc1 = nn.Linear(20 * 10 * 10, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        in_size = x.size(0)
        out = self.conv1(x)  # 24
        out = F.relu(out)
        out = F.max_pool2d(out, 2, 2)  # 12
        out = self.conv2(out)  # 10
        out = F.relu(out)
        out = out.view(in_size, -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.log_softmax(out, dim=1)
        return out


# 捕获并可视化前20张图像
def log_images(loader, num_images=16):
    images_logged = 0
    logged_images = []
    for images, labels in loader:
        # images: batch of images, labels: batch of labels
        for i in range(images.shape[0]):
            if images_logged < num_images:
                # 使用swanlab.Image将图像转换为wandb可视化格式
                logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))
                images_logged += 1
            else:
                break
        if images_logged >= num_images:
            break
    swanlab.log({"MNIST-Preview": logged_images})
    

def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):
    model.train()
    # 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签
    for iter, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        # 2. 传入到resnet18模型中得到预测结果
        outputs = model(inputs)
        # 3. 将结果和标签传入损失函数中计算交叉熵损失
        loss = criterion(outputs, labels)
        # 4. 根据损失计算反向传播
        loss.backward()
        # 5. 优化器执行模型参数更新
        optimizer.step()
        print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),
                                                                      loss.item()))
        # 6. 每20次迭代,用SwanLab记录一下loss的变化
        if iter % 20 == 0:
            swanlab.log({"train/loss": loss.item()})

def test(model, device, val_dataloader, epoch):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        # 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签
        for inputs, labels in val_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            # 2. 传入到resnet18模型中得到预测结果
            outputs = model(inputs)
            # 3. 获得预测的数字
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            # 4. 计算与标签一致的预测结果的数量
            correct += (predicted == labels).sum().item()
    
        # 5. 得到最终的测试准确率
        accuracy = correct / total
        # 6. 用SwanLab记录一下准确率的变化
        swanlab.log({"val/accuracy": accuracy}, step=epoch)
    

if __name__ == "__main__":

    #检测是否支持mps
    try:
        use_mps = torch.backends.mps.is_available()
    except AttributeError:
        use_mps = False

    #检测是否支持cuda
    if torch.cuda.is_available():
        device = "cuda"
    elif use_mps:
        device = "mps"
    else:
        device = "cpu"

    # 初始化swanlab
    run = swanlab.init(
        project="MNIST-example",
        experiment_name="PlainCNN",
        config={
            "model": "ResNet18",
            "optim": "Adam",
            "lr": 1e-4,
            "batch_size": 256,
            "num_epochs": 10,
            "device": device,
        },
    )

    # 设置MNIST训练集和验证集
    dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
    train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])

    train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
    val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    # (可选)看一下数据集的前16张图像
    log_images(train_dataloader, 16)

    # 初始化模型
    model = ConvNet()
    model.to(torch.device(device))

    # 打印模型
    print(model)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=run.config.lr)

    # 开始训练和测试循环
    for epoch in range(1, run.config.num_epochs+1):
        swanlab.log({"train/epoch": epoch}, step=epoch)
        train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)
        if epoch % 2 == 0: 
            test(model, device, val_dataloader, epoch)

    # 保存模型
    # 如果不存在checkpoint文件夹,则自动创建一个
    if not os.path.exists("checkpoint"):
        os.makedirs("checkpoint")
    torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')

四、注意事项

在这里插入图片描述
在运行代码的时候,可能会出现如上提示,需要输入一个凭证,这个时候我们只需要去 SwanLab 云端版登录并获取,复制后粘贴到终端,回车后继续运行即可:

在这里插入图片描述

当然,有云端版肯定也有本地版。

上面的训练会将训练数据上传到云端,让我们可以直接通过在线链接的方式访问自己的实验数据和实验进度 。但是还可以选择不上传,而通过本地命令在本机开启一个面板服务,其前端界面与云端版基本一致,同样能查看实验数据和详细信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/649325.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue基础(数据绑定、export使用)

1、简介 在使用vue开发的过程中&#xff0c;经常会遇到一些容易混淆的问题&#xff0c;因此&#xff0c;在本文中进行汇总操作&#xff0c;只有通过不断总结学习&#xff0c;才能更好掌握vue的使用&#xff08;每天进步一点&#xff09;。 2、数据绑定 在js中定义数据&#xf…

三分钟一条AI小和尚视频 ,日引300+创业粉。单日变现四位数 全套工具

经过六个月的不懈努力和无数次的尝试错误&#xff0c;我终于找到了一个高效引流和积累粉丝的新策略&#xff0c;并愿意与大家无私分享。这一次&#xff0c;我将详尽地介绍这个方法&#xff0c;建议朋友们多次观看以彻底掌握其精髓。 简而言之&#xff0c;该策略主要依托于AI绘…

Spring 原理详解

1. Bean的作用域 Bean在Spring中表示的是Spring管理的对象&#xff0c;Bean的作用域是只Bean在Spring框架中的某种行为模式。 在Spring中&#xff0c;支持6中作用域&#xff1a; singleton&#xff1a;单例作用域&#xff0c;在整个 Spring IoC 容器中&#xff0c;只创建一个…

Json差异比较

json差异比较 如何比较两个json的差异 代码实现 导入依赖 <dependency><groupId>cn.xiaoandcai</groupId><artifactId>json-diff</artifactId><!-- 旧版本可能存在某些缺陷。版本请以maven仓库最版为准。 --><version>4.1.3-RC1-R…

没想到,一个小妙招让桌面运维效率翻倍

号主&#xff1a;老杨丨11年资深网络工程师&#xff0c;更多网工提升干货&#xff0c;请关注公众号&#xff1a;网络工程师俱乐部 我的网工朋友大家好。 咱们都知道&#xff0c;电脑用久了&#xff0c;总会出些小毛病&#xff0c;比如桌面图标不显示了&#xff0c;C盘又满了&a…

springboot+minio 文件上传

前期准备 需要先安装minio文件服务器&#xff0c;请参考我上一篇文章 pom.xml 版本 本次使用的是springboot2.7.16 版本&#xff0c; minio 版本是8.2.2 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-pare…

利用AI办公工具类API,大幅提高办公效率

AI办公工具类API是一项革命性的技术&#xff0c;利用人工智能的力量为办公场景提供了许多创新的解决方案。借助AI办公工具类API&#xff0c;用户可以实现自动化的文档处理、语音转文字、图像识别、数据分析等多种功能&#xff0c;大大提高了办公效率和工作质量。此外&#xff0…

Uni-App开发 导入(引入)Vant-Weapp组件;支持vue3/vue2版本和微信小程序

文章目录 目录 文章目录 操作流程 小结 概要安装流程技术细节小结 概要 Vant Weapp官网&#xff1a;Vant Weapp - 轻量、可靠的小程序 UI 组件库 准备工作&#xff0c;需要确保自己的电脑上已安装Hbuilde和node 全程操作的环境都需要这些配合才能运行上&#xff0c;可参考作者…

如何彻底搞懂组合(Composite)设计模式?

当我们在设计系统对象关系时&#xff0c;有时候会碰到这样一种场景&#xff0c;一个对象中包含了另一组对象&#xff0c;两者构成一种”部分-整体”的关联关系。 正如上图中所展示的&#xff0c;当我们面对这样一种对象关系时&#xff0c;通常都需要分别构建单独的访问方式&…

数据挖掘案例-航空公司客户价值分析

文章目录 1. 案例背景2. 分析方法与过程2.1 分析流程步骤2.2 分析过程1. 数据探索分析2. 描述性统计分析3. 分布分析1.客户基本信息分布分析2. 客户乘机信息分布分析3. 客户积分信息分布分析 4. 相关性分析 3. 数据预处理3.1 数据清洗3.2 属性约束3. 3 数据转换 4. 模型构建4. …

【面经】单片机

1、单片机IO口工作方式 输入 模拟输入&#xff08;GPIO_Mode_AIN&#xff09;&#xff1a;关闭施密特触发器&#xff0c;将电压信号传送到片上外设模块&#xff0c;通常用于连接模拟信号源。浮空输入&#xff08;GPIO_Mode_IN_FLOATING&#xff09;&#xff1a;在浮空输入状态…

回收站清空的文件怎么恢复?8个方法公开(2024更新版)

“我太粗心了&#xff0c;刚想恢复部分回收站中误删的重要文件&#xff0c;一不小心把回收站清空了&#xff0c;现在还有什么方法可以恢复它们吗&#xff1f;” 在数字时代&#xff0c;电脑已经成为我们日常生活和工作中不可或缺的工具。然而&#xff0c;随着我们对电脑的依赖加…

etcd 和 MongoDB 的混沌(故障注入)测试方法

最近在对一些自建的数据库 driver/client 基础库的健壮性做混沌&#xff08;故障&#xff09;测试, 去验证了解业务的故障处理机制和恢复时长. 主要涉及到了 MongoDB 和 etcd 这两个基础组件. 本文会介绍下相关的测试方法. MongoDB 中的故障测试 MongoDB 是比较世界上热门的文…

【算法】排序——加更

补充1个排序&#xff1a;希尔排序 思路&#xff1a;首先定义一个gap,从第0个数开始&#xff0c;每隔一个gap取出一个数&#xff0c;将取出来的数进行比较&#xff0c;方法类似插入排序。第二轮从第二个数开始&#xff0c;每隔一个gap取出一个数再进行插入排序。四轮就可以取完…

新手一次过软考高级(系统规划与管理师)秘笈,请收藏!

2024上软考已经圆满结束&#xff0c;距离下半年的考试也只剩下半年不到的时间。需要备考下半年软考高级的小伙伴们可以抓紧开始准备了&#xff0c;毕竟高级科目的难度可是不低的。 今天给大家整理了——系统规划与管理师的备考资料 &#xff0c;都是核心重点&#xff0c;有PDF&…

微博v14.5.1,集成猪手模块2.3.0-276,移除广告和各类推广提示

软件介绍 微博 v14.5.1&#xff0c;内置猪手模块直装版是一款专业优化的微消客户端&#xff0c;该软件融合了咸猪手模块&#xff0c;并提供了用户友好的自定义选项。这些选项包括广告移除、停止推荐内容、消除各类提示消息等功能&#xff0c;旨在提升用户的个性化使用体验。 …

最详细Linux提权总结(建议收藏)

1、内核漏洞脏牛提权 查看内核版本信息 uname -a 具体提权 1、信息收集配合kali提权 uname -a #查看内核版本信息 内核版本为3.2.78&#xff0c;那我们可以搜索该版本漏洞 searchsploit linux 3.2.78 找到几个可以使用的脏牛提权脚本&#xff0c;这里我使用的是40839.c脚…

Facebook广告如何开户以及投放费用?

Facebook作为全球最大的社交媒体平台之一&#xff0c;成为了企业与个人推广品牌、产品或服务的重要渠道。其精准的广告定向功能和庞大的用户基数&#xff0c;为广告主提供了无限的商机。云衔科技为企业提供专业的Facebook上开户和运营服务&#xff0c;助力您高效获客。 一、Fa…

【Spring Cloud】Feign整合服务容错中间件Sentinel

文章目录 引入sentinel依赖配置文件为被容错的接口指定容错类创建容错类修改controller演示扩展为被容错的接口更改容错类创建回退工厂类演示 总结 上一篇文章中我们已经对服务容错中间件 Sentinel 持久化的两种模式进行了全面解析&#xff0c;本文我们将对Feign和Sentinel进行…

学术图表的基本配色方法

不论是商业图表还是专业图表&#xff0c;图表的配色都极其关键。图表配色主要有彩色和黑白两种配色方案。刘万祥老师曾提出&#xff1a; “在我看来&#xff0c;普通图表与专业图表的差别&#xff0c;很大程度就体现在颜色运用上。” 对于科学图表&#xff0c;大部分国内的期…