Skip to content

Commit 6e83433

Browse files
committed
fix reviewer commits
Signed-off-by: Qiang Xu <qiangx@nvidia.com>
1 parent b24849a commit 6e83433

File tree

6 files changed

+117
-96
lines changed

6 files changed

+117
-96
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .guided_decoder import GuidedDecoder
3030
from .kv_cache_connector import KvCacheConnectorManager
3131
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
32-
from .llm_request import ExecutorResponse
32+
from .llm_request import ExecutorResponse, LlmRequestState
3333
from .mamba_cache_manager import MambaHybridCacheManager
3434
from .model_engine import PyTorchModelEngine
3535
from .py_executor import PyExecutor
@@ -38,7 +38,7 @@
3838
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
3939
TRTLLMSampler)
4040
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
41-
SimpleScheduler, SmDisaggCtxScheduler)
41+
SimpleScheduler)
4242
from .seq_slot_manager import SeqSlotManager
4343

4444
GB = 1 << 30
@@ -801,8 +801,10 @@ def create_py_executor_instance(
801801
two_step_lookahead=mapping.has_pp())
802802
mb_scheduler = BindMicroBatchScheduler(
803803
sm_disagg_config.context_max_batch_size,
804-
sm_disagg_config.context_max_num_tokens, ctx_chunk_config)
805-
ctx_scheduler = SmDisaggCtxScheduler(capacity_scheduler, mb_scheduler)
804+
sm_disagg_config.context_max_num_tokens,
805+
ctx_chunk_config,
806+
no_schedule_after_state=LlmRequestState.GENERATION_IN_PROGRESS)
807+
ctx_scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
806808
else:
807809
ctx_scheduler = None
808810

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -342,31 +342,12 @@ def fetch_new_requests(
342342
self, activate_requests: List[LlmRequest],
343343
num_active_requests_on_engine: int) -> List[LlmRequest]:
344344

345-
if self.is_sm_disagg:
346-
return self._fetch_new_requests_sm_disagg(
347-
len(activate_requests), num_active_requests_on_engine)
348-
elif self.enable_attention_dp:
345+
if self.enable_attention_dp:
349346
return self._fetch_new_requests_attention_dp(activate_requests)
350347
else:
351-
return self._fetch_new_requests_attention_tp(len(activate_requests))
352-
353-
def _fetch_new_requests_sm_disagg(
354-
self, num_active_requests: int,
355-
num_active_requests_on_engine: int) -> List[LlmRequest]:
356-
"""Handle SM-level disaggregation request fetching."""
357-
total_max_num_active_requests = (self.max_num_active_requests +
358-
num_active_requests -
359-
num_active_requests_on_engine)
360-
361-
# fetch and process requests into waiting queue
362-
new_requests = self._fetch_and_process_requests(
363-
num_active_requests_on_engine,
364-
total_max_num_active_requests,
365-
enable_attention_dp=False)
366-
367-
# Merge requests and add to active list
368-
merged_requests = self._merge_requests(new_requests)
369-
return merged_requests
348+
num_active_requests = num_active_requests_on_engine if self.is_sm_disagg else len(
349+
activate_requests)
350+
return self._fetch_new_requests_attention_tp(num_active_requests)
370351

