Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions backends/xnnpack/serialization/xnnpack_graph_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,11 @@ def pretty_print_xnngraph(xnnpack_graph_json: str, filename: Optional[str] = Non
_delegate_instance_id = 0


_cached_schema_bytes: Optional[bytes] = None


def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
global _delegate_instance_id
global _delegate_instance_id, _cached_schema_bytes
sanity_check_xnngraph_dataclass(xnnpack_graph)
xnnpack_graph_json = json.dumps(xnnpack_graph, cls=_DataclassEncoder)

Expand All @@ -316,11 +319,13 @@ def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
with tempfile.TemporaryDirectory() as d:
schema_path = os.path.join(d, "schema.fbs")
with open(schema_path, "wb") as schema_file:
schema_file.write(
_resources.files(serialization_package)
.joinpath("schema.fbs")
.read_bytes()
)
if _cached_schema_bytes is None:
_cached_schema_bytes = (
_resources.files(serialization_package)
.joinpath("schema.fbs")
.read_bytes()
)
schema_file.write(_cached_schema_bytes)
json_path = os.path.join(d, "schema.json")
with open(json_path, "wb") as json_file:
json_file.write(xnnpack_graph_json.encode("ascii"))
Expand Down
59 changes: 37 additions & 22 deletions exir/_serialize/_flatbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import contextlib
import importlib.resources
import os
import re
Expand Down Expand Up @@ -240,35 +241,49 @@ class _FlatbufferResult:
# Name of an optional resource containing the `flatc` executable.
_FLATC_RESOURCE_NAME: str = "flatbuffers-flatc"

# Cached flatc binary path. In PAR files, importlib.resources.as_file()
# extracts the binary to a temp file on each call. With 200+ XNNPACK
# partitions this adds ~30 min of overhead. Caching avoids re-extraction.
_flatc_cached_path: Optional[str] = None
_flatc_exit_stack: Optional[contextlib.ExitStack] = None

def _run_flatc(args: Sequence[str]) -> None:
"""Runs the `flatc` command with the provided args.

If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable.
Otherwise, expects the `flatc` tool to be available on the system path.
"""
def _get_flatc_path() -> str:
"""Returns the path to the flatc executable, caching the result."""
global _flatc_cached_path, _flatc_exit_stack
if _flatc_cached_path is not None:
return _flatc_cached_path

flatc_resource = importlib.resources.files(__package__).joinpath(
_FLATC_RESOURCE_NAME
)
if flatc_resource.is_file():
# Use the provided flatc binary.
with importlib.resources.as_file(flatc_resource) as flatc_path:
# Ensure the binary has execute permissions (needed for PAR files)
try:
current_mode = flatc_path.stat().st_mode
if not (current_mode & stat.S_IXUSR):
flatc_path.chmod(
current_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
)
except OSError:
pass
subprocess.run([flatc_path] + list(args), check=True)
_flatc_exit_stack = contextlib.ExitStack()
flatc_path = _flatc_exit_stack.enter_context(
importlib.resources.as_file(flatc_resource)
)
try:
current_mode = flatc_path.stat().st_mode
if not (current_mode & stat.S_IXUSR):
flatc_path.chmod(
current_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
)
except OSError:
pass
_flatc_cached_path = str(flatc_path)
else:
# Expect the `flatc` tool to be on the system path or set as an env var.
flatc_path = os.getenv("FLATC_EXECUTABLE")
if not flatc_path:
flatc_path = "flatc"
subprocess.run([flatc_path] + list(args), check=True)
_flatc_cached_path = os.getenv("FLATC_EXECUTABLE", "flatc")

return _flatc_cached_path


def _run_flatc(args: Sequence[str]) -> None:
"""Runs the `flatc` command with the provided args.

If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable.
Otherwise, expects the `flatc` tool to be available on the system path.
"""
subprocess.run([_get_flatc_path()] + list(args), check=True)


def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None:
Expand Down
Loading