@@ -377,11 +377,8 @@ def __load_validation_info(
377377 ** distributed_kwargs ,
378378 )
379379
380- dprint ("model.eval()" )
381380model .eval ()
382- dprint ("finished eval" )
383381
384- ######################################### Building valid prompts
385382@dataclass
386383class ProgramInfo :
387384 program_id : str
@@ -406,6 +403,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
406403 limit_val = int (match .group (2 ))
407404 return limit_val , limit_type
408405
406+ # TODO: Add a check or logic for case that prog criteria json must exist if saving CPU outputs
409407with open (args .program_criteria_json_path , "r" ) as f :
410408 program_criteria_json_list = json .load (f )["programs" ]
411409 program_criteria_list = []
@@ -578,12 +576,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
578576 f"no valid prompt shape was found which would result in program { program_id } that satisfied batch{ batch_size_limit_type } { batch_size_limit } and prompt_length{ prompt_length_limit_type } { prompt_length_limit } "
579577 )
580578
581- dprint ("valid prompts are prepared." )
582- ################################################################
583-
584-
585- ############## saving CPU validation ###########################
586-
587579if not args .skip_validation :
588580 with stagger_region (args .stagger_load ):
589581 validation_model = get_model (
@@ -611,7 +603,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
611603 )
612604 # if the cpu validation info is not yet computed, compute it
613605 if cpu_validation_info is None :
614- dprint (f"Let's start to compute it." )
615606 cpu_validation_info = extract_validation_information (
616607 validation_model ,
617608 input_ids ,
@@ -620,7 +611,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
620611 attn_algorithm = "math" ,
621612 ** extra_kwargs ,
622613 )
623- dprint (f"Save that cpu info." )
624614 # save the cpu validation info for later consumption
625615 if save_validation_info_outputs :
626616 cpu_validation_info .save (
@@ -638,12 +628,11 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
638628 )
639629
640630if args .stop_after_info_outputs :
641- dprint ("CPU validation info saved. Exiting as requested." )
642- # sys. exit(0)
631+ dprint ("CPU validation outputs saved. Exiting as requested." )
632+ exit (0 )
643633
644634################################################################
645635
646- dprint ("onto model compilation related stuff ..." )
647636fx_config .backed_size_oblivious = True
648637model .compile (backend = "sendnn" , options = {"sendnn.dynamic" : True })
649638__maybe_prepare_fp8_weights (model , is_fp8 )
0 commit comments