Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "ingress/KernelBench/KernelBench"]
path = ingress/KernelBench/KernelBench
url = https://github.com/ScalingIntelligence/KernelBench.git
1 change: 1 addition & 0 deletions ingress/KernelBench/KernelBench
Submodule KernelBench added at 018c59
165 changes: 165 additions & 0 deletions ingress/KernelBench/convert-kernel-bench-to-mlir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#!/usr/bin/env python3

import sys
from pathlib import Path

from mlir import ir, passmanager
from lighthouse.ingress import torch as torch_ingress


kernels_as_pytorch_folder = Path(__file__).parent / "KernelBench" / "KernelBench"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this depends on where the git was cloned in the bash script, perhaps that last step (clone) could be done in this script as well?

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure.

Doing a git clone in either script feels unclean. I also don't like the idea of it being a submodule as that then seems to imply you have to clone KernelBench to do anything useful with lighthouse. It seems to me KernelBench will be just one source of ingress compute graphs of interest, with it potentially making sense to allow users/CI to opt-in to which paths they want to run tests with. What's the right mechanism for that? I am not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KernelBench is NOT an ingress. Torch-MLIR is.

We now have three PRs that work with FX importer, none using the other. We should have one FX importer script that is used by others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The importer impasse has been resolved.

Whether the KernelBench submodule and converter script should live in this "ingress" directory is up to taste. I will defer to anyone who suggests a better path.


if not (kernels_as_pytorch_folder.exists() and kernels_as_pytorch_folder.is_dir()):
print(
"ERROR: KernelBench repo not found.\n"
"NOTE: Pull in dependency with: git submodule update "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It probably needs update --init when the dir is not present at all

Copy link
Contributor Author

@rolfmorel rolfmorel Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it yesterday without any directory there. That is, started out clean, I ran the script and got the error whereupon I copied the (--init-less) command and ran that. After that I could run the script without error.

Might depend on git version though. If someone knows or encounters that this command isn't sufficient in all cases, please let me know and I will amend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looks like the error only occurs for the first time before overall submodule initialization.
Now, if I only remove the cloned KernelBench dir, then a simple submodule update seems sufficient.
Might be down to some git caching?

However, I'm able to reproduce the error in a freshly cloned repo.
Running git submodule update ingress/KernelBench/KernelBench returns:

Submodule path 'ingress/KernelBench/KernelBench' not initialized
Maybe you want to use 'update --init'?

Not sure about best practices here. But not necessarily a blocker as it's easily fixable by following git's error msg.

+ str(kernels_as_pytorch_folder.parent.relative_to(Path.cwd()))
+ "",
file=sys.stderr,
)
sys.exit(1)


kernels_as_pytorch_level1 = kernels_as_pytorch_folder / "level1"
kernels_as_pytorch_level2 = kernels_as_pytorch_folder / "level2"

kernels_as_mlir_folder = Path(__file__).parent / "cache"
kernels_as_mlir_level1 = kernels_as_mlir_folder / "level1"
kernels_as_mlir_level1.mkdir(parents=True, exist_ok=True)
kernels_as_mlir_level2 = kernels_as_mlir_folder / "level2"
kernels_as_mlir_level2.mkdir(parents=True, exist_ok=True)

