Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ id_ed25519.pub
*.safetensors
*.model
.cline_storage
*.egg-info
20 changes: 13 additions & 7 deletions iron/operators/mha/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,9 +712,13 @@ def batched_matmul_pv(
(heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1)
)

K_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1))
K_tiles = TensorTiler2D.group_tiler(
(num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)
)

V_tiles = TensorTiler2D.group_tiler((heads * S_kv_pad, d), (S_kv_pad, d), (1, 1))
V_tiles = TensorTiler2D.group_tiler(
(num_KV_heads * S_kv_pad, d), (S_kv_pad, d), (1, 1)
)

O_tiles = TensorTiler2D.group_tiler(
(heads * S_q_pad, d), (number_of_pipelines_join_distribute * B_q, d), (1, 1)
Expand Down Expand Up @@ -759,9 +763,9 @@ def legalize_tas(tas: TensorAccessSequence):
if verbose:
print(f"DMA Transfer Configuration: DRAM <-> Mem tile")
# print_tap_seq_info(Q_tiles, "Q")
# print_tap_seq_info(K_tiles, "K")
# print_tap_seq_info(V_tiles, "V")
print_tap_seq_info(O_tiles, "O")
print_tap_seq_info(K_tiles, "K")
print_tap_seq_info(V_tiles, "V")
# print_tap_seq_info(O_tiles, "O")

# Runtime operations to move data to/from the AIE-array
rt = Runtime()
Expand All @@ -788,6 +792,8 @@ def set_mha_rtps():

for head_idx in range(heads):

kv_head_idx = head_idx // (heads // num_KV_heads)

for q_block_idx in range(num_q_block_per_pipeline):

# Initialize a group for parallel drain tasks, with fill resources free'd when drains complete.
Expand Down Expand Up @@ -827,14 +833,14 @@ def set_mha_rtps():
rt.fill(
inK.prod(),
K,
tap=K_tiles[head_idx],
tap=K_tiles[kv_head_idx],
placement=Tile(col=5, row=0),
task_group=tg,
)
rt.fill(
inV.prod(),
V,
tap=V_tiles[head_idx],
tap=V_tiles[kv_head_idx],
placement=Tile(col=6, row=0),
task_group=tg,
)
Expand Down
6 changes: 4 additions & 2 deletions iron/operators/mha/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def set_up_artifacts(self):
)

xclbin_artifact = XclbinArtifact.new(
f"mha.xclbin",
f"{file_name_base}.xclbin",
depends=[
mlir_artifact,
KernelArchiveArtifact.new(
Expand Down Expand Up @@ -139,7 +139,9 @@ def set_up_artifacts(self):
)

insts_artifact = InstsBinArtifact.new(
f"mha.bin", depends=[mlir_artifact], extra_flags=["--dynamic-objFifos"]
f"{file_name_base}.bin",
depends=[mlir_artifact],
extra_flags=["--dynamic-objFifos"],
)

self.xclbin_artifact = xclbin_artifact
Expand Down
31 changes: 19 additions & 12 deletions iron/operators/mha/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel

import numpy as np
from ml_dtypes import bfloat16

Expand Down Expand Up @@ -59,29 +61,34 @@ def generate_golden_reference(
K = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range
V = torch.rand(num_kv_heads, S_kv, d, dtype=torch.bfloat16) * val_range

K_original = K.clone()
V_original = V.clone()

K = K.repeat_interleave(number_of_groups, dim=0)
V = V.repeat_interleave(number_of_groups, dim=0)

# MHA from PyTorch
inv_scale = 1 / np.sqrt(K.shape[-1])
O = torch.nn.functional.scaled_dot_product_attention(
Q.to(torch.bfloat16),
K.to(torch.bfloat16),
V.to(torch.bfloat16),
dropout_p=0.0,
is_causal=True,
scale=inv_scale,
)

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
O = torch.nn.functional.scaled_dot_product_attention(
Q.to(torch.bfloat16).unsqueeze(0),
K.to(torch.bfloat16).unsqueeze(0),
V.to(torch.bfloat16).unsqueeze(0),
dropout_p=0.0,
is_causal=True,
scale=inv_scale,
).squeeze(0)

# Pad all tensors to multiple of 64
Q = pad_to_multiple_of_64(Q, seq_dim=1, num_pipeline=num_pipeline)
K = pad_to_multiple_of_64(K, seq_dim=1, num_pipeline=num_pipeline)
V = pad_to_multiple_of_64(V, seq_dim=1, num_pipeline=num_pipeline)
K_original = pad_to_multiple_of_64(K_original, seq_dim=1, num_pipeline=num_pipeline)
V_original = pad_to_multiple_of_64(V_original, seq_dim=1, num_pipeline=num_pipeline)
O = pad_to_multiple_of_64(O, seq_dim=1, num_pipeline=num_pipeline)

return {
"Q": Q,
"K": K,
"V": V,
"K": K_original,
"V": V_original,
"O": O,
}
42 changes: 35 additions & 7 deletions iron/operators/mha/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@


def generate_test_params(extensive=False):
params = [(16384, 64, 1, 8)]
names = ["mha"]
# (seq_len, head_dim, heads, number_of_pipeline, num_kv_heads)

names = []

params = [(16384, 64, 1, 8, 0)]

if extensive:
params += [
(4096, 64, 8, 8, 4),
(4096, 64, 8, 8, 2),
(4096, 64, 8, 8, 0),
]

for seq_len, head_dim, heads, number_of_pipeline, num_kv_heads in params:
names += [
f"mha_{seq_len}_{head_dim}_{heads}_{number_of_pipeline}_{num_kv_heads}"
]

return params, names


Expand All @@ -35,25 +51,37 @@ def generate_test_params(extensive=False):
Latency=r"Latency \(us\): (?P<value>[\d\.]+)",
Bandwidth=r"Effective Bandwidth: (?P<value>[\d\.e\+-]+) GB/s",
)
@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines", all_params)
def test_mha(seq_len, dim, num_heads, num_pipelines, aie_context):
@pytest.mark.parametrize("seq_len,dim,num_heads,num_pipelines,num_kv_heads", all_params)
def test_mha(
seq_len: int,
dim: int,
num_heads: int,
num_pipelines: int,
num_kv_heads: int,
aie_context,
):

print(
f"\nTest configuration: seq_len={seq_len}, dim={dim}, num_heads={num_heads}, num_pipelines={num_pipelines}, num_kv_heads={num_kv_heads}"
)

golden_ref = generate_golden_reference(
S_q=seq_len,
S_kv=seq_len,
d=dim,
heads=num_heads,
num_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
num_pipeline=num_pipelines,
)

operator = AIEMHA(
num_heads=num_heads,
seq_len=seq_len,
d=dim,
num_KV_heads=num_heads,
num_KV_heads=num_kv_heads,
num_of_pipelines=num_pipelines,
context=aie_context,
)
) # VJUNG: TODO: Pass the verbose flag to the operator for debugging

input_buffers = {
"Q": golden_ref["Q"].flatten(),
Expand Down
2 changes: 1 addition & 1 deletion scripts/hooks/pre-push
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ echo "Checking licenses with reuse..."
if command -v reuse &> /dev/null; then
if ! reuse lint; then
echo "❌ License check failed"
echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./
echo ' Run: reuse annotate --template ApacheAMD --copyright-prefix spdx-string-c --copyright "Advanced Micro Devices, Inc. All rights reserved." --license="Apache-2.0" --recursive --skip-unrecognised ./'
FAILED=1
else
echo "✅ License check passed"
Expand Down