From 900898db2e281055e61c1410b215af7215655c8a Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 11 Feb 2026 17:49:10 -0500 Subject: [PATCH 1/2] extension scalars Signed-off-by: Connor Tsui --- Cargo.lock | 1 + vortex-array/src/stats/flatbuffers.rs | 11 +- vortex-dtype/src/datetime/timestamp.rs | 7 + vortex-dtype/src/extension/mod.rs | 38 +- vortex-scalar/Cargo.toml | 1 + vortex-scalar/src/arrow.rs | 44 +- vortex-scalar/src/cast.rs | 13 + vortex-scalar/src/constructor.rs | 64 ++- vortex-scalar/src/display.rs | 128 ++--- vortex-scalar/src/downcast.rs | 214 ++++++++ vortex-scalar/src/extension/datetime/date.rs | 67 +++ vortex-scalar/src/extension/datetime/mod.rs | 70 +++ vortex-scalar/src/extension/datetime/time.rs | 117 +++++ .../src/extension/datetime/timestamp.rs | 97 ++++ vortex-scalar/src/extension/matcher.rs | 55 +++ vortex-scalar/src/extension/mod.rs | 29 ++ vortex-scalar/src/extension/scalar_value.rs | 377 +++++++++++++++ vortex-scalar/src/extension/tests/apples.rs | 96 ++++ vortex-scalar/src/extension/tests/cast.rs | 91 ++++ .../src/extension/tests/grapheme_utf8.rs | 258 ++++++++++ vortex-scalar/src/extension/tests/mod.rs | 7 + .../src/extension/tests/trivial_ext.rs | 159 ++++++ vortex-scalar/src/extension/vtable.rs | 91 ++++ vortex-scalar/src/lib.rs | 2 + vortex-scalar/src/proto.rs | 455 ++---------------- vortex-scalar/src/scalar.rs | 122 +++-- vortex-scalar/src/scalar_value.rs | 124 +++-- vortex-scalar/src/session.rs | 66 +++ vortex-scalar/src/tests/casting.rs | 171 +++++-- vortex-scalar/src/tests/mod.rs | 10 +- vortex-scalar/src/tests/proto.rs | 407 ++++++++++++++++ vortex-scalar/src/typed_view/ext.rs | 144 ++++++ vortex-scalar/src/typed_view/extension/mod.rs | 152 ------ .../src/typed_view/extension/tests.rs | 293 ----------- vortex-scalar/src/typed_view/mod.rs | 4 +- .../src/typed_view/primitive/scalar.rs | 1 + 36 files changed, 2860 insertions(+), 1126 deletions(-) create mode 100644 vortex-scalar/src/extension/datetime/date.rs create mode 100644 vortex-scalar/src/extension/datetime/mod.rs create mode 100644 vortex-scalar/src/extension/datetime/time.rs create mode 100644 vortex-scalar/src/extension/datetime/timestamp.rs create mode 100644 vortex-scalar/src/extension/matcher.rs create mode 100644 vortex-scalar/src/extension/mod.rs create mode 100644 vortex-scalar/src/extension/scalar_value.rs create mode 100644 vortex-scalar/src/extension/tests/apples.rs create mode 100644 vortex-scalar/src/extension/tests/cast.rs create mode 100644 vortex-scalar/src/extension/tests/grapheme_utf8.rs create mode 100644 vortex-scalar/src/extension/tests/mod.rs create mode 100644 vortex-scalar/src/extension/tests/trivial_ext.rs create mode 100644 vortex-scalar/src/extension/vtable.rs create mode 100644 vortex-scalar/src/session.rs create mode 100644 vortex-scalar/src/tests/proto.rs create mode 100644 vortex-scalar/src/typed_view/ext.rs delete mode 100644 vortex-scalar/src/typed_view/extension/mod.rs delete mode 100644 vortex-scalar/src/typed_view/extension/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 50ed7e89ee7..f8da56a51df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10666,6 +10666,7 @@ dependencies = [ "arrow-array", "bytes", "itertools 0.14.0", + "jiff", "num-traits", "paste", "prost 0.14.3", diff --git a/vortex-array/src/stats/flatbuffers.rs b/vortex-array/src/stats/flatbuffers.rs index 67f82fcb7c3..0d0594e9390 100644 --- a/vortex-array/src/stats/flatbuffers.rs +++ b/vortex-array/src/stats/flatbuffers.rs @@ -11,6 +11,7 @@ use vortex_error::vortex_bail; use vortex_flatbuffers::WriteFlatBuffer; use vortex_flatbuffers::array as fba; use vortex_scalar::ScalarValue; +use vortex_session::VortexSession; use crate::expr::stats::Precision; use crate::expr::stats::Stat; @@ -113,6 +114,7 @@ impl StatsSet { pub fn from_flatbuffer<'a>( fb: &fba::ArrayStats<'a>, array_dtype: &DType, + session: &VortexSession, ) -> VortexResult { let mut stats_set = StatsSet::default(); @@ -142,7 +144,8 @@ impl StatsSet { if let Some(max) = fb.max() && let Some(stat_dtype) = stat_dtype { - let value = ScalarValue::from_proto_bytes(max.bytes(), &stat_dtype)?; + let value = + ScalarValue::from_proto_bytes(max.bytes(), &stat_dtype, session)?; let Some(value) = value else { continue; }; @@ -161,7 +164,8 @@ impl StatsSet { if let Some(min) = fb.min() && let Some(stat_dtype) = stat_dtype { - let value = ScalarValue::from_proto_bytes(min.bytes(), &stat_dtype)?; + let value = + ScalarValue::from_proto_bytes(min.bytes(), &stat_dtype, session)?; let Some(value) = value else { continue; }; @@ -193,7 +197,8 @@ impl StatsSet { if let Some(sum) = fb.sum() && let Some(stat_dtype) = stat_dtype { - let value = ScalarValue::from_proto_bytes(sum.bytes(), &stat_dtype)?; + let value = + ScalarValue::from_proto_bytes(sum.bytes(), &stat_dtype, session)?; let Some(value) = value else { continue; }; diff --git a/vortex-dtype/src/datetime/timestamp.rs b/vortex-dtype/src/datetime/timestamp.rs index e454b80adbe..f104f8d7b57 100644 --- a/vortex-dtype/src/datetime/timestamp.rs +++ b/vortex-dtype/src/datetime/timestamp.rs @@ -31,6 +31,7 @@ impl Timestamp { Self::new_with_tz(time_unit, None, nullability) } + // TODO(connor): This should probably be deprecated in favor of `new_with_options`. /// Creates a new Timestamp extension dtype with the given time unit, timezone, and nullability. pub fn new_with_tz( time_unit: TimeUnit, @@ -46,6 +47,12 @@ impl Timestamp { ) .vortex_expect("failed to create timestamp dtype") } + + /// Creates a new Timestamp extension dtype with the given options and nullability. + pub fn new_with_options(options: TimestampOptions, nullability: Nullability) -> ExtDType { + ExtDType::try_new(options, DType::Primitive(PType::I64, nullability)) + .vortex_expect("failed to create timestamp dtype") + } } /// Options for the Timestamp DType. diff --git a/vortex-dtype/src/extension/mod.rs b/vortex-dtype/src/extension/mod.rs index d6d955346c5..95c8fa4e644 100644 --- a/vortex-dtype/src/extension/mod.rs +++ b/vortex-dtype/src/extension/mod.rs @@ -29,7 +29,7 @@ use crate::Nullability; pub type ExtID = ArcRef; /// An extension data type. -#[derive(Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ExtDType(Arc>); // Convenience impls for zero-sized VTables @@ -60,6 +60,11 @@ impl ExtDType { self.0.id() } + /// Returns the vtable of the extension type. + pub fn vtable(&self) -> &V { + &self.0.vtable + } + /// Returns the metadata of the extension type. pub fn metadata(&self) -> &V::Metadata { &self.0.metadata @@ -70,6 +75,18 @@ impl ExtDType { &self.0.storage_dtype } + /// Returns the nullability of the storage dtype. + #[inline] + pub fn nullability(&self) -> Nullability { + self.storage_dtype().nullability() + } + + /// Returns true if the storage dtype is nullable. + #[inline] + pub fn is_nullable(&self) -> bool { + self.nullability().is_nullable() + } + /// Erase the concrete type information, returning a type-erased extension dtype. pub fn erased(self) -> ExtDTypeRef { ExtDTypeRef(self.0) @@ -135,9 +152,14 @@ impl ExtDTypeRef { self.0.storage_dtype() } + /// Returns the nullability of the storage dtype. + pub fn nullability(&self) -> Nullability { + self.storage_dtype().nullability() + } + /// Returns a new ExtDTypeRef with the given nullability. pub fn with_nullability(&self, nullability: Nullability) -> Self { - if self.storage_dtype().nullability() == nullability { + if self.nullability() == nullability { self.clone() } else { self.0.with_nullability(nullability) @@ -211,7 +233,7 @@ impl ExtDTypeRef { /// Wrapper for type-erased extension dtype metadata. pub struct ExtDTypeMetadata<'a> { - pub(super) ext_dtype: &'a ExtDTypeRef, + ext_dtype: &'a ExtDTypeRef, } impl ExtDTypeMetadata<'_> { @@ -249,7 +271,7 @@ impl Hash for ExtDTypeMetadata<'_> { } /// An object-safe trait encapsulating the behavior for extension DTypes. -trait ExtDTypeImpl: 'static + Send + Sync + private::Sealed { +trait ExtDTypeImpl: 'static + Send + Sync { fn as_any(&self) -> &dyn Any; fn id(&self) -> ExtID; fn storage_dtype(&self) -> &DType; @@ -262,6 +284,7 @@ trait ExtDTypeImpl: 'static + Send + Sync + private::Sealed { fn with_nullability(&self, nullability: Nullability) -> ExtDTypeRef; } +#[derive(Debug, Hash, PartialEq, Eq)] struct ExtDTypeAdapter { vtable: V, metadata: V::Metadata, @@ -314,10 +337,3 @@ impl ExtDTypeImpl for ExtDTypeAdapter { .vortex_expect("Extension DType {} incorrect fails validation with the same storage type but different nullability").erased() } } - -mod private { - use super::ExtDTypeAdapter; - - pub trait Sealed {} - impl Sealed for ExtDTypeAdapter {} -} diff --git a/vortex-scalar/Cargo.toml b/vortex-scalar/Cargo.toml index d69c03f8679..dd52ccbc0ec 100644 --- a/vortex-scalar/Cargo.toml +++ b/vortex-scalar/Cargo.toml @@ -21,6 +21,7 @@ arbitrary = { workspace = true, optional = true } arrow-array = { workspace = true } bytes = { workspace = true } itertools = { workspace = true } +jiff = { workspace = true } num-traits = { workspace = true } paste = { workspace = true } prost = { workspace = true } diff --git a/vortex-scalar/src/arrow.rs b/vortex-scalar/src/arrow.rs index a47958796b5..cdd8f06f9a9 100644 --- a/vortex-scalar/src/arrow.rs +++ b/vortex-scalar/src/arrow.rs @@ -202,9 +202,6 @@ mod tests { use vortex_dtype::datetime::TimeUnit; use vortex_dtype::datetime::Timestamp; use vortex_dtype::datetime::TimestampOptions; - use vortex_dtype::extension::ExtDTypeVTable; - use vortex_error::VortexResult; - use vortex_error::vortex_bail; use crate::DecimalValue; use crate::Scalar; @@ -443,45 +440,6 @@ mod tests { Arc::::try_from(&list_scalar).unwrap(); } - #[test] - #[should_panic(expected = "Cannot convert extension scalar")] - fn test_non_temporal_extension_to_arrow_todo() { - use vortex_dtype::ExtID; - - #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] - struct SomeExt; - impl ExtDTypeVTable for SomeExt { - type Metadata = String; - - fn id(&self) -> ExtID { - ExtID::new_ref("some_ext") - } - - fn serialize(&self, _options: &Self::Metadata) -> VortexResult> { - vortex_bail!("not implemented") - } - - fn deserialize(&self, _data: &[u8]) -> VortexResult { - vortex_bail!("not implemented") - } - - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { - Ok(()) - } - } - - let scalar = Scalar::extension::( - "".into(), - Scalar::primitive(42i32, Nullability::NonNullable), - ); - - Arc::::try_from(&scalar).unwrap(); - } - #[rstest] #[case(TimeUnit::Nanoseconds, PType::I64, 123456789i64)] #[case(TimeUnit::Microseconds, PType::I64, 123456789i64)] @@ -553,7 +511,7 @@ mod tests { #[rstest] #[case(TimeUnit::Nanoseconds, "UTC", 1234567890000000000i64)] #[case(TimeUnit::Microseconds, "EST", 1234567890000000i64)] - #[case(TimeUnit::Milliseconds, "ABC", 1234567890000i64)] + #[case(TimeUnit::Milliseconds, "CET", 1234567890000i64)] #[case(TimeUnit::Seconds, "UTC", 1234567890i64)] fn test_temporal_timestamp_tz_to_arrow( #[case] time_unit: TimeUnit, diff --git a/vortex-scalar/src/cast.rs b/vortex-scalar/src/cast.rs index 97902819852..6318860593f 100644 --- a/vortex-scalar/src/cast.rs +++ b/vortex-scalar/src/cast.rs @@ -45,6 +45,19 @@ impl Scalar { // If the target is an extension type, then we want to cast to its storage type. if let Some(ext_dtype) = target_dtype.as_extension_opt() { let cast_storage_scalar_value = self.cast(ext_dtype.storage_dtype())?.into_value(); + + // NEED A SESSION! + // let ext_scalar_registry = session.scalars(); + // let dyn_scalar_vtable = ext_scalar_registry + // .registry() + // .find(&ext_dtype.id()) + // .ok_or_else(|| { + // vortex_err!( + // "extension type scalar {} did not exist in the registry", + // ext_dtype.id() + // ) + // })?; + return Scalar::try_new(target_dtype.clone(), cast_storage_scalar_value); } diff --git a/vortex-scalar/src/constructor.rs b/vortex-scalar/src/constructor.rs index e8dcd68777e..48a983a5f24 100644 --- a/vortex-scalar/src/constructor.rs +++ b/vortex-scalar/src/constructor.rs @@ -14,14 +14,20 @@ use vortex_dtype::ExtDTypeRef; use vortex_dtype::NativePType; use vortex_dtype::Nullability; use vortex_dtype::PType; -use vortex_dtype::extension::ExtDTypeVTable; use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; use vortex_error::vortex_panic; +use vortex_session::VortexSession; use crate::DecimalValue; use crate::PValue; use crate::Scalar; use crate::ScalarValue; +use crate::extension::ExtScalarVTable; +use crate::extension::ExtScalarValue; +use crate::session::ScalarSessionExt; // TODO(connor): Really, we want `try_` constructors that return errors instead of just panic. impl Scalar { @@ -170,27 +176,57 @@ impl Scalar { .vortex_expect("unable to construct a list `Scalar`") } - /// Creates a new extension scalar wrapping the given storage value. - pub fn extension(options: V::Metadata, value: Scalar) -> Self { - let ext_dtype = ExtDType::::try_new(options, value.dtype().clone()) - .vortex_expect("Failed to create extension dtype"); - Self::try_new(DType::Extension(ext_dtype.erased()), value.into_value()) - .vortex_expect("unable to construct an extension `Scalar`") - } - + // TODO(connor): This needs to return a `VortexResult` instead. /// Creates a new extension scalar wrapping the given storage value. /// /// # Panics /// - /// Panics if the storage dtype of `ext_dtype` does not match `value`'s dtype. - pub fn extension_ref(ext_dtype: ExtDTypeRef, value: Scalar) -> Self { - assert_eq!(ext_dtype.storage_dtype(), value.dtype()); - Self::try_new(DType::Extension(ext_dtype), value.into_value()) + /// Panics if the storage dtype is incompatible with the extension type, or if the storage + /// value fails validation. + pub fn extension(metadata: V::Metadata, value: Scalar) -> Self { + let ext_dtype = ExtDType::::try_new(metadata, value.dtype().clone()) + .vortex_expect("Failed to create extension dtype"); + let storage_value = value.into_value(); + + let ext_value = storage_value.map(|sv| { + let owned = ExtScalarValue::::try_new(ext_dtype.clone(), sv) + .vortex_expect("unable to construct an extension `Scalar`"); + ScalarValue::Extension(owned.erased()) + }); + + Self::try_new(DType::Extension(ext_dtype.erased()), ext_value) .vortex_expect("unable to construct an extension `Scalar`") } + + /// TODO docs. + pub fn extension_ref( + ext_dtype: ExtDTypeRef, + value: Scalar, + session: &VortexSession, + ) -> VortexResult { + let (storage_dtype, storage_value) = value.into_parts(); + Self::extension_ref_from_value(ext_dtype, &storage_dtype, storage_value, session) + } + + /// TODO docs. + pub fn extension_ref_from_value( + ext_dtype: ExtDTypeRef, + storage_dtype: &DType, + storage_value: Option, + session: &VortexSession, + ) -> VortexResult { + vortex_ensure_eq!(ext_dtype.storage_dtype(), storage_dtype); + + let ext_value = Self::extension_value(&ext_dtype, storage_value, session)?; + + Ok( + // SAFETY: `create_ext_scalar_value_ref` validates that the scalar value is compatible. + unsafe { Scalar::new_unchecked(DType::Extension(ext_dtype), ext_value) }, + ) + } } -/// A helper enum for creating a [`ListScalar`]. +/// A helper enum for creating a list scalar. enum ListKind { /// Variable-length list. Variable, diff --git a/vortex-scalar/src/display.rs b/vortex-scalar/src/display.rs index df6e5b1ea1c..f106e77695e 100644 --- a/vortex-scalar/src/display.rs +++ b/vortex-scalar/src/display.rs @@ -42,6 +42,7 @@ mod tests { use vortex_dtype::datetime::Time; use vortex_dtype::datetime::TimeUnit; use vortex_dtype::datetime::Timestamp; + use vortex_dtype::datetime::TimestampOptions; use crate::PValue; use crate::Scalar; @@ -197,18 +198,17 @@ mod tests { DType::Extension(Time::new(TimeUnit::Seconds, Nullable).erased()) } + let storage_dtype = dtype().as_extension().storage_dtype().clone(); assert_eq!(format!("{}", Scalar::null(dtype())), "null"); - assert_eq!( - format!( - "{}", - Scalar::new( - dtype(), - Some(ScalarValue::Primitive(PValue::I32(3 * MINUTES + 25))) - ) - ), - "00:03:25" + let storage_scalar = Scalar::new( + storage_dtype, + Some(ScalarValue::Primitive(PValue::I32(3 * MINUTES + 25))), ); + let ext_scalar = Scalar::extension::