Skip to content

Add Experts4bit for 4-bit quantization of fused MoE experts#1965

Open
pjordanandrsn wants to merge 1 commit into
bitsandbytes-foundation:mainfrom
pjordanandrsn:feature/experts-4bit
Open

Add Experts4bit for 4-bit quantization of fused MoE experts#1965
pjordanandrsn wants to merge 1 commit into
bitsandbytes-foundation:mainfrom
pjordanandrsn:feature/experts-4bit

Conversation

@pjordanandrsn

Copy link
Copy Markdown

What

Adds bitsandbytes.nn.Experts4bit, a module that stores fused Mixture-of-Experts
weights in 4-bit (NF4/FP4) precision.

Fixes the memory issue in #1849: transformers v5 stores MoE experts as a single 3D
nn.Parameter (e.g. OlmoeExperts, Qwen3MoeExpertsgate_up_proj
[num_experts, 2*intermediate, hidden], down_proj [num_experts, hidden, intermediate]).
The nn.Linear-based 4-bit walker only swaps nn.Linear, so these fused experts are
skipped, stay in full precision, and dominate the loaded footprint.

Design

This follows the approach @matthewdouglas outlined on the issue:

  • Plain nn.Parameter for the packed weights (not Params4bit), with per-expert
    absmax kept on the module as buffers
    . This avoids bending Params4bit's
    tensor-subclass + device-movement machinery around a 3D stack, and the module
    serializes through the default state_dict — no custom save/load hooks.
  • Per-expert dequant loop in forward (mirrors the reference fused-experts forward in
    OlmoeExperts / FP8Experts): one expert's weight is dequantized, used, and freed at a
    time. This keeps the runtime working set small and leaves a clean path to a grouped-GEMM
    kernel later.
  • Enforces in_features % blocksize == 0 so per-expert quantization blocks tile each
    expert exactly and never straddle an expert boundary.

Relationship to replace_parameter_4bit (#1720): that generic parametrization also
quantizes arbitrary nn.Parameters, but dequantizes the entire [num_experts, …] stack
on every access. Experts4bit is MoE-aware — it only touches the experts a batch actually
routes to — which is what enables the grouped-GEMM follow-up.

Intentionally deferred for this first cut (per the issue discussion): double-quant
(compress_statistics), a grouped-GEMM forward, and the transformers-side walker wiring.

API

from bitsandbytes.nn import Experts4bit

# Quantize an existing fp16/bf16 fused-expert stack:
experts = Experts4bit.from_float(gate_up_proj, down_proj, quant_type="nf4")
out = experts(hidden_states, top_k_index, top_k_weights)

# Or construct empty + load_state_dict (e.g. pre-quantized checkpoints):
experts = Experts4bit(num_experts, hidden_dim, intermediate_dim)
experts.load_state_dict(sd)

Footprint & validation (measured on an RTX A2000 12 GB, sm_86)

For one real OLMoE-1B-7B layer (num_experts=64, hidden=2048, intermediate=1024, NF4,
blocksize 64, no double-quant), measured Experts4bit vs. the bf16 stack:

per layer full model (×16 layers)
experts, bf16 (today) 768.0 MB 12.00 GB
experts, Experts4bit (192 MB packed + 24 MB absmax) 216.0 MB 3.38 GB

3.56× smaller for the expert weights, which are the bulk of the model — combined with
the existing Linear4bit path on the non-expert layers this takes OLMoE-1B-7B from ~13 GB
to ~3.5 GB (fits a single 12 GB card). A forward over the real-sized layer peaks at
1295 MB of VRAM: because experts are dequantized one at a time, the working set never
materializes the full bf16 stack — the property that makes the grouped-GEMM follow-up
worthwhile.

Testing

tests/test_experts4bit.py — 11 cases, all green on the CPU default backend:

  • quant round-trip per expert (NF4/FP4 × fp16/bf16/fp32) within 4-bit tolerance, with
    packed-weight / absmax shape + dtype assertions
  • forward vs. a full-precision reference forward (gated + non-gated), float32 compute,
    rtol=atol=1e-4
  • state_dict round-trip: bit-exact restore of packed weights + absmax, identical forward
    after reload
  • validation guards (in_features % blocksize, invalid quant_type)

On CUDA (A2000, bnb 0.49.2 / torch 2.4.1) the NF4 round-trip mean-abs error is 0.0073 and
the forward matches the full-precision reference exactly (max-abs 0.0).

Closes #1849.

cc @matthewdouglas @SunMarc

…ytes-foundation#1849)

transformers v5 stores fused MoE experts as a single 3D nn.Parameter
(e.g. OlmoeExperts, Qwen3MoeExperts), which the nn.Linear-based 4-bit
walker skips. The experts stay in full precision and load_in_4bit barely
shrinks the model (issue bitsandbytes-foundation#1849).

Experts4bit holds gate_up_proj and down_proj packed in NF4/FP4 as plain
nn.Parameter buffers, with per-expert absmax kept on the module itself.
The forward pass dequantizes one expert at a time (a per-expert loop),
mirroring the reference fused-experts forward. There is no Params4bit
tensor-subclass machinery, so the module serializes through the default
state_dict with no custom hooks.

- from_float() quantizes existing bf16/fp16 expert stacks
- enforces in_features % blocksize == 0 for clean per-expert blocking
- double-quant (compress_statistics) and grouped-GEMM intentionally
  deferred for a first cut
- tests: quant round-trip, forward vs. full-precision reference,
  state_dict round-trip, and validation guards
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Failed to quant MoE models with fused expert weights in transformers v5

1 participant