Skip to content

Commit 19c18fa

Browse files
committed
Test on attention type and automatically modify flash block sizes object when 'tokamax_flash' requested
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 90dc255 commit 19c18fa

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

src/maxdiffusion/max_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -501,17 +501,26 @@ def get_flash_block_sizes(config):
501501
"""Create custom flash attention BlockSizes."""
502502
flash_block_sizes = None
503503
if len(config.flash_block_sizes.keys()) > 0:
504-
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
504+
attention_is_tokamax = "tokamax" in config.attention_kernel
505+
user_block_sizes:Dict[str, int] = config.flash_block_sizes
506+
if attention_is_tokamax:
507+
max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."
508+
"Hence following flash block properties specified will be ignored:"
509+
f"block_q: {user_block_sizes['block_q']},"
510+
f"block_q_dq: {user_block_sizes.get('block_q_dq')},"
511+
f"block_kv_dq: {user_block_sizes.get('block_kv_dq')},"
512+
f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}"
513+
)
505514
flash_block_sizes = splash_attention_kernel.BlockSizes(
506-
block_q=config.flash_block_sizes["block_q"],
507-
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
508-
block_kv=config.flash_block_sizes["block_kv"],
509-
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
510-
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
511-
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
512-
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
513-
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
514-
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"),
515+
block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"],
516+
block_kv_compute=user_block_sizes["block_kv_compute"],
517+
block_kv=user_block_sizes["block_kv"],
518+
block_q_dkv=user_block_sizes["block_q_dkv"],
519+
block_kv_dkv=user_block_sizes["block_kv_dkv"],
520+
block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"],
521+
block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"),
522+
block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"),
523+
use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"),
515524
)
516525
return flash_block_sizes
517526

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def test_wan_time_text_embedding(self):
133133
assert timestep_proj.shape == (batch_size, time_proj_dim)
134134
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
135135

136-
def test_wan_block(self):
136+
@pytest.mark.parametrize("attention", ["flash", "tokamax_flash"])
137+
def test_wan_block(self, attention):
137138
key = jax.random.key(0)
138139
rngs = nnx.Rngs(key)
139140
pyconfig.initialize(
@@ -179,19 +180,20 @@ def test_wan_block(self):
179180
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
180181

181182
dummy_temb = jnp.ones((batch_size, 6, dim))
182-
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
183-
wan_block = WanTransformerBlock(
184-
rngs=rngs,
185-
dim=dim,
186-
ffn_dim=ffn_dim,
187-
num_heads=num_heads,
188-
qk_norm=qk_norm,
189-
cross_attn_norm=cross_attn_norm,
190-
eps=eps,
191-
attention="flash",
192-
mesh=mesh,
193-
flash_block_sizes=flash_block_sizes,
194-
)
183+
184+
wan_block = WanTransformerBlock(
185+
rngs=rngs,
186+
dim=dim,
187+
ffn_dim=ffn_dim,
188+
num_heads=num_heads,
189+
qk_norm=qk_norm,
190+
cross_attn_norm=cross_attn_norm,
191+
eps=eps,
192+
attention=attention,
193+
mesh=mesh,
194+
flash_block_sizes=flash_block_sizes,
195+
)
196+
with mesh:
195197
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
196198
assert dummy_output.shape == dummy_hidden_states.shape
197199

0 commit comments

Comments
 (0)