pytorch为自己的extension backend添加profiler功能

pytorch为自己的extension backend添加profiler功能

  • 1.参考文档
  • 2.your-extension-for-pytorch需要增加的代码
  • 3.pytorch demo及如何调整chrome trace json文件
  • 4.[可视化](https://ui.perfetto.dev/)

本文演示了pytorch如何为自己的extension backend添加profiler功能
背景介绍

  • 1.没有CNLight、Profiling AscendCL API、ROC Trace之类Profing功能,无法trace runtime,drive,kernel,也无法获取设备的metrics
  • 2.只有event功能,可以统计kernel耗时
  • 3.本文只是一种尝试,并不合理.
  • 4.torch原生的profiler框架,依赖kineto,kineto目前支持CUPTI和ROC Tracer,如果不修改torch源码,第三方设备不方便使用
  • 5.华为、寒武纪、habana都是采用torch.profile的接口形式及at::addThreadLocalCallback功能,但不依赖torch.profiler框架
    profing原始数据都是私有格式,并且修改TensorBoard的插件,可于可视化

实施步骤

  • 1.调用torch::profiler::impl::registerPrivateUse1Methods注册
  • 2.因为没有correlation ID去关联host api与kernel,因此export_chrome_trace出来的数据没有kernel信息
  • 3.获取prof.profiler.function_events里的数据,通过{ev.name}{ev.id}{ev.thread}拼成uuid与上面chrome trace中的events关联
  • 4.因为只有一个stream。可以根据Host lanuch时间、kernel耗时、launch latency(先验),推断出kernel的开始、结束时间,并用flow event进行关联(虽然并不准确)
  • 5.最后把kernel event以及flow event追加到chrome trace中

1.参考文档

  • ROC Tracer
  • CUPTI
  • 华为profiler_npu
  • Profiling AscendCL API
  • 寒武纪profile_mlu
  • 寒武纪CNLight
  • habana torch
  • intel_extension_for_pytorch
  • Make the kineto extendable for other runtime than CUD
  • pytorch_open_registration_example
  • rename_privateuse1_backend
  • Trace Event Format

2.your-extension-for-pytorch需要增加的代码

#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
 
using torch::profiler::impl::ProfilerStubs;
using torch::profiler::impl::ProfilerVoidEventStub;
  
namespace torch {
namespace profiler {
namespace impl {
 
struct NPUMethods : public ProfilerStubs {
   void record(
        int* device,
        ProfilerVoidEventStub* event,
        int64_t* cpu_ns) const override
    {
      if (device) {
          TORCH_CHECK(xpurtGetDevice((uint32_t*)device));
      }
      xpurtEvent_t xpurt_event;
      TORCH_CHECK(xpurtEventCreate(&xpurt_event));
      *event = std::shared_ptr<void>(xpurt_event, [](xpurtEvent_t ptr) {
          TORCH_CHECK(xpurtEventDestroy(ptr));
      });
      auto xpurt_stream = c10::xpu::getCurrentxpuStream(vastai::get_device());
      if (cpu_ns) {
          *cpu_ns = getTime();
      }
      TORCH_CHECK(xpurtEventRecord(xpurt_event, xpurt_stream)); 
    } 
    float elapsed(
        const ProfilerVoidEventStub* event1_,
        const ProfilerVoidEventStub* event2_) const override
    {
 
        auto event1 = static_cast<xpurtEvent_t>(event1_->get());
        TORCH_CHECK(xpurtEventSynchronize(event1));
        auto event2 = static_cast<xpurtEvent_t>(event2_->get());
        TORCH_CHECK(xpurtEventSynchronize(event2));
        int64_t time_ms = 0;
        TORCH_CHECK(xpurtEventElapsedTime(&time_ms, event1, event2));
        return time_ms*1.0;
    } 
    void onEachDevice(std::function<void(int)> op) const override
    {
        uint32_t device = 0;
        TORCH_CHECK(xpurtGetDevice(&device));
        op(device);
    } 
    void synchronize() const override { } 
    bool enabled() const override {return true;} 
    void mark(const char*name) const override { } 
    void rangePush(const char*name) const override { } 
    void rangePop() const override {}
};
 
struct RegisterNPUMethods {
    RegisterNPUMethods()
    {
        static NPUMethods methods;
        torch::profiler::impl::registerPrivateUse1Methods(&methods);
    }
};
RegisterNPUMethods reg;
}}}

3.pytorch demo及如何调整chrome trace json文件

import time
import torchvision.models as models
from torch import nn
import torch.nn.functional as F
import copy
import math
import torch
from torch.profiler import profile
import json
import tqdm

def is_valid_kernel(name,duration,valid_kernel_threshold=100):
    '''通过算子的名字和耗时判断是否是Device Kernel'''
    invalid_kernels=["aten::view","aten::reshape",
                    "aten::t","aten::empty",
                    "aten::transpose",
                    "aten::as_strided",
                    "aten::item",
                    "aten::_local_scalar_dense",
                    "aten::result_type",
                    "aten::_unsafe_view",
                    "aten::expand"]
    for k in invalid_kernels:
        if name.find(k)>=0:
            return False
    if duration<valid_kernel_threshold:
        return False    
    return True

def filter_ev(ev):
    '''过滤Kernel'''
    if 'args' in ev and "External id" in ev['args']:
        return True
    return False

def get_uuid(ev,tid_map):
    return f"{ev['name']}_{ev['args']['External id']}_{tid_map[ev['tid']]}"

def get_valid_kernels(traceEvents,kernel_event,tid_map):
    valid_kernels=[]
    device_memory_usage=0
    for ev in traceEvents:
        if filter_ev(ev):
            uuid=get_uuid(ev,tid_map)
            if uuid not in kernel_event:
                continue
            duration=kernel_event[uuid]['kernel_time']
            kernel_name=ev['name']
            if kernel_event[uuid]['device_memory_usage']>0:
                device_memory_usage=kernel_event[uuid]['device_memory_usage']
            if is_valid_kernel(kernel_name,duration):
                launch_beg=ev['ts']
                launch_end=ev['ts']+ev['dur']            
                valid_kernels.append({"name":kernel_name,
                                      "launch_beg":launch_beg,
                                      "launch_end":launch_end,
                                      "kernel_duration":duration,
                                      "host_pid":ev['pid'],
                                      "host_tid":ev['tid'],
                                      "device_memory_usage":device_memory_usage,
                                      "is_leaf_kernel":False})
                                      
    return sorted(valid_kernels,key=lambda x:x['launch_beg'])
    
def is_leaf_kernel(kernel,valid_kernels):
    '''判断是否是叶子Kernel'''
    ret=True
    for k in valid_kernels:
        if k['is_leaf_kernel']:
            continue
        #自己的时间跨度内还有别的Kernel
        if k['launch_beg']>kernel['launch_beg'] and k['launch_end']<kernel['launch_end']:
            ret=False
            break
    return ret

def create_tid_map(traceEvents):
    tids=set()
    for ev in traceEvents:
        if filter_ev(ev):
            tid=ev['tid']
            tids.add(tid)
    tid_map={}
    tids=sorted(tids,reverse=False)
    for i,v in enumerate(tids):
        tid_map[v]=i+1
    return tid_map
                                      
def merge_prof_timeline(prof_json,kernel_event_json,output_json):
    
    kernel_lanuch_latency=0
    with open(prof_json,'r',encoding='utf-8') as f:
        prof = json.load(f)

    with open(kernel_event_json,'r',encoding='utf-8') as f:
        kernel_event = json.load(f)   
    
    traceEvents=prof['traceEvents']
    tid_map=create_tid_map(traceEvents)
    print(tid_map)
    #获取所有kernel
    valid_kernels=get_valid_kernels(traceEvents,kernel_event,tid_map)
    print(len(valid_kernels))
    #筛出所有会在device上执行的kernel
    on_device_kernels=[]
    for kernel in tqdm.tqdm(valid_kernels):
        if is_leaf_kernel(kernel,valid_kernels):
            on_device_kernels.append(kernel)
    
    kernel_start_offset=0
    kernel_index=0

    for kernel in on_device_kernels:
        name=kernel['name']
        kernel_duration=kernel["kernel_duration"]
        lanuch_time=kernel["launch_beg"]
        host_pid=kernel['host_pid']
        host_tid=kernel['host_tid']
        device_memory_usage=kernel['device_memory_usage']
        
        if kernel_start_offset==0:
            kernel_start_offset=lanuch_time+kernel_start_offset
            
        if lanuch_time>kernel_start_offset: #kernel 队列空闲
            kernel_start_offset=lanuch_time
        
        #增加kernel事件
        traceEvents.append({"ph": "X", "cat": "device_kernel", "name":name, "pid": 10, "tid": 10,"ts": kernel_start_offset, "dur": kernel_duration})
        
        #增加内存事件
        traceEvents.append({"ph": "C", "cat": "memory", "name":"memory", "pid": 11, "tid": 11,"ts": lanuch_time, "args": {"value":device_memory_usage}})
        
        #增加flow event
        traceEvents.append({"ph": "s", "id": kernel_index, "pid": host_pid, "tid": host_tid, "ts": lanuch_time,"cat": "ac2g", "name": "ac2g"})
        traceEvents.append({"ph": "f", "id": kernel_index, "pid": 10,  "tid": 10,"ts": kernel_start_offset,"cat": "ac2g", "name": "ac2g", "bp": "e"})
        
        kernel_index+=1
        kernel_start_offset+=(kernel_duration+kernel_lanuch_latency)
    
    #保存最终的结果
    with open(output_json,'w',encoding='utf-8') as f:
        json.dump(prof, f,ensure_ascii=False,indent=4)
		
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
 
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
 
    def forward(self,query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = query@key.transpose(-2,-1) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e20)
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return p_attn@value, p_attn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        self.attention = ScaledDotProductAttention()
 
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query=self.linears[0](query).view(nbatches, -1, self.h, self.d_k)
        query=query.transpose(1, 2)
        key=self.linears[1](key).view(nbatches, -1, self.h, self.d_k)
        key=key.transpose(1, 2)
        value=self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        x, self.attn = self.attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)
 
