@@ -144,6 +144,30 @@ def _create_modalities(self, problem_hparams, hparams):
144144 def has_input (self ):
145145 return self ._problem_hparams .input_modality
146146
147+ def eval_autoregressive (self ,
148+ features = None ,
149+ decode_length = 50 ,
150+ last_position_only = False ):
151+ """Autoregressive eval.
152+
153+ Quadratic time in decode_length.
154+
155+ Args:
156+ features: an map of string to `Tensor`
157+ decode_length: an integer. How many additional timesteps to decode.
158+ last_position_only: a boolean, speed-up by computing last position only.
159+
160+ Returns:
161+ sharded_logits: a list of `Tensor`s. Assumes one datashard.
162+ losses: a dictionary: {loss-name (string): floating point `Scalar`}.
163+ Contains a single key "training".
164+ """
165+ _ , logits , losses = self ._greedy_infer (
166+ features ,
167+ decode_length = decode_length ,
168+ last_position_only = last_position_only )
169+ return [logits ], losses
170+
147171 def infer (self ,
148172 features = None ,
149173 decode_length = 50 ,
@@ -179,11 +203,13 @@ def infer(self,
179203 beam_size = 1 # No use to run beam-search for a single class.
180204 if beam_size == 1 :
181205 tf .logging .info ("Greedy Decoding" )
182- return self ._greedy_infer (features , decode_length , last_position_only )
206+ samples , _ , _ = self ._greedy_infer (features , decode_length ,
207+ last_position_only )
183208 else :
184209 tf .logging .info ("Beam Decoding with beam size %d" % beam_size )
185- return self ._beam_decode (features , decode_length , beam_size , top_beams ,
186- last_position_only , alpha )
210+ samples = self ._beam_decode (features , decode_length , beam_size , top_beams ,
211+ last_position_only , alpha )
212+ return samples
187213
188214 def _beam_decode (self , features , decode_length , beam_size , top_beams ,
189215 last_position_only , alpha ):
@@ -268,6 +294,8 @@ def _greedy_infer(self, features, decode_length, last_position_only):
268294
269295 Returns:
270296 samples: an integer `Tensor`.
297+ logits: `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
298+ losses: a dictionary: {loss-name (string): floating point `Scalar`}
271299 """
272300 if not features :
273301 features = {}
@@ -278,14 +306,15 @@ def _greedy_infer(self, features, decode_length, last_position_only):
278306 if not self .has_input :
279307 features ["partial_targets" ] = tf .to_int64 (features ["inputs" ])
280308
281- def infer_step (recent_output , _ ):
309+ def infer_step (recent_output , recent_logits , unused_loss ):
282310 """Inference step."""
283311 recent_output .set_shape ([None , None , None , 1 ])
284312 padded = tf .pad (recent_output , [[0 , 0 ], [0 , 1 ], [0 , 0 ], [0 , 0 ]])
285313 features ["targets" ] = padded
286314 # This is inefficient in that it generates samples at all timesteps,
287315 # not just the last one, except if last_position_only is set (dangerous).
288- samples = self .sample (features , last_position_only = last_position_only )
316+ samples , logits , losses = self .sample (
317+ features , last_position_only = last_position_only )
289318 # Concatenate the already-generated recent_output with last timestep
290319 # of the newly-generated samples.
291320 if last_position_only :
@@ -295,7 +324,11 @@ def infer_step(recent_output, _):
295324 cur_sample = tf .to_int64 (tf .expand_dims (cur_sample , axis = 1 ))
296325 samples = tf .concat ([recent_output , cur_sample ], axis = 1 )
297326 samples .set_shape ([None , None , None , 1 ])
298- return samples
327+
328+ # Assuming we have one shard for logits.
329+ logits = tf .concat ([recent_logits , logits [0 ][:, - 1 :]], 1 )
330+ loss = sum (losses .values ())
331+ return samples , logits , loss
299332
300333 # Create an initial output tensor. This will be passed
301334 # to the infer_step, which adds one timestep at every iteration.
@@ -308,20 +341,32 @@ def infer_step(recent_output, _):
308341 # input shape, so we confuse it about the input shape.
309342 initial_output = tf .slice (initial_output , [0 , 0 , 0 , 0 ],
310343 tf .shape (initial_output ))
311- if _is_class_modality (
312- self . _hparams . problems [ self . _problem_idx ]. target_modality ):
344+ target_modality = self . _hparams . problems [ self . _problem_idx ]. target_modality
345+ if _is_class_modality ( target_modality ):
313346 decode_length = 1
314347 else :
315348 decode_length = tf .shape (features ["inputs" ])[1 ] + decode_length
316- result = tf .foldl (
317- infer_step ,
318- tf .range (decode_length ),
319- initializer = initial_output ,
349+ # Initial values of result, logits and loss.
350+ result = initial_output
351+ # tensor of shape [batch_size, time, 1, 1, vocab_size]
352+ logits = tf .zeros ((batch_size , 0 , 1 , 1 , target_modality .top_dimensionality ))
353+ logits .set_shape ([None , None , None , None , None ])
354+ loss = 0.0
355+
356+ result , logits , loss = tf .while_loop (
357+ lambda result , logits , loss : tf .shape (result )[1 ] < decode_length ,
358+ infer_step , [result , logits , loss ],
359+ shape_invariants = [
360+ tf .TensorShape ([None , None , None , None ]),
361+ tf .TensorShape ([None , None , None , None , None ]),
362+ tf .TensorShape ([]),
363+ ],
320364 back_prop = False ,
321365 parallel_iterations = 1 )
322366 if inputs_old is not None : # Restore to not confuse Estimator.
323367 features ["inputs" ] = inputs_old
324- return result
368+ losses = {"training" : loss }
369+ return result , logits , losses
325370
326371 def sample (self , features , last_position_only = False ):
327372 """Run the model and extract samples.
@@ -332,8 +377,10 @@ def sample(self, features, last_position_only=False):
332377
333378 Returns:
334379 samples: an integer `Tensor`.
380+ logits: a list of `Tensor`s, one per datashard.
381+ losses: a dictionary: {loss-name (string): floating point `Scalar`}.
335382 """
336- sharded_logits , _ = self .model_fn (
383+ sharded_logits , losses = self .model_fn (
337384 features , False , last_position_only = last_position_only )
338385 if self ._hparams .sampling_method == "argmax" :
339386 sharded_samples = self ._data_parallelism (tf .argmax , sharded_logits , 4 )
@@ -349,7 +396,7 @@ def _multinomial_squeeze(logits):
349396
350397 sharded_samples = self ._data_parallelism (_multinomial_squeeze ,
351398 sharded_logits )
352- return tf .concat (sharded_samples , 0 )
399+ return tf .concat (sharded_samples , 0 ), sharded_logits , losses
353400
354401 def _shard_features (self , features ): # pylint: disable=missing-docstring
355402 sharded_features = dict ()
0 commit comments