From 86c539b0be1d388cd5e05969804f7df9c6578d94 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Mon, 23 Mar 2026 21:18:02 +0800 Subject: [PATCH 1/5] Added Metric logits_stats to the ZMQ branch to ensure training-inference consistency --- fastdeploy/config.py | 1 + fastdeploy/engine/args_utils.py | 12 +++++ fastdeploy/engine/engine.py | 1 + fastdeploy/engine/request.py | 1 + fastdeploy/entrypoints/openai/protocol.py | 1 + fastdeploy/entrypoints/openai/serving_chat.py | 47 ++++++++++++---- fastdeploy/output/token_processor.py | 21 ++++++++ fastdeploy/worker/output.py | 54 +++++++++++++++++++ fastdeploy/worker/worker_process.py | 5 ++ 9 files changed, 134 insertions(+), 9 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index a26e694a0c4..c1879ae4a7c 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -209,6 +209,7 @@ def __init__( self.max_model_len = 0 self.dtype = "bfloat16" self.enable_logprob = False + self.compute_logits_stats = False self.max_logprobs = 20 self.logprobs_mode = "raw_logprobs" self.redundant_experts_num = 0 diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index f07c2d4157b..13d44698ea0 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -459,6 +459,12 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + compute_logits_stats: bool = False + """ + Flag to enable per-token logits statistics (min/max/mean/std) output. + Only effective when enable_logprob is True. + """ + max_logprobs: int = 20 """ Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the @@ -887,6 +893,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--compute-logits-stats", + action="store_true", + default=EngineArgs.compute_logits_stats, + help="Enable per-token logits statistics (min/max/mean/std) output.", + ) model_group.add_argument( "--max-logprobs", type=int, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 11b71c382a9..531f20176e2 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -644,6 +644,7 @@ def _start_worker_service(self): "use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage, "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, "enable_logprob": self.cfg.model_config.enable_logprob, + "compute_logits_stats": self.cfg.model_config.compute_logits_stats, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle, diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 391e2038534..668fab95609 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -727,6 +727,7 @@ class CompletionOutput: delta_message: Optional[DeltaMessage] = None multipart: Optional[list[Any]] = None num_image_tokens: Optional[int] = None + logits_stats: Optional[dict[str, float]] = None def to_dict(self): """ diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index f8cd70bca08..fd441a9d33d 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -294,6 +294,7 @@ class LogProbEntry(BaseModel): logprob: float bytes: Optional[List[int]] = None top_logprobs: Optional[List[LogProbEntry]] = None + logits_stats: Optional[Dict[str, float]] = None class LogProbs(BaseModel): diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index d5dce8db685..2e528bc85ad 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -817,22 +817,49 @@ def _create_chat_logprobs( request_decode_flag: Optional[bool] = True, ) -> Optional[LogProbs]: """Create OpenAI-style logprobs for chat completions.""" - if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs): + if ( + output_top_logprobs is None + or len(output_top_logprobs) < 3 + or any(not lst for lst in output_top_logprobs[:3]) + ): # check top 3 because logits_stats maybe None return None logprobs_res: Optional[LogProbs] = None - for logprob_token_ids, logprobs, sampled_token_ranks in zip( - output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2] - ): - top_logprobs = LogprobsLists( - logprob_token_ids=[logprob_token_ids], - logprobs=[logprobs], - sampled_token_ranks=[sampled_token_ranks], - ) + + # Extract logits stats from LogprobsLists if available + has_logits_stats = False if output_top_logprobs.logits_min is None else True + + # Iterate by index over mandatory fields; optionally include logits stats + num_tokens = len(output_top_logprobs.logprobs) + for idx in range(num_tokens): + logits_stats = None + if has_logits_stats: + top_logprobs = LogprobsLists( + logprob_token_ids=[output_top_logprobs.logprob_token_ids[idx]], + logprobs=[output_top_logprobs.logprobs[idx]], + sampled_token_ranks=[output_top_logprobs.sampled_token_ranks[idx]], + logits_min=[output_top_logprobs.logits_min[idx]], + logits_max=[output_top_logprobs.logits_max[idx]], + logits_mean=[output_top_logprobs.logits_mean[idx]], + logits_std=[output_top_logprobs.logits_std[idx]], + ) + logits_stats = { + "min": float(output_top_logprobs.logits_min[idx]), + "max": float(output_top_logprobs.logits_max[idx]), + "mean": float(output_top_logprobs.logits_mean[idx]), + "std": float(output_top_logprobs.logits_std[idx]), + } + else: + top_logprobs = LogprobsLists( + logprob_token_ids=[output_top_logprobs.logprob_token_ids[idx]], + logprobs=[output_top_logprobs.logprobs[idx]], + sampled_token_ranks=[output_top_logprobs.sampled_token_ranks[idx]], + ) step_logprobs_res = self._build_logprobs_response( request_logprobs=request_logprobs, response_logprobs=top_logprobs, request_top_logprobs=request_top_logprobs, request_decode_flag=request_decode_flag, + logits_stats=logits_stats, ) if logprobs_res is None: logprobs_res = step_logprobs_res @@ -846,6 +873,7 @@ def _build_logprobs_response( response_logprobs: Optional[LogprobsLists], request_top_logprobs: int, request_decode_flag: bool, + logits_stats: Optional[dict[str, float]] = None, ) -> Optional[LogProbs]: """ Construct a logprobs response object in line with the OpenAI style. @@ -893,6 +921,7 @@ def _build_logprobs_response( logprob=top_logprob_entries[0].logprob, bytes=top_logprob_entries[0].bytes, top_logprobs=top_logprob_entries[1:], # Here are the complete topk candidates + logits_stats=logits_stats, ) return LogProbs(content=[sampled_entry]) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 1ab0b48f350..463103eb7fa 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -83,6 +83,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.speculative_decoding = self.cfg.speculative_config.method is not None self.use_logprobs = self.cfg.model_config.enable_logprob + self.compute_logits_stats = self.cfg.model_config.compute_logits_stats self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob if self.speculative_decoding: @@ -350,6 +351,26 @@ def _process_batch_output_use_zmq(self, receive_datas): logprobs_list: LogprobsLists = stream_data.logprobs.tolists() result.outputs.logprob = float(logprobs_list.logprobs[0][0]) result.outputs.top_logprobs = logprobs_list + # Extract logits statistics if available + if self.compute_logits_stats: + assert ( + logprobs_list.logits_min is not None + ), "logits_min is None when compute_logits_stats is enabled" + assert ( + logprobs_list.logits_max is not None + ), "logits_max is None when compute_logits_stats is enabled" + assert ( + logprobs_list.logits_mean is not None + ), "logits_mean is None when compute_logits_stats is enabled" + assert ( + logprobs_list.logits_std is not None + ), "logits_std is None when compute_logits_stats is enabled" + result.outputs.logits_stats = { + "min": float(logprobs_list.logits_min[0]), + "max": float(logprobs_list.logits_max[0]), + "mean": float(logprobs_list.logits_mean[0]), + "std": float(logprobs_list.logits_std[0]), + } except Exception as e: llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}") if getattr(stream_data, "prompt_logprobs", None) is not None: diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 365fec12475..39c05e6f628 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -44,6 +44,11 @@ class LogprobsLists(NamedTuple): logprobs: list[list[float]] # [num_reqs] sampled_token_ranks: list[int] + # Logits statistics for each sequence (optional) + logits_min: Optional[list[float]] = None # [num_reqs] + logits_max: Optional[list[float]] = None # [num_reqs] + logits_mean: Optional[list[float]] = None # [num_reqs] + logits_std: Optional[list[float]] = None # [num_reqs] def slice_columns(self, start: int, end: int): """ @@ -54,6 +59,14 @@ def slice_columns(self, start: int, end: int): [row[start:end] for row in self.logprob_token_ids], [row[start:end] for row in self.logprobs], self.sampled_token_ranks, # unchanged + # [row[start:end] for row in self.logits_min], + # [row[start:end] for row in self.logits_max], + # [row[start:end] for row in self.logits_mean], + # [row[start:end] for row in self.logits_std], + self.logits_min, # unchanged + self.logits_max, # unchanged + self.logits_mean, # unchanged + self.logits_std, # unchanged ) def slice_rows(self, start: int, end: int): @@ -65,6 +78,10 @@ def slice_rows(self, start: int, end: int): self.logprob_token_ids[start:end], self.logprobs[start:end], self.sampled_token_ranks[start:end], + self.logits_min[start:end] if self.logits_min is not None else None, + self.logits_max[start:end] if self.logits_max is not None else None, + self.logits_mean[start:end] if self.logits_mean is not None else None, + self.logits_std[start:end] if self.logits_std is not None else None, ) @@ -77,6 +94,11 @@ class LogprobsTensors(NamedTuple): logprobs: paddle.Tensor # [num_reqs] selected_token_ranks: paddle.Tensor + # Logits statistics for each sequence (optional) + logits_min: Optional[paddle.Tensor] = None # [num_reqs] + logits_max: Optional[paddle.Tensor] = None # [num_reqs] + logits_mean: Optional[paddle.Tensor] = None # [num_reqs] + logits_std: Optional[paddle.Tensor] = None def tolists(self): """Convert to lists.""" @@ -84,6 +106,10 @@ def tolists(self): self.logprob_token_ids.tolist(), self.logprobs.tolist(), self.selected_token_ranks.tolist(), + self.logits_min.tolist() if self.logits_min is not None else None, + self.logits_max.tolist() if self.logits_max is not None else None, + self.logits_mean.tolist() if self.logits_mean is not None else None, + self.logits_std.tolist() if self.logits_std is not None else None, ) @staticmethod @@ -97,6 +123,10 @@ def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTens logprob_token_ids=logprob_token_ids, logprobs=logprobs, selected_token_ranks=selected_token_ranks, + logits_min=None, + logits_max=None, + logits_mean=None, + logits_std=None, ) @staticmethod @@ -110,6 +140,10 @@ def empty(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors" logprob_token_ids=logprob_token_ids, logprobs=logprobs, selected_token_ranks=selected_token_ranks, + logits_min=None, + logits_max=None, + logits_mean=None, + logits_std=None, ) def slice_rows(self, start: int, end: int): @@ -122,6 +156,26 @@ def slice_rows(self, start: int, end: int): paddle.to_tensor(self.logprob_token_ids.cpu()[start:end], place="cpu"), paddle.to_tensor(self.logprobs.cpu()[start:end], place="cpu"), paddle.to_tensor(self.selected_token_ranks.cpu()[start:end], place="cpu"), + ( + paddle.to_tensor(self.logits_min.cpu()[start:end], place="cpu") + if self.logits_min is not None + else None + ), + ( + paddle.to_tensor(self.logits_max.cpu()[start:end], place="cpu") + if self.logits_max is not None + else None + ), + ( + paddle.to_tensor(self.logits_mean.cpu()[start:end], place="cpu") + if self.logits_mean is not None + else None + ), + ( + paddle.to_tensor(self.logits_std.cpu()[start:end], place="cpu") + if self.logits_std is not None + else None + ), ) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 29deb2db4d4..f9a2c063e3f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -939,6 +939,11 @@ def parse_args(): action="store_true", help="Enable output of token-level log probabilities.", ) + parser.add_argument( + "--compute_logits_stats", + action="store_true", + help="Enable per-token logits statistics (min/max/mean/std) output.", + ) parser.add_argument( "--max_logprobs", type=int, From 44d436709aadae4a6e17002ff9ac0da732da5272 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Tue, 24 Mar 2026 13:25:39 +0800 Subject: [PATCH 2/5] fix ci --- fastdeploy/worker/gpu_model_runner.py | 5 ++--- fastdeploy/worker/metax_model_runner.py | 5 ++--- fastdeploy/worker/xpu_model_runner.py | 5 ++--- tests/worker/test_gpu_prompt_logprobs.py | 7 ++++++- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c374d227dcb..7a0d82a0395 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2957,9 +2957,8 @@ def _get_prompt_logprobs_list( raw_logprobs = self.sampler.compute_logprobs(logits) elif logprobs_mode == "raw_logits": raw_logprobs = logits - token_ids, logprobs, ranks = self.sampler.gather_logprobs( - raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor - ) + gathered = self.sampler.gather_logprobs(raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor) + token_ids, logprobs, ranks = gathered.logprob_token_ids, gathered.logprobs, gathered.selected_token_ranks # Synchronize before using token_ids, logprobs and ranks to ensure async copy are completed. paddle.device.synchronize() chunk_slice = slice(start_idx, start_idx + num_logits) diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index fa7daa41cfb..2a8c822c9b3 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -2833,9 +2833,8 @@ def _get_prompt_logprobs_list( raw_logprobs = self.sampler.compute_logprobs(logits) elif logprobs_mode == "raw_logits": raw_logprobs = logits - token_ids, logprobs, ranks = self.sampler.gather_logprobs( - raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor - ) + gathered = self.sampler.gather_logprobs(raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor) + token_ids, logprobs, ranks = gathered.logprob_token_ids, gathered.logprobs, gathered.selected_token_ranks chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False) logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index a984e8788c4..19b6920eb51 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -271,9 +271,8 @@ def _get_prompt_logprobs_list(self, hidden_states: paddle.Tensor) -> list[Option raw_logprobs = logits else: raw_logprobs = self.sampler.compute_logprobs(logits) - token_ids, logprobs, ranks = self.sampler.gather_logprobs( - raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor - ) + gathered = self.sampler.gather_logprobs(raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor) + token_ids, logprobs, ranks = gathered.logprob_token_ids, gathered.logprobs, gathered.selected_token_ranks chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False) logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False) diff --git a/tests/worker/test_gpu_prompt_logprobs.py b/tests/worker/test_gpu_prompt_logprobs.py index d26bc915339..58dd91db966 100644 --- a/tests/worker/test_gpu_prompt_logprobs.py +++ b/tests/worker/test_gpu_prompt_logprobs.py @@ -206,9 +206,14 @@ def test_prompt_logprobs(self): ref_raw_logprobs = model_runner.sampler.compute_logprobs(ref_logits) token_is = paddle.to_tensor(req.prompt_token_ids[1:], dtype="int64") - ref_token_ids, ref_logprobs, ref_ranks = model_runner.sampler.gather_logprobs( + gathered = model_runner.sampler.gather_logprobs( ref_raw_logprobs, model_runner.fd_config.model_config.ori_vocab_size, token_is ) + ref_token_ids, ref_logprobs, ref_ranks = ( + gathered.logprob_token_ids, + gathered.logprobs, + gathered.selected_token_ranks, + ) prompt_logprobs = model_runner._get_prompt_logprobs_list(hidden_states)[0] np.testing.assert_allclose(ref_logprobs.numpy(), prompt_logprobs.logprobs.numpy(), rtol=1e-04, atol=1e-04) np.testing.assert_allclose( From a8259f634166c2c0ae714208aca8400e40c96a25 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Tue, 24 Mar 2026 18:03:14 +0800 Subject: [PATCH 3/5] fix ci --- fastdeploy/engine/request.py | 2 + fastdeploy/entrypoints/llm.py | 6 +- fastdeploy/entrypoints/openai/serving_chat.py | 69 +++++++++++++------ .../entrypoints/openai/serving_completion.py | 6 +- 4 files changed, 61 insertions(+), 22 deletions(-) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 668fab95609..09e36a13ac7 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -746,6 +746,7 @@ def to_dict(self): "text": self.text, "reasoning_content": self.reasoning_content, "reasoning_token_num": self.reasoning_token_num, + "logits_stats": self.logits_stats, } @classmethod @@ -771,6 +772,7 @@ def __repr__(self) -> str: f"logprobs={self.logprobs}, " f"top_logprobs={self.top_logprobs}, " f"draft_top_logprobs={self.draft_top_logprobs}, " + f"logits_stats={self.logits_stats}, " ) def get(self, key: str, default_value=None): diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index c452c84c44f..43f7c2ae8bb 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -450,7 +450,11 @@ def _build_prompt_logprobs( tensors. """ - token_ids, logprobs, ranks = prompt_logprobs_tensors + token_ids, logprobs, ranks = ( + prompt_logprobs_tensors.logprob_token_ids, + prompt_logprobs_tensors.logprobs, + prompt_logprobs_tensors.selected_token_ranks, + ) # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 2e528bc85ad..e4cb5389836 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -821,38 +821,63 @@ def _create_chat_logprobs( output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs[:3]) - ): # check top 3 because logits_stats maybe None + ): return None logprobs_res: Optional[LogProbs] = None - # Extract logits stats from LogprobsLists if available - has_logits_stats = False if output_top_logprobs.logits_min is None else True + # Check if output_top_logprobs is a LogprobsLists object(NamedTuple) or a list + is_logprobslists = hasattr(output_top_logprobs, "logprob_token_ids") + + # Extract logits stats if available + if is_logprobslists: + # output_top_logprobs is LogprobsLists namedtuple + has_logits_stats = output_top_logprobs.logits_min is not None + else: + # list from msgpack: [logprob_token_ids, logprobs, sampled_token_ranks, logits_min, logits_max, logits_mean, logits_std] + has_logits_stats = len(output_top_logprobs) >= 7 and output_top_logprobs[3] is not None + + if is_logprobslists: + num_tokens = len(output_top_logprobs.logprobs) + _tk_ids = lambda idx: output_top_logprobs.logprob_token_ids[idx] + _lps = lambda idx: output_top_logprobs.logprobs[idx] + _ranks = lambda idx: output_top_logprobs.sampled_token_ranks[idx] + _lmin = lambda idx: output_top_logprobs.logits_min[idx] + _lmax = lambda idx: output_top_logprobs.logits_max[idx] + _lmean = lambda idx: output_top_logprobs.logits_mean[idx] + _lstd = lambda idx: output_top_logprobs.logits_std[idx] + else: + num_tokens = len(output_top_logprobs[1]) + _tk_ids = lambda idx: output_top_logprobs[0][idx] + _lps = lambda idx: output_top_logprobs[1][idx] + _ranks = lambda idx: output_top_logprobs[2][idx] + _lmin = lambda idx: output_top_logprobs[3][idx] + _lmax = lambda idx: output_top_logprobs[4][idx] + _lmean = lambda idx: output_top_logprobs[5][idx] + _lstd = lambda idx: output_top_logprobs[6][idx] - # Iterate by index over mandatory fields; optionally include logits stats - num_tokens = len(output_top_logprobs.logprobs) for idx in range(num_tokens): logits_stats = None if has_logits_stats: top_logprobs = LogprobsLists( - logprob_token_ids=[output_top_logprobs.logprob_token_ids[idx]], - logprobs=[output_top_logprobs.logprobs[idx]], - sampled_token_ranks=[output_top_logprobs.sampled_token_ranks[idx]], - logits_min=[output_top_logprobs.logits_min[idx]], - logits_max=[output_top_logprobs.logits_max[idx]], - logits_mean=[output_top_logprobs.logits_mean[idx]], - logits_std=[output_top_logprobs.logits_std[idx]], + logprob_token_ids=[_tk_ids(idx)], + logprobs=[_lps(idx)], + sampled_token_ranks=[_ranks(idx)], + logits_min=[_lmin(idx)], + logits_max=[_lmax(idx)], + logits_mean=[_lmean(idx)], + logits_std=[_lstd(idx)], ) logits_stats = { - "min": float(output_top_logprobs.logits_min[idx]), - "max": float(output_top_logprobs.logits_max[idx]), - "mean": float(output_top_logprobs.logits_mean[idx]), - "std": float(output_top_logprobs.logits_std[idx]), + "min": float(_lmin(idx)), + "max": float(_lmax(idx)), + "mean": float(_lmean(idx)), + "std": float(_lstd(idx)), } else: top_logprobs = LogprobsLists( - logprob_token_ids=[output_top_logprobs.logprob_token_ids[idx]], - logprobs=[output_top_logprobs.logprobs[idx]], - sampled_token_ranks=[output_top_logprobs.sampled_token_ranks[idx]], + logprob_token_ids=[_tk_ids(idx)], + logprobs=[_lps(idx)], + sampled_token_ranks=[_ranks(idx)], ) step_logprobs_res = self._build_logprobs_response( request_logprobs=request_logprobs, @@ -943,7 +968,11 @@ def _build_prompt_logprobs( tensors. """ - token_ids, logprobs, ranks = prompt_logprobs_tensors + token_ids, logprobs, ranks = ( + prompt_logprobs_tensors.logprob_token_ids, + prompt_logprobs_tensors.logprobs, + prompt_logprobs_tensors.selected_token_ranks, + ) # Normalize to plain Python lists (support both Tensor and list inputs) if hasattr(token_ids, "tolist"): diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 8b49ea82c16..b7c74f6cc85 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -900,7 +900,11 @@ def _build_prompt_logprobs( tensors. """ - token_ids, logprobs, ranks = prompt_logprobs_tensors + token_ids, logprobs, ranks = ( + prompt_logprobs_tensors.logprob_token_ids, + prompt_logprobs_tensors.logprobs, + prompt_logprobs_tensors.selected_token_ranks, + ) # Normalize to plain Python lists (support both Tensor and list inputs) if hasattr(token_ids, "tolist"): From 4488f977c9967696c76344a540f9d63d304e44e5 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Tue, 24 Mar 2026 21:13:42 +0800 Subject: [PATCH 4/5] fix ci ut --- fastdeploy/entrypoints/llm.py | 6 +----- fastdeploy/entrypoints/openai/serving_chat.py | 6 +----- .../entrypoints/openai/serving_completion.py | 6 +----- tests/ce/server/test_logprobs.py | 15 +++++++++++++++ tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4.py | 13 +++++++++++++ .../4cards_cases/test_ernie_21b_tp1_dp4_mtp.py | 13 +++++++++++++ 6 files changed, 44 insertions(+), 15 deletions(-) diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 43f7c2ae8bb..14882872e8d 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -450,11 +450,7 @@ def _build_prompt_logprobs( tensors. """ - token_ids, logprobs, ranks = ( - prompt_logprobs_tensors.logprob_token_ids, - prompt_logprobs_tensors.logprobs, - prompt_logprobs_tensors.selected_token_ranks, - ) + token_ids, logprobs, ranks = prompt_logprobs_tensors[:3] # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index e4cb5389836..89c45bb234b 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -968,11 +968,7 @@ def _build_prompt_logprobs( tensors. """ - token_ids, logprobs, ranks = ( - prompt_logprobs_tensors.logprob_token_ids, - prompt_logprobs_tensors.logprobs, - prompt_logprobs_tensors.selected_token_ranks, - ) + token_ids, logprobs, ranks = prompt_logprobs_tensors[:3] # Normalize to plain Python lists (support both Tensor and list inputs) if hasattr(token_ids, "tolist"): diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index b7c74f6cc85..460f6d31a8f 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -900,11 +900,7 @@ def _build_prompt_logprobs( tensors. """ - token_ids, logprobs, ranks = ( - prompt_logprobs_tensors.logprob_token_ids, - prompt_logprobs_tensors.logprobs, - prompt_logprobs_tensors.selected_token_ranks, - ) + token_ids, logprobs, ranks = prompt_logprobs_tensors[:3] # Normalize to plain Python lists (support both Tensor and list inputs) if hasattr(token_ids, "tolist"): diff --git a/tests/ce/server/test_logprobs.py b/tests/ce/server/test_logprobs.py index 83ca89486c9..aa737a46e7e 100644 --- a/tests/ce/server/test_logprobs.py +++ b/tests/ce/server/test_logprobs.py @@ -3,6 +3,17 @@ from core import TEMPLATE, URL, build_request_payload, send_request +def _strip_logits_stats(obj): + """Recursively remove 'logits_stats' keys from logprobs response.""" + if isinstance(obj, dict): + obj.pop("logits_stats", None) + for v in obj.values(): + _strip_logits_stats(v) + elif isinstance(obj, list): + for item in obj: + _strip_logits_stats(item) + + def test_unstream_with_logprobs(): """ 测试非流式响应开启 logprobs 后,返回的 token 概率信息是否正确。 @@ -21,6 +32,7 @@ def test_unstream_with_logprobs(): response = send_request(URL, payload) print(json.dumps(response.json(), indent=2, ensure_ascii=False)) resp_json = response.json() + _strip_logits_stats(resp_json) # 校验返回内容与概率信息 assert resp_json["choices"][0]["message"]["content"] == "牛顿的" @@ -99,6 +111,7 @@ def test_stream_with_logprobs(): print(json.dumps(result_chunk, indent=2, ensure_ascii=False)) break + _strip_logits_stats(result_chunk) # 校验概率字段 assert result_chunk["choices"][0]["delta"]["content"] == "牛顿" assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" @@ -184,6 +197,7 @@ def test_stream_with_temp_scaled_logprobs(): print(json.dumps(result_chunk, indent=2, ensure_ascii=False)) break + _strip_logits_stats(result_chunk) # 校验概率字段 assert result_chunk["choices"][0]["delta"]["content"] == "牛顿" assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" @@ -229,6 +243,7 @@ def test_stream_with_top_p_normalized_logprobs(): print(json.dumps(result_chunk, indent=2, ensure_ascii=False)) break + _strip_logits_stats(result_chunk) # 校验概率字段 assert result_chunk["choices"][0]["delta"]["content"] == "牛顿" assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" diff --git a/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4.py b/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4.py index 4fb178d4582..5752434c009 100644 --- a/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4.py +++ b/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4.py @@ -23,6 +23,18 @@ import pytest import requests + +def _strip_logits_stats(obj): + """Recursively remove 'logits_stats' keys from logprobs response.""" + if isinstance(obj, dict): + obj.pop("logits_stats", None) + for v in obj.values(): + _strip_logits_stats(v) + elif isinstance(obj, list): + for item in obj: + _strip_logits_stats(item) + + tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) sys.path.insert(0, tests_dir) @@ -606,6 +618,7 @@ def test_non_stream_with_logprobs(api_url): resp_json = send_request(url=api_url, payload=payload).json() logprobs = resp_json["choices"][0]["logprobs"] + _strip_logits_stats(logprobs) base_path = os.getenv("MODEL_PATH") if base_path: diff --git a/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4_mtp.py b/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4_mtp.py index 6e4e36f7392..e77e8b4e8d3 100644 --- a/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4_mtp.py +++ b/tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4_mtp.py @@ -23,6 +23,18 @@ import pytest import requests + +def _strip_logits_stats(obj): + """Recursively remove 'logits_stats' keys from logprobs response.""" + if isinstance(obj, dict): + obj.pop("logits_stats", None) + for v in obj.values(): + _strip_logits_stats(v) + elif isinstance(obj, list): + for item in obj: + _strip_logits_stats(item) + + tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) sys.path.insert(0, tests_dir) @@ -512,6 +524,7 @@ def test_non_stream_with_logprobs(api_url): resp_json = send_request(url=api_url, payload=payload).json() logprobs = resp_json["choices"][0]["logprobs"] + _strip_logits_stats(logprobs) base_path = os.getenv("MODEL_PATH") From 67e0aa1683b337e95380106f56c4c95ed0da9257 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Wed, 25 Mar 2026 12:56:09 +0800 Subject: [PATCH 5/5] fix ci --- tests/output/test_process_batch_output.py | 6 +++--- tests/output/test_token_processor.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 46282cd386a..7362b119948 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -238,7 +238,7 @@ def test_speculative_decoding_use_logprobs(self): for i, request_output in enumerate(batch_result_buffer): assert isinstance(request_output, RequestOutput) assert len(request_output.outputs.token_ids) == accept_num[i] - assert len(request_output.outputs.top_logprobs) == 3 + assert len(request_output.outputs.top_logprobs) == 7 # tokens, scores, ranks assert len(request_output.outputs.top_logprobs[0][0]) == K + 1 assert len(request_output.outputs.top_logprobs[1][0]) == K + 1 @@ -251,8 +251,8 @@ def test_speculative_decoding_use_logprobs(self): for c in cached_generated_tokens.cache: assert isinstance(request_output, RequestOutput) assert len(request_output.outputs.token_ids) == accept_num[i] - assert len(request_output.outputs.top_logprobs) == 3 - assert len(request_output.outputs.draft_top_logprobs) == 3 + assert len(request_output.outputs.top_logprobs) == 7 + assert len(request_output.outputs.draft_top_logprobs) == 7 # tokens, scores, ranks assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1 assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index c0609094a2b..94e14e49796 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -55,7 +55,7 @@ def __init__( num_speculative_tokens=2, enable_draft_logprob=True, ) - self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob) + self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob, compute_logits_stats=False) self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode") self.cache_config = types.SimpleNamespace( enable_prefix_caching=enable_prefix_caching,