Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 09a6084

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Remove redundant copies of res_fn. Use from common_layers
PiperOrigin-RevId: 161451356
1 parent afd1565 commit 09a6084

File tree

4 files changed

+107
-29
lines changed

4 files changed

+107
-29
lines changed

tensor2tensor/models/common_hparams.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def basic_params1():
6565
sampling_method="argmax", # "argmax" or "random"
6666
problem_choice="adaptive", # "uniform", "adaptive", "distributed"
6767
multiply_embedding_mode="sqrt_depth",
68+
norm_type="none", # "batch", layer", "noam", "none".
69+
layer_norm_epsilon=1e-6,
6870
symbol_modality_num_shards=16,
6971
# setting the max length in a minibatch. 0 means default behavior,
7072
# max_length = hparams.batch_size * length_multiplier

tensor2tensor/models/common_layers.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -433,24 +433,48 @@ def noam_norm(x, name=None):
433433
tf.sqrt(tf.to_float(shape[-1])))
434434

435435

436-
def residual_function(hparams):
436+
def get_norm(norm_type):
437+
"""Get the normalizer function."""
438+
if norm_type == "layer":
439+
return lambda x, name, filters=None, epsilon=1e-6: layer_norm( # pylint: disable=g-long-lambda
440+
x, filters=filters, epsilon=epsilon, name=name)
441+
if norm_type == "batch":
442+
return tf.layers.batch_normalization
443+
if norm_type == "noam":
444+
return noam_norm
445+
if norm_type == "none":
446+
return lambda x, name: x
447+
raise ValueError("Parameter normalizer_fn must be one of: 'layer', 'batch',"
448+
"'noam', 'none'.")
449+
450+
451+
def residual_fn(x, y, norm_type, residual_dropout,
452+
filters=None,
453+
epsilon=1e-16,
454+
name="residual"):
437455
"""Returns a function for combining layer input and layer output.
438456
439457
The returned function on x (layer input) and y (layer output) computes:
440-
norm_function(x + t
458+
norm_function(x + dropout(y))
441459
442460
Args:
443-
hparams: model hyperparameters
461+
x: tensor, input layer
462+
y: tensor, output layer
463+
norm_type: string, type of normalizer function
464+
residual_dropout: integer, dropout value for residual connection
465+
filters: integer, dimension for layer norm, optional
466+
epsilon: integer, value of layer norm epsilon
467+
name: string, name
444468
445469
Returns:
446-
a function from x=<layer input> and y=<layer output> to computed output
470+
residual layer output with applied norm_fn.
447471
"""
448-
449-
def residual_fn(x, y):
450-
return hparams.norm_function(x + tf.nn.dropout(
451-
y, 1.0 - hparams.residual_dropout))
452-
453-
return residual_fn
472+
norm_fn = get_norm(norm_type)
473+
res = x + tf.nn.dropout(y, 1.0 - residual_dropout)
474+
if norm_type == "layer":
475+
return norm_fn(res, name=name, filters=filters, epsilon=epsilon)
476+
else:
477+
return norm_fn(res, name=name)
454478

455479

