diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 2cda88a311..cbc90ed27e 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -19,27 +19,28 @@ void AsStrided::eval(const std::vector& inputs, array& out) { "AsStrided must be used with row contiguous arrays only."); } - // Compute the flags given the shape and strides - bool row_contiguous = true, col_contiguous = true; - size_t r = 1, c = 1; - for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { - row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); - col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); - r *= shape_[i]; - c *= shape_[j]; + auto [no_bsx_size, row_contiguous, col_contiguous] = + check_contiguity(shape_, strides_); + + int64_t l = 0, h = 0; + bool has_negative_stride = false; + for (int i = 0; i < strides_.size(); i++) { + auto delta = strides_[i] * (shape_[i] - 1); + if (strides_[i] >= 0) { + h += delta; + } else { + l += delta; + has_negative_stride |= shape_[i] > 1; + } } + size_t data_size = out.size() == 0 ? 0 : (h - l) + 1; + auto flags = in.flags(); - // TODO: Compute the contiguous flag in a better way cause now we are - // unnecessarily strict. - flags.contiguous = row_contiguous || col_contiguous; + flags.contiguous = + out.size() == 0 || (!has_negative_stride && no_bsx_size == data_size); flags.row_contiguous = row_contiguous; flags.col_contiguous = col_contiguous; - // There is no easy way to compute the actual data size so we use out.size(). - // The contiguous flag will almost certainly not be set so no code should - // rely on data_size anyway. - size_t data_size = out.size(); - return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1d05cb0c15..500195093e 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2737,11 +2737,41 @@ TEST_CASE("test as_strided op") { auto x = arange(10); auto y = as_strided(x, {3, 3}, {1, 1}, 0); auto expected = array({0, 1, 2, 1, 2, 3, 2, 3, 4}, {3, 3}); + eval(y); CHECK(array_equal(y, expected).item()); + CHECK_EQ(y.data_size(), 5); + CHECK_FALSE(y.flags().contiguous); y = as_strided(x, {3, 3}, {0, 3}, 0); expected = array({0, 3, 6, 0, 3, 6, 0, 3, 6}, {3, 3}); + eval(y); + CHECK(array_equal(y, expected).item()); + CHECK_EQ(y.data_size(), 7); + CHECK_FALSE(y.flags().contiguous); + + x = arange(24); + y = as_strided(x, {2, 3, 4}, {3, 1, 6}, 0); + expected = array( + {0, 6, 12, 18, 1, 7, 13, 19, 2, 8, 14, 20, + 3, 9, 15, 21, 4, 10, 16, 22, 5, 11, 17, 23}, + {2, 3, 4}); + eval(y); + CHECK(array_equal(y, expected).item()); + CHECK_EQ(y.data_size(), 24); + CHECK(y.flags().contiguous); + CHECK_FALSE(y.flags().row_contiguous); + CHECK_FALSE(y.flags().col_contiguous); + + auto z = astype(y, float32); + CHECK(array_equal(z, astype(expected, float32)).item()); + + x = arange(10); + y = as_strided(x, {10}, {-1}, 9); + expected = array({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, {10}); + eval(y); CHECK(array_equal(y, expected).item()); + CHECK_EQ(y.data_size(), 10); + CHECK_FALSE(y.flags().contiguous); x = reshape(x, {2, 5}); // 0 1 2 3 ... x = transpose(x, {1, 0}); // 0 5 1 6 2 7 ...