Support GLM-5.2 (glm_moe_dsa, DeepSeek-V3.2-style DSA MoE)#1370
Support GLM-5.2 (glm_moe_dsa, DeepSeek-V3.2-style DSA MoE)#1370sufubao wants to merge 8 commits into
Conversation
The fp8kv_dsa NSA path (--llm_kv_type fp8kv_dsa) addressed the FlashMLA sparse kernels with the wrong index space. It was never caught because the default/benchmarked config uses --llm_kv_type None (bf16 NSA path). Decode (sgl_kernel.flash_mla.flash_mla_with_kvcache): - pass indices=topk_mem_indices (absolute KV-pool slots) instead of the raw sequence-space / global-ragged topk_indices - pass an empty (bs, 0) block_table; the kernel ignores block_table for sparse indices, and the previously-computed paged block_table was both wrong (the flat allocator is not 64-page aligned) and dead - pad query heads up to the supported FlashMLA decode variants (64/128) - drop the per-layer cache_seqlens.max().item() host sync (only needed for the removed block_table; also unblocks cuda graph capture) Prefill: the no-prefix branch indexed the local prefill_cache_kv buffer with mem-pool slots; use the local topk_indices (b_topk_index) instead. Mem manager: pad the kv buffer token dim to a multiple of 64 so the decode 64-token page view keeps every valid slot addressable. Contract verified against the SGLang reference (dsa_backend.py::_forward_flashmla_kv); still needs GPU validation on the fp8kv_dsa path (batch>1 decode).
Every MTP draft model (Deepseek3MTP, Qwen3MOEMTP, Mistral, Glm4MoeLite, Glm5_2) already sets the class attribute is_mtp_draft_model = True, so the getattr term alone covers all of them; the "<name> in str(__class__)" chain is unreachable (and was never extended for the GLM models). Reduce the predicate to the getattr.
_init_config already backfills config["rope_theta"] with 1e6 when absent, so _init_glm5_2_rotary's 8e6 fallback is unreachable and disagrees with the value actually used. Align the default to 1e6.
…rnel) Folds the residual elementwise-add into the RMSNorm pass, removing a separate tiny add kernel per residual junction at decode. Variance is taken from the bf16-rounded sum, so the result is bit-identical to `add_` then rmsnorm (test/kernel/test_fused_add_rmsnorm.py: out_max_abs=0). Exposed as RMSNormWeight.fused_add_forward; wired into the decode path in a follow-up.
…Norm
The decode attention-output junction now uses flashinfer kARResidualRMSNorm
(SGLang #22390): the all-reduce, residual add, and ffn RMSNorm fuse into one
oneshot-lamport kernel (fp32_acc) instead of three. LightLLM already launched
flashinfer's allreduce_fusion in kAllReduce mode; this adds the fused pattern via
all_reduce_residual_rmsnorm, with a Triton fused_add_rmsnorm fallback when the
flashinfer AR fast path is inactive (large messages / SP mode). Keeps o
un-reduced (Deepseek2._get_o reduce=False) and inlines the decode attention.
Cuda-graph safe; GSM8K unchanged. Gated by LIGHTLLM_FUSED_{ADD,AR}_RMSNORM=1.
…uant The block-wise-fp8 down projection quantized its activation in a separate per_token_group_quant launch per MoE layer. silu_and_mul_group_quant_fwd now emits the silu output directly as fp8 + row-major group scales (byte-matching per_token_group_quant_fp8: scales bit-identical, 99.9% fp8-exact), and grouped_matmul skips its internal quant when the input is already fp8. Gated by LIGHTLLM_FUSED_SILU_QUANT=1; GSM8K unchanged.
BF16/FP8 + MTP, glm_moe_dsa (DeepSeek-V3.2-style DSA MoE).
There was a problem hiding this comment.
Code Review
This pull request introduces support for GLM-5.2 and GLM-5.2 MTP models, adds fused kernels such as fused_add_rmsnorm and moe_silu_and_mul_group_quant, integrates fused all-reduce + residual-add + RMSNorm via FlashInfer, and improves tokenizer loading fallbacks. Feedback on the changes includes a critical fix for PyTorch advanced indexing in load_index_kv_buffer to prevent silent swapping failures, a shape correction for topk_mem_indices in _nsa_decode_att, and performance optimizations for the new Triton kernels by eliminating redundant loops and tuning warp counts.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def load_index_kv_buffer(self, index, load_tensor_dict): | ||
| self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) | ||
| self.indexer_k_buffer[:, index].copy_(load_tensor_dict["indexer_k_buffer"]) |
There was a problem hiding this comment.
In load_index_kv_buffer, self.kv_buffer[:, index] uses advanced indexing when index is a tensor of indices (which is the case during swapping). In PyTorch, advanced indexing returns a copy of the tensor rather than a view. Therefore, calling .copy_() on self.kv_buffer[:, index] will copy the data into a temporary tensor that is immediately discarded, leaving the original self.kv_buffer unmodified. This will cause silent correctness issues where swapped-in KV cache is not actually written to the buffer.
To fix this, use direct in-place assignment which correctly updates the original tensor even with advanced indexing.
| def load_index_kv_buffer(self, index, load_tensor_dict): | |
| self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) | |
| self.indexer_k_buffer[:, index].copy_(load_tensor_dict["indexer_k_buffer"]) | |
| def load_index_kv_buffer(self, index, load_tensor_dict): | |
| self.kv_buffer[:, index] = load_tensor_dict["kv_buffer"] | |
| self.indexer_k_buffer[:, index] = load_tensor_dict["indexer_k_buffer"] |
| o_tensor, _ = flash_mla_with_kvcache( | ||
| q=q_all, | ||
| k_cache=kv, | ||
| block_table=None, | ||
| cache_seqlens=None, | ||
| block_table=block_table, | ||
| cache_seqlens=cache_seqlens, | ||
| head_dim_v=kv_lora_rank, | ||
| tile_scheduler_metadata=self.flashmla_sched_meta, | ||
| tile_scheduler_metadata=tile_scheduler_metadata, | ||
| num_splits=num_splits, | ||
| softmax_scale=softmax_scale, | ||
| causal=False, | ||
| is_fp8_kvcache=True, | ||
| indices=topk_mem_indices, | ||
| indices=topk_mem_indices.to(dtype=torch.int32), | ||
| ) |
There was a problem hiding this comment.
In _nsa_decode_att, topk_mem_indices is passed directly to flash_mla_with_kvcache as indices=topk_mem_indices.to(dtype=torch.int32).
However, topk_mem_indices is a 2D tensor of shape [batch_size, topk], whereas flash_mla_with_kvcache expects indices to be a 3D tensor of shape [batch_size, q_seqlen, topk] (where q_seqlen = 1 in decode). Passing a 2D tensor can cause shape mismatch errors or silent correctness issues in the kernel.
We should unsqueeze indices to 3D if it is 2D, similar to how it is handled in _nsa_prefill_att.
| o_tensor, _ = flash_mla_with_kvcache( | |
| q=q_all, | |
| k_cache=kv, | |
| block_table=None, | |
| cache_seqlens=None, | |
| block_table=block_table, | |
| cache_seqlens=cache_seqlens, | |
| head_dim_v=kv_lora_rank, | |
| tile_scheduler_metadata=self.flashmla_sched_meta, | |
| tile_scheduler_metadata=tile_scheduler_metadata, | |
| num_splits=num_splits, | |
| softmax_scale=softmax_scale, | |
| causal=False, | |
| is_fp8_kvcache=True, | |
| indices=topk_mem_indices, | |
| indices=topk_mem_indices.to(dtype=torch.int32), | |
| ) | |
| indices = topk_mem_indices.to(dtype=torch.int32) | |
| if indices.ndim == 2: | |
| indices = indices.unsqueeze(1) | |
| o_tensor, _ = flash_mla_with_kvcache( | |
| q=q_all, | |
| k_cache=kv, | |
| block_table=block_table, | |
| cache_seqlens=cache_seqlens, | |
| head_dim_v=kv_lora_rank, | |
| tile_scheduler_metadata=tile_scheduler_metadata, | |
| num_splits=num_splits, | |
| softmax_scale=softmax_scale, | |
| causal=False, | |
| is_fp8_kvcache=True, | |
| indices=indices, | |
| ) |
| # pass 1: residual = residual + x (in place), accumulate variance of the updated residual | ||
| _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
| for off in range(0, N, BLOCK_SIZE): | ||
| cols = off + tl.arange(0, BLOCK_SIZE) | ||
| mask = cols < N | ||
| x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) | ||
| r = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32) | ||
| # round the updated residual to the storage dtype first, then accumulate variance | ||
| # from the rounded value so this matches the unfused (store; reload; rmsnorm) path | ||
| # bit-for-bit instead of using the higher-precision fp32 sum. | ||
| s = (r + x).to(RESIDUAL.dtype.element_ty) | ||
| tl.store(RESIDUAL + cols, s, mask=mask) | ||
| s = s.to(tl.float32) | ||
| _var += s * s | ||
| var = tl.sum(_var, axis=0) / N | ||
| rstd = 1 / tl.sqrt(var + eps) | ||
| # pass 2: normalize the (rounded) updated residual and optionally apply weight | ||
| for off in range(0, N, BLOCK_SIZE): | ||
| cols = off + tl.arange(0, BLOCK_SIZE) | ||
| mask = cols < N | ||
| if HAS_WEIGHT: | ||
| w = tl.load(W + cols, mask=mask).to(tl.float32) | ||
| s = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32) | ||
| y = s * rstd | ||
| if HAS_WEIGHT: | ||
| y = y * w | ||
| tl.store(Y + cols * 1, y.to(Y.dtype.element_ty), mask=mask) |
There was a problem hiding this comment.
In _fused_add_rmsnorm_fwd, the kernel uses a loop for off in range(0, N, BLOCK_SIZE): and performs a redundant load from RESIDUAL in pass 2.
However, fused_add_rmsnorm_forward guarantees that N <= BLOCK_SIZE (otherwise it raises a RuntimeError). This means the loop always executes exactly once, and off is always 0.
We can completely eliminate the loop and avoid the redundant tl.load from RESIDUAL in pass 2 by keeping the computed residual s in registers. This simplifies the control flow and significantly reduces HBM read bandwidth, improving kernel performance.
row = tl.program_id(0)
X += row * x_stride0
RESIDUAL += row * residual_stride0
Y += row * y_stride0
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
r = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32)
# round the updated residual to the storage dtype first, then accumulate variance
# from the rounded value so this matches the unfused (store; reload; rmsnorm) path
# bit-for-bit instead of using the higher-precision fp32 sum.
s = (r + x).to(RESIDUAL.dtype.element_ty)
tl.store(RESIDUAL + cols, s, mask=mask)
s_f32 = s.to(tl.float32)
var = tl.sum(s_f32 * s_f32, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
y = s_f32 * rstd * w
else:
y = s_f32 * rstd
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)| layout=layout, | ||
| USE_LIMIT_AND_ALPHA=limit is not None and alpha is not None, | ||
| USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(), | ||
| num_warps=1, |
There was a problem hiding this comment.
The num_warps parameter is hardcoded to 1 in the _silu_and_mul_group_quant_kernel call. Since group_size is typically 128, using only 1 warp (32 threads) means each thread has to process 4 elements sequentially, which limits instruction-level parallelism and GPU occupancy.
Using num_warps=4 (or dynamically setting it based on group_size // 32) would allow each thread to process 1 element, leading to better coalesced memory access and higher performance.
| num_warps=1, | |
| num_warps=max(1, group_size // 32), |
Summary
Adds support for GLM-5.2 (
glm_moe_dsa), a DeepSeek-V3.2-style architecture: MLA attention + DSA lightning-indexer sparse attention + 256-expert FP8 MoE, 78 layers, with a built-in MTP (nextn) head for EAGLE speculative decoding. Supports BF16/FP8 weights and MTP.What's included
703de087):lightllm/models/glm5_2/(+glm5_2_mtp), registered via@ModelRegistry, reusing the DeepSeek-V3.2 (deepseek3_2) / DeepSeek-V2 layer infrastructure.2654a70f): correct the fp8 DSA FlashMLA prefill/decode kernel contract (index space / pagedblock_table/ 64→128 q-head pad) to match the reference; verified on GPU with--llm_kv_type fp8kv_dsa(batch>1 decode).b0ce5b37,115ca558): drop a dead MTP class-name list; fix a contradictoryrope_thetadefault.fused_add_rmsnorm(7e3dde2c) — fold the post-attention residual add into the ffn RMSNorm (bit-identical). GateLIGHTLLM_FUSED_ADD_RMSNORM.94c8b20e,kARResidualRMSNorm). GateLIGHTLLM_FUSED_AR_RMSNORM.d46cbddc) — fold the down-proj per-token-group fp8 quant intosilu_and_mul. GateLIGHTLLM_FUSED_SILU_QUANT.Verification (8×H200, TP8 + DP8 + EP-MoE, bf16 KV)
fp8kv_dsa0.975; MTP step-1 lossless (0.960). Invalid rate 0 across all configs.test/kernel/test_fused_add_rmsnorm.py(bit-exact),test/kernel/test_silu_and_mul_group_quant.py(scales bit-identical, cos = 1.0).Notes
1) and fall back cleanly when disabled or unsupported.