Skip to content

Fix logical sharding resolution in NNX#4205

Open
xibinliu wants to merge 1 commit into
mainfrom
xibin/nnx_sharding
Open

Fix logical sharding resolution in NNX#4205
xibinliu wants to merge 1 commit into
mainfrom
xibin/nnx_sharding

Conversation

@xibinliu

@xibinliu xibinliu commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

Description

In pure NNX training runs, model variables retrieve physical PartitionSpecs via get_nnx_named_sharding_with_scan_axis in maxtext_utils.py. Previously, this helper used Flax core SPMD's from_sharding_rules to map logical names to physical axes. However, from_sharding_rules resolves rules by converting the rules list into a dictionary (last-write-wins). This caused fallback rules sharing the same logical name (e.g. 'embed') to overwrite preceding specific rules, dropping essential axes like fsdp_transpose and leading to unsharded parameter percentage assertion errors.

Additionally, resolving specifications independently for each dimension without tracking assigned axes could bind a single physical axis (like fsdp_transpose) to multiple positional dimensions of a tensor, causing DuplicateSpecError.

To fix this:

  1. Replaced from_sharding_rules with a Rules-first resolution loop that matches rules sequentially (first-match-wins), matching Flax Linen's mapping behavior.
  2. Implemented an assigned_axes tracker within the loop to ensure physical mesh axes are bound to at most one dimension per tensor.
  3. Added unit tests covering sequential matching (first-match-wins) and duplicate physical axis prevention during resolution.

Tests

Log with Gemma3-12B (2x v6e-256)

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 90.56604% with 5 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/model_creation_utils.py 73.33% 2 Missing and 2 partials ⚠️
src/maxtext/utils/sharding.py 96.87% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/utils/maxtext_utils.py Outdated
)


def _resolve_logical_sharding(out_sharding, context_rules, local_rules) -> list:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you move this function to utils/sharding.py? The goal is to move all sharding related util functions to this file.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the get_nnx_named_sharding_with_scan_axis() to sharding.py, and moved the corresponding unit tests accordingly.

Comment thread tests/unit/maxtext_utils_test.py Outdated
# We define rules for 'embed' mapping to 'fsdp' (specific) then 'layers' (fallback)
rules = (
("embed", "fsdp"),
("embed", "layers"),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layers is not a physical axis name? maybe for something else, expert?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to "stage", and use "layers" only as logical name.

Comment thread tests/unit/maxtext_utils_test.py Outdated
("embed", "layers"),
)
with nn_partitioning.axis_rules(rules):
with jax.set_mesh(self.mesh):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

Comment thread tests/unit/maxtext_utils_test.py Outdated
# When matching 'mlp', 'fsdp' is already bound, so it is skipped (unassigned/None).
rules = (
("embed", ("fsdp", "layers")),
("mlp", "fsdp"),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is an error caused by the rule. We should throw an error if there is a conflict instead of silently solving it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In base.yaml: both mlp and embed have the "fsdp_transpose" physical match:

                      ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
                      ...
                      ['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']],

And this is used in MLP layer:

            kernel_axes=("embed", "mlp"),

I guess this is not a config issue.

Linen resolved this by calling remove_size_one_mesh_axis on the physical specs. We will now apply the same to NNX.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect 'embed' downgraded to the next item,

['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']],
, so the conflict no longer exists. With your function that auto move to next list, I think the error you mentioned should be solved?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Checked the Linen implementation, and it does fall to the next item.
I modified the NNX code to re-use the linen nn.logical_to_mesh_axes(), instead of implementing the same func for NNX.
Now they should have the same behaviour.

@NuojCheng NuojCheng left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I like this feature. I was concerned that this function would silently hide some errors that should be explicit raised, e.g. logical rule conflicting part. I agree we should improve the from_sharding_rules if it simply treats logical rule as dictionary.

@xibinliu xibinliu force-pushed the xibin/nnx_sharding branch from 0730724 to 80b40f7 Compare June 22, 2026 22:59
@xibinliu xibinliu force-pushed the xibin/nnx_sharding branch 9 times, most recently from 94b165c to c4c6d80 Compare June 23, 2026 01:30
@xibinliu

Copy link
Copy Markdown
Collaborator Author

Overall I like this feature. I was concerned that this function would silently hide some errors that should be explicit raised, e.g. logical rule conflicting part. I agree we should improve the from_sharding_rules if it simply treats logical rule as dictionary.

With pure NNX implementation of MaxText models, are we still allowed to import linen functions? If the answer is yes, then we can continue to re-use Linen implementation. I think we can decide during the Linen removal phase. (PR#12)

Pure NNX training runs previously used custom logical sharding resolution helpers which diverged from the standard Flax Linen path, causing logical axis fallback mismatch and DuplicateSpecErrors when multiple logical dimensions mapped to a single physical axis.

This change aligns the NNX path with Flax Linen and consolidates utilities:
1. Replaced the custom rules resolution logic with standard Flax Linen `logical_to_mesh_axes` to ensure identical behavior for rules mapping.
2. Added the `remove_size_one_mesh_axis` reduction step inside the NNX variable resolver to strip size-1 axes from the PartitionSpec, preventing JAX from raising DuplicateSpecError on models with overlapping axis mappings.
3. Aligned the variable wrappers and extraction lifecycle:
   - `sharding.nnx_construct_named_sharding` and `sharding.get_nnx_var_named_sharding_with_scan_axis` retain standard Flax NNX `Variable` / `Param` wrappers to maintain structural type compatibility during multi-tree maps in trainer setup.
   - `maxtext_utils_nnx.nnx_extract_named_sharding` extracts clean JAX-native `NamedSharding` trees for compilation and device dispatch.
4. Cleaned up comments and unit tests (in `sharding_nnx_test.py` and `maxtext_utils_nnx_test.py`) to verify behavior on local meshes and support CPU-only testing environments by avoiding host offloading during JIT.
@xibinliu xibinliu force-pushed the xibin/nnx_sharding branch from c4c6d80 to 0750360 Compare June 23, 2026 02:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants