Skip to content

Commit 034c064

Browse files
committed
linting: fix pyright
1 parent 98fe39f commit 034c064

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

src/array_api_extra/_lib/_quantile.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ def quantile( # numpydoc ignore=PR01,RT01
1919
device = get_device(a)
2020
floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q))
2121
a = xp.asarray(a, dtype=floating_dtype, device=device)
22-
q = xp.asarray(q, dtype=floating_dtype, device=device)
22+
p: Array = xp.asarray(q, dtype=floating_dtype, device=device)
2323

24-
if xp.any((q > 1) | (q < 0) | xp.isnan(q)):
24+
if xp.any((p > 1) | (p < 0) | xp.isnan(p)):
2525
raise ValueError("`q` values must be in the range [0, 1]")
2626

27-
q_scalar = q.ndim == 0
27+
q_scalar = p.ndim == 0
2828
if q_scalar:
29-
q = xp.reshape(q, (1,))
29+
p = xp.reshape(p, (1,))
3030

3131
axis_none = axis is None
3232
a_ndim = a.ndim
@@ -50,26 +50,26 @@ def quantile( # numpydoc ignore=PR01,RT01
5050
# The hard part will be dealing with 0-weights and NaNs
5151
# But maybe a proper use of searchsorted + left/right side will work?
5252

53-
res = _quantile_hf(a, q, float(n), axis, xp)
53+
res = _quantile_hf(a, p, float(n), axis, xp)
5454

5555
# reshaping to conform to doc/other libs' behavior
5656
if axis_none:
5757
if keepdims:
58-
res = xp.reshape(res, q.shape + (1,) * a_ndim)
58+
res = xp.reshape(res, p.shape + (1,) * a_ndim)
5959
else:
6060
res = xp.moveaxis(res, axis, 0)
6161
if keepdims:
6262
shape = list(a.shape)
6363
shape[axis] = 1
64-
shape = q.shape + tuple(shape)
64+
shape = p.shape + tuple(shape)
6565
res = xp.reshape(res, shape)
6666

6767
return res[0, ...] if q_scalar else res
6868

6969

70-
def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType):
71-
m = 1 - p
72-
jg = p*n + m - 1
70+
def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType):
71+
m = 1 - q
72+
jg = q*n + m - 1
7373

7474
j = jg // 1
7575
j = xp.clip(j, 0., n - 1)
@@ -78,11 +78,11 @@ def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType):
7878

7979
g = jg % 1
8080
g = xp.where(j < 0, 0, g) # equiv to g[j < 0] = 0, but work with strictest
81-
new_g_shape = [1] * y.ndim
81+
new_g_shape = [1] * a.ndim
8282
new_g_shape[axis] = g.shape[0]
8383
g = xp.reshape(g, tuple(new_g_shape))
8484

8585
return (
86-
(1 - g) * xp.take(y, xp.astype(j, xp.int64), axis=axis)
87-
+ g * xp.take(y, xp.astype(jp1, xp.int64), axis=axis)
86+
(1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis)
87+
+ g * xp.take(a, xp.astype(jp1, xp.int64), axis=axis)
8888
)

tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1617,7 +1617,7 @@ def test_invalid_q(self, xp: ModuleType):
16171617
_ = quantile(x, -0.5)
16181618

16191619
def test_device(self, xp: ModuleType, device: Device):
1620-
if hasattr(device, 'type') and device.type == "meta":
1620+
if hasattr(device, 'type') and getattr(device, 'type') == "meta":
16211621
pytest.xfail("No Tensor.item() on meta device")
16221622
x = xp.asarray([1, 2, 3, 4, 5], device=device)
16231623
actual = quantile(x, 0.5)

0 commit comments

Comments
 (0)