Reduce peak memory of large half-precision random.uniform/normal#3439
Draft
dogukanveziroglu wants to merge 5 commits intoml-explore:mainfrom
Draft
Reduce peak memory of large half-precision random.uniform/normal#3439dogukanveziroglu wants to merge 5 commits intoml-explore:mainfrom
dogukanveziroglu wants to merge 5 commits intoml-explore:mainfrom
Conversation
A new primitive that runs the entire uniform RNG pipeline (threefry
hash → fp32 normalize → clip → cast → affine) per-thread in registers
for half-precision GPU outputs. Avoids materializing the fp32
intermediate buffer that the standard bits()/divide()/astype() chain
requires; peak memory drops 3x → 1x of target.
Activation conditions (all required): half-precision dtype (bf16 or
fp16), even total output size, scalar low/high, single key (shape
{2}), GPU stream. Bit-exact with vanilla on the same seed; matches
the rbitsc kernel's interleaved counter layout.
Performance: 15.4x faster on (16384, 16384) bf16 (108 ms → 7 ms)
because the fp32 intermediate no longer transits L2/HBM. Small
shapes (<1 MB) pay a slight kernel-launch overhead — chunked path
threshold ensures the fast path only activates when the win
dominates.
CUDA mirror added in a follow-up commit (untested; algorithmic
transcription of the validated Metal kernel).
Splits large GPU random calls (output ≥ ~512 MB fp32-equivalent) along axis 0 into K independent sub-key chunks, computes each via the existing fp32-then-cast pipeline, and writes into a pre- allocated output via slice_update with eval per chunk. Per-chunk fp32 transients are freed between iterations; peak drops from 3x to ~1+2/K of target (1.09x at K=33 on the canary shape). Heuristic: K = ceil(fp32_bytes / 256 MB), clamped to [4, 256]. Profiled in path-c/19-K-isolation.md: theory matches measurement within 5% at K ≥ 32; allocator overhead at small K (2-16) adds 17-30% but amortizes away. Sub-key derivation via random::split is cryptographically independent and seed-deterministic. Same seed produces same chunked output across runs, but the bit pattern differs from vanilla (which uses one key for the whole shape). Same trade-off class as PR ml-explore#904; statistical quality preserved per-chunk (chunked unique-value count ≥ vanilla baseline). Activation rule (all required): GPU stream, scalar lo/hi, single key, fp32-equiv output size ≥ 512 MB, axis-0 dim ≥ 4. Falls back to vanilla path for everything else (small shapes, multi-key, broadcast bounds, CPU). normal() uses the same chunked pipeline when target dtype is bf16/fp16/fp32. Resolves OOM on (46341, 46341) bf16 normal: vanilla aborts at 12.88 GB peak, chunked completes at 4.69 GB. Tolerates up to ~11 GB of concurrent allocations on M4 16 GB before swap kicks in (path-c/21-active-ballast.md).
CUDA mirror of the Metal RandomUniform kernel (same threefry counter mapping, same per-thread fp32-then-cast in registers, same output dtype templating). Marked untested in code: no NVIDIA hardware on this branch's CI; algorithmic equivalence to the validated Metal kernel verified by inspection. TestRandomChunked: 8 tests targeting the chunked path (shapes ≥ 1 GB so chunking activates). Each test uses 5σ/√N statistical tolerance for distribution stats (not hand-tuned); seed reproducibility test confirms deterministic output; odd-first-dim test exercises chunk-remainder handling; unique-bit test asserts ≥ 2000 distinct bf16 values per million samples (PR ml-explore#2361 quality floor). Brings test_random.py coverage from 14 to 22 tests; full pytest remains 696 passed / 4 skipped / 9283 subtests on M4.
The chunked dispatch in mlx/random.cpp had two correctness gaps
discovered by an adversarial drawback sweep against vanilla:
1. fp32 chunking is strictly worse than vanilla. Vanilla fp32
uniform/normal already operate at ~1x output peak (the
intermediate IS the target dtype), so chunking adds K-fold
sub-key derivation + slice_update overhead with zero memory
benefit. Measured ~25% latency regression and ~25% higher
peak memory at 12K^2+ shapes. Restrict chunkable_dtype to
{bfloat16, float16}.
2. Both the fused RandomUniform primitive and the chunked path
are illegal inside mx.compile / mx.vmap / mx.grad: the fused
primitive throws on RandomUniform::vmap, and the chunked
path's per-chunk eval() is rejected by the tracer. Gate
both dispatches on !detail::in_tracing() so any transform
falls back to the vanilla pipeline (which uses RandomBits,
DEFINE_VMAP()-supported).
Headline canary unchanged: (46341, 46341) bf16 normal still
peaks at 4.7 GB on M4-16GB.
Pre-PR cleanup pass: remove internal investigation references
("Variant D1/D4", "Phase X", *.md filenames, drawback Phase
references) from comments, compress the chunked-path docstrings,
and tighten throw messages to drop implementation detail leakage.
Also:
- mlx/random.cpp: replace if-cascade clamps in pick_chunk_count
with std::clamp; mark chunked_fp32_then_cast and pick_chunk_count
static; drop the redundant inner key-shape check (single_key
already guarantees Shape{2}); inline single-use bool 'even'.
- mlx/backend/metal/kernels/random.metal: collapse the two-output
per-thread block to one expression each; drop the "Step 4"
reference.
- mlx/backend/metal/primitives.cpp: drop the 7-line debugging
postmortem about constant-buffer packing.
- .gitignore: drop the path-c-only .venv / python/mlx/lib entries.
No behavior change. 22/22 random tests + 708 full pytest pass;
canary (46341, 46341) bf16 normal still peaks at 4.69 GB.
Net diff: -69 lines.
Contributor
Author
|
@angeloskath could you give some thoughts please, thank you. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hi guys, so I've been messing around with
mlx.randomfor the last few days initially just trying to figure out whymx.random.normal((46341, 46341), dtype=mx.bfloat16)was crashing on my M4 16GB. I realised it uses too much memory then it supposed to do. I tried to make some changes about the calculation. Opening as a draft because I want your gut check before I polish anything further. I might try to climb a vertical flat wall but I am not sure :D. If what I made is dumb pls tell and guide me I would love to get some critisim.Heads up: I didn't tried on nvidia GPU yet I will try it soon...
What
mx.random.normal((46341, 46341), dtype=mx.bfloat16)aborts with"Insufficient Memory" on a 16 GB Apple-Silicon device. The standard
bits → divide → cast → minimum → mul → addchain holds threefp32-sized buffers at the same time — about 12.88 GB of peak for a
4.3 GB output.
This PR adds two narrow GPU paths to fix it. The headline canary
goes from abort (12.88 GB) to success at 4.69 GB peak on
M4-16GB.
How
Two new dispatch paths in
mlx/random.cpp:1. Fused
RandomUniformMetal kernel for half-precision uniformwhen bounds are scalar, total size is even, and a single key is in
flight. Computes
bits → divide → clip → cast → affineper threadin registers, so no fp32 intermediate ever lands in global memory.
The output is bit-identical to vanilla, peak drops from ~3x to
1x, and it's 5–10× faster on large shapes.
2. Chunked path that splits the existing fp32 pipeline along
axis 0 into K independent sub-keys (
K = ⌈bits_bytes / 256MB⌉,clamped to [4, 256]). Triggers when the fp32-equivalent size is
≥ 512 MB. Peak drops to about
(1 + 2/K) × output. Bytes differfrom vanilla because sub-keys differ, each chunk still does
fp32-then-cast.
Both paths are skipped when
detail::in_tracing()is true, soanything inside
vmap/compile/vjp/jvpfalls back tothe vanilla pipeline. fp32 is excluded from the chunked path on
purpose — vanilla fp32 already runs at ~1x output peak (the
intermediate IS the target dtype), so chunking only adds overhead.
Dispatch summary
Numbers (M4-16GB, 30 reps, fresh subprocess)
Half-precision uniform (fused kernel):
Half-precision normal (chunked):
fp32 (deliberately unchanged):
Quality
(mean/var/min/max identical for 10 tested seeds at 16384²).
vanilla; inter-chunk Pearson correlation < 0.011 at all chunk
boundaries; unique bf16 value count 2315 at 100K samples
identical seeds, same peak memory, ~2% better steady-state tok/s.
Trade-offs
Real but small:
slower because the fused kernel's launch overhead dominates.
Crossover above ~1K² gives the 5–90× speedups. Easy to add a
min-size guard if you'd prefer.
+9–11% slower in exchange for the 42–60% peak cut.
uniform path (the kernel emits two outputs per thread).
Tests
python/tests/test_random.py(8 new inTestRandomChunkedwith mathematically derived 5σ/√N tolerances)
Files
mlx/random.cpp— dispatch + chunking helpermlx/primitives.{h,cpp}—RandomUniformclassmlx/backend/metal/kernels/random.metal—runiformc<T>kernelmlx/backend/metal/primitives.cpp—RandomUniform::eval_gpumlx/backend/cpu/primitives.cpp—RandomUniform::eval_cpustubmlx/backend/cuda/random.cu— CUDA mirrorpython/tests/test_random.py—TestRandomChunkedRepro the canary
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes