From b2dfa4454963b1987359917e6686ae2c86e52434 Mon Sep 17 00:00:00 2001 From: Victor Jung Date: Fri, 30 Jan 2026 13:59:42 +0100 Subject: [PATCH 1/6] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index adafbcd6..c2e66af8 100755 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ id_ed25519.pub *.safetensors *.model .cline_storage +*.egg-info From 377e020fc9489dd9b3c86cbf793a8d8e15450a69 Mon Sep 17 00:00:00 2001 From: Victor Jung Date: Fri, 30 Jan 2026 17:06:50 +0100 Subject: [PATCH 2/6] Target flash attention backend to speedup golden values generation --- iron/operators/mha/reference.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/iron/operators/mha/reference.py b/iron/operators/mha/reference.py index 3216ad78..97cb744f 100644 --- a/iron/operators/mha/reference.py +++ b/iron/operators/mha/reference.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch +from torch.nn.attention import SDPBackend, sdpa_kernel + import numpy as np from ml_dtypes import bfloat16 @@ -64,14 +66,16 @@ def generate_golden_reference( # MHA from PyTorch inv_scale = 1 / np.sqrt(K.shape[-1]) - O = torch.nn.functional.scaled_dot_product_attention( - Q.to(torch.bfloat16), - K.to(torch.bfloat16), - V.to(torch.bfloat16), - dropout_p=0.0, - is_causal=True, - scale=inv_scale, - ) + + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + O = torch.nn.functional.scaled_dot_product_attention( + Q.to(torch.bfloat16).unsqueeze(0), + K.to(torch.bfloat16).unsqueeze(0), + V.to(torch.bfloat16).unsqueeze(0), + dropout_p=0.0, + is_causal=True, + scale=inv_scale, + ).squeeze(0) # Pad all tensors to multiple of 64 Q = pad_to_multiple_of_64(Q, seq_dim=1, num_pipeline=num_pipeline) From ac6d3cff3d1958d88c94736e00c45eb1706b1d13 Mon Sep 17 00:00:00 2001 From: Victor Jung Date: Fri, 30 Jan 2026 17:07:57 +0100 Subject: [PATCH 3/6] Add support for GQA without repeat_interleave --- iron/operators/mha/design.py | 16 +++++++++------- iron/operators/mha/op.py | 4 ++-- iron/operators/mha/reference.py | 11 +++++++---- iron/operators/mha/test.py | 33 ++++++++++++++++++++++++++------- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/iron/operators/mha/design.py b/iron/operators/mha/design.py index d11e4ed4..12a0c5db 100644 --- a/iron/operators/mha/design.py +++ b/iron/operators/mha/design.py @@ -712,9 +712,9 @@ def batched_matmul_pv( (heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1) ) - K_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) + K_tiles = TensorTiler2D.group_tiler((num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) - V_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) + V_tiles = TensorTiler2D.group_tiler((num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) O_tiles = TensorTiler2D.group_tiler( (heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1) @@ -758,9 +758,9 @@ def legalize_tas(tas: TensorAccessSequence): if verbose: print(f"DMA Transfer Configuration: DRAM <-> Mem tile") - # print_tap_seq_info(Q_tiles, "Q") - # print_tap_seq_info(K_tiles, "K") - # print_tap_seq_info(V_tiles, "V") + print_tap_seq_info(Q_tiles, "Q") + print_tap_seq_info(K_tiles, "K") + print_tap_seq_info(V_tiles, "V") print_tap_seq_info(O_tiles, "O") # Runtime operations to move data to/from the AIE-array @@ -788,6 +788,8 @@ def set_mha_rtps(): for head_idx in range(heads): + kv_head_idx = head_idx // (heads // num_KV_heads) + for q_block_idx in range(num_q_block_per_pipeline): # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. @@ -827,14 +829,14 @@ def set_mha_rtps(): rt.fill( inK.prod(), K, - tap=K_tiles[head_idx], + tap=K_tiles[kv_head_idx], placement=Tile(col=5, row=0), task_group=tg, ) rt.fill( inV.prod(), V, - tap=V_tiles[head_idx], + tap=V_tiles[kv_head_idx], placement=Tile(col=6, row=0), task_group=tg, ) diff --git a/iron/operators/mha/op.py b/iron/operators/mha/op.py index 58864519..d4017b2a 100644 --- a/iron/operators/mha/op.py +++ b/iron/operators/mha/op.py @@ -103,7 +103,7 @@ def set_up_artifacts(self): ) xclbin_artifact = XclbinArtifact.new( - f"mha.xclbin", + f"{file_name_base}.xclbin", depends=[ mlir_artifact, KernelArchiveArtifact.new( @@ -139,7 +139,7 @@ def set_up_artifacts(self): ) insts_artifact = InstsBinArtifact.new( - f"mha.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"] + f"{file_name_base}.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"] ) self.xclbin_artifact = xclbin_artifact diff --git a/iron/operators/mha/reference.py b/iron/operators/mha/reference.py index 97cb744f..012f39ac 100644 --- a/iron/operators/mha/reference.py +++ b/iron/operators/mha/reference.py @@ -61,6 +61,9 @@ def generate_golden_reference( K = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range V = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range + K_original = K.clone() + V_original = V.clone() + K = K.repeat_interleave(number_of_groups, dim=0) V = V.repeat_interleave(number_of_groups, dim=0) @@ -79,13 +82,13 @@ def generate_golden_reference( # Pad all tensors to multiple of 64 Q = pad_to_multiple_of_64(Q, seq_dim=1, num_pipeline=num_pipeline) - K = pad_to_multiple_of_64(K, seq_dim=1, num_pipeline=num_pipeline) - V = pad_to_multiple_of_64(V, seq_dim=1, num_pipeline=num_pipeline) + K_original = pad_to_multiple_of_64(K_original, seq_dim=1, num_pipeline=num_pipeline) + V_original = pad_to_multiple_of_64(V_original, seq_dim=1, num_pipeline=num_pipeline) O = pad_to_multiple_of_64(O, seq_dim=1, num_pipeline=num_pipeline) return { "Q": Q, - "K": K, - "V": V, + "K": K_original, + "V": V_original, "O": O, } diff --git a/iron/operators/mha/test.py b/iron/operators/mha/test.py index 35c5087f..2342c151 100755 --- a/iron/operators/mha/test.py +++ b/iron/operators/mha/test.py @@ -13,8 +13,24 @@ def generate_test_params(extensive=False): - params = [(16384, 64, 1, 8)] - names = ["mha"] + # (seq_len, head_dim, heads, number_of_pipeline, num_kv_heads) + + names = [] + + params = [ + (16384, 64, 1, 8, 0) + ] + + if extensive: + params += [ + (4096, 64, 8, 8, 4), + (4096, 64, 8, 8, 2), + (4096, 64, 8, 8, 0), + ] + + for seq_len, head_dim, heads, number_of_pipeline, num_kv_heads in params: + names += [f"mha_{seq_len}_{head_dim}_{heads}_{number_of_pipeline}_{num_kv_heads}"] + return params, names @@ -35,14 +51,17 @@ def generate_test_params(extensive=False): Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", all_params) -def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context): +@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines,num_kv_heads", all_params) +def test_mha(seq_len: int, dim: int, num_heads: int, num_pipelines: int, num_kv_heads: int, aie_context): + + print(f"\nTest configuration: seq_len={seq_len}, dim={dim}, num_heads={num_heads}, num_pipelines={num_pipelines}, num_kv_heads={num_kv_heads}") + golden_ref = generate_golden_reference( S_q=seq_len, S_kv=seq_len, d=dim, heads=num_heads, - num_kv_heads=num_heads, + num_kv_heads=num_kv_heads, num_pipeline=num_pipelines, ) @@ -50,10 +69,10 @@ def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context): num_heads=num_heads, seq_len=seq_len, d=dim, - num_KV_heads=num_heads, + num_KV_heads=num_kv_heads, num_of_pipelines=num_pipelines, context=aie_context, - ) + ) # VJUNG: TODO: Pass the verbose flag to the operator for debugging input_buffers = { "Q": golden_ref["Q"].flatten(), From aaf3d0ba62f063ff469371e76cd995e42ba2a93e Mon Sep 17 00:00:00 2001 From: Victor Jung Date: Fri, 30 Jan 2026 17:13:05 +0100 Subject: [PATCH 4/6] Lint --- iron/operators/mha/design.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/iron/operators/mha/design.py b/iron/operators/mha/design.py index 12a0c5db..fd1ffdae 100644 --- a/iron/operators/mha/design.py +++ b/iron/operators/mha/design.py @@ -758,10 +758,10 @@ def legalize_tas(tas: TensorAccessSequence): if verbose: print(f"DMA Transfer Configuration: DRAM <-> Mem tile") - print_tap_seq_info(Q_tiles, "Q") + # print_tap_seq_info(Q_tiles, "Q") print_tap_seq_info(K_tiles, "K") print_tap_seq_info(V_tiles, "V") - print_tap_seq_info(O_tiles, "O") + # print_tap_seq_info(O_tiles, "O") # Runtime operations to move data to/from the AIE-array rt = Runtime() From 543336cb574c24380ce4b4b009aaf3f01ea7ce81 Mon Sep 17 00:00:00 2001 From: Victor Jung Date: Fri, 30 Jan 2026 17:16:01 +0100 Subject: [PATCH 5/6] Fix pre-push script and lint --- iron/operators/mha/design.py | 8 ++++++-- iron/operators/mha/op.py | 4 +++- iron/operators/mha/test.py | 31 ++++++++++++++++++++----------- scripts/hooks/pre-push | 2 +- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/iron/operators/mha/design.py b/iron/operators/mha/design.py index fd1ffdae..519804d7 100644 --- a/iron/operators/mha/design.py +++ b/iron/operators/mha/design.py @@ -712,9 +712,13 @@ def batched_matmul_pv( (heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1) ) - K_tiles = TensorTiler2D.group_tiler((num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) + K_tiles = TensorTiler2D.group_tiler( + (num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1) + ) - V_tiles = TensorTiler2D.group_tiler((num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)) + V_tiles = TensorTiler2D.group_tiler( + (num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1) + ) O_tiles = TensorTiler2D.group_tiler( (heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1) diff --git a/iron/operators/mha/op.py b/iron/operators/mha/op.py index d4017b2a..4f645be7 100644 --- a/iron/operators/mha/op.py +++ b/iron/operators/mha/op.py @@ -139,7 +139,9 @@ def set_up_artifacts(self): ) insts_artifact = InstsBinArtifact.new( - f"{file_name_base}.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"] + f"{file_name_base}.bin", + depends=[mlir_artifact], + extra_flags=["--dynamic-objFifos"], ) self.xclbin_artifact = xclbin_artifact diff --git a/iron/operators/mha/test.py b/iron/operators/mha/test.py index 2342c151..42cd1f48 100755 --- a/iron/operators/mha/test.py +++ b/iron/operators/mha/test.py @@ -17,19 +17,19 @@ def generate_test_params(extensive=False): names = [] - params = [ - (16384, 64, 1, 8, 0) - ] + params = [(16384, 64, 1, 8, 0)] if extensive: params += [ - (4096, 64, 8, 8, 4), - (4096, 64, 8, 8, 2), - (4096, 64, 8, 8, 0), - ] + (4096, 64, 8, 8, 4), + (4096, 64, 8, 8, 2), + (4096, 64, 8, 8, 0), + ] for seq_len, head_dim, heads, number_of_pipeline, num_kv_heads in params: - names += [f"mha_{seq_len}_{head_dim}_{heads}_{number_of_pipeline}_{num_kv_heads}"] + names += [ + f"mha_{seq_len}_{head_dim}_{heads}_{number_of_pipeline}_{num_kv_heads}" + ] return params, names @@ -52,9 +52,18 @@ def generate_test_params(extensive=False): Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) @pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines,num_kv_heads", all_params) -def test_mha(seq_len: int, dim: int, num_heads: int, num_pipelines: int, num_kv_heads: int, aie_context): +def test_mha( + seq_len: int, + dim: int, + num_heads: int, + num_pipelines: int, + num_kv_heads: int, + aie_context, +): - print(f"\nTest configuration: seq_len={seq_len}, dim={dim}, num_heads={num_heads}, num_pipelines={num_pipelines}, num_kv_heads={num_kv_heads}") + print( + f"\nTest configuration: seq_len={seq_len}, dim={dim}, num_heads={num_heads}, num_pipelines={num_pipelines}, num_kv_heads={num_kv_heads}" + ) golden_ref = generate_golden_reference( S_q=seq_len, @@ -72,7 +81,7 @@ def test_mha(seq_len: int, dim: int, num_heads: int, num_pipelines: int, num_kv_ num_KV_heads=num_kv_heads, num_of_pipelines=num_pipelines, context=aie_context, - ) # VJUNG: TODO: Pass the verbose flag to the operator for debugging + ) # VJUNG: TODO: Pass the verbose flag to the operator for debugging input_buffers = { "Q": golden_ref["Q"].flatten(), diff --git a/scripts/hooks/pre-push b/scripts/hooks/pre-push index b866a25f..b096c1f4 100644 --- a/scripts/hooks/pre-push +++ b/scripts/hooks/pre-push @@ -23,7 +23,7 @@ echo "Checking licenses with reuse..." if command -v reuse &> /dev/null; then if ! reuse lint; then echo "❌ License check failed" - echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./ + echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./' FAILED=1 else echo "✅ License check passed" From 193a06584fd5cd123e2e89bcd4ea3606cf0bb076 Mon Sep 17 00:00:00 2001 From: Victor Jung Date: Mon, 2 Feb 2026 13:44:22 +0100 Subject: [PATCH 6/6] Link verbose flag to verbose option in MHA MLIR generation --- conftest.py | 5 +++-- iron/common/aie_context.py | 3 ++- iron/operators/mha/op.py | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/conftest.py b/conftest.py index 2f4ab726..5d2d40fa 100644 --- a/conftest.py +++ b/conftest.py @@ -14,9 +14,10 @@ @pytest.fixture -def aie_context(): +def aie_context(request): """Create a fresh AIEContext for each test""" - return AIEContext() + verbose_mlir = request.config.option.verbose > 0 + return AIEContext(mlir_verbose=verbose_mlir) def pytest_addoption(parser): diff --git a/iron/common/aie_context.py b/iron/common/aie_context.py index 804499f6..702fb9f1 100644 --- a/iron/common/aie_context.py +++ b/iron/common/aie_context.py @@ -14,7 +14,7 @@ class AIEContext: """Context for managing AIE operator compilation and runtime state""" - def __init__(self, use_runlist=True): + def __init__(self, use_runlist=True, mlir_verbose=None): self.operators = [] self.static_data_pool = {} self.device_manager = AIEDeviceManager() @@ -24,6 +24,7 @@ def __init__(self, use_runlist=True): self.peano_dir = Path(aie.utils.config.peano_install_dir()) # Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed) self.use_runlist = use_runlist + self.mlir_verbose = bool(mlir_verbose) self._runtime_prepared = False def register_operator(self, operator): diff --git a/iron/operators/mha/op.py b/iron/operators/mha/op.py index 4f645be7..5095c150 100644 --- a/iron/operators/mha/op.py +++ b/iron/operators/mha/op.py @@ -53,6 +53,7 @@ def set_up_artifacts(self): kv_heads = self.num_KV_heads if self.num_KV_heads > 0 else self.num_heads file_name_base = f"mha_{self.num_heads}h_{kv_heads}kv_{self.seq_len}s_{self.d}d" + mlir_verbose = getattr(self.context, "mlir_verbose", False) # Define source files mm_source = str(self.context.base_dir / "aie_kernels" / "aie2p" / "mm.cc") @@ -98,7 +99,7 @@ def set_up_artifacts(self): "number_of_pipelines": self.num_of_pipelines, "emulate_bf16_mmul_with_bfp16": True, "trace_size": 0, - "verbose": False, + "verbose": mlir_verbose, }, )