From b81c27573635cf35e886c993c16c2dc6e83f0812 Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Mon, 11 May 2026 17:15:42 +0800 Subject: [PATCH 1/5] support mla dummy load & optimize mla mem about value padding --- fastdeploy/model_executor/layers/linear.py | 17 +++++++++++++++++ fastdeploy/model_executor/models/deepseek_v3.py | 3 ++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index bea36b3e05a..bb52bb0345d 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -1010,6 +1010,23 @@ def __init__( # Override weight keys to use the combined kv_b_proj self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight" + if self.fd_config.load_config.load_choices == "dummy": + # Create K projection weight + self.k_b_proj_weight = self.create_parameter( + shape=[self.num_heads_per_partition, qk_nope_head_dim, kv_lora_rank], + dtype=self.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + # Create V projection weight + self.v_b_proj_weight = self.create_parameter( + shape=[self.num_heads_per_partition, kv_lora_rank, v_head_dim], + dtype=self.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + def process_weights_after_loading(self): if self.fd_config.load_config.dynamic_load_weight: return diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 0ac87fa6dfa..6389789c112 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -420,7 +420,8 @@ def forward( key = paddle.empty([full_k_pe.shape[0], self.num_attention_heads_tp, self.qk_head_dim], dtype=query.dtype) key[..., : self.qk_nope_head_dim] = key_nope key[..., self.qk_nope_head_dim :] = full_k_pe.unsqueeze(1) - value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) + if self.qk_head_dim - self.v_head_dim != 0: + value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) fmha_out = self.mla_attn( q=query, From 19b36b453aa07a96bbcf9412d3e46e0d4f8e9695 Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Mon, 11 May 2026 17:17:52 +0800 Subject: [PATCH 2/5] support dummy run prefile --- fastdeploy/worker/gpu_model_runner.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 69795e671d5..dfb60306e0b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1173,8 +1173,9 @@ def get_input_length_list( # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. - if self.fd_config.parallel_config.enable_expert_parallel: - input_length = min(input_length, 32) + if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 0: + if self.fd_config.parallel_config.enable_expert_parallel: + input_length = min(input_length, 32) block_num = ( input_length + self.cache_config.block_size - 1 @@ -2030,6 +2031,12 @@ def _dummy_run( if self.enable_mm: model_inputs["image_features"] = self.share_inputs["image_features"] + if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 1: + import datetime + + paddle.distributed.barrier() + starttime = datetime.datetime.now() + # 3. Run model model_output = self.model( model_inputs, @@ -2059,6 +2066,13 @@ def _dummy_run( ) self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts) + if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 1: + paddle.distributed.barrier() + endtime = datetime.datetime.now() + duringtime = endtime - starttime + time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + print("The whole end to end time : ", time_ms, "ms") + # 7. Updata 'infer_seed' and step_cuda() if not self.speculative_decoding: self.share_inputs["infer_seed"].add_(self.infer_seed_increment) From 11b47622585d002b388c06091eb476284b880aa5 Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Mon, 11 May 2026 19:47:00 +0800 Subject: [PATCH 3/5] support dummy profile --- fastdeploy/envs.py | 2 ++ .../layers/moe/fused_moe_blackwell_backend.py | 7 +++++++ .../layers/moe/fused_moe_deepgemm_backend.py | 7 +++++++ fastdeploy/worker/gpu_model_runner.py | 6 +++--- fastdeploy/worker/worker_process.py | 2 +- 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 6be28f1f3be..7a9d6e412fc 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -279,6 +279,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))), # enable kv cache manager v1 "ENABLE_V1_KVCACHE_MANAGER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_MANAGER", "0")), + # run dummy run for profile + "FD_RUN_DUMMY_FOR_PROFILE": lambda: int(os.getenv("FD_RUN_DUMMY_FOR_PROFILE", "0")), } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_blackwell_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_blackwell_backend.py index 274deda8b69..2fa9eac7257 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_blackwell_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_blackwell_backend.py @@ -23,6 +23,7 @@ from paddleformers.utils.log import logger import fastdeploy +from fastdeploy import envs from fastdeploy.model_executor.layers.moe.ep import deep_ep from fastdeploy.model_executor.layers.quantization.fp8_utils import ( deep_gemm, @@ -623,6 +624,8 @@ def apply_ep_prefill( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") hidden_size = x.shape[1] @@ -963,6 +966,8 @@ def apply_ep_decode( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) @@ -1050,6 +1055,8 @@ def apply_tp( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") if layer.topk_method == "noaux_tc": diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 0264b11110c..18ffb4ee1e9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -24,6 +24,7 @@ from paddleformers.utils.log import logger import fastdeploy +from fastdeploy import envs from fastdeploy.model_executor.layers.moe.ep import deep_ep from fastdeploy.model_executor.layers.quantization.fp8_utils import ( deep_gemm, @@ -341,6 +342,8 @@ def apply_ep_prefill( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") hidden_size = layer.hidden_size @@ -674,6 +677,8 @@ def apply_ep_decode( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) @@ -790,6 +795,8 @@ def apply_tp( ) else: gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index dfb60306e0b..0233b9c7175 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1173,7 +1173,7 @@ def get_input_length_list( # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. - if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 0: + if envs.FD_RUN_DUMMY_FOR_PROFILE: if self.fd_config.parallel_config.enable_expert_parallel: input_length = min(input_length, 32) @@ -2031,7 +2031,7 @@ def _dummy_run( if self.enable_mm: model_inputs["image_features"] = self.share_inputs["image_features"] - if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 1: + if envs.FD_RUN_DUMMY_FOR_PROFILE: import datetime paddle.distributed.barrier() @@ -2066,7 +2066,7 @@ def _dummy_run( ) self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts) - if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 1: + if envs.FD_RUN_DUMMY_FOR_PROFILE: paddle.distributed.barrier() endtime = datetime.datetime.now() duringtime = endtime - starttime diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 28a943cf9d4..fbaed132c49 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -1333,7 +1333,7 @@ def run_worker_proc() -> None: # Instead of doing end to end tests which is very unstable, we can profile the following line of code to pick the best model. # so we add an environment variable RUN_DUMMY_FOR_PROFILE to control whether to run dummy run for profile. # Any Question refer to ChangWenBin. - if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 1: + if envs.FD_RUN_DUMMY_FOR_PROFILE: worker_proc.worker.model_runner._dummy_run( num_tokens=100, batch_size=1, expected_decode_len=10, step_use_cudagraph=True ) From 22c9d0c6221aa98f3dfba399ad848b5fe799984c Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Mon, 11 May 2026 19:47:54 +0800 Subject: [PATCH 4/5] support dummy profile --- .../layers/moe/fused_moe_cutlass_backend.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index faf5f774d6c..483fd8759db 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -23,6 +23,7 @@ from paddleformers.utils.log import logger import fastdeploy +from fastdeploy import envs from fastdeploy.platforms import current_platform from ..utils import get_tensor, group_wise_int4_weight_quantize, pack, rotate_model @@ -137,6 +138,8 @@ def apply_ep_prefill( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) @@ -292,6 +295,8 @@ def apply_ep_decode( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") estimate_total_token_nums = gate_out.shape[0] * layer.top_k # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) @@ -439,6 +444,8 @@ def apply_tp( use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() and not fc1_latent_proj if not use_fused: gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") if fc1_latent_proj is not None: x = fc1_latent_proj(x) gate_out, topk_weights, topk_idx = get_moe_scores( @@ -481,6 +488,8 @@ def apply_tp( ) else: gate_out = gate_out.cast("float32") + if envs.FD_RUN_DUMMY_FOR_PROFILE: + gate_out = paddle.randn_like(gate_out, dtype="float32") if fc1_latent_proj is not None: x = fc1_latent_proj(x) ( From b6e58027e4f65a95659e05a16c1100c7bc23bbff Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Mon, 11 May 2026 19:56:33 +0800 Subject: [PATCH 5/5] support dummy profile --- fastdeploy/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index aa33edf6191..050129841a5 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1173,7 +1173,7 @@ def get_input_length_list( # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. - if envs.FD_RUN_DUMMY_FOR_PROFILE: + if not envs.FD_RUN_DUMMY_FOR_PROFILE: if self.fd_config.parallel_config.enable_expert_parallel: input_length = min(input_length, 32)