diff --git a/vortex-cuda/src/layout.rs b/vortex-cuda/src/layout.rs index 33a011115cd..098d354fe12 100644 --- a/vortex-cuda/src/layout.rs +++ b/vortex-cuda/src/layout.rs @@ -33,7 +33,9 @@ use vortex_array::normalize::Operation; use vortex_array::serde::ArrayParts; use vortex_array::serde::SerializeOptions; use vortex_array::session::ArrayRegistry; +use vortex_array::stats::StatsSetRef; use vortex_buffer::Alignment; +use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_dtype::FieldMask; @@ -54,8 +56,6 @@ use vortex_layout::LayoutRef; use vortex_layout::LayoutStrategy; use vortex_layout::VTable; use vortex_layout::layouts::SharedArrayFuture; -use vortex_layout::layouts::zoned::lower_bound; -use vortex_layout::layouts::zoned::upper_bound; use vortex_layout::segments::SegmentId; use vortex_layout::segments::SegmentSinkRef; use vortex_layout::segments::SegmentSource; @@ -63,6 +63,10 @@ use vortex_layout::sequence::SendableSequentialStream; use vortex_layout::sequence::SequencePointer; use vortex_layout::vtable; use vortex_mask::Mask; +use vortex_scalar::Scalar; +use vortex_scalar::ScalarTruncation; +use vortex_scalar::lower_bound; +use vortex_scalar::upper_bound; use vortex_session::VortexSession; use vortex_utils::aliases::hash_map::HashMap; @@ -456,6 +460,22 @@ impl CudaFlatLayoutStrategy { } } +fn truncate_scalar_stat Option<(Scalar, bool)>>( + statistics: StatsSetRef<'_>, + stat: Stat, + truncation: F, +) { + if let Some(sv) = statistics.get(stat) { + if let Some((truncated_value, truncated)) = truncation(sv.into_inner()) { + if truncated && let Some(v) = truncated_value.into_value() { + statistics.set(stat, Precision::Inexact(v)); + } + } else { + statistics.clear(stat) + } + } +} + #[async_trait] impl LayoutStrategy for CudaFlatLayoutStrategy { async fn write_stream( @@ -474,55 +494,42 @@ impl LayoutStrategy for CudaFlatLayoutStrategy { let (sequence_id, chunk) = chunk?; let row_count = chunk.len() as u64; - // Truncate variable-length statistics. match chunk.dtype() { - DType::Utf8(_) => { - if let Some(sv) = chunk.statistics().get(Stat::Min) { - let (value, truncated) = lower_bound( - sv.into_inner().as_utf8(), - options.max_variable_length_statistics_size, - ); - if truncated && let Some(v) = value.into_value() { - chunk.statistics().set(Stat::Min, Precision::Inexact(v)); - } - } - if let Some(sv) = chunk.statistics().get(Stat::Max) { - let (value, truncated) = upper_bound( - sv.into_inner().as_utf8(), - options.max_variable_length_statistics_size, - ); - if let Some(upper_bound) = value { - if truncated && let Some(v) = upper_bound.into_value() { - chunk.statistics().set(Stat::Max, Precision::Inexact(v)); - } - } else { - chunk.statistics().clear(Stat::Max) - } - } + DType::Utf8(n) => { + truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| { + lower_bound( + BufferString::from_scalar(v) + .vortex_expect("utf8 scalar must be a BufferString"), + self.max_variable_length_statistics_size, + *n, + ) + }); + truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| { + upper_bound( + BufferString::from_scalar(v) + .vortex_expect("utf8 scalar must be a BufferString"), + self.max_variable_length_statistics_size, + *n, + ) + }); } - DType::Binary(_) => { - if let Some(sv) = chunk.statistics().get(Stat::Min) { - let (value, truncated) = lower_bound( - sv.into_inner().as_binary(), - options.max_variable_length_statistics_size, - ); - if truncated && let Some(v) = value.into_value() { - chunk.statistics().set(Stat::Min, Precision::Inexact(v)); - } - } - if let Some(sv) = chunk.statistics().get(Stat::Max) { - let (value, truncated) = upper_bound( - sv.into_inner().as_binary(), - options.max_variable_length_statistics_size, - ); - if let Some(upper_bound) = value { - if truncated && let Some(v) = upper_bound.into_value() { - chunk.statistics().set(Stat::Max, Precision::Inexact(v)); - } - } else { - chunk.statistics().clear(Stat::Max) - } - } + DType::Binary(n) => { + truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| { + lower_bound( + ByteBuffer::from_scalar(v) + .vortex_expect("binary scalar must be a ByteBuffer"), + self.max_variable_length_statistics_size, + *n, + ) + }); + truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| { + upper_bound( + ByteBuffer::from_scalar(v) + .vortex_expect("binary scalar must be a ByteBuffer"), + self.max_variable_length_statistics_size, + *n, + ) + }); } _ => {} } diff --git a/vortex-layout/public-api.lock b/vortex-layout/public-api.lock index 9724b7bec21..c6614258e2b 100644 --- a/vortex-layout/public-api.lock +++ b/vortex-layout/public-api.lock @@ -960,10 +960,6 @@ pub const vortex_layout::layouts::zoned::MAX_IS_TRUNCATED: &str pub const vortex_layout::layouts::zoned::MIN_IS_TRUNCATED: &str -pub fn vortex_layout::layouts::zoned::lower_bound(value: impl vortex_layout::layouts::zoned::builder::ScalarTruncation, max_length: usize) -> (vortex_scalar::scalar::Scalar, bool) - -pub fn vortex_layout::layouts::zoned::upper_bound(value: impl vortex_layout::layouts::zoned::builder::ScalarTruncation, max_length: usize) -> (core::option::Option, bool) - pub type vortex_layout::layouts::SharedArrayFuture = futures_util::future::future::shared::Shared>> pub mod vortex_layout::segments diff --git a/vortex-layout/src/layouts/flat/writer.rs b/vortex-layout/src/layouts/flat/writer.rs index d362774810f..33e6e734e71 100644 --- a/vortex-layout/src/layouts/flat/writer.rs +++ b/vortex-layout/src/layouts/flat/writer.rs @@ -12,18 +12,24 @@ use vortex_array::normalize::NormalizeOptions; use vortex_array::normalize::Operation; use vortex_array::serde::SerializeOptions; use vortex_array::session::ArrayRegistry; +use vortex_array::stats::StatsSetRef; +use vortex_buffer::BufferString; +use vortex_buffer::ByteBuffer; use vortex_dtype::DType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_io::runtime::Handle; +use vortex_scalar::Scalar; +use vortex_scalar::ScalarTruncation; +use vortex_scalar::lower_bound; +use vortex_scalar::upper_bound; use crate::IntoLayout; use crate::LayoutRef; use crate::LayoutStrategy; use crate::layouts::flat::FlatLayout; use crate::layouts::flat::flat_layout_inline_array_node; -use crate::layouts::zoned::lower_bound; -use crate::layouts::zoned::upper_bound; use crate::segments::SegmentSinkRef; use crate::sequence::SendableSequentialStream; use crate::sequence::SequencePointer; @@ -69,6 +75,22 @@ impl FlatLayoutStrategy { } } +fn truncate_scalar_stat Option<(Scalar, bool)>>( + statistics: StatsSetRef<'_>, + stat: Stat, + truncation: F, +) { + if let Some(sv) = statistics.get(stat) { + if let Some((truncated_value, truncated)) = truncation(sv.into_inner()) { + if truncated && let Some(v) = truncated_value.into_value() { + statistics.set(stat, Precision::Inexact(v)); + } + } else { + statistics.clear(stat) + } + } +} + #[async_trait] impl LayoutStrategy for FlatLayoutStrategy { async fn write_stream( @@ -80,7 +102,6 @@ impl LayoutStrategy for FlatLayoutStrategy { _handle: Handle, ) -> VortexResult { let ctx = ctx.clone(); - let options = self.clone(); let Some(chunk) = stream.next().await else { vortex_bail!("flat layout needs a single chunk"); }; @@ -89,60 +110,46 @@ impl LayoutStrategy for FlatLayoutStrategy { let row_count = chunk.len() as u64; match chunk.dtype() { - DType::Utf8(_) => { - if let Some(sv) = chunk.statistics().get(Stat::Min) { - let (value, truncated) = lower_bound( - sv.into_inner().as_utf8(), - options.max_variable_length_statistics_size, - ); - if truncated && let Some(v) = value.into_value() { - chunk.statistics().set(Stat::Min, Precision::Inexact(v)); - } - } - - if let Some(sv) = chunk.statistics().get(Stat::Max) { - let (value, truncated) = upper_bound( - sv.into_inner().as_utf8(), - options.max_variable_length_statistics_size, - ); - if let Some(upper_bound) = value { - if truncated && let Some(v) = upper_bound.into_value() { - chunk.statistics().set(Stat::Max, Precision::Inexact(v)); - } - } else { - chunk.statistics().clear(Stat::Max) - } - } + DType::Utf8(n) => { + truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| { + lower_bound( + BufferString::from_scalar(v) + .vortex_expect("utf8 scalar must be a BufferString"), + self.max_variable_length_statistics_size, + *n, + ) + }); + truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| { + upper_bound( + BufferString::from_scalar(v) + .vortex_expect("utf8 scalar must be a BufferString"), + self.max_variable_length_statistics_size, + *n, + ) + }); } - DType::Binary(_) => { - if let Some(sv) = chunk.statistics().get(Stat::Min) { - let (value, truncated) = lower_bound( - sv.into_inner().as_binary(), - options.max_variable_length_statistics_size, - ); - if truncated && let Some(v) = value.into_value() { - chunk.statistics().set(Stat::Min, Precision::Inexact(v)); - } - } - - if let Some(sv) = chunk.statistics().get(Stat::Max) { - let (value, truncated) = upper_bound( - sv.into_inner().as_binary(), - options.max_variable_length_statistics_size, - ); - if let Some(upper_bound) = value { - if truncated && let Some(v) = upper_bound.into_value() { - chunk.statistics().set(Stat::Max, Precision::Inexact(v)); - } - } else { - chunk.statistics().clear(Stat::Max) - } - } + DType::Binary(n) => { + truncate_scalar_stat(chunk.statistics(), Stat::Min, |v| { + lower_bound( + ByteBuffer::from_scalar(v) + .vortex_expect("binary scalar must be a ByteBuffer"), + self.max_variable_length_statistics_size, + *n, + ) + }); + truncate_scalar_stat(chunk.statistics(), Stat::Max, |v| { + upper_bound( + ByteBuffer::from_scalar(v) + .vortex_expect("binary scalar must be a ByteBuffer"), + self.max_variable_length_statistics_size, + *n, + ) + }); } _ => {} } - let chunk = if let Some(allowed) = &options.allowed_encodings { + let chunk = if let Some(allowed) = &self.allowed_encodings { chunk.normalize(&mut NormalizeOptions { allowed, operation: Operation::Error, @@ -155,7 +162,7 @@ impl LayoutStrategy for FlatLayoutStrategy { &ctx, &SerializeOptions { offset: 0, - include_padding: options.include_padding, + include_padding: self.include_padding, }, )?; // there is at least the flatbuffer and the length diff --git a/vortex-layout/src/layouts/zoned/builder.rs b/vortex-layout/src/layouts/zoned/builder.rs index afe9dade461..715fdbed9fe 100644 --- a/vortex-layout/src/layouts/zoned/builder.rs +++ b/vortex-layout/src/layouts/zoned/builder.rs @@ -11,14 +11,16 @@ use vortex_array::builders::ArrayBuilder; use vortex_array::builders::BoolBuilder; use vortex_array::builders::builder_with_capacity; use vortex_array::expr::stats::Stat; +use vortex_buffer::BufferString; +use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_dtype::FieldName; use vortex_dtype::Nullability; use vortex_error::VortexResult; -use vortex_error::vortex_err; -use vortex_scalar::BinaryScalar; use vortex_scalar::Scalar; -use vortex_scalar::Utf8Scalar; +use vortex_scalar::ScalarTruncation; +use vortex_scalar::lower_bound; +use vortex_scalar::upper_bound; pub const MAX_IS_TRUNCATED: &str = "max_is_truncated"; pub const MIN_IS_TRUNCATED: &str = "min_is_truncated"; @@ -32,12 +34,12 @@ pub fn stats_builder_with_capacity( let values_builder = builder_with_capacity(dtype, capacity); match stat { Stat::Max => match dtype { - DType::Utf8(_) => Box::new(TruncatedMaxBinaryStatsBuilder::::new( + DType::Utf8(_) => Box::new(TruncatedMaxBinaryStatsBuilder::::new( values_builder, BoolBuilder::with_capacity(Nullability::NonNullable, capacity), max_length, )), - DType::Binary(_) => Box::new(TruncatedMaxBinaryStatsBuilder::::new( + DType::Binary(_) => Box::new(TruncatedMaxBinaryStatsBuilder::::new( values_builder, BoolBuilder::with_capacity(Nullability::NonNullable, capacity), max_length, @@ -45,12 +47,12 @@ pub fn stats_builder_with_capacity( _ => Box::new(StatNameArrayBuilder::new(stat, values_builder)), }, Stat::Min => match dtype { - DType::Utf8(_) => Box::new(TruncatedMinBinaryStatsBuilder::::new( + DType::Utf8(_) => Box::new(TruncatedMinBinaryStatsBuilder::::new( values_builder, BoolBuilder::with_capacity(Nullability::NonNullable, capacity), max_length, )), - DType::Binary(_) => Box::new(TruncatedMinBinaryStatsBuilder::::new( + DType::Binary(_) => Box::new(TruncatedMinBinaryStatsBuilder::::new( values_builder, BoolBuilder::with_capacity(Nullability::NonNullable, capacity), max_length, @@ -173,81 +175,16 @@ impl TruncatedMinBinaryStatsBuilder { } } -pub trait ScalarTruncation: Send + Sized { - fn from_scalar(value: &Scalar) -> VortexResult; - - fn len(&self) -> Option; - - fn into_scalar(self) -> Scalar; - - fn upper_bound(self, max_length: usize) -> Option; - - fn lower_bound(self, max_length: usize) -> Scalar; -} - -impl ScalarTruncation for BinaryScalar<'_> { - fn from_scalar(value: &Scalar) -> VortexResult { - value - .as_binary_opt() - .ok_or_else(|| vortex_err!("Expected binary scalar, found {}", value.dtype())) - } - - fn len(&self) -> Option { - self.len() - } - - fn into_scalar(self) -> Scalar { - self.value() - .cloned() - .map(|b| Scalar::binary(b, self.dtype().nullability())) - .unwrap_or_else(|| Scalar::null(self.dtype().clone())) - } - - fn upper_bound(self, max_length: usize) -> Option { - BinaryScalar::upper_bound(&self, max_length) - } - - fn lower_bound(self, max_length: usize) -> Scalar { - BinaryScalar::lower_bound(&self, max_length) - } -} - -impl ScalarTruncation for Utf8Scalar<'_> { - fn from_scalar(value: &Scalar) -> VortexResult { - value - .as_utf8_opt() - .ok_or_else(|| vortex_err!("Expected utf8 scalar, found {}", value.dtype())) - } - - fn len(&self) -> Option { - self.len() - } - - fn into_scalar(self) -> Scalar { - self.value() - .cloned() - .map(|b| Scalar::utf8(b, self.dtype().nullability())) - .unwrap_or_else(|| Scalar::null(self.dtype().clone())) - } - - fn upper_bound(self, max_length: usize) -> Option { - Utf8Scalar::upper_bound(&self, max_length) - } - - fn lower_bound(self, max_length: usize) -> Scalar { - Utf8Scalar::lower_bound(&self, max_length) - } -} - impl StatsArrayBuilder for TruncatedMaxBinaryStatsBuilder { fn stat(&self) -> Stat { Stat::Max } fn append_scalar(&mut self, value: Scalar) -> VortexResult<()> { - let (value, truncated) = upper_bound(T::from_scalar(&value)?, self.max_value_length); - - if let Some(upper_bound) = value { + let nullability = value.dtype().nullability(); + if let Some((upper_bound, truncated)) = + upper_bound(T::from_scalar(value)?, self.max_value_length, nullability) + { self.values.append_scalar(&upper_bound)?; self.is_truncated.append_value(truncated); } else { @@ -256,6 +193,7 @@ impl StatsArrayBuilder for TruncatedMaxBinaryStatsBuilder StatsArrayBuilder for TruncatedMinBinaryStatsBuilder VortexResult<()> { - let (value, truncated) = lower_bound(T::from_scalar(&value)?, self.max_value_length); - self.values.append_scalar(&value)?; - self.is_truncated.append_value(truncated); + let nullability = value.dtype().nullability(); + if let Some((lower_bound, truncated)) = + lower_bound(T::from_scalar(value)?, self.max_value_length, nullability) + { + self.values.append_scalar(&lower_bound)?; + self.is_truncated.append_value(truncated); + } else { + self.append_null() + } Ok(()) } + #[inline] fn append_null(&mut self) { ArrayBuilder::append_null(self.values.as_mut()); self.is_truncated.append_value(false); @@ -299,19 +244,3 @@ impl StatsArrayBuilder for TruncatedMinBinaryStatsBuilder (Scalar, bool) { - if value.len().unwrap_or(0) > max_length { - (value.lower_bound(max_length), true) - } else { - (value.into_scalar(), false) - } -} - -pub fn upper_bound(value: impl ScalarTruncation, max_length: usize) -> (Option, bool) { - if value.len().unwrap_or(0) > max_length { - (value.upper_bound(max_length), true) - } else { - (Some(value.into_scalar()), false) - } -} diff --git a/vortex-layout/src/layouts/zoned/mod.rs b/vortex-layout/src/layouts/zoned/mod.rs index 8cd63e3e429..c8e353cb955 100644 --- a/vortex-layout/src/layouts/zoned/mod.rs +++ b/vortex-layout/src/layouts/zoned/mod.rs @@ -10,8 +10,6 @@ use std::sync::Arc; pub use builder::MAX_IS_TRUNCATED; pub use builder::MIN_IS_TRUNCATED; -pub use builder::lower_bound; -pub use builder::upper_bound; use vortex_array::ArrayContext; use vortex_array::DeserializeMetadata; use vortex_array::SerializeMetadata; diff --git a/vortex-scalar/public-api.lock b/vortex-scalar/public-api.lock index 012bef04a4e..203377da3da 100644 --- a/vortex-scalar/public-api.lock +++ b/vortex-scalar/public-api.lock @@ -416,6 +416,18 @@ pub fn vortex_scalar::ScalarValue::as_primitive(&self) -> &vortex_scalar::PValue pub fn vortex_scalar::ScalarValue::as_utf8(&self) -> &vortex_buffer::string::BufferString +pub fn vortex_scalar::ScalarValue::into_binary(self) -> vortex_buffer::ByteBuffer + +pub fn vortex_scalar::ScalarValue::into_bool(self) -> bool + +pub fn vortex_scalar::ScalarValue::into_decimal(self) -> vortex_scalar::DecimalValue + +pub fn vortex_scalar::ScalarValue::into_list(self) -> alloc::vec::Vec> + +pub fn vortex_scalar::ScalarValue::into_primitive(self) -> vortex_scalar::PValue + +pub fn vortex_scalar::ScalarValue::into_utf8(self) -> vortex_buffer::string::BufferString + impl vortex_scalar::ScalarValue pub fn vortex_scalar::ScalarValue::default_value(dtype: &vortex_dtype::dtype::DType) -> core::option::Option @@ -632,12 +644,8 @@ pub fn vortex_scalar::BinaryScalar<'a>::is_empty(&self) -> core::option::Option< pub fn vortex_scalar::BinaryScalar<'a>::len(&self) -> core::option::Option -pub fn vortex_scalar::BinaryScalar<'a>::lower_bound(&self, max_length: usize) -> vortex_scalar::Scalar - pub fn vortex_scalar::BinaryScalar<'a>::try_new(dtype: &'a vortex_dtype::dtype::DType, value: core::option::Option<&'a vortex_scalar::ScalarValue>) -> vortex_error::VortexResult -pub fn vortex_scalar::BinaryScalar<'a>::upper_bound(&self, max_length: usize) -> core::option::Option - pub fn vortex_scalar::BinaryScalar<'a>::value(&self) -> core::option::Option<&'a vortex_buffer::ByteBuffer> impl core::cmp::Eq for vortex_scalar::BinaryScalar<'_> @@ -1726,12 +1734,8 @@ pub fn vortex_scalar::Utf8Scalar<'a>::is_empty(&self) -> core::option::Option::len(&self) -> core::option::Option -pub fn vortex_scalar::Utf8Scalar<'a>::lower_bound(&self, max_length: usize) -> vortex_scalar::Scalar - pub fn vortex_scalar::Utf8Scalar<'a>::try_new(dtype: &'a vortex_dtype::dtype::DType, value: core::option::Option<&'a vortex_scalar::ScalarValue>) -> vortex_error::VortexResult -pub fn vortex_scalar::Utf8Scalar<'a>::upper_bound(&self, max_length: usize) -> core::option::Option - pub fn vortex_scalar::Utf8Scalar<'a>::value(&self) -> core::option::Option<&'a vortex_buffer::string::BufferString> impl core::cmp::Ord for vortex_scalar::Utf8Scalar<'_> @@ -1770,6 +1774,42 @@ impl<'a> core::hash::Hash for vortex_scalar::Utf8Scalar<'a> pub fn vortex_scalar::Utf8Scalar<'a>::hash<__H: core::hash::Hasher>(&self, state: &mut __H) +pub trait vortex_scalar::ScalarTruncation: core::marker::Send + core::marker::Sized + +pub fn vortex_scalar::ScalarTruncation::from_scalar(value: vortex_scalar::Scalar) -> vortex_error::VortexResult> + +pub fn vortex_scalar::ScalarTruncation::into_scalar(self, nullability: vortex_dtype::nullability::Nullability) -> vortex_scalar::Scalar + +pub fn vortex_scalar::ScalarTruncation::len(&self) -> usize + +pub fn vortex_scalar::ScalarTruncation::lower_bound(self, max_length: usize) -> Self + +pub fn vortex_scalar::ScalarTruncation::upper_bound(self, max_length: usize) -> core::option::Option + +impl vortex_scalar::ScalarTruncation for vortex_buffer::ByteBuffer + +pub fn vortex_buffer::ByteBuffer::from_scalar(value: vortex_scalar::Scalar) -> vortex_error::VortexResult> + +pub fn vortex_buffer::ByteBuffer::into_scalar(self, nullability: vortex_dtype::nullability::Nullability) -> vortex_scalar::Scalar + +pub fn vortex_buffer::ByteBuffer::len(&self) -> usize + +pub fn vortex_buffer::ByteBuffer::lower_bound(self, max_length: usize) -> Self + +pub fn vortex_buffer::ByteBuffer::upper_bound(self, max_length: usize) -> core::option::Option + +impl vortex_scalar::ScalarTruncation for vortex_buffer::string::BufferString + +pub fn vortex_buffer::string::BufferString::from_scalar(value: vortex_scalar::Scalar) -> vortex_error::VortexResult> + +pub fn vortex_buffer::string::BufferString::into_scalar(self, nullability: vortex_dtype::nullability::Nullability) -> vortex_scalar::Scalar + +pub fn vortex_buffer::string::BufferString::len(&self) -> usize + +pub fn vortex_buffer::string::BufferString::lower_bound(self, max_length: usize) -> Self + +pub fn vortex_buffer::string::BufferString::upper_bound(self, max_length: usize) -> core::option::Option + pub trait vortex_scalar::StringLike: vortex_scalar::typed_view::utf8::private::Sealed + core::marker::Sized pub fn vortex_scalar::StringLike::increment(self) -> core::result::Result @@ -1781,3 +1821,7 @@ pub fn alloc::string::String::increment(self) -> core::result::Result core::result::Result + +pub fn vortex_scalar::lower_bound(value: core::option::Option, max_length: usize, nullability: vortex_dtype::nullability::Nullability) -> core::option::Option<(vortex_scalar::Scalar, bool)> + +pub fn vortex_scalar::upper_bound(value: core::option::Option, max_length: usize, nullability: vortex_dtype::nullability::Nullability) -> core::option::Option<(vortex_scalar::Scalar, bool)> diff --git a/vortex-scalar/src/downcast.rs b/vortex-scalar/src/downcast.rs index 3074f590ca6..f4420055d3e 100644 --- a/vortex-scalar/src/downcast.rs +++ b/vortex-scalar/src/downcast.rs @@ -3,15 +3,21 @@ //! Scalar downcasting methods to typed views. +use vortex_buffer::BufferString; +use vortex_buffer::ByteBuffer; use vortex_error::VortexExpect; +use vortex_error::vortex_panic; use crate::BinaryScalar; use crate::BoolScalar; use crate::DecimalScalar; +use crate::DecimalValue; use crate::ExtScalar; use crate::ListScalar; +use crate::PValue; use crate::PrimitiveScalar; use crate::Scalar; +use crate::ScalarValue; use crate::StructScalar; use crate::Utf8Scalar; @@ -142,3 +148,105 @@ impl Scalar { ExtScalar::try_new(self.dtype(), self.value()).ok() } } + +impl ScalarValue { + /// Returns the boolean value, panicking if the value is not a [`Bool`][ScalarValue::Bool]. + pub fn as_bool(&self) -> bool { + match self { + ScalarValue::Bool(b) => *b, + _ => vortex_panic!("ScalarValue is not a Bool"), + } + } + + /// Returns the primitive value, panicking if the value is not a + /// [`Primitive`][ScalarValue::Primitive]. + pub fn as_primitive(&self) -> &PValue { + match self { + ScalarValue::Primitive(p) => p, + _ => vortex_panic!("ScalarValue is not a Primitive"), + } + } + + /// Returns the decimal value, panicking if the value is not a + /// [`Decimal`][ScalarValue::Decimal]. + pub fn as_decimal(&self) -> &DecimalValue { + match self { + ScalarValue::Decimal(d) => d, + _ => vortex_panic!("ScalarValue is not a Decimal"), + } + } + + /// Returns the UTF-8 string value, panicking if the value is not a [`Utf8`][ScalarValue::Utf8]. + pub fn as_utf8(&self) -> &BufferString { + match self { + ScalarValue::Utf8(s) => s, + _ => vortex_panic!("ScalarValue is not a Utf8"), + } + } + + /// Returns the binary value, panicking if the value is not a [`Binary`][ScalarValue::Binary]. + pub fn as_binary(&self) -> &ByteBuffer { + match self { + ScalarValue::Binary(b) => b, + _ => vortex_panic!("ScalarValue is not a Binary"), + } + } + + /// Returns the list elements, panicking if the value is not a [`List`][ScalarValue::List]. + pub fn as_list(&self) -> &[Option] { + match self { + ScalarValue::List(elements) => elements, + _ => vortex_panic!("ScalarValue is not a List"), + } + } + + /// Returns the boolean value, panicking if the value is not a [`Bool`][ScalarValue::Bool]. + pub fn into_bool(self) -> bool { + match self { + ScalarValue::Bool(b) => b, + _ => vortex_panic!("ScalarValue is not a Bool"), + } + } + + /// Returns the primitive value, panicking if the value is not a + /// [`Primitive`][ScalarValue::Primitive]. + pub fn into_primitive(self) -> PValue { + match self { + ScalarValue::Primitive(p) => p, + _ => vortex_panic!("ScalarValue is not a Primitive"), + } + } + + /// Returns the decimal value, panicking if the value is not a + /// [`Decimal`][ScalarValue::Decimal]. + pub fn into_decimal(self) -> DecimalValue { + match self { + ScalarValue::Decimal(d) => d, + _ => vortex_panic!("ScalarValue is not a Decimal"), + } + } + + /// Returns the UTF-8 string value, panicking if the value is not a [`Utf8`][ScalarValue::Utf8]. + pub fn into_utf8(self) -> BufferString { + match self { + ScalarValue::Utf8(s) => s, + _ => vortex_panic!("ScalarValue is not a Utf8"), + } + } + + /// Returns the binary value, panicking if the value is not a [`Binary`][ScalarValue::Binary]. + pub fn into_binary(self) -> ByteBuffer { + match self { + ScalarValue::Binary(b) => b, + _ => vortex_panic!("ScalarValue is not a Binary"), + } + } + + /// Returns the list elements, panicking if the value is not a [`List`][ScalarValue::List]. + pub fn into_list(self) -> Vec> { + match self { + ScalarValue::List(elements) => elements, + _ => vortex_panic!("ScalarValue is not a List"), + } + } +} diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index ef956f32c23..c23d483dfd4 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -36,3 +36,6 @@ pub use typed_view::*; #[cfg(test)] mod tests; +mod truncation; + +pub use truncation::*; diff --git a/vortex-scalar/src/scalar_value.rs b/vortex-scalar/src/scalar_value.rs index 789f56823ef..7364d76ab7f 100644 --- a/vortex-scalar/src/scalar_value.rs +++ b/vortex-scalar/src/scalar_value.rs @@ -119,65 +119,6 @@ impl ScalarValue { } } -impl ScalarValue { - /// Returns the boolean value, panicking if the value is not a [`Bool`][ScalarValue::Bool]. - pub fn as_bool(&self) -> bool { - match self { - ScalarValue::Bool(b) => *b, - _ => vortex_panic!("ScalarValue is not a Bool"), - } - } - - /// Returns the primitive value, panicking if the value is not a - /// [`Primitive`][ScalarValue::Primitive]. - pub fn as_primitive(&self) -> &PValue { - match self { - ScalarValue::Primitive(p) => p, - _ => vortex_panic!("ScalarValue is not a Primitive"), - } - } - - /// Returns the decimal value, panicking if the value is not a - /// [`Decimal`][ScalarValue::Decimal]. - pub fn as_decimal(&self) -> &DecimalValue { - match self { - ScalarValue::Decimal(d) => d, - _ => vortex_panic!("ScalarValue is not a Decimal"), - } - } - - /// Returns the UTF-8 string value, panicking if the value is not a [`Utf8`][ScalarValue::Utf8]. - pub fn as_utf8(&self) -> &BufferString { - match self { - ScalarValue::Utf8(s) => s, - _ => vortex_panic!("ScalarValue is not a Utf8"), - } - } - - /// Returns the binary value, panicking if the value is not a [`Binary`][ScalarValue::Binary]. - pub fn as_binary(&self) -> &ByteBuffer { - match self { - ScalarValue::Binary(b) => b, - _ => vortex_panic!("ScalarValue is not a Binary"), - } - } - - /// Returns the list elements, panicking if the value is not a [`List`][ScalarValue::List]. - pub fn as_list(&self) -> &[Option] { - match self { - ScalarValue::List(elements) => elements, - _ => vortex_panic!("ScalarValue is not a List"), - } - } - - // pub fn as_extension(&self) -> &ExtScalarValueRef { - // match self { - // ScalarValue::Extension(e) => e, - // _ => vortex_panic!("ScalarValue is not an Extension"), - // } - // } -} - impl PartialOrd for ScalarValue { fn partial_cmp(&self, other: &Self) -> Option { match (self, other) { diff --git a/vortex-scalar/src/truncation.rs b/vortex-scalar/src/truncation.rs new file mode 100644 index 00000000000..b1876f0e330 --- /dev/null +++ b/vortex-scalar/src/truncation.rs @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Produce lower on upper bounds of scalars via truncation. + +use vortex_buffer::BufferString; +use vortex_buffer::ByteBuffer; +use vortex_dtype::Nullability; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use crate::Scalar; +use crate::StringLike; + +/// A trait for truncating [`Scalar`]s to a given length in bytes. +#[allow(clippy::len_without_is_empty)] +pub trait ScalarTruncation: Send + Sized { + /// Unwrap a Scalar into a ScalarTruncation object + /// + /// # Errors + /// If the scalar doesn't match the truncations dtype. + fn from_scalar(value: Scalar) -> VortexResult>; + + /// The length of the value in bytes. + fn len(&self) -> usize; + + /// Convert the value into a [`Scalar`] with the given nullability. + fn into_scalar(self, nullability: Nullability) -> Scalar; + + /// Constructs the next [`Scalar`] at most `max_length` bytes that's lexicographically greater + /// than this. + /// + /// Returns `None` if the value is null or if constructing a greater value would overflow. + fn upper_bound(self, max_length: usize) -> Option; + + /// Construct a [`ByteBuffer`] at most `max_length` in size that's less than or equal to + /// ourselves. + fn lower_bound(self, max_length: usize) -> Self; +} + +impl ScalarTruncation for ByteBuffer { + fn from_scalar(value: Scalar) -> VortexResult> { + vortex_ensure!( + value.dtype().is_binary(), + "Expected binary scalar, got {}", + value.dtype() + ); + Ok(value.into_value().map(|b| b.into_binary())) + } + + fn len(&self) -> usize { + ByteBuffer::len(self) + } + + fn into_scalar(self, nullability: Nullability) -> Scalar { + Scalar::binary(self, nullability) + } + + fn upper_bound(self, max_length: usize) -> Option { + let sliced = self.slice(0..max_length); + let mut sliced_mut = sliced.into_mut(); + for b in sliced_mut.iter_mut().rev() { + let (incr, overflow) = b.overflowing_add(1); + *b = incr; + if !overflow { + return Some(sliced_mut.freeze()); + } + } + None + } + + fn lower_bound(self, max_length: usize) -> Self { + self.slice(0..max_length) + } +} + +impl ScalarTruncation for BufferString { + fn from_scalar(value: Scalar) -> VortexResult> { + vortex_ensure!( + value.dtype().is_utf8(), + "Expected utf8 scalar, got {}", + value.dtype() + ); + Ok(value.into_value().map(|b| b.into_utf8())) + } + + fn len(&self) -> usize { + self.inner().len() + } + + fn into_scalar(self, nullability: Nullability) -> Scalar { + Scalar::utf8(self, nullability) + } + + /// Constructs the next [`BufferString`] at most `max_length` bytes that's lexicographically greater + /// than this. + /// + /// Returns `None` if the value is null or if constructing a greater value would overflow. + fn upper_bound(self, max_length: usize) -> Option { + let utf8_split_pos = (max_length.saturating_sub(3)..=max_length) + .rfind(|p| self.is_char_boundary(*p)) + .vortex_expect("Failed to find utf8 character boundary"); + + // SAFETY: we slice to a char boundary so the sliced range contains valid UTF-8. + let sliced = + unsafe { BufferString::new_unchecked(self.into_inner().slice(..utf8_split_pos)) }; + sliced.increment().ok() + } + + /// Construct a [`BufferString`] at most `max_length` in size that's less than or equal to + /// ourselves. + fn lower_bound(self, max_length: usize) -> Self { + // UTF-8 characters are at most 4 bytes. Since we know that `BufferString` is + // valid UTF-8, we must have a valid character boundary. + let utf8_split_pos = (max_length.saturating_sub(3)..=max_length) + .rfind(|p| self.is_char_boundary(*p)) + .vortex_expect("Failed to find utf8 character boundary"); + + unsafe { BufferString::new_unchecked(self.into_inner().slice(..utf8_split_pos)) } + } +} + +/// Truncate the value to be less than max_length in bytes and be lexicographically smaller than the value itself +pub fn lower_bound( + value: Option, + max_length: usize, + nullability: Nullability, +) -> Option<(Scalar, bool)> { + let value = value?; + if value.len() > max_length { + Some((value.lower_bound(max_length).into_scalar(nullability), true)) + } else { + Some((value.into_scalar(nullability), false)) + } +} + +/// Truncate the value to be less than max_length in bytes and be lexicographically greater than the value itself +pub fn upper_bound( + value: Option, + max_length: usize, + nullability: Nullability, +) -> Option<(Scalar, bool)> { + let value = value?; + if value.len() > max_length { + value + .upper_bound(max_length) + .map(|v| (v.into_scalar(nullability), true)) + } else { + Some((value.into_scalar(nullability), false)) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::BufferString; + use vortex_buffer::ByteBuffer; + use vortex_buffer::buffer; + use vortex_dtype::Nullability; + + use crate::truncation::ScalarTruncation; + use crate::truncation::lower_bound; + use crate::truncation::upper_bound; + + #[test] + fn binary_lower_bound() { + let binary = buffer![0u8, 5, 47, 33, 129]; + let expected = buffer![0u8, 5]; + assert_eq!(binary.lower_bound(2), expected,); + } + + #[test] + fn binary_upper_bound() { + let binary = buffer![0u8, 5, 255, 234, 23]; + let expected = buffer![0u8, 6, 0]; + assert_eq!(binary.upper_bound(3).unwrap(), expected,); + } + + #[test] + fn binary_upper_bound_overflow() { + let binary = buffer![255u8, 255, 255]; + assert!(binary.upper_bound(2).is_none()); + } + + #[test] + fn binary_upper_bound_null() { + assert!(upper_bound(Option::::None, 10, Nullability::Nullable).is_none()); + } + + #[test] + fn binary_lower_bound_null() { + assert!(lower_bound(Option::::None, 10, Nullability::Nullable).is_none()); + } + + #[test] + fn utf8_lower_bound() { + let utf8 = BufferString::from("snowman⛄️snowman"); + let expected = BufferString::from("snowman"); + assert_eq!(utf8.lower_bound(9), expected); + } + + #[test] + fn utf8_upper_bound() { + let utf8 = BufferString::from("char🪩"); + let expected = BufferString::from("chas"); + assert_eq!(utf8.upper_bound(5).unwrap(), expected); + } + + #[test] + fn utf8_upper_bound_overflow() { + let utf8 = BufferString::from("🂑🂒🂓"); + assert!(utf8.upper_bound(2).is_none()); + } + + #[test] + fn utf8_upper_bound_null() { + assert!(upper_bound(Option::::None, 10, Nullability::Nullable).is_none()); + } + + #[test] + fn utf8_lower_bound_null() { + assert!(lower_bound(Option::::None, 10, Nullability::Nullable).is_none()); + } +} diff --git a/vortex-scalar/src/typed_view/binary.rs b/vortex-scalar/src/typed_view/binary.rs index b44a0621808..7cb54219445 100644 --- a/vortex-scalar/src/typed_view/binary.rs +++ b/vortex-scalar/src/typed_view/binary.rs @@ -90,38 +90,6 @@ impl<'a> BinaryScalar<'a> { self.value } - /// Constructs the next [`Scalar`] at most `max_length` bytes that's lexicographically greater - /// than this. - /// - /// Returns `None` if the value is null or if constructing a greater value would overflow. - pub fn upper_bound(&self, max_length: usize) -> Option { - let value = self.value()?; - let sliced = value.slice(0..max_length); - let mut sliced_mut = sliced.into_mut(); - for b in sliced_mut.iter_mut().rev() { - let (incr, overflow) = b.overflowing_add(1); - *b = incr; - if !overflow { - return Some(Scalar::binary( - sliced_mut.freeze(), - self.dtype().nullability(), - )); - } - } - None - } - - /// Construct a [`Scalar`] at most `max_length` in size that's less than or equal to - /// ourselves. - /// - /// Returns a null [`Scalar`] if the value is null. - pub fn lower_bound(&self, max_length: usize) -> Scalar { - match self.value() { - Some(value) => Scalar::binary(value.slice(0..max_length), self.dtype().nullability()), - None => Scalar::null(self.dtype().clone()), - } - } - /// Casts this scalar to the given `dtype`. pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { if !matches!(dtype, DType::Binary(..)) { @@ -163,26 +131,6 @@ mod tests { use crate::Scalar; use crate::ScalarValue; - #[test] - fn lower_bound() { - let binary = Scalar::binary(buffer![0u8, 5, 47, 33, 129], Nullability::NonNullable); - let expected = Scalar::binary(buffer![0u8, 5], Nullability::NonNullable); - assert_eq!(binary.as_binary().lower_bound(2), expected,); - } - - #[test] - fn upper_bound() { - let binary = Scalar::binary(buffer![0u8, 5, 255, 234, 23], Nullability::NonNullable); - let expected = Scalar::binary(buffer![0u8, 6, 0], Nullability::NonNullable); - assert_eq!(binary.as_binary().upper_bound(3).unwrap(), expected,); - } - - #[test] - fn upper_bound_overflow() { - let binary = Scalar::binary(buffer![255u8, 255, 255], Nullability::NonNullable); - assert!(binary.as_binary().upper_bound(2).is_none()); - } - #[rstest] #[case(&[1u8, 2, 3], &[1u8, 2, 3], true)] #[case(&[1u8, 2, 3], &[1u8, 2, 4], false)] @@ -316,20 +264,6 @@ mod tests { assert!(scalar.as_binary_opt().is_none()); } - #[test] - fn test_upper_bound_null() { - let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); - let scalar = null_binary.as_binary(); - assert!(scalar.upper_bound(10).is_none()); - } - - #[test] - fn test_lower_bound_null() { - let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); - let scalar = null_binary.as_binary(); - assert!(scalar.lower_bound(10).is_null()); - } - #[test] fn test_from_slice() { let data: &[u8] = &[1u8, 2, 3, 4]; diff --git a/vortex-scalar/src/typed_view/utf8.rs b/vortex-scalar/src/typed_view/utf8.rs index 165b2a45bd6..7513cbbcaeb 100644 --- a/vortex-scalar/src/typed_view/utf8.rs +++ b/vortex-scalar/src/typed_view/utf8.rs @@ -86,43 +86,6 @@ impl<'a> Utf8Scalar<'a> { self.value } - /// Constructs the next [`Scalar`] at most `max_length` bytes that's lexicographically greater - /// than this. - /// - /// Returns `None` if the value is null or if constructing a greater value would overflow. - pub fn upper_bound(&self, max_length: usize) -> Option { - let value = self.value()?; - let utf8_split_pos = (max_length.saturating_sub(3)..=max_length) - .rfind(|p| value.is_char_boundary(*p)) - .vortex_expect("Failed to find utf8 character boundary"); - - // SAFETY: we slice to a char boundary so the sliced range contains valid UTF-8. - let sliced = unsafe { BufferString::new_unchecked(value.inner().slice(..utf8_split_pos)) }; - let incremented = sliced.increment().ok()?; - Some(Scalar::utf8(incremented, self.dtype().nullability())) - } - - /// Construct a [`Scalar`] at most `max_length` in size that's less than or equal to - /// ourselves. - /// - /// Returns a null [`Scalar`] if the value is null. - pub fn lower_bound(&self, max_length: usize) -> Scalar { - match self.value() { - Some(value) => { - // UTF-8 characters are at most 4 bytes. Since we know that `BufferString` is - // valid UTF-8, we must have a valid character boundary. - let utf8_split_pos = (max_length.saturating_sub(3)..=max_length) - .rfind(|p| value.is_char_boundary(*p)) - .vortex_expect("Failed to find utf8 character boundary"); - - let sliced = - unsafe { BufferString::new_unchecked(value.inner().slice(0..utf8_split_pos)) }; - Scalar::utf8(sliced, self.dtype().nullability()) - } - None => Scalar::null(self.dtype().clone()), - } - } - /// Casts this scalar to the given `dtype`. pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { if !matches!(dtype, DType::Utf8(..)) { @@ -232,26 +195,6 @@ mod tests { use crate::Scalar; use crate::Utf8Scalar; - #[test] - fn lower_bound() { - let utf8 = Scalar::utf8("snowman⛄️snowman", Nullability::NonNullable); - let expected = Scalar::utf8("snowman", Nullability::NonNullable); - assert_eq!(utf8.as_utf8().lower_bound(9), expected,); - } - - #[test] - fn upper_bound() { - let utf8 = Scalar::utf8("char🪩", Nullability::NonNullable); - let expected = Scalar::utf8("chas", Nullability::NonNullable); - assert_eq!(utf8.as_utf8().upper_bound(5).unwrap(), expected,); - } - - #[test] - fn upper_bound_overflow() { - let utf8 = Scalar::utf8("🂑🂒🂓", Nullability::NonNullable); - assert!(utf8.as_utf8().upper_bound(2).is_none()); - } - #[rstest] #[case("hello", "hello", true)] #[case("hello", "world", false)] @@ -377,20 +320,6 @@ mod tests { assert!(scalar.as_utf8_opt().is_none()); } - #[test] - fn test_upper_bound_null() { - let null_utf8 = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); - let scalar = null_utf8.as_utf8(); - assert!(scalar.upper_bound(10).is_none()); - } - - #[test] - fn test_lower_bound_null() { - let null_utf8 = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); - let scalar = null_utf8.as_utf8(); - assert!(scalar.lower_bound(10).is_null()); - } - #[test] fn test_from_str() { let data = "hello world";