Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/checkpoint_conversion/inspect_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def print_structure(data_dict, output_file=""):
"""Utility to format and print sorted keys and shapes from a flattened dictionary."""
if output_file:
# Save command
save_lines = [f"# {" ".join(sys.orig_argv)}", ""]
save_lines = [f"# {' '.join(sys.orig_argv)}", ""]

for key in sorted(data_dict.keys(), key=natural_sort_key):
line = f"key: {key} | {data_dict[key]}"
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step):

if self.config.num_experts > 1:
moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0)
log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}")
log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}")

if self.config.mtp_num_layers > 0:
mtp_loss = scalars.get("learning/mtp_loss", 0.0)
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/configs/models/deepseek4-284b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ num_experts_per_tok: 6
mlp_activations_limit: 10
shared_experts: 1
routed_score_func: "sqrtsoftplus"
routed_bias: true
routed_bias_update_rate: 0.001
load_balance_loss_weight: 0.0001
adamw_mask: [".*gate.*bias.*"]

# --- Attention configuration ---
attention_type: 'compressed'
Expand Down
69 changes: 69 additions & 0 deletions src/maxtext/configs/models/deepseek4-tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny model config for DeepSeek V4 for CPU execution and testing

base_emb_dim: 64
base_num_query_heads: 4
base_num_kv_heads: 1
base_num_decoder_layers: 43
base_mlp_dim: 64
base_moe_mlp_dim: 64
vocab_size: 129280
head_dim: 32
qk_rope_head_dim: 32

# --- Standard Defaults ---
enable_dropout: false
logits_via_embedding: false
normalization_layer_epsilon: 1.0e-6

# --- V4 Specific Architectural Keys ---
decoder_block: "deepseek4"
mhc_expansion_rate: 4
first_num_hash_layers: 3
indexer_head_dim: 32
indexer_n_heads: 4
indexer_topk: 16

# Note: Layers (0,1) are not compressed.
# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now.
# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4].
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]

# --- MoE configuration ---
mlp_activations: ["silu", "linear"]
num_experts: 16
num_experts_per_tok: 4
shared_experts: 1
routed_score_func: "sqrtsoftplus"
routed_bias: true
routed_bias_update_rate: 0.001
load_balance_loss_weight: 0.0001
adamw_mask: [".*gate.*bias.*"]

# --- Attention configuration ---
attention: 'dot_product'
attention_type: 'compressed'
q_lora_rank: 16
o_groups: 4
o_lora_rank: 16
sliding_window_size: 32

# --- RoPE ---

rope_type: "default"
rope_max_timescale: 10000 # Main RoPE theta
compressed_rope_max_timescale: 160000 # Compressed RoPE theta
max_position_embeddings: 4096
7 changes: 6 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class ProfilerType(str, Enum):
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek4-284b",
"deepseek4-tiny",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
Expand Down Expand Up @@ -3009,7 +3010,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
)
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK:
if (
self.routed_bias
and self.routed_bias_update_rate > 0.0
and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4)
):
raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.")
if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts:
raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.")
Expand Down
6 changes: 4 additions & 2 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,11 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
pre_bias_logits = output

if self.use_bias:
# Architectural Note: Bias is an nnx.Param rather than nnx.Variable due to Linen/NNX state
# management transitions otherwise we will have to manage the overhead. We use jax.lax.stop_gradient
# here to mathematically enforce the Auxiliary-Loss-Free constraint, isolating it from sequence-wise loss leaks.
bias = jnp.asarray(self.bias[...], self.dtype)
output += bias
output += jax.lax.stop_gradient(bias)
return output, pre_bias_logits


