1- from ctypes import POINTER , Structure , c_int32 , c_uint64 , c_void_p
1+ import torch
22import ctypes
3- import sys
4- import os
5-
6-
7- sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , ".." )))
8- from operatorspy import (
9- open_lib ,
10- to_tensor ,
11- DeviceEnum ,
3+ from ctypes import POINTER , Structure , c_int32 , c_size_t , c_uint64 , c_void_p , c_float
4+ from libinfiniop import (
125 infiniopHandle_t ,
136 infiniopTensorDescriptor_t ,
14- create_handle ,
15- destroy_handle ,
7+ open_lib ,
8+ to_tensor ,
9+ get_test_devices ,
1610 check_error ,
17- rearrange_tensor ,
11+ rearrange_if_needed ,
1812 create_workspace ,
13+ test_operator ,
14+ get_args ,
15+ debug ,
16+ get_tolerance ,
17+ profile_operation ,
1918)
2019
21- from operatorspy .tests .test_utils import get_args
22- import torch
20+ # ==============================================================================
21+ # Configuration (Internal Use Only)
22+ # ==============================================================================
23+ # These are not meant to be imported from other modules
24+ _TEST_CASES = [
25+ # x_shape, x_stride
26+ ((32 , 512 ), None ),
27+ ((32 , 512 ), (1024 , 1 )),
28+ ((32 , 5 , 5 ), None ),
29+ ((32 , 20 , 512 ), None ),
30+ ((32 , 20 , 512 ), (20480 , 512 , 1 )), # Ascend 暂不支持非连续
31+ ]
32+
33+ # Data types used for testing
34+ _TENSOR_DTYPES = [torch .float16 ]
35+
36+ # Tolerance map for different data types
37+ _TOLERANCE_MAP = {
38+ torch .float16 : {"atol" : 0 , "rtol" : 1e-2 },
39+ }
40+
41+ DEBUG = False
42+ PROFILE = False
43+ NUM_PRERUN = 10
44+ NUM_ITERATIONS = 1000
2345
2446
2547class CausalSoftmaxDescriptor (Structure ):
@@ -37,101 +59,82 @@ def causal_softmax(x):
3759 return torch .nn .functional .softmax (masked , dim = - 1 ).to (type )
3860
3961
40- def test (lib , handle , torch_device , x_shape , x_stride = None , x_dtype = torch .float16 ):
62+ def test (lib , handle , torch_device , x_shape , x_stride = None , dtype = torch .float16 ):
4163 print (
42- f"Testing CausalSoftmax on { torch_device } with x_shape:{ x_shape } x_stride:{ x_stride } dtype:{ x_dtype } "
64+ f"Testing CausalSoftmax on { torch_device } with x_shape:{ x_shape } x_stride:{ x_stride } dtype:{ dtype } "
4365 )
44- x = torch . rand ( x_shape , dtype = x_dtype ). to ( torch_device )
45- if x_stride is not None :
46- x = rearrange_tensor ( x , x_stride )
66+
67+ x = torch . rand ( x_shape , dtype = dtype ). to ( torch_device )
68+
4769 ans = causal_softmax (x )
70+
71+ x = rearrange_if_needed (x , x_stride )
72+
4873 x_tensor = to_tensor (x , lib )
74+
4975 descriptor = infiniopCausalSoftmaxDescriptor_t ()
5076 check_error (
5177 lib .infiniopCreateCausalSoftmaxDescriptor (
5278 handle , ctypes .byref (descriptor ), x_tensor .descriptor
5379 )
5480 )
55- workspace_size = c_uint64 (0 )
56- check_error (
57- lib .infiniopGetCausalSoftmaxWorkspaceSize (
58- descriptor , ctypes .byref (workspace_size )
59- )
60- )
6181
6282 # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
6383 x_tensor .descriptor .contents .invalidate ()
6484
65- workspace = create_workspace ( workspace_size . value , x . device )
85+ workspace_size = c_uint64 ( 0 )
6686 check_error (
67- lib .infiniopCausalSoftmax (
68- descriptor ,
69- workspace .data_ptr () if workspace is not None else None ,
70- workspace_size .value ,
71- x_tensor .data ,
72- None ,
87+ lib .infiniopGetCausalSoftmaxWorkspaceSize (
88+ descriptor , ctypes .byref (workspace_size )
7389 )
7490 )
75- assert torch .allclose (x , ans , atol = 0 , rtol = 1e-2 )
76- check_error (lib .infiniopDestroyCausalSoftmaxDescriptor (descriptor ))
77-
78-
79- def test_cpu (lib , test_cases ):
80- device = DeviceEnum .DEVICE_CPU
81- handle = create_handle (lib , device )
82- for x_shape , x_stride in test_cases :
83- test (lib , handle , "cpu" , x_shape , x_stride )
84- destroy_handle (lib , handle )
85-
86-
87- def test_cuda (lib , test_cases ):
88- device = DeviceEnum .DEVICE_CUDA
89- handle = create_handle (lib , device )
90- for x_shape , x_stride in test_cases :
91- test (lib , handle , "cuda" , x_shape , x_stride )
92- destroy_handle (lib , handle )
93-
94-
95- def test_bang (lib , test_cases ):
96- import torch_mlu
91+ workspace = create_workspace (workspace_size .value , x .device )
9792
98- device = DeviceEnum .DEVICE_BANG
99- handle = create_handle (lib , device )
100- for x_shape , x_stride in test_cases :
101- test (lib , handle , "mlu" , x_shape , x_stride )
102- destroy_handle (lib , handle )
93+ def lib_causal_softmax ():
94+ check_error (
95+ lib .infiniopCausalSoftmax (
96+ descriptor ,
97+ workspace .data_ptr () if workspace is not None else None ,
98+ workspace_size .value ,
99+ x_tensor .data ,
100+ None ,
101+ )
102+ )
103103
104+ lib_causal_softmax ()
104105
105- def test_ascend (lib , test_cases ):
106- import torch_npu
106+ atol , rtol = get_tolerance (_TOLERANCE_MAP , dtype )
107+ if DEBUG :
108+ debug (x , ans , atol = atol , rtol = rtol )
109+ assert torch .allclose (x , ans , atol = atol , rtol = rtol )
107110
108- device = DeviceEnum .DEVICE_ASCEND
109- handle = create_handle (lib , device )
110- for x_shape , x_stride in test_cases :
111- test (lib , handle , "npu" , x_shape , x_stride )
111+ # Profiling workflow
112+ if PROFILE :
113+ # fmt: off
114+ profile_operation ("PyTorch" , lambda : causal_softmax (x ), torch_device , NUM_PRERUN , NUM_ITERATIONS )
115+ profile_operation (" lib" , lambda : lib_causal_softmax (), torch_device , NUM_PRERUN , NUM_ITERATIONS )
116+ # fmt: on
112117
113- destroy_handle (lib , handle )
118+ check_error (lib . infiniopDestroyCausalSoftmaxDescriptor ( descriptor ) )
114119
115120
116121if __name__ == "__main__" :
117- test_cases = [
118- # x_shape, x_stride
119- ((32 , 20 , 512 ), None ),
120- ((32 , 20 , 512 ), (20480 , 512 , 1 )), # Ascend 暂不支持非连续
121- ]
122122 args = get_args ()
123123 lib = open_lib ()
124+
124125 lib .infiniopCreateCausalSoftmaxDescriptor .restype = c_int32
125126 lib .infiniopCreateCausalSoftmaxDescriptor .argtypes = [
126127 infiniopHandle_t ,
127128 POINTER (infiniopCausalSoftmaxDescriptor_t ),
128129 infiniopTensorDescriptor_t ,
129130 ]
131+
130132 lib .infiniopGetCausalSoftmaxWorkspaceSize .restype = c_int32
131133 lib .infiniopGetCausalSoftmaxWorkspaceSize .argtypes = [
132134 infiniopCausalSoftmaxDescriptor_t ,
133135 POINTER (c_uint64 ),
134136 ]
137+
135138 lib .infiniopCausalSoftmax .restype = c_int32
136139 lib .infiniopCausalSoftmax .argtypes = [
137140 infiniopCausalSoftmaxDescriptor_t ,
@@ -140,19 +143,19 @@ def test_ascend(lib, test_cases):
140143 c_void_p ,
141144 c_void_p ,
142145 ]
146+
143147 lib .infiniopDestroyCausalSoftmaxDescriptor .restype = c_int32
144148 lib .infiniopDestroyCausalSoftmaxDescriptor .argtypes = [
145149 infiniopCausalSoftmaxDescriptor_t ,
146150 ]
147151
148- if args .cpu :
149- test_cpu (lib , test_cases )
150- if args .cuda :
151- test_cuda (lib , test_cases )
152- if args .bang :
153- test_bang (lib , test_cases )
154- if args .ascend :
155- test_ascend (lib , test_cases )
156- if not (args .cpu or args .cuda or args .bang or args .ascend ):
157- test_cpu (lib , test_cases )
152+ # Configure testing options
153+ DEBUG = args .debug
154+ PROFILE = args .profile
155+ NUM_PRERUN = args .num_prerun
156+ NUM_ITERATIONS = args .num_iterations
157+
158+ for device in get_test_devices (args ):
159+ test_operator (lib , device , test , _TEST_CASES , _TENSOR_DTYPES )
160+
158161 print ("\033 [92mTest passed!\033 [0m" )
0 commit comments