From e25c23c6a4d27135478c8c2f712946bff9122d4c Mon Sep 17 00:00:00 2001 From: Dipak Gaikwad Date: Wed, 17 Jun 2026 20:00:50 +0000 Subject: [PATCH] ffEnabled auxilillary loss free load balancing and sequence wise load balancing for Deepseek. Tested by running training loop with new tiny Deeepseek V4 model added as part of the commit, here are the logs for testing Without load balancing active logs : https://paste.googleplex.com/6421399878107136 with load balancing logs : https://paste.googleplex.com/6551357300539392 Here are the results actived for reducing the varience : 1 === DeepSeek V4 Load Balancing Variance Analysis (Step 0 vs Step 20) === 2 3 | Layer Index | Routing Type | Step 0 Var (Baseline) | Step 20 Var (Run A) | Step 20 Var (Run B) | Improvement (A vs B) | 4 |-------------|--------------|-----------------------|---------------------|---------------------|----------------------| 5 | 0 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 6 | 1 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 7 | 2 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 8 | 3 | Top-K Routed | 7409.38 | 7509.25 | 3672.12 | 51.10% | 9 | 4 | Top-K Routed | 3158.38 | 3230.12 | 1216.00 | 62.35% | 10 | 5 | Top-K Routed | 5713.38 | 5772.75 | 2359.38 | 59.13% | 11 | 6 | Top-K Routed | 8295.25 | 8082.50 | 3674.12 | 54.54% | 12 | 7 | Top-K Routed | 4765.62 | 4614.62 | 1212.75 | 73.72% | 13 | 8 | Top-K Routed | 4960.75 | 4923.12 | 1663.50 | 66.21% | 14 | 9 | Top-K Routed | 3905.50 | 3816.25 | 1316.88 | 65.49% | 15 | 10 | Top-K Routed | 5057.00 | 4981.12 | 2257.75 | 54.67% | 16 | 11 | Top-K Routed | 10446.62 | 10381.62 | 5565.75 | 46.39% | 17 | 12 | Top-K Routed | 9538.50 | 9529.25 | 5319.12 | 44.18% | 18 | 13 | Top-K Routed | 7031.38 | 7131.25 | 3270.25 | 54.14% | 19 | 14 | Top-K Routed | 4852.00 | 4900.12 | 1906.88 | 61.09% | 20 | 15 | Top-K Routed | 9306.12 | 9342.88 | 4733.75 | 49.33% | 21 | 16 | Top-K Routed | 5811.25 | 5749.50 | 2110.88 | 63.29% | 22 | 17 | Top-K Routed | 6715.62 | 6874.25 | 2664.12 | 61.24% | 23 | 18 | Top-K Routed | 8145.50 | 7869.25 | 3383.75 | 57.00% | 24 | 19 | Top-K Routed | 6042.12 | 5908.62 | 2353.00 | 60.18% | 25 | 20 | Top-K Routed | 8559.88 | 8158.25 | 4333.38 | 46.88% | 26 | 21 | Top-K Routed | 11742.25 | 11943.62 | 7563.50 | 36.67% | 27 | 22 | Top-K Routed | 4959.62 | 5014.88 | 1998.62 | 60.15% | 28 | 23 | Top-K Routed | 7717.12 | 7751.88 | 3879.88 | 49.95% | 29 | 24 | Top-K Routed | 9017.75 | 9307.88 | 4702.75 | 49.48% | 30 | 25 | Top-K Routed | 14127.12 | 14111.25 | 8079.25 | 42.75% | 31 | 26 | Top-K Routed | 5074.25 | 5194.12 | 1675.50 | 67.74% | 32 | 27 | Top-K Routed | 11919.50 | 11204.38 | 6470.75 | 42.25% | 33 | 28 | Top-K Routed | 12241.75 | 12998.62 | 7624.12 | 41.35% | 34 | 29 | Top-K Routed | 9384.50 | 9005.00 | 5052.00 | 43.90% | 35 | 30 | Top-K Routed | 9698.62 | 9678.25 | 5231.75 | 45.94% | 36 | 31 | Top-K Routed | 12244.25 | 12392.75 | 7249.25 | 41.50% | 37 | 32 | Top-K Routed | 10030.00 | 9972.62 | 4755.50 | 52.31% | 38 | 33 | Top-K Routed | 7265.00 | 6973.62 | 3271.75 | 53.08% | 39 | 34 | Top-K Routed | 11945.50 | 11940.62 | 6076.88 | 49.11% | 40 | 35 | Top-K Routed | 12917.50 | 13740.00 | 7210.62 | 47.52% | 41 | 36 | Top-K Routed | 15011.62 | 15083.00 | 8870.62 | 41.19% | 42 | 37 | Top-K Routed | 10294.12 | 10176.25 | 5907.50 | 41.95% | 43 | 38 | Top-K Routed | 8928.62 | 9236.00 | 5136.62 | 44.38% | 44 | 39 | Top-K Routed | 15633.62 | 15171.00 | 9684.75 | 36.16% | 45 | 40 | Top-K Routed | 7687.75 | 7658.12 | 4521.25 | 40.96% | 46 | 41 | Top-K Routed | 12485.12 | 12270.38 | 6933.25 | 43.50% | 47 | 42 | Top-K Routed | 17641.25 | 17163.50 | 10974.12 | 36.06% | 48 |-------------|--------------|-----------------------|---------------------|---------------------|----------------------| 49 | TOTAL/AVG | Top-K Only | 357681.12 | 356762.50 | 185883.62 | 47.90% | Raw data collected for this analysis: https://paste.googleplex.com/5060754624610304 https://paste.googleplex.com/5473518849490944 --- .../inspect_checkpoint.py | 2 +- src/maxtext/common/metric_logger.py | 2 +- src/maxtext/configs/models/deepseek4-284b.yml | 4 + src/maxtext/configs/models/deepseek4-tiny.yml | 69 ++++++++++ src/maxtext/configs/types.py | 7 +- src/maxtext/layers/moe.py | 6 +- src/maxtext/layers/quantizations.py | 115 ++++++++++------ src/maxtext/optimizers/optimizers.py | 18 ++- src/maxtext/trainers/pre_train/train.py | 86 +++++++++--- tests/unit/deepseek_routed_bias_test.py | 126 ++++++++++++++++++ tests/unit/metric_logger_test_coverage.py | 46 +++++++ tests/unit/optimizers_test.py | 26 +++- tests/unit/train_nnx_test.py | 20 ++- 13 files changed, 461 insertions(+), 66 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4-tiny.yml create mode 100644 tests/unit/deepseek_routed_bias_test.py create mode 100644 tests/unit/metric_logger_test_coverage.py diff --git a/src/maxtext/checkpoint_conversion/inspect_checkpoint.py b/src/maxtext/checkpoint_conversion/inspect_checkpoint.py index c63f2e1161..7e9784e516 100644 --- a/src/maxtext/checkpoint_conversion/inspect_checkpoint.py +++ b/src/maxtext/checkpoint_conversion/inspect_checkpoint.py @@ -79,7 +79,7 @@ def print_structure(data_dict, output_file=""): """Utility to format and print sorted keys and shapes from a flattened dictionary.""" if output_file: # Save command - save_lines = [f"# {" ".join(sys.orig_argv)}", ""] + save_lines = [f"# {' '.join(sys.orig_argv)}", ""] for key in sorted(data_dict.keys(), key=natural_sort_key): line = f"key: {key} | {data_dict[key]}" diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 44771ecb05..2f1a564c6d 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step): if self.config.num_experts > 1: moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0) - log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}") + log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}") if self.config.mtp_num_layers > 0: mtp_loss = scalars.get("learning/mtp_loss", 0.0) diff --git a/src/maxtext/configs/models/deepseek4-284b.yml b/src/maxtext/configs/models/deepseek4-284b.yml index 598bbd9c1c..9d8d15777c 100644 --- a/src/maxtext/configs/models/deepseek4-284b.yml +++ b/src/maxtext/configs/models/deepseek4-284b.yml @@ -49,6 +49,10 @@ num_experts_per_tok: 6 mlp_activations_limit: 10 shared_experts: 1 routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] # --- Attention configuration --- attention_type: 'compressed' diff --git a/src/maxtext/configs/models/deepseek4-tiny.yml b/src/maxtext/configs/models/deepseek4-tiny.yml new file mode 100644 index 0000000000..881043777b --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-tiny.yml @@ -0,0 +1,69 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny model config for DeepSeek V4 for CPU execution and testing + +base_emb_dim: 64 +base_num_query_heads: 4 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 64 +base_moe_mlp_dim: 64 +vocab_size: 129280 +head_dim: 32 +qk_rope_head_dim: 32 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 32 +indexer_n_heads: 4 +indexer_topk: 16 + +# Note: Layers (0,1) are not compressed. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4]. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 16 +num_experts_per_tok: 4 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] + +# --- Attention configuration --- +attention: 'dot_product' +attention_type: 'compressed' +q_lora_rank: 16 +o_groups: 4 +o_lora_rank: 16 +sliding_window_size: 32 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 4096 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 98b24ff451..cee55d2d0a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -228,6 +228,7 @@ class ProfilerType(str, Enum): "deepseek3-tiny", "deepseek3.2-671b", "deepseek4-284b", + "deepseek4-tiny", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -3009,7 +3010,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") - if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK: + if ( + self.routed_bias + and self.routed_bias_update_rate > 0.0 + and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4) + ): raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.") if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts: raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 5925b0bb4e..4fa9980250 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -348,8 +348,11 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. pre_bias_logits = output if self.use_bias: + # Architectural Note: Bias is an nnx.Param rather than nnx.Variable due to Linen/NNX state + # management transitions otherwise we will have to manage the overhead. We use jax.lax.stop_gradient + # here to mathematically enforce the Auxiliary-Loss-Free constraint, isolating it from sequence-wise loss leaks. bias = jnp.asarray(self.bias[...], self.dtype) - output += bias + output += jax.lax.stop_gradient(bias) return output, pre_bias_logits @@ -2170,7 +2173,6 @@ def dense_matmul( lb_loss = ( self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None ) - # TODO(dipakg-lang, b/521990776): Add sequence-wise balance loss * 0.0001 else: lb_loss = None diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 54ebac07e1..ae61df9420 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -14,41 +14,32 @@ """Quantization library.""" +from dataclasses import dataclass import functools import json -import qwix.pallas as qpl import re -from typing import Tuple, Sequence, Callable -from dataclasses import dataclass +from typing import Callable, Sequence, Tuple -from aqt.jax.v2 import config as aqt_config from aqt.jax.v2 import aqt_tensor -from aqt.jax.v2.flax import aqt_flax -from aqt.jax.v2 import tiled_dot_general from aqt.jax.v2 import calibration - -import qwix -from qwix._src.core import dot_general_qt -from qwix._src.core import sparsity - +from aqt.jax.v2 import config as aqt_config +from aqt.jax.v2 import tiled_dot_general +from aqt.jax.v2.flax import aqt_flax +from flax import nnx +import flax.linen as nn +from flax.linen import fp8_ops +from flax.linen import initializers as flax_initializers import jax import jax.numpy as jnp from jax.tree_util import tree_flatten_with_path, tree_unflatten - -from flax.linen import fp8_ops -from flax.linen import initializers as flax_initializers -import flax.linen as nn -from flax import nnx -# Support different packaging structures across environments even within -# the same Qwix version identifier (imports from _src.utils vs _src). -try: - from qwix._src.utils import flax_util -except ImportError: - from qwix._src import flax_util # pytype: disable=import-error -from maxtext.layers import nnx_wrappers - -from maxtext.common.common_types import DType, Config +from maxtext.common.common_types import Config, DType from maxtext.inference.kvcache import KVQuant +from maxtext.layers import nnx_wrappers +import qwix +from qwix._src.core import dot_general_qt +from qwix._src.core import sparsity +from qwix._src import flax_util +import qwix.pallas as qpl # Params used to define mixed precision quantization configs DEFAULT = "__default__" # default config @@ -150,12 +141,18 @@ def _get_mixed_precision_cfg(self): return quant_dg, is_tiled, tiling_fn def _get_rhs_axis_metadata_wrapper( - self, mesh_axes: Tuple[str, ...] = (), is_tiled: bool = False, replicate_scale: bool = False + self, + mesh_axes: Tuple[str, ...] = (), + is_tiled: bool = False, + replicate_scale: bool = False, ): if self.quant_mode == aqt_flax.QuantMode.CONVERT: return None return functools.partial( - _rhs_axis_metadata_wrapper, mesh_axes=mesh_axes, is_tiled=is_tiled, replicate_scale=replicate_scale + _rhs_axis_metadata_wrapper, + mesh_axes=mesh_axes, + is_tiled=is_tiled, + replicate_scale=replicate_scale, ) def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): @@ -359,8 +356,20 @@ def __call__(self, eqn, *args, **kwargs): k = jnp.asarray(k, comp_dtype) x = jnp.asarray(x, comp_dtype) - x_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) - k_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value) + x_qdq = fp8_ops.in_qdq( + comp_dtype, + self.e4m3_dtype, + x, + self.input_scale.value, + self.input_amax_history.value, + ) + k_qdq = fp8_ops.in_qdq( + comp_dtype, + self.e4m3_dtype, + k, + self.kernel_scale.value, + self.kernel_amax_history.value, + ) y_qdq = jnp.einsum(eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision) @@ -386,6 +395,7 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): def _get_int8_quant_config(config): + """Get int8 quantization configuration.""" drhs_bits = None drhs_accumulator_dtype = None drhs_local_aqt = None @@ -553,7 +563,12 @@ def _get_aqt_fp8_default_config(config): def _get_aqt_fp8_quant_config(config): """get aqt for 8-bit floating point quantization configuration""" - cfg = aqt_config.config_v4(fwd_bits="e4m3", dlhs_bits=None, drhs_bits=None, fwd_accumulator_dtype=jnp.bfloat16) + cfg = aqt_config.config_v4( + fwd_bits="e4m3", + dlhs_bits=None, + drhs_bits=None, + fwd_accumulator_dtype=jnp.bfloat16, + ) return cfg @@ -572,7 +587,14 @@ def _dot_general_make(quant_cfg): def _get_default_mp_config(default=None): - default_config = {_W_BITS: None, _A_BITS: None, _W_SCALE: 1.0, _A_SCALE: 1.0, _TILE_SIZE: -1} + """Get default mixed precision configuration.""" + default_config = { + _W_BITS: None, + _A_BITS: None, + _W_SCALE: 1.0, + _A_SCALE: 1.0, + _TILE_SIZE: -1, + } if default: default_config.update(default) return default_config @@ -590,7 +612,10 @@ def _get_mixed_precision_quant_config(mixed_precision_config): if layer_name_re != DEFAULT: for k in quant_config: quant_config[k] = layer_quantization_config.get(k, default_mp_config[k]) - ret_config[layer_name_re] = [_dot_general_make(quant_config), quant_config["tile_size"]] + ret_config[layer_name_re] = [ + _dot_general_make(quant_config), + quant_config["tile_size"], + ] return ret_config @@ -850,7 +875,10 @@ def maybe_quantize_model(model, config): quantization_provider = get_qt_provider(config) if quantization_provider: if config.pure_nnx: - input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + input_shape = ( + config.micro_batch_size_to_train_on, + config.max_target_length, + ) dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32) dummy_positions = jnp.ones(input_shape, dtype=jnp.int32) dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32) @@ -1012,19 +1040,25 @@ def _wrap(self, f, name=None): """ import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error - - fp8_recipe = self._recipe + from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase): """Wrapper module for TransformerEngine quantization.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set( + self, + postfix: str = "", + _variable_collection: str | None = None, + _quantization_checkpoint_name: str | None = None, + _fp8_recipe: recipe.Recipe | None = None, + **_kwargs, + ): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, quantization_checkpoint_name="quantization", - fp8_recipe=fp8_recipe, + fp8_recipe=self._recipe, ) @nn.compact @@ -1039,9 +1073,12 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Placeholder for dot_general implementation in subclasses.""" import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error - def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): + def te_dot_general(generate_quantizer_set, x, kernel, dims, **_kwargs): contracting_dims, batch_dims = dims - assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot." + assert batch_dims == ( + (), + (), + ), "Batch dimensions must be empty for TransformerEngine dot." quantizer_set = generate_quantizer_set() return transformer_engine.jax.dense.dense( diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 9992d7674f..2880ae3a57 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -13,13 +13,14 @@ # limitations under the License. # pylint: disable=bare-except, consider-using-generator, too-many-positional-arguments -""" Utils that are only interesting to MaxText. """ +"""Utils that are only interesting to MaxText.""" import re import jax import jax.numpy as jnp import optax +from flax import traverse_util from optax.contrib._muon import muon from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers @@ -238,6 +239,21 @@ def get_optimizer(config, learning_rate_schedule, model=None): lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)), ) + if getattr(config, "routed_bias", False) and getattr(config, "routed_bias_update_rate", 0.0) > 0.0: + bias_regex = re.compile(".*gate.*bias.*") + + # Architectural Note: Optax's Muon implementation correctly routes 2D+ matrices to the + # Newton-Schulz algorithm, but its fallback logic for 1D vectors (like our GateLogit bias) + # routes them to a standard AdamW optimizer *without* exposing a weight decay mask. + # To prevent the Muon optimizer from decaying our auxiliary-loss-free bias to zero, + # we apply a global optax.set_to_zero() mask here. + def bias_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + mask = {k: bool(bias_regex.match("/".join(map(str, k)))) for k in flat_params} + return traverse_util.unflatten_dict(mask) + + base_opt = optax.chain(base_opt, optax.masked(optax.set_to_zero(), bias_mask_fn)) + return base_opt diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index e6f73928a0..f4e63fef46 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,7 +36,7 @@ import jax.numpy as jnp from jax.sharding import NamedSharding -from flax import linen as nn, nnx +from flax import linen as nn, nnx, traverse_util from flax.linen import partitioning as nn_partitioning from flax.nnx import variablelib @@ -290,12 +290,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr else: max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") - # get MoE routed bias term updates - moe_bias_updates = None - if config.routed_bias and config.routed_bias_update_rate > 0.0: - nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") - moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) - # Add the model's primary output to the intermediates dict so it can be used # by the acceptance rate calculation in eval_step. intermediate_outputs["logits"] = logits @@ -307,7 +301,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "indexer_loss": indexer_loss, - "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } @@ -423,9 +416,9 @@ def diff_wrapper(curr_params, custom_params, rest, config, data): moe_lb_loss = aux["moe_lb_loss"] indexer_loss = aux.get("indexer_loss", 0.0) z_loss = aux.get("z_loss", 0.0) - moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) new_opt_state = None + bias_metrics = {} if isinstance(model, nn.Module): if config.gradient_clipping_threshold > 0: @@ -482,12 +475,47 @@ def move(path, value): else: new_state = state.apply_gradients(grads=full_grads) - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for the DeepSeek family. + # We dynamically traverse the PyTree to apply updates because the topology varies drastically: + # 1. DeepSeek V3 mixes dense layers (no bias updates) with MoE layers. + # 2. DeepSeek V4 introduces Hash Routing in early layers (which lack a learnable bias entirely). + # 3. DeepSeek V4 groups alternating attention topologies into nested `ScannableBlocks`. + # Dynamic traversal ensures we only target the correct `gate.bias` parameters without hardcoded, brittle paths. + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + flat_params = traverse_util.flatten_dict(new_state.params) + new_flat_params = dict(flat_params) + + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for param_path in flat_params: + param_prefix = param_path[1:] if param_path[0] == "params" else param_path + if ( + len(param_prefix) >= len(prefix) + and param_prefix[: len(prefix)] == prefix + and param_path[-2:] == ("gate", "bias") + ): + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, param_path)) + + old_val = ( + new_flat_params[param_path].value + if hasattr(new_flat_params[param_path], "value") + else new_flat_params[param_path] + ) + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(old_val) + + new_val = old_val + jnp.array(update_val).transpose() + if hasattr(new_flat_params[param_path], "value"): + new_flat_params[param_path] = new_flat_params[param_path].replace(value=new_val) + else: + new_flat_params[param_path] = new_val + + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) + + new_state = new_state.replace(params=traverse_util.unflatten_dict(new_flat_params)) else: if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) @@ -508,9 +536,30 @@ def move(path, value): new_state = state # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias - target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + target = new_state.model + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for key in prefix: + if hasattr(target, key): + target = getattr(target, key) + elif isinstance(target, dict) and key in target: + target = target[key] + else: + target = None + break + if target is None: + continue + for _, node in nnx.iter_graph(target): + if type(node).__name__ == "GateLogit" and hasattr(node, "bias") and node.bias is not None: + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, prefix)) + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(node.bias.value) + node.bias.value = node.bias.value + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -523,6 +572,7 @@ def move(path, value): "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } + scalar_metrics.update(bias_metrics) if config.use_qk_clip: if isinstance(model, nn.Module): new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) diff --git a/tests/unit/deepseek_routed_bias_test.py b/tests/unit/deepseek_routed_bias_test.py new file mode 100644 index 0000000000..11e7e9ec90 --- /dev/null +++ b/tests/unit/deepseek_routed_bias_test.py @@ -0,0 +1,126 @@ +"""Tests for DeepSeek routed bias updates.""" + +import unittest +import jax +import jax.numpy as jnp +import optax +from flax import nnx +from flax.training import train_state +from maxtext.common import train_state_nnx +from maxtext.configs import pyconfig +from maxtext.models import models +from maxtext.trainers.pre_train import train as pre_train + + +class DeepSeekRoutedBiasTest(unittest.TestCase): + + def setUp(self): + self.mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + + def _make_dummy_data(self, batch=1, seq=16): + """Creates dummy input data for testing.""" + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + + def _create_and_run_train_step(self, config_args): + """Initializes the model and runs a single training step.""" + config = pyconfig.initialize(config_args) + rngs = jax.nnx.Rngs(0) if hasattr(jax, "nnx") else __import__("flax.nnx", fromlist=["Rngs"]).Rngs(0) + model = models.Transformer(config, self.mesh, quant=None, rngs=rngs) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + state_graphdef, state_pure = nnx.split(ts) + new_state, metrics = pre_train.train_step( + state_graphdef, config, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + return new_state, metrics + + def test_deepseek_v3_dense_routed_bias_success(self): + """Proves that a DeepSeek V3 model with dense layers (no moe_layers attribute) + successfully traverses the state tree and updates routed bias without crashing. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=1", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertIn("learning/loss", metrics["scalar"]) + + def _create_and_run_linen_train_step(self, config_args): + """Creates a Linen model and runs a single training step.""" + config = pyconfig.initialize(config_args) + model = models.transformer_as_linen(config, self.mesh, quant=None) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + rng = jax.random.PRNGKey(0) + variables = model.init(rng, data["inputs"], data["inputs_position"], data["inputs_segmentation"]) + ts = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=optax.sgd(0.01)) + new_state, metrics = pre_train.train_step( + model, + config, + state_mesh_shardings=None, + params_shardings=None, + state=ts, + data=data, + dropout_rng=jax.random.PRNGKey(0), + ) + return new_state, metrics + + def test_deepseek_v3_moe_routed_bias_linen(self): + """Proves that a DeepSeek V3 model with MoE layers successfully traverses the + Linen state tree and updates routed bias. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=0", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_linen_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertTrue(any(key.startswith("learning/moe_bias_before_norm") for key in metrics["scalar"])) + self.assertTrue(any(key.startswith("learning/moe_bias_update_norm") for key in metrics["scalar"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/metric_logger_test_coverage.py b/tests/unit/metric_logger_test_coverage.py new file mode 100644 index 0000000000..14e9031da6 --- /dev/null +++ b/tests/unit/metric_logger_test_coverage.py @@ -0,0 +1,46 @@ +"""Tests for MetricLogger coverage.""" + +import unittest +from unittest import mock +from maxtext.common.metric_logger import MetricLogger +from maxtext.configs import pyconfig + + +class MetricLoggerTest(unittest.TestCase): + + def test_log_train_metrics_moe_lb_loss(self): + config = pyconfig.initialize( + [ + "", + "src/maxtext/configs/base.yml", + "run_name=test_run", + "base_output_directory=/tmp/maxtext_output", + "num_experts=2", + "mtp_num_layers=0", + "base_moe_mlp_dim=64", + "base_mlp_dim=64", + "skip_jax_distributed_system=True", + ] + ) + + logger = MetricLogger(config, None) + metrics = { + "scalar": { + "learning/loss": 1.0, + "learning/lm_loss": 1.0, + "learning/total_weights": 1000, + "learning/moe_lb_loss": 0.000403, + "perf/step_time_seconds": 1.0, + "perf/per_device_tflops_per_sec": 1.0, + "perf/per_device_tokens_per_sec": 1.0, + } + } + with mock.patch("maxtext.common.metric_logger.max_logging.log") as mock_log: + logger._log_training_metrics(metrics, 1) # pylint: disable=protected-access + mock_log.assert_called() + called_args = mock_log.call_args[0][0] + self.assertIn("moe_lb_loss: 0.000403", called_args) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index b8eab1061e..ae96df2eaf 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -417,7 +417,7 @@ def test_get_optimizer_with_trainable_mask(self): config = pyconfig.initialize(argv) # Use a constant learning rate > 0 to ensure non-zero updates - def learning_rate_schedule(step): + def learning_rate_schedule(_): return 1.0 opt = optimizers.get_optimizer(config, learning_rate_schedule) @@ -454,7 +454,7 @@ def test_get_optimizer_without_trainable_mask(self): config = pyconfig.initialize(argv) # Use a constant learning rate > 0 to ensure non-zero updates - def learning_rate_schedule(step): + def learning_rate_schedule(_): return 1.0 opt = optimizers.get_optimizer(config, learning_rate_schedule) @@ -622,5 +622,27 @@ def __init__(self, rngs: nnx.Rngs): self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) +class TestGetOptimizerGlobalMask(unittest.TestCase): + """Tests that the global optimizer cleanly masks out the routed bias.""" + + def test_routed_bias_global_mask(self): + config = pyconfig.initialize( + ["", "src/maxtext/configs/base.yml", "routed_bias=True", "routed_bias_update_rate=0.001", "opt_type=sgd"] + ) + # We define a dummy params dict containing a routed bias and a regular weight. + # The routed bias must be completely ignored by the optimizer. + params = {"decoder": {"moe_layers": {"MoeBlock_0": {"gate": {"bias": jnp.array([1.0]), "kernel": jnp.array([1.0])}}}}} + grads = {"decoder": {"moe_layers": {"MoeBlock_0": {"gate": {"bias": jnp.array([0.5]), "kernel": jnp.array([0.5])}}}}} + # We use sgd because it's simple to test updates, but the mask logic applies + # cleanly to any base optimizer returned by get_optimizer. + opt = optimizers.get_optimizer(config, learning_rate_schedule=0.1) + opt_state = opt.init(params) + updates, _ = opt.update(grads, opt_state, params) + # The routed bias update should be exactly 0.0 (masked by set_to_zero) + self.assertEqual(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["bias"].item(), 0.0) + # The kernel should receive the SGD gradient update (-0.1 * 0.5) + self.assertTrue(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["kernel"].item() < 0.0) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 6467c6f196..ba2991ceae 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -63,6 +63,11 @@ class _Cfg: debug_sharding: bool = False weight_sparsity_n: int = 0 weight_sparsity_m: int = 0 + decoder_block: str = "default" + + +class _DummyDecoder(nnx.Module): + pass class _TinyDecoder(nnx.Module): @@ -77,6 +82,7 @@ def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) # loss_fn shards activations against model.mesh, so the stub needs one. self.mesh = jax.make_mesh((1, 1, 1, 1), ("data", "fsdp", "expert", "context")) + self.decoder = _DummyDecoder() def __call__( self, @@ -129,7 +135,6 @@ def test_returns_loss_and_full_aux_dict(self): "total_weights", "moe_lb_loss", "indexer_loss", - "moe_bias_updates", "mtp_loss", ): self.assertIn(key, aux) @@ -198,6 +203,19 @@ def test_train_step_with_gradient_clipping(self): self.assertIsInstance(new_state, nnx.State) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + def test_train_step_deepseek_aux_loss(self): + cfg, ts = _build_state() + cfg.routed_bias = True + cfg.routed_bias_update_rate = 0.001 + cfg.decoder_block = "deepseek" + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + # The robust trainer logic will correctly traverse and NOT crash, ignoring the hardcoded path + new_state, _ = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) + class TestEvalStepNNX(unittest.TestCase): """Cover the NNX branch of eval_step (lines 568-570)."""