Skip to content

Commit 0c204ab

Browse files
committed
Document attention behavior
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 0dda0cf commit 0c204ab

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

docs/attention_blocks_flowchart.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Attention block sizes
2+
3+
## Description
4+
- "block_q": Block sizes (HBM TO VMEM and VREG) to tile along Q sequence in forward pass
5+
- "block_kv_compute" : Sub Block size (VMEM to VREG) of "block_kv" where compute is performed in forward pass. It must be factor or same as "block_kv"
6+
- "block_kv" : Block sizes (HBM TO VMEM) to tile along KV sequence in forward pass
7+
- "block_q_dkv" : Block sizes along Q sequence in backward pass with fused kernel to compute gradient of q, k , v. It must be factor or same as block_q
8+
- "block_kv_dkv" : Block sizes along KV sequence in backward pass. It must be factor or same as block_kv
9+
- "block_kv_dkv_compute" : Sub Block Sizes of block_kv_dkv, must be factor or same as "block_kv_dkv"
10+
- "block_q_dq" : Block sizes along Q sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_q"
11+
- "block_kv_dq" : Block sizes along KV to tiline on KV sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_kv"
12+
- "use_fused_bwd_kernel" : This means fused bwd kernel is used where DQ, DK, DV are computed in single kernel. It usually more perfomant but comes with slight HBM memory overhead.
13+
14+
## Flowchart
15+
16+
Maxdiffusion automatically adheres to this flowchart to ensure working, and there is a log that will inform you on the modifications that maxdiffusion makes to the specified block sizes.
17+
18+
![alt text](attention_blocks_flowchart.png)
19+
20+
> "tokamax_flash" uses the splash attention implementation in [tokamax-repo](https://github.com/openxla/tokamax/blob/main/tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py) This kernel only supports fused backward pass where gradients for q,k,v are computed in a single kernel so "block_q_dq" and "block_kv_dq" are not used
21+
22+
## How block sizes matter for perfomance and accuracy
23+
24+
Block sizes key to saturating HBM bandwidth and ensuring maximum possible overlap of computation on cores with HBM use and VMEM to VREG. It is highly reccomended to tune them.
25+
26+
Block sizes also have an effect on the sequence length. Sequence length is multiple of resolution and number of frames (video), along with VAE scale down factors and patchifying ratios. This sequence length or shard of this sequence length needs to be multiple of the block sizes specified. Therefore maxdiffusion pads the sequence lengths to the nearest multiple of the block sizes. It is advisable to choose block sizes which are factor of sequence length, atleast for the Q block sizes.
27+
28+
> In cross attention Image or Video tokens are attending to text tokens sequence length of text tokens is really small and potentially smaller than specified block size so KV block sizes are overwritten to safe values in cross attention
29+
> KV block sizes must be multiple of 128 since the size of register is 8x128 and in the attention KV sequence dim lies on 128 as K is transposed.
229 KB
Loading

0 commit comments

Comments
 (0)