[TRTLLM-12242][feat] Add Marlin NVFP4 backend for MoE and Linear on Hopper#13476
Draft
xuantengh wants to merge 1 commit intoNVIDIA:mainfrom
Draft
[TRTLLM-12242][feat] Add Marlin NVFP4 backend for MoE and Linear on Hopper#13476xuantengh wants to merge 1 commit intoNVIDIA:mainfrom
xuantengh wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Introduce a Hopper-only (SM 9.x) W4A16 NVFP4 execution path backed by the
Marlin kernel family. The new backend is strictly opt-in: users select it
via ``moe_config.backend: MARLIN`` for MoE layers and
``nvfp4_gemm_config.allowed_backends: [marlin]`` for dense / Mamba Linear
layers. Default paths on Blackwell (CUTLASS / cuBLASLt / CUDA-core /
CuteDSL) and the TRT backend are unchanged.
Kernels (cpp/):
* cpp/tensorrt_llm/kernels/marlin/: W4A16 NVFP4 GEMM, fused MoE GEMM, and
GPTQ-style weight repack kernels with BF16 MMA and in-register FP4→BF16
dequant.
* cpp/tensorrt_llm/thop/marlinNvfp4MM.cpp, marlinNvfp4MoeMM.cpp,
marlinRepack.cpp: Torch operator bindings registered as
``trtllm::marlin_nvfp4_gemm``, ``trtllm::marlin_nvfp4_moe_gemm``, and
``trtllm::gptq_marlin_repack``.
Python backend wiring:
* ``MarlinNVFP4Runner`` plugs into ``NVFP4GemmUnifiedRunner`` alongside
the existing CUTLASS / cuBLASLt / CUDA-core / CuteDSL runners; only
contributes tactics when SM is 90-99. Weight repack is done eagerly in
``get_valid_tactics`` so that ``forward()`` is CUDA-graph safe.
* ``MarlinNVFP4LinearMethod`` handles the dense Linear path, gated by
``nvfp4_allowed_backends == ["marlin"]``.
* ``MarlinFusedMoE`` (NVFP4 only) is registered in ``create_moe``,
``ConfigurableMoE``, and ``MoeConfig.backend`` literal.
* Attention / MLP / RMSNorm learn about the Marlin-only mode via a new
``is_marlin_only`` gate that disables CUTLASS-specific fused quantized
paths which cannot feed a Marlin Linear.
Production correctness:
* The existing SM 9.x/10.x range for ``fused_add_rms_norm_quant`` is
preserved; Marlin only further gates it when marlin-only is selected.
* ``MLP._use_fused_relu2_quant`` now requires SM >= 100 (matches the
kernel's ``__CUDA_ARCH__ >= 1000`` guard in fusedActivationQuant.cu)
and is also disabled for marlin-only mode.
* Mamba2 flashinfer path additionally requires ``d_state`` in
{64, 128, 256} so Nemotron-Super-class models fall back to the native
path instead of silently running an unsupported config.
Tests:
* ``tests/unittest/trt/functional/test_fp4_gemm.py``: standalone NVFP4
Marlin GEMM test (Hopper-only).
* ``tests/unittest/_torch/thop/parallel/test_fp4_linear.py``: Linear
``nvfp4_allowed_backends=['marlin']`` path test.
* ``tests/unittest/_torch/modules/moe/test_moe_backend.py``: fused Marlin
MoE GEMM test plus ``MoeBackendType.MARLIN`` wiring in backend/quantize
utilities.
* ``tests/integration/defs/accuracy/test_llm_api_pytorch.py``:
``TestNemotronHNvFP4Marlin.test_nvfp4_marlin`` MMLU accuracy task.
``examples/configs/curated/nemotron-super-marlin.yaml`` provides a ready
reference serving config.
Signed-off-by: Xuanteng Huang <xuantengh@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Introduce a Hopper-only (SM 9.x) W4A16 NVFP4 execution path backed by the Marlin kernel family. The new backend is strictly opt-in: users select it via
moe_config.backend: MARLINfor MoE layers andnvfp4_gemm_config.allowed_backends: [marlin]for dense / Mamba Linear layers. Default paths on Blackwell (CUTLASS / cuBLASLt / CUDA-core / CuteDSL) and the TRT backend are unchanged.Kernels (cpp/):
trtllm::marlin_nvfp4_gemm,trtllm::marlin_nvfp4_moe_gemm, andtrtllm::gptq_marlin_repack.Python backend wiring:
MarlinNVFP4Runnerplugs intoNVFP4GemmUnifiedRunneralongside the existing CUTLASS / cuBLASLt / CUDA-core / CuteDSL runners; only contributes tactics when SM is 90-99. Weight repack is done eagerly inget_valid_tacticsso thatforward()is CUDA-graph safe.MarlinNVFP4LinearMethodhandles the dense Linear path, gated bynvfp4_allowed_backends == ["marlin"].MarlinFusedMoE(NVFP4 only) is registered increate_moe,ConfigurableMoE, andMoeConfig.backendliteral.is_marlin_onlygate that disables CUTLASS-specific fused quantized paths which cannot feed a Marlin Linear.Production correctness:
fused_add_rms_norm_quantis preserved; Marlin only further gates it when marlin-only is selected.MLP._use_fused_relu2_quantnow requires SM >= 100 (matches the kernel's__CUDA_ARCH__ >= 1000guard in fusedActivationQuant.cu) and is also disabled for marlin-only mode.d_statein {64, 128, 256} so Nemotron-Super-class models fall back to the native path instead of silently running an unsupported config.Tests:
tests/unittest/trt/functional/test_fp4_gemm.py: standalone NVFP4 Marlin GEMM test (Hopper-only).tests/unittest/_torch/thop/parallel/test_fp4_linear.py: Linearnvfp4_allowed_backends=['marlin']path test.tests/unittest/_torch/modules/moe/test_moe_backend.py: fused Marlin MoE GEMM test plusMoeBackendType.MARLINwiring in backend/quantize utilities.tests/integration/defs/accuracy/test_llm_api_pytorch.py:TestNemotronHNvFP4Marlin.test_nvfp4_marlinMMLU accuracy task.examples/configs/curated/nemotron-super-marlin.yamlprovides a ready reference serving config.@coderabbitai summary
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.