Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import sharding
from maxtext.utils import maxtext_utils_nnx
from maxtext.utils import model_creation_utils
from maxtext.common.gcloud_stub import jetstream, is_decoupled
Expand Down Expand Up @@ -398,15 +399,15 @@ def _load_params_nnx(self, params, rng):
# axis metadata but no physical .sharding. Resolve logical to physical here so
# device_put actually reshards instead of being a no-op.
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
target_shardings = maxtext_utils.get_nnx_named_sharding_with_scan_axis(params_abs, self._mesh)
target_shardings = sharding.nnx_construct_named_sharding(params_abs, self._mesh)
params_state = jax.device_put(params, target_shardings)
# We only need a concrete `rest` (RNG vars) for nnx.merge. create_nnx_sharded_model
# builds the model with a jitted out_shardings so params are produced already
# sharded, avoiding a single-device allocation of the full model (an OOM risk for
# large models). self.model is abstract with no .sharding, so pass an explicit one.
_, full_abs = nnx.split(self.model)
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
full_sharding = maxtext_utils.get_nnx_named_sharding_with_scan_axis(full_abs, self._mesh)
full_sharding = sharding.nnx_construct_named_sharding(full_abs, self._mesh)
concrete_model = maxtext_utils_nnx.create_nnx_sharded_model(
self.model, self._create_model_fn, mesh=self._mesh, named_sharding=full_sharding
)
Expand Down
100 changes: 8 additions & 92 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@
from typing import Sequence

from flax import nnx, linen as nn
from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules
from flax.linen import partitioning as nn_partitioning
from flax.training.train_state import TrainState

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec
from jax.sharding import AxisType, Mesh, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.serialize_executable import deserialize_and_load

Expand Down Expand Up @@ -1612,88 +1611,6 @@ def move(path, x):
)


def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State:
"""Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis.

Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also
inserts the partition_name axis at the correct scan_axis position for parameters created by
_create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a
3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension.

Args:
abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)).
mesh: JAX physical mesh.

Returns:
Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding.
"""

def _make_named_sharding(v):
val = v.get_value()
if not hasattr(val, "shape"):
# `val` is either truly leafless (e.g. optax MaskedNode) or a composite
# pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables —
# a `qvalue` int8 array + a list of `scale` bf16 arrays). For the latter
# we must emit a parallel tree of NamedSharding leaves so the downstream
# `jax.tree.map(lambda a, s: ShapeDtypeStruct(..., sharding=s), abs, names)`
# finds a real Sharding at every position. Replicated sharding is a safe
# default — AQT serve-mode QTensors are normally small (per-channel scale
# factors and packed int8 weights) and don't need axis-aware sharding.
if jax.tree_util.tree_leaves(val):
replicated = NamedSharding(mesh, PartitionSpec())
return v.replace(jax.tree.map(lambda _: replicated, val))
return v
metadata = v.get_metadata()
out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding")
if not out_sharding:
pspec = PartitionSpec()
else:
# Insert the scan axis for parameters created by _create_scanned_layers.
# _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the
# axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these.
if nnx.PARTITION_NAME in metadata:
partition_name = metadata[nnx.PARTITION_NAME]
# Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits
# param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode
# scan_axis=0 for non-Param types. stacked_rest non-Param variables have
# param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct.
scan_axis = metadata.get("param_scan_axis", 0)
out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
# Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames
# 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted
# the scan axis. Only insert if not already present.
if partition_name not in out_sharding:
out_sharding.insert(scan_axis, partition_name)
out_sharding = tuple(out_sharding)
# Convert logical axis names to physical mesh axes using current context rules.
context_rules = get_logical_axis_rules()
local_rules = metadata.get("sharding_rules", ())
if context_rules or local_rules:
rules = composite_rules(context_rules, local_rules)
raw_sharding = from_sharding_rules(out_sharding, rules)
mesh_axis_names = mesh.axis_names if mesh is not None else ()

# from_sharding_rules leaves a logical name with no matching rule unchanged, so a
# name missing from logical_axis_rules (e.g. concat_embed on the MTP kernel)
# reaches NamedSharding and is rejected as an unknown mesh axis. Map any such
# leftover name to None (replicated), matching Linen, whose logical_to_mesh_axes
# replicates unmatched names.
def _sanitize(x):
if isinstance(x, list):
x = tuple(x)
if x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple):
return x
return None

sanitized_sharding = [_sanitize(x) for x in raw_sharding]
pspec = PartitionSpec(*sanitized_sharding)
else:
pspec = PartitionSpec(*out_sharding)
return v.replace(NamedSharding(mesh, pspec))

return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable))


