[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds an experimental cuDNN frontend
Confidence Score: 5/5This PR is safe to merge — the core execution path is correct with no functional bugs found. Variant_pack UID-to-buffer mappings are correct, BHSD stride reinterpretation of BSHD tensors is accurate, packed-QKV unpacking covers all three layout branches, and the custom_vjp residuals carry exactly what the backward rule consumes. Version matching between Python and C++ cuDNN frontend is enforced at both graph-build and execution time. All findings are non-blocking style and observability suggestions. build_tools/jax.py — the silent omission of the cudnn-frontend include dir conflicts with the unconditional #include in attention.cpp and would produce confusing build errors. Important Files Changed
Reviews (7): Last reviewed commit: "Address JAX score_mod review feedback" | Re-trigger Greptile |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Description
Adds experimental JAX fused-attention
score_modsupport through cuDNN frontend SDPA graphs.This introduces a
score_mod(graph, score, tensors)callback path forfused_attn, plus optionalscore_mod_bprop(graph, dscore, tensors)support for backward. The Python side builds and serializes cuDNN frontend forward/backward graphs, caches graph metadata with stable callback keys, supports auxiliary tensor operands, and supports Python/NumPy scalar operands as cuDNN pass-by-value tensors. The C++ JAX extension deserializes and caches the graphs per device, then executes them through new forward/backward FFI handlers.The Flax API now plumbs
score_modthroughDotProductAttention,MultiHeadAttention, andTransformerLayer. Packed QKV/KV layouts are unpacked to the separate BSHD layout when score modification is requested.Users are responsible for supplying a mathematically correct
score_mod_bpropfor the correspondingscore_mod; Transformer Engine wires the callback into the cuDNN graph but does not validate gradient semantics.Current score_mod limitations:
BSHD_BSHD_BSHDQ/K/V tensors only.Fixes # (issue)
#2492
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: