169169 action = "store_true" ,
170170 help = "set to true to save cpu validation outputs for later consumption" ,
171171)
172+ parser .add_argument (
173+ "--stop_after_info_outputs" ,
174+ action = "store_true" ,
175+ help = "set to true to stop after cpu validation outputs have been saved" ,
176+ )
172177parser .add_argument (
173178 "--prioritize_large_batch_sizes" ,
174179 action = "store_true" ,
@@ -372,57 +377,11 @@ def __load_validation_info(
372377 ** distributed_kwargs ,
373378 )
374379
380+ dprint ("model.eval()" )
375381model .eval ()
376- fx_config .backed_size_oblivious = True
377- model .compile (backend = "sendnn" , options = {"sendnn.dynamic" : True })
378-
379- __maybe_prepare_fp8_weights (model , is_fp8 )
380-
381- if not args .skip_validation :
382- with stagger_region (args .stagger_load ):
383- validation_model = get_model (
384- architecture = "hf_pretrained" ,
385- device_type = "cpu" ,
386- data_type = None if is_fp8 else torch .float32 ,
387- fused_weights = False ,
388- ** model_path_kwargs ,
389- ** distributed_kwargs ,
390- )
391- validation_model .eval ()
392-
393- # warmup with any input so compiler produces criteria json
394- # TODO: Swap this with __prepare_inputs once fix for shape_id is available
395- # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
396- prompt_list = [torch .arange (0 , 64 , dtype = torch .int64 )]
397- # matching vllm warmup to pad to 2 on fp8, and no pad for fp16
398- if is_fp8 :
399- prompt_list = prompt_list * 2
400- input_ids , extra_kwargs = pad_input_ids (prompt_list , min_pad_length = 64 )
401- extra_kwargs ["mask" ] = extra_kwargs ["mask" ].to (torch .float16 )
402-
403- extra_kwargs ["attn_name" ] = ATTN_NAME
404- if (
405- "granite-3.3-8b-instruct" in model_variant
406- and USE_DISTRIBUTED
407- and dist .get_world_size () == 4
408- ):
409- extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
410- warmup_model (
411- model ,
412- input_ids ,
413- max_new_tokens = max_new_tokens ,
414- compile_dynamic_sendnn = True ,
415- stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
416- prefill_chunk_size = args .prefill_chunk_size ,
417- ** extra_kwargs ,
418- )
419-
420- if USE_DISTRIBUTED :
421- # wait for rank0 to be finished as it is the only one generating the criteria json
422- # this is needed since otherwise we may run into a race condition
423- torch .distributed .barrier ()
424-
382+ dprint ("finished eval" )
425383
384+ ######################################### Building valid prompts
426385@dataclass
427386class ProgramInfo :
428387 program_id : str
@@ -431,7 +390,6 @@ class ProgramInfo:
431390 prompt_length_limit : int
432391 prompt_length_limit_type : str
433392
434-
435393def parse_program_limit (limit_str : str ) -> tuple [int , str ]:
436394 matcher = re .compile (r"^(<|>|<=|>=|==)(\d+)" )
437395
@@ -448,7 +406,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
448406 limit_val = int (match .group (2 ))
449407 return limit_val , limit_type
450408
451-
452409with open (args .program_criteria_json_path , "r" ) as f :
453410 program_criteria_json_list = json .load (f )["programs" ]
454411 program_criteria_list = []
@@ -621,39 +578,26 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
621578 f"no valid prompt shape was found which would result in program { program_id } that satisfied batch{ batch_size_limit_type } { batch_size_limit } and prompt_length{ prompt_length_limit_type } { prompt_length_limit } "
622579 )
623580
624-
625- # metric calculator based on the cross-entropy and mean diff for each decode step
626- def __metric_calculator (r : torch .Tensor , t : torch .Tensor ):
627- cross_entropy = torch .nn .CrossEntropyLoss ()(
628- r , t .softmax (dim = 1 ).to (dtype = torch .float32 )
629- )
630- diff = torch .mean (
631- torch .abs (
632- r .softmax (dim = 1 ).to (dtype = torch .float32 )
633- - t .softmax (dim = 1 ).to (dtype = torch .float32 )
634- )
635- )
636- return (cross_entropy , diff )
581+ dprint ("valid prompts are prepared." )
582+ ################################################################
637583
638584
639- failed_cases = []
640- # for each program and valid prompt (batch size, sequence length)
641- for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
642- extra_kwargs ["attn_name" ] = ATTN_NAME
643- if (
644- "granite-3.3-8b-instruct" in model_variant
645- and USE_DISTRIBUTED
646- and dist .get_world_size () == 4
647- ):
648- extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
585+ ############## saving CPU validation ###########################
649586
650- if local_rank == 0 :
651- dprint (f"*** testing program { program_id } ***" )
652- dprint (
653- f"program id: { program_id } , valid prompt: { valid_prompt } , input shape: { input_ids .shape } "
587+ if not args .skip_validation :
588+ with stagger_region (args .stagger_load ):
589+ validation_model = get_model (
590+ architecture = "hf_pretrained" ,
591+ device_type = "cpu" ,
592+ data_type = None if is_fp8 else torch .float32 ,
593+ fused_weights = False ,
594+ ** model_path_kwargs ,
595+ ** distributed_kwargs ,
654596 )
597+ validation_model .eval ()
655598
656- if not args .skip_validation :
599+ for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
600+ dprint (f"Working on program_id: { program_id } " )
657601 # attempt to load the cpu validation info if it is already computed
658602 cpu_validation_info = __load_validation_info (
659603 model_variant ,
@@ -667,6 +611,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
667611 )
668612 # if the cpu validation info is not yet computed, compute it
669613 if cpu_validation_info is None :
614+ dprint (f"Let's start to compute it." )
670615 cpu_validation_info = extract_validation_information (
671616 validation_model ,
672617 input_ids ,
@@ -675,6 +620,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
675620 attn_algorithm = "math" ,
676621 ** extra_kwargs ,
677622 )
623+ dprint (f"Save that cpu info." )
678624 # save the cpu validation info for later consumption
679625 if save_validation_info_outputs :
680626 cpu_validation_info .save (
@@ -691,6 +637,81 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
691637 )
692638 )
693639
640+ if args .stop_after_info_outputs :
641+ dprint ("CPU validation info saved. Exiting as requested." )
642+ # sys.exit(0)
643+
644+ ################################################################
645+
646+ dprint ("onto model compilation related stuff ..." )
647+ fx_config .backed_size_oblivious = True
648+ model .compile (backend = "sendnn" , options = {"sendnn.dynamic" : True })
649+ __maybe_prepare_fp8_weights (model , is_fp8 )
650+
651+ # warmup with any input so compiler produces criteria json
652+ # TODO: Swap this with __prepare_inputs once fix for shape_id is available
653+ # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
654+ prompt_list = [torch .arange (0 , 64 , dtype = torch .int64 )]
655+ # matching vllm warmup to pad to 2 on fp8, and no pad for fp16
656+ if is_fp8 :
657+ prompt_list = prompt_list * 2
658+ input_ids , extra_kwargs = pad_input_ids (prompt_list , min_pad_length = 64 )
659+ extra_kwargs ["mask" ] = extra_kwargs ["mask" ].to (torch .float16 )
660+
661+ extra_kwargs ["attn_name" ] = ATTN_NAME
662+ if (
663+ "granite-3.3-8b-instruct" in model_variant
664+ and USE_DISTRIBUTED
665+ and dist .get_world_size () == 4
666+ ):
667+ extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
668+ warmup_model (
669+ model ,
670+ input_ids ,
671+ max_new_tokens = max_new_tokens ,
672+ compile_dynamic_sendnn = True ,
673+ stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
674+ prefill_chunk_size = args .prefill_chunk_size ,
675+ ** extra_kwargs ,
676+ )
677+
678+ if USE_DISTRIBUTED :
679+ # wait for rank0 to be finished as it is the only one generating the criteria json
680+ # this is needed since otherwise we may run into a race condition
681+ torch .distributed .barrier ()
682+
683+ # metric calculator based on the cross-entropy and mean diff for each decode step
684+ def __metric_calculator (r : torch .Tensor , t : torch .Tensor ):
685+ cross_entropy = torch .nn .CrossEntropyLoss ()(
686+ r , t .softmax (dim = 1 ).to (dtype = torch .float32 )
687+ )
688+ diff = torch .mean (
689+ torch .abs (
690+ r .softmax (dim = 1 ).to (dtype = torch .float32 )
691+ - t .softmax (dim = 1 ).to (dtype = torch .float32 )
692+ )
693+ )
694+ return (cross_entropy , diff )
695+
696+
697+ failed_cases = []
698+ # for each program and valid prompt (batch size, sequence length)
699+ for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
700+ extra_kwargs ["attn_name" ] = ATTN_NAME
701+ if (
702+ "granite-3.3-8b-instruct" in model_variant
703+ and USE_DISTRIBUTED
704+ and dist .get_world_size () == 4
705+ ):
706+ extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
707+
708+ if local_rank == 0 :
709+ dprint (f"*** testing program { program_id } ***" )
710+ dprint (
711+ f"program id: { program_id } , valid prompt: { valid_prompt } , input shape: { input_ids .shape } "
712+ )
713+
714+ if not args .skip_validation :
694715 if args .test_type == "metrics" :
695716 aiu_validation_info = extract_validation_information (
696717 model ,
0 commit comments