Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import inspect
import itertools
import warnings
from collections.abc import Callable
from functools import lru_cache, wraps
from types import EllipsisType
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Final,
Expand All @@ -35,10 +35,6 @@
_tensor_type_base,
)

if TYPE_CHECKING:
from collections.abc import Callable


_logger: Final = _log_utils.get_logger(__name__)

P = ParamSpec("P")
Expand Down Expand Up @@ -189,24 +185,27 @@ def _get_func_lineref(func: Callable[P, R]) -> str:


def dltyped( # noqa: C901, PLR0915
scope_provider: DLTypeScopeProvider | Literal["self"] | None = None,
scope_provider: DLTypeScopeProvider | Literal["self"] | Callable[P, dict[str, int]] | None = None,
*,
enabled: bool = not _constants.GLOBAL_DISABLE,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Apply type checking to the decorated function.

Args:
scope_provider: An optional scope provider to use for type checking, if None, no scope provider is used, if 'self'
is used, the first argument of the function is expected to be a DLTypeScopeProvider and the function must be a method.
scope_provider: An optional scope provider to use for type checking
None (default): no scope provider is used
'self': the first argument of the function is expected to be a DLTypeScopeProvider and the function must be a method.
Callable: the callable must match the signature of the decorated function, the arguments of the function are passed
on each invocation to the callable and the results are returned as the scope before evaluating any dimensions.
enabled: if set to false, perform no type checking.

Returns:
A wrapper function with type checking

"""

def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901
def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR0915
if _dependency_utilities.is_torch_scripting() or not enabled:
# jit script doesn't support annotated type hints at all, we have no choice but to skip the type checking
return func
Expand Down Expand Up @@ -270,10 +269,10 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901
):
ctx.tensor_shape_map = scope_provider.get_dltype_scope()
_logger.debug("Using unbound scope provider %s", ctx.tensor_shape_map)
elif isinstance(scope_provider, Callable):
ctx.tensor_shape_map = scope_provider(*args, **kwargs)
elif scope_provider is not None:
raise _errors.DLTypeScopeProviderError(
bad_scope_provider=scope_provider,
)
raise _errors.DLTypeScopeProviderError(bad_scope_provider=scope_provider)

for name in dltype_hints:
if name == return_key:
Expand Down
4 changes: 1 addition & 3 deletions dltype/_lib/_tensor_type_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def validate_tensor(
source_type = np.ndarray # pyright: ignore[reportPossiblyUnboundVariable]

return core_schema.with_info_after_validator_function(
validate_tensor,
schema=core_schema.is_instance_schema(source_type),
field_name=handler.field_name,
validate_tensor, schema=core_schema.is_instance_schema(source_type)
)

def check(
Expand Down
15 changes: 15 additions & 0 deletions dltype/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# content of conftest.py

from __future__ import annotations

import pytest

empty_mark = pytest.Mark("", args=(), kwargs={})


def _by_slow_marker(item: pytest.Item) -> bool:
return item.get_closest_marker("slow", default=empty_mark) != empty_mark


def pytest_collection_modifyitems(items: list[pytest.Item]) -> None:
items.sort(key=_by_slow_marker, reverse=False)
33 changes: 26 additions & 7 deletions dltype/tests/dltype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def test_single_in_single_out(
class _TestBaseModel(BaseModel, frozen=True):
tensor: Annotated[torch.Tensor, dltype.TensorTypeBase("b c h w")]
tensor_2: Annotated[torch.Tensor, dltype.TensorTypeBase("b c h w")]
arg3: int = 0
arg4: Annotated[torch.Tensor, dltype.FloatTensor["..."]] | None = None


class _TestBaseModel2(BaseModel, frozen=True):
Expand Down Expand Up @@ -405,6 +407,12 @@ def test_numpy_mixed(tensor: NPFloatArrayT, expected: _RaisesInfo) -> None:
_RaisesInfo(),
id="int_tensor_2",
),
pytest.param(
dltype.IntTensor("b c h w"),
np_rand(1, 2, 3, 4),
_RaisesInfo(exception_type=dltype.DLTypeDtypeError),
id="int_tensor_3",
),
pytest.param(
dltype.FloatTensor("b c h w"),
torch.rand(1, 2, 3, 4).int(),
Expand All @@ -423,12 +431,6 @@ def test_numpy_mixed(tensor: NPFloatArrayT, expected: _RaisesInfo) -> None:
_RaisesInfo(),
id="float_tensor_3",
),
pytest.param(
dltype.IntTensor("b c h w"),
np_rand(1, 2, 3, 4),
_RaisesInfo(exception_type=dltype.DLTypeDtypeError),
id="int_tensor_2",
),
pytest.param(
dltype.DoubleTensor("b c h w"),
np_rand(1, 2, 3, 4).astype(np.float32),
Expand Down Expand Up @@ -526,6 +528,7 @@ def test_literal_shapes(
tensor_type.check(tensor)


@pytest.mark.slow
def test_onnx_export() -> None:
class _DummyModule(torch.nn.Module):
@dltype.dltyped()
Expand Down Expand Up @@ -563,6 +566,7 @@ def forward(
)


@pytest.mark.slow
def test_torch_compile() -> None:
class _DummyModule(torch.nn.Module):
@dltype.dltyped()
Expand Down Expand Up @@ -850,7 +854,7 @@ def func_with_anon_wildcard(
None,
func_with_mid_tensor_wildcard,
_RaisesInfo(),
id="mid_tensor_wildcard_2",
id="mid_tensor_wildcard_3",
),
pytest.param(
torch.rand(1, 2),
Expand Down Expand Up @@ -1598,6 +1602,7 @@ def func(
func((torch.zeros(1, 1, 3), torch.zeros(3, 2, 1), 1))


@pytest.mark.slow
def test_jax() -> None:
@dltype.dltyped()
def func(
Expand Down Expand Up @@ -1711,3 +1716,17 @@ def tuple_function(
) -> Self:
"""A function that takes a tensor and returns a tensor."""
return self


def test_lambda_scope() -> None:

def _scope(tensor: torch.Tensor, num_cams: int) -> dict[str, int]:
return {"num_cams": num_cams, "batch": tensor.shape[0] // num_cams}

@dltype.dltyped(_scope)
def func(
tensor: Annotated[torch.Tensor, dltype.FloatTensor["batch*num_cams width height"]], num_cams: int
) -> None:
assert tensor is not None

func(torch.zeros((3, 100, 100)), 3)
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ include = ["dltype"]
reportUnnecessaryTypeIgnoreComment = "error"
typeCheckingMode = "strict"

[tool.pytest.ini_options]
addopts = "--cov=dltype --cov-report lcov:lcov.info --cov-report html"
[tool.pytest]
addopts = ["--cov=dltype", "--cov-report=lcov:lcov.info", "--cov-report=html"]
console_output_style = "count"
markers = ["slow: mark test as slow (run last)."]
strict = true

[tool.ruff]
indent-width = 4
Expand Down
Loading
Loading