Skip to content

Support GLM-5.2 (glm_moe_dsa, DeepSeek-V3.2-style DSA MoE)#1370

Open
sufubao wants to merge 8 commits into
ModelTC:mainfrom
sufubao:support_glm52
Open

Support GLM-5.2 (glm_moe_dsa, DeepSeek-V3.2-style DSA MoE)#1370
sufubao wants to merge 8 commits into
ModelTC:mainfrom
sufubao:support_glm52

Conversation

@sufubao

@sufubao sufubao commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds support for GLM-5.2 (glm_moe_dsa), a DeepSeek-V3.2-style architecture: MLA attention + DSA lightning-indexer sparse attention + 256-expert FP8 MoE, 78 layers, with a built-in MTP (nextn) head for EAGLE speculative decoding. Supports BF16/FP8 weights and MTP.

What's included

  • Model support (703de087): lightllm/models/glm5_2/ (+ glm5_2_mtp), registered via @ModelRegistry, reusing the DeepSeek-V3.2 (deepseek3_2) / DeepSeek-V2 layer infrastructure.
  • DSA fp8 fix (2654a70f): correct the fp8 DSA FlashMLA prefill/decode kernel contract (index space / paged block_table / 64→128 q-head pad) to match the reference; verified on GPU with --llm_kv_type fp8kv_dsa (batch>1 decode).
  • Cleanups (b0ce5b37, 115ca558): drop a dead MTP class-name list; fix a contradictory rope_theta default.
  • Decode operator fusions (env-gated, default on):
    • fused_add_rmsnorm (7e3dde2c) — fold the post-attention residual add into the ffn RMSNorm (bit-identical). Gate LIGHTLLM_FUSED_ADD_RMSNORM.
    • flashinfer fused AR + residual + RMSNorm (94c8b20e, kARResidualRMSNorm). Gate LIGHTLLM_FUSED_AR_RMSNORM.
    • MoE silu+mul+fp8-quant fusion (d46cbddc) — fold the down-proj per-token-group fp8 quant into silu_and_mul. Gate LIGHTLLM_FUSED_SILU_QUANT.

Verification (8×H200, TP8 + DP8 + EP-MoE, bf16 KV)

  • GSM8K 5-shot: full set (1319q) 0.951, on par with SGLang 0.965 (200q); fp8kv_dsa 0.975; MTP step-1 lossless (0.960). Invalid rate 0 across all configs.
  • CUDA graph: capture verified on all 8 ranks with every fusion enabled.
  • Fusion correctness unit-tested: test/kernel/test_fused_add_rmsnorm.py (bit-exact), test/kernel/test_silu_and_mul_group_quant.py (scales bit-identical, cos = 1.0).
  • Serving (completions 512×128): peak throughput competitive/ahead of SGLang at high concurrency; the three fusions reduce low-concurrency ITL ~2–3%.

Notes

  • All three fusions are env-gated (default 1) and fall back cleanly when disabled or unsupported.
  • Docs updated: GLM-5.2 added to the supported-models list (CN + EN).

sufubao added 8 commits June 19, 2026 00:13
The fp8kv_dsa NSA path (--llm_kv_type fp8kv_dsa) addressed the FlashMLA
sparse kernels with the wrong index space. It was never caught because
the default/benchmarked config uses --llm_kv_type None (bf16 NSA path).

Decode (sgl_kernel.flash_mla.flash_mla_with_kvcache):
- pass indices=topk_mem_indices (absolute KV-pool slots) instead of the
  raw sequence-space / global-ragged topk_indices
- pass an empty (bs, 0) block_table; the kernel ignores block_table for
  sparse indices, and the previously-computed paged block_table was both
  wrong (the flat allocator is not 64-page aligned) and dead
- pad query heads up to the supported FlashMLA decode variants (64/128)
- drop the per-layer cache_seqlens.max().item() host sync (only needed
  for the removed block_table; also unblocks cuda graph capture)

Prefill: the no-prefix branch indexed the local prefill_cache_kv buffer
with mem-pool slots; use the local topk_indices (b_topk_index) instead.

