Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8f5fcc2

Browse files
nshazeerRyan Sepassi
authored andcommitted
add wiki-scramble dataset.
PiperOrigin-RevId: 168037859
1 parent c99d5b5 commit 8f5fcc2

File tree

2 files changed

+145
-7
lines changed

2 files changed

+145
-7
lines changed

tensor2tensor/data_generators/wiki.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import bz2file
2727

28+
import numpy as np
29+
2830
import six
2931
from tensor2tensor.data_generators import generator_utils
3032
from tensor2tensor.data_generators import problem
@@ -130,3 +132,118 @@ def generator(self, data_dir, tmp_dir, _):
130132
encoded = encoder.encode(page) + [EOS]
131133
encoded_title = encoder.encode(title) + [EOS]
132134
yield {"inputs": encoded_title, "targets": encoded}
135+
136+
137+
class LanguagemodelWikiScramble(problem.Text2TextProblem):
138+
"""Language modeling on English wikipedia.
139+
140+
"targets" is a sequence of sequence_length tokens - a fragment of an article.
141+
"inputs" is a copy of "targets", but with a random scramble_fraction of the
142+
tokens randomly permuted.
143+
144+
This dataset is intended to test parallel (non-autoregressive) prediction
145+
of the target sequence given the input sequence.
146+
"""
147+
148+
@property
149+
def sequence_length(self):
150+
raise NotImplementedError()
151+
152+
@property
153+
def scramble_fraction(self):
154+
raise NotImplementedError()
155+
156+
@property
157+
def is_character_level(self):
158+
return False
159+
160+
@property
161+
def has_inputs(self):
162+
return True
163+
164+
@property
165+
def input_space_id(self):
166+
return problem.SpaceID.EN_TOK
167+
168+
@property
169+
def target_space_id(self):
170+
return problem.SpaceID.EN_TOK
171+
172+
@property
173+
def num_shards(self):
174+
return 1000
175+
176+
@property
177+
def vocab_name(self):
178+
return "vocab.wiki"
179+
180+
@property
181+
def use_subword_tokenizer(self):
182+
return True
183+
184+
@property
185+
def targeted_vocab_size(self):
186+
return 2**13 # 8192
187+
188+
@property
189+
def use_train_shards_for_dev(self):
190+
return True
191+
192+
@property
193+
def max_cases(self):
194+
return (2 ** 30) / self.sequence_length
195+
196+
def scramble(self, seq):
197+
seq = np.array(seq)
198+
num_permute = int(self.sequence_length * self.scramble_fraction)
199+
full_permutation = np.random.permutation(self.sequence_length)
200+
inverse_full_permutation = np.argsort(full_permutation)
201+
partial_permutation = np.random.permutation(num_permute)
202+
seq = seq[full_permutation]
203+
seq = np.concatenate(
204+
(seq[:num_permute][partial_permutation], seq[num_permute:]))
205+
seq = seq[inverse_full_permutation]
206+
seq = list(seq)
207+
return seq
208+
209+
def generator(self, data_dir, tmp_dir, _):
210+
encoder = generator_utils.get_or_generate_vocab_inner(
211+
data_dir, self.vocab_file, self.targeted_vocab_size,
212+
lambda: page_generator(tmp_dir, max_docs=1000))
213+
case_num = 0
214+
for page in page_generator(tmp_dir):
215+
encoded = encoder.encode(page)
216+
for i in xrange(len(encoded) // self.sequence_length):
217+
case_num += 1
218+
if self.max_cases and case_num > self.max_cases:
219+
return
220+
targets = encoded[
221+
i * self.sequence_length:(i + 1) * self.sequence_length]
222+
inputs = self.scramble(targets)
223+
yield {"inputs": inputs, "targets": targets}
224+
225+
226+
@registry.register_problem
227+
class LanguagemodelWikiScramble1k50(LanguagemodelWikiScramble):
228+
"""Sequence length 1024, 50% scrambed."""
229+
230+
@property
231+
def sequence_length(self):
232+
return 1024
233+
234+
@property
235+
def scramble_fraction(self):
236+
return 0.5
237+
238+
239+
@registry.register_problem
240+
class LanguagemodelWikiScramble8k50(LanguagemodelWikiScramble):
241+
"""Sequence length 8192, 50% scrambed."""
242+
243+
@property
244+
def sequence_length(self):
245+
return 8192
246+
247+
@property
248+
def scramble_fraction(self):
249+
return 0.5

tensor2tensor/models/attention_lm_moe.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,21 @@ def model_fn_body_sharded(self, sharded_features):
6868
# Remove dropout if not training
6969
hparams = self._hparams
7070
dp = self._data_parallelism
71-
targets = sharded_features["targets"]
72-
targets = dp(tf.squeeze, targets, 2)
71+
if hparams.use_inputs:
72+
decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
73+
decoder_self_attention_bias = None
74+
else:
75+
targets = sharded_features["targets"]
76+
targets = dp(tf.squeeze, targets, 2)
77+
(decoder_input, decoder_self_attention_bias, pad_remover) = dp(
78+
attention_lm_moe_prepare_decoder, targets, hparams)
7379

7480
def preprocess(x):
7581
return dp(common_layers.layer_preprocess, x, hparams)
7682

7783
def postprocess(x, y):
7884
return dp(common_layers.layer_postprocess, x, y, hparams)
7985

80-
(decoder_input, decoder_self_attention_bias, pad_remover) = dp(
81-
attention_lm_moe_prepare_decoder, targets, hparams)
82-
8386
x = dp(tf.nn.dropout, decoder_input,
8487
1.0 - hparams.layer_prepostprocess_dropout)
8588
extra_loss = 0.0
@@ -95,7 +98,8 @@ def _diet_expert(x):
9598
expert_fn = expert_utils.ffn_expert_fn(
9699
hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
97100

98-
if hparams.attention_type == AttentionType.LOCAL_EXPERTS:
101+
if (hparams.attention_type == AttentionType.LOCAL_EXPERTS
102+
and not hparams.use_inputs):
99103
# As preprocess and postprocess are called with batch of size one (all
100104
# batches concatenated), we just make sure that batch_norm is not use (
101105
# should not either way)
@@ -162,7 +166,7 @@ def print_shape(x, suffix, debug=False):
162166
attention_num_experts=hparams.attention_num_experts,
163167
train=hparams.mode == ModeKeys.TRAIN,
164168
batch_coordinate=batch_coordinate,
165-
mask_right=True,
169+
mask_right=not hparams.use_inputs,
166170
split_batch=bool(hparams.attention_split_batch),
167171
attention_kq_size=hparams.attention_kq_size,
168172
attention_v_size=hparams.attention_v_size)
@@ -356,6 +360,9 @@ def attention_lm_moe_base():
356360
hparams.add_hparam("use_sepconv", int(False))
357361
hparams.add_hparam("diet_experts", int(False))
358362
hparams.add_hparam("memory_efficient_ffn", int(False))
363+
# if True, we learn a non-autoregressive model from "inputs" to "targets".
364+
# if False, we learn an autoregressive model to generate "targets"
365+
hparams.add_hparam("use_inputs", int(False))
359366
return hparams
360367

361368

@@ -526,3 +533,17 @@ def attention_lm_moe_translation():
526533
hparams.moe_layers = "0,1,2,3,4,5"
527534
hparams.shared_embedding_and_softmax_weights = int(True)
528535
return hparams
536+
537+
538+
@registry.register_hparams
539+
def attention_lm_moe_unscramble_base():
540+
"""Version to use with languagemodel_wiki_scramble1k50."""
541+
hparams = attention_lm_no_moe_small()
542+
hparams.use_inputs = True
543+
hparams.min_length_bucket = 1024
544+
hparams.max_length = 1024
545+
hparams.batch_size = 5000
546+
hparams.layer_prepostprocess_dropout = 0.0
547+
hparams.layer_preprocess_sequence = "n"
548+
hparams.layer_postprocess_sequence = "da"
549+
return hparams

0 commit comments

Comments
 (0)