Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlx/backend/cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
});
}

void RandomUniform::eval_cpu(const std::vector<array>&, array&) {
// GPU-only primitive (the peak-memory amplification it solves is a
// GPU buffer problem). The dispatch in random::uniform skips this
// primitive on CPU streams, so reaching here indicates a misuse.
throw std::runtime_error("[RandomUniform::eval_cpu] GPU-only primitive.");
}

void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
Expand Down
98 changes: 98 additions & 0 deletions mlx/backend/cuda/random.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,106 @@ __global__ void rbits(
}
}

// Fused per-thread uniform RNG for half-precision targets. Mirrors the
// Metal kernel in mlx/backend/metal/kernels/random.metal: each thread
// emits TWO output elements at positions y and y + half_size from a
// single threefry call, matching the rbitsc bit-layout exactly so that
// seed -> output mapping is bit-identical to the vanilla
// bits()/divide()/astype()/affine() pipeline (no fp32 intermediate
// buffer in global memory).
//
// NOTE: This CUDA path is untested on this branch (no NVIDIA hardware
// available during development). Algorithmically it is the direct CUDA
// transcription of the validated Metal kernel; the C++ dispatch logic
// in the host-side primitives.cpp is shared. CI / maintainer review
// required to confirm correctness on a CUDA-capable device.
template <typename T>
__global__ void runiformc_fused(
const uint32_t* keys,
T* out,
uint32_t half_size,
float lo,
float range,
float upper_clip) {
uint32_t y = blockIdx.x * blockDim.x + threadIdx.x;
if (y >= half_size) {
return;
}
uint2 key2 = uint2{keys[0], keys[1]};
rbits hash = threefry2x32_hash(key2, uint2{y, y + half_size});

float f0 = float(hash.val.x) / 4294967295.0f;
if (f0 > upper_clip) f0 = upper_clip;
T t0 = T(f0);
T r_dt = T(range);
T lo_dt = T(lo);
T tr0 = r_dt * t0;
out[y] = tr0 + lo_dt;

float f1 = float(hash.val.y) / 4294967295.0f;
if (f1 > upper_clip) f1 = upper_clip;
T t1 = T(f1);
T tr1 = r_dt * t1;
out[y + half_size] = tr1 + lo_dt;
}

} // namespace cu

void RandomUniform::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("RandomUniform::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) {
return;
}
size_t N = out.size();
if (N % 2 != 0) {
throw std::runtime_error(
"[RandomUniform::eval_gpu] CUDA fast path requires even N; "
"fall back to vanilla pipeline at the random.cpp level.");
}
auto& keys = inputs[0];
encoder.set_input_array(keys);
encoder.set_output_array(out);
size_t half = N / 2;
if (half >= UINT32_MAX) {
throw std::runtime_error(
"[RandomUniform::eval_gpu] half output size exceeds UINT32_MAX");
}
float lo = low_;
float range = high_ - low_;
float upper_clip =
(dtype_ == bfloat16) ? 0.99609375f : 0.99951171875f;
uint32_t threads = 256;
uint32_t blocks =
static_cast<uint32_t>((half + threads - 1) / threads);
if (dtype_ == bfloat16) {
encoder.add_kernel_node(
cu::runiformc_fused<__nv_bfloat16>,
dim3{blocks},
dim3{threads},
gpu_ptr<uint32_t>(keys),
gpu_ptr<__nv_bfloat16>(out),
static_cast<uint32_t>(half),
lo,
range,
upper_clip);
} else {
encoder.add_kernel_node(
cu::runiformc_fused<__half>,
dim3{blocks},
dim3{threads},
gpu_ptr<uint32_t>(keys),
gpu_ptr<__half>(out),
static_cast<uint32_t>(half),
lo,
range,
upper_clip);
}
}

void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("RandomBits::eval_gpu");
assert(inputs.size() == 1);
Expand Down
36 changes: 36 additions & 0 deletions mlx/backend/metal/kernels/random.metal
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,39 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
}
}
}

