-
Notifications
You must be signed in to change notification settings - Fork 753
[XPU][Speculative Decoding] Enable CudaGraph capture for MTP draft model #8061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,9 +171,12 @@ def __init__( | |
| self.share_inputs.init_share_inputs() | ||
| self.max_num_seqs = self.fd_config.scheduler_config.max_num_seqs | ||
|
|
||
| self.increment_value = ( | ||
| 4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4 | ||
| ) | ||
| self.infer_seed_increment = paddle.full( | ||
| shape=[self.scheduler_config.max_num_seqs, 1], | ||
| fill_value=4, | ||
| fill_value=self.increment_value, | ||
| dtype="int64", | ||
| ).cpu() | ||
|
|
||
|
|
@@ -847,22 +850,8 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: | |
| if self.use_cudagraph: | ||
| # Update Batch type for cuda graph for only_decode_batch | ||
| if_only_decode = self.only_decode() | ||
|
|
||
| only_decode_use_cudagraph = self.use_cudagraph and if_only_decode | ||
| # Update config about moe for better performance | ||
| # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() | ||
| if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": | ||
| self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" | ||
| if self.speculative_decoding: | ||
| self.proposer.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill" | ||
|
|
||
| # Update Batch type for cuda graph for only_prefill_batch | ||
| only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() | ||
|
|
||
| self.forward_meta.step_use_cudagraph = ( | ||
| only_prefill_use_cudagraph | ||
| if self.cudagraph_only_prefill | ||
| else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0 | ||
| self.use_cudagraph and if_only_decode and self.forward_meta.ids_remove_padding.shape[0] > 0 | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| ) | ||
|
|
||
| # Update bad tokens len | ||
|
|
@@ -874,11 +863,10 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: | |
| if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": | ||
| self.forward_meta.kv_signal_sender = self.share_inputs["kv_signal_sender"] | ||
|
|
||
| if ( | ||
| self.fd_config.scheduler_config.splitwise_role == "mixed" and envs.FD_XPU_ENABLE_MIXED_EP_MODE | ||
| ): # Centralized scenario: the phase is initialized as "prefill" by default. During inference runtime, different types of batches can achieve phase switching at this point. | ||
| if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": | ||
This comment was marked as outdated.
Sorry, something went wrong.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a pre-existing issue not introduced by this PR (only triggers in mixed EP mode, not PD separation). Added a TODO comment to track it. Will address in a follow-up when mixed EP + MTP scenario is validated. |
||
| if_only_decode = self.only_decode() | ||
| self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" | ||
| # TODO: sync proposer.fd_config.model_config.moe_phase.phase for MTP draft model in mixed EP mode | ||
|
|
||
| # Get sampling metadata | ||
| # TODU(lilujia): sync with GPU | ||
|
|
@@ -1132,6 +1120,7 @@ def _dummy_run( | |
| batch_size: paddle.Tensor, | ||
| expected_decode_len: int = 1, | ||
| in_capturing: bool = False, | ||
| accept_all_drafts=False, | ||
| ) -> paddle.Tensor: | ||
| """ | ||
| Use dummy inputs to run before formal execution. | ||
|
|
@@ -1156,11 +1145,11 @@ def _dummy_run( | |
| self.proposer.dummy_prefill_inputs( | ||
| num_tokens=num_tokens, | ||
| batch_size=batch_size, | ||
| expected_decode_len=1, | ||
| expected_decode_len=expected_decode_len, | ||
| ) | ||
|
|
||
| while True: | ||
| self.execute_model(is_dummy_run=True, in_capturing=in_capturing) | ||
| self.execute_model(is_dummy_run=True, in_capturing=in_capturing, accept_all_drafts=accept_all_drafts) | ||
|
|
||
| if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: | ||
| break | ||
|
|
@@ -1209,14 +1198,30 @@ def capture_model(self) -> None: | |
| capture_sizes = self.cudagraph_capture_sizes.copy() | ||
|
|
||
| try: | ||
| for batch_size in sorted(capture_sizes, reverse=True): | ||
| self._dummy_run( | ||
| num_tokens=self.scheduler_config.max_num_batched_tokens, | ||
| batch_size=batch_size, | ||
| expected_decode_len=expected_decode_len, | ||
| in_capturing=True, | ||
| ) | ||
| logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") | ||
| if self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]: | ||
| for capture_size in sorted(capture_sizes, reverse=True): | ||
| expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2 | ||
| self._dummy_run( | ||
| num_tokens=self.fd_config.get_max_chunk_tokens(), | ||
| batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), | ||
| in_capturing=True, | ||
| expected_decode_len=expected_decode_len, | ||
| accept_all_drafts=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}" | ||
| ) | ||
| else: | ||
| for batch_size in sorted(capture_sizes, reverse=True): | ||
| self._dummy_run( | ||
| num_tokens=self.scheduler_config.max_num_batched_tokens, | ||
| batch_size=batch_size, | ||
| expected_decode_len=expected_decode_len, | ||
| in_capturing=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" | ||
| ) | ||
| except RuntimeError as e: | ||
| if "out of memory" in str(e): | ||
| raise RuntimeError( | ||
|
|
@@ -1273,6 +1278,7 @@ def execute_model( | |
| num_running_requests: int = None, | ||
| is_dummy_run: bool = False, | ||
| in_capturing: bool = False, | ||
| accept_all_drafts: bool = False, | ||
| ) -> Optional[ModelRunnerOutput]: | ||
| """ | ||
| The Entrance of model execute. | ||
|
|
@@ -1286,14 +1292,18 @@ class at the server level, which is too granular for ModelRunner. | |
| # 0. set debug level | ||
| # self._set_debug_level(0x1, model_forward_batch, is_dummy_run) | ||
| with kv_signal_sender_context_manager(self.pd_disaggregation_mode) as sender: | ||
|
|
||
| self.share_inputs["kv_signal_sender"] = sender | ||
| # 1. Prepare inputs of model and decoder. | ||
| self._prepare_inputs(is_dummy_run=is_dummy_run) | ||
| # 2. Padding inputs for cuda graph | ||
| self.padding_cudagraph_inputs() | ||
| if is_dummy_run: | ||
| self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph | ||
| # 2. Padding inputs for cuda grph | ||
| self.padding_cudagraph_inputs() | ||
| else: | ||
| self.forward_meta.step_use_cudagraph = ( | ||
| self.forward_meta.step_use_cudagraph | ||
| and self.real_token_num <= self.fd_config.graph_opt_config.max_capture_size | ||
| ) | ||
|
|
||
| num_tokens = self.share_inputs["ids_remove_padding"].shape[0] | ||
| if not self.parallel_config.enable_expert_parallel and num_tokens <= 0: | ||
|
|
@@ -1310,7 +1320,7 @@ class at the server level, which is too granular for ModelRunner. | |
| model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] | ||
| if self.enable_mm: | ||
| model_inputs["image_features"] = self.share_inputs["image_features"] | ||
| # 3. Execute model | ||
| # 3. Execute | ||
| model_output = self.model( | ||
| model_inputs, | ||
| forward_meta=self.forward_meta, | ||
|
|
@@ -1341,6 +1351,8 @@ class at the server level, which is too granular for ModelRunner. | |
| self.sampling_metadata, | ||
| self.model_config.max_model_len, | ||
| self.share_inputs, | ||
| self.increment_value, | ||
| accept_all_drafts=accept_all_drafts, | ||
| ) | ||
| if self.parallel_config.tensor_parallel_size > 1: | ||
| paddle.distributed.broadcast( | ||
|
|
@@ -1438,13 +1450,18 @@ class at the server level, which is too granular for ModelRunner. | |
| # 6. Draft model propose | ||
| if self.speculative_decoding and self.proposer is not None: | ||
| if self.spec_method == SpecMethod.MTP: | ||
| self.proposer.run(full_hidden_states=model_output) | ||
| self.proposer.run( | ||
| full_hidden_states=model_output, | ||
| step_use_cudagraph=self.forward_meta.step_use_cudagraph, | ||
| is_dummy_run=is_dummy_run, | ||
| ) | ||
| else: | ||
| self.proposer.run(share_inputs=self.share_inputs) | ||
|
|
||
| # 7. Updata 'infer_seed' and step_paddle() | ||
| self.share_inputs["infer_seed"].add_(self.infer_seed_increment) | ||
| self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED | ||
| if not self.speculative_decoding: | ||
| self.share_inputs["infer_seed"].add_(self.infer_seed_increment) | ||
| self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED | ||
|
|
||
| if self.speculative_decoding: | ||
| speculate_schedule_cache( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.