Skip to content

Reduce peak memory of large half-precision random.uniform/normal#3439

Draft
dogukanveziroglu wants to merge 5 commits intoml-explore:mainfrom
dogukanveziroglu:reduce-random-peak-memory
Draft

Reduce peak memory of large half-precision random.uniform/normal#3439
dogukanveziroglu wants to merge 5 commits intoml-explore:mainfrom
dogukanveziroglu:reduce-random-peak-memory

Conversation

@dogukanveziroglu
Copy link
Copy Markdown
Contributor

Hi guys, so I've been messing around with mlx.random for the last few days initially just trying to figure out why mx.random.normal((46341, 46341), dtype=mx.bfloat16) was crashing on my M4 16GB. I realised it uses too much memory then it supposed to do. I tried to make some changes about the calculation. Opening as a draft because I want your gut check before I polish anything further. I might try to climb a vertical flat wall but I am not sure :D. If what I made is dumb pls tell and guide me I would love to get some critisim.

Heads up: I didn't tried on nvidia GPU yet I will try it soon...

What

mx.random.normal((46341, 46341), dtype=mx.bfloat16) aborts with
"Insufficient Memory" on a 16 GB Apple-Silicon device. The standard
bits → divide → cast → minimum → mul → add chain holds three
fp32-sized buffers at the same time — about 12.88 GB of peak for a
4.3 GB output.

This PR adds two narrow GPU paths to fix it. The headline canary
goes from abort (12.88 GB) to success at 4.69 GB peak on
M4-16GB.

How

Two new dispatch paths in mlx/random.cpp:

1. Fused RandomUniform Metal kernel for half-precision uniform
when bounds are scalar, total size is even, and a single key is in
flight. Computes bits → divide → clip → cast → affine per thread
in registers, so no fp32 intermediate ever lands in global memory.
The output is bit-identical to vanilla, peak drops from ~3x to
1x, and it's 5–10× faster on large shapes.

2. Chunked path that splits the existing fp32 pipeline along
axis 0 into K independent sub-keys (K = ⌈bits_bytes / 256MB⌉,
clamped to [4, 256]). Triggers when the fp32-equivalent size is
≥ 512 MB. Peak drops to about (1 + 2/K) × output. Bytes differ
from vanilla because sub-keys differ, each chunk still does
fp32-then-cast.

Both paths are skipped when detail::in_tracing() is true, so
anything inside vmap / compile / vjp / jvp falls back to
the vanilla pipeline. fp32 is excluded from the chunked path on
purpose — vanilla fp32 already runs at ~1x output peak (the
intermediate IS the target dtype), so chunking only adds overhead.

Dispatch summary

When What runs
GPU + bf16/fp16 + scalar bounds + even N + single key + outside transforms fused kernel
(bit-equal vanilla)
GPU + bf16/fp16 + ≥ 512 MB + scalar bounds + single key + outside transforms chunked path
Everything else vanilla (unchanged)

Numbers (M4-16GB, 30 reps, fresh subprocess)

Half-precision uniform (fused kernel):

Shape Vanilla ms / GB This PR ms / GB
4096² bf16 7.83 / 0.10 1.67 / 0.034
16384² bf16 113 / 1.61 12.2 / 0.54
32768² bf16 448 / 6.44 44.7 / 2.15

Half-precision normal (chunked):

Shape Vanilla ms / GB This PR ms / GB Δ ms Δ peak
16384² bf16 164 / 1.61 182 / 0.94 +11% −42%
32768² bf16 649 / 6.44 718 / 2.55 +11% −60%
46341² bf16 abort 12.88 GB 1502 ms / 4.69 GB n/a fits

fp32 (deliberately unchanged):

Shape Vanilla ms / GB This PR ms / GB
16384² fp32 normal 164.61 / 1.07 164.57 / 1.07

Quality

  • bf16 uniform fused output is bit-equal to vanilla
    (mean/var/min/max identical for 10 tested seeds at 16384²).
  • bf16 normal chunked output: KS one-sample distribution matches
    vanilla; inter-chunk Pearson correlation < 0.011 at all chunk
    boundaries; unique bf16 value count 2315 at 100K samples
  • mlx-lm Qwen2.5-0.5B inference: identical sampled tokens at
    identical seeds, same peak memory, ~2% better steady-state tok/s.

Trade-offs

Real but small:

  • Half-precision uniform at very small shapes (≤256²) is +35–43%
    slower because the fused kernel's launch overhead dominates.
    Crossover above ~1K² gives the 5–90× speedups. Easy to add a
    min-size guard if you'd prefer.
  • Half-precision normal at chunked-trigger shapes (≥11K²) is
    +9–11% slower in exchange for the 42–60% peak cut.
  • Shapes with odd total element count fall back to vanilla on the
    uniform path (the kernel emits two outputs per thread).

Tests

  • 22/22 in python/tests/test_random.py (8 new in TestRandomChunked
    with mathematically derived 5σ/√N tolerances)
  • 708/708 full pytest

Files

  • mlx/random.cpp — dispatch + chunking helper
  • mlx/primitives.{h,cpp}RandomUniform class
  • mlx/backend/metal/kernels/random.metalruniformc<T> kernel
  • mlx/backend/metal/primitives.cppRandomUniform::eval_gpu
  • mlx/backend/cpu/primitives.cppRandomUniform::eval_cpu stub
  • mlx/backend/cuda/random.cu — CUDA mirror
  • python/tests/test_random.pyTestRandomChunked

Repro the canary

python -c "                                          
import mlx.core as mx                        
mx.clear_cache(); mx.reset_peak_memory()
mx.random.seed(0)                                                                                  
a = mx.random.normal(shape=(46341, 46341), dtype=mx.bfloat16); mx.eval(a)
print(f'CANARY peak={mx.get_peak_memory()/1e9:.2f} GB')                                            
"                                                    
# Vanilla: abort. This PR: 4.69 GB.                                                                

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

