Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions helion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from triton import cdiv
from triton import next_power_of_2

from . import _compat as _compat_module # noqa: F401 # side-effect import
from . import _logging
from . import exc
from . import language
Expand Down
180 changes: 180 additions & 0 deletions helion/_compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,193 @@
from __future__ import annotations

import contextlib
import functools
from typing import Any
from typing import Callable
from typing import cast

import torch
from torch._inductor.runtime.hints import DeviceProperties
from torch._inductor.utils import triton_type
import triton
from triton.backends.compiler import BaseBackend
from triton.backends.compiler import GPUTarget
import triton.language as tl
import triton.runtime.jit as triton_jit

NativeSpecializeImpl = Callable[
[type[BaseBackend], object, bool, bool, bool], tuple[object, ...]
]
CreateSpecializeImpl = Callable[
[Callable[..., object]], Callable[..., tuple[object, ...]]
]


def _make_specialize_impl_wrapper(
*,
native_impl: NativeSpecializeImpl | None = None,
create_factory: CreateSpecializeImpl | None = None,
) -> Callable[..., object]:
if native_impl is None:
native_impl = cast(
"NativeSpecializeImpl | None",
getattr(triton_jit, "native_specialize_impl", None),
)
if native_impl is None and create_factory is None:
raise AttributeError("native_specialize_impl unavailable")

def specialize_impl_wrapper(
*args: object,
**kwargs: object,
) -> object:
specialize_extra = cast(
"Callable[..., object] | None",
kwargs.pop("specialize_extra", None),
)
kwargs.pop("specialize_zero_one", None)
backend_param = kwargs.pop("backend", None)
args_list: list[object] = list(args)
backend_type: type[BaseBackend]
if backend_param is None and args_list:
first = args_list[0]
if isinstance(first, type) and issubclass(first, BaseBackend):
backend_type = first
args_list.pop(0)
elif isinstance(first, BaseBackend):
backend_type = type(first)
args_list.pop(0)
else:
backend_type = BaseBackend
elif isinstance(backend_param, type) and issubclass(backend_param, BaseBackend):
backend_type = backend_param
elif isinstance(backend_param, BaseBackend):
backend_type = type(backend_param)
else:
backend_type = BaseBackend

arg = kwargs.pop("arg", None)
if arg is None:
if args_list:
arg = args_list.pop(0)
else:
raise TypeError("specialize_impl() missing positional argument 'arg'")

def _pop_flag(
key: str,
*,
alt_keys: tuple[str, ...] = (),
default: bool | None = None,
) -> bool:
value = kwargs.pop(key, None)
if value is None:
for alt in alt_keys:
value = kwargs.pop(alt, None)
if value is not None:
break
if value is None:
if args_list:
value = args_list.pop(0)
elif default is not None:
value = default
else:
raise TypeError(f"specialize_impl() missing argument '{key}'")
return bool(value)

is_const = _pop_flag("is_const")
specialize_value = _pop_flag(
"specialize_value",
alt_keys=("specialize",),
default=True,
)
align = _pop_flag("align", default=True)

if native_impl is not None:
result = native_impl(
backend_type,
arg,
is_const,
specialize_value,
align,
)
if specialize_extra is not None:
with contextlib.suppress(Exception):
specialize_extra(arg)
else:
assert create_factory is not None

def _call_specialize_extra(
extra_arg: object,
kind: object,
*,
align: bool = True,
) -> object:
if specialize_extra is None:
return None
try:
return specialize_extra(extra_arg)
except TypeError:
try:
return specialize_extra(extra_arg, kind, align=align)
except Exception:
return None
except Exception:
return None

impl = create_factory(_call_specialize_extra)
result = impl(
arg,
is_const=is_const,
specialize_value=specialize_value,
align=align,
)
return result

return specialize_impl_wrapper


def _ensure_triton_specialize_impl_alias() -> None:
if hasattr(triton_jit, "specialize_impl"):
return
if hasattr(triton_jit, "native_specialize_impl"):
module: Any = triton_jit
module.specialize_impl = _make_specialize_impl_wrapper() # type: ignore[assignment]
return
if hasattr(triton_jit, "create_specialize_impl"):
module: Any = triton_jit
module.specialize_impl = _make_specialize_impl_wrapper(
create_factory=triton_jit.create_specialize_impl,
) # type: ignore[assignment]


_ensure_triton_specialize_impl_alias()


def _ensure_backend_specialization_alias() -> None:
if hasattr(BaseBackend, "get_arg_specialization"):
return
if hasattr(BaseBackend, "get_tensor_specialization"):
BaseBackend.get_arg_specialization = BaseBackend.get_tensor_specialization # type: ignore[attr-defined]


_ensure_backend_specialization_alias()


@functools.cache
def get_triton_find_paths_if() -> Callable[..., object]:
if hasattr(triton_jit, "find_paths_if"):
return triton_jit.find_paths_if
if hasattr(triton_jit, "_find_paths_if"):
return triton_jit._find_paths_if # type: ignore[attr-defined]
raise AttributeError("Unable to locate Triton find_paths_if helper")


@functools.cache
def get_triton_iterable_path() -> Callable[..., object]:
if hasattr(triton_jit, "get_iterable_path"):
return triton_jit.get_iterable_path
if hasattr(triton_jit, "_get_iterable_path"):
return triton_jit._get_iterable_path # type: ignore[attr-defined]
raise AttributeError("Unable to locate Triton get_iterable_path helper")


def supports_tensor_descriptor() -> bool:
Expand Down
Loading