Pytorch DistributedDataParallel(DDP)教程一:快速入门理论篇

Pytorch DistributedDataParallel(DDP)教程一:快速入门理论篇

目录

          • 一、 写在前面
          • 二、什么是分布式并行训练
            • 1. 并行训练
            • 2. 数据并行
          • 三、DDP的基本原理
            • 1. DDP的训练过程
            • 2. Ring-All-Reduce算法
          • 四、如何搭建一个Pytorch DDP代码框架
            • 1. 与DDP有关的基本概念
            • 2. 与DDP有关的一些操作
            • 3. 要实现DDP训练,我们需要解决哪些问题?
            • 4. 一个最简单的DDP代码框架
          • 五、查资料过程中的一个小惊喜

一、 写在前面

随着深度学习技术的不断发展,模型的训练成本也越来越高。训练一个高效的通用模型,需要大量的训练数据和算力。在很多非大模型相关的常规任务上,往往也需要使用多卡来进行并行训练。在多卡训练中,最为常用的就是分布式数据并行(DistributedDataParallel, DDP)。但是现有的有关DDP的教程和博客比较少,内容也比较分散繁琐。在大多数情况下,我们只需要学会如何使用即可,不需要特别深入地了解原理。为此,写下这个系列博客,简明扼要地介绍一下DDP的使用,抛开繁杂的细节和原理,帮助快速上手使用(All in one blog)。

篇幅较长,分为上下两篇:这篇简要介绍相关背景和理论知识,下篇详细介绍代码框架和搭建流程。

二、什么是分布式并行训练
1. 并行训练

在Pytorch中,有两种并行训练方式:

1)模型并行。模型并行通常是指你的模型非常大,大到一块卡根本放不下,因而需要把模型进行拆分放到不同的卡上。

2)数据并行。数据并行通常用于训练数据非常庞大的时候,比如有几百万张图像用于训练模型。此时,如果只用一张卡来进行训练,那么训练时间就会非常的长。并且由于单卡显存的限制,训练时的batch size不能设置得过大。但是,对于很多模型的性能而言,由于BN层的使用,都会和batch size的大小正相关。此外,很多基于对比学习的训练算法,由于其对负样本的需求,性能也与batch size的大小正相关。因此,我们需要使用多卡训练,不仅可以训练加速,并且可以设置更大的batch size来提升性能。

2. 数据并行

在Pytorch中有两种方式来实现数据并行:

1)数据并行(DataParallel,DP)。DataParallel采用参数服务器架构,其训练过程是单进程的。在训练时,会将一块GPU作为server,其余的GPU作为worker,在每个GPU上都会保留一个模型的副本用于计算。训练时,首先将数据拆分到不同的GPU上,然后在每个worker上分别进行计算,最终将梯度汇总到server上,在server进行模型参数更新,然后将更新后的模型同步到其他GPU上。这种方式有一个很明显的弊端,作为server的GPU其通信开销和计算成本非常大。它需要和其他所有的GPU进行通信,并且梯度汇总、参数更新等步骤都是由它完成,导致效率比较低。并且,随着多卡训练的GPU数量增强,其通信开销也会线性增长。

Parameter Server架构

不过DataParallel的代码十分简洁,仅需在原有单卡训练的代码中加上一行即可。

model = nn.DataParallel(model) 

如果你的数据集并不大,只有几千的规模,并且你多卡训练时的卡也不多,只有4块左右,那么DataParallel会是一个不错的选择。

关于Parameter Server更详细的原理介绍,可以参考:

深度学习加速:算法、编译器、体系结构与硬件设计

一文讀懂「Parameter Server」的分布式機器學習訓練原理

2)分布式数据并行(DistributedDataParallel,DDP)。DDP采用Ring-All-Reduce架构,其训练过程是多进程的。如果要用DDP来进行训练,我们通常需要修改三个地方的代码:数据读取器dataloader,日志输出print,指标评估evaluate。其代码实现略微复杂,不过我们只需要始终牢记一点即可:每一块GPU都对应一个进程,除非我们手动实现相应代码,不然各个进程的数据都是不互通的。Pytorch只为我们实现了同步梯度和参数更新的代码,其余的需要我们自己实现。

Ring-All-Reduce架构

三、DDP的基本原理
1. DDP的训练过程

DDP的训练过程可以总结为如下步骤:

1)在训练开始时,整个数据集被均等分配到每个GPU上。每个GPU独立地对其分配到的数据进行前向传播(计算预测输出)和反向传播(计算梯度)。

2)同步各个GPU上的梯度,以确保模型更新的一致性,该过程通过Ring-All-Reduce算法实现。

3)一旦所有的GPU上的梯度都同步完成,每个GPU就会使用这些聚合后的梯度来更新其维护的模型副本的参数。因为每个GPU都使用相同的更新梯度,所以所有的模型副本在任何时间点上都是相同的。

