Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 86703a2

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Fast beam search decoding.
PiperOrigin-RevId: 173594699
1 parent 2245033 commit 86703a2

File tree

5 files changed

+225
-38
lines changed

5 files changed

+225
-38
lines changed

tensor2tensor/models/transformer.py

Lines changed: 107 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
from tensor2tensor.layers import common_attention
3131
from tensor2tensor.layers import common_hparams
3232
from tensor2tensor.layers import common_layers
33+
from tensor2tensor.utils import beam_search
3334
from tensor2tensor.utils import expert_utils
3435
from tensor2tensor.utils import registry
3536
from tensor2tensor.utils import t2t_model
3637

3738
import tensorflow as tf
3839

40+
from tensorflow.python.util import nest
41+
3942

4043
@registry.register_model
4144
class Transformer(t2t_model.T2TModel):
@@ -159,6 +162,58 @@ def _greedy_infer(
159162
logits: Not returned
160163
losses: Not returned
161164
165+
Raises:
166+
ValueError: If last_position_only if False
167+
NotImplementedError: If there are multiple data shards.
168+
"""
169+
decoded_ids = self._fast_decode(features, decode_length, last_position_only)
170+
return decoded_ids, None, None
171+
172+
def _beam_decode(self, features, decode_length, beam_size, top_beams,
173+
last_position_only, alpha):
174+
"""Beam search decoding.
175+
176+
Args:
177+
features: an map of string to `Tensor`
178+
decode_length: an integer. How many additional timesteps to decode.
179+
beam_size: number of beams.
180+
top_beams: an integer. How many of the beams to return.
181+
last_position_only: MUST be true for fast decoding!
182+
alpha: Float that controls the length penalty. larger the alpha, stronger
183+
the preference for slonger translations.
184+
185+
Returns:
186+
samples: an integer `Tensor`. Top samples from the beam search
187+
"""
188+
return self._fast_decode(
189+
features, decode_length, last_position_only, beam_size, top_beams,
190+
alpha)
191+
192+
def _fast_decode(
193+
self,
194+
features,
195+
decode_length,
196+
last_position_only=True,
197+
beam_size=1,
198+
top_beams=1,
199+
alpha=1.0):
200+
"""Fast decoding.
201+
202+
Implements both greedy and beam search decoding, uses beam search iff
203+
beam_size > 1, otherwise beam search related arguments are ignored.
204+
205+
Args:
206+
features: a map of string to model features.
207+
decode_length: an integer. How many additional timesteps to decode.
208+
last_position_only: MUST be true for fast decoding!
209+
beam_size: number of beams.
210+
top_beams: an integer. How many of the beams to return.
211+
alpha: Float that controls the length penalty. larger the alpha, stronger
212+
the preference for slonger translations.
213+
214+
Returns:
215+
samples: an integer `Tensor`. Top samples from the beam search
216+
162217
Raises:
163218
ValueError: If last_position_only if False
164219
NotImplementedError: If there are multiple data shards.
@@ -192,6 +247,8 @@ def _greedy_infer(
192247
with tf.variable_scope("body"):
193248
encoder_output, encoder_decoder_attention_bias = dp(
194249
self.encode, inputs, features["target_space_id"], hparams)
250+
encoder_output = encoder_output[0]
251+
encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
195252

196253
if hparams.pos == "timing":
197254
timing_signal = common_attention.get_timing_signal_1d(
@@ -236,6 +293,7 @@ def preprocess_targets(targets, i):
236293

237294
def symbols_to_logits_fn(ids, i, cache):
238295
"""Go from ids to logits for next symbol."""
296+
ids = ids[:, -1:]
239297
targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
240298
targets = preprocess_targets(targets, i)
241299

@@ -245,22 +303,16 @@ def symbols_to_logits_fn(ids, i, cache):
245303
body_outputs = dp(
246304
self.decode,
247305
targets,
248-
encoder_output[0],
249-
encoder_decoder_attention_bias[0],
306+
cache["encoder_output"],
307+
cache["encoder_decoder_attention_bias"],
250308
bias,
251309
hparams,
252310
cache)
253311

254312
with tf.variable_scope(target_modality.name):
255313
logits = target_modality.top_sharded(body_outputs, None, dp)[0]
256314

257-
return tf.squeeze(logits, axis=[1, 2, 3])
258-
259-
def inner_loop(i, next_id, decoded_ids, cache):
260-
logits = symbols_to_logits_fn(next_id, i, cache)
261-
next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1)
262-
decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
263-
return i+1, next_id, decoded_ids, cache
315+
return tf.squeeze(logits, axis=[1, 2, 3]), cache
264316

265317
key_channels = hparams.attention_key_channels or hparams.hidden_size
266318
value_channels = hparams.attention_value_channels or hparams.hidden_size
@@ -272,24 +324,53 @@ def inner_loop(i, next_id, decoded_ids, cache):
272324
"v": tf.zeros([batch_size, 0, value_channels]),
273325
} for layer in range(num_layers)
274326
}
275-
decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
276-
next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
277-
_, _, decoded_ids, _ = tf.while_loop(
278-
# TODO(llion): Early stopping.
279-
lambda i, *_: tf.less(i, decode_length),
280-
inner_loop,
281-
[tf.constant(0), next_id, decoded_ids, cache],
282-
shape_invariants=[
283-
tf.TensorShape([]),
284-
tf.TensorShape([None, None]),
285-
tf.TensorShape([None, None]),
286-
{"layer_%d" % layer: {
287-
"k": tf.TensorShape([None, None, key_channels]),
288-
"v": tf.TensorShape([None, None, value_channels]),
289-
} for layer in range(num_layers)}
290-
])
291327

