-
Notifications
You must be signed in to change notification settings - Fork 742
[Cherry-Pick][Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap(#7323) #7794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,7 +49,10 @@ | |
| share_external_data, | ||
| update_attn_mask_offsets, | ||
| ) | ||
|
|
||
| # temporary solution | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 当前 建议后续提独立 issue 跟踪:将两个实现统一到 |
||
| from fastdeploy.model_executor.xpu_pre_and_post_process import ( | ||
| async_set_value, | ||
| xpu_pre_process, | ||
| xpu_process_output, | ||
| ) | ||
|
|
@@ -483,28 +486,32 @@ def insert_tasks_v1( | |
| input_ids = request.prompt_token_ids + request.output_token_ids | ||
|
|
||
| self.model_inputs["input_ids_len"][idx] = length - 1 | ||
| self.model_inputs["pre_ids"][idx : idx + 1] = -1 | ||
| async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1) | ||
| self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ | ||
| idx : idx + 1, 1:length | ||
| ] | ||
| self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ | ||
| "input_ids" | ||
| ][idx : idx + 1, 1:length].cpu() | ||
| # TODO: use token_all_ids replace with input_ids_cpu | ||
| if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs: | ||
| self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ | ||
| "input_ids" | ||
| ][idx : idx + 1, 1:length].cpu() | ||
| encoder_block_num = len(request.block_tables) | ||
| self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num | ||
| self.model_inputs["block_tables"][idx : idx + 1, :] = -1 | ||
| self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( | ||
| request.block_tables, dtype="int32" | ||
| async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) | ||
| async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) | ||
| async_set_value( | ||
| self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables | ||
| ) | ||
| self.model_inputs["stop_flags"][idx : idx + 1] = False | ||
| self.model_inputs["batch_drop"][idx : idx + 1] = False | ||
|
|
||
| self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length | ||
| async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False) | ||
| async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False) | ||
|
|
||
| async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length) | ||
| self.exist_prefill_flag = True | ||
| self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index | ||
| self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length | ||
| self.model_inputs["step_idx"][idx : idx + 1] = ( | ||
| len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 | ||
| async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index) | ||
| async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length) | ||
| async_set_value( | ||
| self.model_inputs["step_idx"][idx : idx + 1], | ||
| len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0, | ||
| ) | ||
| if self.use_attn_mask_offset: | ||
| inputs = request.multimodal_inputs | ||
|
|
@@ -522,18 +529,19 @@ def insert_tasks_v1( | |
| if ( | ||
| self.fd_config.scheduler_config.splitwise_role == "decode" | ||
| ): # In PD, we continue to decode after P generates first token | ||
| self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 | ||
| async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) | ||
| self.exist_prefill_flag = False | ||
| self.model_inputs["recompute_token_num"][idx : idx + 1] = 0 | ||
| self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1 | ||
| async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1) | ||
| # NOTE(liuzichang): | ||
| # extra 1 : P-D split need rollback one step | ||
| self.model_inputs["mask_rollback"][idx : idx + 1] = 1 | ||
|
|
||
| async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0) | ||
| async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1) | ||
| # has_prefill_task = True | ||
| elif request.task_type.value == RequestType.DECODE.value: # decode task | ||
| encoder_block_num = len(request.block_tables) | ||
| self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num | ||
| self.model_inputs["block_tables"][idx : idx + 1, :] = -1 | ||
| async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) | ||
| async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) | ||
| if current_platform.is_cuda(): | ||
| async_set_value( | ||
| self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables | ||
|
|
@@ -542,16 +550,13 @@ def insert_tasks_v1( | |
| self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( | ||
| request.block_tables, dtype="int32" | ||
| ) | ||
| # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode | ||
| # has_decode_task = True | ||
| # continue | ||
| else: | ||
| self.model_inputs["block_tables"][idx : idx + 1, :] = -1 | ||
| self.model_inputs["stop_flags"][idx : idx + 1] = True | ||
| self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0 | ||
| self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 | ||
| self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 | ||
| self.model_inputs["is_block_step"][idx : idx + 1] = False | ||
| async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) | ||
| async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True) | ||
| async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0) | ||
| async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0) | ||
| async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) | ||
| async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False) | ||
| continue | ||
|
|
||
| # TODO(liuzichang): Solve splitewise-p bug to restore | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议
except Exception过于宽泛当前将
except ImportError改为except Exception,会捕获所有异常(包括AttributeError、NameError、TypeError等非导入相关错误),可能让调试困难。例如_cuda_ver.split(".")[0]若返回非预期对象,int(...)会抛ValueError,被静默吞掉后cudart = None,只留一行 warning,难以排查根因。建议拆分为更精确的异常类型: