Skip to content

Commit a531fca

Browse files
authored
Merge pull request #128 from foundation-model-stack/drive_program_script_enhancements
Drive Paged Program Script enhancements
2 parents 046f9c4 + 2e9d6d8 commit a531fca

File tree

5 files changed

+503
-196
lines changed

5 files changed

+503
-196
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def warmup_model(
7777
_max_new_tokens = max_new_tokens
7878
if compile_dynamic_sendnn:
7979
_max_new_tokens = 2
80-
# always warmup with batch size 2 when using attn_type=paged
81-
if "paged" in attn_name:
80+
# When performing fp8 paged attention, we must pad to batch size 2
81+
# this is fixed in torch >= 2.8
82+
if attn_name == "spyre_paged_attn_fp8":
8283
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(
8384
input_ids,
8485
**extra_kwargs,
@@ -467,6 +468,42 @@ def __sample_requests(
467468
return filtered_dataset
468469

469470

471+
def sample_rag_factoid_requests(
472+
dataset_path: str,
473+
num_requests: int,
474+
tokenizer: PreTrainedTokenizerBase,
475+
prompt_length_min: int = 32,
476+
prompt_length_max: int = 65536,
477+
seed: Optional[int] = None,
478+
enforce_heterogeneous: bool = False,
479+
enforce_sizes: List[int] = [],
480+
truncation: bool = False,
481+
pad_multiple: int = 64,
482+
) -> List[Tuple[str, int]]:
483+
if not os.path.exists(dataset_path):
484+
print("error dataset does not exist")
485+
486+
dataset = []
487+
# Load the dataset.
488+
with open(dataset_path, "r", encoding="utf-8") as f:
489+
for line in f:
490+
dataset.append(line)
491+
492+
return __sample_requests(
493+
dataset,
494+
num_requests,
495+
tokenizer,
496+
prompt_length_min,
497+
prompt_length_max,
498+
seed,
499+
enforce_heterogeneous,
500+
enforce_sizes,
501+
truncation,
502+
pad_multiple,
503+
_cached_dataset_key=dataset_path,
504+
)
505+
506+
470507
def sample_sharegpt_requests(
471508
dataset_path: str,
472509
num_requests: int,
@@ -481,16 +518,23 @@ def sample_sharegpt_requests(
481518
) -> List[Tuple[str, int]]:
482519
if not os.path.exists(dataset_path):
483520
print("downloading share-gpt dataset as it does not exist")
484-
__download_file(
485-
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json",
486-
dataset_path,
487-
)
521+
is_distributed_initialized = torch.distributed.is_initialized()
522+
if not is_distributed_initialized or rank < 1:
523+
__download_file(
524+
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json",
525+
dataset_path,
526+
)
527+
else:
528+
print("waiting for rank0 to complete download")
529+
530+
if is_distributed_initialized:
531+
torch.distributed.barrier()
488532

489533
if enforce_sizes is None:
490534
enforce_sizes = []
491535

492536
# Load the dataset.
493-
with open(dataset_path, encoding="utf-8") as f:
537+
with open(dataset_path, "r", encoding="utf-8") as f:
494538
dataset = json.load(f)
495539
# Filter out the conversations with less than 2 turns.
496540
dataset = [data for data in dataset if len(data["conversations"]) >= 2]

aiu_fms_testing_utils/utils/paged.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,15 @@ def generate(
8686
if extra_kwargs is not None:
8787
kwargs.update(extra_kwargs)
8888

89+
is_fp8 = "fp8" in kwargs["attn_name"]
8990
if isinstance(input_ids, torch.Tensor):
9091
if len(input_ids.shape) == 1:
9192
input_ids = input_ids.unsqueeze(0)
9293

9394
is_batch = input_ids.shape[0] > 1
94-
# our model requires batch dimension
95-
if not is_batch:
95+
# our model requires batch dimension when running with fp8
96+
# this is fixed in torch >= 2.8
97+
if is_fp8 and not is_batch:
9698
input_ids, kwargs = adjust_inputs_to_batch(input_ids, **kwargs)
9799
else:
98100
raise TypeError("input_ids must be one of Tensor or List")
@@ -115,7 +117,10 @@ def generate(
115117
# if we set these variables here, we run the risk of warming up and generating with different sizes
116118
_MAX_BATCH = int(os.environ["VLLM_DT_MAX_BATCH_SIZE"])
117119
_MAX_CONTEXT_LENGTH = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"])
118-
NUM_BLOCKS = (_MAX_BATCH * _MAX_CONTEXT_LENGTH) // BLOCK_SIZE
120+
# if the user provides a hint to the number of blocks to use, use it directly
121+
NUM_BLOCKS = kwargs.get(
122+
"_kvcache_num_blocks_hint", (_MAX_BATCH * _MAX_CONTEXT_LENGTH) // BLOCK_SIZE
123+
)
119124

120125
if hasattr(model, "head"):
121126
model_dtype = model.head.weight.dtype
@@ -345,7 +350,10 @@ def generate(
345350
[
346351
(
347352
[b_seq[0]]
348-
* (max(2, max([len(b) for b in block_table])) - len(b_seq))
353+
* (
354+
max(2 if is_fp8 else 1, max([len(b) for b in block_table]))
355+
- len(b_seq)
356+
)
349357
)
350358
+ b_seq
351359
for b_seq in block_table
@@ -408,17 +416,19 @@ def generate(
408416
if post_iteration_hook is not None:
409417
_logits = logits
410418
_next_val = next_val
411-
# since we cannot handle batch size 1 and mimic with batch size 2, we need to only pass in the first logits/next_val
412-
if not is_batch:
419+
# since we cannot handle batch size 1 for fp8 and mimic with batch size 2, we need to only pass in the first logits/next_val
420+
if is_fp8 and not is_batch:
413421
_logits = logits[0].unsqueeze(0)
414422
_next_val = _next_val[0].unsqueeze(0)
415423
_next_val, kwargs = post_iteration_hook(
416424
i + prompt_length, _logits, _next_val, kwargs
417425
)
418426
# we need to normalize back to batch size 2
419-
if not is_batch:
427+
if is_fp8 and not is_batch:
420428
# we need to do an in-place copy here for the same reason we do in-place copy for injecting tokens
421429
next_val.copy_(torch.cat((_next_val, _next_val), dim=0))
430+
else:
431+
next_val = _next_val
422432

423433
result = torch.cat((result, next_val), dim=-1)
424434

@@ -454,7 +464,12 @@ def generate(
454464
return result
455465

456466

457-
VLLM_DT_MAX_BATCH_TKV_LIMIT = 131072
467+
# this value is default to 2080 to be consistent with vllm for granite 3.3 8b instruct
468+
KVCACHE_NUM_BLOCKS_HINT = int(
469+
os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 2080)
470+
)
471+
472+
VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 131072))
458473

459474

460475
class ProgramCriteria:
@@ -468,7 +483,11 @@ def __init__(
468483
self.tkv_granularity = tkv_granularity
469484

470485
def is_possible(self, batch_size, tkv):
471-
return batch_size * tkv <= VLLM_DT_MAX_BATCH_TKV_LIMIT
486+
return (
487+
(batch_size * tkv <= VLLM_DT_MAX_BATCH_TKV_LIMIT)
488+
and (batch_size <= self.max_batch)
489+
and (tkv <= self.max_tkv)
490+
)
472491

473492
def calculate_padding(self, batch_size, tkv):
474493
min_batch_req = (
@@ -496,7 +515,12 @@ def __hash__(self):
496515

497516

498517
def get_programs_prompts(
499-
program_criteria_list, multiple, max_batch_size, max_tkv, program_cycles
518+
program_criteria_list,
519+
multiple,
520+
max_batch_size,
521+
max_tkv,
522+
program_cycles,
523+
prioritize_large_batch_sizes=True,
500524
):
501525
program_map = {}
502526

@@ -515,6 +539,11 @@ def get_programs_prompts(
515539
if (
516540
resolved_programs[program_index] is None
517541
or padding < resolved_programs[program_index][1]
542+
or (
543+
padding == resolved_programs[program_index][1]
544+
and program_criteria.batch_granularity
545+
> resolved_programs[program_index][0].batch_granularity
546+
)
518547
):
519548
resolved_programs[program_index] = (
520549
program_criteria,
@@ -528,4 +557,8 @@ def get_programs_prompts(
528557
else:
529558
program_map[key] = [(batch_size, prompt_len)]
530559

560+
# give higher priority to larger batches
561+
for _, v in program_map.items():
562+
v.sort(key=lambda t: t[0], reverse=prioritize_large_batch_sizes)
563+
531564
return program_map

0 commit comments

Comments
 (0)