2626import six
2727from six .moves import xrange # pylint: disable=redefined-builtin
2828
29+ from tensor2tensor .layers import common_layers
2930from tensor2tensor .utils import beam_search
3031from tensor2tensor .utils import expert_utils as eu
3132from tensor2tensor .utils import registry
@@ -523,9 +524,9 @@ def model_fn(self, features, skip=False, last_position_only=False):
523524 with tf .variable_scope (target_modality .name , reuse = target_reuse ):
524525 if not last_position_only :
525526 sharded_logits = target_modality .top_sharded (
526- body_outputs , sharded_features ["targets" ], self . _data_parallelism )
527+ body_outputs , sharded_features ["targets" ], dp )
527528 training_loss = target_modality .loss_sharded (
528- sharded_logits , sharded_features ["targets" ], self . _data_parallelism )
529+ sharded_logits , sharded_features ["targets" ], dp )
529530
530531 training_loss *= self ._problem_hparams .loss_multiplier
531532 else :
@@ -543,9 +544,60 @@ def model_fn(self, features, skip=False, last_position_only=False):
543544 last_position_targets ,
544545 self ._data_parallelism )
545546 training_loss = None
547+ losses ["training" ] = training_loss
548+
549+ # Scheduled sampling.
550+ do_scheduled_sampling = ( # Only do it if training and set for it.
551+ self ._hparams .scheduled_sampling_prob > 0.0 and
552+ self ._hparams .mode == tf .estimator .ModeKeys .TRAIN and
553+ not skip )
554+ if do_scheduled_sampling :
555+
556+ def sample (x ):
557+ """Multinomial sampling from a n-dimensional tensor."""
558+ vocab_size = target_modality .top_dimensionality
559+ samples = tf .multinomial (tf .reshape (x , [- 1 , vocab_size ]), 1 )
560+ reshaped_samples = tf .reshape (samples , tf .shape (x )[:- 1 ])
561+ return tf .to_int32 (reshaped_samples )
562+
563+ def mix_gold_sampled (gold_targets , sampled_targets ):
564+ return tf .where (
565+ tf .less (tf .random_uniform (tf .shape (sampled_targets )),
566+ self ._hparams .scheduled_sampling_gold_mixin_prob ),
567+ gold_targets , sampled_targets )
568+
569+ def sampled_results ():
570+ """Generate scheduled sampling results."""
571+ sampled_targets = dp (sample , sharded_logits )
572+ new_targets = dp (mix_gold_sampled ,
573+ sharded_features ["targets" ], sampled_targets )
574+ new_features = transformed_features
575+ with tf .variable_scope (tf .get_variable_scope (), reuse = True ):
576+ with tf .variable_scope (target_modality .name ):
577+ new_features ["targets" ] = target_modality .targets_bottom_sharded (
578+ new_targets , dp )
579+ with tf .variable_scope ("body" ):
580+ body_outputs , losses = self .model_fn_body_sharded (new_features )
581+ if not isinstance (losses , dict ): # If it's a single extra loss.
582+ losses = {"extra" : losses }
583+ with tf .variable_scope (target_modality .name ):
584+ new_sharded_logits = target_modality .top_sharded (
585+ body_outputs , sharded_features ["targets" ], dp )
586+ training_loss = target_modality .loss_sharded (
587+ sharded_logits , sharded_features ["targets" ], dp )
588+ training_loss *= self ._problem_hparams .loss_multiplier
589+ losses ["training" ] = training_loss
590+ return new_sharded_logits , losses
591+ # Run the above conditionally.
592+ prob = self ._hparams .scheduled_sampling_prob
593+ prob *= common_layers .inverse_exp_decay (
594+ self ._hparams .scheduled_sampling_warmup_steps , min_value = 0.001 )
595+ sharded_logits , losses = tf .cond (
596+ tf .less (tf .random_uniform ([]), prob ),
597+ sampled_results ,
598+ lambda : (sharded_logits , losses ))
546599
547600 tf .logging .info ("This model_fn took %.3f sec." % (time .time () - start_time ))
548- losses ["training" ] = training_loss
549601 return sharded_logits , losses
550602
551603 def model_fn_body_sharded (self , sharded_features ):
0 commit comments