use_cuda=True
try:
    import torch_xpu
    import torch_xpu.contrib.transfer_to_xpu
    torch.xpu.set_device(0)
    torch.profiler.ProfilerActivity.PrivateUse1="xpu"
    use_cuda=False
except:
    pass
 
import os
os.environ['LOCAL_RANK']="0"
os.environ['RANK']="0"
os.environ['WORLD_SIZE']="1"
os.environ['MASTER_ADDR']="localhost"
os.environ['MASTER_PORT']="6006"

import torch.distributed as dist
dist.init_process_group(backend='vccl')
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
if not dist.is_available() or not dist.is_initialized():
    print("dist init error")
 
cross_attn = MultiHeadAttention(h=8, d_model=64).half().cuda()
cross_attn.eval()
q1 = torch.ones((1, 50, 64),dtype=torch.float32).half().cuda()
k1 = q1.clone()
v1 = q1.clone()
out = cross_attn.forward(q1,k1,v1).sum()
torch.cuda.synchronize()
 
activities=[torch.profiler.ProfilerActivity.CPU]
if use_cuda:
    activities.append(torch.profiler.ProfilerActivity.CUDA)
 
with profile(
    activities=activities,
    schedule=torch.profiler.schedule(
                wait=1,
                warmup=1,
                active=3,
                repeat=1),
    record_shapes=True,
    with_stack=True,
    with_modules=True,
    with_flops=True,
    profile_memory=True,
   ) as prof:
        for i in range(10):
            out = cross_attn.forward(q1,k1,v1).sum()
            prof.step()
        torch.cuda.synchronize()
 
