2929import tensorflow as tf
3030
3131
32- def character_generator (source_path , target_path , eos = None ):
32+ # End-of-sentence marker (should correspond to the position of EOS in the
33+ # RESERVED_TOKENS list in text_encoder.py)
34+ EOS = 1
35+
36+
37+ def character_generator (source_path , target_path , character_vocab , eos = None ):
3338 """Generator for sequence-to-sequence tasks that just uses characters.
3439
3540 This generator assumes the files at source_path and target_path have
@@ -40,6 +45,7 @@ def character_generator(source_path, target_path, eos=None):
4045 Args:
4146 source_path: path to the file with source sentences.
4247 target_path: path to the file with target sentences.
48+ character_vocab: a TextEncoder to encode the characters.
4349 eos: integer to append at the end of each sequence (default: None).
4450
4551 Yields:
@@ -51,8 +57,8 @@ def character_generator(source_path, target_path, eos=None):
5157 with tf .gfile .GFile (target_path , mode = "r" ) as target_file :
5258 source , target = source_file .readline (), target_file .readline ()
5359 while source and target :
54- source_ints = [ ord ( c ) for c in source .strip ()] + eos_list
55- target_ints = [ ord ( c ) for c in target .strip ()] + eos_list
60+ source_ints = character_vocab . encode ( source .strip ()) + eos_list
61+ target_ints = character_vocab . encode ( target .strip ()) + eos_list
5662 yield {"inputs" : source_ints , "targets" : target_ints }
5763 source , target = source_file .readline (), target_file .readline ()
5864
@@ -226,14 +232,16 @@ def ende_wordpiece_token_generator(tmp_dir, train, vocab_size):
226232 tag = "train" if train else "dev"
227233 data_path = _compile_data (tmp_dir , datasets , "wmt_ende_tok_%s" % tag )
228234 return token_generator (data_path + ".lang1" , data_path + ".lang2" ,
229- symbolizer_vocab , 1 )
235+ symbolizer_vocab , EOS )
230236
231237
232238def ende_character_generator (tmp_dir , train ):
239+ character_vocab = text_encoder .ByteTextEncoder ()
233240 datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS
234241 tag = "train" if train else "dev"
235242 data_path = _compile_data (tmp_dir , datasets , "wmt_ende_chr_%s" % tag )
236- return character_generator (data_path + ".lang1" , data_path + ".lang2" , 1 )
243+ return character_generator (data_path + ".lang1" , data_path + ".lang2" ,
244+ character_vocab , EOS )
237245
238246
239247def enfr_wordpiece_token_generator (tmp_dir , train , vocab_size ):
@@ -244,22 +252,25 @@ def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
244252 tag = "train" if train else "dev"
245253 data_path = _compile_data (tmp_dir , datasets , "wmt_enfr_tok_%s" % tag )
246254 return token_generator (data_path + ".lang1" , data_path + ".lang2" ,
247- symbolizer_vocab , 1 )
255+ symbolizer_vocab , EOS )
248256
249257
250258def enfr_character_generator (tmp_dir , train ):
251259 """Instance of character generator for the WMT en->fr task."""
260+ character_vocab = text_encoder .ByteTextEncoder ()
252261 datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
253262 tag = "train" if train else "dev"
254263 data_path = _compile_data (tmp_dir , datasets , "wmt_enfr_chr_%s" % tag )
255- return character_generator (data_path + ".lang1" , data_path + ".lang2" , 1 )
264+ return character_generator (data_path + ".lang1" , data_path + ".lang2" ,
265+ character_vocab , EOS )
256266
257267
258268def parsing_character_generator (tmp_dir , train ):
269+ character_vocab = text_encoder .ByteTextEncoder ()
259270 filename = "parsing_%s" % ("train" if train else "dev" )
260271 text_filepath = os .path .join (tmp_dir , filename + ".text" )
261272 tags_filepath = os .path .join (tmp_dir , filename + ".tags" )
262- return character_generator (text_filepath , tags_filepath , 1 )
273+ return character_generator (text_filepath , tags_filepath , character_vocab , EOS )
263274
264275
265276def parsing_token_generator (tmp_dir , train , vocab_size ):
@@ -268,4 +279,4 @@ def parsing_token_generator(tmp_dir, train, vocab_size):
268279 filename = "parsing_%s" % ("train" if train else "dev" )
269280 text_filepath = os .path .join (tmp_dir , filename + ".text" )
270281 tags_filepath = os .path .join (tmp_dir , filename + ".tags" )
271- return token_generator (text_filepath , tags_filepath , symbolizer_vocab , 1 )
282+ return token_generator (text_filepath , tags_filepath , symbolizer_vocab , EOS )
0 commit comments