3434# End-of-sentence marker.
3535EOS = text_encoder .EOS_ID
3636
37- _ENFR_TRAIN_DATASETS = [
37+ _ENFR_TRAIN_SMALL_DATA = [
3838 [
3939 "https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz" ,
4040 ("baseline-1M-enfr/baseline-1M_train.en" ,
4141 "baseline-1M-enfr/baseline-1M_train.fr" )
4242 ],
43- # [
44- # "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
45- # ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
46- # ],
47- # [
48- # "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
49- # ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
50- # ],
51- # [
52- # "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
53- # ("training/news-commentary-v9.fr-en.en",
54- # "training/news-commentary-v9.fr-en.fr")
55- # ],
56- # [
57- # "http://www.statmt.org/wmt10/training-giga-fren.tar",
58- # ("giga-fren.release2.fixed.en.gz",
59- # "giga-fren.release2.fixed.fr.gz")
60- # ],
61- # [
62- # "http://www.statmt.org/wmt13/training-parallel-un.tgz",
63- # ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
64- # ],
6543]
66- _ENFR_TEST_DATASETS = [
44+ _ENFR_TEST_SMALL_DATA = [
6745 [
6846 "https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz" ,
6947 ("baseline-1M-enfr/baseline-1M_valid.en" ,
7048 "baseline-1M-enfr/baseline-1M_valid.fr" )
7149 ],
72- # [
73- # "http://data.statmt.org/wmt17/translation-task/dev.tgz",
74- # ("dev/newstest2013.en", "dev/newstest2013.fr")
75- # ],
50+ ]
51+ _ENFR_TRAIN_LARGE_DATA = [
52+ [
53+ "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" ,
54+ ("commoncrawl.fr-en.en" , "commoncrawl.fr-en.fr" )
55+ ],
56+ [
57+ "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz" ,
58+ ("training/europarl-v7.fr-en.en" , "training/europarl-v7.fr-en.fr" )
59+ ],
60+ [
61+ "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" ,
62+ ("training/news-commentary-v9.fr-en.en" ,
63+ "training/news-commentary-v9.fr-en.fr" )
64+ ],
65+ [
66+ "http://www.statmt.org/wmt10/training-giga-fren.tar" ,
67+ ("giga-fren.release2.fixed.en.gz" ,
68+ "giga-fren.release2.fixed.fr.gz" )
69+ ],
70+ [
71+ "http://www.statmt.org/wmt13/training-parallel-un.tgz" ,
72+ ("un/undoc.2000.fr-en.en" , "un/undoc.2000.fr-en.fr" )
73+ ],
74+ ]
75+ _ENFR_TEST_LARGE_DATA = [
76+ [
77+ "http://data.statmt.org/wmt17/translation-task/dev.tgz" ,
78+ ("dev/newstest2013.en" , "dev/newstest2013.fr" )
79+ ],
7680]
7781
7882
7983@registry .register_problem
80- class TranslateEnfrWmt8k (translate .TranslateProblem ):
84+ class TranslateEnfrWmtSmall8k (translate .TranslateProblem ):
8185 """Problem spec for WMT En-Fr translation."""
8286
8387 @property
@@ -88,11 +92,18 @@ def targeted_vocab_size(self):
8892 def vocab_name (self ):
8993 return "vocab.enfr"
9094
95+ @property
96+ def use_small_dataset (self ):
97+ return True
98+
9199 def generator (self , data_dir , tmp_dir , train ):
92100 symbolizer_vocab = generator_utils .get_or_generate_vocab (
93101 data_dir , tmp_dir , self .vocab_file , self .targeted_vocab_size ,
94- _ENFR_TRAIN_DATASETS )
95- datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
102+ _ENFR_TRAIN_SMALL_DATA )
103+ if self .use_small_dataset :
104+ datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
105+ else :
106+ datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
96107 tag = "train" if train else "dev"
97108 data_path = translate .compile_data (tmp_dir , datasets ,
98109 "wmt_enfr_tok_%s" % tag )
@@ -109,15 +120,31 @@ def target_space_id(self):
109120
110121
111122@registry .register_problem
112- class TranslateEnfrWmt32k ( TranslateEnfrWmt8k ):
123+ class TranslateEnfrWmtSmall32k ( TranslateEnfrWmtSmall8k ):
113124
114125 @property
115126 def targeted_vocab_size (self ):
116127 return 2 ** 15 # 32768
117128
118129
119130@registry .register_problem
120- class TranslateEnfrWmtCharacters (translate .TranslateProblem ):
131+ class TranslateEnfrWmt8k (TranslateEnfrWmtSmall8k ):
132+
133+ @property
134+ def use_small_dataset (self ):
135+ return False
136+
137+
138+ @registry .register_problem
139+ class TranslateEnfrWmt32k (TranslateEnfrWmtSmall32k ):
140+
141+ @property
142+ def use_small_dataset (self ):
143+ return False
144+
145+
146+ @registry .register_problem
147+ class TranslateEnfrWmtSmallCharacters (translate .TranslateProblem ):
121148 """Problem spec for WMT En-Fr translation."""
122149
123150 @property
@@ -130,7 +157,10 @@ def vocab_name(self):
130157
131158 def generator (self , data_dir , tmp_dir , train ):
132159 character_vocab = text_encoder .ByteTextEncoder ()
133- datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
160+ if self .use_small_dataset :
161+ datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
162+ else :
163+ datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
134164 tag = "train" if train else "dev"
135165 data_path = translate .compile_data (tmp_dir , datasets ,
136166 "wmt_enfr_chr_%s" % tag )
@@ -144,3 +174,11 @@ def input_space_id(self):
144174 @property
145175 def target_space_id (self ):
146176 return problem .SpaceID .FR_CHR
177+
178+
179+ @registry .register_problem
180+ class TranslateEnfrWmtCharacters (TranslateEnfrWmtSmallCharacters ):
181+
182+ @property
183+ def use_small_dataset (self ):
184+ return False
0 commit comments