Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c91989c

Browse files
authored
Merge pull request #162 from vthorsteinsson/ice
Bug fixes in inference and data generation; faster token unescaping
2 parents 1368b00 + e2ed8ed commit c91989c

File tree

6 files changed

+30
-36
lines changed

6 files changed

+30
-36
lines changed

tensor2tensor/bin/t2t-datagen

100644100755
File mode changed.

tensor2tensor/bin/t2t-trainer

100644100755
File mode changed.

tensor2tensor/data_generators/generator_utils.py

100644100755
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,18 +329,19 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
329329
return vocab
330330

331331
# Use Tokenizer to count the word occurrences.
332+
token_counts = defaultdict(int)
332333
filepath = os.path.join(tmp_dir, source_filename)
333334
with tf.gfile.GFile(filepath, mode="r") as source_file:
334335
for line in source_file:
335336
line = line.strip()
336337
if line and "\t" in line:
337338
parts = line.split("\t", maxsplit=1)
338339
part = parts[index].strip()
339-
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
340+
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
341+
token_counts[tok] += 1
340342

341343
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
342-
vocab_size, tokenizer.token_counts, 1,
343-
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
344+
vocab_size, token_counts, 1, 1e3)
344345
vocab.store_to_file(vocab_filepath)
345346
return vocab
346347

tensor2tensor/data_generators/text_encoder.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from __future__ import print_function
2525

2626
from collections import defaultdict
27+
import re
2728

2829
# Dependency imports
2930

@@ -225,6 +226,7 @@ class SubwordTextEncoder(TextEncoder):
225226

226227
def __init__(self, filename=None):
227228
"""Initialize and read from a file, if provided."""
229+
self._alphabet = set()
228230
if filename is not None:
229231
self._load_from_file(filename)
230232
super(SubwordTextEncoder, self).__init__(num_reserved_ids=None)
@@ -503,6 +505,12 @@ def _escape_token(self, token):
503505
ret += u"\\%d;" % ord(c)
504506
return ret
505507

508+
# Regular expression for unescaping token strings
509+
# '\u' is converted to '_'
510+
# '\\' is converted to '\'
511+
# '\213;' is converted to unichr(213)
512+
_UNESCAPE_REGEX = re.compile(u'|'.join([r"\\u", r"\\\\", r"\\([0-9]+);"]))
513+
506514
def _unescape_token(self, escaped_token):
507515
"""Inverse of _escape_token().
508516
@@ -511,32 +519,14 @@ def _unescape_token(self, escaped_token):
511519
Returns:
512520
token: a unicode string
513521
"""
514-
ret = u""
515-
escaped_token = escaped_token[:-1]
516-
pos = 0
517-
while pos < len(escaped_token):
518-
c = escaped_token[pos]
519-
if c == "\\":
520-
pos += 1
521-
if pos >= len(escaped_token):
522-
break
523-
c = escaped_token[pos]
524-
if c == u"u":
525-
ret += u"_"
526-
pos += 1
527-
elif c == "\\":
528-
ret += u"\\"
529-
pos += 1
530-
else:
531-
semicolon_pos = escaped_token.find(u";", pos)
532-
if semicolon_pos == -1:
533-
continue
534-
try:
535-
ret += unichr(int(escaped_token[pos:semicolon_pos]))
536-
pos = semicolon_pos + 1
537-
except (ValueError, OverflowError) as _:
538-
pass
539-
else:
540-
ret += c
541-
pos += 1
542-
return ret
522+
def match(m):
523+
if m.group(1) is not None:
524+
# Convert '\213;' to unichr(213)
525+
try:
526+
return unichr(int(m.group(1)))
527+
except (ValueError, OverflowError) as _:
528+
return ""
529+
# Convert '\u' to '_' and '\\' to '\'
530+
return u"_" if m.group(0) == u"\\u" else u"\\"
531+
# Cut off the trailing underscore and apply the regex substitution
532+
return self._UNESCAPE_REGEX.sub(match, escaped_token[:-1])

tensor2tensor/data_generators/tokenizer_test.py

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding: utf-8 -*-
12
# Copyright 2017 The Tensor2Tensor Authors.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,7 +13,6 @@
1213
# See the License for the specific language governing permissions and
1314
# limitations under the License.
1415

15-
# coding=utf-8
1616
"""Tests for tensor2tensor.data_generators.tokenizer."""
1717

1818
from __future__ import absolute_import

tensor2tensor/utils/trainer_utils.py

100644100755
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def decode_from_dataset(estimator):
585585
tf.logging.info("Performing local inference.")
586586
infer_problems_data = get_datasets_for_mode(hparams.data_dir,
587587
tf.contrib.learn.ModeKeys.INFER)
588+
588589
infer_input_fn = get_input_fn(
589590
mode=tf.contrib.learn.ModeKeys.INFER,
590591
hparams=hparams,
@@ -625,9 +626,11 @@ def log_fn(inputs,
625626

626627
# The function predict() returns an iterable over the network's
627628
# predictions from the test input. We use it to log inputs and decodes.
628-
for j, result in enumerate(result_iter):
629-
inputs, targets, outputs = (result["inputs"], result["targets"],
630-
result["outputs"])
629+
inputs_iter = result_iter["inputs"]
630+
targets_iter = result_iter["targets"]
631+
outputs_iter = result_iter["outputs"]
632+
for j, result in enumerate(zip(inputs_iter, targets_iter, outputs_iter)):
633+
inputs, targets, outputs = result
631634
if FLAGS.decode_return_beams:
632635
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
633636
for k, beam in enumerate(output_beams):

0 commit comments

Comments
 (0)