diff --git a/einx/_src/frontend/impl/jax.py b/einx/_src/frontend/impl/jax.py index 8a73aee..d85cf35 100644 --- a/einx/_src/frontend/impl/jax.py +++ b/einx/_src/frontend/impl/jax.py @@ -1,3 +1,6 @@ +from typing import ParamSpec, TypeVar, Concatenate, cast +from collections.abc import Callable + import einx._src.tracer as tracer import einx._src.adapter as adapter from ..api import api @@ -42,7 +45,11 @@ def get_shape(tensor): return {"optimizations": optimizations, "compiler": tracer.compiler.python, "is_supported_tensor": is_supported_tensor, "get_shape": get_shape} -def adapt_with_vmap(op, signature=None): +P = ParamSpec("P") +R = TypeVar("R") + + +def adapt_with_vmap(op: Callable[P, R], signature=None) -> Callable[Concatenate[str, P], R]: iskwarg = _make_iskwarg(op) jax = tracer.signature.jax() @@ -55,7 +62,7 @@ def adapt_with_vmap(op, signature=None): op = adapter.namedtensor_calltensorfactory.op(op, expected_type=jax.numpy.ndarray) op = adapter.einx_from_namedtensor.op(op, iskwarg=iskwarg, el_op=signature, implicit_output="bijective") - return api(op, backend=types.SimpleNamespace(**_get_backend_kwargs())) + return cast(Callable[Concatenate[str, P], R], api(op, backend=types.SimpleNamespace(**_get_backend_kwargs()))) adapt_with_vmap.__doc__ = _make_doc_adapt_with_vmap("jax", "``jax.vmap``") diff --git a/einx/_src/frontend/impl/mlx.py b/einx/_src/frontend/impl/mlx.py index 46d4161..b259b05 100644 --- a/einx/_src/frontend/impl/mlx.py +++ b/einx/_src/frontend/impl/mlx.py @@ -1,3 +1,6 @@ +from typing import ParamSpec, TypeVar, Concatenate, cast +from collections.abc import Callable + import einx._src.tracer as tracer import einx._src.adapter as adapter from ..types import Tensor @@ -33,7 +36,11 @@ def get_shape(tensor): return {"optimizations": optimizations, "compiler": tracer.compiler.python, "is_supported_tensor": is_supported_tensor, "get_shape": get_shape} -def adapt_with_vmap(op, signature=None): +P = ParamSpec("P") +R = TypeVar("R") + + +def adapt_with_vmap(op: Callable[P, R], signature=None) -> Callable[Concatenate[str, P], R]: iskwarg = _make_iskwarg(op) mlx = tracer.signature.mlx() @@ -46,7 +53,7 @@ def adapt_with_vmap(op, signature=None): op = adapter.namedtensor_calltensorfactory.op(op, expected_type=mlx.core.array) op = adapter.einx_from_namedtensor.op(op, iskwarg=iskwarg, el_op=signature, implicit_output="bijective") - return api(op, backend=types.SimpleNamespace(**_get_backend_kwargs())) + return cast(Callable[Concatenate[str, P], R], api(op, backend=types.SimpleNamespace(**_get_backend_kwargs()))) adapt_with_vmap.__doc__ = _make_doc_adapt_with_vmap("mlx", "``mlx.core.vmap``") diff --git a/einx/_src/frontend/impl/torch.py b/einx/_src/frontend/impl/torch.py index 1e6e9e0..3eddf5c 100644 --- a/einx/_src/frontend/impl/torch.py +++ b/einx/_src/frontend/impl/torch.py @@ -1,3 +1,6 @@ +from typing import ParamSpec, TypeVar, Concatenate, cast +from collections.abc import Callable + import einx._src.tracer as tracer import einx._src.adapter as adapter from ..api import api @@ -62,7 +65,11 @@ def get_shape(tensor): return {"optimizations": optimizations, "compiler": tracer.compiler.python, "is_supported_tensor": is_supported_tensor, "get_shape": get_shape} -def adapt_with_vmap(op, signature=None): +P = ParamSpec("P") +R = TypeVar("R") + + +def adapt_with_vmap(op: Callable[P, R], signature=None) -> Callable[Concatenate[str, P], R]: _raise_on_invalid_version() iskwarg = _make_iskwarg(op) @@ -85,7 +92,7 @@ def adapt_with_vmap(op, signature=None): torch.compiler.allow_in_graph(op) - return op + return cast(Callable[Concatenate[str, P], R], op) adapt_with_vmap.__doc__ = _make_doc_adapt_with_vmap("torch", "``torch.vmap``")