From c72275816edddaad495848e96ca965057a89ce67 Mon Sep 17 00:00:00 2001 From: continuousml Date: Sat, 20 Jun 2026 22:35:18 -0700 Subject: [PATCH] Use resolved mesh size for context parallel sharding --- src/maxtext/configs/pyconfig_deprecated.py | 8 +++----- src/maxtext/configs/types.py | 15 +++++---------- src/maxtext/layers/attention_op.py | 2 +- src/maxtext/utils/train_utils.py | 7 ++++--- tests/utils/attention_test_util.py | 2 +- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 089b00217e..3e6329b7aa 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -249,13 +249,12 @@ def validate_keys(keys): keys["global_rampup_samples"], ) + context_parallel_size = get_context_parallel_size(keys) # TODO remove after b/435512699 resolved - if keys["context_parallel_size"] > 1 and keys["context_parallel_load_balance"] and keys["attention_type"] == "chunk": + if context_parallel_size > 1 and keys["context_parallel_load_balance"] and keys["attention_type"] == "chunk": raise ValueError("Currently load-balanced context parallelism is not supported for chunk attention.") - validate_context_parallel_strategy_ring( - keys["context_parallel_size"], keys["context_parallel_strategy"], keys["hardware"] - ) + validate_context_parallel_strategy_ring(context_parallel_size, keys["context_parallel_strategy"], keys["hardware"]) if keys["mtp_eval_target_module"] < 0: raise ValueError("mtp_eval_target_module cannot be negative. Set to 0 to disable evaluation.") @@ -817,7 +816,6 @@ def user_init(raw_keys): raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - raw_keys["context_parallel_size"] = get_context_parallel_size(raw_keys) raw_keys = create_parallelisms_list(raw_keys) raw_keys = set_and_validate_pipeline_config(raw_keys) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0d64347d60..e83e263c7b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2191,11 +2191,6 @@ class DerivedValues(BaseModel): description="Boolean flag indicating if pipeline parallelism is active across ICI or DCN.", ) - context_parallel_size: None | int = Field( - None, - description="The total size of context parallelism, derived from ICI and DCN values.", - ) - num_target_devices: None | int = Field( None, description="The number of devices computed from topology in train_compile or jax.devices() in train", @@ -2789,9 +2784,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"] self.tensors_to_offload = [t for t in tensors if getattr(self, t) == "offload"] - self.context_parallel_size = getattr(self, f"ici_{self.context_sharding}_parallelism", 1) * getattr( - self, f"dcn_{self.context_sharding}_parallelism", 1 - ) if self.pipeline_parallel_layers == -1: if self.decoder_block == DecoderBlockType.DEEPSEEK: moe_layers = self.num_decoder_layers - self.first_num_dense_layers @@ -3058,7 +3050,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": + context_parallel_size = getattr(self, f"ici_{self.context_sharding}_parallelism", 1) * getattr( + self, f"dcn_{self.context_sharding}_parallelism", 1 + ) + if context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: raise ValueError( "Ring context parallelism strategy (context_parallel_strategy='ring') is only supported on GPUs." @@ -3068,7 +3063,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # because test code paths may load the same config but use a different reorder path. # Training's runtime path in max_utils.reorder_causal_load_balanced enforces this. if ( - self.context_parallel_size > 1 + context_parallel_size > 1 and "gpu" not in self.hardware and self.context_parallel_load_balance and self.context_parallel_reorder_strategy == ReorderStrategy.STRIPED diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index b3c3f296f4..063663fb2f 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1168,7 +1168,7 @@ def tpu_flash_attention( ) -> tuple[Array, Array]: """TPU Flash Attention.""" - cp_size = self.config.context_parallel_size + cp_size = self.mesh.shape.get(self.config.context_sharding, 1) load_balanced_context_parallel = self.config.context_parallel_load_balance # Transpose to ('batch', 'heads', 'length', 'kv') diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index eb429f5446..27adc56961 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -214,6 +214,7 @@ def setup_train_loop(config, recorder, devices=None): is_training = True init_rng = jax.random.PRNGKey(config.init_weights_seed) mesh = maxtext_utils.get_mesh_from_config(config, devices) + context_parallel_size = mesh.shape.get(config.context_sharding, 1) if config.pure_nnx: # Create abstract NNX model. _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) @@ -241,7 +242,7 @@ def create_train_state_fn(): data_iterator, eval_data_iterator = create_data_iterator(config, mesh) rampup_manager = create_rampup_manager(config, checkpoint_manager) # Validate context parallelism with packing configuration - if config.context_parallel_size > 1 and config.packing: + if context_parallel_size > 1 and config.packing: if config.dataset_type == "synthetic": raise ValueError( "Context parallelism with sequence packing is not supported with synthetic data. " @@ -255,7 +256,7 @@ def create_train_state_fn(): # Apply reordering wrapper to data iterators if context parallelism is enabled with jax.set_mesh(mesh): - if config.context_parallel_size > 1 and config.context_parallel_load_balance: + if context_parallel_size > 1 and config.context_parallel_load_balance: # Determine load balancing reorder strategy based on whether packing is enabled if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO: @@ -264,7 +265,7 @@ def create_train_state_fn(): reorder_strategy = config.context_parallel_reorder_strategy reorder_fn = maxtext_utils.get_reorder_callable( - config.context_parallel_size, config.shard_mode, reorder_strategy, config.hardware + context_parallel_size, config.shard_mode, reorder_strategy, config.hardware ) data_iterator = map(reorder_fn, data_iterator) if eval_data_iterator: diff --git a/tests/utils/attention_test_util.py b/tests/utils/attention_test_util.py index 23188caf0c..29777d9549 100644 --- a/tests/utils/attention_test_util.py +++ b/tests/utils/attention_test_util.py @@ -187,7 +187,7 @@ def forward_with_context_expert_parallelism( """Get logits from attention under context/expert parallelism.""" # If load balanced cp, shuffle along seq dim for input # This corresponds to the pre-shuffle step in training - context_parallel_size = cfg_cp.context_parallel_size + context_parallel_size = mesh_cp.shape.get(cfg_cp.context_sharding, 1) # This helper is TPU-oriented and uses the TPU-compatible DUAL_CHUNK_SWAP reorder path. # It does not model GPU-specific packed/striped reorder behavior. if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: