From f952063d07b2a13eb5be1aad422c3e72a5377275 Mon Sep 17 00:00:00 2001 From: Xibin Liu Date: Mon, 22 Jun 2026 21:39:19 +0000 Subject: [PATCH] Consolidate and align NNX sharding helpers with Flax Linen Pure NNX training runs previously used custom logical sharding resolution helpers which diverged from the standard Flax Linen path, causing logical axis fallback mismatch and DuplicateSpecErrors when multiple logical dimensions mapped to a single physical axis. This change aligns the NNX path with Flax Linen and consolidates utilities: 1. Replaced the custom rules resolution logic with standard Flax Linen `logical_to_mesh_axes` to ensure identical behavior for rules mapping. 2. Added the `remove_size_one_mesh_axis` reduction step inside the NNX variable resolver to strip size-1 axes from the PartitionSpec, preventing JAX from raising DuplicateSpecError on models with overlapping axis mappings. 3. Aligned the variable wrappers and extraction lifecycle: - `sharding.nnx_construct_named_sharding` and `sharding.get_nnx_var_named_sharding_with_scan_axis` retain standard Flax NNX `Variable` / `Param` wrappers to maintain structural type compatibility during multi-tree maps in trainer setup. - `maxtext_utils_nnx.nnx_extract_named_sharding` extracts clean JAX-native `NamedSharding` trees for compilation and device dispatch. 4. Cleaned up comments and unit tests (in `sharding_nnx_test.py` and `maxtext_utils_nnx_test.py`) to verify behavior on local meshes and support CPU-only testing environments by avoiding host offloading during JIT. --- src/maxtext/inference/maxengine/maxengine.py | 5 +- src/maxtext/utils/maxtext_utils.py | 100 +----- src/maxtext/utils/maxtext_utils_nnx.py | 11 +- src/maxtext/utils/model_creation_utils.py | 30 +- src/maxtext/utils/sharding.py | 67 +++- tests/unit/maxtext_utils_nnx_test.py | 10 +- tests/unit/maxtext_utils_test.py | 87 +---- tests/unit/sharding_nnx_test.py | 332 +++++++++++++++---- 8 files changed, 384 insertions(+), 258 deletions(-) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index d9b686b182..3c4d6b2f71 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -47,6 +47,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import sharding from maxtext.utils import maxtext_utils_nnx from maxtext.utils import model_creation_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled @@ -398,7 +399,7 @@ def _load_params_nnx(self, params, rng): # axis metadata but no physical .sharding. Resolve logical to physical here so # device_put actually reshards instead of being a no-op. with nn_partitioning.axis_rules(self.config.logical_axis_rules): - target_shardings = maxtext_utils.get_nnx_named_sharding_with_scan_axis(params_abs, self._mesh) + target_shardings = sharding.nnx_construct_named_sharding(params_abs, self._mesh) params_state = jax.device_put(params, target_shardings) # We only need a concrete `rest` (RNG vars) for nnx.merge. create_nnx_sharded_model # builds the model with a jitted out_shardings so params are produced already @@ -406,7 +407,7 @@ def _load_params_nnx(self, params, rng): # large models). self.model is abstract with no .sharding, so pass an explicit one. _, full_abs = nnx.split(self.model) with nn_partitioning.axis_rules(self.config.logical_axis_rules): - full_sharding = maxtext_utils.get_nnx_named_sharding_with_scan_axis(full_abs, self._mesh) + full_sharding = sharding.nnx_construct_named_sharding(full_abs, self._mesh) concrete_model = maxtext_utils_nnx.create_nnx_sharded_model( self.model, self._create_model_fn, mesh=self._mesh, named_sharding=full_sharding ) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 238758da92..67c23fc98f 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,7 +20,6 @@ from typing import Sequence from flax import nnx, linen as nn -from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning from flax.training.train_state import TrainState @@ -28,7 +27,7 @@ import jax import jax.numpy as jnp -from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.sharding import AxisType, Mesh, NamedSharding from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load @@ -1612,88 +1611,6 @@ def move(path, x): ) -def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: - """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. - - Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also - inserts the partition_name axis at the correct scan_axis position for parameters created by - _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a - 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. - - Args: - abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). - mesh: JAX physical mesh. - - Returns: - Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. - """ - - def _make_named_sharding(v): - val = v.get_value() - if not hasattr(val, "shape"): - # `val` is either truly leafless (e.g. optax MaskedNode) or a composite - # pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables — - # a `qvalue` int8 array + a list of `scale` bf16 arrays). For the latter - # we must emit a parallel tree of NamedSharding leaves so the downstream - # `jax.tree.map(lambda a, s: ShapeDtypeStruct(..., sharding=s), abs, names)` - # finds a real Sharding at every position. Replicated sharding is a safe - # default — AQT serve-mode QTensors are normally small (per-channel scale - # factors and packed int8 weights) and don't need axis-aware sharding. - if jax.tree_util.tree_leaves(val): - replicated = NamedSharding(mesh, PartitionSpec()) - return v.replace(jax.tree.map(lambda _: replicated, val)) - return v - metadata = v.get_metadata() - out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") - if not out_sharding: - pspec = PartitionSpec() - else: - # Insert the scan axis for parameters created by _create_scanned_layers. - # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the - # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. - if nnx.PARTITION_NAME in metadata: - partition_name = metadata[nnx.PARTITION_NAME] - # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits - # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode - # scan_axis=0 for non-Param types. stacked_rest non-Param variables have - # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. - scan_axis = metadata.get("param_scan_axis", 0) - out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) - # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames - # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted - # the scan axis. Only insert if not already present. - if partition_name not in out_sharding: - out_sharding.insert(scan_axis, partition_name) - out_sharding = tuple(out_sharding) - # Convert logical axis names to physical mesh axes using current context rules. - context_rules = get_logical_axis_rules() - local_rules = metadata.get("sharding_rules", ()) - if context_rules or local_rules: - rules = composite_rules(context_rules, local_rules) - raw_sharding = from_sharding_rules(out_sharding, rules) - mesh_axis_names = mesh.axis_names if mesh is not None else () - - # from_sharding_rules leaves a logical name with no matching rule unchanged, so a - # name missing from logical_axis_rules (e.g. concat_embed on the MTP kernel) - # reaches NamedSharding and is rejected as an unknown mesh axis. Map any such - # leftover name to None (replicated), matching Linen, whose logical_to_mesh_axes - # replicates unmatched names. - def _sanitize(x): - if isinstance(x, list): - x = tuple(x) - if x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple): - return x - return None - - sanitized_sharding = [_sanitize(x) for x in raw_sharding] - pspec = PartitionSpec(*sanitized_sharding) - else: - pspec = PartitionSpec(*out_sharding) - return v.replace(NamedSharding(mesh, pspec)) - - return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) - - def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): """Calculates the abstract sharded state and memory placement for an NNX TrainState. @@ -1724,26 +1641,25 @@ def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=Tru with nn_partitioning.axis_rules(config.logical_axis_rules): # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply - # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # nnx_construct_named_sharding which correctly inserts the stacked-layers # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls # var.shape for every variable when a global mesh is active, but masked optimizer # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() - # which has no .shape and would raise AttributeError. We handle sharding - # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not - # needed here. + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via nnx_construct_named_sharding, so auto-assignment is not needed here. abs_model = nnx.eval_shape(nnx_init_trainstate_fn) _, abs_var_state = nnx.split(abs_model) - named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh) abstract_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_var_state, named_sharding_state, ) - state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + state_mesh_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) if is_training and config.shard_optimizer_over_data: # Add data to sharding for optimizer state @@ -1849,10 +1765,10 @@ def _nnx_cache_partition_specs(abstract_model, config, mesh): way it does for the Linen helpers below. """ _, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...) - # get_nnx_named_sharding_with_scan_axis reads logical axis rules from the + # nnx_construct_named_sharding reads logical axis rules from the # active flax partitioning context, so wrap. with nn_partitioning.axis_rules(config.logical_axis_rules): - named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh) + named_state = sharding.nnx_construct_named_sharding(cache_state, mesh) return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict()) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py index 5b645b85ca..df5d7a5f06 100644 --- a/src/maxtext/utils/maxtext_utils_nnx.py +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Utils for MaxText NNX. """ +"""Utils for MaxText NNX.""" from functools import partial from typing import Callable @@ -51,14 +51,17 @@ def create_nnx_rngs( return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference -def get_named_sharding_nnx(abstract_state: nnx.State) -> nnx.State: +def nnx_extract_named_sharding(abstract_state: nnx.State) -> nnx.State: """Get named sharding from NNX abstract state. Args: abstract_state: NNX model abstract state created from nnx.get_abstract_model. Returns: - named sharding structure + A tree of raw NamedSharding objects (stripping out any nnx.Variable / Param + wrappers). This clean structure is expected by JAX compiler APIs (like JIT + out_shardings). Contrast with sharding.nnx_construct_named_sharding, which + retains wrappers for abstract tree zipping compatibility. """ # Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we # get the existing sharding from the abstract_state. @@ -156,7 +159,7 @@ def create_nnx_sharded_model( if named_sharding is None: # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) # we get the sharding directly from it. - named_sharding = get_named_sharding_nnx(abstract_state) + named_sharding = nnx_extract_named_sharding(abstract_state) if mesh is None: mesh = abstract_model.mesh diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 6a2ba297d6..8fb41ff21e 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -51,7 +51,7 @@ from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_logging -from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx +from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx, sharding import numpy as np from orbax import checkpoint as ocp @@ -112,13 +112,13 @@ def _zero_pad_axis(arr, axis, extra): if extra == 0: return arr - sharding = getattr(arr, "sharding", None) + arr_sharding = getattr(arr, "sharding", None) pad_width = [(0, 0)] * arr.ndim - if isinstance(sharding, jax.sharding.NamedSharding): - spec = sharding.spec + if isinstance(arr_sharding, jax.sharding.NamedSharding): + spec = arr_sharding.spec partition = spec[axis] if axis < len(spec) else None - shards_along_axis = _partition_size(partition, sharding.mesh) + shards_along_axis = _partition_size(partition, arr_sharding.mesh) if shards_along_axis > 1: if extra % shards_along_axis != 0: raise ValueError( @@ -131,7 +131,7 @@ def _zero_pad_axis(arr, axis, extra): def _pad_local(x): return jnp.pad(x, pad_width) - return jax.shard_map(_pad_local, mesh=sharding.mesh, in_specs=spec, out_specs=spec, check_vma=False)(arr) + return jax.shard_map(_pad_local, mesh=arr_sharding.mesh, in_specs=spec, out_specs=spec, check_vma=False)(arr) pad_width[axis] = (0, extra) return jnp.pad(arr, pad_width) @@ -257,11 +257,11 @@ def _maybe_fuse(path, ckpt_node): # Determine the number of shards (TP degree) along the concatenated axis n_shards = 1 - sharding = getattr(wi_model, "sharding", None) - if isinstance(sharding, jax.sharding.NamedSharding): - spec = sharding.spec + wi_sharding = getattr(wi_model, "sharding", None) + if isinstance(wi_sharding, jax.sharding.NamedSharding): + spec = wi_sharding.spec partition = spec[axis] if axis < len(spec) else None - n_shards = _partition_size(partition, sharding.mesh) + n_shards = _partition_size(partition, wi_sharding.mesh) # Target size for a single half (wi_0 or wi_1) AFTER padding target_half_dim = wi_model.shape[-1] // 2 @@ -325,13 +325,13 @@ def _stored_shape_evenly_shardable(restore_arg, stored_shape): (each device receives only its local slice), avoiding the multi-GB replicated fanout that fully-replicated loading produces for large MoE weights. """ - sharding = restore_arg.sharding - if not isinstance(sharding, jax.sharding.NamedSharding): + restore_sharding = restore_arg.sharding + if not isinstance(restore_sharding, jax.sharding.NamedSharding): return False - spec = sharding.spec + spec = restore_sharding.spec for axis_idx, dim in enumerate(stored_shape): partition = spec[axis_idx] if axis_idx < len(spec) else None - if dim % _partition_size(partition, sharding.mesh) != 0: + if dim % _partition_size(partition, restore_sharding.mesh) != 0: return False return True @@ -581,7 +581,7 @@ def create_nnx_abstract_model( # wrap is unnecessary here. abs_model = nnx.eval_shape(_create_model) graphdef, abs_var_state = nnx.split(abs_model) - named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh) abstract_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_var_state, diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 50115cae72..789f6420e9 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -13,9 +13,10 @@ # limitations under the License. # pylint: disable=line-too-long, disable=bare-except, consider-using-generator -""" Utils that are only interesting to MaxText and sharding related. """ +"""Utils that are only interesting to MaxText and sharding related.""" from flax import linen as nn, nnx +from flax.core.spmd import get_logical_axis_rules from collections.abc import Iterable @@ -179,6 +180,70 @@ def remove_size_one_mesh_axis(spec, mesh): return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced) +def get_nnx_var_named_sharding_with_scan_axis(v: nnx.Variable, mesh) -> nnx.Variable: + """Compute NamedSharding for an NNX variable, correctly handling the scan axis.""" + val = v.get_value() + if not hasattr(val, "shape"): + # `val` is either truly leafless (e.g. optax MaskedNode) or a composite + # pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables). + # Replicated sharding is a safe default. + if jax.tree_util.tree_leaves(val): + replicated = NamedSharding(mesh, P()) + return v.replace(jax.tree.map(lambda _: replicated, val)) + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = P() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + local_rules_list = list(local_rules) if local_rules is not None else [] + context_rules_list = list(context_rules) if context_rules is not None else [] + rules = local_rules_list + context_rules_list + pspec = logical_to_mesh_axes(out_sharding, mesh, rules=rules) + else: + pspec = P(*out_sharding) + if mesh is not None: + pspec = remove_size_one_mesh_axis(pspec, mesh) + return v.replace(NamedSharding(mesh, pspec)) + + +def nnx_construct_named_sharding(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state with leaf values replaced with NamedSharding. + Note that it preserves the original nnx.Variable / Param wrapper nodes to maintain + type structure matching abs_var_state (necessary for multi-tree maps). Use + maxtext_utils_nnx.nnx_extract_named_sharding to retrieve clean raw NamedShardings. + """ + return jax.tree.map( + lambda x: get_nnx_var_named_sharding_with_scan_axis(x, mesh), + abs_var_state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + def logical_to_mesh_axes(logical_names, mesh, rules=None): """Remove size one mesh axes given logical names.""" tensor_spec = nn.logical_to_mesh_axes(logical_names, rules=rules) diff --git a/tests/unit/maxtext_utils_nnx_test.py b/tests/unit/maxtext_utils_nnx_test.py index 10e2b8621f..87ba463963 100644 --- a/tests/unit/maxtext_utils_nnx_test.py +++ b/tests/unit/maxtext_utils_nnx_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Tests for the common MaxText NNX utilities """ +"""Tests for the common MaxText NNX utilities""" import unittest from dataclasses import dataclass from typing import Any @@ -104,7 +104,7 @@ def test_get_set_named_sharding_nnx(self): _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) # 2. Test extraction - extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) # Verify kernel and bias match the P("data") annotations from TinyModel self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None)) @@ -136,7 +136,7 @@ def update_spec_fn(path, leaf_sharding): # 4. Verify named sharding is preserved after NNX merge (update) and split (state) model = self.tiny_model_init_fn() nnx.update(model, updated_abstract) - re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model)) + re_extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(nnx.state(model)) # Verify kernel and bias have expected sharding self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec) @@ -148,7 +148,7 @@ def test_create_nnx_sharded_model(self): abstract_model = nnx.merge(graphdef, abstract_state) # 2. Modify shardings to trigger host offloading - extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) # 3. Run the sharded creation @@ -165,7 +165,7 @@ def test_get_partition_spec_nnx(self): """Verifies extraction of PartitionSpecs from NamedShardings.""" # 1. Create abstract state and get sharding _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) - extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) # 2. Execute extraction spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings) diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 2e90880a83..6857cabca6 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -1523,7 +1523,7 @@ class MockConfig: optimizer_memory_host_offload: bool = False parameter_memory_host_offload: bool = False param_scan_axis: int = 0 - logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]], ["model", ["model"]]]) class MockTrainState(nnx.Module): """Simulates a TrainState with params and optimizer state.""" @@ -1542,6 +1542,13 @@ def setUp(self): devices = jax.local_devices() self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) self.config = self.MockConfig() + # Stub remove_size_one_mesh_axis so that resolved specs are returned unreduced, + # allowing verification of naming resolution without being stripped by size-one mesh dims. + self._old_remove_size_one_mesh_axis = sharding.remove_size_one_mesh_axis + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + def tearDown(self): + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis def nnx_init_trainstate_wrapper(self): """Wrapper to initialize the mock NNX model.""" @@ -1613,83 +1620,5 @@ def test_invalid_init_fn(self): maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) -class TestGetNnxNamedShardingWithScanAxis(unittest.TestCase): - """Unit tests for get_nnx_named_sharding_with_scan_axis covering every branch. - - The helper resolves a NamedSharding for each NNX Variable and — unlike - flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at - `param_scan_axis` when scanned-layers metadata is present. - """ - - def setUp(self): - # Mesh needs to contain every axis name the tests reference in partition specs. - self.mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("fsdp", "layers")) - - def _build_state(self, **variables): - """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" - return nnx.State(variables) - - def _run(self, state): - return maxtext_utils.get_nnx_named_sharding_with_scan_axis(state, self.mesh) - - def test_scan_axis_inserted_at_param_scan_axis(self): - """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" - with jax.set_mesh(self.mesh): - v = nnx.Param( - jnp.zeros((3, 4, 8)), - out_sharding=(None, "fsdp"), - **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, - ) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - self.assertIsInstance(result_sharding, NamedSharding) - # 'layers' must be inserted at position 1 (param_scan_axis=1). - self.assertEqual(result_sharding.spec, PartitionSpec(None, "layers", "fsdp")) - - def test_scan_axis_not_inserted_when_already_present(self): - """Guard against double-insertion when partition_name is already in out_sharding.""" - with jax.set_mesh(self.mesh): - v = nnx.Param( - jnp.zeros((2, 2, 2)), - out_sharding=("layers", None, "fsdp"), - **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, - ) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - # 'layers' must appear exactly once — the same PartitionSpec we started with. - self.assertEqual(result_sharding.spec, PartitionSpec("layers", None, "fsdp")) - - def test_masked_node_preserved_as_is(self): - """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" - masked = nnx.Variable(optax.MaskedNode()) - state = self._build_state(masked=masked) - out = self._run(state) - # The leaf must be the original Variable, not a NamedSharding wrapper. - self.assertIs(out["masked"], masked) - - def test_empty_out_sharding_yields_empty_pspec(self): - """A Variable without any sharding metadata should resolve to PartitionSpec().""" - with jax.set_mesh(self.mesh): - # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() - v = nnx.Param(jnp.zeros((4,))) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - self.assertIsInstance(result_sharding, NamedSharding) - self.assertEqual(result_sharding.spec, PartitionSpec()) - - def test_string_out_sharding_is_wrapped_into_tuple(self): - """A single-string out_sharding value should still produce a valid PartitionSpec.""" - with jax.set_mesh(self.mesh): - v = nnx.Param( - jnp.zeros((4,)), - out_sharding="fsdp", - **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, - ) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - # The single string 'fsdp' is turned into a list, and 'layers' is prepended. - self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp")) - - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/sharding_nnx_test.py b/tests/unit/sharding_nnx_test.py index d182f6d18d..62bdcb4348 100644 --- a/tests/unit/sharding_nnx_test.py +++ b/tests/unit/sharding_nnx_test.py @@ -18,10 +18,12 @@ import unittest from flax import nnx +from flax.linen import partitioning as nn_partitioning import jax from jax.sharding import Mesh, NamedSharding, PartitionSpec from maxtext.common import train_state_nnx from maxtext.utils import sharding +import jax.numpy as jnp import numpy as np import optax @@ -39,6 +41,18 @@ def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 4, rngs=rngs) +def _create_2d_test_mesh(axis_names=("data", "model")): + devices = jax.local_devices() + num_devices = len(devices) + if num_devices >= 4: + mesh_devices = np.array(devices[:4]).reshape(2, 2) + elif num_devices >= 2: + mesh_devices = np.array(devices[:2]).reshape(2, 1) + else: + mesh_devices = np.array(devices[:1]).reshape(1, 1) + return Mesh(devices=mesh_devices, axis_names=axis_names) + + def _build_state_mesh_shardings(model, tx): """Build an nnx.State of NamedShardings mirroring the TrainStateNNX layout. @@ -48,9 +62,7 @@ def _build_state_mesh_shardings(model, tx): optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) state_obj = train_state_nnx.TrainStateNNX(model, optimizer) state = nnx.state(state_obj) - mesh = Mesh( - np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model") - ) + mesh = _create_2d_test_mesh() def _to_sharding(var): val = var.get_value() @@ -62,9 +74,7 @@ def _to_sharding(var): pspec = PartitionSpec("data", "model") return var.replace(NamedSharding(mesh, pspec)) - return jax.tree.map( - _to_sharding, state, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + return jax.tree.map(_to_sharding, state, is_leaf=lambda x: isinstance(x, nnx.Variable)) class TestMaybeUpdateParamsShardingWithOptNNX(unittest.TestCase): @@ -76,12 +86,8 @@ def setUp(self): def test_dispatch_from_main_helper_when_pure_nnx(self): """maybe_update_params_sharding_with_opt should dispatch to the NNX variant.""" cfg = _Cfg(pure_nnx=True, shard_optimizer_over_data=False) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) - prev, updated = sharding.maybe_update_params_sharding_with_opt( - cfg, state_mesh_shardings - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, updated = sharding.maybe_update_params_sharding_with_opt(cfg, state_mesh_shardings) # prev is the param-only view (no rngs / non-Param nodes) self.assertIsInstance(prev, nnx.State) self.assertIn("linear", prev) @@ -91,15 +97,9 @@ def test_dispatch_from_main_helper_when_pure_nnx(self): def test_extract_param_only_skips_non_param_variables(self): """prev_params_shardings must contain Params only — RngKey/RngCount/OptVariable filtered out.""" cfg = _Cfg(shard_optimizer_over_data=False) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) - prev, _ = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) - leaves = jax.tree.leaves( - prev, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, _ = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + leaves = jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable)) # Every surviving leaf is wrapped as an nnx.Param. self.assertTrue(all(isinstance(leaf, nnx.Param) for leaf in leaves)) # The model has linear.kernel and linear.bias — exactly two Param leaves. @@ -108,25 +108,17 @@ def test_extract_param_only_skips_non_param_variables(self): def test_returns_unchanged_when_shard_optimizer_over_data_false(self): """When shard_optimizer_over_data=False, the second return value must be the input object.""" cfg = _Cfg(shard_optimizer_over_data=False) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) - _, updated = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) self.assertIs(updated, state_mesh_shardings) def test_zero1_propagates_mu_sharding_to_model_params(self): """Zero-1: model param shardings must be replaced with the optimizer mu shardings.""" cfg = _Cfg(shard_optimizer_over_data=True) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) # Mutate the optimizer mu leaves in place so the function picks up a distinct PartitionSpec. - mesh = Mesh( - np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model") - ) + mesh = _create_2d_test_mesh() target_pspec = PartitionSpec(("data", "model")) new_mu_sharding = NamedSharding(mesh, target_pspec) @@ -135,20 +127,14 @@ def test_zero1_propagates_mu_sharding_to_model_params(self): # After _build_state_mesh_shardings, every leaf's value is a NamedSharding (no .shape), # so we just override every Variable leaf in mu in place via set_value (modern API). mu_state = state_mesh_shardings.optimizer.opt_state[0]["mu"] - for var in jax.tree.leaves( - mu_state, is_leaf=lambda x: isinstance(x, nnx.Variable) - ): + for var in jax.tree.leaves(mu_state, is_leaf=lambda x: isinstance(x, nnx.Variable)): if isinstance(var, nnx.Variable): var.set_value(new_mu_sharding) - _, updated = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) # All Param leaves under updated.model must now share the new mu sharding. - param_leaves = jax.tree.leaves( - updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + param_leaves = jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) param_leaves = [v for v in param_leaves if isinstance(v, nnx.Param)] self.assertGreater(len(param_leaves), 0) for leaf in param_leaves: @@ -157,13 +143,9 @@ def test_zero1_propagates_mu_sharding_to_model_params(self): def test_raises_when_no_adam_state_present(self): """Stateless optimizers (e.g., SGD) have no mu — function must raise NotImplementedError.""" cfg = _Cfg(shard_optimizer_over_data=True) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.sgd(1e-3) - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.sgd(1e-3)) with self.assertRaises(NotImplementedError): - sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) def test_chained_optimizer_recursion_finds_adam_mu(self): """A nested optax.chain(clip, adam) wraps mu under multiple containers — recursion must find it.""" @@ -172,24 +154,254 @@ def test_chained_optimizer_recursion_finds_adam_mu(self): state_mesh_shardings = _build_state_mesh_shardings(self.model, chained) # Should not raise; verify update happens (params replaced with mu shardings). - prev, updated = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + prev, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) self.assertIsInstance(prev, nnx.State) self.assertIsInstance(updated, nnx.State) # Same number of Param leaves before and after. - n_prev = len( - jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable)) + n_prev = len(jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable))) + n_after = len( + [ + v + for v in jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + if isinstance(v, nnx.Param) + ] ) - n_after = len([ - v - for v in jax.tree.leaves( - updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) - if isinstance(v, nnx.Param) - ]) self.assertEqual(n_prev, n_after) +class TestNnxConstructNamedSharding(unittest.TestCase): + """Unit tests for nnx_construct_named_sharding covering every branch. + + The helper resolves a NamedSharding for each NNX Variable inside an nnx.State and + — unlike flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at + `param_scan_axis` when scanned-layers metadata is present. + """ + + def setUp(self): + # Mesh needs to contain every axis name the tests reference in partition specs. + self.mesh = _create_2d_test_mesh(axis_names=("fsdp", "stage")) + # In local test environments (e.g. single-device CPU), all mesh axes have size 1. + # We stub remove_size_one_mesh_axis to act as a no-op so that resolved physical PartitionSpecs + # are returned unreduced (e.g. retaining "fsdp", "stage", etc.), allowing us to verify naming + # resolution. The actual size-one axis removal is tested separately in TestGetNNXNamedShardingSizeOneAxes. + self._old_remove_size_one_mesh_axis = sharding.remove_size_one_mesh_axis + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + def tearDown(self): + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis + + def _build_state(self, **variables): + """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" + return nnx.State(variables) + + def _run(self, state): + return sharding.nnx_construct_named_sharding(state, self.mesh) + + def test_scan_axis_inserted_at_param_scan_axis(self): + """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" + rules = (("layers", "stage"), ("fsdp", "fsdp")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4, 8)), + out_sharding=(None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # 'layers' resolves to physical axis 'stage' and is inserted at position 1 (param_scan_axis=1). + self.assertEqual(result_sharding.spec, PartitionSpec(None, "stage", "fsdp")) + + def test_scan_axis_not_inserted_when_already_present(self): + """Guard against double-insertion when partition_name is already in out_sharding.""" + rules = (("layers", "stage"), ("fsdp", "fsdp")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((2, 2, 2)), + out_sharding=("layers", None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'stage' must appear exactly once — the same PartitionSpec we started with. + self.assertEqual(result_sharding.spec, PartitionSpec("stage", None, "fsdp")) + + def test_masked_node_preserved_as_is(self): + """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" + masked = nnx.Variable(optax.MaskedNode()) + state = self._build_state(masked=masked) + out = self._run(state) + # The leaf must be the original Variable, not a NamedSharding wrapper. + self.assertIs(out["masked"], masked) + + def test_empty_out_sharding_yields_empty_pspec(self): + """A Variable without any sharding metadata should resolve to PartitionSpec().""" + with jax.set_mesh(self.mesh): + # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() + v = nnx.Param(jnp.zeros((4,))) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + self.assertEqual(result_sharding.spec, PartitionSpec()) + + def test_string_out_sharding_is_wrapped_into_tuple(self): + """A single-string out_sharding value should still produce a valid PartitionSpec.""" + rules = (("layers", "stage"), ("fsdp", "fsdp")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding="fsdp", + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # The single string 'fsdp' is turned into a list, and 'layers' (resolving to 'stage') is prepended. + self.assertEqual(result_sharding.spec, PartitionSpec("stage", "fsdp")) + + def test_sequential_matching_first_match_wins(self): + """Multiple rules for the same logical axis are matched sequentially, first-match-wins.""" + # We define rules for 'embed' mapping to 'fsdp' (specific) then 'stage' (fallback) + rules = ( + ("embed", "fsdp"), + ("embed", "stage"), + ) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'embed' must match the first rule ('fsdp'), not the second ('stage'). + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp")) + + def test_prevents_duplicate_physical_axes(self): + """If multiple dimensions map to the same physical axis, the subsequent ones are skipped (mapped to None).""" + # Setup rules where 'embed' maps to 'fsdp' and 'mlp' also maps to 'fsdp'. + rules = ( + ("embed", "fsdp"), + ("mlp", "fsdp"), + ) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4)), + out_sharding=("embed", "mlp"), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # Expected: Dim 0 matches 'embed' -> 'fsdp'. + # Dim 1 tries 'mlp' -> 'fsdp', but 'fsdp' is already assigned to Dim 0. + # So it skips the rule and falls back to matching nothing -> None. + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp", None)) + + def test_fallback_to_next_physical_axis_when_duplicated(self): + """When a physical axis is already assigned, fallback priority rules should map to the next available physical option.""" + # Setup rules where 'embed' maps to 'fsdp', and 'mlp' maps to 'fsdp' (priority 1) or 'stage' (priority 2). + rules = ( + ("embed", "fsdp"), + ("mlp", "fsdp"), + ("mlp", "stage"), + ) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4)), + out_sharding=("embed", "mlp"), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # Expected: Dim 0 matches 'embed' -> 'fsdp'. + # Dim 1 tries 'mlp' -> 'fsdp', but 'fsdp' is already assigned to Dim 0. + # So it falls to the next item in the list -> 'stage'. + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp", "stage")) + + def test_resolves_when_context_rules_is_none(self): + """When context_rules is None but local_rules are defined, resolution should succeed.""" + # Ensure get_logical_axis_rules() returns None (which is the default outside axis_rules) + # We define local rules on the variable metadata. + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + sharding_rules=(("embed", "fsdp"),), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'embed' must match the local rules even when context_rules is None. + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp")) + + def test_composite_pytree_variable_resolved_to_replicated_shardings(self): + """A Variable holding a composite pytree (e.g. tuple of arrays) is resolved to replicated NamedShardings.""" + with jax.set_mesh(self.mesh): + v = nnx.Variable((jnp.zeros((2, 2)), jnp.zeros((3, 3)))) + out = self._run(self._build_state(w=v)) + result = out["w"].get_value() + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + self.assertIsInstance(result[0], NamedSharding) + self.assertIsInstance(result[1], NamedSharding) + self.assertEqual(result[0].spec, PartitionSpec()) + self.assertEqual(result[1].spec, PartitionSpec()) + + def test_rules_merged_when_both_context_and_local_rules_present(self): + """When both local rules and context rules are present, they are concatenated in order of local then context.""" + # Local rules map 'embed' to 'stage'. Context rules map 'embed' to 'fsdp'. + # Because local rules come first, 'embed' should resolve to 'stage'. + context_rules = (("embed", "fsdp"),) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(context_rules): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + sharding_rules=(("embed", "stage"),), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertEqual(result_sharding.spec, PartitionSpec("stage")) + + def test_removes_size_one_mesh_axes_no_rules(self): + """When no rules are defined but mesh is present, size-1 physical axes in the spec are removed.""" + # Temporarily restore the original remove_size_one_mesh_axis function + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis + try: + # 'fsdp' is a physical axis with size 1 in self.mesh. + # out_sharding directly specifies the physical axis 'fsdp'. + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding=("fsdp",), + ) + out_v = sharding.get_nnx_var_named_sharding_with_scan_axis(v, self.mesh) + result_sharding = out_v.get_value() + # 'fsdp' has size 1, so it gets reduced to None. + self.assertEqual(result_sharding.spec, PartitionSpec(None)) + finally: + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + def test_removes_size_one_mesh_axes(self): + """When remove_size_one_mesh_axis is active, physical axes with size 1 are removed (mapped to None).""" + # Temporarily restore the original remove_size_one_mesh_axis function + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis + try: + # Setup axis rules mapping 'embed' to 'fsdp' and 'layers' to 'stage'. + # In this mesh, fsdp and stage have size 1. + rules = (("embed", "fsdp"), ("layers", "stage")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4)), + out_sharding=("embed",), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + # Resolve sharding + out_v = sharding.get_nnx_var_named_sharding_with_scan_axis(v, self.mesh) + result_sharding = out_v.get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # Expected: P("fsdp", "stage") gets reduced to P(None, None) since both fsdp and stage have size 1. + self.assertEqual(result_sharding.spec, PartitionSpec(None, None)) + finally: + # Re-apply the stub to keep other tests working + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + if __name__ == "__main__": unittest.main()