Skip to content

Commit c8b7564

Browse files
committed
additional testing
1 parent 4face05 commit c8b7564

File tree

3 files changed

+80
-10
lines changed

3 files changed

+80
-10
lines changed

doc/articles/block_index.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, arrays: tp.Iterable[np.ndarray]):
7575

7676
self.selector_int_array = np.arange(0, len(self.bi), 2)
7777
self.selector_int_list = list(range(0, len(self.bi), 2))
78+
self.selector_bool_array = (np.arange(len(self.bi)) % 2) == 0
7879

7980
#-------------------------------------------------------------------------------
8081
class BlockIndexLoad(ArrayProcessor):
@@ -191,6 +192,23 @@ def __call__(self):
191192

192193

193194

195+
class BlockIndexIterBoolArray(ArrayProcessor):
196+
NAME = 'BlockIndex: iter by bool array'
197+
SORT = 7
198+
199+
def __call__(self):
200+
_ = list(self.bi.iter_select(self.selector_bool_array))
201+
202+
class TupleIndexIterBoolArray(ArrayProcessor):
203+
NAME = 'TupleIndex: iter by bool array'
204+
SORT = 17
205+
206+
def __call__(self):
207+
ti = self.ti
208+
_ = [ti[i] for i in self.selector_bool_array if i]
209+
210+
211+
194212

195213
#-------------------------------------------------------------------------------
196214
NUMBER = 2
@@ -360,6 +378,9 @@ def get_versions() -> str:
360378
TupleIndexIterIntArray,
361379
BlockIndexIterIntList,
362380
TupleIndexIterIntList,
381+
BlockIndexIterBoolArray,
382+
TupleIndexIterBoolArray,
383+
363384
)
364385

