Training - PyTorch Lightning 的 Horovod 策略实践 (all_gather)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://blog.csdn.net/caroline_wendy/article/details/137686312

Horovod

在 PyTorch Lightning 中使用 Horovod 策略,可以在多个 GPU 上并行训练模型。Horovod 是分布式训练框架,通过优化数据传输来提高多 GPU/CPU 训练的效率。要在 PyTorch Lightning 中使用 Horovod,需要在训练命令中指定 Horovod 作为策略。

  • PyTorch Lightning 源码:GitHub - pytorch-lightning
  • Horovod 策略的具体源码:pytorch_lightning.strategies.horovod

1. 构建 Docker 环境

首先,需要构建支持 MPI 运行的 Docker,安装 PyTorch Lightning 与 Horovod 的安装包,目前而言,PyTorch Lightning 的 2.+ 版本,以上,已经移除 Horovod 策略,需要降级至 1.8.6 版本,才支持 Horovod 策略,即:

pip install pytorch-lightning==1.8.6
pip install cmake==3.24.2 
pip install horovod==0.27.0

注意:horovod 安装之前,需要满足 cmake 版本,需要预先安装 cmake 包,否则报错:

File "/tmp/pip-install-qcugcd1u/horovod_a39ef0ac7a9e4940bc6b5969457a47f4/setup.py", line 88, in get_cmake_bin
	raise RuntimeError("Failed to install temporary CMake. "
RuntimeError: Failed to install temporary CMake. Please update your CMake to 3.13+ or set HOROVOD_CMAKE appropriately.

参考:StackOverflow - How to reinstall the latest cmake version?

验证 PyTorch 与 Horovod 是否安装成功:

python

import torch
print(torch.__version__)  # 1.13.1
print(torch.cuda.is_available())  # True

from horovod.torch import mpi_lib_v2 as mpi_lib
# pass

也可以使用 Horovod 策略补充工程,支持 PyTorch Lightning 的 2.+ 版本,参考 GitHub - lightning-Horovod

Horovod

启动 Docker:

nvidia-docker run -it --name [your name] -v /pfs_beijing:/pfs_beijing -v /nfs_beijing:/nfs_beijing -v /nfs_beijing_ai:/nfs_beijing_ai [your image]:[version]

上传 Docker 至服务器:

# 提交 Tag
docker ps -l
docker commit 20df5ad955bb [your image]:[version]

# 准备远程 Tag
docker tag [your image]:[version] [remote image]:[version]
docker images | grep [your image]

# 推送至远程
docker push [remote image]:[version]

2. 配置 Horovod 策略

固定随机种子,确保分布式的表现一致:

# 设置 seed 参数
if args.seed is not None:
    seed_everything(args.seed)
    logger.info(f"[CL] Using seed: {args.seed}")

配置 Horovod 环境变量 与 策略,即:

from pytorch_lightning.strategies import HorovodStrategy

os.environ["HOROVOD_FUSION_THRESHOLD"] = "0"
os.environ["HOROVOD_CACHE_CAPACITY"] = "0"
os.environ["OMPI_MCA_btl_vader_single_copy_mechanism"] = "none"
import horovod.torch as hvd
hvd.init()
torch.cuda.set_device(hvd.local_rank())
strategy = HorovodStrategy()

# Horovod 不需要设置,使用默认值
args.num_nodes = 1
args.gpus = None

logger.info(f"[CL] Using HorovodStrategy")

注意:Horovod 策略,在 pl.Trainer 中,不需要设置 num_nodesgpus,使用默认值,即 1 和 None。

具体的 pl.Trainer 配置 Horovod 策略,如下:

trainer = pl.Trainer(
    accelerator="gpu",
    # ...
    strategy=strategy,  # 多机多卡配置
    num_nodes=args.num_nodes,  # 节点数
    devices=args.gpus,  # 每个节点 GPU 卡数
)

3. 配置 Horovod 的 all_gather 实例

在 PyTorch Lightning 中,不推荐直接使用 torch.distributed.all_gather_object() 进行分布式数据汇集,建议在 pl.LightningModule 类中,直接调用 self.all_gather() 方法。

  • torch.distributed.all_gather_object() 的源码,参考 Doc - PyTorch
  • LightningModule.all_gather() 的源码,参考 Doc - Lighting 1.8.6
  • horovod.torch.allgather() 的源码,参考 Doc - Horovod

LightningModule 的 all_gather() 调用 Horovod 的 allgather() 函数,源码如下:

def all_gather(self, result: Tensor, group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False) -> Tensor:
        if group is not None and group != dist_group.WORLD:
            raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")

        if len(result.shape) == 0:
            # Convert scalars to single dimension tensors
            result = result.reshape(1)

        # sync and gather all
        self.join()
        return hvd.allgather(result)

其中,torch.distributed.all_gather_object() 方法,报错如下:

horovod all_gather_object "Default process group has not been initialized, please make sure to call init_process_group.""

原因是,在 LightningModule 中,不推荐直接使用 torch.distributed 的方法,建议直接调用 LightningModule 的内部方法。

其中 all_gather 的源码修改示例,如下:

class ModelWrapper(pl.LightningModule):
  	
    def gather_log(self, log, world_size):
        if world_size == 1:
            return log

        # 异常代码,不建议直接调用 torch.distributed
        # log_list = [None] * world_size
        # torch.distributed.all_gather_object(log_list, log)
        # log = {key: sum([l[key] for l in log_list], []) for key in log}

        log_gather_map = self.all_gather(log)
        # logger.info(f"[CL] log: {log}")
        # logger.info(f"[CL] log_list_map: {log_gather_map}")

        log_parse_map = dict()
        for key in log_gather_map.keys():
            # [sample,num_node],例如 样本 3 个,Node 2个,[[1,2],[3,4],[5,6]]
            tmp_list = log_gather_map[key]
            for item in tmp_list:
                if isinstance(item, torch.Tensor):
                    item_cpu = item.detach().cpu()
                    item_x = item_cpu.numpy().tolist()
                    if key not in log_parse_map.keys():
                        log_parse_map[key] = []
                    # sum([[1,2],[3,4]], []) -> [1, 2, 3, 4]
                    log_parse_map[key] += item_x
                elif isinstance(item, str):
                    # val_name = ['7skh_B', '7vqk_A', '7vrf_A'],all_gather 问题
                    continue
        # logger.info(f"[CL] log_parse_map: {log_parse_map}")
        return log_parse_map
      
	# ...

日志输出,包括2个卡,每个卡的数据,all_gather之后,获得全部数据,如下:

# Worker 0, all_gather 之前:
[worker-0:163] [INFO] [CL] log: 
{
  'val_first_ref_rmsd': [30.974, 21.57, 18.238],
  # ...
}

# Worker 1, all_gather 之前:
[worker-1:163] [INFO] [CL] log: 
{
	'val_first_ref_rmsd': [27.358, 19.888, 32.003],
  # ...
}

# Worker 0, all_gather 之后:
[worker-0:163] [INFO] [CL] log_gather_map:
{
  'val_first_ref_rmsd': [
    tensor([30.9740, 27.4560], device='cuda:0'),
    tensor([21.5700, 19.6400], device='cuda:0'),
    tensor([18.2380, 31.5020], device='cuda:0')
  ],
  # ...
}

# 获得全部的6个样本数据:
[worker-1:163] [INFO] [CL] log_parse_map: 
{
	'val_first_ref_rmsd': [30.9740, 27.4560, 21.5700, 19.6400, 18.2380, 31.5020],
	# ...
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/556995.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux sudo suid提权练习

题目比较简单,可以利用sudo和多种suid程序提权,做个记录 进入靶场题目环境 获得节点信息 远程连接上 执行命令id,发现只是admin普通账户 sudo提权 发现存在 /usr/bin/vim, /usr/bin/bash, /usr/bin/more, /usr/bin/less, /usr/bin/nano, /…

CSS入门:link链接样式和4种状态的详解

你好,我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生,一枚程序媛,感谢关注。回复 “前端基础题”,可免费获得前端基础 100 题汇总,回复 “前端工具”,可获取 Web 开发工具…

React + 项目(从基础到实战) -- 第九期

实现分页 , LoadMore 上划加载更多功能效果 分页 page : 当前页 pageSize: 页面大小 自定义分页组件 组件传值 import {FC , useEffect, useState } from reactimport { useNavigate , useLocation ,useSearchParams} from react-router-dom;import { Pagination } from &quo…

每日两题3

礼物最大价值 class Solution { public:int jewelleryValue(vector<vector<int>>& frame) {int m frame.size(),n frame[0].size();vector<vector<int>> dp(m1,vector<int>(n1,0));for(int i 1; i < m;i){for(int j 1; j < n;j){d…

轻松点餐|餐饮小程序新玩法,美食触手可及

在企业经营领域&#xff0c;小程序正成为越来越多行业开展线上经营的重要工具。依托小程序等工具自主开发数字化经营平台&#xff0c;已经成为零售、餐饮等日常消费行业的趋势。餐饮行业向智能化快速迭代已势在必行&#xff0c;在此进程中&#xff0c;小程序成为了备受餐饮商家…

Mysql嵌套查询太简单了

1、子查询的分类 不相关查询&#xff1a; 子查询能独立执行 相关查询&#xff1a; 子查询不能独立运行 相关查询的执行顺序&#xff1a; 首先取外层查询中表的第一个元组,根据它与内层查询相关的属性值处理内层查询, 若WHERE子句返回值为真&#xff0c;则取此元组放入结果…

SpringBoot整合PDF动态填充数据并下载

目录 目录 一、准备环境 二、iTextPDF介绍 三、步骤 四、访问查看结果 五、源代码参考 一、准备环境 ①下载一个万兴pdf软件 ②准备一个pdf 文件 二、iTextPDF介绍 这是一个用于生成PDF文档的Java库&#xff0c; 文档创建与修改&#xff1a;iTextPDF能够从零开始创建…

2024红明谷杯——Misc 加密的流量

2024红明谷杯——Misc 加密的流量 写在前面&#xff1a; 这里是贝塔贝塔&#xff0c;照例来一段闲聊 打比赛但赛前一波三折&#xff0c;又是成功签到的一个比赛 说起来比赛全名叫红明谷卫星应用数据安全场景赛&#xff0c;但好像真的跟卫星的关系不大&#xff0c;没有bin方…

面试Spring框架

什么是Spring框架&#xff1f; Spring框架是一个开源的Java应用程序框架&#xff0c;提供了综合的基础设施支持&#xff0c;用于开发Java企业应用程序。它涵盖了从基本的核心容器到全面的企业服务&#xff0c;可以用于构建任何规模的应用程序。 Spring框架的核心特性是什么&am…

Go之map详解

map的结构 map实现的两个关键数据结构 hmap 定义了map的结构bmap 定义了hmap.buckets中每个bucket的结构 // A header for a Go map. type hmap struct {count int // 元素的个数flags uint8 // 状态标记&#xff0c;标记map当前状态&#xff0c;是否正在写入B …

<计算机网络自顶向下> 可靠数据传输的原理(未完成)

可靠数据传输&#xff08;rdt&#xff1a;Reliable Data Transfer&#xff09;的原理 rdt在应用层&#xff0c;传输层和数据链路层都很重要是网络TOP10问题之一信道的不可靠特点决定了可靠数据传输rdt的复杂性rdt_send: 被上层&#xff08;如应用层&#xff09;调用&#xff0…

41.缺失的第一个正数

1. 解题原理&#xff1a; &#xff08;1&#xff09;对于一个有序的、不缺失元素的正数数组nums&#xff0c;元素nums[i]应当位于nums[i]-1的位置处。 &#xff08;2&#xff09;nums数组的长度为N&#xff0c;缺失的第一个正数如果不位于[1,N]&#xff0c;那么就肯定是N1 2. …

excel表格怎么设置密码?excel文件加密的两个方法

一、加密码的原理​ Excel加密码的原理主要基于加密算法和密钥管理。当用户为Excel文件或工作表设置密码时&#xff0c;Excel会采用一种加密算法对文件或工作表进行加密处理。这种加密算法通常是对称加密算法&#xff0c;如AES(高级加密标准)或DES(数据加密标准)。 二&#x…

海外住宅代理:推特账号为何容易被关小黑屋?

推特是全球最受欢迎的社交媒体之一&#xff0c;每天都有数以百万计的用户在这个平台上发布信息、分享观点和交流互动。然而&#xff0c;有些用户可能会发现他们的推特账号不幸陷入了所谓的“关小黑屋”状态&#xff0c;即账号被限制了可见度&#xff0c;导致发布的内容无法被其…

【数据分析面试】24.20个数据库问答题 (考察数据开发和实际应用能力)

作为数据从业者&#xff0c;日常工作除了对各类业务数据进行分析挖掘&#xff0c;也需要经常和数据库打交道、甚至也少不了要承担一些数据开发、数仓管理的工作。掌握数据库管理的基本概念和技术是至关重要的。无论是初学者还是从业者&#xff0c;理解数据库索引、范式、事务、…

四.音视频编辑-音频混合-概述

引言 当我们在前两篇博客中成功地构建了一个媒体组合&#xff0c;并且略过了音频部分时&#xff0c;我们意识到了我们需要对这个项目进行更详细的探讨。在本篇博客中&#xff0c;我们将会展示如何创建一个包含视频轨道、配音音频轨道以及背景音频轨道的完整媒体组合。更进一步…

游泳耳机哪个牌子好?体验与口碑兼顾的4大游泳耳机汇总!

最近的天气越来越炎热了&#xff0c;许多人选择游泳作为一种既能锻炼身体又能享受清凉的活动。而随着科技的发展&#xff0c;越来越多的运动爱好者希望在游泳时也能享受到音乐的乐趣。因此&#xff0c;游泳耳机应运而生&#xff0c;成为市场上的热门产品。然而&#xff0c;面对…

项目中的解耦小能手-观察者模式

目录 1.使用场景 2.什么是观察模式 3.观察者模式结构图 4.代码实现案例 4.1 subject代码实现 4.2 Observer类代码实现 5. 回顾总结 1.使用场景 当一个对象的改变需要同事改变其他对象的时候&#xff0c;如&#xff1a;订单中心-下单成功需要通知库存、物流和积分去做相应…

交流回馈老化测试负载优点和应用

交流回馈老化测试负载是用于模拟真实环境下设备运行状态的测试工具&#xff0c;通过对设备进行长时间的连续工作&#xff0c;以检测其性能的稳定性和可靠性。这种测试负载具有许多优点&#xff0c;并且在实际应用中有着广泛的用途。 在实际应用中&#xff0c;设备往往需要在各种…

Flask实战

from flask import Flask appFlask(__name__)点击Flask同时点击键盘ctrl即可查看Flask的默认初始化函数 def __init__(self,import_name: str,static_url_path: str | None None,static_folder: str | os.PathLike[str] | None "static",static_host: str | None …