99
1010def quantile ( # numpydoc ignore=PR01,RT01
1111 a : Array ,
12- q : Array | float ,
12+ q : Array ,
1313 / ,
14- method : str = "linear" , # noqa: ARG001
14+ method : str = "linear" ,
1515 axis : int | None = None ,
1616 keepdims : bool = False ,
1717 * ,
18+ weights : Array | None = None ,
1819 xp : ModuleType ,
1920) -> Array :
2021 """See docstring in `array_api_extra._delegation.py`."""
2122 device = get_device (a )
22- floating_dtype = xp .float64 # xp.result_type(a, xp.asarray(q))
23- a = xp .asarray (a , dtype = floating_dtype , device = device )
2423 a_shape = list (a .shape )
25- p : Array = xp .asarray (q , dtype = floating_dtype , device = device )
2624
27- q_scalar = p .ndim == 0
25+ q_scalar = q .ndim == 0
2826 if q_scalar :
29- p = xp .reshape (p , (1 ,))
27+ q = xp .reshape (q , (1 ,))
3028
3129 axis_none = axis is None
3230 a_ndim = a .ndim
@@ -42,33 +40,41 @@ def quantile( # numpydoc ignore=PR01,RT01
4240 if n == 0 :
4341 a_shape [axis ] = 1
4442 n = 1
45- a = xp .full (tuple (a_shape ), xp .nan , dtype = floating_dtype , device = device )
43+ a = xp .full (tuple (a_shape ), xp .nan , dtype = a . dtype , device = device )
4644
47- a = xp .sort (a , axis = axis , stable = False )
45+ if weights is None :
46+ res = _quantile (a , q , float (n ), axis , method , xp )
47+ else :
48+ average = method == 'averaged_inverted_cdf'
49+ res = _weighted_quantile (a , q , weights , n , axis , average , xp )
4850 # to support weights, the main thing would be to
4951 # argsort a, and then use it to sort a and w.
5052 # The hard part will be dealing with 0-weights and NaNs
5153 # But maybe a proper use of searchsorted + left/right side will work?
5254
53- res = _quantile_hf (a , p , float (n ), axis , xp )
54-
5555 # reshaping to conform to doc/other libs' behavior
5656 if axis_none :
5757 if keepdims :
58- res = xp .reshape (res , p .shape + (1 ,) * a_ndim )
58+ res = xp .reshape (res , q .shape + (1 ,) * a_ndim )
5959 else :
6060 res = xp .moveaxis (res , axis , 0 )
6161 if keepdims :
6262 a_shape [axis ] = 1
63- res = xp .reshape (res , p .shape + tuple (a_shape ))
63+ res = xp .reshape (res , q .shape + tuple (a_shape ))
6464
6565 return res [0 , ...] if q_scalar else res
6666
6767
68- def _quantile_hf ( # numpydoc ignore=GL08
69- a : Array , q : Array , n : float , axis : int , xp : ModuleType
68+ def _quantile ( # numpydoc ignore=GL08
69+ a : Array , q : Array , n : float , axis : int , method : str , xp : ModuleType
7070) -> Array :
71- m = 1 - q
71+ a = xp .sort (a , axis = axis , stable = False )
72+
73+ if method == "linear" :
74+ m = 1 - q
75+ else : # method is "inverted_cdf" or "averaged_inverted_cdf"
76+ m = 0
77+
7278 jg = q * n + m - 1
7379
7480 j = jg // 1
@@ -77,6 +83,11 @@ def _quantile_hf( # numpydoc ignore=GL08
7783 # `̀j` and `jp1` are 1d arrays
7884
7985 g = jg % 1
86+ if method == 'inverted_cdf' :
87+ g = xp .astype ((g > 0 ), jg .dtype )
88+ elif method == 'averaged_inverted_cdf' :
89+ g = (1 + xp .astype ((g > 0 ), jg .dtype )) / 2
90+
8091 g = xp .where (j < 0 , 0 , g ) # equivalent to g[j < 0] = 0, but works with strictest
8192 new_g_shape = [1 ] * a .ndim
8293 new_g_shape [axis ] = g .shape [0 ]
@@ -85,3 +96,39 @@ def _quantile_hf( # numpydoc ignore=GL08
8596 return (1 - g ) * xp .take (a , xp .astype (j , xp .int64 ), axis = axis ) + g * xp .take (
8697 a , xp .astype (jp1 , xp .int64 ), axis = axis
8798 )
99+
100+
101+ def _weighted_quantile (a : Array , q : Array , weights : Array , n : int , axis , average : bool , xp : ModuleType ):
102+ a = xp .moveaxis (a , axis , - 1 )
103+ sorter = xp .argsort (a , axis = - 1 , stable = False )
104+ a = xp .take_along_axis (a , sorter , axis = - 1 )
105+
106+ if a .ndim == 1 :
107+ return _weighted_quantile_sorted_1d (a , q , weights , n , )
108+
109+ d , = eager_shape (a , axis = 0 )
110+ res = xp .empty ((q .shape [0 ], d ))
111+ for idx in range (d ):
112+ w = weights if weights .ndim == 1 else weights [idx , ...]
113+ w = xp .take (w , sorter [idx , ...])
114+ res [..., idx ] = _weighted_quantile_sorted_1d (a [idx , ...], q , w , n , average )
115+ return res
116+
117+
118+ def _weighted_quantile_sorted_1d (a , q , w , n , average : bool , xp : ModuleType ):
119+ cw = xp .cumsum (w )
120+ t = cw [- 1 ] * q
121+ i = xp .searchsorted (cw , t )
122+ j = xp .searchsorted (cw , t , side = 'right' )
123+ i = xp .minimum (i , float (n - 1 ))
124+ j = xp .minimum (j , float (n - 1 ))
125+
126+ # Ignore leading `weights=0` observations when `q=0`
127+ # see https://github.com/scikit-learn/scikit-learn/pull/20528
128+ i = xp .where (q == 0. , j , i )
129+ if average :
130+ # Ignore trailing `weights=0` observations when `q=1`
131+ j = xp .where (q == 1. , i , j )
132+ return (xp .take (a , i ) + xp .take (a , j )) / 2
133+ else :
134+ return xp .take (a , i )
0 commit comments