Skip to content

Commit bed20f7

Browse files
committed
sm disagg implementation
Signed-off-by: Qiang Xu <qiangx@nvidia.com>
1 parent 1a1d617 commit bed20f7

File tree

11 files changed

+453
-259
lines changed

11 files changed

+453
-259
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
EagleDecodingConfig, KvCacheConfig,
1515
MTPDecodingConfig, PeftCacheConfig,
1616
SamplerType, SchedulerConfig,
17-
SparseAttentionConfig,
17+
SmDisaggConfig, SparseAttentionConfig,
1818
SpeculativeConfig, TorchLlmArgs)
1919
from tensorrt_llm.logger import logger
2020
from tensorrt_llm.lora_helper import (LoraConfig,
@@ -38,7 +38,7 @@
3838
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
3939
TRTLLMSampler)
4040
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
41-
SimpleScheduler)
41+
SimpleScheduler, SmDisaggCtxScheduler)
4242
from .seq_slot_manager import SeqSlotManager
4343

4444
GB = 1 << 30
@@ -665,6 +665,8 @@ def create_py_executor_instance(
665665
max_batch_size: Optional[int] = None,
666666
max_beam_width: Optional[int] = None,
667667
max_num_tokens: Optional[int] = None,
668+
ctx_model_engine: Optional[PyTorchModelEngine] = None,
669+
sm_disagg_config: Optional[SmDisaggConfig] = None,
668670
peft_cache_config: Optional[PeftCacheConfig] = None,
669671
scheduler_config: Optional[SchedulerConfig] = None,
670672
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
@@ -789,6 +791,21 @@ def create_py_executor_instance(
789791
ctx_chunk_config)
790792
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
791793

794+
if sm_disagg_config is not None:
795+
scheduler_capacity += sm_disagg_config.context_max_batch_size * mapping.pp_size
796+
capacity_scheduler = BindCapacityScheduler(
797+
scheduler_capacity,
798+
kv_cache_manager.impl if kv_cache_manager is not None else None,
799+
peft_cache_manager.impl if peft_cache_manager is not None else None,
800+
scheduler_config.capacity_scheduler_policy,
801+
two_step_lookahead=mapping.has_pp())
802+
mb_scheduler = BindMicroBatchScheduler(
803+
sm_disagg_config.context_max_batch_size,
804+
sm_disagg_config.context_max_num_tokens, ctx_chunk_config)
805+
ctx_scheduler = SmDisaggCtxScheduler(capacity_scheduler, mb_scheduler)
806+
else:
807+
ctx_scheduler = None
808+
792809
config = model_engine.model.model_config.pretrained_config
793810
attention_type = AttentionTypeCpp.MLA if is_mla(
794811
config) else AttentionTypeCpp.DEFAULT
@@ -801,6 +818,8 @@ def create_py_executor_instance(
801818
model_engine=model_engine,
802819
sampler=sampler,
803820
drafter=drafter,
821+
ctx_scheduler=ctx_scheduler,
822+
ctx_model_engine=ctx_model_engine,
804823
dist=dist,
805824
max_num_sequences=max_num_sequences,
806825
disable_overlap_scheduler=llm_args.disable_overlap_scheduler,

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class ExecutorRequestQueue:
5050
def __init__(self, dist: Distributed, enable_attention_dp: bool,
5151
max_batch_size: int, max_beam_width: int,
5252
max_num_active_requests: int, enable_iter_perf_stats: bool,
53-
batch_wait_timeout_ms: float, is_disaggregated: bool):
53+
batch_wait_timeout_ms: float, is_disaggregated: bool,
54+
is_sm_disagg: bool):
5455
self.dist = dist
5556
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5657
self.waiting_queue: deque[RequestQueueItem] = deque()
@@ -60,6 +61,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
6061
self.max_beam_width = max_beam_width
6162
self.max_num_active_requests = max_num_active_requests
6263
self.is_disaggregated = is_disaggregated
64+
self.is_sm_disagg = is_sm_disagg
6365
self.enqueue_lock = threading.Lock()
6466
self.next_request_id = max_batch_size
6567
self.enable_iter_perf_stats = enable_iter_perf_stats
@@ -333,13 +335,35 @@ def _fetch_and_process_requests(
333335

334336
@nvtx_range("_fetch_new_requests")
335337
def fetch_new_requests(
336-
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
338+
self, activate_requests: List[LlmRequest],
339+
num_active_requests_on_engine: int) -> List[LlmRequest]:
337340

338-
if self.enable_attention_dp:
341+
if self.is_sm_disagg:
342+
return self._fetch_new_requests_sm_disagg(
343+
len(activate_requests), num_active_requests_on_engine)
344+
elif self.enable_attention_dp:
339345
return self._fetch_new_requests_attention_dp(activate_requests)
340346
else:
341347
return self._fetch_new_requests_attention_tp(len(activate_requests))
342348

349+
def _fetch_new_requests_sm_disagg(
350+
self, num_active_requests: int,
351+
num_active_requests_on_engine: int) -> List[LlmRequest]:
352+
"""Handle SM-level disaggregation request fetching."""
353+
total_max_num_active_requests = (self.max_num_active_requests +
354+
num_active_requests -
355+
num_active_requests_on_engine)
356+
357+
# fetch and process requests into waiting queue
358+
new_requests = self._fetch_and_process_requests(
359+
num_active_requests_on_engine,
360+
total_max_num_active_requests,
361+
enable_attention_dp=False)
362+
363+
# Merge requests and add to active list
364+
merged_requests = self._merge_requests(new_requests)
365+
return merged_requests
366+
343367
def _fetch_new_requests_attention_tp(
344368
self, num_active_requests: int) -> List[LlmRequest]:
345369
"""Handle standard (non-attention DP) request fetching."""

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,11 @@ def get_draft_token_length(request: LlmRequest) -> int:
797797
if request.py_draft_tokens is not None:
798798
return len(request.py_draft_tokens)
799799
return 0
800+
801+
802+
def get_context_requests(requests: List[LlmRequest]):
803+
return [req for req in requests if req.is_context_init_state]
804+
805+
806+
def get_generation_requests(requests: List[LlmRequest]):
807+
return [req for req in requests if not req.is_context_init_state]

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,12 @@ def __init__(
136136
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
137137
dist: Optional[MPIDist] = None,
138138
spec_config: Optional["DecodingBaseConfig"] = None,
139+
is_sm_disagg_ctx_phase: bool = False,
139140
is_draft_model: bool = False,
140141
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
141142
torch.nn.Module]] = None,
142143
model: Optional[torch.nn.Module] = None,
144+
weight_sharing_model: Optional[torch.nn.Module] = None,
143145
):
144146
self.forward_pass_callable = None
145147
self.ub_buffers = None
@@ -149,6 +151,9 @@ def __init__(
149151
max_seq_len,
150152
max_batch_size,
151153
) = llm_args.get_runtime_sizes()
154+
if is_sm_disagg_ctx_phase:
155+
max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
156+
max_batch_size = llm_args.sm_disagg_config.context_max_batch_size
152157

