Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
64f5304
array-api initially implementation
amalia-k510 Mar 23, 2026
5373050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2026
f016c39
updates in regards to the handler and some array_api handling fixes
amalia-k510 Apr 15, 2026
cca3ad6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2026
c1bb155
Issues with jax test are fixed, introduced similar tests with pytorch
amalia-k510 Apr 15, 2026
c9a8f85
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 15, 2026
bd08c2e
pre-commit fixes
amalia-k510 Apr 15, 2026
b2e3f9b
mipy issues fix
amalia-k510 Apr 15, 2026
65e83a6
speed up fix
amalia-k510 Apr 15, 2026
2a1924d
Update src/fast_array_utils/stats/_mean_var.py
amalia-k510 Apr 17, 2026
9684213
Addressed the comments
amalia-k510 Apr 17, 2026
1e3c296
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2026
63c5e16
chore: simplify
flying-sheep Apr 20, 2026
9c8466a
addressing comments about removing is_array_api_obj check
amalia-k510 Apr 27, 2026
86a7503
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
e28f176
import fix
amalia-k510 Apr 27, 2026
c704259
ignore comments update and mypy test
amalia-k510 Apr 27, 2026
ab6e200
Merge branch 'main' into array-api-implementation
amalia-k510 Apr 27, 2026
aaccbda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
7e5102a
main version
amalia-k510 Apr 27, 2026
eed16a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
567107a
residues ignore comments removed
amalia-k510 Apr 27, 2026
1e41d24
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
aab50f7
pyproject, jax optional dependencies
amalia-k510 Apr 27, 2026
91ef896
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
0f62abc
commented addressed, mypy try again
amalia-k510 Apr 27, 2026
d642668
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
37a6634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
bab4392
mypy comment add
amalia-k510 Apr 27, 2026
d22f74b
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
3d4ee3a
types
flying-sheep Apr 27, 2026
c77a1bc
types for others
amalia-k510 Apr 29, 2026
6d3891e
types, missing parameters
amalia-k510 Apr 29, 2026
9e2a9fe
revert pyproject.toml
amalia-k510 Apr 29, 2026
57117d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
0415fd5
rework deps
flying-sheep Apr 30, 2026
3857917
Merge branch 'main' into array-api-implementation
flying-sheep Apr 30, 2026
592b7b7
fix deps
flying-sheep Apr 30, 2026
c890bc3
fix types
flying-sheep Apr 30, 2026
49283b0
fix cupy tests
flying-sheep Apr 30, 2026
f6463cc
fix disk array
flying-sheep Apr 30, 2026
474c969
fmt
flying-sheep Apr 30, 2026
5de4e5b
coverage
flying-sheep Apr 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
with:
python-version: ${{ matrix.env.python }}
- name: create environment
run: uvx hatch env create ${{ matrix.env.name }}
run: uvx hatch -v env create ${{ matrix.env.name }}
- name: run tests with coverage
run: |
uvx hatch run ${{ matrix.env.name }}:run-cov
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ repos:
- array-api-compat>=1.13
- dask>=2026.1
- h5py>=3.15
- jax>=0.10
- numba>=0.63
- packaging>=26
- pytest>=9
Expand Down
22 changes: 10 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
]
dynamic = [ "description", "readme", "version" ]
dependencies = [ "numpy>=2" ]
dynamic = [ "version" ]
dependencies = [ "array-api-compat", "numpy>=2" ]
optional-dependencies.accel = [ "numba>=0.57" ]
optional-dependencies.dask = [ "dask>=2023.6.1" ]
optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ]
Expand All @@ -37,9 +37,10 @@ entry-points.pytest11.fast_array_utils = "testing.fast_array_utils.pytest"
[dependency-groups]
test = [
"anndata",
"fast-array-utils[accel]",
"fast-array-utils[full]",
"jax",
"jaxlib",
"scikit-learn",
"zarr",
{ include-group = "test-min" },
]
doc = [
Expand All @@ -66,12 +67,8 @@ envs.docs.scripts.clean = "git clean -fdX docs"
envs.docs.scripts.open = "python -m webbrowser -t docs/_build/html/index.html"
envs.hatch-test.default-args = []
envs.hatch-test.dependency-groups = [ "test-min" ]
# TODO: remove scipy once https://github.com/pypa/hatch/pull/2127 is released
envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape", "scipy" ]
envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape" ]
envs.hatch-test.env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed"
envs.hatch-test.overrides.matrix.extras.features = [
{ if = [ "full" ], value = "full" },
]
envs.hatch-test.overrides.matrix.extras.dependency-groups = [
{ if = [ "full" ], value = "test" },
]
Expand All @@ -85,9 +82,10 @@ envs.hatch-test.matrix = [
{ python = [ "3.14", "3.12" ], extras = [ "full", "min" ] },
{ python = [ "3.12" ], extras = [ "full" ], resolution = [ "lowest" ] },
]
metadata.hooks.docstring-description = {}
metadata.hooks.fancy-pypi-readme.content-type = "text/x-rst"
metadata.hooks.fancy-pypi-readme.fragments = [ { path = "README.rst", start-after = ".. begin" } ]
# TODO: re-activate incl. `dynamic = [ "description", "readme", ... ]` after https://github.com/pypa/hatch/issues/2252
# metadata.hooks.docstring-description = {}
# metadata.hooks.fancy-pypi-readme.content-type = "text/x-rst"
# metadata.hooks.fancy-pypi-readme.fragments = [ { path = "README.rst", start-after = ".. begin" } ]
version.source = "vcs"
version.raw-options = { local_scheme = "no-local-version" } # be able to publish dev version

Expand Down
10 changes: 8 additions & 2 deletions src/fast_array_utils/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,19 @@ def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C
def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ...


@overload
def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[False] = False) -> A: ...
@overload
def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ...