2. Ring-All-Reduce算法

Ring-All-Reduce架构是一个环形架构,所有GPU的位置都是对等的。每个GPU上都会维持一个模型的副本,并且只需要和它相连接的两个GPU通信。

对于第k个GPU而言,只需要接收来自于第k-1个GPU的数据,并将数据汇总后发送给第k+1个GPU。这个过程在环中持续进行,每个GPU轮流接收、聚合并发送梯度。

经过 N 次的迭代循环后(N是GPU的数量),每个GPU将累积得到所有其他GPU的梯度数据的总和。此时,每个GPU上的梯度数据都是完全同步的。

DDP的通信开销与GPU的数量无关,因而比DP更为高效。如果你的训练数据达到了十万这个量级,并且需要使用4卡及以上的设备来进行训练,DDP将会是你的最佳选择。

关于DDP和Ring-All-Reduce算法的更多实现原理和细节,可以参考:

Bringing HPC Techniques to Deep Learning

Pytorch 分散式訓練 DistributedDataParallel — 概念篇

Technologies behind Distributed Deep Learning: AllReduce

四、如何搭建一个Pytorch DDP代码框架
1. 与DDP有关的基本概念

在开始使用DDP之前,我们需要了解一些与DDP相关的概念。

参数含义查看方式
group分布式训练的进程组,每个group可以进行自己的通信和梯度同步Group通常在初始化分布式环境时创建,并通过torch.distributed.new_group等API创建自定义groups。
world size参与当前分布式训练任务的总进程数。在单机多GPU的情况下,world size通常等于GPU的数量;在多机情况下,它是所有机器上所有GPU的总和。torch.distributed.get_world_size()
rankRank是指在所有参与分布式训练的进程中每个进程的唯一标识符。Rank通常从0开始编号,到world size - 1结束。torch.distributed.get_rank()
local rankLocal rank是当前进程在其所在节点内的相对编号。例如,在一个有4个GPU的单机中,每个GPU进程的local rank将是0, 1, 2, 3。这个参数常用于确定每个进程应当使用哪个GPU。Local rank不由PyTorch的分布式API直接提供,而通常是在启动分布式训练时由用户设定的环境变量,或者通过训练脚本的参数传入。
2. 与DDP有关的一些操作

在DDP中,每个进程的数据是互不影响的(除了采用Ring-All-Reduce同步梯度)。如果我们要汇总或者同步不同进程上的数据,就需要用到一些对应的函数。

1)all_reduce

all_reduce操作会在所有进程中聚合每个进程的数据(如张量),并将结果返回给所有进程。聚合可以是求和、取平均、找最大值等。当你需要获得所有进程的梯度总和或平均值时,可以使用all_reduce。这在计算全局平均或总和时非常有用,比如全局平均损失。

一个示例代码如下:

import torch.distributed as dist

tensor_a = torch.tensor([1.0], device=device)
# 所有进程中的tensor_a将会被求和,并且结果会被分配给每个进程中的tensor_a。
dist.all_reduce(tensor_a, op=dist.ReduceOp.SUM)

2)all_gather

all_gather操作用于在每个进程中收集所有进程的数据。它不像all_reduce那样聚合数据,而是将每个进程的数据保留并汇总成一个列表。当每个进程计算出一个局部结果,并且你需要在每个进程中收集所有结果进行分析或进一步处理时,可以使用all_gather

一个示例代码如下:

import torch
import torch.distributed as dist

# 每个进程有一个tensor_a,其值为当前进程的rank
tensor_a = torch.tensor([rank], device=device)  # 假设rank是当前进程的编号
gather_list = [torch.zeros_like(tensor_a) for _ in range(dist.get_world_size())]
# 收集所有进程的tensor_a到每个进程的gather_list
dist.all_gather(gather_list, tensor)

3)broadcast

broadcast操作将一个进程的数据(如张量)发送到所有其他进程中。这通常用于当一个进程生成了某些数据,需要确保其他所有进程都得到相同的数据时。在在开始训练之前,可以用于同步模型的初始权重或者在所有进程中共享某些全局设置。一个示例代码如下:

import torch.distributed as dist

tensor_a = torch.tensor([1.0], device=device)
if rank == 0:
    tensor_a.fill_(10.0)  # 只有rank 0设置tensor_a为10
dist.broadcast(tensor_a, src=0)  # rank 0将tensor_a广播到所有其他进程
3. 要实现DDP训练,我们需要解决哪些问题?

1)如何将数据均等拆分到每个GPU

在分布式训练中,为了确保每个GPU都能高效地工作,需要将训练数据均等地分配到每个GPU上。如果数据分配不均,可能导致某些GPU数据多、某些GPU数据少,从而影响整体的训练效率。

