Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b669110

Browse files
authored
Merge pull request #220 from vthorsteinsson/ice
Move Icelandic parsing problem to separate module
2 parents 73f0be2 + ab9b004 commit b669110

File tree

10 files changed

+128
-74
lines changed

10 files changed

+128
-74
lines changed

tensor2tensor/bin/t2t-datagen

100644100755
Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
8282
"algorithmic_algebra_inverse": (
8383
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
8484
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
85-
"ice_parsing_tokens": (
86-
lambda: wmt.tabbed_parsing_token_generator(
87-
FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8),
88-
lambda: wmt.tabbed_parsing_token_generator(
89-
FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)),
90-
"ice_parsing_characters": (
91-
lambda: wmt.tabbed_parsing_character_generator(
92-
FLAGS.data_dir, FLAGS.tmp_dir, True),
93-
lambda: wmt.tabbed_parsing_character_generator(
94-
FLAGS.data_dir, FLAGS.tmp_dir, False)),
9585
"wmt_parsing_tokens_8k": (
9686
lambda: wmt.parsing_token_generator(
9787
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),

tensor2tensor/bin/t2t-trainer

100644100755
File mode changed.

tensor2tensor/data_generators/all_problems.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensor2tensor.data_generators import wiki
3232
from tensor2tensor.data_generators import wmt
3333
from tensor2tensor.data_generators import wsj_parsing
34+
from tensor2tensor.data_generators import ice_parsing
3435

3536

3637
# Problem modules that require optional dependencies

tensor2tensor/data_generators/generator_utils.py

100644100755
File mode changed.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2017 The Tensor2Tensor Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This module implements the ice_parsing_* problems, which
16+
# parse plain text into flattened parse trees and POS tags.
17+
# The training data is stored in files named `parsing_train.pairs`
18+
# and `parsing_dev.pairs`. These files are UTF-8 text files where
19+
# each line contains an input sentence and a target parse tree,
20+
# separated by a tab character.
21+
22+
import os
23+
24+
# Dependency imports
25+
26+
from tensor2tensor.data_generators import generator_utils
27+
from tensor2tensor.data_generators import problem
28+
from tensor2tensor.data_generators import text_encoder
29+
from tensor2tensor.data_generators.wmt import tabbed_generator
30+
from tensor2tensor.utils import registry
31+
32+
import tensorflow as tf
33+
34+
35+
# End-of-sentence marker.
36+
EOS = text_encoder.EOS_ID
37+
38+
39+
def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
40+
source_vocab_size, target_vocab_size):
41+
"""Generate source and target data from a single file."""
42+
filename = "parsing_{0}.pairs".format("train" if train else "dev")
43+
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
44+
data_dir, tmp_dir, filename, 0,
45+
prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size)
46+
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
47+
data_dir, tmp_dir, filename, 1,
48+
prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size)
49+
pair_filepath = os.path.join(tmp_dir, filename)
50+
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)
51+
52+
53+
def tabbed_parsing_character_generator(tmp_dir, train):
54+
"""Generate source and target data from a single file."""
55+
character_vocab = text_encoder.ByteTextEncoder()
56+
filename = "parsing_{0}.pairs".format("train" if train else "dev")
57+
pair_filepath = os.path.join(tmp_dir, filename)
58+
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)
59+
60+
61+
@registry.register_problem("ice_parsing_tokens")
62+
class IceParsingTokens(problem.Problem):
63+
"""Problem spec for parsing tokenized Icelandic text to
64+
constituency trees, also tokenized but to a smaller vocabulary."""
65+
66+
@property
67+
def source_vocab_size(self):
68+
return 2**14 # 16384
69+
70+
@property
71+
def targeted_vocab_size(self):
72+
return 2**8 # 256
73+
74+
@property
75+
def input_space_id(self):
76+
return problem.SpaceID.ICE_TOK
77+
78+
@property
79+
def target_space_id(self):
80+
return problem.SpaceID.ICE_PARSE_TOK
81+
82+
@property
83+
def num_shards(self):
84+
return 10
85+
86+
def feature_encoders(self, data_dir):
87+
source_vocab_filename = os.path.join(
88+
data_dir, "ice_source.tokens.vocab.%d" % self.source_vocab_size)
89+
target_vocab_filename = os.path.join(
90+
data_dir, "ice_target.tokens.vocab.%d" % self.targeted_vocab_size)
91+
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
92+
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
93+
return {
94+
"inputs": source_subtokenizer,
95+
"targets": target_subtokenizer,
96+
}
97+
98+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
99+
generator_utils.generate_dataset_and_shuffle(
100+
tabbed_parsing_token_generator(data_dir, tmp_dir, True, "ice",
101+
self.source_vocab_size,
102+
self.targeted_vocab_size),
103+
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
104+
tabbed_parsing_token_generator(data_dir, tmp_dir, False, "ice",
105+
self.source_vocab_size,
106+
self.targeted_vocab_size),
107+
self.dev_filepaths(data_dir, 1, shuffled=False))
108+
109+
def hparams(self, defaults, model_hparams):
110+
p = defaults
111+
source_vocab_size = self._encoders["inputs"].vocab_size
112+
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, source_vocab_size)}
113+
p.target_modality = (registry.Modalities.SYMBOL, self.targeted_vocab_size)
114+
p.input_space_id = self.input_space_id
115+
p.target_space_id = self.target_space_id
116+
p.loss_multiplier = 2.5 # Rough estimate of avg number of tokens per word
117+

tensor2tensor/data_generators/problem_hparams.py

