3535
3636# End-of-sentence marker.
3737EOS = text_encoder .EOS_ID
38-
39- _ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
38+ # This is far from being the real WMT17 task - only toyset here
39+ # you need to register to get UN data and CWT data
40+ # also by convention this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
41+ _ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
4042 "training-parallel-nc-v12.tgz" ),
41- ("training/news-commentary-v12.zh-en.zh " ,
42- "training/news-commentary-v12.zh-en.en " )]]
43+ ("training/news-commentary-v12.zh-en.en " ,
44+ "training/news-commentary-v12.zh-en.zh " )]]
4345
44- _ZHEN_TEST_DATASETS = [[
46+ _ENZH_TEST_DATASETS = [[
4547 "http://data.statmt.org/wmt17/translation-task/dev.tgz" ,
46- ("dev/newsdev2017-zhen-src.zh .sgm" , "dev/newsdev2017-zhen-ref.en .sgm" )
48+ ("dev/newsdev2017-zhen-src.en .sgm" , "dev/newsdev2017-zhen-ref.zh .sgm" )
4749]]
4850
4951
5052@registry .register_problem
5153class TranslateEnzhWmt8k (translate .TranslateProblem ):
52- """Problem spec for WMT Zh-En translation."""
54+ """Problem spec for WMT En-Zh translation."""
5355
5456 @property
5557 def targeted_vocab_size (self ):
@@ -61,16 +63,16 @@ def num_shards(self):
6163
6264 @property
6365 def source_vocab_name (self ):
64- return "vocab.zhen-zh .%d" % self .targeted_vocab_size
66+ return "vocab.enzh-en .%d" % self .targeted_vocab_size
6567
6668 @property
6769 def target_vocab_name (self ):
68- return "vocab.zhen-en .%d" % self .targeted_vocab_size
70+ return "vocab.enzh-zh .%d" % self .targeted_vocab_size
6971
7072 def generator (self , data_dir , tmp_dir , train ):
71- datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
72- source_datasets = [[item [0 ], [item [1 ][0 ]]] for item in _ZHEN_TRAIN_DATASETS ]
73- target_datasets = [[item [0 ], [item [1 ][1 ]]] for item in _ZHEN_TRAIN_DATASETS ]
73+ datasets = _ENZH_TRAIN_DATASETS if train else _ENZH_TEST_DATASETS
74+ source_datasets = [[item [0 ], [item [1 ][0 ]]] for item in _ENZH_TRAIN_DATASETS ]
75+ target_datasets = [[item [0 ], [item [1 ][1 ]]] for item in _ENZH_TRAIN_DATASETS ]
7476 source_vocab = generator_utils .get_or_generate_vocab (
7577 data_dir , tmp_dir , self .source_vocab_name , self .targeted_vocab_size ,
7678 source_datasets )
@@ -79,21 +81,18 @@ def generator(self, data_dir, tmp_dir, train):
7981 target_datasets )
8082 tag = "train" if train else "dev"
8183 data_path = translate .compile_data (tmp_dir , datasets ,
82- "wmt_zhen_tok_%s" % tag )
83- # We generate English->X data by convention, to train reverse translation
84- # just add the "_rev" suffix to the problem name, e.g., like this.
85- # --problems=translate_enzh_wmt8k_rev
86- return translate .bi_vocabs_token_generator (data_path + ".lang2" ,
87- data_path + ".lang1" ,
84+ "wmt_enzh_tok_%s" % tag )
85+ return translate .bi_vocabs_token_generator (data_path + ".lang1" ,
86+ data_path + ".lang2" ,
8887 source_vocab , target_vocab , EOS )
8988
9089 @property
9190 def input_space_id (self ):
92- return problem .SpaceID .ZH_TOK
91+ return problem .SpaceID .EN_TOK
9392
9493 @property
9594 def target_space_id (self ):
96- return problem .SpaceID .EN_TOK
95+ return problem .SpaceID .ZH_TOK
9796
9897 def feature_encoders (self , data_dir ):
9998 source_vocab_filename = os .path .join (data_dir , self .source_vocab_name )
0 commit comments