Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,32 @@ if(NOT EXECUTORCH_SELECT_OPS_YAML STREQUAL ""
)
list(APPEND _executorch_kernels executorch_selected_kernels)

# Auto-right-size the kernel registry unless the user has pinned
# MAX_KERNEL_NUM.
if(NOT DEFINED CACHE{MAX_KERNEL_NUM} AND NOT DEFINED MAX_KERNEL_NUM)
gen_selected_max_kernel_num(
LIB_NAME "executorch_selected_kernels" OPLIST_YAMLS
${gen_selected_ops_output_yaml}
)
target_include_directories(
executorch_core
PRIVATE ${executorch_selected_kernels_max_kernel_num_include_dir}
)
add_dependencies(
executorch_core executorch_selected_kernels_max_kernel_num_header
)
if(TARGET executorch_core_shared)
target_include_directories(
executorch_core_shared
PRIVATE ${executorch_selected_kernels_max_kernel_num_include_dir}
)
add_dependencies(
executorch_core_shared
executorch_selected_kernels_max_kernel_num_header
)
endif()
endif()

install(
TARGETS executorch_selected_kernels
EXPORT ExecuTorchTargets
Expand Down
174 changes: 174 additions & 0 deletions codegen/tools/gen_max_kernel_num.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Compute a right-sized MAX_KERNEL_NUM for the ExecuTorch operator registry from
one or more selected_operators.yaml files (produced by gen_oplist.py) and emit
it as a C header.

Total = sum of (op, kernel_key) variants across all input YAMLs
+ prim ops always registered by kernels/prim_ops/register_prim_ops.cpp.

See runtime/kernel/operator_registry.cpp for how the emitted header is
consumed and the full precedence order. Users that register kernels outside
the selective-build YAML should pin the registry explicitly with
-DMAX_KERNEL_NUM=N.
"""

import argparse
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

import yaml


HEADER_TEMPLATE = """\
// @generated by executorch/codegen/tools/gen_max_kernel_num.py. Do not edit.
#pragma once
#define EXECUTORCH_SELECTED_MAX_KERNEL_NUM {count}
"""

# When a YAML opts into include_all_operators, we still need to write *some*
# output file to keep CMake's add_custom_command contract honest, but without
# defining EXECUTORCH_SELECTED_MAX_KERNEL_NUM so that operator_registry.cpp
# falls through to its compile-time default.
OPT_OUT_HEADER = """\
// @generated by executorch/codegen/tools/gen_max_kernel_num.py. Do not edit.
#pragma once
// Selective build opted into all operators; registry uses compile-time default.
"""

# Locates the `static Kernel prim_ops[] = { ... };` array literal. The count
# lives in the array itself (`kernel_span` uses `sizeof(prim_ops)/sizeof(Kernel)`
# at compile time), so we just bracket-match the array body and count Kernel(
# entries inside it, ignoring the rest of the file.
PRIM_OPS_ARRAY_RE = re.compile(
r"static\s+Kernel\s+prim_ops\s*\[\s*\]\s*=\s*\{(.*?)^\};",
re.DOTALL | re.MULTILINE,
)
PRIM_OPS_KERNEL_RE = re.compile(r"\bKernel\s*\(")


def _count_prim_ops(prim_ops_source: Path) -> int:
source = prim_ops_source.read_text()
match = PRIM_OPS_ARRAY_RE.search(source)
if match is None:
raise RuntimeError(
f"Failed to locate `static Kernel prim_ops[] = {{ ... }};` in "
f"{prim_ops_source}. The array may have been renamed; update "
"PRIM_OPS_ARRAY_RE in gen_max_kernel_num.py."
)
count = len(PRIM_OPS_KERNEL_RE.findall(match.group(1)))
if count == 0:
raise RuntimeError(
f"Found `prim_ops[]` in {prim_ops_source} but it contains zero "
"Kernel(...) entries. The array layout may have changed."
)
return count


def _count_yaml_kernels(yaml_path: Path) -> Optional[int]:
"""Returns the kernel count for one YAML, or None if the YAML opts into
include_all_operators / include_all_overloads (callers should skip the
auto-size header in that case)."""
with open(yaml_path, "r") as f:
data = yaml.safe_load(f) or {}

if data.get("include_all_operators"):
return None

operators: Dict[str, Dict[str, Any]] = data.get("operators") or {}
for _op_name, op_info in operators.items():
if isinstance(op_info, dict) and op_info.get("include_all_overloads"):
return None

et_kernel_metadata: Dict[str, List[str]] = data.get("et_kernel_metadata") or {}

count = 0
seen = set()
for op_name, variants in et_kernel_metadata.items():
seen.add(op_name)
if isinstance(variants, list) and variants:
count += len(variants)
else:
count += 1

# Operators listed but missing from et_kernel_metadata still register one
# default kernel each.
for op_name in operators:
if op_name not in seen:
count += 1

return count


def _write_if_different(path: Path, content: str) -> None:
if path.exists() and path.read_text() == content:
return
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content)


def gen_max_kernel_num(
oplist_yamls: List[Path],
prim_ops_source: Path,
output_path: Path,
) -> Optional[int]:
total = 0
for yaml_path in oplist_yamls:
yaml_count = _count_yaml_kernels(yaml_path)
if yaml_count is None:
print(
f"gen_max_kernel_num: {yaml_path} opts into all operators; "
"emitting opt-out header (registry will use default size).",
file=sys.stderr,
)
_write_if_different(output_path, OPT_OUT_HEADER)
return None
total += yaml_count

total += _count_prim_ops(prim_ops_source)

_write_if_different(output_path, HEADER_TEMPLATE.format(count=total))
return total


def main(argv: List[str]) -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--oplist-yaml",
"--oplist_yaml",
action="append",
required=True,
help="Path to a selected_operators.yaml. May be repeated.",
)
parser.add_argument(
"--prim-ops-source",
"--prim_ops_source",
required=True,
help="Path to kernels/prim_ops/register_prim_ops.cpp.",
)
parser.add_argument(
"--output-path",
"--output_path",
required=True,
help="Path to the header file to emit.",
)
args = parser.parse_args(argv)

count = gen_max_kernel_num(
oplist_yamls=[Path(p) for p in args.oplist_yaml],
prim_ops_source=Path(args.prim_ops_source),
output_path=Path(args.output_path),
)
if count is not None:
print(f"gen_max_kernel_num: wrote {args.output_path} (count={count})")


if __name__ == "__main__":
main(sys.argv[1:])
Loading
Loading