在PyTorch中,可以使用torch.utils.data.DataLoader结合torch.utils.data.distributed.DistributedSamplerDistributedSampler会自动根据数据集、进程总数(world size)和当前进程编号(rank)来分配数据,确保每个进程获取到的数据互不重复且均衡分布。

2)如何在IO操作时避免重复

在使用PyTorch的分布式数据并行(DDP)进行模型训练时,由于每个进程都是独立运行的,IO操作如打印(print)、保存(save)或加载(load)等如果未经特别处理,将会在每个GPU进程上执行。这样的行为通常会导致以下问题:重复打印(每个进程都会输出同样的打印信息到控制台,导致输出信息重复,难以阅读)、文件写入冲突(如果多个进程尝试同时写入同一个文件,会产生写入冲突,导致数据损坏或者输出不正确)、资源浪费(每个进程重复加载相同的数据文件会增加IO负担,降低效率和浪费资源)。

一个简单且可行的解决方案是只在特定进程中进行相关操作,例如,只在rank为0的进程中执行,如有必要,再同步到其他进程。

3)如何收集每个进程上的数据进行评估

在DDP训练中,每个GPU进程独立计算其数据的评估结果(如准确率、损失等),在评估时,可能需要收集和整合这些结果。

通过torch.distributed.all_gather函数,可以将所有进程的评估结果聚集到每个进程中。这样每个进程都可以获取到完整的评估数据,进而计算全局的指标。如果只需要全局的汇总数据(如总损失或平均准确率),可以使用torch.distributed.reduceall_reduce操作直接计算汇总结果,这样更加高效。

4. 一个最简单的DDP代码框架

篇幅太长,见下篇。

五、查资料过程中的一个小惊喜

在查找DDP有关过程中,发现了一些博客和视频做得很不错,而且这里面有一部分是女生做的。博客和视频的质量都很高,内容安排合理,逻辑表达清晰,参考资料也很全面。我看到的时候,还是很惊艳的,巾帼不让须眉!链接如下:

国立中央大学的李馨伊

复旦大学的_Meilinger_

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/556434.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

二次元AI绘画生成器免费:教你生成精美图片

二次元AI绘画生成器,无疑是现代技术与艺术完美结合的典范。这些工具不仅将复杂的绘画过程简化,更让每一个艺术爱好者的创意得以充分展现。这些生成器能够精准捕捉大家的创意精髓,将其转化为细腻、独特的二次元画作。无论是角色设计、场景描绘…

OpenHarmony UI动画-box2d

简介 用于游戏开发,使物体的运动更加真实,让游戏场景看起来更具交互性,比如愤怒的小鸟 效果展示 下载安装 ohpm install ohos/box2dOpenHarmony ohpm环境配置等更多内容,请参考 如何安装OpenHarmony ohpm包 使用说明 import * …

OCR图片转化为Excel文件功能在哪些平台和设备上能实现?

在现今数字化时代,将图片转化为Excel文件已成为许多工作者和研究者的重要需求。这种转换功能在多个平台和设备上均能实现,为用户提供了极大的便利。 首先,我们来看看在线平台方面。目前,市面上有许多在线工具可以将图片转换为Exc…

MoJoCo 入门教程(七)XML 参考

系列文章目录 前言 表格第二列中的符号含义如下: ! 必填元素,只能出现一次 ? 可选元素,只能出现一次 * 可选元素,可多次出现 R 可选元素,可递归出现多次 一、简介 本章是 MuJoCo 中使用的 MJCF 建模语言的参考手册。…

【云计算】安全组和网络ACL的区别

安全组和网络ACL的区别 ACL(Access Control List)和 安全组(Security Group)是两种不同的网络安全控制机制,通常用于管理云计算平台中的网络访问权限。它们在功能和实现上有一些显著的区别: 辨析 范围不同&…

如何把网易云ncm转换为mp3格式

大家在网易云下载音乐的时候会发现下载的音乐格式是ncm,这个格式只能在网易云播放器才能播放,无法在别的播放器上播放。今天教大家把ncm转换为更加通用的格式 打开智游剪辑(官网: zyjj.cc),搜索ncm转mp3 这里就只需要上…

DC电源模块的未来发展方向和创新应用领域

BOSHIDA DC电源模块的未来发展方向和创新应用领域 随着科技的快速发展,直流(DC)电源模块的应用领域也在不断扩大。从传统的电子产品到新兴的清洁能源领域,DC电源模块正发挥着越来越重要的作用。未来,DC电源模块将继续…

windows docker desktop==spark环境搭建

