Skip to content

Commit 1d8fef7

Browse files
committed
linting: pyright & mypy
1 parent fa789fc commit 1d8fef7

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

src/array_api_extra/_delegation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,26 +1110,26 @@ def quantile(
11101110
)
11111111
device = get_device(a)
11121112
a = xp.asarray(a, dtype=dtype, device=device)
1113-
q = xp.asarray(q, dtype=dtype, device=device)
1113+
q_arr = xp.asarray(q, dtype=dtype, device=device)
11141114
# TODO: cast weights here? Assert weights are on the same device as `a`?
11151115

1116-
if xp.any((q > 1) | (q < 0) | xp.isnan(q)):
1116+
if xp.any((q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr)):
11171117
msg = "`q` values must be in the range [0, 1]"
11181118
raise ValueError(msg)
11191119

11201120
# Delegate where possible.
11211121
if is_numpy_namespace(xp) and nan_policy == "propagate":
1122-
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims, weights=weights)
1122+
return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights)
11231123
# No delegation for dask: I couldn't make it work
11241124
basic_case = method == "linear" and weights is None and nan_policy == "propagate"
11251125
if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp):
1126-
return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims)
1126+
return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims)
11271127
if basic_case and is_torch_namespace(xp):
1128-
return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims)
1128+
return xp.quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims)
11291129

11301130
# XXX: I'm not sure we want to support dask, it seems uterly slow...
11311131
# Otherwise call our implementation (will sort data)
11321132
return _quantile.quantile(
1133-
a, q, axis=axis, method=method, keepdims=keepdims,
1133+
a, q_arr, axis=axis, method=method, keepdims=keepdims,
11341134
nan_policy=nan_policy, weights=weights, xp=xp
11351135
)

src/array_api_extra/_lib/_quantile.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ def quantile( # numpydoc ignore=PR01,RT01
4848
if not axis_none:
4949
res = xp.moveaxis(res, axis, 0)
5050
else:
51-
weights = xp.asarray(weights, dtype=xp.float64, device=device)
51+
weights_arr = xp.asarray(weights, dtype=xp.float64, device=device)
5252
average = method == 'averaged_inverted_cdf'
5353
res = _weighted_quantile(
54-
a, q, weights, n, axis, average,
54+
a, q, weights_arr, n, axis, average,
5555
nan_policy=nan_policy, xp=xp, device=device
5656
)
5757

@@ -80,7 +80,7 @@ def _quantile( # numpydoc ignore=GL08
8080
if method == "linear":
8181
m = 1 - q
8282
else: # method is "inverted_cdf" or "averaged_inverted_cdf"
83-
m = 0
83+
m = xp.asarray(0, dtype=q.dtype)
8484

8585
jg = q * float(n) + m - 1
8686

@@ -112,7 +112,6 @@ def _weighted_quantile(
112112
"""
113113
a is expected to be 1d or 2d.
114114
"""
115-
kwargs = dict(n=n, average=average, nan_policy=nan_policy, xp=xp, device=device)
116115
a = xp.moveaxis(a, axis, -1)
117116
if weights.ndim > 1:
118117
weights = xp.moveaxis(weights, axis, -1)
@@ -121,15 +120,15 @@ def _weighted_quantile(
121120
if a.ndim == 1:
122121
x = xp.take(a, sorter)
123122
w = xp.take(weights, sorter)
124-
return _weighted_quantile_sorted_1d(x, q, w, **kwargs)
123+
return _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device)
125124

126125
d, = eager_shape(a, axis=0)
127126
res = []
128127
for idx in range(d):
129128
w = weights if weights.ndim == 1 else weights[idx, ...]
130129
w = xp.take(w, sorter[idx, ...])
131130
x = xp.take(a[idx, ...], sorter[idx, ...])
132-
res.append(_weighted_quantile_sorted_1d(x, q, w, **kwargs))
131+
res.append(_weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device))
133132
res = xp.stack(res, axis=1)
134133
return res
135134

tests/test_funcs.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,18 +1578,18 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra
15781578
rng = np.random.default_rng()
15791579
n, d = 10, 20
15801580
a_np = rng.random((n, d))
1581-
kwargs = dict(keepdims=keepdims)
15821581
mask_nan = np.zeros((n, d), dtype=bool)
1583-
if nan_policy != "no_nans":
1582+
if nan_policy == "no_nans":
1583+
nan_policy = "propagate"
1584+
else:
15841585
# from 0% to 100% of NaNs:
15851586
mask_nan = rng.random((n, d)) < rng.random((n, 1))
15861587
# don't put nans in the first row:
15871588
mask_nan[:] = False
15881589
a_np[mask_nan] = np.nan
1589-
kwargs['nan_policy'] = nan_policy
15901590

1591-
a = xp.asarray(a_np)
1592-
q = xp.asarray(np.copy(q_np))
1591+
a = xp.asarray(a_np, copy=True)
1592+
q = xp.asarray(q_np, copy=True)
15931593
m = 'inverted_cdf'
15941594

15951595
np_quantile = np.quantile
@@ -1604,23 +1604,24 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra
16041604
(rng.integers(0, 2, (n, d)), 0),
16051605
(rng.integers(0, 2, (n, d)), 1),
16061606
]:
1607-
print(w_np)
16081607
with warnings.catch_warnings(record=True) as warning:
16091608
warnings.filterwarnings("always", "invalid value encountered in divide", RuntimeWarning)
16101609
warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning)
16111610
try:
1612-
expected = np_quantile(a_np, q_np, axis=axis, method=m, weights=w_np, keepdims=keepdims)
1611+
expected = np_quantile( # type: ignore[call-overload]
1612+
a_np, np.asarray(q_np),
1613+
axis=axis, method=m, weights=w_np, keepdims=keepdims
1614+
)
16131615
except IndexError:
1614-
print('index error')
16151616
continue
16161617
if warning: # this means some weights sum was 0, in this case we skip calling xpx.quantile
1617-
print('warning')
16181618
continue
16191619
expected = xp.asarray(expected)
1620-
print("not skiped")
16211620

16221621
w = xp.asarray(w_np)
1623-
actual = quantile(a, q, axis=axis, method=m, weights=w, **kwargs)
1622+
actual = quantile(
1623+
a, q, axis=axis, method=m, weights=w, keepdims=keepdims, nan_policy=nan_policy
1624+
)
16241625
xp_assert_close(actual, expected, atol=1e-12)
16251626

16261627
def test_2d_axis(self, xp: ModuleType):

0 commit comments

Comments
 (0)