Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d827bb2

Browse files
authored
Merge pull request #133 from EndingCredits/master
Add new set network layer type and example model
2 parents cf76c73 + bcf9a8b commit d827bb2

File tree

4 files changed

+387
-3
lines changed

4 files changed

+387
-3
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
292292
"""Conditional conv_fn making kernel 1d or 2d depending on inputs shape."""
293293
static_shape = inputs.get_shape()
294294
if not static_shape or len(static_shape) != 4:
295-
raise ValueError("Inputs to conv must have statically known rank 4.")
295+
raise ValueError("Inputs to conv must have statically known rank 4. Shape:" +str(static_shape))
296296
# Add support for left padding.
297297
if "padding" in kwargs and kwargs["padding"] == "LEFT":
298298
dilation_rate = (1, 1)
@@ -1378,3 +1378,128 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence):
13781378
xentropy = tf.nn.softmax_cross_entropy_with_logits(
13791379
logits=logits, labels=soft_targets)
13801380
return xentropy - normalizing
1381+
1382+
1383+
def global_pool_1d(inputs, pooling_type='MAX', mask=None):
1384+
"""
1385+
Pools elements across the last dimension. Useful to a list of vectors into a
1386+
single vector to get a representation of a set.
1387+
1388+
Args
1389+
inputs: A tensor of dimensions batch_size x sequence_length x input_dims
1390+
containing the sequences of input vectors.
1391+
pooling_type: the pooling type to use, MAX or AVR
1392+
mask: A tensor of dimensions batch_size x sequence_length containing a
1393+
mask for the inputs with 1's for existing elements, and 0's elsewhere.
1394+
Outputs
1395+
output: A tensor of dimensions batch_size x input_dims
1396+
dimension containing the sequences of transformed vectors.
1397+
"""
1398+
1399+
with tf.name_scope("global_pool", [inputs]):
1400+
if mask is not None:
1401+
mask = tf.expand_dims(mask, axis=2)
1402+
inputs = tf.multiply(inputs, mask)
1403+
1404+
if pooling_type == 'MAX':
1405+
# A tf.pool can be used here, but reduce is cleaner
1406+
output = tf.reduce_max(inputs, axis=1)
1407+
elif pooling_type == 'AVR':
1408+
if mask is not None:
1409+
# Some elems are dummy elems so we can't just reduce the average
1410+
output = tf.reduce_sum(inputs, axis=1)
1411+
num_elems = tf.reduce_sum(mask, axis=1, keep_dims=True)
1412+
output = tf.div(output, num_elems)
1413+
#N.B: this will cause a NaN if one batch contains no elements
1414+
else:
1415+
output = tf.reduce_mean(inputs, axis=1)
1416+
1417+
return output
1418+
1419+
1420+
def linear_set_layer(layer_size,
1421+
inputs,
1422+
context=None,
1423+
activation_fn=tf.nn.relu,
1424+
dropout=0.0,
1425+
name=None):
1426+
"""
1427+
Basic layer type for doing funky things with sets.
1428+
Applies a linear transformation to each element in the input set.
1429+
If a context is supplied, it is concatenated with the inputs.
1430+
e.g. One can use global_pool_1d to get a representation of the set which
1431+
can then be used as the context for the next layer.
1432+
1433+
Args
1434+
layer_size: Dimension to transform the input vectors to
1435+
inputs: A tensor of dimensions batch_size x sequence_length x input_dims
1436+
containing the sequences of input vectors.
1437+
context: A tensor of dimensions batch_size x context_dims
1438+
containing a global statistic about the set.
1439+
dropout: Dropout probability.
1440+
activation_fn: The activation function to use.
1441+
Outputs
1442+
output: A tensor of dimensions batch_size x sequence_length x output_dims
1443+
dimension containing the sequences of transformed vectors.
1444+
1445+
TODO: Add bias add.
1446+
"""
1447+
1448+
with tf.variable_scope(name, "linear_set_layer", [inputs]):
1449+
# Apply 1D convolution to apply linear filter to each element along the 2nd
1450+
# dimension
1451+
#in_size = inputs.get_shape().as_list()[-1]
1452+
outputs = conv1d(inputs, layer_size, 1, activation=None, name="set_conv")
1453+
1454+
# Apply the context if it exists
1455+
if context is not None:
1456+
# Unfortunately tf doesn't support broadcasting via concat, but we can
1457+
# simply add the transformed context to get the same effect
1458+
context = tf.expand_dims(context, axis=1)
1459+
#context_size = context.get_shape().as_list()[-1]
1460+
cont_tfm = conv1d(context, layer_size, 1,
1461+
activation=None, name="cont_conv")
1462+
outputs += cont_tfm
1463+
1464+
if activation_fn is not None:
1465+
outputs = activation_fn(outputs)
1466+
1467+
if dropout != 0.0:
1468+
output = tf.nn.dropout(output, 1.0 - dropout)
1469+
1470+
return outputs
1471+
1472+
1473+
def ravanbakhsh_set_layer(layer_size,
1474+
inputs,
1475+
mask=None,
1476+
activation_fn=tf.nn.tanh,
1477+
dropout=0.0,
1478+
name=None):
1479+
"""
1480+
Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500
1481+
More parameter-efficient verstion of a linear-set-layer with context.
1482+
1483+
1484+
Args
1485+
layer_size: Dimension to transform the input vectors to.
1486+
inputs: A tensor of dimensions batch_size x sequence_length x vector
1487+
containing the sequences of input vectors.
1488+
mask: A tensor of dimensions batch_size x sequence_length containing a
1489+
mask for the inputs with 1's for existing elements, and 0's elsewhere.
1490+
activation_fn: The activation function to use.
1491+
Outputs
1492+
output: A tensor of dimensions batch_size x sequence_length x vector
1493+
dimension containing the sequences of transformed vectors.
1494+
"""
1495+
1496+
with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]):
1497+
output = linear_set_layer(
1498+
layer_size,
1499+
inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1),
1500+
activation_fn=activation_fn,
1501+
name=name)
1502+
1503+
return output
1504+
1505+

