Skip to content

Commit b7893d6

Browse files
Merge pull request #67 from PanZezhong1725/issue/66
issue/66: 重构7个算子的测试脚本
2 parents 3165aba + 642e8de commit b7893d6

File tree

6 files changed

+556
-649
lines changed

6 files changed

+556
-649
lines changed

test/infiniop/causal_softmax.py

Lines changed: 85 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,47 @@
1-
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
1+
import torch
22
import ctypes
3-
import sys
4-
import os
5-
6-
7-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8-
from operatorspy import (
9-
open_lib,
10-
to_tensor,
11-
DeviceEnum,
3+
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
4+
from libinfiniop import (
125
infiniopHandle_t,
136
infiniopTensorDescriptor_t,
14-
create_handle,
15-
destroy_handle,
7+
open_lib,
8+
to_tensor,
9+
get_test_devices,
1610
check_error,
17-
rearrange_tensor,
11+
rearrange_if_needed,
1812
create_workspace,
13+
test_operator,
14+
get_args,
15+
debug,
16+
get_tolerance,
17+
profile_operation,
1918
)
2019

21-
from operatorspy.tests.test_utils import get_args
22-
import torch
20+
# ==============================================================================
21+
# Configuration (Internal Use Only)
22+
# ==============================================================================
23+
# These are not meant to be imported from other modules
24+
_TEST_CASES = [
25+
# x_shape, x_stride
26+
((32, 512), None),
27+
((32, 512), (1024, 1)),
28+
((32, 5, 5), None),
29+
((32, 20, 512), None),
30+
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
31+
]
32+
33+
# Data types used for testing
34+
_TENSOR_DTYPES = [torch.float16]
35+
36+
# Tolerance map for different data types
37+
_TOLERANCE_MAP = {
38+
torch.float16: {"atol": 0, "rtol": 1e-2},
39+
}
40+
41+
DEBUG = False
42+
PROFILE = False
43+
NUM_PRERUN = 10
44+
NUM_ITERATIONS = 1000
2345

2446

2547
class CausalSoftmaxDescriptor(Structure):
@@ -37,101 +59,82 @@ def causal_softmax(x):
3759
return torch.nn.functional.softmax(masked, dim=-1).to(type)
3860

3961

40-
def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float16):
62+
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
4163
print(
42-
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{x_dtype}"
64+
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
4365
)
44-
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
45-
if x_stride is not None:
46-
x = rearrange_tensor(x, x_stride)
66+
67+
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
68+
4769
ans = causal_softmax(x)
70+
71+
x = rearrange_if_needed(x, x_stride)
72+
4873
x_tensor = to_tensor(x, lib)
74+
4975
descriptor = infiniopCausalSoftmaxDescriptor_t()
5076
check_error(
5177
lib.infiniopCreateCausalSoftmaxDescriptor(
5278
handle, ctypes.byref(descriptor), x_tensor.descriptor
5379
)
5480
)
55-
workspace_size = c_uint64(0)
56-
check_error(
57-
lib.infiniopGetCausalSoftmaxWorkspaceSize(
58-
descriptor, ctypes.byref(workspace_size)
59-
)
60-
)
6181

6282
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
6383
x_tensor.descriptor.contents.invalidate()
6484

