Add cost estimates for ragged sort kernels#4228
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
0065932 to
a2caaca
Compare
| inner_kernel() | ||
|
|
||
|
|
||
| def get_cost_estimate( |
There was a problem hiding this comment.
can we allow users to plumb in their own cost estimate and override this? The kernels are not very efficient so estimating cost from theory is not reflective of the time taken
There was a problem hiding this comment.
e.g. just like for splash there is fwd and bwd cost estimate in base.yml
a2caaca to
d8b31fe
Compare
| # ragged gather SparseCore kernel. When false (default), use the SparseCore kernel. | ||
| ragged_gather_reduce_fallback: false # when true, unconditionally use the JAX reference implementation instead of the | ||
| # ragged gather reduce SparseCore kernel. When false (default), use the SparseCore kernel. | ||
| ragged_gather_cost_estimate_flops: -1 # -1 means auto-compute, any > 0 value overrides the flop cost estimate for the ragged gather kernel |
There was a problem hiding this comment.
Is there benefit to exposing four separate cost estimates (fwd dispatch, fwd combine, bwd dispatch, bwd combine?) I guess there are only two main kernels - gather and gather reduce, so only need these two?
There was a problem hiding this comment.
I think two is enough, though there is slight diff between using weight/not using weight for ragged gather kernel
| hidden_size: int, | ||
| dtype_bytes: int, | ||
| has_weights: bool, | ||
| flops_override: int = -1, |
There was a problem hiding this comment.
nit: I would just name it cost_estimate instead of flops_override, with the same -1 gets set to a nice default behavior (identical code just variable name)
There was a problem hiding this comment.
oh nvm I see this just overrides the flops, maybe this makes sense. I forget to consider cost estimate has both bytes and flops... I'm not sure which is a more useful tuning knob for e2e performance (maybe even both...), I think this is fine for now. The end goal is to help XLA schedule collectives
There was a problem hiding this comment.
it makes more sense to me that we can override the bytes_accessed since that is the main cost here, but I don't understand exactly how these affect our goal of helping tune XLA schedules
| if flops_override > 0: | ||
| flops = flops_override | ||
| else: | ||
| flops = 2 * padded_input_size * aligned_hidden_size |
There was a problem hiding this comment.
I'm not sure there are any flops here actually - do flops of a cost_estimate refer to MXU flops? The additions performed here won't happen on the MXU. The flops here are tiny anyway unless the user is setting a large override to tune scheduling, so maybe this is fine, but we should default to 0 if flops of this cost_estimate refers to MXU flops
There was a problem hiding this comment.
The flops are for tokens/gradients accumulation... Agree they are not very useful..
d8b31fe to
5a9aca2
Compare
Description
Add cost estimates for ragged gather and ragged gather reduce kernels, for better XLA compiler scheduling.
This PR also adds four flags for ragged kernel cost estimations:
Their default values are -1 and any positive input value will replace the flop/bytes estimations.
FIXES: b/525538961
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.