Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,19 +247,29 @@ def create_train_state_fn():
"Context parallelism with sequence packing is not supported with synthetic data. "
"Please disable sequence packing (set packing=False)."
)
if config.context_parallel_strategy != "ring":
context_parallel_strategy = config.context_parallel_strategy.lower()
if context_parallel_strategy not in ("all_gather", "ring"):
raise ValueError(
"Context parallelism with 'all_gather' strategy cannot be used with sequence packing. "
"Please use 'ring' strategy instead."
"Context parallelism with sequence packing supports context_parallel_strategy='all_gather' or 'ring'."
)
if (
config.hardware in ("gpu", "gpu_multiprocess")
and config.attention == "cudnn_flash_te"
and not (context_parallel_strategy == "ring" and config.context_parallel_load_balance)
):
raise ValueError("Packing is only supported for load balanced ring attention with context parallelism.")

# Apply reordering wrapper to data iterators if context parallelism is enabled
with jax.set_mesh(mesh):
if config.context_parallel_size > 1 and config.context_parallel_load_balance:

# Determine load balancing reorder strategy based on whether packing is enabled
if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO:
reorder_strategy = ReorderStrategy.STRIPED if config.packing else ReorderStrategy.DUAL_CHUNK_SWAP
reorder_strategy = (
ReorderStrategy.STRIPED
if config.packing and config.context_parallel_strategy.lower() == "ring"
else ReorderStrategy.DUAL_CHUNK_SWAP
)
else:
reorder_strategy = config.context_parallel_reorder_strategy

Expand Down
76 changes: 76 additions & 0 deletions tests/unit/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,82 @@ def test_tpu_flash_attention_context_parallel(
f" ici_expert_parallelism={ici_expert_parallelism}.",
)

@parameterized.named_parameters(
{"testcase_name": "no_load_balance", "context_parallel_load_balance": False},
{"testcase_name": "load_balance", "context_parallel_load_balance": True},
)
@pytest.mark.tpu_only
def test_tpu_flash_attention_packed_all_gather_context_parallel(self, context_parallel_load_balance):
"""Test equivalence between packed dot_product and packed flash attention + all-gather context parallelism."""
lnx = jax.random.normal(
self.rng,
shape=(self.global_batch_size, self.max_target_length, self.embed_dim),
dtype=self.dtype,
)
tokens_per_segment = self.max_target_length // 4
segment_ids = jnp.repeat(jnp.arange(1, 5, dtype=jnp.int32), tokens_per_segment)
positions = jnp.tile(jnp.arange(tokens_per_segment, dtype=jnp.int32), 4)
decoder_segment_ids = jnp.broadcast_to(segment_ids, (self.global_batch_size, self.max_target_length))
decoder_positions = jnp.broadcast_to(positions, (self.global_batch_size, self.max_target_length))
mha_generic_output, _ = self._attention_as_mha_generic(
lnx,
lnx,
decoder_segment_ids=decoder_segment_ids,
inputs_positions=decoder_positions,
deterministic=True,
model_mode=MODEL_MODE_TRAIN,
)
generic_state = nnx.state(self._attention_as_mha_generic)

cfg_cp = pyconfig.initialize(
[sys.argv[0], get_test_config_path()],
**self.config_arguments,
ici_context_parallelism=4,
context_parallel_strategy="all_gather",
context_parallel_load_balance=context_parallel_load_balance,
packing=True,
)
devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp)
mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes)
attention_as_mha_flash_cp = Attention(
config=cfg_cp,
num_query_heads=cfg_cp.num_query_heads,
num_kv_heads=cfg_cp.num_kv_heads,
head_dim=cfg_cp.head_dim,
max_target_length=cfg_cp.max_target_length,
max_prefill_predict_length=cfg_cp.max_prefill_predict_length,
inputs_q_shape=lnx.shape,
inputs_kv_shape=lnx.shape,
mesh=mesh_cp,
attention_kernel="flash",
dtype=self.dtype,
dropout_rate=cfg_cp.dropout_rate,
model_mode=MODEL_MODE_PREFILL,
rngs=self.nnx_rng,
)
nnx.update(attention_as_mha_flash_cp, generic_state)

mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism(
cfg_cp,
mesh_cp,
attention_as_mha_flash_cp,
lnx,
decoder_segment_ids,
decoder_positions,
)

self.assertTrue(
jax.numpy.allclose(
jax.device_get(mha_generic_output),
jax.device_get(mha_generic_flash_cp_output),
rtol=1e-01,
atol=1e-01,
equal_nan=False,
),
msg="Logits from packed generic dot product and packed flash attention + all-gather context parallelism "
f"are not close. context_parallel_load_balance={context_parallel_load_balance}.",
)

@pytest.mark.tpu_only
def test_dot_product_cache_axis_order(self):
all_axis_orders = tuple(itertools.permutations(range(4)))
Expand Down
Loading