Skip to content

Commit 19fb249

Browse files
committed
Address comments
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 5498223 commit 19fb249

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ jit_initializers: True
5050
from_pt: True
5151
split_head_dim: True
5252
attention: 'flash' # Supported attention: dot_product, flash
53+
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
54+
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
55+
5356
flash_block_sizes: {}
5457
# to override default block sizes for flash attention
5558
# flash_block_sizes:

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,16 @@ from_pt: True
6161
split_head_dim: True
6262
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6363
flash_min_seq_length: 0
64-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
65-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
64+
65+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
66+
# Else we do not pass in segment ids and on vpu bound hardware like (ironwood) this is faster.
67+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
68+
mask_padding_tokens: True
69+
# Maxdiffusion has 2 types of attention sharding strategies:
70+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
71+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
72+
# in cross attention q.
73+
attention_sharding_uniform: True
6674
dropout: 0.1
6775

6876
flash_block_sizes: {
@@ -168,7 +176,7 @@ logical_axis_rules: [
168176
['norm', 'tensor'],
169177
['conv_batch', ['data','fsdp']],
170178
['out_channels', 'tensor'],
171-
['conv_in', 'fsdp'],
179+
#['conv_in', 'fsdp'],
172180
['conv_out', 'fsdp'],
173181
]
174182
data_sharding: [['data', 'fsdp', 'tensor']]

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def __call__(
587587
width=width,
588588
num_frames=num_frames,
589589
num_channels_latents=num_channel_latents,
590-
) # # fusion.18
590+
)
591591

592592
data_sharding = NamedSharding(self.mesh, P())
593593
# Using global_batch_size_to_train_on so not to create more config variables

0 commit comments

Comments
 (0)