diff --git a/encodings/alp/public-api.lock b/encodings/alp/public-api.lock index 8416610a039..dd32a355106 100644 --- a/encodings/alp/public-api.lock +++ b/encodings/alp/public-api.lock @@ -160,14 +160,14 @@ impl vortex_array::arrays::slice::SliceKernel for vortex_alp::ALPRDVTable pub fn vortex_alp::ALPRDVTable::slice(array: &Self::Array, range: core::ops::range::Range, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::mask::MaskKernel for vortex_alp::ALPRDVTable - -pub fn vortex_alp::ALPRDVTable::mask(&self, array: &vortex_alp::ALPRDArray, filter_mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_alp::ALPRDVTable pub fn vortex_alp::ALPRDVTable::cast(array: &vortex_alp::ALPRDArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::mask::kernel::MaskReduce for vortex_alp::ALPRDVTable + +pub fn vortex_alp::ALPRDVTable::mask(array: &vortex_alp::ALPRDArray, mask: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::VTable for vortex_alp::ALPRDVTable pub type vortex_alp::ALPRDVTable::Array = vortex_alp::ALPRDArray @@ -256,10 +256,6 @@ impl vortex_array::compute::between::BetweenKernel for vortex_alp::ALPVTable pub fn vortex_alp::ALPVTable::between(&self, array: &vortex_alp::ALPArray, lower: &dyn vortex_array::array::Array, upper: &dyn vortex_array::array::Array, options: &vortex_array::compute::between::BetweenOptions) -> vortex_error::VortexResult> -impl vortex_array::compute::mask::MaskKernel for vortex_alp::ALPVTable - -pub fn vortex_alp::ALPVTable::mask(&self, array: &vortex_alp::ALPArray, filter_mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::nan_count::NaNCountKernel for vortex_alp::ALPVTable pub fn vortex_alp::ALPVTable::nan_count(&self, array: &vortex_alp::ALPArray) -> vortex_error::VortexResult @@ -272,6 +268,14 @@ impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_alp::ALPVTab pub fn vortex_alp::ALPVTable::cast(array: &vortex_alp::ALPArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::mask::kernel::MaskKernel for vortex_alp::ALPVTable + +pub fn vortex_alp::ALPVTable::mask(array: &vortex_alp::ALPArray, mask: &vortex_array::array::ArrayRef, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::expr::exprs::mask::kernel::MaskReduce for vortex_alp::ALPVTable + +pub fn vortex_alp::ALPVTable::mask(array: &vortex_alp::ALPArray, mask: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::VTable for vortex_alp::ALPVTable pub type vortex_alp::ALPVTable::Array = vortex_alp::ALPArray diff --git a/encodings/alp/src/alp/compute/mask.rs b/encodings/alp/src/alp/compute/mask.rs index f08b6499a20..3cf6baa7b6c 100644 --- a/encodings/alp/src/alp/compute/mask.rs +++ b/encodings/alp/src/alp/compute/mask.rs @@ -2,38 +2,48 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ALPArray; use crate::ALPVTable; +impl MaskReduce for ALPVTable { + fn mask(array: &ALPArray, mask: &ArrayRef) -> VortexResult> { + // Masking sparse patches requires reading indices, fall back to kernel. + if array.patches().is_some() { + return Ok(None); + } + let masked_encoded = array.encoded().clone().mask(mask.clone())?; + Ok(Some( + ALPArray::new(masked_encoded, array.exponents(), None).to_array(), + )) + } +} + impl MaskKernel for ALPVTable { - fn mask(&self, array: &ALPArray, filter_mask: &Mask) -> VortexResult { - let masked_encoded = mask(array.encoded(), filter_mask)?; + fn mask( + array: &ALPArray, + mask: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let vortex_mask = Validity::Array(mask.not()?).to_mask(array.len()); + let masked_encoded = array.encoded().clone().mask(mask.clone())?; let masked_patches = array .patches() - .map(|p| p.mask(filter_mask)) + .map(|p| p.mask(&vortex_mask)) .transpose()? - .flatten() - .map(|patches| { - patches.cast_values( - &array - .dtype() - .with_nullability(masked_encoded.dtype().nullability()), - ) - }) - .transpose()?; - Ok(ALPArray::new(masked_encoded, array.exponents(), masked_patches).to_array()) + .flatten(); + Ok(Some( + ALPArray::new(masked_encoded, array.exponents(), masked_patches).to_array(), + )) } } -register_kernel!(MaskKernelAdapter(ALPVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; @@ -58,4 +68,17 @@ mod test { let alp = alp_encode(&array.to_primitive(), None).unwrap(); test_mask_conformance(alp.as_ref()); } + + #[test] + fn test_mask_alp_with_patches() { + use std::f64::consts::PI; + // PI doesn't encode cleanly with ALP, so it creates patches. + let values: Vec = (0..100) + .map(|i| if i % 4 == 3 { PI } else { 1.0 }) + .collect(); + let array = PrimitiveArray::from_iter(values); + let alp = alp_encode(&array, None).unwrap(); + assert!(alp.patches().is_some(), "expected patches"); + test_mask_conformance(alp.as_ref()); + } } diff --git a/encodings/alp/src/alp/rules.rs b/encodings/alp/src/alp/rules.rs index 57d177f99f8..ccd3b69a279 100644 --- a/encodings/alp/src/alp/rules.rs +++ b/encodings/alp/src/alp/rules.rs @@ -5,6 +5,8 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskExecuteAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use vortex_array::optimizer::rules::ParentRuleSet; @@ -14,9 +16,12 @@ use crate::ALPVTable; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&CompareExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(ALPVTable)), + ParentKernelSet::lift(&MaskExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(ALPVTable)), ]); -pub(super) const RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&CastReduceAdaptor(ALPVTable))]); +pub(super) const RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&CastReduceAdaptor(ALPVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ALPVTable)), +]); diff --git a/encodings/alp/src/alp_rd/compute/mask.rs b/encodings/alp/src/alp_rd/compute/mask.rs index dcda1e1e2e2..7000dc4f8d7 100644 --- a/encodings/alp/src/alp_rd/compute/mask.rs +++ b/encodings/alp/src/alp_rd/compute/mask.rs @@ -3,32 +3,36 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::arrays::ScalarFnArrayExt; +use vortex_array::compute::MaskReduce; +use vortex_array::expr::EmptyOptions; +use vortex_array::expr::Mask as MaskExpr; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ALPRDArray; use crate::ALPRDVTable; -impl MaskKernel for ALPRDVTable { - fn mask(&self, array: &ALPRDArray, filter_mask: &Mask) -> VortexResult { - Ok(ALPRDArray::try_new( - array.dtype().as_nullable(), - mask(array.left_parts(), filter_mask)?, - array.left_parts_dictionary().clone(), - array.right_parts().clone(), - array.right_bit_width(), - array.left_parts_patches().cloned(), - )? - .into_array()) +impl MaskReduce for ALPRDVTable { + fn mask(array: &ALPRDArray, mask: &ArrayRef) -> VortexResult> { + let masked_left_parts = MaskExpr.try_new_array( + array.left_parts().len(), + EmptyOptions, + [array.left_parts().clone(), mask.clone()], + )?; + Ok(Some( + ALPRDArray::try_new( + array.dtype().as_nullable(), + masked_left_parts, + array.left_parts_dictionary().clone(), + array.right_parts().clone(), + array.right_bit_width(), + array.left_parts_patches().cloned(), + )? + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(ALPRDVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/alp/src/alp_rd/rules.rs b/encodings/alp/src/alp_rd/rules.rs index a7280e0bf8f..ed048e10829 100644 --- a/encodings/alp/src/alp_rd/rules.rs +++ b/encodings/alp/src/alp_rd/rules.rs @@ -2,9 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ParentRuleSet; use crate::alp_rd::ALPRDVTable; -pub(crate) static RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&CastReduceAdaptor(ALPRDVTable))]); +pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&CastReduceAdaptor(ALPRDVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ALPRDVTable)), +]); diff --git a/encodings/bytebool/public-api.lock b/encodings/bytebool/public-api.lock index 7bbcbea9f78..d6efe219caa 100644 --- a/encodings/bytebool/public-api.lock +++ b/encodings/bytebool/public-api.lock @@ -68,14 +68,14 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_bytebool::ByteBoolVTabl pub fn vortex_bytebool::ByteBoolVTable::slice(array: &vortex_bytebool::ByteBoolArray, range: core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::compute::mask::MaskKernel for vortex_bytebool::ByteBoolVTable - -pub fn vortex_bytebool::ByteBoolVTable::mask(&self, array: &vortex_bytebool::ByteBoolArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_bytebool::ByteBoolVTable pub fn vortex_bytebool::ByteBoolVTable::cast(array: &vortex_bytebool::ByteBoolArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::mask::kernel::MaskReduce for vortex_bytebool::ByteBoolVTable + +pub fn vortex_bytebool::ByteBoolVTable::mask(array: &vortex_bytebool::ByteBoolArray, mask: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::VTable for vortex_bytebool::ByteBoolVTable pub type vortex_bytebool::ByteBoolVTable::Array = vortex_bytebool::ByteBoolArray diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index b40e3722ee0..d627de6f5d4 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -9,14 +9,12 @@ use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::TakeExecute; use vortex_array::compute::CastReduce; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::register_kernel; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_array::vtable::ValidityHelper; use vortex_dtype::DType; use vortex_dtype::match_each_integer_ptype; use vortex_error::VortexResult; -use vortex_mask::Mask; use super::ByteBoolArray; use super::ByteBoolVTable; @@ -44,14 +42,21 @@ impl CastReduce for ByteBoolVTable { } } -impl MaskKernel for ByteBoolVTable { - fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult { - Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)).into_array()) +impl MaskReduce for ByteBoolVTable { + fn mask(array: &ByteBoolArray, mask: &ArrayRef) -> VortexResult> { + Ok(Some( + ByteBoolArray::new( + array.buffer().clone(), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, + ) + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift()); - impl TakeExecute for ByteBoolVTable { fn take( array: &ByteBoolArray, diff --git a/encodings/bytebool/src/rules.rs b/encodings/bytebool/src/rules.rs index 52e9e32cefa..284989a3621 100644 --- a/encodings/bytebool/src/rules.rs +++ b/encodings/bytebool/src/rules.rs @@ -3,11 +3,13 @@ use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ParentRuleSet; use crate::ByteBoolVTable; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ - ParentRuleSet::lift(&SliceReduceAdaptor(ByteBoolVTable)), ParentRuleSet::lift(&CastReduceAdaptor(ByteBoolVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ByteBoolVTable)), + ParentRuleSet::lift(&SliceReduceAdaptor(ByteBoolVTable)), ]); diff --git a/encodings/datetime-parts/public-api.lock b/encodings/datetime-parts/public-api.lock index b43cbafad7a..a82170a65c9 100644 --- a/encodings/datetime-parts/public-api.lock +++ b/encodings/datetime-parts/public-api.lock @@ -6,6 +6,8 @@ impl vortex_datetime_parts::DateTimePartsArray pub fn vortex_datetime_parts::DateTimePartsArray::days(&self) -> &vortex_array::array::ArrayRef +pub fn vortex_datetime_parts::DateTimePartsArray::into_parts(self) -> vortex_datetime_parts::DateTimePartsArrayParts + pub fn vortex_datetime_parts::DateTimePartsArray::seconds(&self) -> &vortex_array::array::ArrayRef pub fn vortex_datetime_parts::DateTimePartsArray::subseconds(&self) -> &vortex_array::array::ArrayRef @@ -44,6 +46,24 @@ impl vortex_array::array::IntoArray for vortex_datetime_parts::DateTimePartsArra pub fn vortex_datetime_parts::DateTimePartsArray::into_array(self) -> vortex_array::array::ArrayRef +pub struct vortex_datetime_parts::DateTimePartsArrayParts + +pub vortex_datetime_parts::DateTimePartsArrayParts::days: vortex_array::array::ArrayRef + +pub vortex_datetime_parts::DateTimePartsArrayParts::dtype: vortex_dtype::dtype::DType + +pub vortex_datetime_parts::DateTimePartsArrayParts::seconds: vortex_array::array::ArrayRef + +pub vortex_datetime_parts::DateTimePartsArrayParts::subseconds: vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_datetime_parts::DateTimePartsArrayParts + +pub fn vortex_datetime_parts::DateTimePartsArrayParts::clone(&self) -> vortex_datetime_parts::DateTimePartsArrayParts + +impl core::fmt::Debug for vortex_datetime_parts::DateTimePartsArrayParts + +pub fn vortex_datetime_parts::DateTimePartsArrayParts::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + #[repr(C)] pub struct vortex_datetime_parts::DateTimePartsMetadata pub vortex_datetime_parts::DateTimePartsMetadata::days_ptype: i32 @@ -118,10 +138,6 @@ impl vortex_array::compute::is_constant::IsConstantKernel for vortex_datetime_pa pub fn vortex_datetime_parts::DateTimePartsVTable::is_constant(&self, array: &vortex_datetime_parts::DateTimePartsArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult> -impl vortex_array::compute::mask::MaskKernel for vortex_datetime_parts::DateTimePartsVTable - -pub fn vortex_datetime_parts::DateTimePartsVTable::mask(&self, array: &vortex_datetime_parts::DateTimePartsArray, mask_array: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_datetime_parts::DateTimePartsVTable pub fn vortex_datetime_parts::DateTimePartsVTable::compare(lhs: &vortex_datetime_parts::DateTimePartsArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> @@ -130,6 +146,10 @@ impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_datetime_par pub fn vortex_datetime_parts::DateTimePartsVTable::cast(array: &vortex_datetime_parts::DateTimePartsArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::mask::kernel::MaskReduce for vortex_datetime_parts::DateTimePartsVTable + +pub fn vortex_datetime_parts::DateTimePartsVTable::mask(array: &vortex_datetime_parts::DateTimePartsArray, mask: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::VTable for vortex_datetime_parts::DateTimePartsVTable pub type vortex_datetime_parts::DateTimePartsVTable::Array = vortex_datetime_parts::DateTimePartsArray diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index 7c576e5fa7e..68f1c625732 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -190,6 +190,14 @@ pub struct DateTimePartsArray { stats_set: ArrayStats, } +#[derive(Clone, Debug)] +pub struct DateTimePartsArrayParts { + pub dtype: DType, + pub days: ArrayRef, + pub seconds: ArrayRef, + pub subseconds: ArrayRef, +} + #[derive(Debug)] pub struct DateTimePartsVTable; @@ -252,6 +260,15 @@ impl DateTimePartsArray { } } + pub fn into_parts(self) -> DateTimePartsArrayParts { + DateTimePartsArrayParts { + dtype: self.dtype, + days: self.days, + seconds: self.seconds, + subseconds: self.subseconds, + } + } + pub fn days(&self) -> &ArrayRef { &self.days } diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs index ba9ba4b9acf..f15ba8ce2ca 100644 --- a/encodings/datetime-parts/src/compute/mask.rs +++ b/encodings/datetime-parts/src/compute/mask.rs @@ -1,41 +1,28 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_array::Array; use vortex_array::ArrayRef; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::IntoArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::compute::MaskReduce; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::DateTimePartsArray; +use crate::DateTimePartsArrayParts; use crate::DateTimePartsVTable; -impl MaskKernel for DateTimePartsVTable { - fn mask(&self, array: &DateTimePartsArray, mask_array: &Mask) -> VortexResult { - // DateTimePartsArray has specific constraints: - // - days nullability must match the dtype - // - seconds and subseconds must always be non-nullable - // - // When masking, we can't make seconds/subseconds nullable. - // Instead, we'll keep the same values but the overall array becomes nullable - // through the days component. - - let masked_days = mask(array.days(), mask_array)?; - assert!(masked_days.dtype().is_nullable()); - - // Keep seconds and subseconds unchanged since they must remain non-nullable - let seconds = array.seconds().clone(); - let subseconds = array.subseconds().clone(); - - // Update the dtype to reflect the new nullability of days - let new_dtype = array.dtype().as_nullable(); - - DateTimePartsArray::try_new(new_dtype, masked_days, seconds, subseconds) - .map(|a| a.to_array()) +impl MaskReduce for DateTimePartsVTable { + fn mask(array: &DateTimePartsArray, mask: &ArrayRef) -> VortexResult> { + let DateTimePartsArrayParts { + dtype, + days, + seconds, + subseconds, + } = array.clone().into_parts(); + let masked_days = days.mask(mask.clone())?; + Ok(Some( + DateTimePartsArray::try_new(dtype.as_nullable(), masked_days, seconds, subseconds)? + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(DateTimePartsVTable).lift()); diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index 436636872e2..e1d5e855236 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -14,6 +14,7 @@ use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::expr::Between; use vortex_array::expr::Binary; use vortex_array::optimizer::ArrayOptimizer; @@ -33,6 +34,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSe ParentRuleSet::lift(&DTPComparisonPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(DateTimePartsVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(DateTimePartsVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(DateTimePartsVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(DateTimePartsVTable)), ]); diff --git a/encodings/decimal-byte-parts/public-api.lock b/encodings/decimal-byte-parts/public-api.lock index a2117734a75..d9691a7fd44 100644 --- a/encodings/decimal-byte-parts/public-api.lock +++ b/encodings/decimal-byte-parts/public-api.lock @@ -68,10 +68,6 @@ impl vortex_array::compute::is_constant::IsConstantKernel for vortex_decimal_byt pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::is_constant(&self, array: &vortex_decimal_byte_parts::DecimalBytePartsArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult> -impl vortex_array::compute::mask::MaskKernel for vortex_decimal_byte_parts::DecimalBytePartsVTable - -pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::mask(&self, array: &vortex_decimal_byte_parts::DecimalBytePartsArray, mask_array: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_decimal_byte_parts::DecimalBytePartsVTable pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::compare(lhs: &Self::Array, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> @@ -80,6 +76,10 @@ impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_decimal_byte pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::cast(array: &vortex_decimal_byte_parts::DecimalBytePartsArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::mask::kernel::MaskReduce for vortex_decimal_byte_parts::DecimalBytePartsVTable + +pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::mask(array: &vortex_decimal_byte_parts::DecimalBytePartsArray, mask: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::VTable for vortex_decimal_byte_parts::DecimalBytePartsVTable pub type vortex_decimal_byte_parts::DecimalBytePartsVTable::Array = vortex_decimal_byte_parts::DecimalBytePartsArray diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs index 04a91c5de17..80b69f5442d 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs @@ -2,21 +2,25 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::ArrayRef; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::IntoArray; +use vortex_array::arrays::ScalarFnArrayExt; +use vortex_array::compute::MaskReduce; +use vortex_array::expr::EmptyOptions; +use vortex_array::expr::Mask as MaskExpr; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::DecimalBytePartsArray; use crate::DecimalBytePartsVTable; -impl MaskKernel for DecimalBytePartsVTable { - fn mask(&self, array: &DecimalBytePartsArray, mask_array: &Mask) -> VortexResult { - let masked = mask(&array.msp, mask_array)?; - DecimalBytePartsArray::try_new(masked, *array.decimal_dtype()).map(|a| a.to_array()) +impl MaskReduce for DecimalBytePartsVTable { + fn mask(array: &DecimalBytePartsArray, mask: &ArrayRef) -> VortexResult> { + let masked_msp = MaskExpr.try_new_array( + array.msp.len(), + EmptyOptions, + [array.msp.clone(), mask.clone()], + )?; + Ok(Some( + DecimalBytePartsArray::try_new(masked_msp, *array.decimal_dtype())?.into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(DecimalBytePartsVTable).lift()); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs index 7c851149323..7762b50b0ec 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::FilterReduceAdaptor; use vortex_array::arrays::FilterVTable; use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ArrayParentReduceRule; use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; @@ -20,6 +21,7 @@ pub(super) const PARENT_RULES: ParentRuleSet = ParentRul ParentRuleSet::lift(&DecimalBytePartsFilterPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(DecimalBytePartsVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(DecimalBytePartsVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(DecimalBytePartsVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(DecimalBytePartsVTable)), ]); diff --git a/encodings/zigzag/public-api.lock b/encodings/zigzag/public-api.lock index 923a6403464..cd576273088 100644 --- a/encodings/zigzag/public-api.lock +++ b/encodings/zigzag/public-api.lock @@ -60,14 +60,14 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_zigzag::ZigZagVTable pub fn vortex_zigzag::ZigZagVTable::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::compute::mask::MaskKernel for vortex_zigzag::ZigZagVTable - -pub fn vortex_zigzag::ZigZagVTable::mask(&self, array: &vortex_zigzag::ZigZagArray, filter_mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_zigzag::ZigZagVTable pub fn vortex_zigzag::ZigZagVTable::cast(array: &vortex_zigzag::ZigZagArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::mask::kernel::MaskReduce for vortex_zigzag::ZigZagVTable + +pub fn vortex_zigzag::ZigZagVTable::mask(array: &vortex_zigzag::ZigZagArray, mask: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::VTable for vortex_zigzag::ZigZagVTable pub type vortex_zigzag::ZigZagVTable::Array = vortex_zigzag::ZigZagArray diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 562af093215..d61433d856b 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -8,11 +8,11 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; +use vortex_array::arrays::ScalarFnArrayExt; use vortex_array::arrays::TakeExecute; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::compute::MaskReduce; +use vortex_array::expr::EmptyOptions; +use vortex_array::expr::Mask as MaskExpr; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -37,15 +37,17 @@ impl TakeExecute for ZigZagVTable { } } -impl MaskKernel for ZigZagVTable { - fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult { - let encoded = mask(array.encoded(), filter_mask)?; - Ok(ZigZagArray::try_new(encoded)?.into_array()) +impl MaskReduce for ZigZagVTable { + fn mask(array: &ZigZagArray, mask: &ArrayRef) -> VortexResult> { + let masked_encoded = MaskExpr.try_new_array( + array.encoded().len(), + EmptyOptions, + [array.encoded().clone(), mask.clone()], + )?; + Ok(Some(ZigZagArray::try_new(masked_encoded)?.into_array())) } } -register_kernel!(MaskKernelAdapter(ZigZagVTable).lift()); - pub(crate) trait ZigZagEncoded { type Int: zigzag::ZigZag; } diff --git a/encodings/zigzag/src/rules.rs b/encodings/zigzag/src/rules.rs index 0b684b0c01d..0cdc6962976 100644 --- a/encodings/zigzag/src/rules.rs +++ b/encodings/zigzag/src/rules.rs @@ -4,6 +4,7 @@ use vortex_array::arrays::FilterReduceAdaptor; use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ParentRuleSet; use crate::ZigZagVTable; @@ -11,5 +12,6 @@ use crate::ZigZagVTable; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(ZigZagVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(ZigZagVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ZigZagVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ZigZagVTable)), ]); diff --git a/fuzz/src/array/mask.rs b/fuzz/src/array/mask.rs index 5d6bc7f28d4..c1d40beeb01 100644 --- a/fuzz/src/array/mask.rs +++ b/fuzz/src/array/mask.rs @@ -5,6 +5,7 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::IntoArray; +use vortex_array::ToCanonical; use vortex_array::arrays::BoolArray; use vortex_array::arrays::DecimalArray; use vortex_array::arrays::ExtensionArray; @@ -13,12 +14,39 @@ use vortex_array::arrays::ListViewArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::arrays::VarBinViewArray; +use vortex_array::validity::Validity; use vortex_array::vtable::ValidityHelper; +use vortex_dtype::Nullability; use vortex_dtype::match_each_decimal_value_type; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_mask::AllOr; use vortex_mask::Mask; +/// Set to false any entries for which the mask is true. +/// +/// The result is always nullable. The result has the same length as self. +#[inline] +pub fn mask_validity(validity: &Validity, mask: &Mask) -> Validity { + match mask.bit_buffer() { + AllOr::All => Validity::AllInvalid, + AllOr::None => validity.clone().into_nullable(), + AllOr::Some(make_invalid) => match validity { + Validity::NonNullable | Validity::AllValid => { + Validity::from_bit_buffer(!make_invalid, Nullability::Nullable) + } + Validity::AllInvalid => Validity::AllInvalid, + Validity::Array(is_valid) => { + let is_valid = is_valid.to_bool(); + Validity::from_bit_buffer( + is_valid.to_bit_buffer() & !make_invalid, + Nullability::Nullable, + ) + } + }, + } +} + /// Apply mask on the canonical form of the array to get a consistent baseline. /// This implementation manually applies the mask to each canonical type /// without using the mask_fn method, to serve as an independent baseline for testing. @@ -29,11 +57,11 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult { - let new_validity = array.validity().mask(mask); + let new_validity = mask_validity(array.validity(), mask); BoolArray::new(array.to_bit_buffer(), new_validity).into_array() } Canonical::Primitive(array) => { - let new_validity = array.validity().mask(mask); + let new_validity = mask_validity(array.validity(), mask); PrimitiveArray::from_buffer_handle( array.buffer_handle().clone(), array.ptype(), @@ -42,14 +70,14 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult { - let new_validity = array.validity().mask(mask); + let new_validity = mask_validity(array.validity(), mask); match_each_decimal_value_type!(array.values_type(), |D| { DecimalArray::new(array.buffer::(), array.decimal_dtype(), new_validity) .into_array() }) } Canonical::VarBinView(array) => { - let new_validity = array.validity().mask(mask); + let new_validity = mask_validity(array.validity(), mask); VarBinViewArray::new_handle( array.views_handle().clone(), array.buffers().clone(), @@ -59,7 +87,7 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult { - let new_validity = array.validity().mask(mask); + let new_validity = mask_validity(array.validity(), mask); // SAFETY: Since we are only masking the validity and everything else comes from an // already valid `ListViewArray`, all of the invariants are still upheld. @@ -75,7 +103,7 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult { - let new_validity = array.validity().mask(mask); + let new_validity = mask_validity(array.validity(), mask); FixedSizeListArray::new( array.elements().clone(), array.list_size(), @@ -85,7 +113,7 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult { - let new_validity = array.validity().mask(mask); + let new_validity = mask_validity(array.validity(), mask); StructArray::try_new_with_dtype( array.unmasked_fields().clone(), array.struct_fields().clone(), diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index 2c4bf4aaa39..e640c78f148 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -555,7 +555,6 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::compare; - use vortex_array::compute::mask; use vortex_array::compute::min_max; use vortex_array::compute::sum; let FuzzArrayAction { array, actions } = fuzz_action; @@ -644,7 +643,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz assert_array_eq(&expected.array(), ¤t_array, i)?; } Action::Mask(mask_val) => { - current_array = mask(¤t_array, &mask_val) + current_array = current_array + .mask(mask_val.into_array()) .vortex_expect("mask operation should succeed in fuzz test"); assert_array_eq(&expected.array(), ¤t_array, i)?; } diff --git a/vortex-array/benches/dict_mask.rs b/vortex-array/benches/dict_mask.rs index 2dff5768431..41a3d0a355c 100644 --- a/vortex-array/benches/dict_mask.rs +++ b/vortex-array/benches/dict_mask.rs @@ -8,11 +8,14 @@ use rand::Rng; use rand::SeedableRng; use rand::rngs::StdRng; use vortex_array::IntoArray; +use vortex_array::RecursiveCanonical; +use vortex_array::VortexSessionExecute; use vortex_array::arrays::DictArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::compute::mask; use vortex_array::compute::warm_up_vtables; use vortex_mask::Mask; +use vortex_session::VortexSession; fn main() { warm_up_vtables(); @@ -59,7 +62,14 @@ fn bench_dict_mask(bencher: Bencher, (fraction_valid, fraction_masked): (f64, f6 let values = PrimitiveArray::from_option_iter([None, Some(42i32)]).into_array(); let array = DictArray::try_new(codes, values).unwrap().into_array(); let filter_mask = filter_mask(len, fraction_masked, &mut rng); + let session = VortexSession::empty(); bencher .with_inputs(|| (&array, &filter_mask)) - .bench_refs(|(array, filter_mask)| mask(array.as_ref(), filter_mask).unwrap()); + .bench_refs(|(array, filter_mask)| { + let mut ctx = session.create_execution_ctx(); + mask(*array, filter_mask) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 950242c2e61..807ecf330ef 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -174,6 +174,10 @@ impl vortex_array::expr::LikeReduce for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::like(array: &vortex_array::arrays::DictArray, pattern: &dyn vortex_array::Array, options: vortex_array::expr::LikeOptions) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::DictVTable + +pub fn vortex_array::arrays::DictVTable::mask(array: &vortex_array::arrays::DictArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::array_eq(array: &vortex_array::arrays::DictArray, other: &vortex_array::arrays::DictArray, precision: vortex_array::Precision) -> bool @@ -456,10 +460,6 @@ pub fn vortex_array::arrays::BoolVTable::is_sorted(&self, array: &vortex_array:: pub fn vortex_array::arrays::BoolVTable::is_strict_sorted(&self, array: &vortex_array::arrays::BoolArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::BoolVTable - -pub fn vortex_array::arrays::BoolVTable::mask(&self, array: &vortex_array::arrays::BoolArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::BoolVTable pub fn vortex_array::arrays::BoolVTable::min_max(&self, array: &vortex_array::arrays::BoolArray) -> vortex_error::VortexResult> @@ -476,6 +476,10 @@ impl vortex_array::expr::FillNullKernel for vortex_array::arrays::BoolVTable pub fn vortex_array::arrays::BoolVTable::fill_null(array: &vortex_array::arrays::BoolArray, fill_value: &vortex_scalar::scalar::Scalar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::BoolVTable + +pub fn vortex_array::arrays::BoolVTable::mask(array: &vortex_array::arrays::BoolArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::BoolMaskedValidityRule pub type vortex_array::arrays::BoolMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -630,10 +634,6 @@ pub fn vortex_array::arrays::ChunkedVTable::is_sorted(&self, array: &vortex_arra pub fn vortex_array::arrays::ChunkedVTable::is_strict_sorted(&self, array: &vortex_array::arrays::ChunkedArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ChunkedVTable - -pub fn vortex_array::arrays::ChunkedVTable::mask(&self, array: &vortex_array::arrays::ChunkedArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::min_max(&self, array: &vortex_array::arrays::ChunkedArray) -> vortex_error::VortexResult> @@ -650,6 +650,10 @@ impl vortex_array::expr::FillNullReduce for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::fill_null(array: &vortex_array::arrays::ChunkedArray, fill_value: &vortex_scalar::scalar::Scalar) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskKernel for vortex_array::arrays::ChunkedVTable + +pub fn vortex_array::arrays::ChunkedVTable::mask(array: &vortex_array::arrays::ChunkedArray, mask: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::ZipReduce for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::zip(if_true: &vortex_array::arrays::ChunkedArray, if_false: &dyn vortex_array::Array, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> @@ -786,10 +790,6 @@ impl vortex_array::arrays::TakeReduce for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::take(array: &vortex_array::arrays::ConstantArray, indices: &dyn vortex_array::Array) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ConstantVTable - -pub fn vortex_array::arrays::ConstantVTable::mask(&self, array: &vortex_array::arrays::ConstantArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::min_max(&self, array: &vortex_array::arrays::ConstantArray) -> vortex_error::VortexResult> @@ -1018,6 +1018,10 @@ impl vortex_array::expr::FillNullKernel for vortex_array::arrays::DecimalVTable pub fn vortex_array::arrays::DecimalVTable::fill_null(array: &vortex_array::arrays::DecimalArray, fill_value: &vortex_scalar::scalar::Scalar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::DecimalVTable + +pub fn vortex_array::arrays::DecimalVTable::mask(array: &vortex_array::arrays::DecimalArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::DecimalMaskedValidityRule pub type vortex_array::arrays::DecimalMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -1228,6 +1232,10 @@ impl vortex_array::expr::LikeReduce for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::like(array: &vortex_array::arrays::DictArray, pattern: &dyn vortex_array::Array, options: vortex_array::expr::LikeOptions) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::DictVTable + +pub fn vortex_array::arrays::DictVTable::mask(array: &vortex_array::arrays::DictArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::array_eq(array: &vortex_array::arrays::DictArray, other: &vortex_array::arrays::DictArray, precision: vortex_array::Precision) -> bool @@ -1398,10 +1406,6 @@ pub fn vortex_array::arrays::ExtensionVTable::is_sorted(&self, array: &vortex_ar pub fn vortex_array::arrays::ExtensionVTable::is_strict_sorted(&self, array: &vortex_array::arrays::ExtensionArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ExtensionVTable - -pub fn vortex_array::arrays::ExtensionVTable::mask(&self, array: &vortex_array::arrays::ExtensionArray, mask_array: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::min_max(&self, array: &vortex_array::arrays::ExtensionArray) -> vortex_error::VortexResult> @@ -1418,6 +1422,10 @@ impl vortex_array::expr::CompareKernel for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ExtensionVTable + +pub fn vortex_array::arrays::ExtensionVTable::mask(array: &vortex_array::arrays::ExtensionArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::array_eq(array: &vortex_array::arrays::ExtensionArray, other: &vortex_array::arrays::ExtensionArray, precision: vortex_array::Precision) -> bool @@ -1714,10 +1722,6 @@ pub fn vortex_array::arrays::FixedSizeListVTable::is_sorted(&self, _array: &vort pub fn vortex_array::arrays::FixedSizeListVTable::is_strict_sorted(&self, _array: &vortex_array::arrays::FixedSizeListArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::FixedSizeListVTable - -pub fn vortex_array::arrays::FixedSizeListVTable::mask(&self, array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::FixedSizeListVTable pub fn vortex_array::arrays::FixedSizeListVTable::min_max(&self, _array: &vortex_array::arrays::FixedSizeListArray) -> vortex_error::VortexResult> @@ -1726,6 +1730,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::FixedSizeListVTabl pub fn vortex_array::arrays::FixedSizeListVTable::cast(array: &vortex_array::arrays::FixedSizeListArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::FixedSizeListVTable + +pub fn vortex_array::arrays::FixedSizeListVTable::mask(array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::FixedSizeListVTable pub fn vortex_array::arrays::FixedSizeListVTable::array_eq(array: &vortex_array::arrays::FixedSizeListArray, other: &vortex_array::arrays::FixedSizeListArray, precision: vortex_array::Precision) -> bool @@ -1914,10 +1922,6 @@ pub fn vortex_array::arrays::ListVTable::is_sorted(&self, _array: &vortex_array: pub fn vortex_array::arrays::ListVTable::is_strict_sorted(&self, _array: &vortex_array::arrays::ListArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ListVTable - -pub fn vortex_array::arrays::ListVTable::mask(&self, array: &vortex_array::arrays::ListArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ListVTable pub fn vortex_array::arrays::ListVTable::min_max(&self, _array: &vortex_array::arrays::ListArray) -> vortex_error::VortexResult> @@ -1926,6 +1930,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::ListVTable pub fn vortex_array::arrays::ListVTable::cast(array: &vortex_array::arrays::ListArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ListVTable + +pub fn vortex_array::arrays::ListVTable::mask(array: &vortex_array::arrays::ListArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::ListVTable pub fn vortex_array::arrays::ListVTable::array_eq(array: &vortex_array::arrays::ListArray, other: &vortex_array::arrays::ListArray, precision: vortex_array::Precision) -> bool @@ -2096,10 +2104,6 @@ pub fn vortex_array::arrays::ListViewVTable::is_sorted(&self, _array: &vortex_ar pub fn vortex_array::arrays::ListViewVTable::is_strict_sorted(&self, _array: &vortex_array::arrays::ListViewArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ListViewVTable - -pub fn vortex_array::arrays::ListViewVTable::mask(&self, array: &vortex_array::arrays::ListViewArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ListViewVTable pub fn vortex_array::arrays::ListViewVTable::min_max(&self, _array: &vortex_array::arrays::ListViewArray) -> vortex_error::VortexResult> @@ -2108,6 +2112,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::ListViewVTable pub fn vortex_array::arrays::ListViewVTable::cast(array: &vortex_array::arrays::ListViewArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ListViewVTable + +pub fn vortex_array::arrays::ListViewVTable::mask(array: &vortex_array::arrays::ListViewArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::ListViewVTable pub fn vortex_array::arrays::ListViewVTable::array_eq(array: &vortex_array::arrays::ListViewArray, other: &vortex_array::arrays::ListViewArray, precision: vortex_array::Precision) -> bool @@ -2228,9 +2236,9 @@ impl vortex_array::arrays::TakeExecute for vortex_array::arrays::MaskedVTable pub fn vortex_array::arrays::MaskedVTable::take(array: &vortex_array::arrays::MaskedArray, indices: &dyn vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::MaskedVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::MaskedVTable -pub fn vortex_array::arrays::MaskedVTable::mask(&self, array: &vortex_array::arrays::MaskedArray, mask_arg: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::MaskedVTable::mask(array: &vortex_array::arrays::MaskedArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::MaskedVTable @@ -2418,10 +2426,6 @@ impl vortex_array::arrays::TakeReduce for vortex_array::arrays::NullVTable pub fn vortex_array::arrays::NullVTable::take(array: &vortex_array::arrays::NullArray, indices: &dyn vortex_array::Array) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::NullVTable - -pub fn vortex_array::arrays::NullVTable::mask(&self, array: &vortex_array::arrays::NullArray, _mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::NullVTable pub fn vortex_array::arrays::NullVTable::min_max(&self, _array: &vortex_array::arrays::NullArray) -> vortex_error::VortexResult> @@ -2430,6 +2434,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::NullVTable pub fn vortex_array::arrays::NullVTable::cast(array: &vortex_array::arrays::NullArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::NullVTable + +pub fn vortex_array::arrays::NullVTable::mask(array: &vortex_array::arrays::NullArray, _mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::NullVTable pub fn vortex_array::arrays::NullVTable::array_eq(array: &vortex_array::arrays::NullArray, other: &vortex_array::arrays::NullArray, _precision: vortex_array::Precision) -> bool @@ -2652,10 +2660,6 @@ pub fn vortex_array::arrays::PrimitiveVTable::is_sorted(&self, array: &vortex_ar pub fn vortex_array::arrays::PrimitiveVTable::is_strict_sorted(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::PrimitiveVTable - -pub fn vortex_array::arrays::PrimitiveVTable::mask(&self, array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::PrimitiveVTable pub fn vortex_array::arrays::PrimitiveVTable::min_max(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult> @@ -2676,6 +2680,10 @@ impl vortex_array::expr::FillNullKernel for vortex_array::arrays::PrimitiveVTabl pub fn vortex_array::arrays::PrimitiveVTable::fill_null(array: &vortex_array::arrays::PrimitiveArray, fill_value: &vortex_scalar::scalar::Scalar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::PrimitiveVTable + +pub fn vortex_array::arrays::PrimitiveVTable::mask(array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::arrays::PrimitiveMaskedValidityRule pub type vortex_array::arrays::PrimitiveMaskedValidityRule::Parent = vortex_array::arrays::MaskedVTable @@ -3284,10 +3292,6 @@ impl vortex_array::compute::IsConstantKernel for vortex_array::arrays::StructVTa pub fn vortex_array::arrays::StructVTable::is_constant(&self, array: &vortex_array::arrays::StructArray, opts: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::StructVTable - -pub fn vortex_array::arrays::StructVTable::mask(&self, array: &vortex_array::arrays::StructArray, filter_mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::StructVTable pub fn vortex_array::arrays::StructVTable::min_max(&self, _array: &vortex_array::arrays::StructArray) -> vortex_error::VortexResult> @@ -3296,6 +3300,10 @@ impl vortex_array::expr::CastKernel for vortex_array::arrays::StructVTable pub fn vortex_array::arrays::StructVTable::cast(array: &vortex_array::arrays::StructArray, dtype: &vortex_dtype::dtype::DType, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::StructVTable + +pub fn vortex_array::arrays::StructVTable::mask(array: &vortex_array::arrays::StructArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::expr::ZipKernel for vortex_array::arrays::StructVTable pub fn vortex_array::arrays::StructVTable::zip(if_true: &vortex_array::arrays::StructArray, if_false: &dyn vortex_array::Array, mask: &vortex_mask::Mask, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -3614,10 +3622,6 @@ pub fn vortex_array::arrays::VarBinVTable::is_sorted(&self, array: &vortex_array pub fn vortex_array::arrays::VarBinVTable::is_strict_sorted(&self, array: &vortex_array::arrays::VarBinArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::VarBinVTable - -pub fn vortex_array::arrays::VarBinVTable::mask(&self, array: &vortex_array::arrays::VarBinArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::VarBinVTable pub fn vortex_array::arrays::VarBinVTable::min_max(&self, array: &vortex_array::arrays::VarBinArray) -> vortex_error::VortexResult> @@ -3630,6 +3634,10 @@ impl vortex_array::expr::CompareKernel for vortex_array::arrays::VarBinVTable pub fn vortex_array::arrays::VarBinVTable::compare(lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::VarBinVTable + +pub fn vortex_array::arrays::VarBinVTable::mask(array: &vortex_array::arrays::VarBinArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::VarBinVTable pub fn vortex_array::arrays::VarBinVTable::array_eq(array: &vortex_array::arrays::VarBinArray, other: &vortex_array::arrays::VarBinArray, precision: vortex_array::Precision) -> bool @@ -3834,10 +3842,6 @@ pub fn vortex_array::arrays::VarBinViewVTable::is_sorted(&self, array: &vortex_a pub fn vortex_array::arrays::VarBinViewVTable::is_strict_sorted(&self, array: &vortex_array::arrays::VarBinViewArray) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::VarBinViewVTable - -pub fn vortex_array::arrays::VarBinViewVTable::mask(&self, array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult - impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::VarBinViewVTable pub fn vortex_array::arrays::VarBinViewVTable::min_max(&self, array: &vortex_array::arrays::VarBinViewArray) -> vortex_error::VortexResult> @@ -3846,6 +3850,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::VarBinViewVTable pub fn vortex_array::arrays::VarBinViewVTable::cast(array: &vortex_array::arrays::VarBinViewArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::MaskReduce for vortex_array::arrays::VarBinViewVTable + +pub fn vortex_array::arrays::VarBinViewVTable::mask(array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + impl vortex_array::expr::ZipKernel for vortex_array::arrays::VarBinViewVTable pub fn vortex_array::arrays::VarBinViewVTable::zip(if_true: &vortex_array::arrays::VarBinViewArray, if_false: &dyn vortex_array::Array, mask: &vortex_mask::Mask, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -5964,23 +5972,37 @@ pub struct vortex_array::compute::ListContainsKernelRef(_) impl inventory::Collect for vortex_array::compute::ListContainsKernelRef -pub struct vortex_array::compute::MaskKernelAdapter(pub V) +pub struct vortex_array::compute::MaskExecuteAdaptor(pub V) + +impl core::default::Default for vortex_array::expr::MaskExecuteAdaptor + +pub fn vortex_array::expr::MaskExecuteAdaptor::default() -> vortex_array::expr::MaskExecuteAdaptor -impl vortex_array::compute::MaskKernelAdapter +impl core::fmt::Debug for vortex_array::expr::MaskExecuteAdaptor -pub const fn vortex_array::compute::MaskKernelAdapter::lift(&'static self) -> vortex_array::compute::MaskKernelRef +pub fn vortex_array::expr::MaskExecuteAdaptor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl core::fmt::Debug for vortex_array::compute::MaskKernelAdapter +impl vortex_array::kernel::ExecuteParentKernel for vortex_array::expr::MaskExecuteAdaptor where V: vortex_array::expr::MaskKernel -pub fn vortex_array::compute::MaskKernelAdapter::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub type vortex_array::expr::MaskExecuteAdaptor::Parent = vortex_array::arrays::ExactScalarFn -impl vortex_array::compute::Kernel for vortex_array::compute::MaskKernelAdapter +pub fn vortex_array::expr::MaskExecuteAdaptor::execute_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Mask>, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -pub fn vortex_array::compute::MaskKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> +pub struct vortex_array::compute::MaskReduceAdaptor(pub V) -pub struct vortex_array::compute::MaskKernelRef(_) +impl core::default::Default for vortex_array::expr::MaskReduceAdaptor -impl inventory::Collect for vortex_array::compute::MaskKernelRef +pub fn vortex_array::expr::MaskReduceAdaptor::default() -> vortex_array::expr::MaskReduceAdaptor + +impl core::fmt::Debug for vortex_array::expr::MaskReduceAdaptor + +pub fn vortex_array::expr::MaskReduceAdaptor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::expr::MaskReduceAdaptor where V: vortex_array::expr::MaskReduce + +pub type vortex_array::expr::MaskReduceAdaptor::Parent = vortex_array::arrays::ExactScalarFn + +pub fn vortex_array::expr::MaskReduceAdaptor::reduce_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Mask>, child_idx: usize) -> vortex_error::VortexResult> pub struct vortex_array::compute::MinMax @@ -6418,10 +6440,6 @@ impl::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> -impl vortex_array::compute::Kernel for vortex_array::compute::MaskKernelAdapter - -pub fn vortex_array::compute::MaskKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> - impl vortex_array::compute::Kernel for vortex_array::compute::MinMaxKernelAdapter pub fn vortex_array::compute::MinMaxKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> @@ -6440,59 +6458,67 @@ pub fn vortex_array::compute::ListContainsKernel::list_contains(&self, list: &dy pub trait vortex_array::compute::MaskKernel: vortex_array::vtable::VTable -pub fn vortex_array::compute::MaskKernel::mask(&self, array: &Self::Array, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::compute::MaskKernel::mask(array: &Self::Array, mask: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::BoolVTable +impl vortex_array::expr::MaskKernel for vortex_array::arrays::ChunkedVTable -pub fn vortex_array::arrays::BoolVTable::mask(&self, array: &vortex_array::arrays::BoolArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::ChunkedVTable::mask(array: &vortex_array::arrays::ChunkedArray, mask: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ChunkedVTable +pub trait vortex_array::compute::MaskReduce: vortex_array::vtable::VTable -pub fn vortex_array::arrays::ChunkedVTable::mask(&self, array: &vortex_array::arrays::ChunkedArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::compute::MaskReduce::mask(array: &Self::Array, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ConstantVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::BoolVTable -pub fn vortex_array::arrays::ConstantVTable::mask(&self, array: &vortex_array::arrays::ConstantArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::BoolVTable::mask(array: &vortex_array::arrays::BoolArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ExtensionVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::DecimalVTable -pub fn vortex_array::arrays::ExtensionVTable::mask(&self, array: &vortex_array::arrays::ExtensionArray, mask_array: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::DecimalVTable::mask(array: &vortex_array::arrays::DecimalArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::FixedSizeListVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::DictVTable -pub fn vortex_array::arrays::FixedSizeListVTable::mask(&self, array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::DictVTable::mask(array: &vortex_array::arrays::DictArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ListVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ExtensionVTable -pub fn vortex_array::arrays::ListVTable::mask(&self, array: &vortex_array::arrays::ListArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::ExtensionVTable::mask(array: &vortex_array::arrays::ExtensionArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::ListViewVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::FixedSizeListVTable -pub fn vortex_array::arrays::ListViewVTable::mask(&self, array: &vortex_array::arrays::ListViewArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::FixedSizeListVTable::mask(array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::MaskedVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ListVTable -pub fn vortex_array::arrays::MaskedVTable::mask(&self, array: &vortex_array::arrays::MaskedArray, mask_arg: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::ListVTable::mask(array: &vortex_array::arrays::ListArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::NullVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ListViewVTable -pub fn vortex_array::arrays::NullVTable::mask(&self, array: &vortex_array::arrays::NullArray, _mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::ListViewVTable::mask(array: &vortex_array::arrays::ListViewArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::PrimitiveVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::MaskedVTable -pub fn vortex_array::arrays::PrimitiveVTable::mask(&self, array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::MaskedVTable::mask(array: &vortex_array::arrays::MaskedArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::StructVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::NullVTable -pub fn vortex_array::arrays::StructVTable::mask(&self, array: &vortex_array::arrays::StructArray, filter_mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::NullVTable::mask(array: &vortex_array::arrays::NullArray, _mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::VarBinVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::PrimitiveVTable -pub fn vortex_array::arrays::VarBinVTable::mask(&self, array: &vortex_array::arrays::VarBinArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::PrimitiveVTable::mask(array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> -impl vortex_array::compute::MaskKernel for vortex_array::arrays::VarBinViewVTable +impl vortex_array::expr::MaskReduce for vortex_array::arrays::StructVTable -pub fn vortex_array::arrays::VarBinViewVTable::mask(&self, array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult +pub fn vortex_array::arrays::StructVTable::mask(array: &vortex_array::arrays::StructArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::VarBinVTable + +pub fn vortex_array::arrays::VarBinVTable::mask(array: &vortex_array::arrays::VarBinArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::VarBinViewVTable + +pub fn vortex_array::arrays::VarBinViewVTable::mask(array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> pub trait vortex_array::compute::MinMaxKernel: vortex_array::vtable::VTable @@ -8644,6 +8670,38 @@ pub fn vortex_array::expr::Mask::simplify(&self, _options: &Self::Options, expr: pub fn vortex_array::expr::Mask::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> +pub struct vortex_array::expr::MaskExecuteAdaptor(pub V) + +impl core::default::Default for vortex_array::expr::MaskExecuteAdaptor + +pub fn vortex_array::expr::MaskExecuteAdaptor::default() -> vortex_array::expr::MaskExecuteAdaptor + +impl core::fmt::Debug for vortex_array::expr::MaskExecuteAdaptor + +pub fn vortex_array::expr::MaskExecuteAdaptor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::kernel::ExecuteParentKernel for vortex_array::expr::MaskExecuteAdaptor where V: vortex_array::expr::MaskKernel + +pub type vortex_array::expr::MaskExecuteAdaptor::Parent = vortex_array::arrays::ExactScalarFn + +pub fn vortex_array::expr::MaskExecuteAdaptor::execute_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Mask>, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + +pub struct vortex_array::expr::MaskReduceAdaptor(pub V) + +impl core::default::Default for vortex_array::expr::MaskReduceAdaptor + +pub fn vortex_array::expr::MaskReduceAdaptor::default() -> vortex_array::expr::MaskReduceAdaptor + +impl core::fmt::Debug for vortex_array::expr::MaskReduceAdaptor + +pub fn vortex_array::expr::MaskReduceAdaptor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::expr::MaskReduceAdaptor where V: vortex_array::expr::MaskReduce + +pub type vortex_array::expr::MaskReduceAdaptor::Parent = vortex_array::arrays::ExactScalarFn + +pub fn vortex_array::expr::MaskReduceAdaptor::reduce_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Mask>, child_idx: usize) -> vortex_error::VortexResult> + pub struct vortex_array::expr::Merge impl vortex_array::expr::VTable for vortex_array::expr::Merge @@ -9236,6 +9294,70 @@ impl vortex_array::expr::LikeReduce for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::like(array: &vortex_array::arrays::DictArray, pattern: &dyn vortex_array::Array, options: vortex_array::expr::LikeOptions) -> vortex_error::VortexResult> +pub trait vortex_array::expr::MaskKernel: vortex_array::vtable::VTable + +pub fn vortex_array::expr::MaskKernel::mask(array: &Self::Array, mask: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskKernel for vortex_array::arrays::ChunkedVTable + +pub fn vortex_array::arrays::ChunkedVTable::mask(array: &vortex_array::arrays::ChunkedArray, mask: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + +pub trait vortex_array::expr::MaskReduce: vortex_array::vtable::VTable + +pub fn vortex_array::expr::MaskReduce::mask(array: &Self::Array, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::BoolVTable + +pub fn vortex_array::arrays::BoolVTable::mask(array: &vortex_array::arrays::BoolArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::DecimalVTable + +pub fn vortex_array::arrays::DecimalVTable::mask(array: &vortex_array::arrays::DecimalArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::DictVTable + +pub fn vortex_array::arrays::DictVTable::mask(array: &vortex_array::arrays::DictArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ExtensionVTable + +pub fn vortex_array::arrays::ExtensionVTable::mask(array: &vortex_array::arrays::ExtensionArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::FixedSizeListVTable + +pub fn vortex_array::arrays::FixedSizeListVTable::mask(array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ListVTable + +pub fn vortex_array::arrays::ListVTable::mask(array: &vortex_array::arrays::ListArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::ListViewVTable + +pub fn vortex_array::arrays::ListViewVTable::mask(array: &vortex_array::arrays::ListViewArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::MaskedVTable + +pub fn vortex_array::arrays::MaskedVTable::mask(array: &vortex_array::arrays::MaskedArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::NullVTable + +pub fn vortex_array::arrays::NullVTable::mask(array: &vortex_array::arrays::NullArray, _mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::PrimitiveVTable + +pub fn vortex_array::arrays::PrimitiveVTable::mask(array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::StructVTable + +pub fn vortex_array::arrays::StructVTable::mask(array: &vortex_array::arrays::StructArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::VarBinVTable + +pub fn vortex_array::arrays::VarBinVTable::mask(array: &vortex_array::arrays::VarBinArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + +impl vortex_array::expr::MaskReduce for vortex_array::arrays::VarBinViewVTable + +pub fn vortex_array::arrays::VarBinViewVTable::mask(array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_array::ArrayRef) -> vortex_error::VortexResult> + pub trait vortex_array::expr::NotKernel: vortex_array::vtable::VTable pub fn vortex_array::expr::NotKernel::invert(array: &Self::Array, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -10054,6 +10176,12 @@ pub type vortex_array::expr::LikeExecuteAdaptor::Parent = vortex_array::array pub fn vortex_array::expr::LikeExecuteAdaptor::execute_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Like>, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::kernel::ExecuteParentKernel for vortex_array::expr::MaskExecuteAdaptor where V: vortex_array::expr::MaskKernel + +pub type vortex_array::expr::MaskExecuteAdaptor::Parent = vortex_array::arrays::ExactScalarFn + +pub fn vortex_array::expr::MaskExecuteAdaptor::execute_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Mask>, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::kernel::ExecuteParentKernel for vortex_array::expr::NotExecuteAdaptor where V: vortex_array::expr::NotKernel pub type vortex_array::expr::NotExecuteAdaptor::Parent = vortex_array::arrays::ExactScalarFn @@ -10242,6 +10370,12 @@ pub type vortex_array::expr::LikeReduceAdaptor::Parent = vortex_array::arrays pub fn vortex_array::expr::LikeReduceAdaptor::reduce_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Like>, child_idx: usize) -> vortex_error::VortexResult> +impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::expr::MaskReduceAdaptor where V: vortex_array::expr::MaskReduce + +pub type vortex_array::expr::MaskReduceAdaptor::Parent = vortex_array::arrays::ExactScalarFn + +pub fn vortex_array::expr::MaskReduceAdaptor::reduce_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Mask>, child_idx: usize) -> vortex_error::VortexResult> + impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::expr::NotReduceAdaptor where V: vortex_array::expr::NotReduce pub type vortex_array::expr::NotReduceAdaptor::Parent = vortex_array::arrays::ExactScalarFn @@ -10956,10 +11090,10 @@ pub fn vortex_array::validity::Validity::is_null(&self, index: usize) -> vortex_ pub fn vortex_array::validity::Validity::is_valid(&self, index: usize) -> vortex_error::VortexResult -pub fn vortex_array::validity::Validity::mask(&self, mask: &vortex_mask::Mask) -> Self - pub fn vortex_array::validity::Validity::maybe_len(&self) -> core::option::Option +pub fn vortex_array::validity::Validity::not(&self) -> vortex_error::VortexResult + pub fn vortex_array::validity::Validity::nullability(&self) -> vortex_dtype::nullability::Nullability pub fn vortex_array::validity::Validity::patch(self, len: usize, indices_offset: usize, indices: &dyn vortex_array::Array, patches: &vortex_array::validity::Validity) -> vortex_error::VortexResult diff --git a/vortex-array/src/arrays/bool/compute/mask.rs b/vortex-array/src/arrays/bool/compute/mask.rs index b013cee08fe..65175438d24 100644 --- a/vortex-array/src/arrays/bool/compute/mask.rs +++ b/vortex-array/src/arrays/bool/compute/mask.rs @@ -2,25 +2,30 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::BoolVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for BoolVTable { - fn mask(&self, array: &BoolArray, mask: &Mask) -> VortexResult { - Ok(BoolArray::new(array.to_bit_buffer(), array.validity().mask(mask)).into_array()) +impl MaskReduce for BoolVTable { + fn mask(array: &BoolArray, mask: &ArrayRef) -> VortexResult> { + Ok(Some( + BoolArray::new( + array.to_bit_buffer(), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, + ) + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(BoolVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; diff --git a/vortex-array/src/arrays/bool/compute/rules.rs b/vortex-array/src/arrays/bool/compute/rules.rs index ab3f5ef50b9..14a7295c4ba 100644 --- a/vortex-array/src/arrays/bool/compute/rules.rs +++ b/vortex-array/src/arrays/bool/compute/rules.rs @@ -11,6 +11,7 @@ use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; @@ -18,6 +19,7 @@ use crate::vtable::ValidityHelper; pub(crate) const RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&BoolMaskedValidityRule), ParentRuleSet::lift(&CastReduceAdaptor(BoolVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(BoolVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(BoolVTable)), ]); diff --git a/vortex-array/src/arrays/chunked/compute/kernel.rs b/vortex-array/src/arrays/chunked/compute/kernel.rs index e933961d816..84f8bf97557 100644 --- a/vortex-array/src/arrays/chunked/compute/kernel.rs +++ b/vortex-array/src/arrays/chunked/compute/kernel.rs @@ -5,10 +5,12 @@ use crate::arrays::ChunkedVTable; use crate::arrays::FilterExecuteAdaptor; use crate::arrays::SliceExecuteAdaptor; use crate::arrays::TakeExecuteAdaptor; +use crate::compute::MaskExecuteAdaptor; use crate::kernel::ParentKernelSet; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ - ParentKernelSet::lift(&SliceExecuteAdaptor(ChunkedVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(ChunkedVTable)), + ParentKernelSet::lift(&MaskExecuteAdaptor(ChunkedVTable)), + ParentKernelSet::lift(&SliceExecuteAdaptor(ChunkedVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(ChunkedVTable)), ]); diff --git a/vortex-array/src/arrays/chunked/compute/mask.rs b/vortex-array/src/arrays/chunked/compute/mask.rs index 7c6aa7b87ea..88a8f0898d5 100644 --- a/vortex-array/src/arrays/chunked/compute/mask.rs +++ b/vortex-array/src/arrays/chunked/compute/mask.rs @@ -1,153 +1,43 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use itertools::Itertools as _; -use vortex_buffer::BitBuffer; -use vortex_buffer::BitBufferMut; -use vortex_dtype::DType; use vortex_error::VortexResult; -use vortex_mask::AllOr; -use vortex_mask::Mask; -use vortex_mask::MaskIter; -use vortex_scalar::Scalar; -use super::filter::ChunkFilter; -use super::filter::chunk_filters; -use super::filter::find_chunk_idx; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; -use crate::arrays::BoolArray; use crate::arrays::ChunkedArray; use crate::arrays::ChunkedVTable; -use crate::arrays::ConstantArray; -use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD; -use crate::builtins::ArrayBuiltins; +use crate::arrays::ScalarFnArrayExt; use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::compute::mask; -use crate::register_kernel; -use crate::validity::Validity; +use crate::expr::EmptyOptions; +use crate::expr::mask::Mask as MaskExpr; impl MaskKernel for ChunkedVTable { - fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { - let new_dtype = array.dtype().as_nullable(); - let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - AllOr::All => unreachable!("handled in top-level mask"), - AllOr::None => unreachable!("handled in top-level mask"), - AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype), - AllOr::Some(MaskIter::Slices(slices)) => { - mask_slices(array, slices.iter().cloned(), &new_dtype) - } - }?; - debug_assert_eq!(new_chunks.len(), array.nchunks()); - debug_assert_eq!( - new_chunks.iter().map(|x| x.len()).sum::(), - array.len() - ); - ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array()) + fn mask( + array: &ChunkedArray, + mask: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let chunk_offsets = array.chunk_offsets(); + let new_chunks: Vec = array + .chunks() + .iter() + .enumerate() + .map(|(i, chunk)| { + let start: usize = chunk_offsets[i].try_into()?; + let end: usize = chunk_offsets[i + 1].try_into()?; + let chunk_mask = mask.slice(start..end)?; + MaskExpr.try_new_array(chunk.len(), EmptyOptions, [chunk.clone(), chunk_mask]) + }) + .collect::>()?; + + Ok(Some( + ChunkedArray::try_new(new_chunks, array.dtype().as_nullable())?.into_array(), + )) } } -register_kernel!(MaskKernelAdapter(ChunkedVTable).lift()); - -fn mask_indices( - array: &ChunkedArray, - indices: &[usize], - new_dtype: &DType, -) -> VortexResult> { - let mut new_chunks = Vec::with_capacity(array.nchunks()); - let mut current_chunk_id = 0; - let mut chunk_indices = Vec::::new(); - - let chunk_offsets = array.chunk_offsets(); - - for &set_index in indices { - let (chunk_id, index) = find_chunk_idx(set_index, &chunk_offsets)?; - if chunk_id != current_chunk_id { - let chunk = array.chunk(current_chunk_id).clone(); - let chunk_len = chunk.len(); - // chunk_indices contains indices to null out, but chunk.mask() expects - // mask=true to mean "retain". So we create a mask with bits set at indices - // to null, then invert it to get mask=true at indices to retain. - let mask = BoolArray::new( - !BitBuffer::from_indices(chunk_len, &chunk_indices), - Validity::NonNullable, - ) - .into_array(); - let masked_chunk = chunk.mask(mask)?; - // Advance the chunk forward, reset the chunk indices buffer. - chunk_indices = Vec::new(); - new_chunks.push(masked_chunk); - current_chunk_id += 1; - - while current_chunk_id < chunk_id { - // Chunks that are not affected by the mask, must still be casted to the correct dtype. - let chunk = array.chunk(current_chunk_id).cast(new_dtype.clone())?; - new_chunks.push(chunk); - current_chunk_id += 1; - } - } - - chunk_indices.push(index); - } - - if !chunk_indices.is_empty() { - let chunk = array.chunk(current_chunk_id).clone(); - let chunk_len = chunk.len(); - // Same inversion as above: invert the mask so mask=true means "retain" - let masked_chunk = chunk.mask( - BoolArray::new( - !BitBufferMut::from_indices(chunk_len, &chunk_indices).freeze(), - Validity::NonNullable, - ) - .into_array(), - )?; - new_chunks.push(masked_chunk); - current_chunk_id += 1; - } - - while current_chunk_id < array.nchunks() { - let chunk = array.chunk(current_chunk_id); - new_chunks.push(chunk.cast(new_dtype.clone())?); - current_chunk_id += 1; - } - - Ok(new_chunks) -} - -fn mask_slices( - array: &ChunkedArray, - slices: impl Iterator, - new_dtype: &DType, -) -> VortexResult> { - let chunked_filters = chunk_filters(array, slices)?; - - array - .chunks() - .iter() - .zip_eq(chunked_filters) - .map(|(chunk, chunk_filter)| -> VortexResult { - match chunk_filter { - ChunkFilter::All => { - // entire chunk is masked out - Ok( - ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()) - .into_array(), - ) - } - ChunkFilter::None => { - // entire chunk is not affected by mask - chunk.cast(new_dtype.clone()) - } - ChunkFilter::Slices(slices) => { - // Slices of indices that must be set to null - mask(chunk, &Mask::from_slices(chunk.len(), slices)) - } - } - }) - .process_results(|iter| iter.collect::>()) -} - #[cfg(test)] mod test { use rstest::rstest; diff --git a/vortex-array/src/arrays/constant/compute/mask.rs b/vortex-array/src/arrays/constant/compute/mask.rs deleted file mode 100644 index a1ad217e2b7..00000000000 --- a/vortex-array/src/arrays/constant/compute/mask.rs +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_dtype::Nullability; -use vortex_error::VortexResult; -use vortex_mask::Mask; - -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::arrays::ConstantVTable; -use crate::arrays::MaskedArray; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; -use crate::validity::Validity; - -impl MaskKernel for ConstantVTable { - fn mask(&self, array: &ConstantArray, mask: &Mask) -> VortexResult { - MaskedArray::try_new( - array.to_array(), - Validity::from_mask(!mask, Nullability::Nullable), - ) - .map(|a| a.into_array()) - } -} - -register_kernel!(MaskKernelAdapter(ConstantVTable).lift()); - -#[cfg(test)] -mod test { - use crate::arrays::ConstantArray; - use crate::compute::conformance::mask::test_mask_conformance; - - #[test] - fn test_mask_constant() { - let array = ConstantArray::new(std::f64::consts::PI, 15); - test_mask_conformance(array.as_ref()); - } -} diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index 2384df41a88..6793937dd6a 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -4,7 +4,6 @@ mod cast; mod fill_null; mod filter; -mod mask; mod min_max; mod not; pub(crate) mod rules; diff --git a/vortex-array/src/arrays/decimal/compute/mask.rs b/vortex-array/src/arrays/decimal/compute/mask.rs new file mode 100644 index 00000000000..35dde67a102 --- /dev/null +++ b/vortex-array/src/arrays/decimal/compute/mask.rs @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::match_each_decimal_value_type; +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::DecimalArray; +use crate::arrays::DecimalVTable; +use crate::compute::MaskReduce; +use crate::validity::Validity; +use crate::vtable::ValidityHelper; + +impl MaskReduce for DecimalVTable { + fn mask(array: &DecimalArray, mask: &ArrayRef) -> VortexResult> { + Ok(Some(match_each_decimal_value_type!( + array.values_type(), + |D| { + // SAFETY: masking the validity does not affect the invariants + unsafe { + DecimalArray::new_unchecked( + array.buffer::(), + array.decimal_dtype(), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, + ) + } + .into_array() + } + ))) + } +} diff --git a/vortex-array/src/arrays/decimal/compute/mod.rs b/vortex-array/src/arrays/decimal/compute/mod.rs index 2af9e6bbdfc..11a509b240d 100644 --- a/vortex-array/src/arrays/decimal/compute/mod.rs +++ b/vortex-array/src/arrays/decimal/compute/mod.rs @@ -6,6 +6,7 @@ mod cast; mod fill_null; mod is_constant; mod is_sorted; +mod mask; mod min_max; pub mod rules; mod sum; diff --git a/vortex-array/src/arrays/decimal/compute/rules.rs b/vortex-array/src/arrays/decimal/compute/rules.rs index 70130bf3e16..671b892f8f4 100644 --- a/vortex-array/src/arrays/decimal/compute/rules.rs +++ b/vortex-array/src/arrays/decimal/compute/rules.rs @@ -14,12 +14,14 @@ use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; use crate::arrays::SliceReduce; use crate::arrays::SliceReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&DecimalMaskedValidityRule), + ParentRuleSet::lift(&MaskReduceAdaptor(DecimalVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(DecimalVTable)), ]); diff --git a/vortex-array/src/arrays/dict/compute/mask.rs b/vortex-array/src/arrays/dict/compute/mask.rs new file mode 100644 index 00000000000..e6f9aca39a5 --- /dev/null +++ b/vortex-array/src/arrays/dict/compute/mask.rs @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::DictArray; +use crate::arrays::DictVTable; +use crate::arrays::ScalarFnArrayExt; +use crate::compute::MaskReduce; +use crate::expr::EmptyOptions; +use crate::expr::mask::Mask as MaskExpr; + +impl MaskReduce for DictVTable { + fn mask(array: &DictArray, mask: &ArrayRef) -> VortexResult> { + let masked_codes = MaskExpr.try_new_array( + array.codes().len(), + EmptyOptions, + [array.codes().clone(), mask.clone()], + )?; + // SAFETY: masking codes doesn't change dict invariants + Ok(Some(unsafe { + DictArray::new_unchecked(masked_codes, array.values().clone()).into_array() + })) + } +} diff --git a/vortex-array/src/arrays/dict/compute/min_max.rs b/vortex-array/src/arrays/dict/compute/min_max.rs index 32a041784fc..8eaaa56d766 100644 --- a/vortex-array/src/arrays/dict/compute/min_max.rs +++ b/vortex-array/src/arrays/dict/compute/min_max.rs @@ -2,17 +2,19 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use super::DictArray; use super::DictVTable; use crate::Array as _; +use crate::IntoArray; +use crate::arrays::BoolArray; +use crate::builtins::ArrayBuiltins; use crate::compute::MinMaxKernel; use crate::compute::MinMaxKernelAdapter; use crate::compute::MinMaxResult; -use crate::compute::mask; use crate::compute::min_max; use crate::register_kernel; +use crate::validity::Validity; impl MinMaxKernel for DictVTable { fn min_max(&self, array: &DictArray) -> VortexResult> { @@ -27,8 +29,13 @@ impl MinMaxKernel for DictVTable { } // Slow path: compute which values are unreferenced and mask them out - let unreferenced_mask = Mask::from_buffer(array.compute_referenced_values_mask(false)?); - min_max(&mask(array.values(), &unreferenced_mask)?) + let unreferenced_mask = BoolArray::new( + array.compute_referenced_values_mask(true)?, + Validity::NonNullable, + ) + .into_array(); + + min_max(&array.values().clone().mask(unreferenced_mask)?) } } diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index 056b151ec06..97cd46fced7 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -7,6 +7,7 @@ mod fill_null; mod is_constant; mod is_sorted; mod like; +mod mask; mod min_max; pub(crate) mod rules; mod slice; diff --git a/vortex-array/src/arrays/dict/compute/rules.rs b/vortex-array/src/arrays/dict/compute/rules.rs index 583e52db2b4..2ee25fc5fab 100644 --- a/vortex-array/src/arrays/dict/compute/rules.rs +++ b/vortex-array/src/arrays/dict/compute/rules.rs @@ -20,6 +20,7 @@ use crate::builtins::ArrayBuiltins; use crate::compute::CastReduceAdaptor; use crate::expr::Cast; use crate::expr::LikeReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::expr::Pack; use crate::optimizer::ArrayOptimizer; use crate::optimizer::rules::ArrayParentReduceRule; @@ -28,6 +29,7 @@ use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(DictVTable)), ParentRuleSet::lift(&CastReduceAdaptor(DictVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(DictVTable)), ParentRuleSet::lift(&LikeReduceAdaptor(DictVTable)), ParentRuleSet::lift(&DictionaryScalarFnValuesPushDownRule), ParentRuleSet::lift(&DictionaryScalarFnCodesPullUpRule), diff --git a/vortex-array/src/arrays/extension/compute/mask.rs b/vortex-array/src/arrays/extension/compute/mask.rs index 1b44447b9dc..4b88575e6e1 100644 --- a/vortex-array/src/arrays/extension/compute/mask.rs +++ b/vortex-array/src/arrays/extension/compute/mask.rs @@ -2,32 +2,31 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; -use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::compute::mask; -use crate::register_kernel; +use crate::arrays::ScalarFnArrayExt; +use crate::compute::MaskReduce; +use crate::expr::EmptyOptions; +use crate::expr::mask::Mask as MaskExpr; -impl MaskKernel for ExtensionVTable { - fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult { - // Use compute::mask directly since mask_array has compute::mask semantics (true=null) - let masked_storage = mask(array.storage(), mask_array)?; - assert!(masked_storage.dtype().is_nullable()); - - Ok(ExtensionArray::new( - array - .ext_dtype() - .with_nullability(masked_storage.dtype().nullability()), - masked_storage, - ) - .into_array()) +impl MaskReduce for ExtensionVTable { + fn mask(array: &ExtensionArray, mask: &ArrayRef) -> VortexResult> { + let masked_storage = MaskExpr.try_new_array( + array.storage().len(), + EmptyOptions, + [array.storage().clone(), mask.clone()], + )?; + Ok(Some( + ExtensionArray::new( + array + .ext_dtype() + .with_nullability(masked_storage.dtype().nullability()), + masked_storage, + ) + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(ExtensionVTable).lift()); diff --git a/vortex-array/src/arrays/extension/compute/rules.rs b/vortex-array/src/arrays/extension/compute/rules.rs index f18687d570a..ee3b883a9f7 100644 --- a/vortex-array/src/arrays/extension/compute/rules.rs +++ b/vortex-array/src/arrays/extension/compute/rules.rs @@ -12,6 +12,7 @@ use crate::arrays::FilterReduceAdaptor; use crate::arrays::FilterVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; @@ -19,6 +20,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::n ParentRuleSet::lift(&ExtensionFilterPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(ExtensionVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(ExtensionVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ExtensionVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ExtensionVTable)), ]); diff --git a/vortex-array/src/arrays/fixed_size_list/compute/mask.rs b/vortex-array/src/arrays/fixed_size_list/compute/mask.rs index 26232addc37..2aba6829fd1 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/mask.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/mask.rs @@ -2,34 +2,31 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -/// Mask implementation for [`FixedSizeListArray`]. -/// -/// Applies a validity mask to the array without modifying the underlying element data. -impl MaskKernel for FixedSizeListVTable { - fn mask(&self, array: &FixedSizeListArray, mask: &Mask) -> VortexResult { - // SAFETY: The only thing that changes here is the validity mask, which will have the same - // length. So as long as the original array is valid, this is also valid. - Ok(unsafe { - FixedSizeListArray::new_unchecked( - array.elements().clone(), - array.list_size(), - array.validity().mask(mask), - array.len(), - ) - } - .into_array()) +impl MaskReduce for FixedSizeListVTable { + fn mask(array: &FixedSizeListArray, mask: &ArrayRef) -> VortexResult> { + // SAFETY: masking the validity does not affect the invariants + Ok(Some( + unsafe { + FixedSizeListArray::new_unchecked( + array.elements().clone(), + array.list_size(), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, + array.len(), + ) + } + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(FixedSizeListVTable).lift()); diff --git a/vortex-array/src/arrays/fixed_size_list/compute/rules.rs b/vortex-array/src/arrays/fixed_size_list/compute/rules.rs index 5ba1e777ec2..af3bc242a14 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/rules.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::FixedSizeListVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(FixedSizeListVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(FixedSizeListVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(FixedSizeListVTable)), ]); diff --git a/vortex-array/src/arrays/list/compute/mask.rs b/vortex-array/src/arrays/list/compute/mask.rs index 877386133c1..59cca6ad2d2 100644 --- a/vortex-array/src/arrays/list/compute/mask.rs +++ b/vortex-array/src/arrays/list/compute/mask.rs @@ -2,26 +2,25 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ListArray; use crate::arrays::ListVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for ListVTable { - fn mask(&self, array: &ListArray, mask: &Mask) -> VortexResult { +impl MaskReduce for ListVTable { + fn mask(array: &ListArray, mask: &ArrayRef) -> VortexResult> { ListArray::try_new( array.elements().clone(), array.offsets().clone(), - array.validity().mask(mask), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, ) - .map(|a| a.into_array()) + .map(|a| Some(a.into_array())) } } - -register_kernel!(MaskKernelAdapter(ListVTable).lift()); diff --git a/vortex-array/src/arrays/list/compute/rules.rs b/vortex-array/src/arrays/list/compute/rules.rs index 70900ea11e8..8ea7043fd1f 100644 --- a/vortex-array/src/arrays/list/compute/rules.rs +++ b/vortex-array/src/arrays/list/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::ListVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(ListVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ListVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ListVTable)), ]); diff --git a/vortex-array/src/arrays/listview/compute/mask.rs b/vortex-array/src/arrays/listview/compute/mask.rs index e34ec88527f..02a818d0491 100644 --- a/vortex-array/src/arrays/listview/compute/mask.rs +++ b/vortex-array/src/arrays/listview/compute/mask.rs @@ -2,32 +2,32 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ListViewArray; use crate::arrays::ListViewVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for ListViewVTable { - fn mask(&self, array: &ListViewArray, mask: &Mask) -> VortexResult { - // SAFETY: Since we are only masking the validity and everything else comes from an already - // valid `ListViewArray`, all of the invariants are still upheld. - Ok(unsafe { - ListViewArray::new_unchecked( - array.elements().clone(), - array.offsets().clone(), - array.sizes().clone(), - array.validity().mask(mask), - ) - .with_zero_copy_to_list(array.is_zero_copy_to_list()) - } - .into_array()) +impl MaskReduce for ListViewVTable { + fn mask(array: &ListViewArray, mask: &ArrayRef) -> VortexResult> { + // SAFETY: masking the validity does not affect the invariants + Ok(Some( + unsafe { + ListViewArray::new_unchecked( + array.elements().clone(), + array.offsets().clone(), + array.sizes().clone(), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, + ) + .with_zero_copy_to_list(array.is_zero_copy_to_list()) + } + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(ListViewVTable).lift()); diff --git a/vortex-array/src/arrays/listview/compute/rules.rs b/vortex-array/src/arrays/listview/compute/rules.rs index 24c25fa1b3e..a293ae4655f 100644 --- a/vortex-array/src/arrays/listview/compute/rules.rs +++ b/vortex-array/src/arrays/listview/compute/rules.rs @@ -12,6 +12,7 @@ use crate::arrays::ListViewArray; use crate::arrays::ListViewVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; @@ -19,6 +20,7 @@ use crate::vtable::ValidityHelper; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&ListViewFilterPushDown), ParentRuleSet::lift(&CastReduceAdaptor(ListViewVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ListViewVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ListViewVTable)), ]); diff --git a/vortex-array/src/arrays/masked/compute/mask.rs b/vortex-array/src/arrays/masked/compute/mask.rs index a8d3c3eacd2..d1b1f35287f 100644 --- a/vortex-array/src/arrays/masked/compute/mask.rs +++ b/vortex-array/src/arrays/masked/compute/mask.rs @@ -2,29 +2,34 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask as MaskType; use crate::ArrayRef; -use crate::IntoArray; use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::arrays::ScalarFnArrayExt; +use crate::compute::MaskReduce; +use crate::expr::EmptyOptions; +use crate::expr::mask::Mask as MaskExpr; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for MaskedVTable { - fn mask(&self, array: &MaskedArray, mask_arg: &MaskType) -> VortexResult { - // Combine the mask with the existing validity - // The child remains unchanged (no nulls), only validity is updated - let combined_validity = array.validity().mask(mask_arg); - - Ok(MaskedArray::try_new(array.child.clone(), combined_validity)?.into_array()) +impl MaskReduce for MaskedVTable { + fn mask(array: &MaskedArray, mask: &ArrayRef) -> VortexResult> { + // AND the existing validity mask with the new mask and push into child. + let combined_mask = array + .validity() + .clone() + .and(Validity::Array(mask.clone()))? + .to_array(array.len()); + let masked_child = MaskExpr.try_new_array( + array.child.len(), + EmptyOptions, + [array.child.clone(), combined_mask], + )?; + Ok(Some(masked_child)) } } -register_kernel!(MaskKernelAdapter(MaskedVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/vortex-array/src/arrays/masked/compute/rules.rs b/vortex-array/src/arrays/masked/compute/rules.rs index 597f1844bb4..b4b7a527259 100644 --- a/vortex-array/src/arrays/masked/compute/rules.rs +++ b/vortex-array/src/arrays/masked/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::FilterReduceAdaptor; use crate::arrays::MaskedVTable; use crate::arrays::SliceReduceAdaptor; +use crate::compute::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(MaskedVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(MaskedVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(MaskedVTable)), ]); diff --git a/vortex-array/src/arrays/masked/execute.rs b/vortex-array/src/arrays/masked/execute.rs index 00e803de64e..496aad3a58a 100644 --- a/vortex-array/src/arrays/masked/execute.rs +++ b/vortex-array/src/arrays/masked/execute.rs @@ -153,8 +153,11 @@ fn mask_validity_extension( // For extension arrays, we need to mask the underlying storage let storage = array.storage().clone().execute::(ctx)?; let masked_storage = mask_validity_canonical(storage, mask, ctx)?; + let masked_storage = masked_storage.into_array(); Ok(ExtensionArray::new( - array.ext_dtype().clone(), - masked_storage.into_array(), + array + .ext_dtype() + .with_nullability(masked_storage.dtype().nullability()), + masked_storage, )) } diff --git a/vortex-array/src/arrays/null/compute/mask.rs b/vortex-array/src/arrays/null/compute/mask.rs index c513efc3bd0..37a9f506604 100644 --- a/vortex-array/src/arrays/null/compute/mask.rs +++ b/vortex-array/src/arrays/null/compute/mask.rs @@ -2,19 +2,15 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::arrays::NullArray; use crate::arrays::NullVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; -impl MaskKernel for NullVTable { - fn mask(&self, array: &NullArray, _mask: &Mask) -> VortexResult { - Ok(array.to_array()) +impl MaskReduce for NullVTable { + fn mask(array: &NullArray, _mask: &ArrayRef) -> VortexResult> { + // Null array is already all nulls, masking has no effect. + Ok(Some(array.to_array())) } } - -register_kernel!(MaskKernelAdapter(NullVTable).lift()); diff --git a/vortex-array/src/arrays/null/compute/rules.rs b/vortex-array/src/arrays/null/compute/rules.rs index 7fd25840dfd..645422884c2 100644 --- a/vortex-array/src/arrays/null/compute/rules.rs +++ b/vortex-array/src/arrays/null/compute/rules.rs @@ -6,11 +6,13 @@ use crate::arrays::NullVTable; use crate::arrays::SliceReduceAdaptor; use crate::arrays::TakeReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(NullVTable)), ParentRuleSet::lift(&CastReduceAdaptor(NullVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(NullVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(NullVTable)), ParentRuleSet::lift(&TakeReduceAdaptor(NullVTable)), ]); diff --git a/vortex-array/src/arrays/primitive/compute/mask.rs b/vortex-array/src/arrays/primitive/compute/mask.rs index 545d0847618..3adbca37a13 100644 --- a/vortex-array/src/arrays/primitive/compute/mask.rs +++ b/vortex-array/src/arrays/primitive/compute/mask.rs @@ -2,35 +2,32 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::PrimitiveVTable; use crate::arrays::primitive::PrimitiveArray; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for PrimitiveVTable { - fn mask(&self, array: &PrimitiveArray, mask: &Mask) -> VortexResult { - let validity = array.validity().mask(mask); - +impl MaskReduce for PrimitiveVTable { + fn mask(array: &PrimitiveArray, mask: &ArrayRef) -> VortexResult> { // SAFETY: validity and data buffer still have same length - Ok(unsafe { + Ok(Some(unsafe { PrimitiveArray::new_unchecked_from_handle( array.buffer_handle().clone(), array.ptype(), - validity, + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, ) .into_array() - }) + })) } } -register_kernel!(MaskKernelAdapter(PrimitiveVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; diff --git a/vortex-array/src/arrays/primitive/compute/rules.rs b/vortex-array/src/arrays/primitive/compute/rules.rs index 9bf20fd5cf6..c32e8e52dd5 100644 --- a/vortex-array/src/arrays/primitive/compute/rules.rs +++ b/vortex-array/src/arrays/primitive/compute/rules.rs @@ -11,12 +11,14 @@ use crate::arrays::MaskedVTable; use crate::arrays::PrimitiveArray; use crate::arrays::PrimitiveVTable; use crate::arrays::SliceReduceAdaptor; +use crate::compute::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; pub(crate) const RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&PrimitiveMaskedValidityRule), + ParentRuleSet::lift(&MaskReduceAdaptor(PrimitiveVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(PrimitiveVTable)), ]); diff --git a/vortex-array/src/arrays/struct_/compute/mask.rs b/vortex-array/src/arrays/struct_/compute/mask.rs index 75f40f81b17..80f57614933 100644 --- a/vortex-array/src/arrays/struct_/compute/mask.rs +++ b/vortex-array/src/arrays/struct_/compute/mask.rs @@ -2,28 +2,26 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::StructArray; use crate::arrays::StructVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for StructVTable { - fn mask(&self, array: &StructArray, filter_mask: &Mask) -> VortexResult { - let validity = array.validity().mask(filter_mask); - +impl MaskReduce for StructVTable { + fn mask(array: &StructArray, mask: &ArrayRef) -> VortexResult> { StructArray::try_new_with_dtype( array.unmasked_fields().clone(), array.struct_fields().clone(), array.len(), - validity, + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, ) - .map(|a| a.into_array()) + .map(|a| Some(a.into_array())) } } -register_kernel!(MaskKernelAdapter(StructVTable).lift()); diff --git a/vortex-array/src/arrays/struct_/compute/rules.rs b/vortex-array/src/arrays/struct_/compute/rules.rs index ff2df148c23..940d5a54c78 100644 --- a/vortex-array/src/arrays/struct_/compute/rules.rs +++ b/vortex-array/src/arrays/struct_/compute/rules.rs @@ -15,6 +15,7 @@ use crate::arrays::SliceReduceAdaptor; use crate::arrays::StructArray; use crate::arrays::StructVTable; use crate::builtins::ArrayBuiltins; +use crate::compute::MaskReduceAdaptor; use crate::expr::Cast; use crate::expr::EmptyOptions; use crate::expr::GetItem; @@ -27,6 +28,7 @@ use crate::vtable::ValidityHelper; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&StructCastPushDownRule), ParentRuleSet::lift(&StructGetItemRule), + ParentRuleSet::lift(&MaskReduceAdaptor(StructVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(StructVTable)), ]); diff --git a/vortex-array/src/arrays/varbin/compute/mask.rs b/vortex-array/src/arrays/varbin/compute/mask.rs index 5119ebbeb82..57cf8c89e36 100644 --- a/vortex-array/src/arrays/varbin/compute/mask.rs +++ b/vortex-array/src/arrays/varbin/compute/mask.rs @@ -2,31 +2,32 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::VarBinVTable; use crate::arrays::varbin::VarBinArray; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for VarBinVTable { - fn mask(&self, array: &VarBinArray, mask: &Mask) -> VortexResult { - Ok(VarBinArray::try_new( - array.offsets().clone(), - array.bytes().clone(), - array.dtype().as_nullable(), - array.validity().mask(mask), - )? - .into_array()) +impl MaskReduce for VarBinVTable { + fn mask(array: &VarBinArray, mask: &ArrayRef) -> VortexResult> { + Ok(Some( + VarBinArray::try_new( + array.offsets().clone(), + array.bytes().clone(), + array.dtype().as_nullable(), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, + )? + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(VarBinVTable).lift()); - #[cfg(test)] mod test { use vortex_dtype::DType; diff --git a/vortex-array/src/arrays/varbin/compute/rules.rs b/vortex-array/src/arrays/varbin/compute/rules.rs index df9f7f7913e..478d97e4746 100644 --- a/vortex-array/src/arrays/varbin/compute/rules.rs +++ b/vortex-array/src/arrays/varbin/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::SliceReduceAdaptor; use crate::arrays::VarBinVTable; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(VarBinVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(VarBinVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(VarBinVTable)), ]); diff --git a/vortex-array/src/arrays/varbinview/compute/mask.rs b/vortex-array/src/arrays/varbinview/compute/mask.rs index dd32dac5c8b..066b392f657 100644 --- a/vortex-array/src/arrays/varbinview/compute/mask.rs +++ b/vortex-array/src/arrays/varbinview/compute/mask.rs @@ -2,34 +2,35 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for VarBinViewVTable { - fn mask(&self, array: &VarBinViewArray, mask: &Mask) -> VortexResult { +impl MaskReduce for VarBinViewVTable { + fn mask(array: &VarBinViewArray, mask: &ArrayRef) -> VortexResult> { // SAFETY: masking the validity does not affect the invariants unsafe { - Ok(VarBinViewArray::new_handle_unchecked( - array.views_handle().clone(), - array.buffers().clone(), - array.dtype().as_nullable(), - array.validity().mask(mask), - ) - .into_array()) + Ok(Some( + VarBinViewArray::new_handle_unchecked( + array.views_handle().clone(), + array.buffers().clone(), + array.dtype().as_nullable(), + array + .validity() + .clone() + .and(Validity::Array(mask.clone()))?, + ) + .into_array(), + )) } } } -register_kernel!(MaskKernelAdapter(VarBinViewVTable).lift()); - #[cfg(test)] mod tests { use crate::arrays::VarBinViewArray; diff --git a/vortex-array/src/arrays/varbinview/compute/rules.rs b/vortex-array/src/arrays/varbinview/compute/rules.rs index 9b1900ef277..0bbd12d1540 100644 --- a/vortex-array/src/arrays/varbinview/compute/rules.rs +++ b/vortex-array/src/arrays/varbinview/compute/rules.rs @@ -3,9 +3,11 @@ use crate::arrays::SliceReduceAdaptor; use crate::arrays::VarBinViewVTable; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(VarBinViewVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(VarBinViewVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(VarBinViewVTable)), ]); diff --git a/vortex-array/src/compute/mask.rs b/vortex-array/src/compute/mask.rs index b92aa0ab3b8..cd9710b9536 100644 --- a/vortex-array/src/compute/mask.rs +++ b/vortex-array/src/compute/mask.rs @@ -1,15 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::LazyLock; +use std::ops::Not; -use arcref::ArcRef; -use arrow_array::BooleanArray; -use vortex_dtype::DType; -use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; use vortex_mask::Mask; use vortex_scalar::Scalar; @@ -17,184 +11,24 @@ use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; -use crate::arrow::FromArrowArray; -use crate::arrow::IntoArrowArray; use crate::builtins::ArrayBuiltins; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; -use crate::compute::Output; -use crate::vtable::VTable; - -static MASK_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - MASK_FN.kernels().len() -} /// Replace values with null where the mask is true. /// /// The returned array is nullable but otherwise has the same dtype and length as `array`. /// -/// # Examples -/// -/// ``` -/// use vortex_array::IntoArray; -/// use vortex_array::arrays::{BoolArray, PrimitiveArray}; -/// use vortex_array::compute::{ mask}; -/// use vortex_error::VortexResult; -/// use vortex_mask::Mask; -/// use vortex_scalar::Scalar; -/// -/// # fn main() -> VortexResult<()> { -/// let array = -/// PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]); -/// let mask_array = Mask::from_iter([true, false, false, false, true]); -/// -/// let masked = mask(array.as_ref(), &mask_array)?; -/// assert_eq!(masked.len(), 5); -/// assert!(!masked.is_valid(0).unwrap()); -/// assert!(!masked.is_valid(1).unwrap()); -/// assert_eq!(masked.scalar_at(2)?, Scalar::from(Some(1))); -/// assert!(!masked.is_valid(3).unwrap()); -/// assert!(!masked.is_valid(4).unwrap()); -/// # Ok(()) -/// # } -/// ``` -/// +/// This function returns a lazy `ScalarFnArray` wrapping the [`Mask`](crate::expr::mask::Mask) +/// expression that defers the actual masking operation until execution time. The mask is inverted +/// (true=mask-out becomes true=keep) and passed as a boolean child to the expression. pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult { - MASK_FN - .invoke(&InvocationArgs { - inputs: &[array.into(), mask.into()], - options: &(), - })? - .unwrap_array() -} - -pub struct MaskKernelRef(ArcRef); -inventory::collect!(MaskKernelRef); - -pub trait MaskKernel: VTable { - /// Replace masked values with null in array. - fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult; -} - -#[derive(Debug)] -pub struct MaskKernelAdapter(pub V); - -impl MaskKernelAdapter { - pub const fn lift(&'static self) -> MaskKernelRef { - MaskKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for MaskKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = MaskArgs::try_from(args)?; - let Some(array) = inputs.array.as_opt::() else { - return Ok(None); - }; - Ok(Some(V::mask(&self.0, array, inputs.mask)?.into())) - } -} - -struct MaskFn; - -impl ComputeFnVTable for MaskFn { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let MaskArgs { array, mask } = MaskArgs::try_from(args)?; - - let mask_true_count = mask.true_count(); - if mask_true_count == 0 { - // Fast-path for empty mask - return Ok(array.to_array().cast(array.dtype().as_nullable())?.into()); - } - - if mask_true_count == mask.len() { - // Fast-path for full mask. - return Ok( - ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len()) - .into_array() - .into(), - ); - } - - // Do nothing if the array is already all nulls. - if array.all_invalid()? { - return Ok(array.to_array().into()); - } - - for kernel in kernels { - if let Some(output) = kernel.invoke(args)? { - return Ok(output); - } - } - - // Fallback: implement using Arrow kernels. - tracing::debug!("No mask implementation found for {}", array.encoding_id()); - - let array_ref = array.to_array().into_arrow_preferred()?; - let mask = BooleanArray::new(mask.to_bit_buffer().into(), None); - - let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?; - - Ok(ArrayRef::from_arrow(masked.as_ref(), true)?.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let MaskArgs { array, .. } = MaskArgs::try_from(args)?; - Ok(array.dtype().as_nullable()) - } - - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - let MaskArgs { array, mask } = MaskArgs::try_from(args)?; - - if mask.len() != array.len() { - vortex_bail!( - "mask.len() is {}, does not equal array.len() of {}", - mask.len(), - array.len() - ); - } - - Ok(mask.len()) - } - - fn is_elementwise(&self) -> bool { - true - } -} - -struct MaskArgs<'a> { - array: &'a dyn Array, - mask: &'a Mask, -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> { - type Error = VortexError; - - fn try_from(value: &InvocationArgs<'a>) -> Result { - if value.inputs.len() != 2 { - vortex_bail!("Mask function requires 2 arguments"); - } - let array = value.inputs[0] - .array() - .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?; - let mask = value.inputs[1] - .mask() - .ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?; - - Ok(MaskArgs { array, mask }) + let mask = mask.not(); + match mask { + Mask::AllTrue(_) => array.to_array().cast(array.dtype().as_nullable()), + Mask::AllFalse(_) => Ok(ConstantArray::new( + Scalar::null(array.dtype().as_nullable()), + array.len(), + ) + .into_array()), + Mask::Values(val) => array.to_array().mask(val.into_array()), } } diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 87298c1e6e5..049c03af710 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -54,6 +54,10 @@ pub use crate::expr::FillNullExecuteAdaptor; pub use crate::expr::FillNullKernel; pub use crate::expr::FillNullReduce; pub use crate::expr::FillNullReduceAdaptor; +pub use crate::expr::MaskExecuteAdaptor; +pub use crate::expr::MaskKernel; +pub use crate::expr::MaskReduce; +pub use crate::expr::MaskReduceAdaptor; pub use crate::expr::NotExecuteAdaptor; pub use crate::expr::NotKernel; pub use crate::expr::NotReduce; @@ -97,7 +101,6 @@ pub fn warm_up_vtables() { is_constant::warm_up_vtable(); is_sorted::warm_up_vtable(); list_contains::warm_up_vtable(); - mask::warm_up_vtable(); min_max::warm_up_vtable(); nan_count::warm_up_vtable(); sum::warm_up_vtable(); diff --git a/vortex-array/src/expr/exprs/get_item.rs b/vortex-array/src/expr/exprs/get_item.rs index ca0d93740f5..76d88a5b23e 100644 --- a/vortex-array/src/expr/exprs/get_item.rs +++ b/vortex-array/src/expr/exprs/get_item.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::fmt::Formatter; -use std::ops::Not; use prost::Message; use vortex_dtype::DType; @@ -17,8 +16,8 @@ use vortex_session::VortexSession; use crate::ArrayRef; use crate::arrays::StructArray; +use crate::builtins::ArrayBuiltins; use crate::builtins::ExprBuiltins; -use crate::compute::mask; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::EmptyOptions; @@ -116,7 +115,7 @@ impl VTable for GetItem { match input.dtype().nullability() { Nullability::NonNullable => Ok(field), - Nullability::Nullable => mask(&field, &input.validity_mask()?.not()), + Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())), }? .execute(args.ctx) } diff --git a/vortex-array/src/expr/exprs/mask/kernel.rs b/vortex-array/src/expr/exprs/mask/kernel.rs new file mode 100644 index 00000000000..427241eab23 --- /dev/null +++ b/vortex-array/src/expr/exprs/mask/kernel.rs @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::arrays::BoolVTable; +use crate::arrays::ExactScalarFn; +use crate::arrays::ScalarFnArrayView; +use crate::expr::Mask as MaskExpr; +use crate::kernel::ExecuteParentKernel; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::vtable::VTable; + +/// Mask an array without reading buffers. +/// +/// This trait is for mask implementations that can operate purely on array metadata and +/// structure without needing to read or execute on the underlying buffers. Implementations +/// should return `None` if masking requires buffer access. +/// +/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out. +/// +/// # Preconditions +/// +/// The mask is guaranteed to have the same length as the array. Trivial cases +/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch. +pub trait MaskReduce: VTable { + fn mask(array: &Self::Array, mask: &ArrayRef) -> VortexResult>; +} + +/// Mask an array, potentially reading buffers. +/// +/// Unlike [`MaskReduce`], this trait is for mask implementations that may need to read +/// and execute on the underlying buffers to produce the masked result. +/// +/// The `mask` parameter is a boolean array where true=keep/valid, false=null-out. +/// +/// # Preconditions +/// +/// The mask is guaranteed to have the same length as the array. Trivial cases +/// (`AllValid`, `AllInvalid`, `NonNullable`) are handled by the caller before dispatch. +pub trait MaskKernel: VTable { + fn mask( + array: &Self::Array, + mask: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; +} + +/// Adaptor that wraps a [`MaskReduce`] impl as an [`ArrayParentReduceRule`]. +#[derive(Default, Debug)] +pub struct MaskReduceAdaptor(pub V); + +impl ArrayParentReduceRule for MaskReduceAdaptor +where + V: MaskReduce, +{ + type Parent = ExactScalarFn; + + fn reduce_parent( + &self, + array: &V::Array, + parent: ScalarFnArrayView<'_, MaskExpr>, + child_idx: usize, + ) -> VortexResult> { + // Only reduce the input child (index 0), not the mask child (index 1). + if child_idx != 0 { + return Ok(None); + } + // The mask child (child 1) is a non-nullable BoolArray where true=keep. + // If it's not yet a BoolArray, we can't reduce without execution. + let mask_child = parent + .nth_child(1) + .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?; + if mask_child.as_opt::().is_none() { + return Ok(None); + }; + ::mask(array, &mask_child) + } +} + +/// Adaptor that wraps a [`MaskKernel`] impl as an [`ExecuteParentKernel`]. +#[derive(Default, Debug)] +pub struct MaskExecuteAdaptor(pub V); + +impl ExecuteParentKernel for MaskExecuteAdaptor +where + V: MaskKernel, +{ + type Parent = ExactScalarFn; + + fn execute_parent( + &self, + array: &V::Array, + parent: ScalarFnArrayView<'_, MaskExpr>, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // Only execute the input child (index 0), not the mask child (index 1). + if child_idx != 0 { + return Ok(None); + } + let mask_child = parent + .nth_child(1) + .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?; + ::mask(array, &mask_child, ctx) + } +} diff --git a/vortex-array/src/expr/exprs/mask.rs b/vortex-array/src/expr/exprs/mask/mod.rs similarity index 72% rename from vortex-array/src/expr/exprs/mask.rs rename to vortex-array/src/expr/exprs/mask/mod.rs index d93618829d1..e427111fb0e 100644 --- a/vortex-array/src/expr/exprs/mask.rs +++ b/vortex-array/src/expr/exprs/mask/mod.rs @@ -1,9 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +mod kernel; use std::fmt::Formatter; -use std::ops::Not; +pub use kernel::*; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_error::VortexExpect; @@ -14,8 +15,13 @@ use vortex_scalar::Scalar; use vortex_session::VortexSession; use crate::ArrayRef; +use crate::Canonical; +use crate::IntoArray; use crate::arrays::BoolArray; -use crate::compute; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; +use crate::arrays::mask_validity_canonical; +use crate::builtins::ArrayBuiltins; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::EmptyOptions; @@ -94,9 +100,11 @@ impl VTable for Mask { .try_into() .map_err(|_| vortex_err!("Wrong arg count"))?; - let mask_bool = mask_array.execute::(args.ctx)?; - let inverted = mask_bool.to_bit_buffer().not(); - compute::mask(&input, &vortex_mask::Mask::from(inverted))?.execute(args.ctx) + if let Some(result) = execute_constant(&input, &mask_array)? { + return Ok(result); + } + + execute_canonical(input, mask_array, args.ctx) } fn simplify( @@ -136,6 +144,47 @@ impl VTable for Mask { } } +/// Try to handle masking when at least one of the input or mask is a constant array. +/// +/// Returns `Ok(Some(result))` if the constant case was handled, `Ok(None)` if not. +fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult> { + let len = input.len(); + + if let Some(constant_mask) = mask_array.as_opt::() { + let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false); + return if mask_value { + input.cast(input.dtype().as_nullable()).map(Some) + } else { + Ok(Some( + ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(), + )) + }; + } + + if let Some(constant_input) = input.as_opt::() + && constant_input.scalar().is_null() + { + return Ok(Some( + ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(), + )); + } + + Ok(None) +} + +/// Execute the mask by materializing both inputs to their canonical forms. +fn execute_canonical( + input: ArrayRef, + mask_array: ArrayRef, + ctx: &mut crate::executor::ExecutionCtx, +) -> VortexResult { + let mask_bool = mask_array.execute::(ctx)?; + let validity_mask = vortex_mask::Mask::from(mask_bool.to_bit_buffer()); + + let canonical = input.execute::(ctx)?; + Ok(mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array()) +} + /// Creates a mask expression that applies the given boolean mask to the input array. pub fn mask(array: Expression, mask: Expression) -> Expression { Mask.new_expr(EmptyOptions, [array, mask]) diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 09f148166c8..b76a4f395c8 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -32,6 +32,7 @@ use crate::builtins::ArrayBuiltins; use crate::compute::sum; use crate::expr::Binary; use crate::expr::Operator; +use crate::optimizer::ArrayOptimizer; use crate::patches::Patches; /// Validity information for an array @@ -197,6 +198,16 @@ impl Validity { } } + // Invert the validity + pub fn not(&self) -> VortexResult { + match self { + Validity::NonNullable => Ok(Validity::NonNullable), + Validity::AllValid => Ok(Validity::AllInvalid), + Validity::AllInvalid => Ok(Validity::AllValid), + Validity::Array(arr) => Ok(Validity::Array(arr.not()?)), + } + } + /// Lazily filters a [`Validity`] with a selection mask, which keeps only the entries for which /// the mask is true. /// @@ -221,30 +232,6 @@ impl Validity { } } - /// Set to false any entries for which the mask is true. - /// - /// The result is always nullable. The result has the same length as self. - #[inline] - pub fn mask(&self, mask: &Mask) -> Self { - match mask.bit_buffer() { - AllOr::All => Validity::AllInvalid, - AllOr::None => self.clone().into_nullable(), - AllOr::Some(make_invalid) => match self { - Validity::NonNullable | Validity::AllValid => { - Validity::from_bit_buffer(!make_invalid, Nullability::Nullable) - } - Validity::AllInvalid => Validity::AllInvalid, - Validity::Array(is_valid) => { - let is_valid = is_valid.to_bool(); - Validity::from_bit_buffer( - is_valid.to_bit_buffer() & !make_invalid, - Nullability::Nullable, - ) - } - }, - } - } - #[inline] pub fn to_mask(&self, length: usize) -> Mask { match self { @@ -281,9 +268,11 @@ impl Validity { | (Validity::AllValid, Validity::NonNullable) | (Validity::AllValid, Validity::AllValid) => Validity::AllValid, // Here we actually have to do some work - (Validity::Array(lhs), Validity::Array(rhs)) => { - Validity::Array(Binary.try_new_array(lhs.len(), Operator::And, [lhs, rhs])?) - } + (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array( + Binary + .try_new_array(lhs.len(), Operator::And, [lhs, rhs])? + .optimize()?, + ), }) } @@ -665,12 +654,4 @@ mod tests { ) { assert_eq!(validity.take(&indices).unwrap(), expected); } - - #[test] - fn mask_non_nullable() { - assert_eq!( - Validity::AllValid, - Validity::NonNullable.mask(&Mask::AllFalse(2)) - ) - } } diff --git a/vortex-duckdb/src/exporter/struct_.rs b/vortex-duckdb/src/exporter/struct_.rs index 140a4ded143..eda7024106b 100644 --- a/vortex-duckdb/src/exporter/struct_.rs +++ b/vortex-duckdb/src/exporter/struct_.rs @@ -1,16 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::ops::Not; - use vortex::array::ExecutionCtx; use vortex::array::IntoArray; +use vortex::array::arrays::BoolArray; use vortex::array::arrays::StructArray; use vortex::array::arrays::StructArrayParts; -use vortex::array::optimizer::ArrayOptimizer; -use vortex::compute::mask; +use vortex::array::builtins::ArrayBuiltins; use vortex::error::VortexResult; -use vortex::mask::Mask; use crate::LogicalType; use crate::duckdb::Vector; @@ -36,9 +33,9 @@ pub(crate) fn new_exporter( fields, .. } = array.into_parts(); - let validity = validity.to_array(len).execute::(ctx)?; + let validity = validity.to_array(len).execute::(ctx)?; - if validity.all_false() { + if validity.to_bit_buffer().true_count() == 0 { return Ok(all_invalid::new_exporter( len, &LogicalType::try_from(struct_fields)?, @@ -48,20 +45,16 @@ pub(crate) fn new_exporter( let children = fields .iter() .map(|child| { - if matches!(validity, Mask::Values(_)) { + if validity.to_bit_buffer().true_count() != validity.len() { // TODO(joe): use new mask. - new_array_exporter( - mask(child, &validity.clone().not())?.optimize()?, - cache, - ctx, - ) + new_array_exporter(child.clone().mask(validity.to_array())?, cache, ctx) } else { new_array_exporter(child.clone().into_array(), cache, ctx) } }) .collect::>>()?; Ok(validity::new_exporter( - validity, + validity.to_mask(), Box::new(StructExporter { children }), )) } diff --git a/vortex-layout/src/layouts/struct_/reader.rs b/vortex-layout/src/layouts/struct_/reader.rs index acc9f07f1b8..27c9f2987c2 100644 --- a/vortex-layout/src/layouts/struct_/reader.rs +++ b/vortex-layout/src/layouts/struct_/reader.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::collections::BTreeSet; -use std::ops::Not; use std::ops::Range; use std::sync::Arc; @@ -13,6 +12,7 @@ use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::ToCanonical; use vortex_array::arrays::StructArray; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::expr::ExactExpr; use vortex_array::expr::Expression; use vortex_array::expr::Merge; @@ -343,7 +343,6 @@ impl LayoutReader for StructReader { Ok(Box::pin(async move { if let Some(validity_fut) = validity_fut { let (array, validity) = try_join!(projected, validity_fut)?; - let mask = Mask::from_buffer(validity.to_bool().to_bit_buffer().not()); // If root expression was a pack, then we apply the validity to each child field if is_pack_merge { @@ -351,7 +350,7 @@ impl LayoutReader for StructReader { let masked_fields: Vec = struct_array .unmasked_fields() .iter() - .map(|a| vortex_array::compute::mask(a.as_ref(), &mask)) + .map(|a| a.clone().mask(validity.clone())) .try_collect()?; Ok(StructArray::try_new( @@ -364,7 +363,7 @@ impl LayoutReader for StructReader { } else { // If the root expression was not a pack or merge, e.g. if it's something like // a get_item, then we apply the validity directly to the result - vortex_array::compute::mask(array.as_ref(), &mask) + array.mask(validity) } } else { projected.await diff --git a/vortex-layout/src/reader.rs b/vortex-layout/src/reader.rs index e26e33372e5..2072c5c209c 100644 --- a/vortex-layout/src/reader.rs +++ b/vortex-layout/src/reader.rs @@ -9,7 +9,9 @@ use futures::future::BoxFuture; use futures::try_join; use once_cell::sync::OnceCell; use vortex_array::ArrayRef; +use vortex_array::IntoArray; use vortex_array::MaskFuture; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::expr::Expression; use vortex_dtype::DType; use vortex_dtype::FieldMask; @@ -98,7 +100,7 @@ impl ArrayFutureExt for ArrayFuture { fn masked(self, mask: MaskFuture) -> Self { Box::pin(async move { let (array, mask) = try_join!(self, mask)?; - vortex_array::compute::mask(array.as_ref(), &mask) + array.mask(mask.into_array()) }) } } diff --git a/vortex-test/e2e/src/lib.rs b/vortex-test/e2e/src/lib.rs index 2e8d3c77be0..765d4de32bd 100644 --- a/vortex-test/e2e/src/lib.rs +++ b/vortex-test/e2e/src/lib.rs @@ -28,7 +28,7 @@ mod tests { #[cfg(feature = "unstable_encodings")] const EXPECTED_SIZE: usize = 216188; #[cfg(not(feature = "unstable_encodings"))] - const EXPECTED_SIZE: usize = 216156; + const EXPECTED_SIZE: usize = 216188; let futures: Vec<_> = (0..5) .map(|_| { let array = array.clone();