分布式执行引擎ray入门--(3)Ray Train

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train.torch

def train_func(config):
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # model.to("cuda")  # This is done by `prepare_model`
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()

# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.load_state_dict(model_state_dict)

模型 

  ray.train.torch.prepare_model() 

model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader() 

报告 checkpoints 和 metrics

+import ray.train
+from ray.train import Checkpoint

 def train_func(config):

     ...
     torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
+    metrics = {"loss": loss.item()} # Training/validation metrics.
+    checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
+    ray.train.report(metrics=metrics, checkpoint=checkpoint)

     ...
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

from ray.train import RunConfig

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")

# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")

# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(
    train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/446092.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言学习--练习4(二维数组)

目录 1.统计有序数组中的负数 2.矩阵对角线元素和 3.最富有客户的资产总量 4.托普利兹矩阵 5.矩阵中的幸运数 6.二进制矩阵中的特殊位置 7.岛屿的周长 1.统计有序数组中的负数 //直接遍历二维数组即可 int countNegatives(int** grid, int gridSize, int* gridColSize) …

大厂面试-好未来一面算法之求最长无重复子串长度

目录指引 大厂面试-好未来一面算法之求最长无重复子串长度本文学习目标或巩固的知识点 3. 无重复字符的最长子串🟡🟢通过题目可知题解结果验证 大厂面试-好未来一面算法之求最长无重复子串长度 本文学习目标或巩固的知识点 学习如何处理经典题目《最长…

Java 中的 File 类常用方法介绍

Java 中的 File 类是 java.io 包的一部分,它提供了丰富的文件操作方法。File 类可以用来表示文件和目录路径名的抽象表示形式,即它可以用来获取文件或目录的属性,也可以用来创建、删除、重命名文件和目录。下面是一些常用的 File 类方法&…

OSPF Router-ID 实验简述

如果要运行OSPF协议,必须存在Router ID。Router ID是一个32比特无符号整数,是一台路由器在自治系统中的唯一标识。 在实际网络部署中,建议手工配置OSPF的Router ID,因为这关系到协议的稳定。 如果不指定,启运OSPF进程后…

【Web开发】深度学习HTML(超详细,一篇就够了)

💓 博客主页:从零开始的-CodeNinja之路 ⏩ 收录文章:【Web开发】深度学习html(超详细,一篇就够了) 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 HTML1. HTML基础1.1 什么是HTML1.2 认识HTML标签1.3 HTML文件基本…

HTML使用

文章目录 一、简介二、HTML快速入门三、基础标签四、图片、音频、视频标签五、超链接标签六、列表标签七、表格标签八、布局标签九、表单标签十、表单向标签 一、简介 二、HTML快速入门 ​ <html><head><title>你好</title></head><body>再…

Springboot整合Mybaits启动过程

Springboot整合Mybaits启动过程 1.前言2.MybatisAutoConfiguration3.SqlSessionFactoryBean3.1 XMLConfigBuilder.parse()3.1.1 XMLMapperBuilder.parse()3.1.1.1 XMLStatementBuilder.parse() 4.SqlSession4.1 Executor 1.前言 直接加载mybatis配置文件&#xff0c;然后创建S…

代码训练LeetCode(7)删除有序数组中的重复项

代码训练(7)LeetCode之删除有序数组中的重复项 Author: Once Day Date: 2024年3月10日 漫漫长路&#xff0c;才刚刚开始… 全系列文章可参考专栏: 十年代码训练_Once-Day的博客-CSDN博客 参考文章: 26. 删除有序数组中的重复项 - 力扣&#xff08;LeetCode&#xff09;力扣…

利用websocket +定时器简易的实现一个网络聊天室

其实原理非常简单,就是客户端用户通过websoket来连接websocket服务端。然后服务端,收集每个用户发出的消息, 进而将每条用户的消息通过广播的形式推送到每个连接到服务端的客户端。从而实现用户的实时聊天。 // TODO : 我主要是讲一下实现思路。并未完善其功能。 1.后端 依赖 …

