2121from typing import Tuple
2222import numpy as np
2323
24- PROFILE = False
24+ PROFILE = True
2525NUM_PRERUN = 1
2626NUM_ITERATIONS = 1
2727
@@ -141,41 +141,50 @@ def test(
141141 )
142142 elapsed = (time .time () - start_time ) / NUM_ITERATIONS
143143 print (f"lib time: { elapsed :10f} " )
144- print (f"custom op output:{ y } " )
145- print (f"pytorch output:{ ans } " )
144+ # print(f"custom op output:{y}")
145+ # print(f"pytorch output:{ans}")
146146 assert torch .allclose (y , ans , atol = 0 , rtol = 1e-3 )
147147
148148 check_error (lib .infiniopDestroyReducemaxDescriptor (descriptor ))
149149
150150def test_cpu (lib , test_cases ):
151151 device = DeviceEnum .DEVICE_CPU
152152 handle = create_handle (lib , device )
153- for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes in test_cases :
154- print (dynamic_axes )
155- test (lib , handle , "cpu" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = torch .float16 )
156- print ("\n " )
153+ for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes , tensor_dtype in test_cases :
154+ test (lib , handle , "cpu" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = tensor_dtype )
157155 #test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
158156 destroy_handle (lib , handle )
159157
158+ def test_cuda (lib , test_cases ):
159+ device = DeviceEnum .DEVICE_CUDA
160+ handle = create_handle (lib , device )
161+ for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes , tensor_dtype in test_cases :
162+ test (lib , handle , "cuda" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = tensor_dtype )
163+ print ("\n " )
164+ destroy_handle (lib , handle )
160165
161166if __name__ == "__main__" :
162167 test_cases = [
163168 # dynamic calc test eg
164- ((2 , 3 , 4 , 5 ), [0 , 2 ], False , True , None ),
165- ((2 , 3 , 4 , 5 ), [0 , 2 ], False , True , None ),
166- #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
167- ((2 , 10 , 24 , 10 ), [0 , 2 ], False , True , None ),
168- # stride =
169- ((2 , 10 , 24 , 10 ), [0 , 1 ], False , True , None ),
170- ((2 , 10 , 24 , 10 ), [2 , 3 ], False , True , None ),
171- ((2 , 10 , 24 , 10 ), [0 , 1 , 2 , 3 ], False , True , None ),
172- # validate attribute noop_with_empty_axes and keepdims
173- ((2 , 10 , 24 , 10 ), None , True , True , None ),
174- ((2 , 10 , 24 , 10 ), None , True , False , None ),
175- ((2 , 10 , 24 , 10 ), None , False , True , None ),
176- ((2 , 10 , 24 , 10 ), None , False , False , None ),
177- ((2 , 3 , 4 ), [0 , 1 ], False , False , None ),
178- #((2, 10, 24, 10), [], True),
169+ # ((2, 3, 4, 5), [0, 2], False, True, None),
170+ # ((2, 3, 4, 5), [0, 2], False, True, None),
171+ # #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
172+ # ((2, 10, 24, 10), [0, 2], False, True, None),
173+ # # stride =
174+ # ((2, 10, 24, 10), [0, 1], False, True, None),
175+ # ((2, 10, 24, 10), [2, 3], False , True, None),
176+ # ((2, 10, 24, 10), [0, 1, 2, 3], False, True, None),
177+ # # validate attribute noop_with_empty_axes and keepdims
178+ # ((2, 10, 24, 10), None, True, True, None),
179+ # ((2, 10, 24, 10), None, True, False, None),
180+ # ((2, 10, 24, 10), None, False, True, None),
181+ # ((2, 10, 24, 10), None, False, False, None),
182+ # ((2, 3, 4), [0, 1], False, False, None),
183+ # #((2, 10, 24, 10), [], True),
184+ ((2 , 1000 ), [0 , 1 ], False , False , None , torch .float32 ),
185+ ((2 , 2 , 5 ), [0 , 1 ], False , True , None , torch .float32 ),
186+ ((1000 , 200 , 500 ), [0 , 1 ], False , True , None , torch .float16 ),
187+ ((1000 , 200 , 50 ), [0 , 1 ], False , True , None , torch .float32 ),
179188 ]
180189 args = get_args ()
181190 lib = open_lib ()
@@ -201,5 +210,8 @@ def test_cpu(lib, test_cases):
201210 ]
202211 lib .infiniopDestroyReduceminDescriptor .restype = c_int32
203212 lib .infiniopDestroyReduceminDescriptor .argtypes = [infiniopReduceminDescriptor_t ]
204- test_cpu (lib , test_cases )
213+ if args .cpu :
214+ test_cpu (lib , test_cases )
215+ if args .cuda :
216+ test_cuda (lib , test_cases )
205217 print ("All tests passed!" )
0 commit comments