99
1010import numpy as np
1111
12+
1213class TestLayers (unittest .TestCase ):
1314 def convert_model (self ):
14- subprocess .run ([sys .executable ,
15- '-m' ,
16- 'mo' ,
17- '--input_model=model.onnx' ,
18- '--extension' , Path (__file__ ).absolute ().parent / 'mo_extensions' ],
19- check = True )
15+ subprocess .run (
16+ [
17+ sys .executable ,
18+ "-m" ,
19+ "mo" ,
20+ "--input_model=model.onnx" ,
21+ "--extension" ,
22+ Path (__file__ ).absolute ().parent / "mo_extensions" ,
23+ ],
24+ check = True ,
25+ )
2026
2127 def run_test (self , convert_ir = True , test_onnx = False , num_inputs = 1 , threshold = 1e-5 ):
2228 if convert_ir and not test_onnx :
@@ -25,67 +31,63 @@ def run_test(self, convert_ir=True, test_onnx=False, num_inputs=1, threshold=1e-
2531 inputs = {}
2632 shapes = {}
2733 for i in range (num_inputs ):
28- suffix = '{}' .format (i if i > 0 else '' )
29- data = np .load (' inp' + suffix + ' .npy' )
30- inputs [' input' + suffix ] = data
31- shapes [' input' + suffix ] = data .shape
34+ suffix = "{}" .format (i if i > 0 else "" )
35+ data = np .load (" inp" + suffix + " .npy" )
36+ inputs [" input" + suffix ] = data
37+ shapes [" input" + suffix ] = data .shape
3238
33- ref = np .load (' ref.npy' )
39+ ref = np .load (" ref.npy" )
3440
3541 ie = IECore ()
36- ie .add_extension (get_extensions_path (), ' CPU' )
37- ie .set_config ({' CONFIG_FILE' : ' user_ie_extensions/gpu_extensions.xml' }, ' GPU' )
42+ ie .add_extension (get_extensions_path (), " CPU" )
43+ ie .set_config ({" CONFIG_FILE" : " user_ie_extensions/gpu_extensions.xml" }, " GPU" )
3844
39- net = ie .read_network (' model.onnx' if test_onnx else ' model.xml' )
45+ net = ie .read_network (" model.onnx" if test_onnx else " model.xml" )
4046 net .reshape (shapes )
41- exec_net = ie .load_network (net , ' CPU' )
47+ exec_net = ie .load_network (net , " CPU" )
4248
4349 out = exec_net .infer (inputs )
4450 out = next (iter (out .values ()))
4551
4652 diff = np .max (np .abs (ref - out ))
4753 self .assertLessEqual (diff , threshold )
4854
49-
5055 def test_unpool (self ):
5156 from examples .unpool .export_model import export
52- export (mode = 'default' )
53- self .run_test ()
5457
58+ export (mode = "default" )
59+ self .run_test ()
5560
5661 def test_unpool_reshape (self ):
5762 from examples .unpool .export_model import export
58- export (mode = 'dynamic_size' , shape = [5 , 3 , 6 , 9 ])
63+
64+ export (mode = "dynamic_size" , shape = [5 , 3 , 6 , 9 ])
5965 self .run_test ()
6066
61- export (mode = ' dynamic_size' , shape = [4 , 3 , 17 , 8 ])
67+ export (mode = " dynamic_size" , shape = [4 , 3 , 17 , 8 ])
6268 self .run_test (convert_ir = False )
6369
64-
6570 def test_fft (self ):
6671 from examples .fft .export_model import export
6772
6873 for shape in [[5 , 120 , 2 ], [4 , 240 , 320 , 2 ], [3 , 5 , 240 , 320 , 2 ]]:
6974 export (shape = shape )
7075 self .run_test ()
7176
72-
7377 def test_fft_roll (self ):
7478 from examples .fft .export_model_with_roll import export
7579
7680 export ()
7781 self .run_test ()
7882 self .run_test (test_onnx = True )
7983
80-
8184 def test_grid_sample (self ):
8285 from examples .grid_sample .export_model import export
8386
8487 export ()
8588 self .run_test (num_inputs = 2 )
8689 self .run_test (num_inputs = 2 , test_onnx = True )
8790
88-
8991 def test_complex_mul (self ):
9092 from examples .complex_mul .export_model import export
9193
@@ -94,6 +96,23 @@ def test_complex_mul(self):
9496 self .run_test (num_inputs = 2 )
9597 self .run_test (num_inputs = 2 , test_onnx = True )
9698
97-
98- if __name__ == '__main__' :
99+ def test_deformable_conv (self ):
100+ from examples .deformable_conv .export_model import export
101+
102+ export (
103+ inplanes = 15 ,
104+ outplanes = 15 ,
105+ kernel_size = 3 ,
106+ stride = 1 ,
107+ padding = 1 ,
108+ dilation = 1 ,
109+ deformable_groups = 1 ,
110+ inp_shape = [1 , 15 , 128 , 240 ],
111+ offset_shape = [1 , 18 , 128 , 240 ],
112+ )
113+ self .run_test (num_inputs = 2 , threshold = 2e-5 )
114+ self .run_test (num_inputs = 2 , test_onnx = True , threshold = 2e-5 )
115+
116+
117+ if __name__ == "__main__" :
99118 unittest .main ()
0 commit comments