Mem manager: pad the kv buffer token dim to a multiple of 64 so the
decode 64-token page view keeps every valid slot addressable.

Contract verified against the SGLang reference
(dsa_backend.py::_forward_flashmla_kv); still needs GPU validation on the
fp8kv_dsa path (batch>1 decode).
Every MTP draft model (Deepseek3MTP, Qwen3MOEMTP, Mistral, Glm4MoeLite,
Glm5_2) already sets the class attribute is_mtp_draft_model = True, so
the getattr term alone covers all of them; the "<name> in str(__class__)"
chain is unreachable (and was never extended for the GLM models). Reduce
the predicate to the getattr.
_init_config already backfills config["rope_theta"] with 1e6 when absent,
so _init_glm5_2_rotary's 8e6 fallback is unreachable and disagrees with
the value actually used. Align the default to 1e6.
…rnel)

Folds the residual elementwise-add into the RMSNorm pass, removing a separate
tiny add kernel per residual junction at decode. Variance is taken from the
bf16-rounded sum, so the result is bit-identical to `add_` then rmsnorm
(test/kernel/test_fused_add_rmsnorm.py: out_max_abs=0). Exposed as
RMSNormWeight.fused_add_forward; wired into the decode path in a follow-up.
…Norm

The decode attention-output junction now uses flashinfer kARResidualRMSNorm
(SGLang #22390): the all-reduce, residual add, and ffn RMSNorm fuse into one
oneshot-lamport kernel (fp32_acc) instead of three. LightLLM already launched
flashinfer's allreduce_fusion in kAllReduce mode; this adds the fused pattern via
all_reduce_residual_rmsnorm, with a Triton fused_add_rmsnorm fallback when the
flashinfer AR fast path is inactive (large messages / SP mode). Keeps o
un-reduced (Deepseek2._get_o reduce=False) and inlines the decode attention.
Cuda-graph safe; GSM8K unchanged. Gated by LIGHTLLM_FUSED_{ADD,AR}_RMSNORM=1.
…uant

The block-wise-fp8 down projection quantized its activation in a separate
per_token_group_quant launch per MoE layer. silu_and_mul_group_quant_fwd now
emits the silu output directly as fp8 + row-major group scales (byte-matching
per_token_group_quant_fp8: scales bit-identical, 99.9% fp8-exact), and
grouped_matmul skips its internal quant when the input is already fp8.
Gated by LIGHTLLM_FUSED_SILU_QUANT=1; GSM8K unchanged.
BF16/FP8 + MTP, glm_moe_dsa (DeepSeek-V3.2-style DSA MoE).

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for GLM-5.2 and GLM-5.2 MTP models, adds fused kernels such as fused_add_rmsnorm and moe_silu_and_mul_group_quant, integrates fused all-reduce + residual-add + RMSNorm via FlashInfer, and improves tokenizer loading fallbacks. Feedback on the changes includes a critical fix for PyTorch advanced indexing in load_index_kv_buffer to prevent silent swapping failures, a shape correction for topk_mem_indices in _nsa_decode_att, and performance optimizations for the new Triton kernels by eliminating redundant loops and tuning warp counts.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +65 to +67
def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
self.indexer_k_buffer[:, index].copy_(load_tensor_dict["indexer_k_buffer"])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In load_index_kv_buffer, self.kv_buffer[:, index] uses advanced indexing when index is a tensor of indices (which is the case during swapping). In PyTorch, advanced indexing returns a copy of the tensor rather than a view. Therefore, calling .copy_() on self.kv_buffer[:, index] will copy the data into a temporary tensor that is immediately discarded, leaving the original self.kv_buffer unmodified. This will cause silent correctness issues where swapped-in KV cache is not actually written to the buffer.

To fix this, use direct in-place assignment which correctly updates the original tensor even with advanced indexing.

Suggested change
def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
self.indexer_k_buffer[:, index].copy_(load_tensor_dict["indexer_k_buffer"])
def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index] = load_tensor_dict["kv_buffer"]
self.indexer_k_buffer[:, index] = load_tensor_dict["indexer_k_buffer"]

