使用Dask在多块AMD GPU上加速XGBoost

Accelerating XGBoost with Dask using multiple AMD GPUs — ROCm Blogs

2024年1月26日 由Clint Greene撰写。

XGBoost 是一个用于分布式梯度提升的优化库。它已经成为解决回归和分类问题的领先机器学习库。如果您想深入了解梯度提升的工作原理,推荐阅读 Introduction to Boosted Trees。

在这篇博客中,我们将向您展示如何构建和安装支持 ROCm 的 XGBoost,以及如何使用 Dask 在多个 AMD GPU 上加速 XGBoost 训练。为了在多个 GPU 上加速 XGBoost,我们利用了 AMD 加速云(AAC),这是一个提供按需 GPU 云计算资源的平台。具体来说,我们使用了装有四个 GPU 的 AAC Dask Docker 容器来加速 XGBoost 训练。需要访问 AAC 才能跟随文中的步骤进行操作。

安装

启动 AAC 上的 Dask 工作负载后,克隆 XGBoost ROCm 仓库并进入该目录。

git clone https://github.com/ROCmSoftwarePlatform/xgboost.git
cd xgboost
git submodule update --init --recursive

接下来,从源码构建 XGBoost。我们在 MI 200 Instinct 系列 GPU 上进行构建,因此将 DCMAKE_HIP_ARCHITECTURES 设置为 gfx90a。如果您使用的是不同架构,请相应地修改此项。您可以使用以下命令查找您的架构:​​​​​​​rocminfo | grep gfx

mkdir build
cd build
cmake -DUSE_HIP=ON -DCMAKE_HIP_ARCHITECTURES="gfx90a" -DUSE_RCCL=1 ../
make -j

现在您可以将 XGBoost 编译成 Python 包。它将安装 XGBoost 的 2.1.0-dev 版本。

 
cd python-package
pip install .

完成带有 ROCm 支持的 XGBoost 安装后,您可以使用 Dask 在 AMD GPU 上进行分布式训练和推理。

入门指南

下载一个训练数据集,以衡量分布在多个 GPU 上的计算所获得的加速效果。我们使用希格斯玻色子数据集,其目标是区分产生希格斯玻色子的信号过程和不产生希格斯玻色子的过程。

wget http://mlphysics.ics.uci.edu/data/higgs/HIGGS.csv.gz
gunzip HIGGS.csv.gz

导入 XGBoost 训练所需的库。

import time
import os
import dask
import pandas as pd
import xgboost as xgb
from distributed import Client
from dask import dataframe as dd
from dask_hip import LocalHIPCluster
from xgboost.dask import DaskDMatrix

分布式环境

要使用 XGBoost 进行多 GPU 训练,您需要使用 Dask 设置分布式环境。Dask 集群由三个不同的组件组成:一个集中的调度器、一个或多个工作节点(GPU),以及一个用于向集群提交任务的客户端。

为了将 XGBoost 计算分布到 Dask 集群上:

将 num_gpus 设置为您想要使用的 GPU 数量。这将用于指定要使用的设备。然后,使用 LocalHIPCluster 以单节点模式创建集群,并将客户端连接到该集群。

num_gpus = 4
devices = ','.join([str(i) for i in range(num_gpus)])
cluster = LocalHIPCluster(HIP_VISIBLE_DEVICES=devices)
client = Client(cluster)

您的分布式环境现已设置完毕,可以进行计算。

加载数据

数据集已经预先平衡、清理和标准化。为了进行基准测试,我们将使用完整的数据集,并将其加载为一个 Dask 数据框。

colnames = ['label'] + ['feature-%02d' % i for i in range(1, 29)]
fn = 'HIGGS.csv'
df = dd.read_csv(fn, header=None, names=colnames, dtype='float32')
X = df[df.columns.difference(['label'])]
y = df['label']

训练

我们接着创建一个 xgboost.dask.DaskDMatrix 对象,并将其与其他参数一起传递给 xgboost.dask.train(),这与XGBoost的常规非Dask接口非常相似。与常规接口不同的是,数据和标签必须是Dask DataFrame或Dask Array实例。

dtrain = xgb.dask.DaskDMatrix(client, X, y)

start_time = time.time()
param =  {
               'max_depth': 8,
               'objective': 'reg:squarederror',
               'tree_method': 'gpu_hist',
               'device': 'gpu'
          }
