From 2355d0e83cf2769efe0132ea5e60a788a743e49a Mon Sep 17 00:00:00 2001 From: Clarity256 <1140021759@qq.com> Date: Wed, 17 Jun 2026 13:26:23 +0800 Subject: [PATCH] [XPU][Speculative Decoding] Enable CudaGraph capture for MTP draft model - Enable step_use_cudagraph for draft model with proper gating logic - Pass forward_meta and use_cudagraph to xpu_pre_process in draft path - Add padding_cudagraph_inputs() for draft model buffer management - Slice model output by real_token_num when graph is active - Adapt target model warmup and execute_model for MTP+CudaGraph - Use build_sampling_params kernel in verify path (replaces padding_sampling_params) - Fix memory issue by using copy_ instead of clone for seq_lens_this_time - Fix expected_decode_len for TP>1 in dummy_prefill Co-Authored-By: Clarity256 <1140021759@qq.com> --- .../model_executor/layers/sample/sampler.py | 34 +++---- .../xpu_pre_and_post_process.py | 8 +- fastdeploy/spec_decode/mtp_xpu.py | 37 +++++++- fastdeploy/worker/xpu_model_runner.py | 89 +++++++++++-------- ...mtp_cudagraph.py => test_mtp_cudagraph.py} | 0 5 files changed, 109 insertions(+), 59 deletions(-) rename tests/xpu_ci/4cards_cases/{run_mtp_cudagraph.py => test_mtp_cudagraph.py} (100%) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 70d6cc029bc..e0c2ce3e53b 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -61,6 +61,12 @@ build_sampling_params_logprob, naive_update_model_status, ) +elif current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + build_sampling_params, + top_p_candidates, + verify_draft_tokens, + ) def _apply_triton_top_k_top_p( @@ -1232,19 +1238,12 @@ def _normal_sample_xpu( share_inputs: List[paddle.Tensor], ) -> SamplerOutput: """Normal sampling for NAIVE mode on XPU.""" - top_p, top_k, topp_seed = padding_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.seed, - paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), - paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), - ) _, next_tokens = top_k_top_p_sampling( probs, - top_p=top_p, - top_k=top_k, + top_p=sampling_metadata.top_p, + top_k=sampling_metadata.top_k, top_k_list=sampling_metadata.top_k_list, - topp_seed=topp_seed, + topp_seed=sampling_metadata.seed, ) real_bsz = share_inputs["seq_lens_this_time"].shape[0] running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32") @@ -1264,25 +1263,24 @@ def _verify_and_sample_xpu( sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> SamplerOutput: """Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens.""" - from fastdeploy.model_executor.ops.xpu import ( - top_p_candidates, - verify_draft_tokens, - ) target_tokens = None candidate_ids, candidate_scores, candidate_lens = None, None, None if self.verify_strategy == VerifyStrategy.TARGET_MATCH: - top_p, top_k, topp_seed = padding_sampling_params( + top_p, top_k, topp_seed = build_sampling_params( sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.seed, - paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), - paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + token_num_output_cpu=int(share_inputs["cu_seqlens_q_output"][-1]), + increment_value=increment_value, ) _, target_tokens = top_k_top_p_sampling( probs, @@ -1344,6 +1342,7 @@ def forward_xpu( sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> SamplerOutput: @@ -1397,6 +1396,7 @@ def forward_xpu( sampling_metadata, max_model_len, share_inputs, + increment_value, accept_all_drafts, reject_all_drafts, ) diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index c104f91a596..592c411ecb6 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -137,8 +137,12 @@ def xpu_pre_process( ) = speculate_pre_process( token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder ) - share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output - share_inputs["batch_id_per_token_output"] = batch_id_per_token_output + if use_cudagraph: + share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) + share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) + else: + share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output + share_inputs["batch_id_per_token_output"] = batch_id_per_token_output else: ( ids_remove_padding, diff --git a/fastdeploy/spec_decode/mtp_xpu.py b/fastdeploy/spec_decode/mtp_xpu.py index b97eec5b66d..98cea92c8eb 100644 --- a/fastdeploy/spec_decode/mtp_xpu.py +++ b/fastdeploy/spec_decode/mtp_xpu.py @@ -106,6 +106,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_ru for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) + # 1. CUDA Graph capture sizes must be recorded in descending order (large → small). + # 2. In multi-step execution, only the first step should be captured. + self.forward_meta.step_use_cudagraph = ( + step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run) + ) + def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): """ Main process for MTP inference. @@ -126,6 +132,8 @@ def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, self.model_inputs["draft_tokens"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], + forward_meta=self.forward_meta, + use_cudagraph=self.draft_model_use_cudagraph, num_speculative_tokens=self.speculative_config.num_speculative_tokens, ) @@ -146,7 +154,12 @@ def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, ) self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) - self._initialize_forward_meta() + self._initialize_forward_meta( + step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep + ) + # Padding inputs for cuda graph + self.padding_cudagraph_inputs() + # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], @@ -168,13 +181,16 @@ def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, ) if self.num_model_steps > 1: - self.model_inputs.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) - + self.model_inputs.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False) + real_num = self.model_inputs["ids_remove_padding"].shape[0] + target_hidden_states = self.model_inputs["target_hidden_states"][:real_num] model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], - previous_hidden_states=self.model_inputs["target_hidden_states"], + previous_hidden_states=target_hidden_states, forward_meta=self.forward_meta, ) + if self.forward_meta.step_use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) @@ -298,3 +314,16 @@ def _update_status(self): self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["stop_flags"], ) + + def padding_cudagraph_inputs(self) -> None: + """ + Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. + In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. + """ + # In init_attention_metadata, the decode buffer has already been cleared + + # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. + if self.forward_meta.step_use_cudagraph: + self.forward_meta.seq_lens_this_time = self.model_inputs["seq_lens_this_time"] + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] + return diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index cbd167e3984..c5ce39a669f 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -171,9 +171,12 @@ def __init__( self.share_inputs.init_share_inputs() self.max_num_seqs = self.fd_config.scheduler_config.max_num_seqs + self.increment_value = ( + 4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4 + ) self.infer_seed_increment = paddle.full( shape=[self.scheduler_config.max_num_seqs, 1], - fill_value=4, + fill_value=self.increment_value, dtype="int64", ).cpu() @@ -847,22 +850,8 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: if self.use_cudagraph: # Update Batch type for cuda graph for only_decode_batch if_only_decode = self.only_decode() - - only_decode_use_cudagraph = self.use_cudagraph and if_only_decode - # Update config about moe for better performance - # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() - if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": - self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" - if self.speculative_decoding: - self.proposer.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill" - - # Update Batch type for cuda graph for only_prefill_batch - only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() - self.forward_meta.step_use_cudagraph = ( - only_prefill_use_cudagraph - if self.cudagraph_only_prefill - else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0 + self.use_cudagraph and if_only_decode and self.forward_meta.ids_remove_padding.shape[0] > 0 ) # Update bad tokens len @@ -874,11 +863,10 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": self.forward_meta.kv_signal_sender = self.share_inputs["kv_signal_sender"] - if ( - self.fd_config.scheduler_config.splitwise_role == "mixed" and envs.FD_XPU_ENABLE_MIXED_EP_MODE - ): # Centralized scenario: the phase is initialized as "prefill" by default. During inference runtime, different types of batches can achieve phase switching at this point. + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": if_only_decode = self.only_decode() self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" + # TODO: sync proposer.fd_config.model_config.moe_phase.phase for MTP draft model in mixed EP mode # Get sampling metadata # TODU(lilujia): sync with GPU @@ -1132,6 +1120,7 @@ def _dummy_run( batch_size: paddle.Tensor, expected_decode_len: int = 1, in_capturing: bool = False, + accept_all_drafts=False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -1156,11 +1145,11 @@ def _dummy_run( self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, - expected_decode_len=1, + expected_decode_len=expected_decode_len, ) while True: - self.execute_model(is_dummy_run=True, in_capturing=in_capturing) + self.execute_model(is_dummy_run=True, in_capturing=in_capturing, accept_all_drafts=accept_all_drafts) if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break @@ -1209,14 +1198,30 @@ def capture_model(self) -> None: capture_sizes = self.cudagraph_capture_sizes.copy() try: - for batch_size in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=batch_size, - expected_decode_len=expected_decode_len, - in_capturing=True, - ) - logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + if self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]: + for capture_size in sorted(capture_sizes, reverse=True): + expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2 + self._dummy_run( + num_tokens=self.fd_config.get_max_chunk_tokens(), + batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), + in_capturing=True, + expected_decode_len=expected_decode_len, + accept_all_drafts=True, + ) + logger.info( + f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}" + ) + else: + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=batch_size, + expected_decode_len=expected_decode_len, + in_capturing=True, + ) + logger.info( + f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" + ) except RuntimeError as e: if "out of memory" in str(e): raise RuntimeError( @@ -1273,6 +1278,7 @@ def execute_model( num_running_requests: int = None, is_dummy_run: bool = False, in_capturing: bool = False, + accept_all_drafts: bool = False, ) -> Optional[ModelRunnerOutput]: """ The Entrance of model execute. @@ -1286,14 +1292,18 @@ class at the server level, which is too granular for ModelRunner. # 0. set debug level # self._set_debug_level(0x1, model_forward_batch, is_dummy_run) with kv_signal_sender_context_manager(self.pd_disaggregation_mode) as sender: - self.share_inputs["kv_signal_sender"] = sender # 1. Prepare inputs of model and decoder. self._prepare_inputs(is_dummy_run=is_dummy_run) + # 2. Padding inputs for cuda graph + self.padding_cudagraph_inputs() if is_dummy_run: self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph - # 2. Padding inputs for cuda grph - self.padding_cudagraph_inputs() + else: + self.forward_meta.step_use_cudagraph = ( + self.forward_meta.step_use_cudagraph + and self.real_token_num <= self.fd_config.graph_opt_config.max_capture_size + ) num_tokens = self.share_inputs["ids_remove_padding"].shape[0] if not self.parallel_config.enable_expert_parallel and num_tokens <= 0: @@ -1310,7 +1320,7 @@ class at the server level, which is too granular for ModelRunner. model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] if self.enable_mm: model_inputs["image_features"] = self.share_inputs["image_features"] - # 3. Execute model + # 3. Execute model_output = self.model( model_inputs, forward_meta=self.forward_meta, @@ -1341,6 +1351,8 @@ class at the server level, which is too granular for ModelRunner. self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, + self.increment_value, + accept_all_drafts=accept_all_drafts, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -1438,13 +1450,18 @@ class at the server level, which is too granular for ModelRunner. # 6. Draft model propose if self.speculative_decoding and self.proposer is not None: if self.spec_method == SpecMethod.MTP: - self.proposer.run(full_hidden_states=model_output) + self.proposer.run( + full_hidden_states=model_output, + step_use_cudagraph=self.forward_meta.step_use_cudagraph, + is_dummy_run=is_dummy_run, + ) else: self.proposer.run(share_inputs=self.share_inputs) # 7. Updata 'infer_seed' and step_paddle() - self.share_inputs["infer_seed"].add_(self.infer_seed_increment) - self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + if not self.speculative_decoding: + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED if self.speculative_decoding: speculate_schedule_cache( diff --git a/tests/xpu_ci/4cards_cases/run_mtp_cudagraph.py b/tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py similarity index 100% rename from tests/xpu_ci/4cards_cases/run_mtp_cudagraph.py rename to tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py