File tree Expand file tree Collapse file tree 1 file changed +0
-5
lines changed Expand file tree Collapse file tree 1 file changed +0
-5
lines changed Original file line number Diff line number Diff line change @@ -195,14 +195,9 @@ def _tpu_flash_attention(
195195 block_q_dkv = min (q_max_block_size , query .shape [2 ]),
196196 block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
197197 block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
198- << << << < Updated upstream
199- block_q_dq = min (q_max_block_size , query .shape [2 ]),
200- block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
201- == == == =
202198 block_q_dq = None if attention_kernel == "tokamax_flash" else min (q_max_block_size , query .shape [2 ]),
203199 block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
204200 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
205- > >> >> >> Stashed changes
206201 )
207202 num_fsdp_shards = mesh .shape ["fsdp" ]
208203 query = _reshape_data_for_flash (query , heads )
You can’t perform that action at this time.
0 commit comments