// Fused per-thread uniform RNG for half-precision targets. Each thread
// emits two outputs at positions y and y + grid_dim.y from a single
// threefry call, matching the rbitsc bit-layout exactly so the
// seed -> output mapping is bit-identical to the vanilla
// bits()/divide()/astype()/affine() pipeline (no fp32 intermediate
// buffer in global memory).
template <typename T>
[[kernel]] void runiformc(
device const uint32_t* keys,
device T* out,
constant const float& lo,
constant const float& range,
constant const float& upper_clip,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
uint2 key2 = uint2(keys[0], keys[1]);
uint y = index.y;
uint half_size = grid_dim.y;
union rbits hash = threefry2x32_hash(key2, uint2(y, y + half_size));

T r_dt = T(range);
T lo_dt = T(lo);

float f0 = min(float(hash.val.x) / 4294967295.0f, upper_clip);
out[y] = r_dt * T(f0) + lo_dt;

float f1 = min(float(hash.val.y) / 4294967295.0f, upper_clip);
out[y + half_size] = r_dt * T(f1) + lo_dt;
}

#define instantiate_runiformc(tname, type) \
instantiate_kernel("runiformc_" #tname, runiformc, type)

instantiate_runiformc(float16, half)
instantiate_runiformc(bfloat16, bfloat16_t)
36 changes: 36 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,42 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

void RandomUniform::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
// Two outputs per thread (rbitsc layout) requires an even N for
// bit-exact match with vanilla.
size_t N = out.size();
if (N % 2 != 0) {
throw std::runtime_error("[RandomUniform::eval_gpu] N must be even.");
}
auto& keys = inputs[0];
auto& s = stream();
auto& d = metal::device(s.device);

std::string tname = (dtype_ == bfloat16) ? "bfloat16" : "float16";
auto kernel = d.get_kernel("runiformc_" + tname);

float lo = low_;
float range = high_ - low_;
float upper_clip = (dtype_ == bfloat16) ? 0.99609375f : 0.99951171875f;

size_t half = N / 2;
MTL::Size grid_dims = MTL::Size(1, half, 1);
auto group_dims = get_block_dims(1, half, 1);
auto& compute_encoder = metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(keys, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(lo, 2);
compute_encoder.set_bytes(range, 3);
compute_encoder.set_bytes(upper_clip, 4);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

void QRF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
Expand Down
14 changes: 14 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3822,6 +3822,20 @@ bool RandomBits::is_equivalent(const Primitive& other) const {
return shape_ == r_other.shape_ && width_ == r_other.width_;
}

std::pair<std::vector<array>, std::vector<int>> RandomUniform::vmap(
const std::vector<array>&, const std::vector<int>&) {
// Defense-in-depth: random::uniform already skips this primitive when
// detail::in_tracing() is true, so vmap should never reach here.
throw std::runtime_error(
"[RandomUniform::vmap] not supported; use float32 uniform and astype.");
}

bool RandomUniform::is_equivalent(const Primitive& other) const {
const RandomUniform& r_other = static_cast<const RandomUniform&>(other);
return shape_ == r_other.shape_ && dtype_ == r_other.dtype_ &&
low_ == r_other.low_ && high_ == r_other.high_;
}

std::vector<array> Real::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
Expand Down
31 changes: 31 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,37 @@ class RandomBits : public UnaryPrimitive {
int width_;
};

// Fused uniform-in-target-dtype primitive for half-precision outputs.
// Avoids the 3x peak memory overhead of bits()->divide()->astype()
// pipelines by computing the entire transform per-thread in registers.
class RandomUniform : public UnaryPrimitive {
public:
explicit RandomUniform(
Stream stream,
const Shape& shape,
Dtype dtype,
float low,
float high)
: UnaryPrimitive(stream),
shape_(shape),
dtype_(dtype),
low_(low),
high_(high) {}

void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;

DEFINE_VMAP()
DEFINE_NAME(RandomUniform)
bool is_equivalent(const Primitive& other) const override;

private:
Shape shape_;
Dtype dtype_;
float low_;
float high_;
};

class Real : public UnaryPrimitive {
public:
explicit Real(Stream stream) : UnaryPrimitive(stream) {}
Expand Down
Loading