diff --git a/src/maxtext/checkpoint_conversion/inspect_checkpoint.py b/src/maxtext/checkpoint_conversion/inspect_checkpoint.py index c63f2e1161..7e9784e516 100644 --- a/src/maxtext/checkpoint_conversion/inspect_checkpoint.py +++ b/src/maxtext/checkpoint_conversion/inspect_checkpoint.py @@ -79,7 +79,7 @@ def print_structure(data_dict, output_file=""): """Utility to format and print sorted keys and shapes from a flattened dictionary.""" if output_file: # Save command - save_lines = [f"# {" ".join(sys.orig_argv)}", ""] + save_lines = [f"# {' '.join(sys.orig_argv)}", ""] for key in sorted(data_dict.keys(), key=natural_sort_key): line = f"key: {key} | {data_dict[key]}" diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 44771ecb05..2f1a564c6d 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step): if self.config.num_experts > 1: moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0) - log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}") + log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}") if self.config.mtp_num_layers > 0: mtp_loss = scalars.get("learning/mtp_loss", 0.0) diff --git a/src/maxtext/configs/models/deepseek4-284b.yml b/src/maxtext/configs/models/deepseek4-284b.yml index 598bbd9c1c..9d8d15777c 100644 --- a/src/maxtext/configs/models/deepseek4-284b.yml +++ b/src/maxtext/configs/models/deepseek4-284b.yml @@ -49,6 +49,10 @@ num_experts_per_tok: 6 mlp_activations_limit: 10 shared_experts: 1 routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] # --- Attention configuration --- attention_type: 'compressed' diff --git a/src/maxtext/configs/models/deepseek4-tiny.yml b/src/maxtext/configs/models/deepseek4-tiny.yml new file mode 100644 index 0000000000..881043777b --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-tiny.yml @@ -0,0 +1,69 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny model config for DeepSeek V4 for CPU execution and testing + +base_emb_dim: 64 +base_num_query_heads: 4 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 64 +base_moe_mlp_dim: 64 +vocab_size: 129280 +head_dim: 32 +qk_rope_head_dim: 32 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 32 +indexer_n_heads: 4 +indexer_topk: 16 + +# Note: Layers (0,1) are not compressed. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4]. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 16 +num_experts_per_tok: 4 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] + +# --- Attention configuration --- +attention: 'dot_product' +attention_type: 'compressed' +q_lora_rank: 16 +o_groups: 4 +o_lora_rank: 16 +sliding_window_size: 32 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 4096 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 98b24ff451..cee55d2d0a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -228,6 +228,7 @@ class ProfilerType(str, Enum): "deepseek3-tiny", "deepseek3.2-671b", "deepseek4-284b", + "deepseek4-tiny", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -3009,7 +3010,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") - if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK: + if ( + self.routed_bias + and self.routed_bias_update_rate > 0.0 + and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4) + ): raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.") if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts: raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 5925b0bb4e..4fa9980250 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -348,8 +348,11 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. pre_bias_logits = output if self.use_bias: + # Architectural Note: Bias is an nnx.Param rather than nnx.Variable due to Linen/NNX state + # management transitions otherwise we will have to manage the overhead. We use jax.lax.stop_gradient + # here to mathematically enforce the Auxiliary-Loss-Free constraint, isolating it from sequence-wise loss leaks. bias = jnp.asarray(self.bias[...], self.dtype) - output += bias + output += jax.lax.stop_gradient(bias) return output, pre_bias_logits @@ -2170,7 +2173,6 @@ def dense_matmul( lb_loss = ( self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None ) - # TODO(dipakg-lang, b/521990776): Add sequence-wise balance loss * 0.0001 else: lb_loss = None diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 54ebac07e1..ae61df9420 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -14,41 +14,32 @@ """Quantization library.""" +from dataclasses import dataclass import functools import json -import qwix.pallas as qpl import re -from typing import Tuple, Sequence, Callable -from dataclasses import dataclass +from typing import Callable, Sequence, Tuple -from aqt.jax.v2 import config as aqt_config from aqt.jax.v2 import aqt_tensor -from aqt.jax.v2.flax import aqt_flax -from aqt.jax.v2 import tiled_dot_general from aqt.jax.v2 import calibration - -import qwix -from qwix._src.core import dot_general_qt -from qwix._src.core import sparsity - +from aqt.jax.v2 import config as aqt_config +from aqt.jax.v2 import tiled_dot_general +from aqt.jax.v2.flax import aqt_flax +from flax import nnx +import flax.linen as nn +from flax.linen import fp8_ops +from flax.linen import initializers as flax_initializers import jax import jax.numpy as jnp from jax.tree_util import tree_flatten_with_path, tree_unflatten - -from flax.linen import fp8_ops -from flax.linen import initializers as flax_initializers -import flax.linen as nn -from flax import nnx -# Support different packaging structures across environments even within -# the same Qwix version identifier (imports from _src.utils vs _src). -try: - from qwix._src.utils import flax_util -except ImportError: - from qwix._src import flax_util # pytype: disable=import-error -from maxtext.layers import nnx_wrappers - -from maxtext.common.common_types import DType, Config +from maxtext.common.common_types import Config, DType from maxtext.inference.kvcache import KVQuant +from maxtext.layers import nnx_wrappers +import qwix +from qwix._src.core import dot_general_qt +from qwix._src.core import sparsity +from qwix._src import flax_util +import qwix.pallas as qpl # Params used to define mixed precision quantization configs DEFAULT = "__default__" # default config @@ -150,12 +141,18 @@ def _get_mixed_precision_cfg(self): return quant_dg, is_tiled, tiling_fn def _get_rhs_axis_metadata_wrapper( - self, mesh_axes: Tuple[str, ...] = (), is_tiled: bool = False, replicate_scale: bool = False + self, + mesh_axes: Tuple[str, ...] = (), + is_tiled: bool = False, + replicate_scale: bool = False, ): if self.quant_mode == aqt_flax.QuantMode.CONVERT: return None return functools.partial( - _rhs_axis_metadata_wrapper, mesh_axes=mesh_axes, is_tiled=is_tiled, replicate_scale=replicate_scale + _rhs_axis_metadata_wrapper, + mesh_axes=mesh_axes, + is_tiled=is_tiled, + replicate_scale=replicate_scale, ) def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): @@ -359,8 +356,20 @@ def __call__(self, eqn, *args, **kwargs): k = jnp.asarray(k, comp_dtype) x = jnp.asarray(x, comp_dtype) - x_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) - k_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value) + x_qdq = fp8_ops.in_qdq( + comp_dtype, + self.e4m3_dtype, + x, + self.input_scale.value, + self.input_amax_history.value, + ) + k_qdq = fp8_ops.in_qdq( + comp_dtype, + self.e4m3_dtype, + k, + self.kernel_scale.value, + self.kernel_amax_history.value, + ) y_qdq = jnp.einsum(eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision) @@ -386,6 +395,7 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): def _get_int8_quant_config(config): + """Get int8 quantization configuration.""" drhs_bits = None drhs_accumulator_dtype = None drhs_local_aqt = None @@ -553,7 +563,12 @@ def _get_aqt_fp8_default_config(config): def _get_aqt_fp8_quant_config(config): """get aqt for 8-bit floating point quantization configuration""" - cfg = aqt_config.config_v4(fwd_bits="e4m3", dlhs_bits=None, drhs_bits=None, fwd_accumulator_dtype=jnp.bfloat16) + cfg = aqt_config.config_v4( + fwd_bits="e4m3", + dlhs_bits=None, + drhs_bits=None, + fwd_accumulator_dtype=jnp.bfloat16, + ) return cfg @@ -572,7 +587,14 @@ def _dot_general_make(quant_cfg): def _get_default_mp_config(default=None): - default_config = {_W_BITS: None, _A_BITS: None, _W_SCALE: 1.0, _A_SCALE: 1.0, _TILE_SIZE: -1} + """Get default mixed precision configuration.""" + default_config = { + _W_BITS: None, + _A_BITS: None, + _W_SCALE: 1.0, + _A_SCALE: 1.0, + _TILE_SIZE: -1, + } if default: default_config.update(default) return default_config @@ -590,7 +612,10 @@ def _get_mixed_precision_quant_config(mixed_precision_config): if layer_name_re != DEFAULT: for k in quant_config: quant_config[k] = layer_quantization_config.get(k, default_mp_config[k]) - ret_config[layer_name_re] = [_dot_general_make(quant_config), quant_config["tile_size"]] + ret_config[layer_name_re] = [ + _dot_general_make(quant_config), + quant_config["tile_size"], + ] return ret_config @@ -850,7 +875,10 @@ def maybe_quantize_model(model, config): quantization_provider = get_qt_provider(config) if quantization_provider: if config.pure_nnx: - input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + input_shape = ( + config.micro_batch_size_to_train_on, + config.max_target_length, + ) dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32) dummy_positions = jnp.ones(input_shape, dtype=jnp.int32) dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32) @@ -1012,19 +1040,25 @@ def _wrap(self, f, name=None): """ import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error - - fp8_recipe = self._recipe + from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase): """Wrapper module for TransformerEngine quantization.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set( + self, + postfix: str = "", + _variable_collection: str | None = None, + _quantization_checkpoint_name: str | None = None, + _fp8_recipe: recipe.Recipe | None = None, + **_kwargs, + ): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, quantization_checkpoint_name="quantization", - fp8_recipe=fp8_recipe, + fp8_recipe=self._recipe, ) @nn.compact @@ -1039,9 +1073,12 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Placeholder for dot_general implementation in subclasses.""" import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error - def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): + def te_dot_general(generate_quantizer_set, x, kernel, dims, **_kwargs): contracting_dims, batch_dims = dims - assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot." + assert batch_dims == ( + (), + (), + ), "Batch dimensions must be empty for TransformerEngine dot." quantizer_set = generate_quantizer_set() return transformer_engine.jax.dense.dense( diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 9992d7674f..2880ae3a57 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -13,13 +13,14 @@ # limitations under the License. # pylint: disable=bare-except, consider-using-generator, too-many-positional-arguments -""" Utils that are only interesting to MaxText. """ +"""Utils that are only interesting to MaxText.""" import re import jax import jax.numpy as jnp import optax +from flax import traverse_util from optax.contrib._muon import muon from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers @@ -238,6 +239,21 @@ def get_optimizer(config, learning_rate_schedule, model=None): lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)), ) + if getattr(config, "routed_bias", False) and getattr(config, "routed_bias_update_rate", 0.0) > 0.0: + bias_regex = re.compile(".*gate.*bias.*") + + # Architectural Note: Optax's Muon implementation correctly routes 2D+ matrices to the + # Newton-Schulz algorithm, but its fallback logic for 1D vectors (like our GateLogit bias) + # routes them to a standard AdamW optimizer *without* exposing a weight decay mask. + # To prevent the Muon optimizer from decaying our auxiliary-loss-free bias to zero, + # we apply a global optax.set_to_zero() mask here. + def bias_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + mask = {k: bool(bias_regex.match("/".join(map(str, k)))) for k in flat_params} + return traverse_util.unflatten_dict(mask) + + base_opt = optax.chain(base_opt, optax.masked(optax.set_to_zero(), bias_mask_fn)) + return base_opt diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index e6f73928a0..f4e63fef46 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,7 +36,7 @@ import jax.numpy as jnp from jax.sharding import NamedSharding -from flax import linen as nn, nnx +from flax import linen as nn, nnx, traverse_util from flax.linen import partitioning as nn_partitioning from flax.nnx import variablelib @@ -290,12 +290,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr else: max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") - # get MoE routed bias term updates - moe_bias_updates = None - if config.routed_bias and config.routed_bias_update_rate > 0.0: - nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") - moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) - # Add the model's primary output to the intermediates dict so it can be used # by the acceptance rate calculation in eval_step. intermediate_outputs["logits"] = logits @@ -307,7 +301,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "indexer_loss": indexer_loss, - "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } @@ -423,9 +416,9 @@ def diff_wrapper(curr_params, custom_params, rest, config, data): moe_lb_loss = aux["moe_lb_loss"] indexer_loss = aux.get("indexer_loss", 0.0) z_loss = aux.get("z_loss", 0.0) - moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) new_opt_state = None + bias_metrics = {} if isinstance(model, nn.Module): if config.gradient_clipping_threshold > 0: @@ -482,12 +475,47 @@ def move(path, value): else: new_state = state.apply_gradients(grads=full_grads) - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for the DeepSeek family. + # We dynamically traverse the PyTree to apply updates because the topology varies drastically: + # 1. DeepSeek V3 mixes dense layers (no bias updates) with MoE layers. + # 2. DeepSeek V4 introduces Hash Routing in early layers (which lack a learnable bias entirely). + # 3. DeepSeek V4 groups alternating attention topologies into nested `ScannableBlocks`. + # Dynamic traversal ensures we only target the correct `gate.bias` parameters without hardcoded, brittle paths. + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + flat_params = traverse_util.flatten_dict(new_state.params) + new_flat_params = dict(flat_params) + + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for param_path in flat_params: + param_prefix = param_path[1:] if param_path[0] == "params" else param_path + if ( + len(param_prefix) >= len(prefix) + and param_prefix[: len(prefix)] == prefix + and param_path[-2:] == ("gate", "bias") + ): + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, param_path)) + + old_val = ( + new_flat_params[param_path].value + if hasattr(new_flat_params[param_path], "value") + else new_flat_params[param_path] + ) + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(old_val) + + new_val = old_val + jnp.array(update_val).transpose() + if hasattr(new_flat_params[param_path], "value"): + new_flat_params[param_path] = new_flat_params[param_path].replace(value=new_val) + else: + new_flat_params[param_path] = new_val + + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) + + new_state = new_state.replace(params=traverse_util.unflatten_dict(new_flat_params)) else: if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) @@ -508,9 +536,30 @@ def move(path, value): new_state = state # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias - target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + target = new_state.model + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for key in prefix: + if hasattr(target, key): + target = getattr(target, key) + elif isinstance(target, dict) and key in target: + target = target[key] + else: + target = None + break + if target is None: + continue + for _, node in nnx.iter_graph(target): + if type(node).__name__ == "GateLogit" and hasattr(node, "bias") and node.bias is not None: + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, prefix)) + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(node.bias.value) + node.bias.value = node.bias.value + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -523,6 +572,7 @@ def move(path, value): "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } + scalar_metrics.update(bias_metrics) if config.use_qk_clip: if isinstance(model, nn.Module): new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) diff --git a/tests/unit/deepseek_routed_bias_test.py b/tests/unit/deepseek_routed_bias_test.py new file mode 100644 index 0000000000..11e7e9ec90 --- /dev/null +++ b/tests/unit/deepseek_routed_bias_test.py @@ -0,0 +1,126 @@ +"""Tests for DeepSeek routed bias updates.""" + +import unittest +import jax +import jax.numpy as jnp +import optax +from flax import nnx +from flax.training import train_state +from maxtext.common import train_state_nnx +from maxtext.configs import pyconfig +from maxtext.models import models +from maxtext.trainers.pre_train import train as pre_train + + +class DeepSeekRoutedBiasTest(unittest.TestCase): + + def setUp(self): + self.mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + + def _make_dummy_data(self, batch=1, seq=16): + """Creates dummy input data for testing.""" + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + + def _create_and_run_train_step(self, config_args): + """Initializes the model and runs a single training step.""" + config = pyconfig.initialize(config_args) + rngs = jax.nnx.Rngs(0) if hasattr(jax, "nnx") else __import__("flax.nnx", fromlist=["Rngs"]).Rngs(0) + model = models.Transformer(config, self.mesh, quant=None, rngs=rngs) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + state_graphdef, state_pure = nnx.split(ts) + new_state, metrics = pre_train.train_step( + state_graphdef, config, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + return new_state, metrics + + def test_deepseek_v3_dense_routed_bias_success(self): + """Proves that a DeepSeek V3 model with dense layers (no moe_layers attribute) + successfully traverses the state tree and updates routed bias without crashing. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=1", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertIn("learning/loss", metrics["scalar"]) + + def _create_and_run_linen_train_step(self, config_args): + """Creates a Linen model and runs a single training step.""" + config = pyconfig.initialize(config_args) + model = models.transformer_as_linen(config, self.mesh, quant=None) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + rng = jax.random.PRNGKey(0) + variables = model.init(rng, data["inputs"], data["inputs_position"], data["inputs_segmentation"]) + ts = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=optax.sgd(0.01)) + new_state, metrics = pre_train.train_step( + model, + config, + state_mesh_shardings=None, + params_shardings=None, + state=ts, + data=data, + dropout_rng=jax.random.PRNGKey(0), + ) + return new_state, metrics + + def test_deepseek_v3_moe_routed_bias_linen(self): + """Proves that a DeepSeek V3 model with MoE layers successfully traverses the + Linen state tree and updates routed bias. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=0", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_linen_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertTrue(any(key.startswith("learning/moe_bias_before_norm") for key in metrics["scalar"])) + self.assertTrue(any(key.startswith("learning/moe_bias_update_norm") for key in metrics["scalar"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/metric_logger_test_coverage.py b/tests/unit/metric_logger_test_coverage.py new file mode 100644 index 0000000000..14e9031da6 --- /dev/null +++ b/tests/unit/metric_logger_test_coverage.py @@ -0,0 +1,46 @@ +"""Tests for MetricLogger coverage.""" + +import unittest +from unittest import mock +from maxtext.common.metric_logger import MetricLogger +from maxtext.configs import pyconfig + + +class MetricLoggerTest(unittest.TestCase): + + def test_log_train_metrics_moe_lb_loss(self): + config = pyconfig.initialize( + [ + "", + "src/maxtext/configs/base.yml", + "run_name=test_run", + "base_output_directory=/tmp/maxtext_output", + "num_experts=2", + "mtp_num_layers=0", + "base_moe_mlp_dim=64", + "base_mlp_dim=64", + "skip_jax_distributed_system=True", + ] + ) + + logger = MetricLogger(config, None) + metrics = { + "scalar": { + "learning/loss": 1.0, + "learning/lm_loss": 1.0, + "learning/total_weights": 1000, + "learning/moe_lb_loss": 0.000403, + "perf/step_time_seconds": 1.0, + "perf/per_device_tflops_per_sec": 1.0, + "perf/per_device_tokens_per_sec": 1.0, + } + } + with mock.patch("maxtext.common.metric_logger.max_logging.log") as mock_log: + logger._log_training_metrics(metrics, 1) # pylint: disable=protected-access + mock_log.assert_called() + called_args = mock_log.call_args[0][0] + self.assertIn("moe_lb_loss: 0.000403", called_args) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index b8eab1061e..ae96df2eaf 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -417,7 +417,7 @@ def test_get_optimizer_with_trainable_mask(self): config = pyconfig.initialize(argv) # Use a constant learning rate > 0 to ensure non-zero updates - def learning_rate_schedule(step): + def learning_rate_schedule(_): return 1.0 opt = optimizers.get_optimizer(config, learning_rate_schedule) @@ -454,7 +454,7 @@ def test_get_optimizer_without_trainable_mask(self): config = pyconfig.initialize(argv) # Use a constant learning rate > 0 to ensure non-zero updates - def learning_rate_schedule(step): + def learning_rate_schedule(_): return 1.0 opt = optimizers.get_optimizer(config, learning_rate_schedule) @@ -622,5 +622,27 @@ def __init__(self, rngs: nnx.Rngs): self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) +class TestGetOptimizerGlobalMask(unittest.TestCase): + """Tests that the global optimizer cleanly masks out the routed bias.""" + + def test_routed_bias_global_mask(self): + config = pyconfig.initialize( + ["", "src/maxtext/configs/base.yml", "routed_bias=True", "routed_bias_update_rate=0.001", "opt_type=sgd"] + ) + # We define a dummy params dict containing a routed bias and a regular weight. + # The routed bias must be completely ignored by the optimizer. + params = {"decoder": {"moe_layers": {"MoeBlock_0": {"gate": {"bias": jnp.array([1.0]), "kernel": jnp.array([1.0])}}}}} + grads = {"decoder": {"moe_layers": {"MoeBlock_0": {"gate": {"bias": jnp.array([0.5]), "kernel": jnp.array([0.5])}}}}} + # We use sgd because it's simple to test updates, but the mask logic applies + # cleanly to any base optimizer returned by get_optimizer. + opt = optimizers.get_optimizer(config, learning_rate_schedule=0.1) + opt_state = opt.init(params) + updates, _ = opt.update(grads, opt_state, params) + # The routed bias update should be exactly 0.0 (masked by set_to_zero) + self.assertEqual(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["bias"].item(), 0.0) + # The kernel should receive the SGD gradient update (-0.1 * 0.5) + self.assertTrue(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["kernel"].item() < 0.0) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 6467c6f196..ba2991ceae 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -63,6 +63,11 @@ class _Cfg: debug_sharding: bool = False weight_sparsity_n: int = 0 weight_sparsity_m: int = 0 + decoder_block: str = "default" + + +class _DummyDecoder(nnx.Module): + pass class _TinyDecoder(nnx.Module): @@ -77,6 +82,7 @@ def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) # loss_fn shards activations against model.mesh, so the stub needs one. self.mesh = jax.make_mesh((1, 1, 1, 1), ("data", "fsdp", "expert", "context")) + self.decoder = _DummyDecoder() def __call__( self, @@ -129,7 +135,6 @@ def test_returns_loss_and_full_aux_dict(self): "total_weights", "moe_lb_loss", "indexer_loss", - "moe_bias_updates", "mtp_loss", ): self.assertIn(key, aux) @@ -198,6 +203,19 @@ def test_train_step_with_gradient_clipping(self): self.assertIsInstance(new_state, nnx.State) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + def test_train_step_deepseek_aux_loss(self): + cfg, ts = _build_state() + cfg.routed_bias = True + cfg.routed_bias_update_rate = 0.001 + cfg.decoder_block = "deepseek" + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + # The robust trainer logic will correctly traverse and NOT crash, ignoring the hardcoded path + new_state, _ = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) + class TestEvalStepNNX(unittest.TestCase): """Cover the NNX branch of eval_step (lines 568-570)."""