训练模型时 遇到速度过慢时的深思 速度提升 (From GPU CPU)

训练模型时 遇到速度过慢时的深思 & 速度提升

  • GPU
    • 查看GPU使用情况 配置
    • 单机多卡并行训练
      • torch.nn.DataParallel
      • 平衡DataParallel带来的显存使用不平衡的问题
      • torch.nn.parallel.DistributedDataParallel
    • 多机多gpu训练
      • Reference
    • 使用半精度训练
    • 更好的显卡,更轻的模型
    • batch_size
  • CPU
    • data loader
    • 减少日志IO操作频率
    • 使用pin_memory和num_workers
  • 总结
  • Reference

GPU

查看GPU使用情况 配置

nvidia_smi

单机多卡并行训练

torch.nn.DataParallel

在使用多GPU的时候, 使用os.environ[‘CUDA_VISIBLE_DEVICES’]来限制使用的GPU个数
例如我要使用第0和第3编号的GPU, 那么只需要在程序中设置:

os.environ['CUDA_VISIBLE_DEVICES'] = '0,3'

Note:
这个参数的设定要保证在模型加载到gpu上之前, 一般在程序开始的时候就设定好这个参数

之后, 将模型加载到多GPU上面

如果是模型:

model = nn.DataParallel(model)
model = model.cuda()

如果是数据:

inputs = inputs.cuda()
labels = labels.cuda()

pytorch官网给的示例代码

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model.to(device)

DataParallel的内部代码:

class DataParallel(Module):
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

截取其中一部分代码,如果不设定好要使用的device_ids的话, 程序会自动找到这个机器上面可以用的所有的显卡, 然后用于训练.
但是因为前面使用os.environ['CUDA_VISIBLE_DEVICES']限定了这个程序可以使用的显卡, 所以这个地方程序如果自己获取的话, 获取到的其实就是我们上面设定的那几个显卡.

深入考究
使用os.environ['CUDA_VISIBLE_DEVICES']对可以使用的显卡进行限定之后, 显卡的实际编号和程序看到的编号应该是不一样的

例如
上面我们设定的是os.environ['CUDA_VISIBLE_DEVICES']="0,2",但是程序看到的显卡编号应该被改成了'0,1', 也就是说程序所使用的显卡编号实际上是经过了一次映射之后才会映射到真正的显卡编号上面的,
例如
这里的程序看到的1对应实际的2

平衡DataParallel带来的显存使用不平衡的问题

这个问题其实讨论的也比较多了, 官方给的解决方案就是使用 DistributedDataParallel来代替 DataParallel
(实际上DistributedDataParallel显存分配的也不是很平衡), 但是从某些角度来说, DataParallel使用起来确实比较方便, 而且最近使用 DistributedDataParallel 遇到一些小问题. 所以这里提供一个解决显存使用不平衡问题的方案:

  1. 首先这次的解决方案来自transformer-XL的官方代码: https://github.com/kimiyoung/transformer-xl

  2. 然后我将其中的平衡GPU显存的代码提取了出来(原代码好像有点小问题)放到了github上面:https://github.com/Link-Li/Balanced-DataParallel

这里的代码是原作者继承了 DataParallel 类之后进行了改写:

class BalancedDataParallel(DataParallel):
    def __init__(self, gpu0_bsz, *args, **kwargs):
        self.gpu0_bsz = gpu0_bsz
        super().__init__(*args, **kwargs)
 ...

这个 BalancedDataParallel 类使用起来和 DataParallel 类似, 下面是一个示例代码:

my_net = MyNet()
my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda()

第一个参数是第一个GPU要分配多大的batch_size

如果使用了梯度累积, 那么这里传入的是每次进行运算的实际batch_size大小.

比如:
你在3个GPU上面跑代码, 但是一个GPU最大只能跑3条数据, 但是因为0号GPU还要做一些数据的整合操作, 于是0号GPU只能跑2条数据, 这样一算, 你可以跑的大小是2+3+3=8, 于是你可以设置下面的这样的参数:

batch_szie = 8
gpu0_bsz = 2
acc_grad = 1
my_net = MyNet()
my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda()

这个时候突然想跑个batch size是16的怎么办呢, 那就是4+6+6=16了, 这样设置累积梯度为2就行了:

batch_szie = 16
gpu0_bsz = 4
acc_grad = 2
my_net = MyNet()
my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda()

torch.nn.parallel.DistributedDataParallel

pytorch的官网建议使用DistributedDataParallel来代替DataParallel

因为DistributedDataParallel比DataParallel运行的更快

然后 显存分屏的更加均衡
且DistributedDataParallel 功能更加强悍
例如
分布式的模型(一个模型太大, 以至于无法放到一个GPU上运行, 需要分开到多个GPU上面执行).
只有DistributedDataParallel支持分布式的模型像单机模型那样可以进行多机多卡的运算.

先设定好os.environ[‘CUDA_VISIBLE_DEVICES’]
然后再进行下面的步骤.