65-
workspace = create_workspace(workspace_size.value, x.device)
85+
workspace_size = c_uint64(0)
6686
check_error(
67-
lib.infiniopCausalSoftmax(
68-
descriptor,
69-
workspace.data_ptr() if workspace is not None else None,
70-
workspace_size.value,
71-
x_tensor.data,
72-
None,
87+
lib.infiniopGetCausalSoftmaxWorkspaceSize(
88+
descriptor, ctypes.byref(workspace_size)
7389
)
7490
)
75-
assert torch.allclose(x, ans, atol=0, rtol=1e-2)
76-
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
77-
78-
79-
def test_cpu(lib, test_cases):
80-
device = DeviceEnum.DEVICE_CPU
81-
handle = create_handle(lib, device)
82-
for x_shape, x_stride in test_cases:
83-
test(lib, handle, "cpu", x_shape, x_stride)
84-
destroy_handle(lib, handle)
85-
86-
87-
def test_cuda(lib, test_cases):
88-
device = DeviceEnum.DEVICE_CUDA
89-
handle = create_handle(lib, device)
90-
for x_shape, x_stride in test_cases:
91-
test(lib, handle, "cuda", x_shape, x_stride)
92-
destroy_handle(lib, handle)
93-
94-
95-
def test_bang(lib, test_cases):
96-
import torch_mlu
91+
workspace = create_workspace(workspace_size.value, x.device)
9792

98-
device = DeviceEnum.DEVICE_BANG
99-
handle = create_handle(lib, device)
100-
for x_shape, x_stride in test_cases:
101-
test(lib, handle, "mlu", x_shape, x_stride)
102-
destroy_handle(lib, handle)
93+
def lib_causal_softmax():
94+
check_error(
95+
lib.infiniopCausalSoftmax(
96+
descriptor,
97+
workspace.data_ptr() if workspace is not None else None,
98+
workspace_size.value,
99+
x_tensor.data,
100+
None,
101+
)
102+
)
103103

104+
lib_causal_softmax()
104105

105-
def test_ascend(lib, test_cases):
106-
import torch_npu
106+
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
107+
if DEBUG:
108+
debug(x, ans, atol=atol, rtol=rtol)
109+
assert torch.allclose(x, ans, atol=atol, rtol=rtol)
107110

108-
device = DeviceEnum.DEVICE_ASCEND
109-
handle = create_handle(lib, device)
110-
for x_shape, x_stride in test_cases:
111-
test(lib, handle, "npu", x_shape, x_stride)
111+
# Profiling workflow
112+
if PROFILE:
113+
# fmt: off
114+
profile_operation("PyTorch", lambda: causal_softmax(x), torch_device, NUM_PRERUN, NUM_ITERATIONS)
115+
profile_operation(" lib", lambda: lib_causal_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
116+
# fmt: on
112117

113-
destroy_handle(lib, handle)
118+
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
114119

115120

116121
if __name__ == "__main__":
117-
test_cases = [
118-
# x_shape, x_stride
119-
((32, 20, 512), None),
120-
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
121-
]
122122
args = get_args()
123123
lib = open_lib()
124+
124125
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
125126
lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
126127
infiniopHandle_t,
127128
POINTER(infiniopCausalSoftmaxDescriptor_t),
128129
infiniopTensorDescriptor_t,
129130
]
131+
130132
lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
131133
lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
132134
infiniopCausalSoftmaxDescriptor_t,
133135
POINTER(c_uint64),
134136
]
137+
135138
lib.infiniopCausalSoftmax.restype = c_int32
136139
lib.infiniopCausalSoftmax.argtypes = [
137140
infiniopCausalSoftmaxDescriptor_t,
@@ -140,19 +143,19 @@ def test_ascend(lib, test_cases):
140143
c_void_p,
141144
c_void_p,
142145
]
146+
143147
lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
144148
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
145149
infiniopCausalSoftmaxDescriptor_t,
146150
]
147151

148-
if args.cpu:
149-
test_cpu(lib, test_cases)
150-
if args.cuda:
151-
test_cuda(lib, test_cases)
152-
if args.bang:
153-
test_bang(lib, test_cases)
154-
if args.ascend:
155-
test_ascend(lib, test_cases)
156-
if not (args.cpu or args.cuda or args.bang or args.ascend):
157-
test_cpu(lib, test_cases)
152+
# Configure testing options
153+
DEBUG = args.debug
154+
PROFILE = args.profile
155+
NUM_PRERUN = args.num_prerun
156+
NUM_ITERATIONS = args.num_iterations
157+
158+
for device in get_test_devices(args):
159+
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
160+
158161
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)