def to_dense(
x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix,
x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace,
/,
*,
order: Literal["K", "A", "C", "F"] = "K",
to_cpu_memory: bool = False,
) -> NDArray[Any] | types.DaskArray | types.CupyArray:
) -> NDArray[Any] | types.DaskArray | types.CupyArray | types.HasArrayNamespace:
r"""Convert x to a dense array.

If ``to_cpu_memory`` is :data:`False`, :class:`dask.array.Array`\ s and
Expand Down
16 changes: 15 additions & 1 deletion src/fast_array_utils/conv/_to_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# fallback’s arg0 type has to include types of registered functions
@singledispatch
def to_dense_(
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace,
/,
*,
order: Literal["K", "A", "C", "F"] = "K",
Expand All @@ -39,6 +39,13 @@ def _to_dense_cs(x: types.spmatrix | types.sparray, /, *, order: Literal["K", "A
return scipy.to_dense(x, order=sparse_order(x, order=order))


@to_dense_.register(np.ndarray)
def _to_dense_numpy(x: np.ndarray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> np.ndarray:
# to bypass the _to_dense_array_api path
del to_cpu_memory
return np.asarray(x, order=order)


@to_dense_.register(types.DaskArray)
def _to_dense_dask(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> NDArray[Any] | types.DaskArray:
from . import to_dense
Expand Down Expand Up @@ -69,6 +76,13 @@ def _to_dense_cupy(x: GpuArray, /, *, order: Literal["K", "A", "C", "F"] = "K",
return x.get(order="A") if to_cpu_memory else x


@to_dense_.register(types.HasArrayNamespace)
def _to_dense_array_api[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> A | np.ndarray:
if to_cpu_memory:
return np.asarray(x, order=order)
return x


def sparse_order(x: types.spmatrix | types.sparray | types.CupySpMatrix | types.CSDataset, /, *, order: Literal["K", "A", "C", "F"]) -> Literal["C", "F"]:
if TYPE_CHECKING:
from scipy.sparse._base import _spbase
Expand Down
37 changes: 24 additions & 13 deletions src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ def is_constant(x: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> ND
def is_constant(x: types.CupyArray, /, *, axis: Literal[0, 1]) -> types.CupyArray: ...
@overload
def is_constant(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None) -> types.DaskArray: ...
@overload
def is_constant[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None) -> bool | A: ...


def is_constant(
x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray:
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray | types.HasArrayNamespace:
"""Check whether values in array are constant.

