[PyTorch] Support for cuDNN-backed flex attention#2984
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds experimental cuDNN-backed flex attention support to
Confidence Score: 4/5Safe to merge for non-training inference and explicitly-provided bprop cases; the shared per-device cuDNN handle and the untested omit-bprop training path are worth addressing before wider use. The core graph-build, caching, and variant-pack wiring are correct for the fully-specified path (both score_mod and score_mod_bprop supplied). The main concerns are: the shared cuDNN handle being overwritten in multi-stream environments, the untested/undocumented gradient behaviour when score_mod_bprop is omitted in training mode, and the order-sensitive tensor-dict cache key — none of which block typical single-stream training use, but any could cause silent correctness issues in less common configurations. flex_attention.py — per-device handle sharing and tensor-dict metadata ordering; test_flex_attention.py — training-without-bprop path not exercised and CPU-tensor device-mismatch bugs in CUDA correctness tests (noted in prior review threads). Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant DPA as DotProductAttention.forward
participant GAS as get_attention_backend
participant FA as FusedAttention.forward
participant FASMF as FusedAttentionWithScoreModFunc
participant Cache as _cudnn_score_mod_graph_cache
participant cuDNN as cuDNN Frontend
User->>DPA: forward(q,k,v, score_mod, score_mod_bprop, score_mod_tensors, score_mod_bprop_tensors)
DPA->>DPA: validate score_mod inputs
DPA->>GAS: "AttentionParams + AttentionRuntimeFlags(has_score_mod=True)"
GAS->>GAS: filter: disable Flash/Unfused, check dtype/format/mask/dropout
GAS->>GAS: select NVTE_F16_arbitrary_seqlen backend
GAS-->>DPA: "use_fused_attention=True"
DPA->>FA: forward(..., score_mod, score_mod_bprop, ...)
FA->>FA: "assert constraints (no FP8, no mask, vanilla softmax, dropout=0)"
FA->>FASMF: apply(is_training, q,k,v, score_mod, ...)
rect rgb(200, 230, 255)
Note over FASMF,cuDNN: Forward Pass
FASMF->>Cache: lookup fwd cache key
alt cache miss
FASMF->>cuDNN: "pygraph() + sdpa(score_mod=wrapped)"
cuDNN->>cuDNN: validate + build_operation_graph + build_plans
cuDNN-->>FASMF: compiled fwd graph entry
FASMF->>Cache: store entry
end
FASMF->>cuDNN: "execute(variant_pack={q,k,v,output,stats,score_mod_tensors})"
cuDNN-->>FASMF: output, stats
FASMF->>FASMF: save_for_backward(q,k,v,output,stats,...)
end
rect rgb(255, 220, 200)
Note over FASMF,cuDNN: Backward Pass
FASMF->>Cache: lookup bwd cache key
alt cache miss
FASMF->>cuDNN: pygraph() + sdpa_backward(score_mod, score_mod_bprop)
cuDNN->>cuDNN: validate + build_operation_graph + build_plans
cuDNN-->>FASMF: compiled bwd graph entry
FASMF->>Cache: store entry
end
FASMF->>cuDNN: "execute(variant_pack={q,k,v,o,dO,stats,dq,dk,dv,...})"
cuDNN-->>FASMF: dq, dk, dv
end
FASMF-->>User: output (fwd) / dq,dk,dv (bwd)
Reviews (6): Last reviewed commit: "Validate score_mod bprop tensor inputs" | Re-trigger Greptile |
| def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: | ||
| """Create a stable cache key for a score_mod callable.""" | ||
| if callback is None: | ||
| return None | ||
| self_obj = getattr(callback, "__self__", None) | ||
| func_obj = getattr(callback, "__func__", None) | ||
| if self_obj is not None and func_obj is not None: | ||
| return ("bound_method", id(self_obj), id(func_obj)) | ||
| return ("callable", id(callback)) |
There was a problem hiding this comment.
id()-based cache key is unsafe for parameterized bound-method score_mods
id(self_obj) identifies a Python object by its memory address. When a bound-method instance is garbage-collected, Python may immediately reuse that memory for a new instance. If the new instance belongs to the same class (same id(func_obj)), the cache key is identical, so _get_cudnn_score_mod_fwd_graph returns the old compiled graph even though the new instance might construct a structurally different computation — e.g., a score_mod class whose forward loops self.n_layers times. The wrong graph is executed without any error, silently producing incorrect attention outputs.
For stateless module-level functions this is fine (they're never GC'd), but any stateful class-based score_mod where different instances produce different graph topologies can hit this bug in long-running programs. Consider using type(self_obj) and a per-class sequence counter, or requiring callers to provide an explicit cache key.
| fused_attention_backend = tex.get_fused_attn_backend( | ||
| self.training, | ||
| q_type, | ||
| q_type, | ||
| dpa_utils.QKVLayout["bshd_bshd_bshd"], | ||
| dpa_utils.AttnBiasType["no_bias"], | ||
| dpa_utils.AttnMaskType["no_mask"], | ||
| dpa_utils.SoftmaxType["vanilla"], |
There was a problem hiding this comment.
get_fused_attn_backend availability check always uses bshd_bshd_bshd regardless of actual format
The score_mod path hard-codes dpa_utils.QKVLayout["bshd_bshd_bshd"] for the backend probe, even when the user passes qkv_format="sbhd". The result is only used to gate on NVTE_No_Backend, so in practice it likely works today because backend availability for a given dtype is layout-independent. However, if a future cuDNN version makes SBHD/BSHD support diverge, this probe would give a false-positive (accepts sbhd even though no backend supports it) or false-negative (rejects sbhd when it is actually supported). Using the real layout for the probe would make the check self-documenting and future-proof.
| ) | ||
|
|
||
| if context_parallel: | ||
| if score_mod is not None: |
There was a problem hiding this comment.
I think this should be in the else branch, because it doesn't support context parallelism. Something like this:
if context_parallel: elif score_mod is not None: else:
| raise ValueError( | ||
| "score_mod requires a cuDNN FusedAttention backend, but no fused " | ||
| "attention backend supports the provided inputs." | ||
| ) |
There was a problem hiding this comment.
For the score_mod path, I don't think we need to call tex.get_fused_attn_backend() and check if it's supported or not. If anything, we should add graph.validate() -> .... graph.build_plans() to dpa_utils.get_attention_backend(attention_params), but if that's too heavy-handed, we can only do the checks you had above (the asserts). Once those checks were added to dpa_utils.get_attention_backend, whether FusedAttention backend is run or not will be controlled by the following logic (just like with non-score_mod cases):
(
use_flash_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
_,
) = dpa_utils.get_attention_backend(attention_params)
|
|
||
| def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): | ||
| """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" | ||
| import cudnn # pylint: disable=import-outside-toplevel |
There was a problem hiding this comment.
Can you import the cudnn from 3rdparty/cudnn-frontend, instead of from the environment/system-wide installation? We have control over the version in 3rdparty/cudnn-frontend, but not the system one.
| @pytest.mark.parametrize("dtype", param_types) | ||
| @pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) | ||
| @pytest.mark.parametrize("scalar_loss", [False, True]) | ||
| def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): |
There was a problem hiding this comment.
Would @pytest.mark.parameterize("score_mod", ["causal", "softcap", "post_scale_bias"]) simplify the tests a bit, so that we don't have 3 separate tests with a lot of repeated code?
| score_mod: Callable, | ||
| score_mod_tensors: Optional[Dict[str, torch.Tensor]], | ||
| output_layer: torch.Tensor, | ||
| stats_bhs1: Optional[torch.Tensor], |
There was a problem hiding this comment.
I think we can just call this stats, even though it might only support bhs1 shape right now. On the C++ side, cuDNN does support th1 (for THD format) as well. Could we leave the name generic for now in case we want to add more support to it in the future?
| return output.contiguous() | ||
|
|
||
|
|
||
| def _bhsd_dim_stride( |
There was a problem hiding this comment.
We have a lot of small utility functions here - is there a way to pack them up a bit or group them in some way, so the code is easier to read? I know this is Python and we probably do need more than 2 functions (fwd+bwd) but could you please have a look into this? Thanks.
There was a problem hiding this comment.
I agree with this and was my first thought too.
We should club these function into a couple classes that can sit in this file at the very least.
However, I think this approach is still not the right approach. We should have a separate flex_attention.py file similar to context_parallel.py and backends.py can import it similar to how it imports the CP functions right now.
I strongly recommend this for two reasons :
- When we refactored attention as a whole early last year, the idea was to modularize attention. That was the reason CP was moved out of attention. With Flex attention's functionality and code in here being fairly decoupled from vanilla DPA, it should be easier to move it out. Leaving this code in here would add ~1000 lines of code that is not related to the vanilla DPA and would practically be undoing the refactoring work we did early last year. The same reason for moving CP to it's own file should also apply to Flex attention.
- A developer/user of TE PyT DPA should not have to worry about the details of flex attention. Similarly someones modifying flex should not be bogged down by the details of vanilla fused attn. Hence, decoupling is important to aid with debugging as well
| ) | ||
|
|
||
|
|
||
| def _score_mod_relative_position(score_mod_graph, score_tensor, _tensors): |
There was a problem hiding this comment.
We can just call this "post_scale_bias" to be consistent with our nomenclature elsewhere.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| if ( | ||
| inspect.isfunction(callback) | ||
| and callback.__closure__ is None | ||
| and "<locals>" not in callback.__qualname__ | ||
| ): | ||
| return ("function", callback.__module__, callback.__qualname__) |
There was a problem hiding this comment.
Module-level lambdas all share the same
__qualname__ = "<lambda>", so two different lambdas defined at module scope in the same file (e.g., sm1 = lambda g, s, t: s and sm2 = lambda g, s, t: g.neg(input=s)) would produce the identical cache key ("function", module, "<lambda>"). The second lambda would silently reuse the compiled graph from the first, computing wrong attention scores with no error. Named module-level functions are safe because their qualnames are unique, but lambdas are not. Excluding <lambda> from the cacheable path makes them _SCORE_MOD_UNCACHEABLE, which builds a fresh graph every call — the same safe fallback already used for closures and nested functions.
| if ( | |
| inspect.isfunction(callback) | |
| and callback.__closure__ is None | |
| and "<locals>" not in callback.__qualname__ | |
| ): | |
| return ("function", callback.__module__, callback.__qualname__) | |
| if ( | |
| inspect.isfunction(callback) | |
| and callback.__closure__ is None | |
| and "<locals>" not in callback.__qualname__ | |
| and "<lambda>" not in callback.__qualname__ | |
| ): | |
| return ("function", callback.__module__, callback.__qualname__) |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
| score_mod_kwargs = { | ||
| "score_mod": _score_mod_causal, | ||
| "score_mod_bprop": _score_mod_causal_bprop, | ||
| "score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, | ||
| "score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)}, | ||
| } |
There was a problem hiding this comment.
The
neg_inf and zero tensors are created on CPU (torch.full defaults to CPU), but the attention computation runs on CUDA. When cuDNN executes the graph it calls into CUDA kernels and expects all variant-pack tensors to reside on the compute device. Passing CPU tensors here will produce a device-mismatch error at graph execution time, causing both the "causal" test cases to fail.
| score_mod_kwargs = { | |
| "score_mod": _score_mod_causal, | |
| "score_mod_bprop": _score_mod_causal_bprop, | |
| "score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, | |
| "score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)}, | |
| } | |
| score_mod_kwargs = { | |
| "score_mod": _score_mod_causal, | |
| "score_mod_bprop": _score_mod_causal_bprop, | |
| "score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9, device="cuda")}, | |
| "score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0, device="cuda")}, | |
| } |
| score_mod : bool, default = False | ||
| Whether a score_mod callback was provided. | ||
| score_mod_bprop : bool, default = False | ||
| Whether a score_mod bprop callback was provided. |
There was a problem hiding this comment.
nit: If this is a bool, to match has_attention_mask, consider has_score_mod and has_score_mod_bprop instead ?
| logger.debug("Disabling all backends for max_logit with FP8 attention") | ||
|
|
||
| # Filter: score_mod | ||
| if score_mod_bprop and not score_mod: |
There was a problem hiding this comment.
What happens (is expected to happen) if score_mod_bprop=False and score_mod=True ?
There was a problem hiding this comment.
It's a perfectly legal case, if, for instance, score_mod is used only for masking.
| use_flash_attention = False | ||
| use_flash_attention_2 = False | ||
| use_flash_attention_3 = False | ||
| use_flash_attention_4 = False | ||
| use_fused_attention = False | ||
| use_unfused_attention = False |
There was a problem hiding this comment.
nit: Outside the scope of this PR but would be good to do in this or subsequent PR: having a function or something similar for when performing an action/query on all flash_attention vars
| if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4: | ||
| logger.debug("Disabling FlashAttention for score_mod") | ||
| use_flash_attention = False | ||
| use_flash_attention_2 = False | ||
| use_flash_attention_3 = False | ||
| use_flash_attention_4 = False |
There was a problem hiding this comment.
Consider this maybe ?:
| if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4: | |
| logger.debug("Disabling FlashAttention for score_mod") | |
| use_flash_attention = False | |
| use_flash_attention_2 = False | |
| use_flash_attention_3 = False | |
| use_flash_attention_4 = False | |
| if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4: | |
| logger.debug("Disabling FlashAttention for score_mod") | |
| use_flash_attention = False | |
| use_flash_attention_2 = False | |
| use_flash_attention_3 = False | |
| use_flash_attention_4 = False |
unless there's a good reason to do otherwise ?
| if use_unfused_attention: | ||
| logger.debug("Disabling UnfusedDotProductAttention for score_mod") | ||
| use_unfused_attention = False |
There was a problem hiding this comment.
Consider this maybe ?
| if use_unfused_attention: | |
| logger.debug("Disabling UnfusedDotProductAttention for score_mod") | |
| use_unfused_attention = False | |
| if use_unfused_attention: | |
| logger.debug("Disabling UnfusedDotProductAttention for score_mod") | |
| use_unfused_attention = False |
unless there's a good reason to do otherwise ?
| ) | ||
| global _attention_backends | ||
| if is_in_onnx_export_mode(): | ||
| if is_in_onnx_export_mode() and score_mod is None: |
There was a problem hiding this comment.
Is this necessary here if dpa_utils.get_attention_backend(attention_params) does get called in the else block below ?
The flash, fused, unfused would be set in there anyways rgiht ?
Or am I missing something ?
cc: @cyanguwa
| return output.contiguous() | ||
|
|
||
|
|
||
| def _bhsd_dim_stride( |
There was a problem hiding this comment.
I agree with this and was my first thought too.
We should club these function into a couple classes that can sit in this file at the very least.
However, I think this approach is still not the right approach. We should have a separate flex_attention.py file similar to context_parallel.py and backends.py can import it similar to how it imports the CP functions right now.
I strongly recommend this for two reasons :
- When we refactored attention as a whole early last year, the idea was to modularize attention. That was the reason CP was moved out of attention. With Flex attention's functionality and code in here being fairly decoupled from vanilla DPA, it should be easier to move it out. Leaving this code in here would add ~1000 lines of code that is not related to the vanilla DPA and would practically be undoing the refactoring work we did early last year. The same reason for moving CP to it's own file should also apply to Flex attention.
- A developer/user of TE PyT DPA should not have to worry about the details of flex attention. Similarly someones modifying flex should not be bogged down by the details of vanilla fused attn. Hence, decoupling is important to aid with debugging as well
| def _import_cudnn_frontend(): | ||
| """Import the vendored cuDNN frontend if built, otherwise use the installed package.""" | ||
| cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) | ||
| cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" | ||
| if ( | ||
| any(cudnn_frontend_package.glob("_compiled_module*")) | ||
| and cudnn_frontend_path not in sys.path | ||
| ): | ||
| sys.path.insert(0, cudnn_frontend_path) | ||
| return importlib.import_module("cudnn") | ||
|
|
There was a problem hiding this comment.
How about this?:
def _import_cudnn_frontend():
cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH)
cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn"
if (
any(cudnn_frontend_package.glob("_compiled_module*"))
and cudnn_frontend_path not in sys.path
):
sys.path.insert(0, cudnn_frontend_path)
return importlib.import_module("cudnn")
# Fall back
if importlib.util.find_spec("cudnn") is not None:
return importlib.import_module("cudnn")
# Fail with a message
raise ImportError(
"cuDNN Frontend Python package not found. "
"Install it with: pip install nvidia-cudnn-frontend"
)
| return out, max_logit, (None, None, None, d_softmax_offset) | ||
|
|
||
|
|
||
| def _score_mod_causal(score_mod_graph, score_tensor, tensors): |
There was a problem hiding this comment.
I would strongly recommend that similar to the CP tests we have a separate Flex attention test file. Firstly for modularization and secondly because the Flex attention tests do not really end up using the test_dot_product_attention() base test like other DPA tests in the file do so there's no code reuse reasons for it either.
These isolated ~800 lines of code can sit in it's own file if it isn't really using of the funtions in here directly but writing the flex tests as "new" tests or else the flex tests must reuse the DPA setup in here and integrate into that.
I've also shared more details on this in my comment in the backends.py file
cc: @cyanguwa
|
Thanks for creating this PR @vcherepanov-nv I was curious about:
|
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
|
Thanks for the thorough review!
I haven't done any benchmarking. Reportedly (from a Slack thread) score_mod can lead to significant perf gains if it allows to avoid mask materialization. For causal, I think I observed cuDNN choosing exactly the same kernel with score_mod and the explicit causal flag.
Sure, thanks for linking! |
There was a problem hiding this comment.
Thanks for the PR! A few comments;
0. Agree with all the comments from @KshitijLakhani and @cyanguwa, so just +1ed them
- A user doc specifying the design choices and the building blocks of graph caching would be valuable.
score_modseems like a argument more than a feature and so the error messaging could use something more substantial like "(TE/cuDNN) Flex Attention"- New arguments of the form
has_*inAttentionParamscould be avoided. If passingscore_mod,score_mod_tensors(which are hefty) is the blocker, could we create a encapsulating dataclass and pass that instead? user_supplied_seqlensis a big vague, it seems like just a derived variable - does it degenerate to meanpad_between_seqs=True?
Among other nits
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Description
Adds experimental PyTorch support for cuDNN-backed flex attention in
DotProductAttentionvia a newscore_modcallback path.Users can pass:
score_mod(graph, score, tensors) -> scorefor forward score modificationscore_mod_bprop(graph, dP, tensors) -> dPfor backwardWhen
score_mod_bpropis supplied, it is the user's responsibility to make it mathematically consistent withscore_mod. TE forwards this callback to cuDNN as provided and does not derive or validate the backward score transformation automatically.Supported score_mod configuration
The current cuDNN-backed Flex Attention path supports:
DotProductAttention/FusedAttentiontorch.TensorQ/K/V inputsattn_mask_type="no_mask"core_attention_bias_type="no_bias"with no explicit bias tensorattention_dropout=0.0num_splits=1The path is currently not supported with FP8,
fp8_output, THD format, explicitcu_seqlensinputs,pad_between_seqs, attention masks, attention bias, ALiBi, sliding-window attention, sink attention, dropout, KV cache, context parallelism, CUDA graph capture, checkpointed core attention, orreturn_max_logit.For deterministic execution, TE passes the deterministic setting through backend selection and forwards it to cuDNN Frontend
sdpa_backwardasuse_deterministic_algorithm. The score_mod forwardsdpacall does not take a separate deterministic flag.Fixes #2492.
Type of change
Changes
FusedAttentionWithScoreModFunc, a cuDNN frontend Python graph path for SDPA forward/backward withscore_modandscore_mod_bprop.DotProductAttention/FusedAttentionAPIs withscore_mod,score_mod_bprop,score_mod_tensors, andscore_mod_bprop_tensors.score_modonly selects supported cuDNN fused attention configurations.score_mod_graph_cache_key()for stateful callbacks, while leaving unsafe unkeyed bound methods uncached.Checklist: