Skip to content

Dsv4#1937

Open
HAOCHENYE wants to merge 58 commits into
mainfrom
dsv4
Open

Dsv4#1937
HAOCHENYE wants to merge 58 commits into
mainfrom
dsv4

Conversation

@HAOCHENYE

Copy link
Copy Markdown
Collaborator

No description provided.

HAOCHENYE added 30 commits June 29, 2026 06:12
`packaging` is imported by sphinx before `autodoc_mock_imports` takes effect,
so a module-level gate like `Version(torch.__version__) >= Version("2.9.1")`
was handing a `_MockObject` to the real `Version()` and raising `TypeError`.
Patch `_MockModule.__getattr__` and `_MockObject.__getattr__` to return a
parseable "0.0.0" placeholder for `__version__`. Also mock `causal_conv1d_cuda`
which has the same import-order problem.

Add MTP loss context and SP support design notes under docs/design/.
…-Flash

* DeepSeekV4Config / DeepSeekV4 (xtuner/v1/model/moe/deepseek_v4.py): ties DSA
  attention, hash routing, sqrt-softplus NoAux routing, dual rope and
  Hyper-Connections together. _V4InnerBlock is an HCInnerBlock adapter that
  stashes per-forward DSA rope / seq_ctx / input_ids via set_context() so the
  HC wrapper's narrow attn_block(x)/ffn_block(x) protocol can drive
  MoEDecoderLayer's attention + MoE dispatch without widening HC's signature.
  _V4DecoderLayer bridges HCDecoderLayer's single-tensor output back to
  MoE._forward's (hidden, logits, weights) contract. _forward is replaced
  here because V4 carries hc_mult residual streams [B, S, hc_mult, D] and
  needs an hc_head reduce before the final norm + lm_head.

* MoEDecoderLayer (xtuner/v1/module/decoder_layer/moe_decoder_layer.py):
  attention_config type widened to MHA | MLA | GatedDeltaNet | DSAConfig and a
  new optional attention_module kwarg lets callers inject a pre-built DSA
  module. DSAConfig.build requires a per-layer compress_ratio which the
  generic build site cannot supply; injecting the pre-built module keeps the
  per-layer wiring in DeepSeekV4.build_layers (which has the context) and
  leaves the existing MLA/MHA/Gated build paths untouched.

* MoE (xtuner/v1/model/moe/moe.py): new _should_compute_aux_loss(layer_idx)
  hook (default True) gates aux_loss.accumulate so DeepSeekV4 can skip the
  hash-routed layers (HashRouter emits a [1] dummy logits placeholder that
  is incompatible with index_select(0, nonpad_indices)).

* RopeParametersConfig (xtuner/v1/module/rope/rope.py): exclude
  compress_rope_theta / compress_ratios from the rope_scaling field-name
  mapping so the backward-compat rope_scaling_cfg property stays valid for
  V4 configs (RopeScalingConfig is extra=forbid and has no dual-rope analog).

* Entry-point (xtuner/v1/model/__init__.py): register deepseek_v4 in
  get_model_config_from_hf with a config.json JSON-fallback for
  transformers releases that do not yet ship DeepseekV4Config.

* Tests (tests/model/test_deepseek_v4_moe.py): from_hf, to_hf_key_list
  coverage against the released safetensors index, entry-point dispatch, and
  hash-layer aux-loss gate. Decoder-layer parity vs the V4 inference
  reference is marked skipped because (a) the reference imports TileLang FP4
  kernels and (b) the bundled flash_attn lacks the sinks parameter required
  for V4 attn_sink.
DSA was strict-validating that position_embeddings cos/sin have shape
[B, S, qk_rope_head_dim], but XTuner's RotaryEmbedding (including the
DualRotaryEmbedding added in PR2) emits cos/sin at the full head_dim
following the rotate-half cat((freqs, freqs)) convention. This caused
an end-to-end forward through DeepSeekV4 to raise ValueError before
hitting any attention math.

Fix: relax the validation to accept any size >= qk_rope_head_dim, and
slice to the first qk_rope_head_dim entries when oversized. The first
half of the rotate-half layout is bit-identical to the V4 reference's
qk_rope_head_dim-sized cos/sin, so no numerical change for callers that
already pass pre-sliced cos/sin.

Discovered by /mnt/shared-storage-user/yehaochen/codespace/xtuner-dev2/.dev_scripts/deepseek_v4_reference/gpu_smoke_test.py
running a tiny DeepSeekV4 on H200 + py312-pt29 (torch 2.9.1+cu128).
Snapshot of in-progress V4 work. Combines wave-1 model scaffolding (already
exercised, parity-tested) with today's session-long memory/compile work whose
final pack=8192 OOM is not yet root-caused — committing so we don't lose
progress while debugging.

Wave-1 V4 scaffolding (parity-tested):
- DeepSeekV4Config / DeepSeekV4 in xtuner/v1/model/moe/deepseek_v4.py:
  V4_NON_EP_COMPILE_CFG / V4_EP_COMPILE_CFG with selective compile targets,
  _V4InnerBlock (HC-adapter over MoEDecoderLayer), _V4DecoderLayer, V4
  _micro_batch_forward override, _hc_head_reduce / _hc_head_reduce_compute
  split, to_hf_key_list keymap fix (was silent-skipping 36 V4 layer weights),
  aux_loss.finalize all-hash guard.
- MoE base patches in xtuner/v1/model/moe/moe.py for HashRouter compat:
  update_bias skips hash-routed layers (HashRouter has no
  e_score_correction_bias), MoEModelOutputs.tokens_per_expert_global Optional,
  post_micro_batch_forward None-safe.
- New xtuner/v1/module/attention/_flash_mla_sparse_attn.py with autograd
  Function wrappers around FlashMLA forward + cudnn DSA backward (Phase-2
  cudnn-frontend ≥ 1.24 contract). Includes cu12/cu13 atomic_add_fp32 ABI
  monkey-patch for the cudnn DSA backward subprocess.

Today's varlen refactor (all module-level pytest passes: 26/26 across DSA,
KVCompressor, Indexer, sparse_attn, HCBlock, HCSinkhorn — DSA per-sample
parity is rtol=0/atol=0, compressor is rtol=1e-5 because the full-pack wkv
GEMM picks a different cuBLAS algorithm than the per-sample version):
- DSA.forward (Phase 1): replaces per-sample Python loop + .cpu().tolist()
  with single varlen sparse_attn call. New helpers in xtuner/v1/module/
  attention/dsa.py: _build_window_topk_idxs_varlen,
  _build_compress_topk_idxs_varlen, _interleave_window_compressed_kv (lays kv
  out as per-sample [W_0,C_0,W_1,C_1,...]), _shift_topk_to_global. Removes
  the old per-sample _build_window_topk_idxs / _build_compress_topk_idxs.
- KVCompressor (Phase 2): per-sample loop → full-pack wkv/wgate + scatter
  into [total_c, ratio, coff*head_dim] chunk layout + sample-aware
  _overlap_transform_varlen. One .item() sync at the end (down from
  1 + num_samples in the original).
- Indexer (Phase 3): consolidate the two .cpu().tolist() syncs into one
  stacked D2H transfer. Per-sample score loop intentionally kept — vectorising
  it would materialise a 137 GFLOPS dense [total_q, total_c] score matrix vs
  the block-diagonal 8 GFLOPS we keep here.
- sparse_attn._build_horizon_mask: drop the cu_seq_lens[-1].item() check —
  was forcing a host sync on the hot path; the invariant is now enforced by
  construction in DSA.forward.

bf16 fp32-upcast removal (preserves bit-identity to within float epsilon):
- hc_block.hc_pre / hc_post: drop x.float() / residual.to(comb.dtype)
  256 MB transients. Square + mean-of-squares accumulate via dtype=fp32
  param; gate linear runs bf16 with cuBLAS's fp32 accumulator; matmul casts
  comb→bf16 instead of residual→fp32. Tiny mixes / matmul outputs upcast
  for Sinkhorn / final accumulate.
- DSA q-norm in dsa.py:324: same trick — bf16 q*q + dtype=fp32 mean.
- DeepSeekV4._hc_head_reduce_compute: same pattern as hc_pre.
- New hc_block._unshard_hc_params eager helper — hoists DTensor.full_tensor()
  out of the compiled hc_pre/hc_post region, eliminating 3 graph breaks per
  HC call.

Compile-machinery patches in xtuner/v1/utils/compile.py:
- _patch_sympy_mod_eval_negative_subs: replaces torch._sympy.functions.Mod.eval
  to return None instead of asserting p >= 0 when _random's [-1,1] real
  substitution violates the non-negative integer precondition. Required for
  DSA EP compile path (inductor's analyze_memory_coalescing triggers it).
- _patch_triton_max_block_xblock: raise TRITON_MAX_BLOCK['X'] from 4096 to
  32768 for V4's varlen pack=8192 path (inductor heuristic picks XBLOCK past
  the sanity cap; cap is software-side, not hardware). Subprocess-wide:
  paired with TORCHINDUCTOR_WORKER_START=fork in the launch script so
  inductor compile workers inherit the patched module state.

Unresolved bug — pack=8192 backward OOM:
- "Tried to allocate 130.03 GiB" single allocation in C++ autograd engine at
  step 1 backward, across every cfg combo we've tried (compile on/off, intra
  1/2, dispatcher deepep/none). 130 GiB tensor shape from inductor codegen
  earlier was (1, s_q, s_q + s_c, n_heads * head_dim) fp32. The same OOM
  shape happens with V4_EP_COMPILE_CFG fully emptied, so the offending
  allocation is in the eager path, not inductor codegen.
- Earlier 06:00 run with the same code under compile_cfg=False got to step 50
  before the log stopped writing (max_memory 114 GB, reserved 130.94 GB) —
  may have hit the same 130 GiB allocation at step 51 and been torchrun-OOM-
  killed before a Python traceback flushed.
- Next debug step: install a CUDA allocator OOM hook (or torch.cuda.memory.
  _record_memory_history(enabled='all') with snapshot dump on signal) so the
  specific aten op behind the 130 GiB request is identifiable.

Active config (ci/config/deepseek_v4_flash.py):
  num_hidden_layers=4, ep_size=4, dispatcher=deepep, pack_max_length=8192,
  recompute_ratio=1.0, intra_layer_micro_batch=1, attention.backend=cudnn,
  compile_cfg=False (temporarily), profile_step=4 profile_memory=True.
Native ``Indexer.forward`` materialises a ``[1, S_i, n_heads, T_i]`` fp32
score tensor per sample (4.6 GiB at pack=8192 / n_heads=64 / total_c=2048,
re-allocated under activation-checkpoint recompute). That was the leading
suspect for the 130 GiB backward OOM seen on the V4 4-layer toy. The
fused varlen kernel keeps per-head ``q · k`` partials in registers and
streams a running top-K, so the dense intermediate never lands in HBM.

New file ``xtuner/v1/module/attention/_indexer_topk_triton.py``:
- Per-query Triton program (grid=(total_q,)). Loads kv in ``BLOCK_C`` tiles;
  for each tile, runs ``relu(q · k) * w_h`` across ``N_HEADS`` (constexpr-
  unrolled per-head load avoids ``tl.dot``'s 16-minimum tile constraint
  that breaks the n_heads=4 test fixture). Causal mask applied before
  top-K merge.
