Skip to content

Commit 06d767c

Browse files
nv-guomingzcodego7250
authored andcommitted
[None][chroe] Polish qwen3-next modeling code. (NVIDIA#8902)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
1 parent a801799 commit 06d767c

File tree

2 files changed

+96
-95
lines changed

2 files changed

+96
-95
lines changed

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 96 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@
5050
from ..modules.linear import Linear, TensorParallelMode
5151
from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
5252
from ..modules.mamba.layernorm_gated import RMSNorm as RMSNormGated
53+
from ..modules.multi_stream_utils import maybe_execute_in_parallel
5354
from ..modules.rms_norm import RMSNorm
5455
from ..speculative import SpecMetadata
55-
from ..utils import AuxStreamType
56+
from ..utils import AuxStreamType, EventType
5657
from .modeling_qwen3 import Qwen3Attention
5758
from .modeling_speculative import SpecDecOneEngineForCausalLM
5859
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
@@ -387,6 +388,7 @@ def __init__(
387388
self.mapping = model_config.mapping
388389
self.allreduce = AllReduce(mapping=model_config.mapping,
389390
strategy=model_config.allreduce_strategy)
391+
self.aux_stream = aux_stream
390392

391393
self.gate = Qwen3NextGate(
392394
hidden_size=self.hidden_dim,
@@ -425,6 +427,11 @@ def __init__(
425427
dtype=config.torch_dtype,
426428
quant_config=None)
427429

430+
self.event_dict = {
431+
key: torch.cuda.Event()
432+
for key in [EventType.Main, EventType.MoeShared]
433+
}
434+
428435
def forward(
429436
self,
430437
hidden_states: torch.Tensor,
@@ -450,22 +457,33 @@ def forward(
450457
dim=0,
451458
sizes=all_rank_num_tokens)
452459

453-
router_logits = self.gate(hidden_states)
454-
final_hidden_states = self.experts(
455-
hidden_states,
456-
router_logits,
457-
all_rank_num_tokens=all_rank_num_tokens,
458-
use_dp_padding=use_dp_padding,
459-
do_finalize=do_finalize,
460-
)
460+
def _compute_routed_output():
461+
router_logits = self.gate(hidden_states)
462+
final_hidden_states = self.experts(
463+
hidden_states,
464+
router_logits,
465+
all_rank_num_tokens=all_rank_num_tokens,
466+
use_dp_padding=use_dp_padding,
467+
do_finalize=do_finalize,
468+
)
469+
return final_hidden_states
461470

471+
def _compute_shared_output():
472+
shared_expert_output = self.shared_expert(hidden_states)
473+
shared_expert_output = F.sigmoid(
474+
self.shared_expert_gate(hidden_states)) * shared_expert_output
475+
return shared_expert_output
476+
477+
final_hidden_states, shared_expert_output = maybe_execute_in_parallel(
478+
_compute_routed_output,
479+
_compute_shared_output,
480+
self.event_dict[EventType.Main],
481+
self.event_dict[EventType.MoeShared],
482+
self.aux_stream,
483+
)
462484
if not do_finalize:
463485
return final_hidden_states
464486

465-
shared_expert_output = self.shared_expert(hidden_states)
466-
shared_expert_output = F.sigmoid(
467-
self.shared_expert_gate(hidden_states)) * shared_expert_output
468-
469487
final_hidden_states = final_hidden_states + shared_expert_output
470488

471489
if not self.enable_attention_dp and self.mapping.tp_size > 1:
@@ -543,22 +561,21 @@ def fused_qkvzba_split_reshape_cat(
543561
):
544562
batch, seq_len = mixed_qkvz.shape[0], 1
545563
qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
546-
mixed_qkv = torch.empty(
547-
[batch * seq_len, qkv_dim_t],
548-
dtype=mixed_qkvz.dtype,
549-
device=mixed_qkvz.device,
550-
)
551-
z = torch.empty(
552-
[batch * seq_len, num_heads_v, head_v],
553-
dtype=mixed_qkvz.dtype,
554-
device=mixed_qkvz.device,
555-
)
556-
b = torch.empty(
557-
[batch * seq_len, num_heads_v],
558-
dtype=mixed_ba.dtype,
559-
device=mixed_ba.device,
560-
)
561-
a = torch.empty_like(b)
564+
batch_seq = batch * seq_len
565+
566+
# Directly allocate output tensors in their final shapes (no intermediate buffers)
567+
mixed_qkv = torch.empty((batch_seq, qkv_dim_t),
568+
dtype=mixed_qkvz.dtype,
569+
device=mixed_qkvz.device)
570+
z = torch.empty((batch_seq, num_heads_v, head_v),
571+
dtype=mixed_qkvz.dtype,
572+
device=mixed_qkvz.device)
573+
b = torch.empty((batch_seq, num_heads_v),
574+
dtype=mixed_ba.dtype,
575+
device=mixed_ba.device)
576+
a = torch.empty((batch_seq, num_heads_v),
577+
dtype=mixed_ba.dtype,
578+
device=mixed_ba.device)
562579
grid = (batch * seq_len, num_heads_qk)
563580
fused_qkvzba_split_reshape_cat_kernel[grid](
564581
mixed_qkv,
@@ -765,43 +782,42 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
765782
"""
766783
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
767784
"""
768-
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
769-
self.num_k_heads // self.attn_tp_size,
770-
(self.head_k_dim + self.head_k_dim +
771-
(self.head_v_dim + self.head_v_dim) * self.num_v_heads //
772-
self.num_k_heads),
773-
)
774-
new_tensor_shape_ba = mixed_ba.size()[:-1] + (
775-
self.num_k_heads // self.attn_tp_size,
776-
2 * self.num_v_heads // self.num_k_heads,
777-
)
778-
779-
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
780-
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
781-
782-
split_arg_list_qkvz = [
783-
self.head_k_dim,
784-
self.head_k_dim,
785-
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
786-
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
787-
]
788-
split_arg_list_ba = [
789-
self.num_v_heads // self.num_k_heads,
790-
self.num_v_heads // self.num_k_heads,
791-
]
792-
793-
# [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
794-
# --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
795-
(query, key, value, z) = torch.split(mixed_qkvz,
796-
split_arg_list_qkvz,
797-
dim=2)
798-
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
799-
800-
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
801-
value = value.reshape(value.size(0), -1, self.head_v_dim)
802-
z = z.reshape(z.size(0), -1, self.head_v_dim)
803-
b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size)
804-
a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size)
785+
batch_size = mixed_qkvz.size(0)
786+
num_k_heads_local = self.num_k_heads // self.attn_tp_size
787+
num_v_heads_local = self.num_v_heads // self.attn_tp_size
788+
heads_ratio = self.num_v_heads // self.num_k_heads
789+
790+
# Reshape qkvz: [b, d] -> [b, ng, (2*hk + 2*np/ng*hv)]
791+
qkvz_dim_per_head = (self.head_k_dim * 2 +
792+
self.head_v_dim * heads_ratio * 2)
793+
mixed_qkvz = mixed_qkvz.view(batch_size, num_k_heads_local,
794+
qkvz_dim_per_head)
795+
796+
# Reshape ba: [b, d] -> [b, ng, 2*np/ng]
797+
mixed_ba = mixed_ba.view(batch_size, num_k_heads_local, heads_ratio * 2)
798+
799+
# Direct slicing instead of torch.split for better performance
800+
# Compute split boundaries once
801+
q_end = self.head_k_dim
802+
k_end = q_end + self.head_k_dim
803+
v_end = k_end + heads_ratio * self.head_v_dim
804+
z_end = v_end + heads_ratio * self.head_v_dim
805+
806+
# Slice qkvz components: [b, ng, dim] -> individual components
807+
query = mixed_qkvz[..., :q_end]
808+
key = mixed_qkvz[..., q_end:k_end]
809+
810+
# Optimize: Use view (zero-copy) instead of reshape for contiguous slices
811+
# Layout: [v_concat | z_concat], need to reshape each separately
812+
value = mixed_qkvz[..., k_end:v_end].view(batch_size, num_v_heads_local,
813+
self.head_v_dim)
814+
z = mixed_qkvz[..., v_end:z_end].view(batch_size, num_v_heads_local,
815+
self.head_v_dim)
816+
817+
# Slice ba components: [b, ng, 2*np/ng] -> [b, np] each
818+
# Optimize: Use view instead of reshape (zero-copy for contiguous data)
819+
b = mixed_ba[..., :heads_ratio].view(batch_size, num_v_heads_local)
820+
a = mixed_ba[..., heads_ratio:].view(batch_size, num_v_heads_local)
805821

806822
return query, key, value, z, b, a
807823

@@ -817,7 +833,6 @@ def forward_decode(
817833
a = kwargs["a"]
818834
b = kwargs["b"]
819835
cache_indices = kwargs["cache_indices"]
820-
821836
query_start_loc = torch.arange(0,
822837
num_decodes + 1,
823838
device=cu_seqlens.device).to(torch.long)
@@ -831,15 +846,11 @@ def forward_decode(
831846
conv_state_indices=cache_indices,
832847
)
833848

834-
query, key, value = torch.split(
835-
mixed_qkv,
836-
[
837-
self.key_dim // self.attn_tp_size,
838-
self.key_dim // self.attn_tp_size,
839-
self.value_dim // self.attn_tp_size,
840-
],
841-
dim=-1,
842-
)
849+
# Direct slicing instead of torch.split for better performance
850+
key_size = self.key_dim // self.attn_tp_size
851+
query = mixed_qkv[..., :key_size]
852+
key = mixed_qkv[..., key_size:key_size * 2]
853+
value = mixed_qkv[..., key_size * 2:]
843854
# Reshape from [l, h*d] to [1, l, h, d]
844855
seq_len = query.shape[0]
845856
num_heads = query.shape[1] // self.head_k_dim
@@ -925,8 +936,7 @@ def forward_extend(
925936
conv_states=conv_states_to_use,
926937
has_initial_state=has_initial_states,
927938
cache_indices=cache_indices,
928-
query_start_loc=query_start_loc,
929-
).transpose(0, 1)
939+
query_start_loc=query_start_loc).transpose(0, 1)
930940

931941
key_split_dim = self.key_dim // self.attn_tp_size
932942
value_split_dim = self.value_dim // self.attn_tp_size
@@ -1024,9 +1034,8 @@ def forward(
10241034

10251035
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
10261036
projected_states_ba = self.in_proj_ba(hidden_states)
1027-
query, key, value, z, b, a = self.fix_query_key_value_ordering(
1028-
projected_states_qkvz, projected_states_ba)
10291037

1038+
# Use fused kernel when possible to avoid elementwise ops
10301039
if self.num_v_heads // self.num_k_heads in [1, 2,
10311040
4]: # and is_cuda_graph:
10321041
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
@@ -1060,17 +1069,11 @@ def forward(
10601069
"num_prefill": num_prefills,
10611070
"num_decode": num_decodes,
10621071
}
1063-
1064-
new_implementation = True
1065-
if new_implementation:
1066-
if num_prefills > 0:
1067-
attn_out = self.forward_extend(conv_states, ssm_states,
1068-
**kwargs)
1069-
else:
1070-
attn_out = self.forward_decode(conv_states, ssm_states,
1071-
num_decodes,
1072-
mamba_metadata.cu_seqlens,
1073-
**kwargs)
1072+
if num_prefills > 0:
1073+
attn_out = self.forward_extend(conv_states, ssm_states, **kwargs)
1074+
else:
1075+
attn_out = self.forward_decode(conv_states, ssm_states, num_decodes,
1076+
mamba_metadata.cu_seqlens, **kwargs)
10741077

10751078
z_shape_og = z.shape
10761079
# reshape input data into 2D tensor
@@ -1125,7 +1128,7 @@ def __init__(
11251128
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0"
11261129
self.enable_fusion &= not self.enable_attention_dp
11271130

1128-
self.mapping.has_tp()
1131+
# has_tp = self.mapping.has_tp()
11291132
has_pp = self.mapping.has_pp()
11301133

11311134
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
@@ -1284,7 +1287,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
12841287
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0"
12851288
self.enable_fusion &= not self.enable_attention_dp
12861289

1287-
self.mapping.has_tp()
1290+
# has_tp = self.mapping.has_tp()
12881291
has_pp = self.mapping.has_pp()
12891292

12901293
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp

tensorrt_llm/_torch/modules/fla/chunk.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def forward(
9090
cu_seqlens: Optional[torch.LongTensor] = None,
9191
use_qk_l2norm_in_kernel: bool = False,
9292
):
93-
pass
94-
9593
if use_qk_l2norm_in_kernel:
9694
q = l2norm_fwd(q)
9795
k = l2norm_fwd(k)

0 commit comments

Comments
 (0)