Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,12 @@ def validate_keys(keys):
keys["global_rampup_samples"],
)

context_parallel_size = get_context_parallel_size(keys)
# TODO remove after b/435512699 resolved
if keys["context_parallel_size"] > 1 and keys["context_parallel_load_balance"] and keys["attention_type"] == "chunk":
if context_parallel_size > 1 and keys["context_parallel_load_balance"] and keys["attention_type"] == "chunk":
raise ValueError("Currently load-balanced context parallelism is not supported for chunk attention.")

validate_context_parallel_strategy_ring(
keys["context_parallel_size"], keys["context_parallel_strategy"], keys["hardware"]
)
validate_context_parallel_strategy_ring(context_parallel_size, keys["context_parallel_strategy"], keys["hardware"])

if keys["mtp_eval_target_module"] < 0:
raise ValueError("mtp_eval_target_module cannot be negative. Set to 0 to disable evaluation.")
Expand Down Expand Up @@ -817,7 +816,6 @@ def user_init(raw_keys):

raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys)
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
raw_keys["context_parallel_size"] = get_context_parallel_size(raw_keys)
raw_keys = create_parallelisms_list(raw_keys)
raw_keys = set_and_validate_pipeline_config(raw_keys)

Expand Down
15 changes: 5 additions & 10 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,11 +2191,6 @@ class DerivedValues(BaseModel):
description="Boolean flag indicating if pipeline parallelism is active across ICI or DCN.",
)

context_parallel_size: None | int = Field(
None,
description="The total size of context parallelism, derived from ICI and DCN values.",
)

num_target_devices: None | int = Field(
None,
description="The number of devices computed from topology in train_compile or jax.devices() in train",
Expand Down Expand Up @@ -2789,9 +2784,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"]
self.tensors_to_offload = [t for t in tensors if getattr(self, t) == "offload"]

self.context_parallel_size = getattr(self, f"ici_{self.context_sharding}_parallelism", 1) * getattr(
self, f"dcn_{self.context_sharding}_parallelism", 1
)
if self.pipeline_parallel_layers == -1:
if self.decoder_block == DecoderBlockType.DEEPSEEK:
moe_layers = self.num_decoder_layers - self.first_num_dense_layers
Expand Down Expand Up @@ -3058,7 +3050,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
context_parallel_size = getattr(self, f"ici_{self.context_sharding}_parallelism", 1) * getattr(
self, f"dcn_{self.context_sharding}_parallelism", 1
)
if context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
raise ValueError(
"Ring context parallelism strategy (context_parallel_strategy='ring') is only supported on GPUs."
Expand All @@ -3068,7 +3063,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
# because test code paths may load the same config but use a different reorder path.
# Training's runtime path in max_utils.reorder_causal_load_balanced enforces this.
if (
self.context_parallel_size > 1
context_parallel_size > 1
and "gpu" not in self.hardware
and self.context_parallel_load_balance
and self.context_parallel_reorder_strategy == ReorderStrategy.STRIPED
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def tpu_flash_attention(
) -> tuple[Array, Array]:
"""TPU Flash Attention."""

cp_size = self.config.context_parallel_size
cp_size = self.mesh.shape.get(self.config.context_sharding, 1)
load_balanced_context_parallel = self.config.context_parallel_load_balance

# Transpose to ('batch', 'heads', 'length', 'kv')
Expand Down
7 changes: 4 additions & 3 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def setup_train_loop(config, recorder, devices=None):
is_training = True
init_rng = jax.random.PRNGKey(config.init_weights_seed)
mesh = maxtext_utils.get_mesh_from_config(config, devices)
context_parallel_size = mesh.shape.get(config.context_sharding, 1)
if config.pure_nnx:
# Create abstract NNX model.
_create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices)
Expand Down Expand Up @@ -241,7 +242,7 @@ def create_train_state_fn():
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)
rampup_manager = create_rampup_manager(config, checkpoint_manager)
# Validate context parallelism with packing configuration
if config.context_parallel_size > 1 and config.packing:
if context_parallel_size > 1 and config.packing:
if config.dataset_type == "synthetic":
raise ValueError(
"Context parallelism with sequence packing is not supported with synthetic data. "
Expand All @@ -255,7 +256,7 @@ def create_train_state_fn():

# Apply reordering wrapper to data iterators if context parallelism is enabled
with jax.set_mesh(mesh):
if config.context_parallel_size > 1 and config.context_parallel_load_balance:
if context_parallel_size > 1 and config.context_parallel_load_balance:

# Determine load balancing reorder strategy based on whether packing is enabled
if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO:
Expand All @@ -264,7 +265,7 @@ def create_train_state_fn():
reorder_strategy = config.context_parallel_reorder_strategy

reorder_fn = maxtext_utils.get_reorder_callable(
config.context_parallel_size, config.shard_mode, reorder_strategy, config.hardware
context_parallel_size, config.shard_mode, reorder_strategy, config.hardware
)
data_iterator = map(reorder_fn, data_iterator)
if eval_data_iterator:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/attention_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def forward_with_context_expert_parallelism(
"""Get logits from attention under context/expert parallelism."""
# If load balanced cp, shuffle along seq dim for input
# This corresponds to the pre-shuffle step in training
context_parallel_size = cfg_cp.context_parallel_size
context_parallel_size = mesh_cp.shape.get(cfg_cp.context_sharding, 1)
# This helper is TPU-oriented and uses the TPU-compatible DUAL_CHUNK_SWAP reorder path.
# It does not model GPU-specific packed/striped reorder behavior.
if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance:
Expand Down
Loading