Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 96 additions & 43 deletions aiu_fms_testing_utils/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
import torch
import fms.utils.spyre.paged # noqa
from aiu_fms_testing_utils.utils import get_pad_size


def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):
Expand Down Expand Up @@ -226,6 +227,12 @@ def generate(
# left_padded_prompt_mask - empty_slots + context_lengths
current_tkv_mask = torch.fill(context_lengths, input_ids.shape[1])

# if using chunked prefill, reserve a pad block
# reserving a pad block is required as writes to pad are done in parallel and could corrupt the real blocks
if prefill_chunk_size > 0:
pad_block_id = block_numbers.pop(0)
pad_slots = [(pad_block_id * BLOCK_SIZE) + pos_i for pos_i in range(BLOCK_SIZE)]

slot_mapping = []
block_table = []
# each sequence has the possibility of a different tkv, so loop over that
Expand Down Expand Up @@ -300,61 +307,107 @@ def generate(
last_n_tokens = kwargs.get("last_n_tokens", 0)

if prefill_chunk_size > 0:
left_padded_prompt_mask_seq_chunk = None
required_extra_pads = (
get_pad_size(current_tkv.item(), prefill_chunk_size)
- current_tkv.item()
)
left_padded_prompt_mask_seq_chunk = (
(kwargs["position_ids"][seq_i][-current_tkv.item() :] == 0).sum(
dim=0
)
- 1
+ required_extra_pads
)
left_padded_prompt_mask_seq_chunk = (
left_padded_prompt_mask_seq_chunk.unsqueeze(0)
)

# Chunked prefill
for chunk_j in range(math.ceil(current_tkv / prefill_chunk_size)):
chunk_start = -current_tkv + chunk_j * prefill_chunk_size
chunk_end = -current_tkv + min(
(chunk_j + 1) * prefill_chunk_size, current_tkv
)
# chunk_start and chunk_end are the index mappings from the original sequence
if chunk_j == 0:
chunk_start = 0
chunk_end = prefill_chunk_size - required_extra_pads
else:
required_extra_pads = 0
chunk_start = chunk_end
chunk_end += prefill_chunk_size
Comment on lines +328 to +334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know what chunk_start and chunk_end mean here, but I don't think they're the best names for these variables, as they are more of a mapping between the original sequence and its chunk partition. I don't know what would be a better name, maybe just a comment explaining what they are?


input_ids_seq_chunk = input_ids[seq_i][chunk_start:chunk_end]
slot_mapping_seq_chunk = slot_mapping[seq_i][
chunk_start:chunk_end
]
position_ids_seq_chunk = kwargs["position_ids"][seq_i][
chunk_start:chunk_end
]

# add the extra required padding to chunk
if required_extra_pads > 0:
input_ids_seq_chunk = torch.cat(
(
torch.zeros(
required_extra_pads,
dtype=torch.int64,
device=input_ids_seq_chunk.device,
),
input_ids_seq_chunk,
)
)
slot_mapping_seq_chunk = (
pad_slots * (required_extra_pads // BLOCK_SIZE)
+ slot_mapping_seq_chunk
)
position_ids_seq_chunk = torch.cat(
(
torch.zeros(
required_extra_pads,
dtype=torch.int64,
device=position_ids_seq_chunk.device,
),
position_ids_seq_chunk,
)
)

input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone()

ids_length = input_ids[seq_i].shape[0]
input_ids_seq_chunk = (
input_ids[seq_i][
chunk_start + ids_length : chunk_end + ids_length
]
.unsqueeze(0)
.clone()
)
slots_length = len(slot_mapping[seq_i])
slot_mapping_seq_chunk = (
torch.tensor(
slot_mapping[seq_i][
chunk_start + slots_length : chunk_end
+ slots_length
],
slot_mapping_seq_chunk,
dtype=torch.int64,
)
.unsqueeze(0)
.clone()
)
pids_length = kwargs["position_ids"][seq_i].shape[0]
position_ids_seq_chunk = (
kwargs["position_ids"][seq_i][
chunk_start + pids_length : chunk_end + pids_length
]
.unsqueeze(0)
.clone()

position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze(
0
).clone()

assert input_ids_seq_chunk.size(1) == prefill_chunk_size, (
f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}"
)

# This view will result in a discontiguous tensor (creates a new graph during compile)
# For this reason, we must explicitly make contiguous
if left_padded_prompt_mask_seq_chunk is None:
left_padded_prompt_mask_seq_chunk = (
position_ids_seq_chunk == 0
).sum(dim=1) - 1
current_tkv_mask_seq_chunk = torch.min(
torch.tensor(
(chunk_j + 1) * prefill_chunk_size, dtype=torch.int64
),
current_tkv,
assert slot_mapping_seq_chunk.size(1) == prefill_chunk_size, (
f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}"
)

assert position_ids_seq_chunk.size(1) == prefill_chunk_size, (
f"prefill chunk size was not equal to the chunk size for position_ids. Found {position_ids_seq_chunk.size(0)}"
)

current_tkv_mask_seq_chunk = torch.tensor(
(chunk_j + 1) * prefill_chunk_size, dtype=torch.int64
).unsqueeze(0)

table_length = len(block_table[seq_i])
block_start = -current_tkv // BLOCK_SIZE + table_length
block_end = chunk_end // BLOCK_SIZE + table_length
block_end = chunk_end // BLOCK_SIZE
block_table_seq_chunk = torch.tensor(
block_table[seq_i][block_start:block_end], dtype=torch.int64
[pad_block_id]
* (
(prefill_chunk_size - chunk_end - chunk_start)
// BLOCK_SIZE
)
+ block_table[seq_i][:block_end],
dtype=torch.int64,
).unsqueeze(0)

chunked_kwargs = {
Expand Down Expand Up @@ -577,12 +630,12 @@ def generate(
return result


# this value is default to 2080 to be consistent with vllm for granite 3.3 8b instruct
# this value is default to 8192 to be consistent with vllm for granite 3.3 8b instruct w/ chunked prefill
KVCACHE_NUM_BLOCKS_HINT = int(
os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 2080)
os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 8192)
)

VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 131072))
VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 524288))


