diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index d9b686b182..3c4d6b2f71 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -47,6 +47,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import sharding from maxtext.utils import maxtext_utils_nnx from maxtext.utils import model_creation_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled @@ -398,7 +399,7 @@ def _load_params_nnx(self, params, rng): # axis metadata but no physical .sharding. Resolve logical to physical here so # device_put actually reshards instead of being a no-op. with nn_partitioning.axis_rules(self.config.logical_axis_rules): - target_shardings = maxtext_utils.get_nnx_named_sharding_with_scan_axis(params_abs, self._mesh) + target_shardings = sharding.nnx_construct_named_sharding(params_abs, self._mesh) params_state = jax.device_put(params, target_shardings) # We only need a concrete `rest` (RNG vars) for nnx.merge. create_nnx_sharded_model # builds the model with a jitted out_shardings so params are produced already @@ -406,7 +407,7 @@ def _load_params_nnx(self, params, rng): # large models). self.model is abstract with no .sharding, so pass an explicit one. _, full_abs = nnx.split(self.model) with nn_partitioning.axis_rules(self.config.logical_axis_rules): - full_sharding = maxtext_utils.get_nnx_named_sharding_with_scan_axis(full_abs, self._mesh) + full_sharding = sharding.nnx_construct_named_sharding(full_abs, self._mesh) concrete_model = maxtext_utils_nnx.create_nnx_sharded_model( self.model, self._create_model_fn, mesh=self._mesh, named_sharding=full_sharding ) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 238758da92..67c23fc98f 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,7 +20,6 @@ from typing import Sequence from flax import nnx, linen as nn -from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning from flax.training.train_state import TrainState @@ -28,7 +27,7 @@ import jax import jax.numpy as jnp -from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.sharding import AxisType, Mesh, NamedSharding from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load @@ -1612,88 +1611,6 @@ def move(path, x): ) -def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: - """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. - - Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also - inserts the partition_name axis at the correct scan_axis position for parameters created by - _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a - 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. - - Args: - abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). - mesh: JAX physical mesh. - - Returns: - Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. - """ - - def _make_named_sharding(v): - val = v.get_value() - if not hasattr(val, "shape"): - # `val` is either truly leafless (e.g. optax MaskedNode) or a composite - # pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables — - # a `qvalue` int8 array + a list of `scale` bf16 arrays). For the latter - # we must emit a parallel tree of NamedSharding leaves so the downstream - # `jax.tree.map(lambda a, s: ShapeDtypeStruct(..., sharding=s), abs, names)` - # finds a real Sharding at every position. Replicated sharding is a safe - # default — AQT serve-mode QTensors are normally small (per-channel scale - # factors and packed int8 weights) and don't need axis-aware sharding. - if jax.tree_util.tree_leaves(val): - replicated = NamedSharding(mesh, PartitionSpec()) - return v.replace(jax.tree.map(lambda _: replicated, val)) - return v - metadata = v.get_metadata() - out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") - if not out_sharding: - pspec = PartitionSpec() - else: - # Insert the scan axis for parameters created by _create_scanned_layers. - # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the - # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. - if nnx.PARTITION_NAME in metadata: - partition_name = metadata[nnx.PARTITION_NAME] - # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits - # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode - # scan_axis=0 for non-Param types. stacked_rest non-Param variables have - # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. - scan_axis = metadata.get("param_scan_axis", 0) - out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) - # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames - # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted - # the scan axis. Only insert if not already present. - if partition_name not in out_sharding: - out_sharding.insert(scan_axis, partition_name) - out_sharding = tuple(out_sharding) - # Convert logical axis names to physical mesh axes using current context rules. - context_rules = get_logical_axis_rules() - local_rules = metadata.get("sharding_rules", ()) - if context_rules or local_rules: - rules = composite_rules(context_rules, local_rules) - raw_sharding = from_sharding_rules(out_sharding, rules) - mesh_axis_names = mesh.axis_names if mesh is not None else () - - # from_sharding_rules leaves a logical name with no matching rule unchanged, so a - # name missing from logical_axis_rules (e.g. concat_embed on the MTP kernel) - # reaches NamedSharding and is rejected as an unknown mesh axis. Map any such - # leftover name to None (replicated), matching Linen, whose logical_to_mesh_axes - # replicates unmatched names. - def _sanitize(x): - if isinstance(x, list): - x = tuple(x) - if x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple): - return x - return None - - sanitized_sharding = [_sanitize(x) for x in raw_sharding] - pspec = PartitionSpec(*sanitized_sharding) - else: - pspec = PartitionSpec(*out_sharding) - return v.replace(NamedSharding(mesh, pspec)) - - return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) - - def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): """Calculates the abstract sharded state and memory placement for an NNX TrainState. @@ -1724,26 +1641,25 @@ def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=Tru with nn_partitioning.axis_rules(config.logical_axis_rules): # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply - # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # nnx_construct_named_sharding which correctly inserts the stacked-layers # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls # var.shape for every variable when a global mesh is active, but masked optimizer # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() - # which has no .shape and would raise AttributeError. We handle sharding - # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not - # needed here. + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via nnx_construct_named_sharding, so auto-assignment is not needed here. abs_model = nnx.eval_shape(nnx_init_trainstate_fn) _, abs_var_state = nnx.split(abs_model) - named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh) abstract_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_var_state, named_sharding_state, ) - state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + state_mesh_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) if is_training and config.shard_optimizer_over_data: # Add data to sharding for optimizer state @@ -1849,10 +1765,10 @@ def _nnx_cache_partition_specs(abstract_model, config, mesh): way it does for the Linen helpers below. """ _, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...) - # get_nnx_named_sharding_with_scan_axis reads logical axis rules from the + # nnx_construct_named_sharding reads logical axis rules from the # active flax partitioning context, so wrap. with nn_partitioning.axis_rules(config.logical_axis_rules): - named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh) + named_state = sharding.nnx_construct_named_sharding(cache_state, mesh) return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict()) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py index 5b645b85ca..df5d7a5f06 100644 --- a/src/maxtext/utils/maxtext_utils_nnx.py +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Utils for MaxText NNX. """ +"""Utils for MaxText NNX.""" from functools import partial from typing import Callable @@ -51,14 +51,17 @@ def create_nnx_rngs( return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference -def get_named_sharding_nnx(abstract_state: nnx.State) -> nnx.State: +def nnx_extract_named_sharding(abstract_state: nnx.State) -> nnx.State: """Get named sharding from NNX abstract state. Args: abstract_state: NNX model abstract state created from nnx.get_abstract_model. Returns: - named sharding structure + A tree of raw NamedSharding objects (stripping out any nnx.Variable / Param + wrappers). This clean structure is expected by JAX compiler APIs (like JIT + out_shardings). Contrast with sharding.nnx_construct_named_sharding, which + retains wrappers for abstract tree zipping compatibility. """ # Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we # get the existing sharding from the abstract_state. @@ -156,7 +159,7 @@ def create_nnx_sharded_model( if named_sharding is None: # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) # we get the sharding directly from it. - named_sharding = get_named_sharding_nnx(abstract_state) + named_sharding = nnx_extract_named_sharding(abstract_state) if mesh is None: mesh = abstract_model.mesh diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 6a2ba297d6..8fb41ff21e 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -51,7 +51,7 @@ from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_logging -from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx +from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx, sharding import numpy as np from orbax import checkpoint as ocp @@ -112,13 +112,13 @@ def _zero_pad_axis(arr, axis, extra): if extra == 0: return arr - sharding = getattr(arr, "sharding", None) + arr_sharding = getattr(arr, "sharding", None) pad_width = [(0, 0)] * arr.ndim - if isinstance(sharding, jax.sharding.NamedSharding): - spec = sharding.spec + if isinstance(arr_sharding, jax.sharding.NamedSharding): + spec = arr_sharding.spec partition = spec[axis] if axis < len(spec) else None - shards_along_axis = _partition_size(partition, sharding.mesh) + shards_along_axis = _partition_size(partition, arr_sharding.mesh) if shards_along_axis > 1: if extra % shards_along_axis != 0: raise ValueError( @@ -131,7 +131,7 @@ def _zero_pad_axis(arr, axis, extra): def _pad_local(x): return jnp.pad(x, pad_width) - return jax.shard_map(_pad_local, mesh=sharding.mesh, in_specs=spec, out_specs=spec, check_vma=False)(arr) + return jax.shard_map(_pad_local, mesh=arr_sharding.mesh, in_specs=spec, out_specs=spec, check_vma=False)(arr) pad_width[axis] = (0, extra) return jnp.pad(arr, pad_width) @@ -257,11 +257,11 @@ def _maybe_fuse(path, ckpt_node): # Determine the number of shards (TP degree) along the concatenated axis n_shards = 1 - sharding = getattr(wi_model, "sharding", None) - if isinstance(sharding, jax.sharding.NamedSharding): - spec = sharding.spec + wi_sharding = getattr(wi_model, "sharding", None) + if isinstance(wi_sharding, jax.sharding.NamedSharding): + spec = wi_sharding.spec partition = spec[axis] if axis < len(spec) else None - n_shards = _partition_size(partition, sharding.mesh) + n_shards = _partition_size(partition, wi_sharding.mesh) # Target size for a single half (wi_0 or wi_1) AFTER padding target_half_dim = wi_model.shape[-1] // 2 @@ -325,13 +325,13 @@ def _stored_shape_evenly_shardable(restore_arg, stored_shape): (each device receives only its local slice), avoiding the multi-GB replicated fanout that fully-replicated loading produces for large MoE weights. """ - sharding = restore_arg.sharding - if not isinstance(sharding, jax.sharding.NamedSharding): + restore_sharding = restore_arg.sharding + if not isinstance(restore_sharding, jax.sharding.NamedSharding): return False - spec = sharding.spec + spec = restore_sharding.spec for axis_idx, dim in enumerate(stored_shape): partition = spec[axis_idx] if axis_idx < len(spec) else None - if dim % _partition_size(partition, sharding.mesh) != 0: + if dim % _partition_size(partition, restore_sharding.mesh) != 0: return False return True @@ -581,7 +581,7 @@ def create_nnx_abstract_model( # wrap is unnecessary here. abs_model = nnx.eval_shape(_create_model) graphdef, abs_var_state = nnx.split(abs_model) - named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh) abstract_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_var_state, diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 50115cae72..789f6420e9 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -13,9 +13,10 @@ # limitations under the License. # pylint: disable=line-too-long, disable=bare-except, consider-using-generator -""" Utils that are only interesting to MaxText and sharding related. """ +"""Utils that are only interesting to MaxText and sharding related.""" from flax import linen as nn, nnx +from flax.core.spmd import get_logical_axis_rules from collections.abc import Iterable @@ -179,6 +180,70 @@ def remove_size_one_mesh_axis(spec, mesh): return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced) +def get_nnx_var_named_sharding_with_scan_axis(v: nnx.Variable, mesh) -> nnx.Variable: + """Compute NamedSharding for an NNX variable, correctly handling the scan axis.""" + val = v.get_value() + if not hasattr(val, "shape"): + # `val` is either truly leafless (e.g. optax MaskedNode) or a composite + # pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables). + # Replicated sharding is a safe default. + if jax.tree_util.tree_leaves(val): + replicated = NamedSharding(mesh, P()) + return v.replace(jax.tree.map(lambda _: replicated, val)) + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = P() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + local_rules_list = list(local_rules) if local_rules is not None else [] + context_rules_list = list(context_rules) if context_rules is not None else [] + rules = local_rules_list + context_rules_list + pspec = logical_to_mesh_axes(out_sharding, mesh, rules=rules) + else: + pspec = P(*out_sharding) + if mesh is not None: + pspec = remove_size_one_mesh_axis(pspec, mesh) + return v.replace(NamedSharding(mesh, pspec)) + + +def nnx_construct_named_sharding(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state with leaf values replaced with NamedSharding. + Note that it preserves the original nnx.Variable / Param wrapper nodes to maintain + type structure matching abs_var_state (necessary for multi-tree maps). Use + maxtext_utils_nnx.nnx_extract_named_sharding to retrieve clean raw NamedShardings. + """ + return jax.tree.map( + lambda x: get_nnx_var_named_sharding_with_scan_axis(x, mesh), + abs_var_state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + def logical_to_mesh_axes(logical_names, mesh, rules=None): """Remove size one mesh axes given logical names.""" tensor_spec = nn.logical_to_mesh_axes(logical_names, rules=rules) diff --git a/tests/unit/maxtext_utils_nnx_test.py b/tests/unit/maxtext_utils_nnx_test.py index 10e2b8621f..87ba463963 100644 --- a/tests/unit/maxtext_utils_nnx_test.py +++ b/tests/unit/maxtext_utils_nnx_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Tests for the common MaxText NNX utilities """ +"""Tests for the common MaxText NNX utilities""" import unittest from dataclasses import dataclass from typing import Any @@ -104,7 +104,7 @@ def test_get_set_named_sharding_nnx(self): _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) # 2. Test extraction - extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) # Verify kernel and bias match the P("data") annotations from TinyModel self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None)) @@ -136,7 +136,7 @@ def update_spec_fn(path, leaf_sharding): # 4. Verify named sharding is preserved after NNX merge (update) and split (state) model = self.tiny_model_init_fn() nnx.update(model, updated_abstract) - re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model)) + re_extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(nnx.state(model)) # Verify kernel and bias have expected sharding self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec) @@ -148,7 +148,7 @@ def test_create_nnx_sharded_model(self): abstract_model = nnx.merge(graphdef, abstract_state) # 2. Modify shardings to trigger host offloading - extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) # 3. Run the sharded creation @@ -165,7 +165,7 @@ def test_get_partition_spec_nnx(self): """Verifies extraction of PartitionSpecs from NamedShardings.""" # 1. Create abstract state and get sharding _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) - extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + extracted_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) # 2. Execute extraction spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings) diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 2e90880a83..6857cabca6 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -1523,7 +1523,7 @@ class MockConfig: optimizer_memory_host_offload: bool = False parameter_memory_host_offload: bool = False param_scan_axis: int = 0 - logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]], ["model", ["model"]]]) class MockTrainState(nnx.Module): """Simulates a TrainState with params and optimizer state.""" @@ -1542,6 +1542,13 @@ def setUp(self): devices = jax.local_devices() self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) self.config = self.MockConfig() + # Stub remove_size_one_mesh_axis so that resolved specs are returned unreduced, + # allowing verification of naming resolution without being stripped by size-one mesh dims. + self._old_remove_size_one_mesh_axis = sharding.remove_size_one_mesh_axis + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + def tearDown(self): + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis def nnx_init_trainstate_wrapper(self): """Wrapper to initialize the mock NNX model.""" @@ -1613,83 +1620,5 @@ def test_invalid_init_fn(self): maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) -class TestGetNnxNamedShardingWithScanAxis(unittest.TestCase): - """Unit tests for get_nnx_named_sharding_with_scan_axis covering every branch. - - The helper resolves a NamedSharding for each NNX Variable and — unlike - flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at - `param_scan_axis` when scanned-layers metadata is present. - """ - - def setUp(self): - # Mesh needs to contain every axis name the tests reference in partition specs. - self.mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("fsdp", "layers")) - - def _build_state(self, **variables): - """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" - return nnx.State(variables) - - def _run(self, state): - return maxtext_utils.get_nnx_named_sharding_with_scan_axis(state, self.mesh) - - def test_scan_axis_inserted_at_param_scan_axis(self): - """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" - with jax.set_mesh(self.mesh): - v = nnx.Param( - jnp.zeros((3, 4, 8)), - out_sharding=(None, "fsdp"), - **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, - ) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - self.assertIsInstance(result_sharding, NamedSharding) - # 'layers' must be inserted at position 1 (param_scan_axis=1). - self.assertEqual(result_sharding.spec, PartitionSpec(None, "layers", "fsdp")) - - def test_scan_axis_not_inserted_when_already_present(self): - """Guard against double-insertion when partition_name is already in out_sharding.""" - with jax.set_mesh(self.mesh): - v = nnx.Param( - jnp.zeros((2, 2, 2)), - out_sharding=("layers", None, "fsdp"), - **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, - ) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - # 'layers' must appear exactly once — the same PartitionSpec we started with. - self.assertEqual(result_sharding.spec, PartitionSpec("layers", None, "fsdp")) - - def test_masked_node_preserved_as_is(self): - """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" - masked = nnx.Variable(optax.MaskedNode()) - state = self._build_state(masked=masked) - out = self._run(state) - # The leaf must be the original Variable, not a NamedSharding wrapper. - self.assertIs(out["masked"], masked) - - def test_empty_out_sharding_yields_empty_pspec(self): - """A Variable without any sharding metadata should resolve to PartitionSpec().""" - with jax.set_mesh(self.mesh): - # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() - v = nnx.Param(jnp.zeros((4,))) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - self.assertIsInstance(result_sharding, NamedSharding) - self.assertEqual(result_sharding.spec, PartitionSpec()) - - def test_string_out_sharding_is_wrapped_into_tuple(self): - """A single-string out_sharding value should still produce a valid PartitionSpec.""" - with jax.set_mesh(self.mesh): - v = nnx.Param( - jnp.zeros((4,)), - out_sharding="fsdp", - **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, - ) - out = self._run(self._build_state(w=v)) - result_sharding = out["w"].get_value() - # The single string 'fsdp' is turned into a list, and 'layers' is prepended. - self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp")) - - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/sharding_nnx_test.py b/tests/unit/sharding_nnx_test.py index d182f6d18d..62bdcb4348 100644 --- a/tests/unit/sharding_nnx_test.py +++ b/tests/unit/sharding_nnx_test.py @@ -18,10 +18,12 @@ import unittest from flax import nnx +from flax.linen import partitioning as nn_partitioning import jax from jax.sharding import Mesh, NamedSharding, PartitionSpec from maxtext.common import train_state_nnx from maxtext.utils import sharding +import jax.numpy as jnp import numpy as np import optax @@ -39,6 +41,18 @@ def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 4, rngs=rngs) +def _create_2d_test_mesh(axis_names=("data", "model")): + devices = jax.local_devices() + num_devices = len(devices) + if num_devices >= 4: + mesh_devices = np.array(devices[:4]).reshape(2, 2) + elif num_devices >= 2: + mesh_devices = np.array(devices[:2]).reshape(2, 1) + else: + mesh_devices = np.array(devices[:1]).reshape(1, 1) + return Mesh(devices=mesh_devices, axis_names=axis_names) + + def _build_state_mesh_shardings(model, tx): """Build an nnx.State of NamedShardings mirroring the TrainStateNNX layout. @@ -48,9 +62,7 @@ def _build_state_mesh_shardings(model, tx): optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) state_obj = train_state_nnx.TrainStateNNX(model, optimizer) state = nnx.state(state_obj) - mesh = Mesh( - np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model") - ) + mesh = _create_2d_test_mesh() def _to_sharding(var): val = var.get_value() @@ -62,9 +74,7 @@ def _to_sharding(var): pspec = PartitionSpec("data", "model") return var.replace(NamedSharding(mesh, pspec)) - return jax.tree.map( - _to_sharding, state, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + return jax.tree.map(_to_sharding, state, is_leaf=lambda x: isinstance(x, nnx.Variable)) class TestMaybeUpdateParamsShardingWithOptNNX(unittest.TestCase): @@ -76,12 +86,8 @@ def setUp(self): def test_dispatch_from_main_helper_when_pure_nnx(self): """maybe_update_params_sharding_with_opt should dispatch to the NNX variant.""" cfg = _Cfg(pure_nnx=True, shard_optimizer_over_data=False) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) - prev, updated = sharding.maybe_update_params_sharding_with_opt( - cfg, state_mesh_shardings - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, updated = sharding.maybe_update_params_sharding_with_opt(cfg, state_mesh_shardings) # prev is the param-only view (no rngs / non-Param nodes) self.assertIsInstance(prev, nnx.State) self.assertIn("linear", prev) @@ -91,15 +97,9 @@ def test_dispatch_from_main_helper_when_pure_nnx(self): def test_extract_param_only_skips_non_param_variables(self): """prev_params_shardings must contain Params only — RngKey/RngCount/OptVariable filtered out.""" cfg = _Cfg(shard_optimizer_over_data=False) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) - prev, _ = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) - leaves = jax.tree.leaves( - prev, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, _ = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + leaves = jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable)) # Every surviving leaf is wrapped as an nnx.Param. self.assertTrue(all(isinstance(leaf, nnx.Param) for leaf in leaves)) # The model has linear.kernel and linear.bias — exactly two Param leaves. @@ -108,25 +108,17 @@ def test_extract_param_only_skips_non_param_variables(self): def test_returns_unchanged_when_shard_optimizer_over_data_false(self): """When shard_optimizer_over_data=False, the second return value must be the input object.""" cfg = _Cfg(shard_optimizer_over_data=False) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) - _, updated = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) self.assertIs(updated, state_mesh_shardings) def test_zero1_propagates_mu_sharding_to_model_params(self): """Zero-1: model param shardings must be replaced with the optimizer mu shardings.""" cfg = _Cfg(shard_optimizer_over_data=True) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.adam(1e-3) - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) # Mutate the optimizer mu leaves in place so the function picks up a distinct PartitionSpec. - mesh = Mesh( - np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model") - ) + mesh = _create_2d_test_mesh() target_pspec = PartitionSpec(("data", "model")) new_mu_sharding = NamedSharding(mesh, target_pspec) @@ -135,20 +127,14 @@ def test_zero1_propagates_mu_sharding_to_model_params(self): # After _build_state_mesh_shardings, every leaf's value is a NamedSharding (no .shape), # so we just override every Variable leaf in mu in place via set_value (modern API). mu_state = state_mesh_shardings.optimizer.opt_state[0]["mu"] - for var in jax.tree.leaves( - mu_state, is_leaf=lambda x: isinstance(x, nnx.Variable) - ): + for var in jax.tree.leaves(mu_state, is_leaf=lambda x: isinstance(x, nnx.Variable)): if isinstance(var, nnx.Variable): var.set_value(new_mu_sharding) - _, updated = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) # All Param leaves under updated.model must now share the new mu sharding. - param_leaves = jax.tree.leaves( - updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + param_leaves = jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) param_leaves = [v for v in param_leaves if isinstance(v, nnx.Param)] self.assertGreater(len(param_leaves), 0) for leaf in param_leaves: @@ -157,13 +143,9 @@ def test_zero1_propagates_mu_sharding_to_model_params(self): def test_raises_when_no_adam_state_present(self): """Stateless optimizers (e.g., SGD) have no mu — function must raise NotImplementedError.""" cfg = _Cfg(shard_optimizer_over_data=True) - state_mesh_shardings = _build_state_mesh_shardings( - self.model, optax.sgd(1e-3) - ) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.sgd(1e-3)) with self.assertRaises(NotImplementedError): - sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) def test_chained_optimizer_recursion_finds_adam_mu(self): """A nested optax.chain(clip, adam) wraps mu under multiple containers — recursion must find it.""" @@ -172,24 +154,254 @@ def test_chained_optimizer_recursion_finds_adam_mu(self): state_mesh_shardings = _build_state_mesh_shardings(self.model, chained) # Should not raise; verify update happens (params replaced with mu shardings). - prev, updated = sharding.maybe_update_params_sharding_with_opt_nnx( - cfg, state_mesh_shardings - ) + prev, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) self.assertIsInstance(prev, nnx.State) self.assertIsInstance(updated, nnx.State) # Same number of Param leaves before and after. - n_prev = len( - jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable)) + n_prev = len(jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable))) + n_after = len( + [ + v + for v in jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + if isinstance(v, nnx.Param) + ] ) - n_after = len([ - v - for v in jax.tree.leaves( - updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) - if isinstance(v, nnx.Param) - ]) self.assertEqual(n_prev, n_after) +class TestNnxConstructNamedSharding(unittest.TestCase): + """Unit tests for nnx_construct_named_sharding covering every branch. + + The helper resolves a NamedSharding for each NNX Variable inside an nnx.State and + — unlike flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at + `param_scan_axis` when scanned-layers metadata is present. + """ + + def setUp(self): + # Mesh needs to contain every axis name the tests reference in partition specs. + self.mesh = _create_2d_test_mesh(axis_names=("fsdp", "stage")) + # In local test environments (e.g. single-device CPU), all mesh axes have size 1. + # We stub remove_size_one_mesh_axis to act as a no-op so that resolved physical PartitionSpecs + # are returned unreduced (e.g. retaining "fsdp", "stage", etc.), allowing us to verify naming + # resolution. The actual size-one axis removal is tested separately in TestGetNNXNamedShardingSizeOneAxes. + self._old_remove_size_one_mesh_axis = sharding.remove_size_one_mesh_axis + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + def tearDown(self): + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis + + def _build_state(self, **variables): + """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" + return nnx.State(variables) + + def _run(self, state): + return sharding.nnx_construct_named_sharding(state, self.mesh) + + def test_scan_axis_inserted_at_param_scan_axis(self): + """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" + rules = (("layers", "stage"), ("fsdp", "fsdp")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4, 8)), + out_sharding=(None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # 'layers' resolves to physical axis 'stage' and is inserted at position 1 (param_scan_axis=1). + self.assertEqual(result_sharding.spec, PartitionSpec(None, "stage", "fsdp")) + + def test_scan_axis_not_inserted_when_already_present(self): + """Guard against double-insertion when partition_name is already in out_sharding.""" + rules = (("layers", "stage"), ("fsdp", "fsdp")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((2, 2, 2)), + out_sharding=("layers", None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'stage' must appear exactly once — the same PartitionSpec we started with. + self.assertEqual(result_sharding.spec, PartitionSpec("stage", None, "fsdp")) + + def test_masked_node_preserved_as_is(self): + """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" + masked = nnx.Variable(optax.MaskedNode()) + state = self._build_state(masked=masked) + out = self._run(state) + # The leaf must be the original Variable, not a NamedSharding wrapper. + self.assertIs(out["masked"], masked) + + def test_empty_out_sharding_yields_empty_pspec(self): + """A Variable without any sharding metadata should resolve to PartitionSpec().""" + with jax.set_mesh(self.mesh): + # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() + v = nnx.Param(jnp.zeros((4,))) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + self.assertEqual(result_sharding.spec, PartitionSpec()) + + def test_string_out_sharding_is_wrapped_into_tuple(self): + """A single-string out_sharding value should still produce a valid PartitionSpec.""" + rules = (("layers", "stage"), ("fsdp", "fsdp")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding="fsdp", + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # The single string 'fsdp' is turned into a list, and 'layers' (resolving to 'stage') is prepended. + self.assertEqual(result_sharding.spec, PartitionSpec("stage", "fsdp")) + + def test_sequential_matching_first_match_wins(self): + """Multiple rules for the same logical axis are matched sequentially, first-match-wins.""" + # We define rules for 'embed' mapping to 'fsdp' (specific) then 'stage' (fallback) + rules = ( + ("embed", "fsdp"), + ("embed", "stage"), + ) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'embed' must match the first rule ('fsdp'), not the second ('stage'). + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp")) + + def test_prevents_duplicate_physical_axes(self): + """If multiple dimensions map to the same physical axis, the subsequent ones are skipped (mapped to None).""" + # Setup rules where 'embed' maps to 'fsdp' and 'mlp' also maps to 'fsdp'. + rules = ( + ("embed", "fsdp"), + ("mlp", "fsdp"), + ) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4)), + out_sharding=("embed", "mlp"), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # Expected: Dim 0 matches 'embed' -> 'fsdp'. + # Dim 1 tries 'mlp' -> 'fsdp', but 'fsdp' is already assigned to Dim 0. + # So it skips the rule and falls back to matching nothing -> None. + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp", None)) + + def test_fallback_to_next_physical_axis_when_duplicated(self): + """When a physical axis is already assigned, fallback priority rules should map to the next available physical option.""" + # Setup rules where 'embed' maps to 'fsdp', and 'mlp' maps to 'fsdp' (priority 1) or 'stage' (priority 2). + rules = ( + ("embed", "fsdp"), + ("mlp", "fsdp"), + ("mlp", "stage"), + ) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4)), + out_sharding=("embed", "mlp"), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # Expected: Dim 0 matches 'embed' -> 'fsdp'. + # Dim 1 tries 'mlp' -> 'fsdp', but 'fsdp' is already assigned to Dim 0. + # So it falls to the next item in the list -> 'stage'. + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp", "stage")) + + def test_resolves_when_context_rules_is_none(self): + """When context_rules is None but local_rules are defined, resolution should succeed.""" + # Ensure get_logical_axis_rules() returns None (which is the default outside axis_rules) + # We define local rules on the variable metadata. + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + sharding_rules=(("embed", "fsdp"),), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'embed' must match the local rules even when context_rules is None. + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp")) + + def test_composite_pytree_variable_resolved_to_replicated_shardings(self): + """A Variable holding a composite pytree (e.g. tuple of arrays) is resolved to replicated NamedShardings.""" + with jax.set_mesh(self.mesh): + v = nnx.Variable((jnp.zeros((2, 2)), jnp.zeros((3, 3)))) + out = self._run(self._build_state(w=v)) + result = out["w"].get_value() + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + self.assertIsInstance(result[0], NamedSharding) + self.assertIsInstance(result[1], NamedSharding) + self.assertEqual(result[0].spec, PartitionSpec()) + self.assertEqual(result[1].spec, PartitionSpec()) + + def test_rules_merged_when_both_context_and_local_rules_present(self): + """When both local rules and context rules are present, they are concatenated in order of local then context.""" + # Local rules map 'embed' to 'stage'. Context rules map 'embed' to 'fsdp'. + # Because local rules come first, 'embed' should resolve to 'stage'. + context_rules = (("embed", "fsdp"),) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(context_rules): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + sharding_rules=(("embed", "stage"),), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertEqual(result_sharding.spec, PartitionSpec("stage")) + + def test_removes_size_one_mesh_axes_no_rules(self): + """When no rules are defined but mesh is present, size-1 physical axes in the spec are removed.""" + # Temporarily restore the original remove_size_one_mesh_axis function + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis + try: + # 'fsdp' is a physical axis with size 1 in self.mesh. + # out_sharding directly specifies the physical axis 'fsdp'. + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding=("fsdp",), + ) + out_v = sharding.get_nnx_var_named_sharding_with_scan_axis(v, self.mesh) + result_sharding = out_v.get_value() + # 'fsdp' has size 1, so it gets reduced to None. + self.assertEqual(result_sharding.spec, PartitionSpec(None)) + finally: + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + def test_removes_size_one_mesh_axes(self): + """When remove_size_one_mesh_axis is active, physical axes with size 1 are removed (mapped to None).""" + # Temporarily restore the original remove_size_one_mesh_axis function + sharding.remove_size_one_mesh_axis = self._old_remove_size_one_mesh_axis + try: + # Setup axis rules mapping 'embed' to 'fsdp' and 'layers' to 'stage'. + # In this mesh, fsdp and stage have size 1. + rules = (("embed", "fsdp"), ("layers", "stage")) + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(rules): + v = nnx.Param( + jnp.zeros((3, 4)), + out_sharding=("embed",), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + # Resolve sharding + out_v = sharding.get_nnx_var_named_sharding_with_scan_axis(v, self.mesh) + result_sharding = out_v.get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # Expected: P("fsdp", "stage") gets reduced to P(None, None) since both fsdp and stage have size 1. + self.assertEqual(result_sharding.spec, PartitionSpec(None, None)) + finally: + # Re-apply the stub to keep other tests working + sharding.remove_size_one_mesh_axis = lambda spec, mesh: spec + + if __name__ == "__main__": unittest.main()