@@ -60,14 +60,6 @@ def setUp(self):
6060 self .beta_op = g .get_operation_by_name ('conv1/BatchNorm/beta/read' )
6161 self .beta_op_slice = orm .OpSlice (self .beta_op , orm .Slice (0 , 5 ))
6262
63- self .mean_op = g .get_operation_by_name (
64- 'conv1/BatchNorm/AssignMovingAvg/sub_1' )
65- self .mean_op_slice = orm .OpSlice (self .mean_op , orm .Slice (0 , 5 ))
66-
67- self .std_op = g .get_operation_by_name (
68- 'conv1/BatchNorm/AssignMovingAvg_1/sub_1' )
69- self .std_op_slice = orm .OpSlice (self .std_op , orm .Slice (0 , 5 ))
70-
7163 # Create mock OpRegularizerManager with custom mapping of OpSlice and
7264 # OpGroup.
7365 self .mock_op_reg_manager = mock .create_autospec (orm .OpRegularizerManager )
@@ -78,8 +70,6 @@ def setUp(self):
7870 self .relu_op : [self .relu_op_slice ],
7971 self .gamma_op : [self .gamma_op_slice ],
8072 self .beta_op : [self .beta_op_slice ],
81- self .mean_op : [self .mean_op_slice ],
82- self .std_op : [self .std_op_slice ],
8373 }
8474 def get_op_slices (op ):
8575 return self .op_slice_dict .get (op )
@@ -92,7 +82,7 @@ def get_op_group(op_slice):
9282 self .mock_op_reg_manager .is_source_op .return_value = False
9383 self .mock_op_reg_manager .ops = [
9484 self .batch_norm_op , self .conv_op , self .relu_op , self .gamma_op ,
95- self .beta_op , self . mean_op , self . std_op ]
85+ self .beta_op ]
9686
9787 def testAssignGrouping_NoNeighborGroups (self ):
9888 # No ops have groups.
@@ -109,39 +99,30 @@ def testAssignGrouping_NoNeighborGroups(self):
10999 mock .call (self .gamma_op ),
110100 mock .call (self .beta_op ),
111101 mock .call (self .relu_op ),
112- mock .call (self .mean_op ),
113- mock .call (self .std_op ),
114102 # Initial slice data.
115103 mock .call (self .batch_norm_op ),
116104 mock .call (self .conv_op ),
117105 mock .call (self .gamma_op ),
118106 mock .call (self .beta_op ),
119107 mock .call (self .relu_op ),
120- mock .call (self .mean_op ),
121- mock .call (self .std_op ),
122108 # Reslicing.
123109 mock .call (self .conv_op ),
124110 mock .call (self .gamma_op ),
125111 mock .call (self .beta_op ),
126112 mock .call (self .batch_norm_op ),
127113 mock .call (self .relu_op ),
128- mock .call (self .mean_op ),
129- mock .call (self .std_op ),
130114 # Refreshing slice data.
131115 mock .call (self .conv_op ),
132116 mock .call (self .gamma_op ),
133117 mock .call (self .beta_op ),
134- mock .call (self .relu_op ),
135- mock .call (self .mean_op ),
136- mock .call (self .std_op )])
118+ mock .call (self .relu_op )])
137119
138120 # Verify manager does not group.
139121 self .mock_op_reg_manager .group_op_slices .assert_not_called ()
140122
141123 # Verify manager processes grouping for Conv2D, ReLU, and batch norm ops.
142124 self .mock_op_reg_manager .process_ops .assert_called_once_with (
143- [self .relu_op , self .mean_op , self .std_op , self .conv_op , self .gamma_op ,
144- self .beta_op ])
125+ [self .relu_op , self .conv_op , self .gamma_op , self .beta_op ])
145126 self .mock_op_reg_manager .process_ops_last .assert_called_once_with (
146127 [self .batch_norm_op ])
147128
@@ -167,43 +148,35 @@ def testAssignGrouping_AllInputsGrouped(self):
167148 mock .call (self .gamma_op ),
168149 mock .call (self .beta_op ),
169150 mock .call (self .relu_op ),
170- mock .call (self .mean_op ),
171- mock .call (self .std_op ),
172151 # Initial slice data.
173152 mock .call (self .batch_norm_op ),
174153 mock .call (self .conv_op ),
175154 mock .call (self .gamma_op ),
176155 mock .call (self .beta_op ),
177156 mock .call (self .relu_op ),
178- mock .call (self .mean_op ),
179- mock .call (self .std_op ),
180157 # Reslicing.
181158 mock .call (self .conv_op ),
182159 mock .call (self .gamma_op ),
183160 mock .call (self .beta_op ),
184161 mock .call (self .batch_norm_op ),
185162 mock .call (self .relu_op ),
186- mock .call (self .mean_op ),
187- mock .call (self .std_op ),
188163 # Refreshing slice data.
189164 mock .call (self .conv_op ),
190165 mock .call (self .gamma_op ),
191166 mock .call (self .beta_op ),
192167 mock .call (self .relu_op ),
193- mock .call (self .mean_op ),
194- mock .call (self .std_op ),
195168 # Group batch norm op.
196169 mock .call (self .batch_norm_op )])
197170
198171 # Verify manager groups batch norm with input ops.
199- self .mock_op_reg_manager .group_op_slices .assert_called_once_with (
200- [self .batch_norm_op_slice , self .conv_op_slice , self .gamma_op_slice ,
201- self .beta_op_slice ])
172+ self .mock_op_reg_manager .group_op_slices .assert_has_calls (
173+ [mock .call ([self .batch_norm_op_slice , self .relu_op_slice ]),
174+ mock .call ([self .batch_norm_op_slice , self .conv_op_slice ,
175+ self .gamma_op_slice , self .beta_op_slice ])])
202176
203177 # Verify manager processes grouping for mean_op and std_op which do not have
204178 # groups.
205- self .mock_op_reg_manager .process_ops .assert_called_once_with (
206- [self .mean_op , self .std_op ])
179+ self .mock_op_reg_manager .process_ops .assert_not_called ()
207180 self .mock_op_reg_manager .process_ops_last .assert_not_called ()
208181
209182 def testAssignGrouping_AllOutputsGrouped (self ):
@@ -213,8 +186,6 @@ def testAssignGrouping_AllOutputsGrouped(self):
213186 self .conv_op_slice : self .conv_op_group ,
214187 self .relu_op_slice : self .relu_op_group ,
215188 self .gamma_op_slice : self .conv_op_group ,
216- self .mean_op_slice : self .relu_op_group ,
217- self .std_op_slice : self .relu_op_group ,
218189 }
219190
220191 # Call handler to assign grouping.
@@ -228,31 +199,23 @@ def testAssignGrouping_AllOutputsGrouped(self):
228199 mock .call (self .gamma_op ),
229200 mock .call (self .beta_op ),
230201 mock .call (self .relu_op ),
231- mock .call (self .mean_op ),
232- mock .call (self .std_op ),
233202 # Initial slice data.
234203 mock .call (self .batch_norm_op ),
235204 mock .call (self .conv_op ),
236205 mock .call (self .gamma_op ),
237206 mock .call (self .beta_op ),
238207 mock .call (self .relu_op ),
239- mock .call (self .mean_op ),
240- mock .call (self .std_op ),
241208 # Reslicing.
242209 mock .call (self .conv_op ),
243210 mock .call (self .gamma_op ),
244211 mock .call (self .beta_op ),
245212 mock .call (self .batch_norm_op ),
246213 mock .call (self .relu_op ),
247- mock .call (self .mean_op ),
248- mock .call (self .std_op ),
249214 # Refreshing slice data.
250215 mock .call (self .conv_op ),
251216 mock .call (self .gamma_op ),
252217 mock .call (self .beta_op ),
253- mock .call (self .relu_op ),
254- mock .call (self .mean_op ),
255- mock .call (self .std_op )])
218+ mock .call (self .relu_op )])
256219
257220 # Verify manager does not group.
258221 self .mock_op_reg_manager .group_op_slices .assert_not_called ()
@@ -271,8 +234,6 @@ def testAssignGrouping_AllNeighborsGrouped(self):
271234 self .relu_op_slice : self .relu_op_group ,
272235 self .gamma_op_slice : self .conv_op_group ,
273236 self .beta_op_slice : self .conv_op_group ,
274- self .mean_op_slice : self .relu_op_group ,
275- self .std_op_slice : self .relu_op_group ,
276237 }
277238
278239 # Call handler to assign grouping.
@@ -286,38 +247,29 @@ def testAssignGrouping_AllNeighborsGrouped(self):
286247 mock .call (self .gamma_op ),
287248 mock .call (self .beta_op ),
288249 mock .call (self .relu_op ),
289- mock .call (self .mean_op ),
290- mock .call (self .std_op ),
291250 # Initial slice data.
292251 mock .call (self .batch_norm_op ),
293252 mock .call (self .conv_op ),
294253 mock .call (self .gamma_op ),
295254 mock .call (self .beta_op ),
296255 mock .call (self .relu_op ),
297- mock .call (self .mean_op ),
298- mock .call (self .std_op ),
299256 # Reslicing.
300257 mock .call (self .conv_op ),
301258 mock .call (self .gamma_op ),
302259 mock .call (self .beta_op ),
303260 mock .call (self .batch_norm_op ),
304261 mock .call (self .relu_op ),
305- mock .call (self .mean_op ),
306- mock .call (self .std_op ),
307262 # Refreshing slice data.
308263 mock .call (self .conv_op ),
309264 mock .call (self .gamma_op ),
310265 mock .call (self .beta_op ),
311266 mock .call (self .relu_op ),
312- mock .call (self .mean_op ),
313- mock .call (self .std_op ),
314267 # Group batch norm op.
315268 mock .call (self .batch_norm_op )])
316269
317270 # Verify manager groups batch norm with inputs and outputs.
318271 self .mock_op_reg_manager .group_op_slices .assert_has_calls (
319- [mock .call ([self .batch_norm_op_slice , self .relu_op_slice ,
320- self .mean_op_slice , self .std_op_slice ]),
272+ [mock .call ([self .batch_norm_op_slice , self .relu_op_slice ]),
321273 mock .call ([self .batch_norm_op_slice , self .conv_op_slice ,
322274 self .gamma_op_slice , self .beta_op_slice ])])
323275
@@ -333,8 +285,6 @@ def testAssignGrouping_AllNeighborsGroupedSameGroup(self):
333285 self .relu_op_slice : self .batch_norm_op_group ,
334286 self .gamma_op_slice : self .batch_norm_op_group ,
335287 self .beta_op_slice : self .batch_norm_op_group ,
336- self .mean_op_slice : self .batch_norm_op_group ,
337- self .std_op_slice : self .batch_norm_op_group ,
338288 }
339289
340290 # Call handler to assign grouping.
@@ -348,31 +298,23 @@ def testAssignGrouping_AllNeighborsGroupedSameGroup(self):
348298 mock .call (self .gamma_op ),
349299 mock .call (self .beta_op ),
350300 mock .call (self .relu_op ),
351- mock .call (self .mean_op ),
352- mock .call (self .std_op ),
353301 # Initial slice data.
354302 mock .call (self .batch_norm_op ),
355303 mock .call (self .conv_op ),
356304 mock .call (self .gamma_op ),
357305 mock .call (self .beta_op ),
358306 mock .call (self .relu_op ),
359- mock .call (self .mean_op ),
360- mock .call (self .std_op ),
361307 # Reslicing.
362308 mock .call (self .conv_op ),
363309 mock .call (self .gamma_op ),
364310 mock .call (self .beta_op ),
365311 mock .call (self .batch_norm_op ),
366312 mock .call (self .relu_op ),
367- mock .call (self .mean_op ),
368- mock .call (self .std_op ),
369313 # Refreshing slice data.
370314 mock .call (self .conv_op ),
371315 mock .call (self .gamma_op ),
372316 mock .call (self .beta_op ),
373317 mock .call (self .relu_op ),
374- mock .call (self .mean_op ),
375- mock .call (self .std_op ),
376318 # Group batch norm op.
377319 mock .call (self .batch_norm_op )])
378320
@@ -400,8 +342,6 @@ def is_passthrough(op):
400342 self .relu_op_slice : self .relu_op_group ,
401343 self .gamma_op_slice : self .conv_op_group ,
402344 self .beta_op_slice : self .conv_op_group ,
403- self .mean_op_slice : self .relu_op_group ,
404- self .std_op_slice : self .relu_op_group ,
405345 }
406346
407347 # Call handler to assign grouping.
@@ -415,37 +355,27 @@ def is_passthrough(op):
415355 mock .call (self .gamma_op ),
416356 mock .call (self .beta_op ),
417357 mock .call (self .relu_op ),
418- mock .call (self .mean_op ),
419- mock .call (self .std_op ),
420358 # Initial slice data.
421359 mock .call (self .batch_norm_op ),
422360 mock .call (self .conv_op ),
423361 mock .call (self .gamma_op ),
424362 mock .call (self .beta_op ),
425- mock .call (self .mean_op ),
426- mock .call (self .std_op ),
427363 # Reslicing.
428364 mock .call (self .conv_op ),
429365 mock .call (self .gamma_op ),
430366 mock .call (self .beta_op ),
431367 mock .call (self .batch_norm_op ),
432- mock .call (self .mean_op ),
433- mock .call (self .std_op ),
434368 # Refreshing slice data.
435369 mock .call (self .conv_op ),
436370 mock .call (self .gamma_op ),
437371 mock .call (self .beta_op ),
438- mock .call (self .mean_op ),
439- mock .call (self .std_op ),
440372 # Group batch norm op.
441373 mock .call (self .batch_norm_op )])
442374
443375 # Verify manager groups batch norm with inputs and outputs. ReLU is not
444376 # part of the grouping.
445377 self .mock_op_reg_manager .group_op_slices .assert_has_calls (
446- [mock .call ([self .batch_norm_op_slice , self .mean_op_slice ,
447- self .std_op_slice ]),
448- mock .call ([self .batch_norm_op_slice , self .conv_op_slice ,
378+ [mock .call ([self .batch_norm_op_slice , self .conv_op_slice ,
449379 self .gamma_op_slice , self .beta_op_slice ])])
450380
451381 # Verify manager does not process any additional ops.
@@ -454,12 +384,11 @@ def is_passthrough(op):
454384
455385 def testGetInputOutputOpSlices (self ):
456386 input_ops = [self .conv_op , self .gamma_op , self .beta_op ]
457- output_ops = [self .mean_op , self . std_op , self . relu_op ]
387+ output_ops = [self .relu_op ]
458388
459389 expected_input_op_slices = [
460390 [self .conv_op_slice ], [self .gamma_op_slice ], [self .beta_op_slice ]]
461- expected_output_op_slices = [
462- [self .mean_op_slice ], [self .std_op_slice ], [self .relu_op_slice ]]
391+ expected_output_op_slices = [[self .relu_op_slice ]]
463392
464393 # Instantiate handler.
465394 handler = grouping_op_handler .GroupingOpHandler ()
0 commit comments