深度学习:Pytorch分布式训练

深度学习:Pytorch分布式训练

  • 简介
  • 模型并行
  • 数据并行
  • 参考文献

简介

在深度学习领域,模型越来越庞大、数据量不断增加,训练这些大型模型越来越耗时。通过在多个GPU或多个节点上并行地训练模型,我们可以显著减少训练时间。此外,某些模型因为巨大的参数量,单个设备可能无法容纳其整个模型和数据。在这种情况下,分布式训练不仅能提高训练速度,更是必要的手段来训练大模型。为此,PyTorch 分布式训练提供了两种基本的并行方法:

  • 模型并行(Model Parallel):模型并行是指将模型的不同部分放到不同的设备上。这种方式通常用于当一个单独的模型太大而无法放到单个GPU上时。

  • 数据并行(Data Parallel):数据并行是将训练数据分割并在多个设备上同时训练的方法。PyTorch提供了 torch.nn.DataParallel torch.nn.parallel.DistributedDataParallel 用于在多个GPU上并行化模型训练。

模型并行

在这里插入图片描述

模型并行主要利用to(device)函数将模型和数据(Tensor张量)放置在适当设备上,其余代码基本无需额外改动。
以下是一个简单的模型并行的代码示例:

import torch
import torch.nn as nn
import torch.optim as optim


class DemoModel(nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10).to('cuda:0')
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to('cuda:1')

    def forward(self, x):
        x = self.relu(self.net1(x.to('cuda:0')))
        return self.net2(x.to('cuda:1'))

model = DemoModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to('cuda:1')
loss_fn(outputs, labels).backward()
optimizer.step()

注意调用损失函数时,您只需要确保标签与输出位于同一设备上。不难看出,此模型并行的方法效率相对较低,因为在任何时间点,两个 GPU 中只有一个在工作,而另一个则处于闲置状态。而且中间过程变量从cuda:0复制到cuda:1,又会需要额外的开销。因此可以引入流水线并行来进行加速。

在以下代码示例中,采取将输入数据批次划分为 20 组。由于 PyTorch 异步启动 CUDA 操作,因此可以不需要生成多个线程来实现并发。值得注意的是,使用较小的结果split_size会导致许多微小的 CUDA 内核启动,而使用较大的split_size会导致在第一次和最后一次数据划分期间存在相对较长的空闲时间。因此split_size对于特定实验可能有一个最佳配置,可以多次尝试最佳的超参数。

class PipelineParallelResNet50(ModelParallelResNet50):
    def __init__(self, split_size=20, *args, **kwargs):
        super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
        self.split_size = split_size

    def forward(self, x):
        splits = iter(x.split(self.split_size, dim=0))
        s_next = next(splits)
        s_prev = self.seq1(s_next).to('cuda:1')
        ret = []

        for s_next in splits:
            # A. ``s_prev`` runs on ``cuda:1``
            s_prev = self.seq2(s_prev)
            ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

            # B. ``s_next`` runs on ``cuda:0``, which can run concurrently with A
            s_prev = self.seq1(s_next).to('cuda:1')

        s_prev = self.seq2(s_prev)
        ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

        return torch.cat(ret)

数据并行

在这里插入图片描述

DataParallel是单进程、多线程,仅适用于单机,而是DistributedDataParallel多进程,适用于单机和多机训练。由于跨线程的 GIL 争用、每次迭代复制模型以及分散输入和收集输出带来的额外开销,DataParallel通常比DistributedDataParallel在单台机器上更慢。

一般地,数据并行的流程为:

  1. 在使用 distributed 包的任何其他函数之前,需要使用 init_process_group 初始化进程组,同时初始化 distributed 包。
  2. 如果需要进行组内集体通信,用 new_group 创建子分组
  3. 创建分布式并行模型 DDP(model, device_ids=device_ids)
  4. 为数据集创建 Sampler
  5. 使用启动工具 torch.distributed.launch 在每个主机上执行一次脚本,开始训练
  6. 使用 destory_process_group() 销毁进程组

