From 9cc3bb7a255cff7eff4042124451144f23bd6a74 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 12 Feb 2026 15:53:38 +0000 Subject: [PATCH 01/11] Experiment with reduced allocations in MaskValues iterators Signed-off-by: Adam Gutglick --- .../src/arrays/bool/compute/filter.rs | 29 +++++--- .../src/arrays/chunked/compute/filter.rs | 15 +++-- .../src/arrays/chunked/compute/mask.rs | 29 +++++--- .../arrays/filter/execute/fixed_size_list.rs | 67 ++++++++++++------- .../src/arrays/list/compute/filter.rs | 37 +++++----- .../src/arrays/varbin/compute/filter.rs | 62 +++++++++++------ vortex-compute/src/filter/slice.rs | 14 ++-- 7 files changed, 157 insertions(+), 96 deletions(-) diff --git a/vortex-array/src/arrays/bool/compute/filter.rs b/vortex-array/src/arrays/bool/compute/filter.rs index 82887d8799d..14d9e525010 100644 --- a/vortex-array/src/arrays/bool/compute/filter.rs +++ b/vortex-array/src/arrays/bool/compute/filter.rs @@ -7,7 +7,6 @@ use vortex_buffer::get_bit; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::Mask; -use vortex_mask::MaskIter; use crate::ArrayRef; use crate::ExecutionCtx; @@ -32,19 +31,33 @@ impl FilterKernel for BoolVTable { .values() .vortex_expect("AllTrue and AllFalse are handled by filter fn"); - let buffer = match mask_values.threshold_iter(FILTER_SLICES_DENSITY_THRESHOLD) { - MaskIter::Indices(indices) => filter_indices( + let buffer = if mask_values.density() >= FILTER_SLICES_DENSITY_THRESHOLD { + filter_slices( &array.to_bit_buffer(), mask.true_count(), - indices.iter().copied(), - ), - MaskIter::Slices(slices) => filter_slices( + mask_values.bit_buffer().set_slices(), + ) + } else { + filter_indices( &array.to_bit_buffer(), mask.true_count(), - slices.iter().copied(), - ), + mask_values.bit_buffer().set_indices(), + ) }; + // let buffer = match mask_values.threshold_iter(FILTER_SLICES_DENSITY_THRESHOLD) { + // MaskIter::Indices(indices) => filter_indices( + // &array.to_bit_buffer(), + // mask.true_count(), + // indices.iter().copied(), + // ), + // MaskIter::Slices(slices) => filter_slices( + // &array.to_bit_buffer(), + // mask.true_count(), + // slices.iter().copied(), + // ), + // }; + Ok(Some(BoolArray::new(buffer, validity).into_array())) } } diff --git a/vortex-array/src/arrays/chunked/compute/filter.rs b/vortex-array/src/arrays/chunked/compute/filter.rs index 2e31adcc0aa..273dcdeb343 100644 --- a/vortex-array/src/arrays/chunked/compute/filter.rs +++ b/vortex-array/src/arrays/chunked/compute/filter.rs @@ -5,7 +5,6 @@ use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::Mask; -use vortex_mask::MaskIter; use crate::Array; use crate::ArrayRef; @@ -34,10 +33,16 @@ impl FilterKernel for ChunkedVTable { // Based on filter selectivity, we take the values between a range of slices, or // we take individual indices. - let chunks = match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - MaskIter::Indices(indices) => filter_indices(array, indices.iter().copied()), - MaskIter::Slices(slices) => filter_slices(array, slices.iter().copied()), - }?; + // let chunks = match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { + // MaskIter::Indices(indices) => filter_indices(array, indices.iter().copied()), + // MaskIter::Slices(slices) => filter_slices(array, slices.iter().copied()), + // }?; + + let chunks = if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { + filter_slices(array, mask_values.bit_buffer().set_slices())? + } else { + filter_indices(array, mask_values.bit_buffer().set_indices())? + }; // SAFETY: Filter operation preserves the dtype of each chunk. // All filtered chunks maintain the same dtype as the original array. diff --git a/vortex-array/src/arrays/chunked/compute/mask.rs b/vortex-array/src/arrays/chunked/compute/mask.rs index 7c6aa7b87ea..522403e47f1 100644 --- a/vortex-array/src/arrays/chunked/compute/mask.rs +++ b/vortex-array/src/arrays/chunked/compute/mask.rs @@ -5,10 +5,9 @@ use itertools::Itertools as _; use vortex_buffer::BitBuffer; use vortex_buffer::BitBufferMut; use vortex_dtype::DType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_mask::AllOr; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_scalar::Scalar; use super::filter::ChunkFilter; @@ -31,13 +30,21 @@ use crate::validity::Validity; impl MaskKernel for ChunkedVTable { fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { let new_dtype = array.dtype().as_nullable(); - let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - AllOr::All => unreachable!("handled in top-level mask"), - AllOr::None => unreachable!("handled in top-level mask"), - AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype), - AllOr::Some(MaskIter::Slices(slices)) => { - mask_slices(array, slices.iter().cloned(), &new_dtype) - } + // let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { + // AllOr::All => unreachable!("handled in top-level mask"), + // AllOr::None => unreachable!("handled in top-level mask"), + // AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype), + // AllOr::Some(MaskIter::Slices(slices)) => { + // mask_slices(array, slices.iter().cloned(), &new_dtype) + // } + // }?; + + let mask_values = mask.values().vortex_expect("handled in top-level mask"); + + let new_chunks = if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { + mask_indices(array, mask_values.bit_buffer().set_indices(), &new_dtype) + } else { + mask_slices(array, mask_values.bit_buffer().set_slices(), &new_dtype) }?; debug_assert_eq!(new_chunks.len(), array.nchunks()); debug_assert_eq!( @@ -52,7 +59,7 @@ register_kernel!(MaskKernelAdapter(ChunkedVTable).lift()); fn mask_indices( array: &ChunkedArray, - indices: &[usize], + indices: impl Iterator, new_dtype: &DType, ) -> VortexResult> { let mut new_chunks = Vec::with_capacity(array.nchunks()); @@ -61,7 +68,7 @@ fn mask_indices( let chunk_offsets = array.chunk_offsets(); - for &set_index in indices { + for set_index in indices { let (chunk_id, index) = find_chunk_idx(set_index, &chunk_offsets)?; if chunk_id != current_chunk_id { let chunk = array.chunk(current_chunk_id).clone(); diff --git a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs index f892f90921f..fc696c60908 100644 --- a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs +++ b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs @@ -3,9 +3,9 @@ use std::sync::Arc; +use itertools::Itertools; use vortex_error::VortexExpect; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskValues; use crate::arrays::FixedSizeListArray; @@ -83,33 +83,48 @@ pub fn filter_fixed_size_list( fn compute_mask_for_fsl_elements(selection_mask: &MaskValues, list_size: usize) -> Mask { let expanded_len = selection_mask.len() * list_size; - // Use threshold_iter to choose the optimal representation based on density. - let expanded_slices = match selection_mask.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { - MaskIter::Slices(slices) => { - // Expand a dense mask (represented as slices) by scaling each slice by `list_size`. - slices - .iter() - .map(|&(start, end)| (start * list_size, end * list_size)) - .collect() - } - MaskIter::Indices(indices) => { - // Expand a sparse mask (represented as indices) by duplicating each index `list_size` - // times. - // - // Note that in the worst case, it is possible that we create only a few slices with a - // small range (for example, when list_size <= 2). This could be further optimized, - // but we choose simplicity for now. - indices - .iter() - .map(|&idx| { - let start = idx * list_size; - let end = (idx + 1) * list_size; - (start, end) - }) - .collect() - } + let expanded_slices = if selection_mask.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { + selection_mask + .bit_buffer() + .set_slices() + .map(|(start, end)| (start * list_size, end * list_size)) + .collect() + } else { + selection_mask + .bit_buffer() + .set_indices() + .tuple_windows() + .map(|(start, end)| (start * list_size, end * list_size)) + .collect() }; + // // Use threshold_iter to choose the optimal representation based on density. + // let expanded_slices = match selection_mask.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { + // MaskIter::Slices(slices) => { + // // Expand a dense mask (represented as slices) by scaling each slice by `list_size`. + // slices + // .iter() + // .map(|&(start, end)| (start * list_size, end * list_size)) + // .collect() + // } + // MaskIter::Indices(indices) => { + // // Expand a sparse mask (represented as indices) by duplicating each index `list_size` + // // times. + // // + // // Note that in the worst case, it is possible that we create only a few slices with a + // // small range (for example, when list_size <= 2). This could be further optimized, + // // but we choose simplicity for now. + // indices + // .iter() + // .map(|&idx| { + // let start = idx * list_size; + // let end = (idx + 1) * list_size; + // (start, end) + // }) + // .collect() + // } + // }; + Mask::from_slices(expanded_len, expanded_slices) } diff --git a/vortex-array/src/arrays/list/compute/filter.rs b/vortex-array/src/arrays/list/compute/filter.rs index 0393c65a817..e7fec85f966 100644 --- a/vortex-array/src/arrays/list/compute/filter.rs +++ b/vortex-array/src/arrays/list/compute/filter.rs @@ -3,6 +3,7 @@ use std::sync::Arc; +use itertools::Itertools; use num_traits::Zero; use vortex_buffer::BitBufferMut; use vortex_buffer::Buffer; @@ -11,7 +12,6 @@ use vortex_dtype::IntegerPType; use vortex_dtype::match_each_integer_ptype; use vortex_error::VortexResult; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskValues; use crate::ArrayRef; @@ -43,27 +43,24 @@ pub fn element_mask_from_offsets( let mut mask_builder = BitBufferMut::with_capacity(len); - match selection.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { - MaskIter::Slices(slices) => { - // Dense iteration: process ranges of consecutive selected lists. - for &(start, end) in slices { - // Optimization: for dense ranges, we can process the elements mask more efficiently. - let elems_start = offsets[start].as_() - first_offset; - let elems_end = offsets[end].as_() - first_offset; + if selection.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { + // Dense iteration: process ranges of consecutive selected lists. + for (start, end) in selection.bit_buffer().set_slices() { + // Optimization: for dense ranges, we can process the elements mask more efficiently. + let elems_start = offsets[start].as_() - first_offset; + let elems_end = offsets[end].as_() - first_offset; - // Process the entire range of elements at once. - process_element_range(elems_start, elems_end, &mut mask_builder); - } + // Process the entire range of elements at once. + process_element_range(elems_start, elems_end, &mut mask_builder); } - MaskIter::Indices(indices) => { - // Sparse iteration: process individual selected lists. - for &idx in indices { - let list_start = offsets[idx].as_() - first_offset; - let list_end = offsets[idx + 1].as_() - first_offset; - - // Process the elements for this list. - process_element_range(list_start, list_end, &mut mask_builder); - } + } else { + // Sparse iteration: process individual selected lists. + for (start, end) in selection.bit_buffer().set_indices().tuple_windows() { + let list_start = offsets[start].as_() - first_offset; + let list_end = offsets[end].as_() - first_offset; + + // Process the elements for this list. + process_element_range(list_start, list_end, &mut mask_builder); } } diff --git a/vortex-array/src/arrays/varbin/compute/filter.rs b/vortex-array/src/arrays/varbin/compute/filter.rs index 24d961e03c1..a9c78f01b73 100644 --- a/vortex-array/src/arrays/varbin/compute/filter.rs +++ b/vortex-array/src/arrays/varbin/compute/filter.rs @@ -12,7 +12,6 @@ use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_mask::AllOr; use vortex_mask::Mask; -use vortex_mask::MaskIter; use crate::ArrayRef; use crate::ExecutionCtx; @@ -36,21 +35,39 @@ impl FilterKernel for VarBinVTable { } fn filter_select_var_bin(arr: &VarBinArray, mask: &Mask) -> VortexResult { - match mask + let mask_values = mask .values() - .vortex_expect("AllTrue and AllFalse are handled by filter fn") - .threshold_iter(0.5) - { - MaskIter::Indices(indices) => { - filter_select_var_bin_by_index(arr, indices, mask.true_count()) - } - MaskIter::Slices(slices) => filter_select_var_bin_by_slice(arr, slices, mask.true_count()), + .vortex_expect("AllTrue and AllFalse are handled by filter fn"); + + if mask_values.density() >= 0.5 { + filter_select_var_bin_by_slice( + arr, + mask_values.bit_buffer().set_slices(), + mask.true_count(), + ) + } else { + filter_select_var_bin_by_index( + arr, + mask_values.bit_buffer().set_indices(), + mask.true_count(), + ) } + + // match mask + // .values() + // .vortex_expect("AllTrue and AllFalse are handled by filter fn") + // .threshold_iter(0.5) + // { + // MaskIter::Indices(indices) => { + // filter_select_var_bin_by_index(arr, indices, mask.true_count()) + // } + // MaskIter::Slices(slices) => filter_select_var_bin_by_slice(arr, slices, mask.true_count()), + // } } fn filter_select_var_bin_by_slice( values: &VarBinArray, - mask_slices: &[(usize, usize)], + mask_slices: impl Iterator, selection_count: usize, ) -> VortexResult { let offsets = values.offsets().to_primitive(); @@ -70,7 +87,7 @@ fn filter_select_var_bin_by_slice_primitive_offset( dtype: DType, offsets: &[O], data: &[u8], - mask_slices: &[(usize, usize)], + mask_slices: impl Iterator, logical_validity: Mask, selection_count: usize, ) -> VortexResult @@ -81,15 +98,15 @@ where let mut builder = VarBinBuilder::::with_capacity(selection_count); match logical_validity.bit_buffer() { AllOr::All => { - mask_slices.iter().for_each(|(start, end)| { - update_non_nullable_slice(data, offsets, &mut builder, *start, *end) + mask_slices.for_each(|(start, end)| { + update_non_nullable_slice(data, offsets, &mut builder, start, end) }); } AllOr::None => { builder.append_n_nulls(selection_count); } AllOr::Some(validity) => { - for (start, end) in mask_slices.iter().copied() { + for (start, end) in mask_slices { let null_sl = validity.slice(start..end); if null_sl.true_count() == null_sl.len() { update_non_nullable_slice(data, offsets, &mut builder, start, end) @@ -148,7 +165,7 @@ fn update_non_nullable_slice( fn filter_select_var_bin_by_index( values: &VarBinArray, - mask_indices: &[usize], + mask_indices: impl Iterator, selection_count: usize, ) -> VortexResult { let offsets = values.offsets().to_primitive(); @@ -168,13 +185,13 @@ fn filter_select_var_bin_by_index_primitive_offset( dtype: DType, offsets: &[O], data: &[u8], - mask_indices: &[usize], + mask_indices: impl Iterator, // TODO(ngates): pass LogicalValidity instead validity: Validity, selection_count: usize, ) -> VortexResult { let mut builder = VarBinBuilder::::with_capacity(selection_count); - for idx in mask_indices.iter().copied() { + for idx in mask_indices { if validity.is_valid(idx)? { let (start, end) = ( offsets[idx].to_usize().ok_or_else(|| { @@ -219,7 +236,7 @@ mod test { ], DType::Utf8(NonNullable), ); - let buf = filter_select_var_bin_by_index(&arr, &[0, 2], 2).unwrap(); + let buf = filter_select_var_bin_by_index(&arr, [0, 2].into_iter(), 2).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec!["hello", "filter"])); } @@ -237,7 +254,8 @@ mod test { DType::Utf8(NonNullable), ); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 1), (2, 3), (4, 5)], 3).unwrap(); + let buf = + filter_select_var_bin_by_slice(&arr, [(0, 1), (2, 3), (4, 5)].into_iter(), 3).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec!["hello", "filter", "filter3"])); } @@ -262,7 +280,7 @@ mod test { ); let arr = VarBinArray::try_new(offsets, bytes, DType::Utf8(Nullable), validity).unwrap(); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 3), (4, 6)], 5).unwrap(); + let buf = filter_select_var_bin_by_slice(&arr, [(0, 3), (4, 6)].into_iter(), 5).unwrap(); assert_arrays_eq!( buf, @@ -287,7 +305,7 @@ mod test { let validity = Validity::Array(BoolArray::from_iter([false, true, true]).into_array()); let arr = VarBinArray::try_new(offsets, bytes, DType::Utf8(Nullable), validity).unwrap(); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 1), (2, 3)], 2).unwrap(); + let buf = filter_select_var_bin_by_slice(&arr, [(0, 1), (2, 3)].into_iter(), 2).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec![None, Some("two")])); } @@ -304,7 +322,7 @@ mod test { ) .unwrap(); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 1), (2, 3)], 2).unwrap(); + let buf = filter_select_var_bin_by_slice(&arr, [(0, 1), (2, 3)].into_iter(), 2).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec![None::<&str>, None])); } diff --git a/vortex-compute/src/filter/slice.rs b/vortex-compute/src/filter/slice.rs index c8f1e89a3bb..189e68587ff 100644 --- a/vortex-compute/src/filter/slice.rs +++ b/vortex-compute/src/filter/slice.rs @@ -16,7 +16,6 @@ use vortex_buffer::BitView; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskValues; use crate::filter::Filter; @@ -46,9 +45,16 @@ impl Filter for &[T] { "Selection mask length must equal the buffer length" ); - match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - MaskIter::Indices(indices) => self.filter(indices), - MaskIter::Slices(slices) => self.filter(slices), + if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { + // High density: use slices (contiguous ranges) + self.filter(mask_values.slices()) + } else { + // Low density: stream indices directly from bitmap without allocatingExpand commentComment on line R52Resolved + let mut out = BufferMut::::with_capacity(mask_values.true_count()); + for idx in mask_values.bit_buffer().set_indices() { + out.push(self[idx]); + } + out.freeze() } } } From f73da40c2f9894a007babc8003ceaa0a1a52cc51 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 12 Feb 2026 16:33:49 +0000 Subject: [PATCH 02/11] Fix bug Signed-off-by: Adam Gutglick --- vortex-array/src/arrays/filter/execute/fixed_size_list.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs index fc696c60908..a5c7028349e 100644 --- a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs +++ b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs @@ -3,7 +3,6 @@ use std::sync::Arc; -use itertools::Itertools; use vortex_error::VortexExpect; use vortex_mask::Mask; use vortex_mask::MaskValues; @@ -93,8 +92,11 @@ fn compute_mask_for_fsl_elements(selection_mask: &MaskValues, list_size: usize) selection_mask .bit_buffer() .set_indices() - .tuple_windows() - .map(|(start, end)| (start * list_size, end * list_size)) + .map(|idx| { + let start = idx * list_size; + let end = (idx + 1) * list_size; + (start, end) + }) .collect() }; From 3782e7ded6d77c9ca17795a878471db061d0487c Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 12 Feb 2026 18:48:44 +0000 Subject: [PATCH 03/11] Just getting benchmarks to run Signed-off-by: Adam Gutglick --- .../src/bitpacking/compute/filter.rs | 75 ++-- encodings/runend/src/compute/filter.rs | 2 +- encodings/runend/src/compute/take.rs | 12 +- encodings/sequence/src/compute/filter.rs | 2 +- encodings/sparse/src/lib.rs | 34 +- encodings/zstd/src/array.rs | 139 ++++---- .../src/arrays/bool/compute/filter.rs | 13 - .../src/arrays/chunked/compute/filter.rs | 7 - .../src/arrays/chunked/compute/mask.rs | 9 - .../arrays/filter/execute/fixed_size_list.rs | 27 -- .../src/arrays/list/compute/filter.rs | 4 +- .../src/arrays/primitive/array/top_value.rs | 24 +- .../src/arrays/varbin/compute/filter.rs | 11 - .../src/arrays/varbinview/compute/zip.rs | 78 ++--- vortex-array/src/compute/zip.rs | 40 ++- vortex-array/src/patches.rs | 30 +- vortex-compute/src/filter/bitbuffer.rs | 21 +- vortex-compute/src/filter/buffer.rs | 1 + vortex-compute/src/filter/slice.rs | 8 +- vortex-compute/src/filter/slice_mut.rs | 22 +- .../src/filter/vector/binaryview.rs | 2 + vortex-compute/src/filter/vector/bool.rs | 2 + vortex-compute/src/filter/vector/decimal.rs | 2 + vortex-compute/src/filter/vector/dvector.rs | 2 + .../src/filter/vector/fixed_size_list.rs | 62 ++-- vortex-compute/src/filter/vector/list.rs | 2 + vortex-compute/src/filter/vector/mod.rs | 16 + vortex-compute/src/filter/vector/null.rs | 26 ++ vortex-compute/src/filter/vector/primitive.rs | 2 + vortex-compute/src/filter/vector/pvector.rs | 2 + vortex-compute/src/filter/vector/struct_.rs | 2 + vortex-mask/src/intersect_by_rank.rs | 327 +++++++++--------- vortex-mask/src/lib.rs | 177 +++++----- vortex-mask/src/tests.rs | 253 +++++++------- 34 files changed, 734 insertions(+), 702 deletions(-) diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index e8d1385e4d8..00779310eaa 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -97,19 +97,20 @@ fn filter_primitive_without_patches( array: &BitPackedArray, selection: &Arc, ) -> VortexResult<(Buffer, Validity)> { - let values = filter_with_indices(array, selection.indices()); + let values = filter_with_indices(array, selection.bit_buffer().set_indices(), selection.len()); let validity = array.validity()?.filter(&Mask::Values(selection.clone()))?; Ok((values.freeze(), validity)) } -fn filter_with_indices( +fn filter_with_indices>( array: &BitPackedArray, - indices: &[usize], + indices: I, + indices_len: usize, ) -> BufferMut { let offset = array.offset() as usize; let bit_width = array.bit_width() as usize; - let mut values = BufferMut::with_capacity(indices.len()); + let mut values = BufferMut::with_capacity(indices_len); // Some re-usable memory to store per-chunk indices. let mut unpacked = [const { MaybeUninit::::uninit() }; 1024]; @@ -118,43 +119,39 @@ fn filter_with_indices( // Group the indices by the FastLanes chunk they belong to. let chunk_size = 128 * bit_width / size_of::(); - chunked_indices( - indices.iter().copied(), - offset, - |chunk_idx, indices_within_chunk| { - let packed = &packed_bytes[chunk_idx * chunk_size..][..chunk_size]; - - if indices_within_chunk.len() == 1024 { - // Unpack the entire chunk. - unsafe { - let values_len = values.len(); - values.set_len(values_len + 1024); - BitPacking::unchecked_unpack( - bit_width, - packed, - &mut values.as_mut_slice()[values_len..], - ); - } - } else if indices_within_chunk.len() > UNPACK_CHUNK_THRESHOLD { - // Unpack into a temporary chunk and then copy the values. - unsafe { - let dst: &mut [MaybeUninit] = &mut unpacked; - let dst: &mut [T] = std::mem::transmute(dst); - BitPacking::unchecked_unpack(bit_width, packed, dst); - } - values.extend_trusted( - indices_within_chunk - .iter() - .map(|&idx| unsafe { unpacked.get_unchecked(idx).assume_init() }), + chunked_indices(indices, offset, |chunk_idx, indices_within_chunk| { + let packed = &packed_bytes[chunk_idx * chunk_size..][..chunk_size]; + + if indices_within_chunk.len() == 1024 { + // Unpack the entire chunk. + unsafe { + let values_len = values.len(); + values.set_len(values_len + 1024); + BitPacking::unchecked_unpack( + bit_width, + packed, + &mut values.as_mut_slice()[values_len..], ); - } else { - // Otherwise, unpack each element individually. - values.extend_trusted(indices_within_chunk.iter().map(|&idx| unsafe { - BitPacking::unchecked_unpack_single(bit_width, packed, idx) - })); } - }, - ); + } else if indices_within_chunk.len() > UNPACK_CHUNK_THRESHOLD { + // Unpack into a temporary chunk and then copy the values. + unsafe { + let dst: &mut [MaybeUninit] = &mut unpacked; + let dst: &mut [T] = std::mem::transmute(dst); + BitPacking::unchecked_unpack(bit_width, packed, dst); + } + values.extend_trusted( + indices_within_chunk + .iter() + .map(|&idx| unsafe { unpacked.get_unchecked(idx).assume_init() }), + ); + } else { + // Otherwise, unpack each element individually. + values.extend_trusted(indices_within_chunk.iter().map(|&idx| unsafe { + BitPacking::unchecked_unpack_single(bit_width, packed, idx) + })); + } + }); values } diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index 269d64b97f4..8af90e69301 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -41,7 +41,7 @@ impl FilterKernel for RunEndVTable { if runs_ratio < FILTER_TAKE_THRESHOLD || mask_values.true_count() < 25 { Ok(Some(take_indices_unchecked( array, - mask_values.indices(), + mask_values.bit_buffer().set_indices(), &Validity::NonNullable, )?)) } else { diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index d8bfdd3fc1b..f638e6261ec 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -49,14 +49,19 @@ impl TakeExecute for RunEndVTable { .collect::>>()? }); - take_indices_unchecked(array, &checked_indices, primitive_indices.validity()).map(Some) + take_indices_unchecked( + array, + checked_indices.into_iter(), + primitive_indices.validity(), + ) + .map(Some) } } /// Perform a take operation on a RunEndArray by binary searching for each of the indices. -pub fn take_indices_unchecked>( +pub fn take_indices_unchecked, I: Iterator>( array: &RunEndArray, - indices: &[T], + indices: I, validity: &Validity, ) -> VortexResult { let ends = array.ends().to_primitive(); @@ -66,7 +71,6 @@ pub fn take_indices_unchecked>( let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| { let end_slices = ends.as_slice::(); let physical_indices_vec: Vec = indices - .iter() .map(|idx| idx.as_() + array.offset()) .map(|idx| { match ::from(idx) { diff --git a/encodings/sequence/src/compute/filter.rs b/encodings/sequence/src/compute/filter.rs index b7064a388bf..a6684477a6d 100644 --- a/encodings/sequence/src/compute/filter.rs +++ b/encodings/sequence/src/compute/filter.rs @@ -37,7 +37,7 @@ fn filter_impl(mul: T, base: T, mask: &Mask, validity: Validity) .values() .vortex_expect("FilterKernel precondition: mask is Mask::Values"); let mut buffer = BufferMut::::with_capacity(mask_values.true_count()); - buffer.extend(mask_values.indices().iter().map(|&idx| { + buffer.extend(mask_values.bit_buffer().set_indices().map(|idx| { let i = T::from_usize(idx).vortex_expect("all valid indices fit"); base + i * mul })); diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 749dfcf7a84..63f56ea705e 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -44,7 +44,6 @@ use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_mask::AllOr; use vortex_mask::Mask; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -311,23 +310,16 @@ impl SparseArray { } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) { // Array is dominated by NULL but has non-NULL values let non_null_values = filter(array, &mask)?; - let non_null_indices = match mask.indices() { - AllOr::All => { - // We already know that the mask is 90%+ false - unreachable!("Mask is mostly null") - } - AllOr::None => { - // we know there are some non-NULL values - unreachable!("Mask is mostly null but not all null") - } - AllOr::Some(values) => { - let buffer: Buffer = values - .iter() - .map(|&v| v.try_into().vortex_expect("indices must fit in u32")) - .collect(); - - buffer.into_array() - } + let non_null_indices = if let Some(mask_values) = mask.values() { + let buffer: Buffer = mask_values + .bit_buffer() + .set_indices() + .map(|v| v.try_into().vortex_expect("indices must fit in u32")) + .collect(); + + buffer.into_array() + } else { + unreachable!() }; return Ok(SparseArray::try_new( @@ -370,7 +362,11 @@ impl SparseArray { // All values are equal to the top value return Ok(fill_array); } - Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(), + Mask::Values(values) => values + .bit_buffer() + .set_indices() + .map(|v| v as u64) + .collect(), }; SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill) diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index 8c665c8882e..d8e8aeb5f2f 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -52,7 +52,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use vortex_mask::AllOr; use vortex_scalar::Scalar; use vortex_session::VortexSession; @@ -251,27 +250,25 @@ fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult VortexResult<(ByteBuffer, Vec)> { let mask = vbv.validity_mask()?; - let buffer_and_value_byte_indices = match mask.bit_buffer() { - AllOr::None => (Buffer::empty(), Vec::new()), - _ => { - let mut buffer = BufferMut::with_capacity( - usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer") - + mask.true_count() * size_of::(), - ); - let mut value_byte_indices = Vec::new(); - vbv.with_iterator(|iterator| { - // by flattening, we should omit nulls - for value in iterator.flatten() { - value_byte_indices.push(buffer.len()); - // here's where we write the string lengths - buffer - .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter()); - buffer.extend_from_slice(value); - } - Ok::<_, VortexError>(()) - })?; - (buffer.freeze(), value_byte_indices) - } + let buffer_and_value_byte_indices = if mask.all_false() { + (Buffer::empty(), Vec::new()) + } else { + let mut buffer = BufferMut::with_capacity( + usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer") + + mask.true_count() * size_of::(), + ); + let mut value_byte_indices = Vec::new(); + vbv.with_iterator(|iterator| { + // by flattening, we should omit nulls + for value in iterator.flatten() { + value_byte_indices.push(buffer.len()); + // here's where we write the string lengths + buffer.extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter()); + buffer.extend_from_slice(value); + } + Ok::<_, VortexError>(()) + })?; + (buffer.freeze(), value_byte_indices) }; Ok(buffer_and_value_byte_indices) } @@ -719,57 +716,59 @@ impl ZstdArray { Ok(primitive.into_array()) } DType::Binary(_) | DType::Utf8(_) => { - match slice_validity.to_mask(slice_n_rows).indices() { - AllOr::All => { - // the decompressed buffer is a bunch of interleaved u32 lengths - // and strings of those lengths, we need to reconstruct the - // views into those strings by passing through the buffer. - let valid_views = reconstruct_views(&decompressed).slice( - slice_value_idx_start - n_skipped_values - ..slice_value_idx_stop - n_skipped_values, - ); - - // SAFETY: we properly construct the views inside `reconstruct_views` - Ok(unsafe { - VarBinViewArray::new_unchecked( - valid_views, - Arc::from([decompressed]), - self.dtype.clone(), - slice_validity, - ) - } - .into_array()) + let mask = slice_validity.to_mask(slice_n_rows); + if mask.all_true() { + // the decompressed buffer is a bunch of interleaved u32 lengths + // and strings of those lengths, we need to reconstruct the + // views into those strings by passing through the buffer. + let valid_views = reconstruct_views(&decompressed).slice( + slice_value_idx_start - n_skipped_values + ..slice_value_idx_stop - n_skipped_values, + ); + + // SAFETY: we properly construct the views inside `reconstruct_views` + Ok(unsafe { + VarBinViewArray::new_unchecked( + valid_views, + Arc::from([decompressed]), + self.dtype.clone(), + slice_validity, + ) } - AllOr::None => Ok(ConstantArray::new( - Scalar::null(self.dtype.clone()), - slice_n_rows, + .into_array()) + } else if mask.all_false() { + Ok( + ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows) + .into_array(), ) - .into_array()), - AllOr::Some(valid_indices) => { - // the decompressed buffer is a bunch of interleaved u32 lengths - // and strings of those lengths, we need to reconstruct the - // views into those strings by passing through the buffer. - let valid_views = reconstruct_views(&decompressed).slice( - slice_value_idx_start - n_skipped_values - ..slice_value_idx_stop - n_skipped_values, - ); - - let mut views = BufferMut::::zeroed(slice_n_rows); - for (view, index) in valid_views.into_iter().zip_eq(valid_indices) { - views[*index] = view - } - - // SAFETY: we properly construct the views inside `reconstruct_views` - Ok(unsafe { - VarBinViewArray::new_unchecked( - views.freeze(), - Arc::from([decompressed]), - self.dtype.clone(), - slice_validity, - ) - } - .into_array()) + } else { + let mask_values = mask.values().unwrap(); + // the decompressed buffer is a bunch of interleaved u32 lengths + // and strings of those lengths, we need to reconstruct the + // views into those strings by passing through the buffer. + let valid_views = reconstruct_views(&decompressed).slice( + slice_value_idx_start - n_skipped_values + ..slice_value_idx_stop - n_skipped_values, + ); + + let mut views = BufferMut::::zeroed(slice_n_rows); + for (view, index) in valid_views + .into_iter() + .zip_eq(mask_values.bit_buffer().set_indices()) + { + views[index] = view + } + + // SAFETY: we properly construct the views inside `reconstruct_views` + Ok(unsafe { + VarBinViewArray::new_unchecked( + views.freeze(), + Arc::from([decompressed]), + self.dtype.clone(), + slice_validity, + ) } + .into_array()) } } _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype), diff --git a/vortex-array/src/arrays/bool/compute/filter.rs b/vortex-array/src/arrays/bool/compute/filter.rs index 14d9e525010..84ad36e109c 100644 --- a/vortex-array/src/arrays/bool/compute/filter.rs +++ b/vortex-array/src/arrays/bool/compute/filter.rs @@ -45,19 +45,6 @@ impl FilterKernel for BoolVTable { ) }; - // let buffer = match mask_values.threshold_iter(FILTER_SLICES_DENSITY_THRESHOLD) { - // MaskIter::Indices(indices) => filter_indices( - // &array.to_bit_buffer(), - // mask.true_count(), - // indices.iter().copied(), - // ), - // MaskIter::Slices(slices) => filter_slices( - // &array.to_bit_buffer(), - // mask.true_count(), - // slices.iter().copied(), - // ), - // }; - Ok(Some(BoolArray::new(buffer, validity).into_array())) } } diff --git a/vortex-array/src/arrays/chunked/compute/filter.rs b/vortex-array/src/arrays/chunked/compute/filter.rs index 273dcdeb343..b1685779455 100644 --- a/vortex-array/src/arrays/chunked/compute/filter.rs +++ b/vortex-array/src/arrays/chunked/compute/filter.rs @@ -31,13 +31,6 @@ impl FilterKernel for ChunkedVTable { .values() .vortex_expect("AllTrue and AllFalse are handled by filter fn"); - // Based on filter selectivity, we take the values between a range of slices, or - // we take individual indices. - // let chunks = match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - // MaskIter::Indices(indices) => filter_indices(array, indices.iter().copied()), - // MaskIter::Slices(slices) => filter_slices(array, slices.iter().copied()), - // }?; - let chunks = if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { filter_slices(array, mask_values.bit_buffer().set_slices())? } else { diff --git a/vortex-array/src/arrays/chunked/compute/mask.rs b/vortex-array/src/arrays/chunked/compute/mask.rs index 522403e47f1..40f3a279570 100644 --- a/vortex-array/src/arrays/chunked/compute/mask.rs +++ b/vortex-array/src/arrays/chunked/compute/mask.rs @@ -30,15 +30,6 @@ use crate::validity::Validity; impl MaskKernel for ChunkedVTable { fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { let new_dtype = array.dtype().as_nullable(); - // let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - // AllOr::All => unreachable!("handled in top-level mask"), - // AllOr::None => unreachable!("handled in top-level mask"), - // AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype), - // AllOr::Some(MaskIter::Slices(slices)) => { - // mask_slices(array, slices.iter().cloned(), &new_dtype) - // } - // }?; - let mask_values = mask.values().vortex_expect("handled in top-level mask"); let new_chunks = if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { diff --git a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs index a5c7028349e..39247025cbd 100644 --- a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs +++ b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs @@ -100,33 +100,6 @@ fn compute_mask_for_fsl_elements(selection_mask: &MaskValues, list_size: usize) .collect() }; - // // Use threshold_iter to choose the optimal representation based on density. - // let expanded_slices = match selection_mask.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { - // MaskIter::Slices(slices) => { - // // Expand a dense mask (represented as slices) by scaling each slice by `list_size`. - // slices - // .iter() - // .map(|&(start, end)| (start * list_size, end * list_size)) - // .collect() - // } - // MaskIter::Indices(indices) => { - // // Expand a sparse mask (represented as indices) by duplicating each index `list_size` - // // times. - // // - // // Note that in the worst case, it is possible that we create only a few slices with a - // // small range (for example, when list_size <= 2). This could be further optimized, - // // but we choose simplicity for now. - // indices - // .iter() - // .map(|&idx| { - // let start = idx * list_size; - // let end = (idx + 1) * list_size; - // (start, end) - // }) - // .collect() - // } - // }; - Mask::from_slices(expanded_len, expanded_slices) } diff --git a/vortex-array/src/arrays/list/compute/filter.rs b/vortex-array/src/arrays/list/compute/filter.rs index e7fec85f966..da5cb57f8be 100644 --- a/vortex-array/src/arrays/list/compute/filter.rs +++ b/vortex-array/src/arrays/list/compute/filter.rs @@ -124,8 +124,8 @@ impl FilterKernel for ListVTable { let mut offset = O::zero(); unsafe { new_offsets.push_unchecked(offset) }; - for idx in selection.indices() { - let size = offsets[idx + 1] - offsets[*idx]; + for idx in selection.bit_buffer().set_indices() { + let size = offsets[idx + 1] - offsets[idx]; offset += size; unsafe { new_offsets.push_unchecked(offset) }; } diff --git a/vortex-array/src/arrays/primitive/array/top_value.rs b/vortex-array/src/arrays/primitive/array/top_value.rs index 0b67de41027..1e115babe47 100644 --- a/vortex-array/src/arrays/primitive/array/top_value.rs +++ b/vortex-array/src/arrays/primitive/array/top_value.rs @@ -8,7 +8,6 @@ use vortex_dtype::NativePType; use vortex_dtype::match_each_native_ptype; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_mask::AllOr; use vortex_mask::Mask; use vortex_scalar::PValue; use vortex_utils::aliases::hash_map::HashMap; @@ -41,20 +40,19 @@ where { let mut distinct_values: HashMap, usize, FxBuildHasher> = HashMap::with_hasher(FxBuildHasher); - match mask.indices() { - AllOr::All => { - for value in values.iter().copied() { - *distinct_values.entry(NativeValue(value)).or_insert(0) += 1; - } + + if let Some(mask_values) = mask.values() { + for i in mask_values.bit_buffer().set_indices() { + *distinct_values + .entry(NativeValue(unsafe { *values.get_unchecked(i) })) + .or_insert(0) += 1 } - AllOr::None => unreachable!("All invalid arrays should be handled earlier"), - AllOr::Some(idxs) => { - for &i in idxs { - *distinct_values - .entry(NativeValue(unsafe { *values.get_unchecked(i) })) - .or_insert(0) += 1 - } + } else if mask.all_true() { + for value in values.iter().copied() { + *distinct_values.entry(NativeValue(value)).or_insert(0) += 1; } + } else { + unreachable!("All invalid arrays should be handled earlier") } let (&top_value, &top_count) = distinct_values diff --git a/vortex-array/src/arrays/varbin/compute/filter.rs b/vortex-array/src/arrays/varbin/compute/filter.rs index a9c78f01b73..1c3e1eb15dd 100644 --- a/vortex-array/src/arrays/varbin/compute/filter.rs +++ b/vortex-array/src/arrays/varbin/compute/filter.rs @@ -52,17 +52,6 @@ fn filter_select_var_bin(arr: &VarBinArray, mask: &Mask) -> VortexResult { - // filter_select_var_bin_by_index(arr, indices, mask.true_count()) - // } - // MaskIter::Slices(slices) => filter_select_var_bin_by_slice(arr, slices, mask.true_count()), - // } } fn filter_select_var_bin_by_slice( diff --git a/vortex-array/src/arrays/varbinview/compute/zip.rs b/vortex-array/src/arrays/varbinview/compute/zip.rs index e04b3983b9c..2dac6186a1c 100644 --- a/vortex-array/src/arrays/varbinview/compute/zip.rs +++ b/vortex-array/src/arrays/varbinview/compute/zip.rs @@ -55,57 +55,57 @@ impl ZipKernel for VarBinViewVTable { let true_validity = if_true.validity_mask()?; let false_validity = if_false.validity_mask()?; - match mask.slices() { - AllOr::All => push_range( + if let Some(values) = mask.values() { + let mut pos = 0; + for (start, end) in values.bit_buffer().set_slices() { + if pos < start { + push_range( + if_false, + &false_lookup, + &false_validity, + pos..start, + &mut views_builder, + &mut validity_builder, + ); + } + push_range( + if_true, + &true_lookup, + &true_validity, + start..end, + &mut views_builder, + &mut validity_builder, + ); + pos = end; + } + if pos < len { + push_range( + if_false, + &false_lookup, + &false_validity, + pos..len, + &mut views_builder, + &mut validity_builder, + ); + } + } else if mask.all_true() { + push_range( if_true, &true_lookup, &true_validity, 0..len, &mut views_builder, &mut validity_builder, - ), - AllOr::None => push_range( + ) + } else { + push_range( if_false, &false_lookup, &false_validity, 0..len, &mut views_builder, &mut validity_builder, - ), - AllOr::Some(slices) => { - let mut pos = 0; - for (start, end) in slices { - if pos < *start { - push_range( - if_false, - &false_lookup, - &false_validity, - pos..*start, - &mut views_builder, - &mut validity_builder, - ); - } - push_range( - if_true, - &true_lookup, - &true_validity, - *start..*end, - &mut views_builder, - &mut validity_builder, - ); - pos = *end; - } - if pos < len { - push_range( - if_false, - &false_lookup, - &false_validity, - pos..len, - &mut views_builder, - &mut validity_builder, - ); - } - } + ) } let validity = validity_builder.finish_with_nullability(dtype.nullability()); diff --git a/vortex-array/src/compute/zip.rs b/vortex-array/src/compute/zip.rs index c0a85efde56..b7f84c6ece0 100644 --- a/vortex-array/src/compute/zip.rs +++ b/vortex-array/src/compute/zip.rs @@ -3,6 +3,8 @@ use vortex_dtype::DType; use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_mask::AllOr; use vortex_mask::Mask; @@ -59,20 +61,34 @@ fn zip_impl_with_builder( mask: &Mask, mut builder: Box, ) -> VortexResult { - match mask.slices() { - AllOr::All => Ok(if_true.to_array()), - AllOr::None => Ok(if_false.to_array()), - AllOr::Some(slices) => { - for (start, end) in slices { - builder.extend_from_array(&if_false.slice(builder.len()..*start)?); - builder.extend_from_array(&if_true.slice(*start..*end)?); - } - if builder.len() < if_false.len() { - builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); - } - Ok(builder.finish()) + if let Some(values) = mask.values() { + for (start, end) in values.bit_buffer().set_slices() { + builder.extend_from_array(&if_false.slice(builder.len()..start)?); + builder.extend_from_array(&if_true.slice(start..end)?); + } + if builder.len() < if_false.len() { + builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); } + Ok(builder.finish()) + } else if mask.all_true() { + Ok(if_true.to_array()) + } else { + Ok(if_false.to_array()) } + // match mask.slices() { + // AllOr::All => Ok(if_true.to_array()), + // AllOr::None => Ok(if_false.to_array()), + // AllOr::Some(slices) => { + // for (start, end) in slices { + // builder.extend_from_array(&if_false.slice(builder.len()..*start)?); + // builder.extend_from_array(&if_true.slice(*start..*end)?); + // } + // if builder.len() < if_false.len() { + // builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); + // } + // Ok(builder.finish()) + // } + // } } #[cfg(test)] diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index afa02def88c..60d4f339eb3 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -21,6 +21,7 @@ use vortex_dtype::match_each_integer_ptype; use vortex_dtype::match_each_native_ptype; use vortex_dtype::match_each_unsigned_integer_ptype; use vortex_error::VortexError; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; @@ -582,20 +583,21 @@ impl Patches { ); } - match mask.indices() { - AllOr::All => Ok(Some(self.clone())), - AllOr::None => Ok(None), - AllOr::Some(mask_indices) => { - let flat_indices = self.indices().to_primitive(); - match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| { - filter_patches_with_mask( - flat_indices.as_slice::(), - self.offset(), - self.values(), - mask_indices, - ) - }) - } + if mask.all_true() { + return Ok(Some(self.clone())); + } else if mask.all_false() { + Ok(None) + } else { + let mask_values = mask.values().vortex_expect("trust me"); + let flat_indices = self.indices().to_primitive(); + match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| { + filter_patches_with_mask( + flat_indices.as_slice::(), + self.offset(), + self.values(), + &mask_values.bit_buffer().set_indices().collect::>(), + ) + }) } } diff --git a/vortex-compute/src/filter/bitbuffer.rs b/vortex-compute/src/filter/bitbuffer.rs index 22f294233d7..522a4c5d4bd 100644 --- a/vortex-compute/src/filter/bitbuffer.rs +++ b/vortex-compute/src/filter/bitbuffer.rs @@ -38,7 +38,16 @@ impl Filter for &BitBuffer { "Selection mask length must equal the mask length" ); - self.filter(mask_values.indices()) + let bools = self.inner().as_slice(); + let bit_offset = self.offset(); + + BitBufferMut::from_iter( + mask_values + .bit_buffer() + .set_indices() + .map(|idx| get_bit(bools, bit_offset + idx)), + ) + .freeze() } } @@ -125,7 +134,15 @@ impl Filter for &mut BitBufferMut { ); // BitBufferMut filtering always uses indices for simplicity. - self.filter(mask_values.indices()) + let bools = self.inner().as_slice(); + let bit_offset = self.offset(); + + *self = BitBufferMut::from_iter( + mask_values + .bit_buffer() + .set_indices() + .map(|idx| get_bit(bools, bit_offset + idx)), + ); } } diff --git a/vortex-compute/src/filter/buffer.rs b/vortex-compute/src/filter/buffer.rs index a70c9e22908..d6331131f5e 100644 --- a/vortex-compute/src/filter/buffer.rs +++ b/vortex-compute/src/filter/buffer.rs @@ -106,6 +106,7 @@ impl Filter> for &Buffer { impl Filter for &mut BufferMut where for<'a> &'a mut [T]: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/slice.rs b/vortex-compute/src/filter/slice.rs index 189e68587ff..06c029511fc 100644 --- a/vortex-compute/src/filter/slice.rs +++ b/vortex-compute/src/filter/slice.rs @@ -47,7 +47,13 @@ impl Filter for &[T] { if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { // High density: use slices (contiguous ranges) - self.filter(mask_values.slices()) + + let mut out = BufferMut::::empty(); + for (start, end) in mask_values.bit_buffer().set_slices() { + out.extend_from_slice(&self[start..end]); + } + + out.freeze() } else { // Low density: stream indices directly from bitmap without allocatingExpand commentComment on line R52Resolved let mut out = BufferMut::::with_capacity(mask_values.true_count()); diff --git a/vortex-compute/src/filter/slice_mut.rs b/vortex-compute/src/filter/slice_mut.rs index fae1b8f68f4..15b8cceeb80 100644 --- a/vortex-compute/src/filter/slice_mut.rs +++ b/vortex-compute/src/filter/slice_mut.rs @@ -43,7 +43,27 @@ impl Filter for &mut [T] { // We choose to _always_ use slices here because iterating over indices will have strictly // more loop iterations than slices (more branches), and the overhead over batched // `ptr::copy(len)` is not that high. - self.filter(mask_values.slices()) + let mut write_pos = 0; + + // For each range in the selection, copy all of the elements to the current write position. + for (start, end) in mask_values.bit_buffer().set_slices() { + // Note that we could add an if statement here that checks `if start != write_pos`, but + // it's probably better to just avoid the branch misprediction. + let len = end - start; + + // SAFETY: Slices should be within bounds. + unsafe { + ptr::copy( + self.as_ptr().add(start), + self.as_mut_ptr().add(write_pos), + len, + ) + }; + + write_pos += len; + } + + &mut self[..write_pos] } } diff --git a/vortex-compute/src/filter/vector/binaryview.rs b/vortex-compute/src/filter/vector/binaryview.rs index 39da497bb34..f4773aac6ac 100644 --- a/vortex-compute/src/filter/vector/binaryview.rs +++ b/vortex-compute/src/filter/vector/binaryview.rs @@ -18,6 +18,7 @@ impl Filter for &BinaryViewVector where for<'a> &'a Mask: Filter, for<'a> &'a Buffer: Filter>, + M: ?Sized, { type Output = BinaryViewVector; @@ -34,6 +35,7 @@ impl Filter for &mut BinaryViewVectorMut where for<'a> &'a mut MaskMut: Filter, for<'a> &'a mut BufferMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/bool.rs b/vortex-compute/src/filter/vector/bool.rs index deec0dbcd25..34dfc1f25a8 100644 --- a/vortex-compute/src/filter/vector/bool.rs +++ b/vortex-compute/src/filter/vector/bool.rs @@ -16,6 +16,7 @@ impl Filter for &BoolVector where for<'a> &'a BitBuffer: Filter, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = BoolVector; @@ -34,6 +35,7 @@ impl Filter for &mut BoolVectorMut where for<'a> &'a mut BitBufferMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/decimal.rs b/vortex-compute/src/filter/vector/decimal.rs index 12263c03189..1fe2a9cd00a 100644 --- a/vortex-compute/src/filter/vector/decimal.rs +++ b/vortex-compute/src/filter/vector/decimal.rs @@ -21,6 +21,7 @@ where for<'a> &'a DVector: Filter>, for<'a> &'a DVector: Filter>, for<'a> &'a DVector: Filter>, + M: ?Sized, { type Output = DecimalVector; @@ -37,6 +38,7 @@ where for<'a> &'a mut DVectorMut: Filter, for<'a> &'a mut DVectorMut: Filter, for<'a> &'a mut DVectorMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/dvector.rs b/vortex-compute/src/filter/vector/dvector.rs index f0371af0497..2425dcde00b 100644 --- a/vortex-compute/src/filter/vector/dvector.rs +++ b/vortex-compute/src/filter/vector/dvector.rs @@ -17,6 +17,7 @@ impl Filter for &DVector where for<'a> &'a Buffer: Filter>, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = DVector; @@ -32,6 +33,7 @@ impl Filter for &mut DVectorMut where for<'a> &'a mut BufferMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/fixed_size_list.rs b/vortex-compute/src/filter/vector/fixed_size_list.rs index c9fd449738a..744e99c2908 100644 --- a/vortex-compute/src/filter/vector/fixed_size_list.rs +++ b/vortex-compute/src/filter/vector/fixed_size_list.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskMut; use vortex_vector::Vector; use vortex_vector::VectorMut; @@ -27,7 +26,8 @@ const MASK_EXPANSION_DENSITY_THRESHOLD: f64 = 0.05; impl Filter for &FixedSizeListVector where for<'a> &'a Mask: Filter, - for<'a> &'a Vector: Filter, + for<'a> &'a Vector: Filter<[(usize, usize)], Output = Vector>, + M: ?Sized, { type Output = FixedSizeListVector; @@ -40,7 +40,7 @@ where let elements_mask = compute_fsl_elements_mask(&filtered_validity, list_size as usize); // Filter the child elements vector. - self.elements().as_ref().filter(&elements_mask) + self.elements().as_ref().filter(elements_mask.as_slice()) } else { debug_assert!( self.elements().is_empty(), @@ -68,7 +68,8 @@ where impl Filter for &mut FixedSizeListVectorMut where for<'a> &'a mut MaskMut: Filter, - for<'a> &'a mut VectorMut: Filter, + for<'a> &'a mut VectorMut: Filter<[(usize, usize)], Output = ()>, + M: ?Sized, { type Output = (); @@ -93,7 +94,7 @@ where // SAFETY: The expanded mask has the correct length (`validity.len() * list_size`), // which maintains the invariant after filtering. unsafe { - self.elements_mut().filter(&elements_mask); + self.elements_mut().filter(elements_mask.as_slice()); } debug_assert_eq!( @@ -137,41 +138,32 @@ where /// `list_size` times. /// /// The output [`Mask`] is guaranteed to have a length equal to `selection_mask.len() * list_size`. -fn compute_fsl_elements_mask(selection_mask: &Mask, list_size: usize) -> Mask { - let expanded_len = selection_mask.len() * list_size; +fn compute_fsl_elements_mask(selection_mask: &Mask, list_size: usize) -> Vec<(usize, usize)> { + // let expanded_len = selection_mask.len() * list_size; let values = match selection_mask { - Mask::AllTrue(_) => return Mask::AllTrue(expanded_len), - Mask::AllFalse(_) => return Mask::AllFalse(expanded_len), + Mask::AllTrue(_) => return vec![(0, selection_mask.len() * list_size)], + Mask::AllFalse(_) => return vec![], Mask::Values(values) => values, }; - // Use threshold_iter to choose the optimal representation based on density. - let expanded_slices = match values.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { - MaskIter::Slices(slices) => { - // Expand a dense mask (represented as slices) by scaling each slice by `list_size`. - slices - .iter() - .map(|&(start, end)| (start * list_size, end * list_size)) - .collect() - } - MaskIter::Indices(indices) => { - // Expand a sparse mask (represented as indices) by duplicating each index `list_size` - // times. - // - // Note that in the worst case, it is possible that we create only a few slices with a - // small range (for example, when list_size <= 2). This could be further optimized, - // but we choose simplicity for now. - indices - .iter() - .map(|&idx| { - let start = idx * list_size; - let end = (idx + 1) * list_size; - (start, end) - }) - .collect() - } + let expanded_slices = if values.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { + values + .bit_buffer() + .set_slices() + .map(|(start, end)| (start * list_size, end * list_size)) + .collect() + } else { + values + .bit_buffer() + .set_indices() + .map(|idx| { + let start = idx * list_size; + let end = (idx + 1) * list_size; + (start, end) + }) + .collect() }; - Mask::from_slices(expanded_len, expanded_slices) + expanded_slices } diff --git a/vortex-compute/src/filter/vector/list.rs b/vortex-compute/src/filter/vector/list.rs index 3b9615c2f05..f87d43b2cd7 100644 --- a/vortex-compute/src/filter/vector/list.rs +++ b/vortex-compute/src/filter/vector/list.rs @@ -18,6 +18,7 @@ impl Filter for &ListViewVector where for<'a> &'a PrimitiveVector: Filter, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = ListViewVector; @@ -37,6 +38,7 @@ impl Filter for &mut ListViewVectorMut where for<'a> &'a mut PrimitiveVectorMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/mod.rs b/vortex-compute/src/filter/vector/mod.rs index 040b2e06b33..cf4ba9687c3 100644 --- a/vortex-compute/src/filter/vector/mod.rs +++ b/vortex-compute/src/filter/vector/mod.rs @@ -74,6 +74,14 @@ impl Filter for &mut VectorMut { } } +impl Filter<[(usize, usize)]> for &Vector { + type Output = Vector; + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + match_each_vector!(self, |v| { v.filter(selection).into() }) + } +} + impl Filter> for &Vector { type Output = Vector; @@ -89,3 +97,11 @@ impl Filter> for &mut VectorMut { match_each_vector_mut!(self, |v| { v.filter(selection) }) } } + +impl Filter<[(usize, usize)]> for &mut VectorMut { + type Output = (); + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + match_each_vector_mut!(self, |v| { v.filter(selection) }) + } +} diff --git a/vortex-compute/src/filter/vector/null.rs b/vortex-compute/src/filter/vector/null.rs index a4df6f4c519..e2a60879fc8 100644 --- a/vortex-compute/src/filter/vector/null.rs +++ b/vortex-compute/src/filter/vector/null.rs @@ -16,6 +16,32 @@ impl Filter for &NullVector { } } +impl Filter<[(usize, usize)]> for &NullVector { + type Output = NullVector; + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + NullVector::new( + selection + .iter() + .map(|(start, end)| start + end) + .sum::(), + ) + } +} + +impl Filter<[(usize, usize)]> for &mut NullVectorMut { + type Output = (); + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + *self = NullVectorMut::new( + selection + .iter() + .map(|(start, end)| start + end) + .sum::(), + ) + } +} + impl Filter> for &NullVector { type Output = NullVector; diff --git a/vortex-compute/src/filter/vector/primitive.rs b/vortex-compute/src/filter/vector/primitive.rs index 82b50c1adbd..799a85c4d01 100644 --- a/vortex-compute/src/filter/vector/primitive.rs +++ b/vortex-compute/src/filter/vector/primitive.rs @@ -26,6 +26,7 @@ where for<'a> &'a PVector: Filter>, for<'a> &'a PVector: Filter>, for<'a> &'a PVector: Filter>, + M: ?Sized, { type Output = PrimitiveVector; @@ -47,6 +48,7 @@ where for<'a> &'a mut PVectorMut: Filter, for<'a> &'a mut PVectorMut: Filter, for<'a> &'a mut PVectorMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/pvector.rs b/vortex-compute/src/filter/vector/pvector.rs index fa8e9b64635..1210271f5cb 100644 --- a/vortex-compute/src/filter/vector/pvector.rs +++ b/vortex-compute/src/filter/vector/pvector.rs @@ -17,6 +17,7 @@ impl Filter for &PVector where for<'a> &'a Buffer: Filter>, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = PVector; @@ -34,6 +35,7 @@ impl Filter for &mut PVectorMut where for<'a> &'a mut BufferMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/struct_.rs b/vortex-compute/src/filter/vector/struct_.rs index 67a057bf691..b9584049748 100644 --- a/vortex-compute/src/filter/vector/struct_.rs +++ b/vortex-compute/src/filter/vector/struct_.rs @@ -18,6 +18,7 @@ impl Filter for &StructVector where for<'a> &'a Mask: Filter, for<'a> &'a Vector: Filter, + M: ?Sized, { type Output = StructVector; @@ -40,6 +41,7 @@ impl Filter for &mut StructVectorMut where for<'a> &'a mut MaskMut: Filter, for<'a> &'a mut VectorMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-mask/src/intersect_by_rank.rs b/vortex-mask/src/intersect_by_rank.rs index efce6f17dbd..d27345f00b6 100644 --- a/vortex-mask/src/intersect_by_rank.rs +++ b/vortex-mask/src/intersect_by_rank.rs @@ -1,7 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use crate::AllOr; +use vortex_error::VortexExpect; + use crate::Mask; impl Mask { @@ -29,171 +30,167 @@ impl Mask { pub fn intersect_by_rank(&self, mask: &Mask) -> Mask { assert_eq!(self.true_count(), mask.len()); - match (self.indices(), mask.indices()) { - (AllOr::All, _) => mask.clone(), - (_, AllOr::All) => self.clone(), - (AllOr::None, _) | (_, AllOr::None) => Self::new_false(self.len()), - - (AllOr::Some(self_indices), AllOr::Some(mask_indices)) => { - Self::from_indices( - self.len(), - mask_indices - .iter() - .map(|idx| - // This is verified as safe because we know that the indices are less than the - // mask.len() and we known mask.len() <= self.len(), - // implied by `self.true_count() == mask.len()`. - unsafe{*self_indices.get_unchecked(*idx)}) - .collect(), - ) - } + if self.all_true() { + mask.clone() + } else if mask.all_true() { + self.clone() + } else if self.all_false() || mask.all_false() { + Self::new_false(self.len()) + } else { + let mask_values = mask.values().vortex_expect("msg"); + Self::from_iter( + mask_values + .bit_buffer() + .set_indices() + .map(|idx| self.value(idx)), + ) } } } -#[cfg(test)] -mod test { - use rstest::rstest; - use vortex_buffer::BitBuffer; - - use crate::Mask; - - #[test] - fn mask_bitand_all_as_bit_and() { - let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true, true, true])); - let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, true, false, true, true])); - assert_eq!( - this.intersect_by_rank(&mask), - Mask::from_indices(5, vec![1, 3, 4]) - ); - } - - #[test] - fn mask_bitand_all_true() { - let this = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, true, true, true])); - let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true])); - assert_eq!( - this.intersect_by_rank(&mask), - Mask::from_indices(5, vec![2, 3, 4]) - ); - } - - #[test] - fn mask_bitand_true() { - let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); - let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, true])); - assert_eq!( - this.intersect_by_rank(&mask), - Mask::from_indices(5, vec![0, 4]) - ); - } - - #[test] - fn mask_bitand_false() { - let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); - let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, false])); - assert_eq!(this.intersect_by_rank(&mask), Mask::from_indices(5, vec![])); - } - - #[test] - fn mask_intersect_by_rank_all_false() { - let this = Mask::AllFalse(10); - let mask = Mask::AllFalse(0); - assert_eq!(this.intersect_by_rank(&mask), Mask::AllFalse(10)); - } - - #[rstest] - #[case::all_true_with_all_true( - Mask::new_true(5), - Mask::new_true(5), - vec![0, 1, 2, 3, 4] - )] - #[case::all_true_with_all_false( - Mask::new_true(5), - Mask::new_false(5), - vec![] - )] - #[case::all_false_with_any( - Mask::new_false(10), - Mask::new_true(0), - vec![] - )] - #[case::indices_with_all_true( - Mask::from_indices(10, vec![2, 5, 7, 9]), - Mask::new_true(4), - vec![2, 5, 7, 9] - )] - #[case::indices_with_all_false( - Mask::from_indices(10, vec![2, 5, 7, 9]), - Mask::new_false(4), - vec![] - )] - fn test_intersect_by_rank_special_cases( - #[case] base_mask: Mask, - #[case] rank_mask: Mask, - #[case] expected_indices: Vec, - ) { - let result = base_mask.intersect_by_rank(&rank_mask); - - match result.indices() { - crate::AllOr::All => assert_eq!(expected_indices.len(), result.len()), - crate::AllOr::None => assert!(expected_indices.is_empty()), - crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), - } - } - - #[test] - fn test_intersect_by_rank_example() { - // Example from the documentation - let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]); - let m2 = Mask::from_iter([false, false, true, false, true]); - let result = m1.intersect_by_rank(&m2); - let expected = Mask::from_iter([false, false, false, false, true, false, false, true]); - assert_eq!(result, expected); - } - - #[test] - #[should_panic] - fn test_intersect_by_rank_wrong_length() { - let m1 = Mask::from_indices(10, vec![2, 5, 7]); // 3 true values - let m2 = Mask::new_true(5); // 5 true values - doesn't match - m1.intersect_by_rank(&m2); - } - - #[rstest] - #[case::single_element( - vec![3], - vec![true], - vec![3] - )] - #[case::single_element_masked( - vec![3], - vec![false], - vec![] - )] - #[case::alternating( - vec![0, 2, 4, 6, 8], - vec![true, false, true, false, true], - vec![0, 4, 8] - )] - #[case::consecutive( - vec![5, 6, 7, 8, 9], - vec![false, true, true, true, false], - vec![6, 7, 8] - )] - fn test_intersect_by_rank_patterns( - #[case] base_indices: Vec, - #[case] rank_pattern: Vec, - #[case] expected_indices: Vec, - ) { - let base = Mask::from_indices(10, base_indices); - let rank = Mask::from_iter(rank_pattern); - let result = base.intersect_by_rank(&rank); - - match result.indices() { - crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), - crate::AllOr::None => assert!(expected_indices.is_empty()), - _ => panic!("Unexpected result"), - } - } -} +// #[cfg(test)] +// mod test { +// use rstest::rstest; +// use vortex_buffer::BitBuffer; + +// use crate::Mask; + +// #[test] +// fn mask_bitand_all_as_bit_and() { +// let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true, true, true])); +// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, true, false, true, true])); +// assert_eq!( +// this.intersect_by_rank(&mask), +// Mask::from_indices(5, vec![1, 3, 4]) +// ); +// } + +// #[test] +// fn mask_bitand_all_true() { +// let this = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, true, true, true])); +// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true])); +// assert_eq!( +// this.intersect_by_rank(&mask), +// Mask::from_indices(5, vec![2, 3, 4]) +// ); +// } + +// #[test] +// fn mask_bitand_true() { +// let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); +// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, true])); +// assert_eq!( +// this.intersect_by_rank(&mask), +// Mask::from_indices(5, vec![0, 4]) +// ); +// } + +// #[test] +// fn mask_bitand_false() { +// let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); +// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, false])); +// assert_eq!(this.intersect_by_rank(&mask), Mask::from_indices(5, vec![])); +// } + +// #[test] +// fn mask_intersect_by_rank_all_false() { +// let this = Mask::AllFalse(10); +// let mask = Mask::AllFalse(0); +// assert_eq!(this.intersect_by_rank(&mask), Mask::AllFalse(10)); +// } + +// #[rstest] +// #[case::all_true_with_all_true( +// Mask::new_true(5), +// Mask::new_true(5), +// vec![0, 1, 2, 3, 4] +// )] +// #[case::all_true_with_all_false( +// Mask::new_true(5), +// Mask::new_false(5), +// vec![] +// )] +// #[case::all_false_with_any( +// Mask::new_false(10), +// Mask::new_true(0), +// vec![] +// )] +// #[case::indices_with_all_true( +// Mask::from_indices(10, vec![2, 5, 7, 9]), +// Mask::new_true(4), +// vec![2, 5, 7, 9] +// )] +// #[case::indices_with_all_false( +// Mask::from_indices(10, vec![2, 5, 7, 9]), +// Mask::new_false(4), +// vec![] +// )] +// fn test_intersect_by_rank_special_cases( +// #[case] base_mask: Mask, +// #[case] rank_mask: Mask, +// #[case] expected_indices: Vec, +// ) { +// let result = base_mask.intersect_by_rank(&rank_mask); + +// match result.indices() { +// crate::AllOr::All => assert_eq!(expected_indices.len(), result.len()), +// crate::AllOr::None => assert!(expected_indices.is_empty()), +// crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), +// } +// } + +// #[test] +// fn test_intersect_by_rank_example() { +// // Example from the documentation +// let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]); +// let m2 = Mask::from_iter([false, false, true, false, true]); +// let result = m1.intersect_by_rank(&m2); +// let expected = Mask::from_iter([false, false, false, false, true, false, false, true]); +// assert_eq!(result, expected); +// } + +// #[test] +// #[should_panic] +// fn test_intersect_by_rank_wrong_length() { +// let m1 = Mask::from_indices(10, vec![2, 5, 7]); // 3 true values +// let m2 = Mask::new_true(5); // 5 true values - doesn't match +// m1.intersect_by_rank(&m2); +// } + +// #[rstest] +// #[case::single_element( +// vec![3], +// vec![true], +// vec![3] +// )] +// #[case::single_element_masked( +// vec![3], +// vec![false], +// vec![] +// )] +// #[case::alternating( +// vec![0, 2, 4, 6, 8], +// vec![true, false, true, false, true], +// vec![0, 4, 8] +// )] +// #[case::consecutive( +// vec![5, 6, 7, 8, 9], +// vec![false, true, true, true, false], +// vec![6, 7, 8] +// )] +// fn test_intersect_by_rank_patterns( +// #[case] base_indices: Vec, +// #[case] rank_pattern: Vec, +// #[case] expected_indices: Vec, +// ) { +// let base = Mask::from_indices(10, base_indices); +// let rank = Mask::from_iter(rank_pattern); +// let result = base.intersect_by_rank(&rank); + +// match result.indices() { +// crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), +// crate::AllOr::None => assert!(expected_indices.is_empty()), +// _ => panic!("Unexpected result"), +// } +// } +// } diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 1f87eb70f46..a880a426890 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -21,13 +21,13 @@ use std::fmt::Formatter; use std::ops::Bound; use std::ops::RangeBounds; use std::sync::Arc; -use std::sync::OnceLock; use itertools::Itertools; pub use mask_mut::*; use vortex_buffer::BitBuffer; use vortex_buffer::BitBufferMut; use vortex_buffer::set_bit_unchecked; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_panic; @@ -130,10 +130,10 @@ pub struct MaskValues { // We cached the indices and slices representations, since it can be faster than iterating // the bit-mask over and over again. - #[cfg_attr(feature = "serde", serde(skip))] - indices: OnceLock>, - #[cfg_attr(feature = "serde", serde(skip))] - slices: OnceLock>, + // #[cfg_attr(feature = "serde", serde(skip))] + // indices: OnceLock>, + // #[cfg_attr(feature = "serde", serde(skip))] + // slices: OnceLock>, // Pre-computed values. true_count: usize, @@ -177,8 +177,8 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer, - indices: Default::default(), - slices: Default::default(), + // indices: Default::default(), + // slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -208,8 +208,8 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - indices: OnceLock::from(indices), - slices: Default::default(), + // indices: OnceLock::from(indices), + // slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -237,8 +237,8 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - indices: Default::default(), - slices: Default::default(), + // indices: Default::default(), + // slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -271,8 +271,8 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - indices: Default::default(), - slices: OnceLock::from(slices), + // indices: Default::default(), + // slices: OnceLock::from(slices), true_count, density: true_count as f64 / len as f64, })) @@ -411,15 +411,7 @@ impl Mask { match &self { Self::AllTrue(len) => (*len > 0).then_some(0), Self::AllFalse(_) => None, - Self::Values(values) => { - if let Some(indices) = values.indices.get() { - return indices.first().copied(); - } - if let Some(slices) = values.slices.get() { - return slices.first().map(|(start, _)| *start); - } - values.buffer.set_indices().next() - } + Self::Values(values) => values.buffer.set_indices().next(), } } @@ -435,7 +427,12 @@ impl Mask { Self::AllTrue(_) => n, Self::AllFalse(_) => unreachable!("no true values in all-false mask"), // TODO(joe): optimize this function - Self::Values(values) => values.indices()[n], + Self::Values(values) => values + .bit_buffer() + .set_indices() + .take(n + 1) + .last() + .vortex_expect("validated within range"), } } @@ -499,34 +496,34 @@ impl Mask { } /// Return the indices representation of the mask. - #[inline] - pub fn indices(&self) -> AllOr<&[usize]> { - match &self { - Self::AllTrue(_) => AllOr::All, - Self::AllFalse(_) => AllOr::None, - Self::Values(values) => AllOr::Some(values.indices()), - } - } + // #[inline] + // pub fn indices(&self) -> AllOr<&[usize]> { + // match &self { + // Self::AllTrue(_) => AllOr::All, + // Self::AllFalse(_) => AllOr::None, + // Self::Values(values) => AllOr::Some(values.indices()), + // } + // } /// Return the slices representation of the mask. - #[inline] - pub fn slices(&self) -> AllOr<&[(usize, usize)]> { - match &self { - Self::AllTrue(_) => AllOr::All, - Self::AllFalse(_) => AllOr::None, - Self::Values(values) => AllOr::Some(values.slices()), - } - } + // #[inline] + // pub fn slices(&self) -> AllOr<&[(usize, usize)]> { + // match &self { + // Self::AllTrue(_) => AllOr::All, + // Self::AllFalse(_) => AllOr::None, + // Self::Values(values) => AllOr::Some(values.slices()), + // } + // } /// Return an iterator over either indices or slices of the mask based on a density threshold. - #[inline] - pub fn threshold_iter(&self, threshold: f64) -> AllOr> { - match &self { - Self::AllTrue(_) => AllOr::All, - Self::AllFalse(_) => AllOr::None, - Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)), - } - } + // #[inline] + // pub fn threshold_iter(&self, threshold: f64) -> AllOr> { + // match &self { + // Self::AllTrue(_) => AllOr::All, + // Self::AllFalse(_) => AllOr::None, + // Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)), + // } + // } /// Return [`MaskValues`] if the mask is not all true or all false. #[inline] @@ -673,54 +670,54 @@ impl MaskValues { self.buffer.value(index) } - /// Constructs an indices vector from one of the other representations. - pub fn indices(&self) -> &[usize] { - self.indices.get_or_init(|| { - if self.true_count == 0 { - return vec![]; - } - - if self.true_count == self.len() { - return (0..self.len()).collect(); - } - - if let Some(slices) = self.slices.get() { - let mut indices = Vec::with_capacity(self.true_count); - indices.extend(slices.iter().flat_map(|(start, end)| *start..*end)); - debug_assert!(indices.is_sorted()); - assert_eq!(indices.len(), self.true_count); - return indices; - } - - let mut indices = Vec::with_capacity(self.true_count); - indices.extend(self.buffer.set_indices()); - debug_assert!(indices.is_sorted()); - assert_eq!(indices.len(), self.true_count); - indices - }) - } + // /// Constructs an indices vector from one of the other representations. + // pub fn indices(&self) -> &[usize] { + // self.indices.get_or_init(|| { + // if self.true_count == 0 { + // return vec![]; + // } + + // if self.true_count == self.len() { + // return (0..self.len()).collect(); + // } + + // if let Some(slices) = self.slices.get() { + // let mut indices = Vec::with_capacity(self.true_count); + // indices.extend(slices.iter().flat_map(|(start, end)| *start..*end)); + // debug_assert!(indices.is_sorted()); + // assert_eq!(indices.len(), self.true_count); + // return indices; + // } + + // let mut indices = Vec::with_capacity(self.true_count); + // indices.extend(self.buffer.set_indices()); + // debug_assert!(indices.is_sorted()); + // assert_eq!(indices.len(), self.true_count); + // indices + // }) + // } /// Constructs a slices vector from one of the other representations. - #[inline] - pub fn slices(&self) -> &[(usize, usize)] { - self.slices.get_or_init(|| { - if self.true_count == self.len() { - return vec![(0, self.len())]; - } + // #[inline] + // pub fn slices(&self) -> &[(usize, usize)] { + // self.slices.get_or_init(|| { + // if self.true_count == self.len() { + // return vec![(0, self.len())]; + // } - self.buffer.set_slices().collect() - }) - } + // self.buffer.set_slices().collect() + // }) + // } /// Return an iterator over either indices or slices of the mask based on a density threshold. - #[inline] - pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> { - if self.density >= threshold { - MaskIter::Slices(self.slices()) - } else { - MaskIter::Indices(self.indices()) - } - } + // #[inline] + // pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> { + // if self.density >= threshold { + // MaskIter::Slices(self.slices()) + // } else { + // MaskIter::Indices(self.indices()) + // } + // } /// Extracts the internal [`BitBuffer`]. pub(crate) fn into_buffer(self) -> BitBuffer { diff --git a/vortex-mask/src/tests.rs b/vortex-mask/src/tests.rs index a13827abbc3..9b757e37864 100644 --- a/vortex-mask/src/tests.rs +++ b/vortex-mask/src/tests.rs @@ -9,7 +9,6 @@ use vortex_buffer::BitBuffer; use crate::AllOr; use crate::Mask; -use crate::MaskIter; // Basic mask creation and properties tests #[test] @@ -18,8 +17,8 @@ fn mask_all_true() { assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 5); assert_eq!(mask.density(), 1.0); - assert_eq!(mask.indices(), AllOr::All); - assert_eq!(mask.slices(), AllOr::All); + // assert_eq!(mask.indices(), AllOr::All); + // assert_eq!(mask.slices(), AllOr::All); assert_eq!(mask.bit_buffer(), AllOr::All,); } @@ -29,8 +28,8 @@ fn mask_all_false() { assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 0); assert_eq!(mask.density(), 0.0); - assert_eq!(mask.indices(), AllOr::None); - assert_eq!(mask.slices(), AllOr::None); + // assert_eq!(mask.indices(), AllOr::None); + // assert_eq!(mask.slices(), AllOr::None); assert_eq!(mask.bit_buffer(), AllOr::None,); } @@ -46,8 +45,8 @@ fn mask_from() { assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 3); assert_eq!(mask.density(), 0.6); - assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..])); - assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..])); + // assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..])); + // assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..])); assert_eq!( mask.bit_buffer(), AllOr::Some(&BitBuffer::from_iter([true, false, true, true, false])) @@ -251,27 +250,27 @@ fn test_mask_values() { assert!(!values.value(1)); } -#[test] -fn test_mask_values_threshold_iter() { - let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); - let values = mask.values().unwrap(); +// #[test] +// fn test_mask_values_threshold_iter() { +// let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); +// let values = mask.values().unwrap(); - // With low threshold, should prefer indices - match values.threshold_iter(0.7) { - MaskIter::Indices(indices) => { - assert_eq!(indices, &[0, 2, 3]); - } - _ => panic!("Expected indices iterator"), - } +// // With low threshold, should prefer indices +// match values.threshold_iter(0.7) { +// MaskIter::Indices(indices) => { +// assert_eq!(indices, &[0, 2, 3]); +// } +// _ => panic!("Expected indices iterator"), +// } - // With high threshold, should prefer slices - match values.threshold_iter(0.5) { - MaskIter::Slices(slices) => { - assert_eq!(slices, &[(0, 1), (2, 4)]); - } - _ => panic!("Expected slices iterator"), - } -} +// // With high threshold, should prefer slices +// match values.threshold_iter(0.5) { +// MaskIter::Slices(slices) => { +// assert_eq!(slices, &[(0, 1), (2, 4)]); +// } +// _ => panic!("Expected slices iterator"), +// } +// } #[test] fn test_mask_values_is_empty() { @@ -476,65 +475,65 @@ fn test_mask_from_slices_overlapping() { Mask::from_slices(5, vec![(0, 3), (2, 4)]); // Overlapping ranges } -// Threshold iterator tests -#[test] -fn test_mask_threshold_iter() { - let all_true = Mask::new_true(5); - assert!(matches!(all_true.threshold_iter(0.5), AllOr::All)); +// // Threshold iterator tests +// #[test] +// fn test_mask_threshold_iter() { +// let all_true = Mask::new_true(5); +// assert!(matches!(all_true.threshold_iter(0.5), AllOr::All)); - let all_false = Mask::new_false(5); - assert!(matches!(all_false.threshold_iter(0.5), AllOr::None)); +// let all_false = Mask::new_false(5); +// assert!(matches!(all_false.threshold_iter(0.5), AllOr::None)); - let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); - if let AllOr::Some(MaskIter::Indices(indices)) = mask.threshold_iter(0.7) { - assert_eq!(indices, &[0, 2, 3]); - } else { - panic!("Expected indices iterator"); - } -} +// let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); +// if let AllOr::Some(MaskIter::Indices(indices)) = mask.threshold_iter(0.7) { +// assert_eq!(indices, &[0, 2, 3]); +// } else { +// panic!("Expected indices iterator"); +// } +// } // Caching tests -#[test] -fn test_mask_indices_caching() { - // Test that indices are properly cached - let mask = Mask::from_slices(10, vec![(0, 3), (5, 7), (9, 10)]); - - // First call should compute indices - let indices1 = mask.indices(); - // Second call should return cached value - let indices2 = mask.indices(); - - match (indices1, indices2) { - (AllOr::Some(i1), AllOr::Some(i2)) => { - assert_eq!(i1, i2); - assert_eq!(i1, &[0, 1, 2, 5, 6, 9]); - // Verify they're the same reference (cached) - assert!(std::ptr::eq(i1, i2)); - } - _ => panic!("Expected Some variant"), - } -} - -#[test] -fn test_mask_slices_caching() { - // Test that slices are properly cached - let mask = Mask::from_indices(10, vec![0, 1, 2, 5, 6, 9]); - - // First call should compute slices - let slices1 = mask.slices(); - // Second call should return cached value - let slices2 = mask.slices(); - - match (slices1, slices2) { - (AllOr::Some(s1), AllOr::Some(s2)) => { - assert_eq!(s1, s2); - assert_eq!(s1, &[(0, 3), (5, 7), (9, 10)]); - // Verify they're the same reference (cached) - assert!(std::ptr::eq(s1, s2)); - } - _ => panic!("Expected Some variant"), - } -} +// #[test] +// fn test_mask_indices_caching() { +// // Test that indices are properly cached +// let mask = Mask::from_slices(10, vec![(0, 3), (5, 7), (9, 10)]); + +// // First call should compute indices +// let indices1 = mask.indices(); +// // Second call should return cached value +// let indices2 = mask.indices(); + +// match (indices1, indices2) { +// (AllOr::Some(i1), AllOr::Some(i2)) => { +// assert_eq!(i1, i2); +// assert_eq!(i1, &[0, 1, 2, 5, 6, 9]); +// // Verify they're the same reference (cached) +// assert!(std::ptr::eq(i1, i2)); +// } +// _ => panic!("Expected Some variant"), +// } +// } + +// #[test] +// fn test_mask_slices_caching() { +// // Test that slices are properly cached +// let mask = Mask::from_indices(10, vec![0, 1, 2, 5, 6, 9]); + +// // First call should compute slices +// let slices1 = mask.slices(); +// // Second call should return cached value +// let slices2 = mask.slices(); + +// match (slices1, slices2) { +// (AllOr::Some(s1), AllOr::Some(s2)) => { +// assert_eq!(s1, s2); +// assert_eq!(s1, &[(0, 3), (5, 7), (9, 10)]); +// // Verify they're the same reference (cached) +// assert!(std::ptr::eq(s1, s2)); +// } +// _ => panic!("Expected Some variant"), +// } +// } // AllOr tests #[test] @@ -610,52 +609,52 @@ fn test_mask_properties( assert!((mask.density() - expected_density).abs() < 1e-10); } -#[rstest] -#[case::indices(vec![0, 2, 4], vec![(0, 1), (2, 3), (4, 5)])] -#[case::consecutive(vec![0, 1, 2], vec![(0, 3)])] -#[case::gap(vec![0, 1, 4, 5], vec![(0, 2), (4, 6)])] -#[case::single(vec![3], vec![(3, 4)])] -fn test_indices_to_slices_conversion( - #[case] indices: Vec, - #[case] expected_slices: Vec<(usize, usize)>, -) { - let mask = Mask::from_indices(10, indices.clone()); - - // Check indices - if let AllOr::Some(actual_indices) = mask.indices() { - assert_eq!(actual_indices, &indices[..]); - } else { - panic!("Expected Some variant for indices"); - } - - // Check slices - if let AllOr::Some(actual_slices) = mask.slices() { - assert_eq!(actual_slices, &expected_slices[..]); - } else { - panic!("Expected Some variant for slices"); - } -} - -#[rstest] -#[case::empty_intersection(vec![0, 2, 4], vec![1, 3, 5], vec![])] -#[case::full_intersection(vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5])] -#[case::partial_intersection(vec![0, 1, 2, 3], vec![2, 3, 4, 5], vec![2, 3])] -#[case::subset_left(vec![1, 2], vec![0, 1, 2, 3], vec![1, 2])] -#[case::subset_right(vec![0, 1, 2, 3], vec![1, 2], vec![1, 2])] -fn test_intersection_indices( - #[case] left: Vec, - #[case] right: Vec, - #[case] expected: Vec, -) { - let mask = Mask::from_intersection_indices(10, left.into_iter(), right.into_iter()); - - match mask.indices() { - AllOr::Some(indices) if expected.is_empty() => assert!(indices.is_empty()), - AllOr::Some(indices) => assert_eq!(indices, &expected[..]), - AllOr::None if expected.is_empty() => {} - AllOr::None | AllOr::All => panic!("Unexpected result for intersection"), - } -} +// #[rstest] +// #[case::indices(vec![0, 2, 4], vec![(0, 1), (2, 3), (4, 5)])] +// #[case::consecutive(vec![0, 1, 2], vec![(0, 3)])] +// #[case::gap(vec![0, 1, 4, 5], vec![(0, 2), (4, 6)])] +// #[case::single(vec![3], vec![(3, 4)])] +// fn test_indices_to_slices_conversion( +// #[case] indices: Vec, +// #[case] expected_slices: Vec<(usize, usize)>, +// ) { +// let mask = Mask::from_indices(10, indices.clone()); + +// // Check indices +// if let AllOr::Some(actual_indices) = mask.indices() { +// assert_eq!(actual_indices, &indices[..]); +// } else { +// panic!("Expected Some variant for indices"); +// } + +// // Check slices +// if let AllOr::Some(actual_slices) = mask.slices() { +// assert_eq!(actual_slices, &expected_slices[..]); +// } else { +// panic!("Expected Some variant for slices"); +// } +// } + +// #[rstest] +// #[case::empty_intersection(vec![0, 2, 4], vec![1, 3, 5], vec![])] +// #[case::full_intersection(vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5])] +// #[case::partial_intersection(vec![0, 1, 2, 3], vec![2, 3, 4, 5], vec![2, 3])] +// #[case::subset_left(vec![1, 2], vec![0, 1, 2, 3], vec![1, 2])] +// #[case::subset_right(vec![0, 1, 2, 3], vec![1, 2], vec![1, 2])] +// fn test_intersection_indices( +// #[case] left: Vec, +// #[case] right: Vec, +// #[case] expected: Vec, +// ) { +// let mask = Mask::from_intersection_indices(10, left.into_iter(), right.into_iter()); + +// match mask.indices() { +// AllOr::Some(indices) if expected.is_empty() => assert!(indices.is_empty()), +// AllOr::Some(indices) => assert_eq!(indices, &expected[..]), +// AllOr::None if expected.is_empty() => {} +// AllOr::None | AllOr::All => panic!("Unexpected result for intersection"), +// } +// } // Concat operation tests #[test] From af5e2f4e8b1b01c63526f3172ee5bbcd33b845c4 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 12 Feb 2026 18:53:26 +0000 Subject: [PATCH 04/11] fix tests Signed-off-by: Adam Gutglick --- .../src/arrays/constant/compute/take.rs | 16 ++++---- vortex-array/src/arrays/dict/array.rs | 38 ++++++++++++------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs index 209fc5feb5c..b4bac8979d7 100644 --- a/vortex-array/src/arrays/constant/compute/take.rs +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -62,7 +62,6 @@ mod tests { use rstest::rstest; use vortex_buffer::buffer; use vortex_dtype::Nullability; - use vortex_mask::AllOr; use vortex_scalar::Scalar; use crate::Array; @@ -86,7 +85,6 @@ mod tests { .into_array(), ) .unwrap(); - let valid_indices: &[usize] = &[1usize]; assert_eq!( &array.dtype().with_nullability(Nullability::Nullable), taken.dtype() @@ -98,10 +96,14 @@ mod tests { Validity::from_iter([false, true, false]) ) ); - assert_eq!( - taken.validity_mask().unwrap().indices(), - AllOr::Some(valid_indices) - ); + let mask = taken.validity_mask().unwrap(); + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); + assert_eq!(indices, [1]); } #[test] @@ -118,7 +120,7 @@ mod tests { taken.to_primitive(), PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid) ); - assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All); + assert!(taken.validity_mask().unwrap().all_true()); } #[rstest] diff --git a/vortex-array/src/arrays/dict/array.rs b/vortex-array/src/arrays/dict/array.rs index de616417ce1..3b576f93bd2 100644 --- a/vortex-array/src/arrays/dict/array.rs +++ b/vortex-array/src/arrays/dict/array.rs @@ -243,8 +243,6 @@ mod test { use vortex_dtype::UnsignedPType; use vortex_error::VortexExpect; use vortex_error::VortexResult; - use vortex_error::vortex_panic; - use vortex_mask::AllOr; use crate::Array; use crate::ArrayRef; @@ -271,9 +269,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [0, 2, 4]); } @@ -289,9 +290,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [0]); } @@ -311,9 +315,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [2, 4]); } @@ -329,9 +336,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [0, 2, 4]); } From f8780c7867c9f85f04edd39669660a30780f3158 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 12 Feb 2026 20:30:23 +0000 Subject: [PATCH 05/11] some stuff Signed-off-by: Adam Gutglick --- .../src/bitpacking/compute/filter.rs | 6 +++- vortex-array/src/array/mod.rs | 2 ++ vortex-array/src/arrays/filter/vtable.rs | 1 + vortex-array/src/mask_future.rs | 1 + vortex-cuda/src/kernel/encodings/zstd.rs | 6 ++-- vortex-file/src/tests.rs | 5 +++ vortex-scan/src/selection.rs | 36 +++++++++++-------- 7 files changed, 39 insertions(+), 18 deletions(-) diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index 00779310eaa..bd5e401ee99 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -97,7 +97,11 @@ fn filter_primitive_without_patches( array: &BitPackedArray, selection: &Arc, ) -> VortexResult<(Buffer, Validity)> { - let values = filter_with_indices(array, selection.bit_buffer().set_indices(), selection.len()); + let values = filter_with_indices( + array, + selection.bit_buffer().set_indices(), + selection.bit_buffer().true_count(), + ); let validity = array.validity()?.filter(&Mask::Values(selection.clone()))?; Ok((values.freeze(), validity)) diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index 981cb99967b..e05b0f937c3 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -466,12 +466,14 @@ impl Array for ArrayAdapter { } fn filter(&self, mask: Mask) -> VortexResult { + dbg!(self.encoding_id()); FilterArray::try_new(self.to_array(), mask)? .into_array() .optimize() } fn take(&self, indices: ArrayRef) -> VortexResult { + dbg!(self.encoding_id()); DictArray::try_new(indices, self.to_array())? .into_array() .optimize() diff --git a/vortex-array/src/arrays/filter/vtable.rs b/vortex-array/src/arrays/filter/vtable.rs index d2876e9e197..5060cdf97bc 100644 --- a/vortex-array/src/arrays/filter/vtable.rs +++ b/vortex-array/src/arrays/filter/vtable.rs @@ -109,6 +109,7 @@ impl VTable for FilterVTable { } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + dbg!(array.encoding_id()); if let Some(canonical) = execute_filter_fast_paths(array, ctx)? { return Ok(canonical); } diff --git a/vortex-array/src/mask_future.rs b/vortex-array/src/mask_future.rs index 7a3089a2169..136db9bef7a 100644 --- a/vortex-array/src/mask_future.rs +++ b/vortex-array/src/mask_future.rs @@ -28,6 +28,7 @@ impl MaskFuture { where F: Future> + Send + 'static, { + dbg!(len); Self { inner: fut .inspect(move |r| { diff --git a/vortex-cuda/src/kernel/encodings/zstd.rs b/vortex-cuda/src/kernel/encodings/zstd.rs index d4b68937047..06eecf79b1a 100644 --- a/vortex-cuda/src/kernel/encodings/zstd.rs +++ b/vortex-cuda/src/kernel/encodings/zstd.rs @@ -25,7 +25,7 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; -use vortex_mask::AllOr; +use vortex_mask::Mask; use vortex_nvcomp::sys::nvcompStatus_t; use vortex_nvcomp::zstd as nvcomp_zstd; use vortex_zstd::ZstdArray; @@ -282,8 +282,8 @@ async fn decode_zstd(array: ZstdArray, ctx: &mut CudaExecutionCtx) -> VortexResu let sliced_validity = validity.slice(slice_start..slice_stop)?; - match sliced_validity.to_mask(slice_stop - slice_start).indices() { - AllOr::All => { + match sliced_validity.to_mask(slice_stop - slice_start) { + Mask::AllTrue(_) => { let all_views = vortex_zstd::reconstruct_views(&host_buffer); let sliced_views = all_views.slice(slice_value_idx_start..slice_value_idx_stop); diff --git a/vortex-file/src/tests.rs b/vortex-file/src/tests.rs index 5a1eac809ba..58b12f4352d 100644 --- a/vortex-file/src/tests.rs +++ b/vortex-file/src/tests.rs @@ -835,6 +835,9 @@ async fn test_with_indices_and_with_row_filter_simple() { assert_eq!(actual_kept_array.len(), 0); + eprintln!("{}", file.footer().layout().display_tree()); + eprintln!("Finished 1"); + // test a few indices let kept_indices = [0u64, 3, 99, 100, 101, 399, 400, 401, 499]; @@ -850,6 +853,8 @@ async fn test_with_indices_and_with_row_filter_simple() { .unwrap() .to_struct(); + eprintln!("Finished 2"); + let actual_kept_numbers_array = actual_kept_array.unmasked_fields()[0].to_primitive(); let expected_kept_numbers: Buffer = kept_indices diff --git a/vortex-scan/src/selection.rs b/vortex-scan/src/selection.rs index d2fb46d4bb0..856369ca3fb 100644 --- a/vortex-scan/src/selection.rs +++ b/vortex-scan/src/selection.rs @@ -166,13 +166,21 @@ fn indices_range(range: &Range, row_indices: &[u64]) -> Option mod tests { use vortex_buffer::Buffer; + fn collect_indices(mask: &vortex_mask::Mask) -> Vec { + mask.values() + .unwrap() + .bit_buffer() + .set_indices() + .collect() + } + #[test] fn test_row_mask_all() { let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7])); let range = 1..8; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2, 4, 6]); } #[test] @@ -181,7 +189,7 @@ mod tests { let range = 3..6; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2]); } #[test] @@ -190,7 +198,7 @@ mod tests { let range = 3..5; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]); + assert_eq!(collect_indices(row_mask.mask()), &[0]); } #[test] @@ -217,7 +225,7 @@ mod tests { let range = 0..5; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]); + assert_eq!(collect_indices(row_mask.mask()), &[0]); } #[cfg(feature = "roaring")] @@ -238,7 +246,7 @@ mod tests { let range = 1..8; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2, 4, 6]); } #[test] @@ -253,7 +261,7 @@ mod tests { let range = 3..6; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2]); } #[test] @@ -299,7 +307,7 @@ mod tests { let row_mask = selection.row_mask(&range); // Should exclude indices 1, 3, 5, so we get 0, 2, 4, 6 - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2, 4, 6]); } #[test] @@ -344,7 +352,7 @@ mod tests { let row_mask = selection.row_mask(&range); // Should exclude 5, 6, 7 (mapped to 0, 1, 2), keep 8, 9 (mapped to 3, 4) - assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]); + assert_eq!(collect_indices(row_mask.mask()), &[3, 4]); } #[test] @@ -377,7 +385,7 @@ mod tests { let range = 0..100; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 99]); } #[test] @@ -393,7 +401,7 @@ mod tests { // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19) let expected: Vec = (0..5).chain(15..20).collect(); - assert_eq!(row_mask.mask().values().unwrap().indices(), &expected); + assert_eq!(collect_indices(row_mask.mask()), &expected); } #[test] @@ -443,8 +451,8 @@ mod tests { let roaring_mask = roaring_selection.row_mask(&range); assert_eq!( - buffer_mask.mask().values().unwrap().indices(), - roaring_mask.mask().values().unwrap().indices() + collect_indices(buffer_mask.mask()), + collect_indices(roaring_mask.mask()) ); } @@ -467,8 +475,8 @@ mod tests { let roaring_mask = roaring_selection.row_mask(&range); assert_eq!( - buffer_mask.mask().values().unwrap().indices(), - roaring_mask.mask().values().unwrap().indices() + collect_indices(buffer_mask.mask()), + collect_indices(roaring_mask.mask()) ); } } From f96c750b6b80e7276e551ddd7a16ab51c53f563c Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 12 Feb 2026 20:52:34 +0000 Subject: [PATCH 06/11] Fix bug and cleanup Signed-off-by: Adam Gutglick --- .../src/bitpacking/compute/filter.rs | 6 +- vortex-array/src/array/mod.rs | 2 - vortex-array/src/mask_future.rs | 1 - vortex-file/src/tests.rs | 5 - vortex-mask/src/intersect_by_rank.rs | 322 ++++++++++-------- 5 files changed, 178 insertions(+), 158 deletions(-) diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index bd5e401ee99..567bce2e78b 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -97,10 +97,12 @@ fn filter_primitive_without_patches( array: &BitPackedArray, selection: &Arc, ) -> VortexResult<(Buffer, Validity)> { + let selection_buffer = selection.bit_buffer(); + let values = filter_with_indices( array, - selection.bit_buffer().set_indices(), - selection.bit_buffer().true_count(), + selection_buffer.set_indices(), + selection_buffer.true_count(), ); let validity = array.validity()?.filter(&Mask::Values(selection.clone()))?; diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index e05b0f937c3..981cb99967b 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -466,14 +466,12 @@ impl Array for ArrayAdapter { } fn filter(&self, mask: Mask) -> VortexResult { - dbg!(self.encoding_id()); FilterArray::try_new(self.to_array(), mask)? .into_array() .optimize() } fn take(&self, indices: ArrayRef) -> VortexResult { - dbg!(self.encoding_id()); DictArray::try_new(indices, self.to_array())? .into_array() .optimize() diff --git a/vortex-array/src/mask_future.rs b/vortex-array/src/mask_future.rs index 136db9bef7a..7a3089a2169 100644 --- a/vortex-array/src/mask_future.rs +++ b/vortex-array/src/mask_future.rs @@ -28,7 +28,6 @@ impl MaskFuture { where F: Future> + Send + 'static, { - dbg!(len); Self { inner: fut .inspect(move |r| { diff --git a/vortex-file/src/tests.rs b/vortex-file/src/tests.rs index 58b12f4352d..5a1eac809ba 100644 --- a/vortex-file/src/tests.rs +++ b/vortex-file/src/tests.rs @@ -835,9 +835,6 @@ async fn test_with_indices_and_with_row_filter_simple() { assert_eq!(actual_kept_array.len(), 0); - eprintln!("{}", file.footer().layout().display_tree()); - eprintln!("Finished 1"); - // test a few indices let kept_indices = [0u64, 3, 99, 100, 101, 399, 400, 401, 499]; @@ -853,8 +850,6 @@ async fn test_with_indices_and_with_row_filter_simple() { .unwrap() .to_struct(); - eprintln!("Finished 2"); - let actual_kept_numbers_array = actual_kept_array.unmasked_fields()[0].to_primitive(); let expected_kept_numbers: Buffer = kept_indices diff --git a/vortex-mask/src/intersect_by_rank.rs b/vortex-mask/src/intersect_by_rank.rs index d27345f00b6..d21b7f19154 100644 --- a/vortex-mask/src/intersect_by_rank.rs +++ b/vortex-mask/src/intersect_by_rank.rs @@ -38,159 +38,185 @@ impl Mask { Self::new_false(self.len()) } else { let mask_values = mask.values().vortex_expect("msg"); - Self::from_iter( + let self_indices = self + .values() + .expect("") + .bit_buffer() + .set_indices() + .collect::>(); + + Self::from_indices( + self.len(), mask_values .bit_buffer() .set_indices() - .map(|idx| self.value(idx)), + .map(|idx| unsafe { *self_indices.get_unchecked(idx) }) + .collect(), ) } } } -// #[cfg(test)] -// mod test { -// use rstest::rstest; -// use vortex_buffer::BitBuffer; - -// use crate::Mask; - -// #[test] -// fn mask_bitand_all_as_bit_and() { -// let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true, true, true])); -// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, true, false, true, true])); -// assert_eq!( -// this.intersect_by_rank(&mask), -// Mask::from_indices(5, vec![1, 3, 4]) -// ); -// } - -// #[test] -// fn mask_bitand_all_true() { -// let this = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, true, true, true])); -// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true])); -// assert_eq!( -// this.intersect_by_rank(&mask), -// Mask::from_indices(5, vec![2, 3, 4]) -// ); -// } - -// #[test] -// fn mask_bitand_true() { -// let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); -// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, true])); -// assert_eq!( -// this.intersect_by_rank(&mask), -// Mask::from_indices(5, vec![0, 4]) -// ); -// } - -// #[test] -// fn mask_bitand_false() { -// let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); -// let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, false])); -// assert_eq!(this.intersect_by_rank(&mask), Mask::from_indices(5, vec![])); -// } - -// #[test] -// fn mask_intersect_by_rank_all_false() { -// let this = Mask::AllFalse(10); -// let mask = Mask::AllFalse(0); -// assert_eq!(this.intersect_by_rank(&mask), Mask::AllFalse(10)); -// } - -// #[rstest] -// #[case::all_true_with_all_true( -// Mask::new_true(5), -// Mask::new_true(5), -// vec![0, 1, 2, 3, 4] -// )] -// #[case::all_true_with_all_false( -// Mask::new_true(5), -// Mask::new_false(5), -// vec![] -// )] -// #[case::all_false_with_any( -// Mask::new_false(10), -// Mask::new_true(0), -// vec![] -// )] -// #[case::indices_with_all_true( -// Mask::from_indices(10, vec![2, 5, 7, 9]), -// Mask::new_true(4), -// vec![2, 5, 7, 9] -// )] -// #[case::indices_with_all_false( -// Mask::from_indices(10, vec![2, 5, 7, 9]), -// Mask::new_false(4), -// vec![] -// )] -// fn test_intersect_by_rank_special_cases( -// #[case] base_mask: Mask, -// #[case] rank_mask: Mask, -// #[case] expected_indices: Vec, -// ) { -// let result = base_mask.intersect_by_rank(&rank_mask); - -// match result.indices() { -// crate::AllOr::All => assert_eq!(expected_indices.len(), result.len()), -// crate::AllOr::None => assert!(expected_indices.is_empty()), -// crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), -// } -// } - -// #[test] -// fn test_intersect_by_rank_example() { -// // Example from the documentation -// let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]); -// let m2 = Mask::from_iter([false, false, true, false, true]); -// let result = m1.intersect_by_rank(&m2); -// let expected = Mask::from_iter([false, false, false, false, true, false, false, true]); -// assert_eq!(result, expected); -// } - -// #[test] -// #[should_panic] -// fn test_intersect_by_rank_wrong_length() { -// let m1 = Mask::from_indices(10, vec![2, 5, 7]); // 3 true values -// let m2 = Mask::new_true(5); // 5 true values - doesn't match -// m1.intersect_by_rank(&m2); -// } - -// #[rstest] -// #[case::single_element( -// vec![3], -// vec![true], -// vec![3] -// )] -// #[case::single_element_masked( -// vec![3], -// vec![false], -// vec![] -// )] -// #[case::alternating( -// vec![0, 2, 4, 6, 8], -// vec![true, false, true, false, true], -// vec![0, 4, 8] -// )] -// #[case::consecutive( -// vec![5, 6, 7, 8, 9], -// vec![false, true, true, true, false], -// vec![6, 7, 8] -// )] -// fn test_intersect_by_rank_patterns( -// #[case] base_indices: Vec, -// #[case] rank_pattern: Vec, -// #[case] expected_indices: Vec, -// ) { -// let base = Mask::from_indices(10, base_indices); -// let rank = Mask::from_iter(rank_pattern); -// let result = base.intersect_by_rank(&rank); - -// match result.indices() { -// crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), -// crate::AllOr::None => assert!(expected_indices.is_empty()), -// _ => panic!("Unexpected result"), -// } -// } -// } +#[cfg(test)] +mod test { + use rstest::rstest; + use vortex_buffer::BitBuffer; + + use crate::Mask; + + #[test] + fn mask_bitand_all_as_bit_and() { + let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true, true, true])); + let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, true, false, true, true])); + assert_eq!( + this.intersect_by_rank(&mask), + Mask::from_indices(5, vec![1, 3, 4]) + ); + } + + #[test] + fn mask_bitand_all_true() { + let this = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, true, true, true])); + let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, true, true])); + assert_eq!( + this.intersect_by_rank(&mask), + Mask::from_indices(5, vec![2, 3, 4]) + ); + } + + #[test] + fn mask_bitand_true() { + let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); + let mask = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, true])); + assert_eq!( + this.intersect_by_rank(&mask), + Mask::from_indices(5, vec![0, 4]) + ); + } + + #[test] + fn mask_bitand_false() { + let this = Mask::from_buffer(BitBuffer::from_iter(vec![true, false, false, true, true])); + let mask = Mask::from_buffer(BitBuffer::from_iter(vec![false, false, false])); + assert_eq!(this.intersect_by_rank(&mask), Mask::from_indices(5, vec![])); + } + + #[test] + fn mask_intersect_by_rank_all_false() { + let this = Mask::AllFalse(10); + let mask = Mask::AllFalse(0); + assert_eq!(this.intersect_by_rank(&mask), Mask::AllFalse(10)); + } + + #[rstest] + #[case::all_true_with_all_true( + Mask::new_true(5), + Mask::new_true(5), + vec![0, 1, 2, 3, 4] + )] + #[case::all_true_with_all_false( + Mask::new_true(5), + Mask::new_false(5), + vec![] + )] + #[case::all_false_with_any( + Mask::new_false(10), + Mask::new_true(0), + vec![] + )] + #[case::indices_with_all_true( + Mask::from_indices(10, vec![2, 5, 7, 9]), + Mask::new_true(4), + vec![2, 5, 7, 9] + )] + #[case::indices_with_all_false( + Mask::from_indices(10, vec![2, 5, 7, 9]), + Mask::new_false(4), + vec![] + )] + fn test_intersect_by_rank_special_cases( + #[case] base_mask: Mask, + #[case] rank_mask: Mask, + #[case] expected_indices: Vec, + ) { + let result = base_mask.intersect_by_rank(&rank_mask); + + match result { + Mask::AllTrue(n) => assert_eq!(expected_indices.len(), result.len()), + Mask::AllFalse(_) => assert!(expected_indices.is_empty()), + Mask::Values(mask_value) => { + assert_eq!( + mask_value.bit_buffer().set_indices().collect::>(), + &expected_indices[..] + ) + } + } + } + + #[test] + fn test_intersect_by_rank_example() { + // Example from the documentation + let m1 = Mask::from_iter([true, false, false, true, true, true, false, true]); + let m2 = Mask::from_iter([false, false, true, false, true]); + let result = m1.intersect_by_rank(&m2); + let expected = Mask::from_iter([false, false, false, false, true, false, false, true]); + assert_eq!(result, expected); + } + + #[test] + #[should_panic] + fn test_intersect_by_rank_wrong_length() { + let m1 = Mask::from_indices(10, vec![2, 5, 7]); // 3 true values + let m2 = Mask::new_true(5); // 5 true values - doesn't match + m1.intersect_by_rank(&m2); + } + + #[rstest] + #[case::single_element( + vec![3], + vec![true], + vec![3] + )] + #[case::single_element_masked( + vec![3], + vec![false], + vec![] + )] + #[case::alternating( + vec![0, 2, 4, 6, 8], + vec![true, false, true, false, true], + vec![0, 4, 8] + )] + #[case::consecutive( + vec![5, 6, 7, 8, 9], + vec![false, true, true, true, false], + vec![6, 7, 8] + )] + fn test_intersect_by_rank_patterns( + #[case] base_indices: Vec, + #[case] rank_pattern: Vec, + #[case] expected_indices: Vec, + ) { + let base = Mask::from_indices(10, base_indices); + let rank = Mask::from_iter(rank_pattern); + // let result = ; + + match base.intersect_by_rank(&rank) { + Mask::AllTrue(n) => unreachable!(), + Mask::AllFalse(_) => assert!(expected_indices.is_empty()), + Mask::Values(mask_values) => { + assert_eq!( + mask_values.bit_buffer().set_indices().collect::>(), + &expected_indices[..] + ) + } + } + + // let mask_values = result.values().unwrap(); + // assert_eq!(mask_values.true_count(), expected_indices.len()); + // // crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), + // // crate::AllOr::None => assert!(expected_indices.is_empty()), + // // _ => panic!("Unexpected result"), + // // } + } +} From a3203c7e064121a022d85c00b74dfcbe21986098 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Fri, 13 Feb 2026 12:03:30 +0000 Subject: [PATCH 07/11] Fix some things and clean up Signed-off-by: Adam Gutglick --- vortex-array/src/arrays/filter/vtable.rs | 1 - vortex-mask/src/intersect_by_rank.rs | 59 ++++++--------- vortex-mask/src/lib.rs | 94 ------------------------ vortex-scan/src/selection.rs | 6 +- 4 files changed, 25 insertions(+), 135 deletions(-) diff --git a/vortex-array/src/arrays/filter/vtable.rs b/vortex-array/src/arrays/filter/vtable.rs index 5060cdf97bc..d2876e9e197 100644 --- a/vortex-array/src/arrays/filter/vtable.rs +++ b/vortex-array/src/arrays/filter/vtable.rs @@ -109,7 +109,6 @@ impl VTable for FilterVTable { } fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { - dbg!(array.encoding_id()); if let Some(canonical) = execute_filter_fast_paths(array, ctx)? { return Ok(canonical); } diff --git a/vortex-mask/src/intersect_by_rank.rs b/vortex-mask/src/intersect_by_rank.rs index d21b7f19154..0e8ead98a38 100644 --- a/vortex-mask/src/intersect_by_rank.rs +++ b/vortex-mask/src/intersect_by_rank.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_error::VortexExpect; - use crate::Mask; impl Mask { @@ -30,29 +28,28 @@ impl Mask { pub fn intersect_by_rank(&self, mask: &Mask) -> Mask { assert_eq!(self.true_count(), mask.len()); - if self.all_true() { - mask.clone() - } else if mask.all_true() { - self.clone() - } else if self.all_false() || mask.all_false() { - Self::new_false(self.len()) - } else { - let mask_values = mask.values().vortex_expect("msg"); - let self_indices = self - .values() - .expect("") - .bit_buffer() - .set_indices() - .collect::>(); - - Self::from_indices( - self.len(), - mask_values - .bit_buffer() - .set_indices() - .map(|idx| unsafe { *self_indices.get_unchecked(idx) }) - .collect(), - ) + match (self, mask) { + (Mask::AllTrue(_), _) => mask.clone(), + (_, Mask::AllTrue(_)) => self.clone(), + (Mask::AllFalse(_), _) | (_, Mask::AllFalse(_)) => Self::new_false(self.len()), + (Mask::Values(self_values), Mask::Values(mask_values)) => { + let self_indices = self_values.bit_buffer().set_indices().collect::>(); + + Self::from_indices( + self.len(), + mask_values + .bit_buffer() + .set_indices() + .map(|idx| { + // SAFETY: + // This is verified as safe because we know that the indices are less than the + // mask.len() and we known mask.len() <= self.len(), + // implied by `self.true_count() == mask.len()`. + unsafe { *self_indices.get_unchecked(idx) } + }) + .collect(), + ) + } } } } @@ -142,7 +139,7 @@ mod test { let result = base_mask.intersect_by_rank(&rank_mask); match result { - Mask::AllTrue(n) => assert_eq!(expected_indices.len(), result.len()), + Mask::AllTrue(_) => assert_eq!(expected_indices.len(), result.len()), Mask::AllFalse(_) => assert!(expected_indices.is_empty()), Mask::Values(mask_value) => { assert_eq!( @@ -199,10 +196,9 @@ mod test { ) { let base = Mask::from_indices(10, base_indices); let rank = Mask::from_iter(rank_pattern); - // let result = ; match base.intersect_by_rank(&rank) { - Mask::AllTrue(n) => unreachable!(), + Mask::AllTrue(n) => assert_eq!(n, expected_indices.len()), Mask::AllFalse(_) => assert!(expected_indices.is_empty()), Mask::Values(mask_values) => { assert_eq!( @@ -211,12 +207,5 @@ mod test { ) } } - - // let mask_values = result.values().unwrap(); - // assert_eq!(mask_values.true_count(), expected_indices.len()); - // // crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), - // // crate::AllOr::None => assert!(expected_indices.is_empty()), - // // _ => panic!("Unexpected result"), - // // } } } diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index a880a426890..548da47f286 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -128,13 +128,6 @@ impl Default for Mask { pub struct MaskValues { buffer: BitBuffer, - // We cached the indices and slices representations, since it can be faster than iterating - // the bit-mask over and over again. - // #[cfg_attr(feature = "serde", serde(skip))] - // indices: OnceLock>, - // #[cfg_attr(feature = "serde", serde(skip))] - // slices: OnceLock>, - // Pre-computed values. true_count: usize, // i.e., the fraction of values that are true @@ -177,8 +170,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer, - // indices: Default::default(), - // slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -208,8 +199,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - // indices: OnceLock::from(indices), - // slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -237,8 +226,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - // indices: Default::default(), - // slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -271,8 +258,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - // indices: Default::default(), - // slices: OnceLock::from(slices), true_count, density: true_count as f64 / len as f64, })) @@ -495,36 +480,6 @@ impl Mask { } } - /// Return the indices representation of the mask. - // #[inline] - // pub fn indices(&self) -> AllOr<&[usize]> { - // match &self { - // Self::AllTrue(_) => AllOr::All, - // Self::AllFalse(_) => AllOr::None, - // Self::Values(values) => AllOr::Some(values.indices()), - // } - // } - - /// Return the slices representation of the mask. - // #[inline] - // pub fn slices(&self) -> AllOr<&[(usize, usize)]> { - // match &self { - // Self::AllTrue(_) => AllOr::All, - // Self::AllFalse(_) => AllOr::None, - // Self::Values(values) => AllOr::Some(values.slices()), - // } - // } - - /// Return an iterator over either indices or slices of the mask based on a density threshold. - // #[inline] - // pub fn threshold_iter(&self, threshold: f64) -> AllOr> { - // match &self { - // Self::AllTrue(_) => AllOr::All, - // Self::AllFalse(_) => AllOr::None, - // Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)), - // } - // } - /// Return [`MaskValues`] if the mask is not all true or all false. #[inline] pub fn values(&self) -> Option<&MaskValues> { @@ -670,55 +625,6 @@ impl MaskValues { self.buffer.value(index) } - // /// Constructs an indices vector from one of the other representations. - // pub fn indices(&self) -> &[usize] { - // self.indices.get_or_init(|| { - // if self.true_count == 0 { - // return vec![]; - // } - - // if self.true_count == self.len() { - // return (0..self.len()).collect(); - // } - - // if let Some(slices) = self.slices.get() { - // let mut indices = Vec::with_capacity(self.true_count); - // indices.extend(slices.iter().flat_map(|(start, end)| *start..*end)); - // debug_assert!(indices.is_sorted()); - // assert_eq!(indices.len(), self.true_count); - // return indices; - // } - - // let mut indices = Vec::with_capacity(self.true_count); - // indices.extend(self.buffer.set_indices()); - // debug_assert!(indices.is_sorted()); - // assert_eq!(indices.len(), self.true_count); - // indices - // }) - // } - - /// Constructs a slices vector from one of the other representations. - // #[inline] - // pub fn slices(&self) -> &[(usize, usize)] { - // self.slices.get_or_init(|| { - // if self.true_count == self.len() { - // return vec![(0, self.len())]; - // } - - // self.buffer.set_slices().collect() - // }) - // } - - /// Return an iterator over either indices or slices of the mask based on a density threshold. - // #[inline] - // pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> { - // if self.density >= threshold { - // MaskIter::Slices(self.slices()) - // } else { - // MaskIter::Indices(self.indices()) - // } - // } - /// Extracts the internal [`BitBuffer`]. pub(crate) fn into_buffer(self) -> BitBuffer { self.buffer diff --git a/vortex-scan/src/selection.rs b/vortex-scan/src/selection.rs index 856369ca3fb..3b647fc073d 100644 --- a/vortex-scan/src/selection.rs +++ b/vortex-scan/src/selection.rs @@ -167,11 +167,7 @@ mod tests { use vortex_buffer::Buffer; fn collect_indices(mask: &vortex_mask::Mask) -> Vec { - mask.values() - .unwrap() - .bit_buffer() - .set_indices() - .collect() + mask.values().unwrap().bit_buffer().set_indices().collect() } #[test] From b4f9040bb127d6fa6ba01e0e94562c9245e0dcb8 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Fri, 13 Feb 2026 12:09:55 +0000 Subject: [PATCH 08/11] bug fix Signed-off-by: Adam Gutglick --- vortex-array/src/arrays/list/compute/filter.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vortex-array/src/arrays/list/compute/filter.rs b/vortex-array/src/arrays/list/compute/filter.rs index da5cb57f8be..fe49ea3bb24 100644 --- a/vortex-array/src/arrays/list/compute/filter.rs +++ b/vortex-array/src/arrays/list/compute/filter.rs @@ -3,7 +3,6 @@ use std::sync::Arc; -use itertools::Itertools; use num_traits::Zero; use vortex_buffer::BitBufferMut; use vortex_buffer::Buffer; @@ -55,9 +54,9 @@ pub fn element_mask_from_offsets( } } else { // Sparse iteration: process individual selected lists. - for (start, end) in selection.bit_buffer().set_indices().tuple_windows() { - let list_start = offsets[start].as_() - first_offset; - let list_end = offsets[end].as_() - first_offset; + for idx in selection.bit_buffer().set_indices() { + let list_start = offsets[idx].as_() - first_offset; + let list_end = offsets[idx + 1].as_() - first_offset; // Process the elements for this list. process_element_range(list_start, list_end, &mut mask_builder); From 84ff88fcbc7f018ec212ac181ff35c593cd2a4a6 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Fri, 13 Feb 2026 14:26:14 +0000 Subject: [PATCH 09/11] clean this a bit Signed-off-by: Adam Gutglick --- vortex-array/src/patches.rs | 29 +++++++++---------- .../src/filter/vector/fixed_size_list.rs | 6 ++-- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 60d4f339eb3..28c1161725b 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -583,21 +583,20 @@ impl Patches { ); } - if mask.all_true() { - return Ok(Some(self.clone())); - } else if mask.all_false() { - Ok(None) - } else { - let mask_values = mask.values().vortex_expect("trust me"); - let flat_indices = self.indices().to_primitive(); - match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| { - filter_patches_with_mask( - flat_indices.as_slice::(), - self.offset(), - self.values(), - &mask_values.bit_buffer().set_indices().collect::>(), - ) - }) + match mask.bit_buffer() { + AllOr::All => Ok(Some(self.clone())), + AllOr::None => Ok(None), + AllOr::Some(mask_bits) => { + let flat_indices = self.indices().to_primitive(); + match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| { + filter_patches_with_mask( + flat_indices.as_slice::(), + self.offset(), + self.values(), + &mask_bits.set_indices().collect::>(), + ) + }) + } } } diff --git a/vortex-compute/src/filter/vector/fixed_size_list.rs b/vortex-compute/src/filter/vector/fixed_size_list.rs index 744e99c2908..f94a098e57a 100644 --- a/vortex-compute/src/filter/vector/fixed_size_list.rs +++ b/vortex-compute/src/filter/vector/fixed_size_list.rs @@ -147,7 +147,7 @@ fn compute_fsl_elements_mask(selection_mask: &Mask, list_size: usize) -> Vec<(us Mask::Values(values) => values, }; - let expanded_slices = if values.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { + if values.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { values .bit_buffer() .set_slices() @@ -163,7 +163,5 @@ fn compute_fsl_elements_mask(selection_mask: &Mask, list_size: usize) -> Vec<(us (start, end) }) .collect() - }; - - expanded_slices + } } From 6cbbc22b9e2338880c35927251e145a2e11d54ea Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Fri, 13 Feb 2026 14:32:57 +0000 Subject: [PATCH 10/11] make clippy finally happy Signed-off-by: Adam Gutglick --- encodings/zstd/src/array.rs | 100 ++++++++++++++++---------------- vortex-array/src/compute/zip.rs | 14 ----- vortex-array/src/patches.rs | 1 - vortex-scan/src/selection.rs | 2 +- 4 files changed, 52 insertions(+), 65 deletions(-) diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index d8e8aeb5f2f..d712ba406da 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -52,6 +52,7 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_error::vortex_panic; +use vortex_mask::AllOr; use vortex_scalar::Scalar; use vortex_session::VortexSession; @@ -717,58 +718,59 @@ impl ZstdArray { } DType::Binary(_) | DType::Utf8(_) => { let mask = slice_validity.to_mask(slice_n_rows); - if mask.all_true() { - // the decompressed buffer is a bunch of interleaved u32 lengths - // and strings of those lengths, we need to reconstruct the - // views into those strings by passing through the buffer. - let valid_views = reconstruct_views(&decompressed).slice( - slice_value_idx_start - n_skipped_values - ..slice_value_idx_stop - n_skipped_values, - ); - - // SAFETY: we properly construct the views inside `reconstruct_views` - Ok(unsafe { - VarBinViewArray::new_unchecked( - valid_views, - Arc::from([decompressed]), - self.dtype.clone(), - slice_validity, - ) + + match mask.bit_buffer() { + AllOr::All => { + // the decompressed buffer is a bunch of interleaved u32 lengths + // and strings of those lengths, we need to reconstruct the + // views into those strings by passing through the buffer. + let valid_views = reconstruct_views(&decompressed).slice( + slice_value_idx_start - n_skipped_values + ..slice_value_idx_stop - n_skipped_values, + ); + + // SAFETY: we properly construct the views inside `reconstruct_views` + Ok(unsafe { + VarBinViewArray::new_unchecked( + valid_views, + Arc::from([decompressed]), + self.dtype.clone(), + slice_validity, + ) + } + .into_array()) } - .into_array()) - } else if mask.all_false() { - Ok( - ConstantArray::new(Scalar::null(self.dtype.clone()), slice_n_rows) - .into_array(), + AllOr::None => Ok(ConstantArray::new( + Scalar::null(self.dtype.clone()), + slice_n_rows, ) - } else { - let mask_values = mask.values().unwrap(); - // the decompressed buffer is a bunch of interleaved u32 lengths - // and strings of those lengths, we need to reconstruct the - // views into those strings by passing through the buffer. - let valid_views = reconstruct_views(&decompressed).slice( - slice_value_idx_start - n_skipped_values - ..slice_value_idx_stop - n_skipped_values, - ); - - let mut views = BufferMut::::zeroed(slice_n_rows); - for (view, index) in valid_views - .into_iter() - .zip_eq(mask_values.bit_buffer().set_indices()) - { - views[index] = view - } - - // SAFETY: we properly construct the views inside `reconstruct_views` - Ok(unsafe { - VarBinViewArray::new_unchecked( - views.freeze(), - Arc::from([decompressed]), - self.dtype.clone(), - slice_validity, - ) + .into_array()), + AllOr::Some(mask_bits) => { + // the decompressed buffer is a bunch of interleaved u32 lengths + // and strings of those lengths, we need to reconstruct the + // views into those strings by passing through the buffer. + let valid_views = reconstruct_views(&decompressed).slice( + slice_value_idx_start - n_skipped_values + ..slice_value_idx_stop - n_skipped_values, + ); + + let mut views = BufferMut::::zeroed(slice_n_rows); + for (view, index) in valid_views.into_iter().zip_eq(mask_bits.set_indices()) + { + views[index] = view + } + + // SAFETY: we properly construct the views inside `reconstruct_views` + Ok(unsafe { + VarBinViewArray::new_unchecked( + views.freeze(), + Arc::from([decompressed]), + self.dtype.clone(), + slice_validity, + ) + } + .into_array()) } - .into_array()) } } _ => vortex_panic!("Unsupported dtype for Zstd array: {}", self.dtype), diff --git a/vortex-array/src/compute/zip.rs b/vortex-array/src/compute/zip.rs index b7f84c6ece0..f924d69681c 100644 --- a/vortex-array/src/compute/zip.rs +++ b/vortex-array/src/compute/zip.rs @@ -75,20 +75,6 @@ fn zip_impl_with_builder( } else { Ok(if_false.to_array()) } - // match mask.slices() { - // AllOr::All => Ok(if_true.to_array()), - // AllOr::None => Ok(if_false.to_array()), - // AllOr::Some(slices) => { - // for (start, end) in slices { - // builder.extend_from_array(&if_false.slice(builder.len()..*start)?); - // builder.extend_from_array(&if_true.slice(*start..*end)?); - // } - // if builder.len() < if_false.len() { - // builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); - // } - // Ok(builder.finish()) - // } - // } } #[cfg(test)] diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 28c1161725b..61bf16b3c22 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -21,7 +21,6 @@ use vortex_dtype::match_each_integer_ptype; use vortex_dtype::match_each_native_ptype; use vortex_dtype::match_each_unsigned_integer_ptype; use vortex_error::VortexError; -use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; diff --git a/vortex-scan/src/selection.rs b/vortex-scan/src/selection.rs index 3b647fc073d..66e0dbd2472 100644 --- a/vortex-scan/src/selection.rs +++ b/vortex-scan/src/selection.rs @@ -397,7 +397,7 @@ mod tests { // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19) let expected: Vec = (0..5).chain(15..20).collect(); - assert_eq!(collect_indices(row_mask.mask()), &expected); + assert_eq!(collect_indices(row_mask.mask()), expected); } #[test] From ecd9f19a992d4bc50b7b326955bb5aaf304cb089 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Fri, 13 Feb 2026 15:32:28 +0000 Subject: [PATCH 11/11] More cleanup Signed-off-by: Adam Gutglick --- .../src/arrays/filter/execute/listview.rs | 4 ++-- vortex-array/src/arrays/filter/execute/mod.rs | 9 ++------- .../src/arrays/filter/execute/struct_.rs | 4 ++-- .../src/arrays/filter/execute/varbinview.rs | 18 +++++++++++++++--- vortex-array/src/compute/filter.rs | 16 ---------------- vortex-array/src/compute/zip.rs | 3 --- 6 files changed, 21 insertions(+), 33 deletions(-) diff --git a/vortex-array/src/arrays/filter/execute/listview.rs b/vortex-array/src/arrays/filter/execute/listview.rs index c9e9d369cff..9eff4849878 100644 --- a/vortex-array/src/arrays/filter/execute/listview.rs +++ b/vortex-array/src/arrays/filter/execute/listview.rs @@ -4,12 +4,12 @@ use std::sync::Arc; use vortex_error::VortexExpect; +use vortex_mask::Mask; use vortex_mask::MaskValues; use crate::arrays::ListViewArray; use crate::arrays::ListViewRebuildMode; use crate::arrays::filter::execute::filter_validity; -use crate::arrays::filter::execute::values_to_mask; use crate::vtable::ValidityHelper; // TODO(connor)[ListView]: Make use of this threshold after we start migrating operators. @@ -49,7 +49,7 @@ pub fn filter_listview(array: &ListViewArray, selection_mask: &Arc) ); // Simply filter the offsets and sizes arrays. - let mask_for_filter = values_to_mask(selection_mask); + let mask_for_filter = Mask::Values(selection_mask.clone()); let new_offsets = offsets .filter(mask_for_filter.clone()) .vortex_expect("ListViewArray offsets are guaranteed to support filter"); diff --git a/vortex-array/src/arrays/filter/execute/mod.rs b/vortex-array/src/arrays/filter/execute/mod.rs index 42bcb18c9d5..8a1c8e48fda 100644 --- a/vortex-array/src/arrays/filter/execute/mod.rs +++ b/vortex-array/src/arrays/filter/execute/mod.rs @@ -31,15 +31,10 @@ mod primitive; mod struct_; mod varbinview; -/// Reconstruct a [`Mask`] from an [`Arc`]. -fn values_to_mask(values: &Arc) -> Mask { - Mask::Values(values.clone()) -} - /// A helper function that lazily filters a [`Validity`] with selection mask values. fn filter_validity(validity: Validity, mask: &Arc) -> Validity { validity - .filter(&values_to_mask(mask)) + .filter(&Mask::Values(mask.clone())) .vortex_expect("Somehow unable to wrap filter around a validity array") } @@ -87,7 +82,7 @@ pub(super) fn execute_filter(canonical: Canonical, mask: &Arc) -> Ca Canonical::Extension(a) => { let filtered_storage = a .storage() - .filter(values_to_mask(mask)) + .filter(Mask::Values(mask.clone())) .vortex_expect("ExtensionArray storage type somehow could not be filtered"); Canonical::Extension(ExtensionArray::new(a.ext_dtype().clone(), filtered_storage)) } diff --git a/vortex-array/src/arrays/filter/execute/struct_.rs b/vortex-array/src/arrays/filter/execute/struct_.rs index dff28d27430..781e73b0f3d 100644 --- a/vortex-array/src/arrays/filter/execute/struct_.rs +++ b/vortex-array/src/arrays/filter/execute/struct_.rs @@ -4,18 +4,18 @@ use std::sync::Arc; use vortex_error::VortexExpect; +use vortex_mask::Mask; use vortex_mask::MaskValues; use crate::ArrayRef; use crate::arrays::StructArray; use crate::arrays::filter::execute::filter_validity; -use crate::arrays::filter::execute::values_to_mask; use crate::vtable::ValidityHelper; pub fn filter_struct(array: &StructArray, mask: &Arc) -> StructArray { let filtered_validity = filter_validity(array.validity().clone(), mask); - let mask_for_filter = values_to_mask(mask); + let mask_for_filter = Mask::Values(mask.clone()); let fields: Vec = array .unmasked_fields() .iter() diff --git a/vortex-array/src/arrays/filter/execute/varbinview.rs b/vortex-array/src/arrays/filter/execute/varbinview.rs index 60286586ba3..9e7ec9645e8 100644 --- a/vortex-array/src/arrays/filter/execute/varbinview.rs +++ b/vortex-array/src/arrays/filter/execute/varbinview.rs @@ -3,22 +3,34 @@ use std::sync::Arc; +use arrow_array::BooleanArray; use vortex_error::VortexExpect; +use vortex_error::VortexResult; use vortex_mask::MaskValues; +use crate::Array; +use crate::ArrayRef; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; -use crate::arrays::filter::execute::values_to_mask; -use crate::compute::arrow_filter_fn; +use crate::arrow::FromArrowArray; +use crate::arrow::IntoArrowArray; pub fn filter_varbinview(array: &VarBinViewArray, mask: &Arc) -> VarBinViewArray { // Delegate to the Arrow implementation of filter over `VarBinView`. - arrow_filter_fn(array.as_ref(), &values_to_mask(mask)) + arrow_filter_fn(array.as_ref(), mask.as_ref()) .vortex_expect("VarBinViewArray is Arrow-compatible and supports arrow_filter_fn") .as_::() .clone() } +fn arrow_filter_fn(array: &dyn Array, values: &MaskValues) -> VortexResult { + let array_ref = array.to_array().into_arrow_preferred()?; + let mask_array = BooleanArray::new(values.bit_buffer().clone().into(), None); + let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?; + + ArrayRef::from_arrow(filtered.as_ref(), array.dtype().is_nullable()) +} + #[cfg(test)] mod test { use crate::arrays::VarBinViewArray; diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 737eacd94de..8f5960c1a51 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -3,7 +3,6 @@ // TODO(connor): REMOVE THIS FILE! -use arrow_array::BooleanArray; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -14,8 +13,6 @@ use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::ToCanonical; -use crate::arrow::FromArrowArray; -use crate::arrow::IntoArrowArray; use crate::builtins::ArrayBuiltins; /// Keep only the elements for which the corresponding mask value is true. @@ -67,16 +64,3 @@ impl dyn Array + '_ { Ok(array.to_bool().to_mask_fill_null_false()) } } - -pub fn arrow_filter_fn(array: &dyn Array, mask: &Mask) -> VortexResult { - let values = match &mask { - Mask::Values(values) => values, - Mask::AllTrue(_) | Mask::AllFalse(_) => unreachable!("check in filter invoke"), - }; - - let array_ref = array.to_array().into_arrow_preferred()?; - let mask_array = BooleanArray::new(values.bit_buffer().clone().into(), None); - let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?; - - ArrayRef::from_arrow(filtered.as_ref(), array.dtype().is_nullable()) -} diff --git a/vortex-array/src/compute/zip.rs b/vortex-array/src/compute/zip.rs index f924d69681c..915152043bb 100644 --- a/vortex-array/src/compute/zip.rs +++ b/vortex-array/src/compute/zip.rs @@ -3,9 +3,6 @@ use vortex_dtype::DType; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_mask::AllOr; use vortex_mask::Mask; use crate::Array;