Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 172a1b1

Browse files
authored
Merge pull request #392 from vince62s/translate
Fix the EnZh task
2 parents 9d6dce3 + 733de7b commit 172a1b1

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

tensor2tensor/data_generators/translate_enzh.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,23 @@
3535

3636
# End-of-sentence marker.
3737
EOS = text_encoder.EOS_ID
38-
39-
_ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
38+
# This is far from being the real WMT17 task - only toyset here
39+
# you need to register to get UN data and CWT data
40+
# also by convention this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
41+
_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
4042
"training-parallel-nc-v12.tgz"),
41-
("training/news-commentary-v12.zh-en.zh",
42-
"training/news-commentary-v12.zh-en.en")]]
43+
("training/news-commentary-v12.zh-en.en",
44+
"training/news-commentary-v12.zh-en.zh")]]
4345

44-
_ZHEN_TEST_DATASETS = [[
46+
_ENZH_TEST_DATASETS = [[
4547
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
46-
("dev/newsdev2017-zhen-src.zh.sgm", "dev/newsdev2017-zhen-ref.en.sgm")
48+
("dev/newsdev2017-zhen-src.en.sgm", "dev/newsdev2017-zhen-ref.zh.sgm")
4749
]]
4850

4951

5052
@registry.register_problem
5153
class TranslateEnzhWmt8k(translate.TranslateProblem):
52-
"""Problem spec for WMT Zh-En translation."""
54+
"""Problem spec for WMT En-Zh translation."""
5355

5456
@property
5557
def targeted_vocab_size(self):
@@ -61,16 +63,16 @@ def num_shards(self):
6163

6264
@property
6365
def source_vocab_name(self):
64-
return "vocab.zhen-zh.%d" % self.targeted_vocab_size
66+
return "vocab.enzh-en.%d" % self.targeted_vocab_size
6567

6668
@property
6769
def target_vocab_name(self):
68-
return "vocab.zhen-en.%d" % self.targeted_vocab_size
70+
return "vocab.enzh-zh.%d" % self.targeted_vocab_size
6971

7072
def generator(self, data_dir, tmp_dir, train):
71-
datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
72-
source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS]
73-
target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS]
73+
datasets = _ENZH_TRAIN_DATASETS if train else _ENZH_TEST_DATASETS
74+
source_datasets = [[item[0], [item[1][0]]] for item in _ENZH_TRAIN_DATASETS]
75+
target_datasets = [[item[0], [item[1][1]]] for item in _ENZH_TRAIN_DATASETS]
7476
source_vocab = generator_utils.get_or_generate_vocab(
7577
data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size,
7678
source_datasets)
@@ -79,21 +81,18 @@ def generator(self, data_dir, tmp_dir, train):
7981
target_datasets)
8082
tag = "train" if train else "dev"
8183
data_path = translate.compile_data(tmp_dir, datasets,
82-
"wmt_zhen_tok_%s" % tag)
83-
# We generate English->X data by convention, to train reverse translation
84-
# just add the "_rev" suffix to the problem name, e.g., like this.
85-
# --problems=translate_enzh_wmt8k_rev
86-
return translate.bi_vocabs_token_generator(data_path + ".lang2",
87-
data_path + ".lang1",
84+
"wmt_enzh_tok_%s" % tag)
85+
return translate.bi_vocabs_token_generator(data_path + ".lang1",
86+
data_path + ".lang2",
8887
source_vocab, target_vocab, EOS)
8988

9089
@property
9190
def input_space_id(self):
92-
return problem.SpaceID.ZH_TOK
91+
return problem.SpaceID.EN_TOK
9392

9493
@property
9594
def target_space_id(self):
96-
return problem.SpaceID.EN_TOK
95+
return problem.SpaceID.ZH_TOK
9796

9897
def feature_encoders(self, data_dir):
9998
source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)

0 commit comments

Comments
 (0)