Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0eeb116

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Evaluate auto-regressively in t2t. Currently, we use the actual output during eval. To use the predicted output in the previous step, extend the infer code to run eval auto-regressively.
PiperOrigin-RevId: 164755091
1 parent d12cb9d commit 0eeb116

File tree

3 files changed

+71
-17
lines changed

3 files changed

+71
-17
lines changed

tensor2tensor/utils/model_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,12 @@ def nth_model(n):
192192
# On worker 0 also build graph for problems <= 1.
193193
# TODO(lukaszkaiser): why is this hack needed for variables init? Repair.
194194
skip_this_one = skip_this_one and (FLAGS.worker_id != 0 or n > 1)
195-
sharded_logits, losses_dict = model_class.model_fn(
196-
features, skip=(skipping_is_on and skip_this_one))
195+
if (FLAGS.eval_run_autoregressive and
196+
mode == tf.contrib.learn.ModeKeys.EVAL):
197+
sharded_logits, losses_dict = model_class.eval_autoregressive(features)
198+
else:
199+
sharded_logits, losses_dict = model_class.model_fn(
200+
features, skip=(skipping_is_on and skip_this_one))
197201
with tf.variable_scope("losses_avg"):
198202
total_loss, ops = 0.0, []
199203
for loss_key, loss_value in six.iteritems(losses_dict):

tensor2tensor/utils/t2t_model.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,30 @@ def _create_modalities(self, problem_hparams, hparams):
144144
def has_input(self):
145145
return self._problem_hparams.input_modality
146146

