@@ -2817,6 +2817,31 @@ def func(x):
28172817 return tf .identity (y , name = _TFOUTPUT )
28182818 self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 )
28192819
2820+ @check_opset_min_version (7 , "batchnorm" )
2821+ @check_tf_min_version ("2.0" , "tf-1.x does not support NDHWC" )
2822+ def test_fused_batchnorm_3d (self ):
2823+ x_shape = [1 , 28 , 28 , 2 , 2 ]
2824+ x_dtype = np .float32
2825+ scale_dtype = np .float32
2826+ scale_shape = [2 ]
2827+ data_format = "NDHWC"
2828+ x_val = np .random .random_sample (x_shape ).astype (x_dtype )
2829+ scale_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2830+ offset_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2831+ mean_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2832+ var_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2833+ def func (x ):
2834+ scale = tf .constant (scale_val , name = 'scale' )
2835+ offset = tf .constant (offset_val , name = 'offset' )
2836+ mean = tf .constant (mean_val , name = 'mean' )
2837+ var = tf .constant (var_val , name = 'variance' )
2838+ epsilon = 0.001
2839+ y , _ , _ = fused_batch_norm (
2840+ x , scale , offset , mean = mean , variance = var ,
2841+ epsilon = epsilon , data_format = data_format , is_training = False )
2842+ return tf .identity (y , name = _TFOUTPUT )
2843+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 )
2844+
28202845 @check_opset_min_version (7 , "batchnorm" )
28212846 @skip_tfjs ("TFJS executes model incorrectly" )
28222847 def test_fused_batchnorm_training (self ):
0 commit comments