diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 289d21e4..4cdf255a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -67,7 +67,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array if xp is None: xp = array_namespace(x) - if 1 <= ndim <= 3 and ( + if 1 <= ndim <= 2 and ( is_numpy_namespace(xp) or is_jax_namespace(xp) or is_dask_namespace(xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ff050468..6631cc32 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -54,6 +54,8 @@ lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) +NestedFloatList = list[float] | list["NestedFloatList"] + class TestApplyWhere: @staticmethod @@ -291,7 +293,31 @@ def test_0D(self, xp: ModuleType): y = atleast_nd(x, ndim=5) xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1))) - def test_1D(self, xp: ModuleType): + @pytest.mark.parametrize( + ("input_shape", "ndim", "expected_shape"), + [ + ((1,), 0, (1,)), + ((5,), 1, (5,)), + ((2,), 2, (1, 2)), + ((3,), 3, (1, 1, 3)), + ((2,), 5, (1, 1, 1, 1, 2)), + ], + ) + def test_1D_shapes( + self, + input_shape: tuple[int], + ndim: int, + expected_shape: tuple[int], + xp: ModuleType, + ): + n = math.prod(input_shape) + x = xp.reshape(xp.asarray(list(range(n))), input_shape) + y = atleast_nd(x, ndim=ndim) + + assert y.shape == expected_shape + assert xp.sum(y) == int(n * (n - 1) / 2) + + def test_1D_values(self, xp: ModuleType): x = xp.asarray([0, 1]) y = atleast_nd(x, ndim=0) @@ -306,8 +332,32 @@ def test_1D(self, xp: ModuleType): y = atleast_nd(x, ndim=5) xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]])) - def test_2D(self, xp: ModuleType): - x = xp.asarray([[3.0]]) + @pytest.mark.parametrize( + ("input_shape", "ndim", "expected_shape"), + [ + ((2, 1), 0, (2, 1)), + ((5, 2), 1, (5, 2)), + ((2, 1), 2, (2, 1)), + ((3, 1), 3, (1, 3, 1)), + ((2, 8), 5, (1, 1, 1, 2, 8)), + ], + ) + def test_2D_shapes( + self, + input_shape: tuple[int], + ndim: int, + expected_shape: tuple[int], + xp: ModuleType, + ): + n = math.prod(input_shape) + x = xp.reshape(xp.asarray(list(range(n))), input_shape) + y = atleast_nd(x, ndim=ndim) + + assert y.shape == expected_shape + assert xp.sum(y) == int(n * (n - 1) / 2) + + def test_2D_values(self, xp: ModuleType): + x = xp.asarray([[3.0], [4.0]]) y = atleast_nd(x, ndim=0) xp_assert_equal(y, x) @@ -316,12 +366,36 @@ def test_2D(self, xp: ModuleType): xp_assert_equal(y, x) y = atleast_nd(x, ndim=3) - xp_assert_equal(y, 3 * xp.ones((1, 1, 1))) + xp_assert_equal(y, xp.asarray([[[3.0], [4.0]]])) y = atleast_nd(x, ndim=5) - xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1))) + xp_assert_equal(y, xp.asarray([[[[[3.0], [4.0]]]]])) + + @pytest.mark.parametrize( + ("input_shape", "ndim", "expected_shape"), + [ + ((2, 1, 1), 0, (2, 1, 1)), + ((1, 5, 2), 1, (1, 5, 2)), + ((2, 1, 1), 2, (2, 1, 1)), + ((1, 3, 1), 3, (1, 3, 1)), + ((2, 8, 1), 5, (1, 1, 2, 8, 1)), + ], + ) + def test_3D_shapes( + self, + input_shape: tuple[int], + ndim: int, + expected_shape: tuple[int], + xp: ModuleType, + ): + n = math.prod(input_shape) + x = xp.reshape(xp.asarray(list(range(n))), input_shape) + y = atleast_nd(x, ndim=ndim) + + assert y.shape == expected_shape + assert xp.sum(y) == int(n * (n - 1) / 2) - def test_3D(self, xp: ModuleType): + def test_3D_values(self, xp: ModuleType): x = xp.asarray([[[3.0], [2.0]]]) y = atleast_nd(x, ndim=0) @@ -336,8 +410,32 @@ def test_3D(self, xp: ModuleType): y = atleast_nd(x, ndim=5) xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]])) - def test_5D(self, xp: ModuleType): - x = xp.ones((1, 1, 1, 1, 1)) + @pytest.mark.parametrize( + ("input_shape", "ndim", "expected_shape"), + [ + ((2, 1, 1, 2, 1), 0, (2, 1, 1, 2, 1)), + ((1, 5, 2, 3, 2), 2, (1, 5, 2, 3, 2)), + ((2, 1, 1, 5, 2), 5, (2, 1, 1, 5, 2)), + ((1, 3, 1, 2, 1), 6, (1, 1, 3, 1, 2, 1)), + ((2, 8, 1, 9, 8), 9, (1, 1, 1, 1, 2, 8, 1, 9, 8)), + ], + ) + def test_5D_shapes( + self, + input_shape: tuple[int], + ndim: int, + expected_shape: tuple[int], + xp: ModuleType, + ): + n = math.prod(input_shape) + x = xp.reshape(xp.asarray(list(range(n))), input_shape) + y = atleast_nd(x, ndim=ndim) + + assert y.shape == expected_shape + assert xp.sum(y) == int(n * (n - 1) / 2) + + def test_5D_values(self, xp: ModuleType): + x = xp.asarray([[[[[3.0]], [[2.0]]]]]) y = atleast_nd(x, ndim=0) xp_assert_equal(y, x) @@ -349,19 +447,10 @@ def test_5D(self, xp: ModuleType): xp_assert_equal(y, x) y = atleast_nd(x, ndim=6) - xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1))) + xp_assert_equal(y, xp.asarray([[[[[[3.0]], [[2.0]]]]]])) y = atleast_nd(x, ndim=9) - xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) - - def test_device(self, xp: ModuleType, device: Device): - x = xp.asarray([1, 2, 3], device=device) - assert get_device(atleast_nd(x, ndim=2)) == device - - def test_xp(self, xp: ModuleType): - x = xp.asarray(1.0) - y = atleast_nd(x, ndim=1, xp=xp) - xp_assert_equal(y, xp.ones((1,))) + xp_assert_equal(y, xp.asarray([[[[[[[[[3.0]], [[2.0]]]]]]]]])) class TestBroadcastShapes: