2828# Dependency imports
2929
3030import six
31+ from six import PY2
3132from six .moves import xrange # pylint: disable=redefined-builtin
3233from tensor2tensor .data_generators import tokenizer
3334
3435import tensorflow as tf
3536
37+
38+ # Conversion between Unicode and UTF-8, if required (on Python2)
39+ _native_to_unicode = (lambda s : s .decode ("utf-8" )) if PY2 else (lambda s : s )
40+
41+
42+ _unicode_to_native = (lambda s : s .encode ("utf-8" )) if PY2 else (lambda s : s )
43+
44+
3645# Reserved tokens for things like padding and EOS symbols.
3746PAD = "<pad>"
3847EOS = "<EOS>"
@@ -162,15 +171,36 @@ def _load_vocab_from_file(self, filename):
162171
163172
164173class SubwordTextEncoder (TextEncoder ):
165- """Class for breaking tokens into subtokens .
174+ """Class for invertibly encoding text using a limited vocabulary .
166175
167- Invertibly encodes a string as a sequence of subtokens from a limited
176+ Invertibly encodes a native string as a sequence of subtokens from a limited
168177 vocabulary.
169178
170179 A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
171180 the corpus), and stored to a file. See text_encoder_build_subword.py.
172181
173182 It can then be loaded and used to encode/decode any text.
183+
184+ Encoding has four phases:
185+
186+ 1. Tokenize into a list of tokens. Each token is a unicode string of either
187+ all alphanumeric characters or all non-alphanumeric characters. We drop
188+ tokens consisting of a single space that are between two alphanumeric
189+ tokens.
190+
191+ 2. Escape each token. This escapes away special and out-of-vocabulary
192+ characters, and makes sure that each token ends with an underscore, and
193+ has no other underscores.
194+
195+ 3. Represent each escaped token as a the concatenation of a list of subtokens
196+ from the limited vocabulary. Subtoken selection is done greedily from
197+ beginning to end. That is, we construct the list in order, always picking
198+ the longest subtoken in our vocabulary that matches a prefix of the
199+ remaining portion of the encoded token.
200+
201+ 4. Concatenate these lists. This concatenation is invertible due to the
202+ fact that the trailing underscores indicate when one list is finished.
203+
174204 """
175205
176206 def __init__ (self , filename = None , num_reserved_ids = 2 ):
@@ -182,24 +212,26 @@ def __init__(self, filename=None, num_reserved_ids=2):
182212 super (SubwordTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
183213
184214 def encode (self , raw_text ):
185- """Converts a string to a list of subtoken ids.
215+ """Converts a native string to a list of subtoken ids.
186216
187217 Args:
188- raw_text: a string.
218+ raw_text: a native string.
189219 Returns:
190220 a list of integers in the range [0, vocab_size)
191221 """
192- return self ._tokens_to_subtokens (self ._tokenizer .encode (raw_text ))
222+ return self ._tokens_to_subtokens (self ._tokenizer .encode (
223+ _native_to_unicode (raw_text )))
193224
194225 def decode (self , subtokens ):
195- """Converts a sequence of subtoken ids to a string.
226+ """Converts a sequence of subtoken ids to a native string.
196227
197228 Args:
198229 subtokens: a list of integers in the range [0, vocab_size)
199230 Returns:
200- a string
231+ a native string
201232 """
202- return self ._tokenizer .decode (self ._subtokens_to_tokens (subtokens ))
233+ return _unicode_to_native (self ._tokenizer .decode (
234+ self ._subtokens_to_tokens (subtokens )))
203235
204236 @property
205237 def vocab_size (self ):
@@ -239,8 +271,8 @@ def subtoken_to_subtoken_string(self, subtoken):
239271 if subtoken_string :
240272 return subtoken_string
241273 if 0 <= subtoken < self ._num_reserved_ids :
242- return "%s_" % RESERVED_TOKENS [subtoken ]
243- return "ID%d_" % subtoken
274+ return u "%s_" % RESERVED_TOKENS [subtoken ]
275+ return u "ID%d_" % subtoken
244276
245277 def _escaped_token_to_subtokens (self , escaped_token ):
246278 """Converts an escaped token string to a list of subtokens.
@@ -260,27 +292,11 @@ def _escaped_token_to_subtokens(self, escaped_token):
260292 if subtoken != - 1 :
261293 break
262294 end -= 1
263- if end > pos :
264- ret .append (subtoken )
265- pos = end
266- else :
267- # No subtoken in the vocabulary matches escaped_token[pos].
268- # This can happen if the token contains a Unicode character
269- # that did not occur in the vocabulary training set.
270- # The id self.vocab_size - 1 is decoded as Unicode uFFFD,
271- # REPLACEMENT_CHARACTER.
272- ret .append (self .vocab_size - 1 )
273- # Ensure that the outer loop continues
274- pos += 1
275- return ret
295+ assert end > pos
296+ ret .append (subtoken )
297+ pos = end
276298
277- @classmethod
278- def alphabet (cls , token_counts ):
279- """Return the set of Unicode characters that appear in the tokens."""
280- alphabet_set = set ()
281- for token in six .iterkeys (token_counts ):
282- alphabet_set |= set (token )
283- return alphabet_set
299+ return ret
284300
285301 @classmethod
286302 def build_to_target_size (cls ,
@@ -304,17 +320,12 @@ def build_to_target_size(cls,
304320 Returns:
305321 a SubwordTextEncoder instance.
306322 """
307- # Calculate the alphabet, i.e. the set of all Unicode characters
308- # that appear in the tokens.
309- alphabet_set = cls .alphabet (token_counts )
310- tf .logging .info ("Alphabet contains %d characters" % len (alphabet_set ))
311-
312323 def bisect (min_val , max_val ):
313324 """Bisection to find the right size."""
314325 present_count = (max_val + min_val ) // 2
315326 tf .logging .info ("Trying min_count %d" % present_count )
316327 subtokenizer = cls ()
317- subtokenizer .build_from_token_counts (token_counts , alphabet_set ,
328+ subtokenizer .build_from_token_counts (token_counts ,
318329 present_count , num_iterations )
319330 if min_val >= max_val or subtokenizer .vocab_size == target_size :
320331 return subtokenizer
@@ -333,17 +344,29 @@ def bisect(min_val, max_val):
333344
334345 def build_from_token_counts (self ,
335346 token_counts ,
336- alphabet_set ,
337347 min_count ,
338348 num_iterations = 4 ):
339349 """Train a SubwordTextEncoder based on a dictionary of word counts.
340350
341351 Args:
342352 token_counts: a dictionary of Unicode strings to int.
343- alphabet_set: the set of Unicode characters that appear in the tokens.
344353 min_count: an integer - discard subtokens with lower counts.
345354 num_iterations: an integer. how many iterations of refinement.
346355 """
356+ # first determine the alphabet to include all characters with count at
357+ # least min_count in the dataset.
358+ char_counts = defaultdict (int )
359+ for token , count in six .iteritems (token_counts ):
360+ for c in token :
361+ char_counts [c ] += count
362+ self ._alphabet = set ()
363+ for c , count in six .iteritems (char_counts ):
364+ if count >= min_count :
365+ self ._alphabet .add (c )
366+ # Make sure all characters needed for escaping are included
367+ for c in u"\\ _;0123456789" :
368+ self ._alphabet .add (c )
369+
347370 # We build iteratively. On each iteration, we segment all the words,
348371 # then count the resulting potential subtokens, keeping the ones
349372 # with high enough counts for our new vocabulary.
@@ -367,43 +390,36 @@ def build_from_token_counts(self,
367390 for end in xrange (start + 1 , len (escaped_token ) + 1 ):
368391 subtoken_string = escaped_token [start :end ]
369392 counts [subtoken_string ] += count
393+ # Make sure all characters needed for escaping are included
394+ for c in self ._alphabet :
395+ counts [c ] += min_count
370396 # Array of sets of candidate subtoken strings, by length
371397 len_to_subtoken_strings = []
372398 for subtoken_string , count in six .iteritems (counts ):
373399 lsub = len (subtoken_string )
374- # All subtoken strings of length 1 are automatically included
375- # later, so we don't need to consider them here
376- if count < min_count or lsub <= 1 :
377- continue
378- # Add this subtoken string to its length set
379- while len (len_to_subtoken_strings ) <= lsub :
380- len_to_subtoken_strings .append (set ())
381- len_to_subtoken_strings [lsub ].add (subtoken_string )
400+ if count >= min_count :
401+ # Add this subtoken string to its length set
402+ while len (len_to_subtoken_strings ) <= lsub :
403+ len_to_subtoken_strings .append (set ())
404+ len_to_subtoken_strings [lsub ].add (subtoken_string )
382405 new_subtoken_strings = []
383406 # consider the candidates longest to shortest, so that if we accept
384407 # a longer subtoken string, we can decrement the counts of its prefixes.
385- for subtoken_strings in reversed (len_to_subtoken_strings [2 :]):
408+ for lsub in reversed (range (1 , len (len_to_subtoken_strings ))):
409+ subtoken_strings = len_to_subtoken_strings [lsub ]
386410 for subtoken_string in subtoken_strings :
387411 count = counts [subtoken_string ]
388- if count < min_count :
389- continue
390- new_subtoken_strings .append ((count , subtoken_string ))
391- for l in xrange (1 , len (subtoken_string )):
392- counts [subtoken_string [:l ]] -= count
393- # Sort what we've got so far in decreasing order by count
412+ if count >= min_count :
413+ new_subtoken_strings .append ((count , subtoken_string ))
414+ for l in xrange (1 , lsub ):
415+ counts [subtoken_string [:l ]] -= count
416+ # Sort in decreasing order by count
394417 new_subtoken_strings .sort (reverse = True )
395- # Add the alphabet set at the end of the vocabulary list
396- for char in alphabet_set :
397- new_subtoken_strings .append ((0 , char ))
398- # Also include the Unicode REPLACEMENT CHARACTER to use
399- # when encountering previously unseen Unicode characters
400- # in the input (i.e. input external to the tokenizer training
401- # set, which may thus contain characters not in the alphabet_set).
402- # This must be the last entry in the subtoken vocabulary list.
403- new_subtoken_strings .append ((0 , u"\uFFFD " ))
404418 # Now we have a candidate vocabulary
419+ old_alphabet = self ._alphabet
405420 self ._init_from_list ([u"" ] * self ._num_reserved_ids +
406421 [p [1 ] for p in new_subtoken_strings ])
422+ assert old_alphabet == self ._alphabet
407423 tf .logging .info ("vocab_size = %d" % self .vocab_size )
408424
409425 original = "This sentence was encoded by the SubwordTextEncoder."
@@ -426,46 +442,77 @@ def _init_from_list(self, subtoken_strings):
426442 self ._all_subtoken_strings = subtoken_strings
427443 self ._subtoken_string_to_id = {
428444 s : i for i , s in enumerate (subtoken_strings ) if s }
445+ self ._alphabet = set ([c for c in subtoken_strings if len (c ) == 1 ])
429446
430447 def _load_from_file (self , filename ):
431448 """Load from a file."""
432449 subtoken_strings = []
433450 with tf .gfile .Open (filename ) as f :
434451 for line in f :
435- if six .PY2 :
436- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ("utf-8" ))
437- else :
438- subtoken_strings .append (line .strip ()[1 :- 1 ])
452+ subtoken_strings .append (_native_to_unicode (line .strip ()[1 :- 1 ]))
439453 self ._init_from_list (subtoken_strings )
440454
441455 def store_to_file (self , filename ):
442456 with tf .gfile .Open (filename , "w" ) as f :
443457 for subtoken_string in self ._all_subtoken_strings :
444- if six .PY2 :
445- f .write ("'" + subtoken_string .encode ("utf-8" ) + "'\n " )
446- else :
447- f .write ("'" + subtoken_string + "'\n " )
458+ f .write ("'" + _unicode_to_native (subtoken_string ) + "'\n " )
448459
449460 def _escape_token (self , token ):
450- r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
461+ r"""Escape away underscores and OOV characters and append '_'.
462+
463+ This allows the token to be experessed as the concatenation of a list
464+ of subtokens from the vocabulary. The underscore acts as a sentinel
465+ which allows us to invertibly concatenate multiple such lists.
451466
452467 Args:
453- token: a string
468+ token: a unicode string
454469 Returns:
455- escaped_token: a string
470+ escaped_token: a unicode string
456471 """
457- return token .replace ("\\ " , "\\ \\ " ).replace ("_" , "\\ u" ) + "_"
472+ token = token .replace ("\\ " , "\\ \\ " ).replace ("_" , "\\ u" ) + "_"
473+ ret = u""
474+ for c in token :
475+ if c in self ._alphabet :
476+ ret += c
477+ else :
478+ ret += u"\\ %d;" % ord (c )
479+ return ret
458480
459481 def _unescape_token (self , escaped_token ):
460- r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_' .
482+ r"""Inverse of _escape_token() .
461483
462484 Args:
463- escaped_token: a string
485+ escaped_token: a unicode string
464486 Returns:
465- token: a string
487+ token: a unicode string
466488 """
467- assert escaped_token [- 1 ] == "_"
468- return escaped_token [:- 1 ].replace ("\\ u" , "_" ).replace ("\\ \\ " , "\\ " )
489+ ret = u""
490+ escaped_token = escaped_token [:- 1 ]
491+ pos = 0
492+ while pos < len (escaped_token ):
493+ c = escaped_token [pos ]
494+ if c == "\\ " :
495+ pos += 1
496+ c = escaped_token [pos ]
497+ if c == u"u" :
498+ ret += u"_"
499+ pos += 1
500+ elif c == "\\ " :
501+ ret += u"_"
502+ pos += 1
503+ else :
504+ semicolon_pos = escaped_token .find (u";" , pos )
505+ if semicolon_pos == - 1 :
506+ continue
507+ try :
508+ ret += unichr (int (escaped_token [pos :semicolon_pos ]))
509+ pos = semicolon_pos + 1
510+ except (ValueError , OverflowError ) as _ :
511+ pass
512+ else :
513+ ret += c
514+ pos += 1
515+ return ret
469516
470517 @classmethod
471518 def get_token_counts (cls , text_filepattern , corpus_max_lines ):
@@ -477,7 +524,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
477524 with tf .gfile .Open (text_filename ) as f :
478525 for line in f :
479526 # The tokenizer updates token_counts in encode()
480- tok .encode (line .strip ())
527+ tok .encode (_native_to_unicode ( line .strip () ))
481528 lines_read += 1
482529 if corpus_max_lines > 0 and lines_read > corpus_max_lines :
483530 return tok .token_counts
0 commit comments