100644100755
Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -462,39 +462,6 @@ def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size,
462462
return p
463463

464464

465-
def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
466-
"""Icelandic to parse tree translation benchmark.
467-
468-
Args:
469-
model_hparams: a tf.contrib.training.HParams
470-
wrong_source_vocab_size: a number used in the filename indicating the
471-
approximate vocabulary size. This is not to be confused with the actual
472-
vocabulary size.
473-
474-
Returns:
475-
A tf.contrib.training.HParams object.
476-
"""
477-
p = default_problem_hparams()
478-
# This vocab file must be present within the data directory.
479-
source_vocab_filename = os.path.join(
480-
model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size)
481-
target_vocab_filename = os.path.join(model_hparams.data_dir,
482-
"ice_target.vocab.256")
483-
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
484-
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
485-
p.input_modality = {
486-
"inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size)
487-
}
488-
p.target_modality = (registry.Modalities.SYMBOL, 256)
489-
p.vocabulary = {
490-
"inputs": source_subtokenizer,
491-
"targets": target_subtokenizer,
492-
}
493-
p.input_space_id = 18 # Icelandic tokens
494-
p.target_space_id = 19 # Icelandic parse tokens
495-
return p
496-
497-
498465
def img2img_imagenet(unused_model_hparams):
499466
"""Image 2 Image for imagenet dataset."""
500467
p = default_problem_hparams()
@@ -544,10 +511,6 @@ def image_celeba(unused_model_hparams):
544511
lm1b_32k,
545512
"wiki_32k":
546513
wiki_32k,
547-
"ice_parsing_characters":
548-
wmt_parsing_characters,
549-
"ice_parsing_tokens":
550-
lambda p: ice_parsing_tokens(p, 2**13),
551514
"wmt_parsing_tokens_8k":
552515
lambda p: wmt_parsing_tokens(p, 2**13),
553516
"wsj_parsing_tokens_16k":

tensor2tensor/data_generators/wmt.py

100644100755
Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -648,28 +648,6 @@ def target_space_id(self):
648648
return problem.SpaceID.CS_CHR
649649

650650

651-
def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
652-
source_vocab_size, target_vocab_size):
653-
"""Generate source and target data from a single file."""
654-
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
655-
data_dir, tmp_dir, "parsing_train.pairs", 0,
656-
prefix + "_source.vocab.%d" % source_vocab_size, source_vocab_size)
657-
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
658-
data_dir, tmp_dir, "parsing_train.pairs", 1,
659-
prefix + "_target.vocab.%d" % target_vocab_size, target_vocab_size)
660-
filename = "parsing_%s" % ("train" if train else "dev")
661-
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
662-
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)
663-
664-
665-
def tabbed_parsing_character_generator(tmp_dir, train):
666-
"""Generate source and target data from a single file."""
667-
character_vocab = text_encoder.ByteTextEncoder()
668-
filename = "parsing_%s" % ("train" if train else "dev")
669-
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
670-
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)
671-
672-
673651
def parsing_token_generator(data_dir, tmp_dir, train, vocab_size):
674652
symbolizer_vocab = generator_utils.get_or_generate_vocab(
675653
data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size)

tensor2tensor/models/transformer.py

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def transformer_parsing_big():
393393

394394
@registry.register_hparams
395395
def transformer_parsing_ice():
396-
"""Hparams for parsing Icelandic text."""
396+
"""Hparams for parsing and tagging Icelandic text."""
397397
hparams = transformer_base_single_gpu()
398398
hparams.batch_size = 4096
399399
hparams.shared_embedding_and_softmax_weights = int(False)

tensor2tensor/utils/decoding.py

100644100755
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,14 +259,19 @@ def _interactive_input_fn(hparams):
259259
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
260260
# This should be longer than the longest input.
261261
const_array_size = 10000
262+
# Import readline if available for command line editing and recall
263+
try:
264+
import readline
265+
except ImportError:
266+
pass
262267
while True:
263268
prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n"
264269
" it=<input_type> ('text' or 'image' or 'label')\n"
265270
" pr=<problem_num> (set the problem number)\n"
266271
" in=<input_problem> (set the input problem number)\n"
267272
" ou=<output_problem> (set the output problem number)\n"
268273
" ns=<num_samples> (changes number of samples)\n"
269-
" dl=<decode_length> (changes decode legnth)\n"
274+
" dl=<decode_length> (changes decode length)\n"
270275
" <%s> (decode)\n"
271276
" q (quit)\n"
272277
">" % (num_samples, decode_length, "source_string"

tensor2tensor/utils/registry.py

100644100755
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,10 @@ def parse_problem_name(problem_name):
225225
was_copy: A boolean.
226226
"""
227227
# Recursively strip tags until we reach a base name.
228-
if len(problem_name) > 4 and problem_name[-4:] == "_rev":
228+
if problem_name.endswith("_rev"):
229229
base, _, was_copy = parse_problem_name(problem_name[:-4])
230230
return base, True, was_copy
231-
elif len(problem_name) > 5 and problem_name[-5:] == "_copy":
231+
elif problem_name.endswith("_copy"):
232232
base, was_reversed, _ = parse_problem_name(problem_name[:-5])
233233
return base, was_reversed, True
234234
else:
@@ -352,7 +352,7 @@ def list_modalities():
352352

353353

354354
def parse_modality_name(name):
355-
name_parts = name.split(":")
355+
name_parts = name.split(":", maxsplit=1)
356356
if len(name_parts) < 2:
357357
name_parts.append("default")
358358
modality_type, modality_name = name_parts

0 commit comments

Comments
 (0)