@@ -90,25 +90,16 @@ _SUPPORTED_PROBLEM_GENERATORS = {
9090 "algorithmic_reverse_nlplike_decimal8K" : (
9191 lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 100000 ,
9292 10 , 1.300 ),
93- lambda : algorithmic .reverse_generator_nlplike (8000 , 700 , 10000 ,
93+ lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 10000 ,
9494 10 , 1.300 )),
9595 "algorithmic_reverse_nlplike_decimal32K" : (
9696 lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 100000 ,
9797 10 , 1.050 ),
98- lambda : algorithmic .reverse_generator_nlplike (32000 , 700 , 10000 ,
98+ lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 10000 ,
9999 10 , 1.050 )),
100100 "algorithmic_algebra_inverse" : (
101101 lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
102102 lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
103- "algorithmic_algebra_simplify" : (
104- lambda : algorithmic_math .algebra_simplify (8 , 0 , 2 , 100000 ),
105- lambda : algorithmic_math .algebra_simplify (8 , 3 , 3 , 10000 )),
106- "algorithmic_calculus_integrate" : (
107- lambda : algorithmic_math .calculus_integrate (8 , 0 , 2 , 100000 ),
108- lambda : algorithmic_math .calculus_integrate (8 , 3 , 3 , 10000 )),
109- "wmt_parsing_characters" : (
110- lambda : wmt .parsing_character_generator (FLAGS .tmp_dir , True ),
111- lambda : wmt .parsing_character_generator (FLAGS .tmp_dir , False )),
112103 "wmt_parsing_tokens_8k" : (
113104 lambda : wmt .parsing_token_generator (FLAGS .tmp_dir , True , 2 ** 13 ),
114105 lambda : wmt .parsing_token_generator (FLAGS .tmp_dir , False , 2 ** 13 )),
@@ -133,10 +124,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
133124 lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
134125 lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
135126 ),
136- "wmt_enfr_tokens_128k" : (
137- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 17 ),
138- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 17 )
139- ),
140127 "wmt_ende_characters" : (
141128 lambda : wmt .ende_character_generator (FLAGS .tmp_dir , True ),
142129 lambda : wmt .ende_character_generator (FLAGS .tmp_dir , False )),
@@ -151,10 +138,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
151138 lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
152139 lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
153140 ),
154- "wmt_ende_tokens_128k" : (
155- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 17 ),
156- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 17 )
157- ),
158141 "image_mnist_tune" : (
159142 lambda : image .mnist_generator (FLAGS .tmp_dir , True , 55000 ),
160143 lambda : image .mnist_generator (FLAGS .tmp_dir , True , 5000 , 55000 )),
@@ -227,33 +210,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
227210 40000 ,
228211 vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
229212 vocab_size = 2 ** 15 )),
230- "image_mscoco_tokens_128k_tune" : (
231- lambda : image .mscoco_generator (
232- FLAGS .tmp_dir ,
233- True ,
234- 70000 ,
235- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
236- vocab_size = 2 ** 17 ),
237- lambda : image .mscoco_generator (
238- FLAGS .tmp_dir ,
239- True ,
240- 10000 ,
241- 70000 ,
242- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
243- vocab_size = 2 ** 17 )),
244- "image_mscoco_tokens_128k_test" : (
245- lambda : image .mscoco_generator (
246- FLAGS .tmp_dir ,
247- True ,
248- 80000 ,
249- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
250- vocab_size = 2 ** 17 ),
251- lambda : image .mscoco_generator (
252- FLAGS .tmp_dir ,
253- False ,
254- 40000 ,
255- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
256- vocab_size = 2 ** 17 )),
257213 "snli_32k" : (
258214 lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
259215 lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
@@ -340,10 +296,31 @@ def set_random_seed():
340296
341297def main (_ ):
342298 tf .logging .set_verbosity (tf .logging .INFO )
343- if FLAGS .problem not in _SUPPORTED_PROBLEM_GENERATORS :
299+
300+ # Calculate the list of problems to generate.
301+ problems = list (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
302+ if FLAGS .problem and FLAGS .problem [- 1 ] == "*" :
303+ problems = [p for p in problems if p .startswith (FLAGS .problem [:- 1 ])]
304+ elif FLAGS .problem :
305+ problems = [p for p in problems if p == FLAGS .problem ]
306+ else :
307+ problems = []
308+ # Remove TIMIT if paths are not given.
309+ if not FLAGS .timit_paths :
310+ problems = [p for p in problems if "timit" not in p ]
311+ # Remove parsing if paths are not given.
312+ if not FLAGS .parsing_path :
313+ problems = [p for p in problems if "parsing" not in p ]
314+ # Remove en-de BPE if paths are not given.
315+ if not FLAGS .ende_bpe_path :
316+ problems = [p for p in problems if "ende_bpe" not in p ]
317+
318+ if not problems :
344319 problems_str = "\n * " .join (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
345320 error_msg = ("You must specify one of the supported problems to "
346321 "generate data for:\n * " + problems_str + "\n " )
322+ error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
323+ "--timit_paths, --ende_bpe_path and --parsing_path." )
347324 raise ValueError (error_msg )
348325
349326 if not FLAGS .data_dir :
@@ -352,26 +329,28 @@ def main(_):
352329 "Data will be written to default data_dir=%s." ,
353330 FLAGS .data_dir )
354331
355- set_random_seed ()
332+ tf .logging .info ("Generating problems:\n * %s\n " % "\n * " .join (problems ))
333+ for problem in problems :
334+ set_random_seed ()
356335
357- training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [FLAGS . problem ]
336+ training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
358337
359- tf .logging .info ("Generating training data for %s." , FLAGS . problem )
360- train_output_files = generator_utils .generate_files (
361- training_gen (), FLAGS . problem + UNSHUFFLED_SUFFIX + "-train" ,
362- FLAGS .data_dir , FLAGS .num_shards , FLAGS .max_cases )
338+ tf .logging .info ("Generating training data for %s." , problem )
339+ train_output_files = generator_utils .generate_files (
340+ training_gen (), problem + UNSHUFFLED_SUFFIX + "-train" ,
341+ FLAGS .data_dir , FLAGS .num_shards , FLAGS .max_cases )
363342
364- tf .logging .info ("Generating development data for %s." , FLAGS . problem )
365- dev_output_files = generator_utils .generate_files (
366- dev_gen (), FLAGS . problem + UNSHUFFLED_SUFFIX + "-dev" , FLAGS .data_dir , 1 )
343+ tf .logging .info ("Generating development data for %s." , problem )
344+ dev_output_files = generator_utils .generate_files (
345+ dev_gen (), problem + UNSHUFFLED_SUFFIX + "-dev" , FLAGS .data_dir , 1 )
367346
368- tf .logging .info ("Shuffling data..." )
369- for fname in train_output_files + dev_output_files :
370- records = generator_utils .read_records (fname )
371- random .shuffle (records )
372- out_fname = fname .replace (UNSHUFFLED_SUFFIX , "" )
373- generator_utils .write_records (records , out_fname )
374- tf .gfile .Remove (fname )
347+ tf .logging .info ("Shuffling data..." )
348+ for fname in train_output_files + dev_output_files :
349+ records = generator_utils .read_records (fname )
350+ random .shuffle (records )
351+ out_fname = fname .replace (UNSHUFFLED_SUFFIX , "" )
352+ generator_utils .write_records (records , out_fname )
353+ tf .gfile .Remove (fname )
375354
376355
377356if __name__ == "__main__" :
0 commit comments