Skip to content

[None][perf] Scheme X L2-aware dispatcher and PDL launchers for sparse-attention GVR Top-K#13477

Open
longcheng-nv wants to merge 6 commits intoNVIDIA:mainfrom
longcheng-nv:feat/gvr-topk-v123
Open

[None][perf] Scheme X L2-aware dispatcher and PDL launchers for sparse-attention GVR Top-K#13477
longcheng-nv wants to merge 6 commits intoNVIDIA:mainfrom
longcheng-nv:feat/gvr-topk-v123

Conversation

@longcheng-nv
Copy link
Copy Markdown
Collaborator

@longcheng-nv longcheng-nv commented Apr 26, 2026

Summary

Follow-up to #12385 (Temporally-Correlated Heuristic-guided Indexer TopK).
This PR adds the per-(BS, N) Scheme X dispatcher, PDL launchers,
CUDA Graph warmup, and renames the inner micro-kernel symbols to
GVR (Guess-Verify-Refine).

  • Scheme X dispatcher (v1.1 + v1.2.3): per-(BS, N) routing between the
    heuristic kernel and the radix fallback, derived from MultiProcessorCount
    and L2CacheSize queried once at runtime. Closes the regression band
    around BS=128 / N=70K where the heuristic alone was slower than the
    radix path.
  • PDL on heuristic launchers: switch launchHeuristicTopKDecode to
    cudaLaunchKernelEx with cudaLaunchAttributeProgrammaticStreamSerialization;
    call cudaTriggerProgrammaticLaunchCompletion() at the kernel epilogue.
    Symmetric with the radix path that already used PDL via
    invokeIndexerTopKDecode.
  • CUDA Graph safety: warmup_heuristic_topk_decode helper invoked from
    the Indexer setup hook (layer_idx == 0) so the dispatcher's
    cudaGetDevice / cudaDeviceGetAttribute queries land outside any
    capture region; the cached sm_count and L2CacheSize are populated
    before any graph capture begins.
  • Rename: inner micro-kernel symbols change from heuristicTopK* to
    gvrTopK* (gvr = Guess-Verify-Refine, the algorithm of the upcoming
    algorithm note). Public dispatcher / launcher names remain unchanged.

Commit Breakdown (5 commits)

  1. fix: Warm up heuristic TopK dispatcher for CUDA Graph safety
  2. feat: L2-aware BS-threshold dispatcher (Scheme X v1.1) for GVR Top-K
  3. feat: Lower GVR Top-K small-N threshold to 12288 (Scheme X v1.2.3)
  4. perf: Enable PDL on heuristic TopK launcher and kernel epilogues
  5. refactor: Rename inner GVR micro-kernel symbols

Key Files

File Change
cpp/tensorrt_llm/kernels/heuristic_topk.cuh GVR rename + phase comments (P1+P2 = Guess, P3 = Verify, P4 = Refine) + PDL trigger in standalone kernel
cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu PDL via cudaLaunchKernelEx + GVR rename + 2 PDL triggers
cpp/tensorrt_llm/kernels/indexerTopK.cu Scheme X v1.1 + v1.2.3 dispatcher (kBsWave, kBsL2, kSeqSmall=12288)
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py warmup_heuristic_topk_decode Python helper
tensorrt_llm/_torch/attention_backend/sparse/dsa.py Indexer setup hook calling the warmup helper

API

