169169 action = "store_true" ,
170170 help = "set to true to save cpu validation outputs for later consumption" ,
171171)
172+ parser .add_argument (
173+ "--stop_after_info_outputs" ,
174+ action = "store_true" ,
175+ help = "set to true to stop after cpu validation outputs have been saved" ,
176+ )
172177parser .add_argument (
173178 "--prioritize_large_batch_sizes" ,
174179 action = "store_true" ,
@@ -373,55 +378,6 @@ def __load_validation_info(
373378 )
374379
375380model .eval ()
376- fx_config .backed_size_oblivious = True
377- model .compile (backend = "sendnn" , options = {"sendnn.dynamic" : True })
378-
379- __maybe_prepare_fp8_weights (model , is_fp8 )
380-
381- if not args .skip_validation :
382- with stagger_region (args .stagger_load ):
383- validation_model = get_model (
384- architecture = "hf_pretrained" ,
385- device_type = "cpu" ,
386- data_type = None if is_fp8 else torch .float32 ,
387- fused_weights = False ,
388- ** model_path_kwargs ,
389- ** distributed_kwargs ,
390- )
391- validation_model .eval ()
392-
393- # warmup with any input so compiler produces criteria json
394- # TODO: Swap this with __prepare_inputs once fix for shape_id is available
395- # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
396- prompt_list = [torch .arange (0 , 64 , dtype = torch .int64 )]
397- # matching vllm warmup to pad to 2 on fp8, and no pad for fp16
398- if is_fp8 :
399- prompt_list = prompt_list * 2
400- input_ids , extra_kwargs = pad_input_ids (prompt_list , min_pad_length = 64 )
401- extra_kwargs ["mask" ] = extra_kwargs ["mask" ].to (torch .float16 )
402-
403- extra_kwargs ["attn_name" ] = ATTN_NAME
404- if (
405- "granite-3.3-8b-instruct" in model_variant
406- and USE_DISTRIBUTED
407- and dist .get_world_size () == 4
408- ):
409- extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
410- warmup_model (
411- model ,
412- input_ids ,
413- max_new_tokens = max_new_tokens ,
414- compile_dynamic_sendnn = True ,
415- stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
416- prefill_chunk_size = args .prefill_chunk_size ,
417- ** extra_kwargs ,
418- )
419-
420- if USE_DISTRIBUTED :
421- # wait for rank0 to be finished as it is the only one generating the criteria json
422- # this is needed since otherwise we may run into a race condition
423- torch .distributed .barrier ()
424-
425381
426382@dataclass
427383class ProgramInfo :
@@ -431,7 +387,6 @@ class ProgramInfo:
431387 prompt_length_limit : int
432388 prompt_length_limit_type : str
433389
434-
435390def parse_program_limit (limit_str : str ) -> tuple [int , str ]:
436391 matcher = re .compile (r"^(<|>|<=|>=|==)(\d+)" )
437392
@@ -448,7 +403,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
448403 limit_val = int (match .group (2 ))
449404 return limit_val , limit_type
450405
451-
406+ # TODO: Add a check or logic for case that prog criteria json must exist if saving CPU outputs
452407with open (args .program_criteria_json_path , "r" ) as f :
453408 program_criteria_json_list = json .load (f )["programs" ]
454409 program_criteria_list = []
@@ -621,39 +576,20 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
621576 f"no valid prompt shape was found which would result in program { program_id } that satisfied batch{ batch_size_limit_type } { batch_size_limit } and prompt_length{ prompt_length_limit_type } { prompt_length_limit } "
622577 )
623578
624-
625- # metric calculator based on the cross-entropy and mean diff for each decode step
626- def __metric_calculator (r : torch .Tensor , t : torch .Tensor ):
627- cross_entropy = torch .nn .CrossEntropyLoss ()(
628- r , t .softmax (dim = 1 ).to (dtype = torch .float32 )
629- )
630- diff = torch .mean (
631- torch .abs (
632- r .softmax (dim = 1 ).to (dtype = torch .float32 )
633- - t .softmax (dim = 1 ).to (dtype = torch .float32 )
634- )
635- )
636- return (cross_entropy , diff )
637-
638-
639- failed_cases = []
640- # for each program and valid prompt (batch size, sequence length)
641- for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
642- extra_kwargs ["attn_name" ] = ATTN_NAME
643- if (
644- "granite-3.3-8b-instruct" in model_variant
645- and USE_DISTRIBUTED
646- and dist .get_world_size () == 4
647- ):
648- extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
649-
650- if local_rank == 0 :
651- dprint (f"*** testing program { program_id } ***" )
652- dprint (
653- f"program id: { program_id } , valid prompt: { valid_prompt } , input shape: { input_ids .shape } "
579+ if not args .skip_validation :
580+ with stagger_region (args .stagger_load ):
581+ validation_model = get_model (
582+ architecture = "hf_pretrained" ,
583+ device_type = "cpu" ,
584+ data_type = None if is_fp8 else torch .float32 ,
585+ fused_weights = False ,
586+ ** model_path_kwargs ,
587+ ** distributed_kwargs ,
654588 )
589+ validation_model .eval ()
655590
656- if not args .skip_validation :
591+ for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
592+ dprint (f"Working on program_id: { program_id } " )
657593 # attempt to load the cpu validation info if it is already computed
658594 cpu_validation_info = __load_validation_info (
659595 model_variant ,
@@ -691,6 +627,80 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
691627 )
692628 )
693629
630+ if args .stop_after_info_outputs :
631+ dprint ("CPU validation outputs saved. Exiting as requested." )
632+ exit (0 )
633+
634+ ################################################################
635+
636+ fx_config .backed_size_oblivious = True
637+ model .compile (backend = "sendnn" , options = {"sendnn.dynamic" : True })
638+ __maybe_prepare_fp8_weights (model , is_fp8 )
639+
640+ # warmup with any input so compiler produces criteria json
641+ # TODO: Swap this with __prepare_inputs once fix for shape_id is available
642+ # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
643+ prompt_list = [torch .arange (0 , 64 , dtype = torch .int64 )]
644+ # matching vllm warmup to pad to 2 on fp8, and no pad for fp16
645+ if is_fp8 :
646+ prompt_list = prompt_list * 2
647+ input_ids , extra_kwargs = pad_input_ids (prompt_list , min_pad_length = 64 )
648+ extra_kwargs ["mask" ] = extra_kwargs ["mask" ].to (torch .float16 )
649+
650+ extra_kwargs ["attn_name" ] = ATTN_NAME
651+ if (
652+ "granite-3.3-8b-instruct" in model_variant
653+ and USE_DISTRIBUTED
654+ and dist .get_world_size () == 4
655+ ):
656+ extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
657+ warmup_model (
658+ model ,
659+ input_ids ,
660+ max_new_tokens = max_new_tokens ,
661+ compile_dynamic_sendnn = True ,
662+ stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
663+ prefill_chunk_size = args .prefill_chunk_size ,
664+ ** extra_kwargs ,
665+ )
666+
667+ if USE_DISTRIBUTED :
668+ # wait for rank0 to be finished as it is the only one generating the criteria json
669+ # this is needed since otherwise we may run into a race condition
670+ torch .distributed .barrier ()
671+
672+ # metric calculator based on the cross-entropy and mean diff for each decode step
673+ def __metric_calculator (r : torch .Tensor , t : torch .Tensor ):
674+ cross_entropy = torch .nn .CrossEntropyLoss ()(
675+ r , t .softmax (dim = 1 ).to (dtype = torch .float32 )
676+ )
677+ diff = torch .mean (
678+ torch .abs (
679+ r .softmax (dim = 1 ).to (dtype = torch .float32 )
680+ - t .softmax (dim = 1 ).to (dtype = torch .float32 )
681+ )
682+ )
683+ return (cross_entropy , diff )
684+
685+
686+ failed_cases = []
687+ # for each program and valid prompt (batch size, sequence length)
688+ for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
689+ extra_kwargs ["attn_name" ] = ATTN_NAME
690+ if (
691+ "granite-3.3-8b-instruct" in model_variant
692+ and USE_DISTRIBUTED
693+ and dist .get_world_size () == 4
694+ ):
695+ extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
696+
697+ if local_rank == 0 :
698+ dprint (f"*** testing program { program_id } ***" )
699+ dprint (
700+ f"program id: { program_id } , valid prompt: { valid_prompt } , input shape: { input_ids .shape } "
701+ )
702+
703+ if not args .skip_validation :
694704 if args .test_type == "metrics" :
695705 aiu_validation_info = extract_validation_information (
696706 model ,
0 commit comments