diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b8995b..9cb4c86 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,7 +56,7 @@ jobs: with: python-version: ${{ matrix.env.python }} - name: create environment - run: uvx hatch env create ${{ matrix.env.name }} + run: uvx hatch -v env create ${{ matrix.env.name }} - name: run tests with coverage run: | uvx hatch run ${{ matrix.env.name }}:run-cov diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 414a8e8..e1314a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,6 +32,7 @@ repos: - array-api-compat>=1.13 - dask>=2026.1 - h5py>=3.15 + - jax>=0.10 - numba>=0.63 - packaging>=26 - pytest>=9 diff --git a/pyproject.toml b/pyproject.toml index 5eb3105..3684e89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", ] -dynamic = [ "description", "readme", "version" ] -dependencies = [ "numpy>=2" ] +dynamic = [ "version" ] +dependencies = [ "array-api-compat", "numpy>=2" ] optional-dependencies.accel = [ "numba>=0.57" ] optional-dependencies.dask = [ "dask>=2023.6.1" ] optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ] @@ -37,9 +37,10 @@ entry-points.pytest11.fast_array_utils = "testing.fast_array_utils.pytest" [dependency-groups] test = [ "anndata", - "fast-array-utils[accel]", + "fast-array-utils[full]", + "jax", + "jaxlib", "scikit-learn", - "zarr", { include-group = "test-min" }, ] doc = [ @@ -66,12 +67,8 @@ envs.docs.scripts.clean = "git clean -fdX docs" envs.docs.scripts.open = "python -m webbrowser -t docs/_build/html/index.html" envs.hatch-test.default-args = [] envs.hatch-test.dependency-groups = [ "test-min" ] -# TODO: remove scipy once https://github.com/pypa/hatch/pull/2127 is released -envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape", "scipy" ] +envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape" ] envs.hatch-test.env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed" -envs.hatch-test.overrides.matrix.extras.features = [ - { if = [ "full" ], value = "full" }, -] envs.hatch-test.overrides.matrix.extras.dependency-groups = [ { if = [ "full" ], value = "test" }, ] @@ -85,9 +82,10 @@ envs.hatch-test.matrix = [ { python = [ "3.14", "3.12" ], extras = [ "full", "min" ] }, { python = [ "3.12" ], extras = [ "full" ], resolution = [ "lowest" ] }, ] -metadata.hooks.docstring-description = {} -metadata.hooks.fancy-pypi-readme.content-type = "text/x-rst" -metadata.hooks.fancy-pypi-readme.fragments = [ { path = "README.rst", start-after = ".. begin" } ] +# TODO: re-activate incl. `dynamic = [ "description", "readme", ... ]` after https://github.com/pypa/hatch/issues/2252 +# metadata.hooks.docstring-description = {} +# metadata.hooks.fancy-pypi-readme.content-type = "text/x-rst" +# metadata.hooks.fancy-pypi-readme.fragments = [ { path = "README.rst", start-after = ".. begin" } ] version.source = "vcs" version.raw-options = { local_scheme = "no-local-version" } # be able to publish dev version diff --git a/src/fast_array_utils/conv/__init__.py b/src/fast_array_utils/conv/__init__.py index 6a34ca0..d00ffe2 100644 --- a/src/fast_array_utils/conv/__init__.py +++ b/src/fast_array_utils/conv/__init__.py @@ -38,13 +38,19 @@ def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ... +@overload +def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[False] = False) -> A: ... +@overload +def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ... + + def to_dense( - x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix, + x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False, -) -> NDArray[Any] | types.DaskArray | types.CupyArray: +) -> NDArray[Any] | types.DaskArray | types.CupyArray | types.HasArrayNamespace: r"""Convert x to a dense array. If ``to_cpu_memory`` is :data:`False`, :class:`dask.array.Array`\ s and diff --git a/src/fast_array_utils/conv/_to_dense.py b/src/fast_array_utils/conv/_to_dense.py index 6b5de88..66d1580 100644 --- a/src/fast_array_utils/conv/_to_dense.py +++ b/src/fast_array_utils/conv/_to_dense.py @@ -21,7 +21,7 @@ # fallback’s arg0 type has to include types of registered functions @singledispatch def to_dense_( - x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace, /, *, order: Literal["K", "A", "C", "F"] = "K", @@ -39,6 +39,13 @@ def _to_dense_cs(x: types.spmatrix | types.sparray, /, *, order: Literal["K", "A return scipy.to_dense(x, order=sparse_order(x, order=order)) +@to_dense_.register(np.ndarray) +def _to_dense_numpy(x: np.ndarray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> np.ndarray: + # to bypass the _to_dense_array_api path + del to_cpu_memory + return np.asarray(x, order=order) + + @to_dense_.register(types.DaskArray) def _to_dense_dask(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> NDArray[Any] | types.DaskArray: from . import to_dense @@ -69,6 +76,13 @@ def _to_dense_cupy(x: GpuArray, /, *, order: Literal["K", "A", "C", "F"] = "K", return x.get(order="A") if to_cpu_memory else x +@to_dense_.register(types.HasArrayNamespace) +def _to_dense_array_api[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> A | np.ndarray: + if to_cpu_memory: + return np.asarray(x, order=order) + return x + + def sparse_order(x: types.spmatrix | types.sparray | types.CupySpMatrix | types.CSDataset, /, *, order: Literal["K", "A", "C", "F"]) -> Literal["C", "F"]: if TYPE_CHECKING: from scipy.sparse._base import _spbase diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 24712d8..5cfd439 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -37,14 +37,16 @@ def is_constant(x: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> ND def is_constant(x: types.CupyArray, /, *, axis: Literal[0, 1]) -> types.CupyArray: ... @overload def is_constant(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None) -> types.DaskArray: ... +@overload +def is_constant[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None) -> bool | A: ... def is_constant( - x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray, + x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, -) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: +) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray | types.HasArrayNamespace: """Check whether values in array are constant. Parameters @@ -90,15 +92,17 @@ def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | def mean(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> types.CupyArray: ... @overload def mean(x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None) -> types.DaskArray: ... +@overload +def mean[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None) -> A: ... def mean( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, -) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray: +) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace: """Mean over both or one axis. Parameters @@ -145,10 +149,10 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup def mean_var(x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[types.CupyArray, types.CupyArray]: ... @overload def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ... - - +@overload +def mean_var[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[A, A]: ... def mean_var( - x: CpuArray | GpuArray | types.DaskArray, + x: CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, @@ -158,6 +162,7 @@ def mean_var( | tuple[NDArray[np.float64], NDArray[np.float64]] | tuple[types.CupyArray, types.CupyArray] | tuple[types.DaskArray, types.DaskArray] + | tuple[types.HasArrayNamespace, types.HasArrayNamespace] ): """Mean and variance over both or one axis. @@ -214,13 +219,13 @@ def _mk_generic_op(op: DtypeOps) -> StatFunDtype: ... # https://github.com/scverse/fast-array-utils/issues/52 def _mk_generic_op(op: Ops) -> StatFunNoDtype | StatFunDtype: def _generic_op( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, - ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray | types.HasArrayNamespace: from ._generic_ops import generic_op assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}" @@ -249,8 +254,10 @@ def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ def min(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... @overload def min(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +@overload +def min[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ... def min( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, @@ -304,8 +311,10 @@ def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ def max(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... @overload def max(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +@overload +def max[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ... def max( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, @@ -359,14 +368,16 @@ def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, ke def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... @overload def sum(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +@overload +def sum[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> A: ... def sum( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, -) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: +) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace: """Sum over both or one axis. Parameters diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index 52a48a8..dd8a07b 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -22,34 +22,32 @@ type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None -def _run_numpy_op( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, +@singledispatch +def generic_op( + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, + /, op: Ops, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: - arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))) - return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr + raise NotImplementedError # pragma: no cover -@singledispatch -def generic_op( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, +@generic_op.register(np.ndarray | types.H5Dataset | types.ZarrArray) +# register explicitly to avoid the array API path and performance slow down +def _generic_op_numpy_disk( + x: np.ndarray | DiskArray, /, op: Ops, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, -) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: +) -> NDArray[Any] | np.number[Any]: del keep_cupy_as_array - if TYPE_CHECKING: - # these are never passed to this fallback function, but `singledispatch` wants them - assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix) - # np supports these, but doesn’t know it. (TODO: test cupy) - assert not isinstance(x, types.ZarrArray | types.H5Dataset) - return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype)) + return getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return] @generic_op.register(types.CupyArray | types.CupyCSMatrix) @@ -62,7 +60,8 @@ def _generic_op_cupy( dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, ) -> types.CupyArray | np.number[Any]: - arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype)) + arr = cast("types.CupyArray | types.CupyCOOMatrix", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))) + arr = arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze() @@ -109,3 +108,22 @@ def _generic_op_dask( dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) + + +@generic_op.register(types.HasArrayNamespace) +def _generic_op_array_api[A: types.HasArrayNamespace]( + x: A, + /, + op: Ops, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, +) -> A: + """Handle arrays with native array API support.""" + del keep_cupy_as_array + + import array_api_compat + + xp = array_api_compat.array_namespace(x) + return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return] diff --git a/src/fast_array_utils/stats/_is_constant.py b/src/fast_array_utils/stats/_is_constant.py index 5119ab4..2a98e93 100644 --- a/src/fast_array_utils/stats/_is_constant.py +++ b/src/fast_array_utils/stats/_is_constant.py @@ -2,7 +2,7 @@ from __future__ import annotations from functools import partial, singledispatch -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import numba import numpy as np @@ -19,20 +19,20 @@ @singledispatch def is_constant_( - a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray, + a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, ) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # pragma: no cover - raise NotImplementedError + raise NotImplementedError # pragma: no cover -@is_constant_.register(np.ndarray | types.CupyArray) +@is_constant_.register(np.ndarray | types.CupyArray | types.HasArrayNamespace) def _is_constant_ndarray(a: NDArray[Any] | types.CupyArray, /, *, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool] | types.CupyArray: # Should eventually support nd, not now. match axis: case None: - return bool((a == a.flat[0]).all()) + return bool((a == a.reshape(-1)[0]).all()) case 0: return _is_constant_rows(a.T) case 1: @@ -40,8 +40,8 @@ def _is_constant_ndarray(a: NDArray[Any] | types.CupyArray, /, *, axis: Literal[ def _is_constant_rows(a: NDArray[Any] | types.CupyArray) -> NDArray[np.bool] | types.CupyArray: - b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape) - return cast("NDArray[np.bool]", (a == b).all(axis=1)) + # broadcasts without needing np.broadcast_to + return (a == a[:, 0:1]).all(axis=1) @is_constant_.register(types.CSBase) diff --git a/src/fast_array_utils/stats/_mean.py b/src/fast_array_utils/stats/_mean.py index ba08164..9b8c868 100644 --- a/src/fast_array_utils/stats/_mean.py +++ b/src/fast_array_utils/stats/_mean.py @@ -18,12 +18,12 @@ def mean_( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, -) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray: +) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray | types.HasArrayNamespace: total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type] n = np.prod(x.shape) if axis is None else x.shape[axis] return total / n # type: ignore[no-any-return] diff --git a/src/fast_array_utils/stats/_mean_var.py b/src/fast_array_utils/stats/_mean_var.py index 1e867f9..c704bc9 100644 --- a/src/fast_array_utils/stats/_mean_var.py +++ b/src/fast_array_utils/stats/_mean_var.py @@ -21,7 +21,7 @@ @no_type_check # mypy is extremely confused def mean_var_( - x: CpuArray | GpuArray | types.DaskArray, + x: CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, @@ -34,11 +34,18 @@ def mean_var_( ): from . import mean + if isinstance(x, np.ndarray | types.CSBase) or not isinstance(x, types.HasArrayNamespace): + xp = np + else: + import array_api_compat + + xp = array_api_compat.array_namespace(x) + if axis is not None and isinstance(x, types.CSBase): mean_, var = _sparse_mean_var(x, axis=axis) else: - mean_ = mean(x, axis=axis, dtype=np.float64) - mean_sq = mean(power(x, 2, dtype=np.float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=np.float64) + mean_ = mean(x, axis=axis, dtype=xp.float64) + mean_sq = mean(power(x, 2, dtype=xp.float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=xp.float64) var = mean_sq - mean_**2 if correction: # R convention == 1 (unbiased estimator) n = np.prod(x.shape) if axis is None else x.shape[axis] diff --git a/src/fast_array_utils/stats/_power.py b/src/fast_array_utils/stats/_power.py index 8387836..82653a9 100644 --- a/src/fast_array_utils/stats/_power.py +++ b/src/fast_array_utils/stats/_power.py @@ -15,7 +15,7 @@ from fast_array_utils.typing import CpuArray, GpuArray # All supported array types except for disk ones and CSDataset - type Array = CpuArray | GpuArray | types.DaskArray + type Array = CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr: @@ -26,9 +26,21 @@ def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr: @singledispatch def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array: - if TYPE_CHECKING: - assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix) - return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator] + raise NotImplementedError # pragma: no cover + + +@_power.register(np.ndarray | types.CupyArray) +def _power_numpy_cupy(x: np.ndarray, n: int, /, dtype: DTypeLike | None = None) -> np.ndarray: + # avoids slower xp.pow(xp.astype(...)) path + return x**n if dtype is None else np.power(x, n, dtype=dtype) + + +@_power.register(types.HasArrayNamespace) +def _power_array_api(x: types.HasArrayNamespace, n: int, /, dtype: DTypeLike | None = None) -> types.HasArrayNamespace: + import array_api_compat + + xp = array_api_compat.array_namespace(x) + return xp.pow(x, n) if dtype is None else xp.pow(xp.astype(x, dtype), n) # type: ignore[no-any-return] @_power.register(types.CSBase | types.CupyCSMatrix) diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py index be671dd..a5084da 100644 --- a/src/fast_array_utils/stats/_typing.py +++ b/src/fast_array_utils/stats/_typing.py @@ -28,8 +28,13 @@ class StatFunNoDtype(Protocol): __name__: str def __call__( - self, x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False - ) -> types.DaskArray: ... + self, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, + /, + *, + axis: Literal[0, 1] | None = None, + keep_cupy_as_array: bool = False, + ) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace: ... class StatFunDtype(Protocol): @@ -37,13 +42,13 @@ class StatFunDtype(Protocol): def __call__( self, - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, - ) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: ... + ) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace: ... NoDtypeOps = Literal["max", "min"] diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index c1fbf26..8da8972 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -4,7 +4,11 @@ from __future__ import annotations from importlib.util import find_spec -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from types import ModuleType __all__ = [ @@ -22,6 +26,7 @@ "DaskArray", "H5Dataset", "H5Group", + "HasArrayNamespace", "ZarrArray", "ZarrGroup", "coo_array", @@ -116,3 +121,23 @@ CSRDataset.__module__ = CSCDataset.__module__ = "anndata.abc" CSDataset = CSRDataset | CSCDataset """Anndata sparse out-of-core matrices.""" + + +@runtime_checkable +class HasArrayNamespace(Protocol): + """An array object compatible with the Python array API standard.""" + + @property + def ndim(self) -> int: + """The number of dimensions of the array.""" + + @property + def shape(self) -> tuple[int, ...]: + """The shape of the array.""" + + @property + def dtype(self) -> object: + """The data type of the array.""" + + def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType: + """Get Array API namespace.""" diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..5cb4fc7 --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from fast_array_utils import stats +from fast_array_utils.conv import to_dense + + +if TYPE_CHECKING: + from typing import Literal + + +pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed") + +if find_spec("jax"): + # enabling 64-bit precision in JAX as it defaults to 32-bit only + # problem as mean_var passes dtype= np.float64 internally, which crashes without this fix + import jax + + jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call] # noqa: FBT003 + + +@pytest.fixture +def jax_arr() -> jax.Array: + import jax.numpy as jnp + + return jnp.array([[1, 0], [2, 0], [3, 0]], dtype=jnp.float32) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +@pytest.mark.parametrize("func", ["sum", "min", "max", "mean"]) +def test_simple_stat(jax_arr: jax.Array, func: Literal["sum", "min", "max", "mean"], axis: Literal[0, 1] | None) -> None: + import jax.numpy as jnp + + result = getattr(stats, func)(jax_arr, axis=axis) + expected = getattr(jnp, func)(jax_arr, axis=axis) + + assert type(result) is type(expected) + if func == "mean": + assert jnp.allclose(result, expected) + else: + assert jnp.array_equal(result, expected) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_is_constant(axis: Literal[0, 1] | None) -> None: + import jax.numpy as jnp + + x = jnp.array( + [ + [0, 0, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 0], + ], + dtype=jnp.float32, + ) + result = stats.is_constant(x, axis=axis) + + if axis is None: + assert not result + elif axis == 0: + expected = jnp.array([True, True, False, False]) + assert type(result) is type(expected) + assert jnp.array_equal(result, expected) + else: + expected = jnp.array([False, False, True, True, False, True]) + assert type(result) is type(expected) + assert jnp.array_equal(result, expected) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_mean_var(subtests: pytest.Subtests, jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None: + import jax.numpy as jnp + + mean, var = stats.mean_var(jax_arr, axis=axis, correction=1) + + for name, result in dict(mean=mean, var=var).items(): + if name == "mean": + expected = jnp.mean(jax_arr, axis=axis) + else: + n = jax_arr.size if axis is None else jax_arr.shape[axis] + expected = jnp.var(jax_arr, axis=axis) * n / (n - 1) + + with subtests.test(name): + assert type(result) is type(expected) + assert jnp.allclose(result, expected) + + +@pytest.mark.parametrize("to_cpu_memory", [True, False], ids=["to_cpu_memory", "not_to_cpu_memory"]) +def test_to_dense(*, jax_arr: jax.Array, to_cpu_memory: bool) -> None: + import jax.numpy as jnp + + result = to_dense(jax_arr, to_cpu_memory=to_cpu_memory) + + if to_cpu_memory: + assert isinstance(result, np.ndarray) + else: + assert isinstance(result, jax.Array) + assert jnp.array_equal(result, jax_arr) diff --git a/tests/test_stats.py b/tests/test_stats.py index 334d1fd..8371250 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -103,9 +103,11 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: return np_arr -def to_np_dense_checked( - stat: NDArray[DTypeOut] | np.number[Any] | types.DaskArray, axis: Literal[0, 1] | None, arr: CpuArray | GpuArray | DiskArray | types.DaskArray -) -> NDArray[DTypeOut] | np.number[Any]: +def to_np_dense_checked[DT: DTypeOut]( + stat: NDArray[DT] | np.number[Any] | types.DaskArray | types.HasArrayNamespace, + axis: Literal[0, 1] | None, + arr: CpuArray | GpuArray | DiskArray | types.COOBase | types.DaskArray | types.HasArrayNamespace, +) -> NDArray[DT] | np.number[Any]: match axis, arr: case _, types.DaskArray(): assert isinstance(stat, types.DaskArray), type(stat) @@ -208,7 +210,7 @@ def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.D np_arr = rng.random((100, 100)) arr = array_type(np_arr) - result = to_np_dense_checked(func(arr, axis=axis), axis, arr) + result = to_np_dense_checked(func(arr, axis=axis), axis, arr) # type: ignore[arg-type] expected = (np.min if func is stats.min else np.max)(np_arr, axis=axis) np.testing.assert_array_equal(result, expected) @@ -229,7 +231,7 @@ def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1] np_arr = np.array(data, dtype=np.float32) arr = array_type(np_arr) assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes" - stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute()) + stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute()) # type: ignore[union-attr] if isinstance(stat, types.CupyArray): stat = stat.get() np_func = getattr(np, func.__name__) @@ -321,6 +323,8 @@ def test_mean_var_pbmc_dask(array_type: ArrayType[types.DaskArray], pbmc64k_redu arr = array_type(mat) mean_mat, var_mat = stats.mean_var(mat, axis=0, correction=1) + mean_arr: NDArray[Any] | np.number # actually just NDArray, and mypy should be able to infer. + var_arr: NDArray[Any] | np.number mean_arr, var_arr = (to_np_dense_checked(a, 0, arr) for a in stats.mean_var(arr, axis=0, correction=1)) rtol = 1.0e-5 if array_type.flags & Flags.Gpu else 1.0e-7 diff --git a/typings/cupyx/scipy/sparse/_coo.pyi b/typings/cupyx/scipy/sparse/_coo.pyi index 61a8887..e67e691 100644 --- a/typings/cupyx/scipy/sparse/_coo.pyi +++ b/typings/cupyx/scipy/sparse/_coo.pyi @@ -8,4 +8,4 @@ from ._base import spmatrix class coo_matrix(spmatrix): format: Literal["coo"] = "coo" - def get(self, stream: cupy.cuda.Stream | None = None) -> sps.spmatrix: ... + def get(self, stream: cupy.cuda.Stream | None = None) -> sps.coo_matrix: ... diff --git a/typings/cupyx/scipy/sparse/_csr.pyi b/typings/cupyx/scipy/sparse/_csr.pyi index f4893b1..e092517 100644 --- a/typings/cupyx/scipy/sparse/_csr.pyi +++ b/typings/cupyx/scipy/sparse/_csr.pyi @@ -8,4 +8,4 @@ from ._compressed import _compressed_sparse_matrix class csr_matrix(_compressed_sparse_matrix): format: Literal["csr"] = "csr" - def get(self, stream: cupy.cuda.Stream | None = None) -> sps.csc_matrix: ... + def get(self, stream: cupy.cuda.Stream | None = None) -> sps.csr_matrix: ...