Skip to content

Commit fa789fc

Browse files
committed
Weighted quantile; nan-policy; everything mostly works
1 parent 19fa6ea commit fa789fc

File tree

3 files changed

+155
-52
lines changed

3 files changed

+155
-52
lines changed

src/array_api_extra/_delegation.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,7 @@ def quantile(
904904
axis: int | None = None,
905905
method: str = "linear",
906906
keepdims: bool = False,
907+
nan_policy: str = "propagate",
907908
*,
908909
weights: Array | None = None,
909910
xp: ModuleType | None = None,
@@ -1051,16 +1052,22 @@ def quantile(
10511052
"Sample quantiles in statistical packages,"
10521053
The American Statistician, 50(4), pp. 361-365, 1996
10531054
"""
1054-
methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"}
1055+
if xp is None:
1056+
xp = array_namespace(a)
1057+
if is_pydata_sparse_namespace(xp):
1058+
raise ValueError('no supported')
10551059

1060+
methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"}
10561061
if method not in methods:
10571062
msg = f"`method` must be one of {methods}"
10581063
raise ValueError(msg)
1064+
nan_policies = {"propagate", "omit"}
1065+
if nan_policy not in nan_policies:
1066+
msg = f"`nan_policy` must be one of {nan_policies}"
1067+
raise ValueError(msg)
10591068
if keepdims not in {True, False}:
10601069
msg = "If specified, `keepdims` must be True or False."
10611070
raise ValueError(msg)
1062-
if xp is None:
1063-
xp = array_namespace(a)
10641071

10651072
a = xp.asarray(a)
10661073
if not xp.isdtype(a.dtype, ("integral", "real floating")):
@@ -1071,15 +1078,31 @@ def quantile(
10711078
raise ValueError(msg)
10721079
ndim = a.ndim
10731080
if ndim < 1:
1074-
msg = "`a` must be at least 1-dimensional"
1081+
msg = "`a` must be at least 1-dimensional."
10751082
raise TypeError(msg)
10761083
if axis is not None and ((axis >= ndim) or (axis < -ndim)):
10771084
msg = "`axis` is not compatible with the dimension of `a`."
10781085
raise ValueError(msg)
1079-
1080-
# Array API states: Mixed integer and floating-point type promotion rules
1081-
# are not specified because behavior varies between implementations.
1082-
# We chose to align with numpy (see docstring):
1086+
if weights is None:
1087+
if nan_policy != "propagate":
1088+
msg = ""
1089+
raise ValueError(msg)
1090+
else:
1091+
if ndim > 2:
1092+
msg = "When weights are provided, dimension of `a` must be 1 or 2."
1093+
raise ValueError(msg)
1094+
if a.shape != weights.shape:
1095+
if axis is None:
1096+
msg = "Axis must be specified when shapes of `a` and ̀ weights` differ."
1097+
raise TypeError(msg)
1098+
if weights.shape != eager_shape(a, axis):
1099+
msg = "Shape of weights must be consistent with shape of a along specified axis."
1100+
raise ValueError(msg)
1101+
if axis is None and ndim == 2:
1102+
msg = "When weights are provided, axis must be specified when `a` is 2d"
1103+
raise ValueError(msg)
1104+
1105+
# Align result dtype with what numpy does:
10831106
dtype = xp.result_type(
10841107
xp.float64 if xp.isdtype(a.dtype, "integral") else a,
10851108
xp.asarray(q),
@@ -1088,20 +1111,25 @@ def quantile(
10881111
device = get_device(a)
10891112
a = xp.asarray(a, dtype=dtype, device=device)
10901113
q = xp.asarray(q, dtype=dtype, device=device)
1114+
# TODO: cast weights here? Assert weights are on the same device as `a`?
10911115

10921116
if xp.any((q > 1) | (q < 0) | xp.isnan(q)):
10931117
msg = "`q` values must be in the range [0, 1]"
10941118
raise ValueError(msg)
10951119

10961120
# Delegate where possible.
1097-
if is_numpy_namespace(xp):
1121+
if is_numpy_namespace(xp) and nan_policy == "propagate":
10981122
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims, weights=weights)
10991123
# No delegation for dask: I couldn't make it work
1100-
basic_case = method == "linear" and weights is None
1124+
basic_case = method == "linear" and weights is None and nan_policy == "propagate"
11011125
if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp):
11021126
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims)
11031127
if basic_case and is_torch_namespace(xp):
11041128
return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims)
11051129

1130+
# XXX: I'm not sure we want to support dask, it seems uterly slow...
11061131
# Otherwise call our implementation (will sort data)
1107-
return _quantile.quantile(a, q, axis=axis, method=method, keepdims=keepdims, xp=xp)
1132+
return _quantile.quantile(
1133+
a, q, axis=axis, method=method, keepdims=keepdims,
1134+
nan_policy=nan_policy, weights=weights, xp=xp
1135+
)

src/array_api_extra/_lib/_quantile.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._utils._compat import device as get_device
66
from ._utils._helpers import eager_shape
7-
from ._utils._typing import Array
7+
from ._utils._typing import Array, Device
88

99

1010
def quantile( # numpydoc ignore=PR01,RT01
@@ -14,6 +14,7 @@ def quantile( # numpydoc ignore=PR01,RT01
1414
method: str = "linear",
1515
axis: int | None = None,
1616
keepdims: bool = False,
17+
nan_policy: str = "propagate",
1718
*,
1819
weights: Array | None = None,
1920
xp: ModuleType,
@@ -43,43 +44,49 @@ def quantile( # numpydoc ignore=PR01,RT01
4344
a = xp.full(tuple(a_shape), xp.nan, dtype=a.dtype, device=device)
4445

4546
if weights is None:
46-
res = _quantile(a, q, float(n), axis, method, xp)
47+
res = _quantile(a, q, n, axis, method, xp)
48+
if not axis_none:
49+
res = xp.moveaxis(res, axis, 0)
4750
else:
51+
weights = xp.asarray(weights, dtype=xp.float64, device=device)
4852
average = method == 'averaged_inverted_cdf'
49-
res = _weighted_quantile(a, q, weights, n, axis, average, xp)
50-
# to support weights, the main thing would be to
51-
# argsort a, and then use it to sort a and w.
52-
# The hard part will be dealing with 0-weights and NaNs
53-
# But maybe a proper use of searchsorted + left/right side will work?
53+
res = _weighted_quantile(
54+
a, q, weights, n, axis, average,
55+
nan_policy=nan_policy, xp=xp, device=device
56+
)
5457

5558
# reshaping to conform to doc/other libs' behavior
5659
if axis_none:
5760
if keepdims:
5861
res = xp.reshape(res, q.shape + (1,) * a_ndim)
59-
else:
60-
res = xp.moveaxis(res, axis, 0)
61-
if keepdims:
62-
a_shape[axis] = 1
63-
res = xp.reshape(res, q.shape + tuple(a_shape))
62+
elif keepdims:
63+
a_shape[axis] = 1
64+
res = xp.reshape(res, q.shape + tuple(a_shape))
6465

6566
return res[0, ...] if q_scalar else res
6667

6768

6869
def _quantile( # numpydoc ignore=GL08
69-
a: Array, q: Array, n: float, axis: int, method: str, xp: ModuleType
70+
a: Array, q: Array, n: int, axis: int, method: str, xp: ModuleType
7071
) -> Array:
7172
a = xp.sort(a, axis=axis, stable=False)
73+
mask_nan = xp.any(xp.isnan(a), axis=axis, keepdims=True)
74+
if xp.any(mask_nan):
75+
# propogate NaNs:
76+
mask = xp.repeat(mask_nan, n, axis=axis)
77+
a = xp.where(mask, xp.nan, a)
78+
del mask
7279

7380
if method == "linear":
74-
m = 1 - q
81+
m = 1 - q
7582
else: # method is "inverted_cdf" or "averaged_inverted_cdf"
7683
m = 0
7784

78-
jg = q * n + m - 1
85+
jg = q * float(n) + m - 1
7986

8087
j = jg // 1
81-
j = xp.clip(j, 0.0, n - 1)
82-
jp1 = xp.clip(j + 1, 0.0, n - 1)
88+
j = xp.clip(j, 0.0, float(n - 1))
89+
jp1 = xp.clip(j + 1, 0.0, float(n - 1))
8390
# `̀j` and `jp1` are 1d arrays
8491

8592
g = jg % 1
@@ -88,7 +95,7 @@ def _quantile( # numpydoc ignore=GL08
8895
elif method == 'averaged_inverted_cdf':
8996
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
9097

91-
g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with strictest
98+
g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with readonly
9299
new_g_shape = [1] * a.ndim
93100
new_g_shape[axis] = g.shape[0]
94101
g = xp.reshape(g, tuple(new_g_shape))
@@ -98,37 +105,55 @@ def _quantile( # numpydoc ignore=GL08
98105
)
99106

100107

101-
def _weighted_quantile(a: Array, q: Array, weights: Array, n: int, axis, average: bool, xp: ModuleType):
108+
def _weighted_quantile(
109+
a: Array, q: Array, weights: Array, n: int, axis: int, average: bool, nan_policy: str,
110+
xp: ModuleType, device: Device
111+
) -> Array:
112+
"""
113+
a is expected to be 1d or 2d.
114+
"""
115+
kwargs = dict(n=n, average=average, nan_policy=nan_policy, xp=xp, device=device)
102116
a = xp.moveaxis(a, axis, -1)
117+
if weights.ndim > 1:
118+
weights = xp.moveaxis(weights, axis, -1)
103119
sorter = xp.argsort(a, axis=-1, stable=False)
104-
a = xp.take_along_axis(a, sorter, axis=-1)
105120

106121
if a.ndim == 1:
107-
return _weighted_quantile_sorted_1d(a, q, weights, n, )
122+
x = xp.take(a, sorter)
123+
w = xp.take(weights, sorter)
124+
return _weighted_quantile_sorted_1d(x, q, w, **kwargs)
108125

109126
d, = eager_shape(a, axis=0)
110-
res = xp.empty((q.shape[0], d))
127+
res = []
111128
for idx in range(d):
112129
w = weights if weights.ndim == 1 else weights[idx, ...]
113130
w = xp.take(w, sorter[idx, ...])
114-
res[..., idx] = _weighted_quantile_sorted_1d(a[idx, ...], q, w, n, average)
131+
x = xp.take(a[idx, ...], sorter[idx, ...])
132+
res.append(_weighted_quantile_sorted_1d(x, q, w, **kwargs))
133+
res = xp.stack(res, axis=1)
115134
return res
116135

117136

118-
def _weighted_quantile_sorted_1d(a, q, w, n, average: bool, xp: ModuleType):
119-
cw = xp.cumsum(w)
137+
def _weighted_quantile_sorted_1d(
138+
x: Array, q: Array, w: Array, n: int, average: bool, nan_policy: str,
139+
xp: ModuleType, device: Device
140+
) -> Array:
141+
if nan_policy == "omit":
142+
w = xp.where(xp.isnan(x), 0., w)
143+
elif xp.any(xp.isnan(x)):
144+
return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device)
145+
cw = xp.cumulative_sum(w)
120146
t = cw[-1] * q
121-
i = xp.searchsorted(cw, t)
147+
i = xp.searchsorted(cw, t, side='left')
122148
j = xp.searchsorted(cw, t, side='right')
123-
i = xp.minimum(i, float(n - 1))
124-
j = xp.minimum(j, float(n - 1))
149+
i = xp.clip(i, 0, n - 1)
150+
j = xp.clip(j, 0, n - 1)
125151

126152
# Ignore leading `weights=0` observations when `q=0`
127153
# see https://github.com/scikit-learn/scikit-learn/pull/20528
128-
i = xp.where(q == 0., j, i)
154+
i = xp.where(q == 0., j, i)
129155
if average:
130156
# Ignore trailing `weights=0` observations when `q=1`
131157
j = xp.where(q == 1., i, j)
132-
return (xp.take(a, i) + xp.take(a, j)) / 2
133-
else:
134-
return xp.take(a, i)
158+
return (xp.take(x, i) + xp.take(x, j)) / 2
159+
return xp.take(x, i)

tests/test_funcs.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,18 +1558,70 @@ def test_shape(self, xp: ModuleType):
15581558
assert quantile(a, q, axis=1, keepdims=True).shape == (2, 3, 1, 5)
15591559
assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1)
15601560

1561-
def test_against_numpy(self, xp: ModuleType):
1561+
@pytest.mark.parametrize("keepdims", [True, False])
1562+
def test_against_numpy(self, xp: ModuleType, keepdims: bool):
15621563
rng = np.random.default_rng()
15631564
a_np = rng.random((3, 4, 5))
15641565
q_np = rng.random(2)
15651566
a = xp.asarray(a_np)
15661567
q = xp.asarray(q_np)
1567-
for keepdims in [False, True]:
1568-
for axis in [None, *range(a.ndim)]:
1569-
actual = quantile(a, q, axis=axis, keepdims=keepdims)
1570-
expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims)
1571-
expected = xp.asarray(expected, dtype=xp.float64)
1572-
xp_assert_close(actual, expected, atol=1e-12)
1568+
for axis in [None, *range(a.ndim)]:
1569+
actual = quantile(a, q, axis=axis, keepdims=keepdims)
1570+
expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims)
1571+
expected = xp.asarray(expected)
1572+
xp_assert_close(actual, expected, atol=1e-12)
1573+
1574+
@pytest.mark.parametrize("keepdims", [True, False])
1575+
@pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"])#, #["omit"])#["no_nans", "propagate"])
1576+
@pytest.mark.parametrize("q_np", [0.5, 0., 1., np.linspace(0, 1, num=11)])
1577+
def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str):
1578+
rng = np.random.default_rng()
1579+
n, d = 10, 20
1580+
a_np = rng.random((n, d))
1581+
kwargs = dict(keepdims=keepdims)
1582+
mask_nan = np.zeros((n, d), dtype=bool)
1583+
if nan_policy != "no_nans":
1584+
# from 0% to 100% of NaNs:
1585+
mask_nan = rng.random((n, d)) < rng.random((n, 1))
1586+
# don't put nans in the first row:
1587+
mask_nan[:] = False
1588+
a_np[mask_nan] = np.nan
1589+
kwargs['nan_policy'] = nan_policy
1590+
1591+
a = xp.asarray(a_np)
1592+
q = xp.asarray(np.copy(q_np))
1593+
m = 'inverted_cdf'
1594+
1595+
np_quantile = np.quantile
1596+
if nan_policy == "omit":
1597+
np_quantile = np.nanquantile
1598+
1599+
for w_np, axis in [
1600+
(rng.random(n), 0),
1601+
(rng.random(d), 1),
1602+
(rng.integers(0, 2, n), 0),
1603+
(rng.integers(0, 2, d), 1),
1604+
(rng.integers(0, 2, (n, d)), 0),
1605+
(rng.integers(0, 2, (n, d)), 1),
1606+
]:
1607+
print(w_np)
1608+
with warnings.catch_warnings(record=True) as warning:
1609+
warnings.filterwarnings("always", "invalid value encountered in divide", RuntimeWarning)
1610+
warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning)
1611+
try:
1612+
expected = np_quantile(a_np, q_np, axis=axis, method=m, weights=w_np, keepdims=keepdims)
1613+
except IndexError:
1614+
print('index error')
1615+
continue
1616+
if warning: # this means some weights sum was 0, in this case we skip calling xpx.quantile
1617+
print('warning')
1618+
continue
1619+
expected = xp.asarray(expected)
1620+
print("not skiped")
1621+
1622+
w = xp.asarray(w_np)
1623+
actual = quantile(a, q, axis=axis, method=m, weights=w, **kwargs)
1624+
xp_assert_close(actual, expected, atol=1e-12)
15731625

15741626
def test_2d_axis(self, xp: ModuleType):
15751627
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
@@ -1605,8 +1657,6 @@ def test_edge_cases(self, xp: ModuleType):
16051657

16061658
def test_invalid_q(self, xp: ModuleType):
16071659
x = xp.asarray([1, 2, 3, 4, 5])
1608-
_ = quantile(x, 1.0)
1609-
# ^ FIXME: here just to make this test fail for sparse backend
16101660
# q > 1 should raise
16111661
with pytest.raises(
16121662
ValueError, match=r"`q` values must be in the range \[0, 1\]"

0 commit comments

Comments
 (0)