@@ -262,6 +262,32 @@ def _prepare_loss_weights(loss_weights, output_names):
262262 'got {}' .format (str (loss_weights )))
263263
264264
265+ def _clone_metrics (metrics ):
266+ """Creates a copy of the maybe-nested metric specification.
267+
268+ Args:
269+ metrics: A collection of metric specifications. Supports the same set of
270+ formats as the `metrics` argument in `tf.keras.Model.compile`.
271+
272+ Returns:
273+ The same format as the `metrics` argument, with all `tf.keras.metric.Metric`
274+ objects replaced by their copies.
275+ """
276+
277+ def clone (metric ):
278+ # A `Metric` object is stateful and can only be used in 1 model on 1 output.
279+ # Cloning the object allows the same metric to be applied in both base and
280+ # adversarial-regularized models, and also on multiple outputs in one model.
281+ # The cloning logic is the same as the `clone_metric` function in
282+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/metrics.py
283+ if not isinstance (metric , keras .metrics .Metric ):
284+ return metric
285+ with tf .init_scope ():
286+ return metric .__class__ .from_config (metric .get_config ())
287+
288+ return tf .nest .map_structure (clone , metrics )
289+
290+
265291def _prepare_metric_fns (metrics , output_names , loss_wrappers ):
266292 """Converts `metrics` into a list of per-output list of metrics.
267293
@@ -290,16 +316,16 @@ def _prepare_metric_fns(metrics, output_names, loss_wrappers):
290316 to_list = lambda x : x if isinstance (x , list ) else [x ]
291317
292318 if isinstance (metrics , collections .Mapping ):
293- # If `metrics` is a dictionary mapping output name to a list of metric fns,
294- # coverts it to a list of lists using the order in `output_names`.
319+ # Converts `metrics` from a dictionary to a list of lists using the order
320+ # specified in `output_names`.
295321 metrics = [to_list (metrics .get (name , [])) for name in output_names ]
296322
297323 if not any (isinstance (m , list ) for m in metrics ):
298- # If `metrics` is a list of metric fns, replicates them to be a list of
299- # lists so that all metric fns can be applied to each output.
300- metrics = [metrics for _ in output_names ]
324+ # Replicates `metrics` to be a list of lists if it is a plain list of
325+ # metrics, so that all metrics can be applied to each output.
326+ metrics = [metrics ] + [ _clone_metrics ( metrics ) for _ in output_names [ 1 :] ]
301327
302- # Here `metrics` is a list of lists, each sub-list corresponds to metric fns
328+ # Here `metrics` is a list of lists, and each sub-list corresponds to metrics
303329 # to be applied on an output.
304330 if len (metrics ) != len (output_names ):
305331 raise ValueError ('The number of sub-lists in `metrics` should be the '
@@ -326,6 +352,7 @@ def _compute_loss_and_metrics(losses,
326352 outputs. Must have the same length as `labels` and `outputs`.
327353 metrics: List of list of (metric fn, metric name) pairs, for additional
328354 metrics to report for each output. Must have the same length as `outputs`.
355+ If set to `None`, no additional metrics will be reported.
329356 labels: List of `Tensor` objects of ground truth targets. Must have the same
330357 length as `losses` and `outputs`.
331358 outputs: List of `Tensor` objects of predicted targets. Must have the same
@@ -334,17 +361,26 @@ def _compute_loss_and_metrics(losses,
334361
335362 Returns:
336363 total_loss: Weighted sum of losses on all outputs.
337- metrics: List of (value, name) pairs for metric reporting.
364+ metrics: List of (value, aggregation, name) tuples for metric reporting.
338365 """
339366 outputs = tf .nest .flatten (outputs )
340367 total_loss , output_metrics = [], []
368+ if metrics is None :
369+ metrics = [[]] * len (losses )
341370 for (label , output , loss , per_output_metrics ) in zip (labels , outputs , losses ,
342371 metrics ):
343372 loss_value = loss (label , output , sample_weights )
344373 total_loss .append (loss .weight * loss_value )
345- output_metrics .append ((loss_value , loss .name ))
374+ output_metrics .append ((loss_value , 'mean' , loss .name ))
346375 for metric_fn , metric_name in per_output_metrics :
347- output_metrics .append ((metric_fn (label , output ), metric_name ))
376+ value = metric_fn (label , output )
377+ # Metric objects always return an aggregated result, and shouldn't be
378+ # aggregated again.
379+ if isinstance (metric_fn , keras .metrics .Metric ):
380+ aggregation = None
381+ else :
382+ aggregation = 'mean'
383+ output_metrics .append ((value , aggregation , metric_name ))
348384 return tf .add_n (total_loss ), output_metrics
349385
350386
@@ -451,7 +487,7 @@ def compile(self,
451487 self .base_model .compile (
452488 optimizer ,
453489 loss = self ._compile_arg_loss ,
454- metrics = self ._compile_arg_metrics ,
490+ metrics = _clone_metrics ( self ._compile_arg_metrics ) ,
455491 loss_weights = self ._compile_arg_loss_weights ,
456492 ** kwargs )
457493
@@ -517,6 +553,9 @@ def _build_labeled_metrics(self, output_names, labeled_losses):
517553 per_output_metrics = []
518554 for metric_fn in metric_fns :
519555 metric_name = self ._make_metric_name (metric_fn , label_key )
556+ if isinstance (metric_fn , keras .metrics .Metric ):
557+ # Updates the name of the Metric object to make sure it is unique.
558+ metric_fn ._name = metric_name # pylint: disable=protected-access
520559 per_output_metrics .append ((metric_fn , metric_name ))
521560 self ._labeled_metrics .append (per_output_metrics )
522561
@@ -526,9 +565,10 @@ def _get_or_create_base_output_names(self, outputs):
526565 ['output_%d' % i for i in range (1 , num_output + 1 )])
527566
528567 def _compute_total_loss (self , labels , outputs , sample_weights = None ):
529- loss , _ = _compute_loss_and_metrics (self ._labeled_losses ,
530- self ._labeled_metrics , labels , outputs ,
531- sample_weights )
568+ # `None` is passed instead of the actual metrics in order to skip computing
569+ # metric values and updating metric states.
570+ loss , _ = _compute_loss_and_metrics (self ._labeled_losses , None , labels ,
571+ outputs , sample_weights )
532572 return loss
533573
534574 def _split_inputs (self , inputs ):
@@ -575,8 +615,8 @@ def call(self, inputs, **kwargs):
575615 outputs , labeled_loss , metrics , tape = self ._forward_pass (
576616 inputs , labels , sample_weights , kwargs )
577617 self .add_loss (labeled_loss )
578- for value , name in metrics :
579- self .add_metric (value , aggregation = 'mean' , name = name )
618+ for value , aggregation , name in metrics :
619+ self .add_metric (value , aggregation = aggregation , name = name )
580620
581621 # Adversarial loss.
582622 base_model_fn = lambda inputs : self .base_model (inputs , ** kwargs )
0 commit comments