Skip to content

Commit d3144e3

Browse files
committed
improved handling of 2d arrays
1 parent cd72e5c commit d3144e3

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

src/methods.c

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ astype_array(PyObject* m, PyObject* args) {
259259
&dtype_spec)) {
260260
return NULL;
261261
}
262-
// AK_CHECK_NUMPY_ARRAY(a);
263262
PyArrayObject* array = (PyArrayObject*)a;
264263

265264
PyArray_Descr* dtype = NULL;
@@ -294,16 +293,25 @@ astype_array(PyObject* m, PyObject* args) {
294293
return NULL;
295294
}
296295
PyObject** data = (PyObject**)PyArray_DATA((PyArrayObject*)result);
297-
npy_intp size = PyArray_SIZE(array);
298296

299-
for (npy_intp i = 0; i < size; ++i) {
300-
PyObject* item = PyArray_Scalar(PyArray_GETPTR1(array, i), array_dt, a);
297+
PyArrayIterObject* it = (PyArrayIterObject*)PyArray_IterNew(a);
298+
if (!it) {
299+
Py_DECREF(result);
300+
return NULL;
301+
}
302+
303+
npy_intp i = 0;
304+
while (it->index < it->size) {
305+
PyObject* item = PyArray_ToScalar(it->dataptr, array);
301306
if (!item) {
302307
Py_DECREF(result);
308+
Py_DECREF(it);
303309
return NULL;
304310
}
305-
data[i] = item;
311+
data[i++] = item;
312+
PyArray_ITER_NEXT(it);
306313
}
314+
Py_DECREF(it);
307315
PyArray_CLEARFLAGS((PyArrayObject *)result, NPY_ARRAY_WRITEABLE);
308316
return result;
309317
}

test/test_astype_array.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,40 @@ def test_astype_array_b3(self) -> None:
5454
a2 = astype_array(a1, np.object_)
5555
self.assertEqual(a2.dtype, np.dtype(np.object_))
5656
self.assertFalse(a2.flags.writeable)
57-
import ipdb; ipdb.set_trace()
5857
self.assertEqual(
5958
list(list(a) for a in a2),
6059
[[np.datetime64('2021'), np.datetime64('2024')], [np.datetime64('1984'), np.datetime64('1642')]])
60+
61+
def test_astype_array_b4(self) -> None:
62+
a1 = np.array(['2021', '2024', '1532', '1984', '1642', '899'], dtype=np.datetime64).reshape((2, 3))
63+
64+
a2 = astype_array(a1, np.object_)
65+
self.assertEqual(a2.dtype, np.dtype(np.object_))
66+
self.assertEqual(a2.shape, (2, 3))
67+
self.assertFalse(a2.flags.writeable)
68+
self.assertEqual(
69+
list(list(a) for a in a2),
70+
[[np.datetime64('2021'), np.datetime64('2024'), np.datetime64('1532')],
71+
[np.datetime64('1984'), np.datetime64('1642'), np.datetime64('899')]])
72+
73+
def test_astype_array_c(self) -> None:
74+
with self.assertRaises(TypeError):
75+
_ = astype_array([3, 4, 5], np.int64)
76+
77+
78+
def test_astype_array_d1(self) -> None:
79+
a1 = np.array([10, 20, 30], dtype=np.int64)
80+
a2 = astype_array(a1)
81+
82+
self.assertEqual(a2.dtype, np.dtype(np.float64))
83+
self.assertEqual(a2.shape, (3,))
84+
self.assertFalse(a2.flags.writeable)
85+
86+
87+
def test_astype_array_d2(self) -> None:
88+
a1 = np.array([10, 20, 30], dtype=np.int64)
89+
a2 = astype_array(a1, None)
90+
91+
self.assertEqual(a2.dtype, np.dtype(np.float64))
92+
self.assertEqual(a2.shape, (3,))
93+
self.assertFalse(a2.flags.writeable)

0 commit comments

Comments
 (0)