@@ -99,15 +99,12 @@ def _update_model_kwargs_for_generation(
9999 self ,
100100 outputs : Dict [str , Any ],
101101 model_kwargs : Dict [str , Any ],
102- ) -> MODEL_KWARGS_TYPE :
102+ ) -> None :
103103 """After a forward pass, update model_kwargs for faster decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L692.
104104
105105 Args:
106106 outputs (Dict[str, Any]): LM output.
107107 model_kwargs (Dict[str, Any]): Model keyword args to be modified for future runs.
108-
109- Returns:
110- Modified model_kwargs w/ updated past, token_type_ids, and attention_mask.
111108 """
112109 # Update past
113110 if "past_key_values" in outputs :
@@ -138,8 +135,6 @@ def _update_model_kwargs_for_generation(
138135 dim = - 1 ,
139136 )
140137
141- return model_kwargs
142-
143138 def greedy_search (
144139 self ,
145140 input_ids : torch .Tensor ,
@@ -222,6 +217,8 @@ def beam_search(
222217 Returns:
223218 Tensor of the generated sequences.
224219 """
220+ device = input_ids .device
221+
225222 if self .is_encoder_decoder :
226223 encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
227224 encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
@@ -231,9 +228,6 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
231228
232229 i = T # Hacky access to the current seq in inputs
233230
234- # Copy over the `model_kwargs` in order to modify
235- new_model_kwargs = model_kwargs .copy ()
236-
237231 # For first timestep, create previous step token_idxs and model_states
238232 if timestep == 0 :
239233 prev_step_token_idxs = [- 1 ]
@@ -268,7 +262,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
268262 state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
269263 token_indices = (
270264 torch .Tensor (prev_step_token_idxs [start :end ])
271- .to (dtype = torch .long , device = self . model . device )
265+ .to (dtype = torch .long , device = device )
272266 .reshape (num_samples , 1 )
273267 )
274268
@@ -281,23 +275,24 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
281275 ), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
282276 else :
283277 assert len (prev_model_state_sequences ) == 1
284- state_and_tokens = token_indices = prev_model_state_sequences [0 ] # dims: [1, 1]
278+ state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (num_beams , - 1 ) # TODO: Make this more robust
279+
285280
286281 # Cleanup -- combine this with the above
287282 if self .is_encoder_decoder :
288283 # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
289284 # This is a view-only operation and doesn't copy
290- new_model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
291- num_samples if timestep > 0 else 1 , - 1 , - 1
285+ model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
286+ num_samples if timestep > 0 else num_beams , - 1 , - 1
292287 )
293288
294289 # Preprocess inputs for generation
295- model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** new_model_kwargs )
290+ model_inputs = self .model .prepare_inputs_for_generation (token_indices , ** model_kwargs )
296291 if self .is_huggingface_model :
297292 model_inputs .update (self ._huggingface_model_input_values )
298- if len (prev_step_hyp_idxs ) > 1 and model_inputs [ "past_key_values " ] is not None :
293+ if len (prev_step_hyp_idxs ) > 1 and model_kwargs [ "past " ] is not None :
299294 model_inputs ["past_key_values" ] = self .model ._reorder_cache (
300- model_inputs [ "past_key_values " ],
295+ model_kwargs [ "past " ],
301296 torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 ),
302297 )
303298
@@ -310,32 +305,23 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
310305
311306 # HF optimizations to reduce overhead in future `forward` calls
312307 if self .is_huggingface_model :
313- new_model_kwargs = self ._update_model_kwargs_for_generation (outputs , new_model_kwargs )
308+ self ._update_model_kwargs_for_generation (outputs , model_kwargs )
314309
315310 # Keep track of probabilities over vocab for this pairing
316- # TODO: clean up duplicate code in these branches
317- if timestep == 0 :
318- sample_lm_scores = torch . squeeze ( lm_scores [: , - 1 ])
311+ # TODO: fix how we track the number here?
312+ for i in range ( lm_scores . shape [ 0 ]) :
313+ sample_lm_scores = lm_scores [i , - 1 ]
319314 out_probs .append (sample_lm_scores .tolist ())
315+ # Keep track of sequence and decoder hidden states
320316 model_states .append (
321317 create_emitting_model_state (
322- Seq2SeqModelState (timestep = timestep , sequence = state_and_tokens , lm_scores = sample_lm_scores )
323- )
324- )
325- else :
326- for i in range (num_samples ):
327- sample_lm_scores = lm_scores [i , - 1 ]
328- out_probs .append (sample_lm_scores .tolist ())
329- # Keep track of sequence and decoder hidden states
330- model_states .append (
331- create_emitting_model_state (
332- Seq2SeqModelState (
333- timestep = timestep ,
334- sequence = state_and_tokens [i ].unsqueeze (0 ),
335- lm_scores = sample_lm_scores ,
336- )
318+ Seq2SeqModelState (
319+ timestep = timestep ,
320+ sequence = state_and_tokens [i ].unsqueeze (0 ),
321+ lm_scores = sample_lm_scores ,
337322 )
338323 )
324+ )
339325
340326 start += step
341327
0 commit comments