【教程】DGL单机多卡分布式GCN训练

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

        PyTorch中的DDP会将模型复制到每个GPU中。

        梯度同步默认使用Ring-AllReduce进行,重叠了通信和计算。

        示例代码:

视频:https://youtu.be/Cvdhwx-OBBo

代码:multigpu.py

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

import dgl
from dgl.data import RedditDataset
from dgl.nn.pytorch import GraphConv


def ddp_setup(rank, world_size):
    """
    DDP初始化设置。
    
    参数:
        rank (int): 当前进程的唯一标识符。
        world_size (int): 总进程数。
    """
    os.environ["MASTER_ADDR"] = "localhost"  # 设置主节点地址
    os.environ["MASTER_PORT"] = "12355"      # 设置主节点端口
    init_process_group(backend="nccl", rank=rank, world_size=world_size)  # 初始化进程组
    torch.cuda.set_device(rank)  # 设置当前进程使用的GPU设备


class GCN(torch.nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        """
        初始化图卷积网络(GCN)。
        
        参数:
            in_feats (int): 输入特征的维度。
            h_feats (int): 隐藏层特征的维度。
            num_classes (int): 输出类别的数量。
        """
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)  # 第一层图卷积
        self.conv2 = GraphConv(h_feats, num_classes)  # 第二层图卷积

    def forward(self, g, in_feat):
        """
        前向传播。
        
        参数:
            g (DGLGraph): 输入的图。
            in_feat (Tensor): 输入特征。
        
        返回:
            Tensor: 输出的logits。
        """
        h = self.conv1(g, in_feat)  # 进行第一层图卷积
        h = F.relu(h)  # ReLU激活
        h = self.conv2(g, h)  # 进行第二层图卷积
        return h


class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int,
    ) -> None:
        """
        初始化训练器。
        
        参数:
            model (torch.nn.Module): 要训练的模型。
            train_data (DataLoader): 训练数据的DataLoader。
            optimizer (torch.optim.Optimizer): 优化器。
            gpu_id (int): GPU ID。
            save_every (int): 每隔多少个epoch保存一次检查点。
        """
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)  # 将模型移动到指定GPU
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(model, device_ids=[gpu_id])  # 使用DDP包装模型

    def _run_batch(self, batch):
        """
        运行单个批次。
        
        参数:
            batch: 单个批次的数据。
        """
        self.optimizer.zero_grad()  # 梯度清零
        graph, features, labels = batch
        graph = graph.to(self.gpu_id)  # 将图移动到GPU
        features = features.to(self.gpu_id)  # 将特征移动到GPU
        labels = labels.to(self.gpu_id)  # 将标签移动到GPU
        output = self.model(graph, features)  # 前向传播
        loss = F.cross_entropy(output, labels)  # 计算交叉熵损失
        loss.backward()  # 反向传播
        self.optimizer.step()  # 更新模型参数

    def _run_epoch(self, epoch):
        """
        运行单个epoch。
        
        参数:
            epoch (int): 当前epoch号。
        """
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Steps: {len(self.train_data)}")
        for batch in self.train_data:
            self._run_batch(batch)  # 运行每个批次

    def _save_checkpoint(self, epoch):
        """
        保存训练检查点。
        
        参数:
            epoch (int): 当前epoch号。
        """
        ckp = self.model.module.state_dict()  # 获取模型的状态字典
        PATH = "checkpoint.pt"  # 定义检查点路径
        torch.save(ckp, PATH)  # 保存检查点
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        """
        训练模型。
        
        参数:
            max_epochs (int): 总训练epoch数。
        """
        for epoch in range(max_epochs):
            self._run_epoch(epoch)  # 运行当前epoch
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_checkpoint(epoch)  # 保存检查点


def load_train_objs():
    """
    加载训练所需的对象:数据集、模型和优化器。
    
    返回:
        tuple: 数据集、模型和优化器。
    """
    data = RedditDataset(self_loop=True)  # 加载Reddit数据集,并添加自环
    graph = data[0]  # 获取图
    train_mask = graph.ndata['train_mask']  # 获取训练掩码
    features = graph.ndata['feat']  # 获取特征
    labels = graph.ndata['label']  # 获取标签

    model = GCN(features.shape[1], 128, data.num_classes)  # 初始化GCN模型
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)  # 初始化优化器
    train_data = [(graph, features, labels)]  # 准备训练数据
    
    return train_data, model, optimizer