153158
self.batch_size = max_batch_size
154159
self.max_num_tokens = max_num_tokens
@@ -166,6 +171,7 @@ def __init__(
166171
if dist is not None:
167172
ExpertStatistic.create(self.dist.rank)
168173
self.llm_args = llm_args
174+
self.sm_disagg_enabled = llm_args.sm_disagg_config is not None
169175
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
170176
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
171177

@@ -196,6 +202,7 @@ def __init__(
196202
max_num_tokens=self.max_num_tokens,
197203
max_seq_len=self.max_seq_len,
198204
lora_config=lora_config,
205+
weight_sharing_model=weight_sharing_model,
199206
)
200207
self.model, moe_load_balancer = loader.load(
201208
checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader)
@@ -1352,8 +1359,10 @@ def _prepare_tp_inputs(
13521359
# the request has no previous tensor:
13531360
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
13541361
# (2) a dummy request; or
1355-
# (3) the first step in the generation server of disaggregated serving
1356-
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
1362+
# (3) the first step in the generation server of disaggregated serving; or
1363+
# (4) the first step in the generation phase of SM-level disaggregation
1364+
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
1365+
or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
13571366
# get token ids, including input token ids and draft token ids. For these dummy requests,
13581367
# no need to copy the token ids.
13591368
if not (request.is_attention_dp_dummy
@@ -1456,8 +1465,10 @@ def _prepare_tp_inputs(
14561465
# the request has no previous tensor:
14571466
# (1) new_tokens_device is None, which means overlap scheduler is disabled; or
14581467
# (2) a dummy request; or
1459-
# (3) the first step in the generation server of disaggregated serving
1460-
if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
1468+
# (3) the first step in the generation server of disaggregated serving; or
1469+
# (4) the first step in the generation phase of SM-level disaggregation
1470+
if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
1471+
or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
14611472
# skip adding input_ids of CUDA graph dummy requests so that new_tokens_device
14621473
# can be aligned to the correct positions.
14631474
if not request.is_cuda_graph_dummy:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def __init__(self,
191191
sparse_attention_config: Optional["SparseAttentionConfig"],
192192
max_num_tokens: int,
193193
max_seq_len: Optional[int],
194-
lora_config: Optional[LoraConfig] = None):
194+
lora_config: Optional[LoraConfig] = None,
195+
weight_sharing_model: Optional[torch.nn.Module] = None):
195196
"""
196197
Initializes the ModelLoader.
197198
@@ -210,6 +211,7 @@ def __init__(self,
210211
self.max_num_tokens = max_num_tokens
211212
self.max_seq_len = max_seq_len
212213
self.lora_config = lora_config
214+
self.weight_sharing_model = weight_sharing_model
213215

214216
def load(
215217
self,
@@ -307,6 +309,12 @@ def init_meta_tensor(t: torch.Tensor):
307309
moe_load_balancer.finalize_model()
308310
logger.info("moe_load_balancer finalize model done")
309311

312+
if self.weight_sharing_model is not None:
313+
model.load_state_dict(self.weight_sharing_model.state_dict(),
314+
assign=True)
315+
# Free up duplicate model weights allocated before weight sharing
316+
torch.cuda.empty_cache()
317+
310318
torch.cuda.current_stream().synchronize()
311319

312320
return model, moe_load_balancer

0 commit comments

Comments
 (0)