Skip to content

Commit 19fa6ea

Browse files
committed
WIP: adding support for weights
1 parent 89d8410 commit 19fa6ea

File tree

3 files changed

+81
-23
lines changed

3 files changed

+81
-23
lines changed

src/array_api_extra/_delegation.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def quantile(
905905
method: str = "linear",
906906
keepdims: bool = False,
907907
*,
908+
weights: Array | None = None,
908909
xp: ModuleType | None = None,
909910
) -> Array:
910911
"""
@@ -943,6 +944,16 @@ def quantile(
943944
the result as dimensions with size one. With this option, the
944945
result will broadcast correctly against the original array `a`.
945946
947+
weights : array_like, optional
948+
An array of weights associated with the values in `a`. Each value in
949+
`a` contributes to the quantile according to its associated weight.
950+
The weights array can either be 1-D (in which case its length must be
951+
the size of `a` along the given axis) or of the same shape as `a`.
952+
If `weights=None`, then all data in `a` are assumed to have a
953+
weight equal to one.
954+
Only `method="inverted_cdf"` or `method="averaged_inverted_cdf"`
955+
support weights. See the notes for more details.
956+
946957
xp : array_namespace, optional
947958
The standard-compatible namespace for `a` and `q`. Default: infer.
948959
@@ -1040,7 +1051,7 @@ def quantile(
10401051
"Sample quantiles in statistical packages,"
10411052
The American Statistician, 50(4), pp. 361-365, 1996
10421053
"""
1043-
methods = {"linear"}
1054+
methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"}
10441055

10451056
if method not in methods:
10461057
msg = f"`method` must be one of {methods}"
@@ -1084,12 +1095,12 @@ def quantile(
10841095

10851096
# Delegate where possible.
10861097
if is_numpy_namespace(xp):
1098+
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims, weights=weights)
1099+
# No delegation for dask: I couldn't make it work
1100+
basic_case = method == "linear" and weights is None
1101+
if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp):
10871102
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims)
1088-
# No delegating for dask: I couldn't make it work
1089-
is_linear = method == "linear"
1090-
if (is_linear and is_jax_namespace(xp)) or is_cupy_namespace(xp):
1091-
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims)
1092-
if is_linear and is_torch_namespace(xp):
1103+
if basic_case and is_torch_namespace(xp):
10931104
return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims)
10941105

10951106
# Otherwise call our implementation (will sort data)

src/array_api_extra/_lib/_quantile.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,22 @@
99

1010
def quantile( # numpydoc ignore=PR01,RT01
1111
a: Array,
12-
q: Array | float,
12+
q: Array,
1313
/,
14-
method: str = "linear", # noqa: ARG001
14+
method: str = "linear",
1515
axis: int | None = None,
1616
keepdims: bool = False,
1717
*,
18+
weights: Array | None = None,
1819
xp: ModuleType,
1920
) -> Array:
2021
"""See docstring in `array_api_extra._delegation.py`."""
2122
device = get_device(a)
22-
floating_dtype = xp.float64 # xp.result_type(a, xp.asarray(q))
23-
a = xp.asarray(a, dtype=floating_dtype, device=device)
2423
a_shape = list(a.shape)
25-
p: Array = xp.asarray(q, dtype=floating_dtype, device=device)
2624

27-
q_scalar = p.ndim == 0
25+
q_scalar = q.ndim == 0
2826
if q_scalar:
29-
p = xp.reshape(p, (1,))
27+
q = xp.reshape(q, (1,))
3028

