diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 7ecac8ff126..05ee4a348ea 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -50,10 +50,11 @@ class RequestStatus(Enum): WAITING = 0 - RUNNING = 1 - PREEMPTED = 2 - FINISHED = 3 - ABORT = 4 + RUNNING_PREFILL = 1 + RUNNING_DECODE = 2 + PREEMPTED = 3 + FINISHED = 4 + ABORT = 5 class RequestType(Enum): diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 3382e077d60..75b87d7a4c1 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -218,18 +218,23 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.bos_client = None self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4) - self.init_reserve_output_block_num = ( - envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.decay_output_block_num = ( - envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # float - self.min_reserve_output_block_num = ( - envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = True + self.use_new_token_ratio_reserve = envs.FD_USE_NEW_TOKEN_RATIO_RESERVE + if self.use_new_token_ratio_reserve: + self.init_new_token_ratio = envs.FD_INIT_NEW_TOKEN_RATIO + self.min_new_token_ratio = envs.FD_MIN_NEW_TOKEN_RATIO + self.new_token_ratio_decay = envs.FD_NEW_TOKEN_RATIO_DECAY + self.clip_max_new_tokens = envs.FD_CLIP_MAX_NEW_TOKENS + self.new_token_ratio = self.init_new_token_ratio + else: + self.init_reserve_output_block_num = envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + self.decay_output_block_num = envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + self.min_reserve_output_block_num = ( + envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + ) + self.current_reserve_output_block_num = self.init_reserve_output_block_num + self.current_reserve_output_block_num_float = float(self.init_reserve_output_block_num) + self.can_relax_prefill_strategy = True + # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 @@ -312,15 +317,6 @@ def _info_each_block(self): f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables" ) - def _can_preempt(self): - """ - cannot preempt request which use extend block - """ - for req in self.running: - if not req.use_extend_tables: - return True - return False - def preempted_all(self): with self.lock: preempted_reqs = [] @@ -355,17 +351,49 @@ def wait_worker_inflight_requests_finish(self, timeout=60): f"still {len(self.to_be_rescheduled_request_id_set)} requests running" ) + def _select_preempt_candidate(self): + # Scan from back to front to find the last preemptable request + preempted_req = None + i = len(self.running) - 1 + while i >= 0: + candidate = self.running[i] + # Skip requests that are not in decode status + if candidate.status != RequestStatus.RUNNING_DECODE: + i -= 1 + continue + # Skip requests using extend tables + if candidate.use_extend_tables: + i -= 1 + continue + # Found a valid preempt target + preempted_req = candidate + break + return preempted_req, i + def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): """ If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out. + Only requests that is in decode status can be preempted. """ can_schedule = False - while self._can_preempt(): - if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): - preempted_req = self.running.pop() - if preempted_req.use_extend_tables: - self.running.insert(0, preempted_req) - continue + while True: + if self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): + # The request can be scheduled. + can_schedule = True + break + else: + # Try to find a candidate request to preempt. + preempted_req, preempted_idx = self._select_preempt_candidate() + if preempted_req is None: + can_schedule = False + llm_logger.warning( + f"Preemption is triggered while no preemptable request can be found, scheduler may be hung! " + f"Running requests: {self.running}" + ) + break + + # Remove the preempted request from the running list + self.running.pop(preempted_idx) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.config.scheduler_config.splitwise_role == "decode": @@ -397,33 +425,82 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re llm_logger.debug( f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}" ) + llm_logger.debug(self.info()) self._info_each_block() + self._reset_reserve_on_preemption() if preempted_req == request: # No more request to preempt. can_schedule = False break - else: - # The request can be scheduled. - can_schedule = True - break - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = False + return can_schedule + def _reset_reserve_on_preemption(self): + """Reset reserved blocks on preemption.""" + if self.use_new_token_ratio_reserve: + if not self.running: + self.new_token_ratio = self.init_new_token_ratio + return + total_decoded_tokens = sum(len(req.output_token_ids) for req in self.running) + total_max_new_tokens = 0 + for req in self.running: + max_tokens = req.sampling_params.max_tokens + if max_tokens is None: + max_tokens = self.config.model_config.max_model_len - req.prompt_token_ids_len + total_max_new_tokens += max_tokens + num_running_decode = sum( + [1 if req.num_total_tokens > req.need_prefill_tokens else 0 for req in self.running] + ) + extra_decode_steps = ( + 16 * self.config.cache_config.block_size + ) # consider extra 16 blocks for each running decode request when estimating new token ratio + new_ratio = (total_decoded_tokens + extra_decode_steps * num_running_decode) / (total_max_new_tokens + 1) + self.new_token_ratio = min(new_ratio, self.init_new_token_ratio) + llm_logger.info( + f"Estimate new token ratio for preemption: {self.new_token_ratio}, " + f"total_decoded_tokens={total_decoded_tokens}, total_max_new_tokens={total_max_new_tokens}, num_running_decode={num_running_decode}" + ) + + else: + self.current_reserve_output_block_num = self.init_reserve_output_block_num + self.current_reserve_output_block_num_float = float(self.init_reserve_output_block_num) + self.can_relax_prefill_strategy = False + + def _get_running_request_reserve_blocks(self, request: Request) -> int: + """Estimate KV-cache blocks to reserve for a running request's future decode tokens. + + Aligned with SGLang's per-request budget estimation: + reserved_tokens = min(max_tokens - already_generated, CLIP_MAX_NEW_TOKENS) * new_token_ratio + then ceil-divided by block_size. The ratio decays each scheduling step so that + the reservation gradually relaxes; on preemption it resets to the initial value. + """ + max_tokens = getattr(request.sampling_params, "max_tokens", None) + if max_tokens is None: + max_tokens = self.config.model_config.max_model_len - request.prompt_token_ids_len + remaining_tokens = max_tokens - len(request.output_token_ids) + clipped_remaining = min(remaining_tokens, self.clip_max_new_tokens) + reserved_tokens = max(int(clipped_remaining * self.new_token_ratio), 0) + block_size = self.config.cache_config.block_size + return (reserved_tokens + block_size - 1) // block_size + def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block): - if self.can_relax_prefill_strategy: - can_schedule_block_num_threshold = num_chunk_new_block + """Compute the minimum free blocks required to admit a new prefill request.""" + if self.use_new_token_ratio_reserve: + reserve_blocks = sum(self._get_running_request_reserve_blocks(req) for req in self.running) + can_schedule_block_num_threshold = num_chunk_new_block + reserve_blocks else: - can_schedule_block_num_threshold = ( - num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num - ) - if self.config.speculative_config.method is not None: - can_schedule_block_num_threshold = min( - can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + if self.can_relax_prefill_strategy: + can_schedule_block_num_threshold = num_chunk_new_block + else: + can_schedule_block_num_threshold = ( + num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num ) + if self.config.speculative_config.method is not None: + can_schedule_block_num_threshold = min( + can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + ) return can_schedule_block_num_threshold def _update_mm_hashes(self, request): @@ -786,6 +863,7 @@ def get_enough_request(request, scheduled_reqs): self.config.scheduler_config.max_num_batched_tokens - num_running_decode_reqs * tokens_per_seq ) need_abort_requests = [] # users trigger abortion + chunk_prefill_in_running_not_satisfied = False # First, schedule the RUNNING requests. req_index = 0 @@ -922,22 +1000,17 @@ def _allocate_decode_and_extend(): req_index += 1 continue num_new_block = self.get_new_block_nums(request, num_new_tokens) + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(num_new_block) # Allocate blocks to prefill - if self.cache_manager.can_allocate_gpu_blocks(num_new_block): - request.block_tables.extend( - self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) - ) - # Prepare prefill task - scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) - else: # Not enough blocks to allocate, trigger preemption - can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) - if not can_schedule: - break + if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): request.block_tables.extend( self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) ) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + else: # Not enough blocks to allocate + chunk_prefill_in_running_not_satisfied = True + break # For chunk prefill request, if not satisfy condition for prefill, just break token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if ( @@ -955,7 +1028,7 @@ def _allocate_decode_and_extend(): self.running.remove(request) # Second, schedule the WAITING requests. - if not preempted_reqs: + if (not preempted_reqs) and (not chunk_prefill_in_running_not_satisfied): skip_requests: list[Request] = [] while self.waiting and token_budget > 0: if ( @@ -1041,7 +1114,7 @@ def _allocate_decode_and_extend(): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) - request.status = RequestStatus.RUNNING + request.status = RequestStatus.RUNNING_PREFILL if self.config.scheduler_config.splitwise_role == "mixed": allocated_position = self.get_available_position() request.idx = allocated_position @@ -1110,7 +1183,7 @@ def _allocate_decode_and_extend(): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) - request.status = RequestStatus.RUNNING + request.status = RequestStatus.RUNNING_PREFILL else: if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) @@ -1124,14 +1197,20 @@ def _allocate_decode_and_extend(): if scheduled_reqs: llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") - self.current_reserve_output_block_num_float -= self.decay_output_block_num - self.current_reserve_output_block_num = max( - int(self.current_reserve_output_block_num_float), - self.min_reserve_output_block_num, - 0, - ) - if self.current_reserve_output_block_num == 0: - self.can_relax_prefill_strategy = True + if self.use_new_token_ratio_reserve: + self.new_token_ratio = max( + self.new_token_ratio - self.new_token_ratio_decay, + self.min_new_token_ratio, + ) + else: + self.current_reserve_output_block_num_float -= self.decay_output_block_num + self.current_reserve_output_block_num = max( + int(self.current_reserve_output_block_num_float), + self.min_reserve_output_block_num, + 0, + ) + if self.current_reserve_output_block_num == 0: + self.can_relax_prefill_strategy = True self._log_console_scheduler_metrics(scheduled_reqs) @@ -1355,6 +1434,7 @@ def pre_recycle_resource(self, request_id: str): def add_request_in_p(self, requests: list[Request]): with self.lock: for request in requests: + request.status = RequestStatus.RUNNING_PREFILL self.running.append(request) def preallocate_resource_in_p(self, request: Request): @@ -1487,6 +1567,7 @@ def add_prefilled_request(self, request_output: RequestOutput): ): request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids) request.need_prefill_tokens = len(request.prompt_token_ids) + 1 + request.status = RequestStatus.RUNNING_DECODE request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time @@ -1553,7 +1634,7 @@ def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): def finish_requests(self, request_ids: Union[str, Iterable[str]]): llm_logger.info(f"recycle resources for requests: {request_ids}") - self.update_metrics(verbose=True) + self.update_metrics() try: if isinstance(request_ids, str): request_ids = (request_ids,) @@ -1608,7 +1689,7 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): except Exception as e: llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") finally: - self.update_metrics(verbose=True) + self.update_metrics() def clear_data(self): self.waiting: deque[Request] = deque() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 7e0f809d5d3..789f6861b91 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -215,6 +215,11 @@ def _validate_split_kv_size(value: int) -> int: # Whether to enable low latency in mixed scenario "FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))), # Reserve output blocks for decoding requests when schedule new prefill requests + "FD_INIT_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_INIT_NEW_TOKEN_RATIO", "0.7")), + "FD_MIN_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_MIN_NEW_TOKEN_RATIO", "0.1")), + "FD_NEW_TOKEN_RATIO_DECAY": lambda: float(os.getenv("FD_NEW_TOKEN_RATIO_DECAY", "0.001")), + "FD_CLIP_MAX_NEW_TOKENS": lambda: int(os.getenv("FD_CLIP_MAX_NEW_TOKENS", "4096")), + # Legacy reserve block env vars (kept for backwards compatibility, no longer used) "FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16") ), @@ -224,6 +229,9 @@ def _validate_split_kv_size(value: int) -> int: "FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "0") ), + # When True, use per-request new_token_ratio to estimate reserved blocks (SGLang-style). + # When False, fall back to the legacy fixed-block reservation strategy. + "FD_USE_NEW_TOKEN_RATIO_RESERVE": lambda: bool(int(os.getenv("FD_USE_NEW_TOKEN_RATIO_RESERVE", "1"))), # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), # File path for file storage backend diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 6f6a8043803..2a8328b28fe 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -37,6 +37,7 @@ Request, RequestMetrics, RequestOutput, + RequestStatus, SpeculateMetrics, ) from fastdeploy.inter_communicator import ZmqIpcServer @@ -950,6 +951,8 @@ def _process_batch_output(self): continue self.total_step += 1 + if task.status == RequestStatus.RUNNING_PREFILL: + task.status = RequestStatus.RUNNING_DECODE current_time = time.time() trace_carrier = None if self.tokens_counter[task_id] == 0: diff --git a/tests/engine/test_resource_manager_v1.py b/tests/engine/test_resource_manager_v1.py index 23275f29f70..716770294a6 100644 --- a/tests/engine/test_resource_manager_v1.py +++ b/tests/engine/test_resource_manager_v1.py @@ -72,7 +72,7 @@ def test_preempted_all_with_normal_requests(self): req1 = Mock(spec=Request) req1.request_id = "req1" req1.use_extend_tables = False - req1.status = RequestStatus.RUNNING + req1.status = RequestStatus.RUNNING_DECODE req1.block_tables = [1, 2, 3] req1.num_cached_blocks = 0 req1.idx = 0 @@ -80,7 +80,7 @@ def test_preempted_all_with_normal_requests(self): req2 = Mock(spec=Request) req2.request_id = "req2" req2.use_extend_tables = False - req2.status = RequestStatus.RUNNING + req2.status = RequestStatus.RUNNING_DECODE req2.block_tables = [4, 5] req2.num_cached_blocks = 0 req2.idx = 1 diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index b47344470de..c84514c06d5 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -21,7 +21,7 @@ import paddle -from fastdeploy.engine.request import RequestMetrics, RequestOutput +from fastdeploy.engine.request import RequestMetrics, RequestOutput, RequestStatus from fastdeploy.output.token_processor import TokenProcessor paddle.set_device("cpu") @@ -82,6 +82,7 @@ def __init__(self): self.ic_req_data = {} self.prompt_token_ids_len = 0 self.trace_carrier = {} + self.status = RequestStatus.RUNNING_DECODE now = time.time() self.metrics = RequestMetrics( diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 0fd4d1753ee..5c26db778c7 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -25,7 +25,12 @@ import pytest from fastdeploy import envs -from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput +from fastdeploy.engine.request import ( + Request, + RequestMetrics, + RequestOutput, + RequestStatus, +) from fastdeploy.output import token_processor from fastdeploy.output.token_processor import ( MAX_BSZ, @@ -671,6 +676,7 @@ def test_process_batch_output_consumes_tokens_and_finishes_task(): prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) @@ -708,6 +714,7 @@ def test_process_batch_output_logprob_records_topk_and_caching(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -784,6 +791,7 @@ def test_process_batch_output_speculative_recovery_stop_finishes(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -911,6 +919,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -1076,6 +1085,7 @@ def test_process_batch_output_records_second_decode_token(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.metrics.inference_start_time = time.time() @@ -1145,6 +1155,7 @@ def test_process_batch_output_prefill_sets_draft_tokens(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -1186,6 +1197,7 @@ def test_process_batch_output_logs_recovery_stop_for_non_speculative(): prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda k, d=None: getattr(task, k, d) @@ -1223,6 +1235,7 @@ def test_process_batch_output_sets_multimodal_token_counts(): num_total_tokens=1, block_tables=[1], multimodal_inputs={"num_input_image_tokens": 4, "num_input_video_tokens": 5}, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index a93d5741d14..d9ab6a59dbc 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -650,7 +650,7 @@ def test_schedule_decode_and_waiting_prefill(self): decode_request = _make_request(request_id="req-decode", prompt_token_ids=[1, 2]) decode_request.idx = 0 - decode_request.status = RequestStatus.RUNNING + decode_request.status = RequestStatus.RUNNING_DECODE decode_request.num_computed_tokens = 2 decode_request.output_token_ids = [99] decode_request.block_tables = [1] @@ -665,7 +665,7 @@ def test_schedule_decode_and_waiting_prefill(self): self.assertGreaterEqual(len(scheduled_reqs), 2) self.assertEqual(error_reqs, []) self.assertIn(decode_request.request_id, manager.using_extend_tables_req_id) - self.assertEqual(waiting_request.status, RequestStatus.RUNNING) + self.assertEqual(waiting_request.status, RequestStatus.RUNNING_PREFILL) def test_trigger_preempt_records_tasks(self): manager = _build_manager() @@ -678,6 +678,7 @@ def test_trigger_preempt_records_tasks(self): preempted_req = _make_request(request_id="req-preempted") preempted_req.idx = 0 preempted_req.use_extend_tables = False + preempted_req.status = RequestStatus.RUNNING_DECODE request = _make_request(request_id="req-target") request.idx = 1 manager.running = [request, preempted_req]