@@ -49,10 +49,17 @@ def quantile( # numpydoc ignore=PR01,RT01
4949 res = xp .moveaxis (res , axis , 0 )
5050 else :
5151 weights_arr = xp .asarray (weights , dtype = xp .float64 , device = device )
52- average = method == ' averaged_inverted_cdf'
52+ average = method == " averaged_inverted_cdf"
5353 res = _weighted_quantile (
54- a , q , weights_arr , n , axis , average ,
55- nan_policy = nan_policy , xp = xp , device = device
54+ a ,
55+ q ,
56+ weights_arr ,
57+ n ,
58+ axis ,
59+ average ,
60+ nan_policy = nan_policy ,
61+ xp = xp ,
62+ device = device ,
5663 )
5764
5865 # reshaping to conform to doc/other libs' behavior
@@ -72,15 +79,17 @@ def _quantile( # numpydoc ignore=GL08
7279 a = xp .sort (a , axis = axis , stable = False )
7380 mask_nan = xp .any (xp .isnan (a ), axis = axis , keepdims = True )
7481 if xp .any (mask_nan ):
75- # propogate NaNs:
82+ # propagate NaNs:
7683 mask = xp .repeat (mask_nan , n , axis = axis )
7784 a = xp .where (mask , xp .nan , a )
7885 del mask
7986
80- if method == "linear" :
81- m = 1 - q
82- else : # method is "inverted_cdf" or "averaged_inverted_cdf"
83- m = xp .asarray (0 , dtype = q .dtype )
87+ m = (
88+ 1 - q
89+ if method == "linear"
90+ # method is "inverted_cdf" or "averaged_inverted_cdf"
91+ else xp .asarray (0 , dtype = q .dtype )
92+ )
8493
8594 jg = q * float (n ) + m - 1
8695
@@ -90,9 +99,9 @@ def _quantile( # numpydoc ignore=GL08
9099 # `̀j` and `jp1` are 1d arrays
91100
92101 g = jg % 1
93- if method == ' inverted_cdf' :
102+ if method == " inverted_cdf" :
94103 g = xp .astype ((g > 0 ), jg .dtype )
95- elif method == ' averaged_inverted_cdf' :
104+ elif method == " averaged_inverted_cdf" :
96105 g = (1 + xp .astype ((g > 0 ), jg .dtype )) / 2
97106
98107 g = xp .where (j < 0 , 0 , g ) # equivalent to g[j < 0] = 0, but works with readonly
@@ -106,8 +115,15 @@ def _quantile( # numpydoc ignore=GL08
106115
107116
108117def _weighted_quantile (
109- a : Array , q : Array , weights : Array , n : int , axis : int , average : bool , nan_policy : str ,
110- xp : ModuleType , device : Device
118+ a : Array ,
119+ q : Array ,
120+ weights : Array ,
121+ n : int ,
122+ axis : int ,
123+ average : bool ,
124+ nan_policy : str ,
125+ xp : ModuleType ,
126+ device : Device ,
111127) -> Array :
112128 """
113129 a is expected to be 1d or 2d.
@@ -122,37 +138,45 @@ def _weighted_quantile(
122138 w = xp .take (weights , sorter )
123139 return _weighted_quantile_sorted_1d (x , q , w , n , average , nan_policy , xp , device )
124140
125- d , = eager_shape (a , axis = 0 )
141+ ( d ,) = eager_shape (a , axis = 0 )
126142 res = []
127143 for idx in range (d ):
128144 w = weights if weights .ndim == 1 else weights [idx , ...]
129145 w = xp .take (w , sorter [idx , ...])
130146 x = xp .take (a [idx , ...], sorter [idx , ...])
131- res .append (_weighted_quantile_sorted_1d (x , q , w , n , average , nan_policy , xp , device ))
132- res = xp .stack (res , axis = 1 )
133- return res
147+ res .append (
148+ _weighted_quantile_sorted_1d (x , q , w , n , average , nan_policy , xp , device )
149+ )
150+
151+ return xp .stack (res , axis = 1 )
134152
135153
136154def _weighted_quantile_sorted_1d (
137- x : Array , q : Array , w : Array , n : int , average : bool , nan_policy : str ,
138- xp : ModuleType , device : Device
155+ x : Array ,
156+ q : Array ,
157+ w : Array ,
158+ n : int ,
159+ average : bool ,
160+ nan_policy : str ,
161+ xp : ModuleType ,
162+ device : Device ,
139163) -> Array :
140164 if nan_policy == "omit" :
141- w = xp .where (xp .isnan (x ), 0. , w )
165+ w = xp .where (xp .isnan (x ), 0.0 , w )
142166 elif xp .any (xp .isnan (x )):
143167 return xp .full (q .shape , xp .nan , dtype = x .dtype , device = device )
144168 cw = xp .cumulative_sum (w )
145169 t = cw [- 1 ] * q
146- i = xp .searchsorted (cw , t , side = ' left' )
147- j = xp .searchsorted (cw , t , side = ' right' )
170+ i = xp .searchsorted (cw , t , side = " left" )
171+ j = xp .searchsorted (cw , t , side = " right" )
148172 i = xp .clip (i , 0 , n - 1 )
149173 j = xp .clip (j , 0 , n - 1 )
150174
151175 # Ignore leading `weights=0` observations when `q=0`
152176 # see https://github.com/scikit-learn/scikit-learn/pull/20528
153- i = xp .where (q == 0. , j , i )
177+ i = xp .where (q == 0.0 , j , i )
154178 if average :
155179 # Ignore trailing `weights=0` observations when `q=1`
156- j = xp .where (q == 1. , i , j )
180+ j = xp .where (q == 1.0 , i , j )
157181 return (xp .take (x , i ) + xp .take (x , j )) / 2
158182 return xp .take (x , i )
0 commit comments