Skip to content

Commit 340c383

Browse files
committed
Refactor of validation and tests
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
1 parent 0f47515 commit 340c383

File tree

1 file changed

+61
-50
lines changed

1 file changed

+61
-50
lines changed

aiu_fms_testing_utils/scripts/refactored_dpp.py

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def get_program_prompt_list(
592592
return valid_prompts
593593

594594

595-
def run_validation_tests(
595+
def run_validation(
596596
args: argparse.Namespace,
597597
model: torch.nn.Module,
598598
validation_model: Optional[torch.nn.Module],
@@ -604,7 +604,7 @@ def run_validation_tests(
604604
attn_name: str,
605605
cpu_dtype: str,
606606
tokenizer: AutoTokenizer,
607-
) -> None:
607+
):
608608

609609
if local_rank == 0:
610610
dprint(f"*** testing program {program_id} ***")
@@ -669,47 +669,17 @@ def run_validation_tests(
669669
**extra_kwargs,
670670
)
671671

672-
if args.test_type == "metrics":
673-
process_metrics_test (
674-
cross_entropy_threshold=args.cross_entropy_threshold,
675-
failure_rate_threshold=args.failure_rate_threshold,
676-
aiu_validation_info=aiu_validation_info,
677-
cpu_validation_info=cpu_validation_info,
678-
program_id=program_id,
679-
prompt_shape=valid_prompt,
680-
tokenizer=tokenizer
681-
)
682-
683-
elif args.test_type == "tokens":
684-
process_tokens_test (
685-
max_new_tokens=args.max_new_tokens,
686-
aiu_validation_info=aiu_validation_info,
687-
cpu_validation_info=cpu_validation_info,
688-
program_id=program_id,
689-
tokenizer=tokenizer
690-
)
691-
692-
if args.skip_validation and local_rank == 0:
693-
for sentence_idx, test_sentence in enumerate(
694-
aiu_validation_info.get_info("tokens")
695-
):
696-
tokens_prompt = [t.item() for t in test_sentence[:-args.max_new_tokens]]
697-
aiu_tokens_generated = [t.item() for t in test_sentence[-args.max_new_tokens:]]
698-
dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:")
699-
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}")
700-
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
701-
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
672+
return aiu_validation_info, cpu_validation_info
702673

703674

704-
def process_metrics_test(
675+
def run_metrics_test(
705676
cross_entropy_threshold: float,
706-
failure_rate_threshold: float,
707677
aiu_validation_info: ValidationInfo,
708678
cpu_validation_info: ValidationInfo,
709679
program_id: str,
710680
prompt_shape: Tuple[int, int],
711681
tokenizer: AutoTokenizer,
712-
) -> None:
682+
):
713683

714684
level_1_metrics = capture_level_1_metrics(
715685
cpu_validation_info.get_info("logits"),
@@ -737,17 +707,11 @@ def process_metrics_test(
737707
ce_fail_responses = filter_failed_level_1_cases(
738708
level_1_metrics, lambda m: m[0] >= cross_entropy_threshold
739709
)
710+
failure_rate = len(ce_fail_responses) / len(level_1_metrics)
740711

741-
failure_rate = len(ce_fail_responses) / len(level_1_metrics) if level_1_metrics else 0.0
742-
743-
if failure_rate >= failure_rate_threshold:
744-
dprint(f"[FAIL] Program {program_id} failed with rate {failure_rate:.4f} >= threshold {failure_rate_threshold}.")
745-
746-
if local_rank == 0:
747-
dprint(f"[PASS] Program {program_id} passed. Failure Rate: {failure_rate:.4f}.")
748-
712+
return failure_rate
749713

750-
def process_tokens_test(
714+
def run_tokens_test(
751715
max_new_tokens: int,
752716
aiu_validation_info: ValidationInfo,
753717
cpu_validation_info: ValidationInfo,
@@ -827,10 +791,10 @@ def process_tokens_test(
827791
distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout, args.save_validation_info_outputs)
828792

829793
model = load_model(
830-
device_type="cpu",
831-
model_variant=args.model_variant,
832-
is_fp8=is_fp8,
833-
distributed_kwargs=distributed_kwargs,
794+
device_type="cpu",
795+
model_variant=args.model_variant,
796+
is_fp8=is_fp8,
797+
distributed_kwargs=distributed_kwargs,
834798
stagger_load=args.stagger_load,
835799
is_validation=False)
836800

@@ -927,14 +891,14 @@ def process_tokens_test(
927891
custom_shape=custom_shape,
928892
)
929893

930-
## RUN TESTS ##
894+
## RUN VALIDATION AND TESTS ##
931895

932896
failed_cases = []
933897
# for each program and valid prompt (batch size, sequence length)
934898
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
935899
extra_kwargs["attn_name"] = ATTN_NAME
936900

937-
run_validation_tests(
901+
aiu_validation_info, cpu_validation_info = run_validation(
938902
args=args,
939903
model=model,
940904
validation_model=validation_model,
@@ -947,3 +911,50 @@ def process_tokens_test(
947911
cpu_dtype=CPU_DTYPE,
948912
tokenizer=tokenizer,
949913
)
914+
915+
if args.test_type == "metrics":
916+
failure_rate = run_metrics_test (
917+
cross_entropy_threshold=args.cross_entropy_threshold,
918+
aiu_validation_info=aiu_validation_info,
919+
cpu_validation_info=cpu_validation_info,
920+
program_id=program_id,
921+
prompt_shape=valid_prompt,
922+
tokenizer=tokenizer
923+
)
924+
if failure_rate > args.failure_rate_threshold:
925+
failed_cases.append(
926+
(program_id, valid_prompt, failure_rate)
927+
)
928+
929+
elif args.test_type == "tokens":
930+
run_tokens_test (
931+
max_new_tokens=args.max_new_tokens,
932+
aiu_validation_info=aiu_validation_info,
933+
cpu_validation_info=cpu_validation_info,
934+
program_id=program_id,
935+
tokenizer=tokenizer
936+
)
937+
938+
else:
939+
raise ValueError("test type must be one of metrics or tokens")
940+
941+
if args.skip_validation and local_rank == 0:
942+
for sentence_idx, test_sentence in enumerate(
943+
aiu_validation_info.get_info("tokens")
944+
):
945+
tokens_prompt = [t.item() for t in test_sentence[:-args.max_new_tokens]]
946+
aiu_tokens_generated = [t.item() for t in test_sentence[-args.max_new_tokens:]]
947+
dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:")
948+
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}")
949+
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
950+
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
951+
952+
if not args.skip_validation and local_rank == 0:
953+
if len(failed_cases) != 0:
954+
dprint("The test failed with the following cases:")
955+
for failed_case in failed_cases:
956+
dprint(
957+
f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}"
958+
)
959+
else:
960+
dprint("all tests passed")

0 commit comments

Comments
 (0)