Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,11 +741,14 @@ def save(

- If both dynamic_shapes and Input objects are provided, the explicit dynamic_shapes
parameter takes precedence.
kwargs: Additional format-specific kwargs. ``partitioners=`` and
``compile_specs=`` are only used with ``output_format="executorch"``;
otherwise they are ignored with a warning. Pass
``compile_specs=[CompileSpec("target_device", b"cuda:<i>")]`` to
override the default target device (``cuda:0``).
kwargs: Additional format-specific kwargs. ``partitioners=``,
``compile_specs=``, and ``backend_config=`` are only used with
``output_format="executorch"``; otherwise they are ignored with a
warning. Pass ``compile_specs=[CompileSpec("target_device",
b"cuda:<i>")]`` to override the default target device (``cuda:0``).
``backend_config=`` takes an ``Optional[ExecutorchBackendConfig]``
and is forwarded to ``to_executorch(config=...)`` to customize
ExecuTorch lowering (e.g. memory planning or device placement).
"""
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
Expand All @@ -772,6 +775,7 @@ def save(

executorch_partitioners = kwargs.pop("partitioners", None)
executorch_compile_specs = kwargs.pop("compile_specs", None)
executorch_backend_config = kwargs.pop("backend_config", None)

if output_format not in accepted_formats:
raise ValueError(
Expand Down Expand Up @@ -898,6 +902,11 @@ def _extract_tensor(obj: Any) -> Any:
"compile_specs= is only used with output_format='executorch' and will "
f"be ignored for output_format='{output_format}'."
)
if executorch_backend_config and output_format != "executorch":
logger.warning(
"backend_config= is only used with output_format='executorch' and will "
f"be ignored for output_format='{output_format}'."
)
if output_format == "aot_inductor" and platform.system() != "Linux":
raise ValueError(
f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format"
Expand Down Expand Up @@ -966,6 +975,7 @@ def _extract_tensor(obj: Any) -> Any:
file_path,
partitioners=executorch_partitioners,
compile_specs=executorch_compile_specs,
backend_config=executorch_backend_config,
)
else:
raise RuntimeError(
Expand Down Expand Up @@ -1031,6 +1041,7 @@ def _extract_tensor(obj: Any) -> Any:
file_path,
partitioners=executorch_partitioners,
compile_specs=executorch_compile_specs,
backend_config=executorch_backend_config,
)
else:
raise RuntimeError(
Expand Down Expand Up @@ -1117,6 +1128,7 @@ def _extract_tensor(obj: Any) -> Any:
file_path,
partitioners=executorch_partitioners,
compile_specs=executorch_compile_specs,
backend_config=executorch_backend_config,
)
else:
raise RuntimeError(
Expand Down Expand Up @@ -1379,7 +1391,7 @@ def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None
partitioner=partitioners,
compile_config=get_edge_compile_config(),
)
executorch_program = edge_program.to_executorch()
executorch_program = edge_program.to_executorch(config=kwargs.get("backend_config"))
with open(file_path, "wb") as f:
executorch_program.write_to_file(f)

Expand Down
Loading