Skip to content

Commit 0b2cb9b

Browse files
committed
working on coverage
1 parent 7160bae commit 0b2cb9b

File tree

3 files changed

+84
-47
lines changed

3 files changed

+84
-47
lines changed

src/array_api_extra/_delegation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,7 @@ def quantile(
948948
949949
nan_policy : str, optional
950950
'propagate' (default) or 'omit'.
951+
'omit' is support only when `weights` are provided.
951952
952953
weights : array_like, optional
953954
An array of weights associated with the values in `a`. Each value in
@@ -1125,19 +1126,26 @@ def quantile(
11251126
raise ValueError(msg)
11261127

11271128
# Delegate when possible.
1129+
# Note: No delegation for dask: I couldn't make it work.
11281130
basic_case = method == "linear" and weights is None
1131+
11291132
np_2 = NUMPY_VERSION >= (2, 0)
1130-
if is_numpy_namespace(xp) and nan_policy == "propagate" and (basic_case or np_2):
1133+
np_handles_weights = np_2 and nan_policy == "propagate" and method == "inverted_cdf"
1134+
if weights is None:
1135+
if is_numpy_namespace(xp) and (basic_case or np_2):
1136+
quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile
1137+
return quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims)
1138+
elif is_numpy_namespace(xp) and np_handles_weights:
11311139
# TODO: call nanquantile for nan_policy == "omit" once
11321140
# https://github.com/numpy/numpy/issues/29709 is fixed
11331141
return xp.quantile(
11341142
a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights
11351143
)
1136-
# No delegation for dask: I couldn't make it work.
1144+
11371145
jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp)
1138-
if basic_case and nan_policy == "propagate" and jax_or_cupy:
1146+
if jax_or_cupy and basic_case and nan_policy == "propagate":
11391147
return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims)
1140-
if basic_case and is_torch_namespace(xp):
1148+
if is_torch_namespace(xp) and basic_case:
11411149
quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile
11421150
return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims)
11431151