以下是一个简单的数据并行的代码示例:

# demo_ddp.py
# 在init_process_group()时,一般可设置为Gloo、NCCL或mpi后端,Gloo目前在GPU上运行速度比 NCCL慢。所以经验法则是:
# 分布式GPU训练使用 NCCL 后端
# 分布式CPU训练使用 Gloo 后端

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

class DemoModel(nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    device_id = rank % torch.cuda.device_count()
    model = DemoModel().to(device_id)
    ddp_model = DDP(model, device_ids=[device_id])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_id)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    dist.destroy_process_group()

if __name__ == "__main__":
    demo_basic()

然后使用torchrun命令进行启动,其中,nnodes表示总节点数,nproc_per_node表示每个节点运行的进程数,rdzv_id表示用户定义的ID,唯一标识作业的工作组, rdzv_backend表示集合点的后端,rdzv_endpoint表示rendezvous后端运行的地址

# 需要应用 slurm 等集群管理工具来实际在 2 个节点上运行此命令。
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 demo_ddp.py

此命令表示在两台服务器上运行 DDP 脚本,每台服务器运行 8 个进程,即在 16 个 GPU 上运行。

参考文献

  1. https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
  2. https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  3. https://medium.com/deelvin-machine-learning/model-parallelism-vs-data-parallelism-in-unet-speedup-1341bc74ff9e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/560654.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python学习 | 我有两个dataframe,想通过某1列进行匹配

