|
29 | 29 | one_hot, |
30 | 30 | pad, |
31 | 31 | partition, |
| 32 | + quantile, |
32 | 33 | setdiff1d, |
33 | 34 | sinc, |
34 | 35 | ) |
@@ -1529,3 +1530,101 @@ def test_kind(self, xp: ModuleType, library: Backend): |
1529 | 1530 | expected = xp.asarray([False, True, False, True]) |
1530 | 1531 | res = isin(a, b, kind="sort") |
1531 | 1532 | xp_assert_equal(res, expected) |
| 1533 | + |
| 1534 | + |
| 1535 | +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no xp.take") |
| 1536 | +class TestQuantile: |
| 1537 | + def test_basic(self, xp: ModuleType): |
| 1538 | + x = xp.asarray([1, 2, 3, 4, 5]) |
| 1539 | + actual = quantile(x, 0.5) |
| 1540 | + expect = xp.asarray(3.0, dtype=xp.float64) |
| 1541 | + xp_assert_close(actual, expect) |
| 1542 | + |
| 1543 | + def test_multiple_quantiles(self, xp: ModuleType): |
| 1544 | + x = xp.asarray([1, 2, 3, 4, 5]) |
| 1545 | + actual = quantile(x, xp.asarray([0.25, 0.5, 0.75])) |
| 1546 | + expect = xp.asarray([2.0, 3.0, 4.0], dtype=xp.float64) |
| 1547 | + xp_assert_close(actual, expect) |
| 1548 | + |
| 1549 | + def test_shape(self, xp: ModuleType): |
| 1550 | + a = xp.asarray(np.random.rand(3, 4, 5)) |
| 1551 | + q = xp.asarray(np.random.rand(2)) |
| 1552 | + assert quantile(a, q, axis=0).shape == (2, 4, 5) |
| 1553 | + assert quantile(a, q, axis=1).shape == (2, 3, 5) |
| 1554 | + assert quantile(a, q, axis=2).shape == (2, 3, 4) |
| 1555 | + |
| 1556 | + assert quantile(a, q, axis=0, keepdims=True).shape == (2, 1, 4, 5) |
| 1557 | + assert quantile(a, q, axis=1, keepdims=True).shape == (2, 3, 1, 5) |
| 1558 | + assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1) |
| 1559 | + |
| 1560 | + def test_against_numpy(self, xp: ModuleType): |
| 1561 | + a_np = np.random.rand(3, 4, 5) |
| 1562 | + q_np = np.random.rand(2) |
| 1563 | + a = xp.asarray(a_np) |
| 1564 | + q = xp.asarray(q_np) |
| 1565 | + for keepdims in [False, True]: |
| 1566 | + for axis in [None, *range(a.ndim)]: |
| 1567 | + actual = quantile(a, q, axis=axis, keepdims=keepdims) |
| 1568 | + expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims) |
| 1569 | + expected = xp.asarray(expected, dtype=xp.float64) |
| 1570 | + xp_assert_close(actual, expected, atol=1e-12) |
| 1571 | + |
| 1572 | + def test_2d_axis(self, xp: ModuleType): |
| 1573 | + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) |
| 1574 | + actual = quantile(x, 0.5, axis=0) |
| 1575 | + expect = xp.asarray([2.5, 3.5, 4.5], dtype=xp.float64) |
| 1576 | + xp_assert_close(actual, expect) |
| 1577 | + |
| 1578 | + def test_2d_axis_keepdims(self, xp: ModuleType): |
| 1579 | + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) |
| 1580 | + actual = quantile(x, 0.5, axis=0, keepdims=True) |
| 1581 | + expect = xp.asarray([[2.5, 3.5, 4.5]], dtype=xp.float64) |
| 1582 | + xp_assert_close(actual, expect) |
| 1583 | + |
| 1584 | + def test_methods(self, xp: ModuleType): |
| 1585 | + x = xp.asarray([1, 2, 3, 4, 5]) |
| 1586 | + methods = ["linear"] #"hazen", "weibull"] |
| 1587 | + for method in methods: |
| 1588 | + actual = quantile(x, 0.5, method=method) |
| 1589 | + # All methods should give reasonable results |
| 1590 | + assert 2.5 <= float(actual) <= 3.5 |
| 1591 | + |
| 1592 | + def test_edge_cases(self, xp: ModuleType): |
| 1593 | + x = xp.asarray([1, 2, 3, 4, 5]) |
| 1594 | + # q = 0 should give minimum |
| 1595 | + actual = quantile(x, 0.0) |
| 1596 | + expect = xp.asarray(1.0, dtype=xp.float64) |
| 1597 | + xp_assert_close(actual, expect) |
| 1598 | + |
| 1599 | + # q = 1 should give maximum |
| 1600 | + actual = quantile(x, 1.0) |
| 1601 | + expect = xp.asarray(5.0, dtype=xp.float64) |
| 1602 | + xp_assert_close(actual, expect) |
| 1603 | + |
| 1604 | + def test_invalid_q(self, xp: ModuleType): |
| 1605 | + x = xp.asarray([1, 2, 3, 4, 5]) |
| 1606 | + _ = quantile(x, 1.0) |
| 1607 | + # ^ FIXME: here just to make this test fail for sparse backend |
| 1608 | + # q > 1 should raise |
| 1609 | + with pytest.raises( |
| 1610 | + ValueError, match=r"`q` values must be in the range \[0, 1\]" |
| 1611 | + ): |
| 1612 | + _ = quantile(x, 1.5) |
| 1613 | + # q < 0 should raise |
| 1614 | + with pytest.raises( |
| 1615 | + ValueError, match=r"`q` values must be in the range \[0, 1\]" |
| 1616 | + ): |
| 1617 | + _ = quantile(x, -0.5) |
| 1618 | + |
| 1619 | + def test_device(self, xp: ModuleType, device: Device): |
| 1620 | + if hasattr(device, 'type') and device.type == "meta": |
| 1621 | + pytest.xfail("No Tensor.item() on meta device") |
| 1622 | + x = xp.asarray([1, 2, 3, 4, 5], device=device) |
| 1623 | + actual = quantile(x, 0.5) |
| 1624 | + assert get_device(actual) == device |
| 1625 | + |
| 1626 | + def test_xp(self, xp: ModuleType): |
| 1627 | + x = xp.asarray([1, 2, 3, 4, 5]) |
| 1628 | + actual = quantile(x, 0.5, xp=xp) |
| 1629 | + expect = xp.asarray(3.0, dtype=xp.float64) |
| 1630 | + xp_assert_close(actual, expect) |
0 commit comments