- Top-K maintenance is bit-packed: ``(score, idx) → uint64`` where the
  score-half uses the universal IEEE→sortable encoding (positive flips
  msb, negative inverts) so descending unsigned sort matches descending
  fp32 across mixed signs — required because the indexer's per-head
  weight has unconstrained sign, so ``relu(q·k) * w_h`` can be negative.
  The position-half is ``(T_PADDED - 1) - idx`` so score-ties break
  ascending-idx, matching ``torch.topk(descending=True)``'s semantics.
  Per c-tile: ``tl.cat(top_packed, new_packed, can_reorder=True)`` then
  ``tl.topk`` collapses back to K. Sentinel slots (running buffer
  initialised with -inf-packed) are detected by an equality check on
  the high 32 bits and rewritten to ``-1`` at the end.

Integration:
- ``IndexerConfig.backend: Literal["native", "triton"] = "native"`` opts
  the dispatch in. Triton path wraps the kernel call in ``torch.no_grad``
  because Indexer's output indices flow into ``sparse_attn``'s ``gather``,
  which has no gradient through indices — there is no useful gradient to
  back-propagate, and skipping autograd avoids the otherwise-saved
  intermediate too.
- ``DeepSeekSparseAttention`` instantiates Indexer with
  ``backend="native"`` explicitly so the default V4 path remains
  unchanged; flipping to triton stays opt-in.

Tests (``tests/module/test_indexer.py``):
- New ``TestIndexerTritonParity`` class: builds matched native+triton
  Indexer pairs (same RNG init), runs both on identical single-sample
  and two-sample varlen inputs, then compares the scores at each
  backend's chosen indices (re-computed via the native math).
- Comparison is on the *sorted score values per row*, not on raw
  indices: cuBLAS-vs-Triton ``q · k`` reductions differ by O(ULP), so
  near-tie indices flip between backends without changing the semantic
  picks (set of effective KV positions). Tolerance ``atol=0.05, rtol=0``
  covers the observed ULP-level divergence (max 5e-3 in practice) with
  three orders of magnitude of headroom.

Bench (V4 production dims, pack=8192 single-sample, bf16, H200):
  native  fwd 11.0 ms  peak  4.57 GB
  triton  fwd 172  ms  peak  0.76 GB
The 6× memory win is the goal here. The ~16× speed regression is
because the kernel currently does scalar per-head FMAs (grid is
total_q-wide with BLOCK_Q=1), missing tensor-core utilisation. Beating
native needs BLOCK_Q ≥ 16 + ``tl.dot``; tracked as a follow-up rewrite.

Also bumps the test_kv_compressor parity tolerance to ``rtol=1e-5,
atol=1e-5`` from ``rtol=0/atol=0`` — same root cause (full-pack vs per-
sample wkv GEMM picks different cuBLAS algorithms), no cross-sample
contamination.
Beats native on both axes at V4 production dims (pack=8192, n_heads=64,
head_dim=128, index_topk=512, bf16, H200):

  backend   fwd ms   peak MB
  native      7.39    4679
  triton      4.58     774    ← 1.6× faster, 6× less HBM

Three changes vs the previous insertion-replace + tl.cat+tl.topk merge
(which had been 23× slower):

1. ``q · k`` via ``tl.dot`` (tensor cores).
   Hoist ``q [N_HEADS, HEAD_DIM]`` and per-head ``w [N_HEADS]`` to once-
   per-program loads outside the c-loop, then on each c-tile do a single
   ``tl.dot(q, kv.T)`` for the per-head dot, ReLU, head-weighted sum →
   ``[BLOCK_C]`` score. Inputs are bf16 (q+kv) with fp32 accumulator —
   required for SMEM and matches Hopper's tensor-core path. Requires
   ``N_HEADS >= 16`` (Triton tile-mma floor); enforced in the wrapper
   with a clear error for the n_heads=4 path.

2. Single-pass top-K via a per-query ``T_PADDED``-wide score buffer.
   Drop the per-c-iter ``tl.cat(top_packed, new_packed)`` + ``tl.topk(K)``
   merge — both because ``tl.cat`` only takes equal-size inputs (forcing
   ``BLOCK_C = K`` and an unmanageable SMEM footprint) and because the
   per-iter sort is ~10× more work than a single final sort. Instead
   keep ``all_packed [T_PADDED]`` in registers/SMEM seeded with
   -inf-packed; each c-tile bit-packs ``(score, idx)`` and writes the
   ``BLOCK_C`` new entries into the right offset via
   ``tl.gather + tl.where``. After the loop one ``tl.topk(all_packed, K)``
   produces the result. Bitonic sort over ``T_PADDED=2048`` is ~11
   stages and amortises across the c-axis cheaply.

3. ``num_stages=1`` on the kernel launch.
   Default pipelining doubles the ``kv_tile`` SMEM footprint, which at
   ``BLOCK_C=256`` (the new wrapper default) tips us over Hopper's
   232 KB SMEM ceiling. One stage costs a small bit of HBM-load /
   compute overlap but the score path is fp-throughput-bound, not
   memory-bound, so the tradeoff is net positive.

Parity tests bump the fixture's ``index_n_heads`` from 4 to 16 to satisfy
the tensor-core minimum; the score-equivalence comparison (sorted scores
at picked indices, ``atol=0.05``) is unchanged and covers the same
"ULP-level reduction-order divergence on near-ties" semantics as before.

Tracking task #40 closes here.
…KVCompressor

The step-5 profiler trace under EP=4 + compile_cfg=True showed 50% of GPU
time in fragmented elementwise / "other" buckets (aten::mul × 1956,
aten::sum × 2526, aten::copy_ × 1813 per step) — almost all of HC / DSA /
compressor / FFN was running eager because ``V4_EP_COMPILE_CFG`` was set
to ``{}``.

That empty cfg was a workaround for a recompute-time 130 GiB fp32
allocation "across DSA, hc_pre/post, and shared/expert paths". The shared
upstream cause was the native Indexer materialising a ``[1, S_i, n_heads,
T_i]`` fp32 score tensor inside the autograd graph; under varlen +
``dynamic=True`` compile + activation checkpoint recompute, that ~4-5 GB
per-layer tensor multiplied across shape-variant retraces until it
crowded out everything else. The Indexer is now invoked under
``torch.no_grad()`` from the Triton tensor-core kernel at
``_indexer_topk_triton.py`` (1.6× faster than the einsum loop on V4 dims),
so that fp32 tensor never enters autograd in the first place.

Three coupled changes:

1. ``DSAConfig.indexer_backend`` (new field, default ``"triton"``). The
   choice was previously hard-coded ``"native"`` at the construction site
   in ``DSA.__init__``. Production V4 (``index_n_heads=64``) now picks
   up the fast no-autograd path automatically; configs with
   ``index_n_heads < 16`` (below the Triton tensor-core tile floor) must
   pin ``indexer_backend="native"`` explicitly — the kernel surfaces a
   clear ValueError otherwise.

2. ``V4_EP_COMPILE_CFG`` rebuilt from ``MOE_EP_COMPILE_CFG | _V4_LAYER_TARGETS``.
   The V4-specific layer targets (hc_pre, hc_post, attn_block,
   _ffn_pre_compute, _ffn_post_compute, _hc_head_reduce_compute,
   DSA.forward) are factored out into a shared ``_V4_LAYER_TARGETS`` dict
   used by both EP and non-EP variants. ``MoEDecoderLayer.forward`` stays
   excluded under EP (already dropped in the parent ``MOE_EP_COMPILE_CFG``
   because it enters the deepep all2all dispatcher).

3. ``KVCompressor.forward`` added as a compile target. The compressor's
   scatter_index_put + softmax + sum + RMSNorm chain is exactly the
   small-op storm that ``aten::mul`` / ``aten::sum`` / ``aten::copy_`` in
   the trace was counting. The ``int(cu_seq_lens_out[-1].item())`` D2H
   sync inside graph-breaks once; ``fullgraph=False`` (the existing
   ``_LITE`` option) accepts that and still fuses the surrounding ops.

The DSA test fixture in ``tests/module/test_dsa.py`` uses
``index_n_heads=4`` (below the Triton kernel's ≥16 floor) so it pins
``indexer_backend="native"`` explicitly. All 6 DSA + 10 Indexer +
KVCompressor tests pass on this change.
The Indexer's output is fed directly into ``sparse_attn``'s ``gather`` as
the *index* argument. ``gather``'s backward is zero w.r.t. its index
input, so no gradient ever reaches anything inside the Indexer in the
current V4 design — the projections, the internal KVCompressor, the
score path, every saved tensor along the way is autograd state that
nothing will ever consume.

Wrap the whole ``self.indexer(...)`` call site in ``torch.no_grad`` so
that state is not allocated in the first place. This:

* Drops the autograd-saved tensors for ``wq_b`` / ``weights_proj`` /
  the internal ``KVCompressor``'s ``wkv`` / ``wgate`` Linears (one
  saved input per Linear, all the way back to ``q_lowrank`` /
  ``hidden_states``).
* Drops the compressor's per-step scatter buffers (``kv_chunks_flat``,
  ``score_chunks_flat``) — they only need to live for the duration of
  the indexer call and can be freed immediately, not held across the
  rest of the layer's forward.
* Gives the surrounding ``DSA.forward`` compile region one clean
  ``no_grad`` subregion to fold into; inductor can emit a single graph
  for the eager indexer path instead of preserving backward-state
  bookkeeping.

The previous internal ``with torch.no_grad()`` around just the triton
kernel call in ``Indexer.forward`` is now redundant (the outer wrap
subsumes it) and is removed. The Indexer's contract is now: callers
that want gradient through its inputs (none in the current codebase)
must call it outside a no_grad block AND have a use for the gradient
through ``gather``'s indices (which there isn't one).
The previous V4 layer was split across three classes that existed only to
work around an over-narrow abstraction:

  MoEDecoderLayer (parent)
       └── _V4InnerBlock      ← inherited but never used the parent's forward;
                                added attn_block / ffn_block / set_context
                                to fit an HCInnerBlock protocol
       └── HCDecoderLayer     ← "generic" wrapper with V4 as its only user;
                                ``attn_block(x) -> x`` couldn't carry DSA's
                                position_embeddings / seq_ctx / input_ids,
                                so set_context shoved them on the inner block
                                as mutable state
       └── _V4DecoderLayer    ← bridge that called set_context, called the
                                HC wrapper, then read the stashed router
                                results back off the inner block

Three classes, one ``set_context`` side-channel, one ``_last_router_results``
stash-and-grab, one ``assert hc_layer.inner is inner`` to police a
dual-registration invariant — to express what is conceptually one decoder
layer (HC residual mix + DSA + MoE FFN). Per CLAUDE.md rule #1: "do not
justify patch-on-patch, layered, spaghetti-like implementations in the name
of backward-compatibility protection". The "generic HC wrapper" abstraction
had exactly one user.

Replace all three with :class:`V4DecoderLayer` that owns every submodule and
parameter directly:

  V4DecoderLayer.forward(hidden_states,
                         *, position_embeddings,
                            position_embeddings_compressed,
                            seq_ctx, input_ids)
      -> (hidden_states_out, router_logits, router_weights)

All inputs flow through arguments; router results flow out through the
tuple; no hidden state on ``self`` between calls. The compile-target
sub-methods (``_attn_compute`` / ``_ffn_pre_compute`` / ``_ffn_post_compute``
/ ``_shared_experts_forward``) stay as separate methods so
``V4_EP_COMPILE_CFG`` can target them individually with the same boundary
the previous ``_V4InnerBlock`` exposed.

Knock-on changes:

* ``hc_block.py`` keeps ``hc_pre`` / ``hc_post`` / ``HCWrapperConfig`` /
  ``_unshard_hc_params`` (the math + DTensor helper) but deletes the
  ``HCDecoderLayer`` class and ``HCInnerBlock`` protocol.
* ``decoder_layer/__init__.py`` re-exports drop ``HCDecoderLayer`` and
  ``HCInnerBlock``.
* ``DeepSeekV4._build_one_layer`` constructs ``V4DecoderLayer`` directly
  (was: ``_V4InnerBlock`` + ``HCDecoderLayer`` + ``_V4DecoderLayer``).
* ``_translate_layer_tail`` no longer strips ``hc_layer.`` or
  ``hc_layer.inner.`` prefixes — params now arrive at the flat
  ``layers.L.hc_attn_*`` / ``layers.L.input_layernorm.weight`` /
  ``layers.L.experts.*`` layout.
* ``V4_EP_COMPILE_CFG`` / ``V4_NON_EP_COMPILE_CFG`` point at
  ``V4DecoderLayer._attn_compute`` / ``_ffn_pre_compute`` /
  ``_ffn_post_compute`` (was: ``_V4InnerBlock.attn_block`` / equivalents).
* ``tests/module/test_hc_block.py`` no longer exercises a (now-deleted)
  wrapper class; it tests ``hc_pre`` / ``hc_post`` directly with the same
  closed-form assertions on the degenerate-init path.

Verified: 51 module + V4 model tests pass (the two ``to_hf_key_list_coverage``
checks that need a local BF16 checkpoint are skipped). Sample of the new
flat param layout under a 2-layer toy model::

    hc_head_fn                                -> hc_head_fn
    layers.0.hc_attn_fn                       -> layers.0.hc_attn_fn
    layers.0.input_layernorm.weight           -> layers.0.attn_norm.weight
    layers.0.post_attention_layernorm.weight  -> layers.0.ffn_norm.weight
    layers.0.self_attn.wq_a.weight            -> layers.0.attn.wq_a.weight
    layers.0.experts.fused_w1w3.weight        -> [N × layers.0.ffn.experts.i.{w1,w3}.weight]
    layers.0.shared_experts.gate_proj.weight  -> layers.0.ffn.shared_experts.w1.weight

No HF-side semantics change — the HF key bridge produces the same target
names as before; only the XTuner-side path stripping is simpler.
… to 32 GiB

The native ``sparse_attn`` body used the standard ``expand``-into-``gather``
idiom to select the per-query top-k KV rows:

    gather_idx  = safe_idxs.unsqueeze(-1).expand(-1, -1, -1, head_dim)
    kv_gathered = torch.gather(
        kv_f.unsqueeze(1).expand(-1, total_tokens, -1, -1),   # view
        2,
        gather_idx,
    )

In eager this is free — the ``kv_f.unsqueeze(1).expand(...)`` returns a
``[1, S, T, D]`` strided view of the underlying ``[1, T, D]`` storage, and
``gather`` reads through it without materialisation. Under ``torch.compile``
inductor can't fuse the expand into the gather's index codegen and
materialises the full ``[1, S, T, D]`` fp32 tensor before gathering. At V4
production dims (pack=8192, T≈2048, D=512) that is a single 32 GiB allocation
per ``sparse_attn`` call. Both the forward and the backward (which expands
again for the scatter-add transposed gather) trip it, and ``DSA.forward`` is
on the compile target list — so under EP4 + recompute the rank0 memory
snapshot showed a 34.94 GB block stuck behind a ``cudnn_sparse_attn`` →
``_native_sparse_attn`` fallback path during the backward.

Rewrite as a plain advanced index:

    kv_gathered = kv_f.squeeze(0)[safe_idxs.squeeze(0)].unsqueeze(0)

Same semantics (``kv[idx[s, j], d]`` per output position), no expand pattern,
no intermediate ``[1, S, T, D]`` materialisation. Inductor compiles this to a
single indexed-load kernel.

Verified: 16 module tests pass.
…e branch

Before: ``cudnn_sparse_attn`` / ``flash_mla_sparse_attn`` each had a runtime
guard at the top::

    if not _flash_mla_topk_ok(topk_idxs.size(-1)):
        return _native_sparse_attn(...)
    return _CudnnSparseAttnFn.apply(...)

This is fine in eager — ``topk_idxs.size(-1)`` is a small Python int and the
branch resolves cleanly. Under ``torch.compile`` with ``dynamic=True`` (the
V4 default), the size becomes a symbolic int whose value depends on module
attributes ``self.sliding_window`` and ``self.index_topk`` that dynamo
doesn't always constant-fold through ``cat`` and ``%``. The compiled
``DSA.forward`` then either (a) bakes the native-fallback branch into the
graph even when ``backend="cudnn"`` was requested, or (b) compiles both
branches with a runtime selector. Either way the native path ends up in the
inductor codegen, and its ``kv.unsqueeze(1).expand(-1, S, -1, -1)`` +
``gather`` chain materialises a ``[1, S, T, D]`` fp32 tensor — ~32 GiB at V4
production dims (pack=8192, T≈2048, D=512). That was the 34.94 GB
``cudnn_sparse_attn → _native_sparse_attn`` allocation showing up in the
backward-recompute snapshot even though the user had set
``moe_cfg.attention.backend = "cudnn"``.

Fix: make the backend a static, per-layer decision in ``DSA.__init__``. We
know ``sliding_window`` and ``index_topk`` at construction time, so we know
each layer's ``topk_max`` and can compute ``_flash_mla_topk_ok(topk_max)``
once. ``DSA._resolve_sparse_attn_fn`` returns one of three function pointers
(``_cudnn_sparse_attn_apply``, ``_flash_mla_sparse_attn_apply``,
``_native_sparse_attn``), stored as ``self._sparse_attn_fn``. The forward
path is now a single function-pointer call with no branch — inductor sees
only the requested backend, the native expand+gather codegen is never
emitted unless the user explicitly chose ``backend="native"``.

Two thin branch-free wrappers added in ``_flash_mla_sparse_attn.py``
(``flash_mla_sparse_attn_apply``, ``cudnn_sparse_attn_apply``) just call
``.apply()`` on the underlying autograd Functions. The existing
``flash_mla_sparse_attn`` and ``cudnn_sparse_attn`` wrappers stay (still
used outside DSA / as the user-facing "auto-fallback" helpers) but are no
longer on the compile path.

If a layer's ``sliding_window + index_topk`` violates the FlashMLA 128-
alignment requirement, we emit a ``logger.warning`` at construction time
explaining which layer / why and select native — instead of silently
falling back inside the compiled forward where the user can't tell.

Verified: 19 module + 4 V4 model tests pass.
…_mla/cudnn

The previous static dispatch fix assumed all compress_ratio>0 layers use
``sliding_window + index_topk`` as their combined topk width. That is
correct for CSA (compress_ratio=4 with Indexer) but wrong for HCA
(compress_ratio=128 with deterministic positional top-K). HCA's
compress_topk dim is ``total_tokens // compress_ratio + 1`` —
pack-dependent and never 128-aligned at any practical pack length:

  pack=4096  → 32+1 = 33  → topk_max = 128 + 33 = 161  (161 % 128 = 33)
  pack=8192  → 64+1 = 65  → topk_max = 128 + 65 = 193  (193 % 128 = 65)
  pack=16384 → 128+1=129  → topk_max = 128 +129 = 257  (257 % 128 = 1)

FlashMLA's prefill kernel asserts ``topk % (2*B_TOPK) == 0`` (= 128).
Selecting ``cudnn_sparse_attn_apply`` for an HCA layer therefore trips:

    RuntimeError: Assertion error (phase1.cuh:577):
        Assertion `params.topk % (2*B_TOPK) == 0` failed.

That assertion is exactly why the original ``cudnn_sparse_attn`` wrapper
had a runtime ``if not _flash_mla_topk_ok(topk_idxs.size(-1))`` fallback to
native. The static dispatch needs to encode the same decision but at
construction time, where we don't know pack — so for HCA we treat
FlashMLA as structurally incompatible regardless of pack and route to
native unconditionally. CSA (ratio=4) and sliding-only (ratio=0) layers
keep their static FlashMLA / cudnn path.

The warning surfaces *which* layer fell back and *why* so the user can
tell whether it's a 128-alignment violation (potentially fixable by
tweaking sliding_window / index_topk) or the structural HCA case (not
fixable without changing the model topology).

Verified: 10 DSA + 4 V4 model tests pass.
…tor can fuse

The HC residual-mix term ``comb @ residual`` at hc_post was a batched gemm
with shapes::

    comb     [B, S, H=4, H=4]
    residual [B, S, H=4, D=4096]
    out      [B, S, H=4, D=4096]

i.e. ``[B*S]`` independent (4×4) × (4×4096) matmuls. K=4 is below Hopper's
wgmma tile floor (K=16), so ``torch.matmul`` / ``torch.einsum`` both lower
to ``aten.bmm`` which inductor delegates to cuBLAS via ``extern_kernels.bmm``
rather than codegen a triton kernel. cuBLAS at K=4 takes the CUDA-core
fallback — bandwidth-bound, no tensor cores — and at pack=16384 cost
~3 ms per call × 86 calls per step ≈ 250 ms/step wasted.

Rewrite as ``broadcast-multiply + reduce-sum``. These are pointwise +
reduction primitives that inductor's pattern-matcher fuses aggressively
into a single triton reduction kernel: the multiply runs in registers,
the H_in=4 reduction is an unrolled loop, and the trailing
``post * x + mixed`` epilogue joins the same kernel. The
``[B, S, H_out, H_in, D]`` 5D intermediate that the expression implies in
eager mode never materialises under compile — it lives in registers as
part of the fused reduction.

Eager note: running this function eagerly at V4 production dims WILL
materialise an 8 GiB transient. It is gated by an explicit warning in
the docstring. The function is registered in ``_V4_LAYER_TARGETS``
compile cfg so the EP + non-EP paths both compile it, and the
``hc_mult=1`` degenerate path in ``V4DecoderLayer`` short-circuits past
``hc_post`` so the existing ``test_hc_mult_1_equals_plain_residual``
test fixture (which would otherwise hit eager) is unaffected.

Verified: 3 hc_block tests pass.
HuggingFace's ``transformers>=5.9.0`` ships a native ``deepseek_v4`` module
(``DeepseekV4Model`` / ``DeepseekV4DecoderLayer`` / ``DeepseekV4Attention``
/ ``DeepseekV4HyperConnection``) — a clean pure-PyTorch reference that does
not depend on TileLang or FlashMLA. This commit adds a parity test that:

1. Builds matched small HF + XTuner V4 configs (vocab=256, hidden=64,
   moe_inter=32, n_heads=8, head_dim=32, hc_mult=4, n_routed_experts=4,
   n_shared_experts=1, sliding_window=32, qk_rope_head_dim=2).
2. Random-inits both, copies every HF parameter into the XTuner
   ``V4DecoderLayer`` via an explicit name + layout map (covers HC params,
   norms, the Q/KV/O LoRA chain + attn_sink, the HCA and CSA compressors,
   the nested Indexer, and the MoE router + experts + shared experts).
