Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 72f0874

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
internal merge
PiperOrigin-RevId: 161608262
1 parent 912daf7 commit 72f0874

File tree

4 files changed

+361
-1
lines changed

4 files changed

+361
-1
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
292292
"""Conditional conv_fn making kernel 1d or 2d depending on inputs shape."""
293293
static_shape = inputs.get_shape()
294294
if not static_shape or len(static_shape) != 4:
295-
raise ValueError("Inputs to conv must have statically known rank 4.")
295+
raise ValueError("Inputs to conv must have statically known rank 4. "
296+
"Shape: " + str(static_shape))
296297
# Add support for left padding.
297298
if "padding" in kwargs and kwargs["padding"] == "LEFT":
298299
dilation_rate = (1, 1)
@@ -1402,3 +1403,127 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence):
14021403
xentropy = tf.nn.softmax_cross_entropy_with_logits(
14031404
logits=logits, labels=soft_targets)
14041405
return xentropy - normalizing
1406+
1407+
1408+
def global_pool_1d(inputs, pooling_type="MAX", mask=None):
1409+
"""Pool elements across the last dimension.
1410+
1411+
Useful to convert a list of vectors into a single vector so as
1412+
to get a representation of a set.
1413+
1414+
Args:
1415+
inputs: A tensor of dimensions batch_size x sequence_length x input_dims
1416+
containing the sequences of input vectors.
1417+
pooling_type: the pooling type to use, MAX or AVR
1418+
mask: A tensor of dimensions batch_size x sequence_length containing a
1419+
mask for the inputs with 1's for existing elements, and 0's elsewhere.
1420+
1421+
Returns:
1422+
output: A tensor of dimensions batch_size x input_dims
1423+
dimension containing the sequences of transformed vectors.
1424+
"""
1425+
with tf.name_scope("global_pool", [inputs]):
1426+
if mask is not None:
1427+
mask = tf.expand_dims(mask, axis=2)
1428+
inputs = tf.multiply(inputs, mask)
1429+
1430+
if pooling_type == "MAX":
1431+
# A tf.pool can be used here, but reduce is cleaner
1432+
output = tf.reduce_max(inputs, axis=1)
1433+
elif pooling_type == "AVR":
1434+
if mask is not None:
1435+
# Some elems are dummy elems so we can't just reduce the average.
1436+
output = tf.reduce_sum(inputs, axis=1)
1437+
num_elems = tf.reduce_sum(mask, axis=1, keep_dims=True)
1438+
output = tf.div(output, tf.maximum(num_elems, 1))
1439+
else:
1440+
output = tf.reduce_mean(inputs, axis=1)
1441+
1442+
return output
1443+
1444+
1445+
def linear_set_layer(layer_size,
1446+
inputs,
1447+
context=None,
1448+
activation_fn=tf.nn.relu,
1449+
dropout=0.0,
1450+
name=None):
1451+
"""Basic layer type for doing funky things with sets.
1452+
1453+
Applies a linear transformation to each element in the input set.
1454+
If a context is supplied, it is concatenated with the inputs.
1455+
e.g. One can use global_pool_1d to get a representation of the set which
1456+
can then be used as the context for the next layer.
1457+
1458+
TODO: Add bias add (or control the biases used).
1459+
1460+
Args:
1461+
layer_size: Dimension to transform the input vectors to.
1462+
inputs: A tensor of dimensions batch_size x sequence_length x input_dims
1463+
containing the sequences of input vectors.
1464+
context: A tensor of dimensions batch_size x context_dims
1465+
containing a global statistic about the set.
1466+
activation_fn: The activation function to use.
1467+
dropout: Dropout probability.
1468+
name: name.
1469+
1470+
Returns:
1471+
output: A tensor of dimensions batch_size x sequence_length x output_dims
1472+
dimension containing the sequences of transformed vectors.
1473+
"""
1474+
with tf.variable_scope(name, "linear_set_layer", [inputs]):
1475+
# Apply 1D convolution to apply linear filter to each element
1476+
# along the 2nd dimension.
1477+
outputs = conv1d(inputs, layer_size, 1, activation=None, name="set_conv")
1478+
1479+
# Apply the context if it exists.
1480+
if context is not None:
1481+
# Unfortunately tf doesn't support broadcasting via concat, but we can
1482+
# simply add the transformed context to get the same effect.
1483+
context = tf.expand_dims(context, axis=1)
1484+
cont_tfm = conv1d(context, layer_size, 1,
1485+
activation=None, name="cont_conv")
1486+
outputs += cont_tfm
1487+
1488+
if activation_fn is not None:
1489+
outputs = activation_fn(outputs)
1490+
1491+
if dropout != 0.0:
1492+
outputs = tf.nn.dropout(outputs, 1.0 - dropout)
1493+
1494+
return outputs
1495+
1496+
1497+
def ravanbakhsh_set_layer(layer_size,
1498+
inputs,
1499+
mask=None,
1500+
activation_fn=tf.nn.tanh,
1501+
dropout=0.0,
1502+
name=None):
1503+
"""Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500 .
1504+
1505+
More parameter-efficient verstion of a linear-set-layer with context.
1506+
1507+
Args:
1508+
layer_size: Dimension to transform the input vectors to.
1509+
inputs: A tensor of dimensions batch_size x sequence_length x vector
1510+
containing the sequences of input vectors.
1511+
mask: A tensor of dimensions batch_size x sequence_length containing a
1512+
mask for the inputs with 1's for existing elements, and 0's elsewhere.
1513+
activation_fn: The activation function to use.
1514+
dropout: dropout.
1515+
name: name.
1516+
1517+
Returns:
1518+
output: A tensor of dimensions batch_size x sequence_length x vector
1519+
dimension containing the sequences of transformed vectors.
1520+
"""
1521+
with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]):
1522+
output = linear_set_layer(
1523+
layer_size,
1524+
inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1),
1525+
activation_fn=activation_fn,
1526+
dropout=dropout,
1527+
name=name)
1528+
1529+
return output

