@@ -19,14 +19,14 @@ def quantile( # numpydoc ignore=PR01,RT01
1919 device = get_device (a )
2020 floating_dtype = xp .float64 #xp.result_type(a, xp.asarray(q))
2121 a = xp .asarray (a , dtype = floating_dtype , device = device )
22- q = xp .asarray (q , dtype = floating_dtype , device = device )
22+ p : Array = xp .asarray (q , dtype = floating_dtype , device = device )
2323
24- if xp .any ((q > 1 ) | (q < 0 ) | xp .isnan (q )):
24+ if xp .any ((p > 1 ) | (p < 0 ) | xp .isnan (p )):
2525 raise ValueError ("`q` values must be in the range [0, 1]" )
2626
27- q_scalar = q .ndim == 0
27+ q_scalar = p .ndim == 0
2828 if q_scalar :
29- q = xp .reshape (q , (1 ,))
29+ p = xp .reshape (p , (1 ,))
3030
3131 axis_none = axis is None
3232 a_ndim = a .ndim
@@ -50,26 +50,26 @@ def quantile( # numpydoc ignore=PR01,RT01
5050 # The hard part will be dealing with 0-weights and NaNs
5151 # But maybe a proper use of searchsorted + left/right side will work?
5252
53- res = _quantile_hf (a , q , float (n ), axis , xp )
53+ res = _quantile_hf (a , p , float (n ), axis , xp )
5454
5555 # reshaping to conform to doc/other libs' behavior
5656 if axis_none :
5757 if keepdims :
58- res = xp .reshape (res , q .shape + (1 ,) * a_ndim )
58+ res = xp .reshape (res , p .shape + (1 ,) * a_ndim )
5959 else :
6060 res = xp .moveaxis (res , axis , 0 )
6161 if keepdims :
6262 shape = list (a .shape )
6363 shape [axis ] = 1
64- shape = q .shape + tuple (shape )
64+ shape = p .shape + tuple (shape )
6565 res = xp .reshape (res , shape )
6666
6767 return res [0 , ...] if q_scalar else res
6868
6969
70- def _quantile_hf (y : Array , p : Array , n : int , axis : int , xp : ModuleType ):
71- m = 1 - p
72- jg = p * n + m - 1
70+ def _quantile_hf (a : Array , q : Array , n : float , axis : int , xp : ModuleType ):
71+ m = 1 - q
72+ jg = q * n + m - 1
7373
7474 j = jg // 1
7575 j = xp .clip (j , 0. , n - 1 )
@@ -78,11 +78,11 @@ def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType):
7878
7979 g = jg % 1
8080 g = xp .where (j < 0 , 0 , g ) # equiv to g[j < 0] = 0, but work with strictest
81- new_g_shape = [1 ] * y .ndim
81+ new_g_shape = [1 ] * a .ndim
8282 new_g_shape [axis ] = g .shape [0 ]
8383 g = xp .reshape (g , tuple (new_g_shape ))
8484
8585 return (
86- (1 - g ) * xp .take (y , xp .astype (j , xp .int64 ), axis = axis )
87- + g * xp .take (y , xp .astype (jp1 , xp .int64 ), axis = axis )
86+ (1 - g ) * xp .take (a , xp .astype (j , xp .int64 ), axis = axis )
87+ + g * xp .take (a , xp .astype (jp1 , xp .int64 ), axis = axis )
8888 )
0 commit comments