Handle non-multiple-of-8 spatial dims in depthwise conv2d Metal path#3446
Open
qflen wants to merge 1 commit intoml-explore:mainfrom
Open
Handle non-multiple-of-8 spatial dims in depthwise conv2d Metal path#3446qflen wants to merge 1 commit intoml-explore:mainfrom
qflen wants to merge 1 commit intoml-explore:mainfrom
Conversation
The depthwise Metal kernel dispatches (tc=8, tw=8, th=4)-sized threadgroups across (C/tc, oW/tw, (oH/th)*N) blocks and gates itself in dispatch_conv_2D_gpu with `oS[0] % 8 == 0 && oS[1] % 8 == 0`. Any other output spatial shape falls back to the generic implicit/explicit gemm path, which is 2-5x slower on M-series. Multi-stage vision encoders hit this at every non-mod-8 intermediate feature map. Relax the gate and make the kernel tail-aware: - dispatch_conv_2D_gpu: ceil-div `grid_dims` for oS so tail tiles are launched, and drop the `oS % 8 == 0` gate. - depthwise_conv_2d: ceil-div `n_tgblocks_h` to match the dispatch, and skip the output write when `oh` or `ow` is past `oS`. The load path is already bounds-checked, and out-of-range threads still cooperate in shared-memory loads — only the store needs guarding. Measured on M5 32GB with the issue reproducer (3x3, pad=1, stride=1, depthwise, median of 3 runs): | channels | res | baseline ms | patched ms | speedup | |---------:|:-------|------------:|-----------:|--------:| | 384 | 60x60 | 1.64 | 0.34 | 4.8x | | 768 | 30x30 | 0.94 | 0.28 | 3.4x | | 1536 | 15x15 | 0.61 | 0.23 | 2.7x | Mod-8 shapes are unaffected (within ±2% across runs). Adds a GPU-vs-CPU allclose test in tests/gpu_tests.cpp covering non-mod-8 3x3/5x5 shapes, asymmetric spatial alignment, and stride-2 downsampling. Closes ml-explore#3324.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #3324. Relaxes the
oS % 8 == 0dispatch gate on the depthwise Metal kernel and makes it tail-aware, so non-mod-8 output spatial dims stay on the fast path instead of falling back to the gemm path.Root cause
depthwise_conv_2D_gpudispatches(tc=8, tw=8, th=4)threadgroups across(C/tc, oW/tw, (oH/th)*N). Indispatch_conv_2D_gputhe selection guard isso any non-mod-8 output spatial shape routes to
implicit_gemm_conv_2D_gpu/explicit_gemm_conv_group_ND_gpu, which is 2-5x slower for depthwise on M-series. The issue title says "mod 16" but the real boundary is mod 8:240(not mod 16, but mod 8) stays on the fast path, while60,30,15do not.Change
mlx/backend/metal/conv.cppdepthwise_conv_2D_gpu: ceil-divgrid_dimsfor the two spatial dims so tail tiles are dispatched.dispatch_conv_2D_gpu: drop theoS % 8 == 0guard.mlx/backend/metal/kernels/conv.metaldepthwise_conv_2d: ceil-divn_tgblocks_hto match the dispatch, and early-return from threads whoseoh/oware pastoSbefore the output store. The shared-memory input load is already bounds-checked (ih/iwguards), and out-of-range threads still cooperate in that load, so only the final store needs guarding.The C, kernel-size, stride, and
wt_stridesconstraints on the depthwise path are unchanged.Perf (M5 32GB, issue reproducer, 3x3 / pad=1 / stride=1 / depthwise, median of 3 runs)
Mod-8 shapes are unchanged (within +/-2% across runs). The previously-slow shapes now match their mod-8 siblings:
60x60 @ 384chis0.34 msvs64x64 @ 384chat0.35 ms.Tests
test gpu depthwise conv2d non-mod-8 spatialintests/gpu_tests.cpp. It computes depthwise conv2d on GPU and CPU for six shapes (issue reproducer sizes, non-square, asymmetric spatial alignment, stride 2) and assertsallcloseat1e-4../build/tests/testssuite: 245 passed, 0 failed.Checklist
clang-formaton all changed files (no changes).tests/gpu_tests.cpp).