diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index f0c300209b..d8e6d614a1 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -226,6 +226,10 @@ ragged_gather_fallback: false # when true, unconditionally use the JAX reference # ragged gather SparseCore kernel. When false (default), use the SparseCore kernel. ragged_gather_reduce_fallback: false # when true, unconditionally use the JAX reference implementation instead of the # ragged gather reduce SparseCore kernel. When false (default), use the SparseCore kernel. +ragged_gather_cost_estimate_flops: -1 # -1 means auto-compute, any > 0 value overrides the flop cost estimate for the ragged gather kernel +ragged_gather_reduce_cost_estimate_flops: -1 # -1 means auto-compute, any > 0 value overrides the flop cost estimate for the ragged gather reduce kernel +ragged_gather_cost_estimate_bytes_accessed: -1 # -1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate for the ragged gather kernel +ragged_gather_reduce_cost_estimate_bytes_accessed: -1 # -1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate for the ragged gather reduce kernel # tunable tiling dimensions used for mlp gmm # megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`) # tokamax ragged dot - supports all 18 configs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 98b24ff451..345b84d81e 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -760,6 +760,26 @@ class MoEGeneral(BaseModel): description="When true, unconditionally use the JAX reference implementation instead of the ragged gather " "reduce SparseCore kernel. When false (default), use the SparseCore kernel.", ) + ragged_gather_cost_estimate_flops: int = Field( + -1, + description="Flop cost estimate override for the ragged gather kernel. " + "-1 means auto-compute, any > 0 value overrides the flop cost estimate.", + ) + ragged_gather_reduce_cost_estimate_flops: int = Field( + -1, + description="Flop cost estimate override for the ragged gather reduce kernel. " + "-1 means auto-compute, any > 0 value overrides the flop cost estimate.", + ) + ragged_gather_cost_estimate_bytes_accessed: int = Field( + -1, + description="Bytes-accessed cost estimate override for the ragged gather kernel. " + "-1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate.", + ) + ragged_gather_reduce_cost_estimate_bytes_accessed: int = Field( + -1, + description="Bytes-accessed cost estimate override for the ragged gather reduce kernel. " + "-1 means auto-compute, any > 0 value overrides the bytes_accessed cost estimate.", + ) use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") moe_fsdp_use_two_stage_all_gather: bool = Field( diff --git a/src/maxtext/kernels/ragged/ragged_gather.py b/src/maxtext/kernels/ragged/ragged_gather.py index 174507e2d0..14aa989c88 100644 --- a/src/maxtext/kernels/ragged/ragged_gather.py +++ b/src/maxtext/kernels/ragged/ragged_gather.py @@ -271,6 +271,59 @@ def dma_write_loop(col_vmem_start): inner_kernel() +def get_cost_estimate( + out_size: int, + hidden_size: int, + dtype_bytes: int, + has_weights: bool, + flops_override: int = -1, + bytes_accessed_override: int = -1, +) -> pl.CostEstimate: + """Returns a cost estimate for the ragged gather kernel. + + The ragged gather is primarily a data-movement kernel: it gathers rows from + an input table according to an index array. When ``has_weights`` is True an + additional element-wise multiply is performed. + + Args: + out_size: Number of output rows (after padding). + hidden_size: Number of columns in the input / output. + dtype_bytes: Size of one element in bytes (e.g. 2 for bf16, 4 for f32). + has_weights: Whether per-row weighting is applied. + flops_override: If > 0, use this value as the flop count instead of + auto-computing. -1 (default) means auto-compute. + bytes_accessed_override: If > 0, use this value as bytes_accessed instead + of auto-computing. -1 (default) means auto-compute. + + Returns: + A ``pl.CostEstimate`` suitable for XLA scheduling. + """ + # Flops: one multiply per element when weighting is enabled. + if flops_override > 0: + flops = flops_override + else: + flops = out_size * hidden_size if has_weights else 0 + + if bytes_accessed_override > 0: + bytes_accessed = bytes_accessed_override + else: + # Bytes accessed: + # read – gathered input rows + indices (int32) + optional weights (f32) + # write – output rows + bytes_in = out_size * hidden_size * dtype_bytes # input rows read + bytes_in += out_size * 4 # indices (int32) + if has_weights: + bytes_in += out_size * 4 # weights (float32) + bytes_out = out_size * hidden_size * dtype_bytes # output rows written + bytes_accessed = bytes_in + bytes_out + + return pl.CostEstimate( + flops=flops, + bytes_accessed=bytes_accessed, + transcendentals=0, + ) + + def _fallback_implementation( x: jax.Array, indices: jax.Array, @@ -308,7 +361,9 @@ def calculate_col_size(hidden_size: int) -> int: return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes -@functools.partial(jax.jit, static_argnames=("has_weights", "enforce_fallback")) +@functools.partial( + jax.jit, static_argnames=("has_weights", "enforce_fallback", "flops_override", "bytes_accessed_override") +) def ragged_gather( x: jax.Array, indices: jax.Array, @@ -317,6 +372,8 @@ def ragged_gather( weights: jax.Array | None = None, has_weights: bool = False, enforce_fallback: bool = False, + flops_override: int = -1, + bytes_accessed_override: int = -1, ) -> jax.Array: """Perform gather on indices within dynamic array start and end. @@ -333,6 +390,10 @@ def ragged_gather( enforce_fallback: Static bool flag. When ``True``, unconditionally use the JAX reference implementation instead of the SparseCore kernel. When ``False`` (default), use the SparseCore kernel and raise any error. + flops_override: If > 0, use this value as the flop count instead of + auto-computing. -1 (default) means auto-compute. + bytes_accessed_override: If > 0, use this value as bytes_accessed instead + of auto-computing. -1 (default) means auto-compute. Returns: Gathered output of shape ``(indices_size, hidden_size)``. @@ -395,6 +456,14 @@ def ragged_gather( compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args **_COMPILER_PARAMS, ), + cost_estimate=get_cost_estimate( + out_size=out_size + out_pad_size, + hidden_size=aligned_hidden_size, + dtype_bytes=jax.dtypes.itemsize_bits(dtype) // 8, + has_weights=has_weights, + flops_override=flops_override, + bytes_accessed_override=bytes_accessed_override, + ), mesh=vector_mesh, name="sc_ragged_gather", **{ diff --git a/src/maxtext/kernels/ragged/ragged_gather_reduce.py b/src/maxtext/kernels/ragged/ragged_gather_reduce.py index e6a908c011..49f1346bad 100644 --- a/src/maxtext/kernels/ragged/ragged_gather_reduce.py +++ b/src/maxtext/kernels/ragged/ragged_gather_reduce.py @@ -49,6 +49,61 @@ def _align_to(a, b): return ((a + b - 1) // b) * b +def get_cost_estimate( + padded_input_size: int, + aligned_hidden_size: int, + reduce_group_size: int, + input_dtype_bytes: int, + bytes_accessed_override: int = -1, + flops_override: int = -1, +) -> pl.CostEstimate: + """Returns a cost estimate for the ragged gather-reduce kernel. + + The kernel gathers rows, multiplies each by a scalar weight, and reduces + (sums) every ``reduce_group_size`` rows into one output row. + + Args: + padded_input_size: Total number of source rows (after padding). + aligned_hidden_size: Number of columns (after alignment). + reduce_group_size: Number of source rows reduced into each output row. + input_dtype_bytes: Size of one input element in bytes. + bytes_accessed_override: If > 0, use this value as bytes_accessed instead + of auto-computing. -1 (default) means auto-compute. + flops_override: If > 0, use this value as the flop count instead of + auto-computing. -1 (default) means auto-compute. + + Returns: + A ``pl.CostEstimate`` suitable for XLA scheduling. + """ + # Flops: + # - one multiply per element for weighting: padded_input_size * aligned_hidden_size + # - one add per element for reduction: padded_input_size * aligned_hidden_size + if flops_override > 0: + flops = flops_override + else: + flops = 2 * padded_input_size * aligned_hidden_size + + if bytes_accessed_override > 0: + bytes_accessed = bytes_accessed_override + else: + # Bytes accessed: + # read – input rows + src_indices (int32) + dst_indices (int32) + topk_weights (f32) + # write – output rows (float32) + bytes_in = padded_input_size * aligned_hidden_size * input_dtype_bytes # input rows + bytes_in += padded_input_size * 4 # src_indices (int32) + bytes_in += padded_input_size * 4 # dst_indices (int32) + bytes_in += padded_input_size * 4 # topk_weights (float32) + output_rows = padded_input_size // reduce_group_size + bytes_out = output_rows * aligned_hidden_size * 4 # output rows (float32) + bytes_accessed = bytes_in + bytes_out + + return pl.CostEstimate( + flops=flops, + bytes_accessed=bytes_accessed, + transcendentals=0, + ) + + def _fallback_implementation( x: jax.Array, indices: jax.Array, @@ -370,7 +425,9 @@ def _preprocess( ) -@functools.partial(jax.jit, static_argnames=("reduce_group_size", "enforce_fallback")) +@functools.partial( + jax.jit, static_argnames=("reduce_group_size", "enforce_fallback", "flops_override", "bytes_accessed_override") +) def ragged_gather_reduce( x: jax.Array, indices: jax.Array, @@ -378,6 +435,8 @@ def ragged_gather_reduce( valid_rows_mask: jax.Array, reduce_group_size: int, enforce_fallback: bool = False, + flops_override: int = -1, + bytes_accessed_override: int = -1, ) -> jax.Array: """Gathers `x` according to `indices`, applies weights and masks, and reduces. @@ -502,6 +561,14 @@ def ragged_gather_reduce( compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args **_COMPILER_PARAMS, ), + cost_estimate=get_cost_estimate( + padded_input_size=padded_input_size, + aligned_hidden_size=aligned_hidden_size, + reduce_group_size=reduce_group_size, + input_dtype_bytes=dtype_bytes, + flops_override=flops_override, + bytes_accessed_override=bytes_accessed_override, + ), mesh=vector_mesh, name="sc_ragged_gather_reduce", **{ diff --git a/src/maxtext/kernels/ragged/ragged_sort.py b/src/maxtext/kernels/ragged/ragged_sort.py index 67fc12da94..a6a0485918 100644 --- a/src/maxtext/kernels/ragged/ragged_sort.py +++ b/src/maxtext/kernels/ragged/ragged_sort.py @@ -30,6 +30,10 @@ def ring_ragged_sort( buffer_size=None, enforce_gather_fallback=False, enforce_gather_reduce_fallback=False, + gather_flops_override=-1, + gather_reduce_flops_override=-1, + gather_bytes_accessed_override=-1, + gather_reduce_bytes_accessed_override=-1, ): """Ragged-gather variant for AG-RS Expert Parallelism token routing. @@ -105,6 +109,8 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local): shard_output_start[None], shard_output_end[None], enforce_fallback=enforce_gather_fallback, + flops_override=gather_flops_override, + bytes_accessed_override=gather_bytes_accessed_override, ) else: local_buffer_size = buffer_size @@ -126,6 +132,8 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local): jnp.int32(0)[None], gather_end[None], enforce_fallback=enforce_gather_fallback, + flops_override=gather_flops_override, + bytes_accessed_override=gather_bytes_accessed_override, ) out = (x, group_sizes_local, topk_argsort_revert_indices) @@ -219,6 +227,10 @@ def ring_ragged_unsort( topk_weights, enforce_gather_fallback=False, enforce_gather_reduce_fallback=False, + gather_flops_override=-1, + gather_reduce_flops_override=-1, + gather_bytes_accessed_override=-1, + gather_reduce_bytes_accessed_override=-1, ): """Dual of :func:`ring_ragged_sort`. diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 5925b0bb4e..7a9755b14b 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -864,6 +864,10 @@ def permute( buffer_size=buffer_size, enforce_gather_fallback=self.config.ragged_gather_fallback, enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback, + gather_flops_override=self.config.ragged_gather_cost_estimate_flops, + gather_reduce_flops_override=self.config.ragged_gather_reduce_cost_estimate_flops, + gather_bytes_accessed_override=self.config.ragged_gather_cost_estimate_bytes_accessed, + gather_reduce_bytes_accessed_override=self.config.ragged_gather_reduce_cost_estimate_bytes_accessed, ) else: flatten_selected_experts = jnp.ravel(selected_experts) @@ -945,6 +949,10 @@ def unpermute( topk_weights=flat_weights, enforce_gather_fallback=self.config.ragged_gather_fallback, enforce_gather_reduce_fallback=self.config.ragged_gather_reduce_fallback, + gather_flops_override=self.config.ragged_gather_cost_estimate_flops, + gather_reduce_flops_override=self.config.ragged_gather_reduce_cost_estimate_flops, + gather_bytes_accessed_override=self.config.ragged_gather_cost_estimate_bytes_accessed, + gather_reduce_bytes_accessed_override=self.config.ragged_gather_reduce_cost_estimate_bytes_accessed, ) else: unsort_intermediate = _sort_activations(