def prepare_dataloader(dataset, batch_size: int):
    """
    准备DataLoader。
    
    参数:
        dataset: 数据集。
        batch_size (int): 批次大小。
    
    返回:
        DataLoader: DataLoader对象。
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        collate_fn=lambda x: x[0]  # 自定义collate函数,解包数据集中的单个元素
    )


def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    """
    主训练函数。
    
    参数:
        rank (int): 当前进程的唯一标识符。
        world_size (int): 总进程数。
        save_every (int): 每隔多少个epoch保存一次检查点。
        total_epochs (int): 总训练epoch数。
        batch_size (int): 批次大小。
    """
    ddp_setup(rank, world_size)  # DDP初始化设置
    dataset, model, optimizer = load_train_objs()  # 加载训练对象
    train_data = prepare_dataloader(dataset, batch_size)  # 准备DataLoader
    trainer = Trainer(model, train_data, optimizer, rank, save_every)  # 初始化训练器
    trainer.train(total_epochs)  # 开始训练
    destroy_process_group()  # 销毁进程组


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Simple distributed training job')
    parser.add_argument('--total_epochs', default=50, type=int, help='Total epochs to train the model')
    parser.add_argument('--save_every', default=10, type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=8, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()
    
    world_size = torch.cuda.device_count()  # 获取可用GPU的数量
    mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)  # 启动多个进程进行分布式训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/702470.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【免费Web系列】大家好 ,今天是Web课程的第十九天点赞收藏关注,持续更新作品 !

1. Vue工程化 前面我们在介绍Vue的时候,我们讲到Vue是一款用于构建用户界面的渐进式JavaScript框架 。(官方:Vue.js - 渐进式 JavaScript 框架 | Vue.js) 那在前面的课程中,我们已经学习了Vue的基本语法、表达式、指令…

MapperStruct拷贝数据的介绍和使用

1、前言 在java 编程中,对象直接拷贝是很常用的方法,最初我们常用spring提供的拷贝工具BeanUtils的copyProperties方法完成对象之间属性的拷贝。但是它有几个明显的如下缺点 1、属性类型不一致导致摸一个属性值拷贝失败 2、通一个字段使用基本类型和包…

Mybatis plus join 一对多对象语法

1. 实体类环境 题目 package co.yixiang.exam.entity;import co.yixiang.domain.BaseDomain; import co.yixiang.exam.config.CustomStringListDeserializer; import com.baomidou.mybatisplus.annotation.TableField; import com.fasterxml.jackson.annotation.JsonCreator;…

使用Python爬取temu商品与评论信息

【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走来长期坚守并致力于Python与爬虫领域研究与开发工作! 【&…

Pytorch--Convolution Layers

文章目录 1.nn.Conv1d2.torch.nn.Conv2d()3.torch.nn.ConvTranspose1d()3.torch.nn.ConvTranspose2d() 1.nn.Conv1d torch.nn.Conv1d() 是 PyTorch 中用于定义一维卷积层的类。一维卷积层常用于处理时间序列数据或具有一维结构的数据。 构造函数 torch.nn.Conv1d() 的语法如…

【运维自动化-配置平台】如何使用云资源同步功能(腾讯云为例)

云资源同步是通过apikey去单向同步云上的主机资源和云区域信息,目前支持腾讯云和亚马逊云。主要特性 1、蓝鲸配置平台周期性的单向只读同步云主机和vpc(对应蓝鲸云区域)信息,第一次全量,后面增量 2、默认同步到主机池…

kotlin 中的数字

以下均来自官方文档: 一、整数类型 1、kotlin中内置的整数类型,有四种不同大小的类型: 类型存储大小(比特数)最小值最大值Byte8-128127Short16-3276832767Int32-2,147,483,648 (-231)2,147,483,647 (231 - 1)Long64…

图片导入AutoCAD建立草图—CAD图像导入插件

插件介绍 CAD图像导入插件可将PNG,JPG等格式图片导入到AutoCAD软件内建立图像边缘的二维线条模型。插件可以提取图像黑色或白色区域的边界,并可绘制原状边界或平滑边界两种样式。 模型说明 边界提取,黑色或白色边界的提取根据原图类型选择…

c#调用c++dll方法

添加dll文件到debug目录,c#生成的exe的相同目录 就可以直接使用了,放在构造函数里面测试

排序的时间复杂度、空间复杂度和稳定性等的比较

时间复杂度和空间复杂度我们比较熟悉,重点来看一下稳定性。 稳定性是指假定在待排序的记录序列中,存在多个具有相同的关键字的记录,若经过排序,这些记录的相对次序保持不变,即在原序列中,a[i] a[j] &…

Golang 百题(实战快速掌握语法)_1

整形转字符串类型 实验介绍 本实验将展示三种方法来实现整形类型转字符串类型。 知识点 strconvfmt Itoa 函数 代码实例 Go 语言中 strconv 包的 itoa 函数输入一个 int 类型,返回转换后的字符串。下面是一个例子。 package mainimport ("fmt"&qu…

跟TED演讲学英文:Toward a new understanding of mental illness by Thomas Insel

Toward a new understanding of mental illness Link: https://www.ted.com/talks/thomas_insel_toward_a_new_understanding_of_mental_illness Speaker: Thomas Insel Date: January 2013 文章目录 Toward a new understanding of mental illnessIntroductionVocabularySum…

【C语言】联合(共用体)

目录 一、什么是联合体 二、联合类型的声明 三、联合变量的创建 四、联合的特点 五、联合体大小的计算 六、联合的应用(判断大小端) 七、联合体的优缺点 7.1 优点 7.2 缺点 一、什么是联合体 联合也是一种特殊的自定义类型。由多个不同类型的数…

测长仪的发展历程!

测长仪的发展历程可以大致分为以下几个阶段: 早期发展: 最早的测量工具主要是一些机械式测量工具,如角尺、卡钳等。 16世纪,在火炮制造中已开始使用光滑量规。 1772年和1805年,英国的J.瓦特和H.莫兹利等先后制造出利用…

Win快速删除node_modules

在Windows系统上删除 node_modules 文件夹通常是一个缓慢且耗时的过程。这主要是由于几个关键因素导致的: 主要原因 文件数量多且嵌套深: node_modules 文件夹通常包含成千上万的子文件夹和文件。由于其结构复杂,文件和文件夹往往嵌套得非常…

XXL-JOB分布式任务调度快速入门

文章目录 概念快速启动XXL-JOB调度初始化执行器项目配置执行器新增GLUE模式(Java)的任务新增BEAN模式(类形式)的任务BEAN模式(方法形式)的任务参考来源 概念 XXL-JOB是一个开源的分布式任务调度平台,它是一个轻量级、…

使用B树实现员工(人事)管理系统

1. 前言 使用B树来表示人事管理系统,其中每个节点代表一个人员,树的根节点为董事长,每个节点可以有多个子节点,表示下属。每一层代表一个等级分布。 addPerson: 添加人员功能通过查找指定上司节点,然后将新的人员作…

程序员/码农创业有多少种可能?

程序员创业,无疑是当下科技浪潮中的一股强大力量。凭借扎实的技术功底和敏锐的市场洞察力,在创业道路上展现出了无限的活力和创造力。那么,程序员创业究竟有哪些事情可以做呢?可以从技术产品的研发入手。 可以利用自己的专业知识…

分析GIS在疾病传播模型和公共卫生决策中的作用

在这个全球化日益加深的时代,疾病的跨国界传播成为全球公共卫生面临的重大挑战。地理信息科学(GIS)作为一门集成了空间数据采集、处理、分析及可视化的技术体系,在公共健康领域展现出其不可替代的价值。本文旨在深入探讨GIS如何助…

电动两轮车——电源方案

随着城镇化的发展人们的活动半径不断变宽,短交通出行方式仍能覆盖主要的范围。从主要国家核心地区的出行数据看平均通勤半径不高于15km,摩托车、电动两轮车等两轮出行方式能更好匹配日常短交通出行需求。 应用框图 通常,电动两轮车由三部分…