3. Runs identical inputs through both layers and ``torch.testing.assert_close``s
   the outputs for four cases:
   - sliding-only (compress_ratio=0)
   - CSA (compress_ratio=4 + Indexer)
   - HCA (compress_ratio=128 + deterministic positional gather)
   - hash-routed sliding (covers ``HashRouter`` vs HF's ``DeepseekV4HashRouter``)

The HF→XTuner weight bridge handles three layout differences:

* HF nests the Indexer inside ``self_attn.compressor`` (CSA only); XTuner
  has ``self_attn.compressor`` and ``self_attn.indexer`` as siblings.
* HF stores routed-expert weights as 3D ``nn.Parameter`` shaped
  ``[n_experts, 2*intermediate, hidden]`` / ``[n_experts, hidden, intermediate]``;
  XTuner stores them flattened to ``[E*2*I, H]`` / ``[E*H, I]`` via
  ``build_grouped_linear``. Same memory once leading two dims are flattened.
* HF's ``DeepseekV4RotaryEmbedding`` emits ``[..., qk_rope_head_dim/2]``
  half-dim cos/sin for the interleaved-rotate-half convention; XTuner's
  DSA expects ``[..., qk_rope_head_dim]`` cat-style cos/sin for cat-style
  rotate-half. ``_hf_rotary_to_xtuner_format`` rebuilds cos/sin in XTuner
  layout from HF's per-layer-type ``inv_freq`` so both sides see the same
  underlying angles.

Known gap. ``qk_rope_head_dim`` is pinned to 2 in this test because the
HF interleaved-pair rotation and XTuner cat-style rotation are
mathematically inequivalent for ``qk_rope_head_dim >= 4`` — different
element pairings (HF: ``(x[2i], x[2i+1])``, XTuner: ``(x[i], x[i+D/2])``).
They degenerate to the same single ``(x[0], x[1])`` pair when D=2 only.
Production V4-Flash uses qk_rope_head_dim=64; the test will not extend
to that until either side adopts the other's convention or the HF→XTuner
load path adds a per-head rope-suffix channel permutation.

Status. The test scaffold runs end-to-end (config build, weight copy,
forward execution, output comparison). With qk_rope_head_dim=2 the
remaining numerical disagreement is ~12% absolute / ~50% mismatched
elements — *real divergence beyond rope*, not floating-point noise.
Investigation of the divergence (likely in the eager-attention vs
sparse_attn mask construction, the NoAux router's grouping math, or
the MoE expert dispatch) is a follow-up.
…ention

HF's ``eager_attention_forward`` treats ``attention_mask=None`` as "no mask",
which results in full (non-causal) attention. XTuner's ``sparse_attn`` is
always causal + sliding-window via its constructed ``topk_idxs``. So we must
hand HF an explicit ``[1, 1, S, S]`` additive mask with ``-inf`` outside the
causal+window cone.

Building it directly via positional arithmetic rather than going through
``create_sliding_window_causal_mask`` because the helper needs a real
``Cache`` and ``inputs_embeds``-typed input, which we don't construct in
this single-layer test.

Effect: mismatched-element count goes from 49% → 45% in the
sliding-attention parity test — the masking fix is correct but does not
account for all of the divergence; further isolation by sub-step (norms,
attention, HC mix, FFN) is needed.
XTuner V4 was using cat-style rotate-half (pairs ``(x[i], x[i + D/2])``)
while HF's ``DeepseekV4RotaryEmbedding`` and the V4-Flash inference
reference both use the interleaved / complex-pair convention (pairs
``(x[2i], x[2i+1])``). The two are NOT bit-equivalent rotations on the
same input for ``qk_rope_head_dim >= 4`` — same θ angles applied to
different element pairs. Effect: loading a V4 BF16 checkpoint into
XTuner produced Q/K tensors whose rope-suffix channels were permuted
relative to what V4 was trained with, breaking downstream attention vs
the reference.

This commit aligns XTuner's V4 rope path with HF without any weight
permutation at load time. Scope is V4-only — ``RotaryEmbedding`` used by
LLaMA / Qwen / V3 etc. is unchanged.

Three changes:

1. ``xtuner/v1/module/rope/rope.py::DualRotaryEmbedding.forward``
   No longer emits ``torch.cat((freqs, freqs), dim=-1)``-doubled cos/sin.
   Returns half-dim ``[B, S, qk_rope_head_dim/2]`` instead — one θ per
   adjacent rope-dim pair, matching HF's
   ``DeepseekV4RotaryEmbedding.forward``.

2. ``xtuner/v1/module/attention/dsa.py::_apply_rope`` /
   ``_apply_rope_inverse`` / ``_rotate_half``
   Replace cat-style pairing with adjacent-pair pairing. ``_apply_rope``
   internally ``repeat_interleave``s the half-dim cos/sin to full
   ``qk_rope_head_dim`` so the broadcast multiply lines up. The
   ``_rotate_half`` helper now does
   ``stack([-x_odd, x_even], dim=-1).flatten(-2)`` instead of
   ``cat([-x2, x1], dim=-1)`` over halves. The ``cos.size(-1) <
   qk_rope_head_dim`` validation in ``DSA.forward`` now checks
   ``qk_rope_head_dim // 2`` (the half-dim contract).

3. ``DSA.forward`` no longer slices ``position_embeddings_compressed``
   down to ``qk_rope_head_dim // 2`` for the Indexer — the source
   ``DualRotaryEmbedding`` is already half-dim, so the slice is a
   no-op redundancy.

Indexer's ``_apply_rope`` (``indexer.py:352``) was already implemented in
the interleaved / complex-pair convention (it takes ``cos`` of size
``rope_head_dim // 2`` and runs ``view-as-pairs + rotate``). No change
needed there.

Test fixture update: ``tests/module/test_dsa.py::_make_position_embeddings``
no longer ``cat``s the half-dim freqs into a full-dim tensor — emits
half-dim cos/sin matching the new ``DualRotaryEmbedding`` contract.

Parity test ``_QK_ROPE`` constant raised from 2 → 16. (Previously 2
sidestepped the convention mismatch because both conventions degenerate
to the same single ``(x[0], x[1])`` pair when D=2; 16 exercises the real
rotation on multiple pairs.)

Verified: 19 module tests pass (DSA / Indexer / KVCompressor / HC).
Parity test ``Greatest absolute difference`` is unchanged at this rope
dim, which is the expected result — the rope is now bit-equivalent to
HF, so the remaining 45% mismatch in the parity test was always from
other sources (HC reduction order / sparse-vs-dense attention / MoE
expert reduction). Those are the next things to isolate.
Adds ``test_subcomponent_probe`` that walks the sliding-attention forward
step-by-step and prints abs diff at each sub-module boundary. Both sides
share copied weights, so each step should match to within bf16
reduction-order tolerance (~1e-4 abs). Findings:

  [DIFF] hc_pre.collapsed               max=7.8e-3
  [DIFF] hc_pre.post                    max=1.3e-3
  [DIFF] hc_pre.comb                    max=5.1e-4
  [OK]   input_layernorm                max=0.0  (full RMSNorm, weighted)
  [OK]   q_a_proj                       max=0.0
  [OK]   q_a_norm                       max=0.0  (full RMSNorm, weighted)
  [OK]   q_b_proj                       max=0.0
  [DIFF] q_b_norm (per-head)            max=3.1e-2  ← biggest
  [OK]   kv_proj+norm                   max=0.0

Pattern. The two ``[DIFF]`` sites are exactly where XTuner trades precision
for memory by doing the square in bf16 before averaging:

* ``DSA.forward`` inline per-head Q RMSNorm (``q_sq = q * q`` in bf16,
  mean in fp32) vs HF ``DeepseekV4UnweightedRMSNorm`` (``x.float().square()``
  fully in fp32 then ``.to(x.dtype)``).

* ``hc_pre`` (``hc_block.py``) — same pattern: bf16 square + fp32 mean,
  plus a bf16 Linear that HF runs in fp32 (``F.linear(flat, fn.float())``
  on a fp32-upcast ``flat``).

The XTuner choices are documented in code comments as memory-savings:
~5 GB on hc_pre, ~128 MB per call on q_b_norm. Both are *deliberate*
precision/memory tradeoffs, not bugs — they just make XTuner V4 NOT
bit-equivalent to HF / V4-ref. Whether to revert them is a precision
budget call, not an obvious-fix call.

Probe is gated to the sliding-only configuration so the surface is small;
hc_pre + q_b_norm cover the most divergent sub-steps. CSA / HCA paths
will need their own probe extensions (compressor + Indexer sub-steps).
…ort order

Five precision-alignment fixes found via the sub-component probe in
``test_deepseek_v4_decoder_layer_parity.py``. The full decoder-layer
parity test's mismatch ratio drops from ~45% to ~5% (sliding-only), the
remaining gap being bf16 kernel-reduction noise from cutlass grouped-GEMM.

1. ``DSA.forward`` inline per-head Q RMSNorm (``q_b_norm``)
   Was: ``q_sq = q * q`` (bf16 square) + fp32 mean.
   Now: ``q.float().square().mean(-1) + eps`` then ``rsqrt → .to(bf16)``,
   matching HF ``DeepseekV4UnweightedRMSNorm`` exactly. Under compile the
   fp32 intermediate doesn't materialise (inductor fuses square+mean+
   rsqrt+multiply into one kernel); eager pays ~128 MB transient per call
   for bit-equivalence with HF.

2. ``hc_pre`` (``hc_block.py``) — RMSNorm + gate Linear
   Was: bf16 square + bf16 ``F.linear(x_flat, hc_fn.to(bf16))`` then
   ``.float() * rsqrt``.
   Now: ``x_flat.float()`` once at the top, the entire chain (square,
   mean, rsqrt, multiply, linear) runs in fp32, matching HF
   ``DeepseekV4HyperConnection.forward`` (``flat = self.input_norm(...float())``
   then ``F.linear(flat, self.fn.float())``). The previously-claimed 5 GB
   memory savings was eager-mode only; under compile the fp32 intermediate
   stays in registers.

3. ``hc_post`` (``hc_block.py``) — fixed transposed-comb bug
   ``comb`` from Sinkhorn is doubly-stochastic but NOT symmetric. HF's
   inline expression in ``DeepseekV4DecoderLayer.forward``::

       hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) +
           torch.matmul(comb.to(dtype).transpose(-1, -2), hidden_states)

   sums over the FIRST hc axis (``comb.T @ residual``). XTuner's earlier
   broadcast+sum reduced over the SECOND hc axis (``comb @ residual``),
   giving WRONG mixing direction — produced ~0.84 abs diff on this op
   alone. Also reverted broadcast+sum back to ``torch.matmul`` for
   cuBLAS's fp32 accumulator (the bf16 sum reduction had ~1.2e-2
   precision loss; the K=4 cuBLAS slowness is a separate optimisation
   problem, not worth a precision compromise). Final tightening: cast
   ``post`` and ``comb`` to ``residual.dtype`` BEFORE the multiply so
   the whole expression runs in bf16 (matching HF), instead of doing
   fp32-mixed arithmetic with a ``.type_as(x)`` cast at the end.

4. ``NoAuxRouter.forward`` — ``torch.topk(sorted=False)``
   Was: default ``sorted=True`` returns indices in descending-by-score
   order.
   Now: ``sorted=False`` matches HF's ``DeepseekV4TopKRouter``
   (modeling_deepseek_v4.py:1034). When two experts tie on score, the
   tie-break order between ``sorted=True`` (CUDA topk sorts descending)
   and ``sorted=False`` (kernel-dependent, often unsorted) can differ,
   sending the same token through a different expert in HF vs XTuner.
   Aligning the sort flag fixes 8 / 64 routed-index disagreements in the
   probe test.

5. Parity test pins ``router_compute_dtype = "native"``
   XTuner's MoEGate defaults to fp32 routing (upcasts ``hidden_states``
   and ``gate.weight`` to fp32 before the ``F.linear``) for routing
   stability. HF runs the gate in bf16. For bit-identical parity the
   test config pins it to ``"native"`` (bf16). Production training
   should keep the fp32 default — this is purely a parity-measurement
   knob.