tensor2tensor/models/common_layers_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ def testConv(self):
8282
res = session.run(y)
8383
self.assertEqual(res.shape, (5, 5, 1, 13))
8484

85+
def testConv1d(self):
86+
x = np.random.rand(5, 7, 11)
87+
with self.test_session() as session:
88+
y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1)
89+
session.run(tf.global_variables_initializer())
90+
res = session.run(y)
91+
self.assertEqual(res.shape, (5, 7, 13))
92+
8593
def testSeparableConv(self):
8694
x = np.random.rand(5, 7, 1, 11)
8795
with self.test_session() as session:
@@ -361,6 +369,58 @@ def testResidualFnWithLayerNorm(self):
361369
actual = session.run(x3)
362370
self.assertEqual(actual.shape, (5, 2, 1, 11))
363371

372+
def testGlobalPool1d(self):
373+
x1 = np.random.rand(5, 4, 11)
374+
no_mask = np.ones((5, 4))
375+
full_mask = np.zeros((5, 4))
376+
377+
with self.test_session() as session:
378+
x1_ = tf.Variable(x1, dtype=tf.float32)
379+
no_mask_ = tf.Variable(no_mask, dtype=tf.float32)
380+
full_mask_ = tf.Variable(full_mask, dtype=tf.float32)
381+
382+
none_mask_max = common_layers.global_pool_1d(x1_)
383+
no_mask_max = common_layers.global_pool_1d(x1_, mask=no_mask_)
384+
result1 = tf.reduce_sum(none_mask_max - no_mask_max)
385+
386+
full_mask_max = common_layers.global_pool_1d(x1_, mask=full_mask_)
387+
result2 = tf.reduce_sum(full_mask_max)
388+
389+
none_mask_avr = common_layers.global_pool_1d(x1_, "AVR")
390+
no_mask_avr = common_layers.global_pool_1d(x1_, "AVR", no_mask_)
391+
result3 = tf.reduce_sum(none_mask_avr - no_mask_avr)
392+
393+
full_mask_avr = common_layers.global_pool_1d(x1_, "AVR", full_mask_)
394+
result4 = tf.reduce_sum(full_mask_avr)
395+
396+
session.run(tf.global_variables_initializer())
397+
actual = session.run([result1, result2, result3, result4])
398+
self.assertAllEqual(actual[:3], [0.0, 0.0, 0.0])
399+
400+
def testLinearSetLayer(self):
401+
x1 = np.random.rand(5, 4, 11)
402+
cont = np.random.rand(5, 13)
403+
with self.test_session() as session:
404+
x1_ = tf.Variable(x1, dtype=tf.float32)
405+
cont_ = tf.Variable(cont, dtype=tf.float32)
406+
407+
simple_ff = common_layers.linear_set_layer(32, x1_)
408+
cont_ff = common_layers.linear_set_layer(32, x1_, context=cont_)
409+
410+
session.run(tf.global_variables_initializer())
411+
actual = session.run([simple_ff, cont_ff])
412+
self.assertEqual(actual[0].shape, (5, 4, 32))
413+
self.assertEqual(actual[1].shape, (5, 4, 32))
414+
415+
def testRavanbakhshSetLayer(self):
416+
x1 = np.random.rand(5, 4, 11)
417+
with self.test_session() as session:
418+
x1_ = tf.Variable(x1, dtype=tf.float32)
419+
layer = common_layers.ravanbakhsh_set_layer(32, x1_)
420+
session.run(tf.global_variables_initializer())
421+
actual = session.run(layer)
422+
self.assertEqual(actual.shape, (5, 4, 32))
423+
364424