if not use_cuda:
    kernel_event={}
    for ev in prof.profiler.function_events:
        if ev.privateuse1_time>0:
            uuid=f"{ev.name}_{ev.id}_{ev.thread}"
            #print(uuid,ev.id,ev.name,ev.privateuse1_time,ev.time_range.start,ev.time_range.end-ev.time_range.start,ev.privateuse1_memory_usage)
            kernel_event[uuid]={"kernel_time":ev.privateuse1_time,
								"device_memory_usage":ev.privateuse1_memory_usage,
								"start_us":ev.time_range.start,
								"host_dur":ev.time_range.end-ev.time_range.start,
								"thread":ev.thread} 
    import json
    with open(f"kernel_event_{rank}.json",'w',encoding='utf-8') as f:
        json.dump(kernel_event, f,ensure_ascii=False,indent=4)

    prof.export_chrome_trace(f"prof_{rank}.json")
    merge_prof_timeline(f"prof_{rank}.json",f"kernel_event_{rank}.json",f"prof_{rank}.json")
else:
    #print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    prof.export_chrome_trace(f"prof_{q1.device.type}.json")

4.可视化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/748985.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Verilog进行结构描述(二):Verilog基本单元(primitives)

目录 1.Verilog基本单元2.基本单元的引脚 (pin)的可扩展性3.带条件的基本单元4.基本单元实例化 微信公众号获取更多FPGA相关源码&#xff1a; 1.Verilog基本单元 Verilog基本单元提供基本的逻辑功能&#xff0c;也就是说这些逻辑功能是预定义的&#xff0c;用户不需要再定义…

