-
Notifications
You must be signed in to change notification settings - Fork 729
【TI-Consisent】Added Metric logits_stats to the ZMQ branch #6979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
86c539b
44d4367
e045231
a8259f6
9abc1e1
4488f97
67e0aa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -644,6 +644,7 @@ def _start_worker_service(self): | |
| "use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage, | ||
| "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, | ||
| "enable_logprob": self.cfg.model_config.enable_logprob, | ||
| "compute_logits_stats": self.cfg.model_config.compute_logits_stats, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. common_engine.py中也得加这个参数 |
||
| "lm_head_fp32": self.cfg.model_config.lm_head_fp32, | ||
| "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, | ||
| "shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -83,6 +83,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.speculative_decoding = self.cfg.speculative_config.method is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_logprobs = self.cfg.model_config.enable_logprob | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.compute_logits_stats = self.cfg.model_config.compute_logits_stats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.speculative_decoding: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -350,6 +351,26 @@ def _process_batch_output_use_zmq(self, receive_datas): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list: LogprobsLists = stream_data.logprobs.tolists() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result.outputs.logprob = float(logprobs_list.logprobs[0][0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result.outputs.top_logprobs = logprobs_list | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Extract logits statistics if available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.compute_logits_stats: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_min is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_min is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_max is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_max is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_mean is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_mean is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_std is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_std is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+356
to
+367
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | |
| logprobs_list.logits_min is not None | |
| ), "logits_min is None when compute_logits_stats is enabled" | |
| assert ( | |
| logprobs_list.logits_max is not None | |
| ), "logits_max is None when compute_logits_stats is enabled" | |
| assert ( | |
| logprobs_list.logits_mean is not None | |
| ), "logits_mean is None when compute_logits_stats is enabled" | |
| assert ( | |
| logprobs_list.logits_std is not None | |
| ), "logits_std is None when compute_logits_stats is enabled" | |
| missing_fields = [] | |
| if logprobs_list.logits_min is None: | |
| missing_fields.append("logits_min") | |
| if logprobs_list.logits_max is None: | |
| missing_fields.append("logits_max") | |
| if logprobs_list.logits_mean is None: | |
| missing_fields.append("logits_mean") | |
| if logprobs_list.logits_std is None: | |
| missing_fields.append("logits_std") | |
| if missing_fields: | |
| # When compute_logits_stats is enabled, all logits_* fields must be present | |
| raise ValueError( | |
| "Missing logits stats fields when compute_logits_stats is enabled: " | |
| + ", ".join(missing_fields) | |
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,11 @@ class LogprobsLists(NamedTuple): | |
| logprobs: list[list[float]] | ||
| # [num_reqs] | ||
| sampled_token_ranks: list[int] | ||
| # Logits statistics for each sequence (optional) | ||
| logits_min: Optional[list[float]] = None # [num_reqs] | ||
| logits_max: Optional[list[float]] = None # [num_reqs] | ||
| logits_mean: Optional[list[float]] = None # [num_reqs] | ||
| logits_std: Optional[list[float]] = None # [num_reqs] | ||
|
Comment on lines
44
to
+51
|
||
|
|
||
| def slice_columns(self, start: int, end: int): | ||
| """ | ||
|
|
@@ -54,6 +59,14 @@ def slice_columns(self, start: int, end: int): | |
| [row[start:end] for row in self.logprob_token_ids], | ||
| [row[start:end] for row in self.logprobs], | ||
| self.sampled_token_ranks, # unchanged | ||
| # [row[start:end] for row in self.logits_min], | ||
| # [row[start:end] for row in self.logits_max], | ||
| # [row[start:end] for row in self.logits_mean], | ||
| # [row[start:end] for row in self.logits_std], | ||
| self.logits_min, # unchanged | ||
| self.logits_max, # unchanged | ||
| self.logits_mean, # unchanged | ||
| self.logits_std, # unchanged | ||
| ) | ||
|
|
||
| def slice_rows(self, start: int, end: int): | ||
|
|
@@ -65,6 +78,10 @@ def slice_rows(self, start: int, end: int): | |
| self.logprob_token_ids[start:end], | ||
| self.logprobs[start:end], | ||
| self.sampled_token_ranks[start:end], | ||
| self.logits_min[start:end] if self.logits_min is not None else None, | ||
| self.logits_max[start:end] if self.logits_max is not None else None, | ||
| self.logits_mean[start:end] if self.logits_mean is not None else None, | ||
| self.logits_std[start:end] if self.logits_std is not None else None, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -77,13 +94,22 @@ class LogprobsTensors(NamedTuple): | |
| logprobs: paddle.Tensor | ||
| # [num_reqs] | ||
| selected_token_ranks: paddle.Tensor | ||
| # Logits statistics for each sequence (optional) | ||
| logits_min: Optional[paddle.Tensor] = None # [num_reqs] | ||
| logits_max: Optional[paddle.Tensor] = None # [num_reqs] | ||
| logits_mean: Optional[paddle.Tensor] = None # [num_reqs] | ||
| logits_std: Optional[paddle.Tensor] = None | ||
|
|
||
| def tolists(self): | ||
| """Convert to lists.""" | ||
| return LogprobsLists( | ||
| self.logprob_token_ids.tolist(), | ||
| self.logprobs.tolist(), | ||
| self.selected_token_ranks.tolist(), | ||
| self.logits_min.tolist() if self.logits_min is not None else None, | ||
| self.logits_max.tolist() if self.logits_max is not None else None, | ||
| self.logits_mean.tolist() if self.logits_mean is not None else None, | ||
| self.logits_std.tolist() if self.logits_std is not None else None, | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
|
@@ -97,6 +123,10 @@ def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTens | |
| logprob_token_ids=logprob_token_ids, | ||
| logprobs=logprobs, | ||
| selected_token_ranks=selected_token_ranks, | ||
| logits_min=None, | ||
| logits_max=None, | ||
| logits_mean=None, | ||
| logits_std=None, | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
|
@@ -110,6 +140,10 @@ def empty(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors" | |
| logprob_token_ids=logprob_token_ids, | ||
| logprobs=logprobs, | ||
| selected_token_ranks=selected_token_ranks, | ||
| logits_min=None, | ||
| logits_max=None, | ||
| logits_mean=None, | ||
| logits_std=None, | ||
| ) | ||
|
|
||
| def slice_rows(self, start: int, end: int): | ||
|
|
@@ -122,6 +156,26 @@ def slice_rows(self, start: int, end: int): | |
| paddle.to_tensor(self.logprob_token_ids.cpu()[start:end], place="cpu"), | ||
| paddle.to_tensor(self.logprobs.cpu()[start:end], place="cpu"), | ||
| paddle.to_tensor(self.selected_token_ranks.cpu()[start:end], place="cpu"), | ||
| ( | ||
| paddle.to_tensor(self.logits_min.cpu()[start:end], place="cpu") | ||
| if self.logits_min is not None | ||
| else None | ||
| ), | ||
| ( | ||
| paddle.to_tensor(self.logits_max.cpu()[start:end], place="cpu") | ||
| if self.logits_max is not None | ||
| else None | ||
| ), | ||
| ( | ||
| paddle.to_tensor(self.logits_mean.cpu()[start:end], place="cpu") | ||
| if self.logits_mean is not None | ||
| else None | ||
| ), | ||
| ( | ||
| paddle.to_tensor(self.logits_std.cpu()[start:end], place="cpu") | ||
| if self.logits_std is not None | ||
| else None | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR 标题目前为“【TI-Consisent】...”,不符合仓库要求的
[CLASS]Title格式(模板里给出的 tag 列表如[Feature]/[BugFix]等)。建议将标题改为类似[Feature] Add logits_stats metric for ZMQ logprobs,并修正 Consisent 的拼写以便后续检索与自动化流程识别。