456480
def conv_block_internal(conv_fn,

tensor2tensor/models/common_layers_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,73 @@ def testDeconvStride2MultiStep(self):
294294
actual = session.run(a)
295295
self.assertEqual(actual.shape, (5, 32, 1, 16))
296296

297+
def testGetNormLayerFn(self):
298+
norm_type = "layer"
299+
with self.test_session() as session:
300+
a = common_layers.get_norm(norm_type)
301+
x1 = np.random.rand(5, 2, 1, 11)
302+
x2 = a(tf.constant(x1, dtype=tf.float32), name="layer", filters=11)
303+
session.run(tf.global_variables_initializer())
304+
actual = session.run(x2)
305+
self.assertEqual(actual.shape, (5, 2, 1, 11))
306+
307+
def testGetNormNoamFn(self):
308+
norm_type = "noam"
309+
with self.test_session() as session:
310+
a = common_layers.get_norm(norm_type)
311+
x1 = np.random.rand(5, 2, 1, 11)
312+
x2 = a(tf.constant(x1, dtype=tf.float32), name="noam")
313+
session.run(tf.global_variables_initializer())
314+
actual = session.run(x2)
315+
self.assertEqual(actual.shape, (5, 2, 1, 11))
316+
317+
def testGetNormBatchFn(self):
318+
norm_type = "batch"
319+
with self.test_session() as session:
320+
a = common_layers.get_norm(norm_type)
321+
x1 = np.random.rand(5, 2, 1, 11)
322+
x2 = a(tf.constant(x1, dtype=tf.float32), name="batch")
323+
session.run(tf.global_variables_initializer())
324+
actual = session.run(x2)
325+
self.assertEqual(actual.shape, (5, 2, 1, 11))
326+
327+
def testGetNormNoneFn(self):
328+
norm_type = "none"
329+
with self.test_session() as session:
330+
a = common_layers.get_norm(norm_type)
331+
x1 = np.random.rand(5, 2, 1, 11)
332+
x2 = a(tf.constant(x1, dtype=tf.float32), name="none")
333+
session.run(tf.global_variables_initializer())
334+
actual = session.run(x2)
335+
self.assertEqual(actual.shape, (5, 2, 1, 11))
336+
self.assertAllClose(actual, x1, atol=1e-03)
337+
338+
def testResidualFn(self):
339+
norm_type = "batch"
340+
with self.test_session() as session:
341+
x1 = np.random.rand(5, 2, 1, 11)
342+
x2 = np.random.rand(5, 2, 1, 11)
343+
x3 = common_layers.residual_fn(
344+
tf.constant(x1, dtype=tf.float32),
345+
tf.constant(x2, dtype=tf.float32),
346+
norm_type, 0.1)
347+
session.run(tf.global_variables_initializer())
348+
actual = session.run(x3)
349+
self.assertEqual(actual.shape, (5, 2, 1, 11))
350+
351+
def testResidualFnWithLayerNorm(self):
352+
norm_type = "layer"
353+
with self.test_session() as session:
354+
x1 = np.random.rand(5, 2, 1, 11)
355+
x2 = np.random.rand(5, 2, 1, 11)
356+
x3 = common_layers.residual_fn(
357+
tf.constant(x1, dtype=tf.float32),
358+
tf.constant(x2, dtype=tf.float32),
359+
norm_type, 0.1, epsilon=0.1)
360+
session.run(tf.global_variables_initializer())
361+
actual = session.run(x3)
362+
self.assertEqual(actual.shape, (5, 2, 1, 11))
363+
297364

298365
if __name__ == "__main__":
299366
tf.test.main()

tensor2tensor/models/slicenet.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,6 @@
3131
import tensorflow as tf
3232

3333

34-
def get_norm(hparams):
35-
"""Get the normalizer function."""
36-
if hparams.normalizer_fn == "layer":
37-
return lambda x, name: common_layers.layer_norm( # pylint: disable=g-long-lambda
38-
x, hparams.hidden_size, name=name)
39-
if hparams.normalizer_fn == "batch":
40-
return tf.layers.batch_normalization
41-
if hparams.normalizer_fn == "noam":
42-
return common_layers.noam_norm
43-
if hparams.normalizer_fn == "none":
44-
return lambda x, name: x
45-
raise ValueError("Parameter normalizer_fn must be one of: 'layer', 'batch',"
46-
"'noam', 'none'.")
47-
48-
4934
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
5035
"""Complete attention layer with preprocessing."""
5136
separabilities = [hparams.separability, hparams.separability]
@@ -128,7 +113,7 @@ def multi_conv_res(x, padding, name, layers, hparams,
128113
hparams.separability - i
129114
for i in reversed(range(len(dilations_and_kernels2)))
130115
]
131-
norm_fn = get_norm(hparams)
116+
norm_fn = common_layers.get_norm(hparams.norm_type)
132117
for layer in xrange(layers):
133118
with tf.variable_scope("layer_%d" % layer):
134119
y = common_layers.subseparable_conv_block(
@@ -188,7 +173,7 @@ def similarity_cost(inputs_encoded, targets_encoded):
188173

189174
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams):
190175
"""Middle part of slicenet, connecting encoder and decoder."""
191-
norm_fn = get_norm(hparams)
176+
norm_fn = common_layers.get_norm(hparams.norm_type)
192177

193178
# Flatten targets and embed target_space_id.
194179
targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2)
@@ -311,7 +296,7 @@ def slicenet_params1():
311296
hparams.num_hidden_layers = 4
312297
hparams.kernel_height = 3
313298
hparams.kernel_width = 1
314-
hparams.add_hparam("normalizer_fn", "layer") # New ones are added like this.
299+
hparams.norm_type = "layer"
315300
hparams.learning_rate_decay_scheme = "exp50k"
316301
hparams.learning_rate = 0.05
317302
hparams.learning_rate_warmup_steps = 3000
@@ -322,7 +307,7 @@ def slicenet_params1():
322307
hparams.optimizer_adam_epsilon = 1e-6
323308
hparams.optimizer_adam_beta1 = 0.85
324309
hparams.optimizer_adam_beta2 = 0.997
325-
hparams.add_hparam("large_kernel_size", 15)
310+
hparams.add_hparam("large_kernel_size", 15) # New ones are added like this.
326311
hparams.add_hparam("separability", -2)
327312
# A dilation scheme, one of _DILATION_SCHEMES.
328313
hparams.add_hparam("dilation_scheme", "1.1.1.1")

0 commit comments

Comments
 (0)