编写文件docker-compose.yml version: 3services:spark-master:image: bde2020/spark-master:3.1.1-hadoop3.2container_name: spark-masterports:- "8080:8080"- "7077:7077"- "2220:22"volumes:- F:\spark-data\m1:/dataenvironment:- INIT_D…

android studio gradle 下载不下来的问题

有时候由于网络问题,会发现 android-studio 或 idea 下载不懂 gradle,可以修改 gradle-wrapper.properties 配置文件中的 distributionUrl 选项来解决这个问题。 distributionUrlhttps\://mirrors.cloud.tencent.com/gradle/gradle-5.1.1-all.zip 这个选…

【七 (2)FineBI FCP模拟试卷-平台新增用户留存分析】

目录 文章导航一、字段解释1、用户平台登录信息表格2、用户平台激活信息表格 二、需求三、操作步骤1、建立用户平台登录信息表格和用户平台激活信息表格的关联关系2、将用户平台激活信息表格的激活日期添加到用户平台登录信息表格3、新增公式列,计算激活时间和登录时…

【ACM列表推荐会议 | EI稳定检索】2024年第四届人工智能、自动化与高性能计算国际会议(AIAHPC 2024)

2024年第四届人工智能、自动化与高性能计算国际会议(AIAHPC 2024) 2024 4th International Conference on Artificial Intelligence, Automation and High Performance Computing 2024第四届人工智能、自动化与高性能计算国际会议(AIAHPC 2024)将于20…

JVM虚拟机(十)Java内存泄漏的排查思路

目录 一、可能产生内存泄露的地方二、复现堆内存泄漏三、如何排查堆内存问题?3.1 获取对内存快照 dump3.2 使用 Visual VM 去分析 dump 文件3.3 定位内存溢出问题 一、可能产生内存泄露的地方 在进行排查 Java 的内存泄漏问题之前,首先我们要知道哪里可…

数据可视化-ECharts Html项目实战(12)

在之前的文章中,我们深入学习ECharts特殊图表中的矩形树图以及Echarts中高级功能的多图表联动。想了解的朋友可以查看这篇文章。同时,希望我的文章能帮助到你,如果觉得我的文章写的不错,请留下你宝贵的点赞,谢谢。 数…

通过IP地理位置阻止网络攻击:有效性与局限性

网络攻击已成为当今互联网世界中的一项常见挑战。黑客和恶意用户利用各种手段对网络系统进行攻击,造成数据泄露、服务中断甚至财产损失。在这种背景下,寻找有效的网络安全解决方案变得至关重要。 IP地理位置阻止是一种基于黑名单的网络安全措施。它的原…

最新SpringBoot项目财务管理系统

采用技术 最新SpringBoot项目财务管理系统的设计与实现~ 开发语言:Java 数据库:MySQL 技术:SpringBootMyBatis 工具:IDEA/Ecilpse、Navicat、Maven 页面展示效果 系统登录页面 管理员功能 管理员功能页面 员工管理页面 部…

Python采集海外电影信息,Roxlabs全球IP代理轻松搞定

一、为什么要学会应用 IP 代理技术二、采集海外电影信息爬虫实战2.1 选择目标网站并生成代理2.2 编写爬虫并设置代理2.3 运行爬虫2.4 处理数据2.5 完整代码 三、Roxlabs 代理及优势分析四、快速应用Roxlabs代理功能4.1 领免费流量4.2 代理类型4.3 获取代理4.4 配置代理设置4.5 …

进销存是什么?怎么快速上手制作进销存报表?

在现代商业运营中,进销存管理是企业持续稳健发展的关键组成部分。为了更好地把握库存状况、合理调配资源、提高决策效率,企业采用进销存报表是一种极为有效的手段。本文将从进销存的基本概念、报表的内容和意义,以及如何制作进销存报表等方面…

opencv的cmake报错

opencv编译报错 CMakeDownloadLog.txt #use_cache "D:/opencv/.cache" #do_unpack "ippicv_2021.8_win_intel64_20230330_general.zip" "71e4f58de939f0348ec7fb58ffb17dbf" "https://raw.githubusercontent.com/opencv/opencv_3rdparty/1…

Ctags的安装和使用

ctags 是一个代码跟踪软件 由于在源码中,具有海量的代码,如果我们在源码中找特定的代码,需要大量时间,但是ctags则可以快速寻找 安装: sudo apt-get install ctagsctags使用 ctrl] //代码跟踪 crlT //代码返回我们…

鸿蒙入门06-常见装饰器( 简单装饰器 )

装饰器是鸿蒙开发中非常重要的一个环节因为在很多地方我们都需要用到装饰器并且如果我们想高度的复用, 那么装饰器就是必不可少的一环接下来我们就来介绍一些常见的装饰器注意 : 所有装饰器首字母大写 Entry 用来装饰 struct 使用表示页面的入口 Component 装饰 struct, …