365425
if __name__ == "__main__":
366426
tf.test.main()

tensor2tensor/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@
3232
from tensor2tensor.models import neural_gpu
3333
from tensor2tensor.models import slicenet
3434
from tensor2tensor.models import transformer
35+
from tensor2tensor.models import transformer_alternative
3536
from tensor2tensor.models import xception
3637
# pylint: enable=unused-import
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Alternative transformer network.
16+
17+
Using different layer types to demonstrate alternatives to self attention.
18+
19+
Code is mostly copied from original Transformer source.
20+
"""
21+
22+
23+
from __future__ import absolute_import
24+
from __future__ import division
25+
from __future__ import print_function
26+
27+
# Dependency imports
28+
29+
from six.moves import xrange # pylint: disable=redefined-builtin
30+
31+
from tensor2tensor.models import common_attention
32+
from tensor2tensor.models import common_layers
33+
from tensor2tensor.models import transformer
34+
from tensor2tensor.utils import registry
35+
from tensor2tensor.utils import t2t_model
36+
37+
import tensorflow as tf
38+
39+
40+
@registry.register_model
41+
class TransformerAlt(t2t_model.T2TModel):
42+
43+
def model_fn_body(self, features):
44+
hparams = self._hparams
45+
targets = features["targets"]
46+
inputs = features.get("inputs")
47+
target_space = features.get("target_space_id")
48+
49+
inputs = common_layers.flatten4d3d(inputs)
50+
targets = common_layers.flatten4d3d(targets)
51+
52+
(encoder_input, encoder_attention_bias,
53+
_) = transformer.transformer_prepare_encoder(inputs, target_space, hparams)
54+
(decoder_input,
55+
decoder_self_attention_bias) = transformer.transformer_prepare_decoder(
56+
targets, hparams)
57+
58+
# We need masks of the form batch size x input sequences
59+
# Biases seem to be of the form batch_size x 1 x input sequences x vec dim
60+
# Squeeze out dim one, and get the first element of each vector.
61+
encoder_mask = tf.squeeze(encoder_attention_bias, [1])[:, :, 0]
62+
decoder_mask = tf.squeeze(decoder_self_attention_bias, [1])[:, :, 0]
63+
64+
def residual_fn(x, y):
65+
return common_layers.layer_norm(x + tf.nn.dropout(
66+
y, 1.0 - hparams.residual_dropout))
67+
68+
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
69+
decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout)
70+
encoder_output = alt_transformer_encoder(
71+
encoder_input, residual_fn, encoder_mask, hparams)
72+
73+
decoder_output = alt_transformer_decoder(
74+
decoder_input, encoder_output, residual_fn, decoder_mask,
75+
encoder_attention_bias, hparams)
76+
77+
decoder_output = tf.expand_dims(decoder_output, 2)
78+
79+
return decoder_output
80+
81+
82+
def composite_layer(inputs, mask, hparams):
83+
"""Composite layer."""
84+
x = inputs
85+
86+
# Applies ravanbakhsh on top of each other.
87+
if hparams.composite_layer_type == "ravanbakhsh":
88+
for layer in xrange(hparams.layers_per_layer):
89+
with tf.variable_scope(".%d" % layer):
90+
x = common_layers.ravanbakhsh_set_layer(
91+
hparams.hidden_size,
92+
x,
93+
mask=mask,
94+
dropout=0.0)
95+
96+
# Transforms elements to get a context, and then uses this in a final layer.
97+
elif hparams.composite_layer_type == "reembedding":
98+
# Transform elements n times and then pool.
99+
for layer in xrange(hparams.layers_per_layer):
100+
with tf.variable_scope(".%d" % layer):
101+
x = common_layers.linear_set_layer(
102+
hparams.hidden_size,
103+
x,
104+
dropout=0.0)
105+
context = common_layers.global_pool_1d(x, mask=mask)
106+
107+
# Final layer.
108+
x = common_layers.linear_set_layer(
109+
hparams.hidden_size,
110+
x,
111+
context=context,
112+
dropout=0.0)
113+
114+
return x
115+
116+
117+
def alt_transformer_encoder(encoder_input,
118+
residual_fn,
119+
mask,
120+
hparams,
121+
name="encoder"):
122+
"""Alternative encoder."""
123+
x = encoder_input
124+
125+
with tf.variable_scope(name):
126+
for layer in xrange(hparams.num_hidden_layers):
127+
with tf.variable_scope("layer_%d" % layer):
128+
x = residual_fn(x, composite_layer(x, mask, hparams))
129+
130+
return x
131+
132+
133+
def alt_transformer_decoder(decoder_input,
134+
encoder_output,
135+
residual_fn,
136+
mask,
137+
encoder_decoder_attention_bias,
138+
hparams,
139+
name="decoder"):
140+
"""Alternative decoder."""
141+
x = decoder_input
142+
143+
# Summaries don't work in multi-problem setting yet.
144+
summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
145+
with tf.variable_scope(name):
146+
for layer in xrange(hparams.num_hidden_layers):
147+
with tf.variable_scope("layer_%d" % layer):
148+
149+
x_ = common_attention.multihead_attention(
150+
x,
151+
encoder_output,
152+
encoder_decoder_attention_bias,
153+
hparams.attention_key_channels or hparams.hidden_size,
154+
hparams.attention_value_channels or hparams.hidden_size,
155+
hparams.hidden_size,
156+
hparams.num_heads,
157+
hparams.attention_dropout,
158+
summaries=summaries,
159+
name="encdec_attention")
160+
161+
x_ = residual_fn(x_, composite_layer(x_, mask, hparams))
162+
x = residual_fn(x, x_)
163+
164+
return x
165+
166+
167+
@registry.register_hparams
168+
def transformer_alt():
169+
"""Set of hyperparameters."""
170+
hparams = transformer.transformer_base()
171+
hparams.batch_size = 64
172+
hparams.add_hparam("layers_per_layer", 4)
173+
hparams.add_hparam("composite_layer_type", "reembedding")
174+
return hparams

0 commit comments

Comments
 (0)