Fix logical sharding resolution in NNX#4205
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
e044d17 to
0730724
Compare
| ) | ||
|
|
||
|
|
||
| def _resolve_logical_sharding(out_sharding, context_rules, local_rules) -> list: |
There was a problem hiding this comment.
could you move this function to utils/sharding.py? The goal is to move all sharding related util functions to this file.
There was a problem hiding this comment.
Moved the get_nnx_named_sharding_with_scan_axis() to sharding.py, and moved the corresponding unit tests accordingly.
| # We define rules for 'embed' mapping to 'fsdp' (specific) then 'layers' (fallback) | ||
| rules = ( | ||
| ("embed", "fsdp"), | ||
| ("embed", "layers"), |
There was a problem hiding this comment.
layers is not a physical axis name? maybe for something else, expert?
There was a problem hiding this comment.
Changed to "stage", and use "layers" only as logical name.
| ("embed", "layers"), | ||
| ) | ||
| with nn_partitioning.axis_rules(rules): | ||
| with jax.set_mesh(self.mesh): |
There was a problem hiding this comment.
with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules):
| # When matching 'mlp', 'fsdp' is already bound, so it is skipped (unassigned/None). | ||
| rules = ( | ||
| ("embed", ("fsdp", "layers")), | ||
| ("mlp", "fsdp"), |
There was a problem hiding this comment.
IMO this is an error caused by the rule. We should throw an error if there is a conflict instead of silently solving it.
There was a problem hiding this comment.
In base.yaml: both mlp and embed have the "fsdp_transpose" physical match:
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
...
['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']],
And this is used in MLP layer:
kernel_axes=("embed", "mlp"),
I guess this is not a config issue.
Linen resolved this by calling remove_size_one_mesh_axis on the physical specs. We will now apply the same to NNX.
There was a problem hiding this comment.
I would expect 'embed' downgraded to the next item,
maxtext/src/maxtext/configs/base.yml
Line 558 in 0b9f604
There was a problem hiding this comment.
You are right. Checked the Linen implementation, and it does fall to the next item.
I modified the NNX code to re-use the linen nn.logical_to_mesh_axes(), instead of implementing the same func for NNX.
Now they should have the same behaviour.
NuojCheng
left a comment
There was a problem hiding this comment.
Overall I like this feature. I was concerned that this function would silently hide some errors that should be explicit raised, e.g. logical rule conflicting part. I agree we should improve the from_sharding_rules if it simply treats logical rule as dictionary.
0730724 to
80b40f7
Compare
94b165c to
c4c6d80
Compare
With pure NNX implementation of MaxText models, are we still allowed to import linen functions? If the answer is yes, then we can continue to re-use Linen implementation. I think we can decide during the Linen removal phase. (PR#12) |
Pure NNX training runs previously used custom logical sharding resolution helpers which diverged from the standard Flax Linen path, causing logical axis fallback mismatch and DuplicateSpecErrors when multiple logical dimensions mapped to a single physical axis. This change aligns the NNX path with Flax Linen and consolidates utilities: 1. Replaced the custom rules resolution logic with standard Flax Linen `logical_to_mesh_axes` to ensure identical behavior for rules mapping. 2. Added the `remove_size_one_mesh_axis` reduction step inside the NNX variable resolver to strip size-1 axes from the PartitionSpec, preventing JAX from raising DuplicateSpecError on models with overlapping axis mappings. 3. Aligned the variable wrappers and extraction lifecycle: - `sharding.nnx_construct_named_sharding` and `sharding.get_nnx_var_named_sharding_with_scan_axis` retain standard Flax NNX `Variable` / `Param` wrappers to maintain structural type compatibility during multi-tree maps in trainer setup. - `maxtext_utils_nnx.nnx_extract_named_sharding` extracts clean JAX-native `NamedSharding` trees for compilation and device dispatch. 4. Cleaned up comments and unit tests (in `sharding_nnx_test.py` and `maxtext_utils_nnx_test.py`) to verify behavior on local meshes and support CPU-only testing environments by avoiding host offloading during JIT.
c4c6d80 to
0750360
Compare
Description
In pure NNX training runs, model variables retrieve physical PartitionSpecs via
get_nnx_named_sharding_with_scan_axisinmaxtext_utils.py. Previously, this helper used Flax core SPMD'sfrom_sharding_rulesto map logical names to physical axes. However,from_sharding_rulesresolves rules by converting the rules list into a dictionary (last-write-wins). This caused fallback rules sharing the same logical name (e.g. 'embed') to overwrite preceding specific rules, dropping essential axes likefsdp_transposeand leading to unsharded parameter percentage assertion errors.Additionally, resolving specifications independently for each dimension without tracking assigned axes could bind a single physical axis (like
fsdp_transpose) to multiple positional dimensions of a tensor, causingDuplicateSpecError.To fix this:
from_sharding_ruleswith a Rules-first resolution loop that matches rules sequentially (first-match-wins), matching Flax Linen's mapping behavior.assigned_axestracker within the loop to ensure physical mesh axes are bound to at most one dimension per tensor.Tests
Log with Gemma3-12B (2x v6e-256)
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.