Skip to content

Commit 8e99ab7

Browse files
committed
First refactor
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
1 parent 5e84730 commit 8e99ab7

File tree

1 file changed

+92
-82
lines changed

1 file changed

+92
-82
lines changed

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 92 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@
169169
action="store_true",
170170
help="set to true to save cpu validation outputs for later consumption",
171171
)
172+
parser.add_argument(
173+
"--stop_after_info_outputs",
174+
action="store_true",
175+
help="set to true to stop after cpu validation outputs have been saved",
176+
)
172177
parser.add_argument(
173178
"--prioritize_large_batch_sizes",
174179
action="store_true",
@@ -373,55 +378,6 @@ def __load_validation_info(
373378
)
374379

375380
model.eval()
376-
fx_config.backed_size_oblivious = True
377-
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
378-
379-
__maybe_prepare_fp8_weights(model, is_fp8)
380-
381-
if not args.skip_validation:
382-
with stagger_region(args.stagger_load):
383-
validation_model = get_model(
384-
architecture="hf_pretrained",
385-
device_type="cpu",
386-
data_type=None if is_fp8 else torch.float32,
387-
fused_weights=False,
388-
**model_path_kwargs,
389-
**distributed_kwargs,
390-
)
391-
validation_model.eval()
392-
393-
# warmup with any input so compiler produces criteria json
394-
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
395-
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
396-
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
397-
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
398-
if is_fp8:
399-
prompt_list = prompt_list * 2
400-
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
401-
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
402-
403-
extra_kwargs["attn_name"] = ATTN_NAME
404-
if (
405-
"granite-3.3-8b-instruct" in model_variant
406-
and USE_DISTRIBUTED
407-
and dist.get_world_size() == 4
408-
):
409-
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
410-
warmup_model(
411-
model,
412-
input_ids,
413-
max_new_tokens=max_new_tokens,
414-
compile_dynamic_sendnn=True,
415-
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
416-
prefill_chunk_size=args.prefill_chunk_size,
417-
**extra_kwargs,
418-
)
419-
420-
if USE_DISTRIBUTED:
421-
# wait for rank0 to be finished as it is the only one generating the criteria json
422-
# this is needed since otherwise we may run into a race condition
423-
torch.distributed.barrier()
424-
425381

426382
@dataclass
427383
class ProgramInfo:
@@ -431,7 +387,6 @@ class ProgramInfo:
431387
prompt_length_limit: int
432388
prompt_length_limit_type: str
433389

434-
435390
def parse_program_limit(limit_str: str) -> tuple[int, str]:
436391
matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)")
437392

@@ -448,7 +403,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
448403
limit_val = int(match.group(2))
449404
return limit_val, limit_type
450405

451-
406+
# TODO: Add a check or logic for case that prog criteria json must exist if saving CPU outputs
452407
with open(args.program_criteria_json_path, "r") as f:
453408
program_criteria_json_list = json.load(f)["programs"]
454409
program_criteria_list = []
@@ -621,39 +576,20 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
621576
f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}"
622577
)
623578

624-
625-
# metric calculator based on the cross-entropy and mean diff for each decode step
626-
def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
627-
cross_entropy = torch.nn.CrossEntropyLoss()(
628-
r, t.softmax(dim=1).to(dtype=torch.float32)
629-
)
630-
diff = torch.mean(
631-
torch.abs(
632-
r.softmax(dim=1).to(dtype=torch.float32)
633-
- t.softmax(dim=1).to(dtype=torch.float32)
634-
)
635-
)
636-
return (cross_entropy, diff)
637-
638-
639-
failed_cases = []
640-
# for each program and valid prompt (batch size, sequence length)
641-
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
642-
extra_kwargs["attn_name"] = ATTN_NAME
643-
if (
644-
"granite-3.3-8b-instruct" in model_variant
645-
and USE_DISTRIBUTED
646-
and dist.get_world_size() == 4
647-
):
648-
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
649-
650-
if local_rank == 0:
651-
dprint(f"*** testing program {program_id} ***")
652-
dprint(
653-
f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}"
579+
if not args.skip_validation:
580+
with stagger_region(args.stagger_load):
581+
validation_model = get_model(
582+
architecture="hf_pretrained",
583+
device_type="cpu",
584+
data_type=None if is_fp8 else torch.float32,
585+
fused_weights=False,
586+
**model_path_kwargs,
587+
**distributed_kwargs,
654588
)
589+
validation_model.eval()
655590

656-
if not args.skip_validation:
591+
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
592+
dprint(f"Working on program_id: {program_id}")
657593
# attempt to load the cpu validation info if it is already computed
658594
cpu_validation_info = __load_validation_info(
659595
model_variant,
@@ -691,6 +627,80 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
691627
)
692628
)
693629

630+
if args.stop_after_info_outputs:
631+
dprint("CPU validation outputs saved. Exiting as requested.")
632+
exit(0)
633+
634+
################################################################
635+
636+
fx_config.backed_size_oblivious = True
637+
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
638+
__maybe_prepare_fp8_weights(model, is_fp8)
639+
640+
# warmup with any input so compiler produces criteria json
641+
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
642+
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
643+
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
644+
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
645+
if is_fp8:
646+
prompt_list = prompt_list * 2
647+
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
648+
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
649+
650+
extra_kwargs["attn_name"] = ATTN_NAME
651+
if (
652+
"granite-3.3-8b-instruct" in model_variant
653+
and USE_DISTRIBUTED
654+
and dist.get_world_size() == 4
655+
):
656+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
657+
warmup_model(
658+
model,
659+
input_ids,
660+
max_new_tokens=max_new_tokens,
661+
compile_dynamic_sendnn=True,
662+
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
663+
prefill_chunk_size=args.prefill_chunk_size,
664+
**extra_kwargs,
665+
)
666+
667+
if USE_DISTRIBUTED:
668+
# wait for rank0 to be finished as it is the only one generating the criteria json
669+
# this is needed since otherwise we may run into a race condition
670+
torch.distributed.barrier()
671+
672+
# metric calculator based on the cross-entropy and mean diff for each decode step
673+
def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
674+
cross_entropy = torch.nn.CrossEntropyLoss()(
675+
r, t.softmax(dim=1).to(dtype=torch.float32)
676+
)
677+
diff = torch.mean(
678+
torch.abs(
679+
r.softmax(dim=1).to(dtype=torch.float32)
680+
- t.softmax(dim=1).to(dtype=torch.float32)
681+
)
682+
)
683+
return (cross_entropy, diff)
684+
685+
686+
failed_cases = []
687+
# for each program and valid prompt (batch size, sequence length)
688+
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
689+
extra_kwargs["attn_name"] = ATTN_NAME
690+
if (
691+
"granite-3.3-8b-instruct" in model_variant
692+
and USE_DISTRIBUTED
693+
and dist.get_world_size() == 4
694+
):
695+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
696+
697+
if local_rank == 0:
698+
dprint(f"*** testing program {program_id} ***")
699+
dprint(
700+
f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}"
701+
)
702+
703+
if not args.skip_validation:
694704
if args.test_type == "metrics":
695705
aiu_validation_info = extract_validation_information(
696706
model,

0 commit comments

Comments
 (0)