Expand Down Expand Up @@ -2170,7 +2173,6 @@ def dense_matmul(
lb_loss = (
self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None
)
# TODO(dipakg-lang, b/521990776): Add sequence-wise balance loss * 0.0001
else:
lb_loss = None

Expand Down
46 changes: 20 additions & 26 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,32 @@

"""Quantization library."""

from dataclasses import dataclass
import functools
import json
import qwix.pallas as qpl
import re
from typing import Tuple, Sequence, Callable
from dataclasses import dataclass
from typing import Callable, Sequence, Tuple

from aqt.jax.v2 import config as aqt_config
from aqt.jax.v2 import aqt_tensor
from aqt.jax.v2.flax import aqt_flax
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2 import calibration

import qwix
from qwix._src.core import dot_general_qt
from qwix._src.core import sparsity

from aqt.jax.v2 import config as aqt_config
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2.flax import aqt_flax
from flax import nnx
import flax.linen as nn
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten_with_path, tree_unflatten

from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn
from flax import nnx
# Support different packaging structures across environments even within
# the same Qwix version identifier (imports from _src.utils vs _src).
try:
from qwix._src.utils import flax_util
except ImportError:
from qwix._src import flax_util # pytype: disable=import-error
from maxtext.layers import nnx_wrappers

from maxtext.common.common_types import DType, Config
from maxtext.common.common_types import Config, DType
from maxtext.inference.kvcache import KVQuant
from maxtext.layers import nnx_wrappers
import qwix
from qwix._src.core import dot_general_qt
from qwix._src.core import sparsity
from qwix._src import flax_util
import qwix.pallas as qpl

# Params used to define mixed precision quantization configs
DEFAULT = "__default__" # default config
Expand Down Expand Up @@ -850,7 +841,10 @@ def maybe_quantize_model(model, config):
quantization_provider = get_qt_provider(config)
if quantization_provider:
if config.pure_nnx:
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
input_shape = (
config.micro_batch_size_to_train_on,
config.max_target_length,
)
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
Expand Down
18 changes: 17 additions & 1 deletion src/maxtext/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

# pylint: disable=bare-except, consider-using-generator, too-many-positional-arguments
""" Utils that are only interesting to MaxText. """
"""Utils that are only interesting to MaxText."""

import re
import jax
import jax.numpy as jnp

import optax
from flax import traverse_util
from optax.contrib._muon import muon
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers

Expand Down Expand Up @@ -238,6 +239,21 @@ def get_optimizer(config, learning_rate_schedule, model=None):
lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)),
)

if getattr(config, "routed_bias", False):
bias_regex = re.compile(".*gate.*bias.*")

# Architectural Note: Optax's Muon implementation correctly routes 2D+ matrices to the
# Newton-Schulz algorithm, but its fallback logic for 1D vectors (like our GateLogit bias)
# routes them to a standard AdamW optimizer *without* exposing a weight decay mask.
# To prevent the Muon optimizer from decaying our auxiliary-loss-free bias to zero,
# we apply a global optax.set_to_zero() mask here.
def bias_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
mask = {k: bool(bias_regex.match("/".join(map(str, k)))) for k in flat_params.keys()}
return traverse_util.unflatten_dict(mask)

base_opt = optax.chain(base_opt, optax.masked(optax.set_to_zero(), bias_mask_fn))

return base_opt


Expand Down
86 changes: 68 additions & 18 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import jax.numpy as jnp
from jax.sharding import NamedSharding

from flax import linen as nn, nnx
from flax import linen as nn, nnx, traverse_util
from flax.linen import partitioning as nn_partitioning
from flax.nnx import variablelib

Expand Down Expand Up @@ -290,12 +290,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
else:
max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.")

# get MoE routed bias term updates
moe_bias_updates = None
if config.routed_bias and config.routed_bias_update_rate > 0.0:
nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates")
moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None)

# Add the model's primary output to the intermediates dict so it can be used
# by the acceptance rate calculation in eval_step.
intermediate_outputs["logits"] = logits
Expand All @@ -307,7 +301,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
"total_weights": total_weights,
"moe_lb_loss": moe_lb_loss,
"indexer_loss": indexer_loss,
"moe_bias_updates": moe_bias_updates,
"mtp_loss": mtp_loss,
"batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None),
}
Expand Down Expand Up @@ -423,9 +416,9 @@ def diff_wrapper(curr_params, custom_params, rest, config, data):
moe_lb_loss = aux["moe_lb_loss"]
indexer_loss = aux.get("indexer_loss", 0.0)
z_loss = aux.get("z_loss", 0.0)
moe_bias_updates = aux.get("moe_bias_updates")
mtp_loss = aux.get("mtp_loss", 0.0)
new_opt_state = None
bias_metrics = {}