Parameters
Expand Down Expand Up @@ -90,15 +92,17 @@ def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike |
def mean(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> types.CupyArray: ...
@overload
def mean(x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None) -> types.DaskArray: ...
@overload
def mean[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None) -> A: ...


def mean(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray:
) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
"""Mean over both or one axis.

Parameters
Expand Down Expand Up @@ -145,10 +149,10 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup
def mean_var(x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[types.CupyArray, types.CupyArray]: ...
@overload
def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ...


@overload
def mean_var[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[A, A]: ...
def mean_var(
x: CpuArray | GpuArray | types.DaskArray,
x: CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
Expand All @@ -158,6 +162,7 @@ def mean_var(
| tuple[NDArray[np.float64], NDArray[np.float64]]
| tuple[types.CupyArray, types.CupyArray]
| tuple[types.DaskArray, types.DaskArray]
| tuple[types.HasArrayNamespace, types.HasArrayNamespace]
):
"""Mean and variance over both or one axis.

Expand Down Expand Up @@ -214,13 +219,13 @@ def _mk_generic_op(op: DtypeOps) -> StatFunDtype: ...
# https://github.com/scverse/fast-array-utils/issues/52
def _mk_generic_op(op: Ops) -> StatFunNoDtype | StatFunDtype:
def _generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray | types.HasArrayNamespace:
from ._generic_ops import generic_op

assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}"
Expand Down Expand Up @@ -249,8 +254,10 @@ def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
def min(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ...
@overload
def min(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
@overload
def min[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ...
def min(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
Expand Down Expand Up @@ -304,8 +311,10 @@ def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
def max(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ...
@overload
def max(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
@overload
def max[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ...
def max(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
Expand Down Expand Up @@ -359,14 +368,16 @@ def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, ke
def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ...
@overload
def sum(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
@overload
def sum[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> A: ...
def sum(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray:
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
"""Sum over both or one axis.

Parameters
Expand Down
48 changes: 33 additions & 15 deletions src/fast_array_utils/stats/_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,32 @@
type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None


def _run_numpy_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
@singledispatch
def generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
raise NotImplementedError # pragma: no cover


@singledispatch
def generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
@generic_op.register(np.ndarray | types.H5Dataset | types.ZarrArray)
# register explicitly to avoid the array API path and performance slow down
def _generic_op_numpy_disk(
x: np.ndarray | DiskArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype))
return getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return]


@generic_op.register(types.CupyArray | types.CupyCSMatrix)
Expand All @@ -62,7 +60,8 @@ def _generic_op_cupy(
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray | np.number[Any]:
arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype))
arr = cast("types.CupyArray | types.CupyCOOMatrix", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
arr = arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()


Expand Down Expand Up @@ -109,3 +108,22 @@ def _generic_op_dask(
dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype

return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array)


@generic_op.register(types.HasArrayNamespace)
def _generic_op_array_api[A: types.HasArrayNamespace](
x: A,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> A:
"""Handle arrays with native array API support."""
del keep_cupy_as_array

import array_api_compat

xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return]
14 changes: 7 additions & 7 deletions src/fast_array_utils/stats/_is_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from functools import partial, singledispatch
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

import numba
import numpy as np
Expand All @@ -19,29 +19,29 @@

@singledispatch
def is_constant_(
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # pragma: no cover
raise NotImplementedError
raise NotImplementedError # pragma: no cover


@is_constant_.register(np.ndarray | types.CupyArray)
@is_constant_.register(np.ndarray | types.CupyArray | types.HasArrayNamespace)
def _is_constant_ndarray(a: NDArray[Any] | types.CupyArray, /, *, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool] | types.CupyArray:
# Should eventually support nd, not now.
match axis:
case None:
return bool((a == a.flat[0]).all())
return bool((a == a.reshape(-1)[0]).all())
case 0:
return _is_constant_rows(a.T)
case 1:
return _is_constant_rows(a)


def _is_constant_rows(a: NDArray[Any] | types.CupyArray) -> NDArray[np.bool] | types.CupyArray:
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
return cast("NDArray[np.bool]", (a == b).all(axis=1))
# broadcasts without needing np.broadcast_to
return (a == a[:, 0:1]).all(axis=1)


@is_constant_.register(types.CSBase)
Expand Down
4 changes: 2 additions & 2 deletions src/fast_array_utils/stats/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@


def mean_(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray:
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type]
n = np.prod(x.shape) if axis is None else x.shape[axis]
return total / n # type: ignore[no-any-return]
Loading
Loading