level1, level2 = Path("level1"), Path("level2")
ignore_list = [
level1 / "12_Matmul_with_diagonal_matrices_.py", # torch.operator "torch.aten.diag"
level1
/ "34_InstanceNorm.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (93898875033000)
level1
/ "72_conv_transposed_3D_asymmetric_input_asymmetric_kernel___strided_padded_grouped_.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "89_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "90_cumprod.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "91_cumsum_reverse.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "92_cumsum_exclusive.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "93_masked_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "95_CrossEntropyLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "96_HuberLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "97_ScaledDotProductAttention.py", # AssertionError: Torch not compiled with CUDA enabled
level1
/ "99_TripletMarginLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level2
/ "17_Conv2d_InstanceNorm_Divide.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104)
level2
/ "18_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "28_BMM_InstanceNorm_Sum_ResidualAdd_Multiply.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104)
level2
/ "42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "43_Conv3d_Max_LogSumExp_ReLU.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "45_Gemm_Sigmoid_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "52_Conv2d_Activation_BatchNorm.py", # failed to legalize operation 'torch.operator'
level2 / "55_Matmul_MaxPool_Sum_Scale.py", # MLIR file too big: 16G
level2 / "59_Matmul_Swish_Scaling.py", # MLIR file too big: 16G
level2 / "56_Matmul_Sigmoid_Sum.py", # MLIR file too big: 16G
level2 / "66_Matmul_Dropout_Softmax.py", # MLIR file too big: 4G
level2 / "68_Matmul_Min_Subtract.py", # MLIR file too big: 4G
level2 / "94_Gemm_BiasAdd_Hardtanh_Mish_GroupNorm.py", # MLIR file too big: 1G
level2 / "33_Gemm_Scale_BatchNorm.py", # MLIR file too big: 1G
level2 / "88_Gemm_GroupNorm_Swish_Multiply_Swish.py", # MLIR file too big: 1G
level2 / "75_Gemm_GroupNorm_Min_BiasAdd.py", # MLIR file too big: 1G
level2 / "84_Gemm_BatchNorm_Scaling_Softmax.py", # MLIR file too big: 1G
level2 / "97_Matmul_BatchNorm_BiasAdd_Divide_Swish.py", # MLIR file too big: 1G
level2 / "62_Matmul_GroupNorm_LeakyReLU_Sum.py", # MLIR file too big: 1G
level2 / "30_Gemm_GroupNorm_Hardtanh.py", # MLIR file too big: 1G
level2 / "95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py", # MLIR file too big: 1G
level2 / "29_Matmul_Mish_Mish.py", # MLIR file too big: 1G
level2 / "99_Matmul_GELU_Softmax.py", # MLIR file too big: 1G
level2 / "98_Matmul_AvgPool_GELU_Scale_Max.py", # MLIR file too big: 1G
level2 / "80_Gemm_Max_Subtract_GELU.py", # MLIR file too big: 1G
level2 / "81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py", # MLIR file too big: 1G
level2 / "12_Gemm_Multiply_LeakyReLU.py", # MLIR file too big: 1G
level2 / "53_Gemm_Scaling_Hardtanh_GELU.py", # MLIR file too big: 1G
level2 / "9_Matmul_Subtract_Multiply_ReLU.py", # MLIR file too big: 1G
level2 / "70_Gemm_Sigmoid_Scaling_ResidualAdd.py", # MLIR file too big: 1G
level2 / "86_Matmul_Divide_GELU.py", # MLIR file too big: 1G
level2 / "63_Gemm_ReLU_Divide.py", # MLIR file too big: 1G
level2 / "76_Gemm_Add_ReLU.py", # MLIR file too big: 1G
level2 / "14_Gemm_Divide_Sum_Scaling.py", # MLIR file too big: 1G
level2 / "39_Gemm_Scale_BatchNorm.py", # MLIR file too big: 256M
level2 / "41_Gemm_BatchNorm_GELU_ReLU.py", # MLIR file too big: 256M
level2 / "40_Matmul_Scaling_ResidualAdd.py", # MLIR file too big: 256M
level2 / "37_Matmul_Swish_Sum_GroupNorm.py", # MLIR file too big: 64.3M
level2
/ "58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "64_Gemm_LogSumExp_LeakyReLU_LeakyReLU_GELU_GELU.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "79_Conv3d_Multiply_InstanceNorm_Clamp_Multiply_Max.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94312016449768)
level2
/ "92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
]


ctx = ir.Context()
pm = passmanager.PassManager(context=ctx)
pm.add("linalg-specialize-generic-ops")

print("Output directory:", kernels_as_mlir_folder)
exitcode = 0
for pytorch_level, mlir_level in (
(kernels_as_pytorch_level1, kernels_as_mlir_level1),
(kernels_as_pytorch_level2, kernels_as_mlir_level2),
):
for kernel_pytorch_file in pytorch_level.iterdir():
level_and_kernel = (
Path(kernel_pytorch_file.parent.name) / kernel_pytorch_file.name
)
if level_and_kernel in ignore_list or not kernel_pytorch_file.is_file():
print(
f"Skipping: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}",
file=sys.stderr,
)
continue

kernel_name = kernel_pytorch_file.stem

kernel_as_mlir_path = mlir_level / (kernel_name + ".mlir")
if kernel_as_mlir_path.exists():
print(
f"Already in cache: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)
continue
print(
f"Processing: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)
mlir_kernel = torch_ingress.import_from_file(
kernel_pytorch_file, ir_context=ctx
)
assert isinstance(mlir_kernel, ir.Module)

try:
pm.run(mlir_kernel.operation) # cleanup
except Exception as e:
print(
f"ERROR: got the following error cleaning up '{kernel_name}'",
file=sys.stderr,
)
raise e

with kernel_as_mlir_path.open("w") as f:
print("// MLIR output after conversion and clean-up:", file=f)
print(mlir_kernel, file=f)