Skip to content

Commit 6b25f1d

Browse files
committed
sm disagg implementation
Signed-off-by: Qiang Xu <qiangx@nvidia.com>
1 parent 69b4e52 commit 6b25f1d

File tree

15 files changed

+629
-170
lines changed

15 files changed

+629
-170
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
EagleDecodingConfig, KvCacheConfig,
1515
MTPDecodingConfig, PeftCacheConfig,
1616
SamplerType, SchedulerConfig,
17-
SparseAttentionConfig,
17+
SmDisaggConfig, SparseAttentionConfig,
1818
SpeculativeConfig, TorchLlmArgs)
1919
from tensorrt_llm.logger import logger
2020
from tensorrt_llm.lora_helper import (LoraConfig,
@@ -29,7 +29,7 @@
2929
from .guided_decoder import GuidedDecoder
3030
from .kv_cache_connector import KvCacheConnectorManager
3131
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
32-
from .llm_request import ExecutorResponse
32+
from .llm_request import ExecutorResponse, LlmRequestState
3333
from .mamba_cache_manager import MambaHybridCacheManager
3434
from .model_engine import PyTorchModelEngine
3535
from .py_executor import PyExecutor
@@ -665,6 +665,8 @@ def create_py_executor_instance(
665665
max_batch_size: Optional[int] = None,
666666
max_beam_width: Optional[int] = None,
667667
max_num_tokens: Optional[int] = None,
668+
ctx_model_engine: Optional[PyTorchModelEngine] = None,
669+
sm_disagg_config: Optional[SmDisaggConfig] = None,
668670
peft_cache_config: Optional[PeftCacheConfig] = None,
669671
scheduler_config: Optional[SchedulerConfig] = None,
670672
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
@@ -789,6 +791,23 @@ def create_py_executor_instance(
789791
ctx_chunk_config)
790792
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
791793

794+
if sm_disagg_config is not None:
795+
scheduler_capacity += sm_disagg_config.context_max_batch_size * mapping.pp_size
796+
capacity_scheduler = BindCapacityScheduler(
797+
scheduler_capacity,
798+
kv_cache_manager.impl if kv_cache_manager is not None else None,
799+
peft_cache_manager.impl if peft_cache_manager is not None else None,
800+
scheduler_config.capacity_scheduler_policy,
801+
two_step_lookahead=mapping.has_pp())
802+
mb_scheduler = BindMicroBatchScheduler(
803+
sm_disagg_config.context_max_batch_size,
804+
sm_disagg_config.context_max_num_tokens,
805+
ctx_chunk_config,
806+
no_schedule_after_state=LlmRequestState.GENERATION_IN_PROGRESS)
807+
ctx_scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
808+
else:
809+
ctx_scheduler = None
810+
792811
config = model_engine.model.model_config.pretrained_config
793812
attention_type = AttentionTypeCpp.MLA if is_mla(
794813
config) else AttentionTypeCpp.DEFAULT
@@ -801,6 +820,8 @@ def create_py_executor_instance(
801820
model_engine=model_engine,
802821
sampler=sampler,
803822
drafter=drafter,
823+
ctx_scheduler=ctx_scheduler,
824+
ctx_model_engine=ctx_model_engine,
804825
dist=dist,
805826
max_num_sequences=max_num_sequences,
806827
disable_overlap_scheduler=llm_args.disable_overlap_scheduler,

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,15 @@ def is_control_request(self):
4747
class ExecutorRequestQueue:
4848
"""Handles fetching and processing of new requests from the request queue."""
4949

50-
def __init__(self, dist: Distributed, enable_attention_dp: bool,
51-
max_batch_size: int, max_beam_width: int,
52-
max_num_active_requests: int, enable_iter_perf_stats: bool,
53-
batch_wait_timeout_ms: float):
50+
def __init__(self,
51+
dist: Distributed,
52+
enable_attention_dp: bool,
53+
max_batch_size: int,
54+
max_beam_width: int,
55+
max_num_active_requests: int,
56+
enable_iter_perf_stats: bool,
57+
batch_wait_timeout_ms: float,
58+
is_sm_disagg: bool = False):
5459
self.dist = dist
5560
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5661
self.waiting_queue: deque[RequestQueueItem] = deque()
@@ -59,6 +64,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
5964
self.max_batch_size = max_batch_size
6065
self.max_beam_width = max_beam_width
6166
self.max_num_active_requests = max_num_active_requests
67+
self.is_sm_disagg = is_sm_disagg
6268
self.enqueue_lock = threading.Lock()
6369
self.next_request_id = max_batch_size
6470
self.enable_iter_perf_stats = enable_iter_perf_stats
@@ -333,12 +339,15 @@ def _fetch_and_process_requests(
333339

334340
@nvtx_range("_fetch_new_requests")
335341
def fetch_new_requests(
336-
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
342+
self, activate_requests: List[LlmRequest],
343+
num_active_requests_on_engine: int) -> List[LlmRequest]:
337344

338345
if self.enable_attention_dp:
339346
return self._fetch_new_requests_attention_dp(activate_requests)
340347
else:
341-
return self._fetch_new_requests_attention_tp(len(activate_requests))
348+
num_active_requests = num_active_requests_on_engine if self.is_sm_disagg else len(
349+
activate_requests)
350+
return self._fetch_new_requests_attention_tp(num_active_requests)
342351

343352
def _fetch_new_requests_attention_tp(
344353
self, num_active_requests: int) -> List[LlmRequest]:
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
import torch
14+
from cuda.bindings import driver
15+
16+
from tensorrt_llm.runtime.generation import CUASSERT
17+
18+
19+
def green_ctx_create_streams(res_list, device):
20+
streams = []
21+
for res in res_list:
22+
desc = CUASSERT(driver.cuDevResourceGenerateDesc([res], 1))[0]
23+
green_ctx = CUASSERT(
24+
driver.cuGreenCtxCreate(
25+
desc, device, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
26+
)
27+
)[0]
28+
stream = CUASSERT(
29+
driver.cuGreenCtxStreamCreate(
30+
green_ctx, driver.CUstream_flags.CU_STREAM_NON_BLOCKING, 0
31+
)
32+
)[0]
33+
stream = torch.cuda.get_stream_from_external(stream, device)
34+
streams.append(stream)
35+
return streams
36+
37+
38+
def green_ctx_split_percent(sm_percent: float, device_id: int = 0):
39+
device = CUASSERT(driver.cuDeviceGet(device_id))[0]
40+
41+
res = CUASSERT(
42+
driver.cuDeviceGetDevResource(device, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM)
43+
)[0]
44+
sm_count = res.sm.smCount
45+
46+
major = CUASSERT(
47+
driver.cuDeviceGetAttribute(
48+
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device
49+
)
50+
)[0]
51+
if major >= 9:
52+
sm_min = 8
53+
sm_align = 8
54+
else:
55+
sm_min = 4 if major == 8 else 2
56+
sm_align = 2
57+
58+
def green_ctx_split_aligned(sm_g1):
59+
sm_g1 = round(sm_g1 / sm_align) * sm_align
60+
sm_g1 = min(max(sm_g1, sm_min), sm_count - sm_min)
61+
result = CUASSERT(
62+
driver.cuDevSmResourceSplitByCount(
63+
1, # nbGroups
64+
res,
65+
0, # useFlags
66+
sm_g1,
67+
)
68+
)
69+
res_split = (result[0][0], result[2])
70+
streams = green_ctx_create_streams(res_split, device)
71+
return streams, res_split
72+
73+
sm_g1 = round(sm_count * sm_percent)
74+
sm_g2 = sm_count - sm_g1
75+
# Choose the split closer to sm_percent when sm_count is not divisible by sm_align
76+
sm_g1_dist = min(sm_g1 % sm_align, sm_align - (sm_g1 % sm_align))
77+
sm_g2_dist = min(sm_g2 % sm_align, sm_align - (sm_g2 % sm_align))
78+
if sm_g1_dist <= sm_g2_dist:
79+
(stream_g1, stream_g2), (res_g1, res_g2) = green_ctx_split_aligned(sm_g1)
80+
else:
81+
(stream_g2, stream_g1), (res_g2, res_g1) = green_ctx_split_aligned(sm_g2)
82+
return (stream_g1, stream_g2), (res_g1, res_g2)

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,3 +816,11 @@ def get_draft_token_length(request: LlmRequest) -> int:
816816
if request.py_draft_tokens is not None:
817817
return len(request.py_draft_tokens)
818818
return 0
819+
820+
821+
def get_context_requests(requests: List[LlmRequest]):
822+
return [req for req in requests if req.is_context_init_state]
823+
824+
825+
def get_generation_requests(requests: List[LlmRequest]):
826+
return [req for req in requests if not req.is_context_init_state]

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,12 @@ def __init__(
135135
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
136136
dist: Optional[MPIDist] = None,
137137
spec_config: Optional["DecodingBaseConfig"] = None,
138+
is_sm_disagg_ctx_phase: bool = False,
138139
is_draft_model: bool = False,
139140
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
140141
torch.nn.Module]] = None,
141142
model: Optional[torch.nn.Module] = None,
143+
weight_sharing_model: Optional[torch.nn.Module] = None,
142144
):
143145
self.forward_pass_callable = None
144146
self.ub_buffers = None
@@ -148,6 +150,9 @@ def __init__(
148150
max_seq_len,
149151
max_batch_size,
150152
) = llm_args.get_runtime_sizes()
153+
if is_sm_disagg_ctx_phase:
154+
max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
155+
max_batch_size = llm_args.sm_disagg_config.context_max_batch_size
151156

152157
self.batch_size = max_batch_size
153158
self.max_num_tokens = max_num_tokens
@@ -165,6 +170,7 @@ def __init__(
165170
if dist is not None:
166171
ExpertStatistic.create(self.dist.rank)
167172
self.llm_args = llm_args
173+
self.sm_disagg_enabled = llm_args.sm_disagg_config is not None
168174
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
169175
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
170176

@@ -195,6 +201,7 @@ def __init__(
195201
max_num_tokens=self.max_num_tokens,
196202
max_seq_len=self.max_seq_len,
197203
lora_config=lora_config,
204+
weight_sharing_model=weight_sharing_model,
198205
)
199206
self.model, moe_load_balancer = loader.load(
200207
checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader)
@@ -368,6 +375,7 @@ def __init__(
368375
if self.use_mrope:
369376
self.mrope_position_ids_cuda = torch.empty(
370377
(3, 1, self.max_num_tokens), dtype=torch.int, device='cuda')
378+
self.iter_counter = 0
371379

372380
# Pre-allocated buffers for draft model to avoid implicit synchronization
373381
# These are used to build index tensors without creating tensors from Python lists
@@ -1452,8 +1460,10 @@ def _prepare_tp_inputs(
14521460
# the request has no previous tensor:
14531461
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
14541462
# (2) a dummy request; or
1455-
# (3) the first step in the generation server of disaggregated serving
1456-
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
1463+
# (3) the first step in the generation server of disaggregated serving; or
1464+
# (4) the first step in the generation phase of SM-level disaggregation
1465+
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
1466+
or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
14571467
# get token ids, including input token ids and draft token ids. For these dummy requests,
14581468
# no need to copy the token ids.
14591469
if not (request.is_attention_dp_dummy
@@ -1577,8 +1587,10 @@ def _prepare_tp_inputs(
15771587
# the request has no previous tensor:
15781588
# (1) new_tokens_device is None, which means overlap scheduler is disabled; or
15791589
# (2) a dummy request; or
1580-
# (3) the first step in the generation server of disaggregated serving
1581-
if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
1590+
# (3) the first step in the generation server of disaggregated serving; or
1591+
# (4) the first step in the generation phase of SM-level disaggregation
1592+
if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
1593+
or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
15821594
# skip adding input_ids of CUDA graph dummy requests so that new_tokens_device
15831595
# can be aligned to the correct positions.
15841596
if not request.is_cuda_graph_dummy:
@@ -2547,6 +2559,8 @@ def forward(self,
25472559
spec_decoding_tensor: Optional[SpecDecodingTensor] = None,
25482560
num_accepted_tokens_device: Optional[torch.Tensor] = None,
25492561
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
2562+
self.iter_counter += 1
2563+
25502564
kv_cache_manager = resource_manager.get_resource_manager(
25512565
self.kv_cache_manager_key)
25522566

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def __init__(self,
191191
sparse_attention_config: Optional["SparseAttentionConfig"],
192192
max_num_tokens: int,
193193
max_seq_len: Optional[int],
194-
lora_config: Optional[LoraConfig] = None):
194+
lora_config: Optional[LoraConfig] = None,
195+
weight_sharing_model: Optional[torch.nn.Module] = None):
195196
"""
196197
Initializes the ModelLoader.
197198
@@ -210,6 +211,7 @@ def __init__(self,
210211
self.max_num_tokens = max_num_tokens
211212
self.max_seq_len = max_seq_len
212213
self.lora_config = lora_config
214+
self.weight_sharing_model = weight_sharing_model
213215

214216
def load(
215217
self,
@@ -307,6 +309,12 @@ def init_meta_tensor(t: torch.Tensor):
307309
moe_load_balancer.finalize_model()
308310
logger.info("moe_load_balancer finalize model done")
309311

312+
if self.weight_sharing_model is not None:
313+
model.load_state_dict(self.weight_sharing_model.state_dict(),
314+
assign=True)
315+
# Free up duplicate model weights allocated before weight sharing
316+
torch.cuda.empty_cache()
317+
310318
torch.cuda.current_stream().synchronize()
311319

312320
return model, moe_load_balancer

0 commit comments

Comments
 (0)