pytorch 演示 tensor并行

pytorch 演示 tensor并行

  • 一.原理
  • 二.实现代码

本文演示了tensor并行的原理。如何将二个mlp切分到多张GPU上分别计算自己的分块,最后做一次reduce。
1.为了避免中间数据产生集合通信,A矩阵只能列切分,只计算全部batch*seqlen的部分feature
2.因为上面的步骤每张GPU只有部分feature,只因B矩阵按行切分,可与之进行矩阵乘,生成部分和
3.最后把每张GPU上的部分和加起来,就是最张的结果
以下demo,先实现了非分块的模型,然后模拟nccl分块,最后是分布式的实现

一.原理

在这里插入图片描述

二.实现代码

# torch_tp_demo.py
import os
import torch
from torch import nn
import torch.nn.functional as F 
import numpy as np
import torch.distributed as dist
from torch.distributed import ReduceOp
  
import time
import argparse

parser = argparse.ArgumentParser(description="")
parser.add_argument('--hidden_size', default=512, type=int, help='')
parser.add_argument('--ffn_size', default=1024, type=int, help='')
parser.add_argument('--seq_len', default=512, type=int, help='')
parser.add_argument('--batch_size', default=8, type=int, help='')
parser.add_argument('--world_size', default=4, type=int, help='')
parser.add_argument('--device', default="cuda", type=str, help='')

class FeedForward(nn.Module): 

    def __init__(self,hidden_size,ffn_size): 
        super(FeedForward, self).__init__() 
        self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False)
        self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False)

    def forward(self, input): 
        return self.fc2(self.fc1(input))

