@@ -47,6 +47,7 @@ def inferShape(x_shape, axis, noop_with_empty_axes, keepdims=False):
4747 return tuple ([1 ] * len (x_shape ))
4848 else :
4949 return tuple ([])
50+
5051 assert len (axis ) <= len (x_shape ), "axis out of range"
5152 output_shape = []
5253 axis = [a if a >= 0 else a + len (x_shape ) for a in axis ] # 更新 axis 列表中的值
@@ -82,9 +83,9 @@ def test(
8283 print (
8384 f"Testing reducemean on { torch_device } with x_shape:{ x_shape } dtype:{ tensor_dtype } "
8485 )
85- x = torch .randn ( x_shape , dtype = tensor_dtype , device = torch_device )
86+ x = torch .randint ( 0 , 10 , x_shape , dtype = tensor_dtype , device = torch_device )
8687 print (f"y_shape = { inferShape (x_shape , axes if dynamic_axes == None else dynamic_axes , noop_with_empty_axes , keepdims )} " )
87- y = torch .full (inferShape (x_shape , axes if dynamic_axes == None else dynamic_axes , noop_with_empty_axes , keepdims ), float ('-inf' ), dtype = tensor_dtype , device = torch_device )
88+ y = torch .full (inferShape (x_shape , axes if dynamic_axes == None else dynamic_axes , noop_with_empty_axes , keepdims ), float (0 ), dtype = tensor_dtype , device = torch_device )
8889 print (f"y_shape = { y .shape } " )
8990 for i in range (NUM_PRERUN if PROFILE else 1 ):
9091 ans = reduce_mean (x , axes if dynamic_axes == None else dynamic_axes , noop_with_empty_axes , keepdims )
@@ -141,6 +142,7 @@ def test(
141142 )
142143 elapsed = (time .time () - start_time ) / NUM_ITERATIONS
143144 print (f"lib time: { elapsed :10f} " )
145+ #print(f"input_data = {x}")
144146 print (f"custom op output:{ y } " )
145147 print (f"pytorch output:{ ans } " )
146148 assert torch .allclose (y , ans , atol = 0 , rtol = 1e-3 )
@@ -150,30 +152,39 @@ def test(
150152def test_cpu (lib , test_cases ):
151153 device = DeviceEnum .DEVICE_CPU
152154 handle = create_handle (lib , device )
153- for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes in test_cases :
154- test (lib , handle , "cpu" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = torch . float16 )
155+ for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes , tensor_dtype in test_cases :
156+ test (lib , handle , "cpu" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = tensor_dtype )
155157 print ("\n " )
156158 #test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
157159 destroy_handle (lib , handle )
158160
161+ def test_cuda (lib , test_cases ):
162+ device = DeviceEnum .DEVICE_CUDA
163+ handle = create_handle (lib , device )
164+ for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes , tensor_dtype in test_cases :
165+ test (lib , handle , "cuda" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = tensor_dtype )
166+ print ("\n " )
167+ destroy_handle (lib , handle )
168+
159169
160170if __name__ == "__main__" :
161171 test_cases = [
162172 # dynamic calc test eg
163- ((2 , 3 , 4 , 5 ), [0 , 2 ], False , True , None ),
164- ((2 , 3 , 4 , 5 ), [0 , 2 ], False , True , None ),
165- #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
166- ((2 , 10 , 24 , 10 ), [0 , 2 ], False , True , None ),
167- # stride =
168- ((2 , 10 , 24 , 10 ), [0 , 1 ], False , True , None ),
169- ((2 , 10 , 24 , 10 ), [2 , 3 ], False , True , None ),
170- ((2 , 10 , 24 , 10 ), [0 , 1 , 2 , 3 ], False , True , None ),
173+ # ((2, 3, 4, 5), [0, 2], False, True, None),
174+ # ((2, 3, 4, 5), [0, 2], False, True, None),
175+ # #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
176+ # ((2, 10, 24, 10), [0, 2], False, True, None),
177+ # # stride =
178+ # ((2, 10, 24, 10), [0, 1], False, True, None),
179+ # ((2, 10, 24, 10), [2, 3], False , True, None),
180+ ((50 , 3 ), [0 , 1 ], False , False , None , torch .float16 ),
181+ ((1000 , 3 ), [0 , 1 ], False , False , None , torch .float16 ),
171182 # validate attribute noop_with_empty_axes and keepdims
172- ((2 , 10 , 24 , 10 ), None , True , True , None ),
173- ((2 , 10 , 24 , 10 ), None , True , False , None ),
174- ((2 , 10 , 24 , 10 ), None , False , True , None ),
175- ((2 , 10 , 24 , 10 ), None , False , False , None ),
176- ((2 , 3 , 4 ), [0 , 1 ], False , False , None ),
183+ # ((2, 10, 24, 10), None, True, True, None),
184+ # ((2, 10, 24, 10), None, True, False, None),
185+ # ((2, 10, 24, 10), None, False, True, None),
186+ # ((2, 10, 24, 10), None, False, False, None),
187+ # ((2, 3, 4), [0, 1], False, False, None),
177188 #((2, 10, 24, 10), [], True),
178189 ]
179190 args = get_args ()
@@ -200,5 +211,8 @@ def test_cpu(lib, test_cases):
200211 ]
201212 lib .infiniopDestroyReducemeanDescriptor .restype = c_int32
202213 lib .infiniopDestroyReducemeanDescriptor .argtypes = [infiniopReducemeanDescriptor_t ]
203- test_cpu (lib , test_cases )
214+ if args .cpu :
215+ test_cpu (lib , test_cases )
216+ if args .cuda :
217+ test_cuda (lib , test_cases )
204218 print ("All tests passed!" )
0 commit comments