def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True):
"""Calculates the abstract sharded state and memory placement for an NNX TrainState.

Expand Down Expand Up @@ -1724,26 +1641,25 @@ def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=Tru

with nn_partitioning.axis_rules(config.logical_axis_rules):
# Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply
# get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers
# nnx_construct_named_sharding which correctly inserts the stacked-layers
# axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally
# which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers,
# causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor.
# Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls
# var.shape for every variable when a global mesh is active, but masked optimizer
# state variables (e.g. from trainable_parameters_mask) have value=MaskedNode()
# which has no .shape and would raise AttributeError. We handle sharding
# ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not
# needed here.
# which has no .shape and would raise AttributeError. We handle sharding
# ourselves via nnx_construct_named_sharding, so auto-assignment is not needed here.
abs_model = nnx.eval_shape(nnx_init_trainstate_fn)
_, abs_var_state = nnx.split(abs_model)
named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh)
named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh)
abstract_state = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_var_state,
named_sharding_state,
)

state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state)
state_mesh_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state)

if is_training and config.shard_optimizer_over_data:
# Add data to sharding for optimizer state
Expand Down Expand Up @@ -1849,10 +1765,10 @@ def _nnx_cache_partition_specs(abstract_model, config, mesh):
way it does for the Linen helpers below.
"""
_, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...)
# get_nnx_named_sharding_with_scan_axis reads logical axis rules from the
# nnx_construct_named_sharding reads logical axis rules from the
# active flax partitioning context, so wrap.
with nn_partitioning.axis_rules(config.logical_axis_rules):
named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh)
named_state = sharding.nnx_construct_named_sharding(cache_state, mesh)
return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict())


Expand Down
11 changes: 7 additions & 4 deletions src/maxtext/utils/maxtext_utils_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utils for MaxText NNX. """
"""Utils for MaxText NNX."""

from functools import partial
from typing import Callable
Expand Down Expand Up @@ -51,14 +51,17 @@ def create_nnx_rngs(
return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference


def get_named_sharding_nnx(abstract_state: nnx.State) -> nnx.State:
def nnx_extract_named_sharding(abstract_state: nnx.State) -> nnx.State:
"""Get named sharding from NNX abstract state.
Args:
abstract_state: NNX model abstract state created from nnx.get_abstract_model.
Returns:
named sharding structure
A tree of raw NamedSharding objects (stripping out any nnx.Variable / Param
wrappers). This clean structure is expected by JAX compiler APIs (like JIT
out_shardings). Contrast with sharding.nnx_construct_named_sharding, which
retains wrappers for abstract tree zipping compatibility.
"""
# Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we
# get the existing sharding from the abstract_state.
Expand Down Expand Up @@ -156,7 +159,7 @@ def create_nnx_sharded_model(
if named_sharding is None:
# The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding)
# we get the sharding directly from it.
named_sharding = get_named_sharding_nnx(abstract_state)
named_sharding = nnx_extract_named_sharding(abstract_state)

if mesh is None:
mesh = abstract_model.mesh
Expand Down
30 changes: 15 additions & 15 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.utils import max_logging
from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx
from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx, sharding
import numpy as np
from orbax import checkpoint as ocp

Expand Down Expand Up @@ -112,13 +112,13 @@ def _zero_pad_axis(arr, axis, extra):
if extra == 0:
return arr

sharding = getattr(arr, "sharding", None)
arr_sharding = getattr(arr, "sharding", None)
pad_width = [(0, 0)] * arr.ndim

if isinstance(sharding, jax.sharding.NamedSharding):
spec = sharding.spec
if isinstance(arr_sharding, jax.sharding.NamedSharding):
spec = arr_sharding.spec
partition = spec[axis] if axis < len(spec) else None
shards_along_axis = _partition_size(partition, sharding.mesh)
shards_along_axis = _partition_size(partition, arr_sharding.mesh)
if shards_along_axis > 1:
if extra % shards_along_axis != 0:
raise ValueError(
Expand All @@ -131,7 +131,7 @@ def _zero_pad_axis(arr, axis, extra):
def _pad_local(x):
return jnp.pad(x, pad_width)

return jax.shard_map(_pad_local, mesh=sharding.mesh, in_specs=spec, out_specs=spec, check_vma=False)(arr)
return jax.shard_map(_pad_local, mesh=arr_sharding.mesh, in_specs=spec, out_specs=spec, check_vma=False)(arr)

pad_width[axis] = (0, extra)
return jnp.pad(arr, pad_width)
Expand Down Expand Up @@ -257,11 +257,11 @@ def _maybe_fuse(path, ckpt_node):

# Determine the number of shards (TP degree) along the concatenated axis
n_shards = 1
sharding = getattr(wi_model, "sharding", None)
if isinstance(sharding, jax.sharding.NamedSharding):
spec = sharding.spec
wi_sharding = getattr(wi_model, "sharding", None)
if isinstance(wi_sharding, jax.sharding.NamedSharding):
spec = wi_sharding.spec
partition = spec[axis] if axis < len(spec) else None
n_shards = _partition_size(partition, sharding.mesh)
n_shards = _partition_size(partition, wi_sharding.mesh)

# Target size for a single half (wi_0 or wi_1) AFTER padding
target_half_dim = wi_model.shape[-1] // 2
Expand Down Expand Up @@ -325,13 +325,13 @@ def _stored_shape_evenly_shardable(restore_arg, stored_shape):
(each device receives only its local slice), avoiding the multi-GB replicated
fanout that fully-replicated loading produces for large MoE weights.
"""
sharding = restore_arg.sharding
if not isinstance(sharding, jax.sharding.NamedSharding):
restore_sharding = restore_arg.sharding
if not isinstance(restore_sharding, jax.sharding.NamedSharding):
return False
spec = sharding.spec
spec = restore_sharding.spec
for axis_idx, dim in enumerate(stored_shape):
partition = spec[axis_idx] if axis_idx < len(spec) else None
if dim % _partition_size(partition, sharding.mesh) != 0:
if dim % _partition_size(partition, restore_sharding.mesh) != 0:
return False
return True

