diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 447dbc56b0..ebde848571 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -741,11 +741,14 @@ def save( - If both dynamic_shapes and Input objects are provided, the explicit dynamic_shapes parameter takes precedence. - kwargs: Additional format-specific kwargs. ``partitioners=`` and - ``compile_specs=`` are only used with ``output_format="executorch"``; - otherwise they are ignored with a warning. Pass - ``compile_specs=[CompileSpec("target_device", b"cuda:")]`` to - override the default target device (``cuda:0``). + kwargs: Additional format-specific kwargs. ``partitioners=``, + ``compile_specs=``, and ``backend_config=`` are only used with + ``output_format="executorch"``; otherwise they are ignored with a + warning. Pass ``compile_specs=[CompileSpec("target_device", + b"cuda:")]`` to override the default target device (``cuda:0``). + ``backend_config=`` takes an ``Optional[ExecutorchBackendConfig]`` + and is forwarded to ``to_executorch(config=...)`` to customize + ExecuTorch lowering (e.g. memory planning or device placement). """ if isinstance(module, CudaGraphsTorchTensorRTModule): module = module.compiled_module @@ -772,6 +775,7 @@ def save( executorch_partitioners = kwargs.pop("partitioners", None) executorch_compile_specs = kwargs.pop("compile_specs", None) + executorch_backend_config = kwargs.pop("backend_config", None) if output_format not in accepted_formats: raise ValueError( @@ -898,6 +902,11 @@ def _extract_tensor(obj: Any) -> Any: "compile_specs= is only used with output_format='executorch' and will " f"be ignored for output_format='{output_format}'." ) + if executorch_backend_config and output_format != "executorch": + logger.warning( + "backend_config= is only used with output_format='executorch' and will " + f"be ignored for output_format='{output_format}'." + ) if output_format == "aot_inductor" and platform.system() != "Linux": raise ValueError( f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format" @@ -966,6 +975,7 @@ def _extract_tensor(obj: Any) -> Any: file_path, partitioners=executorch_partitioners, compile_specs=executorch_compile_specs, + backend_config=executorch_backend_config, ) else: raise RuntimeError( @@ -1031,6 +1041,7 @@ def _extract_tensor(obj: Any) -> Any: file_path, partitioners=executorch_partitioners, compile_specs=executorch_compile_specs, + backend_config=executorch_backend_config, ) else: raise RuntimeError( @@ -1117,6 +1128,7 @@ def _extract_tensor(obj: Any) -> Any: file_path, partitioners=executorch_partitioners, compile_specs=executorch_compile_specs, + backend_config=executorch_backend_config, ) else: raise RuntimeError( @@ -1379,7 +1391,7 @@ def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None partitioner=partitioners, compile_config=get_edge_compile_config(), ) - executorch_program = edge_program.to_executorch() + executorch_program = edge_program.to_executorch(config=kwargs.get("backend_config")) with open(file_path, "wb") as f: executorch_program.write_to_file(f)