Skip to content

Add MLX_SDPA_BLOCKS env var for 2-pass vector kernel block-count override#3455

Open
adurham wants to merge 3 commits intoml-explore:mainfrom
adurham:sdpa-blocks-override
Open

Add MLX_SDPA_BLOCKS env var for 2-pass vector kernel block-count override#3455
adurham wants to merge 3 commits intoml-explore:mainfrom
adurham:sdpa-blocks-override

Conversation

@adurham
Copy link
Copy Markdown

@adurham adurham commented Apr 26, 2026

Summary

Adds an MLX_SDPA_BLOCKS env var that overrides the heuristic-chosen blocks (partial-tile count) for the 2-pass SDPA vector kernel. Unset or non-positive: heuristic unchanged.

The 2-pass kernel currently picks blocks from a device-class + sequence-length matrix in scaled_dot_product_attention.cpp. The defaults work well for most upstream-tested combinations but leave headroom on (device, head-count, sequence-length) tuples the heuristic doesn't anticipate. Letting operators sweep at runtime without rebuilding MLX is the value.

Empirical example

On a 2-rank M4-Ultra cluster running a 256-expert long-context MoE at K≈50k:

  • Heuristic picks blocks=1024
  • MLX_SDPA_BLOCKS=88 is +6.5% decode tps with a sharp cliff at 92+
  • Matches the ~352-concurrent-simdgroup capacity (4 kv_heads × 88 ≈ 1.1 dispatch rounds)

Different workloads will sit at different optima. We've been carrying this in a fork for a few weeks specifically to support that kind of tuning sweep.

Files

  • mlx/backend/metal/scaled_dot_product_attention.cppsdpa_2pass_blocks_override() helper + one call site after the existing heuristic block
  • python/tests/test_fast_sdpa.pytest_sdpa_blocks_env_override: sweeps {16, 64, 256}, asserts numerics match heuristic-default output

Notes for reviewers

  • Currently only wired into the bf16 2-pass path (the only one in upstream). The fork additionally applies the same override to a quantized 2-pass kernel that doesn't exist upstream — easy to extend later if you ever land that.
  • I considered a build-time #define but a runtime env var is the only thing that supports sweeping without rebuilding. Happy to switch to mx.metal.set_sdpa_blocks(int) runtime API if you'd rather not key off env vars.
  • Default (unset) preserves the existing heuristic exactly.

…count

The 2-pass SDPA vector kernel picks `blocks` (the partial-tile count)
from a heuristic over device class + sequence length. The defaults are
sensible across the upstream-tested matrix but leave money on the table
on combinations the heuristic doesn't anticipate.

This adds an `MLX_SDPA_BLOCKS` env var that overrides the heuristic to
a positive integer. Unset / non-positive: heuristic unchanged.

Empirical example: on a 2-rank M4-Ultra cluster running long-context
MoE inference at K~50k, the heuristic picks 1024 but `blocks=88` is
+6.5% decode tps with a sharp cliff at 92+, matching the
~352-concurrent-simdgroup capacity (4 kv_heads × 88 ≈ 1.1 dispatch
rounds). Different workloads will sit at different optima — letting
operators sweep without recompiling MLX is the value.

Also adds a regression test (`test_sdpa_blocks_env_override`) that
sweeps {16, 64, 256} and asserts numerics match the heuristic-default
output, so future changes to the 2-pass dispatch path don't silently
break the override.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is not really testing anything, I'm OK with no test for the env.

Comment thread mlx/backend/metal/scaled_dot_product_attention.cpp Outdated
Comment thread mlx/backend/metal/scaled_dot_product_attention.cpp Outdated
- Rename helper away from C++ keyword `override`
  (sdpa_2pass_blocks_override -> sdpa_2pass_blocks_from_env, and the
  local `int override` at the call site -> `int blocks_env`).
- Use existing env::get_var helper instead of manual std::getenv +
  std::atoi; drop the now-unused <cstdlib> include.
- Drop test_sdpa_blocks_env_override per reviewer (it doesn't exercise
  behavior the existing 2-pass tests don't already cover).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@adurham
Copy link
Copy Markdown
Author

adurham commented Apr 27, 2026

Thanks for the review! Addressed in 4b87e47:

  • Renamed helper away from override (sdpa_2pass_blocks_overridesdpa_2pass_blocks_from_env, and the local int override at the call site → int blocks_env).
  • Switched to env::get_var("MLX_SDPA_BLOCKS", 0) instead of the manual std::getenv + std::atoi; dropped the now-unused <cstdlib> include.
  • Dropped test_sdpa_blocks_env_override.

Left as a separate commit for review; happy to squash before merge.

Comment thread mlx/backend/metal/scaled_dot_product_attention.cpp Outdated
zcbenz noted the helper function was unnecessary. Inline
env::get_var("MLX_SDPA_BLOCKS", 0) directly at the call site.
@adurham
Copy link
Copy Markdown
Author

adurham commented Apr 27, 2026

Done in a7a77ab — helper removed, env::get_var inlined at the call site as suggested.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants