Skip to content

Commit 98fe39f

Browse files
committed
draft version with some tests that are passing
1 parent dc7a1e5 commit 98fe39f

File tree

4 files changed

+132
-10
lines changed

4 files changed

+132
-10
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
one_hot,
1212
pad,
1313
partition,
14+
quantile,
1415
sinc,
1516
)
1617
from ._lib._at import at
@@ -48,6 +49,7 @@
4849
"one_hot",
4950
"pad",
5051
"partition",
52+
"quantile",
5153
"setdiff1d",
5254
"sinc",
5355
]

src/array_api_extra/_delegation.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from types import ModuleType
55
from typing import Literal
66

7-
from ._lib import _funcs
7+
from ._lib import _funcs, _quantile
88
from ._lib._utils._compat import (
99
array_namespace,
1010
is_cupy_namespace,
@@ -930,18 +930,33 @@ def quantile(
930930
if ndim < 1:
931931
msg = "`a` must be at least 1-dimensional"
932932
raise TypeError(msg)
933-
if (axis >= ndim) or (axis < -ndim):
933+
if axis is not None and ((axis >= ndim) or (axis < -ndim)):
934934
message = "`axis` is not compatible with the dimension of `a`."
935935
raise ValueError(message)
936936

937+
# Array API states: Mixed integer and floating-point type promotion rules
938+
# are not specified because behavior varies between implementations.
939+
# => We choose to do:
940+
dtype = (
941+
xp.float64 if xp.isdtype(a.dtype, 'integral')
942+
else xp.result_type(a, xp.asarray(q)) # both a and q are floats
943+
)
944+
device = get_device(a)
945+
a = xp.asarray(a, dtype=dtype, device=device)
946+
q = xp.asarray(q, dtype=dtype, device=device)
947+
948+
if xp.any((q > 1) | (q < 0) | xp.isnan(q)):
949+
raise ValueError("`q` values must be in the range [0, 1]")
950+
937951
# Delegate where possible.
938-
if is_numpy_namespace(xp) or is_dask_namespace(xp):
952+
if is_numpy_namespace(xp):
939953
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims)
954+
# No delegating for dask: I couldn't make it work
940955
is_linear = method == "linear"
941956
if (is_linear and is_jax_namespace(xp)) or is_cupy_namespace(xp):
942957
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims)
943958
if is_linear and is_torch_namespace(xp):
944959
return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims)
945960

946961
# Otherwise call our implementation (will sort data)
947-
return _funcs.quantile(a, q, axis=axis, method=method, keepdims=keepdims, xp=xp)
962+
return _quantile.quantile(a, q, axis=axis, method=method, keepdims=keepdims, xp=xp)

src/array_api_extra/_lib/_quantile.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def quantile( # numpydoc ignore=PR01,RT01
1717
):
1818
"""See docstring in `array_api_extra._delegation.py`."""
1919
device = get_device(a)
20-
floating_dtype = xp.result_type(a, xp.asarray(q))
20+
floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q))
2121
a = xp.asarray(a, dtype=floating_dtype, device=device)
2222
q = xp.asarray(q, dtype=floating_dtype, device=device)
2323