class ProgramCriteria:
Expand Down
34 changes: 25 additions & 9 deletions tests/models/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,25 @@ def execute_dpp(
test_type,
skip_validation,
enforce_homogeneous_prompt_programs,
prefill_chunk_size,
shared_tmp_path,
isolated_env,
):
isolated_env["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = "1024"
isolated_env["VLLM_DT_MAX_CONTEXT_LEN"] = "512"
isolated_env["VLLM_DT_MAX_BATCH_SIZE"] = "2"
if prefill_chunk_size > 0:
isolated_env["VLLM_DT_CHUNK_LEN"] = f"{prefill_chunk_size}"
Path(os.path.join(shared_tmp_path, "sendnn_cache")).mkdir(exist_ok=True)
os.environ.setdefault(
"TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache")
)
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1"

# only enable for non-chunk
if prefill_chunk_size == 0:
os.environ.setdefault(
"TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache")
)
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1"
else:
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "0"

command_list = [
"python3",
Expand Down Expand Up @@ -239,6 +247,9 @@ def execute_dpp(
if enforce_homogeneous_prompt_programs:
command_list += ["--enforce_homogeneous_prompt_programs"]

if prefill_chunk_size > 0:
command_list += [f"--prefill_chunk_size={prefill_chunk_size}"]

# add program criteria path
command_list += [
f"--program_criteria_json_path={os.environ['DT_PROG_CRITERIA_FILEPATH']}"
Expand All @@ -249,21 +260,24 @@ def execute_dpp(

dpp_possibilities = []
dpp_possibilities.append(
("paged", None, 8, "sharegpt", "metrics", False, False)
("paged", None, 8, "sharegpt", "metrics", False, False, 0)
) # metrics and run all programs
dpp_possibilities.append(
("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False)
("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False, 0)
) # tokens and run all programs that satisfy 256 sequence length
dpp_possibilities.append(
("paged", "*:>=2,0", 65, "sharegpt", None, True, True)
("paged", "*:>=2,0", 65, "sharegpt", None, True, True, 0)
) # metrics and run all programs that have >=2 batch size
dpp_possibilities.append(
("paged", None, 8, "custom", "tokens", False, False)
("paged", None, 8, "custom", "tokens", False, False, 0)
) # tokens running with specific custom dataset
dpp_possibilities.append(
("paged", None, 8, "sharegpt", "tokens", False, False, 128)
) # metrics and run all programs with chunked prefill


@pytest.mark.parametrize(
"attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs",
"attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs,prefill_chunk_size",
dpp_possibilities,
)
def test_dpp_script(
Expand All @@ -274,6 +288,7 @@ def test_dpp_script(
test_type,
skip_validation,
enforce_homogeneous_prompt_programs,
prefill_chunk_size,
shared_tmp_path,
isolated_env,
):
Expand All @@ -290,6 +305,7 @@ def test_dpp_script(
test_type,
skip_validation,
enforce_homogeneous_prompt_programs,
prefill_chunk_size,
shared_tmp_path,
isolated_env,
)
Expand Down
Loading