class FeedForwardTp(nn.Module):

    def __init__(self,hidden_size,ffn_size,tp_size,rank): 
        super(FeedForwardTp, self).__init__() 
        self.fc1 = nn.Linear(hidden_size, ffn_size//tp_size,bias=False)
        self.fc2 = nn.Linear(ffn_size//tp_size, hidden_size,bias=False)
      
        self.fc1.weight.data=torch.from_numpy(np.fromfile(f"fc1_{rank}.bin",dtype=np.float32)).reshape(self.fc1.weight.data.shape)
        self.fc2.weight.data=torch.from_numpy(np.fromfile(f"fc2_{rank}.bin",dtype=np.float32)).reshape(self.fc2.weight.data.shape)

    def forward(self, input): 
        return self.fc2(self.fc1(input))


args = parser.parse_args()
hidden_size = args.hidden_size
ffn_size = args.ffn_size
seq_len = args.seq_len
batch_size = args.batch_size
world_size = args.world_size
device = args.device

def native_mode():
  print(args)
  torch.random.manual_seed(1)
  model = FeedForward(hidden_size,ffn_size)
  model.eval()
  input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)

  for idx,chunk in enumerate(torch.split(model.fc1.weight, ffn_size//world_size, dim=0)):
      chunk.data.numpy().tofile(f"fc1_{idx}.bin")
  
  for idx,chunk in enumerate(torch.split(model.fc2.weight, ffn_size//world_size, dim=1)):
      chunk.data.numpy().tofile(f"fc2_{idx}.bin")
  
  model=model.half().to(device)
  

  usetime=[]
  for i in range(32):
    t0=time.time()    
    out = model(input)
    torch.cuda.synchronize()
    t1=time.time()
    if i>3:
      usetime.append(t1-t0)
  
  print("[INFO] native: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))
  

  result=[]
  for rank in range(world_size):
      model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)
      model.eval()
      out=model(input)
      torch.cuda.synchronize()
      result.append(out)
  
  sum_all=result[0]
  for t in result[1:]:
      sum_all=sum_all+t
  
  print("[INFO] tp_simulate: shape:{},sum:{:.5f}".format(sum_all.shape,sum_all.sum().item()))

def tp_mode():
  torch.random.manual_seed(1)
  dist.init_process_group(backend='nccl')
    
  world_size = torch.distributed.get_world_size()
  rank=rank = torch.distributed.get_rank()
  local_rank=int(os.environ['LOCAL_RANK'])
  
  torch.cuda.set_device(local_rank)
  device = torch.device("cuda",local_rank)
  
  input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)  
  model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)
  model.eval()
  if rank==0:
    print(args)
    
  usetime=[]
  for i in range(32):        
    dist.barrier()
    t0=time.time()
    out=model(input)
    #dist.reduce(out,0, op=ReduceOp.SUM) 
    dist.all_reduce(out,op=ReduceOp.SUM)
    torch.cuda.synchronize()
    if rank==0:
      t1=time.time()
      if i>3:
        usetime.append(t1-t0)
  
  if rank==0:
    print("[INFO] tp: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))


if __name__ == "__main__":
  num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
  is_distributed = num_gpus > 1
  if is_distributed:
    tp_mode()
  else:
    native_mode()

运行命令:

python3 torch_tp_demo.py --hidden_size 512 \
			--ffn_size 4096 --seq_len 512 \
			--batch_size 8 --world_size 4 --device "cuda"
torchrun -m --nnodes=1 --nproc_per_node=4 \
			torch_tp_demo --hidden_size 512 \
			--ffn_size 4096 --seq_len 512 \
			--batch_size 8 --world_size 4 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/523897.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

布隆过滤器详解及java实现

什么是布隆过滤器? 布隆过滤器(Bloom Filter)是一种数据结构,用于判断一个元素是否属于一个集合。它的特点是高效地判断一个元素是否可能存在于集合中,但是存在一定的误判率。 布隆过滤器的基本原理是使用一个位数组…

【STL学习】(4)vector的模拟

前言 本文将模拟实现vector的常用功能,目的在于更深入理解vector。 一、前置知识 在模拟之前先对vector的结构和常用接口学习,有一个大致了解。看源码,本文参考的源码是SGI版本的stl3.0。 技巧: 看源码不要一行一行的看&#xff…

Severt

severt是让我们自己写一些类,然后把这些类给加载Tomcat中,后续Tomcat收到HTTP请求(来自于浏览器),就会执行到咱们上面写的代码.从而通过这些代码,完成一定的业务逻辑. 创建项目 此处创建的是一种新的项目的形式称为Maven项目,Maven是Java 中的一个的构建…

libVLC 音频立体声模式切换

在libVLC中,可以使用libvlc_audio_set_channel函数来设置音频的立体声模式。这个函数允许选择不同的音频通道,例如立体声、左声道、右声道、环绕声等。 /*** Set current audio channel.** \param p_mi media player* \param channel the audio channel…

Datacom HCIP笔记-路由策略与路由控制 之二

路由策略和策略的区别? 路由策略: 操作的对象是路由表条目, 实现路由过滤,从而实现访问控制,引入时过滤,发送和接收路由时过滤。 通过配置cost,来实现路径的控制。 策略路由: 对…

【Python】还在用print进行调试,你Out了!!!

1. 引言 Python 中最常用的函数是什么?像在大多数编程语言中,print() 函数是最常用的。我相信大多数开发者都会像我一样,在开发过程中多次使用它将信息进行打印。 当然,没有其他方法可以完全取代print()函数。不过,当…

QA测试开发工程师面试题满分问答9: Python中内存管理的概念、原理、使用

概念原理 Python中的内存管理是由解释器自动处理的,它使用引用计数和垃圾回收机制来管理内存。以下是Python内存管理的一些关键概念、设计原理和最佳实践,以帮助您高效使用和管理内存: 引用计数:Python使用引用计数来追踪对象的引…

谷歌浏览器如何截全屏图片?

有时候想要截取浏览器全屏,谷歌浏览器自带截取全屏命令,操作步骤如下: 1、按住键盘的F12或者是空白处点击鼠标右键找到检查项 2、按住ctrlshiftp,会出现搜索框的界面 3、搜索框中输入screen,选中Capture full size scr…

项目架构MVC,DDD学习

写在前面 本文一起看下项目架构DDD,MVC相关的内容。 1:MVC 不管我们做什么项目,自己想想其实只是做了三件事,如下: 其实,这三件事完全在一个类中做完也可以可以正常把项目完成的,就像下面这…

Redis简介、常用命令

目录 一、关系数据库​​与非关系数据库​​ 1.1. 关系型数据库 1.2 非关系型数据库 1.3.关系数据库与非关系型数据库区别 1.3.1. 数据存储方式不同 1.3.2. 扩展方式不同 1.3.3.对事务性的支持不同 1.4.非关系型数据库产生背景 二、Redis 2.1.Redis简介 2.2.Redis的…

如何使用开源情报跟踪一个人?在线访问网站以及使用方法介绍

如何使用开源情报跟踪一个人?在线访问网站以及使用方法介绍。 开源情报(OSINT)是一门关于收集和分析公开可用信息的独特技艺,它致力于构建个人或团体的详尽档案。 这一过程中,信息搜集者会利用多元化的信息源&#xff…

火山方舟大模型服务平台调用Demo测试(豆包)

豆包得后台大模型支持为字节得火山方舟,所以想使用豆包的API,直接从这里就可以。 一、首先注册账号: 火山引擎-云上增长新动力 注册完成之后,控制台-账户-API访问密钥 二、找到API测试用例: Skylark-chat API调用…

Linux实验3 shell命令进阶

一:实验目的 学习Linux下的文件系统结构,了解最基本的Linux下的shell命令操作,例如ls, cd, cat等各种指令操作。 学习vim编辑器的使用方式,学习如何使用ssh连接远程服务器。 二:实验内容 1.利用ls命令查找…

记一次Debug与Release版程序输出不一致的问题解决

问题叙述: 在x86平台下无论Debug还是Release都没问题,而在arm平台下Debug版本程序无问题,Release版本程序(-O3编译)发现输出值不正确,怀疑值被篡改,于是在调用前后分别使用printf打印出参数值&…

pdf操作器(图片转文字、PDF转word、PDF拆分、图片jpg、png互转)

pdf操作器(不用联网图片转文字、PDF转word、PDF拆分、图片jpg、png互转)介绍目前该软件实现了以下功能 pdf转wordpdf拆分图片,图片导出在桌面的一个文件夹里图片合并为pdf压缩、转换图片格式(jpg和png)OCR图片转文字&…

Leetcode刷题-哈希表详细总结(Java)

哈希表 当我们想使⽤哈希法来解决问题的时候,我们⼀般会选择如下三种数据结构。 数组set (集合)map(映射) 当我们遇到了要快速判断⼀个元素是否出现集合⾥的时候,就要考虑哈希法。如果在做⾯试题⽬的时候…

【Frida】【Android】 10_爬虫之WebSocket协议分析

🛫 系列文章导航 【Frida】【Android】01_手把手教你环境搭建 https://blog.csdn.net/kinghzking/article/details/136986950【Frida】【Android】02_JAVA层HOOK https://blog.csdn.net/kinghzking/article/details/137008446【Frida】【Android】03_RPC https://bl…

鸿蒙HarmonyOS 与 Android 的NDK有什么不一样?

1. 序言 就像开发Android要用Android Studio一样,Android Studio(简称AS)其实是基于IDEAgradle插件android插件开发而来。 鸿蒙系统,你可以认为它和android有点像,但又是超越android的存在,除了手机&…

IO流

一、IO概述 1.什么是IO流? 存储和读取数据的解决方案l: inputo: output流∶像水流一样传输数据 2.IO流的作用? 用于读写数据(本地文件,网络) 3.IO流按照流向可以分类哪两种流? 输出流:程序 - > 文件 输入流:文件 - > 程…

30道python自动化测试面试题与答案汇总!

Python是不可或缺的语言,它的优美与简洁令人无法自拔,下面这篇文章主要给大家介绍了关于30道python自动化测试面试题与答案汇总的相关资料,需要的朋友可以参考下 1、什么项目适合做自动化测试? 关键字:不变的、重复的、规范的 1)任务测试明…