src/array_api_extra/_lib/_quantile.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ def quantile( # numpydoc ignore=PR01,RT01
3636
axis = int(axis)
3737

3838
(n,) = eager_shape(a, axis)
39-
# If data has length zero along `axis`, the result will be an array of NaNs just
40-
# as if the data had length 1 along axis and were filled with NaNs.
41-
if n == 0:
42-
a_shape[axis] = 1
43-
n = 1
44-
a = xp.full(tuple(a_shape), xp.nan, dtype=a.dtype, device=device)
4539

4640
if weights is None:
4741
res = _quantile(a, q, n, axis, method, xp)
@@ -93,12 +87,7 @@ def _quantile( # numpydoc ignore=PR01,RT01
9387
)
9488

9589
jg = q * float(n) + m - 1
96-
9790
j = jg // 1
98-
j = xp.clip(j, 0.0, float(n - 1))
99-
jp1 = xp.clip(j + 1, 0.0, float(n - 1))
100-
# `̀j` and `jp1` are 1d arrays
101-
10291
g = jg % 1
10392
if method == "inverted_cdf":
10493
g = xp.astype((g > 0), jg.dtype)
@@ -110,6 +99,10 @@ def _quantile( # numpydoc ignore=PR01,RT01
11099
new_g_shape[axis] = g.shape[0]
111100
g = xp.reshape(g, tuple(new_g_shape))
112101

102+
j = xp.clip(j, 0.0, float(n - 1))
103+
jp1 = xp.clip(j + 1, 0.0, float(n - 1))
104+
# `̀j` and `jp1` are 1d arrays
105+
113106
return (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) + g * xp.take(
114107
a, xp.astype(jp1, xp.int64), axis=axis
115108
)

tests/test_funcs.py

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import warnings
33
from types import ModuleType
4-
from typing import Any, cast
4+
from typing import Any, Literal, cast, get_args
55

66
import hypothesis
77
import hypothesis.extra.numpy as npst
@@ -1531,6 +1531,7 @@ def test_kind(self, xp: ModuleType, library: Backend):
15311531
res = isin(a, b, kind="sort")
15321532
xp_assert_equal(res, expected)
15331533

1534+
METHOD = Literal["linear", "inverted_cdf", "averaged_inverted_cdf"]
15341535

15351536
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no xp.take")
15361537
class TestQuantile:
@@ -1558,21 +1559,67 @@ def test_shape(self, xp: ModuleType):
15581559
assert quantile(a, q, axis=1, keepdims=True).shape == (2, 3, 1, 5)
15591560
assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1)
15601561

1562+
@pytest.mark.parametrize("with_nans", ["no_nans", "with_nans"])
1563+
@pytest.mark.parametrize("method", get_args(METHOD))
1564+
def test_against_numpy_1d(self, xp: ModuleType, with_nans: str, method: METHOD):
1565+
rng = np.random.default_rng()
1566+
a_np = rng.random(40)
1567+
if with_nans == "with_nans":
1568+
a_np[rng.random(a_np.shape) < rng.random() * 0.5] = np.nan
1569+
q_np = np.asarray([0, *rng.random(2), 1])
1570+
a = xp.asarray(a_np)
1571+
q = xp.asarray(q_np)
1572+
1573+
actual = quantile(a, q, method=method)
1574+
expected = np.quantile(a_np, q_np, method=method)
1575+
expected = xp.asarray(expected)
1576+
xp_assert_close(actual, expected)
1577+
1578+
@pytest.mark.parametrize("with_nans", ["no_nans", "with_nans"])
1579+
@pytest.mark.parametrize("method", get_args(METHOD))
15611580
@pytest.mark.parametrize("keepdims", [True, False])
1562-
def test_against_numpy(self, xp: ModuleType, keepdims: bool):
1581+
def test_against_numpy_nd(self, xp: ModuleType, keepdims: bool,
1582+
with_nans: str, method: METHOD):
15631583
rng = np.random.default_rng()
15641584
a_np = rng.random((3, 4, 5))
1585+
if with_nans == "with_nans":
1586+
a_np[rng.random(a_np.shape) < rng.random()] = np.nan
15651587
q_np = rng.random(2)
15661588
a = xp.asarray(a_np)
15671589
q = xp.asarray(q_np)
15681590
for axis in [None, *range(a.ndim)]:
1569-
actual = quantile(a, q, axis=axis, keepdims=keepdims)
1570-
expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims)
1591+
actual = quantile(a, q, axis=axis, keepdims=keepdims, method=method)
1592+
expected = np.quantile(
1593+
a_np, q_np, axis=axis, keepdims=keepdims, method=method
1594+
)
15711595
expected = xp.asarray(expected)
1572-
xp_assert_close(actual, expected, atol=1e-12)
1596+
xp_assert_close(actual, expected)
1597+
1598+
@pytest.mark.parametrize("nan_policy", ["no_nans", "propagate"])
1599+
@pytest.mark.parametrize("with_weights", ["with_weights", "no_weights"])
1600+
def test_against_median(
1601+
self, xp: ModuleType, nan_policy: str, with_weights: str,
1602+
):
1603+
rng = np.random.default_rng()
1604+
n = 40
1605+
a_np = rng.random(n)
1606+
w_np = rng.integers(0, 2, n) if with_weights == "with_weights" else None
1607+
if nan_policy == "no_nans":
1608+
nan_policy = "propagate"
1609+
else:
1610+
# from 0% to 50% of NaNs:
1611+
a_np[rng.random(n) < rng.random(n) * 0.5] = np.nan
1612+
m = "averaged_inverted_cdf"
1613+
1614+
np_median = np.nanmedian if nan_policy == "omit" else np.median
1615+
expected = np_median(a_np if w_np is None else a_np[w_np > 0])
1616+
a = xp.asarray(a_np)
1617+
w = xp.asarray(w_np) if w_np is not None else None
1618+
actual = quantile(a, 0.5, method=m, nan_policy=nan_policy, weights=w)
1619+
xp_assert_close(actual, xp.asarray(expected))
15731620

