@@ -3626,7 +3626,7 @@ A2DTuple_new(PyArrayObject* array,
36263626 Py_INCREF ((PyObject * )array );
36273627 a2dt -> array = array ;
36283628 a2dt -> num_rows = num_rows ;
3629- a2dt -> num_cols = num_cols ;
3629+ a2dt -> num_cols = num_cols ; // -1 for 1D array
36303630 a2dt -> pos = 0 ;
36313631 return (PyObject * )a2dt ;
36323632}
@@ -3649,19 +3649,35 @@ A2DTuple_iternext(A2DTupleObject *self) {
36493649 if (i < self -> num_rows ) {
36503650 npy_intp num_cols = self -> num_cols ;
36513651 PyArrayObject * array = self -> array ;
3652- PyObject * tuple = PyTuple_New (num_cols );
36533652 PyObject * item ;
3654- if (tuple == NULL ) {
3655- return NULL ;
3653+ PyObject * tuple ;
3654+
3655+ if (num_cols > -1 ) { // ndim == 2
3656+ tuple = PyTuple_New (num_cols );
3657+ if (tuple == NULL ) {
3658+ return NULL ;
3659+ }
3660+ for (npy_intp j = 0 ; j < num_cols ; ++ j ) {
3661+ // cannot assume array is contiguous
3662+ item = PyArray_ToScalar (PyArray_GETPTR2 (array , i , j ), array );
3663+ if (item == NULL ) {
3664+ Py_DECREF (tuple );
3665+ return NULL ;
3666+ }
3667+ PyTuple_SET_ITEM (tuple , j , item ); // steals reference to item
3668+ }
36563669 }
3657- for (npy_intp j = 0 ; j < num_cols ; ++ j ) {
3658- // cannot assume array is contiguous
3659- item = PyArray_ToScalar (PyArray_GETPTR2 (array , i , j ), array );
3670+ else { // ndim == 1
3671+ tuple = PyTuple_New (1 );
3672+ if (tuple == NULL ) {
3673+ return NULL ;
3674+ }
3675+ item = PyArray_ToScalar (PyArray_GETPTR1 (array , i ), array );
36603676 if (item == NULL ) {
36613677 Py_DECREF (tuple );
36623678 return NULL ;
36633679 }
3664- PyTuple_SET_ITEM (tuple , j , item ); // steals reference to item
3680+ PyTuple_SET_ITEM (tuple , 0 , item ); // steals reference to item
36653681 }
36663682 self -> pos ++ ;
36673683 return tuple ;
@@ -3700,10 +3716,19 @@ static PyTypeObject A2DTupleType = {
37003716static PyObject *
37013717array2d_tuple_iter (PyObject * Py_UNUSED (m ), PyObject * a )
37023718{
3703- AK_CHECK_NUMPY_ARRAY_2D (a );
3704- PyArrayObject * array = (PyArrayObject * )a ;
3719+ AK_CHECK_NUMPY_ARRAY (a );
3720+ PyArrayObject * array = (PyArrayObject * )a ;
3721+ int ndim = PyArray_NDIM (array );
3722+ if (ndim != 1 && ndim != 2 ) {
3723+ return PyErr_Format (PyExc_NotImplementedError ,
3724+ "Expected 1D or 2D array, not %i." ,
3725+ ndim );
3726+ }
37053727 npy_intp num_rows = PyArray_DIM (array , 0 );
3706- npy_intp num_cols = PyArray_DIM (array , 1 );
3728+ npy_intp num_cols = -1 ; // indicate 1d
3729+ if (ndim == 2 ) {
3730+ num_cols = PyArray_DIM (array , 1 );
3731+ }
37073732 return A2DTuple_new (array , num_rows , num_cols );
37083733}
37093734
0 commit comments