Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9c6402b

Browse files
T2T TeamCopybara-Service
authored andcommitted
Replace one_hot + matmul with tf.gather on R1 indices for faster gather operation.
PiperOrigin-RevId: 221513887
1 parent 7ae4c7f commit 9c6402b

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,10 @@ def flatten4d3d(x):
272272

273273

274274
# TODO(noam): remove this function after TPUs do gather faster.
275-
def gather(params, indices, dtype=tf.float32):
275+
def gather(params, indices):
276276
"""Version of tf.gather that works faster on tpu."""
277-
if not is_xla_compiled():
278-
return tf.gather(params, indices)
279-
vocab_size = params.get_shape().as_list()[0]
280277
indices_flat = tf.reshape(indices, [-1])
281-
out = tf.matmul(tf.one_hot(indices_flat, vocab_size, dtype=dtype), params)
278+
out = tf.gather(params, indices_flat)
282279
out = reshape_like(out, tf.expand_dims(indices, -1))
283280
return out
284281

@@ -352,7 +349,7 @@ def embedding(x,
352349
if not tf.contrib.eager.in_eager_mode():
353350
embedding_var = convert_gradient_to_tensor(embedding_var)
354351
x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
355-
emb_x = gather(embedding_var, x, dtype)
352+
emb_x = gather(embedding_var, x)
356353
if multiplier != 1.0:
357354
emb_x *= multiplier
358355
static_shape = emb_x.shape.as_list()

0 commit comments

Comments
 (0)