Skip to content

Commit f637ba7

Browse files
[Cherry-Pick] MTP split draft_tokens into standalone post-processing path(#5205) (#5232)
* merge code * fix Request CONFLICT * remove unuse unittest --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
1 parent bbcd92c commit f637ba7

File tree

3 files changed

+249
-27
lines changed

3 files changed

+249
-27
lines changed

fastdeploy/engine/request.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import numpy as np
2525

26+
from fastdeploy import envs
2627
from fastdeploy.engine.sampling_params import SamplingParams
2728
from fastdeploy.entrypoints.openai.protocol import ToolCall
2829
from fastdeploy.utils import data_processor_logger
@@ -273,7 +274,20 @@ def set(self, key, value):
273274
setattr(self, key, value)
274275

275276
def __repr__(self) -> str:
276-
return ""
277+
"""Safe string representation that ignores private and None fields."""
278+
try:
279+
if not envs.FD_DEBUG:
280+
return f"Request(request_id={self.request_id})"
281+
else:
282+
attrs_snapshot = dict(vars(self))
283+
non_none_fields = [
284+
f"{attr}={value!r}"
285+
for attr, value in attrs_snapshot.items()
286+
if value is not None and not attr.startswith("_")
287+
]
288+
return f"Request({', '.join(non_none_fields)})"
289+
except Exception as e:
290+
return f"<Request repr failed: {e}>"
277291

278292

279293
@dataclass(slots=True)

fastdeploy/output/token_processor.py

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,60 @@ def _compute_speculative_status(self):
338338
self.total_step = 0
339339
self.speculative_stats_step += 1
340340

341+
def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, ranks):
342+
"""
343+
Process batch draft tokens and generate corresponding request outputs
344+
345+
Args:
346+
mtype (int): Message type (3=target token, 4=draft token)
347+
batch (int): Batch size
348+
accept_num (list): List of accepted token counts per request
349+
tokens (paddle.Tensor): Generated draft token IDs tensor
350+
scores (paddle.Tensor): Token scores tensor
351+
ranks (paddle.Tensor): Token sampling ranks tensor
352+
353+
Returns:
354+
list[RequestOutput]: List containing processed results for all requests
355+
"""
356+
batch_result = list()
357+
for i in range(batch):
358+
if self.resource_manager.stop_flags[i]:
359+
continue
360+
task = self.resource_manager.tasks_list[i]
361+
task_id = task.request_id
362+
result = RequestOutput(
363+
request_id=task_id,
364+
output_type=mtype,
365+
outputs=CompletionOutput(
366+
index=i,
367+
send_idx=None,
368+
token_ids=[],
369+
draft_token_ids=[],
370+
),
371+
finished=False,
372+
metrics=None,
373+
)
374+
375+
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
376+
for batch_token_index in range(len(token_ids)):
377+
result.outputs.logprob = float(scores[i, batch_token_index, 0])
378+
topk_token_ids = tokens[i, batch_token_index, :].tolist()
379+
topk_logprobs = scores[i, batch_token_index, :].tolist()
380+
sampled_rank = ranks[i, batch_token_index].item()
381+
382+
if result.outputs.draft_top_logprobs is None:
383+
result.outputs.draft_top_logprobs = LogprobsLists(
384+
logprob_token_ids=[topk_token_ids],
385+
logprobs=[topk_logprobs],
386+
sampled_token_ranks=[sampled_rank],
387+
)
388+
else:
389+
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
390+
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
391+
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
392+
batch_result.append(result)
393+
return batch_result
394+
341395
def _process_batch_output(self):
342396
"""
343397
batch post-processing function
@@ -362,6 +416,12 @@ def _process_batch_output(self):
362416
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
363417
)
364418
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
419+
420+
# split draft_tokens into standalone post-processing path for MTP + logprobs
421+
if mtype == 4:
422+
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
423+
self.postprocess(batch_result, mtype)
424+
return
365425
else:
366426
batch = self.output_tokens[1]
367427
accept_num = tokens[2 : batch + 2]
@@ -479,9 +539,11 @@ def _process_batch_output(self):
479539
token_id = token_ids[batch_token_index]
480540
self.tokens_counter[task_id] += 1
481541
if token_id != RECOVERY_STOP_SIGNAL:
482-
result.outputs.token_ids.append(token_id)
483-
if mtype == 3: # target_tokens
484-
task.output_token_ids.append(token_id)
542+
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
543+
result.outputs.token_ids.append(token_id)
544+
545+
task.output_token_ids.append(token_id)
546+
485547
if self.use_logprobs:
486548
if self.cfg.speculative_config.method:
487549
result.outputs.logprob = float(scores[i, batch_token_index, 0])
@@ -494,29 +556,18 @@ def _process_batch_output(self):
494556
topk_logprobs = scores[i, :].tolist()
495557
sampled_rank = ranks[i].item()
496558

497-
if mtype == 3: # top_logprobs
498-
if result.outputs.top_logprobs is None:
499-
result.outputs.top_logprobs = LogprobsLists(
500-
logprob_token_ids=[topk_token_ids],
501-
logprobs=[topk_logprobs],
502-
sampled_token_ranks=[sampled_rank],
503-
)
504-
else:
505-
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
506-
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
507-
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
508-
elif mtype == 4: # draft_top_logprobs
509-
if result.outputs.draft_top_logprobs is None:
510-
result.outputs.draft_top_logprobs = LogprobsLists(
511-
logprob_token_ids=[topk_token_ids],
512-
logprobs=[topk_logprobs],
513-
sampled_token_ranks=[sampled_rank],
514-
)
515-
else:
516-
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
517-
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
518-
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
519-
if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop):
559+
if result.outputs.top_logprobs is None:
560+
result.outputs.top_logprobs = LogprobsLists(
561+
logprob_token_ids=[topk_token_ids],
562+
logprobs=[topk_logprobs],
563+
sampled_token_ranks=[sampled_rank],
564+
)
565+
else:
566+
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
567+
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
568+
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
569+
570+
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
520571
result.finished = True
521572
if recovery_stop:
522573
result.error_msg = "Recover is not supported, the result is incomplete!"
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import unittest
18+
from unittest.mock import MagicMock
19+
20+
import numpy as np
21+
import paddle
22+
23+
from fastdeploy.engine.request import RequestOutput
24+
from fastdeploy.output.token_processor import TokenProcessor
25+
26+
27+
class TestProcessBatchDraftTokens(unittest.TestCase):
28+
29+
def setUp(self):
30+
# 模拟 cfg
31+
cfg = MagicMock()
32+
cfg.speculative_config = MagicMock()
33+
cfg.speculative_config.method = "mtp"
34+
cfg.speculative_config.num_speculative_tokens = 3
35+
cfg.model_config = MagicMock()
36+
cfg.model_config.enable_logprob = True
37+
38+
self.processor = TokenProcessor(
39+
cfg=cfg, cached_generated_tokens=MagicMock(), engine_worker_queue=MagicMock(), split_connector=MagicMock()
40+
)
41+
42+
# mock resource_manager
43+
self.processor.resource_manager = MagicMock()
44+
self.processor.resource_manager.stop_flags = [False] * 512
45+
self.processor.resource_manager.tasks_list = [MagicMock()] * 512
46+
47+
for task in self.processor.resource_manager.tasks_list:
48+
task.request_id = "test_request"
49+
task.eos_token_ids = [2]
50+
51+
def test_process_batch_draft_tokens_normal_case(self):
52+
"""测试正常情况下的target处理"""
53+
batch = 2
54+
accept_num = [3, 2]
55+
K = 20
56+
MAX_DRAFT_TOKENS = 6
57+
58+
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
59+
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
60+
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
61+
62+
results = self.processor._process_batch_draft_tokens(
63+
mtype=4,
64+
batch=batch,
65+
accept_num=accept_num,
66+
tokens=paddle.to_tensor(tokens),
67+
scores=paddle.to_tensor(scores),
68+
ranks=paddle.to_tensor(ranks),
69+
)
70+
71+
self.assertEqual(len(results), batch)
72+
for i, result in enumerate(results):
73+
self.assertIsInstance(result, RequestOutput)
74+
self.assertEqual(result.output_type, 4)
75+
self.assertEqual(result.outputs.index, i)
76+
self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids), accept_num[i])
77+
self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs), accept_num[i])
78+
self.assertEqual(len(result.outputs.draft_top_logprobs.sampled_token_ranks), accept_num[i])
79+
80+
def test_process_batch_draft_tokens_with_stop_flag(self):
81+
"""测试有停止标志的情况"""
82+
batch = 3
83+
self.processor.resource_manager.stop_flags[1] = True # 第二个 request 停止
84+
85+
accept_num = [3, 2, 1]
86+
K = 20
87+
MAX_DRAFT_TOKENS = 6
88+
89+
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
90+
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
91+
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
92+
93+
results = self.processor._process_batch_draft_tokens(
94+
mtype=4,
95+
batch=batch,
96+
accept_num=accept_num,
97+
tokens=paddle.to_tensor(tokens),
98+
scores=paddle.to_tensor(scores),
99+
ranks=paddle.to_tensor(ranks),
100+
)
101+
102+
self.assertEqual(len(results), 2)
103+
self.assertEqual(results[0].outputs.index, 0)
104+
self.assertEqual(results[1].outputs.index, 2)
105+
106+
def test_process_batch_draft_tokens_empty_accept(self):
107+
"""测试 accept_num 为 0 的情况"""
108+
batch = 2
109+
accept_num = [0, 0]
110+
111+
K = 20
112+
MAX_DRAFT_TOKENS = 6
113+
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
114+
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
115+
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
116+
117+
results = self.processor._process_batch_draft_tokens(
118+
mtype=4,
119+
batch=batch,
120+
accept_num=accept_num,
121+
tokens=paddle.to_tensor(tokens),
122+
scores=paddle.to_tensor(scores),
123+
ranks=paddle.to_tensor(ranks),
124+
)
125+
126+
self.assertEqual(len(results), batch)
127+
for result in results:
128+
self.assertIsNone(result.outputs.draft_top_logprobs)
129+
130+
def test_process_batch_draft_tokens_different_k_values(self):
131+
"""测试不同 K 值情况"""
132+
batch = 2
133+
accept_num = [3, 2]
134+
135+
K = 5
136+
MAX_DRAFT_TOKENS = 6
137+
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
138+
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
139+
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
140+
141+
results = self.processor._process_batch_draft_tokens(
142+
mtype=4,
143+
batch=batch,
144+
accept_num=accept_num,
145+
tokens=paddle.to_tensor(tokens),
146+
scores=paddle.to_tensor(scores),
147+
ranks=paddle.to_tensor(ranks),
148+
)
149+
150+
self.assertEqual(len(results), batch)
151+
for i, result in enumerate(results):
152+
self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids[0]), K + 1)
153+
self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs[0]), K + 1)
154+
155+
156+
if __name__ == "__main__":
157+
unittest.main()

0 commit comments

Comments
 (0)