Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f3e5859

Browse files
author
Lukasz Kaiser
committed
merge from github
PiperOrigin-RevId: 160671813
1 parent 302d0ca commit f3e5859

File tree

3 files changed

+40
-14
lines changed

3 files changed

+40
-14
lines changed

.gitignore

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
# Compiled python modules.
22
*.pyc
33

4+
# Byte-compiled
5+
_pycache__/
6+
47
# Python egg metadata, regenerated from source files by setuptools.
58
/*.egg-info
69

7-
# PyPI distribution artificats
10+
# PyPI distribution artifacts.
811
build/
912
dist/
13+
14+
# Sublime project files
15+
*.sublime-project
16+
*.sublime-workspace

tensor2tensor/data_generators/text_encoder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,21 +310,24 @@ def build_to_target_size(cls,
310310
tf.logging.info("Alphabet contains %d characters" % len(alphabet_set))
311311

312312
def bisect(min_val, max_val):
313+
"""Bisection to find the right size."""
313314
present_count = (max_val + min_val) // 2
314315
tf.logging.info("Trying min_count %d" % present_count)
315316
subtokenizer = cls()
316317
subtokenizer.build_from_token_counts(token_counts, alphabet_set,
317318
present_count, num_iterations)
318319
if min_val >= max_val or subtokenizer.vocab_size == target_size:
319320
return subtokenizer
321+
320322
if subtokenizer.vocab_size > target_size:
321323
other_subtokenizer = bisect(present_count + 1, max_val)
322324
else:
323325
other_subtokenizer = bisect(min_val, present_count - 1)
324-
if (abs(other_subtokenizer.vocab_size - target_size) <
325-
abs(subtokenizer.vocab_size - target_size)):
326-
return other_subtokenizer
327-
return subtokenizer
326+
327+
if (abs(other_subtokenizer.vocab_size - target_size) <
328+
abs(subtokenizer.vocab_size - target_size)):
329+
return other_subtokenizer
330+
return subtokenizer
328331

329332
return bisect(min_val, max_val)
330333

tensor2tensor/models/bluenet.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def run_binary_modules(modules, cur1, cur2, hparams):
7777
"""Run binary modules."""
7878
selection_var = tf.get_variable("selection", [len(modules)],
7979
initializer=tf.zeros_initializer())
80-
inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01)
80+
inv_t = 100.0 * common_layers.inverse_exp_decay(
81+
hparams.anneal_until, min_value=0.01)
8182
selected_weights = tf.nn.softmax(selection_var * inv_t)
8283
all_res = [modules[n](cur1, cur2, hparams) for n in xrange(len(modules))]
8384
all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
@@ -89,7 +90,8 @@ def run_unary_modules_basic(modules, cur, hparams):
8990
"""Run unary modules."""
9091
selection_var = tf.get_variable("selection", [len(modules)],
9192
initializer=tf.zeros_initializer())
92-
inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01)
93+
inv_t = 100.0 * common_layers.inverse_exp_decay(
94+
hparams.anneal_until, min_value=0.01)
9395
selected_weights = tf.nn.softmax(selection_var * inv_t)
9496
all_res = [modules[n](cur, hparams) for n in xrange(len(modules))]
9597
all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
@@ -109,7 +111,8 @@ def run_unary_modules_sample(modules, cur, hparams, k):
109111
lambda: tf.zeros_like(cur),
110112
lambda i=n: modules[i](cur, hparams))
111113
for n in xrange(len(modules))]
112-
inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01)
114+
inv_t = 100.0 * common_layers.inverse_exp_decay(
115+
hparams.anneal_until, min_value=0.01)
113116
selected_weights = tf.nn.softmax(selection_var * inv_t - 1e9 * (1.0 - to_run))
114117
all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
115118
res = all_res * tf.reshape(selected_weights, [-1, 1, 1, 1, 1])
@@ -122,6 +125,14 @@ def run_unary_modules(modules, cur, hparams):
122125
return run_unary_modules_sample(modules, cur, hparams, 4)
123126

124127

128+
def batch_deviation(x):
129+
"""Average deviation of the batch."""
130+
x_mean = tf.reduce_mean(x, axis=[0], keep_dims=True)
131+
x_variance = tf.reduce_mean(
132+
tf.square(x - x_mean), axis=[0], keep_dims=True)
133+
return tf.reduce_mean(tf.sqrt(x_variance))
134+
135+
125136
@registry.register_model
126137
class BlueNet(t2t_model.T2TModel):
127138

@@ -153,14 +164,15 @@ def run_unary(x, name):
153164
with tf.variable_scope("conv"):
154165
x = run_unary_modules(conv_modules, x, hparams)
155166
x.set_shape(x_shape)
156-
return x
167+
return tf.nn.dropout(x, 1.0 - hparams.dropout), batch_deviation(x)
157168

158-
cur1, cur2 = inputs, inputs
169+
cur1, cur2, extra_loss = inputs, inputs, 0.0
159170
cur_shape = inputs.get_shape()
160171
for i in xrange(hparams.num_hidden_layers):
161172
with tf.variable_scope("layer_%d" % i):
162-
cur1 = run_unary(cur1, "unary1")
163-
cur2 = run_unary(cur2, "unary2")
173+
cur1, loss1 = run_unary(cur1, "unary1")
174+
cur2, loss2 = run_unary(cur2, "unary2")
175+
extra_loss += (loss1 + loss2) / float(hparams.num_hidden_layers)
164176
with tf.variable_scope("binary1"):
165177
next1 = run_binary_modules(binary_modules, cur1, cur2, hparams)
166178
next1.set_shape(cur_shape)
@@ -169,7 +181,9 @@ def run_unary(x, name):
169181
next2.set_shape(cur_shape)
170182
cur1, cur2 = next1, next2
171183

172-
return cur1
184+
anneal = common_layers.inverse_exp_decay(hparams.anneal_until)
185+
extra_loss *= hparams.batch_deviation_loss_factor * anneal
186+
return cur1, extra_loss
173187

174188

175189
@registry.register_hparams
@@ -185,7 +199,7 @@ def bluenet_base():
185199
hparams.num_hidden_layers = 8
186200
hparams.kernel_height = 3
187201
hparams.kernel_width = 3
188-
hparams.learning_rate_decay_scheme = "exp50k"
202+
hparams.learning_rate_decay_scheme = "exp10k"
189203
hparams.learning_rate = 0.05
190204
hparams.learning_rate_warmup_steps = 3000
191205
hparams.initializer_gain = 1.0
@@ -196,6 +210,8 @@ def bluenet_base():
196210
hparams.optimizer_adam_beta1 = 0.85
197211
hparams.optimizer_adam_beta2 = 0.997
198212
hparams.add_hparam("imagenet_use_2d", True)
213+
hparams.add_hparam("anneal_until", 40000)
214+
hparams.add_hparam("batch_deviation_loss_factor", 0.001)
199215
return hparams
200216

201217

0 commit comments

Comments
 (0)