3030from tensor2tensor .layers import common_attention
3131from tensor2tensor .layers import common_hparams
3232from tensor2tensor .layers import common_layers
33+ from tensor2tensor .utils import beam_search
3334from tensor2tensor .utils import expert_utils
3435from tensor2tensor .utils import registry
3536from tensor2tensor .utils import t2t_model
3637
3738import tensorflow as tf
3839
40+ from tensorflow .python .util import nest
41+
3942
4043@registry .register_model
4144class Transformer (t2t_model .T2TModel ):
@@ -159,6 +162,58 @@ def _greedy_infer(
159162 logits: Not returned
160163 losses: Not returned
161164
165+ Raises:
166+ ValueError: If last_position_only if False
167+ NotImplementedError: If there are multiple data shards.
168+ """
169+ decoded_ids = self ._fast_decode (features , decode_length , last_position_only )
170+ return decoded_ids , None , None
171+
172+ def _beam_decode (self , features , decode_length , beam_size , top_beams ,
173+ last_position_only , alpha ):
174+ """Beam search decoding.
175+
176+ Args:
177+ features: an map of string to `Tensor`
178+ decode_length: an integer. How many additional timesteps to decode.
179+ beam_size: number of beams.
180+ top_beams: an integer. How many of the beams to return.
181+ last_position_only: MUST be true for fast decoding!
182+ alpha: Float that controls the length penalty. larger the alpha, stronger
183+ the preference for slonger translations.
184+
185+ Returns:
186+ samples: an integer `Tensor`. Top samples from the beam search
187+ """
188+ return self ._fast_decode (
189+ features , decode_length , last_position_only , beam_size , top_beams ,
190+ alpha )
191+
192+ def _fast_decode (
193+ self ,
194+ features ,
195+ decode_length ,
196+ last_position_only = True ,
197+ beam_size = 1 ,
198+ top_beams = 1 ,
199+ alpha = 1.0 ):
200+ """Fast decoding.
201+
202+ Implements both greedy and beam search decoding, uses beam search iff
203+ beam_size > 1, otherwise beam search related arguments are ignored.
204+
205+ Args:
206+ features: a map of string to model features.
207+ decode_length: an integer. How many additional timesteps to decode.
208+ last_position_only: MUST be true for fast decoding!
209+ beam_size: number of beams.
210+ top_beams: an integer. How many of the beams to return.
211+ alpha: Float that controls the length penalty. larger the alpha, stronger
212+ the preference for slonger translations.
213+
214+ Returns:
215+ samples: an integer `Tensor`. Top samples from the beam search
216+
162217 Raises:
163218 ValueError: If last_position_only if False
164219 NotImplementedError: If there are multiple data shards.
@@ -192,6 +247,8 @@ def _greedy_infer(
192247 with tf .variable_scope ("body" ):
193248 encoder_output , encoder_decoder_attention_bias = dp (
194249 self .encode , inputs , features ["target_space_id" ], hparams )
250+ encoder_output = encoder_output [0 ]
251+ encoder_decoder_attention_bias = encoder_decoder_attention_bias [0 ]
195252
196253 if hparams .pos == "timing" :
197254 timing_signal = common_attention .get_timing_signal_1d (
@@ -236,6 +293,7 @@ def preprocess_targets(targets, i):
236293
237294 def symbols_to_logits_fn (ids , i , cache ):
238295 """Go from ids to logits for next symbol."""
296+ ids = ids [:, - 1 :]
239297 targets = tf .expand_dims (tf .expand_dims (ids , axis = 2 ), axis = 3 )
240298 targets = preprocess_targets (targets , i )
241299
@@ -245,22 +303,16 @@ def symbols_to_logits_fn(ids, i, cache):
245303 body_outputs = dp (
246304 self .decode ,
247305 targets ,
248- encoder_output [ 0 ],
249- encoder_decoder_attention_bias [ 0 ],
306+ cache [ "encoder_output" ],
307+ cache [ "encoder_decoder_attention_bias" ],
250308 bias ,
251309 hparams ,
252310 cache )
253311
254312 with tf .variable_scope (target_modality .name ):
255313 logits = target_modality .top_sharded (body_outputs , None , dp )[0 ]
256314
257- return tf .squeeze (logits , axis = [1 , 2 , 3 ])
258-
259- def inner_loop (i , next_id , decoded_ids , cache ):
260- logits = symbols_to_logits_fn (next_id , i , cache )
261- next_id = tf .expand_dims (tf .argmax (logits , axis = - 1 ), axis = 1 )
262- decoded_ids = tf .concat ([decoded_ids , next_id ], axis = 1 )
263- return i + 1 , next_id , decoded_ids , cache
315+ return tf .squeeze (logits , axis = [1 , 2 , 3 ]), cache
264316
265317 key_channels = hparams .attention_key_channels or hparams .hidden_size
266318 value_channels = hparams .attention_value_channels or hparams .hidden_size
@@ -272,24 +324,53 @@ def inner_loop(i, next_id, decoded_ids, cache):
272324 "v" : tf .zeros ([batch_size , 0 , value_channels ]),
273325 } for layer in range (num_layers )
274326 }
275- decoded_ids = tf .zeros ([batch_size , 0 ], dtype = tf .int64 )
276- next_id = tf .zeros ([batch_size , 1 ], dtype = tf .int64 )
277- _ , _ , decoded_ids , _ = tf .while_loop (
278- # TODO(llion): Early stopping.
279- lambda i , * _ : tf .less (i , decode_length ),
280- inner_loop ,
281- [tf .constant (0 ), next_id , decoded_ids , cache ],
282- shape_invariants = [
283- tf .TensorShape ([]),
284- tf .TensorShape ([None , None ]),
285- tf .TensorShape ([None , None ]),
286- {"layer_%d" % layer : {
287- "k" : tf .TensorShape ([None , None , key_channels ]),
288- "v" : tf .TensorShape ([None , None , value_channels ]),
289- } for layer in range (num_layers )}
290- ])
291327
292- return decoded_ids , None , None
328+ # Set 2nd dim to None since it's not invariant in the tf.while_loop
329+ # Note: Tensor.set_shape() does not work here since it merges shape info.
330+ # TODO(llion); Find a more robust solution.
331+ # pylint: disable=protected-access
332+ for layer in cache :
333+ cache [layer ]["k" ]._shape = tf .TensorShape ([None , None , key_channels ])
334+ cache [layer ]["v" ]._shape = tf .TensorShape ([None , None , value_channels ])
335+ # pylint: enable=protected-access
336+ cache ["encoder_output" ] = encoder_output
337+ cache ["encoder_decoder_attention_bias" ] = encoder_decoder_attention_bias
338+
339+ if beam_size > 1 : # Beam Search
340+ target_modality = (
341+ self ._hparams .problems [self ._problem_idx ].target_modality )
342+ vocab_size = target_modality .top_dimensionality
343+ initial_ids = tf .zeros ([batch_size ], dtype = tf .int32 )
344+ decoded_ids , _ = beam_search .beam_search (
345+ symbols_to_logits_fn , initial_ids , beam_size , decode_length ,
346+ vocab_size , alpha , states = cache )
347+
348+ if top_beams == 1 :
349+ decoded_ids = decoded_ids [:, 0 , 1 :]
350+ else :
351+ decoded_ids = decoded_ids [:, :top_beams , 1 :]
352+ else : # Greedy
353+ def inner_loop (i , next_id , decoded_ids , cache ):
354+ logits , cache = symbols_to_logits_fn (next_id , i , cache )
355+ next_id = tf .expand_dims (tf .argmax (logits , axis = - 1 ), axis = 1 )
356+ decoded_ids = tf .concat ([decoded_ids , next_id ], axis = 1 )
357+ return i + 1 , next_id , decoded_ids , cache
358+
359+ decoded_ids = tf .zeros ([batch_size , 0 ], dtype = tf .int64 )
360+ next_id = tf .zeros ([batch_size , 1 ], dtype = tf .int64 )
361+ _ , _ , decoded_ids , _ = tf .while_loop (
362+ # TODO(llion): Early stopping.
363+ lambda i , * _ : tf .less (i , decode_length ),
364+ inner_loop ,
365+ [tf .constant (0 ), next_id , decoded_ids , cache ],
366+ shape_invariants = [
367+ tf .TensorShape ([]),
368+ tf .TensorShape ([None , None ]),
369+ tf .TensorShape ([None , None ]),
370+ nest .map_structure (lambda t : tf .TensorShape (t .shape ), cache ),
371+ ])
372+
373+ return decoded_ids
293374
294375
295376@registry .register_model
0 commit comments