Skip to content

Commit ca2f34c

Browse files
committed
issue/66: modified test py
1 parent 87d1097 commit ca2f34c

File tree

6 files changed

+544
-537
lines changed

6 files changed

+544
-537
lines changed

test/infiniop/causal_softmax.py

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,50 @@
1-
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
1+
import torch
22
import ctypes
3-
import sys
4-
import os
5-
6-
7-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8-
from operatorspy import (
9-
open_lib,
10-
to_tensor,
11-
DeviceEnum,
3+
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
4+
from libinfiniop import (
125
infiniopHandle_t,
136
infiniopTensorDescriptor_t,
14-
create_handle,
15-
destroy_handle,
7+
open_lib,
8+
to_tensor,
9+
get_test_devices,
1610
check_error,
17-
rearrange_tensor,
11+
rearrange_if_needed,
1812
create_workspace,
13+
test_operator,
14+
get_args,
15+
debug,
16+
get_tolerance,
17+
profile_operation,
1918
)
2019

21-
from operatorspy.tests.test_utils import get_args
22-
import torch
20+
# ==============================================================================
21+
# Configuration (Internal Use Only)
22+
# ==============================================================================
23+
# These are not meant to be imported from other modules
2324

25+
_TEST_CASES = [
26+
# x_shape, x_stride
27+
((32, 512), None),
28+
((32, 512), (1024, 1)),
29+
((32, 5, 5), None),
30+
((32, 20, 512), None),
31+
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
32+
((32, 20, 4, 512), None),
33+
((32, 20, 4, 512), (81920, 2048, 512, 1)),
34+
]
35+
# Data types used for testing
36+
_TENSOR_DTYPES = [torch.float16, torch.float32]
37+
38+
# Tolerance map for different data types
39+
_TOLERANCE_MAP = {
40+
torch.float16: {'atol': 0, 'rtol': 1e-2},
41+
torch.float32: {'atol': 0, 'rtol': 1e-3},
42+
}
43+
44+
DEBUG = False
45+
PROFILE = False
46+
NUM_PRERUN = 10
47+
NUM_ITERATIONS = 1000
2448

2549
class CausalSoftmaxDescriptor(Structure):
2650
_fields_ = [("device", c_int32)]
@@ -37,88 +61,78 @@ def causal_softmax(x):
3761
return torch.nn.functional.softmax(masked, dim=-1).to(type)
3862

3963

40-
def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float16):
64+
def test(
65+
lib,
66+
handle,
67+
torch_device,
68+
x_shape,
69+
x_stride=None,
70+
dtype=torch.float16
71+
):
4172
print(
42-
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{x_dtype}"
73+
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
4374
)
44-
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
45-
if x_stride is not None:
46-
x = rearrange_tensor(x, x_stride)
75+
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
76+
4777
ans = causal_softmax(x)
78+
79+
80+
x = rearrange_if_needed(x, x_stride)
81+
4882
x_tensor = to_tensor(x, lib)
83+
4984
descriptor = infiniopCausalSoftmaxDescriptor_t()
5085
check_error(
5186
lib.infiniopCreateCausalSoftmaxDescriptor(
52-
handle, ctypes.byref(descriptor), x_tensor.descriptor
87+
handle,
88+
ctypes.byref(descriptor),
89+
x_tensor.descriptor
5390
)
5491
)
92+
93+
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
94+
x_tensor.descriptor.contents.invalidate()
95+
96+
5597
workspace_size = c_uint64(0)
5698
check_error(
5799
lib.infiniopGetCausalSoftmaxWorkspaceSize(
58100
descriptor, ctypes.byref(workspace_size)
59101
)
60102
)
61-
62-
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
63-
x_tensor.descriptor.contents.invalidate()
64-
65103
workspace = create_workspace(workspace_size.value, x.device)
66-
check_error(
67-
lib.infiniopCausalSoftmax(
68-
descriptor,
69-
workspace.data_ptr() if workspace is not None else None,
70-
workspace_size.value,
71-
x_tensor.data,
72-
None,
104+
def lib_causal_softmax():
105+
check_error(
106+
lib.infiniopCausalSoftmax(
107+
descriptor,
108+
workspace.data_ptr() if workspace is not None else None,
109+
workspace_size.value,
110+
x_tensor.data,
111+
None,
112+
)
73113
)
74-
)
75-
assert torch.allclose(x, ans, atol=0, rtol=1e-2)
76-
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
77-
78-
79-
def test_cpu(lib, test_cases):
80-
device = DeviceEnum.DEVICE_CPU
81-
handle = create_handle(lib, device)
82-
for x_shape, x_stride in test_cases:
83-
test(lib, handle, "cpu", x_shape, x_stride)
84-
destroy_handle(lib, handle)
85-
86-
87-
def test_cuda(lib, test_cases):
88-
device = DeviceEnum.DEVICE_CUDA
89-
handle = create_handle(lib, device)
90-
for x_shape, x_stride in test_cases:
91-
test(lib, handle, "cuda", x_shape, x_stride)
92-
destroy_handle(lib, handle)
93-
94-
95-
def test_bang(lib, test_cases):
96-
import torch_mlu
97-
98-
device = DeviceEnum.DEVICE_BANG
99-
handle = create_handle(lib, device)
100-
for x_shape, x_stride in test_cases:
101-
test(lib, handle, "mlu", x_shape, x_stride)
102-
destroy_handle(lib, handle)
114+
lib_causal_softmax()
115+
116+
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
117+
if DEBUG:
118+
debug(x, ans, atol=atol, rtol=rtol)
119+
assert torch.allclose(x, ans, atol=atol, rtol=rtol)
120+
121+
# Profiling workflow
122+
if PROFILE:
123+
# fmt: off
124+
profile_operation("PyTorch", lambda: causal_softmax(x), torch_device, NUM_PRERUN, NUM_ITERATIONS)
125+
profile_operation(" lib", lambda: lib_causal_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
126+
# fmt: on
103127

128+
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
104129

105-
def test_ascend(lib, test_cases):
106-
import torch_npu
107130

108-
device = DeviceEnum.DEVICE_ASCEND
109-
handle = create_handle(lib, device)
110-
for x_shape, x_stride in test_cases:
111-
test(lib, handle, "npu", x_shape, x_stride)
112131

113-
destroy_handle(lib, handle)
114132

115133

116134
if __name__ == "__main__":
117-
test_cases = [
118-
# x_shape, x_stride
119-
((32, 20, 512), None),
120-
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
121-
]
135+
122136
args = get_args()
123137
lib = open_lib()
124138
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
@@ -144,15 +158,14 @@ def test_ascend(lib, test_cases):
144158
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
145159
infiniopCausalSoftmaxDescriptor_t,
146160
]
161+
# Configure testing options
162+
DEBUG = args.debug
163+
PROFILE = args.profile
164+
NUM_PRERUN = args.num_prerun
165+
NUM_ITERATIONS = args.num_iterations
166+
167+
for device in get_test_devices(args):
168+
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
147169

148-
if args.cpu:
149-
test_cpu(lib, test_cases)
150-
if args.cuda:
151-
test_cuda(lib, test_cases)
152-
if args.bang:
153-
test_bang(lib, test_cases)
154-
if args.ascend:
155-
test_ascend(lib, test_cases)
156-
if not (args.cpu or args.cuda or args.bang or args.ascend):
157-
test_cpu(lib, test_cases)
158170
print("\033[92mTest passed!\033[0m")
171+

0 commit comments

Comments
 (0)