Skip to content

Commit 2b4cce1

Browse files
committed
updated register implementation, added tests
1 parent 98e8d8f commit 2b4cce1

File tree

6 files changed

+62
-22
lines changed

6 files changed

+62
-22
lines changed

doc/articles/block_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pickle
1010

1111
from arraykit import BlockIndex
12-
# from arraykit import ErrorInitBlocks
12+
# from arraykit import ErrorInitTypeBlocks
1313
from arraykit import shape_filter
1414
from arraykit import resolve_dtype
1515

src/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ._arraykit import __version__
66
from ._arraykit import ArrayGO as ArrayGO
77
from ._arraykit import BlockIndex as BlockIndex
8-
from ._arraykit import ErrorInitBlocks as ErrorInitBlocks
8+
from ._arraykit import ErrorInitTypeBlocks as ErrorInitTypeBlocks
99

1010
from ._arraykit import immutable_filter as immutable_filter
1111
from ._arraykit import mloc as mloc

src/__init__.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ _T = tp.TypeVar('_T')
66

77
__version__: str
88

9-
class ErrorInitBlocks:
9+
class ErrorInitTypeBlocks:
1010
def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None: ...
1111
def with_traceback(self, tb: Exception) -> Exception: ...
1212
def __setstate__(self) -> None: ...
@@ -27,10 +27,12 @@ class ArrayGO:
2727

2828
class BlockIndex:
2929
shape: tp.Tuple[int, int]
30-
dtype: tp.Optional[np.dtype]
30+
dtype: np.dtype
31+
rows: int
32+
columns: int
3133

3234
def __init__() -> None: ...
33-
def register(self, __value: object) -> None: ...
35+
def register(self, __value: np.ndarray) -> bool: ...
3436
def to_list(self,) -> tp.List[int]: ...
3537
def to_bytes(self,) -> bytes: ...
3638
def copy(self,) -> 'BlockIndex': ...

src/_arraykit.c

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4114,7 +4114,7 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
41144114
//------------------------------------------------------------------------------
41154115

41164116
static PyTypeObject BlockIndexType;
4117-
static PyObject *ErrorInitBlocks;
4117+
static PyObject *ErrorInitTypeBlocks;
41184118

41194119
// NOTE: we use platform size types here, which are appropriate for the values, but might pose issues if trying to pass pickles between 32 and 64 bit machines.
41204120
typedef struct BlockIndexRecord {
@@ -4777,18 +4777,18 @@ BlockIndex_repr(BlockIndexObject *self) {
47774777
dt);
47784778
}
47794779