2024年【电工(初级)】考试内容及电工(初级)考试报名

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 电工&#xff08;初级&#xff09;考试内容根据新电工&#xff08;初级&#xff09;考试大纲要求&#xff0c;安全生产模拟考试一点通将电工&#xff08;初级&#xff09;模拟考试试题进行汇编&#xff0c;组成一套电…

线程的魔法:揭开现代操作系统并发执行的面纱

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

elasticsearch篇:RestClient操作

1. RestClient ES官方提供了各种不同语言的客户端&#xff0c;用来操作ES。这些客户端的本质就是组装DSL语句&#xff0c;通过http请求发送给ES。官方文档地址&#xff1a;Elasticsearch Clients | Elastic 其中的Java Rest Client又包括两种&#xff1a; Java Low Level Res…

解读BOT攻击,探索灵活高效的防护之道

回顾早期的互联网应用&#xff0c;由于业务流量比较小&#xff0c;往往单台服务器就能满足负载需求。随着互联网的流量越来越大&#xff0c;单服务器已经不能满足业务需求&#xff0c;无论它优化得再好&#xff0c;都较难承受大量的访问压力。支持负载均衡的技术很多&#xff0…

openssl3.2 - exp - 选择最好的内建椭圆曲线

文章目录 openssl3.2 - exp - 选择最好的内建椭圆曲线概述笔记将 openssl ecparam -list_curves 实现迁移到自己的demo工程备注END openssl3.2 - exp - 选择最好的内建椭圆曲线 概述 在openssl中使用椭圆曲线, 只允许选择椭圆曲线的名字, 无法给定椭圆曲线的位数. 估计每种椭…

扩展学习|系统理解数字经济

文献来源&#xff1a;[1]肖静华,胡杨颂,吴瑶.成长品&#xff1a;数据驱动的企业与用户互动创新案例研究[J].管理世界,2020,36(03):183-205.DOI:10.19744/j.cnki.11-1235/f.2020.0041. [2]陈晓红,李杨扬,宋丽洁等.数字经济理论体系与研究展望[J].管理世界,2022,38(02):208-22413…

浅谈JUC的理解(含JUC知识体系图)

浅谈JUC的理解 一、前言感悟二、并发知识三、一年前回答四、补充体系回答五、补充层次回答六、碎碎念 本文除了说技术&#xff0c;更多的是在一个两年多开发经验的程序员视角下&#xff0c;记录下自己探索到的世界。 如有不妥之处&#xff0c;还请指正。共勉。 一、前言感悟 当…

力扣hot100:239.滑动窗口最大值(优先队列/单调队列)

本题是一个经典的单调队列题。不过用优先队列也能解决。 一、优先队列 在使用优先队列时&#xff0c;我们会遇到这样的问题&#xff1a;如何将一个目标数从优先队列中弹出&#xff1f;如果使用stl这是办不到的&#xff0c;虽然可以自行实现这样的功能。但是我们可以这样思考&am…

什么是GoogLeNet,亮点是什么,为什么是这个结构?

GooLeNet 亮点 最明显的亮点就是引入了Inception&#xff0c;初衷是多卷积核增加特征的多样性&#xff0c;提高泛化能力 &#xff0c;比如&#xff0c;最下边是一个输入层&#xff0c;然后这个输入分别传递给1*1&#xff0c;3 * 3 &#xff0c;5 * 5和一个最大池化层&#xff…

IP数据报格式

每一行都由32位比特&#xff0c;即4个字节组成&#xff0c;每个格子称为字段或者域。IP数据报由20字节的固定部分和最大40字节的可变部分组成。 总长度 总长度为16个比特&#xff0c;该字段的取值以字节为单位&#xff0c;用来表示IPv4数据报的长度(首部长度数据载荷长度)最大…

Long-term Correlation Tracking LCT 目标跟踪算法源码运行

资源 LCT-tracker项目地址VLFeat官网OpenCV下载地址OTB50数据集百度网盘资源 参考博客 一步一步教你跑lct-tracker&#xff08;Win10Matlab 2016bVisual Studio 2015&#xff09;LCT代码跑起来先文章思路总结 正文 1. 环境配置 我的环境&#xff1a;Win11、Visual Studio…