Skip to content

Commit 7160bae

Browse files
committed
fix tests for numpy 1.x
1 parent 3611708 commit 7160bae

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/array_api_extra/_delegation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Literal
66

77
from ._lib import _funcs, _quantile
8+
from ._lib._backends import NUMPY_VERSION
89
from ._lib._utils._compat import (
910
array_namespace,
1011
is_cupy_namespace,
@@ -1047,7 +1048,6 @@ def quantile(
10471048
empirical cumulative distribution is simply replaced by its weighted
10481049
version, i.e.
10491050
:math:`P(Y \\leq t) = \\frac{1}{\\sum_i w_i} \\sum_i w_i 1_{x_i \\leq t}`.
1050-
Only ``method="inverted_cdf"`` supports weights.
10511051
10521052
References
10531053
----------
@@ -1125,14 +1125,15 @@ def quantile(
11251125
raise ValueError(msg)
11261126

11271127
# Delegate when possible.
1128-
if is_numpy_namespace(xp) and nan_policy == "propagate":
1128+
basic_case = method == "linear" and weights is None
1129+
np_2 = NUMPY_VERSION >= (2, 0)
1130+
if is_numpy_namespace(xp) and nan_policy == "propagate" and (basic_case or np_2):
11291131
# TODO: call nanquantile for nan_policy == "omit" once
11301132
# https://github.com/numpy/numpy/issues/29709 is fixed
11311133
return xp.quantile(
11321134
a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights
11331135
)
11341136
# No delegation for dask: I couldn't make it work.
1135-
basic_case = method == "linear" and weights is None
11361137
jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp)
11371138
if basic_case and nan_policy == "propagate" and jax_or_cupy:
11381139
return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims)
@@ -1141,8 +1142,8 @@ def quantile(
11411142
return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims)
11421143

11431144
# Otherwise call our implementation (will sort data)
1144-
# XXX: I'm not sure we want to support dask, it seems uterly slow...
11451145
return _quantile.quantile(
1146+
# XXX: I'm not sure we want to support dask, it seems uterly slow...
11461147
a,
11471148
q_arr,
11481149
axis=axis,

tests/test_funcs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,8 @@ def test_against_numpy(self, xp: ModuleType, keepdims: bool):
15771577
def test_weighted_against_numpy(
15781578
self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str
15791579
):
1580+
if NUMPY_VERSION < (2, 0):
1581+
pytest.xfail(reason="NumPy 1.x does not support weights in quantile")
15801582
rng = np.random.default_rng()
15811583
n, d = 10, 20
15821584
a_np = rng.random((n, d))

0 commit comments

Comments
 (0)