Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 84445cc

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Change the format of generated vocab files to include languages, put them in data_dir, add --generate_data option.
PiperOrigin-RevId: 162569315
1 parent 293b5f6 commit 84445cc

File tree

12 files changed

+150
-158
lines changed

12 files changed

+150
-158
lines changed

README.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ issues](https://github.com/tensorflow/tensor2tensor/issues).
2626
And chat with us and other users on
2727
[Gitter](https://gitter.im/tensor2tensor/Lobby).
2828

29+
Here is a one-command version that installs tensor2tensor, downloads the data,
30+
trains an English-German translation model, and lets you use it interactively:
31+
```
32+
pip install tensor2tensor && t2t-trainer \
33+
--generate_data \
34+
--data_dir=~/t2t_data \
35+
--problems=wmt_ende_tokens_32k \
36+
--model=transformer \
37+
--hparams_set=transformer_base_single_gpu \
38+
--output_dir=~/t2t_train/base \
39+
--decode_interactive
40+
```
41+
42+
See the [Walkthrough](#walkthrough) below for more details on each step.
43+
2944
### Contents
3045

3146
* [Walkthrough](#walkthrough)
@@ -72,8 +87,6 @@ t2t-datagen \
7287
--num_shards=100 \
7388
--problem=$PROBLEM
7489
75-
cp $TMP_DIR/tokens.vocab.* $DATA_DIR
76-
7790
# Train
7891
# * If you run out of memory, add --hparams='batch_size=2048' or even 1024.
7992
t2t-trainer \

tensor2tensor/bin/t2t-datagen

Lines changed: 42 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -80,24 +80,30 @@ _SUPPORTED_PROBLEM_GENERATORS = {
8080
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
8181
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
8282
"ice_parsing_tokens": (
83-
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
84-
True, "ice", 2**13, 2**8),
85-
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
86-
False, "ice", 2**13, 2**8)),
83+
lambda: wmt.tabbed_parsing_token_generator(
84+
FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8),
85+
lambda: wmt.tabbed_parsing_token_generator(
86+
FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)),
8787
"ice_parsing_characters": (
88-
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
89-
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
88+
lambda: wmt.tabbed_parsing_character_generator(
89+
FLAGS.data_dir, FLAGS.tmp_dir, True),
90+
lambda: wmt.tabbed_parsing_character_generator(
91+
FLAGS.data_dir, FLAGS.tmp_dir, False)),
9092
"wmt_parsing_tokens_8k": (
91-
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
92-
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
93+
lambda: wmt.parsing_token_generator(
94+
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
95+
lambda: wmt.parsing_token_generator(
96+
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)),
9397
"wsj_parsing_tokens_16k": (
94-
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
95-
2**14, 2**9),
96-
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
97-
2**14, 2**9)),
98+
lambda: wsj_parsing.parsing_token_generator(
99+
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
100+
lambda: wsj_parsing.parsing_token_generator(
101+
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)),
98102
"wmt_ende_bpe32k": (
99-
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
100-
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
103+
lambda: wmt.ende_bpe_token_generator(
104+
FLAGS.data_dir, FLAGS.tmp_dir, True),
105+
lambda: wmt.ende_bpe_token_generator(
106+
FLAGS.data_dir, FLAGS.tmp_dir, False)),
101107
"lm1b_32k": (
102108
lambda: lm1b.generator(FLAGS.tmp_dir, True),
103109
lambda: lm1b.generator(FLAGS.tmp_dir, False)
@@ -119,101 +125,50 @@ _SUPPORTED_PROBLEM_GENERATORS = {
119125
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
120126
lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
121127
"image_mscoco_characters_test": (
122-
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000),
123-
lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)),
128+
lambda: image.mscoco_generator(
129+
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000),
130+
lambda: image.mscoco_generator(
131+
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000)),
124132
"image_celeba_tune": (
125133
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
126134
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
127135
"image_mscoco_tokens_8k_test": (
128136
lambda: image.mscoco_generator(
129-
FLAGS.tmp_dir,
130-
True,
131-
80000,
132-
vocab_filename="tokens.vocab.%d" % 2**13,
133-
vocab_size=2**13),
137+
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
138+
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
134139
lambda: image.mscoco_generator(
135-
FLAGS.tmp_dir,
136-
False,
137-
40000,
138-
vocab_filename="tokens.vocab.%d" % 2**13,
139-
vocab_size=2**13)),
140+
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
141+
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)),
140142
"image_mscoco_tokens_32k_test": (
141143
lambda: image.mscoco_generator(
142-
FLAGS.tmp_dir,
143-
True,
144-
80000,
145-
vocab_filename="tokens.vocab.%d" % 2**15,
146-
vocab_size=2**15),
144+
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
145+
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
147146
lambda: image.mscoco_generator(
148-
FLAGS.tmp_dir,
149-
False,
150-
40000,
151-
vocab_filename="tokens.vocab.%d" % 2**15,
152-
vocab_size=2**15)),
147+
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
148+
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
153149
"snli_32k": (
154150
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
155151
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
156152
),
157-
"audio_timit_characters_tune": (
158-
lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1374),
159-
lambda: audio.timit_generator(FLAGS.tmp_dir, True, 344, 1374)),
160153
"audio_timit_characters_test": (
161-
lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1718),
162-
lambda: audio.timit_generator(FLAGS.tmp_dir, False, 626)),
163-
"audio_timit_tokens_8k_tune": (
164154
lambda: audio.timit_generator(
165-
FLAGS.tmp_dir,
166-
True,
167-
1374,
168-
vocab_filename="tokens.vocab.%d" % 2**13,
169-
vocab_size=2**13),
155+
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718),
170156
lambda: audio.timit_generator(
171-
FLAGS.tmp_dir,
172-
True,
173-
344,
174-
1374,
175-
vocab_filename="tokens.vocab.%d" % 2**13,
176-
vocab_size=2**13)),
157+
FLAGS.data_dir, FLAGS.tmp_dir, False, 626)),
177158
"audio_timit_tokens_8k_test": (
178159
lambda: audio.timit_generator(
179-
FLAGS.tmp_dir,
180-
True,
181-
1718,
182-
vocab_filename="tokens.vocab.%d" % 2**13,
183-
vocab_size=2**13),
184-
lambda: audio.timit_generator(
185-
FLAGS.tmp_dir,
186-
False,
187-
626,
188-
vocab_filename="tokens.vocab.%d" % 2**13,
189-
vocab_size=2**13)),
190-
"audio_timit_tokens_32k_tune": (
191-
lambda: audio.timit_generator(
192-
FLAGS.tmp_dir,
193-
True,
194-
1374,
195-
vocab_filename="tokens.vocab.%d" % 2**15,
196-
vocab_size=2**15),
160+
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718,
161+
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
197162
lambda: audio.timit_generator(
198-
FLAGS.tmp_dir,
199-
True,
200-
344,
201-
1374,
202-
vocab_filename="tokens.vocab.%d" % 2**15,
203-
vocab_size=2**15)),
163+
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
164+
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)),
204165
"audio_timit_tokens_32k_test": (
205166
lambda: audio.timit_generator(
206-
FLAGS.tmp_dir,
207-
True,
208-
1718,
209-
vocab_filename="tokens.vocab.%d" % 2**15,
210-
vocab_size=2**15),
167+
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718,
168+
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
211169
lambda: audio.timit_generator(
212-
FLAGS.tmp_dir,
213-
False,
214-
626,
215-
vocab_filename="tokens.vocab.%d" % 2**15,
216-
vocab_size=2**15)),
170+
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
171+
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
217172
"lmptb_10k": (
218173
lambda: ptb.train_generator(
219174
FLAGS.tmp_dir,

tensor2tensor/bin/t2t-trainer

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ from __future__ import print_function
3131

3232
# Dependency imports
3333

34-
from tensor2tensor.utils import trainer_utils as utils
34+
from tensor2tensor.utils import registry
35+
from tensor2tensor.utils import trainer_utils
3536
from tensor2tensor.utils import usr_dir
3637

3738
import tensorflow as tf
@@ -45,14 +46,29 @@ flags.DEFINE_string("t2t_usr_dir", "",
4546
"The imported files should contain registrations, "
4647
"e.g. @registry.register_model calls, that will then be "
4748
"available to the t2t-trainer.")
49+
flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
50+
"Temporary storage directory.")
51+
flags.DEFINE_bool("generate_data", False, "Generate data before training?")
4852

4953

5054
def main(_):
5155
tf.logging.set_verbosity(tf.logging.INFO)
5256
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
53-
utils.log_registry()
54-
utils.validate_flags()
55-
utils.run(
57+
trainer_utils.log_registry()
58+
trainer_utils.validate_flags()
59+
tf.gfile.MakeDirs(FLAGS.output_dir)
60+
61+
# Generate data if requested.
62+
if FLAGS.generate_data:
63+
tf.gfile.MakeDirs(FLAGS.data_dir)
64+
tf.gfile.MakeDirs(FLAGS.tmp_dir)
65+
for problem_name in FLAGS.problems.split("-"):
66+
tf.logging.info("Generating data for %s" % problem_name)
67+
problem = registry.problem(problem_name)
68+
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
69+
70+
# Run the trainer.
71+
trainer_utils.run(
5672
data_dir=FLAGS.data_dir,
5773
model=FLAGS.model,
5874
output_dir=FLAGS.output_dir,

tensor2tensor/data_generators/audio.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def _get_text_data(filepath):
9797
return " ".join(words)
9898

9999

100-
def timit_generator(tmp_dir,
100+
def timit_generator(data_dir,
101+
tmp_dir,
101102
training,
102103
how_many,
103104
start_from=0,
@@ -107,6 +108,7 @@ def timit_generator(tmp_dir,
107108
"""Data generator for TIMIT transcription problem.
108109
109110
Args:
111+
data_dir: path to the data directory.
110112
tmp_dir: path to temporary storage directory.
111113
training: a Boolean; if true, we use the train set, otherwise the test set.
112114
how_many: how many inputs and labels to generate.
@@ -128,7 +130,7 @@ def timit_generator(tmp_dir,
128130
eos_list = [1] if eos_list is None else eos_list
129131
if vocab_filename is not None:
130132
vocab_symbolizer = generator_utils.get_or_generate_vocab(
131-
tmp_dir, vocab_filename, vocab_size)
133+
data_dir, tmp_dir, vocab_filename, vocab_size)
132134
_get_timit(tmp_dir)
133135
datasets = (_TIMIT_TRAIN_DATASETS if training else _TIMIT_TEST_DATASETS)
134136
i = 0

tensor2tensor/data_generators/generator_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,13 @@ def gunzip_file(gz_path, new_path):
244244
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
245245
["un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr"]
246246
],
247-
[
248-
"https://github.com/stefan-it/nmt-mk-en/raw/master/data/setimes.mk-en.train.tgz", # pylint: disable=line-too-long
249-
["train.mk", "train.en"]
250-
],
251247
]
252248

253249

254-
def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
250+
def get_or_generate_vocab(data_dir, tmp_dir,
251+
vocab_filename, vocab_size, sources=None):
255252
"""Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
256-
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
253+
vocab_filepath = os.path.join(data_dir, vocab_filename)
257254
if tf.gfile.Exists(vocab_filepath):
258255
tf.logging.info("Found vocab file: %s", vocab_filepath)
259256
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
@@ -304,7 +301,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
304301
return vocab
305302

306303

307-
def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
304+
def get_or_generate_tabbed_vocab(data_dir, tmp_dir, source_filename,
308305
index, vocab_filename, vocab_size):
309306
r"""Generate a vocabulary from a tabbed source file.
310307
@@ -313,6 +310,7 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
313310
The index parameter specifies 0 for the source or 1 for the target.
314311
315312
Args:
313+
data_dir: path to the data directory.
316314
tmp_dir: path to the temporary directory.
317315
source_filename: the name of the tab-separated source file.
318316
index: index.
@@ -322,7 +320,7 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
322320
Returns:
323321
The vocabulary.
324322
"""
325-
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
323+
vocab_filepath = os.path.join(data_dir, vocab_filename)
326324
if os.path.exists(vocab_filepath):
327325
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
328326
return vocab

tensor2tensor/data_generators/image.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ def _get_mscoco(directory):
230230
zipfile.ZipFile(path, "r").extractall(directory)
231231

232232

233-
def mscoco_generator(tmp_dir,
233+
def mscoco_generator(data_dir,
234+
tmp_dir,
234235
training,
235236
how_many,
236237
start_from=0,
@@ -240,6 +241,7 @@ def mscoco_generator(tmp_dir,
240241
"""Image generator for MSCOCO captioning problem with token-wise captions.
241242
242243
Args:
244+
data_dir: path to the data directory.
243245
tmp_dir: path to temporary storage directory.
244246
training: a Boolean; if true, we use the train set, otherwise the test set.
245247
how_many: how many images and labels to generate.
@@ -261,7 +263,7 @@ def mscoco_generator(tmp_dir,
261263
eos_list = [1] if eos_list is None else eos_list
262264
if vocab_filename is not None:
263265
vocab_symbolizer = generator_utils.get_or_generate_vocab(
264-
tmp_dir, vocab_filename, vocab_size)
266+
data_dir, tmp_dir, vocab_filename, vocab_size)
265267
_get_mscoco(tmp_dir)
266268
caption_filepath = (_MSCOCO_TRAIN_CAPTION_FILE
267269
if training else _MSCOCO_EVAL_CAPTION_FILE)

tensor2tensor/data_generators/inspect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
python data_generators/inspect.py \
1818
--logtostderr \
1919
--print_targets \
20-
--subword_text_encoder_filename=$DATA_DIR/tokens.vocab.8192 \
20+
--subword_text_encoder_filename=$DATA_DIR/vocab.endefr.8192 \
2121
--input_filename=$DATA_DIR/wmt_ende_tokens_8k-train-00000-of-00100
2222
"""
2323

0 commit comments

Comments
 (0)