147+
def eval_autoregressive(self,
148+
features=None,
149+
decode_length=50,
150+
last_position_only=False):
151+
"""Autoregressive eval.
152+
153+
Quadratic time in decode_length.
154+
155+
Args:
156+
features: an map of string to `Tensor`
157+
decode_length: an integer. How many additional timesteps to decode.
158+
last_position_only: a boolean, speed-up by computing last position only.
159+
160+
Returns:
161+
sharded_logits: a list of `Tensor`s. Assumes one datashard.
162+
losses: a dictionary: {loss-name (string): floating point `Scalar`}.
163+
Contains a single key "training".
164+
"""
165+
_, logits, losses = self._greedy_infer(
166+
features,
167+
decode_length=decode_length,
168+
last_position_only=last_position_only)
169+
return [logits], losses
170+
147171
def infer(self,
148172
features=None,
149173
decode_length=50,
@@ -179,11 +203,13 @@ def infer(self,
179203
beam_size = 1 # No use to run beam-search for a single class.
180204
if beam_size == 1:
181205
tf.logging.info("Greedy Decoding")
182-
return self._greedy_infer(features, decode_length, last_position_only)
206+
samples, _, _ = self._greedy_infer(features, decode_length,
207+
last_position_only)
183208
else:
184209
tf.logging.info("Beam Decoding with beam size %d" % beam_size)
185-
return self._beam_decode(features, decode_length, beam_size, top_beams,
186-
last_position_only, alpha)
210+
samples = self._beam_decode(features, decode_length, beam_size, top_beams,
211+
last_position_only, alpha)
212+
return samples
187213

188214
def _beam_decode(self, features, decode_length, beam_size, top_beams,
189215
last_position_only, alpha):
@@ -268,6 +294,8 @@ def _greedy_infer(self, features, decode_length, last_position_only):
268294
269295
Returns:
270296
samples: an integer `Tensor`.
297+
logits: `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
298+
losses: a dictionary: {loss-name (string): floating point `Scalar`}
271299
"""
272300
if not features:
273301
features = {}
@@ -278,14 +306,15 @@ def _greedy_infer(self, features, decode_length, last_position_only):
278306
if not self.has_input:
279307
features["partial_targets"] = tf.to_int64(features["inputs"])
280308

281-
def infer_step(recent_output, _):
309+
def infer_step(recent_output, recent_logits, unused_loss):
282310
"""Inference step."""
283311
recent_output.set_shape([None, None, None, 1])
284312
padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
285313
features["targets"] = padded
286314
# This is inefficient in that it generates samples at all timesteps,
287315
# not just the last one, except if last_position_only is set (dangerous).
288-
samples = self.sample(features, last_position_only=last_position_only)
316+
samples, logits, losses = self.sample(
317+
features, last_position_only=last_position_only)
289318
# Concatenate the already-generated recent_output with last timestep
290319
# of the newly-generated samples.
291320
if last_position_only:
@@ -295,7 +324,11 @@ def infer_step(recent_output, _):
295324
cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1))
296325
samples = tf.concat([recent_output, cur_sample], axis=1)
297326
samples.set_shape([None, None, None, 1])
298-
return samples
327+
328+
# Assuming we have one shard for logits.
329+
logits = tf.concat([recent_logits, logits[0][:, -1:]], 1)
330+
loss = sum(losses.values())
331+
return samples, logits, loss
299332

300333
# Create an initial output tensor. This will be passed
301334
# to the infer_step, which adds one timestep at every iteration.
@@ -308,20 +341,32 @@ def infer_step(recent_output, _):
308341
# input shape, so we confuse it about the input shape.
309342
initial_output = tf.slice(initial_output, [0, 0, 0, 0],
310343
tf.shape(initial_output))
311-
if _is_class_modality(
312-
self._hparams.problems[self._problem_idx].target_modality):
344+
target_modality = self._hparams.problems[self._problem_idx].target_modality
345+
if _is_class_modality(target_modality):
313346
decode_length = 1
314347
else:
315348
decode_length = tf.shape(features["inputs"])[1] + decode_length
316-
result = tf.foldl(
317-
infer_step,
318-
tf.range(decode_length),
319-
initializer=initial_output,
349+
# Initial values of result, logits and loss.
350+
result = initial_output
351+
# tensor of shape [batch_size, time, 1, 1, vocab_size]
352+
logits = tf.zeros((batch_size, 0, 1, 1, target_modality.top_dimensionality))
353+
logits.set_shape([None, None, None, None, None])
354+
loss = 0.0
355+
356+
result, logits, loss = tf.while_loop(
357+
lambda result, logits, loss: tf.shape(result)[1] < decode_length,
358+
infer_step, [result, logits, loss],
359+
shape_invariants=[
360+
tf.TensorShape([None, None, None, None]),
361+
tf.TensorShape([None, None, None, None, None]),
362+
tf.TensorShape([]),
363+
],
320364
back_prop=False,
321365
parallel_iterations=1)
322366
if inputs_old is not None: # Restore to not confuse Estimator.
323367
features["inputs"] = inputs_old
324-
return result
368+
losses = {"training": loss}
369+
return result, logits, losses
325370

326371
def sample(self, features, last_position_only=False):
327372
"""Run the model and extract samples.
@@ -332,8 +377,10 @@ def sample(self, features, last_position_only=False):
332377
333378
Returns:
334379
samples: an integer `Tensor`.
380+
logits: a list of `Tensor`s, one per datashard.
381+
losses: a dictionary: {loss-name (string): floating point `Scalar`}.
335382
"""
336-
sharded_logits, _ = self.model_fn(
383+
sharded_logits, losses = self.model_fn(
337384
features, False, last_position_only=last_position_only)
338385
if self._hparams.sampling_method == "argmax":
339386
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
@@ -349,7 +396,7 @@ def _multinomial_squeeze(logits):
349396

350397
sharded_samples = self._data_parallelism(_multinomial_squeeze,
351398
sharded_logits)
352-
return tf.concat(sharded_samples, 0)
399+
return tf.concat(sharded_samples, 0), sharded_logits, losses
353400

354401
def _shard_features(self, features): # pylint: disable=missing-docstring
355402
sharded_features = dict()

tensor2tensor/utils/trainer_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@
6363
"The number of steps to run training for.")
6464
flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.")
6565
flags.DEFINE_bool("eval_print", False, "Print eval logits and predictions.")
66+
flags.DEFINE_bool("eval_run_autoregressive", False,
67+
"Run eval autoregressively where we condition on previous"
68+
"generated output instead of the actual target.")
6669
flags.DEFINE_integer("keep_checkpoint_max", 20,
6770
"How many recent checkpoints to keep.")
6871
flags.DEFINE_bool("experimental_optimize_placement", False,

0 commit comments

Comments
 (0)