Skip to content

Handle non-multiple-of-8 spatial dims in depthwise conv2d Metal path#3446

Open
qflen wants to merge 1 commit intoml-explore:mainfrom
qflen:fix/depthwise-conv2d-nonmod8
Open

Handle non-multiple-of-8 spatial dims in depthwise conv2d Metal path#3446
qflen wants to merge 1 commit intoml-explore:mainfrom
qflen:fix/depthwise-conv2d-nonmod8

Conversation

@qflen
Copy link
Copy Markdown

@qflen qflen commented Apr 24, 2026

Summary

Fixes #3324. Relaxes the oS % 8 == 0 dispatch gate on the depthwise Metal kernel and makes it tail-aware, so non-mod-8 output spatial dims stay on the fast path instead of falling back to the gemm path.

Root cause

depthwise_conv_2D_gpu dispatches (tc=8, tw=8, th=4) threadgroups across (C/tc, oW/tw, (oH/th)*N). In dispatch_conv_2D_gpu the selection guard is

conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0

so any non-mod-8 output spatial shape routes to implicit_gemm_conv_2D_gpu / explicit_gemm_conv_group_ND_gpu, which is 2-5x slower for depthwise on M-series. The issue title says "mod 16" but the real boundary is mod 8: 240 (not mod 16, but mod 8) stays on the fast path, while 60, 30, 15 do not.

Change

  • mlx/backend/metal/conv.cpp
    • depthwise_conv_2D_gpu: ceil-div grid_dims for the two spatial dims so tail tiles are dispatched.
    • dispatch_conv_2D_gpu: drop the oS % 8 == 0 guard.
  • mlx/backend/metal/kernels/conv.metal
    • depthwise_conv_2d: ceil-div n_tgblocks_h to match the dispatch, and early-return from threads whose oh / ow are past oS before the output store. The shared-memory input load is already bounds-checked (ih/iw guards), and out-of-range threads still cooperate in that load, so only the final store needs guarding.

The C, kernel-size, stride, and wt_strides constraints on the depthwise path are unchanged.

Perf (M5 32GB, issue reproducer, 3x3 / pad=1 / stride=1 / depthwise, median of 3 runs)

channels res baseline ms patched ms speedup
384 60x60 1.64 0.34 4.8x
768 30x30 0.94 0.28 3.4x
1536 15x15 0.61 0.23 2.7x

Mod-8 shapes are unchanged (within +/-2% across runs). The previously-slow shapes now match their mod-8 siblings: 60x60 @ 384ch is 0.34 ms vs 64x64 @ 384ch at 0.35 ms.

Tests

  • Added test gpu depthwise conv2d non-mod-8 spatial in tests/gpu_tests.cpp. It computes depthwise conv2d on GPU and CPU for six shapes (issue reproducer sizes, non-square, asymmetric spatial alignment, stride 2) and asserts allclose at 1e-4.
  • Full ./build/tests/tests suite: 245 passed, 0 failed.

Checklist

  • Ran clang-format on all changed files (no changes).
  • Added a test (tests/gpu_tests.cpp).
  • Benchmarked before / after, numbers above.
  • No API changes; no doc changes needed.

The depthwise Metal kernel dispatches (tc=8, tw=8, th=4)-sized
threadgroups across (C/tc, oW/tw, (oH/th)*N) blocks and gates itself
in dispatch_conv_2D_gpu with `oS[0] % 8 == 0 && oS[1] % 8 == 0`. Any
other output spatial shape falls back to the generic implicit/explicit
gemm path, which is 2-5x slower on M-series. Multi-stage vision
encoders hit this at every non-mod-8 intermediate feature map.

Relax the gate and make the kernel tail-aware:
- dispatch_conv_2D_gpu: ceil-div `grid_dims` for oS so tail tiles are
  launched, and drop the `oS % 8 == 0` gate.
- depthwise_conv_2d: ceil-div `n_tgblocks_h` to match the dispatch,
  and skip the output write when `oh` or `ow` is past `oS`. The load
  path is already bounds-checked, and out-of-range threads still
  cooperate in shared-memory loads — only the store needs guarding.

Measured on M5 32GB with the issue reproducer (3x3, pad=1, stride=1,
depthwise, median of 3 runs):

| channels | res    | baseline ms | patched ms | speedup |
|---------:|:-------|------------:|-----------:|--------:|
|      384 | 60x60  |        1.64 |       0.34 |   4.8x  |
|      768 | 30x30  |        0.94 |       0.28 |   3.4x  |
|     1536 | 15x15  |        0.61 |       0.23 |   2.7x  |

Mod-8 shapes are unaffected (within ±2% across runs). Adds a GPU-vs-CPU
allclose test in tests/gpu_tests.cpp covering non-mod-8 3x3/5x5 shapes,
asymmetric spatial alignment, and stride-2 downsampling.

Closes ml-explore#3324.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Depthwise Conv2d performance degrades at non-mod-16 spatial dimensions

1 participant