【单机多卡】torch改造代码为DDP单机多卡分布式并行

torch分布式数据并行DDPtorch.nn.parallel.DistributedDataParallel代码修改记录。(要求pytorch_version>1.0)

目录

1.🍄🍄要修改的地方概览

2.✏️✏️初始化

3.✏️✏️设置当前进程GPU

4.✏️✏️设置sampler

5.✏️✏️sampler传入dataloader

6.✏️✏️数据放GPU

7.✏️✏️模型放GPU

8.✏️✏️load模型

9.✏️✏️save模型

10.✏️✏️执行命令

整理不易,欢迎一键三连!!!



1.🍄🍄要修改的地方概览

2.✏️✏️初始化

在代码最开始的地方设置初始化参数,即训练和数据组织之前。

n_gpus = args.n_gpus   #自行传入
#local_rank = args.local_rank   #自行传入
local_rank = int(os.environ['LOCAL_RANK'])   #代码计算

torch.distributed.init_process_group('nccl', world_size=n_gpus, rank=local_rank)#初始化进程组
  • 指定GPU之间的通信方式'nccl'
  • world_size:当前这个节点上要用多少GPU卡;(当前节点就是当前机器)
  • rank: 当前进程在哪个GPU卡上,通过args.local_rank来获取,local_rank变量是通过外部指令传入的;(也可以通过环境变量来接收)

注意:自行传入的变量需要通过argparse第三方库写入,示例如下:

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--n_gpus", help="num of gpus")
parser.add_argument("-p", "--project", help="project name")
parser.add_argument('-s', '--sparse', action='store_true', default=False, help='GAT with sparse version or not.')
parser.add_argument('-d', '--seed', type=int, default=72, help='Random seed.')
parser.add_argument('-e', '--epochs', type=int, default=10000, help='Number of epochs to train.')

args = parser.parse_args()
print(args.n_gpus)
print(args.sparse)
print(args.seed)
print(args.epochs)

3.✏️✏️设置当前进程GPU

在初始化之后紧接着设置当前进程的GPU

torch.cuda.set_device(local_rank)

上述指令作用相当于设置CUDA_VISBLE_DEVICES环境变量,设置当前进程要用第几张卡;

4.✏️✏️设置sampler

from torch.utils.data.distributed import DistributedSampler
train_sampler = DistributedSampler(dataset_train)
...
for epoch in range(start_epoch, total_epochs):
    train_sampler.set_epoch(epoch)  #为了让每张卡在每个周期中得到的数据是随机的
    ...

此处的train_dataset为load数据的Dataset类,根据数据地址return出每个image和队形的mask,DistributedSampler返回一堆数据的索引train_sampler,根据索引去dataloader中拿数据,并且在每次epoch训练之前,加上train_sampler.set_epoch(epoch)这句,达到shuffle=True的目的。

5.✏️✏️sampler传入dataloader

from torch.utils.data import DataLoader
dataloader_train = DataLoader(
                              dataset_train, 
                              batch_size=args.batch_size,
                              sampler = train_sampler
    )
dataloader_val = DataLoader(
                            dataset_val,
                            batch_size=1,
    )

通过将train_sampler传入dataloader达到数据传入模型的数据格式。

6.✏️✏️数据放GPU

在每次训练过程中,设置数据放GPU里。

for img,label in dataloader_train:
    inputs = img.cuda(local_rank)  #数据放GPU
    labels = label.cuda(local_rank)  #数据放GPU
    ...

7.✏️✏️模型放GPU

在定义模型的地方,设置将模型放入GPU

model = XXNet()
net = torch.nn.parallel.DistributedDataParallel(model.cuda(local_rank),device_ids=[local_rank])  #模型拷贝,放入DistributedDataParallel

8.✏️✏️load模型

torch.load(model_file_path, map_location = local_rank)

设置 map_location指定将模型传入哪个GPU上

9.✏️✏️save模型

