Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b4de995

Browse files
Moved ice_parsing to data_generators; updated to 1.1.7
1 parent e38ab25 commit b4de995

File tree

5 files changed

+27
-59
lines changed

5 files changed

+27
-59
lines changed

tensor2tensor/data_generators/all_problems.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensor2tensor.data_generators import wiki
3232
from tensor2tensor.data_generators import wmt
3333
from tensor2tensor.data_generators import wsj_parsing
34+
from tensor2tensor.data_generators import ice_parsing
3435

3536

3637
# Problem modules that require optional dependencies

tensor2tensor/ice_parsing/ice_parsing.py renamed to tensor2tensor/data_generators/ice_parsing.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tensor2tensor.data_generators import text_encoder
2929
from tensor2tensor.data_generators.wmt import tabbed_generator
3030
from tensor2tensor.utils import registry
31-
from tensor2tensor.models import transformer
3231

3332
import tensorflow as tf
3433

@@ -69,9 +68,21 @@ def source_vocab_size(self):
6968
return 2**14 # 16384
7069

7170
@property
72-
def target_vocab_size(self):
71+
def targeted_vocab_size(self):
7372
return 2**8 # 256
7473

74+
@property
75+
def input_space_id(self):
76+
return problem.SpaceID.ICE_TOK
77+
78+
@property
79+
def target_space_id(self):
80+
return problem.SpaceID.ICE_PARSE_TOK
81+
82+
@property
83+
def num_shards(self):
84+
return 10
85+
7586
def feature_encoders(self, data_dir):
7687
source_vocab_filename = os.path.join(
7788
data_dir, "ice_source.tokens.vocab.%d" % self.source_vocab_size)
@@ -89,7 +100,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
89100
tabbed_parsing_token_generator(data_dir, tmp_dir, True, "ice",
90101
self.source_vocab_size,
91102
self.target_vocab_size),
92-
self.training_filepaths(data_dir, 1, shuffled=False),
103+
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
93104
tabbed_parsing_token_generator(data_dir, tmp_dir, False, "ice",
94105
self.source_vocab_size,
95106
self.target_vocab_size),
@@ -99,29 +110,8 @@ def hparams(self, defaults, model_hparams):
99110
p = defaults
100111
source_vocab_size = self._encoders["inputs"].vocab_size
101112
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, source_vocab_size)}
102-
p.target_modality = (registry.Modalities.SYMBOL, self.target_vocab_size)
103-
p.input_space_id = problem.SpaceID.ICE_TOK
104-
p.target_space_id = problem.SpaceID.ICE_PARSE_TOK
113+
p.target_modality = (registry.Modalities.SYMBOL, self.targeted_vocab_size)
114+
p.input_space_id = self.input_space_id
115+
p.target_space_id = self.target_space_id
105116
p.loss_multiplier = 2.5 # Rough estimate of avg number of tokens per word
106117

107-
108-
@registry.register_hparams
109-
def transformer_parsing_ice():
110-
"""Hparams for parsing Icelandic text."""
111-
hparams = transformer.transformer_base_single_gpu()
112-
hparams.batch_size = 4096
113-
hparams.shared_embedding_and_softmax_weights = int(False)
114-
return hparams
115-
116-
117-
@registry.register_hparams
118-
def transformer_parsing_ice_big():
119-
"""Hparams for parsing Icelandic text, bigger model."""
120-
hparams = transformer_parsing_ice()
121-
hparams.batch_size = 2048 # 4096 gives Out-of-memory on 8 GB 1080 GTX GPU
122-
hparams.attention_dropout = 0.05
123-
hparams.residual_dropout = 0.05
124-
hparams.max_length = 512
125-
hparams.hidden_size = 1024
126-
return hparams
127-

tensor2tensor/data_generators/wmt.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -187,36 +187,6 @@ def bi_vocabs_token_generator(source_path,
187187
source, target = source_file.readline(), target_file.readline()
188188

189189

190-
def tabbed_generator(source_path, source_vocab, target_vocab, eos=None):
191-
r"""Generator for sequence-to-sequence tasks using tabbed files.
192-
193-
Tokens are derived from text files where each line contains both
194-
a source and a target string. The two strings are separated by a tab
195-
character ('\t'). It yields dictionaries of "inputs" and "targets" where
196-
inputs are characters from the source lines converted to integers, and
197-
targets are characters from the target lines, also converted to integers.
198-
199-
Args:
200-
source_path: path to the file with source and target sentences.
201-
source_vocab: a SunwordTextEncoder to encode the source string.
202-
target_vocab: a SunwordTextEncoder to encode the target string.
203-
eos: integer to append at the end of each sequence (default: None).
204-
205-
Yields:
206-
A dictionary {"inputs": source-line, "targets": target-line} where
207-
the lines are integer lists converted from characters in the file lines.
208-
"""
209-
eos_list = [] if eos is None else [eos]
210-
with tf.gfile.GFile(source_path, mode="r") as source_file:
211-
for line in source_file:
212-
if line and "\t" in line:
213-
parts = line.split("\t", maxsplit=1)
214-
source, target = parts[0].strip(), parts[1].strip()
215-
source_ints = source_vocab.encode(source) + eos_list
216-
target_ints = target_vocab.encode(target) + eos_list
217-
yield {"inputs": source_ints, "targets": target_ints}
218-
219-
220190
# Data-set URLs.
221191

222192

tensor2tensor/ice_parsing/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

tensor2tensor/models/transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,15 @@ def transformer_parsing_big():
391391
return hparams
392392

393393

394+
@registry.register_hparams
395+
def transformer_parsing_ice():
396+
"""Hparams for parsing and tagging Icelandic text."""
397+
hparams = transformer.transformer_base_single_gpu()
398+
hparams.batch_size = 4096
399+
hparams.shared_embedding_and_softmax_weights = int(False)
400+
return hparams
401+
402+
394403
@registry.register_hparams
395404
def transformer_tiny():
396405
hparams = transformer_base()

0 commit comments

Comments
 (0)