Sub-component probe results after all five fixes:

  [OK]   hc_pre.*                          max=0.0   (all three outputs)
  [OK]   input_layernorm                   max=0.0
  [OK]   q_a_proj / q_a_norm / q_b_proj    max=0.0
  [OK]   q_b_norm (per-head)               max=0.0   (fix 1)
  [OK]   kv_proj+norm                      max=0.0
  [OK]   attention end-to-end              max=2.9e-4
  [OK]   hc_post (attn) / (ffn)            max=0.0   (fix 3)
  [OK]   post_attention_layernorm          max=0.0
  [OK]   router.logits / topk_weights      max=0.0   (fix 4 + 5)
  [OK]   routed experts (same indices)     max=0.0
  [OK]   shared_experts                    max=0.0
  [DIFF] xtuner dispatcher+grouped vs per-expert  max=2.7e-2
  [DIFF] MoE end-to-end                    max=2.7e-2

Remaining 2.7e-2 is XTuner's cutlass grouped-GEMM (``MoEBlock.fused_w1w3``
/ ``fused_w2``) vs HF's per-expert ``torch.nn.functional.linear`` loop —
mathematically equivalent, different bf16 reduction order. This propagates
to a ~5% layer-output mismatch with abs diff ~0.03 (bf16 kernel noise at
output magnitude ~1). It is NOT a bug; tightening below this requires
either matching kernels (slow per-expert loop in XTuner production) or a
fp32 expert accumulator (memory cost).
Adds two layer-test anchors that monkey-patch parts of an XTuner
V4DecoderLayer to delegate to HF's matching submodule. Both anchors keep
the XTuner V4DecoderLayer.forward and its HC pre/post bookkeeping; only
the inner sub-call (attention or MoE) is rerouted. After
``_copy_hf_to_xtuner_layer`` the weights match, so the patched calls
return HF-bit-identical outputs.

Helpers:

* ``_install_hf_attention_fallback(xtuner_layer, hf_layer, hf_model)``
  Replaces ``xtuner_layer.self_attn.forward`` with a closure that calls
  ``hf_layer.self_attn`` (HF ``DeepseekV4Attention``). Converts XTuner's
  ``(hidden_states, position_embeddings, position_embeddings_compressed,
  seq_ctx)`` signature to HF's
  ``(hidden_states, position_embeddings=dict, position_ids, attention_mask,
  past_key_values)`` and wraps the returned ``(attn_output, attn_weights)``
  back into XTuner's ``AttnOutputs`` dict. After this anchor the entire
  attention path of the XTuner layer is HF-naive (compressor + Indexer +
  sparse_attn → HF dense attention with block_bias).

* ``_install_hf_moe_fallback(xtuner_layer, hf_layer)``
  Replaces ``xtuner_layer._ffn_compute`` with a closure that calls
  ``hf_layer.mlp`` (HF ``DeepseekV4SparseMoeBlock``: hash or topk router
  + per-expert ``nn.functional.linear`` loop + shared experts). Hooks
  ``_ffn_compute`` (not ``_ffn_pre_compute`` / ``_ffn_post_compute``) so
  the ``hidden_factor`` scaling on the FFN output still flows through
  XTuner's ``_ffn_post_compute`` path.

Four new test cases:

* ``test_csa_parity_with_hf_attention_anchor`` /
  ``test_hca_parity_with_hf_attention_anchor`` — attention only. Tolerance
  ``atol=4e-2`` covers cuBLAS grouped-GEMM vs per-expert reduction noise
  in the *unmocked* MoE path (~3e-2 max abs at this scale).

* ``test_csa_parity_full_hf_anchor`` / ``test_hca_parity_full_hf_anchor``
  — both attention and MoE delegated to HF. Tolerance ``atol=1/128``
  (one bf16 ULP) — the only residual is the multi-op bf16-cast chain
  through HC pre/post, which is unavoidable without going to fp32 or
  patching HC too (which would make the test "HF == HF" tautological).

