Skip to content

Commit c927faf

Browse files
committed
modified format
1 parent ca2f34c commit c927faf

File tree

6 files changed

+213
-285
lines changed

6 files changed

+213
-285
lines changed

test/infiniop/causal_softmax.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,28 @@
2323
# These are not meant to be imported from other modules
2424

2525
_TEST_CASES = [
26-
# x_shape, x_stride
27-
((32, 512), None),
28-
((32, 512), (1024, 1)),
29-
((32, 5, 5), None),
30-
((32, 20, 512), None),
31-
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
32-
((32, 20, 4, 512), None),
33-
((32, 20, 4, 512), (81920, 2048, 512, 1)),
34-
]
26+
# x_shape, x_stride
27+
((32, 512), None),
28+
((32, 512), (1024, 1)),
29+
((32, 5, 5), None),
30+
((32, 20, 512), None),
31+
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
32+
]
33+
3534
# Data types used for testing
36-
_TENSOR_DTYPES = [torch.float16, torch.float32]
35+
_TENSOR_DTYPES = [torch.float16]
3736

3837
# Tolerance map for different data types
3938
_TOLERANCE_MAP = {
40-
torch.float16: {'atol': 0, 'rtol': 1e-2},
41-
torch.float32: {'atol': 0, 'rtol': 1e-3},
39+
torch.float16: {"atol": 0, "rtol": 1e-2},
4240
}
4341

4442
DEBUG = False
4543
PROFILE = False
4644
NUM_PRERUN = 10
4745
NUM_ITERATIONS = 1000
4846

47+
4948
class CausalSoftmaxDescriptor(Structure):
5049
_fields_ = [("device", c_int32)]
5150

@@ -61,46 +60,37 @@ def causal_softmax(x):
6160
return torch.nn.functional.softmax(masked, dim=-1).to(type)
6261

6362

64-
def test(
65-
lib,
66-
handle,
67-
torch_device,
68-
x_shape,
69-
x_stride=None,
70-
dtype=torch.float16
71-
):
63+
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
7264
print(
7365
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
7466
)
67+
7568
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
7669

7770
ans = causal_softmax(x)
7871

79-
8072
x = rearrange_if_needed(x, x_stride)
81-
73+
8274
x_tensor = to_tensor(x, lib)
8375

8476
descriptor = infiniopCausalSoftmaxDescriptor_t()
8577
check_error(
8678
lib.infiniopCreateCausalSoftmaxDescriptor(
87-
handle,
88-
ctypes.byref(descriptor),
89-
x_tensor.descriptor
79+
handle, ctypes.byref(descriptor), x_tensor.descriptor
9080
)
9181
)
9282

9383
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
9484
x_tensor.descriptor.contents.invalidate()
9585

96-
9786
workspace_size = c_uint64(0)
9887
check_error(
9988
lib.infiniopGetCausalSoftmaxWorkspaceSize(
10089
descriptor, ctypes.byref(workspace_size)
10190
)
10291
)
10392
workspace = create_workspace(workspace_size.value, x.device)
93+
10494
def lib_causal_softmax():
10595
check_error(
10696
lib.infiniopCausalSoftmax(
@@ -111,8 +101,9 @@ def lib_causal_softmax():
111101
None,
112102
)
113103
)
104+
114105
lib_causal_softmax()
115-
106+
116107
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
117108
if DEBUG:
118109
debug(x, ans, atol=atol, rtol=rtol)
@@ -128,24 +119,23 @@ def lib_causal_softmax():
128119
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
129120

130121

131-
132-
133-
134122
if __name__ == "__main__":
135-
136123
args = get_args()
137124
lib = open_lib()
125+
138126
lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
139127
lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
140128
infiniopHandle_t,
141129
POINTER(infiniopCausalSoftmaxDescriptor_t),
142130
infiniopTensorDescriptor_t,
143131
]
132+
144133
lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
145134
lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
146135
infiniopCausalSoftmaxDescriptor_t,
147136
POINTER(c_uint64),
148137
]
138+
149139
lib.infiniopCausalSoftmax.restype = c_int32
150140
lib.infiniopCausalSoftmax.argtypes = [
151141
infiniopCausalSoftmaxDescriptor_t,
@@ -154,18 +144,19 @@ def lib_causal_softmax():
154144
c_void_p,
155145
c_void_p,
156146
]
147+
157148
lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
158149
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
159150
infiniopCausalSoftmaxDescriptor_t,
160151
]
152+
161153
# Configure testing options
162154
DEBUG = args.debug
163155
PROFILE = args.profile
164156
NUM_PRERUN = args.num_prerun
165157
NUM_ITERATIONS = args.num_iterations
166-
158+
167159
for device in get_test_devices(args):
168160
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
169161

170162
print("\033[92mTest passed!\033[0m")
171-

test/infiniop/random_sample.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,40 @@
1212
create_workspace,
1313
test_operator,
1414
get_args,
15-
debug,
15+
debug_all,
1616
get_tolerance,
1717
profile_operation,
18+
synchronize_device,
1819
)
1920

2021
# ==============================================================================
2122
# Configuration (Internal Use Only)
2223
# ==============================================================================
2324
# These are not meant to be imported from other modules
25+
2426
_TEST_CASES = [
2527
# voc, random_val, topp, topk, temperature
26-
(512, 0.8, 0.8, 3, 0.5),
27-
(4096, 0.05, 0.9, 5, 1.0),
28-
(16384, 0.15, 0.85, 10, 2.0),
29-
(512, 0.08, 0, 3, 0.5),
30-
(4096, 0.5, 0.9, 1, 1.0),
31-
(16384, 0.15, 0, 1, 2.0),
32-
(16384, 0.15, 0, 1, 2.0),
33-
(32000, 0.08, 0.8, 50, 1.0),
34-
(32000, 0.08, 1.0, 25, 1.0),
35-
# (119696, 0.01, 1.0, 100, 1.0),
28+
(512, 0.8, 0.8, 3, 0.5),
29+
(4096, 0.05, 0.9, 5, 1.0),
30+
(16384, 0.15, 0.85, 10, 2.0),
31+
(512, 0.08, 0, 3, 0.5),
32+
(4096, 0.5, 0.9, 1, 1.0),
33+
(16384, 0.15, 0, 1, 2.0),
34+
(16384, 0.15, 0, 1, 2.0),
35+
(32000, 0.08, 0.8, 50, 1.0),
36+
(32000, 0.08, 1.0, 25, 1.0),
37+
# (119696, 0.01, 1.0, 100, 1.0),
3638
]
3739

3840
# Data types used for testing
39-
_TENSOR_DTYPES = [torch.float16, torch.float32]
41+
_TENSOR_DTYPES = [torch.float16]
42+
43+
_TOLERANCE_MAP = {
44+
torch.float16: {"atol": 0, "rtol": 0},
45+
}
4046

4147

48+
DEBUG = False
4249
PROFILE = False
4350
NUM_PRERUN = 10
4451
NUM_ITERATIONS = 1000
@@ -113,6 +120,7 @@ def test(
113120
x_dtype=torch.float16,
114121
):
115122
print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
123+
116124
data = torch.arange(voc).float() * 0.0001
117125
_perm = torch.randperm(voc)
118126
data = data[_perm].to(x_dtype).to(torch_device)
@@ -122,9 +130,11 @@ def test(
122130
)
123131
else:
124132
ans = random_sample_0(data)
133+
125134
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
126-
x_tensor = to_tensor(data, lib)
127-
indices_tensor = to_tensor(indices, lib)
135+
136+
x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
137+
128138
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
129139

130140
descriptor = infiniopRandomSampleDescriptor_t()
@@ -148,7 +158,7 @@ def test(
148158
)
149159
)
150160
workspace = create_workspace(workspace_size.value, torch_device)
151-
161+
152162
def lib_random_sample():
153163
check_error(
154164
lib.infiniopRandomSample(
@@ -164,11 +174,21 @@ def lib_random_sample():
164174
None,
165175
)
166176
)
167-
if torch_device == "npu":
168-
torch.npu.synchronize()
169177

178+
if torch_device == "npu":
179+
synchronize_device(torch_device)
180+
181+
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
182+
if DEBUG:
183+
debug_all(
184+
(indices[0].type(ans.dtype), data[indices[0]]),
185+
(ans, data[ans]),
186+
"or",
187+
atol=atol,
188+
rtol=rtol,
189+
)
170190
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
171-
191+
172192
# Profiling workflow
173193
if PROFILE:
174194
# fmt: off
@@ -184,23 +204,23 @@ def lib_random_sample():
184204
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
185205

186206

187-
188-
189207
if __name__ == "__main__":
190-
191208
args = get_args()
192209
lib = open_lib()
210+
193211
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
194212
lib.infiniopCreateRandomSampleDescriptor.argtypes = [
195213
infiniopHandle_t,
196214
POINTER(infiniopRandomSampleDescriptor_t),
197215
infiniopTensorDescriptor_t,
198216
]
217+
199218
lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
200219
lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
201220
infiniopRandomSampleDescriptor_t,
202221
POINTER(c_uint64),
203222
]
223+
204224
lib.infiniopRandomSample.restype = c_int32
205225
lib.infiniopRandomSample.argtypes = [
206226
infiniopRandomSampleDescriptor_t,
@@ -214,11 +234,13 @@ def lib_random_sample():
214234
c_float,
215235
c_void_p,
216236
]
237+
217238
lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
218239
lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
219240
infiniopRandomSampleDescriptor_t,
220241
]
221242

243+
DEBUG = args.debug
222244
PROFILE = args.profile
223245
NUM_PRERUN = args.num_prerun
224246
NUM_ITERATIONS = args.num_iterations

0 commit comments

Comments
 (0)