From 5b4d826a05acf121588a8c8c372494b4f95d1cfb Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 30 Oct 2025 01:27:13 +0000 Subject: [PATCH 01/13] Add padding to chunk size and update env vars Signed-off-by: Antoni Viros i Martin --- .../scripts/drive_paged_programs.py | 16 ++++++++++++++-- aiu_fms_testing_utils/utils/paged.py | 12 ++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 2dcd0216..90b5ed73 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -3,6 +3,7 @@ import datetime import itertools import json +import math import os import random import time @@ -254,7 +255,7 @@ def __custom_line_sampler(*args, **kwargs): max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"]) -def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0): +def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0, pad_multiple=64): start = time.time() prompts_and_sizes, sample_key = sampler( DATASET_PATH, @@ -266,6 +267,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 enforce_sizes=enforce_sizes, truncation=allow_truncation, return_key=True, + pad_multiple=pad_multiple, ) end = time.time() if local_rank == 0: @@ -284,7 +286,11 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 ) prompt_list = [prompt_list[0]] * (batch_size - len(prompt_list)) + prompt_list - input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) + min_pad_length = seq_length + if seq_length % pad_multiple != 0: + min_pad_length = math.ceil(seq_length / pad_multiple) * pad_multiple + print(pad_multiple, min_pad_length, seq_length) + input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=min_pad_length) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) return input_ids, extra_kwargs, sample_key @@ -505,6 +511,10 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: random.Random(42).shuffle(v) # select prompts that fit the batch size criteria +pad_multiple = 64 +if args.prefill_chunk_size > 0: + assert args.prefill_chunk_size % 64 == 0, "Chunk size must be a multiple of the page size" + pad_multiple = args.prefill_chunk_size valid_prompts = [] if custom_shape: for program_criteria_seq, valid_prompt_shapes in program_map.items(): @@ -516,6 +526,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape[1], tokenizer, enforce_sizes=enforce_sizes, + pad_multiple=pad_multiple, ) valid_prompts = [ ( @@ -589,6 +600,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape[1], tokenizer, enforce_sizes=enforce_sizes, + pad_multiple=pad_multiple, ) valid_prompts.append( ( diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index e472085f..a696aad1 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -372,9 +372,9 @@ def generate( torch._dynamo.mark_static(block_table_seq_chunk, 0) # seq dynamic - torch._dynamo.mark_dynamic(input_ids_seq_chunk, 1) - torch._dynamo.mark_dynamic(slot_mapping_seq_chunk, 1) - torch._dynamo.mark_dynamic(position_ids_seq_chunk, 1) + # torch._dynamo.mark_dynamic(input_ids_seq_chunk, 1) + # torch._dynamo.mark_dynamic(slot_mapping_seq_chunk, 1) + # torch._dynamo.mark_dynamic(position_ids_seq_chunk, 1) torch._dynamo.mark_dynamic(block_table_seq_chunk, 1) logits, current_kv_cache = model( @@ -573,12 +573,12 @@ def generate( return result -# this value is default to 2080 to be consistent with vllm for granite 3.3 8b instruct +# this value is default to 8192 to be consistent with vllm for granite 3.3 8b instruct w/ chunked prefill KVCACHE_NUM_BLOCKS_HINT = int( - os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 2080) + os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 8192) ) -VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 131072)) +VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 524288)) class ProgramCriteria: From ff9f9a403957d9ad581f35b8363b63d3f8b39c45 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 30 Oct 2025 01:34:20 +0000 Subject: [PATCH 02/13] Fix the DPP warmup Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/scripts/drive_paged_programs.py | 10 +++++----- aiu_fms_testing_utils/utils/paged.py | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 90b5ed73..9d39cd72 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -389,7 +389,11 @@ def __load_validation_info( # warmup with any input so compiler produces criteria json # TODO: Swap this with __prepare_inputs once fix for shape_id is available # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) -prompt_list = [torch.arange(0, 64, dtype=torch.int64)] +pad_multiple = 64 +if args.prefill_chunk_size > 0: + assert args.prefill_chunk_size % 64 == 0, "Chunk size must be a multiple of the page size" + pad_multiple = args.prefill_chunk_size +prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 if is_fp8: prompt_list = prompt_list * 2 @@ -511,10 +515,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: random.Random(42).shuffle(v) # select prompts that fit the batch size criteria -pad_multiple = 64 -if args.prefill_chunk_size > 0: - assert args.prefill_chunk_size % 64 == 0, "Chunk size must be a multiple of the page size" - pad_multiple = args.prefill_chunk_size valid_prompts = [] if custom_shape: for program_criteria_seq, valid_prompt_shapes in program_map.items(): diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index a696aad1..962a453e 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -365,6 +365,8 @@ def generate( "block_table": block_table_seq_chunk, } + print(chunked_kwargs) + # batch static torch._dynamo.mark_static(input_ids_seq_chunk, 0) torch._dynamo.mark_static(slot_mapping_seq_chunk, 0) From edc6d6db7cf4c5bbe2af27c6dba63e7ea042a5a0 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 30 Oct 2025 01:48:37 +0000 Subject: [PATCH 03/13] address PR comments Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/scripts/drive_paged_programs.py | 8 ++------ aiu_fms_testing_utils/utils/paged.py | 5 ----- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 9d39cd72..044b5e56 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -286,11 +286,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 ) prompt_list = [prompt_list[0]] * (batch_size - len(prompt_list)) + prompt_list - min_pad_length = seq_length - if seq_length % pad_multiple != 0: - min_pad_length = math.ceil(seq_length / pad_multiple) * pad_multiple - print(pad_multiple, min_pad_length, seq_length) - input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=min_pad_length) + input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) return input_ids, extra_kwargs, sample_key @@ -505,7 +501,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # FIXME: filter condition for this on prompt and batch program_map = get_programs_prompts( program_criteria_list, - multiple=64, + multiple=pad_multiple, max_batch_size=max_batch_size, max_tkv=max_tkv, program_cycles=max_new_tokens, diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 962a453e..d8641c6a 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -365,8 +365,6 @@ def generate( "block_table": block_table_seq_chunk, } - print(chunked_kwargs) - # batch static torch._dynamo.mark_static(input_ids_seq_chunk, 0) torch._dynamo.mark_static(slot_mapping_seq_chunk, 0) @@ -374,9 +372,6 @@ def generate( torch._dynamo.mark_static(block_table_seq_chunk, 0) # seq dynamic - # torch._dynamo.mark_dynamic(input_ids_seq_chunk, 1) - # torch._dynamo.mark_dynamic(slot_mapping_seq_chunk, 1) - # torch._dynamo.mark_dynamic(position_ids_seq_chunk, 1) torch._dynamo.mark_dynamic(block_table_seq_chunk, 1) logits, current_kv_cache = model( From 5ffd7e383ed05a264bf6ba73ba111c2ea46788ee Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 30 Oct 2025 11:27:02 +0000 Subject: [PATCH 04/13] Ruff and dynamic Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/scripts/drive_paged_programs.py | 8 ++++++-- aiu_fms_testing_utils/utils/paged.py | 3 +++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 044b5e56..8e2a06bf 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -255,7 +255,9 @@ def __custom_line_sampler(*args, **kwargs): max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"]) -def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0, pad_multiple=64): +def __prepare_inputs( + batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0, pad_multiple=64 +): start = time.time() prompts_and_sizes, sample_key = sampler( DATASET_PATH, @@ -387,7 +389,9 @@ def __load_validation_info( # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) pad_multiple = 64 if args.prefill_chunk_size > 0: - assert args.prefill_chunk_size % 64 == 0, "Chunk size must be a multiple of the page size" + assert args.prefill_chunk_size % 64 == 0, ( + "Chunk size must be a multiple of the page size" + ) pad_multiple = args.prefill_chunk_size prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index d8641c6a..ec925dc9 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -372,6 +372,9 @@ def generate( torch._dynamo.mark_static(block_table_seq_chunk, 0) # seq dynamic + torch._dynamo.mark_dynamic(input_ids_seq_chunk, 1) + torch._dynamo.mark_dynamic(slot_mapping_seq_chunk, 1) + torch._dynamo.mark_dynamic(position_ids_seq_chunk, 1) torch._dynamo.mark_dynamic(block_table_seq_chunk, 1) logits, current_kv_cache = model( From c5fab859a80d221f79cfdb2adb2243e5097ef479 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 30 Oct 2025 11:27:53 +0000 Subject: [PATCH 05/13] remove import Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/scripts/drive_paged_programs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 8e2a06bf..9a21ca27 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -3,7 +3,6 @@ import datetime import itertools import json -import math import os import random import time From 2851133b58d4ae1c9170e1da1be8b4ddc2264ca2 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Mon, 3 Nov 2025 04:13:13 +0000 Subject: [PATCH 06/13] fixed incorrect chunk sizes in all sequences but the largest in the batch Signed-off-by: Joshua Rosenkranz --- .../scripts/drive_paged_programs.py | 25 +++++++++++++++---- aiu_fms_testing_utils/utils/paged.py | 3 +++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 9a21ca27..3fce8ff3 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -579,14 +579,29 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user enforce_sizes = [valid_prompt_shape[1]] - if args.enforce_homogeneous_prompt_programs: - # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length - tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) + if ( + args.enforce_homogeneous_prompt_programs + or args.prefill_chunk_size > 0 + ): + # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length + tkv_cutoff = ( + 1 << (valid_prompt_shape[1].bit_length() - 1) + if args.enforce_homogeneous_prompt_programs + else pad_multiple + ) + possible_seq_lengths = [ - _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) + _ + for _ in range( + tkv_cutoff, valid_prompt_shape[1], pad_multiple + ) ] # favor sequences that are close to the valid prompt length possible_seq_lengths.reverse() + # add the valid prompt size to the end since it will already exist in the above enforce_sizes + possible_seq_lengths = possible_seq_lengths + [ + valid_prompt_shape[1] + ] enforce_sizes = enforce_sizes + list( itertools.islice( itertools.cycle(possible_seq_lengths), @@ -599,7 +614,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape[1], tokenizer, enforce_sizes=enforce_sizes, - pad_multiple=pad_multiple, + pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size) ) valid_prompts.append( ( diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index ec925dc9..60f108f0 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -312,6 +312,9 @@ def generate( .unsqueeze(0) .clone() ) + assert input_ids_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size. Found {input_ids_seq_chunk.size(0)}" + ) slots_length = len(slot_mapping[seq_i]) slot_mapping_seq_chunk = ( torch.tensor( From 3510df611735d0e0438106e414b89e5ccf489379 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 13 Nov 2025 04:38:47 +0000 Subject: [PATCH 07/13] fixed logic to allow any multiple of 64 in chunked prefill Signed-off-by: Joshua Rosenkranz --- .../scripts/drive_paged_programs.py | 34 ++----- aiu_fms_testing_utils/utils/paged.py | 89 ++++++++++--------- 2 files changed, 54 insertions(+), 69 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 6b8f3bef..1c51f777 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -264,9 +264,7 @@ def __custom_line_sampler(*args, **kwargs): max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"]) -def __prepare_inputs( - batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0, pad_multiple=64 -): +def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0): start = time.time() prompts_and_sizes, sample_key = sampler( DATASET_PATH, @@ -278,7 +276,6 @@ def __prepare_inputs( enforce_sizes=enforce_sizes, truncation=allow_truncation, return_key=True, - pad_multiple=pad_multiple, ) end = time.time() if local_rank == 0: @@ -396,13 +393,7 @@ def __load_validation_info( # warmup with any input so compiler produces criteria json # TODO: Swap this with __prepare_inputs once fix for shape_id is available # input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) -pad_multiple = 64 -if args.prefill_chunk_size > 0: - assert args.prefill_chunk_size % 64 == 0, ( - "Chunk size must be a multiple of the page size" - ) - pad_multiple = args.prefill_chunk_size -prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)] +prompt_list = [torch.arange(0, 64, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 if is_fp8: prompt_list = prompt_list * 2 @@ -514,7 +505,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # FIXME: filter condition for this on prompt and batch program_map = get_programs_prompts( program_criteria_list, - multiple=pad_multiple, + multiple=64, max_batch_size=max_batch_size, max_tkv=max_tkv, program_cycles=max_new_tokens, @@ -535,7 +526,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape[1], tokenizer, enforce_sizes=enforce_sizes, - pad_multiple=pad_multiple, ) valid_prompts = [ ( @@ -589,22 +579,11 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user enforce_sizes = [valid_prompt_shape[1]] - if ( - args.enforce_homogeneous_prompt_programs - or args.prefill_chunk_size > 0 - ): + if args.enforce_homogeneous_prompt_programs: # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length - tkv_cutoff = ( - 1 << (valid_prompt_shape[1].bit_length() - 1) - if args.enforce_homogeneous_prompt_programs - else pad_multiple - ) - + tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) possible_seq_lengths = [ - _ - for _ in range( - tkv_cutoff, valid_prompt_shape[1], pad_multiple - ) + _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) ] # favor sequences that are close to the valid prompt length possible_seq_lengths.reverse() @@ -624,7 +603,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape[1], tokenizer, enforce_sizes=enforce_sizes, - pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size) ) valid_prompts.append( ( diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 12f47763..bb4b82ef 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -5,6 +5,7 @@ from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union import torch import fms.utils.spyre.paged # noqa +from aiu_fms_testing_utils.utils import get_pad_size def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): @@ -300,62 +301,68 @@ def generate( last_n_tokens = kwargs.get("last_n_tokens", 0) if prefill_chunk_size > 0: - left_padded_prompt_mask_seq_chunk = None + required_extra_pads = get_pad_size(current_tkv, prefill_chunk_size) - current_tkv + left_padded_prompt_mask_seq_chunk = ( + kwargs["position_ids"][seq_i] == 0 + ).sum(dim=0) - 1 + required_extra_pads + left_padded_prompt_mask_seq_chunk = left_padded_prompt_mask_seq_chunk.unsqueeze(0) + # Chunked prefill for chunk_j in range(math.ceil(current_tkv / prefill_chunk_size)): - chunk_start = -current_tkv + chunk_j * prefill_chunk_size - chunk_end = -current_tkv + min( - (chunk_j + 1) * prefill_chunk_size, current_tkv - ) + + if chunk_j == 0: + chunk_start = 0 + chunk_end = prefill_chunk_size - required_extra_pads + else: + required_extra_pads = 0 + chunk_start = chunk_end + chunk_end += prefill_chunk_size + + input_ids_seq_chunk = input_ids[seq_i][chunk_start: chunk_end] + if required_extra_pads > 0: + input_ids_seq_chunk = torch.cat(( + torch.zeros(required_extra_pads, dtype=torch.int64, device=input_ids_seq_chunk.device), + input_ids_seq_chunk + )) + + input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone() - ids_length = input_ids[seq_i].shape[0] - input_ids_seq_chunk = ( - input_ids[seq_i][ - chunk_start + ids_length : chunk_end + ids_length - ] - .unsqueeze(0) - .clone() - ) assert input_ids_seq_chunk.size(1) == prefill_chunk_size, ( - f"prefill chunk size was not equal to the chunk size. Found {input_ids_seq_chunk.size(0)}" + f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}" ) - slots_length = len(slot_mapping[seq_i]) + + slot_mapping_seq_chunk = slot_mapping[seq_i][chunk_start:chunk_end] + if required_extra_pads > 0: + slot_mapping_seq_chunk = slot_mapping_seq_chunk[chunk_start:chunk_start+BLOCK_SIZE] * (required_extra_pads//BLOCK_SIZE) + slot_mapping_seq_chunk slot_mapping_seq_chunk = ( torch.tensor( - slot_mapping[seq_i][ - chunk_start + slots_length : chunk_end - + slots_length - ], + slot_mapping_seq_chunk, dtype=torch.int64, ) .unsqueeze(0) .clone() ) - pids_length = kwargs["position_ids"][seq_i].shape[0] - position_ids_seq_chunk = ( - kwargs["position_ids"][seq_i][ - chunk_start + pids_length : chunk_end + pids_length - ] - .unsqueeze(0) - .clone() + + assert slot_mapping_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}" ) - # This view will result in a discontiguous tensor (creates a new graph during compile) - # For this reason, we must explicitly make contiguous - if left_padded_prompt_mask_seq_chunk is None: - left_padded_prompt_mask_seq_chunk = ( - position_ids_seq_chunk == 0 - ).sum(dim=1) - 1 - current_tkv_mask_seq_chunk = torch.min( - torch.tensor( - (chunk_j + 1) * prefill_chunk_size, dtype=torch.int64 - ), - current_tkv, - ).unsqueeze(0) + position_ids_seq_chunk = kwargs["position_ids"][seq_i][chunk_start:chunk_end] + if required_extra_pads > 0: + position_ids_seq_chunk = torch.cat(( + torch.zeros(required_extra_pads, dtype=torch.int64, device=position_ids_seq_chunk.device), + position_ids_seq_chunk + )) + position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze(0).clone() + + assert position_ids_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size for position_ids. Found {position_ids_seq_chunk.size(0)}" + ) + + current_tkv_mask_seq_chunk = torch.tensor(prefill_chunk_size, dtype=torch.int64).unsqueeze(0) - table_length = len(block_table[seq_i]) - block_start = -current_tkv // BLOCK_SIZE + table_length - block_end = chunk_end // BLOCK_SIZE + table_length + block_start = chunk_start // BLOCK_SIZE + block_end = chunk_end // BLOCK_SIZE block_table_seq_chunk = torch.tensor( block_table[seq_i][block_start:block_end], dtype=torch.int64 ).unsqueeze(0) From 6fa414cd64522390ef718ed0b32fa297021a6ec4 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 13 Nov 2025 07:04:30 +0000 Subject: [PATCH 08/13] compile/inference is now completing Signed-off-by: Joshua Rosenkranz --- aiu_fms_testing_utils/utils/paged.py | 112 +++++++++++++++++++++------ 1 file changed, 90 insertions(+), 22 deletions(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index bb4b82ef..f9265009 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -301,15 +301,23 @@ def generate( last_n_tokens = kwargs.get("last_n_tokens", 0) if prefill_chunk_size > 0: - required_extra_pads = get_pad_size(current_tkv, prefill_chunk_size) - current_tkv + required_extra_pads = ( + get_pad_size(current_tkv.item(), prefill_chunk_size) + - current_tkv.item() + ) + left_padded_prompt_mask_seq_chunk = ( + (kwargs["position_ids"][seq_i][-current_tkv.item() :] == 0).sum( + dim=0 + ) + - 1 + + required_extra_pads + ) left_padded_prompt_mask_seq_chunk = ( - kwargs["position_ids"][seq_i] == 0 - ).sum(dim=0) - 1 + required_extra_pads - left_padded_prompt_mask_seq_chunk = left_padded_prompt_mask_seq_chunk.unsqueeze(0) + left_padded_prompt_mask_seq_chunk.unsqueeze(0) + ) # Chunked prefill for chunk_j in range(math.ceil(current_tkv / prefill_chunk_size)): - if chunk_j == 0: chunk_start = 0 chunk_end = prefill_chunk_size - required_extra_pads @@ -318,22 +326,49 @@ def generate( chunk_start = chunk_end chunk_end += prefill_chunk_size - input_ids_seq_chunk = input_ids[seq_i][chunk_start: chunk_end] + input_ids_seq_chunk = input_ids[seq_i][chunk_start:chunk_end] if required_extra_pads > 0: - input_ids_seq_chunk = torch.cat(( - torch.zeros(required_extra_pads, dtype=torch.int64, device=input_ids_seq_chunk.device), - input_ids_seq_chunk - )) + input_ids_seq_chunk = torch.cat( + ( + torch.zeros( + required_extra_pads, + dtype=torch.int64, + device=input_ids_seq_chunk.device, + ), + input_ids_seq_chunk, + ) + ) + if os.environ["LOCAL_RANK"] == "0": + print("pads were required: ", required_extra_pads) + + if os.environ["LOCAL_RANK"] == "0": + print("input_ids[seq_i] - ", input_ids[seq_i].size(0)) + print( + "chunk start - ", + chunk_start, + " chunk end - ", + chunk_end, + ) + print("chunk ", chunk_j, "-", input_ids_seq_chunk.size(0)) + print("current_tkv", current_tkv) input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone() assert input_ids_seq_chunk.size(1) == prefill_chunk_size, ( f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}" ) - - slot_mapping_seq_chunk = slot_mapping[seq_i][chunk_start:chunk_end] + + slot_mapping_seq_chunk = slot_mapping[seq_i][ + chunk_start:chunk_end + ] if required_extra_pads > 0: - slot_mapping_seq_chunk = slot_mapping_seq_chunk[chunk_start:chunk_start+BLOCK_SIZE] * (required_extra_pads//BLOCK_SIZE) + slot_mapping_seq_chunk + slot_mapping_seq_chunk = ( + slot_mapping_seq_chunk[ + chunk_start : chunk_start + BLOCK_SIZE + ] + * (required_extra_pads // BLOCK_SIZE) + + slot_mapping_seq_chunk + ) slot_mapping_seq_chunk = ( torch.tensor( slot_mapping_seq_chunk, @@ -347,26 +382,59 @@ def generate( f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}" ) - position_ids_seq_chunk = kwargs["position_ids"][seq_i][chunk_start:chunk_end] + position_ids_seq_chunk = kwargs["position_ids"][seq_i][ + chunk_start:chunk_end + ] if required_extra_pads > 0: - position_ids_seq_chunk = torch.cat(( - torch.zeros(required_extra_pads, dtype=torch.int64, device=position_ids_seq_chunk.device), - position_ids_seq_chunk - )) - position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze(0).clone() + position_ids_seq_chunk = torch.cat( + ( + torch.zeros( + required_extra_pads, + dtype=torch.int64, + device=position_ids_seq_chunk.device, + ), + position_ids_seq_chunk, + ) + ) + position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze( + 0 + ).clone() assert position_ids_seq_chunk.size(1) == prefill_chunk_size, ( f"prefill chunk size was not equal to the chunk size for position_ids. Found {position_ids_seq_chunk.size(0)}" ) - current_tkv_mask_seq_chunk = torch.tensor(prefill_chunk_size, dtype=torch.int64).unsqueeze(0) + current_tkv_mask_seq_chunk = torch.tensor( + (chunk_j + 1) * prefill_chunk_size, dtype=torch.int64 + ).unsqueeze(0) - block_start = chunk_start // BLOCK_SIZE block_end = chunk_end // BLOCK_SIZE block_table_seq_chunk = torch.tensor( - block_table[seq_i][block_start:block_end], dtype=torch.int64 + block_table[seq_i][:1] + * ( + (prefill_chunk_size - chunk_end - chunk_start) + // BLOCK_SIZE + ) + + block_table[seq_i][:block_end], + dtype=torch.int64, ).unsqueeze(0) + if os.environ["LOCAL_RANK"] == "0": + print("slot_mapping - ", slot_mapping_seq_chunk.shape) + print("position_ids - ", position_ids_seq_chunk.shape) + print( + "left_padded_prompt_mask_seq_chunk -", + left_padded_prompt_mask_seq_chunk, + ) + print("current_tkv_mask -", current_tkv_mask_seq_chunk) + print( + "block_table_seq_chunk - ", + block_table_seq_chunk.shape, + " - ", + block_table_seq_chunk.tolist(), + ) + print("input_ids_seq_chunk - ", input_ids_seq_chunk.shape) + chunked_kwargs = { "slot_mapping": slot_mapping_seq_chunk, "position_ids": position_ids_seq_chunk, From 13df6396260f86459702d9ae5178f4bdac55e1c4 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 13 Nov 2025 14:14:53 +0000 Subject: [PATCH 09/13] added tests Signed-off-by: Joshua Rosenkranz --- aiu_fms_testing_utils/utils/paged.py | 29 ---------------------------- tests/models/test_scripts.py | 21 +++++++++++++++----- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index f9265009..7679bbcc 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -338,19 +338,6 @@ def generate( input_ids_seq_chunk, ) ) - if os.environ["LOCAL_RANK"] == "0": - print("pads were required: ", required_extra_pads) - - if os.environ["LOCAL_RANK"] == "0": - print("input_ids[seq_i] - ", input_ids[seq_i].size(0)) - print( - "chunk start - ", - chunk_start, - " chunk end - ", - chunk_end, - ) - print("chunk ", chunk_j, "-", input_ids_seq_chunk.size(0)) - print("current_tkv", current_tkv) input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone() @@ -419,22 +406,6 @@ def generate( dtype=torch.int64, ).unsqueeze(0) - if os.environ["LOCAL_RANK"] == "0": - print("slot_mapping - ", slot_mapping_seq_chunk.shape) - print("position_ids - ", position_ids_seq_chunk.shape) - print( - "left_padded_prompt_mask_seq_chunk -", - left_padded_prompt_mask_seq_chunk, - ) - print("current_tkv_mask -", current_tkv_mask_seq_chunk) - print( - "block_table_seq_chunk - ", - block_table_seq_chunk.shape, - " - ", - block_table_seq_chunk.tolist(), - ) - print("input_ids_seq_chunk - ", input_ids_seq_chunk.shape) - chunked_kwargs = { "slot_mapping": slot_mapping_seq_chunk, "position_ids": position_ids_seq_chunk, diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index a725e43f..67d4a53c 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -175,12 +175,15 @@ def execute_dpp( test_type, skip_validation, enforce_homogeneous_prompt_programs, + prefill_chunk_size, shared_tmp_path, isolated_env, ): isolated_env["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = "1024" isolated_env["VLLM_DT_MAX_CONTEXT_LEN"] = "512" isolated_env["VLLM_DT_MAX_BATCH_SIZE"] = "2" + if prefill_chunk_size > 0: + isolated_env["VLLM_DT_CHUNK_LEN"] = f"{prefill_chunk_size}" Path(os.path.join(shared_tmp_path, "sendnn_cache")).mkdir(exist_ok=True) os.environ.setdefault( "TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache") @@ -239,6 +242,9 @@ def execute_dpp( if enforce_homogeneous_prompt_programs: command_list += ["--enforce_homogeneous_prompt_programs"] + if prefill_chunk_size > 0: + command_list += [f"--prefill_chunk_size={prefill_chunk_size}"] + # add program criteria path command_list += [ f"--program_criteria_json_path={os.environ['DT_PROG_CRITERIA_FILEPATH']}" @@ -249,21 +255,24 @@ def execute_dpp( dpp_possibilities = [] dpp_possibilities.append( - ("paged", None, 8, "sharegpt", "metrics", False, False) + ("paged", None, 8, "sharegpt", "metrics", False, False, 0) ) # metrics and run all programs dpp_possibilities.append( - ("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False) + ("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False, 0) ) # tokens and run all programs that satisfy 256 sequence length dpp_possibilities.append( - ("paged", "*:>=2,0", 65, "sharegpt", None, True, True) + ("paged", "*:>=2,0", 65, "sharegpt", None, True, True, 0) ) # metrics and run all programs that have >=2 batch size dpp_possibilities.append( - ("paged", None, 8, "custom", "tokens", False, False) + ("paged", None, 8, "custom", "tokens", False, False, 0) ) # tokens running with specific custom dataset +dpp_possibilities.append( + ("paged", None, 8, "sharegpt", "tokens", False, False, 128) +) # metrics and run all programs @pytest.mark.parametrize( - "attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs", + "attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs,prefill_chunk_size", dpp_possibilities, ) def test_dpp_script( @@ -274,6 +283,7 @@ def test_dpp_script( test_type, skip_validation, enforce_homogeneous_prompt_programs, + prefill_chunk_size, shared_tmp_path, isolated_env, ): @@ -290,6 +300,7 @@ def test_dpp_script( test_type, skip_validation, enforce_homogeneous_prompt_programs, + prefill_chunk_size, shared_tmp_path, isolated_env, ) From cc465016435e3b91c469b40639fdcac044cb265f Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 13 Nov 2025 16:52:51 +0000 Subject: [PATCH 10/13] addressed PR comments Signed-off-by: Joshua Rosenkranz --- .../scripts/drive_paged_programs.py | 6 +- aiu_fms_testing_utils/utils/paged.py | 57 ++++++++++--------- tests/models/test_scripts.py | 2 +- 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 1c51f777..c10652b1 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -580,17 +580,13 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user enforce_sizes = [valid_prompt_shape[1]] if args.enforce_homogeneous_prompt_programs: - # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length + # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) possible_seq_lengths = [ _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) ] # favor sequences that are close to the valid prompt length possible_seq_lengths.reverse() - # add the valid prompt size to the end since it will already exist in the above enforce_sizes - possible_seq_lengths = possible_seq_lengths + [ - valid_prompt_shape[1] - ] enforce_sizes = enforce_sizes + list( itertools.islice( itertools.cycle(possible_seq_lengths), diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 7679bbcc..88207fb2 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -318,6 +318,7 @@ def generate( # Chunked prefill for chunk_j in range(math.ceil(current_tkv / prefill_chunk_size)): + # chunk_start and chunk_end are the index mappings from the original sequence if chunk_j == 0: chunk_start = 0 chunk_end = prefill_chunk_size - required_extra_pads @@ -327,6 +328,14 @@ def generate( chunk_end += prefill_chunk_size input_ids_seq_chunk = input_ids[seq_i][chunk_start:chunk_end] + slot_mapping_seq_chunk = slot_mapping[seq_i][ + chunk_start:chunk_end + ] + position_ids_seq_chunk = kwargs["position_ids"][seq_i][ + chunk_start:chunk_end + ] + + # add the extra required padding to chunk if required_extra_pads > 0: input_ids_seq_chunk = torch.cat( ( @@ -338,17 +347,6 @@ def generate( input_ids_seq_chunk, ) ) - - input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone() - - assert input_ids_seq_chunk.size(1) == prefill_chunk_size, ( - f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}" - ) - - slot_mapping_seq_chunk = slot_mapping[seq_i][ - chunk_start:chunk_end - ] - if required_extra_pads > 0: slot_mapping_seq_chunk = ( slot_mapping_seq_chunk[ chunk_start : chunk_start + BLOCK_SIZE @@ -356,23 +354,6 @@ def generate( * (required_extra_pads // BLOCK_SIZE) + slot_mapping_seq_chunk ) - slot_mapping_seq_chunk = ( - torch.tensor( - slot_mapping_seq_chunk, - dtype=torch.int64, - ) - .unsqueeze(0) - .clone() - ) - - assert slot_mapping_seq_chunk.size(1) == prefill_chunk_size, ( - f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}" - ) - - position_ids_seq_chunk = kwargs["position_ids"][seq_i][ - chunk_start:chunk_end - ] - if required_extra_pads > 0: position_ids_seq_chunk = torch.cat( ( torch.zeros( @@ -383,10 +364,30 @@ def generate( position_ids_seq_chunk, ) ) + + input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone() + + slot_mapping_seq_chunk = ( + torch.tensor( + slot_mapping_seq_chunk, + dtype=torch.int64, + ) + .unsqueeze(0) + .clone() + ) + position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze( 0 ).clone() + assert input_ids_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}" + ) + + assert slot_mapping_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}" + ) + assert position_ids_seq_chunk.size(1) == prefill_chunk_size, ( f"prefill chunk size was not equal to the chunk size for position_ids. Found {position_ids_seq_chunk.size(0)}" ) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 67d4a53c..e813daec 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -268,7 +268,7 @@ def execute_dpp( ) # tokens running with specific custom dataset dpp_possibilities.append( ("paged", None, 8, "sharegpt", "tokens", False, False, 128) -) # metrics and run all programs +) # metrics and run all programs with chunked prefill @pytest.mark.parametrize( From c28aed86af52b4536ded66fb62fd355438dcf9e5 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 13 Nov 2025 16:55:52 +0000 Subject: [PATCH 11/13] disable cache for chunk prefill in test_scripts Signed-off-by: Joshua Rosenkranz --- tests/models/test_scripts.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index e813daec..83885339 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -185,10 +185,15 @@ def execute_dpp( if prefill_chunk_size > 0: isolated_env["VLLM_DT_CHUNK_LEN"] = f"{prefill_chunk_size}" Path(os.path.join(shared_tmp_path, "sendnn_cache")).mkdir(exist_ok=True) - os.environ.setdefault( - "TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache") - ) - isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1" + + # only enable for non-chunk + if prefill_chunk_size == 0: + os.environ.setdefault( + "TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache") + ) + isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1" + else: + isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "0" command_list = [ "python3", From 40f88e1fbdb3a95bc118ef1bf3cfd97dc6ed4bac Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Fri, 14 Nov 2025 00:22:22 +0000 Subject: [PATCH 12/13] adding pad block for chunked prefill; Signed-off-by: Joshua Rosenkranz --- aiu_fms_testing_utils/utils/paged.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 88207fb2..3dd7f579 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -227,6 +227,15 @@ def generate( # left_padded_prompt_mask - empty_slots + context_lengths current_tkv_mask = torch.fill(context_lengths, input_ids.shape[1]) + # if using chunked prefill, reserve a pad block + # reserving a pad block is required as writes to pad are done in parallel and could corrupt the real blocks + if prefill_chunk_size > 0: + pad_block_id = block_numbers.pop(0) + pad_slots = [ + (pad_block_id * BLOCK_SIZE) + (pos_i % BLOCK_SIZE) + for pos_i in range(BLOCK_SIZE) + ] + slot_mapping = [] block_table = [] # each sequence has the possibility of a different tkv, so loop over that @@ -348,10 +357,7 @@ def generate( ) ) slot_mapping_seq_chunk = ( - slot_mapping_seq_chunk[ - chunk_start : chunk_start + BLOCK_SIZE - ] - * (required_extra_pads // BLOCK_SIZE) + pad_slots * (required_extra_pads // BLOCK_SIZE) + slot_mapping_seq_chunk ) position_ids_seq_chunk = torch.cat( @@ -398,7 +404,7 @@ def generate( block_end = chunk_end // BLOCK_SIZE block_table_seq_chunk = torch.tensor( - block_table[seq_i][:1] + [pad_block_id] * ( (prefill_chunk_size - chunk_end - chunk_start) // BLOCK_SIZE From b6175ac61aad5fdd52c2c315ff271b090198d93f Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Fri, 14 Nov 2025 15:18:54 +0000 Subject: [PATCH 13/13] removed unnecessary % Signed-off-by: Joshua Rosenkranz --- aiu_fms_testing_utils/utils/paged.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 3dd7f579..c02bd8f3 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -231,10 +231,7 @@ def generate( # reserving a pad block is required as writes to pad are done in parallel and could corrupt the real blocks if prefill_chunk_size > 0: pad_block_id = block_numbers.pop(0) - pad_slots = [ - (pad_block_id * BLOCK_SIZE) + (pos_i % BLOCK_SIZE) - for pos_i in range(BLOCK_SIZE) - ] + pad_slots = [(pad_block_id * BLOCK_SIZE) + pos_i for pos_i in range(BLOCK_SIZE)] slot_mapping = [] block_table = []