Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions aiu_fms_testing_utils/scripts/drive_paged_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def __custom_line_sampler(*args, **kwargs):
max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"])


def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0):
def __prepare_inputs(
batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0, pad_multiple=64
):
start = time.time()
prompts_and_sizes, sample_key = sampler(
DATASET_PATH,
Expand All @@ -276,6 +278,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0
enforce_sizes=enforce_sizes,
truncation=allow_truncation,
return_key=True,
pad_multiple=pad_multiple,
)
end = time.time()
if local_rank == 0:
Expand Down Expand Up @@ -393,7 +396,13 @@ def __load_validation_info(
# warmup with any input so compiler produces criteria json
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
pad_multiple = 64
if args.prefill_chunk_size > 0:
assert args.prefill_chunk_size % 64 == 0, (
"Chunk size must be a multiple of the page size"
)
pad_multiple = args.prefill_chunk_size
prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)]
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
if is_fp8:
prompt_list = prompt_list * 2
Expand Down Expand Up @@ -505,7 +514,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
# FIXME: filter condition for this on prompt and batch
program_map = get_programs_prompts(
program_criteria_list,
multiple=64,
multiple=pad_multiple,
max_batch_size=max_batch_size,
max_tkv=max_tkv,
program_cycles=max_new_tokens,
Expand All @@ -526,6 +535,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
pad_multiple=pad_multiple,
)
valid_prompts = [
(
Expand Down Expand Up @@ -579,14 +589,29 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
enforce_sizes = [valid_prompt_shape[1]]
if args.enforce_homogeneous_prompt_programs:
# this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1)
if (
args.enforce_homogeneous_prompt_programs
or args.prefill_chunk_size > 0
):
# if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
tkv_cutoff = (
1 << (valid_prompt_shape[1].bit_length() - 1)
if args.enforce_homogeneous_prompt_programs
else pad_multiple
)

possible_seq_lengths = [
_ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64)
_
for _ in range(
tkv_cutoff, valid_prompt_shape[1], pad_multiple
)
]
# favor sequences that are close to the valid prompt length
possible_seq_lengths.reverse()
# add the valid prompt size to the end since it will already exist in the above enforce_sizes
possible_seq_lengths = possible_seq_lengths + [
valid_prompt_shape[1]
]
enforce_sizes = enforce_sizes + list(
itertools.islice(
itertools.cycle(possible_seq_lengths),
Expand All @@ -599,6 +624,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size)
)
valid_prompts.append(
(
Expand Down
9 changes: 6 additions & 3 deletions aiu_fms_testing_utils/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def generate(
.unsqueeze(0)
.clone()
)
assert input_ids_seq_chunk.size(1) == prefill_chunk_size, (
f"prefill chunk size was not equal to the chunk size. Found {input_ids_seq_chunk.size(0)}"
)
slots_length = len(slot_mapping[seq_i])
slot_mapping_seq_chunk = (
torch.tensor(
Expand Down Expand Up @@ -577,12 +580,12 @@ def generate(
return result


# this value is default to 2080 to be consistent with vllm for granite 3.3 8b instruct
# this value is default to 8192 to be consistent with vllm for granite 3.3 8b instruct w/ chunked prefill
KVCACHE_NUM_BLOCKS_HINT = int(
os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 2080)
os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 8192)
)

VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 131072))
VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 524288))


class ProgramCriteria:
Expand Down