@@ -192,7 +192,7 @@ def process_eval_set(self):
192192 ids = prompts
193193 if isinstance (ids , list ) and len (ids ) == 1 :
194194 ids = ids [0 ].unsqueeze (0 )
195- extra_generation_kwargs = None
195+ extra_generation_kwargs = {}
196196
197197 self .extra_generation_kwargs = extra_generation_kwargs
198198
@@ -252,15 +252,10 @@ def infer(self, ids, warmup):
252252 max_seq_len = self .model .config .max_expected_seq_len
253253
254254 # Add only_last_token optimization
255- extra_generation_kwargs = (
256- {}
257- if self .extra_generation_kwargs is None
258- else self .extra_generation_kwargs
259- )
260- extra_generation_kwargs ["only_last_token" ] = True
255+ self .extra_generation_kwargs ["only_last_token" ] = True
261256
262257 if args .device_type == "cpu" :
263- extra_generation_kwargs ["attn_algorithm" ] = "math"
258+ self . extra_generation_kwargs ["attn_algorithm" ] = "math"
264259
265260 if not args .no_early_termination and not warmup :
266261 eos_token_id = self .tokenizer .eos_token_id
@@ -277,7 +272,7 @@ def infer(self, ids, warmup):
277272 timing = args .timing ,
278273 eos_token_id = eos_token_id ,
279274 contiguous_cache = True ,
280- extra_kwargs = extra_generation_kwargs ,
275+ extra_kwargs = self . extra_generation_kwargs ,
281276 )
282277 if args .timing != "" :
283278 result , timings = result
0 commit comments