因为DistributedDataParallel是支持多机多卡的, 所以这个需要先初始化一下:

torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1)

第一个参数pytorch支持的通讯后端, 这里单机多卡, 这个就是走走过场.
第二个参数各个机器之间通讯的方式,这里是单机多卡, 设置成localhost, 后面的端口找一个空着没用的就OK.
**rank**标识主机和从机, 这里就一个主机, 设置成0就OK.
**world_size**是标识使用几个主机, 这里就一个主机, 设置成1就OK, 设置多了代码不允许.

其实如果是使用单机多卡的情况下, 根据pytorch的官方代码distributeddataparallel, 是直接可以使用下面的代码的:

torch.distributed.init_process_group(backend="nccl")
model = DistributedDataParallel(model) # device_ids will include all GPU devices by default

Note:
如果使用这句代码, 直接在pycharm或者别的编辑器中,是没法正常运行的
因为这个需要在shell的命令行中运行, 如果想要正确执行这段代码, 假设这段代码的名字是main.py, 可以使用如下的方法进行(参考1 参考2):

python -m torch.distributed.launch main.py

如果使用了argparse, 一定要在参数里面加上–local_rank, 否则运行还是会出错的

之后就和使用DataParallel很类似了.

model = model.cuda()
model = nn.parallel.DistributedDataParallel(model)

Note: 这里要先将model加载到GPU 然后才能使用DistributedDataParallel进行分发
之后的使用和DataParallel就基本一样了

多机多gpu训练

在单机多gpu可以满足的情况下, 绝对不建议使用多机多gpu进行训练
经过测试, 发现多台机器之间传输数据的时间非常慢, 主要是因为测试的机器可能只是千兆网卡, 再加上别的一些损耗, 网络的传输速度跟不上, 导致训练速度实际很慢.

Reference

pytorch/examples/imagenet/main.py

Distributed-VGG-F

使用半精度训练

更好的显卡,更轻的模型

batch_size

增大 batch size 提高epoch速度
但是收敛速度也会变慢
需要再适当升高学习率

CPU

查看占用率
top命令

top -bn 1 -i -c

在这里插入图片描述

data loader

线程数 arg 增加
多线程加载数据
dataloader 的param.

num_works=4

如果数据集很小,使用多线程加载数据可能会更慢,因为多线程有一定的开销。
在这种情况下,最好使用单线程读取数据。

减少日志IO操作频率

文件的IO操作,这会导致GPU得不到连续性使用,整体速度特别慢。

使用pin_memory和num_workers

总结

有的时候模型训练慢并不是因为显卡不行或者模型太大,而是在跑模型过程中有一些其他的操作导致速度很慢
如:IO操作

  1. GPU 升级
  2. 硬盘读取,重点观察数据是否量大且容量较小
  3. data loader 线程数arg 增加
  4. 观察CPU占用率
  5. 内存占用率
  6. IO速度等

Reference

https://zhuanlan.zhihu.com/p/86441879

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/344686.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

编译和链接---C语言

引言 众所周知,C语言是一门高级的编程语言,是无法被计算机直接读懂的,C语言也不同于汇编PHP,无法直接翻译成机器语言,在学习的过程中,你是否好奇过我们所敲的C语言代码,是如何一步步翻译成机器…

Docker容器引擎(1)

目录 一.Docker 概述 为什么要用到容器? docker是什么? 容器与虚拟机的区别? docker的三个核心概念: 二.安装docker 安装依赖包: 安装 Docker-CE并设置为开机自动启动: 查看 docker 版本信息&#…

10个常考的前端手写题,你全都会吗?(下)

前言 📫 大家好,我是南木元元,热爱技术和分享,欢迎大家交流,一起学习进步! 🍅 个人主页:南木元元 今天接着上篇再来分享一下10个常见的JavaScript手写功能。 目录 1.实现继承 ES5继…

docker 部署springboot项目

新建Dockerfile ## AdoptOpenJDK 停止发布 OpenJDK 二进制,而 Eclipse Temurin 是它的延伸,提供更好的稳定性 ## 感谢复旦核博士的建议!灰子哥,牛皮! FROM eclipse-temurin:8-jre## 将后端项目的 Jar 文件&#xff0c…

软考复习之多媒体篇

常用的计算公式 数据传输率(单位:b/s) 未压缩的数据传输率 采样频率(Hz)* 量化位数(位)* 声道数 波形声音经过数字化后的信息数据量(单位:字节) 声音信号数据量 数据传输率 * …

简单实现网络编程

1. 前置知识 在学习网络编程前,我们需要先了解一些前置知识 1.1 客户端和服务器 在网络编程中,客户端和服务器是两个关键的角色。 客户端是发起连接并向服务器发送请求的一方。客户端通常是一个应用程序或设备,通过与服务器建立连接&…

白盒测试和黑盒测试的区别

黑盒测试 等价类划分 白盒测试 灰盒测试

K8S图像化工具rancher