15741621
@pytest.mark.parametrize("keepdims", [True, False])
1575-
@pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"])
1622+
@pytest.mark.parametrize("nan_policy", ["no_nans", "propagate", "omit"])
15761623
@pytest.mark.parametrize("q_np", [0.5, 0.0, 1.0, np.linspace(0, 1, num=11)])
15771624
def test_weighted_against_numpy(
15781625
self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str
@@ -1581,7 +1628,7 @@ def test_weighted_against_numpy(
15811628
pytest.xfail(reason="NumPy 1.x does not support weights in quantile")
15821629
rng = np.random.default_rng()
15831630
n, d = 10, 20
1584-
a_np = rng.random((n, d))
1631+
a_2d = rng.random((n, d))
15851632
mask_nan = np.zeros((n, d), dtype=bool)
15861633
if nan_policy == "no_nans":
15871634
nan_policy = "propagate"
@@ -1590,36 +1637,36 @@ def test_weighted_against_numpy(
15901637
mask_nan = rng.random((n, d)) < rng.random((n, 1))
15911638
# don't put nans in the first row:
15921639
mask_nan[:] = False
1593-
a_np[mask_nan] = np.nan
1640+
a_2d[mask_nan] = np.nan
15941641

1595-
a = xp.asarray(a_np, copy=True)
15961642
q = xp.asarray(q_np, copy=True)
1597-
m = "inverted_cdf"
1643+
m: METHOD = "inverted_cdf"
15981644

15991645
np_quantile = np.quantile
16001646
if nan_policy == "omit":
16011647
np_quantile = np.nanquantile
16021648

1603-
for w_np, axis in [
1604-
(rng.random(n), 0),
1605-
(rng.random(d), 1),
1606-
(rng.integers(0, 2, n), 0),
1607-
(rng.integers(0, 2, d), 1),
1608-
(rng.integers(0, 2, (n, d)), 0),
1609-
(rng.integers(0, 2, (n, d)), 1),
1649+
for a_np, w_np, axis in [
1650+
(a_2d, rng.random(n), 0),
1651+
(a_2d, rng.random(d), 1),
1652+
(a_2d[0], rng.random(d), None),
1653+
(a_2d, rng.integers(0, 3, n), 0),
1654+
(a_2d, rng.integers(0, 2, d), 1),
1655+
(a_2d, rng.integers(0, 2, (n, d)), 0),
1656+
(a_2d, rng.integers(0, 3, (n, d)), 1),
16101657
]:
16111658
with warnings.catch_warnings(record=True) as warning:
16121659
divide_msg = "invalid value encountered in divide"
16131660
warnings.filterwarnings("always", divide_msg, RuntimeWarning)
16141661
nan_slice_msg = "All-NaN slice encountered"
16151662
warnings.filterwarnings("ignore", nan_slice_msg, RuntimeWarning)
16161663
try:
1617-
expected = np_quantile( # type: ignore[call-overload]
1664+
expected = np_quantile(
16181665
a_np,
16191666
np.asarray(q_np),
16201667
axis=axis,
16211668
method=m,
1622-
weights=w_np,
1669+
weights=w_np, # type: ignore[arg-type]
16231670
keepdims=keepdims,
16241671
)
16251672
except IndexError:
@@ -1630,6 +1677,7 @@ def test_weighted_against_numpy(
16301677
continue
16311678
expected = xp.asarray(expected)
16321679

1680+
a = xp.asarray(a_np)
16331681
w = xp.asarray(w_np)
16341682
actual = quantile(
16351683
a,
@@ -1640,19 +1688,7 @@ def test_weighted_against_numpy(
16401688
keepdims=keepdims,
16411689
nan_policy=nan_policy,
16421690
)
1643-
xp_assert_close(actual, expected, atol=1e-12)
1644-
1645-
def test_2d_axis(self, xp: ModuleType):
1646-
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
1647-
actual = quantile(x, 0.5, axis=0)
1648-
expect = xp.asarray([2.5, 3.5, 4.5], dtype=xp.float64)
1649-
xp_assert_close(actual, expect)
1650-
1651-
def test_2d_axis_keepdims(self, xp: ModuleType):
1652-
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
1653-
actual = quantile(x, 0.5, axis=0, keepdims=True)
1654-
expect = xp.asarray([[2.5, 3.5, 4.5]], dtype=xp.float64)
1655-
xp_assert_close(actual, expect)
1691+
xp_assert_close(actual, expected)
16561692

16571693
def test_methods(self, xp: ModuleType):
16581694
x = xp.asarray([1, 2, 3, 4, 5])

0 commit comments

Comments
 (0)