Skip to content

Commit d316e79

Browse files
committed
Merge conflict error
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 9660fa0 commit d316e79

File tree

1 file changed

+0
-5
lines changed

1 file changed

+0
-5
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,9 @@ def _tpu_flash_attention(
195195
block_q_dkv=min(q_max_block_size, query.shape[2]),
196196
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
197197
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
198-
<<<<<<< Updated upstream
199-
block_q_dq=min(q_max_block_size, query.shape[2]),
200-
block_kv_dq=min(kv_max_block_size, query.shape[2]),
201-
=======
202198
block_q_dq=None if attention_kernel == "tokamax_flash" else min(q_max_block_size, query.shape[2]),
203199
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
204200
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
205-
>>>>>>> Stashed changes
206201
)
207202
num_fsdp_shards = mesh.shape["fsdp"]
208203
query = _reshape_data_for_flash(query, heads)

0 commit comments

Comments
 (0)