Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,14 @@ def _run_write_back_storage(
return block_num

elif self.storage_backend_type == "attention_store":
try:
if (self.rank == 0) and self.storage_backend_type == "attention_store":
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 内层 if 条件中 self.storage_backend_type == "attention_store" 是冗余判断——此代码已在 elif self.storage_backend_type == "attention_store": 分支内执行,该条件必然为 True

建议简化为:

if self.rank == 0:
    self.storage_backend.flush_token_index(task_id, token_ids, 0, False)

self.storage_backend.flush_token_index(task_id, token_ids, 0, False)

This comment was marked as outdated.

logger.info(f"Report cache index out HBM to cache storage for task {task_id}")
except Exception as e:
logger.info(
f"Failed to report cache index out HBM to cache storage for task {task_id}, error: {e}"
)
key_cache = []
val_cache = []
for i in range(self.num_layers + self.num_extra_layers):
Expand Down Expand Up @@ -1040,15 +1048,6 @@ def write_back_storage_task(self, task: WriteStorageTask):
except Exception as e:
logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}")
gpu_block_ids = []
finally:
try:
if (self.rank == 0) and self.storage_backend_type == "attention_store":
self.storage_backend.flush_token_index(task.task_id, task.token_ids, 0, False)
logger.info(f"Report cache index out HBM to cache storage for task {task.task_id}")
except Exception as e:
logger.info(
f"Failed to report cache index out HBM to cache storage for task {task.task_id}, error: {e}"
)

result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids)
self.cache_task_queue.swap_to_storage_barrier.wait()
Expand Down
115 changes: 63 additions & 52 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,58 +831,71 @@ def request_match_blocks(self, task: Request, block_size, *args):

if self.kvcache_storage_backend and no_match_token_num >= block_size and not envs.FD_AS_ONLY_FLUSH:
if not self.can_allocate_gpu_blocks(num_blocks=no_match_block_num, try_free_gpu_blocks=False):
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched Storage Cache"
logger.warning(
"request_match_blocks: skip storage cache prefetch because GPU blocks are insufficient, "
f"req_id {req_id}, need {no_match_block_num}, free {len(self.gpu_free_block_list)}"
)

logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {no_match_block_num} block to receive storage cache"
)
gpu_recv_storage_block_ids = self.allocate_gpu_blocks(no_match_block_num)

prefix_block_key = [] if match_block_node.hash_value is None else [match_block_node.hash_value]
cur_token_idx = match_token_num
no_match_block_keys = []
mm_idx = 0
while cur_token_idx <= input_token_num - block_size:
cur_block_token_ids = input_token_ids[cur_token_idx : cur_token_idx + block_size]
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=task,
start_idx=cur_token_idx,
end_idx=cur_token_idx + block_size,
mm_idx=mm_idx,
else:
logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {no_match_block_num} block to receive storage cache"
)
prefix_block_key.extend(extra_keys)
cur_block_key = get_hash_str(cur_block_token_ids, prefix_block_key)
no_match_block_keys.append(cur_block_key)
cur_token_idx += block_size
prefix_block_key = [cur_block_key]
gpu_recv_storage_block_ids = self.allocate_gpu_blocks(no_match_block_num)

prefix_block_key = [] if match_block_node.hash_value is None else [match_block_node.hash_value]
cur_token_idx = match_token_num
no_match_block_keys = []
mm_idx = 0
while cur_token_idx <= input_token_num - block_size:
cur_block_token_ids = input_token_ids[cur_token_idx : cur_token_idx + block_size]
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=task,
start_idx=cur_token_idx,
end_idx=cur_token_idx + block_size,
mm_idx=mm_idx,
)
prefix_block_key.extend(extra_keys)
cur_block_key = get_hash_str(cur_block_token_ids, prefix_block_key)
no_match_block_keys.append(cur_block_key)
cur_token_idx += block_size
prefix_block_key = [cur_block_key]

logger.info(
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
)
start_time = time.time()
read_storage_task = ReadStorageTask(
task_id=req_id,
keys=no_match_block_keys,
token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None,
gpu_block_ids=gpu_recv_storage_block_ids,
start_read_block_idx=match_token_num // block_size,
)
logger.debug(f"issue read storage task: {read_storage_task}")
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
storage_matched_block_num = len(storage_matched_block_ids)
storage_match_token_num = storage_matched_block_num * block_size
cost_time = time.time() - start_time
metrics["storage_cache_prepare_time"] = cost_time
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, "
f"matched block num: {storage_matched_block_num}, cost_time:{cost_time:.6f}s"
)
try:
logger.info(
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
)
start_time = time.time()
read_storage_task = ReadStorageTask(
task_id=req_id,
keys=no_match_block_keys,
token_ids=(
input_token_ids if self.kvcache_storage_backend == "attention_store" else None
),
gpu_block_ids=gpu_recv_storage_block_ids,
start_read_block_idx=match_token_num // block_size,
)
logger.debug(f"issue read storage task: {read_storage_task}")
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
storage_matched_block_num = len(storage_matched_block_ids)
storage_match_token_num = storage_matched_block_num * block_size
cost_time = time.time() - start_time
metrics["storage_cache_prepare_time"] = cost_time
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, "
f"matched block num: {storage_matched_block_num}, cost_time:{cost_time:.6f}s"
)