print('Training')
bst = xgb.dask.train(client, param, dtrain, num_boost_round=1000)
print("[INFO]: ------ Training is completed in {} seconds ------".format((time.time() - start_time)))

使用XGBoost的Dask训练接口时,我们将Dask客户端作为额外的参数传递,以便进行计算。

[INFO]: ------ Training is completed in 24.53920841217041 seconds ------

推理

Dask接口有两个预测函数:`predict` 和 inplace_predict。我们使用 predict 函数。在获得训练好的模型 booster 后,可以通过以下代码进行推理:

booster = bst['booster']
# Set to use GPU for inference.
booster.set_param({'device': 'gpu'})
# dtrain is the DaskDMatrix defined above.
prediction = xgb.dask.predict(client, booster, dtrain)

性能

可以看到,随着GPU数量的增加,训练时间显著减少。例如,将GPU数量从1增加到4,训练时间减少了2.6倍,相对于使用CPU进行训练,训练时间减少了超过10倍。

Training speedup

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/899393.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Maven入门到实践:从安装到项目构建与IDEA集成

目录 1. Maven的概念 1.1 什么是Maven 1.2 什么是依赖管理 1.3 什么是项目构建 1.4 Maven的应用场景 1.5 为什么使用Maven 1.6 Maven模型 2.初识Maven 2.1 Maven安装 2.1.1 安装准备 2.1.2 Maven安装目录分析 2.1.3 Maven的环境变量 2.2 Maven的第一个项目 2.2.1…

深度学习Pytorch-Tensor函数

深度学习Pytorch-Tensor函数 Tensor的三角函数Tensor中其他的数学函数Tensor中统计学相关的函数(维度,对于二维数据:dim0 按列,dim1 按行,默认 dim1)Tensor的torch.distributions(分布函数)Tensor中的随机抽…

图论day62|拓扑排序理论基础、117.软件构建(卡码网)、最短路径之dijkstra理论基、47.参加科学大会(卡码网 第六期模拟笔试)

图论day62|拓扑排序理论基础、117.软件构建(卡码网)、最短路径之dijkstra理论基、47.参加科学大会(卡码网 第六期模拟笔试) 拓扑排序理论基础117.软件构建(卡码网)最短路径之dijkstra理论基础47.参加科学大…

IDEA 安装热部署 JRebel -新版-亲测有效

由于采用直接从idea 下载的插件会出现版本不适配,激活不成功 下载地址:https://note.youdao.com/web/#/file/recent/note/WEB0e3010b4015162dc6a11d6c0ab11f750/ 导入刚才下载的插件 其中,Team URL可以使用在线GUID地址在线生成GUID 拿到GUID…

Node.js 模块化

1. 介绍 1.1 什么是模块化与模块 ? 将一个复杂的程序文件依据一定规则(规范)拆分成多个文件的过程称之为 模块化其中拆分出的 每个文件就是一个模块 ,模块的内部数据是私有的,不过模块可以暴露内部数据以便其他模块使用 1.2 什…

蓝桥杯注意事项

蓝桥杯注意事项 比赛注意事项 能暴力枚举就暴力枚举,能用简单的思路做就尽量用简单的思路做。认真审核题目的题意和输入输出的要求,避免因为误解题意而导致题目错误。对于提供多组测试样例或者需要对一个过程重复进行循环的代码,要时刻记住…

第四范式发布AI Data Foundry,加速大模型训练及应用

产品上新 Product Release 今日,第四范式发布AI Data Foundry,提供基于AI技术,融合人类专家反馈的高质量、丰富可扩展、多样化的数据集,大幅提升模型效果。同时,通过模型评估系统及工具,对模型效果进行有效…

w外链如何跳转微信小程序

要创建外链跳转微信小程序,主要有以下几种方法: 使用第三方工具生成跳转链接: 注册并登录第三方外链平台:例如 “W外链” 等工具。前往该平台的官方网站,使用手机号、邮箱等方式进行注册并登录账号。选择创建小程序外…

windows SVN 忘记账号密码

一、本地登录过且记录未清空 1、打开C:\Users\用户名\AppData\Roaming\Subversion\auth\svn.simple目录 2、下载SvnPwd.exe文件 链接地址:TortoiseSVN Password Decrypter 复制SvnPwd.exe到 C:\Users\用户名\AppData\Roaming\Subversion\auth\svn.simple目录下 3、运…

