diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index bfec131..c02bd8f 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -5,6 +5,7 @@ from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union import torch import fms.utils.spyre.paged # noqa +from aiu_fms_testing_utils.utils import get_pad_size def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): @@ -226,6 +227,12 @@ def generate( # left_padded_prompt_mask - empty_slots + context_lengths current_tkv_mask = torch.fill(context_lengths, input_ids.shape[1]) + # if using chunked prefill, reserve a pad block + # reserving a pad block is required as writes to pad are done in parallel and could corrupt the real blocks + if prefill_chunk_size > 0: + pad_block_id = block_numbers.pop(0) + pad_slots = [(pad_block_id * BLOCK_SIZE) + pos_i for pos_i in range(BLOCK_SIZE)] + slot_mapping = [] block_table = [] # each sequence has the possibility of a different tkv, so loop over that @@ -300,61 +307,107 @@ def generate( last_n_tokens = kwargs.get("last_n_tokens", 0) if prefill_chunk_size > 0: - left_padded_prompt_mask_seq_chunk = None + required_extra_pads = ( + get_pad_size(current_tkv.item(), prefill_chunk_size) + - current_tkv.item() + ) + left_padded_prompt_mask_seq_chunk = ( + (kwargs["position_ids"][seq_i][-current_tkv.item() :] == 0).sum( + dim=0 + ) + - 1 + + required_extra_pads + ) + left_padded_prompt_mask_seq_chunk = ( + left_padded_prompt_mask_seq_chunk.unsqueeze(0) + ) + # Chunked prefill for chunk_j in range(math.ceil(current_tkv / prefill_chunk_size)): - chunk_start = -current_tkv + chunk_j * prefill_chunk_size - chunk_end = -current_tkv + min( - (chunk_j + 1) * prefill_chunk_size, current_tkv - ) + # chunk_start and chunk_end are the index mappings from the original sequence + if chunk_j == 0: + chunk_start = 0 + chunk_end = prefill_chunk_size - required_extra_pads + else: + required_extra_pads = 0 + chunk_start = chunk_end + chunk_end += prefill_chunk_size + + input_ids_seq_chunk = input_ids[seq_i][chunk_start:chunk_end] + slot_mapping_seq_chunk = slot_mapping[seq_i][ + chunk_start:chunk_end + ] + position_ids_seq_chunk = kwargs["position_ids"][seq_i][ + chunk_start:chunk_end + ] + + # add the extra required padding to chunk + if required_extra_pads > 0: + input_ids_seq_chunk = torch.cat( + ( + torch.zeros( + required_extra_pads, + dtype=torch.int64, + device=input_ids_seq_chunk.device, + ), + input_ids_seq_chunk, + ) + ) + slot_mapping_seq_chunk = ( + pad_slots * (required_extra_pads // BLOCK_SIZE) + + slot_mapping_seq_chunk + ) + position_ids_seq_chunk = torch.cat( + ( + torch.zeros( + required_extra_pads, + dtype=torch.int64, + device=position_ids_seq_chunk.device, + ), + position_ids_seq_chunk, + ) + ) + + input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone() - ids_length = input_ids[seq_i].shape[0] - input_ids_seq_chunk = ( - input_ids[seq_i][ - chunk_start + ids_length : chunk_end + ids_length - ] - .unsqueeze(0) - .clone() - ) - slots_length = len(slot_mapping[seq_i]) slot_mapping_seq_chunk = ( torch.tensor( - slot_mapping[seq_i][ - chunk_start + slots_length : chunk_end - + slots_length - ], + slot_mapping_seq_chunk, dtype=torch.int64, ) .unsqueeze(0) .clone() ) - pids_length = kwargs["position_ids"][seq_i].shape[0] - position_ids_seq_chunk = ( - kwargs["position_ids"][seq_i][ - chunk_start + pids_length : chunk_end + pids_length - ] - .unsqueeze(0) - .clone() + + position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze( + 0 + ).clone() + + assert input_ids_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}" ) - # This view will result in a discontiguous tensor (creates a new graph during compile) - # For this reason, we must explicitly make contiguous - if left_padded_prompt_mask_seq_chunk is None: - left_padded_prompt_mask_seq_chunk = ( - position_ids_seq_chunk == 0 - ).sum(dim=1) - 1 - current_tkv_mask_seq_chunk = torch.min( - torch.tensor( - (chunk_j + 1) * prefill_chunk_size, dtype=torch.int64 - ), - current_tkv, + assert slot_mapping_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}" + ) + + assert position_ids_seq_chunk.size(1) == prefill_chunk_size, ( + f"prefill chunk size was not equal to the chunk size for position_ids. Found {position_ids_seq_chunk.size(0)}" + ) + + current_tkv_mask_seq_chunk = torch.tensor( + (chunk_j + 1) * prefill_chunk_size, dtype=torch.int64 ).unsqueeze(0) - table_length = len(block_table[seq_i]) - block_start = -current_tkv // BLOCK_SIZE + table_length - block_end = chunk_end // BLOCK_SIZE + table_length + block_end = chunk_end // BLOCK_SIZE block_table_seq_chunk = torch.tensor( - block_table[seq_i][block_start:block_end], dtype=torch.int64 + [pad_block_id] + * ( + (prefill_chunk_size - chunk_end - chunk_start) + // BLOCK_SIZE + ) + + block_table[seq_i][:block_end], + dtype=torch.int64, ).unsqueeze(0) chunked_kwargs = { @@ -577,12 +630,12 @@ def generate( return result -# this value is default to 2080 to be consistent with vllm for granite 3.3 8b instruct +# this value is default to 8192 to be consistent with vllm for granite 3.3 8b instruct w/ chunked prefill KVCACHE_NUM_BLOCKS_HINT = int( - os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 2080) + os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 8192) ) -VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 131072)) +VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 524288)) class ProgramCriteria: diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index a725e43..8388533 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -175,17 +175,25 @@ def execute_dpp( test_type, skip_validation, enforce_homogeneous_prompt_programs, + prefill_chunk_size, shared_tmp_path, isolated_env, ): isolated_env["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = "1024" isolated_env["VLLM_DT_MAX_CONTEXT_LEN"] = "512" isolated_env["VLLM_DT_MAX_BATCH_SIZE"] = "2" + if prefill_chunk_size > 0: + isolated_env["VLLM_DT_CHUNK_LEN"] = f"{prefill_chunk_size}" Path(os.path.join(shared_tmp_path, "sendnn_cache")).mkdir(exist_ok=True) - os.environ.setdefault( - "TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache") - ) - isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1" + + # only enable for non-chunk + if prefill_chunk_size == 0: + os.environ.setdefault( + "TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache") + ) + isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1" + else: + isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "0" command_list = [ "python3", @@ -239,6 +247,9 @@ def execute_dpp( if enforce_homogeneous_prompt_programs: command_list += ["--enforce_homogeneous_prompt_programs"] + if prefill_chunk_size > 0: + command_list += [f"--prefill_chunk_size={prefill_chunk_size}"] + # add program criteria path command_list += [ f"--program_criteria_json_path={os.environ['DT_PROG_CRITERIA_FILEPATH']}" @@ -249,21 +260,24 @@ def execute_dpp( dpp_possibilities = [] dpp_possibilities.append( - ("paged", None, 8, "sharegpt", "metrics", False, False) + ("paged", None, 8, "sharegpt", "metrics", False, False, 0) ) # metrics and run all programs dpp_possibilities.append( - ("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False) + ("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False, 0) ) # tokens and run all programs that satisfy 256 sequence length dpp_possibilities.append( - ("paged", "*:>=2,0", 65, "sharegpt", None, True, True) + ("paged", "*:>=2,0", 65, "sharegpt", None, True, True, 0) ) # metrics and run all programs that have >=2 batch size dpp_possibilities.append( - ("paged", None, 8, "custom", "tokens", False, False) + ("paged", None, 8, "custom", "tokens", False, False, 0) ) # tokens running with specific custom dataset +dpp_possibilities.append( + ("paged", None, 8, "sharegpt", "tokens", False, False, 128) +) # metrics and run all programs with chunked prefill @pytest.mark.parametrize( - "attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs", + "attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs,prefill_chunk_size", dpp_possibilities, ) def test_dpp_script( @@ -274,6 +288,7 @@ def test_dpp_script( test_type, skip_validation, enforce_homogeneous_prompt_programs, + prefill_chunk_size, shared_tmp_path, isolated_env, ): @@ -290,6 +305,7 @@ def test_dpp_script( test_type, skip_validation, enforce_homogeneous_prompt_programs, + prefill_chunk_size, shared_tmp_path, isolated_env, )