当前位置:
首页 > Python基础教程 >
-
pytorch单元测试的实现示例
单元测试是一种软件测试方法,本文主要介绍了pytorch单元测试的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.
本文基于torch.testing._internal
一.公共模块[common.py]
import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as dist
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0"
device="cpu"
device_type="cpu"
device_name="cpu"
try:
if torch.cuda.is_available():
device_name=torch.cuda.get_device_name().replace(" ","")
device="cuda:0"
device_type="cuda"
ccl_backend='nccl'
except:
pass
host_name=socket.gethostname()
sdk_version=os.getenv("SDK_VERSION","") #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data") #日志存放的目录
device_count=torch.cuda.device_count()
if not os.path.exists(metric_data_root):
os.makedirs(metric_data_root)
def device_warmup(device):
'''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''
left = torch.rand([128,512], dtype = torch.float16).to(device)
right = torch.rand([512,128], dtype = torch.float16).to(device)
out=torch.matmul(left,right)
torch.cuda.synchronize()
torch.manual_seed(1)
np.random.seed(1)
def loop_decorator(loops,rank=0):
'''循环装饰器,用于统计函数的执行时间,内存占用等'''
def decorator(func):
def wrapper(*args,**kwargs):
latency=[]
memory_allocated_t0=torch.cuda.memory_allocated(rank)
for _ in range(loops):
input_copy=[x.clone() for x in args]
beg= datetime.now().timestamp() * 1e6
pred= func(*input_copy)
gt=kwargs["golden"]
torch.cuda.synchronize()
end=datetime.now().timestamp() * 1e6
mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()
latency.append(end-beg)
memory_allocated_t1=torch.cuda.memory_allocated(rank)
avg_latency=np.mean(latency[len(latency)//2:]).round(3)
first_latency=latency[0]
return { "first_latency":first_latency,"avg_latency":avg_latency,
"memory_allocated":memory_allocated_t1-memory_allocated_t0,
"mse":mse}
return wrapper
return decorator
class TorchUtMetrics:
'''用于统计测试结果,比较之前的最小值'''
def __init__(self,ut_name,thresold=0.2,rank=0):
self.ut_name=f"{ut_name}_{rank}"
self.thresold=thresold
self.rank=rank
self.data={"ut_name":self.ut_name,"metrics":[]}
self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")
try:
with open(self.metrics_path,"r") as f:
self.data=json.loads(f.read())
except:
pass
def __enter__(self):
self.beg= datetime.now().timestamp() * 1e6
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.report()
self.save_data()
def save_data(self):
with open(self.metrics_path,"w") as f:
f.write(json.dumps(self.data,indent=4))
def set_metrics(self,metrics):
self.end=datetime.now().timestamp() * 1e6
item=collections.OrderedDict()
item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
item["sdk_version"]=sdk_version
item["device_name"]=device_name
item["host_name"]=host_name
item["metrics"]=metrics
item["metrics"]["e2e_time"]=self.end-self.beg
self.cur_item=item
self.data["metrics"].append(self.cur_item)
def get_metric_names(self):
return self.data["metrics"][0]["metrics"].keys()
def get_min_metric(self,metric_name,devicename=None):
min_value=0
min_value_index=-1
for idx,item in enumerate(self.data["metrics"]):
if devicename and (devicename!=item['device_name']):
continue
val=float(item["metrics"][metric_name])
if min_value_index==-1 or val<min_value:
min_value=val
min_value_index=idx
return min_value,min_value_index
def get_metric_info(self,index):
metrics=self.data["metrics"][index]
return f'{metrics["device_name"]}@{metrics["sdk_version"]}'
def report(self):
assert len(self.data["metrics"])>0
for metric_name in self.get_metric_names():
min_value,min_value_index=self.get_min_metric(metric_name)
min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)
cur_value=float(self.cur_item["metrics"][metric_name])
print(f"-------------------------------{metric_name}-------------------------------")
print(f"{cur_value}#{device_name}@{sdk_version}")
if min_value_index_same_dev>=0:
print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")
if min_value_index>=0:
print(f"{min_value}#{self.get_metric_info(min_value_index)}")
二.普通算子测试[test_clone.py]
from common import *
class TestCaseClone(TestCase):
#如果不满足条件,则跳过这个测试
@unittest.skipIf(device_count>1, "Not enough devices")
def test_todo(self):
print(".TODO")
#框架会自动遍历以下参数组合
@parametrize("shape", [(10240,20480),(128,256)])
@parametrize("dtype", [torch.float16,torch.float32])
def test_clone(self,shape,dtype):
#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
@loop_decorator(loops=5)
def run(input_dev):
output=input_dev.clone()
return output
#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:
input_host=torch.ones(shape,dtype=dtype)*np.random.rand()
input_dev=input_host.to(device)
metrics=run(input_dev,golden=input_host.cpu())
m.set_metrics(metrics)
assert(metrics["mse"]==0)
instantiate_parametrized_tests(TestCaseClone)
if __name__ == "__main__":
run_tests()
三.集合通信测试[test_ccl.py]
from common import *
class TestCCL(MultiProcessTestCase):
'''CCL测试用例'''
def _create_process_group_vccl(self, world_size, store):
dist.init_process_group(
ccl_backend, world_size=world_size, rank=self.rank, store=store
)
pg = dist.distributed_c10d._get_default_group()
return pg
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self):
return 4
#框架会自动遍历以下参数组合
@unittest.skipIf(device_count<4, "Not enough devices")
@parametrize("op",[dist.ReduceOp.SUM])
@parametrize("shape", [(1024,8192)])
@parametrize("dtype", [torch.int64])
def test_allreduce(self,op,shape,dtype):
if self.rank >= self.world_size:
return
store = dist.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_vccl(self.world_size, store)
if not torch.distributed.is_initialized():
return
torch.cuda.set_device(self.rank)
device = torch.device(device_type,self.rank)
device_warmup(device)
#让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
@loop_decorator(loops=5,rank=self.rank)
def run(input_dev):
dist.all_reduce(input_dev, op=op)
return input_dev
#记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:
input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)
gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]
gt_=gt[0]
for i in range(1,self.world_size):
gt_=gt_+gt[i]
input_dev=input_host.to(device)
metrics=run(input_dev,golden=gt_)
m.set_metrics(metrics)
assert(metrics["mse"]==0)
dist.destroy_process_group(pg)
instantiate_parametrized_tests(TestCCL)
if __name__ == "__main__":
run_tests()
四.测试命令
# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./
# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"
五.测试报告
在这里插入图片描述
到此这篇关于pytorch单元测试的实现示例的文章就介绍到这了,更多相关pytorch单元测试内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!
原文链接:https://blog.csdn.net/m0_61864577/article/details/137888050
栏目列表
最新更新
求1000阶乘的结果末尾有多少个0
详解MyBatis延迟加载是如何实现的
IDEA 控制台中文乱码4种解决方案
SpringBoot中版本兼容性处理的实现示例
Spring的IOC解决程序耦合的实现
详解Spring多数据源如何切换
Java报错:UnsupportedOperationException in Col
使用Spring Batch实现批处理任务的详细教程
java中怎么将多个音频文件拼接合成一个
SpringBoot整合ES多个精确值查询 terms功能实
SQL Server 中的数据类型隐式转换问题
SQL Server中T-SQL 数据类型转换详解
sqlserver 数据类型转换小实验
SQL Server数据类型转换方法
SQL Server 2017无法连接到服务器的问题解决
SQLServer地址搜索性能优化
Sql Server查询性能优化之不可小觑的书签查
SQL Server数据库的高性能优化经验总结
SQL SERVER性能优化综述(很好的总结,不要错
开启SQLSERVER数据库缓存依赖优化网站性能
uniapp/H5 获取手机桌面壁纸 (静态壁纸)
[前端] DNS解析与优化
为什么在js中需要添加addEventListener()?
JS模块化系统
js通过Object.defineProperty() 定义和控制对象
这是目前我见过最好的跨域解决方案!
减少回流与重绘
减少回流与重绘
如何使用KrpanoToolJS在浏览器切图
performance.now() 与 Date.now() 对比