后端之路第三站(Mybatis)——JDBC跟Mybatis、lombok

一、什么是JDBC JDBC就是sun公司研发的一套通过java来操控数据库的工具&#xff0c;对应不同的数据库系统有不同的JDBC&#xff0c;而他们统称【驱动】&#xff0c;这就是上一篇我们提到创建Mybatis项目时要引入的依赖、以及连接数据库四要素里的第一要素。 JDBC有自己一套原始…

应用案例 | 如何监测高价值货物在物流运输过程中受到的振动和冲击?全面保障货物安全

一、货物运输 不同种类的货物对运输的要求不同&#xff0c;钢铁、煤炭、矿石等大宗物资通常对运输要求较低&#xff0c;而电子产品、IT 产品、家电等高价值敏感类货物则更强调运输的安全性和时效性&#xff0c;往往希望能尽可能安全和快速送达这类货物&#xff0c;使之尽快进入…

使用nvm命令进行node和npm版本下载以及切换

下载以及安装nvm方式 https://blog.csdn.net/ppz8823/article/details/130862191 1.查看nvm版本 nvm -v2.查看node 和 npm版本 node -v npm -v3.使用nvm查看已下载的node版本 nvm ls4.使用nvm 查看可使用的在线node版本 nvm list available4.下载想要使用的node版本&#x…

探索 LLMs 在数据标注中的应用潜力:观察、思考与前景展望

编者按&#xff1a; 目前&#xff0c;LLMs 在机器翻译、文本生成、多轮问答等任务上已表现得非常出色了。人们开始思考它们是否也可以用于数据标注工作。数据标注是训练和评估各种机器学习模型的基础&#xff0c;一直是一项昂贵且耗时的工作。是否能够借助 LLMs 的强大能力来为…

OpenAI将终止对中国提供服务,国内模型接棒

说起来&#xff0c;OpenAI自始至终就没有对中国提供过服务&#xff0c;OpenAI官方支持的国家和地区&#xff1a;https://platform.openai.com/docs/supported-countries 列表里面没有“Chinese”的选项&#xff0c;那为什么又要明令禁止呢&#xff0c;国类IT高手们&#xff0…

来给大家推荐得10个有效磁力导航链接(好用搜资料找资源)

都2024现在网上找资源像流水得鱼一样&#xff0c;抓一大把结果很难吃&#xff0c;我通宵特意整理的网站&#xff0c;网上有许多磁力导航网站可以提供海量的磁力链接资源&#xff0c;以下是一些有效的磁力导航网站推荐&#xff1a; 磁力搜索 网站地址&#xff1a;www.chiliso…

我国氮化硼市场规模逐渐扩大 市场集中度有望不断提升

我国氮化硼市场规模逐渐扩大 市场集中度有望不断提升 氮化硼&#xff08;BN&#xff09;俗称为白石墨&#xff0c;是由硼原子和氮原子所构成的一种晶体材料&#xff0c;在常温条件下多表现为一种棕色或暗红色晶体。氮化硼具有导热性好、硬度大、熔点高、抗化学侵蚀性等优点&…

flex 与 overflow 冲突

问题场景&#xff1a; 父盒子高度会变化&#xff0c;可能会比子盒子大&#xff0c;也可能会比子盒子小。 比子盒子大的时候&#xff0c;希望子盒子垂直居中&#xff1b;比子盒子小的时候&#xff0c;能够正常滚动&#xff1b; <body><div class"outer">…

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第49课-机器人自动跳舞