@@ -29,12 +29,13 @@ def quantile( # numpydoc ignore=PR01,RT01
2929
q = xp.reshape(q, (1,))
3030

3131
axis_none = axis is None
32+
a_ndim = a.ndim
3233
if axis_none:
3334
a = xp.reshape(a, (-1,))
3435
axis = 0
3536
axis = int(axis)
3637

37-
n = eager_shape(a, axis)
38+
n, = eager_shape(a, axis)
3839
# If data has length zero along `axis`, the result will be an array of NaNs just
3940
# as if the data had length 1 along axis and were filled with NaNs.
4041
if n == 0:
@@ -49,12 +50,12 @@ def quantile( # numpydoc ignore=PR01,RT01
4950
# The hard part will be dealing with 0-weights and NaNs
5051
# But maybe a proper use of searchsorted + left/right side will work?
5152

52-
res = _quantile_hf(a, q, n, axis, xp)
53+
res = _quantile_hf(a, q, float(n), axis, xp)
5354

5455
# reshaping to conform to doc/other libs' behavior
5556
if axis_none:
5657
if keepdims:
57-
res = xp.reshape(res, q.shape + (1,) * a.ndim)
58+
res = xp.reshape(res, q.shape + (1,) * a_ndim)
5859
else:
5960
res = xp.moveaxis(res, axis, 0)
6061
if keepdims:
@@ -69,13 +70,18 @@ def quantile( # numpydoc ignore=PR01,RT01
6970
def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType):
7071
m = 1 - p
7172
jg = p*n + m - 1
73+
7274
j = jg // 1
73-
g = jg % 1
74-
g[j < 0] = 0
7575
j = xp.clip(j, 0., n - 1)
7676
jp1 = xp.clip(j + 1, 0., n - 1)
7777
# `̀j` and `jp1` are 1d arrays
7878

79+
g = jg % 1
80+
g = xp.where(j < 0, 0, g) # equiv to g[j < 0] = 0, but work with strictest
81+
new_g_shape = [1] * y.ndim
82+
new_g_shape[axis] = g.shape[0]
83+
g = xp.reshape(g, tuple(new_g_shape))
84+
7985
return (
8086
(1 - g) * xp.take(y, xp.astype(j, xp.int64), axis=axis)
8187
+ g * xp.take(y, xp.astype(jp1, xp.int64), axis=axis)

tests/test_funcs.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
one_hot,
3030
pad,
3131
partition,
32+
quantile,
3233
setdiff1d,
3334
sinc,
3435
)
@@ -1529,3 +1530,101 @@ def test_kind(self, xp: ModuleType, library: Backend):
15291530
expected = xp.asarray([False, True, False, True])
15301531
res = isin(a, b, kind="sort")
15311532
xp_assert_equal(res, expected)
1533+
1534+
1535+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no xp.take")
1536+
class TestQuantile:
1537+
def test_basic(self, xp: ModuleType):
1538+
x = xp.asarray([1, 2, 3, 4, 5])
1539+
actual = quantile(x, 0.5)
1540+
expect = xp.asarray(3.0, dtype=xp.float64)
1541+
xp_assert_close(actual, expect)
1542+
1543+
def test_multiple_quantiles(self, xp: ModuleType):
1544+
x = xp.asarray([1, 2, 3, 4, 5])
1545+
actual = quantile(x, xp.asarray([0.25, 0.5, 0.75]))
1546+
expect = xp.asarray([2.0, 3.0, 4.0], dtype=xp.float64)
1547+
xp_assert_close(actual, expect)
1548+
1549+
def test_shape(self, xp: ModuleType):
1550+
a = xp.asarray(np.random.rand(3, 4, 5))
1551+
q = xp.asarray(np.random.rand(2))
1552+
assert quantile(a, q, axis=0).shape == (2, 4, 5)
1553+
assert quantile(a, q, axis=1).shape == (2, 3, 5)
1554+
assert quantile(a, q, axis=2).shape == (2, 3, 4)
1555+
1556+
assert quantile(a, q, axis=0, keepdims=True).shape == (2, 1, 4, 5)
1557+
assert quantile(a, q, axis=1, keepdims=True).shape == (2, 3, 1, 5)
1558+
assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1)
1559+
1560+
def test_against_numpy(self, xp: ModuleType):
1561+
a_np = np.random.rand(3, 4, 5)
1562+
q_np = np.random.rand(2)
1563+
a = xp.asarray(a_np)
1564+
q = xp.asarray(q_np)
1565+
for keepdims in [False, True]:
1566+
for axis in [None, *range(a.ndim)]:
1567+
actual = quantile(a, q, axis=axis, keepdims=keepdims)
1568+
expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims)
1569+
expected = xp.asarray(expected, dtype=xp.float64)
1570+
xp_assert_close(actual, expected, atol=1e-12)
1571+
1572+
def test_2d_axis(self, xp: ModuleType):
1573+
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
1574+
actual = quantile(x, 0.5, axis=0)
1575+
expect = xp.asarray([2.5, 3.5, 4.5], dtype=xp.float64)
1576+
xp_assert_close(actual, expect)
1577+
1578+
def test_2d_axis_keepdims(self, xp: ModuleType):
1579+
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
1580+
actual = quantile(x, 0.5, axis=0, keepdims=True)
1581+
expect = xp.asarray([[2.5, 3.5, 4.5]], dtype=xp.float64)
1582+
xp_assert_close(actual, expect)
1583+
1584+
def test_methods(self, xp: ModuleType):
1585+
x = xp.asarray([1, 2, 3, 4, 5])
1586+
methods = ["linear"] #"hazen", "weibull"]
1587+
for method in methods:
1588+
actual = quantile(x, 0.5, method=method)
1589+
# All methods should give reasonable results
1590+
assert 2.5 <= float(actual) <= 3.5
1591+
1592+
def test_edge_cases(self, xp: ModuleType):
1593+
x = xp.asarray([1, 2, 3, 4, 5])
1594+
# q = 0 should give minimum
1595+
actual = quantile(x, 0.0)
1596+
expect = xp.asarray(1.0, dtype=xp.float64)
1597+
xp_assert_close(actual, expect)
1598+
1599+
# q = 1 should give maximum
1600+
actual = quantile(x, 1.0)
1601+
expect = xp.asarray(5.0, dtype=xp.float64)
1602+
xp_assert_close(actual, expect)
1603+
1604+
def test_invalid_q(self, xp: ModuleType):
1605+
x = xp.asarray([1, 2, 3, 4, 5])
1606+
_ = quantile(x, 1.0)
1607+
# ^ FIXME: here just to make this test fail for sparse backend
1608+
# q > 1 should raise
1609+
with pytest.raises(
1610+
ValueError, match=r"`q` values must be in the range \[0, 1\]"
1611+
):
1612+
_ = quantile(x, 1.5)
1613+
# q < 0 should raise
1614+
with pytest.raises(
1615+
ValueError, match=r"`q` values must be in the range \[0, 1\]"
1616+
):
1617+
_ = quantile(x, -0.5)
1618+
1619+
def test_device(self, xp: ModuleType, device: Device):
1620+
if hasattr(device, 'type') and device.type == "meta":
1621+
pytest.xfail("No Tensor.item() on meta device")
1622+
x = xp.asarray([1, 2, 3, 4, 5], device=device)
1623+
actual = quantile(x, 0.5)
1624+
assert get_device(actual) == device
1625+
1626+
def test_xp(self, xp: ModuleType):
1627+
x = xp.asarray([1, 2, 3, 4, 5])
1628+
actual = quantile(x, 0.5, xp=xp)
1629+
expect = xp.asarray(3.0, dtype=xp.float64)
1630+
xp_assert_close(actual, expect)

0 commit comments

Comments
 (0)