需求 我有两个dataframe,第1个dataframe A的columns是[‘id’, ‘A’, ‘B’, ‘C’],第2个dataframe B的columns是[‘id’, ‘1’, ‘2’, ‘3’],其中’id’列A是B的子集,我想通过’id’列进行匹配,把A给扩充成[‘i…

js 特定索引下拆分字符串并组建成新的字符串数据

要在特定索引处拆分字符串,请使用 slice 方法获取字符串的两个部分,例如 str.slice(0, index) 返回字符串的一部分,但不包括提供的索引,而 str.slice(index) 返回字符串的其余部分。 过程:我们创建一个可重用的变量&a…

「GO基础」目录

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

删除word中下划线的内容

当试卷的题目直接含答案,不利用我们刷题。这时如果能够把下划线的内容删掉,那么将有利于我们复习。 删除下划线内容的具体做法: ①按ctrl H ②点格式下面的字体 ③选择下划线线型中的_____ ④勾选使用通配符并在查找内容中输入"?&qu…

Linux 序列化、反序列化、实现网络版计算器

目录 一、序列化与反序列化 1、序列化(Serialization) 2、反序列化(Deserialization) 3、Linux环境中的应用实例 二、实现网络版计算器 Sock.hpp TcpServer.hpp Jsoncpp库 Protocol.hpp 类 Request 类 Response 辅助函…

SQLite数据库中JSON 函数和运算符(二十七)

返回:SQLite—系列文章目录 上一篇:维护SQLite的私有分支(二十六) 下一篇:SQLite—系列文章目录 ​1. 概述 默认情况下,SQLite 支持 29 个函数和 2 个运算符 处理 JSON 值。还有两个表值函数可用于分解 JSON 字…

程序员自由创业周记#32:新产品构思

程序员自由创业周记#32:新产品构思 新作品 我时常把自己看做一位木匠,有点手艺,能做一些作品养活自己。而 加一、Island Widgets、Nap 就是我的作品。 接下来在持续维护迭代的同时,要开启下一个作品的创造了。 其实早在2022的1…

【C++初阶】List使用特性及其模拟实现

1. list的介绍及使用 1.1 list的介绍 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。 2. list的底层是双向链表结构,双向链表中每个元素存储在互不相关的独立节点中,在节点中通过指针指向其前…

(2022级)成都工业学院数据库原理及应用实验五: SQL复杂查询

写在前面 1、基于2022级软件工程/计算机科学与技术实验指导书 2、成品仅提供参考 3、如果成品不满足你的要求,请寻求其他的途径 运行环境 window11家庭版 Navicat Premium 16 Mysql 8.0.36 实验要求 在实验三的基础上完成下列查询: 1、查询医生…

漆包线行业你了解多少?专业漆包线行业MES生产管理系统

今天就说说漆包线行业,漆包线是工业电机(包括电动机和发电机)、变压器、电工仪表、电力及电子元器件、电动工具、家用电器、汽车电器等用来绕制电磁线圈的主要材料。 漆包线上游是铜杆行业,下游是各种消费终端,主要是电…

[大模型]Qwen-Audio-chat FastApi 部署调用

Qwen-Audio-chat FastApi 部署调用 Qwen-Audio 介绍 Qwen-Audio 是阿里云研发的大规模音频语言模型(Large Audio Language Model)。Qwen-Audio 可以以多种音频 (包括说话人语音、自然音、音乐、歌声)和文本作为输入,并以文本作为…

Linux-用户管理类命令实训

查看根目录下有哪些内容 进入/tmp目录,以自己的学号建一个目录,并进入该目录 像是目前所在的目录 在当前目录下,建立权限为741的目录test1 在目录test1下建立目录test2/test3/test4 进入test2,删除目录test3/test4 (7&…

Python实现在线翻译工具

Python实现在线翻译工具 使用Python的内置的标准库tkinter和webbrowser,实现一个简单Python在线翻译工具。 tkinter库用来创建一个图形用户界面(GUI),webbrowser库用来打开网页。 webbrowser 是 Python 的一个标准库&#xff0…

7.MMD 法线贴图的设置与调教

前期准备 人物 导入温迪模型导入ray.x和ray_controler.pmx导入天空盒time of day调成模型绘制顺序,将天空盒调到最上方给温迪模型添加main.fx材质在自发光一栏,给天空盒添加time of lighting材质 打开材质里的衣服,发现只有一个衣服文件 …

SpringBoot集成FTP

1.加入核心依赖 <dependency><groupId>commons-net</groupId><artifactId>commons-net</artifactId><version>3.8.0</version></dependency> 完整依赖 <dependencies><dependency><groupId>org.springfra…

什么是时间序列分析

时间序列分析是现代计量经济学的重要内容&#xff0c;广泛应用于经济、商业、社会问题研究中&#xff0c;在指标预测中具有重要地位&#xff0c;是研究统计指标动态特征和周期特征及相关关系的重要方法。 一、基本概念 经济社会现象随着时间的推移留下运行轨迹&#xff0c;按…

Linux网络编程--网络传输

Linux网络编程--网络传输 Linux网络编程TCP/IP网络模型网络通信的过程局域网通信跨网络通信&#xff1a;问题总结&#xff1a; Linux网络编程 TCP/IP网络模型 发送方&#xff08;包装&#xff09;&#xff1a; 应用层&#xff1a;HTTP HTTPS SSH等 —> 包含数据&#xff0…

二维码门楼牌管理应用平台建设:助力场所整改与消防安全

文章目录 前言一、二维码门楼牌管理应用平台的构建背景二、二维码门楼牌管理应用平台在场所整改中的作用三、二维码门楼牌管理应用平台的意义与价值四、二维码门楼牌管理应用平台的未来展望 前言 随着城市管理的日益精细化&#xff0c;二维码门楼牌管理应用平台的建设成为了提…

正则表达式中 “$” 并不是表示 “字符串结束”

△△请给“Python猫”加星标 &#xff0c;以免错过文章推送 作者&#xff1a;Seth Larson 译者&#xff1a;豌豆花下猫Python猫 英文&#xff1a;Regex character “$” doesnt mean “end-of-string” 转载请保留作者及译者信息&#xff01; 这篇文章写一写我最近在用 Python …

【Django】学习笔记

文章目录 [toc]MVC与MTVMVC设计模式MTV设计模式 Django下载Django工程创建与运行创建工程运行工程 子应用创建与注册安装创建子应用注册安装子应用 数据模型ORM框架模型迁移 Admin站点修改语言和时区设置管理员账号密码模型注册显示对象名称模型显示中文App显示中文 视图函数与…