@@ -189,6 +189,11 @@ def parse_cli_args() -> argparse.Namespace:
189189 action = "store_true" ,
190190 help = "set to true to save cpu validation outputs for later consumption" ,
191191 )
192+ parser .add_argument (
193+ "--save_validation_info_outputs_only" ,
194+ action = "store_true" ,
195+ help = "set to true to save cpu validation outputs for later consumption" ,
196+ )
192197 parser .add_argument (
193198 "--prioritize_large_batch_sizes" ,
194199 action = "store_true" ,
@@ -595,15 +600,11 @@ def get_program_prompt_list(
595600def run_validation (
596601 args : argparse .Namespace ,
597602 model : torch .nn .Module ,
598- validation_model : Optional [torch .nn .Module ],
599603 program_id : int ,
600604 valid_prompt ,
601605 input_ids : torch .Tensor ,
602606 extra_kwargs : Dict [str , Any ],
603- sample_key : str ,
604- attn_name : str ,
605- cpu_dtype : str ,
606- tokenizer : AutoTokenizer ,
607+ cpu_validation_info : ValidationInfo
607608):
608609
609610 if local_rank == 0 :
@@ -612,46 +613,46 @@ def run_validation(
612613 f"program id: { program_id } , valid prompt: { valid_prompt } , input shape: { input_ids .shape } "
613614 )
614615
615- cpu_validation_info : Optional [ ValidationInfo ] = None
616- if not args .skip_validation :
617- # attempt to load the cpu validation info if it is already computed
618- cpu_validation_info = __load_validation_info (
619- model_variant = args .model_variant ,
620- batch_size = valid_prompt [0 ],
621- seq_length = valid_prompt [1 ],
622- max_new_tokens = args .max_new_tokens ,
623- tokenizer = tokenizer ,
624- seed = 0 ,
625- cpu_dtype = cpu_dtype ,
626- attn_type = attn_name ,
627- validation_info_outputs_dir = args .validation_info_outputs_dir ,
628- sample_key = sample_key ,
629- )
630- # if the cpu validation info is not yet computed, compute it
631- if cpu_validation_info is None and validation_model is not None :
632- cpu_validation_info = extract_validation_information (
633- model = validation_model ,
634- input_ids = input_ids ,
635- max_new_tokens = args .max_new_tokens ,
636- post_iteration_hook = LogitsExtractorHook (),
637- attn_algorithm = "math" ,
638- ** extra_kwargs ,
639- )
640- # save the cpu validation info if requested
641- if args .save_validation_info_outputs :
642- cpu_validation_info .save (
643- get_validation_info_path (
644- validation_info_dir = args .validation_info_outputs_dir ,
645- model_variant = args .model_variant ,
646- batch_size = valid_prompt [0 ],
647- seq_length = valid_prompt [1 ],
648- max_new_tokens = args .max_new_tokens ,
649- seed = 0 ,
650- attn_type = attn_name ,
651- dtype = cpu_dtype ,
652- sample_key = sample_key ,
653- )
654- )
616+
617+ # if not args.skip_validation:
618+ # # attempt to load the cpu validation info if it is already computed
619+ # cpu_validation_info = __load_validation_info(
620+ # model_variant=args.model_variant,
621+ # batch_size=valid_prompt[0],
622+ # seq_length=valid_prompt[1],
623+ # max_new_tokens=args.max_new_tokens,
624+ # tokenizer=tokenizer,
625+ # seed=0,
626+ # cpu_dtype=cpu_dtype,
627+ # attn_type=attn_name,
628+ # validation_info_outputs_dir=args.validation_info_outputs_dir,
629+ # sample_key=sample_key,
630+ # )
631+ # # if the cpu validation info is not yet computed, compute it
632+ # if cpu_validation_info is None and validation_model is not None:
633+ # cpu_validation_info = extract_validation_information(
634+ # model=validation_model,
635+ # input_ids=input_ids,
636+ # max_new_tokens=args.max_new_tokens,
637+ # post_iteration_hook=LogitsExtractorHook(),
638+ # attn_algorithm="math",
639+ # **extra_kwargs,
640+ # )
641+ # # save the cpu validation info if requested
642+ # if args.save_validation_info_outputs:
643+ # cpu_validation_info.save(
644+ # get_validation_info_path(
645+ # validation_info_dir=args.validation_info_outputs_dir,
646+ # model_variant=args.model_variant,
647+ # batch_size=valid_prompt[0],
648+ # seq_length=valid_prompt[1],
649+ # max_new_tokens=args.max_new_tokens,
650+ # seed=0,
651+ # attn_type=attn_name,
652+ # dtype=cpu_dtype,
653+ # sample_key=sample_key,
654+ # )
655+ # )
655656
656657 golden_hook = None
657658 if args .test_type == "metrics" :
@@ -800,50 +801,6 @@ def run_tokens_test(
800801
801802__maybe_prepare_fp8_weights (model , is_fp8 )
802803
803- # Load validation model
804- validation_model = None
805- if not args .skip_validation :
806- validation_model = load_model (
807- device_type = "cpu" ,
808- model_variant = args .model_variant ,
809- is_fp8 = is_fp8 ,
810- distributed_kwargs = distributed_kwargs ,
811- stagger_load = args .stagger_load ,
812- is_validation = True
813- )
814-
815- ## MODEL WARMUP ##
816-
817- # warmup with any input so compiler produces criteria json
818- # TODO: Swap this with __prepare_inputs once fix for shape_id is available
819- # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
820- prompt_list = [torch .arange (0 , 64 , dtype = torch .int64 )]
821- # matching vllm warmup to pad to 2 on fp8, and no pad for fp16
822- if is_fp8 :
823- prompt_list = prompt_list * 2
824- input_ids , extra_kwargs = pad_input_ids (prompt_list , min_pad_length = 64 )
825- extra_kwargs ["mask" ] = extra_kwargs ["mask" ].to (torch .float16 )
826-
827- extra_kwargs ["attn_name" ] = ATTN_NAME
828- if ( "granite-3.3-8b-instruct" in args .model_variant and args .distributed and dist .get_world_size () == 4 ):
829- extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
830-
831- warmup_model (
832- model = model ,
833- input_ids = input_ids ,
834- max_new_tokens = args .max_new_tokens ,
835- compile_dynamic_sendnn = True ,
836- stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
837- prefill_chunk_size = args .prefill_chunk_size ,
838- ** extra_kwargs ,
839- )
840-
841- if args .distributed :
842- # wait for rank0 to be finished as it is the only one generating the criteria json
843- # this is needed since otherwise we may run into a race condition
844- torch .distributed .barrier ()
845-
846-
847804## PREPARE PROGRAM CRITERIA AND PROMPTS ##
848805
849806with open (args .program_criteria_json_path , "r" ) as f :
@@ -891,25 +848,134 @@ def run_tokens_test(
891848 custom_shape = custom_shape ,
892849)
893850
851+
852+ ## CPU validation
853+
854+ def get_cpu_validation (
855+ args : argparse .Namespace ,
856+ valid_prompt ,
857+ input_ids : torch .Tensor ,
858+ extra_kwargs : Dict [str , Any ],
859+ sample_key : str ,
860+ attn_name : str ,
861+ cpu_dtype : str ,
862+ tokenizer : AutoTokenizer ,
863+ ):
864+ if not args .skip_validation :
865+ dprint ("Generating CPU validation for prompt: {}" .format (valid_prompt ))
866+ # Load validation model
867+ validation_model = load_model (
868+ device_type = "cpu" ,
869+ model_variant = args .model_variant ,
870+ is_fp8 = is_fp8 ,
871+ distributed_kwargs = distributed_kwargs ,
872+ stagger_load = args .stagger_load ,
873+ is_validation = True
874+ )
875+
876+ # attempt to load the cpu validation info if it is already computed
877+ cpu_validation_info = __load_validation_info (
878+ model_variant = args .model_variant ,
879+ batch_size = valid_prompt [0 ],
880+ seq_length = valid_prompt [1 ],
881+ max_new_tokens = args .max_new_tokens ,
882+ tokenizer = tokenizer ,
883+ seed = 0 ,
884+ cpu_dtype = cpu_dtype ,
885+ attn_type = attn_name ,
886+ validation_info_outputs_dir = args .validation_info_outputs_dir ,
887+ sample_key = sample_key ,
888+ )
889+
890+ if cpu_validation_info is not None :
891+ dprint ("cpu validation info found, returning it" )
892+ return cpu_validation_info
893+ dprint ("cpu validation info not found, computing it now" )
894+ # if the cpu validation info is not yet computed, compute it
895+ if validation_model is not None :
896+ dprint ("extracting cpu validation info" )
897+ cpu_validation_info = extract_validation_information (
898+ model = validation_model ,
899+ input_ids = input_ids ,
900+ max_new_tokens = args .max_new_tokens ,
901+ post_iteration_hook = LogitsExtractorHook (),
902+ attn_algorithm = "math" ,
903+ ** extra_kwargs ,
904+ )
905+ dprint ("cpu validation info extracted" )
906+ # save the cpu validation info if requested
907+ if args .save_validation_info_outputs :
908+ dprint ("saving cpu validation info" )
909+ cpu_validation_info .save (
910+ get_validation_info_path (
911+ validation_info_dir = args .validation_info_outputs_dir ,
912+ model_variant = args .model_variant ,
913+ batch_size = valid_prompt [0 ],
914+ seq_length = valid_prompt [1 ],
915+ max_new_tokens = args .max_new_tokens ,
916+ seed = 0 ,
917+ attn_type = attn_name ,
918+ dtype = cpu_dtype ,
919+ sample_key = sample_key ,
920+ )
921+ )
922+ dprint ("cpu validation info saved" )
923+ for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
924+ cpu_validation_info : Optional [ValidationInfo ] = None
925+ cpu_validation_info = get_cpu_validation (args , valid_prompt , input_ids , extra_kwargs , sample_key , ATTN_NAME , CPU_DTYPE , tokenizer )
926+
927+ if args .save_validation_info_outputs_only :
928+ dprint ("CPU validation information saved. Exiting." )
929+ exit (0 )
930+
931+
932+ ## MODEL WARMUP ##
933+
934+ # warmup with any input so compiler produces criteria json
935+ # TODO: Swap this with __prepare_inputs once fix for shape_id is available
936+ # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
937+ prompt_list = [torch .arange (0 , 64 , dtype = torch .int64 )]
938+ # matching vllm warmup to pad to 2 on fp8, and no pad for fp16
939+ if is_fp8 :
940+ prompt_list = prompt_list * 2
941+ input_ids , extra_kwargs = pad_input_ids (prompt_list , min_pad_length = 64 )
942+ extra_kwargs ["mask" ] = extra_kwargs ["mask" ].to (torch .float16 )
943+
944+ extra_kwargs ["attn_name" ] = ATTN_NAME
945+ if ( "granite-3.3-8b-instruct" in args .model_variant and args .distributed and dist .get_world_size () == 4 ):
946+ extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
947+
948+ warmup_model (
949+ model = model ,
950+ input_ids = input_ids ,
951+ max_new_tokens = args .max_new_tokens ,
952+ compile_dynamic_sendnn = True ,
953+ stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
954+ prefill_chunk_size = args .prefill_chunk_size ,
955+ ** extra_kwargs ,
956+ )
957+
958+ if args .distributed :
959+ # wait for rank0 to be finished as it is the only one generating the criteria json
960+ # this is needed since otherwise we may run into a race condition
961+ torch .distributed .barrier ()
962+
963+
894964## RUN VALIDATION AND TESTS ##
895965
896966failed_cases = []
897967# for each program and valid prompt (batch size, sequence length)
898968for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
899969 extra_kwargs ["attn_name" ] = ATTN_NAME
900970
901- aiu_validation_info , cpu_validation_info = run_validation (
971+ aiu_validation_info = run_validation (
902972 args = args ,
903973 model = model ,
904- validation_model = validation_model ,
905974 program_id = program_id ,
906975 valid_prompt = valid_prompt ,
907976 input_ids = input_ids ,
908- extra_kwargs = extra_kwargs ,
909- sample_key = sample_key ,
910- attn_name = ATTN_NAME ,
911- cpu_dtype = CPU_DTYPE ,
912- tokenizer = tokenizer ,
977+ extra_kwargs = extra_kwargs ,
978+ cpu_validation_info = cpu_validation_info
913979 )
914980
915981 if args .test_type == "metrics" :
0 commit comments