1616
1717"""Produces the training and dev data for --problem into --data_dir.
1818
19- generator.py produces sharded and shuffled TFRecord files of tensorflow.Example
20- protocol buffers for a variety of datasets registered in this file.
21-
22- All datasets are registered in _SUPPORTED_PROBLEM_GENERATORS. Each entry maps a
23- string name (selectable on the command-line with --problem) to a function that
24- takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
25- yields for each training example a dictionary mapping string feature names to
26- lists of {string, int, float}. The generator will be run once for each mode.
19+ Produces sharded and shuffled TFRecord files of tensorflow.Example protocol
20+ buffers for a variety of registered datasets.
21+
22+ All Problems are registered with @registry.register_problem or are in
23+ _SUPPORTED_PROBLEM_GENERATORS in this file. Each entry maps a string name
24+ (selectable on the command-line with --problem) to a function that takes 2
25+ arguments - input_directory and mode (one of "train" or "dev") - and yields for
26+ each training example a dictionary mapping string feature names to lists of
27+ {string, int, float}. The generator will be run once for each mode.
2728"""
2829from __future__ import absolute_import
2930from __future__ import division
@@ -113,40 +114,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
113114 lambda : wiki .generator (FLAGS .tmp_dir , True ),
114115 1000
115116 ),
116- "image_mnist_tune" : (
117- lambda : image .mnist_generator (FLAGS .tmp_dir , True , 55000 ),
118- lambda : image .mnist_generator (FLAGS .tmp_dir , True , 5000 , 55000 )),
119- "image_mnist_test" : (
120- lambda : image .mnist_generator (FLAGS .tmp_dir , True , 60000 ),
121- lambda : image .mnist_generator (FLAGS .tmp_dir , False , 10000 )),
122- "image_cifar10_tune" : (
123- lambda : image .cifar10_generator (FLAGS .tmp_dir , True , 48000 ),
124- lambda : image .cifar10_generator (FLAGS .tmp_dir , True , 2000 , 48000 )),
125- "image_cifar10_test" : (
126- lambda : image .cifar10_generator (FLAGS .tmp_dir , True , 50000 ),
127- lambda : image .cifar10_generator (FLAGS .tmp_dir , False , 10000 )),
128- "image_mscoco_characters_test" : (
129- lambda : image .mscoco_generator (
130- FLAGS .data_dir , FLAGS .tmp_dir , True , 80000 ),
131- lambda : image .mscoco_generator (
132- FLAGS .data_dir , FLAGS .tmp_dir , False , 40000 )),
133117 "image_celeba_tune" : (
134118 lambda : image .celeba_generator (FLAGS .tmp_dir , 162770 ),
135119 lambda : image .celeba_generator (FLAGS .tmp_dir , 19867 , 162770 )),
136- "image_mscoco_tokens_8k_test" : (
137- lambda : image .mscoco_generator (
138- FLAGS .data_dir , FLAGS .tmp_dir , True , 80000 ,
139- vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 ),
140- lambda : image .mscoco_generator (
141- FLAGS .data_dir , FLAGS .tmp_dir , False , 40000 ,
142- vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 )),
143- "image_mscoco_tokens_32k_test" : (
144- lambda : image .mscoco_generator (
145- FLAGS .data_dir , FLAGS .tmp_dir , True , 80000 ,
146- vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 ),
147- lambda : image .mscoco_generator (
148- FLAGS .data_dir , FLAGS .tmp_dir , False , 40000 ,
149- vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 )),
150120 "snli_32k" : (
151121 lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
152122 lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
@@ -255,8 +225,7 @@ def generate_data_for_problem(problem):
255225 num_shards = FLAGS .num_shards or 10
256226 tf .logging .info ("Generating training data for %s." , problem )
257227 train_output_files = generator_utils .train_data_filenames (
258- problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
259- num_shards )
228+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir , num_shards )
260229 generator_utils .generate_files (training_gen (), train_output_files ,
261230 FLAGS .max_cases )
262231 tf .logging .info ("Generating development data for %s." , problem )
@@ -276,9 +245,10 @@ def generate_data_for_registered_problem(problem_name):
276245 raise ValueError ("--num_shards should not be set for registered Problem." )
277246 problem = registry .problem (problem_name )
278247 task_id = None if FLAGS .task_id < 0 else FLAGS .task_id
279- problem .generate_data (os .path .expanduser (FLAGS .data_dir ),
280- os .path .expanduser (FLAGS .tmp_dir ),
281- task_id = task_id )
248+ problem .generate_data (
249+ os .path .expanduser (FLAGS .data_dir ),
250+ os .path .expanduser (FLAGS .tmp_dir ),
251+ task_id = task_id )
282252
283253
284254if __name__ == "__main__" :
0 commit comments