Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ ragged_gather_fallback: false # when true, unconditionally use the JAX reference
# ragged gather SparseCore kernel. When false (default), use the SparseCore kernel.
ragged_gather_reduce_fallback: false # when true, unconditionally use the JAX reference implementation instead of the
# ragged gather reduce SparseCore kernel. When false (default), use the SparseCore kernel.
ragged_gather_cost_estimate_flops: -1 # -1 means auto-compute, any > 0 value overrides the flop cost estimate for the ragged gather kernel

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there benefit to exposing four separate cost estimates (fwd dispatch, fwd combine, bwd dispatch, bwd combine?) I guess there are only two main kernels - gather and gather reduce, so only need these two?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think two is enough, though there is slight diff between using weight/not using weight for ragged gather kernel

ragged_gather_reduce_cost_estimate_flops: -1 # -1 means auto-compute, any > 0 value overrides the flop cost estimate for the ragged gather reduce kernel
ragged_gather_cost_estimate_bytes_accessed: -1 # -1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate for the ragged gather kernel
ragged_gather_reduce_cost_estimate_bytes_accessed: -1 # -1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate for the ragged gather reduce kernel
# tunable tiling dimensions used for mlp gmm
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
# tokamax ragged dot - supports all 18 configs
Expand Down
20 changes: 20 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,26 @@ class MoEGeneral(BaseModel):
description="When true, unconditionally use the JAX reference implementation instead of the ragged gather "
"reduce SparseCore kernel. When false (default), use the SparseCore kernel.",
)
ragged_gather_cost_estimate_flops: int = Field(
-1,
description="Flop cost estimate override for the ragged gather kernel. "
"-1 means auto-compute, any > 0 value overrides the flop cost estimate.",
)
ragged_gather_reduce_cost_estimate_flops: int = Field(
-1,
description="Flop cost estimate override for the ragged gather reduce kernel. "
"-1 means auto-compute, any > 0 value overrides the flop cost estimate.",
)
ragged_gather_cost_estimate_bytes_accessed: int = Field(
-1,
description="Bytes-accessed cost estimate override for the ragged gather kernel. "
"-1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate.",
)
ragged_gather_reduce_cost_estimate_bytes_accessed: int = Field(
-1,
description="Bytes-accessed cost estimate override for the ragged gather reduce kernel. "
"-1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate.",
)
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
moe_fsdp_use_two_stage_all_gather: bool = Field(
Expand Down
71 changes: 70 additions & 1 deletion src/maxtext/kernels/ragged/ragged_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,59 @@ def dma_write_loop(col_vmem_start):
inner_kernel()


def get_cost_estimate(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we allow users to plumb in their own cost estimate and override this? The kernels are not very efficient so estimating cost from theory is not reflective of the time taken

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. just like for splash there is fwd and bwd cost estimate in base.yml

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

out_size: int,
hidden_size: int,
dtype_bytes: int,
has_weights: bool,
flops_override: int = -1,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would just name it cost_estimate instead of flops_override, with the same -1 gets set to a nice default behavior (identical code just variable name)

@gobbleturk gobbleturk Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh nvm I see this just overrides the flops, maybe this makes sense. I forget to consider cost estimate has both bytes and flops... I'm not sure which is a more useful tuning knob for e2e performance (maybe even both...), I think this is fine for now. The end goal is to help XLA schedule collectives

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it makes more sense to me that we can override the bytes_accessed since that is the main cost here, but I don't understand exactly how these affect our goal of helping tune XLA schedules

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added both now

bytes_accessed_override: int = -1,
) -> pl.CostEstimate:
"""Returns a cost estimate for the ragged gather kernel.

The ragged gather is primarily a data-movement kernel: it gathers rows from
an input table according to an index array. When ``has_weights`` is True an
additional element-wise multiply is performed.

Args:
out_size: Number of output rows (after padding).
hidden_size: Number of columns in the input / output.
dtype_bytes: Size of one element in bytes (e.g. 2 for bf16, 4 for f32).
has_weights: Whether per-row weighting is applied.
flops_override: If > 0, use this value as the flop count instead of
auto-computing. -1 (default) means auto-compute.
bytes_accessed_override: If > 0, use this value as bytes_accessed instead
of auto-computing. -1 (default) means auto-compute.

Returns:
A ``pl.CostEstimate`` suitable for XLA scheduling.
"""
# Flops: one multiply per element when weighting is enabled.
if flops_override > 0:
flops = flops_override
else:
flops = out_size * hidden_size if has_weights else 0

if bytes_accessed_override > 0:
bytes_accessed = bytes_accessed_override
else:
# Bytes accessed:
# read – gathered input rows + indices (int32) + optional weights (f32)
# write – output rows
bytes_in = out_size * hidden_size * dtype_bytes # input rows read
bytes_in += out_size * 4 # indices (int32)
if has_weights:
bytes_in += out_size * 4 # weights (float32)
bytes_out = out_size * hidden_size * dtype_bytes # output rows written
bytes_accessed = bytes_in + bytes_out

return pl.CostEstimate(
flops=flops,
bytes_accessed=bytes_accessed,
transcendentals=0,
)


def _fallback_implementation(
x: jax.Array,
indices: jax.Array,
Expand Down Expand Up @@ -308,7 +361,9 @@ def calculate_col_size(hidden_size: int) -> int:
return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes


@functools.partial(jax.jit, static_argnames=("has_weights", "enforce_fallback"))
@functools.partial(
jax.jit, static_argnames=("has_weights", "enforce_fallback", "flops_override", "bytes_accessed_override")
)
def ragged_gather(
x: jax.Array,
indices: jax.Array,
Expand All @@ -317,6 +372,8 @@ def ragged_gather(
weights: jax.Array | None = None,
has_weights: bool = False,
enforce_fallback: bool = False,
flops_override: int = -1,
bytes_accessed_override: int = -1,
) -> jax.Array:
"""Perform gather on indices within dynamic array start and end.

Expand All @@ -333,6 +390,10 @@ def ragged_gather(
enforce_fallback: Static bool flag. When ``True``, unconditionally use the
JAX reference implementation instead of the SparseCore kernel.
When ``False`` (default), use the SparseCore kernel and raise any error.
flops_override: If > 0, use this value as the flop count instead of
auto-computing. -1 (default) means auto-compute.
bytes_accessed_override: If > 0, use this value as bytes_accessed instead
of auto-computing. -1 (default) means auto-compute.

Returns:
Gathered output of shape ``(indices_size, hidden_size)``.
Expand Down Expand Up @@ -395,6 +456,14 @@ def ragged_gather(
compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args
**_COMPILER_PARAMS,
),
cost_estimate=get_cost_estimate(
out_size=out_size + out_pad_size,
hidden_size=aligned_hidden_size,
dtype_bytes=jax.dtypes.itemsize_bits(dtype) // 8,
has_weights=has_weights,
flops_override=flops_override,
bytes_accessed_override=bytes_accessed_override,
),
mesh=vector_mesh,
name="sc_ragged_gather",
**{
Expand Down
69 changes: 68 additions & 1 deletion src/maxtext/kernels/ragged/ragged_gather_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,61 @@ def _align_to(a, b):
return ((a + b - 1) // b) * b


def get_cost_estimate(
padded_input_size: int,
aligned_hidden_size: int,
reduce_group_size: int,
input_dtype_bytes: int,
bytes_accessed_override: int = -1,
flops_override: int = -1,
) -> pl.CostEstimate:
"""Returns a cost estimate for the ragged gather-reduce kernel.

The kernel gathers rows, multiplies each by a scalar weight, and reduces
(sums) every ``reduce_group_size`` rows into one output row.

Args:
padded_input_size: Total number of source rows (after padding).
aligned_hidden_size: Number of columns (after alignment).
reduce_group_size: Number of source rows reduced into each output row.
input_dtype_bytes: Size of one input element in bytes.
bytes_accessed_override: If > 0, use this value as bytes_accessed instead
of auto-computing. -1 (default) means auto-compute.
flops_override: If > 0, use this value as the flop count instead of
auto-computing. -1 (default) means auto-compute.

Returns:
A ``pl.CostEstimate`` suitable for XLA scheduling.
"""
# Flops:
# - one multiply per element for weighting: padded_input_size * aligned_hidden_size
# - one add per element for reduction: padded_input_size * aligned_hidden_size
if flops_override > 0:
flops = flops_override
else:
flops = 2 * padded_input_size * aligned_hidden_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there are any flops here actually - do flops of a cost_estimate refer to MXU flops? The additions performed here won't happen on the MXU. The flops here are tiny anyway unless the user is setting a large override to tune scheduling, so maybe this is fine, but we should default to 0 if flops of this cost_estimate refers to MXU flops

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flops are for tokens/gradients accumulation... Agree they are not very useful..


if bytes_accessed_override > 0:
bytes_accessed = bytes_accessed_override
else:
# Bytes accessed:
# read – input rows + src_indices (int32) + dst_indices (int32) + topk_weights (f32)
# write – output rows (float32)
bytes_in = padded_input_size * aligned_hidden_size * input_dtype_bytes # input rows
bytes_in += padded_input_size * 4 # src_indices (int32)
bytes_in += padded_input_size * 4 # dst_indices (int32)
bytes_in += padded_input_size * 4 # topk_weights (float32)
output_rows = padded_input_size // reduce_group_size
bytes_out = output_rows * aligned_hidden_size * 4 # output rows (float32)
bytes_accessed = bytes_in + bytes_out

return pl.CostEstimate(
flops=flops,
bytes_accessed=bytes_accessed,
transcendentals=0,
)


def _fallback_implementation(
x: jax.Array,
indices: jax.Array,
Expand Down Expand Up @@ -370,14 +425,18 @@ def _preprocess(
)


@functools.partial(jax.jit, static_argnames=("reduce_group_size", "enforce_fallback"))
@functools.partial(
jax.jit, static_argnames=("reduce_group_size", "enforce_fallback", "flops_override", "bytes_accessed_override")
)
def ragged_gather_reduce(
x: jax.Array,
indices: jax.Array,
topk_weights: jax.Array,
valid_rows_mask: jax.Array,
reduce_group_size: int,
enforce_fallback: bool = False,
flops_override: int = -1,
bytes_accessed_override: int = -1,
) -> jax.Array:
"""Gathers `x` according to `indices`, applies weights and masks, and reduces.

Expand Down Expand Up @@ -502,6 +561,14 @@ def ragged_gather_reduce(
compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args
**_COMPILER_PARAMS,
),
cost_estimate=get_cost_estimate(
padded_input_size=padded_input_size,
aligned_hidden_size=aligned_hidden_size,
reduce_group_size=reduce_group_size,
input_dtype_bytes=dtype_bytes,
flops_override=flops_override,
bytes_accessed_override=bytes_accessed_override,
),
mesh=vector_mesh,
name="sc_ragged_gather_reduce",
**{
Expand Down
12 changes: 12 additions & 0 deletions src/maxtext/kernels/ragged/ragged_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def ring_ragged_sort(
buffer_size=None,
enforce_gather_fallback=False,
enforce_gather_reduce_fallback=False,
gather_flops_override=-1,
gather_reduce_flops_override=-1,
gather_bytes_accessed_override=-1,
gather_reduce_bytes_accessed_override=-1,
):
"""Ragged-gather variant for AG-RS Expert Parallelism token routing.

Expand Down Expand Up @@ -105,6 +109,8 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
shard_output_start[None],
shard_output_end[None],
enforce_fallback=enforce_gather_fallback,
flops_override=gather_flops_override,
bytes_accessed_override=gather_bytes_accessed_override,
)
else:
local_buffer_size = buffer_size
Expand All @@ -126,6 +132,8 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
jnp.int32(0)[None],
gather_end[None],
enforce_fallback=enforce_gather_fallback,
flops_override=gather_flops_override,
bytes_accessed_override=gather_bytes_accessed_override,
)

out = (x, group_sizes_local, topk_argsort_revert_indices)
Expand Down Expand Up @@ -219,6 +227,10 @@ def ring_ragged_unsort(
topk_weights,
enforce_gather_fallback=False,
enforce_gather_reduce_fallback=False,
gather_flops_override=-1,
gather_reduce_flops_override=-1,
gather_bytes_accessed_override=-1,
gather_reduce_bytes_accessed_override=-1,
):
"""Dual of :func:`ring_ragged_sort`.

Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,10 @@ def permute(
buffer_size=buffer_size,
enforce_gather_fallback=self.config.ragged_gather_fallback,
enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback,
gather_flops_override=self.config.ragged_gather_cost_estimate_flops,
gather_reduce_flops_override=self.config.ragged_gather_reduce_cost_estimate_flops,
gather_bytes_accessed_override=self.config.ragged_gather_cost_estimate_bytes_accessed,
gather_reduce_bytes_accessed_override=self.config.ragged_gather_reduce_cost_estimate_bytes_accessed,
)
else:
flatten_selected_experts = jnp.ravel(selected_experts)
Expand Down Expand Up @@ -945,6 +949,10 @@ def unpermute(
topk_weights=flat_weights,
enforce_gather_fallback=self.config.ragged_gather_fallback,
enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback,
gather_flops_override=self.config.ragged_gather_cost_estimate_flops,
gather_reduce_flops_override=self.config.ragged_gather_reduce_cost_estimate_flops,
gather_bytes_accessed_override=self.config.ragged_gather_cost_estimate_bytes_accessed,
gather_reduce_bytes_accessed_override=self.config.ragged_gather_reduce_cost_estimate_bytes_accessed,
)
else:
unsort_intermediate = _sort_activations(
Expand Down
Loading