3129
axis_none = axis is None
3230
a_ndim = a.ndim
@@ -42,33 +40,41 @@ def quantile( # numpydoc ignore=PR01,RT01
4240
if n == 0:
4341
a_shape[axis] = 1
4442
n = 1
45-
a = xp.full(tuple(a_shape), xp.nan, dtype=floating_dtype, device=device)
43+
a = xp.full(tuple(a_shape), xp.nan, dtype=a.dtype, device=device)
4644

47-
a = xp.sort(a, axis=axis, stable=False)
45+
if weights is None:
46+
res = _quantile(a, q, float(n), axis, method, xp)
47+
else:
48+
average = method == 'averaged_inverted_cdf'
49+
res = _weighted_quantile(a, q, weights, n, axis, average, xp)
4850
# to support weights, the main thing would be to
4951
# argsort a, and then use it to sort a and w.
5052
# The hard part will be dealing with 0-weights and NaNs
5153
# But maybe a proper use of searchsorted + left/right side will work?
5254

53-
res = _quantile_hf(a, p, float(n), axis, xp)
54-
5555
# reshaping to conform to doc/other libs' behavior
5656
if axis_none:
5757
if keepdims:
58-
res = xp.reshape(res, p.shape + (1,) * a_ndim)
58+
res = xp.reshape(res, q.shape + (1,) * a_ndim)
5959
else:
6060
res = xp.moveaxis(res, axis, 0)
6161
if keepdims:
6262
a_shape[axis] = 1
63-
res = xp.reshape(res, p.shape + tuple(a_shape))
63+
res = xp.reshape(res, q.shape + tuple(a_shape))
6464

6565
return res[0, ...] if q_scalar else res
6666

6767

68-
def _quantile_hf( # numpydoc ignore=GL08
69-
a: Array, q: Array, n: float, axis: int, xp: ModuleType
68+
def _quantile( # numpydoc ignore=GL08
69+
a: Array, q: Array, n: float, axis: int, method: str, xp: ModuleType
7070
) -> Array:
71-
m = 1 - q
71+
a = xp.sort(a, axis=axis, stable=False)
72+
73+
if method == "linear":
74+
m = 1 - q
75+
else: # method is "inverted_cdf" or "averaged_inverted_cdf"
76+
m = 0
77+
7278
jg = q * n + m - 1
7379

7480
j = jg // 1
@@ -77,6 +83,11 @@ def _quantile_hf( # numpydoc ignore=GL08
7783
# `̀j` and `jp1` are 1d arrays
7884

7985
g = jg % 1
86+
if method == 'inverted_cdf':
87+
g = xp.astype((g > 0), jg.dtype)
88+
elif method == 'averaged_inverted_cdf':
89+
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
90+
8091
g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with strictest
8192
new_g_shape = [1] * a.ndim
8293
new_g_shape[axis] = g.shape[0]
@@ -85,3 +96,39 @@ def _quantile_hf( # numpydoc ignore=GL08
8596
return (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) + g * xp.take(
8697
a, xp.astype(jp1, xp.int64), axis=axis
8798
)
99+
100+
101+
def _weighted_quantile(a: Array, q: Array, weights: Array, n: int, axis, average: bool, xp: ModuleType):
102+
a = xp.moveaxis(a, axis, -1)
103+
sorter = xp.argsort(a, axis=-1, stable=False)
104+
a = xp.take_along_axis(a, sorter, axis=-1)
105+
106+
if a.ndim == 1:
107+
return _weighted_quantile_sorted_1d(a, q, weights, n, )
108+
109+
d, = eager_shape(a, axis=0)
110+
res = xp.empty((q.shape[0], d))
111+
for idx in range(d):
112+
w = weights if weights.ndim == 1 else weights[idx, ...]
113+
w = xp.take(w, sorter[idx, ...])
114+
res[..., idx] = _weighted_quantile_sorted_1d(a[idx, ...], q, w, n, average)
115+
return res
116+
117+
118+
def _weighted_quantile_sorted_1d(a, q, w, n, average: bool, xp: ModuleType):
119+
cw = xp.cumsum(w)
120+
t = cw[-1] * q
121+
i = xp.searchsorted(cw, t)
122+
j = xp.searchsorted(cw, t, side='right')
123+
i = xp.minimum(i, float(n - 1))
124+
j = xp.minimum(j, float(n - 1))
125+
126+
# Ignore leading `weights=0` observations when `q=0`
127+
# see https://github.com/scikit-learn/scikit-learn/pull/20528
128+
i = xp.where(q == 0., j, i)
129+
if average:
130+
# Ignore trailing `weights=0` observations when `q=1`
131+
j = xp.where(q == 1., i, j)
132+
return (xp.take(a, i) + xp.take(a, j)) / 2
133+
else:
134+
return xp.take(a, i)

tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ def test_2d_axis_keepdims(self, xp: ModuleType):
15851585

15861586
def test_methods(self, xp: ModuleType):
15871587
x = xp.asarray([1, 2, 3, 4, 5])
1588-
methods = ["linear"] # "hazen", "weibull"]
1588+
methods = ["linear", "inverted_cdf", "averaged_inverted_cdf"]
15891589
for method in methods:
15901590
actual = quantile(x, 0.5, method=method)
15911591
# All methods should give reasonable results

0 commit comments

Comments
 (0)