1- from ctypes import POINTER , Structure , c_int32 , c_uint64 , c_void_p
1+ import torch
22import ctypes
3- import sys
4- import os
5-
6-
7- sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , ".." )))
8- from operatorspy import (
9- open_lib ,
10- to_tensor ,
11- DeviceEnum ,
3+ from ctypes import POINTER , Structure , c_int32 , c_size_t , c_uint64 , c_void_p , c_float
4+ from libinfiniop import (
125 infiniopHandle_t ,
136 infiniopTensorDescriptor_t ,
14- create_handle ,
15- destroy_handle ,
7+ open_lib ,
8+ to_tensor ,
9+ get_test_devices ,
1610 check_error ,
17- rearrange_tensor ,
11+ rearrange_if_needed ,
1812 create_workspace ,
13+ test_operator ,
14+ get_args ,
15+ debug ,
16+ get_tolerance ,
17+ profile_operation ,
1918)
2019
21- from operatorspy .tests .test_utils import get_args
22- import torch
20+ # ==============================================================================
21+ # Configuration (Internal Use Only)
22+ # ==============================================================================
23+ # These are not meant to be imported from other modules
2324
25+ _TEST_CASES = [
26+ # x_shape, x_stride
27+ ((32 , 512 ), None ),
28+ ((32 , 512 ), (1024 , 1 )),
29+ ((32 , 5 , 5 ), None ),
30+ ((32 , 20 , 512 ), None ),
31+ ((32 , 20 , 512 ), (20480 , 512 , 1 )), # Ascend 暂不支持非连续
32+ ((32 , 20 , 4 , 512 ), None ),
33+ ((32 , 20 , 4 , 512 ), (81920 , 2048 , 512 , 1 )),
34+ ]
35+ # Data types used for testing
36+ _TENSOR_DTYPES = [torch .float16 , torch .float32 ]
37+
38+ # Tolerance map for different data types
39+ _TOLERANCE_MAP = {
40+ torch .float16 : {'atol' : 0 , 'rtol' : 1e-2 },
41+ torch .float32 : {'atol' : 0 , 'rtol' : 1e-3 },
42+ }
43+
44+ DEBUG = False
45+ PROFILE = False
46+ NUM_PRERUN = 10
47+ NUM_ITERATIONS = 1000
2448
2549class CausalSoftmaxDescriptor (Structure ):
2650 _fields_ = [("device" , c_int32 )]
@@ -37,88 +61,78 @@ def causal_softmax(x):
3761 return torch .nn .functional .softmax (masked , dim = - 1 ).to (type )
3862
3963
40- def test (lib , handle , torch_device , x_shape , x_stride = None , x_dtype = torch .float16 ):
64+ def test (
65+ lib ,
66+ handle ,
67+ torch_device ,
68+ x_shape ,
69+ x_stride = None ,
70+ dtype = torch .float16
71+ ):
4172 print (
42- f"Testing CausalSoftmax on { torch_device } with x_shape:{ x_shape } x_stride:{ x_stride } dtype:{ x_dtype } "
73+ f"Testing CausalSoftmax on { torch_device } with x_shape:{ x_shape } x_stride:{ x_stride } dtype:{ dtype } "
4374 )
44- x = torch .rand (x_shape , dtype = x_dtype ).to (torch_device )
45- if x_stride is not None :
46- x = rearrange_tensor (x , x_stride )
75+ x = torch .rand (x_shape , dtype = dtype ).to (torch_device )
76+
4777 ans = causal_softmax (x )
78+
79+
80+ x = rearrange_if_needed (x , x_stride )
81+
4882 x_tensor = to_tensor (x , lib )
83+
4984 descriptor = infiniopCausalSoftmaxDescriptor_t ()
5085 check_error (
5186 lib .infiniopCreateCausalSoftmaxDescriptor (
52- handle , ctypes .byref (descriptor ), x_tensor .descriptor
87+ handle ,
88+ ctypes .byref (descriptor ),
89+ x_tensor .descriptor
5390 )
5491 )
92+
93+ # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
94+ x_tensor .descriptor .contents .invalidate ()
95+
96+
5597 workspace_size = c_uint64 (0 )
5698 check_error (
5799 lib .infiniopGetCausalSoftmaxWorkspaceSize (
58100 descriptor , ctypes .byref (workspace_size )
59101 )
60102 )
61-
62- # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
63- x_tensor .descriptor .contents .invalidate ()
64-
65103 workspace = create_workspace (workspace_size .value , x .device )
66- check_error (
67- lib .infiniopCausalSoftmax (
68- descriptor ,
69- workspace .data_ptr () if workspace is not None else None ,
70- workspace_size .value ,
71- x_tensor .data ,
72- None ,
104+ def lib_causal_softmax ():
105+ check_error (
106+ lib .infiniopCausalSoftmax (
107+ descriptor ,
108+ workspace .data_ptr () if workspace is not None else None ,
109+ workspace_size .value ,
110+ x_tensor .data ,
111+ None ,
112+ )
73113 )
74- )
75- assert torch .allclose (x , ans , atol = 0 , rtol = 1e-2 )
76- check_error (lib .infiniopDestroyCausalSoftmaxDescriptor (descriptor ))
77-
78-
79- def test_cpu (lib , test_cases ):
80- device = DeviceEnum .DEVICE_CPU
81- handle = create_handle (lib , device )
82- for x_shape , x_stride in test_cases :
83- test (lib , handle , "cpu" , x_shape , x_stride )
84- destroy_handle (lib , handle )
85-
86-
87- def test_cuda (lib , test_cases ):
88- device = DeviceEnum .DEVICE_CUDA
89- handle = create_handle (lib , device )
90- for x_shape , x_stride in test_cases :
91- test (lib , handle , "cuda" , x_shape , x_stride )
92- destroy_handle (lib , handle )
93-
94-
95- def test_bang (lib , test_cases ):
96- import torch_mlu
97-
98- device = DeviceEnum .DEVICE_BANG
99- handle = create_handle (lib , device )
100- for x_shape , x_stride in test_cases :
101- test (lib , handle , "mlu" , x_shape , x_stride )
102- destroy_handle (lib , handle )
114+ lib_causal_softmax ()
115+
116+ atol , rtol = get_tolerance (_TOLERANCE_MAP , dtype )
117+ if DEBUG :
118+ debug (x , ans , atol = atol , rtol = rtol )
119+ assert torch .allclose (x , ans , atol = atol , rtol = rtol )
120+
121+ # Profiling workflow
122+ if PROFILE :
123+ # fmt: off
124+ profile_operation ("PyTorch" , lambda : causal_softmax (x ), torch_device , NUM_PRERUN , NUM_ITERATIONS )
125+ profile_operation (" lib" , lambda : lib_causal_softmax (), torch_device , NUM_PRERUN , NUM_ITERATIONS )
126+ # fmt: on
103127
128+ check_error (lib .infiniopDestroyCausalSoftmaxDescriptor (descriptor ))
104129
105- def test_ascend (lib , test_cases ):
106- import torch_npu
107130
108- device = DeviceEnum .DEVICE_ASCEND
109- handle = create_handle (lib , device )
110- for x_shape , x_stride in test_cases :
111- test (lib , handle , "npu" , x_shape , x_stride )
112131
113- destroy_handle (lib , handle )
114132
115133
116134if __name__ == "__main__" :
117- test_cases = [
118- # x_shape, x_stride
119- ((32 , 20 , 512 ), None ),
120- ((32 , 20 , 512 ), (20480 , 512 , 1 )), # Ascend 暂不支持非连续
121- ]
135+
122136 args = get_args ()
123137 lib = open_lib ()
124138 lib .infiniopCreateCausalSoftmaxDescriptor .restype = c_int32
@@ -144,15 +158,14 @@ def test_ascend(lib, test_cases):
144158 lib .infiniopDestroyCausalSoftmaxDescriptor .argtypes = [
145159 infiniopCausalSoftmaxDescriptor_t ,
146160 ]
161+ # Configure testing options
162+ DEBUG = args .debug
163+ PROFILE = args .profile
164+ NUM_PRERUN = args .num_prerun
165+ NUM_ITERATIONS = args .num_iterations
166+
167+ for device in get_test_devices (args ):
168+ test_operator (lib , device , test , _TEST_CASES , _TENSOR_DTYPES )
147169
148- if args .cpu :
149- test_cpu (lib , test_cases )
150- if args .cuda :
151- test_cuda (lib , test_cases )
152- if args .bang :
153- test_bang (lib , test_cases )
154- if args .ascend :
155- test_ascend (lib , test_cases )
156- if not (args .cpu or args .cuda or args .bang or args .ascend ):
157- test_cpu (lib , test_cases )
158170 print ("\033 [92mTest passed!\033 [0m" )
171+
0 commit comments