Skip to content

Commit 6ebe86e

Browse files
committed
refactored array2d_tuple_iter to handle 1D array
1 parent 0cd2dee commit 6ebe86e

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

src/_arraykit.c

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3626,7 +3626,7 @@ A2DTuple_new(PyArrayObject* array,
36263626
Py_INCREF((PyObject*)array);
36273627
a2dt->array = array;
36283628
a2dt->num_rows = num_rows;
3629-
a2dt->num_cols = num_cols;
3629+
a2dt->num_cols = num_cols; // -1 for 1D array
36303630
a2dt->pos = 0;
36313631
return (PyObject *)a2dt;
36323632
}
@@ -3649,19 +3649,35 @@ A2DTuple_iternext(A2DTupleObject *self) {
36493649
if (i < self->num_rows) {
36503650
npy_intp num_cols = self->num_cols;
36513651
PyArrayObject* array = self->array;
3652-
PyObject* tuple = PyTuple_New(num_cols);
36533652
PyObject* item;
3654-
if (tuple == NULL) {
3655-
return NULL;
3653+
PyObject* tuple;
3654+
3655+
if (num_cols > -1) { // ndim == 2
3656+
tuple = PyTuple_New(num_cols);
3657+
if (tuple == NULL) {
3658+
return NULL;
3659+
}
3660+
for (npy_intp j = 0; j < num_cols; ++j) {
3661+
// cannot assume array is contiguous
3662+
item = PyArray_ToScalar(PyArray_GETPTR2(array, i, j), array);
3663+
if (item == NULL) {
3664+
Py_DECREF(tuple);
3665+
return NULL;
3666+
}
3667+
PyTuple_SET_ITEM(tuple, j, item); // steals reference to item
3668+
}
36563669
}
3657-
for (npy_intp j = 0; j < num_cols; ++j) {
3658-
// cannot assume array is contiguous
3659-
item = PyArray_ToScalar(PyArray_GETPTR2(array, i, j), array);
3670+
else { // ndim == 1
3671+
tuple = PyTuple_New(1);
3672+
if (tuple == NULL) {
3673+
return NULL;
3674+
}
3675+
item = PyArray_ToScalar(PyArray_GETPTR1(array, i), array);
36603676
if (item == NULL) {
36613677
Py_DECREF(tuple);
36623678
return NULL;
36633679
}
3664-
PyTuple_SET_ITEM(tuple, j, item); // steals reference to item
3680+
PyTuple_SET_ITEM(tuple, 0, item); // steals reference to item
36653681
}
36663682
self->pos++;
36673683
return tuple;
@@ -3700,10 +3716,19 @@ static PyTypeObject A2DTupleType = {
37003716
static PyObject *
37013717
array2d_tuple_iter(PyObject *Py_UNUSED(m), PyObject *a)
37023718
{
3703-
AK_CHECK_NUMPY_ARRAY_2D(a);
3704-
PyArrayObject* array = (PyArrayObject *)a;
3719+
AK_CHECK_NUMPY_ARRAY(a);
3720+
PyArrayObject *array = (PyArrayObject *)a;
3721+
int ndim = PyArray_NDIM(array);
3722+
if (ndim != 1 && ndim != 2) {
3723+
return PyErr_Format(PyExc_NotImplementedError,
3724+
"Expected 1D or 2D array, not %i.",
3725+
ndim);
3726+
}
37053727
npy_intp num_rows = PyArray_DIM(array, 0);
3706-
npy_intp num_cols = PyArray_DIM(array, 1);
3728+
npy_intp num_cols = -1; // indicate 1d
3729+
if (ndim == 2) {
3730+
num_cols = PyArray_DIM(array, 1);
3731+
}
37073732
return A2DTuple_new(array, num_rows, num_cols);
37083733
}
37093734

0 commit comments

Comments
 (0)