No new user-facing API. The dispatcher is automatically engaged when
enable_heuristic_topk=True (per-DeepSeekSparseAttentionConfig,
already shipped in #12385). Override knob for benchmarking:
TRTLLM_HEURISTIC_NMIN=<int> env var (defaults to dispatcher-decided).

Test plan

  • Unit tests: pytest tests/unittest/_torch/thop/parallel/test_indexer_topk.py -k test_indexer_topk_decode284 passed / 0 failed (32.4 s)
  • Scheme X T2 / T3 / T4 / T5 verification suite on B200 sm_100 — all PASS
  • Multi-day regression suite (nsys pooled / CUDA Graph / Ragged / CDF) — all clean
  • All output indices verified identical to torch.topk (set equality; outputs are unordered)
  • CUDA Graph capture/replay safe with the new warmup helper

Correctness coverage

Test Cases Result Notes
test_indexer_topk_decode 32 ✅ PASS bs={1,64,512,2048} × next_n={1,2} × index_topk={2048,128} × num_tokens={4K,8K}
test_indexer_topk_decode_dist 252 ✅ PASS Distribution-parameterized: beta / lognorm / logistic / weibull_min × MTP × success_ratio
Scheme X T2 (indices) 7 N values ✅ PASS 12288 / 13312 / 14336 / 15360 / 16384 / 32768 / 70688 — Heuristic + Fallback
Scheme X T3 (crossover band) 6 N values ✅ PASS Heuristic ≤ Fallback at all N in [12288, 16384)
Scheme X T5 (Scenario B) 7 N values ✅ PASS preIdx=null path byte-identical to torch.topk

Performance Results (B200 sm_100)

The heuristic TopK micro-kernel (gvrTopKJob — single-CTA single-row,
called from heuristicTopKMultiRowKernel) is benchmarked against the
default radix-sort path (topKPerRowDecode).

BS=1 single-op vs N — realistic input (DeepSeek V3.2 SWE-Bench-64K decode logits)

Profiled across 9 layers × 17 decode steps (N ≈ 68.7K – 70.7K):

Layer v0 latency (μs) Speedup vs Radix
L0 33.01 1.602×
L1 29.57 1.767×
L20 29.54 1.757×
L21 24.22 2.115×
L22 25.58 2.003×
L40 26.13 1.963×
L41 26.64 1.896×
L42 24.72 2.144×
L60 25.30 2.027×

Average BS=1 speedup: 1.91× across 9 layers (consistent with the
1.81× on the synthetic DeepSeek workload reported in #12385). T4 spot
check at N=70688 with random-correlated preIdx: 34.82 μs Heuristic vs
55.30 μs Radix → 1.59×.

BS sweep — pooled across 9 layers × 16 rows

BS v0 latency (μs) Speedup vs Radix
1 26.40 1.958×
2 26.80 1.940×
4 26.72 1.917×
8 27.04 1.903×
16 27.22 1.902×
32 27.81 1.876×
64 28.64 1.839×
128 31.10 1.738×
148 32.05 1.704×
256 45.18 1.528×
312 60.64 1.399×
400 67.34 1.360×
432 92.53 1.000× ◀ dispatcher routes to Radix at BS ≥ 432 (Scheme X kBsLarge)
512 133.39 1.000×

The Scheme X dispatcher routes BS ≥ 432 back to the radix path on a
148-SM B200 (where Heuristic loses its lead because the wave-occupancy
crossover point has been crossed). All 15 BS values are within ±2 μs of
the v1.1 baseline (max abs Δ = -1.09 μs, mean Δ = -0.07 μs).

Sweet spot — small-N + large-BS (Scheme X v1.2.3 vs ForceHeuristic)

kSeqSmall=12288 lets the dispatcher route small-N + large-BS regimes
to the radix path where Heuristic histogram overhead dominates:

BS \ N 4096 8192 12288 16384
1 1.546× 1.427× 0.959× 0.999×
64 1.832× 1.785× 1.168× 0.999×
128 1.834× 1.855× 1.208× 1.000×
256 4.357× 2.470× 1.125× 1.000×

Crossover band [12288, 16384) — T3 verification

At kSeqSmall=12288, the heuristic kernel must remain ≤ the radix
fallback in the BS=1 + correlated-preIdx case:

N Heur (correlated, μs) Heur (random, μs) Radix fallback (μs) OK
12288 22.56 20.51 24.58
13312 22.56 24.58 24.61
14336 20.51 22.53 24.58
15360 22.54 22.59 26.62
16383 22.59 22.56 26.66
16384 24.58 22.56 26.62

Regression guard

Verified against the Scheme X v1.1 baseline:

Verification Coverage Result
nsys pooled 2160 cells (15 BS × 9 layers × 16 rows × 2 variants) v0 max abs Δ = -1.09 μs (BS=512), mean Δ = -0.07 μs
CUDA Graph 18 cells (6 configs × 3 BS) 18/18 capture=True & correctness=True; replay/eager ratio max Δ = +0.058 (cold capture warmup)
Ragged 33 intra + 400-step Poisson trace 33/33 correctness=True; 400/400 GVR-routed; Σv0_us = 20619.7 (vs 20773.7, -0.7%)
CDF data-sim 9 layers × 2024 rows × 18216 total 9 layers byte-identical to v1.1 baseline

GitHub Bot Help

/bot run --disable-fail-fast

Summary by CodeRabbit

  • New Features

    • Added warmup utility function for heuristic top-K decode to pre-populate caches before CUDA Graph capture.
  • Bug Fixes

    • Improved tie-breaking in top-K selection to prevent dropping strictly-greater candidates at the threshold boundary.
  • Performance

    • Optimized heuristic top-K decode dispatch with refined path selection logic and hardware-aware thresholds.
    • Enhanced execution efficiency on newer GPU architectures with programmatic launch completion support.
  • Configuration

    • Added environment variables (TRTLLM_HEURISTIC_NMIN, TRTLLM_SCHEMEX_DEBUG, PDL control) for tuning heuristic top-K behavior.

longcheng-nv and others added 5 commits April 26, 2026 14:04
Add ``warmup_heuristic_topk_decode`` helper that issues one small
``indexer_topk_decode`` call from the ``Indexer`` setup hook before
any CUDA Graph capture begins. This forces the C++ Scheme X
dispatcher to run its one-time ``cudaGetDevice`` /
``cudaDeviceGetAttribute`` host queries outside capture, so the
cached ``sm_count`` and ``L2CacheSize`` values are populated up front
and not frozen into a captured graph.

The warmup is gated on ``enable_heuristic_topk`` to match the runtime
configuration already used to select the heuristic path; cold-start
rows stay cold (no ``threshold_pred`` is passed), so Opt-M semantics
are unaffected.

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… Top-K

Host-side dispatcher in `invokeIndexerTopKDecode` routes between the
heuristic and radix Top-K paths by comparing `numRows` against an
architecture-derived threshold:

  kBsLarge = min(3*SM - SM/8, 0.9*L2 / (4*N))

The occupancy bound (3*SM - SM/8) reflects the per-CTA SMEM budget
(~58 KB kernel SMEM vs B200's 228 KB dynamic SMEM → 3 CTA/SM max),
with a -SM/8 margin for CTA-setup and L2-ingestion overhead.
On B200(148 SM): 3×148 − 18 = 426.

The L2 bound (0.9*L2 / (4*N)) reflects per-row logits fit into the
GPU L2 — once `concurrent_CTAs × N × 4B` exceeds L2, eviction
dominates and the heuristic kernel loses to radix. The two
constraints cross near N ≈ 73K; for SWE-Bench N ≈ 70K both yield
≈426, so this is a zero-regression change for DSv3.2 decode. For
larger N (e.g. 128K → 227, 196K → 148), the L2 bound auto-tightens
and preserves the no-regression guarantee.

Both hardware attributes (`cudaDevAttrMultiProcessorCount`,
`cudaDevAttrL2CacheSize`) are queried once and static-cached. Zero
data-dependent hyperparameters; zero kernel changes; +10 host LOC.
An opt-in `TRTLLM_SCHEMEX_DEBUG=1` env var emits a per-launch
dispatch trace for introspection.

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…2.3)

Add a small-N lower bound `kSeqSmall` to `canUseHeuristic` so the
GVR heuristic Top-K path engages wherever the original TRT-LLM
Radix-radix branch would have triggered. Below `kSeqSmall` the
existing Insertion / Radix-radix path stays active.

Rationale:

* GVR Heuristic has a fixed per-launch overhead (P1 preIdx stats
  + P4 2048-bin histogram snap, ~11 us regardless of N). For small
  N the fixed cost dominates and the kernel loses to the existing
  insertion-sort path. The crossover N depends on data: random
  benchmarks show Heuristic reaching parity at N=16384, but real
  SWE-Bench workloads see Heuristic ~6.3 us faster than random
  (preIdx-vs-logits ~99% correlated -> P1 stats accurate -> P2
  secant converges in 1-2 iterations), shifting the real crossover
  into the [12288, 16384] band.

* Setting `kSeqSmall = 12288` lets the Heuristic axis take over
  wherever the original Radix-radix branch would fire on real
  workload, while keeping N < 12288 on the insertion path (where
  GVR's fixed overhead remains uncompetitive).

The original radix dispatcher constants (`kSortingAlgorithmThreshold
= 12288`, `kDefaultSplitWorkThreshold = 200000`) are NOT touched --
when `canUseHeuristic` is false (e.g. preIdx missing, BS too large,
N < kSeqSmall), the dispatcher falls back to BYTE-IDENTICAL
pre-Scheme-X behavior.

Tunables:

* `TRTLLM_HEURISTIC_NMIN` env (range [1024, 200000]) overrides the
  default kSeqSmall at process start; cached after first query.

* `TRTLLM_SCHEMEX_DEBUG=1` env now prints `kSeqSmall` plus a
  "(small-N route)" marker when N < kSeqSmall, alongside the
  existing `kBsWave / kBsL2 / kBsLarge` trace.

Validation (B200, single-GPU):

* Indices set match `torch.topk` for N in {12288, 13312, 14336,
  15360, 16384, 32768, 70688} -- both Heuristic and fallback paths.

* Crossover-band perf at [12288, 16384) (correlated preIdx, BS=1):
  Heuristic 20-23 us vs Radix-radix fallback 24-26 us -- 1.09-1.20x
  speedup at the new boundary.

* Production N=70688 path unchanged: 32.77 us Heuristic vs 55.30 us
  Radix-radix fallback (1.69x preserved).

* Scenario B regression guard (preIdx=null) -- all N produce
  indices set matching `torch.topk` reference.

No kernel-body or radix dispatcher changes; +57 lines (mostly
explanatory comments around the new `kSeqSmall` cache + the updated
Heuristic eligibility predicate), -5 lines.

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wire the heuristic TopK decode path into the project's Programmatic
Dependent Launch (PDL) pipeline so it stays symmetric with the radix
and insertion fallbacks (both of which already use PDL via
``cudaLaunchKernelEx`` in ``invokeIndexerTopKDecode``).

Changes:

* ``launchHeuristicTopKDecode`` now launches via ``cudaLaunchKernelEx``
  with ``cudaLaunchAttributeProgrammaticStreamSerialization`` set from
  ``tensorrt_llm::common::getEnvEnablePDL()``. When
  ``TRTLLM_ENABLE_PDL=1`` (the default in production), the heuristic
  kernel is allowed to start before the preceding PDL kernel's tail
  drains, giving the same overlap the radix path already enjoyed.

* Both heuristic kernel entry points -- the multi-row
  ``heuristicTopKMultiRowKernel`` and the standalone
  ``heuristicTopKKernel`` -- now call
  ``cudaTriggerProgrammaticLaunchCompletion()`` at their exit points
  (guarded by ``__CUDA_ARCH__ >= 900``) so the next PDL kernel can
  likewise pre-launch.

No functional change when PDL is disabled. No kernel-body or
dispatcher-logic changes; this only adjusts the launch attributes
and adds the device-side trigger.

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rename the single-CTA single-row micro-kernel symbols inside
`heuristic_topk.cuh` from the `heuristicTopK*` family to the
`gvrTopK*` family, to align the in-source vocabulary with the
upcoming algorithm note "Guess-Verify-Refine: Data-Aware Top-K
for Sparse-Attention Decoding on Blackwell via Temporal
Correlation".

Renames (inner micro-kernel only):

* `heuristicTopKJob`    -> `gvrTopKJob`     (`__device__` per-CTA
                                              algorithm body)
* `heuristicTopKKernel` -> `gvrTopKKernel`  (`__global__` single-row
                                              wrapper used by the
                                              standalone launcher
                                              `launchHeuristicTopK`)

Outer wrappers keep their `heuristic*` / `kHeuristic*` names because
they sit on the public dispatcher / launch surface that callers
already depend on:

* `launchHeuristicTopKDecode`     -- multi-row decode launcher
* `heuristicTopKMultiRowKernel`   -- multi-row global wrapper
* `launchHeuristicTopK`           -- single-row standalone launcher
* `canUseHeuristic`, `kHeuristicTopK`, `kHeuristicSize` -- dispatcher
                                                          predicates +
                                                          public
                                                          constants

Phase comments inside `gvrTopKJob` are also annotated with the GVR
mapping for cross-reference with the algorithm note:

* P1 (preIdx Min/Max/Mean)        -> GVR Guess, part 1
* P2 (Secant threshold search)    -> GVR Guess, part 2
* P3 (Ballot-free collect)        -> GVR Verify
* P4 (Histogram snap + partition) -> GVR Refine

Pure source-level rename; no signature changes on any public symbol,
no kernel-body changes, no dispatcher-logic changes. Rebuilt
`tensorrt_llm` + `th_common` and re-ran the indices-correctness +
crossover-band + Scenario-B suite -- all PASS, with N=70688 timing
byte-identical to the pre-rename run (32.77 us / 1.69x vs
Radix-radix).

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45585 [ run ] triggered by Bot. Commit: 7f6ed4c Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 26, 2026

📝 Walkthrough

Walkthrough

Implements a "Guess-Verify-Refine" (GVR) micro-kernel architecture for TopK decoding with renamed kernels (heuristicTopKJobgvrTopKJob), reorganized phase control flow, two-pass candidate emission to avoid drops during ties, programmatic launch completion for CUDA 9.0+ architectures, dynamic dispatch eligibility thresholds, and a Python warmup utility for cache pre-population before CUDA graph capture.

Changes

Cohort / File(s) Summary
GVR Kernel Implementation
cpp/tensorrt_llm/kernels/heuristic_topk.cuh, cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu
Renames heuristicTopKJob to gvrTopKJob and reorganizes kernel phases to "Guess-Verify-Refine" mapping; replaces single-pass emission with two-pass (strictly-greater, then equal-to-threshold) to prevent candidate loss at K-th boundary; adds cudaTriggerProgrammaticLaunchCompletion() for __CUDA_ARCH__ >= 900; switches host launch to cudaLaunchKernelEx with cudaLaunchAttributeProgrammaticStreamSerialization controlled by PDL environment variable.
Dispatch Eligibility Logic
cpp/tensorrt_llm/kernels/indexerTopK.cu
Computes cached hardware-derived thresholds (kBsWave, kBsL2, kBsLarge) based on device SM count and L2 cache; adds small-N cutoff (kSeqSmall) from TRTLLM_HEURISTIC_NMIN environment variable (default 12288); tightens canUseHeuristic conditions to require numColumns >= kSeqSmall and valid scratch buffer; adds optional debug output via TRTLLM_SCHEMEX_DEBUG.
Python Warmup Integration
tensorrt_llm/_torch/attention_backend/sparse/dsa.py, tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Adds layer-0-only warmup call during Indexer initialization; introduces new public utility function warmup_heuristic_topk_decode() that executes one-time heuristic TopK decode with minimal tensors to pre-populate static caches outside CUDA graph capture scope.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically summarizes the main changes: Scheme X L2-aware dispatcher and PDL launchers for sparse-attention GVR Top-K, directly reflecting the core performance improvements and refactoring objectives.
Description check ✅ Passed The PR description comprehensively covers all template sections with clear explanations of what changed, why, extensive test coverage documentation, and performance results supporting the implementation.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1)

1161-1191: Make the warmup idempotent.

This helper initializes process-global dispatcher caches, but it reruns the CUDA allocations and torch.cuda.synchronize() every time it is called. Memoizing it per (device, top_k, hint_size, num_cols) would avoid repeated init-time stalls when multiple models/indexers are constructed in one process.

♻️ Suggested direction
+_heuristic_topk_warmups: set[tuple[int, int, int, int]] = set()
+
 def warmup_heuristic_topk_decode(top_k: int = 2048,
                                  hint_size: int = 2048,
                                  num_cols: int = 4096) -> None:
+    device_idx = torch.cuda.current_device()
+    key = (device_idx, top_k, hint_size, num_cols)
+    if key in _heuristic_topk_warmups:
+        return
-    device = torch.device("cuda")
+    device = torch.device(f"cuda:{device_idx}")
     logits = torch.zeros((1, num_cols), dtype=torch.float32, device=device)
     seq_lens = torch.tensor([num_cols], dtype=torch.int32, device=device)
     indices = torch.empty((1, top_k), dtype=torch.int32, device=device)
     pre_idx = torch.zeros((1, hint_size), dtype=torch.int32, device=device)
     scratch = torch.empty((top_k, ), dtype=torch.float32, device=device)
     torch.ops.trtllm.indexer_topk_decode(logits,
                                          seq_lens,
                                          indices,
                                          1,
                                          top_k,
                                          pre_idx=pre_idx,
                                          heuristic_scratch=scratch)
     torch.cuda.synchronize()
+    _heuristic_topk_warmups.add(key)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py` around lines 1161 - 1191,
Make warmup_heuristic_topk_decode idempotent by adding a module-level cache and
lock (e.g. _warmup_heuristic_topk_decode_cache and
_warmup_heuristic_topk_decode_lock) and short-circuiting when the tuple key for
(device, top_k, hint_size, num_cols) is already warmed; construct the key using
the CUDA device identity (device.index or device.type+index) plus the three
ints, acquire the lock, check cache, perform the existing
allocation/call/synchronize only if not present, then add the key to the cache
before releasing the lock so subsequent calls skip the heavy work.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh`:
- Around line 802-806: The launcher for gvrTopKKernel must enable programmatic
stream serialization (PDL); replace the triple-angle launch in
launchHeuristicTopK() with a cudaLaunchKernelEx call and pass
cudaLaunchAttributeProgrammaticStreamSerialization in the attributes. Keep the
same kernel function (gvrTopKKernel), grid (1), block (BLOCK_SIZE), shared
memory size (smemSize), and stream, and build the void* args array containing
input, &N, &preIdx, &M, &topK, outputValues, outputIndices, &thresholdPos
(matching the kernel signature). Also preserve the prior cudaFuncSetAttribute
call for MaxDynamicSharedMemorySize; ensure cudaLaunchKernelEx is used exactly
like the pattern in heuristicTopKDecode.cu with the programmatic serialization
attribute.

In `@cpp/tensorrt_llm/kernels/indexerTopK.cu`:
- Around line 740-803: The function-level mutable statics sCachedSmCount,
sCachedL2Bytes, sCachedNMin and the debug flags must be initialized once to
avoid races; wrap all their initialization (the cudaDeviceGetAttribute calls,
the TRTLLM_HEURISTIC_NMIN getenv logic, and the TRTLLM_SCHEMEX_DEBUG
getenv/flags) in a single std::call_once using a static std::once_flag (e.g.,
sSchemeXInitFlag) so concurrent callers of invokeIndexerTopKDecode() cannot
race; move the existing if (sCached... == 0) blocks into a single lambda invoked
by std::call_once and remove the separate sDebugChecked pattern so sDebug is set
under the same call_once guard.

---

Nitpick comments:
In `@tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py`:
- Around line 1161-1191: Make warmup_heuristic_topk_decode idempotent by adding
a module-level cache and lock (e.g. _warmup_heuristic_topk_decode_cache and
_warmup_heuristic_topk_decode_lock) and short-circuiting when the tuple key for
(device, top_k, hint_size, num_cols) is already warmed; construct the key using
the CUDA device identity (device.index or device.type+index) plus the three
ints, acquire the lock, check cache, perform the existing
allocation/call/synchronize only if not present, then add the key to the cache
before releasing the lock so subsequent calls skip the heavy work.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: eaa9f343-c674-49cf-b002-5a5fb043ef42

📥 Commits

Reviewing files that changed from the base of the PR and between dd907c0 and 7f6ed4c.

📒 Files selected for processing (5)
  • cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu
  • cpp/tensorrt_llm/kernels/heuristic_topk.cuh
  • cpp/tensorrt_llm/kernels/indexerTopK.cu
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Comment on lines 802 to 806
cudaFuncSetAttribute(gvrTopKKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(smemSize));
}

heuristicTopKKernel<<<1, BLOCK_SIZE, smemSize, stream>>>(
gvrTopKKernel<<<1, BLOCK_SIZE, smemSize, stream>>>(
input, N, preIdx, M, topK, outputValues, outputIndices, thresholdPos);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Call sites / definition for launchHeuristicTopK:"
rg -nP '\blaunchHeuristicTopK\s*\('

echo
echo "Current single-row launcher:"
sed -n '790,807p' cpp/tensorrt_llm/kernels/heuristic_topk.cuh

echo
echo "PDL-enabled decode launcher for comparison:"
sed -n '113,125p' cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu

Repository: NVIDIA/TensorRT-LLM

Length of output: 1701


🏁 Script executed:

# Check for all call sites of launchHeuristicTopK across the repository
rg -n '\blaunchHeuristicTopK\s*\(' --type cpp --type cu --type cuh -A 2

echo "---"
echo "Checking gvrTopKKernel implementation for cudaTriggerProgrammaticLaunchCompletion:"
sed -n '770,780p' cpp/tensorrt_llm/kernels/heuristic_topk.cuh

Repository: NVIDIA/TensorRT-LLM

Length of output: 585


🏁 Script executed:

# Search for all references to launchHeuristicTopK in the repo
rg -i 'launchHeuristicTopK' --no-type -n

echo "---"
echo "Checking if launchHeuristicTopK is declared in any header:"
rg -i 'launchHeuristicTopK' --no-type -n | grep -E '\.(h|hpp|cuh):'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1016


🏁 Script executed:

# Simple search for all launchHeuristicTopK references
rg 'launchHeuristicTopK' -n

echo "---"
echo "Check the full function definition and context:"
sed -n '782,810p' cpp/tensorrt_llm/kernels/heuristic_topk.cuh

Repository: NVIDIA/TensorRT-LLM

Length of output: 2418


Enable PDL on launchHeuristicTopK() to match the kernel's PDL-aware design.

The kernel gvrTopKKernel calls cudaTriggerProgrammaticLaunchCompletion() (lines 772–774, Hopper+ only), but this launcher uses <<<...>>> syntax and doesn't enable PDL. While no internal call sites exist, the function is template-instantiated and exported (line 811), making it part of the public API. The mismatch between the kernel's PDL-aware completion signaling and the launcher's inability to enable stream serialization should be resolved.

Update the launcher to use cudaLaunchKernelEx with cudaLaunchAttributeProgrammaticStreamSerialization to match the pattern used in heuristicTopKDecode.cu (lines 113–125).

🔧 Suggested fix
-    gvrTopKKernel<<<1, BLOCK_SIZE, smemSize, stream>>>(
-        input, N, preIdx, M, topK, outputValues, outputIndices, thresholdPos);
+    cudaLaunchConfig_t config;
+    config.gridDim = 1;
+    config.blockDim = BLOCK_SIZE;
+    config.dynamicSmemBytes = smemSize;
+    config.stream = stream;
+
+    cudaLaunchAttribute attrs[1];
+    attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
+    attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
+    config.numAttrs = 1;
+    config.attrs = attrs;
+
+    cudaLaunchKernelEx(&config, gvrTopKKernel, input, N, preIdx, M, topK, outputValues, outputIndices, thresholdPos);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh` around lines 802 - 806, The
launcher for gvrTopKKernel must enable programmatic stream serialization (PDL);
replace the triple-angle launch in launchHeuristicTopK() with a
cudaLaunchKernelEx call and pass
cudaLaunchAttributeProgrammaticStreamSerialization in the attributes. Keep the
same kernel function (gvrTopKKernel), grid (1), block (BLOCK_SIZE), shared
memory size (smemSize), and stream, and build the void* args array containing
input, &N, &preIdx, &M, &topK, outputValues, outputIndices, &thresholdPos
(matching the kernel signature). Also preserve the prior cudaFuncSetAttribute
call for MaxDynamicSharedMemorySize; ensure cudaLaunchKernelEx is used exactly
like the pattern in heuristicTopKDecode.cu with the programmatic serialization
attribute.

Comment on lines +740 to +803
static int sCachedSmCount = 0;
static int sCachedL2Bytes = 0;
if (sCachedSmCount == 0 || sCachedL2Bytes == 0)
{
int dev = 0;
cudaGetDevice(&dev);
cudaDeviceGetAttribute(&sCachedSmCount, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&sCachedL2Bytes, cudaDevAttrL2CacheSize, dev);
}
int const kBsWave = (sCachedSmCount > 0) ? (sCachedSmCount * 3 - sCachedSmCount / 8) : 426;
int const kBsL2 = (sCachedL2Bytes > 0 && numColumns > 0)
? (int) ((int64_t) sCachedL2Bytes * 9 / 10 / ((int64_t) numColumns * 4))
: kBsWave;
int const kBsLarge = std::min(kBsWave, kBsL2 > 0 ? kBsL2 : kBsWave);

// v1.2: small-N lower bound — set to kSortingAlgorithmThreshold (12288) so
// the Heuristic axis takes over wherever the original Radix-radix branch
// would have triggered. Random-data benchmarks suggested 16384, but real
// SWE-Bench workloads see Heuristic ~6.3 us faster than random (preIdx-vs-
// logits ~99% correlated → P1 stats accurate → P2 secant 1-2 iter), shifting
// the real crossover into the [12288, 16384] band. Below 12288 the Insertion
// path is still used (canUseHeuristic gating + dispatcher fallback both
// route there). Configurable via TRTLLM_HEURISTIC_NMIN env (>=1024).
static int sCachedNMin = 0;
if (sCachedNMin == 0)
{
constexpr int kSeqSmallDefault = 12288;
char const* env = std::getenv("TRTLLM_HEURISTIC_NMIN");
if (env != nullptr)
{
int const v = std::atoi(env);
sCachedNMin = (v >= 1024 && v <= 200000) ? v : kSeqSmallDefault;
}
else
{
sCachedNMin = kSeqSmallDefault;
}
}
int const kSeqSmall = sCachedNMin;

bool const canUseHeuristic = preIdx != nullptr && stride1 == 1 && topK == kHeuristicTopK
&& preIdxCount == kHeuristicSize && preIdxStride >= preIdxCount && numColumns < effectiveSplitWorkThreshold
&& heuristicScratch != nullptr;
&& numColumns >= kSeqSmall && heuristicScratch != nullptr && numRows < kBsLarge;

// Optional env-gated dispatch trace (set TRTLLM_SCHEMEX_DEBUG=1 to enable)
{
static bool sDebugChecked = false;
static bool sDebug = false;
if (!sDebugChecked)
{
char const* env = std::getenv("TRTLLM_SCHEMEX_DEBUG");
sDebug = (env != nullptr && env[0] == '1');
sDebugChecked = true;
}
if (sDebug)
{
fprintf(stderr,
"[Scheme X v1.2] numRows=%d numColumns=%d kBsWave=%d kBsL2=%d kBsLarge=%d kSeqSmall=%d smCount=%d "
"L2=%dMB -> %s path%s\n",
numRows, numColumns, kBsWave, kBsL2, kBsLarge, kSeqSmall, sCachedSmCount,
sCachedL2Bytes / (1024 * 1024), canUseHeuristic ? "Heuristic" : "Radix",
(numColumns < kSeqSmall) ? " (small-N route)" : "");
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the indexerTopK.cu file
find . -name "indexerTopK.cu" -type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 105


🏁 Script executed:

# Check the specific lines 740-803 in the file
if [ -f "cpp/tensorrt_llm/kernels/indexerTopK.cu" ]; then
  wc -l cpp/tensorrt_llm/kernels/indexerTopK.cu
fi

Repository: NVIDIA/TensorRT-LLM

Length of output: 107


🏁 Script executed:

# Read lines around 740-803
if [ -f "cpp/tensorrt_llm/kernels/indexerTopK.cu" ]; then
  sed -n '740,803p' cpp/tensorrt_llm/kernels/indexerTopK.cu
fi

Repository: NVIDIA/TensorRT-LLM

Length of output: 2988


🏁 Script executed:

# Search for getEnvEnablePDL to verify the reference pattern
rg "getEnvEnablePDL" --type cpp --type h -A 10 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 28720


🏁 Script executed:

# Check for std::call_once usage in the codebase
rg "std::call_once" --type cpp --type h -B 2 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 7434


🏁 Script executed:

# Look at the function containing these statics to understand the context
grep -n "invokeIndexerTopKDecode" cpp/tensorrt_llm/kernels/indexerTopK.cu | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 172


Protect Scheme X cache initialization with std::call_once to prevent data races.

The static caches sCachedSmCount, sCachedL2Bytes, sCachedNMin, and the debug flags are mutable function-local statics updated without synchronization. In multi-threaded scenarios, multiple host threads calling invokeIndexerTopKDecode() concurrently can race during initialization—multiple threads may observe zero values, enter the if-block, and write to the same statics simultaneously. Use std::call_once with a static std::once_flag, following the established pattern in getEnvEnablePDL() and elsewhere in the codebase.

Suggested fix
+#include <mutex>
 ...
+    static std::once_flag sSchemeXInit;
     static int sCachedSmCount = 0;
     static int sCachedL2Bytes = 0;
-    if (sCachedSmCount == 0 || sCachedL2Bytes == 0)
+    static int sCachedNMin = 0;
+    static bool sDebugChecked = false;
+    static bool sDebug = false;
+    
+    std::call_once(sSchemeXInit, [] {
+        int dev = 0;
+        cudaGetDevice(&dev);
+        cudaDeviceGetAttribute(&sCachedSmCount, cudaDevAttrMultiProcessorCount, dev);
+        cudaDeviceGetAttribute(&sCachedL2Bytes, cudaDevAttrL2CacheSize, dev);
+        
+        constexpr int kSeqSmallDefault = 12288;
+        char const* env = std::getenv("TRTLLM_HEURISTIC_NMIN");
+        if (env != nullptr)
+        {
+            int const v = std::atoi(env);
+            sCachedNMin = (v >= 1024 && v <= 200000) ? v : kSeqSmallDefault;
+        }
+        else
+        {
+            sCachedNMin = kSeqSmallDefault;
+        }
+        
+        env = std::getenv("TRTLLM_SCHEMEX_DEBUG");
+        sDebug = (env != nullptr && env[0] == '1');
+        sDebugChecked = true;
+    });
-    {
-        int dev = 0;
-        cudaGetDevice(&dev);
-        cudaDeviceGetAttribute(&sCachedSmCount, cudaDevAttrMultiProcessorCount, dev);
-        cudaDeviceGetAttribute(&sCachedL2Bytes, cudaDevAttrL2CacheSize, dev);
-    }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/indexerTopK.cu` around lines 740 - 803, The
function-level mutable statics sCachedSmCount, sCachedL2Bytes, sCachedNMin and
the debug flags must be initialized once to avoid races; wrap all their
initialization (the cudaDeviceGetAttribute calls, the TRTLLM_HEURISTIC_NMIN
getenv logic, and the TRTLLM_SCHEMEX_DEBUG getenv/flags) in a single
std::call_once using a static std::once_flag (e.g., sSchemeXInitFlag) so
concurrent callers of invokeIndexerTopKDecode() cannot race; move the existing
if (sCached... == 0) blocks into a single lambda invoked by std::call_once and
remove the separate sDebugChecked pattern so sDebug is set under the same
call_once guard.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45585 [ run ] completed with state SUCCESS. Commit: 7f6ed4c
/LLM/main/L0_MergeRequest_PR pipeline #35801 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Three fixes responding to the automated review of the Scheme X / GVR
Top-K PR:

1. heuristic_topk.cuh: switch launchHeuristicTopK to cudaLaunchKernelEx
   with cudaLaunchAttributeProgrammaticStreamSerialization so the kernel
   epilogue's cudaTriggerProgrammaticLaunchCompletion() actually takes
   effect. Honors TRTLLM_ENABLE_PDL=0 via std::getenv to stay self-
   contained (this header is also reused by the standalone JIT-compiled
   PyTorch extension under ablation_study/, which cannot pull in
   tensorrt_llm/common headers).

2. indexerTopK.cu: wrap the three function-local static caches inside
   invokeIndexerTopKDecode (sm count + L2 capacity, kSeqSmall, debug
   flag) in std::call_once with a once_flag to remove the data race on
   first concurrent entry. Pattern matches getEnvEnablePDL() in
   tensorrt_llm/common/envUtils.cpp.

3. cpp_custom_ops.py: add a module-level idempotency guard keyed by
   (device, top_k, hint_size, num_cols) around warmup_heuristic_topk_
   decode so repeated Indexer constructions in the same process do not
   re-allocate scratch tensors or issue redundant synchronizations.

Verified: rebuilt tensorrt_llm + th_common, smoke-tested both Heuristic
and Radix dispatch paths plus the standalone JIT extension.

Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45604 [ run ] triggered by Bot. Commit: 7c98137 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45604 [ run ] completed with state SUCCESS. Commit: 7c98137
/LLM/main/L0_MergeRequest_PR pipeline #35819 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45647 [ run ] triggered by Bot. Commit: 7c98137 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45647 [ run ] completed with state FAILURE. Commit: 7c98137
/LLM/main/L0_MergeRequest_PR pipeline #35860 completed with status: 'ABORTED'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45666 [ run ] triggered by Bot. Commit: 7c98137 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45666 [ run ] completed with state SUCCESS. Commit: 7c98137
/LLM/main/L0_MergeRequest_PR pipeline #35874 completed with status: 'SUCCESS'

CI Report

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants