2020import torch
2121from typing import Tuple
2222import numpy as np
23- import onnx
2423
25- PROFILE = True
24+ PROFILE = False
2625NUM_PRERUN = 10
2726NUM_ITERATIONS = 1000
2827
@@ -43,15 +42,22 @@ def tuple_to_void_p(py_tuple: Tuple):
4342def inferShape (x_shape , y_shape ):
4443 ndim_x = len (x_shape )
4544 ndim_y = len (y_shape )
46- ndim = 0
47- output_shape = []
4845 ndim = max (ndim_x , ndim_y )
49- for i in range (ndim - 1 , - 1 , - 1 ):
50- dim_x = x_shape [i ] if i < ndim_x else 1
51- dim_y = y_shape [i ] if i < ndim_y else 1
52- output_shape .append (max (dim_x , dim_y ))
53- output_shape .reverse ()
46+ output_shape = []
47+
48+ for i in range (- 1 , - ndim - 1 , - 1 ):
49+ dim_x = x_shape [i ] if i >= - ndim_x else 1
50+ dim_y = y_shape [i ] if i >= - ndim_y else 1
51+
52+ if dim_x != dim_y :
53+ if dim_x != 1 and dim_y != 1 :
54+ raise ValueError (f"Shapes { x_shape } and { y_shape } cannot be broadcast together" )
55+
56+ output_dim = max (dim_x , dim_y )
57+ output_shape .insert (0 , output_dim )
58+
5459 return tuple (output_shape )
60+
5561
5662def test (
5763 lib ,
@@ -68,8 +74,7 @@ def test(
6874 condition = torch .randint (0 , 2 , condition_shape , dtype = torch .uint8 ).to (torch_device )
6975 src1 = torch .randn (src1_shape , dtype = tensor_dtype , device = torch_device )
7076 src2 = torch .randn (src2_shape , dtype = tensor_dtype , device = torch_device )
71- output = torch .randn (inferShape (src1_shape , src2_shape ), dtype = tensor_dtype , device = torch_device )
72-
77+ output = torch .randn (inferShape (inferShape (src1_shape , src2_shape ), condition_shape ), dtype = tensor_dtype , device = torch_device )
7378
7479 for i in range (NUM_PRERUN if PROFILE else 1 ):
7580 ans = where (condition , src1 , src2 )
@@ -130,18 +135,33 @@ def test(
130135def test_cpu (lib , test_cases ):
131136 device = DeviceEnum .DEVICE_CPU
132137 handle = create_handle (lib , device )
133- for condition_shape , src1_shape , src2_shape in test_cases :
134- test (lib , handle , "cpu" , condition_shape , src1_shape , src2_shape , tensor_dtype = torch .float16 )
138+ for condition_shape , src1_shape , src2_shape , tensor_dtype in test_cases :
139+ test (lib , handle , "cpu" , condition_shape , src1_shape , src2_shape , tensor_dtype = tensor_dtype )
140+ print ("\n " )
141+ destroy_handle (lib , handle )
142+
143+ def test_cuda (lib , test_cases ):
144+ device = DeviceEnum .DEVICE_CUDA
145+ handle = create_handle (lib , device )
146+ for condition_shape , src1_shape , src2_shape , tensor_dtype in test_cases :
147+ test (lib , handle , "cuda" , condition_shape , src1_shape , src2_shape , tensor_dtype = tensor_dtype )
135148 print ("\n " )
136149 destroy_handle (lib , handle )
137150
138151
139152if __name__ == "__main__" :
140153 test_cases = [
141- ((2 , 3 , 4 , 5 ), (2 , 3 , 4 , 5 ), (2 , 3 , 4 , 5 )),
142- ((3 , 1 ), (3 , 4 ), (1 , 4 )),
143- ((1 ,), (3 , 4 ), (3 , 4 )),
144- ((2 , 1 , 3 ), (1 , 4 , 3 ), (2 , 4 , 1 )),
154+ ((2 , 16 ), (2 , 16 ), (2 , 16 ), torch .float32 ),
155+ ((2 , 3 , 1 , 1 ), (1 , 4 , 5 ), (2 , 3 , 4 , 5 ), torch .float32 ),
156+ ((3 , 1 ), (3 , 4 ), (1 , 4 ), torch .float32 ),
157+ ((1 ,), (3 , 4 ), (3 , 4 ), torch .float32 ),
158+ ((2 , 1 , 3 ), (1 , 4 , 3 ), (2 , 4 , 1 ), torch .float32 ),
159+
160+ ((2 , 16 ), (2 , 16 ), (2 , 16 ), torch .float16 ),
161+ ((2 , 3 , 1 , 1 ), (1 , 4 , 5 ), (2 , 3 , 4 , 5 ), torch .float16 ),
162+ ((3 , 1 ), (3 , 4 ), (1 , 4 ), torch .float16 ),
163+ ((1 ,), (3 , 4 ), (3 , 4 ), torch .float16 ),
164+ ((2 , 1 , 3 ), (1 , 4 , 3 ), (2 , 4 , 1 ), torch .float16 ),
145165 ]
146166 args = get_args ()
147167 lib = open_lib ()
@@ -150,6 +170,9 @@ def test_cpu(lib, test_cases):
150170 infiniopHandle_t ,
151171 POINTER (infiniopWhereDescriptor_t ),
152172 infiniopTensorDescriptor_t ,
173+ infiniopTensorDescriptor_t ,
174+ infiniopTensorDescriptor_t ,
175+ infiniopTensorDescriptor_t
153176 ]
154177 lib .infiniopWhere .restype = c_int32
155178 lib .infiniopWhere .argtypes = [
@@ -162,5 +185,8 @@ def test_cpu(lib, test_cases):
162185 ]
163186 lib .infiniopDestroyWhereDescriptor .restype = c_int32
164187 lib .infiniopDestroyWhereDescriptor .argtypes = [infiniopWhereDescriptor_t ]
165- test_cpu (lib , test_cases )
188+ if args .cpu :
189+ test_cpu (lib , test_cases )
190+ if args .cuda :
191+ test_cuda (lib , test_cases )
166192 print ("All tests passed!" )
0 commit comments