if isinstance(model, nn.Module):
if config.gradient_clipping_threshold > 0:
Expand Down Expand Up @@ -482,12 +475,47 @@ def move(path, value):
else:
new_state = state.apply_gradients(grads=full_grads)

# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")
# Updates the shape to be aligned with state.
moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose()
new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates)
# Apply updates for Auxiliary-Loss-Free load balancing for the DeepSeek family.
# We dynamically traverse the PyTree to apply updates because the topology varies drastically:
# 1. DeepSeek V3 mixes dense layers (no bias updates) with MoE layers.
# 2. DeepSeek V4 introduces Hash Routing in early layers (which lack a learnable bias entirely).
# 3. DeepSeek V4 groups alternating attention topologies into nested `ScannableBlocks`.
# Dynamic traversal ensures we only target the correct `gate.bias` parameters without hardcoded, brittle paths.
if config.routed_bias and config.routed_bias_update_rate > 0.0:
flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {}))
flat_params = traverse_util.flatten_dict(new_state.params)
new_flat_params = dict(flat_params)

for path, update in flat_intermediates.items():
if path[-1] != "moe_bias_updates":
continue
prefix = path[1:-1] if path[0] == "intermediates" else path[:-1]
for param_path in flat_params:
param_prefix = param_path[1:] if param_path[0] == "params" else param_path
if (
len(param_prefix) >= len(prefix)
and param_prefix[: len(prefix)] == prefix
and param_path[-2:] == ("gate", "bias")
):
update_val = update[0] if isinstance(update, (tuple, list)) else update
name_prefix = "-".join(map(str, param_path))

old_val = (
new_flat_params[param_path].value
if hasattr(new_flat_params[param_path], "value")
else new_flat_params[param_path]
)
bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(old_val)

new_val = old_val + jnp.array(update_val).transpose()
if hasattr(new_flat_params[param_path], "value"):
new_flat_params[param_path] = new_flat_params[param_path].replace(value=new_val)
else:
new_flat_params[param_path] = new_val

bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val))

new_state = new_state.replace(params=traverse_util.unflatten_dict(new_flat_params))
else:
if config.gradient_clipping_threshold > 0:
grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold)
Expand All @@ -508,9 +536,30 @@ def move(path, value):
new_state = state

# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias
target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose()
if config.routed_bias and config.routed_bias_update_rate > 0.0:
flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {}))
for path, update in flat_intermediates.items():
if path[-1] != "moe_bias_updates":
continue
target = new_state.model
prefix = path[1:-1] if path[0] == "intermediates" else path[:-1]
for key in prefix:
if hasattr(target, key):
target = getattr(target, key)
elif isinstance(target, dict) and key in target:
target = target[key]
else:
target = None
break
if target is None:
continue
for _, node in nnx.iter_graph(target):
if type(node).__name__ == "GateLogit" and hasattr(node, "bias") and node.bias is not None:
update_val = update[0] if isinstance(update, (tuple, list)) else update
name_prefix = "-".join(map(str, prefix))
bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(node.bias.value)
node.bias.value = node.bias.value + jnp.array(update_val).transpose()
bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val))

lm_loss = xent_sum / (total_weights + EPS)
scalar_metrics = {
Expand All @@ -523,6 +572,7 @@ def move(path, value):
"learning/mtp_loss": mtp_loss,
"learning/total_weights": total_weights,
}
scalar_metrics.update(bias_metrics)
if config.use_qk_clip:
if isinstance(model, nn.Module):
new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
Expand Down
Loading
Loading