365386
CLS_FF = (

src/_arraykit.c

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4146,7 +4146,7 @@ typedef struct BIIterObject {
41464146
PyObject_VAR_HEAD
41474147
BlockIndexObject *bi;
41484148
int8_t reversed;
4149-
Py_ssize_t index; // current index state, mutated in-place
4149+
Py_ssize_t pos; // current index state, mutated in-place
41504150
} BIIterObject;
41514151

41524152
static PyObject *
@@ -4158,7 +4158,7 @@ BIIter_new(BlockIndexObject *bi, int8_t reversed) {
41584158
Py_INCREF(bi);
41594159
bii->bi = bi;
41604160
bii->reversed = reversed;
4161-
bii->index = 0;
4161+
bii->pos = 0;
41624162
return (PyObject *)bii;
41634163
}
41644164

@@ -4178,13 +4178,13 @@ static PyObject *
41784178
BIIter_iternext(BIIterObject *self) {
41794179
Py_ssize_t i;
41804180
if (self->reversed) {
4181-
i = self->bi->bir_count - ++self->index;
4181+
i = self->bi->bir_count - ++self->pos;
41824182
if (i < 0) {
41834183
return NULL;
41844184
}
41854185
}
41864186
else {
4187-
i = self->index++;
4187+
i = self->pos++;
41884188
}
41894189
if (self->bi->bir_count <= i) {
41904190
return NULL;
@@ -4199,8 +4199,8 @@ BIIter_reversed(BIIterObject *self) {
41994199

42004200
static PyObject *
42014201
BIIter_length_hint(BIIterObject *self) {
4202-
// this works for reversed as we use self-> index to subtract from length
4203-
Py_ssize_t len = Py_MAX(0, self->bi->bir_count - self->index);
4202+
// this works for reversed as we use self->pos to subtract from length
4203+
Py_ssize_t len = Py_MAX(0, self->bi->bir_count - self->pos);
42044204
return PyLong_FromSsize_t(len);
42054205
}
42064206

@@ -4564,13 +4564,15 @@ BIIterSelector_new(BlockIndexObject *bi,
45644564
PyErr_SetString(PyExc_TypeError, "Slices cannot be used as selectors for this type of iterator");
45654565
return NULL;
45664566
}
4567-
if (PySlice_GetIndicesEx(selector, bi->bir_count, &pos, &stop, &step, &len)) {
4567+
if (PySlice_Unpack(selector, &pos, &stop, &step)) {
45684568
return NULL;
45694569
}
4570+
len = PySlice_AdjustIndices(bi->bir_count, &pos, &stop, step);
45704571
if (reversed) {
45714572
pos += (step * (len - 1));
45724573
step *= -1;
45734574
}
4575+
// AK_DEBUG_MSG_OBJ("resolved slice", Py_BuildValue("nnnn", pos, stop, step, len));
45744576
}
45754577
else if (PyList_CheckExact(selector)) {
45764578
if (kind == BIIS_UNKNOWN) {

test/test_block_index.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,40 @@ def test_block_index_iter_select_slice_c(self) -> None:
409409
[(2, 1), (2, 0), (1, 0), (0, 1)]
410410
)
411411

412+
413+
def test_block_index_iter_select_slice_d(self) -> None:
414+
bi1 = BlockIndex()
415+
bi1.register(np.arange(6).reshape(2,3))
416+
bi1.register(np.arange(2))
417+
418+
self.assertEqual(list(bi1.iter_select(slice(None))),
419+
[(0, 0), (0, 1), (0, 2), (1, 0)]
420+
)
421+
self.assertEqual(list(bi1.iter_select(slice(20, 24))),
422+
[]
423+
)
424+
self.assertEqual(list(bi1.iter_select(slice(0, 100, 10))),
425+
[(0, 0)]
426+
)
427+
self.assertEqual(list(bi1.iter_select(slice(0, 100, 3))),
428+
[(0, 0), (1, 0)]
429+
)
430+
431+
def test_block_index_iter_select_slice_e(self) -> None:
432+
bi1 = BlockIndex()
433+
bi1.register(np.arange(12).reshape(2,6))
434+
bi1.register(np.arange(12).reshape(2,6))
435+
436+
self.assertEqual(list(bi1.iter_select(slice(11, None, -3))),
437+
[(1, 5), (1, 2), (0, 5), (0, 2)]
438+
)
439+
self.assertEqual(list(bi1.iter_select(slice(11, None, -4))),
440+
[(1, 5), (1, 1), (0, 3)]
441+
)
442+
443+
412444
#---------------------------------------------------------------------------
413-
def test_block_index_iter_select_slice_a(self) -> None:
445+
def test_block_index_iter_select_boolean_a(self) -> None:
414446
bi1 = BlockIndex()
415447
bi1.register(np.arange(4).reshape(2,2))
416448
bi1.register(np.arange(2))
@@ -428,10 +460,25 @@ def test_block_index_iter_select_slice_a(self) -> None:
428460
[(0, 0), (2, 4)]
429461
)
430462

431-
def test_block_index_iter_select_slice_b(self) -> None:
463+
def test_block_index_iter_select_boolean_b(self) -> None:
432464
bi1 = BlockIndex()
433465
bi1.register(np.arange(4).reshape(2,2))
434466
bi1.register(np.arange(2))
435467

436468
with self.assertRaises(TypeError):
437-
bi1.iter_select(np.array([False, True]))
469+
bi1.iter_select(np.array([False, True]))
470+
471+
with self.assertRaises(TypeError):
472+
bi1.iter_select(np.full(20, True))
473+
474+
475+
def test_block_index_iter_select_boolean_c(self) -> None:
476+
bi1 = BlockIndex()
477+
bi1.register(np.arange(4).reshape(2,2))
478+
bi1.register(np.arange(2))
479+
self.assertEqual(list(bi1.iter_select(np.full(len(bi1), False))),
480+
[]
481+
)
482+
self.assertEqual(list(bi1.iter_select(np.full(len(bi1), True))),
483+
[(0, 0), (0, 1), (1, 0)]
484+
)

0 commit comments

Comments
 (0)