Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 43dbf4c

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
First (simple) version of scheduled sampling.
PiperOrigin-RevId: 172038992
1 parent 39fd769 commit 43dbf4c

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

tensor2tensor/data_generators/wmt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ def _compile_data(tmp_dir, datasets, filename):
375375
compressed_filename = os.path.basename(url)
376376
compressed_filepath = os.path.join(tmp_dir, compressed_filename)
377377

378+
generator_utils.maybe_download(tmp_dir, compressed_filename, url)
379+
378380
if dataset[1][0] == "tsv":
379381
_, src_column, trg_column, glob_pattern = dataset[1]
380382
filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern))

tensor2tensor/layers/common_hparams.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ def basic_params1():
160160
# entire inputs portion. This removes the challenge of
161161
# autoregressively predicting the inputs portion.
162162
prepend_mode="none",
163+
# Scheduled sampling is interesting for auto-regressive models.
164+
# It runs an additional step using the generated output as autoregressive
165+
# targets, which can improve the models inference results later. The
166+
# parameter scheduled_sampling_prob determines with what probability
167+
# will such additional step be run. It's turned off (0.0) by default.
168+
# This probability will exponentially warm up for the number of
169+
# steps determined by scheduled_sampling_warmup_steps.
170+
# The tensor used for the second step will consist of outputs from
171+
# the first step mixed with gold truth, with the proportion of gold
172+
# determined by scheduled_sampling_gold_mixin_prob.
173+
scheduled_sampling_prob=0.0,
174+
scheduled_sampling_warmup_steps=50000,
175+
scheduled_sampling_gold_mixin_prob=0.5,
163176
# This is the actual batch size, *not* tokens per batch (i.e. for
164177
# language models this is the number of sentences in the batch)
165178
tpu_batch_size_per_shard=24,)

tensor2tensor/utils/t2t_model.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import six
2727
from six.moves import xrange # pylint: disable=redefined-builtin
2828

29+
from tensor2tensor.layers import common_layers
2930
from tensor2tensor.utils import beam_search
3031
from tensor2tensor.utils import expert_utils as eu
3132
from tensor2tensor.utils import registry
@@ -523,9 +524,9 @@ def model_fn(self, features, skip=False, last_position_only=False):
523524
with tf.variable_scope(target_modality.name, reuse=target_reuse):
524525
if not last_position_only:
525526
sharded_logits = target_modality.top_sharded(
526-
body_outputs, sharded_features["targets"], self._data_parallelism)
527+
body_outputs, sharded_features["targets"], dp)
527528
training_loss = target_modality.loss_sharded(
528-
sharded_logits, sharded_features["targets"], self._data_parallelism)
529+
sharded_logits, sharded_features["targets"], dp)
529530

530531
training_loss *= self._problem_hparams.loss_multiplier
531532
else:
@@ -543,9 +544,60 @@ def model_fn(self, features, skip=False, last_position_only=False):
543544
last_position_targets,
544545
self._data_parallelism)
545546
training_loss = None
547+
losses["training"] = training_loss
548+
549+
# Scheduled sampling.
550+
do_scheduled_sampling = ( # Only do it if training and set for it.
551+
self._hparams.scheduled_sampling_prob > 0.0 and
552+
self._hparams.mode == tf.estimator.ModeKeys.TRAIN and
553+
not skip)
554+
if do_scheduled_sampling:
555+
556+
def sample(x):
557+
"""Multinomial sampling from a n-dimensional tensor."""
558+
vocab_size = target_modality.top_dimensionality
559+
samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]), 1)
560+
reshaped_samples = tf.reshape(samples, tf.shape(x)[:-1])
561+
return tf.to_int32(reshaped_samples)
562+
563+
def mix_gold_sampled(gold_targets, sampled_targets):
564+
return tf.where(
565+
tf.less(tf.random_uniform(tf.shape(sampled_targets)),
566+
self._hparams.scheduled_sampling_gold_mixin_prob),
567+
gold_targets, sampled_targets)
568+
569+
def sampled_results():
570+
"""Generate scheduled sampling results."""
571+
sampled_targets = dp(sample, sharded_logits)
572+
new_targets = dp(mix_gold_sampled,
573+
sharded_features["targets"], sampled_targets)
574+
new_features = transformed_features
575+
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
576+
with tf.variable_scope(target_modality.name):
577+
new_features["targets"] = target_modality.targets_bottom_sharded(
578+
new_targets, dp)
579+
with tf.variable_scope("body"):
580+
body_outputs, losses = self.model_fn_body_sharded(new_features)
581+
if not isinstance(losses, dict): # If it's a single extra loss.
582+
losses = {"extra": losses}
583+
with tf.variable_scope(target_modality.name):
584+
new_sharded_logits = target_modality.top_sharded(
585+
body_outputs, sharded_features["targets"], dp)
586+
training_loss = target_modality.loss_sharded(
587+
sharded_logits, sharded_features["targets"], dp)
588+
training_loss *= self._problem_hparams.loss_multiplier
589+
losses["training"] = training_loss
590+
return new_sharded_logits, losses
591+
# Run the above conditionally.
592+
prob = self._hparams.scheduled_sampling_prob
593+
prob *= common_layers.inverse_exp_decay(
594+
self._hparams.scheduled_sampling_warmup_steps, min_value=0.001)
595+
sharded_logits, losses = tf.cond(
596+
tf.less(tf.random_uniform([]), prob),
597+
sampled_results,
598+
lambda: (sharded_logits, losses))
546599

547600
tf.logging.info("This model_fn took %.3f sec." % (time.time() - start_time))
548-
losses["training"] = training_loss
549601
return sharded_logits, losses
550602

551603
def model_fn_body_sharded(self, sharded_features):

0 commit comments

Comments
 (0)