Rancher是一个开源的企业级多集群的k8s管理平台 Rancher和k8s的区别 都是为了容器的调度和编排系统,但是rancher不仅能够调度,还能挂历k8s集群,自带监控(普罗米修斯),你哪怕不知带k8s是什么,一样…

跟着pink老师前端入门教程-day09

二十二、定位 22.1 为什么需要定位 1. 某个元素可以自由的在一个盒子内移动位置,并且压住其他盒子 2. 当我们滚动窗口时,盒子是固定屏幕某个位置的 解决方法: 1. 浮动可以让多个块级盒子一行没有缝隙排列显示,经常用于横向排…

C#用DateTime.Now静态属性返回日期的星期信息

目录 一、使用的方法 1.Now属性 2.ToString方法 二、示例 使用DateTime结构的Now静态属性,可以方便地获取系统日期信息。调用时间对象的ToString方法,在该方法的参数中添加适当的格式化字符串,将返回日期的星期信息。 一、使用的方法 1…

C语言入门到精通之练习实例10:打印楼梯,同时在楼梯上方打印两个笑脸(附源码)

题目:打印楼梯,同时在楼梯上方打印两个笑脸。 程序分析:用 ASCII 1 来输出笑脸;用i控制行,j来控制列,j根据i的变化来控制输出黑方格的个数。 如果出现乱码情况请参考【C 练习实例7】的解决方法。 // Cr…

uniapp 在static/index.html中添加全局样式

前言 略 在static/index.html中添加全局样式 <style>div {background-color: #ccc;} </style>static/index.html源码&#xff1a; <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"utf-8"><meta http-…

从零开始用Rust编写nginx,命令行参数的设计与解析及说明

wmproxy wmproxy已用Rust实现http/https代理, socks5代理, 反向代理, 静态文件服务器&#xff0c;四层TCP/UDP转发&#xff0c;七层负载均衡&#xff0c;内网穿透&#xff0c;后续将实现websocket代理等&#xff0c;会将实现过程分享出来&#xff0c;感兴趣的可以一起造个轮子 …

教育研究方法有哪些分类

教育研究方法的分类是多种多样的&#xff0c;可以从不同的角度进行划分。根据研究目的、研究范围、研究层次等&#xff0c;可以将教育研究方法分为不同的类型。 一、根据研究目的划分 基础性研究基础性研究也称为理论性研究&#xff0c;主要是为了探索和发现新的理论观点和原…

uniapp开发过程一些小坑

问题1、uniapp使用scroll-view的:scroll-into-view“lastChatData“跳到某个元素id时候&#xff0c;在app上不生效&#xff0c;小程序没问题 使用this.$nextTick或者 setTimeout(()>{that.lastChatData 元素id },500) 进行延后处理就可以了。 问题2&#xff1a;uniapp开…

运动刷步工具2.4。运动步数同步wx和支付宝

时隔一年&#xff1b;上次更新还是2022-11-26&#xff0c;陆续收到私信旧版已不能使用&#xff0c;2.4版本更新修复超时问题。后续有时间还会更新邮箱登录功能2023-11-15更新&#xff1a;修复错误超时问题测试不行的小伙伴&#xff0c;先看说明再使用&#xff01;&#xff01;&…

关于常见分布式组件高可用设计原理的理解和思考

文章目录 1. 数据存储场景和存储策略1.1 镜像模式-小规模数据1.2 分片模式-大规模数据 2. 数据一致性和高可用问题2.1 镜像模式如何保证数据一致性2.2 镜像模式如何保证数据高可用2.2.1 HA模式2.2.2 分布式选主模式 2.3 分片模式如何数据一致性和高可用 3. 大规模数据集群的架构…

垃圾收集算法

垃圾收集算法有如下几种&#xff1a; 分代收集理论&#xff1a;年龄代和老年代选择各自的垃圾收集算法。 复制算法&#xff1a;可达性分析算法找非垃圾对象&#xff0c;然后把非垃圾对象移动到另一端&#xff0c;这一端的垃圾对象清除&#xff0c;该方法浪费内…

如何查找SpringBoot应用中的请求路径(不使用idea)

背景 昨天有个同事向我咨询某个接口的物理表是哪个&#xff0c;由于公司业务较多、这块业务的确不是我负责的&#xff0c;也没有使用idea不能全局搜索(eclipse搜不到jar内的字符串)&#xff0c;也就回复了不清楚。 除了自己写代码输出servlet的路径和类外&#xff0c;发现了一…

【JVM故障问题排查心得】「Java技术体系方向」Java虚拟机内存优化之虚拟机参数调优原理介绍

Java技术体系方向-JVM虚拟机参数调优原理 内容简介栈上分配与逃逸分析逃逸分析(Escape Analysis)栈上分配基本思想使用场景线程私有对象 虚拟机内存逻辑图JVM内存分配源码&#xff1a;代码总体逻辑 在某些场景使用栈上分配设置JVM运行参数&#xff1a;开启逃逸模式&#xff0c;…