11#!/usr/bin/env python
2+ # coding=utf-8
23# Copyright 2017 The Tensor2Tensor Authors.
34#
45# Licensed under the Apache License, Version 2.0 (the "License");
@@ -42,7 +43,6 @@ from tensor2tensor.data_generators import audio
4243from tensor2tensor .data_generators import generator_utils
4344from tensor2tensor .data_generators import image
4445from tensor2tensor .data_generators import lm1b
45- from tensor2tensor .data_generators import ptb
4646from tensor2tensor .data_generators import snli
4747from tensor2tensor .data_generators import wiki
4848from tensor2tensor .data_generators import wmt
@@ -62,10 +62,12 @@ flags.DEFINE_string("problem", "",
6262 "The name of the problem to generate data for." )
6363flags .DEFINE_string ("exclude_problems" , "" ,
6464 "Comma-separates list of problems to exclude." )
65- flags .DEFINE_integer ("num_shards" , 10 , "How many shards to use." )
65+ flags .DEFINE_integer ("num_shards" , 0 , "How many shards to use. Ignored for "
66+ "registered Problems." )
6667flags .DEFINE_integer ("max_cases" , 0 ,
6768 "Maximum number of cases to generate (unbounded if 0)." )
6869flags .DEFINE_integer ("random_seed" , 429459 , "Random seed to use." )
70+ flags .DEFINE_integer ("task_id" , - 1 , "For distributed data generation." )
6971flags .DEFINE_string ("t2t_usr_dir" , "" ,
7072 "Path to a Python module that will be imported. The "
7173 "__init__.py file should include the necessary imports. "
@@ -103,6 +105,10 @@ _SUPPORTED_PROBLEM_GENERATORS = {
103105 lambda : lm1b .generator (FLAGS .tmp_dir , True ),
104106 lambda : lm1b .generator (FLAGS .tmp_dir , False )
105107 ),
108+ "lm1b_characters" : (
109+ lambda : lm1b .generator (FLAGS .tmp_dir , True , characters = True ),
110+ lambda : lm1b .generator (FLAGS .tmp_dir , False , characters = True )
111+ ),
106112 "wiki_32k" : (
107113 lambda : wiki .generator (FLAGS .tmp_dir , True ),
108114 1000
@@ -164,12 +170,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
164170 lambda : audio .timit_generator (
165171 FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ,
166172 vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 )),
167- "lmptb_10k" : (
168- lambda : ptb .train_generator (
169- FLAGS .tmp_dir ,
170- FLAGS .data_dir ,
171- False ),
172- ptb .valid_generator ),
173173}
174174
175175# pylint: enable=g-long-lambda
@@ -241,7 +241,7 @@ def generate_data_for_problem(problem):
241241 if isinstance (dev_gen , int ):
242242 # The dev set and test sets are generated as extra shards using the
243243 # training generator. The integer specifies the number of training
244- # shards. FLAGS.num_shards is ignored.
244+ # shards. FLAGS.num_shards is ignored.
245245 num_training_shards = dev_gen
246246 tf .logging .info ("Generating data for %s." , problem )
247247 all_output_files = generator_utils .combined_data_filenames (
@@ -252,10 +252,11 @@ def generate_data_for_problem(problem):
252252 else :
253253 # usual case - train data and dev data are generated using separate
254254 # generators.
255+ num_shards = FLAGS .num_shards or 10
255256 tf .logging .info ("Generating training data for %s." , problem )
256257 train_output_files = generator_utils .train_data_filenames (
257258 problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
258- FLAGS . num_shards )
259+ num_shards )
259260 generator_utils .generate_files (training_gen (), train_output_files ,
260261 FLAGS .max_cases )
261262 tf .logging .info ("Generating development data for %s." , problem )
@@ -270,10 +271,14 @@ def generate_data_for_problem(problem):
270271
271272
272273def generate_data_for_registered_problem (problem_name ):
274+ tf .logging .info ("Generating training data for %s." , problem_name )
275+ if FLAGS .num_shards :
276+ raise ValueError ("--num_shards should not be set for registered Problem." )
273277 problem = registry .problem (problem_name )
278+ task_id = None if FLAGS .task_id < 0 else FLAGS .task_id
274279 problem .generate_data (os .path .expanduser (FLAGS .data_dir ),
275280 os .path .expanduser (FLAGS .tmp_dir ),
276- FLAGS . num_shards )
281+ task_id = task_id )
277282
278283
279284if __name__ == "__main__" :
0 commit comments