diff --git a/.gitignore b/.gitignore index adafbcd6..c2e66af8 100755 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ id_ed25519.pub *.safetensors *.model .cline_storage +*.egg-info diff --git a/iron/operators/mha/design.py b/iron/operators/mha/design.py index d11e4ed4..519804d7 100644 --- a/iron/operators/mha/design.py +++ b/iron/operators/mha/design.py @@ -712,9 +712,13 @@ def batched_matmul_pv( (heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1) ) - K_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) + K_tiles = TensorTiler2D.group_tiler( + (num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1) + ) - V_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) + V_tiles = TensorTiler2D.group_tiler( + (num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1) + ) O_tiles = TensorTiler2D.group_tiler( (heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1) @@ -759,9 +763,9 @@ def legalize_tas(tas: TensorAccessSequence): if verbose: print(f"DMA Transfer Configuration: DRAM <-> Mem tile") # print_tap_seq_info(Q_tiles, "Q") - # print_tap_seq_info(K_tiles, "K") - # print_tap_seq_info(V_tiles, "V") - print_tap_seq_info(O_tiles, "O") + print_tap_seq_info(K_tiles, "K") + print_tap_seq_info(V_tiles, "V") + # print_tap_seq_info(O_tiles, "O") # Runtime operations to move data to/from the AIE-array rt = Runtime() @@ -788,6 +792,8 @@ def set_mha_rtps(): for head_idx in range(heads): + kv_head_idx = head_idx // (heads // num_KV_heads) + for q_block_idx in range(num_q_block_per_pipeline): # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. @@ -827,14 +833,14 @@ def set_mha_rtps(): rt.fill( inK.prod(), K, - tap=K_tiles[head_idx], + tap=K_tiles[kv_head_idx], placement=Tile(col=5, row=0), task_group=tg, ) rt.fill( inV.prod(), V, - tap=V_tiles[head_idx], + tap=V_tiles[kv_head_idx], placement=Tile(col=6, row=0), task_group=tg, ) diff --git a/iron/operators/mha/op.py b/iron/operators/mha/op.py index 58864519..4f645be7 100644 --- a/iron/operators/mha/op.py +++ b/iron/operators/mha/op.py @@ -103,7 +103,7 @@ def set_up_artifacts(self): ) xclbin_artifact = XclbinArtifact.new( - f"mha.xclbin", + f"{file_name_base}.xclbin", depends=[ mlir_artifact, KernelArchiveArtifact.new( @@ -139,7 +139,9 @@ def set_up_artifacts(self): ) insts_artifact = InstsBinArtifact.new( - f"mha.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"] + f"{file_name_base}.bin", + depends=[mlir_artifact], + extra_flags=["--dynamic-objFifos"], ) self.xclbin_artifact = xclbin_artifact diff --git a/iron/operators/mha/reference.py b/iron/operators/mha/reference.py index 3216ad78..012f39ac 100644 --- a/iron/operators/mha/reference.py +++ b/iron/operators/mha/reference.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch +from torch.nn.attention import SDPBackend, sdpa_kernel + import numpy as np from ml_dtypes import bfloat16 @@ -59,29 +61,34 @@ def generate_golden_reference( K = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range V = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range + K_original = K.clone() + V_original = V.clone() + K = K.repeat_interleave(number_of_groups, dim=0) V = V.repeat_interleave(number_of_groups, dim=0) # MHA from PyTorch inv_scale = 1 / np.sqrt(K.shape[-1]) - O = torch.nn.functional.scaled_dot_product_attention( - Q.to(torch.bfloat16), - K.to(torch.bfloat16), - V.to(torch.bfloat16), - dropout_p=0.0, - is_causal=True, - scale=inv_scale, - ) + + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + O = torch.nn.functional.scaled_dot_product_attention( + Q.to(torch.bfloat16).unsqueeze(0), + K.to(torch.bfloat16).unsqueeze(0), + V.to(torch.bfloat16).unsqueeze(0), + dropout_p=0.0, + is_causal=True, + scale=inv_scale, + ).squeeze(0) # Pad all tensors to multiple of 64 Q = pad_to_multiple_of_64(Q, seq_dim=1, num_pipeline=num_pipeline) - K = pad_to_multiple_of_64(K, seq_dim=1, num_pipeline=num_pipeline) - V = pad_to_multiple_of_64(V, seq_dim=1, num_pipeline=num_pipeline) + K_original = pad_to_multiple_of_64(K_original, seq_dim=1, num_pipeline=num_pipeline) + V_original = pad_to_multiple_of_64(V_original, seq_dim=1, num_pipeline=num_pipeline) O = pad_to_multiple_of_64(O, seq_dim=1, num_pipeline=num_pipeline) return { "Q": Q, - "K": K, - "V": V, + "K": K_original, + "V": V_original, "O": O, } diff --git a/iron/operators/mha/test.py b/iron/operators/mha/test.py index 35c5087f..42cd1f48 100755 --- a/iron/operators/mha/test.py +++ b/iron/operators/mha/test.py @@ -13,8 +13,24 @@ def generate_test_params(extensive=False): - params = [(16384, 64, 1, 8)] - names = ["mha"] + # (seq_len, head_dim, heads, number_of_pipeline, num_kv_heads) + + names = [] + + params = [(16384, 64, 1, 8, 0)] + + if extensive: + params += [ + (4096, 64, 8, 8, 4), + (4096, 64, 8, 8, 2), + (4096, 64, 8, 8, 0), + ] + + for seq_len, head_dim, heads, number_of_pipeline, num_kv_heads in params: + names += [ + f"mha_{seq_len}_{head_dim}_{heads}_{number_of_pipeline}_{num_kv_heads}" + ] + return params, names @@ -35,14 +51,26 @@ def generate_test_params(extensive=False): Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", all_params) -def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context): +@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines,num_kv_heads", all_params) +def test_mha( + seq_len: int, + dim: int, + num_heads: int, + num_pipelines: int, + num_kv_heads: int, + aie_context, +): + + print( + f"\nTest configuration: seq_len={seq_len}, dim={dim}, num_heads={num_heads}, num_pipelines={num_pipelines}, num_kv_heads={num_kv_heads}" + ) + golden_ref = generate_golden_reference( S_q=seq_len, S_kv=seq_len, d=dim, heads=num_heads, - num_kv_heads=num_heads, + num_kv_heads=num_kv_heads, num_pipeline=num_pipelines, ) @@ -50,10 +78,10 @@ def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context): num_heads=num_heads, seq_len=seq_len, d=dim, - num_KV_heads=num_heads, + num_KV_heads=num_kv_heads, num_of_pipelines=num_pipelines, context=aie_context, - ) + ) # VJUNG: TODO: Pass the verbose flag to the operator for debugging input_buffers = { "Q": golden_ref["Q"].flatten(), diff --git a/scripts/hooks/pre-push b/scripts/hooks/pre-push index b866a25f..b096c1f4 100644 --- a/scripts/hooks/pre-push +++ b/scripts/hooks/pre-push @@ -23,7 +23,7 @@ echo "Checking licenses with reuse..." if command -v reuse &> /dev/null; then if ! reuse lint; then echo "❌ License check failed" - echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./ + echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./' FAILED=1 else echo "✅ License check passed"