From c15610145eed98e81d9ad87a41c88b6f5995f34c Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 14 Nov 2025 13:55:55 -0700 Subject: [PATCH 1/5] refactor get valid prompts Signed-off-by: Sukriti-Sharma4 --- .../scripts/drive_paged_programs.py | 187 ++++++++---------- 1 file changed, 87 insertions(+), 100 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index c10652b..cee85f0 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -516,110 +516,97 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # select prompts that fit the batch size criteria valid_prompts = [] -if custom_shape: - for program_criteria_seq, valid_prompt_shapes in program_map.items(): - for valid_prompt_shape in valid_prompt_shapes: - if valid_prompt_shape == custom_shape: - enforce_sizes = [valid_prompt_shape[1]] - input_ids, extra_kwargs, sample_key = __prepare_inputs( - valid_prompt_shape[0], - valid_prompt_shape[1], - tokenizer, - enforce_sizes=enforce_sizes, - ) - valid_prompts = [ - ( - program_criteria_seq[0].program_id, - custom_shape, - input_ids, - extra_kwargs, - sample_key, - ) - ] - break - if len(valid_prompts) > 0: - break -else: - for program_info in programs: - program_id = program_info.program_id - batch_size_limit = program_info.batch_size_limit - batch_size_limit_type = program_info.batch_size_limit_type - prompt_length_limit = program_info.prompt_length_limit - prompt_length_limit_type = program_info.prompt_length_limit_type - - filtered_program_map = program_map - if program_id.isnumeric(): - filtered_program_map = { - k: v - for k, v in program_map.items() - if k[0] == program_criteria_list[int(program_id)] - } - used_keys = set() - # for each program, we need to check if we have a shape that satisfies the --programs request - for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): - # if ? or numeric => we need to check if we have found at least one valid key to stop - if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: - break - # if * => we need to see if we have found the first key to see if we should skip - elif program_id == "*" and program_seq_key[0] in used_keys: - continue +def get_program_prompt_list(): + if custom_shape: + for program_criteria_seq, valid_prompt_shapes in program_map.items(): for valid_prompt_shape in valid_prompt_shapes: - # make sure the criteria for batch limit and prompt limit is satisfied - # eval is safe here because we have limited what type and limit can be before - - batch_check = eval( - f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" - ) - prompt_check = eval( - f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" - ) - if batch_check and prompt_check: - # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length - # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning - # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user + if valid_prompt_shape == custom_shape: enforce_sizes = [valid_prompt_shape[1]] - if args.enforce_homogeneous_prompt_programs: - # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length - tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) - possible_seq_lengths = [ - _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) - ] - # favor sequences that are close to the valid prompt length - possible_seq_lengths.reverse() - enforce_sizes = enforce_sizes + list( - itertools.islice( - itertools.cycle(possible_seq_lengths), - valid_prompt_shape[0] - 1, + input_ids, extra_kwargs, sample_key = __prepare_inputs( + valid_prompt_shape[0], + valid_prompt_shape[1], + tokenizer, + enforce_sizes=enforce_sizes, + ) + + yield program_criteria_seq[0].program_id, custom_shape, input_ids, extra_kwargs, sample_key + break + if len(valid_prompts) > 0: + break + else: + for program_info in programs: + program_id = program_info.program_id + batch_size_limit = program_info.batch_size_limit + batch_size_limit_type = program_info.batch_size_limit_type + prompt_length_limit = program_info.prompt_length_limit + prompt_length_limit_type = program_info.prompt_length_limit_type + + filtered_program_map = program_map + if program_id.isnumeric(): + filtered_program_map = { + k: v + for k, v in program_map.items() + if k[0] == program_criteria_list[int(program_id)] + } + used_keys = set() + # for each program, we need to check if we have a shape that satisfies the --programs request + for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): + # if ? or numeric => we need to check if we have found at least one valid key to stop + if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: + break + # if * => we need to see if we have found the first key to see if we should skip + elif program_id == "*" and program_seq_key[0] in used_keys: + continue + + for valid_prompt_shape in valid_prompt_shapes: + # make sure the criteria for batch limit and prompt limit is satisfied + # eval is safe here because we have limited what type and limit can be before + + batch_check = eval( + f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" + ) + prompt_check = eval( + f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" + ) + if batch_check and prompt_check: + # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length + # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning + # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user + enforce_sizes = [valid_prompt_shape[1]] + if args.enforce_homogeneous_prompt_programs: + # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length + tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) + possible_seq_lengths = [ + _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) + ] + # favor sequences that are close to the valid prompt length + possible_seq_lengths.reverse() + enforce_sizes = enforce_sizes + list( + itertools.islice( + itertools.cycle(possible_seq_lengths), + valid_prompt_shape[0] - 1, + ) ) - ) - try: - input_ids, extra_kwargs, sample_key = __prepare_inputs( - valid_prompt_shape[0], - valid_prompt_shape[1], - tokenizer, - enforce_sizes=enforce_sizes, - ) - valid_prompts.append( - ( - program_seq_key[0], - valid_prompt_shape, - input_ids, - extra_kwargs, - sample_key, + try: + input_ids, extra_kwargs, sample_key = __prepare_inputs( + valid_prompt_shape[0], + valid_prompt_shape[1], + tokenizer, + enforce_sizes=enforce_sizes, ) - ) - used_keys.add(program_seq_key[0]) - break - except ValueError: - dprint( - f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" - ) - - if len(used_keys) == 0 and local_rank == 0: - dprint( - f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" - ) + used_keys.add(program_seq_key[0]) + yield program_seq_key[0], valid_prompt_shape, input_ids, extra_kwargs, sample_key + break + except ValueError: + dprint( + f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" + ) + + if len(used_keys) == 0 and local_rank == 0: + dprint( + f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" + ) # metric calculator based on the cross-entropy and mean diff for each decode step @@ -638,7 +625,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): failed_cases = [] # for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: +for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in get_program_prompt_list(): extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant From ba345b15941195106adee5c1cd43fa35454e8c03 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 14 Nov 2025 14:46:07 -0700 Subject: [PATCH 2/5] black fmt Signed-off-by: Sukriti-Sharma4 --- .../scripts/drive_paged_programs.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index cee85f0..f19ba7f 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -517,6 +517,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # select prompts that fit the batch size criteria valid_prompts = [] + def get_program_prompt_list(): if custom_shape: for program_criteria_seq, valid_prompt_shapes in program_map.items(): @@ -529,8 +530,14 @@ def get_program_prompt_list(): tokenizer, enforce_sizes=enforce_sizes, ) - - yield program_criteria_seq[0].program_id, custom_shape, input_ids, extra_kwargs, sample_key + + yield ( + program_criteria_seq[0].program_id, + custom_shape, + input_ids, + extra_kwargs, + sample_key, + ) break if len(valid_prompts) > 0: break @@ -596,7 +603,13 @@ def get_program_prompt_list(): enforce_sizes=enforce_sizes, ) used_keys.add(program_seq_key[0]) - yield program_seq_key[0], valid_prompt_shape, input_ids, extra_kwargs, sample_key + yield ( + program_seq_key[0], + valid_prompt_shape, + input_ids, + extra_kwargs, + sample_key, + ) break except ValueError: dprint( @@ -625,7 +638,13 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): failed_cases = [] # for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in get_program_prompt_list(): +for ( + program_id, + valid_prompt, + input_ids, + extra_kwargs, + sample_key, +) in get_program_prompt_list(): extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant From 22c13ea7fd79d5fd4674592205453bd39ec41ac3 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Mon, 17 Nov 2025 15:58:13 -0700 Subject: [PATCH 3/5] rebase on main Signed-off-by: Sukriti-Sharma4 --- .../scripts/drive_paged_programs.py | 50 ++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index f19ba7f..b94c056 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -264,7 +264,9 @@ def __custom_line_sampler(*args, **kwargs): max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"]) -def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0): +def __prepare_inputs( + batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0, pad_multiple=64 +): start = time.time() prompts_and_sizes, sample_key = sampler( DATASET_PATH, @@ -276,6 +278,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 enforce_sizes=enforce_sizes, truncation=allow_truncation, return_key=True, + pad_multiple=pad_multiple, ) end = time.time() if local_rank == 0: @@ -393,7 +396,13 @@ def __load_validation_info( # warmup with any input so compiler produces criteria json # TODO: Swap this with __prepare_inputs once fix for shape_id is available # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) -prompt_list = [torch.arange(0, 64, dtype=torch.int64)] +pad_multiple = 64 +if args.prefill_chunk_size > 0: + assert ( + args.prefill_chunk_size % 64 == 0 + ), "Chunk size must be a multiple of the page size" + pad_multiple = args.prefill_chunk_size +prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 if is_fp8: prompt_list = prompt_list * 2 @@ -505,7 +514,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # FIXME: filter condition for this on prompt and batch program_map = get_programs_prompts( program_criteria_list, - multiple=64, + multiple=pad_multiple, max_batch_size=max_batch_size, max_tkv=max_tkv, program_cycles=max_new_tokens, @@ -514,12 +523,11 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: for v in program_map.values(): random.Random(42).shuffle(v) -# select prompts that fit the batch size criteria -valid_prompts = [] - +# select prompts that fit the batch size criteria def get_program_prompt_list(): if custom_shape: + prompt_found = 0 for program_criteria_seq, valid_prompt_shapes in program_map.items(): for valid_prompt_shape in valid_prompt_shapes: if valid_prompt_shape == custom_shape: @@ -529,8 +537,9 @@ def get_program_prompt_list(): valid_prompt_shape[1], tokenizer, enforce_sizes=enforce_sizes, + pad_multiple=pad_multiple, ) - + prompt_found = 1 yield ( program_criteria_seq[0].program_id, custom_shape, @@ -539,7 +548,7 @@ def get_program_prompt_list(): sample_key, ) break - if len(valid_prompts) > 0: + if prompt_found: break else: for program_info in programs: @@ -581,14 +590,29 @@ def get_program_prompt_list(): # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user enforce_sizes = [valid_prompt_shape[1]] - if args.enforce_homogeneous_prompt_programs: - # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length - tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) + if ( + args.enforce_homogeneous_prompt_programs + or args.prefill_chunk_size > 0 + ): + # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length + tkv_cutoff = ( + 1 << (valid_prompt_shape[1].bit_length() - 1) + if args.enforce_homogeneous_prompt_programs + else pad_multiple + ) + possible_seq_lengths = [ - _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) + _ + for _ in range( + tkv_cutoff, valid_prompt_shape[1], pad_multiple + ) ] # favor sequences that are close to the valid prompt length possible_seq_lengths.reverse() + # add the valid prompt size to the end since it will already exist in the above enforce_sizes + possible_seq_lengths = possible_seq_lengths + [ + valid_prompt_shape[1] + ] enforce_sizes = enforce_sizes + list( itertools.islice( itertools.cycle(possible_seq_lengths), @@ -601,6 +625,7 @@ def get_program_prompt_list(): valid_prompt_shape[1], tokenizer, enforce_sizes=enforce_sizes, + pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size) ) used_keys.add(program_seq_key[0]) yield ( @@ -610,6 +635,7 @@ def get_program_prompt_list(): extra_kwargs, sample_key, ) + break except ValueError: dprint( From 952d432b4f297eecd76deb3bec8ec247a2e0ced8 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Mon, 17 Nov 2025 16:01:10 -0700 Subject: [PATCH 4/5] merge main Signed-off-by: Sukriti-Sharma4 --- aiu_fms_testing_utils/scripts/drive_paged_programs.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 0a114b3..b94c056 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -398,15 +398,9 @@ def __load_validation_info( # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) pad_multiple = 64 if args.prefill_chunk_size > 0: -<<<<<<< HEAD assert ( args.prefill_chunk_size % 64 == 0 ), "Chunk size must be a multiple of the page size" -======= - assert args.prefill_chunk_size % 64 == 0, ( - "Chunk size must be a multiple of the page size" - ) ->>>>>>> main pad_multiple = args.prefill_chunk_size prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 From 0aa793fca2ec60e841a1b3308024da0c6e002fe3 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Mon, 17 Nov 2025 16:20:54 -0700 Subject: [PATCH 5/5] ruff format Signed-off-by: Sukriti-Sharma4 --- aiu_fms_testing_utils/scripts/drive_paged_programs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index b94c056..2e75994 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -398,9 +398,9 @@ def __load_validation_info( # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) pad_multiple = 64 if args.prefill_chunk_size > 0: - assert ( - args.prefill_chunk_size % 64 == 0 - ), "Chunk size must be a multiple of the page size" + assert args.prefill_chunk_size % 64 == 0, ( + "Chunk size must be a multiple of the page size" + ) pad_multiple = args.prefill_chunk_size prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16