Other small change: ``_run_xtuner_layer`` now calls
``hf_model.rotary_emb`` directly to get cos/sin (after the rope
interleaved-convention migration, XTuner DSA and HF emit identical
half-dim layouts), instead of recomputing via the helper
``_hf_rotary_to_xtuner_format``. This eliminates a ULP-level matmul-vs-
broadcast ordering difference between the two cos/sin paths.
Adds ``_install_full_hf_layer_fallback(xtuner_layer, hf_layer, hf_model)``
which replaces ``xtuner_layer.forward`` with a thin signature adapter
that delegates the *entire* forward pass to ``hf_layer.forward``. Unlike
``_install_hf_attention_fallback`` + ``_install_hf_moe_fallback`` (which
keep XTuner's HC pre/post bookkeeping inline), this one runs zero XTuner
code on the forward path — only the signature conversion. After patching,
the XTuner V4DecoderLayer produces ``atol=0.0 rtol=0.0`` parity with HF.

Two new test cases:

* ``test_csa_parity_full_hf_layer`` (CSA, atol=0)
* ``test_hca_parity_full_hf_layer`` (HCA, atol=0)

Both pass. These act as a sanity baseline confirming the test harness
(weight copy, input prep, output comparison) is bit-clean — any
non-zero diff here would point at a harness bug rather than a V4
implementation bug.

The signature adapter: V4DecoderLayer takes
``(hidden_states, *, position_embeddings, position_embeddings_compressed,
seq_ctx, input_ids) -> (out, logits, weights)``. HF DecoderLayer takes
``(hidden_states, position_embeddings=dict, position_ids, attention_mask,
input_ids, past_key_values) -> out``. The adapter packs the two rope
tuples into HF's ``{"main": ..., "compress": ...}`` dict, pulls
position_ids from ``seq_ctx`` (or arange-falls-back), builds a
causal+sliding mask, calls HF, and pads the missing tuple slots with
zero placeholders (the parity test only consumes the first output).

Together with the partial anchors:

  attention only:   atol=4e-2  (covers cuBLAS grouped-GEMM noise)
  attention + MoE:  atol=1 ULP (covers bf16 HC cast chain)
  full layer:       atol=0     (no XTuner code on forward path)

The user can pick the appropriate anchor depending on which sub-graph
they're debugging.
Two fixes that bring ``test_csa_parity_full_hf_anchor`` and
``test_hca_parity_full_hf_anchor`` to bit-identical (atol=0.0) output
versus HF, while genuinely exercising XTuner code (HC pre/post +
V4DecoderLayer main forward + input_layernorm + V4DecoderLayer
wrapping). The previous test_csa_parity_full_hf_anchor was passing at
atol=1/128 (one bf16 ULP) — that ULP turned out to come from two real
bugs, not from "unavoidable bf16 rounding noise" as the previous commit
message claimed.

1. ``hc_split_sinkhorn`` (xtuner/v1/module/decoder_layer/hc_sinkhorn.py)
   First Sinkhorn iteration was using a manual
   ``comb - amax(-1).detach() + exp + / sum(-1)`` softmax chain,
   mathematically identical to HF's ``torch.softmax(comb_logits, dim=-1)``
   but with a different fp32 reduction order — the kernel difference
   produced a ~6e-8 fp32 ULP drift in ``comb`` (measured in
   ``test_subcomponent_probe``'s ``hc_pre.comb`` step). Even that tiny
   fp32 drift was enough to flip the bf16 rounding direction at a few
   downstream cast boundaries. Replaced with the same
   ``torch.softmax(comb, dim=-1) + eps`` HF uses; comb now matches
   HF bit-for-bit.

2. ``_install_hf_moe_fallback`` (tests/model/test_deepseek_v4_decoder_layer_parity.py)
   The anchor was passing ``x`` straight into ``hf_mlp(x, input_ids)`` —
   but XTuner's ``_ffn_compute`` (which the anchor replaces) internally
   applies ``self.post_attention_layernorm(x)`` before the gate via
   ``_ffn_pre_compute``. HF's ``mlp.forward`` does NOT layernorm — that
   step lives outside ``mlp`` in HF's ``DecoderLayer.forward`` (line
   1136: ``self.mlp(self.post_attention_layernorm(collapsed), ...)``).
   So the anchor was sending HF the un-normalised collapsed input,
   which gave HF mlp slightly wrong activations → propagated as ~1 bf16
   ULP drift at the layer output. Fixed by applying
   ``xtuner_layer.post_attention_layernorm(x)`` before the
   ``hf_mlp(...)`` call inside the anchor.

After both fixes the test goal is reached: with attention and MoE
delegated to HF, the XTuner V4DecoderLayer's hc_pre / hc_post / layer
wrapping produces output ``torch.equal``-true against HF. ``atol=0.0``
``rtol=0.0`` PASS.
HAOCHENYE added 28 commits June 29, 2026 06:14
…s>=5.9.0

transformers 5.9.0's DeepseekV4Config.__post_init__ consumes the legacy
num_hash_layers kwarg from config.json and converts it into the per-layer
mlp_layer_types list ("hash_moe" * num_hash_layers + "moe" * rest), then
drops the scalar attribute. Reading cfg.num_hash_layers directly now raises
AttributeError on every V4 load and the trainer fails at config build.

Recover the count by counting leading "hash_moe" entries — V4 always puts
hash-routed layers at the front so this matches DeepSeekV4._should_compute_aux_loss's
layer_idx < num_hash_layers semantics. Falls back to the legacy direct
attribute when present, so transformers <5.9 keeps working.
…mpile-bwd NaN

The interleaved RoPE implementation introduced by [Fix] V4: switch RoPE to
interleaved convention (17bef28) used cos.repeat_interleave(2, -1) +
strided x[..., 0::2] / x[..., 1::2] + stack/flatten to recover the full-D
rotation factors per layer. Under torch.compile (eager inductor backward),
that combination miscompiles to a NaN-producing strided kernel — confirmed
by bisection: pre-17bef289 + compile_cfg=True trains with finite grads,
17bef28 alone + compile_cfg=True produces step-1 grad_norm=NaN.

Restructure the rope pipeline so the per-layer kernel has zero strided
ops on x:

1. DualRotaryEmbedding.forward now precomputes the D-dim arrangement once
   per micro-batch (inside @torch.no_grad — no autograd hazard there):
       cos_full[..., 2i] = cos_full[..., 2i+1] = cos_half[..., i]
       sin_full[..., 2i]   = -sin_half[..., i]
       sin_full[..., 2i+1] = +sin_half[..., i]
   Sign pattern is folded into sin so the per-layer kernel does not need
   any per-element negation.

2. dsa.py::_apply_rope / _apply_rope_inverse reduce to one fused expression:
       x_swap = x.unflatten(-1, (-1, 2)).flip(-1).flatten(-2)
       return x * cos_full ± x_swap * sin_full
   No unbind on x, no stack of split halves — just a contiguous unflatten +
   adjacent-pair flip + multiply + add. Inductor lowers this to one
   pointwise triton kernel for both forward and backward.

3. _RopeHeadDimProxy added so the HF rope_init_fn sizes inv_freq from
   qk_rope_head_dim instead of attention.head_dim. Mirrors HF V3's
   self.head_dim = qk_rope_head_dim trick (transformers/models/deepseek_v3/
   configuration_deepseek_v3.py:204) that XTuner can't apply directly
   because TransformerConfig.head_dim is a computed_field.

4. Indexer's _apply_rope adapted to the same D-dim signed layout; runs
   under torch.no_grad so the kernel-shape concerns are perf-only.

Tests: test_csa_parity_full_hf_anchor / test_hca_parity_full_hf_anchor
still pass at atol=0; test_dsa.py fixture updated to emit the new D-dim
signed cos/sin so the module-level tests keep their existing tolerances.
…tch HF without cuBLAS K=4 fallback

hc_post mixes the four HC streams via:
    mixed = sum_{h_in} comb[h_in, h_out] * residual[h_in, d]
which HF / V4-ref both write as torch.matmul(comb.transpose(-1, -2), residual).
K = hc_mult = 4 here is below Hopper's wgmma tile floor (16), so cuBLAS
falls back to a CUDA-core small-K gemm; at pack=16384 this cost ~3 ms ×
86 calls/step ≈ 250 ms/step on the matmul alone.

A prior "perf" rewrite (d621cff) replaced the matmul with broadcast+sum
to force inductor into one fused triton kernel, but used comb WITHOUT
.transpose(-1, -2) — that swaps the reduction axis on the non-symmetric
Sinkhorn output, producing a different mathematical operation (~4.0
elementwise diff vs HF in numerical experiments, not a bf16 precision
drift) and breaking the parity contract.

This commit:

1. Adds the missing .transpose(-1, -2) so reduction direction matches HF
   (sum over comb's FIRST hc axis = sum over h_in).
2. Rounds comb to residual.dtype (bf16) BEFORE upcasting to fp32 — that
   first round-to-bf16 boundary is what HF / cuBLAS see on input, and
   skipping it costs ~1 bf16 ULP at the output.
3. Runs multiply + reduction in fp32 (matching cuBLAS' fp32 accumulator
   on bf16 inputs), then casts the result back to bf16.

The five-D [B, S, H, H, D] intermediate that the unsqueeze+multiply
implies never materialises under compile: inductor fuses the multiply
with the trailing sum(dim=-2) into one triton kernel with the H_in
reduction in registers. cast(bf16) → upcast(fp32) → multiply → sum →
cast(bf16) all join that kernel's epilogue.

Verified bit-identical to HF: test_subcomponent_probe's hc_post (attn) /
hc_post (ffn) steps now max=0.0; test_csa_parity_full_hf_anchor and
test_hca_parity_full_hf_anchor both pass at atol=0.0, rtol=0.0.
…ocated -1 buffer

HCA (compress_ratio=128) layers had to dispatch to native sparse_attn
because their topk width sliding_window + (pack/128 + 1) is never a
multiple of FlashMLA's hard-coded params.topk % 128 == 0 assertion.
Native sparse_attn is a pure-PyTorch gather + einsum that runs full fp32
without tensor cores; at V4 production dims (pack=16384, k≈257) it costs
~3 ms/call/layer.

V4-ref itself uses -1 as the mask sentinel in topk_idxs
(get_window_topk_idxs L260/L264, get_compress_topk_idxs L275, Indexer
L430), so cat'ing extra -1 columns onto topk_idxs is a no-op for
attention math — FlashMLA / cuDNN / native all drop -1 slots to -inf
before softmax. The pad strip lifts the trailing dim of topk_idxs to
the next multiple of 128 so FlashMLA's alignment check passes.

The pad buffer is sized once at __init__ from DSAConfig.pack_max_length
([1, pack_max_length, padded_topk - natural_topk] of -1s) and used as
a slice + cat in forward. This keeps the entire forward graph compile-
friendly — no Python-side cache check, no @torch.compiler.disable, no
graph break, no sym-int arithmetic for the pad shape. The only sym dim
the cat sees is total_tokens, which is already symbolic. When
pack_max_length is None and the backend is flash_mla/cudnn,
_resolve_sparse_attn_fn warns and falls back to native for HCA layers
(same as before this commit). Native HCA layers also stay on the
historic dynamic-width build, since native has no alignment constraint
and padding would be pure overhead.

The build width for compress_topk on the padded path is pinned to
_hca_max_compress_w = pack_max_length // 128 + 1 (a Python int
constant) so the cat partner shape is static. Tokens below the natural
horizon already get -1 from _build_compress_topk_idxs_varlen's
clamp, so widening the build only adds more masked entries inside the
valid block — equivalent semantics, just zero-cost masked.

Expected savings vs the prior native HCA path: FlashMLA tensor-core
sparse attention beats fp32 einsum by ~10× at K=384 on H100; even
factoring the 49.6% -1 columns the pad adds, net wallclock is well
ahead. At 16 HCA layers per V4 step that maps to ~50 ms/step saved at
pack=16384.

Tests: tests/module/test_dsa.py adds test_hca_pad_buffer_static_shape
(verifies buffer shape + dtype + all-(-1) content under cudnn) and
test_hca_pack_max_length_without_flash_backend_falls_back (verifies the
buffer is NOT allocated on the native path even with pack_max_length
set). DSA + V4 parity regression: 10 passed / 7 pre-existing fails,
identical pass/fail set to before this commit.
…rd allgather

V4 keeps Hyper-Connections mixing parameters in fp32 because the 20-iter
Sinkhorn projection is bf16-unstable. Until this commit the params were
also FSDP-sharded along with everything else, so every layer's
hc_pre had to allgather them via _unshard_hc_params's
.full_tensor() call. At 43 layers × 2 calls (attn + ffn) per step
this is 86 small allgathers per step, each ~16-20 KB of fp32 data plus
the collective sync cost.

XTuner's _fully_shard already ships a per-config opt-out:
hf_save_cfg.fp32_keys_pattern regexes the HF parameter name, and
matches get replicated on the world mesh + added to FSDP's
ignored_params. We use that machinery here instead of inventing a
new path.

Pattern r"hc_(attn|ffn|head)_(fn|base|scale)$" covers all 9 HC
params:
  * per-layer layers.<L>.hc_(attn|ffn)_(fn|base|scale) (6 entries)
  * model-top-level hc_head_(fn|base|scale) (3 entries)
  * the same shape on MTP layers (mtp.<m>.hc_attn_fn etc.)
The trailing $ keeps the match precise (re.search is substring
by default).

After this change _unshard_hc_params's .full_tensor() becomes a
no-op on Replicate DTensors instead of an allgather, and FSDP skips the
HC params entirely. Memory cost: ~100 KB × 43 layers replicated on each
rank, negligible.
Inductor names triton kernels by op signature + a monotonic id
(triton_poi_fused_add_mul_..._47), with no native way to encode the
source compile target. As V4 layered up to 8 distinct compile targets
(hc_pre, hc_post, _attn_compute, _ffn_pre_compute, _ffn_post_compute,
_hc_head_reduce_compute, DSA.forward, KVCompressor.forward) the
timeline showed kernels with names that gave no hint which target they
came from.

Wrap each compiled callable in torch.profiler.record_function so the
captured timeline has a named parent range around the kernels. The
wrapper is an eager Python shim — record_function lives outside the
compiled region, never causes a graph break, and is zero-cost when no
profiler is active. Profiler views (Chrome trace, TensorBoard,
perfetto) now group kernels under e.g.
"compile:xtuner.v1.module.decoder_layer.hc_block.hc_pre", making it
obvious which compile target produced each kernel.

To keep idempotency, the wrapper propagates the get_compiler_config
attribute that torch.compile attaches to its output so
is_compiled_function still returns True; repeated
_resolve_compile_cfg passes will recognise the wrapped fn as
already compiled and skip re-wrapping.

Both compile-dispatch paths are covered:
  * MaybeCompile.enable_compile (module-level functions decorated with
    @maybe_compile, e.g. hc_pre / hc_post)
  * XTunerBaseModel._compile_function class-level branch (methods like
    DSA.forward / V4DecoderLayer._attn_compute)

Smoke: wrap_with_profile_range returns a callable that (a) produces
identical output to the unwrapped compiled fn, (b) is recognised by
is_compiled_function, (c) registers a compile:<qualname> event
in a torch.profiler capture. DSA + V4 parity regression unchanged
(10 passed / 7 pre-existing fails).
…uses flash_mla/cudnn

DSAConfig.pack_max_length is the static upper bound HCA needs to size the
pre-allocated -1 pad buffer that brings topk_idxs to a 128-multiple. Without
this set, the HCA layers' resolve_sparse_attn_fn warns and falls back to
native sparse_attn (the slow pure-PyTorch path).

Set to 4096 to match the existing DataloaderConfig.pack_max_length so the
ci config stays self-consistent. Production runs that bump the dataloader
pack should bump this in sync.
…s in sinkhorn

Three coupled changes that collectively cut hc_pre's per-call elementwise
op count and shift its Linear off the cuBLAS fp32 slow path:

1. hc_block.hc_pre: replace the explicit x.float() * rsqrt(...)
   chain with F.rms_norm(x_bf16, ..., weight=None, eps=norm_eps).
   The fp32 variance accumulator now stays inside the fused ATen op —
   no per-call full-tensor x_flat_f32 materialisation.

2. hc_block.hc_pre: cast hc_fn to bf16 instead of fp32 for the
   F.linear call (was hc_fn.float() to match HF's all-fp32 path).
   At K=16384 / N=24 cuBLAS' bf16 GEMM with its fp32 accumulator is
   ~10-20× faster than the fp32 GEMM (which has no tensor-core
   acceleration on H100 — profile showed
   sm80_xmma_gemm_f32f32_f32f32_f32_tn at 0.6 ms per call). Trade-off
   is one bf16 rounding on the GEMM inputs, propagating to ~1.5e-2 max
   abs diff at the layer output (smaller than the MoE cutlass GEMM diff
   already accepted at 2.7e-2).

3. hc_sinkhorn.hc_split_sinkhorn: drop the mixes.float(),
   hc_scale.float(), hc_base.float() calls. After (1) mixes
   arrives in fp32 (hc_pre upcasts the Linear output explicitly);
   hc_scale / hc_base are fp32 nn.Parameters. The three
   casts were no-ops in the V4 path and only added identity ops to the
   compile graph. Replaced with a single assert mixes.dtype ==
   float32 precondition.

Tests:
  * test_subcomponent_probe: hc_pre.collapsed shows 1.56e-2 max
    diff vs HF (was 0; this is the bf16 Linear's bf16-rounding
    propagation), hc_pre.post 1.16e-3 (sigmoid Lipschitz-bounded),
    hc_pre.comb 5.4e-4 (sinkhorn-renormalised). Downstream attention
    end-to-end / hc_post (attn) / hc_post (ffn) all 0.0 — bf16 Linear
    diff stays bounded by the sinkhorn/sigmoid downstream.
  * test_csa_parity_full_hf_anchor / test_hca_parity_full_hf_anchor:
    relax atol from 0.0 to 2e-2 to reflect the new HC pre/post
    tolerance.

Verified bit-identical to the pre-substitution implementation for the
F.rms_norm change alone (no precision change relative to the manual
chain). The 1.56e-2 is entirely from (2).
…ntity path

Add an opt-in env var that flips hc_pre back to the all-fp32 RMS +
fp32 Linear chain when set (matches HF's DeepseekV4HyperConnection
bitwise, at the cost of the ~10-20× cuBLAS fp32 GEMM slow path on H100).

Default (env unset / "0") stays on the bf16 Linear path from the
previous commit — fast, ~1.5e-2 max abs drift vs HF, accepted at
atol=2e-2 in _full_hf_anchor tests.

Env: XTUNER_V4_HF_PARITY=1 → bit-identical to HF.

The env var is read once at hc_block import; restart the worker to
flip it. Tests use monkeypatch.setattr(hc_block, "_HC_HF_PARITY",
True) to exercise both paths in the same process.

Tests:
  * Parametrise test_{csa,hca}_parity_full_hf_anchor across
    bf16-default (atol=2e-2) and hf-parity (atol=0.0). All four
    cases pass.
V4DecoderLayer is an FSDPModule; fully_shard registers pre/post-forward
hooks on forward to manage parameter all-gather and resharding. The
Domino schedule must therefore live INSIDE a single forward() call so
FSDP brackets the entire multi-MB pass exactly once — orchestrating the
dispatcher chain from the outer model (calling layer._forward_pre_ffn_dispatch
and layer._forward_post_ffn_combine as separate methods from
DeepSeekV4._micro_batch_forward) bypasses those hooks and breaks
parameter management. This change moves the wave pipeline into the layer
to mirror :class:`MoEDecoderLayer`.

V4DecoderLayer changes:
* forward is now variadic *hidden_states + list-typed sibling args;
  dispatches on the count: N == 1 runs _forward (renamed from the
  old forward body, single-MB unchanged), N >= 2 runs
  _micro_batch_forward. Mirrors :meth:`MoEDecoderLayer.forward`'s
  dispatch shape so the outer caller can call layers uniformly.
* _forward_pre_ffn_dispatch / _forward_post_ffn_combine (renamed to
  private from the prior public spellings): one-call halves used internally
  by _micro_batch_forward.
* _micro_batch_forward(hidden_states_list, seq_ctx_list, position_embeddings_list,
  position_embeddings_compressed_list, input_ids_list): the 3-phase
  Domino wave pipeline. Phase A queues per-MB pre-dispatch compute +
  dispatch_preprocess(async_op=True); Phase B interleaves
  dispatch / dispatch_post / experts / combine_pre / combine across MBs
  with async_op=True; Phase C drains with combine_postprocess +
  FFN-post + HC-post-FFN. Returns a flat 3 * N tuple
  (h_out_0..h_out_{N-1}, rl_0..rl_{N-1}, rw_0..rw_{N-1}) matching
  :meth:`MoEDecoderLayer._micro_batch_forward`'s contract.

DeepSeekV4._micro_batch_forward changes:
* Domino branch is now a single v4_layer(*hidden_states_list, seq_ctx=list,
  position_embeddings=list, position_embeddings_compressed=list,
  input_ids=list) call per layer, with the result tuple unpacked back into
  hidden_states_list.
* Sequential branch restored to the per-MB loop (carries the activation-offload
  context).
* Auto-fallback unchanged: domino_active requires
  n_mb >= 2 and ep_size > 1 and not XTUNER_ACTIVATION_OFFLOAD —
  NaiveDispatcher at ep=1 has no async_op=True impl, and offload's per-MB
  saved-on-cpu context is incompatible with one multi-MB layer forward.
* The _v4_sequential_layer_pass / _v4_domino_layer_pass model-level
  helpers from the previous attempt are removed; layer-internal is the
  FSDP-correct factoring.

Validated:
* v4_layer(hs) (variadic dispatch) bit-identical to v4_layer._forward(hs)
* v4_layer(MB0, MB1, lists=...) returns a 6-tuple matching two sequential
  _forward(MB) calls bit-by-bit (with NaiveDispatcher patched to ignore
  async_op so the call is reachable at ep=1)
* cfg.domino=True at ep=1 → domino_active=False (NaiveDispatcher
  never sees async_op=True).
hc_post re-expands one stream into hc_mult via
out[h_out,d] = post[h_out]*x[d] + sum_h_in comb[h_in,h_out]*residual[h_in,d].
The eager reference (now _hc_post_eager) deliberately uses broadcast-multiply +
reduce-sum instead of torch.matmul because K=hc_mult=4 is below Hopper's wgmma
tile floor and cuBLAS falls back to a slow CUDA-core GEMM. But the inductor
fusion makes each output element its own reduction thread, so residual[:, d]
is re-read once per h_out (4x); at pack=16384 the ~1 GB residual blows past L2,
leaving hc_post the single largest triton cost on the compute stream
(~170 ms/step across fwd + recompute + bwd in profiling).

This adds a fused Triton kernel (xtuner.v1.ops.hc_post) that assigns one
program to a whole token's [H_out, BLOCK_D] tile, reads residual exactly once,
and does the 4x4 mix in registers. Registered as torch.library.custom_op
(xtuner::hc_post_fwd / hc_post_bwd) with register_fake + register_autograd so
the V4 decoder layer keeps compiling around it as an opaque node (verified
fullgraph, bit-identical to eager-op under torch.compile).

Measured vs the eager reference: ~21x forward, ~7x fwd+bwd. Numerics differ by
~1 bf16 ULP (different but equally valid reduction order) and are marginally
closer to the fp32 ground truth (more accumulation stays in fp32). Gradients
match to bf16-relative ~1e-3; grad_post / grad_comb are returned in the fp32
dtype of the Sinkhorn outputs they flow back to.

Wiring:
- hc_block.hc_post is now an eager dispatcher: default bf16-on-CUDA path calls
  hc_post_fused; _HC_HF_PARITY (and any non-bf16 / non-CUDA input) falls back
  to the eager fp32 _hc_post_eager so the atol=0 HF-parity tests stay bit-exact.
- _V4_LAYER_TARGETS compiles _hc_post_eager instead of hc_post (the dispatcher's
  fast path is the opaque custom op; only the parity fp32 body wants fusion).

Test: tests/ops/test_hc_post.py — forward within bf16 ULP of the reference,
no worse than reference vs fp32 truth, gradients match (rel < 1e-2), and
torch.compile(fullgraph=True) traces with bit-identical forward.
Replace vanilla elementwise ops on the V4 hot path with optimized kernels:

- DSA q-norm and hc_pre RMSNorm now dispatch to quack.rmsnorm when quack is
  installed, falling back to the native fp32 rsqrt / F.rms_norm path otherwise.
- Add a fused Triton rope-split kernel that does the NoPE-prefix copy and the
  rope-tail rotation in a single HBM pass, dropping the intermediate
  rotated-tail allocation and the separate cat. Used by _apply_rope_split /
  _apply_rope_inverse_split on bf16 CUDA tensors; the slice+cat path remains
  the fallback. Custom op carries a register_autograd backward (transpose of
  the orthogonal rotation) so it stays compile- and autograd-safe.
- rms_norm: let zero-centered RMSNorm fall back to native on cuda instead of
  raising when XTUNER_USE_NATIVE_RMSNORM=0.

Add TestRopeSplitFused covering forward parity, forward/inverse round-trip,
and backward parity against the slice+cat reference.
…chunk loss

- Disable moe_cfg.mtp_config: V4 has no MTP block wired yet, but from_hf set it
  from num_nextn_predict_layers=1, so MoE.build_loss_ctx_batch built MTP loss
  contexts every step (~450 .item() D2H syncs at pack=8192) only to drop them.
- pack_max_length 4096 -> 12288 (shared by DSA, dataloader); micro_batch 1 -> 2.
- loss_cfg switched to chunk mode; comment out the cutlass group_gemm and
  activation-offload env toggles; enable profile_step/profile_memory.
dsa.py had grown to ~1000 lines mixing the DSAConfig + DeepSeekSparseAttention
module with two unrelated op groups. Pure code move (no signature or behavior
change) to make the package read at the module-op level:

- _dsa_rope.py: fused Triton rope-split kernel + custom op and the
  _apply_rope* / _broadcast_cos_sin helpers.
- _dsa_topk.py: varlen sparse-attn index ops (_build_*_topk_idxs_varlen,
  _interleave_window_compressed_kv, _shift_topk_to_global).

dsa.py imports both back; public imports (DeepSeekSparseAttention, DSAConfig)
are unchanged. test_dsa.py rope tests now import from _dsa_rope.
The single-layer module lived inside the model file (model/moe/deepseek_v4.py),
while every other decoder layer sits in module/decoder_layer/. V4DecoderLayer is
already self-contained — it takes plain args + a prebuilt attention_module and
has no runtime dependency on the model class — so this is a pure code move:

- New module/decoder_layer/deepseek_v4_decoder_layer.py holds V4DecoderLayer and
  its FFN-dispatch carry-state V4FFNState. _build_compressed_position_embeddings
  stays in the model (only the model forward uses it).
- deepseek_v4.py imports V4DecoderLayer back (re-exported, so existing
  `from ...deepseek_v4 import V4DecoderLayer` keeps working) and drops the
  imports that only the layer used.
- The three @maybe_compile dotted-path targets for V4DecoderLayer._attn_compute
  / _ffn_pre_compute / _ffn_post_compute are repointed to the new module path;
  verified all compile targets still resolve (a stale path would silently
  no-op the compile).

This also fixes the layering direction: hc_block/hc_sinkhorn previously
cross-referenced "up" to the model's V4DecoderLayer; the layer now sits below
the model alongside its op modules.
…opies

MoE._forward was copied near-verbatim by two subclasses (DeepSeekV4 and
Qwen3VLTextMoE) just to change a few pipeline points, duplicating all the
output-dict / aux-loss / z-loss / lm_head bookkeeping. Turn _forward into a
template method with overridable seams (defaults reproduce the current base
behavior exactly, so DeepSeekV3 / Qwen3 / Qwen3.5 / GPT-OSS are unchanged):

- _prepare_hidden_states: embed + position embeddings + _mark_dynamic; returns
  (hidden_states, layer_ctx). layer_ctx is threaded to the layer call.
- _decoder_layer_offload_ctx: the per-layer activation-offload window.
- _call_decoder_layer: invoke one layer, normalised to
  (hidden, router_logits|None, router_weights|None); dense layers return None.
- _post_layer: per-layer post hook (identity by default).
- _finalize_hidden_states: transform before the final norm (identity).
- _should_finalize_aux_loss: guard around aux_loss.finalize (True).

DeepSeekV4 now overrides _prepare_hidden_states (dual rope + HC expand, no
_mark_dynamic), _call_decoder_layer (compressed rope + input_ids),
_finalize_hidden_states (hc_head reduce) and _should_finalize_aux_loss
(hash-layer guard); it deletes its _forward copy. It inherits the base offload
ctx (first_k_dense_replace == 0 ⇒ all layers offloaded, matching before).

Qwen3VLTextMoE now overrides only _post_layer (deepstack visual-embed inject)
and deletes its _forward copy. Per decision, its image-batch path drops the
stale custom offload (FSDP all-gather stream + depth) and the no-clone
inputs_embeds path, and uses the MoE base offload (self.offload_stream) and
input handling — same as its text-batch path already did via super().

Note: V4's single-forward `extra_info` now matches the base (may be None in
forward-only/no-loss), instead of its local ModelForwardExtraLogInfo() default;
identical on the training path where loss_ctx is always present.
The micro-batch path keeps a full override (its per-MB-list topology
deliberately avoids the base's cat-then-chunk + dense-prefix machine, so the
two don't share a clean seam). But its tail duplicated logic that is now a
seam: swap the inline `_hc_head_reduce` for `_finalize_hidden_states` and the
inline `num_hash_layers < num_hidden_layers` guard for
`_should_finalize_aux_loss`, so both V4 forwards go through the same seams.
Behavior is identical (the seam bodies are exactly the swapped-out expressions).
Refresh the stale docstring that claimed `_forward` "overrides the parent".
… seams

MoE._micro_batch_forward baked the dense-prefix optimization into the main
loop: it carried a concatenated cat_hidden_states, ran dense layers on it, and
lazily chunked to a per-MB list at the first MoE layer (a stateful moe_forward /
cat_seq_ctx machine with an i.clone() workaround). That cat-then-chunk scaffold
only serves first_k_dense_replace > 0, so DeepSeekV4 (no dense prefix, HC-
expanded per-MB stream) couldn't reuse it and kept a full mb copy.

Restructure the base mb into explicit phases with the per-MB list as the
default working representation:
- _prepare_hidden_states_mb: embed + rope once on the concatenated batch, then
  split (clone) to a per-MB list. (mb seam)
- _run_dense_layers_mb: if there is a dense prefix, cat the list, run the dense
  layers on the concatenation, split back. No-op otherwise, so no-dense models
  pay no cat/chunk round-trip.
- moe phase: per-MB list throughout; offload + aux accumulate stay in the base.
- _call_decoder_layer_mb: one layer call carrying all MBs, normalised return.
  (mb seam)
- tail reuses _finalize_hidden_states / _should_finalize_aux_loss.

The in-loop moe_forward/cat_seq_ctx state machine is gone. Behavior for
dense-prefix models is preserved (cat == cat(chunk(cat))); the clone moves into
prepare so the no-dense path stays offload-safe.

DeepSeekV4 now overrides only _prepare_hidden_states_mb and
_call_decoder_layer_mb, inherits the empty dense phase and shared tail, and
deletes its _micro_batch_forward copy.

Validation: base _forward (test_moe_config) and an all-dense mb skeleton smoke
(prepare + dense phase + tail) pass single-process. The MoE-layer call under EP
is unchanged logic but only runnable under ep>1, so it needs a distributed /
V4-EP run to confirm.
…orwards

Two compaction passes on the base forward methods (behavior-preserving, no
subclass changes — V4/qwen3vl inherit the new helpers):

- _activation_offload_ctx(block_idx, tensors, *, reserve_pin_memory): single
  helper that returns the offload window or a null context based on the env
  flag, so callers `with` it unconditionally. _decoder_layer_offload_ctx
  (single path) delegates to it; the micro-batch loop drops its inline
  if offload_active / else nullcontext branch and calls it directly.
- _maybe_compute_mtp_loss / _maybe_compute_mtp_loss_mb: the ~55-line MTP blocks
  move out of _forward / _micro_batch_forward into dedicated methods that no-op
  when MTP is off. The forward bodies now read as
  prepare -> stack -> finalize -> lm_head -> mtp -> aux-finalize.

_forward shed its now-unused input_ids/position_ids locals (only MTP used them;
it reads them from seq_ctx). Verified by test_moe_config (single path) and an
all-dense mb skeleton smoke (prepare + dense phase + tail + mtp no-op).
… mirror

KVCompressor.forward used to compute total_c via `int(cu_seq_lens_out[-1].item())`
to size the chunk-buffer allocation. Under activation offload that .item() — a
synchronous D2H to pageable host memory — would queue behind in-flight offload
D2Hs on the (single, direction-specialized) D2H copy engine; profiler traces
showed the compressor's .item() stalling for the full remaining offload window
(~8 ms in the v4-flash configuration), even though the small D2H itself was on
a different stream. The hardware fact: NVIDIA copy engines are direction-
specialized and same-direction memcpys serialize on a single CE FIFO regardless
of CUDA stream.

Fix: derive the shape-defining quantities (total_c, cu_seq_lens_out) on CPU and
H2D the tiny B+1 ints. cu_seq_lens already originated on CPU inside
SequenceContext.from_input_ids (cumsum on a LongTensor, then .to(device)); we
just stop throwing the CPU copy away and plumb it through.

- SequenceContext gains cu_seq_lens_q_cpu / cu_seq_lens_k_cpu (optional fields,
  default None for backward compatibility). from_input_ids and cat populate
  them; split leaves them None for now (SP path falls back).
- KVCompressor.forward gains an optional cu_seq_lens_cpu kwarg. When provided,
  q_lens/c_lens/total_c/cu_seq_lens_out are all derived on CPU and the result
  H2D'd non-blocking (B+1 int32s, uses the H2D engine which doesn't share the
  D2H queue with offload). The old GPU+.item() path remains as the fallback.
- DSA.forward threads seq_ctx.cu_seq_lens_q_cpu into both the compressor and
  the indexer; Indexer.forward propagates it to its private KVCompressor.
- SequenceContext.from_input_ids also stops going through GPU for
  max_length_*: the same CPU cumsum was already being thrown away, now reused.

Verified:
- Bit-identical output between old and new path (max abs diff = 0.0 over a
  varlen smoke).
- Under simulated offload pressure (2x 200MB D2H in flight on offload stream),
  compressor host-block drops from 8213us to 656us (~12.5x). Baseline (no
  offload pressure) is unchanged.

Refs the activation-offload / copy-engine analysis tracked locally.
…ad of leaving None

The prior commit added optional cu_seq_lens_q_cpu / cu_seq_lens_k_cpu fields but
left them defaulting to None, which meant any caller that didn't explicitly pass
them (raw __init__, the SP split path, future code paths) silently regressed to
the slow GPU + .item() path in KVCompressor.

Derive the CPU mirrors from cu_seq_lens_q / cu_seq_lens_k at __init__ when the
caller did not pass them. The derivation is a one-shot D2H at construction time;
SequenceContext is constructed before forward in every existing path
(dataloader, SP split), so the D2H happens when the offload queue is empty.
The only forward-time constructor (cat) populates the mirrors directly via the
already-on-CPU parts, so it never triggers this default branch.

Effect: cu_seq_lens_*_cpu is now guaranteed non-None across all construction
paths, including those that didn't go through from_input_ids.
Adds an opt-in path through DeepSeek's TileKernels ``mhc`` backend (Hopper SM90+
TileLang JIT) for the HC primitives on the V4 forward graph. Default is OFF;
set ``XTUNER_USE_MHC_KERNELS=1`` to enable. ``XTUNER_V4_HF_PARITY=1`` still
takes precedence as the bit-exact trust anchor.

