@@ -45,13 +45,13 @@ def decode_from_dataset(estimator):
4545 tf .logging .info ("Performing local inference." )
4646 infer_problems_data = data_reader .get_data_filepatterns (
4747 FLAGS .problems , hparams .data_dir , tf .contrib .learn .ModeKeys .INFER )
48+
4849 infer_input_fn = input_fn_builder .build_input_fn (
4950 mode = tf .contrib .learn .ModeKeys .INFER ,
5051 hparams = hparams ,
5152 data_file_patterns = infer_problems_data ,
5253 num_datashards = devices .data_parallelism ().n ,
5354 fixed_problem = i )
54- result_iter = estimator .predict (input_fn = infer_input_fn , as_iterable = False )
5555
5656 def log_fn (inputs ,
5757 targets ,
@@ -66,36 +66,47 @@ def log_fn(inputs,
6666 "%s_prediction_%d.jpg" % (problem , j ))
6767 show_and_save_image (inputs / 255. , save_path )
6868 elif inputs_vocab :
69- decoded_inputs = inputs_vocab .decode (_save_until_eos (inputs .flatten ()))
69+ decoded_inputs = inputs_vocab .decode (
70+ _save_until_eos (inputs .flatten ()))
7071 tf .logging .info ("Inference results INPUT: %s" % decoded_inputs )
7172
72- decoded_outputs = targets_vocab .decode (_save_until_eos (outputs .flatten ()))
73+ if FLAGS .identity_output :
74+ decoded_outputs = " " .join (map (str , outputs .flatten ()))
75+ decoded_targets = " " .join (map (str , targets .flatten ()))
76+ else :
77+ decoded_outputs = targets_vocab .decode (
78+ _save_until_eos (outputs .flatten ()))
79+ decoded_targets = targets_vocab .decode (
80+ _save_until_eos (targets .flatten ()))
81+
7382 tf .logging .info ("Inference results OUTPUT: %s" % decoded_outputs )
74- decoded_targets = targets_vocab .decode (_save_until_eos (targets .flatten ()))
7583 tf .logging .info ("Inference results TARGET: %s" % decoded_targets )
76-
7784 if FLAGS .decode_to_file :
7885 output_filepath = FLAGS .decode_to_file + ".outputs." + problem
7986 output_file = tf .gfile .Open (output_filepath , "a" )
8087 output_file .write (decoded_outputs + "\n " )
8188 target_filepath = FLAGS .decode_to_file + ".targets." + problem
8289 target_file = tf .gfile .Open (target_filepath , "a" )
8390 target_file .write (decoded_targets + "\n " )
84-
85- # The function predict() returns an iterable over the network's
86- # predictions from the test input. We use it to log inputs and decodes.
87- inputs_iter = result_iter ["inputs" ]
88- targets_iter = result_iter ["targets" ]
89- outputs_iter = result_iter ["outputs" ]
90- for j , result in enumerate (zip (inputs_iter , targets_iter , outputs_iter )):
91- inputs , targets , outputs = result
91+ result_iter = estimator .predict (input_fn = infer_input_fn , as_iterable = True )
92+ count = 0
93+ for result in result_iter :
94+ # predictions from the test input. We use it to log inputs and decodes.
95+ inputs = result ["inputs" ]
96+ targets = result ["targets" ]
97+ outputs = result ["outputs" ]
9298 if FLAGS .decode_return_beams :
9399 output_beams = np .split (outputs , FLAGS .decode_beam_size , axis = 0 )
94100 for k , beam in enumerate (output_beams ):
95101 tf .logging .info ("BEAM %d:" % k )
96- log_fn (inputs , targets , beam , problem , j )
102+ log_fn (inputs , targets , beam , problem , count )
97103 else :
98- log_fn (inputs , targets , outputs , problem , j )
104+ log_fn (inputs , targets , outputs , problem , count )
105+
106+ count += 1
107+ if FLAGS .decode_num_samples != - 1 and count >= FLAGS .decode_num_samples :
108+ break
109+ tf .logging .info ("Completed inference on %d samples." % count )
99110
100111
101112def decode_from_file (estimator , filename ):
0 commit comments