Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f5e371f

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
internal merge from github
PiperOrigin-RevId: 160210398
1 parent c820307 commit f5e371f

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

tensor2tensor/data_generators/wmt.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
import tensorflow as tf
3030

3131

32-
def character_generator(source_path, target_path, eos=None):
32+
# End-of-sentence marker (should correspond to the position of EOS in the
33+
# RESERVED_TOKENS list in text_encoder.py)
34+
EOS = 1
35+
36+
37+
def character_generator(source_path, target_path, character_vocab, eos=None):
3338
"""Generator for sequence-to-sequence tasks that just uses characters.
3439
3540
This generator assumes the files at source_path and target_path have
@@ -40,6 +45,7 @@ def character_generator(source_path, target_path, eos=None):
4045
Args:
4146
source_path: path to the file with source sentences.
4247
target_path: path to the file with target sentences.
48+
character_vocab: a TextEncoder to encode the characters.
4349
eos: integer to append at the end of each sequence (default: None).
4450
4551
Yields:
@@ -51,8 +57,8 @@ def character_generator(source_path, target_path, eos=None):
5157
with tf.gfile.GFile(target_path, mode="r") as target_file:
5258
source, target = source_file.readline(), target_file.readline()
5359
while source and target:
54-
source_ints = [ord(c) for c in source.strip()] + eos_list
55-
target_ints = [ord(c) for c in target.strip()] + eos_list
60+
source_ints = character_vocab.encode(source.strip()) + eos_list
61+
target_ints = character_vocab.encode(target.strip()) + eos_list
5662
yield {"inputs": source_ints, "targets": target_ints}
5763
source, target = source_file.readline(), target_file.readline()
5864

@@ -226,14 +232,16 @@ def ende_wordpiece_token_generator(tmp_dir, train, vocab_size):
226232
tag = "train" if train else "dev"
227233
data_path = _compile_data(tmp_dir, datasets, "wmt_ende_tok_%s" % tag)
228234
return token_generator(data_path + ".lang1", data_path + ".lang2",
229-
symbolizer_vocab, 1)
235+
symbolizer_vocab, EOS)
230236

231237

232238
def ende_character_generator(tmp_dir, train):
239+
character_vocab = text_encoder.ByteTextEncoder()
233240
datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS
234241
tag = "train" if train else "dev"
235242
data_path = _compile_data(tmp_dir, datasets, "wmt_ende_chr_%s" % tag)
236-
return character_generator(data_path + ".lang1", data_path + ".lang2", 1)
243+
return character_generator(data_path + ".lang1", data_path + ".lang2",
244+
character_vocab, EOS)
237245

238246

239247
def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
@@ -244,22 +252,25 @@ def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
244252
tag = "train" if train else "dev"
245253
data_path = _compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag)
246254
return token_generator(data_path + ".lang1", data_path + ".lang2",
247-
symbolizer_vocab, 1)
255+
symbolizer_vocab, EOS)
248256

249257

250258
def enfr_character_generator(tmp_dir, train):
251259
"""Instance of character generator for the WMT en->fr task."""
260+
character_vocab = text_encoder.ByteTextEncoder()
252261
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
253262
tag = "train" if train else "dev"
254263
data_path = _compile_data(tmp_dir, datasets, "wmt_enfr_chr_%s" % tag)
255-
return character_generator(data_path + ".lang1", data_path + ".lang2", 1)
264+
return character_generator(data_path + ".lang1", data_path + ".lang2",
265+
character_vocab, EOS)
256266

257267

258268
def parsing_character_generator(tmp_dir, train):
269+
character_vocab = text_encoder.ByteTextEncoder()
259270
filename = "parsing_%s" % ("train" if train else "dev")
260271
text_filepath = os.path.join(tmp_dir, filename + ".text")
261272
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
262-
return character_generator(text_filepath, tags_filepath, 1)
273+
return character_generator(text_filepath, tags_filepath, character_vocab, EOS)
263274

264275

265276
def parsing_token_generator(tmp_dir, train, vocab_size):
@@ -268,4 +279,4 @@ def parsing_token_generator(tmp_dir, train, vocab_size):
268279
filename = "parsing_%s" % ("train" if train else "dev")
269280
text_filepath = os.path.join(tmp_dir, filename + ".text")
270281
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
271-
return token_generator(text_filepath, tags_filepath, symbolizer_vocab, 1)
282+
return token_generator(text_filepath, tags_filepath, symbolizer_vocab, EOS)

tensor2tensor/data_generators/wmt_test.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# Dependency imports
2626

2727
import six
28+
from tensor2tensor.data_generators import text_encoder
2829
from tensor2tensor.data_generators import wmt
2930

3031
import tensorflow as tf
@@ -36,31 +37,48 @@ def testCharacterGenerator(self):
3637
# Generate a trivial source and target file.
3738
tmp_dir = self.get_temp_dir()
3839
(_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
40+
if six.PY2:
41+
enc_f = lambda s: s
42+
else:
43+
enc_f = lambda s: s.encode("utf-8")
3944
with io.open(tmp_file_path + ".src", "wb") as src_file:
40-
src_file.write("source1\n")
41-
src_file.write("source2\n")
45+
src_file.write(enc_f("source1\n"))
46+
src_file.write(enc_f("source2\n"))
4247
with io.open(tmp_file_path + ".tgt", "wb") as tgt_file:
43-
tgt_file.write("target1\n")
44-
tgt_file.write("target2\n")
48+
tgt_file.write(enc_f("target1\n"))
49+
tgt_file.write(enc_f("target2\n"))
4550

4651
# Call character generator on the generated files.
4752
results_src, results_tgt = [], []
48-
for dictionary in wmt.character_generator(tmp_file_path + ".src",
49-
tmp_file_path + ".tgt"):
53+
character_vocab = text_encoder.ByteTextEncoder()
54+
for dictionary in wmt.character_generator(
55+
tmp_file_path + ".src", tmp_file_path + ".tgt", character_vocab):
5056
self.assertEqual(sorted(list(dictionary)), ["inputs", "targets"])
5157
results_src.append(dictionary["inputs"])
5258
results_tgt.append(dictionary["targets"])
5359

5460
# Check that the results match the files.
61+
# First check that the results match the encoded original strings;
62+
# this is a comparison of integer arrays.
5563
self.assertEqual(len(results_src), 2)
56-
self.assertEqual("".join([six.int2byte(i)
57-
for i in results_src[0]]), "source1")
58-
self.assertEqual("".join([six.int2byte(i)
59-
for i in results_src[1]]), "source2")
60-
self.assertEqual("".join([six.int2byte(i)
61-
for i in results_tgt[0]]), "target1")
62-
self.assertEqual("".join([six.int2byte(i)
63-
for i in results_tgt[1]]), "target2")
64+
self.assertEqual(results_src[0],
65+
character_vocab.encode("source1"))
66+
self.assertEqual(results_src[1],
67+
character_vocab.encode("source2"))
68+
self.assertEqual(results_tgt[0],
69+
character_vocab.encode("target1"))
70+
self.assertEqual(results_tgt[1],
71+
character_vocab.encode("target2"))
72+
# Then decode the results and compare with the original strings;
73+
# this is a comparison of strings
74+
self.assertEqual(character_vocab.decode(results_src[0]),
75+
"source1")
76+
self.assertEqual(character_vocab.decode(results_src[1]),
77+
"source2")
78+
self.assertEqual(character_vocab.decode(results_tgt[0]),
79+
"target1")
80+
self.assertEqual(character_vocab.decode(results_tgt[1]),
81+
"target2")
6482

6583
# Clean up.
6684
os.remove(tmp_file_path + ".src")

0 commit comments

Comments
 (0)