Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a9826de

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Extend decode_from_dataset to run decode iteratively for specified number of samples rather than one
PiperOrigin-RevId: 164761976
1 parent 0eeb116 commit a9826de

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

tensor2tensor/utils/decoding.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ def decode_from_dataset(estimator):
4545
tf.logging.info("Performing local inference.")
4646
infer_problems_data = data_reader.get_data_filepatterns(
4747
FLAGS.problems, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER)
48+
4849
infer_input_fn = input_fn_builder.build_input_fn(
4950
mode=tf.contrib.learn.ModeKeys.INFER,
5051
hparams=hparams,
5152
data_file_patterns=infer_problems_data,
5253
num_datashards=devices.data_parallelism().n,
5354
fixed_problem=i)
54-
result_iter = estimator.predict(input_fn=infer_input_fn, as_iterable=False)
5555

5656
def log_fn(inputs,
5757
targets,
@@ -66,36 +66,47 @@ def log_fn(inputs,
6666
"%s_prediction_%d.jpg" % (problem, j))
6767
show_and_save_image(inputs / 255., save_path)
6868
elif inputs_vocab:
69-
decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten()))
69+
decoded_inputs = inputs_vocab.decode(
70+
_save_until_eos(inputs.flatten()))
7071
tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
7172

72-
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs.flatten()))
73+
if FLAGS.identity_output:
74+
decoded_outputs = " ".join(map(str, outputs.flatten()))
75+
decoded_targets = " ".join(map(str, targets.flatten()))
76+
else:
77+
decoded_outputs = targets_vocab.decode(
78+
_save_until_eos(outputs.flatten()))
79+
decoded_targets = targets_vocab.decode(
80+
_save_until_eos(targets.flatten()))
81+
7382
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
74-
decoded_targets = targets_vocab.decode(_save_until_eos(targets.flatten()))
7583
tf.logging.info("Inference results TARGET: %s" % decoded_targets)
76-
7784
if FLAGS.decode_to_file:
7885
output_filepath = FLAGS.decode_to_file + ".outputs." + problem
7986
output_file = tf.gfile.Open(output_filepath, "a")
8087
output_file.write(decoded_outputs + "\n")
8188
target_filepath = FLAGS.decode_to_file + ".targets." + problem
8289
target_file = tf.gfile.Open(target_filepath, "a")
8390
target_file.write(decoded_targets + "\n")
84-
85-
# The function predict() returns an iterable over the network's
86-
# predictions from the test input. We use it to log inputs and decodes.
87-
inputs_iter = result_iter["inputs"]
88-
targets_iter = result_iter["targets"]
89-
outputs_iter = result_iter["outputs"]
90-
for j, result in enumerate(zip(inputs_iter, targets_iter, outputs_iter)):
91-
inputs, targets, outputs = result
91+
result_iter = estimator.predict(input_fn=infer_input_fn, as_iterable=True)
92+
count = 0
93+
for result in result_iter:
94+
# predictions from the test input. We use it to log inputs and decodes.
95+
inputs = result["inputs"]
96+
targets = result["targets"]
97+
outputs = result["outputs"]
9298
if FLAGS.decode_return_beams:
9399
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
94100
for k, beam in enumerate(output_beams):
95101
tf.logging.info("BEAM %d:" % k)
96-
log_fn(inputs, targets, beam, problem, j)
102+
log_fn(inputs, targets, beam, problem, count)
97103
else:
98-
log_fn(inputs, targets, outputs, problem, j)
104+
log_fn(inputs, targets, outputs, problem, count)
105+
106+
count += 1
107+
if FLAGS.decode_num_samples != -1 and count >= FLAGS.decode_num_samples:
108+
break
109+
tf.logging.info("Completed inference on %d samples." % count)
99110

100111

101112
def decode_from_file(estimator, filename):

tensor2tensor/utils/trainer_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@
121121
flags.DEFINE_integer("decode_max_input_size", -1,
122122
"Maximum number of ids in input. Or <= 0 for no max.")
123123
flags.DEFINE_bool("identity_output", False, "To print the output as identity")
124+
flags.DEFINE_integer("decode_num_samples", -1,
125+
"Number of samples to decode. Currently used in"
126+
"decode_from_dataset. Use -1 for all.")
124127

125128

126129
def make_experiment_fn(data_dir, model_name, train_steps, eval_steps):

0 commit comments

Comments
 (0)