Skip to content

Commit 0f47515

Browse files
committed
Fix a few vars and args
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
1 parent 46fefc1 commit 0f47515

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

aiu_fms_testing_utils/scripts/refactored_dpp.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def get_model_path_kwargs(model_variant: str) -> Dict[str, Any]:
328328

329329
return model_path_kwargs
330330

331-
def get_distributed_kwargs(is_distributed: bool, dist_timeout: str) -> Dict[str, Any]:
331+
def get_distributed_kwargs(is_distributed: bool, dist_timeout: str, save_validation_info_outputs: bool) -> Dict[str, Any]:
332332

333333
distributed_kwargs = {}
334334
if is_distributed:
@@ -403,7 +403,6 @@ def load_model(
403403
is_validation: bool = False,
404404
):
405405

406-
device_type = "cpu"
407406
dtype = None if is_fp8 else (torch.float32 if is_validation else torch.float16)
408407

409408
model_path_kwargs = get_model_path_kwargs(model_variant)
@@ -631,10 +630,10 @@ def run_validation_tests(
631630
# if the cpu validation info is not yet computed, compute it
632631
if cpu_validation_info is None and validation_model is not None:
633632
cpu_validation_info = extract_validation_information(
634-
validation_model,
635-
input_ids,
636-
args.max_new_tokens,
637-
LogitsExtractorHook(),
633+
model=validation_model,
634+
input_ids=input_ids,
635+
max_new_tokens=args.max_new_tokens,
636+
post_iteration_hook=LogitsExtractorHook(),
638637
attn_algorithm="math",
639638
**extra_kwargs,
640639
)
@@ -660,10 +659,10 @@ def run_validation_tests(
660659
golden_hook = GoldenTokenHook(cpu_validation_info.get_info("tokens"))
661660

662661
aiu_validation_info = extract_validation_information(
663-
model,
664-
input_ids,
665-
args.max_new_tokens,
666-
golden_hook,
662+
model=model,
663+
input_ids=input_ids,
664+
max_new_tokens=args.max_new_tokens,
665+
post_iteration_hook=golden_hook,
667666
last_n_tokens=64,
668667
timing=args.timing,
669668
prefill_chunk_size=args.prefill_chunk_size,
@@ -825,7 +824,7 @@ def process_tokens_test(
825824
## MODEL LOADING ##
826825

827826
# Get distributed kwargs (empty if not distributed)
828-
distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout)
827+
distributed_kwargs = get_distributed_kwargs(args.distributed, args.dist_timeout, args.save_validation_info_outputs)
829828

830829
model = load_model(
831830
device_type="cpu",
@@ -841,7 +840,12 @@ def process_tokens_test(
841840
validation_model = None
842841
if not args.skip_validation:
843842
validation_model = load_model(
844-
args.model_variant, is_fp8, distributed_kwargs, args.stagger_load, is_validation=True
843+
device_type="cpu",
844+
model_variant=args.model_variant,
845+
is_fp8=is_fp8,
846+
distributed_kwargs=distributed_kwargs,
847+
stagger_load=args.stagger_load,
848+
is_validation=True
845849
)
846850

847851
## MODEL WARMUP ##
@@ -912,14 +916,15 @@ def process_tokens_test(
912916

913917
# Select concrete prompts and program associations
914918
valid_prompts = get_program_prompt_list(
915-
args,
916-
programs_to_test,
917-
program_criteria_list,
918-
program_map,
919-
tokenizer,
920-
sampler,
921-
allow_truncation,
922-
custom_shape
919+
program_map=program_map,
920+
dataset_path=args.dataset_path,
921+
enforce_homogeneous_prompt_programs=args.enforce_homogeneous_prompt_programs,
922+
programs_to_test=programs_to_test,
923+
program_criteria_list=program_criteria_list,
924+
tokenizer=tokenizer,
925+
sampler=sampler,
926+
allow_truncation=allow_truncation,
927+
custom_shape=custom_shape,
923928
)
924929

925930
## RUN TESTS ##

0 commit comments

Comments
 (0)