Expand Down Expand Up @@ -581,7 +581,7 @@ def create_nnx_abstract_model(
# wrap is unnecessary here.
abs_model = nnx.eval_shape(_create_model)
graphdef, abs_var_state = nnx.split(abs_model)
named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh)
named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh)
abstract_state = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_var_state,
Expand Down
67 changes: 66 additions & 1 deletion src/maxtext/utils/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

# pylint: disable=line-too-long, disable=bare-except, consider-using-generator
""" Utils that are only interesting to MaxText and sharding related. """
"""Utils that are only interesting to MaxText and sharding related."""

from flax import linen as nn, nnx
from flax.core.spmd import get_logical_axis_rules

from collections.abc import Iterable

Expand Down Expand Up @@ -179,6 +180,70 @@ def remove_size_one_mesh_axis(spec, mesh):
return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced)


def get_nnx_var_named_sharding_with_scan_axis(v: nnx.Variable, mesh) -> nnx.Variable:
"""Compute NamedSharding for an NNX variable, correctly handling the scan axis."""
val = v.get_value()
if not hasattr(val, "shape"):
# `val` is either truly leafless (e.g. optax MaskedNode) or a composite
# pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables).
# Replicated sharding is a safe default.
if jax.tree_util.tree_leaves(val):
replicated = NamedSharding(mesh, P())
return v.replace(jax.tree.map(lambda _: replicated, val))
return v
metadata = v.get_metadata()
out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding")
if not out_sharding:
pspec = P()
else:
# Insert the scan axis for parameters created by _create_scanned_layers.
if nnx.PARTITION_NAME in metadata:
partition_name = metadata[nnx.PARTITION_NAME]
scan_axis = metadata.get("param_scan_axis", 0)
out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
if partition_name not in out_sharding:
out_sharding.insert(scan_axis, partition_name)
out_sharding = tuple(out_sharding)
# Convert logical axis names to physical mesh axes using current context rules.
context_rules = get_logical_axis_rules()
local_rules = metadata.get("sharding_rules", ())
if context_rules or local_rules:
local_rules_list = list(local_rules) if local_rules is not None else []
context_rules_list = list(context_rules) if context_rules is not None else []
rules = local_rules_list + context_rules_list
pspec = logical_to_mesh_axes(out_sharding, mesh, rules=rules)
else:
pspec = P(*out_sharding)
if mesh is not None:
pspec = remove_size_one_mesh_axis(pspec, mesh)
return v.replace(NamedSharding(mesh, pspec))


def nnx_construct_named_sharding(abs_var_state: nnx.State, mesh) -> nnx.State:
"""Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis.

Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also
inserts the partition_name axis at the correct scan_axis position for parameters created by
_create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a
3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension.

Args:
abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)).
mesh: JAX physical mesh.

Returns:
Same tree structure as abs_var_state with leaf values replaced with NamedSharding.
Note that it preserves the original nnx.Variable / Param wrapper nodes to maintain
type structure matching abs_var_state (necessary for multi-tree maps). Use
maxtext_utils_nnx.nnx_extract_named_sharding to retrieve clean raw NamedShardings.
"""
return jax.tree.map(
lambda x: get_nnx_var_named_sharding_with_scan_axis(x, mesh),
abs_var_state,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)


def logical_to_mesh_axes(logical_names, mesh, rules=None):
"""Remove size one mesh axes given logical names."""
tensor_spec = nn.logical_to_mesh_axes(logical_names, rules=rules)
Expand Down
Loading
Loading