From b24765a746d925be2ebb74ae5a91c369f666292b Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:29:22 +0800 Subject: [PATCH 001/143] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 55e78125f5a..d0edf2c907b 100644 --- a/setup.py +++ b/setup.py @@ -251,7 +251,7 @@ def get_name(): cmdclass_dict = {"bdist_wheel": CustomBdistWheel} cmdclass_dict["build_ext"] = CMakeBuild -FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.5.0-dev") +FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.6.0") cmdclass_dict["build_optl"] = PostInstallCommand From 55dbc83310d6b693e42b26cb940538a3ed9279af Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:46:13 +0800 Subject: [PATCH 002/143] [Cherry-Pick][BugFix] prevent requests from entering running state without a slot(#7141) (#7181) * [BugFix] Set MC_MAX_MR_SIZE to avoid register hang (#7163) * Set MC_MAX_MR_SIZE to avoid register hang * up * [fix] prevent requests from entering running state without a slot * [fix] count abort set * [fix] count preempted task in waiting list --------- Co-authored-by: jc <52520497+juncaipeng@users.noreply.github.com> --- .../transfer_factory/mooncake_store/mooncake_store.py | 2 +- fastdeploy/engine/common_engine.py | 3 +-- fastdeploy/engine/sched/resource_manager_v1.py | 8 +++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py index 3fc10996a61..ba7d003b7ae 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -112,7 +112,7 @@ def __init__(self, tp_rank=None): os.environ["MC_TCP_BIND_ADDRESS"] = host_ip logger.info(f"Set MC_TCP_BIND_ADDRESS to {host_ip}") if os.environ.get("MC_MAX_MR_SIZE") is None: - os.environ["MC_MAX_MR_SIZE"] = "4294967296" # 4GB + os.environ["MC_MAX_MR_SIZE"] = str(4 * 1024**3) # 4GB logger.info("MC_MAX_MR_SIZE is not set, default to 4GB.") try: diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 61a9914225b..f1152c6e22c 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1152,8 +1152,7 @@ def _fetch_request(): time.sleep(0.005) except RuntimeError as e: - if "cannot schedule new futures after shutdown" in str(e): - break + raise e except Exception as e: err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) self.llm_logger.error(err_msg) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 5af1605cdaf..80b58d68972 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -941,7 +941,13 @@ def _allocate_decode_and_extend(): if not preempted_reqs: skip_requests: list[Request] = [] while self.waiting and token_budget > 0: - if len(self.running) == self.max_num_seqs: + if ( + len(self.running) + + len(self.to_be_rescheduled_request_id_set) + + len(self.to_be_aborted_req_id_set) + + sum([req.status == RequestStatus.PREEMPTED for req in self.waiting]) + >= self.max_num_seqs + ): break request = self.waiting[0] From 7ab48c47605cec3c7465633656736f2501b139bc Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 3 Apr 2026 20:55:53 +0800 Subject: [PATCH 003/143] [Cherry-Pick][CI] Use GPU-Build-RL runner for _build_linux_rl.yml (#7186) (#7195) --- .github/workflows/_build_linux_rl.yml | 4 ++-- .github/workflows/ce_job.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/_build_linux_rl.yml b/.github/workflows/_build_linux_rl.yml index ede288c805a..9e809d59a59 100644 --- a/.github/workflows/_build_linux_rl.yml +++ b/.github/workflows/_build_linux_rl.yml @@ -8,7 +8,7 @@ on: description: "Build Images" required: true type: string - default: "iregistry.baidu-int.com/tiangexiao/base-images:paddlecloud-ubuntu24.04-gcc13.3-cuda12.9-cudnn9.9-bccl1.4.1.4-nccl2.26.5-openmpi4.1.5-FleetY13.0.0-rc2" + default: "iregistry.baidu-int.com/new_rl_infra/base-images:paddlecloud-ubuntu24.04-gcc13.3-cuda12.9-cudnn9.9-bccl1.4.1.4-nccl2.26.5-openmpi4.1.5-FleetY13.0.0-v2.4.0-rc1" FASTDEPLOY_ARCHIVE_URL: description: "URL of the compressed FastDeploy code archive." required: true @@ -54,7 +54,7 @@ on: value: ${{ jobs.fd-build-rl.outputs.wheel_path_rl }} jobs: fd-build-rl: - runs-on: [self-hosted, GPU-Build] + runs-on: [self-hosted, GPU-Build-RL] timeout-minutes: 360 outputs: wheel_path_rl: ${{ steps.set_output.outputs.wheel_path_rl }} diff --git a/.github/workflows/ce_job.yml b/.github/workflows/ce_job.yml index 5b20eccdf2e..30775a455a6 100644 --- a/.github/workflows/ce_job.yml +++ b/.github/workflows/ce_job.yml @@ -186,7 +186,7 @@ jobs: COMPILE_ARCH: "80,90" WITH_NIGHTLY_BUILD: OFF FD_VERSION: 0.0.0 - PADDLE_WHL_URL: https://paddle-qa.bj.bcebos.com/paddle-pipeline/Paddle-RL-Compile/develop/latest/paddlepaddle_gpu-3.3.0.dev-cp310-cp310-linux_x86_64.whl + PADDLE_WHL_URL: https://paddle-qa.bj.bcebos.com/paddle-pipeline/Develop-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/latest/paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl build_sm8689: name: BUILD_SM8689 From 36909bf27d399389cf56731999cd0427eded1d2c Mon Sep 17 00:00:00 2001 From: huicongyao Date: Wed, 8 Apr 2026 10:24:38 +0800 Subject: [PATCH 004/143] [Cherry-Pick][BugFix] fix MTP bugs in TP and overlap(#7172) (#7192) * fix MTP bugs in TP and overlap * fix --- .../gpu_ops/speculate_decoding/speculate_save_output.cc | 5 +++-- .../speculate_decoding/speculate_save_output_with_topk.cc | 4 +++- fastdeploy/worker/gpu_model_runner.py | 4 +--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc index 2a040a7e7b4..f72f3774107 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc @@ -36,8 +36,9 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, int msg_queue_id, int save_each_rank, bool skip_prefill) { - // printf("enter save output"); - if (!save_each_rank && rank_id > 0) { + // NOTE(yaohuicong): Skip non-zero TP ranks — they share identical sampling + // outputs, so only rank 0 needs to send results to the message queue. + if (rank_id > 0) { return; } diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 53e822e6223..3d75886bd25 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -53,7 +53,9 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, int message_flag, // Target: 3, Draft: 4 int64_t rank_id, bool save_each_rank) { - if (!save_each_rank && rank_id > 0) { + // NOTE(yaohuicong): Skip non-zero TP ranks — they share identical sampling + // outputs, so only rank 0 needs to send results to the message queue. + if (rank_id > 0) { return; } diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c0e689735d4..bc315c3646b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -345,9 +345,7 @@ def _predict_next_launch_token_num(self) -> int: is_block_step_cpu = self.share_inputs["is_block_step_cpu"].numpy() next_real_bsz = (seq_lens_this_time_cpu > 0).sum().item() + (is_block_step_cpu > 0).sum().item() token_num_one_step = (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1 - next_launch_token_num = ( - seq_lens_this_time_cpu.sum().item() + is_block_step_cpu.sum().item() * token_num_one_step - ) + next_launch_token_num = next_real_bsz * token_num_one_step return next_launch_token_num, next_real_bsz def only_prefill(self): From 403ce139c7e667102f907d901141cc1f5e1e9c8d Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Wed, 8 Apr 2026 15:25:21 +0800 Subject: [PATCH 005/143] remove arctic_inference deps (#7236) --- fastdeploy/spec_decode/suffix.py | 2 +- requirements.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/fastdeploy/spec_decode/suffix.py b/fastdeploy/spec_decode/suffix.py index f4d1495524c..e11c2255a3e 100644 --- a/fastdeploy/spec_decode/suffix.py +++ b/fastdeploy/spec_decode/suffix.py @@ -43,7 +43,7 @@ def __init__(self, fd_config: "FDConfig"): if SuffixDecodingCache is None: raise ImportError( - "arctic_inference.suffix_decoding is not available. Please install arctic-inference package." + "arctic_inference.suffix_decoding is not available. Please install via `pip install arctic-inference==0.1.2`." ) # Initialize SuffixDecodingCache diff --git a/requirements.txt b/requirements.txt index a6a7b6619c9..e662f07e974 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,5 +48,4 @@ p2pstore py-cpuinfo flashinfer-python-paddle flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl -arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl transformers>=4.55.1,<5.0.0 From 6b78981dde307908718add1be50dc86f3e478737 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:32:04 +0800 Subject: [PATCH 006/143] Split enable_mm (#7183) (#7233) Co-authored-by: K11OntheBoat Co-authored-by: liuruian --- fastdeploy/config.py | 33 +++++++++++-- fastdeploy/engine/async_llm.py | 3 +- fastdeploy/engine/common_engine.py | 8 +-- .../engine/sched/resource_manager_v1.py | 6 +-- fastdeploy/entrypoints/engine_client.py | 4 +- fastdeploy/input/preprocess.py | 5 +- .../inter_communicator/engine_worker_queue.py | 2 - .../layers/attention/append_attn_backend.py | 4 +- .../layers/attention/dsa_attention_backend.py | 2 +- .../layers/attention/flash_attn_backend.py | 4 +- .../attention/flash_mask_attn_backend.py | 4 +- .../layers/attention/mla_attention_backend.py | 2 +- .../iluvatar/attention/mha_attn_backend.py | 2 +- .../intel_hpu/attention/hpu_attn_backend.py | 2 +- .../metax/attention/flash_attn_backend.py | 4 +- .../metax/attention/mla_attn_metax_backend.py | 2 +- .../layers/backends/xpu/attention.py | 4 +- fastdeploy/output/token_processor.py | 1 + fastdeploy/spec_decode/base.py | 2 +- fastdeploy/spec_decode/mtp.py | 2 +- fastdeploy/worker/gcu_model_runner.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 4 +- fastdeploy/worker/iluvatar_worker.py | 2 +- fastdeploy/worker/input_batch.py | 49 +++++++++---------- fastdeploy/worker/metax_model_runner.py | 2 +- fastdeploy/worker/worker_process.py | 4 +- fastdeploy/worker/xpu_model_runner.py | 2 +- tests/distributed/chunked_moe.py | 3 +- tests/entrypoints/test_engine_client.py | 7 +++ ...est_kv_cache_int8_dynamic_quant_backend.py | 3 ++ .../test_chunked_prefill_determinism.py | 2 + tests/worker/test_gpu_prompt_logprobs.py | 1 + .../test_reorder_split_prefill_and_decode.py | 1 + 33 files changed, 109 insertions(+), 69 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index b15a6dc824b..8b92a138a34 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1980,6 +1980,7 @@ def expand_bsz_map(real_bsz_to_captured_size): int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0 and self.model_config is not None and self.model_config.enable_mm + and self.deploy_modality != DeployModality.TEXT ): self.max_prefill_batch = 1 # TODO:当前V0多模prefill阶段只支持并行度为1,待优化 else: @@ -2019,6 +2020,20 @@ def expand_bsz_map(real_bsz_to_captured_size): self.check() # self.print() # NOTE: it's better to explicitly call .print() when FDConfig is initialized + @property + def enable_mm_runtime(self) -> bool: + return ( + self.model_config is not None + and self.model_config.enable_mm + and self.deploy_modality != DeployModality.TEXT + ) + + @property + def enable_rope_3d_runtime(self) -> bool: + return self.enable_mm_runtime and ( + getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False) + ) + def _disable_sequence_parallel_moe_if_needed(self, mode_name): if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: self.parallel_config.use_sequence_parallel_moe = False @@ -2057,9 +2072,21 @@ def postprocess(self): if self.long_prefill_token_threshold == 0: self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04) + if ( + self.model_config is not None + and self.model_config.enable_mm + and self.deploy_modality == DeployModality.TEXT + ): + if getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False): + logger.info( + "Deploy modality is text; forcing the multimodal-capable model onto the 2D RoPE runtime path." + ) + setattr(self.model_config, "rope_3d", False) + setattr(self.model_config, "use_3d_rope", False) + self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size) self.cache_config.postprocess(self.get_max_chunk_tokens(), self.scheduler_config.max_num_seqs) - if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER: + if self.model_config is not None and self.enable_mm_runtime and not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.cache_config.enable_prefix_caching = False if ( self.structured_outputs_config is not None @@ -2085,7 +2112,7 @@ def postprocess(self): f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]" ) - if self.model_config.enable_mm: + if self.enable_mm_runtime: if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0: self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens elif self.cache_config.max_encoder_cache != 0: @@ -2392,7 +2419,7 @@ def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): num_tokens = self.scheduler_config.max_num_seqs else: num_tokens = self.scheduler_config.max_num_batched_tokens - if mm_max_tokens_per_item is not None and self.deploy_modality != DeployModality.TEXT: + if self.enable_mm_runtime and mm_max_tokens_per_item is not None: max_mm_tokens = max( mm_max_tokens_per_item.get("image", 0), mm_max_tokens_per_item.get("video", 0), diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index 4afb3dc5c49..c06292ec981 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -294,6 +294,7 @@ def __init__(self, cfg, pid): cfg.limit_mm_per_prompt, cfg.mm_processor_kwargs, cfg.tool_parser, + enable_mm_runtime=cfg.enable_mm_runtime, ) # Create data processor self.data_processor = self.input_processor.create_processor() @@ -446,7 +447,7 @@ async def add_request( ) if envs.ZMQ_SEND_BATCH_DATA and self.connection_manager is not None: request["zmq_worker_pid"] = self.connection_manager.worker_pid - if self.cfg.model_config.enable_mm: + if self.cfg.enable_mm_runtime: self.request_client.send_pyobj(request) else: self.request_client.send_json(request) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index f1152c6e22c..cd9e42f8bcf 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -330,6 +330,7 @@ def create_data_processor(self): self.cfg.limit_mm_per_prompt, self.cfg.mm_processor_kwargs, self.cfg.tool_parser, + enable_mm_runtime=self.cfg.enable_mm_runtime, ) self.data_processor = self.input_processor.create_processor() self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item( @@ -601,7 +602,7 @@ def insert_tasks(self, tasks: List[Request], current_id=-1): LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "") ) if not is_prefill: - if not self.cfg.model_config.enable_mm: + if not self.cfg.enable_mm_runtime: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) @@ -1217,7 +1218,7 @@ def _insert_zmq_task_to_scheduler(self): while self.running: try: block = True if len(added_requests) == 0 else False - if not self.cfg.model_config.enable_mm: + if not self.cfg.enable_mm_runtime: err, data = self.recv_request_server.receive_json_once(block) else: err, data = self.recv_request_server.receive_pyobj_once(block) @@ -1275,6 +1276,7 @@ def _insert_zmq_task_to_scheduler(self): err_msg = None try: request = Request.from_dict(data) + request.metrics.scheduler_recv_req_time = time.time() main_process_metrics.requests_number.inc() trace_carrier = data.get("trace_carrier") @@ -2377,7 +2379,7 @@ def _setting_environ_variables(self): if self.cfg.scheduler_config.splitwise_role == "prefill": variables["FLAGS_fmt_write_cache_completed_signal"] = 1 - if self.cfg.model_config.enable_mm: + if self.cfg.enable_mm_runtime: variables["FLAGS_max_partition_size"] = 1024 command_prefix = "" diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 80b58d68972..45ec18aa1c0 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -205,11 +205,11 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.need_block_num_map = dict() self.encoder_cache = None - if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0: + if config.enable_mm_runtime and config.cache_config.max_encoder_cache > 0: self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) self.processor_cache = None - if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0: + if config.enable_mm_runtime and config.cache_config.max_processor_cache > 0: max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) @@ -550,7 +550,7 @@ def _get_num_new_tokens(self, request, token_budget): num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size request.with_image = False - if not self.config.model_config.enable_mm: + if not self.config.enable_mm_runtime: return num_new_tokens inputs = request.multimodal_inputs diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index f03a18594de..4c56e9bcd76 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -84,7 +84,7 @@ class EngineClient: def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20): self.fd_config = fd_config self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size - self.enable_mm = self.fd_config.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.max_logprobs = max_logprobs input_processor = InputPreprocessor( self.fd_config.model_config, @@ -93,6 +93,7 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers self.fd_config.mm_processor_kwargs, self.fd_config.tool_parser, self.enable_mm and self.fd_config.cache_config.max_processor_cache > 0, + enable_mm_runtime=self.enable_mm, ) self.enable_logprob = self.fd_config.model_config.enable_logprob self.data_processor = input_processor.create_processor() @@ -358,6 +359,7 @@ async def add_requests(self, task): task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) min_tokens = task.get("min_tokens", 1) + if "messages" in task: task["messages"] = None api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}") diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 755f0612def..6467f6a89ac 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -48,6 +48,7 @@ def __init__( mm_processor_kwargs: Optional[Dict[str, Any]] = None, tool_parser: str = None, enable_processor_cache: bool = False, + enable_mm_runtime: Optional[bool] = None, ) -> None: self.model_config = model_config self.model_name_or_path = self.model_config.model @@ -56,6 +57,7 @@ def __init__( self.mm_processor_kwargs = mm_processor_kwargs self.tool_parser = tool_parser self.enable_processor_cache = enable_processor_cache + self.enable_mm_runtime = self.model_config.enable_mm if enable_mm_runtime is None else enable_mm_runtime def create_processor(self): reasoning_parser_obj = None @@ -77,10 +79,11 @@ def create_processor(self): reasoning_parser_obj=reasoning_parser_obj, tool_parser_obj=tool_parser_obj, mm_processor_kwargs=self.mm_processor_kwargs, + enable_mm_runtime=self.enable_mm_runtime, ) except Exception as e: logger.info(f"Plugin input processor not available ({e}), using built-in processor") - if not self.model_config.enable_mm: + if not self.enable_mm_runtime: from fastdeploy.input.text_processor import TextProcessor tokenizer_type = "ernie4_5" if ErnieArchitectures.contains_ernie_arch(architecture) else "auto" diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index b64fcacda33..a7876669f8f 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -549,7 +549,6 @@ def put_tasks(self, tasks: List[Any]) -> None: self.lock.release() time.sleep(0.001) self.lock.acquire() - if envs.FD_ENABLE_MAX_PREFILL or envs.FD_ENABLE_E2W_TENSOR_CONVERT: # multimodal input numpy -> tensor to_tensor(tasks[0]) @@ -571,7 +570,6 @@ def get_tasks(self) -> Tuple[List[Any], bool]: """ tasks: List[Any] = list() self.lock.acquire() - tasks.extend(self.tasks) self.client_read_flag[self.client_id] = 1 all_client_read: bool = np.sum(self.client_read_flag) == self.num_client diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 81eab7cce86..15b657c249d 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -138,9 +138,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) diff --git a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py index 4d6bbcdfb7d..66a92a52599 100644 --- a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py @@ -136,7 +136,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index b51dce1449d..bcffcd0bac0 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -267,9 +267,7 @@ def __init__( self.rank, self.device_id = init_rank_and_device_id(fd_config) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime if fd_config.speculative_config.model_type != "main": self.rope_3d = False # Note(ZKK): here must be consistent with append_attn_backend.py diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 35d27504ab5..6ebea2cb3d9 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -121,9 +121,7 @@ def __init__( self.rank, self.device_id = init_rank_and_device_id(fd_config) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 61ccc4e16e7..209817e69a2 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -263,7 +263,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None diff --git a/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py b/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py index 092912149a9..d01973f80d0 100644 --- a/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py @@ -89,7 +89,7 @@ def __init__( # note: scale need to change if using MLA self.scale = 1.0 / sqrt(head_dim) self.dtype = paddle.get_default_dtype() - self.enable_mm = fd_config.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.rope_batch_stride = self.max_context_len * self.head_dim if self.enable_mm else 0 if "paddleocr" in fd_config.model_config.model_type: self.is_interleaved_rope_mode = False diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py index 82938c87367..bd2d8505228 100644 --- a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py @@ -219,7 +219,7 @@ def __init__( self.block_size = llm_config.cache_config.block_size self.max_seq_len = llm_config.model_config.max_model_len self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta - self.rope_3d = getattr(llm_config.model_config, "rope_3d", False) + self.rope_3d = llm_config.enable_rope_3d_runtime self.causal = getattr(llm_config.model_config, "causal", True) self.speculative_method = llm_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index 74fc27f67b4..0fd3553fda9 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -101,7 +101,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None @@ -128,7 +128,7 @@ def __init__( fd_config.parallel_config.expert_parallel_rank = 0 self.rank, self.device_id = init_rank_and_device_id(fd_config) - self.enable_mm = fd_config.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.model_type = fd_config.model_config.model_type self.is_neox_style = False if "paddleocr" in fd_config.model_config.model_type: diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index c905086f9f7..dcd5589f0d8 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -105,7 +105,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None diff --git a/fastdeploy/model_executor/layers/backends/xpu/attention.py b/fastdeploy/model_executor/layers/backends/xpu/attention.py index 85565d33efb..31fce9bdf51 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/attention.py +++ b/fastdeploy/model_executor/layers/backends/xpu/attention.py @@ -88,9 +88,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 1ab0b48f350..85e54647b7e 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -947,6 +947,7 @@ def _process_batch_output(self): if not is_prefill: self._record_completion_metrics(task, current_time) llm_logger.info(f"task {task_id} received eos token. Recycling.") + if ( envs.ENABLE_V1_KVCACHE_SCHEDULER and self.cfg.cache_config.enable_prefix_caching diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index fa50eae462a..8db764fcf12 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -71,7 +71,7 @@ def __init__(self, fd_config: "FDConfig"): self.max_ngram_size = self.speculative_config.max_ngram_size self.min_ngram_size = self.speculative_config.min_ngram_size - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime spec_logger.info(f"Speculate config: {self.speculative_config}") diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 4ec57e93594..9ca5f535ab0 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -103,7 +103,7 @@ def __init__( self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id - self.use_attn_mask_offset = self.enable_mm and self.fd_config.deploy_modality != "text" + self.use_attn_mask_offset = self.enable_mm self._update_mtp_config(main_model) self._load_model() diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 44a8c5f3578..284cee8843d 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -62,7 +62,7 @@ def __init__( local_rank: int, ): super().__init__(fd_config=fd_config, device=device) - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bc315c3646b..2454b016209 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -119,7 +119,7 @@ def __init__( ): super().__init__(fd_config=fd_config, device=device) self.MAX_INFER_SEED = 9223372036854775806 - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id @@ -1118,10 +1118,12 @@ def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None: """Prepare the model inputs""" + if self.enable_mm and self.share_inputs["image_features_list"] is not None: tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)] if tensor_feats: self.share_inputs["image_features"] = paddle.concat(tensor_feats, axis=0) + recover_decode_task( self.share_inputs["stop_flags"], self.share_inputs["seq_lens_this_time"], diff --git a/fastdeploy/worker/iluvatar_worker.py b/fastdeploy/worker/iluvatar_worker.py index 625aca86db1..44be900bb73 100644 --- a/fastdeploy/worker/iluvatar_worker.py +++ b/fastdeploy/worker/iluvatar_worker.py @@ -40,7 +40,7 @@ def __init__( local_rank: int, rank: int, ): - if fd_config.model_config.enable_mm: + if fd_config.enable_mm_runtime: paddle.set_flags({"FLAGS_enable_ixattnbkd": True, "FLAGS_enable_ixdnn_attn": False}) super(IluvatarWorker, self).__init__( fd_config=fd_config, diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 363dfb63097..55a3f39a2ee 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -17,13 +17,7 @@ import paddle from paddleformers.utils.log import logger -from fastdeploy.config import ( - CacheConfig, - DeployModality, - FDConfig, - ModelConfig, - SpeculativeConfig, -) +from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.logits_processor import build_logits_processors from fastdeploy.platforms import current_platform @@ -101,7 +95,8 @@ def __init__(self, fd_config: FDConfig) -> None: self.scheduler_config = fd_config.scheduler_config self.speculative_config: SpeculativeConfig = fd_config.speculative_config self.speculative_decoding = self.speculative_config.method is not None - self.enable_mm = self.model_config.enable_mm + self.is_mm_model = self.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.enable_expert_parallel = fd_config.parallel_config.enable_expert_parallel self.index_to_batch_id = {} self.enable_pd_reorder = False @@ -231,6 +226,9 @@ def init_share_inputs(self): model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) + if self.is_mm_model: + self.image_features = None + self.image_features_list = None # Set block tables pre_max_block_num = ( @@ -677,6 +675,9 @@ def reset_share_inputs(self): model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) + if self.is_mm_model: + self.image_features = None + self.image_features_list = None # Reset other miscellaneous tensors fill_paddle_tensor(self, "mask_rollback", 0) @@ -689,7 +690,7 @@ def reset_share_inputs(self): class ProposerInputBatch(InputBatch): def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> None: - self.enable_mm = fd_config.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.num_model_steps = fd_config.speculative_config.num_model_steps self.index_to_batch_id = {} self.target_model_input_batch = target_model_input_batch @@ -863,18 +864,15 @@ def init_share_inputs(self): -1, dtype="int32", ) - if self.fd_config.deploy_modality != DeployModality.TEXT: - self.attn_mask_offsets = paddle.full( - shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len], - fill_value=-1, - dtype="int32", - ) - self.attn_mask_offsets_full = paddle.full( - [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32" - ) - self.attn_mask_offsets_decoder = paddle.full( - [self.scheduler_config.max_num_seqs, 1], -1, dtype="int32" - ) + self.attn_mask_offsets = paddle.full( + shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len], + fill_value=-1, + dtype="int32", + ) + self.attn_mask_offsets_full = paddle.full( + [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32" + ) + self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32") def swap_states(self, i1, i2) -> None: def swap_data(tensor, idx1, idx2): @@ -896,7 +894,7 @@ def swap_data(tensor, idx1, idx2): swap_data(self.input_ids_len, i1, i2) swap_data(self.mask_rollback, i1, i2) swap_data(self.recompute_token_num, i1, i2) - if self.enable_mm and self.fd_config.deploy_modality != DeployModality.TEXT: + if self.enable_mm: swap_data(self.attn_mask_offsets_full, i1, i2) swap_data(self.attn_mask_offsets_decoder, i1, i2) @@ -1030,10 +1028,9 @@ def reset_model_inputs(self) -> None: # Reset multimodal tensors if enabled if self.enable_mm: fill_paddle_tensor(self, "decode_states", -1) - if self.fd_config.deploy_modality != DeployModality.TEXT: - fill_paddle_tensor(self, "attn_mask_offsets", -1) - fill_paddle_tensor(self, "attn_mask_offsets_full", -1) - fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1) + fill_paddle_tensor(self, "attn_mask_offsets", -1) + fill_paddle_tensor(self, "attn_mask_offsets_full", -1) + fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1) logger.info("model_inputs reset completed") except Exception as e: diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 93f5cec6a57..28c769e1166 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -97,7 +97,7 @@ def __init__( ): super().__init__(fd_config=fd_config, device=device) self.MAX_INFER_SEED = 9223372036854775806 - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8182e06990b..2f51359959f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -138,7 +138,7 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]: def update_fd_config_for_mm(fd_config: FDConfig) -> None: architectures = fd_config.model_config.architectures - if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures): + if fd_config.enable_mm_runtime and ErnieArchitectures.contains_ernie_arch(architectures): fd_config.model_config.tensor_model_parallel_size = fd_config.parallel_config.tensor_parallel_size fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype @@ -506,7 +506,7 @@ def event_loop_normal(self) -> None: if tp_rank == 0: if self.task_queue.exist_tasks(): if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( - self.fd_config.model_config.enable_mm and self.worker.exist_prefill() + self.fd_config.enable_mm_runtime and self.worker.exist_prefill() ): if self.nnode > 1: self.task_queue.read_finish_flag.set(1) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 1446257d3ae..bd585519520 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -97,7 +97,7 @@ def __init__( local_rank: int, ): super().__init__(fd_config=fd_config, device=device) - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id diff --git a/tests/distributed/chunked_moe.py b/tests/distributed/chunked_moe.py index fee1582f3c7..0fe9f9f3974 100644 --- a/tests/distributed/chunked_moe.py +++ b/tests/distributed/chunked_moe.py @@ -92,6 +92,7 @@ class SchedulerConfig: model_config = MockModelConfig() cache_config = MockCacheConfig() speculative_config = MockSpecaulativeConfig() + enable_mm_runtime = MockModelConfig.enable_mm def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): return 8192 @@ -139,7 +140,7 @@ def setup_model_runner(self): model_runner.model_config = mock_model_config model_runner.cache_config = mock_cache_config model_runner.attn_backends = [MockAttentionBackend()] - model_runner.enable_mm = True + model_runner.enable_mm = mock_fd_config.enable_mm_runtime model_runner.cudagraph_only_prefill = False model_runner.use_cudagraph = False model_runner.speculative_decoding = False diff --git a/tests/entrypoints/test_engine_client.py b/tests/entrypoints/test_engine_client.py index 0ed8fbdc033..71ad4b29db9 100644 --- a/tests/entrypoints/test_engine_client.py +++ b/tests/entrypoints/test_engine_client.py @@ -102,6 +102,7 @@ def create_mock_fd_config( mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config.reasoning_parser = None mock_config.tool_parser = None + mock_config.enable_mm_runtime = enable_mm return mock_config @@ -181,6 +182,7 @@ async def asyncSetUp(self): mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config.reasoning_parser = None mock_config.node_rank = 0 + mock_config.enable_mm_runtime = mock_model_config.enable_mm # Create mocks for all the external dependencies mock_input_processor = Mock() @@ -363,6 +365,7 @@ def setUp(self): mock_config.structured_outputs_config = MagicMock() # Add this mock_config.structured_outputs_config.reasoning_parser = None mock_config.tool_parser = None # Add this attribute + mock_config.enable_mm_runtime = mock_model_config.enable_mm # Mock IPCSignal to avoid file system dependencies with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: @@ -655,6 +658,7 @@ async def test_init_basic_parameters(self): mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config.reasoning_parser = None mock_config.tool_parser = None + mock_config.enable_mm_runtime = mock_config.model_config.enable_mm client = EngineClient( pid=5678, @@ -1078,6 +1082,7 @@ async def test_init_with_multimodal_prefix_cache(self): mock_config = Mock() mock_config.model_config = mock_model_config + mock_config.enable_mm_runtime = mock_model_config.enable_mm mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False @@ -1131,6 +1136,7 @@ async def test_init_as_worker_node(self): mock_config = Mock() mock_config.model_config = mock_model_config + mock_config.enable_mm_runtime = mock_model_config.enable_mm mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False @@ -1408,6 +1414,7 @@ async def test_init_iluvatar_platform(self): mock_config = Mock() mock_config.model_config = mock_model_config + mock_config.enable_mm_runtime = mock_model_config.enable_mm mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False diff --git a/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py b/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py index 17a393ee11e..f679be08b31 100644 --- a/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py +++ b/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py @@ -92,6 +92,7 @@ def __init__(self): "max_model_len": 2048, "head_dim": 128, "num_hidden_layers": 2, + "enable_mm": False, "causal": True, "start_layer_index": 0, "rope_3d": False, @@ -124,6 +125,8 @@ def __init__(self): "model_type": "main", }, )() + self.enable_mm_runtime = self.model_config.enable_mm + self.enable_rope_3d_runtime = self.model_config.enable_mm class DummyLayer: diff --git a/tests/scheduler/test_chunked_prefill_determinism.py b/tests/scheduler/test_chunked_prefill_determinism.py index 1a0f786f3d1..17b466a014c 100644 --- a/tests/scheduler/test_chunked_prefill_determinism.py +++ b/tests/scheduler/test_chunked_prefill_determinism.py @@ -78,6 +78,7 @@ def __init__(self): self.cache_config = CacheConfig() self.parallel_config = ParallelConfig() self.speculative_config = SpeculativeConfig() + self.enable_mm_runtime = self.model_config.enable_mm # --------------------------------------------------------------------------- @@ -168,6 +169,7 @@ def _create_resource_manager(self, config): def _create_mm_resource_manager(self): config = StubConfig() config.model_config.enable_mm = True + config.enable_mm_runtime = config.model_config.enable_mm return self._create_resource_manager(config) # ==================== 1. Deterministic disabled ==================== diff --git a/tests/worker/test_gpu_prompt_logprobs.py b/tests/worker/test_gpu_prompt_logprobs.py index d26bc915339..f12bc4cf3dc 100644 --- a/tests/worker/test_gpu_prompt_logprobs.py +++ b/tests/worker/test_gpu_prompt_logprobs.py @@ -64,6 +64,7 @@ class SpecaulativeConfig: scheduler_config = SchedulerConfig() cache_config = CacheConfig() parallel_config = ParallelConfig() + enable_mm_runtime = model_config.enable_mm def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): return 8192 diff --git a/tests/worker/test_reorder_split_prefill_and_decode.py b/tests/worker/test_reorder_split_prefill_and_decode.py index aff9f551cf4..d2d9e3a1f61 100644 --- a/tests/worker/test_reorder_split_prefill_and_decode.py +++ b/tests/worker/test_reorder_split_prefill_and_decode.py @@ -83,6 +83,7 @@ def create_mock_config(): fd_config.parallel_config = parallel_config fd_config.structured_outputs_config = structured_outputs_config fd_config.pad_to = 8 + fd_config.enable_mm_runtime = model_config.enable_mm def get_max_chunk_tokens(mm_max_tokens_per_item=None): return 100 From 84d62712c9b62b7ce9a9e7daafff4ad30830bf74 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Wed, 8 Apr 2026 17:32:38 +0800 Subject: [PATCH 007/143] [Feature]distinguish whl version (#7204) (#7224) * [Feature]whl version * [Feature]whl version,set root_is_pure = false * [Feature]code style Co-authored-by: ChowMingSing <610208940@qq.com> --- setup.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index d0edf2c907b..4c4e24f950e 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ from pathlib import Path import paddle +from packaging import tags from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools.command.install import install @@ -42,16 +43,17 @@ class CustomBdistWheel(bdist_wheel): - """Custom wheel builder for pure Python packages.""" + """Custom wheel builder.""" def finalize_options(self): - """Configure wheel as pure Python and platform-independent.""" + """Configure wheel as {python tag}-{abi tag}-{platform tag}.""" super().finalize_options() - self.root_is_pure = True - self.python_tag = "py3" - self.abi_tag = "none" + tag = next(tags.sys_tags()) + self.root_is_pure = False + self.python_tag = tag.interpreter + self.abi_tag = tag.abi self.plat_name_supplied = True - self.plat_name = "any" + self.plat_name = tag.platform class CMakeExtension(Extension): From 01818844b4369ce7a64327e4bce760c0343a03b1 Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Wed, 8 Apr 2026 20:56:23 +0800 Subject: [PATCH 008/143] support moe for sm103 (#7240) --- .../cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h | 4 ++-- .../moe_gemm/fused_moe_gemm_kernels_template.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 9c5e7bfc47b..7e93f169028 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -635,7 +635,7 @@ struct MoeFCGemm { static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010) +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1100) static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); @@ -1060,7 +1060,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 800) && (__CUDA_ARCH__ < 1010) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1100) KernelRunner::run_kernel( params, shared_storage); #else diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h index db5af4f4938..68b5b054476 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h @@ -709,7 +709,7 @@ void MoeGemmRunner::dispatch_to_arch( dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm70); } else if (sm_ >= 75 && sm_ < 80) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm75); - } else if (sm_ >= 80 && sm_ < 101) { + } else if (sm_ >= 80 && sm_ < 104) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm80); } else { throw std::runtime_error( From 9c65655cb3bc48e77ddc955b4a5476f98ebc6665 Mon Sep 17 00:00:00 2001 From: JYChen Date: Thu, 9 Apr 2026 11:01:10 +0800 Subject: [PATCH 009/143] [Cherry-Pick][RL] support moe-topk use topk_reduce_func #7218 (#7256) * support moe-topk use topk_reduce_func * fix ep error * fix ut * fix ut --- fastdeploy/model_executor/layers/moe/ep.py | 2 + .../layers/moe/fused_moe_cutlass_backend.py | 1 + .../layers/moe/fused_moe_deepgemm_backend.py | 110 ++---------------- fastdeploy/model_executor/layers/moe/moe.py | 22 ++++ fastdeploy/model_executor/models/glm4_moe.py | 1 + .../layers/test_fused_moe_cutlass_backend.py | 4 +- tests/operators/test_noaux_tc_redundant.py | 38 ++++-- 7 files changed, 66 insertions(+), 112 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 243567a422f..33993872cb6 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -509,6 +509,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): expert_in_rank_num_list=expert_in_rank_num_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list, redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, + topk_reduce_func=getattr(layer, "topk_reduce_func", None), ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( @@ -534,6 +535,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + topk_reduce_func=getattr(layer, "topk_reduce_func", None), ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 0c86270c630..d36d2d3b6e7 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -285,6 +285,7 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + topk_reduce_func=getattr(layer, "topk_reduce_func", None), ) ( permute_input, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 135fb5ecafc..a4a8d831e26 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -207,67 +207,6 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op( return ffn_out -def moe_topk_select( - gating_output: paddle.Tensor, - n_group: int, - topk_group: int, - top_k: int, - routed_scaling_factor: float, - e_score_correction_bias: paddle.Tensor, - renormalize: bool = False, -): - """ - Topk selection using paddle PHI topk API. - - Args: - gating_output: gate output logits, shape [seq_len, n_experts] - n_group: number of expert groups - topk_group: number of top-k groups to select - top_k: number of top experts per token - routed_scaling_factor: scaling factor for routed experts - e_score_correction_bias: bias for expert selection - renormalize: whether to renormalize topk probabilities - - Returns: - topk_weights: normalized topk probabilities, shape [seq_len, top_k] - topk_ids: topk expert indices, shape [seq_len, top_k] - """ - # compute gate probs via sigmoid - gate_probs = paddle.nn.functional.sigmoid(gating_output) - # probs_for_choice includes correction bias for topk selection - probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs - # group-based topk selection - n_group = n_group if n_group > 0 else 1 - topk_group = topk_group if topk_group > 0 else 1 - if n_group > 1 and topk_group < n_group: - seq_length, n_experts = probs_for_choice.shape - group_scores = ( - probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) - ) # [seq_len, n_group] - group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group] - group_mask = paddle.sum( - paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype), - axis=1, # Sum over topk_group dimension -> [seq_len, n_group] - ) - score_mask = ( - group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1]) - ) # [seq_len, n_experts] - probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) - - _, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1) - topk_weights = paddle.index_sample(gate_probs, topk_ids) - - # normalize combine weights - if renormalize: - topk_weights = topk_weights / paddle.clip(topk_weights.sum(-1, keepdim=True), min=1e-12) - - # apply routed scaling factor - if routed_scaling_factor: - topk_weights = topk_weights * routed_scaling_factor - - return topk_weights, topk_ids - - class DeepGemmFusedMoeMethod(MoEMethodBase): """ DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend. @@ -403,22 +342,7 @@ def apply_ep_prefill( hidden_size = x.shape[1] # 1. Select topk experts and weights - if ( - fastdeploy.envs.FD_USE_PHI_MOE_TOPK - and layer.redundant_table_manger is None - and layer.topk_method == "noaux_tc" - ): - topk_weights, topk_idx = moe_topk_select( - gate_out, - layer.n_group, - layer.topk_group, - layer.top_k, - layer.routed_scaling_factor, - layer.gate_correction_bias, - getattr(layer, "renormalize", True), - ) - else: - topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) + topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) @@ -820,28 +744,16 @@ def apply_tp( gate_out = gate_out.cast("float32") if layer.topk_method == "noaux_tc": - - if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK: - _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( - gate_out, - layer.n_group, - layer.topk_group, - layer.top_k, - layer.routed_scaling_factor, - layer.gate_correction_bias, - getattr(layer, "renormalize", True), - ) - else: - topk_weights, topk_ids = moe_topk_select( - gate_out, - layer.n_group, - layer.topk_group, - layer.top_k, - layer.routed_scaling_factor, - layer.gate_correction_bias, - getattr(layer, "renormalize", True), - ) - + _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + ) else: topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 4e56c7485f9..f7d0b32c7a5 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -90,6 +90,7 @@ def get_moe_scores( expert_in_rank_num_list: paddle.Tensor = None, tokens_per_expert_stats_list: paddle.Tensor = None, redundant_ep_rank_num_plus_one: int = 1, + topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. @@ -97,6 +98,14 @@ def get_moe_scores( scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" scores_with_bias = scores + e_score_correction_bias + + if envs.FD_USE_PHI_MOE_TOPK: + # calculate renormalize and routed_scaling_factor value outside the noaux_tc + original_renormalize = renormalize + original_routed_scaling_factor = routed_scaling_factor + renormalize = False + routed_scaling_factor = 1.0 + if expert_id_to_ep_rank_array is None: scores, topk_values, topk_idx = noaux_tc( scores, @@ -123,6 +132,16 @@ def get_moe_scores( routed_scaling_factor, redundant_ep_rank_num_plus_one, ) + if envs.FD_USE_PHI_MOE_TOPK: + if original_renormalize: + if topk_reduce_func is not None: + topk_values = topk_values / topk_reduce_func(topk_values) + else: + # 使用默认的 sum + epsilon + topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20) + + if original_routed_scaling_factor != 1.0: + topk_values *= original_routed_scaling_factor return scores, topk_values, topk_idx @@ -152,6 +171,8 @@ def __init__( with_bias: bool = False, activation="swiglu", model_format: Optional[str] = None, + topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + + 1e-20, # only used when FD_USE_PHI_MOE_TOPK=1, default is same as noaux_tc kernel ): """ Initialize the Moe layer with given parameters. @@ -197,6 +218,7 @@ def __init__( self.moe_tag = moe_tag self.with_bias = with_bias self.activation = activation + self.topk_reduce_func = topk_reduce_func if self.ep_size > 1: expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index b32ebb2ced9..f1927fea5d2 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -182,6 +182,7 @@ def __init__( layer_idx=layer_id, gate_correction_bias=self.gate.e_score_correction_bias, weight_key_map=weight_key_map, + topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, ) if self.n_shared_experts > 0: diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 2e8ea281daa..0c9ecc9f6eb 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -388,7 +388,9 @@ def combine(self, ffn_out, topk_idx, topk_weights, handle, quant_group_size=-1): np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0)) def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch): - def fake_get_moe_scores(gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize): + def fake_get_moe_scores( + gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None + ): return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) def fake_dispatch(*args, **kwargs): diff --git a/tests/operators/test_noaux_tc_redundant.py b/tests/operators/test_noaux_tc_redundant.py index 60d1aad2a22..f5289e0ab3c 100644 --- a/tests/operators/test_noaux_tc_redundant.py +++ b/tests/operators/test_noaux_tc_redundant.py @@ -1,10 +1,22 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest +from unittest import mock import paddle -from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import ( - moe_topk_select, -) from fastdeploy.model_executor.layers.moe.moe import get_moe_scores @@ -135,15 +147,17 @@ def test_group_topk_using_phi_topk(self): e_score_correction_bias=e_score_correction_bias, ) - topk_values, topk_idx = moe_topk_select( - gating_output=gating_output, - n_group=n_group, - topk_group=topk_group, - top_k=top_k, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - renormalize=renormalize, - ) + with mock.patch.dict("os.environ", {"FD_USE_PHI_MOE_TOPK": "1"}): + new_score, topk_values, topk_idx = get_moe_scores( + gating_output=gating_output, + n_group=n_group, + topk_group=topk_group, + top_k=top_k, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + renormalize=renormalize, + topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, + ) equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item() equal_topk_ids = paddle.allclose( From 5fd8020363371d38ed37d887d5af0bde1566c4a8 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Thu, 9 Apr 2026 11:05:43 +0800 Subject: [PATCH 010/143] [Cherry-Pick][BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn (#7216) --- custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu | 1 - .../model_executor/layers/attention/flash_mask_attn_backend.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index b0ca5e2c0ce..dce65b97274 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); constexpr int kBlockM = 128; diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 6ebea2cb3d9..5b3c5ecdd3a 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -307,7 +307,7 @@ def forward_mixed( q, k, v, - forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]], forward_meta.attn_cu_seqlens_k, forward_meta.seq_lens_encoder, res_encoder, From 098dd2c2515865a586c183800c75832513eef616 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Thu, 9 Apr 2026 12:46:13 +0800 Subject: [PATCH 011/143] [XPU][CI] lock xvllm version for fix bug (#7264) (#7266) * Remove duplicate NICs from environment variables * Update version for xvllm in download_dependencies.sh Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> --- custom_ops/xpu_ops/download_dependencies.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/xpu_ops/download_dependencies.sh b/custom_ops/xpu_ops/download_dependencies.sh index ad6d4d2dea6..4aaa95777d0 100644 --- a/custom_ops/xpu_ops/download_dependencies.sh +++ b/custom_ops/xpu_ops/download_dependencies.sh @@ -15,7 +15,7 @@ if [ "$1" == "stable" ]; then version_xvllm="20251017" version_xtdk="3.4.0.1" else - version_xvllm="latest" + version_xvllm="20260407" version_xtdk="latest" fi From 849eb3df65ef26685f63ca7a97996df508b7b5b3 Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Thu, 9 Apr 2026 14:15:43 +0800 Subject: [PATCH 012/143] =?UTF-8?q?[Cherry-Pick][Optimization]=20merge=20m?= =?UTF-8?q?atmul=20and=20add=20=EF=BC=88#6986=EF=BC=89=20(#7191)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * merge matmul and add * modify format * using paddle.nn.functional.linear * using _C_ops.linear * using paddle.nn.functional.linear * add FLAGS_use_legacy_linear env var in test case * fix format * add assert and remove env * modify format * using matmul for no bias * modify accurate baseline --- fastdeploy/model_executor/layers/linear.py | 13 ++++++++++--- .../e2e/utils/rollout_routing_replay_test_utils.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 2bee885ff43..b35d97d7660 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -82,10 +82,17 @@ def process_loaded_weights(self, layer, weights) -> None: layer.weight.set_value(weights) def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor: - linear_out = paddle.matmul(x, layer.weight) if layer.with_bias: - linear_out = paddle.add(linear_out, layer.bias) - return linear_out + bias = layer.bias + assert bias.dim() == 1 and bias.shape[-1] == layer.weight.shape[-1], ( + f"bias must be 1D with size equal to the last dim of weight, " + f"but got bias.shape={bias.shape}, weight.shape[-1]={layer.weight.shape[-1]}" + ) + out = paddle.nn.functional.linear(x, layer.weight, bias) + else: + out = paddle.matmul(x, layer.weight) + + return out class LinearBase(nn.Layer): diff --git a/tests/e2e/utils/rollout_routing_replay_test_utils.py b/tests/e2e/utils/rollout_routing_replay_test_utils.py index 4186a71649a..74af852a292 100644 --- a/tests/e2e/utils/rollout_routing_replay_test_utils.py +++ b/tests/e2e/utils/rollout_routing_replay_test_utils.py @@ -157,10 +157,10 @@ def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, mode model_path = os.getenv("MODEL_PATH") if model_path: baseline_path = os.path.join( - model_path, f"R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}" + model_path, f"R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}" ) else: - baseline_path = f"./R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}" + baseline_path = f"./R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}" stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream") nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream") From 6fcc25f3f6c86cd408c81cbad163f679e4389f30 Mon Sep 17 00:00:00 2001 From: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> Date: Thu, 9 Apr 2026 17:31:20 +0800 Subject: [PATCH 013/143] Update ci_metax.yml (#7286) --- .github/workflows/ci_metax.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ci_metax.yml b/.github/workflows/ci_metax.yml index 5584147eb8c..c983ae38590 100644 --- a/.github/workflows/ci_metax.yml +++ b/.github/workflows/ci_metax.yml @@ -6,8 +6,7 @@ on: - opened - synchronize branches: - - develop - - release/** + - never-trigger-this permissions: contents: read From 921a0ae60b439c43b07404ada5fab21a521b373c Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:03:19 +0800 Subject: [PATCH 014/143] [Docs] Update docs for release/2.5 (#7267) (#7277) * Update docs for release/2.5 * Update English docs for release/2.5 - Update README_EN.md: add v2.5 news entry, reformat v2.4 entry with release link - Update docs/get_started/installation/nvidia_gpu.md: - Docker image: 2.4.0 -> 2.5.0, notice now shows SM80/86/89/90 support - paddlepaddle-gpu: 3.3.0 -> 3.3.1, add CUDA 12.9 alternatives - fastdeploy-gpu: 2.4.0 -> 2.5.0, unified arch install with CUDA 12.9 option - Update docs/zh/get_started/installation/nvidia_gpu.md: - Fix remaining paddlepaddle-gpu==3.3.0 refs in sections 4&5 -> 3.3.1 Agent-Logs-Url: https://github.com/PaddlePaddle/FastDeploy/sessions/fa0be381-324e-4b0d-b7a6-e2c1fa12174f * Clarify --extra-index-url usage in installation docs Add note explaining that --extra-index-url is only for downloading fastdeploy-gpu dependencies; fastdeploy-gpu itself must be installed from the Paddle source specified by -i. Applied to both Chinese and English nvidia_gpu.md installation guides. Agent-Logs-Url: https://github.com/PaddlePaddle/FastDeploy/sessions/9fa8b3c9-7555-4eae-b9b9-026cddd7e74c * Update nvidia_gpu.md --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: jiang-jia-jun Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- README_CN.md | 4 +- README_EN.md | 4 +- docs/get_started/installation/nvidia_gpu.md | 45 ++++++++++--------- .../zh/get_started/installation/nvidia_gpu.md | 45 ++++++++++--------- 4 files changed, 52 insertions(+), 46 deletions(-) diff --git a/README_CN.md b/README_CN.md index 4d110b8d830..33642687062 100644 --- a/README_CN.md +++ b/README_CN.md @@ -27,7 +27,9 @@ ## 最新活动 -**[2026-01] FastDeploy v2.4 全新发布!** 新增 DeepSeek V3 与 Qwen3-MoE 模型的 PD 分离部署,增强MTP 投机解码能力,全面优化多硬件平台上的 MoE 推理与多模态前缀缓存性能,升级全部内容参阅 [v2.4 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.4.0)。 +**[2026-03] FastDeploy v2.5 全新发布!** 新增Qwen3-VL与Qwen3-VL MoE模型部署支持,新增W4AFP8量化方法,增强强化学习训练支持能力,包含170+项Bug修复与性能优化,升级全部内容参阅 [v2.5 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.5.0)。 + +**[2026-01] FastDeploy v2.4**: 新增 DeepSeek V3 与 Qwen3-MoE 模型的 PD 分离部署,增强MTP 投机解码能力,全面优化多硬件平台上的 MoE 推理与多模态前缀缓存性能,升级全部内容参阅 [v2.4 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.4.0)。 **[2025-11] FastDeploy v2.3**: 新增[ERNIE-4.5-VL-28B-A3B-Thinking](docs/zh/get_started/ernie-4.5-vl-thinking.md)与[PaddleOCR-VL-0.9B](docs/zh/best_practices/PaddleOCR-VL-0.9B.md)两大重磅模型在多硬件平台上的部署支持,进一步优化全方位推理性能,以及带来更多部署功能和易用性的提升,升级全部内容参阅[v2.3 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.3.0)。 diff --git a/README_EN.md b/README_EN.md index 4d918455d5f..72c8cf1a1ac 100644 --- a/README_EN.md +++ b/README_EN.md @@ -27,7 +27,9 @@ English | [简体中文](README_CN.md) ## News -[2026-01] FastDeploy v2.4 is released! Featuring PD-separated deployment for DeepSeek V3 and Qwen3-MoE, enhanced MTP speculative decoding, and comprehensive performance boosts for MoE inference and multi-modal Prefix Caching across various hardware backends. See the full v2.4 ReleaseNote for more details. +**[2026-03] FastDeploy v2.5 is released!** It adds deployment support for Qwen3-VL and Qwen3-VL MoE models, introduces the W4AFP8 quantization method, enhances reinforcement learning training capabilities, and includes 170+ bug fixes and performance optimizations. For all the upgrade details, refer to the [v2.5 Release Note](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.5.0). + +**[2026-01] FastDeploy v2.4**: Featuring PD-separated deployment for DeepSeek V3 and Qwen3-MoE, enhanced MTP speculative decoding, and comprehensive performance boosts for MoE inference and multi-modal Prefix Caching across various hardware backends. For all the upgrade details, refer to the [v2.4 Release Note](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.4.0). **[2025-11] FastDeploy v2.3**: It adds deployment support for two major models, [ERNIE-4.5-VL-28B-A3B-Thinking](docs/get_started/ernie-4.5-vl-thinking.md) and [PaddleOCR-VL-0.9B](docs/best_practices/PaddleOCR-VL-0.9B.md), across multiple hardware platforms. It further optimizes comprehensive inference performance and brings more deployment features and usability enhancements. For all the upgrade details, refer to the [v2.3 Release Note](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.3.0). diff --git a/docs/get_started/installation/nvidia_gpu.md b/docs/get_started/installation/nvidia_gpu.md index c59467175de..cc7f8caffd3 100644 --- a/docs/get_started/installation/nvidia_gpu.md +++ b/docs/get_started/installation/nvidia_gpu.md @@ -12,10 +12,10 @@ The following installation methods are available when your environment meets the ## 1. Pre-built Docker Installation (Recommended) -**Notice**: The pre-built image only supports SM80/90 GPU(e.g. H800/A800),if you are deploying on SM86/89GPU(L40/4090/L20), please reinstall ```fastdeploy-gpu``` after you create the container. +**Notice**: The pre-built image supports SM 80/86/89/90 architecture GPUs (e.g. A800/H800/L20/L40/4090). ```shell -docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.4.0 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.5.0 ``` ## 2. Pre-built Pip Installation @@ -23,30 +23,33 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12 First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html) ```shell # Install stable release -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.6 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.9 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ # Install latest Nightly build +# CUDA 12.6 python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ +# CUDA 12.9 +python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ ``` -Then install fastdeploy. **Do not install from PyPI**. Use the following methods instead: +Then install fastdeploy. **Do not install from PyPI**. Use the following methods instead (supports SM80/86/89/90 GPU architectures). -For SM80/90 architecture GPUs(e.g A30/A100/H100/): +**Note**: Stable FastDeploy release pairs with stable PaddlePaddle; Nightly Build FastDeploy pairs with Nightly Build PaddlePaddle. The `--extra-index-url` is only used for downloading fastdeploy-gpu's dependencies; fastdeploy-gpu itself must be installed from the Paddle source specified by `-i`. ``` -# Install stable release -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Install latest Nightly build -python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple -``` - -For SM86/89 architecture GPUs(e.g A10/4090/L20/L40): -``` -# Install stable release -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-86_89/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Install latest Nightly build +# Install stable release FastDeploy +# CUDA 12.6 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + +# Install Nightly Build FastDeploy +# CUDA 12.6 python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple ``` ## 3. Build from Source Using Docker @@ -64,10 +67,8 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu . First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html) ```shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` - -Then clone the source code and build: ```shell git clone https://github.com/PaddlePaddle/FastDeploy cd FastDeploy @@ -92,7 +93,7 @@ First, install paddlepaddle-gpu. For detailed instructions, please refer to the [PaddlePaddle Installation Guide](https://www.paddlepaddle.org.cn/). ```shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` Then, clone the FastDeploy repository and build using the precompiled operator wheels: diff --git a/docs/zh/get_started/installation/nvidia_gpu.md b/docs/zh/get_started/installation/nvidia_gpu.md index 8e989db5368..004216c6133 100644 --- a/docs/zh/get_started/installation/nvidia_gpu.md +++ b/docs/zh/get_started/installation/nvidia_gpu.md @@ -14,10 +14,10 @@ ## 1. 预编译Docker安装(推荐) -**注意**: 如下镜像仅支持SM 80/90架构GPU(A800/H800等),如果你是在L20/L40/4090等SM 86/89架构的GPU上部署,请在创建容器后,卸载```fastdeploy-gpu```再重新安装如下文档指定支持86/89架构的`fastdeploy-gpu`包。 +**注意**: 预编译镜像支持 80/86/89/90 架构的GPU硬件 (如 A800/H800/L20/L40/4090)。 ``` shell -docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.4.0 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.5.0 ``` ## 2. 预编译Pip安装 @@ -26,32 +26,33 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12 ``` shell # Install stable release -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.6 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.9 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ # Install latest Nightly build +# CUDA 12.6 python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ +# CUDA 12.9 +python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ ``` -再安装 fastdeploy,**注意不要通过pypi源安装**,需要通过如下方式安装 +再安装 fastdeploy,**注意不要通过pypi源安装**,需要通过如下方式安装(目前支持80/86/89/90四个架构GPU) -如你的 GPU 是 SM80/90 架构(A100/H100等),按如下方式安装 - -``` -# 安装稳定版本fastdeploy -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# 安装Nightly Build的最新版本fastdeploy -python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +**注意**: 稳定版本的FastDeploy搭配稳定版本的PaddlePaddle; 而Nightly Build的FastDeploy则对应Nightly Build的PaddlePaddle。其中 `--extra-index-url` 仅用于安装 fastdeploy-gpu 所需的依赖包,fastdeploy-gpu 本身必须从 `-i` 指定的 Paddle 源安装。 ``` - -如你的 GPU 是 SM86/89 架构(4090/L20/L40等),按如下方式安装 - -``` -# 安装稳定版本fastdeploy -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-86_89/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# 安装Nightly Build的最新版本fastdeploy +# 安装稳定版本FastDeploy +# CUDA 12.6 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + +# 安装Nightly Build版本FastDeploy +# CUDA 12.6 python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple ``` ## 3. 镜像自行构建 @@ -70,7 +71,7 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu . 首先安装 paddlepaddle-gpu,详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/) ``` shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` 接着克隆源代码,编译安装 @@ -98,7 +99,7 @@ FastDeploy 提供了 GPU 算子预编译版 Wheel 包,可在无需完整源码 首先安装 paddlepaddle-gpu,详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/) ``` shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` 接着克隆源代码,拉取 whl 包并安装 From dea9d3517172f43f443243c47eb60d97210e7588 Mon Sep 17 00:00:00 2001 From: fxyfxy777 <137464345+fxyfxy777@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:37:42 +0800 Subject: [PATCH 015/143] [OP]Unify MoE op with moe_permute path for bf16 GLM (#7164) (#7279) --- custom_ops/gpu_ops/cpp_extensions.cc | 4 +- custom_ops/gpu_ops/moe/deepgemm_preprocess.cu | 70 ++++-- .../layers/moe/fused_moe_cutlass_backend.py | 200 +++++++++++---- .../layers/moe/fused_moe_deepgemm_backend.py | 4 +- .../layers/test_fused_moe_cutlass_backend.py | 235 ++++++++++++++++++ 5 files changed, 444 insertions(+), 69 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 40898434bf1..8e9cf6a3ddc 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -537,7 +537,9 @@ std::vector TextImageGatherScatter( const bool is_scatter); std::vector count_tokens_per_expert_func( - const paddle::Tensor& topk_ids, int64_t num_experts); + const paddle::Tensor& topk_ids, + int64_t num_experts, + bool compute_padded_cumsum = false); void GetPositionIdsAndMaskEncoderBatch( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, diff --git a/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu b/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu index 6eda3598cd3..4316aa5cbda 100644 --- a/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu +++ b/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu @@ -15,10 +15,11 @@ #include "helper.h" #include "paddle/extension.h" -template +template __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids, int32_t *__restrict__ res, int32_t *__restrict__ res_padded, + int32_t *__restrict__ res_padded_cumsum, size_t numel, int num_experts) { extern __shared__ int32_t tokens_per_ep[]; @@ -35,48 +36,81 @@ __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids, __syncthreads(); - for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) { - res[i] = tokens_per_ep[i]; - res_padded[i] = (res[i] + 127) / 128 * 128; + if constexpr (kComputeCumsum) { + if (threadIdx.x == 0) { + int32_t running_sum = 0; + for (int i = 0; i < num_experts; i++) { + int32_t count = tokens_per_ep[i]; + int32_t padded = (count + 127) / 128 * 128; + res[i] = count; + res_padded[i] = padded; + running_sum += padded; + res_padded_cumsum[i] = running_sum; + } + } + } else { + for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) { + res[i] = tokens_per_ep[i]; + res_padded[i] = (tokens_per_ep[i] + 127) / 128 * 128; + } } } std::vector count_tokens_per_expert_func( - const paddle::Tensor &topk_ids, int64_t num_experts) { + const paddle::Tensor &topk_ids, + int64_t num_experts, + bool compute_padded_cumsum) { int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1]; + int64_t num_rows = compute_padded_cumsum ? 3 : 2; auto token_nums_per_expert = paddle::empty( - {2, num_experts}, paddle::DataType::INT32, topk_ids.place()); + {num_rows, num_experts}, paddle::DataType::INT32, topk_ids.place()); auto stream = topk_ids.stream(); using scalar_t = int64_t; - // CUDA_CHECK(cudaGetLastError()); - cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>( - topk_ids.data(), - token_nums_per_expert.data(), - token_nums_per_expert.data() + num_experts, - topk_ids_numel, - num_experts); + if (compute_padded_cumsum) { + cuda_kernel + <<<1, 1024, num_experts * sizeof(int32_t), stream>>>( + topk_ids.data(), + token_nums_per_expert.data(), + token_nums_per_expert.data() + num_experts, + token_nums_per_expert.data() + 2 * num_experts, + topk_ids_numel, + num_experts); + } else { + cuda_kernel + <<<1, 1024, num_experts * sizeof(int32_t), stream>>>( + topk_ids.data(), + token_nums_per_expert.data(), + token_nums_per_expert.data() + num_experts, + nullptr, + topk_ids_numel, + num_experts); + } - // CUDA_CHECK(cudaGetLastError()); return {token_nums_per_expert}; } std::vector count_tokens_per_expert_func_infer_dtype( - const paddle::DataType &topk_ids_dtype, int64_t num_experts) { + const paddle::DataType &topk_ids_dtype, + int64_t num_experts, + bool compute_padded_cumsum) { return {paddle::DataType::INT32}; } std::vector> count_tokens_per_expert_func_infer_shape( - const std::vector &topk_ids_shape, int64_t num_experts) { - return {{2, num_experts}}; + const std::vector &topk_ids_shape, + int64_t num_experts, + bool compute_padded_cumsum) { + int64_t num_rows = compute_padded_cumsum ? 3 : 2; + return {{num_rows, num_experts}}; } PD_BUILD_STATIC_OP(count_tokens_per_expert_func) .Inputs({"topk_ids"}) .Outputs({"token_nums_per_expert"}) - .Attrs({"num_experts:int64_t"}) + .Attrs({"num_experts:int64_t", "compute_padded_cumsum:bool"}) .SetKernelFn(PD_KERNEL(count_tokens_per_expert_func)) .SetInferShapeFn(PD_INFER_SHAPE(count_tokens_per_expert_func_infer_shape)) .SetInferDtypeFn(PD_INFER_DTYPE(count_tokens_per_expert_func_infer_dtype)); diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index d36d2d3b6e7..f927cd8c5ee 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -28,7 +28,11 @@ from .fused_moe_backend_base import UnquantizedFusedMoEMethod if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce + from fastdeploy.model_executor.ops.gpu import ( + count_tokens_per_expert_func, + moe_expert_dispatch, + moe_expert_reduce, + ) try: from fastdeploy.model_executor.ops.gpu import ( @@ -126,6 +130,7 @@ def apply_ep_prefill( # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) # 2. EP Dispatch + dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {} ( recv_x, recv_topk_idx, @@ -133,7 +138,7 @@ def apply_ep_prefill( recv_num_tokens_per_expert_list, handle, event, - ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights) + ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs) if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) @@ -146,54 +151,91 @@ def apply_ep_prefill( # 3. Compute ffn if token_all_num > 0: logger.debug(f"token_all_num {token_all_num}") - ( - permute_input, - permute_indices_per_token, - recv_num_tokens_per_expert_list_cumsum, - dst_weights, - dst_indices, - cumsum_idx_gpu, - expert_idx_per_token, - dequant_scale, - ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch( - recv_x, - recv_topk_idx, - recv_topk_weights, - (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), - recv_num_tokens_per_expert_list, - token_all_num, - self.moe_quant_type, - ) - if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": - # only w4a8 and w4afp8 need expert_idx_per_token - # Other need not this tensor, so we make it None. - expert_idx_per_token = None + + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": + # --- moe_permute / moe_unpermute path --- + recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) + (permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute( + hidden_states=recv_x, + scale=None, + expert_routemap_topk=recv_topk_idx_i32, + expert_prob_topk=recv_topk_weights, + num_experts=layer.num_local_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=token_all_num, + ) + + token_nums_per_expert_cumsum = count_tokens_per_expert_func( + recv_topk_idx, layer.num_local_experts, True + )[2].cast(paddle.int64) + ffn_out = self.compute_ffn( + layer, + permute_input, + token_nums_per_expert_cumsum, + None, + False, + -1, + None, + None, + ) + + tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( + hidden_states_unzipped=ffn_out, + zipped_expertwise_rowmap=permute_indices_per_token, + expert_routemap_topk=recv_topk_idx_i32, + token_prob_unzipped=dst_weights, + total_zipped_tokens=recv_x.shape[0], + num_experts=layer.num_local_experts, + using_weighted_combine=True, + ) else: - expert_idx_per_token = expert_idx_per_token.cast("int64") + # --- original ep_moe_expert_dispatch / combine path --- + ( + permute_input, + permute_indices_per_token, + recv_num_tokens_per_expert_list_cumsum, + dst_weights, + dst_indices, + cumsum_idx_gpu, + expert_idx_per_token, + dequant_scale, + ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch( + recv_x, + recv_topk_idx, + recv_topk_weights, + (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), + recv_num_tokens_per_expert_list, + token_all_num, + self.moe_quant_type, + ) + if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": + expert_idx_per_token = None + else: + expert_idx_per_token = expert_idx_per_token.cast("int64") - if hasattr(layer, "up_gate_proj_in_scale"): - dequant_scale = None + if hasattr(layer, "up_gate_proj_in_scale"): + dequant_scale = None - ffn_out = self.compute_ffn( - layer, - permute_input, - recv_num_tokens_per_expert_list_cumsum, - expert_idx_per_token, - False, - -1, - dequant_scale, - ) + ffn_out = self.compute_ffn( + layer, + permute_input, + recv_num_tokens_per_expert_list_cumsum, + expert_idx_per_token, + False, + -1, + dequant_scale, + ) - # prmt back per rank - tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( - ffn_out, - dst_weights, - permute_indices_per_token, - dst_indices, - None, # down_proj_bias, - False, # norm_topk_prob - 1.0, - ) + tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( + ffn_out, + dst_weights, + permute_indices_per_token, + dst_indices, + None, # down_proj_bias, + False, # norm_topk_prob + 1.0, + ) else: tmp_ffn_out = recv_x @@ -276,6 +318,69 @@ def apply_tp( """ gate_out = gate(x) gate_out = gate_out.cast("float32") + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": + if layer.topk_method == "noaux_tc": + gate_out, topk_weights, topk_idx = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + else: + topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + topk_idx_i32 = topk_idx.astype(paddle.int32) + override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1) + (permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap + paddle.nn.functional.moe_permute( + hidden_states=x, + scale=None, + expert_routemap_topk=topk_idx_i32, + expert_prob_topk=topk_weights, + num_experts=layer.num_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=override_buffer_size, + ) + ) + + # Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert. + token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast( + paddle.int64 + ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + + ffn_out = self.compute_ffn( + layer, + permute_input, + token_nums_per_expert_cumsum, + None, # expert_idx_per_token not needed for w16a16 without bias + False, + -1, + None, # dequant_scale + None, # max_tokens_per_expert + ) + + fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute( + hidden_states_unzipped=ffn_out, + zipped_expertwise_rowmap=permute_indices_per_token, + expert_routemap_topk=topk_idx_i32, + token_prob_unzipped=dst_weights, + total_zipped_tokens=x.shape[0], + num_experts=layer.num_experts, + using_weighted_combine=True, + ) + return fused_moe_out + if layer.topk_method == "noaux_tc": gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, @@ -287,6 +392,7 @@ def apply_tp( getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), ) + ( permute_input, token_nums_per_expert, @@ -341,7 +447,6 @@ def apply_tp( expert_idx_per_token = None else: expert_idx_per_token = expert_idx_per_token.cast("int64") - ffn_out = self.compute_ffn( layer, permute_input, @@ -363,7 +468,6 @@ def apply_tp( norm_topk_prob=False if layer.topk_method == "noaux_tc" else True, routed_scaling_factor=1.0, ) - return fused_moe_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index a4a8d831e26..53247e29126 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -521,7 +521,7 @@ def apply_ep_prefill( using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) else: - token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) + token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts, False) ( permute_input, permute_scale, @@ -805,7 +805,7 @@ def apply_tp( using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) else: - tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts) + tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts, False) ( permute_input, permute_scale, diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 0c9ecc9f6eb..a476744b51b 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -35,6 +35,7 @@ iluvatar_stub.prefill_fused_paged_attention = lambda *args, **kwargs: None sys.modules["fastdeploy.model_executor.ops.iluvatar"] = iluvatar_stub +import fastdeploy # noqa: E402 from fastdeploy.model_executor.layers import utils as layer_utils from fastdeploy.model_executor.layers.moe import fused_moe_cutlass_backend as backend @@ -709,3 +710,237 @@ def test_weight_only_prequanted_and_int4_create(self): int4_method.create_weights( int4_layer, num_experts=2, hidden_size=4, moe_intermediate_size=2, model_format="paddle" ) + + +# --------------------------------------------------------------------------- +# Real-op tests for FD_USE_PHI_MOE_PERMUTE=True (w16a16, moe_permute path) +# --------------------------------------------------------------------------- + +from fastdeploy.platforms import current_platform # noqa: E402 + +_CUDA_AVAILABLE = current_platform.is_cuda() +requires_cuda = pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA required") + + +class RealMoELayer(paddle.nn.Layer): + """Minimal bf16 MoE layer with real weights for moe_permute path testing.""" + + def __init__(self, num_experts=4, hidden_size=64, moe_intermediate_size=32, top_k=2): + super().__init__() + self.fd_config = DummyFDConfig() + self.num_experts = num_experts + self.num_local_experts = num_experts + self.hidden_size = hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.top_k = top_k + self.topk_method = "noaux_tc" + self.n_group = 1 + self.topk_group = 1 + self.routed_scaling_factor = 1.0 + self.with_bias = False + self.ep_size = 1 + self.ep_rank = 0 + self.layer_idx = 0 + self.weight_dtype = "bfloat16" + self.is_quantized = False + self.activation = "swiglu" + self.moe_quant_config = types.SimpleNamespace(moe_dynamic_quant=False, hadamard_block_size=128) + self.gate_correction_bias = self.create_parameter( + shape=[1, num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + paddle.seed(0) + self.up_gate_proj_weight = self.create_parameter( + shape=[num_experts, 2 * moe_intermediate_size, hidden_size], + dtype="bfloat16", + ) + self.down_proj_weight = self.create_parameter( + shape=[num_experts, hidden_size, moe_intermediate_size], + dtype="bfloat16", + ) + self.up_gate_proj_weight.set_value( + paddle.randn([num_experts, 2 * moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01 + ) + self.down_proj_weight.set_value( + paddle.randn([num_experts, hidden_size, moe_intermediate_size]).cast("bfloat16") * 0.01 + ) + + +class SimpleLinearGate(paddle.nn.Layer): + def __init__(self, hidden_size, num_experts): + super().__init__() + self.weight = self.create_parameter(shape=[hidden_size, num_experts], dtype="float32") + + def forward(self, x): + return paddle.matmul(x.cast("float32"), self.weight) + + +class TestMoePermuteTrueRealOps: + """Real-op tests for FD_USE_PHI_MOE_PERMUTE=True on the w16a16 path.""" + + def _build(self, num_experts=4, hidden_size=64, moe_intermediate_size=32, top_k=2): + layer = RealMoELayer( + num_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + top_k=top_k, + ) + gate = SimpleLinearGate(hidden_size, num_experts) + method = backend.CutlassMoEMethod(None) + method.moe_quant_type = "w16a16" + return layer, gate, method + + @requires_cuda + def test_apply_tp_moe_permute_real_ops(self, monkeypatch): + """FD_USE_PHI_MOE_PERMUTE=True + w16a16: real moe_permute/moe_unpermute/ + count_tokens_per_expert_func/moe_expert_ffn all called end-to-end.""" + monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True) + + num_tokens, hidden_size = 8, 64 + layer, gate, method = self._build(hidden_size=hidden_size) + + paddle.seed(42) + x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16") + + # Spy: confirm moe_permute is called, moe_expert_dispatch is NOT + permute_called = {"v": False} + dispatch_called = {"v": False} + original_permute = paddle.nn.functional.moe_permute + + def spy_permute(*args, **kwargs): + permute_called["v"] = True + return original_permute(*args, **kwargs) + + monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute) + monkeypatch.setattr( + backend, + "moe_expert_dispatch", + lambda *a, **kw: (_ for _ in ()).throw(AssertionError("moe_expert_dispatch must not be called")), + ) + + out = method.apply_tp(layer, x, gate) + + assert permute_called["v"], "moe_permute was not called" + assert not dispatch_called["v"], "moe_expert_dispatch must not be called" + assert list(out.shape) == [num_tokens, hidden_size], f"wrong output shape: {out.shape}" + assert not paddle.isnan(out).any(), "output contains NaN" + assert not paddle.isinf(out).any(), "output contains Inf" + + @requires_cuda + def test_apply_ep_prefill_moe_permute_real_ops(self, monkeypatch): + """FD_USE_PHI_MOE_PERMUTE=True + w16a16: EP prefill uses real moe_permute / + moe_unpermute / count_tokens_per_expert_func / moe_expert_ffn end-to-end. + The EP dispatch/combine are stubbed (no real NCCL needed). + Use num_tokens=128 and num_experts=4 so each expert gets exactly 64 tokens + (128 * top_k=2 / 4 experts = 64), satisfying moe_expert_ffn alignment.""" + monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True) + + # 128 tokens, top_k=2, 4 experts → 64 tokens/expert (128-aligned after padding) + num_tokens, hidden_size = 128, 64 + layer, gate, method = self._build(num_experts=4, hidden_size=hidden_size, top_k=2) + + paddle.seed(42) + x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16") + + # Stub only the EP communication runner (dispatch/combine). + # All on-device compute (moe_permute, moe_expert_ffn, moe_unpermute) runs for real. + class StubEPRunner: + ep_engine = types.SimpleNamespace(async_finish=False) + + def moe_select(self, _layer, gate_out): + n = gate_out.shape[0] + # Route token i to experts (i % E) and ((i+1) % E) so all experts + # get tokens and recv_num_tokens_per_expert_list is accurate. + E = _layer.num_local_experts + idx0 = paddle.arange(n, dtype="int64") % E + idx1 = (paddle.arange(n, dtype="int64") + 1) % E + topk_ids = paddle.stack([idx0, idx1], axis=1) + topk_weights = paddle.ones([n, _layer.top_k], dtype="float32") / _layer.top_k + return topk_ids, topk_weights + + def dispatch(self, x, topk_idx, topk_weights, **kwargs): + # Pass tensors through unchanged — single-rank, no real communication. + # Compute accurate recv_num_tokens_per_expert_list from topk_idx. + E = layer.num_local_experts + counts = [int((topk_idx == e).sum().item()) for e in range(E)] + return ( + x, + topk_idx, + topk_weights, + counts, + object(), + types.SimpleNamespace(current_stream_wait=lambda: None), + ) + + def combine(self, ffn_out, handle, recv_topk_weights): + return ffn_out, types.SimpleNamespace(current_stream_wait=lambda: None) + + method.ep_prefill_runner = StubEPRunner() + + # Spy: confirm moe_permute is called inside ep_prefill + permute_called = {"v": False} + original_permute = paddle.nn.functional.moe_permute + + def spy_permute(*args, **kwargs): + permute_called["v"] = True + return original_permute(*args, **kwargs) + + monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute) + + out = method.apply_ep_prefill(layer, x, gate) + + assert permute_called["v"], "moe_permute was not called in ep_prefill path" + assert len(out.shape) == 2, f"wrong output ndim: {out.shape}" + assert out.shape[1] == hidden_size, f"wrong hidden_size: {out.shape}" + assert not paddle.isnan(out).any(), "output contains NaN" + assert not paddle.isinf(out).any(), "output contains Inf" + + @requires_cuda + def test_apply_tp_moe_permute_non_noaux_tc(self, monkeypatch): + """FD_USE_PHI_MOE_PERMUTE=True + w16a16 + topk_method != 'noaux_tc': + the else-branch calls moe_topk_select instead of get_moe_scores, + then proceeds through moe_permute / moe_expert_ffn / moe_unpermute.""" + monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True) + + num_tokens, hidden_size = 8, 64 + layer, gate, method = self._build(hidden_size=hidden_size) + # Switch to non-noaux_tc to exercise the else-branch (moe_topk_select) + layer.topk_method = "greedy" + + paddle.seed(7) + x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16") + + # Spy on which routing function is invoked + get_moe_scores_called = {"v": False} + moe_topk_select_called = {"v": False} + permute_called = {"v": False} + + original_get_moe_scores = backend.get_moe_scores + original_moe_topk_select = fastdeploy.model_executor.ops.gpu.moe_topk_select + original_permute = paddle.nn.functional.moe_permute + + def spy_get_moe_scores(*args, **kwargs): + get_moe_scores_called["v"] = True + return original_get_moe_scores(*args, **kwargs) + + def spy_moe_topk_select(*args, **kwargs): + moe_topk_select_called["v"] = True + return original_moe_topk_select(*args, **kwargs) + + def spy_permute(*args, **kwargs): + permute_called["v"] = True + return original_permute(*args, **kwargs) + + monkeypatch.setattr(backend, "get_moe_scores", spy_get_moe_scores) + monkeypatch.setattr(fastdeploy.model_executor.ops.gpu, "moe_topk_select", spy_moe_topk_select) + monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute) + + out = method.apply_tp(layer, x, gate) + + assert not get_moe_scores_called["v"], "get_moe_scores must NOT be called for non-noaux_tc" + assert moe_topk_select_called["v"], "moe_topk_select must be called for non-noaux_tc" + assert permute_called["v"], "moe_permute must be called" + assert list(out.shape) == [num_tokens, hidden_size], f"wrong shape: {out.shape}" + assert not paddle.isnan(out).any(), "output contains NaN" + assert not paddle.isinf(out).any(), "output contains Inf" From dd0863b07680e455f1288c38446be235aa9059d8 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 10 Apr 2026 13:54:02 +0800 Subject: [PATCH 016/143] [BugFix] Fix Async D2H copy bug & flash mash atten cache V out of bound bug (#7221) (#7296) Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com> --- .../get_block_shape_and_split_kv_block.cu | 8 ++++---- .../gpu_ops/append_attn/pre_cache_len_concat.cu | 4 ++-- .../gpu_ops/flash_mask_attn/mainloop_attn.hpp | 17 +++++++++++++++++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index f94e8493f7f..d61aa3c2313 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -296,7 +296,7 @@ void GetBlockShapeAndSplitKVBlock( if (!phi::backends::gpu::IsCUDAGraphCapturing()) #endif max_len_tensor_cpu.copy_( - max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + max_len_tensor_gpu, max_len_tensor_cpu.place(), true); auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -378,7 +378,7 @@ void GetBlockShapeAndSplitKVBlock( if (!phi::backends::gpu::IsCUDAGraphCapturing()) #endif decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), true); } } // mla_backend not need run the following code. @@ -409,7 +409,7 @@ void GetBlockShapeAndSplitKVBlock( block_size); kv_num_blocks_x_cpu.copy_( - kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); + kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true); // Clear buffer const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); @@ -433,7 +433,7 @@ void GetBlockShapeAndSplitKVBlock( encoder_block_shape_q, group_size); encoder_num_blocks_x_cpu.copy_( - encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); + encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), true); } } diff --git a/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu b/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu index 492b3a26647..435c87ba4bb 100644 --- a/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu +++ b/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu @@ -87,9 +87,9 @@ std::vector PreCacheLenConcat( bsz, block_size); paddle::Tensor pre_cache_num_blocks_cpu = - pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false); + pre_cache_num_blocks.copy_to(paddle::CPUPlace(), true); paddle::Tensor kv_token_num_cpu = - kv_token_num.copy_to(paddle::CPUPlace(), false); + kv_token_num.copy_to(paddle::CPUPlace(), true); return { cu_seqlens_k, diff --git a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp index cb76da20d6a..277ed46f851 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp @@ -490,6 +490,23 @@ struct CollectiveMainloopAttn { softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); + if (seq_len_k - n_block * kBlockN < kBlockN) { + int valid_k = seq_len_k - n_block * kBlockN; + auto sVt_this = sVt(_, _, smem_pipe_read_v.index()); + constexpr int kHdLo = decltype(get<0, 0>(shape(sVt_this)))::value; + constexpr int kHdHi = decltype(get<0, 1>(shape(sVt_this)))::value; + if (thread_idx >= valid_k && thread_idx < kBlockN) { +#pragma unroll + for (int hd_hi = 0; hd_hi < kHdHi; ++hd_hi) { +#pragma unroll + for (int hd_lo = 0; hd_lo < kHdLo; ++hd_lo) { + sVt_this(make_coord(make_coord(hd_lo, hd_hi), thread_idx)) = + Element(0); + } + } + } + cutlass::arch::fence_view_async_shared(); + } gemm( tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); From 4f36346e14c517f1b27c82495d3b92a47fed8948 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:03:00 +0800 Subject: [PATCH 017/143] [Cherry-Pick] change rms norm for glm #7269 (#7276) * fix * refine code * refine code * refine code * refine code * refine code --- fastdeploy/envs.py | 2 ++ fastdeploy/model_executor/models/glm4_moe.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 0c7ac3e22b1..d918a2e4648 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -217,6 +217,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))), # Whether to use phi MOE permute,if 1,use paddle op. "FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))), + # Whether to use phi rms_norm,if 1,use paddle op. + "FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))), # Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), # Reserve output blocks for decoding requests when schedule new prefill requests diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index f1927fea5d2..7840107a046 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -25,6 +25,7 @@ from paddleformers.transformers import PretrainedModel from paddleformers.utils.log import logger +import fastdeploy from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -266,6 +267,14 @@ def forward( return output +def rms_norm_func(x, weight, eps): + rms_norm_out = paddle.nn.functional.rms_norm(x, x.shape[-1:], weight, eps) + if isinstance(rms_norm_out, (tuple, list)): + return rms_norm_out[0].astype(weight.dtype) + else: + return rms_norm_out.astype(weight.dtype) + + class Glm4MoeDecoderLayer(nn.Layer): """ """ @@ -319,8 +328,11 @@ def forward( residual: paddle.Tensor = None, ): """ """ + + proxy_rmsnorm = rms_norm_func if fastdeploy.envs.FD_USE_PHI_RMSNORM else None + hidden_states, residual = self.input_layernorm( - hidden_states, residual_input=residual, forward_meta=forward_meta + hidden_states, residual_input=residual, forward_meta=forward_meta, proxy_rmsnorm=proxy_rmsnorm ) hidden_states = self.self_attn( @@ -329,7 +341,7 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual, proxy_rmsnorm=proxy_rmsnorm) hidden_states = self.mlp(hidden_states, forward_meta) From c7560383ab79569da027bcb1888cfd4e64aa3f9b Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:10:31 +0800 Subject: [PATCH 018/143] [Cherry-Pick][FDConfig] Auto-scale CUDA Graph Capture & CLI Quantization Params + CUDAGraph Validation (#7215,#7281) (#7301) * refactor cudagraph args * refactor quant cli param * fix * fix * tmp skip xpu * fix --- fastdeploy/config.py | 111 +++++++++--------- fastdeploy/engine/args_utils.py | 13 +- .../cudagraph_piecewise_backend.py | 45 ++++++- .../layers/quantization/__init__.py | 53 ++++++++- fastdeploy/worker/gpu_model_runner.py | 3 +- fastdeploy/worker/gpu_worker.py | 32 +++-- 6 files changed, 170 insertions(+), 87 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 8b92a138a34..6e7001bc18c 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1058,6 +1058,7 @@ def __init__( - None (default): capture sizes are inferred from llm config. - list[int]: capture sizes are specified as given.""" self.cudagraph_capture_sizes: Optional[list[int]] = None + self.flag_cudagraph_capture_sizes_initlized = False self.cudagraph_capture_sizes_prefill: list[int] = [1, 2, 4, 8] """ Number of warmup runs for cudagraph. """ self.cudagraph_num_of_warmups: int = 2 @@ -1108,13 +1109,27 @@ def __init__( self.check_legality_parameters() - def init_with_cudagrpah_size(self, max_capture_size: int = 0, max_capture_shape_prefill: int = 0) -> None: + def init_with_cudagrpah_size( + self, + max_capture_size: int = 0, + max_capture_shape_prefill: int = 0, + num_speculative_tokens: int = 0, + ) -> None: """ Initialize cuda graph capture sizes and pre-compute the mapping from batch size to padded graph size """ # Regular capture sizes - self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size] + if num_speculative_tokens != 0: + max_capture_size = max_capture_size * (num_speculative_tokens + 1) + if not self.flag_cudagraph_capture_sizes_initlized and num_speculative_tokens != 0: + self.cudagraph_capture_sizes = [ + size * (num_speculative_tokens + 1) + for size in self.cudagraph_capture_sizes + if (size * (num_speculative_tokens + 1)) <= max_capture_size + ] + else: + self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size] self.cudagraph_capture_sizes_prefill = [ size for size in self.cudagraph_capture_sizes_prefill if size <= max_capture_shape_prefill ] @@ -1154,24 +1169,41 @@ def init_with_cudagrpah_size(self, max_capture_size: int = 0, max_capture_shape_ self.real_shape_to_captured_size_prefill[bs] = end self.real_shape_to_captured_size_prefill[self.max_capture_size_prefill] = self.max_capture_size_prefill + if num_speculative_tokens != 0: + real_bsz_to_captured_size = {} + for capture_size in self.cudagraph_capture_sizes: + dummy_batch_size = int(capture_size / (num_speculative_tokens + 1)) + real_bsz_to_captured_size[dummy_batch_size] = capture_size + + def expand_bsz_map(real_bsz_to_captured_size): + sorted_items = sorted(real_bsz_to_captured_size.items()) + result = {} + prev_bsz = 0 + for curr_bsz, cap in sorted_items: + for bsz in range(prev_bsz + 1, curr_bsz + 1): + result[bsz] = cap + prev_bsz = curr_bsz + return result + + self.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size) + + self.flag_cudagraph_capture_sizes_initlized = True + def _set_cudagraph_sizes( self, max_capture_size: int = 0, max_capture_shape_prefill: int = 0, - dec_token_per_query_per_step: int = 1, ): """ Calculate a series of candidate capture sizes, and then extract a portion of them as the capture list for the CUDA graph based on user input. """ - # Shape [1, 2, 4, 8, 16, ... 120, 128] * dec_token_per_query_per_step - draft_capture_sizes = [i * dec_token_per_query_per_step for i in [1, 2, 4]] + [ - 8 * i * dec_token_per_query_per_step for i in range(1, 17) - ] - # Shape [128, 144, ... 240, 256] * dec_token_per_query_per_step - draft_capture_sizes += [16 * i * dec_token_per_query_per_step for i in range(9, 17)] - # Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step - draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range(9, 33)] + # Shape [1, 2, 4, 8, 16, ... 120, 128] + draft_capture_sizes = [i for i in [1, 2, 4]] + [8 * i for i in range(1, 17)] + # Shape [128, 144, ... 240, 256] + draft_capture_sizes += [16 * i for i in range(9, 17)] + # Shape [256, 288, ... 992, 1024] + draft_capture_sizes += [32 * i for i in range(9, 33)] draft_capture_sizes_prefill = draft_capture_sizes.copy() draft_capture_sizes.append(max_capture_size) @@ -1881,65 +1913,34 @@ def __init__( self.deploy_modality: DeployModality = deploy_modality # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs - if self.speculative_config is not None and self.speculative_config.method in [ - SpecMethod.MTP, - SpecMethod.SUFFIX, - ]: - max_capture_shape = self.scheduler_config.max_num_seqs * ( - self.speculative_config.num_speculative_tokens + 1 - ) - assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios." - self.graph_opt_config.real_bsz_to_captured_size = { - k: 0 for k in range(1, self.scheduler_config.max_num_seqs + 1) - } if self.graph_opt_config.cudagraph_only_prefill: max_capture_shape = 512 else: - max_capture_shape = ( - max_capture_shape if self.speculative_config is not None else min(512, max_capture_shape) - ) + max_capture_shape = min(512, max_capture_shape) max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill if self.graph_opt_config.cudagraph_capture_sizes is None: - dec_token_per_query_per_step = ( - self.speculative_config.num_speculative_tokens + 1 - if self.speculative_config is not None and self.speculative_config.method is not None - else 1 - ) self.graph_opt_config._set_cudagraph_sizes( max_capture_size=max_capture_shape, max_capture_shape_prefill=max_capture_shape_prefill, - dec_token_per_query_per_step=dec_token_per_query_per_step, ) - if self.speculative_config is not None and self.speculative_config.method is not None: - real_bsz_to_captured_size = {} - for capture_size in self.graph_opt_config.cudagraph_capture_sizes: - dummy_batch_size = int(capture_size / (self.speculative_config.num_speculative_tokens + 1)) - real_bsz_to_captured_size[dummy_batch_size] = capture_size - def expand_bsz_map(real_bsz_to_captured_size): - """ - Expand a sparse batch size mapping into a dense one. - - Args: - real_bsz_to_captured_size (dict): Sparse batch size to capture size mapping. - Returns: - dict: Dense batch size to capture size mapping. - """ - sorted_items = sorted(real_bsz_to_captured_size.items()) - result = {} - prev_bsz = 0 - for curr_bsz, cap in sorted_items: - for bsz in range(prev_bsz + 1, curr_bsz + 1): - result[bsz] = cap - prev_bsz = curr_bsz - return result - - self.graph_opt_config.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size) self.graph_opt_config.init_with_cudagrpah_size( max_capture_size=max_capture_shape, max_capture_shape_prefill=max_capture_shape_prefill, + num_speculative_tokens=( + self.speculative_config.num_speculative_tokens + if ( + self.speculative_config is not None + and self.speculative_config.method + in [ + SpecMethod.MTP, + SpecMethod.SUFFIX, + ] + ) + else 0 + ), ) self.tokenizer = tokenizer diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index ff0965c56bb..848926f963c 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -857,11 +857,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--quantization", type=parse_quantization, default=EngineArgs.quantization, - help="Quantization name for the model, currently support " - "'wint8', 'wint4'," - "default is None. The priority of this configuration " - "is lower than that of the config file. " - "More complex quantization methods need to be configured via the config file.", + help="Quantization config for the model. Can be a simple method name " + "(e.g. 'wint8', 'wint4') or a full JSON quantization_config string " + '(e.g. \'{"quantization": "mix_quant", "kv_cache_quant_type": "block_wise_fp8", ' + '"dense_quant_type": "block_wise_fp8", "moe_quant_type": "block_wise_fp8"}\'). ' + "When a JSON config is provided, it is processed the same way as " + "quantization_config in the model's config.json. " + "If both CLI and config.json specify quantization_config, " + "config.json takes higher priority. Default is None.", ) model_group.add_argument( "--graph-optimization-config", diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 526f55c2369..c04f137d10d 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -29,6 +29,7 @@ capture_custom_allreduce, custom_ar_clear_ipc_handles, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import get_logger logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log") @@ -123,9 +124,46 @@ def __init__( self.max_num_seqs = fd_config.scheduler_config.max_num_seqs self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size - def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): + # Expected decode capture sequence (descending), consistent with capture_model() iteration order. + # Used to validate that captures happen in the correct order. + self._decode_expected_sequence: list[int] = sorted(self.cudagraph_capture_sizes, reverse=True) + # Points to the next expected position in _decode_expected_sequence. + self._decode_capture_index: int = 0 + + def _validate_decode_capture_order(self, shape: int) -> None: + """Validate that decode CUDA graph captures happen in expected descending order. + + Raises RuntimeError immediately if the actual capture order deviates from + the order defined by cudagraph_capture_sizes (sorted descending). + """ + if current_platform.is_xpu(): + return + + if self._decode_capture_index >= len(self._decode_expected_sequence): + raise RuntimeError( + f"[CUDA GRAPH][ID:{id(self)}] Unexpected CUDA graph capture: shape={shape}. " + f"All {len(self._decode_expected_sequence)} expected captures have already completed. " + f"Expected sequence: {self._decode_expected_sequence}" + ) + expected = self._decode_expected_sequence[self._decode_capture_index] + if shape != expected: + raise RuntimeError( + f"[CUDA GRAPH][ID:{id(self)}] CUDA graph capture order mismatch at index " + f"{self._decode_capture_index}: expected shape={expected}, got shape={shape}. " + f"Full expected sequence: {self._decode_expected_sequence}" + ) + logger.debug( + f"[CUDA GRAPH][ID:{id(self)}] Capture order validated: shape={shape} matches " + f"expected sequence at index {self._decode_capture_index} " + f"(sequence: {self._decode_expected_sequence})" + ) + self._decode_capture_index += 1 + + def run_static_model(self, entry: ConcreteSizeEntry, is_decode: bool = False, **kwargs): if not entry.captured: + if is_decode: + self._validate_decode_capture_order(entry.real_shape) # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 @@ -194,13 +232,14 @@ def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor: # - Static full graph mode: Dynamic for prefill/mixed, Static + CUDAGraph for decode # - Dynamic mode: Dynamic + CUDAGraph for decode only if static_cudagraph_for_prefill or static_cudagraph_for_decode: - return self.run_static_model(entry, **kwargs) + return self.run_static_model(entry, is_decode=static_cudagraph_for_decode, **kwargs) # Capture a new cuda graph if entry.cuda_graph is None: assert ( real_shape == padding_real_shape ), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph." + self._validate_decode_capture_order(padding_real_shape) # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 @@ -278,6 +317,8 @@ def clear_graph(self): del self.concrete_size_entries paddle.device.cuda.empty_cache() + self._decode_capture_index = 0 + # Create new entrys self._create_entry_dict() diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index 3e9e34c54ab..2c9992b18b5 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -54,17 +54,56 @@ def _compute_hadamard_block_size(moe_intermediate_size: int, tp_size: int) -> in return block_size +def _is_full_quantization_config(quantization_dict): + """ + Determine whether the parsed quantization dict is a simple method name or a full quantization_config. + Simple method name: {"quantization": "wint4"} (only one key "quantization") + Full config: {"quantization": "mix_quant", "dense_quant_type": "wint8", ...} (multiple keys) + Or torch format: {"quant_method": "fp8", "weight_block_size": [128, 128]} (has "quant_method" key) + """ + if "quant_method" in quantization_dict: + return True + if len(quantization_dict) > 1: + return True + return False + + def parse_quant_config(args, model_config, is_ernie, is_v1_loader): if args.quantization is not None and isinstance(args.quantization, str): args.quantization = parse_quantization(args.quantization) + + # Determine whether CLI --quantization is a simple method name or a full JSON quantization_config + cli_quantization = args.quantization + cli_is_full_config = ( + cli_quantization is not None + and isinstance(cli_quantization, dict) + and _is_full_quantization_config(cli_quantization) + ) + + model_quantization_config = model_config.quantization_config + quantization_config = model_quantization_config + + # If CLI provides a full quantization_config JSON, handle priority with config.json + if cli_is_full_config: + if model_quantization_config is not None: + if model_quantization_config != cli_quantization: + logger.warning( + "The quantization_config from --quantization argument " + "differs from the one in model's config.json. " + "Using config.json's quantization_config as it has higher priority. " + f"config.json: {model_quantization_config}, " + f"--quantization: {cli_quantization}" + ) + else: + # config.json has no quantization_config, use CLI's full config + quantization_config = cli_quantization + # 1.model_config.is_quantized # TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future if model_config.model_format == "torch": - quantization_config = model_config.quantization_config if quantization_config is not None: model_config.is_quantized = True else: - quantization_config = model_config.quantization_config if not model_config.is_quantized: if quantization_config is not None: if "is_quantized" in quantization_config: @@ -84,11 +123,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): quant_config_name = None - if quantization_config is not None: + if model_quantization_config is not None: quant_config_name = _get_offline_quant_config_name( - quantization_config, model_config.model_format == "torch", is_v1_loader + model_quantization_config, model_config.model_format == "torch", is_v1_loader ) - elif args.quantization is not None: + elif cli_quantization is not None and not cli_is_full_config: quantization_config = {} try: quantization_config.update(args.quantization) @@ -116,8 +155,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): quantization_config["hadamard_block_size"] = 512 quantization_config["quantization"] = "mix_quant" quant_config_name = "mix_quant" + elif cli_quantization is not None and cli_is_full_config: + quant_config_name = quantization_config["quantization"] else: quant_config_name = None + if quant_config_name is None: quant_config = None else: @@ -127,6 +169,7 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): quantization_config["is_quantized"] = True quant_cls = get_quantization_config(quant_config_name) quant_config = quant_cls.from_config(quantization_config) + return quant_config diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2454b016209..6218e58687b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1895,8 +1895,7 @@ def capture_model(self) -> None: logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) - elif self.speculative_decoding and self.spec_method == SpecMethod.MTP: - # Capture Target Model without bsz 1 + elif self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]: for capture_size in sorted(capture_sizes, reverse=True): expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2 self._dummy_run( diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index aebf3f21111..423d9fb54a5 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -126,14 +126,12 @@ def determine_available_memory(self) -> int: before_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) logger.info( - ( - "Before running the profile, the memory usage info is as follows:", - f"\nDevice Total memory: {before_run_meminfo.total / Gb}", - f"\nDevice used memory: {before_run_meminfo.used / Gb}", - f"\nDevice free memory: {before_run_meminfo.free / Gb}", - f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}", - f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}", - ) + "Before running the profile, the memory usage info is as follows:" + f"\nDevice Total memory: {before_run_meminfo.total / Gb}" + f"\nDevice used memory: {before_run_meminfo.used / Gb}" + f"\nDevice free memory: {before_run_meminfo.free / Gb}" + f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}" + f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}" ) # 2. Profile run @@ -161,16 +159,14 @@ def determine_available_memory(self) -> int: end_time = time.perf_counter() logger.info( - ( - "After running the profile, the memory usage info is as follows:", - f"\nDevice Total memory: {after_run_meminfo.total / Gb}", - f"\nDevice used memory: {after_run_meminfo.used / Gb}", - f"\nDevice free memory: {after_run_meminfo.free / Gb}", - f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}", - f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}", - f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}", - f"Profile time: {end_time - start_time}", - ) + "After running the profile, the memory usage info is as follows:" + f"\nDevice Total memory: {after_run_meminfo.total / Gb}" + f"\nDevice used memory: {after_run_meminfo.used / Gb}" + f"\nDevice free memory: {after_run_meminfo.free / Gb}" + f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}" + f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}" + f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}" + f"Profile time: {end_time - start_time}" ) return available_kv_cache_memory # return to calculate the block num in this device From 2ac9b894090328ee090ccf5ce082b9afa7aa39ab Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Sat, 11 Apr 2026 00:27:54 +0800 Subject: [PATCH 019/143] [XPU][CI]Update xtdk version in download_dependencies.sh (#7320) (#7322) Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> --- custom_ops/xpu_ops/download_dependencies.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/xpu_ops/download_dependencies.sh b/custom_ops/xpu_ops/download_dependencies.sh index 4aaa95777d0..b927448ccc2 100644 --- a/custom_ops/xpu_ops/download_dependencies.sh +++ b/custom_ops/xpu_ops/download_dependencies.sh @@ -16,7 +16,7 @@ if [ "$1" == "stable" ]; then version_xtdk="3.4.0.1" else version_xvllm="20260407" - version_xtdk="latest" + version_xtdk="3.6.2.1" fi ( From 65c6e726f5795c79100c503c01ead33ff139ddc2 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Sat, 11 Apr 2026 16:48:06 +0800 Subject: [PATCH 020/143] [Cherry-Pick][Docs] Update Release Note(#7302) (#7341) --- dockerfiles/Dockerfile.gpu | 8 ++++---- dockerfiles/Dockerfile.xpu | 6 +++--- docs/get_started/installation/nvidia_gpu.md | 18 ++++++++++-------- docs/zh/get_started/installation/nvidia_gpu.md | 15 +++++++++------ 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/dockerfiles/Dockerfile.gpu b/dockerfiles/Dockerfile.gpu index 5ce8b05b199..4a4240cd76a 100644 --- a/dockerfiles/Dockerfile.gpu +++ b/dockerfiles/Dockerfile.gpu @@ -1,6 +1,6 @@ FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:tag-base -ARG PADDLE_VERSION=3.3.0 -ARG FD_VERSION=2.4.0 +ARG PADDLE_VERSION=3.3.1 +ARG FD_VERSION=2.5.0 ENV DEBIAN_FRONTEND=noninteractive @@ -16,8 +16,8 @@ RUN python -m pip uninstall paddlepaddle-gpu fastdeploy-gpu -y RUN python -m pip install --no-cache-dir paddlepaddle-gpu==${PADDLE_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ # build and install FastDeploy -RUN python -m pip install --no-cache-dir fastdeploy-gpu==${FD_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +RUN python -m pip install --no-cache-dir fastdeploy-gpu==${FD_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple ENV http_proxy="" ENV https_proxy="" -ENV no_proxy="" +ENV no_proxy="" \ No newline at end of file diff --git a/dockerfiles/Dockerfile.xpu b/dockerfiles/Dockerfile.xpu index 14998860a12..3b98165dd0c 100644 --- a/dockerfiles/Dockerfile.xpu +++ b/dockerfiles/Dockerfile.xpu @@ -15,7 +15,7 @@ RUN python -m pip uninstall paddlepaddle-gpu paddlepaddle-xpu fastdeploy-xpu -y RUN python -m pip uninstall -y Pillow && rm -rf /usr/local/lib/python3.10/dist-packages/Pillow* && rm -rf /usr/local/lib/python3.10/dist-packages/pillow* && python -m pip install Pillow==11.3.0 # install paddlepaddle-xpu -ARG PADDLE_VERSION=nightly +ARG PADDLE_VERSION=3.3.1 RUN if [ "$PADDLE_VERSION" = "nightly" ]; then \ python -m pip install --no-cache-dir --progress-bar off paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/; \ @@ -26,7 +26,7 @@ RUN if [ "$PADDLE_VERSION" = "nightly" ]; then \ # install fastdeploy-xpu ARG INSTALL_REQUIREMENTS=true ARG INSTALL_FASTDEPLOY=true -ARG FASTDEPLOY_VERSION=2.4.0 +ARG FASTDEPLOY_VERSION=2.5.0 RUN if [ "$INSTALL_FASTDEPLOY" = "true" ]; then \ python -m pip install --no-cache-dir fastdeploy-xpu==${FASTDEPLOY_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple; \ @@ -40,4 +40,4 @@ RUN mkdir -p /workspace/deps && cd /workspace/deps && \ wget https://klx-sdk-release-public.su.bcebos.com/xre/kl3-release/5.0.21.21/xre-Linux-x86_64-5.0.21.21.tar.gz && \ tar -zxf xre-Linux-x86_64-5.0.21.21.tar.gz && mv xre-Linux-x86_64-5.0.21.21 xre -ENV PATH=/workspace/deps/xre/bin:$PATH +ENV PATH=/workspace/deps/xre/bin:$PATH \ No newline at end of file diff --git a/docs/get_started/installation/nvidia_gpu.md b/docs/get_started/installation/nvidia_gpu.md index cc7f8caffd3..5a1f1ae2156 100644 --- a/docs/get_started/installation/nvidia_gpu.md +++ b/docs/get_started/installation/nvidia_gpu.md @@ -12,10 +12,13 @@ The following installation methods are available when your environment meets the ## 1. Pre-built Docker Installation (Recommended) -**Notice**: The pre-built image supports SM 80/86/89/90 architecture GPUs (e.g. A800/H800/L20/L40/4090). +**Notice**: The pre-built image supports SM 80/86/89/90 architecture GPUs (e.g. A800/H800/L20/L40/4090), and requires Python 3.10. ```shell +# CUDA 12.6 docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.5.0 +# CUDA 12.9 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.9:2.5.0 ``` ## 2. Pre-built Pip Installation @@ -38,7 +41,7 @@ python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/ Then install fastdeploy. **Do not install from PyPI**. Use the following methods instead (supports SM80/86/89/90 GPU architectures). **Note**: Stable FastDeploy release pairs with stable PaddlePaddle; Nightly Build FastDeploy pairs with Nightly Build PaddlePaddle. The `--extra-index-url` is only used for downloading fastdeploy-gpu's dependencies; fastdeploy-gpu itself must be installed from the Paddle source specified by `-i`. -``` +```shell # Install stable release FastDeploy # CUDA 12.6 python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -54,7 +57,7 @@ python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages ## 3. Build from Source Using Docker -- Note: ```dockerfiles/Dockerfile.gpu``` by default supports SM 80/90 architectures. To support other architectures, modify ```bash build.sh 1 python false [80,90]``` in the Dockerfile. It's recommended to specify no more than 2 architectures. +> Note: `dockerfiles/Dockerfile.gpu` currently supports CUDA 12.6 only, targeting SM 80/86/89/90 architectures, and requires Python 3.10. To support other architectures, modify ```bash build.sh 1 python false [80,90]``` in the Dockerfile. It's recommended to specify no more than 2 architectures. ```shell git clone https://github.com/PaddlePaddle/FastDeploy @@ -83,8 +86,7 @@ The built packages will be in the ```FastDeploy/dist``` directory. ## 5. Precompiled Operator Wheel Packages -FastDeploy provides precompiled GPU operator wheel packages for quick setup without building the entire source code. -This method currently supports **SM80/90 architecture (e.g., A100/H100)** and **CUDA 12.6** environments only. +FastDeploy provides precompiled GPU operator wheel packages for quick setup without building the entire source code. This method currently supports **SM80/90 architecture (e.g., A100/H100)** **CUDA 12.6** and **Python 3.10** environments only. > By default, `build.sh` compiles all custom operators from source.To use the precompiled package, enable it with the `FD_USE_PRECOMPILED` parameter. > If the precompiled package cannot be downloaded or does not match the current environment, the system will automatically fall back to `4. Build Wheel from Source`. @@ -113,7 +115,7 @@ cd FastDeploy bash build.sh 1 python false [90] 1 # Use precompiled wheel from a specific commit -bash build.sh 1 python false [90] 1 8a9e7b53af4a98583cab65e4b44e3265a93e56d2 +bash build.sh 1 python false [90] 1 d693d4be1448d414097882386fdc24c8bec2a63a ``` The downloaded wheel packages will be stored in the `FastDeploy/pre_wheel` directory. @@ -122,9 +124,9 @@ After the build completes, the operator binaries can be found in `FastDeploy/fas > **Notes:** > > - This mode prioritizes downloading precompiled GPU operator wheels to reduce build time. -> - Currently supports **GPU, SM80/90, CUDA 12.6** only. +> - Currently supports **GPU, SM80/90, CUDA 12.6, Python3.10** only. > - For custom architectures or modified operator logic, please use **source compilation (Section 4)**. -> - You can check whether the precompiled wheel for a specific commit has been successfully built on the [FastDeploy CI Build Status Page](https://github.com/PaddlePaddle/FastDeploy/actions/workflows/ci_image_update.yml). +> - You can check whether the precompiled wheel for a specific commit has been successfully built on the [FastDeploy CI Build Status Page](https://github.com/PaddlePaddle/FastDeploy/actions/workflows/ce_job.yml). ## Environment Verification diff --git a/docs/zh/get_started/installation/nvidia_gpu.md b/docs/zh/get_started/installation/nvidia_gpu.md index 004216c6133..dd266b6c7eb 100644 --- a/docs/zh/get_started/installation/nvidia_gpu.md +++ b/docs/zh/get_started/installation/nvidia_gpu.md @@ -14,10 +14,13 @@ ## 1. 预编译Docker安装(推荐) -**注意**: 预编译镜像支持 80/86/89/90 架构的GPU硬件 (如 A800/H800/L20/L40/4090)。 +**注意**: 预编译镜像支持 80/86/89/90 架构的GPU硬件 (如 A800/H800/L20/L40/4090) 且仅支持 Python 3.10。 ``` shell +# CUDA 12.6 docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.5.0 +# CUDA 12.9 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.9:2.5.0 ``` ## 2. 预编译Pip安装 @@ -41,7 +44,7 @@ python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/ 再安装 fastdeploy,**注意不要通过pypi源安装**,需要通过如下方式安装(目前支持80/86/89/90四个架构GPU) **注意**: 稳定版本的FastDeploy搭配稳定版本的PaddlePaddle; 而Nightly Build的FastDeploy则对应Nightly Build的PaddlePaddle。其中 `--extra-index-url` 仅用于安装 fastdeploy-gpu 所需的依赖包,fastdeploy-gpu 本身必须从 `-i` 指定的 Paddle 源安装。 -``` +```shell # 安装稳定版本FastDeploy # CUDA 12.6 python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -57,7 +60,7 @@ python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages ## 3. 镜像自行构建 -> 注意 ```dockerfiles/Dockerfile.gpu``` 默认编译的架构支持SM 80/90,如若需要支持其它架构,需自行修改Dockerfile中的 ```bash build.sh 1 python false [80,90]```,建议不超过2个架构。 +> 注意 ```dockerfiles/Dockerfile.gpu``` 默认编译产物仅支持 SM 80/86/89/90 架构,基于 CUDA 12.6 环境构建,且仅支持 Python 3.10,如若需要支持其它架构,需自行修改Dockerfile中的 ```bash build.sh 1 python false [80,90]```,建议不超过2个架构。 ``` git clone https://github.com/PaddlePaddle/FastDeploy @@ -91,7 +94,7 @@ bash build.sh 1 python false [80,90] ## 5. 算子预编译 Wheel 包 -FastDeploy 提供了 GPU 算子预编译版 Wheel 包,可在无需完整源码编译的情况下快速构建。该方式当前仅支持 **SM80/90 架构(A100/H100等)** 和 **CUDA 12.6** 环境。 +FastDeploy 提供了 GPU 算子预编译版 Wheel 包,可在无需完整源码编译的情况下快速构建。该方式当前仅支持 **SM80/90 架构(A100/H100等)** **CUDA 12.6** 和 **Python 3.10** 环境。 >默认情况下,`build.sh` 会从源码编译;若希望使用预编译包,可使用`FD_USE_PRECOMPILED` 参数; >若预编译包下载失败或与环境不匹配,系统会自动回退至 `4. wheel 包源码编译` 模式。 @@ -119,7 +122,7 @@ cd FastDeploy bash build.sh 1 python false [90] 1 # 从指定 commitID 获取对应预编译算子 -bash build.sh 1 python false [90] 1 8a9e7b53af4a98583cab65e4b44e3265a93e56d2 +bash build.sh 1 python false [90] 1 d693d4be1448d414097882386fdc24c8bec2a63a ``` 下载的 whl 包在 `FastDeploy/pre_wheel`目录下。 @@ -128,7 +131,7 @@ bash build.sh 1 python false [90] 1 8a9e7b53af4a98583cab65e4b44e3265a93e56d2 > **说明:** > - 该模式会优先下载预编译的 GPU 算子 whl 包,减少编译时间; -> - 目前仅支持 **GPU, SM80/90 架构, CUDA 12.6**; +> - 目前仅支持 **GPU, SM80/90 架构, CUDA 12.6, Python3.10**; > - 若希望自定义架构或修改算子逻辑,请使用 **源码编译方式(第4节)**。 > - 您可以在 FastDeploy CI 构建状态页面查看对应 commit 的预编译 whl 是否已构建成功。 From 42b0f59b9ebbd5752a23ad301daafd1851892b67 Mon Sep 17 00:00:00 2001 From: JYChen Date: Sat, 11 Apr 2026 18:38:37 +0800 Subject: [PATCH 021/143] [Cherry-Pick][RL] change glm rope_emb calculation #7316 (#7318) * change glm rope_emb calculation * glm without EnforceFmulRN * fix ci --- .../decoder_write_cache_with_rope_kernel.cu | 8 +-- .../encoder_write_cache_with_rope_impl.cuh | 8 +-- .../append_attn/gqa_rope_write_cache.cu | 50 ++++++++++--------- .../speculate_write_cache_with_rope_kernel.cu | 9 ++-- fastdeploy/envs.py | 2 + .../model_executor/layers/rotary_embedding.py | 10 +++- 6 files changed, 49 insertions(+), 38 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index 963ccfa23d9..e25816fcbb3 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -146,10 +146,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, rope_3d); } else { if (rotary_dim < dim_head) { - auto* kernelFn = - append_decode_cache_T_neox_partial_rope_kernel; + auto* kernelFn = append_decode_cache_T_neox_partial_rope_kernel< + T, + PackSize, + false>; // GLM use EnforceFmulRN=false launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 0cdea537327..60d5d34bf48 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -2543,10 +2543,10 @@ void gqa_rotary_qk_variable( } const int pack_num_new = elem_nums / PackSize; GetNumBlocks<128>(pack_num_new, &grid_size); - auto *kernelFn = - GQANeoxVariableLengthPartialRotaryKernel; + auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel< + T, + PackSize, + false>; // GLM use EnforceFmulRN=false launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index e4d0554fea6..c86ec27dca8 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -387,30 +387,32 @@ void gqa_neox_partial_rotary_qk_split_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2; - launchWithPdlWhenEnabled( - GQAVariableLengthNeoxPartialRotarySplitKernel, - grid_size, - block_size, - 0, - stream, - qkv_input, - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - cu_seqlens_k, - qkv_out, - q, - k, - v, - elem_nums, - num_heads, - kv_num_heads, - max_model_len, - head_dim, - rotary_dim); + launchWithPdlWhenEnabled(GQAVariableLengthNeoxPartialRotarySplitKernel< + T, + PackSize, + false>, // GLM use EnforceFmulRN=false + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rotary_dim); } template + append_speculate_cache_neox_partial_rope_kernel< + T, + PackSize, + QKV_TYPE, + false> // GLM use EnforceFmulRN=false <<>>( qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] key_cache, diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index d918a2e4648..e78db512be6 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -268,6 +268,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # Whether to align RoPE and moe gate precision with training + "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), } diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index af7203ed6f1..dd77cf2bc0d 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -20,6 +20,7 @@ import paddle from paddle import nn +from fastdeploy import envs from fastdeploy.config import ModelConfig from fastdeploy.platforms import current_platform @@ -87,8 +88,13 @@ def __init__(self, rotary_dim, base, partial_rotary_factor): def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] - inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) - freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + if envs.FD_ENABLE_RL == 1: + idx = paddle.arange(0, self.rotary_dim, 2, dtype=paddle.int64).astype(paddle.float32) + inv_freq = 1.0 / (self.base ** (idx / self.rotary_dim)) + freqs = paddle.outer(position_ids.astype(inv_freq.dtype), inv_freq) + else: + inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) # shape: [B, S, D/2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2)) From 7446665676db34657fb06a265418c73d893e7c08 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Sat, 11 Apr 2026 21:51:26 +0800 Subject: [PATCH 022/143] [Cherry-Pick][RL]moe bf16 ep support paddle batch_gemm(#7337) (#7339) * moe bf16 ep support paddle batch_gemm --- .../layers/moe/fused_moe_cutlass_backend.py | 26 ++++++++++--------- .../layers/test_fused_moe_cutlass_backend.py | 16 ++++++++---- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index f927cd8c5ee..1dc349478c8 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -43,6 +43,7 @@ logger.warning("import w4afp8_gemm_scale_permute Failed!") from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops from fastdeploy.model_executor.utils import ( TensorTracker, free_tensor, @@ -166,18 +167,19 @@ def apply_ep_prefill( override_buffer_size=token_all_num, ) - token_nums_per_expert_cumsum = count_tokens_per_expert_func( - recv_topk_idx, layer.num_local_experts, True - )[2].cast(paddle.int64) - ffn_out = self.compute_ffn( - layer, + out = paddle.incubate.nn.functional.batched_gemm( permute_input, - token_nums_per_expert_cumsum, - None, - False, - -1, - None, - None, + getattr(layer, self.added_weight_attrs[0]), + recv_num_tokens_per_expert_list, + ) + if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: + out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) + else: + out = paddle.incubate.nn.functional.swiglu(out) + ffn_out = paddle.incubate.nn.functional.batched_gemm( + out, + getattr(layer, self.added_weight_attrs[1]), + recv_num_tokens_per_expert_list, ) tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( @@ -187,7 +189,7 @@ def apply_ep_prefill( token_prob_unzipped=dst_weights, total_zipped_tokens=recv_x.shape[0], num_experts=layer.num_local_experts, - using_weighted_combine=True, + using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE, ) else: # --- original ep_moe_expert_dispatch / combine path --- diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index a476744b51b..0a03ecc62a8 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -40,6 +40,10 @@ from fastdeploy.model_executor.layers.moe import fused_moe_cutlass_backend as backend +def align(x, y): + return (x + y - 1) // y * y + + class DummyQuantConfig: def __init__(self, algo="weight_only_int8", is_quantized=False, is_checkpoint_bf16=False): self.algo = algo @@ -752,18 +756,18 @@ def __init__(self, num_experts=4, hidden_size=64, moe_intermediate_size=32, top_ ) paddle.seed(0) self.up_gate_proj_weight = self.create_parameter( - shape=[num_experts, 2 * moe_intermediate_size, hidden_size], + shape=[num_experts, hidden_size, 2 * moe_intermediate_size], dtype="bfloat16", ) self.down_proj_weight = self.create_parameter( - shape=[num_experts, hidden_size, moe_intermediate_size], + shape=[num_experts, moe_intermediate_size, hidden_size], dtype="bfloat16", ) self.up_gate_proj_weight.set_value( - paddle.randn([num_experts, 2 * moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01 + paddle.randn([num_experts, hidden_size, 2 * moe_intermediate_size]).cast("bfloat16") * 0.01 ) self.down_proj_weight.set_value( - paddle.randn([num_experts, hidden_size, moe_intermediate_size]).cast("bfloat16") * 0.01 + paddle.randn([num_experts, moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01 ) @@ -863,7 +867,9 @@ def dispatch(self, x, topk_idx, topk_weights, **kwargs): # Pass tensors through unchanged — single-rank, no real communication. # Compute accurate recv_num_tokens_per_expert_list from topk_idx. E = layer.num_local_experts - counts = [int((topk_idx == e).sum().item()) for e in range(E)] + counts = [ + align(int((topk_idx == e).sum().item()), kwargs.get("expert_alignment", 1)) for e in range(E) + ] return ( x, topk_idx, From 9e8ea7db14f404c59eee7cddccd0cc3d7950cbc0 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Sun, 12 Apr 2026 13:22:52 +0800 Subject: [PATCH 023/143] [Cherry-Pick][CI] Sync dev optimizations to 2.6(#7335) (#7343) --- .github/workflows/_accuracy_test.yml | 34 +++++++++++++++++--- .github/workflows/_base_test.yml | 23 +++++++++++-- .github/workflows/_build_linux.yml | 12 ++++++- .github/workflows/_build_linux_cu129.yml | 12 ++++++- .github/workflows/_build_linux_cu130.yml | 12 ++++++- .github/workflows/_build_linux_fd_router.yml | 12 ++++++- .github/workflows/_build_linux_rl.yml | 13 +++++++- .github/workflows/_golang_router_test.yml | 29 +++++++++++++++-- .github/workflows/_gpu_4cards_case_test.yml | 29 +++++++++++++++-- .github/workflows/_logprob_test_linux.yml | 30 +++++++++++++++-- .github/workflows/_pre_ce_test.yml | 29 +++++++++++++++-- .github/workflows/_stable_test.yml | 29 +++++++++++++++-- .github/workflows/_unit_test_coverage.yml | 33 ++++++++++++++++--- scripts/run_pre_ce.sh | 6 +++- 14 files changed, 272 insertions(+), 31 deletions(-) diff --git a/.github/workflows/_accuracy_test.yml b/.github/workflows/_accuracy_test.yml index 4efb008da17..87994625c58 100644 --- a/.github/workflows/_accuracy_test.yml +++ b/.github/workflows/_accuracy_test.yml @@ -69,12 +69,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -145,7 +160,10 @@ jobs: docker rm -f ${runner_name} || true fi - docker run --rm --ipc=host --pid=host --net=host \ + docker run --rm --net=host \ + --shm-size=64g \ + --sysctl kernel.msgmax=1048576 \ + --sysctl kernel.msgmnb=268435456 \ --name ${runner_name} \ -v $(pwd):/workspace \ -w /workspace \ @@ -160,6 +178,7 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ @@ -204,3 +223,10 @@ jobs: fi echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" exit ${TEST_EXIT_CODE} + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index b114bad15d4..3eb022725e5 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -81,7 +81,14 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' @@ -111,7 +118,11 @@ jobs: exit 1 fi - tar -xf FastDeploy.tar.gz + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -200,6 +211,7 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ @@ -294,3 +306,10 @@ jobs: fi echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" exit ${TEST_EXIT_CODE} + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux.yml b/.github/workflows/_build_linux.yml index 172f07cfd73..1431df353cb 100644 --- a/.github/workflows/_build_linux.yml +++ b/.github/workflows/_build_linux.yml @@ -125,6 +125,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -156,7 +157,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -171,6 +173,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -248,3 +251,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_cu129.yml b/.github/workflows/_build_linux_cu129.yml index 6370268c7cb..61108a82b40 100644 --- a/.github/workflows/_build_linux_cu129.yml +++ b/.github/workflows/_build_linux_cu129.yml @@ -112,6 +112,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -143,7 +144,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -158,6 +160,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -235,3 +238,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path_cu129=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_cu130.yml b/.github/workflows/_build_linux_cu130.yml index 278aff6956b..7c2aee69c6f 100644 --- a/.github/workflows/_build_linux_cu130.yml +++ b/.github/workflows/_build_linux_cu130.yml @@ -112,6 +112,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -143,7 +144,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache_cu130:/root/.cache" \ @@ -158,6 +160,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -235,3 +238,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path_cu130=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_fd_router.yml b/.github/workflows/_build_linux_fd_router.yml index b600cc2328e..9e93290d509 100644 --- a/.github/workflows/_build_linux_fd_router.yml +++ b/.github/workflows/_build_linux_fd_router.yml @@ -107,6 +107,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy FD_ROUTER Build shell: bash env: @@ -137,7 +138,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -151,6 +153,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -211,3 +214,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" FD_ROUTER_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/fd-router echo "fd_router_path=${FD_ROUTER_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_rl.yml b/.github/workflows/_build_linux_rl.yml index 9e809d59a59..1a131adb1a1 100644 --- a/.github/workflows/_build_linux_rl.yml +++ b/.github/workflows/_build_linux_rl.yml @@ -52,6 +52,7 @@ on: wheel_path_rl: description: "Output path of the generated wheel" value: ${{ jobs.fd-build-rl.outputs.wheel_path_rl }} + jobs: fd-build-rl: runs-on: [self-hosted, GPU-Build-RL] @@ -107,6 +108,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -137,7 +139,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache_rl:/root/.cache" \ @@ -151,6 +154,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -202,3 +206,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path_rl=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_golang_router_test.yml b/.github/workflows/_golang_router_test.yml index 4964f3a3a05..bbb1bc7799f 100644 --- a/.github/workflows/_golang_router_test.yml +++ b/.github/workflows/_golang_router_test.yml @@ -76,12 +76,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -191,6 +206,7 @@ jobs: -e "fd_router_url=${fd_router_url}" \ -e "BASE_REF=${BASE_REF}" \ -e "IS_PR=${IS_PR}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy @@ -211,3 +227,10 @@ jobs: bash scripts/run_golang_router.sh ' + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_gpu_4cards_case_test.yml b/.github/workflows/_gpu_4cards_case_test.yml index 02a16b8b93b..5c9a51aa809 100644 --- a/.github/workflows/_gpu_4cards_case_test.yml +++ b/.github/workflows/_gpu_4cards_case_test.yml @@ -81,12 +81,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -186,6 +201,7 @@ jobs: -e "fd_wheel_url=${fd_wheel_url}" \ -e "BASE_REF=${BASE_REF}" \ -e "IS_PR=${IS_PR}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy @@ -204,3 +220,10 @@ jobs: export CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/run_gpu_4cards.sh ' + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_logprob_test_linux.yml b/.github/workflows/_logprob_test_linux.yml index 47486cef243..0a014d26854 100644 --- a/.github/workflows/_logprob_test_linux.yml +++ b/.github/workflows/_logprob_test_linux.yml @@ -78,11 +78,27 @@ jobs: if ls /workspace/* >/dev/null 2>&1; then echo "ERROR: Failed to clean /workspace/* after multiple attempts" ls -ld /workspace/* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls /workspace/* >/dev/null 2>&1; then + echo "ERROR: Force cleanup failed. Exiting..." + exit 1 + else + echo "Force cleanup succeeded." + fi fi ' - wget -q --no-proxy ${paddletest_archive_url} - tar -xf PaddleTest.tar.gz + + wget -q --no-proxy ${paddletest_archive_url} || { + echo "ERROR: Failed to download archive from ${paddletest_archive_url}" + exit 1 + } + + tar --no-same-owner -xf PaddleTest.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf PaddleTest.tar.gz cd PaddleTest git config --global user.name "FastDeployCI" @@ -171,6 +187,7 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ @@ -223,3 +240,10 @@ jobs: run: | echo "logprob test failed with exit code ${{ env.LOGPROB_EXIT_CODE }}" exit 8 + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_pre_ce_test.yml b/.github/workflows/_pre_ce_test.yml index 9e313606a36..8c5e20e2de0 100644 --- a/.github/workflows/_pre_ce_test.yml +++ b/.github/workflows/_pre_ce_test.yml @@ -83,12 +83,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -182,6 +197,7 @@ jobs: -e "FD_ZMQ_SEND_RESPONSE_SERVER_PORT=${FD_ZMQ_SEND_RESPONSE_SERVER_PORT}" \ -e "FD_ZMQ_CONTROL_CMD_SERVER_PORTS=${FD_ZMQ_CONTROL_CMD_SERVER_PORTS}" \ -e "fd_wheel_url=${fd_wheel_url}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy @@ -189,3 +205,10 @@ jobs: python -m pip install ${fd_wheel_url} bash scripts/run_pre_ce.sh ' + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_stable_test.yml b/.github/workflows/_stable_test.yml index dd4ce4e811d..11ae14927ef 100644 --- a/.github/workflows/_stable_test.yml +++ b/.github/workflows/_stable_test.yml @@ -81,12 +81,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -176,6 +191,7 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ @@ -221,3 +237,10 @@ jobs: fi echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" exit ${TEST_EXIT_CODE} + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_unit_test_coverage.yml b/.github/workflows/_unit_test_coverage.yml index 75aef3e937a..4096fe4f4c0 100644 --- a/.github/workflows/_unit_test_coverage.yml +++ b/.github/workflows/_unit_test_coverage.yml @@ -86,12 +86,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -178,10 +193,12 @@ jobs: --sysctl kernel.msgmnb=268435456 \ --name ${runner_name} \ --cap-add=SYS_PTRACE --cap-add=IPC_LOCK \ - --shm-size=64G \ + --shm-size=128G \ ${RDMA_DEVICES} \ --device=/dev/infiniband/rdma_cm \ --ulimit memlock=-1:-1 \ + --ulimit nofile=65536:65536 \ + --ulimit nproc=8192:8192 \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -201,6 +218,7 @@ jobs: -e "fd_wheel_url=${fd_wheel_url}" \ -e "BASE_REF=${BASE_REF}" \ -e "IS_PR=${IS_PR}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy @@ -388,6 +406,13 @@ jobs: echo "coverage passed" exit 0 + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} + diff_coverage_report: needs: run_tests_with_coverage if: always() diff --git a/scripts/run_pre_ce.sh b/scripts/run_pre_ce.sh index 8eafe280346..928aa2e7cef 100644 --- a/scripts/run_pre_ce.sh +++ b/scripts/run_pre_ce.sh @@ -7,7 +7,11 @@ python -m pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/p python -m pip install -r requirements.txt python -m pip install jsonschema aistudio_sdk==0.3.5 -python -m pip install xgrammar==0.1.19 torch==2.6.0 +# Use prebuilt wheel files to install xgrammar==0.1.19 and torch==2.6.0 specifically for the CI environment +python -m pip install \ + https://paddle-qa.bj.bcebos.com/FastDeploy/torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/xgrammar-0.1.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl failed_files=() run_path="$DIR/../tests/ci_use/" From 9cb82d79a04263b6931fccbf36ad0650897a422b Mon Sep 17 00:00:00 2001 From: liuruyan <44316842+liuruyan@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:02:08 +0800 Subject: [PATCH 024/143] [Cherry-Pick][TI-consistent] support quant use pow2scale(#7308) (#7310) * support quant use pow2scale * fix * fix --- fastdeploy/envs.py | 25 +++++++++++-------- .../layers/moe/fused_moe_deepgemm_backend.py | 9 ++++--- .../layers/moe/fused_moe_triton_backend.py | 6 +++-- .../layers/quantization/block_wise_fp8.py | 2 +- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index e78db512be6..733d874edc4 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -210,17 +210,6 @@ def _validate_split_kv_size(value: int) -> int: "FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""), # Whether to enable low latency in mixed scenario "FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))), - # Whether to use phi FP8 quantization,if 1,use paddle default. - "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), - # Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc, - # intended for training alignment. Defaults to 0 (disabled). - "FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))), - # Whether to use phi MOE permute,if 1,use paddle op. - "FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))), - # Whether to use phi rms_norm,if 1,use paddle op. - "FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))), - # Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function - "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), # Reserve output blocks for decoding requests when schedule new prefill requests "FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16") @@ -268,8 +257,22 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # train-infer consistency, used in RL # Whether to align RoPE and moe gate precision with training "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), + # Whether to use phi FP8 quantization,if 1,use paddle default. + "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), + # Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc, + # intended for training alignment. Defaults to 0 (disabled). + "FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))), + # Whether to use phi MOE permute,if 1,use paddle op. + "FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))), + # Whether to use phi rms_norm,if 1,use paddle op. + "FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))), + # Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function + "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), + # Whether to enable FP8 quantization with pow2scale. + "FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))), } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 53247e29126..a16e5ccbe9c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -188,7 +188,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op( else: ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( ffn_out, - using_pow2_scale=not disable_ue8m0_cast, + using_pow2_scale=not disable_ue8m0_cast or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, using_ue8m0_scale=not disable_ue8m0_cast, ) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] @@ -355,7 +355,7 @@ def apply_ep_prefill( else: x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) @@ -581,7 +581,8 @@ def apply_ep_prefill( else: ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( ffn_out, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 + or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] @@ -773,7 +774,7 @@ def apply_tp( else: recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index d1db43a3241..65d1d23b9be 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1247,7 +1247,7 @@ def python_op_fused_moe_kernel_paddle( x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False) else: x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, using_pow2_scale=False, output_scale_transpose=False + x, using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=False ) x_scale = x_scale[: x.shape[0]] @@ -1305,7 +1305,9 @@ def python_op_fused_moe_kernel_paddle( ) else: x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False + intermediate_cache2, + using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, + output_scale_transpose=False, ) x_scale = x_scale[: x_q.shape[0]] diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index a86170e0727..ae37ca45961 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -343,7 +343,7 @@ def apply(self, layer, x): else: x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=True, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) From b2997f3aad74ad3f4e1e0f867168718ec14bdd13 Mon Sep 17 00:00:00 2001 From: sunxin <68891411+Sunny-bot1@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:20:11 +0800 Subject: [PATCH 025/143] fix overlap mtp empty run (#7314) --- fastdeploy/worker/gpu_model_runner.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 6218e58687b..726a11a7627 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2110,6 +2110,12 @@ def execute_model_overlap( self._cached_sampler_output = sampler_output self._cached_post_process_event = post_process_event else: + if ( + self.fd_config.speculative_config.method == SpecMethod.MTP + and hasattr(self.proposer.model, "empty_input_forward") + and self.parallel_config.use_ep + ): + self._execute_empty_mtp_input(self.forward_meta) self._cached_model_output_data = None self._cached_sampler_output = None self._cached_post_process_event = None From d9a008f3c82cab23cac32d5dc34a1a57ff8695ea Mon Sep 17 00:00:00 2001 From: chenjian <1435317881@qq.com> Date: Mon, 13 Apr 2026 15:24:01 +0800 Subject: [PATCH 026/143] [Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 (#7159) (#7351) * [Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 * [Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 * fix --- fastdeploy/worker/gpu_model_runner.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 726a11a7627..73d9a791843 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -27,7 +27,7 @@ from paddle import nn from paddleformers.utils.log import logger -from fastdeploy.config import FDConfig +from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.request import ImagePosition, Request, RequestType from fastdeploy.model_executor.graph_optimization.utils import ( @@ -2409,6 +2409,16 @@ def _postprocess( # 5.1. Async cpy post_process_event = paddle.device.cuda.create_event() + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + # If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished. + paddle.assign( + paddle.where( + self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1, + PREEMPTED_TOKEN_ID, + sampler_output.sampled_token_ids, + ), + sampler_output.sampled_token_ids, + ) # if not self.speculative_decoding: self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False) if self.speculative_decoding: From 9823d632202e99017c832f3f829183a3275b6830 Mon Sep 17 00:00:00 2001 From: JYChen Date: Mon, 13 Apr 2026 19:24:24 +0800 Subject: [PATCH 027/143] remove fa4 requirements (#7354) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e662f07e974..f51ae1bdf9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,5 +47,5 @@ aistudio_sdk p2pstore py-cpuinfo flashinfer-python-paddle -flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl +#flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl transformers>=4.55.1,<5.0.0 From 144dc17b14caa8fc18826439abbf5f293955edf7 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Mon, 13 Apr 2026 23:06:16 +0800 Subject: [PATCH 028/143] update attn_mask_q 2 (#7373) --- custom_ops/gpu_ops/get_attn_mask_q.cu | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/custom_ops/gpu_ops/get_attn_mask_q.cu b/custom_ops/gpu_ops/get_attn_mask_q.cu index 4ee814178bc..a485d04f6bc 100644 --- a/custom_ops/gpu_ops/get_attn_mask_q.cu +++ b/custom_ops/gpu_ops/get_attn_mask_q.cu @@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel( const int max_batch_size) { constexpr int VecSize = 4; const uint32_t tid = threadIdx.x, bid = blockIdx.x; - int startend_row_vec[4]; + int startend_row_vec[2]; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel( const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start; startend_row_vec[0] = this_batch_q_end; - startend_row_vec[1] = cu_seqlens_q[max_batch_size]; - startend_row_vec[2] = 0; - startend_row_vec[3] = this_batch_q_end; + // startend_row_vec[1] = cu_seqlens_q[max_batch_size]; + // startend_row_vec[2] = 0; + startend_row_vec[1] = this_batch_q_end; for (int this_batch_q_idx = this_batch_q_start; this_batch_q_idx < this_batch_q_end; ++this_batch_q_idx) { @@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel( : this_batch_q_idx - this_batch_q_start + kv_len - (this_batch_q_len); if (cache_k_idx <= append_mask_k_end) { - startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx); + startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx); // 可提前跳出循环 break; } } - reinterpret_cast(startend_row_indices_ptr + - cu_seqlens_k_idx * 4)[0] = - reinterpret_cast(startend_row_vec)[0]; + reinterpret_cast(startend_row_indices_ptr + + cu_seqlens_k_idx * 2)[0] = + reinterpret_cast(startend_row_vec)[0]; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); @@ -82,7 +82,7 @@ std::vector get_attn_mask_q( const paddle::optional& attn_mask_kv, const int kv_token_num) { paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor( - {1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place()); + {1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place()); const int max_batch_size = cu_seqlens_k.dims()[0] - 1; constexpr int block_size = 512; int grid_size = div_up(kv_token_num, block_size); @@ -123,7 +123,7 @@ std::vector> GetAttnMaskQInferShape( const std::vector& cu_seqlens_k_shape, const paddle::optional>& attn_mask_kv_shape, const int kv_token_num) { - return {{1, 1, kv_token_num, 4}}; + return {{1, 1, kv_token_num, 2}}; } PD_BUILD_STATIC_OP(get_attn_mask_q) From e7c8dc2fe940ba2f50d932f17b41eab555846ecb Mon Sep 17 00:00:00 2001 From: lonelygsh <80582973+lonelygsh@users.noreply.github.com> Date: Tue, 14 Apr 2026 12:54:22 +0800 Subject: [PATCH 029/143] [Speculate Decoding] Fix step_idx semantics in limit_thinking and set_stop_value kernels (#7370) - speculate_limit_thinking_content_length: update current_base_step to step_idx+1 (step_idx now records history count before current round); remove incorrect step_idx decrement on accept_num truncation; mark step_idx param as const. - speculate_set_stop_value_multi_seqs: fix can_stop gate to use step_idx_now+accept_num>=min_token_limit; fix skip check and pre_ids_idx formula (remove stale -accept_num offset); use <= condition so accept_idx maps directly to the accepted token that ends the stop sequence; fix accept_tokens index (remove -1). - Update unit tests for speculate_set_stop_value_multi_seqs kernel. --- ...speculate_limit_thinking_content_length.cu | 17 +- .../speculate_set_stop_value_multi_seqs.cu | 64 ++--- .../unified_update_model_status.cu | 2 +- ...est_speculate_set_stop_value_multi_seqs.py | 248 +++++++++++++----- .../test_unified_update_model_status.py | 4 +- 5 files changed, 229 insertions(+), 106 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu index 18aa5d53d21..e620e914a25 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu @@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int64_t* next_tokens, // [bs, tokens_per_step] const int* max_think_lens, // [bs] int* max_reply_lens, // [bs] - int64_t* step_idx, // [bs] + const int64_t* step_idx, // [bs] const int64_t* eos_token_ids, // [eos_len] int* limit_status, // [bs] int* accept_num, // [bs] @@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int new_accept_num = original_accept_num; // 本 step 的 token offset 对应的绝对 step - const int64_t current_base_step = step_idx[bid] - original_accept_num + 1; + const int64_t current_base_step = step_idx[bid] + 1; for (int token_offset = 0; token_offset < original_accept_num; token_offset++) { @@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel( // inject_token_ids[0]) if (status == 0 && (current_step - 1) == - max_think_len) { // current_step - 1 是因为 speculate_verify 里 - // step_idx + 1 了 + max_think_len) { // current_step - 1 : 已输出 current_step-1 + // 个thinking token status = (inject_len > 0) ? 1 : done_status; } } else if (max_think_len == 0) { @@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel( } } - // 更新 step_idx / accept_num(被截断的 token 需要回退 - // step_idx) - const int discarded_tokens = original_accept_num - new_accept_num; - if (discarded_tokens > 0) { - step_idx[bid] -= discarded_tokens; - } - accept_num[bid] = new_accept_num; limit_status[bid] = status; max_reply_lens[bid] = max_reply_len; @@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength( const_cast(next_tokens.data()), max_think_lens.data(), const_cast(max_reply_lens.data()), - const_cast(step_idx.data()), + step_idx.data(), eos_token_ids.data(), const_cast(limit_status.data()), const_cast(accept_num.data()), diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu index ee364884e96..c6379387efe 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu @@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, const int64_t step_idx_now = step_idx[bid]; const int64_t min_token_limit = min_tokens[bid]; - const bool can_stop = (step_idx_now >= min_token_limit); + const bool can_stop = (step_idx_now + accept_num >= min_token_limit); if (!can_stop) return; if (!stop_flags[bid]) { - int accept_idx = 0; + /* + accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based) + accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾 + (pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。 + 为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1] + (当前轮最后一个 token),该 token 延迟到下一轮匹配。 + 循环范围:accept_num > 0 时为 [-1, accept_num-2]; + accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。 + */ + int accept_idx = -1; bool is_end = false; - // 遍历起始位置 - for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { + + // 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾 + // 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens + // 中的匹配。两者共享同一套从后向前匹配逻辑。 + int loop_end = (accept_num > 0) ? accept_num - 2 : -1; + for (; accept_idx <= loop_end && !is_end; accept_idx++) { if (step_idx_now + accept_idx + 1 < stop_seq_len) { #ifdef DEBUG_SPEC_STOP_SEQS printf("num %d < stop_seq_len %d\n", - step_idx_now - accept_num + accept_idx + 1, + step_idx_now + accept_idx + 1, stop_seq_len); #endif continue; } - // 遍历一个 stop_seqs + // 从后向前匹配 stop_seq 的每个 token for (int i = stop_seq_len - 1; i >= 0; --i) { int64_t cur_token_idx = -1; - // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 - if (stop_seq_len - 1 - i < accept_idx) { + int offset = stop_seq_len - 1 - i; + int accept_tokens_idx = accept_idx - offset; + + if (accept_tokens_idx >= 0) { #ifdef DEBUG_SPEC_STOP_SEQS printf( "AcceptTokens bid:%d. tid:%d, accept_idx:%d, " - "accept_token_idx: " - "%d\n", + "accept_token_idx: %d\n", bid, tid, accept_idx, - accept_idx - (stop_seq_len - 1 - i) - 1); + accept_tokens_idx); #endif - cur_token_idx = - accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]; + cur_token_idx = accept_tokens_now[accept_tokens_idx]; } else { + int pre_ids_idx = step_idx_now + accept_tokens_idx; #ifdef DEBUG_SPEC_STOP_SEQS printf( "PreIds bid:%d. tid:%d, step_idx_now:%ld. " - "accept_idx:%d. " - "pre_id_idx: %ld\n", + "accept_idx:%d. pre_id_idx: %d\n", bid, tid, step_idx_now, accept_idx, - step_idx_now - accept_num + accept_idx - - (stop_seq_len - 1 - i)); + pre_ids_idx); #endif - int pre_ids_idx = - step_idx_now + accept_idx - (stop_seq_len - 1 - i); - // EC3 - // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, - // 导致异常结束 - if (pre_ids_idx <= 0) { - break; - } + if (pre_ids_idx < 0) break; cur_token_idx = pre_ids_now[pre_ids_idx]; } #ifdef DEBUG_SPEC_STOP_SEQS @@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, } if (is_end) { #ifdef DEBUG_SPEC_STOP_SEQS - printf("bid:%d end with accept_idx %d", bid, accept_idx); + printf("bid:%d end with accept_idx %d\n", bid, accept_idx); #endif - - accept_nums[bid] = accept_idx; - accept_tokens_now[accept_idx - 1] = end_ids[0]; - // stop_flags[bid] = true; + // accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置 + accept_nums[bid] = accept_idx + 1; + accept_tokens_now[accept_idx] = end_ids[0]; } } } diff --git a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu index 94f71d6fd0e..f7ab5daece6 100644 --- a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu +++ b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu @@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder, int64_t *token_ids_all_now = &token_ids_all[batch_id * max_model_len + prompt_len]; int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; - int64_t base = cur_step_idx - output_len + 1; + int64_t base = cur_step_idx - output_len; for (int i = 0; i < output_len; i++) { token_ids_all_now[base + i] = output_ids[i]; } diff --git a/tests/operators/test_speculate_set_stop_value_multi_seqs.py b/tests/operators/test_speculate_set_stop_value_multi_seqs.py index 45d8a0ef34f..aa048560c30 100644 --- a/tests/operators/test_speculate_set_stop_value_multi_seqs.py +++ b/tests/operators/test_speculate_set_stop_value_multi_seqs.py @@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: return paddle_inputs -def run_kernel(paddle_inputs, inputs): +def run_kernel(paddle_inputs): """Call the CUDA kernel.""" speculate_set_stop_value_multi_seqs( paddle_inputs["accept_tokens"], @@ -137,7 +137,18 @@ def gen_inputs( def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]: - """Python reference — must match CUDA kernel logic exactly.""" + """Python reference — must match CUDA kernel logic exactly. + + token_ids_all 布局 (新 step_idx 语义): + pre_ids_now[k] = 第 k 个 output token (k >= 0, 0-indexed) + 最后一个 output token 在 pre_ids_now[step_idx - 1] + step_idx = 历史已生成的 token 数量 + + 核心设计: + 1. accept_idx 从 -1 开始,-1 表示检查 pre_ids 末尾(上一轮延迟的情况) + 2. 主循环检查 accept_idx <= accept_num - 2 + 3. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos + """ accept_tokens = inputs["accept_tokens"].copy() accept_num = inputs["accept_num"].copy() stop_flags = inputs["stop_flags"].copy() @@ -166,27 +177,36 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str step_idx_now = int(step_idx[bid]) min_token_limit = int(min_tokens[bid]) - can_stop = step_idx_now >= min_token_limit + can_stop = step_idx_now + an >= min_token_limit if not can_stop: continue if stop_flags[bid]: continue - accept_idx = 0 + # CUDA kernel: accept_idx 从 -1 开始,检查 pre_ids 末尾 + accept_idx = -1 is_end = False - while accept_idx <= an - 1 and not is_end: + + # loop_end = accept_num > 0 ? accept_num - 2 : -1 + loop_end = an - 2 if an > 0 else -1 + while accept_idx <= loop_end and not is_end: if step_idx_now + accept_idx + 1 < stop_seq_len: accept_idx += 1 continue - # Check one stop_seq match + # 从后向前匹配 stop_seq 的每个 token for i in range(stop_seq_len - 1, -1, -1): + offset = stop_seq_len - 1 - i + accept_tokens_idx = accept_idx - offset cur_token_idx = -1 - if stop_seq_len - 1 - i < accept_idx: - cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1] + + if accept_tokens_idx >= 0: + cur_token_idx = accept_tokens_now[accept_tokens_idx] else: - pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i) - if pre_ids_idx <= 0: + # 新语义: pre_ids_idx = step_idx_now + accept_tokens_idx + # pre_ids_now[0] 是第 1 个 output token + pre_ids_idx = step_idx_now + accept_tokens_idx + if pre_ids_idx < 0: break cur_token_idx = pre_ids_now[pre_ids_idx] @@ -199,9 +219,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str accept_idx += 1 if is_end: - accept_num[bid] = accept_idx - accept_tokens[bid, accept_idx - 1] = end_ids[0] - # stop_flags[bid] = True # kernel no longer sets stop_flags + # accept_idx 已递增,指向 stop_seq 最后 token 的下一个位置 + # 保留 stop_seq 所有 token,在其后追加 eos + accept_num[bid] = accept_idx + 1 + accept_tokens[bid, accept_idx] = end_ids[0] return { "accept_tokens": accept_tokens, @@ -239,7 +260,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): def _run_and_get(self, inputs): paddle_inputs = to_paddle_inputs(inputs) - run_kernel(paddle_inputs, inputs) + run_kernel(paddle_inputs) return get_outputs(paddle_inputs) def _check_all_outputs(self, inputs, outputs): @@ -264,7 +285,7 @@ def test_configs(self): self._run_full_test(test_cfg) def test_match_in_accept_tokens_only(self): - """Stop seq found entirely within accept_tokens.""" + """Stop seq found entirely within accept_tokens. Eos appended after stop_seq last token.""" inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10) # Place stop seq [A, B, C] at accept_tokens positions [0,1,2] inputs["accept_num"][:] = 4 @@ -276,9 +297,13 @@ def test_match_in_accept_tokens_only(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30) + # After loop increment, accept_idx=3, accept_num=4, eos appended at accept_tokens[3] + self.assertEqual(outputs["accept_num"][0], 4) + self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq def test_match_spanning_pre_ids_and_accept(self): - """Stop seq spans token_ids_all (pre_ids) and accept_tokens.""" + """Stop seq spans token_ids_all (pre_ids) and accept_tokens. Eos appended after stop_seq last token.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -290,12 +315,15 @@ def test_match_spanning_pre_ids_and_accept(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 6 inputs["accept_num"][:] = 3 - # Kernel matching at accept_idx=2 (3rd token, 0-indexed): - # i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1] - # i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0] - # i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6] - # So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]] - inputs["token_ids_all"][0, 6] = 99 + # stop_seq = [99, 11, 22] (len=3) + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 6 表示有 6 个历史 output token,在 pre_ids_now[0..5] + # At accept_idx=1 (window ends at accept_tokens[1]=22): + # i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓ + # i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓ + # i=0: offset=2, accept_tokens_idx=-1 -> pre_ids_idx=6+(-1)=5 -> pre_ids[5]=99 vs stop_seq[0]=99 ✓ + inputs["token_ids_all"][0, 5] = 99 # pre_ids_now[5] = 第 6 个 output token (0-indexed) inputs["accept_tokens"][0, :3] = [11, 22, 33] inputs["stop_seqs"][0, 0, :3] = [99, 11, 22] inputs["stop_seqs_len"][0, 0] = 3 @@ -303,12 +331,14 @@ def test_match_spanning_pre_ids_and_accept(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - # Match at accept_idx=2, loop increments to 3 + # Match at accept_idx=1, loop increments to 2 -> accept_num=3, eos at accept_tokens[2] self.assertEqual(outputs["accept_num"][0], 3) - self.assertEqual(outputs["accept_tokens"][0, 2], -1) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq - def test_match_in_pre_ids_only(self): - """Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0.""" + def test_match_in_pre_ids_only_not_detected(self): + """Stop seq ending purely in pre_ids history but NOT at the end position. + The kernel only detects stop_seq at the very end of pre_ids via accept_idx=-1 check. + Stop seq placed earlier in pre_ids should not be detected.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -320,15 +350,13 @@ def test_match_in_pre_ids_only(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70 - # stop_seq = [50, 60, 70], all 3 tokens are in pre_ids - # For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check - # i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70 - # i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60 - # i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50 - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 + # 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 8 表示有 8 个历史 output token,在 pre_ids_now[0..7] + # accept_idx=-1 会检查 pre_ids_now[7] 开始的 stop_seq + # 把 stop_seq 放在 pre_ids_now[2,3,4] - 不会被检测到 + inputs["token_ids_all"][0, 2] = 50 + inputs["token_ids_all"][0, 3] = 60 + inputs["token_ids_all"][0, 4] = 70 inputs["accept_tokens"][0, :3] = [1, 2, 3] inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] inputs["stop_seqs_len"][0, 0] = 3 @@ -336,7 +364,8 @@ def test_match_in_pre_ids_only(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - self.assertEqual(outputs["accept_num"][0], 1) + # No match: stop_seq is in pre_ids but not at the end, accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 3) def test_already_stopped(self): """Kernel skips sequences with stop_flags=True.""" @@ -351,7 +380,7 @@ def test_already_stopped(self): np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"]) def test_min_tokens_blocks_stop(self): - """Kernel skips stop check when step_idx < min_tokens.""" + """Kernel skips stop check when step_idx + accept_num < min_tokens.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -363,20 +392,24 @@ def test_min_tokens_blocks_stop(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # Same setup that would match (like test_match_in_pre_ids_only) - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 + # Place stop_seq in pre_ids at end position (would be detected by accept_idx=-1) + # pre_ids_now[0..7] = 8 个历史 output token + # accept_idx=-1 检查 pre_ids_now[5,6,7] 对应 stop_seq[0,1,2] + inputs["token_ids_all"][0, 5] = 50 + inputs["token_ids_all"][0, 6] = 60 + inputs["token_ids_all"][0, 7] = 70 inputs["accept_tokens"][0, :3] = [1, 2, 3] inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] inputs["stop_seqs_len"][0, 0] = 3 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop + inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # min_tokens prevents stop, accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 3) def test_min_tokens_allows_stop(self): - """Kernel allows stop when step_idx >= min_tokens.""" + """Kernel allows stop when step_idx + accept_num >= min_tokens.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -388,15 +421,17 @@ def test_min_tokens_allows_stop(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only) - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 - inputs["accept_tokens"][0, :3] = [1, 2, 3] - inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] - inputs["stop_seqs_len"][0, 0] = 3 + # stop_seq [X, 50] spans pre_ids and accept_tokens[0]. + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # At accept_idx=0 (window ends at accept_tokens[0]=50): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=8+(-1)=7 -> pre_ids[7] + pre_val = int(inputs["token_ids_all"][0, 7]) # pre_ids_now[7] + inputs["accept_tokens"][0, :3] = [50, 60, 70] + inputs["stop_seqs"][0, 0, :2] = [pre_val, 50] + inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop + inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) @@ -413,20 +448,24 @@ def test_multiple_stop_seqs_second_matches(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # accept_tokens: stop_seq[20,30] matches at accept_idx=2: - # i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK - # i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK + # accept_tokens: [20, 30, 40] + # Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30): + # i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30 ✓ + # i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓ inputs["accept_tokens"][0, :3] = [20, 30, 40] # First stop seq doesn't match inputs["stop_seqs"][0, 0, :3] = [99, 98, 97] inputs["stop_seqs_len"][0, 0] = 3 - # Second stop seq matches + # Second stop seq [20, 30] matches inputs["stop_seqs"][0, 1, :2] = [20, 30] inputs["stop_seqs_len"][0, 1] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=1 -> accept_num=3, eos at accept_tokens[2] + self.assertEqual(outputs["accept_num"][0], 3) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq def test_nonzero_prompt_lens(self): """Verify prompt_lens offset is applied correctly.""" @@ -444,19 +483,104 @@ def test_nonzero_prompt_lens(self): inputs["accept_num"][:] = 2 inputs["accept_tokens"][0, :2] = [55, 66] # pre_ids_now starts at token_ids_all[0, prompt_len:] - # stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx] - # For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4 - # -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4] - # For accept_idx=1 (second token is accept_tokens[0,0]=55): - # i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55 - # i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5] - target_val = int(inputs["token_ids_all"][0, prompt_len + 5]) + # pre_ids_now[k] = 第 k 个 output token (k >= 0) + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # stop_seq = [X, 55] where X = pre_ids_now[5 + (-1)] = pre_ids_now[4] + # At accept_idx=0 (window ends at accept_tokens[0]=55): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=5+(-1)=4 -> pre_ids[4]=token_ids_all[0, prompt_len+4] + target_val = int(inputs["token_ids_all"][0, prompt_len + 4]) inputs["stop_seqs"][0, 0, :2] = [target_val, 55] inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=0 -> accept_num=2, eos at accept_tokens[1] + self.assertEqual(outputs["accept_num"][0], 2) + self.assertEqual(outputs["accept_tokens"][0, 1], -1) # eos appended after stop_seq + + def test_single_token_stop_seq_preserved(self): + """Single token stop_seq (like <|im_end|>) with eos appended after it.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=90, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999 + inputs["accept_tokens"][0, :4] = [100, 200, 999, 300] + # stop_seq = [<|im_end|>] (single token) + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # Match at accept_idx=2 (window ends at accept_tokens[2]=999) + # After loop increment, accept_idx=3, accept_num=4, eos at accept_tokens[3] + self.assertEqual(outputs["accept_num"][0], 4) + self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq + + def test_stop_seq_at_last_position_not_detected(self): + """Stop seq at the last position of accept_tokens is NOT detected (deferred to next round).""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=100, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # stop_seq [999] is at accept_tokens[3] (last valid position) + # Since we only check up to accept_num - 2 = 2, this won't be detected + inputs["accept_tokens"][0, :4] = [100, 200, 300, 999] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # No match because accept_idx only goes up to 2, and 999 is at position 3 + # accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 4) + + def test_stop_seq_detected_from_previous_round(self): + """Stop seq at the end of pre_ids (from previous round) is detected via accept_idx=-1.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=110, + ) + inputs["prompt_lens"][:] = 0 + # 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 10 表示有 10 个历史 output token,在 pre_ids_now[0..9] + # accept_idx=-1 检查 pre_ids_now[9] (最后一个历史 token) + inputs["step_idx"][:] = 10 + inputs["token_ids_all"][0, 9] = 999 # pre_ids_now[9] = 第 10 个 output token (0-indexed) + inputs["accept_num"][:] = 3 + inputs["accept_tokens"][0, :3] = [100, 200, 300] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # stop_seq [999] was in pre_ids at end, accept_idx=-1 matches + # After loop increment, accept_idx=0, accept_num=1, eos at accept_tokens[0] + self.assertEqual(outputs["accept_num"][0], 1) + self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos if __name__ == "__main__": diff --git a/tests/operators/test_unified_update_model_status.py b/tests/operators/test_unified_update_model_status.py index 56656fdbe75..ed97aa86879 100644 --- a/tests/operators/test_unified_update_model_status.py +++ b/tests/operators/test_unified_update_model_status.py @@ -261,7 +261,9 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: # Write history to token_ids_all (forward loop, mirrors kernel step 5) if output_len > 0: base_addr = int(prompt_lens[batch_id]) - base = cur_step_idx - output_len + 1 + # 新语义: step_idx 入口 = 历史数量,处理后 cur_step_idx = 历史 + output_len + # 第一个 output token 写入位置 = cur_step_idx - output_len + base = cur_step_idx - output_len for i in range(output_len): write_idx = base_addr + base + i if 0 <= write_idx < max_model_len: From 8a8beca548e8c064714d07284a76a5b6c7259d70 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Tue, 14 Apr 2026 19:25:12 +0800 Subject: [PATCH 030/143] [BugFix][PD Disaggregation][KVCache] Fix low cache hit rate in PD split scenario (#7364) (#7387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation 在 PD 分离场景下,decode 节点在接收 prefill 节点转发的请求后,没有及时更新 cache block 的命中信息, 导致 prefix cache 命中率低,影响推理性能。 ## Modifications 1. 在 `_free_blocks_when_stop` 方法中,额外排除 prefill 节点(`splitwise_role == "prefill"`) 的 cache block 更新,避免 prefill 节点重复更新 cache 导致状态混乱。 2. 在 decode 节点分配请求(`_alloc_requests_with_cache`)成功后,主动调用 `update_cache_blocks` 使用 `need_prefill_tokens` 更新 cache block 信息, 确保 decode 节点能正确感知已命中的 prefix cache。 Co-authored-by: kevin --- fastdeploy/engine/sched/resource_manager_v1.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 45ec18aa1c0..ae0e0c798b3 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -927,6 +927,7 @@ def _allocate_decode_and_extend(): if ( self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode" + and self.config.scheduler_config.splitwise_role != "prefill" ): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens @@ -1374,6 +1375,11 @@ def preallocate_resource_in_p(self, request: Request): self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position + + self.cache_manager.update_cache_blocks( + request, self.config.cache_config.block_size, request.need_prefill_tokens + ) + return True else: self._free_blocks(request) From f6c066fb9dff6e74b14ea2d85ed6a66899153270 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 14 Apr 2026 20:01:39 +0800 Subject: [PATCH 031/143] Revert "[Optimization] Optimize ttft for prefill pd (#6680)" (#7386) * Revert "[Optimization] Optimize ttft for prefill pd (#6680)" This reverts commit 6727df82868a23cf1a41cb906fed9e4440a0beb3. * fix revert pr --- docs/usage/environment_variables.md | 3 + docs/zh/usage/environment_variables.md | 3 + fastdeploy/engine/common_engine.py | 53 +++++-------- fastdeploy/envs.py | 2 + fastdeploy/scheduler/dp_scheduler.py | 55 +++++++++++--- .../splitwise/internal_adapter_utils.py | 5 +- fastdeploy/worker/worker_process.py | 74 +++++-------------- tests/ci_use/metrics/test_metrics.py | 37 +++++----- tests/engine/test_common_engine.py | 12 +-- tests/scheduler/test_dp_scheduler.py | 26 +++++++ .../splitwise/test_internal_adapter_utils.py | 3 - 11 files changed, 137 insertions(+), 136 deletions(-) diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 692ad8cd023..e54ec8f8798 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to enable the decode caches requests for preallocating resource "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), + # Max pre-fetch requests number in PD "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 0a4cfd389db..ab625bd4d2c 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # 是否启用 decode 缓存请求以预分配资源 "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # EP 中批处理 token 的超时时间 + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), + # PD 中最大预取请求数量 "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index cd9e42f8bcf..dabed9e4342 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -372,15 +372,6 @@ def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进 create=True, ) - engine_forward_signal_data = np.zeros([1], dtype=np.int32) - self.engine_forward_signal = IPCSignal( - name="engine_forward_signal", - array=engine_forward_signal_data, - dtype=np.int32, - suffix=current_suffix, - create=True, - ) - # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros( shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32 @@ -1050,29 +1041,26 @@ def _fetch_request(): with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) try: - if not is_fetching: - # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. - try: + if self.engine_worker_queue.exist_tasks(): + time.sleep(0.001) + continue + if self.cfg.scheduler_config.splitwise_role != "mixed": + if not is_fetching: is_fetching = True get_request_pool.submit(_fetch_request) - except RuntimeError as e: - if "shutdown" in str(e): - self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") - break - else: - raise - if self.cfg.scheduler_config.splitwise_role != "mixed": - # Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished. - # Once the forward pass finishes, these accumulated requests can be scheduled in larger, - # more efficient batches. - if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0: - time.sleep(0.001) - continue + else: - # In mixed, todo: optimze cache swap, to decouple swap from scheduler - if self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - continue + if len(self.resource_manager.waiting) == 0 and (not is_fetching): + # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. + try: + is_fetching = True + get_request_pool.submit(_fetch_request) + except RuntimeError as e: + if "shutdown" in str(e): + self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") + break + else: + raise if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() @@ -1133,13 +1121,6 @@ def _fetch_request(): elif not task.has_been_preempted_before: task.metrics.inference_start_time = time.time() self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) - else: - # When there are no actual tasks to schedule, send an empty task batch to EP workers. - # This helps EP workers barrier for syncing tasks not hang. - if self.cfg.parallel_config.enable_expert_parallel: - self.engine_worker_queue.put_tasks( - ([], self.resource_manager.real_bsz) - ) # Empty (as idle tasks for ep) # 4. Response error tasks if error_tasks: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 733d874edc4..96bc09934a8 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -145,6 +145,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), # Whether to enable the decode caches requests for preallocating resource "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), # Max pre-fetch requests number in PD "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), # Enable or disable model caching. diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index 2339a077c96..f5b03eba30f 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -23,7 +23,7 @@ from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledResponse from fastdeploy.scheduler.local_scheduler import LocalScheduler -from fastdeploy.utils import get_logger +from fastdeploy.utils import envs, get_logger class DPLocalScheduler(LocalScheduler): @@ -131,19 +131,52 @@ def get_requests( Returns: List of Request objects ready for processing """ - # DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that. + if available_blocks <= reserved_output_blocks or batch < 1: + self.scheduler_logger.debug( + f"Scheduler's resource are insufficient: available_blocks={available_blocks} " + f"reserved_output_blocks={reserved_output_blocks} batch={batch} " + f"max_num_batched_tokens={max_num_batched_tokens}" + ) + return [] + required_total_blocks = 0 + current_prefill_tokens = 0 + start_batch_time = time.time() requests: List[Request] = [] with self.requests_not_empty: - batch_ids = self.requests_not_empty.wait_for( - lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1], - 0.005, - ) - if batch_ids: - for request_id in batch_ids: - request = self.requests[request_id] - requests.append(request.raw) - self.ids_read_cursor += 1 + while True: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break + + requests.append(request.raw) + self.ids_read_cursor += 1 + start_batch_time = time.time() + if current_prefill_tokens > max_num_batched_tokens: + break + if len(requests) >= batch: + break + if ( + (current_prefill_tokens > max_num_batched_tokens) + or (len(requests) >= batch) + or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) + ): + break + + if batch_ids: + if len(batch_ids) > 0 and len(requests) == 0: + self.scheduler_logger.debug( + f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" + ) if len(requests) > 0: self.scheduler_logger.info( diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index e64e468b186..5c2f793fdbf 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -53,9 +53,6 @@ def _get_current_server_info(self): available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) available_block_num = self.engine.resource_manager.available_block_num() - unhandled_request_num = self.engine.scheduler.get_unhandled_request_num() - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - unhandled_request_num = max(unhandled_request_num, len(self.engine.resource_manager.waiting)) server_info = { "splitwise_role": self.cfg.scheduler_config.splitwise_role, "block_size": int(self.cfg.cache_config.block_size), @@ -65,7 +62,7 @@ def _get_current_server_info(self): "available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num), "max_batch_size": int(available_batch_size), "max_input_token_num": self.cfg.model_config.max_model_len, - "unhandled_request_num": unhandled_request_num, + "unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(), "available_batch": int(self.engine.resource_manager.available_batch()), } return server_info diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 2f51359959f..3f2a1fcf0dd 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -287,19 +287,6 @@ def init_health_status(self) -> None: create=False, ) - # init engine forward signal - # If engine is being forward, engine_forward_signal_data should be 1. - # If engine is out of forward, engine_forward_signal_data should be 0. - # In pd disaggregation + EP parallel, only when engine is out of forward, scheduler send next batch to worker. - # When engine is out of forward, engine_forward_signal_data must be 0, otherwise scheduler will not schedule next batch. - engine_forward_signal_data = np.zeros([1], dtype=np.int32) - self.engine_forward_signal = IPCSignal( - name="engine_forward_signal", - array=engine_forward_signal_data, - dtype=np.int32, - suffix=self.parallel_config.local_engine_worker_queue_port, - create=False, - ) # gpu_cache_lock: file-based lock for mutual exclusion between worker # and CPU transfer when accessing GPU KV cache. self.gpu_cache_lock = IPCLock( @@ -494,6 +481,9 @@ def event_loop_normal(self) -> None: # TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one self.model_weights_signal = np.zeros([1], dtype=np.int32) while True: + # run eplb + self._run_eplb(tp_rank) + if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) if self.ranks > 1: @@ -571,7 +561,7 @@ def event_loop_normal(self) -> None: if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: logger.debug(f"Rank: {self.local_rank} Detected new requests.") - self.engine_forward_signal.value[0] = 1 + tasks, read_finish = self.task_queue.get_tasks() # Only one of all tp_size client will get read_finish == True. if read_finish: @@ -580,39 +570,25 @@ def event_loop_normal(self) -> None: self.task_queue.read_finish_flag.set(0) else: self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY - # In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival. - # Only EP + DP prefill should barrier for data arrival. - # In mixed mode and decoder in D, we should not barrier to influence decoding. - if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill": - paddle.distributed.barrier(self.parallel_config.ep_group) req_dicts, control_reqs = [], [] - assert ( - len(tasks) > 0 - ), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}" - # In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue. - # tasks[0] contains two part, ([req1, ...] ,real_bsz) - # tasks[0][0] is [req1, ...] - # if empty batch is delived, eval(tasks[0][0]) should be False ([]), - # if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below. - if tasks[0][0]: - for req_dict, bsz in tasks: - if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest): - control_reqs.append(req_dict[0]) + for req_dict, bsz in tasks: + if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest): + control_reqs.append(req_dict[0]) + else: + max_occupied_batch_index = int(bsz) + req_dicts.extend(req_dict) + + # todo: run control request async + if len(control_reqs) > 0: + logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") + for control_req in control_reqs: + if self.parallel_config.use_ep: + self.cached_control_reqs.append(control_req) + logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}") else: - max_occupied_batch_index = int(bsz) - req_dicts.extend(req_dict) - - # todo: run control request async - if len(control_reqs) > 0: - logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") - for control_req in control_reqs: - if self.parallel_config.use_ep: - self.cached_control_reqs.append(control_req) - logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}") - else: - self.run_control_method(control_req) - self._tp_barrier_wait() if tp_size > 1 else None + self.run_control_method(control_req) + self._tp_barrier_wait() if tp_size > 1 else None if len(req_dicts) > 0: # Count prefill requests in current batch @@ -628,12 +604,6 @@ def event_loop_normal(self) -> None: # Process prefill inputs self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index) - else: - if self.scheduler_config.splitwise_role == "prefill": - if tp_size > 1: - # Synchronize the signal for other workers - self._tp_barrier_wait() - continue # Let the ep group run control method synchronically if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep: @@ -648,7 +618,6 @@ def event_loop_normal(self) -> None: and not self.worker.model_runner.not_need_stop() ): self._tp_barrier_wait() if tp_size > 1 else None - self.engine_forward_signal.value[0] = 0 time.sleep(0.001) continue @@ -673,9 +642,6 @@ def event_loop_normal(self) -> None: if not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill() logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s") - # run eplb - self._run_eplb(tp_rank) - self.engine_forward_signal.value[0] = 0 if ( not self.parallel_config.use_ep diff --git a/tests/ci_use/metrics/test_metrics.py b/tests/ci_use/metrics/test_metrics.py index 0d5353780f0..a54504c29bd 100644 --- a/tests/ci_use/metrics/test_metrics.py +++ b/tests/ci_use/metrics/test_metrics.py @@ -214,29 +214,28 @@ def test_metrics_with_clear_and_reset(): """ Test the metrics monitoring endpoint. """ - pass # not stable, uncomment after bug fix - # metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" - # async_concurrency(n=10) + async_concurrency(n=10) - # time.sleep(0.3) + time.sleep(0.3) # ===== clear_load_weight ===== - # clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight" - # print("Calling clear_load_weight...") - # r = requests.get(clear_url, timeout=30) - # assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}" - - # metrics = get_metrics_dict(metrics_url) - # running = metrics["fastdeploy:num_requests_running"] - # waiting = metrics["fastdeploy:num_requests_waiting"] - - # print( - # "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):", - # running, - # "waiting:", - # waiting, - # ) + clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight" + print("Calling clear_load_weight...") + r = requests.get(clear_url, timeout=30) + assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}" + + metrics = get_metrics_dict(metrics_url) + running = metrics["fastdeploy:num_requests_running"] + waiting = metrics["fastdeploy:num_requests_waiting"] + + print( + "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):", + running, + "waiting:", + waiting, + ) # assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight" diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 551f93babd8..8778e2013ea 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -1457,9 +1457,7 @@ def test_schedule_request_to_worker_v1_decode_preempted_and_errors(self): task.metrics.scheduler_recv_req_time = time.time() eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")])) @@ -1493,9 +1491,7 @@ def test_schedule_request_to_worker_v1_decode_prefill_task_path(self): task.metrics.scheduler_recv_req_time = time.time() eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [])) @@ -1526,9 +1522,7 @@ def test_schedule_request_to_worker_v1_error_task_none_skips_send(self): task.metrics.scheduler_recv_req_time = time.time() eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)])) diff --git a/tests/scheduler/test_dp_scheduler.py b/tests/scheduler/test_dp_scheduler.py index 0e42c4491f3..a5f9cfa8380 100644 --- a/tests/scheduler/test_dp_scheduler.py +++ b/tests/scheduler/test_dp_scheduler.py @@ -411,6 +411,32 @@ def test_recycle_expired_requests(self, mock_time): self.assertEqual(scheduler.ids, ["fresh_req"]) self.assertEqual(scheduler.ids_read_cursor, 1) + def test_get_requests_insufficient_resources(self): + """Test getting requests when resources are insufficient.""" + mock_logger.reset_mock() + + # Test with insufficient blocks - mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + self.assertEqual(requests, []) + # The logger should have been called for insufficient resources + self.assertTrue(mock_logger.debug.called) + # Check the message contains expected content + call_args = mock_logger.debug.call_args[0][0] + self.assertIn("insufficient", call_args.lower()) + + def test_get_requests_insufficient_batch(self): + """Test getting requests when batch size is insufficient.""" + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0 + ) + + self.assertEqual(requests, []) + @patch("time.time") @patch.object(dp_scheduler_module, "envs") def test_get_requests_no_requests_available(self, mock_envs, mock_time): diff --git a/tests/splitwise/test_internal_adapter_utils.py b/tests/splitwise/test_internal_adapter_utils.py index f8f22215c02..4d772789848 100644 --- a/tests/splitwise/test_internal_adapter_utils.py +++ b/tests/splitwise/test_internal_adapter_utils.py @@ -25,9 +25,6 @@ class DummyEngine: """Dummy Engine class to simulate the actual Engine for testing.""" class ResourceManager: - def __init__(self): - self.waiting = [] - def available_batch(self): return 4 From 5f7524eb8561e01be3446760acc79665eb587ca9 Mon Sep 17 00:00:00 2001 From: sunxin <68891411+Sunny-bot1@users.noreply.github.com> Date: Tue, 14 Apr 2026 20:04:09 +0800 Subject: [PATCH 032/143] fix rl moe gate type (#7394) --- fastdeploy/rl/rollout_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index cade1355088..0caefd9ada1 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -68,6 +68,7 @@ def __init__( routing_replay_config: str = None, load_choices: str = "default_v1", lm_head_fp32: bool = False, + moe_gate_fp32: bool = True, ): # Required parameters self.model = model_name_or_path @@ -121,6 +122,7 @@ def __init__( self.routing_replay_config = routing_replay_config self.load_choices = load_choices self.lm_head_fp32 = lm_head_fp32 + self.moe_gate_fp32 = moe_gate_fp32 def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) From 2ee1cc3d0ad5ea4375e9462b751072aeb9f7bab9 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:05:20 +0800 Subject: [PATCH 033/143] check init_flash_attn_version log (#7401) --- .../model_executor/layers/attention/flash_attn_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index bcffcd0bac0..2549f9f5d87 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -95,7 +95,7 @@ def init_flash_attn_version(): logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.") if FLASH_ATTN_VERSION is None: - if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()): + if sm_version == 90 and 90 in paddle.version.cuda_archs(): FLASH_ATTN_VERSION = 3 logger.info("The current platform supports Flash Attention V3.") else: From 61bfe6e5b3ef874d95c03ea3c574218b77093b25 Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Wed, 15 Apr 2026 18:19:21 +0800 Subject: [PATCH 034/143] modify flashmask version (#7414) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f51ae1bdf9e..2edef89b859 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,5 +47,5 @@ aistudio_sdk p2pstore py-cpuinfo flashinfer-python-paddle -#flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl +flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl transformers>=4.55.1,<5.0.0 From 26674bbbb62726c61787e58c22d85ee71e533093 Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Wed, 15 Apr 2026 19:45:09 +0800 Subject: [PATCH 035/143] [Cherry-Pick][RL] Add clear_graph_opt_backend for glm4_mtp (#7378) (#7379) * add clear_grpah func * fix spell --- .../model_executor/graph_optimization/decorator.py | 2 +- fastdeploy/model_executor/models/deepseek_v3.py | 4 ++-- fastdeploy/model_executor/models/ernie4_5_moe.py | 4 ++-- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 4 ++-- fastdeploy/model_executor/models/glm4_moe.py | 4 ++-- fastdeploy/model_executor/models/glm4_mtp.py | 4 ++++ fastdeploy/model_executor/models/qwen2.py | 4 ++-- fastdeploy/model_executor/models/qwen3.py | 4 ++-- fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py | 4 ++-- fastdeploy/model_executor/models/qwen3moe.py | 4 ++-- fastdeploy/worker/gpu_model_runner.py | 6 +++--- fastdeploy/worker/metax_model_runner.py | 2 +- tests/graph_optimization/test_cuda_graph_recapture.py | 10 +++++----- tests/worker/test_gpu_model_runner.py | 4 ++-- 14 files changed, 32 insertions(+), 28 deletions(-) diff --git a/fastdeploy/model_executor/graph_optimization/decorator.py b/fastdeploy/model_executor/graph_optimization/decorator.py index 562164aae1a..05ec79a495c 100644 --- a/fastdeploy/model_executor/graph_optimization/decorator.py +++ b/fastdeploy/model_executor/graph_optimization/decorator.py @@ -92,7 +92,7 @@ def forward(self, **kwargs): def __call__(self, **kwargs): return self.graph_opt_backend(**kwargs) - def clear_grpah_opt_backend(self, fd_config): + def clear_graph_opt_backend(self, fd_config): """ """ # TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights assert ( diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 4e75ba1d90b..aa3f3af346e 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -1306,9 +1306,9 @@ def forward( ) return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class DeepSeekV3PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 4cc4306de5f..bf8b3d93481 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -701,9 +701,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + self.ernie.clear_graph_opt_backend(fd_config=self.fd_config) @ModelRegistry.register_model_class( diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index f4d70108e4b..b6fa97ab0ba 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -829,9 +829,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + self.ernie.clear_graph_opt_backend(fd_config=self.fd_config) class Ernie4_5_VLPretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 7840107a046..fba36185a4a 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -563,9 +563,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Glm4MoePretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/glm4_mtp.py b/fastdeploy/model_executor/models/glm4_mtp.py index c28023202d2..c700ea442c5 100644 --- a/fastdeploy/model_executor/models/glm4_mtp.py +++ b/fastdeploy/model_executor/models/glm4_mtp.py @@ -369,3 +369,7 @@ def forward( ) return hidden_states + + def clear_graph_opt_backend(self): + """Clear graph optimization backend, the captured cuda graph will be cleaned""" + self.model.clear_graph_opt_backend(fd_config=self.fd_config) diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 1bca09265ee..1d0ce349bf2 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -417,9 +417,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config) + self.qwen2.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen2PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index ebbf4f5aed0..b0bcf9d5883 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -341,9 +341,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py index a4d3f1579c3..3f2a6904248 100644 --- a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py +++ b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py @@ -382,9 +382,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3VLPretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 74ca37ab695..95adc7ad0eb 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -453,9 +453,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3MoePretrainedModel(PretrainedModel): diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 73d9a791843..2bdbdb345b1 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2692,13 +2692,13 @@ def clear_parameters(self, pid): """Dynamic model loader use to clear parameters use for RL""" # Clear CUDAGraph if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) if self.spec_method == SpecMethod.MTP: - self.proposer.model.clear_grpah_opt_backend() + self.proposer.model.clear_graph_opt_backend() self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.cuda.empty_cache() @@ -2752,7 +2752,7 @@ def sleep(self, tags): logger.info("GPU model runner's weight is already sleeping, no need to sleep again!") return if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() if self.fd_config.parallel_config.enable_expert_parallel: self.dynamic_weight_manager.clear_deepep_buffer() self.dynamic_weight_manager.clear_model_weight() diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 28c769e1166..d72538ba8d7 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -2511,7 +2511,7 @@ def clear_parameters(self, pid): """Dynamic model loader use to clear parameters use for RL""" # Clear CUDAGraph if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle diff --git a/tests/graph_optimization/test_cuda_graph_recapture.py b/tests/graph_optimization/test_cuda_graph_recapture.py index 1a28c0731b3..902bcf182fd 100644 --- a/tests/graph_optimization/test_cuda_graph_recapture.py +++ b/tests/graph_optimization/test_cuda_graph_recapture.py @@ -91,10 +91,10 @@ def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): return sublayer2_output - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """ """ - self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config) - self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config) + self.sublayer1.clear_graph_opt_backend(fd_config=self.fd_config) + self.sublayer2.clear_graph_opt_backend(fd_config=self.fd_config) class TestCUDAGrpahRecapture(unittest.TestCase): @@ -152,7 +152,7 @@ def capture_and_replay(self, input_tensor1, forward_meta1): # Destroy print_gpu_memory_use("before destroy", 0) - self.test_model1.clear_grpah_opt_backend() + self.test_model1.clear_graph_opt_backend() print_gpu_memory_use("after destroy", 0) def recapture_and_replay(self, input_tensor1, forward_meta1): @@ -168,7 +168,7 @@ def recapture_and_replay(self, input_tensor1, forward_meta1): # Destroy print_gpu_memory_use("before destroy", 0) - self.test_model1.clear_grpah_opt_backend() + self.test_model1.clear_graph_opt_backend() print_gpu_memory_use("after destroy", 0) diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py index 3a02475b5ae..43ab5130cdb 100644 --- a/tests/worker/test_gpu_model_runner.py +++ b/tests/worker/test_gpu_model_runner.py @@ -487,7 +487,7 @@ def _make_runner(self): runner.local_rank = 0 runner.device_id = 1 runner.num_gpu_blocks = 8 - runner.model = Mock(clear_grpah_opt_backend=Mock()) + runner.model = Mock(clear_graph_opt_backend=Mock()) runner.clear_cache = Mock() runner.initialize_kv_cache = Mock() runner.capture_model = Mock() @@ -523,7 +523,7 @@ def test_sleep_offloads_weight_and_cache(self, mock_empty_cache, mock_print_memo runner.sleep("weight,kv_cache") - runner.model.clear_grpah_opt_backend.assert_called_once() + runner.model.clear_graph_opt_backend.assert_called_once() runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once() runner.dynamic_weight_manager.clear_model_weight.assert_called_once() runner.dynamic_weight_manager.clear_communication_group.assert_called_once() From b8e8a6253f8ece29b41e1cf2426796ddf7713a60 Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:02:10 +0800 Subject: [PATCH 036/143] PD deployment support without router (#7412) (#7424) --- examples/splitwise/start_v0_tp1.sh | 113 ----- fastdeploy/config.py | 13 +- fastdeploy/engine/args_utils.py | 11 +- fastdeploy/engine/expert_service.py | 4 +- ...test_ernie_03b_pd_wo_router_v1_rdma_tp1.py | 455 ++++++++++++++++++ tests/model_executor/test_thinking_budget.py | 2 +- tests/utils/test_config.py | 2 +- 7 files changed, 473 insertions(+), 127 deletions(-) delete mode 100644 examples/splitwise/start_v0_tp1.sh create mode 100644 tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py diff --git a/examples/splitwise/start_v0_tp1.sh b/examples/splitwise/start_v0_tp1.sh deleted file mode 100644 index 40c20301138..00000000000 --- a/examples/splitwise/start_v0_tp1.sh +++ /dev/null @@ -1,113 +0,0 @@ -#!/bin/bash -set -e - -# Test splitwise deployment -# There are two methods for splitwise deployment: -# v0: using splitwise_scheduler or dp_scheduler (deprecated) -# v1: using local_scheduler + router - -# prepare environment -export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" -export FD_DEBUG=1 -export ENABLE_V1_KVCACHE_SCHEDULER=1 -export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 - -SCRIPT_PATH=$(readlink -f "$0") -SCRIPT_DIR=$(dirname "$SCRIPT_PATH") -export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu) -echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}" -if [ -z "${KVCACHE_RDMA_NICS}" ]; then - echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh" - exit 1 -fi - -unset http_proxy && unset https_proxy -source ${SCRIPT_DIR}/utils.sh - -P_PORT=52400 -D_PORT=52500 -REDIS_PORT="${REDIS_PORT:-6379}" -LOG_DATE=$(date +%Y%m%d_%H%M%S) - -ports=( - $P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5)) - $D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5)) - $REDIS_PORT -) -check_ports "${ports[@]}" || { - echo "❌ Some ports are in use. Please release them." - exit 1 -} - -# start redis -if ! redis-cli -p ${REDIS_PORT} ping &>/dev/null; then - echo "Redis is not running. Starting redis-server..." - redis-server --daemonize yes --port ${REDIS_PORT} - sleep 1 -else - echo "Redis is already running." -fi -sleep 1 - -# start prefill -export CUDA_VISIBLE_DEVICES=0 -export FD_LOG_DIR="log/$LOG_DATE/prefill" -rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR} - -nohup python -m fastdeploy.entrypoints.openai.api_server \ - --model ${MODEL_NAME} \ - --port ${P_PORT} \ - --metrics-port $((P_PORT + 1)) \ - --engine-worker-queue-port $((P_PORT + 2)) \ - --cache-queue-port $((P_PORT + 3)) \ - --max-model-len 32768 \ - --num-gpu-blocks-override 1000 \ - --splitwise-role "prefill" \ - --cache-transfer-protocol "rdma" \ - --rdma-comm-ports $((P_PORT + 4)) \ - --pd-comm-port $((P_PORT + 5)) \ - --scheduler-name "splitwise" \ - --scheduler-host "127.0.0.1" \ - --scheduler-port ${REDIS_PORT} \ - --scheduler-ttl 9000 \ - 2>&1 >${FD_LOG_DIR}/nohup & - -wait_for_health ${P_PORT} - -# start decode -export CUDA_VISIBLE_DEVICES=1 -export FD_LOG_DIR="log/$LOG_DATE/decode" -rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR} - -nohup python -m fastdeploy.entrypoints.openai.api_server \ - --model ${MODEL_NAME} \ - --port ${D_PORT} \ - --metrics-port $((D_PORT + 1)) \ - --engine-worker-queue-port $((D_PORT + 2)) \ - --cache-queue-port $((D_PORT + 3)) \ - --max-model-len 32768 \ - --splitwise-role "decode" \ - --cache-transfer-protocol "rdma" \ - --rdma-comm-ports $((D_PORT + 4)) \ - --pd-comm-port $((D_PORT + 5)) \ - --scheduler-name "splitwise" \ - --scheduler-host "127.0.0.1" \ - --scheduler-port ${REDIS_PORT} \ - --scheduler-ttl 9000 \ - 2>&1 >${FD_LOG_DIR}/nohup & - -wait_for_health ${D_PORT} - - -# send request -sleep 10 # make sure server is registered to router -echo "send request..." -curl -X POST "http://0.0.0.0:${D_PORT}/v1/chat/completions" \ --H "Content-Type: application/json" \ --d '{ - "messages": [ - {"role": "user", "content": "hello"} - ], - "max_tokens": 20, - "stream": false -}' diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 6e7001bc18c..1b37db9611b 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2009,13 +2009,13 @@ def __init__( and self.router_config and self.router_config.router ): - # For RL scenario: version.yaml will be required for models in future releases. + # For RL scenario, version.yaml is required for models # Temporarily enforce use router to be enabled. self.model_config.read_model_version() self.read_from_config() self.postprocess() - self.init_cache_info() + self.init_pd_info() if test_mode: return self.check() @@ -2348,18 +2348,17 @@ def print(self): logger.info("{:<20}:{:<6}{}".format(k, "", v)) logger.info("=============================================================") - def init_cache_info(self): + def init_pd_info(self): """ - initialize cache info + initialize info for pd deployment """ - # TODO: group the splitiwse params # There are two methods for splitwise deployment: # 1. v0 splitwise_scheduler or dp_scheduler - # 2. v1 local_scheduler + router + # 2. v1 local_scheduler + router (optional) self.splitwise_version = None if self.scheduler_config.name in ("splitwise", "dp"): self.splitwise_version = "v0" - elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router: + elif self.scheduler_config.name == "local": self.splitwise_version = "v1" # the information for registering this server to router or splitwise_scheduler diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 848926f963c..d350350f85d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -592,10 +592,15 @@ def __post_init__(self): raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1") if self.splitwise_role != "mixed": - if self.scheduler_name == "local" and self.router is None: + if self.scheduler_name == "splitwise": raise ValueError( - f"When using {self.splitwise_role} role and the {self.scheduler_name} " - f"scheduler, please provide --router argument." + "Setting scheduler_name as splitwise is not supported in pd deployment, " + "please use router as scheduler." + ) + if self.scheduler_name == "local" and self.router is None: + console_logger.warning( + f"Running {self.splitwise_role} role with {self.scheduler_name} " + f"scheduler without --router. Router registration and request routing will be disabled." ) if not ( diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 81fe93e52a4..5958b3d9bd3 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -109,7 +109,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id): if envs.FD_ENABLE_RETURN_TEXT: self.engine.create_data_processor() if self.cfg.scheduler_config.name == "dp": - self.cfg.init_cache_info() + self.cfg.init_pd_info() self.engine.scheduler.start(local_data_parallel_id) if ipc_signal_suffix is not None: @@ -122,7 +122,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id): self.llm_logger.info(f"start expert service {local_data_parallel_id}") if self.cfg.scheduler_config.name == "splitwise": - self.cfg.init_cache_info() + self.cfg.init_pd_info() role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip self.engine.scheduler.start(role, host_ip, self.cfg.register_info) diff --git a/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py b/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py new file mode 100644 index 00000000000..efe702240e1 --- /dev/null +++ b/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py @@ -0,0 +1,455 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test splitwise deployment WITHOUT Router: +# use local_scheduler, manually construct disaggregate_info, +# send requests to both Prefill and Decode concurrently. +# ENABLE_V1_KVCACHE_SCHEDULER=1, use rdma to transfer cache. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time +import uuid + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + check_service_health, + clean, +) + +# Ports for PD disaggregation (no router port needed) +FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433)) +FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623)) + +# Prefill uses base ports, Decode uses base+1 +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_CONNECTOR_PORT, + FD_RDMA_PORT, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_CONNECTOR_PORT + 1, + FD_RDMA_PORT + 1, +] + + +def _build_disaggregate_info() -> dict: + """Build disaggregate_info manually, replicating Router's handle_splitwise_request logic.""" + host_ip = os.getenv("FD_HOST_IP", "127.0.0.1") + return { + "prefill_ip": host_ip, + "decode_ip": host_ip, + "prefill_connector_port": FD_CONNECTOR_PORT, + "decode_connector_port": FD_CONNECTOR_PORT + 1, + "decode_device_ids": ["1"], + "decode_rdma_ports": [FD_RDMA_PORT + 1], + "transfer_protocol": "rdma", + "decode_tp_size": 1, + } + + +def _send_pd_request(payload: dict, timeout: int = 120): + """ + Send request to both Prefill and Decode concurrently, + replicate Router's fan-out forwarding behavior. + Returns the Decode response (same as Router's return_result_url_index=-1). + """ + disaggregate_info = _build_disaggregate_info() + + # Inject disaggregate_info and request_id (same as Router) + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/chat/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions" + + headers = {"Content-Type": "application/json"} + + # For streaming, use requests with stream=True for decode response + if payload.get("stream", False): + # Send to both concurrently (same as Router's fan-out), stream from decode + import concurrent.futures + + def _post_stream(url): + return requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + prefill_future = executor.submit(_post_stream, prefill_url) + decode_future = executor.submit(_post_stream, decode_url) + # Return decode streaming response immediately + decode_resp = decode_future.result() + # Consume prefill response in background (don't block) + try: + prefill_future.result(timeout=timeout) + except Exception: + pass + return decode_resp + else: + # Non-streaming: send to both, return decode response + import concurrent.futures + + def _post(url): + return requests.post(url, headers=headers, json=payload, timeout=timeout) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + prefill_future = executor.submit(_post, prefill_url) + decode_future = executor.submit(_post, decode_url) + # Wait for both, return decode response + decode_resp = decode_future.result() + # Also check prefill didn't error (but don't block on it) + try: + prefill_future.result(timeout=5) + except Exception: + pass + return decode_resp + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts Prefill and Decode instances WITHOUT Router + - Waits for both to be healthy + - Tears down after all tests finish + """ + print("Pre-test port cleanup...") + clean(PORTS_TO_CLEAN) + + print("log dir clean") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + base_log_dir = os.getenv("FD_LOG_DIR", "log") + + # Prefill instance + print("start prefill...") + env_prefill = os.environ.copy() + env_prefill["CUDA_VISIBLE_DEVICES"] = "0" + env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill") + + prefill_log_path = "prefill.log" + prefill_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + str(FD_RDMA_PORT), + "--pd-comm-port", + str(FD_CONNECTOR_PORT), + # No --router flag + ] + + with open(prefill_log_path, "w") as logfile: + process_prefill = subprocess.Popen( + prefill_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + env=env_prefill, + ) + time.sleep(1) + + # Decode instance + print("start decode...") + env_decode = os.environ.copy() + env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode") + + decode_log_path = "decode.log" + decode_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + str(FD_RDMA_PORT + 1), + "--pd-comm-port", + str(FD_CONNECTOR_PORT + 1), + # No --router flag + ] + + with open(decode_log_path, "w") as logfile: + process_decode = subprocess.Popen( + decode_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + env=env_decode, + ) + + # Wait up to 300 seconds for both instances to be healthy + for _ in range(60): + prefill_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT}") + decode_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT + 1}") + if prefill_healthy and decode_healthy: + print("Prefill and decode servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] Servers failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError("Prefill or decode server did not start") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + print(f"Prefill server (pid={process_prefill.pid}) terminated") + print(f"Decode server (pid={process_decode.pid}) terminated") + except Exception as e: + print(f"Failed to terminate server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the Decode API endpoint URL (where final responses come from). + """ + return f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions" + + +@pytest.fixture +def headers(): + return {"Content-Type": "application/json"} + + +def get_stream_chunks(response): + """Parse streaming response into chunk list.""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"Parse failed: {e}, line: {line}") + else: + print(f"Request failed, status: {response.status_code}") + print("Response:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """Test streaming chat with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = _send_pd_request(payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """Test non-streaming chat with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = _send_pd_request(payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """Test streaming completion (non-chat) with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + # Send to /v1/completions endpoints + disaggregate_info = _build_disaggregate_info() + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions" + headers = {"Content-Type": "application/json"} + + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120) + decode_future = executor.submit( + requests.post, decode_url, json=payload, headers=headers, timeout=120, stream=True + ) + response = decode_future.result() + + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """Test non-streaming completion (non-chat) with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + # Send to /v1/completions endpoints + disaggregate_info = _build_disaggregate_info() + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions" + headers = {"Content-Type": "application/json"} + + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120) + decode_future = executor.submit(requests.post, decode_url, json=payload, headers=headers, timeout=120) + response = decode_future.result().json() + + usage = response["usage"] + result = response["choices"][0]["text"] + print("Decode Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" diff --git a/tests/model_executor/test_thinking_budget.py b/tests/model_executor/test_thinking_budget.py index 139b6859951..4cc5a1563bd 100644 --- a/tests/model_executor/test_thinking_budget.py +++ b/tests/model_executor/test_thinking_budget.py @@ -111,7 +111,7 @@ def setUp(self): self._fdconfig_patches = [ patch.object(FDConfig, "read_from_config", return_value=None), patch.object(FDConfig, "postprocess", return_value=None), - patch.object(FDConfig, "init_cache_info", return_value=None), + patch.object(FDConfig, "init_pd_info", return_value=None), patch.object(FDConfig, "check", return_value=None), ] for patcher in self._fdconfig_patches: diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index 240cf702ed7..4f55ca46472 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -138,7 +138,7 @@ def test_fdconfig_init_cache(self): model_config=model_config, test_mode=True, ) - fd_config.init_cache_info() + fd_config.init_pd_info() assert fd_config.register_info is not None def test_fdconfig_postprocess_ports(self): From 72ce56b10b366f287a892177e9bc155f3e9eecfb Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:15:03 +0800 Subject: [PATCH 037/143] [BugFix] fix tool call parser (#7369) (#7419) * fix tool call parser * add unit test * fix unit test * add unit test Co-authored-by: luukunn <981429396@qq.com> --- .../tool_parsers/ernie_x1_tool_parser.py | 33 +- .../tool_parsers/test_ernie_x1_tool_parser.py | 773 ++++++++++-------- 2 files changed, 461 insertions(+), 345 deletions(-) diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index f4556a3679f..7435dbce490 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -111,7 +111,7 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) ) ) return ExtractedToolCallInformation( - tools_called=True, + tools_called=len(tool_calls) > 0, tool_calls=tool_calls, ) except Exception: @@ -182,11 +182,13 @@ def extract_tool_calls_streaming( logger.debug("attempting to close tool call, but no tool call") return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") - if diff: - if '"}' not in delta_text: + if diff is not None: + if "}" not in delta_text: + return None + end_loc = delta_text.rindex("}") + diff = delta_text[:end_loc] + if not diff: return None - end_loc = delta_text.rindex('"}') - diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff, @@ -248,15 +250,15 @@ def extract_tool_calls_streaming( prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") - if not cur_arguments and not prev_arguments: + if cur_arguments is None and prev_arguments is None: logger.debug("Skipping text %s - no arguments", delta_text) delta = None - elif not cur_arguments and prev_arguments: + elif cur_arguments is None and prev_arguments is not None: logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.") delta = None - elif cur_arguments and not prev_arguments: + elif cur_arguments is not None and prev_arguments is None: function_name = current_tool_call.get("name") match = re.search( r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', @@ -265,6 +267,19 @@ def extract_tool_calls_streaming( ) if match: cur_arguments_json = match.group(1) + # When tool_call_portion is complete JSON, the regex + # (.*) over-captures the outer closing brace of the + # tool call object. Strip it from both + # cur_arguments_json and delta_text, consistent with + # the both-have-arguments branch handling. + try: + json.loads(tool_call_portion) + if cur_arguments_json.endswith("}"): + cur_arguments_json = cur_arguments_json[:-1] + if delta_text.rstrip().endswith("}"): + delta_text = delta_text.rstrip()[:-1] + except Exception: + pass else: cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) @@ -287,7 +302,7 @@ def extract_tool_calls_streaming( ) self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - elif cur_arguments and prev_arguments: + elif cur_arguments is not None and prev_arguments is not None: try: json.loads(tool_call_portion) is_complete_json = True diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py index 01a68c2380d..0dbda0c35eb 100644 --- a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py @@ -60,6 +60,50 @@ def get_vocab(self): return ErnieX1ToolParser(tokenizer=DummyTokenizer()) + def _simulate_streaming(self, parser, deltas): + """Simulate a multi-step streaming flow. + + Args: + parser: ErnieX1ToolParser instance + deltas: list of delta text strings, each representing one streaming step + + Returns: + list of results from each extract_tool_calls_streaming call + """ + results = [] + previous_text = "" + token_id = 0 + previous_token_ids = [] + + for delta in deltas: + current_text = previous_text + delta + # When delta contains plus more content, use 2 tokens + # so that the parser extracts tool_call_portion (line 163-164) + if "" in delta and delta != "": + n_tokens = 2 + else: + n_tokens = 1 + + delta_token_ids = list(range(token_id + 1, token_id + 1 + n_tokens)) + token_id += n_tokens + current_token_ids = previous_token_ids + delta_token_ids + + result = parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta, + previous_token_ids, + current_token_ids, + delta_token_ids, + self.dummy_request, + ) + results.append(result) + + previous_text = current_text + previous_token_ids = list(current_token_ids) + + return results + # ==================== __init__ tests (lines 60-81) ==================== def test_init_sets_tokens_and_ids(self): @@ -116,6 +160,14 @@ def test_extract_tool_calls_no_arguments(self): self.assertTrue(result.tools_called) self.assertEqual(result.tool_calls[0].function.arguments, "{}") + def test_extract_tool_calls_empty_arguments(self): + """Cover: tool call with explicit empty arguments {}""" + output = '{"name": "fn", "arguments": {}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(result.tool_calls[0].function.name, "fn") + self.assertEqual(result.tool_calls[0].function.arguments, "{}") + def test_extract_tool_calls_nested_arguments(self): """Cover regex with nested braces in arguments""" output = '{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}' @@ -182,38 +234,24 @@ def test_streaming_balanced_counts_text_after_tool(self): def test_streaming_end_token_in_delta(self): """Cover lines 149-156: appears in delta""" parser = self._new_parser() - # First, start a tool call - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # Now stream arguments - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - ', "arguments": {"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Close with end token in delta - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "v"}}', - '"}}', - [1, 10, 20], - [1, 10, 20, 2], - [2], - self.dummy_request, - ) - # Should handle end token - self.assertTrue(result is None or isinstance(result, DeltaMessage)) + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "', # start + name + args key + "v", # args value + '"}}', # close with end token in delta + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args branch, regex extracts '{"k": "v' as arguments_delta + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, '{"k": "v') + # Step 3: end token in delta triggers close handling + # delta before is '"}}', close branch: rindex('}')=2, diff='"}' + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.arguments, '"}') # --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) --- @@ -255,37 +293,29 @@ def test_streaming_new_tool_call_multi_tokens(self): def test_streaming_continue_tool_call_no_name_yet(self): """Cover lines 174-176, 220-222: partial JSON without name yet""" parser = self._new_parser() - # Start tool call - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Continue with partial content, no name parseable yet - result = parser.extract_tool_calls_streaming( - "", - '{"na', - '{"na', - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"na', # partial content, no name yet + ], ) - self.assertIsNone(result) + self.assertIsNone(results[0]) + self.assertIsNone(results[1]) def test_streaming_continue_tool_call_with_name(self): """Cover lines 174-176, 223-235: name becomes available""" parser = self._new_parser() - # Start tool call - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Name appears - result = parser.extract_tool_calls_streaming( - "", - '{"name": "get_weather"', - '{"name": "get_weather"', - [1], - [1, 10], - [10], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertEqual(result.tool_calls[0].function.name, "get_weather") + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"name": "get_weather"', # name appears + ], + ) + self.assertIsNone(results[0]) + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.name, "get_weather") self.assertTrue(parser.current_tool_name_sent) # --- Lines 236-237: name not sent and function_name is None --- @@ -293,18 +323,14 @@ def test_streaming_continue_tool_call_with_name(self): def test_streaming_no_function_name(self): """Cover lines 236-237: parsed JSON has no 'name' field""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Send JSON without name field - result = parser.extract_tool_calls_streaming( - "", - '{"arguments": {"k": "v"}}', - '{"arguments": {"k": "v"}}', - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"arguments": {"k": "v"}}', # JSON without name field + ], ) - self.assertIsNone(result) + self.assertIsNone(results[1]) # --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) --- @@ -333,9 +359,9 @@ def test_streaming_close_with_remaining_diff(self): parser.streamed_args_for_tool = [""] parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] result = parser.extract_tool_calls_streaming( - '{"name":"fn","arguments":{"k":"v"}}', + '{"name":"fn","arguments":{"k":"v"', '{"name":"fn","arguments":{"k":"v"}}', - '"}}', + "}}", [1, 10], [1, 10, 2], [2], @@ -343,9 +369,14 @@ def test_streaming_close_with_remaining_diff(self): ) self.assertIsNotNone(result) self.assertIsNotNone(result.tool_calls) + self.assertEqual(result.tool_calls[0].function.arguments, "}") - def test_streaming_close_with_diff_no_end_marker(self): - """Cover lines 184-185: close with arguments but no '"}' in delta_text""" + def test_streaming_text_after_completed_tool_call(self): + """Cover lines 143-147: text content after a completed tool call. + + When start==end counts, prev_end==cur_end, and end_token not in delta, + the parser treats delta as regular text content. + """ parser = self._new_parser() parser.current_tool_id = 0 parser.current_tool_name_sent = True @@ -353,7 +384,7 @@ def test_streaming_close_with_diff_no_end_marker(self): parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] # Simulate end token in delta but without '"}' pattern # We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta - # so that we enter the elif at 178 + # so that we enter the text-content branch at line 143-147 result = parser.extract_tool_calls_streaming( '{"name":"fn","arguments":{"k":"v"}}', '{"name":"fn","arguments":{"k":"v"}} text', @@ -363,8 +394,9 @@ def test_streaming_close_with_diff_no_end_marker(self): [30], self.dummy_request, ) - # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147) - self.assertIsInstance(result, DeltaMessage) + # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 149) + self.assertIsNotNone(result) + self.assertEqual(result.content, " text") def test_streaming_close_no_arguments(self): """Cover lines 182-183: close branch where prev arguments is None/empty""" @@ -382,8 +414,126 @@ def test_streaming_close_no_arguments(self): [2], self.dummy_request, ) - # diff is None (no arguments), so falls through to partial_json_parser - self.assertTrue(result is None or isinstance(result, DeltaMessage)) + # diff is None (no arguments key in prev), falls through to partial_json_parser + # parses complete JSON, cur_args=None, prev_args=None -> no-args -> delta=None + self.assertIsNone(result) + + def test_streaming_close_with_empty_dict_arguments(self): + """Regression: close branch must handle arguments={} (empty dict). + + Before fix, `if diff:` was False for empty dict {}, so the close + logic was skipped. After fix, `if diff is not None:` correctly + enters the branch. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # start + name + args key + "{}", # empty dict value + "}", # outer close brace + "", # end token + ], + ) + # Step 1: name sent + # Step 2: first-args, cur_args={} is not None, prev_args=None + # Without fix: not {} == True -> no-args branch -> returns None + # With fix: enters first-args -> streams "{}" -> DeltaMessage + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[1].tool_calls) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + + def test_streaming_empty_arguments_with_outer_brace_in_same_token(self): + """Regression: when arguments={} and outer } arrive in the same token '{}}', + regex (.*) over-captures the outer brace, producing '{}}'. + + Real production data showed arguments='{}}}' for get_default_weather + with empty arguments. This test reproduces that exact scenario. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "get_default_weather", "arguments": ', # start + name + args key + "{}}", # empty args + outer close brace in same token + "", # end token + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "get_default_weather") + # Step 2: first-args branch, tool_call_portion is complete JSON + # regex (.*) captures '{}}' but fix strips outer '}' -> '{}' + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + # Step 3: end token, close branch + # diff = prev_arguments = {} (not None), delta_text = '' (empty after split) + # '}' not in '' -> returns None + self.assertIsNone(results[2]) + + def test_streaming_close_with_number_ending_arguments(self): + """Regression: close branch must flush remaining args ending with number. + + Before fix, '"}' not in delta was True for numbers, causing return None. + After fix, rindex('}') correctly finds the closing brace. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"count": ', # start + name + args key + "123", # number value + "}}", # close braces + end token + ], + ) + # Step 1: name sent + # Step 2: first-args, streams {"count": 123 + # Step 3: close branch flushes remaining "}" + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"count": 123}') + + def test_streaming_close_with_boolean_ending_arguments(self): + """Regression: close branch must flush remaining args ending with boolean.""" + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"flag": ', # start + args key + "true", # boolean value + "}}", # close + end token + ], + ) + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"flag": true}') + + def test_streaming_close_with_nested_object_ending(self): + """Regression: close branch must flush remaining args ending with nested '}'.""" + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"nested": {"a": ', # start + args key + "1", # nested value + "}}}", # close all + end token + ], + ) + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"nested": {"a": 1}}') # --- Lines 202-206: else branch (cur_start < cur_end, edge case) --- @@ -404,23 +554,21 @@ def test_streaming_else_branch(self): def test_streaming_malformed_json(self): """Cover lines 213-215: MalformedJSON from partial parser""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Feed badly formed content - result = parser.extract_tool_calls_streaming( - "", - "{{{", - "{{{", - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + "{{{", # badly formed content + ], ) - self.assertIsNone(result) + self.assertIsNone(results[1]) def test_streaming_json_decode_error(self): """Cover lines 216-218: JSONDecodeError from partial parser""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + # Step 1: start tool call normally + self._simulate_streaming(parser, [""]) + # Step 2: mock partial_json_parser to throw ValueError with patch( "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", side_effect=ValueError("bad json"), @@ -430,8 +578,8 @@ def test_streaming_json_decode_error(self): "bad", "bad", [1], - [1, 10], - [10], + [1, 2], + [2], self.dummy_request, ) self.assertIsNone(result) @@ -469,30 +617,17 @@ def test_streaming_tool_portion_none_with_text(self): def test_streaming_first_arguments_with_regex_match(self): """Cover lines 243-244, 257-286: first arguments appear, regex matches""" parser = self._new_parser() - # Start tool call and send name - parser.extract_tool_calls_streaming( - "", - '{"name": "get_weather"', - '{"name": "get_weather"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # Now stream arguments (first time) - # Key must be complete (closing quote) so partial_json_parser returns truthy arguments. - # delta must be a substring of the regex-extracted arguments portion (after "arguments":). - result = parser.extract_tool_calls_streaming( - '{"name": "get_weather"', - '{"name": "get_weather", "arguments": {"location": "bei', - '"bei', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertIsNotNone(result.tool_calls) + results = self._simulate_streaming( + parser, + [ + '{"name": "get_weather", "arguments": {"location": "', # start + name + args key + "bei", # args value + ], + ) + # Step 1: name sent + # Step 2: first-args, regex finds "bei" in '{"location": "bei' + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, '{"location": "bei') def test_streaming_first_arguments_no_regex_match(self): """Cover lines 266-267: regex doesn't match, fallback to json.dumps""" @@ -522,67 +657,119 @@ def test_streaming_first_arguments_no_regex_match(self): self.assertIsNotNone(result.tool_calls) def test_streaming_first_arguments_delta_not_in_json(self): - """Cover lines 271-272: delta_text not found in cur_arguments_json""" + """Cover lines 275-276: delta_text not found in cur_arguments_json, returns None. + When delta contains the arguments key itself (e.g. ', "arguments": {'), + regex extracts cur_arguments_json='{' but delta ', "arguments": {' is not in '{'. + """ parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # Delta text that doesn't appear in the arguments JSON - result = parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v"}}', - "ZZZZZ", - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - self.assertIsNone(result) + results = self._simulate_streaming( + parser, + [ + '{"name": "fn"', # start + partial name + ', "arguments": {', # delta introduces arguments key + open brace + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args branch, regex extracts cur_arguments_json='{' + # delta_text=', "arguments": {' is NOT in '{' -> returns None + self.assertIsNone(results[1]) # --- Lines 249-251: no cur_arguments and no prev_arguments --- def test_streaming_no_arguments_at_all(self): """Cover lines 249-251: both cur and prev arguments are empty/None""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn"', # start + name + "}", # close JSON, no arguments + ], ) - # Continue with name only, no arguments + # prev_arguments=None, cur_arguments=None -> delta=None + self.assertIsNone(results[1]) + + def test_streaming_empty_dict_arguments_not_skipped(self): + """Regression: arguments={} (empty dict) must not be treated as no arguments. + + Empty dict is falsy in Python (`not {} == True`). Before the fix, + this caused empty arguments to enter the no-arguments branch, + silently dropping them during streaming. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # start + name + args key + "{}", # empty dict value + "}", # outer close brace + ], + ) + # Step 1: name sent + # Step 2: cur_arguments={} (not None), prev_arguments=None + # With fix: enters first-arguments branch -> streams "{}" + # Without fix: not {} == True -> no-arguments branch -> delta=None + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[1].tool_calls) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + + def test_streaming_empty_dict_prev_arguments_not_reset(self): + """Regression: prev_arguments={} must not be treated as no arguments. + + When prev has {} and cur has a non-empty dict, the code should enter + the both-have-arguments branch, not the first-arguments branch. + + This scenario (arguments growing from {} to non-empty) is hard to + produce naturally, so we build up state through a real flow then + verify the branch behavior with one additional call. + """ + parser = self._new_parser() + # Build up state naturally: prev_tool_call_arr gets arguments={} + self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # name + args key + "{}", # empty dict value + "}", # outer close + ], + ) + # Verify state is correct + self.assertEqual(parser.prev_tool_call_arr[0].get("arguments"), {}) + + # Now test: if more argument data arrives, prev_args={} should be + # treated as "not None" -> enters both-have-arguments branch + # Without fix: not {} == True -> first-arguments branch (wrong) result = parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn"}', - "}", - [1, 10], - [1, 10, 20], - [20], + '{"name": "fn", "arguments": {"k": "v', + '{"name": "fn", "arguments": {"k": "val', + "al", + [1, 2, 3], + [1, 2, 3, 4], + [4], self.dummy_request, ) - # prev_arguments=None, cur_arguments=None -> delta=None - # then prev_tool_call_arr updated and returns delta (which is None) - self.assertIsNone(result) + # both-have-arguments branch: delta_text="al" streamed as arguments + self.assertIsNotNone(result) + self.assertEqual(result.tool_calls[0].function.arguments, "al") # --- Lines 253-255: cur_arguments reset (impossible branch) --- def test_streaming_arguments_reset_mid_call(self): - """Cover lines 253-255: prev has arguments but cur doesn't (impossible case)""" + """Cover lines 253-255: prev has arguments but cur doesn't (impossible case). + + This is an edge case that shouldn't happen in normal flow, but tests + defensive handling when partial parser returns no arguments after + previously having them. + """ parser = self._new_parser() parser.current_tool_id = 0 parser.current_tool_name_sent = True parser.streamed_args_for_tool = [""] + # Simulate state where prev already had arguments parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] - # Feed content where cur has no arguments but prev does + # Mock parser to return no arguments (simulating the impossible reset) with patch( "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", return_value={"name": "fn"}, @@ -591,9 +778,9 @@ def test_streaming_arguments_reset_mid_call(self): '{"name": "fn", "arguments": {"k": "v"', '{"name": "fn", "arguments": {"k": "v"}', '"}', - [1, 10], - [1, 10, 20], - [20], + [1, 2], + [1, 2, 3], + [3], self.dummy_request, ) self.assertIsNone(result) @@ -603,110 +790,48 @@ def test_streaming_arguments_reset_mid_call(self): def test_streaming_incremental_arguments_incomplete(self): """Cover lines 288-314: both prev and cur have arguments, JSON incomplete""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # First arguments - delta must appear in regex-extracted arguments portion - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - '{"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # More argument tokens (both prev and cur have arguments now) - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "val', - "al", - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertEqual(result.tool_calls[0].function.arguments, "al") + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v', # start + name + first args + "a", # establishes prev_args + "l", # incremental: both-have-args + ], + ) + # Step 1: name sent + # Step 2: first-args branch + # Step 3: both-have-args branch, streams "l" + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.arguments, "l") def test_streaming_incremental_arguments_complete_json(self): """Cover lines 289-305: complete JSON with trailing }""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # First arguments - delta must appear in regex-extracted arguments portion - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - '{"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Complete with closing braces - both prev and cur have arguments - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "v"}}', - '"}}', - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - # is_complete_json=True, delta ends with }, should strip trailing } - # After strip: '"' which is not empty, so returns DeltaMessage - self.assertIsNotNone(result) - self.assertIsInstance(result, DeltaMessage) + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v', # start + name + first args + "a", # establishes prev_args + '"}}', # completes JSON + ], + ) + # Step 3: both-have-args, complete JSON, strips trailing } -> streams '"}' + self.assertIsNotNone(results[2]) + self.assertIsInstance(results[2], DeltaMessage) def test_streaming_incremental_arguments_complete_empty_delta(self): """Cover lines 304-305: complete JSON where delta becomes empty after strip""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v"', # start + name + first args + "}", # inner close (establishes prev_args) + "}", # outer close: both-have-args, complete, delta stripped to "" + ], ) - # First arguments with proper delta - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v"}', - '{"k": "v"}', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Send just the outer closing brace - # tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v"}', - '{"name": "fn", "arguments": {"k": "v"}}', - "}", - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - # is_complete_json=True, delta="}" -> stripped to "" -> return None - self.assertIsNone(result) + # Step 3: is_complete_json=True, delta="}" -> stripped to "" -> return None + self.assertIsNone(results[2]) # --- Lines 316-319: prev_tool_call_arr update branches --- @@ -759,95 +884,71 @@ def test_streaming_general_exception(self): def test_streaming_full_flow(self): """Integration test: simulate a full streaming tool call flow""" parser = self._new_parser() - req = self.dummy_request - - # Step 1: text before tool call - r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req) - self.assertEqual(r.content, "thinking") - - # Step 2: tool_call start token - r = parser.extract_tool_calls_streaming("thinking", "thinking", "", [], [1], [1], req) - self.assertIsNone(r) + results = self._simulate_streaming( + parser, + [ + "thinking", # Step 1: text before tool call + "", # Step 2: tool_call start token + '{"name": "search", "arguments": {"query": "', # Step 3: name + args key + "test", # Step 4: args value + " data", # Step 5: more args + ], + ) + # Step 1: plain text + self.assertEqual(results[0].content, "thinking") + # Step 2: start token -> None + self.assertIsNone(results[1]) + # Step 3: name sent + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.name, "search") + # Step 4: first arguments + self.assertIsNotNone(results[3]) + self.assertEqual(results[3].tool_calls[0].function.arguments, '{"query": "test') + # Step 5: more arguments + self.assertIsNotNone(results[4]) + self.assertEqual(results[4].tool_calls[0].function.arguments, " data") - # Step 3: function name appears - r = parser.extract_tool_calls_streaming( - "thinking", - 'thinking{"name": "search"', - '{"name": "search"', - [1], - [1, 10], - [10], - req, - ) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.name, "search") - - # Step 4: arguments start - delta must appear in regex-extracted arguments portion - r = parser.extract_tool_calls_streaming( - 'thinking{"name": "search"', - 'thinking{"name": "search", "arguments": {"query": "test', - '{"query": "test', - [1, 10], - [1, 10, 20], - [20], - req, - ) - self.assertIsNotNone(r) + def test_streaming_empty_arguments_full_flow(self): + """Integration: streaming tool call with arguments={} must not lose arguments. - # Step 5: more arguments - r = parser.extract_tool_calls_streaming( - 'thinking{"name": "search", "arguments": {"query": "test', - 'thinking{"name": "search", "arguments": {"query": "test data', - " data", - [1, 10, 20], - [1, 10, 20, 30], - [30], - req, - ) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.arguments, " data") + Simulates a complete streaming flow where the tool call has empty + arguments. Verifies the name is sent and arguments are streamed. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # Step 1: start + name + args key + "{}", # Step 2: empty dict value + "}", # Step 3: outer close + "", # Step 4: end token + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args with cur_args={}, streams "{}" + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + # Step 4: close branch, delta_text="" after stripping + # diff={} is not None, but "}" not in "" -> return None + self.assertIsNone(results[2]) + self.assertIsNone(results[3]) def test_streaming_multiple_tool_calls(self): """Integration test: two tool calls in one response""" parser = self._new_parser() - req = self.dummy_request - - # First tool call - parser.extract_tool_calls_streaming( - "", - '{"name": "fn1"', - '{"name": "fn1"', - [], - [1, 10], - [1, 10], - req, - ) - self.assertEqual(parser.current_tool_id, 0) - - # Close first tool - parser.extract_tool_calls_streaming( - '{"name": "fn1"', - '{"name": "fn1"}', - "}", - [1, 10], - [1, 10, 2], - [2], - req, - ) - - # Second tool call - r = parser.extract_tool_calls_streaming( - '{"name": "fn1"}', - '{"name": "fn1"}{"name": "fn2"', - '{"name": "fn2"', - [1, 10, 2], - [1, 10, 2, 1, 20], - [1, 20], - req, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn1"', # First tool: start + name + "}", # Close first tool + '{"name": "fn2"', # Second tool: start + name + ], ) self.assertEqual(parser.current_tool_id, 1) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.name, "fn2") + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.name, "fn2") if __name__ == "__main__": From 185708b56655234cadd24e73a201e67d6eb2b306 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Fri, 17 Apr 2026 16:17:59 +0800 Subject: [PATCH 038/143] [Cherry-Pick][BugFix] Fix real token exceeding max_batched_tokens limit(#7438) (#7439) * fix max_num_batched_tokens error compute * add temperatory solution * fix bug --- fastdeploy/engine/sched/resource_manager_v1.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index ae0e0c798b3..ffc9c0bacf4 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -768,7 +768,17 @@ def get_enough_request(request, scheduled_reqs): scheduled_reqs: list[Request] = [] preempted_reqs: list[Request] = [] error_reqs: list[tuple[str, str]] = [] - token_budget = self.config.scheduler_config.max_num_batched_tokens + tokens_per_seq = ( + (self.config.speculative_config.num_speculative_tokens + 1) + if self.config.speculative_config is not None + else 1 + ) + token_budget = ( + self.config.scheduler_config.max_num_batched_tokens + - self.config.scheduler_config.max_num_seqs * tokens_per_seq + ) + # temperatory solution to avoid negative token_budget + token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512)) need_abort_requests = [] # users trigger abortion # First, schedule the RUNNING requests. From 650d1e49aaf2f6fedf7354655e11480fe7aeba6d Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:37:42 +0800 Subject: [PATCH 039/143] [Cherry-Pick][Speculative Decoding] Add MTP logprob support for PD disaggregation (#7442) (#7464) * support mtp logprob in pd * fix * fix * fix * fix xpu bugs --- .../mtp_save_first_token_with_topk.cc | 218 ++++++++++++++++++ .../speculate_get_output_with_topk.cc | 45 ++-- .../speculate_logprob_msg.h | 39 ++++ .../speculate_save_output_with_topk.cc | 44 ++-- .../model_executor/pre_and_post_process.py | 75 +++++- fastdeploy/spec_decode/mtp.py | 38 +-- fastdeploy/worker/gpu_model_runner.py | 8 +- 7 files changed, 389 insertions(+), 78 deletions(-) create mode 100644 custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc create mode 100644 custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc new file mode 100644 index 00000000000..02203a51cff --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc @@ -0,0 +1,218 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "../../custom_ftok.h" +#include "../speculate_logprob_msg.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& logprob_token_ids, + const paddle::Tensor& logprob_scores, + const paddle::Tensor& logprob_ranks, + const paddle::Tensor& token_num_per_batch, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& preempted_idx, + int message_flag, // Target: 3, Draft: 4 + int64_t rank_id, + bool save_each_rank) { + if (!save_each_rank && rank_id > 0) { + return; + } + + int max_draft_tokens = sampled_token_ids.shape()[1]; + int bsz = token_num_per_batch.shape()[0]; + + auto sampled_token_ids_cpu = + sampled_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_token_ids_cpu = + logprob_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); + auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); + auto token_num_per_batch_cpu = + token_num_per_batch.copy_to(paddle::CPUPlace(), false); + auto cu_batch_token_offset_cpu = + cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true); + int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); + int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); + float* logprob_scores_data = logprob_scores_cpu.data(); + int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); + int* token_num_per_batch_data = token_num_per_batch_cpu.data(); + int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + int* seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + int64_t* prompt_lens_data = prompt_lens_cpu.data(); + const int32_t* preempted_idx_data = preempted_idx.data(); + + static struct msgdata msg_sed; + int msg_queue_id = 1; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = custom_ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtype = 1; + msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env + : -inference_msg_id_from_env; + msg_sed.meta[1] = message_flag; + msg_sed.meta[2] = bsz; + int max_num_logprobs = logprob_token_ids.shape()[1]; + for (int i = 0; i < bsz; i++) { + int cur_token_num; + if (seq_lens_decoder_data[i] < prompt_lens_data[i] || + token_num_per_batch_data[i] == 0) { + // chunk prefill or stop slots + cur_token_num = 0; + } else { + cur_token_num = token_num_per_batch_data[i] + 1; + } + msg_sed.meta[3 + i] = cur_token_num; + if (preempted_idx_data[i] == 1) { + msg_sed.meta[3 + i] = -9; + } + + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + int token_offset = cu_batch_token_offset_data[i]; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + if (j == 0) { + // first token has full logprobs + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + if (k == 0) { + cur_tokens[k] = + (int)sampled_token_ids_data[i * max_draft_tokens + j]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; + } else if (k < max_num_logprobs) { + // only for first token + cur_tokens[k] = + (int)logprob_token_ids_data[(token_offset + j) * + (SPEC_LOGPROB_K + 1) + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; + } else { + cur_tokens[k] = -1; + cur_scores[k] = 0.0; + } + } + cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; + } else { + // draft token only has token_id + cur_tokens[0] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; + } + } + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << msg_sed.meta[0] + << ", message_flag: " << msg_sed.meta[1] + << ", bsz: " << msg_sed.meta[2] << std::endl; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_sed.meta[3 + i]; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + std::cout << "tokens: "; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << cur_tokens[k] << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << cur_scores[k] << " "; + } + std::cout << std::endl; + std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; + } + } + std::cout << std::endl; +#endif + if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { + printf("full msg buffer\n"); + } +} + +PD_BUILD_STATIC_OP(mtp_save_first_token_with_topk) + .Inputs({"sampled_token_ids", + "logprob_token_ids", + "logprob_scores", + "logprob_ranks", + "token_num_per_batch", + "cu_batch_token_offset", + "not_need_stop", + "seq_lens_decoder", + "prompt_lens", + "preempted_idx"}) + .Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"}) + .SetKernelFn(PD_KERNEL(MTPSaveFirstTokenWithTopK)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 76ff5e190d8..4fd7d4103c4 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -19,27 +19,12 @@ #include #include "paddle/extension.h" #include "../custom_ftok.h" +#include "speculate_logprob_msg.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 512 -#define K 20 -#define MAX_DRAFT_TOKEN_NUM 6 - -struct batch_msgdata { - int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - int ranks[MAX_DRAFT_TOKEN_NUM]; -}; - -struct msgdata { - long mtype; - int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums - batch_msgdata mtext[MAX_BSZ]; -}; - void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, const paddle::Tensor& output_scores, const paddle::Tensor& output_ranks, @@ -93,22 +78,22 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; - int output_tokens_offset = 3 + MAX_BSZ; + int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_rcv.meta[3 + i]; output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums auto* cur_output_token = output_tokens_data + output_tokens_offset + - i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_output_score = - output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; for (int j = 0; j < cur_token_num; j++) { for (int k = 0; k < real_k + 1; k++) { - cur_output_token[j * (K + 1) + k] = - (int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k]; - cur_output_score[j * (K + 1) + k] = - cur_batch_msg_rcv->scores[j * (K + 1) + k]; + cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] = + (int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k]; + cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] = + cur_batch_msg_rcv->scores[j * (SPEC_LOGPROB_K + 1) + k]; } output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] = (int64_t)cur_batch_msg_rcv->ranks[j]; @@ -124,17 +109,19 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; for (int j = 0; j < cur_token_num; j++) { std::cout << "tokens: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << output_tokens_data[output_tokens_offset + - i * MAX_DRAFT_TOKEN_NUM * (K + 1) + - j * (K + 1) + k] + i * MAX_DRAFT_TOKEN_NUM * + (SPEC_LOGPROB_K + 1) + + j * (SPEC_LOGPROB_K + 1) + k] << " "; } std::cout << std::endl; std::cout << "scores: "; - for (int k = 0; k < K + 1; k++) { - std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * (K + 1) + - j * (K + 1) + k] + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * + (SPEC_LOGPROB_K + 1) + + j * (SPEC_LOGPROB_K + 1) + k] << " "; } std::cout << std::endl; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h new file mode 100644 index 00000000000..dc2c6f399f4 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h @@ -0,0 +1,39 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define SPEC_LOGPROB_MAX_BSZ 512 +#define SPEC_LOGPROB_K 20 +#define MAX_DRAFT_TOKEN_NUM 6 + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + long mtype; + // stop_flag, message_flag, bsz, batch_token_nums + int meta[3 + SPEC_LOGPROB_MAX_BSZ]; + batch_msgdata mtext[SPEC_LOGPROB_MAX_BSZ]; +}; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 3d75886bd25..0b3de384cee 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -19,27 +19,12 @@ #include #include "paddle/extension.h" #include "../custom_ftok.h" +#include "speculate_logprob_msg.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 512 -#define K 20 -#define MAX_DRAFT_TOKEN_NUM 6 - -struct batch_msgdata { - int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - int ranks[MAX_DRAFT_TOKEN_NUM]; -}; - -struct msgdata { - long mtype; - int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums - batch_msgdata mtext[MAX_BSZ]; -}; - void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& logprob_token_ids, const paddle::Tensor& logprob_scores, @@ -154,16 +139,21 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; - for (int k = 0; k < K + 1; k++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; - cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; } else if (k < max_num_logprobs) { - cur_tokens[k] = - (int)logprob_token_ids_data[(token_offset + j) * (K + 1) + k]; - cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; } else { cur_tokens[k] = -1; cur_scores[k] = 0.0; @@ -182,15 +172,15 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; std::cout << "tokens: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << cur_tokens[k] << " "; } std::cout << std::endl; std::cout << "scores: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << cur_scores[k] << " "; } std::cout << std::endl; diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0fc6bfde5d0..29fc4235381 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -22,9 +22,14 @@ from fastdeploy import envs from fastdeploy.config import SpeculativeConfig +from fastdeploy.model_executor.ops.gpu import ( + mtp_save_first_token, + mtp_save_first_token_with_topk, +) from fastdeploy.platforms import current_platform from fastdeploy.worker.input_batch import ( InputBatch, + ProposerInputBatch, recover_batch_index_for_output, recover_batch_index_for_sampler_output, ) @@ -525,10 +530,76 @@ def save_output_specualate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, + proposer_share_inputs: ProposerInputBatch, + local_rank: int, + tensor_parallel_rank: int, save_each_rank: bool = False, - skip_save_output: bool = False, + is_mtp_prefill: bool = False, ): - if not skip_save_output: + if is_mtp_prefill: + if tensor_parallel_rank == 0: + skip_chunk_prefill = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + if sampler_output.logprobs_tensors is None: + recover_proposer_share_inputs_map = recover_batch_index_for_output( + proposer_share_inputs, + proposer_share_inputs.index_to_batch_id, + proposer_share_inputs.enable_pd_reorder, + [ + "base_model_draft_tokens", + "seq_lens_decoder", + "prompt_lens", + "step_idx", + ], + ) + mtp_save_first_token( + recover_proposer_share_inputs_map["base_model_draft_tokens"], + proposer_share_inputs["not_need_stop"], + recover_proposer_share_inputs_map["seq_lens_decoder"], + recover_proposer_share_inputs_map["prompt_lens"], + recover_proposer_share_inputs_map["step_idx"], + local_rank, + save_each_rank, + skip_chunk_prefill, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + [ + "sampled_token_ids", + "accept_tokens_cpu", + "accept_num_cpu", + "seq_lens_decoder_cpu", + "prompt_lens_cpu", + "last_preempted_idx", + ], + ) + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + recover_proposer_share_inputs_map = recover_batch_index_for_output( + proposer_share_inputs, + proposer_share_inputs.index_to_batch_id, + proposer_share_inputs.enable_pd_reorder, + ["base_model_draft_tokens"], + ) + mtp_save_first_token_with_topk( + recover_proposer_share_inputs_map["base_model_draft_tokens"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_share_inputs_map["accept_num_cpu"], + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + recover_share_inputs_map["seq_lens_decoder_cpu"], + recover_share_inputs_map["prompt_lens_cpu"], + recover_share_inputs_map["last_preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) + else: if sampler_output.logprobs_tensors is None: recover_share_inputs = recover_batch_index_for_output( share_inputs, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 9ca5f535ab0..acf7bee27a5 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -62,7 +62,6 @@ eagle_get_self_hidden_states, eagle_gather_hidden_states, hybrid_mtp_ngram, - mtp_save_first_token, mtp_step_paddle, share_external_data, speculate_get_logits, @@ -835,23 +834,26 @@ def _post_process(self, sampled_token_ids): ) if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: - skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder, - ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], - ) - mtp_save_first_token( - recover_model_output_map["base_model_draft_tokens"], - self.model_inputs["not_need_stop"], - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_model_output_map["step_idx"], - self.local_rank, - self.parallel_config.use_ep, - skip_save, - ) + if current_platform.is_xpu(): + # Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner. + # Only XPU platform is retained here. + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder, + ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], + ) + mtp_save_first_token( + recover_model_output_map["base_model_draft_tokens"], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_model_output_map["step_idx"], + self.local_rank, + self.parallel_config.use_ep, + skip_save, + ) # Ensure only save first token once. paddle.assign( paddle.where( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2bdbdb345b1..43478e1a817 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2478,13 +2478,17 @@ def _save_model_output( sampler_output, ): if self.speculative_decoding: - skip_save_output = self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" save_output_specualate( sampler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, + proposer_share_inputs=self.proposer.model_inputs, + local_rank=self.local_rank, + tensor_parallel_rank=self.parallel_config.tensor_parallel_rank, save_each_rank=self.parallel_config.use_ep, - skip_save_output=skip_save_output, + is_mtp_prefill=( + self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" + ), ) else: save_output_normal( From 56b761de3fb48d305420be4dfab1f7c47f016b91 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Sat, 18 Apr 2026 00:07:34 +0800 Subject: [PATCH 040/143] [Cherry-Pick][Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy(#7467) (#7468) * fix repeat_time kernel and change default spec verify strategy * fix unit_test --- ...peculate_get_token_penalty_multi_scores.cu | 152 +++++++++--------- fastdeploy/config.py | 2 +- tests/layers/test_speculative_sampler.py | 8 +- ...peculate_get_token_penalty_multi_scores.py | 2 +- 4 files changed, 81 insertions(+), 83 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu index ca5d8353c3e..022c39bfb64 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu @@ -16,12 +16,12 @@ template __global__ inline void min_length_logits_process( - T *logits, - const int64_t *cur_len, - const int64_t *min_len, - const int64_t *eos_token_id, - const int *batch_id_per_token_output, - const int *cu_seqlens_q_output, + T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process( template <> __global__ inline void min_length_logits_process( - half *logits, - const int64_t *cur_len, - const int64_t *min_len, - const int64_t *eos_token_id, - const int *batch_id_per_token_output, - const int *cu_seqlens_q_output, + half* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process( } } -__global__ void update_repeat_times(const int64_t *token_ids_all, - const int64_t *prompt_lens, - const int64_t *cur_len, - int *repeat_times, - const int *batch_id_per_token_output, +__global__ void update_repeat_times(const int64_t* token_ids_all, + const int64_t* prompt_lens, + const int64_t* cur_len, + int* repeat_times, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all, return; } int tid = threadIdx.x; - const int64_t *pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi]; - int *repeat_times_now = repeat_times + token_idx * length; - for (int i = tid; i < length_id; i += blockDim.x) { + const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi]; + int* repeat_times_now = repeat_times + token_idx * length; + for (int i = tid; i < cur_len[bi]; i += blockDim.x) { int64_t id = pre_ids_now[i]; if (id < 0) break; atomicAdd(&repeat_times_now[id], 1); @@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all, template __global__ void update_value_by_repeat_times( - const int *repeat_times, - const T *penalty_scores, - const T *frequency_score, - const T *presence_score, - const float *temperatures, - T *logits, - const int *batch_id_per_token_output, + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times( if (bi < 0) return; if (bi >= bs) return; int tid = threadIdx.x; - T *logits_now = logits + token_idx * length; - const int *repeat_times_now = repeat_times + token_idx * length; + T* logits_now = logits + token_idx * length; + const int* repeat_times_now = repeat_times + token_idx * length; float alpha = static_cast(penalty_scores[bi]); float beta = static_cast(frequency_score[bi]); float gamma = static_cast(presence_score[bi]); @@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times( } template -__global__ void ban_bad_words(T *logits, - const int64_t *bad_tokens, - const int64_t *bad_tokens_len, - const int *batch_id_per_token_output, +__global__ void ban_bad_words(T* logits, + const int64_t* bad_tokens, + const int64_t* bad_tokens_len, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits, if (bi < 0) return; if (bi >= bs) return; int tid = threadIdx.x; - T *logits_now = logits + token_idx * length; - const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length; + T* logits_now = logits + token_idx * length; + const int64_t* bad_tokens_now = bad_tokens + bi * bad_words_length; const int32_t bad_token_len = static_cast(min(bad_tokens_len[bi], bad_words_length)); for (int i = tid; i < bad_token_len; i += blockDim.x) { @@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits, template void token_penalty_multi_scores_kernel( - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_score, - const paddle::Tensor &presence_score, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &bad_tokens_len, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token_output, - const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_score, + const paddle::Tensor& presence_score, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel( int64_t end_length = eos_token_id.shape()[0]; int block_size = (token_num + 32 - 1) / 32 * 32; min_length_logits_process<<<1, block_size, 0, cu_stream>>>( - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast(const_cast(logits.data())), cur_len.data(), min_len.data(), eos_token_id.data(), @@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel( update_value_by_repeat_times <<>>( repeat_times.data(), - reinterpret_cast( - const_cast(penalty_scores.data())), - reinterpret_cast( - const_cast(frequency_score.data())), - reinterpret_cast( - const_cast(presence_score.data())), + reinterpret_cast( + const_cast(penalty_scores.data())), + reinterpret_cast( + const_cast(frequency_score.data())), + reinterpret_cast( + const_cast(presence_score.data())), temperatures.data(), - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast( + const_cast(logits.data())), batch_id_per_token_output.data(), token_num, bs, @@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel( block_size = (length_bad_words + 32 - 1) / 32 * 32; block_size = min(block_size, 512); ban_bad_words<<>>( - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast(const_cast(logits.data())), bad_tokens.data(), bad_tokens_len.data(), batch_id_per_token_output.data(), @@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel( } void SpecTokenPenaltyMultiScores( - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_scores, - const paddle::Tensor &presence_scores, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &bad_tokens_len, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token_output, - const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len) { switch (logits.type()) { case paddle::DataType::BFLOAT16: { diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 1b37db9611b..8dc18403608 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -774,7 +774,7 @@ class SpeculativeConfig: "benchmark_mode": False, "enf_gen_phase_tag": False, "enable_draft_logprob": False, - "verify_strategy": "topp", + "verify_strategy": "target_match", "accept_policy": "normal", } diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index ef75fe5d4e8..227e73db5a8 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -97,12 +97,12 @@ def _create_default_sampling_metadata( return fake_sampling_metadata -def _create_fd_config(max_model_len, method=None): +def _create_fd_config(max_model_len, method=None, verify_strategy="topp"): model_config: Mock = Mock() model_config.max_model_len = max_model_len model_config.architectures = ["test_model"] model_config.mm_max_tokens_per_item = None - speculative_config = SpeculativeConfig({"method": method} if method else {}) + speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy}) graph_opt_config = GraphOptimizationConfig({}) scheduler_config = SchedulerConfig({}) parallel_config = ParallelConfig({}) @@ -187,7 +187,7 @@ def test_speculative_sampler(): max_draft_token_num = 1 # Use ngram method for speculative decoding - fd_config = _create_fd_config(max_model_len, method="ngram") + fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp") sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) @@ -208,7 +208,7 @@ def test_speculative_sampler_logprobs(): max_draft_token_num = 1 # Use ngram method for speculative decoding - fd_config = _create_fd_config(max_model_len, method="ngram") + fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp") share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) sampling_metadata.share_inputs = share_inputs diff --git a/tests/operators/test_speculate_get_token_penalty_multi_scores.py b/tests/operators/test_speculate_get_token_penalty_multi_scores.py index 845f666ee7d..61efdbf270f 100644 --- a/tests/operators/test_speculate_get_token_penalty_multi_scores.py +++ b/tests/operators/test_speculate_get_token_penalty_multi_scores.py @@ -61,7 +61,7 @@ def update_repeat_times( token_ids_all_now = token_ids_all[bi] repeat_times_now = repeat_times[token_idx] - for i in range(length_id): + for i in range(cur_len[bi]): id = token_ids_all_now[i] if id < 0: break From fc801f83870f5b10e8c939de8c0161ef77d0e8e8 Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Mon, 20 Apr 2026 11:23:44 +0800 Subject: [PATCH 041/143] [Bugfix][RL] fix control request timeout in async update weights pipeline (#7470) --- fastdeploy/entrypoints/engine_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 4c56e9bcd76..278d6e576ab 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -598,6 +598,10 @@ def check_health(self, time_interval_threashold=30): async def run_control_method(self, request: ControlRequest): api_server_logger.info(f"Received control request: {request}") + request_id = request.request_id + dealer, response_queue = await self.connection_manager.get_connection(request_id) + if not envs.ZMQ_SEND_BATCH_DATA: + dealer.write([b"", request_id.encode("utf-8")]) req_dict = request.to_dict() if envs.ZMQ_SEND_BATCH_DATA: req_dict["zmq_worker_pid"] = self.worker_pid @@ -605,10 +609,6 @@ async def run_control_method(self, request: ControlRequest): self.zmq_client.send_json(req_dict) else: self.zmq_client.send_pyobj(req_dict) - request_id = request.request_id - dealer, response_queue = await self.connection_manager.get_connection(request_id) - if not envs.ZMQ_SEND_BATCH_DATA: - dealer.write([b"", request_id.encode("utf-8")]) try: # todo: support user specified timeout. default 600s is enough for most control cases response = await asyncio.wait_for(response_queue.get(), timeout=600) From f4f7760925ab2696c2d3700c5d409be62fe76538 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Mon, 20 Apr 2026 21:09:21 +0800 Subject: [PATCH 042/143] [CI] Temporarily pin paddlepaddle-gpu to 3.5.0.dev20260417 (#7486) (#7519) --- .github/workflows/_accuracy_test.yml | 2 +- .github/workflows/_base_test.yml | 2 +- .github/workflows/_build_linux.yml | 2 +- .github/workflows/_build_linux_cu129.yml | 2 +- .github/workflows/_build_linux_cu130.yml | 2 +- .github/workflows/_build_linux_rl.yml | 3 +-- .github/workflows/_golang_router_test.yml | 2 +- .github/workflows/_gpu_4cards_case_test.yml | 2 +- .github/workflows/_logprob_test_linux.yml | 2 +- .github/workflows/_pre_ce_test.yml | 2 +- .github/workflows/_stable_test.yml | 2 +- .github/workflows/_unit_test_coverage.yml | 2 +- 12 files changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/_accuracy_test.yml b/.github/workflows/_accuracy_test.yml index 87994625c58..f06e3c88669 100644 --- a/.github/workflows/_accuracy_test.yml +++ b/.github/workflows/_accuracy_test.yml @@ -180,7 +180,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index 3eb022725e5..fce46f04412 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -213,7 +213,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux.yml b/.github/workflows/_build_linux.yml index 1431df353cb..bd8f9f257c2 100644 --- a/.github/workflows/_build_linux.yml +++ b/.github/workflows/_build_linux.yml @@ -196,7 +196,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ else - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux_cu129.yml b/.github/workflows/_build_linux_cu129.yml index 61108a82b40..9800795b1b2 100644 --- a/.github/workflows/_build_linux_cu129.yml +++ b/.github/workflows/_build_linux_cu129.yml @@ -183,7 +183,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ else - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux_cu130.yml b/.github/workflows/_build_linux_cu130.yml index 7c2aee69c6f..85593eada9a 100644 --- a/.github/workflows/_build_linux_cu130.yml +++ b/.github/workflows/_build_linux_cu130.yml @@ -183,7 +183,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu130/ else - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu130/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu130/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux_rl.yml b/.github/workflows/_build_linux_rl.yml index 1a131adb1a1..9c3f2a47966 100644 --- a/.github/workflows/_build_linux_rl.yml +++ b/.github/workflows/_build_linux_rl.yml @@ -166,8 +166,7 @@ jobs: cd FastDeploy python -m pip uninstall paddlepaddle-gpu -y || true - wget -q --no-proxy https://paddle-qa.bj.bcebos.com/paddle-pipeline/Develop-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/latest/paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl - python -m pip install paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_golang_router_test.yml b/.github/workflows/_golang_router_test.yml index bbb1bc7799f..93c794482f6 100644 --- a/.github/workflows/_golang_router_test.yml +++ b/.github/workflows/_golang_router_test.yml @@ -212,7 +212,7 @@ jobs: git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt diff --git a/.github/workflows/_gpu_4cards_case_test.yml b/.github/workflows/_gpu_4cards_case_test.yml index 5c9a51aa809..be580a08dd9 100644 --- a/.github/workflows/_gpu_4cards_case_test.yml +++ b/.github/workflows/_gpu_4cards_case_test.yml @@ -208,7 +208,7 @@ jobs: cd FastDeploy git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt diff --git a/.github/workflows/_logprob_test_linux.yml b/.github/workflows/_logprob_test_linux.yml index 0a014d26854..b0ebec7d791 100644 --- a/.github/workflows/_logprob_test_linux.yml +++ b/.github/workflows/_logprob_test_linux.yml @@ -189,7 +189,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_pre_ce_test.yml b/.github/workflows/_pre_ce_test.yml index 8c5e20e2de0..8420c388ac3 100644 --- a/.github/workflows/_pre_ce_test.yml +++ b/.github/workflows/_pre_ce_test.yml @@ -201,7 +201,7 @@ jobs: --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ python -m pip install ${fd_wheel_url} bash scripts/run_pre_ce.sh ' diff --git a/.github/workflows/_stable_test.yml b/.github/workflows/_stable_test.yml index 11ae14927ef..fc89dfa6cac 100644 --- a/.github/workflows/_stable_test.yml +++ b/.github/workflows/_stable_test.yml @@ -193,7 +193,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_unit_test_coverage.yml b/.github/workflows/_unit_test_coverage.yml index 4096fe4f4c0..9f06d2f1ccb 100644 --- a/.github/workflows/_unit_test_coverage.yml +++ b/.github/workflows/_unit_test_coverage.yml @@ -224,7 +224,7 @@ jobs: git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt From 95261f098ba44d03e774810c21035044c3c6a4d1 Mon Sep 17 00:00:00 2001 From: zhouchong <43821961+xyxinyang@users.noreply.github.com> Date: Tue, 21 Apr 2026 15:21:47 +0800 Subject: [PATCH 043/143] Unify num_experts_per_tok to moe_k in ModelConfig for MoE model compatibility (#7517) --- fastdeploy/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 8dc18403608..2637a820c82 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -378,6 +378,9 @@ def override_name_from_config(self): # Because the ERNIE 4.5 config.json contains two sets of keys, adaptation is required. self.moe_num_shared_experts = self.n_shared_experts + if hasattr(self, "num_experts_per_tok") and not hasattr(self, "moe_k"): + self.moe_k = self.num_experts_per_tok + def read_from_env(self): """ Read configuration information from environment variables and update the object's attributes. From 74ddb20a734e1ef1b09ca0f70b45bc4aceba7661 Mon Sep 17 00:00:00 2001 From: RAM Date: Tue, 21 Apr 2026 16:51:45 +0800 Subject: [PATCH 044/143] [RL][Cherry-Pick] Fix the out-of-bounds issue caused by int32 in the R3 kernel (#7496) * [RL]Perf: Optimize batch delete prefix and fused put in R3 (#6604) * Optimizate delete batch and fused put * refine code * refine code * refine code * Support suspend r3 * [RL] Fix R3 Empty bug with TP=1 (#6777) * Fix int32 overflow * refine code * fix seq_lens_decoder bug --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> --- fastdeploy/config.py | 6 + fastdeploy/envs.py | 2 + .../layers/moe/routing_indices_cache.py | 109 +++++++++++------- fastdeploy/worker/gpu_model_runner.py | 6 +- 4 files changed, 79 insertions(+), 44 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 2637a820c82..7bdb413c126 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1748,6 +1748,9 @@ def __init__(self, args: dict): else: self.metrics_port = self.api_server_port + def __str__(self): + return json.dumps({key: value for key, value in self.__dict__.items()}) + class CommitConfig: """ @@ -1861,6 +1864,9 @@ def to_json_string(self): """ return json.dumps({key: value for key, value in self.__dict__.items()}) + def __str__(self): + return self.to_json_string() + class FDConfig: """ diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 96bc09934a8..f9db2f8e50b 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -259,6 +259,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # Suspend rollouting routing replay + "FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))), # train-infer consistency, used in RL # Whether to align RoPE and moe gate precision with training "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index b27957bf0c6..c139700f347 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -54,6 +54,7 @@ def _save_routing_kernel( TOP_K, NUM_HIDDEN_LAYERS, MAX_MODEL_LEN, + MAX_NUM_SEQS, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): @@ -63,45 +64,37 @@ def _save_routing_kernel( token_mask = token_offsets < TOKEN_NUM k_offsets = tl.arange(0, BLOCK_SIZE_K) - k_mask = k_offsets < TOP_K topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] - # [BLOCK_SIZE_M, BLOCK_SIZE_K] - load_mask = token_mask[:, None] & k_mask[None, :] - topk_vals = tl.load(topk_ids_ptrs, mask=load_mask) - - batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) - pad_mask = token_mask & (batch_ids != -1) - # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] - # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] - start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask) + topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1) + + batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1) + + batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS) + pad_mask = token_mask & (batch_ids != -1) & batch_mask + + start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0) token_relative_index = token_offsets - start_offsets - # [BLOCK_SIZE_M] - len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask) + len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0) token_seq_pos = len_decoder + token_relative_index - STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K - STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K + STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64) + STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64) STRIDE_BUF_LAYER = TOP_K - # [BLOCK_SIZE_M, BLOCK_SIZE_K] output_ptrs = ( ROUTING_REPLAY_TABLE_PTR - + batch_ids[:, None] * STRIDE_BUF_SEQ - + token_seq_pos[:, None] * STRIDE_BUF_TOKEN - + LAYER_IDX * STRIDE_BUF_LAYER + + tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ + + tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN + + tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER + k_offsets[None, :] ) - pos_mask = token_seq_pos < MAX_MODEL_LEN + pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN) pos_mask = pos_mask & pad_mask - - # [BLOCK_SIZE_M, BLOCK_SIZE_K] pos_mask = pos_mask[:, None] & k_mask[None, :] final_mask = load_mask & pos_mask @@ -120,10 +113,10 @@ def save_routing_to_buffer( ep_size: int, tp_group: dist.communication.group.Group, ): + token_num_per_rank = topk_ids.shape[0] + if token_num_per_rank == 0: + return if tp_size > 1 and ep_size > 1: - token_num_per_rank = topk_ids.shape[0] - if token_num_per_rank == 0: - return topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] @@ -131,9 +124,10 @@ def save_routing_to_buffer( token_num, top_k = topk_ids.shape max_num_seqs, max_model_len, num_hidden_layers, _ = routing_replay_table.shape assert token_num > 0 - assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3]) - assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num) - assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs) + assert ( + topk_ids.shape[1] == routing_replay_table.shape[3] + ), f"({topk_ids.shape[1]}, {routing_replay_table.shape[3]})" + assert batch_id_per_token.shape[0] == token_num, f"({batch_id_per_token.shape[0]}, {token_num})" BLOCK_SIZE_M = 128 BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k @@ -150,6 +144,7 @@ def save_routing_to_buffer( TOP_K=top_k, NUM_HIDDEN_LAYERS=num_hidden_layers, MAX_MODEL_LEN=max_model_len, + MAX_NUM_SEQS=max_num_seqs, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_K=BLOCK_SIZE_K, ) @@ -166,6 +161,7 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num): self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index self.only_last_turn = fd_config.routing_replay_config.only_last_turn self.use_fused_put = fd_config.routing_replay_config.use_fused_put + logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}") if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": self.moe_top_k = fd_config.model_config.num_experts_per_tok else: @@ -186,6 +182,17 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num): ) self._store_wrapper.start_store_warpper() + # Suspend Routing Replay + self.suspend_routing_replay = False + self.update_suspend_routing_replay() + + def update_suspend_routing_replay(self): + """Allow RL to use R3 in different training rounds""" + # TODO(gongshaotian): Delete this func + suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0") + self.suspend_routing_replay = bool(int(suspend_routing_replay)) + logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}") + def _init_routing_cache(self, dtype: str, total_block_num: int): """Initialize the device buffer and host buffer.""" @@ -341,6 +348,11 @@ def _put_request_to_store( seq_lens_decoder, ): if self.tp_rank == 0: + # TODO(gongshaotian): Delete the suspend func + if self.suspend_routing_replay: + logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store") + return + before_put_request_time = time.perf_counter() # Collect the routing of finished request @@ -351,16 +363,19 @@ def _put_request_to_store( if self.use_fused_put: self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id) + # Only store the routing of last turn + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) + else: for layer_id in range(self.num_moe_layers): layer_buffer = batch_buffer[layer_id] self._store_wrapper.submit_put_task( routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id ) - - # Only store the routing of last turn - if self.only_last_turn: - self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) + # Only store the routing of last turn + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}") @@ -481,7 +496,6 @@ def _monitor_queue_load(self): if qsize > self.queue_max_size * 0.8: logger.warning( f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " - f"Dropped tasks so far: {self._dropped_tasks}. " "Consider increasing max_workers or queue_max_size." ) logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") @@ -523,22 +537,26 @@ def submit_clear_store_task(self) -> None: raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") - def submit_clear_prefix_batch_task(self, rollout_id) -> None: + def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: """Submit clear prefix batch task""" if not self._sotre_process_running: raise RuntimeError("Store not started.") - prefix_batch = self.get_needed_clear_ids(rollout_id) - - if prefix_batch is None: + prefix_batch_id = self.get_needed_clear_ids(rollout_id) + if prefix_batch_id is None: return start_time = time.perf_counter() - task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None} + if layer_idx is not None: + rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" + else: + rdma_rollout_key = prefix_batch_id + + task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} try: self._task_queue.put_nowait(task) except Exception: raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") logger.info( - f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s" + f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" ) def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]: @@ -615,7 +633,7 @@ def run(self): self._task_queue.task_done() raise RuntimeError(f"Error during processing task. {e}") - logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.") + logger.info("RoutingReplay Consumer Process Shutdown.") def process_put_task(self, store_task: StoreTask) -> None: try: @@ -838,13 +856,18 @@ def __init__(self, routing_replay_config) -> None: async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: """Put the routing indices into store""" time_before_put = time.perf_counter() - result = await self.p2p_client.put(routing_key, routing_indices) + if len(routing_indices.shape) == 3: + # NOTE(gongshaotian) Fused put with bytes data + routing_bytes = routing_indices.tobytes() + result = await self.p2p_client.put(routing_key, routing_bytes) + else: + result = await self.p2p_client.put(routing_key, routing_indices) logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") return result async def clear_prefix_batch(self, routing_prefix_key: str): time_before_clear = time.perf_counter() - result = await self.p2p_client.delete_prefix_batch([routing_prefix_key]) + result = await self.p2p_client.delete_batch([routing_prefix_key]) logger.info( f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 43478e1a817..a6addca4abd 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2737,9 +2737,13 @@ def update_parameters(self, pid): # Recapture CUDAGraph if self.use_cudagraph: self.capture_model() + # Rollout Routing Replay + if self.fd_config.routing_replay_config.enable_routing_replay: + # TODO(gongshaotian): Delete suspend func + self.routing_replay_manager.update_suspend_routing_replay() + # Send single self.dynamic_weight_manager.finalize_update(pid) - self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def update_weights(self, version: str = None, verify_checksum: bool = False): From be2fd17e7d143d493370e6cff9dcd3e78f856db5 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:20:03 +0800 Subject: [PATCH 045/143] add m_grouped_bf16_gemm_nn_contiguous(#7536) --- .../layers/moe/fused_moe_cutlass_backend.py | 59 +++++++++++++------ 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 1dc349478c8..96423cc6e4f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -53,6 +53,12 @@ ) +def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token): + out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype) + paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token) + return out + + class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ Use Cutlass Group Gemm to compute Fused MoE. @@ -156,31 +162,46 @@ def apply_ep_prefill( if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": # --- moe_permute / moe_unpermute path --- recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) - (permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute( - hidden_states=recv_x, - scale=None, - expert_routemap_topk=recv_topk_idx_i32, - expert_prob_topk=recv_topk_weights, - num_experts=layer.num_local_experts, - tokens_per_expert=[], - padding_alignment=128, - override_buffer_size=token_all_num, + (permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = ( + paddle.nn.functional.moe_permute( + hidden_states=recv_x, + scale=None, + expert_routemap_topk=recv_topk_idx_i32, + expert_prob_topk=recv_topk_weights, + num_experts=layer.num_local_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=token_all_num, + return_expert_indices=True, + ) ) - out = paddle.incubate.nn.functional.batched_gemm( - permute_input, - getattr(layer, self.added_weight_attrs[0]), - recv_num_tokens_per_expert_list, - ) + if paddlefleet_ops is not None: + out = m_grouped_bf16_gemm_nn_contiguous( + permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices + ) + else: + out = paddle.incubate.nn.functional.batched_gemm( + permute_input, + getattr(layer, self.added_weight_attrs[0]), + recv_num_tokens_per_expert_list, + ) + if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) else: out = paddle.incubate.nn.functional.swiglu(out) - ffn_out = paddle.incubate.nn.functional.batched_gemm( - out, - getattr(layer, self.added_weight_attrs[1]), - recv_num_tokens_per_expert_list, - ) + + if paddlefleet_ops is not None: + ffn_out = m_grouped_bf16_gemm_nn_contiguous( + out, getattr(layer, self.added_weight_attrs[1]), m_indices + ) + else: + ffn_out = paddle.incubate.nn.functional.batched_gemm( + out, + getattr(layer, self.added_weight_attrs[1]), + recv_num_tokens_per_expert_list, + ) tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( hidden_states_unzipped=ffn_out, From 13034ef0ca4b42ecd05ad5889eae61b39d2199ed Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Tue, 21 Apr 2026 21:31:45 +0800 Subject: [PATCH 046/143] [BugFix] Fix skip_x_record_stream incompatibility across deep_ep versions (#7542) (#7546) * fix skip_x_record_stream * fix * optim Co-authored-by: Yuanle Liu --- fastdeploy/model_executor/layers/moe/ep.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 33993872cb6..1b1df3748ad 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import inspect import traceback from abc import abstractmethod from types import ModuleType @@ -602,6 +603,8 @@ def __init__( use_internode_ll_two_stage=use_internode_ll_two_stage, ) self.num_worst_tokens = prefill_num_worst_tokens + self._dispatch_parameters: Optional[set] = None + self._combine_parameters: Optional[set] = None logger.info(f"prefill_num_worst_tokens {prefill_num_worst_tokens}") def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False): @@ -656,8 +659,12 @@ def dispatch( } if envs.FD_USE_PFCC_DEEP_EP: - dispatch_args["num_worst_tokens"] = self.num_worst_tokens - dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0 + if self._dispatch_parameters is None: + self._dispatch_parameters = set(inspect.signature(buffer.dispatch).parameters) + if "num_worst_tokens" in self._dispatch_parameters: + dispatch_args["num_worst_tokens"] = self.num_worst_tokens + if "skip_x_record_stream" in self._dispatch_parameters: + dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0 return buffer.dispatch(**dispatch_args) @@ -683,7 +690,10 @@ def combine( } if envs.FD_USE_PFCC_DEEP_EP: - combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0 + if self._combine_parameters is None: + self._combine_parameters = set(inspect.signature(buffer.combine).parameters) + if "skip_x_record_stream" in self._combine_parameters: + combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0 fused_moe_out, _, event = buffer.combine(**combine_args) return fused_moe_out, event From d5518463ce0152b309d149fa6df8c2fe00e700fe Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:46:57 +0800 Subject: [PATCH 047/143] Mooncake storage register local buffer by chunk (#7416) (#7540) --- docs/features/global_cache_pooling.md | 2 +- docs/zh/features/global_cache_pooling.md | 2 +- .../mooncake_store/mooncake_store.py | 75 ++++++++++++++++--- 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/docs/features/global_cache_pooling.md b/docs/features/global_cache_pooling.md index 2218e788cf3..3c8e18301b6 100644 --- a/docs/features/global_cache_pooling.md +++ b/docs/features/global_cache_pooling.md @@ -90,7 +90,7 @@ Create a `mooncake_config.json` file: "metadata_server": "http://0.0.0.0:15002/metadata", "master_server_addr": "0.0.0.0:15001", "global_segment_size": 1000000000, - "local_buffer_size": 134217728, + "local_buffer_size": 1048576, "protocol": "rdma", "rdma_devices": "" } diff --git a/docs/zh/features/global_cache_pooling.md b/docs/zh/features/global_cache_pooling.md index 292e764ac80..b0cf985f3a3 100644 --- a/docs/zh/features/global_cache_pooling.md +++ b/docs/zh/features/global_cache_pooling.md @@ -90,7 +90,7 @@ pip install ./dist/fastdeploy*.whl "metadata_server": "http://0.0.0.0:15002/metadata", "master_server_addr": "0.0.0.0:15001", "global_segment_size": 1000000000, - "local_buffer_size": 134217728, + "local_buffer_size": 1048576, "protocol": "rdma", "rdma_devices": "" } diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py index ba7d003b7ae..1a81cfd652f 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -31,7 +31,14 @@ from fastdeploy.utils import get_host_ip DEFAULT_GLOBAL_SEGMENT_SIZE = 1024 * 1024 * 1024 # 1 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128MB +DEFAULT_LOCAL_BUFFER_SIZE = 1024 * 1024 # 1MB +DEFAULT_MC_MAX_MR_SIZE = 4 * 1024 * 1024 * 1024 # 4GB +MIN_MC_MAX_MR_SIZE = 1024 * 1024 * 1024 # 1GB +MAX_MC_MAX_MR_SIZE = 6 * 1024 * 1024 * 1024 # 6GB + + +def byte_to_gb(byte): + return byte / (1024 * 1024 * 1024) @dataclass @@ -111,9 +118,25 @@ def __init__(self, tp_rank=None): host_ip = get_host_ip() os.environ["MC_TCP_BIND_ADDRESS"] = host_ip logger.info(f"Set MC_TCP_BIND_ADDRESS to {host_ip}") - if os.environ.get("MC_MAX_MR_SIZE") is None: - os.environ["MC_MAX_MR_SIZE"] = str(4 * 1024**3) # 4GB - logger.info("MC_MAX_MR_SIZE is not set, default to 4GB.") + + # Set MC_MAX_MR_SIZE for mooncake store to control the maximum mr size + self.mc_max_mr_size = int(os.environ.get("MC_MAX_MR_SIZE", 0)) + if self.mc_max_mr_size == 0: + self.mc_max_mr_size = DEFAULT_MC_MAX_MR_SIZE + logger.info(f"MC_MAX_MR_SIZE is not set, default to {byte_to_gb(DEFAULT_MC_MAX_MR_SIZE)} GB.") + elif self.mc_max_mr_size < MIN_MC_MAX_MR_SIZE: + self.mc_max_mr_size = MIN_MC_MAX_MR_SIZE + logger.info( + f"MC_MAX_MR_SIZE is smaller than {byte_to_gb(MIN_MC_MAX_MR_SIZE)} GB, set to {byte_to_gb(MIN_MC_MAX_MR_SIZE)} GB." + ) + elif self.mc_max_mr_size > MAX_MC_MAX_MR_SIZE: + self.mc_max_mr_size = MAX_MC_MAX_MR_SIZE + logger.info( + f"MC_MAX_MR_SIZE is larger than {byte_to_gb(MAX_MC_MAX_MR_SIZE)} GB, set to {byte_to_gb(MAX_MC_MAX_MR_SIZE)} GB." + ) + else: + logger.info(f"MC_MAX_MR_SIZE is set to {self.mc_max_mr_size} bytes.") + os.environ["MC_MAX_MR_SIZE"] = str(self.mc_max_mr_size) try: from mooncake.store import MooncakeDistributedStore @@ -129,6 +152,11 @@ def __init__(self, tp_rank=None): self.config = MooncakeStoreConfig.create() if self.tp_rank is not None: self.config.select_rdma_device(self.tp_rank) + if self.config.local_buffer_size > self.mc_max_mr_size: + raise ValueError( + f"local_buffer_size {self.config.local_buffer_size} must be " + f"smaller than mc_max_mr_size {self.mc_max_mr_size}" + ) logger.info(f"Mooncake Configuration loaded, {self.config}.") ret_code = self.store.setup( @@ -162,13 +190,38 @@ def warmup(self): self.store.remove(warmup_key) def register_buffer(self, buffer_ptr, buffer_size) -> None: - try: - ret_code = self.store.register_buffer(buffer_ptr, buffer_size) - if ret_code: - logger.error(f"failed to register buffer, error code: {ret_code}") - except TypeError as err: - logger.error("Failed to register buffer to Mooncake Store: %s", err) - raise TypeError("Mooncake Store Register Buffer Error.") from err + """Register a buffer with Mooncake Store. + If buffer_size exceeds mc_max_mr_size, the buffer is split into + multiple chunks, each registered separately. + cuda_host_alloc returns physically contiguous pinned memory, so + pointer offset arithmetic is valid for sub-region registration. + """ + max_mr_size = self.mc_max_mr_size + if buffer_size <= max_mr_size: + try: + ret_code = self.store.register_buffer(buffer_ptr, buffer_size) + assert ret_code == 0, f"failed to register buffer, error code: {ret_code}" + except TypeError as err: + logger.error("Failed to register buffer to Mooncake Store: %s", err) + raise TypeError("Mooncake Store Register Buffer Error.") from err + else: + num_chunks = (buffer_size + max_mr_size - 1) // max_mr_size + logger.info( + f"Registering buffer of {byte_to_gb(buffer_size):.2f}GB in {num_chunks} chunks " + f"(max_mr_size={byte_to_gb(max_mr_size):.2f}GB per chunk)" + ) + for i in range(num_chunks): + chunk_ptr = buffer_ptr + i * max_mr_size + chunk_size = min(max_mr_size, buffer_size - i * max_mr_size) + try: + ret_code = self.store.register_buffer(chunk_ptr, chunk_size) + assert ret_code == 0, ( + f"failed to register chunk {i}/{num_chunks}, " + f"size={byte_to_gb(chunk_size):.2f}GB, error code: {ret_code}" + ) + except TypeError as err: + logger.error("Failed to register chunk %d/%d to Mooncake Store: %s", i, num_chunks, err) + raise TypeError("Mooncake Store Register Buffer Error.") from err def set( self, From 86df2a9e86472ff9295329a46228b657d8fec309 Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:59:52 +0800 Subject: [PATCH 048/143] Update args_utils.py (#7549) --- fastdeploy/engine/args_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index d350350f85d..4b4e7aabf6c 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -264,7 +264,7 @@ class EngineArgs: """ Flag to enable prefix caching. """ - enable_output_caching: bool = True + enable_output_caching: bool = False """ Flag to enable kv cache for output tokens, only valid in V1 scheduler. """ From b0fde163a6b9b58fa8f1096b89d943790bcddb65 Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Date: Wed, 22 Apr 2026 11:01:54 +0800 Subject: [PATCH 049/143] Enable output caching by default --- fastdeploy/engine/args_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 4b4e7aabf6c..d350350f85d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -264,7 +264,7 @@ class EngineArgs: """ Flag to enable prefix caching. """ - enable_output_caching: bool = False + enable_output_caching: bool = True """ Flag to enable kv cache for output tokens, only valid in V1 scheduler. """ From 2961400190dd7951ae118552141239d2054d252d Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Wed, 22 Apr 2026 15:24:10 +0800 Subject: [PATCH 050/143] [Cherry-Pick][BugFix] Fix clear_parameters hang issue in MTP during weight cleanup in RL (#7522) (#7523) * fix mtp clear graph bugs in rl --- fastdeploy/worker/gpu_model_runner.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a6addca4abd..fc3dbc4ab2c 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2697,12 +2697,19 @@ def clear_parameters(self, pid): # Clear CUDAGraph if self.use_cudagraph: self.model.clear_graph_opt_backend() + if ( + self.speculative_decoding + and self.spec_method == SpecMethod.MTP + and self.graph_opt_config.draft_model_use_cudagraph + ): + self.proposer.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) - if self.spec_method == SpecMethod.MTP: - self.proposer.model.clear_graph_opt_backend() + + # NOTE(wangyanpeng): MTP cache must be cleared before clearing the main KV cache + if self.speculative_decoding and self.spec_method == SpecMethod.MTP: self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.cuda.empty_cache() From 9c91ecb1ec81ff4dc69e4de596ec68e5a8cc49c4 Mon Sep 17 00:00:00 2001 From: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Date: Wed, 22 Apr 2026 15:49:51 +0800 Subject: [PATCH 051/143] [Cherry-Pick][BugFix] Fix bugs in /v1/abort_requests interface from PR(#6992) (#7176) (#7551) * abort api bug fix * bug fix * bug fix --- fastdeploy/engine/common_engine.py | 12 +++++++++++- fastdeploy/engine/sched/resource_manager_v1.py | 6 ++++-- fastdeploy/entrypoints/openai/protocol.py | 8 ++++---- tests/engine/test_common_engine.py | 4 ++++ 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index dabed9e4342..5f91a318556 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1587,13 +1587,14 @@ def _control_abort_requests(self, control_req: ControlRequest): engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, request_start_time=request.metrics.arrival_time if request.metrics else now, ) + eos_token_ids = getattr(request, "eos_token_ids", [0]) result = RequestOutput( request_id=req_id, finished=True, outputs=CompletionOutput( index=0, send_idx=len(partial_token_ids), - token_ids=[self.data_processor.eos_token_ids[0]], + token_ids=[eos_token_ids[0]], ), metrics=abort_metrics, error_code=200, @@ -1637,10 +1638,19 @@ def _wait_abort_complete(self, target_req_ids, stall_timeout=1): reset progress state if any, then continue monitoring """ target_set = set(target_req_ids) + target_set = target_set & (set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) prev_remaining_count = len(target_set) last_progress_time = time.time() remaining = target_set & self.resource_manager.get_reqs_in_aborting() while remaining: + alive_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) + finished_reqs = target_set - alive_reqs + if finished_reqs: + self.llm_logger.info(f"abort targets already finished, skip: {finished_reqs}") + for req_id in finished_reqs: + self.resource_manager.waiting_abort_req_id_set.discard(req_id) + self.resource_manager.to_be_aborted_req_id_set.discard(req_id) + target_set -= finished_reqs remaining = target_set & self.resource_manager.get_reqs_in_aborting() if not remaining: self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index ffc9c0bacf4..f3704a533b5 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -281,7 +281,7 @@ def recycle_abort_task(self, request_id): self.stop_flags[request.idx] = True # 设置停止标志 del self.requests[request_id] del self.req_dict[request_id] - self.to_be_aborted_req_id_set.remove(request_id) + self.to_be_aborted_req_id_set.discard(request_id) self.update_metrics() def _trigger_abort(self, request_id, scheduled_reqs): @@ -293,7 +293,7 @@ def _trigger_abort(self, request_id, scheduled_reqs): abort_request.cached_block_num = 0 scheduled_reqs.append(self._prepare_abort_task(abort_request)) self.to_be_aborted_req_id_set.add(request_id) - self.waiting_abort_req_id_set.remove(request_id) + self.waiting_abort_req_id_set.discard(request_id) def _info_each_block(self): """ @@ -1544,6 +1544,8 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): del self.requests[req_id] if req_id in self.req_dict: del self.req_dict[req_id] + self.waiting_abort_req_id_set.discard(req_id) + self.to_be_aborted_req_id_set.discard(req_id) # Do not block the main thread here # Write cache to storage if kvcache_storage_backend is enabled diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 3560f3a8aef..b4e87e7a20c 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -268,7 +268,7 @@ class ChatCompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] speculate_metrics: Optional[SpeculateMetrics] = None @@ -333,7 +333,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None arrival_time: Optional[float] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -369,7 +369,7 @@ class CompletionResponseChoice(BaseModel): draft_logprobs: Optional[CompletionLogprobs] = None prompt_logprobs: Optional[PromptLogprobs] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -415,7 +415,7 @@ class CompletionResponseStreamChoice(BaseModel): prompt_tokens: Optional[str] = None completion_tokens: Optional[str] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 8778e2013ea..53bd8462d81 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -3700,6 +3700,8 @@ def test_wait_abort_complete_progress(self): """_wait_abort_complete exits when background thread cleans up.""" eng = self._make_abort_engine() eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"} + # Add the request to requests dict so it won't be filtered out + eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} call_count = [0] @@ -3718,6 +3720,8 @@ def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self): """Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set.""" eng = self._make_abort_engine() eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"} + # Add the request to requests dict so it won't be filtered out + eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} def mock_recycle(req_id): eng.resource_manager.to_be_aborted_req_id_set.discard(req_id) From 3d6d3a217768f0934cc5a3676272e229022adbe5 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Wed, 22 Apr 2026 18:05:17 +0800 Subject: [PATCH 052/143] [DataProcessor] add completions (#7543) (#7558) * add completions * add unit test * add unit test Co-authored-by: luukunn <981429396@qq.com> --- fastdeploy/input/base_processor.py | 3 +++ tests/input/test_text_processor.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/fastdeploy/input/base_processor.py b/fastdeploy/input/base_processor.py index 357339be766..24e0d7ddec8 100644 --- a/fastdeploy/input/base_processor.py +++ b/fastdeploy/input/base_processor.py @@ -412,6 +412,9 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): if len(request["prompt_token_ids"]) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") + if request.get("completion_token_ids"): + request["prompt_token_ids"].extend(request["completion_token_ids"]) + # truncate prompts that exceed the length limit if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] diff --git a/tests/input/test_text_processor.py b/tests/input/test_text_processor.py index 818f4b77d0f..56137633350 100644 --- a/tests/input/test_text_processor.py +++ b/tests/input/test_text_processor.py @@ -418,6 +418,31 @@ def test_process_request_dict_rejects_bad_kwargs(self): with self.assertRaisesRegex(ValueError, "chat_template_kwargs must be a dict"): self.processor.process_request_dict(request) + def test_process_request_dict_completion_token_ids_extend(self): + request = {"prompt": "hi", "completion_token_ids": [10, 11, 12], "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=20) + # prompt "hi" is tokenized to [2] by DummyTokenizer, then extended with completion_token_ids + self.assertEqual(processed["prompt_token_ids"], [2, 10, 11, 12]) + + def test_process_request_dict_no_completion_token_ids(self): + request = {"prompt": "hi", "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=20) + # without completion_token_ids, prompt_token_ids should remain as tokenized result + self.assertEqual(processed["prompt_token_ids"], [2]) + + def test_process_request_dict_empty_completion_token_ids(self): + request = {"prompt": "hi", "completion_token_ids": [], "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=20) + # empty list is falsy, should not extend prompt_token_ids + self.assertEqual(processed["prompt_token_ids"], [2]) + + def test_process_request_dict_completion_token_ids_truncated(self): + # prompt "hi" -> [2], extend [10,11,12] -> [2,10,11,12] (len=4) + # max_model_len=3, 4 > 3 triggers truncation: [:3-1] = [:2] -> [2, 10] + request = {"prompt": "hi", "completion_token_ids": [10, 11, 12], "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=3) + self.assertEqual(processed["prompt_token_ids"], [2, 10]) + def test_ids2tokens_and_clear_request_status(self): delta, _, _ = self.processor.ids2tokens([3], "task-1") self.assertEqual(delta, "3") From 2c04dfdffd986e576a8753bbfca24f5acdcfa7c4 Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Date: Wed, 22 Apr 2026 21:49:01 +0800 Subject: [PATCH 053/143] Update args_utils.py --- fastdeploy/engine/args_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index d350350f85d..4b4e7aabf6c 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -264,7 +264,7 @@ class EngineArgs: """ Flag to enable prefix caching. """ - enable_output_caching: bool = True + enable_output_caching: bool = False """ Flag to enable kv cache for output tokens, only valid in V1 scheduler. """ From 9ef8467d74b681751ee983295205c7a6f6e5d474 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Wed, 22 Apr 2026 21:49:10 +0800 Subject: [PATCH 054/143] [Scheduler][BugFix] Fix token_budget calculation to use actual decode request count (#7499) (#7562) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Scheduler][BugFix] Fix token_budget calculation to use actual decode request count ## Motivation 当前 `token_budget` 的计算方式存在两个问题: 1. **预扣过多**:budget 按 `max_num_seqs * tokens_per_seq` 预扣,而不是 running 队列中实际处于 decode 阶段的请求数,导致 prefill 可用的 token 数被低估。 2. **循环内重复扣减**:decode 分支固定执行 `token_budget -= 1`,在 spec decode 场景下(`tokens_per_seq > 1`)每个 decode 请求只扣 1,少扣了 `num_speculative_tokens` 个;此外,当 running 队列中 prefill 请求耗尽 budget 后,排在其后的 decode 请求会被循环退出条件 `token_budget > 0` 提前跳过,导致调度漏发。 ## Modifications - `resource_manager_v1.py` - 新增 `_is_decoding(request)` 内部方法,封装 `num_computed_tokens >= need_prefill_tokens` 判断,全文统一使用 - 调度前统计 running 队列中真实的 decode 请求数 `num_running_decode_reqs`,以 `num_running_decode_reqs * tokens_per_seq` 一次性预扣 budget,替代原来的 `max_num_seqs * tokens_per_seq` - 去掉 decode 分支内的 `token_budget -= 1`(已在循环前整体预扣) - 修改循环退出条件:decode 请求不受 `token_budget <= 0` 限制,仅 prefill 请求在 budget 耗尽时退出 - `config.py` - 修复 `max_num_batched_tokens` 的合法性校验,考虑 spec decode 场景下 `tokens_per_seq = num_speculative_tokens + 1`,改为检查 `max_num_batched_tokens >= max_num_seqs * tokens_per_seq` ## Usage or Command ```bash # 普通启动(非spec decode,行为不变) python -m fastdeploy.entrypoints.openai.api_server \ --max-num-batched-tokens 8192 \ --max-num-seqs 256 \ ... # spec decode 场景(tokens_per_seq = num_speculative_tokens + 1) # 确保 max_num_batched_tokens >= max_num_seqs * tokens_per_seq,否则启动报错 python -m fastdeploy.entrypoints.openai.api_server \ --max-num-batched-tokens 8192 \ --max-num-seqs 256 \ --num-speculative-tokens 4 \ ... ``` * [FDConfig][BugFix] Fix AttributeError when speculative_config is SimpleNamespace without num_speculative_tokens ## Motivation 当测试中使用 `SimpleNamespace(method=None)` 构造 `speculative_config` 时, `config.py` 的 `check()` 方法直接访问 `self.speculative_config.num_speculative_tokens`, 导致 `AttributeError: 'types.SimpleNamespace' object has no attribute 'num_speculative_tokens'`。 影响以下测试文件: - tests/v1/test_resource_manager_v1.py - tests/eplb/test_eplb_utils.py - tests/eplb/test_experts_manager.py - tests/v1/cache_manager/test_prefix_cache.py - tests/v1/test_schedule_output.py ## Modifications - `fastdeploy/config.py`: 使用 `getattr(..., "num_speculative_tokens", 0)` 兜底, 防止 speculative_config 对象缺少该属性时崩溃 - 测试文件:将 `speculative_config=SimpleNamespace(method=None)` 统一改为 `speculative_config=None`,与无投机解码场景语义一致 --------- Co-authored-by: kevin Co-authored-by: Claude Sonnet 4.6 --- fastdeploy/config.py | 10 ++++++++-- fastdeploy/engine/sched/resource_manager_v1.py | 16 +++++++++------- tests/eplb/test_eplb_utils.py | 3 +-- tests/eplb/test_experts_manager.py | 3 +-- tests/v1/cache_manager/test_prefix_cache.py | 3 +-- tests/v1/test_resource_manager_v1.py | 3 +-- tests/v1/test_schedule_output.py | 9 +++------ 7 files changed, 24 insertions(+), 23 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 7bdb413c126..e56ebe2e704 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2238,9 +2238,15 @@ def check(self): assert ( self.scheduler_config.max_num_seqs >= 1 ), f"max_num_seqs: {self.scheduler_config.max_num_seqs} should be larger than 1" - assert self.scheduler_config.max_num_batched_tokens >= self.scheduler_config.max_num_seqs, ( + tokens_per_seq = ( + (getattr(self.speculative_config, "num_speculative_tokens", 0) + 1) + if self.speculative_config is not None + else 1 + ) + assert self.scheduler_config.max_num_batched_tokens >= self.scheduler_config.max_num_seqs * tokens_per_seq, ( f"max_num_batched_tokens: {self.scheduler_config.max_num_batched_tokens} " - f"should be larger than or equal to max_num_seqs: {self.scheduler_config.max_num_seqs}" + f"should be larger than or equal to max_num_seqs: {self.scheduler_config.max_num_seqs} " + f"* tokens_per_seq: {tokens_per_seq}" ) assert ( self.scheduler_config.max_num_batched_tokens diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index f3704a533b5..cf65edf717e 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -245,6 +245,10 @@ def get_new_block_nums(self, request: Request, num_new_tokens: int): block_num = min(block_num, self.config.cache_config.max_block_num_per_seq) return block_num + def _is_decoding(self, request) -> bool: + """Return True if the request has finished prefill and is in the decoding phase.""" + return request.num_computed_tokens >= request.need_prefill_tokens + def _prepare_prefill_task(self, request, new_token_num): request.prefill_start_index = request.num_computed_tokens request.prefill_end_index = request.num_computed_tokens + new_token_num @@ -749,7 +753,7 @@ def cache_output_tokens(self, request): and self.config.scheduler_config.splitwise_role != "decode" ): with self.lock: - if request.num_computed_tokens >= request.need_prefill_tokens: # request is decoding + if self._is_decoding(request): # request is decoding self.cache_manager.cache_output_blocks(request, self.config.cache_config.block_size) def schedule(self): @@ -773,12 +777,10 @@ def get_enough_request(request, scheduled_reqs): if self.config.speculative_config is not None else 1 ) + num_running_decode_reqs = sum(1 for req in self.running if self._is_decoding(req)) token_budget = ( - self.config.scheduler_config.max_num_batched_tokens - - self.config.scheduler_config.max_num_seqs * tokens_per_seq + self.config.scheduler_config.max_num_batched_tokens - num_running_decode_reqs * tokens_per_seq ) - # temperatory solution to avoid negative token_budget - token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512)) need_abort_requests = [] # users trigger abortion # First, schedule the RUNNING requests. @@ -791,7 +793,7 @@ def get_enough_request(request, scheduled_reqs): self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1) self.need_block_num_signal.value[request.idx] = 0 - if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding + if self._is_decoding(request): # to be decoding if ( self.config.scheduler_config.splitwise_role == "prefill" ): # do not need to schedule for decoding @@ -838,7 +840,7 @@ def get_enough_request(request, scheduled_reqs): # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) num_decoding_req_nums += 1 - token_budget -= 1 + # Decode token cost has been pre-deducted upfront (num_running_decode_reqs * tokens_per_seq). if ( request.use_extend_tables and request.request_id not in self.using_extend_tables_req_id diff --git a/tests/eplb/test_eplb_utils.py b/tests/eplb/test_eplb_utils.py index 4b367c2a36d..08d4d89dc0f 100644 --- a/tests/eplb/test_eplb_utils.py +++ b/tests/eplb/test_eplb_utils.py @@ -168,7 +168,6 @@ def setUp(self): cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.num_hidden_layers = 3 @@ -200,7 +199,7 @@ def setUp(self): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, eplb_config=eplb_config, ) diff --git a/tests/eplb/test_experts_manager.py b/tests/eplb/test_experts_manager.py index b736c20f263..15060ea480a 100644 --- a/tests/eplb/test_experts_manager.py +++ b/tests/eplb/test_experts_manager.py @@ -48,7 +48,6 @@ def setUp(self): cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.num_hidden_layers = 3 @@ -80,7 +79,7 @@ def setUp(self): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, eplb_config=eplb_config, ) diff --git a/tests/v1/cache_manager/test_prefix_cache.py b/tests/v1/cache_manager/test_prefix_cache.py index b3393500173..0a5eb669582 100644 --- a/tests/v1/cache_manager/test_prefix_cache.py +++ b/tests/v1/cache_manager/test_prefix_cache.py @@ -31,7 +31,6 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.architectures = ["test_model"] model_cfg.mm_max_tokens_per_item = None @@ -46,7 +45,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed") diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index 7cee36cd060..5cc7eb4ef01 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -138,7 +138,6 @@ def setUp(self): cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 3200 model_cfg.architectures = ["test_model"] @@ -155,7 +154,7 @@ def setUp(self): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) self.manager = ResourceManagerV1( diff --git a/tests/v1/test_schedule_output.py b/tests/v1/test_schedule_output.py index 3175b087e2e..db30b15f48b 100644 --- a/tests/v1/test_schedule_output.py +++ b/tests/v1/test_schedule_output.py @@ -29,7 +29,6 @@ def test_normal_schedule(): args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=False) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.mm_max_tokens_per_item = None @@ -41,7 +40,7 @@ def test_normal_schedule(): model_config=model_cfg, cache_config=cache_cfg, parallel_config=parallel_cfg, - speculative_config=speculative_cfg, + speculative_config=None, graph_opt_config=graph_opt_cfg, scheduler_config=scheduler_cfg, ) @@ -95,7 +94,6 @@ def test_preempted_request(): args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=False) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.mm_max_tokens_per_item = None @@ -108,7 +106,7 @@ def test_preempted_request(): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) resource_manager_v1 = ResourceManagerV1( @@ -162,7 +160,6 @@ def test_caching_output(): args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=False) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.mm_max_tokens_per_item = None @@ -175,7 +172,7 @@ def test_caching_output(): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) resource_manager_v1 = ResourceManagerV1( From 258b22abeab470266ed354f94ffed2dffa530225 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Thu, 23 Apr 2026 10:45:45 +0800 Subject: [PATCH 055/143] support deepgemm without bias input (#7559) (#7565) Co-authored-by: JYChen --- .../layers/quantization/block_wise_fp8.py | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index ae37ca45961..19f4597ab34 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -51,6 +51,19 @@ else: fp8_gemm_nt = None +# Detect whether fp8_gemm_nt accepts a 'bias' keyword argument +_fp8_gemm_nt_has_bias_kwarg = False +if fp8_gemm_nt is not None: + import inspect + + try: + _sig = inspect.signature(fp8_gemm_nt) + _fp8_gemm_nt_has_bias_kwarg = "bias" in _sig.parameters + except (ValueError, TypeError): + # pybind11 functions may not expose signatures via inspect; + # fall back to a cheap probe call to determine support. + pass + class BlockWiseFP8Config(QuantConfigBase): """ @@ -128,14 +141,22 @@ def deep_gemm_fp8_gemm_nt( sm_version = get_sm_version() if sm_version >= 100 and current_platform.is_cuda(): # disable_ue8m0_cast is default False for SM100 - fp8_gemm_nt( - (x, x_scale_tensor), - (layer_weight, layer_weight_scale_inv), - linear_out, - bias=bias, - ) + if _fp8_gemm_nt_has_bias_kwarg: + fp8_gemm_nt( + (x, x_scale_tensor), + (layer_weight, layer_weight_scale_inv), + linear_out, + bias=bias, + ) + else: + fp8_gemm_nt( + (x, x_scale_tensor), + (layer_weight, layer_weight_scale_inv), + linear_out, + ) + if bias is not None: + linear_out = paddle.add(linear_out, bias) else: - # disable_ue8m0_cast is default False for SM100 fp8_gemm_nt( (x, x_scale_tensor), (layer_weight, layer_weight_scale_inv), From b3aa46978db2bfb8e6778b94b429fecd5429ffed Mon Sep 17 00:00:00 2001 From: Zero Rains Date: Thu, 23 Apr 2026 15:53:42 +0800 Subject: [PATCH 056/143] [KSM] support keep sampling mask (#7460) * [KSM] support keep sampling mask * fix bug * fix typo * fix test case * update zmq --- fastdeploy/config.py | 1 + fastdeploy/engine/args_utils.py | 20 ++ fastdeploy/engine/common_engine.py | 1 + fastdeploy/engine/engine.py | 1 + fastdeploy/engine/request.py | 5 + fastdeploy/entrypoints/openai/protocol.py | 5 + fastdeploy/entrypoints/openai/serving_chat.py | 32 +++ .../model_executor/layers/sample/logprobs.py | 40 +++- .../model_executor/layers/sample/meta_data.py | 2 + .../model_executor/layers/sample/sampler.py | 223 +++++++++++++++++- .../model_executor/pre_and_post_process.py | 74 ++++++ fastdeploy/output/stream_transfer_data.py | 4 + fastdeploy/output/token_processor.py | 22 ++ fastdeploy/worker/gpu_model_runner.py | 46 ++++ fastdeploy/worker/output.py | 29 ++- fastdeploy/worker/worker_process.py | 10 + tests/e2e/test_ernie_21b_mtp.py | 174 ++++++++++++++ .../openai/test_max_streaming_tokens.py | 1 + tests/entrypoints/openai/test_serving_chat.py | 2 + tests/metrics/test_new_metrics.py | 2 + .../output/test_process_batch_draft_tokens.py | 2 + tests/output/test_process_batch_output.py | 1 + .../test_process_batch_output_use_zmq.py | 1 + .../test_token_processor_trace_print.py | 2 + 24 files changed, 691 insertions(+), 9 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e56ebe2e704..18efa2586b6 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -211,6 +211,7 @@ def __init__( self.enable_logprob = False self.max_logprobs = 20 self.logprobs_mode = "raw_logprobs" + self.enable_keep_sampling_mask = False self.redundant_experts_num = 0 self.seed = 0 self.quantization = None diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 4b4e7aabf6c..02879bcc7fc 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -460,6 +460,14 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + enable_keep_sampling_mask: bool = False + """ + When enabled, the server returns a sparse index list for each generated token, indicating + which vocabulary positions were retained after top_p/top_k sampling, and streams it to + the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]], + where each inner list contains the retained vocabulary indices for a predicted token. + """ + max_logprobs: int = 20 """ Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the @@ -901,6 +909,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--enable-keep-sampling-mask", + action="store_true", + default=EngineArgs.enable_keep_sampling_mask, + help=( + "Enable output of sampling mask as a sparse index list over the vocabulary. " + "For non-MTP decoding, this is a list[int] per token step indicating which " + "vocabulary indices were kept after top_p/top_k sampling. " + "For MTP decoding, this is a list[list[int]] per token step, where each inner " + "list corresponds to one MTP group." + ), + ) model_group.add_argument( "--max-logprobs", type=int, diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 5f91a318556..6b500a38bdf 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2501,6 +2501,7 @@ def _start_worker_service(self): "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 283693fae8c..44edea80d34 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -655,6 +655,7 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0e95cd5e1fb..ccab1ac4114 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -727,6 +727,10 @@ class CompletionOutput: delta_message: Optional[DeltaMessage] = None multipart: Optional[list[Any]] = None num_image_tokens: Optional[int] = None + # Sparse indices of retained vocab ids: + # - Non-MTP: list[int] + # - MTP: list[list[int]] + sampling_mask: Optional[Any] = None def to_dict(self): """ @@ -745,6 +749,7 @@ def to_dict(self): "text": self.text, "reasoning_content": self.reasoning_content, "reasoning_token_num": self.reasoning_token_num, + "sampling_mask": self.sampling_mask, } @classmethod diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index b4e87e7a20c..5642daca0c5 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -270,6 +270,8 @@ class ChatCompletionResponseChoice(BaseModel): prompt_logprobs: Optional[PromptLogprobs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] speculate_metrics: Optional[SpeculateMetrics] = None + # Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token + sampling_mask: Optional[List[List[int]]] = None class ChatCompletionResponse(BaseModel): @@ -333,6 +335,9 @@ class ChatCompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None + # Per-token index list of retained positions after top_p sampling. + # Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step). + sampling_mask: Optional[List[List[int]]] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None arrival_time: Optional[float] = None speculate_metrics: Optional[SpeculateMetrics] = None diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index eb106f6550f..55bd37412a0 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -435,6 +435,11 @@ async def chat_completion_stream_generator( delta=delta_message, logprobs=logprobs_res, draft_logprobs=draft_logprobs_res, + sampling_mask=( + self._make_sampling_mask_list(output["sampling_mask"]) + if output.get("sampling_mask") is not None + else None + ), arrival_time=arrival_time, speculate_metrics=output_speculate_metrics, ) @@ -580,6 +585,7 @@ async def chat_completion_full_generator( decoder_base_url=self.tokenizer_base_url, ) prompt_logprobs_res_list = [[] for _ in range(num_choices)] + sampling_mask_list = [[] for _ in range(num_choices)] speculate_metrics = [None for _ in range(num_choices)] choices = [] while num_choices > 0: @@ -660,6 +666,9 @@ async def chat_completion_full_generator( ) if prompt_logprobs_res: prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res)) + output_sampling_mask = output.get("sampling_mask", None) + if output_sampling_mask is not None: + sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask)) speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None) if data["finished"]: trace_carrier = data.get("trace_carrier") @@ -695,6 +704,7 @@ async def chat_completion_full_generator( draft_logprob_contents=draft_logprob_contents, response_processor=response_processor, prompt_logprobs_res_list=prompt_logprobs_res_list, + sampling_mask_list=sampling_mask_list, max_tokens=max_tokens, speculate_metrics=speculate_metrics[idx], ) @@ -749,6 +759,7 @@ async def _create_chat_completion_choice( logprob_contents: list, draft_logprob_contents: list, prompt_logprobs_res_list: list, + sampling_mask_list: list, response_processor: ChatResponseProcessor, max_tokens: int, speculate_metrics: SpeculateMetrics | None, @@ -787,6 +798,11 @@ async def _create_chat_completion_choice( if prompt_logprobs_res_list[idx]: prompt_logprobs_full_res = prompt_logprobs_res_list[idx] + # Flatten per-step List[List[int]] into a single List[List[int]] over all tokens. + sampling_mask_full_res = None + if sampling_mask_list and sampling_mask_list[idx]: + sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step] + num_cached_tokens[idx] = data.get("num_cached_tokens", 0) num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0) num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0) @@ -810,6 +826,7 @@ async def _create_chat_completion_choice( logprobs=logprobs_full_res, draft_logprobs=draft_logprobs_full_res, prompt_logprobs=prompt_logprobs_full_res, + sampling_mask=sampling_mask_full_res, finish_reason=finish_reason, speculate_metrics=speculate_metrics, ) @@ -1000,3 +1017,18 @@ def _make_logprob_dict( ) for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) } + + @staticmethod + def _make_sampling_mask_list(sampling_mask) -> List[List[int]]: + """Wrap sampling_mask into a uniform List[List[int]] format. + + sampling_mask is already in sparse-index form (no bool-to-index conversion needed): + Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]] + MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...] + """ + assert sampling_mask is not None + if sampling_mask and isinstance(sampling_mask[0], list): + # MTP: already List[List[int]], return as-is + return sampling_mask + # Non-MTP: already List[int], wrap in outer list for uniform format + return [sampling_mask] diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index 559abdb298e..33fbbc01603 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -133,7 +133,7 @@ def build_output_logprobs( is_naive: bool = False, logprobs_mode: str = "default", compute_logprobs_fn: Optional[Callable] = None, -) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: +) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]: """ Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. @@ -153,15 +153,12 @@ def build_output_logprobs( scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs". Returns: - tuple: (logprobs_tensors, cu_batch_token_offset) + tuple: (logprobs_tensors, cu_batch_token_offset, output_logits) """ num_logprobs = sampling_metadata.max_num_logprobs logprobs_tensors = None cu_batch_token_offset = None - if num_logprobs is None: - return logprobs_tensors, cu_batch_token_offset - real_bsz = share_inputs["seq_lens_this_time"].shape[0] if is_naive: @@ -208,6 +205,10 @@ def build_output_logprobs( mask = idx < share_inputs["accept_num"].unsqueeze(1) token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) + # Adapt for sampling mask + if num_logprobs is None: + return None, None, output_logits + # Compute logprobs with temperature scaling and top_p normalization if logprobs_mode == "raw_logprobs": raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) @@ -217,5 +218,32 @@ def build_output_logprobs( raw_logprobs = F.log_softmax(output_logits, axis=-1) logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + # output_logits use to compute sampling_mask + return logprobs_tensors, cu_batch_token_offset, output_logits + - return logprobs_tensors, cu_batch_token_offset +def logprobs_renormalize_with_logz(logprobs: paddle.Tensor, logz, logprobs_tensors: LogprobsTensors): + """ + Renormalize logprobs to match truncated sampling distribution. + Args: + logprobs: tensor [B, max_num_logprobs + 1] + logz: [B], log(sum(probs in candidate set K)) for each request. + Can be np.ndarray or paddle.Tensor (CPU pinned memory). + logprobs_tensors: LogprobsTensors + """ + if isinstance(logz, paddle.Tensor): + logz = logz.astype(logprobs.dtype) + else: + logz = paddle.to_tensor(logz, dtype=logprobs.dtype) + # Renormalize: log π_masked = log π_full - log Z_K + # Only normalize valid candidates; padding positions use -inf + valid_mask = paddle.isfinite(logprobs) + normalized_logprobs = paddle.where( + valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) + ) + # Update logprobs_tensors with normalized values + return LogprobsTensors( + logprob_token_ids=logprobs_tensors.logprob_token_ids, + logprobs=normalized_logprobs, + selected_token_ranks=logprobs_tensors.selected_token_ranks, + ) diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 0d7f6915ab4..e2ecb276957 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -66,3 +66,5 @@ class SamplingMetadata: # Add for HPU post-processing seq_lens_encoder: Optional[paddle.Tensor] = None seq_lens_decoder: Optional[paddle.Tensor] = None + # Add for keep sampling mask + keep_sampling_mask: Optional[bool] = None diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 08a33c11096..2729cddba8f 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -19,6 +19,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, List, Optional +import numpy as np import paddle import paddle.nn.functional as F from paddle import nn @@ -105,6 +106,153 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le return top_p_padding, top_k_padding, topp_seed +def _compute_sampling_mask( + probs: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, + top_k_list: Optional[list] = None, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, int]: + """ + Compute a combined top-k + top-p (nucleus) sampling mask — GPU only, + no D2H transfer or CPU sync. + + Processing order: + 1. Sort probs descending once (shared by top-k and top-p stages). + 2. top-k mask — zero out positions beyond top_k[i] in sorted order. + 3. top-k renorm — renormalise in-place after truncation. + 4. top-p mask — cumsum on the already-sorted renormed probs; no + second argsort needed. + 5. intersect — AND of the two masks, applied on GPU before D2H. + + Either filter can be disabled: + - top-k is skipped when top_k_list is None or all values <= 0. + - top-p[i] >= 1.0 → keep all tokens for that request. + + Args: + probs: [num_reqs, vocab_size] softmax probabilities (GPU). + top_p: [num_reqs, 1] top-p threshold per request (GPU). + top_k: [num_reqs, 1] top-k per request (GPU, int); 0 = disabled. + top_k_list: Python list of top-k values; used to decide whether any + top-k filtering is needed at all. + + Returns: + Tuple of (indices_window, mask_window, logz_per_batch, real_bsz): + - indices_window: [B, max_k] GPU int64 tensor of sorted vocab indices. + - mask_window: [B, max_k] GPU bool tensor, True = retained. + - logz_per_batch: [B] GPU float32 tensor, log(Z_K) per request. + - real_bsz: int, the batch size. + """ + real_bsz = probs.shape[0] + vocab_size = probs.shape[1] + top_p = top_p[:real_bsz] # [B, 1] + + has_top_k = top_k is not None and top_k_list and any(x > 0 for x in top_k_list) + + # ------------------------------------------------------------------ + # Stage 1: single sort — descending by probability. + # sorted_indices / sorted_probs are reused by both top-k and top-p. + # ------------------------------------------------------------------ + sorted_indices = paddle.argsort(probs, axis=-1, descending=True) # [B, V] + sorted_probs = paddle.take_along_axis(probs, sorted_indices, axis=-1) # [B, V] + + # ------------------------------------------------------------------ + # Stage 2: top-k mask (GPU, no D2H) + # ------------------------------------------------------------------ + if has_top_k: + top_k = top_k[:real_bsz] # [B, 1] + # top_k == 0 means "disabled" → keep all columns for that row. + effective_k = paddle.where(top_k > 0, top_k, paddle.full_like(top_k, vocab_size)) + + # Relax: also keep positions whose prob ties with the k-th element. + # boundary index (0-based) = effective_k - 1, clamped to [0, V-1]. + k_idx = (effective_k - 1).clip(min=0).squeeze(-1).astype("int64") # [B] k-th index + batch_idx = paddle.arange(k_idx.shape[0], dtype="int64") # [B] bs index + boundary_prob = sorted_probs[batch_idx, k_idx].unsqueeze(-1) # [B, 1] min_probs in topk candidates + topk_mask = sorted_probs >= boundary_prob # [B, V] True = retained by top-k + + # Zero out tail, then renorm row-wise. + masked_sorted_probs = paddle.where(topk_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + row_sums = masked_sorted_probs.sum(axis=-1, keepdim=True).clip(min=1e-9) + renorm_sorted_probs = masked_sorted_probs / row_sums # [B, V] + else: + topk_mask = None + renorm_sorted_probs = sorted_probs + + # ------------------------------------------------------------------ + # Stage 3: top-p mask on already-sorted renormed probs (no re-sort). + # ------------------------------------------------------------------ + cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V] + topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V] + # When top_p[i] >= 1.0, keep the entire row. + topp_mask = paddle.where( + (top_p >= 1.0).expand_as(topp_mask), + paddle.ones_like(topp_mask), + topp_mask, + ) + + # Extend mask to cover sort tie-breaking: include all tokens whose + # probability >= the boundary token's probability (last retained + # in sorted order). In descending-sorted probs this just extends + # the contiguous True block by the run of equal-prob tokens. + k_per_row = topp_mask.astype("int32").sum(axis=-1, keepdim=True) # [B,1] + # boundary_idx = last True position (k-1), clamp for safety + boundary_idx = (k_per_row - 1).clip(min=0) # [B, 1] + boundary_prob = paddle.take_along_axis( + renorm_sorted_probs, + boundary_idx, + axis=-1, + ) # [B, 1] + topp_mask = topp_mask | (renorm_sorted_probs >= boundary_prob) + + # ------------------------------------------------------------------ + # Stage 4: intersect on GPU, then minimal D2H. + # ------------------------------------------------------------------ + final_mask = topk_mask & topp_mask if has_top_k else topp_mask # [B, V] + + k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] + max_k = k_per_row.max().reshape([-1]) # [1], stays on GPU + + # ------------------------------------------------------------------ + # Stage 5: compute logZ_K for renormalization + # Z_K = sum(probs[i] * final_mask[i]) for each request i + # logZ_K = log(Z_K), with small constant to avoid log(0) + # ------------------------------------------------------------------ + candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + z_k = candidate_probs.sum(axis=-1) # [B] + logz_per_batch = paddle.log(z_k + 1e-10) # [B], GPU + + # Slice only the leading max_k columns on GPU — typically max_k << vocab_size. + # All outputs stay on GPU; D2H is deferred to save_output via async copy_. + indices_window = sorted_indices.slice([1], [0], max_k) # [B, max_k] + mask_window = final_mask.slice([1], [0], max_k) # [B, max_k] + + return indices_window, mask_window, logz_per_batch, real_bsz + + +def _extract_sparse_indices( + indices_window_cpu: np.ndarray, + mask_window_cpu: np.ndarray, + real_bsz: int, +) -> List[np.ndarray]: + """ + Extract per-request sparse retained-token indices from CPU numpy arrays. + + This is the CPU-side counterpart of _compute_sampling_mask. It should be + called after the sampling_mask_event has been synchronized, so that the + async D2H copy is guaranteed to be complete. + + Args: + indices_window_cpu: [B, max_k] int64 numpy array of sorted vocab indices. + mask_window_cpu: [B, max_k] bool numpy array, True = retained. + real_bsz: batch size (number of rows to process). + + Returns: + List of length real_bsz; element i is a 1-D int64 numpy array of + retained vocab indices for request i. + """ + return [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)] + + class GuidedDecoding: """ processor for guided decoding. @@ -554,6 +702,34 @@ def forward_cuda( _record_logits_diagnostic(logits, tag="post_penalty_logits", probs=probs) probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) + + # Compute sampling mask BEFORE top_k_top_p_sampling modifies probs. + # All GPU ops; D2H is done via async copy_ with event sync in save_output. + sampling_mask = None + logz_per_batch = None + sampling_mask_event = None + if sampling_metadata.keep_sampling_mask: + sampling_mask_event = paddle.device.cuda.create_event() + indices_window_gpu, mask_window_gpu, logz_per_batch, mask_bsz = _compute_sampling_mask( + probs, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + ) + # Allocate CPU pinned tensors and async copy + indices_window_cpu = paddle.empty_like( + indices_window_gpu, dtype=indices_window_gpu.dtype, device="cpu" + ).pin_memory() + mask_window_cpu = paddle.empty_like( + mask_window_gpu, dtype=mask_window_gpu.dtype, device="cpu" + ).pin_memory() + indices_window_cpu.copy_(indices_window_gpu, False) + mask_window_cpu.copy_(mask_window_gpu, False) + # Record event — sync this event before reading CPU buffers + sampling_mask_event.record() + # Store deferred GPU→CPU data; sparse extraction happens in save_output + sampling_mask = (indices_window_cpu, mask_window_cpu, mask_bsz) + _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, @@ -577,6 +753,9 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, logits=logits, + sampling_mask=sampling_mask, + logz_per_batch=logz_per_batch, + sampling_mask_event=sampling_mask_event, ) return sampler_output @@ -1029,9 +1208,10 @@ def forward_cuda( reject_all_drafts, ) + keep_sampling_mask = sampling_metadata.keep_sampling_mask # Build logprobs via unified path (outside of sampling logic) - if sampling_metadata.max_num_logprobs is not None: - logprobs_tensors, cu_batch_token_offset = build_output_logprobs( + if sampling_metadata.max_num_logprobs is not None or keep_sampling_mask: + logprobs_tensors, cu_batch_token_offset, target_logits = build_output_logprobs( logits, sampling_metadata, share_inputs, @@ -1042,6 +1222,45 @@ def forward_cuda( sampler_output.logprobs_tensors = logprobs_tensors if cu_batch_token_offset is not None: sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() + if keep_sampling_mask: + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + accept_nums = share_inputs["accept_num"][:real_bsz].reshape([-1]) + # Derive target probs from already-extracted target_logits; avoids a second kernel call. + target_probs = F.softmax(target_logits, axis=-1) + # Compute sampling mask at accepted token positions. + # Expand top_p from [batch, 1] to [total_accepted, 1]. + accept_top_p = ( + sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + accept_top_k = None + if ( + sampling_metadata.top_k is not None + and sampling_metadata.top_k_list + and any(x > 0 for x in sampling_metadata.top_k_list) + ): + accept_top_k = ( + sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + indices_window_gpu, mask_window_gpu, logz_per_batch, mask_bsz = _compute_sampling_mask( + target_probs, + accept_top_p, + top_k=accept_top_k, + top_k_list=sampling_metadata.top_k_list, + ) + # Async D2H copy with event + indices_window_cpu = paddle.empty_like( + indices_window_gpu, dtype=indices_window_gpu.dtype, device="cpu" + ).pin_memory() + mask_window_cpu = paddle.empty_like( + mask_window_gpu, dtype=mask_window_gpu.dtype, device="cpu" + ).pin_memory() + indices_window_cpu.copy_(indices_window_gpu, False) + mask_window_cpu.copy_(mask_window_gpu, False) + sampling_mask_event = paddle.device.cuda.create_event() + sampling_mask_event.record() + sampler_output.sampling_mask = (indices_window_cpu, mask_window_cpu, mask_bsz) + sampler_output.logz_per_batch = logz_per_batch + sampler_output.sampling_mask_event = sampling_mask_event return sampler_output def forward_xpu( diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 29fc4235381..0db31252bbd 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -114,7 +114,11 @@ from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( RoutingReplayManager, ) +from fastdeploy.model_executor.layers.sample.logprobs import ( + logprobs_renormalize_with_logz, +) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.layers.sample.sampler import _extract_sparse_indices from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOutput @@ -221,6 +225,7 @@ def _build_stream_transfer_data( pooler_outputs: List[PoolingSequenceGroupOutput] = None, logprobs: Optional[LogprobsTensors] = None, prompt_logprobs_list: Optional[LogprobsTensors] = None, + sampling_mask: Optional[List[np.ndarray]] = None, ): """Split output_tokens and output""" @@ -230,6 +235,8 @@ def _build_stream_transfer_data( output_tokens = output_tokens.numpy().reshape([-1]) output_tokens_lists = np.split(output_tokens, output_tokens.shape[0]) + sampling_mask_list = sampling_mask + for bid, output_token_per_sample in enumerate(output_tokens_lists): stream_transfer_data = StreamTransferData( decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid @@ -238,6 +245,8 @@ def _build_stream_transfer_data( stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1) if prompt_logprobs_list: stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] + if sampling_mask_list is not None: + stream_transfer_data.sampling_mask = sampling_mask_list[bid] stream_transfer_datas.append(stream_transfer_data) elif pooler_outputs is not None: for bid, pooler_output in enumerate(pooler_outputs): @@ -373,6 +382,9 @@ def post_process_normal( model_output.is_block_step, ) + # logprobs renormalization with logz is deferred to save_output, + # so that async D2H of logz_per_batch has more time to complete. + def save_output_normal( model_output: ModelOutputData, @@ -380,7 +392,28 @@ def save_output_normal( share_inputs: Dict[str, paddle.Tensor], async_output_queue: queue.Queue = None, save_each_rank: bool = False, + sampling_mask_async_queue: Optional[queue.Queue] = None, ): + # Resolve deferred async D2H: sync event once at the top so all paths below + # can safely read sampling_mask and logz_per_batch. + if sampler_output.sampling_mask_event is not None: + sampler_output.sampling_mask_event.synchronize() + # Extract sparse indices from pinned CPU buffers + if sampler_output.sampling_mask is not None: + indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask + sampler_output.sampling_mask = _extract_sparse_indices( + indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz + ) + sampler_output.sampling_mask_event = None + + # Renormalize logprobs with logz (deferred from post_process for better overlap). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( + sampler_output.logprobs_tensors.logprobs, + sampler_output.logz_per_batch, + sampler_output.logprobs_tensors, + ) + # Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if envs.FD_USE_GET_SAVE_OUTPUT_V1: @@ -398,6 +431,7 @@ def save_output_normal( recover_share_inputs_map["sampled_token_ids"], logprobs=sampler_output.logprobs_tensors, prompt_logprobs_list=model_output.prompt_logprobs_list, + sampling_mask=sampler_output.sampling_mask, ) async_output_queue.put(output) else: @@ -434,6 +468,13 @@ def save_output_normal( recover_share_inputs_map["last_preempted_idx"], model_output.mp_rank, ) + # Send sampling_mask via ZMQ side-channel when enabled (async via background thread). + if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: + # sampling_mask already resolved at function entry. + assert ( + sampling_mask_async_queue is not None + ), "sampling_mask_async_queue must not be None when sampling_mask is enabled" + sampling_mask_async_queue.put((sampler_output.sampling_mask, None)) share_inputs["last_preempted_idx"][:] = 0 @@ -525,6 +566,9 @@ def post_process_specualate( model_output.max_dec_len, # max_dec_len ) + # logprobs renormalization with logz is deferred to save_output, + # so that async D2H of logz_per_batch has more time to complete. + def save_output_specualate( sampler_output: SamplerOutput, @@ -534,8 +578,28 @@ def save_output_specualate( local_rank: int, tensor_parallel_rank: int, save_each_rank: bool = False, + sampling_mask_async_queue: Optional[queue.Queue] = None, is_mtp_prefill: bool = False, ): + # Resolve deferred async D2H: sync event once at the top so all paths below + # can safely read sampling_mask and logz_per_batch. + if sampler_output.sampling_mask_event is not None: + sampler_output.sampling_mask_event.synchronize() + if sampler_output.sampling_mask is not None: + indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask + sampler_output.sampling_mask = _extract_sparse_indices( + indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz + ) + sampler_output.sampling_mask_event = None + + # Renormalize logprobs with logz (deferred from post_process for better overlap). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( + sampler_output.logprobs_tensors.logprobs, + sampler_output.logz_per_batch, + sampler_output.logprobs_tensors, + ) + if is_mtp_prefill: if tensor_parallel_rank == 0: skip_chunk_prefill = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) @@ -656,6 +720,16 @@ def save_output_specualate( model_output.mp_rank, save_each_rank, ) + # Send sampling_mask via ZMQ side-channel when enabled (async via background thread). + if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: + # sampling_mask already resolved at function entry. + # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). + real_bsz = model_output.accept_num.shape[0] + accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() + assert ( + sampling_mask_async_queue is not None + ), "sampling_mask_async_queue must not be None when sampling_mask is enabled" + sampling_mask_async_queue.put((sampler_output.sampling_mask, accept_nums)) share_inputs["last_preempted_idx"][:] = 0 diff --git a/fastdeploy/output/stream_transfer_data.py b/fastdeploy/output/stream_transfer_data.py index b32e01c954f..dce21bb5963 100644 --- a/fastdeploy/output/stream_transfer_data.py +++ b/fastdeploy/output/stream_transfer_data.py @@ -46,3 +46,7 @@ class StreamTransferData: accept_num: Optional[np.array] = None # [num_reqs, hidden_size] pooler_output: Optional[np.array] = None + # 1-D int32 numpy array of vocab indices retained by top_p/top_k for + # this request. Sparse format: only retained positions, not a dense + # vocab-sized bool mask. + sampling_mask: Optional[np.array] = None diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 85e54647b7e..7195bf83aa8 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -83,6 +83,14 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.speculative_decoding = self.cfg.speculative_config.method is not None self.use_logprobs = self.cfg.model_config.enable_logprob + self.use_sampling_mask = getattr(self.cfg.model_config, "enable_keep_sampling_mask", False) + if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.use_sampling_mask: + rank_id = self.cfg.parallel_config.local_data_parallel_id + port = self.cfg.parallel_config.engine_worker_queue_port[rank_id] + self.sampling_mask_zmq_server = ZmqIpcServer( + name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PULL + ) + llm_logger.info(f"create zmq sampling_mask_output_rank_{rank_id}_{port}") self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob if self.speculative_decoding: @@ -357,6 +365,8 @@ def _process_batch_output_use_zmq(self, receive_datas): result.prompt_logprobs = stream_data.prompt_logprobs except Exception as e: llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}") + if getattr(stream_data, "sampling_mask", None) is not None: + result.outputs.sampling_mask = stream_data.sampling_mask.tolist() if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages @@ -734,6 +744,15 @@ def _process_batch_output(self): batch = self.output_tokens[1, 0] tokens = tokens[2 : batch + 2] + # Receive sampling constraints per request from ZMQ side-channel (if enabled). + # The worker sends a dict {batch_id: sparse_vocab_indices} each step, + # where the value is a list[int] or list[list[int]] of allowed token ids + sampling_masks_per_request = {} + if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): + _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) + if mask_data is not None and isinstance(mask_data, dict): + sampling_masks_per_request = mask_data + batch_result = list() # reschedule for i in range(batch): @@ -868,6 +887,9 @@ def _process_batch_output(self): result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0) result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0) + if self.use_sampling_mask and i in sampling_masks_per_request: + result.outputs.sampling_mask = sampling_masks_per_request[i] + if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index fc3dbc4ab2c..a9502be7509 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -126,6 +126,7 @@ def __init__( self.spec_method = self.fd_config.speculative_config.method self.speculative_decoding = self.spec_method is not None self.enable_logprob = fd_config.model_config.enable_logprob + self.enable_keep_sampling_mask = fd_config.model_config.enable_keep_sampling_mask self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size @@ -236,6 +237,27 @@ def __init__( # Rollout routing replay config self.routing_replay_manager = None + # ZMQ side-channel for sampling_mask in non-FD_USE_GET_SAVE_OUTPUT_V1 path + self.sampling_mask_zmq_client = None + if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.enable_keep_sampling_mask: + rank_id = self.parallel_config.local_data_parallel_id + port = self.parallel_config.engine_worker_queue_port[rank_id] + self.sampling_mask_zmq_client = ZmqIpcClient( + name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PUSH + ) + self.sampling_mask_zmq_client.connect() + logger.info(f"create send zmq sampling_mask_output_rank_{rank_id}_{port}") + + self.sampling_mask_async_queue = None + if self.sampling_mask_zmq_client is not None: + self.sampling_mask_async_queue = queue.Queue() + self._sampling_mask_send_thread = Thread( + target=self._async_sampling_mask_send_loop, + daemon=True, + name="WorkerAsyncSamplingMaskSend", + ) + self._sampling_mask_send_thread.start() + self.zmq_client = None self.async_output_queue = None if envs.FD_USE_GET_SAVE_OUTPUT_V1: @@ -286,6 +308,27 @@ def _async_output_busy_loop(self): except Exception as e: logger.exception("Exception in async output loop: %s", e) + def _async_sampling_mask_send_loop(self): + """Background thread: serialize and send sampling_mask over ZMQ.""" + while True: + try: + mask_list, accept_nums = self.sampling_mask_async_queue.get() + if accept_nums is None: + # Normal (non-speculative) path + mask_dict = {i: arr.tolist() for i, arr in enumerate(mask_list)} + else: + # Speculative path: group by accept_num + mask_dict = {} + offset = 0 + for i, n in enumerate(accept_nums): + n = int(n) + if n > 0: + mask_dict[i] = [arr.tolist() for arr in mask_list[offset : offset + n]] + offset += n + self.sampling_mask_zmq_client.send_pyobj(mask_dict) + except Exception as e: + logger.exception("Exception in async sampling_mask send loop: %s", e) + def exist_prefill(self): """ check whether prefill stage exist @@ -1233,6 +1276,7 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], logits_processors=self.share_inputs["logits_processors"], share_inputs=self.share_inputs, + keep_sampling_mask=self.enable_keep_sampling_mask, ) return token_num, token_num_event @@ -2486,6 +2530,7 @@ def _save_model_output( local_rank=self.local_rank, tensor_parallel_rank=self.parallel_config.tensor_parallel_rank, save_each_rank=self.parallel_config.use_ep, + sampling_mask_async_queue=self.sampling_mask_async_queue, is_mtp_prefill=( self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" ), @@ -2497,6 +2542,7 @@ def _save_model_output( share_inputs=self.share_inputs, async_output_queue=self.async_output_queue, save_each_rank=self.parallel_config.use_ep, + sampling_mask_async_queue=self.sampling_mask_async_queue, ) def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]: diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 365fec12475..3f247d66197 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -15,8 +15,9 @@ """ from dataclasses import dataclass, field -from typing import NamedTuple, Optional +from typing import List, NamedTuple, Optional +import numpy as np import paddle @@ -178,6 +179,32 @@ class SamplerOutput: token_num_per_batch: Optional[paddle.Tensor] = None cu_batch_token_offset: Optional[paddle.Tensor] = None logits: Optional[paddle.Tensor] = None + # Sparse sampling mask for top_p/top_k: + # Before sampling_mask_event sync: stored as a deferred tuple + # (indices_window_cpu, mask_window_cpu, real_bsz) where the CPU tensors + # are pinned-memory targets of async D2H copies. + # After event sync + _extract_sparse_indices: converted to the final + # List[np.ndarray] format below. + # - Non-speculative decoding: per-request mask. This is a list of length + # num_reqs, where element i is a 1-D int32 numpy array of vocab indices + # retained by top_p/top_k for request i. Replaces the previous dense + # [num_reqs, vocab_size] bool tensor. + # - Speculative decoding: flattened per-accepted-token mask. This may be + # stored as a list aligned with all accepted tokens + # (e.g. length = total_accepted_tokens) and is regrouped by accept_num + # (number of accepted tokens per request) in post-processing before + # being sent back as per-request data. + # Callers MUST NOT assume this is always shaped by num_reqs; they should + # check whether the current path is speculative or non-speculative when + # interpreting the dimension. + sampling_mask: Optional[List[np.ndarray]] = None + # logZ_K for each request: log(sum(probs in candidate set K)) + # Used for renormalizing logprobs to match the truncated sampling distribution. + # Shape: [num_reqs] + logz_per_batch: Optional[np.ndarray] = None + # CUDA event that guards async D2H copy of sampling_mask / logz_per_batch. + # Must be synchronized before reading sampling_mask or logz_per_batch. + sampling_mask_event: Optional[object] = None @dataclass diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 3f2a1fcf0dd..be4883bdefb 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -1095,6 +1095,16 @@ def parse_args(): help="Maximum tokens per item in mm input.", ) + parser.add_argument( + "--enable_keep_sampling_mask", + "--enable-keep-sampling-mask", + action="store_true", + help=( + "Enable output of keep_sampling_mask as sparse vocab index list per token step " + "(Non-MTP: List[int]; MTP: List[List[int]])." + ), + ) + parser.add_argument( "--num_cpu_blocks", type=int, diff --git a/tests/e2e/test_ernie_21b_mtp.py b/tests/e2e/test_ernie_21b_mtp.py index dc60a213217..0ac4ec789af 100644 --- a/tests/e2e/test_ernie_21b_mtp.py +++ b/tests/e2e/test_ernie_21b_mtp.py @@ -83,6 +83,7 @@ def setup_and_run_server(): json.dumps(speculative_config), "--graph-optimization-config", '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + "--enable-keep-sampling-mask", ] # Start subprocess in new process group @@ -366,3 +367,176 @@ def test_mtp_accept_ratio(api_url): prompt_tokens = chunks[-1]["usage"]["prompt_tokens"] cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"] assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题" + + +def _assert_sampling_mask_format(sampling_mask, max_tokens): + """验证 sampling_mask 字段格式的公共辅助函数。 + + sampling_mask 是 List[List[int]]: + - 外层列表长度 == 生成的 token 数(completion_tokens),对应 MTP 每步可接受多个 token + - 内层列表为保留位置的词汇表索引(int),非空且单调递增 + """ + assert sampling_mask is not None, "sampling_mask 不应为 None" + assert isinstance(sampling_mask, list), "sampling_mask 应为 list" + assert len(sampling_mask) > 0, "sampling_mask 不应为空" + assert len(sampling_mask) <= max_tokens, "sampling_mask 长度不应超过 max_tokens" + + for token_mask in sampling_mask: + assert isinstance(token_mask, list), f"每个 token 的 mask 应为 list,实际: {type(token_mask)}" + assert len(token_mask) > 0, "每个 token 的 mask 不应为空(至少保留采样到的 token)" + for idx in token_mask: + assert isinstance(idx, int), f"mask 中的每个元素应为 int,实际: {type(idx)}" + assert idx >= 0, f"mask 索引不应为负数,实际: {idx}" + + +def test_keep_sampling_mask_stream(api_url): + """测试流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. 每个非空 chunk 的 choices[0].sampling_mask 格式为 List[List[int]] + 2. 内层列表包含词汇表保留位置的索引,非空且单调递增 + 3. 最终 sampling_mask 总长度等于 completion_tokens + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + all_sampling_masks = [] + for chunk in chunks[:-1]: # 最后一个 chunk 是 usage-only + choice = chunk["choices"][0] + # 仅当 delta 有实际内容时才应携带 sampling_mask(首个 role chunk 内容为空,不含该字段) + has_content = bool(choice.get("delta", {}).get("content")) + mask = choice.get("sampling_mask") + if has_content: + assert mask is not None, f"有内容的 chunk 缺少 sampling_mask 字段: {choice}" + if mask is not None: + assert isinstance(mask, list), f"sampling_mask 应为 list,实际: {type(mask)}" + for token_mask in mask: + assert isinstance(token_mask, list), "每个 token mask 应为 list" + assert len(token_mask) > 0, "每个 token mask 不应为空" + for idx in token_mask: + assert isinstance(idx, int) and idx >= 0, f"mask 索引应为非负 int,实际: {idx}" + all_sampling_masks.extend(mask) + + # 最后一个 chunk 携带 usage 信息 + usage = chunks[-1].get("usage") + if usage: + completion_tokens = usage["completion_tokens"] + assert ( + len(all_sampling_masks) == completion_tokens + ), f"sampling_mask 总长度 {len(all_sampling_masks)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_non_stream(api_url): + """测试非流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. choices[0].sampling_mask 格式为 List[List[int]] + 2. 长度等于 completion_tokens + 3. 内层列表包含非负递增的词汇表索引 + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + + response = send_request(url=api_url, payload=payload).json() + assert "choices" in response, f"响应缺少 choices 字段: {response}" + choice = response["choices"][0] + assert "sampling_mask" in choice, f"choice 缺少 sampling_mask 字段: {choice}" + + sampling_mask = choice["sampling_mask"] + completion_tokens = response["usage"]["completion_tokens"] + _assert_sampling_mask_format(sampling_mask, max_tokens) + assert ( + len(sampling_mask) == completion_tokens + ), f"sampling_mask 长度 {len(sampling_mask)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_top_p_1_stream(api_url): + """测试 top_p=1.0 时流式响应的 sampling_mask(MTP 模式)。 + + top_p=1.0 表示保留全部词汇,每个 token mask 应包含所有词汇表位置。 + 验证 mask 非空且每个内层列表长度 > 1(至少保留多个候选 token)。 + """ + max_tokens = 10 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 1.0, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "1+1="}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + for chunk in chunks[:-1]: + choice = chunk["choices"][0] + mask = choice.get("sampling_mask") + if mask is not None: + for token_mask in mask: + assert len(token_mask) > 1, "top_p=1.0 时每个 token 的候选集应大于 1" + + +def test_keep_sampling_mask_consistent_with_top_p(api_url): + """对比 top_p=0.1 与 top_p=0.9 时 sampling_mask 的候选集大小(非流式,MTP 模式)。 + + top_p 越小,保留的候选 token 越少,平均 mask 长度应更短。 + """ + max_tokens = 15 + + def get_avg_mask_len(top_p): + payload = { + "model": "default", + "temperature": 1.0, + "top_p": top_p, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请列举三种编程语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + resp = send_request(url=api_url, payload=payload).json() + mask = resp["choices"][0].get("sampling_mask") + if not mask: + return 0 + return sum(len(m) for m in mask) / len(mask) + + avg_small = get_avg_mask_len(0.1) + avg_large = get_avg_mask_len(0.9) + assert avg_small <= avg_large, f"top_p=0.1 的平均 mask 长度 ({avg_small:.1f}) 应 <= top_p=0.9 ({avg_large:.1f})" diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index d98e79b74f2..bd7b6482b09 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -577,6 +577,7 @@ async def test_create_chat_completion_choice(self): response_processor=mock_response_processor, max_tokens=max_tokens_list[idx], speculate_metrics=None, + sampling_mask_list=None, ) expected = case["expected"] diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 1b33405503f..12f20f39eab 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -398,6 +398,7 @@ async def test_create_chat_completion_choice_audio_recover(self): response_processor=response_processor, max_tokens=2, speculate_metrics=None, + sampling_mask_list=None, ) self.assertEqual(choice.finish_reason, "recover_stop") @@ -421,6 +422,7 @@ async def test_create_chat_completion_choice_audio_recover(self): response_processor=response_processor, max_tokens=2, speculate_metrics=None, + sampling_mask_list=None, ) self.assertEqual(choice_length.finish_reason, "length") diff --git a/tests/metrics/test_new_metrics.py b/tests/metrics/test_new_metrics.py index 030acaf4299..f650d6d7d7c 100644 --- a/tests/metrics/test_new_metrics.py +++ b/tests/metrics/test_new_metrics.py @@ -54,6 +54,8 @@ def test_cache_metrics_update_history(self, mock_main_process_metrics): def setUp(self): """为 TokenProcessor 测试设置通用的 mock 对象。""" self.mock_cfg = MagicMock() + self.mock_cfg.parallel_config.local_data_parallel_id = 0 + self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] self.mock_cached_generated_tokens = MagicMock() self.mock_engine_worker_queue = MagicMock() self.mock_split_connector = MagicMock() diff --git a/tests/output/test_process_batch_draft_tokens.py b/tests/output/test_process_batch_draft_tokens.py index 3686dd1b64b..eef5df62cc9 100644 --- a/tests/output/test_process_batch_draft_tokens.py +++ b/tests/output/test_process_batch_draft_tokens.py @@ -30,6 +30,8 @@ def setUp(self): # 模拟 cfg cfg = MagicMock() cfg.speculative_config = MagicMock() + cfg.parallel_config.local_data_parallel_id = 0 + cfg.parallel_config.engine_worker_queue_port = ["9700"] cfg.speculative_config.method = "mtp" cfg.speculative_config.num_speculative_tokens = 3 cfg.model_config = MagicMock() diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 46282cd386a..9398e07d9f5 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -166,6 +166,7 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): processor.total_step_per_request = {} processor.accept_token_num_per_head_per_request = {} processor.accept_token_num_per_head = [0] * MAX_DRAFT_TOKENS + processor.use_sampling_mask = False # processor._recycle_resources = Mock() diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py index 07826e6f0eb..8244bb06bbf 100644 --- a/tests/output/test_process_batch_output_use_zmq.py +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -31,6 +31,7 @@ def setUp(self): self.cfg.model_config.enable_logprob = True self.cfg.speculative_config.method = None self.cfg.parallel_config.local_data_parallel_id = 0 + self.cfg.parallel_config.engine_worker_queue_port = ["9700"] self.cached_generated_tokens = MagicMock() self.engine_worker_queue = MagicMock() self.split_connector = MagicMock() diff --git a/tests/output/test_token_processor_trace_print.py b/tests/output/test_token_processor_trace_print.py index 9ba9b45dfae..018038143f3 100644 --- a/tests/output/test_token_processor_trace_print.py +++ b/tests/output/test_token_processor_trace_print.py @@ -23,6 +23,8 @@ class TestTokenProcessorMetrics: def setup_method(self): self.mock_cfg = MagicMock() + self.mock_cfg.parallel_config.local_data_parallel_id = 0 + self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] self.mock_cached_tokens = MagicMock() self.mock_engine_queue = MagicMock() self.mock_split_connector = MagicMock() From eb92613c9770e7feeb503477e9ce59091262238d Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:40:46 +0800 Subject: [PATCH 057/143] [Cherry-Pick][BugFix] Fix save_output_specualate parameter bugs in suffix decoding (#7566) (#7569) * fix bugs in suffix --- fastdeploy/model_executor/pre_and_post_process.py | 3 ++- fastdeploy/spec_decode/base.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0db31252bbd..fb720affb58 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -574,12 +574,12 @@ def save_output_specualate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, - proposer_share_inputs: ProposerInputBatch, local_rank: int, tensor_parallel_rank: int, save_each_rank: bool = False, sampling_mask_async_queue: Optional[queue.Queue] = None, is_mtp_prefill: bool = False, + proposer_share_inputs: Optional[ProposerInputBatch] = None, ): # Resolve deferred async D2H: sync event once at the top so all paths below # can safely read sampling_mask and logz_per_batch. @@ -601,6 +601,7 @@ def save_output_specualate( ) if is_mtp_prefill: + assert proposer_share_inputs is not None if tensor_parallel_rank == 0: skip_chunk_prefill = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) if sampler_output.logprobs_tensors is None: diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 8db764fcf12..08553411188 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -118,7 +118,7 @@ def prepare_dummy_speculative_drafts( stop = share_inputs["stop_flags"][0].item() if not stop: - share_inputs["draft_tokens"][:batch_size, :max_fake_drafts] = 5 + share_inputs["draft_tokens"][:batch_size, : max_fake_drafts + 1] = 5 share_inputs["seq_lens_this_time"][:batch_size] = max_fake_drafts + 1 else: share_inputs["seq_lens_this_time"][:batch_size] = 0 diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a9502be7509..5e09d2a5d07 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2526,7 +2526,6 @@ def _save_model_output( sampler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, - proposer_share_inputs=self.proposer.model_inputs, local_rank=self.local_rank, tensor_parallel_rank=self.parallel_config.tensor_parallel_rank, save_each_rank=self.parallel_config.use_ep, @@ -2534,6 +2533,7 @@ def _save_model_output( is_mtp_prefill=( self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" ), + proposer_share_inputs=self.proposer.model_inputs if self.spec_method == SpecMethod.MTP else None, ) else: save_output_normal( From 10f5a20855d7cc7f52c2f3b8120a120d0e0c35bf Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:43:27 +0800 Subject: [PATCH 058/143] Cache queue support ipc (#7589) --- .../cache_manager/cache_transfer_manager.py | 7 +++++-- .../cache_manager/prefix_cache_manager.py | 10 ++++++++- fastdeploy/engine/common_engine.py | 21 ++++++++++++------- .../engine/sched/resource_manager_v1.py | 4 ++++ .../inter_communicator/engine_cache_queue.py | 10 +++++---- fastdeploy/trace/constants.py | 8 +++++++ 6 files changed, 45 insertions(+), 15 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 85a113adf66..bd347b384e6 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -208,9 +208,12 @@ def __init__(self, args): self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 self.ctrl_output_queue = None - address = (args.pod_ip, args.cache_queue_port) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + engine_cache_queue_address = (args.pod_ip, args.cache_queue_port) + else: + engine_cache_queue_address = f"/dev/shm/fd_task_queue_{args.cache_queue_port}.sock" self.cache_task_queue = EngineCacheQueue( - address=address, + address=engine_cache_queue_address, is_server=False, num_client=args.mp_num, client_id=self.rank, diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 52cd83682eb..1da592f891b 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -37,6 +37,8 @@ from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.trace.constants import LoggingEventName +from fastdeploy.trace.trace_logger import print as trace_print from fastdeploy.utils import get_hash_str, get_logger logger = get_logger("prefix_cache_manager", "cache_manager.log") @@ -211,8 +213,12 @@ def launch_cache_manager( create=True, ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + engine_cache_queue_address = (pod_ip, cache_config.local_cache_queue_port) + else: + engine_cache_queue_address = f"/dev/shm/fd_task_queue_{cache_config.local_cache_queue_port}.sock" self.cache_task_queue = EngineCacheQueue( - address=(pod_ip, cache_config.local_cache_queue_port), + address=engine_cache_queue_address, authkey=b"cache_queue_service", is_server=False, num_client=tensor_parallel_size, @@ -1157,6 +1163,7 @@ def write_cache_to_storage(self, request: Request): if not keys: return + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) gpu_block_ids = request.block_tables[: len(keys)] logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}") write_storage_task = WriteStorageTask( @@ -1170,6 +1177,7 @@ def write_cache_to_storage(self, request: Request): self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) def write_cache_to_storage_decode(self, request: Request): """ diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 6b500a38bdf..090eabd40d8 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -453,15 +453,19 @@ def start_worker_queue_service(self, start_queue): start queue service for engine worker communication """ if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: - address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port) + engine_worker_queue_address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port) + engine_cache_queue_address = (self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port) else: - address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock" + engine_worker_queue_address = ( + f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock" + ) + engine_cache_queue_address = f"/dev/shm/fd_task_queue_{self.cfg.cache_config.local_cache_queue_port}.sock" if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": if start_queue: - self.llm_logger.info(f"Starting engine worker queue server service at {address}") + self.llm_logger.info(f"Starting engine worker queue server service at {engine_worker_queue_address}") self.engine_worker_queue_server = EngineWorkerQueue( - address=address, + address=engine_worker_queue_address, is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, @@ -471,7 +475,7 @@ def start_worker_queue_service(self, start_queue): self.cfg.parallel_config.local_engine_worker_queue_port = ( self.engine_worker_queue_server.get_server_port() ) - address = ( + engine_worker_queue_address = ( self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port, ) @@ -481,17 +485,18 @@ def start_worker_queue_service(self, start_queue): f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}" ) self.cache_task_queue = EngineCacheQueue( - address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port), + address=engine_cache_queue_address, authkey=b"cache_queue_service", is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=-1, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) - self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port() + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port() self.engine_worker_queue = EngineWorkerQueue( - address=address, + address=engine_worker_queue_address, is_server=False, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index cf65edf717e..39ad13f0b66 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1264,6 +1264,8 @@ def get_prefix_cached_blocks(self, request: Request): Match and fetch cache for a task. """ try: + trace_print(LoggingEventName.PREPARE_PREFIX_CACHE_START, request.request_id, getattr(request, "user", "")) + (common_block_ids, matched_token_num, metrics) = self.cache_manager.request_match_blocks( request, self.config.cache_config.block_size ) @@ -1316,6 +1318,8 @@ def get_prefix_cached_blocks(self, request: Request): main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num) main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.cpu_cache_token_num) + trace_print(LoggingEventName.PREPARE_PREFIX_CACHE_END, request.request_id, getattr(request, "user", "")) + return True except Exception as e: llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...") diff --git a/fastdeploy/inter_communicator/engine_cache_queue.py b/fastdeploy/inter_communicator/engine_cache_queue.py index 535dc1dc4c3..97a08cd4e88 100644 --- a/fastdeploy/inter_communicator/engine_cache_queue.py +++ b/fastdeploy/inter_communicator/engine_cache_queue.py @@ -24,7 +24,7 @@ Value, ValueProxy, ) -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union from fastdeploy.utils import get_logger @@ -39,7 +39,7 @@ class EngineCacheQueue: def __init__( self, - address: Tuple[str, int] = ("127.0.0.1", 56666), + address: Union[Tuple[str, int], str] = ("127.0.0.1", 56666), authkey: bytes = b"cache_queue_service", is_server: bool = False, num_client: int = 1, # tensor parallel size @@ -62,7 +62,7 @@ def __init__( TODO(liyonghua): Remove multi-DP initialization. Each DP will have its own cache queue. """ - self.address: Tuple[str, int] = address + self.address: Union[Tuple[str, int], str] = address self.authkey: bytes = authkey self.is_server: bool = is_server self.num_client: int = num_client @@ -210,8 +210,10 @@ class QueueManager(BaseManager): QueueManager.register("get_swap_storage_to_gpu_barrier") QueueManager.register("get_swap_to_storage_barrier") + logger.info(f"Try to connect QueueManager, address: {self.address}") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() + logger.info(f"Connected to QueueManager, address: {self.address}") # Get proxy objects for shared resources self.transfer_task_queue = self.manager.get_transfer_task_queue(self.local_data_parallel_id) @@ -246,7 +248,7 @@ def get_server_port(self) -> int: Returns the actual port that the server instance is listening on. Calling this method only makes sense on instances where is_server=True. """ - if not self.is_server: + if not self.is_server or isinstance(self.address, str): raise RuntimeError("Only the server instance can provide the port.") return self.address[1] diff --git a/fastdeploy/trace/constants.py b/fastdeploy/trace/constants.py index eaf54d68085..b8ffc94271a 100644 --- a/fastdeploy/trace/constants.py +++ b/fastdeploy/trace/constants.py @@ -26,12 +26,16 @@ class LoggingEventName(Enum): REQUEST_QUEUE_START = "REQUEST_QUEUE_START" REQUEST_QUEUE_END = "REQUEST_QUEUE_END" RESOURCE_ALLOCATE_START = "RESOURCE_ALLOCATE_START" + PREPARE_PREFIX_CACHE_START = "PREPARE_PREFIX_CACHE_START" + PREPARE_PREFIX_CACHE_END = "PREPARE_PREFIX_CACHE_END" RESOURCE_ALLOCATE_END = "RESOURCE_ALLOCATE_END" REQUEST_SCHEDULE_END = "REQUEST_SCHEDULE_END" INFERENCE_START = "INFERENCE_START" FIRST_TOKEN_GENERATED = "FIRST_TOKEN_GENERATED" DECODE_START = "DECODE_START" INFERENCE_END = "INFERENCE_END" + WRITE_CACHE_TO_STORAGE_START = "WRITE_CACHE_TO_STORAGE_START" + WRITE_CACHE_TO_STORAGE_END = "WRITE_CACHE_TO_STORAGE_END" POSTPROCESSING_START = "POSTPROCESSING_START" POSTPROCESSING_END = "POSTPROCESSING_END" PREEMPTED = "PREEMPTED" @@ -57,6 +61,8 @@ class StageName(Enum): LoggingEventName.REQUEST_QUEUE_START: StageName.SCHEDULE, LoggingEventName.REQUEST_QUEUE_END: StageName.SCHEDULE, LoggingEventName.RESOURCE_ALLOCATE_START: StageName.SCHEDULE, + LoggingEventName.PREPARE_PREFIX_CACHE_START: StageName.SCHEDULE, + LoggingEventName.PREPARE_PREFIX_CACHE_END: StageName.SCHEDULE, LoggingEventName.RESOURCE_ALLOCATE_END: StageName.SCHEDULE, LoggingEventName.REQUEST_SCHEDULE_END: StageName.SCHEDULE, LoggingEventName.INFERENCE_START: StageName.PREFILL, @@ -65,6 +71,8 @@ class StageName(Enum): LoggingEventName.PREEMPTED: StageName.DECODE, LoggingEventName.RESCHEDULED_INFERENCE_START: StageName.DECODE, LoggingEventName.INFERENCE_END: StageName.DECODE, + LoggingEventName.WRITE_CACHE_TO_STORAGE_START: StageName.POSTPROCESSING, + LoggingEventName.WRITE_CACHE_TO_STORAGE_END: StageName.POSTPROCESSING, LoggingEventName.POSTPROCESSING_START: StageName.POSTPROCESSING, LoggingEventName.POSTPROCESSING_END: StageName.POSTPROCESSING, } From af68b26a0c96281b90740850af407198fd2bf138 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 24 Apr 2026 11:12:46 +0800 Subject: [PATCH 059/143] [RL] Remove redundant barrier and optimize model weights signal broadcast (#7545) (#7592) * rm signal sync * overlap Co-authored-by: sunxin <68891411+Sunny-bot1@users.noreply.github.com> --- fastdeploy/worker/worker_process.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index be4883bdefb..865cbb909db 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -316,10 +316,9 @@ def update_weights_from_tensor(self, mmap_infos): self.experts_manager.tensor_infos = None def _broadcast_model_weights_signal(self, src: int, group) -> int: - model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32") - paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group) - value = model_weights_signal_tensor.numpy()[0] - return int(value) + signal_list = [self.model_weights_signal[0]] + paddle.distributed.broadcast_object_list(signal_list, src=src, group=group) + return int(signal_list[0]) def _tp_barrier_wait(self): if current_platform.is_xpu() or self.enable_overlap_schedule: @@ -507,9 +506,9 @@ def event_loop_normal(self) -> None: self._tp_barrier_wait() if tp_size > 1 else None if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: - if self.ranks > 1: - paddle.distributed.barrier() if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL: + if self.ranks > 1: + paddle.distributed.barrier() logger.info( f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]" ) From 4cbae626eff1f946435f24d9b5aeb363a8abd5df Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 24 Apr 2026 14:36:51 +0800 Subject: [PATCH 060/143] Use triton qk_norm both in Prefill and Decode (#7213) (#7306) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: K11OntheBoat Co-authored-by: “liuruian” --- fastdeploy/model_executor/layers/normalization.py | 2 +- tests/e2e/test_Qwen3VL_serving.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 14e248e0a72..4532b8d27af 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -341,7 +341,7 @@ def forward( forward_meta, proxy_rmsnorm=None, ) -> paddle.Tensor: - if proxy_rmsnorm is None and self.qk_norm_fused and forward_meta.step_use_cudagraph: + if proxy_rmsnorm is None and self.qk_norm_fused: qkv_out = qk_rmsnorm_fused( qkv_out, self.q_norm.weight, diff --git a/tests/e2e/test_Qwen3VL_serving.py b/tests/e2e/test_Qwen3VL_serving.py index 3872b4050ce..bbb053b13dd 100644 --- a/tests/e2e/test_Qwen3VL_serving.py +++ b/tests/e2e/test_Qwen3VL_serving.py @@ -173,7 +173,7 @@ def test_consistency_between_runs(api_url, headers, consistent_payload): content1 = result1["choices"][0]["message"]["content"] # base result - content2 = "视频中手机支架的颜色是黑色的。" + content2 = "视频中手机支架的颜色是黑色。" # Verify that result is same as the base result assert content1.startswith(content2), content1 From 8d7063ec49da55515588e6709dea5eb8a381b882 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 24 Apr 2026 14:37:50 +0800 Subject: [PATCH 061/143] [Cherry-Pick][Optimization]Change default workers and max-concurrency when launch api-server(#7457) (#7516) * Change default workers and max-concurrency when launch api-server (#7457) Co-authored-by: zhangxiao35 * [CI] Add --workers=1 to keep test behavior consistent with default change --------- Co-authored-by: K11OntheBoat Co-authored-by: zhangxiao35 --- .github/workflows/_base_test.yml | 2 +- fastdeploy/entrypoints/api_server.py | 2 +- fastdeploy/entrypoints/openai/utils.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index fce46f04412..99bf7209747 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -266,7 +266,7 @@ jobs: curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ -H "Content-Type: application/json" \ - -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }" + -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--workers\": 1, \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }" check_service 90 python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1 diff --git a/fastdeploy/entrypoints/api_server.py b/fastdeploy/entrypoints/api_server.py index 4f4d7f2250c..e182eb61fe9 100644 --- a/fastdeploy/entrypoints/api_server.py +++ b/fastdeploy/entrypoints/api_server.py @@ -123,7 +123,7 @@ def main(): parser = FlexibleArgumentParser() parser.add_argument("--port", default=9904, type=int, help="port to the http server") parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") - parser.add_argument("--workers", default=1, type=int, help="number of workers") + parser.add_argument("--workers", default=4, type=int, help="number of workers") parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() launch_api_server(args) diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index baa428b5003..57976b0f5b2 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -341,9 +341,10 @@ async def close(self): def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + _is_multi_server = os.environ.get("FD_ENABLE_MULTI_API_SERVER") == "1" parser.add_argument("--port", default=8000, type=int, help="port to the http server") parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") - parser.add_argument("--workers", default=1, type=int, help="number of workers") + parser.add_argument("--workers", default=1 if _is_multi_server else 4, type=int, help="number of workers") parser.add_argument("--metrics-port", default=None, type=int, help="port for metrics server") parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") parser.add_argument( @@ -352,7 +353,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, help="max waiting time for connection, if set value -1 means no waiting time limit", ) - parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") + parser.add_argument( + "--max-concurrency", default=512 if _is_multi_server else 2048, type=int, help="max concurrency" + ) parser.add_argument( "--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. " From 0de0be4b51d78e3eaad1e17e48d66650b3b2a754 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 24 Apr 2026 14:50:26 +0800 Subject: [PATCH 062/143] [Others] print evictable blocks in console log (#7384) (#7580) * [chore] print evictable blocks in console log * [chore] add resource manager info log Co-authored-by: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> --- fastdeploy/engine/resource_manager.py | 10 ++++++++-- fastdeploy/engine/sched/resource_manager_v1.py | 14 +++++++++++++- .../engine/sched/scheduler_metrics_logger.py | 8 ++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index 609c88533bd..173cbdf9dd7 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -368,6 +368,12 @@ def info(self): total_block_number = self.total_block_number() available_block_num = self.available_block_num() used_block_num = total_block_number - available_block_num + blocks_used_by_tasks = set() + for task in self.tasks_list: + if task is not None: + blocks_used_by_tasks.update(getattr(task, "block_tables", [])) + blocks_used_by_tasks.update(getattr(task, "extend_block_tables", [])) + evictable_block_num = used_block_num - len(blocks_used_by_tasks) block_usage = used_block_num / total_block_number * 100 total_batch_number = len(self.stop_flags) available_batch_num = self.available_batch() @@ -375,8 +381,8 @@ def info(self): batch_usage = used_batch_num / total_batch_number * 100 info = ( f"ResourceManager info, " - f"total_block_number: {total_block_number}, total_batch_number: {total_batch_number}, " - f"available_block_num: {available_block_num}, available_batch: {available_batch_num}," + f"total_block_number: {total_block_number}, available_block_num: {available_block_num}, evictable_block_num: {evictable_block_num}, " + f"total_batch_number: {total_batch_number}, available_batch: {available_batch_num}," f"running_reqs: {used_batch_num}, block_usage: {block_usage:.2f}%, batch_usage: {batch_usage:.2f}%" ) return info diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 39ad13f0b66..46925699412 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1586,7 +1586,8 @@ def update_metrics(self, verbose=False): blocks_used_by_tasks = set() for task in self.tasks_list: if task is not None: - blocks_used_by_tasks.update(task.block_tables) + blocks_used_by_tasks.update(getattr(task, "block_tables", [])) + blocks_used_by_tasks.update(getattr(task, "extend_block_tables", [])) main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks)) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) @@ -1619,6 +1620,13 @@ def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | Schedule total_blocks = self.total_block_number() free_blocks = self.available_block_num() used_blocks = max(total_blocks - free_blocks, 0) + # Evictable = used blocks not held by any running task + blocks_used_by_tasks = set() + for task in self.tasks_list: + if task is not None: + blocks_used_by_tasks.update(getattr(task, "block_tables", [])) + blocks_used_by_tasks.update(getattr(task, "extend_block_tables", [])) + evictable_blocks = used_blocks - len(blocks_used_by_tasks) tokens_used = used_blocks * self.config.cache_config.block_size token_usage = used_blocks / total_blocks if total_blocks > 0 else 0.0 running_cnt = len(self.running) @@ -1634,6 +1642,8 @@ def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | Schedule queue_cnt=queue_cnt, tokens_used=tokens_used, token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, ) if has_decode: has_prefill = len(prefill_reqs) > 0 @@ -1659,4 +1669,6 @@ def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | Schedule tokens_used=tokens_used, token_usage=token_usage, use_cudagraph=use_decode_cudagraph, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, ) diff --git a/fastdeploy/engine/sched/scheduler_metrics_logger.py b/fastdeploy/engine/sched/scheduler_metrics_logger.py index 9e08375a395..0aaa29e246c 100644 --- a/fastdeploy/engine/sched/scheduler_metrics_logger.py +++ b/fastdeploy/engine/sched/scheduler_metrics_logger.py @@ -72,6 +72,8 @@ def log_prefill_batch( queue_cnt: int, tokens_used: int, token_usage: float, + free_blocks: int = 0, + evictable_blocks: int = 0, ) -> None: if not self.enabled: return @@ -95,6 +97,8 @@ def log_prefill_batch( f"#new-token: {new_tokens}, " f"#cached-token: {cached_tokens}, " f"token usage: {token_usage:.2f}, " + f"#free-block: {free_blocks}, " + f"#evictable-block: {evictable_blocks}, " f"#running-req: {running_cnt}, " f"#queue-req: {queue_cnt}, " ) @@ -107,6 +111,8 @@ def log_decode_batch( tokens_used: int, token_usage: float, use_cudagraph: bool, + free_blocks: int = 0, + evictable_blocks: int = 0, ) -> None: if not self.enabled: return @@ -129,6 +135,8 @@ def log_decode_batch( f"#running-req: {running_cnt}, " f"#token: {tokens_used}, " f"token usage: {token_usage:.2f}, " + f"#free-block: {free_blocks}, " + f"#evictable-block: {evictable_blocks}, " f"cuda graph: {use_cudagraph}, " f"gen throughput (token/s): {throughput:.2f}, " f"#queue-req: {queue_cnt}, " From d88982bc1d7f15efa3ddc4275818c1da4e6b3f41 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:00:16 +0800 Subject: [PATCH 063/143] [Optimization] Support async D2H copy for MTP logprobs & Clean up overlap schedule condition checks (#7521) (#7594) Co-authored-by: sunxin <68891411+Sunny-bot1@users.noreply.github.com> --- fastdeploy/engine/args_utils.py | 7 +---- .../model_executor/layers/sample/logprobs.py | 13 +++++++++- .../model_executor/layers/sample/sampler.py | 26 +++++++++++++++++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 02879bcc7fc..5690dee0a33 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -579,12 +579,7 @@ def __post_init__(self): and not current_platform.is_maca() ): self.enable_prefix_caching = False - if ( - not current_platform.is_cuda() - or (self.speculative_config is not None and self.enable_logprob) - or self.splitwise_role == "prefill" - or self.dynamic_load_weight - ): + if not current_platform.is_cuda() or self.splitwise_role == "prefill": self.enable_overlap_schedule = False if self.enable_logprob: if not current_platform.is_cuda() and not current_platform.is_xpu(): diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index 33fbbc01603..80ccfc2fdd9 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -123,7 +123,18 @@ def gather_logprobs( indices = token_ids top_logprobs = token_logprobs - return LogprobsTensors(indices.cpu(), top_logprobs.cpu(), token_ranks.cpu()) + if current_platform.is_cuda(): + indices_cpu = paddle.empty_like(indices, device="cpu").pin_memory() + top_logprobs_cpu = paddle.empty_like(top_logprobs, device="cpu").pin_memory() + token_ranks_cpu = paddle.empty_like(token_ranks, device="cpu").pin_memory() + indices_cpu.copy_(indices, False) + top_logprobs_cpu.copy_(top_logprobs, False) + token_ranks_cpu.copy_(token_ranks, False) + else: + indices_cpu = indices.cpu() + top_logprobs_cpu = top_logprobs.cpu() + token_ranks_cpu = token_ranks.cpu() + return LogprobsTensors(indices_cpu, top_logprobs_cpu, token_ranks_cpu) def build_output_logprobs( diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 2729cddba8f..7892b4d73ad 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -948,7 +948,18 @@ def gather_logprobs( indices = token_ids top_logprobs = token_logprobs - return LogprobsTensors(indices, top_logprobs, token_ranks) + if current_platform.is_cuda(): + indices_cpu = paddle.empty_like(indices, device="cpu").pin_memory() + top_logprobs_cpu = paddle.empty_like(top_logprobs, device="cpu").pin_memory() + token_ranks_cpu = paddle.empty_like(token_ranks, device="cpu").pin_memory() + indices_cpu.copy_(indices, False) + top_logprobs_cpu.copy_(top_logprobs, False) + token_ranks_cpu.copy_(token_ranks, False) + else: + indices_cpu = indices.cpu() + top_logprobs_cpu = top_logprobs.cpu() + token_ranks_cpu = token_ranks.cpu() + return LogprobsTensors(indices_cpu, top_logprobs_cpu, token_ranks_cpu) def _verify_and_sample( self, @@ -1488,7 +1499,18 @@ def gather_logprobs( indices = token_ids top_logprobs = token_logprobs - return LogprobsTensors(indices, top_logprobs, token_ranks) + if current_platform.is_cuda(): + indices_cpu = paddle.empty_like(indices, device="cpu").pin_memory() + top_logprobs_cpu = paddle.empty_like(top_logprobs, device="cpu").pin_memory() + token_ranks_cpu = paddle.empty_like(token_ranks, device="cpu").pin_memory() + indices_cpu.copy_(indices, False) + top_logprobs_cpu.copy_(top_logprobs, False) + token_ranks_cpu.copy_(token_ranks, False) + else: + indices_cpu = indices.cpu() + top_logprobs_cpu = top_logprobs.cpu() + token_ranks_cpu = token_ranks.cpu() + return LogprobsTensors(indices_cpu, top_logprobs_cpu, token_ranks_cpu) def forward_cuda( self, From 5508979a3d898d0e7b3f30dc4196c5da30543f1c Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:25:03 +0800 Subject: [PATCH 064/143] Fix PD interaction and error response (#7606) --- fastdeploy/engine/common_engine.py | 15 +- fastdeploy/engine/request.py | 2 + .../engine/sched/resource_manager_v1.py | 4 +- fastdeploy/entrypoints/openai/protocol.py | 8 +- fastdeploy/entrypoints/openai/serving_chat.py | 52 ++++-- .../entrypoints/openai/serving_completion.py | 18 +- fastdeploy/input/base_processor.py | 11 ++ fastdeploy/output/token_processor.py | 9 +- fastdeploy/router/router.py | 168 +++++++++++++----- fastdeploy/splitwise/splitwise_connector.py | 16 +- tests/engine/test_request.py | 1 + tests/entrypoints/test_serving_completion.py | 4 +- tests/input/test_ernie_vl_processor.py | 4 +- tests/output/test_token_processor.py | 4 +- tests/router/test_router.py | 9 +- tests/splitwise/test_splitwise_connector.py | 9 - tests/v1/test_resource_manager_v1.py | 1 + 17 files changed, 233 insertions(+), 102 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 090eabd40d8..22016511f61 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -964,14 +964,18 @@ def _fetch_request(): status, msg = self.split_connector.check_decode_allocated(task) task.metrics.ask_decode_resource_finish_time = time.time() if not status: - self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + error_msg = ( + f"PD Error: prefill failed to apply for resource from decode, " + f"req: {task.request_id}, msg:{msg}." + ) + self.llm_logger.error(error_msg) self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, error_code=500, - error_msg=msg, + error_msg=error_msg, ) ] ) @@ -1077,14 +1081,17 @@ def _fetch_request(): if self.cfg.scheduler_config.splitwise_role == "decode": for task in tasks: if task.task_type == RequestType.PREEMPTED: - msg = f"{task.request_id} decode not enough blocks, need to be rescheduled." + msg = ( + f"PD Error: decode does not have enough blocks for " + f"preallocated request. req:{task.request_id} " + ) self.llm_logger.error(msg) self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, - error_code=500, + error_code=502, error_msg=msg, ) ] diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index ccab1ac4114..624f8f32951 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -205,6 +205,7 @@ def __init__( self.metrics = RequestMetrics() else: self.metrics = metrics + self.metrics.prompt_token_ids_len = self.prompt_token_ids_len # from ChatCompletionRequest or CompletionRequest self.user = user self.metadata = metadata @@ -877,6 +878,7 @@ class RequestMetrics: speculate_metrics: Optional[SpeculateMetrics] = None # cache related + prompt_token_ids_len: Optional[int] = None gpu_cache_token_num: Optional[int] = 0 cpu_cache_token_num: Optional[int] = 0 storage_cache_token_num: Optional[int] = 0 diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 46925699412..31af47f0507 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1313,6 +1313,7 @@ def get_prefix_cached_blocks(self, request: Request): request.metrics.storage_cache_token_num = metrics["storage_match_token_num"] request.metrics.cpu_cache_prepare_time = metrics["cpu_cache_prepare_time"] request.metrics.storage_cache_prepare_time = metrics["storage_cache_prepare_time"] + request.metrics.prompt_token_ids_len = request.prompt_token_ids_len main_process_metrics.prefix_cache_token_num.inc(request.num_computed_tokens) main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num) @@ -1445,7 +1446,6 @@ def preallocate_resource_in_d(self, request: Request): request.disaggregate_info["block_tables"] = request.block_tables allocated_position = self.get_available_position() request.idx = allocated_position - self.tasks_list[request.idx] = request self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position @@ -1489,6 +1489,8 @@ def add_prefilled_request(self, request_output: RequestOutput): request.metrics = copy.deepcopy(request_output.metrics) request.metrics.decode_inference_start_time = time.time() request.metrics.update_decoder_start_time() + + self.tasks_list[request.idx] = request self.running.append(request) def _free_blocks(self, request: Request): diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 5642daca0c5..883ff71e858 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -268,7 +268,7 @@ class ChatCompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] speculate_metrics: Optional[SpeculateMetrics] = None # Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token sampling_mask: Optional[List[List[int]]] = None @@ -338,7 +338,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): # Per-token index list of retained positions after top_p sampling. # Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step). sampling_mask: Optional[List[List[int]]] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] = None arrival_time: Optional[float] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -374,7 +374,7 @@ class CompletionResponseChoice(BaseModel): draft_logprobs: Optional[CompletionLogprobs] = None prompt_logprobs: Optional[PromptLogprobs] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -420,7 +420,7 @@ class CompletionResponseStreamChoice(BaseModel): prompt_tokens: Optional[str] = None completion_tokens: Optional[str] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 55bd37412a0..c6e31605e0a 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -620,10 +620,22 @@ async def chat_completion_full_generator( request=request, ) async for data in generator: - if data.get("error_code", 200) != 200: - raise ValueError("{}".format(data["error_msg"])) idx = int(data["request_id"].split("_")[-1]) - # api_server_logger.debug(f"Client {request_id} received: {data}") + if data.get("error_code", 200) != 200: + # Error response - include already-generated tokens in the response + data["outputs"] = { + "text": "", + "completion_tokens": "", + "reasoning_content": "", + "tool_calls": None, + "reasoning_token_num": 0, + "num_image_tokens": 0, + "token_ids": [], + "top_logprobs": None, + "draft_top_logprobs": None, + } + data["metrics"] = data.get("metrics") or {} + data["finished"] = True previous_num_tokens[idx] += len(data["outputs"]["token_ids"]) completion_token_ids[idx].extend(data["outputs"]["token_ids"]) # The logprob for handling the response @@ -767,6 +779,26 @@ async def _create_chat_completion_choice( idx = int(data["request_id"].split("_")[-1]) output = data["outputs"] + finish_reason = "stop" + if previous_num_tokens != max_tokens: + finish_reason = "stop" + if output.get("tool_calls"): + finish_reason = "tool_calls" + else: + finish_reason = "length" + if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]: + finish_reason = "recover_stop" + + if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]: + finish_reason = "abort" + + if data.get("error_msg", None) is not None and "PD Error" in data["error_msg"]: + finish_reason = "pd_reschedule" + + return_completion_token_ids = False + if request.return_token_ids or finish_reason == "pd_reschedule": + return_completion_token_ids = True + if output is not None and output.get("metrics") and output["metrics"].get("request_start_time"): main_process_metrics.e2e_request_latency.observe( time.time() - data.get("metrics").get("request_start_time") @@ -776,7 +808,7 @@ async def _create_chat_completion_choice( reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_calls"), prompt_token_ids=prompt_token_ids if request.return_token_ids else None, - completion_token_ids=completion_token_ids if request.return_token_ids else None, + completion_token_ids=completion_token_ids if return_completion_token_ids else None, prompt_tokens=prompt_tokens if request.return_token_ids else None, completion_tokens=output.get("completion_tokens") if request.return_token_ids else None, ) @@ -808,18 +840,6 @@ async def _create_chat_completion_choice( num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0) num_image_tokens[idx] = output.get("num_image_tokens", 0) or 0 - finish_reason = "stop" - if previous_num_tokens != max_tokens: - finish_reason = "stop" - if output.get("tool_calls"): - finish_reason = "tool_calls" - else: - finish_reason = "length" - if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]: - finish_reason = "recover_stop" - - if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]: - finish_reason = "abort" return ChatCompletionResponseChoice( index=idx, message=message, diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index b277576a1fc..cdd0a1a096d 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -305,7 +305,15 @@ async def completion_full_generator( for data in response: rid = int(data["request_id"].split("_")[-1]) if data.get("error_code", 200) != 200: - raise ValueError("{}".format(data["error_msg"])) + data["outputs"] = { + "text": "", + "completion_tokens": "", + "token_ids": [], + "top_logprobs": None, + "draft_top_logprobs": None, + } + data["metrics"] = data.get("metrics") or {} + data["finished"] = True output = data["outputs"] output_top_logprobs = output.get("top_logprobs") or None @@ -727,13 +735,19 @@ def request_output_to_completion_response( ) if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]: finish_reason = "abort" + if final_res.get("error_msg", None) is not None and "PD Error" in final_res["error_msg"]: + finish_reason = "pd_reschedule" + + return_completion_token_ids = False + if request.return_token_ids or finish_reason == "pd_reschedule": + return_completion_token_ids = True choice_data = CompletionResponseChoice( token_ids=token_ids, index=len(choices), text=output_text, prompt_token_ids=prompt_token_ids if request.return_token_ids else None, - completion_token_ids=completion_token_ids if request.return_token_ids else None, + completion_token_ids=completion_token_ids if return_completion_token_ids else None, completion_tokens=output.get("completion_tokens") if request.return_token_ids else None, prompt_tokens=( prompt_tokens_list[idx // (1 if request.n is None else request.n)] diff --git a/fastdeploy/input/base_processor.py b/fastdeploy/input/base_processor.py index 24e0d7ddec8..cdc01067bfd 100644 --- a/fastdeploy/input/base_processor.py +++ b/fastdeploy/input/base_processor.py @@ -236,6 +236,17 @@ def process_response_dict(self, response_dict, **kwargs): ``stream`` is read from ``kwargs`` (default: True). """ + # Error responses (e.g., preemption) have outputs=None or error_code!=200. + # Skip token decoding and return as-is to let upstream error handling take over. + if isinstance(response_dict, dict): + outputs = response_dict.get("outputs") + error_code = response_dict.get("error_code", 200) + else: + outputs = getattr(response_dict, "outputs", None) + error_code = getattr(response_dict, "error_code", 200) + if outputs is None or error_code != 200: + return response_dict + stream = kwargs.get("stream", True) if stream: return self.process_response_dict_streaming(response_dict, **kwargs) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 7195bf83aa8..a8682374ed6 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -538,8 +538,11 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] if task_id in self.prefill_result_status: if self.prefill_result_status[task_id] != "finished": - result.error_code = 400 - result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" + result.error_code = 501 + result.error_msg = ( + f"PD Error: prefill failed to send cache to decode, " + f"{task_id}, {self.prefill_result_status[task_id]}" + ) llm_logger.info( f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" ) @@ -756,7 +759,7 @@ def _process_batch_output(self): batch_result = list() # reschedule for i in range(batch): - if self.resource_manager.stop_flags[i]: + if self.resource_manager.stop_flags[i] or self.resource_manager.tasks_list[i] is None: continue recovery_stop = False diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index 960a64e7f58..bdb9c5b9c6a 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -49,6 +49,14 @@ class RouterArgs: """ Request timeout in seconds """ + preempt_retry_count: int = 3 + """ + Max retry count when decode instance preempts a request in splitwise mode. + """ + preempt_retry_exclude_decode: bool = False + """ + Whether to exclude the previously used decode instance when retrying after preemption. + """ @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -76,6 +84,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=RouterArgs.request_timeout_secs, help="Request timeout in seconds", ) + parser.add_argument( + "--preempt-retry-count", + type=int, + default=RouterArgs.preempt_retry_count, + help="Max retry count when decode instance preempts a request in splitwise mode.", + ) + parser.add_argument( + "--preempt-retry-exclude-decode", + action="store_true", + default=RouterArgs.preempt_retry_exclude_decode, + help="Whether to exclude the previously used decode instance when retrying after preemption.", + ) return parser @@ -91,6 +111,8 @@ def __init__(self, args): self.port = args.port self.splitwise = args.splitwise self.timeout = args.request_timeout_secs + self.preempt_retry_count = args.preempt_retry_count + self.preempt_retry_exclude_decode = args.preempt_retry_exclude_decode self.mixed_servers = [] self.prefill_servers = [] @@ -152,16 +174,21 @@ async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict instances = [inst for inst in instances if inst.version == version] return [inst.to_dict() for inst in instances] - async def select_pd(self): - """Select one prefill and one decode server""" + async def select_pd(self, exclude_decode=None): + """Select one prefill and one decode server, optionally excluding a decode instance.""" async with self.lock: if not self.prefill_servers: raise RuntimeError(f"No prefill servers available (decode={len(self.decode_servers)})") if not self.decode_servers: raise RuntimeError(f"No decode servers available (prefill={len(self.prefill_servers)})") pidx = random.randint(0, len(self.prefill_servers) - 1) - didx = random.randint(0, len(self.decode_servers) - 1) - return self.prefill_servers[pidx], self.decode_servers[didx] + available_decode = ( + [d for d in self.decode_servers if d is not exclude_decode] if exclude_decode else self.decode_servers + ) + if not available_decode: + available_decode = self.decode_servers + didx = random.randint(0, len(available_decode) - 1) + return self.prefill_servers[pidx], available_decode[didx] async def select_mixed(self): """Select one mixed server""" @@ -191,57 +218,108 @@ async def handle_mixed_request(self, request_data: dict, endpoint_name: str): async def handle_splitwise_request(self, request_data: dict, endpoint_name: str): logger.debug(f"Received request: {request_data}") - prefill_server, decode_server = await self.select_pd() - logger.debug(f"Selected prefill server: {prefill_server}") - logger.debug(f"Selected decode server: {decode_server}") - - if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1: - raise HTTPException( - status_code=400, - detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1", + last_decode_server = None + # Preserve client request_id on first attempt; append retry suffix on subsequent attempts + base_request_id = request_data.get("request_id") or str(uuid4()) + max_attempts = self.preempt_retry_count + 1 + completion_token_ids = [] + + for attempt in range(max_attempts): + prefill_server, decode_server = await self.select_pd( + exclude_decode=last_decode_server if self.preempt_retry_exclude_decode else None ) + logger.debug(f"Selected prefill server: {prefill_server}, decode server: {decode_server}") - # TODO: unify the disaggregate_info in server and remove redundancy params - is_same_node = prefill_server.host_ip == decode_server.host_ip - is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol - is_same_tp_size = prefill_server.tp_size == decode_server.tp_size - use_ipc = is_same_node and is_support_ipc and is_same_tp_size - - disaggregate_info = { - "prefill_ip": prefill_server.host_ip, - "decode_ip": decode_server.host_ip, - "prefill_connector_port": prefill_server.connector_port, - "decode_connector_port": decode_server.connector_port, - "decode_device_ids": decode_server.device_ids, - "decode_rdma_ports": decode_server.rdma_ports, - "transfer_protocol": "ipc" if use_ipc else "rdma", - "decode_tp_size": decode_server.tp_size, - } + if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1: + raise HTTPException( + status_code=400, + detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1", + ) - modified_request = request_data.copy() - modified_request["disaggregate_info"] = disaggregate_info - if "request_id" not in modified_request: - modified_request["request_id"] = str(uuid4()) + # TODO: unify the disaggregate_info in server and remove redundancy params + is_same_node = prefill_server.host_ip == decode_server.host_ip + is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol + is_same_tp_size = prefill_server.tp_size == decode_server.tp_size + use_ipc = is_same_node and is_support_ipc and is_same_tp_size + + disaggregate_info = { + "prefill_ip": prefill_server.host_ip, + "decode_ip": decode_server.host_ip, + "prefill_connector_port": prefill_server.connector_port, + "decode_connector_port": decode_server.connector_port, + "decode_device_ids": decode_server.device_ids, + "decode_rdma_ports": decode_server.rdma_ports, + "transfer_protocol": "ipc" if use_ipc else "rdma", + "decode_tp_size": decode_server.tp_size, + } + + modified_request = request_data.copy() + modified_request["disaggregate_info"] = disaggregate_info + if completion_token_ids: + modified_request["completion_token_ids"] = completion_token_ids + if attempt == 0: + modified_request["request_id"] = base_request_id + else: + modified_request["request_id"] = f"{base_request_id}-retry{attempt}" - logger.debug(f"Modified request: {modified_request}") + logger.debug(f"Modified request: {modified_request}") - if request_data.get("stream", False): - return await self._generate_stream( - modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name - ) - else: - return await self._generate( - modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name - ) - - async def _generate( + if request_data.get("stream", False): + return await self._generate_stream( + modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name + ) + else: + ret_json, status_code = await self._do_generate( + modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name + ) + logger.debug(f"Get response of req {modified_request['request_id']}: {ret_json}") + + if self._is_need_reschedule(ret_json): + last_decode_server = decode_server + choices = ret_json.get("choices", []) + if choices: + completion_token_ids.extend(choices[0].get("message", {}).get("completion_token_ids", [])) + + logger.warning( + f"Preemption detected on attempt {attempt+1}/{max_attempts}, " + f"decode={decode_server.url()}, req_id {modified_request['request_id']}," + f"retrying with new PD instances..." + ) + else: + break + + logger.debug(f"Return response of req_id {base_request_id}: {ret_json}") + return ORJSONResponse(content=ret_json, status_code=status_code) + + def _is_need_reschedule(self, ret_json: dict) -> bool: + # ChatCompletionResponse format: choices[0].finish_reason == "pd_reschedule" + choices = ret_json.get("choices", []) + if choices: + finish_reason = choices[0].get("finish_reason", "") + if finish_reason == "pd_reschedule": + logger.debug(f"PD reschedule request, ret_json: {ret_json}") + return True + # ErrorResponse format compatibility + error = ret_json.get("error", {}) + if isinstance(error, dict) and "PD Error" in str(error.get("message", "")): + return True + return False + + async def _do_generate( self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" - ) -> ORJSONResponse: + ) -> tuple: + """Send requests and return (ret_json, status_code).""" async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls] results = await asyncio.gather(*tasks) ret_json = await results[return_result_url_index].json() - return ORJSONResponse(content=ret_json, status_code=results[return_result_url_index].status) + return ret_json, results[return_result_url_index].status + + async def _generate( + self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" + ) -> ORJSONResponse: + ret_json, status_code = await self._do_generate(modified_request, urls, return_result_url_index, endpoint) + return ORJSONResponse(content=ret_json, status_code=status_code) async def _generate_stream( self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 7200c99ed9c..e5feb661f66 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -277,23 +277,15 @@ def send_cache_info_to_prefill(self, tasks: List[Request]): "request_id": tasks[i].request_id, "error_msg": tasks[i].get("error_msg"), } - if ( - envs.ENABLE_V1_KVCACHE_SCHEDULER - and tasks[i].request_id in self.resource_manager.waiting_abort_req_id_set - ): - addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" - if addr not in cache_info: - cache_info[addr] = [] - cache_info[addr].append(info) else: - addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" info = { "request_id": tasks[i].request_id, "dest_block_ids": dsg_info["block_tables"], } - if addr not in cache_info: - cache_info[addr] = [] - cache_info[addr].append(info) + addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" + if addr not in cache_info: + cache_info[addr] = [] + cache_info[addr].append(info) self.logger.debug(f"send cache info to prefill, {cache_info}") if len(cache_info): diff --git a/tests/engine/test_request.py b/tests/engine/test_request.py index 9a1f0bc31cf..fd9eab17dc7 100644 --- a/tests/engine/test_request.py +++ b/tests/engine/test_request.py @@ -398,6 +398,7 @@ def test_to_dict_basic(self): request.prompt_token_ids_len = 3 request.sampling_params = SamplingParams() request.metrics = RequestMetrics() + request.metrics.prompt_token_ids_len = 3 data = request.to_dict() diff --git a/tests/entrypoints/test_serving_completion.py b/tests/entrypoints/test_serving_completion.py index 9c2beb678df..9b48b2271a3 100644 --- a/tests/entrypoints/test_serving_completion.py +++ b/tests/entrypoints/test_serving_completion.py @@ -21,6 +21,7 @@ import paddle import fastdeploy.metrics.trace as tracing +from fastdeploy.entrypoints.openai.protocol import CompletionResponse from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion from fastdeploy.utils import ErrorCode, ParameterError from fastdeploy.worker.output import LogprobsLists, LogprobsTensors, SpeculateMetrics @@ -171,7 +172,8 @@ async def test_completion_full_generator_branches(self): ec.connection_manager.get_connection = AsyncMock(return_value=(Mock(), rq)) serving = OpenAIServingCompletion(ec, None, "pid", None, -1) res = await serving.completion_full_generator(_make_request(), 1, "req", 1, "m", [[1, 2]], [["p1", "p2"]], [2]) - self.assertIsNone(res) + self.assertIsNotNone(res) + self.assertIsInstance(res, CompletionResponse) ec.connection_manager.cleanup_request.assert_called_once_with("req") def test_logprobs_helpers(self): diff --git a/tests/input/test_ernie_vl_processor.py b/tests/input/test_ernie_vl_processor.py index 6e4fac00182..c440187667f 100644 --- a/tests/input/test_ernie_vl_processor.py +++ b/tests/input/test_ernie_vl_processor.py @@ -361,14 +361,14 @@ def test_process_response_dict(self): # Test with stream=True processor.process_response_dict_streaming = MagicMock(return_value={"text": "response"}) - response_dict = {"ids": [1, 2, 3]} + response_dict = {"ids": [1, 2, 3], "outputs": [[1, 2, 3]]} result = processor.process_response_dict(response_dict, stream=True) processor.process_response_dict_streaming.assert_called_once() self.assertEqual(result, {"text": "response"}) # Test with stream=False processor.process_response_dict_normal = MagicMock(return_value={"text": "response"}) - response_dict = {"ids": [1, 2, 3]} + response_dict = {"ids": [1, 2, 3], "outputs": [[1, 2, 3]]} result = processor.process_response_dict(response_dict, stream=False) processor.process_response_dict_normal.assert_called_once() self.assertEqual(result, {"text": "response"}) diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index c0609094a2b..6a692f24cc4 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -601,8 +601,8 @@ def test_recycle_resources_prefill_failure_sets_error(): with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): processor._recycle_resources(task_id, 0, task, result, is_prefill=True) - assert result.error_code == 400 - assert "failed" in result.error_message + assert result.error_code == 501 + assert "failed" in result.error_msg assert connector.calls and connector.calls[0][1][0] is result diff --git a/tests/router/test_router.py b/tests/router/test_router.py index aa5be52f2f1..ba21b814870 100644 --- a/tests/router/test_router.py +++ b/tests/router/test_router.py @@ -28,7 +28,14 @@ def _make_args(**kwargs): - defaults = {"host": "0.0.0.0", "port": 9000, "splitwise": False, "request_timeout_secs": 30} + defaults = { + "host": "0.0.0.0", + "port": 9000, + "splitwise": False, + "request_timeout_secs": 30, + "preempt_retry_count": 3, + "preempt_retry_exclude_decode": False, + } defaults.update(kwargs) return SimpleNamespace(**defaults) diff --git a/tests/splitwise/test_splitwise_connector.py b/tests/splitwise/test_splitwise_connector.py index 610cfd9246d..6a50b34cfc5 100644 --- a/tests/splitwise/test_splitwise_connector.py +++ b/tests/splitwise/test_splitwise_connector.py @@ -208,15 +208,6 @@ def test_send_cache_info_to_prefill_groups_by_addr_and_skips_error(): "block_tables": [1, 2, 3], }, ), - DummyTask( - request_id="req-err", - disaggregate_info={ - "prefill_ip": "10.0.0.2", - "prefill_connector_port": 9002, - "block_tables": [9], - }, - error_msg="failed", - ), ] connector.send_cache_info_to_prefill(tasks) diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index 5cc7eb4ef01..5db1cc7f7f3 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -579,6 +579,7 @@ def test_prefilled_request_flow_and_resource_check(self): self.assertTrue(manager.has_resource_for_prefilled_req("prefilled")) request = _make_request(request_id="req-prefilled") + request.idx = 0 request.metrics.decode_recv_req_time = 1.0 request.metrics.decode_preallocate_req_time = 2.0 manager.requests[request.request_id] = request From c8a59a3ead0f528e18180865ed8574b35171cff0 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 24 Apr 2026 19:45:48 +0800 Subject: [PATCH 065/143] [Cherry-Pick][CI] Sync dev optimizations to 2.6(#7602) (#7610) --- scripts/check_approval.sh | 8 ++-- scripts/coverage_run.sh | 35 ++++++++++++-- scripts/run_golang_router.sh | 2 +- scripts/run_gpu_4cards.sh | 2 +- scripts/run_pre_ce.sh | 2 +- tests/conftest.py | 90 +++++++++++++++++++++++++++--------- 6 files changed, 106 insertions(+), 33 deletions(-) diff --git a/scripts/check_approval.sh b/scripts/check_approval.sh index db1e8a3e225..899019f5a24 100644 --- a/scripts/check_approval.sh +++ b/scripts/check_approval.sh @@ -40,7 +40,7 @@ function add_failed(){ } -HAS_CUSTOM_REGISTRER=`git diff -U0 upstream/$BRANCH | grep '^\+' | grep -zoE "PD_BUILD_(STATIC_)?OP" || true` +HAS_CUSTOM_REGISTRER=`git diff --merge-base -U0 upstream/$BRANCH | grep '^\+' | grep -zoE "PD_BUILD_(STATIC_)?OP" || true` if [ ${HAS_CUSTOM_REGISTRER} ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD (qingqing01(dangqingqing), Jiang-Jia-Jun(jiangjiajun), heavengate(dengkaipeng)) approval for adding custom op.\n" echo_line2="You must have one PaddlePaddle RD (jeff41404(gaoxiang), yongqiangma(mayongqiang)) approval for adding custom op.\n" @@ -52,7 +52,7 @@ WORKER_OR_CONFIG_LIST=( "fastdeploy/model_executor/graph_optimization" ) -HAS_WORKER_OR_CONFIG_MODIFY=`git diff upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${WORKER_OR_CONFIG_LIST[@]}") || true` +HAS_WORKER_OR_CONFIG_MODIFY=`git diff --merge-base upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${WORKER_OR_CONFIG_LIST[@]}") || true` if [ "${HAS_WORKER_OR_CONFIG_MODIFY}" != "" ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD gongshaotian(gongshaotian) approval for modifing [$(IFS=', '; echo "${WORKER_OR_CONFIG_LIST[*]}")]." check_approval "$echo_line1" 1 gongshaotian @@ -63,7 +63,7 @@ SPECULATIVE_DECODING_LIST=( "custom_ops/gpu_ops/speculate_decoding" ) -HAS_SPECULATIVE_DECODING_MODIFY=`git diff upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${SPECULATIVE_DECODING_LIST[@]}") || true` +HAS_SPECULATIVE_DECODING_MODIFY=`git diff --merge-base upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${SPECULATIVE_DECODING_LIST[@]}") || true` if [ "${HAS_SPECULATIVE_DECODING_MODIFY}" != "" ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD (freeliuzc(liuzichang01), Deleter-D(wangyanpeng04)) approval for modifing [$(IFS=', '; echo "${SPECULATIVE_DECODING_LIST[*]}")]." check_approval "$echo_line1" 1 freeliuzc Deleter-D @@ -71,7 +71,7 @@ fi ENV_FILE="fastdeploy/envs.py" -HAS_ENV_MODIFY=$(git diff upstream/$BRANCH --name-only | grep -E "^${ENV_FILE}$" || true) +HAS_ENV_MODIFY=$(git diff --merge-base upstream/$BRANCH --name-only | grep -E "^${ENV_FILE}$" || true) if [ "${HAS_ENV_MODIFY}" != "" ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD (Jiang-Jia-Jun(jiangjiajun), yuanlehome(liuyuanle), rainyfly(chenjian26), Wanglongzhi2001(wanglongzhi)) approval for modifying [${ENV_FILE}]." check_approval "$echo_line1" 1 Jiang-Jia-Jun yuanlehome rainyfly Wanglongzhi2001 diff --git a/scripts/coverage_run.sh b/scripts/coverage_run.sh index 7d001ea83c1..d3b641f1878 100644 --- a/scripts/coverage_run.sh +++ b/scripts/coverage_run.sh @@ -47,11 +47,13 @@ classify_tests() { } # ============================================================ -# Run Test With Logging +# Run Test With Logging (with retry for OOM/Kill) # ============================================================ run_test_with_logging() { local test_file=$1 local log_prefix=$2 + local max_retries=3 # Max retries for OOM/Kill issues + local retry_count=0 local status echo "Running pytest file: $test_file" @@ -67,14 +69,37 @@ run_test_with_logging() { # Set FD_LOG_DIR to isolate logs for each test export FD_LOG_DIR="$isolated_log_dir" - # Run test - timeout 600 python -m coverage run -m pytest -c ${PYTEST_INI} "$test_file" -vv -s - status=$? + # Retry loop for OOM/Kill issues (only handle "Killed" / SIGKILL) + while [ $retry_count -le $max_retries ]; do + if [ $retry_count -gt 0 ]; then + echo "" + echo "==================== Retrying (${retry_count}/${max_retries}) ====================" + echo "Previous attempt was Killed, retrying..." + # Clean up before retry + sleep 5 # Wait a bit to let resources be released + fi + + # Run test + timeout 600 python -m coverage run -m pytest -c ${PYTEST_INI} "$test_file" -vv -s + status=$? + + # Exit code 137 = SIGKILL (Killed / OOM) + if [ "$status" -eq 137 ] && [ $retry_count -lt $max_retries ]; then + retry_count=$((retry_count + 1)) + continue + fi + + # Break loop on success or non-Kill error or max retries reached + break + done if [ "$status" -ne 0 ]; then echo "$test_file" >> "$log_prefix" echo "" echo "==================== Test Failed: $test_file ====================" + if [ $retry_count -gt 0 ]; then + echo "Total attempts: $((retry_count + 1))" + fi # Use isolated log directory for this test if [ -d "$isolated_log_dir" ]; then @@ -94,7 +119,7 @@ run_test_with_logging() { fi echo ">>> grep error in ${isolated_log_dir}" - grep -Rni --color=auto "error" "${isolated_log_dir}" || true + grep -Rni --color=auto "error" "${isolated_log_dir}" --exclude="pytest_*_error.log" || true fi # print all server logs diff --git a/scripts/run_golang_router.sh b/scripts/run_golang_router.sh index 66578d267d9..85e204bf72c 100644 --- a/scripts/run_golang_router.sh +++ b/scripts/run_golang_router.sh @@ -54,7 +54,7 @@ for test_file in "${test_files[@]}"; do fi echo ">>> grep error in ${log_dir}" - grep -Rni --color=auto "error" "${log_dir}" || true + grep -Rni --color=auto "error" "${log_dir}" --exclude="pytest_*_error.log" || true fi done diff --git a/scripts/run_gpu_4cards.sh b/scripts/run_gpu_4cards.sh index 719ec19255c..9874302cf86 100644 --- a/scripts/run_gpu_4cards.sh +++ b/scripts/run_gpu_4cards.sh @@ -44,7 +44,7 @@ for test_file in "${test_files[@]}"; do if [ -d "${REPO_ROOT}/log" ]; then echo ">>> grep error in ${REPO_ROOT}/log/" - grep -Rni --color=auto "error" "${REPO_ROOT}/log/" || true + grep -Rni --color=auto "error" "${REPO_ROOT}/log/" --exclude="pytest_*_error.log" || true else echo "${REPO_ROOT}/log directory not found" fi diff --git a/scripts/run_pre_ce.sh b/scripts/run_pre_ce.sh index 928aa2e7cef..4bf56a290aa 100644 --- a/scripts/run_pre_ce.sh +++ b/scripts/run_pre_ce.sh @@ -38,7 +38,7 @@ for subdir in "$run_path"*/; do if [ $exit_code -ne 0 ]; then if [ -d "${subdir%/}/log" ]; then echo ">>> grep error in ${subdir%/}/log/" - grep -Rni --color=auto "error" "${subdir%/}/log/" || true + grep -Rni --color=auto "error" "${subdir%/}/log/" --exclude="pytest_*_error.log" || true else echo "${subdir%/}/log directory not found" fi diff --git a/tests/conftest.py b/tests/conftest.py index 057dc15aebc..5a57414bc69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,23 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob +import os +import re +import time +from typing import Any, Union + import pytest +from e2e.utils.serving_utils import ( # noqa: E402 + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + clean_ports, +) def pytest_configure(config): + """ + Configure pytest: + - Register custom markers + - Ensure log directory exists + """ config.addinivalue_line("markers", "gpu: mark test as requiring GPU platform") + log_dir = os.environ.get("FD_LOG_DIR", "log") + os.makedirs(log_dir, exist_ok=True) -def pytest_collection_modifyitems(config, items): - """Skip GPU-marked tests when not on a GPU platform. - IMPORTANT: Do NOT import paddle or fastdeploy here. This function runs - during pytest collection (before fork). Importing paddle initializes the - CUDA runtime, which makes forked child processes unable to re-initialize - CUDA (OSError: CUDA error(3), initialization error). +def pytest_collection_modifyitems(config, items): + """ + Skip tests marked with 'gpu' if no GPU device is detected. + + IMPORTANT: + Do NOT import paddle or fastdeploy here. + This hook runs during test collection (before process fork). + Importing CUDA-related libraries will initialize CUDA runtime, + causing forked subprocesses to fail with: + OSError: CUDA error(3), initialization error. """ - import glob - has_gpu = len(glob.glob("/dev/nvidia[0-9]*")) > 0 if has_gpu: @@ -40,18 +61,11 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_marker) -import time -from typing import Any, Union - -from e2e.utils.serving_utils import ( # noqa: E402 - FD_API_PORT, - FD_CACHE_QUEUE_PORT, - FD_ENGINE_QUEUE_PORT, - clean_ports, -) - - class FDRunner: + """ + Wrapper for FastDeploy LLM serving process. + """ + def __init__( self, model_name_or_path: str, @@ -88,7 +102,9 @@ def generate( sampling_params, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - + """ + Run generation and return token IDs and generated texts. + """ req_outputs = self.llm.generate(prompts, sampling_params=sampling_params, **kwargs) outputs: list[tuple[list[list[int]], list[str]]] = [] for output in req_outputs: @@ -101,6 +117,9 @@ def generate_topp0( max_tokens: int, **kwargs: Any, ) -> list[tuple[list[int], str]]: + """ + Generate outputs with deterministic sampling (top_p=0, temperature=0). + """ from fastdeploy.engine.sampling_params import SamplingParams topp_params = SamplingParams(temperature=0.0, top_p=0, max_tokens=max_tokens) @@ -116,4 +135,33 @@ def __exit__(self, exc_type, exc_value, traceback): @pytest.fixture(scope="session") def fd_runner(): + """Provide FDRunner as a pytest fixture.""" return FDRunner + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + """ + Capture failed test cases and save error logs to FD_LOG_DIR. + + Only logs failures during the test execution phase. + """ + outcome = yield + report = outcome.get_result() + + if report.when == "call" and report.failed: + log_dir = os.environ.get("FD_LOG_DIR", "log") + os.makedirs(log_dir, exist_ok=True) + + case_name = re.sub(r"_+", "_", re.sub(r"[^\w\-.]", "_", item.nodeid.split("::", 1)[-1])).strip("_")[:200] + + error_log_file = os.path.join(log_dir, f"pytest_{case_name}_error.log") + + with open(error_log_file, "w", encoding="utf-8") as f: + f.write(f"Case name: {item.nodeid}\n") + f.write(f"Outcome: {report.outcome}\n") + f.write(f"Duration: {report.duration:.4f}s\n") + f.write("-" * 80 + "\n") + + if report.longrepr: + f.write(str(report.longrepr)) From 6ad8fce47c0ec5aedb7690824d7fa887c2e9c007 Mon Sep 17 00:00:00 2001 From: RAM Date: Mon, 27 Apr 2026 10:39:45 +0800 Subject: [PATCH 066/143] [RL][Feature] R3 Support GPUPrefixCache, CPUPrefixCache, PD Disaggregation (#7390) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [RL][Feature] R3 Phase 2: routing data follows KVCache block lifecycle (swap/storage/PD) Implement dual-buffer architecture for routing replay: - GPU transient buffer [max_num_batched_tokens, L, K] with Triton v2 kernel - SharedMemory routing_host_buffer for cross-process Engine/Worker/CTM sharing - Lazy SharedMemory attach in Worker and TokenProcessor (Engine creates after profiling) - CTM routing write/read for swap and storage backends - PD disaggregation: P gathers routing via send_first_token, D writes to host buffer - Local store persistence verified end-to-end Co-Authored-By: Claude Opus 4.6 * [RL][Refactor] Extract Store/Buffer classes from routing_indices_cache.py Move Store-related classes (StoreWrapper, StoreProcess, StoreTask, RoutingStoreBase, RoutingStoreLocal, RoutingStoreRDMA, etc.) to fastdeploy/cache_manager/routing_store.py and Buffer classes (RoutingHostBuffer, RoutingHostBufferView, RoutingSwapBuffer, RoutingSwapBufferView) to fastdeploy/cache_manager/routing_cache_manager.py. Update import paths in cache_transfer_manager.py, common_engine.py, and token_processor.py. Keep only actually-used imports in routing_indices_cache.py (RoutingHostBufferView, StoreWrapper). Pure file reorganization — zero functional change. Co-Authored-By: Claude Opus 4.6 * [RL][Refactor] Unify routing config, delete v1 kernel, rename to RoutedExpertsCapturer + RoutingCacheManager - Centralize routing_dtype/num_moe_layers/moe_top_k computation in RoutingReplayConfig (FDConfig.__init__), eliminating 4 duplicate computation sites - Delete v1 kernel (_save_routing_kernel) and all v1 fallback branches in moe.py and pre_and_post_process.py — single code path through v2 linear-indexed kernel - Rename RoutingReplayManager → RoutedExpertsCapturer (Worker-side, capture-only) with backward-compat alias - Remove request management (register/deregister/put_finished_batch), store dispatch (_put_request_to_store), and suspend methods from Worker - Create stateless RoutingCacheManager (Engine-side) for routing data aggregation and return-mode dispatch (response/local/rdma) - Migrate all callers: gpu_model_runner, metax_model_runner, common_engine, token_processor, prefix_cache_manager - Remove FD_SUSPEND_ROUTING_REPLAY env var from envs.py Co-Authored-By: Claude Opus 4.6 * [RL][BugFix] Fix _init_routing_host_buffer rename missed in engine.py The rename from _init_routing_host_buffer to _init_routing_cache_manager in common_engine.py was not updated in engine.py:_stop_profile(), causing AttributeError on startup with routing replay enabled. Co-Authored-By: Claude Opus 4.6 * success run response&store * success run p2pstore * Refine code * Refine code * fix config bug * fix ci bug * close share memory * Supplementary shape inspection * refine debug log * fix tp+ep all_gatcher bug * delete suspend r3 * get total token num from batch_id_per_token * refine code * [CI] Update R3_BaseLine to R3_BaseLine_uint8_0424 * [CI] Fix code codestyle * fix test_engine --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> --- fastdeploy/cache_manager/cache_data.py | 26 +- .../cache_manager/cache_transfer_manager.py | 220 ++++- .../cache_manager/prefix_cache_manager.py | 8 + .../cache_manager/routing_cache_manager.py | 286 ++++++ fastdeploy/cache_manager/routing_store.py | 515 ++++++++++ fastdeploy/config.py | 32 +- fastdeploy/engine/common_engine.py | 37 + fastdeploy/engine/engine.py | 16 + fastdeploy/engine/request.py | 11 +- .../engine/sched/resource_manager_v1.py | 31 + fastdeploy/entrypoints/openai/protocol.py | 2 + fastdeploy/entrypoints/openai/serving_chat.py | 11 + .../entrypoints/openai/serving_completion.py | 12 + fastdeploy/envs.py | 2 - fastdeploy/model_executor/forward_meta.py | 4 +- fastdeploy/model_executor/layers/moe/moe.py | 14 +- .../layers/moe/routing_indices_cache.py | 902 +++--------------- .../model_executor/pre_and_post_process.py | 49 +- fastdeploy/output/token_processor.py | 104 ++ fastdeploy/worker/gpu_model_runner.py | 26 +- fastdeploy/worker/metax_model_runner.py | 31 +- .../test_cache_transfer_manager.py | 1 + .../rollout_routing_replay_test_utils.py | 6 +- tests/engine/test_engine.py | 1 + tests/output/test_process_batch_output.py | 1 + tests/output/test_token_processor.py | 1 + 26 files changed, 1479 insertions(+), 870 deletions(-) create mode 100644 fastdeploy/cache_manager/routing_cache_manager.py create mode 100644 fastdeploy/cache_manager/routing_store.py diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 82911eccfa3..9fd48cec2ce 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -14,13 +14,35 @@ # limitations under the License. """ +from dataclasses import dataclass from enum import Enum +from typing import Any, Optional from fastdeploy.utils import get_logger logger = get_logger("prefix_cache_manager", "cache_manager.log") +@dataclass +class AuxBlockDataSpec: + """ + Describes a type of auxiliary data bound to KVCache blocks. + CacheTransferManager iterates registered specs during swap/storage + to perform corresponding data transfers. + """ + + name: str + num_layers: int + per_token_size: int = 0 + block_size: int = 0 + dtype: str = "uint8" + swap_buffer: Optional[Any] = None + enabled: bool = True + + def get_storage_key(self, key_prefix: str, block_hash: str, rank: int) -> str: + return f"prefix{key_prefix}_{block_hash}_{rank}_{self.name}" + + class CacheStatus(Enum): """ cache status enum class @@ -56,6 +78,7 @@ def __init__( cache_status=CacheStatus.GPU, is_persistent=False, persistent_shared_count=0, + aux_data_names=None, ): """ Args: @@ -89,6 +112,7 @@ def __init__( self.cache_status = cache_status self.is_persistent = is_persistent self.persistent_shared_count = persistent_shared_count + self.aux_data_names = aux_data_names or [] self.req_id_set = set() def __lt__(self, other): @@ -102,7 +126,7 @@ def __lt__(self, other): else: return self.depth > other.depth - def __str__(self): + def __str__(self) -> str: """ return node info """ diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index bd347b384e6..f5e39fe0c0f 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -24,6 +24,7 @@ import threading import time import traceback +import weakref from typing import List import numpy as np @@ -48,7 +49,7 @@ FileStore, MooncakeStore, ) -from fastdeploy.config import CacheConfig, SpeculativeConfig +from fastdeploy.config import CacheConfig, RoutingReplayConfig, SpeculativeConfig from fastdeploy.engine.request import ControlRequest, ControlResponse from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.inter_communicator.fmq import FMQ @@ -129,7 +130,11 @@ def parse_args(): ) parser.add_argument("--model_path", type=str, help="The path of model") + # Routing replay (R3) — single JSON arg, mirrors SpeculativeConfig pattern + parser.add_argument("--routing_replay_config", type=json.loads, default="{}", help="Routing replay config JSON") + args = parser.parse_args() + args.routing_replay_config = RoutingReplayConfig(args.routing_replay_config) return args @@ -244,6 +249,25 @@ def __init__(self, args): self._init_cpu_cache() if self.storage_backend_type is not None: self._init_storage(args) + + # Initialize auxiliary data specs (e.g., routing replay) + self.aux_data_specs = {} + self.routing_host_view = None + self.routing_swap_buffer = None + self.routing_replay_config = args.routing_replay_config + self.engine_worker_queue_port = args.engine_worker_queue_port + self._init_routing_aux_data() + + # Register finalizer to release routing SharedMemory on process exit. + # Must use a static method — callback must NOT hold a reference to self, + # otherwise the object can never be GC'd and the finalizer won't fire. + self._finalizer = weakref.finalize( + self, + CacheTransferManager._cleanup_routing_resources, + self.routing_swap_buffer, + self.routing_host_view, + ) + self._init_control() cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) @@ -310,6 +334,185 @@ def __init__(self, args): ) self.cache_transfer_inited_signal.value[self.rank] = 1 + def _init_routing_aux_data(self): + """Initialize routing auxiliary data buffers for swap sync.""" + routing_replay_config = self.routing_replay_config + if not routing_replay_config.enable_routing_replay: + return + + try: + from fastdeploy.cache_manager.cache_data import AuxBlockDataSpec + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingHostBufferView, + RoutingSwapBuffer, + ) + + num_moe_layers = routing_replay_config.num_moe_layers + moe_top_k = routing_replay_config.moe_top_k + routing_dtype = routing_replay_config.routing_dtype + + if num_moe_layers == 0 or moe_top_k == 0: + return + + spec = AuxBlockDataSpec( + name="routing", + num_layers=num_moe_layers, + per_token_size=moe_top_k, + block_size=self.block_size, + dtype=routing_dtype, + ) + + # Create routing swap buffer (for CPU blocks). + # Only rank 0 needs it — _swap_routing() only runs on rank 0. + if self.num_cpu_blocks > 0 and self.rank == 0: + dp_suffix = str(self.engine_worker_queue_port) + self.routing_swap_buffer = RoutingSwapBuffer( + num_cpu_blocks=self.num_cpu_blocks, + block_size=self.block_size, + num_moe_layers=num_moe_layers, + top_k=moe_top_k, + dtype=routing_dtype, + dp_suffix=dp_suffix, + ) + spec.swap_buffer = self.routing_swap_buffer + + # Attach to routing host buffer (SharedMemory created by Engine) + dp_suffix = str(self.engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = self.num_gpu_blocks * self.block_size + shape = (max_num_kv_tokens, num_moe_layers, moe_top_k) + try: + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=routing_dtype, shm_name=shm_name) + logger.info(f"[R3] CTM attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + logger.warning(f"[R3] CTM RoutingHostBuffer {shm_name} not found") + + self.aux_data_specs["routing"] = spec + logger.info(f"[R3] CTM registered routing aux data: layers={num_moe_layers}, top_k={moe_top_k}") + + except Exception as e: + logger.warning(f"[R3] CTM failed to init routing aux data: {e}") + + @staticmethod + def _cleanup_routing_resources(routing_swap_buffer, routing_host_view): + """Release routing SharedMemory on process exit. Called by weakref.finalize.""" + if routing_swap_buffer is not None: + routing_swap_buffer.close() + if routing_host_view is not None: + routing_host_view.close() + + def _swap_routing(self, gpu_block_ids, cpu_block_ids, direction): + """ + Swap routing data between routing_host_buffer and routing_swap_buffer. + Pure CPU-to-CPU numpy memcpy, no GPU DMA. + Only rank 0 performs this (routing buffers are cross-rank SharedMemory). + """ + if self.routing_host_view is None or self.routing_swap_buffer is None: + logger.warning( + f"[R3] _swap_routing skipped: host_view={self.routing_host_view is not None}, " + f"swap_buffer={self.routing_swap_buffer is not None}" + ) + return + if self.rank > 0: + return + bs = self.block_size + for gpu_bid, cpu_bid in zip(gpu_block_ids, cpu_block_ids): + gpu_start = gpu_bid * bs + gpu_end = gpu_start + bs + cpu_start = cpu_bid * bs + cpu_end = cpu_start + bs + if direction == "to_cpu": + self.routing_swap_buffer.buffer[cpu_start:cpu_end] = self.routing_host_view.buffer[gpu_start:gpu_end] + elif direction == "to_gpu": + self.routing_host_view.buffer[gpu_start:gpu_end] = self.routing_swap_buffer.buffer[cpu_start:cpu_end] + else: + raise ValueError(f"[R3] _swap_routing: unknown direction '{direction}', expected 'to_cpu' or 'to_gpu'") + logger.info( + f"[R3] _swap_routing {direction}: {len(gpu_block_ids)} blocks, " + f"gpu_ids={gpu_block_ids[:3]}{'...' if len(gpu_block_ids) > 3 else ''}, " + f"cpu_ids={cpu_block_ids[:3]}{'...' if len(cpu_block_ids) > 3 else ''}" + ) + + def _write_routing_to_storage(self, task_keys, gpu_block_ids): + """ + Write routing data from routing_host_buffer to storage backend. + Only for mooncake/file backends; only tp_rank=0 writes routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + routing_keys = [] + routing_ptrs = [] + routing_sizes = [] + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + block_data = self.routing_host_view.buffer[start:end] + if not block_data.flags["C_CONTIGUOUS"]: + block_data = np.ascontiguousarray(block_data) + routing_keys.append(key) + routing_ptrs.append(block_data.ctypes.data) + routing_sizes.append(per_block_bytes) + + if routing_keys: + self.storage_backend.batch_set( + keys=routing_keys, target_locations=routing_ptrs, target_sizes=routing_sizes + ) + logger.debug(f"[R3] Wrote {len(routing_keys)} routing blocks to storage") + except Exception as e: + logger.warning(f"[R3] Failed to write routing to storage: {e}") + + def _read_routing_from_storage(self, task_keys, gpu_block_ids): + """ + Read routing data from storage backend into routing_host_buffer. + Only for mooncake/file backends; only tp_rank=0 reads routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + target_slice = self.routing_host_view.buffer[start:end] + if not target_slice.flags["C_CONTIGUOUS"]: + # Need contiguous target for ctypes pointer + tmp = np.ascontiguousarray(target_slice) + result = self.storage_backend.get( + key=key, target_location=tmp.ctypes.data, target_size=per_block_bytes + ) + if result is not None and result >= 0: + self.routing_host_view.buffer[start:end] = tmp + else: + self.storage_backend.get( + key=key, target_location=target_slice.ctypes.data, target_size=per_block_bytes + ) + + logger.debug(f"[R3] Read {len(task_keys)} routing blocks from storage") + except Exception as e: + logger.warning(f"[R3] Failed to read routing from storage: {e}") + def _init_control(self): dp_rank = self.local_data_parallel_id tp_rank = self.rank @@ -812,6 +1015,9 @@ def read_storage_task(self, task: ReadStorageTask): logger.info( f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}" ) + # Read routing data from storage for matched blocks + matched_keys = task.keys[: len(valid_gpu_block_ids)] + self._read_routing_from_storage(matched_keys, valid_gpu_block_ids) except Exception as e: logger.error( f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}" @@ -1003,6 +1209,9 @@ def write_back_storage_task(self, task: WriteStorageTask): logger.info( f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" ) + # Write routing data to storage (shares dedup with KVCache) + remaining_keys = task.keys[match_block_num:] + self._write_routing_to_storage(remaining_keys, gpu_block_ids) except Exception as e: logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}") gpu_block_ids = [] @@ -1387,6 +1596,10 @@ def _transfer_data( 0, ) + # Routing: routing_host_buffer → routing_swap_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_cpu") + elif event_type.value == CacheStatus.SWAP2GPU.value: swap_cache_all_layers( self.gpu_cache_k_tensors, @@ -1425,6 +1638,11 @@ def _transfer_data( self.device, 1, ) + + # Routing: routing_swap_buffer → routing_host_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_gpu") + else: logger.warning( f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported" diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 1da592f891b..328fd224f5b 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -299,6 +299,13 @@ def launch_cache_manager( else: storage_arg_str = " " + # Compute routing replay args for CTM — single JSON arg + routing_replay_config = getattr(self.config, "routing_replay_config", None) + if routing_replay_config is not None and routing_replay_config.enable_routing_replay: + routing_arg_str = f" --routing_replay_config '{routing_replay_config.to_json_string()}'" + else: + routing_arg_str = "" + if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend: for i in range(tensor_parallel_size): launch_cmd = ( @@ -330,6 +337,7 @@ def launch_cache_manager( + f" --write_policy {cache_config.write_policy}" + f" --max_model_len {self.config.model_config.max_model_len}" + f" --model_path {self.config.model_config.model}" + + routing_arg_str + f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" ) logger.info(f"Launch cache transfer manager, command:{launch_cmd}") diff --git a/fastdeploy/cache_manager/routing_cache_manager.py b/fastdeploy/cache_manager/routing_cache_manager.py new file mode 100644 index 00000000000..68dff10b37d --- /dev/null +++ b/fastdeploy/cache_manager/routing_cache_manager.py @@ -0,0 +1,286 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +import multiprocessing +import multiprocessing.shared_memory +from typing import Optional + +import numpy as np + +from fastdeploy.utils import get_logger + +logger = get_logger("routing_cache_manager", "routing_cache_manager.log") + + +class RoutingHostBuffer: + """ + Manages routing_host_buffer (corresponds to KVCache GPU cache). + Indexed by gpu_block_id * block_size + offset. + Shared across processes via POSIX SharedMemory. + Each DP rank creates its own instance; name includes dp_suffix. + """ + + def __init__( + self, num_gpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_gpu_tokens = num_gpu_blocks * block_size + self.shape = (max_num_gpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_host_buffer.{dp_suffix}" + # Clean up stale SharedMemory from previous crashed process + try: + stale = multiprocessing.shared_memory.SharedMemory(name=self.shm_name, create=False) + stale.close() + stale.unlink() + logger.warning(f"[R3] Cleaned up stale SharedMemory: {self.shm_name}") + except FileNotFoundError: + pass + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = -1 # unsigned wrap: uint8→255, uint16→65535, uint32→4294967295 + + self._owner = True + logger.info( + f"[R3] Created RoutingHostBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + """Close and unlink SharedMemory. Only the owner (creator) unlinks.""" + self.shm.close() + if self._owner: + self.shm.unlink() + self._owner = False + + +class RoutingHostBufferView: + """Read/write view of routing_host_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def scatter(self, slot_mapping: np.ndarray, data: np.ndarray): + """Scatter GPU buffer data to corresponding slots (Worker calls this).""" + self.buffer[slot_mapping] = data + + def gather(self, slot_mapping: np.ndarray) -> np.ndarray: + """Gather data from specified slots (TokenProcessor calls this).""" + return self.buffer[slot_mapping].copy() + + def close(self): + self.shm.close() + + +class RoutingSwapBuffer: + """ + Manages routing_swap_buffer (corresponds to KVCache CPU cache). + Indexed by cpu_block_id * block_size + offset. + CacheTransferManager creates this; shared via SharedMemory. + """ + + def __init__( + self, num_cpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_cpu_tokens = num_cpu_blocks * block_size + self.shape = (max_num_cpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_swap_buffer.{dp_suffix}" + # Clean up stale SharedMemory from previous crashed process + try: + stale = multiprocessing.shared_memory.SharedMemory(name=self.shm_name, create=False) + stale.close() + stale.unlink() + logger.warning(f"[R3] Cleaned up stale SharedMemory: {self.shm_name}") + except FileNotFoundError: + pass + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = -1 # unsigned wrap: uint8→255, uint16→65535, uint32→4294967295 + + self._owner = True + logger.info( + f"[R3] Created RoutingSwapBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + """Close and unlink SharedMemory. Only the owner (creator) unlinks.""" + self.shm.close() + if self._owner: + self.shm.unlink() + self._owner = False + + +class RoutingSwapBufferView: + """Read/write view of routing_swap_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def close(self): + self.shm.close() + + +def split_request_id(request_id: str) -> str: + """ + Split the request id to get rollout id. + + request_id: "chatcmpl-request.user-uuid" + rollout_id: "request.user" + example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" + -> "xxx_xxx_epoch_15:2:2:1" + """ + chat_type, tmp_str = request_id.split("-", 1) + assert ( + chat_type == "chatcmpl" + ), "Rollout Routing Replay only supports chatcmpl. Please check request type and userid settings." + reversed_tmp_str = tmp_str[::-1].split("-", 5) + rollout_id = reversed_tmp_str[-1][::-1] + return rollout_id + + +class RoutingCacheManager: + """ + Engine-side stateless routing data manager. + Does NOT maintain request mapping — request state is fully managed by Scheduler. + Responsible for: SharedMemory creation/destruction, routing data gather, return mode dispatch. + """ + + def __init__(self, fd_config, num_gpu_blocks: int): + routing_replay_config = fd_config.routing_replay_config + self.num_moe_layers = routing_replay_config.num_moe_layers + self.moe_top_k = routing_replay_config.moe_top_k + self.routing_dtype = routing_replay_config.routing_dtype + self.only_last_turn = routing_replay_config.only_last_turn + self.use_fused_put = routing_replay_config.use_fused_put + self.block_size = fd_config.cache_config.block_size + self.return_mode = ( + routing_replay_config.routing_store_type + ) # "local" / "rdma" → p2pstore; "response" → attach to RequestOutput + + dp_suffix = str(fd_config.parallel_config.local_engine_worker_queue_port) + + # Create SharedMemory routing_host_buffer + self.host_buffer = RoutingHostBuffer( + num_gpu_blocks=num_gpu_blocks, + block_size=self.block_size, + num_moe_layers=self.num_moe_layers, + top_k=self.moe_top_k, + dtype=self.routing_dtype, + dp_suffix=dp_suffix, + ) + + # Host view for gather operations + self.host_view = RoutingHostBufferView( + shape=self.host_buffer.shape, + dtype=self.routing_dtype, + shm_name=self.host_buffer.shm_name, + ) + + # Initialize store wrapper for p2pstore mode + self._store_wrapper = None + if self.return_mode in ("local", "rdma"): + from fastdeploy.cache_manager.routing_store import StoreWrapper + + self._store_wrapper = StoreWrapper(fd_config=fd_config) + self._store_wrapper.start_store_warpper() + + logger.info( + f"[R3] RoutingCacheManager initialized: return_mode={self.return_mode}, " + f"host_buffer shape={self.host_buffer.shape}" + ) + + def gather_routing_for_request(self, block_table, seq_len: int) -> np.ndarray: + """ + Gather complete routing data for a request from routing_host_buffer. + + Args: + block_table: List of block IDs for the request + seq_len: Total sequence length + + Returns: + routing_data: [seq_len, num_moe_layers, top_k] numpy array + """ + num_blocks = math.ceil(seq_len / self.block_size) + block_ids = block_table[:num_blocks] + positions = np.arange(seq_len) + block_indices = positions // self.block_size + offsets = positions % self.block_size + slot_mapping = np.array(block_ids)[block_indices] * self.block_size + offsets + return self.host_view.gather(slot_mapping) + + def on_request_finished(self, request_id: str, block_table, seq_len: int) -> Optional[np.ndarray]: + """ + Unified entry point when a request finishes. Called by TokenProcessor on EOS detection. + Scheduler/TokenProcessor passes request_id, block_table, seq_len. + + Returns: + - "response" mode: routing_data numpy array (caller attaches to RequestOutput) + - "local"/"rdma" mode: None (submitted to StoreWrapper internally) + """ + routing_data = self.gather_routing_for_request(block_table, seq_len) + + if self._store_wrapper is not None: + # P2PStore mode: submit to store + rollout_id = split_request_id(request_id) + # Transpose to [num_moe_layers, seq_len, top_k] for store compatibility + # TODO(gongshaotian): Delete redundant transpose + routing_data = np.ascontiguousarray(routing_data.transpose(1, 0, 2)) + + if self.use_fused_put: + self._store_wrapper.submit_put_task(routing_indices=routing_data, rollout_id=rollout_id) + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) + else: + for layer_id in range(self.num_moe_layers): + layer_buffer = routing_data[layer_id] + self._store_wrapper.submit_put_task( + routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id + ) + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) + return None + else: + # Response mode: return data for caller to attach to RequestOutput + return routing_data + + def reset(self): + """Reset SharedMemory buffer. Used during RL round cleanup.""" + self.host_buffer.buffer[:] = -1 + + def close(self): + """Clean up SharedMemory resources.""" + if self.host_view is not None: + self.host_view.close() + self.host_view = None + if self.host_buffer is not None: + self.host_buffer.close() + self.host_buffer = None diff --git a/fastdeploy/cache_manager/routing_store.py b/fastdeploy/cache_manager/routing_store.py new file mode 100644 index 00000000000..cd1dce19bd4 --- /dev/null +++ b/fastdeploy/cache_manager/routing_store.py @@ -0,0 +1,515 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import atexit +import functools +import multiprocessing +import os +import shutil +import threading +import time +import traceback +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Process, Queue +from typing import Optional, TypedDict + +import numpy as np +import paddle + +from fastdeploy.utils import get_logger + +logger = get_logger("routing_cache_manager", "routing_cache_manager.log") + +from fastdeploy.config import RoutingReplayConfig + + +class StoreTask(TypedDict): + task_type: str + key: str + data: np.ndarray + + +class StoreWrapper(object): + def __init__(self, fd_config) -> None: + super().__init__() + self.fd_config = fd_config + + # Initialize task queue + moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index + max_num_seqs = fd_config.scheduler_config.max_num_seqs + self.queue_max_size = moe_layer_num * max_num_seqs * 1000 + + self.manager = multiprocessing.Manager() + self._task_queue = self.manager.Queue(maxsize=self.queue_max_size) + + self._monitor_thread: threading.Thread = None + self._stop_monitor = threading.Event() + + # Initialize consumer process + self._routing_store_process = StoreProcess( + task_queue=self._task_queue, + routing_replay_config=self.fd_config.routing_replay_config, + max_model_len=self.fd_config.model_config.max_model_len, + ) + self._store_process_running = False + + # Register atexit handler + atexit.register(self.shutdown) + + def shutdown(self): + """ """ + if not self._store_process_running: + return + self._store_process_running = False + + # Stop the monitor thread + self._stop_monitor.set() + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=3.0) + + # Put a sentinel value to signal the consumer to stop + if self._routing_store_process and self._routing_store_process.is_alive(): + try: + self._task_queue.put_nowait(None) + except Exception as e: + logger.info(f"Could not put sentinel into queue: {e}") + + if self._routing_store_process and self._routing_store_process.is_alive(): + # Wait for all tasks to be processed + self._routing_store_process.join(timeout=10.0) + if self._routing_store_process.is_alive(): + self._routing_store_process.close() + self._routing_store_process.join() + + self._task_queue.join() + self.manager.shutdown() + self._store_process_running = False + + def start_store_warpper(self): + """ """ + if self._store_process_running: + return + self._store_process_running = True + + # Start monitor thread + self._stop_monitor.clear() + self._monitor_thread = threading.Thread(target=self._monitor_queue_load, daemon=True) + self._monitor_thread.start() + + # Start Routing Store Wrapper in sub process + self._routing_store_process.start() + + def _monitor_queue_load(self): + """ """ + while not self._stop_monitor.is_set(): + time.sleep(2.0) + if not self._store_process_running: + break + qsize = self._task_queue.qsize() + + # Alarm when the task exceeds 80% of the queue capacity + if qsize > self.queue_max_size * 0.8: + logger.warning( + f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " + "Consider increasing max_workers or queue_max_size." + ) + logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") + + def submit_put_task(self, routing_indices: np.ndarray, rollout_id: str, layer_idx: int = None) -> None: + """Submit a put task to the task queue""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + + start_time = time.perf_counter() + if layer_idx is not None: + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + else: + rdma_rollout_key = rollout_id + + task: StoreTask = {"task_type": "put", "key": rdma_rollout_key, "data": routing_indices} + + try: + self._task_queue.put_nowait(task) + except Exception: + raise RuntimeError(f"Queue is FULL. Dropping put task for key: {rdma_rollout_key}. ") + logger.info(f"[R3] Submit put task for key: {rdma_rollout_key}, cost time: {time.perf_counter()-start_time} s") + + def submit_clear_store_task(self) -> None: + """Submit clear store task""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + + start_time = time.perf_counter() + task: StoreTask = {"task_type": "clear_store", "key": None, "data": None} + + try: + self._task_queue.put_nowait(task) + # Wait for the task to be processed + self._task_queue.join() + except Exception: + raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") + logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") + + def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: + """Submit clear prefix batch task""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + prefix_batch_id = self.get_needed_clear_ids(rollout_id) + if prefix_batch_id is None: + return + start_time = time.perf_counter() + if layer_idx is not None: + rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" + else: + rdma_rollout_key = prefix_batch_id + + task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} + try: + self._task_queue.put_nowait(task) + except Exception: + raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") + logger.info( + f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" + ) + + def get_needed_clear_ids(self, rollout_id: str) -> Optional[str]: + """ + Generate the prefix IDs for all closed multi-round tasks. + rollout_id: "xxx_xxx_epoch_15:2:2:1" + example: xxx_xxx_data_id:gen_id:turn_id:segment_id + """ + reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = rollout_id[::-1].split(":", 2) + prefix_gen_id = reversed_prefix_gen_id[::-1] + turn_id = eval(reversed_turn_id[::-1]) + segment_id = eval(reversed_segment_id[::-1]) + + assert turn_id >= 0 and segment_id >= 0 + prefix_batch = None + if turn_id > 0: + prefix_batch = f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}" + return prefix_batch + + +class StoreProcess(Process): + def __init__(self, task_queue: Queue, routing_replay_config: RoutingReplayConfig, max_model_len: int) -> None: + super().__init__() + self.max_model_len = max_model_len + self._task_queue = task_queue + self.routing_replay_config = routing_replay_config + self.max_workers = 5 + self._closed = False + + # Note: _routing_store and _event_loop_thread must be initialized in run() + # because they cannot be properly inherited after fork() + self._routing_store = None + self._event_loop_thread = None + + def run(self): + logger.info(f"[R3] Start Running Store Wrapper in sub process {os.getpid()}") + + # Initialize routing store in subprocess + self._routing_store = get_routing_store(routing_replay_config=self.routing_replay_config) + + # Initialize event loop thread in subprocess + self._event_loop_thread = AsyncEventLoopThread() + self._event_loop_thread.start() + if not self._event_loop_thread._started_event.wait(timeout=5.0): + raise RuntimeError("Failed to start async event loop thread in subprocess") + + clear_store_task = StoreTask({"task_type": "clear_store", "key": None, "data": None}) + self._task_queue.put_nowait(clear_store_task) + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + while not self._closed: + try: + task = self._task_queue.get() + if task is None: # Sentinel + self._task_queue.task_done() + break + + if task["task_type"] == "put": + future = executor.submit(self.process_put_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + elif task["task_type"] == "clear_store": + future = executor.submit(self.process_clear_store_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + elif task["task_type"] == "clear_prefix_batch": + future = executor.submit(self.process_clear_prefix_batch_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + except Exception as e: + self._task_queue.task_done() + raise RuntimeError(f"Error during processing task. {e}") + + logger.info("RoutingReplay Consumer Process Shutdown.") + + def process_put_task(self, store_task: StoreTask) -> None: + try: + # TODO(gongshaotian): delete this after trainer support dynamic len + store_task["data"] = self.pad_routing_indices(store_task["data"]) + coro_obj = self._routing_store.put(routing_key=store_task["key"], routing_indices=store_task["data"]) + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error submitting put task: {e}") + traceback.print_exc() + raise + + def process_clear_store_task(self, store_task: StoreTask) -> None: + try: + coro_obj = self._routing_store.clear_store() + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error during processing clear store task. {e}") + traceback.print_exc() + raise + + def process_clear_prefix_batch_task(self, store_task: StoreTask) -> None: + try: + coro_obj = self._routing_store.clear_prefix_batch(routing_prefix_key=store_task["key"]) + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error submitting clear_prefix_batch task: {e}") + traceback.print_exc() + raise + + def _on_async_task_completed(self, task, future): + """ """ + try: + # result = future.result() + logger.info(f"[R3] Async task completed: {task['task_type']}, key: {task['key']}") + except Exception as e: + logger.error(f"[R3] Async task failed: {task['task_type']}, key: {task['key']}, error: {e}") + traceback.print_exc() + raise + + def close(self): + """Close the store process""" + self._closed = True + if hasattr(self, "_event_loop_thread"): + self._event_loop_thread.stop() + + def pad_routing_indices(self, routing_indices: np.ndarray) -> np.ndarray: + """Pad routing indices of the request levevl to max model len""" + routing_shape = routing_indices.shape + if len(routing_shape) == 2: # [token, topk] + pad_array = np.full( + shape=[(self.max_model_len - routing_indices.shape[0]), routing_indices.shape[1]], + fill_value=-1, + dtype=routing_indices.dtype, + ) + return np.concatenate([routing_indices, pad_array], axis=0) + + elif len(routing_shape) == 3: # [layer, token, topk] + pad_array = np.full( + shape=[ + routing_indices.shape[0], + (self.max_model_len - routing_indices.shape[1]), + routing_indices.shape[2], + ], + fill_value=-1, + dtype=routing_indices.dtype, + ) + return np.concatenate([routing_indices, pad_array], axis=1) + else: + raise ValueError(f"Invalid routing indices shape: {routing_shape}") + + +class AsyncEventLoopThread(threading.Thread): + def __init__(self): + super().__init__(daemon=True) + self._loop = None + self._started_event = threading.Event() + self._closed = False + + def run(self): + """Run the async event loop""" + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + # Set the event loop to be started + self._started_event.set() + logger.info("[EventLoopThread] Event loop started, running forever...") + + try: + self._loop.run_forever() + logger.info("[EventLoopThread] Event loop stopped") + except Exception as e: + logger.error(f"[EventLoopThread] Event loop exception: {e}") + traceback.print_exc() + finally: + logger.info("[EventLoopThread] Closing event loop") + self._loop.close() + + def submit_coroutine(self, coro, callback=None): + """Thread safely submit coroutine to event loop""" + if self._closed: + raise RuntimeError("Event loop thread is closed") + if not self._started_event.wait(timeout=5.0): + raise RuntimeError("Event loop failed to start within 5 seconds") + + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + if callback: + + def wrapped_callback(f): + try: + callback(f) + except Exception as e: + logger.error(f"Error in callback: {e}") + traceback.print_exc() + + future.add_done_callback(wrapped_callback) + return future + + def stop(self): + """Stop the event loop""" + if not self._closed: + self._closed = True + if self._loop: + self._loop.call_soon_threadsafe(self._loop.stop) + + +class RoutingStoreBase(ABC): + """Base class for routing store""" + + def __init__(self, routing_replay_config: RoutingReplayConfig) -> None: + self.routing_replay_config = routing_replay_config + + @abstractmethod + async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: + """Put the routing indices into store""" + raise NotImplementedError + + @abstractmethod + async def clear_store( + self, + ): + """Clear the routing indices store""" + raise NotImplementedError + + @abstractmethod + async def clear_prefix_batch(self, routing_prefix_key: str): + """Clear the routing indices""" + raise NotImplementedError + + +class RoutingStoreLocal(RoutingStoreBase): + """Routing Store using local memory""" + + def __init__(self, routing_replay_config) -> None: + super().__init__(routing_replay_config=routing_replay_config) + self.local_store_dir = routing_replay_config.local_store_dir + os.makedirs(self.local_store_dir, exist_ok=True) + + async def put( + self, + routing_key: str, + routing_indices: np.ndarray, + ) -> None: + """Put the routing indices into store""" + # TODO(gongshaotian) covert ./store_dir/routing_key/layer_id.pdtensor to ./store_dir/routing_key.pdtensor + time_before_put = time.perf_counter() + + if len(routing_indices.shape) == 2: + re_layer_id, re_rollout_id = routing_key[::-1].split("_", 1) + rollout_id = re_rollout_id[::-1] + layer_id = re_layer_id[::-1] + request_path = os.path.join(self.local_store_dir, rollout_id) + file_path = os.path.join(request_path, f"layer_{layer_id}.pdtensor") + elif len(routing_indices.shape) == 3: + request_path = os.path.join(self.local_store_dir, routing_key) + file_path = os.path.join(request_path, f"{routing_key}.pdtensor") + else: + raise ValueError(f"Invalid routing indices shape: {routing_indices.shape}") + + paddle.save(routing_indices, file_path) + logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s") + + async def clear_store(self): + """Clear the routing indices store""" + if os.path.isdir(self.local_store_dir): + shutil.rmtree(self.local_store_dir) + + logger.info("[R3] Clear routing store.") + + async def clear_prefix_batch(self, routing_prefix_key: str): + """Clear the routing indices""" + raise NotImplementedError + + +class RoutingStoreRDMA(RoutingStoreBase): + """Routing Store using RDMA""" + + def __init__(self, routing_replay_config) -> None: + super().__init__(routing_replay_config=routing_replay_config) + try: + # Only used in RLHF + from p2pstore import P2PClient, P2PConfig + except ModuleNotFoundError: + raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") + + rdma_store_server = routing_replay_config.rdma_store_server + p2pConfig = P2PConfig(metadata_server=rdma_store_server) + self.p2p_client = P2PClient(p2pConfig) + + async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: + """Put the routing indices into store""" + time_before_put = time.perf_counter() + if len(routing_indices.shape) == 3: + # NOTE(gongshaotian) Fused put with bytes data + routing_bytes = routing_indices.tobytes() + result = await self.p2p_client.put(routing_key, routing_bytes) + else: + result = await self.p2p_client.put(routing_key, routing_indices) + logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") + return result + + async def clear_prefix_batch(self, routing_prefix_key: str): + time_before_clear = time.perf_counter() + result = await self.p2p_client.delete_batch([routing_prefix_key]) + logger.info( + f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" + ) + return result + + async def clear_store(self): + """Clear the routing indices store""" + time_before_clear = time.perf_counter() + result = await self.p2p_client.clear() + logger.info(f"[R3] Clear routing store cost is {time.perf_counter()-time_before_clear}s.") + return result + + +def get_routing_store(routing_replay_config: RoutingReplayConfig) -> RoutingStoreBase: + if routing_replay_config.routing_store_type == "local": + return RoutingStoreLocal(routing_replay_config=routing_replay_config) + elif routing_replay_config.routing_store_type == "rdma": + return RoutingStoreRDMA(routing_replay_config=routing_replay_config) + else: + raise ValueError( + f"Invalid routing store type: '{routing_replay_config.routing_store_type}'. " + "Valid types are: 'local', 'rdma'" + ) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 18efa2586b6..3ee93724f28 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1839,7 +1839,7 @@ def __init__(self, args) -> None: self.enable_routing_replay: bool = False - # Routing store type: local/rdma + # Routing return mode: "local" (file store) / "rdma" (P2PStore) / "response" (attach to RequestOutput) self.routing_store_type: str = "local" # Local routing store @@ -1854,11 +1854,37 @@ def __init__(self, args) -> None: # Fused routing of all layers self.use_fused_put: bool = False + # Auto-filled by FDConfig from ModelConfig (do not set manually) + self.routing_dtype: str = "" # "uint8" / "uint16" / "uint32" + self.num_moe_layers: int = 0 + self.moe_top_k: int = 0 + if args is not None: for key, value in args.items(): if hasattr(self, key) and value != "None": setattr(self, key, value) + def postprocess(self, model_config: "ModelConfig") -> None: + """Fill computed fields from ModelConfig. Must be called after model-specific + field unification (e.g. GLM's first_k_dense_replace → moe_layer_start_index).""" + if not self.enable_routing_replay: + return + self.num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index + if model_config.architectures[0] == "Glm4MoeForCausalLM": + self.moe_top_k = model_config.num_experts_per_tok + else: + self.moe_top_k = model_config.moe_k + num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts + total_number = num_experts + 1 # +1 for reserved fill value + if total_number <= 255: + self.routing_dtype = "uint8" + elif total_number <= 65535: + self.routing_dtype = "uint16" + elif total_number <= 4294967295: + self.routing_dtype = "uint32" + else: + raise ValueError(f"num_experts {num_experts} exceeds uint32 range") + def to_json_string(self): """ Convert routing replay config to json string. @@ -1921,6 +1947,7 @@ def __init__( self.router_config: RouterConfig = router_config self.routing_replay_config = routing_replay_config self.deploy_modality: DeployModality = deploy_modality + # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs if self.graph_opt_config.cudagraph_only_prefill: @@ -2062,6 +2089,9 @@ def postprocess(self): # The first moe layer id of GLM4.5 model self.model_config.moe_layer_start_index = self.model_config.first_k_dense_replace + if self.routing_replay_config is not None: + self.routing_replay_config.postprocess(self.model_config) + if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node or self.node_rank == 0: self.is_master = True self.master_ip = "0.0.0.0" diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 22016511f61..833b4830824 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2552,10 +2552,47 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingCacheManager (SharedMemory) after num_gpu_blocks is known + self.routing_cache_manager = None + if self.cfg.routing_replay_config.enable_routing_replay: + self._init_routing_cache_manager(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) + def _init_routing_cache_manager(self, num_gpu_blocks: int): + """Create RoutingCacheManager (includes SharedMemory host buffer) after profiling.""" + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingCacheManager, + RoutingHostBufferView, + ) + + self.routing_cache_manager = RoutingCacheManager( + fd_config=self.cfg, + num_gpu_blocks=num_gpu_blocks, + ) + + # Pass routing_cache_manager to TokenProcessor for local/rdma store dispatch + self.token_processor.routing_cache_manager = self.routing_cache_manager + + # Set routing_host_view on resource_manager for PD disaggregation (D side) + if hasattr(self, "resource_manager") and hasattr(self.resource_manager, "routing_host_view"): + rrc = self.cfg.routing_replay_config + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = num_gpu_blocks * self.cfg.cache_config.block_size + shape = (max_num_kv_tokens, rrc.num_moe_layers, rrc.moe_top_k) + try: + self.resource_manager.routing_host_view = RoutingHostBufferView( + shape=shape, dtype=rrc.routing_dtype, shm_name=shm_name + ) + except FileNotFoundError: + self.llm_logger.warning( + f"[R3] RoutingHostBuffer SharedMemory {shm_name} not found for resource_manager" + ) + def check_health(self, time_interval_threashold=30): """ Check the health of the model server by checking whether all workers are alive. diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 44edea80d34..ba9049be0c6 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -143,6 +143,12 @@ def start(self, api_server_pid=None): self.engine.create_data_processor() self.data_processor = self.engine.data_processor + # Create RoutingCacheManager when skipping profiling (num_gpu_blocks_override is set) + if not self.do_profile and self.cfg.routing_replay_config.enable_routing_replay: + num_gpu_blocks = self.cfg.cache_config.num_gpu_blocks_override + if num_gpu_blocks is not None: + self.engine._init_routing_cache_manager(num_gpu_blocks) + # If block numer is specified and model is deployed in mixed mode, start cache manager first if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": if not current_platform.is_intel_hpu(): @@ -458,6 +464,11 @@ def _exit_sub_services(self): if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close() + if hasattr(self, "engine") and hasattr(self.engine, "routing_cache_manager"): + if self.engine.routing_cache_manager is not None: + self.engine.routing_cache_manager.close() + self.engine.routing_cache_manager = None + if hasattr(self, "dp_processed"): for p in self.dp_processed: console_logger.info(f"Waiting for worker {p.pid} to exit") @@ -763,6 +774,11 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingCacheManager (SharedMemory) before starting cache service + if self.cfg.routing_replay_config.enable_routing_replay: + self.engine._init_routing_cache_manager(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": if not current_platform.is_intel_hpu(): device_ids = self.cfg.parallel_config.device_ids.split(",") diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 624f8f32951..7ecac8ff126 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -1026,6 +1026,7 @@ def __init__( self.ic_req_data = ic_req_data self.prompt_token_ids_len = prompt_token_ids_len self.trace_carrier = trace_carrier + self.routing_data = None # Optional[np.ndarray], [seq_len, num_moe_layers, top_k] if prompt_token_ids is None: self.prompt_token_ids = [] @@ -1121,12 +1122,15 @@ def from_dict(cls, d: dict): d.pop("metrics", None) metrics = None trace_carrier = d.pop("trace_carrier", {}) - return RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + routing_data = d.pop("routing_data", None) + obj = RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + obj.routing_data = routing_data + return obj def to_dict(self): """convert RequestOutput into a serializable dict""" - return { + d = { "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, @@ -1144,6 +1148,9 @@ def to_dict(self): "prompt_token_ids_len": self.prompt_token_ids_len, "trace_carrier": self.trace_carrier, } + if self.routing_data is not None: + d["routing_data"] = self.routing_data + return d def get(self, key: str, default_value=None): if hasattr(self, key): diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 31af47f0507..27a1328c1ee 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -209,6 +209,8 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) self.processor_cache = None + self.routing_host_view = None # Set by Engine after RoutingHostBuffer creation + if config.enable_mm_runtime and config.cache_config.max_processor_cache > 0: max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) @@ -1490,9 +1492,38 @@ def add_prefilled_request(self, request_output: RequestOutput): request.metrics.decode_inference_start_time = time.time() request.metrics.update_decoder_start_time() + # [R3] Write P's prefill routing data into D's routing_host_buffer + if ( + self.routing_host_view is not None + and hasattr(request_output, "routing_data") + and request_output.routing_data is not None + ): + try: + self._write_prefill_routing_to_host_buffer(request, request_output.routing_data) + except Exception as e: + llm_logger.warning(f"[R3] Failed to write prefill routing for {request_output.request_id}: {e}") + self.tasks_list[request.idx] = request self.running.append(request) + def _write_prefill_routing_to_host_buffer(self, request, routing_data): + """ + Write P's prefill routing data into D's routing_host_buffer. + Uses D's block_tables to compute slot_mapping. + """ + import math + + seq_len = routing_data.shape[0] + block_size = self.config.cache_config.block_size + num_blocks_needed = math.ceil(seq_len / block_size) + block_ids = request.block_tables[:num_blocks_needed] + + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + self.routing_host_view.scatter(slot_mapping, routing_data) + def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode": self.cache_manager.release_block_ids(request) diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 883ff71e858..a546017d30f 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -285,6 +285,7 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo + routed_experts: Optional[str] = None class LogProbEntry(BaseModel): @@ -390,6 +391,7 @@ class CompletionResponse(BaseModel): model: str choices: List[CompletionResponseChoice] usage: UsageInfo + routed_experts: Optional[str] = None class CompletionLogprobs(BaseModel): diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index c6e31605e0a..1bbee7b07f9 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -588,6 +588,7 @@ async def chat_completion_full_generator( sampling_mask_list = [[] for _ in range(num_choices)] speculate_metrics = [None for _ in range(num_choices)] choices = [] + routing_data_result = None while num_choices > 0: if self.engine_client.check_model_weight_status(): return ErrorResponse( @@ -721,6 +722,15 @@ async def chat_completion_full_generator( speculate_metrics=speculate_metrics[idx], ) choices.append(choice) + if data.get("routing_data") is not None: + import base64 + + import numpy as np + + rd = data["routing_data"] + if not isinstance(rd, np.ndarray): + rd = np.array(rd) + routing_data_result = base64.b64encode(rd.tobytes()).decode("utf-8") finally: trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", "")) tracing.trace_req_finish(request_id) @@ -752,6 +762,7 @@ async def chat_completion_full_generator( model=model_name, choices=choices, usage=usage, + routed_experts=routing_data_result, ) api_server_logger.info(f"Chat response: {res.model_dump_json()}") return res diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index cdd0a1a096d..e7941197f37 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -786,12 +786,24 @@ def request_output_to_completion_response( ) del request + routed_experts = None + if final_res_batch and final_res_batch[-1].get("routing_data") is not None: + import base64 + + import numpy as np + + rd = final_res_batch[-1]["routing_data"] + if not isinstance(rd, np.ndarray): + rd = np.array(rd) + routed_experts = base64.b64encode(rd.tobytes()).decode("utf-8") + return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, + routed_experts=routed_experts, ) async def _call_process_response_dict(self, res, request, stream): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index f9db2f8e50b..96bc09934a8 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -259,8 +259,6 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), - # Suspend rollouting routing replay - "FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))), # train-infer consistency, used in RL # Whether to align RoPE and moe gate precision with training "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index d6a448a2693..64555bcbdfe 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -146,8 +146,8 @@ class ForwardMeta: caches: Optional[list[paddle.Tensor]] = None # Flag of profile run is_dummy_or_profile_run: bool = False - # Routing Replay table buffer - routing_replay_table: Optional[paddle.Tensor] = None + # GPU transient routing buffer [max_num_batched_tokens, num_moe_layers, top_k] + gpu_routing_buffer: Optional[paddle.Tensor] = None # chunked MoE related moe_num_chunk: int = 1 diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index f7d0b32c7a5..cc427a3fe54 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -28,7 +28,7 @@ ) from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( - save_routing_to_buffer, + save_routing_to_buffer_v2, ) from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import h2d_copy, slice_fn @@ -720,23 +720,21 @@ def forward( Tensor: Output tensor.s """ + topk_ids_hookfunc = None if self.enable_routing_replay: # When execute empty_input_forward forward_meta is None. When execute mtp layer routing_replay_table is None. - if forward_meta is not None and forward_meta.routing_replay_table is not None: + if forward_meta is not None and forward_meta.gpu_routing_buffer is not None: moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index topk_ids_hookfunc = partial( - save_routing_to_buffer, - routing_replay_table=forward_meta.routing_replay_table, - batch_id_per_token=forward_meta.batch_id_per_token, - seq_lens_decoder=forward_meta.seq_lens_decoder, - cu_seqlens_q=forward_meta.cu_seqlens_q, + save_routing_to_buffer_v2, + gpu_routing_buffer=forward_meta.gpu_routing_buffer, layer_idx=moe_layer_idx, tp_size=self.fd_config.parallel_config.tensor_parallel_size, ep_size=self.fd_config.parallel_config.expert_parallel_size, tp_group=self.fd_config.parallel_config.tp_group, + total_token_num=forward_meta.batch_id_per_token.shape[0], ) - if current_platform.is_intel_hpu(): out = self.forward_normal( x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index c139700f347..534e303c89c 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -14,20 +14,6 @@ # limitations under the License. """ -import asyncio -import atexit -import functools -import multiprocessing -import os -import shutil -import threading -import time -import traceback -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from multiprocessing import Process, Queue -from typing import Dict, Optional, TypedDict - import numpy as np import paddle import paddle.distributed as dist @@ -35,7 +21,8 @@ import triton.language as tl from paddleformers.utils.log import logger -from fastdeploy.config import FDConfig, RoutingReplayConfig +from fastdeploy.cache_manager.routing_cache_manager import RoutingHostBufferView +from fastdeploy.config import FDConfig from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( enable_compat_on_triton_kernel, ) @@ -43,75 +30,44 @@ @enable_compat_on_triton_kernel @triton.jit -def _save_routing_kernel( - ROUTING_REPLAY_TABLE_PTR, +def _save_routing_kernel_v2( + GPU_ROUTING_BUFFER_PTR, TOPK_IDS_PTR, - BATCH_ID_PER_TOKEN_PTR, - CU_SEQLENS_Q_PTR, - SEQ_LENS_DECODER_PTR, LAYER_IDX, TOKEN_NUM, TOP_K, - NUM_HIDDEN_LAYERS, - MAX_MODEL_LEN, - MAX_NUM_SEQS, + NUM_MOE_LAYERS, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(axis=0) - token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) token_mask = token_offsets < TOKEN_NUM - k_offsets = tl.arange(0, BLOCK_SIZE_K) k_mask = k_offsets < TOP_K - topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] load_mask = token_mask[:, None] & k_mask[None, :] - topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1) - - batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1) - - batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS) - pad_mask = token_mask & (batch_ids != -1) & batch_mask - - start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0) - token_relative_index = token_offsets - start_offsets - - len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0) - token_seq_pos = len_decoder + token_relative_index - - STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64) - STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64) - STRIDE_BUF_LAYER = TOP_K + topk_vals = tl.load( + TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :], + mask=load_mask, + ) + STRIDE_TOKEN = NUM_MOE_LAYERS * TOP_K + STRIDE_LAYER = TOP_K output_ptrs = ( - ROUTING_REPLAY_TABLE_PTR - + tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ - + tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN - + tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER - + k_offsets[None, :] + GPU_ROUTING_BUFFER_PTR + token_offsets[:, None] * STRIDE_TOKEN + LAYER_IDX * STRIDE_LAYER + k_offsets[None, :] ) + tl.store(output_ptrs, topk_vals, mask=load_mask) - pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN) - pos_mask = pos_mask & pad_mask - pos_mask = pos_mask[:, None] & k_mask[None, :] - - final_mask = load_mask & pos_mask - - tl.store(output_ptrs, topk_vals, mask=final_mask) - -def save_routing_to_buffer( - routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k] - topk_ids: paddle.Tensor, # [token_num, top_k] - batch_id_per_token: paddle.Tensor, # [token_num, 1] - seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1] - cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1] +def save_routing_to_buffer_v2( + gpu_routing_buffer: paddle.Tensor, + topk_ids: paddle.Tensor, layer_idx: int, tp_size: int, ep_size: int, tp_group: dist.communication.group.Group, + total_token_num: int = -1, ): token_num_per_rank = topk_ids.shape[0] if token_num_per_rank == 0: @@ -119,125 +75,141 @@ def save_routing_to_buffer( if tp_size > 1 and ep_size > 1: topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) - topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] + assert ( + total_token_num >= token_num_per_rank + ), f"[R3] total_token_num={total_token_num} < token_num_per_rank={token_num_per_rank}" + topk_ids = topk_ids_all[:total_token_num, :] token_num, top_k = topk_ids.shape - max_num_seqs, max_model_len, num_hidden_layers, _ = routing_replay_table.shape - assert token_num > 0 + buf_max_tokens, num_moe_layers, buf_top_k = gpu_routing_buffer.shape + assert ( - topk_ids.shape[1] == routing_replay_table.shape[3] - ), f"({topk_ids.shape[1]}, {routing_replay_table.shape[3]})" - assert batch_id_per_token.shape[0] == token_num, f"({batch_id_per_token.shape[0]}, {token_num})" + token_num <= buf_max_tokens + ), f"[R3] token_num={token_num} exceeds gpu_routing_buffer capacity={buf_max_tokens}" + assert top_k == buf_top_k, f"[R3] top_k mismatch: topk_ids.top_k={top_k} vs gpu_routing_buffer.top_k={buf_top_k}" + assert 0 <= layer_idx < num_moe_layers, f"[R3] layer_idx={layer_idx} out of range [0, {num_moe_layers})" BLOCK_SIZE_M = 128 - BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k - + BLOCK_SIZE_K = triton.next_power_of_2(top_k) grid = (triton.cdiv(token_num, BLOCK_SIZE_M),) - _save_routing_kernel[grid]( - routing_replay_table, + _save_routing_kernel_v2[grid]( + gpu_routing_buffer, topk_ids, - batch_id_per_token, - cu_seqlens_q, - seq_lens_decoder, LAYER_IDX=layer_idx, TOKEN_NUM=token_num, TOP_K=top_k, - NUM_HIDDEN_LAYERS=num_hidden_layers, - MAX_MODEL_LEN=max_model_len, - MAX_NUM_SEQS=max_num_seqs, + NUM_MOE_LAYERS=num_moe_layers, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_K=BLOCK_SIZE_K, ) -class RoutingReplayManager: - """Request level routing replay table manager""" +class RoutedExpertsCapturer: + """ + Worker-side routing capture: manages GPU transient buffer and GPU→CPU scatter. + Does NOT manage request lifecycle — that is handled by RoutingCacheManager on the Engine side. + """ def __init__(self, fd_config: FDConfig, block_table, total_block_num): self.fd_config = fd_config self.block_table = block_table self.max_num_seqs = fd_config.scheduler_config.max_num_seqs - self.max_model_len = fd_config.model_config.max_model_len - self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index - self.only_last_turn = fd_config.routing_replay_config.only_last_turn - self.use_fused_put = fd_config.routing_replay_config.use_fused_put - logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}") - if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": - self.moe_top_k = fd_config.model_config.num_experts_per_tok - else: - self.moe_top_k = fd_config.model_config.moe_k + + # Read routing params from centralized config + rrc = fd_config.routing_replay_config + self.num_moe_layers = rrc.num_moe_layers + self.moe_top_k = rrc.moe_top_k + self.routing_dtype = rrc.routing_dtype self.tp_rank = fd_config.parallel_config.tensor_parallel_rank - # Initialize the routing replay table and routing cache - self.routing_batch_to_request: Dict[int, str] = {} - num_experts = fd_config.model_config.moe_num_experts + fd_config.model_config.moe_num_shared_experts - self.routing_dtype = self.get_routing_dtype(num_experts=num_experts) + logger.info(f"[R3] RoutedExpertsCapturer config: {rrc}") + self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num) self.pending_update_positions = None - # Initialize routing store wrapper - if self.tp_rank == 0: - self._store_wrapper = StoreWrapper( - fd_config=fd_config, - ) - self._store_wrapper.start_store_warpper() - - # Suspend Routing Replay - self.suspend_routing_replay = False - self.update_suspend_routing_replay() - - def update_suspend_routing_replay(self): - """Allow RL to use R3 in different training rounds""" - # TODO(gongshaotian): Delete this func - suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0") - self.suspend_routing_replay = bool(int(suspend_routing_replay)) - logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}") - def _init_routing_cache(self, dtype: str, total_block_num: int): - """Initialize the device buffer and host buffer.""" - + """Initialize GPU transient buffer and prepare lazy SharedMemory attach.""" max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size - self._host_cache = paddle.full( - shape=[max_num_kv_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, dtype=dtype, device="cpu" - ) - - self.routing_replay_table = paddle.full( - shape=[self.max_num_seqs, self.max_model_len, self.num_moe_layers, self.moe_top_k], + # Small GPU transient buffer: only current step's token routing + max_num_batched_tokens = self.fd_config.scheduler_config.max_num_batched_tokens + self.gpu_routing_buffer = paddle.full( + shape=[max_num_batched_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, dtype=dtype, ) + + # Lazy attach to SharedMemory routing_host_buffer (created by Engine after profiling) + self.routing_host_view = None + self._routing_host_view_attach_attempted = False + self._routing_host_view_shm_name = ( + f"routing_host_buffer.{str(self.fd_config.parallel_config.local_engine_worker_queue_port)}" + ) + self._routing_host_view_shape = (max_num_kv_tokens, self.num_moe_layers, self.moe_top_k) + self._routing_host_view_dtype = dtype + + gpu_buffer_bytes = int(np.prod(self.gpu_routing_buffer.shape)) * np.dtype(dtype).itemsize logger.info( - f"[R3] The host cache size is:{self._host_cache.shape}, device cache size is: {self.routing_replay_table.shape}" + f"[R3] GPU transient routing buffer: {self.gpu_routing_buffer.shape} " + f"({gpu_buffer_bytes / 1024:.1f} KB)" ) - def get_routing_dtype(self, num_experts: int, reserved_fill_value: int = 1) -> str: - """Calculate the minimum number of bits required for storage routing.""" - if num_experts <= 0: - raise ValueError(f"num_experts must be greater than 0 but got {num_experts}, please check model config.") - dtype = "uint8" - total_number = num_experts + reserved_fill_value - if total_number <= 255: # uint8: 0~255 - dtype = "uint8" - elif total_number <= 65535: # uint16: 0~65,535 - dtype = "uint16" - elif total_number <= 4294967295: # uint32: 0~4,294,967,295 - dtype = "uint32" - else: - raise ValueError( - f"The number of experts {num_experts} exceeds the representation range of uint32, please check model config." + def _try_attach_routing_host_view(self): + """Lazily attach to SharedMemory routing_host_buffer on first use.""" + if self._routing_host_view_attach_attempted: + return + self._routing_host_view_attach_attempted = True + try: + self.routing_host_view = RoutingHostBufferView( + shape=self._routing_host_view_shape, + dtype=self._routing_host_view_dtype, + shm_name=self._routing_host_view_shm_name, + ) + logger.info(f"[R3] Attached to RoutingHostBuffer SharedMemory: {self._routing_host_view_shm_name}") + except FileNotFoundError: + logger.warning( + f"[R3] RoutingHostBuffer SharedMemory {self._routing_host_view_shm_name} not found. " + "Routing capture will be skipped." ) - logger.info(f"[R3] Routing replay table dtype: {dtype}") - return dtype - def update_host_cache(self, positions: paddle.Tensor, slot_mapping: paddle.Tensor): - """Update the host cache with new tokens""" - for batch_id, position in enumerate(positions): - if len(position) > 0 and len(slot_mapping[batch_id]) > 0: - routing_ids = self.routing_replay_table[batch_id, position, :, :].contiguous() - routing_ids = routing_ids.cpu() + def save_captured_routing(self, num_tokens: int, slot_mapping: np.ndarray): + """ + After forward, scatter GPU buffer routing data to routing_host_buffer. + Called in step gap (post_process), not during forward. CUDAGraph compatible. + """ + assert slot_mapping.shape[0] == num_tokens + if num_tokens == 0: + return + + # Lazy attach to SharedMemory (Engine creates it after profiling completes) + if self.routing_host_view is None and not self._routing_host_view_attach_attempted: + self._try_attach_routing_host_view() - self._host_cache[slot_mapping[batch_id], :, :] = routing_ids + if self.routing_host_view is None: + return + + # D2H copy: GPU → CPU numpy, then scatter to SharedMemory + data = self.gpu_routing_buffer[:num_tokens].cpu().numpy() + self.routing_host_view.scatter(slot_mapping, data) + + def compute_slot_mapping_flat(self, positions) -> np.ndarray: + """ + Compute flat slot_mapping for all tokens in the step. + Returns a 1D numpy array of slot indices. + """ + all_slots = [] + block_size = self.fd_config.cache_config.block_size + for batch_id, position in enumerate(positions): + if len(position) == 0: + continue + block_table_indices = position // block_size + token_block_ids = self.block_table[batch_id, block_table_indices] + block_offset = position % block_size + token_cache_ids = np.array(token_block_ids) * block_size + block_offset + all_slots.append(token_cache_ids) + if all_slots: + return np.concatenate(all_slots) + return np.array([], dtype=np.int64) def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): """Get token position of each sequence in a batch.""" @@ -245,7 +217,7 @@ def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): increase_num = seq_lens_this_time.numpy() positions = [] - for i in range(self.max_num_seqs): + for i in range(seq_lens_this_time.shape[0]): if seq_lens_this_time[i] == 0: positions.append([]) continue @@ -254,640 +226,14 @@ def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): return positions - def compute_slot_mapping(self, positions: np.ndarray): - """Compute the mapping between token ids and kvcache slots""" - slot_mapping = [] - for batch_id, position in enumerate(positions): - if len(position) == 0: - slot_mapping.append([]) - continue - block_table_indices = position // self.fd_config.cache_config.block_size - token_block_ids = self.block_table[batch_id, block_table_indices] - block_offset = position % self.fd_config.cache_config.block_size - - token_cache_ids = np.array(token_block_ids) * self.fd_config.cache_config.block_size + block_offset - slot_mapping.append(token_cache_ids) + def get_gpu_routing_buffer(self) -> paddle.Tensor: + return self.gpu_routing_buffer - return slot_mapping - - def _get_routing_from_cache(self, finished_batch_ids, seq_lens_decoder): - """ - When request is finished or cleared the length of the request is recorded at seq_lens_decoder - 1. finish the step: after update input, lens = seq_lens_decoder_buffer - 2. clear parameter: after update input, lens = seq_lens_decoder_buffer - """ - # Get the slot mapping of the request cache. - current_token_nums = seq_lens_decoder.numpy() - positions = [] - for batch_id in range(self.max_num_seqs): - position = [] - if batch_id in finished_batch_ids: - position = np.arange(0, current_token_nums[batch_id]) - positions.append(position) - - # Collection the cached routing information - token_cache_ids = self.compute_slot_mapping(positions=positions) - for slot_map in token_cache_ids: - if len(slot_map) > 0: - token_cached_routing = self._host_cache[slot_map, :, :] - return paddle.transpose(token_cached_routing, [1, 0, 2]) - raise ValueError("No cached routing found") - - def put_finished_batch( - self, - finished_batch_ids, - seq_lens_decoder, - ): - finished_batch_ids_list = finished_batch_ids.cpu().tolist() - for batch_id, finished in enumerate(finished_batch_ids_list): - if finished: - assert batch_id in self.routing_batch_to_request.keys() - # Deregister the request - request_id = self._deregister_request(batch_id) - # Put the routing of finished request to store - self._put_request_to_store( - batch_id=batch_id, - request_id=request_id, - seq_lens_decoder=seq_lens_decoder, - ) - # Clear the slot of the finished batch - self._clear_table_slot(batch_id) - - def register_request(self, batch_id: int, request_id: str): - """ - Register a new request to routing replay table - Args: - batch_id: The batch ID of this request - request_id: The global ID of the request is usually executed by the training process in RL - """ - # The chunked prefill tasks will be registered repeatedly - if batch_id in self.routing_batch_to_request: - if self.routing_batch_to_request[batch_id] == request_id: - logger.warning(f"[R3] Request {request_id} has been registered at {batch_id}.") - return - else: - raise RuntimeError( - f"[R3] The Batch {batch_id} has been registered by request {self.routing_batch_to_request[batch_id]}, now robed by {request_id}," - ) - - # Register the new request - self.routing_batch_to_request[batch_id] = request_id - logger.info(f"[R3] Register request {request_id} with batch id {batch_id}") - - def _deregister_request(self, batch_id: int) -> str: - """ - Deregister a request from routing replay table - """ - assert batch_id in self.routing_batch_to_request - return self.routing_batch_to_request.pop(batch_id) - - def _put_request_to_store( - self, - batch_id: int, - request_id: str, - seq_lens_decoder, - ): - if self.tp_rank == 0: - # TODO(gongshaotian): Delete the suspend func - if self.suspend_routing_replay: - logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store") - return - - before_put_request_time = time.perf_counter() - - # Collect the routing of finished request - batch_buffer = self._get_routing_from_cache( - finished_batch_ids=[batch_id], seq_lens_decoder=seq_lens_decoder - ) - rollout_id = self.split_request_id(request_id) - - if self.use_fused_put: - self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id) - # Only store the routing of last turn - if self.only_last_turn: - self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) - - else: - for layer_id in range(self.num_moe_layers): - layer_buffer = batch_buffer[layer_id] - self._store_wrapper.submit_put_task( - routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id - ) - # Only store the routing of last turn - if self.only_last_turn: - self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) - - logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}") - - def clear_request(self, batch_id: int): - """Clear the routing indices of the request""" - self._clear_table_slot(batch_id) - self.routing_batch_to_request.pop(batch_id, None) - - def _clear_table_slot(self, batch_id: int): - assert 0 <= batch_id < self.max_num_seqs - self.routing_replay_table[batch_id].fill_(-1) - - def get_routing_table(self) -> paddle.Tensor: - return self.routing_replay_table - - def split_request_id(self, request_id: str): - """ - Split the request id to get rollout id. - - request_id: "chatcmpl-request.user-uuid" - rollout_id: "request.user" - example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" -> "xxx_xxx_epoch_15:2:2:1" - """ - chat_type, tmp_str = request_id.split("-", 1) - # NOTE(gongshaotian): only support chatcmpl now - assert ( - chat_type == "chatcmpl" - ), "Rollout Routing Replay only supports chatcmpl. Please check whether the request type and userid settings are correct." - reversed_tmp_str = tmp_str[::-1].split("-", 5) - rollout_id = reversed_tmp_str[-1][::-1] - return rollout_id - - def clear_all_request(self): - """Clear all requests""" - self.routing_replay_table.fill_(-1) - self.routing_batch_to_request = {} - - -class StoreWrapper(object): - def __init__(self, fd_config: False) -> None: - super().__init__() - self.fd_config = fd_config - - # Initialize task queue - moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index - max_num_seqs = fd_config.scheduler_config.max_num_seqs - self.queue_max_size = moe_layer_num * max_num_seqs * 1000 - - self.manager = multiprocessing.Manager() - self._task_queue = self.manager.Queue(maxsize=self.queue_max_size) - - self._monitor_thread: threading.Thread = None - self._stop_monitor = threading.Event() - - # Initialize consumer process - self._routing_store_process = StoreProcess( - task_queue=self._task_queue, - routing_replay_config=self.fd_config.routing_replay_config, - max_model_len=self.fd_config.model_config.max_model_len, - ) - self._sotre_process_running = False - - # Register atexit handler - atexit.register(self.shutdown) - - def shutdown(self): - """ """ - if not self._sotre_process_running: - return - self._sotre_process_running = False - - # Stop the monitor thread - self._stop_monitor.set() - if self._monitor_thread and self._monitor_thread.is_alive(): - self._monitor_thread.join(timeout=3.0) - - # Put a sentinel value to signal the consumer to stop - if self._routing_store_process and self._routing_store_process.is_alive(): - try: - self._task_queue.put_nowait(None) - except Exception as e: - logger.info(f"Could not put sentinel into queue: {e}") - - if self._routing_store_process and self._routing_store_process.is_alive(): - # Wait for all tasks to be processed - self._routing_store_process.join(timeout=10.0) - if self._routing_store_process.is_alive(): - self._routing_store_process.close() - self._routing_store_process.join() - - self._task_queue.join() - self.manager.shutdown() - self._sotre_process_running = False - - def start_store_warpper(self): - """ """ - if self._sotre_process_running: - return - self._sotre_process_running = True - - # Start monitor thread - self._stop_monitor.clear() - self._monitor_thread = threading.Thread(target=self._monitor_queue_load, daemon=True) - self._monitor_thread.start() - - # Start Routing Store Wrapper in sub process - self._routing_store_process.start() - - def _monitor_queue_load(self): - """ """ - while not self._stop_monitor.is_set(): - time.sleep(2.0) - if not self._sotre_process_running: - break - qsize = self._task_queue.qsize() - - # Alarm when the task exceeds 80% of the queue capacity - if qsize > self.queue_max_size * 0.8: - logger.warning( - f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " - "Consider increasing max_workers or queue_max_size." - ) - logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") - - def submit_put_task(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int = None) -> None: - """Submit a put task to the task queue""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - - start_time = time.perf_counter() - if layer_idx is not None: - rdma_rollout_key = f"{rollout_id}_{layer_idx}" - else: - rdma_rollout_key = rollout_id - - routing_indices_np = routing_indices.numpy() - - task: StoreTask = {"task_type": "put", "key": rdma_rollout_key, "data": routing_indices_np} - - try: - self._task_queue.put_nowait(task) - except Exception: - raise RuntimeError(f"Queue is FULL. Dropping put task for key: {rdma_rollout_key}. ") - logger.info(f"[R3] Submit put task for key: {rdma_rollout_key}, cost time: {time.perf_counter()-start_time} s") - - def submit_clear_store_task(self) -> None: - """Submit clear store task""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - - start_time = time.perf_counter() - task: StoreTask = {"task_type": "clear_store", "key": None, "data": None} - - try: - self._task_queue.put_nowait(task) - # Wait for the task to be processed - self._task_queue.join() - except Exception: - raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") - logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") - - def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: - """Submit clear prefix batch task""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - prefix_batch_id = self.get_needed_clear_ids(rollout_id) - if prefix_batch_id is None: - return - start_time = time.perf_counter() - if layer_idx is not None: - rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" - else: - rdma_rollout_key = prefix_batch_id - - task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} - try: - self._task_queue.put_nowait(task) - except Exception: - raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") - logger.info( - f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" - ) - - def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]: - """ - Generate the prefix IDs for all closed multi-round tasks. - rollout_id: "xxx_xxx_epoch_15:2:2:1" - example: xxx_xxx_data_id:gen_id:turn_id:segment_id - """ - reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = roullout_id[::-1].split(":", 2) - prefix_gen_id = reversed_prefix_gen_id[::-1] - turn_id = eval(reversed_turn_id[::-1]) - segment_id = eval(reversed_segment_id[::-1]) - - assert turn_id >= 0 and segment_id >= 0 - prefix_batch = None - if turn_id > 0: - prefix_batch = f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}" - return prefix_batch - - -class StoreTask(TypedDict): - task_type: str - key: str - data: np.ndarray - - -class StoreProcess(Process): - def __init__(self, task_queue: Queue, routing_replay_config: RoutingReplayConfig, max_model_len: int) -> None: - super().__init__() - self.max_model_len = max_model_len - self._task_queue = task_queue - self.routing_replay_config = routing_replay_config - self.max_workers = 5 - self._closed = False - - # Note: _routing_store and _event_loop_thread must be initialized in run() - # because they cannot be properly inherited after fork() - self._routing_store = None - self._event_loop_thread = None - - def run(self): - logger.info(f"[R3] Start Running Store Wrapper in sub process {os.getpid()}") - - # Initialize routing store in subprocess - self._routing_store = get_routing_store(routing_replay_config=self.routing_replay_config) - - # Initialize event loop thread in subprocess - self._event_loop_thread = AsyncEventLoopThread() - self._event_loop_thread.start() - if not self._event_loop_thread._started_event.wait(timeout=5.0): - raise RuntimeError("Failed to start async event loop thread in subprocess") - - clear_store_task = StoreTask({"task_type": "clear_store", "key": None, "data": None}) - self._task_queue.put_nowait(clear_store_task) - - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - while not self._closed: - try: - task = self._task_queue.get() - if task is None: # Sentinel - self._task_queue.task_done() - break - - if task["task_type"] == "put": - future = executor.submit(self.process_put_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - elif task["task_type"] == "clear_store": - future = executor.submit(self.process_clear_store_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - elif task["task_type"] == "clear_prefix_batch": - future = executor.submit(self.process_clear_prefix_batch_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - except Exception as e: - self._task_queue.task_done() - raise RuntimeError(f"Error during processing task. {e}") - - logger.info("RoutingReplay Consumer Process Shutdown.") - - def process_put_task(self, store_task: StoreTask) -> None: - try: - # TODO(gongshaotian): delete this after trainer support dynamic len - store_task["data"] = self.pad_routing_indices(store_task["data"]) - coro_obj = self._routing_store.put(routing_key=store_task["key"], routing_indices=store_task["data"]) - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error submitting put task: {e}") - traceback.print_exc() - raise - - def process_clear_store_task(self, store_task: StoreTask) -> None: - try: - coro_obj = self._routing_store.clear_store() - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error during processing clear store task. {e}") - traceback.print_exc() - raise - - def process_clear_prefix_batch_task(self, store_task: StoreTask) -> None: - try: - coro_obj = self._routing_store.clear_prefix_batch(routing_prefix_key=store_task["key"]) - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error submitting clear_prefix_batch task: {e}") - traceback.print_exc() - raise - - def _on_async_task_completed(self, task, future): - """ """ - try: - # result = future.result() - logger.info(f"[R3] Async task completed: {task['task_type']}, key: {task['key']}") - except Exception as e: - logger.error(f"[R3] Async task failed: {task['task_type']}, key: {task['key']}, error: {e}") - traceback.print_exc() - raise - - def close(self): - """Close the store process""" - self._closed = True - if hasattr(self, "_event_loop_thread"): - self._event_loop_thread.stop() - - def pad_routing_indices(self, routing_indices: np.ndarray) -> np.ndarray: - """Pad routing indices of the request levevl to max model len""" - routing_shape = routing_indices.shape - if len(routing_shape) == 2: # [token, topk] - pad_array = np.full( - shape=[(self.max_model_len - routing_indices.shape[0]), routing_indices.shape[1]], - fill_value=-1, - dtype=routing_indices.dtype, - ) - return np.concatenate([routing_indices, pad_array], axis=0) - - elif len(routing_shape) == 3: # [layer, token, topk] - pad_array = np.full( - shape=[ - routing_indices.shape[0], - (self.max_model_len - routing_indices.shape[1]), - routing_indices.shape[2], - ], - fill_value=-1, - dtype=routing_indices.dtype, - ) - return np.concatenate([routing_indices, pad_array], axis=1) - else: - raise ValueError(f"Invalid routing indices shape: {routing_shape}") - - -class AsyncEventLoopThread(threading.Thread): - def __init__(self): - super().__init__(daemon=True) - self._loop = None - self._started_event = threading.Event() - self._closed = False - - def run(self): - """Run the async event loop""" - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) + def clear(self): + """Clear GPU buffer and pending positions. Used during RL round cleanup.""" + self.gpu_routing_buffer.fill_(-1) + self.pending_update_positions = None - # Set the event loop to be started - self._started_event.set() - logger.info("[EventLoopThread] Event loop started, running forever...") - try: - self._loop.run_forever() - logger.info("[EventLoopThread] Event loop stopped") - except Exception as e: - logger.error(f"[EventLoopThread] Event loop exception: {e}") - traceback.print_exc() - finally: - logger.info("[EventLoopThread] Closing event loop") - self._loop.close() - - def submit_coroutine(self, coro, callback=None): - """Thread safely submit coroutine to event loop""" - if self._closed: - raise RuntimeError("Event loop thread is closed") - if not self._started_event.wait(timeout=5.0): - raise RuntimeError("Event loop failed to start within 5 seconds") - - future = asyncio.run_coroutine_threadsafe(coro, self._loop) - - if callback: - - def wrapped_callback(f): - try: - callback(f) - except Exception as e: - logger.error(f"Error in callback: {e}") - traceback.print_exc() - - future.add_done_callback(wrapped_callback) - return future - - def stop(self): - """Stop the event loop""" - if not self._closed: - self._closed = True - if self._loop: - self._loop.call_soon_threadsafe(self._loop.stop) - - -class RoutingStoreBase(ABC): - """Base class for routing store""" - - def __init__(self, routing_replay_config: RoutingReplayConfig) -> None: - self.routing_replay_config = routing_replay_config - - @abstractmethod - async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: - """Put the routing indices into store""" - raise NotImplementedError - - @abstractmethod - async def clear_store( - self, - ): - """Clear the routing indices store""" - raise NotImplementedError - - @abstractmethod - async def clear_prefix_batch(self, routing_prefix_key: str): - """Clear the routing indices""" - raise NotImplementedError - - -class RoutingStoreLocal(RoutingStoreBase): - """Routing Store using local memory""" - - def __init__(self, routing_replay_config) -> None: - super().__init__(routing_replay_config=routing_replay_config) - self.local_store_dir = routing_replay_config.local_store_dir - os.makedirs(self.local_store_dir, exist_ok=True) - - async def put( - self, - routing_key: str, - routing_indices: np.ndarray, - ) -> None: - """Put the routing indices into store""" - # TODO(gongshaotian) covert ./store_dir/routing_key/layer_id.pdtensor to ./store_dir/routing_key.pdtensor - time_before_put = time.perf_counter() - - if len(routing_indices.shape) == 2: - re_layer_id, re_rollout_id = routing_key[::-1].split("_", 1) - rollout_id = re_rollout_id[::-1] - layer_id = re_layer_id[::-1] - request_path = os.path.join(self.local_store_dir, rollout_id) - file_path = os.path.join(request_path, f"layer_{layer_id}.pdtensor") - elif len(routing_indices.shape) == 3: - request_path = os.path.join(self.local_store_dir, routing_key) - file_path = os.path.join(request_path, f"{routing_key}.pdtensor") - else: - raise ValueError(f"Invalid routing indices shape: {routing_indices.shape}") - - paddle.save(routing_indices, file_path) - logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s") - - async def clear_store(self): - """Clear the routing indices store""" - if os.path.isdir(self.local_store_dir): - shutil.rmtree(self.local_store_dir) - - logger.info("[R3] Clear routing store.") - - async def clear_prefix_batch(self, routing_prefix_key: str): - """Clear the routing indices""" - raise NotImplementedError - - -class RoutingStoreRDMA(RoutingStoreBase): - """Routing Store using RDMA""" - - def __init__(self, routing_replay_config) -> None: - super().__init__(routing_replay_config=routing_replay_config) - try: - # Only used in RLHF - from p2pstore import P2PClient, P2PConfig - except ModuleNotFoundError: - raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") - - rdma_store_server = routing_replay_config.rdma_store_server - p2pConfig = P2PConfig(metadata_server=rdma_store_server) - self.p2p_client = P2PClient(p2pConfig) - - async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: - """Put the routing indices into store""" - time_before_put = time.perf_counter() - if len(routing_indices.shape) == 3: - # NOTE(gongshaotian) Fused put with bytes data - routing_bytes = routing_indices.tobytes() - result = await self.p2p_client.put(routing_key, routing_bytes) - else: - result = await self.p2p_client.put(routing_key, routing_indices) - logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") - return result - - async def clear_prefix_batch(self, routing_prefix_key: str): - time_before_clear = time.perf_counter() - result = await self.p2p_client.delete_batch([routing_prefix_key]) - logger.info( - f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" - ) - return result - - async def clear_store(self): - """Clear the routing indices store""" - time_before_clear = time.perf_counter() - result = await self.p2p_client.clear() - logger.info(f"[R3] Clear routing store cost is {time.perf_counter()-time_before_clear}s.") - return result - - -def get_routing_store(routing_replay_config: RoutingReplayConfig) -> RoutingStoreBase: - if routing_replay_config.routing_store_type == "local": - return RoutingStoreLocal(routing_replay_config=routing_replay_config) - elif routing_replay_config.routing_store_type == "rdma": - return RoutingStoreRDMA(routing_replay_config=routing_replay_config) - else: - raise ValueError( - f"Invalid routing store type: '{routing_replay_config.routing_store_type}'. " - "Valid types are: 'local', 'rdma'" - ) +# Backward compatibility alias +RoutingReplayManager = RoutedExpertsCapturer diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index fb720affb58..0aa50f4ef68 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -339,18 +339,18 @@ def post_process_normal( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( + # Trigger lazy SharedMemory attach if not yet attempted + routing_replay_manager._try_attach_routing_host_view() + # GPU transient buffer → SharedMemory routing_host_buffer + slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( positions=routing_replay_manager.pending_update_positions ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) - - # Put routing of finished requests to store - finished_batch_ids = paddle.flatten(paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id)) - context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder - routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens) + num_tokens = len(slot_mapping_flat) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.save_captured_routing( + num_tokens=num_tokens, + slot_mapping=slot_mapping_flat, + ) # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): @@ -521,27 +521,18 @@ def post_process_specualate( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( + # Trigger lazy SharedMemory attach if not yet attempted + routing_replay_manager._try_attach_routing_host_view() + # GPU transient buffer → SharedMemory routing_host_buffer + slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( positions=routing_replay_manager.pending_update_positions ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) - - # Put routing of finished requests to store - last_accept_token = paddle.full_like(model_output.accept_tokens, -1) - col_indices = paddle.arange(model_output.accept_tokens.shape[1], dtype=model_output.accept_num.dtype) - mask = col_indices < paddle.unsqueeze(model_output.accept_num, 1) - last_accept_token[mask] = model_output.accept_tokens[mask] - eos_tokens_flat = model_output.eos_token_id.flatten() - isin_mask = paddle.isin(last_accept_token, eos_tokens_flat) - finished_batch_ids = isin_mask.any(axis=-1) - context_lens = model_output.seq_lens_encoder + model_output.seq_lens_decoder - routing_replay_manager.put_finished_batch( - finished_batch_ids=finished_batch_ids, - seq_lens_decoder=context_lens, - ) + num_tokens = len(slot_mapping_flat) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.save_captured_routing( + num_tokens=num_tokens, + slot_mapping=slot_mapping_flat, + ) # Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx # into a single kernel launch. Handles EOS detection, max_dec_len truncation, step_idx diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index a8682374ed6..45e60e8d656 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -140,6 +140,65 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.health_lock = threading.Lock() self.engine_output_token_hang = False + # Routing replay: attach to SharedMemory routing_host_buffer (lazy init after profiling) + self.routing_host_view = None + self._routing_host_view_init_attempted = False + self.routing_cache_manager = None # Set by Engine after profiling for local/rdma store dispatch + + def _init_routing_host_view(self): + """Attach to SharedMemory routing_host_buffer created by Engine. Called lazily.""" + self._routing_host_view_init_attempted = True + if not self.cfg.routing_replay_config.enable_routing_replay: + return + try: + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingHostBufferView, + ) + + rrc = self.cfg.routing_replay_config + cache_config = self.cfg.cache_config + + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + num_gpu_blocks = cache_config.total_block_num + max_num_kv_tokens = num_gpu_blocks * cache_config.block_size + shape = (max_num_kv_tokens, rrc.num_moe_layers, rrc.moe_top_k) + + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=rrc.routing_dtype, shm_name=shm_name) + self._routing_block_size = cache_config.block_size + llm_logger.info(f"[R3] TokenProcessor attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + llm_logger.warning("[R3] RoutingHostBuffer SharedMemory not found, routing gather disabled.") + except Exception as e: + llm_logger.warning(f"[R3] Failed to attach to RoutingHostBuffer: {e}") + + def _gather_routing_for_finished_request(self, task, seq_len: int): + """ + Gather complete routing data for a finished request from routing_host_buffer. + + Args: + task: Request task with block_tables + seq_len: Total sequence length + + Returns: + numpy array [seq_len, num_moe_layers, top_k] or None + """ + if self.routing_host_view is None and not self._routing_host_view_init_attempted: + self._init_routing_host_view() + if self.routing_host_view is None: + return None + + import math + + block_size = self._routing_block_size + block_ids = task.block_tables[: math.ceil(seq_len / block_size)] + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + + return self.routing_host_view.gather(slot_mapping) + def healthy(self): """ whether token processor is healthy @@ -274,6 +333,7 @@ def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: self._compute_speculative_status() if not is_prefill: self._record_completion_metrics(task, current_time) + self._finalize_routing(task_id, task, result, is_prefill) self._recycle_resources(task_id, batch_id, task, result, is_prefill) break return result @@ -337,6 +397,7 @@ def _process_batch_output_use_zmq(self, receive_datas): prompt_token_ids=task.prompt_token_ids, outputs=PoolingOutput(data=pooler_output), ) + self._finalize_routing(task_id, task, result, False) self._recycle_resources(task_id, i, task, result, False) batch_result.append(result) else: @@ -523,6 +584,47 @@ def postprocess(self, batch_result: List[RequestOutput], mtype=3): except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") + def _finalize_routing(self, task_id, task, result, is_prefill=False): + """ + Gather routing data before blocks are freed. + Must be called before _recycle_resources so that block_tables are still valid. + + - PD P node (is_prefill=True): gather prefill-only routing, attach to result for sending to D. + - Non-PD / D node (result.finished): gather full routing (prompt + output), + either attach to result ("response" mode) or dispatch to store ("local"/"rdma" mode). + """ + if not self.cfg.routing_replay_config.enable_routing_replay: + return + if result is None: + return + + try: + if is_prefill: + if result.error_code == 200: + seq_len = task.prompt_token_ids_len + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + elif result.finished: + store_type = self.cfg.routing_replay_config.routing_store_type + seq_len = ( + task.prompt_token_ids_len + len(task.output_token_ids) + if hasattr(task, "output_token_ids") + else task.prompt_token_ids_len + ) + if store_type == "response": + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + elif self.routing_cache_manager is not None: + self.routing_cache_manager.on_request_finished( + request_id=task_id, + block_table=task.block_tables, + seq_len=seq_len, + ) + except Exception as e: + llm_logger.warning(f"[R3] Failed to finalize routing for {task_id}: {e}") + def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False): """ recycle resources @@ -981,6 +1083,7 @@ def _process_batch_output(self): self.resource_manager.cache_output_tokens( task ) # when enable prefix caching, cache kv cache for output tokens + self._finalize_routing(task_id, task, result, is_prefill) self._recycle_resources(task_id, i, task, result, is_prefill) llm_logger.info(f"eos token {task_id} Recycle end.") break @@ -1102,6 +1205,7 @@ def clear_data(self): ), ) is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" + self._finalize_routing(task.request_id, task, result, is_prefill) self._recycle_resources(task.request_id, i, task, result, is_prefill) llm_logger.warning(f"clear data for task {task.request_id}") diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 5e09d2a5d07..6dee2af9008 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -912,11 +912,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = prompt_token_ids = request.prompt_token_ids self.proposer.start_request(idx, request.request_id, prompt_token_ids) - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - # 1.prefix task(need regist) 2. chunkend task(not need regist) - self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) - if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token @@ -961,10 +956,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_request(batch_id=idx) - continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -1323,11 +1314,10 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta - routing_replay_table = None - if self.routing_replay_manager is not None: - routing_replay_table = self.routing_replay_manager.get_routing_table() - num_running_requests = self.share_inputs["seq_lens_this_time"].shape[0] + gpu_routing_buffer = None + if self.routing_replay_manager is not None: + gpu_routing_buffer = self.routing_replay_manager.get_gpu_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1354,7 +1344,7 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - routing_replay_table=routing_replay_table, + gpu_routing_buffer=gpu_routing_buffer, ) dist_status = self.collect_distributed_status() @@ -2196,7 +2186,7 @@ def _preprocess( if self.fd_config.routing_replay_config.enable_routing_replay: self.routing_replay_manager.pending_update_positions = self.routing_replay_manager.get_token_positions( seq_lens_decoder=self.share_inputs["seq_lens_decoder"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time_buffer"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], ) # Update state of logits processor @@ -2772,7 +2762,7 @@ def clear_requests(self): # Routing Replay if self.routing_replay_manager: - self.routing_replay_manager.clear_all_request() + self.routing_replay_manager.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" @@ -2790,10 +2780,6 @@ def update_parameters(self, pid): # Recapture CUDAGraph if self.use_cudagraph: self.capture_model() - # Rollout Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - # TODO(gongshaotian): Delete suspend func - self.routing_replay_manager.update_suspend_routing_replay() # Send single self.dynamic_weight_manager.finalize_update(pid) diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index d72538ba8d7..a5fe3547e64 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -51,9 +51,6 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) -from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( - RoutingReplayManager, -) from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata @@ -203,8 +200,6 @@ def __init__( # Rollout routing replay config self.routing_replay_manager = None - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config) self.zmq_client = None self.async_output_queue = None @@ -786,11 +781,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.forward_batch_reqs_list[idx] = request has_prefill_task = True - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - if prefill_start_index == 0: - self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) - if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token @@ -822,10 +812,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_request(batch_id=idx) - continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -1239,9 +1225,9 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta - routing_replay_table = None + gpu_routing_buffer = None if self.routing_replay_manager is not None: - routing_replay_table = self.routing_replay_manager.get_routing_table() + gpu_routing_buffer = self.routing_replay_manager.get_gpu_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1268,7 +1254,8 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - routing_replay_table=routing_replay_table, + routing_replay_table=None, + gpu_routing_buffer=gpu_routing_buffer, ) dist_status = self.collect_distributed_status() @@ -1790,8 +1777,8 @@ def _dummy_run( # only need to capture prefill break - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_routing_table() + if self.fd_config.routing_replay_config.enable_routing_replay and self.routing_replay_manager is not None: + self.routing_replay_manager.clear() @sot_warmup_guard(True) def capture_model(self) -> None: @@ -2302,7 +2289,7 @@ def _postprocess( and self.share_inputs["is_block_step"].sum() == 0 and self.share_inputs["is_chunk_step"].sum() == 0 ): - self.routing_replay_manager.put_table_to_store() + pass # Routing store submission now handled by RoutingCacheManager on Engine side return model_output_data, sampler_output, post_process_done def _save_model_output( @@ -2530,8 +2517,8 @@ def clear_requests(self): self.prompt_logprobs_reqs.clear() self.in_progress_prompt_logprobs.clear() self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)] - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.put_table_to_store() + if self.fd_config.routing_replay_config.enable_routing_replay and self.routing_replay_manager is not None: + self.routing_replay_manager.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 76419eba8cd..f3699d4fc3b 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -65,6 +65,7 @@ class Args: kvcache_storage_backend = None write_policy = "write_through" model_path = "test_model" + routing_replay_config = MagicMock(enable_routing_replay=False) # ========================== diff --git a/tests/e2e/utils/rollout_routing_replay_test_utils.py b/tests/e2e/utils/rollout_routing_replay_test_utils.py index 74af852a292..86d853d845b 100644 --- a/tests/e2e/utils/rollout_routing_replay_test_utils.py +++ b/tests/e2e/utils/rollout_routing_replay_test_utils.py @@ -156,11 +156,9 @@ def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, mode cur_save_routing_path = f"./R3_tmp/routing_replay_output_{model_name}/" model_path = os.getenv("MODEL_PATH") if model_path: - baseline_path = os.path.join( - model_path, f"R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}" - ) + baseline_path = os.path.join(model_path, f"R3_BaseLine_uint8_0424/routing_replay_output_baseline_{model_name}") else: - baseline_path = f"./R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}" + baseline_path = f"./R3_BaseLine_uint8_0424/routing_replay_output_baseline_{model_name}" stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream") nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream") diff --git a/tests/engine/test_engine.py b/tests/engine/test_engine.py index 762db4ea4ed..17de3b32bc2 100644 --- a/tests/engine/test_engine.py +++ b/tests/engine/test_engine.py @@ -68,6 +68,7 @@ def test_stop_profile_returns_true_on_success(self): parallel_config=types.SimpleNamespace(device_ids="0"), scheduler_config=types.SimpleNamespace(splitwise_role="decode"), cache_config=Mock(enable_prefix_caching=False, reset=Mock()), + routing_replay_config=types.SimpleNamespace(enable_routing_replay=False), ) eng.engine = types.SimpleNamespace( start_cache_service=lambda *_: None, diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 9398e07d9f5..b47344470de 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -65,6 +65,7 @@ class CacheConfig: model_config = ModelConfig() scheduler_config = SchedulerConfig() cache_config = CacheConfig() + routing_replay_config = MagicMock(enable_routing_replay=False) class MockTask: diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 6a692f24cc4..0fd4d1753ee 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -64,6 +64,7 @@ def __init__( ) self.max_num_seqs = max_num_seqs self.splitwise_version = "v1" + self.routing_replay_config = types.SimpleNamespace(enable_routing_replay=False) class _DummyResourceManager: From e0cad0f3f329bf80c47ac0ed6cfc27ceb3f9e463 Mon Sep 17 00:00:00 2001 From: huicongyao Date: Mon, 27 Apr 2026 20:07:30 +0800 Subject: [PATCH 067/143] [Cherry-Pick][Speculative Decoding][BugFix] overlap compute logprobs for speculative decoding (#7406) (#7585) * [Speculative Decoding] [BugFix] overlap compute logprobs for speculative decoding (#7406) * fix shape mismatch while cuda graph closed * fix * fix xpu typo * overlap compute logprobs * fix * optimize * fix * opt * fix unitest error and optimize code * fix --- custom_ops/gpu_ops/cpp_extensions.cc | 9 + .../build_sampling_params_logprob.cu | 129 +++++++++ .../model_executor/layers/sample/logprobs.py | 18 +- .../model_executor/layers/sample/sampler.py | 60 ++-- .../model_executor/pre_and_post_process.py | 6 +- .../xpu_pre_and_post_process.py | 2 +- fastdeploy/spec_decode/mtp.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 20 +- fastdeploy/worker/xpu_model_runner.py | 4 +- tests/layers/test_sampler.py | 3 +- tests/layers/test_speculative_sampler.py | 10 +- .../test_build_sampling_params_logprob.py | 269 ++++++++++++++++++ 12 files changed, 479 insertions(+), 53 deletions(-) create mode 100644 custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu create mode 100644 tests/operators/test_build_sampling_params_logprob.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 8e9cf6a3ddc..92478f71d75 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -783,6 +783,11 @@ std::vector BuildSamplingParams( const int64_t token_num_output_cpu, const int64_t increment_value); +std::vector BuildSamplingParamLogProb( + const paddle::Tensor& input_params, + const paddle::Tensor& token_num_per_batch, + int64_t token_num_output_cpu); + void SpecTokenPenaltyMultiScores( const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, @@ -1771,6 +1776,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &BuildSamplingParams, "build_sampling_params function"); + m.def("build_sampling_params_logprob", + &BuildSamplingParamLogProb, + "build_sampling_params_logprob function"); + m.def("speculate_get_token_penalty_multi_scores", &SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); diff --git a/custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu b/custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu new file mode 100644 index 00000000000..790ba551485 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu @@ -0,0 +1,129 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +__global__ void BuildSamplingParamLogProbKernel( + T* output_params, + const T* input_params, + const int32_t* token_num_per_batch, + const int64_t token_num_output_cpu) { + const int bi = blockIdx.x; + const int tid = threadIdx.x; + + // Compute start offset: sum of token_num_per_batch[0..bi-1] + int start_offset = 0; + for (int i = 0; i < bi; i++) { + start_offset += token_num_per_batch[i]; + } + + int cur_token_num = token_num_per_batch[bi]; + + if (cur_token_num <= 0) { + return; + } + + // Read per-batch param into register + T val = input_params[bi]; + + // Fill output_params with bounds check against total output size + for (int i = tid; i < cur_token_num; i += blockDim.x) { + int64_t idx = static_cast(start_offset) + i; + if (idx < token_num_output_cpu) { + output_params[idx] = val; + } + } +} + +std::vector BuildSamplingParamLogProb( + const paddle::Tensor& input_params, + const paddle::Tensor& token_num_per_batch, + const int64_t token_num_output_cpu) { + auto cu_stream = input_params.stream(); + // Initialize output to safe defaults for use as divisors: + // int32/float32 -> 1, bool -> false + paddle::Tensor output_params; + switch (input_params.dtype()) { + case paddle::DataType::BOOL: + output_params = paddle::full({token_num_output_cpu}, + false, + input_params.dtype(), + input_params.place()); + break; + case paddle::DataType::INT32: + output_params = paddle::full({token_num_output_cpu}, + 1, + input_params.dtype(), + input_params.place()); + break; + case paddle::DataType::FLOAT32: + output_params = paddle::full({token_num_output_cpu}, + 1.0f, + input_params.dtype(), + input_params.place()); + break; + default: + PD_THROW( + "Unsupported data type for BuildSamplingParamLogProb. " + "Only bool, int32, float32 are supported."); + } + + int32_t num_blocks = token_num_per_batch.shape()[0]; + switch (input_params.dtype()) { + case paddle::DataType::BOOL: { + BuildSamplingParamLogProbKernel<<>>( + output_params.data(), + input_params.data(), + token_num_per_batch.data(), + token_num_output_cpu); + break; + } + case paddle::DataType::INT32: { + BuildSamplingParamLogProbKernel + <<>>( + output_params.data(), + input_params.data(), + token_num_per_batch.data(), + token_num_output_cpu); + break; + } + case paddle::DataType::FLOAT32: { + BuildSamplingParamLogProbKernel<<>>( + output_params.data(), + input_params.data(), + token_num_per_batch.data(), + token_num_output_cpu); + break; + } + default: { + PD_THROW( + "Unsupported data type for BuildSamplingParamLogProb. " + "Only bool, int32, float32 are supported."); + } + } + + return {output_params}; +} + +PD_BUILD_STATIC_OP(build_sampling_params_logprob) + .Inputs({"input_params", "token_num_per_batch"}) + .Outputs({"output_params"}) + .Attrs({"token_num_output_cpu: int64_t"}) + .SetKernelFn(PD_KERNEL(BuildSamplingParamLogProb)); diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index 80ccfc2fdd9..ac9e0edacbf 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -144,7 +144,8 @@ def build_output_logprobs( is_naive: bool = False, logprobs_mode: str = "default", compute_logprobs_fn: Optional[Callable] = None, -) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]: + real_bsz: int = 0, +) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: """ Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. @@ -170,12 +171,13 @@ def build_output_logprobs( logprobs_tensors = None cu_batch_token_offset = None - real_bsz = share_inputs["seq_lens_this_time"].shape[0] + # NOTE(huicongyao) real_bsz is passed from _postprocess, remove this in future + max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0] if is_naive: # NAIVE mode: one token per request, logits are already correct output_logits = logits - token_ids = share_inputs["accept_tokens"][:real_bsz, 0] + token_ids = share_inputs["accept_tokens"][:max_occupied_slots, 0] else: # Speculative mode: extract target logits for accepted positions from fastdeploy.model_executor.layers.sample.ops import ( @@ -183,8 +185,8 @@ def build_output_logprobs( ) batch_token_num = paddle.where( - share_inputs["seq_lens_encoder"][:real_bsz] != 0, - paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), + share_inputs["seq_lens_encoder"][:max_occupied_slots] != 0, + paddle.ones_like(share_inputs["seq_lens_encoder"][:max_occupied_slots]), share_inputs["seq_lens_this_time"], ).flatten() @@ -194,12 +196,12 @@ def build_output_logprobs( "int32" ) cu_batch_token_offset = paddle.concat( - [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] + [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:max_occupied_slots])] ).astype("int32") share_inputs["cu_batch_token_offset"] = cu_batch_token_offset output_logits = paddle.empty( - [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], + [share_inputs["accept_num"][:max_occupied_slots].sum(), logits.shape[1]], dtype=logits.dtype, ) speculate_get_target_logits( @@ -222,7 +224,7 @@ def build_output_logprobs( # Compute logprobs with temperature scaling and top_p normalization if logprobs_mode == "raw_logprobs": - raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) + raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata, real_bsz) elif logprobs_mode == "raw_logits": raw_logprobs = output_logits.clone() else: diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 7892b4d73ad..7157fc8d755 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -54,6 +54,7 @@ if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import ( build_sampling_params, + build_sampling_params_logprob, naive_update_model_status, ) @@ -833,6 +834,7 @@ def __init__(self, fd_config: FDConfig): self.spec_method = spec_config.method self.verify_strategy = spec_config.verify_strategy self.prefill_one_step_stop = fd_config.parallel_config.prefill_one_step_stop + self.num_speculative_tokens = spec_config.num_speculative_tokens # Accept policy from config (can be overridden by function parameters) self.config_accept_all = spec_config.accept_policy == "accept_all" @@ -858,55 +860,57 @@ def compute_logprobs( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, + real_bsz: int = 0, ) -> paddle.Tensor: """compute logprobs""" share_inputs = sampling_metadata.share_inputs last_logits = logits - real_bsz = share_inputs["seq_lens_this_time"].shape[0] - batch_token_num = share_inputs["accept_num"][:real_bsz] + + # NOTE(huicongyao): temporarily used to provide a max_sized input, remove in the future + num_tokens = real_bsz * (self.num_speculative_tokens + 1) + padded_logits = paddle.zeros(shape=[num_tokens, last_logits.shape[1]], dtype=last_logits.dtype) + padded_logits[: logits.shape[0]] = last_logits + max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0] + + batch_token_num = share_inputs["accept_num"][:max_occupied_slots] temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs if temp_scaled_logprobs is not None: - real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] - temperature = sampling_metadata.temperature[:real_bsz] - real_bsz_temp_scaled = ( - real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool") - ) - temperature = temperature.squeeze(1).repeat_interleave(batch_token_num) + real_bsz_temp_scaled = temp_scaled_logprobs[:max_occupied_slots] + temperature = sampling_metadata.temperature[:max_occupied_slots] + real_bsz_temp_scaled = build_sampling_params_logprob(real_bsz_temp_scaled, batch_token_num, num_tokens) + temperature = build_sampling_params_logprob(temperature, batch_token_num, num_tokens) temp_temperature = paddle.where( real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) ).unsqueeze(1) - last_logits = last_logits / temp_temperature + padded_logits = padded_logits / temp_temperature - last_logprobs = F.log_softmax(last_logits, axis=-1) + last_logprobs = F.log_softmax(padded_logits, axis=-1) top_p_logprob = None top_p_token_mask = None - if ( top_p_normalized_logprobs is not None and share_inputs is not None and sampling_metadata.top_p_normalized_logprobs_flag ): - real_token_top_p = ( - sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) - ) - top_p_normalized_logprobs = ( - top_p_normalized_logprobs[:real_bsz] - .astype("int32") - .squeeze(1) - .repeat_interleave(batch_token_num) - .astype("bool") - .unsqueeze(1) - ) + real_token_top_p = build_sampling_params_logprob( + sampling_metadata.top_p[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens + ).unsqueeze(1) + top_p_normalized_logprobs = build_sampling_params_logprob( + top_p_normalized_logprobs[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens + ).unsqueeze(1) top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) - if top_p_token_mask.any(): - probs = F.softmax(last_logits, axis=-1) - probs = top_p_normalize_probs_paddle(probs, real_token_top_p) - top_p_logprob = paddle.log(probs) + + probs = F.softmax(padded_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + top_p_logprob = paddle.log(probs) if top_p_logprob is not None: last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) - return last_logprobs + + # NOTE(huicongyao) temporarily used for slice last_logprobs to its real shape, remove in the future + real_token_num = batch_token_num.sum().item() + return last_logprobs[:real_token_num] def gather_logprobs( self, @@ -1136,6 +1140,7 @@ def forward_cuda( increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, + real_bsz: int = 0, ) -> SamplerOutput: """ Forward pass for speculative sampling. @@ -1229,6 +1234,7 @@ def forward_cuda( is_naive=is_naive, logprobs_mode=self.logprobs_mode, compute_logprobs_fn=self.compute_logprobs, + real_bsz=real_bsz, ) sampler_output.logprobs_tensors = logprobs_tensors if cu_batch_token_offset is not None: diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0aa50f4ef68..6de08ae0b70 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -478,7 +478,7 @@ def save_output_normal( share_inputs["last_preempted_idx"][:] = 0 -def post_process_specualate( +def post_process_speculate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, @@ -561,7 +561,7 @@ def post_process_specualate( # so that async D2H of logz_per_batch has more time to complete. -def save_output_specualate( +def save_output_speculate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, @@ -755,7 +755,7 @@ def post_process( ) else: if speculative_decoding: - post_process_specualate( + post_process_speculate( sampler_or_pooler_output, model_output, share_inputs, diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 9e32ea34876..e5a1d9419c8 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -381,7 +381,7 @@ def xpu_post_process_normal( share_inputs["preempted_idx"][:] = 0 -def xpu_post_process_specualate( +def xpu_post_process_speculate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: Dict[str, paddle.Tensor], diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index acf7bee27a5..0c1681e12a0 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -882,7 +882,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item() else: if substep == 0: - token_num_cpu = real_bsz * (self.max_draft_token_num + 1) + token_num_cpu = self.model_inputs["target_hidden_states"].shape[0] else: token_num_cpu = real_bsz if token_num_cpu > 0: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 6dee2af9008..8b0c1468acb 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -97,7 +97,7 @@ pre_process, rebuild_padding, save_output_normal, - save_output_specualate, + save_output_speculate, ) from fastdeploy.output.pooler import PoolerOutput from fastdeploy.worker.model_runner_base import ( @@ -137,8 +137,8 @@ def __init__( if fd_config.model_config.max_logprobs == -1 else fd_config.model_config.max_logprobs ) - self.temp_scaled_logprobs = True - self.top_p_normalized_logprobs = True + self.temp_scaled_logprobs = False + self.top_p_normalized_logprobs = False self.prompt_logprobs_reqs: dict[str, Request] = {} self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} self.forward_batch_reqs_list: list[Request] = [None for _ in range(self.scheduler_config.max_num_seqs)] @@ -1177,7 +1177,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p for req in self.forward_batch_reqs_list if req is not None and req.sampling_params is not None and req.sampling_params.logprobs is not None ] - if len(logprobs_reqs): + self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs) + self.top_p_normalized_logprobs = any( + req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs + ) + if logprobs_reqs: self.max_logprobs = ( max( [ @@ -1188,10 +1192,6 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p if not self.speculative_decoding else 20 ) - self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs) - self.top_p_normalized_logprobs = any( - req.sampling_params.top_p_normalized_logprobs for req in logprobs_reqs - ) elif self.enable_logprob: self.max_logprobs = None if not self.speculative_decoding else 0 @@ -1731,6 +1731,7 @@ def _dummy_sampler_run( self.increment_value, accept_all_drafts, reject_all_drafts, + real_bsz=batch_size, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -2360,6 +2361,7 @@ def _postprocess( self.share_inputs, real_output_token_num, self.increment_value, + real_bsz=real_bsz, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -2512,7 +2514,7 @@ def _save_model_output( sampler_output, ): if self.speculative_decoding: - save_output_specualate( + save_output_speculate( sampler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index bd585519520..ba250215345 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -58,7 +58,7 @@ from fastdeploy.model_executor.xpu_pre_and_post_process import ( step_xpu, xpu_post_process_normal, - xpu_post_process_specualate, + xpu_post_process_speculate, xpu_pre_process, xpu_process_output, ) @@ -1635,7 +1635,7 @@ class at the server level, which is too granular for ModelRunner. if self.speculative_decoding: # base model post process - xpu_post_process_specualate( + xpu_post_process_speculate( sampler_output, model_output_data, self.share_inputs, diff --git a/tests/layers/test_sampler.py b/tests/layers/test_sampler.py index cdc58eb33d1..9fde61f48cc 100644 --- a/tests/layers/test_sampler.py +++ b/tests/layers/test_sampler.py @@ -310,6 +310,7 @@ def test_speculative_sampler_basic(monkeypatch): enf_gen_phase_tag=False, verify_strategy="topp", accept_policy="normal", + num_speculative_tokens=1, ), parallel_config=types.SimpleNamespace(prefill_one_step_stop=False), ) @@ -327,7 +328,7 @@ def test_speculative_sampler_basic(monkeypatch): m.top_p_normalized_logprobs_flag = True m.share_inputs = { "seq_lens_this_time": paddle.to_tensor([[1]], dtype="int64"), - "accept_num": paddle.to_tensor([1], dtype="int64"), + "accept_num": paddle.to_tensor([1], dtype="int32"), } gathered = sampler.gather_logprobs(sampler.compute_logprobs(logits, m), 0, paddle.to_tensor([1], dtype="int64")) assert gathered.logprob_token_ids.shape[1] == 1 diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index 227e73db5a8..6321e6f08f3 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -221,7 +221,15 @@ def test_speculative_sampler_logprobs(): for logprobs_mode in logprobs_mode_list: fd_config.model_config.logprobs_mode = logprobs_mode sampler = SpeculativeSampler(fd_config) - sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value) + sampler( + logits, + sampling_metadata, + max_model_len, + share_inputs, + token_num_output_cpu, + increment_value, + real_bsz=batch_size, + ) def test_mtp_sampler(): diff --git a/tests/operators/test_build_sampling_params_logprob.py b/tests/operators/test_build_sampling_params_logprob.py new file mode 100644 index 00000000000..9eb5e8e3052 --- /dev/null +++ b/tests/operators/test_build_sampling_params_logprob.py @@ -0,0 +1,269 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Any, Dict + +import numpy as np +import paddle + +# --- Import ops (bypass fastdeploy.__init__) --- +try: + import os + import sys + + _fd_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + if _fd_root not in sys.path: + sys.path.insert(0, _fd_root) + from fastdeploy.import_ops import import_custom_ops + + _package = "fastdeploy.model_executor.ops.gpu" + import_custom_ops(_package, ".fastdeploy_ops", globals()) +except ImportError as e: + print(f"Import error: {e}") + raise + +CUDA_PLACE = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + + +# ============================================================ +# Layer 1: Helpers -- tensor creation / kernel invocation / output extraction +# ============================================================ + + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy dict -> paddle tensors on GPU. Scalar attrs are passed through.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + + +def run_kernel(paddle_inputs, inputs): + """Call build_sampling_params_logprob with paddle tensors + scalar attrs.""" + return build_sampling_params_logprob( # noqa: F821 + paddle_inputs["input_params"], + paddle_inputs["token_num_per_batch"], + inputs["token_num_output_cpu"], + ) + + +def get_outputs(result) -> Dict[str, np.ndarray]: + """Extract output tensor to numpy.""" + return {"output_params": result.numpy()} + + +# ============================================================ +# Layer 2: Input generation +# ============================================================ + + +def gen_inputs( + real_bsz=8, + max_tokens_per_batch=5, + dtype=np.float32, + seed=42, +) -> Dict[str, Any]: + """Generate randomized test inputs. + + Args: + real_bsz: number of batch items + max_tokens_per_batch: max token count per batch item + dtype: numpy dtype for input_params (np.float32, np.int32, np.bool_) + seed: random seed + """ + rng = np.random.default_rng(seed) + + # Random token counts per batch, allow zeros (empty slots) + token_num_per_batch = rng.integers(0, max_tokens_per_batch + 1, size=real_bsz).astype(np.int32) + token_num_output_cpu = int(token_num_per_batch.sum()) + + # Generate per-batch param values + if dtype == np.float32: + input_params = rng.uniform(0.0, 1.0, size=real_bsz).astype(np.float32) + elif dtype == np.int32: + input_params = rng.integers(0, 100, size=real_bsz).astype(np.int32) + elif dtype == np.bool_: + input_params = rng.choice([False, True], size=real_bsz) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + return { + "input_params": input_params, + "token_num_per_batch": token_num_per_batch, + "token_num_output_cpu": token_num_output_cpu, + } + + +# ============================================================ +# Layer 3: Reference implementation (pure Python/NumPy) +# ============================================================ + + +def reference_build_sampling_params_logprob(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Python reference -- must match CUDA kernel logic exactly. + + Kernel logic: + 1. Initialize output with safe defaults (bool->False, int32->1, float32->1.0) + 2. For each batch bi, fill output[start_offset..start_offset+cur_token_num-1] + with input_params[bi], where start_offset = sum(token_num_per_batch[0..bi-1]) + """ + input_params = inputs["input_params"].copy() + token_num_per_batch = inputs["token_num_per_batch"].copy() + token_num_output_cpu = inputs["token_num_output_cpu"] + real_bsz = len(input_params) + dtype = input_params.dtype + + # Initialize output with safe defaults (matching kernel behavior) + if dtype == np.bool_: + output_params = np.full(token_num_output_cpu, False, dtype=dtype) + elif dtype == np.int32: + output_params = np.full(token_num_output_cpu, 1, dtype=dtype) + elif dtype == np.float32: + output_params = np.full(token_num_output_cpu, 1.0, dtype=dtype) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + for bi in range(real_bsz): + start_offset = int(token_num_per_batch[:bi].sum()) + cur_token_num = int(token_num_per_batch[bi]) + if cur_token_num <= 0: + continue + val = input_params[bi] + for i in range(cur_token_num): + idx = start_offset + i + if idx < token_num_output_cpu: + output_params[idx] = val + + return {"output_params": output_params} + + +# ============================================================ +# Layer 4a: TEST_CONFIGS -- all pure-parameter test scenarios +# ============================================================ + +TEST_CONFIGS = [ + # --- basic coverage, float32 --- + {"name": "float32_small_batch", "real_bsz": 2, "max_tokens_per_batch": 3, "dtype": np.float32, "seed": 42}, + {"name": "float32_medium_batch", "real_bsz": 16, "max_tokens_per_batch": 8, "dtype": np.float32, "seed": 42}, + {"name": "float32_large_batch", "real_bsz": 64, "max_tokens_per_batch": 16, "dtype": np.float32, "seed": 42}, + # --- int32 dtype --- + {"name": "int32_small_batch", "real_bsz": 4, "max_tokens_per_batch": 5, "dtype": np.int32, "seed": 42}, + {"name": "int32_large_batch", "real_bsz": 32, "max_tokens_per_batch": 10, "dtype": np.int32, "seed": 42}, + # --- bool dtype --- + {"name": "bool_small_batch", "real_bsz": 4, "max_tokens_per_batch": 5, "dtype": np.bool_, "seed": 42}, + {"name": "bool_large_batch", "real_bsz": 32, "max_tokens_per_batch": 10, "dtype": np.bool_, "seed": 42}, + # --- edge cases --- + {"name": "single_batch_single_token", "real_bsz": 1, "max_tokens_per_batch": 1, "dtype": np.float32, "seed": 42}, + {"name": "single_batch_many_tokens", "real_bsz": 1, "max_tokens_per_batch": 64, "dtype": np.float32, "seed": 42}, + {"name": "many_batch_one_token", "real_bsz": 64, "max_tokens_per_batch": 1, "dtype": np.float32, "seed": 42}, +] + + +# ============================================================ +# Layer 4b: Test suite +# ============================================================ + + +class TestBuildSamplingParamLogprob(unittest.TestCase): + + # ------ shared helpers ------ + + def _run_and_get(self, inputs): + paddle_inputs = to_paddle_inputs(inputs) + result = run_kernel(paddle_inputs, inputs) + return get_outputs(result) + + def _check_all_outputs(self, inputs, outputs): + """Compare ALL output tensors against reference.""" + ref = reference_build_sampling_params_logprob(inputs) + np.testing.assert_array_equal(outputs["output_params"], ref["output_params"], err_msg="output_params mismatch") + + def _run_full_test(self, config): + inputs = gen_inputs(**config) + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + return outputs + + # ------ test cases ------ + + def test_configs(self): + """Run all TEST_CONFIGS via subTest (one subTest per config).""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + self._run_full_test(test_cfg) + + def test_all_zero_token_counts(self): + """All batch items have zero tokens -- output should be empty array.""" + inputs = gen_inputs(real_bsz=4, max_tokens_per_batch=1, dtype=np.float32, seed=42) + # Force all token counts to zero + inputs["token_num_per_batch"] = np.zeros(4, dtype=np.int32) + inputs["token_num_output_cpu"] = 0 + outputs = self._run_and_get(inputs) + self.assertEqual(outputs["output_params"].size, 0) + + def test_exact_golden_float32(self): + """Exact golden values for float32 -- hand-verified.""" + inputs = { + "input_params": np.array([0.5, 0.9, 0.1], dtype=np.float32), + "token_num_per_batch": np.array([2, 3, 1], dtype=np.int32), + "token_num_output_cpu": 6, + } + outputs = self._run_and_get(inputs) + expected = np.array([0.5, 0.5, 0.9, 0.9, 0.9, 0.1], dtype=np.float32) + np.testing.assert_array_equal(outputs["output_params"], expected) + + def test_exact_golden_int32(self): + """Exact golden values for int32 -- hand-verified.""" + inputs = { + "input_params": np.array([10, 20, 30], dtype=np.int32), + "token_num_per_batch": np.array([1, 2, 3], dtype=np.int32), + "token_num_output_cpu": 6, + } + outputs = self._run_and_get(inputs) + expected = np.array([10, 20, 20, 30, 30, 30], dtype=np.int32) + np.testing.assert_array_equal(outputs["output_params"], expected) + + def test_exact_golden_bool(self): + """Exact golden values for bool -- hand-verified.""" + inputs = { + "input_params": np.array([True, False, True], dtype=np.bool_), + "token_num_per_batch": np.array([3, 2, 1], dtype=np.int32), + "token_num_output_cpu": 6, + } + outputs = self._run_and_get(inputs) + expected = np.array([True, True, True, False, False, True], dtype=np.bool_) + np.testing.assert_array_equal(outputs["output_params"], expected) + + def test_mixed_with_empty_slots(self): + """Some batch items have zero tokens (empty slots).""" + inputs = { + "input_params": np.array([0.5, 0.9, 0.1, 0.7], dtype=np.float32), + "token_num_per_batch": np.array([2, 0, 3, 0], dtype=np.int32), + "token_num_output_cpu": 5, + } + outputs = self._run_and_get(inputs) + # bi=0: tokens 0,1 -> 0.5; bi=1: empty; bi=2: tokens 2,3,4 -> 0.1; bi=3: empty + expected = np.array([0.5, 0.5, 0.1, 0.1, 0.1], dtype=np.float32) + np.testing.assert_array_equal(outputs["output_params"], expected) + + +if __name__ == "__main__": + unittest.main() From eee8289626a2be6568248fa2fbb6c11495820ea5 Mon Sep 17 00:00:00 2001 From: ChowMingSing <610208940@qq.com> Date: Mon, 27 Apr 2026 21:42:49 +0800 Subject: [PATCH 068/143] [Bugfix]compile support SM100 (#7581) (#7629) * [Bugfix]compile support SM100 * [Bugfix]compile support SM100 --- custom_ops/setup_ops.py | 105 ++++++++++++++++++++-------------------- 1 file changed, 53 insertions(+), 52 deletions(-) diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 180116bf2c7..b9a2fe90dbc 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -470,58 +470,59 @@ def find_end_files(directory, end_str): # This script seems general enough for different SM versions, specific templates are chosen by CUTLASS. os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py") - if cc >= 90: # Hopper and newer - # SM90 (Hopper) specific auto-generation and flags - if cc == 90: # Only for SM90 - nvcc_compile_args += [ - # The gencode for 90a is added in get_gencode_flags now - # "-gencode", - # "arch=compute_90a,code=compute_90a", - "-O3", - "-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a - ] - print("SM90: Running SM90-specific FP8 kernel auto-generation.") - os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") - os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py") - os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py") - - nvcc_compile_args += [ - "-DENABLE_SCALED_MM_SM90=1", - ] - sources += [ - "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", - "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", - ] - elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics - print("SM100 (Blackwell): Applying SM100 configurations.") - nvcc_compile_args += [ - # The gencode for 100a is added in get_gencode_flags - # "-gencode", - # "arch=compute_100a,code=compute_100a", - "-O3", # Common optimization flag - "-DNDEBUG", # Common debug flag - # Potentially add -DENABLE_SM100_FEATURES if specific macros are identified - ] - # Placeholder for SM100-specific kernel auto-generation scripts - # These might be needed if Blackwell has new FP8 hardware features - # not covered by existing generic CUTLASS templates or SM90 scripts. - # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).") - # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example - # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example - - # Add SM100 specific sources if any, e.g., for new hardware intrinsics - # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example - pass # No SM100 specific sources identified yet beyond what CUTLASS handles - else: # For cc >= 89 but not 90 or 100 (e.g. SM89) - print(f"SM{cc}: Running generic FP8 kernel auto-generation.") - os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") - os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") - - else: # For cc == 89 (Ada) - print("SM89: Running generic FP8 kernel auto-generation.") + # Use non-exclusive checks against sm_versions so that building for + # multiple architectures (e.g. [80,90,100]) compiles kernels for ALL + # of them instead of only the highest one. + has_sm90 = 90 in sm_versions + has_sm100 = 100 in sm_versions and nvcc_version >= 12.9 + has_generic_fp8 = not has_sm90 and not has_sm100 # SM89 or other + + if has_sm90 or has_sm100: + nvcc_compile_args += [ + "-O3", + "-DNDEBUG", + ] + + if has_sm90: + print("SM90: Running SM90-specific FP8 kernel auto-generation.") + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py") + os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py") + + nvcc_compile_args += [ + "-DENABLE_SCALED_MM_SM90=1", + ] + sources += [ + "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", + "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", + ] + + if has_sm100: + print("SM100 (Blackwell): Applying SM100 configurations.") + # Placeholder for SM100-specific kernel auto-generation scripts + # These might be needed if Blackwell has new FP8 hardware features + # not covered by existing generic CUTLASS templates or SM90 scripts. + # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).") + # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example + # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example + + # Add SM100 specific sources if any, e.g., for new hardware intrinsics + # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example + pass # No SM100 specific sources identified yet beyond what CUTLASS handles + + if has_generic_fp8: + # For SM89 (Ada) or other architectures without dedicated paths + print(f"SM{cc}: Running generic FP8 kernel auto-generation.") + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") + + if not has_sm90 and cc >= 90: + # When cc >= 90 but SM90 is not in the target list (e.g. only [80,100]), + # still run generic FP8 auto-generation for non-SM90 paths. + print(f"SM{cc}: Running generic FP8 kernel auto-generation (no SM90 target).") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") From 99444f636a6c32840dad268754e51c36a910b743 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Tue, 28 Apr 2026 10:48:38 +0800 Subject: [PATCH 069/143] fix fp8 infer error (#7627) (#7631) Co-authored-by: JYChen --- .../model_executor/layers/quantization/block_wise_fp8.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index 19f4597ab34..97f72e026fc 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -155,7 +155,7 @@ def deep_gemm_fp8_gemm_nt( linear_out, ) if bias is not None: - linear_out = paddle.add(linear_out, bias) + linear_out.add_(bias) else: fp8_gemm_nt( (x, x_scale_tensor), @@ -370,7 +370,7 @@ def apply(self, layer, x): ) x_scale_tensor = x_scale_tensor.T[: x.shape[0], ...] - if get_sm_version() == 100 and current_platform.is_cuda(): + if get_sm_version() >= 100 and current_platform.is_cuda(): deep_gemm_fp8_gemm_nt( x, x_scale_tensor, @@ -391,5 +391,4 @@ def apply(self, layer, x): ) if layer.with_bias: linear_out = paddle.add(linear_out, layer.bias) - return linear_out From 23e0a84314ea9081c089dd267aff400e4e0cdc5e Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Tue, 28 Apr 2026 11:04:37 +0800 Subject: [PATCH 070/143] [Cherry-Pick][CI] Pin Paddle to release/3.3 last_commit build in 2.6(#7547) (#7571) * [Cherry-Pick][CI] Pin PaddlePaddle to release/3.3 last_commit build(#7547) * [Cherry-Pick][BugFix] Fix deep gemm import(#7425) * [CI][BugFix] fix test_vl_prefix_caching_swap with pyarrow==24.0.0 --- .github/workflows/_accuracy_test.yml | 2 +- .github/workflows/_base_test.yml | 2 +- .github/workflows/_build_linux.yml | 2 +- .github/workflows/_build_linux_cu129.yml | 2 +- .github/workflows/_build_linux_cu130.yml | 2 +- .github/workflows/_build_linux_rl.yml | 2 +- .github/workflows/_golang_router_test.yml | 2 +- .github/workflows/_gpu_4cards_case_test.yml | 2 +- .github/workflows/_logprob_test_linux.yml | 2 +- .github/workflows/_pre_ce_test.yml | 2 +- .github/workflows/_stable_test.yml | 2 +- .github/workflows/_unit_test_coverage.yml | 2 +- fastdeploy/model_executor/models/deepseek_v3.py | 4 +--- scripts/run_pre_ce.sh | 3 +++ 14 files changed, 16 insertions(+), 15 deletions(-) diff --git a/.github/workflows/_accuracy_test.yml b/.github/workflows/_accuracy_test.yml index f06e3c88669..e2c8d40dbfe 100644 --- a/.github/workflows/_accuracy_test.yml +++ b/.github/workflows/_accuracy_test.yml @@ -180,7 +180,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index 99bf7209747..7087183e447 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -213,7 +213,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux.yml b/.github/workflows/_build_linux.yml index bd8f9f257c2..5865a3cc7fd 100644 --- a/.github/workflows/_build_linux.yml +++ b/.github/workflows/_build_linux.yml @@ -196,7 +196,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ else - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux_cu129.yml b/.github/workflows/_build_linux_cu129.yml index 9800795b1b2..aabf5bb16a9 100644 --- a/.github/workflows/_build_linux_cu129.yml +++ b/.github/workflows/_build_linux_cu129.yml @@ -183,7 +183,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ else - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu129/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux_cu130.yml b/.github/workflows/_build_linux_cu130.yml index 85593eada9a..a294c3557e4 100644 --- a/.github/workflows/_build_linux_cu130.yml +++ b/.github/workflows/_build_linux_cu130.yml @@ -183,7 +183,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu130/ else - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu130/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda130-Cudnn913-Trt1013-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu130/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_build_linux_rl.yml b/.github/workflows/_build_linux_rl.yml index 9c3f2a47966..fb3a85a5685 100644 --- a/.github/workflows/_build_linux_rl.yml +++ b/.github/workflows/_build_linux_rl.yml @@ -166,7 +166,7 @@ jobs: cd FastDeploy python -m pip uninstall paddlepaddle-gpu -y || true - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu129/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_golang_router_test.yml b/.github/workflows/_golang_router_test.yml index 93c794482f6..62810d527a0 100644 --- a/.github/workflows/_golang_router_test.yml +++ b/.github/workflows/_golang_router_test.yml @@ -212,7 +212,7 @@ jobs: git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt diff --git a/.github/workflows/_gpu_4cards_case_test.yml b/.github/workflows/_gpu_4cards_case_test.yml index be580a08dd9..c4b771c15fd 100644 --- a/.github/workflows/_gpu_4cards_case_test.yml +++ b/.github/workflows/_gpu_4cards_case_test.yml @@ -208,7 +208,7 @@ jobs: cd FastDeploy git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt diff --git a/.github/workflows/_logprob_test_linux.yml b/.github/workflows/_logprob_test_linux.yml index b0ebec7d791..5ccd0be40fa 100644 --- a/.github/workflows/_logprob_test_linux.yml +++ b/.github/workflows/_logprob_test_linux.yml @@ -189,7 +189,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_pre_ce_test.yml b/.github/workflows/_pre_ce_test.yml index 8420c388ac3..0669d503d80 100644 --- a/.github/workflows/_pre_ce_test.yml +++ b/.github/workflows/_pre_ce_test.yml @@ -201,7 +201,7 @@ jobs: --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ python -m pip install ${fd_wheel_url} bash scripts/run_pre_ce.sh ' diff --git a/.github/workflows/_stable_test.yml b/.github/workflows/_stable_test.yml index fc89dfa6cac..8678490f9d7 100644 --- a/.github/workflows/_stable_test.yml +++ b/.github/workflows/_stable_test.yml @@ -193,7 +193,7 @@ jobs: -e TZ="Asia/Shanghai" \ -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple diff --git a/.github/workflows/_unit_test_coverage.yml b/.github/workflows/_unit_test_coverage.yml index 9f06d2f1ccb..0a939c49583 100644 --- a/.github/workflows/_unit_test_coverage.yml +++ b/.github/workflows/_unit_test_coverage.yml @@ -224,7 +224,7 @@ jobs: git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt - python -m pip install paddlepaddle-gpu==3.5.0.dev20260417 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index aa3f3af346e..3270bbf4308 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -75,8 +75,6 @@ radix_topk_ragged_transform, ) - paddle.enable_compat(scope={"deep_gemm"}) - class DeepSeekV3MLP(nn.Layer): """ @@ -665,7 +663,7 @@ def forward( # indexer write_cache indexer_k_quant_and_cache(k, self.indexer_cache, slot_mapping, self.quant_block_size, self.scale_fmt) - import deep_gemm + from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm if forward_meta.max_len_tensor_cpu[1]: diff --git a/scripts/run_pre_ce.sh b/scripts/run_pre_ce.sh index 4bf56a290aa..069a2e938a0 100644 --- a/scripts/run_pre_ce.sh +++ b/scripts/run_pre_ce.sh @@ -13,6 +13,9 @@ python -m pip install \ https://paddle-qa.bj.bcebos.com/FastDeploy/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \ https://paddle-qa.bj.bcebos.com/FastDeploy/xgrammar-0.1.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +# fix tests/ci_use/Prefix_Caching_Swap/test_vl_prefix_caching_swap.py (requires new pyarrow memory behavior) +python -m pip install pyarrow==24.0.0 + failed_files=() run_path="$DIR/../tests/ci_use/" From 5582b5af1e1ee82b68aa109ecfff5a46f92081d2 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Tue, 28 Apr 2026 17:15:54 +0800 Subject: [PATCH 071/143] [BugFix][Speculative Decoding] Fix tokens_per_seq min value calculation for non-MTP speculative methods (#7623) (#7640) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [BugFix][Speculative Decoding] Fix tokens_per_seq calculation to only apply speculative tokens for MTP method ## Motivation 非 MTP 的 speculative 方法(如 NGRAM)也会有 num_speculative_tokens,但 tokens_per_seq 原来只判断 speculative_config is not None,导致非 MTP 场景下 max_num_batched_tokens 最小值校验偏大,错误拦截合法配置。 ## Modifications - fastdeploy/config.py: check() 方法中 tokens_per_seq 计算,增加 method == SpecMethod.MTP 条件判断 - fastdeploy/engine/sched/resource_manager_v1.py: 同上,保持两处逻辑一致 * [BugFix][Speculative Decoding] Fix tokens_per_seq: use method is not None instead of MTP check Co-authored-by: kevin --- fastdeploy/config.py | 2 +- fastdeploy/engine/sched/resource_manager_v1.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 3ee93724f28..8f16f789eab 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2271,7 +2271,7 @@ def check(self): ), f"max_num_seqs: {self.scheduler_config.max_num_seqs} should be larger than 1" tokens_per_seq = ( (getattr(self.speculative_config, "num_speculative_tokens", 0) + 1) - if self.speculative_config is not None + if self.speculative_config is not None and self.speculative_config.method is not None else 1 ) assert self.scheduler_config.max_num_batched_tokens >= self.scheduler_config.max_num_seqs * tokens_per_seq, ( diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 27a1328c1ee..38b4eb14381 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -776,7 +776,7 @@ def get_enough_request(request, scheduled_reqs): error_reqs: list[tuple[str, str]] = [] tokens_per_seq = ( (self.config.speculative_config.num_speculative_tokens + 1) - if self.config.speculative_config is not None + if self.config.speculative_config is not None and self.config.speculative_config.method is not None else 1 ) num_running_decode_reqs = sum(1 for req in self.running if self._is_decoding(req)) From ecb31fb126c940637a8186d3d4d69662adef8807 Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Tue, 28 Apr 2026 19:40:04 +0800 Subject: [PATCH 072/143] [KVCache] Support flush FD GPU/CPU Cache index by AttentionStore (#7644) (cherry picked from commit 6ead1397cea8c167cdac99094366275db48e4a17) --- fastdeploy/cache_manager/cache_tasks.py | 7 +++- .../cache_manager/cache_transfer_manager.py | 34 +++++++++++++++- .../cache_manager/prefix_cache_manager.py | 29 ++++++++++++-- fastdeploy/envs.py | 2 + .../test_prefix_cache_manager.py | 40 +++++++++++++++++++ 5 files changed, 107 insertions(+), 5 deletions(-) diff --git a/fastdeploy/cache_manager/cache_tasks.py b/fastdeploy/cache_manager/cache_tasks.py index fe15263827a..d50c809c0c6 100644 --- a/fastdeploy/cache_manager/cache_tasks.py +++ b/fastdeploy/cache_manager/cache_tasks.py @@ -15,7 +15,7 @@ """ from dataclasses import dataclass -from typing import List +from typing import List, Optional @dataclass(frozen=True, kw_only=True) @@ -35,3 +35,8 @@ class ReadStorageTask(CacheTask): @dataclass(frozen=True, kw_only=True) class WriteStorageTask(CacheTask): timeout: float = 30.0 + # Used in FD_AS_ONLY_FLUSH mode to indicate whether cache is present on this node. + # True = cache exists (request finish), False = cache gone (CPU eviction), None = not applicable. + flush_cache_exists: Optional[bool] = None + # Block index to start the write/flush operation from. Defaults to 0 (all blocks). + start_write_block_idx: int = 0 diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index f5e39fe0c0f..bfa0aca16d5 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -526,7 +526,8 @@ def _init_storage(self, args): try: # TODO: support cache scale for other backend if self.has_cache_scale and self.storage_backend_type is not None: - if self.storage_backend_type not in ["mooncake"]: + is_as_only_flush = envs.FD_AS_ONLY_FLUSH and self.storage_backend_type == "attention_store" + if not is_as_only_flush and self.storage_backend_type not in ["mooncake"]: raise ValueError( f"Unsupported storage backend ({self.storage_backend_type}) " "when cache quantization is block_wise_fp8" @@ -1153,6 +1154,34 @@ def _run_write_back_storage( ) return 0 + def _flush_only_storage_task(self, task: WriteStorageTask): + """ + AS-only flush mode: skip actual storage write, only report cache index to AttentionStore. + Used when FD_AS_ONLY_FLUSH is enabled — AS acts as index-only (no data storage). + + Args: + task: WriteStorageTask with flush_cache_exists indicating cache state: + True/None = cache present on this node (request finish) + False = cache gone from this node (eviction) + """ + try: + if (self.rank == 0) and self.storage_backend_type == "attention_store": + reside_in_gpu = task.flush_cache_exists if task.flush_cache_exists is not None else True + self.storage_backend.flush_token_index( + task.task_id, task.token_ids, task.start_write_block_idx, reside_in_gpu + ) + logger.info( + f"[AS_ONLY_FLUSH] flush token index reside_in_gpu={reside_in_gpu} " + f"start_block_idx={task.start_write_block_idx} for task {task.task_id}" + ) + except Exception as e: + logger.warning(f"[AS_ONLY_FLUSH] Failed to flush token index for task {task.task_id}, error: {e}") + result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, []) + self.cache_task_queue.swap_to_storage_barrier.wait() + if self.rank == 0: + self.cache_task_queue.swap_to_storage_barrier.reset() + self.cache_task_queue.put_transfer_done_signal(result) + def write_back_storage_task(self, task: WriteStorageTask): """ Write cache to the storage backend from the GPU memory. @@ -1161,6 +1190,9 @@ def write_back_storage_task(self, task: WriteStorageTask): self.storage_backend ), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}" + if envs.FD_AS_ONLY_FLUSH: + return self._flush_only_storage_task(task) + try: gpu_block_ids = task.gpu_block_ids.copy() cpu_block_ids = [i for i in range(len(gpu_block_ids))] diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 328fd224f5b..537f092a5ed 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -313,6 +313,7 @@ def launch_cache_manager( + visible_devices + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" + + f" FD_AS_ONLY_FLUSH={int(envs.FD_AS_ONLY_FLUSH)}" + f" {sys.executable} {py_path}" + f" --device_id {int(device_ids[i])}" + f" --rank {i}" @@ -858,7 +859,7 @@ def request_match_blocks(self, task: Request, block_size, *args): storage_match_token_num = 0 match_storage_block_ids = [] - if self.kvcache_storage_backend and no_match_token_num >= block_size: + if self.kvcache_storage_backend and no_match_token_num >= block_size and not envs.FD_AS_ONLY_FLUSH: if not self.can_allocate_gpu_blocks(num_blocks=no_match_block_num, try_free_gpu_blocks=False): raise Exception( "request_match_blocks: Not enough GPU memory to allocate cache for matched Storage Cache" @@ -1270,14 +1271,15 @@ def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): if self.kvcache_storage_backend is None: return - if len(task.keys) != len(task.gpu_block_ids): + if not envs.FD_AS_ONLY_FLUSH and len(task.keys) != len(task.gpu_block_ids): err_msg = ( f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})" ) logger.error(err_msg) raise ValueError(err_msg) - self.task_write_back_event[task.task_id] = Event() + if is_sync: + self.task_write_back_event[task.task_id] = Event() self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task)) if is_sync: self.wait_write_storage_task(task.task_id) @@ -1464,6 +1466,7 @@ def free_block_ids_async(self, need_block_num): hash_value_swap_node_ids_map = defaultdict(list) hash_value_gpu_block_ids_map = defaultdict(list) + hash_value_flush_info = {} # {input_hash_value: (token_ids, min_depth)} total_gpu_free_count = 0 while True: @@ -1476,6 +1479,10 @@ def free_block_ids_async(self, need_block_num): self.gpu_lru_leaf_set.remove(node) if self.cache_config.num_cpu_blocks < need_block_num: if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收 + if envs.FD_AS_ONLY_FLUSH and self.kvcache_storage_backend == "attention_store": + key = node.input_hash_value + if key not in hash_value_flush_info or node.depth < hash_value_flush_info[key][1]: + hash_value_flush_info[key] = (node.input_ids, node.depth) self._handle_free_gpu_node_without_cpu(node) total_gpu_free_count += 1 cur_node = node @@ -1525,6 +1532,22 @@ def free_block_ids_async(self, need_block_num): f"free_block_ids_async: need_block_num {need_block_num}, free_block_num {total_gpu_free_count}." ) + if ( + envs.FD_AS_ONLY_FLUSH + and self.kvcache_storage_backend == "attention_store" + and hash_value_flush_info + ): + for input_hash_value, (token_ids, min_depth) in hash_value_flush_info.items(): + flush_task = WriteStorageTask( + task_id=str(uuid.uuid4()), + keys=[input_hash_value], + token_ids=token_ids, + gpu_block_ids=[], + flush_cache_exists=False, + start_write_block_idx=min_depth - 1, + ) + self.issue_write_back_storage_task(flush_task, is_sync=False) + # swap cache to cpu if hash_value_gpu_block_ids_map: self.cpu_free_future = None diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 96bc09934a8..7e0f809d5d3 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -154,6 +154,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), # Whether to clear cpu cache when clearing model weights. "FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")), + # AS-only flush mode: AttentionStore only reports cache index without storing actual data. + "FD_AS_ONLY_FLUSH": lambda: bool(int(os.getenv("FD_AS_ONLY_FLUSH", "0"))), # enable return text, used when FD_ENABLE_INTERNAL_ADAPTER=1 "FD_ENABLE_RETURN_TEXT": lambda: bool(int(os.getenv("FD_ENABLE_RETURN_TEXT", "0"))), # Used to truncate the string inserted during thinking when reasoning in a model. ( for ernie-45-vl, \n\n\n for ernie-x1) diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index 2ed9a0e6b02..07df533f626 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -1544,6 +1544,46 @@ def test_reset_sets_empty_cpu_free_list_when_no_cpu_blocks(self): manager.reset() self.assertEqual(manager.cpu_free_block_list, []) + @patch("fastdeploy.cache_manager.prefix_cache_manager.envs") + def test_free_gpu_block_ids_flushes_cache_gone_with_as_only_flush(self, mock_envs): + """Verify GPU-only eviction sends flush(flush_cache_exists=False) with correct start_write_block_idx.""" + mock_envs.FD_AS_ONLY_FLUSH = True + manager = _create_manager(num_gpu_blocks=4, num_cpu_blocks=0) + manager.kvcache_storage_backend = "attention_store" + + gpu_hash = get_hash_str([9, 10]) + node = BlockNode( + 91, + [9, 10], + gpu_hash, + 3, + 0, + 2, + gpu_hash, + 0, + parent=manager.radix_tree_root, + cache_status=CacheStatus.GPU, + ) + node.shared_count = 0 + node.block_id = 12 + manager.radix_tree_root.children[gpu_hash] = node + manager.node_map[node.node_id] = node + manager.gpu_lru_leaf_heap.append(node) + manager.gpu_lru_leaf_set.add(node) + + captured_tasks = [] + manager.issue_write_back_storage_task = lambda task, is_sync=True: captured_tasks.append(task) + + manager.free_block_ids_async(1) + + self.assertEqual(len(captured_tasks), 1) + flush_task = captured_tasks[0] + self.assertFalse(flush_task.flush_cache_exists) + self.assertEqual(flush_task.keys, [gpu_hash]) + self.assertEqual(flush_task.token_ids, [9, 10]) + self.assertEqual(flush_task.gpu_block_ids, []) + self.assertEqual(flush_task.start_write_block_idx, 2) + if __name__ == "__main__": unittest.main() From 37672f9b3692de30cc2c1e7f5acf644330afcf81 Mon Sep 17 00:00:00 2001 From: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com> Date: Tue, 28 Apr 2026 19:40:14 +0800 Subject: [PATCH 073/143] support different AS interface for GPU and XPU (#7380) (#7647) --- .../mooncake_store/attention_store.py | 107 +++++++++++++----- 1 file changed, 77 insertions(+), 30 deletions(-) diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index 466caef59e9..9b561690a6e 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -25,12 +25,16 @@ KVCacheStorage, logger, ) +from fastdeploy.platforms import current_platform try: import attentionstore_sdk.api.common.common_pb2 as common_pb2 from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens from attentionstore_sdk.utils.err import AttentionStoreSDKError + if current_platform.is_cuda(): + from attentionstore_sdk.client.client import AttentionType + _ATTENTIONSTORE_AVAILABLE = True except Exception: AttentionStoreSDK = None @@ -63,18 +67,36 @@ def __init__(self, **args): try: logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}") - self.sdk = AttentionStoreSDK( - self.config.namespace, - self.config.pod_name, - self.config.model_version, - self.config.shard_id, - self.config.shard_num, - self.config.layer_num, - self.config.block_token_size, - self.config.bytes_per_shard_layer_per_block, - self.config.device_id, - self.config.dp_id, - ) + if current_platform.is_cuda(): + self.sdk = AttentionStoreSDK( + self.config.namespace, + self.config.pod_name, + self.config.model_version, + self.config.shard_id, + self.config.shard_num, + self.config.layer_num, + self.config.block_token_size, + self.config.bytes_per_shard_layer_per_block, + self.config.bytes_per_shard_layer_per_block, + self.config.device_id, + self.config.dp_id, + attention_type=AttentionType.MHA, + enable_as_kv_rw=True, + gpu_count=0, + ) + else: + self.sdk = AttentionStoreSDK( + self.config.namespace, + self.config.pod_name, + self.config.model_version, + self.config.shard_id, + self.config.shard_num, + self.config.layer_num, + self.config.block_token_size, + self.config.bytes_per_shard_layer_per_block, + self.config.device_id, + self.config.dp_id, + ) self.wait_for_sdk_ready(timeout=300, delta_t=5) logger.info("[INIT] ✅ AttentionStore is initialized successfully!") except Exception as e: @@ -120,15 +142,27 @@ def read( v_data_ptrs = [v.data_ptr() for v in val_cache] num = 0 try: - num = self.sdk.read( - list(range(self.config.layer_num)), - tokens, - start_read_block_idx, - k_data_ptrs, - v_data_ptrs, - gpu_block_ids, - timeout, - ) + if current_platform.is_cuda(): + num = self.sdk.read( + list(range(self.config.layer_num)), + tokens, + start_read_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + remote_addrs=None, + ) + else: + num = self.sdk.read( + list(range(self.config.layer_num)), + tokens, + start_read_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) logger.debug(f"[READ END] task_id: {task_id} read_blocks: {num}") except AttentionStoreSDKError: logger.error( @@ -154,15 +188,28 @@ def write( v_data_ptrs = [v.data_ptr() for v in val_cache] num = 0 try: - num = self.sdk.write( - list(range(self.config.layer_num)), - tokens, - start_write_block_idx, - k_data_ptrs, - v_data_ptrs, - gpu_block_ids, - timeout, - ) + if current_platform.is_cuda(): + num = self.sdk.write( + list(range(self.config.layer_num)), + tokens, + start_write_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + h2h_copy=False, + params=None, + ) + else: + num = self.sdk.write( + list(range(self.config.layer_num)), + tokens, + start_write_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}") except AttentionStoreSDKError: logger.error( From bfff3d9b6ea4a817ce7a4df31e7da53bb16131c0 Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Tue, 28 Apr 2026 19:42:34 +0800 Subject: [PATCH 074/143] [Cherry-Pick][KVCache] Support environment variable overrides for AttentionStore config(#7455) (#7643) * [KVCache] Support environment variable overrides for AttentionStore config (cherry picked from commit 3f18128e8fb7a6021dcd9b0b20963f6468de8adc) * [KVCache] Support environment variable overrides for AttentionStore config (cherry picked from commit bba23b65b94049358b3a0894dffe3c224926e962) --- fastdeploy/cache_manager/cache_transfer_manager.py | 1 + .../transfer_factory/mooncake_store/attention_store.py | 9 +++++++++ tests/cache_manager/test_cache_transfer_manager.py | 2 ++ 3 files changed, 12 insertions(+) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index bfa0aca16d5..50def0a6c36 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -555,6 +555,7 @@ def _init_storage(self, args): * self.cache_item_bytes, device_id=self.device, dp_id=self.local_data_parallel_id, + splitwise_role=getattr(args, "splitwise_role", "mixed"), ) logger.info("Initialized attention store successfully!") elif args.kvcache_storage_backend == "file": diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index 9b561690a6e..a938c043422 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import os import time import traceback from dataclasses import dataclass @@ -55,6 +56,7 @@ class AttentionStoreConfig: bytes_per_shard_layer_per_block: int = 1024 device_id: int = 0 dp_id: int = 0 + splitwise_role: str = "mixed" class AttentionStore(KVCacheStorage): @@ -66,6 +68,13 @@ def __init__(self, **args): self.config = AttentionStoreConfig(**args) try: + self.config.namespace = os.getenv("AS_NAMESPACE", self.config.namespace) + self.config.pod_name = os.getenv("AS_POD_NAME", self.config.pod_name) + if int(os.getenv("ENABLE_EP_DP_IN_FD", "1")): + self.config.pod_name = ( + self.config.pod_name + "_" + self.config.splitwise_role + "_" + str(self.config.dp_id) + ) + self.config.model_version = os.getenv("AS_MODEL_VERSION", self.config.model_version) logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}") if current_platform.is_cuda(): self.sdk = AttentionStoreSDK( diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index f3699d4fc3b..afcd64574dd 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -65,6 +65,7 @@ class Args: kvcache_storage_backend = None write_policy = "write_through" model_path = "test_model" + splitwise_role = "mixed" routing_replay_config = MagicMock(enable_routing_replay=False) @@ -724,6 +725,7 @@ class LocalArgs(Args): * manager.cache_item_bytes, device_id=manager.device, dp_id=manager.local_data_parallel_id, + splitwise_role=LocalArgs.splitwise_role, ) def test_invalid_write_policy_raises(self): From 188db3566deb3bab08bd759d2e5660548428487e Mon Sep 17 00:00:00 2001 From: RAM Date: Tue, 28 Apr 2026 21:52:02 +0800 Subject: [PATCH 075/143] [RL] Correct the semantics of max_num_batched_tokens with multimodal in R3 (#7655) * Correct the semantics of max_num_batched_tokens with multi mode * fix D2H bug --- fastdeploy/config.py | 2 +- .../model_executor/layers/moe/routing_indices_cache.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 8f16f789eab..3d51eaf4e47 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2449,7 +2449,7 @@ def reset_value(cls, value_name, key): ) reset_value(self.cache_config, "cache_dtype", "infer_model_dtype") - def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): + def get_max_chunk_tokens(self, mm_max_tokens_per_item=None) -> int: """ get max chunk tokens diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index 534e303c89c..21e9f406366 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -132,7 +132,8 @@ def _init_routing_cache(self, dtype: str, total_block_num: int): max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size # Small GPU transient buffer: only current step's token routing - max_num_batched_tokens = self.fd_config.scheduler_config.max_num_batched_tokens + # TODO(Chengyanfu): Use max_num_batched_tokens to replace get_max_chunk_tokens() + max_num_batched_tokens = self.fd_config.get_max_chunk_tokens() self.gpu_routing_buffer = paddle.full( shape=[max_num_batched_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, @@ -218,7 +219,7 @@ def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): positions = [] for i in range(seq_lens_this_time.shape[0]): - if seq_lens_this_time[i] == 0: + if increase_num[i] == 0: positions.append([]) continue repeated_base = np.repeat(starts[i], increase_num[i]) From 0aa3e25e8e0b35c6db64f5504ecd82f887ab5a34 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Wed, 29 Apr 2026 10:36:43 +0800 Subject: [PATCH 076/143] [Cherry-Pick][RL] rl support mix_quant (#7645) (#7650) * rl support mix_quant * code check --- .../layers/moe/fused_moe_cutlass_backend.py | 4 ++-- fastdeploy/rl/rollout_config.py | 7 ++++--- fastdeploy/utils.py | 8 ++++++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 96423cc6e4f..92e039dd742 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -176,7 +176,7 @@ def apply_ep_prefill( ) ) - if paddlefleet_ops is not None: + if fastdeploy.envs.FD_USE_DEEP_GEMM: out = m_grouped_bf16_gemm_nn_contiguous( permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices ) @@ -192,7 +192,7 @@ def apply_ep_prefill( else: out = paddle.incubate.nn.functional.swiglu(out) - if paddlefleet_ops is not None: + if fastdeploy.envs.FD_USE_DEEP_GEMM: ffn_out = m_grouped_bf16_gemm_nn_contiguous( out, getattr(layer, self.added_weight_attrs[1]), m_indices ) diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 0caefd9ada1..59a7822c3ff 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -14,8 +14,9 @@ # limitations under the License. """ -from typing import Any, Dict, Optional +from typing import Dict, Optional, Union +from fastdeploy.utils import parse_quantization from fastdeploy.worker.worker_process import initialize_fd_config @@ -54,7 +55,7 @@ def __init__( expert_parallel_size: int = 1, enable_expert_parallel: bool = False, ori_vocab_size: int = None, - quantization: Optional[Dict[str, Any]] = None, + quantization: Optional[Union[Dict, str]] = None, guided_decoding_backend: str = "off", disable_any_whitespace: bool = True, enable_logprob: bool = False, @@ -108,7 +109,7 @@ def __init__( self.enable_expert_parallel = enable_expert_parallel self.data_parallel_size = data_parallel_size self.ori_vocab_size = ori_vocab_size - self.quantization = quantization + self.quantization = parse_quantization(quantization) self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace self.enable_logprob = enable_logprob diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 0a591dc2777..6c0b72ae8ae 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -38,7 +38,7 @@ from importlib.metadata import PackageNotFoundError, distribution from logging.handlers import BaseRotatingHandler from pathlib import Path -from typing import Any, Literal, TypeVar, Union +from typing import Any, Dict, Literal, TypeVar, Union import numpy as np import paddle @@ -1044,10 +1044,14 @@ def status(self) -> dict: } -def parse_quantization(value: str): +def parse_quantization(value: Union[Dict, str]) -> Dict: """ Parse a JSON string into a dictionary. """ + if isinstance(value, dict): + return value + if value is None: + value = "null" try: return json.loads(value) except ValueError: From 32d5f5b329500532e9d659bf5f4932416df43c43 Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 29 Apr 2026 14:08:08 +0800 Subject: [PATCH 077/143] Refine metrics and trace for pd (#7613) (#7661) --- docs/online_serving/metrics.md | 78 +++++++++++++++++++ docs/zh/online_serving/metrics.md | 78 +++++++++++++++++++ fastdeploy/engine/common_engine.py | 27 +++++++ .../engine/sched/resource_manager_v1.py | 14 +++- fastdeploy/metrics/metrics.py | 33 +++++++- fastdeploy/output/token_processor.py | 28 ++++--- fastdeploy/trace/constants.py | 24 ++++++ .../test_token_processor_trace_print.py | 4 +- 8 files changed, 269 insertions(+), 17 deletions(-) diff --git a/docs/online_serving/metrics.md b/docs/online_serving/metrics.md index 8e2dd286888..4f30cf3a222 100644 --- a/docs/online_serving/metrics.md +++ b/docs/online_serving/metrics.md @@ -46,3 +46,81 @@ After FastDeploy is launched, it supports continuous monitoring of the FastDeplo - Access URL: `http://localhost:8000/metrics` - Metric Type: Prometheus format + +## Trace Events + +FastDeploy outputs structured trace events to `trace.log` at key stages of request processing, useful for diagnosing per-request latency bottlenecks. Each trace log entry contains fields such as `timestamp` (milliseconds), `request_id`, `event`, and `stage`. + +### Common Events (Mixed / All Instances) + +| Stage | Event | Description | +| :---: | --- | --- | +| PREPROCESSING | `PREPROCESSING_START` | API Server begins preprocessing the request | +| PREPROCESSING | `PREPROCESSING_END` | Engine receives the request, preprocessing complete | +| SCHEDULE | `REQUEST_SCHEDULE_START` | Request enters the scheduling flow | +| SCHEDULE | `REQUEST_QUEUE_START` | Request enters the scheduling queue | +| SCHEDULE | `REQUEST_QUEUE_END` | Request dequeued from the scheduling queue | +| SCHEDULE | `RESOURCE_ALLOCATE_START` | Resource allocation begins for the request | +| SCHEDULE | `PREPARE_PREFIX_CACHE_START` | Prefix cache block matching begins | +| SCHEDULE | `PREPARE_PREFIX_CACHE_END` | Prefix cache block matching complete | +| SCHEDULE | `RESOURCE_ALLOCATE_END` | Resource allocation complete | +| SCHEDULE | `REQUEST_SCHEDULE_END` | Scheduling flow complete | +| PREFILL | `INFERENCE_START` | Request sent to GPU for inference | +| PREFILL | `FIRST_TOKEN_GENERATED` | First token generated | +| DECODE | `DECODE_START` | Enters Decode phase | +| DECODE | `INFERENCE_END` | Inference complete (all tokens generated) | +| DECODE | `PREEMPTED` | Request preempted | +| DECODE | `RESCHEDULED_INFERENCE_START` | Preempted request rescheduled for execution | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_START` | Begins writing KV Cache to external storage | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_END` | KV Cache written to external storage | +| POSTPROCESSING | `POSTPROCESSING_START` | Post-processing begins | +| POSTPROCESSING | `POSTPROCESSING_END` | Post-processing complete, response sent | + +### PD Disaggregation — Prefill (P) Instance Events + +| Stage | Event | Description | +| :---: | --- | --- | +| SCHEDULE | `ASK_DECODE_RESOURCE_START` | P begins requesting resources from D (sends ZMQ request) | +| SCHEDULE | `ASK_DECODE_RESOURCE_END` | P receives resource allocation confirmation from D (with dest_block_ids) | +| PREFILL | `PREFILL_INFERENCE_END` | P instance Prefill inference complete | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_START` | P begins waiting for KV Cache transfer to complete | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_END` | KV Cache transfer confirmed, ready to send first token to D | + +### PD Disaggregation — Decode (D) Instance Events + +| Stage | Event | Description | +| :---: | --- | --- | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_START` | D begins processing resource allocation request from P | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_END` | D completes resource allocation and returns dest_block_ids to P | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_START` | D receives first token from P, begins processing Prefilled request | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_END` | D adds Prefilled request to running queue | +| DECODE | `DECODE_INFERENCE_END` | D instance Decode inference complete | + +### Request Lifecycle Sequence + +**Mixed mode** (single instance, full inference): +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END → INFERENCE_START +→ FIRST_TOKEN_GENERATED → DECODE_START → INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` + +**PD Disaggregation — Prefill (P) Instance**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ ASK_DECODE_RESOURCE_START → ASK_DECODE_RESOURCE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END +→ INFERENCE_START → PREFILL_INFERENCE_END +→ CHECK_CACHE_TRANSFER_START → CHECK_CACHE_TRANSFER_END → [send first token to D] +``` + +**PD Disaggregation — Decode (D) Instance**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ DECODE_PROCESS_PREALLOCATE_REQUEST_START → DECODE_PROCESS_PREALLOCATE_REQUEST_END +→ [wait for P to complete prefill and transfer KV Cache] +→ DECODE_PROCESS_PREFILLED_REQUEST_START → DECODE_PROCESS_PREFILLED_REQUEST_END +→ INFERENCE_START → DECODE_INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` diff --git a/docs/zh/online_serving/metrics.md b/docs/zh/online_serving/metrics.md index 630f68e2ff8..20da957bcf2 100644 --- a/docs/zh/online_serving/metrics.md +++ b/docs/zh/online_serving/metrics.md @@ -46,3 +46,81 @@ - 访问地址:`http://localhost:8000/metrics` - 指标类型:Prometheus 格式 + +## Trace 事件 + +FastDeploy 在请求处理的关键阶段输出结构化 trace 事件到 `trace.log`,用于定位请求级别的延迟瓶颈。每条 trace 日志包含 `timestamp`(毫秒)、`request_id`、`event`、`stage` 等字段。 + +### 通用事件(Mixed / 所有实例) + +| 阶段 | 事件 | 说明 | +| :---: | --- | --- | +| PREPROCESSING | `PREPROCESSING_START` | API Server 开始预处理请求 | +| PREPROCESSING | `PREPROCESSING_END` | Engine 收到请求,预处理完成 | +| SCHEDULE | `REQUEST_SCHEDULE_START` | 请求进入调度流程 | +| SCHEDULE | `REQUEST_QUEUE_START` | 请求进入调度队列等待 | +| SCHEDULE | `REQUEST_QUEUE_END` | 请求从调度队列取出 | +| SCHEDULE | `RESOURCE_ALLOCATE_START` | 开始为请求分配资源 | +| SCHEDULE | `PREPARE_PREFIX_CACHE_START` | 开始匹配前缀缓存块 | +| SCHEDULE | `PREPARE_PREFIX_CACHE_END` | 前缀缓存块匹配完成 | +| SCHEDULE | `RESOURCE_ALLOCATE_END` | 资源分配完成 | +| SCHEDULE | `REQUEST_SCHEDULE_END` | 调度流程结束 | +| PREFILL | `INFERENCE_START` | 请求送入 GPU 执行推理 | +| PREFILL | `FIRST_TOKEN_GENERATED` | 首 token 生成 | +| DECODE | `DECODE_START` | 进入 Decode 阶段 | +| DECODE | `INFERENCE_END` | 推理完成(所有 token 生成完毕) | +| DECODE | `PREEMPTED` | 请求被抢占 | +| DECODE | `RESCHEDULED_INFERENCE_START` | 被抢占的请求重新调度执行 | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_START` | 开始将 KV Cache 写入外部存储 | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_END` | KV Cache 写入外部存储完成 | +| POSTPROCESSING | `POSTPROCESSING_START` | 开始后处理 | +| POSTPROCESSING | `POSTPROCESSING_END` | 后处理完成,响应发送完毕 | + +### PD 分离 — Prefill (P) 实例专属事件 + +| 阶段 | 事件 | 说明 | +| :---: | --- | --- | +| SCHEDULE | `ASK_DECODE_RESOURCE_START` | P 开始向 D 申请资源(发送 ZMQ 请求) | +| SCHEDULE | `ASK_DECODE_RESOURCE_END` | P 收到 D 的资源分配确认(含 dest_block_ids) | +| PREFILL | `PREFILL_INFERENCE_END` | P 实例 Prefill 推理完成 | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_START` | P 开始等待 KV Cache 传输完成 | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_END` | KV Cache 传输完成确认,准备发送 first token 到 D | + +### PD 分离 — Decode (D) 实例专属事件 + +| 阶段 | 事件 | 说明 | +| :---: | --- | --- | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_START` | D 开始处理 P 发来的资源分配请求 | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_END` | D 完成资源分配并返回 dest_block_ids 给 P | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_START` | D 收到 P 的 first token,开始处理 Prefilled 请求 | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_END` | D 将 Prefilled 请求加入 running queue | +| DECODE | `DECODE_INFERENCE_END` | D 实例 Decode 推理完成 | + +### 请求生命周期时序图 + +**Mixed 模式**(单实例完整推理): +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END → INFERENCE_START +→ FIRST_TOKEN_GENERATED → DECODE_START → INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` + +**PD 分离 — Prefill (P) 实例**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ ASK_DECODE_RESOURCE_START → ASK_DECODE_RESOURCE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END +→ INFERENCE_START → PREFILL_INFERENCE_END +→ CHECK_CACHE_TRANSFER_START → CHECK_CACHE_TRANSFER_END → [发送 first token 到 D] +``` + +**PD 分离 — Decode (D) 实例**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ DECODE_PROCESS_PREALLOCATE_REQUEST_START → DECODE_PROCESS_PREALLOCATE_REQUEST_END +→ [等待 P 完成 prefill 并传输 KV Cache] +→ DECODE_PROCESS_PREFILLED_REQUEST_START → DECODE_PROCESS_PREFILLED_REQUEST_END +→ INFERENCE_START → DECODE_INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 833b4830824..3121b351898 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -934,6 +934,9 @@ def _fetch_request(): self.llm_logger.debug( f"P has allocated resources and then ask D resource for request: {task.request_id}" ) + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") + ) task.metrics.ask_decode_resource_start_time = time.time() while True: self.split_connector.send_splitwise_tasks([task], task.idx) @@ -945,6 +948,11 @@ def _fetch_request(): time.sleep(0.05) else: task.metrics.ask_decode_resource_finish_time = time.time() + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_END, + task.request_id, + getattr(task, "user", ""), + ) break self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") else: @@ -956,6 +964,9 @@ def _fetch_request(): self.llm_logger.debug( f"P has allocated resources and then ask D resource for req_id: {task.request_id}" ) + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") + ) task.metrics.ask_decode_resource_start_time = time.time() self.split_connector.send_splitwise_tasks([task], task.idx) @@ -963,6 +974,9 @@ def _fetch_request(): # assure fetch block ids from D status, msg = self.split_connector.check_decode_allocated(task) task.metrics.ask_decode_resource_finish_time = time.time() + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "") + ) if not status: error_msg = ( f"PD Error: prefill failed to apply for resource from decode, " @@ -979,6 +993,7 @@ def _fetch_request(): ) ] ) + main_process_metrics.reschedule_req_num.inc() need_delete_tasks.append(task) continue for tmp_task in need_delete_tasks: @@ -1086,6 +1101,7 @@ def _fetch_request(): f"preallocated request. req:{task.request_id} " ) self.llm_logger.error(msg) + main_process_metrics.reschedule_req_num.inc() self.scheduler.put_results( [ RequestOutput( @@ -2066,6 +2082,7 @@ def _process_allocate_resource_requests(): processed_indices = [] for idx, task in enumerate(allocate_resource_requests): is_success = False + trace_print(LoggingEventName.DECODE_PROCESS_PREALLOCATE_REQUEST_START, task.request_id, task.user) if envs.ENABLE_V1_KVCACHE_SCHEDULER: if self.resource_manager.preallocate_resource_in_d(task): @@ -2075,6 +2092,7 @@ def _process_allocate_resource_requests(): self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}") processed_indices.append(idx) is_success = True + main_process_metrics.decode_preallocated_req_num.inc() else: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.llm_logger.debug(f"D Resource available, processing task {task.request_id}") @@ -2094,6 +2112,11 @@ def _process_allocate_resource_requests(): break for idx in sorted(processed_indices, reverse=True): + trace_print( + LoggingEventName.DECODE_PROCESS_PREALLOCAT_REQUEST_END, + allocate_resource_requests[idx].request_id, + allocate_resource_requests[idx].user, + ) allocate_resource_requests.pop(idx) def _process_prefilled_requests(): @@ -2109,6 +2132,7 @@ def _process_prefilled_requests(): continue req_output.finished = False ready_request_outputs.append(req_output) + trace_print(LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_START, req_output.request_id, "") self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}") prefilled_request_ouputs = waiting_request_outputs @@ -2121,6 +2145,8 @@ def _process_prefilled_requests(): else: for req_output in ready_request_outputs: request_id = req_output.request_id + main_process_metrics.decode_preallocated_req_num.dec() + trace_print(LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_END, request_id, "") if envs.FD_ENABLE_INTERNAL_ADAPTER and not req_output.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue self.llm_logger.warning(f"{request_id} need not decode after first token") @@ -2134,6 +2160,7 @@ def _process_prefilled_requests(): self.llm_logger.warning( f"{request_id} prefill failed with msg:{req_output.error_msg}, recycle resource." ) + main_process_metrics.failed_recv_first_token_req_num.inc() self.resource_manager.pre_recycle_resource(request_id) if request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[request_id] diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 38b4eb14381..819022cd4ba 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1615,7 +1615,9 @@ def clear_data(self): def update_metrics(self, verbose=False): # Update metrics - num_tasks = sum([1 if task else 0 for task in self.tasks_list]) + num_requests_running = len(self.running) + num_requests_waiting = len(self.waiting) + num_requests_queuing = max(int(getattr(self, "scheduler_unhandled_request_num", 0) or 0), 0) blocks_used_by_tasks = set() for task in self.tasks_list: if task is not None: @@ -1624,10 +1626,14 @@ def update_metrics(self, verbose=False): main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks)) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) - main_process_metrics.num_requests_running.set(len(self.running)) - main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running)) + main_process_metrics.num_requests_running.set(num_requests_running) + main_process_metrics.num_requests_waiting.set(num_requests_waiting) + main_process_metrics.num_requests_queuing.set(num_requests_queuing) if verbose: - llm_logger.info(f"update metrics: running={len(self.running)}, waiting={num_tasks - len(self.running)}") + llm_logger.info( + f"update metrics: running={num_requests_running}, " + f"waiting={num_requests_waiting}, queuing={num_requests_queuing}" + ) def log_status(self): llm_logger.info( diff --git a/fastdeploy/metrics/metrics.py b/fastdeploy/metrics/metrics.py index 42fd0231bf0..0daa36ad58a 100644 --- a/fastdeploy/metrics/metrics.py +++ b/fastdeploy/metrics/metrics.py @@ -136,6 +136,7 @@ class MetricsManager: num_requests_running: "Gauge" num_requests_waiting: "Gauge" + num_requests_queuing: "Gauge" time_to_first_token: "Histogram" time_per_output_token: "Histogram" request_inference_time: "Histogram" @@ -153,7 +154,6 @@ class MetricsManager: spec_decode_num_emitted_tokens_total: "Gauge" spec_decode_draft_single_head_acceptance_rate: "list[Gauge]" - # for YIYAN Adapter prefix_cache_token_num: "Counter" prefix_gpu_cache_token_num: "Counter" prefix_cpu_cache_token_num: "Counter" @@ -192,6 +192,11 @@ class MetricsManager: request_prompt_tokens: "Histogram" request_token_ratio: "Histogram" + # for pd + decode_preallocated_req_num: "Gauge" + reschedule_req_num: "Counter" + failed_recv_first_token_req_num: "Counter" + # 定义所有指标配置 # gauge指标在多进程中,会有pid隔离,需要特殊处理,因此手动定义出来 @@ -205,7 +210,13 @@ class MetricsManager: "num_requests_waiting": { "type": Gauge, "name": "fastdeploy:num_requests_waiting", - "description": "Number of requests currently waiting", + "description": "Number of requests currently waiting in resource manager", + "kwargs": {}, + }, + "num_requests_queuing": { + "type": Gauge, + "name": "fastdeploy:num_requests_queuing", + "description": "Number of requests currently queuing in local scheduler", "kwargs": {}, }, "gpu_cache_usage_perc": { @@ -298,6 +309,12 @@ class MetricsManager: "description": "Token-level GPU prefix cache hit rate", "kwargs": {}, }, + "decode_preallocated_req_num": { + "type": Gauge, + "name": "fastdeploy:decode_preallocated_req_num", + "description": "Number of preallocated requests in decode instance", + "kwargs": {}, + }, } METRICS = { @@ -459,6 +476,18 @@ class MetricsManager: ], }, }, + "reschedule_req_num": { + "type": Counter, + "name": "fastdeploy:reschedule_req_num", + "description": "Total number of reschedule requests", + "kwargs": {}, + }, + "failed_recv_first_token_req_num": { + "type": Counter, + "name": "fastdeploy:failed_recv_first_token_req_num", + "description": "Total number of failed requests to receive the first token in decode", + "kwargs": {}, + }, } SPECULATIVE_METRICS = {} diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 45e60e8d656..61d84e0666a 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -331,8 +331,7 @@ def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: llm_logger.info(f"{self.resource_manager.info()}") if self.cfg.speculative_config.method: self._compute_speculative_status() - if not is_prefill: - self._record_completion_metrics(task, current_time) + self._record_completion_metrics(task, current_time) self._finalize_routing(task_id, task, result, is_prefill) self._recycle_resources(task_id, batch_id, task, result, is_prefill) break @@ -632,6 +631,8 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False if is_prefill: start_time = time.time() result.metrics.wait_for_sending_cache_time = time.time() + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) + while True: finished_task_ids = self.engine_worker_queue.get_finished_req() if len(finished_task_ids) > 0: @@ -648,6 +649,7 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False llm_logger.info( f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" ) + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) result.metrics.send_request_output_to_decode_time = time.time() self.split_connector.send_first_token(task.disaggregate_info, [result]) if envs.ENABLE_V1_KVCACHE_SCHEDULER: @@ -1071,10 +1073,8 @@ def _process_batch_output(self): llm_logger.info(f"{self.resource_manager.info()}") if self.cfg.speculative_config.method: self._compute_speculative_status(result) - if not is_prefill: - self._record_completion_metrics(task, current_time) + self._record_completion_metrics(task, current_time) llm_logger.info(f"task {task_id} received eos token. Recycling.") - if ( envs.ENABLE_V1_KVCACHE_SCHEDULER and self.cfg.cache_config.enable_prefix_caching @@ -1116,13 +1116,21 @@ def _record_first_token_metrics(self, task, current_time): def _record_completion_metrics(self, task, current_time): """Record metrics when request completes""" + role = self.cfg.scheduler_config.splitwise_role metrics = task.metrics - if metrics.engine_recv_first_token_time: - decode_time = current_time - metrics.engine_recv_first_token_time - main_process_metrics.request_decode_time.observe(decode_time) - trace_print(LoggingEventName.INFERENCE_END, task.request_id, getattr(task, "user", "")) + + if role in ("mixed", "decode"): + if metrics.engine_recv_first_token_time: + decode_time = current_time - metrics.engine_recv_first_token_time + main_process_metrics.request_decode_time.observe(decode_time) + trace_print(LoggingEventName.INFERENCE_END, task.request_id, getattr(task, "user", "")) + + if role == "prefill": + trace_print(LoggingEventName.PREFILL_INFERENCE_END, task.request_id, getattr(task, "user", "")) + elif role == "decode": + trace_print(LoggingEventName.DECODE_INFERENCE_END, task.request_id, getattr(task, "user", "")) + trace_print(LoggingEventName.POSTPROCESSING_START, task.request_id, getattr(task, "user", "")) - main_process_metrics.num_requests_running.dec(1) main_process_metrics.request_success_total.inc() main_process_metrics.request_inference_time.observe(current_time - metrics.inference_start_time) main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id]) diff --git a/fastdeploy/trace/constants.py b/fastdeploy/trace/constants.py index b8ffc94271a..ff481af248a 100644 --- a/fastdeploy/trace/constants.py +++ b/fastdeploy/trace/constants.py @@ -41,6 +41,20 @@ class LoggingEventName(Enum): PREEMPTED = "PREEMPTED" RESCHEDULED_INFERENCE_START = "RESCHEDULED_INFERENCE_START" + # For Prefill Instance + ASK_DECODE_RESOURCE_START = "ASK_DECODE_RESOURCE_START" + ASK_DECODE_RESOURCE_END = "ASK_DECODE_RESOURCE_END" + CHECK_CACHE_TRANSFER_START = "CHECK_CACHE_TRANSFER_START" + CHECK_CACHE_TRANSFER_END = "CHECK_CACHE_TRANSFER_END" + PREFILL_INFERENCE_END = "PREFILL_INFERENCE_END" + + # For Decode Instance + DECODE_PROCESS_PREALLOCATE_REQUEST_START = "DECODE_PROCESS_PREALLOCATE_REQUEST_START" + DECODE_PROCESS_PREALLOCAT_REQUEST_END = "DECODE_PROCESS_PREALLOCAT_REQUEST_END" + DECODE_PROCESS_PREFILLED_REQUEST_START = "DECODE_PROCESS_PREFILLED_REQUEST_START" + DECODE_PROCESS_PREFILLED_REQUEST_END = "DECODE_PROCESS_PREFILLED_REQUEST_END" + DECODE_INFERENCE_END = "DECODE_INFERENCE_END" + class StageName(Enum): """ @@ -75,4 +89,14 @@ class StageName(Enum): LoggingEventName.WRITE_CACHE_TO_STORAGE_END: StageName.POSTPROCESSING, LoggingEventName.POSTPROCESSING_START: StageName.POSTPROCESSING, LoggingEventName.POSTPROCESSING_END: StageName.POSTPROCESSING, + LoggingEventName.ASK_DECODE_RESOURCE_START: StageName.SCHEDULE, + LoggingEventName.ASK_DECODE_RESOURCE_END: StageName.SCHEDULE, + LoggingEventName.CHECK_CACHE_TRANSFER_START: StageName.POSTPROCESSING, + LoggingEventName.CHECK_CACHE_TRANSFER_END: StageName.POSTPROCESSING, + LoggingEventName.PREFILL_INFERENCE_END: StageName.PREFILL, + LoggingEventName.DECODE_PROCESS_PREALLOCATE_REQUEST_START: StageName.DECODE, + LoggingEventName.DECODE_PROCESS_PREALLOCAT_REQUEST_END: StageName.DECODE, + LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_START: StageName.DECODE, + LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_END: StageName.DECODE, + LoggingEventName.DECODE_INFERENCE_END: StageName.DECODE, } diff --git a/tests/output/test_token_processor_trace_print.py b/tests/output/test_token_processor_trace_print.py index 018038143f3..d43183705fb 100644 --- a/tests/output/test_token_processor_trace_print.py +++ b/tests/output/test_token_processor_trace_print.py @@ -25,6 +25,7 @@ def setup_method(self): self.mock_cfg = MagicMock() self.mock_cfg.parallel_config.local_data_parallel_id = 0 self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] + self.mock_cfg.scheduler_config.splitwise_role = "decode" self.mock_cached_tokens = MagicMock() self.mock_engine_queue = MagicMock() self.mock_split_connector = MagicMock() @@ -76,9 +77,10 @@ def test_record_completion_metrics(self, caplog): with caplog.at_level(logging.INFO): self.processor._record_completion_metrics(self.task, current_time) - assert len(caplog.records) == 2 + assert len(caplog.records) == 3 assert "[request_id=test123]" in caplog.text assert "[event=INFERENCE_END]" in caplog.text + assert "[event=DECODE_INFERENCE_END]" in caplog.text assert "[event=POSTPROCESSING_START]" in caplog.text # Verify metrics are updated From 3ac8ff2fe634ee01d860221e5ba78ca0c48faee2 Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 29 Apr 2026 14:16:19 +0800 Subject: [PATCH 078/143] Remove recode info for request when finish sending cache (#7664) --- fastdeploy/output/token_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 61d84e0666a..e654acf87a1 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -646,6 +646,7 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False f"PD Error: prefill failed to send cache to decode, " f"{task_id}, {self.prefill_result_status[task_id]}" ) + self.prefill_result_status.pop(task_id) llm_logger.info( f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" ) From f8e38f633d3cea3d5dbbbd2d0015c1b8950c20a5 Mon Sep 17 00:00:00 2001 From: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Date: Wed, 29 Apr 2026 14:17:24 +0800 Subject: [PATCH 079/143] abort requests fix2 (#7652) --- fastdeploy/engine/common_engine.py | 1 + .../engine/sched/resource_manager_v1.py | 2 ++ fastdeploy/output/token_processor.py | 30 +++++++++++++++++++ fastdeploy/scheduler/local_scheduler.py | 7 +++-- 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 3121b351898..50017baf5de 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1683,6 +1683,7 @@ def _wait_abort_complete(self, target_req_ids, stall_timeout=1): if not remaining: self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") return + self.llm_logger.debug(f"remaining:{remaining}") current_count = len(remaining) if current_count < prev_remaining_count: diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 819022cd4ba..3382e077d60 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -288,6 +288,8 @@ def recycle_abort_task(self, request_id): del self.requests[request_id] del self.req_dict[request_id] self.to_be_aborted_req_id_set.discard(request_id) + self.waiting_abort_req_id_set.discard(request_id) + llm_logger.debug(f"request_id:{request_id} recycle end") self.update_metrics() def _trigger_abort(self, request_id, scheduled_reqs): diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index e654acf87a1..6f6a8043803 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -358,6 +358,7 @@ def _process_batch_output_use_zmq(self, receive_datas): ): llm_logger.info(f"start to recycle abort request_id {task_id}") self.resource_manager.recycle_abort_task(task_id) + self._put_abort_results(task) if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set and token_ids[-1] == PREEMPTED_TOKEN_ID @@ -886,6 +887,7 @@ def _process_batch_output(self): if envs.ENABLE_V1_KVCACHE_SCHEDULER: if task_id in self.resource_manager.to_be_aborted_req_id_set: self.resource_manager.recycle_abort_task(task_id) + self._put_abort_results(task) if task_id in self.resource_manager.to_be_rescheduled_request_id_set: self.resource_manager.reschedule_preempt_task(task_id) continue @@ -920,6 +922,7 @@ def _process_batch_output(self): and token_id == PREEMPTED_TOKEN_ID ): self.resource_manager.recycle_abort_task(task_id) + self._put_abort_results(task) llm_logger.info(f"sync abortion for request_id {task_id} done.") if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set @@ -1192,6 +1195,33 @@ def _record_speculative_decoding_accept_num_per_request(self, req_id, accept_num self.accept_token_num_per_head_per_request[req_id][i] += 1 self.accept_token_num_per_head[i] += 1 + def _put_abort_results(self, task): + now = time.time() + eos_token_ids = getattr(task, "eos_token_ids", [0]) + abort_metrics = copy.copy(task.metrics) + for field in ( + "arrival_time", + "inference_start_time", + "engine_recv_latest_token_time", + "engine_recv_first_token_time", + "request_start_time", + ): + if not getattr(abort_metrics, field): + setattr(abort_metrics, field, now) + result = RequestOutput( + request_id=task.request_id, + finished=True, + outputs=CompletionOutput( + index=0, + send_idx=self.tokens_counter.get(task.request_id), + token_ids=[eos_token_ids[0]], + ), + metrics=abort_metrics, + error_code=200, + error_msg="Aborted", + ) + self.cached_generated_tokens.put_results([result]) + def clear_data(self): if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.clear_data() diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index fc4a64686b5..8fca9a4690d 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -129,8 +129,11 @@ def _recycle(self, request_id: Optional[str] = None): if request_id is not None: self.requests.pop(request_id, None) self.responses.pop(request_id, None) - self.ids.pop(self.ids.index(request_id)) - self.ids_read_cursor -= 1 + idx = self.ids.index(request_id) + self.ids.pop(idx) + if idx < self.ids_read_cursor: + self.ids_read_cursor -= 1 + scheduler_logger.debug(f"request_id : {request_id} has been recycled") return if self.max_size <= 0: From 97cda572f293df4736bbc9c1f935e6afce0c00b3 Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Wed, 29 Apr 2026 16:12:05 +0800 Subject: [PATCH 080/143] [Cherry-Pick][Optimize]Compute slot_mapping and position_ids(#7313 #7367) (#7638) * [Optimization] [OP] [Models] dsk del prefill mask (#7313) * dsk del prefill mask * dsk support 1M+ seq_len rope * update rope tests * [Optimization][DeepSeekV3.2]Reducing slot_mapping compute frequency from twice per layer to a single pre-processing step. (#7367) [Cherry-Pick][Optimize]Compute slot_mapping and position_ids --------- Co-authored-by: AIbin <37361953+chang-wenbin@users.noreply.github.com> --- custom_ops/gpu_ops/cpp_extensions.cc | 10 +- .../gpu_ops/fused_rotary_position_encoding.cu | 28 +++--- ...get_position_ids_and_mask_encoder_batch.cu | 32 +++--- .../gpu_ops/merge_prefill_decode_output.cu | 77 +++++++++------ fastdeploy/model_executor/forward_meta.py | 2 + .../layers/attention/dsa_attention_backend.py | 36 +------ .../model_executor/models/deepseek_v3.py | 99 ++++--------------- fastdeploy/worker/gpu_model_runner.py | 38 +++++++ fastdeploy/worker/input_batch.py | 5 + tests/distributed/chunked_moe.py | 1 + .../test_fused_rotary_position_encoding.py | 13 ++- ...get_position_ids_and_mask_encoder_batch.py | 16 +-- .../test_reorder_split_prefill_and_decode.py | 1 + 13 files changed, 157 insertions(+), 201 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 92478f71d75..34f2ffd9a55 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -540,12 +540,10 @@ std::vector count_tokens_per_expert_func( const paddle::Tensor& topk_ids, int64_t num_experts, bool compute_padded_cumsum = false); -void GetPositionIdsAndMaskEncoderBatch( - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& position_ids, - const paddle::Tensor& mask_encoder_batch); +void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids); std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_nope, diff --git a/custom_ops/gpu_ops/fused_rotary_position_encoding.cu b/custom_ops/gpu_ops/fused_rotary_position_encoding.cu index 05c94b60b74..8ec225d5c29 100644 --- a/custom_ops/gpu_ops/fused_rotary_position_encoding.cu +++ b/custom_ops/gpu_ops/fused_rotary_position_encoding.cu @@ -53,9 +53,13 @@ __global__ void apply_rotary_embedding_kernel( const int64_t key_stride, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; + const int head_size, + const int num_tokens) { // 新增 num_tokens 参数用于边界检查 + + // 用2D grid表示token_idx,突破65535限制 + const int token_idx = blockIdx.x + blockIdx.y * gridDim.x; + if (token_idx >= num_tokens) return; // 边界保护 + int pos = position_ids[token_idx]; const T* cache_ptr = cos_sin_cache + pos * rot_dim; @@ -99,13 +103,13 @@ void FusedRotaryPositionEncoding( int64_t query_stride = num_heads * head_size; int64_t key_stride = num_kv_heads * head_size; - if (num_tokens > 65535) { - PD_THROW( - "apply_rotary_embedding_kernel launch failed when num_tokens > 65535."); - } - - dim3 grid(num_tokens); + // 拆成2D grid:每维最大65535,总计支持 65535*65535 >> 1024*1024 + constexpr int MAX_GRID_X = 65535; + int grid_x = std::min(num_tokens, MAX_GRID_X); + int grid_y = (num_tokens + MAX_GRID_X - 1) / MAX_GRID_X; + dim3 grid(grid_x, grid_y); dim3 block(std::min(num_heads * rot_dim / 2, 512)); + PD_DISPATCH_FLOATING_AND_HALF_TYPES( query.dtype(), "apply_rotary_embedding_kernel", [&] { if (is_neox) { @@ -119,7 +123,8 @@ void FusedRotaryPositionEncoding( key_stride, num_heads, num_kv_heads, - head_size); + head_size, + num_tokens); } else { apply_rotary_embedding_kernel <<>>(query.data(), @@ -131,7 +136,8 @@ void FusedRotaryPositionEncoding( key_stride, num_heads, num_kv_heads, - head_size); + head_size, + num_tokens); } }); } diff --git a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu index 946c9754072..63bc77c9afc 100644 --- a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu +++ b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu @@ -20,8 +20,7 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel( const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度 const int* seq_lens_this_time, int* position_ids, // 输出的一维 position_ids - int* mask_encoder_batch, - const int bsz) { // 批次大小 + const int bsz) { // 批次大小 // 当前线程索引(每个线程对应一个批次) int tid = threadIdx.x; if (tid >= bsz) return; @@ -43,7 +42,6 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel( // 写入 encoder 的 position_ids for (int i = 0; i < encoder_len; i++) { position_ids[offset + i] = i; - mask_encoder_batch[offset + i] = 1; } offset += encoder_len; @@ -51,17 +49,14 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel( if (decoder_len > 0) { for (int i = 0; i < seq_len_this_time; i++) { position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身 - mask_encoder_batch[offset + i] = 0; } } } -void GetPositionIdsAndMaskEncoderBatch( - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& position_ids, - const paddle::Tensor& mask_encoder_batch) { +void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids) { const int bsz = seq_lens_this_time.shape()[0]; GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>( @@ -69,17 +64,16 @@ void GetPositionIdsAndMaskEncoderBatch( seq_lens_decoder.data(), seq_lens_this_time.data(), const_cast(position_ids.data()), - const_cast(mask_encoder_batch.data()), bsz); } PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch) - .Inputs({"seq_lens_encoder", - "seq_lens_decoder", - "seq_lens_this_time", - "position_ids", - "mask_encoder_batch"}) - .Outputs({"position_ids_out", "mask_encoder_batch_out"}) - .SetInplaceMap({{"position_ids", "position_ids_out"}, - {"mask_encoder_batch", "mask_encoder_batch_out"}}) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "position_ids", + }) + .Outputs({"position_ids_out"}) + .SetInplaceMap({{"position_ids", "position_ids_out"}}) .SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch)); diff --git a/custom_ops/gpu_ops/merge_prefill_decode_output.cu b/custom_ops/gpu_ops/merge_prefill_decode_output.cu index a57d72bdf3a..b158ca7f0e8 100644 --- a/custom_ops/gpu_ops/merge_prefill_decode_output.cu +++ b/custom_ops/gpu_ops/merge_prefill_decode_output.cu @@ -44,13 +44,49 @@ __global__ void FillEncoderDecoderResKernel(T *encoder_res_data, return; } - const int load_idx = - ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4; + const int base_idx = + ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim; - *reinterpret_cast(encoder_res_data + load_idx) = - *reinterpret_cast(decoder_res_data + load_idx); + if (head_dim == 128) { + const int load_idx = base_idx + land_id * 4; + *reinterpret_cast(encoder_res_data + load_idx) = + *reinterpret_cast(decoder_res_data + load_idx); + } else if (head_dim == 192) { + const int load_idx = base_idx + land_id * 4; + *reinterpret_cast(encoder_res_data + load_idx) = + *reinterpret_cast(decoder_res_data + load_idx); + if (land_id < 16) { + *reinterpret_cast(encoder_res_data + load_idx + 128) = + *reinterpret_cast(decoder_res_data + load_idx + 128); + } + } else if (head_dim == 256) { + // float4 = 单条LDG.128,性能最优 + const int load_idx = base_idx + land_id * 8; + *reinterpret_cast(encoder_res_data + load_idx) = + *reinterpret_cast(decoder_res_data + load_idx); + } } +#define LAUNCH_KERNEL(T, WARPS) \ + FillEncoderDecoderResKernel \ + <<>>( \ + const_cast(encoder_res.data()), \ + const_cast(decoder_res.data()), \ + seq_lens_encoder.data(), \ + seq_lens_decoder.data(), \ + seq_lens_this_time.data(), \ + cu_seq_q.data(), \ + head_num, \ + head_dim) + +#define LAUNCH_KERNEL_BY_HEAD_DIM(T) \ + if (head_dim == 128) \ + LAUNCH_KERNEL(T, 4); \ + else if (head_dim == 192) \ + LAUNCH_KERNEL(T, 6); \ + else if (head_dim == 256) \ + LAUNCH_KERNEL(T, 8) + void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res, const paddle::Tensor &decoder_res, const paddle::Tensor &seq_lens_encoder, @@ -60,41 +96,20 @@ void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res, const int head_num, const int head_dim, const int max_token) { - if (head_dim != 128) { - PD_THROW("Only supported head_dim = 128"); + if (head_dim != 128 && head_dim != 192 && head_dim != 256) { + PD_THROW("Only supported head_dim = 128, 192 or 256"); } const int batch_size = seq_lens_encoder.shape()[0]; - constexpr int warps = 4; + const int warps = head_dim / 32; const int tokens_block = (max_token + warps - 1) / warps; - dim3 grid_dims; - grid_dims.x = batch_size; - grid_dims.y = head_num; - grid_dims.z = tokens_block; + dim3 grid_dims(batch_size, head_num, tokens_block); if (encoder_res.dtype() == paddle::DataType::FLOAT16) { using T = phi::dtype::float16; - FillEncoderDecoderResKernel - <<>>( - const_cast(encoder_res.data()), - const_cast(decoder_res.data()), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - seq_lens_this_time.data(), - cu_seq_q.data(), - head_num, - head_dim); + LAUNCH_KERNEL_BY_HEAD_DIM(T); } else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) { using T = phi::dtype::bfloat16; - FillEncoderDecoderResKernel - <<>>( - const_cast(encoder_res.data()), - const_cast(decoder_res.data()), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - seq_lens_this_time.data(), - cu_seq_q.data(), - head_num, - head_dim); + LAUNCH_KERNEL_BY_HEAD_DIM(T); } } diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 64555bcbdfe..03a2734b41d 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -159,6 +159,8 @@ class ForwardMeta: exist_prefill: bool = False position_ids: Optional[paddle.Tensor] = None + # for kvcache slot + slot_mapping: Optional[paddle.Tensor] = None real_bsz: int = 0 diff --git a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py index 66a92a52599..acb73f5420a 100644 --- a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py @@ -54,33 +54,6 @@ def yarn_get_mscale(scale=1, mscale=1): return 0.1 * mscale * math.log(scale) + 1.0 -def compute_slot_mapping( - block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req] - positions: paddle.Tensor, # [num_tokens] 每个token的位置 - batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求 - block_size: int, -) -> paddle.Tensor: - """ - 计算 slot_mapping - - 公式: slot = block_id * block_size + offset_in_block - """ - # 1. 计算每个 token 对应的 block 索引 - block_idx = positions // block_size # [num_tokens] - - # 2. 从 block_tables 中查表获取 block_id - # block_tables[batch_id_per_token, block_idx] - block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens] - - # 3. 计算在 block 内的偏移 - block_offset = positions % block_size # [num_tokens] - - # 4. 计算 slot_mapping - slot_mapping = block_ids * block_size + block_offset - - return slot_mapping.cast(paddle.int64) - - @dataclass class DSAAttentionMetadata(AttentionMetadata): """ @@ -347,18 +320,11 @@ def forward_mixed( k_range = paddle.tensor(200.0) scale = paddle.abs(compressed_kv).max() / k_range - slot_mapping = compute_slot_mapping( - forward_meta.block_tables, - forward_meta.position_ids, - forward_meta.batch_id_per_token, - self.block_size, - ) - dsk_attn_write_cache( compressed_kv, k_pe, latent_cache, - slot_mapping, + forward_meta.slot_mapping, scale.cast(paddle.float32), "fp8_ds_mla", ) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 3270bbf4308..457846874ee 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -45,6 +45,9 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.layers.quantization.fp8_utils import ( + per_token_group_quant_fp8, +) from fastdeploy.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, ) @@ -58,20 +61,11 @@ ) from fastdeploy.platforms import current_platform -if current_platform.is_cuda() or current_platform.is_maca(): - from fastdeploy.model_executor.ops.gpu import ( - get_position_ids_and_mask_encoder_batch, - ) - -from fastdeploy.model_executor.layers.quantization.fp8_utils import ( - per_token_group_quant_fp8, -) -from fastdeploy.platforms import current_platform - if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import ( cp_gather_indexer_k_quant_cache, indexer_k_quant_and_cache, + merge_prefill_decode_output, radix_topk_ragged_transform, ) @@ -343,7 +337,6 @@ def forward( forward_meta: ForwardMeta, hidden_states: paddle.Tensor, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ @@ -398,7 +391,6 @@ def forward( fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) fmha_out = fmha_out_prefill if need_do_decode: # max_dec_len_this_time @@ -433,7 +425,17 @@ def forward( ) if need_do_prefill: - fmha_out += fmha_out_decode + merge_prefill_decode_output( + fmha_out, + fmha_out_decode, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_attention_heads_tp, + self.v_head_dim, + 1, + ) else: fmha_out = fmha_out_decode @@ -441,33 +443,6 @@ def forward( return output -def compute_slot_mapping( - block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req] - positions: paddle.Tensor, # [num_tokens] 每个token的位置 - batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求 - block_size: int, -) -> paddle.Tensor: - """ - 计算 slot_mapping - - 公式: slot = block_id * block_size + offset_in_block - """ - # 1. 计算每个 token 对应的 block 索引 - block_idx = positions // block_size # [num_tokens] - - # 2. 从 block_tables 中查表获取 block_id - # block_tables[batch_id_per_token, block_idx] - block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens] - - # 3. 计算在 block 内的偏移 - block_offset = positions % block_size # [num_tokens] - - # 4. 计算 slot_mapping - slot_mapping = block_ids * block_size + block_offset - - return slot_mapping.cast(paddle.int64) - - import triton import triton.language as tl @@ -651,17 +626,12 @@ def forward( weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.index_n_heads**-0.5 weights = weights.squeeze(-1) - slot_mapping = compute_slot_mapping( - forward_meta.block_tables, - forward_meta.position_ids, - forward_meta.batch_id_per_token, - 64, - ) - indexer_top_k = paddle.full([q_fp8.shape[0], self.index_topk], -1, dtype="int32") # indexer write_cache - indexer_k_quant_and_cache(k, self.indexer_cache, slot_mapping, self.quant_block_size, self.scale_fmt) + indexer_k_quant_and_cache( + k, self.indexer_cache, forward_meta.slot_mapping, self.quant_block_size, self.scale_fmt + ) from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm @@ -925,7 +895,6 @@ def forward( forward_meta: ForwardMeta, hidden_states: paddle.Tensor, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) @@ -1043,7 +1012,6 @@ def forward( hidden_states: paddle.Tensor, residual: paddle.Tensor, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ if hidden_states.shape[0] > 0: @@ -1051,7 +1019,7 @@ def forward( hidden_states, residual_input=residual, forward_meta=forward_meta ) - hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch) + hidden_states = self.self_attn(forward_meta, hidden_states, position_ids) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) else: @@ -1107,7 +1075,6 @@ def forward( ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) @@ -1119,7 +1086,6 @@ def forward( hidden_states, residual, position_ids, - mask_encoder_batch, ) out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0] @@ -1154,12 +1120,6 @@ def __init__(self, fd_config: FDConfig): num_embeddings=fd_config.model_config.vocab_size, prefix="lm_head", ) - self.position_ids_buffer = paddle.empty( - [fd_config.scheduler_config.max_num_batched_tokens], dtype=paddle.int32 - ) - self.mask_encoder_batch_buffer = paddle.empty( - [fd_config.scheduler_config.max_num_batched_tokens, 1], dtype=paddle.int32 - ) @classmethod def name(cls): @@ -1256,25 +1216,6 @@ def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta logits[:, self.ori_vocab_size :] = -float("inf") return logits - def pre_process(self, forward_meta): - """ """ - seq_lens_encoder = forward_meta.seq_lens_encoder - seq_lens_decoder = forward_meta.seq_lens_decoder - seq_lens_this_time = forward_meta.seq_lens_this_time - - current_total_tokens = forward_meta.ids_remove_padding.shape[0] - position_ids = self.position_ids_buffer[:current_total_tokens] - mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens] - - get_position_ids_and_mask_encoder_batch( - seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - position_ids, - mask_encoder_batch, - ) - return position_ids, mask_encoder_batch - def empty_input_forward(self, forward_meta): """ empty_input_forward @@ -1295,12 +1236,10 @@ def forward( forward_meta: ForwardMeta, ): ids_remove_padding = inputs["ids_remove_padding"] - forward_meta.position_ids, mask_encoder_batch = self.pre_process(forward_meta) hidden_states = self.model( ids_remove_padding=ids_remove_padding, forward_meta=forward_meta, position_ids=forward_meta.position_ids, - mask_encoder_batch=mask_encoder_batch, ) return hidden_states diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8b0c1468acb..7a6b2e22d71 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -45,6 +45,12 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) +from fastdeploy.model_executor.layers.attention.dsa_attention_backend import ( + DSAAttentionBackend, +) +from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( + MLAAttentionBackend, +) from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( RoutingReplayManager, ) @@ -78,6 +84,7 @@ speculate_schedule_cache, set_data_ipc, unset_data_ipc, + get_position_ids_and_mask_encoder_batch, ) import zmq @@ -1271,6 +1278,33 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p ) return token_num, token_num_event + def _compute_position_ids_and_slot_mapping(self) -> None: + """Compute position_ids and slot_mapping for KV cache addressing. + This is a general computation based on sequence length info and block tables, + applicable to all models that need per-token KV cache physical slot addresses. + Results are stored in self.forward_meta. + """ + # NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently. + if not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)): + return + current_total_tokens = self.forward_meta.ids_remove_padding.shape[0] + position_ids = self.share_inputs["position_ids_buffer"][:current_total_tokens] + get_position_ids_and_mask_encoder_batch( + self.forward_meta.seq_lens_encoder, + self.forward_meta.seq_lens_decoder, + self.forward_meta.seq_lens_this_time, + position_ids, + ) + block_size = self.cache_config.block_size + block_idx = position_ids // block_size # [num_tokens] + assert self.forward_meta.batch_id_per_token.shape == block_idx.shape + block_ids = self.forward_meta.block_tables[self.forward_meta.batch_id_per_token, block_idx] # [num_tokens] + block_offset = position_ids % block_size # [num_tokens] + slot_mapping = self.share_inputs["slot_mapping_buffer"][:current_total_tokens] + paddle.assign((block_ids * block_size + block_offset).cast(paddle.int64), slot_mapping) + self.forward_meta.position_ids = position_ids + self.forward_meta.slot_mapping = slot_mapping + def _process_reorder(self) -> None: if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False): self.share_inputs.enable_pd_reorder = True @@ -1859,6 +1893,8 @@ def _dummy_run( # 2. Padding inputs for cuda graph self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph self.padding_cudagraph_inputs() + # Compute position_ids and slot_mapping + self._compute_position_ids_and_slot_mapping() model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] @@ -2196,6 +2232,8 @@ def _preprocess( # Padding inputs for cuda graph self.padding_cudagraph_inputs() + # Compute position_ids and slot_mapping + self._compute_position_ids_and_slot_mapping() model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 55a3f39a2ee..f47c7bccc6d 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -188,6 +188,11 @@ def init_share_inputs(self): self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32") self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") + # Initialize addressing buffers + _max_batched_tokens = self.scheduler_config.max_num_batched_tokens + self.position_ids_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int32) + self.slot_mapping_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int64) + # Declare AttentionBackend buffers self.decoder_batch_ids = None self.decoder_tile_ids_per_batch = None diff --git a/tests/distributed/chunked_moe.py b/tests/distributed/chunked_moe.py index 0fe9f9f3974..81561f5d829 100644 --- a/tests/distributed/chunked_moe.py +++ b/tests/distributed/chunked_moe.py @@ -85,6 +85,7 @@ class SchedulerConfig: name = "default" splitwise_role = "mixed" max_num_seqs = 2 + max_num_batched_tokens = 2048 parallel_config = ParallelConfig() scheduler_config = SchedulerConfig() diff --git a/tests/operators/test_fused_rotary_position_encoding.py b/tests/operators/test_fused_rotary_position_encoding.py index cbff608c7c4..8ab5fb2c2ca 100644 --- a/tests/operators/test_fused_rotary_position_encoding.py +++ b/tests/operators/test_fused_rotary_position_encoding.py @@ -116,9 +116,10 @@ def test_neox_mode(self): self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True) def test_large_num_tokens(self): - self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False) - - def test_exceed_max_tokens(self): + """ + 测试算子支持大量 tokens(超过 65535) + 算子使用 2D grid,理论上可支持 65535*65535 个 tokens + """ num_tokens, num_heads, head_size = 65537, 1, 4 num_kv_heads, rot_dim = 1, 4 query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32") @@ -126,8 +127,10 @@ def test_exceed_max_tokens(self): position_ids_np = np.arange(num_tokens, dtype="int32") cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim) - with self.assertRaises(Exception): - self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False) + # 不应该抛出异常,算子应该能处理大量 tokens + query_out, key_out = self._run_op( + query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False + ) if __name__ == "__main__": diff --git a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py index 41474b4726c..2d1dd8e2f7c 100644 --- a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py +++ b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py @@ -33,24 +33,17 @@ def test_basic_functionality(self): total_len = int(seq_lens_encoder.numpy().sum() + seq_lens_this_time.numpy().sum()) position_ids = paddle.zeros([total_len], dtype="int32") - mask_encoder_batch = paddle.zeros([total_len], dtype="int32") # Call the custom operator - get_position_ids_and_mask_encoder_batch( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch - ) + get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) expected_position_ids = np.array([0, 1, 2, 1, 0, 1, 2, 3], dtype=np.int32) - expected_mask = np.array([1, 1, 1, 0, 1, 1, 0, 0], dtype=np.int32) - # Convert to numpy for comparison position_ids_np = position_ids.numpy() - mask_encoder_batch_np = mask_encoder_batch.numpy() # Assert equality np.testing.assert_array_equal(position_ids_np, expected_position_ids) - np.testing.assert_array_equal(mask_encoder_batch_np, expected_mask) def test_empty_decoder(self): # Test case where decoder length is 0 @@ -59,17 +52,12 @@ def test_empty_decoder(self): seq_lens_this_time = paddle.to_tensor([0], dtype="int32") position_ids = paddle.zeros([2], dtype="int32") - mask_encoder_batch = paddle.zeros([2], dtype="int32") - get_position_ids_and_mask_encoder_batch( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch - ) + get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) expected_position_ids = np.array([0, 1], dtype=np.int32) - expected_mask = np.array([1, 1], dtype=np.int32) np.testing.assert_array_equal(position_ids.numpy(), expected_position_ids) - np.testing.assert_array_equal(mask_encoder_batch.numpy(), expected_mask) if __name__ == "__main__": diff --git a/tests/worker/test_reorder_split_prefill_and_decode.py b/tests/worker/test_reorder_split_prefill_and_decode.py index d2d9e3a1f61..147e9581201 100644 --- a/tests/worker/test_reorder_split_prefill_and_decode.py +++ b/tests/worker/test_reorder_split_prefill_and_decode.py @@ -59,6 +59,7 @@ def create_mock_config(): scheduler_config = Mock(spec=SchedulerConfig) scheduler_config.max_num_seqs = 10 + scheduler_config.max_num_batched_tokens = 2048 speculative_config = Mock(spec=SpeculativeConfig) speculative_config.method = None From 75f328c7fc5c377236c29bb6155e1ee4868078e3 Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:09:51 +0800 Subject: [PATCH 081/143] [Cherry-Pick][Optimization] Support logprob overlap in speculative decoding (#7600) (#7656) * support logprob overlap --- custom_ops/gpu_ops/cpp_extensions.cc | 23 ++-- .../speculate_logprob_utils.cu | 124 ++++++++++++++---- .../model_executor/layers/sample/logprobs.py | 39 ++---- .../layers/sample/ops/__init__.py | 4 +- .../sample/ops/speculate_logprob_utils.py | 18 ++- .../model_executor/layers/sample/sampler.py | 9 +- .../model_executor/pre_and_post_process.py | 8 ++ tests/layers/test_speculative_sampler.py | 10 +- ...speculate_get_accept_tokens_and_logits.py} | 78 ++++++----- 9 files changed, 207 insertions(+), 106 deletions(-) rename tests/operators/{test_speculate_get_target_logits.py => test_speculate_get_accept_tokens_and_logits.py} (61%) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 34f2ffd9a55..bc6f7e0783a 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -1145,13 +1145,16 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder); -void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, - const paddle::Tensor& logits, - const paddle::Tensor& cu_batch_token_offset, - const paddle::Tensor& ori_cu_batch_token_offset, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& accept_num); +void SpeculateGetAcceptTokensAndLogits( + const paddle::Tensor& token_ids, + const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num, + const paddle::Tensor& accept_tokens); std::vector UpdateAttnMaskOffsets( const paddle::Tensor& ids_remove_padding, @@ -1879,9 +1882,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &SpeculateInsertFirstToken, "speculate_insert_first_token function"); - m.def("speculate_get_target_logits", - &SpeculateGetTargetLogits, - "speculate_get_target_logits function"); + m.def("speculate_get_accept_tokens_and_logits", + &SpeculateGetAcceptTokensAndLogits, + "speculate_get_accept_tokens_and_logits function"); #endif m.def("update_attn_mask_offsets", diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu index 76a84f30d4f..eadcf015f70 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu @@ -184,24 +184,65 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, real_bsz); } +template +__global__ void compute_cu_batch_offset_kernel(int* cu_batch_token_offset, + const int* accept_num, + const int real_bsz) { + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + int tid = threadIdx.x; + if (tid == 0) cu_batch_token_offset[0] = 0; + + int thread_data[ITEMS_PER_THREAD]; + + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int batch_id = tid * ITEMS_PER_THREAD + i; + thread_data[i] = + batch_id < real_bsz ? accept_num[tid * ITEMS_PER_THREAD + i] : 0; + } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int batch_id = tid * ITEMS_PER_THREAD + i; + if (batch_id < real_bsz) { + cu_batch_token_offset[batch_id + 1] = thread_data[i]; + } + } +} + template -__global__ void speculate_get_target_logits_kernel( +__global__ void speculate_get_accept_tokens_and_logits_kernel( + int64_t* token_ids, float* target_logits, const float* logits, const int* cu_batch_token_offset, - const int* ori_cu_batch_token_offset, + const int* cu_seqlens_q_output, const int* seq_lens_this_time, const int* seq_lens_encoder, const int* accept_num, + const int64_t* accept_tokens, const int vocab_size, + const int max_draft_tokens, const int real_bsz) { AlignedVector src_vec; const int bid = blockIdx.x; const int tid = threadIdx.x; if (bid < real_bsz) { + // get token_ids + if (tid == 0) { + auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens; + for (int i = 0; i < accept_num[bid]; i++) { + token_ids[cu_batch_token_offset[bid] + i] = accept_tokens_now[i]; + } + } + + // get output_logits auto* target_logits_now = target_logits + cu_batch_token_offset[bid] * vocab_size; - auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + cu_seqlens_q_output[bid] * vocab_size; for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { if (seq_lens_encoder[bid] > 0) { Load(&logits_now[i], &src_vec); @@ -217,31 +258,64 @@ __global__ void speculate_get_target_logits_kernel( } } -void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, - const paddle::Tensor& logits, - const paddle::Tensor& cu_batch_token_offset, - const paddle::Tensor& ori_cu_batch_token_offset, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& accept_num) { +void SpeculateGetAcceptTokensAndLogits( + const paddle::Tensor& token_ids, + const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num, + const paddle::Tensor& accept_tokens) { auto cu_stream = seq_lens_this_time.stream(); const int vocab_size = logits.shape()[1]; - const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_occupied_slots = seq_lens_this_time.shape()[0]; + const int max_draft_tokens = accept_tokens.shape()[1]; + + const int BLOCK_DIM = 512; + PADDLE_ENFORCE_LE(max_occupied_slots, + 2048, + phi::errors::InvalidArgument( + "Only support bsz <= 2048, but received bsz is ", + max_occupied_slots)); + if (max_occupied_slots <= 512) { + compute_cu_batch_offset_kernel + <<<1, BLOCK_DIM, 0, cu_stream>>>( + const_cast(cu_batch_token_offset.data()), + accept_num.data(), + max_occupied_slots); + } else if (max_occupied_slots <= 1024) { + compute_cu_batch_offset_kernel + <<<1, BLOCK_DIM, 0, cu_stream>>>( + const_cast(cu_batch_token_offset.data()), + accept_num.data(), + max_occupied_slots); + } else if (max_occupied_slots <= 2048) { + compute_cu_batch_offset_kernel + <<<1, BLOCK_DIM, 0, cu_stream>>>( + const_cast(cu_batch_token_offset.data()), + accept_num.data(), + max_occupied_slots); + } constexpr int PackSize = VEC_16B / sizeof(float); - dim3 grid_dim(real_bsz); + dim3 grid_dim(max_occupied_slots); dim3 block_dim(128); - speculate_get_target_logits_kernel + speculate_get_accept_tokens_and_logits_kernel <<>>( + const_cast(token_ids.data()), const_cast(target_logits.data()), logits.data(), cu_batch_token_offset.data(), - ori_cu_batch_token_offset.data(), + cu_seqlens_q_output.data(), seq_lens_this_time.data(), seq_lens_encoder.data(), accept_num.data(), + accept_tokens.data(), vocab_size, - real_bsz); + max_draft_tokens, + max_occupied_slots); } PD_BUILD_STATIC_OP(speculate_get_logits) @@ -274,14 +348,20 @@ PD_BUILD_STATIC_OP(speculate_insert_first_token) .SetInplaceMap({{"token_ids", "token_ids_out"}}) .SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken)); -PD_BUILD_STATIC_OP(speculate_get_target_logits) - .Inputs({"target_logits", +PD_BUILD_STATIC_OP(speculate_get_accept_tokens_and_logits) + .Inputs({"token_ids", + "target_logits", "logits", "cu_batch_token_offset", - "ori_cu_batch_token_offset", + "cu_seqlens_q_output", "seq_lens_this_time", "seq_lens_encoder", - "accept_num"}) - .Outputs({"target_logits_out"}) - .SetInplaceMap({{"target_logits", "target_logits_out"}}) - .SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits)); + "accept_num", + "accept_tokens"}) + .Outputs({"token_ids_out", + "target_logits_out", + "cu_batch_token_offset_out"}) + .SetInplaceMap({{"token_ids", "token_ids_out"}, + {"target_logits", "target_logits_out"}, + {"cu_batch_token_offset", "cu_batch_token_offset_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetAcceptTokensAndLogits)); diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index ac9e0edacbf..9f14c1d5e0e 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -169,9 +169,8 @@ def build_output_logprobs( """ num_logprobs = sampling_metadata.max_num_logprobs logprobs_tensors = None - cu_batch_token_offset = None - # NOTE(huicongyao) real_bsz is passed from _postprocess, remove this in future + max_draft_token_num = share_inputs["accept_tokens"].shape[1] max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0] if is_naive: @@ -181,43 +180,27 @@ def build_output_logprobs( else: # Speculative mode: extract target logits for accepted positions from fastdeploy.model_executor.layers.sample.ops import ( - speculate_get_target_logits, + speculate_get_accept_tokens_and_logits, ) - batch_token_num = paddle.where( - share_inputs["seq_lens_encoder"][:max_occupied_slots] != 0, - paddle.ones_like(share_inputs["seq_lens_encoder"][:max_occupied_slots]), - share_inputs["seq_lens_this_time"], - ).flatten() - - share_inputs["batch_token_num"] = batch_token_num - - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) - cu_batch_token_offset = paddle.concat( - [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:max_occupied_slots])] - ).astype("int32") - share_inputs["cu_batch_token_offset"] = cu_batch_token_offset - output_logits = paddle.empty( - [share_inputs["accept_num"][:max_occupied_slots].sum(), logits.shape[1]], + [real_bsz * max_draft_token_num, logits.shape[1]], dtype=logits.dtype, ) - speculate_get_target_logits( + token_ids = paddle.full([real_bsz * max_draft_token_num], fill_value=0, dtype="int64") + + speculate_get_accept_tokens_and_logits( + token_ids, output_logits, logits, - cu_batch_token_offset, - ori_cu_batch_token_offset, + share_inputs["cu_batch_token_offset"], + share_inputs["cu_seqlens_q_output"], share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], share_inputs["accept_num"], + share_inputs["accept_tokens"], ) - idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32") - mask = idx < share_inputs["accept_num"].unsqueeze(1) - token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) - # Adapt for sampling mask if num_logprobs is None: return None, None, output_logits @@ -232,7 +215,7 @@ def build_output_logprobs( logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) # output_logits use to compute sampling_mask - return logprobs_tensors, cu_batch_token_offset, output_logits + return logprobs_tensors, share_inputs["cu_batch_token_offset"], output_logits def logprobs_renormalize_with_logz(logprobs: paddle.Tensor, logz, logprobs_tensors: LogprobsTensors): diff --git a/fastdeploy/model_executor/layers/sample/ops/__init__.py b/fastdeploy/model_executor/layers/sample/ops/__init__.py index 911c1697497..eb2b79927bd 100644 --- a/fastdeploy/model_executor/layers/sample/ops/__init__.py +++ b/fastdeploy/model_executor/layers/sample/ops/__init__.py @@ -20,7 +20,7 @@ reasoning_phase_token_constraint, ) from .speculate_logprob_utils import ( - speculate_get_target_logits, + speculate_get_accept_tokens_and_logits, speculate_insert_first_token, ) from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling @@ -31,6 +31,6 @@ "reasoning_phase_token_constraint", "top_k_top_p_sampling", "min_p_sampling", - "speculate_get_target_logits", + "speculate_get_accept_tokens_and_logits", "speculate_insert_first_token", ] diff --git a/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py b/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py index 2caaf4892b8..df9bcaf6195 100644 --- a/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py +++ b/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py @@ -19,29 +19,35 @@ from fastdeploy.platforms import current_platform -def speculate_get_target_logits( +def speculate_get_accept_tokens_and_logits( + token_ids: paddle.Tensor, target_logits: paddle.Tensor, logits: paddle.Tensor, cu_batch_token_offset: paddle.Tensor, - ori_cu_batch_token_offset: paddle.Tensor, + cu_seqlens_q_output: paddle.Tensor, seq_lens_this_time: paddle.Tensor, seq_lens_encoder: paddle.Tensor, accept_num: paddle.Tensor, + accept_tokens: paddle.Tensor, ): """ - speculate_get_target_logits + speculate_get_accept_tokens_and_logits """ if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import speculate_get_target_logits + from fastdeploy.model_executor.ops.gpu import ( + speculate_get_accept_tokens_and_logits, + ) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) else: raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 7157fc8d755..f06bd695149 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -908,9 +908,7 @@ def compute_logprobs( if top_p_logprob is not None: last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) - # NOTE(huicongyao) temporarily used for slice last_logprobs to its real shape, remove in the future - real_token_num = batch_token_num.sum().item() - return last_logprobs[:real_token_num] + return last_logprobs def gather_logprobs( self, @@ -1238,10 +1236,13 @@ def forward_cuda( ) sampler_output.logprobs_tensors = logprobs_tensors if cu_batch_token_offset is not None: - sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() + cu_batch_token_offset_cpu = paddle.empty_like(cu_batch_token_offset, device="cpu").pin_memory() + cu_batch_token_offset_cpu.copy_(cu_batch_token_offset, False) + sampler_output.cu_batch_token_offset = cu_batch_token_offset_cpu if keep_sampling_mask: real_bsz = share_inputs["seq_lens_this_time"].shape[0] accept_nums = share_inputs["accept_num"][:real_bsz].reshape([-1]) + target_logits = target_logits[: accept_nums.sum()] # Derive target probs from already-extracted target_logits; avoids a second kernel call. target_probs = F.softmax(target_logits, axis=-1) # Compute sampling mask at accepted token positions. diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 6de08ae0b70..4c7a09b93c0 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -585,6 +585,14 @@ def save_output_speculate( # Renormalize logprobs with logz (deferred from post_process for better overlap). if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + # TODO (wangyanpeng): Currently, there is a bug when overlap is enabled. + # Please ensure overlap is disabled when using this functionality to avoid unexpected behavior. + real_token_num = share_inputs["accept_num_cpu"].sum() + sampler_output.logprobs_tensors = LogprobsTensors( + logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids[:real_token_num], + logprobs=sampler_output.logprobs_tensors.logprobs[:real_token_num], + selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks[:real_token_num], + ) sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( sampler_output.logprobs_tensors.logprobs, sampler_output.logz_per_batch, diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index 6321e6f08f3..11247e3fe06 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -196,7 +196,15 @@ def test_speculative_sampler(): increment_value = (max_draft_token_num + 1) * 4 sampler = SpeculativeSampler(fd_config) - sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value) + sampler( + logits, + sampling_metadata, + max_model_len, + share_inputs, + token_num_output_cpu, + increment_value, + real_bsz=batch_size, + ) def test_speculative_sampler_logprobs(): diff --git a/tests/operators/test_speculate_get_target_logits.py b/tests/operators/test_speculate_get_accept_tokens_and_logits.py similarity index 61% rename from tests/operators/test_speculate_get_target_logits.py rename to tests/operators/test_speculate_get_accept_tokens_and_logits.py index 5d930418ae1..70b346f4895 100644 --- a/tests/operators/test_speculate_get_target_logits.py +++ b/tests/operators/test_speculate_get_accept_tokens_and_logits.py @@ -17,7 +17,7 @@ import paddle from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import ( - speculate_get_target_logits, + speculate_get_accept_tokens_and_logits, ) @@ -35,34 +35,38 @@ def test_all_decode(self): seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32") seq_lens_this_time = paddle.to_tensor([[2], [2], [2]], dtype="int32") accept_num = paddle.to_tensor([1, 2, 1], dtype="int32") + accept_tokens = paddle.to_tensor([[10, -1], [20, 21], [30, -1]], dtype="int64") batch_token_num = paddle.where( seq_lens_encoder != 0, paddle.ones_like(seq_lens_encoder), seq_lens_this_time, ).squeeze(1) - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) + cu_seqlens_q_output = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype("int32") cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + token_ids = paddle.full(shape=[accept_num.sum()], fill_value=0, dtype="int64") target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) - glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") - glod_logits[0][:] = 0 - glod_logits[1][:] = 2 - glod_logits[2][:] = 3 - glod_logits[3][:] = 4 + ref_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") + ref_logits[0][:] = 0 + ref_logits[1][:] = 2 + ref_logits[2][:] = 3 + ref_logits[3][:] = 4 + ref_token_ids = paddle.to_tensor([10, 20, 21, 30], dtype="int64") - assert paddle.allclose(target_logits, glod_logits) + assert paddle.allclose(target_logits, ref_logits) + assert paddle.equal_all(token_ids, ref_token_ids) def test_partial_decode(self): token_num = 5 @@ -73,34 +77,38 @@ def test_partial_decode(self): seq_lens_encoder = paddle.to_tensor([[10], [0], [0]], dtype="int32") seq_lens_this_time = paddle.to_tensor([[10], [2], [2]], dtype="int32") accept_num = paddle.to_tensor([1, 2, 1], dtype="int32") + accept_tokens = paddle.to_tensor([[10, -1], [20, 21], [30, -1]], dtype="int64") batch_token_num = paddle.where( seq_lens_encoder != 0, paddle.ones_like(seq_lens_encoder), seq_lens_this_time, ).squeeze(1) - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) + cu_seqlens_q_output = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype("int32") cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + token_ids = paddle.full(shape=[accept_num.sum()], fill_value=0, dtype="int64") target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) - glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") - glod_logits[0][:] = 0 - glod_logits[1][:] = 1 - glod_logits[2][:] = 2 - glod_logits[3][:] = 3 + ref_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") + ref_logits[0][:] = 0 + ref_logits[1][:] = 1 + ref_logits[2][:] = 2 + ref_logits[3][:] = 3 + ref_token_ids = paddle.to_tensor([10, 20, 21, 30], dtype="int64") - assert paddle.allclose(target_logits, glod_logits) + assert paddle.allclose(target_logits, ref_logits) + assert paddle.equal_all(token_ids, ref_token_ids) def test_all_prefill(self): token_num = 3 @@ -111,33 +119,37 @@ def test_all_prefill(self): seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32") seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32") accept_num = paddle.to_tensor([1, 1, 1], dtype="int32") + accept_tokens = paddle.to_tensor([[10, -1], [20, -1], [30, -1]], dtype="int64") batch_token_num = paddle.where( seq_lens_encoder != 0, paddle.ones_like(seq_lens_encoder), seq_lens_this_time, ).squeeze(1) - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) + cu_seqlens_q_output = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype("int32") cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + token_ids = paddle.full(shape=[accept_num.sum()], fill_value=0, dtype="int64") target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) - glod_logits = paddle.full(shape=[3, self.vocab_size], fill_value=-1, dtype="float32") - glod_logits[0][:] = 0 - glod_logits[1][:] = 1 - glod_logits[2][:] = 2 + ref_logits = paddle.full(shape=[3, self.vocab_size], fill_value=-1, dtype="float32") + ref_logits[0][:] = 0 + ref_logits[1][:] = 1 + ref_logits[2][:] = 2 + ref_token_ids = paddle.to_tensor([10, 20, 30], dtype="int64") - assert paddle.allclose(target_logits, glod_logits) + assert paddle.allclose(target_logits, ref_logits) + assert paddle.equal_all(token_ids, ref_token_ids) if __name__ == "__main__": From d3a2c7104fcec0305a9c6ad2c03163472c83d9af Mon Sep 17 00:00:00 2001 From: kevin Date: Wed, 29 Apr 2026 19:51:43 +0800 Subject: [PATCH 082/143] [Cherry-Pick][BugFix][KVCache] Fix inference slowdown when enabling CPU cache (#7471) (#7651) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [BugFix][KVCache] Fix inference slowdown when enabling CPU cache on Blackwell GPU 在 B 卡(Blackwell GPU)上开启 CPU cache(num_cpu_blocks > 0)时,推理性能出现明显降速。 根因是 `create_cache_tensor` 的判断逻辑将 `num_cpu_blocks > 0` 作为跳过 GPU cache tensor 创建的条件,导致 B 卡上错误地跳过了 GPU cache tensor 的初始化。 - `fastdeploy/worker/gpu_model_runner.py`:`create_cache_tensor` 判断中移除 `num_cpu_blocks > 0` 条件(两处:`init_cache` 和 `clear_cache`),保证开启 CPU cache 时 GPU cache tensor 仍正常创建 - `fastdeploy/cache_manager/prefix_cache_manager.py`:将 `--create_cache_tensor` 参数从非 splitwise 场景的条件判断中移出,统一归到 `kvcache_storage_backend` 配置路径下,逻辑更清晰 ```bash python -m fastdeploy.entrypoints.openai.api_server \ --num-cpu-blocks \ ... ``` * [BugFix][KVCache] Enlarge prealloc threshold for speculative decoding 投机解码场景下,每个调度步骤一次性消耗 `num_spec_tokens` 个 slot,原有的 `prealloc_dec_block_slot_num_threshold` 阈值偏小,导致块预分配触发不够及时, 影响推理性能。 在 `FDConfig` 初始化阶段,当启用 speculative decoding 时,将 `prealloc_dec_block_slot_num_threshold` 扩大为原值乘以 `num_spec_tokens`, 同时确保不超过 enc_dec_block 容量上限。 启用投机解码时,无需额外配置,阈值自动调整: ```bash python -m fastdeploy.entrypoints.openai.api_server \ --speculative-config '{"method": "draft_model", "num_speculative_tokens": 4}' \ ... ``` * [BugFix][KVCache][FDConfig] Fix prealloc threshold and create_cache_tensor for splitwise ## Motivation 两处 bug 修复: 1. speculative decoding 场景下,prealloc_dec_block_slot_num_threshold 的放大系数应为 (num_spec_tokens + 1) 而非 num_spec_tokens,确保预分配触发时机足够提前。 2. kvcache_storage_backend 启用时,--create_cache_tensor 参数只应在非 splitwise 模式下传入,避免 splitwise P 节点错误创建 cache tensor。 ## Modifications - fastdeploy/config.py: 修正 prealloc 放大系数为 (num_spec_tokens + 1),并添加 logger 打印变动前后的值 - fastdeploy/cache_manager/prefix_cache_manager.py: --create_cache_tensor 仅在非 splitwise 模式下追加 ## Usage or Command ```bash # 启动服务(含 speculative decoding + kvcache storage) python -m fastdeploy.entrypoints.openai.api_server \ --model \ --speculative-model \ --num-speculative-tokens 3 \ --kvcache-storage-backend ``` * [BugFix][SpecDecode] Fix create_cache_tensor condition in MTPProposer ## Motivation `create_cache_tensor` 的判断条件中包含 `num_cpu_blocks > 0`,导致在 B卡 CPU cache 场景下,MTP 的 kv cache 创建逻辑出现异常。 ## Modifications 移除 `create_cache_tensor` 判断中的 `num_cpu_blocks > 0` 条件,仅保留 `kvcache_storage_backend` 和 `splitwise_role` 的判断,避免冗余条件干扰。 Co-Authored-By: Claude Sonnet 4.6 * [BugFix][FDConfig] Remove auto_dispatch_tokens logic not in cherry-pick scope 在2.6分支中移除误加入的 auto_dispatch_tokens 及 num_max_dispatch_tokens_per_rank 相关逻辑,这些代码属于 develop 分支已有逻辑,不在 PR #7471 cherry-pick 范围内。 Co-Authored-By: Claude Sonnet 4.6 * [BugFix][KVCache] Address review comments: fix negative cap, sync runner fixes, update comments ## Modifications 1. **config.py**: Fix `enc_dec_block_num=0` causing negative upper bound for `prealloc_dec_block_slot_num_threshold`. Use `max(0, ...)` to guard against negative cap. Also fix comment to say `num_spec_tokens + 1` (matching code). 2. **metax_model_runner.py**: Sync the same fix from gpu_model_runner.py — remove `num_cpu_blocks > 0` from `create_cache_tensor` condition in both `initialize_kv_cache` and `clear_cache`. CPU cache should not prevent GPU runners from creating GPU cache tensors on Metax platform. 3. **mtp.py**: Remove `num_cpu_blocks > 0` from `clear_mtp_cache` condition. Restore the `if not create_cache_tensor` IPC unset block that was lost. 4. **gpu_model_runner.py / mtp.py / xpu_model_runner.py**: Update stale comments to clarify that CPU cache does NOT prevent GPU cache tensor creation; cache transfer manager handles CPU<->GPU swap on top. Co-Authored-By: Claude Sonnet 4.6 --------- Co-authored-by: Claude Sonnet 4.6 --- .../cache_manager/prefix_cache_manager.py | 3 ++- fastdeploy/config.py | 19 +++++++++++++++++++ fastdeploy/spec_decode/mtp.py | 11 ++++++----- fastdeploy/worker/gpu_model_runner.py | 11 ++++++----- fastdeploy/worker/metax_model_runner.py | 8 ++++---- fastdeploy/worker/xpu_model_runner.py | 2 ++ 6 files changed, 39 insertions(+), 15 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 537f092a5ed..e12e47d3fd0 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -296,6 +296,8 @@ def launch_cache_manager( val_cache_arg_str = f" --value_cache_shape {val_shape_str}" if cache_config.kvcache_storage_backend: storage_arg_str = f" --kvcache_storage_backend {cache_config.kvcache_storage_backend}" + if not self.enable_splitwise: + storage_arg_str += " --create_cache_tensor" else: storage_arg_str = " " @@ -333,7 +335,6 @@ def launch_cache_manager( + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --default_dtype '{self.config.model_config.dtype}'" - + (" --create_cache_tensor" if not self.enable_splitwise else "") + storage_arg_str + f" --write_policy {cache_config.write_policy}" + f" --max_model_len {self.config.model_config.max_model_len}" diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 3d51eaf4e47..3219cd34a9c 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2195,6 +2195,25 @@ def postprocess(self): self.speculative_config.num_speculative_tokens = 1 self.speculative_config.num_model_steps = 1 + if self.speculative_config is not None and self.speculative_config.method is not None: + num_spec_tokens = self.speculative_config.num_speculative_tokens + # For speculative, enlarge the threshold to trigger block preallocation earlier, + # since each step consumes num_spec_tokens + 1 slots at once + old_prealloc_threshold = self.cache_config.prealloc_dec_block_slot_num_threshold + prealloc_dec_block_slot = self.cache_config.prealloc_dec_block_slot_num_threshold * (num_spec_tokens + 1) + max_prealloc_dec_block_slot = max( + 0, self.cache_config.block_size * self.cache_config.enc_dec_block_num - 1 + ) + self.cache_config.prealloc_dec_block_slot_num_threshold = min( + prealloc_dec_block_slot, max_prealloc_dec_block_slot + ) + logger.info( + f"prealloc_dec_block_slot_num_threshold updated: {old_prealloc_threshold} -> " + f"{self.cache_config.prealloc_dec_block_slot_num_threshold} " + f"(num_spec_tokens={num_spec_tokens}, block_size={self.cache_config.block_size}, " + f"enc_dec_block_num={self.cache_config.enc_dec_block_num})" + ) + if self.scheduler_config.splitwise_role == "mixed": self._disable_sequence_parallel_moe_if_needed("Mixed") self.model_config.moe_phase = MoEPhase(phase="prefill") diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 0c1681e12a0..2d8d310a469 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -242,10 +242,12 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. - # 2. If no need to profile, create kv cache if cache managers do not exist. + # 2. If no need to profile, create kv cache unless kvcache_storage_backend or + # p/d disaggregation is enabled. Note: CPU cache (num_cpu_blocks > 0) does NOT + # prevent GPU runner from creating GPU cache tensors; cache transfer manager + # handles CPU<->GPU swap on top of the GPU tensors created here. create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) @@ -423,8 +425,7 @@ def clear_mtp_cache(self, profile=False): Clear allocated cacheKV """ create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) if not create_cache_tensor: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 7a6b2e22d71..1890a5e77a1 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1472,10 +1472,12 @@ def initialize_kv_cache(self, profile: bool = False) -> None: # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. - # 2. If no need to profile, create kv cache if cache managers do not exist. + # 2. If no need to profile, create kv cache unless kvcache_storage_backend or + # p/d disaggregation is enabled. Note: CPU cache (num_cpu_blocks > 0) does NOT + # prevent GPU runner from creating GPU cache tensors; cache transfer manager + # handles CPU<->GPU swap on top of the GPU tensors created here. create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) @@ -2752,8 +2754,7 @@ def cal_theortical_kvcache(self): def clear_cache(self, profile=False): """Clear cached data from shared inputs and forward metadata""" create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index a5fe3547e64..7e721107f1e 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -1332,9 +1332,10 @@ def initialize_kv_cache(self, profile: bool = False) -> None: # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. # 2. If no need to profile, create kv cache if cache managers do not exist. + # Note: even when CPU cache (num_cpu_blocks > 0) is enabled, GPU runner still + # creates GPU cache tensors; cache transfer manager handles CPU<->GPU swap. create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) @@ -2478,8 +2479,7 @@ def not_need_stop(self) -> bool: def clear_cache(self, profile=False): """Clear cached data from shared inputs and forward metadata""" create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index ba250215345..ea83d7d0141 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -1251,6 +1251,8 @@ def initialize_kv_cache(self, profile: bool = False) -> None: # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. # 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled. + # Note: even when CPU cache (num_cpu_blocks > 0) is enabled, GPU runner still + # creates GPU cache tensors; cache transfer manager handles CPU<->GPU swap. create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed" if not create_cache_tensor: logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}") From df1d64c88a497bb75f450a5c4fefe800d8d489bf Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Thu, 30 Apr 2026 09:40:48 +0800 Subject: [PATCH 083/143] Fix key error for updating mtp model weights (#7676) --- fastdeploy/rl/dynamic_weight_manager.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index 46e686c38a6..a6f14068abc 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -74,10 +74,10 @@ def _capture_model_state(self): def update_weights_by_rdma(self, version: str = None, verify_checksum: bool = False): def valid_parameters(old_state_dict, new_state_dict): is_valid = True - for key in old_state_dict: - if key not in new_state_dict: + for key in new_state_dict: + if key not in old_state_dict: is_valid = False - logger.error(f"Invalid parameter: {key} not in new_state_dict") + logger.error(f"Invalid parameter: {key} not in old_state_dict") elif old_state_dict[key].shape != new_state_dict[key].shape: is_valid = False logger.error( @@ -128,8 +128,8 @@ def valid_parameters(old_state_dict, new_state_dict): raise ValueError(error_msg) update_start = time.perf_counter() - for name, target_param in old_state_dict.items(): - new_param = new_state_dict[name] + for name, new_param in new_state_dict.items(): + target_param = old_state_dict[name] if bootstrap_load and not target_param._is_initialized(): new_param = new_param.cuda() new_param._share_buffer_to(target_param) From c1f9714c01e1b363395b0934435a04c4a7929122 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:06:28 +0800 Subject: [PATCH 084/143] [Cherry-Pick] [BugFix] Fix get_tasks returns empty list and incorrect nnode computation (#7677) * [fix] reset exist task flag if no task exists * [BugFix] fix incorrect nnode computation --- fastdeploy/worker/worker_process.py | 30 ++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 865cbb909db..55ada820f50 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -320,6 +320,18 @@ def _broadcast_model_weights_signal(self, src: int, group) -> int: paddle.distributed.broadcast_object_list(signal_list, src=src, group=group) return int(signal_list[0]) + def _get_exist_task_flag(self) -> bool: + if self.nnode > 1: + return self.task_queue.read_finish_flag.get() == 1 + else: + return self.exist_task_signal.value[0] == ExistTaskStatus.EXIST + + def _update_exist_task_flag(self, flag: bool) -> None: + if self.nnode > 1: + self.task_queue.read_finish_flag.set(1 if flag else 0) + else: + self.exist_task_signal.value[0] = ExistTaskStatus.EXIST if flag else ExistTaskStatus.EMPTY + def _tp_barrier_wait(self): if current_platform.is_xpu() or self.enable_overlap_schedule: self.task_queue.worker_process_tp_barrier.wait() @@ -473,7 +485,7 @@ def event_loop_normal(self) -> None: self._init_eplb_signal() tp_size = self.parallel_config.tensor_parallel_size # Currently, only support single node - self.nnode = (tp_size + self.max_chips_per_node) // self.max_chips_per_node + self.nnode = (tp_size + self.max_chips_per_node - 1) // self.max_chips_per_node max_occupied_batch_index = 0 tp_rank = self.local_rank % tp_size @@ -497,10 +509,9 @@ def event_loop_normal(self) -> None: if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( self.fd_config.enable_mm_runtime and self.worker.exist_prefill() ): - if self.nnode > 1: - self.task_queue.read_finish_flag.set(1) - else: - self.exist_task_signal.value[0] = ExistTaskStatus.EXIST + self._update_exist_task_flag(True) + else: + self._update_exist_task_flag(False) # Synchronize the signal set by tp_rank0 visiable to other workers self._tp_barrier_wait() if tp_size > 1 else None @@ -558,17 +569,14 @@ def event_loop_normal(self) -> None: ) # 所有 Rank 已同步唤醒,启动权重更新流程 continue - if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: + if self._get_exist_task_flag(): logger.debug(f"Rank: {self.local_rank} Detected new requests.") tasks, read_finish = self.task_queue.get_tasks() # Only one of all tp_size client will get read_finish == True. if read_finish: - # Reset the two signal. - if self.nnode > 1: - self.task_queue.read_finish_flag.set(0) - else: - self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY + self._update_exist_task_flag(False) + self._tp_barrier_wait() if tp_size > 1 else None req_dicts, control_reqs = [], [] for req_dict, bsz in tasks: From 0ec96258451e4d7ca30bae7175c5f8773e155e1c Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:33:44 +0800 Subject: [PATCH 085/143] [Cherry-Pick] [BugFix] fix preempted token id not returned when a full batch is aborted (#7633) (#7654) * [BugFix] fix preempted token id not returned when a full batch is aborted * [fix] changed fake_sampled_token_ids shape and filled value * [test] add test * [chore] move code place * [test] add more tests and docstring --- fastdeploy/worker/gpu_model_runner.py | 127 ++++++++++++++++-- tests/worker/test_gpu_model_runner.py | 181 ++++++++++++++++++++++++++ 2 files changed, 299 insertions(+), 9 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1890a5e77a1..bca07c82170 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -112,7 +112,12 @@ DistributedStatus, ModelRunnerBase, ) -from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput +from fastdeploy.worker.output import ( + LogprobsTensors, + ModelOutputData, + ModelRunnerOutput, + SamplerOutput, +) class GPUModelRunner(ModelRunnerBase): @@ -2107,6 +2112,91 @@ def _execute_empty_mtp_input(self, forward_meta) -> None: for _ in range(self.fd_config.speculative_config.num_model_steps): self.proposer.model.empty_input_forward(forward_meta) + def _make_preempted_batch_output(self): + """Build a minimal batch-shaped control output for preempted slots. + + This is used when the current step contains only preempted/aborted + requests and therefore produces no normal model tokens. The helper + fabricates a lightweight batch output so the existing save_output path + can still return PREEMPTED_TOKEN_ID for the affected slots. + """ + preempted_indices = paddle.nonzero(self.share_inputs["preempted_idx"] == 1) + bsz = int(preempted_indices[-1][0].item()) + 1 + + fake_sampled_token_ids = paddle.where( + self.share_inputs["preempted_idx"][:bsz] == 1, + PREEMPTED_TOKEN_ID, + -1, + ).astype("int64") + self.share_inputs["sampled_token_ids"][:bsz].copy_(fake_sampled_token_ids, False) + + fake_logprobs_tensors = None + if self.enable_logprob: + fake_logprobs_tensors = LogprobsTensors( + logprob_token_ids=paddle.zeros([bsz, 1], dtype="int64", device="cpu"), + logprobs=paddle.zeros([bsz, 1], dtype="float32", device="cpu"), + selected_token_ranks=paddle.zeros([bsz], dtype="int64", device="cpu"), + ) + + if self.speculative_decoding: + self.share_inputs["accept_tokens_cpu"][:bsz].fill_(0) + self.share_inputs["accept_num_cpu"][:bsz].fill_(0) + self.share_inputs["seq_lens_decoder_cpu"][:bsz].copy_(self.share_inputs["seq_lens_decoder"][:bsz], False) + self.share_inputs["prompt_lens_cpu"][:bsz].copy_(self.share_inputs["prompt_lens"][:bsz], False) + sampler_output = SamplerOutput( + sampled_token_ids=fake_sampled_token_ids, + logprobs_tensors=fake_logprobs_tensors, + token_num_per_batch=(self.share_inputs["accept_num_cpu"][:bsz] if self.enable_logprob else None), + cu_batch_token_offset=( + paddle.zeros([bsz + 1], dtype="int32", device="cpu") if self.enable_logprob else None + ), + ) + else: + sampler_output = SamplerOutput( + sampled_token_ids=fake_sampled_token_ids, + logprobs_tensors=fake_logprobs_tensors, + ) + + index_to_batch_id = { + i: self.share_inputs["index_to_batch_id"][i] + for i in range(bsz) + if i in self.share_inputs["index_to_batch_id"] + } + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + not_need_stop_device=self.share_inputs["not_need_stop_device"], + input_ids=self.share_inputs["input_ids"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=None, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.parallel_config.tensor_parallel_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + token_ids_all=self.share_inputs["token_ids_all"], + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + min_tokens=self.share_inputs["min_dec_len"], + prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], + prompt_logprobs_list=None, + index_to_batch_id=index_to_batch_id, + enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), + ) + return model_output_data, sampler_output + def execute_model( self, model_forward_batch: Optional[List[Request]] = None, @@ -2141,14 +2231,23 @@ def execute_model_normal( and self.parallel_config.use_ep ): self._execute_empty_mtp_input(self.forward_meta) - return - model_output_data, sampler_output, post_process_event = self._postprocess( - model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz - ) - if model_output_data is not None: - # synchronizes the async DtoH copies of sampled_token_ids. - post_process_event.synchronize() - self._save_model_output(model_output_data, sampler_output) + + if paddle.sum(self.share_inputs["preempted_idx"]) > 0: + logger.info( + f"All requests in batch are preempted, real_bsz: {real_bsz} preempted: {paddle.sum(self.share_inputs['preempted_idx'])}" + ) + model_output_data, sampler_output = self._make_preempted_batch_output() + self.share_inputs["last_preempted_idx"].copy_(self.share_inputs["preempted_idx"]) + self.share_inputs["preempted_idx"][:] = 0 + self._save_model_output(model_output_data, sampler_output) + else: + model_output_data, sampler_output, post_process_event = self._postprocess( + model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz + ) + if model_output_data is not None: + # synchronizes the async DtoH copies of sampled_token_ids. + post_process_event.synchronize() + self._save_model_output(model_output_data, sampler_output) def execute_model_overlap( self, @@ -2189,6 +2288,16 @@ def execute_model_overlap( and self.parallel_config.use_ep ): self._execute_empty_mtp_input(self.forward_meta) + + if paddle.sum(self.share_inputs["preempted_idx"]) > 0: + logger.info( + f"All requests in batch are preempted, real_bsz: {real_bsz} preempted: {paddle.sum(self.share_inputs['preempted_idx'])}" + ) + model_output_data, sampler_output = self._make_preempted_batch_output() + self.share_inputs["last_preempted_idx"].copy_(self.share_inputs["preempted_idx"]) + self.share_inputs["preempted_idx"][:] = 0 + self._save_model_output(model_output_data, sampler_output) + self._cached_model_output_data = None self._cached_sampler_output = None self._cached_post_process_event = None diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py index 43ab5130cdb..8400fb79271 100644 --- a/tests/worker/test_gpu_model_runner.py +++ b/tests/worker/test_gpu_model_runner.py @@ -19,6 +19,7 @@ import numpy as np import paddle +from fastdeploy.config import PREEMPTED_TOKEN_ID from fastdeploy.engine.request import ImagePosition from fastdeploy.spec_decode import SpecMethod from fastdeploy.worker.gpu_model_runner import GPUModelRunner @@ -591,5 +592,185 @@ def test_wakeup_kvcache_is_idempotent(self, mock_print_memory): mock_print_memory.assert_not_called() +class TestMakePreemptedBatchOutput(unittest.TestCase): + def _make_runner(self, speculative_decoding=False, enable_logprob=False): + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.speculative_decoding = speculative_decoding + runner.enable_logprob = enable_logprob + runner.parallel_config = Mock(msg_queue_id=0, tensor_parallel_rank=0, use_ep=False) + + class _ShareInputs(dict): + enable_pd_reorder = False + + share_inputs = _ShareInputs() + share_inputs["preempted_idx"] = paddle.to_tensor( + [[0], [0], [0], [1], [0], [0], [1], [0], [0], [0]], dtype="int32" + ) + share_inputs["sampled_token_ids"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["index_to_batch_id"] = {i: i for i in range(10)} + share_inputs["next_tokens"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["stop_flags"] = paddle.zeros([10, 1], dtype="bool") + share_inputs["step_idx"] = 0 + share_inputs["max_dec_len"] = 16 + share_inputs["seq_lens_this_time"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["eos_token_id"] = paddle.zeros([1], dtype="int64") + share_inputs["not_need_stop"] = False + share_inputs["not_need_stop_device"] = paddle.zeros([1], dtype="bool") + share_inputs["input_ids"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["seq_lens_encoder"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["seq_lens_decoder"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["is_block_step"] = paddle.zeros([10, 1], dtype="bool") + share_inputs["token_ids_all"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["stop_seqs"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["stop_seqs_len"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["min_dec_len"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["prompt_lens"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["mask_rollback"] = paddle.zeros([10, 1], dtype="bool") + share_inputs["accept_tokens_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int64") + share_inputs["accept_num_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int32") + share_inputs["seq_lens_decoder_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int32") + share_inputs["prompt_lens_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int32") + share_inputs["draft_tokens"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["actual_draft_token_num"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["accept_tokens"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["accept_num"] = paddle.zeros([10, 1], dtype="int32") + runner.share_inputs = share_inputs + return runner + + def test_make_preempted_batch_output_emits_sparse_preempt_mask(self): + runner = self._make_runner() + + model_output_data, sampler_output = runner._make_preempted_batch_output() + + expected = [-1, -1, -1, PREEMPTED_TOKEN_ID, -1, -1, PREEMPTED_TOKEN_ID] + self.assertEqual(sampler_output.sampled_token_ids.shape, [7, 1]) + self.assertEqual(sampler_output.sampled_token_ids.numpy().reshape([-1]).tolist(), expected) + self.assertEqual(runner.share_inputs["sampled_token_ids"][:7].numpy().reshape([-1]).tolist(), expected) + self.assertEqual(model_output_data.index_to_batch_id, {i: i for i in range(7)}) + + def test_make_preempted_batch_output_speculative_logprob(self): + runner = self._make_runner(speculative_decoding=True, enable_logprob=True) + runner.share_inputs["seq_lens_decoder"][:7] = paddle.arange(7, dtype="int32").reshape([7, 1]) + runner.share_inputs["prompt_lens"][:7] = paddle.arange(10, 17, dtype="int32").reshape([7, 1]) + + model_output_data, sampler_output = runner._make_preempted_batch_output() + + self.assertEqual(sampler_output.sampled_token_ids.shape, [7, 1]) + self.assertIsNotNone(sampler_output.logprobs_tensors) + self.assertEqual(sampler_output.logprobs_tensors.logprob_token_ids.shape, [7, 1]) + self.assertEqual(sampler_output.token_num_per_batch.shape, [7, 1]) + self.assertEqual(sampler_output.cu_batch_token_offset.shape, [8]) + self.assertEqual(runner.share_inputs["accept_tokens_cpu"][:7].numpy().reshape([-1]).tolist(), [0] * 7) + self.assertEqual(runner.share_inputs["accept_num_cpu"][:7].numpy().reshape([-1]).tolist(), [0] * 7) + self.assertEqual( + runner.share_inputs["seq_lens_decoder_cpu"][:7].numpy().reshape([-1]).tolist(), + list(range(7)), + ) + self.assertEqual( + runner.share_inputs["prompt_lens_cpu"][:7].numpy().reshape([-1]).tolist(), + list(range(10, 17)), + ) + self.assertIsNotNone(model_output_data.accept_tokens) + self.assertIsNotNone(model_output_data.accept_num) + + +class TestExecuteModel(unittest.TestCase): + def _make_runner(self): + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.speculative_decoding = False + runner.parallel_config = Mock(use_ep=False) + runner.fd_config = Mock() + runner.fd_config.speculative_config = Mock(method=None) + runner.proposer = Mock(model=Mock()) + runner.forward_meta = Mock() + runner._save_model_output = Mock() + runner._make_preempted_batch_output = Mock(return_value=("model_output", "sampler_output")) + runner._postprocess = Mock() + runner._execute_empty_mtp_input = Mock() + runner._cached_launch_token_num = 0 + runner._cached_real_bsz = 0 + + class _ShareInputs(dict): + pass + + share_inputs = _ShareInputs() + share_inputs["seq_lens_this_time_cpu"] = paddle.zeros([2, 1], dtype="int32") + share_inputs["preempted_idx"] = paddle.to_tensor([[1], [0]], dtype="int32") + share_inputs["last_preempted_idx"] = paddle.zeros([2, 1], dtype="int32") + runner.share_inputs = share_inputs + return runner + + def test_execute_model_dispatches_to_normal_path(self): + runner = self._make_runner() + runner.enable_overlap_schedule = False + runner.execute_model_normal = Mock() + runner.execute_model_overlap = Mock() + + runner.execute_model(model_forward_batch=["req"], num_running_requests=1) + + runner.execute_model_normal.assert_called_once_with(["req"], 1) + runner.execute_model_overlap.assert_not_called() + + def test_execute_model_dispatches_to_overlap_path(self): + runner = self._make_runner() + runner.enable_overlap_schedule = True + runner.execute_model_normal = Mock() + runner.execute_model_overlap = Mock() + + runner.execute_model(model_forward_batch=["req"], num_running_requests=1) + + runner.execute_model_overlap.assert_called_once_with(["req"], 1) + runner.execute_model_normal.assert_not_called() + + def test_execute_model_normal_zero_output_flushes_preempted_batch(self): + runner = self._make_runner() + runner._preprocess = Mock(return_value=("model_inputs", "done_idxs", None)) + runner._execute = Mock(return_value=None) + + runner.execute_model_normal() + + runner._make_preempted_batch_output.assert_called_once_with() + np.testing.assert_array_equal(runner.share_inputs["last_preempted_idx"].numpy(), np.array([[1], [0]])) + np.testing.assert_array_equal(runner.share_inputs["preempted_idx"].numpy(), np.array([[0], [0]])) + runner._save_model_output.assert_called_once_with("model_output", "sampler_output") + + def test_execute_model_normal_postprocess_saves_output_after_sync(self): + runner = self._make_runner() + runner.share_inputs["seq_lens_this_time_cpu"] = paddle.to_tensor([[1], [0]], dtype="int32") + runner._preprocess = Mock(return_value=("model_inputs", "done_idxs", None)) + runner._execute = Mock(return_value="model_output") + post_process_event = Mock() + runner._postprocess.return_value = ("model_output_data", "sampler_output", post_process_event) + + runner.execute_model_normal(model_forward_batch=["req"], num_running_requests=1) + + runner._make_preempted_batch_output.assert_not_called() + post_process_event.synchronize.assert_called_once_with() + runner._save_model_output.assert_called_once_with("model_output_data", "sampler_output") + + def test_execute_model_overlap_zero_output_flushes_preempted_batch(self): + runner = self._make_runner() + token_num_event = Mock() + runner._preprocess = Mock(return_value=("model_inputs", "done_idxs", token_num_event)) + runner._execute = Mock(return_value=None) + runner._predict_next_launch_token_num = Mock(return_value=(11, 22)) + runner._cached_model_output_data = None + runner._cached_sampler_output = "cached_sampler" + runner._cached_post_process_event = "cached_event" + + runner.execute_model_overlap() + + token_num_event.synchronize.assert_called_once_with() + runner._make_preempted_batch_output.assert_called_once_with() + np.testing.assert_array_equal(runner.share_inputs["last_preempted_idx"].numpy(), np.array([[1], [0]])) + np.testing.assert_array_equal(runner.share_inputs["preempted_idx"].numpy(), np.array([[0], [0]])) + runner._save_model_output.assert_called_once_with("model_output", "sampler_output") + self.assertIsNone(runner._cached_model_output_data) + self.assertIsNone(runner._cached_sampler_output) + self.assertIsNone(runner._cached_post_process_event) + self.assertEqual(runner._cached_launch_token_num, 11) + self.assertEqual(runner._cached_real_bsz, 22) + + if __name__ == "__main__": unittest.main() From 66dea60be82db5c4ea360d7af362d46d37a59697 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Fri, 1 May 2026 23:56:54 +0800 Subject: [PATCH 086/143] [BugFix] Fix get_tasks returns empty list and incorrect nnode computation (additional fixes) (#7685) --- fastdeploy/worker/worker_process.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 55ada820f50..f11837f4f52 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -503,6 +503,8 @@ def event_loop_normal(self) -> None: req_dicts = None self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time()) + self._tp_barrier_wait() if tp_size > 1 else None + # The first worker detects whether there are tasks in the task queue if tp_rank == 0: if self.task_queue.exist_tasks(): From d0a0b3e32e9dff83a30356d299c64a08c89295cd Mon Sep 17 00:00:00 2001 From: sunxin <68891411+Sunny-bot1@users.noreply.github.com> Date: Fri, 8 May 2026 14:56:32 +0800 Subject: [PATCH 087/143] fix rl overlap (#7745) --- fastdeploy/worker/gpu_model_runner.py | 7 +++++++ fastdeploy/worker/input_batch.py | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bca07c82170..8b4d29bc081 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2900,6 +2900,13 @@ def clear_parameters(self, pid): self.clear_cache() paddle.device.cuda.empty_cache() + # clear overlap status + self._cached_model_output_data = None + self._cached_sampler_output = None + self._cached_post_process_event = None + self._cached_launch_token_num = -1 + self._cached_real_bsz = -1 + self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") def clear_requests(self): diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index f47c7bccc6d..225eac1c05b 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -565,7 +565,6 @@ def reset_share_inputs(self): fill_paddle_tensor(self, "step_idx", 0) # fill_paddle_tensor(self, "not_need_stop", False) fill_paddle_tensor(self, "not_need_stop_device", False) - fill_paddle_tensor(self, "sampled_token_ids", -1) fill_paddle_tensor(self, "stop_flags", True) fill_paddle_tensor(self, "bad_tokens", -1) @@ -687,6 +686,12 @@ def reset_share_inputs(self): # Reset other miscellaneous tensors fill_paddle_tensor(self, "mask_rollback", 0) fill_paddle_tensor(self, "preempted_idx", 0) + fill_paddle_tensor(self, "last_preempted_idx", 0) + + # Reset tensors for overlap + self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory() + self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory() + self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory() logger.info("share_inputs reset completed") except Exception as e: From d5af4597b6df8d92e5771dc33f9f4316c8f3db92 Mon Sep 17 00:00:00 2001 From: AIbin <37361953+chang-wenbin@users.noreply.github.com> Date: Sat, 9 May 2026 13:17:01 +0800 Subject: [PATCH 088/143] [Cherry-Pick] [BugFix] Fix stop token sequence pointer offset and actual length computation #7618 [BugFix][Scheduler]Fix FD_DISABLE_CHUNKED_PREFILL max_num_batched_tokens limit#7407 (#7755) * cp cwb_pr * del tests --------- Co-authored-by: chang-wenbin --- custom_ops/gpu_ops/stop_generation_multi_ends.cu | 5 +++-- fastdeploy/config.py | 5 ++++- fastdeploy/engine/args_utils.py | 6 +++++- fastdeploy/input/utils.py | 9 ++++++--- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/custom_ops/gpu_ops/stop_generation_multi_ends.cu b/custom_ops/gpu_ops/stop_generation_multi_ends.cu index d2a6dcbbf60..06cf99831d7 100644 --- a/custom_ops/gpu_ops/stop_generation_multi_ends.cu +++ b/custom_ops/gpu_ops/stop_generation_multi_ends.cu @@ -79,8 +79,9 @@ __global__ void set_value_by_flags(bool *stop_flags, // dealing stop_seqs const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; if (stop_seq_len <= 0) return; - const int64_t *stop_seq_now = - stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; + const int64_t *stop_seq_now = stop_seqs + + bid * stop_seqs_bs * stop_seqs_max_len + + tid * stop_seqs_max_len; const int64_t *pre_ids_now = token_ids_all + bid * max_model_len + prompt_lens[bid]; const int64_t step_idx_now = step_idx[bid]; diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 3219cd34a9c..49bd5faacd6 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2103,7 +2103,10 @@ def postprocess(self): if self.scheduler_config.max_num_batched_tokens is None: if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): - self.scheduler_config.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + if int(envs.FD_DISABLE_CHUNKED_PREFILL): + self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len + else: + self.scheduler_config.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM else: if self.cache_config.enable_chunked_prefill: self.scheduler_config.max_num_batched_tokens = 2048 diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 5690dee0a33..5b443e91c42 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -1500,7 +1500,11 @@ def create_engine_config(self) -> FDConfig: if self.max_num_batched_tokens is None: if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): - if current_platform.is_maca() or current_platform.is_iluvatar(): + if ( + int(envs.FD_DISABLE_CHUNKED_PREFILL) + or current_platform.is_maca() + or current_platform.is_iluvatar() + ): self.max_num_batched_tokens = self.max_model_len else: self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM diff --git a/fastdeploy/input/utils.py b/fastdeploy/input/utils.py index 19a86f31574..74e1cf2c6fb 100644 --- a/fastdeploy/input/utils.py +++ b/fastdeploy/input/utils.py @@ -82,6 +82,7 @@ def process_stop_token_ids( update_stop_seq_fn: Callable[[List[str]], Tuple[List[List[int]], List[int]]], ) -> None: stop_token_ids_final = [] + stop_seqs_len_final = [] if request.get("stop_token_ids") is not None: stop_token_ids = request.get("stop_token_ids") @@ -89,17 +90,19 @@ def process_stop_token_ids( if isinstance(stop_token_ids[0], int): # List[int] -> List[List[int]] stop_token_ids_final.extend([[t] for t in stop_token_ids]) + stop_seqs_len_final.extend([1] * len(stop_token_ids)) elif isinstance(stop_token_ids[0], list): # Already List[List[int]] stop_token_ids_final.extend(stop_token_ids) + stop_seqs_len_final.extend([len(seq) for seq in stop_token_ids]) stop_sequences = request.get("stop", []) if stop_sequences: - stop_seqs, _ = update_stop_seq_fn(stop_sequences) + stop_seqs, stop_seqs_actual_lens = update_stop_seq_fn(stop_sequences) stop_token_ids_final.extend(stop_seqs) + stop_seqs_len_final.extend(stop_seqs_actual_lens) # Update request if stop_token_ids_final: - stop_seqs_len = [len(seq) for seq in stop_token_ids_final] request["stop_token_ids"] = stop_token_ids_final - request["stop_seqs_len"] = stop_seqs_len + request["stop_seqs_len"] = stop_seqs_len_final From a5fa7278fed1dc32c49973a91ef3ebc75f89653b Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Sat, 9 May 2026 17:45:35 +0800 Subject: [PATCH 089/143] [BugFix][KVCache][Speculative Decoding] Fix get_max_chunk_tokens for PD-split decode node in MTP scenario (#7756) (#7758) Co-authored-by: kevin --- fastdeploy/config.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 49bd5faacd6..6cf244d2bb8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2484,7 +2484,13 @@ def get_max_chunk_tokens(self, mm_max_tokens_per_item=None) -> int: if paddle.is_compiled_with_xpu(): num_tokens = self.scheduler_config.max_num_batched_tokens else: - num_tokens = self.scheduler_config.max_num_seqs + # In MTP scenario, each sequence generates (num_speculative_tokens + 1) tokens per step + mtp_steps = ( + (getattr(self.speculative_config, "num_speculative_tokens", 0) + 1) + if self.speculative_config is not None and self.speculative_config.method is not None + else 1 + ) + num_tokens = self.scheduler_config.max_num_seqs * mtp_steps else: num_tokens = self.scheduler_config.max_num_batched_tokens if self.enable_mm_runtime and mm_max_tokens_per_item is not None: From d92163f147b591ab2c2bbfbefafda097777fb8a3 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Sat, 9 May 2026 20:01:47 +0800 Subject: [PATCH 090/143] [BugFix] Fix ZMQ multipart frame interleaving in Splitwise connector (#7763) * splitwise_zmq_lock * fix ut and comment --------- Co-authored-by: yuanlehome --- fastdeploy/splitwise/splitwise_connector.py | 62 ++++++++++++++++----- tests/splitwise/test_splitwise_connector.py | 22 ++++++-- 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index e5feb661f66..77f75ee4de7 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -15,10 +15,11 @@ """ import pickle +import threading import time import traceback from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List +from typing import Dict, List, Tuple import zmq @@ -58,6 +59,8 @@ def __init__(self, cfg, worker_queue, resource_manager): if self.cfg.scheduler_config.splitwise_role != "mixed": self.zmq_ctx = zmq.Context() self.push_sockets: Dict[str, zmq.Socket] = {} + self._push_socket_locks: Dict[str, threading.Lock] = {} + self._push_sockets_meta_lock = threading.Lock() self.pull_socket = None self.io_executor = ThreadPoolExecutor(max_workers=4) self._init_network() @@ -105,13 +108,21 @@ def start_receiver(self): self.logger.error(f"start_receiver: Receiver error: {e}, {str(traceback.format_exc())}") time.sleep(1) - def _get_push_socket(self, addr): - """获取或创建 DEALER socket""" + def _get_push_socket(self, addr) -> Tuple[zmq.Socket, threading.Lock]: + """ + 获取或创建 DEALER socket 及其发送锁。 + + Returns: + Tuple[zmq.Socket, threading.Lock]: 目标地址对应的 socket 和保护 multipart 发送的锁。 + """ - if addr in self.push_sockets: - sock = self.push_sockets[addr] - if not sock.closed: - return sock + with self._push_sockets_meta_lock: + if addr in self.push_sockets: + sock = self.push_sockets[addr] + if not sock.closed: + return sock, self._push_socket_locks[addr] + del self.push_sockets[addr] + self._push_socket_locks.pop(addr, None) try: self.logger.info(f"_get_push_socket: Establishing new connection to {addr}") @@ -129,8 +140,18 @@ def _get_push_socket(self, addr): sock.connect(f"tcp://{addr}") - self.push_sockets[addr] = sock - return sock + with self._push_sockets_meta_lock: + if addr in self.push_sockets: + existing_sock = self.push_sockets[addr] + if not existing_sock.closed: + sock.close() + return existing_sock, self._push_socket_locks[addr] + del self.push_sockets[addr] + self._push_socket_locks.pop(addr, None) + + self.push_sockets[addr] = sock + self._push_socket_locks[addr] = threading.Lock() + return sock, self._push_socket_locks[addr] except zmq.ZMQError as e: self.logger.error(f"_get_push_socket: Connection to {addr} failed: {e}") @@ -144,8 +165,11 @@ def _send_message(self, addr, msg_type: str, payload): message = self._serialize_message(msg_type, payload) try: self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}") - sock = self._get_push_socket(addr) - sock.send_multipart(message) + sock, lock = self._get_push_socket(addr) + with lock: + if sock.closed: + raise ConnectionError(f"Connection to {addr} is closed") + sock.send_multipart(message) self.logger.info(f"Sent {msg_type} to {addr}") @@ -164,9 +188,19 @@ def _close_connection(self, addr): """ Close the connection to the specified address. """ - if addr in self.push_sockets: - self.push_sockets[addr].close() - del self.push_sockets[addr] + sock = None + lock = None + with self._push_sockets_meta_lock: + if addr in self.push_sockets: + sock = self.push_sockets.pop(addr) + lock = self._push_socket_locks.pop(addr, None) + + if sock is not None: + if lock is not None: + with lock: + sock.close() + else: + sock.close() def send_splitwise_tasks(self, tasks: List[Request], current_id): """ diff --git a/tests/splitwise/test_splitwise_connector.py b/tests/splitwise/test_splitwise_connector.py index 6a50b34cfc5..569e138fa06 100644 --- a/tests/splitwise/test_splitwise_connector.py +++ b/tests/splitwise/test_splitwise_connector.py @@ -17,6 +17,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from threading import Lock from typing import Any, Dict, List from unittest.mock import Mock, patch @@ -89,6 +90,10 @@ def _build_connector() -> SplitwiseConnector: connector = SplitwiseConnector(cfg=DummyCfg(), worker_queue=DummyWorkerQueue(), resource_manager=None) if not hasattr(connector, "push_sockets"): connector.push_sockets = {} + if not hasattr(connector, "_push_socket_locks"): + connector._push_socket_locks = {} + if not hasattr(connector, "_push_sockets_meta_lock"): + connector._push_sockets_meta_lock = Lock() return connector @@ -253,9 +258,11 @@ def test_get_push_socket_reuses_existing_and_handles_zmq_error(): open_socket = Mock() open_socket.closed = False connector.push_sockets["127.0.0.1:8000"] = open_socket + connector._push_socket_locks["127.0.0.1:8000"] = Lock() - same_socket = connector._get_push_socket("127.0.0.1:8000") + same_socket, same_lock = connector._get_push_socket("127.0.0.1:8000") assert same_socket is open_socket + assert same_lock is connector._push_socket_locks["127.0.0.1:8000"] connector.zmq_ctx = Mock() connector.zmq_ctx.socket.side_effect = zmq.ZMQError("boom") @@ -270,9 +277,10 @@ def test_get_push_socket_creates_and_configures_socket(): new_socket.closed = False connector.zmq_ctx.socket.return_value = new_socket - socket = connector._get_push_socket("127.0.0.1:7000") + socket, lock = connector._get_push_socket("127.0.0.1:7000") assert socket is new_socket + assert lock is connector._push_socket_locks["127.0.0.1:7000"] new_socket.connect.assert_called_once_with("tcp://127.0.0.1:7000") assert connector.push_sockets["127.0.0.1:7000"] is new_socket @@ -280,7 +288,8 @@ def test_get_push_socket_creates_and_configures_socket(): def test_send_message_serializes_and_sends_payload(): connector = _build_connector() mock_socket = Mock() - connector._get_push_socket = Mock(return_value=mock_socket) + mock_socket.closed = False + connector._get_push_socket = Mock(return_value=(mock_socket, Lock())) request = Request( request_id="req-send", prompt=None, @@ -317,14 +326,17 @@ def test_send_message_handles_missing_addr_and_errors(): connector._send_message("127.0.0.1:7000", "prefill", []) failing_socket = Mock() + failing_socket.closed = False failing_socket.send_multipart.side_effect = zmq.Again() - connector._get_push_socket = Mock(return_value=failing_socket) + connector._get_push_socket = Mock(return_value=(failing_socket, Lock())) connector._send_message("127.0.0.1:7001", "prefill", []) crash_socket = Mock() + crash_socket.closed = False crash_socket.send_multipart.side_effect = RuntimeError("boom") - connector._get_push_socket = Mock(return_value=crash_socket) + connector._get_push_socket = Mock(return_value=(crash_socket, Lock())) connector.push_sockets["127.0.0.1:7002"] = crash_socket + connector._push_socket_locks["127.0.0.1:7002"] = Lock() connector._send_message("127.0.0.1:7002", "prefill", []) assert "127.0.0.1:7002" not in connector.push_sockets From 228987aea6ff7f22dbe481294238b4ca0d001152 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Mon, 11 May 2026 09:40:43 +0800 Subject: [PATCH 091/143] [Cherry-Pick] [BugFix] [RL] Fix cpu cache for rl (#7764) (#7765) * [BugFix] [RL] Fix cpu cache for rl * [fix] fix stable ci * [fix] fix stable ci * [fix] fix stable ci --- .../cache_manager/cache_transfer_manager.py | 21 ++++++++++++------- fastdeploy/worker/gpu_model_runner.py | 16 ++++++++++---- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 50def0a6c36..36306ee5dc6 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -743,23 +743,28 @@ def _init_gpu_cache(self): logger.info("GPU KV cache is initialized") def _clear_gpu_cache(self): + if self.create_cache_tensor: logger.debug("Waiting for gpu runner to unlink cuda ipc") while self.cache_ready_signal.value[self.rank] != 0: time.sleep(0.1) logger.debug("Stop waiting! gpu runner has unlinked cuda ipc") - self.gpu_cache_kvs.clear() - self.gpu_cache_k_tensors.clear() - self.gpu_cache_v_tensors.clear() - if hasattr(self, "gpu_cache_scales_k_tensors"): - self.gpu_cache_scales_k_tensors.clear() - if hasattr(self, "gpu_cache_scales_v_tensors"): - self.gpu_cache_scales_v_tensors.clear() - paddle.device.cuda.empty_cache() else: for name, tensor in self.gpu_cache_kvs.items(): unset_data_ipc(tensor, name, True, False) logger.debug("Successfully unlinked gpu caches cuda ipc") + + self.gpu_cache_kvs.clear() + self.gpu_cache_k_tensors.clear() + self.gpu_cache_v_tensors.clear() + if hasattr(self, "gpu_cache_scales_k_tensors"): + self.gpu_cache_scales_k_tensors.clear() + if hasattr(self, "gpu_cache_scales_v_tensors"): + self.gpu_cache_scales_v_tensors.clear() + paddle.set_flags({"FLAGS_selected_gpus": f"{self.device}"}) + paddle.device.cuda.empty_cache() + + if not self.create_cache_tensor: self.cache_ready_signal.value[self.rank] = 0 while np.sum(self.cache_ready_signal.value) != 0: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8b4d29bc081..0b55aa97c6c 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2868,10 +2868,18 @@ def clear_cache(self, profile=False): ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not create_cache_tensor: - for name, tensor in self.cache_kvs_map.items(): - unset_data_ipc(tensor, name, True, False) - self.cache_ready_signal.value[local_rank] = 0 + if not profile: + if create_cache_tensor: + if self.fd_config.cache_config.num_cpu_blocks > 0: + logger.info("Waiting for cache transfer manager to unlink cuda ipc") + while self.cache_ready_signal.value[local_rank] != 0: + time.sleep(0.1) + logger.info("Stop waiting! cache transfer manager has unlinked cuda ipc") + else: + for name, tensor in self.cache_kvs_map.items(): + unset_data_ipc(tensor, name, True, False) + self.cache_ready_signal.value[local_rank] = 0 + self.cache_kvs_map.clear() self.share_inputs.pop("caches", None) if self.forward_meta is not None: From ad431c77b31274841bc89aa101617de5f1e322a1 Mon Sep 17 00:00:00 2001 From: RAM Date: Mon, 11 May 2026 14:27:36 +0800 Subject: [PATCH 092/143] [RL] R3 Support Overlap Schedule (#7674) * Correct the semantics of max_num_batched_tokens with multi mode * fix D2H bug * R3 support Overlap Schedule * rewrite get_position_id kernel * fix bug when slice the pinned memory * fix test case * rename gpu_routing_buffer as device_routing_buffer --- custom_ops/gpu_ops/cpp_extensions.cc | 12 +- custom_ops/gpu_ops/get_position_ids.cu | 67 +++++++++ ...get_position_ids_and_mask_encoder_batch.cu | 79 ---------- custom_ops/setup_ops.py | 6 +- fastdeploy/model_executor/forward_meta.py | 2 +- fastdeploy/model_executor/layers/moe/moe.py | 4 +- .../layers/moe/routing_indices_cache.py | 139 +++++++++--------- .../model_executor/pre_and_post_process.py | 28 +--- fastdeploy/worker/gpu_model_runner.py | 25 ++-- fastdeploy/worker/metax_model_runner.py | 6 +- ...get_position_ids_and_mask_encoder_batch.py | 18 +-- tests/worker/test_gpu_model_runner.py | 1 + 12 files changed, 179 insertions(+), 208 deletions(-) create mode 100644 custom_ops/gpu_ops/get_position_ids.cu delete mode 100644 custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index bc6f7e0783a..15866c57643 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -540,10 +540,10 @@ std::vector count_tokens_per_expert_func( const paddle::Tensor& topk_ids, int64_t num_experts, bool compute_padded_cumsum = false); -void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& position_ids); +void GetPositionIds(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids); std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_nope, @@ -1639,9 +1639,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("is_zp_float")); #endif - m.def("get_position_ids_and_mask_encoder_batch", - &GetPositionIdsAndMaskEncoderBatch, - "get_position_ids_and_mask_encoder_batch function"); + m.def("get_position_ids", &GetPositionIds, "get_position_ids function"); /** * cutlass_scaled_mm.cu diff --git a/custom_ops/gpu_ops/get_position_ids.cu b/custom_ops/gpu_ops/get_position_ids.cu new file mode 100644 index 00000000000..3c04332934f --- /dev/null +++ b/custom_ops/gpu_ops/get_position_ids.cu @@ -0,0 +1,67 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void GetPositionIdsKernel(const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + const int* __restrict__ seq_lens_this_time, + int* __restrict__ position_ids, + const int bsz) { + int current_bid = threadIdx.x; + if (current_bid >= bsz) return; + + // Caculate the offset of current batch in the position_ids buffer + int buffer_offset = 0; + for (int i = 0; i < current_bid; i++) { + buffer_offset += seq_lens_this_time[i]; + } + + // Caculate the token offset in the current batch + int token_offset = seq_lens_decoder[current_bid]; + int token_num_this_batch = seq_lens_this_time[current_bid]; + if (token_num_this_batch == 0) return; + +// Write position ids for current batch +#pragma unroll + for (int i = 0; i < token_num_this_batch; i++) { + position_ids[buffer_offset + i] = token_offset + i; + } +} + +void GetPositionIds(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids) { + const int bsz = seq_lens_this_time.shape()[0]; + + GetPositionIdsKernel<<<1, bsz, 0, position_ids.stream()>>>( + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + const_cast(position_ids.data()), + bsz); +} + +PD_BUILD_STATIC_OP(get_position_ids) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "position_ids", + }) + .Outputs({"position_ids_out"}) + .SetInplaceMap({{"position_ids", "position_ids_out"}}) + .SetKernelFn(PD_KERNEL(GetPositionIds)); diff --git a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu deleted file mode 100644 index 63bc77c9afc..00000000000 --- a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "helper.h" -#include "paddle/extension.h" - -__global__ void GetPositionIdsAndMaskEncoderBatchKernel( - const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度 - const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度 - const int* seq_lens_this_time, - int* position_ids, // 输出的一维 position_ids - const int bsz) { // 批次大小 - // 当前线程索引(每个线程对应一个批次) - int tid = threadIdx.x; - if (tid >= bsz) return; - - // 动态计算当前批次的偏移量 - int offset = 0; - for (int i = 0; i < tid; i++) { - offset += seq_lens_encoder[i]; - if (seq_lens_decoder[i] > 0) { - offset += seq_lens_this_time[i]; - } - } - - // 当前批次的 encoder 和 decoder 长度 - int encoder_len = seq_lens_encoder[tid]; - int decoder_len = seq_lens_decoder[tid]; - int seq_len_this_time = seq_lens_this_time[tid]; - - // 写入 encoder 的 position_ids - for (int i = 0; i < encoder_len; i++) { - position_ids[offset + i] = i; - } - offset += encoder_len; - - // 写入 decoder 的 position_ids - if (decoder_len > 0) { - for (int i = 0; i < seq_len_this_time; i++) { - position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身 - } - } -} - -void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& position_ids) { - const int bsz = seq_lens_this_time.shape()[0]; - - GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>( - seq_lens_encoder.data(), - seq_lens_decoder.data(), - seq_lens_this_time.data(), - const_cast(position_ids.data()), - bsz); -} - -PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch) - .Inputs({ - "seq_lens_encoder", - "seq_lens_decoder", - "seq_lens_this_time", - "position_ids", - }) - .Outputs({"position_ids_out"}) - .SetInplaceMap({{"position_ids", "position_ids_out"}}) - .SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index b9a2fe90dbc..268cde02825 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -250,7 +250,7 @@ def find_end_files(directory, end_str): "gpu_ops/speculate_decoding/speculate_step.cu", "gpu_ops/speculate_decoding/speculate_step_system_cache.cu", "gpu_ops/speculate_decoding/speculate_update_v3.cu", - "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/get_position_ids.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/step_reschedule.cu", ] @@ -326,7 +326,7 @@ def find_end_files(directory, end_str): "gpu_ops/sample_kernels/rejection_top_p_sampling.cu", "gpu_ops/sample_kernels/top_k_renorm_probs.cu", "gpu_ops/sample_kernels/min_p_sampling_from_probs.cu", - "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/get_position_ids.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", @@ -687,7 +687,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", - "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/get_position_ids.cu", "gpu_ops/limit_thinking_content_length.cu", "gpu_ops/update_attn_mask_offsets.cu", "gpu_ops/append_attn/mla_cache_kernel.cu", diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 03a2734b41d..effb8108422 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -147,7 +147,7 @@ class ForwardMeta: # Flag of profile run is_dummy_or_profile_run: bool = False # GPU transient routing buffer [max_num_batched_tokens, num_moe_layers, top_k] - gpu_routing_buffer: Optional[paddle.Tensor] = None + device_routing_buffer: Optional[paddle.Tensor] = None # chunked MoE related moe_num_chunk: int = 1 diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index cc427a3fe54..9cb9340cf01 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -724,11 +724,11 @@ def forward( topk_ids_hookfunc = None if self.enable_routing_replay: # When execute empty_input_forward forward_meta is None. When execute mtp layer routing_replay_table is None. - if forward_meta is not None and forward_meta.gpu_routing_buffer is not None: + if forward_meta is not None and forward_meta.device_routing_buffer is not None: moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index topk_ids_hookfunc = partial( save_routing_to_buffer_v2, - gpu_routing_buffer=forward_meta.gpu_routing_buffer, + device_routing_buffer=forward_meta.device_routing_buffer, layer_idx=moe_layer_idx, tp_size=self.fd_config.parallel_config.tensor_parallel_size, ep_size=self.fd_config.parallel_config.expert_parallel_size, diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index 21e9f406366..57765f65255 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -31,7 +31,7 @@ @enable_compat_on_triton_kernel @triton.jit def _save_routing_kernel_v2( - GPU_ROUTING_BUFFER_PTR, + device_routing_buffer_PTR, TOPK_IDS_PTR, LAYER_IDX, TOKEN_NUM, @@ -55,13 +55,16 @@ def _save_routing_kernel_v2( STRIDE_TOKEN = NUM_MOE_LAYERS * TOP_K STRIDE_LAYER = TOP_K output_ptrs = ( - GPU_ROUTING_BUFFER_PTR + token_offsets[:, None] * STRIDE_TOKEN + LAYER_IDX * STRIDE_LAYER + k_offsets[None, :] + device_routing_buffer_PTR + + token_offsets[:, None] * STRIDE_TOKEN + + LAYER_IDX * STRIDE_LAYER + + k_offsets[None, :] ) tl.store(output_ptrs, topk_vals, mask=load_mask) def save_routing_to_buffer_v2( - gpu_routing_buffer: paddle.Tensor, + device_routing_buffer: paddle.Tensor, topk_ids: paddle.Tensor, layer_idx: int, tp_size: int, @@ -81,19 +84,21 @@ def save_routing_to_buffer_v2( topk_ids = topk_ids_all[:total_token_num, :] token_num, top_k = topk_ids.shape - buf_max_tokens, num_moe_layers, buf_top_k = gpu_routing_buffer.shape + buf_max_tokens, num_moe_layers, buf_top_k = device_routing_buffer.shape assert ( token_num <= buf_max_tokens - ), f"[R3] token_num={token_num} exceeds gpu_routing_buffer capacity={buf_max_tokens}" - assert top_k == buf_top_k, f"[R3] top_k mismatch: topk_ids.top_k={top_k} vs gpu_routing_buffer.top_k={buf_top_k}" + ), f"[R3] token_num={token_num} exceeds device_routing_buffer capacity={buf_max_tokens}" + assert ( + top_k == buf_top_k + ), f"[R3] top_k mismatch: topk_ids.top_k={top_k} vs device_routing_buffer.top_k={buf_top_k}" assert 0 <= layer_idx < num_moe_layers, f"[R3] layer_idx={layer_idx} out of range [0, {num_moe_layers})" BLOCK_SIZE_M = 128 BLOCK_SIZE_K = triton.next_power_of_2(top_k) grid = (triton.cdiv(token_num, BLOCK_SIZE_M),) _save_routing_kernel_v2[grid]( - gpu_routing_buffer, + device_routing_buffer, topk_ids, LAYER_IDX=layer_idx, TOKEN_NUM=token_num, @@ -110,9 +115,8 @@ class RoutedExpertsCapturer: Does NOT manage request lifecycle — that is handled by RoutingCacheManager on the Engine side. """ - def __init__(self, fd_config: FDConfig, block_table, total_block_num): + def __init__(self, fd_config: FDConfig, total_block_num: int): self.fd_config = fd_config - self.block_table = block_table self.max_num_seqs = fd_config.scheduler_config.max_num_seqs # Read routing params from centralized config @@ -125,20 +129,23 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num): logger.info(f"[R3] RoutedExpertsCapturer config: {rrc}") self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num) - self.pending_update_positions = None def _init_routing_cache(self, dtype: str, total_block_num: int): - """Initialize GPU transient buffer and prepare lazy SharedMemory attach.""" + """Initialize GPU transient buffer, staging buffers, and CPU pinned buffers.""" max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size # Small GPU transient buffer: only current step's token routing # TODO(Chengyanfu): Use max_num_batched_tokens to replace get_max_chunk_tokens() max_num_batched_tokens = self.fd_config.get_max_chunk_tokens() - self.gpu_routing_buffer = paddle.full( - shape=[max_num_batched_tokens, self.num_moe_layers, self.moe_top_k], - fill_value=-1, - dtype=dtype, - ) + shape = [max_num_batched_tokens, self.num_moe_layers, self.moe_top_k] + + self.device_routing_buffer = paddle.full(shape=shape, fill_value=-1, dtype=dtype) + self.routing_staging_buf = paddle.full(shape=shape, fill_value=-1, dtype=dtype) + self.slot_mapping_staging_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64) + + self.cpu_routing_buf = paddle.zeros(shape, dtype=dtype).pin_memory() + self.cpu_slot_mapping_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory() + self._pending_save = None # {"num_tokens": int} # Lazy attach to SharedMemory routing_host_buffer (created by Engine after profiling) self.routing_host_view = None @@ -149,9 +156,9 @@ def _init_routing_cache(self, dtype: str, total_block_num: int): self._routing_host_view_shape = (max_num_kv_tokens, self.num_moe_layers, self.moe_top_k) self._routing_host_view_dtype = dtype - gpu_buffer_bytes = int(np.prod(self.gpu_routing_buffer.shape)) * np.dtype(dtype).itemsize + gpu_buffer_bytes = int(np.prod(self.device_routing_buffer.shape)) * np.dtype(dtype).itemsize logger.info( - f"[R3] GPU transient routing buffer: {self.gpu_routing_buffer.shape} " + f"[R3] GPU transient routing buffer: {self.device_routing_buffer.shape} " f"({gpu_buffer_bytes / 1024:.1f} KB)" ) @@ -173,67 +180,59 @@ def _try_attach_routing_host_view(self): "Routing capture will be skipped." ) - def save_captured_routing(self, num_tokens: int, slot_mapping: np.ndarray): + def prepare_pending_save(self, num_tokens: int, slot_mapping_gpu: paddle.Tensor): + """ + Enqueue D2D + async D2H for routing data and slot_mapping. + Must be called before post_process_event.record(). + All ops are enqueued on the current CUDA stream; CPU returns immediately. + + 1. D2D (non-blocking): device_routing_buffer → routing_staging_buf + 2. D2D (non-blocking): slot_mapping_gpu → slot_mapping_staging_buf + 3. async D2H: routing_staging_buf → cpu_routing_buf + 4. async D2H: slot_mapping_staging_buf → cpu_slot_mapping_buf + """ + if num_tokens > 0: + # D2D: GPU → staging + self.routing_staging_buf.copy_(self.device_routing_buffer, False) + self.slot_mapping_staging_buf.copy_(slot_mapping_gpu, False) + # Async D2H: staging → CPU pinned + self.cpu_routing_buf.copy_(self.routing_staging_buf, False) + self.cpu_slot_mapping_buf.copy_(self.slot_mapping_staging_buf, False) + self._pending_save = {"num_tokens": num_tokens} + else: + self._pending_save = None + + def flush_pending_save(self): """ - After forward, scatter GPU buffer routing data to routing_host_buffer. - Called in step gap (post_process), not during forward. CUDAGraph compatible. + Pure CPU operation. Called after post_process_event.synchronize(), + which guarantees all D2D and D2H transfers have completed. + Scatter from CPU pinned buffers to SharedMemory. """ - assert slot_mapping.shape[0] == num_tokens - if num_tokens == 0: + pending = self._pending_save + if pending is None: return - - # Lazy attach to SharedMemory (Engine creates it after profiling completes) - if self.routing_host_view is None and not self._routing_host_view_attach_attempted: - self._try_attach_routing_host_view() + self._pending_save = None if self.routing_host_view is None: - return + if not self._routing_host_view_attach_attempted: + self._try_attach_routing_host_view() + if self.routing_host_view is None: + return - # D2H copy: GPU → CPU numpy, then scatter to SharedMemory - data = self.gpu_routing_buffer[:num_tokens].cpu().numpy() - self.routing_host_view.scatter(slot_mapping, data) + num_tokens = pending["num_tokens"] + # NOTE(gongshaotian): Slice pinned memory tensor maybe cause problem. + data = self.cpu_routing_buf.cpu()[:num_tokens].numpy() + slot_np = self.cpu_slot_mapping_buf.cpu()[:num_tokens].numpy() - def compute_slot_mapping_flat(self, positions) -> np.ndarray: - """ - Compute flat slot_mapping for all tokens in the step. - Returns a 1D numpy array of slot indices. - """ - all_slots = [] - block_size = self.fd_config.cache_config.block_size - for batch_id, position in enumerate(positions): - if len(position) == 0: - continue - block_table_indices = position // block_size - token_block_ids = self.block_table[batch_id, block_table_indices] - block_offset = position % block_size - token_cache_ids = np.array(token_block_ids) * block_size + block_offset - all_slots.append(token_cache_ids) - if all_slots: - return np.concatenate(all_slots) - return np.array([], dtype=np.int64) - - def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): - """Get token position of each sequence in a batch.""" - starts = seq_lens_decoder.numpy() - increase_num = seq_lens_this_time.numpy() - - positions = [] - for i in range(seq_lens_this_time.shape[0]): - if increase_num[i] == 0: - positions.append([]) - continue - repeated_base = np.repeat(starts[i], increase_num[i]) - positions.append(repeated_base + np.arange(0, increase_num[i])) - - return positions - - def get_gpu_routing_buffer(self) -> paddle.Tensor: - return self.gpu_routing_buffer + self.routing_host_view.scatter(slot_np, data) + + def get_device_routing_buffer(self) -> paddle.Tensor: + return self.device_routing_buffer def clear(self): - """Clear GPU buffer and pending positions. Used during RL round cleanup.""" - self.gpu_routing_buffer.fill_(-1) - self.pending_update_positions = None + """Clear GPU buffer and pending save state. Used during RL round cleanup.""" + self.device_routing_buffer.fill_(-1) + self._pending_save = None # Backward compatibility alias diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 4c7a09b93c0..94b917f3f9f 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -339,18 +339,10 @@ def post_process_normal( # Routing replay if routing_replay_manager is not None: - # Trigger lazy SharedMemory attach if not yet attempted - routing_replay_manager._try_attach_routing_host_view() - # GPU transient buffer → SharedMemory routing_host_buffer - slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( - positions=routing_replay_manager.pending_update_positions - ) - num_tokens = len(slot_mapping_flat) + slot_mapping_gpu = share_inputs["slot_mapping_buffer"] + num_tokens = int(share_inputs["ids_remove_padding"].shape[0]) if routing_replay_manager.tp_rank == 0: - routing_replay_manager.save_captured_routing( - num_tokens=num_tokens, - slot_mapping=slot_mapping_flat, - ) + routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu) # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): @@ -521,18 +513,10 @@ def post_process_speculate( # Routing replay if routing_replay_manager is not None: - # Trigger lazy SharedMemory attach if not yet attempted - routing_replay_manager._try_attach_routing_host_view() - # GPU transient buffer → SharedMemory routing_host_buffer - slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( - positions=routing_replay_manager.pending_update_positions - ) - num_tokens = len(slot_mapping_flat) + slot_mapping_gpu = share_inputs["slot_mapping_buffer"] + num_tokens = int(share_inputs["ids_remove_padding"].shape[0]) if routing_replay_manager.tp_rank == 0: - routing_replay_manager.save_captured_routing( - num_tokens=num_tokens, - slot_mapping=slot_mapping_flat, - ) + routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu) # Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx # into a single kernel launch. Handles EOS detection, max_dec_len truncation, step_idx diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 0b55aa97c6c..47f700e83e5 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -84,7 +84,7 @@ speculate_schedule_cache, set_data_ipc, unset_data_ipc, - get_position_ids_and_mask_encoder_batch, + get_position_ids, ) import zmq @@ -1290,11 +1290,14 @@ def _compute_position_ids_and_slot_mapping(self) -> None: Results are stored in self.forward_meta. """ # NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently. - if not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)): + # Also needed when R3 (Routing Replay) is enabled for slot_mapping_buffer computation. + needs_slot_mapping = isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)) + needs_slot_mapping = (self.routing_replay_manager is not None) or needs_slot_mapping + if not needs_slot_mapping: return current_total_tokens = self.forward_meta.ids_remove_padding.shape[0] position_ids = self.share_inputs["position_ids_buffer"][:current_total_tokens] - get_position_ids_and_mask_encoder_batch( + get_position_ids( self.forward_meta.seq_lens_encoder, self.forward_meta.seq_lens_decoder, self.forward_meta.seq_lens_this_time, @@ -1354,9 +1357,9 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): """ # Initialize forward meta num_running_requests = self.share_inputs["seq_lens_this_time"].shape[0] - gpu_routing_buffer = None + device_routing_buffer = None if self.routing_replay_manager is not None: - gpu_routing_buffer = self.routing_replay_manager.get_gpu_routing_buffer() + device_routing_buffer = self.routing_replay_manager.get_device_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1383,7 +1386,7 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - gpu_routing_buffer=gpu_routing_buffer, + device_routing_buffer=device_routing_buffer, ) dist_status = self.collect_distributed_status() @@ -2247,6 +2250,8 @@ def execute_model_normal( if model_output_data is not None: # synchronizes the async DtoH copies of sampled_token_ids. post_process_event.synchronize() + if self.routing_replay_manager is not None: + self.routing_replay_manager.flush_pending_save() self._save_model_output(model_output_data, sampler_output) def execute_model_overlap( @@ -2263,6 +2268,8 @@ def execute_model_overlap( if self._cached_model_output_data is not None: # synchronizes the async DtoH copies of sampled_token_ids. self._cached_post_process_event.synchronize() + if self.routing_replay_manager is not None: + self.routing_replay_manager.flush_pending_save() self._save_model_output( self._cached_model_output_data, self._cached_sampler_output, @@ -2331,11 +2338,6 @@ def _preprocess( p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests) self.sampler.pre_process(p_done_idxs) - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.pending_update_positions = self.routing_replay_manager.get_token_positions( - seq_lens_decoder=self.share_inputs["seq_lens_decoder"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time"], - ) # Update state of logits processor for proc in self.sampling_metadata.logits_processors: @@ -3330,6 +3332,5 @@ def initialize_routing_replay_manager(self): # Use updated block number self.routing_replay_manager = RoutingReplayManager( fd_config=self.fd_config, - block_table=self.share_inputs["block_tables"], total_block_num=self.num_gpu_blocks, ) diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 7e721107f1e..9b9bbe2bb76 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -1225,9 +1225,9 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta - gpu_routing_buffer = None + device_routing_buffer = None if self.routing_replay_manager is not None: - gpu_routing_buffer = self.routing_replay_manager.get_gpu_routing_buffer() + device_routing_buffer = self.routing_replay_manager.get_device_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1255,7 +1255,7 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], routing_replay_table=None, - gpu_routing_buffer=gpu_routing_buffer, + device_routing_buffer=device_routing_buffer, ) dist_status = self.collect_distributed_status() diff --git a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py index 2d1dd8e2f7c..54b34850780 100644 --- a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py +++ b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py @@ -17,27 +17,27 @@ import numpy as np import paddle -from fastdeploy.model_executor.ops.gpu import get_position_ids_and_mask_encoder_batch +from fastdeploy.model_executor.ops.gpu import get_position_ids -class TestGetPositionIdsAndMaskEncoderBatch(unittest.TestCase): +class TestGetPositionIds(unittest.TestCase): def setUp(self): np.random.seed(42) paddle.set_device("gpu") def test_basic_functionality(self): # Test normal case with batch size 2 - seq_lens_encoder = paddle.to_tensor([3, 2], dtype="int32") - seq_lens_decoder = paddle.to_tensor([1, 2], dtype="int32") + seq_lens_encoder = paddle.to_tensor([1, 2], dtype="int32") + seq_lens_decoder = paddle.to_tensor([3, 2], dtype="int32") seq_lens_this_time = paddle.to_tensor([1, 2], dtype="int32") - total_len = int(seq_lens_encoder.numpy().sum() + seq_lens_this_time.numpy().sum()) + total_len = int(seq_lens_this_time.numpy().sum()) position_ids = paddle.zeros([total_len], dtype="int32") # Call the custom operator - get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) + get_position_ids(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) - expected_position_ids = np.array([0, 1, 2, 1, 0, 1, 2, 3], dtype=np.int32) + expected_position_ids = np.array([3, 2, 3], dtype=np.int32) # Convert to numpy for comparison position_ids_np = position_ids.numpy() @@ -49,11 +49,11 @@ def test_empty_decoder(self): # Test case where decoder length is 0 seq_lens_encoder = paddle.to_tensor([2], dtype="int32") seq_lens_decoder = paddle.to_tensor([0], dtype="int32") - seq_lens_this_time = paddle.to_tensor([0], dtype="int32") + seq_lens_this_time = paddle.to_tensor([2], dtype="int32") position_ids = paddle.zeros([2], dtype="int32") - get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) + get_position_ids(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) expected_position_ids = np.array([0, 1], dtype=np.int32) diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py index 8400fb79271..f0f44ea68c6 100644 --- a/tests/worker/test_gpu_model_runner.py +++ b/tests/worker/test_gpu_model_runner.py @@ -689,6 +689,7 @@ def _make_runner(self): runner._execute_empty_mtp_input = Mock() runner._cached_launch_token_num = 0 runner._cached_real_bsz = 0 + runner.routing_replay_manager = Mock() class _ShareInputs(dict): pass From 53af5cca43d6c33fde59e7e6c18c7e47e2ab9917 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Mon, 11 May 2026 14:50:06 +0800 Subject: [PATCH 093/143] [Cherry-Pick][CI] Remove checklist validation from CheckPRTemplate.py (#7760) (#7769) --- scripts/CheckPRTemplate.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/scripts/CheckPRTemplate.py b/scripts/CheckPRTemplate.py index c51d64b9bb6..5cb4536ab88 100644 --- a/scripts/CheckPRTemplate.py +++ b/scripts/CheckPRTemplate.py @@ -28,7 +28,6 @@ "## Modifications", "## Usage or Command", "## Accuracy Tests", - "## Checklist", ] } } @@ -65,27 +64,6 @@ def check_section_content(body, section_titles): return results -def parse_checklist(section_content): - """ - Parse a checklist section and return dict of items with checked status. - Example return: - { - 'Add at least a tag in the PR title.': False, - 'Format your code, run `pre-commit` before commit.': True, - ... - } - """ - items = {} - lines = section_content.splitlines() - for line in lines: - match = re.match(r"- \[( |x|X)\] (.+)", line) - if match: - checked = match.group(1).lower() == "x" - item_text = match.group(2).strip() - items[item_text] = checked - return items - - def check_pr_template(repo, body): """Check whether a PR description follows the expected template.""" body = remove_comments(body) @@ -108,21 +86,11 @@ def check_pr_template(repo, body): else: messages.append("❌ Missing sections: {}. Please complete them.".format(", ".join(missing))) - # Check Checklist items if present - checklist_content = results.get("## Checklist", "") - if checklist_content: - checklist_items = parse_checklist(checklist_content) - unchecked = [item for item, checked in checklist_items.items() if not checked] - if unchecked: - messages.append("❌ The following checklist items are not completed:") - for item in unchecked: - messages.append(f" - [ ] {item}") - if messages: messages.append( "\n💡 **Tips for fixing:**\n" "1. Each PR must follow the standard FastDeploy PR template.\n" - "2. Ensure every section (Motivation, Modifications, Usage, Accuracy Tests, Checklist) " + "2. Ensure every section (Motivation, Modifications, Usage, Accuracy Tests) " "is clearly filled with relevant details.\n" "3. You can refer to the official PR example: " "https://github.com/PaddlePaddle/FastDeploy/blob/develop/.github/pull_request_template.md\n" From f8a0cf2a4ddf1f3c0e811fc132e1151c0520aac7 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 11 May 2026 17:42:08 +0800 Subject: [PATCH 094/143] [BugFix][KSM] Fix sampling_mask reordering in recover_batch_index_for_sampler_output (#7773) --- fastdeploy/worker/input_batch.py | 10 ++ .../test_recover_batch_index_sampling_mask.py | 113 ++++++++++++++++++ 2 files changed, 123 insertions(+) create mode 100644 tests/worker/test_recover_batch_index_sampling_mask.py diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 225eac1c05b..ce679b90939 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -1208,3 +1208,13 @@ def recover_batch_index_for_sampler_output(sampler_output, index_to_batch_id, en logits = sampler_output.logits real_logits = _recover_tensor(logits, src_order) sampler_output.logits = real_logits + + if sampler_output.sampling_mask is not None: + sampling_mask = sampler_output.sampling_mask + sort_len = len(src_order) + real_sampling_mask = [None] * len(sampling_mask) + for i in range(sort_len): + real_sampling_mask[i] = sampling_mask[src_order[i]] + for i in range(sort_len, len(sampling_mask)): + real_sampling_mask[i] = sampling_mask[i] + sampler_output.sampling_mask = real_sampling_mask diff --git a/tests/worker/test_recover_batch_index_sampling_mask.py b/tests/worker/test_recover_batch_index_sampling_mask.py new file mode 100644 index 00000000000..6119faa0685 --- /dev/null +++ b/tests/worker/test_recover_batch_index_sampling_mask.py @@ -0,0 +1,113 @@ +from unittest.mock import Mock + +import numpy as np +import paddle +import pytest + +from fastdeploy.worker.input_batch import recover_batch_index_for_sampler_output + + +def _make_sampler_output(batch_size, with_sampling_mask=True): + """Create a minimal mock SamplerOutput for testing reorder logic.""" + so = Mock() + so.sampled_token_ids = paddle.arange(batch_size, dtype="int64").unsqueeze(1) + so.logprobs_tensors = Mock() + so.logprobs_tensors.logprob_token_ids = paddle.arange(batch_size, dtype="int64").unsqueeze(1) + so.logprobs_tensors.logprobs = paddle.arange(batch_size, dtype="float32").unsqueeze(1) + so.logprobs_tensors.selected_token_ranks = paddle.zeros([batch_size, 1], dtype="int64") + so.token_num_per_batch = None + so.cu_batch_token_offset = None + so.logits = None + + if with_sampling_mask: + so.sampling_mask = [np.array([i * 10, i * 10 + 1, i * 10 + 2]) for i in range(batch_size)] + else: + so.sampling_mask = None + + return so + + +class TestRecoverBatchIndexSamplingMask: + """Test sampling_mask reordering in recover_batch_index_for_sampler_output.""" + + def test_no_sampling_mask_no_error(self): + """SamplerOutput without sampling_mask should not raise.""" + so = _make_sampler_output(batch_size=4, with_sampling_mask=False) + index_to_batch_id = {0: 2, 1: 0, 2: 3, 3: 1} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + assert so.sampling_mask is None + + def test_sampling_mask_reorder_matches_token_ids(self): + """After reorder, sampling_mask[i] should correspond to sampled_token_ids[i].""" + batch_size = 4 + so = _make_sampler_output(batch_size=batch_size, with_sampling_mask=True) + + original_masks = [m.copy() for m in so.sampling_mask] + + # index_to_batch_id = {0:2, 1:0, 2:3, 3:1} + # src_order = [k for k,v in sorted(..., key=v)] = [1, 3, 0, 2] + # result[i] = src[src_order[i]] + index_to_batch_id = {0: 2, 1: 0, 2: 3, 3: 1} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + reordered_token_ids = so.sampled_token_ids.numpy().flatten() + for i in range(batch_size): + token_id = int(reordered_token_ids[i]) + expected_mask = original_masks[token_id] + np.testing.assert_array_equal( + so.sampling_mask[i], + expected_mask, + err_msg=f"Position {i}: sampling_mask doesn't match sampled_token_ids", + ) + + def test_identity_reorder_is_noop(self): + """When index_to_batch_id is identity, function returns early without changes.""" + batch_size = 3 + so = _make_sampler_output(batch_size=batch_size, with_sampling_mask=True) + original_masks = [m.copy() for m in so.sampling_mask] + + index_to_batch_id = {0: 0, 1: 1, 2: 2} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + for i in range(batch_size): + np.testing.assert_array_equal(so.sampling_mask[i], original_masks[i]) + + def test_pd_reorder_disabled_is_noop(self): + """When enable_pd_reorder=False, nothing is reordered.""" + batch_size = 3 + so = _make_sampler_output(batch_size=batch_size, with_sampling_mask=True) + original_masks = [m.copy() for m in so.sampling_mask] + original_token_ids = so.sampled_token_ids.clone() + + index_to_batch_id = {0: 2, 1: 0, 2: 1} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=False) + + assert paddle.equal_all(so.sampled_token_ids, original_token_ids) + for i in range(batch_size): + np.testing.assert_array_equal(so.sampling_mask[i], original_masks[i]) + + def test_sampling_mask_longer_than_sort_len(self): + """Tail elements beyond sort_len are preserved in place.""" + so = _make_sampler_output(batch_size=5, with_sampling_mask=True) + original_masks = [m.copy() for m in so.sampling_mask] + + # Only reorder first 3 positions; positions 3,4 should stay put + index_to_batch_id = {0: 1, 1: 2, 2: 0} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + # src_order = [2, 0, 1] + np.testing.assert_array_equal(so.sampling_mask[0], original_masks[2]) + np.testing.assert_array_equal(so.sampling_mask[1], original_masks[0]) + np.testing.assert_array_equal(so.sampling_mask[2], original_masks[1]) + np.testing.assert_array_equal(so.sampling_mask[3], original_masks[3]) + np.testing.assert_array_equal(so.sampling_mask[4], original_masks[4]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 7901aebc450f051a7b7c68d137d60b8a087de4f5 Mon Sep 17 00:00:00 2001 From: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> Date: Mon, 11 May 2026 20:35:51 +0800 Subject: [PATCH 095/143] [XPU][CI] fix XPU CI bug (#7778) * Update _build_xpu.yml * Update _xpu_4cards_case_test.yml * Update _xpu_8cards_case_test.yml --- .github/workflows/_build_xpu.yml | 2 +- .github/workflows/_xpu_4cards_case_test.yml | 2 +- .github/workflows/_xpu_8cards_case_test.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/_build_xpu.yml b/.github/workflows/_build_xpu.yml index b9bab8381d0..1222f040812 100644 --- a/.github/workflows/_build_xpu.yml +++ b/.github/workflows/_build_xpu.yml @@ -159,7 +159,7 @@ jobs: python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ else python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y - python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.5.0.dev20260507-cp310-cp310-linux_x86_64.whl fi diff --git a/.github/workflows/_xpu_4cards_case_test.yml b/.github/workflows/_xpu_4cards_case_test.yml index f3c97f40dc6..a60f9f3aa33 100644 --- a/.github/workflows/_xpu_4cards_case_test.yml +++ b/.github/workflows/_xpu_4cards_case_test.yml @@ -178,7 +178,7 @@ jobs: python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ else python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y - python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.5.0.dev20260507-cp310-cp310-linux_x86_64.whl fi echo "安装上游任务编译的fastdeploy-xpu..." python -m pip install ${FASTDEPLOY_WHEEL_URL} diff --git a/.github/workflows/_xpu_8cards_case_test.yml b/.github/workflows/_xpu_8cards_case_test.yml index c9ed0fa2314..de746b05050 100644 --- a/.github/workflows/_xpu_8cards_case_test.yml +++ b/.github/workflows/_xpu_8cards_case_test.yml @@ -167,7 +167,7 @@ jobs: python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ else python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y - python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.5.0.dev20260507-cp310-cp310-linux_x86_64.whl fi echo "安装上游任务编译的fastdeploy-xpu..." python -m pip install ${FASTDEPLOY_WHEEL_URL} From a5191f27230912555943a6d4f7247c80eb04a6c7 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Mon, 11 May 2026 22:17:58 +0800 Subject: [PATCH 096/143] [Cherry-Pick][Cleanup] Replace torch proxy alias with public compat API (#7348) (#7780) Co-authored-by: Codex --- .../layers/attention/flash_attn_backend.py | 2 +- .../layers/attention/mla_attention_backend.py | 2 +- .../batch_invariant_ops/batch_invariant_ops.py | 6 +++--- fastdeploy/model_executor/layers/moe/ep.py | 4 ++-- .../layers/moe/flashinfer_cutedsl_moe.py | 2 +- .../model_executor/layers/quantization/fp8_utils.py | 2 +- .../model_executor/layers/quantization/mxfp4.py | 2 +- .../model_executor/layers/quantization/nvfp4.py | 2 +- fastdeploy/worker/worker_process.py | 2 +- tests/cache_manager/test_cache_messager.py | 4 ++-- tests/cache_manager/test_cache_transfer_manager.py | 12 +++--------- tests/cache_manager/test_prefix_cache_manager.py | 4 ++-- tests/engine/test_common_engine.py | 10 ++-------- tests/inter_communicator/test_e2w_queue.py | 5 ++--- tests/inter_communicator/test_zmq_server.py | 5 ++--- tests/layers/test_fused_moe_cutlass_backend.py | 4 ++-- tests/layers/test_fused_moe_triton_backend.py | 4 ++-- tests/layers/test_sampler.py | 4 ++-- tests/quantization/test_modelopt_nvfp4.py | 5 ++++- tests/splitwise/test_splitwise_connector.py | 9 ++------- tests/v1/test_resource_manager_v1.py | 4 ++-- 21 files changed, 39 insertions(+), 55 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 2549f9f5d87..b203fdbb221 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -84,7 +84,7 @@ def init_flash_attn_version(): sm_version = get_sm_version() if sm_version >= 100: try: - paddle.compat.enable_torch_proxy(scope={"cutlass"}) + paddle.enable_compat(scope={"cutlass"}) from flash_mask.cute.interface import flashmask_attention as fa4 global flashmask_attention_v4 diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 209817e69a2..932c13decc3 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -18,7 +18,7 @@ import paddle -paddle.enable_compat(scope={"flash_mla"}) # Enable torch proxy before importing flash_mla +paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla import math import os from dataclasses import dataclass, field diff --git a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py index c0df764c07c..21f2c2dedc4 100644 --- a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py +++ b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py @@ -805,9 +805,9 @@ def enable_batch_invariant_mode(): if _batch_invariant_MODE: return - if hasattr(paddle, "compat") and hasattr(paddle.compat, "enable_torch_proxy"): - paddle.compat.enable_torch_proxy() - # TODO(liujundong): Enabling torch proxy here has a global effect. + if hasattr(paddle, "enable_compat"): + paddle.enable_compat() + # TODO(liujundong): Enabling paddle.enable_compat() here has a global effect. # Do NOT call this function from module import time, # otherwise it may affect other test cases during pytest collection. # (ex: Could not import module 'PretrainedTokenizer' or No module named 'paddle.distributed.tensor') diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 1b1df3748ad..1ddb6994878 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -40,8 +40,8 @@ def load_deep_ep() -> ModuleType: try: if envs.FD_USE_PFCC_DEEP_EP: - # Enable torch proxy before importing deep_ep (required by PFCC/PaddleFleet variants) - paddle.compat.enable_torch_proxy(scope={"deep_ep"}) + # Enable paddle.enable_compat before importing deep_ep (required by PFCC/PaddleFleet variants) + paddle.enable_compat(scope={"deep_ep"}) try: import paddlefleet.ops.deep_ep as deep_ep # type: ignore diff --git a/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py index b449c246cc0..26c729cbf34 100644 --- a/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py +++ b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py @@ -18,7 +18,7 @@ import paddle -paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +paddle.enable_compat(scope={"flashinfer"}) def _dtype_str(dtype) -> str: diff --git a/fastdeploy/model_executor/layers/quantization/fp8_utils.py b/fastdeploy/model_executor/layers/quantization/fp8_utils.py index a5cd230f601..d7ad693c1cd 100644 --- a/fastdeploy/model_executor/layers/quantization/fp8_utils.py +++ b/fastdeploy/model_executor/layers/quantization/fp8_utils.py @@ -67,7 +67,7 @@ def load_deep_gemm(): if current_platform.is_cuda(): if get_sm_version() >= 100: # SM100 should use PFCC DeepGemm - paddle.compat.enable_torch_proxy(scope={"deep_gemm"}) + paddle.enable_compat(scope={"deep_gemm"}) try: import logging diff --git a/fastdeploy/model_executor/layers/quantization/mxfp4.py b/fastdeploy/model_executor/layers/quantization/mxfp4.py index 9fa02866210..8732b1cc4d1 100644 --- a/fastdeploy/model_executor/layers/quantization/mxfp4.py +++ b/fastdeploy/model_executor/layers/quantization/mxfp4.py @@ -35,7 +35,7 @@ from ..moe import FusedMoE from .quant_base import QuantConfigBase, QuantMethodBase -paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +paddle.enable_compat(scope={"flashinfer"}) logger = get_logger("config", "config.log") diff --git a/fastdeploy/model_executor/layers/quantization/nvfp4.py b/fastdeploy/model_executor/layers/quantization/nvfp4.py index 196f6af6755..66e15750b75 100644 --- a/fastdeploy/model_executor/layers/quantization/nvfp4.py +++ b/fastdeploy/model_executor/layers/quantization/nvfp4.py @@ -33,7 +33,7 @@ from .quant_base import QuantConfigBase, QuantMethodBase -paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +paddle.enable_compat(scope={"flashinfer"}) from fastdeploy.platforms import current_platform diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f11837f4f52..3b0b8e1fe80 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -1317,7 +1317,7 @@ def run_worker_proc() -> None: # Enable batch-invariant mode for deterministic inference. # This must happen AFTER worker creation but BEFORE model loading, - # because enable_batch_invariant_mode() calls paddle.compat.enable_torch_proxy() + # because enable_batch_invariant_mode() calls paddle.enable_compat() # which makes torch appear available via proxy. If called before worker creation, # the gpu_model_runner import chain (ernie4_5_vl_processor → paddleformers → # transformers) will fail when transformers tries to query torch metadata. diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py index d053653e658..c69d27a24fa 100644 --- a/tests/cache_manager/test_cache_messager.py +++ b/tests/cache_manager/test_cache_messager.py @@ -19,8 +19,8 @@ import paddle import pytest -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda *args, **kwargs: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None from fastdeploy.cache_manager import cache_messager diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index afcd64574dd..599e0b8c5e0 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -21,15 +21,9 @@ import paddle -# Ensure paddle exposes compat.enable_torch_proxy for fastdeploy import compatibility. -if not hasattr(paddle, "compat"): - - class _DummyCompat: - @staticmethod - def enable_torch_proxy(scope=None): - return None - - paddle.compat = _DummyCompat() +# Ensure paddle exposes enable_compat for fastdeploy import compatibility. +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None # Add the root directory to Python path so we can import fastdeploy sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index 07df533f626..f2a4a5fa116 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -30,8 +30,8 @@ "ignore:ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead:DeprecationWarning" ) -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda **_: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda **_: None warnings.filterwarnings( "ignore", diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 53bd8462d81..84dbe2ce3c5 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -30,14 +30,8 @@ import paddle from e2e.utils.serving_utils import clean_ports -if not hasattr(paddle, "compat"): - - class _PaddleCompat: - @staticmethod - def enable_torch_proxy(scope=None): - return None - - paddle.compat = _PaddleCompat() +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.engine.args_utils import EngineArgs diff --git a/tests/inter_communicator/test_e2w_queue.py b/tests/inter_communicator/test_e2w_queue.py index 97a17346c91..d3cd657f01a 100644 --- a/tests/inter_communicator/test_e2w_queue.py +++ b/tests/inter_communicator/test_e2w_queue.py @@ -16,14 +16,13 @@ import threading import time -import types import unittest import numpy as np import paddle -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda **_: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda **_: None from fastdeploy import envs from fastdeploy.engine.request import Request diff --git a/tests/inter_communicator/test_zmq_server.py b/tests/inter_communicator/test_zmq_server.py index 57c9a0c479a..17925f219f1 100644 --- a/tests/inter_communicator/test_zmq_server.py +++ b/tests/inter_communicator/test_zmq_server.py @@ -6,7 +6,6 @@ import tempfile import threading import time -import types import unittest from collections import defaultdict from unittest import mock @@ -16,8 +15,8 @@ import zmq from zmq.utils import jsonapi -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda **kwargs: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda **kwargs: None from fastdeploy import envs from fastdeploy.inter_communicator.zmq_server import ( diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 0a03ecc62a8..98185a04c38 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -23,8 +23,8 @@ import paddle import pytest -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda *args, **kwargs: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None iluvatar_stub = types.ModuleType("fastdeploy.model_executor.ops.iluvatar") iluvatar_stub.moe_expert_ffn = lambda *args, **kwargs: None diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 1140cf72b16..7dacbbe390d 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -23,8 +23,8 @@ import paddle import pytest -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda scope=None: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None if not hasattr(paddle.nn.functional, "swiglu"): paddle.nn.functional.swiglu = lambda x: x diff --git a/tests/layers/test_sampler.py b/tests/layers/test_sampler.py index 9fde61f48cc..72475d4ba49 100644 --- a/tests/layers/test_sampler.py +++ b/tests/layers/test_sampler.py @@ -26,8 +26,8 @@ import fastdeploy # noqa: F401 -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda *args, **kwargs: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None # Optional runtime deps are intentionally stubbed for unit isolation. if "triton" not in sys.modules: diff --git a/tests/quantization/test_modelopt_nvfp4.py b/tests/quantization/test_modelopt_nvfp4.py index 3bf4653c725..27b5ac1309a 100644 --- a/tests/quantization/test_modelopt_nvfp4.py +++ b/tests/quantization/test_modelopt_nvfp4.py @@ -22,6 +22,9 @@ import paddle +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None + import fastdeploy.model_executor.layers.quantization.nvfp4 as nvfp4_module from fastdeploy.model_executor.layers.linear import QKVParallelLinear from fastdeploy.model_executor.layers.moe import FusedMoE @@ -133,7 +136,7 @@ def test_module_import_with_flashinfer(self): """Test module reloading when flashinfer is available.""" mock_flashinfer = types.ModuleType("flashinfer") with mock.patch.dict(sys.modules, {"flashinfer": mock_flashinfer}): - with mock.patch("paddle.compat.enable_torch_proxy"): + with mock.patch("paddle.enable_compat"): importlib.reload(nvfp4_module) diff --git a/tests/splitwise/test_splitwise_connector.py b/tests/splitwise/test_splitwise_connector.py index 569e138fa06..cc39a52cd76 100644 --- a/tests/splitwise/test_splitwise_connector.py +++ b/tests/splitwise/test_splitwise_connector.py @@ -25,13 +25,8 @@ import pytest import zmq -if not hasattr(paddle, "compat"): - - class _CompatStub: - def enable_torch_proxy(self, scope=None): - return None - - paddle.compat = _CompatStub() +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None from fastdeploy import envs from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index 5db1cc7f7f3..a93d5741d14 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -24,8 +24,8 @@ import numpy as np import paddle -if not hasattr(paddle, "compat"): - paddle.compat = SimpleNamespace(enable_torch_proxy=lambda scope: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig from fastdeploy.engine.args_utils import EngineArgs From fae4a8b35ce4b7e7ed2189c547023685cc8a07d9 Mon Sep 17 00:00:00 2001 From: Zero Rains Date: Tue, 12 May 2026 19:04:56 +0800 Subject: [PATCH 097/143] [BugFix] Fix KSM bug in MTP and Overlap (#7788) --- .../model_executor/layers/sample/sampler.py | 34 +++++---------- .../model_executor/pre_and_post_process.py | 41 ++++++++----------- fastdeploy/worker/output.py | 8 ---- 3 files changed, 26 insertions(+), 57 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index f06bd695149..a611769a9c9 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -238,9 +238,7 @@ def _extract_sparse_indices( """ Extract per-request sparse retained-token indices from CPU numpy arrays. - This is the CPU-side counterpart of _compute_sampling_mask. It should be - called after the sampling_mask_event has been synchronized, so that the - async D2H copy is guaranteed to be complete. + This is the CPU-side counterpart of _compute_sampling_mask. Args: indices_window_cpu: [B, max_k] int64 numpy array of sorted vocab indices. @@ -708,9 +706,7 @@ def forward_cuda( # All GPU ops; D2H is done via async copy_ with event sync in save_output. sampling_mask = None logz_per_batch = None - sampling_mask_event = None if sampling_metadata.keep_sampling_mask: - sampling_mask_event = paddle.device.cuda.create_event() indices_window_gpu, mask_window_gpu, logz_per_batch, mask_bsz = _compute_sampling_mask( probs, sampling_metadata.top_p, @@ -726,8 +722,6 @@ def forward_cuda( ).pin_memory() indices_window_cpu.copy_(indices_window_gpu, False) mask_window_cpu.copy_(mask_window_gpu, False) - # Record event — sync this event before reading CPU buffers - sampling_mask_event.record() # Store deferred GPU→CPU data; sparse extraction happens in save_output sampling_mask = (indices_window_cpu, mask_window_cpu, mask_bsz) @@ -756,7 +750,6 @@ def forward_cuda( logits=logits, sampling_mask=sampling_mask, logz_per_batch=logz_per_batch, - sampling_mask_event=sampling_mask_event, ) return sampler_output @@ -1245,20 +1238,16 @@ def forward_cuda( target_logits = target_logits[: accept_nums.sum()] # Derive target probs from already-extracted target_logits; avoids a second kernel call. target_probs = F.softmax(target_logits, axis=-1) - # Compute sampling mask at accepted token positions. - # Expand top_p from [batch, 1] to [total_accepted, 1]. - accept_top_p = ( - sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + accept_top_p, accept_top_k, _ = build_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + share_inputs["cu_seqlens_q_output"], + token_num_output_cpu, + increment_value, ) - accept_top_k = None - if ( - sampling_metadata.top_k is not None - and sampling_metadata.top_k_list - and any(x > 0 for x in sampling_metadata.top_k_list) - ): - accept_top_k = ( - sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) - ) + indices_window_gpu, mask_window_gpu, logz_per_batch, mask_bsz = _compute_sampling_mask( target_probs, accept_top_p, @@ -1274,11 +1263,8 @@ def forward_cuda( ).pin_memory() indices_window_cpu.copy_(indices_window_gpu, False) mask_window_cpu.copy_(mask_window_gpu, False) - sampling_mask_event = paddle.device.cuda.create_event() - sampling_mask_event.record() sampler_output.sampling_mask = (indices_window_cpu, mask_window_cpu, mask_bsz) sampler_output.logz_per_batch = logz_per_batch - sampler_output.sampling_mask_event = sampling_mask_event return sampler_output def forward_xpu( diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 94b917f3f9f..7cd71dabb51 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -386,17 +386,12 @@ def save_output_normal( save_each_rank: bool = False, sampling_mask_async_queue: Optional[queue.Queue] = None, ): - # Resolve deferred async D2H: sync event once at the top so all paths below - # can safely read sampling_mask and logz_per_batch. - if sampler_output.sampling_mask_event is not None: - sampler_output.sampling_mask_event.synchronize() - # Extract sparse indices from pinned CPU buffers - if sampler_output.sampling_mask is not None: - indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask - sampler_output.sampling_mask = _extract_sparse_indices( - indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz - ) - sampler_output.sampling_mask_event = None + # Extract sparse indices from pinned CPU buffers + if sampler_output.sampling_mask is not None: + indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask + sampler_output.sampling_mask = _extract_sparse_indices( + indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz + ) # Renormalize logprobs with logz (deferred from post_process for better overlap). if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: @@ -558,24 +553,20 @@ def save_output_speculate( ): # Resolve deferred async D2H: sync event once at the top so all paths below # can safely read sampling_mask and logz_per_batch. - if sampler_output.sampling_mask_event is not None: - sampler_output.sampling_mask_event.synchronize() - if sampler_output.sampling_mask is not None: - indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask - sampler_output.sampling_mask = _extract_sparse_indices( - indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz - ) - sampler_output.sampling_mask_event = None + mask_bsz = None + if sampler_output.sampling_mask is not None: + indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask + sampler_output.sampling_mask = _extract_sparse_indices( + indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz + ) # Renormalize logprobs with logz (deferred from post_process for better overlap). if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: - # TODO (wangyanpeng): Currently, there is a bug when overlap is enabled. - # Please ensure overlap is disabled when using this functionality to avoid unexpected behavior. - real_token_num = share_inputs["accept_num_cpu"].sum() + assert mask_bsz is not None sampler_output.logprobs_tensors = LogprobsTensors( - logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids[:real_token_num], - logprobs=sampler_output.logprobs_tensors.logprobs[:real_token_num], - selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks[:real_token_num], + logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids[:mask_bsz], + logprobs=sampler_output.logprobs_tensors.logprobs[:mask_bsz], + selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks[:mask_bsz], ) sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( sampler_output.logprobs_tensors.logprobs, diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 3f247d66197..44cc9cb9e16 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -180,11 +180,6 @@ class SamplerOutput: cu_batch_token_offset: Optional[paddle.Tensor] = None logits: Optional[paddle.Tensor] = None # Sparse sampling mask for top_p/top_k: - # Before sampling_mask_event sync: stored as a deferred tuple - # (indices_window_cpu, mask_window_cpu, real_bsz) where the CPU tensors - # are pinned-memory targets of async D2H copies. - # After event sync + _extract_sparse_indices: converted to the final - # List[np.ndarray] format below. # - Non-speculative decoding: per-request mask. This is a list of length # num_reqs, where element i is a 1-D int32 numpy array of vocab indices # retained by top_p/top_k for request i. Replaces the previous dense @@ -202,9 +197,6 @@ class SamplerOutput: # Used for renormalizing logprobs to match the truncated sampling distribution. # Shape: [num_reqs] logz_per_batch: Optional[np.ndarray] = None - # CUDA event that guards async D2H copy of sampling_mask / logz_per_batch. - # Must be synchronized before reading sampling_mask or logz_per_batch. - sampling_mask_event: Optional[object] = None @dataclass From ae5dac1d380402776c21113c2bc866c936cc1ce1 Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Tue, 12 May 2026 20:20:35 +0800 Subject: [PATCH 098/143] [Cherry-Pick][Optimization] enable trtllm_all_reduce fusion kernel in glm model (#6660) (#7228) * enable trtllm_all_reduce fusion kernel in glm model * update flashinfer paddle version * format update modify test modify test support empty tensor and modify test fix test_linear config issues modify test name add edge test case modify format fix conflict modify default max token num in trtllm_allreduce_fusion add max token num branch for trtllm_allreduce_fusion fix format fix rmsnorm config issue modify 2025 to 2026 enable trtllm_allreduce fusion Revert "[Cherry-Pick][CI] Use GPU-Build-RL runner for _build_linux_rl.yml (#7186) (#7195)" This reverts commit ca2f38b93452ad6bf4593361cbbdb47832f3eab9. Revert "[Cherry-Pick][BugFix] prevent requests from entering running state without a slot(#7141) (#7181)" This reverts commit 80f4a7287c8802649dcb4e55c2dd87f3ee37264b. clean flashinfer cache and modify test fix dumpy patch issue fix some issues * remove redundent * enable moe reduce fusion * fix test * fix cuda context issue * update flashinfer version --- fastdeploy/config.py | 1 + fastdeploy/engine/args_utils.py | 11 + fastdeploy/engine/common_engine.py | 1 + fastdeploy/engine/engine.py | 1 + .../layers/flashinfer_comm_fusion.py | 209 +++++++ fastdeploy/model_executor/layers/linear.py | 16 +- .../model_executor/layers/normalization.py | 11 + .../layers/quantization/mxfp4.py | 9 +- fastdeploy/model_executor/models/glm4_moe.py | 14 +- fastdeploy/model_executor/utils.py | 6 + fastdeploy/worker/worker_process.py | 6 + requirements.txt | 2 +- .../test_rmsnorm_layer_batch_invariant.py | 1 + .../test_trtllm_allreduce_rms_fusion.py | 54 ++ tests/layers/trtllm_allreduce_rms_fusion.py | 574 ++++++++++++++++++ tests/model_executor/test_linear.py | 1 + 16 files changed, 903 insertions(+), 14 deletions(-) create mode 100644 fastdeploy/model_executor/layers/flashinfer_comm_fusion.py create mode 100644 tests/layers/test_trtllm_allreduce_rms_fusion.py create mode 100644 tests/layers/trtllm_allreduce_rms_fusion.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 6cf244d2bb8..3559e89799e 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -675,6 +675,7 @@ def __init__( self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). self.disable_custom_all_reduce: bool = False + self.enable_flashinfer_allreduce_fusion: bool = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 5b443e91c42..79fb13d95c7 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -274,6 +274,11 @@ class EngineArgs: Flag to disable the custom all-reduce kernel. """ + enable_flashinfer_allreduce_fusion: bool = False + """ + Flag to enable all reduce fusion kernel in flashinfer. + """ + use_internode_ll_two_stage: bool = False """ Flag to use the internode_ll_two_stage kernel. @@ -1000,6 +1005,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_custom_all_reduce, help="Flag to disable custom all-reduce.", ) + parallel_group.add_argument( + "--enable-flashinfer-allreduce-fusion", + action="store_true", + default=EngineArgs.enable_flashinfer_allreduce_fusion, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parallel_group.add_argument( "--use-internode-ll-two-stage", action="store_true", diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 50017baf5de..98374941136 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2542,6 +2542,7 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index ba9049be0c6..210b6f4bf26 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -667,6 +667,7 @@ def _start_worker_service(self): "ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py new file mode 100644 index 00000000000..7f27b52975d --- /dev/null +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -0,0 +1,209 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional, Tuple + +import paddle +import paddle.distributed as dist + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.utils import has_flashinfer +from fastdeploy.utils import get_logger + +logger = get_logger("flashinfer", "flashinfer.log") + +_flashinfer_comm = None +_workspace_manager = None + + +def _get_flashinfer_comm(): + """Lazily import flashinfer.comm to avoid side effects at module load time.""" + global _flashinfer_comm + if _flashinfer_comm is not None: + return _flashinfer_comm + if has_flashinfer(): + try: + with paddle.use_compat_guard(enable=True, scope={"flashinfer"}): + import flashinfer.comm as comm + + _flashinfer_comm = comm + except ImportError: + logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") + return _flashinfer_comm + + +class FlashInferWorkspaceManager: + def __init__(self): + self.workspace_tensor = None + self.ipc_handles = None + self.world_size = None + self.rank = None + self.initialized = False + + def initialize( + self, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + group=None, + use_fp32_lamport: bool = False, + ): + """Initialize workspace""" + if self.initialized and self.world_size == world_size: + return + + comm = _get_flashinfer_comm() + if comm is None: + logger.warning("FlashInfer comm not available, skipping workspace " "initialization") + return + + self.cleanup() + + self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + max_token_num, + hidden_dim, + group=group, + use_fp32_lamport=use_fp32_lamport, + ) + + self.world_size = world_size + self.rank = rank + self.initialized = True + + logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") + + def cleanup(self): + """Clean up workspace""" + if self.initialized and self.ipc_handles is not None: + try: + comm = _get_flashinfer_comm() + if comm is not None: + comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) + except Exception as e: + logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") + finally: + self.workspace_tensor = None + self.ipc_handles = None + self.initialized = False + + +_workspace_manager = FlashInferWorkspaceManager() + + +def ensure_workspace_initialized( + fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False +): + """Ensure workspace is initialized""" + comm = _get_flashinfer_comm() + if not has_flashinfer() or comm is None: + return False + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + return False + + rank = dist.get_rank() + + if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: + _workspace_manager.initialize( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + use_fp32_lamport=use_fp32_lamport, + ) + + return _workspace_manager.initialized + + +def flashinfer_allreduce_residual_rmsnorm( + fd_config: FDConfig, + input_tensor: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-6, + max_token_num: int = 2048, + use_oneshot: Optional[bool] = None, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Use FlashInfer's fused allreduce + residual + RMS norm operation + """ + comm = _get_flashinfer_comm() + if not has_flashinfer() or comm is None: + logger.debug("FlashInfer not available, falling back to standard " "implementation") + return None, None + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + logger.debug("Single GPU, no need for allreduce fusion") + return None, None + + assert input_tensor.shape[0] <= max_token_num + + if not ensure_workspace_initialized( + fd_config=fd_config, + max_token_num=max_token_num, + hidden_dim=input_tensor.shape[-1], + use_fp32_lamport=(input_tensor.dtype == paddle.float32), + ): + logger.debug("FlashInfer workspace not available") + return None, None + + token_num, hidden_dim = input_tensor.shape + + residual_out = paddle.empty_like(residual) + norm_out = paddle.empty_like(input_tensor) + # support empty tensor + if input_tensor.shape[0] == 0: + return norm_out, residual_out + comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + world_size=world_size, + world_rank=dist.get_rank(), + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=_workspace_manager.workspace_tensor, + launch_with_pdl=True, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm), + allreduce_out=None, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=eps, + scale_factor=None, + layout_code=None, + ) + + return norm_out, residual_out + + +def cleanup_flashinfer_workspace(): + global _workspace_manager + if _workspace_manager is not None: + _workspace_manager.cleanup() diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index b35d97d7660..b9138adad06 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -853,6 +853,7 @@ def __init__( skip_quant: bool = False, weight_dtype: str = "", layer_id: int = -1, + enable_all_reduce_fusion: bool = None, ): """ Initialize a linear layer with additional parameters for inference and quantization. @@ -864,9 +865,17 @@ def __init__( input_size (int): Number of input features. Defaults to None. output_size (int): Number of output features. Defaults to None. with_bias (bool): Whether to include bias or not. Defaults to False. - skip_quant (bool): Whether to skip quantization. Defaults to False. + skip_quant (bool): Whether to skip quantization or not. Defaults to False. + enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion. + If None, it is determined by the config flag and prefix. Defaults to None. """ self.fd_config = fd_config + if enable_all_reduce_fusion is None: + self.enable_all_reduce_fusion = False + else: + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion + ) self.ep_size = fd_config.parallel_config.expert_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size self.tp_group = fd_config.parallel_config.tp_group @@ -944,7 +953,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: out = self.quant_method.apply(self, x) - if self.reduce_results and self.tp_size > 1: + need_tp_all_reduce = ( + self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048) + ) + if need_tp_all_reduce: out = tensor_model_parallel_all_reduce(out, self.tp_group) return out diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 4532b8d27af..6f2e64ed6b2 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -35,6 +35,7 @@ is_batch_invariant_mode_enabled, rms_norm_batch_invariant, ) +from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm from .utils import get_tensor, modules_to_convert @@ -122,6 +123,10 @@ def __init__( self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank self.tp_group = self.fd_config.parallel_config.tp_group is_input_norm = prefix.endswith(".input_layernorm") + self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and ( + ("post_attention_layernorm" in prefix) or (("input_layernorm" in prefix and layer_id != 0)) + ) + self.is_last_norm = prefix.endswith(".norm") self.split_x = ( self.fd_config.parallel_config.use_sequence_parallel_moe @@ -240,6 +245,12 @@ def forward( norm_out = rms_norm(x, self.weight, self.eps) return norm_out.astype(x_dtype), residual_out norm_out = self.norm_func(x, residual_input, self.weight, self.eps) + # enable trtllm all reduce fusion + elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: + norm_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps + ) + assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!" else: if is_batch_invariant_mode_enabled(): # M-invariant path: per-row Triton kernel, no cross-row reduction diff --git a/fastdeploy/model_executor/layers/quantization/mxfp4.py b/fastdeploy/model_executor/layers/quantization/mxfp4.py index 8732b1cc4d1..24ec38e696c 100644 --- a/fastdeploy/model_executor/layers/quantization/mxfp4.py +++ b/fastdeploy/model_executor/layers/quantization/mxfp4.py @@ -14,8 +14,6 @@ # limitations under the License. """ -import importlib -import importlib.util import math from enum import Enum from typing import Callable, Optional @@ -25,11 +23,12 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch + from fastdeploy.utils import get_logger from ..moe import FusedMoE @@ -59,10 +58,6 @@ def check_device_capability(num): return False -def has_flashinfer(): - return importlib.util.find_spec("flashinfer") is not None - - def round_up(a, b): return ((a + b - 1) // b) * b diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index fba36185a4a..befbe64dd3f 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -101,6 +101,7 @@ def __init__( output_size=fd_config.model_config.hidden_size, with_bias=False, reduce_results=reduce_results, + enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion, ) self.act_fn = SiluAndMul( @@ -130,10 +131,12 @@ def __init__( self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank self.tp_group = fd_config.parallel_config.tp_group - self.use_ep = self.expert_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1 - + self.last_layer_id = fd_config.model_config.num_hidden_layers - 1 + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and layer_id != self.last_layer_id + ) self.n_routed_experts: int = fd_config.model_config.n_routed_experts self.n_shared_experts: int = fd_config.model_config.n_shared_experts @@ -201,8 +204,10 @@ def forward(self, x, forward_meta: ForwardMeta = None): if self.n_shared_experts > 0: out = out + self.shared_experts(x) if self.merge_ffn_tp: - # Both branches produced partial sums; combine first, then single all-reduce. - out = tensor_model_parallel_all_reduce(out, self.tp_group) + need_tp_all_reduce_fusion = self.enable_all_reduce_fusion and out.shape[0] <= 2048 + if not need_tp_all_reduce_fusion: + # Both branches produced partial sums; combine first, then single all-reduce. + out = tensor_model_parallel_all_reduce(out, self.tp_group) return out @@ -230,6 +235,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim, output_size=fd_config.model_config.hidden_size, layer_id=layer_id, + enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion, ) self.attn = Attention( diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index e63603047be..de5ef678b01 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -14,6 +14,8 @@ # limitations under the License. """ +import importlib +import importlib.util import os import re from collections.abc import Mapping @@ -553,6 +555,10 @@ def fn(loaded_weight_name, is_moe): return fn +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None + + @cache def get_sm_version(): if paddle.cuda.is_available(): diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 3b0b8e1fe80..5c988c5fae1 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -849,6 +849,12 @@ def parse_args(): default=None, help="Configuration of SpeculativeConfig.", ) + parser.add_argument( + "--enable_flashinfer_allreduce_fusion", + action="store_true", + default=False, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parser.add_argument( "--max_num_batched_tokens", type=int, diff --git a/requirements.txt b/requirements.txt index 2edef89b859..9a7bb3b3613 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,6 +46,6 @@ setproctitle aistudio_sdk p2pstore py-cpuinfo -flashinfer-python-paddle +flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.3-py3-none-any.whl flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl transformers>=4.55.1,<5.0.0 diff --git a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py index 121e74ee4b9..54ce40ca5d4 100644 --- a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py +++ b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py @@ -31,6 +31,7 @@ def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"): layer.bias = None layer.split_x = False layer.allgather_out = False + layer.enable_all_reduce_fusion = False return layer diff --git a/tests/layers/test_trtllm_allreduce_rms_fusion.py b/tests/layers/test_trtllm_allreduce_rms_fusion.py new file mode 100644 index 00000000000..8edd007cadd --- /dev/null +++ b/tests/layers/test_trtllm_allreduce_rms_fusion.py @@ -0,0 +1,54 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import subprocess +import sys + + +def test_run_distributed(): + """Launch multi-GPU distributed test via paddle.distributed.launch as subprocess""" + # flashinfer_cache_dir = os.path.expanduser("~/.cache/flashinfer") + # if os.path.exists(flashinfer_cache_dir): + # print(f"=== Clearing flashinfer cache directory: {flashinfer_cache_dir} ===") + # subprocess.run(["rm", "-rf", flashinfer_cache_dir], check=True) + current_dir = os.path.dirname(os.path.abspath(__file__)) + run_script = os.path.join(current_dir, "trtllm_allreduce_rms_fusion.py") + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + command = [ + sys.executable, + "-m", + "paddle.distributed.launch", + "--gpus", + "0,1", + run_script, + ] + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + try: + stdout, stderr = process.communicate(timeout=400) + return_code = process.returncode + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + return_code = -1 + print(f"=== Distributed test stdout ===\n{stdout}") + print(f"=== Distributed test stderr ===\n{stderr}") + assert return_code in (0, 250), f"Process exited with code {return_code}" + + +test_run_distributed() diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py new file mode 100644 index 00000000000..1417d2a6463 --- /dev/null +++ b/tests/layers/trtllm_allreduce_rms_fusion.py @@ -0,0 +1,574 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import time +import unittest +from unittest.mock import Mock, patch + +import numpy as np +import paddle +import paddle.distributed as dist + + +class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase): + """Test FlashInfer AllReduce + Residual + RMSNorm fused operator""" + + @classmethod + def setUpClass(cls): + """Set up test environment""" + if paddle.is_compiled_with_cuda(): + # Bind each rank to its own GPU explicitly; otherwise all ranks + # default to "gpu:0" and cudaIpcOpenMemHandle fails with + # "invalid device context". + local_rank = int( + os.environ.get("PADDLE_LOCAL_RANK", os.environ.get("FLAGS_selected_gpus", "0").split(",")[0]) + ) + paddle.set_device(f"gpu:{local_rank}") + + # paddle.distributed.launch remaps each rank's visible GPU to + # index 0 inside the worker process. flashinfer's IPC calls go + # through the cudart runtime API (cuda-python), which maintains + # its own primary context separate from Paddle's driver context. + # Explicitly activate cudart's primary context on device 0 here, + # otherwise cudaIpcOpenMemHandle reports "invalid device context". + try: + from cuda import cudart + + cudart.cudaSetDevice(0) + cudart.cudaFree(0) # force primary context creation + except ImportError: + pass + else: + paddle.set_device("cpu") + dist.init_parallel_env() + if paddle.is_compiled_with_cuda(): + # Force the CUDA primary context to be created on the current + # device before flashinfer's cudart IPC calls run. + paddle.zeros([1]).cuda() + paddle.device.cuda.synchronize() + + def setUp(self): + """Initialize each test case""" + # Fix random seed for reproducibility + paddle.seed(42) + np.random.seed(42) + + self.dtype = paddle.float32 + self.token_num = 128 + self.hidden_dim = 768 + self.eps = 1e-6 + self.epsilon = 1e-6 + self.max_token_num = 2048 + + # Create mock FDConfig + self.fd_config = Mock() + self.fd_config.parallel_config = Mock() + self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size() + self.begin_norm_axis = 1 + + # Performance test params - increase iterations for stability + self.warmup_iterations = 20 # Increase warmup + self.test_iterations = 200 # Increase test iterations + + def tearDown(self): + """Clean up resources""" + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + + def create_test_tensors(self): + """Create test tensors""" + input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + weight = paddle.randn([self.hidden_dim], dtype=self.dtype) + return input_tensor, residual, weight + + def compute_reference_output(self, input_tensor, residual, weight, eps): + """Reference implementation: manually compute AllReduce + Residual + RMSNorm""" + # # Step 1: AllReduce (identity on single device) + # allreduce_out = input_tensor.clone() + # Apply all reduce operator + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + # Step 2: Add residual + residual_out = input_tensor + residual + + # Step 3: RMSNorm + variance = residual_out.pow(2).mean(axis=-1, keepdim=True) + norm_out = residual_out * paddle.rsqrt(variance + eps) + norm_out = norm_out * weight + + # dist.all_reduce(residual_out, op=dist.ReduceOp.SUM) + return norm_out, residual_out + + def paddle_rms_fuse(self, input_tensor, residual, weight, eps): + from paddle.incubate.nn.functional import fused_rms_norm + + # Apply all reduce operator + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + out_fused = fused_rms_norm( + input_tensor, + norm_weight=weight, + norm_bias=None, + epsilon=eps, + begin_norm_axis=self.begin_norm_axis, + bias=None, + residual=residual, + ) + + return out_fused[0], out_fused[1] + + def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps): + """FlashInfer fused operator""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=eps, + max_token_num=self.max_token_num, + use_oneshot=False, + ) + return norm_out, residual_out + + def benchmark_function(self, func, *args, name="", **kwargs): + """ + Improved performance benchmark + - Wait for GPU frequency stabilization + - Use median instead of mean (more stable) + - Filter outliers + """ + # Force GPU frequency stabilization + if paddle.is_compiled_with_cuda(): + for _ in range(5): + paddle.device.cuda.synchronize() + time.sleep(0.01) + + # Warmup - thorough warm-up + for _ in range(self.warmup_iterations): + result = func(*args, **kwargs) + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + # Extra wait to ensure GPU stability + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + time.sleep(0.1) + + # Benchmark run + times = [] + for i in range(self.test_iterations): + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + start = time.perf_counter() + result = func(*args, **kwargs) + + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + end = time.perf_counter() + elapsed = (end - start) * 1000 # Convert to milliseconds + times.append(elapsed) + + times = np.array(times) + + # Filter outliers using IQR method + q1, q3 = np.percentile(times, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + filtered_times = times[(times >= lower_bound) & (times <= upper_bound)] + + # Fall back to raw data if too many samples filtered out + if len(filtered_times) < self.test_iterations * 0.5: + filtered_times = times + + # Statistics + avg_time = np.mean(filtered_times) + median_time = np.median(filtered_times) + std_time = np.std(filtered_times) + min_time = np.min(filtered_times) + max_time = np.max(filtered_times) + cv = (std_time / avg_time) * 100 # Coefficient of variation (%) + + print(f"\n{'='*70}") + print(f"Performance Benchmark: {name}") + print(f"{'='*70}") + print(f"Iterations: {len(filtered_times)}/{self.test_iterations} (after {self.warmup_iterations} warmup)") + print(f"Median: {median_time:.4f} ms (most stable metric)") + print(f"Average: {avg_time:.4f} ms") + print(f"Std Dev: {std_time:.4f} ms (CV: {cv:.2f}%)") + print(f"Min: {min_time:.4f} ms") + print(f"Max: {max_time:.4f} ms") + print(f"{'='*70}\n") + + # Return median (more stable) and result + return median_time, result + + def test_accuracy_fused_vs_reference(self): + """Test accuracy of fused operator vs reference implementation""" + input_tensor, residual, weight = self.create_test_tensors() + reference_output, ref_res = self.compute_reference_output( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + fused_output, paddle_res = self.paddle_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + # Verify results + np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5) + + +class TestFlashInferWorkspaceManager(unittest.TestCase): + """Test FlashInferWorkspaceManager""" + + def setUp(self): + """Initialize""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + self.manager = FlashInferWorkspaceManager() + + def test_initialization(self): + """Test initialization state""" + self.assertIsNone(self.manager.workspace_tensor) + self.assertIsNone(self.manager.ipc_handles) + self.assertIsNone(self.manager.world_size) + self.assertIsNone(self.manager.rank) + self.assertFalse(self.manager.initialized) + + def test_cleanup(self): + """Test cleanup functionality""" + self.manager.cleanup() + self.assertFalse(self.manager.initialized) + self.assertIsNone(self.manager.workspace_tensor) + + +class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase): + """Test FlashInferWorkspaceManager edge cases and fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + # Patch before importing to test fallback paths + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_initialization_early_return_when_already_initialized(self): + """Test line 47: early return when already initialized with same world_size""" + # Patch _flashinfer_comm to be available + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # First initialization + manager.initialized = True + manager.world_size = 2 + + # Mock the comm functions + mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock(return_value=(Mock(), Mock())) + + # Second initialization with same world_size - should return early + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + def test_initialization_warning_when_comm_none(self): + """Test lines 50-51: warning when _flashinfer_comm is None""" + # Patch to ensure _get_flashinfer_comm returns None + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm", + return_value=None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # Should not raise, just log warning and return + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + # Verify not initialized + self.assertFalse(manager.initialized) + + def test_cleanup_with_exception(self): + """Test lines 73-80: cleanup with exception handling""" + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = True + manager.ipc_handles = Mock() + manager.workspace_tensor = Mock() + + # Mock the destroy function to raise exception + mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock(side_effect=RuntimeError("Cleanup error")) + + # Should not raise, just log warning + manager.cleanup() + + # Verify cleanup happened + self.assertFalse(manager.initialized) + self.assertIsNone(manager.workspace_tensor) + self.assertIsNone(manager.ipc_handles) + + def test_cleanup_without_initialization(self): + """Test cleanup when not initialized""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = False + + # Should not raise + manager.cleanup() + + # Verify state + self.assertFalse(manager.initialized) + + +class TestEnsureWorkspaceInitialized(unittest.TestCase): + """Test ensure_workspace_initialized fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_ensure_workspace_when_flashinfer_not_available(self): + """Test line 91: early return when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False (not initialized) + self.assertFalse(result) + + def test_ensure_workspace_when_comm_none(self): + """Test ensure_workspace_initialized when _flashinfer_comm is None""" + self.mock_has_flashinfer.return_value = True + + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm", + return_value=None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False + self.assertFalse(result) + + def test_ensure_workspace_single_gpu(self): + """Test line 96: early return when world_size <= 1""" + self.mock_has_flashinfer.return_value = True + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.dist.get_rank", return_value=0): + result = ensure_workspace_initialized(fd_config) + + # Should return False for single GPU + self.assertFalse(result) + + +class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase): + """Test flashinfer_allreduce_residual_rmsnorm fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_flashinfer_not_available_fallback(self): + """Test lines 140-141: fallback when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None when flashinfer not available + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_single_gpu_fallback(self): + """Test lines 146-147: fallback for single GPU""" + self.mock_has_flashinfer.return_value = True + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None for single GPU + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_empty_tensor_handling(self): + """Test line 166: empty tensor handling""" + self.mock_has_flashinfer.return_value = True + + with ( + patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm, + patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized", + return_value=True, + ), + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + # Empty tensor (0 tokens) + input_tensor = paddle.zeros([0, 768]) + residual = paddle.zeros([0, 768]) + weight = paddle.randn([768]) + + # Mock the trtllm_allreduce_fusion to not be called + mock_comm.trtllm_allreduce_fusion = Mock() + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return empty tensors, not call flashinfer + self.assertEqual(norm_out.shape[0], 0) + self.assertEqual(residual_out.shape[0], 0) + mock_comm.trtllm_allreduce_fusion.assert_not_called() + + +class TestCleanupFlashInferWorkspace(unittest.TestCase): + """Test cleanup_flashinfer_workspace function""" + + def test_cleanup_workspace_function(self): + """Test lines 211-212: cleanup function""" + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager") as mock_manager: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + cleanup_flashinfer_workspace, + ) + + mock_manager.cleanup = Mock() + + cleanup_flashinfer_workspace() + + mock_manager.cleanup.assert_called_once() + + +if __name__ == "__main__": + """Run tests directly (called by subprocess after distributed launch)""" + unittest.main(verbosity=2) diff --git a/tests/model_executor/test_linear.py b/tests/model_executor/test_linear.py index 13f2bbe245e..aba98479303 100644 --- a/tests/model_executor/test_linear.py +++ b/tests/model_executor/test_linear.py @@ -58,6 +58,7 @@ def make_fd_config( expert_parallel_size=1, tp_group=None, use_sequence_parallel_moe=use_sequence_parallel_moe, + enable_flashinfer_allreduce_fusion=False, ), scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1), load_config=SimpleNamespace( From 00778227cc75fc2a93e2c843b9191b7d5536c74b Mon Sep 17 00:00:00 2001 From: SunLei Date: Wed, 13 May 2026 10:09:14 +0800 Subject: [PATCH 099/143] =?UTF-8?q?[FDConfig]=20=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=BC=80=E5=90=AF=20FD=5FENABLE=5FE2W=5FTENSOR=5FCONVERT=20?= =?UTF-8?q?=E5=92=8C=20FD=5FENGINE=5FTASK=5FQUEUE=5FWITH=5FSHM=20(#7746)?= =?UTF-8?q?=20(#7784)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [test] Stop server with /dev/shm cleanup * cleanup shm by clean_ports * kill_process_by_unix_socket * add engine_worker_queue.is_broken * Failed to connect to engine worker queue, retry after 5 seconds * test_Qwen2-7B-Instruct_offline * sys.path.insert(0, project_root) * Cleaning unix socket for all ports * add is_file_socket_available * clearup dev/shm/* for xpu --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> --- fastdeploy/engine/common_engine.py | 4 + fastdeploy/envs.py | 4 +- .../inter_communicator/engine_worker_queue.py | 10 + fastdeploy/utils.py | 35 +++ .../test_eblite_serving.py | 95 +------- .../test_Qwen2-7B-Instruct_offline.py | 48 ++-- tests/e2e/utils/serving_utils.py | 59 +++++ tests/utils/test_find_free_ports.py | 212 ++++++++++++++++++ tests/xpu_ci/conftest.py | 7 + 9 files changed, 354 insertions(+), 120 deletions(-) create mode 100644 tests/utils/test_find_free_ports.py diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 98374941136..886f275d098 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1166,6 +1166,10 @@ def _fetch_request(): except Exception as e: err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) self.llm_logger.error(err_msg) + # Failed to connect to engine worker queue, retry after 5 seconds + if self.engine_worker_queue.is_broken(): + self.llm_logger.error("Failed to connect to engine worker queue, retry after 5 seconds") + time.sleep(5) def _get_scheduler_unhandled_request_num(self) -> int: """ diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 7e0f809d5d3..08cef5849e1 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -178,8 +178,8 @@ def _validate_split_kv_size(value: int) -> int: "PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES": lambda: int( os.getenv("PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", "1") ), - "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), - "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), + "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "1")), + "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "1")), "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), "FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")), diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index a7876669f8f..b0fc9bb3385 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -848,3 +848,13 @@ def cleanup(self): """ if self.manager is not None and self.is_server: self.manager.shutdown() + + def is_broken(self): + try: + self.manager.connect() + return False + except (ConnectionRefusedError, ConnectionResetError, BrokenPipeError, EOFError, OSError): + llm_logger.error("Failed to connect to engine worker queue") + return True + except Exception: + return False diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 6c0b72ae8ae..965fca6d96c 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -626,6 +626,12 @@ def is_port_available(host, port): import errno import socket + # If FD_ENGINE_TASK_QUEUE_WITH_SHM is enabled, then check the file socket is available + if envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + socket_path = f"/dev/shm/fd_task_queue_{port}.sock" + if not is_file_socket_available(socket_path): + return False + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -637,6 +643,35 @@ def is_port_available(host, port): return True +def is_file_socket_available(socket_path): + """ + Check the Unix domain socket (file socket) is available. + + Args: + socket_path: Path to the socket file, e.g. /dev/shm/fd_task_queue_8000.sock + + Returns: + True if the socket is available (not in use), False otherwise. + """ + import errno + import os + import socket + + if not os.path.exists(socket_path): + return True + + # File exists, try to connect to see if someone is listening + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + try: + s.connect(socket_path) + return False + except OSError as e: + if e.errno in (errno.ECONNREFUSED, errno.ENOENT): + # Stale socket file: exists but nobody is listening + return True + return False + + def find_free_ports( port_range: tuple[int, int] = (8000, 65535), num_ports: int = 1, diff --git a/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py index 6d8dfac53fd..7c3c6657434 100644 --- a/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py +++ b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py @@ -16,7 +16,6 @@ import queue import shutil import signal -import socket import subprocess import sys import time @@ -30,6 +29,7 @@ sys.path.insert(0, project_root) from ci_use.EB_Lite_with_adapter.zmq_client import LLMControlClient, LLMReqClient +from e2e.utils.serving_utils import clean_ports, is_port_open env = os.environ.copy() @@ -79,88 +79,6 @@ def zmq_control_client(): return client -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False - - -def kill_process_on_port(port: int): - """ - Kill processes that are listening on the given port. - Uses multiple methods to ensure thorough cleanup. - """ - current_pid = os.getpid() - parent_pid = os.getppid() - - # Method 1: Use lsof to find processes - try: - output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() - for pid in output.splitlines(): - pid = int(pid) - if pid in (current_pid, parent_pid): - print(f"Skip killing current process (pid={pid}) on port {port}") - continue - try: - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - time.sleep(1) - # Then SIGKILL if still running - os.kill(pid, signal.SIGKILL) - print(f"Killed process on port {port}, pid={pid}") - except ProcessLookupError: - pass # Process already terminated - except subprocess.CalledProcessError: - pass - - # Method 2: Use netstat and fuser as backup - try: - # Find processes using netstat and awk - cmd = f"netstat -tulpn 2>/dev/null | grep :{port} | awk '{{print $7}}' | cut -d'/' -f1" - output = subprocess.check_output(cmd, shell=True).decode().strip() - for pid in output.splitlines(): - if pid and pid.isdigit(): - pid = int(pid) - if pid in (current_pid, parent_pid): - continue - try: - os.kill(pid, signal.SIGKILL) - print(f"Killed process (netstat) on port {port}, pid={pid}") - except ProcessLookupError: - pass - except (subprocess.CalledProcessError, FileNotFoundError): - pass - - # Method 3: Use fuser if available - try: - subprocess.run(f"fuser -k {port}/tcp", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def clean_ports(): - """ - Kill all processes occupying the ports listed in PORTS_TO_CLEAN. - """ - print(f"Cleaning ports: {PORTS_TO_CLEAN}") - for port in PORTS_TO_CLEAN: - kill_process_on_port(port) - - # Double check and retry if ports are still in use - time.sleep(2) - for port in PORTS_TO_CLEAN: - if is_port_open("127.0.0.1", port, timeout=0.1): - print(f"Port {port} still in use, retrying cleanup...") - kill_process_on_port(port) - time.sleep(1) - - @pytest.fixture(scope="session", autouse=True) def setup_and_run_server(): """ @@ -170,8 +88,15 @@ def setup_and_run_server(): - Waits for server port to open (up to 30 seconds) - Tears down server after all tests finish """ + # 清理/dev/shm中的临时文件 + try: + subprocess.run("rm -rf /dev/shm/*", shell=True) + print("Successfully cleaned up /dev/shm.") + except Exception as e: + print(f"Failed to cleanup /dev/shm: {e}") + print("Pre-test port cleanup...") - clean_ports() + clean_ports(PORTS_TO_CLEAN) base_path = os.getenv("MODEL_PATH") if base_path: @@ -236,7 +161,7 @@ def setup_and_run_server(): print("\n===== Post-test server cleanup... =====") try: os.killpg(process.pid, signal.SIGTERM) - clean_ports() + clean_ports(PORTS_TO_CLEAN) print(f"API server (pid={process.pid}) terminated") except Exception as e: print(f"Failed to terminate API server: {e}") diff --git a/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py b/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py index fde03d70ee1..b42799ce066 100644 --- a/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py +++ b/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py @@ -13,9 +13,7 @@ # limitations under the License. import os -import signal -import socket -import subprocess +import sys import time import traceback @@ -23,21 +21,17 @@ from fastdeploy import LLM, SamplingParams -FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313)) -FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) -MAX_WAIT_SECONDS = 60 - +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..", "..")) +sys.path.insert(0, project_root) +from e2e.utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + clean_ports, +) -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False +MAX_WAIT_SECONDS = 60 def format_chat_prompt(messages): @@ -74,19 +68,15 @@ def llm(model_path): """ Fixture to initialize the LLM model with a given model path """ - try: - output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip() - for pid in output.splitlines(): - os.kill(int(pid), signal.SIGKILL) - print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}") - except subprocess.CalledProcessError: - pass + # Clean ports before starting the test + clean_ports() try: start = time.time() llm = LLM( model=model_path, tensor_parallel_size=1, + port=FD_API_PORT, engine_worker_queue_port=FD_ENGINE_QUEUE_PORT, cache_queue_port=FD_CACHE_QUEUE_PORT, max_model_len=32768, @@ -94,15 +84,7 @@ def llm(model_path): logits_processors=["LogitBiasLogitsProcessor"], ) - # Wait for the port to be open - wait_start = time.time() - while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT): - if time.time() - wait_start > MAX_WAIT_SECONDS: - pytest.fail( - f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}" - ) - time.sleep(1) - + time.sleep(2) print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.") yield llm except Exception: diff --git a/tests/e2e/utils/serving_utils.py b/tests/e2e/utils/serving_utils.py index 6dd5e77c9b7..9e47ca177e7 100644 --- a/tests/e2e/utils/serving_utils.py +++ b/tests/e2e/utils/serving_utils.py @@ -98,6 +98,60 @@ def kill_process_on_port(port: int): pass +def kill_process_by_unix_socket( + socket_path: str, + force: bool = True, +): + """ + 根据 unix socket 文件路径杀掉对应进程 + cmd: ss -xlpn | grep /dev/shm/fd_task_queue_8664.sock + Args: + socket_path: 例如 /dev/shm/fd_task_queue_8664.sock + force: + True -> SIGKILL + False -> SIGTERM + Returns: + pid 或 None + """ + try: + output = subprocess.check_output( + ["ss", "-xlpn"], + text=True, + ) + for line in output.splitlines(): + if socket_path not in line: + continue + m = re.search(r"pid=(\d+)", line) + if not m: + continue + pid = int(m.group(1)) + os.kill( + pid, + signal.SIGKILL if force else signal.SIGTERM, + ) + return pid + except Exception: + pass + return None + + +def cleanup_unix_socket(socket_path: str): + if not os.path.exists(socket_path): + return + try: + pid = kill_process_by_unix_socket(socket_path) + print(f"Killed process by unix socket: {socket_path}, pid={pid}") + except Exception as e: + print(f"Failed to kill process by unix socket: {socket_path}, error={e}") + finally: + try: + if os.path.exists(socket_path): + os.remove(socket_path) + print(f"Cleaned unix socket: {socket_path}") + except Exception: + pass + + def clean_ports(ports=None): """ Kill all processes occupying the ports @@ -117,6 +171,11 @@ def clean_ports(ports=None): kill_process_on_port(port) time.sleep(1) + # Clean unix socket, fd_task_queue_*.sock, for FD_ENGINE_TASK_QUEUE_WITH_SHM = 1 + print("Cleaning unix socket") + for port in ports: + cleanup_unix_socket(f"/dev/shm/fd_task_queue_{port}.sock") + def clean(ports=None): """ diff --git a/tests/utils/test_find_free_ports.py b/tests/utils/test_find_free_ports.py new file mode 100644 index 00000000000..3ffe272443e --- /dev/null +++ b/tests/utils/test_find_free_ports.py @@ -0,0 +1,212 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from unittest.mock import patch + +import pytest + +from fastdeploy.utils import find_free_ports + + +class TestFindFreePorts: + """Unit tests for find_free_ports function.""" + + def test_find_single_free_port_success(self): + """Test finding a single free port successfully.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(20000, 20100), num_ports=1) + assert len(ports) == 1 + assert 20000 <= ports[0] <= 20100 + + def test_find_multiple_free_ports_success(self): + """Test finding multiple free ports successfully.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(20000, 20100), num_ports=5) + assert len(ports) == 5 + for port in ports: + assert 20000 <= port <= 20100 + + def test_find_ports_with_custom_host(self): + """Test finding ports with a custom host.""" + with patch("fastdeploy.utils.is_port_available", return_value=True) as mock_avail: + ports = find_free_ports(port_range=(30000, 30010), num_ports=2, host="127.0.0.1") + assert len(ports) == 2 + # Verify is_port_available was called with the custom host + for call in mock_avail.call_args_list: + assert call[0][0] == "127.0.0.1" + + def test_find_all_ports_in_range(self): + """Test finding all ports in a small range.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(40000, 40002), num_ports=3) + assert len(ports) == 3 + # All ports should be from the range + expected_ports = {40000, 40001, 40002} + assert set(ports) == expected_ports + + def test_invalid_port_range_start_negative(self): + """Test ValueError when port range start is negative.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(-1, 1000)) + + def test_invalid_port_range_end_exceeds_max(self): + """Test ValueError when port range end exceeds 65535.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(1000, 65536)) + + def test_invalid_port_range_start_greater_than_end(self): + """Test ValueError when port range start is greater than end.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(10000, 9000)) + + def test_invalid_port_range_boundary_values(self): + """Test port range boundary at exactly 0 and 65535.""" + # Valid: start = 0 + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(0, 100), num_ports=1) + assert len(ports) == 1 + + # Valid: end = 65535 + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(65530, 65535), num_ports=1) + assert len(ports) == 1 + + def test_num_ports_zero_raises_error(self): + """Test ValueError when num_ports is zero.""" + with pytest.raises(ValueError, match="num_ports must be a positive integer"): + find_free_ports(port_range=(20000, 30000), num_ports=0) + + def test_num_ports_negative_raises_error(self): + """Test ValueError when num_ports is negative.""" + with pytest.raises(ValueError, match="num_ports must be a positive integer"): + find_free_ports(port_range=(20000, 30000), num_ports=-1) + + def test_num_ports_larger_than_range_size(self): + """Test ValueError when num_ports exceeds the range size.""" + # Range has only 5 ports (100-104), but requesting 6 + with pytest.raises(ValueError, match="num_ports is larger than range size"): + find_free_ports(port_range=(100, 104), num_ports=6) + + def test_not_enough_free_ports_raises_runtime_error(self): + """Test RuntimeError when not enough free ports are available.""" + # Mock to return False for all ports + with patch("fastdeploy.utils.is_port_available", return_value=False): + with pytest.raises(RuntimeError, match="Only found 0 free ports"): + find_free_ports(port_range=(20000, 20010), num_ports=3) + + def test_partial_free_ports_raises_runtime_error(self): + """Test RuntimeError when only some ports are free.""" + call_count = [0] + + def mock_availability(host, port): + # Only first 2 ports are available + call_count[0] += 1 + return call_count[0] <= 2 + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with pytest.raises(RuntimeError, match="Only found 2 free ports"): + find_free_ports(port_range=(20000, 20005), num_ports=5) + + def test_random_start_offset(self): + """Test that port scanning starts from a random offset.""" + # Track the order of ports checked + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + ports = find_free_ports(port_range=(100, 105), num_ports=3) + + # With offset 0, ports should be checked in order + assert checked_ports[:3] == [100, 101, 102] + assert ports == [100, 101, 102] + + def test_random_start_offset_non_zero(self): + """Test port scanning with non-zero random offset.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + # With offset 2, scanning starts from port 102 + with patch("fastdeploy.utils.random.randint", return_value=2): + ports = find_free_ports(port_range=(100, 105), num_ports=3) + + # With offset 2, ports are rotated: [102, 103, 104, 105, 100, 101] + assert checked_ports[:3] == [102, 103, 104] + assert ports == [102, 103, 104] + + def test_single_port_range(self): + """Test finding port from a single-port range.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(12345, 12345), num_ports=1) + assert ports == [12345] + + def test_single_port_range_not_available(self): + """Test RuntimeError when the single port in range is not available.""" + with patch("fastdeploy.utils.is_port_available", return_value=False): + with pytest.raises(RuntimeError, match="Only found 0 free ports"): + find_free_ports(port_range=(12345, 12345), num_ports=1) + + def test_default_parameters(self): + """Test function with default parameters.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports() + assert len(ports) == 1 + assert 8000 <= ports[0] <= 65535 + + def test_stops_early_when_enough_ports_found(self): + """Test that scanning stops as soon as enough ports are found.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + # Range has 100 ports but we only need 2 + ports = find_free_ports(port_range=(20000, 20099), num_ports=2) + + # Should only check 2 ports, not all 100 + assert len(checked_ports) == 2 + assert len(ports) == 2 + + def test_skips_unavailable_ports(self): + """Test that unavailable ports are skipped.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + # Only odd ports are available + return port % 2 == 1 + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + ports = find_free_ports(port_range=(100, 110), num_ports=3) + + # Should find 3 odd ports: 101, 103, 105 + assert len(ports) == 3 + assert all(p % 2 == 1 for p in ports) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/xpu_ci/conftest.py b/tests/xpu_ci/conftest.py index ae0c95d727a..dc6e4d30262 100644 --- a/tests/xpu_ci/conftest.py +++ b/tests/xpu_ci/conftest.py @@ -101,6 +101,13 @@ def safe_kill_cmd(cmd): for cmd in commands: safe_kill_cmd(cmd) + try: + # 清理/dev/shm下的所有文件 + subprocess.run("rm -rf /dev/shm/*", shell=True, check=True) + except subprocess.CalledProcessError: + print("Failed to remove files from /dev/shm") + pass + def cleanup_resources(): """ From 976cb7b9e12a92459b222160f005209945827523 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Wed, 13 May 2026 13:33:29 +0800 Subject: [PATCH 100/143] [BugFix] fix: cast image_mask.any() to bool for task queue serialization (#7793) (#7798) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation 多模 RL 0 卡场景下,`image_mask.any()` 返回 `numpy.bool_` 类型,通过 task queue 传输时序列化失败导致报错。 ## Modifications - 将 `request.with_image = image_mask.any()` 改为 `request.with_image = bool(image_mask.any())`,转换为 Python 原生 bool 类型 ## Usage or Command 无需额外配置,启动多模 RL 推理服务即可验证: ```bash bash run.sh ``` Co-authored-by: kevin --- fastdeploy/engine/sched/resource_manager_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 3382e077d60..9c18f02da12 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -694,7 +694,7 @@ def _compute_audio_prefix_count(end_idx, end_patch_idx): num_new_tokens = new_end_idx - pre_end_idx image_mask = input_ids[pre_end_idx:new_end_idx] == image_patch_id - request.with_image = image_mask.any() + request.with_image = bool(image_mask.any()) if request.with_image: pre_boundary_idx = np.searchsorted(img_boundaries_idx, pre_end_idx, side="left").item() if pre_boundary_idx == len(img_boundaries_idx): From 90c010da4a848902c5f10875314cd0fd587ee798 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 13 May 2026 15:04:18 +0800 Subject: [PATCH 101/143] [Cherry-Pick][Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap(#7323) (#7794) * [Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap (#7323) * support mtp overlap in pd-split mode with insert_task overlap * fix note --- fastdeploy/eplb/async_expert_loader.py | 21 ++- .../model_executor/pre_and_post_process.py | 43 +++-- .../xpu_pre_and_post_process.py | 23 +++ fastdeploy/spec_decode/mtp.py | 65 +++---- fastdeploy/worker/gpu_model_runner.py | 162 +++++++++++------- tests/worker/test_gpu_model_runner.py | 159 ++++++++++++++++- 6 files changed, 352 insertions(+), 121 deletions(-) diff --git a/fastdeploy/eplb/async_expert_loader.py b/fastdeploy/eplb/async_expert_loader.py index 0cf9eb0453e..2832a7f635f 100644 --- a/fastdeploy/eplb/async_expert_loader.py +++ b/fastdeploy/eplb/async_expert_loader.py @@ -24,8 +24,24 @@ import paddle try: - from cuda import cudart -except ImportError: + import cuda as _cuda_pkg + + _cuda_ver = getattr(_cuda_pkg, "__version__", None) + if _cuda_ver is None: + # cuda-python >= 13.x does not expose a top-level __version__; + # detect the version via the cuda-bindings package. + import importlib.metadata as _meta + + _cuda_ver = _meta.version("cuda-bindings") + _cuda_major = int(_cuda_ver.split(".")[0]) + if _cuda_major >= 13: + from cuda.bindings import runtime as cudart + else: + from cuda import cudart +except Exception as _e: + import warnings + + warnings.warn(f"cuda-python import failed, async_expert_loader will be unavailable: {_e}") cudart = None from fastdeploy.config import EPLBConfig @@ -98,6 +114,7 @@ def create_mmap(model_name: List, ep_rank: int, ep_size: int, shm_uuid: str, epl raise ImportError( "cuda-python not installed. Install the version matching your CUDA toolkit:\n" " CUDA 12.x → pip install cuda-python==12.*\n" + " CUDA 13.x → pip install cuda-python cuda-bindings\n" ) # Register memory with CUDA diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 7cd71dabb51..fd8811b1101 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -125,36 +125,33 @@ DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" -if current_platform.is_cuda(): - def async_set_value(tgt, src): - if isinstance(src, (int, float, bool)): - src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype) - elif isinstance(src, (list, np.array)): - dtype_str = str(tgt.dtype).split(".")[1] - if isinstance(src, list): - src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32") +def async_set_value(tgt, src): + if isinstance(src, (int, float, bool)): + src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype) + elif isinstance(src, (list, np.ndarray)): + dtype_str = str(tgt.dtype).split(".")[1] + if isinstance(src, list): + src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32") + if current_platform.is_cuda(): if str(src.dtype) != dtype_str: srt_tensor = paddle.empty(tgt.shape, dtype=str(src.dtype)) src = custom_numpy_to_tensor(src, srt_tensor) else: return custom_numpy_to_tensor(src, tgt) - elif isinstance(src, paddle.Tensor): - pass else: - raise ValueError("async_set_value unsupported src type: {}".format(type(src))) - if src.shape != tgt.shape: - src = src.reshape(tgt.shape) - if src.dtype != tgt.dtype: - src = src.cast(tgt.dtype) - if src.place != tgt.place: - src = src.to(tgt.place) - tgt.copy_(src, blocking=False) - -else: - - def async_set_value(*args, **kwargs): - raise RuntimeError("async_set_value is only available on CUDA") + src = paddle.to_tensor(src, dtype=tgt.dtype) + elif isinstance(src, paddle.Tensor): + pass + else: + raise ValueError("async_set_value unsupported src type: {}".format(type(src))) + if src.shape != tgt.shape: + src = src.reshape(tgt.shape) + if src.dtype != tgt.dtype: + src = src.cast(tgt.dtype) + if src.place != tgt.place: + src = src.to(tgt.place) + tgt.copy_(src, blocking=False) def pre_process( diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index e5a1d9419c8..9232989dd48 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -55,6 +55,29 @@ DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" +def async_set_value(tgt, src): + if isinstance(src, (int, float, bool)): + src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype) + elif isinstance(src, (list, np.ndarray)): + dtype_str = str(tgt.dtype).split(".")[1] + np_dtype = dtype_str if dtype_str != "bfloat16" else "float32" + if isinstance(src, list): + src = np.array(src, dtype=np_dtype) + # TODO: support async_numpy_to_tensor + src = paddle.to_tensor(src, dtype=tgt.dtype) + elif isinstance(src, paddle.Tensor): + pass + else: + raise ValueError("async_set_value unsupported src type: {}".format(type(src))) + if src.shape != tgt.shape: + src = src.reshape(tgt.shape) + if src.dtype != tgt.dtype: + src = src.cast(tgt.dtype) + if src.place != tgt.place: + src = src.to(tgt.place) + tgt.copy_(src, blocking=False) + + def _build_stream_transfer_data( output_tokens: paddle.Tensor, pooler_outputs: List = None, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 2d8d310a469..70c626df28f 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -49,7 +49,10 @@ share_external_data, update_attn_mask_offsets, ) + + # temporary solution from fastdeploy.model_executor.xpu_pre_and_post_process import ( + async_set_value, xpu_pre_process, xpu_process_output, ) @@ -483,28 +486,32 @@ def insert_tasks_v1( input_ids = request.prompt_token_ids + request.output_token_ids self.model_inputs["input_ids_len"][idx] = length - 1 - self.model_inputs["pre_ids"][idx : idx + 1] = -1 + async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1) self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ idx : idx + 1, 1:length ] - self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ - "input_ids" - ][idx : idx + 1, 1:length].cpu() + # TODO: use token_all_ids replace with input_ids_cpu + if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs: + self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ + "input_ids" + ][idx : idx + 1, 1:length].cpu() encoder_block_num = len(request.block_tables) - self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.model_inputs["block_tables"][idx : idx + 1, :] = -1 - self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( - request.block_tables, dtype="int32" + async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) + async_set_value( + self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) - self.model_inputs["stop_flags"][idx : idx + 1] = False - self.model_inputs["batch_drop"][idx : idx + 1] = False - self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length + async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False) + async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False) + + async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length) self.exist_prefill_flag = True - self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index - self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length - self.model_inputs["step_idx"][idx : idx + 1] = ( - len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index) + async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length) + async_set_value( + self.model_inputs["step_idx"][idx : idx + 1], + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0, ) if self.use_attn_mask_offset: inputs = request.multimodal_inputs @@ -522,18 +529,19 @@ def insert_tasks_v1( if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generates first token - self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) self.exist_prefill_flag = False - self.model_inputs["recompute_token_num"][idx : idx + 1] = 0 - self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1 + async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1) # NOTE(liuzichang): # extra 1 : P-D split need rollback one step - self.model_inputs["mask_rollback"][idx : idx + 1] = 1 + + async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0) + async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1) # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task encoder_block_num = len(request.block_tables) - self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) if current_platform.is_cuda(): async_set_value( self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables @@ -542,16 +550,13 @@ def insert_tasks_v1( self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) - # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode - # has_decode_task = True - # continue else: - self.model_inputs["block_tables"][idx : idx + 1, :] = -1 - self.model_inputs["stop_flags"][idx : idx + 1] = True - self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0 - self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 - self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 - self.model_inputs["is_block_step"][idx : idx + 1] = False + async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) + async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True) + async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0) + async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0) + async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) + async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False) continue # TODO(liuzichang): Solve splitewise-p bug to restore diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 47f700e83e5..b06a8de0564 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -872,9 +872,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = input_ids = prompt_token_ids + request.output_token_ids prompt_len = len(prompt_token_ids) # prompt_tokens - self.share_inputs["token_ids_all"][idx : idx + 1, :prompt_len] = np.array( - prompt_token_ids, dtype="int64" - ) + async_set_value(self.share_inputs["token_ids_all"][idx : idx + 1, :prompt_len], prompt_token_ids) # generated_token_ids fill -1 self.share_inputs["token_ids_all"][idx : idx + 1, prompt_len:] = -1 @@ -884,33 +882,39 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.deterministic_logger.log_prefill_input( request.request_id, idx, prefill_start_index, prefill_end_index, input_ids ) - logger.debug( f"Handle prefill request {request} at idx {idx}, " f"{prefill_start_index=}, {prefill_end_index=}, " f"need_prefilled_token_num={len(input_ids)}" f"prompt_len={prompt_len}" ) - self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( - input_ids[prefill_start_index:prefill_end_index] + async_set_value( + self.share_inputs["input_ids"][idx : idx + 1, :length], + input_ids[prefill_start_index:prefill_end_index], ) encoder_block_num = len(request.block_tables) - self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.share_inputs["block_tables"][idx : idx + 1, :] = -1 - self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( - request.block_tables, dtype="int32" + async_set_value(self.share_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + + async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) + + async_set_value( + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) - self.share_inputs["stop_flags"][idx : idx + 1] = False - self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index - self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length - self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + + async_set_value(self.share_inputs["stop_flags"][idx : idx + 1], False) + + async_set_value(self.share_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index) + async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length) + async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], length) self.exist_prefill_flag = True - self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 - self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) - self.share_inputs["is_block_step"][idx : idx + 1] = False + async_set_value(self.share_inputs["step_seq_lens_decoder"][idx : idx + 1], 0) + async_set_value(self.share_inputs["prompt_lens"][idx : idx + 1], len(input_ids)) + + async_set_value(self.share_inputs["is_block_step"][idx : idx + 1], False) self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids) - self.share_inputs["step_idx"][idx : idx + 1] = ( - len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + async_set_value( + self.share_inputs["step_idx"][idx : idx + 1], + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0, ) # pooling model request.sampling_params is None if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: @@ -927,21 +931,37 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token - self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + # TODO: delete useless operation like this + async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], 0) self.exist_prefill_flag = False - self._cached_launch_token_num = -1 + if self._cached_launch_token_num != -1: + token_num_one_step = ( + (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1 + ) + self._cached_launch_token_num += token_num_one_step + self._cached_real_bsz += 1 if self.speculative_decoding: - # D speculate decode, seq_lens_this_time = length + 1 - self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + 1 - self.share_inputs["draft_tokens"][idx : idx + 1, 0 : length + 1] = paddle.to_tensor( - request.draft_token_ids[0 : length + 1], - dtype="int64", + # D first decode step, [Target first token, MTP first draft token] + # MTP in P only generate one draft token in any num_model_step config + draft_tokens_to_write = request.draft_token_ids[0:2] + if len(draft_tokens_to_write) != 2: + raise ValueError( + "Expected at least 2 draft tokens for speculative suffix decode, " + f"but got {len(draft_tokens_to_write)} for request {request.request_id}." + ) + async_set_value( + self.share_inputs["draft_tokens"][idx : idx + 1, 0:2], + draft_tokens_to_write, ) + async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 2) + logger.debug( + f"insert request {request.request_id} idx: {idx} suffix tokens {request.draft_token_ids}" + ) elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") encoder_block_num = len(request.block_tables) - self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + async_set_value(self.share_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) if current_platform.is_cuda(): async_set_value( self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables @@ -950,6 +970,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + # CPU Tensor self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0 continue else: # preempted task @@ -958,12 +979,12 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = elif request.task_type.value == RequestType.ABORT.value: logger.info(f"Handle abort request {request} at idx {idx}") self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1 - self.share_inputs["block_tables"][idx : idx + 1, :] = -1 - self.share_inputs["stop_flags"][idx : idx + 1] = True - self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0 - self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 - self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 - self.share_inputs["is_block_step"][idx : idx + 1] = False + async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) + async_set_value(self.share_inputs["stop_flags"][idx : idx + 1], True) + async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0) + async_set_value(self.share_inputs["seq_lens_decoder"][idx : idx + 1], 0) + async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], 0) + async_set_value(self.share_inputs["is_block_step"][idx : idx + 1], False) self.prompt_logprobs_reqs.pop(request.request_id, None) self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None @@ -971,53 +992,63 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens - self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) - - self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) - self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) - self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) - self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0) - self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) - self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) - self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) - self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) - self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False) - self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get( - "top_p_normalized_logprobs", False + self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) + async_set_value(self.share_inputs["eos_token_id"][:], request.eos_token_ids) + async_set_value(self.share_inputs["top_p"][idx : idx + 1], request.get("top_p", 0.7)) + async_set_value(self.share_inputs["top_k"][idx : idx + 1], request.get("top_k", 0)) + async_set_value(self.share_inputs["min_p"][idx : idx + 1], request.get("min_p", 0.0)) + async_set_value(self.share_inputs["temperature"][idx : idx + 1], request.get("temperature", 0.95)) + async_set_value(self.share_inputs["penalty_score"][idx : idx + 1], request.get("repetition_penalty", 1.0)) + async_set_value(self.share_inputs["frequency_score"][idx : idx + 1], request.get("frequency_penalty", 0.0)) + async_set_value(self.share_inputs["presence_score"][idx : idx + 1], request.get("presence_penalty", 0.0)) + async_set_value( + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1], request.get("temp_scaled_logprobs", False) ) - self.share_inputs["generated_modality"][idx : idx + 1] = request.get("generated_modality", 0) - - self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) - self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( - "max_tokens", self.model_config.max_model_len + async_set_value( + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1], + request.get("top_p_normalized_logprobs", False), + ) + async_set_value( + self.share_inputs["generated_modality"][idx : idx + 1], request.get("generated_modality", 0) + ) + async_set_value(self.share_inputs["min_dec_len"][idx : idx + 1], request.get("min_tokens", 1)) + async_set_value( + self.share_inputs["max_dec_len"][idx : idx + 1], + request.get("max_tokens", self.model_config.max_model_len), ) if request.get("seed") is not None: - self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + async_set_value(self.share_inputs["infer_seed"][idx : idx + 1], request.get("seed")) if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0: bad_words_len = len(request.get("bad_words_token_ids")) - self.share_inputs["bad_tokens_len"][idx] = bad_words_len - self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array( - request.get("bad_words_token_ids"), dtype="int64" + async_set_value(self.share_inputs["bad_tokens_len"][idx : idx + 1], bad_words_len) + async_set_value( + self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len], request.get("bad_words_token_ids") ) else: - self.share_inputs["bad_tokens_len"][idx] = 1 - self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64") + async_set_value(self.share_inputs["bad_tokens_len"][idx : idx + 1], 1) + async_set_value(self.share_inputs["bad_tokens"][idx : idx + 1, :], -1) if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): request.sampling_params.stop_seqs_len.append(0) - self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( - request.sampling_params.stop_seqs_len, dtype="int32" + async_set_value( + self.share_inputs["stop_seqs_len"][idx : idx + 1, :], request.sampling_params.stop_seqs_len ) - self.share_inputs["stop_seqs"][ - idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) - ] = np.array(request.get("stop_token_ids"), dtype="int64") + # Pad each stop sequence to stop_seqs_max_len, then fill remaining rows + # and write the whole block at once to avoid partial slicing on the + # third dimension, which may cause async_set_value stride issues on + # non-contiguous memory. + stop_token_ids = request.get("stop_token_ids") + max_len = self.model_config.stop_seqs_max_len + padded = [seq + [-1] * (max_len - len(seq)) for seq in stop_token_ids] + padded.extend([[-1] * max_len] * (self.model_config.max_stop_seqs_num - stop_seqs_num)) + async_set_value(self.share_inputs["stop_seqs"][idx : idx + 1, :, :], padded) else: - self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + async_set_value(self.share_inputs["stop_seqs_len"][idx : idx + 1, :], 0) self.pooling_params = batch_pooling_params # For logits processors @@ -1026,7 +1057,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) self._process_mm_features(req_dicts) - if len(rope_3d_position_ids["position_ids_idx"]) > 0: + + if len(rope_3d_position_ids["position_ids_idx"]) > 0 and self.enable_mm: packed_position_ids = paddle.to_tensor( np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64" ) diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py index f0f44ea68c6..008eafcaf62 100644 --- a/tests/worker/test_gpu_model_runner.py +++ b/tests/worker/test_gpu_model_runner.py @@ -14,7 +14,7 @@ import unittest from dataclasses import dataclass -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import numpy as np import paddle @@ -773,5 +773,162 @@ def test_execute_model_overlap_zero_output_flushes_preempted_batch(self): self.assertEqual(runner._cached_real_bsz, 22) +def _sync_async_set_value(tgt, src): + """Synchronous stand-in for async_set_value used in tests (no CUDA required). + + Writes to real numpy arrays; silently skips Mock objects (untracked share_inputs + fields whose values we do not assert on). + """ + from unittest.mock import MagicMock + + import numpy as np + + if isinstance(tgt, MagicMock): + return # untracked field — nothing to write + if isinstance(src, (int, float, bool)): + tgt[:] = src + elif isinstance(src, (list, np.ndarray)): + tgt[:] = np.array(src).reshape(tgt.shape) + elif hasattr(src, "numpy"): + tgt[:] = src.numpy() + else: + tgt[:] = src + + +class TestInsertTasksV1SplitwiseSuffix(unittest.TestCase): + """Tests for insert_tasks_v1 splitwise_role=\'decode\' + SpecMethod.SUFFIX branch.""" + + def _make_share_inputs(self, bsz=4, max_draft=6): + """Mock-backed share_inputs; only keys we assert on hold real numpy arrays.""" + import numpy as np + + # Keys whose values we want to inspect after the call + tracked = { + "seq_lens_encoder": np.zeros((bsz, 1), dtype=np.int32), + "draft_tokens": np.zeros((bsz, max_draft), dtype=np.int64), + "seq_lens_this_time_buffer": np.zeros((bsz, 1), dtype=np.int32), + "req_ids": [""] * bsz, + "preempted_idx": np.zeros((bsz, 1), dtype=np.int32), + "num_running_requests": 0, + "running_requests_ids": [], + } + + class _SI: + def get_index_by_batch_id(self, batch_id): + return batch_id + + def __getitem__(self, key): + # Return real array for tracked keys; Mock for everything else + if key in tracked: + return tracked[key] + return MagicMock() + + def __setitem__(self, key, value): + tracked[key] = value + + return _SI() + + def _make_runner(self, bsz=4, num_spec_tokens=3): + from unittest.mock import Mock + + from fastdeploy.spec_decode import SpecMethod + from fastdeploy.worker.gpu_model_runner import GPUModelRunner + + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.enable_mm = False + runner.is_pooling_model = False + runner.speculative_decoding = True + runner.spec_method = SpecMethod.SUFFIX + runner.speculative_config = Mock(num_speculative_tokens=num_spec_tokens) + runner.deterministic_logger = None + runner.routing_replay_manager = Mock() + runner.prompt_logprobs_reqs = {} + runner.in_progress_prompt_logprobs = {} + runner.forward_batch_reqs_list = [None] * bsz + runner._cached_launch_token_num = -1 + runner._cached_real_bsz = 0 + runner.exist_prefill_flag = True + runner.proposer = Mock() + runner.sampler = Mock() + runner.model_config = Mock(eos_tokens_lens=1) + runner.share_inputs = self._make_share_inputs(bsz=bsz, max_draft=num_spec_tokens + 2) + + fd_config = Mock() + fd_config.scheduler_config.splitwise_role = "decode" + fd_config.routing_replay_config.enable_routing_replay = False + runner.fd_config = fd_config + runner.scheduler_config = fd_config.scheduler_config + return runner + + def _make_prefill_request(self, idx, draft_token_ids): + from unittest.mock import Mock + + from fastdeploy.engine.request import RequestType + + req = Mock() + req.task_type = Mock(value=RequestType.PREFILL.value) + req.idx = idx + req.request_id = f"req_{idx}" + req.prompt_token_ids = [10, 20, 30] + req.output_token_ids = [99] + req.draft_token_ids = draft_token_ids + req.pooling_params = None + req.guided_json = None + req.guided_regex = None + req.structural_tag = None + req.guided_grammar = None + req.prefill_start_index = 0 + req.prefill_end_index = 3 + req.multimodal_inputs = None + req.get = Mock(return_value=None) + req.eos_token_ids = [2] + req.block_tables = [] + return req + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_draft_tokens_and_seq_lens_written(self, _mock_asv): + """draft_tokens[0:2] and seq_lens_this_time_buffer=2 are written.""" + runner = self._make_runner(num_spec_tokens=3) + req = self._make_prefill_request(idx=0, draft_token_ids=[101, 202, 303]) + runner.insert_tasks_v1([req], num_running_requests=1) + + self.assertEqual(runner.share_inputs["draft_tokens"][0, 0], 101) + self.assertEqual(runner.share_inputs["draft_tokens"][0, 1], 202) + self.assertEqual(runner.share_inputs["seq_lens_this_time_buffer"][0, 0], 2) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_exist_prefill_flag_cleared(self, _mock_asv): + runner = self._make_runner() + req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2]) + runner.insert_tasks_v1([req], num_running_requests=1) + self.assertFalse(runner.exist_prefill_flag) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_cached_launch_token_num_incremented(self, _mock_asv): + runner = self._make_runner(num_spec_tokens=3) + runner._cached_launch_token_num = 10 + runner._cached_real_bsz = 2 + req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2]) + runner.insert_tasks_v1([req], num_running_requests=1) + # token_num_one_step = num_speculative_tokens + 1 = 4 + self.assertEqual(runner._cached_launch_token_num, 14) + self.assertEqual(runner._cached_real_bsz, 3) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_cached_launch_token_num_skipped_when_negative_one(self, _mock_asv): + runner = self._make_runner(num_spec_tokens=3) + runner._cached_launch_token_num = -1 + req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2]) + runner.insert_tasks_v1([req], num_running_requests=1) + self.assertEqual(runner._cached_launch_token_num, -1) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_raises_when_fewer_than_two_draft_tokens(self, _mock_asv): + runner = self._make_runner() + req = self._make_prefill_request(idx=0, draft_token_ids=[42]) + with self.assertRaises(ValueError): + runner.insert_tasks_v1([req], num_running_requests=1) + + if __name__ == "__main__": unittest.main() From 4e7a46e46784b78d858d7a08535b6803ee3b8c9f Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 13 May 2026 15:27:03 +0800 Subject: [PATCH 102/143] prepare request in prefill instance by multi threads (#7724) --- fastdeploy/cache_manager/cache_messager.py | 38 ++- fastdeploy/engine/common_engine.py | 215 +------------ .../engine/common_engine_prepare_mixin.py | 282 ++++++++++++++++++ fastdeploy/envs.py | 2 + .../inter_communicator/engine_worker_queue.py | 106 ------- tests/cache_manager/test_cache_messager.py | 24 +- tests/engine/test_common_engine.py | 95 ++---- tests/inter_communicator/test_e2w_queue.py | 26 +- 8 files changed, 374 insertions(+), 414 deletions(-) create mode 100644 fastdeploy/engine/common_engine_prepare_mixin.py diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index b934c3e74c7..08c8dea003a 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -613,12 +613,16 @@ def __init__( ) self.gpu_id = gpu_id - self.cache_info = dict() + self.cache_info = dict() # {'request_id': cache_info_dict} self.rank_id = self.rank + local_data_parallel_id * self.nranks self.engine_cache_task_thread_lock = threading.Lock() - self.engine_cache_tasks = [dict() for _ in range(512)] - self.idx_cache_task_dict = {} - self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step + self.engine_cache_tasks = [ + dict() for _ in range(512) + ] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}} + self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict} + self.cache_prefilled_engine_ids_queue = ( + queue.Queue() + ) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)] if splitwise_role == "prefill": consume_signals_thread = threading.Thread(target=self.consume_signals) consume_signals_thread.daemon = True @@ -638,7 +642,6 @@ def _add_cache_task_thread(self): while True: try: cache_info = self.engine_worker_queue.get_cache_info() - finished_add_cache_task_req_ids = [] if cache_info: logger.debug(f"Get cache info from engine worker queue, {cache_info}") self.engine_worker_queue.cache_info_barrier.wait() @@ -647,7 +650,6 @@ def _add_cache_task_thread(self): self.cache_info[info["request_id"]].update(info) current_info = self.cache_info[info["request_id"]] assert "dest_block_ids" in current_info and "src_block_ids" in current_info - finished_add_cache_task_req_ids.append(info["request_id"]) decode_cached_block_num = len(current_info["src_block_ids"]) - len( current_info["dest_block_ids"] ) @@ -659,17 +661,13 @@ def _add_cache_task_thread(self): current_info["sended_layer_id"] = -1 current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size current_info["status"] = "init" - logger.info(f"Get cache info from D: finish add cache task: {current_info}") + logger.info(f"Get cache info and finish add cache task: {current_info}") self.cache_info[info["request_id"]] = current_info self.idx_cache_task_dict[current_info["current_id"]] = current_info else: - logger.info(f"Get cache info from P: {info}") + logger.info(f"Get cache info: {info}") self.cache_info[info["request_id"]] = info - if finished_add_cache_task_req_ids: - logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}") - self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids) - self.engine_worker_queue.finish_add_cache_task_barrier.wait() else: time.sleep(0.001) except Exception as e: @@ -687,10 +685,12 @@ def prefill_layerwise_send_cache_thread(self): block_start_end_list = [] current_prefilled_token_num_list = [] for engine_index, current_step_prefilled_token_num in batch_engine_signals: + self._maybe_wait_for_cache_task(engine_index) assert ( engine_index in self.idx_cache_task_dict ), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}" block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"] + prefilled_token_num = current_step_prefilled_token_num if ( prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"] @@ -917,6 +917,20 @@ def _handle_connect_task(self): except Exception as e: logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}") + def _maybe_wait_for_cache_task(self, engine_index): + # If cache messager does not get cache task from engine, just hang here for now + wait_step = 1 + sleep_seconds = 0.005 + + while engine_index not in self.idx_cache_task_dict: + time.sleep(sleep_seconds) + wait_step += 1 + + if wait_step % 400 == 0: + logger.warning( + f"waiting cache task for engine_index: {engine_index}, cost_time: {wait_step * 0.005:.2f} s" + ) + def main(): device = args.device_id diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 886f275d098..65bfa6d606b 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -30,7 +30,6 @@ import time import traceback import weakref -from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -42,6 +41,7 @@ import fastdeploy.metrics.trace as tracing from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.config import FDConfig +from fastdeploy.engine.common_engine_prepare_mixin import EngineServicePrepareMixin from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( CompletionOutput, @@ -115,7 +115,7 @@ def _format_worker_launch_failure_message(log_dir: str) -> str: return message -class EngineService: +class EngineService(EngineServicePrepareMixin): """ Base class containing common engine functionality """ @@ -251,12 +251,13 @@ def start(self, async_llm_pid=None): self.start_worker_service(async_llm_pid) if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.insert_task_to_worker_thread = threading.Thread( - target=self._schedule_request_to_worker_v1, daemon=True - ) + self.prepare_request_thread = threading.Thread(target=self._prepare_request_v1, daemon=True) + self.prepare_request_thread.start() + self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker_v1, daemon=True) + self.schedule_request_thread.start() else: - self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) - self.insert_task_to_worker_thread.start() + self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) + self.schedule_request_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() if self.cfg.scheduler_config.splitwise_role == "decode": @@ -879,215 +880,19 @@ def _schedule_request_to_worker_v1(self): Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). """ tracing.trace_set_thread_info("Scheduler Task to Work") - get_request_pool = ThreadPoolExecutor(max_workers=1) - is_fetching = False - - def _fetch_request(): - try: - with self._pause_cond: - self._pause_cond.wait_for(lambda: not self.is_paused) - nonlocal is_fetching - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - - if self.cfg.scheduler_config.splitwise_role != "mixed": - max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens - else: - max_num_batched_tokens = self.cfg.model_config.max_model_len - - available_blocks = self.cfg.cache_config.max_block_num_per_seq - tasks = self.scheduler.get_requests( - available_blocks=available_blocks, - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num - max_num_batched_tokens=max_num_batched_tokens, - batch=num_prefill_batch, - ) - for task in tasks: - task.metrics.engine_get_req_time = time.time() - trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) - - if self.cfg.scheduler_config.splitwise_role == "decode": - # TODO: refine scheduler to remove this limitation - # Decode will process and schedule the request sent by prefill to engine, - # so the same request sent by the decode api server will be ignored - is_fetching = False - return - - if tasks: - self.llm_logger.debug( - f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" - ) - - if self.cfg.scheduler_config.splitwise_role == "prefill": - for task in tasks: - # start async preprocess - self.resource_manager.apply_async_preprocess(task) - need_delete_tasks = [] - if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for request: {task.request_id}" - ) - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") - ) - task.metrics.ask_decode_resource_start_time = time.time() - while True: - self.split_connector.send_splitwise_tasks([task], task.idx) - status, msg = self.split_connector.check_decode_allocated(task) - if not status: - self.llm_logger.warning( - f"D failed to allocate resource for request {task.request_id}, try again." - ) - time.sleep(0.05) - else: - task.metrics.ask_decode_resource_finish_time = time.time() - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_END, - task.request_id, - getattr(task, "user", ""), - ) - break - self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") - else: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for req_id: {task.request_id}" - ) - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "") - ) - task.metrics.ask_decode_resource_start_time = time.time() - self.split_connector.send_splitwise_tasks([task], task.idx) - - for task in tasks: - # assure fetch block ids from D - status, msg = self.split_connector.check_decode_allocated(task) - task.metrics.ask_decode_resource_finish_time = time.time() - trace_print( - LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "") - ) - if not status: - error_msg = ( - f"PD Error: prefill failed to apply for resource from decode, " - f"req: {task.request_id}, msg:{msg}." - ) - self.llm_logger.error(error_msg) - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=error_msg, - ) - ] - ) - main_process_metrics.reschedule_req_num.inc() - need_delete_tasks.append(task) - continue - for tmp_task in need_delete_tasks: - tasks.remove(tmp_task) - # release resource in P - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # to send cache info to cache messager - if tasks: - need_check_req_ids = [task.request_id for task in tasks] - self.split_connector.send_cache_info_to_messager(tasks, 0) - # ensure cache tasks has sent to cache_messager - need_check_req_ids = [task.request_id for task in tasks] - finished_ids, delete_tasks_list = [], [] - while need_check_req_ids: - finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) - self.llm_logger.debug( - f"P has successfully sent cache infos to cache messager for requests: {finished_ids}" - ) - if finished_ids: - for task in tasks: - result = self.resource_manager.waiting_async_process(task) - if result is None: - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=task.error_code, - error_msg=task.error_message, - ) - ] - ) - need_check_req_ids.remove(task.request_id) - delete_tasks_list.append(task) - elif result is False: - if task.request_id in finished_ids: - need_check_req_ids.remove(task.request_id) - finished_ids.remove(task.request_id) - else: - time.sleep(0.001) - - for tmp_task in delete_tasks_list: - tasks.remove(tmp_task) - # release resource in P - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # Fetch requests and add them to the scheduling queue - if tasks: - for task in tasks: - task.metrics.add_req_to_resource_manager_time = time.time() - trace_print( - LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "") - ) - if self.cfg.scheduler_config.splitwise_role == "prefill": - self.resource_manager.add_request_in_p(tasks) - self.llm_logger.info( - f"P add requests into running queue: {[task.request_id for task in tasks]}" - ) - else: - for task in tasks: - self.resource_manager.add_request(task) - is_fetching = False - except Exception as e: - self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}") - is_fetching = False while self.running: with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) + try: if self.engine_worker_queue.exist_tasks(): time.sleep(0.001) continue - if self.cfg.scheduler_config.splitwise_role != "mixed": - if not is_fetching: - is_fetching = True - get_request_pool.submit(_fetch_request) - - else: - if len(self.resource_manager.waiting) == 0 and (not is_fetching): - # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. - try: - is_fetching = True - get_request_pool.submit(_fetch_request) - except RuntimeError as e: - if "shutdown" in str(e): - self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") - break - else: - raise if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() + # 2. Schedule requests tasks, error_tasks = self.resource_manager.schedule() diff --git a/fastdeploy/engine/common_engine_prepare_mixin.py b/fastdeploy/engine/common_engine_prepare_mixin.py new file mode 100644 index 00000000000..71327025458 --- /dev/null +++ b/fastdeploy/engine/common_engine_prepare_mixin.py @@ -0,0 +1,282 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import threading +import time +import traceback + +import fastdeploy.metrics.trace as tracing +from fastdeploy.engine.request import RequestOutput +from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.trace.constants import LoggingEventName +from fastdeploy.trace.trace_logger import print as trace_print +from fastdeploy.utils import envs + + +class EngineServicePrepareMixin: + def _fetch_request_mixed(self) -> bool: + """Fetch and prepare requests for a mixed instance. Returns True if tasks were fetched.""" + # FIXME: to validate if it's necessary for avoiding error when enable mtp + if len(self.resource_manager.waiting) > 0: + return False + + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.model_config.max_model_len + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + if not tasks: + return False + + for task in tasks: + task.metrics.engine_get_req_time = time.time() + trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) + + self.llm_logger.debug( + f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" + ) + + for task in tasks: + task.metrics.add_req_to_resource_manager_time = time.time() + trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) + self.resource_manager.add_request(task) + + return True + + def _fetch_request_decode(self) -> bool: + """Consume scheduler queue for decode instance to prevent memory accumulation. + Returns True if tasks were consumed.""" + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + # Tasks are intentionally discarded - decode receives requests via _decode_process_splitwise_requests + return len(tasks) > 0 + + def _fetch_request_prefill(self) -> bool: + """Fetch and prepare requests for a prefill instance. Returns True if tasks were fetched.""" + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + if not tasks: + return False + + for task in tasks: + task.metrics.engine_get_req_time = time.time() + trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) + + self.llm_logger.debug( + f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" + ) + + # Start async preprocess for all tasks in this batch + for task in tasks: + self.resource_manager.apply_async_preprocess(task) + + # P-side resource preallocation + D-side coordination + failed_tasks = [] + if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for request: {task.request_id}" + ) + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) + task.metrics.ask_decode_resource_start_time = time.time() + while True: + self.split_connector.send_splitwise_tasks([task], task.idx) + status, msg = self.split_connector.check_decode_allocated(task) + if status: + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_END, + task.request_id, + getattr(task, "user", ""), + ) + break + else: + self.llm_logger.warning( + f"D failed to allocate resource for request {task.request_id}, try again." + ) + time.sleep(0.05) + + self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") + else: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for req_id: {task.request_id}" + ) + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) + task.metrics.ask_decode_resource_start_time = time.time() + self.split_connector.send_splitwise_tasks([task], task.idx) + + for task in tasks: + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "")) + if not status: + error_msg = ( + f"PD Error: prefill failed to apply for resource from decode, " + f"req: {task.request_id}, msg:{msg}." + ) + self.llm_logger.error(error_msg) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=error_msg, + ) + ] + ) + main_process_metrics.reschedule_req_num.inc() + failed_tasks.append(task) + + for tmp_task in failed_tasks: + tasks.remove(tmp_task) + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # Check and wait async preprocess + if tasks: + need_check_req_ids = [task.request_id for task in tasks] + failed_tasks = [] + + while need_check_req_ids: + still_in_progress = False + for task in tasks: + if task.request_id not in need_check_req_ids: + continue + + result = self.resource_manager.waiting_async_process(task) + if result is False: # async preprocess success + need_check_req_ids.remove(task.request_id) + elif result is True: + still_in_progress = True + elif result is None: # async preprocess failed + failed_tasks.append(task) + need_check_req_ids.remove(task.request_id) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=task.error_code, + error_msg=task.error_message, + ) + ] + ) + + if still_in_progress: + time.sleep(0.005) + + for tmp_task in failed_tasks: + tasks.remove(tmp_task) + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # Send cache info to messager + if tasks: + self.split_connector.send_cache_info_to_messager(tasks, 0) + + # Fetch requests and add them to the scheduling queue + if tasks: + for task in tasks: + task.metrics.add_req_to_resource_manager_time = time.time() + trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) + self.resource_manager.add_request_in_p(tasks) + self.llm_logger.info(f"P add requests into running queue: {[task.request_id for task in tasks]}") + + return True + + def _fetch_loop(self, fetch_fn, thread_idx: int): + """Fetch loop run by each worker thread.""" + tracing.trace_set_thread_info(f"Prepare Request for Scheduling - thread {thread_idx}") + while self.running: + try: + with self._pause_cond: + self._pause_cond.wait_for(lambda: not self.is_paused) + fetch_fn() + time.sleep(0.002) + except Exception as e: + self.llm_logger.error(f"fetching request error in worker-{thread_idx}: {e} {traceback.format_exc()}") + time.sleep(0.002) + + def _prepare_request_v1(self): + """Prepare request and send to the queue for scheduling""" + tracing.trace_set_thread_info("Prepare Request for Scheduling") + role = self.cfg.scheduler_config.splitwise_role + num_workers = envs.FD_PREFILL_PREPARE_REQ_THREAD_NUM if role == "prefill" else 1 + self.llm_logger.info(f"prepare request for scheduling, role: {role}, num_workers: {num_workers}") + + fetch_fn = { + "mixed": self._fetch_request_mixed, + "prefill": self._fetch_request_prefill, + "decode": self._fetch_request_decode, + }[role] + + self._fetch_threads = [] + for i in range(num_workers): + t = threading.Thread( + target=self._fetch_loop, + args=(fetch_fn, i), + daemon=True, + name=f"fetch-{i}", + ) + t.start() + self._fetch_threads.append(t) + + # Keep this thread alive for graceful shutdown + while self.running: + time.sleep(1.0) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 08cef5849e1..734a04ab484 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -191,6 +191,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")), # "Enable FP8 calibration on HPU" "FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"), + # Number of worker threads for prepare requests in prefill instance + "FD_PREFILL_PREPARE_REQ_THREAD_NUM": lambda: int(os.getenv("FD_PREFILL_PREPARE_REQ_THREAD_NUM", "5")), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), "FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int( os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1") diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index b0fc9bb3385..2cb1246aad3 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -92,7 +92,6 @@ class QueueManager(BaseManager): Value("i", 0) for _ in range(self.local_data_parallel_size) ] self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)] - self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] @@ -110,9 +109,6 @@ class QueueManager(BaseManager): self.connect_task_response_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # connect rdma task response - self.finish_add_cache_task_lock_init: List[threading.Lock] = [ - threading.Lock() for _ in range(self.local_data_parallel_size) - ] # finish add cache task self.finish_send_cache_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # finish send cache @@ -124,18 +120,12 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] - self.client_get_finished_add_cache_task_flag_init: List[List[int]] = [ - [0] * self.num_client for _ in range(self.local_data_parallel_size) - ] self.client_get_finish_send_cache_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] self.can_put_next_connect_task_response_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] - self.can_put_next_add_task_finished_flag_init: List[Value] = [ - Value("i", 1) for _ in range(self.local_data_parallel_size) - ] self.can_put_next_send_cache_finished_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] @@ -147,9 +137,6 @@ class QueueManager(BaseManager): self.get_connect_task_response_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] - self.finish_add_cache_task_barrier = [ - threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) - ] self.begin_send_cache_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] @@ -188,11 +175,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.client_get_connect_task_response_flag_init[idx], proxytype=ListProxy, ) - QueueManager.register( - "get_client_get_finished_add_cache_task_flag_init", - callable=lambda idx: self.client_get_finished_add_cache_task_flag_init[idx], - proxytype=ListProxy, - ) QueueManager.register( "get_client_get_finish_send_cache_flag_init", callable=lambda idx: self.client_get_finish_send_cache_flag_init[idx], @@ -218,11 +200,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.can_put_next_connect_task_response_flag_init[idx], proxytype=ValueProxy, ) - QueueManager.register( - "get_can_put_next_add_task_finished_flag", - callable=lambda idx: self.can_put_next_add_task_finished_flag_init[idx], - proxytype=ValueProxy, - ) QueueManager.register( "get_can_put_next_send_cache_finished_flag", callable=lambda idx: self.can_put_next_send_cache_finished_flag_init[idx], @@ -239,11 +216,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.connect_task_response_lock_init[idx], proxytype=AcquirerProxy, ) - QueueManager.register( - "get_finish_add_cache_task_lock", - callable=lambda idx: self.finish_add_cache_task_lock_init[idx], - proxytype=AcquirerProxy, - ) QueueManager.register( "get_finish_send_cache_lock", callable=lambda idx: self.finish_send_cache_lock_init[idx], @@ -268,12 +240,6 @@ class QueueManager(BaseManager): "get_finish_request_queue", callable=lambda idx: self.finished_req_list[idx], proxytype=ListProxy ) - QueueManager.register( - "get_finish_add_cache_task_queue", - callable=lambda idx: self.finished_add_cache_task_list[idx], - proxytype=ListProxy, - ) - QueueManager.register( "get_cache_infos", callable=lambda idx: self.cache_infos_init[idx], @@ -321,12 +287,6 @@ class QueueManager(BaseManager): "get_cache_info_barrier", callable=lambda idx: self.get_cache_info_barrier[idx], ) - - QueueManager.register( - "get_finish_add_cache_task_barrier", - callable=lambda idx: self.finish_add_cache_task_barrier[idx], - ) - QueueManager.register( "get_worker_process_tp_barrier", callable=lambda idx: self.worker_process_tp_barrier[idx], @@ -351,13 +311,11 @@ class QueueManager(BaseManager): QueueManager.register("get_exist_tasks_inter_signal") QueueManager.register("get_connected_client_counter") QueueManager.register("get_finish_request_queue") - QueueManager.register("get_finish_add_cache_task_queue") QueueManager.register("get_cache_infos") QueueManager.register("get_client_read_info_flag") QueueManager.register("get_lock_info") QueueManager.register("get_disaggregate_requests") QueueManager.register("get_finish_request_barrier") - QueueManager.register("get_finish_add_cache_task_barrier") QueueManager.register("get_connect_task_barrier") QueueManager.register("get_connect_task_response_barrier") QueueManager.register("get_finish_send_cache_barrier") @@ -366,16 +324,13 @@ class QueueManager(BaseManager): QueueManager.register("get_connect_rdma_tasks") QueueManager.register("get_client_get_connect_task_flag") QueueManager.register("get_client_get_connect_task_response_flag") - QueueManager.register("get_client_get_finished_add_cache_task_flag_init") QueueManager.register("get_client_get_finish_send_cache_flag_init") QueueManager.register("get_connect_rdma_tasks_responses") QueueManager.register("get_connect_task_lock") QueueManager.register("get_connect_task_response_lock") - QueueManager.register("get_finish_add_cache_task_lock") QueueManager.register("get_finish_send_cache_lock") QueueManager.register("get_worker_process_tp_barrier") QueueManager.register("get_can_put_next_connect_task_response_flag") - QueueManager.register("get_can_put_next_add_task_finished_flag") QueueManager.register("get_can_put_next_send_cache_finished_flag") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -398,9 +353,6 @@ class QueueManager(BaseManager): # p/d 分离获取 self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id) self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) - self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier( - self.local_data_parallel_id - ) self.connect_task_barrier = self.manager.get_connect_task_barrier(self.local_data_parallel_id) self.connect_task_response_barrier = self.manager.get_connect_task_response_barrier( self.local_data_parallel_id @@ -410,9 +362,6 @@ class QueueManager(BaseManager): self.begin_send_cache_barrier = self.manager.get_begin_send_cache_barrier(self.local_data_parallel_id) self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id) self.finished_send_cache_list = self.manager.get_finish_request_queue(self.local_data_parallel_id) - self.finished_add_cache_task_list = self.manager.get_finish_add_cache_task_queue( - self.local_data_parallel_id - ) # p/d互联 self.connect_rdma_tasks = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) self.client_get_connect_task_flag = self.manager.get_client_get_connect_task_flag( @@ -421,9 +370,6 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag = self.manager.get_client_get_connect_task_response_flag( self.local_data_parallel_id ) - self.client_get_finished_add_cache_task_flag = ( - self.manager.get_client_get_finished_add_cache_task_flag_init(self.local_data_parallel_id) - ) self.client_get_finish_send_cache_flag = self.manager.get_client_get_finish_send_cache_flag_init( self.local_data_parallel_id ) @@ -433,12 +379,8 @@ class QueueManager(BaseManager): ) self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) self.connect_task_response_lock = self.manager.get_connect_task_response_lock(self.local_data_parallel_id) - self.finish_add_cache_task_lock = self.manager.get_finish_add_cache_task_lock(self.local_data_parallel_id) self.finish_send_cache_lock = self.manager.get_finish_send_cache_lock(self.local_data_parallel_id) - self.can_put_next_add_task_finished_flag = self.manager.get_can_put_next_add_task_finished_flag( - self.local_data_parallel_id - ) self.can_put_next_connect_task_response_flag = self.manager.get_can_put_next_connect_task_response_flag( self.local_data_parallel_id ) @@ -756,54 +698,6 @@ def get_finished_req(self) -> str: self.finish_send_cache_lock.release() return response - def put_finished_add_cache_task_req(self, req_ids) -> None: - """ - Put finished request ID into the queue. - - Args: - req_ids: Request ID to be added to the queue - """ - self.finish_add_cache_task_lock.acquire() - while not self.can_put_next_add_task_finished_flag.get(): - self.finish_add_cache_task_lock.release() - time.sleep(0.001) - self.finish_add_cache_task_lock.acquire() - self.finished_add_cache_task_list.append(req_ids) - self.client_get_finished_add_cache_task_flag[self.client_id] = 1 - all_client_put: bool = np.sum(self.client_get_finished_add_cache_task_flag) == self.num_client - if all_client_put: - self.can_put_next_add_task_finished_flag.set(0) - self.finish_add_cache_task_lock.release() - return all_client_put - - def get_finished_add_cache_task_req(self) -> str: - """ - Get finished request ID from the queue. - - Returns: - str: Finished request ID - """ - response = [] - self.finish_add_cache_task_lock.acquire() - if len(self.finished_add_cache_task_list) == 0: - self.finish_add_cache_task_lock.release() - return response - while sum(self.client_get_finished_add_cache_task_flag) < self.num_client: - self.finish_add_cache_task_lock.release() - time.sleep(0.001) - self.finish_add_cache_task_lock.acquire() - if len(self.finished_add_cache_task_list) > 0: - response = self.finished_add_cache_task_list[0] - for tmp_response in self.finished_add_cache_task_list: - assert ( - tmp_response == response - ), f"Inconsistent responses across workers: expected {response}, got {tmp_response}" - self.finished_add_cache_task_list[:] = list() - self.client_get_finished_add_cache_task_flag[:] = [0] * self.num_client - self.can_put_next_add_task_finished_flag.set(1) - self.finish_add_cache_task_lock.release() - return response - def disaggregate_queue_empty(self): """ Check if the disaggregated task queue is empty. diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py index c69d27a24fa..3e415ebe9c8 100644 --- a/tests/cache_manager/test_cache_messager.py +++ b/tests/cache_manager/test_cache_messager.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util +import os import sys import types @@ -22,7 +24,21 @@ if not hasattr(paddle, "enable_compat"): paddle.enable_compat = lambda *args, **kwargs: None -from fastdeploy.cache_manager import cache_messager +# Import the legacy cache_messager module directly from the .py file, +# because the cache_messager/ package shadows it and the legacy +# fallback (cache_messager_legacy) does not exist locally. +_cm_py_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "fastdeploy", + "cache_manager", + "cache_messager.py", +) +_spec = importlib.util.spec_from_file_location( + "fastdeploy.cache_manager.cache_messager_py", + _cm_py_path, +) +cache_messager = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(cache_messager) class _DummyBarrier: @@ -40,12 +56,10 @@ def __init__(self, cache_info_sequence=None, connect_task_sequence=None, **kwarg self.cache_info_calls = 0 self.connect_task_calls = 0 self.cache_info_barrier = _DummyBarrier() - self.finish_add_cache_task_barrier = _DummyBarrier() self.finish_send_cache_barrier = _DummyBarrier() self.connect_task_barrier = _DummyBarrier() self.connect_task_response_barrier = _DummyBarrier() self.begin_send_cache_barrier = _DummyBarrier() - self.finished_add_cache_task_req_ids = [] self.finished_req_payloads = [] self.connect_task_responses = [] @@ -56,9 +70,6 @@ def get_cache_info(self): self.cache_info_calls += 1 return info - def put_finished_add_cache_task_req(self, req_ids): - self.finished_add_cache_task_req_ids.append(req_ids) - def put_finished_req(self, payload): self.finished_req_payloads.append(payload) @@ -376,7 +387,6 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch): } with pytest.raises(SystemExit): messager._add_cache_task_thread() - assert dummy_queue.finished_add_cache_task_req_ids == [["req-2"]] assert messager.cache_info["req-2"]["status"] == "init" diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 84dbe2ce3c5..833bd5008da 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -579,6 +579,7 @@ def test_start_prefill_branch_cache_manager_and_worker_dead(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None + eng._prepare_request_v1 = lambda: None started_cache = {} @@ -624,6 +625,7 @@ def test_start_mixed_branch_cache_after_load_and_zmq(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None + eng._prepare_request_v1 = lambda: None started_cache = {} @@ -1379,21 +1381,18 @@ def test_schedule_request_to_worker_v1_mixed_single_iteration(self): task = Request(request_id="v1_r0", prompt_token_ids=[1], prompt_token_ids_len=1) task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) - eng.resource_manager = self._make_v1_decode_rm(eng, ([], []), with_add_request=True) + eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []), with_add_request=True) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False - eng.resource_manager.add_request.assert_called_once_with(task) + eng.engine_worker_queue.put_tasks.assert_called_once() self._detach_finalizer(eng) def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): @@ -1413,7 +1412,6 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(return_value=[]), ) eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) @@ -1425,11 +1423,13 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): try: with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", False), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", + False, + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() finally: eng.running = False @@ -1450,17 +1450,14 @@ def test_schedule_request_to_worker_v1_decode_preempted_and_errors(self): task.task_type = RequestType.PREEMPTED task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")])) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1484,16 +1481,13 @@ def test_schedule_request_to_worker_v1_decode_prefill_task_path(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [])) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1515,23 +1509,20 @@ def test_schedule_request_to_worker_v1_error_task_none_skips_send(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)])) - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() eng.engine_worker_queue.put_tasks.assert_called_once() eng._send_error_response.assert_not_called() self._detach_finalizer(eng) - def test_schedule_request_to_worker_v1_threadpool_shutdown_breaks(self): + def test_schedule_request_to_worker_v1_no_tasks_sleeps(self): eng = self._make_mixed_engine() self._setup_v1_engine(eng) @@ -1539,17 +1530,7 @@ def test_schedule_request_to_worker_v1_threadpool_shutdown_breaks(self): eng.resource_manager = self._make_v1_decode_rm(eng, ([], [])) - class DummyExecutor: - def __init__(self, max_workers=None): - pass - - def submit(self, fn): - raise RuntimeError("cannot schedule new futures after shutdown") - - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", DummyExecutor), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() self._detach_finalizer(eng) @@ -1572,17 +1553,8 @@ def test_schedule_request_to_worker_v1_prefill_continuous_cache_success(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) - calls = {"n": 0} - - def get_finished_add_cache_task_req(): - if calls["n"] == 0: - calls["n"] += 1 - return ["pc_ok"] - return [] - eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1592,11 +1564,12 @@ def get_finished_add_cache_task_req(): ) with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() eng.split_connector.send_splitwise_tasks.assert_called() eng.split_connector.send_cache_info_to_messager.assert_called_once() @@ -1624,17 +1597,8 @@ def test_schedule_request_to_worker_v1_prefill_continuous_wait_async_none(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=None) - calls = {"n": 0} - - def get_finished_add_cache_task_req(): - if calls["n"] == 0: - calls["n"] += 1 - return ["pc_fail"] - return [] - eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1644,11 +1608,12 @@ def get_finished_add_cache_task_req(): ) with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() eng.scheduler.put_results.assert_called_once() eng.resource_manager.pre_recycle_resource.assert_called_once_with("pc_fail") diff --git a/tests/inter_communicator/test_e2w_queue.py b/tests/inter_communicator/test_e2w_queue.py index d3cd657f01a..333249cc66d 100644 --- a/tests/inter_communicator/test_e2w_queue.py +++ b/tests/inter_communicator/test_e2w_queue.py @@ -301,15 +301,15 @@ def test_wait_loops_and_tensor_conversion(self): client.get_finished_req() thread.join() - client.can_put_next_add_task_finished_flag.set(0) - thread = self._set_value_after_delay(client.can_put_next_add_task_finished_flag, 1) - client.put_finished_add_cache_task_req(["req-wait"]) + client.can_put_next_send_cache_finished_flag.set(0) + thread = self._set_value_after_delay(client.can_put_next_send_cache_finished_flag, 1) + client.put_finished_req([["req-wait", {"status": "ok"}]]) thread.join() - client.finished_add_cache_task_list.append(["req-wait"]) - client.client_get_finished_add_cache_task_flag[:] = [0] - thread = self._set_list_after_delay(client.client_get_finished_add_cache_task_flag, [1]) - client.get_finished_add_cache_task_req() + client.finished_send_cache_list.append(["req-wait", {"error": "bad"}]) + client.client_get_finish_send_cache_flag[:] = [0] + thread = self._set_list_after_delay(client.client_get_finish_send_cache_flag, [1]) + client.get_finished_req() thread.join() finally: paddle.set_device(previous_device) @@ -361,18 +361,6 @@ def test_finished_req_flow(self): finally: self._cleanup_queue_pair(server) - def test_finished_add_cache_task_req(self): - server, client = self._build_queue_pair() - try: - req_ids = ["req-2"] - self.assertTrue(client.put_finished_add_cache_task_req(req_ids)) - client.finished_add_cache_task_list.append(req_ids) - self.assertEqual(client.get_finished_add_cache_task_req(), req_ids) - self.assertEqual(client.get_finished_add_cache_task_req(), []) - self.assertEqual(client.can_put_next_add_task_finished_flag.get(), 1) - finally: - self._cleanup_queue_pair(server) - def test_disaggregated_queue(self): server, client = self._build_queue_pair() try: From d38eeb8334331b903fb34c9cd270ff0c58a48adf Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Wed, 13 May 2026 15:50:31 +0800 Subject: [PATCH 103/143] [Scheduler] [Optimization] Only preempt decode requests and better manage reserved blocks in scheduler (#7444) (#7783) * [Optimization] Use new_token_ratio to control reserved blocks in scheduler * Only decode req can be preempted * Optimize scheduler for chunk prefill * [chore] add env var to switch reserved blocks policy * [chore] remove useless code * [test] fix some ci test * [test] fix embedding serving * [fix] fix for cache manager v1 * [fix] fix for cache manager v1 * [test] fix test_resource_manager_v1.py * [opt] stepped scheduling after model forward in mixed mode * Revert "[opt] stepped scheduling after model forward in mixed mode" This reverts commit 40f774eb3eb63dfd791859b51250b641308ae692. * [chore] remove unused code --------- Co-authored-by: juncaipeng <13006307475@163.com> Co-authored-by: rainyfly <1435317881@qq.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> --- fastdeploy/engine/request.py | 9 +- .../engine/sched/resource_manager_v1.py | 211 ++++++++++++------ fastdeploy/envs.py | 8 + fastdeploy/output/token_processor.py | 3 + tests/engine/test_resource_manager_v1.py | 4 +- tests/output/test_process_batch_output.py | 3 +- tests/output/test_token_processor.py | 15 +- tests/v1/test_resource_manager_v1.py | 5 +- 8 files changed, 183 insertions(+), 75 deletions(-) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 7ecac8ff126..05ee4a348ea 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -50,10 +50,11 @@ class RequestStatus(Enum): WAITING = 0 - RUNNING = 1 - PREEMPTED = 2 - FINISHED = 3 - ABORT = 4 + RUNNING_PREFILL = 1 + RUNNING_DECODE = 2 + PREEMPTED = 3 + FINISHED = 4 + ABORT = 5 class RequestType(Enum): diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 9c18f02da12..e25ed0e1231 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -218,18 +218,23 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.bos_client = None self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4) - self.init_reserve_output_block_num = ( - envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.decay_output_block_num = ( - envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # float - self.min_reserve_output_block_num = ( - envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = True + self.use_new_token_ratio_reserve = envs.FD_USE_NEW_TOKEN_RATIO_RESERVE + if self.use_new_token_ratio_reserve: + self.init_new_token_ratio = envs.FD_INIT_NEW_TOKEN_RATIO + self.min_new_token_ratio = envs.FD_MIN_NEW_TOKEN_RATIO + self.new_token_ratio_decay = envs.FD_NEW_TOKEN_RATIO_DECAY + self.clip_max_new_tokens = envs.FD_CLIP_MAX_NEW_TOKENS + self.new_token_ratio = self.init_new_token_ratio + else: + self.init_reserve_output_block_num = envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + self.decay_output_block_num = envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + self.min_reserve_output_block_num = ( + envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + ) + self.current_reserve_output_block_num = self.init_reserve_output_block_num + self.current_reserve_output_block_num_float = float(self.init_reserve_output_block_num) + self.can_relax_prefill_strategy = True + # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 @@ -312,15 +317,6 @@ def _info_each_block(self): f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables" ) - def _can_preempt(self): - """ - cannot preempt request which use extend block - """ - for req in self.running: - if not req.use_extend_tables: - return True - return False - def preempted_all(self): with self.lock: preempted_reqs = [] @@ -355,17 +351,49 @@ def wait_worker_inflight_requests_finish(self, timeout=60): f"still {len(self.to_be_rescheduled_request_id_set)} requests running" ) + def _select_preempt_candidate(self): + # Scan from back to front to find the last preemptable request + preempted_req = None + i = len(self.running) - 1 + while i >= 0: + candidate = self.running[i] + # Skip requests that are not in decode status + if candidate.status != RequestStatus.RUNNING_DECODE: + i -= 1 + continue + # Skip requests using extend tables + if candidate.use_extend_tables: + i -= 1 + continue + # Found a valid preempt target + preempted_req = candidate + break + return preempted_req, i + def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): """ If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out. + Only requests that is in decode status can be preempted. """ can_schedule = False - while self._can_preempt(): - if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): - preempted_req = self.running.pop() - if preempted_req.use_extend_tables: - self.running.insert(0, preempted_req) - continue + while True: + if self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): + # The request can be scheduled. + can_schedule = True + break + else: + # Try to find a candidate request to preempt. + preempted_req, preempted_idx = self._select_preempt_candidate() + if preempted_req is None: + can_schedule = False + llm_logger.warning( + f"Preemption is triggered while no preemptable request can be found, scheduler may be hung! " + f"Running requests: {self.running}" + ) + break + + # Remove the preempted request from the running list + self.running.pop(preempted_idx) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.config.scheduler_config.splitwise_role == "decode": @@ -397,33 +425,82 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re llm_logger.debug( f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}" ) + llm_logger.debug(self.info()) self._info_each_block() + self._reset_reserve_on_preemption() if preempted_req == request: # No more request to preempt. can_schedule = False break - else: - # The request can be scheduled. - can_schedule = True - break - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = False + return can_schedule + def _reset_reserve_on_preemption(self): + """Reset reserved blocks on preemption.""" + if self.use_new_token_ratio_reserve: + if not self.running: + self.new_token_ratio = self.init_new_token_ratio + return + total_decoded_tokens = sum(len(req.output_token_ids) for req in self.running) + total_max_new_tokens = 0 + for req in self.running: + max_tokens = req.sampling_params.max_tokens + if max_tokens is None: + max_tokens = self.config.model_config.max_model_len - req.prompt_token_ids_len + total_max_new_tokens += max_tokens + num_running_decode = sum( + [1 if req.num_total_tokens > req.need_prefill_tokens else 0 for req in self.running] + ) + extra_decode_steps = ( + 16 * self.config.cache_config.block_size + ) # consider extra 16 blocks for each running decode request when estimating new token ratio + new_ratio = (total_decoded_tokens + extra_decode_steps * num_running_decode) / (total_max_new_tokens + 1) + self.new_token_ratio = min(new_ratio, self.init_new_token_ratio) + llm_logger.info( + f"Estimate new token ratio for preemption: {self.new_token_ratio}, " + f"total_decoded_tokens={total_decoded_tokens}, total_max_new_tokens={total_max_new_tokens}, num_running_decode={num_running_decode}" + ) + + else: + self.current_reserve_output_block_num = self.init_reserve_output_block_num + self.current_reserve_output_block_num_float = float(self.init_reserve_output_block_num) + self.can_relax_prefill_strategy = False + + def _get_running_request_reserve_blocks(self, request: Request) -> int: + """Estimate KV-cache blocks to reserve for a running request's future decode tokens. + + Aligned with SGLang's per-request budget estimation: + reserved_tokens = min(max_tokens - already_generated, CLIP_MAX_NEW_TOKENS) * new_token_ratio + then ceil-divided by block_size. The ratio decays each scheduling step so that + the reservation gradually relaxes; on preemption it resets to the initial value. + """ + max_tokens = getattr(request.sampling_params, "max_tokens", None) + if max_tokens is None: + max_tokens = self.config.model_config.max_model_len - request.prompt_token_ids_len + remaining_tokens = max_tokens - len(request.output_token_ids) + clipped_remaining = min(remaining_tokens, self.clip_max_new_tokens) + reserved_tokens = max(int(clipped_remaining * self.new_token_ratio), 0) + block_size = self.config.cache_config.block_size + return (reserved_tokens + block_size - 1) // block_size + def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block): - if self.can_relax_prefill_strategy: - can_schedule_block_num_threshold = num_chunk_new_block + """Compute the minimum free blocks required to admit a new prefill request.""" + if self.use_new_token_ratio_reserve: + reserve_blocks = sum(self._get_running_request_reserve_blocks(req) for req in self.running) + can_schedule_block_num_threshold = num_chunk_new_block + reserve_blocks else: - can_schedule_block_num_threshold = ( - num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num - ) - if self.config.speculative_config.method is not None: - can_schedule_block_num_threshold = min( - can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + if self.can_relax_prefill_strategy: + can_schedule_block_num_threshold = num_chunk_new_block + else: + can_schedule_block_num_threshold = ( + num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num ) + if self.config.speculative_config.method is not None: + can_schedule_block_num_threshold = min( + can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + ) return can_schedule_block_num_threshold def _update_mm_hashes(self, request): @@ -786,6 +863,7 @@ def get_enough_request(request, scheduled_reqs): self.config.scheduler_config.max_num_batched_tokens - num_running_decode_reqs * tokens_per_seq ) need_abort_requests = [] # users trigger abortion + chunk_prefill_in_running_not_satisfied = False # First, schedule the RUNNING requests. req_index = 0 @@ -922,22 +1000,17 @@ def _allocate_decode_and_extend(): req_index += 1 continue num_new_block = self.get_new_block_nums(request, num_new_tokens) + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(num_new_block) # Allocate blocks to prefill - if self.cache_manager.can_allocate_gpu_blocks(num_new_block): - request.block_tables.extend( - self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) - ) - # Prepare prefill task - scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) - else: # Not enough blocks to allocate, trigger preemption - can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) - if not can_schedule: - break + if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): request.block_tables.extend( self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) ) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + else: # Not enough blocks to allocate + chunk_prefill_in_running_not_satisfied = True + break # For chunk prefill request, if not satisfy condition for prefill, just break token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if ( @@ -955,7 +1028,7 @@ def _allocate_decode_and_extend(): self.running.remove(request) # Second, schedule the WAITING requests. - if not preempted_reqs: + if (not preempted_reqs) and (not chunk_prefill_in_running_not_satisfied): skip_requests: list[Request] = [] while self.waiting and token_budget > 0: if ( @@ -1041,7 +1114,7 @@ def _allocate_decode_and_extend(): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) - request.status = RequestStatus.RUNNING + request.status = RequestStatus.RUNNING_PREFILL if self.config.scheduler_config.splitwise_role == "mixed": allocated_position = self.get_available_position() request.idx = allocated_position @@ -1110,7 +1183,7 @@ def _allocate_decode_and_extend(): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) - request.status = RequestStatus.RUNNING + request.status = RequestStatus.RUNNING_PREFILL else: if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) @@ -1124,14 +1197,20 @@ def _allocate_decode_and_extend(): if scheduled_reqs: llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") - self.current_reserve_output_block_num_float -= self.decay_output_block_num - self.current_reserve_output_block_num = max( - int(self.current_reserve_output_block_num_float), - self.min_reserve_output_block_num, - 0, - ) - if self.current_reserve_output_block_num == 0: - self.can_relax_prefill_strategy = True + if self.use_new_token_ratio_reserve: + self.new_token_ratio = max( + self.new_token_ratio - self.new_token_ratio_decay, + self.min_new_token_ratio, + ) + else: + self.current_reserve_output_block_num_float -= self.decay_output_block_num + self.current_reserve_output_block_num = max( + int(self.current_reserve_output_block_num_float), + self.min_reserve_output_block_num, + 0, + ) + if self.current_reserve_output_block_num == 0: + self.can_relax_prefill_strategy = True self._log_console_scheduler_metrics(scheduled_reqs) @@ -1355,6 +1434,7 @@ def pre_recycle_resource(self, request_id: str): def add_request_in_p(self, requests: list[Request]): with self.lock: for request in requests: + request.status = RequestStatus.RUNNING_PREFILL self.running.append(request) def preallocate_resource_in_p(self, request: Request): @@ -1487,6 +1567,7 @@ def add_prefilled_request(self, request_output: RequestOutput): ): request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids) request.need_prefill_tokens = len(request.prompt_token_ids) + 1 + request.status = RequestStatus.RUNNING_DECODE request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time @@ -1553,7 +1634,7 @@ def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): def finish_requests(self, request_ids: Union[str, Iterable[str]]): llm_logger.info(f"recycle resources for requests: {request_ids}") - self.update_metrics(verbose=True) + self.update_metrics() try: if isinstance(request_ids, str): request_ids = (request_ids,) @@ -1608,7 +1689,7 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): except Exception as e: llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") finally: - self.update_metrics(verbose=True) + self.update_metrics() def clear_data(self): self.waiting: deque[Request] = deque() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 734a04ab484..509f9a768d9 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -217,6 +217,11 @@ def _validate_split_kv_size(value: int) -> int: # Whether to enable low latency in mixed scenario "FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))), # Reserve output blocks for decoding requests when schedule new prefill requests + "FD_INIT_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_INIT_NEW_TOKEN_RATIO", "0.7")), + "FD_MIN_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_MIN_NEW_TOKEN_RATIO", "0.1")), + "FD_NEW_TOKEN_RATIO_DECAY": lambda: float(os.getenv("FD_NEW_TOKEN_RATIO_DECAY", "0.001")), + "FD_CLIP_MAX_NEW_TOKENS": lambda: int(os.getenv("FD_CLIP_MAX_NEW_TOKENS", "4096")), + # Legacy reserve block env vars (kept for backwards compatibility, no longer used) "FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16") ), @@ -226,6 +231,9 @@ def _validate_split_kv_size(value: int) -> int: "FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "0") ), + # When True, use per-request new_token_ratio to estimate reserved blocks (SGLang-style). + # When False, fall back to the legacy fixed-block reservation strategy. + "FD_USE_NEW_TOKEN_RATIO_RESERVE": lambda: bool(int(os.getenv("FD_USE_NEW_TOKEN_RATIO_RESERVE", "1"))), # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), # File path for file storage backend diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 6f6a8043803..2a8328b28fe 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -37,6 +37,7 @@ Request, RequestMetrics, RequestOutput, + RequestStatus, SpeculateMetrics, ) from fastdeploy.inter_communicator import ZmqIpcServer @@ -950,6 +951,8 @@ def _process_batch_output(self): continue self.total_step += 1 + if task.status == RequestStatus.RUNNING_PREFILL: + task.status = RequestStatus.RUNNING_DECODE current_time = time.time() trace_carrier = None if self.tokens_counter[task_id] == 0: diff --git a/tests/engine/test_resource_manager_v1.py b/tests/engine/test_resource_manager_v1.py index 23275f29f70..716770294a6 100644 --- a/tests/engine/test_resource_manager_v1.py +++ b/tests/engine/test_resource_manager_v1.py @@ -72,7 +72,7 @@ def test_preempted_all_with_normal_requests(self): req1 = Mock(spec=Request) req1.request_id = "req1" req1.use_extend_tables = False - req1.status = RequestStatus.RUNNING + req1.status = RequestStatus.RUNNING_DECODE req1.block_tables = [1, 2, 3] req1.num_cached_blocks = 0 req1.idx = 0 @@ -80,7 +80,7 @@ def test_preempted_all_with_normal_requests(self): req2 = Mock(spec=Request) req2.request_id = "req2" req2.use_extend_tables = False - req2.status = RequestStatus.RUNNING + req2.status = RequestStatus.RUNNING_DECODE req2.block_tables = [4, 5] req2.num_cached_blocks = 0 req2.idx = 1 diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index b47344470de..c84514c06d5 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -21,7 +21,7 @@ import paddle -from fastdeploy.engine.request import RequestMetrics, RequestOutput +from fastdeploy.engine.request import RequestMetrics, RequestOutput, RequestStatus from fastdeploy.output.token_processor import TokenProcessor paddle.set_device("cpu") @@ -82,6 +82,7 @@ def __init__(self): self.ic_req_data = {} self.prompt_token_ids_len = 0 self.trace_carrier = {} + self.status = RequestStatus.RUNNING_DECODE now = time.time() self.metrics = RequestMetrics( diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 0fd4d1753ee..5c26db778c7 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -25,7 +25,12 @@ import pytest from fastdeploy import envs -from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput +from fastdeploy.engine.request import ( + Request, + RequestMetrics, + RequestOutput, + RequestStatus, +) from fastdeploy.output import token_processor from fastdeploy.output.token_processor import ( MAX_BSZ, @@ -671,6 +676,7 @@ def test_process_batch_output_consumes_tokens_and_finishes_task(): prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) @@ -708,6 +714,7 @@ def test_process_batch_output_logprob_records_topk_and_caching(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -784,6 +791,7 @@ def test_process_batch_output_speculative_recovery_stop_finishes(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -911,6 +919,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -1076,6 +1085,7 @@ def test_process_batch_output_records_second_decode_token(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.metrics.inference_start_time = time.time() @@ -1145,6 +1155,7 @@ def test_process_batch_output_prefill_sets_draft_tokens(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -1186,6 +1197,7 @@ def test_process_batch_output_logs_recovery_stop_for_non_speculative(): prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda k, d=None: getattr(task, k, d) @@ -1223,6 +1235,7 @@ def test_process_batch_output_sets_multimodal_token_counts(): num_total_tokens=1, block_tables=[1], multimodal_inputs={"num_input_image_tokens": 4, "num_input_video_tokens": 5}, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index a93d5741d14..d9ab6a59dbc 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -650,7 +650,7 @@ def test_schedule_decode_and_waiting_prefill(self): decode_request = _make_request(request_id="req-decode", prompt_token_ids=[1, 2]) decode_request.idx = 0 - decode_request.status = RequestStatus.RUNNING + decode_request.status = RequestStatus.RUNNING_DECODE decode_request.num_computed_tokens = 2 decode_request.output_token_ids = [99] decode_request.block_tables = [1] @@ -665,7 +665,7 @@ def test_schedule_decode_and_waiting_prefill(self): self.assertGreaterEqual(len(scheduled_reqs), 2) self.assertEqual(error_reqs, []) self.assertIn(decode_request.request_id, manager.using_extend_tables_req_id) - self.assertEqual(waiting_request.status, RequestStatus.RUNNING) + self.assertEqual(waiting_request.status, RequestStatus.RUNNING_PREFILL) def test_trigger_preempt_records_tasks(self): manager = _build_manager() @@ -678,6 +678,7 @@ def test_trigger_preempt_records_tasks(self): preempted_req = _make_request(request_id="req-preempted") preempted_req.idx = 0 preempted_req.use_extend_tables = False + preempted_req.status = RequestStatus.RUNNING_DECODE request = _make_request(request_id="req-target") request.idx = 1 manager.running = [request, preempted_req] From 5e76c8bab60ec816faf932f5d5b4892d991e8ec4 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Wed, 13 May 2026 16:31:37 +0800 Subject: [PATCH 104/143] fix(PrefixCache): fix garbled text in PD disaggregation by early return when no new tokens to cache (#7797) (#7802) - Add early return check when can_cache_computed_tokens <= num_cached_tokens - Avoid unnecessary cache insertion operations that cause garbled output - Only affects PD disaggregation scenarios Co-authored-by: kevin --- fastdeploy/cache_manager/prefix_cache_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index e12e47d3fd0..c41a6109029 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -731,6 +731,8 @@ def update_cache_blocks(self, task, block_size, num_computed_tokens): req_id = task.request_id last_node, num_cached_tokens = self.req_to_radix_tree_info[req_id] can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size + if can_cache_computed_tokens <= num_cached_tokens: + return if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later self.leaf_req_map[last_node].remove(req_id) logger.debug( From 33b22b33be3edc03418c0d2e0fec1f2096f285b0 Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Wed, 13 May 2026 18:36:58 +0800 Subject: [PATCH 105/143] [Cherry-pick] [Optimization] Elemenwise fusion (#6880) (#7683) * support ele fusion in get score * bug fix * add enable_moe_scores_elementwise_fuse arg * fix conflict --- custom_ops/gpu_ops/cpp_extensions.cc | 11 + custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu | 206 ++++++++ custom_ops/setup_ops.py | 2 + fastdeploy/engine/args_utils.py | 12 +- fastdeploy/engine/engine.py | 1 + .../layers/moe/fused_cast_sigmoid_bias.py | 73 +++ .../layers/moe/fused_moe_cutlass_backend.py | 15 +- .../layers/moe/fused_moe_deepgemm_backend.py | 8 +- .../layers/moe/fused_moe_triton_backend.py | 10 +- fastdeploy/model_executor/layers/moe/moe.py | 14 +- fastdeploy/scheduler/config.py | 1 + fastdeploy/worker/worker_process.py | 5 + tests/layers/test_deepgemm_fused_moe.py | 21 +- tests/layers/test_fused_cast_sigmoid_bias.py | 497 ++++++++++++++++++ .../layers/test_fused_moe_cutlass_backend.py | 67 ++- tests/layers/test_fused_moe_triton_backend.py | 90 ++++ 16 files changed, 1024 insertions(+), 9 deletions(-) create mode 100644 custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu create mode 100644 fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py create mode 100644 tests/layers/test_fused_cast_sigmoid_bias.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 15866c57643..591cf363f06 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -691,6 +691,10 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type); + std::vector NoauxTcRedundant( paddle::Tensor& scores, paddle::Tensor& scores_with_bias, @@ -1700,6 +1704,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("fused_cast_sigmoid_bias", + &FusedCastSigmoidBias, + "Fused cast+sigmoid+bias for MoE gating scores", + py::arg("input"), + py::arg("bias"), + py::arg("cast_type") = std::string("float32")); + m.def("noaux_tc_redundant", &NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu new file mode 100644 index 00000000000..f25084076c4 --- /dev/null +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -0,0 +1,206 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +// Fused kernel: cast(input, cast_type) -> sigmoid -> scores, scores + bias -> +// scores_with_bias +// +// For each element (token i, expert j): +// scores[i][j] = OutT(sigmoid(float(input[i][j]))) +// scores_with_bias[i][j] = OutT(sigmoid(float(input[i][j])) + bias[j]) +// +// Input: input [num_tokens, num_experts] bf16/fp16/fp32 +// bias [num_experts] or [1, num_experts] fp32 +// Output: scores [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// scores_with_bias [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// +// Precision guarantee: +// All intermediate computations (cast, sigmoid, bias addition) are performed +// in float32, regardless of input/output types. The cast to OutT only happens +// at the final store. This matches the reference implementation: +// gate_fp32 = gate_out.cast("float32") +// scores_fp32 = sigmoid(gate_fp32) +// scores_with_bias_fp32 = scores_fp32 + bias // bias is always float32 +// scores = scores_fp32.cast(cast_type) +// scores_with_bias = scores_with_bias_fp32.cast(cast_type) +// +// When cast_type is "float32", the fused kernel is numerically identical to +// the reference. For fp16/bf16 output, the only precision loss comes from +// the final static_cast, equivalent to .cast() in the reference path. +// +// Note: bias is intentionally kept as float32 (not converted to OutT) to +// ensure the addition s + bias[j] is always computed in full float32 +// precision before the final downcast. + +template +__global__ void fused_cast_sigmoid_bias_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + // All intermediate computation in float32 for precision + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias[j] (float32) -> float32 addition, then downcast + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +// Vectorized version for better memory throughput +template +__global__ void fused_cast_sigmoid_bias_vec_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, // kept as float32 for full-precision add + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + using in_vec_t = AlignedVector; + using out_vec_t = AlignedVector; + using bias_vec_t = AlignedVector; // float32 bias vectors + + const int vec_count = num_experts / kVecSize; + for (int idx = threadIdx.x; idx < vec_count; idx += blockDim.x) { + const int base = idx * kVecSize; + in_vec_t in_vec; + bias_vec_t bias_vec; + Load(input + offset + base, &in_vec); + Load(bias + base, &bias_vec); + + out_vec_t s_vec, sb_vec; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + // All intermediate computation in float32 for precision + float val = static_cast(in_vec[i]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias_vec[i] (float32) -> float32 addition, then downcast + s_vec[i] = static_cast(s); + sb_vec[i] = static_cast(s + bias_vec[i]); + } + + Store(s_vec, scores + offset + base); + Store(sb_vec, scores_with_bias + offset + base); + } + + // Handle remaining elements (same float32 precision guarantee) + const int remaining_start = vec_count * kVecSize; + for (int j = remaining_start + threadIdx.x; j < num_experts; + j += blockDim.x) { + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +static paddle::DataType ParseCastType(const std::string& cast_type) { + if (cast_type == "float32") return paddle::DataType::FLOAT32; + if (cast_type == "float16") return paddle::DataType::FLOAT16; + if (cast_type == "bfloat16") return paddle::DataType::BFLOAT16; + PD_THROW("Unsupported cast_type: " + cast_type + + ". Only float32, float16, bfloat16 are supported."); +} + +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type) { + auto input_shape = input.shape(); + PD_CHECK(input_shape.size() == 2, + "input must be 2D [num_tokens, num_experts]"); + auto bias_shape = bias.shape(); + // Support both [num_experts] and [1, num_experts] bias shapes + PD_CHECK( + bias_shape.size() == 1 || (bias_shape.size() == 2 && bias_shape[0] == 1), + "bias must be 1D [num_experts] or 2D [1, num_experts]"); + + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + int64_t bias_numel = (bias_shape.size() == 1) ? bias_shape[0] : bias_shape[1]; + PD_CHECK(bias_numel == num_experts, "bias size must match num_experts"); + PD_CHECK(bias.dtype() == paddle::DataType::FLOAT32, + "bias must be float32, got ", + bias.dtype()); + + auto place = input.place(); + auto stream = input.stream(); + auto out_dtype = ParseCastType(cast_type); + + auto scores = paddle::empty({num_tokens, num_experts}, out_dtype, place); + auto scores_with_bias = + paddle::empty({num_tokens, num_experts}, out_dtype, place); + + if (num_tokens == 0) { + return {scores, scores_with_bias}; + } + + dim3 grid(num_tokens); + int block_size = std::min(static_cast(1024), num_experts); + // Round up to warp size + block_size = ((block_size + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(block_size); + + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), in_scalar_t, { + DISPATCH_FLOAT_FP6_DTYPE(out_dtype, out_scalar_t, { + constexpr int kVecSize = 16 / sizeof(in_scalar_t); + if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { + fused_cast_sigmoid_bias_vec_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } else { + fused_cast_sigmoid_bias_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } + }); + }); + + return {scores, scores_with_bias}; +} + +std::vector FusedCastSigmoidBiasInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& bias_dtype, + std::string cast_type) { + auto out_dtype = ParseCastType(cast_type); + return {out_dtype, out_dtype}; +} + +std::vector> FusedCastSigmoidBiasInferShape( + const std::vector& input_shape, + const std::vector& bias_shape) { + return {input_shape, input_shape}; +} + +PD_BUILD_STATIC_OP(fused_cast_sigmoid_bias) + .Inputs({"input", "bias"}) + .Outputs({"scores", "scores_with_bias"}) + .Attrs({"cast_type: std::string"}) + .SetKernelFn(PD_KERNEL(FusedCastSigmoidBias)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedCastSigmoidBiasInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedCastSigmoidBiasInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 268cde02825..7ae1e964761 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -330,6 +330,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length.cu", @@ -684,6 +685,7 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 79fb13d95c7..892d6668859 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -337,6 +337,11 @@ class EngineArgs: Chunk size of moe input. """ + enable_moe_scores_elementwise_fuse: bool = False + """ + Flag to enable fused elementwise cast in get_moe_scores. Default is False (disabled). + """ + cache_transfer_protocol: str = "ipc,rdma" """ Protocol to use for cache transfer. @@ -1390,7 +1395,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_overlap_schedule, help="Enable overlapping schedule.", ) - + scheduler_group.add_argument( + "--enable-moe-scores-elementwise-fuse", + action="store_true", + default=EngineArgs.enable_moe_scores_elementwise_fuse, + help="Enable fused elementwise cast in get_moe_scores for MoE routing.", + ) model_group.add_argument( "--deploy-modality", type=str, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 210b6f4bf26..996a4f9d68a 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -668,6 +668,7 @@ def _start_worker_service(self): "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, + "enable_moe_scores_elementwise_fuse": self.cfg.scheduler_config.enable_moe_scores_elementwise_fuse, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..44d7e54ae88 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -0,0 +1,73 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +_FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = None + +try: + from fastdeploy.model_executor.ops.gpu import ( + fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, + ) +except ImportError as e: + _fused_cast_sigmoid_bias_cuda = None + _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = e + + +def is_available() -> bool: + """Return whether the fused GPU custom op is available.""" + return _fused_cast_sigmoid_bias_cuda is not None + + +def fused_cast_sigmoid_bias( + gate_out: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + cast_type: str = "float32", +) -> tuple: + """ + Fused operation: cast gate_out to the specified type, apply sigmoid, and add bias. + + This function fuses the following three separate operations: + 1. gate_out = gate_out.cast(cast_type) + 2. scores = sigmoid(gate_out) + 3. scores_with_bias = scores + e_score_correction_bias + + Args: + gate_out: [num_tokens, num_experts], bf16/fp16/fp32 dtype - raw gate output + e_score_correction_bias: [num_experts], fp32 dtype - correction bias + cast_type: output dtype string, supports "float32", "float16", "bfloat16" + + Returns: + scores: [num_tokens, num_experts], cast_type dtype - result of sigmoid(gate_out) + scores_with_bias: [num_tokens, num_experts], cast_type dtype - scores with bias added + + Precision: + All intermediate computations (cast, sigmoid, bias addition) are performed + in float32 precision; conversion to cast_type happens only at the final store. + When cast_type is "float32", the result is bit-exact with the following + reference implementation: + gate_fp32 = gate_out.cast("float32") + scores = sigmoid(gate_fp32) + scores_with_bias = scores + bias + When cast_type is "float16"/"bfloat16", the only precision loss comes from + the final type conversion, equivalent to calling .cast(cast_type) after + computing in float32. + """ + if _fused_cast_sigmoid_bias_cuda is None: + raise ImportError( + "fused_cast_sigmoid_bias is not available. " "Please ensure the GPU custom ops are compiled." + ) from _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR + return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias, cast_type) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 92e039dd742..dbd0679ebcd 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -340,9 +340,13 @@ def apply_tp( Paddle Cutlass compute Fused MoE. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -351,8 +355,10 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -405,6 +411,11 @@ def apply_tp( return fused_moe_out if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -414,6 +425,7 @@ def apply_tp( layer.gate_correction_bias, getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) ( @@ -438,6 +450,7 @@ def apply_tp( topk_only_mode=True, ) else: + gate_out = gate_out.cast("float32") ( permute_input, token_nums_per_expert, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index a16e5ccbe9c..adeac2f2cf4 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -742,9 +742,13 @@ def apply_tp( below is TP compute method. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( gate_out, layer.n_group, @@ -754,8 +758,10 @@ def apply_tp( layer.gate_correction_bias, getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 65d1d23b9be..64e8aa38fe5 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -28,6 +28,7 @@ set_weight_attrs, weight_fully_copied, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import ceil_div, register_custom_python_op from ..quantization.quant_base import QuantMethodBase @@ -299,7 +300,7 @@ def apply( if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) gate_out = gate(x) - gate_out = gate_out.cast("float32") + top_k = layer.top_k num_local_experts = layer.num_local_experts top_k = layer.top_k @@ -307,6 +308,11 @@ def apply( hidden_size = layer.hidden_size if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -315,8 +321,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 9cb9340cf01..b98a911a225 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -41,6 +41,11 @@ logger.warning("import noaux_tc Failed!") import numpy as np +if current_platform.is_cuda(): + from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + ) + def get_moe_method(layer=None): """ @@ -91,14 +96,17 @@ def get_moe_scores( tokens_per_expert_stats_list: paddle.Tensor = None, redundant_ep_rank_num_plus_one: int = 1, topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, + use_fused_cast: bool = False, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. """ - scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - scores_with_bias = scores + e_score_correction_bias - + if use_fused_cast and current_platform.is_cuda(): + scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias, cast_type="float32") + else: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias if envs.FD_USE_PHI_MOE_TOPK: # calculate renormalize and routed_scaling_factor value outside the noaux_tc original_renormalize = renormalize diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index 1422b2635f3..90016479063 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -273,6 +273,7 @@ def __init__(self, args): self.max_num_seqs = 34 self.splitwise_role = "mixed" self.enable_overlap_schedule = False + self.enable_moe_scores_elementwise_fuse = False self.config = None for key, value in args.items(): diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 5c988c5fae1..37e23be2d50 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -906,6 +906,11 @@ def parse_args(): action="store_true", help="enable chunked moe", ) + parser.add_argument( + "--enable_moe_scores_elementwise_fuse", + action="store_true", + help="enable fused elementwise cast in get_moe_scores", + ) parser.add_argument( "--chunked_moe_size", type=int, diff --git a/tests/layers/test_deepgemm_fused_moe.py b/tests/layers/test_deepgemm_fused_moe.py index 5381ee866a3..4ec3e017e20 100644 --- a/tests/layers/test_deepgemm_fused_moe.py +++ b/tests/layers/test_deepgemm_fused_moe.py @@ -106,7 +106,10 @@ def __init__(self): # ep_size * this = max tokens buffer for masked GEMM; must be ≥ aligned M num_max_dispatch_tokens_per_rank=128, ) - self.scheduler_config = types.SimpleNamespace(max_num_batched_tokens=NUM_TOKENS) + self.scheduler_config = types.SimpleNamespace( + max_num_batched_tokens=NUM_TOKENS, + enable_moe_scores_elementwise_fuse=False, + ) self.parallel_config = types.SimpleNamespace(tensor_parallel_size=1) @@ -205,6 +208,22 @@ def hook(topk_ids): assert "topk_ids" in captured assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + @requires_deepgemm + def test_apply_tp_noaux_tc_with_use_fused_true(self): + """noaux_tc path with enable_moe_scores_elementwise_fuse=True: triggers use_fused=True (no gate_out.cast).""" + layer = DummyLayer() + layer.topk_method = "noaux_tc" + gate = DummyGate(layer.num_local_experts) + method = _make_method() + + x = paddle.randn([NUM_TOKENS, HIDDEN_SIZE], dtype="bfloat16") + + # Enable flag to exercise the fused path (use_fused=True) + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse = True + + out = method.apply(layer, x, gate) + assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + @requires_deepgemm def test_apply_tp_aux_path(self): """Non-noaux_tc: moe_topk_select → fp8_quant_blockwise → moe_permute → deepgemm → moe_unpermute.""" diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..21bfb0901fd --- /dev/null +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -0,0 +1,497 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib +import os +import sys +from unittest import mock + +import paddle +import paddle.nn.functional as F +import pytest + +from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + is_available, +) + +DTYPE_MAP = { + "float16": paddle.float16, + "bfloat16": paddle.bfloat16, + "float32": paddle.float32, +} + + +def _ensure_gpu_test_environment(): + """Ensure GPU runtime and required custom ops are available for this test module.""" + if not paddle.is_compiled_with_cuda(): + pytest.skip( + "fused_cast_sigmoid_bias requires CUDA-enabled Paddle.", + allow_module_level=True, + ) + paddle.set_device("gpu") + + +_ensure_gpu_test_environment() + + +def reference_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): + """Reference implementation: compute in fp32, cast output to cast_type.""" + gate_fp32 = gate_out.cast("float32") + scores_fp32 = F.sigmoid(gate_fp32) + scores_with_bias_fp32 = scores_fp32 + bias + scores = scores_fp32.cast(cast_type) + scores_with_bias = scores_with_bias_fp32.cast(cast_type) + return scores, scores_with_bias + + +def test_functionality(): + """Test basic functionality: correct shapes and dtypes (default cast_type=float32).""" + print("=" * 60) + print("Test 1: Functionality (default cast_type=float32)") + print("=" * 60) + + for dtype_name in ["float16", "bfloat16", "float32"]: + for num_tokens in [1, 7, 128, 1024]: + for num_experts in [8, 64, 128, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + assert scores.shape == [ + num_tokens, + num_experts, + ], f"scores shape mismatch: {scores.shape} vs {[num_tokens, num_experts]}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert scores.dtype == paddle.float32, f"scores dtype mismatch: {scores.dtype}" + assert ( + scores_with_bias.dtype == paddle.float32 + ), f"scores_with_bias dtype mismatch: {scores_with_bias.dtype}" + + # Sigmoid output should be in [0, 1] + assert bool(paddle.all(scores >= 0.0).item()) and bool( + paddle.all(scores <= 1.0).item() + ), "scores out of [0,1] range" + print(f" [PASS] dtype={dtype_name}") + + print(" All functionality tests passed.\n") + + +def test_functionality_cast_types(): + """Test functionality with different cast_type values.""" + print("=" * 60) + print("Test 1b: Functionality with different cast_type") + print("=" * 60) + + for input_dtype in ["float16", "bfloat16", "float32"]: + for cast_type in ["float16", "bfloat16", "float32"]: + expected_paddle_dtype = DTYPE_MAP[cast_type] + for num_tokens in [1, 64, 256]: + for num_experts in [8, 64, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + assert scores.shape == [num_tokens, num_experts], f"scores shape mismatch: {scores.shape}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert ( + scores.dtype == expected_paddle_dtype + ), f"scores dtype mismatch: got {scores.dtype}, expected {expected_paddle_dtype}" + assert ( + scores_with_bias.dtype == expected_paddle_dtype + ), f"scores_with_bias dtype mismatch: got {scores_with_bias.dtype}, expected {expected_paddle_dtype}" + + print(f" [PASS] input_dtype={input_dtype}, cast_type={cast_type}") + + print(" All cast_type functionality tests passed.\n") + + +def test_accuracy(): + """Test numerical accuracy against reference implementation (default cast_type=float32).""" + print("=" * 60) + print("Test 2: Accuracy (default cast_type=float32)") + print("=" * 60) + + test_cases = [ + ("float16", 1, 8), + ("float16", 128, 256), + ("float16", 1024, 256), + ("bfloat16", 1, 8), + ("bfloat16", 128, 256), + ("bfloat16", 1024, 256), + ("float32", 1, 8), + ("float32", 128, 256), + ("float32", 1024, 256), + ] + + for dtype_name, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias) + + # Compare + scores_diff = paddle.abs(fused_scores - ref_scores).max().item() + scores_bias_diff = paddle.abs(fused_scores_with_bias - ref_scores_with_bias).max().item() + + atol = 1e-6 if dtype_name == "float32" else 1e-3 + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] dtype={dtype_name}, tokens={num_tokens}, experts={num_experts} | " + f"scores_max_diff={scores_diff:.2e}, scores_with_bias_max_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for dtype={dtype_name}, tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, scores_bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All accuracy tests passed.\n") + + +def test_accuracy_cast_types(): + """Test numerical accuracy with different cast_type values.""" + print("=" * 60) + print("Test 2b: Accuracy with different cast_type") + print("=" * 60) + + # (input_dtype, cast_type, num_tokens, num_experts) + test_cases = [ + # cast to float32 (original behavior) + ("float16", "float32", 128, 256), + ("bfloat16", "float32", 128, 256), + ("float32", "float32", 128, 256), + # cast to float16 + ("float16", "float16", 128, 256), + ("bfloat16", "float16", 128, 256), + ("float32", "float16", 128, 256), + # cast to bfloat16 + ("float16", "bfloat16", 128, 256), + ("bfloat16", "bfloat16", 128, 256), + ("float32", "bfloat16", 128, 256), + # different shapes + ("bfloat16", "float16", 1, 8), + ("bfloat16", "float16", 1024, 256), + ("float16", "bfloat16", 1, 8), + ("float16", "bfloat16", 1024, 256), + ] + + for input_dtype, cast_type, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Compare in float32 for stable diff computation + scores_diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + scores_bias_diff = ( + paddle.abs(fused_scores_with_bias.cast("float32") - ref_scores_with_bias.cast("float32")).max().item() + ) + + # Tolerance depends on cast_type precision + if cast_type == "float32": + atol = 1e-6 + elif cast_type == "bfloat16": + atol = 1e-2 # bfloat16 has fewer mantissa bits + else: # float16 + atol = 1e-3 + + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts} | " + f"scores_diff={scores_diff:.2e}, bias_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All cast_type accuracy tests passed.\n") + + +def test_accuracy_extreme_values(): + """Test accuracy with extreme input values.""" + print("=" * 60) + print("Test 3: Accuracy with extreme values") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for dtype_name in ["float16", "bfloat16"]: + # Large positive values -> sigmoid ~ 1.0 + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=dtype_name) + bias = paddle.zeros([num_experts], dtype="float32") + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large positive: max_diff={diff:.2e}") + + # Large negative values -> sigmoid ~ 0.0 + gate_out = paddle.full([num_tokens, num_experts], -10.0, dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large negative: max_diff={diff:.2e}") + + # Zero values -> sigmoid = 0.5 + gate_out = paddle.zeros([num_tokens, num_experts], dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + assert diff < 1e-6, f"Zero input test failed: diff={diff}" + print(f" [PASS] dtype={dtype_name}, zeros: max_diff={diff:.2e}") + + print(" All extreme value tests passed.\n") + + +def test_accuracy_extreme_values_cast_types(): + """Test accuracy with extreme values across different cast_type values.""" + print("=" * 60) + print("Test 3b: Accuracy with extreme values + different cast_type") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for input_dtype in ["float16", "bfloat16"]: + for cast_type in ["float16", "bfloat16", "float32"]: + bias = paddle.zeros([num_experts], dtype="float32") + + # Large positive + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + status = "PASS" if diff < atol else "FAIL" + print(f" [{status}] input={input_dtype}, cast={cast_type}, " f"large positive: diff={diff:.2e}") + + # Zero values + gate_out = paddle.zeros([num_tokens, num_experts], dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + assert diff < atol, f"Zero input test failed: input={input_dtype}, cast={cast_type}, diff={diff}" + print(f" [PASS] input={input_dtype}, cast={cast_type}, " f"zeros: diff={diff:.2e}") + + print(" All extreme value cast_type tests passed.\n") + + +@pytest.mark.skipif( + os.getenv("RUN_PERFORMANCE_TESTS") != "1", + reason="Performance benchmark is disabled by default. Set RUN_PERFORMANCE_TESTS=1 to enable.", +) +def test_performance(): + """Benchmark fused kernel vs reference implementation using CUDA events.""" + print("=" * 60) + print("Test 4: Performance (CUDA event timing)") + print("=" * 60) + + configs = [ + ("bfloat16", 1, 256), # single token decode + ("bfloat16", 8, 256), # small batch decode + ("bfloat16", 64, 256), # medium batch + ("bfloat16", 256, 256), # typical DeepSeek-V3 config + ("bfloat16", 1024, 256), # large prefill + ("bfloat16", 4096, 256), # very large prefill + ] + + warmup_iters = 100 + bench_iters = 500 + + for dtype_name, num_tokens, num_experts in configs: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Warmup fused + for _ in range(warmup_iters): + fused_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark fused with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + fused_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + fused_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + # Warmup reference + for _ in range(warmup_iters): + reference_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark reference with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + reference_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + ref_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + speedup = ref_time / fused_time if fused_time > 0 else float("inf") + print( + f" tokens={num_tokens:5d}, experts={num_experts:3d} | " + f"ref={ref_time:8.1f}us, fused={fused_time:8.1f}us, speedup={speedup:.2f}x" + ) + + print() + print(" Note: The CUDA custom op fuses cast+sigmoid+bias into a single kernel,") + print(" eliminating 2 intermediate tensors and reducing kernel launches from 3 to 1.") + print(" Expected speedup: ~3x over the reference 3-op implementation.") + print(" Performance benchmark complete.\n") + + +def test_is_available(): + """Test is_available() function returns True when GPU ops are available.""" + print("=" * 60) + print("Test: is_available()") + print("=" * 60) + + # In normal GPU test environment, is_available should return True + result = is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is True, f"is_available() should return True when GPU ops are compiled, got {result}" + print(f" [PASS] is_available() returned {result}") + print(" is_available() test passed.\n") + + +def test_import_error(): + """Test that ImportError is raised when GPU ops are not available.""" + print("=" * 60) + print("Test 5: Import error handling") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # The module should load successfully, but calling the function + # should raise ImportError because the cuda op is unavailable. + dummy_gate = paddle.randn([1, 8], dtype="float32") + dummy_bias = paddle.randn([8], dtype="float32") + try: + reloaded.fused_cast_sigmoid_bias(dummy_gate, dummy_bias) + raise AssertionError("Expected ImportError was not raised") + except ImportError as e: + assert "fused_cast_sigmoid_bias is not available" in str(e), f"Unexpected error message: {e}" + print(f" [PASS] ImportError raised with correct message: {e}") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" Import error handling test passed.\n") + + +def test_is_available_when_ops_unavailable(): + """Test is_available() returns False when GPU ops are not available.""" + print("=" * 60) + print("Test: is_available() when ops unavailable") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # is_available should return False when ops are not available + result = reloaded.is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is False, f"is_available() should return False when GPU ops are unavailable, got {result}" + print(f" [PASS] is_available() returned {result} when ops unavailable") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" is_available() when ops unavailable test passed.\n") + + +if __name__ == "__main__": + print("Running fused_cast_sigmoid_bias tests...\n") + + test_is_available() + test_functionality() + test_functionality_cast_types() + test_accuracy() + test_accuracy_cast_types() + test_accuracy_extreme_values() + test_accuracy_extreme_values_cast_types() + test_import_error() + test_is_available_when_ops_unavailable() + if os.getenv("RUN_PERFORMANCE_TESTS") == "1": + test_performance() + else: + print("Skipping performance benchmark. Set RUN_PERFORMANCE_TESTS=1 to enable.\n") + + print("=" * 60) + print("All tests passed!") + print("=" * 60) diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 98185a04c38..f6f92fb44da 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -58,6 +58,7 @@ class DummyFDConfig: def __init__(self, load_choices="default_v1"): self.model_config = types.SimpleNamespace(model="dummy", prefix_layer_name="prefix") self.load_config = types.SimpleNamespace(load_choices=load_choices) + self.scheduler_config = types.SimpleNamespace(enable_moe_scores_elementwise_fuse=False) class DummyLayer(paddle.nn.Layer): @@ -394,7 +395,15 @@ def combine(self, ffn_out, topk_idx, topk_weights, handle, quant_group_size=-1): def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch): def fake_get_moe_scores( - gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, ): return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) @@ -831,6 +840,62 @@ def spy_permute(*args, **kwargs): assert not paddle.isnan(out).any(), "output contains NaN" assert not paddle.isinf(out).any(), "output contains Inf" + def test_apply_tp_noaux_tc_with_use_fused_true(self, monkeypatch): + def fake_get_moe_scores( + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, + ): + return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) + + def fake_dispatch(*args, **kwargs): + return ( + paddle.ones([1, 2]), + paddle.to_tensor([1, 0]), + paddle.to_tensor([0]), + paddle.to_tensor([[0.6, 0.4]]), + paddle.to_tensor([[0, 1]]), + paddle.to_tensor([0]), + None, + None, + ) + + def fake_reduce(*args, **kwargs): + return paddle.ones([1, 2]) * 5 + + def fake_compute_ffn(*args, **kwargs): + return paddle.ones([1, 2]) * 2 + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores, raising=False) + monkeypatch.setattr(backend, "moe_expert_dispatch", fake_dispatch, raising=False) + monkeypatch.setattr(backend, "moe_expert_reduce", fake_reduce, raising=False) + + # Mock compute_ffn on the class to avoid real GPU op data type issues + monkeypatch.setattr(backend.CutlassMoEMethod, "compute_ffn", fake_compute_ffn) + + # Enable enable_moe_scores_elementwise_fuse and force is_cuda=True to trigger use_fused = True + monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: True)) + layer = DummyLayer(with_bias=False) + layer.topk_method = "noaux_tc" + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse = True + # Add necessary attributes for compute_ffn access + layer.up_gate_proj_weight = paddle.zeros([2, 2 * 1], dtype="float16") + layer.down_proj_weight = paddle.zeros([2, 2], dtype="float16") + layer.activation = "silu" + + method = backend.CutlassMoEMethod(None) + + x = paddle.ones([1, 2]) + gate = paddle.nn.Identity() + + method.apply(layer, x, gate) + @requires_cuda def test_apply_ep_prefill_moe_permute_real_ops(self, monkeypatch): """FD_USE_PHI_MOE_PERMUTE=True + w16a16: EP prefill uses real moe_permute / diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 7dacbbe390d..49ea579e562 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -57,6 +57,7 @@ class DummyFDConfig: def __init__(self, load_choices="default_v1"): self.load_config = DummyLoadConfig(load_choices) self.model_config = types.SimpleNamespace(enable_cache=False) + self.scheduler_config = types.SimpleNamespace(enable_moe_scores_elementwise_fuse=False) class DummyGate(paddle.nn.Layer): @@ -695,3 +696,92 @@ def fake_transform_scale_ue8m0(sf, mn, weight_block_size=None): # Verify the quant_weight_ue8m0 branch was executed assert len(quant_calls) > 0, "quant_weight_ue8m0 should have been called" assert len(transform_calls) > 0, "transform_scale_ue8m0 should have been called" + + def test_triton_weight_only_apply_noaux_tc_with_use_fused_true(self, fake_ops, monkeypatch): + quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + layer = DummyLayer(quant_config) + layer.topk_method = "noaux_tc" + method = backend.TritonWeightOnlyMoEMethod(quant_config) + method.create_weights(layer, model_format="torch") + + layer._up_weights = [ + paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( + [layer.hidden_size, layer.moe_intermediate_size * 2] + ) + for _ in range(layer.num_local_experts) + ] + layer._down_weights = [ + paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( + [layer.moe_intermediate_size, layer.hidden_size] + ) + for _ in range(layer.num_local_experts) + ] + method.process_loaded_weights(layer, state_dict={}) + + kernel = DummyKernel() + monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + + # Enable enable_moe_scores_elementwise_fuse and force is_cuda=True to trigger use_fused = True + monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: True)) + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse = True + + x = paddle.randn([1, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) + assert "topk_ids" in captured + + def test_triton_weight_only_apply_noaux_tc_with_non_cuda(self, fake_ops, monkeypatch): + quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + layer = DummyLayer(quant_config) + # Ensure topk_method is "noaux_tc" to enter the target branch + layer.topk_method = "noaux_tc" + method = backend.TritonWeightOnlyMoEMethod(quant_config) + method.create_weights(layer, model_format="torch") + + layer._up_weights = [ + paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( + [layer.hidden_size, layer.moe_intermediate_size * 2] + ) + for _ in range(layer.num_local_experts) + ] + layer._down_weights = [ + paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( + [layer.moe_intermediate_size, layer.hidden_size] + ) + for _ in range(layer.num_local_experts) + ] + method.process_loaded_weights(layer, state_dict={}) + + kernel = DummyKernel() + monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + + # Mock current_platform.is_cuda() to return False to trigger use_fused = False at line 313 + # This should trigger gate_out.cast("float32") at line 315 + monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: False)) + + x = paddle.randn([2, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + + def fake_get_moe_scores(*args, **kwargs): + gate_out = args[0] + token_num = gate_out.shape[0] + top_k = args[3] + topk_ids = paddle.zeros([token_num, top_k], dtype="int64") + topk_weights = paddle.ones([token_num, top_k], dtype="float32") + return gate_out, topk_weights, topk_ids + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) + assert "topk_ids" in captured From dc1fea1ad9e51e47f973364ec9596cfb1c9a77c0 Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Wed, 13 May 2026 20:06:14 +0800 Subject: [PATCH 106/143] [Cherry-Pick] [BugFix] Fix abort when enabling overlap schedule (#7800) (#7801) * [BugFix] Fix abort when enabling overlap schedule * [fix] fix pin memory tensor copy --- fastdeploy/worker/gpu_model_runner.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index b06a8de0564..72830fe0183 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2163,7 +2163,9 @@ def _make_preempted_batch_output(self): PREEMPTED_TOKEN_ID, -1, ).astype("int64") - self.share_inputs["sampled_token_ids"][:bsz].copy_(fake_sampled_token_ids, False) + sampled_token_ids = self.share_inputs["sampled_token_ids"].cpu() + sampled_token_ids[:bsz].copy_(fake_sampled_token_ids, True) + self.share_inputs["sampled_token_ids"].copy_(sampled_token_ids, True) fake_logprobs_tensors = None if self.enable_logprob: @@ -2174,10 +2176,12 @@ def _make_preempted_batch_output(self): ) if self.speculative_decoding: - self.share_inputs["accept_tokens_cpu"][:bsz].fill_(0) - self.share_inputs["accept_num_cpu"][:bsz].fill_(0) - self.share_inputs["seq_lens_decoder_cpu"][:bsz].copy_(self.share_inputs["seq_lens_decoder"][:bsz], False) - self.share_inputs["prompt_lens_cpu"][:bsz].copy_(self.share_inputs["prompt_lens"][:bsz], False) + self.share_inputs["accept_tokens"][:bsz].fill_(0) + self.share_inputs["accept_num"][:bsz].fill_(0) + self.share_inputs["accept_tokens_cpu"].copy_(self.share_inputs["accept_tokens"], True) + self.share_inputs["accept_num_cpu"].copy_(self.share_inputs["accept_num"], True) + self.share_inputs["seq_lens_decoder_cpu"].copy_(self.share_inputs["seq_lens_decoder"], True) + self.share_inputs["prompt_lens_cpu"].copy_(self.share_inputs["prompt_lens"], True) sampler_output = SamplerOutput( sampled_token_ids=fake_sampled_token_ids, logprobs_tensors=fake_logprobs_tensors, From 478c9faceeb8811032a751be4b9cca7f12071220 Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Thu, 14 May 2026 10:35:52 +0800 Subject: [PATCH 107/143] [RL] pause: use abort pipeline with scheduling loop alive for graceful termination (#7806) Replace the old preempted_all + error_response approach in _control_pause with a two-phase design: Phase 1: Block new requests via _rejecting_new_requests (NOT is_paused) - Scheduling loop keeps running so _trigger_abort can process - add_abort_req_ids(ALL) marks all requests for abort - Scheduling loop catches them via _trigger_abort as they cycle through Phase 2: After drain, set is_paused=True to fully stop scheduling loop - Handle scheduler-only stragglers with direct _send_error_response - Wait for output queue empty, then reset Depends-on: #7615 (refact abort_requests to fire-and-forget) --- fastdeploy/engine/common_engine.py | 58 +++++++++++++++--------------- tests/engine/test_common_engine.py | 19 ++++++---- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 65bfa6d606b..7a8a641573a 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -139,6 +139,7 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): self.is_paused = False # pause request generation self._pause_cond = threading.Condition() + self._rejecting_new_requests = False # blocks new requests during abort drain self._ctrl_output_queues = {} self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict) @@ -1106,7 +1107,7 @@ def _insert_zmq_task_to_scheduler(self): trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) self.llm_logger.debug(f"Receive request from api server: {request}") - if self.is_paused: + if self.is_paused or self._rejecting_new_requests: self.llm_logger.warning(f"Engine is paused, drop request: {request}") self._send_error_response( request.request_id, @@ -1226,39 +1227,20 @@ def _control_pause(self, control_request: ControlRequest): if self.is_paused: self.llm_logger.info("Engine is already paused, no need to pause again.") return - self.is_paused = True + self._rejecting_new_requests = True + self.resource_manager.log_status() - self.llm_logger.info("Abort running requests.") + # Scheduling loop picks them up via _trigger_abort when they enter resource_manager + all_req_ids = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) + self.llm_logger.info(f"Pause: aborting {len(all_req_ids)} total requests.") + if all_req_ids: + self.resource_manager.add_abort_req_ids(all_req_ids) + self._wait_inflight_drained() - self.resource_manager.log_status() - # preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue - timeout, count = 60, 0 - while self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - count += 1 - if count >= timeout * 1000: - break - if count >= timeout * 1000: - error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!" - self.llm_logger.error(error_msg) - raise Exception(error_msg) - running_reqs = self.resource_manager.preempted_all() - if len(running_reqs) > 0: - self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.") - self.resource_manager.get_real_bsz() - self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz)) - self.resource_manager.wait_worker_inflight_requests_finish(timeout=60) - # self.engine_worker_queue.clear_data() - self.token_processor.clear_data() + with self._pause_cond: + self.is_paused = True self.resource_manager.log_status() - # abort inflight requests to user - inflight_requests = self.scheduler.get_inflight_requests() - self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).") - for req in inflight_requests: - self._send_error_response(req.request_id, "Request is aborted since engine is paused.") - self.scheduler.reset() - # pause cache transfer if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: self.llm_logger.info("Start to pause cache transfer.") @@ -1278,6 +1260,21 @@ def _control_pause(self, control_request: ControlRequest): self.llm_logger.info("Successfully paused request generation.") return None + def _wait_inflight_drained(self): + """ + Wait until resource_manager.requests is completely empty. + No timeout — abort pipeline will complete. Aligned with SGLang's poll-until-drained. + """ + start_time = time.time() + while ( + self.resource_manager.requests + or self.scheduler.requests + or self.resource_manager.waiting_abort_req_id_set + or self.resource_manager.to_be_aborted_req_id_set + ): + time.sleep(0.005) + self.llm_logger.info(f"All inflight requests drained, take time: {time.time() - start_time:.3f} seconds") + def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: """Control function for resuming request generation. @@ -1293,6 +1290,7 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: self.llm_logger.info("Engine is not paused, no need to resume.") return None self.is_paused = False + self._rejecting_new_requests = False self._pause_cond.notify_all() # resume cache transfer diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 833bd5008da..54e613c8637 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -1139,22 +1139,29 @@ def test_control_pause_and_resume_paths(self): eng = self._make_mixed_engine() eng.is_paused = False eng._pause_cond = threading.Condition() - eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False)) eng.resource_manager = Mock( - preempted_all=Mock(return_value=[Request(request_id="r1", prompt_token_ids=[1], prompt_token_ids_len=1)]), - get_real_bsz=Mock(), - wait_worker_inflight_requests_finish=Mock(), + requests={"r1": Mock(output_token_ids=[1, 2, 3])}, + waiting_abort_req_id_set=set(), + to_be_aborted_req_id_set=set(), + add_abort_req_ids=Mock(), log_status=Mock(), cache_manager=Mock(reset=Mock()), - real_bsz=1, ) eng.token_processor = Mock(clear_data=Mock()) - eng.scheduler = Mock(get_inflight_requests=Mock(return_value=[]), reset=Mock()) + mock_scheduler = Mock(reset=Mock()) + mock_scheduler.requests = {} + mock_scheduler.mutex = threading.Lock() + mock_scheduler.responses = {} + mock_scheduler.batch_responses_per_step = [] + eng.scheduler = mock_scheduler eng._send_error_response = Mock() + eng._wait_inflight_drained = Mock() with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", True): eng._control_pause(ControlRequest(request_id="ctrl1", method="pause")) self.assertTrue(eng.is_paused) + eng.resource_manager.add_abort_req_ids.assert_called_once() eng._control_resume(ControlRequest(request_id="ctrl2", method="resume")) self.assertFalse(eng.is_paused) From d02f3ba8e2f8a054c5378b53492827e4e719163a Mon Sep 17 00:00:00 2001 From: xuanyuanminzheng <53077335+xuanyuanminzheng@users.noreply.github.com> Date: Thu, 14 May 2026 16:42:49 +0800 Subject: [PATCH 108/143] [Feature] Add TritonMoEMethod for BF16 MoE inference (#7815) --- .../model_executor/layers/moe/__init__.py | 3 +- .../layers/moe/fused_moe_triton_backend.py | 268 +++++- fastdeploy/model_executor/layers/moe/moe.py | 5 + .../layers/moe/triton_moe_kernels.py | 139 +++ tests/layers/test_fused_moe_triton_backend.py | 791 +++++++++++++++++- 5 files changed, 1151 insertions(+), 55 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/__init__.py b/fastdeploy/model_executor/layers/moe/__init__.py index 540a0828ae5..7f2ded19cb6 100644 --- a/fastdeploy/model_executor/layers/moe/__init__.py +++ b/fastdeploy/model_executor/layers/moe/__init__.py @@ -17,7 +17,7 @@ CutlassW4AFP8MoEMethod, CutlassWeightOnlyMoEMethod, ) -from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod +from .fused_moe_triton_backend import TritonMoEMethod, TritonWeightOnlyMoEMethod from .moe import FusedMoE __all__ = [ @@ -26,4 +26,5 @@ CutlassW4AFP8MoEMethod, FusedMoE, TritonWeightOnlyMoEMethod, + TritonMoEMethod, ] diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 64e8aa38fe5..bd669deedc0 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -20,6 +20,17 @@ from paddle import nn import fastdeploy +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.moe.triton_moe_kernels import ( + fused_moe_kernel_bf16, + fused_moe_kernel_paddle, +) +from fastdeploy.model_executor.layers.quantization.fp8_utils import ( + fused_stack_transpose_quant, + quant_weight_ue8m0, + transform_scale_ue8m0, +) +from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import ( TensorTracker, @@ -31,21 +42,15 @@ from fastdeploy.platforms import current_platform from fastdeploy.utils import ceil_div, register_custom_python_op -from ..quantization.quant_base import QuantMethodBase - try: - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func + import triton.language as tl - from .triton_moe_kernels import fused_moe_kernel_paddle + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func except ImportError: pass -from fastdeploy.model_executor.layers.moe.moe import get_moe_scores -from fastdeploy.model_executor.layers.quantization.fp8_utils import ( - fused_stack_transpose_quant, - quant_weight_ue8m0, - transform_scale_ue8m0, -) -from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant + +from ..quantization.quant_base import QuantMethodBase +from .fused_moe_backend_base import UnquantizedFusedMoEMethod class TritonWeightOnlyMoEMethod(QuantMethodBase): @@ -778,8 +783,8 @@ def apply( stride_am=x_q.strides[0], stride_ak=x_q.strides[1], stride_be=layer.up_gate_proj_weight.strides[0], - stride_bk=layer.up_gate_proj_weight.strides[2], - stride_bn=layer.up_gate_proj_weight.strides[1], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], stride_cm=up_gate_proj_out.strides[0], stride_cn=up_gate_proj_out.strides[1], # @@ -1846,3 +1851,240 @@ def apply( self.quant_config, topk_ids_hookfunc, ) + + +class TritonMoEMethod(UnquantizedFusedMoEMethod): + """ + Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. + + Activated via: export FD_MOE_BACKEND=triton + Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj. + This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA. + """ + + def __init__(self, quant_config=None): + super().__init__(quant_config) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """Stack individual expert weights into the stacked parameter.""" + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) + layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) + + def _get_default_config(self, M: int, E: int) -> dict: + """ + Heuristic tile config for BF16 MoE, ported verbatim from vLLM's + `get_default_config` (bf16/fp16 non-block_shape branch). + See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319. + + M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count. + E: number of (local) experts. + """ + + # Tile sizes scale with batch: small batches are memory-bound + # (favor tall-K tiles), large batches are compute-bound (favor + # large M/N tiles with more warps). + if M <= 32: + block_m = 16 + elif M <= 96: + block_m = 32 + elif M <= 512: + block_m = 64 + else: + block_m = 128 + + block_n = 64 if M <= 64 else 128 + + block_k = 64 + + # Grouping adjacent M-blocks lets them share weight tiles in L2. + # Only helps when there are enough M-blocks per expert to group; + # with many experts each one sees few tokens so grouping is useless. + tokens_per_expert = M // max(E, 1) + group_m = 16 if tokens_per_expert > 128 else 1 + + # Large batches have enough blocks to saturate the GPU, so we + # use more warps per block to increase arithmetic intensity. + num_warps = 4 if M <= 128 else 8 + + num_stages = 4 if M <= 32 else 3 + + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + } + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + topk_ids_hookfunc: Callable = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, + ) -> paddle.Tensor: + """ + BF16 Triton Fused MoE forward. + + Pipeline: + 1. Gate + topk routing + 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded + 3. fused_moe_kernel_bf16 GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] + 4. SwiGLU activation + 5. fused_moe_kernel_bf16 GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] + (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) + 6. Reshape + sum over topk dim + """ + token_num = x.shape[0] + if token_num == 0: + return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) + + top_k = layer.top_k + num_local_experts = layer.num_local_experts + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + + # --- 1. Routing --- + gate_out = gate(x) + + if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") + + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + use_fused_cast=use_fused, + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + ) + else: + gate_out = gate_out.cast("float32") + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight + False, + ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + + # --- 2. Preprocess: sort tokens by expert assignment --- + num_token_expert_pairs = token_num * top_k + # vLLM convention: pass num_tokens (pre-expansion), NOT tokens*top_k. + cfg = self._get_default_config(token_num, num_local_experts) + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + + # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- + # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn + up_gate_proj_out = paddle.empty( + [num_token_expert_pairs, moe_intermediate_size * 2], + dtype=x.dtype, + ) + grid1 = ( + ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid1]( + x, + layer.up_gate_proj_weight, + up_gate_proj_out, + None, # topk_weights_ptr (no weight mul on GEMM1) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N=moe_intermediate_size * 2, + K=hidden_size, + EM=max_possible_num_post_padded, + num_valid_tokens=num_token_expert_pairs, + stride_am=x.strides[0], + stride_ak=x.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type=tl.bfloat16, + even_Ks=(hidden_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], + ) + + # --- 4. SwiGLU activation --- + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) + + # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- + # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn + down_proj_out = paddle.empty( + (num_token_expert_pairs, hidden_size), + dtype=x.dtype, + ) + grid2 = ( + ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid2]( + down_proj_input, + layer.down_proj_weight, + down_proj_out, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N=hidden_size, + K=moe_intermediate_size, + EM=max_possible_num_post_padded, + num_valid_tokens=num_token_expert_pairs, + stride_am=down_proj_input.strides[0], + stride_ak=down_proj_input.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + top_k=1, + compute_type=tl.bfloat16, + even_Ks=(moe_intermediate_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], + ) + + # --- 6. Reduce over topk --- + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) + return out + + def apply_ep_prefill( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + raise NotImplementedError("TritonMoEMethod does not support EP prefill yet.") + + def apply_ep_decode( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + raise NotImplementedError("TritonMoEMethod does not support EP decode yet.") diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index b98a911a225..df7b7f12c4b 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -53,6 +53,11 @@ def get_moe_method(layer=None): """ if current_platform.is_cuda(): + moe_backend = envs.FD_MOE_BACKEND.lower() + if moe_backend == "triton": + from .fused_moe_triton_backend import TritonMoEMethod + + return TritonMoEMethod(None) from .fused_moe_cutlass_backend import CutlassMoEMethod return CutlassMoEMethod(None) diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index ac5dfa96fcc..ac9e18480b6 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -198,3 +198,142 @@ def fused_moe_kernel_paddle( c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + + +# --------------------------------------------------------------------------- +# BF16-native MoE kernel, ported from vLLM fused_moe_kernel (BF16-only path). +# +# Key differences from fused_moe_kernel_paddle (the wint8/fp8 kernel above): +# 1. compute_type is a tl.constexpr parameter (not hardcoded bfloat16). +# 2. offs_token is cast to int64 to prevent stride-multiplication overflow. +# 3. b matrix load always uses a K-boundary mask (no even_Ks special path). +# 4. Router-weight multiplication is done in fp32 before the final cast. +# 5. No quantization paths (use_fp8/int8 removed for clarity). +# --------------------------------------------------------------------------- +@enable_compat_on_triton_kernel +@triton.jit +def fused_moe_kernel_bf16( # pragma: no cover -- Triton JIT; body compiles to GPU code + # Pointers + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Dimensions (runtime scalars) + N, + K, + EM, + num_valid_tokens, + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters (compile-time constants) + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + # naive_block_assignment: tl.constexpr = False, + even_Ks: tl.constexpr = False, +): + """ + BF16 Fused-MoE GEMM kernel, ported from vLLM. + + A: [num_tokens, K] – input activations (bf16) + B: [E, K, N] – expert weights (bf16) + C: [num_tokens * top_k, N] – output (bf16) + + sorted_token_ids: [EM] flat token-expert pair indices (int32) + expert_ids: [EM // BLOCK_SIZE_M] expert index per M-block (int32) + + When naive_block_assignment=True, each M-block processes exactly one + token-expert pair (skipping the preprocess/sort step). In this mode: + - expert_ids[pid_m] holds the expert index for token-expert pair pid_m + - sorted_token_ids_ptr is unused + - offs_token is constructed as [pid_m, invalid, invalid, ...] + This avoids the preprocess kernel overhead for very small token counts. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs = tl.arange(0, BLOCK_SIZE_M) + + offs_token_id = pid_m * BLOCK_SIZE_M + offs + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + + # if not naive_block_assignment: + # offs_token_id = pid_m * BLOCK_SIZE_M + offs + # offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + # else: + # # Each block handles exactly one token-expert pair: + # # row 0 = pid_m (the token-expert pair index), remaining rows are + # # set to num_valid_tokens which will fail the < mask check. + # offs_token = tl.where(offs == 0, pid_m, num_valid_tokens) + + # Cast to int64 to prevent overflow: stride_cm * offs_token can exceed int32 + offs_token = offs_token.to(tl.int64) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # A pointer: a_ptr[token_idx, :K] where token_idx = offs_token // top_k + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) + + # B pointer: b_ptr[expert, :K, offs_bn] — B layout is [E, K, N] + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if even_Ks: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Router-weight multiplication in fp32 (before precision conversion) + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 49ea579e562..b42db5cc3d3 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -20,6 +20,7 @@ import sys import types +import numpy as np import paddle import pytest @@ -37,6 +38,7 @@ def __init__(self, is_checkpoint_bf16=False, weight_block_size=(2, 2), name_valu self.weight_block_size = weight_block_size self._name_value = name_value self.deepgemm_scale_ue8m0 = False + self.moe_blockwise_gemm_scale_ue8m0 = False def name(self): return self._name_value @@ -57,7 +59,12 @@ class DummyFDConfig: def __init__(self, load_choices="default_v1"): self.load_config = DummyLoadConfig(load_choices) self.model_config = types.SimpleNamespace(enable_cache=False) - self.scheduler_config = types.SimpleNamespace(enable_moe_scores_elementwise_fuse=False) + self.scheduler_config = types.SimpleNamespace( + enable_moe_scores_elementwise_fuse=False, + splitwise_role="mixed", + max_num_seqs=8, + max_num_batched_tokens=256, + ) class DummyGate(paddle.nn.Layer): @@ -89,9 +96,14 @@ def __init__( self.n_group = 1 self.topk_group = 1 self.routed_scaling_factor = 1.0 + self.routed_scaling_factor_learnable = False self.renormalize = True self.gate_correction_bias = paddle.zeros([num_local_experts], dtype="float32") self.topk_method = "noaux_tc" + self.with_bias = False + self.ep_size = 1 + self.activation = "swiglu" + self.moe_quant_config = types.SimpleNamespace() self.fd_config = DummyFDConfig(load_choices) self.weight_dtype = weight_dtype self.quant_method = DummyQuantMethod(quant_config) @@ -210,10 +222,15 @@ def test_backend_imports_kernel_module(self, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) reloaded = importlib.reload(backend) assert hasattr(reloaded, "fused_moe_kernel_paddle") + # Restore the real module: reload() permanently rebinds module-level names + # (e.g. fused_moe_kernel_bf16) to the fake, and monkeypatch cannot undo that. + # A second reload after monkeypatch restores sys.modules fixes the binding. + monkeypatch.undo() + importlib.reload(backend) def test_triton_weight_only_create_and_apply(self, fake_ops, monkeypatch): quant_config = DummyQuantConfig(is_checkpoint_bf16=False) @@ -322,7 +339,7 @@ def test_wfp8afp8_method_apply_paths(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -396,7 +413,7 @@ def test_wfp8afp8_apply_noaux_and_empty(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) _ = method.apply( @@ -436,7 +453,7 @@ def test_tensorwise_prequant_and_apply(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -459,7 +476,7 @@ def test_python_op_fused_moe_kernel_paddle(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr( paddle.static, @@ -643,6 +660,7 @@ def test_blockwise_process_weights_ue8m0_branch(self, fake_ops, monkeypatch): """Test the quant_weight_ue8m0 branch in BlockWiseFP8MoEMethod.process_weights_after_loading.""" quant_config = DummyQuantConfig(is_checkpoint_bf16=True, weight_block_size=(128, 128)) quant_config.deepgemm_scale_ue8m0 = True + quant_config.moe_blockwise_gemm_scale_ue8m0 = True layer = DummyLayer(quant_config, weight_dtype="bfloat16") method = backend.BlockWiseFP8MoEMethod(quant_config) method.create_weights(layer, model_format="torch") @@ -697,7 +715,7 @@ def fake_transform_scale_ue8m0(sf, mn, weight_block_size=None): assert len(quant_calls) > 0, "quant_weight_ue8m0 should have been called" assert len(transform_calls) > 0, "transform_scale_ue8m0 should have been called" - def test_triton_weight_only_apply_noaux_tc_with_use_fused_true(self, fake_ops, monkeypatch): + def test_triton_weight_only_apply_noaux_tc_with_fd_enable_rl(self, fake_ops, monkeypatch): quant_config = DummyQuantConfig(is_checkpoint_bf16=False) layer = DummyLayer(quant_config) layer.topk_method = "noaux_tc" @@ -721,9 +739,9 @@ def test_triton_weight_only_apply_noaux_tc_with_use_fused_true(self, fake_ops, m kernel = DummyKernel() monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) - # Enable enable_moe_scores_elementwise_fuse and force is_cuda=True to trigger use_fused = True - monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: True)) - layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse = True + # Set FD_ENABLE_RL=True to trigger use_fused = False at line 313 + # This should trigger gate_out.cast('float32') at line 315 + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) x = paddle.randn([1, layer.hidden_size], dtype="float32") gate = DummyGate(layer.num_local_experts) @@ -736,39 +754,555 @@ def hook(topk_ids): _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) assert "topk_ids" in captured - def test_triton_weight_only_apply_noaux_tc_with_non_cuda(self, fake_ops, monkeypatch): - quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + def test_python_op_learnable_scaling(self, fake_ops, monkeypatch): + """routed_scaling_factor_learnable=True: per_expert_scale applied to topk_weights inside python_op.""" + quant_config = DummyQuantConfig(is_checkpoint_bf16=False, weight_block_size=(2, 2)) layer = DummyLayer(quant_config) - # Ensure topk_method is "noaux_tc" to enter the target branch - layer.topk_method = "noaux_tc" - method = backend.TritonWeightOnlyMoEMethod(quant_config) - method.create_weights(layer, model_format="torch") + layer.routed_scaling_factor_learnable = True + layer.per_expert_scale = paddle.ones([layer.num_local_experts], dtype="float32") + + kernel = DummyKernel() + monkeypatch.setitem( + sys.modules, + "fastdeploy.model_executor.layers.moe.triton_moe_kernels", + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), + ) + monkeypatch.setattr( + paddle.static, + "MetaTensor", + lambda shape, dtype: types.SimpleNamespace(shape=shape, dtype=dtype), + raising=False, + ) + + x = paddle.randn([2, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + gate_out = gate(x) + + up_weight = paddle.randn( + [layer.num_local_experts, layer.moe_intermediate_size * 2, layer.hidden_size], dtype="float32" + ) + down_weight = paddle.randn( + [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size], dtype="float32" + ) + up_scale = paddle.ones([layer.num_local_experts, 2, 2], dtype="float32") + down_scale = paddle.ones([layer.num_local_experts, 2, 2], dtype="float32") + + captured = {} + + def hook(topk_ids): + captured["topk"] = topk_ids + + config = {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1} + + _ = backend.python_op_fused_moe_kernel_paddle( + x, + up_weight, + up_scale, + down_weight, + down_scale, + gate_out, + layer.gate_correction_bias, + layer.top_k, + up_weight.shape[1], + down_weight.shape[1], + layer.num_local_experts, + layer.moe_intermediate_size, + layer.hidden_size, + config, + quant_config, + hook, + ) + + assert "topk" in captured + + +class DummyBF16Kernel: + """ + Simulates fused_moe_kernel_bf16[grid](...). + Writes zeros into the output tensor (3rd positional argument). + """ + + def __init__(self): + self.calls = [] + + def __getitem__(self, grid): + def _runner(*args, **kwargs): + # output tensor is the 3rd positional argument (index 2) + if len(args) > 2 and isinstance(args[2], paddle.Tensor): + args[2].set_value(paddle.zeros_like(args[2])) + self.calls.append({"grid": grid, "kwargs": kwargs}) + + return _runner + + +class DummyTL: + """Minimal stub for triton.language so tests don't need a real Triton install.""" + + bfloat16 = "bfloat16" + float16 = "float16" + + +class TestTritonMoEMethod: + """Unit tests for TritonMoEMethod. + + Pattern mirrors TestFusedMoeTritonBackend: + - DummyLayer / DummyGate / DummyFDConfig (reused from module top) + - fake_ops fixture patches routing + preprocess ops + - DummyBF16Kernel patches fused_moe_kernel_bf16 + - No real GPU kernels are executed; output shapes / attributes are verified + """ + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + def _make_layer(self, num_experts=2, hidden_size=8, intermediate_size=4, top_k=2): + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + return layer + + def _create_weights(self, method, layer): + """Call create_weights with the mandatory kwargs that the real MoE layer supplies. + + TritonMoEMethod targets the CUDA non-torch weight layout: + up_gate_proj_weight: [E, hidden_size, inter*2] (K-major) + down_proj_weight: [E, inter, hidden_size] (K-major) + Therefore we must NOT pass model_format="torch"; any non-"torch" value + (or omitting the key) lets UnquantizedFusedMoEMethod take the CUDA branch. + """ + method.create_weights( + layer, + model_format="default", + num_experts=layer.num_local_experts, + hidden_size=layer.hidden_size, + moe_intermediate_size=layer.moe_intermediate_size, + ) + + def _patch_bf16_kernel(self, monkeypatch): + kernel = DummyBF16Kernel() + monkeypatch.setattr(backend, "fused_moe_kernel_bf16", kernel, raising=False) + # Patch tl so that `compute_type=tl.bfloat16` inside apply() does not + # raise NameError when triton is not installed in the test environment. + monkeypatch.setattr(backend, "tl", DummyTL(), raising=False) + return kernel + + # ------------------------------------------------------------------ + # __init__ / basic construction + # ------------------------------------------------------------------ + + def test_init_sets_weight_attrs(self): + """TritonMoEMethod.__init__ must expose the two weight attr names.""" + method = backend.TritonMoEMethod() + assert "up_gate_proj_weight" in method.added_weight_attrs + assert "down_proj_weight" in method.added_weight_attrs + + def test_init_none_quant_config(self): + method = backend.TritonMoEMethod(quant_config=None) + assert method.quant_config is None + + # ------------------------------------------------------------------ + # create_weights + # ------------------------------------------------------------------ + + def test_create_weights_registers_parameters(self): + """After create_weights the layer should have up_gate_proj_weight and down_proj_weight.""" + method = backend.TritonMoEMethod() + layer = self._make_layer() + self._create_weights(method, layer) + assert hasattr(layer, "up_gate_proj_weight") + assert hasattr(layer, "down_proj_weight") + + def test_create_weights_shapes(self): + """Weight tensors must have the correct [E, K, N] / [E, N, K] layout.""" + E, H, N = 3, 8, 4 + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + # up_gate: [E, hidden_size, intermediate*2] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + # down: [E, intermediate, hidden_size] + assert list(layer.down_proj_weight.shape) == [E, N, H] + + # ------------------------------------------------------------------ + # process_loaded_weights + # ------------------------------------------------------------------ + + def test_process_loaded_weights_stacks_experts(self): + """process_loaded_weights must stack per-expert tensors into the stacked param.""" + E, H, N = 2, 8, 4 + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + + # Provide per-expert tensors via extract_moe_ffn_weights + up_weights = [paddle.ones([H, N * 2], dtype="bfloat16") * (i + 1) for i in range(E)] + down_weights = [paddle.ones([N, H], dtype="bfloat16") * (i + 1) for i in range(E)] + layer._up_weights = up_weights + layer._down_weights = down_weights - layer._up_weights = [ - paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( - [layer.hidden_size, layer.moe_intermediate_size * 2] - ) - for _ in range(layer.num_local_experts) - ] - layer._down_weights = [ - paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( - [layer.moe_intermediate_size, layer.hidden_size] - ) - for _ in range(layer.num_local_experts) - ] method.process_loaded_weights(layer, state_dict={}) - kernel = DummyKernel() - monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + # After stacking, shape should be [E, ...] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + assert list(layer.down_proj_weight.shape) == [E, N, H] + # Verify each expert's data is correctly stacked (expert i has value i+1) + for i in range(E): + expected_up = float(i + 1) + expected_down = float(i + 1) + actual_up = float(layer.up_gate_proj_weight[i].cast("float32").mean()) + actual_down = float(layer.down_proj_weight[i].cast("float32").mean()) + assert ( + abs(actual_up - expected_up) < 1e-3 + ), f"Expert {i} up_gate weight mean={actual_up}, expected {expected_up}" + assert ( + abs(actual_down - expected_down) < 1e-3 + ), f"Expert {i} down_proj weight mean={actual_down}, expected {expected_down}" + + # ------------------------------------------------------------------ + # ------------------------------------------------------------------ + # _get_default_config — tile heuristic + # ------------------------------------------------------------------ + + def test_get_default_config_decode(self): + """M<=32 decode path → 16x64x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=4, E=8) + assert cfg["BLOCK_SIZE_M"] == 16 + assert cfg["BLOCK_SIZE_N"] == 64 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_mid(self): + """96 < M <= 512 mid path → 64x128x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=128, E=8) + assert cfg["BLOCK_SIZE_M"] == 64 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_prefill(self): + """M > 512 prefill path → 128x128x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=1024, E=8) + assert cfg["BLOCK_SIZE_M"] == 128 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_boundary_32(self): + """M==32 is decode (<=32).""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=32, E=8) + assert cfg["BLOCK_SIZE_M"] == 16 + + def test_get_default_config_boundary_96(self): + """M==96 is small-mid (32 < M <= 96) → BLOCK_SIZE_M=32.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=96, E=8) + assert cfg["BLOCK_SIZE_M"] == 32 + + def test_get_default_config_boundary_512(self): + """M==512 is mid (<=512) → BLOCK_SIZE_M=64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=512, E=8) + assert cfg["BLOCK_SIZE_M"] == 64 + + def test_get_default_config_has_group_size_m(self): + """All configs must include GROUP_SIZE_M key.""" + method = backend.TritonMoEMethod() + for M in (1, 64, 1024): + cfg = method._get_default_config(M=M, E=8) + assert "GROUP_SIZE_M" in cfg + + def test_get_default_config_block_n_boundary(self): + """M<=64 → BLOCK_SIZE_N=64; M>64 → BLOCK_SIZE_N=128.""" + method = backend.TritonMoEMethod() + cfg64 = method._get_default_config(M=64, E=8) + assert cfg64["BLOCK_SIZE_N"] == 64 + cfg65 = method._get_default_config(M=65, E=8) + assert cfg65["BLOCK_SIZE_N"] == 128 + + def test_get_default_config_group_m_16(self): + """tokens_per_expert > 128 → GROUP_SIZE_M=16.""" + method = backend.TritonMoEMethod() + # M=1024, E=1 → tokens_per_expert=1024 > 128 → group_m=16 + cfg = method._get_default_config(M=1024, E=1) + assert cfg["GROUP_SIZE_M"] == 16 + + def test_get_default_config_group_m_1(self): + """tokens_per_expert <= 128 → GROUP_SIZE_M=1.""" + method = backend.TritonMoEMethod() + # M=128, E=8 → tokens_per_expert=16 <= 128 → group_m=1 + cfg = method._get_default_config(M=128, E=8) + assert cfg["GROUP_SIZE_M"] == 1 + + def test_get_default_config_num_warps(self): + """M<=128 → num_warps=4; M>128 → num_warps=8.""" + method = backend.TritonMoEMethod() + cfg128 = method._get_default_config(M=128, E=8) + assert cfg128["num_warps"] == 4 + cfg256 = method._get_default_config(M=256, E=8) + assert cfg256["num_warps"] == 8 + + def test_get_default_config_num_stages(self): + """M<=32 → num_stages=4; M>32 → num_stages=3.""" + method = backend.TritonMoEMethod() + cfg32 = method._get_default_config(M=32, E=8) + assert cfg32["num_stages"] == 4 + cfg33 = method._get_default_config(M=33, E=8) + assert cfg33["num_stages"] == 3 + + # ------------------------------------------------------------------ + # apply — empty-batch fast path + # ------------------------------------------------------------------ + + def test_apply_empty_batch_returns_zero_tensor(self, fake_ops, monkeypatch): + """apply() with 0 tokens must return a zero tensor of shape [0, hidden_size].""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.zeros([0, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) - # Mock current_platform.is_cuda() to return False to trigger use_fused = False at line 313 - # This should trigger gate_out.cast("float32") at line 315 - monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: False)) + assert list(out.shape) == [0, layer.hidden_size] - x = paddle.randn([2, layer.hidden_size], dtype="float32") + # ------------------------------------------------------------------ + # apply — normal forward (noaux_tc routing path) + # ------------------------------------------------------------------ + + def test_apply_noaux_tc_output_shape(self, fake_ops, monkeypatch): + """apply() noaux_tc path: output shape must be [token_num, hidden_size].""" + T, H = 4, 8 + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=H) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([T, H], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [T, H] + + def test_apply_noaux_tc_topk_hook_called(self, fake_ops, monkeypatch): + """topk_ids_hookfunc must be called with topk_ids kwarg during apply().""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert "topk_ids" in captured + + def test_apply_noaux_tc_kernel_called_twice(self, fake_ops, monkeypatch): + """fused_moe_kernel_bf16 must be launched twice (GEMM1 + GEMM2) per forward pass.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts)) + + assert len(kernel.calls) == 2, f"Expected 2 kernel launches (GEMM1 + GEMM2), got {len(kernel.calls)}" + + # ------------------------------------------------------------------ + # apply — non-noaux routing path (moe_topk_select) + # ------------------------------------------------------------------ + + def test_apply_aux_routing_path(self, fake_ops, monkeypatch): + """When topk_method != 'noaux_tc', the moe_topk_select path is used.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + layer.topk_method = "aux" + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["ids"] = topk_ids + + x = paddle.randn([3, layer.hidden_size], dtype="bfloat16") + out = method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert list(out.shape) == [3, layer.hidden_size] + assert "ids" in captured + + # ------------------------------------------------------------------ + # apply_tp delegates to apply + # ------------------------------------------------------------------ + + def test_apply_tp_delegates_to_apply(self, fake_ops, monkeypatch): + """apply_tp() must produce the same output shape as apply().""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply_tp(layer, x, gate) + + assert list(out.shape) == [2, layer.hidden_size] + + # ------------------------------------------------------------------ + # EP methods raise NotImplementedError + # ------------------------------------------------------------------ + + def test_apply_ep_prefill_raises(self): + method = backend.TritonMoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_prefill(layer, None, None) + + def test_apply_ep_decode_raises(self): + method = backend.TritonMoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_decode(layer, None, None) + + # ------------------------------------------------------------------ + # apply — kernel argument verification + # ------------------------------------------------------------------ + + def test_apply_kernel_even_ks_true(self, fake_ops, monkeypatch): + """When hidden_size is divisible by BLOCK_SIZE_K, even_Ks=True in GEMM1.""" + method = backend.TritonMoEMethod() + # hidden_size=64, BLOCK_SIZE_K=64 → even_Ks=True for GEMM1 + layer = self._make_layer(hidden_size=64, intermediate_size=32) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["even_Ks"] is True + + def test_apply_kernel_even_ks_false(self, fake_ops, monkeypatch): + """When hidden_size is NOT divisible by BLOCK_SIZE_K, even_Ks=False in GEMM1.""" + method = backend.TritonMoEMethod() + # hidden_size=8, BLOCK_SIZE_K=64 → even_Ks=False for GEMM1 + layer = self._make_layer(hidden_size=8, intermediate_size=4) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["even_Ks"] is False + + def test_apply_gemm2_top_k_always_1(self, fake_ops, monkeypatch): + """GEMM2 must always be called with top_k=1 (flat token-expert pairs).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8, top_k=4) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["top_k"] == layer.top_k + assert kernel.calls[1]["kwargs"]["top_k"] == 1 + + def test_apply_gemm1_no_mul_weight_gemm2_mul_weight(self, fake_ops, monkeypatch): + """GEMM1 has MUL_ROUTED_WEIGHT=False, GEMM2 has MUL_ROUTED_WEIGHT=True.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) - def fake_get_moe_scores(*args, **kwargs): + assert kernel.calls[0]["kwargs"]["MUL_ROUTED_WEIGHT"] is False + assert kernel.calls[1]["kwargs"]["MUL_ROUTED_WEIGHT"] is True + + def test_apply_large_batch_config(self, fake_ops, monkeypatch): + """Large token count picks larger tile config (BLOCK_SIZE_M=128, num_warps=8).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + # 1024 tokens → prefill config: BLOCK_SIZE_M=128 + x = paddle.randn([1024, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["BLOCK_SIZE_M"] == 128 + assert kernel.calls[0]["kwargs"]["num_warps"] == 8 + + def test_apply_single_token_output_shape(self, fake_ops, monkeypatch): + """Single token decode scenario.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=128, hidden_size=16, intermediate_size=8, top_k=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [1, layer.hidden_size] + + def test_get_moe_method_triton_branch(self, monkeypatch): + """get_moe_method() returns TritonMoEMethod when FD_MOE_BACKEND='triton' and is_cuda().""" + from fastdeploy.model_executor.layers.moe import moe as moe_module + + monkeypatch.setattr(moe_module, "current_platform", types.SimpleNamespace(is_cuda=lambda: True)) + monkeypatch.setattr(moe_module.envs, "FD_MOE_BACKEND", "triton") + result = moe_module.get_moe_method() + assert isinstance(result, backend.TritonMoEMethod) + + def test_apply_use_fused_false(self, fake_ops, monkeypatch): + """FD_ENABLE_RL=True triggers use_fused=False branch (gate_out.cast('float32')).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + assert list(out.shape) == [2, layer.hidden_size] + + def test_apply_tp_with_topk_reduce_func(self, fake_ops, monkeypatch): + """topk_reduce_func attribute is passed through to get_moe_scores.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + layer.topk_reduce_func = lambda x: x + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + scores_kwargs = {} + + def tracking_get_moe_scores(*args, **kwargs): + scores_kwargs.update(kwargs) gate_out = args[0] token_num = gate_out.shape[0] top_k = args[3] @@ -776,12 +1310,187 @@ def fake_get_moe_scores(*args, **kwargs): topk_weights = paddle.ones([token_num, top_k], dtype="float32") return gate_out, topk_weights, topk_ids - monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores) + monkeypatch.setattr(backend, "get_moe_scores", tracking_get_moe_scores) - captured = {} + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) - def hook(topk_ids): - captured["topk_ids"] = topk_ids + assert "topk_reduce_func" in scores_kwargs - _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) - assert "topk_ids" in captured + +# =========================================================================== +# Precision tests: TritonMoEMethod vs. CutlassMoEMethod (BF16) +# =========================================================================== + + +def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_k): + """ + Build a DummyLayer with random BF16 weights and a TritonMoEMethod. + + Weight layout (CUDA non-torch): [E, H, 2N] for up_gate_proj, [E, N, H] for down_proj. + Returns (layer, None, triton_method) for compatibility with existing test signatures. + """ + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + + triton_method = backend.TritonMoEMethod() + + # Create weight parameters (CUDA non-torch layout) + triton_method.create_weights( + layer, + model_format="default", + num_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + ) + + # Fill with Xavier-like random BF16 weights to produce meaningful output magnitudes. + # W1: [E, H, 2N] — scale by 1/sqrt(H) so GEMM1 output ~O(1) + # W2: [E, N, H] — scale by 1/sqrt(N) so GEMM2 output ~O(1) + paddle.seed(42) + w1_scale = 1.0 / (hidden_size**0.5) + w2_scale = 1.0 / (intermediate_size**0.5) + layer.up_gate_proj_weight.set_value((paddle.randn(layer.up_gate_proj_weight.shape) * w1_scale).cast("bfloat16")) + layer.down_proj_weight.set_value((paddle.randn(layer.down_proj_weight.shape) * w2_scale).cast("bfloat16")) + return layer, None, triton_method + + +def _uniform_gate(layer): + """Gate that outputs uniform logits so every expert gets equal probability.""" + + class _Gate(paddle.nn.Layer): + def __init__(self, num_experts): + super().__init__() + self.num_experts = num_experts + + def forward(self, x): + return paddle.ones([x.shape[0], self.num_experts], dtype="float32") + + return _Gate(layer.num_local_experts) + + +# Shapes to exercise: (token_num, hidden_size, intermediate_size, num_experts, top_k) +# Small/medium sizes to keep test runtime reasonable. +_PRECISION_SHAPES = [ + pytest.param(1, 64, 32, 8, 2, id="decode_T1_H64"), + pytest.param(16, 64, 32, 8, 2, id="decode_T16_H64"), + pytest.param(64, 128, 64, 8, 2, id="mid_T64_H128"), + pytest.param(128, 128, 64, 8, 2, id="mid_T128_H128_E8"), + pytest.param(256, 256, 128, 8, 4, id="prefill_T256_H256"), +] + + +@pytest.mark.skipif(not paddle.is_compiled_with_cuda(), reason="requires CUDA") +# @pytest.mark.skipif(not _triton_ops_available(), reason="triton MoE ops not available (custom ops not compiled)") +class TestTritonMoEPrecision: + """ + Precision tests: Triton BF16 path vs. Cutlass BF16 path. + + Both paths are activated in production via the FD_MOE_BACKEND env var + (triton vs cutlass). This test verifies they produce numerically equivalent + results on the same shared BF16 weights and identical inputs. + + All tests run real GPU kernels (no mocking). + Tolerance: atol=1e-2, rtol=1e-2 (both kernels use BF16 arithmetic with + fp32 accumulation; differences come from tile ordering / rounding). + """ + + # Tolerance for comparing two independent BF16 GEMM implementations. + # BF16 has ~7-bit mantissa (eps ~0.008). After GEMM1 + SwiGLU + GEMM2, + # rounding differences accumulate. Use np.allclose style: + # |triton - cutlass| <= ATOL + RTOL * |cutlass| + ATOL = 1e-3 + RTOL = 1e-3 + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_vs_cutlass(self, T, H, N, E, K): + """Triton BF16 MoE output must agree with CUTLASS BF16 MoE output. + + Both paths use the same weight layout, routing logic, and BF16 arithmetic. + Differences should only come from tile ordering / rounding in GEMM. + """ + from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import ( + CutlassMoEMethod, + ) + + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + + # CUTLASS method shares the same weights (already created by _make_precision_layer_pair) + cutlass_method = CutlassMoEMethod(None) + + paddle.seed(0) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + + # Use a deterministic non-uniform gate to ensure consistent routing + # across multiple calls of noaux_tc (avoids tie-breaking ambiguity) + class _DeterministicGate(paddle.nn.Layer): + def __init__(self, num_experts, T): + super().__init__() + self.num_experts = num_experts + paddle.seed(123) + self._scores = paddle.randn([T, num_experts], dtype="float32") * 2.0 + + def forward(self, x): + return self._scores[: x.shape[0]] + + gate = _DeterministicGate(E, T) + + # --- Run Triton path --- + triton_out = triton_method.apply(layer, x, gate).cast("float32").numpy() + + # --- Run CUTLASS path --- + cutlass_out = cutlass_method.apply(layer, x, gate).cast("float32").numpy() + + # np.allclose style: |a - b| <= atol + rtol * |b| + tol = self.ATOL + self.RTOL * np.abs(cutlass_out) + violations = np.abs(triton_out - cutlass_out) > tol + num_violations = int(violations.sum()) + total_elements = triton_out.size + + assert num_violations == 0, ( + f"[T={T},H={H},N={N},E={E},K={K}] " + f"{num_violations}/{total_elements} elements exceed tolerance " + f"(atol={self.ATOL}, rtol={self.RTOL}). " + f"Max abs diff: {float(np.abs(triton_out - cutlass_out).max()):.2e}, " + f"max |cutlass|: {float(np.abs(cutlass_out).max()):.2e}" + ) + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_shape(self, T, H, N, E, K): + """Output shape must always be [T, H] regardless of batch size.""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert list(out.shape) == [T, H], f"Expected [{T}, {H}], got {list(out.shape)}" + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_dtype_is_bfloat16(self, T, H, N, E, K): + """Output dtype must match input dtype (bfloat16).""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert out.dtype == paddle.bfloat16, f"Expected bfloat16, got {out.dtype}" + + def test_zero_input_gives_zero_output(self): + """All-zero input must produce all-zero output.""" + T, H, N, E, K = 8, 64, 32, 8, 2 + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = paddle.zeros([T, H], dtype="bfloat16") + gate = _uniform_gate(layer) + + out = triton_method.apply(layer, x, gate).cast("float32").numpy() + np.testing.assert_allclose( + out, + np.zeros_like(out), + atol=1e-6, + err_msg="triton: zero input should produce zero output", + ) From df637af7ba19dde7e76c413d44497200329d81e2 Mon Sep 17 00:00:00 2001 From: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Date: Thu, 14 May 2026 17:02:08 +0800 Subject: [PATCH 109/143] refact abort requests (#7808) --- docs/online_serving/README.md | 2 +- docs/zh/online_serving/README.md | 2 +- fastdeploy/engine/common_engine.py | 189 ++++-------------- .../engine/sched/resource_manager_v1.py | 3 +- fastdeploy/entrypoints/engine_client.py | 12 ++ fastdeploy/entrypoints/openai/api_server.py | 9 +- .../entrypoints/openai/response_processors.py | 4 +- fastdeploy/router/router.py | 45 +---- tests/engine/test_common_engine.py | 176 +--------------- tests/entrypoints/openai/test_api_server.py | 34 +--- tests/router/test_router.py | 78 ++------ 11 files changed, 109 insertions(+), 445 deletions(-) diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index 2b447476020..c9dba035339 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -577,4 +577,4 @@ DeltaFunctionCall: - `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset. - `/v1/resume` - Resume generation. - `/v1/is_paused` - Check if generation is paused. -- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts. +- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index 21f16d06e32..0264c928bd5 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -563,4 +563,4 @@ DeltaFunctionCall: /v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。 /v1/resume - 恢复推理生成。 /v1/is_paused - 检查推理生成是否已暂停。 -/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。 +/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。 diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 7a8a641573a..c31c9039b40 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -44,11 +44,9 @@ from fastdeploy.engine.common_engine_prepare_mixin import EngineServicePrepareMixin from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( - CompletionOutput, ControlRequest, ControlResponse, Request, - RequestMetrics, RequestOutput, RequestStatus, RequestType, @@ -1087,10 +1085,27 @@ def _insert_zmq_task_to_scheduler(self): self.request_worker_map[req_id_for_map] = worker_pid status_value = data.get("status", None) if status_value is not None and status_value == RequestStatus.ABORT.value: - req_id = data["request_id"] - self.llm_logger.info(f"Receive abort request, req_id: {req_id}") - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.resource_manager.add_abort_req_ids(req_id) + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.llm_logger.info("abort requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") + else: + abort_all = data.get("abort_all", False) + req_ids = data.get("req_ids", []) + if abort_all or req_ids: + target_req_ids = self._resolve_abort_targets(abort_all, req_ids) + self.llm_logger.info( + f"Receive abort_reqs, abort_all={abort_all}, " + f"input={len(req_ids)}, resolved={len(target_req_ids)}" + ) + self.resource_manager.add_abort_req_ids(target_req_ids) + else: + req_id = data.get("request_id", None) + if not req_id: + self.llm_logger.warning( + "Receive abort request without request_id, skip invalid abort message" + ) + continue + self.llm_logger.info(f"Receive abort request, req_id: {req_id}") + self.resource_manager.add_abort_req_ids(req_id) continue err_msg = None try: @@ -1373,150 +1388,6 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d return responses - def _control_abort_requests(self, control_req: ControlRequest): - if not envs.ENABLE_V1_KVCACHE_SCHEDULER: - raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") - args = control_req.get_args() - abort_all = args.get("abort_all", False) - req_ids = args.get("req_ids", []) - matched_input_ids = set() - now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) - - # Step 1: Determine target request list - if abort_all: - # all requests in running + waiting - target_req_ids = now_reqs - else: - # filter out requests that actually exist - target_req_ids = [] - for rid in req_ids: - if rid in now_reqs: - target_req_ids.append(rid) - matched_input_ids.add(rid) - elif f"{rid}_0" in now_reqs: - target_req_ids.append(f"{rid}_0") - matched_input_ids.add(rid) - - if not target_req_ids: - return {"aborted": [], "not_found": req_ids if not abort_all else []} - - # Step 2: Collect partial results - aborted_info = [] - results = [] - for req_id in target_req_ids: - request = self.resource_manager.requests.get(req_id) - if request is None: - scheduled_req = self.scheduler.requests.get(req_id) - if scheduled_req is None: - continue - request = scheduled_req.raw - - partial_token_ids = list(request.output_token_ids) - - # Construct finished response with partial results - now = time.time() - abort_metrics = RequestMetrics( - arrival_time=request.metrics.arrival_time if request.metrics else now, - inference_start_time=request.metrics.inference_start_time if request.metrics else now, - engine_recv_latest_token_time=now, - engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, - request_start_time=request.metrics.arrival_time if request.metrics else now, - ) - eos_token_ids = getattr(request, "eos_token_ids", [0]) - result = RequestOutput( - request_id=req_id, - finished=True, - outputs=CompletionOutput( - index=0, - send_idx=len(partial_token_ids), - token_ids=[eos_token_ids[0]], - ), - metrics=abort_metrics, - error_code=200, - error_msg="Aborted", - ) - results.append(result) - aborted_info.append( - { - "request_id": req_id, - "output_token_count": len(partial_token_ids), - } - ) - - # Step 3: Execute abort — add all requests to waiting_abort_req_id_set - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - for req_id in target_req_ids: - self.resource_manager.add_abort_req_ids(req_id) - time.sleep(0.0001) - if self.cfg.scheduler_config.splitwise_role != "prefill": - self._wait_abort_complete(target_req_ids) - - # Add results to scheduler, engine will have a thread calling get_results, - # then cleanup and call send_response to send to client. - # When client disconnects, send_response will automatically ignore - if self.cfg.scheduler_config.splitwise_role != "prefill": - try: - # self.send_response_server.send_response(req_id, [result]) - self.scheduler.put_results(results) - except Exception: - pass # client may have disconnected - - not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] - - return {"aborted": aborted_info, "not_found": not_found} - - def _wait_abort_complete(self, target_req_ids, stall_timeout=1): - """ - Wait for all abort requests to complete. - - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet - - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, - reset progress state if any, then continue monitoring - """ - target_set = set(target_req_ids) - target_set = target_set & (set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) - prev_remaining_count = len(target_set) - last_progress_time = time.time() - remaining = target_set & self.resource_manager.get_reqs_in_aborting() - while remaining: - alive_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) - finished_reqs = target_set - alive_reqs - if finished_reqs: - self.llm_logger.info(f"abort targets already finished, skip: {finished_reqs}") - for req_id in finished_reqs: - self.resource_manager.waiting_abort_req_id_set.discard(req_id) - self.resource_manager.to_be_aborted_req_id_set.discard(req_id) - target_set -= finished_reqs - remaining = target_set & self.resource_manager.get_reqs_in_aborting() - if not remaining: - self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") - return - self.llm_logger.debug(f"remaining:{remaining}") - - current_count = len(remaining) - if current_count < prev_remaining_count: - # progress made: recycle_abort_task was called - self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") - last_progress_time = time.time() - prev_remaining_count = current_count - - if time.time() - last_progress_time > stall_timeout: - # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) - stuck = remaining & self.resource_manager.to_be_aborted_req_id_set - if stuck: - self.llm_logger.warning( - f"no abort progress for {stall_timeout}s, " - f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" - ) - for req_id in list(stuck): - self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") - self.resource_manager.recycle_abort_task(req_id) - # reset progress state - last_progress_time = time.time() - prev_remaining_count = current_count - len(stuck) - # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow - - time.sleep(0.005) - def _parse_tags(self, control_request: ControlRequest): """ Parse tags from control request. @@ -2562,3 +2433,21 @@ def detect_thread(): except Exception: pass return True + + def _resolve_abort_targets(self, abort_all, req_ids): + """ + Resolve abort target request IDs. + """ + now_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) + self.llm_logger.debug(f"now_reqs: {now_reqs}") + + if abort_all: + return list(now_reqs) + + target_req_ids = [] + for rid in req_ids: + if rid in now_reqs: + target_req_ids.append(rid) + elif f"{rid}_0" in now_reqs: + target_req_ids.append(f"{rid}_0") + return target_req_ids diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index e25ed0e1231..de89ab3adca 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -294,7 +294,7 @@ def recycle_abort_task(self, request_id): del self.req_dict[request_id] self.to_be_aborted_req_id_set.discard(request_id) self.waiting_abort_req_id_set.discard(request_id) - llm_logger.debug(f"request_id:{request_id} recycle end") + llm_logger.debug(f"request_id:{request_id} recycle abort task end") self.update_metrics() def _trigger_abort(self, request_id, scheduled_reqs): @@ -307,6 +307,7 @@ def _trigger_abort(self, request_id, scheduled_reqs): scheduled_reqs.append(self._prepare_abort_task(abort_request)) self.to_be_aborted_req_id_set.add(request_id) self.waiting_abort_req_id_set.discard(request_id) + llm_logger.debug(f"request_id:{request_id} trigger abort") def _info_each_block(self): """ diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 278d6e576ab..7fcb2e0fbc4 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -1046,6 +1046,18 @@ async def abort(self, request_id, n=1) -> None: api_server_logger.info("Aborted request(s) %s.", ",".join(request_ids)) + async def abort_reqs(self, req_ids=None, abort_all=False): + """ + Fire-and-forget: abort multiple requests in one ZMQ message. + Used by /v1/abort_requests API. + """ + data = { + "status": RequestStatus.ABORT.value, + "abort_all": abort_all, + "req_ids": req_ids or [], + } + self._send_task(data) + def process_messages(self, messages): for message in messages: if message["role"] == "assistant" and "tool_calls" in message: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index b96d93ab312..4e76a62ca65 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -486,13 +486,8 @@ async def abort_requests(request: Request): if not abort_all and not req_ids: return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"}) - control_request = ControlRequest( - request_id=f"control-{uuid.uuid4()}", - method="abort_requests", - args={"abort_all": abort_all, "req_ids": req_ids or []}, - ) - control_response = await app.state.engine_client.run_control_method(control_request) - return control_response.to_api_json_response() + await app.state.engine_client.abort_reqs(req_ids=req_ids or [], abort_all=abort_all) + return Response(status_code=200) def wrap_streaming_generator(original_generator: AsyncGenerator): diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py index 41761963be8..acdfcaeade4 100644 --- a/fastdeploy/entrypoints/openai/response_processors.py +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -89,7 +89,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ decode_type = request_output["outputs"].get("decode_type", 0) or 0 if decode_type == 0: # text tts = req_id in self._audio_buffer - if token_ids[-1] == self.eos_token_id: + if token_ids and token_ids[-1] == self.eos_token_id: all_audio_tokens = self._audio_buffer.pop(req_id, []) else: all_audio_tokens = None @@ -186,7 +186,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ else: self.accumulate_token_ids(request_output) token_ids = request_output["outputs"]["token_ids"] - if token_ids[-1] == self.eos_token_id: + if token_ids and token_ids[-1] == self.eos_token_id: multipart = [] num_image_tokens = 0 for part in self._multipart_buffer: diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index bdb9c5b9c6a..1e1adf5fd9b 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -18,7 +18,7 @@ import aiohttp import uvicorn from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse +from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastdeploy.router.utils import ( InstanceInfo, @@ -29,6 +29,7 @@ from fastdeploy.utils import router_logger as logger app = FastAPI() +_background_tasks = set() @dataclass @@ -588,39 +589,15 @@ async def abort_requests(request: Request): decode_servers = app.state.router.decode_servers all_servers = prefill_servers + decode_servers - async with aiohttp.ClientSession() as session: - tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] - responses = await asyncio.gather(*tasks, return_exceptions=True) - - # Aggregate results from Node D only - all_aborted = [] - all_not_found = [] - errors = [] - decode_start = len(prefill_servers) - for i, (server, resp) in enumerate(zip(all_servers, responses)): - if i < decode_start: - continue - if isinstance(resp, Exception): - errors.append({"server": server.url(), "error": str(resp)}) - elif resp.status == 200: - data = await resp.json() - result = data.get("result") or {} - all_aborted.extend(result.get("aborted", [])) - all_not_found.extend(result.get("not_found", [])) - else: - errors.append({"server": server.url(), "status": resp.status}) - - return JSONResponse( - content={ - "request_id": f"router-{uuid4()}", - "status": "success" if not errors else "error", - "error_message": None if not errors else str(errors), - "result": { - "aborted": all_aborted, - "not_found": list(set(all_not_found)), - }, - } - ) + async def _forward_abort(): + async with aiohttp.ClientSession() as session: + tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] + await asyncio.gather(*tasks, return_exceptions=True) + + task = asyncio.create_task(_forward_abort()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + return Response(status_code=200) def launch_router(router_args: RouterArgs): diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 54e613c8637..7e7c660964b 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -3496,7 +3496,7 @@ def _fake_sleep(s): self.assertGreaterEqual(call_count[0], 1) self._detach_finalizer(eng) - # ── _control_abort_requests / _wait_abort_complete ─────────────── + # ── _resolve_abort_targets / _build_abort_results ─────────────── def _make_abort_engine(self, splitwise_role="mixed"): """Create an engine wired up for abort tests.""" @@ -3537,179 +3537,21 @@ def _make_fake_request(self, output_token_ids=None): req.metrics.engine_recv_first_token_time = 1000.2 return req - def test_control_abort_requests_not_v1_raises(self): - """abort_requests raises when ENABLE_V1_KVCACHE_SCHEDULER is off.""" - eng = self._make_abort_engine() - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): - with self.assertRaises(Exception) as ctx: - eng._control_abort_requests(control_req) - self.assertIn("only supported", str(ctx.exception)) - self._detach_finalizer(eng) - - def test_control_abort_requests_abort_all(self): - """abort_all=True aborts all requests in resource_manager + scheduler.""" + def test_resolve_abort_targets_abort_all(self): + """abort_all=True returns all requests in resource_manager + scheduler.""" eng = self._make_abort_engine() eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20])} eng.scheduler.requests = {"req-2_0": MagicMock(raw=self._make_fake_request([30]))} - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - - def clear_abort_sets(req_id): - # Simulate immediate abort completion - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(len(result["aborted"]), 2) - self.assertEqual(result["not_found"], []) - ids = {a["request_id"] for a in result["aborted"]} - self.assertEqual(ids, {"req-1_0", "req-2_0"}) - # put_results should have been called (not prefill) - eng.scheduler.put_results.assert_called_once() - self._detach_finalizer(eng) - - def test_control_abort_requests_by_req_ids_with_suffix_match(self): - """req_ids match both exact and _0 suffix.""" - eng = self._make_abort_engine() - eng.resource_manager.requests = { - "req-A_0": self._make_fake_request([1, 2, 3]), - "req-B": self._make_fake_request([4, 5]), - } - - control_req = ControlRequest( - "ctrl-1", - "abort_requests", - { - "abort_all": False, - "req_ids": ["req-A", "req-B", "req-C"], - }, - ) - - def clear_abort_sets(req_id): - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - aborted_ids = {a["request_id"] for a in result["aborted"]} - self.assertIn("req-A_0", aborted_ids) # matched via _0 suffix - self.assertIn("req-B", aborted_ids) # exact match - self.assertEqual(result["not_found"], ["req-C"]) + target = eng._resolve_abort_targets(abort_all=True, req_ids=[]) + self.assertEqual(set(target), {"req-1_0", "req-2_0"}) self._detach_finalizer(eng) - def test_control_abort_requests_no_match(self): - """No requests found returns empty aborted and all in not_found.""" + def test_resolve_abort_targets_no_match(self): + """No matching request ids returns empty list.""" eng = self._make_abort_engine() - control_req = ControlRequest( - "ctrl-1", - "abort_requests", - { - "abort_all": False, - "req_ids": ["nonexistent"], - }, - ) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(result["aborted"], []) - self.assertEqual(result["not_found"], ["nonexistent"]) - self._detach_finalizer(eng) - - def test_control_abort_requests_prefill_skips_wait_and_put(self): - """Prefill role skips _wait_abort_complete and put_results.""" - eng = self._make_abort_engine(splitwise_role="prefill") - eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} - - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - eng.resource_manager.add_abort_req_ids = MagicMock() - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(len(result["aborted"]), 1) - eng.scheduler.put_results.assert_not_called() - self._detach_finalizer(eng) - - def test_control_abort_requests_output_token_count(self): - """output_token_count reflects partial_token_ids length.""" - eng = self._make_abort_engine() - eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20, 30, 40, 50])} - - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - - def clear_abort_sets(req_id): - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(result["aborted"][0]["output_token_count"], 5) - self._detach_finalizer(eng) - - def test_wait_abort_complete_immediate(self): - """_wait_abort_complete returns immediately when all requests already cleaned.""" - eng = self._make_abort_engine() - # Empty abort sets → remaining is empty → returns immediately - eng._wait_abort_complete(["req-1_0"]) - self._detach_finalizer(eng) - - def test_wait_abort_complete_progress(self): - """_wait_abort_complete exits when background thread cleans up.""" - eng = self._make_abort_engine() - eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"} - # Add the request to requests dict so it won't be filtered out - eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} - - call_count = [0] - - def fake_sleep(s): - call_count[0] += 1 - # Simulate background thread cleaning up after first sleep - eng.resource_manager.waiting_abort_req_id_set.discard("req-1_0") - - with patch("fastdeploy.engine.common_engine.time.sleep", fake_sleep): - eng._wait_abort_complete(["req-1_0"]) - - self.assertGreaterEqual(call_count[0], 1) - self._detach_finalizer(eng) - - def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self): - """Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set.""" - eng = self._make_abort_engine() - eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"} - # Add the request to requests dict so it won't be filtered out - eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} - - def mock_recycle(req_id): - eng.resource_manager.to_be_aborted_req_id_set.discard(req_id) - - eng.resource_manager.recycle_abort_task = MagicMock(side_effect=mock_recycle) - - # Make time.time() advance past stall_timeout - time_values = [100.0, 100.0, 102.0, 102.0, 102.0] - time_idx = [0] - - def fake_time(): - idx = min(time_idx[0], len(time_values) - 1) - time_idx[0] += 1 - return time_values[idx] - - with ( - patch("fastdeploy.engine.common_engine.time.time", fake_time), - patch("fastdeploy.engine.common_engine.time.sleep", lambda s: None), - ): - eng._wait_abort_complete(["req-1_0"], stall_timeout=1) - - eng.resource_manager.recycle_abort_task.assert_called_with("req-1_0") + target = eng._resolve_abort_targets(abort_all=False, req_ids=["nonexistent"]) + self.assertEqual(target, []) self._detach_finalizer(eng) diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 48704e026b6..301e77489c1 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -828,44 +828,30 @@ def _mock_abort_control_response(api_server, result, status_code=200): async def test_abort_requests_with_req_ids(): args = _build_args() api_server = _reload_api_server(args) - _mock_abort_control_response( - api_server, - { - "aborted": [{"request_id": "req-1_0", "output_token_count": 10}], - "not_found": ["req-999"], - }, - ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.abort_reqs = AsyncMock(return_value=None) req = MagicMock() req.json = AsyncMock(return_value={"req_ids": ["req-1", "req-999"]}) resp = await api_server.abort_requests(req) assert resp.status_code == 200 - control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] - assert control_req.method == "abort_requests" - assert control_req.args["req_ids"] == ["req-1", "req-999"] - assert control_req.args["abort_all"] is False + call_kwargs = api_server.app.state.engine_client.abort_reqs.await_args.kwargs + assert call_kwargs["req_ids"] == ["req-1", "req-999"] + assert call_kwargs["abort_all"] is False @pytest.mark.asyncio async def test_abort_requests_with_abort_all(): args = _build_args() api_server = _reload_api_server(args) - _mock_abort_control_response( - api_server, - { - "aborted": [ - {"request_id": "req-1_0", "output_token_count": 5}, - {"request_id": "req-2_0", "output_token_count": 12}, - ], - "not_found": [], - }, - ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.abort_reqs = AsyncMock(return_value=None) req = MagicMock() req.json = AsyncMock(return_value={"abort_all": True}) resp = await api_server.abort_requests(req) assert resp.status_code == 200 - control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] - assert control_req.args["abort_all"] is True - assert control_req.args["req_ids"] == [] + call_kwargs = api_server.app.state.engine_client.abort_reqs.await_args.kwargs + assert call_kwargs["abort_all"] is True + assert call_kwargs["req_ids"] == [] @pytest.mark.asyncio diff --git a/tests/router/test_router.py b/tests/router/test_router.py index ba21b814870..3970cd8b46d 100644 --- a/tests/router/test_router.py +++ b/tests/router/test_router.py @@ -20,6 +20,7 @@ We mock it at the network boundary to test Router's registration and selection logic. """ +import asyncio import unittest from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch @@ -178,7 +179,7 @@ async def _coro(): @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health): - """P and D both receive the request, but only D results are aggregated.""" + """Router returns 200 immediately and forwards to all (P + D) servers in background.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -189,24 +190,8 @@ async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 0}], "not_found": []}, - } - ) decode_resp = AsyncMock() decode_resp.status = 200 - decode_resp.json = AsyncMock( - return_value={ - "request_id": "control-d", - "status": "success", - "error_message": None, - "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 15}], "not_found": []}, - } - ) mock_session = self._make_mock_session([prefill_resp, decode_resp]) mock_request = AsyncMock() @@ -214,18 +199,17 @@ async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + # Give the background task a chance to run + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(len(body["result"]["aborted"]), 1) - self.assertEqual(body["result"]["aborted"][0]["output_token_count"], 15) - self.assertEqual(body["status"], "success") + self.assertEqual(resp.status_code, 200) + # Forwarded to both prefill + decode self.assertEqual(mock_session.post.call_count, 2) @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) - async def test_abort_decode_error_returns_error_status(self, mock_health): - """When D node returns a non-200 status, status should be 'error'.""" + async def test_abort_returns_200_even_when_decode_errors(self, mock_health): + """Router fire-and-forgets: still returns 200 when D returns non-200.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -236,14 +220,6 @@ async def test_abort_decode_error_returns_error_status(self, mock_health): prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [], "not_found": []}, - } - ) decode_resp = AsyncMock() decode_resp.status = 500 @@ -253,16 +229,14 @@ async def test_abort_decode_error_returns_error_status(self, mock_health): with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(body["status"], "error") - self.assertIsNotNone(body["error_message"]) + self.assertEqual(resp.status_code, 200) @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) - async def test_abort_decode_exception_returns_error(self, mock_health): - """When D node connection fails (exception), error should be captured.""" + async def test_abort_returns_200_when_decode_raises(self, mock_health): + """Router fire-and-forgets: still returns 200 when a downstream raises.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -273,30 +247,20 @@ async def test_abort_decode_exception_returns_error(self, mock_health): prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [], "not_found": []}, - } - ) - - # D node raises exception — but asyncio.gather(return_exceptions=True) captures it - # So we pass the exception as a response directly + mock_session = self._make_mock_session([prefill_resp, prefill_resp]) # placeholder call_idx = [0] def post_with_exception(*args, **kwargs): call_idx[0] += 1 if call_idx[0] == 1: - # prefill: normal + async def _coro(): return prefill_resp return _coro() else: - # decode: raise (gather with return_exceptions=True will catch) + async def _coro_err(): raise ConnectionError("refused") @@ -308,12 +272,10 @@ async def _coro_err(): with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(body["status"], "error") - self.assertIn("refused", body["error_message"]) + self.assertEqual(resp.status_code, 200) if __name__ == "__main__": From 18cab83c9d3a7bee758afcce56c1ffe8bc7f48e1 Mon Sep 17 00:00:00 2001 From: JYChen Date: Thu, 14 May 2026 19:54:49 +0800 Subject: [PATCH 110/143] fix paddle optional get assert in sm103 (#7820) --- custom_ops/setup_ops.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 7ae1e964761..dcb514504e8 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -174,6 +174,12 @@ def get_gencode_flags(archs): "-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}", ] + elif cc_val == 103: + arch_code = "103a" + flags += [ + "-gencode", + f"arch=compute_{arch_code},code=sm_{arch_code}", + ] else: flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"] return flags @@ -476,9 +482,10 @@ def find_end_files(directory, end_str): # of them instead of only the highest one. has_sm90 = 90 in sm_versions has_sm100 = 100 in sm_versions and nvcc_version >= 12.9 - has_generic_fp8 = not has_sm90 and not has_sm100 # SM89 or other + has_sm103 = 103 in sm_versions and nvcc_version >= 13.0 + has_generic_fp8 = not has_sm90 and not has_sm100 and not has_sm103 # SM89 or other - if has_sm90 or has_sm100: + if has_sm90 or has_sm100 or has_sm103: nvcc_compile_args += [ "-O3", "-DNDEBUG", @@ -501,8 +508,8 @@ def find_end_files(directory, end_str): "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", ] - if has_sm100: - print("SM100 (Blackwell): Applying SM100 configurations.") + if has_sm100 or has_sm103: + print("SM100 / 103 (Blackwell): Applying SM100 / SM103 configurations.") # Placeholder for SM100-specific kernel auto-generation scripts # These might be needed if Blackwell has new FP8 hardware features # not covered by existing generic CUTLASS templates or SM90 scripts. From 72beb9e896c7e170acdcd5f87489c6dfd4f8f76a Mon Sep 17 00:00:00 2001 From: yongqiangma Date: Thu, 14 May 2026 20:16:45 +0800 Subject: [PATCH 111/143] opt moe_align_kernel (#7786) --- custom_ops/gpu_ops/helper.h | 2 + custom_ops/gpu_ops/moe/moe_align_kernel.cu | 604 ++++++++++++++++++ .../gpu_ops/moe/tritonmoe_preprocess.cu | 165 ++--- custom_ops/setup_ops.py | 2 + tests/operators/test_tritonmoe_preprocess.py | 392 +++++++++++- 5 files changed, 1035 insertions(+), 130 deletions(-) create mode 100644 custom_ops/gpu_ops/moe/moe_align_kernel.cu diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 83f3ad1077d..cb8c2e3e623 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -73,6 +73,8 @@ namespace cub = hipcub; using json = nlohmann::json; #endif +#define CEILDIV(a, b) (((a + b - 1) / b)) + #define CUDA_CHECK(call) \ do { \ const cudaError_t error_code = call; \ diff --git a/custom_ops/gpu_ops/moe/moe_align_kernel.cu b/custom_ops/gpu_ops/moe/moe_align_kernel.cu new file mode 100644 index 00000000000..4d2a01d8dd9 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_align_kernel.cu @@ -0,0 +1,604 @@ + +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Reference +// https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/sgl-kernel/csrc/moe/moe_align_kernel.cu +// Licensed under Apache License 2.0 +// with further performance optimizations applied. + +#include + +#include "helper.h" +#include "paddle/extension.h" + +#define VEC_SIZE 4 +using Vec = int4; + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +#ifdef __CUDA_ARCH__ +__device__ __forceinline__ int warp_exclusive_scan( + int v, unsigned mask = 0xffffffffu) { + int original = v; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = __shfl_up_sync(mask, v, offset); + if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; + } + return v - original; +} +#endif + +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + bool pad_sorted_token_ids, + const int32_t scan_size, + int32_t max_num_tokens_padded) { + // Use a separate thread block to populate sorted_token_ids + if (blockIdx.x == 1) { + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = threadIdx.x; i < total_vecs; i += blockDim.x) { + out_ptr[i] = fill_vec; + } + } + return; + } + + extern __shared__ int32_t smem[]; + int32_t* shared_counts = smem; // [num_experts] + int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] + int32_t* scan_buf = prefix + num_experts + 1; // [scan_size] + __shared__ int32_t s_total_tokens_post_pad; + + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + if (tid < num_experts) { + shared_counts[tid] = 0; + } + + __syncthreads(); + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i] + 1; + atomicAdd(&shared_counts[expert_id], 1); + } + + __syncthreads(); + + int32_t padded_count = 0; + if (tid < num_experts) { + int32_t count = shared_counts[tid]; + padded_count = (count + block_size - 1) / block_size * block_size; + scan_buf[tid] = padded_count; + } + +#ifndef __CUDA_ARCH__ // HIP + + if (tid >= num_experts && tid < scan_size) { + scan_buf[tid] = 0; + } + + __syncthreads(); + + // Blelloch scan + int offset = 1; +#pragma unroll + for (int d = scan_size >> 1; d > 0; d >>= 1) { + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + scan_buf[bi] += scan_buf[ai]; + } + offset <<= 1; + __syncthreads(); + } + + // down-sweep + if (tid == 0) { + prefix[num_experts] = scan_buf[scan_size - 1]; + scan_buf[scan_size - 1] = 0; + } + __syncthreads(); + +#pragma unroll + for (int d = 1; d < scan_size; d <<= 1) { + offset >>= 1; + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + if (bi < scan_size) { + int temp = scan_buf[ai]; + scan_buf[ai] = scan_buf[bi]; + scan_buf[bi] += temp; + } + } + __syncthreads(); + } + + if (tid < num_experts) { + prefix[tid] = scan_buf[tid]; + } + + if (tid == 0) { + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + +#else // CUDA + + // Intra warp prefix sum + int32_t* warp_sums = scan_buf + scan_size; // [<= 32] + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid & (WARP_SIZE - 1); + const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE; + const int warp_sum = warp_exclusive_scan(padded_count) + padded_count; + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum; + __syncthreads(); + + // warp0 accumulate all the block's prefix sum + if (tid < WARP_SIZE) { + int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0; + int incl = warp_exclusive_scan(val) + val; + warp_sums[tid] = incl; + } + __syncthreads(); + + // Every thread obtains the whole block's sum + if (tid == 0) { + prefix[num_experts] = warp_sums[num_warps_for_scan - 1]; + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + + // Fill 0 to scan_buf extended area (tid >= num_expert) + if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0; + __syncthreads(); + + // Perform 2 level exclusive-prefix-sum to scan_buf + int v = (tid < scan_size) ? scan_buf[tid] : 0; + int pre = warp_exclusive_scan(v); + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v; + __syncthreads(); + + if (warp_id == 0) { + int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0; + warp_sums[lane_id] = warp_exclusive_scan(val); + } + __syncthreads(); + + int offset = warp_sums[warp_id]; + if (tid < scan_size) scan_buf[tid] = pre + offset; + __syncthreads(); + + // Write prefix[0..num_experts - 1] and cumsum + if (tid < num_experts) prefix[tid] = scan_buf[tid]; +#endif + + if (tid <= num_experts) { + cumsum[tid] = prefix[tid]; + } + // fill expert_ids + const int32_t num_blocks = s_total_tokens_post_pad / block_size; + for (int32_t i = tid; i < num_blocks; i += stride) { + int32_t block_start = i * block_size; + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (prefix[mid] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 2; + } +} + +// ===== Cooperative fused kernel for large batch (single launch, grid.sync) + +namespace cg = cooperative_groups; + +template +__global__ void moe_align_block_size_cooperative_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ global_counts, // [num_experts+1], zeroed by caller + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids, + int32_t max_num_tokens_padded) { + cg::grid_group grid = cg::this_grid(); + + extern __shared__ int32_t smem[]; + // smem layout: [num_experts] local_hist + [num_experts+1] expert_starts + int32_t* local_hist = smem; + int32_t* expert_starts_local = smem + num_experts; + + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int nthreads = blockDim.x; + const int nblocks = gridDim.x; + + __shared__ int32_t s_total; + + // ===== Stage 0: Cooperative initialization ===== + // Fill sorted_token_ids with sentinel value (all blocks cooperate) + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = + static_cast(numel); + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = bid * nthreads + tid; i < total_vecs; + i += nblocks * nthreads) { + out_ptr[i] = fill_vec; + } + } + + // Initialize local histogram to 0 + for (int i = tid; i < num_experts; i += nthreads) { + local_hist[i] = 0; + } + __syncthreads(); + + // ===== Stage 1: Local histogram + global atomic merge ===== + for (size_t i = (size_t)bid * nthreads + tid; i < numel; + i += (size_t)nblocks * nthreads) { + int expert_id = static_cast(topk_ids[i]) + 1; + atomicAdd(&local_hist[expert_id], 1); + } + __syncthreads(); + + // Merge local counts into global via atomic fetch-and-add. + // Return value = prefix_before (reuse local_hist to store it). + for (int i = tid; i < num_experts; i += nthreads) { + int32_t count = local_hist[i]; + int32_t prefix_before = atomicAdd(&global_counts[i], count); + local_hist[i] = prefix_before; + } + + grid.sync(); // all histograms merged, global_counts has totals + + // ===== Stage 2: Redundant prefix sum per block ===== + if (tid == 0) { + int32_t running_sum = 0; + for (int i = 0; i < num_experts; i++) { + int32_t count = global_counts[i]; + int32_t padded = (count + block_size - 1) / block_size * block_size; + expert_starts_local[i] = running_sum; + running_sum += padded; + } + expert_starts_local[num_experts] = running_sum; // total + s_total = running_sum; + } + + grid.sync(); + + // Block 0 writes total_tokens_post_pad and cumsum (global_counts) + if (bid == 0) { + // Write expert starts to global_counts for external consumers + if (tid <= num_experts) { + global_counts[tid] = expert_starts_local[tid]; + } + if (tid == 0) { + *total_tokens_post_pad = s_total; + } + } + + // ===== Stage 3: Fill expert_ids (all blocks cooperate) ===== + const int32_t num_blocks_out = s_total / block_size; + for (int32_t i = bid * nthreads + tid; i < num_blocks_out; + i += nblocks * nthreads) { + int32_t block_start = i * block_size; + // Binary search: find the expert whose start <= block_start < next start + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (expert_starts_local[mid + 1] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 1; // expert indexing: topk_ids uses +1 offset + } + + // ===== Stage 4: Scatter tokens using shared memory atomics ===== + // local_hist[i] currently holds prefix_before for this block. + // We do atomic_add on local_hist to get each token's rank within the expert, + // then add expert_starts_local to get the final position. + for (size_t i = (size_t)bid * nthreads + tid; i < numel; + i += (size_t)nblocks * nthreads) { + int expert_id = static_cast(topk_ids[i]) + 1; + int32_t rank = atomicAdd(&local_hist[expert_id], 1); + int32_t pos = rank + expert_starts_local[expert_id]; + sorted_token_ids[pos] = i; + } +} + +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids, + int32_t max_num_tokens_padded) { + // Adapted from + // https://github.com/vllm-project/vllm/pull/29642/files#diff-5647b1413f4ae9aacba904eca8f8a8aee9079321eadff4c10101a2c6962dcc53R226 + // Use an additional group of threads to fill sorted_token_ids. + // Since the kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + if (pad_sorted_token_ids) { + for (int32_t it = threadIdx.x; it < max_num_tokens_padded; + it += fill_threads) { + sorted_token_ids[it] = numel; + } + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + ++tokens_cnts[(tid + 1) * num_experts + expert_id]; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += + tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[i / block_size] = tid - 1; + } + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } +} + +template +void moe_align_block_size(const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t block_size, + paddle::Tensor& sorted_token_ids, + paddle::Tensor& experts_ids, + paddle::Tensor& num_tokens_post_pad, + paddle::Tensor& cumsum_buffer, + bool pad_sorted_token_ids) { + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + auto stream = topk_ids.stream(); + + const size_t numel = topk_ids.numel(); + const int64_t max_num_tokens_padded = sorted_token_ids.shape()[0]; + + bool small_batch_expert_mode = (numel < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t expert_threads = max((int32_t)num_experts, WARP_SIZE); + constexpr int32_t fill_threads = 256; + const int32_t shared_mem_size = + ((expert_threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + auto small_batch_expert_kernel = + moe_align_block_size_small_batch_expert_kernel; + small_batch_expert_kernel<<<1, + fill_threads + expert_threads, + shared_mem_size, + stream>>>(topk_ids.data(), + sorted_token_ids.data(), + experts_ids.data(), + num_tokens_post_pad.data(), + num_experts, + block_size, + numel, + pad_sorted_token_ids, + max_num_tokens_padded); + } else { + // Use cooperative fused kernel for large inputs where multi-block + // parallelism outweighs cooperative launch overhead + if (numel >= 16384) { + const int coop_threads = 256; + const size_t coop_smem = (2 * num_experts + 1) * sizeof(int32_t); + + auto coop_kernel = moe_align_block_size_cooperative_kernel; + + static int cached_max_blocks_per_sm = 0; + static int cached_num_sms = 0; + if (cached_num_sms == 0) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&cached_max_blocks_per_sm, + (void*)coop_kernel, + coop_threads, + coop_smem); + int device_id; + cudaGetDevice(&device_id); + cudaDeviceGetAttribute( + &cached_num_sms, cudaDevAttrMultiProcessorCount, device_id); + } + + int max_coop_blocks = cached_max_blocks_per_sm * cached_num_sms; + int desired_blocks = std::max( + 1, std::min(256, static_cast(numel / (coop_threads * 4)))); + int coop_blocks = std::min(desired_blocks, max_coop_blocks); + if (coop_blocks < 1) coop_blocks = 1; + + const scalar_t* topk_ids_ptr = topk_ids.data(); + int32_t* sorted_token_ids_ptr = sorted_token_ids.data(); + int32_t* experts_ids_ptr = experts_ids.data(); + int32_t* num_tokens_post_pad_ptr = num_tokens_post_pad.data(); + int32_t* cumsum_ptr = cumsum_buffer.data(); + int32_t num_experts_i32 = static_cast(num_experts); + int32_t block_size_i32 = static_cast(block_size); + size_t numel_val = numel; + bool pad_val = pad_sorted_token_ids; + int32_t max_padded_i32 = static_cast(max_num_tokens_padded); + + void* args[] = {&topk_ids_ptr, + &sorted_token_ids_ptr, + &experts_ids_ptr, + &num_tokens_post_pad_ptr, + &cumsum_ptr, + &num_experts_i32, + &block_size_i32, + &numel_val, + &pad_val, + &max_padded_i32}; + + cudaError_t err = cudaLaunchCooperativeKernel((void*)coop_kernel, + dim3(coop_blocks), + dim3(coop_threads), + args, + coop_smem, + stream); + + if (err == cudaSuccess) { + return; + } + // Fall through to original path if cooperative launch failed + } + + // Original 2-kernel approach (for medium inputs or cooperative fallback) + auto align_kernel = moe_align_block_size_kernel; + + const size_t scan_size = next_pow_2(num_experts); + const size_t shared_mem_size = + (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * + sizeof(int32_t); + align_kernel<<<2, threads, shared_mem_size, stream>>>( + topk_ids.data(), + sorted_token_ids.data(), + experts_ids.data(), + num_tokens_post_pad.data(), + num_experts, + block_size, + numel, + cumsum_buffer.data(), + pad_sorted_token_ids, + scan_size, + max_num_tokens_padded); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = ((int)numel + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + auto sort_kernel = count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data(), + sorted_token_ids.data(), + cumsum_buffer.data(), + numel); + } +} + +// Explicit instantiations for use from other translation units (e.g. +// tritonmoe_preprocess.cu) +template void moe_align_block_size(const paddle::Tensor&, + int64_t, + int64_t, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + bool); +template void moe_align_block_size(const paddle::Tensor&, + int64_t, + int64_t, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + bool); diff --git a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu index 071e0a9b418..eb680ea744e 100644 --- a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu +++ b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu @@ -15,83 +15,40 @@ #include "helper.h" #include "paddle/extension.h" -#define CEILDIV(a, b) (((a + b - 1) / b)) - template -__global__ void count_and_sort_expert_tokens_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - size_t numel) { - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; - } -} - -template -__global__ void moe_align_block_size_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, - int32_t GEMM_BLOCK_SIZE_M, - size_t numel, - int32_t* __restrict__ cumsum_buffer) { - __shared__ int32_t tokens_per_ep[num_experts]; - - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { - tokens_per_ep[i] = 0; - } - - __syncthreads(); - - for (int i = threadIdx.x; i < numel; i += blockDim.x) { - int expert_id = topk_ids[i]; - atomicAdd(&tokens_per_ep[expert_id], 1); - } - - __syncthreads(); - - if (threadIdx.x == 0) { - cumsum_buffer[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = tokens_per_ep[i - 1]; - cumsum_buffer[i] = - cumsum_buffer[i - 1] + - CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M; - } - *total_tokens_post_pad = cumsum_buffer[num_experts]; - } - - __syncthreads(); - - if (threadIdx.x < num_experts) { - for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1]; - i += GEMM_BLOCK_SIZE_M) { - expert_ids[i / GEMM_BLOCK_SIZE_M] = threadIdx.x; - } - } -} +void moe_align_block_size(const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t block_size, + paddle::Tensor& sorted_token_ids, + paddle::Tensor& experts_ids, + paddle::Tensor& num_tokens_post_pad, + paddle::Tensor& cumsum_buffer, + bool pad_sorted_token_ids); std::vector> tritonmoe_preprocessInferShape( const std::vector& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) { - int topk_ids_numel = topk_ids[0] * topk_ids[1]; - int max_num_tokens_padded = - topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1); + int topk_ids_numel = 1; + for (int64_t dim : topk_ids) { + topk_ids_numel *= static_cast(dim); + } + int max_num_tokens_padded; + if (topk_ids_numel < num_experts + 1) { + max_num_tokens_padded = topk_ids_numel * GEMM_BLOCK_SIZE_M; + } else { + max_num_tokens_padded = + topk_ids_numel + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1); + } std::vector sorted_ids = {max_num_tokens_padded}; - int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M; - std::vector expert_ids = {max_num_m_blocks}; + int max_num_m_blocks = + (max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) / GEMM_BLOCK_SIZE_M; + std::vector experts_ids = {max_num_m_blocks}; std::vector num_tokens_post_pad = {1}; - return {sorted_ids, expert_ids, num_tokens_post_pad}; + return {sorted_ids, experts_ids, num_tokens_post_pad}; } std::vector tritonmoe_preprocessIferDtype( @@ -127,76 +84,50 @@ std::vector tritonmoe_preprocess_kernel( const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) { - int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1]; - int max_num_tokens_padded = - topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1); + int topk_ids_numel = static_cast(topk_ids.numel()); + + int max_num_tokens_padded; + if (topk_ids_numel < num_experts + 1) { + max_num_tokens_padded = topk_ids_numel * GEMM_BLOCK_SIZE_M; + } else { + max_num_tokens_padded = + topk_ids_numel + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1); + } auto sorted_ids = paddle::full({max_num_tokens_padded}, topk_ids_numel, paddle::DataType::INT32, topk_ids.place()); - int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M; + int max_num_m_blocks = + (max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) / GEMM_BLOCK_SIZE_M; - auto expert_ids = paddle::empty( + auto experts_ids = paddle::empty( {max_num_m_blocks}, paddle::DataType::INT32, topk_ids.place()); auto num_tokens_post_pad = paddle::empty({1}, paddle::DataType::INT32, topk_ids.place()); - auto cumsum_buffer = paddle::empty( - {num_experts + 1}, paddle::DataType::INT32, topk_ids.place()); + auto cumsum_buffer = paddle::zeros( + {num_experts + 2}, paddle::DataType::INT32, topk_ids.place()); - auto stream = topk_ids.stream(); using scalar_t = int64_t; - -#define run_align_kernel(num_experts) \ - auto align_kernel = moe_align_block_size_kernel; \ - align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data(), \ - expert_ids.data(), \ - num_tokens_post_pad.data(), \ - GEMM_BLOCK_SIZE_M, \ - topk_ids_numel, \ - cumsum_buffer.data()); - - if (num_experts == 8) { - run_align_kernel(8); - } else if (num_experts == 256) { - run_align_kernel(256); - } else if (num_experts == 2) { - run_align_kernel(2); - } else if (num_experts == 64) { - run_align_kernel(64); - } else if (num_experts == 128) { - run_align_kernel(128); - } else if (num_experts == 160) { - run_align_kernel(160); - } else if (num_experts == 32) { - run_align_kernel(32); - } else { - PD_THROW("Not support num_experts: %d", num_experts); - } - - const int block_threads = 256; - const int num_blocks = CEILDIV(topk_ids_numel, block_threads); - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - - auto sort_kernel = count_and_sort_expert_tokens_kernel; - - sort_kernel<<>>( - topk_ids.data(), - sorted_ids.data(), - cumsum_buffer.data(), - topk_ids_numel); - - return {sorted_ids, expert_ids, num_tokens_post_pad}; + moe_align_block_size(topk_ids, + num_experts + 1, + GEMM_BLOCK_SIZE_M, + sorted_ids, + experts_ids, + num_tokens_post_pad, + cumsum_buffer, + true); + + return {sorted_ids, experts_ids, num_tokens_post_pad}; } PD_BUILD_STATIC_OP(tritonmoe_preprocess) .Inputs({"topk_ids"}) .Attrs({"num_experts: int64_t", "GEMM_BLOCK_SIZE_M: int64_t"}) - .Outputs({"sorted_ids", "expert_ids", "num_tokens_post_pad"}) + .Outputs({"sorted_ids", "experts_ids", "num_tokens_post_pad"}) .SetKernelFn(PD_KERNEL(tritonmoe_preprocess_kernel)) .SetInferShapeFn(PD_INFER_SHAPE(tritonmoe_preprocessInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(tritonmoe_preprocessIferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index dcb514504e8..2e0de0123b6 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -243,6 +243,7 @@ def find_end_files(directory, end_str): "gpu_ops/set_data_ipc.cu", "gpu_ops/unset_data_ipc.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", + "gpu_ops/moe/moe_align_kernel.cu", "gpu_ops/step_system_cache.cu", "gpu_ops/get_output_ep.cc", "gpu_ops/speculate_decoding/speculate_get_padding_offset.cu", @@ -702,6 +703,7 @@ def find_end_files(directory, end_str): "gpu_ops/append_attn/mla_cache_kernel.cu", "gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", + "gpu_ops/moe/moe_align_kernel.cu", "gpu_ops/moe/moe_topk_select.cu", "gpu_ops/get_img_boundaries.cc", "gpu_ops/remote_cache_kv_ipc.cc", diff --git a/tests/operators/test_tritonmoe_preprocess.py b/tests/operators/test_tritonmoe_preprocess.py index 94d85c956e1..7071e275225 100644 --- a/tests/operators/test_tritonmoe_preprocess.py +++ b/tests/operators/test_tritonmoe_preprocess.py @@ -12,12 +12,159 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Correctness tests for tritonmoe_preprocess +========================================== + +Tests the fastdeploy wrapper: + tritonmoe_preprocess(topk_ids, num_experts, block_size) + -> (sorted_token_ids, expert_ids, num_tokens_post_padded) + +The verification approach mirrors FlagTree/python/tutorials/tle/02-moe_align_block_size.py: + - Use paddle.bincount as an independent reference (no second kernel to cross-compare). + - Validate three dimensions: + 1. num_tokens_post_padded – total token count after per-expert block alignment + 2. expert_ids – each block is mapped to the correct expert + 3. sorted_token_ids – every token is routed to the right expert's slot, + and padding slots carry sentinel values >= num_tokens +""" + import unittest import numpy as np import paddle -from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess +# --------------------------------------------------------------------------- +# Import guard – skip entire module when CUDA is unavailable or +# fastdeploy is not installed (e.g. CPU-only CI environments). +# --------------------------------------------------------------------------- +try: + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess + + _AVAILABLE = paddle.device.is_compiled_with_cuda() +except Exception: + _AVAILABLE = False + +DEVICE = "gpu" + +# 仅对小规模 case 打印详细 tensor,超过此阈值只打印统计摘要 +_PRINT_TENSOR_NUMEL_LIMIT = 64 + + +def _fmt_tensor(t: paddle.Tensor, name: str) -> str: + t_cpu = t.cpu() + if t_cpu.numel() <= _PRINT_TENSOR_NUMEL_LIMIT: + return f"{name}{list(t_cpu.shape)} = {t_cpu.tolist()}" + return ( + f"{name}{list(t_cpu.shape)} | " + f"min={int(t_cpu.min())} max={int(t_cpu.max())} " + f"mean={float(t_cpu.cast('float32').mean()):.2f} numel={t_cpu.numel()}" + ) + + +# --------------------------------------------------------------------------- +# Reference helpers (CPU, independent of the kernel under test) +# --------------------------------------------------------------------------- + + +def _ref_counts_and_cumsum(topk_ids_flat: paddle.Tensor, num_experts: int, block_size: int): + """ + Compute per-expert token counts and the cumulative sum of block-aligned counts. + + Returns: + counts : int32 tensor of shape (num_experts,) + cumsum : int32 tensor of shape (num_experts,) – cumulative aligned counts + """ + # Only consider valid expert ids [0, num_experts); ignore -1 (EP filtered) + valid_mask = (topk_ids_flat >= 0) & (topk_ids_flat < num_experts) + valid_ids = topk_ids_flat[valid_mask] + counts = paddle.bincount(valid_ids.cast("int64"), minlength=num_experts).cast("int32") + aligned = ((counts + block_size - 1) // block_size) * block_size + cumsum = paddle.cumsum(aligned, axis=0).cast("int32") + return counts, cumsum + + +# --------------------------------------------------------------------------- +# Core verification logic (shared across all test cases) +# --------------------------------------------------------------------------- + + +def _verify(topk_ids: paddle.Tensor, block_size: int, num_experts: int, label: str = ""): + """ + Run tritonmoe_preprocess and verify all three output tensors. + topk_ids may be 1-D or 2-D; dtype int32 or int64. + Prints inputs, golden references, kernel outputs, and per-check comparison. + """ + tag = f"[{label}] " if label else "" + sep = "=" * 70 + + sorted_token_ids, expert_ids, num_tokens_post_pad = tritonmoe_preprocess(topk_ids, num_experts, block_size) + + topk_ids_flat = topk_ids.flatten().cast("int64").cpu() + num_tokens = topk_ids_flat.numel() + + counts, cumsum = _ref_counts_and_cumsum(topk_ids_flat, num_experts, block_size) + aligned = ((counts + block_size - 1) // block_size) * block_size + valid_length = int(cumsum[-1].item()) + num_blocks = valid_length // block_size + + expected_expert_ids = paddle.repeat_interleave( + paddle.arange(num_experts, dtype="int32"), # CPU + (aligned // block_size).cast("int32"), + ) + + np.testing.assert_array_equal( + num_tokens_post_pad.cpu().numpy(), + cumsum[-1:].cpu().numpy(), + ) + + # ------------------------------------------------------------------ # + # Check 2: expert_ids – each block maps to the expected expert # + # ------------------------------------------------------------------ # + got_eids = expert_ids[:num_blocks].cpu() + want_eids = expected_expert_ids.cpu() + np.testing.assert_array_equal( + got_eids.numpy(), + want_eids.numpy(), + ) + + # ------------------------------------------------------------------ # + # Check 3: sorted_token_ids – routing correctness per expert # + # ------------------------------------------------------------------ # + + start = 0 + for expert_id in range(num_experts): + end = int(cumsum[expert_id].item()) + tokens = sorted_token_ids[start:end].cpu() + valid_tokens = tokens[tokens < num_tokens] + # padding_tokens = tokens[tokens >= num_tokens] + + want_count = int(counts[expert_id].item()) + got_count = valid_tokens.numel() + count_ok = got_count == want_count + + assert count_ok, f"expert {expert_id}: expected {want_count} valid tokens, got {got_count}" + if counts[expert_id] > 0: + np.testing.assert_array_equal( + topk_ids_flat[valid_tokens.cast("int64")].numpy(), + paddle.full_like(valid_tokens, expert_id).numpy(), + ) + start = end + + # padding 区域哨兵检查 + if valid_length < sorted_token_ids.numel(): + padding_region = sorted_token_ids[valid_length:].cpu() + sentinel_ok = paddle.all(padding_region >= num_tokens).item() + + assert sentinel_ok, "padding slots beyond valid_length contain non-sentinel values" + + print(f"\n{tag}ALL CHECKS PASSED") + print(sep) + + +# --------------------------------------------------------------------------- +# Original unittest-based tests (kept for backward compatibility) +# --------------------------------------------------------------------------- class TestTritonMOEPreprocess(unittest.TestCase): @@ -35,10 +182,14 @@ def _check_output_shapes( self, sorted_ids, expert_ids, num_tokens_post_pad, topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M ): """Check output shapes and dtypes""" - expected_max_num_tokens_padded = topk_ids_np.size + num_experts * (GEMM_BLOCK_SIZE_M - 1) + if topk_ids_np.size < num_experts + 1: + expected_max_num_tokens_padded = topk_ids_np.size * GEMM_BLOCK_SIZE_M + else: + expected_max_num_tokens_padded = topk_ids_np.size + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1) + self.assertEqual(sorted_ids.shape[0], expected_max_num_tokens_padded) - expected_max_num_m_blocks = expected_max_num_tokens_padded // GEMM_BLOCK_SIZE_M + expected_max_num_m_blocks = (expected_max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) // GEMM_BLOCK_SIZE_M self.assertEqual(expert_ids.shape[0], expected_max_num_m_blocks) self.assertEqual(num_tokens_post_pad.shape[0], 1) @@ -104,17 +255,232 @@ def test_basic_case(self): ) self._check_output_values_basic(sorted_ids, expert_ids, num_tokens_post_pad) - def test_unsupported_num_experts(self): - """Test unsupported num_experts raises OSError""" - topk_ids_np = np.array([[0, 1], [1, 0]], dtype=np.int64) - unsupported_experts = [3, 9, 65, 129] - GEMM_BLOCK_SIZE_M = 4 - for num_experts in unsupported_experts: - with self.subTest(num_experts=num_experts): - with self.assertRaises(OSError): - self._run_op(topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M) +# --------------------------------------------------------------------------- +# Correctness tests (ported from test_moe_align_block_size.py) +# --------------------------------------------------------------------------- + + +class TestTritonMoePreprocessBasic(unittest.TestCase): + """Basic / small cases – easy to reason about manually.""" + + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def test_docstring_example(self): + """Reproduce the example from the function docstring.""" + topk_ids = paddle.to_tensor([[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], dtype="int64") + _verify(topk_ids, block_size=4, num_experts=5, label="docstring_example") + + def test_single_token_single_expert(self): + """Minimal input: one token assigned to one expert.""" + topk_ids = paddle.to_tensor([[0]], dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="single_token_single_expert") + + def test_all_tokens_same_expert(self): + """All tokens go to expert 0 – only one expert's slot is used.""" + topk_ids = paddle.zeros((64, 1), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="all_tokens_same_expert") + + def test_uniform_1d(self): + """1-D topk_ids (top_k=1 squeezed) with uniform distribution.""" + paddle.seed(42) + topk_ids = paddle.randint(0, 8, (128,), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="uniform_1d") + + def test_topk_equals_num_experts(self): + """Every token selects all experts (top_k == num_experts).""" + num_experts = 4 + topk_ids = paddle.arange(num_experts, dtype="int64").unsqueeze(0).expand((8, num_experts)) + _verify(topk_ids, block_size=4, num_experts=num_experts, label="topk_equals_num_experts") + + def test_num_tokens_less_than_num_experts(self): + """Fewer tokens than experts – exercises the small-input branch.""" + topk_ids = paddle.to_tensor([[0], [3]], dtype="int64") + _verify(topk_ids, block_size=16, num_experts=64, label="num_tokens_less_than_num_experts") + + def test_exact_block_boundary(self): + """Token count per expert is exactly block_size – no padding needed.""" + block_size = 16 + num_experts = 4 + topk_ids = paddle.concat([paddle.full((block_size,), e, dtype="int64") for e in range(num_experts)]) + _verify(topk_ids, block_size=block_size, num_experts=num_experts, label="exact_block_boundary") + + def test_block_size_1(self): + """block_size=1 means no padding is ever added.""" + paddle.seed(0) + topk_ids = paddle.randint(0, 16, (64,), dtype="int64") + _verify(topk_ids, block_size=1, num_experts=16, label="block_size_1") + + +class TestTritonMoePreprocessEdgeCases(unittest.TestCase): + """Edge / boundary cases.""" + + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def test_empty_topk_ids(self): + """Zero-token input should not crash; num_tokens_post_pad == 0.""" + topk_ids = paddle.empty((0,), dtype="int64").cuda() + sorted_ids, expert_ids_out, num_post = tritonmoe_preprocess(topk_ids, 8, 16) + got = int(num_post.item()) + + self.assertEqual(got, 0) + + def test_one_expert(self): + """Single expert: all tokens must end up in expert 0's bucket.""" + paddle.seed(1) + topk_ids = paddle.zeros((32,), dtype="int64") + _verify(topk_ids, block_size=8, num_experts=1, label="one_expert") + + def test_large_block_size(self): + """block_size larger than total tokens.""" + topk_ids = paddle.randint(0, 4, (8,), dtype="int64") + _verify(topk_ids, block_size=128, num_experts=4, label="large_block_size") + + def test_int64_dtype(self): + """topk_ids in int64 – the kernel should handle dtype conversion.""" + paddle.seed(7) + topk_ids = paddle.randint(0, 8, (64, 2), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="int64_dtype") + + +class TestTritonMoePreprocessRealistic(unittest.TestCase): + """Larger, more realistic MoE shapes.""" + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def _run_uniform_distribution(self, num_tokens, num_experts, block_size): + """Uniform random token-to-expert assignment across common MoE shapes.""" + paddle.seed(0) + topk_ids = paddle.randint(0, num_experts, (num_tokens,), dtype="int64") + _verify( + topk_ids, + block_size=block_size, + num_experts=num_experts, + label=f"uniform_T{num_tokens}_E{num_experts}_B{block_size}", + ) + + def test_uniform_distribution(self): + """Uniform random token-to-expert assignment across common MoE shapes.""" + for num_tokens, num_experts, block_size in [ + (256, 8, 16), + (1024, 16, 16), + (4096, 64, 16), + (8192, 64, 32), + (8192, 128, 64), + (16384, 128, 128), + (16384, 256, 128), + (16384, 512, 256), + (32768, 512, 256), + (32768, 512, 64), + (163840, 1024, 256), + ]: + with self.subTest(num_tokens=num_tokens, num_experts=num_experts, block_size=block_size): + self._run_uniform_distribution(num_tokens, num_experts, block_size) + + def _run_topk_2d(self, num_tokens, top_k, num_experts, block_size): + """2-D topk_ids as produced by the router (shape [num_tokens, top_k]).""" + paddle.seed(0) + topk_ids = paddle.randint(0, num_experts, (num_tokens, top_k), dtype="int64") + _verify( + topk_ids, + block_size=block_size, + num_experts=num_experts, + label=f"topk2d_T{num_tokens}_K{top_k}_E{num_experts}_B{block_size}", + ) + + def test_topk_2d(self): + """2-D topk_ids as produced by the router (shape [num_tokens, top_k]).""" + for num_tokens, top_k, num_experts, block_size in [ + (512, 2, 8, 16), + (1024, 4, 16, 16), + (2048, 8, 64, 16), + ]: + with self.subTest(num_tokens=num_tokens, top_k=top_k, num_experts=num_experts, block_size=block_size): + self._run_topk_2d(num_tokens, top_k, num_experts, block_size) + + def _run_zipf_distribution(self, alpha): + """Skewed (Zipf) token distribution – simulates real MoE load imbalance.""" + num_tokens, num_experts, block_size = 8192, 64, 16 + ranks = paddle.arange(1, num_experts + 1, dtype="float32") + probs = 1.0 / ranks**alpha + probs = probs / probs.sum() + paddle.seed(0) + topk_ids = paddle.multinomial(probs, num_tokens, replacement=True).cast("int64") + _verify(topk_ids, block_size=block_size, num_experts=num_experts, label=f"zipf_alpha{alpha}") + + def test_zipf_distribution(self): + """Skewed (Zipf) token distribution – simulates real MoE load imbalance.""" + for alpha in [0.5, 1.2, 2.0]: + with self.subTest(alpha=alpha): + self._run_zipf_distribution(alpha) + + def test_deterministic_with_fixed_seed(self): + """Same seed must produce the same outputs (kernel is deterministic).""" + num_tokens, num_experts, block_size = 4096, 64, 16 + + paddle.seed(99) + topk_ids = paddle.randint(0, num_experts, (num_tokens,), dtype="int64").cuda() + s1, e1, n1 = tritonmoe_preprocess(topk_ids, num_experts, block_size) + + paddle.seed(99) + topk_ids2 = paddle.randint(0, num_experts, (num_tokens,), dtype="int64").cuda() + s2, e2, n2 = tritonmoe_preprocess(topk_ids2, num_experts, block_size) + + valid = int(n1.item()) + + np.testing.assert_array_equal(n1.numpy(), n2.numpy()) + np.testing.assert_array_equal(e1[: valid // block_size].numpy(), e2[: valid // block_size].numpy()) + np.testing.assert_array_equal(paddle.sort(s1[:valid]).numpy(), paddle.sort(s2[:valid]).numpy()) + + +# --------------------------------------------------------------------------- +# Direct-run entry point (python test_tritonmoe_preprocess.py) +# --------------------------------------------------------------------------- if __name__ == "__main__": - unittest.main() + if not _AVAILABLE: + print("SKIP: CUDA or fastdeploy not available.") + else: + basic = TestTritonMoePreprocessBasic() + basic.test_docstring_example() + basic.test_single_token_single_expert() + basic.test_all_tokens_same_expert() + basic.test_uniform_1d() + basic.test_topk_equals_num_experts() + basic.test_num_tokens_less_than_num_experts() + basic.test_exact_block_boundary() + basic.test_block_size_1() + + edge = TestTritonMoePreprocessEdgeCases() + edge.test_empty_topk_ids() + edge.test_one_expert() + edge.test_large_block_size() + edge.test_int64_dtype() + + real = TestTritonMoePreprocessRealistic() + for num_tokens, num_experts, block_size in [ + (256, 8, 16), + (1024, 16, 16), + (4096, 64, 16), + (8192, 64, 32), + (8192, 128, 64), + (16384, 256, 128), + ]: + real._run_uniform_distribution(num_tokens, num_experts, block_size) + for num_tokens, top_k, num_experts, block_size in [ + (512, 2, 8, 16), + (1024, 4, 16, 16), + (2048, 8, 64, 16), + ]: + real._run_topk_2d(num_tokens, top_k, num_experts, block_size) + for alpha in [0.5, 1.2, 2.0]: + real._run_zipf_distribution(alpha) + real.test_deterministic_with_fixed_seed() + + print("\n*** All direct-run tests passed ***") From 04e4ae8db25e79ec4a6466c155331ab90c97191a Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Fri, 15 May 2026 11:47:22 +0800 Subject: [PATCH 112/143] [Cherry-Pick][BugFix] Fix pause drain hang caused by stale abort markers(#7825) (#7826) --- fastdeploy/engine/common_engine.py | 31 +++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index c31c9039b40..f553a4f8ee5 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1278,17 +1278,30 @@ def _control_pause(self, control_request: ControlRequest): def _wait_inflight_drained(self): """ Wait until resource_manager.requests is completely empty. - No timeout — abort pipeline will complete. Aligned with SGLang's poll-until-drained. + No timeout — abort pipeline will complete. + Logs a warning every 30 seconds while waiting to help diagnose potential hangs. """ - start_time = time.time() - while ( - self.resource_manager.requests - or self.scheduler.requests - or self.resource_manager.waiting_abort_req_id_set - or self.resource_manager.to_be_aborted_req_id_set - ): + start_time = time.monotonic() + next_warn_time = start_time + 30 + + while self.resource_manager.requests or self.scheduler.requests: + now = time.monotonic() + + if now >= next_warn_time: + self.llm_logger.warning( + "Still waiting for inflight requests to drain, " + f"elapsed: {now - start_time:.3f} seconds, " + f"resource_manager.requests: {len(self.resource_manager.requests)}, " + f"scheduler.requests: {len(self.scheduler.requests)}", + ) + next_warn_time = now + 30 + time.sleep(0.005) - self.llm_logger.info(f"All inflight requests drained, take time: {time.time() - start_time:.3f} seconds") + + self.llm_logger.info( + "All inflight requests drained, take time: %.3f seconds", + time.monotonic() - start_time, + ) def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: """Control function for resuming request generation. From d71bddacc59ce2ce87b9e4edc5cd70c45fd285dc Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Fri, 15 May 2026 16:30:58 +0800 Subject: [PATCH 113/143] [Cherry-Pick][CI] Optimize clean_ports logic by removing redundant code(#7809) (#7830) --- .github/workflows/_unit_test_coverage.yml | 2 +- tests/engine/test_common_engine.py | 17 ++- tests/entrypoints/openai/test_run_batch.py | 127 ++------------------- tests/layers/test_plas_attention.py | 11 ++ tests/model_loader/test_load_ernie_vl.py | 103 ++--------------- 5 files changed, 47 insertions(+), 213 deletions(-) diff --git a/.github/workflows/_unit_test_coverage.yml b/.github/workflows/_unit_test_coverage.yml index 0a939c49583..a2b72eda854 100644 --- a/.github/workflows/_unit_test_coverage.yml +++ b/.github/workflows/_unit_test_coverage.yml @@ -394,7 +394,7 @@ jobs: wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..." fi if [ -f "${filename}" ];then - echo "Failed test cases:" + echo "GPU Patch Coverage Details:" if command -v jq >/dev/null 2>&1; then jq . "${filename}" else diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 7e7c660964b..9c6baeae348 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -28,7 +28,7 @@ import numpy as np import paddle -from e2e.utils.serving_utils import clean_ports +from e2e.utils.serving_utils import PORTS_TO_CLEAN, clean_ports if not hasattr(paddle, "enable_compat"): paddle.enable_compat = lambda scope=None: None @@ -512,6 +512,21 @@ def _make_cfg(self, **kwargs): engine_worker_queue_port = [engine_worker_queue_port + 21 + i for i in range(dp // nnode)] cache_queue_port = [cache_queue_port + 21 + i for i in range(dp // nnode)] + # Add ports to cleanup list + ports_to_add = [] + if isinstance(engine_worker_queue_port, list): + ports_to_add.extend(engine_worker_queue_port) + else: + ports_to_add.append(engine_worker_queue_port) + if isinstance(cache_queue_port, list): + ports_to_add.extend(cache_queue_port) + else: + ports_to_add.append(cache_queue_port) + + for port in ports_to_add: + if port not in PORTS_TO_CLEAN: + PORTS_TO_CLEAN.append(port) + if kwargs.get("num_gpu_blocks_override") is not None and "kv_cache_ratio" not in kwargs: kwargs["kv_cache_ratio"] = 1 diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index 50410ccf236..db871cc7a73 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -19,7 +19,6 @@ import os import shutil import signal -import socket import subprocess import sys import tempfile @@ -63,124 +62,16 @@ write_local_file, ) -# Read ports from environment variables; use default values if not set -FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) -FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) -FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) -FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) - -# List of ports to clean before and after tests -PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] - - -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False - - -def _clean_cuda_process(): - """ - Kill processes that are using CUDA devices. - NOTE: Do not call this function directly, use the `clean` function instead. - """ - try: - subprocess.run("fuser -k /dev/nvidia*", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def kill_process_on_port(port: int): - """ - Kill processes that are listening on the given port. - Uses multiple methods to ensure thorough cleanup. - """ - current_pid = os.getpid() - parent_pid = os.getppid() - - # Method 1: Use lsof to find processes - try: - output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() - for pid in output.splitlines(): - pid = int(pid) - if pid in (current_pid, parent_pid): - print(f"Skip killing current process (pid={pid}) on port {port}") - continue - try: - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - time.sleep(1) - # Then SIGKILL if still running - os.kill(pid, signal.SIGKILL) - print(f"Killed process on port {port}, pid={pid}") - except ProcessLookupError: - pass # Process already terminated - except subprocess.CalledProcessError: - pass - - # Method 2: Use netstat and fuser as backup - try: - # Find processes using netstat and awk - cmd = f"netstat -tulpn 2>/dev/null | grep :{port} | awk '{{print $7}}' | cut -d'/' -f1" - output = subprocess.check_output(cmd, shell=True).decode().strip() - for pid in output.splitlines(): - if pid and pid.isdigit(): - pid = int(pid) - if pid in (current_pid, parent_pid): - continue - try: - os.kill(pid, signal.SIGKILL) - print(f"Killed process (netstat) on port {port}, pid={pid}") - except ProcessLookupError: - pass - except (subprocess.CalledProcessError, FileNotFoundError): - pass - - # Method 3: Use fuser if available - try: - subprocess.run(f"fuser -k {port}/tcp", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def clean_ports(ports=None): - """ - Kill all processes occupying the ports - """ - if ports is None: - ports = PORTS_TO_CLEAN - - print(f"Cleaning ports: {ports}") - for port in ports: - kill_process_on_port(port) - - # Double check and retry if ports are still in use - time.sleep(2) - for port in ports: - if is_port_open("127.0.0.1", port, timeout=0.1): - print(f"Port {port} still in use, retrying cleanup...") - kill_process_on_port(port) - time.sleep(1) - - -def clean(ports=None): - """ - Clean up resources used during testing. - """ - clean_ports(ports) - - # Clean CUDA devices before and after tests. - # NOTE: It is dangerous to use this flag on development machines, as it may kill other processes - clean_cuda = int(os.getenv("CLEAN_CUDA", "0")) == 1 - if clean_cuda: - _clean_cuda_process() +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) +from e2e.utils.serving_utils import ( + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + clean_ports, +) INPUT_BATCH = """ {"custom_id": "req-00001", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "Can you write a short poem? (id=1)"}], "temperature": 0.7, "max_tokens": 200}} diff --git a/tests/layers/test_plas_attention.py b/tests/layers/test_plas_attention.py index 663b27dc9ab..e593595fa5a 100644 --- a/tests/layers/test_plas_attention.py +++ b/tests/layers/test_plas_attention.py @@ -12,8 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys + import paddle +tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, tests_dir) + +from e2e.utils.serving_utils import clean_ports + try: from fastdeploy.model_executor.ops.gpu import ( fused_block_mean_and_rope, @@ -338,6 +346,9 @@ def test_plas_attention(self): self.compare_attn(attn_out, qk_gate_topk_idx) def test_server(self): + # Clean ports before starting the test + clean_ports() + if get_cur_cu_seq_len_k is None: return os.environ["FD_ATTENTION_BACKEND"] = "PLAS_ATTN" diff --git a/tests/model_loader/test_load_ernie_vl.py b/tests/model_loader/test_load_ernie_vl.py index abbdeb542f5..129c6076533 100644 --- a/tests/model_loader/test_load_ernie_vl.py +++ b/tests/model_loader/test_load_ernie_vl.py @@ -15,7 +15,6 @@ import json import os import signal -import socket import subprocess import sys import time @@ -28,96 +27,14 @@ if project_root not in sys.path: sys.path.insert(0, project_root) -# Read ports from environment variables; use default values if not set -FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) -FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) -FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) -FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) - -# List of ports to clean before and after tests -PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] - - -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False - - -def kill_process_on_port(port: int): - """ - Kill processes that are listening on the given port. - Uses multiple methods to ensure thorough cleanup. - """ - current_pid = os.getpid() - parent_pid = os.getppid() - - # Method 1: Use lsof to find processes - try: - output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() - for pid in output.splitlines(): - pid = int(pid) - if pid in (current_pid, parent_pid): - print(f"Skip killing current process (pid={pid}) on port {port}") - continue - try: - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - time.sleep(1) - # Then SIGKILL if still running - os.kill(pid, signal.SIGKILL) - print(f"Killed process on port {port}, pid={pid}") - except ProcessLookupError: - pass # Process already terminated - except subprocess.CalledProcessError: - pass - - # Method 2: Use netstat and fuser as backup - try: - # Find processes using netstat and awk - cmd = f"netstat -tulpn 2>/dev/null | grep :{port} | awk '{{print $7}}' | cut -d'/' -f1" - output = subprocess.check_output(cmd, shell=True).decode().strip() - for pid in output.splitlines(): - if pid and pid.isdigit(): - pid = int(pid) - if pid in (current_pid, parent_pid): - continue - try: - os.kill(pid, signal.SIGKILL) - print(f"Killed process (netstat) on port {port}, pid={pid}") - except ProcessLookupError: - pass - except (subprocess.CalledProcessError, FileNotFoundError): - pass - - # Method 3: Use fuser if available - try: - subprocess.run(f"fuser -k {port}/tcp", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def clean_ports(): - """ - Kill all processes occupying the ports listed in PORTS_TO_CLEAN. - """ - print(f"Cleaning ports: {PORTS_TO_CLEAN}") - for port in PORTS_TO_CLEAN: - kill_process_on_port(port) - - # Double check and retry if ports are still in use - time.sleep(2) - for port in PORTS_TO_CLEAN: - if is_port_open("127.0.0.1", port, timeout=0.1): - print(f"Port {port} still in use, retrying cleanup...") - kill_process_on_port(port) - time.sleep(1) +from e2e.utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean_ports, + is_port_open, +) @pytest.fixture(scope="session", autouse=True) @@ -184,8 +101,8 @@ def setup_and_run_server(): start_new_session=True, # Enables killing full group via os.killpg ) - # Wait up to 10 minutes for API server to be ready - for _ in range(10 * 60): + # Wait up to 5 minutes for API server to be ready + for _ in range(300): if is_port_open("127.0.0.1", FD_API_PORT): print(f"API server is up on port {FD_API_PORT}") break From 514ed5c6bd784c7be142d1cae3755f0c914bc89a Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Mon, 18 May 2026 11:29:16 +0800 Subject: [PATCH 114/143] [Cherry-Pick][Op][Optimization]Kernel fusion: cast+sigmoid+bias+noauxtc(#7777) (#7818) * [Op][Optimization]Kernel fusion: cast+sigmoid+bias+noauxtc (#7777) [Cherry-Pick][Op][Optimization]Kernel fusion: cast+sigmoid+bias+noauxtc (#7777) * Bug fixes and modifications to the fused kernel switch. * fix replicated env args --- custom_ops/gpu_ops/cpp_extensions.cc | 11 + custom_ops/gpu_ops/grouped_topk_kernels.cu | 786 ++++++++++++++++++ custom_ops/setup_ops.py | 2 + fastdeploy/engine/args_utils.py | 4 +- .../layers/moe/fused_moe_triton_backend.py | 17 +- fastdeploy/model_executor/layers/moe/moe.py | 35 +- tests/operators/test_grouped_topk_op.py | 485 +++++++++++ 7 files changed, 1319 insertions(+), 21 deletions(-) create mode 100644 custom_ops/gpu_ops/grouped_topk_kernels.cu create mode 100644 tests/operators/test_grouped_topk_op.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 591cf363f06..e2a2fc1b92f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -691,6 +691,15 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor); + std::vector FusedCastSigmoidBias(const paddle::Tensor& input, const paddle::Tensor& bias, std::string cast_type); @@ -1704,6 +1713,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("grouped_topk", &grouped_topk, "fused grouped topk for MoE routing"); + m.def("fused_cast_sigmoid_bias", &FusedCastSigmoidBias, "Fused cast+sigmoid+bias for MoE gating scores", diff --git a/custom_ops/gpu_ops/grouped_topk_kernels.cu b/custom_ops/gpu_ops/grouped_topk_kernels.cu new file mode 100644 index 00000000000..ef5ed8533f0 --- /dev/null +++ b/custom_ops/gpu_ops/grouped_topk_kernels.cu @@ -0,0 +1,786 @@ + +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "helper.h" + +namespace cg = cooperative_groups; + +constexpr unsigned FUSED_FULL_WARP_MASK = 0xffffffff; + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ inline float cuda_cast(__half val) { + return __half2float(val); +} + +template <> +__device__ inline __half cuda_cast<__half, float>(float val) { + return __float2half(val); +} + +// Numerically stable sigmoid via tanh: σ(x) = 0.5 * tanh(0.5*x) + 0.5 +template +__device__ __forceinline__ T sigmoid_device(T x) { + float xf = cuda_cast(x); + return cuda_cast(0.5f * tanhf(0.5f * xf) + 0.5f); +} + +// Sigmoid matching fused_cast_sigmoid_bias: 1 / (1 + exp(-x)). +// Must use the same formula to get bit-identical results when comparing +// against the fused_cast_sigmoid_bias + noaux_tc path. +template +__device__ __forceinline__ float sigmoid_to_float(InT x) { + float xf = cuda_cast(x); + return 1.0f / (1.0f + expf(-xf)); +} + +template +__device__ inline T neg_inf() { + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + +template +__device__ inline bool is_finite_val(T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + +namespace warp_topk_fused { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) return 0; + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, + T baseline, + idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + return res; +} + +template +struct BitonicMerge { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) + is_better = is_better_than( + val, other_val, idx_arr[i], idx_arr[other_i]); + else + is_better = is_better_than(val, other_val); + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = + __shfl_xor_sync(FUSED_FULL_WARP_MASK, *idx_arr, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FUSED_FULL_WARP_MASK, idx, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + __device__ __forceinline__ idxT get_idx(int i = 0) const { + return idx_arr_[i]; + } + __device__ __forceinline__ T get_val(int i = 0) const { return val_arr_[i]; } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + int const lane_; + idxT const k_; + T const dummy_; +}; + +// WarpSelect WITHOUT __syncthreads() in done() — safe when only one warp is +// active. +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_idx_(0), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + else + do_add = is_better_than(val, k_th_); + + uint32_t mask = __ballot_sync(FUSED_FULL_WARP_MASK, do_add); + if (mask == 0) return; + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + // NOTE: no __syncthreads() here — callers must sync externally if needed. + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync( + FUSED_FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) + k_th_idx_ = __shfl_sync( + FUSED_FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + T& old = val_arr_[max_arr_len_ - 1]; + bool is_better; + if constexpr (is_stable) + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + else + is_better = is_better_than(val, old); + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + BitonicMerge::merge( + val_arr_, idx_arr_); + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; + +} // namespace warp_topk_fused + +// --------------------------------------------------------------------------- +// Fused kernel: group-score computation + group selection + expert topk +// + sparse scores write-back, in one kernel launch. +// +// gridDim = num_tokens (one block per token) +// blockDim = n_group * WARP_SIZE (one warp per group) +// --------------------------------------------------------------------------- +template +__global__ void grouped_topk_fused_kernel( + float* scores, // output: sparse routing weights [num_tokens, num_experts] + float* topk_values, // output: topk routing weights [num_tokens, topk] + IdxT* topk_indices, // output: topk expert indices [num_tokens, topk] + InT const* gating_output, // input: raw logits (float or bf16) + // [num_tokens, num_experts] + float const* e_score_correction_bias, // input: bias [num_experts] + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double routed_scaling_factor) { + int32_t const token_id = static_cast(blockIdx.x); + if (token_id >= static_cast(num_tokens)) return; + + int32_t const warp_id = threadIdx.x / WARP_SIZE; + int32_t const lane_id = threadIdx.x % WARP_SIZE; + int32_t const n_group_i32 = static_cast(n_group); + int32_t const topk_group_i32 = static_cast(topk_group); + int32_t const topk_i32 = static_cast(topk); + int32_t const num_warps = blockDim.x / WARP_SIZE; + + if (warp_id >= n_group_i32 || num_warps < n_group_i32) return; + + int32_t const num_experts_per_group = + static_cast(num_experts) / n_group_i32; + int32_t const align_epg = warp_topk_fused::round_up_to_multiple_of( + num_experts_per_group); + + InT const* gate_token = gating_output + (int64_t)token_id * num_experts; + float* scores_token = scores + (int64_t)token_id * num_experts; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + // smem layout: [val_staging (256B-aligned) | idx_staging | (16B pad) | + // s_group_scores] + extern __shared__ char smem_buf[]; + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + uintptr_t ptr = + (reinterpret_cast(smem_buf + val_aligned + idx_bytes) + 15) & + ~static_cast(15); + float* s_group_scores = reinterpret_cast(ptr); + float* s_topk_value = + reinterpret_cast(smem_buf); // val_staging (256B-aligned) + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // ------------------------------------------------------------------ + // Phase 1 (all warps): compute group score = top2 sum of (gate + bias) + // ------------------------------------------------------------------ + { + int32_t const offset = warp_id * num_experts_per_group; + InT const* gate_g = gate_token + offset; + float const* bias_g = e_score_correction_bias + offset; + + float largest = neg_inf(); + float second_largest = neg_inf(); + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + float val = sigmoid_to_float(gate_g[i]) + bias_g[i]; + if (val > largest) { + second_largest = largest; + largest = val; + } else if (val > second_largest) { + second_largest = val; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) + largest = sigmoid_to_float(gate_g[i]) + bias_g[i]; + } + __syncwarp(); + float max1 = cg::reduce(tile, largest, cg::greater()); + float max2 = max1; + int cnt = __popc(__ballot_sync(FUSED_FULL_WARP_MASK, largest == max1)); + if (cnt == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + if (lane_id == 0) s_group_scores[warp_id] = max1 + max2; + } + + __syncthreads(); // __syncwarp() maybe better? + + // ------------------------------------------------------------------ + // Phase 2 (warp0 only): group selection → expert selection → output + // ------------------------------------------------------------------ + if (warp_id != 0) return; + + float value = neg_inf(); + float topk_group_value = neg_inf(); + int32_t num_equalto_topkth_group; + if (token_id < num_tokens) { + int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && (isfinite(s_group_scores[lane_id]))) { + value = s_group_scores[lane_id]; + } + + int neg_inf_num = WARP_SIZE - n_group; + int last_neg_inf_num = 0; + // Use loop to find the largset top_group + while (neg_inf_num < want_neg_inf_num) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = neg_inf(); + } + last_neg_inf_num = neg_inf_num; + + neg_inf_num = __popc( + __ballot_sync(FUSED_FULL_WARP_MASK, (value == neg_inf()))); + } + // There is a possible case: + // may have many different group holding the same score! + // but we only accept some of them! + num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num; + } + __syncwarp(); + + warp_topk_fused::WarpSelect + queue((int32_t)topk, neg_inf()); + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = (topk_group_value != neg_inf()); + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((s_group_scores[i_group] > topk_group_value) || + ((s_group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_epg; i += WARP_SIZE) { + float candidates = neg_inf(); + if (i < num_experts_per_group) { + float biased = sigmoid_to_float(gate_token[offset + i]) + + e_score_correction_bias[offset + i]; + if (is_finite_val(biased)) candidates = biased; + } + queue.add(candidates, offset + i); + } + if (s_group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + } + + float topk_sum = 1e-20; + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk_fused::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + int32_t idx = i / WARP_SIZE; + float value = + i < topk ? sigmoid_to_float(gate_token[queue.get_idx(idx)]) : 0.0f; + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += cg::reduce(tile, value, cg::plus()); + } + } + __syncwarp(); + + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores_token[i] = 0; + } + } + __syncwarp(); + + topk_values += (int64_t)token_id * topk; + topk_indices += (int64_t)token_id * topk; + if (token_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = s_topk_value[i] / topk_sum * routed_scaling_factor; + } else { + value = s_topk_value[i] * routed_scaling_factor; + } + int32_t idx = i / WARP_SIZE; // topk may be bigger than WARP_SIZE + scores_token[queue.get_idx(idx)] = value; + topk_indices[i] = queue.get_idx(idx); + topk_values[i] = value; + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + int32_t idx = i / WARP_SIZE; + topk_indices[i] = queue.get_idx(idx); + topk_values[i] = static_cast(1.0f / topk); + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +// --------------------------------------------------------------------------- +// Launch wrapper +// --------------------------------------------------------------------------- +template +void invokeFusedNoAuxTc(InT* gating_output, + float* e_score_correction_bias, + float* scores, + float* topk_values, + IdxT* topk_indices, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double const routed_scaling_factor, + cudaStream_t const stream) { + auto* kernel = &grouped_topk_fused_kernel; + + // blockDim = n_group * WARP_SIZE (one warp per group) + int32_t const num_warps = static_cast(n_group); + + // smem = WarpSelect staging (float) + 16B pad + group_scores buffer (float) + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + size_t const extra_bytes = 16 + static_cast(n_group) * sizeof(float); + size_t const smem_bytes = val_aligned + idx_bytes + extra_bytes; + + cudaLaunchConfig_t config; + config.gridDim = static_cast(num_tokens); + config.blockDim = static_cast(n_group) * WARP_SIZE; + config.dynamicSmemBytes = smem_bytes; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, + kernel, + scores, + topk_values, + topk_indices, + gating_output, + e_score_correction_bias, + num_tokens, + num_experts, + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor); +} + +#define INSTANTIATE_FUSED_NOAUX_TC(InT, IdxT) \ + template void invokeFusedNoAuxTc( \ + InT * gating_output, \ + float* e_score_correction_bias, \ + float* scores, \ + float* topk_values, \ + IdxT* topk_indices, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + cudaStream_t const stream); + +INSTANTIATE_FUSED_NOAUX_TC(float, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__nv_bfloat16, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__half, int64_t); + +// --------------------------------------------------------------------------- +// Paddle op wrapper +// --------------------------------------------------------------------------- +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor) { + auto input_shape = gating_output.shape(); + PD_CHECK(input_shape.size() == 2); + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + auto place = gating_output.place(); + PD_CHECK(n_group <= 32, "grouped_topk: n_group must be <= 32"); + PD_CHECK(topk <= 32, "grouped_topk: topk must be <= WARP_SIZE (32)"); + + // Outputs are always float32 regardless of input dtype + auto scores = paddle::empty( + {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); + auto topk_values = + paddle::empty({num_tokens, topk}, paddle::DataType::FLOAT32, place); + auto topk_indices = + paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place); + + auto stream = gating_output.stream(); + auto dtype = gating_output.dtype(); + + float* scores_ptr = reinterpret_cast(scores.data()); + float* topk_values_ptr = reinterpret_cast(topk_values.data()); + int64_t* topk_idx_ptr = + reinterpret_cast(topk_indices.data()); + float* bias_ptr = + reinterpret_cast(e_score_correction_bias.data()); + + if (dtype == paddle::DataType::BFLOAT16) { + invokeFusedNoAuxTc<__nv_bfloat16, int64_t>( + reinterpret_cast<__nv_bfloat16*>( + gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else if (dtype == paddle::DataType::FLOAT16) { + invokeFusedNoAuxTc<__half, int64_t>( + reinterpret_cast<__half*>(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else { + PD_CHECK( + dtype == paddle::DataType::FLOAT32, + "grouped_topk: gating_output must be float32, float16, or bfloat16"); + invokeFusedNoAuxTc( + reinterpret_cast(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } + + return {scores, topk_values, topk_indices}; +} + +std::vector GroupedTopkInferDtype( + const paddle::DataType& /*gating_output_dtype*/, + const paddle::DataType& /*e_score_correction_bias_dtype*/) { + // Outputs are always float32: cast is fused into the kernel. + return {paddle::DataType::FLOAT32, + paddle::DataType::FLOAT32, + paddle::DataType::INT64}; +} + +std::vector> GroupedTopkInferShape( + const std::vector& gating_output_shape, + const std::vector&, + const int topk) { + auto num_tokens = gating_output_shape[0]; + auto num_experts = gating_output_shape[1]; + return {{num_tokens, num_experts}, {num_tokens, topk}, {num_tokens, topk}}; +} + +PD_BUILD_STATIC_OP(grouped_topk) + .Inputs({"gating_output", "e_score_correction_bias"}) + .Outputs({"output_tensor", "topk_values", "topk_indices"}) + .Attrs({"n_group: int", + "topk_group: int", + "topk: int", + "renormalize: bool", + "routed_scaling_factor: float"}) + .SetKernelFn(PD_KERNEL(grouped_topk)) + .SetInferShapeFn(PD_INFER_SHAPE(GroupedTopkInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GroupedTopkInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 2e0de0123b6..ed6ba5a5ef8 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -337,6 +337,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", @@ -693,6 +694,7 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 892d6668859..b285f66b42a 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -339,7 +339,7 @@ class EngineArgs: enable_moe_scores_elementwise_fuse: bool = False """ - Flag to enable fused elementwise cast in get_moe_scores. Default is False (disabled). + Flag to enable fused elementwise in get_moe_scores. Default is False (disabled). """ cache_transfer_protocol: str = "ipc,rdma" @@ -1399,7 +1399,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--enable-moe-scores-elementwise-fuse", action="store_true", default=EngineArgs.enable_moe_scores_elementwise_fuse, - help="Enable fused elementwise cast in get_moe_scores for MoE routing.", + help="Enable fused elementwise in get_moe_scores for MoE routing.", ) model_group.add_argument( "--deploy-modality", diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index bd669deedc0..be215065db3 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -701,7 +701,6 @@ def apply( if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) gate_out = gate(x) - gate_out = gate_out.cast("float32") top_k = layer.top_k num_local_experts = layer.num_local_experts moe_intermediate_size = layer.moe_intermediate_size @@ -709,6 +708,11 @@ def apply( E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -717,8 +721,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -783,8 +789,8 @@ def apply( stride_am=x_q.strides[0], stride_ak=x_q.strides[1], stride_be=layer.up_gate_proj_weight.strides[0], - stride_bk=layer.up_gate_proj_weight.strides[1], - stride_bn=layer.up_gate_proj_weight.strides[2], + stride_bk=layer.up_gate_proj_weight.strides[2], + stride_bn=layer.up_gate_proj_weight.strides[1], stride_cm=up_gate_proj_out.strides[0], stride_cn=up_gate_proj_out.strides[1], # @@ -1952,10 +1958,11 @@ def apply_tp( gate_out = gate(x) if layer.topk_method == "noaux_tc": - use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) if not use_fused: gate_out = gate_out.cast("float32") - _, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index df7b7f12c4b..0e4fa6ee9dd 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -36,16 +36,15 @@ from fastdeploy.worker.experts_manager import RedundantExpertManger try: - from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant + from fastdeploy.model_executor.ops.gpu import ( + grouped_topk, + noaux_tc, + noaux_tc_redundant, + ) except: logger.warning("import noaux_tc Failed!") import numpy as np -if current_platform.is_cuda(): - from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( - fused_cast_sigmoid_bias, - ) - def get_moe_method(layer=None): """ @@ -107,11 +106,6 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - if use_fused_cast and current_platform.is_cuda(): - scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias, cast_type="float32") - else: - scores = paddle.nn.functional.sigmoid(gating_output) - scores_with_bias = scores + e_score_correction_bias if envs.FD_USE_PHI_MOE_TOPK: # calculate renormalize and routed_scaling_factor value outside the noaux_tc original_renormalize = renormalize @@ -119,7 +113,9 @@ def get_moe_scores( renormalize = False routed_scaling_factor = 1.0 - if expert_id_to_ep_rank_array is None: + if expert_id_to_ep_rank_array is None and not use_fused_cast: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, @@ -129,9 +125,20 @@ def get_moe_scores( renormalize, routed_scaling_factor, ) + elif expert_id_to_ep_rank_array is None and use_fused_cast: + # fused kernel: cast + sigmoid + add + noaux_tc + scores, topk_values, topk_idx = grouped_topk( + gating_output, + e_score_correction_bias, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + renormalize, + routed_scaling_factor, + ) else: - # noaux_tc_redundant returns 4 values: scores, topk_values, topk_idx, - # and tokens_per_expert_stats_list_out (inplace updated) + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx, _ = noaux_tc_redundant( scores, scores_with_bias, diff --git a/tests/operators/test_grouped_topk_op.py b/tests/operators/test_grouped_topk_op.py new file mode 100644 index 00000000000..1e76328eb93 --- /dev/null +++ b/tests/operators/test_grouped_topk_op.py @@ -0,0 +1,485 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the `grouped_topk` custom CUDA op (low-level interface). + +grouped_topk fuses sigmoid into the kernel and accepts raw logits directly, +unlike noaux_tc which requires Python-side sigmoid preprocessing. + +Algorithm: + 1. scores = sigmoid(gating_output) [fused inside kernel] + 2. scores_with_bias = scores + e_score_correction_bias + 3. group_scores = sum of top-2 biased expert scores per group + 4. Select top-topk_group groups + 5. Within selected groups select top-topk experts by biased score + 6. Gather unbiased sigmoid scores for selected experts as topk_values + 7. Optionally renormalize and scale by routed_scaling_factor + +Model configs covered: + DeepSeek-V3 / R1 num_experts=256, n_group=8, topk_group=4, topk=8, renorm=True, scale=2.5 + GLM-4.5-Air num_experts=128, n_group=1, topk_group=1, topk=8, renorm=True, scale=1.0 + Qwen3-30B-A3B num_experts=128, n_group=4, topk_group=2, topk=8, renorm=False, scale=1.0 + Kimi-K2 num_experts=384, n_group=8, topk_group=2, topk=8, renorm=False, scale=1.0 +""" + +import unittest + +import numpy as np +import paddle + +try: + from fastdeploy.model_executor.ops.gpu import grouped_topk + + _GROUPED_TOPK_AVAILABLE = True +except Exception: + _GROUPED_TOPK_AVAILABLE = False + + +def native_grouped_topk( + gating_output: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, +): + """Pure-Python reference that mirrors the grouped_topk kernel semantics. + + Args: + gating_output: raw logits, shape [num_tokens, num_experts] + e_score_correction_bias: bias added to sigmoid scores, shape [1, num_experts] or [num_experts] + n_group: number of expert groups + topk_group: number of groups selected per token + topk: number of experts selected per token + renormalize: whether to L1-normalise the selected weights + routed_scaling_factor: multiplicative scale applied after renorm + + Returns: + (scores_out, topk_values, topk_indices) + scores_out – sparse score tensor, shape [num_tokens, num_experts] + topk_values – weights for selected experts, shape [num_tokens, topk] + topk_indices – expert indices, shape [num_tokens, topk] (int64) + """ + num_tokens, num_experts = gating_output.shape + experts_per_group = num_experts // n_group + + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias + + # Step 1: group scores = sum of top-2 biased scores per group + biased = scores_with_bias.reshape([num_tokens, n_group, experts_per_group]) + group_scores = biased.topk(min(2, experts_per_group), axis=-1)[0].sum(axis=-1) + + # Step 2: select top-topk_group groups + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] + group_mask = paddle.zeros_like(group_scores) + group_mask.put_along_axis_(group_idx, paddle.ones_like(group_idx, dtype=group_mask.dtype), axis=-1) + score_mask = ( + group_mask.unsqueeze(-1).expand([num_tokens, n_group, experts_per_group]).reshape([num_tokens, num_experts]) + ) + + # Step 3: select top-topk experts within selected groups (biased score) + tmp_scores = scores_with_bias.masked_fill(~score_mask.cast(paddle.bool), float("-inf")) + topk_indices = paddle.topk(tmp_scores, topk, axis=-1)[1] + + # Step 4: gather unbiased sigmoid scores + topk_values = paddle.take_along_axis(scores, topk_indices, axis=1) + + # Step 5: renormalize + scale + if renormalize: + topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20) + if routed_scaling_factor != 1.0: + topk_values = topk_values * routed_scaling_factor + + scores_out = paddle.zeros_like(scores) + scores_out.put_along_axis_(topk_indices, topk_values, axis=1) + + return scores_out, topk_values, topk_indices.cast(paddle.int64) + + +@unittest.skipUnless(_GROUPED_TOPK_AVAILABLE, "grouped_topk custom op not available (not compiled)") +class TestGroupedTopkOp(unittest.TestCase): + """Tests for the grouped_topk custom CUDA op.""" + + ATOL = 1e-3 + RTOL = 1e-3 + + def setUp(self): + paddle.seed(42) + + # ------------------------------------------------------------------ + # Parametrised helper + # ------------------------------------------------------------------ + def _run_case( + self, + num_tokens: int, + num_experts: int, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, + input_dtype=paddle.float32, + bias_scale: float = 0.1, + seed: int = 42, + ): + paddle.seed(seed) + gating = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = (paddle.rand([1, num_experts], dtype=paddle.float32) - 0.5) * bias_scale + + # Reference always runs in fp32 + gating_fp32 = gating.cast(paddle.float32) if input_dtype != paddle.float32 else gating + ref_scores, ref_tv, ref_ti = native_grouped_topk( + gating_fp32.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + op_scores, op_tv, op_ti = grouped_topk( + gating.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + label = ( + f"T={num_tokens}, E={num_experts}, n_group={n_group}, " + f"topk_group={topk_group}, topk={topk}, " + f"renorm={renormalize}, scale={routed_scaling_factor}, dtype={input_dtype}" + ) + + self.assertEqual(op_tv.shape, [num_tokens, topk], f"[{label}] topk_values shape") + self.assertEqual(op_ti.shape, [num_tokens, topk], f"[{label}] topk_indices shape") + self.assertEqual(op_ti.dtype, paddle.int64, f"[{label}] topk_indices dtype") + self.assertEqual(op_tv.dtype, paddle.float32, f"[{label}] topk_values dtype") + + # Compare set-level index match (position order not guaranteed) + ref_sorted = paddle.sort(ref_ti, axis=-1) + op_sorted = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_sorted, op_sorted).item(): + n_diff = (ref_sorted != op_sorted).sum().item() + self.fail(f"[{label}] topk_indices set mismatch: {n_diff} positions differ") + + # Align values by expert index before comparing + ref_ord = paddle.argsort(ref_ti, axis=-1) + op_ord = paddle.argsort(op_ti, axis=-1) + ref_tv_s = paddle.take_along_axis(ref_tv, ref_ord, axis=-1) + op_tv_s = paddle.take_along_axis(op_tv, op_ord, axis=-1) + if not paddle.allclose(op_tv_s, ref_tv_s, atol=self.ATOL, rtol=self.RTOL).item(): + max_diff = (op_tv_s - ref_tv_s).abs().max().item() + self.fail(f"[{label}] topk_values max_diff={max_diff:.2e}") + + # ------------------------------------------------------------------ + # GLM-4.5-Air: n_experts=128, n_group=1, topk_group=1, topk=8, renorm=True + # ------------------------------------------------------------------ + def test_glm45air_T1(self): + self._run_case(1, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T32(self): + self._run_case(32, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T128(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T512(self): + self._run_case(512, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T1024(self): + self._run_case(1024, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T4096(self): + self._run_case(4096, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T8192(self): + self._run_case(8192, 128, 1, 1, 8, True, 1.0) + + # ------------------------------------------------------------------ + # DeepSeek-V3 / R1: n_experts=256, n_group=8, topk_group=4, topk=8, + # renorm=True, scale=2.5 + # ------------------------------------------------------------------ + def test_deepseek_v3_T1(self): + self._run_case(1, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T32(self): + self._run_case(32, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T128(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T512(self): + self._run_case(512, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T4096(self): + self._run_case(4096, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T8192(self): + self._run_case(8192, 256, 8, 4, 8, True, 2.5) + + # ------------------------------------------------------------------ + # Qwen3-30B-A3B: n_experts=128, n_group=4, topk_group=2, topk=8, + # renorm=False + # ------------------------------------------------------------------ + def test_qwen3_30b_T1(self): + self._run_case(1, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T128(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T512(self): + self._run_case(512, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T4096(self): + self._run_case(4096, 128, 4, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # Kimi-K2: n_experts=384, n_group=8, topk_group=2, topk=8, renorm=False + # ------------------------------------------------------------------ + def test_kimi_k2_T1(self): + self._run_case(1, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T128(self): + self._run_case(128, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T512(self): + self._run_case(512, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T4096(self): + self._run_case(4096, 384, 8, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # bfloat16 input path: kernel should cast internally to fp32 + # ------------------------------------------------------------------ + def test_bf16_input_glm45air(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0, input_dtype=paddle.bfloat16) + + def test_bf16_input_deepseek_v3(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5, input_dtype=paddle.bfloat16) + + def test_bf16_input_qwen3_30b(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0, input_dtype=paddle.bfloat16) + + # ------------------------------------------------------------------ + # Output shape and dtype sanity + # ------------------------------------------------------------------ + def test_output_shapes(self): + """Verify output shapes for various (T, E, topk) combinations.""" + cases = [ + (1, 128, 1, 1, 8), + (32, 256, 8, 4, 8), + (64, 384, 8, 2, 8), + ] + for T, E, ng, tkg, topk in cases: + gating = paddle.randn([T, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertEqual(tv.shape, [T, topk], f"T={T},E={E}: topk_values shape") + self.assertEqual(ti.shape, [T, topk], f"T={T},E={E}: topk_indices shape") + + def test_output_dtype_is_float32(self): + """topk_values must always be float32 regardless of input dtype.""" + for dtype in [paddle.float32, paddle.bfloat16]: + gating = paddle.randn([16, 128], dtype=dtype) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertEqual(tv.dtype, paddle.float32, f"dtype={dtype}: topk_values not float32") + self.assertEqual(ti.dtype, paddle.int64, f"dtype={dtype}: topk_indices not int64") + + # ------------------------------------------------------------------ + # Correctness invariants + # ------------------------------------------------------------------ + def test_topk_indices_in_valid_range(self): + """All selected expert indices must lie in [0, num_experts).""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8), (384, 8, 2, 8)]: + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertTrue((ti >= 0).all().item(), f"E={E}: negative index found") + self.assertTrue((ti < E).all().item(), f"E={E}: index >= num_experts") + + def test_no_duplicate_experts_per_token(self): + """Each token must select exactly topk distinct experts.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + for row in ti.numpy(): + self.assertEqual(len(set(row.tolist())), topk, f"E={E}: duplicate expert indices in row {row}") + + def test_topk_values_non_negative(self): + """Sigmoid output is in (0,1); routing weights must be >= 0.""" + gating = paddle.randn([64, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertTrue((tv >= 0).all().item(), "topk_values contains negative weights") + + def test_renormalized_weights_sum_to_one(self): + """With renormalize=True and scale=1.0, per-token weights sum ≈ 1.""" + num_tokens = 64 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.ones(num_tokens, dtype=np.float32), + atol=1e-3, + err_msg="Renormalized weights do not sum to 1 per token", + ) + + def test_scaled_weights_sum(self): + """With renormalize=True and scale=2.5, per-token weights sum ≈ 2.5.""" + num_tokens, scale = 64, 2.5 + gating = paddle.randn([num_tokens, 256], dtype=paddle.float32) + bias = paddle.zeros([1, 256], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 8, 4, 8, True, scale) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.full(num_tokens, scale, dtype=np.float32), + atol=1e-2, + err_msg=f"Scaled weights do not sum to {scale} per token", + ) + + def test_no_renorm_weights_are_raw_sigmoid(self): + """With renormalize=False, topk_values must equal sigmoid(logits) at selected positions.""" + num_tokens, E = 32, 128 + gating = paddle.randn([num_tokens, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, False, 1.0) + expected = paddle.take_along_axis(paddle.nn.functional.sigmoid(gating), ti, axis=1) + np.testing.assert_allclose( + tv.numpy(), + expected.numpy(), + atol=1e-4, + err_msg="Without renorm, topk_values should equal sigmoid(gating) at selected positions", + ) + + def test_deterministic(self): + """Two identical calls must produce bit-for-bit identical outputs.""" + gating = paddle.randn([32, 256], dtype=paddle.float32) + bias = (paddle.rand([1, 256], dtype=paddle.float32) - 0.5) * 0.1 + _, tv1, ti1 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + _, tv2, ti2 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + self.assertTrue( + paddle.allclose(tv1, tv2, atol=0.0, rtol=0.0).item(), + "topk_values not deterministic across two identical calls", + ) + self.assertTrue( + paddle.equal_all(ti1, ti2).item(), + "topk_indices not deterministic across two identical calls", + ) + + def test_zero_bias(self): + """All-zero bias: biased == unbiased; reference and op must agree.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(16) + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + self.assertTrue( + paddle.equal_all(ref_s, op_s).item(), + f"E={E}/zero_bias: topk_indices set mismatch", + ) + + def test_large_bias_steers_routing(self): + """Large positive bias on first half of experts must dominate selection.""" + E, topk = 128, 8 + paddle.seed(17) + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.concat( + [ + paddle.full([1, E // 2], 2.0, dtype=paddle.float32), + paddle.full([1, E // 2], -2.0, dtype=paddle.float32), + ], + axis=1, + ) + _, _, ti = grouped_topk(gating, bias, 1, 1, topk, True, 1.0) + self.assertTrue( + (ti < E // 2).all().item(), + "Large positive bias on experts [0, E/2) did not steer all selections there", + ) + + def test_extreme_logits_no_nan_inf(self): + """Very large logits must not produce NaN or Inf in outputs.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(18) + gating = paddle.randn([8, E], dtype=paddle.float32) * 50.0 + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, ng, tkg, topk, False, 1.0) + self.assertFalse(paddle.isnan(tv).any().item(), f"E={E}: NaN in topk_values") + self.assertFalse(paddle.isinf(tv).any().item(), f"E={E}: Inf in topk_values") + + def test_single_expert_selected(self): + """topk=1: each token selects exactly one expert; weight == 1.0 with renorm.""" + num_tokens = 16 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 1, True, 1.0) + self.assertEqual(tv.shape, [num_tokens, 1]) + self.assertEqual(ti.shape, [num_tokens, 1]) + np.testing.assert_allclose( + tv.numpy(), + np.ones((num_tokens, 1), dtype=np.float32), + atol=1e-5, + err_msg="With topk=1 and renorm=True, each weight should be 1.0", + ) + + def test_sparse_scores_consistency(self): + """Sparse scores tensor: non-zero at selected positions must equal topk_values; zero elsewhere.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([16, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + s, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + gathered = paddle.take_along_axis(s, ti, axis=1) + np.testing.assert_allclose( + gathered.numpy(), + tv.numpy(), + atol=1e-6, + err_msg=f"E={E}: sparse scores at topk positions != topk_values", + ) + nonzero_count = (s != 0).sum(axis=-1) + self.assertTrue( + (nonzero_count == topk).all().item(), + f"E={E}: non-zero count per token != topk", + ) + + def test_irregular_token_counts(self): + """Non-power-of-2 token counts must produce correct shapes and values.""" + irregular_T = [3, 7, 15, 33, 65, 127, 129, 257, 511, 513, 900] + for T in irregular_T: + gating = paddle.randn([T, 128], dtype=paddle.float32) + bias = (paddle.rand([1, 128], dtype=paddle.float32) - 0.5) * 0.1 + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + self.assertEqual(op_tv.shape, [T, 8], f"T={T}: topk_values shape mismatch") + self.assertEqual(op_ti.shape, [T, 8], f"T={T}: topk_indices shape mismatch") + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_s, op_s).item(): + n_diff = (ref_s != op_s).sum().item() + self.fail(f"T={T}: topk_indices mismatch, {n_diff} positions differ") + + +if __name__ == "__main__": + unittest.main() From 9894b326b6ab47aa79cc4adcdb7ee2a04d1f5e63 Mon Sep 17 00:00:00 2001 From: sunxin <68891411+Sunny-bot1@users.noreply.github.com> Date: Mon, 18 May 2026 15:49:46 +0800 Subject: [PATCH 115/143] [Cherry-Pick][RL] Support cpu tensor broadcast(#7833) (#7840) * support cpu tensor broadcast * fix place * fix group * fix init * fix shutdown process group --- fastdeploy/rl/dynamic_weight_manager.py | 7 +++++++ fastdeploy/worker/worker_process.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index a6f14068abc..c30da6f9124 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -348,6 +348,13 @@ def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None: if shutdown_process_group: paddle.distributed.shutdown_process_group(self.parallel_config.ep_group) if shutdown_process_group: + # ProcessGroupGloo has no shutdown(); remove it from paddle's registry + # before the global sweep to avoid AttributeError. + from paddle.distributed.collective import _get_group_map_by_name + + for name, pg in list(_get_group_map_by_name().items()): + if pg.process_group is not None and not hasattr(pg.process_group, "shutdown"): + _get_group_map_by_name().pop(name, None) paddle.distributed.shutdown_process_group() self._update_shared_status(pid, ModelWeightsStatus.CLEARED) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 37e23be2d50..61e6fcda85f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -174,6 +174,10 @@ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) -> self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule self.cached_control_reqs = [] + if self.ranks > 1: + self.gloo_group = dist.new_group(list(range(self.ranks)), backend="gloo") + else: + self.gloo_group = None def init_control(self): engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port @@ -316,9 +320,12 @@ def update_weights_from_tensor(self, mmap_infos): self.experts_manager.tensor_infos = None def _broadcast_model_weights_signal(self, src: int, group) -> int: - signal_list = [self.model_weights_signal[0]] - paddle.distributed.broadcast_object_list(signal_list, src=src, group=group) - return int(signal_list[0]) + model_weights_signal_tensor = paddle.full( + shape=[1], fill_value=self.model_weights_signal[0], dtype="int32", device="cpu" + ) + paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group) + value = model_weights_signal_tensor.numpy()[0] + return int(value) def _get_exist_task_flag(self) -> bool: if self.nnode > 1: @@ -498,7 +505,7 @@ def event_loop_normal(self) -> None: if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) if self.ranks > 1: - self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None) + self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=self.gloo_group) req_dicts = None self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time()) @@ -563,7 +570,7 @@ def event_loop_normal(self) -> None: self.model_weights_signal[0] = self.model_weights_status.value[0] if self.ranks > 1: self.model_weights_signal[0] = self._broadcast_model_weights_signal( - src=0, group=None + src=0, group=self.gloo_group ) time.sleep(1) self.model_weights_status.value[0] = ( From ab3c5f4ce60b2dbbb495c2be92b3d5cd8cdf4968 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Tue, 19 May 2026 11:30:58 +0800 Subject: [PATCH 116/143] [Cherry-Pick][CI] Set --workers=1 to avoid intermittent timeout failures (#7846) (#7848) --- .github/workflows/_base_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index 7087183e447..9f75b2b4b35 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -272,7 +272,7 @@ jobs: curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ -H "Content-Type: application/json" \ - -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }" + -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--workers\": 1, \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }" check_service 90 python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1 From 41d44d61e917e4918ed029656b36c3827bf50772 Mon Sep 17 00:00:00 2001 From: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Date: Tue, 19 May 2026 11:41:44 +0800 Subject: [PATCH 117/143] fix refact abort (#7838) --- fastdeploy/engine/common_engine.py | 103 +++++++++++++++++++- fastdeploy/splitwise/splitwise_connector.py | 35 +++++++ 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index f553a4f8ee5..4ae03e18a30 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1129,6 +1129,22 @@ def _insert_zmq_task_to_scheduler(self): "Request is aborted since LLM Engine is paused.", worker_pid=worker_pid, ) + # PD ghost prevention: notify decode side to recycle its + # scheduler entry, otherwise it would sit there as a ghost + # since prefill will never deliver any first token. + if ( + self.cfg.scheduler_config.splitwise_role == "prefill" + and getattr(request, "disaggregate_info", None) + and self.split_connector is not None + ): + try: + self.split_connector.send_drop_signal( + request.request_id, request.disaggregate_info + ) + except Exception as e: + self.llm_logger.warning( + f"Failed to send drop signal for {request.request_id}: {e}" + ) continue except Exception as e: self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") @@ -1278,15 +1294,47 @@ def _control_pause(self, control_request: ControlRequest): def _wait_inflight_drained(self): """ Wait until resource_manager.requests is completely empty. - No timeout — abort pipeline will complete. - Logs a warning every 30 seconds while waiting to help diagnose potential hangs. + Logs a warning and remove scheduler-only request every 30 seconds while waiting to help diagnose potential hangs. """ start_time = time.monotonic() next_warn_time = start_time + 30 + GHOST_REAP_AFTER = 30.0 while self.resource_manager.requests or self.scheduler.requests: now = time.monotonic() + late_ids = list( + set(self.resource_manager.requests.keys()) + - self.resource_manager.waiting_abort_req_id_set + - self.resource_manager.to_be_aborted_req_id_set + ) + if late_ids: + self.resource_manager.add_abort_req_ids(late_ids) + self.llm_logger.info(f"Pause drain: late-arrived requests added to abort set: {late_ids}") + + if now - start_time >= GHOST_REAP_AFTER: + scheduler_only_ids = list( + set(self.scheduler.requests.keys()) - set(self.resource_manager.requests.keys()) + ) + if scheduler_only_ids: + ghost_outputs = [ + RequestOutput( + request_id=req_id, + finished=True, + error_code=499, + error_msg=(f"forced cleanup after {GHOST_REAP_AFTER}s"), + ) + for req_id in scheduler_only_ids + ] + self.scheduler.put_results(ghost_outputs) + self.llm_logger.warning( + f"Pause drain timeout: reaped {len(scheduler_only_ids)} " + f"scheduler-only ghost(s) after {GHOST_REAP_AFTER}s: " + f"{scheduler_only_ids}" + ) + # Reset to avoid re-reaping on the next tick + start_time = now + if now >= next_warn_time: self.llm_logger.warning( "Still waiting for inflight requests to drain, " @@ -1751,6 +1799,31 @@ def _fetch_requests(): items = self.engine_worker_queue.get_disaggregated_tasks() for item in items: + msg_type = item[0] + + # PD pause race: P drops a request via paused gate and notifies us + # to recycle our scheduler entry (otherwise it becomes a ghost that + # blocks pause/abort drain forever). Synthesize a finished + # RequestOutput so it walks the normal put_results -> _recycle path + # and the client gets a 499 error response. + if msg_type == "decode_drop": + drop_outputs = [ + RequestOutput( + request_id=req_id, + finished=True, + error_code=499, + error_msg="Aborted: prefill dropped this request (paused gate)", + ) + for req_id in item[1] + ] + if drop_outputs: + self.scheduler.put_results(drop_outputs) + self.llm_logger.info( + "Decode recycled scheduler ghost(s) via P-side drop signal: " + f"{[r.request_id for r in drop_outputs]}" + ) + continue + tasks = item[1] if isinstance(tasks[0], Request): self.llm_logger.debug( @@ -1815,9 +1888,17 @@ def _process_prefilled_requests(): nonlocal prefilled_request_ouputs ready_request_outputs = [] waiting_request_outputs = [] + ghost_request_outputs = [] for req_output in prefilled_request_ouputs: - if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_output.request_id): + req_id = req_output.request_id + if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_id): + if ( + req_id in self.resource_manager.waiting_abort_req_id_set + or req_id in self.resource_manager.to_be_aborted_req_id_set + ): + ghost_request_outputs.append(req_output) + continue # ensure the api_server and scheduler in decode have # received the request sent by the client waiting_request_outputs.append(req_output) @@ -1828,6 +1909,22 @@ def _process_prefilled_requests(): self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}") prefilled_request_ouputs = waiting_request_outputs + + for req_output in ghost_request_outputs: + req_id = req_output.request_id + self.llm_logger.warning( + f"Pause drain: reaping prefilled-output ghost {req_id} " + "(scheduler never registered, marked for abort -- breaks deadlock)" + ) + try: + self.resource_manager.pre_recycle_resource(req_id) + except Exception as e: + self.llm_logger.warning(f"pre_recycle_resource({req_id}) failed: {e}") + self.resource_manager.waiting_abort_req_id_set.discard(req_id) + self.resource_manager.to_be_aborted_req_id_set.discard(req_id) + if req_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[req_id] + if self.cfg.splitwise_version == "v1": # decode return first token to client self.scheduler.put_results(ready_request_outputs) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 77f75ee4de7..27c608a4d1f 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -236,6 +236,27 @@ def send_first_token(self, prefill_msg, tasks_list): ) self._send_message(addr, "decode", tasks_list) + def send_drop_signal(self, request_id: str, disaggregate_info: dict): + """ + Notify the decode side that this prefill request has been dropped + (e.g. paused gate rejected it on P). The decode side should recycle + its scheduler entry for this request_id, otherwise it would sit + there forever as a ghost and pause/abort drain would hang. + """ + if not disaggregate_info: + return + decode_ip = disaggregate_info.get("decode_ip") + decode_port = disaggregate_info.get("decode_connector_port") + if not decode_ip or not decode_port: + self.logger.warning( + f"send_drop_signal: missing decode_ip/decode_connector_port in " + f"disaggregate_info for {request_id}; skip" + ) + return + addr = f"{decode_ip}:{decode_port}" + self.logger.info(f"send_drop_signal: addr={addr}, request_id={request_id}") + self._send_message(addr, "drop", {"request_id": request_id}) + def check_decode_allocated(self, task): """Check whether the requests have been allocated resources in decode.""" self.logger.debug(f"check_decode_allocated: {task.request_id}") @@ -382,6 +403,8 @@ def _process_message(self, frames: List[bytes]): self._handle_prefill(payload) elif msg_type == "decode": self._handle_decode(payload) + elif msg_type == "drop": + self._handle_drop(payload) elif msg_type == "cache_sync": for task in payload: self.logger.info(f"_process_message: cache_sync task: {task}") @@ -412,3 +435,15 @@ def _handle_decode(self, payload): for task in payload: tasks.append(RequestOutput.from_dict(task)) self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks)) + + def _handle_drop(self, payload): + """ + Handle drop signal from prefill: forward to engine worker queue so the + decode engine main loop can recycle the corresponding scheduler entry. + """ + request_id = payload.get("request_id") if isinstance(payload, dict) else None + if not request_id: + self.logger.warning(f"_handle_drop: invalid payload {payload}") + return + self.logger.info(f"_handle_drop: request_id={request_id}") + self.engine_worker_queue.put_disaggregated_tasks(("decode_drop", [request_id])) From 8c4f5a6a1dce8982381b8f32c70e6d2e50d93161 Mon Sep 17 00:00:00 2001 From: liuruyan <44316842+liuruyan@users.noreply.github.com> Date: Wed, 20 May 2026 14:18:01 +0800 Subject: [PATCH 118/143] [Cherry-Pick] update fleet_ops(#7859) (#7858) * update fleet_ops * use try_import in ep --- fastdeploy/model_executor/layers/moe/ep.py | 5 ++- .../layers/quantization/fp8_utils.py | 32 ++----------------- fastdeploy/model_executor/utils.py | 24 ++++++++++++++ 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 1ddb6994878..022a26ea74e 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -27,8 +27,11 @@ import fastdeploy from fastdeploy import envs from fastdeploy.config import MoEPhase +from fastdeploy.model_executor.utils import try_import from fastdeploy.utils import singleton +paddlefleet_ops = try_import(["paddlefleet.ops"]) + def load_deep_ep() -> ModuleType: """ @@ -43,7 +46,7 @@ def load_deep_ep() -> ModuleType: # Enable paddle.enable_compat before importing deep_ep (required by PFCC/PaddleFleet variants) paddle.enable_compat(scope={"deep_ep"}) try: - import paddlefleet.ops.deep_ep as deep_ep # type: ignore + import paddlefleet_ops.deep_ep as deep_ep # type: ignore logger.info("FD use PaddleFleet/DeepEP now.") return deep_ep diff --git a/fastdeploy/model_executor/layers/quantization/fp8_utils.py b/fastdeploy/model_executor/layers/quantization/fp8_utils.py index d7ad693c1cd..6b3f1326765 100644 --- a/fastdeploy/model_executor/layers/quantization/fp8_utils.py +++ b/fastdeploy/model_executor/layers/quantization/fp8_utils.py @@ -14,44 +14,18 @@ # limitations under the License. """ -import importlib - import paddle import triton from paddleformers.utils.log import logger +from fastdeploy.model_executor.layers.utils import get_sm_version from fastdeploy.model_executor.ops.triton_ops import _per_token_group_quant_fp8 +from fastdeploy.model_executor.utils import try_import from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import per_token_group_fp8_quant -from ..utils import get_sm_version - - -def try_import(modules, name=None, fail_msg=None): - """ - try_import - """ - if not isinstance(modules, (list, tuple)): - modules = [modules] - - for m in modules: - assert isinstance(m, str), m - try: - m = importlib.import_module(m) - except ImportError: - m = None - - if m is not None: - if name is None: - return m - elif hasattr(m, name): - return getattr(m, name) - - if fail_msg is not None: - logger.warning(fail_msg) - paddlefleet_ops = try_import(["paddlefleet.ops"]) @@ -71,7 +45,7 @@ def load_deep_gemm(): try: import logging - import paddlefleet.ops.deep_gemm as deep_gemm + import paddlefleet_ops.deep_gemm as deep_gemm logging.getLogger().handlers.clear() logger.info("Detected sm100, use PaddleFleet DeepGEMM") diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index de5ef678b01..960d8f23f7e 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -555,6 +555,30 @@ def fn(loaded_weight_name, is_moe): return fn +def try_import(modules, name=None, fail_msg=None): + """ + try_import + """ + if not isinstance(modules, (list, tuple)): + modules = [modules] + + for m in modules: + assert isinstance(m, str), m + try: + m = importlib.import_module(m) + except ImportError: + m = None + + if m is not None: + if name is None: + return m + elif hasattr(m, name): + return getattr(m, name) + + if fail_msg is not None: + logger.warning(fail_msg) + + def has_flashinfer(): return importlib.util.find_spec("flashinfer") is not None From 31b12ee8f942c4d2a218a016603e80dc8af42357 Mon Sep 17 00:00:00 2001 From: sunxin <68891411+Sunny-bot1@users.noreply.github.com> Date: Wed, 20 May 2026 19:27:10 +0800 Subject: [PATCH 119/143] [Cherry-Pick][Optimization] Reduce logprob processing overhead by using actual topk instead of fixed K+1 (#7860) (#7861) * opt logprob process * fix test_token_processor * fix xpu --- .../gpu_ops/get_output_msg_with_topk.cc | 12 ++++++--- .../gpu_ops/save_output_msg_with_topk.cc | 21 ++++++++------- .../src/ops/get_output_msg_with_topk.cc | 12 ++++++--- .../src/ops/save_output_msg_with_topk.cc | 21 ++++++++------- fastdeploy/output/token_processor.py | 27 ++++++++++++++----- tests/output/test_token_processor.py | 6 +++-- 6 files changed, 62 insertions(+), 37 deletions(-) diff --git a/custom_ops/gpu_ops/get_output_msg_with_topk.cc b/custom_ops/gpu_ops/get_output_msg_with_topk.cc index e70f7c2c24d..363274d6aac 100644 --- a/custom_ops/gpu_ops/get_output_msg_with_topk.cc +++ b/custom_ops/gpu_ops/get_output_msg_with_topk.cc @@ -88,13 +88,17 @@ void GetOutputTopK(const paddle::Tensor& x, return; } - int bsz = msg_rcv.mtext[1]; + // Unpack bsz (low 16 bits) and actual_topk (high 16 bits) from mtext[1]. + // This matches the packing in save_output_msg_with_topk.cc: + // mtext[1] = bsz | (max_num_logprobs << 16) + int bsz = msg_rcv.mtext[1] & 0xFFFF; + int actual_topk = (msg_rcv.mtext[1] >> 16) & 0xFFFF; out_data[0] = (int64_t)msg_rcv.mtext[0]; - out_data[1] = (int64_t)msg_rcv.mtext[1]; + out_data[1] = (int64_t)msg_rcv.mtext[1]; // keep packed value; Python unpacks for (int i = 0; i < bsz; i++) { - for (int j = 0; j < k + 1; j++) { - const int64_t offset = i * (K + 1) + j; + for (int j = 0; j < actual_topk; j++) { + const int64_t offset = i * actual_topk + j; out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2]; scores_data[offset] = msg_rcv.mtext_f[offset]; } diff --git a/custom_ops/gpu_ops/save_output_msg_with_topk.cc b/custom_ops/gpu_ops/save_output_msg_with_topk.cc index 0a7d2ab6eac..3069cb3929b 100644 --- a/custom_ops/gpu_ops/save_output_msg_with_topk.cc +++ b/custom_ops/gpu_ops/save_output_msg_with_topk.cc @@ -109,20 +109,21 @@ void SaveOutMmsgTopK(const paddle::Tensor& x, : -inference_msg_id_from_env; int bsz = x.shape()[0]; int max_num_logprobs = logprob_token_ids.shape()[1]; - msg_sed.mtext[1] = bsz; + // Pack bsz (low 16 bits) and max_num_logprobs (high 16 bits) into mtext[1]. + // token_processor unpacks both fields to avoid reading unused topk slots. + msg_sed.mtext[1] = bsz | (max_num_logprobs << 16); for (int i = 0; i < bsz; i++) { - for (int j = 0; j < K + 1; j++) { - const int64_t offset = i * (K + 1) + j; + // Loop only over actual logprob columns (max_num_logprobs) instead of the + // fixed K+1=21, and use max_num_logprobs as the stride so data is packed + // densely in the message buffer. + for (int j = 0; j < max_num_logprobs; j++) { + const int64_t offset = i * max_num_logprobs + j; if (j == 0) { msg_sed.mtext[offset + 2] = (int)x_data[i]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; - } else if (j < max_num_logprobs) { - msg_sed.mtext[offset + 2] = - (int)logprob_token_ids_data[i * max_num_logprobs + j]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } else { - msg_sed.mtext[offset + 2] = -1; - msg_sed.mtext_f[offset] = 0.0; + msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[offset]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } if (preempted_idx_data[i] == 1) { msg_sed.mtext[offset + 2] = -9; diff --git a/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc b/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc index 04d8efe71e7..cb50725fdbc 100644 --- a/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc +++ b/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc @@ -82,13 +82,17 @@ void GetOutputTopK(const paddle::Tensor& x, return; } - int bsz = msg_rcv.mtext[1]; + // Unpack bsz (low 16 bits) and actual_topk (high 16 bits) from mtext[1]. + // This matches the packing in save_output_msg_with_topk.cc: + // mtext[1] = bsz | (max_num_logprobs << 16) + int bsz = msg_rcv.mtext[1] & 0xFFFF; + int actual_topk = (msg_rcv.mtext[1] >> 16) & 0xFFFF; out_data[0] = (int64_t)msg_rcv.mtext[0]; - out_data[1] = (int64_t)msg_rcv.mtext[1]; + out_data[1] = (int64_t)msg_rcv.mtext[1]; // keep packed value; Python unpacks for (int i = 0; i < bsz; i++) { - for (int j = 0; j < k + 1; j++) { - const int64_t offset = i * (K + 1) + j; + for (int j = 0; j < actual_topk; j++) { + const int64_t offset = i * actual_topk + j; out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2]; scores_data[offset] = msg_rcv.mtext_f[offset]; } diff --git a/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc b/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc index 455e0fa18fb..154affbbde6 100644 --- a/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc +++ b/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc @@ -109,20 +109,21 @@ void SaveOutMmsgTopK(const paddle::Tensor& x, : -inference_msg_id_from_env; int bsz = x.shape()[0]; int max_num_logprobs = logprob_token_ids.shape()[1]; - msg_sed.mtext[1] = bsz; + // Pack bsz (low 16 bits) and max_num_logprobs (high 16 bits) into mtext[1]. + // token_processor unpacks both fields to avoid reading unused topk slots. + msg_sed.mtext[1] = bsz | (max_num_logprobs << 16); for (int i = 0; i < bsz; i++) { - for (int j = 0; j < K + 1; j++) { - const int64_t offset = i * (K + 1) + j; + // Loop only over actual logprob columns (max_num_logprobs) instead of the + // fixed K+1=21, and use max_num_logprobs as the stride so data is packed + // densely in the message buffer. + for (int j = 0; j < max_num_logprobs; j++) { + const int64_t offset = i * max_num_logprobs + j; if (j == 0) { msg_sed.mtext[offset + 2] = (int)x_data[i]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; - } else if (j < max_num_logprobs) { - msg_sed.mtext[offset + 2] = - (int)logprob_token_ids_data[i * max_num_logprobs + j]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } else { - msg_sed.mtext[offset + 2] = -1; - msg_sed.mtext_f[offset] = 0.0; + msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[offset]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } if (preempted_idx_data[i] == 1) { msg_sed.mtext[offset + 2] = -9; diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 2a8328b28fe..a8544cb5979 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -846,10 +846,22 @@ def _process_batch_output(self): batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] elif self.use_logprobs: - batch = self.output_tokens[1, 0] - tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)] - scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)] + # mtext[1] packs bsz (low 16 bits) and actual_topk (high 16 bits). + # actual_topk = max_num_logprobs written by save_output_topk, which + # equals the actual number of logprob columns in this step's message + # (top_logprobs+1 across the batch). Using actual_topk as stride + # avoids processing the K+1=21 fixed-size slots when fewer are needed. + packed = int(self.output_tokens[1, 0]) + batch = packed & 0xFFFF + actual_topk = (packed >> 16) & 0xFFFF + tokens = tokens[2 : batch * actual_topk + 2].reshape([batch, actual_topk]) + scores = self.output_scores[: batch * actual_topk].numpy().reshape([batch, actual_topk]) ranks = self.output_ranks[:batch].numpy() + # Pre-convert the full [batch, actual_topk] arrays to Python lists once, + # avoiding per-row .tolist() calls inside the loop below. + tokens_lists = tokens.tolist() + scores_lists = scores.tolist() + ranks_list = ranks.tolist() else: batch = self.output_tokens[1, 0] tokens = tokens[2 : batch + 2] @@ -1022,10 +1034,11 @@ def _process_batch_output(self): topk_logprobs = scores[i, batch_token_index, :].tolist() sampled_rank = ranks[i, batch_token_index].item() else: - result.outputs.logprob = float(scores[i, 0]) - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() + # Use pre-converted lists (batch .tolist() done before the loop). + result.outputs.logprob = scores_lists[i][0] + topk_token_ids = tokens_lists[i] + topk_logprobs = scores_lists[i] + sampled_rank = ranks_list[i] if result.outputs.top_logprobs is None: result.outputs.top_logprobs = LogprobsLists( diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 5c26db778c7..4240e84c75a 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -719,7 +719,8 @@ def test_process_batch_output_logprob_records_topk_and_caching(): task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task - processor.output_tokens[1, 0] = 1 + # mtext[1] packs bsz (low 16 bits) | actual_topk (high 16 bits) + processor.output_tokens[1, 0] = 1 | ((K + 1) << 16) token_block = np.arange(K + 1, dtype=np.int64) + 3 processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(token_block.reshape([-1, 1])) processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32") @@ -842,7 +843,8 @@ def test_process_batch_output_prefill_chunk_and_adapter_skip(): task.get = lambda key, default=None: getattr(task, key, default) rm.tasks_list[0] = task rm.req_dict[task.request_id] = task - processor.output_tokens[1, 0] = 1 + # mtext[1] packs bsz (low 16 bits) | actual_topk (high 16 bits) + processor.output_tokens[1, 0] = 1 | ((K + 1) << 16) processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(np.ones([K + 1, 1], dtype=np.int64)) processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32") processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64") From b5c8290be6319d88a9055719d434e31ffcf4f00d Mon Sep 17 00:00:00 2001 From: RAM Date: Thu, 21 May 2026 11:15:43 +0800 Subject: [PATCH 120/143] [RL] Reset buffer size of `slot_mapping` (#7868) * Reset buffer size of R3 * refine code --- fastdeploy/worker/gpu_model_runner.py | 4 +++- fastdeploy/worker/input_batch.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 72830fe0183..aec0be0e746 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1337,7 +1337,9 @@ def _compute_position_ids_and_slot_mapping(self) -> None: ) block_size = self.cache_config.block_size block_idx = position_ids // block_size # [num_tokens] - assert self.forward_meta.batch_id_per_token.shape == block_idx.shape + assert ( + self.forward_meta.batch_id_per_token.shape == block_idx.shape + ), f"batch_id_per_token.shape:{self.forward_meta.batch_id_per_token.shape} != block_idx.shape:{block_idx.shape}" block_ids = self.forward_meta.block_tables[self.forward_meta.batch_id_per_token, block_idx] # [num_tokens] block_offset = position_ids % block_size # [num_tokens] slot_mapping = self.share_inputs["slot_mapping_buffer"][:current_total_tokens] diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index ce679b90939..241ccaf6b71 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -189,9 +189,8 @@ def init_share_inputs(self): self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") # Initialize addressing buffers - _max_batched_tokens = self.scheduler_config.max_num_batched_tokens - self.position_ids_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int32) - self.slot_mapping_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int64) + self.position_ids_buffer = paddle.zeros([self.max_chunk_tokens], dtype=paddle.int32) + self.slot_mapping_buffer = paddle.zeros([self.max_chunk_tokens], dtype=paddle.int64) # Declare AttentionBackend buffers self.decoder_batch_ids = None From b562b8d53e03c677bd7192eb92a39c94dd5271b0 Mon Sep 17 00:00:00 2001 From: liuruyan <44316842+liuruyan@users.noreply.github.com> Date: Thu, 21 May 2026 14:14:10 +0800 Subject: [PATCH 121/143] fix ce bug (#7874) --- fastdeploy/model_executor/layers/moe/ep.py | 8 ++++---- .../model_executor/layers/quantization/fp8_utils.py | 7 +++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 022a26ea74e..05c36a68f48 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -27,11 +27,8 @@ import fastdeploy from fastdeploy import envs from fastdeploy.config import MoEPhase -from fastdeploy.model_executor.utils import try_import from fastdeploy.utils import singleton -paddlefleet_ops = try_import(["paddlefleet.ops"]) - def load_deep_ep() -> ModuleType: """ @@ -46,7 +43,10 @@ def load_deep_ep() -> ModuleType: # Enable paddle.enable_compat before importing deep_ep (required by PFCC/PaddleFleet variants) paddle.enable_compat(scope={"deep_ep"}) try: - import paddlefleet_ops.deep_ep as deep_ep # type: ignore + try: + import paddlefleet.ops.deep_ep as deep_ep # type: ignore + except: + import paddlefleet_ops.deep_ep as deep_ep # type: ignore logger.info("FD use PaddleFleet/DeepEP now.") return deep_ep diff --git a/fastdeploy/model_executor/layers/quantization/fp8_utils.py b/fastdeploy/model_executor/layers/quantization/fp8_utils.py index 6b3f1326765..89b9467ecc6 100644 --- a/fastdeploy/model_executor/layers/quantization/fp8_utils.py +++ b/fastdeploy/model_executor/layers/quantization/fp8_utils.py @@ -27,7 +27,7 @@ from fastdeploy.model_executor.ops.gpu import per_token_group_fp8_quant -paddlefleet_ops = try_import(["paddlefleet.ops"]) +paddlefleet_ops = try_import(["paddlefleet.ops", "paddlefleet_ops"]) def load_deep_gemm(): @@ -45,7 +45,10 @@ def load_deep_gemm(): try: import logging - import paddlefleet_ops.deep_gemm as deep_gemm + try: + import paddlefleet.ops.deep_gemm as deep_gemm + except: + import paddlefleet_ops.deep_gemm as deep_gemm logging.getLogger().handlers.clear() logger.info("Detected sm100, use PaddleFleet DeepGEMM") From 485f6c2cbfd5ddeb8d40e6ab6898ea5ac41e189d Mon Sep 17 00:00:00 2001 From: CSWYF3634076 Date: Fri, 22 May 2026 15:28:29 +0800 Subject: [PATCH 122/143] [Cherry-Pick][Feature][Log]console metrics log for pd disaggregation #7843 (#7845) * [Feature]console metrics log for pd disaggregation * [Feature]console metrics log for pd disaggregation fix test --- fastdeploy/engine/common_engine.py | 1 + .../engine/sched/resource_manager_v1.py | 29 +++++++---- .../engine/sched/scheduler_metrics_logger.py | 52 +++++++++++++++++-- tests/engine/test_scheduler_metrics_logger.py | 34 +++++++++++- tests/v1/test_resource_manager_v1.py | 22 ++++++++ 5 files changed, 124 insertions(+), 14 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 4ae03e18a30..16f705cf23f 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -196,6 +196,7 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): self.scheduler_metrics_logger = SchedulerMetricsLogger( enabled=True, dp_rank=self.cfg.parallel_config.local_data_parallel_id, + splitwise_role=self.cfg.scheduler_config.splitwise_role, ) self.resource_manager.scheduler_metrics_logger = self.scheduler_metrics_logger self.token_processor.set_scheduler_metrics_logger(self.scheduler_metrics_logger) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index de89ab3adca..2c63d7b70df 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1759,15 +1759,26 @@ def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | Schedule prefill_reqs = [r for r in scheduled_reqs if isinstance(r, Request) and r.task_type == RequestType.PREFILL] has_decode = any(getattr(r, "task_type", None) == RequestType.DECODE for r in scheduled_reqs) - self.scheduler_metrics_logger.log_prefill_batch( - prefill_reqs=prefill_reqs, - running_cnt=running_cnt, - queue_cnt=queue_cnt, - tokens_used=tokens_used, - token_usage=token_usage, - free_blocks=free_blocks, - evictable_blocks=evictable_blocks, - ) + if self.config.scheduler_config.splitwise_role == "decode": + self.scheduler_metrics_logger.log_decode_bootstrap_batch( + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) + else: + self.scheduler_metrics_logger.log_prefill_batch( + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) if has_decode: has_prefill = len(prefill_reqs) > 0 graph_opt_cfg = self.config.graph_opt_config diff --git a/fastdeploy/engine/sched/scheduler_metrics_logger.py b/fastdeploy/engine/sched/scheduler_metrics_logger.py index 0aaa29e246c..b989584fe07 100644 --- a/fastdeploy/engine/sched/scheduler_metrics_logger.py +++ b/fastdeploy/engine/sched/scheduler_metrics_logger.py @@ -29,9 +29,10 @@ class SchedulerMetricsLogger: DEFAULT_DECODE_LOG_INTERVAL = 5 - def __init__(self, enabled: bool = True, dp_rank: int = 0) -> None: + def __init__(self, enabled: bool = True, dp_rank: int = 0, splitwise_role: str = "mixed") -> None: self.enabled = enabled self.dp_rank = dp_rank + self.splitwise_role = splitwise_role decode_log_interval = envs.FD_CONSOLE_DECODE_LOG_INTERVAL if decode_log_interval <= 0: decode_log_interval = self.DEFAULT_DECODE_LOG_INTERVAL @@ -65,8 +66,9 @@ def on_decode_tokens(self, num_tokens: int) -> None: with self._lock: self._decode_tokens_since_last += num_tokens - def log_prefill_batch( + def _log_prefill_like_batch( self, + batch_name: str, prefill_reqs: Iterable, running_cnt: int, queue_cnt: int, @@ -91,8 +93,9 @@ def log_prefill_batch( cached_tokens += getattr(req, "num_cached_tokens", 0) or 0 msg = ( - "Prefill batch, " + f"{batch_name}, " f"dp_rank: {self.dp_rank}, " + f"splitwise_role: {self.splitwise_role}, " f"#new-seq: {len(prefill_reqs)}, " f"#new-token: {new_tokens}, " f"#cached-token: {cached_tokens}, " @@ -104,6 +107,48 @@ def log_prefill_batch( ) self._logger.info(msg) + def log_prefill_batch( + self, + prefill_reqs: Iterable, + running_cnt: int, + queue_cnt: int, + tokens_used: int, + token_usage: float, + free_blocks: int = 0, + evictable_blocks: int = 0, + ) -> None: + self._log_prefill_like_batch( + batch_name="Prefill batch", + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) + + def log_decode_bootstrap_batch( + self, + prefill_reqs: Iterable, + running_cnt: int, + queue_cnt: int, + tokens_used: int, + token_usage: float, + free_blocks: int = 0, + evictable_blocks: int = 0, + ) -> None: + self._log_prefill_like_batch( + batch_name="Decode bootstrap batch from prefill", + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) + def log_decode_batch( self, running_cnt: int, @@ -132,6 +177,7 @@ def log_decode_batch( msg = ( "Decode batch, " f"dp_rank: {self.dp_rank}, " + f"splitwise_role: {self.splitwise_role}, " f"#running-req: {running_cnt}, " f"#token: {tokens_used}, " f"token usage: {token_usage:.2f}, " diff --git a/tests/engine/test_scheduler_metrics_logger.py b/tests/engine/test_scheduler_metrics_logger.py index c1305a3daa6..cab38350c49 100644 --- a/tests/engine/test_scheduler_metrics_logger.py +++ b/tests/engine/test_scheduler_metrics_logger.py @@ -32,7 +32,7 @@ def test_on_decode_tokens_accumulates(): def test_log_prefill_batch_logs_expected_message(): - logger = SchedulerMetricsLogger(enabled=True, dp_rank=2) + logger = SchedulerMetricsLogger(enabled=True, dp_rank=2, splitwise_role="prefill") logger._logger = mock.Mock() reqs = [ @@ -46,6 +46,7 @@ def test_log_prefill_batch_logs_expected_message(): message = logger._logger.info.call_args[0][0] assert "Prefill batch" in message assert "dp_rank: 2" in message + assert "splitwise_role: prefill" in message assert "#new-seq: 2" in message assert "#new-token: 4" in message assert "#cached-token: 3" in message @@ -54,8 +55,31 @@ def test_log_prefill_batch_logs_expected_message(): assert "#queue-req: 6" in message +def test_log_decode_bootstrap_batch_logs_expected_message(): + logger = SchedulerMetricsLogger(enabled=True, dp_rank=0, splitwise_role="decode") + logger._logger = mock.Mock() + + reqs = [types.SimpleNamespace(prefill_start_index=4, prefill_end_index=5, num_cached_tokens=4)] + + logger.log_decode_bootstrap_batch( + prefill_reqs=reqs, + running_cnt=1, + queue_cnt=0, + tokens_used=5, + token_usage=0.25, + ) + + logger._logger.info.assert_called_once() + message = logger._logger.info.call_args[0][0] + assert "Decode bootstrap batch" in message + assert "splitwise_role: decode" in message + assert "#new-seq: 1" in message + assert "#new-token: 1" in message + assert "#cached-token: 4" in message + + def test_log_decode_batch_computes_throughput(monkeypatch): - logger = SchedulerMetricsLogger(enabled=True, dp_rank=1) + logger = SchedulerMetricsLogger(enabled=True, dp_rank=1, splitwise_role="decode") logger._logger = mock.Mock() logger._decode_batch_count = logger._decode_log_interval - 1 logger._decode_tokens_since_last = 10 @@ -69,6 +93,7 @@ def test_log_decode_batch_computes_throughput(monkeypatch): message = logger._logger.info.call_args[0][0] assert "Decode batch" in message assert "dp_rank: 1" in message + assert "splitwise_role: decode" in message assert "gen throughput (token/s): 5.00" in message assert "#queue-req: 7" in message assert logger._decode_tokens_since_last == 0 @@ -99,3 +124,8 @@ def test_decode_log_interval_non_positive_falls_back_to_default(monkeypatch): monkeypatch.setenv("FD_CONSOLE_DECODE_LOG_INTERVAL", "0") logger = SchedulerMetricsLogger(enabled=True, dp_rank=0) assert logger._decode_log_interval == SchedulerMetricsLogger.DEFAULT_DECODE_LOG_INTERVAL + + +def test_default_splitwise_role_is_mixed(): + logger = SchedulerMetricsLogger(enabled=True, dp_rank=0) + assert logger.splitwise_role == "mixed" diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index d9ab6a59dbc..6c51adb63a5 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -27,6 +27,7 @@ if not hasattr(paddle, "enable_compat"): paddle.enable_compat = lambda scope=None: None +from fastdeploy import envs from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.request import ( @@ -36,6 +37,7 @@ RequestMetrics, RequestOutput, RequestStatus, + RequestType, ) from fastdeploy.engine.sched.resource_manager_v1 import ( ResourceManagerV1, @@ -569,6 +571,26 @@ def test_preallocate_resource_in_p_and_d(self): self.assertEqual(request_d.num_computed_tokens, request_d.need_prefill_tokens) self.assertEqual(request_d.disaggregate_info["block_tables"], [4, 5]) + def test_decode_role_prefill_task_logs_decode_bootstrap_batch(self): + manager = _build_manager(splitwise_role="decode", enable_prefix_caching=False) + _register_manager_cleanup(self, manager) + manager.cache_manager = MagicMock() + manager.cache_manager.num_gpu_blocks = 8 + manager.cache_manager.gpu_free_block_list = [0, 1, 2, 3] + manager.scheduler_metrics_logger = MagicMock() + + request = _make_request(prompt_token_ids=[1, 2, 3, 4]) + request.task_type = RequestType.PREFILL + request.prefill_start_index = 4 + request.prefill_end_index = 5 + batch_request = [request] + + with patch.object(envs, "FD_CONSOLE_SCHEDULER_METRICS", True): + manager._log_console_scheduler_metrics(batch_request) + + manager.scheduler_metrics_logger.log_decode_bootstrap_batch.assert_called_once() + manager.scheduler_metrics_logger.log_prefill_batch.assert_not_called() + def test_prefilled_request_flow_and_resource_check(self): manager = _build_manager(splitwise_role="decode", speculative_method="mtp") _register_manager_cleanup(self, manager) From e7815bef2a3e029fe22b14e2aaca96801d5ceebd Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Fri, 22 May 2026 18:18:00 +0800 Subject: [PATCH 123/143] [Cherry-Pick][Benchmark] Add inner benchmark metrics component (#7881) (#7831) * Add inner benchmark metrics component * Add window_mode * remove temp scripts * fix ut * increase coverage lines --- docs/benchmark.md | 97 ++++ docs/zh/benchmark.md | 103 ++++ fastdeploy/config.py | 88 +++ fastdeploy/engine/args_utils.py | 24 + fastdeploy/engine/common_engine.py | 12 + .../metrics/benchmark_metrics_logger.py | 222 ++++++++ fastdeploy/output/token_processor.py | 27 + .../metrics/test_benchmark_metrics_logger.py | 499 ++++++++++++++++++ tests/output/test_process_batch_output.py | 1 + 9 files changed, 1073 insertions(+) create mode 100644 fastdeploy/metrics/benchmark_metrics_logger.py create mode 100644 tests/metrics/test_benchmark_metrics_logger.py diff --git a/docs/benchmark.md b/docs/benchmark.md index 1a2e6f88031..7abdd68aac1 100644 --- a/docs/benchmark.md +++ b/docs/benchmark.md @@ -40,3 +40,100 @@ python benchmark_serving.py \ --max-concurrency 1 \ --save-result ``` + +## In-Process Benchmark Metrics Logger + +FastDeploy provides a built-in performance monitoring module that runs inside the inference process. It collects per-request timing data and computes rolling statistics aligned with `benchmark_serving.py`, writing results to a JSONL file for real-time monitoring and post-hoc analysis. + +### Enable + +Add `--benchmark-metrics-config` with a JSON string to the service startup command: + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-0.3B-Base-Paddle \ + --benchmark-metrics-config '{"enable": true}' +``` + +### Configuration Parameters + +| Parameter | Type | Default | Description | +| :-------- | :--- | :------ | :---------- | +| `enable` | bool | `false` | Whether to enable the benchmark metrics logger. Must be set to `true` to activate. | +| `window_size` | int | `0` | Number of recent requests to aggregate. `0` = cumulative (all requests since start). | +| `window_mode` | str | `"sliding"` | Window aggregation mode. `"sliding"` = sliding window (keeps last N records, oldest automatically dropped). `"tumbling"` = tumbling window (clears and restarts after every N records). | +| `percentiles` | str | `"50,90,95,99"` | Comma-separated percentile values to compute. | +| `metrics` | str | `"all"` | Comma-separated metric names to report, or `"all"` for all metrics. | + +### Available Metrics + +Metrics are aligned with `benchmark_serving.py --percentile-metrics`: + +| Metric Name | Description | Unit | +| :---------- | :---------- | :--- | +| `ttft` | Time to First Token (client arrival → first token) | ms | +| `s_ttft` | Server TTFT (inference start → first token) | ms | +| `tpot` | Time per Output Token (excluding first token) | ms | +| `s_itl` | Infer Inter-token Latency | ms | +| `e2el` | End-to-end Latency (client arrival → last token) | ms | +| `s_e2el` | Server E2EL (inference start → last token) | ms | +| `s_decode` | Decode speed (excluding first token) | tok/s | +| `input_len` | Prefix cache hit token count ("Cached Tokens") | tokens | +| `s_input_len` | Infer input length (total prompt tokens) | tokens | +| `output_len` | Output token length per request | tokens | + +In addition, the following throughput metrics are always computed (not user-selectable) when there are 2+ records: + +| Metric | Description | Unit | +| :----- | :---------- | :--- | +| `request_throughput` | Request throughput | req/s | +| `output_throughput` | Output token throughput | tok/s | +| `total_throughput` | Total token throughput (input + output) | tok/s | + +### Window Modes + +**Sliding Window** (`"sliding"`, default): + +The window keeps the most recent N records. When a new record arrives and the window is full, the oldest record is automatically dropped. Each output line reflects the statistics of the latest N requests. + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "sliding"}' +``` + +**Tumbling Window** (`"tumbling"`): + +The window accumulates records up to N, then clears and starts fresh. Each output line still reflects the current window's accumulated statistics, but the window resets at every boundary. This is useful for RL training scenarios where each step has a fixed batch size and you want per-step independent analysis. + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "tumbling"}' +``` + +**No Window** (`window_size: 0`): + +All completed requests are accumulated. Statistics reflect the entire lifetime of the service. + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 0}' +``` + +### Output + +Results are written to `{FD_LOG_DIR}/benchmark_metrics.jsonl` (default: `./log/benchmark_metrics.jsonl`). Each line is a JSON object representing the window statistics at the time of a request completion. + +Example output line: + +```json +{ + "timestamp": "2026-05-14T10:30:05.123", + "window_size": 64, + "window_mode": "sliding", + "completed": 64, + "total_input_tokens": 8192, + "total_output_tokens": 16384, + "request_throughput": 5.2, + "output_throughput": 1250.0, + "total_throughput": 2500.0, + "ttft_ms": {"mean": 45.0, "median": 42.1, "p50": 42.1, "p90": 68.5, "p95": 82.3, "p99": 120.5}, + "s_decode": {"mean": 67.3, "median": 67.5, "p50": 67.5, "p90": 70.1, "p95": 71.2, "p99": 73.0} +} +``` diff --git a/docs/zh/benchmark.md b/docs/zh/benchmark.md index e4a58d93b1e..e0c55c63ef3 100644 --- a/docs/zh/benchmark.md +++ b/docs/zh/benchmark.md @@ -40,3 +40,106 @@ python benchmark_serving.py \ --max-concurrency 1 \ --save-result ``` + +## 进程内性能监控(Benchmark Metrics Logger) + +FastDeploy 提供了内置的进程内性能监控模块,在推理进程内部运行,复用已有的请求时间戳数据,每个请求完成时计算滚动统计并写入 JSONL 文件,可用于实时监控和事后分析。 + +### 启用方式 + +在服务启动命令中添加 `--benchmark-metrics-config` 参数,传入 JSON 配置字符串: + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-0.3B-Base-Paddle \ + --benchmark-metrics-config '{"enable": true}' +``` + +### 配置参数 + +| 参数 | 类型 | 默认值 | 说明 | +| :--- | :--- | :----- | :--- | +| `enable` | bool | `false` | 是否启用性能监控。必须设置为 `true` 才会激活。 | +| `window_size` | int | `0` | 统计窗口大小。`0` = 累计模式(统计所有请求);`>0` = 统计最近 N 个请求。 | +| `window_mode` | str | `"sliding"` | 窗口聚合模式。`"sliding"` = 滑动窗口(保持最近 N 条,旧记录自动淘汰);`"tumbling"` = 翻滚窗口(满 N 条后清空重新累积)。 | +| `percentiles` | str | `"50,90,95,99"` | 要计算的分位值,逗号分隔。 | +| `metrics` | str | `"all"` | 要统计的指标子集,逗号分隔,或 `"all"` 表示全部指标。 | + +### 可用指标 + +指标与 `benchmark_serving.py --percentile-metrics` 对齐: + +| 指标名称 | 说明 | 单位 | +| :------- | :--- | :--- | +| `ttft` | 首 Token 时延(客户端到达 → 首 Token) | ms | +| `s_ttft` | 服务端首 Token 时延(推理开始 → 首 Token) | ms | +| `tpot` | 每 Token 输出时延(不含首 Token) | ms | +| `s_itl` | 推理 Token 间时延 | ms | +| `e2el` | 端到端时延(客户端到达 → 最后一个 Token) | ms | +| `s_e2el` | 服务端端到端时延(推理开始 → 最后一个 Token) | ms | +| `s_decode` | 解码速度(不含首 Token) | tok/s | +| `input_len` | 前缀缓存命中 Token 数("Cached Tokens") | tokens | +| `s_input_len` | 推理输入长度(总 prompt token 数) | tokens | +| `output_len` | 输出 Token 长度 | tokens | + +此外,以下吞吐量指标在有 2 个以上请求完成时自动计算(不受 `metrics` 参数控制): + +| 指标 | 说明 | 单位 | +| :--- | :--- | :--- | +| `request_throughput` | 请求吞吐量 | req/s | +| `output_throughput` | 输出 Token 吞吐量 | tok/s | +| `total_throughput` | 总 Token 吞吐量(输入 + 输出) | tok/s | + +### 窗口模式 + +**滑动窗口**(`"sliding"`,默认): + +窗口始终保持最近 N 条记录。当新记录到达且窗口已满时,最旧的记录自动淘汰。每行输出反映最近 N 个请求的统计值。 + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "sliding"}' +``` + +**翻滚窗口**(`"tumbling"`): + +窗口累积到 N 条后清空重新开始。每行输出反映当前窗口已累积请求的统计值,窗口在边界处重置。适用于 RL 训练场景,每个 step 有固定 batch size,需要逐 step 独立分析。 + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "tumbling"}' +``` + +**无窗口**(`window_size: 0`): + +所有已完成请求持续累积,统计值反映服务启动以来的全量数据。 + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 0}' +``` + +### 输出说明 + +结果写入 `{FD_LOG_DIR}/benchmark_metrics.jsonl`(默认路径:`./log/benchmark_metrics.jsonl`)。每行为一个 JSON 对象,表示某个请求完成时刻窗口内的统计快照。 + +输出示例: + +```json +{ + "timestamp": "2026-05-14T10:30:05.123", + "window_size": 64, + "window_mode": "sliding", + "completed": 64, + "total_input_tokens": 8192, + "total_output_tokens": 16384, + "request_throughput": 5.2, + "output_throughput": 1250.0, + "total_throughput": 2500.0, + "ttft_ms": {"mean": 45.0, "median": 42.1, "p50": 42.1, "p90": 68.5, "p95": 82.3, "p99": 120.5}, + "s_decode": {"mean": 67.3, "median": 67.5, "p50": 67.5, "p90": 70.1, "p95": 71.2, "p99": 73.0} +} +``` + +读取最后一行即可获取当前最新的性能快照: + +```bash +tail -1 log/benchmark_metrics.jsonl | python -m json.tool +``` diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 3559e89799e..f5d37cbc7ff 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1896,6 +1896,65 @@ def __str__(self): return self.to_json_string() +class BenchmarkMetricsConfig: + """Configuration for in-process benchmark metrics logger. + + Args (passed as JSON dict via --benchmark-metrics-config): + enable: Whether to enable the benchmark metrics logger. Default: False. + window_size: Number of recent requests to aggregate. 0 = all requests (cumulative). + window_mode: Window aggregation mode. Default: "sliding". + "sliding" = sliding window (keep last N records), + "tumbling" = tumbling window (clear and restart after every N records). + percentiles: Comma-separated percentile values to compute, e.g. "50,90,95,99". + metrics: Comma-separated metric names to report, or "all". + Available metrics (aligned with benchmark_serving.py --percentile-metrics): + ttft - Time to First Token (client arrival → first token) + s_ttft - Server TTFT (inference start → first token) + tpot - Time per Output Token (excluding first token) + s_itl - Infer Inter-token Latency + e2el - End-to-end Latency (client arrival → last token) + s_e2el - Server E2EL (inference start → last token) + s_decode - Decode speed (tokens/s, excluding first token) + input_len - Prefix cache hit token count ("Cached Tokens" in benchmark_serving) + s_input_len - Infer input length (total prompt tokens on inference side) + output_len - Output token length per request + """ + + _DEFAULTS = { + "enable": False, + "window_size": 0, + "window_mode": "sliding", + "percentiles": "50,90,95,99", + "metrics": "all", + } + + _ALL_METRICS = [ + "ttft", # Time to First Token + "s_ttft", # Server TTFT + "tpot", # Time per Output Token + "s_itl", # Infer Inter-token Latency + "e2el", # End-to-end Latency + "s_e2el", # Server E2EL + "s_decode", # Decode speed (tok/s) + "input_len", # Prefix cache hit tokens (= "Cached Tokens" in benchmark_serving) + "s_input_len", # Infer input length (total prompt tokens) + "output_len", # Output token length + ] + + def __init__(self, args: Optional[dict] = None): + for key, value in self._DEFAULTS.items(): + setattr(self, key, value) + if args: + for key, value in args.items(): + if key in self._DEFAULTS: + setattr(self, key, value) + self.percentile_values = [float(p.strip()) for p in self.percentiles.split(",") if p.strip()] + if self.metrics == "all": + self.selected_metrics = set(self._ALL_METRICS) + else: + self.selected_metrics = {m.strip() for m in self.metrics.split(",") if m.strip()} + + class FDConfig: """ The configuration class which contains all fastdeploy-related configuration. This @@ -1930,6 +1989,7 @@ def __init__( tool_parser: str = None, test_mode=False, routing_replay_config: Optional[RoutingReplayConfig] = None, + benchmark_metrics_config=None, deploy_modality: DeployModality = DeployModality.MIXED, ): self.model_config: ModelConfig = model_config # type: ignore @@ -1947,6 +2007,7 @@ def __init__( self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config self.router_config: RouterConfig = router_config self.routing_replay_config = routing_replay_config + self.benchmark_metrics_config = benchmark_metrics_config self.deploy_modality: DeployModality = deploy_modality # Initialize cuda graph capture list @@ -2395,6 +2456,33 @@ def check(self): " CUDA 12.x → pip install cuda-python==12.*\n" ) + if self.benchmark_metrics_config is not None: + cfg = self.benchmark_metrics_config + assert isinstance( + cfg.enable, bool + ), f"BenchmarkMetricsConfig: 'enable' must be a bool, got {type(cfg.enable).__name__}" + assert ( + isinstance(cfg.window_size, int) and cfg.window_size >= 0 + ), f"BenchmarkMetricsConfig: 'window_size' must be a non-negative integer, got {cfg.window_size!r}" + assert cfg.window_mode in ( + "sliding", + "tumbling", + ), f"BenchmarkMetricsConfig: 'window_mode' must be 'sliding' or 'tumbling', got {cfg.window_mode!r}" + assert ( + isinstance(cfg.percentiles, str) and cfg.percentiles.strip() + ), f"BenchmarkMetricsConfig: 'percentiles' must be a non-empty string, got {cfg.percentiles!r}" + for p in cfg.percentile_values: + assert 0 <= p <= 100, f"BenchmarkMetricsConfig: percentile value {p} out of range [0, 100]" + assert ( + isinstance(cfg.metrics, str) and cfg.metrics.strip() + ), f"BenchmarkMetricsConfig: 'metrics' must be a non-empty string, got {cfg.metrics!r}" + if cfg.metrics != "all": + invalid = cfg.selected_metrics - set(BenchmarkMetricsConfig._ALL_METRICS) + assert not invalid, ( + f"BenchmarkMetricsConfig: unknown metric(s): {invalid}. " + f"Valid metrics: {BenchmarkMetricsConfig._ALL_METRICS}" + ) + def print(self): """ print all config diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index b285f66b42a..da4cc40ca6e 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -188,6 +188,10 @@ class EngineArgs: """ Configuration for speculative execution. """ + benchmark_metrics_config: Optional[Dict[str, Any]] = None + """ + Configuration for in-process benchmark metrics logger. + """ dynamic_load_weight: bool = False """ dynamic load weight @@ -847,6 +851,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.speculative_config, help="Configuration for speculative execution.", ) + model_group.add_argument( + "--benchmark-metrics-config", + type=json.loads, + default=EngineArgs.benchmark_metrics_config, + help="Configuration for in-process benchmark metrics logger. " + "Pass '{}' for defaults or a JSON with keys: " + "window_size (int, 0=all requests), " + "percentiles (str, e.g. '50,90,95,99'), " + "metrics (str, 'all' or comma-separated subset).", + ) model_group.add_argument( "--dynamic-load-weight", action="store_true", @@ -1431,6 +1445,14 @@ def create_speculative_config(self) -> SpeculativeConfig: return SpeculativeConfig(speculative_args) + def create_benchmark_metrics_config(self): + """Create BenchmarkMetricsConfig if --benchmark-metrics-config is provided.""" + if self.benchmark_metrics_config is None: + return None + from fastdeploy.config import BenchmarkMetricsConfig + + return BenchmarkMetricsConfig(self.benchmark_metrics_config) + def create_scheduler_config(self) -> SchedulerConfig: """ Create and return a SchedulerConfig object based on the current settings. @@ -1510,6 +1532,7 @@ def create_engine_config(self) -> FDConfig: self.tensor_parallel_size = model_cfg.tensor_parallel_size speculative_cfg = self.create_speculative_config() + benchmark_metrics_cfg = self.create_benchmark_metrics_config() if not self.enable_chunked_prefill: if (current_platform.is_cuda() or current_platform.is_maca()) and self.splitwise_role == "mixed": # default enable chunked prefill @@ -1574,5 +1597,6 @@ def create_engine_config(self) -> FDConfig: plas_attention_config=plas_attention_config, early_stop_config=early_stop_cfg, routing_replay_config=routing_replay_config, + benchmark_metrics_config=benchmark_metrics_cfg, deploy_modality=DeployModality.from_str(self.deploy_modality), ) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 16f705cf23f..0a49be9e73a 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -201,6 +201,18 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): self.resource_manager.scheduler_metrics_logger = self.scheduler_metrics_logger self.token_processor.set_scheduler_metrics_logger(self.scheduler_metrics_logger) + if self.cfg.benchmark_metrics_config is not None and self.cfg.benchmark_metrics_config.enable: + from fastdeploy.metrics.benchmark_metrics_logger import ( + BenchmarkMetricsLogger, + ) + + self.benchmark_metrics_logger = BenchmarkMetricsLogger( + config=self.cfg.benchmark_metrics_config, + log_dir=envs.FD_LOG_DIR, + dp_rank=self.cfg.parallel_config.local_data_parallel_id, + ) + self.token_processor.set_benchmark_logger(self.benchmark_metrics_logger) + self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) for idx in range(1, self.cfg.max_num_partial_prefills + 1): self.partial_chunked_tokens[idx] = ( diff --git a/fastdeploy/metrics/benchmark_metrics_logger.py b/fastdeploy/metrics/benchmark_metrics_logger.py new file mode 100644 index 00000000000..7e381fb3cc6 --- /dev/null +++ b/fastdeploy/metrics/benchmark_metrics_logger.py @@ -0,0 +1,222 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import os +import threading +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import numpy as np + +from fastdeploy.config import BenchmarkMetricsConfig + + +@dataclass(slots=True) +class CompletedRequestRecord: + """Raw timing data collected when a request completes.""" + + request_id: str + completion_time: float + arrival_time: float + inference_start_time: float + first_token_time: float + last_token_time: float + input_len: int + output_len: int + num_cached_tokens: int = 0 + itl_samples: list = field(default_factory=list) + + +class BenchmarkMetricsLogger: + """ + In-process performance monitoring that produces metrics aligned with + benchmark_serving.py. Uses a lock-free deque for data collection and + a background daemon thread for stats computation and file I/O. + """ + + def __init__(self, config: BenchmarkMetricsConfig, log_dir: str, dp_rank: int = 0): + self.config = config + self.enabled = config.enable + self.dp_rank = dp_rank + + if config.window_mode == "sliding" and config.window_size > 0: + self._window: deque = deque(maxlen=config.window_size) + else: + self._window: deque = deque() + + self._pending: deque = deque() + self._condition = threading.Condition() + self._stop_event = threading.Event() + + os.makedirs(log_dir, exist_ok=True) + self._file_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + self._file = open(self._file_path, "a", encoding="utf-8") + + self._thread = threading.Thread( + target=self._writer_loop, + daemon=True, + name=f"BenchmarkMetricsLogger-dp{dp_rank}", + ) + self._thread.start() + + def on_request_completed(self, record: CompletedRequestRecord) -> None: + """Called from token processor on request completion. Lock-free append.""" + self._pending.append(record) + with self._condition: + self._condition.notify() + + def _writer_loop(self) -> None: + """Background thread: wait for new records, compute stats, write JSONL.""" + while not self._stop_event.is_set(): + with self._condition: + self._condition.wait(timeout=1.0) + self._process_pending() + + def _process_pending(self) -> None: + """Process all pending records, write one JSONL line per record.""" + while True: + try: + record = self._pending.popleft() + except IndexError: + break + self._window.append(record) + stats = self._compute_rolling_stats() + line = json.dumps(stats, ensure_ascii=False) + self._file.write(line + "\n") + # Tumbling window: clear after reaching window_size + if ( + self.config.window_mode == "tumbling" + and self.config.window_size > 0 + and len(self._window) >= self.config.window_size + ): + self._window.clear() + self._file.flush() + + def _compute_rolling_stats(self) -> dict: + """Compute aggregate statistics over the current window.""" + records = list(self._window) + n = len(records) + if n == 0: + return {"timestamp": datetime.now().isoformat(), "completed": 0} + + selected = self.config.selected_metrics + percentile_values = self.config.percentile_values + + ttfts = [] + s_ttfts = [] + tpots = [] + all_itls = [] + e2els = [] + s_e2els = [] + decode_speeds = [] + input_lens = [] + s_input_lens = [] + output_lens = [] + + for r in records: + if r.first_token_time and r.arrival_time: + ttfts.append((r.first_token_time - r.arrival_time) * 1000) + if r.first_token_time and r.inference_start_time: + s_ttfts.append((r.first_token_time - r.inference_start_time) * 1000) + if r.output_len > 1 and r.first_token_time and r.arrival_time: + e2el_s = r.last_token_time - r.arrival_time + ttft_s = r.first_token_time - r.arrival_time + tpots.append(((e2el_s - ttft_s) / (r.output_len - 1)) * 1000) + if r.itl_samples: + all_itls.extend([x * 1000 for x in r.itl_samples]) + if r.last_token_time and r.arrival_time: + e2els.append((r.last_token_time - r.arrival_time) * 1000) + if r.last_token_time and r.inference_start_time: + s_e2els.append((r.last_token_time - r.inference_start_time) * 1000) + if r.output_len > 1 and r.first_token_time and r.last_token_time: + decode_time = r.last_token_time - r.first_token_time + if decode_time > 0: + decode_speeds.append((r.output_len - 1) / decode_time) + input_lens.append(r.num_cached_tokens) + s_input_lens.append(r.input_len) + output_lens.append(r.output_len) + + # Throughput: based on window time span + total_input = sum(s_input_lens) + total_output = sum(output_lens) + if n >= 2: + duration = records[-1].completion_time - records[0].arrival_time + else: + duration = 0.0 + + result: dict[str, Any] = { + "timestamp": datetime.now().isoformat(), + "window_size": self.config.window_size, + "window_mode": self.config.window_mode, + "completed": n, + "total_input_tokens": total_input, + "total_output_tokens": total_output, + } + + if duration > 0: + result["request_throughput"] = round(n / duration, 2) + result["output_throughput"] = round(total_output / duration, 2) + result["total_throughput"] = round((total_input + total_output) / duration, 2) + + if "ttft" in selected: + result["ttft_ms"] = self._stats(ttfts, percentile_values) + if "s_ttft" in selected: + result["s_ttft_ms"] = self._stats(s_ttfts, percentile_values) + if "tpot" in selected: + result["tpot_ms"] = self._stats(tpots, percentile_values) + if "s_itl" in selected: + result["s_itl_ms"] = self._stats(all_itls, percentile_values) + if "e2el" in selected: + result["e2el_ms"] = self._stats(e2els, percentile_values) + if "s_e2el" in selected: + result["s_e2el_ms"] = self._stats(s_e2els, percentile_values) + if "s_decode" in selected: + result["s_decode"] = self._stats(decode_speeds, percentile_values) + if "input_len" in selected: + result["input_len"] = self._stats(input_lens, percentile_values) + if "s_input_len" in selected: + result["s_input_len"] = self._stats(s_input_lens, percentile_values) + if "output_len" in selected: + result["output_len"] = self._stats(output_lens, percentile_values) + + return result + + @staticmethod + def _stats(values: list, percentiles: list[float]) -> dict: + """Compute mean/median/percentiles for a list of values.""" + if not values: + return {} + arr = np.array(values) + result = { + "mean": round(float(np.mean(arr)), 2), + "median": round(float(np.median(arr)), 2), + } + for p in percentiles: + key = f"p{int(p)}" if int(p) == p else f"p{p}" + result[key] = round(float(np.percentile(arr, p)), 2) + return result + + def shutdown(self) -> None: + """Stop the writer thread and close the file.""" + self._stop_event.set() + with self._condition: + self._condition.notify() + self._thread.join(timeout=5) + self._process_pending() + self._file.close() diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index a8544cb5979..7505513d7f9 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -69,6 +69,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.cached_generated_tokens = cached_generated_tokens self.resource_manager = None self.scheduler_metrics_logger = None + self._benchmark_logger = None self.engine_worker_queue = engine_worker_queue self.tokens_counter = Counter() self.split_connector = split_connector @@ -236,6 +237,9 @@ def set_resource_manager(self, resource_manager): def set_scheduler_metrics_logger(self, scheduler_metrics_logger): self.scheduler_metrics_logger = scheduler_metrics_logger + def set_benchmark_logger(self, benchmark_logger): + self._benchmark_logger = benchmark_logger + def _is_decode_stage(self, task): if task is None: return False @@ -1120,6 +1124,10 @@ def _record_metrics(self, task, current_time, token_ids): if hasattr(task, "last_token_time") and task.last_token_time is not None: token_gen_time = current_time - task.last_token_time main_process_metrics.time_per_output_token.observe(token_gen_time) + if self._benchmark_logger: + if not hasattr(task, "_itl_samples"): + task._itl_samples = [] + task._itl_samples.append(token_gen_time) task.last_token_time = current_time # Record generation metrics @@ -1155,6 +1163,25 @@ def _record_completion_metrics(self, task, current_time): main_process_metrics.request_inference_time.observe(current_time - metrics.inference_start_time) main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id]) + if self._benchmark_logger: + from fastdeploy.metrics.benchmark_metrics_logger import ( + CompletedRequestRecord, + ) + + record = CompletedRequestRecord( + request_id=task.request_id, + completion_time=current_time, + arrival_time=metrics.arrival_time or 0.0, + inference_start_time=metrics.inference_start_time or 0.0, + first_token_time=metrics.engine_recv_first_token_time or 0.0, + last_token_time=metrics.engine_recv_latest_token_time or current_time, + input_len=getattr(task, "prompt_token_ids_len", 0) or 0, + output_len=self.tokens_counter[task.request_id], + num_cached_tokens=getattr(task, "num_cached_tokens", 0) or 0, + itl_samples=getattr(task, "_itl_samples", []), + ) + self._benchmark_logger.on_request_completed(record) + def _record_speculative_decoding_metrics(self, accept_num): """Record metrics of speculative decoding""" if not hasattr(main_process_metrics, "spec_decode_draft_acceptance_rate"): diff --git a/tests/metrics/test_benchmark_metrics_logger.py b/tests/metrics/test_benchmark_metrics_logger.py new file mode 100644 index 00000000000..4f291327ad0 --- /dev/null +++ b/tests/metrics/test_benchmark_metrics_logger.py @@ -0,0 +1,499 @@ +""" +Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import os +import time +from unittest.mock import MagicMock, patch + +import pytest + +from fastdeploy.config import BenchmarkMetricsConfig, FDConfig +from fastdeploy.metrics.benchmark_metrics_logger import ( + BenchmarkMetricsLogger, + CompletedRequestRecord, +) + + +def _make_record(request_id, now, offset, input_len=100, output_len=50): + return CompletedRequestRecord( + request_id=request_id, + completion_time=now + offset, + arrival_time=now + offset - 0.05, + inference_start_time=now + offset - 0.04, + first_token_time=now + offset - 0.02, + last_token_time=now + offset, + input_len=input_len, + output_len=output_len, + itl_samples=[0.02, 0.021, 0.019], + ) + + +def test_config_defaults(): + config = BenchmarkMetricsConfig(None) + assert config.enable is False + assert config.window_size == 0 + assert config.window_mode == "sliding" + assert config.percentile_values == [50.0, 90.0, 95.0, 99.0] + assert config.selected_metrics == set(BenchmarkMetricsConfig._ALL_METRICS) + + +def test_config_custom(): + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 200, "window_mode": "tumbling", "percentiles": "50,99", "metrics": "ttft,e2el"} + ) + assert config.enable is True + assert config.window_size == 200 + assert config.window_mode == "tumbling" + assert config.percentile_values == [50.0, 99.0] + assert config.selected_metrics == {"ttft", "e2el"} + + +def test_config_empty_dict(): + config = BenchmarkMetricsConfig({}) + assert config.enable is False + assert config.window_size == 0 + assert config.window_mode == "sliding" + assert config.percentile_values == [50.0, 90.0, 95.0, 99.0] + + +def test_config_enable_only(): + config = BenchmarkMetricsConfig({"enable": True}) + assert config.enable is True + assert config.window_mode == "sliding" + + +def test_logger_writes_jsonl(tmp_path): + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "percentiles": "50,99", "metrics": "ttft,e2el"}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(5): + logger.on_request_completed(_make_record(f"req-{i}", now, i * 0.1)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + assert os.path.exists(jsonl_path) + + with open(jsonl_path) as f: + lines = f.readlines() + + assert len(lines) == 5 + + last_record = json.loads(lines[-1]) + assert last_record["completed"] == 5 + assert "ttft_ms" in last_record + assert "e2el_ms" in last_record + assert "tpot_ms" not in last_record + assert last_record["ttft_ms"]["mean"] > 0 + + +def test_logger_sliding_window(tmp_path): + """Sliding window: keeps the last N records, never clears.""" + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 3, "window_mode": "sliding", "percentiles": "50", "metrics": "all"} + ) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(5): + logger.on_request_completed(_make_record(f"req-{i}", now, i)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + + assert len(lines) == 5 + + # After 5 records with window_size=3, the window always has at most 3 + rec3 = json.loads(lines[2]) # 3rd record: window full (3 records) + assert rec3["completed"] == 3 + + rec4 = json.loads(lines[3]) # 4th record: still 3 (oldest dropped) + assert rec4["completed"] == 3 + + last_record = json.loads(lines[-1]) + assert last_record["completed"] == 3 + assert last_record["window_size"] == 3 + assert last_record["window_mode"] == "sliding" + + +def test_logger_tumbling_window(tmp_path): + """Tumbling window: clears after reaching window_size, then restarts.""" + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 3, "window_mode": "tumbling", "percentiles": "50", "metrics": "all"} + ) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(5): + logger.on_request_completed(_make_record(f"req-{i}", now, i)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + + assert len(lines) == 5 + + # Records 1,2,3 accumulate then clear; records 4,5 start fresh + rec1 = json.loads(lines[0]) + assert rec1["completed"] == 1 + + rec3 = json.loads(lines[2]) # 3rd record: window full (3 records), then clears + assert rec3["completed"] == 3 + + rec4 = json.loads(lines[3]) # 4th record: window restarted, 1 record + assert rec4["completed"] == 1 + + rec5 = json.loads(lines[4]) # 5th record: 2 records in new window + assert rec5["completed"] == 2 + assert rec5["window_mode"] == "tumbling" + + +def test_logger_no_output_when_no_requests(tmp_path): + config = BenchmarkMetricsConfig({"enable": True}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + time.sleep(0.3) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + assert os.path.exists(jsonl_path) + with open(jsonl_path) as f: + content = f.read() + assert content == "" + + +def test_logger_enabled_flag(tmp_path): + """Logger with enable=False should have enabled=False.""" + config = BenchmarkMetricsConfig({"enable": False}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + assert logger.enabled is False + logger.shutdown() + + +def test_logger_enabled_true(tmp_path): + """Logger with enable=True should have enabled=True.""" + config = BenchmarkMetricsConfig({"enable": True}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + assert logger.enabled is True + logger.shutdown() + + +def test_stats_computation(): + stats = BenchmarkMetricsLogger._stats([10.0, 20.0, 30.0, 40.0, 50.0], [50.0, 99.0]) + assert stats["mean"] == 30.0 + assert stats["median"] == 30.0 + assert "p50" in stats + assert "p99" in stats + assert stats["p50"] == 30.0 + + +def test_stats_empty_list(): + stats = BenchmarkMetricsLogger._stats([], [50.0]) + assert stats == {} + + +def test_throughput_in_output(tmp_path): + """Throughput fields should appear when there are 2+ records.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "percentiles": "50", "metrics": "ttft"}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(3): + logger.on_request_completed(_make_record(f"req-{i}", now, i * 0.5)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + + # First record has no throughput (only 1 sample, duration=0) + rec1 = json.loads(lines[0]) + assert "request_throughput" not in rec1 + + # Last record should have throughput + last = json.loads(lines[-1]) + assert "request_throughput" in last + assert "output_throughput" in last + assert "total_throughput" in last + assert last["request_throughput"] > 0 + + +# ============================================================ +# Validation tests (via FDConfig.check()) +# ============================================================ + + +def _make_fd_config_with_benchmark(benchmark_cfg): + """Create a mock FDConfig with valid base attributes, only benchmark_metrics_config is real.""" + cfg = object.__new__(FDConfig) + # Mock all attributes accessed by check() before benchmark validation + cfg.scheduler_config = MagicMock() + cfg.scheduler_config.max_num_seqs = 128 + cfg.scheduler_config.max_num_batched_tokens = 8192 + cfg.scheduler_config.splitwise_role = "mixed" + cfg.scheduler_config.check = MagicMock() + cfg.model_config = MagicMock() + cfg.model_config.max_model_len = 8192 + cfg.cache_config = MagicMock() + cfg.cache_config.enable_chunked_prefill = True + cfg.cache_config.block_size = 64 + cfg.speculative_config = None + cfg.eplb_config = None + cfg.structured_outputs_config = None + cfg.graph_opt_config = MagicMock() + cfg.graph_opt_config.graph_opt_level = 0 + cfg.nnode = 1 + cfg.max_num_partial_prefills = 1 + cfg.max_long_partial_prefills = 1 + cfg.long_prefill_token_threshold = 0 + cfg.benchmark_metrics_config = benchmark_cfg + return cfg + + +@patch("fastdeploy.config.envs") +def test_valid_config_passes_check(mock_envs): + """Valid configs should pass FDConfig.check() without errors.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + configs = [ + {"enable": True}, + {"enable": True, "window_size": 64, "window_mode": "tumbling"}, + {"enable": False, "window_size": 0, "window_mode": "sliding"}, + {"enable": True, "percentiles": "50,90,99", "metrics": "ttft,e2el,s_decode"}, + ] + for args in configs: + benchmark_cfg = BenchmarkMetricsConfig(args) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + fd_cfg.check() # Should not raise + + +@patch("fastdeploy.config.envs") +def test_invalid_enable(mock_envs): + """enable must be a bool.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": "true"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'enable' must be a bool"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_window_size_negative(mock_envs): + """window_size must be non-negative.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "window_size": -1}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'window_size' must be a non-negative integer"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_window_size_type(mock_envs): + """window_size must be an integer.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "window_size": 3.5}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'window_size' must be a non-negative integer"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_window_mode(mock_envs): + """window_mode must be 'sliding' or 'tumbling'.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "window_mode": "fixed"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'window_mode' must be 'sliding' or 'tumbling'"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_percentile_out_of_range(mock_envs): + """Percentile values must be in [0, 100].""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "percentiles": "50,101"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="percentile value .* out of range"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_percentile_negative(mock_envs): + """Percentile values must be >= 0.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "percentiles": "-1,50"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="percentile value .* out of range"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_metrics_unknown(mock_envs): + """Unknown metric names should fail validation.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "metrics": "ttft,unknown_metric"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="unknown metric"): + fd_cfg.check() + + +# ============================================================ +# Direct method tests (bypass daemon thread for coverage) +# ============================================================ + + +def test_process_pending_direct(tmp_path): + """Directly call _process_pending to cover lines 98-109.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50,99"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + # Add records directly to _pending without relying on background thread + for i in range(3): + logger._pending.append(_make_record(f"req-{i}", now, i * 0.5)) + + # Call _process_pending directly from main thread (coverage-tracked) + logger._process_pending() + + assert len(logger._window) == 3 + logger.shutdown() + + jsonl_path = os.path.join(str(tmp_path), "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + assert len(lines) == 3 + rec = json.loads(lines[-1]) + assert rec["completed"] == 3 + assert "ttft_ms" in rec + assert "tpot_ms" in rec + assert "e2el_ms" in rec + assert "s_ttft_ms" in rec + assert "s_e2el_ms" in rec + assert "s_decode" in rec + assert "input_len" in rec + assert "s_input_len" in rec + assert "output_len" in rec + assert "request_throughput" in rec + assert "output_throughput" in rec + assert "total_throughput" in rec + + +def test_process_pending_tumbling_clear(tmp_path): + """Tumbling window clears after reaching window_size via direct call.""" + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 2, "window_mode": "tumbling", "metrics": "ttft", "percentiles": "50"} + ) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + for i in range(3): + logger._pending.append(_make_record(f"req-{i}", now, i * 0.5)) + + logger._process_pending() + + # After 3 records with window_size=2: first 2 fill window then clear, 3rd starts fresh + assert len(logger._window) == 1 + logger.shutdown() + + +def test_compute_rolling_stats_empty_window(tmp_path): + """_compute_rolling_stats with empty window returns minimal result.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + result = logger._compute_rolling_stats() + assert result["completed"] == 0 + logger.shutdown() + + +def test_compute_rolling_stats_single_record(tmp_path): + """Single record: no throughput, no tpot (needs output_len > 1 check).""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50,99"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + # output_len=1 means tpot and decode_speed won't be computed + logger._window.append( + CompletedRequestRecord( + request_id="r1", + completion_time=now, + arrival_time=now - 0.05, + inference_start_time=now - 0.04, + first_token_time=now - 0.02, + last_token_time=now, + input_len=100, + output_len=1, + itl_samples=[], + ) + ) + + result = logger._compute_rolling_stats() + assert result["completed"] == 1 + assert "request_throughput" not in result # duration=0 for single record + assert result["ttft_ms"]["mean"] > 0 + assert result["tpot_ms"] == {} # no tpot with output_len=1 + assert result["s_itl_ms"] == {} # no itl samples + logger.shutdown() + + +def test_compute_rolling_stats_multiple_records(tmp_path): + """Multiple records: throughput and all metrics computed.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50,95"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + for i in range(3): + logger._window.append(_make_record(f"req-{i}", now, i * 0.5)) + + result = logger._compute_rolling_stats() + assert result["completed"] == 3 + assert result["request_throughput"] > 0 + assert result["output_throughput"] > 0 + assert result["total_throughput"] > 0 + assert result["ttft_ms"]["mean"] > 0 + assert result["s_ttft_ms"]["mean"] > 0 + assert result["tpot_ms"]["mean"] > 0 + assert result["s_itl_ms"]["mean"] > 0 + assert result["e2el_ms"]["mean"] > 0 + assert result["s_e2el_ms"]["mean"] > 0 + assert result["s_decode"]["mean"] > 0 + assert "p50" in result["ttft_ms"] + assert "p95" in result["ttft_ms"] + logger.shutdown() + + +def test_stats_with_float_percentile(): + """Percentile key uses float format when not integer.""" + stats = BenchmarkMetricsLogger._stats([1.0, 2.0, 3.0], [99.9]) + assert "p99.9" in stats diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index c84514c06d5..2853e47be15 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -169,6 +169,7 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): processor.accept_token_num_per_head_per_request = {} processor.accept_token_num_per_head = [0] * MAX_DRAFT_TOKENS processor.use_sampling_mask = False + processor._benchmark_logger = None # processor._recycle_resources = Mock() From 5d1898447db4f2462bd4462dc840811f14c3a0bd Mon Sep 17 00:00:00 2001 From: kevin Date: Mon, 25 May 2026 10:43:46 +0800 Subject: [PATCH 124/143] fix(kvcache): buffer early layer0 signals (#7896) --- fastdeploy/cache_manager/cache_messager.py | 70 ++++++--- tests/cache_manager/test_cache_messager.py | 167 ++++++++++++++++++++- 2 files changed, 213 insertions(+), 24 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 08c8dea003a..4a188edd161 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -620,6 +620,8 @@ def __init__( dict() for _ in range(512) ] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}} self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict} + self.pending_layer0_signals = {} + self.pending_layer0_signal_lock = threading.Lock() self.cache_prefilled_engine_ids_queue = ( queue.Queue() ) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)] @@ -663,7 +665,28 @@ def _add_cache_task_thread(self): current_info["status"] = "init" logger.info(f"Get cache info and finish add cache task: {current_info}") self.cache_info[info["request_id"]] = current_info - self.idx_cache_task_dict[current_info["current_id"]] = current_info + current_id = current_info["current_id"] + with self.engine_cache_task_thread_lock: + self.idx_cache_task_dict[current_id] = current_info + with self.pending_layer0_signal_lock: + recovered_signal = self.pending_layer0_signals.pop(current_id, None) + if recovered_signal is not None: + _, prefilled_token_num = recovered_signal + if prefilled_token_num <= current_info["need_prefill_tokens"]: + recovered_signal_batch = [recovered_signal] + logger.info( + "cache_task_register_recover_layer0_signal: " + f"current_id: {current_id}, " + f"recovered_signal_batch: {recovered_signal_batch}" + ) + self.cache_prefilled_engine_ids_queue.put(recovered_signal_batch) + else: + logger.info( + "cache_task_register_drop_layer0_signal: " + f"current_id: {current_id}, " + f"recovered_signal: {recovered_signal}, " + f"need_prefill_tokens: {current_info['need_prefill_tokens']}" + ) else: logger.info(f"Get cache info: {info}") self.cache_info[info["request_id"]] = info @@ -842,9 +865,12 @@ def prefill_layerwise_send_cache_thread(self): logger.info( f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}" ) - self.engine_cache_tasks[task["current_id"]] = dict() + current_id = task["current_id"] + self.engine_cache_tasks[current_id] = dict() del self.cache_info[task["request_id"]] - del self.idx_cache_task_dict[task["current_id"]] + del self.idx_cache_task_dict[current_id] + with self.pending_layer0_signal_lock: + self.pending_layer0_signals.pop(current_id, None) break except Exception as e: logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}") @@ -856,32 +882,42 @@ def consume_signals(self): while True: try: get_output_kv_signal(kv_signal_data, self.rank_id, 1) # wait_flag - if not self.cache_info: - time.sleep(0.01) - continue - tasks_count = kv_signal_data[0] + has_cache_info = bool(self.cache_info) + tasks_count = kv_signal_data[0].item() if tasks_count == -1: continue + if not has_cache_info: + logger.debug("consume_signals get kv signal before cache info is ready") layer_id = kv_signal_data[1].item() if layer_id == self.num_layers - 1: logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}") - batch_engine_signals = [] + ready_engine_signals = [] + pending_engine_signals = [] # format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)] with self.engine_cache_task_thread_lock: for bi in range(tasks_count): engine_idx = kv_signal_data[3 * bi + 2].item() chuck_token_offset = kv_signal_data[3 * bi + 3].item() current_seq_len = kv_signal_data[3 * bi + 4].item() + prefilled_token_num = chuck_token_offset + current_seq_len self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id - self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = ( - chuck_token_offset + current_seq_len - ) - batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len)) - if layer_id == 0: - logger.info( - f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue" - ) - self.cache_prefilled_engine_ids_queue.put(batch_engine_signals) + self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = prefilled_token_num + if layer_id == 0: + if engine_idx in self.idx_cache_task_dict: + ready_engine_signals.append((engine_idx, prefilled_token_num)) + else: + pending_engine_signals.append((engine_idx, prefilled_token_num)) + if pending_engine_signals: + with self.pending_layer0_signal_lock: + for engine_idx, prefilled_token_num in pending_engine_signals: + self.pending_layer0_signals[engine_idx] = (engine_idx, prefilled_token_num) + if pending_engine_signals: + logger.debug(f"cache_task_pending_layer0_signal: {pending_engine_signals}") + if ready_engine_signals: + logger.info( + f"Put batch_engine_signals {ready_engine_signals} into cache_prefilled_engine_ids_queue" + ) + self.cache_prefilled_engine_ids_queue.put(ready_engine_signals) except Exception as e: logger.error(f"Consume signals get exception: {e}") diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py index 3e415ebe9c8..07ff5054f2a 100644 --- a/tests/cache_manager/test_cache_messager.py +++ b/tests/cache_manager/test_cache_messager.py @@ -124,6 +124,14 @@ def error(self, msg): self.messages.append(("error", msg)) +class _QueueRecorder: + def __init__(self): + self.items = [] + + def put(self, item): + self.items.append(item) + + class _DummySignalValue: def __init__(self, sequence): self.sequence = list(sequence) @@ -390,6 +398,111 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch): assert messager.cache_info["req-2"]["status"] == "init" +def test_cache_messager_v1_recovers_pending_layer0_signal(monkeypatch): + dummy_queue = _DummyEngineWorkerQueue( + cache_info_sequence=[ + [ + { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + ] + ] + ) + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + messager.cache_info["req-pending"] = { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + messager.pending_layer0_signals[3] = (3, 64) + messager.pending_layer0_signals[4] = (4, 64) + + with pytest.raises(SystemExit): + messager._add_cache_task_thread() + + assert messager.pending_layer0_signals == {4: (4, 64)} + assert messager.cache_prefilled_engine_ids_queue.items == [[(3, 64)]] + + +def test_cache_messager_v1_drops_invalid_pending_layer0_signal(monkeypatch): + dummy_queue = _DummyEngineWorkerQueue( + cache_info_sequence=[ + [ + { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + ] + ] + ) + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + messager.cache_info["req-pending"] = { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + messager.pending_layer0_signals[3] = (3, 256) + + with pytest.raises(SystemExit): + messager._add_cache_task_thread() + + assert messager.pending_layer0_signals == {} + assert messager.cache_prefilled_engine_ids_queue.items == [] + + def test_cache_messager_v1_prefill_layerwise_send_cache_thread(monkeypatch): class _OneShotQueue: def __init__(self): @@ -435,10 +548,12 @@ def get(self): } messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 1, "prefilled_token_num": 64} messager.cache_info["req-3"] = messager.idx_cache_task_dict[0] + messager.pending_layer0_signals = {0: (0, 64), 1: (1, 64)} with pytest.raises(SystemExit): messager.prefill_layerwise_send_cache_thread() assert dummy_queue.finished_req_payloads assert dummy_queue.finished_req_payloads[0][0][0] == "req-3" + assert messager.pending_layer0_signals == {1: (1, 64)} def test_cache_messager_v1_handle_connect_task(monkeypatch): @@ -562,13 +677,6 @@ def test_cache_messager_v1_consume_signals(monkeypatch): monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) - class _QueueRecorder: - def __init__(self): - self.items = [] - - def put(self, item): - self.items.append(item) - counter = {"calls": 0} def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): @@ -600,12 +708,57 @@ def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): rdma_port="2222", ) messager.cache_info["req-4"] = {"request_id": "req-4"} + messager.idx_cache_task_dict[2] = {"request_id": "req-4", "current_id": 2} messager.cache_prefilled_engine_ids_queue = _QueueRecorder() with pytest.raises(SystemExit): messager.consume_signals() assert messager.cache_prefilled_engine_ids_queue.items == [[(2, 9)]] +def test_cache_messager_v1_consume_signals_buffers_early_layer0(monkeypatch): + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + signals = [(5, 7, 9), (5, 17, 19)] + + def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): + if not signals: + raise SystemExit + engine_idx, chuck_token_offset, current_seq_len = signals.pop(0) + data = np.full(kv_signal_data.shape, -1, dtype="int32") + data[0] = 1 + data[1] = 0 + data[2] = engine_idx + data[3] = chuck_token_offset + data[4] = current_seq_len + kv_signal_data.set_value(data) + + monkeypatch.setattr(cache_messager, "get_output_kv_signal", _fake_get_output_kv_signal) + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + + with pytest.raises(SystemExit): + messager.consume_signals() + + assert messager.pending_layer0_signals == {5: (5, 36)} + assert messager.cache_prefilled_engine_ids_queue.items == [] + + def test_main_initializes_cache_and_exits(monkeypatch): monkeypatch.setattr(cache_messager, "set_device", lambda device: None) monkeypatch.setattr(cache_messager, "set_data_ipc", lambda tensor, name: None) From 3ffeb445d3ce511ba06afbb621978ef356471450 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Mon, 25 May 2026 13:34:30 +0800 Subject: [PATCH 125/143] [Cherry-Pick][CI] Restore self-hosted runners for GitHub workflows(#7906) (#7909) --- .github/workflows/CheckPRTemplate.yml | 3 +- .github/workflows/Codestyle-Check.yml | 3 +- .github/workflows/_unit_test_coverage.yml | 6 ++- .github/workflows/approve.yml | 6 ++- .github/workflows/cancel_ci_iluvatar.yml | 6 ++- .github/workflows/cancel_ci_xpu.yml | 6 ++- .../workflows/cancel_pr_build_and_test.yml | 6 ++- .github/workflows/ce_job.yml | 42 +++++++++++++--- .github/workflows/check-bypass.yml | 7 ++- .github/workflows/cherry-pick.yml | 6 ++- .github/workflows/ci_image_update.yml | 12 ++++- .github/workflows/ci_metax.yml | 33 ------------- .github/workflows/gh-pages.yml | 6 ++- .github/workflows/pr_build_and_test.yml | 6 ++- .github/workflows/publish_job.yml | 48 +++++++++++++++---- .github/workflows/remove-skip-ci-labels.yml | 6 ++- .github/workflows/rerun.yml | 3 +- 17 files changed, 141 insertions(+), 64 deletions(-) delete mode 100644 .github/workflows/ci_metax.yml diff --git a/.github/workflows/CheckPRTemplate.yml b/.github/workflows/CheckPRTemplate.yml index e5b3dcd3ad9..ba1cd676a03 100644 --- a/.github/workflows/CheckPRTemplate.yml +++ b/.github/workflows/CheckPRTemplate.yml @@ -10,7 +10,8 @@ jobs: check: name: Check PR Template if: ${{ github.repository_owner == 'PaddlePaddle' }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: PR_ID: ${{ github.event.pull_request.number }} BASE_BRANCH: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/Codestyle-Check.yml b/.github/workflows/Codestyle-Check.yml index 6811e3fb38d..0470068d417 100644 --- a/.github/workflows/Codestyle-Check.yml +++ b/.github/workflows/Codestyle-Check.yml @@ -10,7 +10,8 @@ jobs: pre-commit: name: Pre Commit if: ${{ github.repository_owner == 'PaddlePaddle' }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: PR_ID: ${{ github.event.pull_request.number }} BRANCH: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/_unit_test_coverage.yml b/.github/workflows/_unit_test_coverage.yml index a2b72eda854..1cb1ca41213 100644 --- a/.github/workflows/_unit_test_coverage.yml +++ b/.github/workflows/_unit_test_coverage.yml @@ -416,11 +416,15 @@ jobs: diff_coverage_report: needs: run_tests_with_coverage if: always() - runs-on: ubuntu-latest + runs-on: + group: APPROVAL timeout-minutes: 15 env: all_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.all_cov_file_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: diff --git a/.github/workflows/approve.yml b/.github/workflows/approve.yml index 6de30d6f564..39b1844da74 100644 --- a/.github/workflows/approve.yml +++ b/.github/workflows/approve.yml @@ -13,11 +13,15 @@ jobs: Approval: name: Approval if: ${{ github.repository_owner == 'PaddlePaddle' }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: PR_ID: ${{ github.event.pull_request.number }} BRANCH: ${{ github.event.pull_request.base.ref }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Checkout base repo uses: actions/checkout@v6 with: diff --git a/.github/workflows/cancel_ci_iluvatar.yml b/.github/workflows/cancel_ci_iluvatar.yml index 9dba9a7d1e0..1bb5ae247d4 100644 --- a/.github/workflows/cancel_ci_iluvatar.yml +++ b/.github/workflows/cancel_ci_iluvatar.yml @@ -13,8 +13,12 @@ concurrency: jobs: cancel: name: Cancel ILUVATAR-CI for ${{ github.event.pull_request.number }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Cancel ILUVATAR-CI run: | exit 0 diff --git a/.github/workflows/cancel_ci_xpu.yml b/.github/workflows/cancel_ci_xpu.yml index befd59796e9..dab6c9ce79a 100644 --- a/.github/workflows/cancel_ci_xpu.yml +++ b/.github/workflows/cancel_ci_xpu.yml @@ -13,8 +13,12 @@ concurrency: jobs: cancel: name: Cancel CI_XPU for ${{ github.event.pull_request.number }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Cancel CI_XPU run: | exit 0 diff --git a/.github/workflows/cancel_pr_build_and_test.yml b/.github/workflows/cancel_pr_build_and_test.yml index bb488a529ea..0cc0f3d0671 100644 --- a/.github/workflows/cancel_pr_build_and_test.yml +++ b/.github/workflows/cancel_pr_build_and_test.yml @@ -12,8 +12,12 @@ concurrency: jobs: cancel: name: Cancel PR Build and Test for ${{ github.event.pull_request.number }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Cancel PR Build and Test run: | exit 0 diff --git a/.github/workflows/ce_job.yml b/.github/workflows/ce_job.yml index 30775a455a6..c9ce4400ae3 100644 --- a/.github/workflows/ce_job.yml +++ b/.github/workflows/ce_job.yml @@ -14,7 +14,8 @@ concurrency: jobs: ce_job_pre_check: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }} CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }} @@ -26,6 +27,9 @@ jobs: sm8090_match: ${{ steps.set_output.outputs.sm8090_match }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Set Version id: set_output env: @@ -78,9 +82,13 @@ jobs: done print_ce_job_pre_check_outputs: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: ce_job_pre_check steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print outputs as JSON run: | echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}' @@ -89,12 +97,16 @@ jobs: clone: environment: CodeSync name: FD-Clone-Linux - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: ce_job_pre_check if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }} outputs: repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: @@ -154,8 +166,12 @@ jobs: resultshow: name: Show Code Archive Output needs: clone - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print repo_archive_url path run: | echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" @@ -207,13 +223,17 @@ jobs: environment: CodeSync name: CE_UPLOAD needs: build_sm8090 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} COMPILE_ARCH: "80,90" steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -257,13 +277,17 @@ jobs: environment: CodeSync name: CE_UPLOAD_RL needs: build_sm8090_rl - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090_rl.outputs.wheel_path_rl }} COMPILE_ARCH: "80,90" steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -303,13 +327,17 @@ jobs: environment: CodeSync name: CE_UPLOAD needs: build_sm8689 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }} COMPILE_ARCH: "86,89" steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' diff --git a/.github/workflows/check-bypass.yml b/.github/workflows/check-bypass.yml index c9256e7a6cf..a799bbe3a41 100644 --- a/.github/workflows/check-bypass.yml +++ b/.github/workflows/check-bypass.yml @@ -18,7 +18,8 @@ on: jobs: check-bypass: name: Check bypass - runs-on: ubuntu-latest + runs-on: + group: APPROVAL permissions: contents: read env: @@ -64,7 +65,9 @@ jobs: exit 0 fi - files=$(gh pr view ${{ github.event.pull_request.number }} --repo ${{ github.repository }} --json files --jq '.files[].path') + files=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \ + "https://api.github.com/repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/files?per_page=100" \ + | jq -r '.[].filename') echo "$files" can_skip_docs=true diff --git a/.github/workflows/cherry-pick.yml b/.github/workflows/cherry-pick.yml index c6e1bad992e..407acbea687 100644 --- a/.github/workflows/cherry-pick.yml +++ b/.github/workflows/cherry-pick.yml @@ -22,8 +22,12 @@ jobs: github.event.action == 'labeled' || contains(join(github.event.pull_request.labels.*.name, ' '), 'cherry-pick') ) - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Checkout uses: actions/checkout@v6 with: diff --git a/.github/workflows/ci_image_update.yml b/.github/workflows/ci_image_update.yml index 762cad91023..ae6b1b5d0e8 100644 --- a/.github/workflows/ci_image_update.yml +++ b/.github/workflows/ci_image_update.yml @@ -16,10 +16,14 @@ jobs: clone: environment: CodeSync name: FD-Clone-Linux - runs-on: ubuntu-latest + runs-on: + group: APPROVAL outputs: repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: @@ -64,8 +68,12 @@ jobs: resultshow: name: Show Code Archive Output needs: clone - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print wheel path run: | echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" diff --git a/.github/workflows/ci_metax.yml b/.github/workflows/ci_metax.yml deleted file mode 100644 index c983ae38590..00000000000 --- a/.github/workflows/ci_metax.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: CI_METAX - -on: - pull_request_target: - types: - - opened - - synchronize - branches: - - never-trigger-this - -permissions: - contents: read - -concurrency: - group: jenkins-pr-${{ github.event.pull_request.number }} - cancel-in-progress: true - -jobs: - trigger-jenkins: - name: Trigger Jenkins for PR - runs-on: ubuntu-latest - environment: Metax_ci - - steps: - - name: Trigger Jenkins job - timeout-minutes: 120 - uses: MetaX-MACA/simple-jenkins-githubaction@v1.1 - with: - job_name: paddle_fastdeploy_metax_smoketest - username: fastdeploy_builder - api_token: ${{ secrets.METAX_JENKINS_API_TOKEN }} - pr_number: ${{ github.event.pull_request.number }} - project_branch: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 6c06ed0a6aa..17a64cf1d88 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -9,8 +9,12 @@ permissions: jobs: deploy: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/checkout@v6 - uses: actions/setup-python@v6 with: diff --git a/.github/workflows/pr_build_and_test.yml b/.github/workflows/pr_build_and_test.yml index 9ffcd75ee5c..bbad1ee939c 100644 --- a/.github/workflows/pr_build_and_test.yml +++ b/.github/workflows/pr_build_and_test.yml @@ -32,8 +32,12 @@ jobs: resultshow: name: Use Build Output needs: build - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print wheel path run: | echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}" diff --git a/.github/workflows/publish_job.yml b/.github/workflows/publish_job.yml index 9207d58a497..e5b98392665 100644 --- a/.github/workflows/publish_job.yml +++ b/.github/workflows/publish_job.yml @@ -19,7 +19,8 @@ concurrency: jobs: publish_pre_check: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL if: | github.event.repository.fork == false && ( @@ -40,6 +41,9 @@ jobs: compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Get tag version if: github.ref_type == 'tag' run: | @@ -108,9 +112,13 @@ jobs: echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT print_publish_pre_check_outputs: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: publish_pre_check steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print outputs as JSON run: | echo '${{ toJSON(needs.publish_pre_check.outputs) }}' @@ -118,12 +126,16 @@ jobs: clone: environment: CodeSync name: FD-Clone-Linux - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: publish_pre_check if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }} outputs: repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: @@ -168,8 +180,12 @@ jobs: resultshow: name: Show Code Archive Output needs: clone - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print wheel path run: | echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" @@ -235,12 +251,16 @@ jobs: environment: CodeSync name: CE_UPLOAD_FD_ROUTER needs: build_fd_router - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FD_ROUTER_URL: ${{ needs.build_fd_router.outputs.fd_router_path }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -291,12 +311,16 @@ jobs: environment: PaddleSourceUpload name: PADDLE_PYPI_UPLOAD_cu126 needs: build_cu126 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -323,12 +347,16 @@ jobs: environment: PaddleSourceUpload name: PADDLE_PYPI_UPLOAD_cu129 needs: build_cu129 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu129.outputs.wheel_path_cu129 }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -355,12 +383,16 @@ jobs: environment: PaddleSourceUpload name: PADDLE_PYPI_UPLOAD_cu130 needs: build_cu130 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu130.outputs.wheel_path_cu130 }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' diff --git a/.github/workflows/remove-skip-ci-labels.yml b/.github/workflows/remove-skip-ci-labels.yml index 978f70ea240..aace7ae5af6 100644 --- a/.github/workflows/remove-skip-ci-labels.yml +++ b/.github/workflows/remove-skip-ci-labels.yml @@ -10,8 +10,12 @@ permissions: jobs: remove-skip-ci-labels: name: Remove skip-ci labels on new commits - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Get PR labels id: get-labels uses: actions/github-script@v8 diff --git a/.github/workflows/rerun.yml b/.github/workflows/rerun.yml index bbc96edd37e..6527ccf0679 100644 --- a/.github/workflows/rerun.yml +++ b/.github/workflows/rerun.yml @@ -7,7 +7,8 @@ on: jobs: re-run: if: ${{ github.event.issue.pull_request && contains(github.event.comment.body, '/re-run') && github.event.comment.user.login == github.event.issue.user.login }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: - name: Cleanup run: | From 85399db2eb7827955d2eb9a3e3626e6379542e3e Mon Sep 17 00:00:00 2001 From: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> Date: Mon, 25 May 2026 16:10:54 +0800 Subject: [PATCH 126/143] [Cherry-pick][XPU][CI] fix logs update bug (#7915) * Update _xpu_4cards_case_test.yml * Update _xpu_8cards_case_test.yml --- .github/workflows/_xpu_4cards_case_test.yml | 1 + .github/workflows/_xpu_8cards_case_test.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/_xpu_4cards_case_test.yml b/.github/workflows/_xpu_4cards_case_test.yml index a60f9f3aa33..0548ea2afcb 100644 --- a/.github/workflows/_xpu_4cards_case_test.yml +++ b/.github/workflows/_xpu_4cards_case_test.yml @@ -213,6 +213,7 @@ jobs: - name: Upload case logs if: always() + continue-on-error: true uses: actions/upload-artifact@v6 with: name: xpu-4cards-case-logs diff --git a/.github/workflows/_xpu_8cards_case_test.yml b/.github/workflows/_xpu_8cards_case_test.yml index de746b05050..a0afceab1ad 100644 --- a/.github/workflows/_xpu_8cards_case_test.yml +++ b/.github/workflows/_xpu_8cards_case_test.yml @@ -201,6 +201,7 @@ jobs: - name: Upload case logs if: always() + continue-on-error: true uses: actions/upload-artifact@v6 with: name: xpu-8cards-case-logs From e7a02e217599ea0791274676a17faf338609c693 Mon Sep 17 00:00:00 2001 From: sunxin <68891411+Sunny-bot1@users.noreply.github.com> Date: Mon, 25 May 2026 16:37:06 +0800 Subject: [PATCH 127/143] supoort glm yarn rope (#7894) --- .../model_executor/layers/rotary_embedding.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index dd77cf2bc0d..6c7286e2606 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -263,7 +263,7 @@ def forward( return query, key -class GptOssScalingRotaryEmbedding: +class YarnScalingRotaryEmbedding: def __init__( self, rotary_dim, @@ -340,10 +340,29 @@ def get_rope_impl( rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb = rotary_emb_layer(position_ids) elif architecture.startswith("Glm"): - rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor) + rope_scaling = getattr(model_config, "rope_scaling", None) + if ( + rope_scaling is not None + and isinstance(rope_scaling, dict) + and rope_scaling.get("rope_type", rope_scaling.get("type", "")) == "yarn" + and "factor" in rope_scaling + ): + yarn_rotary_dim = int(rotary_dim * partial_rotary_factor) if partial_rotary_factor < 1.0 else rotary_dim + rotary_emb_layer = YarnScalingRotaryEmbedding( + rotary_dim=yarn_rotary_dim, + base=base, + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + scale=rope_scaling["factor"], + mscale=rope_scaling.get("mscale", 1.0), + beta_fast=rope_scaling.get("beta_fast", 32), + beta_slow=rope_scaling.get("beta_slow", 1), + use_neox_rotary_style=False, + ) + else: + rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb = rotary_emb_layer(position_ids) elif architecture.startswith("GptOss"): - rotary_emb_layer = GptOssScalingRotaryEmbedding( + rotary_emb_layer = YarnScalingRotaryEmbedding( rotary_dim=model_config.head_dim, base=model_config.rope_theta, original_max_position_embeddings=model_config.rope_scaling["original_max_position_embeddings"], From 0a5d4b65d4dfd86a745b5b2b825112ff09c7d25b Mon Sep 17 00:00:00 2001 From: zccjjj <62829461+zccjjj@users.noreply.github.com> Date: Tue, 26 May 2026 10:57:50 +0800 Subject: [PATCH 128/143] [bugfix] AS block leaks (#7895) Co-authored-by: kevin --- fastdeploy/engine/sched/resource_manager_v1.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 2c63d7b70df..04aceda7806 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1063,9 +1063,12 @@ def _allocate_decode_and_extend(): self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend ): - if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size + ) + if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold ): # to prevent block allocation for matching in hierarchical cache and cause dead lock break success = self.get_prefix_cached_blocks(request) @@ -1124,6 +1127,7 @@ def _allocate_decode_and_extend(): self.req_dict[request.request_id] = allocated_position llm_logger.debug(f"req_id:{request.request_id} allocate pos end") else: + # Warning: _free_blocks before update_cache_blocks may cause storage blocks leak if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) break @@ -1139,9 +1143,12 @@ def _allocate_decode_and_extend(): self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend ): - if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size + ) + if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold ): # to prevent block allocation for matching in hierarchical cache and cause dead lock break success = self.get_prefix_cached_blocks(request) @@ -1186,6 +1193,7 @@ def _allocate_decode_and_extend(): ) request.status = RequestStatus.RUNNING_PREFILL else: + # Warning: _free_blocks before update_cache_blocks may cause storage blocks leak if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) break From bf0daced98bda2f00c35cef84a44484586de061b Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Tue, 26 May 2026 16:18:14 +0800 Subject: [PATCH 129/143] [Scheduler] Increase sleep interval in fetch loops and cancel schedule threashold for prefill instance (#7871) --- fastdeploy/engine/common_engine_prepare_mixin.py | 4 ++-- fastdeploy/engine/sched/resource_manager_v1.py | 10 ++++++++-- fastdeploy/envs.py | 2 +- fastdeploy/output/token_processor.py | 2 +- fastdeploy/splitwise/splitwise_connector.py | 2 +- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/fastdeploy/engine/common_engine_prepare_mixin.py b/fastdeploy/engine/common_engine_prepare_mixin.py index 71327025458..60ccb7ccd09 100644 --- a/fastdeploy/engine/common_engine_prepare_mixin.py +++ b/fastdeploy/engine/common_engine_prepare_mixin.py @@ -248,10 +248,10 @@ def _fetch_loop(self, fetch_fn, thread_idx: int): with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) fetch_fn() - time.sleep(0.002) + time.sleep(0.02) except Exception as e: self.llm_logger.error(f"fetching request error in worker-{thread_idx}: {e} {traceback.format_exc()}") - time.sleep(0.002) + time.sleep(0.02) def _prepare_request_v1(self): """Prepare request and send to the queue for scheduling""" diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 04aceda7806..e7792f44dc0 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -245,7 +245,7 @@ def get_new_block_nums(self, request: Request, num_new_tokens: int): block_num = ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size - len(request.block_tables) - + block_num = max(block_num, 0) if self.config.speculative_config.method is not None: block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq) else: @@ -1001,7 +1001,13 @@ def _allocate_decode_and_extend(): req_index += 1 continue num_new_block = self.get_new_block_nums(request, num_new_tokens) - can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(num_new_block) + if self.config.scheduler_config.splitwise_role == "prefill": + # for prefill instance, do not set threshold for running requests + can_schedule_block_num_threshold = 0 + else: + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( + num_new_block + ) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): request.block_tables.extend( diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 509f9a768d9..955f3dfdd39 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -192,7 +192,7 @@ def _validate_split_kv_size(value: int) -> int: # "Enable FP8 calibration on HPU" "FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"), # Number of worker threads for prepare requests in prefill instance - "FD_PREFILL_PREPARE_REQ_THREAD_NUM": lambda: int(os.getenv("FD_PREFILL_PREPARE_REQ_THREAD_NUM", "5")), + "FD_PREFILL_PREPARE_REQ_THREAD_NUM": lambda: int(os.getenv("FD_PREFILL_PREPARE_REQ_THREAD_NUM", "3")), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), "FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int( os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1") diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 7505513d7f9..8b06c96c9d8 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -672,7 +672,7 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False # TODO: Refine checking sending cache and do not keep waiting if time.time() - start_time > 30: llm_logger.warning(f"wait for sending cache, {task_id}") - time.sleep(0.002) + time.sleep(0.005) else: if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.finish_requests_async(task_id) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 27c608a4d1f..9f896b694a3 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -267,7 +267,7 @@ def check_decode_allocated(self, task): return True, "" while self.current_request_ids[task.request_id] == "init": - time.sleep(0.001) + time.sleep(0.005) if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS: del self.current_request_ids[task.request_id] return False, "prefill waits for decode resource timeout" From a095d6fdfc295be44669964d37f324d7bc8e543f Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Tue, 26 May 2026 16:19:10 +0800 Subject: [PATCH 130/143] [Cherry-Pick][Feature] support decode unified attention for mix(#7688) (#7729) * support c8 decode attention * support c16 attention && backend * opt kernel * fix * opt larger batch * inplace out * fix input_batch && remove fast_math * fix xpu * fix bug * fix ci * opt and fix mtp * fix merge * clean code * fix merge * update * update test * fix test * fix test * opt buffer * fix conflict --------- Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> --- custom_ops/gpu_ops/cpp_extensions.cc | 102 ++ .../gpu_ops/decode_unified_attention.cu | 428 ++++++ .../attention_func.cuh | 1231 +++++++++++++++++ .../config_for_attention.cu | 409 ++++++ .../cu_tensor_map.cuh | 124 ++ .../decode_unified_attention_c16_impl.cuh | 492 +++++++ .../decode_unified_attention_c8_impl.cuh | 706 ++++++++++ .../decode_unified_attention/mem_util.cuh | 389 ++++++ .../mma_tensor_op.cuh | 296 ++++ .../template_config.json | 78 ++ .../decode_unified_attention/utils.cuh | 689 +++++++++ .../gpu_ops/decoder_write_cache_with_rope.cu | 326 +++++ custom_ops/setup_ops.py | 7 + .../utils/auto_gen_template_attention.py | 227 +++ fastdeploy/envs.py | 2 + .../layers/attention/append_attn_backend.py | 24 + .../layers/attention/flash_attn_backend.py | 220 ++- .../layers/attention/ops/__init__.py | 6 + .../attention/ops/config_for_attention.py | 58 + .../attention/ops/decode_unified_attention.py | 105 ++ .../ops/decoder_write_cache_with_rope.py | 97 ++ fastdeploy/spec_decode/mtp.py | 22 + fastdeploy/worker/gpu_model_runner.py | 11 + fastdeploy/worker/input_batch.py | 14 + fastdeploy/worker/metax_model_runner.py | 2 + ..._ernie_21b_mtp_decode_unified_attention.py | 381 +++++ .../test_decode_unified_attention_c16.py | 868 ++++++++++++ .../test_decode_unified_attention_c8.py | 921 ++++++++++++ 28 files changed, 8172 insertions(+), 63 deletions(-) create mode 100644 custom_ops/gpu_ops/decode_unified_attention.cu create mode 100644 custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh create mode 100644 custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu create mode 100644 custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh create mode 100644 custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh create mode 100644 custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh create mode 100644 custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh create mode 100644 custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh create mode 100644 custom_ops/gpu_ops/decode_unified_attention/template_config.json create mode 100644 custom_ops/gpu_ops/decode_unified_attention/utils.cuh create mode 100644 custom_ops/gpu_ops/decoder_write_cache_with_rope.cu create mode 100644 custom_ops/utils/auto_gen_template_attention.py create mode 100644 fastdeploy/model_executor/layers/attention/ops/config_for_attention.py create mode 100644 fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py create mode 100644 fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py create mode 100644 tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py create mode 100644 tests/operators/attention/test_decode_unified_attention_c16.py create mode 100644 tests/operators/attention/test_decode_unified_attention_c8.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index e2a2fc1b92f..d74f4240260 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -189,6 +189,84 @@ std::vector AppendAttentionWithOutput( const int sliding_window, const int sink_size); +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder); + +std::vector DecodeUnifiedAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window); + +void ConfigForAttention(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch); + std::vector GQARopeWriteCacheKernel( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, @@ -1962,4 +2040,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("per_token_group_fp8_quant", &PerTokenGroupQuantFp8, "per_token_group_quant_fp8"); + + /** + * decoder_write_cache_with_rope.cu + * decoder_write_cache_with_rope + */ + m.def("decoder_write_cache_with_rope", + &DecoderWriteCacheWithRoPE, + "decoder write cache with RoPE function"); + + /** + * decode_unified_attention.cu + * decode_unified_attention + */ + m.def("decode_unified_attention", + &DecodeUnifiedAttention, + "decoder append attention function"); + + /** + * config_for_attention.cu + * config_for_attention + */ + m.def("config_for_attention", + &ConfigForAttention, + "config for attention function"); } diff --git a/custom_ops/gpu_ops/decode_unified_attention.cu b/custom_ops/gpu_ops/decode_unified_attention.cu new file mode 100644 index 00000000000..257134d1e95 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention.cu @@ -0,0 +1,428 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decode_unified_attention/decode_unified_attention_c8_impl.cuh" +#include "decode_unified_attention/decode_unified_attention_c16_impl.cuh" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecodeUnifiedAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_num = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_kv_len_this_time = set_max_lengths.data()[5]; + + auto stream = qkv.stream(); + bool is_fp8 = + cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8"; + bool is_dynamic_cfp8 = cache_quant_type == "block_wise_fp8"; + bool is_c16 = cache_quant_type == "none"; + + if (max_just_dec_len_this_time > 0) { + if (is_c16) { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, max_tokens_per_batch, Q_TILE_SIZE, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeUnifiedC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeUnifiedC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})}) + } else { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, + max_tokens_per_batch, + Q_TILE_SIZE, + {DISPATCH_DyCfp8( + is_dynamic_cfp8, + IsDynamicC8, + {DISPATCH_IS_FP8(is_fp8, IsFP8, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeUnifiedC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeUnifiedC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})})})}) + } + } + return {fmha_out}; +} + +std::vector> DecodeUnifiedAttentionInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& tmp_workspace_shape, + const std::vector& tmp_m_shape, + const std::vector& tmp_d_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& block_indices_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& mask_offset_shape, + const paddle::optional>& sinks_shape, + const std::vector& fmha_out_shape, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_shape}; +} + +std::vector DecodeUnifiedAttentionInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& tmp_workspace_dtype, + const paddle::DataType& tmp_m_dtype, + const paddle::DataType& tmp_d_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& block_indices_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& mask_offset_dtype, + const paddle::optional& sinks_dtype, + const paddle::DataType& fmha_out_dtype, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_dtype}; +} + +PD_BUILD_STATIC_OP(decode_unified_attention) + .Inputs({"qkv", + "key_cache", + "value_cache", + "tmp_workspace", + "tmp_m", + "tmp_d", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "block_indices", + "num_blocks", + "chunk_size", + "set_max_lengths", + paddle::Optional("attn_mask"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("mask_offset"), + paddle::Optional("sinks"), + "fmha_out"}) + .Outputs({"fmha_out_out"}) + .SetInplaceMap({{"fmha_out", "fmha_out_out"}}) + .Attrs({ + "cache_quant_type: std::string", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "max_tokens_per_batch: int", + "causal: bool", + "sliding_window: int", + }) + .SetKernelFn(PD_KERNEL(DecodeUnifiedAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(DecodeUnifiedAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecodeUnifiedAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh b/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh new file mode 100644 index 00000000000..ee74570e5d8 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh @@ -0,0 +1,1231 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "mma_tensor_op.cuh" +#include "utils.cuh" + +template +__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = 0.f; + } + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if constexpr (std::is_same::value) { + m[fx][j] = -5e4f; + } else if constexpr (std::is_same::value) { + m[fx][j] = -3.0e+30f; + } + d[fx][j] = 1.f; + } + } +} + +template +__device__ __forceinline__ void load_block_table_per_chunk( + const int32_t* block_table_chunk_start, + int32_t* block_table_smem, + uint32_t chunk_start, + uint32_t chunk_end, + uint32_t tid, + uint32_t wid) { + uint32_t len = chunk_end / BLOCK_SIZE - chunk_start / BLOCK_SIZE; + for (uint32_t i = 0; i < div_up(len, 128); i++) { + uint32_t offset = wid * kWarpSize + tid + i * 128; + if (offset < len) { + block_table_smem[offset] = block_table_chunk_start[offset]; + } + } +} + +// load q from global memory to shared memory +template +__device__ __forceinline__ void load_q_global_smem_multi_warps( + T* q_ptr_base, + smem_t* q_smem, + uint32_t q_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] + smem_t::get_permuted_offset(ty * 4 + tx / 8, + tx % 8); // 4 * 64 + + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); + q_smem_offset_w = + q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); + } + q_smem_offset_w = + q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( + smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + const float sm_scale) { + constexpr int vec_size = 16 / sizeof(T); + using LoadT = AlignedVector; + LoadT tmp_vec; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { + const int offset = i * 1024 + ty * 256 + tx * 8; + Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp_vec[reg_id] *= sm_scale; + } + Store(tmp_vec, reinterpret_cast(q_smem->base) + offset); + } +} + +template +__device__ __forceinline__ void produce_k_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_k, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_k_offset) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = + head_dim / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; +#pragma unroll + for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y; // num_frags_y / 4 * 4 + cache_k_now += num_warps * 4 * kv_b_stride - + num_frags_y * num_elems_per_128b(); + } + } + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void produce_v_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_v, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_d_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_v_offset) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); + +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; + +#pragma unroll + for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; + ++i) { // m (num_frags_y * 16 / (num_warps * 8)) +#pragma unroll + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 4 * num_elems_per_128b(); + kv_idx += 4 * num_elems_per_128b(); + } + kv_idx -= 2 * num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_z; // num_frags_z / 4 * 4 + cache_v_now += num_warps * 8 * kv_d_stride - + 2 * num_frags_z * num_elems_per_128b(); + } + kv_idx += block_size; + } + *smem_offset -= NUM_WARP_KV / 2 * num_frags_y * 16 * num_vecs_per_blocksize; +} + +template +__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async( + smem_t kv_scale_smem, + const int* block_table_now, + const T* cache_kv_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tid = ty * 32 + tx; + // 1 warp 32 tokens + if (tid < block_size / 8 * 2) { + const uint32_t kv_idx_now = kv_idx + block_size * tid / 8; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const int kv_idx_this_thread = kv_idx + tid * 8; + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid % 8 * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); + } +} + +template +__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg( + T* k_smem_scale, T* cache_k_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + // 1 warp 32 tokens + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg( + T* v_smem_scale, T* cache_v_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + // 1 warp 32 tokens + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; + } +} + +template +__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + const T* cache_k_scale, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + head_dim / num_elems_per_128b(); + + uint32_t a_frag[num_frags_x][2][4], b_frag[4], b_frag_dq[4]; + +#pragma unroll + for (uint32_t ky = 0; ky < num_frags_y / 2; ++ky) { // k + // load q +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]); + + *q_smem_offset_r = + q_smem->advance_offset_by_row<16, num_vecs_per_head_q>( + *q_smem_offset_r); + } + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 2 + fy) - + num_frags_x * 16 * num_vecs_per_head_q; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + // load + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[0]; + } + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (ky == 0 && fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } + } + } + } + *k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>( + *k_smem_offset_r, ky) - + num_frags_z * 16 * num_vecs_per_head_k; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y / 2 * 2; +} + +template +__device__ __forceinline__ void mask_s(const bool* attn_mask, + const uint32_t qo_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + const uint32_t attn_mask_len, + float (*s_frag)[num_frags_z][8], + const int* mask_offset = nullptr, + const int sliding_window = 0) { + const uint32_t tx = threadIdx.x; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 + + 8 * ((reg_id % 4) / 2)) / + group_size, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + bool out_of_boundary; + if (mask_offset) { + const int2 mo = reinterpret_cast(mask_offset)[q_idx]; + out_of_boundary = + q_idx < qo_len ? (kv_idx >= mo.y || kv_idx < mo.x) : true; + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + bool mask = attn_mask[mask_idx]; + out_of_boundary |= mask; + } + } + + if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; + } else if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; + } + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t j_id = j * 2; + float m_prev = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_tmp = s_frag[fx][fz] + j_id; + float m_local = max(max(s_frag_tmp[0], s_frag_tmp[1]), + max(s_frag_tmp[4], s_frag_tmp[5])); + m[fx][j] = max(m[fx][j], m_local); + } + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x2, 32)); + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x1, 32)); + float o_scale = expf(m_prev - m[fx][j]); + d[fx][j] *= o_scale; + float2 fp2_scale = make_float2(o_scale, o_scale); +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_frag_ptr = reinterpret_cast(o_frag[fx][fy] + j_id); + o_frag_ptr[0] = fast_float2_mul(o_frag_ptr[0], fp2_scale); + o_frag_ptr[2] = fast_float2_mul(o_frag_ptr[2], fp2_scale); + } + float tmp_m = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_ptr = s_frag[fx][fz] + j_id; + s_frag_ptr[0] = __expf(s_frag_ptr[0] - tmp_m); + s_frag_ptr[1] = __expf(s_frag_ptr[1] - tmp_m); + s_frag_ptr[4] = __expf(s_frag_ptr[4] - tmp_m); + s_frag_ptr[5] = __expf(s_frag_ptr[5] - tmp_m); + } + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T* cache_v_scale) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 2; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } + } + } else { + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), + b_frag_dq); + } + } + } + *v_smem_offset_r -= num_frags_y * 16 * num_vecs_per_blocksize; + } +} + +template +__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], + float* md_smem, + float (*m)[2], + float (*d)[2], + const uint32_t wid, + const uint32_t tid, + const bool normalize = false) { + // Padded row stride (33 instead of 32) to avoid cross-row bank conflicts. + constexpr uint32_t kRowStride = 33; + // o_smem row stride in floats: kRowStride * 8 = 264 + constexpr uint32_t kORowStride = kRowStride * 8; + // md_smem base offset: after all o_smem data + // NUM_WARPS(4) * num_frags_x * num_frags_y * kORowStride floats + constexpr uint32_t kOMemFloats = 4 * num_frags_x * num_frags_y * kORowStride; + float2* smem_md = reinterpret_cast(md_smem + kOMemFloats); + + // Phase 1: Write m/d to smem only (2KB, no o data yet) +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((wid * num_frags_x + fx) * 2 + j) * kRowStride + tid] = + make_float2(m[fx][j], d[fx][j]); + } + } + __syncthreads(); + + // Phase 2: Compute global m/d and scale own o_frag in registers + float scale_j[2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new; + float d_new = 1.f; + if constexpr (std::is_same::value) { + m_new = -5e4f; + } else { + m_new = -3.0e+30f; + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = + smem_md[((i * num_frags_x + fx) * 2 + j) * kRowStride + tid]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = fmaf(d_prev, expf(m_prev - m_new), md.y * expf(md.x - m_new)); + } + float own_scale = expf(m[fx][j] - m_new); + m[fx][j] = m_new; + d[fx][j] = d_new; + float d_rcp = normalize ? (1.f / d_new) : 1.f; + scale_j[j] = own_scale * d_rcp; + } + // Apply scale to o_frag using WGMMA fragment layout: + // regs 0,1→j=0, 2,3→j=1, 4,5→j=0, 6,7→j=1 + // i.e., float2 index k → j = k % 2 +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t k = 0; k < 4; ++k) { + float s = scale_j[k % 2]; + o_frag[fx][fy][2 * k + 0] *= s; + o_frag[fx][fy][2 * k + 1] *= s; + } + } + } + + // Phase 3: Write pre-scaled o_frag to smem with padded stride +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_smem_start = + (float2*)(md_smem + + ((wid * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + o_smem_start[i * kRowStride] = ((float2*)(&o_frag[fx][fy][0]))[i]; + } + } + } + __syncthreads(); + + // Phase 4: Accumulate all warps' scaled o_frag +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_new_fp2 = reinterpret_cast(&o_frag[fx][fy][0]); +#pragma unroll + for (uint32_t o_id = 0; o_id < 4; ++o_id) { + o_new_fp2[o_id] = make_float2(0.f, 0.f); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + AlignedVector oi_fp2; + float2* o_smem_start = + (float2*)(md_smem + + ((i * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + oi_fp2[reg_id] = o_smem_start[reg_id * kRowStride]; + } +#pragma unroll + for (uint32_t reg_fp2_id = 0; reg_fp2_id < 4; ++reg_fp2_id) { + o_new_fp2[reg_fp2_id].x += oi_fp2[reg_fp2_id].x; + o_new_fp2[reg_fp2_id].y += oi_fp2[reg_fp2_id].y; + } + } + } + } +} + +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + float d_rcp[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[fx][j] = 1.f / d[fx][j]; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2]; + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem_multi_warps( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + OutT* o_ptr_base, + uint32_t o_idx_base, + const uint32_t q_head_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr int VEC_SIZE = 16 / sizeof(T); + // [num_warps * num_frags_x * 16, num_frags_y * 16] + if (ty == 0) { + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + } + __syncthreads(); + + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); + + const uint32_t tx_offset = tx / 8; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + + OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + if (n_offset < qo_upper_bound) { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +struct prefill_softmax_state_t { + AlignedVector o; + float m; + float d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + __device__ __forceinline__ void merge( + const AlignedVector& other_o, float other_m, float other_d) { + float m_prev = m, d_prev = d; + m = m_prev > other_m ? m_prev : other_m; + const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d_prev * scale1 + other_d * scale2; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1_T + other_o[i] * scale2_T; + } + } + + __device__ __forceinline__ void normalize() { + const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } + + __device__ __forceinline__ void normalize(float current_sink) { + const T d_t = static_cast(d + __expf(current_sink - m)); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } +}; + +// C16 (fp16/bf16 KV cache) helper functions + +template +__device__ __forceinline__ void produce_kv_blockwise(smem_t smem, + uint32_t* smem_offset, + T** gptr, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_y; + *gptr += + num_warps * 4 * kv_b_stride - 2 * num_frags_y * num_elems_per_128b(); + } + *gptr -= NUM_WARP_KV * num_frags_z * 16 * kv_b_stride; + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void compute_qk(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + uint32_t a_frag[num_frags_x][4], b_frag[4]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); + *q_smem_offset_r = q_smem->advance_offset_by_row<16, num_vecs_per_head>( + *q_smem_offset_r); + } + + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, fy) - + num_frags_x * 16 * num_vecs_per_head; + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } + } + } + *k_smem_offset_r = + k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) - + num_frags_z * 16 * num_vecs_per_head; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y * 2; +} + +template +__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t b_frag[4]; + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); + } + *v_smem_offset_r = + v_smem->advance_offset_by_column<2>(*v_smem_offset_r, fy); + } + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_head>(*v_smem_offset_r) - + 2 * num_frags_y; + } + *v_smem_offset_r -= 16 * num_frags_z * num_vecs_per_head; +} + +template +__global__ void merge_chunks_kernel( + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + const int* __restrict__ chunk_size_ptr, + T* __restrict__ out, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int head_dim, + const int token_num, + const int max_tokens_per_batch = 5) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int hid = blockIdx.y; + // After intra-warp reduction, only bdy/2 results need smem storage + __shared__ T smem[(bdy / 2) * HEAD_DIM]; + __shared__ float md_smem[(bdy / 2) * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + // Phase 1: Fast path — all ty participate independently (no smem, no + // syncthreads) Each ty handles a different qid with stride gridDim.x * bdy + using LoadT = AlignedVector; + for (int qid = blockIdx.x + ty * gridDim.x; qid < token_num; + qid += gridDim.x * bdy) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; + if (seq_lens_encoder[bid] > 0) continue; // skip prefill batches + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq != 1) continue; // handled in Phase 2 + + LoadT load_vec; + uint32_t offset = + ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * num_heads + + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + Store( + load_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + + // Phase 2: Slow path — merge multi-chunk results + // Optimization: use warp-shuffle reduction within each warp, then cross-warp + // via smem. This eliminates the large smem[bdy * HEAD_DIM] buffer and reduces + // syncthreads from 2 per qid to 1 per qid. + // Block layout: (blockx=16, bdy=8) => 4 warps, each warp has 2 ty values + // Warp 0: ty=0,1 Warp 1: ty=2,3 Warp 2: ty=4,5 Warp 3: ty=6,7 + // Lane layout within warp: lanes 0-15 = (ty_low, vid), lanes 16-31 = + // (ty_high, vid) + const int lane_id = (ty * blockDim.x + vid) % 32; + + for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; // uniform skip — no syncthreads needed + if (seq_lens_encoder[bid] > 0) continue; + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq == 1) continue; // handled in Phase 1 + + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } + + // Step 1: Each ty iterates over its chunk subset and does local online + // softmax merge +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset; + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks + i) * + num_heads + + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * + num_heads + + i * num_heads + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + + // Step 2: Intra-warp reduction via warp shuffle + // Each warp has 2 ty values: ty_low at lanes 0-15, ty_high at lanes 16-31 + // Merge ty_high into ty_low using shuffle + const int partner_lane = lane_id ^ 16; // flip bit 4 to swap low/high ty + const float m_partner = __shfl_sync(0xffffffff, m, partner_lane); + const float d_partner = __shfl_sync(0xffffffff, d, partner_lane); + // Pack adjacent 16-bit pairs into 32-bit for efficient shuffle. + // AlignedVector alignment >= 4 bytes, so uint32 reinterpret is safe + // — no OOB read, no type confusion. This halves shuffle count vs + // per-element memcpy for bf16/fp16. + constexpr int PACKED_SIZE = vec_size * sizeof(T) / sizeof(unsigned); + const unsigned* packed_res = reinterpret_cast(&res_vec); + unsigned packed_partner[PACKED_SIZE]; +#pragma unroll + for (int j = 0; j < PACKED_SIZE; j++) { + packed_partner[j] = __shfl_sync(0xffffffff, packed_res[j], partner_lane); + } + LoadT partner_vec; + memcpy(&partner_vec, packed_partner, sizeof(partner_vec)); + + // Merge partner into self (only the "low ty" keeps the result) + float m_new = max(m, m_partner); + const float scale1 = __expf(m - m_new); + const float scale2 = __expf(m_partner - m_new); + float d_new = d * scale1 + d_partner * scale2; + if ((ty & 1) == 0) { // low ty keeps merged result + m = m_new; + d = d_new; + const T scale1_T = static_cast(scale1); + const T scale2_T = static_cast(scale2); +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + partner_vec[j] * scale2_T; + } + } + + // Cross-warp: only even ty (0,2,4,6) write to smem + if ((ty & 1) == 0) { + Store(res_vec, &smem[(ty / 2) * head_dim + vid * vec_size]); + md_smem[ty] = m; + md_smem[ty + 1] = d; + } + __syncthreads(); + + if (ty == 0) { + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy / 2; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + + if (sinks) { + float current_sink = static_cast(sinks[hid]); + st.normalize(current_sink); + } else { + st.normalize(); + } + + const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } + +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); + } + Store( + out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + __syncthreads(); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu b/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu new file mode 100644 index 00000000000..7033cbd10bf --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu @@ -0,0 +1,409 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "cute/tensor.hpp" +#include "helper.h" +#include "paddle/extension.h" +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#endif +#include "utils.cuh" + +template +__global__ void GetMaxLenKernel(const int* seq_lens_decoder, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int* max_lens, + const int batch_size) { + const int tid = threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int max_len_this_time_this_thread = 0; + int max_len_encoder_this_thread = 0; + int max_len_decoder_this_thread = 0; + int max_len_this_thread = 0; + int max_just_dec_len_this_thread = 0; + int max_len_kv_this_thread = 0; + for (int i = tid; i < batch_size; i += blockDim.x) { + const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; + max_len_this_time_this_thread = + max(seq_len_this_time, max_len_this_time_this_thread); + max_len_encoder_this_thread = + max(seq_lens_encoder[i], max_len_encoder_this_thread); + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; + max_len_this_thread = + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); + max_just_dec_len_this_thread = + max(max_just_dec_len_this_thread, max_just_dec_len_now); + + if (seq_len_decoder == 0) continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); + } + int total_max_len_this_time = + BlockReduce(temp_storage) + .Reduce(max_len_this_time_this_thread, MaxOp()); + int total_max_len_encoder = + BlockReduce(temp_storage) + .Reduce(max_len_encoder_this_thread, MaxOp()); + int total_max_len_decoder = + BlockReduce(temp_storage) + .Reduce(max_len_decoder_this_thread, MaxOp()); + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + int total_just_dec = BlockReduce(temp_storage) + .Reduce(max_just_dec_len_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); + if (tid == 0) { + max_lens[0] = total_max_len_this_time; + max_lens[1] = total_max_len_encoder; + max_lens[2] = total_max_len_decoder; + max_lens[3] = total; + max_lens[4] = total_just_dec; + max_lens[5] = total_max_len_kv; + } +} + +template +__global__ void config_decode_attn(const int* __restrict__ seq_lens_this_time, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + int4* __restrict__ block_indices, + int* __restrict__ num_blocks, + int* __restrict__ chunk_size, + const int bsz, + const int group_size, + const int kv_num_heads, + const int q_tile_size, + const int max_tokens_per_batch, + const int config_gridx) { + const int tid = threadIdx.x, wid = threadIdx.y; + const uint32_t warp_size = blockDim.x; + __shared__ int num_block_all_shared[block_size]; + __shared__ int chunk_size_res[1]; + + const int lane_id = tid + wid * warp_size; + + // Merged Step 1+2: single bsz loop computing both Scheme E metrics and + // split-KV block counts per lane. Avoids redundant seq_lens reads and + // shared intermediate values (token_num, kv_len, q_tile_num). + const int target_blocks = config_gridx / 3; // sm_count * 3 + // Search chunk_size from 512 with step 128: {512, 640, 768, ...} + + const int cur_chunk_size = + min(min_chunk_size + lane_id * chunk_step, max_chunk_size); + int num_block_no_chunk = 0; + int max_kv_len_no_chunk = 0; + int num_block_all = 0; + for (int bid = 0; bid < bsz; bid++) { + if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) { + continue; + } + const int token_num_cur_batch = seq_lens_this_time[bid]; + const int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch; + const int q_tile_num = + div_up(token_num_cur_batch * group_size, q_tile_size); + num_block_no_chunk += q_tile_num * kv_num_heads; + max_kv_len_no_chunk = max(max_kv_len_no_chunk, kv_len_cur_batch); + const int kv_chunk_num = div_up(kv_len_cur_batch, cur_chunk_size); + num_block_all += q_tile_num * kv_chunk_num * kv_num_heads; + } + num_block_all_shared[lane_id] = num_block_all; + __syncthreads(); + + // Step 3: find best chunk_size, then decide Scheme E vs split-KV + if (tid == 0 && wid == 0) { + // Strategy: + // 1. Must fill target_blocks (2*sm_count) to maintain SM concurrency + // 2. Among valid choices, prefer minimum per-SM max KV traffic + // (= waves * chunk_size, since kernel time = slowest SM) + // 3. Within 5% of minimum KV traffic, prefer larger chunk_size + int chunk_size_best = min_chunk_size; + int num_block_all_best = num_block_all_shared[0]; + // Step 1: find minimum kv_traffic among chunk_sizes that fill SMs + int64_t kv_traffic_min = INT64_MAX; + for (int i = 0; i < static_cast(block_size); i++) { + const int nb = num_block_all_shared[i]; + if (nb < target_blocks) continue; + const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size); + const int w = div_up(nb, target_blocks); + const int64_t kv_traffic = static_cast(w) * cs; + if (kv_traffic < kv_traffic_min) { + kv_traffic_min = kv_traffic; + } + } + // Step 2: if no chunk_size fills SMs, fall back to smallest + if (kv_traffic_min == INT64_MAX) { + chunk_size_best = min_chunk_size; + num_block_all_best = num_block_all_shared[0]; + } else { + // Step 3: scan from largest chunk_size downward; accept the first + // one that fills SMs AND has kv_traffic within 20% of minimum + for (int i = block_size - 1; i >= 0; i--) { + const int nb = num_block_all_shared[i]; + if (nb < target_blocks) continue; + const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size); + const int w = div_up(nb, target_blocks); + const int64_t kv_traffic = static_cast(w) * cs; + if (kv_traffic <= kv_traffic_min + kv_traffic_min / 4) { + chunk_size_best = cs; + num_block_all_best = nb; + break; + } + } + } + + // Decide Scheme E: prefer when blocks fill SMs AND estimated latency + // is no worse than split-KV. + // Scheme E: waves_E * max_kv_len (few heavy blocks) + // Split-KV: waves_split * chunk_size_best (many light blocks) + // When no splitting is needed (num_block_all_best == num_block_no_chunk), + // Scheme E is strictly better (saves merge overhead). + bool use_scheme_e = false; + if (num_block_no_chunk >= target_blocks) { + if (num_block_all_best == num_block_no_chunk) { + use_scheme_e = true; + } else { + // target_blocks = sm_count * 3 ≈ CTAs per wave (sm_count × occupancy). + // Using target_blocks as denominator correctly accounts for occupancy + // in wave count estimation. + const int waves_e = div_up(num_block_no_chunk, target_blocks); + const int waves_split = div_up(num_block_all_best, target_blocks); + use_scheme_e = (static_cast(waves_e) * max_kv_len_no_chunk <= + static_cast(waves_split) * chunk_size_best); + } + } + + if (use_scheme_e) { + num_blocks[0] = num_block_no_chunk; + chunk_size[0] = INT_MAX; + chunk_size_res[0] = INT_MAX; + } else { + num_blocks[0] = num_block_all_best; + chunk_size[0] = chunk_size_best; + chunk_size_res[0] = chunk_size_best; + } + } + + __syncthreads(); + if (wid == 0) { + const int chunk_size_final = chunk_size_res[0]; + + int prev_offset = 0; + for (int base = 0; base < bsz; base += warp_size) { + const int bid = base + tid; + int num_block_cur = 0; + int q_tile_num = 0; + int kv_chunk_num = 0; + + if (bid < bsz) { + int token_num_cur_batch = seq_lens_this_time[bid]; + if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { + token_num_cur_batch = 0; + } + q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + const int kv_len_cur_batch = + seq_lens_decoder[bid] + token_num_cur_batch; + kv_chunk_num = div_up(kv_len_cur_batch, chunk_size_final); + num_block_cur = q_tile_num * kv_chunk_num * kv_num_heads; + } + + // inclusive prefix sum + int x = num_block_cur; + for (int offset = 1; offset < warp_size; offset <<= 1) { + int y = __shfl_up_sync(0xffffffff, x, offset); + if (tid >= offset) x += y; + } + int bid_offset = x - num_block_cur; + int tile_sum = __shfl_sync(0xffffffff, x, warp_size - 1); + + // Write block_indices using int4 vectorized stores. + // Each entry is exactly 4 ints (bid, kv_head_id, kv_chunk_id, q_tile_id), + // matching int4 layout. This reduces 4 scalar stores to 1 vector store. + if (bid < bsz && num_block_cur > 0) { + int4* write_ptr = block_indices + prev_offset + bid_offset; + int flat_idx = 0; + const int kv_chunk_num_x_q_tile_num = kv_chunk_num * q_tile_num; +#pragma unroll 2 + for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) { + const int head_base = kv_head_id * kv_chunk_num_x_q_tile_num; +#pragma unroll 2 + for (int kv_chunk_id = 0; kv_chunk_id < kv_chunk_num; kv_chunk_id++) { + const int chunk_base = head_base + kv_chunk_id * q_tile_num; +#pragma unroll + for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) { + write_ptr[flat_idx] = + make_int4(bid, kv_head_id, kv_chunk_id, q_tile_id); + flat_idx++; + } + } + } + } + prev_offset += tile_sum; + } + } +} + +void ConfigForAttention( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace, shape:[block_num,4], block's + // indices with 4 dimension[batch_idx, + // kv_head_idx, kv_chunk_idx, q_tile_idx] + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + auto stream = seq_lens_encoder.stream(); + int bsz = seq_lens_this_time.shape()[0]; + + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + auto max_len_cpu_ptr = max_len_tensor_cpu.data(); + int max_just_dec_len_this_time = max_len_cpu_ptr[4]; + + const uint32_t block_indices_ele_num = block_indices.size(); + + // decoder + if (max_just_dec_len_this_time > 0) { + CUDA_CHECK(cudaMemsetAsync(block_indices.data(), + 0, + block_indices_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK( + cudaMemsetAsync(num_blocks.data(), 0, sizeof(int32_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(chunk_size.data(), 0, sizeof(int32_t), stream)); + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK(cudaDeviceGetAttribute( + &sm_cout, cudaDevAttrMultiProcessorCount, device)); + const int config_gridx = sm_cout * 6; + + const int q_tile_size = 16; + dim3 blocks(32, 4); + // Cast block_indices to int4* for vectorized stores. + // Each block_indices entry is 4 ints = 16 bytes = sizeof(int4), + // and block_num * 4 ints = block_num int4s, so the reinterpret is valid. + int4* block_indices_i4 = reinterpret_cast(block_indices.data()); + if (cache_quant_type == "cache_int4_zp") { + config_decode_attn<512, 256, 128, 32768> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices_i4, + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } else { + config_decode_attn<512, 128, 128, 16384> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices_i4, + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } + } +} + +std::vector> ConfigForAttentionInferShape( + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& max_len_tensor_cpu_shape, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +std::vector ConfigForAttentionInferDtype( + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& max_len_tensor_cpu_dtype, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +PD_BUILD_STATIC_OP(config_for_attention) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "block_indices", + "num_blocks", + "chunk_size", + "max_len_tensor_cpu", + }) + .Outputs({ + + }) + .Attrs({"cache_quant_type: std::string", + "group_size: int", + "kv_num_heads: int", + "max_tokens_per_batch: int"}) + .SetKernelFn(PD_KERNEL(ConfigForAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(ConfigForAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ConfigForAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh b/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh new file mode 100644 index 00000000000..ff84e1cd3f6 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh @@ -0,0 +1,124 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +template +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template +CUtensorMap makeTensorMapForKVCache(T const* addr, + uint32_t block_num, + uint32_t kv_num_head, + uint32_t second_size, + uint32_t last_size) { + CUtensorMap tensorMap{}; + + uint32_t elem_bytes = sizeof(T); + + uint32_t const last_size_bytes = elem_bytes * last_size; + // VLLM Layout + CUtensorMapDataType data_dtype = cu_tensor_map_type_traits::type; + constexpr uint32_t rank = 4; + uint64_t global_dims[] = {last_size, second_size, kv_num_head, block_num}; + uint64_t global_strides[] = {last_size_bytes, + second_size * last_size_bytes, + kv_num_head * second_size * last_size_bytes}; + + uint32_t box_dims[] = {last_size, second_size, 1, 1}; + uint32_t elem_strides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] { + switch (last_size_bytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache last_size"); + } + }(); + CUresult res = cuTensorMapEncodeTiled( + &tensorMap, + data_dtype, + rank, + reinterpret_cast(const_cast(addr)), + global_dims, + global_strides, + box_dims, + elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + switch (res) { + case CUDA_SUCCESS: + printf("CUDA_SUCCESS!\n"); + break; + case CUDA_ERROR_INVALID_VALUE: + printf("CUDA_ERROR_INVALID_VALUE\n"); + break; + case CUDA_ERROR_OUT_OF_MEMORY: + printf("CUDA_ERROR_OUT_OF_MEMORY\n"); + break; + case CUDA_ERROR_NOT_INITIALIZED: + printf("CUDA_ERROR_NOT_INITIALIZED\n"); + break; + case CUDA_ERROR_DEINITIALIZED: + printf("CUDA_ERROR_DEINITIALIZED\n"); + break; + case CUDA_ERROR_PROFILER_DISABLED: + printf("CUDA_ERROR_PROFILER_DISABLED\n"); + break; + default: + throw std::runtime_error("unsupported res!"); + } + + return tensorMap; +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh new file mode 100644 index 00000000000..e30588a01ab --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh @@ -0,0 +1,492 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +#include "attention_func.cuh" + +template +__global__ void decode_unified_attention_c16_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_frags_x * 16 + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int* block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T* q_base_ptr = qkv + q_offset; + + T* o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int* mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool* attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + BLOCK_SIZE); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len, chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (BLOCK_SIZE); + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid % 16, tid / 16); + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx = chunk_start; + int block_table_idx = kv_idx / BLOCK_SIZE; + int block_id = __ldg(&block_table_now[block_table_idx]); + int block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T* cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T* cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if (iter + 1 < num_iterations) { + block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + } + + wait_group<1>(); + __syncthreads(); + + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx += BLOCK_SIZE; + block_table_idx++; + + block_id = block_id_next; + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + __syncthreads(); + + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res(o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeUnifiedC16Attention( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::optional& attn_mask, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const int max_seq_len, + const int max_dec_len, + const int max_tokens_per_batch, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; + constexpr uint32_t smem_size_0 = + (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * + sizeof(NV_TYPE); + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = + decode_unified_attention_c16_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 512); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = + reinterpret_cast(const_cast(cache_k.data())); + params.cache_v = + reinterpret_cast(const_cast(cache_v.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 6); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE*)nullptr, + (NV_TYPE*)nullptr, + sinks ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + 0.f, + 0.f, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh new file mode 100644 index 00000000000..00a20165555 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh @@ -0,0 +1,706 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +// #include "cu_tensor_map.cuh" +#include "attention_func.cuh" + +template +void print_params(AttentionParams const params) { + printf("max_model_len: %d\n", params.max_model_len); + printf("max_kv_len: %d\n", params.max_kv_len); + printf("max_blocks_per_seq: %d\n", params.max_blocks_per_seq); + printf("softmax_scale: %f\n", params.softmax_scale); + printf("quant_max_bound: %f\n", params.quant_max_bound); + printf("quant_min_bound: %f\n", params.quant_min_bound); + printf("max_tokens_per_batch: %d\n", params.max_tokens_per_batch); + printf("attn_mask_len: %d\n", params.attn_mask_len); + printf("sliding_window: %d\n", params.sliding_window); + printf("q_num_heads: %d\n", params.q_num_heads); + printf("kv_num_heads: %d\n", params.kv_num_heads); + printf("max_num_chunks: %d\n", params.max_num_chunks); + printf("max_tile_q: %d\n", params.max_tile_q); + printf("batch_size: %d\n", params.batch_size); +} + +template +__global__ void decode_unified_attention_c8_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto cache_k_scale = params.cache_k_scale; + const auto cache_v_scale = params.cache_v_scale; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + smem_t k_scale_smem; + smem_t v_scale_smem; + T* k_smem_scale_ptr = nullptr; + T* v_smem_scale_ptr = nullptr; + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int* block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + T cache_k_scale_reg[IsDynamicC8 + ? num_frags_z * 2 + : (is_scale_channel_wise ? num_frags_y * 4 : 1)]; + T cache_v_scale_reg[IsDynamicC8 + ? num_frags_z * 4 + : (is_scale_channel_wise ? num_frags_y * 2 : 1)]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T* cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T* cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + T* o_base_ptr_T = nullptr; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T* q_base_ptr = qkv + q_offset; + + o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int* mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool* attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + if constexpr (IsDynamicC8) { + k_smem_scale_ptr = reinterpret_cast( + smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); + } + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_idx * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg(k_smem_scale_ptr, + cache_k_scale_reg); + } + + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg(v_smem_scale_ptr, + cache_v_scale_reg); + } + + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res(o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeUnifiedC8Attention(const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::optional& attn_mask, + const paddle::Tensor& cache_k_scale, + const paddle::Tensor& cache_v_scale, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + auto* allocator = paddle::GetAllocator(qkv.place()); + + bool is_scale_channel_wise = false; + if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { + is_scale_channel_wise = true; + } + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; + constexpr uint32_t smem_size_0 = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = decode_unified_attention_c8_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = decode_unified_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 512); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = const_cast(cache_k.data()); + params.cache_v = const_cast(cache_v.data()); + params.cache_k_scale = + reinterpret_cast(const_cast(cache_k_scale.data())); + params.cache_v_scale = + reinterpret_cast(const_cast(cache_v_scale.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.quant_max_bound = quant_max_bound; + params.quant_min_bound = quant_min_bound; + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 6); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE*)nullptr, + (NV_TYPE*)nullptr, + sinks ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh b/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh new file mode 100644 index 00000000000..18788858923 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh @@ -0,0 +1,389 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class SharedMemFillMode { kFillZero, kNoFill }; + +enum class PrefetchMode { kNoPrefetch, kPrefetch }; + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_impl(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, + T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +__device__ __forceinline__ void commit_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + {} +#else + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void wait_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + cooperative_groups::wait(cooperative_groups::this_thread_block()); +#else + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } +#else + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } + } +#endif +} + +template +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 8); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_32b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 4); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4)); + } +#endif +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128, "num_bits must be 128"); + load_128b(smem_ptr, gmem_ptr); +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + static_assert(num_bits == 128 || num_bits == 64 || num_bits == 32, + "num_bits must be 128, 64 or 32."); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 64) { + pred_load_64b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 32) { + pred_load_32b(smem_ptr, gmem_ptr, predicate); + } +} + +using b32_t = uint32_t; +using b64_t = uint2; +using b128_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { + return sizeof(b128_t) / sizeof(T); +} + +struct smem_t { + // The base pointer. + b128_t* base; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, + uint32_t j) { + if constexpr (inv_stride <= 1) { + return i * stride + (j ^ (i % 8)); + } else { + return i / inv_stride * 8 + ((j + (i % inv_stride) * stride)) ^ + ((i / inv_stride) % 8); + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_column(uint32_t offset, uint32_t step_idx) { + if constexpr (row_stride == 2) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } else if constexpr (row_stride == 4) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_row(uint32_t offset) { + if constexpr (row_stride == 2) { + static_assert(step_size == 16 || step_size % 32 == 0, + "Unsupported step size"); + if constexpr (step_size == 16) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 32 == 0 + return offset + step_size * row_stride; + } + } else if constexpr (row_stride == 4) { + static_assert(step_size == 8 || step_size % 16 == 0, + "Unsupported step size"); + if constexpr (step_size == 8) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 16 == 0 + return offset + step_size * row_stride; + } + } else { + static_assert(step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_impl(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_trans_impl(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr, + bool predicate) { + b128_t* smem_ptr = base + offset; + pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr) { + b128_t* smem_ptr = base + offset; + load_128b(smem_ptr, + reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; diff --git a/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh b/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh new file mode 100644 index 00000000000..8662ee298d2 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh @@ -0,0 +1,296 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_i8i8i32( + int* C, // 8 + uint32_t* A, // 4 + uint32_t* B) { // 4 + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + } else { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), + "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(C[4]), + "r"(C[5]), + "r"(C[6]), + "r"(C[7])); + } +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + float* C, uint32_t* A, uint32_t* B) { + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same::value) { // fp16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } else { // bf16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } + } else { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } + } +} + +template +__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { + static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); + uint32_t* s_u32 = (uint32_t*)(s); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1006648320), + "r"(1006648320), + "f"(d[0]), + "f"(d[1])); + } else { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1065369472), + "r"(1065369472), + "f"(d[0]), + "f"(d[1])); + } +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/template_config.json b/custom_ops/gpu_ops/decode_unified_attention/template_config.json new file mode 100644 index 00000000000..d768c93a1ad --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/template_config.json @@ -0,0 +1,78 @@ +{ + "multiquery_attention_c8": { + "name": "decode_unified_attention_c8_kernel", + "function_name": "decode_unified_attention_c8_kernel", + "impl_file": "decode_unified_attention_c8_impl.cuh", + "template_params": [ + "T", + "CacheT", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_y", + "num_frags_z", + "is_scale_channel_wise", + "IsFP8", + "IsDynamicC8" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "CacheT": ["uint8_t"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_y": [8], + "num_frags_z": [1], + "is_scale_channel_wise": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_unified_attention_c8", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + }, + "multiquery_attention_c16": { + "name": "decode_unified_attention_c16_kernel", + "function_name": "decode_unified_attention_c16_kernel", + "impl_file": "decode_unified_attention_c16_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_z", + "num_frags_y" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_z": [1], + "num_frags_y": [8] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_unified_attention_c16", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + } +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/utils.cuh b/custom_ops/gpu_ops/decode_unified_attention/utils.cuh new file mode 100644 index 00000000000..7111ad23fb7 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/utils.cuh @@ -0,0 +1,689 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include "helper.h" +#include "mem_util.cuh" + +#define NUM_WARPS_PER_BLOCK 4 +#define NUM_THREADS_PER_BLOCK 128 +#define kWarpSize 32 + +#define HOSTDEVICE __host__ __device__ + +/*-------------------------------------traits-----------------------------------------*/ +template +struct type_traits { + using paddle_type = T; + using phi_type = T; + using nv_type = T; + using nv2_type = T; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::float16; +// using nv_type = half; +// using nv2_type = half2; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::bfloat16; +// using nv_type = __nv_bfloat16; +// using nv2_type = __nv_bfloat162; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat16> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat162> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT8_E4M3FN; +// using phi_type = phi::dtype::float8_e4m3fn; +// using nv_type = __nv_fp8_e4m3; +// using nv2_type = __nv_fp8x2_e4m3; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8x2_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; +/*---------------------------------1. type + * traits--------------------------------------*/ + +/*---------------------------------2. fast + * convert--------------------------------------*/ +inline __device__ static void convert_fp8(half* result, + const uint32_t& source) { + printf("Do not support fp8 to half although it's very easy.\n"); +} + +inline __device__ static void convert_fp8(__nv_bfloat16* result, + const uint32_t& source) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + uint32_t dest0; + uint32_t dest1; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "mov.b32 {lo, hi}, %2;\n" + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" + "}\n" + : "=r"(dest0), "=r"(dest1) + : "r"(source)); + + ((nv_bfloat162*)(result))[0] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); + ((nv_bfloat162*)(result))[1] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif +} + +inline __device__ static void convert_int8( + half* result, const uint32_t& source) { // 4 int8 each time + uint32_t* fp16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +} + +inline __device__ static void convert_int8( + __nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; // (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +} +/*---------------------------------2. fast + * convert--------------------------------------*/ + +/*---------------------------------3. vector + * cast--------------------------------------*/ +template +__forceinline__ HOSTDEVICE void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(float* dst, + const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(half* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast( + float* dst, const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(nv_bfloat16* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} +/*---------------------------------3. vector + * cast--------------------------------------*/ + +/*-------------------------------------4. + * func-----------------------------------------*/ +__forceinline__ HOSTDEVICE int div_up(int a, int b) { + return a / b + (a % b != 0); +} + +template +__inline__ __device__ T Rsqrt(T x); + +template <> +__inline__ __device__ float Rsqrt(float x) { + return rsqrt(x); +} + +template <> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} + +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, + uint32_t y) { + return (x > y) ? x - y : 0U; +} + +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +HOSTDEVICE __forceinline__ uint8_t QuantToC8(const T scale, + const T value, + const float max_bound, + const float min_bound) { + uint8_t eight_bits; + float quant_value; + if constexpr (is_need_kv_quant) { + quant_value = static_cast(scale * value); + } else { + quant_value = static_cast(value); + } + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + + if constexpr (IsFP8) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + quant_value = quant_value > 448.0f ? 448.0f : quant_value; + quant_value = quant_value < -448.0f ? -448.0f : quant_value; + auto tmp = static_cast<__nv_fp8_e4m3>(quant_value); + eight_bits = *(reinterpret_cast(&tmp)); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif + } else { + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + eight_bits = static_cast(quant_value + 128.0f); + } + return eight_bits; +} + +template +inline __device__ static void convert_c8(T* result, const uint32_t& source) { + if constexpr (IsFP8) { + convert_fp8(result, source); + } else { + convert_int8(result, source); + } +} + +template +inline __device__ void WelfordCombine1(T b_m2, T* m2) { + *m2 += b_m2; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { + *m2 = thread_m2; + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask); + WelfordCombine1(b_m2, m2); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { + WelfordWarpReduce(thread_m2, m2); +} + +#define CHECK_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ + << ") " << __FILE__ << ": line " << __LINE__ \ + << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } + +__device__ __forceinline__ float2 fast_float2_mul(const float2& a, + const float2& b) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, 0.0;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, 0.0;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + return res; +} + +__device__ __forceinline__ float2 fast_float2_fma(float2& a, + const float2& b, + const float2& c) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, %6;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, %7;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), + "f"(a.y), + "f"(b.x), + "f"(b.y), + "f"(c.x), + "f"(c.y) // 输入操作数 + ); + return res; +} + +// __device__ __forceinline__ float2 fast_bfloat162_fma(__nv_bfloat162& a_bf162, +// const __nv_bfloat162& b_bf162, const __nv_bfloat162& c_bf162) { +// // 使用向量化PTX指令同时处理x/y分量 +// asm volatile ( +// "{\n" +// " fma.rn.b16 %0, %2, %4, %0;\n" // res.x = a.x * b.x +// " fma.rn.b16 %1, %3, %5, %1;\n" // res.y = a.y * b.y +// "}" +// : "=r"(a_bf162.x), "=r"(a_bf162.y) // 输出操作数 +// : "r"(b_bf162.x), "r"(b_bf162.y), +// "r"(c_bf162.x), "r"(c_bf162.y) // 输入操作数 +// ); +// float2 res = __bfloat1622float2_rn(a_bf162); +// return res; +// } + +__device__ __forceinline__ float2 fast_float2_sub_expf(const float2& a, + const float2& b) { + float2 res; + // 使用向量化减法指令(PTX sub.rn.f32) + asm volatile( + "{\n" + " sub.f32 %0, %2, %4;\n" // res.x = a.x - b.x + " sub.f32 %1, %3, %5;\n" // res.y = a.y - b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + res.x = expf(res.x); + res.y = expf(res.y); + return res; +} + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = static_cast(ori_out_vec[i]); + printf("Fatal! Unimplemented StoreFunc for cascade append attention\n"); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + 127.0f * + static_cast((ori_out_vec[i] + shift_bias_vec[i]) * + smooth_weight_vec[i]) * + in_scale; + quant_value = rintf(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + out_vec[i] = static_cast(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector<__nv_fp8_e4m3, VEC_SIZE>& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + quant_max_bound * static_cast(ori_out_vec[i]) * in_scale; + quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value; + quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value; + out_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = ori_out_vec[i]; + } +}; +/*-------------------------------------4. + * func-----------------------------------------*/ + +/*-----------------------------------5. + * dispatch---------------------------------------*/ +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim"); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 5) { \ + constexpr size_t GROUP_SIZE = 5; \ + __VA_ARGS__ \ + } else if (group_size == 6) { \ + constexpr size_t GROUP_SIZE = 6; \ + __VA_ARGS__ \ + } else if (group_size == 7) { \ + constexpr size_t GROUP_SIZE = 7; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ + } + +#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_Q_TILE_SIZE( \ + group_size, max_tokens_per_batch, Q_TILE_SIZE, ...) \ + { \ + constexpr size_t Q_TILE_SIZE = 16; \ + __VA_ARGS__ \ + } + +#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ + if (causal) { \ + constexpr bool CAUSAL = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCKSHAPE_Q_SYSTEM( \ + block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr size_t BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } + +#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ + if (is_dynamic_cfp8) { \ + constexpr bool IsDynamicC8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IsDynamicC8 = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_IS_FP8(is_fp8, IS_FP8, ...) \ + if (is_fp8) { \ + constexpr bool IS_FP8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IS_FP8 = false; \ + __VA_ARGS__ \ + } + +struct AppendAttnMetaData { + int batch_size; + int block_size; + int q_num_heads; + int kv_num_heads; + int token_num; + int head_dims; + int head_dims_v; + int max_blocks_per_seq; + const int* mask_offset = nullptr; +}; + +template +struct AttentionParams { + T* __restrict__ qkv; + CacheT* __restrict__ cache_k; + CacheT* __restrict__ cache_v; + T* __restrict__ cache_k_scale; + T* __restrict__ cache_v_scale; + int* __restrict__ seq_lens_q; + int* __restrict__ seq_lens_kv; + int* __restrict__ block_indices; + int* __restrict__ num_blocks_ptr; + int* __restrict__ chunk_size_ptr; + int* __restrict__ cu_seqlens_q; + int* __restrict__ block_table; + int* __restrict__ mask_offset; + bool* __restrict__ attn_mask; + T* __restrict__ tmp_o; + float* __restrict__ tmp_m; + float* __restrict__ tmp_d; + int max_model_len; + int max_kv_len; + int max_blocks_per_seq; + float softmax_scale; + float quant_max_bound; + float quant_min_bound; + int num_blocks_x; + int attn_mask_len; + bool sliding_window; + int q_num_heads; + int kv_num_heads; + int max_num_chunks; + int max_tile_q; + int batch_size; + int token_num; + int head_dims; + int max_tokens_per_batch; +}; diff --git a/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu new file mode 100644 index 00000000000..7878e9926c5 --- /dev/null +++ b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu @@ -0,0 +1,326 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/decoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + auto stream = qkv.stream(); + + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + + if (max_just_dec_len_this_time > 0) { + if (speculate_decoder) { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } else { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } + } + return {qkv}; +} + +std::vector> DecoderWriteCacheWithRoPEInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_shape}; +} + +std::vector DecoderWriteCacheWithRoPEInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_dtype}; +} + +PD_BUILD_STATIC_OP(decoder_write_cache_with_rope) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "set_max_lengths", + paddle::Optional("rotary_embs"), + paddle::Optional("qkv_bias"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight")}) + .Outputs({"qkv_out"}) + .SetInplaceMap({{"qkv", "qkv_out"}}) + .Attrs({ + "rms_norm_eps: float", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "speculate_decoder: bool", + }) + .SetKernelFn(PD_KERNEL(DecoderWriteCacheWithRoPE)) + .SetInferShapeFn(PD_INFER_SHAPE(DecoderWriteCacheWithRoPEInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecoderWriteCacheWithRoPEInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index ed6ba5a5ef8..1dbd443e7dd 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -553,6 +553,13 @@ def find_end_files(directory, end_str): sources += find_end_files(fp8_auto_gen_directory, ".cu") if cc >= 90 and nvcc_version >= 12.0: + # decode unified attention + os.system( + "python utils/auto_gen_template_attention.py --config gpu_ops/decode_unified_attention/template_config.json --output gpu_ops/decode_unified_attention/template_instantiation/autogen" + ) + sources += ["gpu_ops/decode_unified_attention.cu"] + sources += ["gpu_ops/decoder_write_cache_with_rope.cu"] + sources += find_end_files("gpu_ops/decode_unified_attention", ".cu") # Hopper optimized mla sources += find_end_files("gpu_ops/mla_attn", ".cu") sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"] diff --git a/custom_ops/utils/auto_gen_template_attention.py b/custom_ops/utils/auto_gen_template_attention.py new file mode 100644 index 00000000000..5658f6645e7 --- /dev/null +++ b/custom_ops/utils/auto_gen_template_attention.py @@ -0,0 +1,227 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Universal template instantiation generator - fully based on configuration file template instantiation generation.""" + +import argparse +import json +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class TemplateConfig: + """Template configuration class.""" + + name: str # Function name + function_name: str # Actual function name + impl_file: str # Implementation file path + template_params: List[str] # Template parameter list (in order) + dispatch_params: Dict[str, List[Any]] # Dispatch parameters + data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix) + max_instances_per_file: int = 60 # Maximum instances per file + file_prefix: str = "" # File prefix + function_signature: str = "" # Function signature template + + +class UniversalTemplateInstantiator: + """Universal template instantiator - fully based on configuration file.""" + + def __init__(self, config_file: str): + """Initialize the instantiator.""" + self.config_file = config_file + self.configs = self._load_configs() + + def _load_configs(self) -> Dict[str, TemplateConfig]: + """Load configuration file.""" + with open(self.config_file, "r", encoding="utf-8") as f: + config_data = json.load(f) + + configs = {} + for name, config_dict in config_data.items(): + config = TemplateConfig(**config_dict) + self._validate_config(config) + configs[name] = config + return configs + + def _validate_config(self, config: TemplateConfig): + """Validate configuration completeness.""" + for param_name in config.template_params: + if param_name not in config.dispatch_params: + raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params") + + def _build_template_args(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Build template arguments.""" + template_args_parts = [] + + for param_name in config.template_params: + if param_name in params: + template_args_parts.append(str(params[param_name])) + + else: + raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params") + + return f"<{', '.join(template_args_parts)}>" + + def _build_params_template_args(self, params: Dict[str, Any]) -> str: + """Build template arguments for AttentionParams.""" + params_template_args = [] + if "T" in params: + params_template_args.append(str(params["T"])) + else: + raise ValueError("Template parameter 'T' not found in dispatch_params") + + if "CacheT" in params: + params_template_args.append(str(params["CacheT"])) + else: + # C16 kernels use AttentionParams - T is repeated for both args + params_template_args.append(str(params["T"])) + + return f"<{', '.join(params_template_args)}>" + + def _generate_function_signature( + self, config: TemplateConfig, template_args: str, params_template_args: str + ) -> str: + """Generate function signature.""" + if config.function_signature: + signature = config.function_signature.format( + function_name=config.function_name, + template_args=template_args, + params_template_args=params_template_args, + ) + + return signature + else: + raise ValueError(f"Function signature not found for {config.name}") + + def _generate_file_header(self, config: TemplateConfig) -> str: + """Generate file header.""" + return f"""// Generated by autogen_template_instantiation.py - Do not edit. + +#pragma once + +#include "../../{config.impl_file}" +""" + + def _generate_template_instantiation(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Generate template instantiation.""" + template_args = self._build_template_args(config, params) + params_template_args = self._build_params_template_args(params) + return self._generate_function_signature(config, template_args, params_template_args) + + def _clean_output_directory(self, output_dir: str): + """Clean output directory before generating new files.""" + output_path = Path(output_dir) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + def generate_combinations_for_type(self, config: TemplateConfig) -> List[Dict[str, Any]]: + """Generate parameter combinations for specific type.""" + combinations = [] + + def _generate_recursive( + params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] + ): + if not param_names: + combinations.append(current_params.copy()) + return + + param_name = param_names[0] + for value in params_dict[param_name]: + current_params[param_name] = value + _generate_recursive(params_dict, current_params, param_names[1:]) + + _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) + + return combinations + + def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]: + """Split combinations into multiple files.""" + chunks = [] + for i in range(0, len(combinations), max_per_file): + chunk = combinations[i : i + max_per_file] + chunks.append(chunk) + return chunks + + def generate_file_content( + self, + config: TemplateConfig, + file_index: int, + combinations: List[Dict[str, Any]], + ) -> str: + """Generate file content.""" + content = self._generate_file_header(config) + + for params in combinations: + content += self._generate_template_instantiation(config, params) + + return content + + def generate_for_function_type(self, function_name: str, output_dir: str): + """Generate template instantiation files for specific function type.""" + if function_name not in self.configs: + raise ValueError(f"Function type '{function_name}' not found in config") + + config = self.configs[function_name] + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + combinations = self.generate_combinations_for_type(config) + if combinations: + chunks = self.split_combinations(combinations, config.max_instances_per_file) + for i, chunk in enumerate(chunks): + filename = f"{config.file_prefix}_part_{i:02d}.cu" + filepath = output_path / filename + content = self.generate_file_content(config, i, chunk) + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + def generate_all(self, output_dir: str): + """Generate all configured function types.""" + self._clean_output_directory(output_dir) + for function_name in self.configs.keys(): + print(f"Generating template instantiations for {function_name}...") + self.generate_for_function_type(function_name, output_dir) + print(f"Completed generating {function_name} template instantiations.") + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Universal template instantiation generator") + parser.add_argument( + "--config", + "-c", + type=str, + help="Configuration file path (JSON format)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + help="Output directory", + ) + + args = parser.parse_args() + + try: + instantiator = UniversalTemplateInstantiator(args.config) + instantiator.generate_all(args.output) + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 955f3dfdd39..bd56b21b99e 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -70,6 +70,8 @@ def _validate_split_kv_size(value: int) -> int: # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN" # and "MLA_ATTN" can be set currently. "FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"), + # enable decode attention + "USE_DECODE_UNIFIED_ATTENTION": lambda: bool(int(os.getenv("USE_DECODE_UNIFIED_ATTENTION", "0"))), # Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently. "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), # Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently. diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 15b657c249d..905e5941aa9 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -73,6 +73,8 @@ def allocate_launch_related_buffer( num_heads, kv_num_heads, block_size, + head_dim=128, + dtype="bfloat16", ): # Initialize AttentionBackend buffers assert num_heads % kv_num_heads == 0 @@ -107,6 +109,28 @@ def allocate_launch_related_buffer( res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + # Decode attention split ops buffers + if envs.USE_DECODE_UNIFIED_ATTENTION: + min_chunk_size = 512 + max_num_chunk = (max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 if decoder_step_token_num * group_size <= 16 else 32 + q_tile_num = (decoder_step_token_num * group_size + q_tile_size - 1) // q_tile_size + res["decode_block_indices"] = paddle.full( + [max_batch_size * kv_num_heads * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + res["decode_num_blocks"] = paddle.full([1], 0, dtype="int32") + res["decode_chunk_size"] = paddle.full([1], 0, dtype="int32") + res["decode_tmp_workspace"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads * head_dim], 0, dtype=dtype + ) + res["decode_tmp_m"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + res["decode_tmp_d"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + return res diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index b203fdbb221..37401b1b314 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -43,6 +43,9 @@ ) from fastdeploy.model_executor.layers.attention.ops import ( append_attention, + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, get_attn_mask_q, get_block_shape_and_split_kv_block, gqa_rope_write_cache, @@ -272,8 +275,10 @@ def __init__( self.rope_3d = False # Note(ZKK): here must be consistent with append_attn_backend.py self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024)) + self.max_tokens_per_batch: int = self.speculate_max_draft_token_num + 1 if FLASH_ATTN_VERSION is None: init_flash_attn_version() + print(f"num_heads: {self.num_heads}, kv_num_heads: {self.kv_num_heads}") def get_attention_meta(self): """get_attention_meta""" @@ -414,6 +419,20 @@ def forward_mixed( ) else: forward_meta.attn_mask_q = None + if envs.USE_DECODE_UNIFIED_ATTENTION: + config_for_attention( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + getattr(layer, "cache_quant_type_str", "none"), + self.group_size, + self.kv_num_heads, + self.max_tokens_per_batch, + ) use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0 @@ -468,73 +487,148 @@ def forward_mixed( head_dim=self.head_dim, )[0].reshape([-1, self.attn_outputsize_tp]) - res_decoder = append_attention( - qkv, - cache_k, - cache_v, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - forward_meta.block_tables, - forward_meta.encoder_batch_ids, - forward_meta.encoder_tile_ids_per_batch, - forward_meta.encoder_num_blocks_x_cpu, - forward_meta.kv_batch_ids, - forward_meta.kv_tile_ids_per_batch, - forward_meta.kv_num_blocks_x_cpu, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, - forward_meta.rotary_embs, - forward_meta.attn_mask, - layer.qkv_bias, - layer.qkv_scale, - cache_k_scales, - cache_v_scales, - getattr(layer, "cache_k_out_scale", None), - getattr(layer, "cache_v_out_scale", None), - getattr(layer, "cache_k_zp", None), - getattr(layer, "cache_v_zp", None), - layer.linear_shift, - layer.linear_smooth, - forward_meta.attn_mask_offsets, - metadata.kv_signal_data_list[layer.layer_id], - q_norm_weight, - k_norm_weight, - getattr(layer, "sinks", None), - getattr(layer, "rms_norm_eps", 1e-6), - metadata._fuse_kernel_compute_dtype, - getattr(layer, "cache_quant_type_str", "none"), - layer.use_neox_rotary_style, - self.rope_3d, - self.max_seq_len, - getattr(layer, "quant_max_bound", 0.0), - getattr(layer, "quant_min_bound", 0.0), - getattr(layer, "out_scale", -1.0), - self.encoder_block_shape_q, - self.decoder_block_shape_q, - self.max_partition_size, - self.max_seq_len, - self.speculate_max_draft_token_num + 1, - self.causal, - self.speculative_method is not None, - ) - - if use_fa_do_prefill: - merge_prefill_decode_output( - res_encoder, - res_decoder, + if envs.USE_DECODE_UNIFIED_ATTENTION: + qkv_out = decoder_write_cache_with_rope( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + layer.qkv_bias, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "rms_norm_eps", 1e-6), + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.speculative_method is not None, + ) + if use_fa_do_prefill: + res_decoder = res_encoder + else: + res_decoder = paddle.empty( + [qkv.shape[0], self.num_heads * self.head_dim], + dtype=qkv.dtype, + ) + decode_unified_attention( + qkv_out, + cache_k, + cache_v, + forward_meta.decode_tmp_workspace, + forward_meta.decode_tmp_m, + forward_meta.decode_tmp_d, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - self.num_heads, - self.head_dim, + forward_meta.block_tables, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + forward_meta.attn_mask, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + forward_meta.attn_mask_offsets, + getattr(layer, "sinks", None), + res_decoder, + getattr(layer, "cache_quant_type_str", "none"), + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), self.speculate_max_draft_token_num + 1, + self.causal, ) - return res_encoder - else: return res_decoder + else: + res_decoder = append_attention( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + forward_meta.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + forward_meta.attn_mask_offsets, + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "sinks", None), + getattr(layer, "rms_norm_eps", 1e-6), + metadata._fuse_kernel_compute_dtype, + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.max_partition_size, + self.max_seq_len, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + ) + + if use_fa_do_prefill: + merge_prefill_decode_output( + res_encoder, + res_decoder, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_heads, + self.head_dim, + self.speculate_max_draft_token_num + 1, + ) + return res_encoder + else: + return res_decoder diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index e0175573fa3..d5d6c45afa7 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,9 @@ """ from .append_attention import append_attention, append_attention_with_output +from .config_for_attention import config_for_attention +from .decode_unified_attention import decode_unified_attention +from .decoder_write_cache_with_rope import decoder_write_cache_with_rope from .flash_attn_v4 import flash_attn_v4 from .flash_mask_attention import flash_mask_attention from .get_attn_mask_q import get_attn_mask_q @@ -37,4 +40,7 @@ "flash_attn_v4", "flash_mask_attention", "get_attn_mask_q", + "config_for_attention", + "decoder_write_cache_with_rope", + "decode_unified_attention", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py new file mode 100644 index 00000000000..d8226aad4b1 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py @@ -0,0 +1,58 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + config_for_attention as config_for_attention_cuda, + ) + + +def config_for_attention( + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + max_len_tensor_cpu: paddle.Tensor, + cache_quant_type: str = "none", + group_size: int = 1, + kv_num_heads: int = 1, + max_tokens_per_batch: int = 1, +): + """ + append_attention + """ + if current_platform.is_cuda(): + config_for_attention_cuda( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_indices, + num_blocks, + chunk_size, + max_len_tensor_cpu, + cache_quant_type, + group_size, + kv_num_heads, + max_tokens_per_batch, + ) + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py b/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py new file mode 100644 index 00000000000..fedfc33dc7c --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py @@ -0,0 +1,105 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decode_unified_attention as decode_unified_attention_cuda, + ) + + +def decode_unified_attention( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + tmp_workspace: paddle.Tensor, + tmp_m: paddle.Tensor, + tmp_d: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + set_max_lengths: paddle.Tensor, + attn_mask: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + mask_offset: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, + fmha_out: Optional[paddle.Tensor] = None, + cache_quant_type: str = "none", + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + max_tokens_per_batch: int = 1, + causal: bool = True, + sliding_window: int = 0, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + out = decode_unified_attention_cuda( + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + set_max_lengths, + attn_mask, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + mask_offset, + sinks, + fmha_out, + cache_quant_type, + max_input_length, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + causal, + sliding_window, + ) + return out + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py new file mode 100644 index 00000000000..b10f6cd1bf6 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py @@ -0,0 +1,97 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decoder_write_cache_with_rope as decoder_write_cache_with_rope_cuda, + ) + + +def decoder_write_cache_with_rope( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + set_max_lengths: paddle.Tensor, + rotary_embs: Optional[paddle.Tensor] = None, + qkv_bias: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, + rms_norm_eps: float = 1e-6, + cache_quant_type: str = "none", + use_neox_rotary_style: bool = False, + rope_3d: bool = False, + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + speculate_decoder: bool = False, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + qkv_out = decoder_write_cache_with_rope_cuda( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + set_max_lengths, + rotary_embs, + qkv_bias, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + cache_quant_type, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + speculate_decoder, + ) + return qkv_out + else: + raise NotImplementedError diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 70c626df28f..41b0da93819 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -407,6 +407,19 @@ def _initialize_attn_backend( self.target_model_inputs["kv_num_blocks_x_cpu"] ).cpu() + # Decode attention split ops buffers + if ( + "decode_block_indices" in self.target_model_inputs + and self.target_model_inputs["decode_block_indices"] is not None + ): + self.model_inputs["decode_block_indices"] = self.target_model_inputs["decode_block_indices"] + + self.model_inputs["decode_num_blocks"] = self.target_model_inputs["decode_num_blocks"] + self.model_inputs["decode_chunk_size"] = self.target_model_inputs["decode_chunk_size"] + self.model_inputs["decode_tmp_workspace"] = self.target_model_inputs["decode_tmp_workspace"] + self.model_inputs["decode_tmp_m"] = self.target_model_inputs["decode_tmp_m"] + self.model_inputs["decode_tmp_d"] = self.target_model_inputs["decode_tmp_d"] + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -678,6 +691,15 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_ru attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, ) + # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) + if "decode_block_indices" in self.model_inputs: + self.forward_meta.decode_block_indices = self.model_inputs["decode_block_indices"] + self.forward_meta.decode_num_blocks = self.model_inputs["decode_num_blocks"] + self.forward_meta.decode_chunk_size = self.model_inputs["decode_chunk_size"] + self.forward_meta.decode_tmp_workspace = self.model_inputs["decode_tmp_workspace"] + self.forward_meta.decode_tmp_m = self.model_inputs["decode_tmp_m"] + self.forward_meta.decode_tmp_d = self.model_inputs["decode_tmp_d"] + # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index aec0be0e746..a6033693542 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1423,6 +1423,15 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): device_routing_buffer=device_routing_buffer, ) + # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) + if "decode_block_indices" in self.share_inputs: + self.forward_meta.decode_block_indices = self.share_inputs["decode_block_indices"] + self.forward_meta.decode_num_blocks = self.share_inputs["decode_num_blocks"] + self.forward_meta.decode_chunk_size = self.share_inputs["decode_chunk_size"] + self.forward_meta.decode_tmp_workspace = self.share_inputs["decode_tmp_workspace"] + self.forward_meta.decode_tmp_m = self.share_inputs["decode_tmp_m"] + self.forward_meta.decode_tmp_d = self.share_inputs["decode_tmp_d"] + dist_status = self.collect_distributed_status() if_only_decode = dist_status.if_only_decode @@ -1651,6 +1660,8 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, + dtype=self.model_config.dtype, ) self.share_inputs.update(res_buffer) diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 241ccaf6b71..e09d5c81193 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -205,6 +205,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode attention split ops buffers (initialized by _initialize_attn_backend) + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Initialize thinking related buffers self.enable_thinking = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool") @@ -814,6 +821,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode attention split ops buffers + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Input tokens self.draft_tokens = paddle.full( diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 9b9bbe2bb76..2673386a927 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -1452,6 +1452,8 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, + dtype=self.model_config.dtype, ) self.share_inputs.update(res_buffer) diff --git a/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py b/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py new file mode 100644 index 00000000000..0083d70e769 --- /dev/null +++ b/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py @@ -0,0 +1,381 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean, + is_port_open, +) + +os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN" +os.environ["FLAGS_flash_attn_version"] = "3" +os.environ["USE_DECODE_UNIFIED_ATTENTION"] = "1" + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean() + + print("log dir clean ") + if os.path.exists("log") and os.path.isdir("log"): + shutil.rmtree("log") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle") + else: + model_path = "./ernie-4_5-21b-a3b-bf16-paddle" + mtp_model_path = os.path.join(model_path, "mtp") + speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path} + + log_path = "server.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "2", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "32768", + "--max-num-seqs", + "128", + "--quantization", + "wint4", + "--speculative-config", + json.dumps(speculative_config), + "--graph-optimization-config", + '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + ] + + # Start subprocess in new process group + # 清除log目录 + if os.path.exists("log"): + shutil.rmtree("log") + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(300): + if is_port_open("127.0.0.1", FD_API_PORT): + print(f"Server is up on port {FD_API_PORT}") + break + time.sleep(1) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + print(f"server (pid={process.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def send_request(url, payload, timeout=60): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Prefill Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + # print("Prefill Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + # print("Prefill Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_mtp_accept_ratio(api_url): + """测试mtp接受率""" + payload = { + "model": "default", + "messages": [ + { + "role": "user", + "content": "国外项目风险管理研究起步较早,理论体系成熟。早期研究集中于保险与金融领域,后逐步扩展至工程项目、" + "公共管理等多领域。在理论层面,COSO《企业风险管理——整合框架》和ISO31000标准为风险管理提供了系统性" + "指导,强调风险识别、评估、应对与监控的全流程管理。风险识别方法包括故障树分析、事件树分析等;风险评估" + "则广泛应用VaR模型、蒙特卡洛模拟等量化工具。应对策略涵盖规避、转移、减轻和接受等,并衍生出风险共享、" + "升级等复杂策略。此外,组织文化、管理层支持等因素对风险管理有效性影响显著。近年来,随着科技发展," + "人工智能、大数据等技术被引入风险管理,推动其向智能化、自动化方向发展。请介绍一下国外关于项目风险管理" + "的文献研究综述,300字以内", + }, + ], + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "temperature": 0, + "seed": 23, + "top_p": 0, + } + + print("fastdeploy answer is :") + + try: + # TODO: 第一次和第二次存在diff,后面正常,暂时多请求一次 + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + for idx, chunk in enumerate(chunks): + print(f"\nchunk[{idx}]:\n{json.dumps(chunk, ensure_ascii=False)}") + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + speculate_metrics = chunks[-2]["choices"][0]["speculate_metrics"] + except Exception as e: + print(f"解析失败: {e}") + print("\nresult:\n", result) + + baseline = ( + "国外项目风险管理研究起步早、体系成熟。" + "早期聚焦保险与金融领域,后拓展至多领域。" + "理论层面,COSO《企业风险管理——整合框架》及ISO31000标准提供系统性指导," + "强调全流程管理。" + "风险识别方法多样,如故障树、事件树分析;" + "评估常用VaR模型、蒙特卡洛模拟等量化工具。" + "应对策略丰富,涵盖规避、转移等基本策略及风险共享、升级等复杂策略。" + "组织文化与管理层支持对风险管理有效性影响大。" + "近年来,科技发展促使人工智能、大数据等融入," + "推动风险管理向智能化、自动化迈进 。" + ) + + baseline_ratio = { + "accepted_tokens": 130, + "rejected_tokens": 20, + "accept_ratio": 0.42307692307692313, + "average_accept_length": 1.7333333333333334, + "accepted_tokens_per_head": [75, 55], + "accept_ratio_per_head": [0.7333333333333333], + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result_2 = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + speculate_metrics_2 = chunks[-2]["choices"][0]["speculate_metrics"] + print("chunks:", chunks[-2]) + print("baseline", speculate_metrics) + print("speculate_metrics_2", speculate_metrics_2) + assert result_2 == baseline, f"与baseline存在diff,result_2: {result}\n baseline: {baseline}" + assert speculate_metrics_2 == baseline_ratio, ( + f"speculate_metrics存在diff," f"speculate_metrics_2: {speculate_metrics_2}\n " f"baseline: {baseline_ratio}" + ) + assert speculate_metrics_2["accept_ratio"] > 0, "accept_ratio异常" + prompt_tokens = chunks[-1]["usage"]["prompt_tokens"] + cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"] + assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题" diff --git a/tests/operators/attention/test_decode_unified_attention_c16.py b/tests/operators/attention/test_decode_unified_attention_c16.py new file mode 100644 index 00000000000..0d17d17ccd6 --- /dev/null +++ b/tests/operators/attention/test_decode_unified_attention_c16.py @@ -0,0 +1,868 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention as append_attention_op, +) +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + emb = paddle.unsqueeze(emb, 2) + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _apply_rope(self, rotary_emb, q, k, start_pos=0): + seq, head_dim = q.shape[2], q.shape[3] + cos, sin = paddle.chunk(rotary_emb, 2, axis=0) + cos = cos[:, :, start_pos : start_pos + seq, ...] + sin = sin[:, :, start_pos : start_pos + seq, ...] + cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim]) + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim]) + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + + query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos)) + key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos)) + return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype) + + +def naive_attention_impl(query, key, value, cache_k=None, cache_v=None, mask=None, scale=1.0): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + if cache_k is not None: + cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim]) + cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1]) + cache_k = cache_k.reshape([batch, heads, -1, head_dim]) + key = paddle.concat([cache_k, key], axis=2) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + if cache_v is not None: + cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim]) + cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1]) + cache_v = cache_v.reshape([batch, heads, -1, head_dim]) + value = paddle.concat([cache_v, value], axis=2) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + if mask is not None: + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value) + return result + + +def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len): + """Read K/V from paged cache and return as [batch, num_head, seq_len, dim_head].""" + _, num_head, blocksize, dim_head = cache_k.shape + out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype) + out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype) + for i in range(bsz): + for j in range(cache_seq_len): + out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] + out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] + return out_cache_k, out_cache_v + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def remove_padding(seq_lens, cu_seq_lens, inputs, token_num): + bsz, num_head, seq_len, head_dim = inputs.shape + output = paddle.zeros(shape=[token_num, num_head * head_dim], dtype=inputs.dtype) + inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1]) + for i in range(bsz): + seq_len_now = seq_lens[i] + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :] + return output + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeUnifiedAttentionC16(unittest.TestCase): + """Base test class for decode append attention with cache_quant_type='none' (fp16/bf16 KV cache). + + Uses append_attention for prefill (verified correct by test_append_attention_c16.py) + and then tests decode_unified_attention (new split ops) against the same naive reference. + + Subclasses override setUp to vary batch_size, max_tokens_per_batch, dtype, etc. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + + # Use small seq_len for fast testing; can increase later + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def init_tensor(self): + self.rope = RopeEmbedding(self.use_neox_rotary_style) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + + # Encoder phase: prefill with seq_len tokens + self.enc_q, self.enc_k, self.enc_v, self.enc_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.seq_len, + self.head_dim, + self.place, + self.dtype, + ) + + # Decoder phase: max_tokens_per_batch decode tokens + self.dec_q, self.dec_k, self.dec_v, self.dec_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + + def _get_block_shape_buffers(self, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time): + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + return { + "decoder_batch_ids": decoder_batch_ids, + "decoder_tile_ids_per_batch": decoder_tile_ids_per_batch, + "decoder_num_blocks_cpu": decoder_num_blocks_cpu, + "encoder_batch_ids": encoder_batch_ids, + "encoder_tile_ids_per_batch": encoder_tile_ids_per_batch, + "encoder_num_blocks_cpu": encoder_num_blocks_cpu, + "kv_batch_ids": kv_batch_ids, + "kv_tile_ids_per_batch": kv_tile_ids_per_batch, + "kv_num_blocks_x_cpu": kv_num_blocks_x_cpu, + "max_len_tensor_cpu": max_len_tensor_cpu, + } + + def run_append_attention( + self, + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run append_attention op.""" + buffers = self._get_block_shape_buffers(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time) + + qkv_copy = copy.deepcopy(qkv) + cache_k_copy = copy.deepcopy(cache_k) + cache_v_copy = copy.deepcopy(cache_v) + + out = append_attention_op( + qkv_copy, + cache_k_copy, + cache_v_copy, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["max_len_tensor_cpu"], + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + None, # sinks + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + -1, + 64, + 16, + 1024, + self.max_model_len, + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, + self.max_tokens_per_batch > 1, # speculate_decoder + ) + return out, cache_k_copy, cache_v_copy + + def _build_decode_buffer(self): + """Build buffer for new split decode ops.""" + buffer = {} + min_chunk_size = 512 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + return buffer + + def _run_decode_unified_attention( + self, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run config_for_attention + decoder_write_cache_with_rope + decode_unified_attention.""" + buffer = self._build_decode_buffer() + + config_for_attention( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + + dec_cache_k = copy.deepcopy(cache_k) + dec_cache_v = copy.deepcopy(cache_v) + dec_qkv = copy.deepcopy(self.dec_qkv) + + decoder_write_cache_with_rope( + dec_qkv, + dec_cache_k, + dec_cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["max_len_tensor_cpu"], + self.rotary_embs, + None, # qkv_bias + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_unified_attention( + dec_qkv, + dec_cache_k, + dec_cache_v, + buffer["tmp_workspace"], + buffer["tmp_m"], + buffer["tmp_d"], + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + None, # attn_mask + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # mask_offset + None, # sinks + paddle.empty([dec_qkv.shape[0], self.q_num_head * self.head_dim], dtype=dec_qkv.dtype), # fmha_out + self.cache_quant_type, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + ) + return out, dec_cache_k, dec_cache_v + + def do_prefill_with_append_attention(self): + """Prefill using append_attention. Returns cache_k, cache_v after prefill.""" + seq_lens_encoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + seq_lens_decoder = paddle.to_tensor([0] * self.batch_size, "int32") + seq_lens_this_time = copy.deepcopy(seq_lens_encoder) + + batch_id_per_token, cu_seqlens_q, _ = get_padding_offset(self.batch_size, seq_lens_this_time) + + _, cache_k, cache_v = self.run_append_attention( + self.enc_qkv, + self.cache_k, + self.cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ) + return cache_k, cache_v + + def compute_naive_decode_ref(self, cache_k, cache_v): + """Compute naive reference for decode step using cache from paged cache.""" + # Read K/V from paged cache + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + cache_k, cache_v, self.batch_size, self.block_tables, self.seq_len + ) + + # Only use the first decode token (seq_lens_this_time=1 per batch) + dec_q = self.dec_q[:, :, :1, :] + dec_k = self.dec_k[:, :, :1, :] + dec_v = self.dec_v[:, :, :1, :] + + # Apply RoPE to decode Q/K at position seq_len + dec_q_rope, dec_k_rope = self.rope._apply_rope(self.rotary_embs, dec_q, dec_k, start_pos=self.seq_len) + + # Compute naive attention + out_ref = naive_attention_impl( + dec_q_rope, + dec_k_rope, + dec_v, + cache_k=naive_cache_k, + cache_v=naive_cache_v, + scale=self.softmax_scale, + ) + + dec_seq_lens_this_time = paddle.to_tensor([1] * self.batch_size, "int32") + dec_token_num = self.batch_size + _, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + out_ref = remove_padding(dec_seq_lens_this_time, dec_cu_seqlens_q, out_ref, dec_token_num) + return out_ref + + def test_naive_vs_append_attention_decode(self): + """Test: prefill with append_attention, then decode with append_attention. Compare to naive.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Naive reference for decode + out_ref = self.compute_naive_decode_ref(cache_k, cache_v) + + # Step 3: Decode with append_attention + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_dec, _, _ = self.run_append_attention( + self.dec_qkv, + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_ref_f = out_ref.astype("float32").numpy() + out_dec_f = out_dec.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_dec_f = out_dec_f[:dec_token_num] + + np.testing.assert_allclose( + out_dec_f, + out_ref_f, + rtol=1e-02, + atol=1e-02, + err_msg="append_attention decode output doesn't match naive reference", + ) + + def test_naive_vs_decode_unified_attention(self): + """Test: prefill with append_attention, then decode with new split decode ops.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Naive reference for decode + out_ref = self.compute_naive_decode_ref(cache_k, cache_v) + + # Step 3: Decode with new split ops + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out, _, _ = self._run_decode_unified_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_ref_f = out_ref.astype("float32").numpy() + out_decode_f = out.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_ref_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention output doesn't match naive reference", + ) + + def test_append_vs_decode_unified_attention(self): + """Test: append_attention decode vs new split decode ops should produce same result.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Decode with append_attention + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_append, _, _ = self.run_append_attention( + self.dec_qkv, + copy.deepcopy(cache_k), + copy.deepcopy(cache_v), + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + # Step 3: Decode with new split ops + out_decode, _, _ = self._run_decode_unified_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_append_f = out_append.astype("float32").numpy() + out_decode_f = out_decode.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_append_f = out_append_f[:dec_token_num] + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_append_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention doesn't match append_attention decode", + ) + + +class TestDecodeUnifiedAttentionC16Speculate(TestDecodeUnifiedAttentionC16): + """Test with speculate decode: max_tokens_per_batch=2. + + When max_tokens_per_batch > 1, naive ref only computes 1 token while ops + compute multiple tokens. So naive comparison tests are skipped; only + append_attention vs decode_unified_attention comparison is kept. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def test_naive_vs_append_attention_decode(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + def test_naive_vs_decode_unified_attention(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + +class TestDecodeUnifiedAttentionC16MultiBatch(TestDecodeUnifiedAttentionC16): + """Test with multiple batches.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16MultiHead(TestDecodeUnifiedAttentionC16): + """Test with multiple KV heads (GQA).""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 2 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16FP16(TestDecodeUnifiedAttentionC16): + """Test with float16 dtype.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "float16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16NoCausal(TestDecodeUnifiedAttentionC16): + """Test with causal=False.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = False + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16MultiBatchSpeculate(TestDecodeUnifiedAttentionC16): + """Test with multi-batch + speculate decode. + + When max_tokens_per_batch > 1, the naive reference only computes 1 token + while ops compute multiple tokens. So we only compare append_attention vs + decode_unified_attention (both should produce same result), and skip the + naive comparison tests. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def test_naive_vs_append_attention_decode(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + def test_naive_vs_decode_unified_attention(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/attention/test_decode_unified_attention_c8.py b/tests/operators/attention/test_decode_unified_attention_c8.py new file mode 100644 index 00000000000..d5ec0e5354c --- /dev/null +++ b/tests/operators/attention/test_decode_unified_attention_c8.py @@ -0,0 +1,921 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention, + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, + gqa_rope_write_cache, + pre_cache_len_concat, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, D/2] + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + # shape: [B, S, 1, D/2] + emb = paddle.unsqueeze(emb, 2) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeUnifiedAttention(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + def init_tensor(self): + # seq_lens + if self.seq_len_dec is None: + self.seq_lens_dec = [ + self.cache_len, + ] * self.batch_size + else: + self.batch_size = len(self.seq_lens_dec) + self.seq_lens_decoder = paddle.to_tensor( + self.seq_lens_dec, + "int32", + ) + if self.seq_lens_this_time is None: + self.seq_lens_this_time = [ + self.max_tokens_per_batch, + ] * self.batch_size + self.token_num = sum(self.seq_lens_this_time) + self.seq_lens_this_time = paddle.to_tensor(self.seq_lens_this_time, "int32") + + self.seq_lens_enc = [0] * self.batch_size + + self.seq_lens_encoder = paddle.to_tensor( + self.seq_lens_enc, + "int32", + ) + + # self.qkv = paddle.rand([self.token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim], dtype=self.dtype) + self.q, self.k, self.v, self.qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + self.qkv = paddle.to_tensor(self.qkv, dtype=self.dtype) + + # qk_norm + self.q_norm_weight = None + self.k_norm_weight = None + if self.use_qk_norm: + q_norm_weight_np = np.random.random([self.head_dim]) / 10 + k_norm_weight_np = np.random.random([self.head_dim]) / 10 + self.q_norm_weight = paddle.to_tensor(q_norm_weight_np, dtype="float32") + self.k_norm_weight = paddle.to_tensor(k_norm_weight_np, dtype="float32") + + # rotary embedding + self.rope = RopeEmbedding(False) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache_kv && scale + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + + if self.cache_quant_type == "block_wise_fp8": + self.cache_scale_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_k_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_v_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_k_out_scale = None + self.cache_v_out_scale = None + else: + self.cache_k_scale = ( + self.quant_max_bound / self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) + ).astype(self.dtype) + self.cache_v_scale = ( + self.quant_max_bound / self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) + ).astype(self.dtype) + + self.cache_k_out_scale = ( + self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) / self.quant_max_bound + ).astype(self.dtype) + self.cache_v_out_scale = ( + self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) / self.quant_max_bound + ).astype(self.dtype) + + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + + ( + self.batch_id_per_token, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, self.seq_lens_this_time) + + # mask offset + self.mask_offset = None + if self.use_mask_offset: + self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32") + for i in range(self.batch_size): + self.mask_offset[i * 2] = 0 + self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1 + + # buffer + self.buffer = {} + min_chunk_size = 512 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + self.group_size = self.q_num_head // self.kv_num_head + q_tile_size = 16 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + self.buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + # block_indices: Launched block's indices with 4 dimensions [batch_idx, kv_head_idx, chunk_idx, q_tile_idx] in decode append attention backend + self.buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + # num_blocks: Number of Launched blocks in decode append attention backend, researched by config_for_attention op + self.buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + # chunk_size: Chunk size for split kv cache in decode append attention backend, researched by config_for_attention op + self.buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + # tmp_workspace: Workspace tensor for temporary store the result before merging in decode append attention backend + self.buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + # tmp_m: Tmp_m tensor for temporary store the max value before merging in decode append attention backend + self.buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + # tmp_d: Tmp_d tensor for temporary store the exponential sum before merging in decode append attention backend + self.buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + + def append_attention_with_args( + self, + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run append_attention with explicit arguments.""" + # buffer + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + out = append_attention( + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + max_len_tensor_cpu, + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + self.q_norm_weight, + self.k_norm_weight, + None, # sinks + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + -1, + 64, + 16, + self.max_model_len, + 1024, + self.max_tokens_per_batch, + self.causal, + self.max_tokens_per_batch > 1, + self.sliding_window, + ) + return out, cache_k, cache_v + + def append_attention(self): + """Convenience wrapper using default self members.""" + return self.append_attention_with_args( + copy.deepcopy(self.qkv), + copy.deepcopy(self.cache_k), + copy.deepcopy(self.cache_v), + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + ) + + def decode_unified_attention(self): + paddle.disable_static() + + config_for_attention( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + # print(f"num_blocks: {self.buffer['num_blocks']}") + decoder_write_cache_with_rope( + self.qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["max_len_tensor_cpu"], + self.rotary_embs, # rotary_embs + None, # qkv_bias + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + self.q_norm_weight, # q_norm_weight + self.k_norm_weight, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_unified_attention( + self.qkv, + self.cache_k, + self.cache_v, + self.buffer["tmp_workspace"], + self.buffer["tmp_m"], + self.buffer["tmp_d"], + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], # set_max_lengths + None, # attn_mask + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # mask_offset + None, # sinks # sinks + paddle.empty([self.qkv.shape[0], self.q_num_head * self.head_dim], dtype=self.qkv.dtype), # fmha_out + self.cache_quant_type, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + self.sliding_window, + ) + return self.qkv, out + + def prefill(self): + # init seq_len + seq_lens_encoder = copy.deepcopy(self.seq_lens_decoder) + seq_lens_decoder = paddle.zeros([self.batch_size], dtype="int32") + seq_lens_this_time = seq_lens_encoder + token_num = seq_lens_this_time.sum().item() + qkv_np = np.random.random([token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim]) - 0.5 + qkv = paddle.to_tensor(qkv_np, dtype=self.dtype) + + ( + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(self.batch_size, seq_lens_this_time) + # buffer + decode_max_tile_size = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + ( + cu_seqlens_k, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + kv_token_num_cpu, + ) = pre_cache_len_concat( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + max_len_tensor_cpu[2], + self.block_size, + ) + q, k, v, _ = gqa_rope_write_cache( + qkv, + self.cache_k, + self.cache_v, + cu_seqlens_q, + cu_seqlens_k, + self.rotary_embs, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + self.block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + self.q_norm_weight, + self.k_norm_weight, + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + kv_token_num_cpu[0].item(), + self.max_model_len, + self.rms_norm_eps, + False, # use_neox_rotary_style + self.cache_quant_type, + self.rope_3d, + ) + + k = k.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + v = v.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + return k, v + + def test_all(self): + """Compare append_attention vs decode_unified_attention output for consistency.""" + # Step 1: Prefill - just write K/V to cache via gqa_rope_write_cache + self.prefill() + + # Step 2: Decode with append_attention (copy cache so it's not modified) + dec_seq_lens_encoder = paddle.zeros([self.batch_size], dtype="int32") + dec_seq_lens_decoder = copy.deepcopy(self.seq_lens_decoder) + + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, dtype="int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_append_dec, _, _ = self.append_attention_with_args( + copy.deepcopy(self.qkv), + copy.deepcopy(self.cache_k), + copy.deepcopy(self.cache_v), + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + # Step 3: Decode with decode_unified_attention (uses self.cache_k/v directly) + _, out_decode = self.decode_unified_attention() + + # Step 4: Compare + out_append_f = out_append_dec.astype("float32").numpy() + out_decode_f = out_decode.astype("float32").numpy() + + np.testing.assert_allclose( + out_decode_f, + out_append_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention output doesn't match append_attention output", + ) + + +class TestDecodeUnifiedAttentionMultiBatch(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 60 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionSpeculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionMultiHead(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionMultiSpeculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionSpeculateBs128Mtp4(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 128 + self.max_tokens_per_batch = 4 + self.cache_len = 508 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 2048 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8MultiBatch(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 60 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8Speculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionQKNorm(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = True + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +if __name__ == "__main__": + unittest.main() From c52b06310675a1a3a1ae8d0f6d472a4ae47668c1 Mon Sep 17 00:00:00 2001 From: sunxin <68891411+Sunny-bot1@users.noreply.github.com> Date: Tue, 26 May 2026 17:35:40 +0800 Subject: [PATCH 131/143] [Cherry-Pick][Optimization][Speculative Decoding]opt mtp logprob (#7883) (#7884) * opt mtp logprob * fix * fix test and log * fix bits * Adapt logprobs baseline update in test_ernie_21b_mtp_multistep.py --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> --- .../mtp_save_first_token_with_topk.cc | 32 ++++++++--------- .../speculate_get_output_with_topk.cc | 8 +++-- .../speculate_save_output_with_topk.cc | 27 +++++++------- fastdeploy/output/token_processor.py | 36 ++++++++++++------- fastdeploy/worker/gpu_model_runner.py | 14 +++----- tests/e2e/test_ernie_21b_mtp_multistep.py | 4 +-- tests/output/test_process_batch_output.py | 19 +++++----- tests/output/test_token_processor.py | 4 +-- 8 files changed, 76 insertions(+), 68 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc index 02203a51cff..0ec49b854ae 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc @@ -119,9 +119,11 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - msg_sed.meta[2] = bsz; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; for (int i = 0; i < bsz; i++) { int cur_token_num; if (seq_lens_decoder_data[i] < prompt_lens_data[i] || @@ -139,29 +141,24 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; if (j == 0) { // first token has full logprobs - for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + for (int k = 0; k < max_num_logprobs; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else if (k < max_num_logprobs) { - // only for first token - cur_tokens[k] = - (int)logprob_token_ids_data[(token_offset + j) * - (SPEC_LOGPROB_K + 1) + - k]; - cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } } cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; @@ -174,7 +171,8 @@ void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << msg_sed.meta[0] - << ", message_flag: " << msg_sed.meta[1] + << ", message_flag: " << (msg_sed.meta[1] & 0xFF) + << ", max_num_logprobs: " << (msg_sed.meta[1] >> 8) << ", bsz: " << msg_sed.meta[2] << std::endl; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_sed.meta[3 + i]; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 4fd7d4103c4..3e5ed2430b0 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -75,8 +75,11 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, int bsz = msg_rcv.meta[2]; output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; + // Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from + // meta[1]. Keep packed value; Python unpacks message_flag and actual_topk. output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; + int actual_topk = msg_rcv.meta[1] >> 8; int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ; for (int i = 0; i < bsz; i++) { @@ -89,7 +92,7 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; for (int j = 0; j < cur_token_num; j++) { - for (int k = 0; k < real_k + 1; k++) { + for (int k = 0; k < actual_topk; k++) { cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] = (int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k]; cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] = @@ -102,7 +105,8 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, #ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << output_tokens_data[0] - << ", message_flag: " << output_tokens_data[1] + << ", message_flag: " << (output_tokens_data[1] & 0xFF) + << ", max_num_logprobs: " << (output_tokens_data[1] >> 8) << ", bsz: " << output_tokens_data[2] << std::endl; for (int i = 0; i < output_tokens_data[2]; i++) { int cur_token_num = output_tokens_data[3 + i]; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 0b3de384cee..a11897b7ff3 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -121,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - msg_sed.meta[2] = bsz; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; for (int i = 0; i < bsz; i++) { int cur_token_num; if (seq_lens_decoder_data[i] < prompt_lens_data[i]) { @@ -139,24 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; - for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + for (int k = 0; k < max_num_logprobs; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else if (k < max_num_logprobs) { + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; + } else { cur_tokens[k] = (int) - logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k]; cur_scores[k] = - logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + - k]; - } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } } cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; @@ -165,7 +163,8 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << msg_sed.meta[0] - << ", message_flag: " << msg_sed.meta[1] + << ", message_flag: " << (msg_sed.meta[1] & 0xFF) + << ", max_num_logprobs: " << (msg_sed.meta[1] >> 8) << ", bsz: " << msg_sed.meta[2] << std::endl; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_sed.meta[3 + i]; diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 8b06c96c9d8..21e29e4c571 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -796,12 +796,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, metrics=None, ) - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + tokens_i = tokens[i].tolist() + scores_i = scores[i].tolist() + ranks_i = ranks[i].tolist() + token_ids = [row[0] for row in tokens_i[: accept_num[i]]] for batch_token_index in range(len(token_ids)): - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_i[batch_token_index][0] + topk_token_ids = tokens_i[batch_token_index] + topk_logprobs = scores_i[batch_token_index] + sampled_rank = ranks_i[batch_token_index] if result.outputs.draft_top_logprobs is None: result.outputs.draft_top_logprobs = LogprobsLists( @@ -828,16 +831,19 @@ def _process_batch_output(self): mtype = 3 if self.cfg.speculative_config.method: if self.use_logprobs: - mtype = int(self.output_tokens[1, 0].item()) + # meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits). + packed_meta1 = int(self.output_tokens[1, 0].item()) + mtype = packed_meta1 & 0xFF + actual_topk = packed_meta1 >> 8 batch = self.output_tokens[2, 0] accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( [batch, MAX_DRAFT_TOKENS, K + 1] - ) + )[:, :, :actual_topk] scores = ( self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] .numpy() - .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + .reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk] ) ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) @@ -846,6 +852,10 @@ def _process_batch_output(self): batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks) self.postprocess(batch_result, mtype) return + # Pre-convert full arrays to Python lists once for MTP target token path. + tokens_lists = tokens.tolist() + scores_lists = scores.tolist() + ranks_list = ranks.tolist() else: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] @@ -914,7 +924,7 @@ def _process_batch_output(self): llm_logger.info(f"recovery stop signal found at task {task_id}") token_ids = [RECOVERY_STOP_SIGNAL] elif self.use_logprobs: - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]] else: token_ids = tokens[ 2 @@ -1033,10 +1043,10 @@ def _process_batch_output(self): task.output_token_ids.append(token_id) if self.use_logprobs: if self.cfg.speculative_config.method: - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_lists[i][batch_token_index][0] + topk_token_ids = tokens_lists[i][batch_token_index] + topk_logprobs = scores_lists[i][batch_token_index] + sampled_rank = ranks_list[i][batch_token_index] else: # Use pre-converted lists (batch .tolist() done before the loop). result.outputs.logprob = scores_lists[i][0] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a6033693542..69ef6a6bbe3 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1226,15 +1226,11 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs ) if logprobs_reqs: - self.max_logprobs = ( - max( - [ - self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs - for req in logprobs_reqs - ] - ) - if not self.speculative_decoding - else 20 + self.max_logprobs = max( + [ + self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs + for req in logprobs_reqs + ] ) elif self.enable_logprob: self.max_logprobs = None if not self.speculative_decoding else 0 diff --git a/tests/e2e/test_ernie_21b_mtp_multistep.py b/tests/e2e/test_ernie_21b_mtp_multistep.py index 8c4e3b6bab4..9f84b495f8b 100644 --- a/tests/e2e/test_ernie_21b_mtp_multistep.py +++ b/tests/e2e/test_ernie_21b_mtp_multistep.py @@ -212,11 +212,11 @@ def test_prefix_cache_text(api_url): if os.getenv("BASELINE") == "1": baseline_manager.save("base_21b_step3", result) baseline_manager.save("base_21b_mtp_metrics_step3", speculate_metrics_2) - baseline_manager.save("base_21b_logprobs_step3", logprobs_2) + baseline_manager.save("base_21b_logprobs_step3_new", logprobs_2) baseline_result = baseline_manager.load("base_21b_step3") baseline_mtp_metrics = baseline_manager.load("base_21b_mtp_metrics_step3") - baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3") + baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3_new") assert logprobs == logprobs_2, ( "logprobs 前后不一致\n" diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 2853e47be15..04aa08935af 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -211,8 +211,9 @@ def test_speculative_decoding_use_logprobs(self): # stop_flag processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) - # mtype target = 3, decode = 4 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits) + actual_topk = K + 1 + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8))) # batch processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) # accept_num @@ -244,12 +245,12 @@ def test_speculative_decoding_use_logprobs(self): assert len(request_output.outputs.token_ids) == accept_num[i] assert len(request_output.outputs.top_logprobs) == 3 # tokens, scores, ranks - assert len(request_output.outputs.top_logprobs[0][0]) == K + 1 - assert len(request_output.outputs.top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk + assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk assert len(request_output.outputs.top_logprobs[2]) == accept_num[i] # mtype = 4 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4)) + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8))) processor._process_batch_output() cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens for c in cached_generated_tokens.cache: @@ -258,8 +259,8 @@ def test_speculative_decoding_use_logprobs(self): assert len(request_output.outputs.top_logprobs) == 3 assert len(request_output.outputs.draft_top_logprobs) == 3 # tokens, scores, ranks - assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1 - assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk + assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i] def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self): @@ -282,8 +283,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s # Set up output tokens with negative token # stop_flag processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) - # mtype target = 3 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # mtype target = 3, actual_topk packed in high bits + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8))) # batch = 2 (so batch_id=0 is < batch_size-1=1) processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) # Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 4240e84c75a..ca63c17c903 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -749,7 +749,7 @@ def test_process_batch_output_speculative_logprob_handles_draft_batch(): ) processor._batch_result_buffer = [target] processor.cached_generated_tokens = mock.Mock() - processor.output_tokens[1, 0] = 4 + processor.output_tokens[1, 0] = 4 | ((K + 1) << 8) processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 1 @@ -926,7 +926,7 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores(): task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task - processor.output_tokens[1, 0] = 3 + processor.output_tokens[1, 0] = 3 | ((K + 1) << 8) processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 2 token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3 From 261041b6a67ff4398101142b61919747bc2d46de Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 27 May 2026 14:01:00 +0800 Subject: [PATCH 132/143] [Cherry-Pick][Bugfix] Fix clear bug in RL causing CUDA error 700 during CUDAGraph recapture(#7934) (#7933) * fix clear bug in rl * fix: use self.max_chunk_tokens instead of fd_config.get_max_chunk_tokens() for buffer recreation fd_config.get_max_chunk_tokens() without mm_max_tokens_per_item arg may return a smaller value than the actual initial buffer size when enable_mm and mm_max_tokens_per_item is None. Use self.max_chunk_tokens which is already computed during __init__ and consistent with first CUDAGraph capture. --- fastdeploy/worker/input_batch.py | 42 ++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index e09d5c81193..687ae24088f 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -936,8 +936,12 @@ def reset_model_inputs(self) -> None: self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"]) self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"]) fill_paddle_tensor(self, "input_ids_cpu", -1) - # acceptance rate decline when reset seq_lens_this_time - # self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"]) + # NOTE(fix): Must reset seq_lens_this_time_buffer to avoid stale values from previous + # RL round causing illegal memory access during CUDAGraph recapture (error 700). + # When draft_model_use_cudagraph=true, padding_cudagraph_inputs() uses the full + # seq_lens_this_time_buffer tensor; residual non-zero values in high-index slots + # (from previous round) will make attention kernel access invalid block_table entries. + fill_paddle_tensor(self, "seq_lens_this_time_buffer", 0) self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"]) self.seq_lens_decoder = paddle.clone(self.target_model_input_batch["seq_lens_decoder"]) @@ -946,8 +950,19 @@ def reset_model_inputs(self) -> None: self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"]) self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"]) self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu") + self.not_need_stop_device = paddle.to_tensor([False], dtype="bool") self.index_to_batch_id = {} if current_platform.is_cuda(): + # NOTE(fix): These tensors get reshaped during runtime inference, so we must + # recreate them at full initial size instead of cloning the (possibly resized) + # target_model_input_batch tensors. Otherwise CUDAGraph replay will write + # beyond tensor boundaries causing CUDA error(700). + max_num_seqs = self.scheduler_config.max_num_seqs + max_draft_token_num = self.speculative_config.num_speculative_tokens + self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") + self.batch_id_per_token_output = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32" + ) if "token_ids_all" in self.target_model_input_batch: self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"]) # TODO: delete pre_ids in mtp @@ -967,13 +982,24 @@ def reset_model_inputs(self) -> None: self.token_ids_all = None else: self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) - self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) - self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) - self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"]) - self.cu_seqlens_k = paddle.clone(self.target_model_input_batch["cu_seqlens_k"]) - # Reset target hidden states - fill_paddle_tensor(self, "target_hidden_states", 0) + # NOTE(fix): These tensors are dynamically resized during runtime inference. + # Must recreate at full initial size to avoid CUDAGraph replay OOB access. + max_num_seqs = self.scheduler_config.max_num_seqs + self.ids_remove_padding = paddle.full([max_num_seqs * self.max_chunk_tokens], 0, dtype="int64") + self.batch_id_per_token = paddle.full([max_num_seqs * self.max_chunk_tokens, 1], 0, dtype="int32") + self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32") + self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") + + # Reset target hidden states - must recreate at full size + self.target_hidden_states = paddle.full( + [ + self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens, + self.model_config.hidden_size, + ], + 0, + dtype="bfloat16", + ) # Reset rope embedding by recreating with default position_ids tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) From 8a1e71d0c13b4cda6086c7b3d26ae984bcf1099b Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 27 May 2026 20:34:24 +0800 Subject: [PATCH 133/143] [PD] PD send cache via storage & Refine swap_cache_layout op (#7839) * PD send cache via storage & Refine swap_cache_layout op * skip messager * up * consider write cache error * fix ci * up --- custom_ops/gpu_ops/swap_cache_layout.cu | 305 ++++++++++++++---- examples/cache_storage/run_03b_pd_storage.sh | 2 +- fastdeploy/cache_manager/cache_messager.py | 1 + .../cache_manager/cache_transfer_manager.py | 68 +++- .../cache_manager/prefix_cache_manager.py | 250 +++++++++++++- fastdeploy/engine/args_utils.py | 4 + fastdeploy/engine/common_engine.py | 26 ++ .../engine/common_engine_prepare_mixin.py | 4 +- .../engine/sched/resource_manager_v1.py | 8 +- fastdeploy/envs.py | 4 + .../layers/attention/append_attn_backend.py | 10 +- .../layers/attention/dsa_attention_backend.py | 11 +- .../layers/attention/flash_attn_backend.py | 10 +- .../attention/flash_mask_attn_backend.py | 11 +- .../layers/attention/mla_attention_backend.py | 15 +- fastdeploy/output/token_processor.py | 106 ++++-- fastdeploy/splitwise/splitwise_connector.py | 2 +- .../test_cache_transfer_manager.py | 9 +- .../test_prefix_cache_manager.py | 1 + 19 files changed, 705 insertions(+), 142 deletions(-) diff --git a/custom_ops/gpu_ops/swap_cache_layout.cu b/custom_ops/gpu_ops/swap_cache_layout.cu index 62adccb2d04..08f64197f9b 100644 --- a/custom_ops/gpu_ops/swap_cache_layout.cu +++ b/custom_ops/gpu_ops/swap_cache_layout.cu @@ -15,74 +15,264 @@ #include "helper.h" #include "paddle/extension.h" -// #define SWAP_DEBUG +// D2H: Each thread block handles ALL layers for one swap block. +// This produces perfectly contiguous host writes (1 block × all layers), +// maximizing write-combining efficiency. +template +__global__ void swap_d2h_kernel(T** __restrict__ layer_ptrs, + T* __restrict__ cpu_buffer, + const int64_t* __restrict__ gpu_block_ids, + int n_blocks, + int layer_num, + int64_t block_stride) { + int block_idx = blockIdx.x; + if (block_idx >= n_blocks) return; + + int64_t gpu_block = gpu_block_ids[block_idx]; + int64_t num_vec_per_layer = (block_stride * sizeof(T)) / sizeof(float4); + + T* dst_base = cpu_buffer + (int64_t)block_idx * layer_num * block_stride; + + for (int layer_idx = 0; layer_idx < layer_num; layer_idx++) { + const T* src = layer_ptrs[layer_idx] + gpu_block * block_stride; + float4* dst4 = + reinterpret_cast(dst_base + layer_idx * block_stride); + const float4* src4 = reinterpret_cast(src); + + for (int64_t i = threadIdx.x; i < num_vec_per_layer; i += blockDim.x) { + dst4[i] = src4[i]; + } + } +} + +// H2D: scatter from contiguous staging buffer to scattered GPU layer tensors +template +__global__ void scatter_blocks_kernel(T** __restrict__ layer_ptrs, + const T* __restrict__ staging, + const int64_t* __restrict__ gpu_block_ids, + int n_blocks, + int layer_num, + int64_t block_stride) { + int pair_idx = blockIdx.x; + int block_idx = pair_idx / layer_num; + int layer_idx = pair_idx % layer_num; + + if (block_idx >= n_blocks) return; + + int64_t gpu_block = gpu_block_ids[block_idx]; + const T* src = staging + (int64_t)block_idx * layer_num * block_stride + + layer_idx * block_stride; + T* dst = layer_ptrs[layer_idx] + gpu_block * block_stride; + + int64_t num_vec = (block_stride * sizeof(T)) / sizeof(float4); + const float4* src4 = reinterpret_cast(src); + float4* dst4 = reinterpret_cast(dst); + + for (int64_t i = threadIdx.x; i < num_vec; i += blockDim.x) { + dst4[i] = src4[i]; + } +} + +static void* g_staging_buffer = nullptr; +static size_t g_staging_buffer_size = 0; +static void* g_device_block_ids = nullptr; +static size_t g_device_block_ids_size = 0; +static void* g_device_layer_ptrs = nullptr; +static size_t g_device_layer_ptrs_size = 0; + +static void ensure_staging_buffer(size_t required_size) { + if (g_staging_buffer_size < required_size) { + if (g_staging_buffer) cudaFree(g_staging_buffer); + cudaError_t err = cudaMalloc(&g_staging_buffer, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc staging buffer failed: %s", + cudaGetErrorString(err))); + g_staging_buffer_size = required_size; + } +} + +static void ensure_device_block_ids(size_t required_size) { + if (g_device_block_ids_size < required_size) { + if (g_device_block_ids) cudaFree(g_device_block_ids); + cudaError_t err = cudaMalloc(&g_device_block_ids, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc device block_ids failed: %s", + cudaGetErrorString(err))); + g_device_block_ids_size = required_size; + } +} + +static void ensure_device_layer_ptrs(size_t required_size) { + if (g_device_layer_ptrs_size < required_size) { + if (g_device_layer_ptrs) cudaFree(g_device_layer_ptrs); + cudaError_t err = cudaMalloc(&g_device_layer_ptrs, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc device layer_ptrs failed: %s", + cudaGetErrorString(err))); + g_device_layer_ptrs_size = required_size; + } +} + +static bool is_cpu_block_ids_sequential( + const std::vector& cpu_block_ids) { + if (cpu_block_ids.empty()) return true; + int64_t start = cpu_block_ids[0]; + for (size_t i = 1; i < cpu_block_ids.size(); i++) { + if (cpu_block_ids[i] != start + static_cast(i)) return false; + } + return true; +} template -void SwapCacheImpLayout( - const std::vector& cache_gpu_tensors, // gpu - const int64_t& cache_cpu_pointer, // cpu - const std::vector& cache_shape, - const std::vector& gpu_block_ids, - const std::vector& cpu_block_ids, - int mode) { - /* - mode is 0: gpu to cpu; 1: cpu to gpu - - cache layout: layer_num * [block_num, head_num, block_size, head_dim] - scale layout: layer_num * [block_num, head_num, block_size] - cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim] - scale buffer layout: [block_num, layer_num, head_num, block_size] - */ +void SwapCacheImpLayout(const std::vector& cache_gpu_tensors, + const int64_t& cache_cpu_pointer, + const std::vector& cache_shape, + const std::vector& gpu_block_ids, + const std::vector& cpu_block_ids, + int mode) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int64_t layer_number = cache_gpu_tensors.size(); int64_t cache_block_stride = 1; - for (int i = 1; i < cache_shape.size(); i++) { + for (size_t i = 1; i < cache_shape.size(); i++) { cache_block_stride *= cache_shape[i]; } + const int n_blocks = gpu_block_ids.size(); + if (n_blocks == 0) return; + auto stream = cache_gpu_tensors[0].stream(); - const cudaMemcpyKind copy_kind = - (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; - - for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) { - const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; - data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); - auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); - - for (int block_idx = 0; block_idx < gpu_block_ids.size(); block_idx++) { - auto cur_gpu_block_id = gpu_block_ids[block_idx]; - auto cur_cpu_block_id = cpu_block_ids[block_idx]; - auto* cache_gpu_ptr_now = - cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; - auto* cache_cpu_ptr_now = - cache_cpu_ptr + cur_cpu_block_id * cache_block_stride * layer_number + - layer_idx * cache_block_stride; - - cudaError_t status = cudaMemcpyAsync( - (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now - : cache_gpu_ptr_now, - (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now - : cache_cpu_ptr_now, - cache_block_stride * sizeof(DataType_), - copy_kind, - stream); + const size_t block_bytes = cache_block_stride * sizeof(DataType_); + const size_t total_bytes = (size_t)n_blocks * layer_number * block_bytes; + + bool use_optimized = is_cpu_block_ids_sequential(cpu_block_ids); + + // float4 vectorized kernels require block_bytes to be 16-byte aligned + // and cache_cpu_base to be 16-byte aligned for correct float4 access. + if (use_optimized && (block_bytes % sizeof(float4) != 0)) { + use_optimized = false; + } + if (use_optimized) { + int64_t cpu_start_block = cpu_block_ids[0]; + uintptr_t cpu_base_addr = + static_cast(cache_cpu_pointer) + + cpu_start_block * layer_number * cache_block_stride * sizeof(DataType_); + if (cpu_base_addr % sizeof(float4) != 0) { + use_optimized = false; + } + } + if (use_optimized) { + ensure_device_block_ids(n_blocks * sizeof(int64_t)); + ensure_device_layer_ptrs(layer_number * sizeof(DataType_*)); + + cudaError_t status = cudaMemcpyAsync(g_device_block_ids, + gpu_block_ids.data(), + n_blocks * sizeof(int64_t), + cudaMemcpyHostToDevice, + stream); + PADDLE_ENFORCE_EQ( + status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync block_ids H2D failed: %s", + cudaGetErrorString(status))); + + std::vector h_layer_ptrs(layer_number); + for (int64_t i = 0; i < layer_number; i++) { + h_layer_ptrs[i] = reinterpret_cast( + const_cast(cache_gpu_tensors[i].data())); + } + status = cudaMemcpyAsync(g_device_layer_ptrs, + h_layer_ptrs.data(), + layer_number * sizeof(DataType_*), + cudaMemcpyHostToDevice, + stream); + PADDLE_ENFORCE_EQ( + status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync layer_ptrs H2D failed: %s", + cudaGetErrorString(status))); + + int64_t cpu_start_block = cpu_block_ids[0]; + auto* cache_cpu_base = reinterpret_cast(cache_cpu_pointer) + + cpu_start_block * layer_number * cache_block_stride; + + int grid_size = n_blocks * layer_number; + + if (mode == 0) { + // GPU→CPU: direct kernel write to pinned host memory + // Multi-layer kernel: each block handles all layers for one swap block + swap_d2h_kernel<<>>( + reinterpret_cast(g_device_layer_ptrs), + cache_cpu_base, + reinterpret_cast(g_device_block_ids), + n_blocks, + layer_number, + cache_block_stride); + } else { + // CPU→GPU: DMA memcpy to staging then scatter kernel + ensure_staging_buffer(total_bytes); + + status = cudaMemcpyAsync(g_staging_buffer, + cache_cpu_base, + total_bytes, + cudaMemcpyHostToDevice, + stream); PADDLE_ENFORCE_EQ(status, cudaSuccess, - phi::errors::External("cudaMemcpyAsync failed: %s", + phi::errors::External("cudaMemcpyAsync H2D failed: %s", cudaGetErrorString(status))); -#ifdef SWAP_DEBUG - cudaStreamSynchronize(stream); - std::cout << "mode:" << mode << ", layer_idx:" << layer_idx - << ", block_idx:" << block_idx << ", cache_cpu_ptr_now data:" - << static_cast(*cache_cpu_ptr_now) << std::endl; -#endif + scatter_blocks_kernel<<>>( + reinterpret_cast(g_device_layer_ptrs), + reinterpret_cast(g_staging_buffer), + reinterpret_cast(g_device_block_ids), + n_blocks, + layer_number, + cache_block_stride); + } + } else { + const cudaMemcpyKind copy_kind = + (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; + for (int64_t layer_idx = 0; layer_idx < layer_number; layer_idx++) { + const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; + data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); + auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); + + for (int block_idx = 0; block_idx < n_blocks; block_idx++) { + auto cur_gpu_block_id = gpu_block_ids[block_idx]; + auto cur_cpu_block_id = cpu_block_ids[block_idx]; + auto* cache_gpu_ptr_now = + cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; + auto* cache_cpu_ptr_now = + cache_cpu_ptr + + cur_cpu_block_id * cache_block_stride * layer_number + + layer_idx * cache_block_stride; + + cudaError_t status = cudaMemcpyAsync( + (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now + : cache_gpu_ptr_now, + (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now + : cache_cpu_ptr_now, + block_bytes, + copy_kind, + stream); + PADDLE_ENFORCE_EQ(status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync failed: %s", + cudaGetErrorString(status))); + } } } + cudaError_t sync_status = cudaStreamSynchronize(stream); PADDLE_ENFORCE_EQ(sync_status, cudaSuccess, @@ -90,15 +280,14 @@ void SwapCacheImpLayout( cudaGetErrorString(sync_status))); } -void SwapCacheLayout( - const std::vector& cache_gpu_tensors, // gpu - const int64_t& cache_cpu_ptrs, // cpu memory pointer - const std::vector& cache_shape, - const std::vector& gpu_block_ids, - const std::vector& cpu_block_ids, - int rank, - int mode) { - cudaSetDevice(rank); // used for distributed launch +void SwapCacheLayout(const std::vector& cache_gpu_tensors, + const int64_t& cache_cpu_ptrs, + const std::vector& cache_shape, + const std::vector& gpu_block_ids, + const std::vector& cpu_block_ids, + int rank, + int mode) { + cudaSetDevice(rank); assert(cache_gpu_tensors.size() > 0); switch (cache_gpu_tensors[0].dtype()) { case paddle::DataType::BFLOAT16: diff --git a/examples/cache_storage/run_03b_pd_storage.sh b/examples/cache_storage/run_03b_pd_storage.sh index 5577a0ebf27..c940fe9a8ef 100644 --- a/examples/cache_storage/run_03b_pd_storage.sh +++ b/examples/cache_storage/run_03b_pd_storage.sh @@ -18,7 +18,7 @@ metadata_port=15002 export MOONCAKE_MASTER_SERVER_ADDR="${master_ip}:${master_port}" export MOONCAKE_METADATA_SERVER="http://${master_ip}:${metadata_port}/metadata" -export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000" +export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000" # 50GB # export MOONCAKE_PROTOCOL="tcp" export MOONCAKE_PROTOCOL="rdma" # export MOONCAKE_RDMA_DEVICES="mlx5_0" diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 4a188edd161..33407b785e5 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -705,6 +705,7 @@ def prefill_layerwise_send_cache_thread(self): try: batch_engine_signals = self.cache_prefilled_engine_ids_queue.get() self.engine_worker_queue.begin_send_cache_barrier.wait() + block_start_end_list = [] current_prefilled_token_num_list = [] for engine_index, current_step_prefilled_token_num in batch_engine_signals: diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 36306ee5dc6..d3f26511372 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -1131,13 +1131,42 @@ def _run_write_back_storage( target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2) start_time = time.time() - self.storage_backend.batch_set(keys=keys, target_locations=target_locations, target_sizes=target_sizes) + result = self.storage_backend.batch_set( + keys=keys, target_locations=target_locations, target_sizes=target_sizes + ) write_cost_time = time.time() - start_time + # Per-block success validation (same pattern as _run_read_storage) + # batch_set returns List[int]: 0 = success, negative = error + if k_scale_keys and v_scale_keys: + k_result = result[:block_num] + v_result = result[block_num : 2 * block_num] + k_scale_result = result[2 * block_num : 3 * block_num] + v_scale_result = result[3 * block_num :] + success_block_num = 0 + for k, v, ks, vs in zip(k_result, v_result, k_scale_result, v_scale_result): + if not (k == 0 and v == 0 and ks == 0 and vs == 0): + break + success_block_num += 1 + else: + k_result = result[:block_num] + v_result = result[block_num : 2 * block_num] + success_block_num = 0 + for k, v in zip(k_result, v_result): + if not (k == 0 and v == 0): + break + success_block_num += 1 + + if success_block_num < block_num: + logger.error( + f"_run_write_back_storage partial failure: " + f"{success_block_num}/{block_num} blocks written, task_id: {task_id}" + ) + logger.debug( f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s" ) - return block_num + return success_block_num elif self.storage_backend_type == "attention_store": key_cache = [] @@ -1222,14 +1251,13 @@ def write_back_storage_task(self, task: WriteStorageTask): if match_block_num >= len(k_cache_keys): logger.info(f"No uncached keys found for task {task.task_id}") - gpu_block_ids = [] else: try: k_cache_keys = k_cache_keys[match_block_num:] v_cache_keys = v_cache_keys[match_block_num:] k_scale_keys = k_scale_keys[match_block_num:] if k_scale_keys else None v_scale_keys = v_scale_keys[match_block_num:] if v_scale_keys else None - gpu_block_ids = gpu_block_ids[match_block_num:] + write_gpu_block_ids = gpu_block_ids[match_block_num:] cpu_block_ids = cpu_block_ids[match_block_num:] # TODO: support timeout with actual block count write_block_num = self._run_write_back_storage( @@ -1240,19 +1268,28 @@ def write_back_storage_task(self, task: WriteStorageTask): v_cache_keys, k_scale_keys, v_scale_keys, - gpu_block_ids, + write_gpu_block_ids, cpu_block_ids, task.timeout, ) logger.info( f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" ) - # Write routing data to storage (shares dedup with KVCache) - remaining_keys = task.keys[match_block_num:] - self._write_routing_to_storage(remaining_keys, gpu_block_ids) + # Check for partial write failure + if write_block_num < len(write_gpu_block_ids): + logger.error( + f"Partial write failure for task {task.task_id}: " + f"{write_block_num}/{len(write_gpu_block_ids)} blocks written" + ) + # Report: match_block_num (already cached) + write_block_num (newly written) + gpu_block_ids = gpu_block_ids[: match_block_num + write_block_num] + # Write routing data to storage only for actually-written blocks + written_block_ids = write_gpu_block_ids[:write_block_num] + remaining_keys = task.keys[match_block_num : match_block_num + len(written_block_ids)] + self._write_routing_to_storage(remaining_keys, written_block_ids) except Exception as e: logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}") - gpu_block_ids = [] + gpu_block_ids = gpu_block_ids[:match_block_num] finally: try: if (self.rank == 0) and self.storage_backend_type == "attention_store": @@ -1265,14 +1302,19 @@ def write_back_storage_task(self, task: WriteStorageTask): result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids) self.cache_task_queue.swap_to_storage_barrier.wait() - if self.rank == 0: # 只有当rank为0时执行同步操作 - self.cache_task_queue.swap_to_storage_barrier.reset() - self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号 - logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") + self.cache_task_queue.put_transfer_done_signal(result) + logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") except Exception as e: logger.error( f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) + # Prevent caller from blocking forever: send empty done signal + try: + result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, []) + self.cache_task_queue.swap_to_storage_barrier.wait() + self.cache_task_queue.put_transfer_done_signal(result) + except Exception as barrier_err: + logger.error(f"Failed to send failure signal for task {task.task_id}: {barrier_err}") def _do_swap_to_cpu_task( self, diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index c41a6109029..b21b4349172 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -97,6 +97,7 @@ def __init__( self.kvcache_storage_backend = self.cache_config.kvcache_storage_backend self.write_policy = self.cache_config.write_policy self.task_write_back_event = {} + self.storage_write_back_result = {} self.task_prefetch_event = {} self.storage_prefetch_block_ids = {} @@ -1186,9 +1187,15 @@ def write_cache_to_storage(self, request: Request): ) logger.debug(f"issue write storage task: {write_storage_task}") tic = time.time() - self.issue_write_back_storage_task(write_storage_task, is_sync=True) + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic - logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + if not success: + logger.error( + f"write cache back to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) def write_cache_to_storage_decode(self, request: Request): @@ -1257,6 +1264,118 @@ def write_cache_to_storage_decode(self, request: Request): # Incremental logic is handled by CacheTransferManager.write_back_storage_task() req_id = request.request_id logger.info(f"[D instance] start write cache to storage, req_id: {req_id}, block num: {len(keys)}") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) + + write_storage_task = WriteStorageTask( + task_id=req_id, + keys=keys, + token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None, + gpu_block_ids=gpu_block_ids, + ) + + tic = time.time() + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) + cost_time = time.time() - tic + if not success: + logger.error( + f"[D instance] write cache to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info(f"[D instance] finish write cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) + + def _compute_pd_storage_keys(self, request: Request, input_token_ids: list): + """ + Compute cache keys (including :partial:N suffix for last incomplete block) + for PD storage-pool mode. Used by both write_all_cache_to_storage (P/D) and + read_cache_from_storage_for_pd (D) to ensure consistent key computation. + + Args: + request: The request object (needed for get_block_hash_extra_keys). + input_token_ids: The token IDs to compute keys for. + + Returns: + list: The computed hash keys for each block. + """ + keys = [] + prefix_block_key = [] + block_size = self.config.cache_config.block_size + mm_idx = 0 + + for i in range(0, len(input_token_ids), block_size): + block_token_ids = input_token_ids[i : i + block_size] + actual_token_num = len(block_token_ids) + + if actual_token_num < block_size: + # Last incomplete block: compute key with actual tokens + partial marker + key = get_hash_str(block_token_ids, prefix_block_key) + key = f"{key}:partial:{actual_token_num}" + keys.append(key) + else: + # Full block: compute key normally + mm_idx, extra_keys = self.get_block_hash_extra_keys( + request=request, + start_idx=i, + end_idx=i + block_size, + mm_idx=mm_idx, + ) + prefix_block_key.extend(extra_keys) + key = get_hash_str(block_token_ids, prefix_block_key) + keys.append(key) + + prefix_block_key = [key] + + return keys + + def write_all_cache_to_storage(self, request: Request, include_output=True): + """ + Write ALL token cache (including last incomplete block) to storage. + Used in PD storage-pool mode where P writes to storage instead of RDMA to D, + and D writes back all cache (including output tokens) on request completion. + + Unlike write_cache_to_storage_decode which skips incomplete blocks, this method + writes the last incomplete block by padding it to block_size in the storage key + computation (using a ':partial:N' suffix on the key). + + The actual GPU block is still full-sized, so swap_cache_layout works normally. + + Args: + request: The request object. + include_output: If True, include output_token_ids in the write (used by D). + If False, only write prompt_token_ids (used by P). + + Returns: + bool: True if all blocks written successfully, False otherwise. + """ + if self.kvcache_storage_backend is None: + return True + + # 1. Get complete token_ids + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + else: + token_ids = list(token_ids) + + input_token_ids = token_ids + request.output_token_ids if include_output else token_ids + + # 2. Calculate cache keys using shared helper + keys = self._compute_pd_storage_keys(request, input_token_ids) + + if not keys: + return True + + # 3. Get corresponding gpu_block_ids + gpu_block_ids = request.block_tables[: len(keys)] + + # 4. Construct WriteStorageTask and send + req_id = request.request_id + logger.info( + f"[PD Storage] start write all cache to storage, req_id: {req_id}, " + f"block num: {len(keys)}, total_tokens: {len(input_token_ids)}" + ) + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) write_storage_task = WriteStorageTask( task_id=req_id, @@ -1266,13 +1385,92 @@ def write_cache_to_storage_decode(self, request: Request): ) tic = time.time() - self.issue_write_back_storage_task(write_storage_task, is_sync=True) + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic - logger.info(f"[D instance] finish write cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + if not success: + logger.error( + f"[PD Storage] write all cache to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info( + f"[PD Storage] finish write all cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s" + ) + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) + return success + + def read_cache_from_storage_for_pd(self, request: Request): + """ + PD storage-pool mode: D instance reads cache from storage that P wrote. + + This is different from request_match_blocks() storage read: + - Called on D instance after receiving first_token notification from P + - Reads ALL blocks (including last partial block) that P wrote to storage + - Target gpu_block_ids are D's pre-allocated blocks + + Returns: + list: gpu_block_ids if all blocks fetched successfully, + empty list if any block failed to fetch (caller should abort this request). + """ + if self.kvcache_storage_backend is None: + return [] + + # 1. Get token_ids (same as what P prefilled) + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + else: + token_ids = list(token_ids) + input_token_ids = token_ids + + # 2. Calculate cache keys using shared helper (same algorithm as write_all_cache_to_storage) + keys = self._compute_pd_storage_keys(request, token_ids) + + if not keys: + return [] + + # 3. gpu_block_ids = D's pre-allocated block_tables + gpu_block_ids = request.block_tables[: len(keys)] + + # 4. Issue ReadStorageTask + req_id = request.request_id + logger.info( + f"[PD Storage] D start read cache from storage, req_id: {req_id}, " + f"block num: {len(keys)}, total_tokens: {len(input_token_ids)}" + ) + + read_task = ReadStorageTask( + task_id=req_id, + keys=keys, + token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None, + gpu_block_ids=gpu_block_ids, + start_read_block_idx=0, + ) + + tic = time.time() + storage_block_ids = self.issue_prefetch_storage_task(read_task, is_sync=True) + cost_time = time.time() - tic + + if len(storage_block_ids) != len(keys): + logger.error( + f"[PD Storage] D failed to read all blocks from storage, req_id: {req_id}, " + f"matched blocks: {len(storage_block_ids)}/{len(keys)}, cost_time: {cost_time:.6f}s" + ) + return [] + else: + logger.info( + f"[PD Storage] D finish reading the cache of all blocks from storage, req_id: {req_id}, " + f"matched blocks: {len(storage_block_ids)}/{len(keys)}, cost_time: {cost_time:.6f}s" + ) + return storage_block_ids def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): + """ + Issue a write-back storage task. + Returns True if all blocks written successfully (sync mode), True always (async mode). + """ if self.kvcache_storage_backend is None: - return + return True if not envs.FD_AS_ONLY_FLUSH and len(task.keys) != len(task.gpu_block_ids): err_msg = ( @@ -1285,15 +1483,37 @@ def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): self.task_write_back_event[task.task_id] = Event() self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task)) if is_sync: - self.wait_write_storage_task(task.task_id) + return self.wait_write_storage_task(task.task_id, expected_block_num=len(task.gpu_block_ids)) + return True - def wait_write_storage_task(self, req_id): + def wait_write_storage_task(self, req_id, expected_block_num=0, timeout=60.0): """ - Sync write back task + Sync write back task. + Returns True if all expected blocks written successfully across all TP ranks. + + Args: + req_id: request ID + expected_block_num: number of blocks expected to be written + timeout: max wait time in seconds """ if req_id in self.task_write_back_event: - self.task_write_back_event[req_id].wait() + success = self.task_write_back_event[req_id].wait(timeout=timeout) del self.task_write_back_event[req_id] + if not success: + logger.error(f"[PD Storage] write storage task timeout after {timeout}s, req_id: {req_id}") + self.storage_write_back_result.pop(req_id, None) + return False + # Check actual written block count vs expected + written_block_ids = self.storage_write_back_result.pop(req_id, []) + actual_written = len(written_block_ids) + if expected_block_num > 0 and actual_written < expected_block_num: + logger.error( + f"[PD Storage] write storage incomplete: {actual_written}/{expected_block_num} blocks, " + f"req_id: {req_id}" + ) + return False + return True + return True def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True): """ @@ -2226,8 +2446,16 @@ def recv_data_transfer_result(self): elif event_type.value == CacheStatus.GPU2STORAGE.value: logger.debug(f"recv_data_transfer_result: {data}") task_id, hash_keys, block_ids = data[1:] - if task_id in self.task_write_back_event: - self.task_write_back_event[task_id].set() + # Collect results from all TP ranks (same pattern as STORAGE2GPU path) + if task_id not in self.storage_write_back_result: + self.storage_write_back_result[task_id] = [] + saved_results = self.storage_write_back_result[task_id] + saved_results.append(block_ids) + if len(saved_results) == self.tensor_parallel_size: + # Take minimum across all ranks (conservative, same as read path) + self.storage_write_back_result[task_id] = min(saved_results, key=len) + if task_id in self.task_write_back_event: + self.task_write_back_event[task_id].set() else: ( event_type, diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index da4cc40ca6e..a96a6ace489 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -641,6 +641,10 @@ def __post_init__(self): "kvcache_storage_backend is only supported when ENABLE_V1_KVCACHE_SCHEDULER=1" ) + if envs.FD_PD_TRANSFER_VIA_STORAGE: + if self.kvcache_storage_backend is None: + raise ValueError("Must set kvcache_storage_backend when FD_PD_TRANSFER_VIA_STORAGE=1") + valid_model_impls = ["auto", "fastdeploy", "paddleformers"] if self.model_impl not in valid_model_impls: raise NotImplementedError( diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 0a49be9e73a..429cce3b0b7 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1971,6 +1971,32 @@ def _process_prefilled_requests(): self.token_processor.tokens_counter[request_id] = 1 if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance self.scheduler.put_results([req_output]) + + # Storage pool mode: D reads cache from storage before adding to running queue + if envs.FD_PD_TRANSFER_VIA_STORAGE: + request = self.resource_manager.requests[request_id] + self.llm_logger.info(f"[PD Storage] D reading cache from storage, request_id: {request_id}") + storage_block_ids = self.resource_manager.cache_manager.read_cache_from_storage_for_pd(request) + if not storage_block_ids: + self.llm_logger.error( + f"[PD Storage] D failed to read cache from storage, " f"request_id: {request_id}" + ) + self.resource_manager.pre_recycle_resource(request_id) + if request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[request_id] + req_output.error_code = 502 + req_output.error_msg = ( + f"PD Storage Error: D failed to read all blocks from storage, " + f"request_id: {request_id}" + ) + req_output.finished = True + self.scheduler.put_results([req_output]) + continue + self.llm_logger.info( + f"[PD Storage] D successfully read cache from storage, " + f"request_id: {request_id}, blocks: {len(storage_block_ids)}" + ) + self.resource_manager.add_prefilled_request(req_output) self.llm_logger.info(f"D has successfully added prefilled request, {request_id}") diff --git a/fastdeploy/engine/common_engine_prepare_mixin.py b/fastdeploy/engine/common_engine_prepare_mixin.py index 60ccb7ccd09..c6ea2f3ee9d 100644 --- a/fastdeploy/engine/common_engine_prepare_mixin.py +++ b/fastdeploy/engine/common_engine_prepare_mixin.py @@ -226,8 +226,8 @@ def _fetch_request_prefill(self) -> bool: tasks.remove(tmp_task) self.resource_manager.pre_recycle_resource(tmp_task.request_id) - # Send cache info to messager - if tasks: + # Send cache info to messager (skip in storage pool mode - messager is bypassed) + if tasks and not envs.FD_PD_TRANSFER_VIA_STORAGE: self.split_connector.send_cache_info_to_messager(tasks, 0) # Fetch requests and add them to the scheduling queue diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index e7792f44dc0..04ae315f2c1 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -869,6 +869,7 @@ def get_enough_request(request, scheduled_reqs): # First, schedule the RUNNING requests. req_index = 0 num_decoding_req_nums = 0 + while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] need_block_num = self.need_block_num_signal.value[request.idx] @@ -1687,7 +1688,12 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): # Do not block the main thread here # Write cache to storage if kvcache_storage_backend is enabled for req in need_postprocess_reqs: - if self.config.scheduler_config.splitwise_role == "decode": + if envs.FD_PD_TRANSFER_VIA_STORAGE: + # Storage pool mode: P already writes cache in token_processor before notifying D, + # only D needs to write here (including output tokens generated during decode) + if self.config.scheduler_config.splitwise_role == "decode": + self.cache_manager.write_all_cache_to_storage(req) + elif self.config.scheduler_config.splitwise_role == "decode": # D instance uses simplified write method (does not rely on Radix Tree) self.cache_manager.write_cache_to_storage_decode(req) else: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index bd56b21b99e..b2f5ec74fa5 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -254,6 +254,10 @@ def _validate_split_kv_size(value: int) -> int: "FD_DETERMINISTIC_LOG_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_LOG_MODE", "0"))), # Whether to use PD REORDER, can set 0 or 1 "FD_PD_REORDER": lambda: int(os.getenv("FD_PD_REORDER", "0")), + # PD disaggregation cache transfer mode: + # 0 (default): Direct transfer mode, P writes cache to D's GPU via RDMA/IPC + # 1: Storage pool mode, P writes cache to global storage pool, D reads from storage pool + "FD_PD_TRANSFER_VIA_STORAGE": lambda: int(os.getenv("FD_PD_TRANSFER_VIA_STORAGE", "0")), # Whether to enable KV cache lock, enforcing mutual exclusion between # PrefixCacheManager and Worker when accessing GPU KV cache. # Under certain DP+EP configurations, concurrent access (even read-only) diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 905e5941aa9..e003b1be089 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -234,7 +234,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -242,7 +246,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -330,7 +334,7 @@ def forward_mixed( # 64 is gpt-oss assert forward_meta.rotary_embs.shape[4] in [128, 32, 64] - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py index acb73f5420a..88f28467569 100644 --- a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py @@ -28,6 +28,7 @@ if current_platform.is_cuda(): paddle.enable_compat(scope={"flash_mla"}) +from fastdeploy import envs from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_kv_signal_per_query, @@ -243,7 +244,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -251,7 +256,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -304,7 +309,7 @@ def forward_mixed( # speculate_decoder = self.speculative_method is not None # speculate_max_tokens = self.speculate_max_draft_token_num - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 37401b1b314..d5044dbb544 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -304,7 +304,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -312,7 +316,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -339,7 +343,7 @@ def forward_mixed( ): metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 5b3c5ecdd3a..4c663c1a702 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -22,6 +22,7 @@ import paddle +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -146,7 +147,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # metadata only save pd_disaggregation info. metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -154,7 +159,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -197,7 +202,7 @@ def forward_mixed( cache_k_scales = getattr(layer, "cache_k_scale", None) cache_v_scales = getattr(layer, "cache_v_scale", None) - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 932c13decc3..c2a0c9e6a2d 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -34,6 +34,7 @@ logger.debug(f"flash_attention_v3_varlen not available: {e}") flash_attention_v3_varlen = None +from fastdeploy import envs from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_kv_signal_per_query, @@ -358,7 +359,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -366,7 +371,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -405,7 +410,7 @@ def forward_extend( """ metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -459,7 +464,7 @@ def forward_decode( """ metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -549,7 +554,7 @@ def forward_mixed( speculate_decoder = self.speculative_method is not None speculate_max_tokens = self.speculate_max_draft_token_num - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 21e29e4c571..902cf55328a 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -635,44 +635,80 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False recycle resources """ if is_prefill: - start_time = time.time() - result.metrics.wait_for_sending_cache_time = time.time() - trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) - - while True: - finished_task_ids = self.engine_worker_queue.get_finished_req() - if len(finished_task_ids) > 0: - for finished_task_id in finished_task_ids: - llm_logger.info(f"finished_task_id: {finished_task_id}") - self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] - if task_id in self.prefill_result_status: - if self.prefill_result_status[task_id] != "finished": - result.error_code = 501 - result.error_msg = ( - f"PD Error: prefill failed to send cache to decode, " - f"{task_id}, {self.prefill_result_status[task_id]}" - ) - self.prefill_result_status.pop(task_id) - llm_logger.info( - f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" + if envs.FD_PD_TRANSFER_VIA_STORAGE: + # Storage pool mode: bypass CacheMessager entirely. + # At this point, all transformer layers are complete and KV cache is in GPU memory. + # Directly write cache to storage and send first token to D. + result.metrics.wait_for_sending_cache_time = time.time() + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) + if result.error_code == 200: + write_cache_start_time = time.time() + llm_logger.info(f"[PD Storage] P writing cache to storage (direct), request_id: {task_id}") + write_success = self.resource_manager.cache_manager.write_all_cache_to_storage( + task, include_output=False ) - trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) - result.metrics.send_request_output_to_decode_time = time.time() - self.split_connector.send_first_token(task.disaggregate_info, [result]) - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.resource_manager.finish_requests_async(task_id) + if not write_success: + result.error_code = 501 + result.error_msg = f"P instance failed to write cache to storage for request {task_id}" + llm_logger.error(f"[PD Storage] {result.error_msg}") else: - self.resource_manager.stop_flags[index] = True - self.resource_manager.tasks_list[index] = None - self.resource_manager._recycle_block_tables(task) - if task_id in self.resource_manager.req_dict: - del self.resource_manager.req_dict[task_id] - break + llm_logger.info( + f"[PD Storage] P finished writing cache to storage (direct), " + f"request_id: {task_id}, cost: {time.time()-write_cache_start_time:.5f}s" + ) + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) + result.metrics.send_request_output_to_decode_time = time.time() + self.split_connector.send_first_token(task.disaggregate_info, [result]) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) else: - # TODO: Refine checking sending cache and do not keep waiting - if time.time() - start_time > 30: - llm_logger.warning(f"wait for sending cache, {task_id}") - time.sleep(0.005) + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] + else: + # RDMA/IPC mode: poll CacheMessager for transfer completion + start_time = time.time() + result.metrics.wait_for_sending_cache_time = time.time() + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) + + while True: + finished_task_ids = self.engine_worker_queue.get_finished_req() + if len(finished_task_ids) > 0: + for finished_task_id in finished_task_ids: + llm_logger.info(f"finished_task_id: {finished_task_id}") + self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] + if task_id in self.prefill_result_status: + if self.prefill_result_status[task_id] != "finished": + result.error_code = 501 + result.error_msg = ( + f"PD Error: prefill failed to send cache to decode, " + f"{task_id}, {self.prefill_result_status[task_id]}" + ) + self.prefill_result_status.pop(task_id) + llm_logger.info( + f"wait for sending cache, request_id: {task_id}, " + f"cost seconds: {time.time()-start_time:.5f}" + ) + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) + + result.metrics.send_request_output_to_decode_time = time.time() + self.split_connector.send_first_token(task.disaggregate_info, [result]) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) + else: + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] + break + else: + # TODO: Refine checking sending cache and do not keep waiting + if time.time() - start_time > 30: + llm_logger.warning(f"wait for sending cache, {task_id}") + time.sleep(0.005) else: if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.finish_requests_async(task_id) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 9f896b694a3..acbe71411da 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -412,7 +412,7 @@ def _process_message(self, frames: List[bytes]): self.current_request_ids[task["request_id"]] = current_status if self.enable_decode_cache_task: del self.current_request_ids[task["request_id"]] - if current_status == "finished": + if current_status == "finished" and not envs.FD_PD_TRANSFER_VIA_STORAGE: self.engine_worker_queue.put_cache_info(payload) except Exception as e: diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 599e0b8c5e0..5d2a054761e 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -449,7 +449,7 @@ def test_write_back_storage_task_skips_cached_keys(self): self.manager._run_write_back_storage.assert_not_called() self.manager.cache_task_queue.put_transfer_done_signal.assert_called_once_with( - (cache_transfer_manager.CacheStatus.GPU2STORAGE, "5", ["k1", "k2"], []) + (cache_transfer_manager.CacheStatus.GPU2STORAGE, "5", ["k1", "k2"], [0, 1]) ) def test_read_storage_task_no_matches(self): @@ -737,7 +737,7 @@ class LocalArgs(Args): def test_write_back_storage_task_nonzero_rank_no_signal(self): self.manager.cache_task_queue.swap_to_storage_barrier = MagicMock() self.manager.cache_task_queue.put_transfer_done_signal = MagicMock() - self.manager._run_write_back_storage = MagicMock() + self.manager._run_write_back_storage = MagicMock(return_value=1) self.manager.rank = 1 # Mock storage backend to return 0 matches (no keys exist) @@ -761,7 +761,10 @@ def test_write_back_storage_task_nonzero_rank_no_signal(self): [0], 0.1, ) - self.manager.cache_task_queue.put_transfer_done_signal.assert_not_called() + # After the refactor, the done signal is always sent regardless of rank. + self.manager.cache_task_queue.put_transfer_done_signal.assert_called_once_with( + (cache_transfer_manager.CacheStatus.GPU2STORAGE, "9", ["k1"], [0]) + ) def test_get_key_prefix_from_version(self): with patch("fastdeploy.cache_manager.cache_transfer_manager.yaml.safe_load") as mock_load: diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index f2a4a5fa116..8dd9b5162c4 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -1485,6 +1485,7 @@ def test_recv_data_transfer_result_handles_storage_events(self): (CacheStatus.STORAGE2GPU, "pref", ["h1"], [1, 2]), (CacheStatus.STORAGE2GPU, "pref", ["h2"], [1]), (CacheStatus.GPU2STORAGE, "write", ["h3"], [9]), + (CacheStatus.GPU2STORAGE, "write", ["h3"], [9]), ] manager.cache_task_queue = _FakeTransferQueue(payloads) with self.assertRaises(SystemExit): From 2b0fd532cbfc0ebc3801a30ce8a7d0780d99e99f Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Thu, 28 May 2026 10:38:32 +0800 Subject: [PATCH 134/143] [Cherry-Pick][Optimization]support fused noauxtc kernel on ep mode(#7936) (#7917) * support fused noauxtc kernel on ep mode * fix unit test --- fastdeploy/model_executor/layers/moe/ep.py | 5 +++++ tests/model_executor/test_ep.py | 1 + 2 files changed, 6 insertions(+) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 05c36a68f48..967c2a2fd02 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -27,6 +27,7 @@ import fastdeploy from fastdeploy import envs from fastdeploy.config import MoEPhase +from fastdeploy.platforms import current_platform from fastdeploy.utils import singleton @@ -531,6 +532,9 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): if layer.topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) score, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -540,6 +544,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): layer.gate_correction_bias, getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( diff --git a/tests/model_executor/test_ep.py b/tests/model_executor/test_ep.py index 373e8899396..b099c7ad57e 100644 --- a/tests/model_executor/test_ep.py +++ b/tests/model_executor/test_ep.py @@ -419,6 +419,7 @@ def fake_get_moe_scores(*_args, **_kwargs): routed_scaling_factor=1.0, gate_correction_bias=None, renormalize=False, + fd_config=SimpleNamespace(scheduler_config=SimpleNamespace(enable_moe_scores_elementwise_fuse=False)), ) gate_out = paddle.randn([1, 4], dtype="float32") From 1e7ee2287581d8a334a03c2cd2503ed0e12fbe70 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Thu, 28 May 2026 15:10:33 +0800 Subject: [PATCH 135/143] [Cherry-Pick] [Optimization] TopP=1.0 using _random_sample (#7892) and Triton SamplerBackend (#7639) (#7910) * [CP][Feature] support new sampler backend with triton (#7639) * [Optimization] TopP=1.0 using _random_sample (#7892) * code check * add env FD_ENABLE_TOP_P_ONE_OPT control top_p=1 opt * defalut FD_ENABLE_TOP_P_ONE_OPT=0 * change FD_ENABLE_TOP_P_ONE_OPT=1 * fix mtp triton seed * change triton seed int64 * fix triton sampler * add seed for mtp triton sampler --------- Co-authored-by: Zero Rains Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> --- .flake8 | 1 + fastdeploy/envs.py | 4 +- .../model_executor/layers/sample/meta_data.py | 1 + .../layers/sample/ops/__init__.py | 7 +- .../layers/sample/ops/top_k_top_p_sampling.py | 36 +- .../layers/sample/ops/top_k_top_p_triton.py | 992 ++++++++++++++++++ .../model_executor/layers/sample/sampler.py | 219 +++- fastdeploy/worker/gpu_model_runner.py | 2 + fastdeploy/worker/input_batch.py | 3 + scripts/run_pre_ce.sh | 14 +- tests/layers/test_triton_sampler.py | 429 ++++++++ 11 files changed, 1652 insertions(+), 56 deletions(-) create mode 100644 fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py create mode 100644 tests/layers/test_triton_sampler.py diff --git a/.flake8 b/.flake8 index 1656330a998..eeec63740e8 100644 --- a/.flake8 +++ b/.flake8 @@ -5,3 +5,4 @@ max-line-length = 119 # E402: module level import not at top of file per-file-ignores = __init__.py:F401,F403,E402 + fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E121,E131,E266 diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index b2f5ec74fa5..58edb8ca026 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -72,7 +72,7 @@ def _validate_split_kv_size(value: int) -> int: "FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"), # enable decode attention "USE_DECODE_UNIFIED_ATTENTION": lambda: bool(int(os.getenv("USE_DECODE_UNIFIED_ATTENTION", "0"))), - # Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently. + # Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently. "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), # Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently. "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), @@ -293,6 +293,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), # Whether to enable FP8 quantization with pow2scale. "FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))), + # Whether to enable top_p=1.0 optimization. + "FD_ENABLE_TOP_P_ONE_OPT": lambda: bool(int(os.getenv("FD_ENABLE_TOP_P_ONE_OPT", "1"))), } diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index e2ecb276957..b51ecb84010 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -42,6 +42,7 @@ class SamplingMetadata: step_idx: paddle.Tensor top_p: paddle.Tensor + top_p_list: Optional[list] = None # only GPU used bad_words_token_len: Optional[paddle.Tensor] = None top_k: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/__init__.py b/fastdeploy/model_executor/layers/sample/ops/__init__.py index eb2b79927bd..3b272ede7b3 100644 --- a/fastdeploy/model_executor/layers/sample/ops/__init__.py +++ b/fastdeploy/model_executor/layers/sample/ops/__init__.py @@ -23,7 +23,11 @@ speculate_get_accept_tokens_and_logits, speculate_insert_first_token, ) -from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling +from .top_k_top_p_sampling import ( + dispatch_top_k_renorm_probs, + min_p_sampling, + top_k_top_p_sampling, +) __all__ = [ "apply_penalty_multi_scores", @@ -33,4 +37,5 @@ "min_p_sampling", "speculate_get_accept_tokens_and_logits", "speculate_insert_first_token", + "dispatch_top_k_renorm_probs", ] diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index ff072e1a8ef..cc7c1c11277 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -34,6 +34,20 @@ def _reset_cuda_generator_for_determinism(): paddle.framework.core.default_cuda_generator(0).manual_seed(_DETERMINISTIC_RNG_SEED) +def dispatch_top_k_renorm_probs(probs, top_k): + try: + if current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import top_k_renorm_probs + else: + from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs + probs = top_k_renorm_probs(probs, top_k) + + except ImportError: + logger.warning("top_k sampling is not supported on current platform, skipping top_k filtering.") + + return probs + + def top_k_top_p_sampling( x: paddle.Tensor, top_p: paddle.Tensor, @@ -70,7 +84,6 @@ def top_k_top_p_sampling( """ top_p_class = envs.FD_SAMPLING_CLASS.lower() - topp_seed_device = None # In deterministic mode, reset CUDA generator offset before sampling. # paddle.tensor.top_p_sampling uses the global GPU generator offset even @@ -85,29 +98,17 @@ def top_k_top_p_sampling( _ = None else: if top_k_list and any(x > 0 for x in top_k_list): - try: - if current_platform.is_iluvatar(): - from fastdeploy.model_executor.ops.iluvatar import ( - top_k_renorm_probs, - ) - else: - from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs - x = top_k_renorm_probs(x, top_k) - except ImportError: - logger.warning("top_k sampling is not supported on current platform, skipping top_k filtering.") + x = dispatch_top_k_renorm_probs(x, top_k) if top_p_class == "air": _, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode) elif top_p_class == "base_non_truncated": - if topp_seed is not None: - topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype) - topp_seed_device.copy_(topp_seed, False) _, ids = paddle.tensor.top_p_sampling( x, top_p, threshold=threshold, - topp_seed=topp_seed_device, + topp_seed=topp_seed, seed=seed, k=k, mode="non-truncated", @@ -122,14 +123,11 @@ def top_k_top_p_sampling( _, ids = native_top_p_sampling(x, top_p) else: - if topp_seed is not None: - topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype) - topp_seed_device.copy_(topp_seed, False) _, ids = paddle.tensor.top_p_sampling( x, top_p, threshold=threshold, - topp_seed=topp_seed_device, + topp_seed=topp_seed, seed=seed, k=k, mode="truncated", diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py new file mode 100644 index 00000000000..cc2fe4faafa --- /dev/null +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py @@ -0,0 +1,992 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +""" +Combined Top-K and Top-P Triton kernels. + +Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs +using Pivot-based Truncation and Selection" By Park et al. +(https://arxiv.org/abs/2602.01518) + +""" + +import warnings + +import paddle +from paddle.utils.deprecated import VisibleDeprecationWarning + +from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( + enable_compat_on_triton_kernel, +) + +# Suppress the VisibleDeprecationWarning from use_triton_in_paddle that fires +# on every Triton kernel launch (paddle.device.cuda.current_stream / +# synchronize). In serving hot-paths this produces thousands of log lines per +# second and the I/O overhead alone can cause client-visible timeouts. +warnings.filterwarnings("ignore", category=VisibleDeprecationWarning) + +import triton # noqa: E402 +import triton.language as tl # noqa: E402 + +_TRITON_TABLE_CACHE: dict[tuple[paddle.device], tuple[paddle.Tensor, paddle.Tensor]] = {} +_TRITON_BUFFER_CACHE: dict[tuple[paddle.device, paddle.dtype, int], paddle.Tensor] = {} + +# fmt: off +_NORMAL_CDF_TO_SIGMA_TABLE = [ + 3.656, 3.650, 3.650, 3.650, 3.626, 3.626, 3.626, 3.514, 3.514, 3.503, + 3.503, 3.434, 3.434, 3.428, 3.428, 3.387, 3.380, 3.380, 3.376, 3.373, + 3.373, 3.356, 3.354, 3.354, 3.291, 3.249, 3.234, 3.214, 3.198, 3.198, + 3.185, 3.177, 3.177, 3.165, 3.164, 3.161, 3.138, 3.120, 3.115, 3.113, + 3.093, 3.066, 3.054, 3.043, 3.037, 3.023, 2.993, 2.991, 2.976, 2.970, + 2.952, 2.946, 2.932, 2.908, 2.902, 2.895, 2.886, 2.874, 2.861, 2.844, + 2.836, 2.810, 2.801, 2.790, 2.784, 2.779, 2.767, 2.757, 2.745, 2.733, + 2.723, 2.716, 2.693, 2.678, 2.671, 2.656, 2.649, 2.629, 2.611, 2.595, + 2.592, 2.585, 2.574, 2.550, 2.543, 2.534, 2.521, 2.518, 2.497, 2.485, + 2.468, 2.450, 2.441, 2.430, 2.412, 2.402, 2.389, 2.383, 2.377, 2.364, + 2.349, 2.338, 2.332, 2.319, 2.310, 2.301, 2.282, 2.274, 2.266, 2.250, + 2.242, 2.236, 2.226, 2.215, 2.207, 2.196, 2.179, 2.171, 2.162, 2.147, + 2.135, 2.121, 2.109, 2.095, 2.085, 2.073, 2.063, 2.045, 2.030, 2.016, + 2.003, 1.992, 1.983, 1.972, 1.960, 1.949, 1.940, 1.928, 1.912, 1.897, + 1.881, 1.869, 1.854, 1.838, 1.824, 1.807, 1.792, 1.779, 1.764, 1.751, + 1.739, 1.726, 1.711, 1.697, 1.685, 1.668, 1.652, 1.636, 1.622, 1.603, + 1.585, 1.568, 1.551, 1.534, 1.513, 1.499, 1.480, 1.464, 1.441, 1.422, + 1.394, 1.373, 1.347, 1.320, 1.296, 1.270, 1.246, 1.219, 1.190, 1.163, + 1.135, 1.104, 1.073, 1.041, 1.006, 0.969, 0.931, 0.894, 0.851, 0.806, + 0.757, 0.702, 0.643, 0.574, 0.498, 0.405, 0.288, 0.134, -0.110, -3.813 +] + +_PERCENTILE_TO_STD_TABLE = [ + 2.576, 2.319, 2.178, 2.064, 1.968, 1.892, 1.819, 1.757, 1.708, 1.659, + 1.616, 1.568, 1.526, 1.492, 1.456, 1.420, 1.382, 1.342, 1.309, 1.280, + 1.249, 1.221, 1.193, 1.169, 1.145, 1.121, 1.095, 1.073, 1.050, 1.030, + 1.008, 0.987, 0.966, 0.945, 0.926, 0.910, 0.891, 0.871, 0.854, 0.837, + 0.819, 0.803, 0.784, 0.767, 0.753, 0.734, 0.719, 0.702, 0.690, 0.675, + 0.658, 0.640, 0.625, 0.609, 0.595, 0.578, 0.564, 0.550, 0.537, 0.521, + 0.509, 0.495, 0.481, 0.466, 0.453, 0.439, 0.424, 0.410, 0.397, 0.383, + 0.370, 0.356, 0.343, 0.330, 0.316, 0.302, 0.289, 0.274, 0.261, 0.247, + 0.235, 0.223, 0.209, 0.196, 0.184, 0.172, 0.159, 0.149, 0.137, 0.124, + 0.112, 0.100, 0.086, 0.074, 0.062, 0.050, 0.035, 0.023, 0.009, -0.003, + -0.015, -0.027, -0.039, -0.052, -0.063, -0.074, -0.085, -0.097, -0.109, -0.122, + -0.134, -0.147, -0.158, -0.171, -0.184, -0.196, -0.210, -0.223, -0.235, -0.248, + -0.261, -0.275, -0.289, -0.302, -0.317, -0.328, -0.341, -0.353, -0.368, -0.382, + -0.396, -0.410, -0.426, -0.439, -0.452, -0.465, -0.480, -0.493, -0.507, -0.521, + -0.537, -0.551, -0.568, -0.582, -0.597, -0.614, -0.628, -0.643, -0.658, -0.673, + -0.691, -0.706, -0.721, -0.738, -0.754, -0.769, -0.789, -0.808, -0.824, -0.838, + -0.857, -0.877, -0.893, -0.912, -0.929, -0.947, -0.965, -0.983, -1.003, -1.027, + -1.050, -1.070, -1.092, -1.117, -1.139, -1.162, -1.189, -1.216, -1.241, -1.272, + -1.300, -1.330, -1.367, -1.404, -1.441, -1.485, -1.523, -1.564, -1.607, -1.658, + -1.710, -1.778, -1.832, -1.901, -1.978, -2.068, -2.174, -2.325, -2.577, -3.813 +] +# fmt: on + + +@triton.jit +def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, sentinel): + """Update running (min, count) of values above a pivot across tiles. + + Tracks the smallest value strictly above a pivot and how many times + it occurs. Called once per tile per pivot; the running state is + carried across tiles via `min_larger` / `num_min_larger`. + + Merge rule: + - tile min < running min → replace both + - tile min == running min → accumulate count + - tile min > running min → keep running values + """ + tile_min = tl.min(tl.where(above_mask, data, sentinel)) + tile_eq = above_mask & (tl.abs(data - tile_min) < 1e-9) + tile_cnt = tl.sum(tile_eq) + is_new = tile_min < min_larger + is_same = tl.abs(tile_min - min_larger) < 1e-9 + num_min_larger = tl.where(is_new, tile_cnt, num_min_larger + tile_cnt * is_same) + min_larger = tl.minimum(min_larger, tile_min) + return min_larger, num_min_larger + + +@enable_compat_on_triton_kernel +@triton.jit +def _topk_topp_kernel( + LOGITS, + BUFFER, + MASK_OUT, + PERCENTILE_TO_STD_TABLE, + NORMAL_CDF_TO_SIGMA_TABLE, + K, + P, + BATCH_SIZE, + VOCAB_SIZE: tl.constexpr, + MASK_VALUE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_TRUNC: tl.constexpr, + TOPK_ENABLED: tl.constexpr, + TOPP_ENABLED: tl.constexpr, + WRITE_MASK: tl.constexpr, +): + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + for row_id in tl.range(pid, BATCH_SIZE, num_programs): + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + BUFFER_ROW = BUFFER + pid * VOCAB_SIZE + + final_pivot = -float("inf") + duplicate_logit = float("inf") + num_duplicate_logit = tl.zeros((), dtype=tl.uint32) + num_keep = tl.zeros((), dtype=tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + max_logit = -float("inf") + min_logit = float("inf") + + if TOPK_ENABLED: + k = tl.load(K + row_id) + if k < VOCAB_SIZE: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf")) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0)) + + # Calculate outlier pivot t for Gaussian sigma-truncation + percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) + percentile = tl.minimum(percentile, 199) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + sigma = sigma + tl.abs(sigma) * -0.15 + outlier_pivot = avg_logit + std_logit * sigma + num_outliers = tl.zeros((), dtype=tl.uint32) + + # First pass: compute max and min logits and gather outliers + num_finite_total = tl.zeros((), dtype=tl.uint32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk_mask = logits_blk > -float("inf") + finite_blk = tl.where(finite_blk_mask, logits_blk, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) + num_finite_total += tl.sum(finite_blk_mask & mask_n) + + outlier_mask = (logits_blk > outlier_pivot) & mask_n + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) + + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + + # Second passes: Ternary search for pivots + num_iters = 0 + k_pivot = float("inf") + k_pivots_num = tl.zeros((), dtype=tl.uint32) + min_larger = float("inf") + num_min_larger = tl.zeros((), dtype=tl.uint32) + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + found_pivot = 0 + while found_pivot == 0: + k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # Single fused pass: compute k_pivots_num, + # min_larger, and num_min_larger together to avoid + # a second data scan. See _update_min_larger_stats + # for the tile-level merge logic. + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")) + + above_0 = logits_blk2 > k_pivot_0 + above_1 = logits_blk2 > k_pivot_1 + k_pivots_num_0 += tl.sum(above_0) + k_pivots_num_1 += tl.sum(above_1) + + min_larger_0, num_min_larger_0 = _update_min_larger_stats( + logits_blk2, + above_0, + min_larger_0, + num_min_larger_0, + float("inf"), + ) + min_larger_1, num_min_larger_1 = _update_min_larger_stats( + logits_blk2, + above_1, + min_larger_1, + num_min_larger_1, + float("inf"), + ) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + found_pivot = 1 + if k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + found_pivot = 1 + + # Update range + if k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + + num_iters += 1 + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: + k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + else: + # If top-k outlier gathering failed, search whole logit space + max_range = max_logit + min_range = min_logit + found_pivot = 0 + while found_pivot == 0: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # Single fused pass over full vocab (same approach + # as the buffer path above). + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + above_0 = logits_blk2 > k_pivot_0 + above_1 = logits_blk2 > k_pivot_1 + k_pivots_num_0 += tl.sum(above_0) + k_pivots_num_1 += tl.sum(above_1) + + min_larger_0, num_min_larger_0 = _update_min_larger_stats( + logits_blk2, + above_0, + min_larger_0, + num_min_larger_0, + float("inf"), + ) + min_larger_1, num_min_larger_1 = _update_min_larger_stats( + logits_blk2, + above_1, + min_larger_1, + num_min_larger_1, + float("inf"), + ) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + found_pivot = 1 + if k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + found_pivot = 1 + + # Update range + if k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + + num_iters += 1 + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: + k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = min_larger + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - (k_pivots_num - k) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k only path. If there are fewer finite values + # than k (e.g. grammar mask), keep everything. + final_pivot = k_pivot if num_finite_total > k else -float("inf") + + if TOPP_ENABLED and num_finite_total > k: + #### TOP-P SAMPLING AFTER TOP-K #### + p = tl.load(P + row_id) + if p < 1.0: + min_logit = k_pivot + sum_exp_logits = 0.0 + num_outliers_2 = tl.zeros((), dtype=tl.uint32) + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + # Third pass: Calculate exp logits and sum, gather outliers + if num_outliers > k: + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + + # Duplicate logit handling for Top-k + if num_keep < num_duplicate_logit: + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float("inf")) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + else: + # If top-k outlier gathering failed, + # retry gathering using top-k pivot + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load( + LOGITS_ROW + offs_n, + mask=mask_n, + other=-float("inf"), + ) + + outlier_mask = (probs_blk > min_logit) & mask_n + + # Duplicate logit handling for Top-k + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float("inf")) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers_2, + tl.int32, + ) + num_outliers_2 += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + search_range = tl.cast(num_outliers_2, tl.int32) + search_iters = tl.cast( + (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Fifth passes: Search for p_pivot + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and (p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p): + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and (p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p): + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k + Top-p path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + + if TOPP_ENABLED and final_pivot == -float("inf"): + #### STANDALONE TOP-P SAMPLING #### + p = tl.load(P + row_id) + if p < 1.0: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf")) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0)) + max_sample = avg_logit + std_logit * 10.0 + sum_exp_logits = 0.0 + + # First pass: compute max and min logits and sum_exp_logits + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk = tl.where(logits_blk > -float("inf"), logits_blk, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) + + probs_blk = tl.exp(logits_blk - max_sample) + probs_blk = tl.where(mask_n, probs_blk, 0.0) + sum_exp_logits += tl.sum(probs_blk) + + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + + idx = tl.cast(p * 200, tl.int32) + idx = tl.maximum(0, tl.minimum(idx, 199)) + sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma + + outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits + sum_outlier_probs = 0.0 + num_outliers = tl.zeros((), dtype=tl.uint32) + + # Second pass: Calculate softmax and gather outliers + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + max_range = tl.exp(max_logit - max_sample) / sum_exp_logits + min_range = tl.exp(min_logit - max_sample) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Third pass: Search for p_pivot + if sum_outlier_probs > p: + min_range = outlier_prob + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + else: + # Re-populate the buffer with full softmax probabilities + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-p only path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample + + # Sixth pass: Apply mask and store final output. + # If the pivot >= max logit (or is NaN), no token would + # survive the strict `>` keep_mask. Skip masking. + # Using `not <` instead of `>=` so that NaN is also caught. + if not (final_pivot < max_logit): + final_pivot = -float("inf") + elif final_pivot != -float("inf"): + if WRITE_MASK: + MASK_ROW = MASK_OUT + row_id * VOCAB_SIZE + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + keep_mask = (logits_blk > final_pivot) & mask_n + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + num_kept += tl.sum(duplicate_keep_mask) + keep_mask = keep_mask & (~duplicate_remove_mask) + + logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + if WRITE_MASK: + tl.store(MASK_ROW + offs_n, keep_mask, mask=mask_n) + + # When no masking was applied (final_pivot == -inf), all tokens are kept. + if WRITE_MASK and final_pivot == -float("inf"): + MASK_ROW = MASK_OUT + row_id * VOCAB_SIZE + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + tl.store(MASK_ROW + offs_n, mask_n, mask=mask_n) + + +def apply_top_k_top_p_triton( + logits: paddle.Tensor, + k: paddle.Tensor | None, + p: paddle.Tensor | None, + mask_value: float = float("-inf"), + return_mask: bool = False, +) -> paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor]: + """ + Apply combined top-k and top-p masking using Triton. + + Top-k is applied first (by logit value), then top-p is applied + to the remaining k values (by probability). + + Args: + logits: [batch_size, vocab_size] float32 tensor, modified in-place + k: [batch_size] int32 tensor of top-k values per row, or None to disable top-k + p: [batch_size] float32 tensor of top-p values per row (0 to 1), + or None to disable top-p + mask_value: Value for masked positions (default: -inf) + return_mask: If True, also return a bool mask [batch_size, vocab_size] + where True = retained token. The mask is computed inside the kernel + with zero extra memory bandwidth cost. + + Returns: + logits if return_mask is False, else (logits, mask). + """ + assert logits.ndim == 2 + assert logits.dtype == paddle.float32 + + batch_size, vocab_size = logits.shape + + topk_enabled = k is not None + topp_enabled = p is not None + + if batch_size == 0 or not (topk_enabled or topp_enabled): + if return_mask: + mask = paddle.ones(logits.shape, dtype=paddle.bool) + return logits, mask + return logits + + if k is not None: + assert k.ndim == 1 and k.shape[0] == batch_size + k_ptr = k.to(paddle.int32) + else: + k_ptr = logits # Dummy pointer (won't be read) + + if p is not None: + assert p.ndim == 1 and p.shape[0] == batch_size + p_ptr = p.to(paddle.float32) + else: + p_ptr = logits # Dummy pointer (won't be read) + + num_sm = paddle.device.cuda.get_device_properties(logits.device.index).multi_processor_count + NUM_PROGRAMS = min(num_sm, batch_size) + + # Cache per-Triton Program buffer on each device. + buf_key = (logits.device, logits.dtype, vocab_size) + buffer = _TRITON_BUFFER_CACHE.get(buf_key) + if buffer is None or buffer.shape[0] < NUM_PROGRAMS: + size = min(triton.next_power_of_2(NUM_PROGRAMS), num_sm) + buffer = paddle.empty((size, vocab_size), dtype=logits.dtype) + _TRITON_BUFFER_CACHE[buf_key] = buffer + if buffer.shape[0] > NUM_PROGRAMS: + buffer = buffer[:NUM_PROGRAMS] + + # Allocate mask output if requested. + write_mask = return_mask + if write_mask: + mask_out = paddle.empty(logits.shape, dtype=paddle.int8) + else: + mask_out = logits # Dummy pointer (won't be written) + + # Cache lookup table entries on each device. + tables = _TRITON_TABLE_CACHE.get(logits.device) + if tables is None: + normal_cdf_to_sigma_table = paddle.to_tensor( + _NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place + ) + percentile_to_std_table = paddle.to_tensor(_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place) + _TRITON_TABLE_CACHE[logits.device] = ( + normal_cdf_to_sigma_table, + percentile_to_std_table, + ) + else: + normal_cdf_to_sigma_table, percentile_to_std_table = tables + + _topk_topp_kernel[(NUM_PROGRAMS,)]( + logits, + buffer, + mask_out, + percentile_to_std_table, + normal_cdf_to_sigma_table, + k_ptr, + p_ptr, + BATCH_SIZE=batch_size, + MASK_VALUE=mask_value, + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=8192, + BLOCK_SIZE_TRUNC=4096, + TOPK_ENABLED=topk_enabled, + TOPP_ENABLED=topp_enabled, + WRITE_MASK=write_mask, + ) + + if return_mask: + return logits, mask_out.astype(paddle.bool) + return logits + + +@enable_compat_on_triton_kernel +@triton.jit +def _seeded_gumbel_kernel( + OUT_ptr, + SEEDS_ptr, + stride_out_batch, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Generate -log(u) with per-row Philox seeds, fully on GPU.""" + pid = tl.program_id(0) + seed = tl.load(SEEDS_ptr + pid) + for start in tl.range(0, VOCAB_SIZE, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < VOCAB_SIZE + u = tl.rand(seed, offsets) + u = tl.maximum(u, 1e-10) + q = -tl.log(u) + tl.store(OUT_ptr + pid * stride_out_batch + offsets, q, mask=mask) + + +def seeded_gumbel_noise(probs: paddle.Tensor, seeds: paddle.Tensor) -> paddle.Tensor: + """ + Generate Gumbel noise q = -log(u) with per-row Philox seeds on GPU. + + Args: + probs: [batch_size, vocab_size] — used only for shape/dtype. + seeds: [batch_size] int64 per-request seeds (GPU). + + Returns: + q: [batch_size, vocab_size] float tensor of Gumbel noise. + """ + batch_size, vocab_size = probs.shape + q = paddle.empty_like(probs) + BLOCK_SIZE = min(triton.next_power_of_2(vocab_size), 4096) + _seeded_gumbel_kernel[(batch_size,)]( + q, + seeds, + q.strides[0], + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + return q + + +def reset_buffer_cache(): + _TRITON_BUFFER_CACHE.clear() + _TRITON_TABLE_CACHE.clear() + paddle.accelerator.empty_cache() diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index a611769a9c9..cac8e7249a8 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -27,7 +27,7 @@ from fastdeploy import envs from fastdeploy.config import FDConfig -from fastdeploy.envs import FD_FILL_BITMASK_BATCH +from fastdeploy.envs import FD_FILL_BITMASK_BATCH, FD_SAMPLING_CLASS from fastdeploy.logger.deterministic_logger import _record_logits_diagnostic from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.early_stopper import ( @@ -41,11 +41,16 @@ from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, + dispatch_top_k_renorm_probs, min_p_sampling, reasoning_phase_token_constraint, speculate_insert_first_token, top_k_top_p_sampling, ) +from fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton import ( + apply_top_k_top_p_triton, + seeded_gumbel_noise, +) from fastdeploy.platforms import current_platform from fastdeploy.reasoning import ReasoningParser from fastdeploy.spec_decode import SpecMethod, VerifyStrategy @@ -59,6 +64,73 @@ ) +def _apply_triton_top_k_top_p( + logits: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, + top_k_list: Optional[list] = None, + return_mask: bool = False, +) -> paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor]: + """ + Apply combined top-k/top-p masking on logits using the Triton kernel. + Masked positions are set to -inf in-place. Call this BEFORE softmax. + + Args: + return_mask: If True, return (logits, mask) where mask is a bool + tensor [B, V] computed inside the Triton kernel (zero extra cost). + + Returns: + logits if return_mask is False, else (logits, mask). + """ + if top_p is None and top_k is None: + return logits + batch_size = logits.shape[0] + + top_p = top_p[:batch_size].squeeze(axis=-1) + + has_top_k = top_k_list and any(x > 0 for x in top_k_list) + if has_top_k: + top_k = top_k[:batch_size].squeeze(axis=-1) + else: + top_k = None + + return apply_top_k_top_p_triton(logits.astype("float32"), k=top_k, p=top_p, return_mask=return_mask) + + +def _random_sample( + probs: paddle.Tensor, + topp_seed: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + """ + Sample from probabilities using the Gumbel-max trick. + + Equivalent to multinomial sampling but avoids CPU-GPU synchronization. + When ``topp_seed`` is provided and Triton is available, a Triton kernel + generates per-row deterministic Gumbel noise using Philox PRNG entirely + on GPU, eliminating the Python for-loop and CPU-GPU sync overhead. + + Args: + probs: [batch_size, vocab_size] float32 probabilities. + topp_seed: [batch_size, 1] int64 per-request seeds, or None. + + Returns: + Token ids of shape [batch_size, 1]. + + Reference: vllm/v1/sample/ops/topk_topp_sampler.py::random_sample + """ + # Sample from Exp(1): q = -log(u), u ~ Uniform(0, 1) + if topp_seed is not None: + seeds = topp_seed[: probs.shape[0]].reshape([-1]) + if not seeds.place.is_gpu_place(): + seeds = seeds.cuda() + q = seeded_gumbel_noise(probs, seeds) + else: + u = paddle.uniform(probs.shape, dtype=probs.dtype, min=0.0, max=1.0) + q = -paddle.log(u.clip(min=1e-10)) + # Gumbel-max: argmax(probs / q) is equivalent to multinomial(probs) + return (probs / q).argmax(axis=-1).reshape([-1, 1]) + + def top_p_normalize_probs_paddle( probs: paddle.Tensor, top_ps: paddle.Tensor, @@ -252,6 +324,57 @@ def _extract_sparse_indices( return [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)] +def _sample_from_probs(probs, sampling_metadata, top_p=None, top_k=None, topp_seed=None): + """Sample next tokens from probability distributions with optional top-k and top-p filtering. + + When ``top_p_list`` is all 1.0 (no top-p filtering needed), uses + :func:`_random_sample` with an optional top-k renormalization pass via + :func:`dispatch_top_k_renorm_probs`. Otherwise dispatches through + :func:`top_k_top_p_sampling` to apply joint top-k/top-p constraints. + + Args: + probs: [token_num, vocab_size] float32 probability tensor (normalized logits). + sampling_metadata: Metadata carrying top_p, top_k, seed, top_k_list, + and top_p_list for the current batch of requests. + top_p: Override for per-row top-p values, shape [token_num, 1] or None. + top_k: Override for per-row top-k values, shape [token_num, 1] or None. + topp_seed: Override for per-row random seeds, shape [token_num, 1] or None. + + Returns: + Sampled token ids of shape [token_num, 1]. + """ + token_num = probs.shape[0] + if top_p is None: + top_p = sampling_metadata.top_p + if top_k is None: + top_k = sampling_metadata.top_k + if topp_seed is None: + topp_seed = sampling_metadata.seed + top_k_list = sampling_metadata.top_k_list + top_p_list = sampling_metadata.top_p_list + need_top_k_sampling = False + need_top_p_sampling = True + if top_k_list is not None: + top_k_list = top_k_list[:token_num] + need_top_k_sampling = any(k > 0 for k in top_k_list) + if top_p_list is not None: + top_p_list = top_p_list[:token_num] + need_top_p_sampling = any(p != 1.0 for p in top_p_list) + if not need_top_p_sampling and current_platform.is_cuda() and envs.FD_ENABLE_TOP_P_ONE_OPT: + if need_top_k_sampling: + probs = dispatch_top_k_renorm_probs(probs, top_k) + next_tokens = _random_sample(probs, topp_seed=topp_seed) + else: + _, next_tokens = top_k_top_p_sampling( + probs, + top_p, + top_k, + top_k_list, + topp_seed=topp_seed, + ) + return next_tokens + + class GuidedDecoding: """ processor for guided decoding. @@ -694,6 +817,16 @@ def forward_cuda( elif self.logprobs_mode == "processed_logits": raw_logprobs = logits.clone() + # Triton path: mask logits in-place BEFORE softmax (no probs→log round-trip). + if FD_SAMPLING_CLASS.lower() == "triton": + logits = _apply_triton_top_k_top_p( + logits, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + return_mask=False, + ) + probs = F.softmax(logits) # Record post-penalty logits and probs MD5 for determinism diagnosis @@ -725,13 +858,10 @@ def forward_cuda( # Store deferred GPU→CPU data; sparse extraction happens in save_output sampling_mask = (indices_window_cpu, mask_window_cpu, mask_bsz) - _, next_tokens = top_k_top_p_sampling( - probs, - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.top_k_list, - topp_seed=sampling_metadata.seed, - ) + if FD_SAMPLING_CLASS.lower() == "triton": + next_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed) + else: + next_tokens = _sample_from_probs(probs, sampling_metadata) logprobs_tensors = ( None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) @@ -967,6 +1097,7 @@ def _verify_and_sample( increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, + topp_seed: Optional[paddle.Tensor] = None, ) -> SamplerOutput: """ Verify draft tokens against target model output and produce final samples. @@ -998,19 +1129,22 @@ def _verify_and_sample( target_tokens, candidate_ids, candidate_scores, candidate_lens = None, None, None, None if self.verify_strategy == VerifyStrategy.TARGET_MATCH: - # Only TARGET_MATCH needs stochastic sampling - top_p, top_k, topp_seed = build_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.seed, - share_inputs["seq_lens_this_time"], - share_inputs["cu_seqlens_q_output"], - token_num_output_cpu, - increment_value, - ) - _, target_tokens = top_k_top_p_sampling( - probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed - ) + if FD_SAMPLING_CLASS.lower() == "triton": + target_tokens = _random_sample(probs, topp_seed=topp_seed) + else: + # Only TARGET_MATCH needs stochastic sampling + top_p, top_k, topp_seed = build_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + share_inputs["cu_seqlens_q_output"], + token_num_output_cpu, + increment_value, + ) + target_tokens = _sample_from_probs( + probs, sampling_metadata, top_p=top_p, top_k=top_k, topp_seed=topp_seed + ) elif self.verify_strategy == VerifyStrategy.GREEDY: # GREEDY: deterministic argmax in target_tokens, no candidates needed target_tokens = paddle.argmax(probs, axis=-1) @@ -1075,6 +1209,7 @@ def _normal_sample( probs: paddle.Tensor, sampling_metadata: SamplingMetadata, share_inputs: List[paddle.Tensor], + topp_seed: Optional[paddle.Tensor] = None, ) -> SamplerOutput: """ Normal sampling without draft token verification. @@ -1096,13 +1231,16 @@ def _normal_sample( probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) # Sample tokens - _, next_tokens = top_k_top_p_sampling( - probs, - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.top_k_list, - topp_seed=sampling_metadata.seed, - ) + if FD_SAMPLING_CLASS.lower() == "triton": + next_tokens = _random_sample(probs, topp_seed=topp_seed) + else: + next_tokens = _sample_from_probs( + probs, + sampling_metadata, + top_p=sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + topp_seed=sampling_metadata.seed, + ) # Scatter sampled tokens into accept_tokens using cu_seqlens_q_output to # correctly handle mixed prefill+decode batches where token index != batch index. @@ -1196,12 +1334,32 @@ def forward_cuda( self.line_break_id, ) + logits_ori = None + topp_seed = None + if FD_SAMPLING_CLASS.lower() == "triton": + logits_ori = logits.clone() + top_p, top_k, topp_seed = build_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + share_inputs["cu_seqlens_q_output"], + token_num_output_cpu, + increment_value, + ) + logits = _apply_triton_top_k_top_p( + logits, + top_p, + top_k=top_k, + top_k_list=sampling_metadata.top_k_list, + ) + probs = F.softmax(logits) # Route based on spec_method is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE if is_naive: - sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs) + sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs, topp_seed=topp_seed) else: sampler_output = self._verify_and_sample( logits, @@ -1213,13 +1371,14 @@ def forward_cuda( increment_value, accept_all_drafts, reject_all_drafts, + topp_seed=topp_seed, ) keep_sampling_mask = sampling_metadata.keep_sampling_mask # Build logprobs via unified path (outside of sampling logic) if sampling_metadata.max_num_logprobs is not None or keep_sampling_mask: logprobs_tensors, cu_batch_token_offset, target_logits = build_output_logprobs( - logits, + logits if logits_ori is None else logits_ori, sampling_metadata, share_inputs, is_naive=is_naive, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 69ef6a6bbe3..0ceec13bd92 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -992,6 +992,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens + self.share_inputs["top_p_list"][idx] = request.get("top_p", 0.7) self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0) self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) async_set_value(self.share_inputs["eos_token_id"][:], request.eos_token_ids) @@ -1283,6 +1284,7 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], + top_p_list=self.share_inputs["top_p_list"], top_k=self.share_inputs["top_k"], top_k_list=self.share_inputs["top_k_list"], min_p=self.share_inputs["min_p"], diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 687ae24088f..546c3ee36ff 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -121,6 +121,7 @@ def init_share_inputs(self): ) self.eos_token_id = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.top_p = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") + self.top_p_list = [self.model_config.top_p] * max_num_seqs self.top_k = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.top_k_list = [0] * max_num_seqs self.min_p = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") @@ -403,6 +404,7 @@ def swap_data(tensor, idx1, idx2): # swap_data(self.recompute_token_num, i1, i2) # # Swap list-based arrays (lists don't need clone) + self.top_p_list[i1], self.top_p_list[i2] = self.top_p_list[i2], self.top_p_list[i1] self.top_k_list[i1], self.top_k_list[i2] = self.top_k_list[i2], self.top_k_list[i1] self.min_p_list[i1], self.min_p_list[i2] = self.min_p_list[i2], self.min_p_list[i1] @@ -554,6 +556,7 @@ def reset_share_inputs(self): fill_paddle_tensor(self, "top_p_normalized_logprobs", False) # Reset list variables (not paddle tensors) + self.top_p_list = [self.model_config.top_p] * max_num_seqs self.top_k_list = [0] * max_num_seqs self.min_p_list = [0.0] * max_num_seqs diff --git a/scripts/run_pre_ce.sh b/scripts/run_pre_ce.sh index 069a2e938a0..f16eea8dd82 100644 --- a/scripts/run_pre_ce.sh +++ b/scripts/run_pre_ce.sh @@ -7,11 +7,15 @@ python -m pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/p python -m pip install -r requirements.txt python -m pip install jsonschema aistudio_sdk==0.3.5 -# Use prebuilt wheel files to install xgrammar==0.1.19 and torch==2.6.0 specifically for the CI environment -python -m pip install \ - https://paddle-qa.bj.bcebos.com/FastDeploy/torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl \ - https://paddle-qa.bj.bcebos.com/FastDeploy/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \ - https://paddle-qa.bj.bcebos.com/FastDeploy/xgrammar-0.1.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +# Use prebuilt wheel files to install xgrammar==0.1.19 triton==3.4.0 nvidia_nccl_cu12==2.27.3 and torch==2.8.0 specifically for the CI environment +python -m pip install --no-deps \ + https://paddle-qa.bj.bcebos.com/FastDeploy/torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/xgrammar-0.1.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + +# install runtime dependencies for torch and xgrammar +python -m pip install pydantic sentencepiece tiktoken ninja filelock sympy jinja2 fsspec # fix tests/ci_use/Prefix_Caching_Swap/test_vl_prefix_caching_swap.py (requires new pyarrow memory behavior) python -m pip install pyarrow==24.0.0 diff --git a/tests/layers/test_triton_sampler.py b/tests/layers/test_triton_sampler.py new file mode 100644 index 00000000000..f4c12e10082 --- /dev/null +++ b/tests/layers/test_triton_sampler.py @@ -0,0 +1,429 @@ +""" +Unit tests for the triton sampling path introduced in commit 16e692f. + +Covers: + - _apply_triton_top_k_top_p / apply_top_k_top_p_triton Python wrapper + - _random_sample / seeded_gumbel_noise Python wrapper + - Sampler.forward_cuda triton branch (FD_SAMPLING_CLASS="triton") + - SpeculativeSampler triton branches +""" + +import sys +import types + +import paddle +import pytest + +import fastdeploy # noqa: F401 + +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None + +# Stub triton for unit isolation (same pattern as test_sampler.py). +if "triton" not in sys.modules: + triton_stub = types.ModuleType("triton") + triton_stub.jit = lambda fn: fn + triton_stub.next_power_of_2 = lambda n: 1 << (n - 1).bit_length() + triton_lang_stub = types.ModuleType("triton.language") + triton_lang_stub.constexpr = int + sys.modules["triton"] = triton_stub + sys.modules["triton.language"] = triton_lang_stub + +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata + +# Must import after stubs are in place. +from fastdeploy.model_executor.layers.sample.sampler import ( + Sampler, + SpeculativeSampler, + _apply_triton_top_k_top_p, + _random_sample, +) +from fastdeploy.spec_decode import VerifyStrategy + +# --------------------------------------------------------------------------- +# Fixtures & helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _patch_gpu_deps(monkeypatch): + """Patch only GPU-specific calls so Python wrapper code can execute on CPU.""" + import fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton as triton_mod + + # Patch the kernel launch inside apply_top_k_top_p_triton: replace + # _topk_topp_kernel so it becomes a no-op (logits left unchanged → + # equivalent to "keep all" when no real GPU masking happens). + # This lets the Python wrapper (lines 830-936) run for coverage. + def _fake_kernel_call(grid, kwargs): + pass + + monkeypatch.setattr(triton_mod._topk_topp_kernel, "__call__", _fake_kernel_call) + + # Patch paddle.device.cuda.get_device_properties used inside + # apply_top_k_top_p_triton to avoid "no CUDA device" error. + fake_props = types.SimpleNamespace(multi_processor_count=1) + monkeypatch.setattr( + paddle.device.cuda, + "get_device_properties", + lambda idx: fake_props, + ) + + # Patch _seeded_gumbel_kernel similarly so seeded_gumbel_noise (lines + # 960-981) runs its Python logic without real GPU. + def _fake_gumbel_kernel_call(grid, kwargs): + pass + + monkeypatch.setattr(triton_mod._seeded_gumbel_kernel, "__call__", _fake_gumbel_kernel_call) + + # Patch batched_count_greater_than (used in gather_logprobs). + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.batched_count_greater_than", + lambda x, y: (x >= y).sum(-1), + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.logprobs.batched_count_greater_than", + lambda x, y: (x >= y).sum(-1), + ) + + # Patch current_platform so is_cuda() returns True (needed for + # build_sampling_params import). + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.current_platform.is_cuda", + lambda: True, + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.current_platform.is_xpu", + lambda: False, + ) + + +@pytest.fixture +def mock_ops(monkeypatch): + """Patch heavy GPU ops that are not the focus of triton tests.""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.apply_penalty_multi_scores", + lambda *a, **k: a[1], + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.apply_speculative_penalty_multi_scores", + lambda *a, **k: a[2], + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.min_p_sampling", + lambda probs, *a, **k: probs, + ) + return monkeypatch + + +@pytest.fixture +def triton_mode(monkeypatch): + """Set FD_SAMPLING_CLASS to triton for the duration of the test.""" + import fastdeploy.envs as envs + + monkeypatch.setattr(envs, "FD_SAMPLING_CLASS", "triton") + + +def _create_metadata(batch_size=1, min_seq_len=1, max_seq_len=3, max_num_logprobs=None, **overrides): + m = SamplingMetadata( + temperature=paddle.full(shape=[batch_size, 1], fill_value=0.9, dtype="float32"), + top_p=paddle.full(shape=[batch_size, 1], fill_value=0.7, dtype="float32"), + prompt_lens=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"), + step_idx=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"), + token_ids_all=paddle.full(shape=[batch_size, max_seq_len], fill_value=-1, dtype="int64"), + frequency_penalties=paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32"), + presence_penalties=paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32"), + repetition_penalties=paddle.full(shape=[batch_size, 1], fill_value=1.0, dtype="float32"), + min_dec_lens=paddle.full(shape=[batch_size, 1], fill_value=min_seq_len, dtype="int64"), + bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"), + bad_words_token_len=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"), + eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"), + min_p=paddle.zeros([batch_size], dtype="float32"), + seed=paddle.full([batch_size, 1], 7, dtype="int64"), + logits_processors=None, + ) + m.max_num_logprobs = max_num_logprobs + m.top_k = paddle.full([batch_size, 1], 5, dtype="int64") + m.top_k_list = [5 for _ in range(batch_size)] + m.min_p_list = [0.0 for _ in range(batch_size)] + m.enable_early_stop = True + m.stop_flags = paddle.zeros([batch_size, 1], dtype="int32") + m.share_inputs = { + "seq_lens_this_time": paddle.ones([batch_size, 1], dtype="int64"), + "seq_lens_encoder": paddle.zeros([batch_size, 1], dtype="int64"), + "seq_lens_decoder": paddle.zeros([batch_size, 1], dtype="int64"), + } + for k, v in overrides.items(): + setattr(m, k, v) + return m + + +def _make_stubbed_sampler(mode="processed_logprobs"): + s = Sampler.__new__(Sampler) + s.guided_decoding = types.SimpleNamespace(apply_token_mask=lambda logits, p_done_idxs: logits) + s.logprobs_mode = mode + s.early_stopper = types.SimpleNamespace(process=lambda probs, next_tokens, stop_flags: None) + return s + + +# --------------------------------------------------------------------------- +# Tests for _apply_triton_top_k_top_p (direct call) +# --------------------------------------------------------------------------- + + +class TestApplyTritonTopKTopP: + """Tests for _apply_triton_top_k_top_p.""" + + def test_returns_logits_unchanged_when_both_none(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=None, top_k=None) + assert paddle.equal_all(result, logits) + + def test_top_p_only_no_error(self): + """top_p filtering runs through apply_top_k_top_p_triton wrapper.""" + logits = paddle.to_tensor([[1.0, 2.0, 5.0]], dtype="float32") + top_p = paddle.to_tensor([[0.7]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p) + assert result.shape == [1, 3] + + def test_top_k_disabled_when_list_none(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + top_p = paddle.to_tensor([[1.0]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p, top_k=None, top_k_list=None) + assert result.shape == [1, 3] + + def test_return_mask_false(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + top_p = paddle.to_tensor([[0.9]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p, return_mask=False) + assert isinstance(result, paddle.Tensor) + + def test_return_mask_true(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + top_p = paddle.to_tensor([[0.5]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p, return_mask=True) + assert isinstance(result, tuple) + assert len(result) == 2 + logits_out, mask = result + assert logits_out.shape == [1, 3] + assert mask.shape == [1, 3] + assert mask.dtype == paddle.bool + + def test_output_dtype_is_float32(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float16") + top_p = paddle.to_tensor([[0.9]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p) + assert result.dtype == paddle.float32 + + def test_combined_top_k_top_p(self): + logits = paddle.to_tensor([[1.0, 5.0, 3.0, 2.0, 4.0]], dtype="float32") + top_p = paddle.to_tensor([[0.5]], dtype="float32") + top_k = paddle.to_tensor([[3]], dtype="int64") + top_k_list = [3] + result = _apply_triton_top_k_top_p(logits, top_p=top_p, top_k=top_k, top_k_list=top_k_list) + assert result.shape == [1, 5] + + +# --------------------------------------------------------------------------- +# Tests for _random_sample (direct call) +# --------------------------------------------------------------------------- + + +class TestRandomSample: + """Tests for _random_sample.""" + + def test_output_shape_and_dtype(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], dtype="float32") + result = _random_sample(probs) + assert result.shape == [2, 1] + assert result.dtype == paddle.int64 + + def test_without_seed(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7]], dtype="float32") + result = _random_sample(probs, topp_seed=None) + assert 0 <= result[0, 0].item() < 3 + + def test_with_seed(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7]], dtype="float32") + seed = paddle.to_tensor([[42]], dtype="int64") + result = _random_sample(probs, topp_seed=seed) + assert result.shape == [1, 1] + + def test_greedy_with_peak_distribution(self): + probs = paddle.zeros([1, 10], dtype="float32") + probs[0, 5] = 1.0 + result = _random_sample(probs) + assert result[0, 0].item() == 5 + + def test_batch_multiple_requests(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7], [0.0, 0.0, 1.0]], dtype="float32") + result = _random_sample(probs) + assert result.shape == [2, 1] + assert 0 <= result[0, 0].item() < 3 + assert result[1, 0].item() == 2 + + +# --------------------------------------------------------------------------- +# Tests for Sampler.forward_cuda with triton path +# --------------------------------------------------------------------------- + + +class TestSamplerTritonPath: + """Test Sampler.forward_cuda with FD_SAMPLING_CLASS=triton.""" + + def test_forward_cuda_triton_path(self, mock_ops, triton_mode): + """Sampler.forward_cuda should call _apply_triton_top_k_top_p and _random_sample.""" + sampler = _make_stubbed_sampler("processed_logprobs") + m = _create_metadata(batch_size=1, max_num_logprobs=2) + + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + output = sampler.forward_cuda(logits, m) + assert output.sampled_token_ids.shape == [1, 1] + assert output.logprobs_tensors is not None + + +# --------------------------------------------------------------------------- +# Tests for SpeculativeSampler triton branches +# --------------------------------------------------------------------------- + + +def _make_spec_sampler(verify_strategy=VerifyStrategy.TARGET_MATCH, spec_method=None): + """Create a SpeculativeSampler with stubbed internals.""" + s = SpeculativeSampler.__new__(SpeculativeSampler) + s.verify_strategy = verify_strategy + s.spec_method = spec_method # None → NAIVE path + s.enf_gen_phase_tag = False + s.config_accept_all = False + s.config_reject_all = False + s.speculative_benchmark_mode = False + s.speculative_max_candidate_len = 1 + s.speculative_verify_window = 2 + s.think_end_id = 1 + s.line_break_id = 2 + s.logprobs_mode = "processed_logprobs" + return s + + +def _spec_share_inputs(batch_size=1): + return { + "seq_lens_this_time": paddle.ones([batch_size, 1], dtype="int64"), + "seq_lens_encoder": paddle.zeros([batch_size, 1], dtype="int64"), + "cu_seqlens_q_output": paddle.to_tensor([0] + [1] * batch_size, dtype="int32"), + "batch_id_per_token_output": paddle.zeros([batch_size], dtype="int32"), + "accept_tokens": paddle.zeros([batch_size, 1], dtype="int64"), + "accept_num": paddle.zeros([batch_size], dtype="int32"), + "draft_tokens": paddle.zeros([batch_size, 1], dtype="int64"), + "stop_flags": paddle.zeros([batch_size, 1], dtype="int32"), + "is_block_step": paddle.zeros([batch_size], dtype="int32"), + "reasoning_status": paddle.zeros([batch_size, 1], dtype="int32"), + "max_dec_len": paddle.full([batch_size, 1], 1024, dtype="int64"), + "step_idx": paddle.zeros([batch_size, 1], dtype="int64"), + } + + +class TestSpeculativeSamplerTritonPath: + """Test SpeculativeSampler triton branches (lines 916, 1016-1017, 1120-1132).""" + + def test_verify_and_sample_target_match_triton(self, mock_ops, triton_mode, monkeypatch): + """_verify_and_sample with TARGET_MATCH + triton → calls _random_sample (line 916).""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.build_sampling_params", + lambda *a, **k: ( + paddle.to_tensor([[0.9]], dtype="float32"), + paddle.to_tensor([[5]], dtype="int64"), + paddle.to_tensor([[7]], dtype="int64"), + ), + ) + # verify_draft_tokens is lazily imported inside _verify_and_sample + import fastdeploy.model_executor.ops.gpu as gpu_ops + + monkeypatch.setattr(gpu_ops, "verify_draft_tokens", lambda *a, **k: None) + monkeypatch.setattr(gpu_ops, "top_p_candidates", lambda *a, **k: (None, None, None)) + + sampler = _make_spec_sampler(verify_strategy=VerifyStrategy.TARGET_MATCH, spec_method="ngram") + m = _create_metadata(batch_size=1) + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + probs = paddle.nn.functional.softmax(logits, axis=-1) + seeds = paddle.ones([probs.shape[0], 1], dtype="int64") + + out = sampler._verify_and_sample( + logits, + probs, + m, + max_model_len=8, + share_inputs=_spec_share_inputs(), + token_num_output_cpu=1, + increment_value=1, + topp_seed=seeds, + ) + assert out.sampled_token_ids is not None + + def test_normal_sample_triton(self, mock_ops, triton_mode, monkeypatch): + """_normal_sample with triton → calls _random_sample (line 1016-1017).""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.naive_update_model_status", + lambda *a, **k: None, + ) + + sampler = _make_spec_sampler(spec_method=None) # None → NAIVE + m = _create_metadata(batch_size=1) + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + probs = paddle.nn.functional.softmax(logits, axis=-1) + seeds = paddle.ones([probs.shape[0], 1], dtype="int64") + + out = sampler._normal_sample(logits, probs, m, share_inputs=_spec_share_inputs(), topp_seed=seeds) + assert out.sampled_token_ids is not None + + def test_forward_cuda_triton_logit_mask(self, mock_ops, triton_mode, monkeypatch): + """SpeculativeSampler.forward_cuda with triton → masks logits (lines 1120-1132).""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.build_sampling_params", + lambda *a, **k: ( + paddle.to_tensor([[0.9]], dtype="float32"), + paddle.to_tensor([[5]], dtype="int64"), + paddle.to_tensor([[7]], dtype="int64"), + ), + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.naive_update_model_status", + lambda *a, **k: None, + ) + + sampler = _make_spec_sampler(spec_method=None) # NAIVE → _normal_sample + m = _create_metadata(batch_size=1) + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + + out = sampler.forward_cuda( + logits, + m, + max_model_len=8, + share_inputs=_spec_share_inputs(), + token_num_output_cpu=1, + increment_value=1, + ) + assert out.sampled_token_ids is not None + + +# --------------------------------------------------------------------------- +# Tests for triton Python wrapper functions (top_k_top_p_triton.py coverage) +# --------------------------------------------------------------------------- + + +class TestTritonWrapperFunctions: + """Cover the Python wrapper functions in top_k_top_p_triton.py.""" + + def test_reset_buffer_cache(self, monkeypatch): + """reset_buffer_cache should run without error.""" + from fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton import ( + reset_buffer_cache, + ) + + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton.paddle.accelerator", + types.SimpleNamespace(empty_cache=lambda: None), + raising=False, + ) + reset_buffer_cache() + + +if __name__ == "__main__": + pytest.main([__file__]) From fefbcff4cf11f97a84ad01c2150776491d86ddfa Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Thu, 28 May 2026 20:29:28 +0800 Subject: [PATCH 136/143] [Cherry-Pick] [BugFix] fix all reduce fusion accurate issue (#7923) (#7922) * fix accurate issue * fix acc issue in ep + tp mode --------- Co-authored-by: root --- .../model_executor/layers/normalization.py | 28 ++- fastdeploy/model_executor/models/glm4_moe.py | 6 +- tests/layers/trtllm_allreduce_rms_fusion.py | 217 +++++++++++++++++- 3 files changed, 238 insertions(+), 13 deletions(-) diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 6f2e64ed6b2..8efe0056eb5 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -124,7 +124,8 @@ def __init__( self.tp_group = self.fd_config.parallel_config.tp_group is_input_norm = prefix.endswith(".input_layernorm") self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and ( - ("post_attention_layernorm" in prefix) or (("input_layernorm" in prefix and layer_id != 0)) + ("post_attention_layernorm" in prefix) + or (("input_layernorm" in prefix and layer_id != 0) and not fd_config.parallel_config.use_ep) ) self.is_last_norm = prefix.endswith(".norm") @@ -239,6 +240,13 @@ def forward( if residual_input is None: residual_out = x + use_allreduce_fused = ( + self.enable_all_reduce_fusion + and self.tp_size > 1 + and x.shape[0] <= 2048 + and residual_input is not None + and current_platform.is_cuda() + ) if proxy_rmsnorm is None: if current_platform.is_gcu(): if residual_input is None: @@ -246,7 +254,7 @@ def forward( return norm_out.astype(x_dtype), residual_out norm_out = self.norm_func(x, residual_input, self.weight, self.eps) # enable trtllm all reduce fusion - elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: + elif use_allreduce_fused: norm_out = flashinfer_allreduce_residual_rmsnorm( fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps ) @@ -272,9 +280,19 @@ def forward( quant_min_bound=self.quant_min_bound, ) else: - if residual_input is not None: - x = x + residual_input - norm_out = proxy_rmsnorm(x, self.weight, self.eps), x + if use_allreduce_fused: + norm_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, + input_tensor=x, + residual=residual_input, + weight=self.weight, + eps=self.eps, + ) + assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!" + else: + if residual_input is not None: + x = x + residual_input + norm_out = proxy_rmsnorm(x, self.weight, self.eps), x out = norm_out[0].astype(x_dtype) if residual_input is not None: diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index befbe64dd3f..0e5974c0410 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -64,6 +64,10 @@ def __init__( reduce_results: bool = True, ) -> None: super().__init__() + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and not reduce_results + ) + # shared experts not split when use_sequence_parallel_moe in ep + tp if ( fd_config.parallel_config.use_sequence_parallel_moe @@ -101,7 +105,7 @@ def __init__( output_size=fd_config.model_config.hidden_size, with_bias=False, reduce_results=reduce_results, - enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion, + enable_all_reduce_fusion=self.enable_all_reduce_fusion, ) self.act_fn = SiluAndMul( diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py index 1417d2a6463..eea59d26743 100644 --- a/tests/layers/trtllm_allreduce_rms_fusion.py +++ b/tests/layers/trtllm_allreduce_rms_fusion.py @@ -67,9 +67,13 @@ def setUp(self): paddle.seed(42) np.random.seed(42) - self.dtype = paddle.float32 + # NOTE: switched fp32 -> bf16 to mirror real model dtype on B GPUs. + # Combined with use_oneshot=None below, this exercises the bf16 + + # oneshot Lamport path, which is the suspected garbled-output path + # on Blackwell (sm100). + self.dtype = paddle.bfloat16 self.token_num = 128 - self.hidden_dim = 768 + self.hidden_dim = 4096 self.eps = 1e-6 self.epsilon = 1e-6 self.max_token_num = 2048 @@ -144,7 +148,9 @@ def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps): weight=weight, eps=eps, max_token_num=self.max_token_num, - use_oneshot=False, + # NOTE: do NOT pass use_oneshot=False here. We want the auto path + # (use_oneshot=None) so the oneshot Lamport kernel is exercised, + # matching how normalization.py calls it in the real model. ) return norm_out, residual_out @@ -235,11 +241,21 @@ def test_accuracy_fused_vs_reference(self): flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse( input_tensor.clone(), residual.clone(), weight.clone(), self.eps ) + + # bf16 needs much looser tolerance than fp32. Cast to fp32 for + # comparison to avoid numpy bf16 issues. + if self.dtype == paddle.bfloat16: + rtol, atol = 5e-2, 5e-2 + to_np = lambda t: t.astype("float32").numpy() # noqa: E731 + else: + rtol, atol = 1e-5, 1e-5 + to_np = lambda t: t.numpy() # noqa: E731 + # Verify results - np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(to_np(fused_output), to_np(reference_output), rtol=rtol, atol=atol) + np.testing.assert_allclose(to_np(ref_res), to_np(paddle_res), rtol=rtol, atol=atol) + np.testing.assert_allclose(to_np(flashinfer_output), to_np(reference_output), rtol=rtol, atol=atol) + np.testing.assert_allclose(to_np(ref_res), to_np(flashinfer_res), rtol=rtol, atol=atol) class TestFlashInferWorkspaceManager(unittest.TestCase): @@ -569,6 +585,193 @@ def test_cleanup_workspace_function(self): mock_manager.cleanup.assert_called_once() +class TestRMSNormProxyAllreduceFused(unittest.TestCase): + @classmethod + def setUpClass(cls): + # The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py + # has already done paddle.set_device + init_parallel_env, so we don't + # repeat that here. (unittest.main runs in the same process.) + cls.tp_size = dist.get_world_size() + cls.tp_rank = dist.get_rank() + + def _make_fd_config(self, enable_fusion: bool): + """Mock fd_config with the minimal attributes RMSNorm.__init__ touches.""" + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = self.tp_size + fd_config.parallel_config.tensor_parallel_rank = self.tp_rank + fd_config.parallel_config.tp_group = dist.get_group() + fd_config.parallel_config.expert_parallel_size = 1 + fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion + fd_config.parallel_config.use_sequence_parallel_moe = False + fd_config.model_config = Mock() + fd_config.model_config.moe_layer_start_index = -1 + fd_config.quant_config = None + return fd_config + + def _build_rmsnorm(self, enable_fusion: bool, hidden_size: int, layer_id: int = 1): + """Build a real RMSNorm whose enable_all_reduce_fusion resolves to + `enable_fusion` (use post_attention_layernorm prefix to ensure the + prefix-match in __init__ passes).""" + from fastdeploy.model_executor.layers.normalization import RMSNorm + + fd_config = self._make_fd_config(enable_fusion=enable_fusion) + norm = RMSNorm( + fd_config=fd_config, + hidden_size=hidden_size, + eps=1e-6, + prefix=f"model.layers.{layer_id}.post_attention_layernorm", + layer_id=layer_id, + dtype="bfloat16", + ) + # Initialize weight to a known reproducible value (constant=1.0 by default). + with paddle.no_grad(): + paddle.seed(2024) + new_w = paddle.randn([hidden_size], dtype=paddle.bfloat16) + dist.broadcast(new_w, src=0) + norm.weight.set_value(new_w) + return norm + + @staticmethod + def _proxy_rmsnorm_fn(x, weight, eps): + """Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula + in fp32 to keep reference numerics clean.""" + x_fp32 = x.astype("float32") + var = x_fp32.pow(2).mean(axis=-1, keepdim=True) + out = x_fp32 * paddle.rsqrt(var + eps) + out = out * weight.astype("float32") + return out.astype(x.dtype) + + def _reference(self, x_partial, residual, weight, eps): + """Manual: all_reduce(x_partial) + residual, then standard RMSNorm. + Mirrors what proxy path WOULD produce after explicit allreduce+add.""" + x = x_partial.clone() + dist.all_reduce(x, op=dist.ReduceOp.SUM) + residual_out = x + residual + norm_out = self._proxy_rmsnorm_fn(residual_out, weight, eps) + return norm_out, residual_out + + def _make_inputs(self, token_num, hidden_size, seed=123): + """Each rank gets a different x_partial (simulates RowParallelLinear's + un-reduced output); residual is identical across ranks.""" + paddle.seed(seed + self.tp_rank * 7919) + x_partial = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) * 0.1 + paddle.seed(seed + 99) + residual = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + return x_partial, residual + + def _assert_close_bf16(self, a, b, rtol=5e-2, atol=5e-2, msg=""): + a32 = a.astype("float32").numpy() + b32 = b.astype("float32").numpy() + np.testing.assert_allclose(a32, b32, rtol=rtol, atol=atol, err_msg=msg) + + # ---------- Tests ---------- + + def test_proxy_path_takes_fused_branch(self): + """fusion=on, tp>1, shape<=2048, residual!=None + -> proxy branch picks flashinfer_allreduce_residual_rmsnorm. + Verify by patching the symbol and asserting it was called. + """ + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 512 + norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden) + self.assertTrue(norm.enable_all_reduce_fusion) + x_partial, residual = self._make_inputs(token_num=64, hidden_size=hidden) + + # Patch within the normalization module's namespace. + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm", + wraps=__import__( + "fastdeploy.model_executor.layers.normalization", fromlist=["flashinfer_allreduce_residual_rmsnorm"] + ).flashinfer_allreduce_residual_rmsnorm, + ) as spy: + out, res = norm.forward( + x_partial.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=self._proxy_rmsnorm_fn, + ) + spy.assert_called_once() + + # Numerics: must match reference (allreduce + add + std rmsnorm). + ref_norm, ref_res = self._reference(x_partial, residual, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="proxy fused-branch norm output mismatch") + self._assert_close_bf16(res, ref_res, msg="proxy fused-branch residual mismatch") + + def test_proxy_path_falls_back_when_fusion_disabled(self): + """fusion=off -> proxy branch must call proxy_rmsnorm directly, + no fused allreduce path used. Input is treated as already-reduced.""" + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 512 + norm = self._build_rmsnorm(enable_fusion=False, hidden_size=hidden) + self.assertFalse(norm.enable_all_reduce_fusion) + + # Each rank uses the SAME x (already-reduced) — that's the contract + # when fusion is off (RowParallelLinear has done its own allreduce). + paddle.seed(777) + x = paddle.randn([64, hidden], dtype=paddle.bfloat16) * 0.1 + dist.broadcast(x, src=0) + residual = paddle.randn([64, hidden], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + + proxy_called = {"n": 0} + + def proxy_spy(_x, _w, _eps): + proxy_called["n"] += 1 + return self._proxy_rmsnorm_fn(_x, _w, _eps) + + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" + ) as fused_spy: + out, res = norm.forward( + x.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=proxy_spy, + ) + fused_spy.assert_not_called() + + self.assertEqual(proxy_called["n"], 1, "proxy_rmsnorm must be invoked exactly once") + + # Reference: x is already full -> just add + rmsnorm, no allreduce. + residual_full = x + residual + ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="fallback norm output mismatch") + self._assert_close_bf16(res, residual_full, msg="fallback residual mismatch") + + def test_proxy_path_falls_back_when_token_too_large(self): + """fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused; + in this regime upstream RowParallelLinear didn't skip its own + all-reduce, so x is already full and proxy_rmsnorm is invoked directly.""" + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 256 + norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden) + # shape[0] > 2048 forces use_allreduce_fused=False + token_num = 2049 + paddle.seed(555) + x = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) * 0.1 + dist.broadcast(x, src=0) + residual = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" + ) as fused_spy: + out, res = norm.forward( + x.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=self._proxy_rmsnorm_fn, + ) + fused_spy.assert_not_called() + + residual_full = x + residual + ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="large-shape fallback norm mismatch") + self._assert_close_bf16(res, residual_full, msg="large-shape fallback residual mismatch") + + if __name__ == "__main__": """Run tests directly (called by subprocess after distributed launch)""" unittest.main(verbosity=2) From ac24fcc360f77c5b803273d7941afef851dc52ba Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Fri, 29 May 2026 15:57:22 +0800 Subject: [PATCH 137/143] [Cherry-Pick][BugFix] fix mtp reset bugs in rl (#7957) (#7958) --- fastdeploy/worker/input_batch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 546c3ee36ff..af3032949b5 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -989,6 +989,10 @@ def reset_model_inputs(self) -> None: # NOTE(fix): These tensors are dynamically resized during runtime inference. # Must recreate at full initial size to avoid CUDAGraph replay OOB access. max_num_seqs = self.scheduler_config.max_num_seqs + if self.enable_mm and self.model_config.mm_max_tokens_per_item is None: + self.max_chunk_tokens = self.model_config.max_model_len + else: + self.max_chunk_tokens = self.fd_config.get_max_chunk_tokens(self.model_config.mm_max_tokens_per_item) self.ids_remove_padding = paddle.full([max_num_seqs * self.max_chunk_tokens], 0, dtype="int64") self.batch_id_per_token = paddle.full([max_num_seqs * self.max_chunk_tokens, 1], 0, dtype="int32") self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32") From 7198b5883a0ebdc85b91a455aa07d0856911c9a7 Mon Sep 17 00:00:00 2001 From: RAM Date: Mon, 1 Jun 2026 21:21:02 +0800 Subject: [PATCH 138/143] [RL] Fix the incorrect routing of EOS tokens, which leads to changes in accuracy (#7960) * Reset buffer size of R3 * refine code * R3 fix Eos bug * pre-commit * fix r3 ci and support dsa * refine code * refine code * reset ci dir * refine code * fix dsv3 --- custom_ops/gpu_ops/cpp_extensions.cc | 11 + .../get_position_ids_and_slot_mapping.cu | 108 ++++++ custom_ops/setup_ops.py | 1 + .../cache_manager/routing_cache_manager.py | 37 +- fastdeploy/config.py | 5 + fastdeploy/model_executor/layers/moe/moe.py | 2 + .../layers/moe/routing_indices_cache.py | 106 +++++- .../model_executor/models/deepseek_v3.py | 2 +- .../model_executor/pre_and_post_process.py | 6 +- fastdeploy/output/token_processor.py | 2 + fastdeploy/worker/gpu_model_runner.py | 55 +-- fastdeploy/worker/input_batch.py | 2 +- .../rollout_routing_replay_test_utils.py | 8 +- .../test_get_position_ids_and_slot_mapping.py | 345 ++++++++++++++++++ 14 files changed, 658 insertions(+), 32 deletions(-) create mode 100644 custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu create mode 100644 tests/operators/test_get_position_ids_and_slot_mapping.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index d74f4240260..0a65b6de6d3 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -622,6 +622,14 @@ void GetPositionIds(const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& position_ids); +void GetPositionIdsAndSlotMapping(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& block_tables, + const paddle::Tensor& position_ids, + const paddle::Tensor& slot_mapping, + const int block_size); std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_nope, @@ -1731,6 +1739,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) { #endif m.def("get_position_ids", &GetPositionIds, "get_position_ids function"); + m.def("get_position_ids_and_slot_mapping", + &GetPositionIdsAndSlotMapping, + "get_position_ids_and_slot_mapping function"); /** * cutlass_scaled_mm.cu diff --git a/custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu b/custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu new file mode 100644 index 00000000000..5c57a071461 --- /dev/null +++ b/custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu @@ -0,0 +1,108 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void GetPositionIdsAndSlotMappingKernel( + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + const int* __restrict__ seq_lens_this_time, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ block_tables, + const int bsz, + const int max_num_blocks, + const int block_size, + int64_t* __restrict__ position_ids, + int64_t* __restrict__ slot_mapping) { + int current_bid = threadIdx.x; + if (current_bid >= bsz) return; + + // Calculate the offset of current batch in the position_ids buffer + int buffer_offset = 0; + for (int i = 0; i < current_bid; i++) { + buffer_offset += seq_lens_this_time[i]; + } + + // Calculate the token offset in the current batch + int token_offset = seq_lens_decoder[current_bid]; + int token_num_this_batch = seq_lens_this_time[current_bid]; + if (token_num_this_batch == 0) return; + + // Write position ids and slot mapping for current batch +#pragma unroll + for (int i = 0; i < token_num_this_batch; i++) { + int pos_id = token_offset + i; + int idx = buffer_offset + i; + + // Write position_id + position_ids[idx] = pos_id; + + // Calculate slot mapping directly + int block_idx = pos_id / block_size; + int block_offset = pos_id % block_size; + int batch_id = batch_id_per_token[idx]; + + // Get block_id from block_tables + int block_id = block_tables[batch_id * max_num_blocks + block_idx]; + + // Calculate slot mapping + slot_mapping[idx] = static_cast( + static_cast(block_id) * block_size + block_offset); + } +} + +void GetPositionIdsAndSlotMapping(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& block_tables, + const paddle::Tensor& position_ids, + const paddle::Tensor& slot_mapping, + const int block_size) { + const int bsz = seq_lens_this_time.shape()[0]; + const int total_token_num = position_ids.shape()[0]; + const int max_num_blocks = block_tables.shape()[1]; + + GetPositionIdsAndSlotMappingKernel<<<1, + bsz, + 0, + seq_lens_this_time.stream()>>>( + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + batch_id_per_token.data(), + block_tables.data(), + bsz, + max_num_blocks, + block_size, + const_cast(position_ids.data()), + const_cast(slot_mapping.data())); +} + +PD_BUILD_STATIC_OP(get_position_ids_and_slot_mapping) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "block_tables", + "position_ids", + "slot_mapping", + }) + .Attrs({"block_size: int"}) + .Outputs({"position_ids_out", "slot_mapping_out"}) + .SetInplaceMap({{"position_ids", "position_ids_out"}, + {"slot_mapping", "slot_mapping_out"}}) + .SetKernelFn(PD_KERNEL(GetPositionIdsAndSlotMapping)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 1dbd443e7dd..1ba47905f1f 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -334,6 +334,7 @@ def find_end_files(directory, end_str): "gpu_ops/sample_kernels/top_k_renorm_probs.cu", "gpu_ops/sample_kernels/min_p_sampling_from_probs.cu", "gpu_ops/get_position_ids.cu", + "gpu_ops/get_position_ids_and_slot_mapping.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", diff --git a/fastdeploy/cache_manager/routing_cache_manager.py b/fastdeploy/cache_manager/routing_cache_manager.py index 68dff10b37d..f086b6e5821 100644 --- a/fastdeploy/cache_manager/routing_cache_manager.py +++ b/fastdeploy/cache_manager/routing_cache_manager.py @@ -181,6 +181,7 @@ def __init__(self, fd_config, num_gpu_blocks: int): self.routing_dtype = routing_replay_config.routing_dtype self.only_last_turn = routing_replay_config.only_last_turn self.use_fused_put = routing_replay_config.use_fused_put + self.debug_mode = routing_replay_config.debug_mode self.block_size = fd_config.cache_config.block_size self.return_mode = ( routing_replay_config.routing_store_type @@ -235,7 +236,41 @@ def gather_routing_for_request(self, block_table, seq_len: int) -> np.ndarray: block_indices = positions // self.block_size offsets = positions % self.block_size slot_mapping = np.array(block_ids)[block_indices] * self.block_size + offsets - return self.host_view.gather(slot_mapping) + routing_data = self.host_view.gather(slot_mapping) + + if self.debug_mode: + expected_routing = np.arange(seq_len, dtype=routing_data.dtype)[:, None, None] + expected_routing = np.broadcast_to(expected_routing, (seq_len, self.num_moe_layers, self.moe_top_k)) + if not np.array_equal(routing_data, expected_routing): + # Find all mismatched tokens + mismatch_mask = (routing_data != expected_routing).any(axis=(1, 2)) + mismatched_token_indices = np.where(mismatch_mask)[0] + # Check for duplicate slots in gather + unique_slots, counts = np.unique(slot_mapping, return_counts=True) + num_duplicates = np.sum(counts > 1) + dup_info = "" + if num_duplicates > 0: + dup_indices = np.where(counts > 1)[0] + dup_slots = unique_slots[dup_indices] + dup_info = f", duplicate_slots={list(dup_slots)}" + logger.error( + f"[R3 Debug] Gather mismatch! seq_len={seq_len}, mismatched_tokens={len(mismatched_token_indices)}, " + f"slots=[{slot_mapping[0]}...{slot_mapping[-1]}]{dup_info}" + ) + logger.error(f"Mismatched token indices: {mismatched_token_indices}") + for idx in mismatched_token_indices: # Print all mismatches tokens + logger.error( + f" position={idx}, slot={slot_mapping[idx]}, " + f"expected={expected_routing[idx, 0, 0]}, actual={routing_data[idx, 0, 0]}" + ) + raise ValueError("[R3 Debug]Routing gather validation failed.") + else: + logger.debug( + f"[R3 Debug] Gather validation passed: seq_len={seq_len}, " + f"slots=[{slot_mapping[0]}...{slot_mapping[-1]}]" + ) + + return routing_data def on_request_finished(self, request_id: str, block_table, seq_len: int) -> Optional[np.ndarray]: """ diff --git a/fastdeploy/config.py b/fastdeploy/config.py index f5d37cbc7ff..e0298a015ba 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1855,6 +1855,9 @@ def __init__(self, args) -> None: # Fused routing of all layers self.use_fused_put: bool = False + # Debug mode: hack topk_ids to use position_ids for validation + self.debug_mode: bool = False + # Auto-filled by FDConfig from ModelConfig (do not set manually) self.routing_dtype: str = "" # "uint8" / "uint16" / "uint32" self.num_moe_layers: int = 0 @@ -1885,6 +1888,8 @@ def postprocess(self, model_config: "ModelConfig") -> None: self.routing_dtype = "uint32" else: raise ValueError(f"num_experts {num_experts} exceeds uint32 range") + if self.debug_mode: + self.routing_dtype = "int64" def to_json_string(self): """ diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 0e4fa6ee9dd..b2354327631 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -754,6 +754,8 @@ def forward( ep_size=self.fd_config.parallel_config.expert_parallel_size, tp_group=self.fd_config.parallel_config.tp_group, total_token_num=forward_meta.batch_id_per_token.shape[0], + position_ids=forward_meta.position_ids, + debug_mode=self.fd_config.routing_replay_config.debug_mode, ) if current_platform.is_intel_hpu(): out = self.forward_normal( diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index 57765f65255..392eb3a238d 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -71,6 +71,8 @@ def save_routing_to_buffer_v2( ep_size: int, tp_group: dist.communication.group.Group, total_token_num: int = -1, + position_ids: paddle.Tensor = None, + debug_mode: bool = False, ): token_num_per_rank = topk_ids.shape[0] if token_num_per_rank == 0: @@ -83,6 +85,12 @@ def save_routing_to_buffer_v2( ), f"[R3] total_token_num={total_token_num} < token_num_per_rank={token_num_per_rank}" topk_ids = topk_ids_all[:total_token_num, :] + if debug_mode and position_ids is not None: + token_num, top_k = topk_ids.shape + hack_ids = position_ids[:token_num].cast(topk_ids.dtype) + hack_ids = hack_ids.unsqueeze(1).expand([-1, top_k]) + topk_ids = hack_ids + token_num, top_k = topk_ids.shape buf_max_tokens, num_moe_layers, buf_top_k = device_routing_buffer.shape @@ -124,7 +132,9 @@ def __init__(self, fd_config: FDConfig, total_block_num: int): self.num_moe_layers = rrc.num_moe_layers self.moe_top_k = rrc.moe_top_k self.routing_dtype = rrc.routing_dtype + self.debug_mode = rrc.debug_mode self.tp_rank = fd_config.parallel_config.tensor_parallel_rank + self.token_num_overlap = 0 logger.info(f"[R3] RoutedExpertsCapturer config: {rrc}") @@ -133,6 +143,7 @@ def __init__(self, fd_config: FDConfig, total_block_num: int): def _init_routing_cache(self, dtype: str, total_block_num: int): """Initialize GPU transient buffer, staging buffers, and CPU pinned buffers.""" max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size + self.max_num_kv_tokens = max_num_kv_tokens # Save for slot range validation # Small GPU transient buffer: only current step's token routing # TODO(Chengyanfu): Use max_num_batched_tokens to replace get_max_chunk_tokens() @@ -145,6 +156,14 @@ def _init_routing_cache(self, dtype: str, total_block_num: int): self.cpu_routing_buf = paddle.zeros(shape, dtype=dtype).pin_memory() self.cpu_slot_mapping_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory() + + if self.debug_mode: + self.position_ids_staging_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64) + self.cpu_position_ids_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory() + else: + self.position_ids_staging_buf = None + self.cpu_position_ids_buf = None + self._pending_save = None # {"num_tokens": int} # Lazy attach to SharedMemory routing_host_buffer (created by Engine after profiling) @@ -180,7 +199,9 @@ def _try_attach_routing_host_view(self): "Routing capture will be skipped." ) - def prepare_pending_save(self, num_tokens: int, slot_mapping_gpu: paddle.Tensor): + def prepare_pending_save( + self, num_tokens: int, slot_mapping_gpu: paddle.Tensor, position_ids_gpu: paddle.Tensor = None + ): """ Enqueue D2D + async D2H for routing data and slot_mapping. Must be called before post_process_event.record(). @@ -190,14 +211,25 @@ def prepare_pending_save(self, num_tokens: int, slot_mapping_gpu: paddle.Tensor) 2. D2D (non-blocking): slot_mapping_gpu → slot_mapping_staging_buf 3. async D2H: routing_staging_buf → cpu_routing_buf 4. async D2H: slot_mapping_staging_buf → cpu_slot_mapping_buf + 5. async D2H (debug mode): position_ids_gpu → cpu_position_ids_buf """ if num_tokens > 0: + if self.fd_config.scheduler_config.enable_overlap_schedule: + num_tokens = self.token_num_overlap + slot_mapping_gpu = slot_mapping_gpu[:num_tokens] + position_ids_gpu = position_ids_gpu[:num_tokens] + # D2D: GPU → staging self.routing_staging_buf.copy_(self.device_routing_buffer, False) self.slot_mapping_staging_buf.copy_(slot_mapping_gpu, False) # Async D2H: staging → CPU pinned self.cpu_routing_buf.copy_(self.routing_staging_buf, False) self.cpu_slot_mapping_buf.copy_(self.slot_mapping_staging_buf, False) + + if self.debug_mode and position_ids_gpu is not None and self.cpu_position_ids_buf is not None: + self.position_ids_staging_buf.copy_(position_ids_gpu, False) + self.cpu_position_ids_buf.copy_(self.position_ids_staging_buf, False) + self._pending_save = {"num_tokens": num_tokens} else: self._pending_save = None @@ -222,7 +254,77 @@ def flush_pending_save(self): num_tokens = pending["num_tokens"] # NOTE(gongshaotian): Slice pinned memory tensor maybe cause problem. data = self.cpu_routing_buf.cpu()[:num_tokens].numpy() - slot_np = self.cpu_slot_mapping_buf.cpu()[:num_tokens].numpy() + slot_cpu = self.cpu_slot_mapping_buf.cpu() + slot_cpu_slice = slot_cpu[:num_tokens] + slot_np = slot_cpu_slice.numpy() + + if self.debug_mode and self.cpu_position_ids_buf is not None: + position_ids = self.cpu_position_ids_buf.cpu()[:num_tokens].numpy() + expected_routing = position_ids[:, None, None] + expected_routing = np.broadcast_to(expected_routing, (num_tokens, self.num_moe_layers, self.moe_top_k)) + if not np.array_equal(data, expected_routing): + # 1. Check routing capture + mismatch_mask = (data != expected_routing).any(axis=(1, 2)) + mismatched_token_indices = np.where(mismatch_mask)[0] + logger.error( + f"[R3 Debug] flush mismatch! num_tokens={num_tokens}, mismatched_tokens={len(mismatched_token_indices)}" + ) + logger.error(f"Mismatched token indices: {mismatched_token_indices}") + for idx in mismatched_token_indices: + logger.error( + f" token={idx}, position_id={position_ids[idx]}, slot={slot_np[idx]}, " + f"expected={expected_routing[idx, :, :]}, actual={data[idx, :, :]}" + ) + raise ValueError("Routing data verification failed.") + else: + # 2. Check slot mapping generation and validate slot indices (should be >= 0) + if slot_cpu_slice.min() < 0: + error_parts = [f"[R3 Debug] Invalid slot indices: num_tokens={num_tokens}"] + error_parts.append(" token |slot_staging | slot_pinned | slot_cpu | position_id | data[0,0]") + error_parts.append(" " + "-" * 50) + for i in range(num_tokens): + error_parts.append( + f" {i:4d} | {int(self.slot_mapping_staging_buf[i]):7d} | {int(self.cpu_slot_mapping_buf[i]):7d} | {int(slot_cpu[i]):7d} | {int(position_ids[i]):11d} | {int(data[i, 0, 0])}" + ) + raise AssertionError("\n".join(error_parts)) + # 2.1 Check slot range (should be < max_num_kv_tokens) + max_slot = slot_cpu_slice.max() + if max_slot >= self.max_num_kv_tokens: + invalid_slots = np.where(slot_np >= self.max_num_kv_tokens)[0] + error_parts = [ + f"[R3 Debug] Slot indices out of range: num_tokens={num_tokens}, " + f"max_slot={max_slot}, max_num_kv_tokens={self.max_num_kv_tokens}" + ] + error_parts.append(f" Invalid slot indices: {invalid_slots[:10]}... ({len(invalid_slots)} total)") + error_parts.append(" token |slot | position_id | data[0,0]") + error_parts.append(" " + "-" * 50) + for idx in invalid_slots[:10]: + error_parts.append( + f" {idx:4d} | {int(slot_np[idx]):6d} | {int(position_ids[idx]):11d} | {int(data[idx, 0, 0])}" + ) + raise AssertionError("\n".join(error_parts)) + # 3. Check slot mapping duplicates + unique_slots, counts = np.unique(slot_np, return_counts=True) + num_unique = len(unique_slots) + num_duplicates = np.sum(counts > 1) + if num_duplicates > 0: + duplicate_indices = np.where(counts > 1)[0] + dup_slots_info = [] + for slot_idx in duplicate_indices[:5]: + slot = unique_slots[slot_idx] + count = counts[slot_idx] + dup_token_indices = np.where(slot_np == slot)[0] + dup_slots_info.append(f"slot={slot} count={count} indices={dup_token_indices}") + logger.error( + f"[R3 Debug] flush validation passed but found duplicate slots! " + f"num_tokens={num_tokens}, unique_slots={num_unique}, duplicates={num_duplicates}. " + f"Details: {'; '.join(dup_slots_info)}" + ) + else: + logger.debug( + f"[R3 Debug] flush validation passed: num_tokens={num_tokens}, " + f"slots=[{slot_np[0]}...{slot_np[-1]}], unique_slots={num_unique}" + ) self.routing_host_view.scatter(slot_np, data) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 457846874ee..32a37f21967 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -1085,7 +1085,7 @@ def forward( forward_meta, hidden_states, residual, - position_ids, + position_ids.cast(paddle.int32), ) out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0] diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index fd8811b1101..a9309f5a3af 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -337,9 +337,10 @@ def post_process_normal( # Routing replay if routing_replay_manager is not None: slot_mapping_gpu = share_inputs["slot_mapping_buffer"] + position_ids_gpu = share_inputs.get("position_ids_buffer") num_tokens = int(share_inputs["ids_remove_padding"].shape[0]) if routing_replay_manager.tp_rank == 0: - routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu) + routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu, position_ids_gpu) # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): @@ -506,9 +507,10 @@ def post_process_speculate( # Routing replay if routing_replay_manager is not None: slot_mapping_gpu = share_inputs["slot_mapping_buffer"] + position_ids_gpu = share_inputs.get("position_ids_buffer") num_tokens = int(share_inputs["ids_remove_padding"].shape[0]) if routing_replay_manager.tp_rank == 0: - routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu) + routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu, position_ids_gpu) # Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx # into a single kernel launch. Handles EOS detection, max_dec_len truncation, step_idx diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 902cf55328a..9e37cc52f11 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -617,6 +617,8 @@ def _finalize_routing(self, task_id, task, result, is_prefill=False): if hasattr(task, "output_token_ids") else task.prompt_token_ids_len ) + if task.output_token_ids[-1] in task.eos_token_ids: + seq_len = seq_len - 1 # Ignore eos token if store_type == "response": routing_data = self._gather_routing_for_finished_request(task, seq_len) if routing_data is not None: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 0ceec13bd92..4fb295434be 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -71,11 +71,13 @@ ) share_external_data = None + get_position_ids_and_slot_mapping = None elif current_platform.is_dcu(): from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx recover_decode_task = None share_external_data = None + get_position_ids_and_slot_mapping = None else: from fastdeploy.model_executor.ops.gpu import ( recover_decode_task, @@ -84,7 +86,7 @@ speculate_schedule_cache, set_data_ipc, unset_data_ipc, - get_position_ids, + get_position_ids_and_slot_mapping, ) import zmq @@ -1313,7 +1315,7 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p ) return token_num, token_num_event - def _compute_position_ids_and_slot_mapping(self) -> None: + def _compute_position_ids_and_slot_mapping(self, total_token_num) -> None: """Compute position_ids and slot_mapping for KV cache addressing. This is a general computation based on sequence length info and block tables, applicable to all models that need per-token KV cache physical slot addresses. @@ -1325,25 +1327,32 @@ def _compute_position_ids_and_slot_mapping(self) -> None: needs_slot_mapping = (self.routing_replay_manager is not None) or needs_slot_mapping if not needs_slot_mapping: return - current_total_tokens = self.forward_meta.ids_remove_padding.shape[0] - position_ids = self.share_inputs["position_ids_buffer"][:current_total_tokens] - get_position_ids( + # Directly write to existing buffers (no memory allocation or copy needed) + position_ids_buffer = self.share_inputs["position_ids_buffer"][:total_token_num] + slot_mapping_buffer = self.share_inputs["slot_mapping_buffer"][:total_token_num] + get_position_ids_and_slot_mapping( self.forward_meta.seq_lens_encoder, self.forward_meta.seq_lens_decoder, self.forward_meta.seq_lens_this_time, - position_ids, + self.forward_meta.batch_id_per_token, + self.forward_meta.block_tables, + position_ids_buffer, + slot_mapping_buffer, + self.cache_config.block_size, ) - block_size = self.cache_config.block_size - block_idx = position_ids // block_size # [num_tokens] - assert ( - self.forward_meta.batch_id_per_token.shape == block_idx.shape - ), f"batch_id_per_token.shape:{self.forward_meta.batch_id_per_token.shape} != block_idx.shape:{block_idx.shape}" - block_ids = self.forward_meta.block_tables[self.forward_meta.batch_id_per_token, block_idx] # [num_tokens] - block_offset = position_ids % block_size # [num_tokens] - slot_mapping = self.share_inputs["slot_mapping_buffer"][:current_total_tokens] - paddle.assign((block_ids * block_size + block_offset).cast(paddle.int64), slot_mapping) - self.forward_meta.position_ids = position_ids - self.forward_meta.slot_mapping = slot_mapping + # Store views in forward_meta + self.forward_meta.position_ids = position_ids_buffer + self.forward_meta.slot_mapping = slot_mapping_buffer + + # Debug: print all tokens' position_ids and slot_mapping in R3 debug mode + if self.routing_replay_manager is not None and self.routing_replay_manager.debug_mode: + logger.info(f"[R3 Debug] token mapping: num_tokens={total_token_num}") + logger.info(" token | position_id | slot ") + logger.info(" " + "-" * 30) + for i in range(total_token_num): + logger.info( + f" {i:4d} | {int(self.forward_meta.position_ids[i]):8d} | {int(self.forward_meta.slot_mapping[i]):7d}" + ) def _process_reorder(self) -> None: if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False): @@ -1942,12 +1951,13 @@ def _dummy_run( while True: # 1. Initialize forward meta and attention meta data - self._prepare_inputs(is_dummy_or_profile_run=True) + token_num, _ = self._prepare_inputs(is_dummy_or_profile_run=True) # 2. Padding inputs for cuda graph self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph self.padding_cudagraph_inputs() # Compute position_ids and slot_mapping - self._compute_position_ids_and_slot_mapping() + + self._compute_position_ids_and_slot_mapping(total_token_num=token_num) model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] @@ -2327,7 +2337,10 @@ def execute_model_overlap( # ensuring that the token count for the current batch is ready to be computed and reused in the subsequent batch. token_num_event.synchronize() next_launch_token_num, next_real_bsz = self._predict_next_launch_token_num() - real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item() + seq_lens_this_time_cpu_numpy = self.share_inputs["seq_lens_this_time_cpu"].numpy() + real_bsz = (seq_lens_this_time_cpu_numpy > 0).sum().item() + if self.routing_replay_manager is not None: + self.routing_replay_manager.token_num_overlap = seq_lens_this_time_cpu_numpy.sum().item() if real_bsz > 0 and model_output is not None: model_output_data, sampler_output, post_process_event = self._postprocess( model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz @@ -2393,7 +2406,7 @@ def _preprocess( # Padding inputs for cuda graph self.padding_cudagraph_inputs() # Compute position_ids and slot_mapping - self._compute_position_ids_and_slot_mapping() + self._compute_position_ids_and_slot_mapping(total_token_num=current_launch_token_num) model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index af3032949b5..53f6db55649 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -190,7 +190,7 @@ def init_share_inputs(self): self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") # Initialize addressing buffers - self.position_ids_buffer = paddle.zeros([self.max_chunk_tokens], dtype=paddle.int32) + self.position_ids_buffer = paddle.zeros([self.max_chunk_tokens], dtype=paddle.int64) self.slot_mapping_buffer = paddle.zeros([self.max_chunk_tokens], dtype=paddle.int64) # Declare AttentionBackend buffers diff --git a/tests/e2e/utils/rollout_routing_replay_test_utils.py b/tests/e2e/utils/rollout_routing_replay_test_utils.py index 86d853d845b..7ff1808514c 100644 --- a/tests/e2e/utils/rollout_routing_replay_test_utils.py +++ b/tests/e2e/utils/rollout_routing_replay_test_utils.py @@ -21,8 +21,8 @@ def calculate_routing_ratio(expected_routing: paddle.Tensor, actual_routing: pad print(f"token index {i}:\n expected_routing:{expected_routing[i]}\n actual_routing: {actual_routing[i]}\n") assert ( - expected_routing_length == actual_routing_length - ), f"Routing real lengths do not match. Expected length {expected_routing_length} actual length {actual_routing_length}." + expected_routing_length + ) == actual_routing_length, f"Routing real lengths do not match. Expected length {expected_routing_length} actual length {actual_routing_length}." total_rows, elements_per_row = expected_routing.shape mask1 = paddle.any(expected_routing != -1, axis=1) @@ -156,9 +156,9 @@ def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, mode cur_save_routing_path = f"./R3_tmp/routing_replay_output_{model_name}/" model_path = os.getenv("MODEL_PATH") if model_path: - baseline_path = os.path.join(model_path, f"R3_BaseLine_uint8_0424/routing_replay_output_baseline_{model_name}") + baseline_path = os.path.join(model_path, f"R3_BaseLine_uint8_0530/routing_replay_output_baseline_{model_name}") else: - baseline_path = f"./R3_BaseLine_uint8_0424/routing_replay_output_baseline_{model_name}" + baseline_path = f"./R3_BaseLine_uint8_0530/routing_replay_output_baseline_{model_name}" stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream") nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream") diff --git a/tests/operators/test_get_position_ids_and_slot_mapping.py b/tests/operators/test_get_position_ids_and_slot_mapping.py new file mode 100644 index 00000000000..22bf32a3323 --- /dev/null +++ b/tests/operators/test_get_position_ids_and_slot_mapping.py @@ -0,0 +1,345 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import ( + get_position_ids, + get_position_ids_and_slot_mapping, +) + + +class TestGetPositionIdsAndSlotMapping(unittest.TestCase): + """Test the fused get_position_ids_and_slot_mapping kernel. + + Variable meanings: + - seq_lens_encoder: 0 if decode stage, else prefill length in current step + - seq_lens_decoder: total context length (processed history, prefill + decode) + - seq_lens_this_time: tokens to process in current step + """ + + def setUp(self): + np.random.seed(42) + paddle.set_device("gpu") + + def _compute_position_ids_and_slot_mapping_old( + self, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ): + """Old implementation for comparison.""" + sum_token_num = int(seq_lens_this_time.numpy().sum()) + + # get_position_ids expects int32, so use int32 and then cast to int64 + position_ids_int32 = paddle.zeros([sum_token_num], dtype="int32") + get_position_ids(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids_int32) + + block_idx = position_ids_int32 // block_size + block_ids = block_tables[batch_id_per_token[:sum_token_num], block_idx] + block_offset = position_ids_int32 % block_size + slot_mapping = (block_ids * block_size + block_offset).cast(paddle.int64) + + # Cast position_ids to int64 for comparison with new kernel + position_ids = position_ids_int32.cast(paddle.int64) + + return position_ids.numpy(), slot_mapping.numpy() + + def _compute_position_ids_and_slot_mapping_new( + self, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ): + """New fused kernel implementation.""" + sum_token_num = int(seq_lens_this_time.numpy().sum()) + # Create output buffers (int64 for kernel compatibility) + position_ids = paddle.zeros([sum_token_num], dtype="int64") + slot_mapping = paddle.zeros([sum_token_num], dtype="int64") + + # Kernel writes directly to buffers + get_position_ids_and_slot_mapping( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + position_ids, + slot_mapping, + block_size, + ) + + return position_ids.numpy(), slot_mapping.numpy() + + def _generate_batch_id_per_token(self, seq_lens_this_time, bsz): + """Generate batch_id_per_token based on seq_lens_this_time.""" + total_tokens = int(seq_lens_this_time.numpy().sum()) + batch_id_per_token = np.zeros([total_tokens], dtype=np.int32) + offset = 0 + for bid in range(bsz): + seq_len = int(seq_lens_this_time[bid].numpy()) + batch_id_per_token[offset : offset + seq_len] = bid + offset += seq_len + return paddle.to_tensor(batch_id_per_token, dtype="int32", place=paddle.CUDAPlace(0)) + + def _generate_block_tables(self, bsz, max_num_blocks): + """Generate block_tables with sequential block ids for reproducibility.""" + block_tables = np.arange(bsz * max_num_blocks, dtype=np.int32).reshape(bsz, max_num_blocks) + return paddle.to_tensor(block_tables, dtype="int32", place=paddle.CUDAPlace(0)) + + def test_single_batch_decode(self): + """Test single batch in decode stage.""" + # Decode stage: already processed 10 tokens, now decode 1 more + seq_lens_encoder = paddle.to_tensor([0], dtype="int32") # decode stage + seq_lens_decoder = paddle.to_tensor([10], dtype="int32") # history length + seq_lens_this_time = paddle.to_tensor([1], dtype="int32") # current step + + batch_id_per_token = paddle.to_tensor([0], dtype="int32") + block_tables = self._generate_block_tables(1, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10], slot_old=[10] (block_id=0, block_offset=10, slot=0*64+10=10) + # logger.info(f"test_single_batch_decode: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new, "position_ids mismatch") + np.testing.assert_array_equal(slot_old, slot_new, "slot_mapping mismatch") + + # Verify position_id starts from seq_lens_decoder + self.assertEqual(pos_new[0], 10) + + def test_single_batch_prefill(self): + """Test single batch in prefill stage.""" + # Prefill stage: no history, processing 5 tokens + seq_lens_encoder = paddle.to_tensor([5], dtype="int32") # prefill length + seq_lens_decoder = paddle.to_tensor([0], dtype="int32") # no history + seq_lens_this_time = paddle.to_tensor([5], dtype="int32") # current step + + batch_id_per_token = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32") + block_tables = self._generate_block_tables(1, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[0,1,2,3,4], slot_old=[0,1,2,3,4] (all in block 0) + # logger.info(f"test_single_batch_prefill: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new, "position_ids mismatch") + np.testing.assert_array_equal(slot_old, slot_new, "slot_mapping mismatch") + + # Verify position_ids start from 0 + np.testing.assert_array_equal(pos_new, np.array([0, 1, 2, 3, 4])) + + def test_multiple_batches_decode(self): + """Test multiple batches all in decode stage.""" + # Batch 0: history 10, now 1 + # Batch 1: history 20, now 2 + seq_lens_encoder = paddle.to_tensor([0, 0], dtype="int32") # both decode + seq_lens_decoder = paddle.to_tensor([10, 20], dtype="int32") # history lengths + seq_lens_this_time = paddle.to_tensor([1, 2], dtype="int32") # current step + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, 2) + block_tables = self._generate_block_tables(2, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10,20,21] + # Batch 0: position_id=10, block_id=0, block_offset=10, slot=10 + # Batch 1: position_ids=[20,21], batch_id=1, block_tables[1][0]=100 + # slot[1]=100*64+20=6420, slot[2]=100*64+21=6421 + # logger.info(f"test_multiple_batches_decode: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new, "position_ids mismatch") + np.testing.assert_array_equal(slot_old, slot_new, "slot_mapping mismatch") + + # Batch 0: position_id = 10 + # Batch 1: position_ids = [20, 21] + np.testing.assert_array_equal(pos_new, np.array([10, 20, 21])) + + def test_different_block_sizes(self): + """Test with different block sizes.""" + for block_size in [1, 8, 16, 32, 64]: + with self.subTest(block_size=block_size): + seq_lens_encoder = paddle.to_tensor([0], dtype="int32") # decode + seq_lens_decoder = paddle.to_tensor([10], dtype="int32") # history + seq_lens_this_time = paddle.to_tensor([5], dtype="int32") # current + batch_id_per_token = paddle.to_tensor([0] * 5, dtype="int32") + block_tables = self._generate_block_tables(1, 100) + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ) + # Expected: pos_old=[10,11,12,13,14] + # For block_size=64: block_id=0, slot=[10,11,12,13,14] + # For block_size=16: block_id=0 for all (10-14<16), slot=[10,11,12,13,14] + # logger.info(f"test_different_block_sizes[{block_size}]: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + def test_block_boundary_crossing(self): + """Test tokens crossing block boundaries.""" + # block_size=64, history=60, so position_ids will be [60, 61, 62, 63, 64] + # This crosses the block boundary (60-63 in block 0, 64 in block 1) + seq_lens_encoder = paddle.to_tensor([0], dtype="int32") # decode + seq_lens_decoder = paddle.to_tensor([60], dtype="int32") # history + seq_lens_this_time = paddle.to_tensor([5], dtype="int32") # current + batch_id_per_token = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32") + block_tables = self._generate_block_tables(1, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[60,61,62,63,64] + # position 60-63: block_id=0, block_offset=60-63, slot=60-63 + # position 64: block_id=1, block_offset=0, slot=64 + # logger.info(f"test_block_boundary_crossing: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + # Verify position_ids + np.testing.assert_array_equal(pos_new, np.array([60, 61, 62, 63, 64])) + + def test_large_batch(self): + """Test with larger batch size.""" + bsz = 16 + # All in decode stage + seq_lens_encoder = paddle.to_tensor([0] * bsz, dtype="int32") + seq_lens_decoder = paddle.to_tensor(np.random.randint(0, 100, size=bsz), dtype="int32") + seq_lens_this_time = paddle.to_tensor(np.random.randint(1, 5, size=bsz), dtype="int32") + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, bsz) + block_tables = self._generate_block_tables(bsz, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Too many tokens to list expected values + # logger.info(f"test_large_batch: shape pos_old={pos_old.shape}, slot_old={slot_old.shape}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + def test_empty_batch(self): + """Test with some batches having zero tokens this step.""" + # Batch 0: decode (1 token) + # Batch 1: skip (0 tokens) + # Batch 2: decode (2 tokens) + seq_lens_encoder = paddle.to_tensor([0, 0, 0], dtype="int32") # all decode + seq_lens_decoder = paddle.to_tensor([10, 20, 5], dtype="int32") # history + seq_lens_this_time = paddle.to_tensor([1, 0, 2], dtype="int32") # current + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, 3) + block_tables = self._generate_block_tables(3, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10,5,6] + # Batch 0: position_id=10, batch_id=0, block_id=0, slot=10 + # Batch 2: position_ids=[5,6], batch_id=2, block_tables[2][0]=200 + # slot[1]=200*64+5=12805, slot[2]=200*64+6=12806 + # logger.info(f"test_empty_batch: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + # Batch 0: position_id = 10 + # Batch 1: skipped + # Batch 2: position_ids = [5, 6] + np.testing.assert_array_equal(pos_new, np.array([10, 5, 6])) + + def test_mtp_scenario(self): + """Test MTP scenario where seq_lens_this_time varies per batch.""" + # All in decode stage, different accepted tokens per batch + seq_lens_encoder = paddle.to_tensor([0, 0], dtype="int32") # decode + seq_lens_decoder = paddle.to_tensor([10, 20], dtype="int32") # history + # Batch 0: 2 accepted tokens, Batch 1: 1 accepted token + seq_lens_this_time = paddle.to_tensor([2, 1], dtype="int32") + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, 2) + block_tables = self._generate_block_tables(2, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10,11,20] + # Batch 0: position_ids=[10,11], batch_id=0, block_id=0, slot=[10,11] + # Batch 1: position_ids=[20], batch_id=1, block_tables[1][0]=100 + # slot[2]=100*64+20=6420 + # logger.info(f"test_mtp_scenario: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + # Batch 0: position_ids = [10, 11] + # Batch 1: position_id = [20] + np.testing.assert_array_equal(pos_new, np.array([10, 11, 20])) + + +if __name__ == "__main__": + unittest.main() From eeed8a37a5dce34be939567f02e13b5313e2f51d Mon Sep 17 00:00:00 2001 From: RAM Date: Tue, 2 Jun 2026 12:07:54 +0800 Subject: [PATCH 139/143] [RL] Fix Ernie mm bug (#7966) * Reset buffer size of R3 * refine code * R3 fix Eos bug * pre-commit * fix r3 ci and support dsa * refine code * refine code * reset ci dir * refine code * fix dsv3 * fix ernie5 mm bug --- fastdeploy/model_executor/layers/moe/routing_indices_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index 392eb3a238d..a4edbbfb724 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -217,7 +217,8 @@ def prepare_pending_save( if self.fd_config.scheduler_config.enable_overlap_schedule: num_tokens = self.token_num_overlap slot_mapping_gpu = slot_mapping_gpu[:num_tokens] - position_ids_gpu = position_ids_gpu[:num_tokens] + if position_ids_gpu is not None: + position_ids_gpu = position_ids_gpu[:num_tokens] # D2D: GPU → staging self.routing_staging_buf.copy_(self.device_routing_buffer, False) From 780c00089d5b6ad72d21193b6a5a51777255ebde Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Wed, 3 Jun 2026 10:57:56 +0800 Subject: [PATCH 140/143] [Cherry-Pick][RL][Feature] Add GDR streaming weight update path (#7951) (#7971) * Add GDR streaming weight update path * [RL] Unify GDR and IPC weight update --- fastdeploy/config.py | 2 + fastdeploy/envs.py | 2 + fastdeploy/model_executor/utils.py | 33 + fastdeploy/rl/dynamic_weight_manager.py | 200 +++++- fastdeploy/worker/gpu_model_runner.py | 99 ++- fastdeploy/worker/worker_process.py | 1 + .../test_model_executor_utils.py | 131 ++++ tests/rl/test_dynamic_weight_gdr.py | 574 ++++++++++++++++++ 8 files changed, 1020 insertions(+), 22 deletions(-) create mode 100644 tests/rl/test_dynamic_weight_gdr.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e0298a015ba..1d32f386910 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1454,6 +1454,8 @@ def __init__( self.rsync_config: Optional[Dict[str, Any]] = None for key, value in args.items(): if hasattr(self, key): + if key == "rsync_config" and isinstance(value, str): + value = json.loads(value) setattr(self, key, value) def __str__(self) -> str: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 58edb8ca026..f030478f1b3 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -277,6 +277,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # Whether to use GDR CheckpointTransfer for dynamic weight updates. + "FD_USE_GDR_CHECKPOINT_TRANSFER": lambda: bool(int(os.getenv("FD_USE_GDR_CHECKPOINT_TRANSFER", "0"))), # train-infer consistency, used in RL # Whether to align RoPE and moe gate precision with training "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 960d8f23f7e..952b82b2a60 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -131,6 +131,35 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1): return weight_or_paramter +def _is_gdr_checkpoint_transfer_dynamic_load_config(fd_config: FDConfig) -> bool: + load_config = fd_config.load_config + if not load_config.dynamic_load_weight: + return False + return envs.FD_USE_GDR_CHECKPOINT_TRANSFER + + +def _copy_gdr_checkpoint_transfer_transposed_weight_attrs(src, dst): + attr_names = ( + "weight_loader", + "output_dim", + "weight_need_transpose", + "is_distributed", + "split_axis", + "tp_row_bias", + ) + for name in attr_names: + if hasattr(src, name): + setattr(dst, name, getattr(src, name)) + if hasattr(src, "output_dim") and src.output_dim is not None: + dst.output_dim = not src.output_dim + dst.weight_need_transpose = not getattr(src, "weight_need_transpose", False) + if hasattr(src, "split_axis"): + if len(src.shape) == 2 and src.split_axis in (0, 1): + dst.split_axis = 1 - src.split_axis + elif len(src.shape) == 3 and src.split_axis in (1, 2): + dst.split_axis = 3 - src.split_axis + + def process_weight_transpose(layer, weight_name): weight = getattr(layer, weight_name) if len(weight.shape) == 2: @@ -143,6 +172,8 @@ def process_weight_transpose(layer, weight_name): default_initializer=paddle.nn.initializer.Constant(0), is_bias=False, ) + if _is_gdr_checkpoint_transfer_dynamic_load_config(layer.fd_config): + _copy_gdr_checkpoint_transfer_transposed_weight_attrs(weight, weight_tmp) if layer.fd_config.load_config.dynamic_load_weight or getattr(layer.fd_config.model_config, "enable_cache", False): free_tensor(weight) setattr(layer, weight_name, weight_tmp) @@ -348,6 +379,8 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" ) loaded_weight = get_tensor(loaded_weight) + if not param._is_initialized(): + param.initialize() param.copy_(loaded_weight, False) return fn diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index c30da6f9124..5a4666d46b5 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -14,19 +14,21 @@ # limitations under the License. """ +import asyncio import gc import glob import os import re import time from multiprocessing.shared_memory import SharedMemory -from typing import Any, Dict, List +from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np import paddle import yaml from paddleformers.utils.log import logger +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus @@ -52,10 +54,15 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int): self.model_list = models self._capture_model_state() self.rdma_handle = None - if self.load_config.load_strategy == "rsync": - self.update_weights_by_rdma() + self.use_gdr_checkpoint_transfer = envs.FD_USE_GDR_CHECKPOINT_TRANSFER + + if self.use_gdr_checkpoint_transfer: + self.update_weights_by_gdr() else: - self.update_parameters() + if self.load_config.load_strategy == "rsync": + self.update_weights_by_rdma() + else: + self.update_parameters() self.finalize_update() logger.info( @@ -64,14 +71,20 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int): ) @paddle.no_grad() - def _capture_model_state(self): + def _capture_model_state(self, log_params: bool = True): """Capture and store initial model parameters state.""" + self.state_dict = {} for model in self.model_list: for name, param in model.state_dict().items(): - logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}") + if log_params: + logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}") self.state_dict[name] = param - def update_weights_by_rdma(self, version: str = None, verify_checksum: bool = False): + def update_weights_by_rdma( + self, + version: str = None, + verify_checksum: bool = False, + ): def valid_parameters(old_state_dict, new_state_dict): is_valid = True for key in new_state_dict: @@ -92,14 +105,7 @@ def valid_parameters(old_state_dict, new_state_dict): ) return is_valid - bootstrap_load = version is None or version == "" - if bootstrap_load: - version = self.read_model_version_from_file() - if version is None or version == "": - raise Exception( - "rsync model version not set, please set it in 1) {model_version}/version.yaml " - "or 2) interface arguments 'version'" - ) + version, bootstrap_load = self._resolve_weight_update_version(version) logger.info( f"START rank:{self.local_rank}/{self.nranks} update_weights_by_rdma, " @@ -151,6 +157,164 @@ def valid_parameters(old_state_dict, new_state_dict): "rank": self.local_rank, } + def update_weights_by_gdr( + self, version: str = None, verify_checksum: bool = False, restore_cleared_params: bool = False + ): + """Unified weight update via CheckpointTransfer (supports GDR and IPC backends).""" + config = dict(self.fd_config.load_config.rsync_config or {}) + is_ipc = self.load_config.load_strategy != "rsync" + + if is_ipc: + step_id = version or "0" + else: + version, _ = self._resolve_weight_update_version(version) + step_id = version + + logger.info( + f"START rank:{self.local_rank}/{self.nranks} update_weights_by_gdr, " + f"load_strategy:{self.load_config.load_strategy}, step_id:{step_id}" + ) + + from checkpoint_transfer.transfer import CheckpointTransfer + + transfer_config = self._build_ct_transfer_config(config) + logger.info(f"CheckpointTransfer config:{transfer_config}") + ct_handle = CheckpointTransfer(transfer_config) + + total_start = time.perf_counter() + asyncio.run(ct_handle.initialize()) + try: + weights_iterator = ct_handle.receive_weights_sync(step_id=step_id, output_framework="paddle") + + if restore_cleared_params: + for name, target_param in self.state_dict.items(): + if not target_param._is_initialized(): + paddle.empty(target_param.shape, dtype=target_param.dtype)._share_buffer_to(target_param) + logger.debug(f"Restored cleared parameter storage before GDR checkpoint transfer load: {name}") + update_count, mtp_cache_count = self._load_models_from_weight_iterator(weights_iterator) + finally: + asyncio.run(ct_handle.cleanup()) + self._capture_model_state(log_params=False) + total_cost = time.perf_counter() - total_start + logger.info( + f"END update_weights_by_gdr, cost {total_cost:.2f} seconds, " + f"weights:{update_count}, mtp_cached_weights:{mtp_cache_count}, " + f"step_id:{step_id}, local_rank:{self.local_rank}" + ) + return { + "update_cost": total_cost, + "total_cost": total_cost, + "version": step_id, + "rank": self.local_rank, + "update_count": update_count, + "mtp_cache_count": mtp_cache_count, + } + + def _build_ct_transfer_config(self, config: dict): + from dataclasses import fields + + from checkpoint_transfer.config import Phase1Backend, Role, TransferConfig + + transfer_config = dict(config) + if "device_name" in transfer_config and "device" not in transfer_config: + transfer_config["device"] = transfer_config.pop("device_name") + else: + transfer_config.pop("device_name", None) + + transfer_config["role"] = Role.INFERENCE + + if self.load_config.load_strategy == "rsync": + node_index = int(transfer_config.pop("index", 0)) + transfer_config["global_rank"] = node_index * self.nranks + self.local_rank + transfer_config["phase1_backend"] = Phase1Backend.GPU_DIRECT + transfer_config["group_size"] = int(transfer_config.get("group_size", self.nranks)) + else: + transfer_config.pop("index", None) + gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) + transfer_config["global_rank"] = gpu_id + transfer_config["phase1_backend"] = Phase1Backend.IPC + transfer_config["group_size"] = int(transfer_config.get("group_size", self.nranks)) + transfer_config["qsize"] = int(transfer_config.get("qsize", 2)) + + transfer_config_keys = {field.name for field in fields(TransferConfig)} + transfer_config = {key: value for key, value in transfer_config.items() if key in transfer_config_keys} + return TransferConfig(**transfer_config) + + def _resolve_weight_update_version(self, version: Optional[str]) -> Tuple[str, bool]: + bootstrap_load = version is None or version == "" + if bootstrap_load: + version = self.read_model_version_from_file() + if version is None or version == "": + raise Exception( + "rsync model version not set, please set it in 1) {model_version}/version.yaml " + "or 2) interface arguments 'version'" + ) + return version, bootstrap_load + + def _load_models_from_weight_iterator( + self, + weights_iterator: Iterable[Tuple[str, Any]], + ) -> Tuple[int, int]: + update_count = 0 + + if len(self.model_list) == 1: + + def count_weights(): + nonlocal update_count + for item in weights_iterator: + update_count += 1 + yield item + + self.model_list[0].load_weights(count_weights()) + return update_count, 0 + + mtp_models = self.model_list[1:] + config = self.fd_config.load_config.rsync_config or {} + mtp_chunk_size = max(1, int(config.get("gdr_mtp_chunk_size", 16))) + mtp_chunk: List[Tuple[str, Any]] = [] + mtp_cache_count = 0 + mtp_weight_tokens = ["mtp_", "mtp_block"] + for model in mtp_models: + model_config = getattr(getattr(model, "fd_config", None), "model_config", None) + start_layer = getattr(model, "mtp_start_layer_idx", None) + num_layers = getattr(model, "num_mtp_layers", None) + start_layer = start_layer if start_layer is not None else getattr(model_config, "start_layer_index", None) + num_layers = ( + num_layers if num_layers is not None else getattr(model_config, "num_nextn_predict_layers", None) + ) + if start_layer is None or num_layers is None: + continue + for layer_id in range(int(start_layer), int(start_layer) + int(num_layers)): + mtp_weight_tokens.append(f"layers.{layer_id}.") + mtp_weight_tokens.append(f".layers.{layer_id}.") + + def flush_mtp_chunk(): + nonlocal mtp_chunk + if not mtp_chunk: + return + for model in mtp_models: + model.load_weights(iter(mtp_chunk)) + mtp_chunk = [] + + def cache_mtp_weights(): + nonlocal update_count, mtp_cache_count + for item in weights_iterator: + name, _ = item + update_count += 1 + if any(token in name for token in mtp_weight_tokens): + mtp_chunk.append(item) + mtp_cache_count += 1 + yield item + if len(mtp_chunk) >= mtp_chunk_size: + flush_mtp_chunk() + + self.model_list[0].load_weights(cache_mtp_weights()) + flush_mtp_chunk() + if mtp_cache_count == 0: + raise ValueError("No MTP weights were cached from the GDR stream for auxiliary model loading.") + + return update_count, mtp_cache_count + def update_parameters(self, pid: int = 0, restart_process_group=False) -> None: """Core method to update model parameters based on strategy.""" start_time = time.perf_counter() @@ -414,7 +578,7 @@ def _validate_parameter_match(self, name: str, src: paddle.Tensor, dst: paddle.T if src.shape != dst.shape: raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}") - def finalize_update(self, pid: int = 0): + def finalize_update(self, pid: Optional[int] = None): """Finalize update process with verification.""" self._verify_parameters("update") @@ -479,8 +643,10 @@ def _log_memory(self, context: str): f"current_reserved: {curr_reserved:.2f}GB" ) - def _update_shared_status(self, pid: int, status: int) -> None: + def _update_shared_status(self, pid: Optional[int], status: int) -> None: """Update shared memory status flag for inter-process communication.""" + if pid is None: + pid = self.parallel_config.local_engine_worker_queue_port array = np.zeros([1], dtype=np.int32) shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}") value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4fb295434be..07cb4dfb67b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -94,7 +94,7 @@ from fastdeploy import envs from fastdeploy.engine.tasks import PoolingTask from fastdeploy.input.ernie4_5_vl_processor import DataProcessor -from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient +from fastdeploy.inter_communicator import IPCSignal, KVCacheStatus, ZmqIpcClient from fastdeploy.logger.deterministic_logger import DeterministicLogger from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata @@ -2994,9 +2994,16 @@ def clear_requests(self): def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" # Update parameters - self.dynamic_weight_manager.update_parameters( - pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle - ) + if self.dynamic_weight_manager.use_gdr_checkpoint_transfer: + if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: + self.dynamic_weight_manager.restart_communication_group() + if self.dynamic_weight_manager.parallel_config.enable_expert_parallel: + self.dynamic_weight_manager.recreate_deepep_buffer() + self.dynamic_weight_manager.update_weights_by_gdr(restore_cleared_params=True) + else: + self.dynamic_weight_manager.update_parameters( + pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle + ) # Reset share_inputs self.share_inputs.reset_share_inputs() @@ -3013,7 +3020,89 @@ def update_parameters(self, pid): self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def update_weights(self, version: str = None, verify_checksum: bool = False): - return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) + if self.dynamic_weight_manager.use_gdr_checkpoint_transfer: + release_cache = bool((self.fd_config.load_config.rsync_config or {}).get("gdr_release_cache", False)) + + cache_clear_cost = 0.0 + cache_rebuild_cost = 0.0 + if release_cache: + clear_start = time.perf_counter() + self._clear_cache_for_gdr_weight_update() + cache_clear_cost = time.perf_counter() - clear_start + + result = self.dynamic_weight_manager.update_weights_by_gdr(version, verify_checksum) + + if release_cache: + rebuild_start = time.perf_counter() + self._rebuild_cache_after_gdr_weight_update() + cache_rebuild_cost = time.perf_counter() - rebuild_start + + result["release_cache"] = release_cache + result["cache_clear_cost"] = cache_clear_cost + result["cache_rebuild_cost"] = cache_rebuild_cost + self.dynamic_weight_manager.finalize_update() + return result + else: + result = self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) + self.dynamic_weight_manager.finalize_update() + return result + + def _clear_cache_for_gdr_weight_update(self): + cache_flag = ( + self.fd_config.cache_config.num_cpu_blocks > 0 + or self.fd_config.cache_config.kvcache_storage_backend is not None + ) + kv_cache_status = self.kv_cache_status if cache_flag else None + if kv_cache_status: + kv_cache_status.value[0] = KVCacheStatus.CLEARING + if self.use_cudagraph: + self.model.clear_graph_opt_backend() + if envs.FD_USE_BLOCK_WISE_CUDA_GRAPH: + from fastdeploy.model_executor.graph_optimization.cuda_graph_op import ( + clear_all_block_wise_graphs, + ) + + clear_all_block_wise_graphs() + if ( + self.speculative_decoding + and self.spec_method == SpecMethod.MTP + and self.graph_opt_config.draft_model_use_cudagraph + ): + self.proposer.model.clear_graph_opt_backend() + if self.speculative_decoding and self.spec_method == SpecMethod.MTP: + self.proposer.clear_mtp_cache() + self.clear_cache() + if kv_cache_status: + while kv_cache_status.value[0] != KVCacheStatus.CLEARED: + time.sleep(0.01) + paddle.device.cuda.empty_cache() + self._cached_model_output_data = None + self._cached_sampler_output = None + self._cached_post_process_event = None + self._cached_launch_token_num = -1 + self._cached_real_bsz = -1 + + def _rebuild_cache_after_gdr_weight_update(self): + cache_flag = ( + self.fd_config.cache_config.num_cpu_blocks > 0 + or self.fd_config.cache_config.kvcache_storage_backend is not None + ) + kv_cache_status = self.kv_cache_status if cache_flag else None + if kv_cache_status: + kv_cache_status.value[0] = KVCacheStatus.UPDATING + self.share_inputs.reset_share_inputs() + if self.spec_method == SpecMethod.MTP: + self.proposer.model_inputs.reset_model_inputs() + if not self.enable_cache_manager_v1: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) + self.initialize_kv_cache() + if self.use_cudagraph: + self.capture_model() + if self.fd_config.routing_replay_config.enable_routing_replay: + self.routing_replay_manager.update_suspend_routing_replay() + if kv_cache_status: + while kv_cache_status.value[0] != KVCacheStatus.NORMAL: + time.sleep(0.01) def sleep(self, tags): diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 61e6fcda85f..c69dd2c859f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -260,6 +260,7 @@ def init_health_status(self) -> None: suffix=self.parallel_config.local_engine_worker_queue_port, create=False, ) + self.worker.model_runner.kv_cache_status = self.kv_cache_status # init exist_task_signal workers_exist_task = np.zeros([1], dtype=np.int32) diff --git a/tests/model_executor/test_model_executor_utils.py b/tests/model_executor/test_model_executor_utils.py index 98cba5c3302..701be987251 100644 --- a/tests/model_executor/test_model_executor_utils.py +++ b/tests/model_executor/test_model_executor_utils.py @@ -13,11 +13,16 @@ # limitations under the License. import unittest +import unittest.mock +from types import SimpleNamespace + +import paddle from fastdeploy.model_executor.utils import ( BitMaskTracker, TensorTracker, WeightsMapper, + process_weight_transpose, remap_weight_keys, set_weight_attrs, slice_fn, @@ -157,6 +162,132 @@ class Param: set_weight_attrs(p, None) # should not raise +class TestProcessWeightTranspose(unittest.TestCase): + def _make_layer(self, shape, dynamic_load_weight=True, load_strategy="rsync", rsync_config=None): + class Layer(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fd_config = SimpleNamespace( + load_config=SimpleNamespace( + dynamic_load_weight=dynamic_load_weight, + load_strategy=load_strategy, + rsync_config=rsync_config or {}, + ), + model_config=SimpleNamespace(enable_cache=False), + ) + self.weight = self.create_parameter( + shape=shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + is_bias=False, + ) + + return Layer() + + def test_gdr_dynamic_transpose_preserves_loading_attrs_for_future_updates(self): + def loader(): + return None + + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.weight_need_transpose = False + layer.weight.weight_loader = loader + layer.weight.is_distributed = True + layer.weight.split_axis = 1 + layer.weight.tensor_track = object() + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [4, 8]) + self.assertFalse(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertIs(layer.weight.weight_loader, loader) + self.assertTrue(layer.weight.is_distributed) + self.assertEqual(layer.weight.split_axis, 0) + self.assertFalse(hasattr(layer.weight, "tensor_track")) + + def test_gpu_direct_dynamic_transpose_preserves_loading_attrs(self): + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertFalse(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertEqual(layer.weight.split_axis, 0) + + def test_rdma_dynamic_transpose_does_not_preserve_loading_attrs(self): + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.split_axis = 1 + + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [4, 8]) + self.assertFalse(hasattr(layer.weight, "output_dim")) + self.assertFalse(hasattr(layer.weight, "weight_need_transpose")) + self.assertFalse(hasattr(layer.weight, "split_axis")) + + def test_ct_ipc_dynamic_transpose_preserves_loading_attrs(self): + layer = self._make_layer([8, 4], load_strategy="ipc") + layer.weight.output_dim = True + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertFalse(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertEqual(layer.weight.split_axis, 0) + + def test_gdr_transpose_preserves_loading_attrs_for_3d_weight(self): + layer = self._make_layer([2, 8, 4]) + layer.weight.output_dim = False + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [2, 4, 8]) + self.assertTrue(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertEqual(layer.weight.split_axis, 2) + + def test_gdr_transpose_clears_weight_need_transpose_for_torch_format(self): + """Production scenario: torch format sets weight_need_transpose=True. + After transpose, param is in HF layout so no transpose needed on reload.""" + + def loader(): + return None + + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.weight_need_transpose = True + layer.weight.weight_loader = loader + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [4, 8]) + self.assertFalse(layer.weight.output_dim) + self.assertFalse(layer.weight.weight_need_transpose) + self.assertIs(layer.weight.weight_loader, loader) + self.assertEqual(layer.weight.split_axis, 0) + + def test_gdr_transpose_preserves_none_output_dim(self): + layer = self._make_layer([8, 4]) + layer.weight.output_dim = None + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertIsNone(layer.weight.output_dim) + + class TestSliceFn(unittest.TestCase): def test_1d_slice(self): import numpy as np diff --git a/tests/rl/test_dynamic_weight_gdr.py b/tests/rl/test_dynamic_weight_gdr.py new file mode 100644 index 00000000000..f6b1be0baad --- /dev/null +++ b/tests/rl/test_dynamic_weight_gdr.py @@ -0,0 +1,574 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib.util +import sys +import types +import unittest +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from unittest.mock import MagicMock, patch + +_DYNAMIC_WEIGHT_MODULE = None + + +def _install_dynamic_weight_manager_stubs(): + """Install minimal stubs so this unit test can run without Paddle installed.""" + + def no_grad(): + def decorator(func): + return func + + return decorator + + fake_paddle = types.SimpleNamespace( + Tensor=object, + no_grad=no_grad, + distributed=types.SimpleNamespace( + get_world_size=lambda: 1, + get_rank=lambda: 0, + barrier=lambda *args, **kwargs: None, + restart_process_group=lambda *args, **kwargs: None, + shutdown_process_group=lambda *args, **kwargs: None, + ), + device=types.SimpleNamespace( + cuda=types.SimpleNamespace( + synchronize=lambda: None, + empty_cache=lambda: None, + max_memory_allocated=lambda: 0, + max_memory_reserved=lambda: 0, + memory_allocated=lambda: 0, + memory_reserved=lambda: 0, + ) + ), + base=types.SimpleNamespace( + core=types.SimpleNamespace(LoDTensor=types.SimpleNamespace(_new_shared_cuda=MagicMock())) + ), + load=MagicMock(), + empty=MagicMock(), + to_tensor=MagicMock(), + ) + fake_logger = types.SimpleNamespace( + info=MagicMock(), + warning=MagicMock(), + error=MagicMock(), + debug=MagicMock(), + ) + fake_fastdeploy = types.ModuleType("fastdeploy") + fake_fastdeploy.__path__ = [] + fake_config = types.ModuleType("fastdeploy.config") + fake_config.FDConfig = object + fake_model_executor = types.ModuleType("fastdeploy.model_executor") + fake_model_executor.__path__ = [] + fake_model_executor_utils = types.ModuleType("fastdeploy.model_executor.utils") + fake_model_executor_utils.process_final_after_loading = MagicMock() + fake_numpy = types.ModuleType("numpy") + fake_envs = types.ModuleType("fastdeploy.envs") + fake_envs.FD_USE_GDR_CHECKPOINT_TRANSFER = False + fake_inter_communicator = types.ModuleType("fastdeploy.inter_communicator") + fake_inter_communicator.KVCacheStatus = types.SimpleNamespace() + fake_inter_communicator.ModelWeightsStatus = types.SimpleNamespace(NORMAL=0, CLEARED=1) + fake_yaml = types.ModuleType("yaml") + fake_yaml.safe_load = MagicMock(return_value={}) + fake_yaml.YAMLError = Exception + + sys.modules.update( + { + "paddle": fake_paddle, + "numpy": fake_numpy, + "yaml": fake_yaml, + "paddleformers": types.ModuleType("paddleformers"), + "paddleformers.utils": types.ModuleType("paddleformers.utils"), + "paddleformers.utils.log": types.SimpleNamespace(logger=fake_logger), + "fastdeploy": fake_fastdeploy, + "fastdeploy.envs": fake_envs, + "fastdeploy.config": fake_config, + "fastdeploy.model_executor": fake_model_executor, + "fastdeploy.model_executor.utils": fake_model_executor_utils, + "fastdeploy.inter_communicator": fake_inter_communicator, + } + ) + + +def _load_dynamic_weight_manager_from_file(): + module_path = Path(__file__).resolve().parents[2] / "fastdeploy" / "rl" / "dynamic_weight_manager.py" + spec = importlib.util.spec_from_file_location("dynamic_weight_manager_under_test", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _load_dynamic_weight_manager_module(): + global _DYNAMIC_WEIGHT_MODULE + if _DYNAMIC_WEIGHT_MODULE is not None: + return _DYNAMIC_WEIGHT_MODULE + + fastdeploy_module = sys.modules.get("fastdeploy") + if fastdeploy_module is not None and not hasattr(fastdeploy_module, "__path__"): + _DYNAMIC_WEIGHT_MODULE = _load_dynamic_weight_manager_from_file() + return _DYNAMIC_WEIGHT_MODULE + + try: + from fastdeploy.rl import dynamic_weight_manager + + _DYNAMIC_WEIGHT_MODULE = dynamic_weight_manager + return dynamic_weight_manager + except ModuleNotFoundError as exc: + if exc.name not in ("numpy", "paddle", "yaml"): + raise + + for name in list(sys.modules): + if name == "fastdeploy" or name.startswith("fastdeploy."): + sys.modules.pop(name, None) + _install_dynamic_weight_manager_stubs() + + _DYNAMIC_WEIGHT_MODULE = _load_dynamic_weight_manager_from_file() + return _DYNAMIC_WEIGHT_MODULE + + +class _FakeModel: + def __init__(self): + self.loaded = [] + self.params = {} + + def load_weights(self, weights_iterator): + self.loaded.extend(list(weights_iterator)) + + def state_dict(self): + return self.params + + +class _FakeMTPModel(_FakeModel): + def __init__(self, mtp_start_layer_idx=2, num_mtp_layers=1): + super().__init__() + self.mtp_start_layer_idx = mtp_start_layer_idx + self.num_mtp_layers = num_mtp_layers + + +def _make_manager(rsync_config=None, load_strategy="rsync"): + DynamicWeightManager = _load_dynamic_weight_manager_module().DynamicWeightManager + + manager = object.__new__(DynamicWeightManager) + fd_config = MagicMock() + fd_config.load_config.rsync_config = rsync_config or { + "backend": "mooncake", + "output_framework": "paddle", + } + fd_config.load_config.load_strategy = load_strategy + fd_config.parallel_config.data_parallel_rank = 2 + fd_config.parallel_config.data_parallel_size = 1 + fd_config.parallel_config.tensor_parallel_rank = 1 + fd_config.parallel_config.tensor_parallel_size = 4 + manager.fd_config = fd_config + manager.load_config = fd_config.load_config + manager.parallel_config = fd_config.parallel_config + manager.local_rank = 5 + manager.nranks = 8 + manager.rdma_handle = None + manager.model_list = [_FakeModel()] + manager.state_dict = {} + manager.use_gdr_checkpoint_transfer = True + return manager + + +class _FakeRole(Enum): + TRAINER = "trainer" + INFERENCE = "inference" + + +class _FakePhase1Backend(Enum): + GPU_DIRECT = "gpu_direct" + MOONCAKE = "mooncake" + IPC = "ipc" + + +@dataclass +class _FakeTransferConfig: + role: object + global_rank: int + group_size: int = 1 + phase1_backend: object = _FakePhase1Backend.GPU_DIRECT + phase2_backend: object = None + phase2_fan_out: int = 4 + bucket_size_mb: int = 512 + num_buffers: int = 2 + redis_host: str = "127.0.0.1" + redis_port: int = 6379 + discover_timeout_s: float = 60.0 + redis_ttl_s: int = 60 + recv_bucket_timeout_s: float = 60.0 + session_total_timeout_s: float = 600.0 + device: str = None + log_level: str = None + log_file: str = None + perf_log_file: str = None + materialize_tensors: bool = True + qsize: int = 3 + gpu_id: int = -1 + + def __post_init__(self): + self.kwargs = dict(self.__dict__) + self.kwargs.pop("kwargs", None) + + +def _patch_gdr_checkpoint_transfer(fake_checkpoint_transfer): + class FakeCheckpointTransferWithLifecycle(fake_checkpoint_transfer): + async def initialize(self): + self.initialized = True + + async def cleanup(self): + self.cleaned = True + + fake_config_module = types.SimpleNamespace( + Role=_FakeRole, + TransferConfig=_FakeTransferConfig, + Phase1Backend=_FakePhase1Backend, + ) + fake_transfer_module = types.SimpleNamespace(CheckpointTransfer=FakeCheckpointTransferWithLifecycle) + return patch.dict( + sys.modules, + { + "checkpoint_transfer.config": fake_config_module, + "checkpoint_transfer.transfer": fake_transfer_module, + }, + ) + + +class TestDynamicWeightGDR(unittest.TestCase): + def test_update_weights_by_gdr_gdr_mode(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + self.step_id = step_id + self.output_framework = output_framework + yield "model.layers.0.weight", object() + + manager = _make_manager() + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + result = manager.update_weights_by_gdr(version="step-1") + + self.assertEqual(result["version"], "step-1") + self.assertEqual(result["update_count"], 1) + self.assertIn("total_cost", result) + self.assertEqual(manager.model_list[0].loaded[0][0], "model.layers.0.weight") + self.assertTrue(created[0].initialized) + self.assertTrue(created[0].cleaned) + self.assertEqual(created[0].step_id, "step-1") + self.assertEqual(created[0].output_framework, "paddle") + self.assertEqual(created[0].config.kwargs["role"], _FakeRole.INFERENCE) + self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT) + self.assertEqual(created[0].config.kwargs["global_rank"], 5) + self.assertEqual(created[0].config.kwargs["group_size"], 8) + self.assertNotIn("backend", created[0].config.kwargs) + self.assertNotIn("output_framework", created[0].config.kwargs) + + def test_update_weights_by_gdr_ipc_mode(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + self.step_id = step_id + yield "model.layers.0.weight", object() + + manager = _make_manager( + rsync_config={"redis_host": "10.0.0.1", "redis_port": 6379}, + load_strategy="ipc", + ) + + with ( + _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer), + patch.dict("os.environ", {"FLAGS_selected_gpus": "3"}), + ): + result = manager.update_weights_by_gdr() + + self.assertEqual(result["version"], "0") + self.assertEqual(created[0].step_id, "0") + self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.IPC) + self.assertEqual(created[0].config.kwargs["global_rank"], 3) + self.assertEqual(created[0].config.kwargs["qsize"], 2) + + def test_gdr_checkpoint_transfer_receive_exception_propagates(self): + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.0.weight", object() + raise RuntimeError("receive failed") + + class IncrementalModel(_FakeModel): + def load_weights(self, weights_iterator): + for item in weights_iterator: + self.loaded.append(item) + + manager = _make_manager() + manager.model_list = [IncrementalModel()] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + with self.assertRaisesRegex(RuntimeError, "receive failed"): + manager.update_weights_by_gdr(version="step-error") + + def test_gdr_checkpoint_transfer_refreshes_state_dict_after_model_loader(self): + loaded_param = object() + + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.weight", loaded_param + + class RefreshingModel(_FakeModel): + def load_weights(self, weights_iterator): + super().load_weights(weights_iterator) + self.params["model.weight"] = loaded_param + + manager = _make_manager() + manager.model_list = [RefreshingModel()] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-refresh") + + self.assertIs(manager.state_dict["model.weight"], loaded_param) + + def test_gdr_checkpoint_transfer_caches_mtp_subset_for_auxiliary_model(self): + objects = [object() for _ in range(4)] + + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.0.self_attn.q_proj.weight", objects[0] + yield "model.layers.2.self_attn.q_proj.weight", objects[1] + yield "model.layers.20.self_attn.q_proj.weight", objects[2] + yield "ernie.mtp_linear_proj.0.weight", objects[3] + + manager = _make_manager() + main_model = _FakeModel() + mtp_model = _FakeMTPModel(mtp_start_layer_idx=2, num_mtp_layers=1) + manager.model_list = [main_model, mtp_model] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + result = manager.update_weights_by_gdr(version="step-5") + + self.assertEqual(result["update_count"], 4) + self.assertEqual(result["mtp_cache_count"], 2) + self.assertEqual( + [name for name, _ in main_model.loaded], + [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.2.self_attn.q_proj.weight", + "model.layers.20.self_attn.q_proj.weight", + "ernie.mtp_linear_proj.0.weight", + ], + ) + self.assertEqual( + [name for name, _ in mtp_model.loaded], + [ + "model.layers.2.self_attn.q_proj.weight", + "ernie.mtp_linear_proj.0.weight", + ], + ) + + def test_gdr_checkpoint_transfer_flushes_mtp_subset_by_chunk_limit(self): + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.2.self_attn.q_proj.weight", object() + yield "ernie.mtp_linear_proj.0.weight", object() + yield "model.layers.2.self_attn.o_proj.weight", object() + + class ChunkRecordingMTPModel(_FakeMTPModel): + def __init__(self): + super().__init__(mtp_start_layer_idx=2, num_mtp_layers=1) + self.load_calls = [] + + def load_weights(self, weights_iterator): + chunk = list(weights_iterator) + self.load_calls.append([name for name, _ in chunk]) + self.loaded.extend(chunk) + + manager = _make_manager( + { + "backend": "mooncake", + "output_framework": "paddle", + "gdr_mtp_chunk_size": 2, + } + ) + main_model = _FakeModel() + mtp_model = ChunkRecordingMTPModel() + manager.model_list = [main_model, mtp_model] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + result = manager.update_weights_by_gdr(version="step-8") + + self.assertEqual(result["mtp_cache_count"], 3) + self.assertEqual( + mtp_model.load_calls, + [ + [ + "model.layers.2.self_attn.q_proj.weight", + "ernie.mtp_linear_proj.0.weight", + ], + ["model.layers.2.self_attn.o_proj.weight"], + ], + ) + + def test_gdr_checkpoint_transfer_multi_model_requires_mtp_subset(self): + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.0.self_attn.q_proj.weight", object() + + manager = _make_manager() + manager.model_list = [_FakeModel(), _FakeMTPModel(mtp_start_layer_idx=2, num_mtp_layers=1)] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + with self.assertRaisesRegex(ValueError, "No MTP weights"): + manager.update_weights_by_gdr(version="step-5") + + def test_gdr_checkpoint_transfer_config_not_forwarded_to_transfer_config(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "w1", object() + + manager = _make_manager( + { + "backend": "mooncake", + "output_framework": "paddle", + } + ) + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-6") + + self.assertNotIn("gpu_direct", created[0].config.kwargs) + self.assertNotIn("output_framework", created[0].config.kwargs) + self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT) + + def test_gdr_checkpoint_transfer_computes_global_rank_from_node_index(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "w1", object() + + manager = _make_manager( + { + "index": 1, + "backend": "mooncake", + "output_framework": "paddle", + "group_size": 16, + } + ) + manager.local_rank = 5 + manager.nranks = 8 + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-index") + + self.assertEqual(created[0].config.kwargs["global_rank"], 13) + self.assertEqual(created[0].config.kwargs["group_size"], 16) + self.assertNotIn("index", created[0].config.kwargs) + + def test_gdr_checkpoint_transfer_config_deep_copied_before_forwarding(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "w1", object() + + rsync_config = { + "backend": "mooncake", + "output_framework": "paddle", + "device_name": "mlx5_0", + } + manager = _make_manager(rsync_config) + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-7") + + self.assertEqual(created[0].config.kwargs["device"], "mlx5_0") + self.assertEqual(rsync_config["device_name"], "mlx5_0") + + def test_finalize_update_uses_worker_queue_port_status_suffix(self): + module = _load_dynamic_weight_manager_module() + manager = _make_manager() + manager.first_load = False + manager.rank = 0 + manager.parallel_config.tensor_parallel_size = 1 + manager.parallel_config.enable_expert_parallel = False + manager.parallel_config.local_engine_worker_queue_port = 60572 + manager._verify_parameters = MagicMock() + + class FakeArray: + shape = (1,) + dtype = "int32" + nbytes = 4 + + class FakeValue: + def __init__(self): + self.writes = {} + + def __setitem__(self, key, value): + self.writes[key] = value + + fake_value = FakeValue() + with ( + patch.object(module.np, "int32", "int32", create=True), + patch.object(module.np, "zeros", return_value=FakeArray(), create=True), + patch.object(module.np, "ndarray", return_value=fake_value, create=True), + patch.object(module, "SharedMemory") as fake_shared_memory, + ): + manager.finalize_update() + + fake_shared_memory.assert_called_once_with(create=False, size=4, name="model_weights_status.60572") + self.assertEqual(fake_value.writes[0], module.ModelWeightsStatus.NORMAL) + + +if __name__ == "__main__": + unittest.main() From 99c7df1a127c569f3bb5213fe168e89e93e40caa Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Wed, 3 Jun 2026 11:41:25 +0800 Subject: [PATCH 141/143] fix moe accurate issue --- fastdeploy/model_executor/models/glm4_moe.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 0e5974c0410..afd3aaf51a4 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -308,11 +308,17 @@ def __init__( ): self.mlp = Glm4Moe(fd_config, layer_id, prefix=f"{prefix}.mlp") else: + expert_parallel_size = fd_config.parallel_config.expert_parallel_size + tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size + use_tp = tensor_parallel_size > 1 + use_ep = expert_parallel_size > 1 + merge_ffn_tp = use_tp and not use_ep self.mlp = Glm4MoeMLP( fd_config, intermediate_size=fd_config.model_config.intermediate_size, layer_id=layer_id, prefix=f"{prefix}.mlp", + reduce_results=not merge_ffn_tp, ) self.input_layernorm = RMSNorm( From f232ed97cfb14175fb0d299479a0b28ebd2d9add Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Wed, 3 Jun 2026 12:56:20 +0800 Subject: [PATCH 142/143] fix bug --- fastdeploy/model_executor/models/glm4_moe.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index afd3aaf51a4..0df0f7103cb 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -64,8 +64,12 @@ def __init__( reduce_results: bool = True, ) -> None: super().__init__() - self.enable_all_reduce_fusion = ( - fd_config.parallel_config.enable_flashinfer_allreduce_fusion and not reduce_results + self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size + self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size + self.use_tp = self.tensor_parallel_size > 1 + self.use_ep = self.expert_parallel_size > 1 + self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and ( + self.use_tp and not self.use_ep ) # shared experts not split when use_sequence_parallel_moe in ep + tp @@ -308,17 +312,11 @@ def __init__( ): self.mlp = Glm4Moe(fd_config, layer_id, prefix=f"{prefix}.mlp") else: - expert_parallel_size = fd_config.parallel_config.expert_parallel_size - tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size - use_tp = tensor_parallel_size > 1 - use_ep = expert_parallel_size > 1 - merge_ffn_tp = use_tp and not use_ep self.mlp = Glm4MoeMLP( fd_config, intermediate_size=fd_config.model_config.intermediate_size, layer_id=layer_id, prefix=f"{prefix}.mlp", - reduce_results=not merge_ffn_tp, ) self.input_layernorm = RMSNorm( From b5ec4fa99db8b625bbb19f88386c4d13b4705ebd Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Wed, 3 Jun 2026 17:15:34 +0800 Subject: [PATCH 143/143] add test --- tests/layers/trtllm_allreduce_rms_fusion.py | 68 +++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py index eea59d26743..117e2edbe32 100644 --- a/tests/layers/trtllm_allreduce_rms_fusion.py +++ b/tests/layers/trtllm_allreduce_rms_fusion.py @@ -772,6 +772,74 @@ def test_proxy_path_falls_back_when_token_too_large(self): self._assert_close_bf16(res, residual_full, msg="large-shape fallback residual mismatch") +class TestGlm4MoeMLPInit(unittest.TestCase): + """Cover Glm4MoeMLP.__init__ attribute assignments (glm4_moe.py:67-71).""" + + def _make_fd_config(self, tp_size, ep_size, enable_fusion, use_seq_parallel_moe=False): + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = tp_size + fd_config.parallel_config.expert_parallel_size = ep_size + fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion + fd_config.parallel_config.use_sequence_parallel_moe = use_seq_parallel_moe + fd_config.model_config = Mock() + fd_config.model_config.hidden_size = 64 + fd_config.model_config.hidden_act = "silu" + fd_config.model_config.moe_layer_start_index = 0 + return fd_config + + def _build(self, fd_config, layer_id=1): + # Patch heavy submodules so Glm4MoeMLP.__init__ runs without real deps. + with ( + patch("fastdeploy.model_executor.models.glm4_moe.MergedColumnParallelLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.MergedReplicatedLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.RowParallelLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.ReplicatedLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.SiluAndMul"), + ): + from fastdeploy.model_executor.models.glm4_moe import Glm4MoeMLP + + return Glm4MoeMLP( + fd_config=fd_config, + intermediate_size=128, + layer_id=layer_id, + prefix="model.layers.1.mlp", + ) + + def test_tp_only_fusion_enabled(self): + """tp>1, ep=1, fusion=True -> enable_all_reduce_fusion=True.""" + fd_config = self._make_fd_config(tp_size=4, ep_size=1, enable_fusion=True) + mlp = self._build(fd_config) + self.assertEqual(mlp.expert_parallel_size, 1) + self.assertEqual(mlp.tensor_parallel_size, 4) + self.assertTrue(mlp.use_tp) + self.assertFalse(mlp.use_ep) + self.assertTrue(mlp.enable_all_reduce_fusion) + + def test_ep_disables_fusion(self): + """ep>1 -> enable_all_reduce_fusion forced False even if flag is True.""" + fd_config = self._make_fd_config(tp_size=2, ep_size=2, enable_fusion=True) + mlp = self._build(fd_config) + self.assertTrue(mlp.use_tp) + self.assertTrue(mlp.use_ep) + self.assertFalse(mlp.enable_all_reduce_fusion) + + def test_single_gpu_no_fusion(self): + """tp=1, ep=1 -> use_tp/use_ep False, fusion False.""" + fd_config = self._make_fd_config(tp_size=1, ep_size=1, enable_fusion=True) + mlp = self._build(fd_config) + self.assertFalse(mlp.use_tp) + self.assertFalse(mlp.use_ep) + self.assertFalse(mlp.enable_all_reduce_fusion) + + def test_fusion_flag_off(self): + """flag False -> enable_all_reduce_fusion False regardless of tp.""" + fd_config = self._make_fd_config(tp_size=4, ep_size=1, enable_fusion=False) + mlp = self._build(fd_config) + self.assertTrue(mlp.use_tp) + self.assertFalse(mlp.enable_all_reduce_fusion) + + if __name__ == "__main__": """Run tests directly (called by subprocess after distributed launch)""" unittest.main(verbosity=2)