Skip to content

Commit 26804fe

Browse files
committed
linting: ruff
1 parent 1d8fef7 commit 26804fe

File tree

3 files changed

+91
-39
lines changed

3 files changed

+91
-39
lines changed

src/array_api_extra/_delegation.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,8 @@ def quantile(
10551055
if xp is None:
10561056
xp = array_namespace(a)
10571057
if is_pydata_sparse_namespace(xp):
1058-
raise ValueError('no supported')
1058+
msg = "Sparse backend not supported"
1059+
raise ValueError(msg)
10591060

10601061
methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"}
10611062
if method not in methods:
@@ -1096,7 +1097,10 @@ def quantile(
10961097
msg = "Axis must be specified when shapes of `a` and ̀ weights` differ."
10971098
raise TypeError(msg)
10981099
if weights.shape != eager_shape(a, axis):
1099-
msg = "Shape of weights must be consistent with shape of a along specified axis."
1100+
msg = (
1101+
"Shape of weights must be consistent with shape"
1102+
" of a along specified axis."
1103+
)
11001104
raise ValueError(msg)
11011105
if axis is None and ndim == 2:
11021106
msg = "When weights are provided, axis must be specified when `a` is 2d"
@@ -1119,7 +1123,9 @@ def quantile(
11191123

11201124
# Delegate where possible.
11211125
if is_numpy_namespace(xp) and nan_policy == "propagate":
1122-
return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights)
1126+
return xp.quantile(
1127+
a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights
1128+
)
11231129
# No delegation for dask: I couldn't make it work
11241130
basic_case = method == "linear" and weights is None and nan_policy == "propagate"
11251131
if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp):
@@ -1130,6 +1136,12 @@ def quantile(
11301136
# XXX: I'm not sure we want to support dask, it seems uterly slow...
11311137
# Otherwise call our implementation (will sort data)
11321138
return _quantile.quantile(
1133-
a, q_arr, axis=axis, method=method, keepdims=keepdims,
1134-
nan_policy=nan_policy, weights=weights, xp=xp
1139+
a,
1140+
q_arr,
1141+
axis=axis,
1142+
method=method,
1143+
keepdims=keepdims,
1144+
nan_policy=nan_policy,
1145+
weights=weights,
1146+
xp=xp,
11351147
)

src/array_api_extra/_lib/_quantile.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,17 @@ def quantile( # numpydoc ignore=PR01,RT01
4949
res = xp.moveaxis(res, axis, 0)
5050
else:
5151
weights_arr = xp.asarray(weights, dtype=xp.float64, device=device)
52-
average = method == 'averaged_inverted_cdf'
52+
average = method == "averaged_inverted_cdf"
5353
res = _weighted_quantile(
54-
a, q, weights_arr, n, axis, average,
55-
nan_policy=nan_policy, xp=xp, device=device
54+
a,
55+
q,
56+
weights_arr,
57+
n,
58+
axis,
59+
average,
60+
nan_policy=nan_policy,
61+
xp=xp,
62+
device=device,
5663
)
5764

5865
# reshaping to conform to doc/other libs' behavior
@@ -72,15 +79,17 @@ def _quantile( # numpydoc ignore=GL08
7279
a = xp.sort(a, axis=axis, stable=False)
7380
mask_nan = xp.any(xp.isnan(a), axis=axis, keepdims=True)
7481
if xp.any(mask_nan):
75-
# propogate NaNs:
82+
# propagate NaNs:
7683
mask = xp.repeat(mask_nan, n, axis=axis)
7784
a = xp.where(mask, xp.nan, a)
7885
del mask
7986

80-
if method == "linear":
81-
m = 1 - q
82-
else: # method is "inverted_cdf" or "averaged_inverted_cdf"
83-
m = xp.asarray(0, dtype=q.dtype)
87+
m = (
88+
1 - q
89+
if method == "linear"
90+
# method is "inverted_cdf" or "averaged_inverted_cdf"
91+
else xp.asarray(0, dtype=q.dtype)
92+
)
8493

8594
jg = q * float(n) + m - 1
8695

@@ -90,9 +99,9 @@ def _quantile( # numpydoc ignore=GL08
9099
# `̀j` and `jp1` are 1d arrays
91100

92101
g = jg % 1
93-
if method == 'inverted_cdf':
102+
if method == "inverted_cdf":
94103
g = xp.astype((g > 0), jg.dtype)
95-
elif method == 'averaged_inverted_cdf':
104+
elif method == "averaged_inverted_cdf":
96105
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
97106

98107
g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with readonly
@@ -106,8 +115,15 @@ def _quantile( # numpydoc ignore=GL08
106115

107116

108117
def _weighted_quantile(
109-
a: Array, q: Array, weights: Array, n: int, axis: int, average: bool, nan_policy: str,
110-
xp: ModuleType, device: Device
118+
a: Array,
119+
q: Array,
120+
weights: Array,
121+
n: int,
122+
axis: int,
123+
average: bool,
124+
nan_policy: str,
125+
xp: ModuleType,
126+
device: Device,
111127
) -> Array:
112128
"""
113129
a is expected to be 1d or 2d.
@@ -122,37 +138,45 @@ def _weighted_quantile(
122138
w = xp.take(weights, sorter)
123139
return _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device)
124140

125-
d, = eager_shape(a, axis=0)
141+
(d,) = eager_shape(a, axis=0)
126142
res = []
127143
for idx in range(d):
128144
w = weights if weights.ndim == 1 else weights[idx, ...]
129145
w = xp.take(w, sorter[idx, ...])
130146
x = xp.take(a[idx, ...], sorter[idx, ...])
131-
res.append(_weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device))
132-
res = xp.stack(res, axis=1)
133-
return res
147+
res.append(
148+
_weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device)
149+
)
150+
151+
return xp.stack(res, axis=1)
134152

135153

136154
def _weighted_quantile_sorted_1d(
137-
x: Array, q: Array, w: Array, n: int, average: bool, nan_policy: str,
138-
xp: ModuleType, device: Device
155+
x: Array,
156+
q: Array,
157+
w: Array,
158+
n: int,
159+
average: bool,
160+
nan_policy: str,
161+
xp: ModuleType,
162+
device: Device,
139163
) -> Array:
140164
if nan_policy == "omit":
141-
w = xp.where(xp.isnan(x), 0., w)
165+
w = xp.where(xp.isnan(x), 0.0, w)
142166
elif xp.any(xp.isnan(x)):
143167
return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device)
144168
cw = xp.cumulative_sum(w)
145169
t = cw[-1] * q
146-
i = xp.searchsorted(cw, t, side='left')
147-
j = xp.searchsorted(cw, t, side='right')
170+
i = xp.searchsorted(cw, t, side="left")
171+
j = xp.searchsorted(cw, t, side="right")
148172
i = xp.clip(i, 0, n - 1)
149173
j = xp.clip(j, 0, n - 1)
150174

151175
# Ignore leading `weights=0` observations when `q=0`
152176
# see https://github.com/scikit-learn/scikit-learn/pull/20528
153-
i = xp.where(q == 0., j, i)
177+
i = xp.where(q == 0.0, j, i)
154178
if average:
155179
# Ignore trailing `weights=0` observations when `q=1`
156-
j = xp.where(q == 1., i, j)
180+
j = xp.where(q == 1.0, i, j)
157181
return (xp.take(x, i) + xp.take(x, j)) / 2
158182
return xp.take(x, i)

tests/test_funcs.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,9 +1572,11 @@ def test_against_numpy(self, xp: ModuleType, keepdims: bool):
15721572
xp_assert_close(actual, expected, atol=1e-12)
15731573

15741574
@pytest.mark.parametrize("keepdims", [True, False])
1575-
@pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"])#, #["omit"])#["no_nans", "propagate"])
1576-
@pytest.mark.parametrize("q_np", [0.5, 0., 1., np.linspace(0, 1, num=11)])
1577-
def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str):
1575+
@pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"])
1576+
@pytest.mark.parametrize("q_np", [0.5, 0.0, 1.0, np.linspace(0, 1, num=11)])
1577+
def test_weighted_against_numpy(
1578+
self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str
1579+
):
15781580
rng = np.random.default_rng()
15791581
n, d = 10, 20
15801582
a_np = rng.random((n, d))
@@ -1590,7 +1592,7 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra
15901592

15911593
a = xp.asarray(a_np, copy=True)
15921594
q = xp.asarray(q_np, copy=True)
1593-
m = 'inverted_cdf'
1595+
m = "inverted_cdf"
15941596

15951597
np_quantile = np.quantile
15961598
if nan_policy == "omit":
@@ -1605,22 +1607,36 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra
16051607
(rng.integers(0, 2, (n, d)), 1),
16061608
]:
16071609
with warnings.catch_warnings(record=True) as warning:
1608-
warnings.filterwarnings("always", "invalid value encountered in divide", RuntimeWarning)
1609-
warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning)
1610+
divide_msg = "invalid value encountered in divide"
1611+
warnings.filterwarnings("always", divide_msg, RuntimeWarning)
1612+
nan_slice_msg = "All-NaN slice encountered"
1613+
warnings.filterwarnings("ignore", nan_slice_msg, RuntimeWarning)
16101614
try:
16111615
expected = np_quantile( # type: ignore[call-overload]
1612-
a_np, np.asarray(q_np),
1613-
axis=axis, method=m, weights=w_np, keepdims=keepdims
1616+
a_np,
1617+
np.asarray(q_np),
1618+
axis=axis,
1619+
method=m,
1620+
weights=w_np,
1621+
keepdims=keepdims,
16141622
)
16151623
except IndexError:
16161624
continue
1617-
if warning: # this means some weights sum was 0, in this case we skip calling xpx.quantile
1625+
if warning:
1626+
# this means some weights sum was 0
1627+
# in this case we skip calling xpx.quantile
16181628
continue
16191629
expected = xp.asarray(expected)
16201630

16211631
w = xp.asarray(w_np)
1622-
actual = quantile(
1623-
a, q, axis=axis, method=m, weights=w, keepdims=keepdims, nan_policy=nan_policy
1632+
actual = quantile(
1633+
a,
1634+
q,
1635+
axis=axis,
1636+
method=m,
1637+
weights=w,
1638+
keepdims=keepdims,
1639+
nan_policy=nan_policy,
16241640
)
16251641
xp_assert_close(actual, expected, atol=1e-12)
16261642

0 commit comments

Comments
 (0)