@@ -328,7 +328,7 @@ def get_model_path_kwargs(model_variant: str) -> Dict[str, Any]:
328328
329329 return model_path_kwargs
330330
331- def get_distributed_kwargs (is_distributed : bool , dist_timeout : str ) -> Dict [str , Any ]:
331+ def get_distributed_kwargs (is_distributed : bool , dist_timeout : str , save_validation_info_outputs : bool ) -> Dict [str , Any ]:
332332
333333 distributed_kwargs = {}
334334 if is_distributed :
@@ -403,7 +403,6 @@ def load_model(
403403 is_validation : bool = False ,
404404):
405405
406- device_type = "cpu"
407406 dtype = None if is_fp8 else (torch .float32 if is_validation else torch .float16 )
408407
409408 model_path_kwargs = get_model_path_kwargs (model_variant )
@@ -631,10 +630,10 @@ def run_validation_tests(
631630 # if the cpu validation info is not yet computed, compute it
632631 if cpu_validation_info is None and validation_model is not None :
633632 cpu_validation_info = extract_validation_information (
634- validation_model ,
635- input_ids ,
636- args .max_new_tokens ,
637- LogitsExtractorHook (),
633+ model = validation_model ,
634+ input_ids = input_ids ,
635+ max_new_tokens = args .max_new_tokens ,
636+ post_iteration_hook = LogitsExtractorHook (),
638637 attn_algorithm = "math" ,
639638 ** extra_kwargs ,
640639 )
@@ -660,10 +659,10 @@ def run_validation_tests(
660659 golden_hook = GoldenTokenHook (cpu_validation_info .get_info ("tokens" ))
661660
662661 aiu_validation_info = extract_validation_information (
663- model ,
664- input_ids ,
665- args .max_new_tokens ,
666- golden_hook ,
662+ model = model ,
663+ input_ids = input_ids ,
664+ max_new_tokens = args .max_new_tokens ,
665+ post_iteration_hook = golden_hook ,
667666 last_n_tokens = 64 ,
668667 timing = args .timing ,
669668 prefill_chunk_size = args .prefill_chunk_size ,
@@ -825,7 +824,7 @@ def process_tokens_test(
825824## MODEL LOADING ##
826825
827826# Get distributed kwargs (empty if not distributed)
828- distributed_kwargs = get_distributed_kwargs (args .distributed , args .dist_timeout )
827+ distributed_kwargs = get_distributed_kwargs (args .distributed , args .dist_timeout , args . save_validation_info_outputs )
829828
830829model = load_model (
831830 device_type = "cpu" ,
@@ -841,7 +840,12 @@ def process_tokens_test(
841840validation_model = None
842841if not args .skip_validation :
843842 validation_model = load_model (
844- args .model_variant , is_fp8 , distributed_kwargs , args .stagger_load , is_validation = True
843+ device_type = "cpu" ,
844+ model_variant = args .model_variant ,
845+ is_fp8 = is_fp8 ,
846+ distributed_kwargs = distributed_kwargs ,
847+ stagger_load = args .stagger_load ,
848+ is_validation = True
845849 )
846850
847851## MODEL WARMUP ##
@@ -912,14 +916,15 @@ def process_tokens_test(
912916
913917# Select concrete prompts and program associations
914918valid_prompts = get_program_prompt_list (
915- args ,
916- programs_to_test ,
917- program_criteria_list ,
918- program_map ,
919- tokenizer ,
920- sampler ,
921- allow_truncation ,
922- custom_shape
919+ program_map = program_map ,
920+ dataset_path = args .dataset_path ,
921+ enforce_homogeneous_prompt_programs = args .enforce_homogeneous_prompt_programs ,
922+ programs_to_test = programs_to_test ,
923+ program_criteria_list = program_criteria_list ,
924+ tokenizer = tokenizer ,
925+ sampler = sampler ,
926+ allow_truncation = allow_truncation ,
927+ custom_shape = custom_shape ,
923928)
924929
925930## RUN TESTS ##
0 commit comments