Add Experts4bit for 4-bit quantization of fused MoE experts#1965
Open
pjordanandrsn wants to merge 1 commit into
Open
Add Experts4bit for 4-bit quantization of fused MoE experts#1965pjordanandrsn wants to merge 1 commit into
pjordanandrsn wants to merge 1 commit into
Conversation
…ytes-foundation#1849) transformers v5 stores fused MoE experts as a single 3D nn.Parameter (e.g. OlmoeExperts, Qwen3MoeExperts), which the nn.Linear-based 4-bit walker skips. The experts stay in full precision and load_in_4bit barely shrinks the model (issue bitsandbytes-foundation#1849). Experts4bit holds gate_up_proj and down_proj packed in NF4/FP4 as plain nn.Parameter buffers, with per-expert absmax kept on the module itself. The forward pass dequantizes one expert at a time (a per-expert loop), mirroring the reference fused-experts forward. There is no Params4bit tensor-subclass machinery, so the module serializes through the default state_dict with no custom hooks. - from_float() quantizes existing bf16/fp16 expert stacks - enforces in_features % blocksize == 0 for clean per-expert blocking - double-quant (compress_statistics) and grouped-GEMM intentionally deferred for a first cut - tests: quant round-trip, forward vs. full-precision reference, state_dict round-trip, and validation guards
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Adds
bitsandbytes.nn.Experts4bit, a module that stores fused Mixture-of-Expertsweights in 4-bit (NF4/FP4) precision.
Fixes the memory issue in #1849: transformers v5 stores MoE experts as a single 3D
nn.Parameter(e.g.OlmoeExperts,Qwen3MoeExperts—gate_up_proj[num_experts, 2*intermediate, hidden],down_proj[num_experts, hidden, intermediate]).The
nn.Linear-based 4-bit walker only swapsnn.Linear, so these fused experts areskipped, stay in full precision, and dominate the loaded footprint.
Design
This follows the approach @matthewdouglas outlined on the issue:
nn.Parameterfor the packed weights (notParams4bit), with per-expertabsmaxkept on the module as buffers. This avoids bendingParams4bit'stensor-subclass + device-movement machinery around a 3D stack, and the module
serializes through the default
state_dict— no custom save/load hooks.forward(mirrors the reference fused-experts forward inOlmoeExperts/FP8Experts): one expert's weight is dequantized, used, and freed at atime. This keeps the runtime working set small and leaves a clean path to a grouped-GEMM
kernel later.
in_features % blocksize == 0so per-expert quantization blocks tile eachexpert exactly and never straddle an expert boundary.
Relationship to
replace_parameter_4bit(#1720): that generic parametrization alsoquantizes arbitrary
nn.Parameters, but dequantizes the entire[num_experts, …]stackon every access.
Experts4bitis MoE-aware — it only touches the experts a batch actuallyroutes to — which is what enables the grouped-GEMM follow-up.
Intentionally deferred for this first cut (per the issue discussion): double-quant
(
compress_statistics), a grouped-GEMM forward, and the transformers-side walker wiring.API
Footprint & validation (measured on an RTX A2000 12 GB, sm_86)
For one real OLMoE-1B-7B layer (
num_experts=64, hidden=2048, intermediate=1024, NF4,blocksize 64, no double-quant), measured
Experts4bitvs. the bf16 stack:Experts4bit(192 MB packed + 24 MB absmax)3.56× smaller for the expert weights, which are the bulk of the model — combined with
the existing
Linear4bitpath on the non-expert layers this takes OLMoE-1B-7B from ~13 GBto ~3.5 GB (fits a single 12 GB card). A forward over the real-sized layer peaks at
1295 MB of VRAM: because experts are dequantized one at a time, the working set never
materializes the full bf16 stack — the property that makes the grouped-GEMM follow-up
worthwhile.
Testing
tests/test_experts4bit.py— 11 cases, all green on the CPU default backend:packed-weight / absmax shape + dtype assertions
forwardvs. a full-precision reference forward (gated + non-gated), float32 compute,rtol=atol=1e-4state_dictround-trip: bit-exact restore of packed weights + absmax, identical forwardafter reload
in_features % blocksize, invalidquant_type)On CUDA (A2000, bnb 0.49.2 / torch 2.4.1) the NF4 round-trip mean-abs error is 0.0073 and
the forward matches the full-precision reference exactly (max-abs 0.0).
Closes #1849.
cc @matthewdouglas @SunMarc