A new primitive that runs the entire uniform RNG pipeline (threefry
hash → fp32 normalize → clip → cast → affine) per-thread in registers
for half-precision GPU outputs. Avoids materializing the fp32
intermediate buffer that the standard bits()/divide()/astype() chain
requires; peak memory drops 3x → 1x of target.

Activation conditions (all required): half-precision dtype (bf16 or
fp16), even total output size, scalar low/high, single key (shape
{2}), GPU stream. Bit-exact with vanilla on the same seed; matches
the rbitsc kernel's interleaved counter layout.

Performance: 15.4x faster on (16384, 16384) bf16 (108 ms → 7 ms)
because the fp32 intermediate no longer transits L2/HBM. Small
shapes (<1 MB) pay a slight kernel-launch overhead — chunked path
threshold ensures the fast path only activates when the win
dominates.

CUDA mirror added in a follow-up commit (untested; algorithmic
transcription of the validated Metal kernel).
Splits large GPU random calls (output ≥ ~512 MB fp32-equivalent)
along axis 0 into K independent sub-key chunks, computes each via
the existing fp32-then-cast pipeline, and writes into a pre-
allocated output via slice_update with eval per chunk. Per-chunk
fp32 transients are freed between iterations; peak drops from 3x
to ~1+2/K of target (1.09x at K=33 on the canary shape).

Heuristic: K = ceil(fp32_bytes / 256 MB), clamped to [4, 256].
Profiled in path-c/19-K-isolation.md: theory matches measurement
within 5% at K ≥ 32; allocator overhead at small K (2-16) adds
17-30% but amortizes away.

Sub-key derivation via random::split is cryptographically
independent and seed-deterministic. Same seed produces same
chunked output across runs, but the bit pattern differs from
vanilla (which uses one key for the whole shape). Same trade-off
class as PR ml-explore#904; statistical quality preserved per-chunk
(chunked unique-value count ≥ vanilla baseline).

Activation rule (all required): GPU stream, scalar lo/hi, single
key, fp32-equiv output size ≥ 512 MB, axis-0 dim ≥ 4. Falls back
to vanilla path for everything else (small shapes, multi-key,
broadcast bounds, CPU). normal() uses the same chunked pipeline
when target dtype is bf16/fp16/fp32.

Resolves OOM on (46341, 46341) bf16 normal: vanilla aborts at
12.88 GB peak, chunked completes at 4.69 GB. Tolerates up to
~11 GB of concurrent allocations on M4 16 GB before swap kicks
in (path-c/21-active-ballast.md).
CUDA mirror of the Metal RandomUniform kernel (same threefry counter
mapping, same per-thread fp32-then-cast in registers, same output
dtype templating). Marked untested in code: no NVIDIA hardware on
this branch's CI; algorithmic equivalence to the validated Metal
kernel verified by inspection.

TestRandomChunked: 8 tests targeting the chunked path (shapes ≥
1 GB so chunking activates). Each test uses 5σ/√N statistical
tolerance for distribution stats (not hand-tuned); seed
reproducibility test confirms deterministic output; odd-first-dim
test exercises chunk-remainder handling; unique-bit test asserts
≥ 2000 distinct bf16 values per million samples (PR ml-explore#2361 quality
floor).

Brings test_random.py coverage from 14 to 22 tests; full pytest
remains 696 passed / 4 skipped / 9283 subtests on M4.
The chunked dispatch in mlx/random.cpp had two correctness gaps
discovered by an adversarial drawback sweep against vanilla:

1. fp32 chunking is strictly worse than vanilla. Vanilla fp32
   uniform/normal already operate at ~1x output peak (the
   intermediate IS the target dtype), so chunking adds K-fold
   sub-key derivation + slice_update overhead with zero memory
   benefit. Measured ~25% latency regression and ~25% higher
   peak memory at 12K^2+ shapes. Restrict chunkable_dtype to
   {bfloat16, float16}.

2. Both the fused RandomUniform primitive and the chunked path
   are illegal inside mx.compile / mx.vmap / mx.grad: the fused
   primitive throws on RandomUniform::vmap, and the chunked
   path's per-chunk eval() is rejected by the tracer. Gate
   both dispatches on !detail::in_tracing() so any transform
   falls back to the vanilla pipeline (which uses RandomBits,
   DEFINE_VMAP()-supported).

Headline canary unchanged: (46341, 46341) bf16 normal still
peaks at 4.7 GB on M4-16GB.
Pre-PR cleanup pass: remove internal investigation references
("Variant D1/D4", "Phase X", *.md filenames, drawback Phase
references) from comments, compress the chunked-path docstrings,
and tighten throw messages to drop implementation detail leakage.

Also:
- mlx/random.cpp: replace if-cascade clamps in pick_chunk_count
  with std::clamp; mark chunked_fp32_then_cast and pick_chunk_count
  static; drop the redundant inner key-shape check (single_key
  already guarantees Shape{2}); inline single-use bool 'even'.
- mlx/backend/metal/kernels/random.metal: collapse the two-output
  per-thread block to one expression each; drop the "Step 4"
  reference.
- mlx/backend/metal/primitives.cpp: drop the 7-line debugging
  postmortem about constant-buffer packing.
- .gitignore: drop the path-c-only .venv / python/mlx/lib entries.

No behavior change. 22/22 random tests + 708 full pytest pass;
canary (46341, 46341) bf16 normal still peaks at 4.69 GB.
Net diff: -69 lines.
@dogukanveziroglu
Copy link
Copy Markdown
Contributor Author

@angeloskath could you give some thoughts please, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant