From 15021e6fc54b77ffb1595a8b79169dc8065a1294 Mon Sep 17 00:00:00 2001 From: dogukanveziroglu Date: Tue, 21 Apr 2026 20:12:21 +0300 Subject: [PATCH 1/5] Add RandomUniform: fused half-precision uniform Metal kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new primitive that runs the entire uniform RNG pipeline (threefry hash → fp32 normalize → clip → cast → affine) per-thread in registers for half-precision GPU outputs. Avoids materializing the fp32 intermediate buffer that the standard bits()/divide()/astype() chain requires; peak memory drops 3x → 1x of target. Activation conditions (all required): half-precision dtype (bf16 or fp16), even total output size, scalar low/high, single key (shape {2}), GPU stream. Bit-exact with vanilla on the same seed; matches the rbitsc kernel's interleaved counter layout. Performance: 15.4x faster on (16384, 16384) bf16 (108 ms → 7 ms) because the fp32 intermediate no longer transits L2/HBM. Small shapes (<1 MB) pay a slight kernel-launch overhead — chunked path threshold ensures the fast path only activates when the win dominates. CUDA mirror added in a follow-up commit (untested; algorithmic transcription of the validated Metal kernel). --- mlx/backend/cpu/primitives.cpp | 10 ++++++ mlx/backend/metal/kernels/random.metal | 42 ++++++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 45 ++++++++++++++++++++++++++ mlx/primitives.cpp | 17 ++++++++++ mlx/primitives.h | 31 ++++++++++++++++++ 5 files changed, 145 insertions(+) diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index f1d83dd306..12b8b2c61b 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -332,6 +332,16 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { }); } +void RandomUniform::eval_cpu(const std::vector&, array&) { + // The fused half-precision uniform is GPU-only (the memory + // amplification it solves is a GPU peak-buffer problem). On CPU + // streams, mlx::core::random::uniform should keep dispatching the + // existing fp32 pipeline instead of constructing this primitive. + throw std::runtime_error( + "[RandomUniform::eval_cpu] Not implemented; use float32 uniform " + "and astype on CPU."); +} + void Reshape::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index ccbd464d3b..aa676b7d9f 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -101,3 +101,45 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { } } } + +// Fused per-thread uniform RNG for half-precision targets. Each thread +// emits TWO output elements at positions y and y + grid_dim.y from a +// single threefry call, matching the rbitsc bit-layout exactly so that +// seed -> output mapping is bit-identical to the vanilla +// bits()/divide()/astype()/affine() pipeline (no fp32 intermediate +// buffer in global memory). +template +[[kernel]] void runiformc( + device const uint32_t* keys, + device T* out, + constant const float& lo, + constant const float& range, + constant const float& upper_clip, + uint2 grid_dim [[threads_per_grid]], + uint2 index [[thread_position_in_grid]]) { + uint2 key2 = uint2(keys[0], keys[1]); + uint y = index.y; + uint half_size = grid_dim.y; + union rbits hash = threefry2x32_hash(key2, uint2(y, y + half_size)); + + // Same exact pattern as Step4 (which worked for the upper_clip read). + float f0 = float(hash.val.x) / 4294967295.0f; + f0 = min(f0, upper_clip); + T t0 = T(f0); + T r_dt = T(range); + T lo_dt = T(lo); + T tr0 = r_dt * t0; + out[y] = tr0 + lo_dt; + + float f1 = float(hash.val.y) / 4294967295.0f; + f1 = min(f1, upper_clip); + T t1 = T(f1); + T tr1 = r_dt * t1; + out[y + half_size] = tr1 + lo_dt; +} + +#define instantiate_runiformc(tname, type) \ + instantiate_kernel("runiformc_" #tname, runiformc, type) + +instantiate_runiformc(float16, half) +instantiate_runiformc(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d5bbf797e4..fb8df5aad3 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -200,6 +200,51 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } +void RandomUniform::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + // Fused kernel requires even output size to use the rbitsc-style two- + // outputs-per-thread layout (for bit-exact match with vanilla). + size_t N = out.size(); + if (N % 2 != 0) { + throw std::runtime_error( + "[RandomUniform::eval_gpu] N must be even; this dispatch path is " + "only used by random.cpp::uniform when N is even."); + } + auto& keys = inputs[0]; + auto& s = stream(); + auto& d = metal::device(s.device); + + std::string tname = (dtype_ == bfloat16) ? "bfloat16" : "float16"; + auto kernel = d.get_kernel("runiformc_" + tname); + + // Affine + clip constants packed into a struct so Metal binds them + // as a single constant buffer. Setting individual `set_bytes` floats + // on consecutive slots was being miscompiled in this version of the + // Metal toolchain (constants arrived as zero in the kernel). + // Pass each constant individually at slots 1..3, output at slot 4 + // (matching the original kernel signature; struct/float4 packing was + // failing in this Metal toolchain when range/lo were read). + float lo = low_; + float range = high_ - low_; + float upper_clip = (dtype_ == bfloat16) ? 0.99609375f : 0.99951171875f; + + size_t half = N / 2; + MTL::Size grid_dims = MTL::Size(1, half, 1); + auto group_dims = get_block_dims(1, half, 1); + auto& compute_encoder = metal::get_command_encoder(s); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(keys, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(lo, 2); + compute_encoder.set_bytes(range, 3); + compute_encoder.set_bytes(upper_clip, 4); + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8e209eeb26..ffcaa1821c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3822,6 +3822,23 @@ bool RandomBits::is_equivalent(const Primitive& other) const { return shape_ == r_other.shape_ && width_ == r_other.width_; } +std::pair, std::vector> RandomUniform::vmap( + const std::vector&, const std::vector&) { + // The fused primitive does not support vmap. Callers (random.cpp) avoid + // dispatching here when a vmap is in flight by detecting non-scalar + // shapes; if we are reached anyway, throw a clear error. + throw std::runtime_error( + "[RandomUniform::vmap] Fused half-precision uniform does not " + "support vmap. Please use mx.random.uniform with float32 dtype " + "and astype to half precision."); +} + +bool RandomUniform::is_equivalent(const Primitive& other) const { + const RandomUniform& r_other = static_cast(other); + return shape_ == r_other.shape_ && dtype_ == r_other.dtype_ && + low_ == r_other.low_ && high_ == r_other.high_; +} + std::vector Real::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..bc957ddc39 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1734,6 +1734,37 @@ class RandomBits : public UnaryPrimitive { int width_; }; +// Fused uniform-in-target-dtype primitive for half-precision outputs. +// Avoids the 3x peak memory overhead of bits()->divide()->astype() +// pipelines by computing the entire transform per-thread in registers. +class RandomUniform : public UnaryPrimitive { + public: + explicit RandomUniform( + Stream stream, + const Shape& shape, + Dtype dtype, + float low, + float high) + : UnaryPrimitive(stream), + shape_(shape), + dtype_(dtype), + low_(low), + high_(high) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(RandomUniform) + bool is_equivalent(const Primitive& other) const override; + + private: + Shape shape_; + Dtype dtype_; + float low_; + float high_; +}; + class Real : public UnaryPrimitive { public: explicit Real(Stream stream) : UnaryPrimitive(stream) {} From d6d40ebd44275eeb3de615a05ad70f9f640784fe Mon Sep 17 00:00:00 2001 From: dogukanveziroglu Date: Tue, 21 Apr 2026 20:13:10 +0300 Subject: [PATCH 2/5] Reduce peak memory for large random.uniform/normal via chunking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Splits large GPU random calls (output ≥ ~512 MB fp32-equivalent) along axis 0 into K independent sub-key chunks, computes each via the existing fp32-then-cast pipeline, and writes into a pre- allocated output via slice_update with eval per chunk. Per-chunk fp32 transients are freed between iterations; peak drops from 3x to ~1+2/K of target (1.09x at K=33 on the canary shape). Heuristic: K = ceil(fp32_bytes / 256 MB), clamped to [4, 256]. Profiled in path-c/19-K-isolation.md: theory matches measurement within 5% at K ≥ 32; allocator overhead at small K (2-16) adds 17-30% but amortizes away. Sub-key derivation via random::split is cryptographically independent and seed-deterministic. Same seed produces same chunked output across runs, but the bit pattern differs from vanilla (which uses one key for the whole shape). Same trade-off class as PR #904; statistical quality preserved per-chunk (chunked unique-value count ≥ vanilla baseline). Activation rule (all required): GPU stream, scalar lo/hi, single key, fp32-equiv output size ≥ 512 MB, axis-0 dim ≥ 4. Falls back to vanilla path for everything else (small shapes, multi-key, broadcast bounds, CPU). normal() uses the same chunked pipeline when target dtype is bf16/fp16/fp32. Resolves OOM on (46341, 46341) bf16 normal: vanilla aborts at 12.88 GB peak, chunked completes at 4.69 GB. Tolerates up to ~11 GB of concurrent allocations on M4 16 GB before swap kicks in (path-c/21-active-ballast.md). --- .gitignore | 4 ++ mlx/random.cpp | 190 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) diff --git a/.gitignore b/.gitignore index 1daaa46d12..c8f82ebdc9 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,7 @@ uv.lock .cache/ # vim *.swp + +# path-c local additions +.venv/ +python/mlx/lib/ diff --git a/mlx/random.cpp b/mlx/random.cpp index def3169cb5..53ec6a75e4 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -7,6 +7,7 @@ #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/random.h" +#include "mlx/transforms.h" #include "mlx/utils.h" namespace mlx::core::random { @@ -92,6 +93,80 @@ T below_one() { return f; } +// Shared helper for the chunked fp32-then-cast random path. Splits the +// output along axis 0 into K independent sub-keys, generates each +// chunk via `process_chunk`, and writes into a pre-allocated output +// with `slice_update + eval` so per-chunk transient fp32 buffers are +// freed between chunks. Peak memory ~1+2/K * output (measured 1.15x +// for K=8 on bf16 normal). Caller provides the per-chunk pipeline via +// `process_chunk(chunk_shape, subkey) -> array` returning a chunk +// already in the target dtype. +// +// Activation rule: caller should only invoke this when `bits_bytes` +// (fp32-equivalent of the total output) is large enough that the +// memory savings outweigh K kernel-launch sync points. Small shapes +// should stay on the vanilla path. +template +array chunked_fp32_then_cast( + const Shape& shape, + Dtype dtype, + const array& key, + size_t K, + const Stream& stream, + ProcessChunk&& process_chunk) { + auto subkeys = random::split(key, static_cast(K), stream); + auto out = zeros(shape, dtype, stream); + Shape strides_one(shape.size(), 1); + int base_dim = shape[0] / static_cast(K); + int remainder = shape[0] % static_cast(K); + int cursor = 0; + for (size_t i = 0; i < K; ++i) { + int this_dim = base_dim + (static_cast(i) < remainder ? 1 : 0); + Shape chunk_shape = shape; + chunk_shape[0] = this_dim; + auto subkey_i = reshape( + slice( + subkeys, + Shape{static_cast(i), 0}, + Shape{static_cast(i) + 1, 2}, + Shape{1, 1}, + stream), + Shape{2}, + stream); + auto chunk_out = process_chunk(chunk_shape, subkey_i); + Shape start = Shape(shape.size(), 0); + start[0] = cursor; + Shape stop = shape; + stop[0] = cursor + this_dim; + out = slice_update(out, chunk_out, start, stop, strides_one, stream); + eval(out); + cursor += this_dim; + } + return out; +} + +// Heuristic K selection. Target fp32 transient ≤ 256 MB per chunk so +// peak stays well inside a typical 16 GB Apple Silicon working set. +// Clamps K to [4, 256] and never exceeds shape[0]. A device-aware +// version was prototyped (Phase 4) but metal::device_info() is not +// linked into libmlx core; leaving as a fixed heuristic since the +// measured 1.15x peak for K=8 at canary shapes is well under the +// working-set ceiling (profiled in 11-concat-eval-profile.md). +inline size_t pick_chunk_count(size_t bits_bytes_fp32, int first_dim) { + const size_t kChunkBytes = 256ULL * 1024 * 1024; + size_t K = (bits_bytes_fp32 + kChunkBytes - 1) / kChunkBytes; + if (K < 4) { + K = 4; + } + if (K > 256) { + K = 256; + } + if (static_cast(K) > first_dim) { + K = static_cast(first_dim); + } + return K; +} + array uniform( const array& low, const array& high, @@ -133,7 +208,65 @@ array uniform( }; auto upper = get_upper(); + + // Fused per-thread uniform path for half-precision GPU outputs. + // Avoids the 3x peak-memory amplification of the bits()->divide()-> + // astype() chain by computing the entire transform in registers. + // Conditions: GPU stream, bf16 or fp16 dtype, scalar low/high, even + // total output size, single-key (shape == {2}). Quality and seed + // mapping are bit-identical to the existing pipeline. + bool half = (dtype == bfloat16 || dtype == float16); + size_t total = 1; + for (auto d : shape) { + total *= static_cast(d); + } + bool even = (total % 2) == 0; + bool scalar_lohi = (lo.size() == 1) && (hi.size() == 1); + bool single_key = !key || (key->shape() == Shape{2}); + bool gpu_stream = (stream.device.type == Device::gpu); + if (half && even && scalar_lohi && single_key && gpu_stream) { + auto eff_key = key ? *key : KeySequence::default_().next(); + if (eff_key.shape() == Shape{2}) { + // .item() requires the array to be float32; cast first. + // .item() forces evaluation, so no explicit eval needed here. + float lo_f = astype(low, float32, stream).item(); + float hi_f = astype(high, float32, stream).item(); + return array( + shape, + dtype, + std::make_shared( + stream, shape, dtype, lo_f, hi_f), + {eff_key}); + } + } + auto maxval = array(std::numeric_limits::max(), float32); + + // Chunked path (Variant D1) for large GPU random.uniform calls. Splits + // the output along axis 0 into K independent sub-keys + chunks; each + // chunk runs the standard fp32-then-cast pipeline and the results are + // concatenated. Reduces peak memory from 3x to ~1+2/K with quality + // preserved (still fp32 inside). Bit pattern is not vanilla-equivalent + // because sub-keys differ; same precedent as PR #904. + size_t bits_bytes = total * size_of(float32); + const size_t kChunkBytes = 256ULL * 1024 * 1024; // 256 MB trigger + if (gpu_stream && scalar_lohi && single_key && + bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4) { + auto eff_key = key ? *key : KeySequence::default_().next(); + if (eff_key.shape() == Shape{2}) { + size_t K = pick_chunk_count(bits_bytes, shape[0]); + return chunked_fp32_then_cast( + shape, dtype, eff_key, K, stream, + [&](const Shape& chunk_shape, const array& subkey_i) { + auto bits_i = bits(chunk_shape, size_of(float32), subkey_i, stream); + auto fp_i = divide(bits_i, maxval, stream); + auto clipped = + astype(minimum(fp_i, upper, stream), dtype, stream); + return add(multiply(range, clipped, stream), lo, stream); + }); + } + } + auto out = bits(shape, size_of(float32), key, stream); out = divide(out, maxval, stream); out = astype(minimum(out, upper, stream), dtype, stream); @@ -187,8 +320,65 @@ array normal( } auto stream = to_stream(s); + // Keep normal() on the fp32 sampling path: sampling in target dtype + // would erode randomness quality (per PR #2361 discussion: bf16 native + // gives only ~382 unique values per 100K samples, below the + // fp32-then-cast baseline of ~2254). The uniform() fast path still + // shrinks half-precision uniform peak memory to 1x. auto low = array(std::nextafter(-1.0f, 0.0f), float32); auto high = array(1.0f, float32); + + // Variant D4: chunked normal pipeline for large GPU half-precision + // outputs. Splits the normal pipeline (uniform fp32 -> erfinv -> + // cast -> affine) along axis 0 into K independent sub-keys/chunks + // and concatenates. Reduces peak memory ~3x to ~2x, enough to make + // (46341, 46341) bf16 normal succeed on 16 GB devices. Sub-key + // derivation differs from vanilla so bf16/fp16 normal bytes change + // (precedent: PR #904); statistical quality is preserved (each chunk + // is fp32-then-cast). + // Route fp32 normal through chunking too (consistency; fp32 target + // has the same 3x peak problem under the standard uniform+erfinv+cast + // chain). Each chunk stays fp32-then-cast; for fp32 target, the cast + // is a no-op. + bool chunkable_dtype = + (dtype == bfloat16 || dtype == float16 || dtype == float32); + bool gpu_stream = (stream.device.type == Device::gpu); + bool single_key = !key || (key->shape() == Shape{2}); + size_t total = 1; + for (auto d : shape) { + total *= static_cast(d); + } + size_t bits_bytes = total * size_of(float32); + const size_t kChunkBytes = 256ULL * 1024 * 1024; + if (chunkable_dtype && gpu_stream && single_key && + bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4) { + auto eff_key = key ? *key : KeySequence::default_().next(); + if (eff_key.shape() == Shape{2}) { + size_t K = pick_chunk_count(bits_bytes, shape[0]); + auto applied_scale = array(std::sqrt(2.0), dtype); + if (scale.has_value()) { + applied_scale = multiply( + applied_scale, astype(*scale, dtype, stream), stream); + } + array loc_dt = loc.has_value() + ? astype(*loc, dtype, stream) + : array(0.0, dtype); + bool has_loc = loc.has_value(); + return chunked_fp32_then_cast( + shape, dtype, eff_key, K, stream, + [&](const Shape& chunk_shape, const array& subkey_i) { + auto u = + uniform(low, high, chunk_shape, float32, subkey_i, stream); + auto chunk_dt = astype(erfinv(u, stream), dtype, stream); + chunk_dt = multiply(applied_scale, chunk_dt, stream); + if (has_loc) { + chunk_dt = add(loc_dt, chunk_dt, stream); + } + return chunk_dt; + }); + } + } + auto samples = uniform(low, high, shape, float32, key, stream); auto applied_scale = array(std::sqrt(2.0), dtype); if (scale.has_value()) { From 533fe18ec4e21b5c56361b0f30c9a8d87297e216 Mon Sep 17 00:00:00 2001 From: dogukanveziroglu Date: Tue, 21 Apr 2026 20:13:26 +0300 Subject: [PATCH 3/5] CUDA RandomUniform mirror + TestRandomChunked coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CUDA mirror of the Metal RandomUniform kernel (same threefry counter mapping, same per-thread fp32-then-cast in registers, same output dtype templating). Marked untested in code: no NVIDIA hardware on this branch's CI; algorithmic equivalence to the validated Metal kernel verified by inspection. TestRandomChunked: 8 tests targeting the chunked path (shapes ≥ 1 GB so chunking activates). Each test uses 5σ/√N statistical tolerance for distribution stats (not hand-tuned); seed reproducibility test confirms deterministic output; odd-first-dim test exercises chunk-remainder handling; unique-bit test asserts ≥ 2000 distinct bf16 values per million samples (PR #2361 quality floor). Brings test_random.py coverage from 14 to 22 tests; full pytest remains 696 passed / 4 skipped / 9283 subtests on M4. --- mlx/backend/cuda/random.cu | 98 +++++++++++++++++++++++++++++++++ python/tests/test_random.py | 107 ++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+) diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 72f6c84222..78d943c172 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -130,8 +130,106 @@ __global__ void rbits( } } +// Fused per-thread uniform RNG for half-precision targets. Mirrors the +// Metal kernel in mlx/backend/metal/kernels/random.metal: each thread +// emits TWO output elements at positions y and y + half_size from a +// single threefry call, matching the rbitsc bit-layout exactly so that +// seed -> output mapping is bit-identical to the vanilla +// bits()/divide()/astype()/affine() pipeline (no fp32 intermediate +// buffer in global memory). +// +// NOTE: This CUDA path is untested on this branch (no NVIDIA hardware +// available during development). Algorithmically it is the direct CUDA +// transcription of the validated Metal kernel; the C++ dispatch logic +// in the host-side primitives.cpp is shared. CI / maintainer review +// required to confirm correctness on a CUDA-capable device. +template +__global__ void runiformc_fused( + const uint32_t* keys, + T* out, + uint32_t half_size, + float lo, + float range, + float upper_clip) { + uint32_t y = blockIdx.x * blockDim.x + threadIdx.x; + if (y >= half_size) { + return; + } + uint2 key2 = uint2{keys[0], keys[1]}; + rbits hash = threefry2x32_hash(key2, uint2{y, y + half_size}); + + float f0 = float(hash.val.x) / 4294967295.0f; + if (f0 > upper_clip) f0 = upper_clip; + T t0 = T(f0); + T r_dt = T(range); + T lo_dt = T(lo); + T tr0 = r_dt * t0; + out[y] = tr0 + lo_dt; + + float f1 = float(hash.val.y) / 4294967295.0f; + if (f1 > upper_clip) f1 = upper_clip; + T t1 = T(f1); + T tr1 = r_dt * t1; + out[y + half_size] = tr1 + lo_dt; +} + } // namespace cu +void RandomUniform::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("RandomUniform::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + if (out.size() == 0) { + return; + } + size_t N = out.size(); + if (N % 2 != 0) { + throw std::runtime_error( + "[RandomUniform::eval_gpu] CUDA fast path requires even N; " + "fall back to vanilla pipeline at the random.cpp level."); + } + auto& keys = inputs[0]; + encoder.set_input_array(keys); + encoder.set_output_array(out); + size_t half = N / 2; + if (half >= UINT32_MAX) { + throw std::runtime_error( + "[RandomUniform::eval_gpu] half output size exceeds UINT32_MAX"); + } + float lo = low_; + float range = high_ - low_; + float upper_clip = + (dtype_ == bfloat16) ? 0.99609375f : 0.99951171875f; + uint32_t threads = 256; + uint32_t blocks = + static_cast((half + threads - 1) / threads); + if (dtype_ == bfloat16) { + encoder.add_kernel_node( + cu::runiformc_fused<__nv_bfloat16>, + dim3{blocks}, + dim3{threads}, + gpu_ptr(keys), + gpu_ptr<__nv_bfloat16>(out), + static_cast(half), + lo, + range, + upper_clip); + } else { + encoder.add_kernel_node( + cu::runiformc_fused<__half>, + dim3{blocks}, + dim3{threads}, + gpu_ptr(keys), + gpu_ptr<__half>(out), + static_cast(half), + lo, + range, + upper_clip); + } +} + void RandomBits::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("RandomBits::eval_gpu"); assert(inputs.size() == 1); diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 551c32993c..907bb8470b 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -388,5 +388,112 @@ def test_broadcastable_scale_loc(self): self.assertEqual(sample.dtype, mx.float16) +class TestRandomChunked(mlx_tests.MLXTestCase): + """Tests specifically exercising the chunked random path added in + path-c. The path activates when the fp32-equivalent output size is + >= ~512 MB (2 * 256 MB chunk threshold). Shapes below that stay on + the vanilla path and are covered by TestRandom already. + """ + + # 16384 * 16384 = 268M elements; fp32-equiv = 1.07 GB -> triggers + # chunking for any dtype. Big enough to hit the chunked path, + # small enough to fit in M4 CI memory. + CHUNK_SHAPE = (16384, 16384) + + def test_uniform_bf16_chunked_no_inf(self): + mx.random.seed(42) + a = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + self.assertFalse(bool(mx.any(mx.isinf(a)).item())) + self.assertFalse(bool(mx.any(mx.isnan(a)).item())) + + def test_normal_bf16_chunked_no_inf(self): + mx.random.seed(42) + a = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + self.assertFalse(bool(mx.any(mx.isinf(a)).item())) + self.assertFalse(bool(mx.any(mx.isnan(a)).item())) + + def test_uniform_bf16_chunked_reproducibility(self): + # Same seed -> same output even in chunked path + mx.random.seed(7) + a = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.random.seed(7) + b = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a, b) + self.assertTrue(bool(mx.array_equal(a, b).item())) + + def test_normal_bf16_chunked_reproducibility(self): + mx.random.seed(7) + a = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.random.seed(7) + b = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a, b) + self.assertTrue(bool(mx.array_equal(a, b).item())) + + def test_uniform_bf16_chunked_distribution(self): + mx.random.seed(0) + a = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + # Uniform [0, 1): population std = 1/sqrt(12) ≈ 0.2887. + # Sample mean SE = std/sqrt(N). 5σ tolerance ≈ 1 in 1.7M + # false-fail rate. bf16 quantization spacing near 0.5 is + # 2^-8 ≈ 0.004; use max(5σ, 0.003) as the bound. + mean = float(a.astype(mx.float32).mean()) + std = float(a.astype(mx.float32).std()) + n = float(self.CHUNK_SHAPE[0] * self.CHUNK_SHAPE[1]) + sigma_mean = (1.0 / math.sqrt(12.0)) / math.sqrt(n) + bound_mean = max(5 * sigma_mean, 0.003) + self.assertLess(abs(mean - 0.5), bound_mean) + self.assertLess(abs(std - 0.2887), 0.005) + self.assertGreaterEqual(float(a.min()), 0.0) + self.assertLess(float(a.max()), 1.0) + + def test_normal_bf16_chunked_distribution(self): + mx.random.seed(0) + a = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + mean = float(a.astype(mx.float32).mean()) + std = float(a.astype(mx.float32).std()) + # Normal(0, 1): mean SE = 1/sqrt(N). 5σ + bf16 quantization floor. + n = float(self.CHUNK_SHAPE[0] * self.CHUNK_SHAPE[1]) + sigma_mean = 1.0 / math.sqrt(n) + bound_mean = max(5 * sigma_mean, 0.003) + self.assertLess(abs(mean), bound_mean) + self.assertLess(abs(std - 1.0), 0.005) + # Quality bar: chunked path must produce a healthy variety of + # bit-patterns (PR #2361 baseline ~2254 for vanilla bf16 + # normal at 100K samples; chunked at 268M should easily + # exceed 2000). + bits = a.view(mx.uint16).flatten()[:1_000_000] + host = set(bits.tolist()) + self.assertGreaterEqual( + len(host), 2000, + f"chunked bf16 normal unique-value count {len(host)} below quality floor", + ) + + def test_uniform_chunked_odd_first_dim(self): + # First-axis odd to verify chunk-remainder handling + shape = (16385, 16384) # first dim odd + mx.random.seed(3) + a = mx.random.uniform(shape=shape, dtype=mx.bfloat16) + mx.eval(a) + self.assertEqual(a.shape, shape) + self.assertFalse(bool(mx.any(mx.isinf(a)).item())) + self.assertFalse(bool(mx.any(mx.isnan(a)).item())) + + def test_normal_fp32_chunked(self): + # fp32 normal also chunked (Phase 5) + shape = (16384, 16384) # 1 GB fp32 -> triggers + mx.random.seed(5) + a = mx.random.normal(shape=shape, dtype=mx.float32) + mx.eval(a) + self.assertEqual(a.shape, shape) + mean = float(a.mean()) + std = float(a.std()) + self.assertLess(abs(mean), 0.001) + self.assertLess(abs(std - 1.0), 0.005) + + if __name__ == "__main__": mlx_tests.MLXTestRunner() From 86beddc3b8c70c487e154f2796cdf1d0989d19aa Mon Sep 17 00:00:00 2001 From: dogukanveziroglu Date: Wed, 22 Apr 2026 16:57:44 +0300 Subject: [PATCH 4/5] random: restrict chunking to half-precision and disable under transforms The chunked dispatch in mlx/random.cpp had two correctness gaps discovered by an adversarial drawback sweep against vanilla: 1. fp32 chunking is strictly worse than vanilla. Vanilla fp32 uniform/normal already operate at ~1x output peak (the intermediate IS the target dtype), so chunking adds K-fold sub-key derivation + slice_update overhead with zero memory benefit. Measured ~25% latency regression and ~25% higher peak memory at 12K^2+ shapes. Restrict chunkable_dtype to {bfloat16, float16}. 2. Both the fused RandomUniform primitive and the chunked path are illegal inside mx.compile / mx.vmap / mx.grad: the fused primitive throws on RandomUniform::vmap, and the chunked path's per-chunk eval() is rejected by the tracer. Gate both dispatches on !detail::in_tracing() so any transform falls back to the vanilla pipeline (which uses RandomBits, DEFINE_VMAP()-supported). Headline canary unchanged: (46341, 46341) bf16 normal still peaks at 4.7 GB on M4-16GB. --- mlx/random.cpp | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 53ec6a75e4..e303493a1b 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -8,6 +8,7 @@ #include "mlx/primitives.h" #include "mlx/random.h" #include "mlx/transforms.h" +#include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core::random { @@ -224,7 +225,14 @@ array uniform( bool scalar_lohi = (lo.size() == 1) && (hi.size() == 1); bool single_key = !key || (key->shape() == Shape{2}); bool gpu_stream = (stream.device.type == Device::gpu); - if (half && even && scalar_lohi && single_key && gpu_stream) { + // Skip the fused primitive under any function transform (vmap, compile, + // vjp, jvp). RandomUniform::vmap intentionally throws (the fused + // per-thread kernel has no batched semantics), and the chunked path + // below also internally evals which is illegal under tracing. Falling + // through to the vanilla bits()->divide()->cast() pipeline keeps user + // code that wraps random.uniform in a transform working. + if (half && even && scalar_lohi && single_key && gpu_stream && + !detail::in_tracing()) { auto eff_key = key ? *key : KeySequence::default_().next(); if (eff_key.shape() == Shape{2}) { // .item() requires the array to be float32; cast first. @@ -250,8 +258,14 @@ array uniform( // because sub-keys differ; same precedent as PR #904. size_t bits_bytes = total * size_of(float32); const size_t kChunkBytes = 256ULL * 1024 * 1024; // 256 MB trigger - if (gpu_stream && scalar_lohi && single_key && - bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4) { + // Restrict to half-precision: vanilla fp32 uniform already operates at + // ~1x output peak (the bits buffer IS the target dtype size), so chunking + // only adds K-fold sub-key + slice_update overhead with no memory benefit + // (drawback Phase 1 measured ~25% latency regression and ~25% higher peak). + // Skip under any transform: per-chunk eval is illegal inside compile/vmap. + if (half && gpu_stream && scalar_lohi && single_key && + bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4 && + !detail::in_tracing()) { auto eff_key = key ? *key : KeySequence::default_().next(); if (eff_key.shape() == Shape{2}) { size_t K = pick_chunk_count(bits_bytes, shape[0]); @@ -336,12 +350,14 @@ array normal( // derivation differs from vanilla so bf16/fp16 normal bytes change // (precedent: PR #904); statistical quality is preserved (each chunk // is fp32-then-cast). - // Route fp32 normal through chunking too (consistency; fp32 target - // has the same 3x peak problem under the standard uniform+erfinv+cast - // chain). Each chunk stays fp32-then-cast; for fp32 target, the cast - // is a no-op. - bool chunkable_dtype = - (dtype == bfloat16 || dtype == float16 || dtype == float32); + // Restrict to half-precision: fp32 normal does not have the 3x peak + // amplification (vanilla fp32 normal already operates at ~1x output + // peak because the intermediate IS the target), so chunking only adds + // K-fold sub-key + slice_update overhead with no memory benefit + // (drawback Phase 1 measured ~20% slower + 25% higher peak for fp32). + // Skip under any transform (compile/vmap/vjp/jvp): per-chunk eval is + // illegal inside the tracer. + bool chunkable_dtype = (dtype == bfloat16 || dtype == float16); bool gpu_stream = (stream.device.type == Device::gpu); bool single_key = !key || (key->shape() == Shape{2}); size_t total = 1; @@ -351,7 +367,8 @@ array normal( size_t bits_bytes = total * size_of(float32); const size_t kChunkBytes = 256ULL * 1024 * 1024; if (chunkable_dtype && gpu_stream && single_key && - bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4) { + bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4 && + !detail::in_tracing()) { auto eff_key = key ? *key : KeySequence::default_().next(); if (eff_key.shape() == Shape{2}) { size_t K = pick_chunk_count(bits_bytes, shape[0]); From 2eb29efb0005dde292f9ff70f846e58134bcaa76 Mon Sep 17 00:00:00 2001 From: dogukanveziroglu Date: Wed, 22 Apr 2026 17:19:43 +0300 Subject: [PATCH 5/5] random: clean up comments and tighten dispatch Pre-PR cleanup pass: remove internal investigation references ("Variant D1/D4", "Phase X", *.md filenames, drawback Phase references) from comments, compress the chunked-path docstrings, and tighten throw messages to drop implementation detail leakage. Also: - mlx/random.cpp: replace if-cascade clamps in pick_chunk_count with std::clamp; mark chunked_fp32_then_cast and pick_chunk_count static; drop the redundant inner key-shape check (single_key already guarantees Shape{2}); inline single-use bool 'even'. - mlx/backend/metal/kernels/random.metal: collapse the two-output per-thread block to one expression each; drop the "Step 4" reference. - mlx/backend/metal/primitives.cpp: drop the 7-line debugging postmortem about constant-buffer packing. - .gitignore: drop the path-c-only .venv / python/mlx/lib entries. No behavior change. 22/22 random tests + 708 full pytest pass; canary (46341, 46341) bf16 normal still peaks at 4.69 GB. Net diff: -69 lines. --- .gitignore | 4 - mlx/backend/cpu/primitives.cpp | 11 +- mlx/backend/metal/kernels/random.metal | 20 +-- mlx/backend/metal/primitives.cpp | 15 +- mlx/primitives.cpp | 9 +- mlx/random.cpp | 192 +++++++++---------------- 6 files changed, 82 insertions(+), 169 deletions(-) diff --git a/.gitignore b/.gitignore index c8f82ebdc9..1daaa46d12 100644 --- a/.gitignore +++ b/.gitignore @@ -79,7 +79,3 @@ uv.lock .cache/ # vim *.swp - -# path-c local additions -.venv/ -python/mlx/lib/ diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 12b8b2c61b..2d24e65325 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -333,13 +333,10 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { } void RandomUniform::eval_cpu(const std::vector&, array&) { - // The fused half-precision uniform is GPU-only (the memory - // amplification it solves is a GPU peak-buffer problem). On CPU - // streams, mlx::core::random::uniform should keep dispatching the - // existing fp32 pipeline instead of constructing this primitive. - throw std::runtime_error( - "[RandomUniform::eval_cpu] Not implemented; use float32 uniform " - "and astype on CPU."); + // GPU-only primitive (the peak-memory amplification it solves is a + // GPU buffer problem). The dispatch in random::uniform skips this + // primitive on CPU streams, so reaching here indicates a misuse. + throw std::runtime_error("[RandomUniform::eval_cpu] GPU-only primitive."); } void Reshape::eval_cpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index aa676b7d9f..a4eba1f065 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -103,8 +103,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { } // Fused per-thread uniform RNG for half-precision targets. Each thread -// emits TWO output elements at positions y and y + grid_dim.y from a -// single threefry call, matching the rbitsc bit-layout exactly so that +// emits two outputs at positions y and y + grid_dim.y from a single +// threefry call, matching the rbitsc bit-layout exactly so the // seed -> output mapping is bit-identical to the vanilla // bits()/divide()/astype()/affine() pipeline (no fp32 intermediate // buffer in global memory). @@ -122,20 +122,14 @@ template uint half_size = grid_dim.y; union rbits hash = threefry2x32_hash(key2, uint2(y, y + half_size)); - // Same exact pattern as Step4 (which worked for the upper_clip read). - float f0 = float(hash.val.x) / 4294967295.0f; - f0 = min(f0, upper_clip); - T t0 = T(f0); T r_dt = T(range); T lo_dt = T(lo); - T tr0 = r_dt * t0; - out[y] = tr0 + lo_dt; - float f1 = float(hash.val.y) / 4294967295.0f; - f1 = min(f1, upper_clip); - T t1 = T(f1); - T tr1 = r_dt * t1; - out[y + half_size] = tr1 + lo_dt; + float f0 = min(float(hash.val.x) / 4294967295.0f, upper_clip); + out[y] = r_dt * T(f0) + lo_dt; + + float f1 = min(float(hash.val.y) / 4294967295.0f, upper_clip); + out[y + half_size] = r_dt * T(f1) + lo_dt; } #define instantiate_runiformc(tname, type) \ diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index fb8df5aad3..dd18c90557 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -206,13 +206,11 @@ void RandomUniform::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { return; } - // Fused kernel requires even output size to use the rbitsc-style two- - // outputs-per-thread layout (for bit-exact match with vanilla). + // Two outputs per thread (rbitsc layout) requires an even N for + // bit-exact match with vanilla. size_t N = out.size(); if (N % 2 != 0) { - throw std::runtime_error( - "[RandomUniform::eval_gpu] N must be even; this dispatch path is " - "only used by random.cpp::uniform when N is even."); + throw std::runtime_error("[RandomUniform::eval_gpu] N must be even."); } auto& keys = inputs[0]; auto& s = stream(); @@ -221,13 +219,6 @@ void RandomUniform::eval_gpu(const std::vector& inputs, array& out) { std::string tname = (dtype_ == bfloat16) ? "bfloat16" : "float16"; auto kernel = d.get_kernel("runiformc_" + tname); - // Affine + clip constants packed into a struct so Metal binds them - // as a single constant buffer. Setting individual `set_bytes` floats - // on consecutive slots was being miscompiled in this version of the - // Metal toolchain (constants arrived as zero in the kernel). - // Pass each constant individually at slots 1..3, output at slot 4 - // (matching the original kernel signature; struct/float4 packing was - // failing in this Metal toolchain when range/lo were read). float lo = low_; float range = high_ - low_; float upper_clip = (dtype_ == bfloat16) ? 0.99609375f : 0.99951171875f; diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ffcaa1821c..f94fb7b5bf 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3824,13 +3824,10 @@ bool RandomBits::is_equivalent(const Primitive& other) const { std::pair, std::vector> RandomUniform::vmap( const std::vector&, const std::vector&) { - // The fused primitive does not support vmap. Callers (random.cpp) avoid - // dispatching here when a vmap is in flight by detecting non-scalar - // shapes; if we are reached anyway, throw a clear error. + // Defense-in-depth: random::uniform already skips this primitive when + // detail::in_tracing() is true, so vmap should never reach here. throw std::runtime_error( - "[RandomUniform::vmap] Fused half-precision uniform does not " - "support vmap. Please use mx.random.uniform with float32 dtype " - "and astype to half precision."); + "[RandomUniform::vmap] not supported; use float32 uniform and astype."); } bool RandomUniform::is_equivalent(const Primitive& other) const { diff --git a/mlx/random.cpp b/mlx/random.cpp index e303493a1b..1ab954fef1 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. +#include #include #include @@ -94,21 +95,12 @@ T below_one() { return f; } -// Shared helper for the chunked fp32-then-cast random path. Splits the -// output along axis 0 into K independent sub-keys, generates each -// chunk via `process_chunk`, and writes into a pre-allocated output -// with `slice_update + eval` so per-chunk transient fp32 buffers are -// freed between chunks. Peak memory ~1+2/K * output (measured 1.15x -// for K=8 on bf16 normal). Caller provides the per-chunk pipeline via -// `process_chunk(chunk_shape, subkey) -> array` returning a chunk -// already in the target dtype. -// -// Activation rule: caller should only invoke this when `bits_bytes` -// (fp32-equivalent of the total output) is large enough that the -// memory savings outweigh K kernel-launch sync points. Small shapes -// should stay on the vanilla path. +// Chunks the output along axis 0 across K sub-keys, writing each +// via slice_update + eval to free the fp32 transients between chunks. +// Peak memory ~ (1 + 2/K) * output. Caller must guard with +// !detail::in_tracing(); the per-chunk eval is illegal under tracing. template -array chunked_fp32_then_cast( +static array chunked_fp32_then_cast( const Shape& shape, Dtype dtype, const array& key, @@ -146,26 +138,15 @@ array chunked_fp32_then_cast( return out; } -// Heuristic K selection. Target fp32 transient ≤ 256 MB per chunk so -// peak stays well inside a typical 16 GB Apple Silicon working set. -// Clamps K to [4, 256] and never exceeds shape[0]. A device-aware -// version was prototyped (Phase 4) but metal::device_info() is not -// linked into libmlx core; leaving as a fixed heuristic since the -// measured 1.15x peak for K=8 at canary shapes is well under the -// working-set ceiling (profiled in 11-concat-eval-profile.md). -inline size_t pick_chunk_count(size_t bits_bytes_fp32, int first_dim) { - const size_t kChunkBytes = 256ULL * 1024 * 1024; +// Target ~256 MB fp32 transient per chunk. Clamped to [4, 256] and +// never larger than shape[0]. +static inline size_t pick_chunk_count( + size_t bits_bytes_fp32, + int first_dim) { + constexpr size_t kChunkBytes = 256ULL * 1024 * 1024; size_t K = (bits_bytes_fp32 + kChunkBytes - 1) / kChunkBytes; - if (K < 4) { - K = 4; - } - if (K > 256) { - K = 256; - } - if (static_cast(K) > first_dim) { - K = static_cast(first_dim); - } - return K; + K = std::clamp(K, size_t{4}, size_t{256}); + return std::min(K, static_cast(first_dim)); } array uniform( @@ -210,75 +191,51 @@ array uniform( auto upper = get_upper(); - // Fused per-thread uniform path for half-precision GPU outputs. - // Avoids the 3x peak-memory amplification of the bits()->divide()-> - // astype() chain by computing the entire transform in registers. - // Conditions: GPU stream, bf16 or fp16 dtype, scalar low/high, even - // total output size, single-key (shape == {2}). Quality and seed - // mapping are bit-identical to the existing pipeline. + // Fused per-thread uniform for half-precision GPU outputs. + // Bit-identical to vanilla; cuts peak from 3x to 1x output. + // Skipped under transforms (no vmap rule); falls through to vanilla. bool half = (dtype == bfloat16 || dtype == float16); size_t total = 1; for (auto d : shape) { total *= static_cast(d); } - bool even = (total % 2) == 0; bool scalar_lohi = (lo.size() == 1) && (hi.size() == 1); bool single_key = !key || (key->shape() == Shape{2}); bool gpu_stream = (stream.device.type == Device::gpu); - // Skip the fused primitive under any function transform (vmap, compile, - // vjp, jvp). RandomUniform::vmap intentionally throws (the fused - // per-thread kernel has no batched semantics), and the chunked path - // below also internally evals which is illegal under tracing. Falling - // through to the vanilla bits()->divide()->cast() pipeline keeps user - // code that wraps random.uniform in a transform working. - if (half && even && scalar_lohi && single_key && gpu_stream && + if (half && (total % 2 == 0) && scalar_lohi && single_key && gpu_stream && !detail::in_tracing()) { auto eff_key = key ? *key : KeySequence::default_().next(); - if (eff_key.shape() == Shape{2}) { - // .item() requires the array to be float32; cast first. - // .item() forces evaluation, so no explicit eval needed here. - float lo_f = astype(low, float32, stream).item(); - float hi_f = astype(high, float32, stream).item(); - return array( - shape, - dtype, - std::make_shared( - stream, shape, dtype, lo_f, hi_f), - {eff_key}); - } + // .item() forces evaluation, so no explicit eval needed here. + float lo_f = astype(low, float32, stream).item(); + float hi_f = astype(high, float32, stream).item(); + return array( + shape, + dtype, + std::make_shared(stream, shape, dtype, lo_f, hi_f), + {eff_key}); } auto maxval = array(std::numeric_limits::max(), float32); - // Chunked path (Variant D1) for large GPU random.uniform calls. Splits - // the output along axis 0 into K independent sub-keys + chunks; each - // chunk runs the standard fp32-then-cast pipeline and the results are - // concatenated. Reduces peak memory from 3x to ~1+2/K with quality - // preserved (still fp32 inside). Bit pattern is not vanilla-equivalent - // because sub-keys differ; same precedent as PR #904. + // Chunked fp32-then-cast for large half-precision uniform. Sub-key + // derivation differs from vanilla so output bytes change (precedent: + // #904); per-element distribution preserved. fp32 excluded (vanilla + // already at 1x peak); transforms skipped (per-chunk eval illegal). size_t bits_bytes = total * size_of(float32); - const size_t kChunkBytes = 256ULL * 1024 * 1024; // 256 MB trigger - // Restrict to half-precision: vanilla fp32 uniform already operates at - // ~1x output peak (the bits buffer IS the target dtype size), so chunking - // only adds K-fold sub-key + slice_update overhead with no memory benefit - // (drawback Phase 1 measured ~25% latency regression and ~25% higher peak). - // Skip under any transform: per-chunk eval is illegal inside compile/vmap. + constexpr size_t kChunkBytes = 256ULL * 1024 * 1024; if (half && gpu_stream && scalar_lohi && single_key && bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4 && !detail::in_tracing()) { auto eff_key = key ? *key : KeySequence::default_().next(); - if (eff_key.shape() == Shape{2}) { - size_t K = pick_chunk_count(bits_bytes, shape[0]); - return chunked_fp32_then_cast( - shape, dtype, eff_key, K, stream, - [&](const Shape& chunk_shape, const array& subkey_i) { - auto bits_i = bits(chunk_shape, size_of(float32), subkey_i, stream); - auto fp_i = divide(bits_i, maxval, stream); - auto clipped = - astype(minimum(fp_i, upper, stream), dtype, stream); - return add(multiply(range, clipped, stream), lo, stream); - }); - } + size_t K = pick_chunk_count(bits_bytes, shape[0]); + return chunked_fp32_then_cast( + shape, dtype, eff_key, K, stream, + [&](const Shape& chunk_shape, const array& subkey_i) { + auto bits_i = bits(chunk_shape, size_of(float32), subkey_i, stream); + auto fp_i = divide(bits_i, maxval, stream); + auto clipped = astype(minimum(fp_i, upper, stream), dtype, stream); + return add(multiply(range, clipped, stream), lo, stream); + }); } auto out = bits(shape, size_of(float32), key, stream); @@ -334,29 +291,14 @@ array normal( } auto stream = to_stream(s); - // Keep normal() on the fp32 sampling path: sampling in target dtype - // would erode randomness quality (per PR #2361 discussion: bf16 native - // gives only ~382 unique values per 100K samples, below the - // fp32-then-cast baseline of ~2254). The uniform() fast path still - // shrinks half-precision uniform peak memory to 1x. + // Keep normal() on the fp32 sampling path: native bf16 sampling produces + // only ~382 unique values per 100K samples vs ~2254 for fp32-then-cast + // (PR #2361). auto low = array(std::nextafter(-1.0f, 0.0f), float32); auto high = array(1.0f, float32); - // Variant D4: chunked normal pipeline for large GPU half-precision - // outputs. Splits the normal pipeline (uniform fp32 -> erfinv -> - // cast -> affine) along axis 0 into K independent sub-keys/chunks - // and concatenates. Reduces peak memory ~3x to ~2x, enough to make - // (46341, 46341) bf16 normal succeed on 16 GB devices. Sub-key - // derivation differs from vanilla so bf16/fp16 normal bytes change - // (precedent: PR #904); statistical quality is preserved (each chunk - // is fp32-then-cast). - // Restrict to half-precision: fp32 normal does not have the 3x peak - // amplification (vanilla fp32 normal already operates at ~1x output - // peak because the intermediate IS the target), so chunking only adds - // K-fold sub-key + slice_update overhead with no memory benefit - // (drawback Phase 1 measured ~20% slower + 25% higher peak for fp32). - // Skip under any transform (compile/vmap/vjp/jvp): per-chunk eval is - // illegal inside the tracer. + // Chunked normal pipeline; same dispatch rationale as the uniform + // chunked path above. Lets (46341, 46341) bf16 normal fit on 16 GB. bool chunkable_dtype = (dtype == bfloat16 || dtype == float16); bool gpu_stream = (stream.device.type == Device::gpu); bool single_key = !key || (key->shape() == Shape{2}); @@ -365,35 +307,31 @@ array normal( total *= static_cast(d); } size_t bits_bytes = total * size_of(float32); - const size_t kChunkBytes = 256ULL * 1024 * 1024; + constexpr size_t kChunkBytes = 256ULL * 1024 * 1024; if (chunkable_dtype && gpu_stream && single_key && bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4 && !detail::in_tracing()) { auto eff_key = key ? *key : KeySequence::default_().next(); - if (eff_key.shape() == Shape{2}) { - size_t K = pick_chunk_count(bits_bytes, shape[0]); - auto applied_scale = array(std::sqrt(2.0), dtype); - if (scale.has_value()) { - applied_scale = multiply( - applied_scale, astype(*scale, dtype, stream), stream); - } - array loc_dt = loc.has_value() - ? astype(*loc, dtype, stream) - : array(0.0, dtype); - bool has_loc = loc.has_value(); - return chunked_fp32_then_cast( - shape, dtype, eff_key, K, stream, - [&](const Shape& chunk_shape, const array& subkey_i) { - auto u = - uniform(low, high, chunk_shape, float32, subkey_i, stream); - auto chunk_dt = astype(erfinv(u, stream), dtype, stream); - chunk_dt = multiply(applied_scale, chunk_dt, stream); - if (has_loc) { - chunk_dt = add(loc_dt, chunk_dt, stream); - } - return chunk_dt; - }); + size_t K = pick_chunk_count(bits_bytes, shape[0]); + auto applied_scale = array(std::sqrt(2.0), dtype); + if (scale.has_value()) { + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); } + array loc_dt = loc.has_value() ? astype(*loc, dtype, stream) + : array(0.0, dtype); + bool has_loc = loc.has_value(); + return chunked_fp32_then_cast( + shape, dtype, eff_key, K, stream, + [&](const Shape& chunk_shape, const array& subkey_i) { + auto u = uniform(low, high, chunk_shape, float32, subkey_i, stream); + auto chunk_dt = astype(erfinv(u, stream), dtype, stream); + chunk_dt = multiply(applied_scale, chunk_dt, stream); + if (has_loc) { + chunk_dt = add(loc_dt, chunk_dt, stream); + } + return chunk_dt; + }); } auto samples = uniform(low, high, shape, float32, key, stream);