@@ -43,6 +43,8 @@ class Metrics(object):
4343 ROUGE_2_F = "rouge_2_fscore"
4444 ROUGE_L_F = "rouge_L_fscore"
4545 EDIT_DISTANCE = "edit_distance"
46+ SET_PRECISION = "set_precision"
47+ SET_RECALL = "set_recall"
4648
4749
4850def padded_rmse (predictions , labels , weights_fn = common_layers .weights_all ):
@@ -189,6 +191,54 @@ def padded_accuracy(predictions,
189191 return tf .to_float (tf .equal (outputs , padded_labels )), weights
190192
191193
194+ def set_precision (predictions ,
195+ labels ,
196+ weights_fn = common_layers .weights_nonzero ):
197+ """Precision of set predictions.
198+
199+ Args:
200+ predictions : A Tensor of scores of shape [batch, nlabels].
201+ labels: A Tensor of int32s giving true set elements,
202+ of shape [batch, seq_length].
203+ weights_fn: A function to weight the elements.
204+
205+ Returns:
206+ hits: A Tensor of shape [batch, nlabels].
207+ weights: A Tensor of shape [batch, nlabels].
208+ """
209+ with tf .variable_scope ("set_precision" , values = [predictions , labels ]):
210+ labels = tf .squeeze (labels , [2 , 3 ])
211+ weights = weights_fn (labels )
212+ labels = tf .one_hot (labels , predictions .shape [- 1 ])
213+ labels = tf .reduce_max (labels , axis = 1 )
214+ labels = tf .cast (labels , tf .bool )
215+ return tf .to_float (tf .equal (labels , predictions )), weights
216+
217+
218+ def set_recall (predictions ,
219+ labels ,
220+ weights_fn = common_layers .weights_nonzero ):
221+ """Recall of set predictions.
222+
223+ Args:
224+ predictions : A Tensor of scores of shape [batch, nlabels].
225+ labels: A Tensor of int32s giving true set elements,
226+ of shape [batch, seq_length].
227+ weights_fn: A function to weight the elements.
228+
229+ Returns:
230+ hits: A Tensor of shape [batch, nlabels].
231+ weights: A Tensor of shape [batch, nlabels].
232+ """
233+ with tf .variable_scope ("set_recall" , values = [predictions , labels ]):
234+ labels = tf .squeeze (labels , [2 , 3 ])
235+ weights = weights_fn (labels )
236+ labels = tf .one_hot (labels , predictions .shape [- 1 ])
237+ labels = tf .reduce_max (labels , axis = 1 )
238+ labels = tf .cast (labels , tf .bool )
239+ return tf .to_float (tf .equal (labels , predictions )), weights
240+
241+
192242def create_evaluation_metrics (problems , model_hparams ):
193243 """Creates the evaluation metrics for the model.
194244
@@ -281,4 +331,6 @@ def wrapped_metric_fn():
281331 Metrics .ROUGE_2_F : rouge .rouge_2_fscore ,
282332 Metrics .ROUGE_L_F : rouge .rouge_l_fscore ,
283333 Metrics .EDIT_DISTANCE : sequence_edit_distance ,
334+ Metrics .SET_PRECISION : set_precision ,
335+ Metrics .SET_RECALL : set_recall ,
284336}
0 commit comments