match_storage_block_ids = gpu_recv_storage_block_ids[:storage_matched_block_num]
self.recycle_gpu_blocks(gpu_recv_storage_block_ids[storage_matched_block_num:])
match_storage_block_ids = gpu_recv_storage_block_ids[:storage_matched_block_num]
self.recycle_gpu_blocks(gpu_recv_storage_block_ids[storage_matched_block_num:])
except Exception as e:
logger.warning(
"request_match_blocks: storage cache prefetch failed, fallback to cache miss, "
f"req_id {req_id}, error: {type(e)} {e}"
)
self.recycle_gpu_blocks(gpu_recv_storage_block_ids, req_id)
gpu_recv_storage_block_ids = []
storage_match_token_num = 0
match_storage_block_ids = []

# 4. update metrics
match_token_num = gpu_match_token_num + cpu_match_token_num + storage_match_token_num
Expand Down Expand Up @@ -1127,10 +1140,7 @@ def write_cache_to_storage(self, request: Request):
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()

if self.config.cache_config.enable_output_caching:
input_token_ids = token_ids + request.output_token_ids
else:
input_token_ids = token_ids
input_token_ids = token_ids + request.output_token_ids

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 移除了 enable_output_caching 条件判断,现在无论该配置开关是否开启,都会将 token_ids + output_token_ids 写入存储缓存。

原代码:

if self.config.cache_config.enable_output_caching:
    input_token_ids = token_ids + request.output_token_ids
else:
    input_token_ids = token_ids

enable_output_caching=False,将 output tokens 写入存储缓存是否符合预期?请在 PR 中说明该行为变更的原因。


req_id = request.request_id
keys = []
Expand All @@ -1144,6 +1154,7 @@ def write_cache_to_storage(self, request: Request):

trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", ""))
gpu_block_ids = request.block_tables[: len(keys)]
input_token_ids = input_token_ids[: len(keys) * self.config.cache_config.block_size]
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
write_storage_task = WriteStorageTask(
task_id=req_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,42 +189,96 @@ def write(
start_write_block_idx: int,
timeout: float = 30.0,
) -> int:
logger.debug(
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}"
)
tokens = Tokens(token_ids, self.config.block_token_size)
k_data_ptrs = [k.data_ptr() for k in key_cache]
v_data_ptrs = [v.data_ptr() for v in val_cache]
num = 0
try:
if current_platform.is_cuda():
num = self.sdk.write(
list(range(self.config.layer_num)),
tokens,
start_write_block_idx,
k_data_ptrs,
v_data_ptrs,
gpu_block_ids,
timeout,
h2h_copy=False,
params=None,
layer_ids = list(range(self.config.layer_num))
block_token_size = self.config.block_token_size

total_timeout = float(os.getenv("AS_WRITE_TOTAL_TIMEOUT", str(timeout)))
slice_block_num = int(os.getenv("AS_WRITE_SLICE_BLOCK_NUM", "100"))
slice_timeout = float(os.getenv("AS_WRITE_SLICE_TIMEOUT", "10"))
logger.debug(
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids}"
f" start_write_block_idx: {start_write_block_idx} timeout: {total_timeout}"
)
total_blocks = len(gpu_block_ids)
total_written = 0
overall_start = time.time()

for slice_start in range(0, total_blocks, slice_block_num):
elapsed = time.time() - overall_start
remaining_timeout = total_timeout - elapsed
if remaining_timeout <= 0:
logger.warning(
f"[WRITE TIMEOUT] task_id: {task_id} total timeout {total_timeout}s reached, "
f"written {total_written}/{total_blocks} blocks"
)
else:
num = self.sdk.write(
list(range(self.config.layer_num)),
tokens,
start_write_block_idx,
k_data_ptrs,
v_data_ptrs,
gpu_block_ids,
timeout,
break

slice_end = min(slice_start + slice_block_num, total_blocks)
slice_gpu_block_ids = gpu_block_ids[slice_start:slice_end]
slice_write_block_idx = start_write_block_idx + slice_start
slice_token_ids = token_ids[: (start_write_block_idx + slice_end) * block_token_size]
slice_tokens = Tokens(slice_token_ids, block_token_size)

logger.debug(
f"[WRITE SLICE BEGIN] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"block_idx={slice_write_block_idx}, blocks={len(slice_gpu_block_ids)}, "
f"token_ids_len={len(slice_token_ids)}, timeout={slice_timeout:.2f}s"
)
slice_start_time = time.time()
try:
if current_platform.is_cuda():
written = self.sdk.write(
layer_ids,
slice_tokens,
slice_write_block_idx,
k_data_ptrs,
v_data_ptrs,
slice_gpu_block_ids,
slice_timeout,
h2h_copy=False,
params=None,
)
else:
written = self.sdk.write(
layer_ids,
slice_tokens,
slice_write_block_idx,
k_data_ptrs,
v_data_ptrs,
slice_gpu_block_ids,
slice_timeout,
)
except AttentionStoreSDKError:
logger.error(
f"[WRITE ERROR] task_id: {task_id} slice [{slice_start}:{slice_end}], "
f"traceback:\n{traceback.format_exc()}"
)
logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}")
except AttentionStoreSDKError:
logger.error(
f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
written = 0
slice_cost = time.time() - slice_start_time
total_written += written

if written < len(slice_gpu_block_ids):
logger.warning(
f"[WRITE SLICE INCOMPLETE] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"({written}/{len(slice_gpu_block_ids)}), cost={slice_cost:.6f}s, "
f"total written {total_written}/{total_blocks}, "
f"prefix cache continuity broken, skip remaining slices"
)
break

logger.debug(
f"[WRITE SLICE END] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"written={written}, cost={slice_cost:.6f}s"
)
return num

total_cost = time.time() - overall_start
logger.info(
f"[WRITE END] task_id: {task_id} total cost={total_cost:.6f}s, "
f"written {total_written}/{total_blocks} blocks"
)
return total_written

def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0):
"""
Expand Down
7 changes: 5 additions & 2 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
topp_seed = paddle.repeat_interleave(infer_seed[:real_bsz], repeats).unsqueeze(1)

MAX_INFER_SEED = 9223372036854775806
if current_platform.is_xpu():
MAX_INFER_SEED = 2147483646
else:
MAX_INFER_SEED = 9223372036854775806

token_lens = paddle.where(
seq_lens_encoder[:real_bsz] == 0,
Expand All @@ -97,7 +100,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le

offsets = paddle.where(
is_decoder,
local_pos * 4,
local_pos * 32,

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug local_pos * 32 未加 XPU 平台门控,将对所有硬件全局修改采样偏移量。

同一函数中 MAX_INFER_SEED 的修改已正确加上 if current_platform.is_xpu(): 门控,此处改动应保持一致。若 local_pos * 32 是 XPU 特有需求,建议同样加门控:

offsets = paddle.where(
    is_decoder,
    local_pos * (32 if current_platform.is_xpu() else 4),
    paddle.zeros_like(local_pos),
)

* 32 是针对所有平台的修正,请在 PR 描述中说明原因。

paddle.zeros_like(local_pos),
)

Expand Down
17 changes: 13 additions & 4 deletions fastdeploy/scheduler/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,14 @@ def _recycle(self, request_id: Optional[str] = None):
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
try:
idx = self.ids.index(request_id)
self.ids.pop(idx)
if idx < self.ids_read_cursor:
self.ids_read_cursor -= 1
except ValueError:
scheduler_logger.warning(f"_recycle error, request_id:{request_id} is not found in ids")
pass
return

if self.max_size <= 0:
Expand All @@ -148,10 +154,10 @@ def _recycle(self, request_id: Optional[str] = None):
break
expired_ids.append(request.request_id)

for i, expired_id in enumerate(expired_ids):
for expired_id in expired_ids:
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
self.ids = self.ids[len(expired_ids) :]

if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
Expand Down Expand Up @@ -234,6 +240,9 @@ def calc_required_blocks(self, token_num, block_size):
return (token_num + block_size - 1) // block_size

def get_unhandled_request_num(self):
scheduler_logger.debug(
f"get_unhandled_request_num len(self.ids):{len(self.ids)}, self.ids_read_cursor:{self.ids_read_cursor}"
)
return len(self.ids) - self.ids_read_cursor

def get_requests(
Expand Down
9 changes: 9 additions & 0 deletions fastdeploy/splitwise/internal_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def _recv_external_module_control_instruct(self):
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)

elif task["cmd"] == "interrupt_requests":
self.engine.resource_manager.add_abort_req_ids(task["req_ids"])
result = {
"task_id": task_id_str,
"result": {"success": True, "interrupted_req_ids": task["req_ids"]},
}
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)

except Exception as e:
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")

Expand Down
3 changes: 1 addition & 2 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0

self._process_mm_features(req_dicts)
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["not_need_stop"][0] = has_prefill_task or has_decode_task

if self.spec_method == SpecMethod.MTP:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
Expand Down
Loading