【WEB前端2024】3D智体编程&#xff1a;乔布斯3D纪念馆-第49课-机器人自动跳舞 使用dtns.network德塔世界&#xff08;开源的智体世界引擎&#xff09;&#xff0c;策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由JavaScript编写的智体世界引擎…

vscode下无法识别node、npm的问题

node : 无法将“node”项识别为 cmdlet、函数、脚本文件或可运行程序的名称 因为node是在cmd安装的&#xff0c;是全局安装的&#xff0c;并不是在这个项目里安装的。 解决方案&#xff1a; 1.在vscode的控制台&#xff0c;针对一个项目安装特定版本的node&#xff1b; 2.已经…

SAP标准报表 S_ALR_8701XXXX是没有export to excel的 或者禁用 %PC也禁用了,如何开回来

以 s_alr_87012172为例子 系统-状态 进入程序 搜索 XXL 做下替换

手动将jar包导入本地Maven仓库

1、进入存放jar包的目录&#xff0c;可以先放进仓库底下 2、cmd回车 3、执行命令&#xff0c;看到BUILD SUCCESS就是成功了 -DgroupId、-DartifactId、-Dversion、-Dfile记得换成自己对应的 mvn install:install-file -DgroupIdcom.github.03 -DartifactIdonvif -Dversion1.0…

docker部署ClamAV集成java和python实现文件病毒扫描

介绍 官方文档&#xff1a;https://docs.clamav.net/manual/Signatures/DatabaseInfo.html ClamAV 是一个开源的反病毒引擎&#xff0c;它由多个模块组成&#xff0c;负责不同的任务处理。以下是 ClamAV 的主要模块和它们的功能&#xff1a; clamd&#xff1a;clamd 是 Clam…

# Kafka_深入探秘者(1):初识 kafka

Kafka_深入探秘者&#xff08;1&#xff09;&#xff1a;初识 kafka 一、kafka 特性 1、Kafka &#xff1a;最初是由 Linkedln 公司采用 Scala 语言开发的一个多分区、多副本并且基于 ZooKeeper 协调的分布式消息系统&#xff0c;现在已经捐献给了 Apache 基金会。目前 Kafka…

保姆级本地部署Qwen2

重点&#xff1a;Qwen2提供了CPU与GPU两种运行方式 运行成功效果图&#xff1a; 前提说明&#xff1a;如果需要用GPU&#xff0c;那么请在物理机安装ubuntu系统&#xff0c;不然显卡驱动很难安装&#xff0c;不建议新手部署。训练微调模型需要用到GPU。本文仅以ubuntu系统演示…

Todesk远程连接Ubuntu卡100%,以及小窗口打不开

Todesk远程连接Ubuntu卡100%&#xff0c;以及小窗口打不开 使用Todesk远程连接Ubuntu一直卡100%进不去还有todesk里的小悬浮窗打开就会小时&#xff08;小下拉框会消失&#xff09; 使用Todesk远程连接Ubuntu一直卡100%进不去 还有todesk里的小悬浮窗打开就会小时&#xff08;小…

梗图生成器突然爆红;ElevenLabs发布IOS APP 高质量语音朗读手机各种文本内容;开源工作流架构ControlFlow

✨ 1: 梗图生成器 fabianstelzer 在Glif做的一个超强meme生成器 Glif 是一个工作流&#xff0c;能生成文字图片和视频&#xff0c;用工作流的形式可以完成很多的花样来。 最近爆红的梗图生成器&#xff0c;WOJAK MEME GENERATOR &#xff0c;也是用工作流的形式来生成这些有…

防坑知识:如果要查自己的大数据信用报告,这几种平台一定不要选!

很多小伙伴在候遇到申贷碰壁&#xff0c;特别是被告知原因是大数据不良之后&#xff0c;都急着去了解自己的大数据信用情况&#xff0c;常见的方式就是在百度搜索大数据信用&#xff0c;大数据报告查询&#xff0c;哪里能查大数据信用等关键词&#xff0c;随便找一个地方就去查…

JavaScript的学习之图片的切换

目录 一、寻找素材 二、编写简单的静态html页面 代码示例 效果展示 三、JS功能的实现 JS代码 完整代码 效果展示 一、寻找素材 随便去网上找几张图片素材 二、编写简单的静态html页面 代码示例 <!doctype html> <html><head><meta charset"…