292-
return decoded_ids, None, None
328+
# Set 2nd dim to None since it's not invariant in the tf.while_loop
329+
# Note: Tensor.set_shape() does not work here since it merges shape info.
330+
# TODO(llion); Find a more robust solution.
331+
# pylint: disable=protected-access
332+
for layer in cache:
333+
cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
334+
cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
335+
# pylint: enable=protected-access
336+
cache["encoder_output"] = encoder_output
337+
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
338+
339+
if beam_size > 1: # Beam Search
340+
target_modality = (
341+
self._hparams.problems[self._problem_idx].target_modality)
342+
vocab_size = target_modality.top_dimensionality
343+
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
344+
decoded_ids, _ = beam_search.beam_search(
345+
symbols_to_logits_fn, initial_ids, beam_size, decode_length,
346+
vocab_size, alpha, states=cache)
347+
348+
if top_beams == 1:
349+
decoded_ids = decoded_ids[:, 0, 1:]
350+
else:
351+
decoded_ids = decoded_ids[:, :top_beams, 1:]
352+
else: # Greedy
353+
def inner_loop(i, next_id, decoded_ids, cache):
354+
logits, cache = symbols_to_logits_fn(next_id, i, cache)
355+
next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1)
356+
decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
357+
return i+1, next_id, decoded_ids, cache
358+
359+
decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
360+
next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
361+
_, _, decoded_ids, _ = tf.while_loop(
362+
# TODO(llion): Early stopping.
363+
lambda i, *_: tf.less(i, decode_length),
364+
inner_loop,
365+
[tf.constant(0), next_id, decoded_ids, cache],
366+
shape_invariants=[
367+
tf.TensorShape([]),
368+
tf.TensorShape([None, None]),
369+
tf.TensorShape([None, None]),
370+
nest.map_structure(lambda t: tf.TensorShape(t.shape), cache),
371+
])
372+
373+
return decoded_ids
293374

294375

295376
@registry.register_model

tensor2tensor/models/transformer_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,51 @@ def testGreedyVsFast(self):
112112
self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))
113113
self.assertAllClose(greedy_res, fast_res)
114114

