@@ -592,7 +592,7 @@ def get_program_prompt_list(
592592 return valid_prompts
593593
594594
595- def run_validation_tests (
595+ def run_validation (
596596 args : argparse .Namespace ,
597597 model : torch .nn .Module ,
598598 validation_model : Optional [torch .nn .Module ],
@@ -604,7 +604,7 @@ def run_validation_tests(
604604 attn_name : str ,
605605 cpu_dtype : str ,
606606 tokenizer : AutoTokenizer ,
607- ) -> None :
607+ ):
608608
609609 if local_rank == 0 :
610610 dprint (f"*** testing program { program_id } ***" )
@@ -669,47 +669,17 @@ def run_validation_tests(
669669 ** extra_kwargs ,
670670 )
671671
672- if args .test_type == "metrics" :
673- process_metrics_test (
674- cross_entropy_threshold = args .cross_entropy_threshold ,
675- failure_rate_threshold = args .failure_rate_threshold ,
676- aiu_validation_info = aiu_validation_info ,
677- cpu_validation_info = cpu_validation_info ,
678- program_id = program_id ,
679- prompt_shape = valid_prompt ,
680- tokenizer = tokenizer
681- )
682-
683- elif args .test_type == "tokens" :
684- process_tokens_test (
685- max_new_tokens = args .max_new_tokens ,
686- aiu_validation_info = aiu_validation_info ,
687- cpu_validation_info = cpu_validation_info ,
688- program_id = program_id ,
689- tokenizer = tokenizer
690- )
691-
692- if args .skip_validation and local_rank == 0 :
693- for sentence_idx , test_sentence in enumerate (
694- aiu_validation_info .get_info ("tokens" )
695- ):
696- tokens_prompt = [t .item () for t in test_sentence [:- args .max_new_tokens ]]
697- aiu_tokens_generated = [t .item () for t in test_sentence [- args .max_new_tokens :]]
698- dprint (f"For Program { program_id } in sentence { sentence_idx + 1 } :" )
699- dprint (f"Prompt:\n { tokenizer .decode (tokens_prompt )} " )
700- dprint (f"AIU tokens:\n { aiu_tokens_generated } " )
701- dprint (f"AIU output:\n { tokenizer .decode (aiu_tokens_generated )} " )
672+ return aiu_validation_info , cpu_validation_info
702673
703674
704- def process_metrics_test (
675+ def run_metrics_test (
705676 cross_entropy_threshold : float ,
706- failure_rate_threshold : float ,
707677 aiu_validation_info : ValidationInfo ,
708678 cpu_validation_info : ValidationInfo ,
709679 program_id : str ,
710680 prompt_shape : Tuple [int , int ],
711681 tokenizer : AutoTokenizer ,
712- ) -> None :
682+ ):
713683
714684 level_1_metrics = capture_level_1_metrics (
715685 cpu_validation_info .get_info ("logits" ),
@@ -737,17 +707,11 @@ def process_metrics_test(
737707 ce_fail_responses = filter_failed_level_1_cases (
738708 level_1_metrics , lambda m : m [0 ] >= cross_entropy_threshold
739709 )
710+ failure_rate = len (ce_fail_responses ) / len (level_1_metrics )
740711
741- failure_rate = len (ce_fail_responses ) / len (level_1_metrics ) if level_1_metrics else 0.0
742-
743- if failure_rate >= failure_rate_threshold :
744- dprint (f"[FAIL] Program { program_id } failed with rate { failure_rate :.4f} >= threshold { failure_rate_threshold } ." )
745-
746- if local_rank == 0 :
747- dprint (f"[PASS] Program { program_id } passed. Failure Rate: { failure_rate :.4f} ." )
748-
712+ return failure_rate
749713
750- def process_tokens_test (
714+ def run_tokens_test (
751715 max_new_tokens : int ,
752716 aiu_validation_info : ValidationInfo ,
753717 cpu_validation_info : ValidationInfo ,
@@ -827,10 +791,10 @@ def process_tokens_test(
827791distributed_kwargs = get_distributed_kwargs (args .distributed , args .dist_timeout , args .save_validation_info_outputs )
828792
829793model = load_model (
830- device_type = "cpu" ,
831- model_variant = args .model_variant ,
832- is_fp8 = is_fp8 ,
833- distributed_kwargs = distributed_kwargs ,
794+ device_type = "cpu" ,
795+ model_variant = args .model_variant ,
796+ is_fp8 = is_fp8 ,
797+ distributed_kwargs = distributed_kwargs ,
834798 stagger_load = args .stagger_load ,
835799 is_validation = False )
836800
@@ -927,14 +891,14 @@ def process_tokens_test(
927891 custom_shape = custom_shape ,
928892)
929893
930- ## RUN TESTS ##
894+ ## RUN VALIDATION AND TESTS ##
931895
932896failed_cases = []
933897# for each program and valid prompt (batch size, sequence length)
934898for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
935899 extra_kwargs ["attn_name" ] = ATTN_NAME
936900
937- run_validation_tests (
901+ aiu_validation_info , cpu_validation_info = run_validation (
938902 args = args ,
939903 model = model ,
940904 validation_model = validation_model ,
@@ -947,3 +911,50 @@ def process_tokens_test(
947911 cpu_dtype = CPU_DTYPE ,
948912 tokenizer = tokenizer ,
949913 )
914+
915+ if args .test_type == "metrics" :
916+ failure_rate = run_metrics_test (
917+ cross_entropy_threshold = args .cross_entropy_threshold ,
918+ aiu_validation_info = aiu_validation_info ,
919+ cpu_validation_info = cpu_validation_info ,
920+ program_id = program_id ,
921+ prompt_shape = valid_prompt ,
922+ tokenizer = tokenizer
923+ )
924+ if failure_rate > args .failure_rate_threshold :
925+ failed_cases .append (
926+ (program_id , valid_prompt , failure_rate )
927+ )
928+
929+ elif args .test_type == "tokens" :
930+ run_tokens_test (
931+ max_new_tokens = args .max_new_tokens ,
932+ aiu_validation_info = aiu_validation_info ,
933+ cpu_validation_info = cpu_validation_info ,
934+ program_id = program_id ,
935+ tokenizer = tokenizer
936+ )
937+
938+ else :
939+ raise ValueError ("test type must be one of metrics or tokens" )
940+
941+ if args .skip_validation and local_rank == 0 :
942+ for sentence_idx , test_sentence in enumerate (
943+ aiu_validation_info .get_info ("tokens" )
944+ ):
945+ tokens_prompt = [t .item () for t in test_sentence [:- args .max_new_tokens ]]
946+ aiu_tokens_generated = [t .item () for t in test_sentence [- args .max_new_tokens :]]
947+ dprint (f"For Program { program_id } in sentence { sentence_idx + 1 } :" )
948+ dprint (f"Prompt:\n { tokenizer .decode (tokens_prompt )} " )
949+ dprint (f"AIU tokens:\n { aiu_tokens_generated } " )
950+ dprint (f"AIU output:\n { tokenizer .decode (aiu_tokens_generated )} " )
951+
952+ if not args .skip_validation and local_rank == 0 :
953+ if len (failed_cases ) != 0 :
954+ dprint ("The test failed with the following cases:" )
955+ for failed_case in failed_cases :
956+ dprint (
957+ f"Program ID: { failed_case [0 ]} , Prompt Shape: { failed_case [1 ]} , Failure Rate: { failed_case [2 ]} "
958+ )
959+ else :
960+ dprint ("all tests passed" )
0 commit comments