Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7566c4d

Browse files
authored
Merge pull request #151 from cshanbo/wmt_zhen_translate
add wmt_zhen_token_32k
2 parents 9a1f888 + 235392c commit 7566c4d

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ _SUPPORTED_PROBLEM_GENERATORS = {
140140
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
141141
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
142142
),
143+
"wmt_zhen_tokens_32k": (
144+
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
145+
2**15, 2**15),
146+
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, False,
147+
2**15, 2**15)
148+
),
143149
"lm1b_32k": (
144150
lambda: lm1b.generator(FLAGS.tmp_dir, True),
145151
lambda: lm1b.generator(FLAGS.tmp_dir, False)

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def default_problem_hparams():
177177
# 13: Audio spectral domain
178178
# 14: Parse characters
179179
# 15: Parse tokens
180+
# 16: Chinese tokens
180181
# Add more above if needed.
181182
input_space_id=0,
182183
target_space_id=0,
@@ -472,6 +473,32 @@ def wmt_ende_tokens(model_hparams, wrong_vocab_size):
472473
return p
473474

474475

476+
def wmt_zhen_tokens(model_hparams, wrong_vocab_size):
477+
"""Chinese to English translation benchmark."""
478+
p = default_problem_hparams()
479+
# This vocab file must be present within the data directory.
480+
if model_hparams.shared_embedding_and_softmax_weights == 1:
481+
model_hparams.shared_embedding_and_softmax_weights = 0
482+
source_vocab_filename = os.path.join(model_hparams.data_dir,
483+
"tokens.vocab.zh.%d" % wrong_vocab_size)
484+
target_vocab_filename = os.path.join(model_hparams.data_dir,
485+
"tokens.vocab.en.%d" % wrong_vocab_size)
486+
source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
487+
target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
488+
p.input_modality = {
489+
"inputs": (registry.Modalities.SYMBOL, source_token.vocab_size)
490+
}
491+
p.target_modality = (registry.Modalities.SYMBOL, target_token.vocab_size)
492+
p.vocabulary = {
493+
"inputs": source_token,
494+
"targets": target_token,
495+
}
496+
p.loss_multiplier = 1.4
497+
p.input_space_id = 16
498+
p.target_space_id = 4
499+
return p
500+
501+
475502
def wmt_ende_v2(model_hparams, vocab_size):
476503
"""English to German translation benchmark with separate vocabularies."""
477504
p = default_problem_hparams()
@@ -730,6 +757,7 @@ def img2img_imagenet(unused_model_hparams):
730757
"wmt_ende_bpe32k_160": wmt_ende_bpe32k,
731758
"wmt_ende_v2_32k_combined": lambda p: wmt_ende_v2(p, 2**15),
732759
"wmt_ende_v2_16k_combined": lambda p: wmt_ende_v2(p, 2**14),
760+
"wmt_zhen_tokens_32k": lambda p: wmt_zhen_tokens(p, 2**15),
733761
"image_cifar10_tune": image_cifar10,
734762
"image_cifar10_test": image_cifar10,
735763
"image_mnist_tune": image_mnist,

tensor2tensor/data_generators/wmt.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,38 @@ def token_generator(source_path, target_path, token_vocab, eos=None):
101101
source, target = source_file.readline(), target_file.readline()
102102

103103

104+
def bi_vocabs_token_generator(source_path, target_path,
105+
source_token_vocab,
106+
target_token_vocab,
107+
eos=None):
108+
"""Generator for sequence-to-sequence tasks that uses tokens.
109+
110+
This generator assumes the files at source_path and target_path have
111+
the same number of lines and yields dictionaries of "inputs" and "targets"
112+
where inputs are token ids from the " "-split source (and target, resp.) lines
113+
converted to integers using the token_map.
114+
115+
Args:
116+
source_path: path to the file with source sentences.
117+
target_path: path to the file with target sentences.
118+
source_token_vocab: text_encoder.TextEncoder object.
119+
target_token_vocab: text_encoder.TextEncoder object.
120+
eos: integer to append at the end of each sequence (default: None).
121+
122+
Yields:
123+
A dictionary {"inputs": source-line, "targets": target-line} where
124+
the lines are integer lists converted from tokens in the file lines.
125+
"""
126+
eos_list = [] if eos is None else [eos]
127+
with tf.gfile.GFile(source_path, mode="r") as source_file:
128+
with tf.gfile.GFile(target_path, mode="r") as target_file:
129+
source, target = source_file.readline(), target_file.readline()
130+
while source and target:
131+
source_ints = source_token_vocab.encode(source.strip()) + eos_list
132+
target_ints = target_token_vocab.encode(target.strip()) + eos_list
133+
yield {"inputs": source_ints, "targets": target_ints}
134+
source, target = source_file.readline(), target_file.readline()
135+
104136
def _get_wmt_ende_dataset(directory, filename):
105137
"""Extract the WMT en-de corpus `filename` to directory unless it's there."""
106138
train_path = os.path.join(directory, filename)
@@ -177,6 +209,21 @@ def ende_bpe_token_generator(tmp_dir, train):
177209
],
178210
]
179211

212+
_ZHEN_TRAIN_DATASETS = [
213+
[
214+
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz",
215+
("training/news-commentary-v12.zh-en.zh",
216+
"training/news-commentary-v12.zh-en.en")
217+
]
218+
]
219+
220+
_ZHEN_TEST_DATASETS = [
221+
[
222+
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
223+
("dev/newsdev2017-zhen-src.zh",
224+
"dev/newsdev2017-zhen-ref.en")
225+
]
226+
]
180227

181228
def _compile_data(tmp_dir, datasets, filename):
182229
"""Concatenate all `datasets` and save to `filename`."""
@@ -253,6 +300,25 @@ def ende_character_generator(tmp_dir, train):
253300
character_vocab, EOS)
254301

255302

303+
def zhen_wordpiece_token_generator(tmp_dir, train,
304+
source_vocab_size,
305+
target_vocab_size):
306+
datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
307+
source_datasets = [[item[0], [item[1][0]]] for item in datasets]
308+
target_datasets = [[item[0], [item[1][1]]] for item in datasets]
309+
source_vocab = generator_utils.get_or_generate_vocab(
310+
tmp_dir, "tokens.vocab.zh.%d" % source_vocab_size,
311+
source_vocab_size, source_datasets)
312+
target_vocab = generator_utils.get_or_generate_vocab(
313+
tmp_dir, "tokens.vocab.en.%d" % target_vocab_size,
314+
target_vocab_size, target_datasets)
315+
tag = "train" if train else "dev"
316+
data_path = _compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag)
317+
return bi_vocabs_token_generator(data_path + ".lang1",
318+
data_path + ".lang2",
319+
source_vocab, target_vocab, EOS)
320+
321+
256322
def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
257323
"""Instance of token generator for the WMT en->fr task."""
258324
symbolizer_vocab = generator_utils.get_or_generate_vocab(

0 commit comments

Comments
 (0)