torch.save(net.module.state_dict(), os.path.join(ckp_savepath, ckp_name))

注意,此处保存的net是net.module.state_dict

10.✏️✏️执行命令

python -m torch.distributed.launch --nproc_per_node=n_gpus --master_port 29502 train.py
  • nproc_per_node:等于GPU数量
  • master_port:默认为29501,如果出现address already in use,可以将其修改为其他值,比如29502

参考:视频讲解

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--


🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/53870.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

后端技术趋势指南|如何选择自己的技术方向

编程多条路,条条通罗马 后台大佬 后台路线都是面对后台服务器业务,比如web后台服务器,视频后台服务器,搜索后台服务器,游戏后台服务器,直播后台服务器,社交IM后台服务器等等,大部分…

Python基础入门教程(下)

目录 七、函数进阶 7.1、函数多返回值 7.2、函数多种传参方式 位置参数 关键字参数 缺省参数 不定长参数 位置传递 关键字传递 7.3、匿名函数 函数作为参数传递 lambda匿名函数 八、文件操作 8.1、文件的读取 open()打开函数 mode常用的三种基础访问模式 读操…

无人机调试笔记——常见参数

无人机的PID调试以及速度相关参数 1、Multicopter Position Control主要是用来设置无人机的各种速度和位置参数。调试顺序是先调试内环PID,也就是无人机的速度闭环控制,确认没有问题后再进行外环位置控制,也就是定点模式控制。 2、调试的时…

【C++】做一个飞机空战小游戏(一)——使用getch()函数获得键盘码值

[导读]本系列博文内容链接如下: 【C】做一个飞机空战小游戏(一)——使用getch()函数获得键盘码值 【C】做一个飞机空战小游戏(二)——利用getch()函数实现键盘控制单个字符移动 【C】【C】做一个飞机空战小游戏(三)——模块化程设设计 最近想用c做一个小游戏&#x…

ClickHouse的安装启动

安装步骤 1.关闭防火墙 2.修改资源限制配置文件 2.1 路径:/etc/security/limits.conf 在末尾添加: * soft nofile 65536 #任何用户可以打开的最大的文件描述符数量,默认1024 这里的设置会限制tcp连接数 * hard nofile 65536 * soft nproc…

Android 测试

工程目录图 1- Espresso 2- uiautomator Espresso 文档UI Automator文档ui-automator 英文文档 请点击下面工程名称,跳转到代码的仓库页面,将工程 下载下来 Demo Code 里有详细的注释 代码:testespresso 参考文献 Android 利用 espre…

自建纯内网iot平台服务,软硬件服务器全栈实践

基于以下几个考虑,自制硬件设备,mqtt内网服务器。 1.米家app不稳定,逻辑在云端或xiaomi中枢网关只支持少部分在本地计算。 2.监控homeassistant官方服务有大量数据交互。可能与hass安装小米账户有关。 3.硬件:原理图,l…

机器学习李宏毅学习笔记39

文章目录 前言一、大模型的发展趋势二、KNN LM总结 前言 大模型大资料 大模型的顿悟时刻 一、大模型的发展趋势 随数据量增加,模型可以从量变达到质变,从某一刻开始突然学会东西。 当成为大模型时,分数会从0,0突然变成100,完成“…

OSG3.6.5 + VS2017前期准备及编译

