@@ -275,8 +275,9 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
275275 ), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
276276 else :
277277 assert len (prev_model_state_sequences ) == 1
278- state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (num_beams , - 1 ) # TODO: Make this more robust
279-
278+ state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (
279+ num_beams , - 1
280+ ) # TODO: Make this more robust
280281
281282 # Cleanup -- combine this with the above
282283 if self .is_encoder_decoder :
@@ -287,14 +288,14 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
287288 )
288289
289290 # Preprocess inputs for generation
290- model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** model_kwargs )
291+ model_inputs = self .model .prepare_inputs_for_generation (
292+ token_indices , ** model_kwargs
293+ ) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
291294 if self .is_huggingface_model :
292295 model_inputs .update (self ._huggingface_model_input_values )
293296 if len (prev_step_hyp_idxs ) > 1 and model_kwargs ["past" ] is not None :
294- model_inputs ["past_key_values" ] = self .model ._reorder_cache (
295- model_kwargs ["past" ],
296- torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
297- )
297+ beam_idxs = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
298+ model_inputs ["past_key_values" ] = self .model ._reorder_cache (model_kwargs ["past" ], beam_idxs )
298299
299300 # Forward pass
300301 outputs = self .model (** model_inputs )
0 commit comments