Add 1-bit affine quantization support#3478
Closed
dusterbloom wants to merge 4 commits intoml-explore:mainfrom
Closed
Add 1-bit affine quantization support#3478dusterbloom wants to merge 4 commits intoml-explore:mainfrom
dusterbloom wants to merge 4 commits intoml-explore:mainfrom
Conversation
Collaborator
|
Closing as a duplicate of #3161. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds support for
bits=1toaffine_quantize/affine_dequantize/quantized_matmul, enabling 1.25-bpw model serving in MLX (1 bit/weight +fp16 scale + bias per group of 128 input columns).
The 3 commits (originally authored by Pasha Khosravi, cherry-picked from
research work) are:
309f947bits < 2guard, adds bit-0/bit-1 → w_min/w_max codepath, ships Metal kernelaffine_dequantize_*_gs_128_b_1, adds Python tests17edfc906ee46cqmv_fasttail iteration for non-aligned K (26-line correctness fix)Total: 11 files changed, +484 / −98 in the kernel commit, plus 27 lines of
tail-iteration / dispatch hardening.
Motivation: 1.25-bpw production models
Several models in the wild are now shipping in 1.25-bpw (1 bit/weight +
group-128 affine scale/bias) format:
prism-ml/Bonsai-1.7B-mlx-1bit(~260 MB residency)prism-ml/Bonsai-8B-mlx-1bit(~1.25 GB residency)prism-ml/Bonsai-4B-mlx-1bitThese checkpoints declare
quantization.bits == 1inconfig.jsonandrequire the
affine_dequantize_*_gs_128_b_1Metal kernel that this PR adds.Validated end-to-end
The kernels have been integrated into the higgs inference engine
(PR #142) and validated on
real Bonsai-8B inference:
Runtime parity validated against the original
feat/magic-canvasresearch branch: matches within thermal noise on M4.
Test coverage
The first commit (
309f947) ships:python/tests/test_quantized.py— addstest_quantize_1bitcoveringaffine_quantizeround-trip +quantized_matmulcorrectness for the1-bit path (96 lines added)
python/tests/cuda_skip.py— explicit skip entry for the new test on CUDAThe two follow-up commits are correctness fixes guarded by the existing
mlx_quantize/mlx_qmmtest suite.Downstream chain
Once this PR merges, two follow-up PRs drop the fork dependency for downstream
Rust consumers:
ml-explore/mlx-c— bump submodule to a merged-mlx SHA, no signaturechanges anticipated (the v0.6.0-3 C bindings already accept
global_scaleas a nullable
mlx_array)oxideai/mlx-rs— 12-line plumbing fix insrc/ops/quantization.rsto pass null
global_scalearrays throughmlx_quantize/mlx_dequantize/mlx_qqmm(matches the v0.6.0-3 C signature)Acknowledgements
All three commits authored by Pasha Khosravi.
This PR is a packaging step to bring his research work into upstream MLX so
production model loaders can drop their fork chains.
Cherry-picks are clean against current
main(the branch base sits onupstream
ce45c52"[CUDA] Use qmv kernel for fp quantizations (#3239)").🤖 PR prepared with Claude Code