From 50207e2aa8ef52fbaac1d2ff7cc2017ee4eb10d2 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Tue, 23 Jun 2026 17:23:16 +0800 Subject: [PATCH 1/3] fix: add missing all-gather for nnx vocab tiling --- src/maxtext/utils/vocabulary_tiling.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index b17e318570..9633528d63 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -349,6 +349,15 @@ def _reshape(inputs, out_shape, out_sharding): # custom_vjp + lax.scan boundary, which fails for tied embeddings. graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...) + # all gather only the embedding table + head_params = all_gather_over_fsdp( + head_params, + nnx.get_partition_spec(head_params), + model.mesh, + config.logical_axis_rules, + config.shard_mode, + ) + def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk): local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True) chunk_logits = local_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode) From ce352a4b01f858bd6fd513c0635609464aedb409 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Tue, 23 Jun 2026 17:24:36 +0800 Subject: [PATCH 2/3] Update base.yml --- src/maxtext/configs/base.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 39fab5b076..350efbf8e0 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1196,9 +1196,9 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: false -pure_nnx_decoder: false -pure_nnx: false +enable_nnx: true +pure_nnx_decoder: true +pure_nnx: true ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net From 18391f2eaf21c40761bb76063983436cf458e65e Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Tue, 23 Jun 2026 17:49:44 +0800 Subject: [PATCH 3/3] Revert "fix: add missing all-gather for nnx vocab tiling" This reverts commit 50207e2aa8ef52fbaac1d2ff7cc2017ee4eb10d2. --- src/maxtext/utils/vocabulary_tiling.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index 9633528d63..b17e318570 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -349,15 +349,6 @@ def _reshape(inputs, out_shape, out_sharding): # custom_vjp + lax.scan boundary, which fails for tied embeddings. graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...) - # all gather only the embedding table - head_params = all_gather_over_fsdp( - head_params, - nnx.get_partition_spec(head_params), - model.mesh, - config.logical_axis_rules, - config.shard_mode, - ) - def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk): local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True) chunk_logits = local_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode)