Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions mlx/backend/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,28 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
"AsStrided must be used with row contiguous arrays only.");
}

// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
auto [no_bsx_size, row_contiguous, col_contiguous] =
check_contiguity(shape_, strides_);

int64_t l = 0, h = 0;
bool has_negative_stride = false;
for (int i = 0; i < strides_.size(); i++) {
auto delta = strides_[i] * (shape_[i] - 1);
if (strides_[i] >= 0) {
h += delta;
} else {
l += delta;
has_negative_stride |= shape_[i] > 1;
}
}
size_t data_size = out.size() == 0 ? 0 : (h - l) + 1;

auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.contiguous =
out.size() == 0 || (!has_negative_stride && no_bsx_size == data_size);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;

// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();

return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}

Expand Down
30 changes: 30 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2737,11 +2737,41 @@ TEST_CASE("test as_strided op") {
auto x = arange(10);
auto y = as_strided(x, {3, 3}, {1, 1}, 0);
auto expected = array({0, 1, 2, 1, 2, 3, 2, 3, 4}, {3, 3});
eval(y);
CHECK(array_equal(y, expected).item<bool>());
CHECK_EQ(y.data_size(), 5);
CHECK_FALSE(y.flags().contiguous);

y = as_strided(x, {3, 3}, {0, 3}, 0);
expected = array({0, 3, 6, 0, 3, 6, 0, 3, 6}, {3, 3});
eval(y);
CHECK(array_equal(y, expected).item<bool>());
CHECK_EQ(y.data_size(), 7);
CHECK_FALSE(y.flags().contiguous);

x = arange(24);
y = as_strided(x, {2, 3, 4}, {3, 1, 6}, 0);
expected = array(
{0, 6, 12, 18, 1, 7, 13, 19, 2, 8, 14, 20,
3, 9, 15, 21, 4, 10, 16, 22, 5, 11, 17, 23},
{2, 3, 4});
eval(y);
CHECK(array_equal(y, expected).item<bool>());
CHECK_EQ(y.data_size(), 24);
CHECK(y.flags().contiguous);
CHECK_FALSE(y.flags().row_contiguous);
CHECK_FALSE(y.flags().col_contiguous);

auto z = astype(y, float32);
CHECK(array_equal(z, astype(expected, float32)).item<bool>());

x = arange(10);
y = as_strided(x, {10}, {-1}, 9);
expected = array({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, {10});
eval(y);
CHECK(array_equal(y, expected).item<bool>());
CHECK_EQ(y.data_size(), 10);
CHECK_FALSE(y.flags().contiguous);

x = reshape(x, {2, 5}); // 0 1 2 3 ...
x = transpose(x, {1, 0}); // 0 5 1 6 2 7 ...
Expand Down
Loading