Comment on lines +218 to 230
o_tensor, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv,
block_table=None,
cache_seqlens=None,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
indices=topk_mem_indices.to(dtype=torch.int32),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _nsa_decode_att, topk_mem_indices is passed directly to flash_mla_with_kvcache as indices=topk_mem_indices.to(dtype=torch.int32).

However, topk_mem_indices is a 2D tensor of shape [batch_size, topk], whereas flash_mla_with_kvcache expects indices to be a 3D tensor of shape [batch_size, q_seqlen, topk] (where q_seqlen = 1 in decode). Passing a 2D tensor can cause shape mismatch errors or silent correctness issues in the kernel.

We should unsqueeze indices to 3D if it is 2D, similar to how it is handled in _nsa_prefill_att.

Suggested change
o_tensor, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv,
block_table=None,
cache_seqlens=None,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
indices=topk_mem_indices.to(dtype=torch.int32),
)
indices = topk_mem_indices.to(dtype=torch.int32)
if indices.ndim == 2:
indices = indices.unsqueeze(1)
o_tensor, _ = flash_mla_with_kvcache(
q=q_all,
k_cache=kv,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=indices,
)

Comment on lines +28 to +54
# pass 1: residual = residual + x (in place), accumulate variance of the updated residual
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
r = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32)
# round the updated residual to the storage dtype first, then accumulate variance
# from the rounded value so this matches the unfused (store; reload; rmsnorm) path
# bit-for-bit instead of using the higher-precision fp32 sum.
s = (r + x).to(RESIDUAL.dtype.element_ty)
tl.store(RESIDUAL + cols, s, mask=mask)
s = s.to(tl.float32)
_var += s * s
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# pass 2: normalize the (rounded) updated residual and optionally apply weight
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
s = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32)
y = s * rstd
if HAS_WEIGHT:
y = y * w
tl.store(Y + cols * 1, y.to(Y.dtype.element_ty), mask=mask)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _fused_add_rmsnorm_fwd, the kernel uses a loop for off in range(0, N, BLOCK_SIZE): and performs a redundant load from RESIDUAL in pass 2.

However, fused_add_rmsnorm_forward guarantees that N <= BLOCK_SIZE (otherwise it raises a RuntimeError). This means the loop always executes exactly once, and off is always 0.

We can completely eliminate the loop and avoid the redundant tl.load from RESIDUAL in pass 2 by keeping the computed residual s in registers. This simplifies the control flow and significantly reduces HBM read bandwidth, improving kernel performance.

    row = tl.program_id(0)
    X += row * x_stride0
    RESIDUAL += row * residual_stride0
    Y += row * y_stride0

    cols = tl.arange(0, BLOCK_SIZE)
    mask = cols < N

    x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
    r = tl.load(RESIDUAL + cols, mask=mask, other=0.0).to(tl.float32)

    # round the updated residual to the storage dtype first, then accumulate variance
    # from the rounded value so this matches the unfused (store; reload; rmsnorm) path
    # bit-for-bit instead of using the higher-precision fp32 sum.
    s = (r + x).to(RESIDUAL.dtype.element_ty)
    tl.store(RESIDUAL + cols, s, mask=mask)

    s_f32 = s.to(tl.float32)
    var = tl.sum(s_f32 * s_f32, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)

    if HAS_WEIGHT:
        w = tl.load(W + cols, mask=mask).to(tl.float32)
        y = s_f32 * rstd * w
    else:
        y = s_f32 * rstd
    tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)

layout=layout,
USE_LIMIT_AND_ALPHA=limit is not None and alpha is not None,
USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(),
num_warps=1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_warps parameter is hardcoded to 1 in the _silu_and_mul_group_quant_kernel call. Since group_size is typically 128, using only 1 warp (32 threads) means each thread has to process 4 elements sequentially, which limits instruction-level parallelism and GPU occupancy.

Using num_warps=4 (or dynamically setting it based on group_size // 32) would allow each thread to process 1 element, leading to better coalesced memory access and higher performance.

Suggested change
num_warps=1,
num_warps=max(1, group_size // 32),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant