Skip to content

Commit 4642142

Browse files
🚧 wip stuff
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
1 parent 340c383 commit 4642142

File tree

1 file changed

+162
-96
lines changed

1 file changed

+162
-96
lines changed

aiu_fms_testing_utils/scripts/refactored_dpp.py

Lines changed: 162 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def parse_cli_args() -> argparse.Namespace:
189189
action="store_true",
190190
help="set to true to save cpu validation outputs for later consumption",
191191
)
192+
parser.add_argument(
193+
"--save_validation_info_outputs_only",
194+
action="store_true",
195+
help="set to true to save cpu validation outputs for later consumption",
196+
)
192197
parser.add_argument(
193198
"--prioritize_large_batch_sizes",
194199
action="store_true",
@@ -595,15 +600,11 @@ def get_program_prompt_list(
595600
def run_validation(
596601
args: argparse.Namespace,
597602
model: torch.nn.Module,
598-
validation_model: Optional[torch.nn.Module],
599603
program_id: int,
600604
valid_prompt,
601605
input_ids: torch.Tensor,
602606
extra_kwargs: Dict[str, Any],
603-
sample_key: str,
604-
attn_name: str,
605-
cpu_dtype: str,
606-
tokenizer: AutoTokenizer,
607+
cpu_validation_info: ValidationInfo
607608
):
608609

609610
if local_rank == 0:
@@ -612,46 +613,46 @@ def run_validation(
612613
f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}"
613614
)
614615

615-
cpu_validation_info: Optional[ValidationInfo] = None
616-
if not args.skip_validation:
617-
# attempt to load the cpu validation info if it is already computed
618-
cpu_validation_info = __load_validation_info(
619-
model_variant=args.model_variant,
620-
batch_size=valid_prompt[0],
621-
seq_length=valid_prompt[1],
622-
max_new_tokens=args.max_new_tokens,
623-
tokenizer=tokenizer,
624-
seed=0,
625-
cpu_dtype=cpu_dtype,
626-
attn_type=attn_name,
627-
validation_info_outputs_dir=args.validation_info_outputs_dir,
628-
sample_key=sample_key,
629-
)
630-
# if the cpu validation info is not yet computed, compute it
631-
if cpu_validation_info is None and validation_model is not None:
632-
cpu_validation_info = extract_validation_information(
633-
model=validation_model,
634-
input_ids=input_ids,
635-
max_new_tokens=args.max_new_tokens,
636-
post_iteration_hook=LogitsExtractorHook(),
637-
attn_algorithm="math",
638-
**extra_kwargs,
639-
)
640-
# save the cpu validation info if requested
641-
if args.save_validation_info_outputs:
642-
cpu_validation_info.save(
643-
get_validation_info_path(
644-
validation_info_dir=args.validation_info_outputs_dir,
645-
model_variant=args.model_variant,
646-
batch_size=valid_prompt[0],
647-
seq_length=valid_prompt[1],
648-
max_new_tokens=args.max_new_tokens,
649-
seed=0,
650-
attn_type=attn_name,
651-
dtype=cpu_dtype,
652-
sample_key=sample_key,
653-
)
654-
)
616+
617+
# if not args.skip_validation:
618+
# # attempt to load the cpu validation info if it is already computed
619+
# cpu_validation_info = __load_validation_info(
620+
# model_variant=args.model_variant,
621+
# batch_size=valid_prompt[0],
622+
# seq_length=valid_prompt[1],
623+
# max_new_tokens=args.max_new_tokens,
624+
# tokenizer=tokenizer,
625+
# seed=0,
626+
# cpu_dtype=cpu_dtype,
627+
# attn_type=attn_name,
628+
# validation_info_outputs_dir=args.validation_info_outputs_dir,
629+
# sample_key=sample_key,
630+
# )
631+
# # if the cpu validation info is not yet computed, compute it
632+
# if cpu_validation_info is None and validation_model is not None:
633+
# cpu_validation_info = extract_validation_information(
634+
# model=validation_model,
635+
# input_ids=input_ids,
636+
# max_new_tokens=args.max_new_tokens,
637+
# post_iteration_hook=LogitsExtractorHook(),
638+
# attn_algorithm="math",
639+
# **extra_kwargs,
640+
# )
641+
# # save the cpu validation info if requested
642+
# if args.save_validation_info_outputs:
643+
# cpu_validation_info.save(
644+
# get_validation_info_path(
645+
# validation_info_dir=args.validation_info_outputs_dir,
646+
# model_variant=args.model_variant,
647+
# batch_size=valid_prompt[0],
648+
# seq_length=valid_prompt[1],
649+
# max_new_tokens=args.max_new_tokens,
650+
# seed=0,
651+
# attn_type=attn_name,
652+
# dtype=cpu_dtype,
653+
# sample_key=sample_key,
654+
# )
655+
# )
655656

656657
golden_hook = None
657658
if args.test_type == "metrics":
@@ -800,50 +801,6 @@ def run_tokens_test(
800801

801802
__maybe_prepare_fp8_weights(model, is_fp8)
802803

803-
# Load validation model
804-
validation_model = None
805-
if not args.skip_validation:
806-
validation_model = load_model(
807-
device_type="cpu",
808-
model_variant=args.model_variant,
809-
is_fp8=is_fp8,
810-
distributed_kwargs=distributed_kwargs,
811-
stagger_load=args.stagger_load,
812-
is_validation=True
813-
)
814-
815-
## MODEL WARMUP ##
816-
817-
# warmup with any input so compiler produces criteria json
818-
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
819-
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
820-
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
821-
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
822-
if is_fp8:
823-
prompt_list = prompt_list * 2
824-
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
825-
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
826-
827-
extra_kwargs["attn_name"] = ATTN_NAME
828-
if ( "granite-3.3-8b-instruct" in args.model_variant and args.distributed and dist.get_world_size() == 4):
829-
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
830-
831-
warmup_model(
832-
model=model,
833-
input_ids=input_ids,
834-
max_new_tokens=args.max_new_tokens,
835-
compile_dynamic_sendnn=True,
836-
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
837-
prefill_chunk_size=args.prefill_chunk_size,
838-
**extra_kwargs,
839-
)
840-
841-
if args.distributed:
842-
# wait for rank0 to be finished as it is the only one generating the criteria json
843-
# this is needed since otherwise we may run into a race condition
844-
torch.distributed.barrier()
845-
846-
847804
## PREPARE PROGRAM CRITERIA AND PROMPTS ##
848805

849806
with open(args.program_criteria_json_path, "r") as f:
@@ -891,25 +848,134 @@ def run_tokens_test(
891848
custom_shape=custom_shape,
892849
)
893850

851+
852+
## CPU validation
853+
854+
def get_cpu_validation(
855+
args: argparse.Namespace,
856+
valid_prompt,
857+
input_ids: torch.Tensor,
858+
extra_kwargs: Dict[str, Any],
859+
sample_key: str,
860+
attn_name: str,
861+
cpu_dtype: str,
862+
tokenizer: AutoTokenizer,
863+
):
864+
if not args.skip_validation:
865+
dprint("Generating CPU validation for prompt: {}".format(valid_prompt))
866+
# Load validation model
867+
validation_model = load_model(
868+
device_type="cpu",
869+
model_variant=args.model_variant,
870+
is_fp8=is_fp8,
871+
distributed_kwargs=distributed_kwargs,
872+
stagger_load=args.stagger_load,
873+
is_validation=True
874+
)
875+
876+
# attempt to load the cpu validation info if it is already computed
877+
cpu_validation_info = __load_validation_info(
878+
model_variant=args.model_variant,
879+
batch_size=valid_prompt[0],
880+
seq_length=valid_prompt[1],
881+
max_new_tokens=args.max_new_tokens,
882+
tokenizer=tokenizer,
883+
seed=0,
884+
cpu_dtype=cpu_dtype,
885+
attn_type=attn_name,
886+
validation_info_outputs_dir=args.validation_info_outputs_dir,
887+
sample_key=sample_key,
888+
)
889+
890+
if cpu_validation_info is not None:
891+
dprint("cpu validation info found, returning it")
892+
return cpu_validation_info
893+
dprint("cpu validation info not found, computing it now")
894+
# if the cpu validation info is not yet computed, compute it
895+
if validation_model is not None:
896+
dprint("extracting cpu validation info")
897+
cpu_validation_info = extract_validation_information(
898+
model=validation_model,
899+
input_ids=input_ids,
900+
max_new_tokens=args.max_new_tokens,
901+
post_iteration_hook=LogitsExtractorHook(),
902+
attn_algorithm="math",
903+
**extra_kwargs,
904+
)
905+
dprint("cpu validation info extracted")
906+
# save the cpu validation info if requested
907+
if args.save_validation_info_outputs:
908+
dprint("saving cpu validation info")
909+
cpu_validation_info.save(
910+
get_validation_info_path(
911+
validation_info_dir=args.validation_info_outputs_dir,
912+
model_variant=args.model_variant,
913+
batch_size=valid_prompt[0],
914+
seq_length=valid_prompt[1],
915+
max_new_tokens=args.max_new_tokens,
916+
seed=0,
917+
attn_type=attn_name,
918+
dtype=cpu_dtype,
919+
sample_key=sample_key,
920+
)
921+
)
922+
dprint("cpu validation info saved")
923+
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
924+
cpu_validation_info: Optional[ValidationInfo] = None
925+
cpu_validation_info = get_cpu_validation(args, valid_prompt, input_ids, extra_kwargs, sample_key, ATTN_NAME, CPU_DTYPE, tokenizer)
926+
927+
if args.save_validation_info_outputs_only:
928+
dprint("CPU validation information saved. Exiting.")
929+
exit(0)
930+
931+
932+
## MODEL WARMUP ##
933+
934+
# warmup with any input so compiler produces criteria json
935+
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
936+
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
937+
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
938+
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
939+
if is_fp8:
940+
prompt_list = prompt_list * 2
941+
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
942+
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
943+
944+
extra_kwargs["attn_name"] = ATTN_NAME
945+
if ( "granite-3.3-8b-instruct" in args.model_variant and args.distributed and dist.get_world_size() == 4):
946+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
947+
948+
warmup_model(
949+
model=model,
950+
input_ids=input_ids,
951+
max_new_tokens=args.max_new_tokens,
952+
compile_dynamic_sendnn=True,
953+
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
954+
prefill_chunk_size=args.prefill_chunk_size,
955+
**extra_kwargs,
956+
)
957+
958+
if args.distributed:
959+
# wait for rank0 to be finished as it is the only one generating the criteria json
960+
# this is needed since otherwise we may run into a race condition
961+
torch.distributed.barrier()
962+
963+
894964
## RUN VALIDATION AND TESTS ##
895965

896966
failed_cases = []
897967
# for each program and valid prompt (batch size, sequence length)
898968
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
899969
extra_kwargs["attn_name"] = ATTN_NAME
900970

901-
aiu_validation_info, cpu_validation_info = run_validation(
971+
aiu_validation_info = run_validation(
902972
args=args,
903973
model=model,
904-
validation_model=validation_model,
905974
program_id=program_id,
906975
valid_prompt=valid_prompt,
907976
input_ids=input_ids,
908-
extra_kwargs=extra_kwargs,
909-
sample_key=sample_key,
910-
attn_name=ATTN_NAME,
911-
cpu_dtype=CPU_DTYPE,
912-
tokenizer=tokenizer,
977+
extra_kwargs=extra_kwargs,
978+
cpu_validation_info=cpu_validation_info
913979
)
914980

915981
if args.test_type == "metrics":

0 commit comments

Comments
 (0)