OSG3.6.5 VS2017前期准备及编译 1、前期准备 1.1、osg稳定版本源码 Stable releases (openscenegraph.com) 1.2、osg依赖项 Dependencies (openscenegraph.com) 1.3、osg测试及演示数据 Data Resources (openscenegraph.com) 1.4、安装doxygen和Graphviz(用…

Springboot计算机课程教学辅助系统小程序【纯干货分享,免费领源码01616】

目 录 摘要 1 绪论 1.1课题背景 1.2研究现状 1.3springboot框架介绍 1.4小程序框架以及目录结构介绍 2 计算机课程教学辅助系统小程序系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 操作可行性分析 2.2 系统业务流程分析 2.3 系统功能…

Practice3|922. 按奇偶排序数组 II、143. 重排链表

922. 按奇偶排序数组 II 1.题目: 给定一个非负整数数组 nums, nums 中一半整数是 奇数 ,一半整数是 偶数 。 对数组进行排序,以便当 nums[i] 为奇数时,i 也是 奇数 ;当 nums[i] 为偶数时, i…

Spring Cloud Eureka 服务注册和服务发现超详细(附加--源码实现案例--及实现逻辑图)

文章目录 EurekaEureka组件可以实现哪些功能什么是CAP原则?服务注册代码实战搭建注册中心服务A搭建服务B搭建启动服务启动注册中心启动服务A启动服务B 结束语 Eureka 这篇文章先讲述一下Eureka的应用场景、代码实现案例,多个服务模块注册到Euraka中&…

Mac查看系统状态

syatem profiler mac系统中提供了system profiler来查看系统的详细信息,包括硬件、网络以及安装的软件 Console 显示了系统上的日志文件信息,有助于诊断问题 Activity Monitor 可以提供正在运行的系统的相关信息 https://zhhll.icu/2021/Mac/查看系统…

Ansible的脚本 --- playbook 剧本

文章目录 一、playbook剧本的组成创建剧本运行playbook二、定义、引用变量三、指定远程主机sudo切换用户四、when条件判断五、迭代Templates 模块tags 模块 一、playbook剧本的组成 playbooks 本身由以下各部分组成 (1)Tasks:任务&#xff0…

this is incompatible with sql_mode=only_full_group_by

查看配置 select global.sql_mode 在sql命令行中输入select sql_mode 能够看到sql_mode配置,如果有ONLY_FULL_GROUP_BY,则需要修改 在mysql5.7.5后,ONLY_FULL_GROUP_BY是默认选项,所以就会导致group by的问题 set sql_mode‘复制去掉ONLY_F…

[个人笔记] Windows配置NTP时间同步

Windows - 运维篇 第六章 Windows配置NTP时间同步 Windows - 运维篇系列文章回顾Windows配置NTP时间同步域控环境的NTP配置工作组环境的NTP配置Windows的CMD部分命令集 参考来源 系列文章回顾 第一章 迁移WinSrv系统到虚拟机 第二章 本地安全策略xcopy实现实时备份文件夹内容 …

clickhouse分布式查询降级为本地查询

在基于 clickhouse 做类数仓建模时通常的做法是在本地创建物化视图,然后使用分布式表做代理对外提供服务。我们知道 clickhouse 对于 DQL 内部实现了分布式,而对于 DDL 则需要我们自动实现比如: drop table table_name on cluster cluster_n…

mysql的整体架构

服务层-引擎层-硬盘层 一条语句执行的整体过程: 先建立连接(mysql -h -p 密码)–预处理-词法分析-语法分析-优化器选择用什么索引表如何连接等-执行器 到这里都是属于server层,大多数功能包括视图,存储过程,触发器都是…

WIZnet W5500-EVB-Pico DHCP 配置教程(三)

DHCP协议介绍 什么是DHCP? 动态主机配置协议DHCP(Dynamic Host Configuration Protocol)是一种网络管理协议,用于集中对用户IP地址进行动态管理和配置。 DHCP于1993年10月成为标准协议,其前身是BOOTP协议。DHCP协议由R…

AD21 PCB设计的高级应用(六)极坐标的应用

(六)极坐标的应用 在 PCB 设计过程中,特别是 LED圆形灯板的 PCB 设计,需要对 LED灯珠进行圆形等间距排列,如果每个元件都计算清楚其坐标再进行放置会非常烦琐。要实现如图 所示的元件布局效果,在 Altium Designer 软件里可以使用极坐标的方法。 (1)打开…