-
Notifications
You must be signed in to change notification settings - Fork 542
Add cost estimates for ragged sort kernels #4228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -271,6 +271,59 @@ def dma_write_loop(col_vmem_start): | |
| inner_kernel() | ||
|
|
||
|
|
||
| def get_cost_estimate( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we allow users to plumb in their own cost estimate and override this? The kernels are not very efficient so estimating cost from theory is not reflective of the time taken
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. e.g. just like for splash there is fwd and bwd cost estimate in base.yml
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| out_size: int, | ||
| hidden_size: int, | ||
| dtype_bytes: int, | ||
| has_weights: bool, | ||
| flops_override: int = -1, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I would just name it cost_estimate instead of flops_override, with the same -1 gets set to a nice default behavior (identical code just variable name)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh nvm I see this just overrides the flops, maybe this makes sense. I forget to consider cost estimate has both bytes and flops... I'm not sure which is a more useful tuning knob for e2e performance (maybe even both...), I think this is fine for now. The end goal is to help XLA schedule collectives
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it makes more sense to me that we can override the bytes_accessed since that is the main cost here, but I don't understand exactly how these affect our goal of helping tune XLA schedules
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added both now |
||
| bytes_accessed_override: int = -1, | ||
| ) -> pl.CostEstimate: | ||
| """Returns a cost estimate for the ragged gather kernel. | ||
|
|
||
| The ragged gather is primarily a data-movement kernel: it gathers rows from | ||
| an input table according to an index array. When ``has_weights`` is True an | ||
| additional element-wise multiply is performed. | ||
|
|
||
| Args: | ||
| out_size: Number of output rows (after padding). | ||
| hidden_size: Number of columns in the input / output. | ||
| dtype_bytes: Size of one element in bytes (e.g. 2 for bf16, 4 for f32). | ||
| has_weights: Whether per-row weighting is applied. | ||
| flops_override: If > 0, use this value as the flop count instead of | ||
| auto-computing. -1 (default) means auto-compute. | ||
| bytes_accessed_override: If > 0, use this value as bytes_accessed instead | ||
| of auto-computing. -1 (default) means auto-compute. | ||
|
|
||
| Returns: | ||
| A ``pl.CostEstimate`` suitable for XLA scheduling. | ||
| """ | ||
| # Flops: one multiply per element when weighting is enabled. | ||
| if flops_override > 0: | ||
| flops = flops_override | ||
| else: | ||
| flops = out_size * hidden_size if has_weights else 0 | ||
|
|
||
| if bytes_accessed_override > 0: | ||
| bytes_accessed = bytes_accessed_override | ||
| else: | ||
| # Bytes accessed: | ||
| # read – gathered input rows + indices (int32) + optional weights (f32) | ||
| # write – output rows | ||
| bytes_in = out_size * hidden_size * dtype_bytes # input rows read | ||
| bytes_in += out_size * 4 # indices (int32) | ||
| if has_weights: | ||
| bytes_in += out_size * 4 # weights (float32) | ||
| bytes_out = out_size * hidden_size * dtype_bytes # output rows written | ||
| bytes_accessed = bytes_in + bytes_out | ||
|
|
||
| return pl.CostEstimate( | ||
| flops=flops, | ||
| bytes_accessed=bytes_accessed, | ||
| transcendentals=0, | ||
| ) | ||
|
|
||
|
|
||
| def _fallback_implementation( | ||
| x: jax.Array, | ||
| indices: jax.Array, | ||
|
|
@@ -308,7 +361,9 @@ def calculate_col_size(hidden_size: int) -> int: | |
| return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes | ||
|
|
||
|
|
||
| @functools.partial(jax.jit, static_argnames=("has_weights", "enforce_fallback")) | ||
| @functools.partial( | ||
| jax.jit, static_argnames=("has_weights", "enforce_fallback", "flops_override", "bytes_accessed_override") | ||
| ) | ||
| def ragged_gather( | ||
| x: jax.Array, | ||
| indices: jax.Array, | ||
|
|
@@ -317,6 +372,8 @@ def ragged_gather( | |
| weights: jax.Array | None = None, | ||
| has_weights: bool = False, | ||
| enforce_fallback: bool = False, | ||
| flops_override: int = -1, | ||
| bytes_accessed_override: int = -1, | ||
| ) -> jax.Array: | ||
| """Perform gather on indices within dynamic array start and end. | ||
|
|
||
|
|
@@ -333,6 +390,10 @@ def ragged_gather( | |
| enforce_fallback: Static bool flag. When ``True``, unconditionally use the | ||
| JAX reference implementation instead of the SparseCore kernel. | ||
| When ``False`` (default), use the SparseCore kernel and raise any error. | ||
| flops_override: If > 0, use this value as the flop count instead of | ||
| auto-computing. -1 (default) means auto-compute. | ||
| bytes_accessed_override: If > 0, use this value as bytes_accessed instead | ||
| of auto-computing. -1 (default) means auto-compute. | ||
|
|
||
| Returns: | ||
| Gathered output of shape ``(indices_size, hidden_size)``. | ||
|
|
@@ -395,6 +456,14 @@ def ragged_gather( | |
| compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args | ||
| **_COMPILER_PARAMS, | ||
| ), | ||
| cost_estimate=get_cost_estimate( | ||
| out_size=out_size + out_pad_size, | ||
| hidden_size=aligned_hidden_size, | ||
| dtype_bytes=jax.dtypes.itemsize_bits(dtype) // 8, | ||
| has_weights=has_weights, | ||
| flops_override=flops_override, | ||
| bytes_accessed_override=bytes_accessed_override, | ||
| ), | ||
| mesh=vector_mesh, | ||
| name="sc_ragged_gather", | ||
| **{ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,61 @@ def _align_to(a, b): | |
| return ((a + b - 1) // b) * b | ||
|
|
||
|
|
||
| def get_cost_estimate( | ||
| padded_input_size: int, | ||
| aligned_hidden_size: int, | ||
| reduce_group_size: int, | ||
| input_dtype_bytes: int, | ||
| bytes_accessed_override: int = -1, | ||
| flops_override: int = -1, | ||
| ) -> pl.CostEstimate: | ||
| """Returns a cost estimate for the ragged gather-reduce kernel. | ||
|
|
||
| The kernel gathers rows, multiplies each by a scalar weight, and reduces | ||
| (sums) every ``reduce_group_size`` rows into one output row. | ||
|
|
||
| Args: | ||
| padded_input_size: Total number of source rows (after padding). | ||
| aligned_hidden_size: Number of columns (after alignment). | ||
| reduce_group_size: Number of source rows reduced into each output row. | ||
| input_dtype_bytes: Size of one input element in bytes. | ||
| bytes_accessed_override: If > 0, use this value as bytes_accessed instead | ||
| of auto-computing. -1 (default) means auto-compute. | ||
| flops_override: If > 0, use this value as the flop count instead of | ||
| auto-computing. -1 (default) means auto-compute. | ||
|
|
||
| Returns: | ||
| A ``pl.CostEstimate`` suitable for XLA scheduling. | ||
| """ | ||
| # Flops: | ||
| # - one multiply per element for weighting: padded_input_size * aligned_hidden_size | ||
| # - one add per element for reduction: padded_input_size * aligned_hidden_size | ||
| if flops_override > 0: | ||
| flops = flops_override | ||
| else: | ||
| flops = 2 * padded_input_size * aligned_hidden_size | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure there are any flops here actually - do flops of a cost_estimate refer to MXU flops? The additions performed here won't happen on the MXU. The flops here are tiny anyway unless the user is setting a large override to tune scheduling, so maybe this is fine, but we should default to 0 if flops of this cost_estimate refers to MXU flops
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The flops are for tokens/gradients accumulation... Agree they are not very useful.. |
||
|
|
||
| if bytes_accessed_override > 0: | ||
| bytes_accessed = bytes_accessed_override | ||
| else: | ||
| # Bytes accessed: | ||
| # read – input rows + src_indices (int32) + dst_indices (int32) + topk_weights (f32) | ||
| # write – output rows (float32) | ||
| bytes_in = padded_input_size * aligned_hidden_size * input_dtype_bytes # input rows | ||
| bytes_in += padded_input_size * 4 # src_indices (int32) | ||
| bytes_in += padded_input_size * 4 # dst_indices (int32) | ||
| bytes_in += padded_input_size * 4 # topk_weights (float32) | ||
| output_rows = padded_input_size // reduce_group_size | ||
| bytes_out = output_rows * aligned_hidden_size * 4 # output rows (float32) | ||
| bytes_accessed = bytes_in + bytes_out | ||
|
|
||
| return pl.CostEstimate( | ||
| flops=flops, | ||
| bytes_accessed=bytes_accessed, | ||
| transcendentals=0, | ||
| ) | ||
|
|
||
|
|
||
| def _fallback_implementation( | ||
| x: jax.Array, | ||
| indices: jax.Array, | ||
|
|
@@ -370,14 +425,18 @@ def _preprocess( | |
| ) | ||
|
|
||
|
|
||
| @functools.partial(jax.jit, static_argnames=("reduce_group_size", "enforce_fallback")) | ||
| @functools.partial( | ||
| jax.jit, static_argnames=("reduce_group_size", "enforce_fallback", "flops_override", "bytes_accessed_override") | ||
| ) | ||
| def ragged_gather_reduce( | ||
| x: jax.Array, | ||
| indices: jax.Array, | ||
| topk_weights: jax.Array, | ||
| valid_rows_mask: jax.Array, | ||
| reduce_group_size: int, | ||
| enforce_fallback: bool = False, | ||
| flops_override: int = -1, | ||
| bytes_accessed_override: int = -1, | ||
| ) -> jax.Array: | ||
| """Gathers `x` according to `indices`, applies weights and masks, and reduces. | ||
|
|
||
|
|
@@ -502,6 +561,14 @@ def ragged_gather_reduce( | |
| compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args | ||
| **_COMPILER_PARAMS, | ||
| ), | ||
| cost_estimate=get_cost_estimate( | ||
| padded_input_size=padded_input_size, | ||
| aligned_hidden_size=aligned_hidden_size, | ||
| reduce_group_size=reduce_group_size, | ||
| input_dtype_bytes=dtype_bytes, | ||
| flops_override=flops_override, | ||
| bytes_accessed_override=bytes_accessed_override, | ||
| ), | ||
| mesh=vector_mesh, | ||
| name="sc_ragged_gather_reduce", | ||
| **{ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there benefit to exposing four separate cost estimates (fwd dispatch, fwd combine, bwd dispatch, bwd combine?) I guess there are only two main kernels - gather and gather reduce, so only need these two?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think two is enough, though there is slight diff between using weight/not using weight for ragged gather kernel