Skip to content

Commit 9efcd6a

Browse files
committed
First refactor
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
1 parent 5e84730 commit 9efcd6a

File tree

1 file changed

+100
-79
lines changed

1 file changed

+100
-79
lines changed

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 100 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@
169169
action="store_true",
170170
help="set to true to save cpu validation outputs for later consumption",
171171
)
172+
parser.add_argument(
173+
"--stop_after_info_outputs",
174+
action="store_true",
175+
help="set to true to stop after cpu validation outputs have been saved",
176+
)
172177
parser.add_argument(
173178
"--prioritize_large_batch_sizes",
174179
action="store_true",
@@ -372,57 +377,11 @@ def __load_validation_info(
372377
**distributed_kwargs,
373378
)
374379

380+
dprint("model.eval()")
375381
model.eval()
376-
fx_config.backed_size_oblivious = True
377-
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
378-
379-
__maybe_prepare_fp8_weights(model, is_fp8)
380-
381-
if not args.skip_validation:
382-
with stagger_region(args.stagger_load):
383-
validation_model = get_model(
384-
architecture="hf_pretrained",
385-
device_type="cpu",
386-
data_type=None if is_fp8 else torch.float32,
387-
fused_weights=False,
388-
**model_path_kwargs,
389-
**distributed_kwargs,
390-
)
391-
validation_model.eval()
392-
393-
# warmup with any input so compiler produces criteria json
394-
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
395-
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
396-
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
397-
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
398-
if is_fp8:
399-
prompt_list = prompt_list * 2
400-
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
401-
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
402-
403-
extra_kwargs["attn_name"] = ATTN_NAME
404-
if (
405-
"granite-3.3-8b-instruct" in model_variant
406-
and USE_DISTRIBUTED
407-
and dist.get_world_size() == 4
408-
):
409-
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
410-
warmup_model(
411-
model,
412-
input_ids,
413-
max_new_tokens=max_new_tokens,
414-
compile_dynamic_sendnn=True,
415-
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
416-
prefill_chunk_size=args.prefill_chunk_size,
417-
**extra_kwargs,
418-
)
419-
420-
if USE_DISTRIBUTED:
421-
# wait for rank0 to be finished as it is the only one generating the criteria json
422-
# this is needed since otherwise we may run into a race condition
423-
torch.distributed.barrier()
424-
382+
dprint("finished eval")
425383

384+
######################################### Building valid prompts
426385
@dataclass
427386
class ProgramInfo:
428387
program_id: str
@@ -431,7 +390,6 @@ class ProgramInfo:
431390
prompt_length_limit: int
432391
prompt_length_limit_type: str
433392

434-
435393
def parse_program_limit(limit_str: str) -> tuple[int, str]:
436394
matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)")
437395

@@ -448,7 +406,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
448406
limit_val = int(match.group(2))
449407
return limit_val, limit_type
450408

451-
452409
with open(args.program_criteria_json_path, "r") as f:
453410
program_criteria_json_list = json.load(f)["programs"]
454411
program_criteria_list = []
@@ -621,39 +578,26 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
621578
f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}"
622579
)
623580

624-
625-
# metric calculator based on the cross-entropy and mean diff for each decode step
626-
def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
627-
cross_entropy = torch.nn.CrossEntropyLoss()(
628-
r, t.softmax(dim=1).to(dtype=torch.float32)
629-
)
630-
diff = torch.mean(
631-
torch.abs(
632-
r.softmax(dim=1).to(dtype=torch.float32)
633-
- t.softmax(dim=1).to(dtype=torch.float32)
634-
)
635-
)
636-
return (cross_entropy, diff)
581+
dprint("valid prompts are prepared.")
582+
################################################################
637583

638584

639-
failed_cases = []
640-
# for each program and valid prompt (batch size, sequence length)
641-
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
642-
extra_kwargs["attn_name"] = ATTN_NAME
643-
if (
644-
"granite-3.3-8b-instruct" in model_variant
645-
and USE_DISTRIBUTED
646-
and dist.get_world_size() == 4
647-
):
648-
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
585+
############## saving CPU validation ###########################
649586

650-
if local_rank == 0:
651-
dprint(f"*** testing program {program_id} ***")
652-
dprint(
653-
f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}"
587+
if not args.skip_validation:
588+
with stagger_region(args.stagger_load):
589+
validation_model = get_model(
590+
architecture="hf_pretrained",
591+
device_type="cpu",
592+
data_type=None if is_fp8 else torch.float32,
593+
fused_weights=False,
594+
**model_path_kwargs,
595+
**distributed_kwargs,
654596
)
597+
validation_model.eval()
655598

656-
if not args.skip_validation:
599+
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
600+
dprint(f"Working on program_id: {program_id}")
657601
# attempt to load the cpu validation info if it is already computed
658602
cpu_validation_info = __load_validation_info(
659603
model_variant,
@@ -667,6 +611,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
667611
)
668612
# if the cpu validation info is not yet computed, compute it
669613
if cpu_validation_info is None:
614+
dprint(f"Let's start to compute it.")
670615
cpu_validation_info = extract_validation_information(
671616
validation_model,
672617
input_ids,
@@ -675,6 +620,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
675620
attn_algorithm="math",
676621
**extra_kwargs,
677622
)
623+
dprint(f"Save that cpu info.")
678624
# save the cpu validation info for later consumption
679625
if save_validation_info_outputs:
680626
cpu_validation_info.save(
@@ -691,6 +637,81 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
691637
)
692638
)
693639

640+
if args.stop_after_info_outputs:
641+
dprint("CPU validation info saved. Exiting as requested.")
642+
# sys.exit(0)
643+
644+
################################################################
645+
646+
dprint("onto model compilation related stuff ...")
647+
fx_config.backed_size_oblivious = True
648+
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
649+
__maybe_prepare_fp8_weights(model, is_fp8)
650+
651+
# warmup with any input so compiler produces criteria json
652+
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
653+
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
654+
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
655+
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
656+
if is_fp8:
657+
prompt_list = prompt_list * 2
658+
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
659+
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
660+
661+
extra_kwargs["attn_name"] = ATTN_NAME
662+
if (
663+
"granite-3.3-8b-instruct" in model_variant
664+
and USE_DISTRIBUTED
665+
and dist.get_world_size() == 4
666+
):
667+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
668+
warmup_model(
669+
model,
670+
input_ids,
671+
max_new_tokens=max_new_tokens,
672+
compile_dynamic_sendnn=True,
673+
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
674+
prefill_chunk_size=args.prefill_chunk_size,
675+
**extra_kwargs,
676+
)
677+
678+
if USE_DISTRIBUTED:
679+
# wait for rank0 to be finished as it is the only one generating the criteria json
680+
# this is needed since otherwise we may run into a race condition
681+
torch.distributed.barrier()
682+
683+
# metric calculator based on the cross-entropy and mean diff for each decode step
684+
def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
685+
cross_entropy = torch.nn.CrossEntropyLoss()(
686+
r, t.softmax(dim=1).to(dtype=torch.float32)
687+
)
688+
diff = torch.mean(
689+
torch.abs(
690+
r.softmax(dim=1).to(dtype=torch.float32)
691+
- t.softmax(dim=1).to(dtype=torch.float32)
692+
)
693+
)
694+
return (cross_entropy, diff)
695+
696+
697+
failed_cases = []
698+
# for each program and valid prompt (batch size, sequence length)
699+
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
700+
extra_kwargs["attn_name"] = ATTN_NAME
701+
if (
702+
"granite-3.3-8b-instruct" in model_variant
703+
and USE_DISTRIBUTED
704+
and dist.get_world_size() == 4
705+
):
706+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
707+
708+
if local_rank == 0:
709+
dprint(f"*** testing program {program_id} ***")
710+
dprint(
711+
f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}"
712+
)
713+
714+
if not args.skip_validation:
694715
if args.test_type == "metrics":
695716
aiu_validation_info = extract_validation_information(
696717
model,

0 commit comments

Comments
 (0)