@@ -101,6 +101,38 @@ def token_generator(source_path, target_path, token_vocab, eos=None):
101101 source , target = source_file .readline (), target_file .readline ()
102102
103103
104+ def bi_vocabs_token_generator (source_path , target_path ,
105+ source_token_vocab ,
106+ target_token_vocab ,
107+ eos = None ):
108+ """Generator for sequence-to-sequence tasks that uses tokens.
109+
110+ This generator assumes the files at source_path and target_path have
111+ the same number of lines and yields dictionaries of "inputs" and "targets"
112+ where inputs are token ids from the " "-split source (and target, resp.) lines
113+ converted to integers using the token_map.
114+
115+ Args:
116+ source_path: path to the file with source sentences.
117+ target_path: path to the file with target sentences.
118+ source_token_vocab: text_encoder.TextEncoder object.
119+ target_token_vocab: text_encoder.TextEncoder object.
120+ eos: integer to append at the end of each sequence (default: None).
121+
122+ Yields:
123+ A dictionary {"inputs": source-line, "targets": target-line} where
124+ the lines are integer lists converted from tokens in the file lines.
125+ """
126+ eos_list = [] if eos is None else [eos ]
127+ with tf .gfile .GFile (source_path , mode = "r" ) as source_file :
128+ with tf .gfile .GFile (target_path , mode = "r" ) as target_file :
129+ source , target = source_file .readline (), target_file .readline ()
130+ while source and target :
131+ source_ints = source_token_vocab .encode (source .strip ()) + eos_list
132+ target_ints = target_token_vocab .encode (target .strip ()) + eos_list
133+ yield {"inputs" : source_ints , "targets" : target_ints }
134+ source , target = source_file .readline (), target_file .readline ()
135+
104136def _get_wmt_ende_dataset (directory , filename ):
105137 """Extract the WMT en-de corpus `filename` to directory unless it's there."""
106138 train_path = os .path .join (directory , filename )
@@ -177,6 +209,21 @@ def ende_bpe_token_generator(tmp_dir, train):
177209 ],
178210]
179211
212+ _ZHEN_TRAIN_DATASETS = [
213+ [
214+ "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" ,
215+ ("training/news-commentary-v12.zh-en.zh" ,
216+ "training/news-commentary-v12.zh-en.en" )
217+ ]
218+ ]
219+
220+ _ZHEN_TEST_DATASETS = [
221+ [
222+ "http://data.statmt.org/wmt17/translation-task/dev.tgz" ,
223+ ("dev/newsdev2017-zhen-src.zh" ,
224+ "dev/newsdev2017-zhen-ref.en" )
225+ ]
226+ ]
180227
181228def _compile_data (tmp_dir , datasets , filename ):
182229 """Concatenate all `datasets` and save to `filename`."""
@@ -253,6 +300,25 @@ def ende_character_generator(tmp_dir, train):
253300 character_vocab , EOS )
254301
255302
303+ def zhen_wordpiece_token_generator (tmp_dir , train ,
304+ source_vocab_size ,
305+ target_vocab_size ):
306+ datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
307+ source_datasets = [[item [0 ], [item [1 ][0 ]]] for item in datasets ]
308+ target_datasets = [[item [0 ], [item [1 ][1 ]]] for item in datasets ]
309+ source_vocab = generator_utils .get_or_generate_vocab (
310+ tmp_dir , "tokens.vocab.zh.%d" % source_vocab_size ,
311+ source_vocab_size , source_datasets )
312+ target_vocab = generator_utils .get_or_generate_vocab (
313+ tmp_dir , "tokens.vocab.en.%d" % target_vocab_size ,
314+ target_vocab_size , target_datasets )
315+ tag = "train" if train else "dev"
316+ data_path = _compile_data (tmp_dir , datasets , "wmt_zhen_tok_%s" % tag )
317+ return bi_vocabs_token_generator (data_path + ".lang1" ,
318+ data_path + ".lang2" ,
319+ source_vocab , target_vocab , EOS )
320+
321+
256322def enfr_wordpiece_token_generator (tmp_dir , train , vocab_size ):
257323 """Instance of token generator for the WMT en->fr task."""
258324 symbolizer_vocab = generator_utils .get_or_generate_vocab (
0 commit comments