Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit dc190ec

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
internal merege
PiperOrigin-RevId: 172006716
1 parent 9aa3326 commit dc190ec

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

tensor2tensor/utils/metrics.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class Metrics(object):
4343
ROUGE_2_F = "rouge_2_fscore"
4444
ROUGE_L_F = "rouge_L_fscore"
4545
EDIT_DISTANCE = "edit_distance"
46+
SET_PRECISION = "set_precision"
47+
SET_RECALL = "set_recall"
4648

4749

4850
def padded_rmse(predictions, labels, weights_fn=common_layers.weights_all):
@@ -189,6 +191,54 @@ def padded_accuracy(predictions,
189191
return tf.to_float(tf.equal(outputs, padded_labels)), weights
190192

191193

194+
def set_precision(predictions,
195+
labels,
196+
weights_fn=common_layers.weights_nonzero):
197+
"""Precision of set predictions.
198+
199+
Args:
200+
predictions : A Tensor of scores of shape [batch, nlabels].
201+
labels: A Tensor of int32s giving true set elements,
202+
of shape [batch, seq_length].
203+
weights_fn: A function to weight the elements.
204+
205+
Returns:
206+
hits: A Tensor of shape [batch, nlabels].
207+
weights: A Tensor of shape [batch, nlabels].
208+
"""
209+
with tf.variable_scope("set_precision", values=[predictions, labels]):
210+
labels = tf.squeeze(labels, [2, 3])
211+
weights = weights_fn(labels)
212+
labels = tf.one_hot(labels, predictions.shape[-1])
213+
labels = tf.reduce_max(labels, axis=1)
214+
labels = tf.cast(labels, tf.bool)
215+
return tf.to_float(tf.equal(labels, predictions)), weights
216+
217+
218+
def set_recall(predictions,
219+
labels,
220+
weights_fn=common_layers.weights_nonzero):
221+
"""Recall of set predictions.
222+
223+
Args:
224+
predictions : A Tensor of scores of shape [batch, nlabels].
225+
labels: A Tensor of int32s giving true set elements,
226+
of shape [batch, seq_length].
227+
weights_fn: A function to weight the elements.
228+
229+
Returns:
230+
hits: A Tensor of shape [batch, nlabels].
231+
weights: A Tensor of shape [batch, nlabels].
232+
"""
233+
with tf.variable_scope("set_recall", values=[predictions, labels]):
234+
labels = tf.squeeze(labels, [2, 3])
235+
weights = weights_fn(labels)
236+
labels = tf.one_hot(labels, predictions.shape[-1])
237+
labels = tf.reduce_max(labels, axis=1)
238+
labels = tf.cast(labels, tf.bool)
239+
return tf.to_float(tf.equal(labels, predictions)), weights
240+
241+
192242
def create_evaluation_metrics(problems, model_hparams):
193243
"""Creates the evaluation metrics for the model.
194244
@@ -281,4 +331,6 @@ def wrapped_metric_fn():
281331
Metrics.ROUGE_2_F: rouge.rouge_2_fscore,
282332
Metrics.ROUGE_L_F: rouge.rouge_l_fscore,
283333
Metrics.EDIT_DISTANCE: sequence_edit_distance,
334+
Metrics.SET_PRECISION: set_precision,
335+
Metrics.SET_RECALL: set_recall,
284336
}

0 commit comments

Comments
 (0)