1717)
1818
1919from operatorspy .tests .test_utils import get_args
20+ from enum import Enum , auto
2021import torch
2122import ctypes
2223import torch .nn .functional as F
2829NUM_PRERUN = 10
2930NUM_ITERATIONS = 1000
3031
32+ class Inplace (Enum ):
33+ OUT_OF_PLACE = auto ()
34+ INPLACE_X = auto ()
35+
3136
3237class BatchNormDescriptor (Structure ):
3338 _fields_ = [("device" , c_int32 )]
@@ -64,17 +69,18 @@ def test(
6469 x_shape ,
6570 eps = 1e-5 ,
6671 tensor_dtype = torch .float16 ,
72+ inplace = Inplace .OUT_OF_PLACE ,
6773):
6874 print (
69- f"Testing BatchNorm on { torch_device } with x_shape: { x_shape } , scale_shape: { x_shape [1 ]} , b_shape: { x_shape [1 ]} , mean_shape: { x_shape [1 ]} , var_shape: { x_shape [1 ]} , eps: { eps } dtype:{ tensor_dtype } "
75+ f"Testing BatchNorm on { torch_device } with x_shape: { x_shape } , scale_shape: { x_shape [1 ]} , b_shape: { x_shape [1 ]} , mean_shape: { x_shape [1 ]} , var_shape: { x_shape [1 ]} , eps: { eps } , dtype:{ tensor_dtype } , Inplace: { inplace } "
7076 )
7177 num_channel = x_shape [1 ]
7278 bn_dtype = tensor_dtype if tensor_dtype != torch .float16 else torch .float32
7379 x = torch .rand (x_shape , dtype = tensor_dtype ).to (torch_device ) * 10 - 2
7480 scale = torch .rand (num_channel , dtype = bn_dtype ).to (torch_device )
7581 b = torch .rand (num_channel , dtype = bn_dtype ).to (torch_device )
7682 mean , var = get_mean_variance (x , bn_dtype )
77- y = torch .zeros (x_shape , dtype = tensor_dtype ).to (torch_device )
83+ y = torch .zeros (x_shape , dtype = tensor_dtype ).to (torch_device ) if inplace == Inplace . OUT_OF_PLACE else x
7884
7985 # get the pytorch answer
8086 for i in range (NUM_PRERUN if PROFILE else 1 ):
@@ -92,7 +98,7 @@ def test(
9298 b_tensor = to_tensor (b , lib )
9399 mean_tensor = to_tensor (mean , lib )
94100 var_tensor = to_tensor (var , lib )
95- y_tensor = to_tensor (y , lib )
101+ y_tensor = to_tensor (y , lib ) if inplace == Inplace . OUT_OF_PLACE else x_tensor
96102 descriptor = infiniopBatchNormDescriptor_t ()
97103
98104 check_error (
@@ -145,18 +151,18 @@ def test(
145151def test_cpu (lib , test_cases ):
146152 device = DeviceEnum .DEVICE_CPU
147153 handle = create_handle (lib , device )
148- for x_shape , eps in test_cases :
149- test (lib , handle , "cpu" , x_shape , eps , tensor_dtype = torch .float16 )
150- test (lib , handle , "cpu" , x_shape , eps , tensor_dtype = torch .float32 )
154+ for x_shape , eps , inplace in test_cases :
155+ test (lib , handle , "cpu" , x_shape , eps , tensor_dtype = torch .float16 , inplace = inplace )
156+ test (lib , handle , "cpu" , x_shape , eps , tensor_dtype = torch .float32 , inplace = inplace )
151157 destroy_handle (lib , handle )
152158
153159
154160def test_cuda (lib , test_cases ):
155161 device = DeviceEnum .DEVICE_CUDA
156162 handle = create_handle (lib , device )
157- for x_shape , eps in test_cases :
158- test (lib , handle , "cuda" , x_shape , eps , tensor_dtype = torch .float16 )
159- test (lib , handle , "cuda" , x_shape , eps , tensor_dtype = torch .float32 )
163+ for x_shape , eps , inplace in test_cases :
164+ test (lib , handle , "cuda" , x_shape , eps , tensor_dtype = torch .float16 , inplace = inplace )
165+ test (lib , handle , "cuda" , x_shape , eps , tensor_dtype = torch .float32 , inplace = inplace )
160166 destroy_handle (lib , handle )
161167
162168
@@ -165,19 +171,20 @@ def test_bang(lib, test_cases):
165171
166172 device = DeviceEnum .DEVICE_BANG
167173 handle = create_handle (lib , device )
168- for x_shape , eps in test_cases :
169- test (lib , handle , "mlu" , x_shape , eps , tensor_dtype = torch .float16 )
170- test (lib , handle , "mlu" , x_shape , eps , tensor_dtype = torch .float32 )
174+ for x_shape , eps , inplace in test_cases :
175+ test (lib , handle , "mlu" , x_shape , eps , tensor_dtype = torch .float16 , inplace = inplace )
176+ test (lib , handle , "mlu" , x_shape , eps , tensor_dtype = torch .float32 , inplace = inplace )
171177 destroy_handle (lib , handle )
172178
173179
174180if __name__ == "__main__" :
175181 test_cases = [
176- # x_shape, eps
177- ((2 , 5 , 7 ), 1e-5 ),
178- ((32 , 3 , 1024 ), 1e-5 ),
179- ((32 , 3 , 128 , 128 ), 1e-5 ),
180- ((32 , 3 , 64 , 64 , 64 ), 1e-5 ),
182+ # x_shape, eps, inplace
183+ ((2 , 5 , 7 ), 1e-5 , Inplace .OUT_OF_PLACE ),
184+ ((2 , 5 , 7 ), 1e-5 , Inplace .INPLACE_X ),
185+ ((32 , 3 , 1024 ), 1e-5 , Inplace .OUT_OF_PLACE ),
186+ ((32 , 3 , 128 , 128 ), 1e-5 , Inplace .OUT_OF_PLACE ),
187+ ((32 , 3 , 64 , 64 , 64 ), 1e-5 , Inplace .OUT_OF_PLACE ),
181188 ]
182189 args = get_args ()
183190 lib = open_lib ()
0 commit comments