Add MLX_SDPA_BLOCKS env var for 2-pass vector kernel block-count override#3455
Open
adurham wants to merge 3 commits intoml-explore:mainfrom
Open
Add MLX_SDPA_BLOCKS env var for 2-pass vector kernel block-count override#3455adurham wants to merge 3 commits intoml-explore:mainfrom
adurham wants to merge 3 commits intoml-explore:mainfrom
Conversation
…count
The 2-pass SDPA vector kernel picks `blocks` (the partial-tile count)
from a heuristic over device class + sequence length. The defaults are
sensible across the upstream-tested matrix but leave money on the table
on combinations the heuristic doesn't anticipate.
This adds an `MLX_SDPA_BLOCKS` env var that overrides the heuristic to
a positive integer. Unset / non-positive: heuristic unchanged.
Empirical example: on a 2-rank M4-Ultra cluster running long-context
MoE inference at K~50k, the heuristic picks 1024 but `blocks=88` is
+6.5% decode tps with a sharp cliff at 92+, matching the
~352-concurrent-simdgroup capacity (4 kv_heads × 88 ≈ 1.1 dispatch
rounds). Different workloads will sit at different optima — letting
operators sweep without recompiling MLX is the value.
Also adds a regression test (`test_sdpa_blocks_env_override`) that
sweeps {16, 64, 256} and asserts numerics match the heuristic-default
output, so future changes to the 2-pass dispatch path don't silently
break the override.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
zcbenz
reviewed
Apr 26, 2026
Collaborator
zcbenz
left a comment
There was a problem hiding this comment.
The test is not really testing anything, I'm OK with no test for the env.
zcbenz
reviewed
Apr 26, 2026
- Rename helper away from C++ keyword `override` (sdpa_2pass_blocks_override -> sdpa_2pass_blocks_from_env, and the local `int override` at the call site -> `int blocks_env`). - Use existing env::get_var helper instead of manual std::getenv + std::atoi; drop the now-unused <cstdlib> include. - Drop test_sdpa_blocks_env_override per reviewer (it doesn't exercise behavior the existing 2-pass tests don't already cover). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Author
|
Thanks for the review! Addressed in 4b87e47:
Left as a separate commit for review; happy to squash before merge. |
zcbenz
reviewed
Apr 27, 2026
zcbenz noted the helper function was unnecessary. Inline
env::get_var("MLX_SDPA_BLOCKS", 0) directly at the call site.
Author
|
Done in a7a77ab — helper removed, env::get_var inlined at the call site as suggested. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an
MLX_SDPA_BLOCKSenv var that overrides the heuristic-chosenblocks(partial-tile count) for the 2-pass SDPA vector kernel. Unset or non-positive: heuristic unchanged.The 2-pass kernel currently picks
blocksfrom a device-class + sequence-length matrix inscaled_dot_product_attention.cpp. The defaults work well for most upstream-tested combinations but leave headroom on (device, head-count, sequence-length) tuples the heuristic doesn't anticipate. Letting operators sweep at runtime without rebuilding MLX is the value.Empirical example
On a 2-rank M4-Ultra cluster running a 256-expert long-context MoE at K≈50k:
blocks=1024MLX_SDPA_BLOCKS=88is +6.5% decode tps with a sharp cliff at 92+Different workloads will sit at different optima. We've been carrying this in a fork for a few weeks specifically to support that kind of tuning sweep.
Files
mlx/backend/metal/scaled_dot_product_attention.cpp—sdpa_2pass_blocks_override()helper + one call site after the existing heuristic blockpython/tests/test_fast_sdpa.py—test_sdpa_blocks_env_override: sweeps{16, 64, 256}, asserts numerics match heuristic-default outputNotes for reviewers
#definebut a runtime env var is the only thing that supports sweeping without rebuilding. Happy to switch tomx.metal.set_sdpa_blocks(int)runtime API if you'd rather not key off env vars.