pytest学习-pytorch单元测试

pytorch单元测试

  • 一.公共模块[common.py]
  • 二.普通算子测试[test_clone.py]
  • 三.集合通信测试[test_ccl.py]
  • 四.测试命令
  • 五.测试报告

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as dist

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" 

device="cpu"
device_type="cpu"
device_name="cpu"

try:
    if torch.cuda.is_available():     
        device_name=torch.cuda.get_device_name().replace(" ","")
        device="cuda:0"
        device_type="cuda"
        ccl_backend='nccl'
except:
    pass

host_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()

if not os.path.exists(metric_data_root):
    os.makedirs(metric_data_root)

def device_warmup(device):
    '''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''
    left = torch.rand([128,512], dtype = torch.float16).to(device)
    right = torch.rand([512,128], dtype = torch.float16).to(device)
    out=torch.matmul(left,right)
    torch.cuda.synchronize()

torch.manual_seed(1) 
np.random.seed(1)

def loop_decorator(loops,rank=0):
    '''循环装饰器,用于统计函数的执行时间,内存占用等'''
    def decorator(func):
        def wrapper(*args,**kwargs):
            latency=[]
            memory_allocated_t0=torch.cuda.memory_allocated(rank)
            for _ in range(loops):
                input_copy=[x.clone() for x in args]
                beg= datetime.now().timestamp() * 1e6
                pred= func(*input_copy)
                gt=kwargs["golden"]
                torch.cuda.synchronize()
                end=datetime.now().timestamp() * 1e6
                mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()
                latency.append(end-beg)
            memory_allocated_t1=torch.cuda.memory_allocated(rank)
            avg_latency=np.mean(latency[len(latency)//2:]).round(3)
            first_latency=latency[0]
            return { "first_latency":first_latency,"avg_latency":avg_latency,
                      "memory_allocated":memory_allocated_t1-memory_allocated_t0,
                      "mse":mse}
        return wrapper
    return decorator

class TorchUtMetrics:
    '''用于统计测试结果,比较之前的最小值'''
    def __init__(self,ut_name,thresold=0.2,rank=0):
        self.ut_name=f"{ut_name}_{rank}"
        self.thresold=thresold
        self.rank=rank
        self.data={"ut_name":self.ut_name,"metrics":[]}
        self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")
        try:
            with open(self.metrics_path,"r") as f:
                self.data=json.loads(f.read())
        except:
            pass

    def __enter__(self):
        self.beg= datetime.now().timestamp() * 1e6
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):        
        self.report()
        self.save_data()

    def save_data(self):
        with open(self.metrics_path,"w") as f:
            f.write(json.dumps(self.data,indent=4))

    def set_metrics(self,metrics):
        self.end=datetime.now().timestamp() * 1e6
        item=collections.OrderedDict()
        item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
        item["sdk_version"]=sdk_version
        item["device_name"]=device_name
        item["host_name"]=host_name
        item["metrics"]=metrics
        item["metrics"]["e2e_time"]=self.end-self.beg
        self.cur_item=item
        self.data["metrics"].append(self.cur_item)

    def get_metric_names(self):
        return self.data["metrics"][0]["metrics"].keys()

    def get_min_metric(self,metric_name,devicename=None):
        min_value=0
        min_value_index=-1
        for idx,item in enumerate(self.data["metrics"]):
            if devicename and (devicename!=item['device_name']):                
                continue            
            val=float(item["metrics"][metric_name])
            if min_value_index==-1 or val<min_value:
                min_value=val
                min_value_index=idx
        return min_value,min_value_index

    def get_metric_info(self,index):
        metrics=self.data["metrics"][index]
        return f'{metrics["device_name"]}@{metrics["sdk_version"]}'

    def report(self):
        assert len(self.data["metrics"])>0
        for metric_name in self.get_metric_names():
            min_value,min_value_index=self.get_min_metric(metric_name)
            min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)
            cur_value=float(self.cur_item["metrics"][metric_name])
            print(f"-------------------------------{metric_name}-------------------------------")
            print(f"{cur_value}#{device_name}@{sdk_version}")
            if min_value_index_same_dev>=0:
                print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")
            if min_value_index>=0:
                print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

from common import *
class TestCaseClone(TestCase):
    #如果不满足条件,则跳过这个测试
    @unittest.skipIf(device_count>1, "Not enough devices") 
    def test_todo(self):
        print(".TODO")

    #框架会自动遍历以下参数组合
    @parametrize("shape", [(10240,20480),(128,256)])
    @parametrize("dtype", [torch.float16,torch.float32])
    def test_clone(self,shape,dtype):
        
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5)
        def run(input_dev):
            output=input_dev.clone()
            return output
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:
            input_host=torch.ones(shape,dtype=dtype)*np.random.rand()
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=input_host.cpu())
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        
instantiate_parametrized_tests(TestCaseClone)

if __name__ == "__main__":
    run_tests()

三.集合通信测试[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):
    '''CCL测试用例'''
    def _create_process_group_vccl(self, world_size, store):
        dist.init_process_group(
            ccl_backend, world_size=world_size, rank=self.rank, store=store
        )        
        pg = dist.distributed_c10d._get_default_group()
        return pg

    def setUp(self):
        super().setUp()
        self._spawn_processes()

    def tearDown(self):
        super().tearDown()
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    @property
    def world_size(self):
        return 4
      
    #框架会自动遍历以下参数组合
    @unittest.skipIf(device_count<4, "Not enough devices") 
    @parametrize("op",[dist.ReduceOp.SUM])
    @parametrize("shape", [(1024,8192)])
    @parametrize("dtype", [torch.int64])
    def test_allreduce(self,op,shape,dtype):
        if self.rank >= self.world_size:
            return
        
        store = dist.FileStore(self.file_name, self.world_size)
        pg = self._create_process_group_vccl(self.world_size, store)
        if not torch.distributed.is_initialized():
            return
    
        torch.cuda.set_device(self.rank)
        device = torch.device(device_type,self.rank)
        device_warmup(device)
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5,rank=self.rank)
        def run(input_dev):
            dist.all_reduce(input_dev, op=op)
            return input_dev
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:
            input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)
            gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]
            gt_=gt[0]
            for i in range(1,self.world_size):
                gt_=gt_+gt[i]
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=gt_)
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        dist.destroy_process_group(pg)
    
instantiate_parametrized_tests(TestCCL)

if __name__ == "__main__":
    run_tests()

四.测试命令

# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./

# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/553132.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

sql知识总结二

一.报错注入 1.什么是报错注入&#xff1f; 这是一种页面响应形式&#xff0c;响应过程如下&#xff1a; 用户在前台页面输入检索内容----->后台将前台输入的检索内容无加区别的拼接成sql语句&#xff0c;送给数据库执行------>数据库将执行的结果返回给后台&#xff…

Java 集合(ArrayList、LinkedList、HashMap、HashSet、LinkedHashMap、LinkedHashSet)【补充复习】

Java 集合&#xff08;ArrayList、LinkedList、HashMap、HashSet、LinkedHashMap、LinkedHashSet&#xff09;【补充复习】 Java 集合概述Collection 接口继承树Map 接口继承树 Collection 接口方法使用 iterator 接口遍历集合元素使用 forearch 遍历集合元素 List 接口List 实…

媒体邀约的好处?怎么邀请媒体?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 媒体邀约的好处主要体现在提高品牌知名度、扩大受众群体以及与媒体建立良好的合作关系。 媒体邀约是一种有效的公关策略&#xff0c;通过吸引媒体关注来促进信息的传播。它可以帮助组织…

传统大数据架构与现代数据平台的期望——Lakehouse 架构(二)

文章目录 前言数据仓库数仓基础好处和优势限制和挑战 数据湖数据湖基础好处和优势限制和挑战 现代数据平台云数据湖与云数仓组合架构现代数据平台的期望Lakehouse 架构的出现未来数据平台的默认选择&#xff1f; 总结 前言 本文概述了传统数据架构&#xff1a;数据仓库和数据湖…

【Linux系列】Ctrl + R 的使用

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

python后端相关知识点汇总(十二)

python知识点汇总十二 1、什么是 C/S 和 B/S 架构2、count(1)、count(*)、count(列名)有啥区别&#xff1f;3、如何使用线程池3.1、为什么使用线程池&#xff1f; 4、MySQL 数据库备份命令5、supervisor和Gunicorn6、python项目部署6.1、entrypoint.sh制作6.2、Dockerfile制作6…

8.Jetson AGX Orin Ubuntu20.04 gRPC编译安装

Jetson AGX Orin Ubuntu20.04 gRPC编译安装 一、CMake版本检查 grpc编译cmake要求最低版本为3.15。首先&#xff0c;cmake -version 查看当前cmake版本&#xff0c;如果低于3.15&#xff0c;按照以下步骤进行安装。 1.1 卸载已经安装的旧版的CMake sudo apt-get autoremove…

Redmi Turbo 3新品发布,天星金融(原小米金融)优惠加持护航新机体验

Redmi新十年使命不变&#xff0c;挑战不断升级。Redmi Turbo 3&#xff0c;作为Turbo系列的开篇之作&#xff0c;将自身定位为新生代性能旗舰&#xff0c;决心重塑中端性能新格局。据悉&#xff0c;Redmi Turbo 3于4月10日已正式发布。预售期间更是连续数日&#xff0c;蝉联小米…

mac终端使用代理加速下载

环境变量增加前IP&#xff1a; 环境变量配置后&#xff0c;新打开一个终端的ip&#xff0c;开始享受极速吧~

【Python基础】MySQL

文章目录 [toc]创建数据库创建数据表数据插入数据查询数据更新 个人主页&#xff1a;丷从心 系列专栏&#xff1a;Python基础 学习指南&#xff1a;Python学习指南 创建数据库 import pymysqldef create_database():db pymysql.connect(hostlocalhost, userroot, passwordr…

【GIS教程】土地利用转移矩阵、土地利用面积变化

随着科技社会的不断进步&#xff0c;人类活动对地理环境的影响与塑造日益明显&#xff0c;土地不断的侵蚀与改变也导致一系列的环境问题日益突出。土地利用/覆盖&#xff08;LUCC&#xff09;作为全球环境变化研究的重点问题为越来越多的国际研究机构所重视&#xff0c;研究它的…

Python大数据分析——岭回归和LASSO回归模型

Python大数据分析——岭回归和LASSO回归模型 模型原因列数多于行数变量和变量间存在多重共线性 岭回归模型理论分析函数示例 LASSO回归模型理论分析函数示例 模型原因 我们为什么要有岭回归和LASSO回归呢&#xff1f; 因为根据线性回归模型的参数估计公式β(X’X)-1X’y可知&…

3DGS渐进式渲染 - 离线生成渲染视频

总览 输入&#xff1a;环绕Object拍摄的RGB视频 输出&#xff1a;自定义相机路径的渲染视频&#xff08;包含渐变效果&#xff09; 实现过程 首先&#xff0c;编译3DGS的C代码&#xff0c;并跑通convert.py、train.py和render.py。教程如下&#xff1a; github网址&#xf…

HarmonyOS开发实例:【分布式手写板】

介绍 本篇Codelab使用设备管理及分布式键值数据库能力&#xff0c;实现多设备之间手写板应用拉起及同步书写内容的功能。操作流程&#xff1a; 设备连接同一无线网络&#xff0c;安装分布式手写板应用。进入应用&#xff0c;点击允许使用多设备协同&#xff0c;点击主页上查询…

spring06:mybatis-spring(Spring整合MyBatis)

spring06&#xff1a;mybatis-spring&#xff08;Spring整合MyBatis&#xff09; 文章目录 spring06&#xff1a;mybatis-spring&#xff08;Spring整合MyBatis&#xff09;前言&#xff1a;什么是 MyBatis-Spring&#xff1f;MyBatis-Spring 会帮助你将 MyBatis 代码无缝地整合…

【VIC水文模型】准备工作:平台软件安装

VIC水文模型所需平台软件安装 1 Arcgis安装2 Cygwin安装&#xff08;Linux系统&#xff09;3 Matlab/R/Fortran的安装Notepad 4 VIC模型程序代码获取参考 由于VIC模型的编程语言为C语言&#xff0c;交互方式为控制台输指令&#xff0c;需要在Linux系统上运行。Windows 上使用 …

简述PDF原理和实践

Hello&#xff0c;我是小恒不会java。 由于最近有输出PDF报表的项目需求&#xff0c;所以复习一下PDF到底是什么&#xff0c;该如何产生&#xff0c;如何应用至项目中。 更多参见Adobe官方文档&#xff08;https://www.adobe.com/cn/&#xff09; PDF原理 PDF&#xff08;Port…

Docker应用推荐个人服务器实用有趣的项目推荐

Wallabag&#xff1a;是一个开源的、自托管的文章阅读和保存工具。它允许你保存网页文章并进行离线阅读&#xff0c;去除广告和不必要的内容&#xff0c;以提供更好的阅读体验。Wallabag支持多种导入和导出格式&#xff0c;并提供了一些实用的功能&#xff0c;如标签、阅读列表…

Flutter 像素编辑器#03 | 像素图层

theme: cyanosis 本系列&#xff0c;将通过 Flutter 实现一个全平台的像素编辑器应用。源码见开源项目 【pix_editor】 《Flutter 像素编辑器#01 | 像素网格》《Flutter 像素编辑器#02 | 配置编辑》《Flutter 像素编辑器#03 | 像素图层》 上一篇我们实现了编辑配置&#xff0c;…

【R语言】组合图:散点图+箱线图+平滑曲线图+柱状图

用算数运算符轻松组合不同的ggplot图&#xff0c;如图&#xff1a; 具体代码如下&#xff1a; install.packages("devtools")#安装devtools包 devtools::install_github("thomasp85/patchwork")#安装patchwork包 library(ggplot2) library(patchwork) #p1是…