Skip to content

Commit 5167f37

Browse files
Aoti aot windows (#15711)
Add some support for the windows options in the partitioner. Adding CI with the cpp changes
1 parent cadb0db commit 5167f37

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

backends/cuda/cuda_backend.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import typing
1010
from enum import Enum
11+
from importlib import resources
1112

1213
from typing import Any, Dict, final, List, Optional, Set
1314

@@ -116,7 +117,7 @@ class CudaBackend(BackendDetails):
116117
"""
117118

118119
@staticmethod
119-
def preprocess(
120+
def preprocess( # noqa: C901
120121
edge_program: ExportedProgram,
121122
compile_specs: List[CompileSpec],
122123
) -> PreprocessResult:
@@ -162,6 +163,31 @@ def preprocess(
162163
"max_autotune_conv_backends": "TRITON",
163164
}
164165

166+
platform = "linux"
167+
shim_library_path = None
168+
for spec in compile_specs:
169+
if spec.key == "platform":
170+
platform = spec.value.decode("utf-8")
171+
if spec.key == "shim_library_path":
172+
shim_library_path = spec.value.decode("utf-8")
173+
174+
assert platform == "linux" or platform == "windows"
175+
if platform == "windows" and shim_library_path is None:
176+
lib_dir = resources.files("executorch").joinpath("data/lib")
177+
shim_library_path = str(lib_dir)
178+
if platform == "linux":
179+
assert shim_library_path is None
180+
181+
if platform == "windows":
182+
options.update(
183+
{
184+
"aot_inductor.cross_target_platform": "windows",
185+
"aot_inductor.aoti_shim_library": "aoti_cuda_shims",
186+
"aot_inductor.aoti_shim_library_path": shim_library_path,
187+
"aot_inductor.precompile_headers": False,
188+
}
189+
)
190+
165191
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
166192
[
167193
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.

0 commit comments

Comments
 (0)