@@ -501,17 +501,26 @@ def get_flash_block_sizes(config):
501501 """Create custom flash attention BlockSizes."""
502502 flash_block_sizes = None
503503 if len (config .flash_block_sizes .keys ()) > 0 :
504- use_fused_bwd_kernel = config .flash_block_sizes .get ("use_fused_bwd_kernel" , False )
504+ attention_is_tokamax = "tokamax" in config .attention_kernel
505+ user_block_sizes :Dict [str , int ] = config .flash_block_sizes
506+ if attention_is_tokamax :
507+ max_logging .log ("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."
508+ "Hence following flash block properties specified will be ignored:"
509+ f"block_q: { user_block_sizes ['block_q' ]} ,"
510+ f"block_q_dq: { user_block_sizes .get ('block_q_dq' )} ,"
511+ f"block_kv_dq: { user_block_sizes .get ('block_kv_dq' )} ,"
512+ f"use_fused_bwd_kernel: { user_block_sizes .get ('use_fused_bwd_kernel' )} "
513+ )
505514 flash_block_sizes = splash_attention_kernel .BlockSizes (
506- block_q = config . flash_block_sizes ["block_q" ],
507- block_kv_compute = config . flash_block_sizes ["block_kv_compute" ],
508- block_kv = config . flash_block_sizes ["block_kv" ],
509- block_q_dkv = config . flash_block_sizes ["block_q_dkv" ],
510- block_kv_dkv = config . flash_block_sizes ["block_kv_dkv" ],
511- block_kv_dkv_compute = config . flash_block_sizes ["block_kv_dkv_compute" ],
512- block_q_dq = value_or_none (config . flash_block_sizes , "block_q_dq" ),
513- block_kv_dq = value_or_none (config . flash_block_sizes , "block_kv_dq" ),
514- use_fused_bwd_kernel = value_or_none (config . flash_block_sizes , "use_fused_bwd_kernel" ),
515+ block_q = user_block_sizes . get ( "block_q_dkv" , user_block_sizes [ "block_kv" ]) if attention_is_tokamax else user_block_sizes ["block_q" ],
516+ block_kv_compute = user_block_sizes ["block_kv_compute" ],
517+ block_kv = user_block_sizes ["block_kv" ],
518+ block_q_dkv = user_block_sizes ["block_q_dkv" ],
519+ block_kv_dkv = user_block_sizes ["block_kv_dkv" ],
520+ block_kv_dkv_compute = user_block_sizes ["block_kv_dkv_compute" ],
521+ block_q_dq = None if attention_is_tokamax else value_or_none (user_block_sizes , "block_q_dq" ),
522+ block_kv_dq = None if attention_is_tokamax else value_or_none (user_block_sizes , "block_kv_dq" ),
523+ use_fused_bwd_kernel = True if attention_is_tokamax else value_or_none (user_block_sizes , "use_fused_bwd_kernel" ),
515524 )
516525 return flash_block_sizes
517526
0 commit comments