371352
def _fetch_new_requests_attention_tp(
372353
self, num_active_requests: int) -> List[LlmRequest]:
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
import torch
14+
from cuda.bindings import driver
15+
16+
from tensorrt_llm.runtime.generation import CUASSERT
17+
18+
19+
def green_ctx_create_streams(res_list, device):
20+
streams = []
21+
for res in res_list:
22+
desc = CUASSERT(driver.cuDevResourceGenerateDesc([res], 1))[0]
23+
green_ctx = CUASSERT(
24+
driver.cuGreenCtxCreate(
25+
desc, device, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
26+
)
27+
)[0]
28+
stream = CUASSERT(
29+
driver.cuGreenCtxStreamCreate(
30+
green_ctx, driver.CUstream_flags.CU_STREAM_NON_BLOCKING, 0
31+
)
32+
)[0]
33+
stream = torch.cuda.get_stream_from_external(stream, device)
34+
streams.append(stream)
35+
return streams
36+
37+
38+
def green_ctx_split_percent(sm_percent: float, device_id: int = 0):
39+
device = CUASSERT(driver.cuDeviceGet(device_id))[0]
40+
41+
res = CUASSERT(
42+
driver.cuDeviceGetDevResource(device, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM)
43+
)[0]
44+
sm_count = res.sm.smCount
45+
46+
major = CUASSERT(
47+
driver.cuDeviceGetAttribute(
48+
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device
49+
)
50+
)[0]
51+
if major >= 9:
52+
sm_min = 8
53+
sm_align = 8
54+
else:
55+
sm_min = 4 if major == 8 else 2
56+
sm_align = 2
57+
58+
def green_ctx_split_aligned(sm_g1):
59+
sm_g1 = round(sm_g1 / sm_align) * sm_align
60+
sm_g1 = min(max(sm_g1, sm_min), sm_count - sm_min)
61+
result = CUASSERT(
62+
driver.cuDevSmResourceSplitByCount(
63+
1, # nbGroups
64+
res,
65+
0, # useFlags
66+
sm_g1,
67+
)
68+
)
69+
res_split = (result[0][0], result[2])
70+
streams = green_ctx_create_streams(res_split, device)
71+
return streams, res_split
72+
73+
sm_g1 = round(sm_count * sm_percent)
74+
sm_g2 = sm_count - sm_g1
75+
# Choose the split closer to sm_percent when sm_count is not divisible by sm_align
76+
sm_g1_dist = min(sm_g1 % sm_align, sm_align - (sm_g1 % sm_align))
77+
sm_g2_dist = min(sm_g2 % sm_align, sm_align - (sm_g2 % sm_align))
78+
if sm_g1_dist <= sm_g2_dist:
79+
(stream_g1, stream_g2), (res_g1, res_g2) = green_ctx_split_aligned(sm_g1)
80+
else:
81+
(stream_g2, stream_g1), (res_g2, res_g1) = green_ctx_split_aligned(sm_g2)
82+
return (stream_g1, stream_g2), (res_g1, res_g2)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ..speculative.mtp import SampleStateTensorsMTP
4343
from ..speculative.speculation_gate import SpeculationGate
4444
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
45+
from .green_ctx import green_ctx_split_percent
4546
from .guided_decoder import GuidedDecoder
4647
from .handle_additional_outputs import HandleAdditionalOutputs
4748
from .handle_logits import HandleLogits
@@ -215,9 +216,10 @@ def __init__(self,
215216
self.responses = {}
216217
self.result_wait_queues = {}
217218

218-
self.sm_disagg_lock = threading.Lock()
219-
self.ctx_request_cv = threading.Condition(self.sm_disagg_lock)
220-
self.gen_request_cv = threading.Condition(self.sm_disagg_lock)
219+
if self.ctx_model_engine is not None:
220+
self.sm_disagg_lock = threading.Lock()
221+
self.ctx_request_cv = threading.Condition(self.sm_disagg_lock)
222+
self.gen_request_cv = threading.Condition(self.sm_disagg_lock)
221223

222224
# kv cache events
223225
self.kv_cache_manager = self.resource_manager.resource_managers.get(
@@ -229,6 +231,9 @@ def __init__(self,
229231
self.max_input_len = max_input_len
230232
# _executor_loop private data
231233
self.max_num_active_requests = model_engine.get_max_num_sequences()
234+
if self.ctx_model_engine is not None:
235+
self.max_num_active_requests += ctx_model_engine.get_max_num_sequences(
236+
)
232237
self.active_requests: List[LlmRequest] = []
233238
self.expected_num_active_requests = 0
234239
self.ctx_in_transmission_requests = dict()
@@ -1694,7 +1699,11 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream):
16941699
iter_stats=iter_stats)
16951700

16961701
def _executor_loop_sm_disagg(self):
1697-
stream_ctx, stream_gen = self.split_device_green_ctx()
1702+
(stream_ctx, stream_gen), (res_ctx, res_gen) = green_ctx_split_percent(
1703+
self.sm_disagg_ctx_sm_percent, self.device_id)
1704+
logger.info(
1705+
f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase."
1706+
)
16981707

16991708
thread_ctx = threading.Thread(target=self._executor_loop_sm_disagg_ctx,
17001709
args=(stream_ctx, ),
@@ -1705,42 +1714,6 @@ def _executor_loop_sm_disagg(self):
17051714

17061715
thread_ctx.join()
17071716

1708-
def split_device_green_ctx(self):
1709-
device = torch.device("cuda", self.device_id)
1710-
device_properties = torch.cuda.get_device_properties(device)
1711-
sm_count = device_properties.multi_processor_count
1712-
if device_properties.major >= 9:
1713-
sm_min = 8
1714-
sm_align = 8
1715-
else:
1716-
sm_min = 4 if device_properties.major == 8 else 2
1717-
sm_align = 2
1718-
1719-
from flashinfer import green_ctx
1720-
1721-
def split_device_green_ctx_aligned(sm_s1):
1722-
sm_s1 = round(sm_s1 / sm_align) * sm_align
1723-
sm_s1 = min(max(sm_s1, sm_min), sm_count - sm_min)
1724-
return green_ctx.split_device_green_ctx_by_sm_count(device, [sm_s1])
1725-
1726-
sm_ctx = round(sm_count * self.sm_disagg_ctx_sm_percent)
1727-
sm_gen = sm_count - sm_ctx
1728-
# Choose the split closer to user-specified percentage when sm_count is not divisible by sm_align
1729-
sm_ctx_dist = min(sm_ctx % sm_align, sm_align - (sm_ctx % sm_align))
1730-
sm_gen_dist = min(sm_gen % sm_align, sm_align - (sm_gen % sm_align))
1731-
if sm_gen_dist < sm_ctx_dist:
1732-
(stream_gen,
1733-
stream_ctx), (res_gen,
1734-
res_ctx) = split_device_green_ctx_aligned(sm_gen)
1735-
else:
1736-
(stream_ctx,
1737-
stream_gen), (res_ctx,
1738-
res_gen) = split_device_green_ctx_aligned(sm_ctx)
1739-
logger.info(
1740-
f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase."
1741-
)
1742-
return stream_ctx, stream_gen
1743-
17441717
def _accept_draft_tokens(
17451718
self, scheduled_batch: ScheduledRequests,
17461719
target_outputs: SampleStateTensors,

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def allocation_scope(current_stage: ExecutorMemoryType,
374374
attn_runtime_features=attn_runtime_features,
375375
dist=dist,
376376
spec_config=spec_config,
377+
is_sm_disagg_ctx_phase=True,
377378
weight_sharing_model=model_engine.model,
378379
)
379380
else:

tensorrt_llm/_torch/pyexecutor/scheduler.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tensorrt_llm.bindings import internal as tb_internal
88
from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy
99

10-
from .llm_request import LlmRequest, LlmRequestState, get_context_requests
10+
from .llm_request import LlmRequest, LlmRequestState
1111

1212
RequestList = list[LlmRequest]
1313

@@ -79,6 +79,9 @@ def __init__(
7979
scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.
8080
GUARANTEED_NO_EVICT,
8181
two_step_lookahead: bool = False,
82+
no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT,
83+
no_schedule_after_state: LlmRequestState = LlmRequestState.
84+
GENERATION_COMPLETE,
8285
):
8386
super(BindCapacityScheduler, self).__init__()
8487
self.kv_cache_manager = kv_cache_manager
@@ -89,8 +92,8 @@ def __init__(
8992
capacity_scheduler_policy=scheduler_policy._to_pybind(),
9093
has_kv_cache_manager=kv_cache_manager is not None,
9194
two_step_lookahead=two_step_lookahead,
92-
no_schedule_until_state=LlmRequestState.CONTEXT_INIT,
93-
no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE)
95+
no_schedule_until_state=no_schedule_until_state,
96+
no_schedule_after_state=no_schedule_after_state)
9497

9598
def schedule_request(
9699
self, active_requests: RequestList
@@ -175,6 +178,9 @@ def __init__(
175178
max_batch_size: int,
176179
max_num_tokens: int = None,
177180
ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None,
181+
no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT,
182+
no_schedule_after_state: LlmRequestState = LlmRequestState.
183+
GENERATION_COMPLETE,
178184
) -> None:
179185
super(BindMicroBatchScheduler, self).__init__()
180186
self.max_batch_size = max_batch_size
@@ -186,7 +192,8 @@ def __init__(
186192
ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1])
187193

188194
self.impl = tb_internal.algorithms.MicroBatchScheduler(
189-
ctx_chunk_config_cpp, max_num_tokens)
195+
ctx_chunk_config_cpp, max_num_tokens, no_schedule_until_state,
196+
no_schedule_after_state)
190197

191198
def schedule(
192199
self, active_requests: RequestList, inflight_request_ids: set[int]
@@ -216,28 +223,3 @@ def schedule_request(self, active_requests: RequestList,
216223
list(generation_requests), list(paused_requests),
217224
list(fitting_disagg_gen_init_requests),
218225
len(fitting_requests))
219-
220-
221-
class SmDisaggCtxScheduler(RequestScheduler):
222-
223-
def __init__(self, capacity_scheduler: CapacityScheduler,
224-
micro_batch_scheduler: MicroBatchScheduler):
225-
super(SmDisaggCtxScheduler, self).__init__()
226-
self.capacity_scheduler = capacity_scheduler
227-
self.micro_batch_scheduler = micro_batch_scheduler
228-
229-
def schedule_request(self, active_requests: RequestList,
230-
inflight_request_ids: set[int]) -> SchedulerOutput:
231-
fitting_requests, fitting_disagg_gen_init_requests, paused_requests = self.capacity_scheduler.schedule_request(
232-
active_requests)
233-
234-
fitting_requests = get_context_requests(fitting_requests)
235-
236-
context_requests, generation_requests = self.micro_batch_scheduler.schedule(
237-
fitting_requests, inflight_request_ids)
238-
# Convert from binding type RequestVector to list[LlmRequest],
239-
# so Python fields on LlmRequest won't be stripped away
240-
return SchedulerOutput(list(context_requests),
241-
list(generation_requests), list(paused_requests),
242-
list(fitting_disagg_gen_init_requests),
243-
len(fitting_requests))

0 commit comments

Comments
 (0)