44
55from ._utils ._compat import device as get_device
66from ._utils ._helpers import eager_shape
7- from ._utils ._typing import Array
7+ from ._utils ._typing import Array , Device
88
99
1010def quantile ( # numpydoc ignore=PR01,RT01
@@ -14,6 +14,7 @@ def quantile( # numpydoc ignore=PR01,RT01
1414 method : str = "linear" ,
1515 axis : int | None = None ,
1616 keepdims : bool = False ,
17+ nan_policy : str = "propagate" ,
1718 * ,
1819 weights : Array | None = None ,
1920 xp : ModuleType ,
@@ -43,43 +44,49 @@ def quantile( # numpydoc ignore=PR01,RT01
4344 a = xp .full (tuple (a_shape ), xp .nan , dtype = a .dtype , device = device )
4445
4546 if weights is None :
46- res = _quantile (a , q , float (n ), axis , method , xp )
47+ res = _quantile (a , q , n , axis , method , xp )
48+ if not axis_none :
49+ res = xp .moveaxis (res , axis , 0 )
4750 else :
51+ weights = xp .asarray (weights , dtype = xp .float64 , device = device )
4852 average = method == 'averaged_inverted_cdf'
49- res = _weighted_quantile (a , q , weights , n , axis , average , xp )
50- # to support weights, the main thing would be to
51- # argsort a, and then use it to sort a and w.
52- # The hard part will be dealing with 0-weights and NaNs
53- # But maybe a proper use of searchsorted + left/right side will work?
53+ res = _weighted_quantile (
54+ a , q , weights , n , axis , average ,
55+ nan_policy = nan_policy , xp = xp , device = device
56+ )
5457
5558 # reshaping to conform to doc/other libs' behavior
5659 if axis_none :
5760 if keepdims :
5861 res = xp .reshape (res , q .shape + (1 ,) * a_ndim )
59- else :
60- res = xp .moveaxis (res , axis , 0 )
61- if keepdims :
62- a_shape [axis ] = 1
63- res = xp .reshape (res , q .shape + tuple (a_shape ))
62+ elif keepdims :
63+ a_shape [axis ] = 1
64+ res = xp .reshape (res , q .shape + tuple (a_shape ))
6465
6566 return res [0 , ...] if q_scalar else res
6667
6768
6869def _quantile ( # numpydoc ignore=GL08
69- a : Array , q : Array , n : float , axis : int , method : str , xp : ModuleType
70+ a : Array , q : Array , n : int , axis : int , method : str , xp : ModuleType
7071) -> Array :
7172 a = xp .sort (a , axis = axis , stable = False )
73+ mask_nan = xp .any (xp .isnan (a ), axis = axis , keepdims = True )
74+ if xp .any (mask_nan ):
75+ # propogate NaNs:
76+ mask = xp .repeat (mask_nan , n , axis = axis )
77+ a = xp .where (mask , xp .nan , a )
78+ del mask
7279
7380 if method == "linear" :
74- m = 1 - q
81+ m = 1 - q
7582 else : # method is "inverted_cdf" or "averaged_inverted_cdf"
7683 m = 0
7784
78- jg = q * n + m - 1
85+ jg = q * float ( n ) + m - 1
7986
8087 j = jg // 1
81- j = xp .clip (j , 0.0 , n - 1 )
82- jp1 = xp .clip (j + 1 , 0.0 , n - 1 )
88+ j = xp .clip (j , 0.0 , float ( n - 1 ) )
89+ jp1 = xp .clip (j + 1 , 0.0 , float ( n - 1 ) )
8390 # `̀j` and `jp1` are 1d arrays
8491
8592 g = jg % 1
@@ -88,7 +95,7 @@ def _quantile( # numpydoc ignore=GL08
8895 elif method == 'averaged_inverted_cdf' :
8996 g = (1 + xp .astype ((g > 0 ), jg .dtype )) / 2
9097
91- g = xp .where (j < 0 , 0 , g ) # equivalent to g[j < 0] = 0, but works with strictest
98+ g = xp .where (j < 0 , 0 , g ) # equivalent to g[j < 0] = 0, but works with readonly
9299 new_g_shape = [1 ] * a .ndim
93100 new_g_shape [axis ] = g .shape [0 ]
94101 g = xp .reshape (g , tuple (new_g_shape ))
@@ -98,37 +105,55 @@ def _quantile( # numpydoc ignore=GL08
98105 )
99106
100107
101- def _weighted_quantile (a : Array , q : Array , weights : Array , n : int , axis , average : bool , xp : ModuleType ):
108+ def _weighted_quantile (
109+ a : Array , q : Array , weights : Array , n : int , axis : int , average : bool , nan_policy : str ,
110+ xp : ModuleType , device : Device
111+ ) -> Array :
112+ """
113+ a is expected to be 1d or 2d.
114+ """
115+ kwargs = dict (n = n , average = average , nan_policy = nan_policy , xp = xp , device = device )
102116 a = xp .moveaxis (a , axis , - 1 )
117+ if weights .ndim > 1 :
118+ weights = xp .moveaxis (weights , axis , - 1 )
103119 sorter = xp .argsort (a , axis = - 1 , stable = False )
104- a = xp .take_along_axis (a , sorter , axis = - 1 )
105120
106121 if a .ndim == 1 :
107- return _weighted_quantile_sorted_1d (a , q , weights , n , )
122+ x = xp .take (a , sorter )
123+ w = xp .take (weights , sorter )
124+ return _weighted_quantile_sorted_1d (x , q , w , ** kwargs )
108125
109126 d , = eager_shape (a , axis = 0 )
110- res = xp . empty (( q . shape [ 0 ], d ))
127+ res = []
111128 for idx in range (d ):
112129 w = weights if weights .ndim == 1 else weights [idx , ...]
113130 w = xp .take (w , sorter [idx , ...])
114- res [..., idx ] = _weighted_quantile_sorted_1d (a [idx , ...], q , w , n , average )
131+ x = xp .take (a [idx , ...], sorter [idx , ...])
132+ res .append (_weighted_quantile_sorted_1d (x , q , w , ** kwargs ))
133+ res = xp .stack (res , axis = 1 )
115134 return res
116135
117136
118- def _weighted_quantile_sorted_1d (a , q , w , n , average : bool , xp : ModuleType ):
119- cw = xp .cumsum (w )
137+ def _weighted_quantile_sorted_1d (
138+ x : Array , q : Array , w : Array , n : int , average : bool , nan_policy : str ,
139+ xp : ModuleType , device : Device
140+ ) -> Array :
141+ if nan_policy == "omit" :
142+ w = xp .where (xp .isnan (x ), 0. , w )
143+ elif xp .any (xp .isnan (x )):
144+ return xp .full (q .shape , xp .nan , dtype = x .dtype , device = device )
145+ cw = xp .cumulative_sum (w )
120146 t = cw [- 1 ] * q
121- i = xp .searchsorted (cw , t )
147+ i = xp .searchsorted (cw , t , side = 'left' )
122148 j = xp .searchsorted (cw , t , side = 'right' )
123- i = xp .minimum (i , float ( n - 1 ) )
124- j = xp .minimum (j , float ( n - 1 ) )
149+ i = xp .clip (i , 0 , n - 1 )
150+ j = xp .clip (j , 0 , n - 1 )
125151
126152 # Ignore leading `weights=0` observations when `q=0`
127153 # see https://github.com/scikit-learn/scikit-learn/pull/20528
128- i = xp .where (q == 0. , j , i )
154+ i = xp .where (q == 0. , j , i )
129155 if average :
130156 # Ignore trailing `weights=0` observations when `q=1`
131157 j = xp .where (q == 1. , i , j )
132- return (xp .take (a , i ) + xp .take (a , j )) / 2
133- else :
134- return xp .take (a , i )
158+ return (xp .take (x , i ) + xp .take (x , j )) / 2
159+ return xp .take (x , i )
0 commit comments