Skip to content

Commit b997ba3

Browse files
committed
additional tests
1 parent d3144e3 commit b997ba3

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

src/methods.c

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,21 +266,28 @@ astype_array(PyObject* m, PyObject* args) {
266266
dtype = PyArray_DescrFromType(NPY_DEFAULT_TYPE);
267267
} else {
268268
if (!PyArray_DescrConverter(dtype_spec, &dtype)) {
269-
Py_DECREF((PyObject*)array);
270269
return NULL;
271270
}
272271
}
273272

274-
int dt_equal = PyArray_EquivTypes(PyArray_DESCR(array), dtype);
275-
if (dt_equal && !PyArray_ISWRITEABLE(array)) {
273+
if (PyArray_EquivTypes(PyArray_DESCR(array), dtype)) {
276274
Py_DECREF(dtype);
277-
Py_INCREF(a);
278-
return a;
275+
if (PyArray_ISWRITEABLE(array)) {
276+
PyObject* result = PyArray_NewCopy(array, NPY_ANYORDER);
277+
if (!result) {
278+
return NULL;
279+
}
280+
PyArray_CLEARFLAGS((PyArrayObject *)result, NPY_ARRAY_WRITEABLE);
281+
return result;
282+
}
283+
else { // already immutable
284+
Py_INCREF(a);
285+
return a;
286+
}
279287
}
280-
// if not already an object and converting to an object
281-
if (!dt_equal && dtype->type_num == NPY_OBJECT) {
282-
PyArray_Descr* array_dt = PyArray_DESCR(array);
283-
char kind = array_dt->kind;
288+
// if converting to an object
289+
if (dtype->type_num == NPY_OBJECT) {
290+
char kind = PyArray_DESCR(array)->kind;
284291
if ((kind == 'M' || kind == 'm')) {
285292
PyObject* dt_year = PyObject_GetAttrString(m, "dt_year");
286293
int is_objectable = AK_is_objectable_dt64(array, dt_year);

test/test_astype_array.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,33 @@ def test_astype_array_d2(self) -> None:
9191
self.assertEqual(a2.dtype, np.dtype(np.float64))
9292
self.assertEqual(a2.shape, (3,))
9393
self.assertFalse(a2.flags.writeable)
94+
95+
96+
97+
def test_astype_array_d3(self) -> None:
98+
a1 = np.array([10, 20, 30], dtype=np.int64)
99+
a2 = astype_array(a1, np.int64)
100+
101+
self.assertEqual(a2.dtype, np.dtype(np.int64))
102+
self.assertEqual(a2.shape, (3,))
103+
self.assertFalse(a2.flags.writeable)
104+
105+
self.assertNotEqual(id(a1), id(a2))
106+
107+
def test_astype_array_e(self) -> None:
108+
a1 = np.array(['2021', '2024', '1997', '1984', '2000', '1999'], dtype='datetime64[ns]').reshape((2, 3))
109+
110+
a2 = astype_array(a1, np.object_)
111+
self.assertEqual(a2.dtype, np.dtype(np.object_))
112+
self.assertEqual(a2.shape, (2, 3))
113+
self.assertFalse(a2.flags.writeable)
114+
self.assertEqual(
115+
list(list(a) for a in a2),
116+
[[np.datetime64('2021-01-01T00:00:00.000000000'),
117+
np.datetime64('2024-01-01T00:00:00.000000000'),
118+
np.datetime64('1997-01-01T00:00:00.000000000')],
119+
[np.datetime64('1984-01-01T00:00:00.000000000'),
120+
np.datetime64('2000-01-01T00:00:00.000000000'),
121+
np.datetime64('1999-01-01T00:00:00.000000000')]]
122+
)
123+

0 commit comments

Comments
 (0)