Web组态-仪器间的相互通信(WebSocket技术)

Web组态,通过Vue3TypeScriptWebSocket技术实现平台仪器间的相互通信,用于设计工业化虚拟仿真。 界面图如下(之前文章有详细教学) 如下是通信设备虚拟仿真的三个仪器,设计初衷是想三个仪器能够数据互通,实现…

【Thymeleaf】spring boot模板引擎thymeleaf用法详解

快速入门Thymeleaf 1️⃣ 什么是Thymeleaf?1️⃣ 模板入门2️⃣ 创建测试工程2️⃣ 配置文件2️⃣ 创建controller2️⃣ 写一个html页面2️⃣ 启动测试 1️⃣ Thymeleaf基础2️⃣ 实体类2️⃣ 增加接口2️⃣ $符号使用2️⃣ *符号的使用2️⃣ 符号的使用2️⃣ #符号…

一文掌握异步web框架FastAPI(五)-- 中间件(测试环境、访问速率限制、请求体解析、自定义认证、重试机制、请求频率统计、路径重写)

接上篇:一文掌握异步web框架FastAPI(四)-CSDN博客 目录 七、中间件 15、测试环境中间件 16、访问速率限制中间件,即限制每个IP特定时间内的请求数(基于内存,生产上要使用数据库) 1)限制单ip访问速率 2)增加限制单ip并发(跟上面的一样,也是限制每个IP特定时间内的请…

??? 命令行形式的简单功能的计算器的Shell脚本

文章目录 需求编码Way1Way2: 测试 需求 需求分析: 支持浮点型:使用let命令 编码 Way1 用下循环吧! #!/bin/bash # Author: # Date: # Description:# functions defines: input_check_to_startup() {num1$1num2$2isNum_statu…

Node版本管理nvm

公司项目比较多,且有历史包袱,没时间升级,高版本的node无法在低版本项目中打包编译; 下载地址 gitHub地址 nvm-setup.zip:安装版,推荐使用 nvm-setup.exe 常用指令 // 查看版本信息 nvm -v // 查看能安装…

《线下学习受局限,知识付费小程序开启新篇》

在知识大爆炸的时代,人们对知识的渴望从未如此强烈。然而,传统的线下学习方式却逐渐显露出诸多局限。 线下学习往往受到时间和空间的严格限制。为了参加一场培训课程或者讲座,你可能需要在特定的时间赶到特定的地点,这对于忙碌的…

大数据-188 Elasticsearch - ELK 家族 Logstash Output 插件

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

基于开源Jetlinks物联网平台协议包-MQTT自定义主题数据的编解码

目录 前言 1.下载官方协议包 2.解压 3.自定义主题 4.重写解码方法 5.以下是我解析后接收到的数据 前言 最近这段时间,一直在用开源的Jetlinks物联网平台在学习,偶尔有一次机会接触到物联网设备对接,在协议对接的时候,遇到了…

400行程序写一个实时操作系统(十):用面向对象思想构建抢占式内核

前言 通过前几章的学习,我们学会了如何为RTOS设计一个合理的内存管理算法。现在,是时候学习设计RTOS内核了。 关于RTOS内核的文章也有很多,但都有一点先射箭再化靶子的意味。要么是代码连篇解释却寥寥无几,要么是要先怎么样再怎么…

【星闪开发连载】WS63E模块连接华为IoT云

目录 引言 WS63E对MQTT的支持 程序修改 测试结果 结语 引言 在上一篇博文中已经介绍了WiFi的使用。今天介绍一下如何使用MQTT协议连接到华为云上。 WS63E对MQTT的支持 WS63E的代码参考直接提供了MQTT的支持,文档介绍见docs/board/WS63V100 MQTT 开发指南.pd…

【STM32开发笔记】移植AI框架TensorFlow到STM32单片机【DSP指令加速篇】

【STM32开发笔记】移植AI框架TensorFlow到STM32单片机【DSP指令加速篇】 一、前文回顾二、CMSIS-NN简介2.1 为什么介绍CMSIS-NN?2.2 CMSIS-NN是什么?2.3 CMSIS-NN核心特性2.4 CMSIS-NN算子支持 三、TFLMCMSIS-NN集成3.1 包含TFLM的STM32项目3.2 理解TFLM…