1010
1111
1212def quantile ( # numpydoc ignore=PR01,RT01
13- x ,
14- p ,
13+ a : Array ,
14+ q : Array | float ,
1515 / ,
1616 method : str = 'linear' , # noqa: ARG001
1717 axis : int | None = None ,
@@ -20,107 +20,67 @@ def quantile( # numpydoc ignore=PR01,RT01
2020 xp : ModuleType ,
2121):
2222 """See docstring in `array_api_extra._delegation.py`."""
23- # Input validation / standardization
24- temp = _quantile_iv (x , p , axis , keepdims )
25- y , p , axis , keepdims , n , axis_none , ndim = temp
23+ device = get_device (a )
24+ floating_dtype = xp .result_type (a , xp .asarray (q ))
25+ a = xp .asarray (a , dtype = floating_dtype , device = device )
26+ q = xp .asarray (q , dtype = floating_dtype , device = device )
2627
27- res = _quantile_hf (y , p , n , xp )
28+ if xp .any ((q > 1 ) | (q < 0 ) | xp .isnan (q )):
29+ raise ValueError ("`q` values must be in the range [0, 1]" )
2830
29- # Reshape per axis/keepdims
30- if axis_none and keepdims :
31- shape = (1 ,)* (ndim - 1 ) + res .shape
32- res = xp .reshape (res , shape )
33- axis = - 1
34-
35- res = xp .moveaxis (res , - 1 , axis )
36-
37- if not keepdims :
38- res = xp .squeeze (res , axis = axis )
39-
40- return res [()] if res .ndim == 0 else res
41-
42-
43- def _quantile_iv (
44- x : Array ,
45- p : Array ,
46- axis : int | None ,
47- keepdims : bool ,
48- xp : ModuleType
49- ):
50-
51- if not xp .isdtype (xp .asarray (x ).dtype , ('integral' , 'real floating' )):
52- raise ValueError ("`x` must have real dtype." )
53-
54- if not xp .isdtype (xp .asarray (p ).dtype , 'real floating' ):
55- raise ValueError ("`p` must have real floating dtype." )
56-
57- p_mask = (p > 1 ) | (p < 0 ) | xp .isnan (p )
58- if xp .any (p_mask ):
59- raise ValueError ("`p` values must be in the range [0, 1]" )
60-
61- device = get_device (x )
62- floating_dtype = xp .result_type (x , p )
63- x = xp .asarray (x , dtype = floating_dtype , device = device )
64- p = xp .asarray (p , dtype = floating_dtype , device = device )
65- dtype = x .dtype
31+ q_scalar = q .ndim == 0
32+ if q_scalar :
33+ q = xp .reshape (q , (1 ,))
6634
6735 axis_none = axis is None
68- ndim = max (x .ndim , p .ndim )
6936 if axis_none :
70- x = xp .reshape (x , (- 1 ,))
71- p = xp .reshape (p , (- 1 ,))
37+ a = xp .reshape (a , (- 1 ,))
7238 axis = 0
73- elif np .iterable (axis ) or int (axis ) != axis :
74- message = "`axis` must be an integer or None."
75- raise ValueError (message )
76- elif (axis >= ndim ) or (axis < - ndim ):
77- message = "`axis` is not compatible with the shapes of the inputs."
78- raise ValueError (message )
7939 axis = int (axis )
8040
81- if keepdims not in {None , True , False }:
82- message = "If specified, `keepdims` must be True or False."
83- raise ValueError (message )
84-
41+ n = eager_shape (a , axis )
8542 # If data has length zero along `axis`, the result will be an array of NaNs just
8643 # as if the data had length 1 along axis and were filled with NaNs.
87- n = eager_shape (x , axis )
8844 if n == 0 :
89- shape = eager_shape (x )
45+ shape = list ( eager_shape (a ) )
9046 shape [axis ] = 1
9147 n = 1
92- x = xp .full (shape , xp .nan , dtype = dtype , device = device )
93-
94- y = xp .sort (x , axis = axis , stable = False )
95- # FIXME: I still need to look into the broadcasting:
96- y , p = _broadcast_arrays ((y , p ), axis = axis )
48+ a = xp .full (shape , xp .nan , dtype = floating_dtype , device = device )
9749
98- p_shape = eager_shape ( p )
99- if ( keepdims is False ) and ( p_shape [ axis ] != 1 ):
100- message = "`keepdims` may be False only if the length of `p` along `axis` is 1."
101- raise ValueError ( message )
102- keepdims = ( p_shape [ axis ] != 1 ) if keepdims is None else keepdims
50+ a = xp . sort ( a , axis = axis , stable = False )
51+ # to support weights, the main thing would be to
52+ # argsort a, and then use it to sort a and w.
53+ # The hard part will be dealing with 0-weights and NaNs
54+ # But maybe a proper use of searchsorted + left/right side will work?
10355
104- y = xp .moveaxis (y , axis , - 1 )
105- p = xp .moveaxis (p , axis , - 1 )
56+ res = _quantile_hf (a , q , n , axis , xp )
10657
107- nans = xp .isnan (y )
108- nan_out = xp .any (nans , axis = - 1 )
109- if xp .any (nan_out ):
110- y = xp .asarray (y , copy = True ) # ensure writable
111- y = at (y , nan_out ).set (xp .nan )
58+ # reshaping to conform to doc/other libs' behavior
59+ if axis_none :
60+ if keepdims :
61+ res = xp .reshape (res , q .shape + (1 ,) * a .ndim )
62+ else :
63+ res = xp .moveaxis (res , axis , 0 )
64+ if keepdims :
65+ shape = list (a .shape )
66+ shape [axis ] = 1
67+ shape = q .shape + tuple (shape )
68+ res = xp .reshape (res , shape )
11269
113- return y , p , axis , keepdims , n , axis_none , ndim , xp
70+ return res [ 0 , ...] if q_scalar else res
11471
11572
116- def _quantile_hf (y , p , n , xp ):
73+ def _quantile_hf (y : Array , p : Array , n : int , axis : int , xp : ModuleType ):
11774 m = 1 - p
11875 jg = p * n + m - 1
11976 j = jg // 1
12077 g = jg % 1
12178 g [j < 0 ] = 0
12279 j = xp .clip (j , 0. , n - 1 )
12380 jp1 = xp .clip (j + 1 , 0. , n - 1 )
81+ # `̀j` and `jp1` are 1d arrays
12482
125- return ((1 - g ) * xp .take_along_axis (y , xp .astype (j , xp .int64 ), axis = - 1 )
126- + g * xp .take_along_axis (y , xp .astype (jp1 , xp .int64 ), axis = - 1 ))
83+ return (
84+ (1 - g ) * xp .take (y , xp .astype (j , xp .int64 ), axis = axis )
85+ + g * xp .take (y , xp .astype (jp1 , xp .int64 ), axis = axis )
86+ )
0 commit comments