115+
def testBeamVsFast(self):
116+
model, features = self.getModel(transformer.transformer_small())
117+
118+
decode_length = 2
119+
120+
out_logits, _ = model.model_fn(features)
121+
out_logits = tf.squeeze(out_logits[0], axis=[2, 3])
122+
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
123+
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
124+
labels=tf.reshape(features["targets"], [-1]))
125+
loss = tf.reduce_mean(loss)
126+
apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
127+
128+
with self.test_session():
129+
tf.global_variables_initializer().run()
130+
for _ in range(100):
131+
apply_grad.run()
132+
133+
model, _ = self.getModel(transformer.transformer_small(),
134+
mode=tf.estimator.ModeKeys.PREDICT)
135+
136+
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
137+
beam_result = model._beam_decode_slow(
138+
features,
139+
decode_length,
140+
beam_size=4,
141+
top_beams=1,
142+
last_position_only=True,
143+
alpha=1.0)
144+
145+
fast_result = model._beam_decode(
146+
features,
147+
decode_length,
148+
beam_size=4,
149+
top_beams=1,
150+
last_position_only=True,
151+
alpha=1.0)
152+
153+
with self.test_session():
154+
beam_res = beam_result.eval()
155+
fast_res = fast_result.eval()
156+
157+
self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))
158+
self.assertAllClose(beam_res, fast_res)
159+
160+
115161
if __name__ == "__main__":
116162
tf.test.main()

tensor2tensor/utils/beam_search.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,45 @@
3030
INF = 1. * 1e7
3131

3232

