11import math
22import warnings
33from types import ModuleType
4- from typing import Any , cast
4+ from typing import Any , Literal , cast , get_args
55
66import hypothesis
77import hypothesis .extra .numpy as npst
@@ -1531,6 +1531,7 @@ def test_kind(self, xp: ModuleType, library: Backend):
15311531 res = isin (a , b , kind = "sort" )
15321532 xp_assert_equal (res , expected )
15331533
1534+ METHOD = Literal ["linear" , "inverted_cdf" , "averaged_inverted_cdf" ]
15341535
15351536@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no xp.take" )
15361537class TestQuantile :
@@ -1558,21 +1559,67 @@ def test_shape(self, xp: ModuleType):
15581559 assert quantile (a , q , axis = 1 , keepdims = True ).shape == (2 , 3 , 1 , 5 )
15591560 assert quantile (a , q , axis = 2 , keepdims = True ).shape == (2 , 3 , 4 , 1 )
15601561
1562+ @pytest .mark .parametrize ("with_nans" , ["no_nans" , "with_nans" ])
1563+ @pytest .mark .parametrize ("method" , get_args (METHOD ))
1564+ def test_against_numpy_1d (self , xp : ModuleType , with_nans : str , method : METHOD ):
1565+ rng = np .random .default_rng ()
1566+ a_np = rng .random (40 )
1567+ if with_nans == "with_nans" :
1568+ a_np [rng .random (a_np .shape ) < rng .random () * 0.5 ] = np .nan
1569+ q_np = np .asarray ([0 , * rng .random (2 ), 1 ])
1570+ a = xp .asarray (a_np )
1571+ q = xp .asarray (q_np )
1572+
1573+ actual = quantile (a , q , method = method )
1574+ expected = np .quantile (a_np , q_np , method = method )
1575+ expected = xp .asarray (expected )
1576+ xp_assert_close (actual , expected )
1577+
1578+ @pytest .mark .parametrize ("with_nans" , ["no_nans" , "with_nans" ])
1579+ @pytest .mark .parametrize ("method" , get_args (METHOD ))
15611580 @pytest .mark .parametrize ("keepdims" , [True , False ])
1562- def test_against_numpy (self , xp : ModuleType , keepdims : bool ):
1581+ def test_against_numpy_nd (self , xp : ModuleType , keepdims : bool ,
1582+ with_nans : str , method : METHOD ):
15631583 rng = np .random .default_rng ()
15641584 a_np = rng .random ((3 , 4 , 5 ))
1585+ if with_nans == "with_nans" :
1586+ a_np [rng .random (a_np .shape ) < rng .random ()] = np .nan
15651587 q_np = rng .random (2 )
15661588 a = xp .asarray (a_np )
15671589 q = xp .asarray (q_np )
15681590 for axis in [None , * range (a .ndim )]:
1569- actual = quantile (a , q , axis = axis , keepdims = keepdims )
1570- expected = np .quantile (a_np , q_np , axis = axis , keepdims = keepdims )
1591+ actual = quantile (a , q , axis = axis , keepdims = keepdims , method = method )
1592+ expected = np .quantile (
1593+ a_np , q_np , axis = axis , keepdims = keepdims , method = method
1594+ )
15711595 expected = xp .asarray (expected )
1572- xp_assert_close (actual , expected , atol = 1e-12 )
1596+ xp_assert_close (actual , expected )
1597+
1598+ @pytest .mark .parametrize ("nan_policy" , ["no_nans" , "propagate" ])
1599+ @pytest .mark .parametrize ("with_weights" , ["with_weights" , "no_weights" ])
1600+ def test_against_median (
1601+ self , xp : ModuleType , nan_policy : str , with_weights : str ,
1602+ ):
1603+ rng = np .random .default_rng ()
1604+ n = 40
1605+ a_np = rng .random (n )
1606+ w_np = rng .integers (0 , 2 , n ) if with_weights == "with_weights" else None
1607+ if nan_policy == "no_nans" :
1608+ nan_policy = "propagate"
1609+ else :
1610+ # from 0% to 50% of NaNs:
1611+ a_np [rng .random (n ) < rng .random (n ) * 0.5 ] = np .nan
1612+ m = "averaged_inverted_cdf"
1613+
1614+ np_median = np .nanmedian if nan_policy == "omit" else np .median
1615+ expected = np_median (a_np if w_np is None else a_np [w_np > 0 ])
1616+ a = xp .asarray (a_np )
1617+ w = xp .asarray (w_np ) if w_np is not None else None
1618+ actual = quantile (a , 0.5 , method = m , nan_policy = nan_policy , weights = w )
1619+ xp_assert_close (actual , xp .asarray (expected ))
15731620
15741621 @pytest .mark .parametrize ("keepdims" , [True , False ])
1575- @pytest .mark .parametrize ("nan_policy" , ["omit " , "no_nans " , "propagate " ])
1622+ @pytest .mark .parametrize ("nan_policy" , ["no_nans " , "propagate " , "omit " ])
15761623 @pytest .mark .parametrize ("q_np" , [0.5 , 0.0 , 1.0 , np .linspace (0 , 1 , num = 11 )])
15771624 def test_weighted_against_numpy (
15781625 self , xp : ModuleType , keepdims : bool , q_np : Array | float , nan_policy : str
@@ -1581,7 +1628,7 @@ def test_weighted_against_numpy(
15811628 pytest .xfail (reason = "NumPy 1.x does not support weights in quantile" )
15821629 rng = np .random .default_rng ()
15831630 n , d = 10 , 20
1584- a_np = rng .random ((n , d ))
1631+ a_2d = rng .random ((n , d ))
15851632 mask_nan = np .zeros ((n , d ), dtype = bool )
15861633 if nan_policy == "no_nans" :
15871634 nan_policy = "propagate"
@@ -1590,36 +1637,36 @@ def test_weighted_against_numpy(
15901637 mask_nan = rng .random ((n , d )) < rng .random ((n , 1 ))
15911638 # don't put nans in the first row:
15921639 mask_nan [:] = False
1593- a_np [mask_nan ] = np .nan
1640+ a_2d [mask_nan ] = np .nan
15941641
1595- a = xp .asarray (a_np , copy = True )
15961642 q = xp .asarray (q_np , copy = True )
1597- m = "inverted_cdf"
1643+ m : METHOD = "inverted_cdf"
15981644
15991645 np_quantile = np .quantile
16001646 if nan_policy == "omit" :
16011647 np_quantile = np .nanquantile
16021648
1603- for w_np , axis in [
1604- (rng .random (n ), 0 ),
1605- (rng .random (d ), 1 ),
1606- (rng .integers (0 , 2 , n ), 0 ),
1607- (rng .integers (0 , 2 , d ), 1 ),
1608- (rng .integers (0 , 2 , (n , d )), 0 ),
1609- (rng .integers (0 , 2 , (n , d )), 1 ),
1649+ for a_np , w_np , axis in [
1650+ (a_2d , rng .random (n ), 0 ),
1651+ (a_2d , rng .random (d ), 1 ),
1652+ (a_2d [0 ], rng .random (d ), None ),
1653+ (a_2d , rng .integers (0 , 3 , n ), 0 ),
1654+ (a_2d , rng .integers (0 , 2 , d ), 1 ),
1655+ (a_2d , rng .integers (0 , 2 , (n , d )), 0 ),
1656+ (a_2d , rng .integers (0 , 3 , (n , d )), 1 ),
16101657 ]:
16111658 with warnings .catch_warnings (record = True ) as warning :
16121659 divide_msg = "invalid value encountered in divide"
16131660 warnings .filterwarnings ("always" , divide_msg , RuntimeWarning )
16141661 nan_slice_msg = "All-NaN slice encountered"
16151662 warnings .filterwarnings ("ignore" , nan_slice_msg , RuntimeWarning )
16161663 try :
1617- expected = np_quantile ( # type: ignore[call-overload]
1664+ expected = np_quantile (
16181665 a_np ,
16191666 np .asarray (q_np ),
16201667 axis = axis ,
16211668 method = m ,
1622- weights = w_np ,
1669+ weights = w_np , # type: ignore[arg-type]
16231670 keepdims = keepdims ,
16241671 )
16251672 except IndexError :
@@ -1630,6 +1677,7 @@ def test_weighted_against_numpy(
16301677 continue
16311678 expected = xp .asarray (expected )
16321679
1680+ a = xp .asarray (a_np )
16331681 w = xp .asarray (w_np )
16341682 actual = quantile (
16351683 a ,
@@ -1640,19 +1688,7 @@ def test_weighted_against_numpy(
16401688 keepdims = keepdims ,
16411689 nan_policy = nan_policy ,
16421690 )
1643- xp_assert_close (actual , expected , atol = 1e-12 )
1644-
1645- def test_2d_axis (self , xp : ModuleType ):
1646- x = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
1647- actual = quantile (x , 0.5 , axis = 0 )
1648- expect = xp .asarray ([2.5 , 3.5 , 4.5 ], dtype = xp .float64 )
1649- xp_assert_close (actual , expect )
1650-
1651- def test_2d_axis_keepdims (self , xp : ModuleType ):
1652- x = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
1653- actual = quantile (x , 0.5 , axis = 0 , keepdims = True )
1654- expect = xp .asarray ([[2.5 , 3.5 , 4.5 ]], dtype = xp .float64 )
1655- xp_assert_close (actual , expect )
1691+ xp_assert_close (actual , expected )
16561692
16571693 def test_methods (self , xp : ModuleType ):
16581694 x = xp .asarray ([1 , 2 , 3 , 4 , 5 ])
0 commit comments