4780-
// Returns NULL on error, None otherwise. This checks and raises on non-array inputs, dimensions other than 1 or 2.
4780+
// Returns NULL on error, True if the block should be reatained, False if the block has zero columns and should not be retained. This checks and raises on non-array inputs, dimensions other than 1 or 2, and mis-aligned columns.
47814781
static PyObject *
47824782
BlockIndex_register(BlockIndexObject *self, PyObject *value) {
47834783
if (!PyArray_Check(value)) {
4784-
PyErr_Format(ErrorInitBlocks, "Found non-array block: %R", value);
4784+
PyErr_Format(ErrorInitTypeBlocks, "Found non-array block: %R", value);
47854785
return NULL;
47864786
}
47874787
PyArrayObject *a = (PyArrayObject *)value;
47884788
int ndim = PyArray_NDIM(a);
47894789

47904790
if (ndim < 1 || ndim > 2) {
4791-
PyErr_Format(ErrorInitBlocks, "Array block has invalid dimensions: %i", ndim);
4791+
PyErr_Format(ErrorInitTypeBlocks, "Array block has invalid dimensions: %i", ndim);
47924792
return NULL;
47934793
}
47944794
Py_ssize_t increment = ndim == 1 ? 1 : PyArray_DIM(a, 1);
@@ -4799,13 +4799,17 @@ BlockIndex_register(BlockIndexObject *self, PyObject *value) {
47994799
self->row_count = alignment;
48004800
}
48014801
else if (self->row_count != alignment) {
4802-
PyErr_Format(ErrorInitBlocks,
4802+
PyErr_Format(ErrorInitTypeBlocks,
48034803
"Array block has unaligned row count: found %i, expected %i",
48044804
alignment,
48054805
self->row_count);
48064806
return NULL;
48074807
}
48084808

4809+
if (increment == 0) {
4810+
Py_RETURN_FALSE;
4811+
}
4812+
48094813
PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref
48104814
if (self->dtype == NULL) {
48114815
Py_INCREF((PyObject*)dt);
@@ -4829,7 +4833,7 @@ BlockIndex_register(BlockIndexObject *self, PyObject *value) {
48294833
self->bir_count++;
48304834
}
48314835
self->block_count++;
4832-
Py_RETURN_NONE;
4836+
Py_RETURN_TRUE;
48334837
}
48344838

48354839

@@ -4942,6 +4946,7 @@ BlockIndex_iter(BlockIndexObject* self) {
49424946

49434947
static PyObject *
49444948
BlockIndex_shape_getter(BlockIndexObject *self, void* Py_UNUSED(closure)){
4949+
// NOTE: this could be cached
49454950
return Py_BuildValue("nn", self->row_count, self->bir_count);
49464951
}
49474952

@@ -5395,12 +5400,12 @@ PyInit__arraykit(void)
53955400
{
53965401
import_array();
53975402

5398-
ErrorInitBlocks = PyErr_NewExceptionWithDoc(
5399-
"arraykit.ErrorInitBlocks",
5403+
ErrorInitTypeBlocks = PyErr_NewExceptionWithDoc(
5404+
"arraykit.ErrorInitTypeBlocks",
54005405
"RuntimeError error in block initialization.",
54015406
PyExc_RuntimeError,
54025407
NULL);
5403-
if (ErrorInitBlocks == NULL) {
5408+
if (ErrorInitTypeBlocks == NULL) {
54045409
return NULL;
54055410
}
54065411

@@ -5426,7 +5431,7 @@ PyInit__arraykit(void)
54265431
PyModule_AddObject(m, "BlockIndex", (PyObject *) &BlockIndexType) ||
54275432
PyModule_AddObject(m, "ArrayGO", (PyObject *) &ArrayGOType) ||
54285433
PyModule_AddObject(m, "deepcopy", deepcopy) ||
5429-
PyModule_AddObject(m, "ErrorInitBlocks", ErrorInitBlocks)
5434+
PyModule_AddObject(m, "ErrorInitTypeBlocks", ErrorInitTypeBlocks)
54305435
){
54315436
Py_DECREF(deepcopy);
54325437
Py_XDECREF(m);

test/test_block_index.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from arraykit import BlockIndex
10-
from arraykit import ErrorInitBlocks
10+
from arraykit import ErrorInitTypeBlocks
1111

1212

1313
class TestUnit(unittest.TestCase):
@@ -53,19 +53,19 @@ def test_block_index_init_d(self) -> None:
5353

5454
def test_block_index_register_a(self) -> None:
5555
bi1 = BlockIndex()
56-
with self.assertRaises(ErrorInitBlocks):
56+
with self.assertRaises(ErrorInitTypeBlocks):
5757
bi1.register('foo')
5858

59-
with self.assertRaises(ErrorInitBlocks):
59+
with self.assertRaises(ErrorInitTypeBlocks):
6060
bi1.register(3.5)
6161

6262
def test_block_index_register_b(self) -> None:
6363

6464
bi1 = BlockIndex()
65-
with self.assertRaises(ErrorInitBlocks):
65+
with self.assertRaises(ErrorInitTypeBlocks):
6666
bi1.register(np.array(0))
6767

68-
with self.assertRaises(ErrorInitBlocks):
68+
with self.assertRaises(ErrorInitTypeBlocks):
6969
bi1.register(np.arange(12).reshape(2,3,2))
7070

7171

@@ -96,7 +96,7 @@ def test_block_index_register_d(self) -> None:
9696
def test_block_index_register_e(self) -> None:
9797
bi1 = BlockIndex()
9898
bi1.register(np.arange(2))
99-
with self.assertRaises(ErrorInitBlocks):
99+
with self.assertRaises(ErrorInitTypeBlocks):
100100
bi1.register(np.arange(12).reshape(3,4))
101101

102102

@@ -108,6 +108,39 @@ def test_block_index_register_f(self) -> None:
108108
self.assertEqual(bi1.columns, 10_000)
109109

110110

111+
def test_block_index_register_g(self) -> None:
112+
bi1 = BlockIndex()
113+
a1 = np.array(()).reshape(4, 0)
114+
self.assertFalse(bi1.register(a1))
115+
self.assertEqual(bi1.shape, (4, 0))
116+
# as not dtype has been registered, we will get default float
117+
self.assertEqual(bi1.dtype, np.dtype(float))
118+
119+
a2 = np.arange(8).reshape(4, 2).astype(bool)
120+
self.assertTrue(bi1.register(a2))
121+
self.assertEqual(bi1.shape, (4, 2))
122+
self.assertEqual(bi1.dtype, np.dtype(bool))
123+
124+
125+
def test_block_index_register_h(self) -> None:
126+
bi1 = BlockIndex()
127+
a1 = np.array(()).reshape(0, 4).astype(bool)
128+
self.assertTrue(bi1.register(a1))
129+
self.assertEqual(bi1.shape, (0, 4))
130+
self.assertEqual(bi1.dtype, np.dtype(bool))
131+
132+
a2 = np.array(()).reshape(0, 0).astype(float)
133+
self.assertFalse(bi1.register(a2))
134+
self.assertEqual(bi1.shape, (0, 4))
135+
# dtype is still bool
136+
self.assertEqual(bi1.dtype, np.dtype(bool))
137+
138+
a3 = np.array(()).reshape(0, 3).astype(int)
139+
self.assertTrue(bi1.register(a3))
140+
self.assertEqual(bi1.shape, (0, 7))
141+
self.assertEqual(bi1.dtype, np.dtype(object))
142+
143+
111144
#---------------------------------------------------------------------------
112145

113146
def test_block_index_to_bytes_a(self) -> None:

test/test_pyi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def from_module(cls, module):
3030
continue
3131
obj = getattr(module, name)
3232
if isinstance(obj, type): # a class
33-
if name == ak.ErrorInitBlocks.__name__:
33+
if name == ak.ErrorInitTypeBlocks.__name__:
3434
# skip as there is Python version variability
3535
continue
3636
classes[name] = []

0 commit comments

Comments
 (0)