Skip to content

Commit 05ffb7b

Browse files
committed
linting: fix mypy
1 parent 034c064 commit 05ffb7b

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

src/array_api_extra/_lib/_quantile.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ def quantile( # numpydoc ignore=PR01,RT01
1414
keepdims: bool = False,
1515
*,
1616
xp: ModuleType,
17-
):
17+
) -> Array:
1818
"""See docstring in `array_api_extra._delegation.py`."""
1919
device = get_device(a)
2020
floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q))
2121
a = xp.asarray(a, dtype=floating_dtype, device=device)
22+
a_shape = list(a.shape)
2223
p: Array = xp.asarray(q, dtype=floating_dtype, device=device)
2324

2425
if xp.any((p > 1) | (p < 0) | xp.isnan(p)):
@@ -30,19 +31,19 @@ def quantile( # numpydoc ignore=PR01,RT01
3031

3132
axis_none = axis is None
3233
a_ndim = a.ndim
33-
if axis_none:
34+
if axis is None:
3435
a = xp.reshape(a, (-1,))
3536
axis = 0
36-
axis = int(axis)
37+
else:
38+
axis = int(axis)
3739

3840
n, = eager_shape(a, axis)
3941
# If data has length zero along `axis`, the result will be an array of NaNs just
4042
# as if the data had length 1 along axis and were filled with NaNs.
4143
if n == 0:
42-
shape = list(eager_shape(a))
43-
shape[axis] = 1
44+
a_shape[axis] = 1
4445
n = 1
45-
a = xp.full(shape, xp.nan, dtype=floating_dtype, device=device)
46+
a = xp.full(tuple(a_shape), xp.nan, dtype=floating_dtype, device=device)
4647

4748
a = xp.sort(a, axis=axis, stable=False)
4849
# to support weights, the main thing would be to
@@ -59,15 +60,13 @@ def quantile( # numpydoc ignore=PR01,RT01
5960
else:
6061
res = xp.moveaxis(res, axis, 0)
6162
if keepdims:
62-
shape = list(a.shape)
63-
shape[axis] = 1
64-
shape = p.shape + tuple(shape)
65-
res = xp.reshape(res, shape)
63+
a_shape[axis] = 1
64+
res = xp.reshape(res, p.shape + tuple(a_shape))
6665

6766
return res[0, ...] if q_scalar else res
6867

6968

70-
def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType):
69+
def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType) -> Array:
7170
m = 1 - q
7271
jg = q*n + m - 1
7372

0 commit comments

Comments
 (0)