tensor2tensor/models/common_layers_test.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def testSaturatingSigmoid(self):
5050
self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
5151

5252
def testFlatten4D3D(self):
53-
x = np.random.random_integers(1, high=8, size=(3, 5, 2))
53+
x = np.random.randint(1, 9, size=(3, 5, 2))
5454
with self.test_session() as session:
5555
y = common_layers.flatten4d3d(common_layers.embedding(x, 10, 7))
5656
session.run(tf.global_variables_initializer())
5757
res = session.run(y)
5858
self.assertEqual(res.shape, (3, 5 * 2, 7))
5959

6060
def testEmbedding(self):
61-
x = np.random.random_integers(1, high=8, size=(3, 5))
61+
x = np.random.randint(1, 9, size=(3, 5))
6262
with self.test_session() as session:
6363
y = common_layers.embedding(x, 10, 16)
6464
session.run(tf.global_variables_initializer())
@@ -81,6 +81,14 @@ def testConv(self):
8181
session.run(tf.global_variables_initializer())
8282
res = session.run(y)
8383
self.assertEqual(res.shape, (5, 5, 1, 13))
84+
85+
def testConv1d(self):
86+
x = np.random.rand(5, 7, 11)
87+
with self.test_session() as session:
88+
y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1)
89+
session.run(tf.global_variables_initializer())
90+
res = session.run(y)
91+
self.assertEqual(res.shape, (5, 7, 13))
8492

8593
def testSeparableConv(self):
8694
x = np.random.rand(5, 7, 1, 11)
@@ -293,6 +301,66 @@ def testDeconvStride2MultiStep(self):
293301
session.run(tf.global_variables_initializer())
294302
actual = session.run(a)
295303
self.assertEqual(actual.shape, (5, 32, 1, 16))
304+
305+
def testGlobalPool1d(self):
306+
shape = (5, 4)
307+
x1 = np.random.rand(5,4,11)
308+
#mask = np.random.randint(2, size=shape)
309+
no_mask = np.ones((5,4))
310+
full_mask = np.zeros((5,4))
311+
312+
with self.test_session() as session:
313+
x1_ = tf.Variable(x1, dtype=tf.float32)
314+
no_mask_ = tf.Variable(no_mask, dtype=tf.float32)
315+
full_mask_ = tf.Variable(full_mask, dtype=tf.float32)
316+
317+
none_mask_max = common_layers.global_pool_1d(x1_)
318+
no_mask_max = common_layers.global_pool_1d(x1_, mask=no_mask_)
319+
result1 = tf.reduce_sum(none_mask_max - no_mask_max)
320+
321+
full_mask_max = common_layers.global_pool_1d(x1_, mask=full_mask_)
322+
result2 = tf.reduce_sum(full_mask_max)
323+
324+
none_mask_avr = common_layers.global_pool_1d(x1_, 'AVR')
325+
no_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', no_mask_)
326+
result3 = tf.reduce_sum(none_mask_avr - no_mask_avr)
327+
328+
full_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', full_mask_)
329+
result4 = tf.reduce_sum(full_mask_avr)
330+
331+
session.run(tf.global_variables_initializer())
332+
actual = session.run([result1, result2, result3, result4])
333+
# N.B: Last result will give a NaN.
334+
self.assertAllEqual(actual[:3], [0.0, 0.0, 0.0])
335+
336+
337+
def testLinearSetLayer(self):
338+
x1 = np.random.rand(5,4,11)
339+
cont = np.random.rand(5,13)
340+
with self.test_session() as session:
341+
x1_ = tf.Variable(x1, dtype=tf.float32)
342+
cont_ = tf.Variable(cont, dtype=tf.float32)
343+
344+
simple_ff = common_layers.linear_set_layer(32, x1_)
345+
cont_ff = common_layers.linear_set_layer(32, x1_, context=cont_)
346+
347+
session.run(tf.global_variables_initializer())
348+
actual = session.run([simple_ff, cont_ff])
349+
self.assertEqual(actual[0].shape, (5,4,32))
350+
self.assertEqual(actual[1].shape, (5,4,32))
351+
352+
def testRavanbakhshSetLayer(self):
353+
x1 = np.random.rand(5,4,11)
354+
cont = np.random.rand(5,13)
355+
with self.test_session() as session:
356+
x1_ = tf.Variable(x1, dtype=tf.float32)
357+
cont_ = tf.Variable(cont, dtype=tf.float32)
358+
359+
layer = common_layers.ravanbakhsh_set_layer(32, x1_)
360+
361+
session.run(tf.global_variables_initializer())
362+
actual = session.run(layer)
363+
self.assertEqual(actual.shape, (5,4,32))
296364

297365

298366
if __name__ == "__main__":

0 commit comments

Comments
 (0)