Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 113 additions & 107 deletions aiu_fms_testing_utils/scripts/drive_paged_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,129 +523,129 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
for v in program_map.values():
random.Random(42).shuffle(v)


# select prompts that fit the batch size criteria
valid_prompts = []
if custom_shape:
for program_criteria_seq, valid_prompt_shapes in program_map.items():
for valid_prompt_shape in valid_prompt_shapes:
if valid_prompt_shape == custom_shape:
enforce_sizes = [valid_prompt_shape[1]]
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
pad_multiple=pad_multiple,
)
valid_prompts = [
(
def get_program_prompt_list():
if custom_shape:
prompt_found = 0
Copy link
Author

@Ssukriti Ssukriti Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since its hard to see git diff changes made in this PR:,
change 1 - use prompt_found flag as we are yielding instead of storing in list

for program_criteria_seq, valid_prompt_shapes in program_map.items():
for valid_prompt_shape in valid_prompt_shapes:
if valid_prompt_shape == custom_shape:
enforce_sizes = [valid_prompt_shape[1]]
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
pad_multiple=pad_multiple,
)
prompt_found = 1
yield (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change2: yield instead of list, flag set before yield

program_criteria_seq[0].program_id,
custom_shape,
input_ids,
extra_kwargs,
sample_key,
)
]
break
if prompt_found:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change3: see flag instead of length of list

break
if len(valid_prompts) > 0:
break
else:
for program_info in programs:
program_id = program_info.program_id
batch_size_limit = program_info.batch_size_limit
batch_size_limit_type = program_info.batch_size_limit_type
prompt_length_limit = program_info.prompt_length_limit
prompt_length_limit_type = program_info.prompt_length_limit_type

filtered_program_map = program_map
if program_id.isnumeric():
filtered_program_map = {
k: v
for k, v in program_map.items()
if k[0] == program_criteria_list[int(program_id)]
}
used_keys = set()
# for each program, we need to check if we have a shape that satisfies the --programs request
for program_seq_key, valid_prompt_shapes in filtered_program_map.items():
# if ? or numeric => we need to check if we have found at least one valid key to stop
if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0:
break
# if * => we need to see if we have found the first key to see if we should skip
elif program_id == "*" and program_seq_key[0] in used_keys:
continue

for valid_prompt_shape in valid_prompt_shapes:
# make sure the criteria for batch limit and prompt limit is satisfied
# eval is safe here because we have limited what type and limit can be before
else:
for program_info in programs:
program_id = program_info.program_id
batch_size_limit = program_info.batch_size_limit
batch_size_limit_type = program_info.batch_size_limit_type
prompt_length_limit = program_info.prompt_length_limit
prompt_length_limit_type = program_info.prompt_length_limit_type

filtered_program_map = program_map
if program_id.isnumeric():
filtered_program_map = {
k: v
for k, v in program_map.items()
if k[0] == program_criteria_list[int(program_id)]
}
used_keys = set()
# for each program, we need to check if we have a shape that satisfies the --programs request
for program_seq_key, valid_prompt_shapes in filtered_program_map.items():
# if ? or numeric => we need to check if we have found at least one valid key to stop
if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0:
break
# if * => we need to see if we have found the first key to see if we should skip
elif program_id == "*" and program_seq_key[0] in used_keys:
continue

for valid_prompt_shape in valid_prompt_shapes:
# make sure the criteria for batch limit and prompt limit is satisfied
# eval is safe here because we have limited what type and limit can be before

batch_check = eval(
f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}"
)
prompt_check = eval(
f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}"
)
if batch_check and prompt_check:
# when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
enforce_sizes = [valid_prompt_shape[1]]
if (
args.enforce_homogeneous_prompt_programs
or args.prefill_chunk_size > 0
):
# if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
tkv_cutoff = (
1 << (valid_prompt_shape[1].bit_length() - 1)
if args.enforce_homogeneous_prompt_programs
else pad_multiple
)

batch_check = eval(
f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}"
)
prompt_check = eval(
f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}"
)
if batch_check and prompt_check:
# when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
enforce_sizes = [valid_prompt_shape[1]]
if (
args.enforce_homogeneous_prompt_programs
or args.prefill_chunk_size > 0
):
# if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
tkv_cutoff = (
1 << (valid_prompt_shape[1].bit_length() - 1)
if args.enforce_homogeneous_prompt_programs
else pad_multiple
)

possible_seq_lengths = [
_
for _ in range(
tkv_cutoff, valid_prompt_shape[1], pad_multiple
possible_seq_lengths = [
_
for _ in range(
tkv_cutoff, valid_prompt_shape[1], pad_multiple
)
]
# favor sequences that are close to the valid prompt length
possible_seq_lengths.reverse()
# add the valid prompt size to the end since it will already exist in the above enforce_sizes
possible_seq_lengths = possible_seq_lengths + [
valid_prompt_shape[1]
]
enforce_sizes = enforce_sizes + list(
itertools.islice(
itertools.cycle(possible_seq_lengths),
valid_prompt_shape[0] - 1,
)
)
]
# favor sequences that are close to the valid prompt length
possible_seq_lengths.reverse()
# add the valid prompt size to the end since it will already exist in the above enforce_sizes
possible_seq_lengths = possible_seq_lengths + [
valid_prompt_shape[1]
]
enforce_sizes = enforce_sizes + list(
itertools.islice(
itertools.cycle(possible_seq_lengths),
valid_prompt_shape[0] - 1,
try:
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size)
)
)
try:
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size)
)
valid_prompts.append(
(
used_keys.add(program_seq_key[0])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change 4: used_keys.add(program_seq_key[0]) before yield and then yield

yield (
program_seq_key[0],
valid_prompt_shape,
input_ids,
extra_kwargs,
sample_key,
)
)
used_keys.add(program_seq_key[0])
break
except ValueError:
dprint(
f"No valid sample exists in dataset for this input shape {valid_prompt_shape}"
)

if len(used_keys) == 0 and local_rank == 0:
dprint(
f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}"
)

break
except ValueError:
dprint(
f"No valid sample exists in dataset for this input shape {valid_prompt_shape}"
)

if len(used_keys) == 0 and local_rank == 0:
dprint(
f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}"
)


# metric calculator based on the cross-entropy and mean diff for each decode step
Expand All @@ -664,7 +664,13 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):

failed_cases = []
# for each program and valid prompt (batch size, sequence length)
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
for (
program_id,
valid_prompt,
input_ids,
extra_kwargs,
sample_key,
) in get_program_prompt_list():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change 5: call function to yield instead of list

extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
Expand Down
Loading