33-
def expand_to_beam_size(tensor, beam_size):
33+
def _get_shape(tensor):
34+
"""Returns static shape if available and dynamic shape otherwise."""
35+
static = tensor.shape.as_list()
36+
dynamic = tf.unstack(tf.shape(tensor))
37+
return [s[1] if s[0] is None else s[0] for s in zip(static, dynamic)]
38+
39+
40+
def _merge_beam_dim(tensor):
41+
"""Reshapes first two dimensions in to single dimension.
42+
43+
Args:
44+
tensor: Tensor to reshape of shape [A, B, ...]
45+
46+
Returns:
47+
Reshaped tensor of shape [A*B, ...]
48+
"""
49+
shape = _get_shape(tensor)
50+
shape[0] *= shape[1] # batch -> batch * beam_size
51+
shape.pop(1) # Remove beam dim
52+
return tf.reshape(tensor, shape)
53+
54+
55+
def _unmerge_beam_dim(tensor, batch_size, beam_size):
56+
"""Reshapes first dimension back to [batch_size, beam_size].
57+
58+
Args:
59+
tensor: Tensor to reshape of shape [batch_size*beam_size, ...]
60+
batch_size: Tensor, original batch size.
61+
beam_size: int, original beam size.
62+
63+
Returns:
64+
Reshaped tensor of shape [batch_size, beam_size, ...]
65+
"""
66+
shape = _get_shape(tensor)
67+
new_shape = [batch_size] + [beam_size] + shape[1:]
68+
return tf.reshape(tensor, new_shape)
69+
70+
71+
def _expand_to_beam_size(tensor, beam_size):
3472
"""Tiles a given tensor by beam_size.
3573
3674
Args:
@@ -191,11 +229,11 @@ def beam_search(symbols_to_logits_fn,
191229
alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])
192230

193231
# Expand each batch and state to beam_size
194-
alive_seq = expand_to_beam_size(initial_ids, beam_size)
232+
alive_seq = _expand_to_beam_size(initial_ids, beam_size)
195233
alive_seq = tf.expand_dims(alive_seq, axis=2) # (batch_size, beam_size, 1)
196234
if states:
197235
states = nest.map_structure(
198-
lambda state: expand_to_beam_size(state, beam_size), states)
236+
lambda state: _expand_to_beam_size(state, beam_size), states)
199237
else:
200238
states = {}
201239

@@ -302,12 +340,10 @@ def grow_topk(i, alive_seq, alive_log_probs, states):
302340

303341
# (batch_size * beam_size, decoded_length)
304342
if states:
305-
flat_states = nest.map_structure(
306-
lambda state: tf.reshape(state, [batch_size * beam_size, -1]), states)
307-
flat_logits, flat_states = symbols_to_logits_fn(flat_ids, flat_states)
343+
flat_states = nest.map_structure(_merge_beam_dim, states)
344+
flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i, flat_states)
308345
states = nest.map_structure(
309-
lambda state: tf.reshape(state, [batch_size, beam_size, -1]),
310-
flat_states)
346+
lambda t: _unmerge_beam_dim(t, batch_size, beam_size), flat_states)
311347
else:
312348
flat_logits = symbols_to_logits_fn(flat_ids)
313349
logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])
@@ -478,8 +514,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
478514
finished_scores.get_shape(),
479515
finished_flags.get_shape(),
480516
nest.map_structure(
481-
lambda tensor: tf.TensorShape([None] * tensor.shape.ndims),
482-
states),
517+
lambda tensor: tf.TensorShape(tensor.shape), states),
483518
],
484519
parallel_iterations=1,
485520
back_prop=False)

tensor2tensor/utils/beam_search_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def testStates(self):
289289

290290
expected_states = tf.constant([[[0.]], [[1.]]])
291291

292-
def symbols_to_logits(ids, states):
292+
def symbols_to_logits(ids, _, states):
293293
pos = tf.shape(ids)[1] - 1
294294
# We have to assert the values of state inline here since we can't fetch
295295
# them out of the loop!
@@ -303,6 +303,7 @@ def symbols_to_logits(ids, states):
303303
states = {
304304
"state": tf.zeros((batch_size, 1)),
305305
}
306+
states["state"]._shape = tf.TensorShape((None, 1))
306307

307308
final_ids, _ = beam_search.beam_search(
308309
symbols_to_logits,
@@ -336,7 +337,7 @@ def testStateBeamTwo(self):
336337
# at each position, which is the one thats getting 3 added to it each step.
337338
expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]])
338339

339-
def symbols_to_logits(ids, states):
340+
def symbols_to_logits(ids, _, states):
340341
pos = tf.shape(ids)[1] - 1
341342

342343
# We have to assert the values of state inline here since we can't fetch
@@ -351,6 +352,7 @@ def symbols_to_logits(ids, states):
351352
states = {
352353
"state": tf.zeros((batch_size, 1)),
353354
}
355+
states["state"]._shape = tf.TensorShape((None, 1))
354356

355357
final_ids, _ = beam_search.beam_search(
356358
symbols_to_logits,

tensor2tensor/utils/t2t_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,29 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams,
217217
last_position_only, alpha):
218218
"""Beam search decoding.
219219
220+
Models should ideally implement a more efficient version of this function.
221+
222+
Args:
223+
features: an map of string to `Tensor`
224+
decode_length: an integer. How many additional timesteps to decode.
225+
beam_size: number of beams.
226+
top_beams: an integer. How many of the beams to return.
227+
last_position_only: a boolean, speed-up by computing last position only.
228+
alpha: Float that controls the length penalty. larger the alpha, stronger
229+
the preference for slonger translations.
230+
231+
Returns:
232+
samples: an integer `Tensor`. Top samples from the beam search
233+
"""
234+
return self._beam_decode_slow(features, decode_length, beam_size, top_beams,
235+
last_position_only, alpha)
236+
237+
def _beam_decode_slow(self, features, decode_length, beam_size, top_beams,
238+
last_position_only, alpha):
239+
"""Slow version of Beam search decoding.
240+
241+
Quadratic time in decode_length.
242+
220243
Args:
221244
features: an map of string to `Tensor`
222245
decode_length: an integer. How many additional timesteps to decode.

0 commit comments

Comments
 (0)