diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index f1d83dd306..2d24e65325 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -332,6 +332,13 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { }); } +void RandomUniform::eval_cpu(const std::vector&, array&) { + // GPU-only primitive (the peak-memory amplification it solves is a + // GPU buffer problem). The dispatch in random::uniform skips this + // primitive on CPU streams, so reaching here indicates a misuse. + throw std::runtime_error("[RandomUniform::eval_cpu] GPU-only primitive."); +} + void Reshape::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 72f6c84222..78d943c172 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -130,8 +130,106 @@ __global__ void rbits( } } +// Fused per-thread uniform RNG for half-precision targets. Mirrors the +// Metal kernel in mlx/backend/metal/kernels/random.metal: each thread +// emits TWO output elements at positions y and y + half_size from a +// single threefry call, matching the rbitsc bit-layout exactly so that +// seed -> output mapping is bit-identical to the vanilla +// bits()/divide()/astype()/affine() pipeline (no fp32 intermediate +// buffer in global memory). +// +// NOTE: This CUDA path is untested on this branch (no NVIDIA hardware +// available during development). Algorithmically it is the direct CUDA +// transcription of the validated Metal kernel; the C++ dispatch logic +// in the host-side primitives.cpp is shared. CI / maintainer review +// required to confirm correctness on a CUDA-capable device. +template +__global__ void runiformc_fused( + const uint32_t* keys, + T* out, + uint32_t half_size, + float lo, + float range, + float upper_clip) { + uint32_t y = blockIdx.x * blockDim.x + threadIdx.x; + if (y >= half_size) { + return; + } + uint2 key2 = uint2{keys[0], keys[1]}; + rbits hash = threefry2x32_hash(key2, uint2{y, y + half_size}); + + float f0 = float(hash.val.x) / 4294967295.0f; + if (f0 > upper_clip) f0 = upper_clip; + T t0 = T(f0); + T r_dt = T(range); + T lo_dt = T(lo); + T tr0 = r_dt * t0; + out[y] = tr0 + lo_dt; + + float f1 = float(hash.val.y) / 4294967295.0f; + if (f1 > upper_clip) f1 = upper_clip; + T t1 = T(f1); + T tr1 = r_dt * t1; + out[y + half_size] = tr1 + lo_dt; +} + } // namespace cu +void RandomUniform::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("RandomUniform::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + if (out.size() == 0) { + return; + } + size_t N = out.size(); + if (N % 2 != 0) { + throw std::runtime_error( + "[RandomUniform::eval_gpu] CUDA fast path requires even N; " + "fall back to vanilla pipeline at the random.cpp level."); + } + auto& keys = inputs[0]; + encoder.set_input_array(keys); + encoder.set_output_array(out); + size_t half = N / 2; + if (half >= UINT32_MAX) { + throw std::runtime_error( + "[RandomUniform::eval_gpu] half output size exceeds UINT32_MAX"); + } + float lo = low_; + float range = high_ - low_; + float upper_clip = + (dtype_ == bfloat16) ? 0.99609375f : 0.99951171875f; + uint32_t threads = 256; + uint32_t blocks = + static_cast((half + threads - 1) / threads); + if (dtype_ == bfloat16) { + encoder.add_kernel_node( + cu::runiformc_fused<__nv_bfloat16>, + dim3{blocks}, + dim3{threads}, + gpu_ptr(keys), + gpu_ptr<__nv_bfloat16>(out), + static_cast(half), + lo, + range, + upper_clip); + } else { + encoder.add_kernel_node( + cu::runiformc_fused<__half>, + dim3{blocks}, + dim3{threads}, + gpu_ptr(keys), + gpu_ptr<__half>(out), + static_cast(half), + lo, + range, + upper_clip); + } +} + void RandomBits::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("RandomBits::eval_gpu"); assert(inputs.size() == 1); diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index ccbd464d3b..a4eba1f065 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -101,3 +101,39 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { } } } + +// Fused per-thread uniform RNG for half-precision targets. Each thread +// emits two outputs at positions y and y + grid_dim.y from a single +// threefry call, matching the rbitsc bit-layout exactly so the +// seed -> output mapping is bit-identical to the vanilla +// bits()/divide()/astype()/affine() pipeline (no fp32 intermediate +// buffer in global memory). +template +[[kernel]] void runiformc( + device const uint32_t* keys, + device T* out, + constant const float& lo, + constant const float& range, + constant const float& upper_clip, + uint2 grid_dim [[threads_per_grid]], + uint2 index [[thread_position_in_grid]]) { + uint2 key2 = uint2(keys[0], keys[1]); + uint y = index.y; + uint half_size = grid_dim.y; + union rbits hash = threefry2x32_hash(key2, uint2(y, y + half_size)); + + T r_dt = T(range); + T lo_dt = T(lo); + + float f0 = min(float(hash.val.x) / 4294967295.0f, upper_clip); + out[y] = r_dt * T(f0) + lo_dt; + + float f1 = min(float(hash.val.y) / 4294967295.0f, upper_clip); + out[y + half_size] = r_dt * T(f1) + lo_dt; +} + +#define instantiate_runiformc(tname, type) \ + instantiate_kernel("runiformc_" #tname, runiformc, type) + +instantiate_runiformc(float16, half) +instantiate_runiformc(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d5bbf797e4..dd18c90557 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -200,6 +200,42 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } +void RandomUniform::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + // Two outputs per thread (rbitsc layout) requires an even N for + // bit-exact match with vanilla. + size_t N = out.size(); + if (N % 2 != 0) { + throw std::runtime_error("[RandomUniform::eval_gpu] N must be even."); + } + auto& keys = inputs[0]; + auto& s = stream(); + auto& d = metal::device(s.device); + + std::string tname = (dtype_ == bfloat16) ? "bfloat16" : "float16"; + auto kernel = d.get_kernel("runiformc_" + tname); + + float lo = low_; + float range = high_ - low_; + float upper_clip = (dtype_ == bfloat16) ? 0.99609375f : 0.99951171875f; + + size_t half = N / 2; + MTL::Size grid_dims = MTL::Size(1, half, 1); + auto group_dims = get_block_dims(1, half, 1); + auto& compute_encoder = metal::get_command_encoder(s); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(keys, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(lo, 2); + compute_encoder.set_bytes(range, 3); + compute_encoder.set_bytes(upper_clip, 4); + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8e209eeb26..f94fb7b5bf 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3822,6 +3822,20 @@ bool RandomBits::is_equivalent(const Primitive& other) const { return shape_ == r_other.shape_ && width_ == r_other.width_; } +std::pair, std::vector> RandomUniform::vmap( + const std::vector&, const std::vector&) { + // Defense-in-depth: random::uniform already skips this primitive when + // detail::in_tracing() is true, so vmap should never reach here. + throw std::runtime_error( + "[RandomUniform::vmap] not supported; use float32 uniform and astype."); +} + +bool RandomUniform::is_equivalent(const Primitive& other) const { + const RandomUniform& r_other = static_cast(other); + return shape_ == r_other.shape_ && dtype_ == r_other.dtype_ && + low_ == r_other.low_ && high_ == r_other.high_; +} + std::vector Real::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..bc957ddc39 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1734,6 +1734,37 @@ class RandomBits : public UnaryPrimitive { int width_; }; +// Fused uniform-in-target-dtype primitive for half-precision outputs. +// Avoids the 3x peak memory overhead of bits()->divide()->astype() +// pipelines by computing the entire transform per-thread in registers. +class RandomUniform : public UnaryPrimitive { + public: + explicit RandomUniform( + Stream stream, + const Shape& shape, + Dtype dtype, + float low, + float high) + : UnaryPrimitive(stream), + shape_(shape), + dtype_(dtype), + low_(low), + high_(high) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(RandomUniform) + bool is_equivalent(const Primitive& other) const override; + + private: + Shape shape_; + Dtype dtype_; + float low_; + float high_; +}; + class Real : public UnaryPrimitive { public: explicit Real(Stream stream) : UnaryPrimitive(stream) {} diff --git a/mlx/random.cpp b/mlx/random.cpp index def3169cb5..1ab954fef1 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. +#include #include #include @@ -7,6 +8,8 @@ #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/random.h" +#include "mlx/transforms.h" +#include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core::random { @@ -92,6 +95,60 @@ T below_one() { return f; } +// Chunks the output along axis 0 across K sub-keys, writing each +// via slice_update + eval to free the fp32 transients between chunks. +// Peak memory ~ (1 + 2/K) * output. Caller must guard with +// !detail::in_tracing(); the per-chunk eval is illegal under tracing. +template +static array chunked_fp32_then_cast( + const Shape& shape, + Dtype dtype, + const array& key, + size_t K, + const Stream& stream, + ProcessChunk&& process_chunk) { + auto subkeys = random::split(key, static_cast(K), stream); + auto out = zeros(shape, dtype, stream); + Shape strides_one(shape.size(), 1); + int base_dim = shape[0] / static_cast(K); + int remainder = shape[0] % static_cast(K); + int cursor = 0; + for (size_t i = 0; i < K; ++i) { + int this_dim = base_dim + (static_cast(i) < remainder ? 1 : 0); + Shape chunk_shape = shape; + chunk_shape[0] = this_dim; + auto subkey_i = reshape( + slice( + subkeys, + Shape{static_cast(i), 0}, + Shape{static_cast(i) + 1, 2}, + Shape{1, 1}, + stream), + Shape{2}, + stream); + auto chunk_out = process_chunk(chunk_shape, subkey_i); + Shape start = Shape(shape.size(), 0); + start[0] = cursor; + Shape stop = shape; + stop[0] = cursor + this_dim; + out = slice_update(out, chunk_out, start, stop, strides_one, stream); + eval(out); + cursor += this_dim; + } + return out; +} + +// Target ~256 MB fp32 transient per chunk. Clamped to [4, 256] and +// never larger than shape[0]. +static inline size_t pick_chunk_count( + size_t bits_bytes_fp32, + int first_dim) { + constexpr size_t kChunkBytes = 256ULL * 1024 * 1024; + size_t K = (bits_bytes_fp32 + kChunkBytes - 1) / kChunkBytes; + K = std::clamp(K, size_t{4}, size_t{256}); + return std::min(K, static_cast(first_dim)); +} + array uniform( const array& low, const array& high, @@ -133,7 +190,54 @@ array uniform( }; auto upper = get_upper(); + + // Fused per-thread uniform for half-precision GPU outputs. + // Bit-identical to vanilla; cuts peak from 3x to 1x output. + // Skipped under transforms (no vmap rule); falls through to vanilla. + bool half = (dtype == bfloat16 || dtype == float16); + size_t total = 1; + for (auto d : shape) { + total *= static_cast(d); + } + bool scalar_lohi = (lo.size() == 1) && (hi.size() == 1); + bool single_key = !key || (key->shape() == Shape{2}); + bool gpu_stream = (stream.device.type == Device::gpu); + if (half && (total % 2 == 0) && scalar_lohi && single_key && gpu_stream && + !detail::in_tracing()) { + auto eff_key = key ? *key : KeySequence::default_().next(); + // .item() forces evaluation, so no explicit eval needed here. + float lo_f = astype(low, float32, stream).item(); + float hi_f = astype(high, float32, stream).item(); + return array( + shape, + dtype, + std::make_shared(stream, shape, dtype, lo_f, hi_f), + {eff_key}); + } + auto maxval = array(std::numeric_limits::max(), float32); + + // Chunked fp32-then-cast for large half-precision uniform. Sub-key + // derivation differs from vanilla so output bytes change (precedent: + // #904); per-element distribution preserved. fp32 excluded (vanilla + // already at 1x peak); transforms skipped (per-chunk eval illegal). + size_t bits_bytes = total * size_of(float32); + constexpr size_t kChunkBytes = 256ULL * 1024 * 1024; + if (half && gpu_stream && scalar_lohi && single_key && + bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4 && + !detail::in_tracing()) { + auto eff_key = key ? *key : KeySequence::default_().next(); + size_t K = pick_chunk_count(bits_bytes, shape[0]); + return chunked_fp32_then_cast( + shape, dtype, eff_key, K, stream, + [&](const Shape& chunk_shape, const array& subkey_i) { + auto bits_i = bits(chunk_shape, size_of(float32), subkey_i, stream); + auto fp_i = divide(bits_i, maxval, stream); + auto clipped = astype(minimum(fp_i, upper, stream), dtype, stream); + return add(multiply(range, clipped, stream), lo, stream); + }); + } + auto out = bits(shape, size_of(float32), key, stream); out = divide(out, maxval, stream); out = astype(minimum(out, upper, stream), dtype, stream); @@ -187,8 +291,49 @@ array normal( } auto stream = to_stream(s); + // Keep normal() on the fp32 sampling path: native bf16 sampling produces + // only ~382 unique values per 100K samples vs ~2254 for fp32-then-cast + // (PR #2361). auto low = array(std::nextafter(-1.0f, 0.0f), float32); auto high = array(1.0f, float32); + + // Chunked normal pipeline; same dispatch rationale as the uniform + // chunked path above. Lets (46341, 46341) bf16 normal fit on 16 GB. + bool chunkable_dtype = (dtype == bfloat16 || dtype == float16); + bool gpu_stream = (stream.device.type == Device::gpu); + bool single_key = !key || (key->shape() == Shape{2}); + size_t total = 1; + for (auto d : shape) { + total *= static_cast(d); + } + size_t bits_bytes = total * size_of(float32); + constexpr size_t kChunkBytes = 256ULL * 1024 * 1024; + if (chunkable_dtype && gpu_stream && single_key && + bits_bytes >= 2 * kChunkBytes && !shape.empty() && shape[0] >= 4 && + !detail::in_tracing()) { + auto eff_key = key ? *key : KeySequence::default_().next(); + size_t K = pick_chunk_count(bits_bytes, shape[0]); + auto applied_scale = array(std::sqrt(2.0), dtype); + if (scale.has_value()) { + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); + } + array loc_dt = loc.has_value() ? astype(*loc, dtype, stream) + : array(0.0, dtype); + bool has_loc = loc.has_value(); + return chunked_fp32_then_cast( + shape, dtype, eff_key, K, stream, + [&](const Shape& chunk_shape, const array& subkey_i) { + auto u = uniform(low, high, chunk_shape, float32, subkey_i, stream); + auto chunk_dt = astype(erfinv(u, stream), dtype, stream); + chunk_dt = multiply(applied_scale, chunk_dt, stream); + if (has_loc) { + chunk_dt = add(loc_dt, chunk_dt, stream); + } + return chunk_dt; + }); + } + auto samples = uniform(low, high, shape, float32, key, stream); auto applied_scale = array(std::sqrt(2.0), dtype); if (scale.has_value()) { diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 551c32993c..907bb8470b 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -388,5 +388,112 @@ def test_broadcastable_scale_loc(self): self.assertEqual(sample.dtype, mx.float16) +class TestRandomChunked(mlx_tests.MLXTestCase): + """Tests specifically exercising the chunked random path added in + path-c. The path activates when the fp32-equivalent output size is + >= ~512 MB (2 * 256 MB chunk threshold). Shapes below that stay on + the vanilla path and are covered by TestRandom already. + """ + + # 16384 * 16384 = 268M elements; fp32-equiv = 1.07 GB -> triggers + # chunking for any dtype. Big enough to hit the chunked path, + # small enough to fit in M4 CI memory. + CHUNK_SHAPE = (16384, 16384) + + def test_uniform_bf16_chunked_no_inf(self): + mx.random.seed(42) + a = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + self.assertFalse(bool(mx.any(mx.isinf(a)).item())) + self.assertFalse(bool(mx.any(mx.isnan(a)).item())) + + def test_normal_bf16_chunked_no_inf(self): + mx.random.seed(42) + a = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + self.assertFalse(bool(mx.any(mx.isinf(a)).item())) + self.assertFalse(bool(mx.any(mx.isnan(a)).item())) + + def test_uniform_bf16_chunked_reproducibility(self): + # Same seed -> same output even in chunked path + mx.random.seed(7) + a = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.random.seed(7) + b = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a, b) + self.assertTrue(bool(mx.array_equal(a, b).item())) + + def test_normal_bf16_chunked_reproducibility(self): + mx.random.seed(7) + a = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.random.seed(7) + b = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a, b) + self.assertTrue(bool(mx.array_equal(a, b).item())) + + def test_uniform_bf16_chunked_distribution(self): + mx.random.seed(0) + a = mx.random.uniform(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + # Uniform [0, 1): population std = 1/sqrt(12) ≈ 0.2887. + # Sample mean SE = std/sqrt(N). 5σ tolerance ≈ 1 in 1.7M + # false-fail rate. bf16 quantization spacing near 0.5 is + # 2^-8 ≈ 0.004; use max(5σ, 0.003) as the bound. + mean = float(a.astype(mx.float32).mean()) + std = float(a.astype(mx.float32).std()) + n = float(self.CHUNK_SHAPE[0] * self.CHUNK_SHAPE[1]) + sigma_mean = (1.0 / math.sqrt(12.0)) / math.sqrt(n) + bound_mean = max(5 * sigma_mean, 0.003) + self.assertLess(abs(mean - 0.5), bound_mean) + self.assertLess(abs(std - 0.2887), 0.005) + self.assertGreaterEqual(float(a.min()), 0.0) + self.assertLess(float(a.max()), 1.0) + + def test_normal_bf16_chunked_distribution(self): + mx.random.seed(0) + a = mx.random.normal(shape=self.CHUNK_SHAPE, dtype=mx.bfloat16) + mx.eval(a) + mean = float(a.astype(mx.float32).mean()) + std = float(a.astype(mx.float32).std()) + # Normal(0, 1): mean SE = 1/sqrt(N). 5σ + bf16 quantization floor. + n = float(self.CHUNK_SHAPE[0] * self.CHUNK_SHAPE[1]) + sigma_mean = 1.0 / math.sqrt(n) + bound_mean = max(5 * sigma_mean, 0.003) + self.assertLess(abs(mean), bound_mean) + self.assertLess(abs(std - 1.0), 0.005) + # Quality bar: chunked path must produce a healthy variety of + # bit-patterns (PR #2361 baseline ~2254 for vanilla bf16 + # normal at 100K samples; chunked at 268M should easily + # exceed 2000). + bits = a.view(mx.uint16).flatten()[:1_000_000] + host = set(bits.tolist()) + self.assertGreaterEqual( + len(host), 2000, + f"chunked bf16 normal unique-value count {len(host)} below quality floor", + ) + + def test_uniform_chunked_odd_first_dim(self): + # First-axis odd to verify chunk-remainder handling + shape = (16385, 16384) # first dim odd + mx.random.seed(3) + a = mx.random.uniform(shape=shape, dtype=mx.bfloat16) + mx.eval(a) + self.assertEqual(a.shape, shape) + self.assertFalse(bool(mx.any(mx.isinf(a)).item())) + self.assertFalse(bool(mx.any(mx.isnan(a)).item())) + + def test_normal_fp32_chunked(self): + # fp32 normal also chunked (Phase 5) + shape = (16384, 16384) # 1 GB fp32 -> triggers + mx.random.seed(5) + a = mx.random.normal(shape=shape, dtype=mx.float32) + mx.eval(a) + self.assertEqual(a.shape, shape) + mean = float(a.mean()) + std = float(a.std()) + self.assertLess(abs(mean), 0.001) + self.assertLess(abs(std - 1.0), 0.005) + + if __name__ == "__main__": mlx_tests.MLXTestRunner()