xtuner/v1/ops/mhc/__init__.py wraps six TileLang kernels behind
``torch.library.custom_op`` (fwd + register_fake + register_autograd) so the
calls participate in autograd through the matching TileKernels ``*_bwd``
kernels:

  * mhc_head_compute_mix  → DeepSeekV4._hc_head_sigmoid_gate
  * mhc_post              → hc_block.hc_post
  * mhc_expand            → DeepSeekV4._expand_hc_streams
  * mhc_sinkhorn          → hc_sinkhorn.hc_split_sinkhorn (iter loop)
  * mhc_pre_split_mixes   → hc_sinkhorn.hc_split_sinkhorn (split + sigmoid)
  * mhc_pre_apply_mix     → hc_block.hc_pre (weighted reduce tail)

Parity vs the eager / HF path: fp32 ULP on the fp32 ops, bf16 ULP on the bf16
ops; e2e V4 SFT matches MHC=0 loss within 1e-3 with tgs +4% peak (22246 → 23157
on 8×H200 / pack=12288).

Side changes pulled in with the same train run:

  * moe.py: skip activation_offload window on the last decoder layer (no
    next-layer consumer means the H2D stash is pure overhead).
  * ci/config/deepseek_v4_flash.py: bump ep_size to 8, switch compile_cfg off
    for the MHC-on path (TileLang kernels are not yet in the compile target
    list), add num_workers=4 to the dataloader.
The triton indexer top-k held a per-query ``[next_pow2(total_c)]`` score buffer
collapsed by a single final ``tl.topk``, so every query paid a top-k sized at the
global compressed length regardless of its causal horizon. Runtime scales ~linearly
with that buffer width, so the many short sub-samples in a packed varlen batch each
paid an up-to-4096-wide topk for a horizon of a few dozen positions.

Replace it with a running ``[next_pow2(K + BLOCK_C)]`` top-K merged tile-by-tile:
the collapse cost now scales with each query's horizon (number of c-tiles), so a
short-sample query collapses a 1024-wide buffer instead of 4096. Exact-parity with
the old kernel (same packing / tie-break), and compile-safe (single launch, static
grid). ~2.16x faster at V4 dims on realistic many-sample packs; single-/few-sample
long packs run more collapses per query and regress, which is acceptable because
packed varlen batches are dominated by short samples.
…_out reuse

Two changes that co-evolve the KVCompressor.forward signature (compressed-rope
table + precomputed boundaries are added side by side), so they land together.

1. Compressed-kv RoPE. After the chunk softmax + norm, rotate each compressed
   chunk's rope tail at its window-center position, mirroring HF
   DeepseekV4{CSA,HCA}Compressor.forward. ``qk_rope_head_dim`` is wired from the
   DSA/Indexer configs into the internal KVCompressor, and DSA now forwards
   ``position_embeddings_compressed`` to the compressor (required for
   compress_ratio > 0, not just == 4). The chunk->sample map uses
   ``searchsorted(cu_seq_lens_out, ., right=True) - 1`` — right=True is
   load-bearing: a chunk on a sample boundary is the first chunk of the next
   sample, and mapping it to the previous one overruns ``first_token_per_chunk``
   and indexes the rope table out of bounds.

2. Hoist cu_seq_lens_out. ``KVCompressor.build_cu_seq_lens_out`` computes the
   per-sample compressed boundaries once; DeepSeekV4 forward builds one per
   distinct compress_ratio and caches it on
   ``SequenceContext.compressed_cu_seq_lens``, so every decoder layer of that
   ratio reuses a single cumsum + H2D instead of recomputing it. ``total_c``
   stays derived in the compressor from the CPU mirror (it must remain a Python
   int and would force a recompile if threaded through the compiled attn graph).
   Standalone callers (no cache on seq_ctx) fall back to building it in-place.
Pass ``fused=True`` to ``torch.optim.AdamW`` on the non-foreach path. The fused
kernel folds the per-parameter element-wise AdamW update into a single CUDA
launch, cutting launch overhead and optimizer-step time at large parameter
counts.
… bs128

Switch the V4-Flash reference config from the 4-layer smoke setup to the full
run: drop the ``num_hidden_layers = 4`` cap (use the release's 43 layers), select
the fused Triton indexer top-k backend (``indexer_backend = "triton"``), turn
``compile_cfg`` on, and raise ``global_batch_size`` 16 -> 128.
…an attention/dsa subpackage

The DSA building blocks were scattered as 8 sibling modules under
``module/attention/`` (dsa, indexer, kv_compressor, sparse_attn, plus the private
_dsa_rope / _dsa_topk / _indexer_topk_triton / _flash_mla_sparse_attn kernels),
mixed in with the unrelated mha / mla / gated_deltanet / kv_cache modules. Move
them into a cohesive ``module/attention/dsa/`` subpackage whose ``__init__``
re-exports the public API (DeepSeekSparseAttention, DSAConfig, Indexer,
IndexerConfig, sparse_attn), so ``from ...attention.dsa import DeepSeekSparseAttention``
and the ``attention`` package facade keep working unchanged.

Internals only: relative imports that reach outside the package (rms_norm,
attn_outputs) gain one level; the DSA-internal cross-imports are untouched. The
V4 compile-target qualnames in deepseek_v4.py follow the new ``__module__``
(attention.dsa.dsa / attention.dsa.kv_compressor). The four DSA tests move to
``tests/module/dsa/`` to mirror the package. No behavior change.
…oder_layer/deepseek_v4 subpackage

The V4 decoder block and its Hyper-Connections helpers were three loose siblings
under ``module/decoder_layer/`` (deepseek_v4_decoder_layer, hc_block, hc_sinkhorn)
next to the unrelated dense / moe decoder layers. Move them into a cohesive
``module/decoder_layer/deepseek_v4/`` subpackage (decoder_layer.py + hc_block.py +
hc_sinkhorn.py) whose ``__init__`` re-exports V4DecoderLayer / V4FFNState plus the
HC API (HCWrapperConfig, hc_pre, hc_post, hc_split_sinkhorn). The ``decoder_layer``
package facade keeps re-exporting the HC symbols, so ``from ...decoder_layer import
HCWrapperConfig`` is unchanged.

The layer now imports hc_block relatively; moe_decoder_layer stays put. The V4
compile-target qualnames in deepseek_v4.py follow the new ``__module__``
(decoder_layer.deepseek_v4.{decoder_layer,hc_block}.*) — verified to still match
via get_function_full_qualname. Doc/test references updated. No behavior change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant