From c78b9ed313eaac58ce5f1151d06cd085e6dd35b2 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 24 Mar 2026 20:16:31 +0800 Subject: [PATCH 1/2] fix pd-split metrics and support other model runner --- fastdeploy/engine/request.py | 4 ++++ fastdeploy/engine/sched/resource_manager_v1.py | 4 +++- fastdeploy/spec_decode/mtp.py | 5 +++-- fastdeploy/worker/gpu_model_runner.py | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 391e2038534..bb309a44c63 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -865,6 +865,7 @@ class RequestMetrics: llm_engine_recv_req_timestamp: Optional[float] = None llm_engine_send_req_to_engine_timestamp: Optional[float] = None + llm_engine_send_req_to_decoder_engine_timestamp: Optional[float] = None llm_engine_recv_latest_token_timestamp: Optional[float] = None llm_engine_recv_token_timestamp: Optional[float] = None @@ -952,6 +953,9 @@ def __getitem__(self, key): def __setitem__(self, key, value): setattr(self, key, value) + def update_decoder_start_time(self): + self.llm_engine_send_req_to_decoder_engine_timestamp = self.decode_inference_start_time + class RequestOutput: """The output data of a completion request to the LLM. diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index b0425d779d1..0d91ea4d8bc 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1439,7 +1439,9 @@ def add_prefilled_request(self, request_output: RequestOutput): request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time - request.metrics = request_output.metrics + request.metrics = copy.deepcopy(request_output.metrics) + request.metrics.decode_inference_start_time = time.time() + request.metrics.update_decoder_start_time() self.running.append(request) def _free_blocks(self, request: Request): diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 5868d0bff0c..ad1e639f897 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -141,6 +141,7 @@ def __init__( self.attn_backends: list[AttentionBackend] = [] self._initialize_attn_backend() + self.eb5_runner = bool(int(os.getenv("EB5_ENABLE_FD_RUNNER", "0"))) # Forward meta store the global meta information of the forward self.forward_meta = None @@ -503,7 +504,7 @@ def insert_tasks_v1( self.model_inputs["step_idx"][idx : idx + 1] = ( len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) - if self.enable_mm: + if self.enable_mm and not self.eb5_runner: inputs = request.multimodal_inputs self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = ( paddle.to_tensor( @@ -885,7 +886,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F self.model_inputs["seq_lens_decoder"], ) - if self.enable_mm: + if self.enable_mm and not self.eb5_runner: attn_mask_offsets = update_attn_mask_offsets( ids_remove_padding, getattr( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1c91847d849..19cf5be86ef 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -870,6 +870,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self._cached_launch_token_num = -1 if self.speculative_decoding: # D speculate decode, seq_lens_this_time = length + 1 + logger.info(f"seq_lens_this_time: {length + 1}") + logger.info(f"draft_tokens: {request.draft_token_ids}") self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + 1 self.share_inputs["draft_tokens"][idx : idx + 1, 0 : length + 1] = paddle.to_tensor( request.draft_token_ids[0 : length + 1], @@ -2011,6 +2013,13 @@ def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_ return prefill_done_idxs + def _execute_empty_mtp_input(self, forward_meta) -> None: + """ + run ep inference forward with empty input. + """ + for _ in range(self.fd_config.speculative_config.num_model_steps): + self.proposer.model.empty_input_forward(forward_meta) + def execute_model( self, model_forward_batch: Optional[List[Request]] = None, @@ -2038,6 +2047,12 @@ def execute_model_normal( model_inputs, p_done_idxs, _ = self._preprocess(model_forward_batch, num_running_requests) model_output = self._execute(model_inputs) if model_output is None or self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item() <= 0: + if ( + self.fd_config.speculative_config.method == SpecMethod.MTP + and hasattr(self.proposer.model, "empty_input_forward") + and self.parallel_config.use_ep + ): + self._execute_empty_mtp_input(self.forward_meta) return model_output_data, sampler_output, post_process_event = self._postprocess( model_output, p_done_idxs, model_forward_batch, num_running_requests From 9fa1c0300a8465a0759f92f5e483c04c88cd312e Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 25 Mar 2026 00:11:40 +0800 Subject: [PATCH 2/2] del print info --- fastdeploy/worker/gpu_model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 19cf5be86ef..5f7d297371e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -870,8 +870,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self._cached_launch_token_num = -1 if self.speculative_decoding: # D speculate decode, seq_lens_this_time = length + 1 - logger.info(f"seq_lens_this_time: {length + 1}") - logger.info(f"draft_tokens: {request.draft_token_ids}") self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + 1 self.share_inputs["draft_tokens"][idx : idx + 1, 0 : length + 1] = paddle.to_tensor( request.draft_token_ids[0 : length + 1],