diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index eb429f5446..e6ec03cabb 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -247,11 +247,17 @@ def create_train_state_fn(): "Context parallelism with sequence packing is not supported with synthetic data. " "Please disable sequence packing (set packing=False)." ) - if config.context_parallel_strategy != "ring": + context_parallel_strategy = config.context_parallel_strategy.lower() + if context_parallel_strategy not in ("all_gather", "ring"): raise ValueError( - "Context parallelism with 'all_gather' strategy cannot be used with sequence packing. " - "Please use 'ring' strategy instead." + "Context parallelism with sequence packing supports context_parallel_strategy='all_gather' or 'ring'." ) + if ( + config.hardware in ("gpu", "gpu_multiprocess") + and config.attention == "cudnn_flash_te" + and not (context_parallel_strategy == "ring" and config.context_parallel_load_balance) + ): + raise ValueError("Packing is only supported for load balanced ring attention with context parallelism.") # Apply reordering wrapper to data iterators if context parallelism is enabled with jax.set_mesh(mesh): @@ -259,7 +265,11 @@ def create_train_state_fn(): # Determine load balancing reorder strategy based on whether packing is enabled if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO: - reorder_strategy = ReorderStrategy.STRIPED if config.packing else ReorderStrategy.DUAL_CHUNK_SWAP + reorder_strategy = ( + ReorderStrategy.STRIPED + if config.packing and config.context_parallel_strategy.lower() == "ring" + else ReorderStrategy.DUAL_CHUNK_SWAP + ) else: reorder_strategy = config.context_parallel_reorder_strategy diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 3fa8391833..908a2b989c 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -865,6 +865,82 @@ def test_tpu_flash_attention_context_parallel( f" ici_expert_parallelism={ici_expert_parallelism}.", ) + @parameterized.named_parameters( + {"testcase_name": "no_load_balance", "context_parallel_load_balance": False}, + {"testcase_name": "load_balance", "context_parallel_load_balance": True}, + ) + @pytest.mark.tpu_only + def test_tpu_flash_attention_packed_all_gather_context_parallel(self, context_parallel_load_balance): + """Test equivalence between packed dot_product and packed flash attention + all-gather context parallelism.""" + lnx = jax.random.normal( + self.rng, + shape=(self.global_batch_size, self.max_target_length, self.embed_dim), + dtype=self.dtype, + ) + tokens_per_segment = self.max_target_length // 4 + segment_ids = jnp.repeat(jnp.arange(1, 5, dtype=jnp.int32), tokens_per_segment) + positions = jnp.tile(jnp.arange(tokens_per_segment, dtype=jnp.int32), 4) + decoder_segment_ids = jnp.broadcast_to(segment_ids, (self.global_batch_size, self.max_target_length)) + decoder_positions = jnp.broadcast_to(positions, (self.global_batch_size, self.max_target_length)) + mha_generic_output, _ = self._attention_as_mha_generic( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + generic_state = nnx.state(self._attention_as_mha_generic) + + cfg_cp = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **self.config_arguments, + ici_context_parallelism=4, + context_parallel_strategy="all_gather", + context_parallel_load_balance=context_parallel_load_balance, + packing=True, + ) + devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) + mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes) + attention_as_mha_flash_cp = Attention( + config=cfg_cp, + num_query_heads=cfg_cp.num_query_heads, + num_kv_heads=cfg_cp.num_kv_heads, + head_dim=cfg_cp.head_dim, + max_target_length=cfg_cp.max_target_length, + max_prefill_predict_length=cfg_cp.max_prefill_predict_length, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, + mesh=mesh_cp, + attention_kernel="flash", + dtype=self.dtype, + dropout_rate=cfg_cp.dropout_rate, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + nnx.update(attention_as_mha_flash_cp, generic_state) + + mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mha_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, + ) + + self.assertTrue( + jax.numpy.allclose( + jax.device_get(mha_generic_output), + jax.device_get(mha_generic_flash_cp_output), + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ), + msg="Logits from packed generic dot product and packed flash attention + all-gather context parallelism " + f"are not close. context_parallel_load_balance={context_parallel_load_balance}.", + ) + @pytest.mark.tpu_only def test_dot_product_cache_axis_order(self): all_axis_orders = tuple(itertools.permutations(range(4)))