Skip to content

Add cost estimates for ragged sort kernels#4228

Open
NuojCheng wants to merge 1 commit into
mainfrom
chengnuojin-ragged-cost
Open

Add cost estimates for ragged sort kernels#4228
NuojCheng wants to merge 1 commit into
mainfrom
chengnuojin-ragged-cost

Conversation

@NuojCheng

@NuojCheng NuojCheng commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Description

Add cost estimates for ragged gather and ragged gather reduce kernels, for better XLA compiler scheduling.

This PR also adds four flags for ragged kernel cost estimations:

  • ragged_gather_cost_estimate_flops
  • ragged_gather_reduce_cost_estimate_flops
  • ragged_gather_cost_estimate_bytes_accessed
  • ragged_gather_reduce_cost_estimate_bytes_accessed

Their default values are -1 and any positive input value will replace the flop/bytes estimations.

FIXES: b/525538961

Tests

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 22, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 48.48485% with 17 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/kernels/ragged/ragged_gather_reduce.py 13.33% 13 Missing ⚠️
src/maxtext/kernels/ragged/ragged_gather.py 71.42% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

inner_kernel()


def get_cost_estimate(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we allow users to plumb in their own cost estimate and override this? The kernels are not very efficient so estimating cost from theory is not reflective of the time taken

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. just like for splash there is fwd and bwd cost estimate in base.yml

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@Shuwen-Fang Shuwen-Fang self-requested a review June 23, 2026 00:08
@NuojCheng NuojCheng force-pushed the chengnuojin-ragged-cost branch from a2caaca to d8b31fe Compare June 23, 2026 01:32
@NuojCheng NuojCheng requested a review from michelle-yooh as a code owner June 23, 2026 01:32
# ragged gather SparseCore kernel. When false (default), use the SparseCore kernel.
ragged_gather_reduce_fallback: false # when true, unconditionally use the JAX reference implementation instead of the
# ragged gather reduce SparseCore kernel. When false (default), use the SparseCore kernel.
ragged_gather_cost_estimate_flops: -1 # -1 means auto-compute, any > 0 value overrides the flop cost estimate for the ragged gather kernel

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there benefit to exposing four separate cost estimates (fwd dispatch, fwd combine, bwd dispatch, bwd combine?) I guess there are only two main kernels - gather and gather reduce, so only need these two?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think two is enough, though there is slight diff between using weight/not using weight for ragged gather kernel

hidden_size: int,
dtype_bytes: int,
has_weights: bool,
flops_override: int = -1,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would just name it cost_estimate instead of flops_override, with the same -1 gets set to a nice default behavior (identical code just variable name)

@gobbleturk gobbleturk Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh nvm I see this just overrides the flops, maybe this makes sense. I forget to consider cost estimate has both bytes and flops... I'm not sure which is a more useful tuning knob for e2e performance (maybe even both...), I think this is fine for now. The end goal is to help XLA schedule collectives

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it makes more sense to me that we can override the bytes_accessed since that is the main cost here, but I don't understand exactly how these affect our goal of helping tune XLA schedules

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added both now

if flops_override > 0:
flops = flops_override
else:
flops = 2 * padded_input_size * aligned_hidden_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there are any flops here actually - do flops of a cost_estimate refer to MXU flops? The additions performed here won't happen on the MXU. The flops here are tiny anyway unless the user is setting a large override to tune scheduling, so maybe this is fine, but we should default to 0 if flops of this cost_estimate refers to MXU flops

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flops are for tokens/gradients accumulation... Agree they are not very useful..

@NuojCheng NuojCheng force-pushed the chengnuojin-ragged-cost branch from d8b31fe to 5a9aca2 Compare June 23, 2026 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants