5050from ..modules .linear import Linear , TensorParallelMode
5151from ..modules .mamba .causal_conv1d import causal_conv1d_fn , causal_conv1d_update
5252from ..modules .mamba .layernorm_gated import RMSNorm as RMSNormGated
53+ from ..modules .multi_stream_utils import maybe_execute_in_parallel
5354from ..modules .rms_norm import RMSNorm
5455from ..speculative import SpecMetadata
55- from ..utils import AuxStreamType
56+ from ..utils import AuxStreamType , EventType
5657from .modeling_qwen3 import Qwen3Attention
5758from .modeling_speculative import SpecDecOneEngineForCausalLM
5859from .modeling_utils import DecoderModel , EagerFusionConfig , register_auto_model
@@ -387,6 +388,7 @@ def __init__(
387388 self .mapping = model_config .mapping
388389 self .allreduce = AllReduce (mapping = model_config .mapping ,
389390 strategy = model_config .allreduce_strategy )
391+ self .aux_stream = aux_stream
390392
391393 self .gate = Qwen3NextGate (
392394 hidden_size = self .hidden_dim ,
@@ -425,6 +427,11 @@ def __init__(
425427 dtype = config .torch_dtype ,
426428 quant_config = None )
427429
430+ self .event_dict = {
431+ key : torch .cuda .Event ()
432+ for key in [EventType .Main , EventType .MoeShared ]
433+ }
434+
428435 def forward (
429436 self ,
430437 hidden_states : torch .Tensor ,
@@ -450,22 +457,33 @@ def forward(
450457 dim = 0 ,
451458 sizes = all_rank_num_tokens )
452459
453- router_logits = self .gate (hidden_states )
454- final_hidden_states = self .experts (
455- hidden_states ,
456- router_logits ,
457- all_rank_num_tokens = all_rank_num_tokens ,
458- use_dp_padding = use_dp_padding ,
459- do_finalize = do_finalize ,
460- )
460+ def _compute_routed_output ():
461+ router_logits = self .gate (hidden_states )
462+ final_hidden_states = self .experts (
463+ hidden_states ,
464+ router_logits ,
465+ all_rank_num_tokens = all_rank_num_tokens ,
466+ use_dp_padding = use_dp_padding ,
467+ do_finalize = do_finalize ,
468+ )
469+ return final_hidden_states
461470
471+ def _compute_shared_output ():
472+ shared_expert_output = self .shared_expert (hidden_states )
473+ shared_expert_output = F .sigmoid (
474+ self .shared_expert_gate (hidden_states )) * shared_expert_output
475+ return shared_expert_output
476+
477+ final_hidden_states , shared_expert_output = maybe_execute_in_parallel (
478+ _compute_routed_output ,
479+ _compute_shared_output ,
480+ self .event_dict [EventType .Main ],
481+ self .event_dict [EventType .MoeShared ],
482+ self .aux_stream ,
483+ )
462484 if not do_finalize :
463485 return final_hidden_states
464486
465- shared_expert_output = self .shared_expert (hidden_states )
466- shared_expert_output = F .sigmoid (
467- self .shared_expert_gate (hidden_states )) * shared_expert_output
468-
469487 final_hidden_states = final_hidden_states + shared_expert_output
470488
471489 if not self .enable_attention_dp and self .mapping .tp_size > 1 :
@@ -543,22 +561,21 @@ def fused_qkvzba_split_reshape_cat(
543561):
544562 batch , seq_len = mixed_qkvz .shape [0 ], 1
545563 qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
546- mixed_qkv = torch .empty (
547- [batch * seq_len , qkv_dim_t ],
548- dtype = mixed_qkvz .dtype ,
549- device = mixed_qkvz .device ,
550- )
551- z = torch .empty (
552- [batch * seq_len , num_heads_v , head_v ],
553- dtype = mixed_qkvz .dtype ,
554- device = mixed_qkvz .device ,
555- )
556- b = torch .empty (
557- [batch * seq_len , num_heads_v ],
558- dtype = mixed_ba .dtype ,
559- device = mixed_ba .device ,
560- )
561- a = torch .empty_like (b )
564+ batch_seq = batch * seq_len
565+
566+ # Directly allocate output tensors in their final shapes (no intermediate buffers)
567+ mixed_qkv = torch .empty ((batch_seq , qkv_dim_t ),
568+ dtype = mixed_qkvz .dtype ,
569+ device = mixed_qkvz .device )
570+ z = torch .empty ((batch_seq , num_heads_v , head_v ),
571+ dtype = mixed_qkvz .dtype ,
572+ device = mixed_qkvz .device )
573+ b = torch .empty ((batch_seq , num_heads_v ),
574+ dtype = mixed_ba .dtype ,
575+ device = mixed_ba .device )
576+ a = torch .empty ((batch_seq , num_heads_v ),
577+ dtype = mixed_ba .dtype ,
578+ device = mixed_ba .device )
562579 grid = (batch * seq_len , num_heads_qk )
563580 fused_qkvzba_split_reshape_cat_kernel [grid ](
564581 mixed_qkv ,
@@ -765,43 +782,42 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
765782 """
766783 Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
767784 """
768- new_tensor_shape_qkvz = mixed_qkvz .size ()[:- 1 ] + (
769- self .num_k_heads // self .attn_tp_size ,
770- (self .head_k_dim + self .head_k_dim +
771- (self .head_v_dim + self .head_v_dim ) * self .num_v_heads //
772- self .num_k_heads ),
773- )
774- new_tensor_shape_ba = mixed_ba .size ()[:- 1 ] + (
775- self .num_k_heads // self .attn_tp_size ,
776- 2 * self .num_v_heads // self .num_k_heads ,
777- )
778-
779- mixed_qkvz = mixed_qkvz .view (* new_tensor_shape_qkvz )
780- mixed_ba = mixed_ba .view (* new_tensor_shape_ba )
781-
782- split_arg_list_qkvz = [
783- self .head_k_dim ,
784- self .head_k_dim ,
785- (self .num_v_heads // self .num_k_heads * self .head_v_dim ),
786- (self .num_v_heads // self .num_k_heads * self .head_v_dim ),
787- ]
788- split_arg_list_ba = [
789- self .num_v_heads // self .num_k_heads ,
790- self .num_v_heads // self .num_k_heads ,
791- ]
792-
793- # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
794- # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
795- (query , key , value , z ) = torch .split (mixed_qkvz ,
796- split_arg_list_qkvz ,
797- dim = 2 )
798- (b , a ) = torch .split (mixed_ba , split_arg_list_ba , dim = 2 )
799-
800- # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
801- value = value .reshape (value .size (0 ), - 1 , self .head_v_dim )
802- z = z .reshape (z .size (0 ), - 1 , self .head_v_dim )
803- b = b .reshape (b .size (0 ), self .num_v_heads // self .attn_tp_size )
804- a = a .reshape (a .size (0 ), self .num_v_heads // self .attn_tp_size )
785+ batch_size = mixed_qkvz .size (0 )
786+ num_k_heads_local = self .num_k_heads // self .attn_tp_size
787+ num_v_heads_local = self .num_v_heads // self .attn_tp_size
788+ heads_ratio = self .num_v_heads // self .num_k_heads
789+
790+ # Reshape qkvz: [b, d] -> [b, ng, (2*hk + 2*np/ng*hv)]
791+ qkvz_dim_per_head = (self .head_k_dim * 2 +
792+ self .head_v_dim * heads_ratio * 2 )
793+ mixed_qkvz = mixed_qkvz .view (batch_size , num_k_heads_local ,
794+ qkvz_dim_per_head )
795+
796+ # Reshape ba: [b, d] -> [b, ng, 2*np/ng]
797+ mixed_ba = mixed_ba .view (batch_size , num_k_heads_local , heads_ratio * 2 )
798+
799+ # Direct slicing instead of torch.split for better performance
800+ # Compute split boundaries once
801+ q_end = self .head_k_dim
802+ k_end = q_end + self .head_k_dim
803+ v_end = k_end + heads_ratio * self .head_v_dim
804+ z_end = v_end + heads_ratio * self .head_v_dim
805+
806+ # Slice qkvz components: [b, ng, dim] -> individual components
807+ query = mixed_qkvz [..., :q_end ]
808+ key = mixed_qkvz [..., q_end :k_end ]
809+
810+ # Optimize: Use view (zero-copy) instead of reshape for contiguous slices
811+ # Layout: [v_concat | z_concat], need to reshape each separately
812+ value = mixed_qkvz [..., k_end :v_end ].view (batch_size , num_v_heads_local ,
813+ self .head_v_dim )
814+ z = mixed_qkvz [..., v_end :z_end ].view (batch_size , num_v_heads_local ,
815+ self .head_v_dim )
816+
817+ # Slice ba components: [b, ng, 2*np/ng] -> [b, np] each
818+ # Optimize: Use view instead of reshape (zero-copy for contiguous data)
819+ b = mixed_ba [..., :heads_ratio ].view (batch_size , num_v_heads_local )
820+ a = mixed_ba [..., heads_ratio :].view (batch_size , num_v_heads_local )
805821
806822 return query , key , value , z , b , a
807823
@@ -817,7 +833,6 @@ def forward_decode(
817833 a = kwargs ["a" ]
818834 b = kwargs ["b" ]
819835 cache_indices = kwargs ["cache_indices" ]
820-
821836 query_start_loc = torch .arange (0 ,
822837 num_decodes + 1 ,
823838 device = cu_seqlens .device ).to (torch .long )
@@ -831,15 +846,11 @@ def forward_decode(
831846 conv_state_indices = cache_indices ,
832847 )
833848
834- query , key , value = torch .split (
835- mixed_qkv ,
836- [
837- self .key_dim // self .attn_tp_size ,
838- self .key_dim // self .attn_tp_size ,
839- self .value_dim // self .attn_tp_size ,
840- ],
841- dim = - 1 ,
842- )
849+ # Direct slicing instead of torch.split for better performance
850+ key_size = self .key_dim // self .attn_tp_size
851+ query = mixed_qkv [..., :key_size ]
852+ key = mixed_qkv [..., key_size :key_size * 2 ]
853+ value = mixed_qkv [..., key_size * 2 :]
843854 # Reshape from [l, h*d] to [1, l, h, d]
844855 seq_len = query .shape [0 ]
845856 num_heads = query .shape [1 ] // self .head_k_dim
@@ -925,8 +936,7 @@ def forward_extend(
925936 conv_states = conv_states_to_use ,
926937 has_initial_state = has_initial_states ,
927938 cache_indices = cache_indices ,
928- query_start_loc = query_start_loc ,
929- ).transpose (0 , 1 )
939+ query_start_loc = query_start_loc ).transpose (0 , 1 )
930940
931941 key_split_dim = self .key_dim // self .attn_tp_size
932942 value_split_dim = self .value_dim // self .attn_tp_size
@@ -1024,9 +1034,8 @@ def forward(
10241034
10251035 projected_states_qkvz = self .in_proj_qkvz (hidden_states )
10261036 projected_states_ba = self .in_proj_ba (hidden_states )
1027- query , key , value , z , b , a = self .fix_query_key_value_ordering (
1028- projected_states_qkvz , projected_states_ba )
10291037
1038+ # Use fused kernel when possible to avoid elementwise ops
10301039 if self .num_v_heads // self .num_k_heads in [1 , 2 ,
10311040 4 ]: # and is_cuda_graph:
10321041 mixed_qkv , z , b , a = fused_qkvzba_split_reshape_cat (
@@ -1060,17 +1069,11 @@ def forward(
10601069 "num_prefill" : num_prefills ,
10611070 "num_decode" : num_decodes ,
10621071 }
1063-
1064- new_implementation = True
1065- if new_implementation :
1066- if num_prefills > 0 :
1067- attn_out = self .forward_extend (conv_states , ssm_states ,
1068- ** kwargs )
1069- else :
1070- attn_out = self .forward_decode (conv_states , ssm_states ,
1071- num_decodes ,
1072- mamba_metadata .cu_seqlens ,
1073- ** kwargs )
1072+ if num_prefills > 0 :
1073+ attn_out = self .forward_extend (conv_states , ssm_states , ** kwargs )
1074+ else :
1075+ attn_out = self .forward_decode (conv_states , ssm_states , num_decodes ,
1076+ mamba_metadata .cu_seqlens , ** kwargs )
10741077
10751078 z_shape_og = z .shape
10761079 # reshape input data into 2D tensor
@@ -1125,7 +1128,7 @@ def __init__(
11251128 "TRTLLM_QWEN3_EAGER_FUSION_DISABLED" , "1" ) == "0"
11261129 self .enable_fusion &= not self .enable_attention_dp
11271130
1128- self .mapping .has_tp ()
1131+ # has_tp = self.mapping.has_tp()
11291132 has_pp = self .mapping .has_pp ()
11301133
11311134 # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
@@ -1284,7 +1287,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
12841287 "TRTLLM_QWEN3_EAGER_FUSION_DISABLED" , "0" ) == "0"
12851288 self .enable_fusion &= not self .enable_attention_dp
12861289
1287- self .mapping .has_tp ()
1290+ # has_tp = self.mapping.has_tp()
12881291 has_pp = self .mapping .has_pp ()
12891292
12901293 # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
0 commit comments