From aaae13f55f6d818f53d5cd69017f16e787c4ba75 Mon Sep 17 00:00:00 2001 From: Will Manning Date: Thu, 2 Apr 2026 14:12:56 -0400 Subject: [PATCH 01/13] TurboQuant encoding for Vectors (#7167) Lossy quantization for vector data (e.g., embeddings) based on TurboQuant (https://arxiv.org/abs/2504.19874). Supports both MSE-optimal and inner-product-optimal (Prod with QJL correction) variants at 1-8 bits per coordinate. Key components: - Single TurboQuant array encoding with optional QJL correction fields, storing quantized codes, norms, centroids, and rotation signs as children. - Structured Random Hadamard Transform (SRHT) for O(d log d) rotation, fully self-contained with no external linear algebra library. - Max-Lloyd centroid computation on Beta(d/2, d/2) distribution. - Approximate cosine similarity and dot product compute directly on quantized arrays without full decompression. - Pluggable TurboQuantScheme for BtrBlocks, exposed via WriteStrategyBuilder::with_vector_quantization(). - Benchmarks covering common embedding dimensions (128, 768, 1024, 1536). Also refactors CompressingStrategy to a single constructor, and adds vortex_tensor::initialize() for session registration of tensor types, encodings, and scalar functions. Co-Authored-By: Claude Opus 4.6 (1M context) Co-Authored-By: Will Manning Signed-off-by: Connor Tsui --- Cargo.lock | 9 + _typos.toml | 2 +- vortex-btrblocks/Cargo.toml | 3 +- vortex-btrblocks/src/builder.rs | 19 +- vortex-file/src/strategy.rs | 4 + vortex-tensor/Cargo.toml | 7 +- vortex-tensor/public-api.lock | 206 +++- vortex-tensor/src/encodings/mod.rs | 3 +- .../src/encodings/turboquant/array.rs | 256 +++++ .../src/encodings/turboquant/centroids.rs | 311 +++++ .../src/encodings/turboquant/compress.rs | 355 ++++++ .../turboquant/compute/cosine_similarity.rs | 151 +++ .../src/encodings/turboquant/compute/mod.rs | 10 + .../src/encodings/turboquant/compute/ops.rs | 28 + .../src/encodings/turboquant/compute/rules.rs | 15 + .../src/encodings/turboquant/compute/slice.rs | 50 + .../src/encodings/turboquant/compute/take.rs | 51 + .../src/encodings/turboquant/decompress.rs | 156 +++ vortex-tensor/src/encodings/turboquant/mod.rs | 1018 +++++++++++++++++ .../src/encodings/turboquant/rotation.rs | 379 ++++++ .../src/encodings/turboquant/scheme.rs | 213 ++++ .../src/encodings/turboquant/vtable.rs | 232 ++++ vortex-tensor/src/lib.rs | 7 +- .../src/scalar_fns/cosine_similarity.rs | 23 +- vortex-tensor/src/scalar_fns/inner_product.rs | 14 +- vortex-tensor/src/scalar_fns/l2_norm.rs | 15 +- vortex/Cargo.toml | 3 + vortex/benches/single_encoding_throughput.rs | 126 +- 28 files changed, 3652 insertions(+), 14 deletions(-) create mode 100644 vortex-tensor/src/encodings/turboquant/array.rs create mode 100644 vortex-tensor/src/encodings/turboquant/centroids.rs create mode 100644 vortex-tensor/src/encodings/turboquant/compress.rs create mode 100644 vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs create mode 100644 vortex-tensor/src/encodings/turboquant/compute/mod.rs create mode 100644 vortex-tensor/src/encodings/turboquant/compute/ops.rs create mode 100644 vortex-tensor/src/encodings/turboquant/compute/rules.rs create mode 100644 vortex-tensor/src/encodings/turboquant/compute/slice.rs create mode 100644 vortex-tensor/src/encodings/turboquant/compute/take.rs create mode 100644 vortex-tensor/src/encodings/turboquant/decompress.rs create mode 100644 vortex-tensor/src/encodings/turboquant/mod.rs create mode 100644 vortex-tensor/src/encodings/turboquant/rotation.rs create mode 100644 vortex-tensor/src/encodings/turboquant/scheme.rs create mode 100644 vortex-tensor/src/encodings/turboquant/vtable.rs diff --git a/Cargo.lock b/Cargo.lock index 61c277226dd..0c7b686cdf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10067,7 +10067,9 @@ dependencies = [ "fastlanes", "mimalloc", "parquet 58.0.0", + "paste", "rand 0.10.0", + "rand_distr 0.6.0", "serde_json", "tokio", "tracing", @@ -10258,6 +10260,7 @@ dependencies = [ "vortex-runend", "vortex-sequence", "vortex-sparse", + "vortex-tensor", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10960,14 +10963,20 @@ dependencies = [ name = "vortex-tensor" version = "0.1.0" dependencies = [ + "half", "itertools 0.14.0", "num-traits", "prost 0.14.3", + "rand 0.10.0", + "rand_distr 0.6.0", "rstest", "vortex-array", "vortex-buffer", + "vortex-compressor", "vortex-error", + "vortex-fastlanes", "vortex-session", + "vortex-utils", ] [[package]] diff --git a/_typos.toml b/_typos.toml index e9cf23d68b7..62c3b0d6358 100644 --- a/_typos.toml +++ b/_typos.toml @@ -1,5 +1,5 @@ [default] -extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui"] +extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui", "wht", "WHT"] # We support a few common special comments to tell the checker to ignore sections of code extend-ignore-re = [ "(#|//)\\s*spellchecker:ignore-next-line\\n.*", # Ignore the next line diff --git a/vortex-btrblocks/Cargo.toml b/vortex-btrblocks/Cargo.toml index 9bbd2430f09..8906fd24d2e 100644 --- a/vortex-btrblocks/Cargo.toml +++ b/vortex-btrblocks/Cargo.toml @@ -35,6 +35,7 @@ vortex-pco = { workspace = true, optional = true } vortex-runend = { workspace = true } vortex-sequence = { workspace = true } vortex-sparse = { workspace = true } +vortex-tensor = { workspace = true, optional = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } @@ -47,7 +48,7 @@ vortex-array = { workspace = true, features = ["_test-harness"] } [features] # This feature enabled unstable encodings for which we don't guarantee stability. -unstable_encodings = ["vortex-zstd?/unstable_encodings"] +unstable_encodings = ["dep:vortex-tensor", "vortex-zstd?/unstable_encodings"] pco = ["dep:pco", "dep:vortex-pco"] zstd = ["dep:vortex-zstd"] diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index 3ff8e872b19..2a390e89504 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -120,7 +120,7 @@ impl BtrBlocksCompressorBuilder { /// Adds compact encoding schemes (Zstd for strings, Pco for numerics). /// /// This provides better compression ratios than the default, especially for floating-point - /// heavy datasets. Requires the `zstd` feature. When the `pco` feature is also enabled, + /// heavy datasets. Requires the `zstd` feature. When the `pco` rfeature is also enabled, /// Pco schemes for integers and floats are included. /// /// # Panics @@ -138,6 +138,23 @@ impl BtrBlocksCompressorBuilder { builder } + /// Adds the TurboQuant lossy vector quantization scheme. + /// + /// When enabled, [`Vector`] extension arrays are compressed using the TurboQuant algorithm with + /// QJL correction for unbiased inner product estimation. + /// + /// # Panics + /// + /// Panics if the TurboQuant scheme is already present. + /// + /// [`Vector`]: vortex_tensor::vector::Vector + /// [`FixedShapeTensor`]: vortex_tensor::fixed_shape::FixedShapeTensor + #[cfg(feature = "unstable_encodings")] + pub fn with_turboquant(self) -> Self { + use vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME; + self.with_new_scheme(&TURBOQUANT_SCHEME) + } + /// Excludes schemes without CUDA kernel support and adds Zstd for string compression. /// /// With the `unstable_encodings` feature, buffer-level Zstd compression is used which diff --git a/vortex-file/src/strategy.rs b/vortex-file/src/strategy.rs index 197efd9583f..9af6c1e9402 100644 --- a/vortex-file/src/strategy.rs +++ b/vortex-file/src/strategy.rs @@ -56,6 +56,8 @@ use vortex_pco::Pco; use vortex_runend::RunEnd; use vortex_sequence::Sequence; use vortex_sparse::Sparse; +#[cfg(feature = "unstable_encodings")] +use vortex_tensor::encodings::turboquant::TurboQuant; use vortex_utils::aliases::hash_map::HashMap; use vortex_zigzag::ZigZag; #[cfg(feature = "zstd")] @@ -104,6 +106,8 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { session.register(RunEnd); session.register(Sequence); session.register(Sparse); + #[cfg(feature = "unstable_encodings")] + session.register(TurboQuant); session.register(ZigZag); #[cfg(feature = "zstd")] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index f0b6670cc51..9f94a0c2d3d 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -19,13 +19,18 @@ workspace = true [dependencies] vortex-array = { workspace = true } vortex-buffer = { workspace = true } +vortex-compressor = { workspace = true } vortex-error = { workspace = true } +vortex-fastlanes = { workspace = true } vortex-session = { workspace = true } +vortex-utils = { workspace = true } +half = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } +rand = { workspace = true } [dev-dependencies] +rand_distr = { workspace = true } rstest = { workspace = true } -vortex-buffer = { workspace = true } diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index cea02b69e38..fbbbe0ace6c 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -2,6 +2,210 @@ pub mod vortex_tensor pub mod vortex_tensor::encodings +pub mod vortex_tensor::encodings::turboquant + +pub mod vortex_tensor::encodings::turboquant::scheme + +pub struct vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::clone(&self) -> vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +impl core::cmp::Eq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +impl core::cmp::PartialEq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::eq(&self, other: &vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme) -> bool + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::marker::Copy for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +impl core::marker::StructuralPartialEq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +impl vortex_compressor::scheme::Scheme for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::compress(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::expected_compression_ratio(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::matches(&self, canonical: &vortex_array::canonical::Canonical) -> bool + +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::scheme_name(&self) -> &'static str + +pub static vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME: vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme + +pub struct vortex_tensor::encodings::turboquant::QjlCorrection + +impl vortex_tensor::encodings::turboquant::QjlCorrection + +pub fn vortex_tensor::encodings::turboquant::QjlCorrection::residual_norms(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::QjlCorrection::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::QjlCorrection::signs(&self) -> &vortex_array::array::erased::ArrayRef + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::QjlCorrection + +pub fn vortex_tensor::encodings::turboquant::QjlCorrection::clone(&self) -> vortex_tensor::encodings::turboquant::QjlCorrection + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::QjlCorrection + +pub fn vortex_tensor::encodings::turboquant::QjlCorrection::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub struct vortex_tensor::encodings::turboquant::TurboQuant + +impl vortex_tensor::encodings::turboquant::TurboQuant + +pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuant + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquant::TurboQuant + +pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData + +pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_array::metadata::ProstMetadata + +pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant + +pub type vortex_tensor::encodings::turboquant::TurboQuant::ValidityVTable = vortex_array::array::vtable::validity::ValidityVTableFromChild + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::array_eq(array: &vortex_tensor::encodings::turboquant::TurboQuantData, other: &vortex_tensor::encodings::turboquant::TurboQuantData, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::array_hash(array: &vortex_tensor::encodings::turboquant::TurboQuantData, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer_name(_array: vortex_array::array::view::ArrayView<'_, Self>, _idx: usize) -> core::option::Option + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::dtype(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::dtype::DType + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute(array: vortex_array::array::typed::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::id(&self) -> vortex_array::array::ArrayId + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::len(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> usize + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::metadata(array: vortex_array::array::view::ArrayView<'_, Self>) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::nbuffers(_array: vortex_array::array::view::ArrayView<'_, Self>) -> usize + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::reduce_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::serialize(metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::slot_name(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> alloc::string::String + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::slots(array: vortex_array::array::view::ArrayView<'_, Self>) -> &[core::option::Option] + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::stats(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::stats::array::ArrayStats + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::vtable(_array: &Self::ArrayData) -> &Self + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::with_slots(array: &mut vortex_tensor::encodings::turboquant::TurboQuantData, slots: alloc::vec::Vec>) -> vortex_error::VortexResult<()> + +impl vortex_array::array::vtable::operations::OperationsVTable for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +impl vortex_array::array::vtable::validity::ValidityChild for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::validity_child(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::array::erased::ArrayRef + +impl vortex_array::arrays::dict::take::TakeExecute for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::take(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, indices: &vortex_array::array::erased::ArrayRef, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::arrays::slice::SliceReduce for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::slice(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, range: core::ops::range::Range) -> vortex_error::VortexResult> + +pub struct vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8 + +pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: core::option::Option + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub fn vortex_tensor::encodings::turboquant::TurboQuantConfig::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantConfig + +impl core::default::Default for vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub fn vortex_tensor::encodings::turboquant::TurboQuantConfig::default() -> Self + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub fn vortex_tensor::encodings::turboquant::TurboQuantConfig::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub struct vortex_tensor::encodings::turboquant::TurboQuantData + +impl vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::bit_width(&self) -> u8 + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::centroids(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::codes(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32 + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::has_qjl(&self) -> bool + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::norms(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32 + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::qjl(&self) -> core::option::Option + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new_mse(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef, dimension: u32, bit_width: u8) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new_qjl(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef, qjl: vortex_tensor::encodings::turboquant::QjlCorrection, dimension: u32, bit_width: u8) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantData + +impl core::convert::From for vortex_array::array::erased::ArrayRef + +pub fn vortex_array::array::erased::ArrayRef::from(value: vortex_tensor::encodings::turboquant::TurboQuantData) -> vortex_array::array::erased::ArrayRef + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::array::IntoArray for vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::into_array(self) -> vortex_array::array::erased::ArrayRef + +pub const vortex_tensor::encodings::turboquant::FIXED_SHAPE_TENSOR_EXT_ID: &str + +pub const vortex_tensor::encodings::turboquant::VECTOR_EXT_ID: &str + +pub fn vortex_tensor::encodings::turboquant::initialize(session: &mut vortex_session::VortexSession) + +pub fn vortex_tensor::encodings::turboquant::turboquant_encode_mse(fsl: &vortex_array::arrays::fixed_size_list::vtable::FixedSizeListArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::turboquant_encode_qjl(fsl: &vortex_array::arrays::fixed_size_list::vtable::FixedSizeListArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig) -> vortex_error::VortexResult + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor @@ -180,7 +384,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs index 090151e9226..7c75269b632 100644 --- a/vortex-tensor/src/encodings/mod.rs +++ b/vortex-tensor/src/encodings/mod.rs @@ -7,5 +7,4 @@ // pub mod norm; // Unit-normalized vectors. // pub mod spherical; // Spherical transform on unit-normalized vectors. -// TODO(will): -// pub mod turboquant; +pub mod turboquant; diff --git a/vortex-tensor/src/encodings/turboquant/array.rs b/vortex-tensor/src/encodings/turboquant/array.rs new file mode 100644 index 00000000000..89c600853b4 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array.rs @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant array definition: stores quantized coordinate codes, norms, +//! centroids (codebook), rotation signs, and optional QJL correction fields. + +use vortex_array::ArrayId; +use vortex_array::ArrayRef; +use vortex_array::dtype::DType; +use vortex_array::stats::ArrayStats; +use vortex_array::vtable; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// Encoding marker type for TurboQuant. +#[derive(Clone, Debug)] +pub struct TurboQuant; + +impl TurboQuant { + pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); +} + +vtable!(TurboQuant, TurboQuant, TurboQuantData); + +/// Protobuf metadata for TurboQuant encoding. +#[derive(Clone, prost::Message)] +pub struct TurboQuantMetadata { + /// Vector dimension d. + #[prost(uint32, tag = "1")] + pub dimension: u32, + /// MSE bits per coordinate (1-8). + #[prost(uint32, tag = "2")] + pub bit_width: u32, + /// Whether QJL correction children are present. + #[prost(bool, tag = "3")] + pub has_qjl: bool, +} + +/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased +/// inner product estimation. When present, adds 3 additional children. +#[derive(Clone, Debug)] +pub struct QjlCorrection { + /// Sign bits: `BoolArray`, length `num_rows * padded_dim`. + pub(crate) signs: ArrayRef, + /// Residual norms: `PrimitiveArray`, length `num_rows`. + pub(crate) residual_norms: ArrayRef, + /// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order). + pub(crate) rotation_signs: ArrayRef, +} + +impl QjlCorrection { + /// The QJL sign bits. + pub fn signs(&self) -> &ArrayRef { + &self.signs + } + + /// The residual norms. + pub fn residual_norms(&self) -> &ArrayRef { + &self.residual_norms + } + + /// The QJL rotation signs (BoolArray, inverse application order). + pub fn rotation_signs(&self) -> &ArrayRef { + &self.rotation_signs + } +} + +/// Slot positions for TurboQuantArray children. +#[repr(usize)] +#[derive(Clone, Copy, Debug)] +pub(crate) enum Slot { + Codes = 0, + Norms = 1, + Centroids = 2, + RotationSigns = 3, + QjlSigns = 4, + QjlResidualNorms = 5, + QjlRotationSigns = 6, +} + +impl Slot { + pub(crate) const COUNT: usize = 7; + + pub(crate) fn name(self) -> &'static str { + match self { + Self::Codes => "codes", + Self::Norms => "norms", + Self::Centroids => "centroids", + Self::RotationSigns => "rotation_signs", + Self::QjlSigns => "qjl_signs", + Self::QjlResidualNorms => "qjl_residual_norms", + Self::QjlRotationSigns => "qjl_rotation_signs", + } + } + + pub(crate) fn from_index(idx: usize) -> Self { + match idx { + 0 => Self::Codes, + 1 => Self::Norms, + 2 => Self::Centroids, + 3 => Self::RotationSigns, + 4 => Self::QjlSigns, + 5 => Self::QjlResidualNorms, + 6 => Self::QjlRotationSigns, + _ => vortex_error::vortex_panic!("invalid slot index {idx}"), + } + } +} + +/// TurboQuant array. +/// +/// Slots (always present): +/// - 0: `codes` — `FixedSizeListArray` (quantized indices, list_size=padded_dim) +/// - 1: `norms` — `PrimitiveArray` (one per vector row) +/// - 2: `centroids` — `PrimitiveArray` (codebook, length 2^bit_width) +/// - 3: `rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order) +/// +/// Optional QJL slots (None when MSE-only): +/// - 4: `qjl_signs` — `FixedSizeListArray` (num_rows * padded_dim, 1-bit) +/// - 5: `qjl_residual_norms` — `PrimitiveArray` (one per row) +/// - 6: `qjl_rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit, QJL rotation) +#[derive(Clone, Debug)] +pub struct TurboQuantData { + pub(crate) dtype: DType, + pub(crate) slots: Vec>, + pub(crate) dimension: u32, + pub(crate) bit_width: u8, + pub(crate) stats_set: ArrayStats, +} + +impl TurboQuantData { + /// Build a TurboQuant array with MSE-only encoding (no QJL correction). + #[allow(clippy::too_many_arguments)] + pub fn try_new_mse( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + dimension: u32, + bit_width: u8, + ) -> VortexResult { + vortex_ensure!( + (1..=8).contains(&bit_width), + "MSE bit_width must be 1-8, got {bit_width}" + ); + let mut slots = vec![None; Slot::COUNT]; + slots[Slot::Codes as usize] = Some(codes); + slots[Slot::Norms as usize] = Some(norms); + slots[Slot::Centroids as usize] = Some(centroids); + slots[Slot::RotationSigns as usize] = Some(rotation_signs); + Ok(Self { + dtype, + slots, + dimension, + bit_width, + stats_set: Default::default(), + }) + } + + /// Build a TurboQuant array with QJL correction (MSE + QJL). + #[allow(clippy::too_many_arguments)] + pub fn try_new_qjl( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + qjl: QjlCorrection, + dimension: u32, + bit_width: u8, + ) -> VortexResult { + vortex_ensure!( + (1..=8).contains(&bit_width), + "MSE bit_width must be 1-8, got {bit_width}" + ); + let mut slots = vec![None; Slot::COUNT]; + slots[Slot::Codes as usize] = Some(codes); + slots[Slot::Norms as usize] = Some(norms); + slots[Slot::Centroids as usize] = Some(centroids); + slots[Slot::RotationSigns as usize] = Some(rotation_signs); + slots[Slot::QjlSigns as usize] = Some(qjl.signs); + slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms); + slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs); + Ok(Self { + dtype, + slots, + dimension, + bit_width, + stats_set: Default::default(), + }) + } + + /// The vector dimension d. + pub fn dimension(&self) -> u32 { + self.dimension + } + + /// MSE bits per coordinate. + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Padded dimension (next power of 2 >= dimension). + pub fn padded_dim(&self) -> u32 { + self.dimension.next_power_of_two() + } + + /// Whether QJL correction is present. + pub fn has_qjl(&self) -> bool { + self.slots[Slot::QjlSigns as usize].is_some() + } + + fn slot(&self, idx: usize) -> &ArrayRef { + self.slots[idx] + .as_ref() + .vortex_expect("required slot is None") + } + + /// The quantized codes child (FixedSizeListArray). + pub fn codes(&self) -> &ArrayRef { + self.slot(Slot::Codes as usize) + } + + /// The norms child (`PrimitiveArray`). + pub fn norms(&self) -> &ArrayRef { + self.slot(Slot::Norms as usize) + } + + /// The centroids (codebook) child (`PrimitiveArray`). + pub fn centroids(&self) -> &ArrayRef { + self.slot(Slot::Centroids as usize) + } + + /// The MSE rotation signs child (BitPackedArray, length 3 * padded_dim). + pub fn rotation_signs(&self) -> &ArrayRef { + self.slot(Slot::RotationSigns as usize) + } + + /// The optional QJL correction fields, reconstructed from slots. + pub fn qjl(&self) -> Option { + Some(QjlCorrection { + signs: self.slots[Slot::QjlSigns as usize].clone()?, + residual_norms: self.slots[Slot::QjlResidualNorms as usize].clone()?, + rotation_signs: self.slots[Slot::QjlRotationSigns as usize].clone()?, + }) + } + + /// Set the QJL correction fields on this array. + pub(crate) fn set_qjl(&mut self, qjl: QjlCorrection) { + self.slots[Slot::QjlSigns as usize] = Some(qjl.signs); + self.slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms); + self.slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/centroids.rs b/vortex-tensor/src/encodings/turboquant/centroids.rs new file mode 100644 index 00000000000..85ea39fcc9e --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/centroids.rs @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. +//! +//! Pre-computes optimal scalar quantizer centroids for the marginal distribution of coordinates +//! after random rotation of a unit-norm vector. In high dimensions, each coordinate of a randomly +//! rotated unit vector follows a distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, +//! which converges to `N(0, 1/d)`. The Max-Lloyd algorithm finds optimal quantization centroids +//! that minimize MSE for this distribution. + +use std::sync::LazyLock; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_utils::aliases::dash_map::DashMap; + +/// Number of numerical integration points for computing conditional expectations. +const INTEGRATION_POINTS: usize = 1000; + +/// Max-Lloyd convergence threshold. +const CONVERGENCE_EPSILON: f64 = 1e-12; + +/// Maximum iterations for Max-Lloyd algorithm. +const MAX_ITERATIONS: usize = 200; + +/// Global centroid cache keyed by (dimension, bit_width). +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); + +/// Get or compute cached centroids for the given dimension and bit width. +/// +/// Returns `2^bit_width` centroids sorted in ascending order, representing +/// optimal scalar quantization levels for the coordinate distribution after +/// random rotation in `dimension`-dimensional space. +pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { + if !(1..=8).contains(&bit_width) { + vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}"); + } + if dimension < 3 { + vortex_bail!("TurboQuant dimension must be >= 3, got {dimension}"); + } + + if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { + return Ok(centroids.clone()); + } + + let centroids = max_lloyd_centroids(dimension, bit_width); + CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + Ok(centroids) +} + +/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. +/// +/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) +/// or a half-integer (when `d` is even). This type makes that invariant explicit and +/// avoids floating-point comparison in the hot path. +#[derive(Clone, Copy, Debug)] +struct HalfIntExponent { + int_part: i32, + has_half: bool, +} + +impl HalfIntExponent { + /// Compute `(numerator) / 2` as a half-integer exponent. + /// + /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. + fn from_numerator(numerator: i32) -> Self { + // Integer division truncates toward zero; for negative odd numerators + // (e.g., d=2 → num=-1) this gives int_part=0, has_half=true, + // representing -0.5 = 0 + (-0.5). The sign is handled by adjusting + // int_part: -1/2 = 0 with has_half, but we need the floor division. + // Rust's `/` truncates toward zero, so -1/2 = 0. We want floor: -1. + // Use divmod that rounds toward negative infinity. + let int_part = numerator.div_euclid(2); + let has_half = numerator.rem_euclid(2) != 0; + Self { int_part, has_half } + } +} + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// +/// Operates on the marginal distribution of a single coordinate of a randomly +/// rotated unit vector in d dimensions. The PDF is: +/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` +/// where `C_d` is the normalizing constant. +#[allow(clippy::cast_possible_truncation)] // f64→f32 centroid values are intentional +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec { + let num_centroids = 1usize << bit_width; + + // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. + let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); + + // Initialize centroids uniformly on [-1, 1]. + let mut centroids: Vec = (0..num_centroids) + .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .collect(); + + let mut boundaries: Vec = vec![0.0; num_centroids + 1]; + for _ in 0..MAX_ITERATIONS { + // Compute decision boundaries (midpoints between adjacent centroids). + boundaries[0] = -1.0; + for idx in 0..num_centroids - 1 { + boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; + } + boundaries[num_centroids] = 1.0; + + // Update each centroid to the conditional mean within its Voronoi cell. + let mut max_change = 0.0f64; + for idx in 0..num_centroids { + let lo = boundaries[idx]; + let hi = boundaries[idx + 1]; + let new_centroid = conditional_mean(lo, hi, exponent); + max_change = max_change.max((new_centroid - centroids[idx]).abs()); + centroids[idx] = new_centroid; + } + + if max_change < CONVERGENCE_EPSILON { + break; + } + } + + centroids.into_iter().map(|val| val as f32).collect() +} + +/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. +/// +/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` +/// on [-1, 1]. +fn conditional_mean(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { + if (hi - lo).abs() < 1e-15 { + return (lo + hi) / 2.0; + } + + let dx = (hi - lo) / INTEGRATION_POINTS as f64; + + let mut numerator = 0.0; + let mut denominator = 0.0; + + for step in 0..=INTEGRATION_POINTS { + let x_val = lo + (step as f64) * dx; + let weight = pdf_unnormalized(x_val, exponent); + + let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + 0.5 + } else { + 1.0 + }; + + numerator += trap_weight * x_val * weight; + denominator += trap_weight * weight; + } + + if denominator.abs() < 1e-30 { + (lo + hi) / 2.0 + } else { + numerator / denominator + } +} + +/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. +/// +/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents +/// that arise from `(d-3)/2`. This is significantly faster than the general +/// `powf` which goes through `exp(exponent * ln(base))`. +#[inline] +fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { + let base = (1.0 - x_val * x_val).max(0.0); + + if exponent.has_half { + // Half-integer exponent: base^(int_part) * sqrt(base). + base.powi(exponent.int_part) * base.sqrt() + } else { + // Integer exponent: use powi directly. + base.powi(exponent.int_part) + } +} + +/// Precompute decision boundaries (midpoints between adjacent centroids). +/// +/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps +/// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, +/// and a value >= `boundaries[k-2]` maps to centroid `k-1`. +pub fn compute_boundaries(centroids: &[f32]) -> Vec { + centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() +} + +/// Find the index of the nearest centroid using precomputed decision boundaries. +/// +/// `boundaries` must be the output of [`compute_boundaries`] for the corresponding +/// centroids. Uses binary search on the midpoints, avoiding distance comparisons +/// in the inner loop. +#[inline] +#[allow(clippy::cast_possible_truncation)] // bounded by num_centroids <= 256 +pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { + debug_assert!( + boundaries.windows(2).all(|w| w[0] <= w[1]), + "boundaries must be sorted" + ); + + boundaries.partition_point(|&b| b < value) as u8 +} + +#[cfg(test)] +#[allow(clippy::cast_possible_truncation)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[rstest] + #[case(128, 1, 2)] + #[case(128, 2, 4)] + #[case(128, 3, 8)] + #[case(128, 4, 16)] + #[case(768, 2, 4)] + #[case(1536, 3, 8)] + fn centroids_have_correct_count( + #[case] dim: u32, + #[case] bits: u8, + #[case] expected: usize, + ) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + assert_eq!(centroids.len(), expected); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(768, 2)] + fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for window in centroids.windows(2) { + assert!( + window[0] < window[1], + "centroids not sorted: {:?}", + centroids + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(256, 2)] + #[case(768, 2)] + fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + let count = centroids.len(); + for idx in 0..count / 2 { + let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); + assert!( + diff < 1e-5, + "centroids not symmetric: c[{idx}]={}, c[{}]={}", + centroids[idx], + count - 1 - idx, + centroids[count - 1 - idx] + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 4)] + fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for &val in ¢roids { + assert!( + (-1.0..=1.0).contains(&val), + "centroid out of [-1, 1]: {val}", + ); + } + Ok(()) + } + + #[test] + fn centroids_cached() -> VortexResult<()> { + let c1 = get_centroids(128, 2)?; + let c2 = get_centroids(128, 2)?; + assert_eq!(c1, c2); + Ok(()) + } + + #[test] + fn find_nearest_basic() -> VortexResult<()> { + let centroids = get_centroids(128, 2)?; + let boundaries = compute_boundaries(¢roids); + assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + + let last_idx = (centroids.len() - 1) as u8; + assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + for (idx, &cv) in centroids.iter().enumerate() { + let expected = idx as u8; + assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + } + Ok(()) + } + + #[test] + fn rejects_invalid_params() { + assert!(get_centroids(128, 0).is_err()); + assert!(get_centroids(128, 9).is_err()); + assert!(get_centroids(1, 2).is_err()); + assert!(get_centroids(2, 2).is_err()); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs new file mode 100644 index 00000000000..756b17cdada --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant encoding (quantization) logic. + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_fastlanes::bitpack_compress::bitpack_encode; + +use crate::encodings::turboquant::array::TurboQuantData; +use crate::encodings::turboquant::centroids::compute_boundaries; +use crate::encodings::turboquant::centroids::find_nearest_centroid; +use crate::encodings::turboquant::centroids::get_centroids; +use crate::encodings::turboquant::rotation::RotationMatrix; + +/// Configuration for TurboQuant encoding. +#[derive(Clone, Debug)] +pub struct TurboQuantConfig { + /// Bits per coordinate. + /// + /// For MSE encoding: 1-8. + /// For QJL encoding: 2-9 (the MSE component uses `bit_width - 1`). + pub bit_width: u8, + /// Optional seed for the rotation matrix. If None, the default seed is used. + pub seed: Option, +} + +impl Default for TurboQuantConfig { + fn default() -> Self { + Self { + bit_width: 5, + seed: Some(42), + } + } +} + +/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray. +#[allow(clippy::cast_possible_truncation)] +fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult { + let elements = fsl.elements(); + let primitive = elements.to_canonical()?.into_primitive(); + let ptype = primitive.ptype(); + + match ptype { + PType::F16 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| f32::from(v)) + .collect()), + PType::F32 => Ok(primitive), + PType::F64 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| v as f32) + .collect()), + _ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"), + } +} + +/// Compute the L2 norm of a vector. +#[inline] +fn l2_norm(x: &[f32]) -> f32 { + x.iter().map(|&v| v * v).sum::().sqrt() +} + +/// Shared intermediate results from the MSE quantization loop. +struct MseQuantizationResult { + rotation: RotationMatrix, + f32_elements: PrimitiveArray, + centroids: Vec, + all_indices: BufferMut, + norms: BufferMut, + padded_dim: usize, +} + +/// Core quantization: extract f32 elements, build rotation, normalize/rotate/quantize all rows. +#[allow(clippy::cast_possible_truncation)] +fn turboquant_quantize_core( + fsl: &FixedSizeListArray, + seed: u64, + bit_width: u8, +) -> VortexResult { + let dimension = fsl.list_size() as usize; + let num_rows = fsl.len(); + + let rotation = RotationMatrix::try_new(seed, dimension)?; + let padded_dim = rotation.padded_dim(); + + let f32_elements = extract_f32_elements(fsl)?; + + let centroids = get_centroids(padded_dim as u32, bit_width)?; + let boundaries = compute_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut norms = BufferMut::::with_capacity(num_rows); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + let f32_slice = f32_elements.as_slice::(); + for row in 0..num_rows { + let x = &f32_slice[row * dimension..(row + 1) * dimension]; + let norm = l2_norm(x); + norms.push(norm); + + if norm > 0.0 { + let inv_norm = 1.0 / norm; + for (dst, &src) in padded[..dimension].iter_mut().zip(x.iter()) { + *dst = src * inv_norm; + } + } else { + padded[..dimension].fill(0.0); + } + rotation.rotate(&padded, &mut rotated); + + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + Ok(MseQuantizationResult { + rotation, + f32_elements, + centroids, + all_indices, + norms, + padded_dim, + }) +} + +/// Build a `TurboQuantArray` (MSE-only) from quantization results. +#[allow(clippy::cast_possible_truncation)] +fn build_turboquant_mse( + fsl: &FixedSizeListArray, + core: MseQuantizationResult, + bit_width: u8, +) -> VortexResult { + let dimension = fsl.list_size(); + + let num_rows = fsl.len(); + let padded_dim = core.padded_dim; + let codes_elements = + PrimitiveArray::new::(core.all_indices.freeze(), Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes_elements.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )? + .into_array(); + let norms_array = + PrimitiveArray::new::(core.norms.freeze(), Validity::NonNullable).into_array(); + + // TODO(perf): `get_centroids` returns Vec; could avoid the copy by + // supporting Buffer::from(Vec) or caching as Buffer directly. + let mut centroids_buf = BufferMut::::with_capacity(core.centroids.len()); + centroids_buf.extend_from_slice(&core.centroids); + let centroids_array = + PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable).into_array(); + + let rotation_signs = bitpack_rotation_signs(&core.rotation)?; + + TurboQuantData::try_new_mse( + fsl.dtype().clone(), + codes, + norms_array, + centroids_array, + rotation_signs, + dimension, + bit_width, + ) +} + +/// Encode a FixedSizeListArray into a MSE-only `TurboQuantArray`. +/// +/// The input must be non-nullable. TurboQuant is a lossy encoding that does not +/// preserve null positions; callers must handle validity externally. +pub fn turboquant_encode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 1 && config.bit_width <= 8, + "MSE bit_width must be 1-8, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + if fsl.is_empty() { + return Ok(fsl.clone().into_array()); + } + + let seed = config.seed.unwrap_or(42); + let core = turboquant_quantize_core(fsl, seed, config.bit_width)?; + + Ok(build_turboquant_mse(fsl, core, config.bit_width)?.into_array()) +} + +/// Encode a FixedSizeListArray into a `TurboQuantArray` with QJL correction. +/// +/// The QJL variant uses `bit_width - 1` MSE bits plus 1 bit of QJL residual +/// correction, giving unbiased inner product estimation. The input must be +/// non-nullable. +#[allow(clippy::cast_possible_truncation)] +pub fn turboquant_encode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 2 && config.bit_width <= 9, + "QJL bit_width must be 2-9, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + if fsl.is_empty() { + return Ok(fsl.clone().into_array()); + } + + let seed = config.seed.unwrap_or(42); + let dim = dimension as usize; + let mse_bit_width = config.bit_width - 1; + + let core = turboquant_quantize_core(fsl, seed, mse_bit_width)?; + let padded_dim = core.padded_dim; + + // QJL uses a different rotation than the MSE stage to ensure statistical + // independence between the quantization noise and the sign projection. + let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(25), dim)?; + + let num_rows = fsl.len(); + let mut residual_norms_buf = BufferMut::::with_capacity(num_rows); + let mut qjl_sign_u8 = BufferMut::::with_capacity(num_rows * padded_dim); + + let mut dequantized_rotated = vec![0.0f32; padded_dim]; + let mut dequantized = vec![0.0f32; padded_dim]; + let mut residual = vec![0.0f32; padded_dim]; + let mut projected = vec![0.0f32; padded_dim]; + + // Compute QJL residuals using precomputed indices and norms from the core. + { + let f32_slice = core.f32_elements.as_slice::(); + let indices_slice: &[u8] = &core.all_indices; + let norms_slice: &[f32] = &core.norms; + + for row in 0..num_rows { + let x = &f32_slice[row * dim..(row + 1) * dim]; + let norm = norms_slice[row]; + + // Dequantize from precomputed indices. + let row_indices = &indices_slice[row * padded_dim..(row + 1) * padded_dim]; + for j in 0..padded_dim { + dequantized_rotated[j] = core.centroids[row_indices[j] as usize]; + } + + core.rotation + .inverse_rotate(&dequantized_rotated, &mut dequantized); + if norm > 0.0 { + for val in dequantized[..dim].iter_mut() { + *val *= norm; + } + } + + // Compute residual: r = x_padded - x̂. + // For positions 0..dim: r[j] = x[j] - dequantized[j]. + // For pad positions dim..padded_dim: the original was zero-padded, + // so r[j] = 0 - dequantized[j]. These pad artifacts are nonzero + // because the SRHT mixes quantization error into the padded region. + // Omitting them would corrupt the QJL signs for non-power-of-2 dims. + for j in 0..dim { + residual[j] = x[j] - dequantized[j]; + } + for j in dim..padded_dim { + residual[j] = -dequantized[j]; + } + // The residual norm for QJL scaling is over the dim-dimensional + // subspace only — pad artifacts don't contribute to reconstruction + // error in the output space. The pad positions are still included + // in the sign projection to avoid corrupting the SRHT mixing. + let residual_norm = l2_norm(&residual[..dim]); + residual_norms_buf.push(residual_norm); + + // QJL: sign(S · r). + if residual_norm > 0.0 { + qjl_rotation.rotate(&residual, &mut projected); + } else { + projected.fill(0.0); + } + + for j in 0..padded_dim { + qjl_sign_u8.push(if projected[j] >= 0.0 { 1u8 } else { 0u8 }); + } + } + } + + // Build the MSE part. + let mut array = build_turboquant_mse(fsl, core, mse_bit_width)?; + + // Attach QJL correction. + let residual_norms_array = + PrimitiveArray::new::(residual_norms_buf.freeze(), Validity::NonNullable); + let qjl_signs_elements = PrimitiveArray::new::(qjl_sign_u8.freeze(), Validity::NonNullable); + let qjl_signs = FixedSizeListArray::try_new( + qjl_signs_elements.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )?; + let qjl_rotation_signs = bitpack_rotation_signs(&qjl_rotation)?; + + array.set_qjl(crate::encodings::turboquant::array::QjlCorrection { + signs: qjl_signs.into_array(), + residual_norms: residual_norms_array.into_array(), + rotation_signs: qjl_rotation_signs, + }); + + Ok(array.into_array()) +} + +/// Export rotation signs as a 1-bit `BitPackedArray` for efficient storage. +/// +/// The rotation matrix's 3 × padded_dim sign values are exported as 0/1 u8 +/// values in inverse application order, then bitpacked to 1 bit per sign. +/// On decode, FastLanes SIMD-unpacks back to `&[u8]` of 0/1 values. +fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult { + let signs_u8 = rotation.export_inverse_signs_u8(); + let mut buf = BufferMut::::with_capacity(signs_u8.len()); + buf.extend_from_slice(&signs_u8); + let prim = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + Ok(bitpack_encode(&prim, 1, None)?.into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs new file mode 100644 index 00000000000..2666270f6e8 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Approximate cosine similarity in the quantized domain. +//! +//! Since the SRHT is orthogonal, inner products are preserved in the rotated +//! domain. For two vectors from the same TurboQuant column (same rotation and +//! centroids), we can compute the dot product of their quantized representations +//! without full decompression: +//! +//! ```text +//! cos_approx(a, b) = sum(centroids[code_a[j]] × centroids[code_b[j]]) +//! ``` +//! +//! where `code_a` and `code_b` are the quantized coordinate indices of the +//! unit-norm rotated vectors `â_rot` and `b̂_rot`. +//! +//! # Bias and error bounds +//! +//! This estimate is **biased** — it uses only the MSE-quantized codes and does +//! not incorporate the QJL residual correction. The MSE quantizer minimizes +//! reconstruction error but does not guarantee unbiased inner products; the +//! discrete centroid grid introduces systematic bias in the dot product. +//! +//! The TurboQuant paper's Theorem 2 shows that unbiased inner product estimation +//! requires the full QJL correction term, which involves decoding the per-row +//! QJL signs and computing cross-terms — nearly as expensive as full decompression. +//! +//! The approximation error is bounded by the MSE quantization distortion. For +//! unit-norm vectors quantized at `b` bits, the per-coordinate MSE is bounded by +//! `(√3 · π / 2) / 4^b` (Theorem 1). The inner product error scales with this +//! distortion: at 4 bits the error is typically < 0.1, at 8 bits < 0.001. +//! +//! For approximate nearest neighbor (ANN) search, biased-but-accurate ranking is +//! usually sufficient — the relative ordering of cosine similarities is preserved +//! even if the absolute values have bounded error. + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use crate::encodings::turboquant::TurboQuant; + +/// Shared helper: read codes, norms, and centroids from two TurboQuant arrays, +/// then compute per-row quantized unit-norm dot products. +/// +/// Both arrays must have the same dimension (vector length) and row count. +/// They may have different codebooks (e.g., different bit widths), in which +/// case each array's own centroids are used for its code lookups. +/// +/// Returns `(norms_a, norms_b, unit_dots)` where `unit_dots[i]` is the dot product +/// of the unit-norm quantized vectors for row i. +fn quantized_unit_dots( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult<(Vec, Vec, Vec)> { + vortex_ensure!( + lhs.dimension() == rhs.dimension(), + "TurboQuant quantized dot product requires matching dimensions, got {} and {}", + lhs.dimension(), + rhs.dimension() + ); + + let pd = lhs.padded_dim() as usize; + let num_rows = lhs.norms().len(); + + let lhs_norms: PrimitiveArray = lhs.norms().clone().execute(ctx)?; + let rhs_norms: PrimitiveArray = rhs.norms().clone().execute(ctx)?; + let na = lhs_norms.as_slice::(); + let nb = rhs_norms.as_slice::(); + + let lhs_codes_fsl: FixedSizeListArray = lhs.codes().clone().execute(ctx)?; + let rhs_codes_fsl: FixedSizeListArray = rhs.codes().clone().execute(ctx)?; + let lhs_codes = lhs_codes_fsl.elements().to_canonical()?.into_primitive(); + let rhs_codes = rhs_codes_fsl.elements().to_canonical()?.into_primitive(); + let ca = lhs_codes.as_slice::(); + let cb = rhs_codes.as_slice::(); + + // Read centroids from both arrays — they may have different codebooks + // (e.g., different bit widths). + let lhs_centroids: PrimitiveArray = lhs.centroids().clone().execute(ctx)?; + let rhs_centroids: PrimitiveArray = rhs.centroids().clone().execute(ctx)?; + let cl = lhs_centroids.as_slice::(); + let cr = rhs_centroids.as_slice::(); + + let mut dots = Vec::with_capacity(num_rows); + for row in 0..num_rows { + let row_ca = &ca[row * pd..(row + 1) * pd]; + let row_cb = &cb[row * pd..(row + 1) * pd]; + let dot: f32 = row_ca + .iter() + .zip(row_cb.iter()) + .map(|(&a, &b)| cl[a as usize] * cr[b as usize]) + .sum(); + dots.push(dot); + } + + Ok((na.to_vec(), nb.to_vec(), dots)) +} + +/// Compute approximate cosine similarity for all rows between two TurboQuant +/// arrays (same rotation matrix and codebook) without full decompression. +pub fn cosine_similarity_quantized_column( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_rows = lhs.norms().len(); + let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?; + + let mut result = BufferMut::::with_capacity(num_rows); + for row in 0..num_rows { + if na[row] == 0.0 || nb[row] == 0.0 { + result.push(0.0); + } else { + // Unit-norm dot product IS the cosine similarity. + result.push(dots[row]); + } + } + + Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) +} + +/// Compute approximate dot product for all rows between two TurboQuant +/// arrays (same rotation matrix and codebook) without full decompression. +/// +/// `dot_product(a, b) ≈ ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])` +pub fn dot_product_quantized_column( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_rows = lhs.norms().len(); + let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?; + + let mut result = BufferMut::::with_capacity(num_rows); + for row in 0..num_rows { + // Scale the unit-norm dot product by both norms to get the actual dot product. + result.push(na[row] * nb[row] * dots[row]); + } + + Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/mod.rs b/vortex-tensor/src/encodings/turboquant/compute/mod.rs new file mode 100644 index 00000000000..67b4d3efb7f --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/mod.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Compute pushdown implementations for TurboQuant. + +pub(crate) mod cosine_similarity; +mod ops; +pub(crate) mod rules; +mod slice; +mod take; diff --git a/vortex-tensor/src/encodings/turboquant/compute/ops.rs b/vortex-tensor/src/encodings/turboquant/compute/ops.rs new file mode 100644 index 00000000000..89f25c61018 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/ops.rs @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_array::scalar::Scalar; +use vortex_array::vtable::OperationsVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::encodings::turboquant::array::TurboQuant; + +impl OperationsVTable for TurboQuant { + fn scalar_at( + array: ArrayView<'_, TurboQuant>, + index: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + // Slice to single row, decompress that one row. + let Some(sliced) = ::slice(array, index..index + 1)? else { + vortex_bail!("slice returned None for index {index}") + }; + let decoded = sliced.execute::(ctx)?; + decoded.scalar_at(0) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/rules.rs b/vortex-tensor/src/encodings/turboquant/compute/rules.rs new file mode 100644 index 00000000000..d482994f720 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/rules.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::dict::TakeExecuteAdaptor; +use vortex_array::arrays::slice::SliceReduceAdaptor; +use vortex_array::kernel::ParentKernelSet; +use vortex_array::optimizer::rules::ParentRuleSet; + +use crate::encodings::turboquant::array::TurboQuant; + +pub(crate) static RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&SliceReduceAdaptor(TurboQuant))]); + +pub(crate) static PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(TurboQuant))]); diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs new file mode 100644 index 00000000000..acd4f1a42ee --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/slice.rs @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::Range; + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::IntoArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::array::QjlCorrection; +use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::array::TurboQuantData; + +impl SliceReduce for TurboQuant { + fn slice( + array: ArrayView<'_, TurboQuant>, + range: Range, + ) -> VortexResult> { + let sliced_codes = array.codes().slice(range.clone())?; + let sliced_norms = array.norms().slice(range.clone())?; + + let sliced_qjl = array + .qjl() + .map(|qjl| -> VortexResult { + Ok(QjlCorrection { + signs: qjl.signs.slice(range.clone())?, + residual_norms: qjl.residual_norms.slice(range.clone())?, + rotation_signs: qjl.rotation_signs, + }) + }) + .transpose()?; + + let mut result = TurboQuantData::try_new_mse( + array.dtype.clone(), + sliced_codes, + sliced_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + array.dimension, + array.bit_width, + )?; + if let Some(qjl) = sliced_qjl { + result.set_qjl(qjl); + } + + Ok(Some(result.into_array())) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs new file mode 100644 index 00000000000..2779a907375 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/take.rs @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::dict::TakeExecute; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::array::QjlCorrection; +use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::array::TurboQuantData; + +impl TakeExecute for TurboQuant { + fn take( + array: ArrayView<'_, TurboQuant>, + indices: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // FSL children handle per-row take natively. + let taken_codes = array.codes().take(indices.clone())?; + let taken_norms = array.norms().take(indices.clone())?; + + let taken_qjl = array + .qjl() + .map(|qjl| -> VortexResult { + Ok(QjlCorrection { + signs: qjl.signs.take(indices.clone())?, + residual_norms: qjl.residual_norms.take(indices.clone())?, + rotation_signs: qjl.rotation_signs, + }) + }) + .transpose()?; + + let mut result = TurboQuantData::try_new_mse( + array.dtype.clone(), + taken_codes, + taken_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + array.dimension, + array.bit_width, + )?; + if let Some(qjl) = taken_qjl { + result.set_qjl(qjl); + } + + Ok(Some(result.into_array())) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs new file mode 100644 index 00000000000..5f5c68ce802 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant decoding (dequantization) logic. + +use vortex_array::Array; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::rotation::RotationMatrix; + +/// QJL correction scale factor: `sqrt(π/2) / padded_dim`. +/// +/// Accounts for the SRHT normalization (`1/padded_dim^{3/2}` per transform) +/// combined with `E[|z|] = sqrt(2/π)` for half-normal sign expectations. +#[inline] +fn qjl_correction_scale(padded_dim: usize) -> f32 { + (std::f32::consts::FRAC_PI_2).sqrt() / (padded_dim as f32) +} + +/// Decompress a `TurboQuantArray` into a `FixedSizeListArray` of floats. +/// +/// Reads stored centroids and rotation signs from the array's children, +/// avoiding any recomputation. If QJL correction is present, applies +/// the residual correction after MSE decoding. +pub fn execute_decompress( + array: Array, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dim = array.dimension() as usize; + let padded_dim = array.padded_dim() as usize; + let num_rows = array.norms().len(); + + if num_rows == 0 { + let elements = PrimitiveArray::empty::(array.dtype.nullability()); + return Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + 0, + )? + .into_array()); + } + + // Read stored centroids — no recomputation. + let centroids_prim = array.centroids().clone().execute::(ctx)?; + let centroids = centroids_prim.as_slice::(); + + // FastLanes SIMD-unpacks the 1-bit bitpacked rotation signs into u8 0/1 values, + // then we expand to u32 XOR masks once (amortized over all rows). This enables + // branchless XOR-based sign application in the per-row SRHT hot loop. + let signs_prim = array + .rotation_signs() + .clone() + .execute::(ctx)?; + let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::(), dim)?; + + // Unpack codes from FixedSizeListArray → flat u8 elements. + let codes_fsl = array.codes().clone().execute::(ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let indices = codes_prim.as_slice::(); + + let norms_prim = array.norms().clone().execute::(ctx)?; + let norms = norms_prim.as_slice::(); + + // MSE decode: dequantize → inverse rotate → scale by norm. + let mut mse_output = BufferMut::::with_capacity(num_rows * dim); + let mut dequantized = vec![0.0f32; padded_dim]; + let mut unrotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim]; + let norm = norms[row]; + + for idx in 0..padded_dim { + dequantized[idx] = centroids[row_indices[idx] as usize]; + } + + rotation.inverse_rotate(&dequantized, &mut unrotated); + + for idx in 0..dim { + unrotated[idx] *= norm; + } + + mse_output.extend_from_slice(&unrotated[..dim]); + } + + // If no QJL correction, we're done. + let Some(qjl) = array.qjl() else { + let elements = PrimitiveArray::new::(mse_output.freeze(), Validity::NonNullable); + return Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )? + .into_array()); + }; + + // Apply QJL residual correction. + // Unpack QJL signs from FixedSizeListArray → flat u8 0/1 values. + let qjl_signs_fsl = qjl.signs.clone().execute::(ctx)?; + let qjl_signs_prim = qjl_signs_fsl.elements().to_canonical()?.into_primitive(); + let qjl_signs_u8 = qjl_signs_prim.as_slice::(); + + let residual_norms_prim = qjl.residual_norms.clone().execute::(ctx)?; + let residual_norms = residual_norms_prim.as_slice::(); + + let qjl_rot_signs_prim = qjl.rotation_signs.execute::(ctx)?; + let qjl_rot = RotationMatrix::from_u8_slice(qjl_rot_signs_prim.as_slice::(), dim)?; + + let qjl_scale = qjl_correction_scale(padded_dim); + let mse_elements = mse_output.as_ref(); + + let mut output = BufferMut::::with_capacity(num_rows * dim); + let mut qjl_signs_vec = vec![0.0f32; padded_dim]; + let mut qjl_projected = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let mse_row = &mse_elements[row * dim..(row + 1) * dim]; + let residual_norm = residual_norms[row]; + + // Branchless u8 0/1 → f32 ±1.0 via XOR on the IEEE 754 sign bit. + // 1.0f32 = 0x3F800000; flipping the sign bit gives -1.0 = 0xBF800000. + // For sign=0 (negative): mask = 0x80000000, 1.0 XOR mask = -1.0. + // For sign=1 (positive): mask = 0x00000000, 1.0 XOR mask = 1.0. + let row_signs = &qjl_signs_u8[row * padded_dim..(row + 1) * padded_dim]; + for (dst, &sign) in qjl_signs_vec.iter_mut().zip(row_signs.iter()) { + let mask = ((sign as u32) ^ 1) << 31; + *dst = f32::from_bits(0x3F80_0000 ^ mask); + } + + qjl_rot.inverse_rotate(&qjl_signs_vec, &mut qjl_projected); + let scale = qjl_scale * residual_norm; + + for idx in 0..dim { + output.push(mse_row[idx] + scale * qjl_projected[idx]); + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )? + .into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs new file mode 100644 index 00000000000..b537db29029 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -0,0 +1,1018 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant vector quantization encoding for Vortex. +//! +//! Implements the TurboQuant algorithm ([arXiv:2504.19874]) for lossy compression of +//! high-dimensional vector data. The encoding operates on `FixedSizeList` arrays of floats +//! (the storage format of `Vector` and `FixedShapeTensor` extension types). +//! +//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! +//! # Variants +//! +//! - **MSE** (`TurboQuantVariant::Mse`): Minimizes mean-squared reconstruction error +//! (1-8 bits per coordinate). +//! - **Prod** (`TurboQuantVariant::Prod`): Preserves inner products with an unbiased +//! estimator (uses `b-1` bits for MSE + 1-bit QJL residual correction, 2-9 bits). +//! At `b=9`, the MSE codes are raw int8 values suitable for direct use with +//! tensor core int8 GEMM kernels. +//! +//! # Theoretical error bounds +//! +//! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1 +//! guarantees normalized MSE distortion: +//! +//! > `E[||x - x̂||² / ||x||²] ≤ (√3 · π / 2) / 4^b` +//! +//! | Bits | MSE bound | Quality | +//! |------|------------|-------------------| +//! | 1 | 6.80e-01 | Poor | +//! | 2 | 1.70e-01 | Usable for ANN | +//! | 3 | 4.25e-02 | Good | +//! | 4 | 1.06e-02 | Very good | +//! | 5 | 2.66e-03 | Excellent | +//! | 6 | 6.64e-04 | Near-lossless | +//! | 7 | 1.66e-04 | Near-lossless | +//! | 8 | 4.15e-05 | Near-lossless | +//! +//! # Compression ratios +//! +//! Each vector is stored as `padded_dim × bit_width / 8` bytes of quantized codes plus a +//! 4-byte f32 norm. Non-power-of-2 dimensions are padded to the next power of 2 for the +//! Walsh-Hadamard transform, which reduces the effective ratio for those dimensions. +//! +//! | dim | padded | bits | f32 bytes | TQ bytes | ratio | +//! |------|--------|------|-----------|----------|--------| +//! | 768 | 1024 | 2 | 3072 | 260 | 11.8x | +//! | 1024 | 1024 | 2 | 4096 | 260 | 15.8x | +//! | 768 | 1024 | 4 | 3072 | 516 | 6.0x | +//! | 1024 | 1024 | 4 | 4096 | 516 | 7.9x | +//! | 768 | 1024 | 8 | 3072 | 1028 | 3.0x | +//! | 1024 | 1024 | 8 | 4096 | 1028 | 4.0x | +//! +//! # Example +//! +//! ``` +//! use vortex_array::IntoArray; +//! use vortex_array::arrays::FixedSizeListArray; +//! use vortex_array::arrays::PrimitiveArray; +//! use vortex_array::validity::Validity; +//! use vortex_buffer::BufferMut; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_mse}; +//! +//! // Create a FixedSizeListArray of 100 random 128-d vectors. +//! let num_rows = 100; +//! let dim = 128; +//! let mut buf = BufferMut::::with_capacity(num_rows * dim); +//! for i in 0..(num_rows * dim) { +//! buf.push((i as f32 * 0.001).sin()); +//! } +//! let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); +//! let fsl = FixedSizeListArray::try_new( +//! elements.into_array(), dim as u32, Validity::NonNullable, num_rows, +//! ).unwrap(); +//! +//! // Quantize at 2 bits per coordinate using MSE-optimal encoding. +//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; +//! let encoded = turboquant_encode_mse(&fsl, &config).unwrap(); +//! +//! // Verify compression: 100 vectors × 128 dims × 4 bytes = 51200 bytes input. +//! assert!(encoded.nbytes() < 51200); +//! ``` + +pub use array::QjlCorrection; +pub use array::TurboQuant; +pub use array::TurboQuantData; +pub use compress::TurboQuantConfig; +pub use compress::turboquant_encode_mse; +pub use compress::turboquant_encode_qjl; + +mod array; +pub(crate) mod centroids; +mod compress; +pub(crate) mod compute; +pub(crate) mod decompress; +pub(crate) mod rotation; +pub mod scheme; +mod vtable; + +/// Extension ID for the `Vector` type from `vortex-tensor`. +pub const VECTOR_EXT_ID: &str = "vortex.tensor.vector"; + +/// Extension ID for the `FixedShapeTensor` type from `vortex-tensor`. +pub const FIXED_SHAPE_TENSOR_EXT_ID: &str = "vortex.tensor.fixed_shape_tensor"; + +use vortex_array::session::ArraySessionExt; +use vortex_session::VortexSession; + +/// Initialize the TurboQuant encoding in the given session. +pub fn initialize(session: &mut VortexSession) { + session.arrays().register(TurboQuant); +} + +#[cfg(test)] +#[allow(clippy::cast_possible_truncation)] +mod tests { + use std::sync::LazyLock; + + use rand::RngExt; + use rand::SeedableRng; + use rand::rngs::StdRng; + use rand_distr::Distribution; + use rand_distr::Normal; + use rstest::rstest; + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::session::ArraySession; + use vortex_array::validity::Validity; + use vortex_buffer::BufferMut; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::encodings::turboquant::TurboQuant; + use crate::encodings::turboquant::TurboQuantConfig; + use crate::encodings::turboquant::rotation::RotationMatrix; + use crate::encodings::turboquant::turboquant_encode_mse; + use crate::encodings::turboquant::turboquant_encode_qjl; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + /// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal). + fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + ) + .unwrap() + } + + fn theoretical_mse_bound(bit_width: u8) -> f32 { + let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; + sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) + } + + fn per_vector_normalized_mse( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, + ) -> f32 { + let mut total = 0.0f32; + for row in 0..num_rows { + let orig = &original[row * dim..(row + 1) * dim]; + let recon = &reconstructed[row * dim..(row + 1) * dim]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + total += err_sq / norm_sq; + } + total / num_rows as f32 + } + + /// Encode and decode, returning (original, decoded) flat f32 slices. + fn encode_decode( + fsl: &FixedSizeListArray, + encode_fn: impl FnOnce(&FixedSizeListArray) -> VortexResult, + ) -> VortexResult<(Vec, Vec)> { + let original: Vec = { + let prim = fsl.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + let encoded = encode_fn(fsl)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + let decoded_elements: Vec = { + let prim = decoded.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + Ok((original, decoded_elements)) + } + + fn encode_decode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| turboquant_encode_mse(fsl, &config)) + } + + fn encode_decode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| turboquant_encode_qjl(fsl, &config)) + } + + // ----------------------------------------------------------------------- + // MSE encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 1)] + #[case(32, 2)] + #[case(32, 3)] + #[case(32, 4)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 2)] + fn roundtrip_mse(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 2)] + #[case(256, 4)] + fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + let bound = theoretical_mse_bound(bit_width); + + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + #[rstest] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 6)] + #[case(256, 8)] + fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let (original_4, decoded_4) = encode_decode_mse(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" + ); + assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); + Ok(()) + } + + #[test] + fn mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 1..=8u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 2)] + #[case(32, 3)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + fn roundtrip_qjl(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(456), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + /// Compute the mean signed relative error of QJL inner product estimation + /// over random query/vector pairs. + fn qjl_mean_signed_relative_error( + original: &[f32], + decoded: &[f32], + dim: usize, + num_rows: usize, + ) -> f32 { + let num_pairs = 500; + let mut rng = StdRng::seed_from_u64(0); + let mut signed_errors = Vec::with_capacity(num_pairs); + + for _ in 0..num_pairs { + let qi = rng.random_range(0..num_rows); + let xi = rng.random_range(0..num_rows); + if qi == xi { + continue; + } + + let query = &original[qi * dim..(qi + 1) * dim]; + let orig_vec = &original[xi * dim..(xi + 1) * dim]; + let quant_vec = &decoded[xi * dim..(xi + 1) * dim]; + + let true_ip: f32 = query.iter().zip(orig_vec).map(|(&a, &b)| a * b).sum(); + let quant_ip: f32 = query.iter().zip(quant_vec).map(|(&a, &b)| a * b).sum(); + + if true_ip.abs() > 1e-6 { + signed_errors.push((quant_ip - true_ip) / true_ip.abs()); + } + } + + if signed_errors.is_empty() { + return 0.0; + } + + signed_errors.iter().sum::() / signed_errors.len() as f32 + } + + #[rstest] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + #[case(768, 4)] + fn qjl_inner_product_bias(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 100; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let mean_rel_error = qjl_mean_signed_relative_error(&original, &decoded, dim, num_rows); + + // Known limitation: non-power-of-2 dims have elevated QJL bias (~23% vs + // ~11%) due to distribution mismatch between the SRHT zero-padded coordinate + // distribution and the analytical (1-x^2)^((d-3)/2) model used for centroids. + // Investigated approaches: + // - Random permutation of zeros: no effect (issue is distribution shape) + // - MC empirical centroids: fixes QJL bias but regresses MSE quality + // - Analytical centroids with dim instead of padded_dim: mixed results + // The principled fix requires jointly correcting centroids and QJL scale + // factor for the actual SRHT zero-padded distribution. + let threshold = if dim.is_power_of_two() { 0.15 } else { 0.25 }; + assert!( + mean_rel_error.abs() < threshold, + "QJL inner product bias too high: {mean_rel_error:.4} for dim={dim}, bits={bit_width} \ + (threshold={threshold})" + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Edge cases + // ----------------------------------------------------------------------- + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_mse_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_qjl_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[rstest] + #[case(1)] + #[case(2)] + fn mse_rejects_dimension_below_3(#[case] dim: usize) { + let fsl = make_fsl_small(dim); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(0), + }; + assert!(turboquant_encode_mse(&fsl, &config).is_err()); + } + + #[rstest] + #[case(1)] + #[case(2)] + fn qjl_rejects_dimension_below_3(#[case] dim: usize) { + let fsl = make_fsl_small(dim); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(0), + }; + assert!(turboquant_encode_qjl(&fsl, &config).is_err()); + } + + fn make_fsl_small(dim: usize) -> FixedSizeListArray { + let mut buf = BufferMut::::with_capacity(dim); + for i in 0..dim { + buf.push(i as f32 + 1.0); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new(elements.into_array(), dim as u32, Validity::NonNullable, 1) + .unwrap() + } + + // ----------------------------------------------------------------------- + // Verification tests for stored metadata + // ----------------------------------------------------------------------- + + /// Verify that the centroids stored in the MSE array match what get_centroids() computes. + #[test] + fn stored_centroids_match_computed() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + let mut ctx = SESSION.create_execution_ctx(); + let stored_centroids_prim = encoded + .centroids() + .clone() + .execute::(&mut ctx)?; + let stored = stored_centroids_prim.as_slice::(); + + let padded_dim = encoded.padded_dim(); + let computed = crate::encodings::turboquant::centroids::get_centroids(padded_dim, 3)?; + + assert_eq!(stored.len(), computed.len()); + for i in 0..stored.len() { + assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); + } + Ok(()) + } + + /// Verify that stored rotation signs produce identical decode to seed-based decode. + /// + /// Encodes the same data twice: once with the new path (stored signs), and + /// once by manually recomputing the rotation from the seed. Both should + /// produce identical output. + #[test] + fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + // Decode via the stored-signs path (normal decode). + let mut ctx = SESSION.create_execution_ctx(); + let decoded_fsl = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_slice = decoded.as_slice::(); + + // Verify stored signs match seed-derived signs. + let rot_from_seed = RotationMatrix::try_new(123, 128)?; + let expected_u8 = rot_from_seed.export_inverse_signs_u8(); + let stored_signs = encoded + .rotation_signs() + .clone() + .execute::(&mut ctx)?; + let stored_u8 = stored_signs.as_slice::(); + + assert_eq!(expected_u8.len(), stored_u8.len()); + for i in 0..expected_u8.len() { + assert_eq!(expected_u8[i], stored_u8[i], "Sign mismatch at index {i}"); + } + + // Also verify decode output is non-empty and has expected size. + assert_eq!(decoded_slice.len(), 20 * 128); + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL-specific quality tests + // ----------------------------------------------------------------------- + + /// Verify that QJL's MSE component (at bit_width-1) satisfies the theoretical bound. + #[rstest] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 3)] + fn qjl_mse_within_theoretical_bound( + #[case] dim: usize, + #[case] bit_width: u8, + ) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + // QJL at b bits uses (b-1)-bit MSE plus a correction term. + // The MSE should be at most the (b-1)-bit theoretical bound, though + // in practice the QJL correction often improves it further. + let mse_bound = theoretical_mse_bound(bit_width - 1); + assert!( + normalized_mse < mse_bound, + "QJL MSE {normalized_mse:.6} exceeds (b-1)-bit bound {mse_bound:.6} \ + for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + /// Verify that high-bitwidth QJL (8-9 bits) achieves very low distortion. + #[rstest] + #[case(128, 8)] + #[case(128, 9)] + fn high_bitwidth_qjl_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + // Compare against 4-bit QJL as reference ceiling. + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(789), + }; + let (original_4, decoded_4) = encode_decode_qjl(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 4-bit ({mse_4bit:.6})" + ); + assert!( + mse < 0.01, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 1%" + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Edge case and input format tests + // ----------------------------------------------------------------------- + + /// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). + #[test] + fn all_zero_vectors_roundtrip() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let buf = BufferMut::::full(0.0f32, num_rows * dim); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + // All-zero vectors should decode to all-zero (norm=0 → 0 * anything = 0). + for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { + assert_eq!(o, 0.0, "original[{i}] not zero"); + assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); + } + Ok(()) + } + + /// Verify that f64 input is accepted and encoded (converted to f32 internally). + #[test] + fn f64_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 64; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f64, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + // Verify encoding succeeds with f64 input (f64→f32 conversion). + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + assert_eq!(encoded.norms().len(), num_rows); + assert_eq!(encoded.dimension(), dim as u32); + Ok(()) + } + + /// Verify serde roundtrip: serialize MSE array metadata + children, then rebuild. + #[test] + fn mse_serde_roundtrip() -> VortexResult<()> { + use vortex_array::vtable::VTable; + + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + // Serialize metadata. + let metadata = ::metadata(encoded)?; + let serialized = + ::serialize(metadata)?.expect("metadata should serialize"); + + // Collect children. + let nchildren = ::nchildren(encoded); + assert_eq!(nchildren, 4); + let children: Vec = (0..nchildren) + .map(|i| ::child(encoded, i)) + .collect(); + + // Deserialize and rebuild. + let deserialized = ::deserialize( + &serialized, + encoded.dtype(), + encoded.len(), + &[], + &SESSION, + )?; + + // Verify metadata fields survived roundtrip. + assert_eq!(deserialized.dimension, encoded.dimension()); + assert_eq!(deserialized.bit_width, encoded.bit_width() as u32); + assert_eq!(deserialized.has_qjl, encoded.has_qjl()); + + // Verify the rebuilt array decodes identically. + let mut ctx = SESSION.create_execution_ctx(); + let decoded_original = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); + + // Rebuild from children (simulating deserialization). + let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new_mse( + encoded.dtype().clone(), + children[0].clone(), + children[1].clone(), + children[2].clone(), + children[3].clone(), + deserialized.dimension, + deserialized.bit_width as u8, + )?; + let decoded_rebuilt = rebuilt + .into_array() + .execute::(&mut ctx)?; + let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); + + assert_eq!( + original_elements.as_slice::(), + rebuilt_elements.as_slice::() + ); + Ok(()) + } + + /// Verify serde roundtrip for QJL: serialize metadata + children, then rebuild. + #[test] + fn qjl_serde_roundtrip() -> VortexResult<()> { + use vortex_array::vtable::VTable; + + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + // Serialize metadata. + let metadata = ::metadata(encoded)?; + let serialized = + ::serialize(metadata)?.expect("metadata should serialize"); + + // Collect children — QJL has 7 (4 MSE + 3 QJL). + let nchildren = ::nchildren(encoded); + assert_eq!(nchildren, 7); + let children: Vec = (0..nchildren) + .map(|i| ::child(encoded, i)) + .collect(); + + // Deserialize metadata. + let deserialized = ::deserialize( + &serialized, + encoded.dtype(), + encoded.len(), + &[], + &SESSION, + )?; + + assert!(deserialized.has_qjl); + assert_eq!(deserialized.dimension, encoded.dimension()); + + // Verify decode: original vs rebuilt from children. + let mut ctx = SESSION.create_execution_ctx(); + let decoded_original = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); + + // Rebuild with QJL children. + let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new_qjl( + encoded.dtype().clone(), + children[0].clone(), + children[1].clone(), + children[2].clone(), + children[3].clone(), + crate::encodings::turboquant::array::QjlCorrection { + signs: children[4].clone(), + residual_norms: children[5].clone(), + rotation_signs: children[6].clone(), + }, + deserialized.dimension, + deserialized.bit_width as u8, + )?; + let decoded_rebuilt = rebuilt + .into_array() + .execute::(&mut ctx)?; + let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); + + assert_eq!( + original_elements.as_slice::(), + rebuilt_elements.as_slice::() + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Compute pushdown tests + // ----------------------------------------------------------------------- + + #[test] + fn slice_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + // Full decompress then slice. + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let expected = full_decoded.slice(5..10)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + // Slice then decompress. + let sliced = encoded.slice(5..10)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) + } + + #[test] + fn slice_qjl_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let expected = full_decoded.slice(3..8)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + let sliced = encoded.slice(3..8)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) + } + + #[test] + fn scalar_at_matches_decompress() -> VortexResult<()> { + let fsl = make_fsl(10, 64, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + + for i in [0, 1, 5, 9] { + let expected = full_decoded.scalar_at(i)?; + let actual = encoded.scalar_at(i)?; + assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); + } + Ok(()) + } + + #[test] + fn l2_norm_readthrough() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let tq = encoded.as_opt::().unwrap(); + + // Stored norms should match the actual L2 norms of the input. + let norms_prim = tq.norms().to_canonical()?.into_primitive(); + let stored_norms = norms_prim.as_slice::(); + + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + for row in 0..10 { + let vec = &input_f32[row * 128..(row + 1) * 128]; + let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); + assert!( + (stored_norms[row] - actual_norm).abs() < 1e-5, + "norm mismatch at row {row}: stored={}, actual={}", + stored_norms[row], + actual_norm + ); + } + Ok(()) + } + + #[test] + fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { + use vortex_array::arrays::FixedSizeListArray; + + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let tq = encoded.as_opt::().unwrap(); + + // Compute exact cosine similarity from original data. + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + + // Read quantized codes, norms, and centroids for approximate computation. + let mut ctx = SESSION.create_execution_ctx(); + let pd = tq.padded_dim() as usize; + let norms_prim = tq.norms().clone().execute::(&mut ctx)?; + let norms = norms_prim.as_slice::(); + let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let all_codes = codes_prim.as_slice::(); + let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; + let centroid_vals = centroids_prim.as_slice::(); + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); + let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); + let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); + let exact_cos = dot / (norm_a * norm_b); + + // Approximate cosine similarity in quantized domain. + let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { + 0.0 + } else { + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) + .sum::() + }; + + // 4-bit quantization: expect reasonable accuracy. + let error = (exact_cos - approx_cos).abs(); + assert!( + error < 0.15, + "cosine similarity error too large for ({row_a}, {row_b}): \ + exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" + ); + } + Ok(()) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/rotation.rs b/vortex-tensor/src/encodings/turboquant/rotation.rs new file mode 100644 index 00000000000..2f654349778 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/rotation.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Deterministic random rotation for TurboQuant. +//! +//! Uses a Structured Random Hadamard Transform (SRHT) for O(d log d) rotation +//! instead of a full d×d matrix multiply. The SRHT applies the sequence +//! D₃ · H · D₂ · H · D₁ where H is the Walsh-Hadamard Transform (WHT) and Dₖ are +//! random diagonal ±1 sign matrices. Three rounds of HD provide sufficient +//! randomness for near-uniform distribution on the sphere. +//! +//! For dimensions that are not powers of 2, the input is zero-padded to the +//! next power of 2 before the transform and truncated afterward. +//! +//! # Sign representation +//! +//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) +//! and `0x80000000` for -1 (flip IEEE 754 sign bit). The sign application +//! function uses integer XOR instead of floating-point multiply, which avoids +//! FP dependency chains and auto-vectorizes into `vpxor`/`veor`. + +use rand::RngExt; +use rand::SeedableRng; +use rand::rngs::StdRng; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// IEEE 754 sign bit mask for f32. +const F32_SIGN_BIT: u32 = 0x8000_0000; + +/// A structured random Hadamard transform for O(d log d) pseudo-random rotation. +pub struct RotationMatrix { + /// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`. + /// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit). + sign_masks: [Vec; 3], + /// The padded dimension (next power of 2 >= dimension). + padded_dim: usize, + /// Normalization factor: 1/(padded_dim * sqrt(padded_dim)), applied once at the end. + norm_factor: f32, +} + +impl RotationMatrix { + /// Create a new SRHT rotation from a deterministic seed. + pub fn try_new(seed: u64, dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + let mut rng = StdRng::seed_from_u64(seed); + + let sign_masks = std::array::from_fn(|_| gen_random_sign_masks(&mut rng, padded_dim)); + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } + + /// Apply forward rotation: `output = SRHT(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. The caller + /// is responsible for zero-padding input beyond `dim` positions. + pub fn rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_srht(output); + } + + /// Apply inverse rotation: `output = SRHT⁻¹(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. + pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_inverse_srht(output); + } + + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All rotate/inverse_rotate buffers must be this length. + pub fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the SRHT: D₃ · H · D₂ · H · D₁ · x, with normalization. + fn apply_srht(&self, buf: &mut [f32]) { + apply_signs_xor(buf, &self.sign_masks[0]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[1]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[2]); + walsh_hadamard_transform(buf); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Apply the inverse SRHT. + /// + /// Forward is: norm · H · D₃ · H · D₂ · H · D₁ + /// Inverse is: norm · D₁ · H · D₂ · H · D₃ · H + fn apply_inverse_srht(&self, buf: &mut [f32]) { + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[2]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[1]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[0]); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Export the 3 sign vectors as a flat `Vec` of 0/1 values in inverse + /// application order `[D₃ | D₂ | D₁]`. + /// + /// Convention: `1` = positive (+1), `0` = negative (-1). + /// The output has length `3 * padded_dim` and is suitable for bitpacking + /// via FastLanes `bitpack_encode(..., 1, None)`. + pub fn export_inverse_signs_u8(&self) -> Vec { + let total = 3 * self.padded_dim; + let mut out = Vec::with_capacity(total); + + // Store in inverse order: sign_masks[2] (D₃), sign_masks[1] (D₂), sign_masks[0] (D₁) + for sign_idx in [2, 1, 0] { + for &mask in &self.sign_masks[sign_idx] { + out.push(if mask == 0 { 1u8 } else { 0u8 }); + } + } + out + } + + /// Reconstruct a `RotationMatrix` from unpacked `u8` 0/1 values. + /// + /// The input must have length `3 * padded_dim` with signs in inverse + /// application order `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]). + /// Convention: `1` = positive, `0` = negative. + /// + /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the + /// stored `BitPackedArray` into `&[u8]`, which is passed here. + pub fn from_u8_slice(signs_u8: &[u8], dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + vortex_ensure!( + signs_u8.len() == 3 * padded_dim, + "Expected {} sign bytes, got {}", + 3 * padded_dim, + signs_u8.len() + ); + + // Reconstruct in storage order (inverse): [D₃, D₂, D₁] → sign_masks[2], [1], [0] + let mut sign_masks: [Vec; 3] = std::array::from_fn(|_| Vec::with_capacity(padded_dim)); + + for (round, sign_idx) in [2, 1, 0].into_iter().enumerate() { + let offset = round * padded_dim; + sign_masks[sign_idx] = signs_u8[offset..offset + padded_dim] + .iter() + .map(|&v| if v != 0 { 0u32 } else { F32_SIGN_BIT }) + .collect(); + } + + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } +} + +/// Generate a vector of random XOR sign masks. +fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec { + (0..len) + .map(|_| { + if rng.random_bool(0.5) { + 0u32 // +1: no-op + } else { + F32_SIGN_BIT // -1: flip sign bit + } + }) + .collect() +} + +/// Apply sign masks via XOR on the IEEE 754 sign bit. +/// +/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). +/// Equivalent to multiplying each element by ±1.0, but avoids FP dependency chains. +#[inline] +fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) { + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } +} + +/// In-place Walsh-Hadamard Transform (unnormalized, iterative). +/// +/// Input length must be a power of 2. Runs in O(n log n). +/// +/// Uses a fixed-size chunk strategy: for each stage, the buffer is processed +/// in `CHUNK`-element blocks with a compile-time-known butterfly function. +/// This lets LLVM unroll and auto-vectorize the butterfly into NEON/AVX SIMD. +fn walsh_hadamard_transform(buf: &mut [f32]) { + let len = buf.len(); + debug_assert!(len.is_power_of_two()); + + let mut half = 1; + while half < len { + let stride = half * 2; + // Process in chunks of `stride` elements. Within each chunk, + // split into non-overlapping (lo, hi) halves for the butterfly. + for chunk in buf.chunks_exact_mut(stride) { + let (lo, hi) = chunk.split_at_mut(half); + butterfly(lo, hi); + } + half *= 2; + } +} + +/// Butterfly: `lo[i], hi[i] = lo[i] + hi[i], lo[i] - hi[i]`. +/// +/// Separate function so LLVM can see the slice lengths match and auto-vectorize. +#[inline(always)] +fn butterfly(lo: &mut [f32], hi: &mut [f32]) { + debug_assert_eq!(lo.len(), hi.len()); + for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { + let sum = *a + *b; + let diff = *a - *b; + *a = sum; + *b = diff; + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[test] + fn deterministic_from_seed() -> VortexResult<()> { + let r1 = RotationMatrix::try_new(42, 64)?; + let r2 = RotationMatrix::try_new(42, 64)?; + let pd = r1.padded_dim(); + + let mut input = vec![0.0f32; pd]; + for i in 0..64 { + input[i] = i as f32; + } + let mut out1 = vec![0.0f32; pd]; + let mut out2 = vec![0.0f32; pd]; + + r1.rotate(&input, &mut out1); + r2.rotate(&input, &mut out2); + + assert_eq!(out1, out2); + Ok(()) + } + + /// Verify roundtrip is exact to f32 precision across many dimensions, + /// including non-power-of-two dimensions that require padding. + #[rstest] + #[case(32)] + #[case(64)] + #[case(100)] + #[case(128)] + #[case(256)] + #[case(512)] + #[case(768)] + #[case(1024)] + fn roundtrip_exact(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + let mut rotated = vec![0.0f32; padded_dim]; + let mut recovered = vec![0.0f32; padded_dim]; + + rot.rotate(&input, &mut rotated); + rot.inverse_rotate(&rotated, &mut recovered); + + let max_err: f32 = input + .iter() + .zip(recovered.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); + let rel_err = max_err / max_val; + + // SRHT roundtrip should be exact up to f32 precision (~1e-6). + assert!( + rel_err < 1e-5, + "roundtrip relative error too large for dim={dim}: {rel_err:.2e}" + ); + Ok(()) + } + + /// Verify norm preservation across dimensions. + #[rstest] + #[case(128)] + #[case(768)] + fn preserves_norm(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(7, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32) * 0.01; + } + let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); + + let mut rotated = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut rotated); + let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::().sqrt(); + + assert!( + (input_norm - rotated_norm).abs() / input_norm < 1e-5, + "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", + input_norm, + rotated_norm, + (input_norm - rotated_norm).abs() / input_norm + ); + Ok(()) + } + + /// Verify that export → from_u8_slice produces identical rotation output. + #[rstest] + #[case(64)] + #[case(128)] + #[case(768)] + fn sign_export_import_roundtrip(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let signs_u8 = rot.export_inverse_signs_u8(); + let rot2 = RotationMatrix::from_u8_slice(&signs_u8, dim)?; + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + + let mut out1 = vec![0.0f32; padded_dim]; + let mut out2 = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut out1); + rot2.rotate(&input, &mut out2); + assert_eq!(out1, out2, "Forward rotation mismatch after export/import"); + + rot.inverse_rotate(&out1, &mut out2); + let mut out3 = vec![0.0f32; padded_dim]; + rot2.inverse_rotate(&out1, &mut out3); + assert_eq!(out2, out3, "Inverse rotation mismatch after export/import"); + + Ok(()) + } + + #[test] + fn wht_basic() { + // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] + let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; + walsh_hadamard_transform(&mut buf); + assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); + + // WHT is self-inverse (up to scaling by n) + walsh_hadamard_transform(&mut buf); + // After two WHTs: each element multiplied by n=4 + assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs new file mode 100644 index 00000000000..6db642ae25f --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant compression scheme for the pluggable compressor. + +use vortex_array::ArrayRef; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_compressor::CascadingCompressor; +use vortex_compressor::ctx::CompressorContext; +use vortex_compressor::scheme::Scheme; +use vortex_compressor::stats::ArrayAndStats; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; + +use super::FIXED_SHAPE_TENSOR_EXT_ID; +use super::TurboQuantConfig; +use super::VECTOR_EXT_ID; +use super::turboquant_encode_qjl; + +/// TurboQuant compression scheme for tensor extension types. +/// +/// Applies lossy vector quantization to `Vector` and `FixedShapeTensor` extension +/// arrays using the TurboQuant algorithm with QJL correction for unbiased inner +/// product estimation. +/// +/// Register this scheme with the compressor builder via `with_scheme`: +/// ```ignore +/// use vortex_btrblocks::BtrBlocksCompressorBuilder; +/// use vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME; +/// +/// let compressor = BtrBlocksCompressorBuilder::default() +/// .with_scheme(&TURBOQUANT_SCHEME) +/// .build(); +/// ``` +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct TurboQuantScheme; + +/// Static instance for registration with `BtrBlocksCompressorBuilder::with_scheme`. +pub static TURBOQUANT_SCHEME: TurboQuantScheme = TurboQuantScheme; + +impl Scheme for TurboQuantScheme { + fn scheme_name(&self) -> &'static str { + "vortex.tensor.turboquant" + } + + fn matches(&self, canonical: &Canonical) -> bool { + let Canonical::Extension(ext) = canonical else { + return false; + }; + + get_tensor_element_ptype_and_length(ext.dtype()).is_ok() + } + + fn expected_compression_ratio( + &self, + _compressor: &CascadingCompressor, + data: &mut ArrayAndStats, + _ctx: CompressorContext, + ) -> VortexResult { + let dtype = data.array().dtype(); + let len = data.array().len(); + let (element_ptype, dimensions) = get_tensor_element_ptype_and_length(dtype)?; + Ok(estimate_compression_ratio( + element_ptype.bit_width(), + dimensions, + len, + )) + } + + fn compress( + &self, + _compressor: &CascadingCompressor, + data: &mut ArrayAndStats, + _ctx: CompressorContext, + ) -> VortexResult { + let array = data.array().clone(); + let ext_array = array.to_canonical()?.into_extension(); + let storage = ext_array.storage_array(); + let fsl = storage.to_canonical()?.into_fixed_size_list(); + + let config = TurboQuantConfig::default(); + let encoded = turboquant_encode_qjl(&fsl, &config)?; + + Ok(ExtensionArray::new(ext_array.ext_dtype().clone(), encoded).into_array()) + } +} + +/// Estimate the compression ratio for TurboQuant QJL encoding with the default config. +/// +/// Uses the default [`TurboQuantConfig`] (5-bit QJL = 4-bit MSE + 1-bit QJL signs). +fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vectors: usize) -> f64 { + let config = TurboQuantConfig::default(); + let padded_dim = dimensions.next_power_of_two() as usize; + + // Per-vector: MSE codes + QJL signs per padded coordinate, + // plus two f32 values (norm and QJL residual norm). + let compressed_bits_per_vector = 2 * 32 // norm + residual_norm are always f32 + + (config.bit_width as usize) * padded_dim; // MSE codes + QJL sign bits + + // Shared overhead: codebook centroids (2^mse_bit_width f32 values) and + // rotation signs (3 * padded_dim bits each for MSE and QJL rotations). + let mse_bit_width = config.bit_width - 1; // QJL uses bit_width-1 for MSE + let num_centroids = 1usize << mse_bit_width; + let overhead_bits = num_centroids * 32 // centroids are always f32 + + 2 * 3 * padded_dim; // MSE + QJL rotation signs, 1 bit each + + let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits; + let uncompressed_size_bits = bits_per_element * num_vectors * dimensions as usize; + uncompressed_size_bits as f64 / compressed_size_bits as f64 +} + +fn get_tensor_element_ptype_and_length(dtype: &DType) -> VortexResult<(PType, u32)> { + let ext_id = dtype.as_extension().id(); + let is_tensor = dtype.is_extension() + && (ext_id.as_ref() == VECTOR_EXT_ID || ext_id.as_ref() == FIXED_SHAPE_TENSOR_EXT_ID); + vortex_ensure!(is_tensor, "expected tensor extension dtype, got {}", dtype); + + let storage_dtype = dtype.as_extension().storage_dtype(); + let (element_dtype, fsl_len) = match storage_dtype { + DType::FixedSizeList(element_dtype, list_size, _) => (element_dtype, list_size), + _ => vortex_bail!( + "expected FixedSizeList storage dtype, got {}", + storage_dtype + ), + }; + + // TurboQuant requires dimension >= 3: the marginal coordinate distribution + // (1 - x^2)^((d-3)/2) has a singularity at d=2 (arcsine distribution) that + // causes NaN in the Max-Lloyd centroid computation. + vortex_ensure!( + *fsl_len >= 3, + "TurboQuant requires dimension >= 3, got {}", + fsl_len + ); + + if let &DType::Primitive(ptype, Nullability::NonNullable) = element_dtype.as_ref() { + Ok((ptype, *fsl_len)) + } else { + vortex_bail!( + "expected non-nullable primitive element type, got {}", + element_dtype + ); + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + /// Verify compression ratio for typical embedding dimensions. + /// + /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~4-6x. + /// f32 input at 1024-d (no padding) should give higher ratio since no waste. + #[rstest] + #[case::f32_768d(32, 768, 1000, 3.5, 6.0)] + #[case::f32_1024d(32, 1024, 1000, 4.5, 7.0)] + #[case::f32_1536d(32, 1536, 1000, 3.0, 6.0)] + #[case::f32_128d(32, 128, 1000, 4.0, 6.0)] + #[case::f64_768d(64, 768, 1000, 7.0, 12.0)] + #[case::f16_768d(16, 768, 1000, 1.5, 3.5)] + fn compression_ratio_in_expected_range( + #[case] bits_per_element: usize, + #[case] dim: u32, + #[case] num_vectors: usize, + #[case] min_ratio: f64, + #[case] max_ratio: f64, + ) { + let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + assert!( + ratio > min_ratio && ratio < max_ratio, + "ratio {ratio:.2} not in [{min_ratio}, {max_ratio}] for \ + {bits_per_element}-bit elements, dim={dim}, n={num_vectors}" + ); + } + + /// Compression ratio must always be > 1 for reasonable inputs, + /// otherwise TurboQuant makes things bigger and should not be selected. + #[rstest] + #[case(32, 128, 100)] + #[case(32, 768, 10)] + #[case(64, 256, 50)] + fn ratio_always_greater_than_one( + #[case] bits_per_element: usize, + #[case] dim: u32, + #[case] num_vectors: usize, + ) { + let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + assert!( + ratio > 1.0, + "ratio {ratio:.4} <= 1.0 for {bits_per_element}-bit, dim={dim}, n={num_vectors}" + ); + } + + /// Power-of-2 dimensions should have better ratios than their non-power-of-2 + /// predecessors due to no padding waste. + #[test] + fn power_of_two_has_better_ratio() { + let ratio_768 = estimate_compression_ratio(32, 768, 1000); + let ratio_1024 = estimate_compression_ratio(32, 1024, 1000); + assert!( + ratio_1024 > ratio_768, + "1024-d ratio ({ratio_1024:.2}) should exceed 768-d ({ratio_768:.2})" + ); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs new file mode 100644 index 00000000000..f309ca8bc7c --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! VTable implementation for TurboQuant encoding. + +use std::hash::Hash; +use std::sync::Arc; + +use vortex_array::Array; +use vortex_array::ArrayEq; +use vortex_array::ArrayHash; +use vortex_array::ArrayId; +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::DeserializeMetadata; +use vortex_array::ExecutionCtx; +use vortex_array::ExecutionResult; +use vortex_array::Precision; +use vortex_array::ProstMetadata; +use vortex_array::SerializeMetadata; +use vortex_array::buffer::BufferHandle; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::serde::ArrayChildren; +use vortex_array::stats::ArrayStats; +use vortex_array::vtable::VTable; +use vortex_array::vtable::ValidityChild; +use vortex_array::vtable::ValidityVTableFromChild; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_panic; +use vortex_session::VortexSession; + +use crate::encodings::turboquant::array::Slot; +use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::array::TurboQuantData; +use crate::encodings::turboquant::array::TurboQuantMetadata; +use crate::encodings::turboquant::decompress::execute_decompress; + +impl VTable for TurboQuant { + type ArrayData = TurboQuantData; + type Metadata = ProstMetadata; + type OperationsVTable = TurboQuant; + type ValidityVTable = ValidityVTableFromChild; + + fn vtable(_array: &Self::ArrayData) -> &Self { + &TurboQuant + } + + fn id(&self) -> ArrayId { + Self::ID + } + + fn len(array: &TurboQuantData) -> usize { + array.norms().len() + } + + fn dtype(array: &TurboQuantData) -> &DType { + &array.dtype + } + + fn stats(array: &TurboQuantData) -> &ArrayStats { + &array.stats_set + } + + fn array_hash( + array: &TurboQuantData, + state: &mut H, + precision: Precision, + ) { + array.dtype.hash(state); + array.dimension.hash(state); + array.bit_width.hash(state); + for slot in &array.slots { + slot.is_some().hash(state); + if let Some(child) = slot { + child.array_hash(state, precision); + } + } + } + + fn array_eq(array: &TurboQuantData, other: &TurboQuantData, precision: Precision) -> bool { + array.dtype == other.dtype + && array.dimension == other.dimension + && array.bit_width == other.bit_width + && array.slots.len() == other.slots.len() + && array + .slots + .iter() + .zip(other.slots.iter()) + .all(|(a, b)| match (a, b) { + (Some(a), Some(b)) => a.array_eq(b, precision), + (None, None) => true, + _ => false, + }) + } + + fn nbuffers(_array: ArrayView) -> usize { + 0 + } + + fn buffer(_array: ArrayView, idx: usize) -> BufferHandle { + vortex_panic!("TurboQuantArray buffer index {idx} out of bounds") + } + + fn buffer_name(_array: ArrayView, _idx: usize) -> Option { + None + } + + fn slots(array: ArrayView<'_, Self>) -> &[Option] { + &array.data().slots + } + + fn slot_name(_array: ArrayView, idx: usize) -> String { + Slot::from_index(idx).name().to_string() + } + + fn with_slots(array: &mut TurboQuantData, slots: Vec>) -> VortexResult<()> { + vortex_ensure!( + slots.len() == Slot::COUNT, + "TurboQuantArray expects {} slots, got {}", + Slot::COUNT, + slots.len() + ); + array.slots = slots; + Ok(()) + } + + fn metadata(array: ArrayView) -> VortexResult { + Ok(ProstMetadata(TurboQuantMetadata { + dimension: array.dimension, + bit_width: array.bit_width as u32, + has_qjl: array.has_qjl(), + })) + } + + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(metadata.serialize())) + } + + fn deserialize( + bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(ProstMetadata( + as DeserializeMetadata>::deserialize(bytes)?, + )) + } + + #[allow(clippy::cast_possible_truncation)] + fn build( + dtype: &DType, + len: usize, + metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + let bit_width = u8::try_from(metadata.bit_width)?; + let padded_dim = metadata.dimension.next_power_of_two() as usize; + let num_centroids = 1usize << bit_width; + + let u8_nn = DType::Primitive(PType::U8, Nullability::NonNullable); + let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable); + let codes_dtype = DType::FixedSizeList( + Arc::new(u8_nn.clone()), + padded_dim as u32, + Nullability::NonNullable, + ); + let codes = children.get(0, &codes_dtype, len)?; + + let norms = children.get(1, &f32_nn, len)?; + let centroids = children.get(2, &f32_nn, num_centroids)?; + + let signs_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); + let rotation_signs = children.get(3, &signs_dtype, 3 * padded_dim)?; + + let mut slots = vec![None; Slot::COUNT]; + slots[Slot::Codes as usize] = Some(codes); + slots[Slot::Norms as usize] = Some(norms); + slots[Slot::Centroids as usize] = Some(centroids); + slots[Slot::RotationSigns as usize] = Some(rotation_signs); + + if metadata.has_qjl { + let qjl_signs_dtype = + DType::FixedSizeList(Arc::new(u8_nn), padded_dim as u32, Nullability::NonNullable); + slots[Slot::QjlSigns as usize] = Some(children.get(4, &qjl_signs_dtype, len)?); + slots[Slot::QjlResidualNorms as usize] = Some(children.get(5, &f32_nn, len)?); + slots[Slot::QjlRotationSigns as usize] = + Some(children.get(6, &signs_dtype, 3 * padded_dim)?); + } + + Ok(TurboQuantData { + dtype: dtype.clone(), + slots, + dimension: metadata.dimension, + bit_width, + stats_set: Default::default(), + }) + } + + fn reduce_parent( + array: ArrayView, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + crate::encodings::turboquant::compute::rules::RULES.evaluate(array, parent, child_idx) + } + + fn execute_parent( + array: ArrayView, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + crate::encodings::turboquant::compute::rules::PARENT_KERNELS + .execute(array, parent, child_idx, ctx) + } + + fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(ExecutionResult::done(execute_decompress(array, ctx)?)) + } +} + +impl ValidityChild for TurboQuant { + fn validity_child(array: &TurboQuantData) -> &ArrayRef { + array.codes() + } +} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 6b55389d8c9..e17ec4c88f0 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -7,8 +7,10 @@ use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; +use vortex_array::session::ArraySessionExt; use vortex_session::VortexSession; +use crate::encodings::turboquant::TurboQuant; use crate::fixed_shape::FixedShapeTensor; use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::scalar_fns::inner_product::InnerProduct; @@ -25,10 +27,13 @@ pub mod encodings; mod utils; -/// Registers the tensor extension dtypes and scalar functions with the given session. +/// Initialize the Vortex tensor library with a Vortex session. pub fn initialize(session: &VortexSession) { session.dtypes().register(Vector); session.dtypes().register(FixedShapeTensor); + + session.arrays().register(TurboQuant); + session.scalar_fns().register(CosineSimilarity); session.scalar_fns().register(InnerProduct); session.scalar_fns().register(L2Norm); diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 22c51189380..9f9f440c4f4 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -28,6 +28,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::compute::cosine_similarity; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::inner_product::InnerProduct; @@ -142,13 +144,28 @@ impl ScalarFnVTable for CosineSimilarity { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs = args.get(0)?.execute::(ctx)?.into_array(); - let rhs = args.get(1)?.execute::(ctx)?.into_array(); + let lhs = args.get(0)?.execute::(ctx)?; + let rhs = args.get(1)?.execute::(ctx)?; let len = args.row_count(); // Compute combined validity. - let validity = lhs.validity()?.and(rhs.validity()?)?; + let validity = lhs.as_ref().validity()?.and(rhs.as_ref().validity()?)?; + + // TurboQuant approximate path: compute cosine similarity in quantized domain. + if *options == ApproxOptions::Approximate { + let lhs_storage = lhs.data().storage_array(); + let rhs_storage = rhs.data().storage_array(); + if let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_storage.as_opt::(), + rhs_storage.as_opt::(), + ) { + return cosine_similarity::cosine_similarity_quantized_column(lhs_tq, rhs_tq, ctx); + } + } + + let lhs = lhs.into_array(); + let rhs = rhs.into_array(); // Compute inner product and norms as columnar operations, and propagate the options. let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?; diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index d142649600d..044b76b164c 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -29,6 +29,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::compute::cosine_similarity; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::utils::extension_element_ptype; @@ -137,7 +139,7 @@ impl ScalarFnVTable for InnerProduct { fn execute( &self, - _options: &Self::Options, + options: &Self::Options, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { @@ -160,6 +162,16 @@ impl ScalarFnVTable for InnerProduct { let lhs_storage = lhs.data().storage_array(); let rhs_storage = rhs.data().storage_array(); + // TurboQuant approximate path: norm_a * norm_b * quantized unit-norm dot. + if *options == ApproxOptions::Approximate + && let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_storage.as_opt::(), + rhs_storage.as_opt::(), + ) + { + return cosine_similarity::dot_product_quantized_column(lhs_tq, rhs_tq, ctx); + } + let lhs_flat = extract_flat_elements(lhs_storage, list_size, ctx)?; let rhs_flat = extract_flat_elements(rhs_storage, list_size, ctx)?; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index ed29cc776b7..5992bf4fdfc 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -12,6 +12,7 @@ use vortex_array::IntoArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; @@ -28,6 +29,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::encodings::turboquant::TurboQuant; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::utils::extension_element_ptype; @@ -128,11 +130,22 @@ impl ScalarFnVTable for L2Norm { let row_count = args.row_count(); let validity = input.as_ref().validity()?; - // Get list size (dimensions) from the dtype (validated by `return_dtype`). + // Get element ptype and list size from the dtype (validated by `return_dtype`). let ext = input.dtype().as_extension(); let list_size = extension_list_size(ext)? as usize; + let target_ptype = extension_element_ptype(ext)?; let storage = input.data().storage_array(); + + // TurboQuant stores exact precomputed norms -- no decompression needed. + // Norms are currently stored as f32; cast to the target dtype if needed + // (e.g., if the input extension has f64 elements). + if let Some(tq) = storage.as_opt::() { + let norms: PrimitiveArray = tq.norms().clone().execute(ctx)?; + let target_dtype = DType::Primitive(target_ptype, input.dtype().nullability()); + return norms.into_array().cast(target_dtype); + } + let flat = extract_flat_elements(storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index 896ec139251..f042f568c11 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -56,12 +56,15 @@ divan = { workspace = true } fastlanes = { workspace = true } mimalloc = { workspace = true } parquet = { workspace = true } +paste = { workspace = true } rand = { workspace = true } +rand_distr = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } tracing-subscriber = { workspace = true } vortex = { path = ".", features = ["tokio"] } +vortex-tensor = { workspace = true } [features] default = ["files", "zstd"] diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index 8a4cfa53f59..abac0e36fed 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -17,7 +17,6 @@ use rand::prelude::IndexedRandom; use rand::rngs::StdRng; use vortex::array::IntoArray; use vortex::array::ToCanonical; -use vortex::array::VortexSessionExecute; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::builders::dict::dict_encode; @@ -39,6 +38,7 @@ use vortex::encodings::sequence::sequence_encode; use vortex::encodings::zigzag::zigzag_encode; use vortex::encodings::zstd::Zstd; use vortex::encodings::zstd::ZstdData; +use vortex_array::VortexSessionExecute; use vortex_sequence::Sequence; use vortex_session::VortexSession; @@ -426,3 +426,127 @@ fn bench_zstd_decompress_string(bencher: Bencher) { .with_inputs(|| &compressed) .bench_refs(|a| a.to_canonical()); } + +// TurboQuant vector quantization benchmarks. +#[cfg(feature = "unstable_encodings")] +mod turboquant_benches { + use divan::Bencher; + use paste::paste; + use rand::SeedableRng; + use rand::rngs::StdRng; + use vortex::array::IntoArray; + use vortex::array::arrays::FixedSizeListArray; + use vortex::array::arrays::PrimitiveArray; + use vortex::array::validity::Validity; + use vortex_array::VortexSessionExecute; + use vortex_buffer::BufferMut; + use vortex_tensor::encodings::turboquant::TurboQuantConfig; + use vortex_tensor::encodings::turboquant::turboquant_encode_mse; + use vortex_tensor::encodings::turboquant::turboquant_encode_qjl; + + use super::SESSION; + use super::with_byte_counter; + + const NUM_VECTORS: usize = 1_000; + + /// Generate `num_vectors` random f32 vectors of the given dimension using i.i.d. + /// standard normal components. This is a conservative test distribution: real + /// neural network embeddings typically have structure (clustered, anisotropic) + /// that the SRHT exploits for better quantization, so Gaussian i.i.d. is a + /// worst-case baseline for TurboQuant. + fn setup_vector_fsl(dim: usize) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(42); + let normal = rand_distr::Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(NUM_VECTORS * dim); + for _ in 0..(NUM_VECTORS * dim) { + buf.push(rand_distr::Distribution::sample(&normal, &mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + NUM_VECTORS, + ) + .unwrap() + } + + fn turboquant_config(bit_width: u8) -> TurboQuantConfig { + TurboQuantConfig { + bit_width, + seed: Some(123), + } + } + + macro_rules! turboquant_bench { + (compress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] + fn [<$name _mse>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &fsl) + .bench_refs(|a| turboquant_encode_mse(a, &config).unwrap()); + } + + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] + fn [<$name _qjl>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &fsl) + .bench_refs(|a| turboquant_encode_qjl(a, &config).unwrap()); + } + } + }; + (decompress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] + fn [<$name _mse>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + let compressed = turboquant_encode_mse(&fsl, &config).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] + fn [<$name _qjl>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + let compressed = turboquant_encode_qjl(&fsl, &config).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + } + }; + } + + turboquant_bench!(compress, 128, 4, bench_tq_compress_128_4); + turboquant_bench!(decompress, 128, 4, bench_tq_decompress_128_4); + turboquant_bench!(compress, 768, 4, bench_tq_compress_768_4); + turboquant_bench!(decompress, 768, 4, bench_tq_decompress_768_4); + turboquant_bench!(compress, 1024, 2, bench_tq_compress_1024_2); + turboquant_bench!(decompress, 1024, 2, bench_tq_decompress_1024_2); + turboquant_bench!(compress, 1024, 4, bench_tq_compress_1024_4); + turboquant_bench!(decompress, 1024, 4, bench_tq_decompress_1024_4); + turboquant_bench!(compress, 1024, 8, bench_tq_compress_1024_8); + turboquant_bench!(decompress, 1024, 8, bench_tq_decompress_1024_8); +} From 4c90908e43152fe8f3b038003aa3e6d269449454 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 14:00:39 -0400 Subject: [PATCH 02/13] remove QJL We are going to implement this later as a separate encoding (if we decide to implement it at all because word on the street is that the MSE + QJL is not actually better than MSE on its own). Signed-off-by: Connor Tsui --- vortex-btrblocks/src/builder.rs | 6 +- .../src/encodings/turboquant/array.rs | 120 +--- .../src/encodings/turboquant/compress.rs | 162 +----- .../turboquant/compute/cosine_similarity.rs | 11 +- .../src/encodings/turboquant/compute/slice.rs | 19 +- .../src/encodings/turboquant/compute/take.rs | 17 +- .../src/encodings/turboquant/decompress.rs | 75 +-- vortex-tensor/src/encodings/turboquant/mod.rs | 523 ++++-------------- .../src/encodings/turboquant/scheme.rs | 39 +- .../src/encodings/turboquant/vtable.rs | 30 +- vortex/benches/single_encoding_throughput.rs | 40 +- 11 files changed, 172 insertions(+), 870 deletions(-) diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index 2a390e89504..4e197feb5a2 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -120,7 +120,7 @@ impl BtrBlocksCompressorBuilder { /// Adds compact encoding schemes (Zstd for strings, Pco for numerics). /// /// This provides better compression ratios than the default, especially for floating-point - /// heavy datasets. Requires the `zstd` feature. When the `pco` rfeature is also enabled, + /// heavy datasets. Requires the `zstd` feature. When the `pco` feature is also enabled, /// Pco schemes for integers and floats are included. /// /// # Panics @@ -140,8 +140,8 @@ impl BtrBlocksCompressorBuilder { /// Adds the TurboQuant lossy vector quantization scheme. /// - /// When enabled, [`Vector`] extension arrays are compressed using the TurboQuant algorithm with - /// QJL correction for unbiased inner product estimation. + /// When enabled, [`Vector`] extension arrays are compressed using the TurboQuant algorithm + /// with MSE-optimal scalar quantization. /// /// # Panics /// diff --git a/vortex-tensor/src/encodings/turboquant/array.rs b/vortex-tensor/src/encodings/turboquant/array.rs index 89c600853b4..3172f28f808 100644 --- a/vortex-tensor/src/encodings/turboquant/array.rs +++ b/vortex-tensor/src/encodings/turboquant/array.rs @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors //! TurboQuant array definition: stores quantized coordinate codes, norms, -//! centroids (codebook), rotation signs, and optional QJL correction fields. +//! centroids (codebook), and rotation signs. use vortex_array::ArrayId; use vortex_array::ArrayRef; @@ -32,38 +32,6 @@ pub struct TurboQuantMetadata { /// MSE bits per coordinate (1-8). #[prost(uint32, tag = "2")] pub bit_width: u32, - /// Whether QJL correction children are present. - #[prost(bool, tag = "3")] - pub has_qjl: bool, -} - -/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased -/// inner product estimation. When present, adds 3 additional children. -#[derive(Clone, Debug)] -pub struct QjlCorrection { - /// Sign bits: `BoolArray`, length `num_rows * padded_dim`. - pub(crate) signs: ArrayRef, - /// Residual norms: `PrimitiveArray`, length `num_rows`. - pub(crate) residual_norms: ArrayRef, - /// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order). - pub(crate) rotation_signs: ArrayRef, -} - -impl QjlCorrection { - /// The QJL sign bits. - pub fn signs(&self) -> &ArrayRef { - &self.signs - } - - /// The residual norms. - pub fn residual_norms(&self) -> &ArrayRef { - &self.residual_norms - } - - /// The QJL rotation signs (BoolArray, inverse application order). - pub fn rotation_signs(&self) -> &ArrayRef { - &self.rotation_signs - } } /// Slot positions for TurboQuantArray children. @@ -74,13 +42,10 @@ pub(crate) enum Slot { Norms = 1, Centroids = 2, RotationSigns = 3, - QjlSigns = 4, - QjlResidualNorms = 5, - QjlRotationSigns = 6, } impl Slot { - pub(crate) const COUNT: usize = 7; + pub(crate) const COUNT: usize = 4; pub(crate) fn name(self) -> &'static str { match self { @@ -88,9 +53,6 @@ impl Slot { Self::Norms => "norms", Self::Centroids => "centroids", Self::RotationSigns => "rotation_signs", - Self::QjlSigns => "qjl_signs", - Self::QjlResidualNorms => "qjl_residual_norms", - Self::QjlRotationSigns => "qjl_rotation_signs", } } @@ -100,9 +62,6 @@ impl Slot { 1 => Self::Norms, 2 => Self::Centroids, 3 => Self::RotationSigns, - 4 => Self::QjlSigns, - 5 => Self::QjlResidualNorms, - 6 => Self::QjlRotationSigns, _ => vortex_error::vortex_panic!("invalid slot index {idx}"), } } @@ -110,16 +69,11 @@ impl Slot { /// TurboQuant array. /// -/// Slots (always present): -/// - 0: `codes` — `FixedSizeListArray` (quantized indices, list_size=padded_dim) -/// - 1: `norms` — `PrimitiveArray` (one per vector row) -/// - 2: `centroids` — `PrimitiveArray` (codebook, length 2^bit_width) -/// - 3: `rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order) -/// -/// Optional QJL slots (None when MSE-only): -/// - 4: `qjl_signs` — `FixedSizeListArray` (num_rows * padded_dim, 1-bit) -/// - 5: `qjl_residual_norms` — `PrimitiveArray` (one per row) -/// - 6: `qjl_rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit, QJL rotation) +/// Slots: +/// - 0: `codes` -- `FixedSizeListArray` (quantized indices, list_size=padded_dim). +/// - 1: `norms` -- `PrimitiveArray` (one per vector row). +/// - 2: `centroids` -- `PrimitiveArray` (codebook, length 2^bit_width). +/// - 3: `rotation_signs` -- `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order). #[derive(Clone, Debug)] pub struct TurboQuantData { pub(crate) dtype: DType, @@ -130,9 +84,9 @@ pub struct TurboQuantData { } impl TurboQuantData { - /// Build a TurboQuant array with MSE-only encoding (no QJL correction). + /// Build a TurboQuant array. #[allow(clippy::too_many_arguments)] - pub fn try_new_mse( + pub fn try_new( dtype: DType, codes: ArrayRef, norms: ArrayRef, @@ -143,7 +97,7 @@ impl TurboQuantData { ) -> VortexResult { vortex_ensure!( (1..=8).contains(&bit_width), - "MSE bit_width must be 1-8, got {bit_width}" + "bit_width must be 1-8, got {bit_width}" ); let mut slots = vec![None; Slot::COUNT]; slots[Slot::Codes as usize] = Some(codes); @@ -159,39 +113,6 @@ impl TurboQuantData { }) } - /// Build a TurboQuant array with QJL correction (MSE + QJL). - #[allow(clippy::too_many_arguments)] - pub fn try_new_qjl( - dtype: DType, - codes: ArrayRef, - norms: ArrayRef, - centroids: ArrayRef, - rotation_signs: ArrayRef, - qjl: QjlCorrection, - dimension: u32, - bit_width: u8, - ) -> VortexResult { - vortex_ensure!( - (1..=8).contains(&bit_width), - "MSE bit_width must be 1-8, got {bit_width}" - ); - let mut slots = vec![None; Slot::COUNT]; - slots[Slot::Codes as usize] = Some(codes); - slots[Slot::Norms as usize] = Some(norms); - slots[Slot::Centroids as usize] = Some(centroids); - slots[Slot::RotationSigns as usize] = Some(rotation_signs); - slots[Slot::QjlSigns as usize] = Some(qjl.signs); - slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms); - slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs); - Ok(Self { - dtype, - slots, - dimension, - bit_width, - stats_set: Default::default(), - }) - } - /// The vector dimension d. pub fn dimension(&self) -> u32 { self.dimension @@ -207,11 +128,6 @@ impl TurboQuantData { self.dimension.next_power_of_two() } - /// Whether QJL correction is present. - pub fn has_qjl(&self) -> bool { - self.slots[Slot::QjlSigns as usize].is_some() - } - fn slot(&self, idx: usize) -> &ArrayRef { self.slots[idx] .as_ref() @@ -237,20 +153,4 @@ impl TurboQuantData { pub fn rotation_signs(&self) -> &ArrayRef { self.slot(Slot::RotationSigns as usize) } - - /// The optional QJL correction fields, reconstructed from slots. - pub fn qjl(&self) -> Option { - Some(QjlCorrection { - signs: self.slots[Slot::QjlSigns as usize].clone()?, - residual_norms: self.slots[Slot::QjlResidualNorms as usize].clone()?, - rotation_signs: self.slots[Slot::QjlRotationSigns as usize].clone()?, - }) - } - - /// Set the QJL correction fields on this array. - pub(crate) fn set_qjl(&mut self, qjl: QjlCorrection) { - self.slots[Slot::QjlSigns as usize] = Some(qjl.signs); - self.slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms); - self.slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs); - } } diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 756b17cdada..53f7046bad2 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -25,10 +25,7 @@ use crate::encodings::turboquant::rotation::RotationMatrix; /// Configuration for TurboQuant encoding. #[derive(Clone, Debug)] pub struct TurboQuantConfig { - /// Bits per coordinate. - /// - /// For MSE encoding: 1-8. - /// For QJL encoding: 2-9 (the MSE component uses `bit_width - 1`). + /// Bits per coordinate (1-8). pub bit_width: u8, /// Optional seed for the rotation matrix. If None, the default seed is used. pub seed: Option, @@ -37,7 +34,7 @@ pub struct TurboQuantConfig { impl Default for TurboQuantConfig { fn default() -> Self { Self { - bit_width: 5, + bit_width: 4, seed: Some(42), } } @@ -73,9 +70,8 @@ fn l2_norm(x: &[f32]) -> f32 { } /// Shared intermediate results from the MSE quantization loop. -struct MseQuantizationResult { +struct QuantizationResult { rotation: RotationMatrix, - f32_elements: PrimitiveArray, centroids: Vec, all_indices: BufferMut, norms: BufferMut, @@ -88,7 +84,7 @@ fn turboquant_quantize_core( fsl: &FixedSizeListArray, seed: u64, bit_width: u8, -) -> VortexResult { +) -> VortexResult { let dimension = fsl.list_size() as usize; let num_rows = fsl.len(); @@ -126,9 +122,8 @@ fn turboquant_quantize_core( } } - Ok(MseQuantizationResult { + Ok(QuantizationResult { rotation, - f32_elements, centroids, all_indices, norms, @@ -136,11 +131,11 @@ fn turboquant_quantize_core( }) } -/// Build a `TurboQuantArray` (MSE-only) from quantization results. +/// Build a `TurboQuantArray` from quantization results. #[allow(clippy::cast_possible_truncation)] -fn build_turboquant_mse( +fn build_turboquant( fsl: &FixedSizeListArray, - core: MseQuantizationResult, + core: QuantizationResult, bit_width: u8, ) -> VortexResult { let dimension = fsl.list_size(); @@ -168,7 +163,7 @@ fn build_turboquant_mse( let rotation_signs = bitpack_rotation_signs(&core.rotation)?; - TurboQuantData::try_new_mse( + TurboQuantData::try_new( fsl.dtype().clone(), codes, norms_array, @@ -179,11 +174,11 @@ fn build_turboquant_mse( ) } -/// Encode a FixedSizeListArray into a MSE-only `TurboQuantArray`. +/// Encode a FixedSizeListArray into a `TurboQuantArray`. /// /// The input must be non-nullable. TurboQuant is a lossy encoding that does not /// preserve null positions; callers must handle validity externally. -pub fn turboquant_encode_mse( +pub fn turboquant_encode( fsl: &FixedSizeListArray, config: &TurboQuantConfig, ) -> VortexResult { @@ -193,7 +188,7 @@ pub fn turboquant_encode_mse( ); vortex_ensure!( config.bit_width >= 1 && config.bit_width <= 8, - "MSE bit_width must be 1-8, got {}", + "bit_width must be 1-8, got {}", config.bit_width ); let dimension = fsl.list_size(); @@ -209,141 +204,12 @@ pub fn turboquant_encode_mse( let seed = config.seed.unwrap_or(42); let core = turboquant_quantize_core(fsl, seed, config.bit_width)?; - Ok(build_turboquant_mse(fsl, core, config.bit_width)?.into_array()) -} - -/// Encode a FixedSizeListArray into a `TurboQuantArray` with QJL correction. -/// -/// The QJL variant uses `bit_width - 1` MSE bits plus 1 bit of QJL residual -/// correction, giving unbiased inner product estimation. The input must be -/// non-nullable. -#[allow(clippy::cast_possible_truncation)] -pub fn turboquant_encode_qjl( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, -) -> VortexResult { - vortex_ensure!( - fsl.dtype().nullability() == Nullability::NonNullable, - "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" - ); - vortex_ensure!( - config.bit_width >= 2 && config.bit_width <= 9, - "QJL bit_width must be 2-9, got {}", - config.bit_width - ); - let dimension = fsl.list_size(); - vortex_ensure!( - dimension >= 3, - "TurboQuant requires dimension >= 3, got {dimension}" - ); - - if fsl.is_empty() { - return Ok(fsl.clone().into_array()); - } - - let seed = config.seed.unwrap_or(42); - let dim = dimension as usize; - let mse_bit_width = config.bit_width - 1; - - let core = turboquant_quantize_core(fsl, seed, mse_bit_width)?; - let padded_dim = core.padded_dim; - - // QJL uses a different rotation than the MSE stage to ensure statistical - // independence between the quantization noise and the sign projection. - let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(25), dim)?; - - let num_rows = fsl.len(); - let mut residual_norms_buf = BufferMut::::with_capacity(num_rows); - let mut qjl_sign_u8 = BufferMut::::with_capacity(num_rows * padded_dim); - - let mut dequantized_rotated = vec![0.0f32; padded_dim]; - let mut dequantized = vec![0.0f32; padded_dim]; - let mut residual = vec![0.0f32; padded_dim]; - let mut projected = vec![0.0f32; padded_dim]; - - // Compute QJL residuals using precomputed indices and norms from the core. - { - let f32_slice = core.f32_elements.as_slice::(); - let indices_slice: &[u8] = &core.all_indices; - let norms_slice: &[f32] = &core.norms; - - for row in 0..num_rows { - let x = &f32_slice[row * dim..(row + 1) * dim]; - let norm = norms_slice[row]; - - // Dequantize from precomputed indices. - let row_indices = &indices_slice[row * padded_dim..(row + 1) * padded_dim]; - for j in 0..padded_dim { - dequantized_rotated[j] = core.centroids[row_indices[j] as usize]; - } - - core.rotation - .inverse_rotate(&dequantized_rotated, &mut dequantized); - if norm > 0.0 { - for val in dequantized[..dim].iter_mut() { - *val *= norm; - } - } - - // Compute residual: r = x_padded - x̂. - // For positions 0..dim: r[j] = x[j] - dequantized[j]. - // For pad positions dim..padded_dim: the original was zero-padded, - // so r[j] = 0 - dequantized[j]. These pad artifacts are nonzero - // because the SRHT mixes quantization error into the padded region. - // Omitting them would corrupt the QJL signs for non-power-of-2 dims. - for j in 0..dim { - residual[j] = x[j] - dequantized[j]; - } - for j in dim..padded_dim { - residual[j] = -dequantized[j]; - } - // The residual norm for QJL scaling is over the dim-dimensional - // subspace only — pad artifacts don't contribute to reconstruction - // error in the output space. The pad positions are still included - // in the sign projection to avoid corrupting the SRHT mixing. - let residual_norm = l2_norm(&residual[..dim]); - residual_norms_buf.push(residual_norm); - - // QJL: sign(S · r). - if residual_norm > 0.0 { - qjl_rotation.rotate(&residual, &mut projected); - } else { - projected.fill(0.0); - } - - for j in 0..padded_dim { - qjl_sign_u8.push(if projected[j] >= 0.0 { 1u8 } else { 0u8 }); - } - } - } - - // Build the MSE part. - let mut array = build_turboquant_mse(fsl, core, mse_bit_width)?; - - // Attach QJL correction. - let residual_norms_array = - PrimitiveArray::new::(residual_norms_buf.freeze(), Validity::NonNullable); - let qjl_signs_elements = PrimitiveArray::new::(qjl_sign_u8.freeze(), Validity::NonNullable); - let qjl_signs = FixedSizeListArray::try_new( - qjl_signs_elements.into_array(), - padded_dim as u32, - Validity::NonNullable, - num_rows, - )?; - let qjl_rotation_signs = bitpack_rotation_signs(&qjl_rotation)?; - - array.set_qjl(crate::encodings::turboquant::array::QjlCorrection { - signs: qjl_signs.into_array(), - residual_norms: residual_norms_array.into_array(), - rotation_signs: qjl_rotation_signs, - }); - - Ok(array.into_array()) + Ok(build_turboquant(fsl, core, config.bit_width)?.into_array()) } /// Export rotation signs as a 1-bit `BitPackedArray` for efficient storage. /// -/// The rotation matrix's 3 × padded_dim sign values are exported as 0/1 u8 +/// The rotation matrix's 3 x padded_dim sign values are exported as 0/1 u8 /// values in inverse application order, then bitpacked to 1 bit per sign. /// On decode, FastLanes SIMD-unpacks back to `&[u8]` of 0/1 values. fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult { diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs index 2666270f6e8..98935e5fb4e 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -17,14 +17,9 @@ //! //! # Bias and error bounds //! -//! This estimate is **biased** — it uses only the MSE-quantized codes and does -//! not incorporate the QJL residual correction. The MSE quantizer minimizes -//! reconstruction error but does not guarantee unbiased inner products; the -//! discrete centroid grid introduces systematic bias in the dot product. -//! -//! The TurboQuant paper's Theorem 2 shows that unbiased inner product estimation -//! requires the full QJL correction term, which involves decoding the per-row -//! QJL signs and computing cross-terms — nearly as expensive as full decompression. +//! This estimate is **biased**. The MSE quantizer minimizes reconstruction error +//! but does not guarantee unbiased inner products; the discrete centroid grid +//! introduces systematic bias in the dot product. //! //! The approximation error is bounded by the MSE quantization distortion. For //! unit-norm vectors quantized at `b` bits, the per-coordinate MSE is bounded by diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs index acd4f1a42ee..19e1a9e0f91 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/slice.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/slice.rs @@ -9,7 +9,6 @@ use vortex_array::IntoArray; use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; -use crate::encodings::turboquant::array::QjlCorrection; use crate::encodings::turboquant::array::TurboQuant; use crate::encodings::turboquant::array::TurboQuantData; @@ -19,20 +18,9 @@ impl SliceReduce for TurboQuant { range: Range, ) -> VortexResult> { let sliced_codes = array.codes().slice(range.clone())?; - let sliced_norms = array.norms().slice(range.clone())?; + let sliced_norms = array.norms().slice(range)?; - let sliced_qjl = array - .qjl() - .map(|qjl| -> VortexResult { - Ok(QjlCorrection { - signs: qjl.signs.slice(range.clone())?, - residual_norms: qjl.residual_norms.slice(range.clone())?, - rotation_signs: qjl.rotation_signs, - }) - }) - .transpose()?; - - let mut result = TurboQuantData::try_new_mse( + let result = TurboQuantData::try_new( array.dtype.clone(), sliced_codes, sliced_norms, @@ -41,9 +29,6 @@ impl SliceReduce for TurboQuant { array.dimension, array.bit_width, )?; - if let Some(qjl) = sliced_qjl { - result.set_qjl(qjl); - } Ok(Some(result.into_array())) } diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs index 2779a907375..638b493d3a6 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/take.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/take.rs @@ -8,7 +8,6 @@ use vortex_array::IntoArray; use vortex_array::arrays::dict::TakeExecute; use vortex_error::VortexResult; -use crate::encodings::turboquant::array::QjlCorrection; use crate::encodings::turboquant::array::TurboQuant; use crate::encodings::turboquant::array::TurboQuantData; @@ -22,18 +21,7 @@ impl TakeExecute for TurboQuant { let taken_codes = array.codes().take(indices.clone())?; let taken_norms = array.norms().take(indices.clone())?; - let taken_qjl = array - .qjl() - .map(|qjl| -> VortexResult { - Ok(QjlCorrection { - signs: qjl.signs.take(indices.clone())?, - residual_norms: qjl.residual_norms.take(indices.clone())?, - rotation_signs: qjl.rotation_signs, - }) - }) - .transpose()?; - - let mut result = TurboQuantData::try_new_mse( + let result = TurboQuantData::try_new( array.dtype.clone(), taken_codes, taken_norms, @@ -42,9 +30,6 @@ impl TakeExecute for TurboQuant { array.dimension, array.bit_width, )?; - if let Some(qjl) = taken_qjl { - result.set_qjl(qjl); - } Ok(Some(result.into_array())) } diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index 5f5c68ce802..0f5d7c1942a 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -16,20 +16,10 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::rotation::RotationMatrix; -/// QJL correction scale factor: `sqrt(π/2) / padded_dim`. -/// -/// Accounts for the SRHT normalization (`1/padded_dim^{3/2}` per transform) -/// combined with `E[|z|] = sqrt(2/π)` for half-normal sign expectations. -#[inline] -fn qjl_correction_scale(padded_dim: usize) -> f32 { - (std::f32::consts::FRAC_PI_2).sqrt() / (padded_dim as f32) -} - /// Decompress a `TurboQuantArray` into a `FixedSizeListArray` of floats. /// /// Reads stored centroids and rotation signs from the array's children, -/// avoiding any recomputation. If QJL correction is present, applies -/// the residual correction after MSE decoding. +/// avoiding any recomputation. pub fn execute_decompress( array: Array, ctx: &mut ExecutionCtx, @@ -49,7 +39,7 @@ pub fn execute_decompress( .into_array()); } - // Read stored centroids — no recomputation. + // Read stored centroids -- no recomputation. let centroids_prim = array.centroids().clone().execute::(ctx)?; let centroids = centroids_prim.as_slice::(); @@ -62,7 +52,7 @@ pub fn execute_decompress( .execute::(ctx)?; let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::(), dim)?; - // Unpack codes from FixedSizeListArray → flat u8 elements. + // Unpack codes from FixedSizeListArray -> flat u8 elements. let codes_fsl = array.codes().clone().execute::(ctx)?; let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); let indices = codes_prim.as_slice::(); @@ -70,8 +60,8 @@ pub fn execute_decompress( let norms_prim = array.norms().clone().execute::(ctx)?; let norms = norms_prim.as_slice::(); - // MSE decode: dequantize → inverse rotate → scale by norm. - let mut mse_output = BufferMut::::with_capacity(num_rows * dim); + // MSE decode: dequantize -> inverse rotate -> scale by norm. + let mut output = BufferMut::::with_capacity(num_rows * dim); let mut dequantized = vec![0.0f32; padded_dim]; let mut unrotated = vec![0.0f32; padded_dim]; @@ -89,60 +79,7 @@ pub fn execute_decompress( unrotated[idx] *= norm; } - mse_output.extend_from_slice(&unrotated[..dim]); - } - - // If no QJL correction, we're done. - let Some(qjl) = array.qjl() else { - let elements = PrimitiveArray::new::(mse_output.freeze(), Validity::NonNullable); - return Ok(FixedSizeListArray::try_new( - elements.into_array(), - array.dimension(), - Validity::NonNullable, - num_rows, - )? - .into_array()); - }; - - // Apply QJL residual correction. - // Unpack QJL signs from FixedSizeListArray → flat u8 0/1 values. - let qjl_signs_fsl = qjl.signs.clone().execute::(ctx)?; - let qjl_signs_prim = qjl_signs_fsl.elements().to_canonical()?.into_primitive(); - let qjl_signs_u8 = qjl_signs_prim.as_slice::(); - - let residual_norms_prim = qjl.residual_norms.clone().execute::(ctx)?; - let residual_norms = residual_norms_prim.as_slice::(); - - let qjl_rot_signs_prim = qjl.rotation_signs.execute::(ctx)?; - let qjl_rot = RotationMatrix::from_u8_slice(qjl_rot_signs_prim.as_slice::(), dim)?; - - let qjl_scale = qjl_correction_scale(padded_dim); - let mse_elements = mse_output.as_ref(); - - let mut output = BufferMut::::with_capacity(num_rows * dim); - let mut qjl_signs_vec = vec![0.0f32; padded_dim]; - let mut qjl_projected = vec![0.0f32; padded_dim]; - - for row in 0..num_rows { - let mse_row = &mse_elements[row * dim..(row + 1) * dim]; - let residual_norm = residual_norms[row]; - - // Branchless u8 0/1 → f32 ±1.0 via XOR on the IEEE 754 sign bit. - // 1.0f32 = 0x3F800000; flipping the sign bit gives -1.0 = 0xBF800000. - // For sign=0 (negative): mask = 0x80000000, 1.0 XOR mask = -1.0. - // For sign=1 (positive): mask = 0x00000000, 1.0 XOR mask = 1.0. - let row_signs = &qjl_signs_u8[row * padded_dim..(row + 1) * padded_dim]; - for (dst, &sign) in qjl_signs_vec.iter_mut().zip(row_signs.iter()) { - let mask = ((sign as u32) ^ 1) << 31; - *dst = f32::from_bits(0x3F80_0000 ^ mask); - } - - qjl_rot.inverse_rotate(&qjl_signs_vec, &mut qjl_projected); - let scale = qjl_scale * residual_norm; - - for idx in 0..dim { - output.push(mse_row[idx] + scale * qjl_projected[idx]); - } + output.extend_from_slice(&unrotated[..dim]); } let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index b537db29029..d186d315b6e 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -9,21 +9,17 @@ //! //! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 //! -//! # Variants +//! # Overview //! -//! - **MSE** (`TurboQuantVariant::Mse`): Minimizes mean-squared reconstruction error -//! (1-8 bits per coordinate). -//! - **Prod** (`TurboQuantVariant::Prod`): Preserves inner products with an unbiased -//! estimator (uses `b-1` bits for MSE + 1-bit QJL residual correction, 2-9 bits). -//! At `b=9`, the MSE codes are raw int8 values suitable for direct use with -//! tensor core int8 GEMM kernels. +//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) +//! using MSE-optimal scalar quantization with an SRHT rotation for coordinate independence. //! //! # Theoretical error bounds //! //! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1 //! guarantees normalized MSE distortion: //! -//! > `E[||x - x̂||² / ||x||²] ≤ (√3 · π / 2) / 4^b` +//! > `E[||x - x_hat||^2 / ||x||^2] <= (sqrt(3) * pi / 2) / 4^b` //! //! | Bits | MSE bound | Quality | //! |------|------------|-------------------| @@ -38,7 +34,7 @@ //! //! # Compression ratios //! -//! Each vector is stored as `padded_dim × bit_width / 8` bytes of quantized codes plus a +//! Each vector is stored as `padded_dim * bit_width / 8` bytes of quantized codes plus a //! 4-byte f32 norm. Non-power-of-2 dimensions are padded to the next power of 2 for the //! Walsh-Hadamard transform, which reduces the effective ratio for those dimensions. //! @@ -59,7 +55,7 @@ //! use vortex_array::arrays::PrimitiveArray; //! use vortex_array::validity::Validity; //! use vortex_buffer::BufferMut; -//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_mse}; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; //! //! // Create a FixedSizeListArray of 100 random 128-d vectors. //! let num_rows = 100; @@ -73,20 +69,18 @@ //! elements.into_array(), dim as u32, Validity::NonNullable, num_rows, //! ).unwrap(); //! -//! // Quantize at 2 bits per coordinate using MSE-optimal encoding. +//! // Quantize at 2 bits per coordinate. //! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; -//! let encoded = turboquant_encode_mse(&fsl, &config).unwrap(); +//! let encoded = turboquant_encode(&fsl, &config).unwrap(); //! -//! // Verify compression: 100 vectors × 128 dims × 4 bytes = 51200 bytes input. +//! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. //! assert!(encoded.nbytes() < 51200); //! ``` -pub use array::QjlCorrection; pub use array::TurboQuant; pub use array::TurboQuantData; pub use compress::TurboQuantConfig; -pub use compress::turboquant_encode_mse; -pub use compress::turboquant_encode_qjl; +pub use compress::turboquant_encode; mod array; pub(crate) mod centroids; @@ -116,7 +110,6 @@ pub fn initialize(session: &mut VortexSession) { mod tests { use std::sync::LazyLock; - use rand::RngExt; use rand::SeedableRng; use rand::rngs::StdRng; use rand_distr::Distribution; @@ -136,8 +129,7 @@ mod tests { use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::rotation::RotationMatrix; - use crate::encodings::turboquant::turboquant_encode_mse; - use crate::encodings::turboquant::turboquant_encode_qjl; + use crate::encodings::turboquant::turboquant_encode; static SESSION: LazyLock = LazyLock::new(|| VortexSession::empty().with::()); @@ -194,13 +186,14 @@ mod tests { /// Encode and decode, returning (original, decoded) flat f32 slices. fn encode_decode( fsl: &FixedSizeListArray, - encode_fn: impl FnOnce(&FixedSizeListArray) -> VortexResult, + config: &TurboQuantConfig, ) -> VortexResult<(Vec, Vec)> { let original: Vec = { let prim = fsl.elements().to_canonical().unwrap().into_primitive(); prim.as_slice::().to_vec() }; - let encoded = encode_fn(fsl)?; + let config = config.clone(); + let encoded = turboquant_encode(fsl, &config)?; let mut ctx = SESSION.create_execution_ctx(); let decoded = encoded.execute::(&mut ctx)?; let decoded_elements: Vec = { @@ -210,24 +203,8 @@ mod tests { Ok((original, decoded_elements)) } - fn encode_decode_mse( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, - ) -> VortexResult<(Vec, Vec)> { - let config = config.clone(); - encode_decode(fsl, |fsl| turboquant_encode_mse(fsl, &config)) - } - - fn encode_decode_qjl( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, - ) -> VortexResult<(Vec, Vec)> { - let config = config.clone(); - encode_decode(fsl, |fsl| turboquant_encode_qjl(fsl, &config)) - } - // ----------------------------------------------------------------------- - // MSE encoding tests + // Roundtrip tests // ----------------------------------------------------------------------- #[rstest] @@ -240,17 +217,21 @@ mod tests { #[case(128, 6)] #[case(128, 8)] #[case(256, 2)] - fn roundtrip_mse(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + fn roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { let fsl = make_fsl(10, dim, 42); let config = TurboQuantConfig { bit_width, seed: Some(123), }; - let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let (original, decoded) = encode_decode(&fsl, &config)?; assert_eq!(decoded.len(), original.len()); Ok(()) } + // ----------------------------------------------------------------------- + // MSE quality tests + // ----------------------------------------------------------------------- + #[rstest] #[case(128, 1)] #[case(128, 2)] @@ -265,14 +246,15 @@ mod tests { bit_width, seed: Some(123), }; - let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let (original, decoded) = encode_decode(&fsl, &config)?; let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); let bound = theoretical_mse_bound(bit_width); assert!( normalized_mse < bound, - "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} for dim={dim}, bits={bit_width}", + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} \ + for dim={dim}, bits={bit_width}", ); Ok(()) } @@ -290,14 +272,14 @@ mod tests { bit_width: 4, seed: Some(123), }; - let (original_4, decoded_4) = encode_decode_mse(&fsl, &config_4bit)?; + let (original_4, decoded_4) = encode_decode(&fsl, &config_4bit)?; let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); let config = TurboQuantConfig { bit_width, seed: Some(123), }; - let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let (original, decoded) = encode_decode(&fsl, &config)?; let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); assert!( @@ -320,7 +302,7 @@ mod tests { bit_width, seed: Some(123), }; - let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let (original, decoded) = encode_decode(&fsl, &config)?; let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); assert!( mse <= prev_mse * 1.01, @@ -331,106 +313,6 @@ mod tests { Ok(()) } - // ----------------------------------------------------------------------- - // QJL encoding tests - // ----------------------------------------------------------------------- - - #[rstest] - #[case(32, 2)] - #[case(32, 3)] - #[case(128, 2)] - #[case(128, 4)] - #[case(128, 6)] - #[case(128, 8)] - #[case(128, 9)] - #[case(768, 3)] - fn roundtrip_qjl(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let fsl = make_fsl(10, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(456), - }; - let (original, decoded) = encode_decode_qjl(&fsl, &config)?; - assert_eq!(decoded.len(), original.len()); - Ok(()) - } - - /// Compute the mean signed relative error of QJL inner product estimation - /// over random query/vector pairs. - fn qjl_mean_signed_relative_error( - original: &[f32], - decoded: &[f32], - dim: usize, - num_rows: usize, - ) -> f32 { - let num_pairs = 500; - let mut rng = StdRng::seed_from_u64(0); - let mut signed_errors = Vec::with_capacity(num_pairs); - - for _ in 0..num_pairs { - let qi = rng.random_range(0..num_rows); - let xi = rng.random_range(0..num_rows); - if qi == xi { - continue; - } - - let query = &original[qi * dim..(qi + 1) * dim]; - let orig_vec = &original[xi * dim..(xi + 1) * dim]; - let quant_vec = &decoded[xi * dim..(xi + 1) * dim]; - - let true_ip: f32 = query.iter().zip(orig_vec).map(|(&a, &b)| a * b).sum(); - let quant_ip: f32 = query.iter().zip(quant_vec).map(|(&a, &b)| a * b).sum(); - - if true_ip.abs() > 1e-6 { - signed_errors.push((quant_ip - true_ip) / true_ip.abs()); - } - } - - if signed_errors.is_empty() { - return 0.0; - } - - signed_errors.iter().sum::() / signed_errors.len() as f32 - } - - #[rstest] - #[case(128, 2)] - #[case(128, 3)] - #[case(128, 4)] - #[case(128, 6)] - #[case(128, 8)] - #[case(128, 9)] - #[case(768, 3)] - #[case(768, 4)] - fn qjl_inner_product_bias(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 100; - let fsl = make_fsl(num_rows, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(789), - }; - let (original, decoded) = encode_decode_qjl(&fsl, &config)?; - - let mean_rel_error = qjl_mean_signed_relative_error(&original, &decoded, dim, num_rows); - - // Known limitation: non-power-of-2 dims have elevated QJL bias (~23% vs - // ~11%) due to distribution mismatch between the SRHT zero-padded coordinate - // distribution and the analytical (1-x^2)^((d-3)/2) model used for centroids. - // Investigated approaches: - // - Random permutation of zeros: no effect (issue is distribution shape) - // - MC empirical centroids: fixes QJL bias but regresses MSE quality - // - Analytical centroids with dim instead of padded_dim: mixed results - // The principled fix requires jointly correcting centroids and QJL scale - // factor for the actual SRHT zero-padded distribution. - let threshold = if dim.is_power_of_two() { 0.15 } else { 0.25 }; - assert!( - mean_rel_error.abs() < threshold, - "QJL inner product bias too high: {mean_rel_error:.4} for dim={dim}, bits={bit_width} \ - (threshold={threshold})" - ); - Ok(()) - } - // ----------------------------------------------------------------------- // Edge cases // ----------------------------------------------------------------------- @@ -438,29 +320,13 @@ mod tests { #[rstest] #[case(0)] #[case(1)] - fn roundtrip_mse_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { let fsl = make_fsl(num_rows, 128, 42); let config = TurboQuantConfig { bit_width: 2, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; - let mut ctx = SESSION.create_execution_ctx(); - let decoded = encoded.execute::(&mut ctx)?; - assert_eq!(decoded.len(), num_rows); - Ok(()) - } - - #[rstest] - #[case(0)] - #[case(1)] - fn roundtrip_qjl_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { - let fsl = make_fsl(num_rows, 128, 42); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(456), - }; - let encoded = turboquant_encode_qjl(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; let mut ctx = SESSION.create_execution_ctx(); let decoded = encoded.execute::(&mut ctx)?; assert_eq!(decoded.len(), num_rows); @@ -470,25 +336,13 @@ mod tests { #[rstest] #[case(1)] #[case(2)] - fn mse_rejects_dimension_below_3(#[case] dim: usize) { + fn rejects_dimension_below_3(#[case] dim: usize) { let fsl = make_fsl_small(dim); let config = TurboQuantConfig { bit_width: 2, seed: Some(0), }; - assert!(turboquant_encode_mse(&fsl, &config).is_err()); - } - - #[rstest] - #[case(1)] - #[case(2)] - fn qjl_rejects_dimension_below_3(#[case] dim: usize) { - let fsl = make_fsl_small(dim); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(0), - }; - assert!(turboquant_encode_qjl(&fsl, &config).is_err()); + assert!(turboquant_encode(&fsl, &config).is_err()); } fn make_fsl_small(dim: usize) -> FixedSizeListArray { @@ -501,11 +355,70 @@ mod tests { .unwrap() } + /// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). + #[test] + fn all_zero_vectors_roundtrip() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let buf = BufferMut::::full(0.0f32, num_rows * dim); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + // All-zero vectors should decode to all-zero (norm=0 -> 0 * anything = 0). + for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { + assert_eq!(o, 0.0, "original[{i}] not zero"); + assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); + } + Ok(()) + } + + /// Verify that f64 input is accepted and encoded (converted to f32 internally). + #[test] + fn f64_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 64; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f64, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + // Verify encoding succeeds with f64 input (f64->f32 conversion). + let encoded = turboquant_encode(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + assert_eq!(encoded.norms().len(), num_rows); + assert_eq!(encoded.dimension(), dim as u32); + Ok(()) + } + // ----------------------------------------------------------------------- // Verification tests for stored metadata // ----------------------------------------------------------------------- - /// Verify that the centroids stored in the MSE array match what get_centroids() computes. + /// Verify that the centroids stored in the array match what `get_centroids()` computes. #[test] fn stored_centroids_match_computed() -> VortexResult<()> { let fsl = make_fsl(10, 128, 42); @@ -513,7 +426,7 @@ mod tests { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; let encoded = encoded.as_opt::().unwrap(); let mut ctx = SESSION.create_execution_ctx(); @@ -534,10 +447,6 @@ mod tests { } /// Verify that stored rotation signs produce identical decode to seed-based decode. - /// - /// Encodes the same data twice: once with the new path (stored signs), and - /// once by manually recomputing the rotation from the seed. Both should - /// produce identical output. #[test] fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { let fsl = make_fsl(20, 128, 42); @@ -545,7 +454,7 @@ mod tests { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; let encoded = encoded.as_opt::().unwrap(); // Decode via the stored-signs path (normal decode). @@ -577,140 +486,11 @@ mod tests { } // ----------------------------------------------------------------------- - // QJL-specific quality tests + // Serde roundtrip // ----------------------------------------------------------------------- - /// Verify that QJL's MSE component (at bit_width-1) satisfies the theoretical bound. - #[rstest] - #[case(128, 3)] - #[case(128, 4)] - #[case(256, 3)] - fn qjl_mse_within_theoretical_bound( - #[case] dim: usize, - #[case] bit_width: u8, - ) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(789), - }; - let (original, decoded) = encode_decode_qjl(&fsl, &config)?; - - let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - - // QJL at b bits uses (b-1)-bit MSE plus a correction term. - // The MSE should be at most the (b-1)-bit theoretical bound, though - // in practice the QJL correction often improves it further. - let mse_bound = theoretical_mse_bound(bit_width - 1); - assert!( - normalized_mse < mse_bound, - "QJL MSE {normalized_mse:.6} exceeds (b-1)-bit bound {mse_bound:.6} \ - for dim={dim}, bits={bit_width}", - ); - Ok(()) - } - - /// Verify that high-bitwidth QJL (8-9 bits) achieves very low distortion. - #[rstest] - #[case(128, 8)] - #[case(128, 9)] - fn high_bitwidth_qjl_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - - // Compare against 4-bit QJL as reference ceiling. - let config_4bit = TurboQuantConfig { - bit_width: 4, - seed: Some(789), - }; - let (original_4, decoded_4) = encode_decode_qjl(&fsl, &config_4bit)?; - let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); - - let config = TurboQuantConfig { - bit_width, - seed: Some(789), - }; - let (original, decoded) = encode_decode_qjl(&fsl, &config)?; - let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - - assert!( - mse < mse_4bit, - "{bit_width}-bit QJL MSE ({mse:.6}) should be < 4-bit ({mse_4bit:.6})" - ); - assert!( - mse < 0.01, - "{bit_width}-bit QJL MSE ({mse:.6}) should be < 1%" - ); - Ok(()) - } - - // ----------------------------------------------------------------------- - // Edge case and input format tests - // ----------------------------------------------------------------------- - - /// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). - #[test] - fn all_zero_vectors_roundtrip() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let buf = BufferMut::::full(0.0f32, num_rows * dim); - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim as u32, - Validity::NonNullable, - num_rows, - )?; - - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - }; - let (original, decoded) = encode_decode_mse(&fsl, &config)?; - // All-zero vectors should decode to all-zero (norm=0 → 0 * anything = 0). - for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { - assert_eq!(o, 0.0, "original[{i}] not zero"); - assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); - } - Ok(()) - } - - /// Verify that f64 input is accepted and encoded (converted to f32 internally). - #[test] - fn f64_input_encodes_successfully() -> VortexResult<()> { - let num_rows = 10; - let dim = 64; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f64, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim as u32, - Validity::NonNullable, - num_rows, - )?; - - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - }; - // Verify encoding succeeds with f64 input (f64→f32 conversion). - let encoded = turboquant_encode_mse(&fsl, &config)?; - let encoded = encoded.as_opt::().unwrap(); - assert_eq!(encoded.norms().len(), num_rows); - assert_eq!(encoded.dimension(), dim as u32); - Ok(()) - } - - /// Verify serde roundtrip: serialize MSE array metadata + children, then rebuild. #[test] - fn mse_serde_roundtrip() -> VortexResult<()> { + fn serde_roundtrip() -> VortexResult<()> { use vortex_array::vtable::VTable; let fsl = make_fsl(10, 128, 42); @@ -718,7 +498,7 @@ mod tests { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; let encoded = encoded.as_opt::().unwrap(); // Serialize metadata. @@ -745,7 +525,6 @@ mod tests { // Verify metadata fields survived roundtrip. assert_eq!(deserialized.dimension, encoded.dimension()); assert_eq!(deserialized.bit_width, encoded.bit_width() as u32); - assert_eq!(deserialized.has_qjl, encoded.has_qjl()); // Verify the rebuilt array decodes identically. let mut ctx = SESSION.create_execution_ctx(); @@ -756,84 +535,12 @@ mod tests { let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); // Rebuild from children (simulating deserialization). - let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new_mse( - encoded.dtype().clone(), - children[0].clone(), - children[1].clone(), - children[2].clone(), - children[3].clone(), - deserialized.dimension, - deserialized.bit_width as u8, - )?; - let decoded_rebuilt = rebuilt - .into_array() - .execute::(&mut ctx)?; - let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); - - assert_eq!( - original_elements.as_slice::(), - rebuilt_elements.as_slice::() - ); - Ok(()) - } - - /// Verify serde roundtrip for QJL: serialize metadata + children, then rebuild. - #[test] - fn qjl_serde_roundtrip() -> VortexResult<()> { - use vortex_array::vtable::VTable; - - let fsl = make_fsl(10, 128, 42); - let config = TurboQuantConfig { - bit_width: 4, - seed: Some(456), - }; - let encoded = turboquant_encode_qjl(&fsl, &config)?; - let encoded = encoded.as_opt::().unwrap(); - - // Serialize metadata. - let metadata = ::metadata(encoded)?; - let serialized = - ::serialize(metadata)?.expect("metadata should serialize"); - - // Collect children — QJL has 7 (4 MSE + 3 QJL). - let nchildren = ::nchildren(encoded); - assert_eq!(nchildren, 7); - let children: Vec = (0..nchildren) - .map(|i| ::child(encoded, i)) - .collect(); - - // Deserialize metadata. - let deserialized = ::deserialize( - &serialized, - encoded.dtype(), - encoded.len(), - &[], - &SESSION, - )?; - - assert!(deserialized.has_qjl); - assert_eq!(deserialized.dimension, encoded.dimension()); - - // Verify decode: original vs rebuilt from children. - let mut ctx = SESSION.create_execution_ctx(); - let decoded_original = encoded - .array() - .clone() - .execute::(&mut ctx)?; - let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); - - // Rebuild with QJL children. - let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new_qjl( + let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new( encoded.dtype().clone(), children[0].clone(), children[1].clone(), children[2].clone(), children[3].clone(), - crate::encodings::turboquant::array::QjlCorrection { - signs: children[4].clone(), - residual_norms: children[5].clone(), - rotation_signs: children[6].clone(), - }, deserialized.dimension, deserialized.bit_width as u8, )?; @@ -860,7 +567,7 @@ mod tests { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; // Full decompress then slice. let mut ctx = SESSION.create_execution_ctx(); @@ -881,32 +588,6 @@ mod tests { Ok(()) } - #[test] - fn slice_qjl_preserves_data() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let config = TurboQuantConfig { - bit_width: 4, - seed: Some(456), - }; - let encoded = turboquant_encode_qjl(&fsl, &config)?; - - let mut ctx = SESSION.create_execution_ctx(); - let full_decoded = encoded.clone().execute::(&mut ctx)?; - let expected = full_decoded.slice(3..8)?; - let expected_prim = expected.to_canonical()?.into_fixed_size_list(); - let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); - - let sliced = encoded.slice(3..8)?; - let sliced_decoded = sliced.execute::(&mut ctx)?; - let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); - - assert_eq!( - expected_elements.as_slice::(), - actual_elements.as_slice::() - ); - Ok(()) - } - #[test] fn scalar_at_matches_decompress() -> VortexResult<()> { let fsl = make_fsl(10, 64, 42); @@ -914,7 +595,7 @@ mod tests { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; let mut ctx = SESSION.create_execution_ctx(); let full_decoded = encoded.clone().execute::(&mut ctx)?; @@ -934,7 +615,7 @@ mod tests { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; let tq = encoded.as_opt::().unwrap(); // Stored norms should match the actual L2 norms of the input. @@ -958,14 +639,12 @@ mod tests { #[test] fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { - use vortex_array::arrays::FixedSizeListArray; - let fsl = make_fsl(20, 128, 42); let config = TurboQuantConfig { bit_width: 4, seed: Some(123), }; - let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; let tq = encoded.as_opt::().unwrap(); // Compute exact cosine similarity from original data. diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index 6db642ae25f..cf92c2ce4c4 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -21,13 +21,12 @@ use vortex_error::vortex_ensure; use super::FIXED_SHAPE_TENSOR_EXT_ID; use super::TurboQuantConfig; use super::VECTOR_EXT_ID; -use super::turboquant_encode_qjl; +use super::turboquant_encode; /// TurboQuant compression scheme for tensor extension types. /// /// Applies lossy vector quantization to `Vector` and `FixedShapeTensor` extension -/// arrays using the TurboQuant algorithm with QJL correction for unbiased inner -/// product estimation. +/// arrays using the TurboQuant algorithm with MSE-optimal encoding. /// /// Register this scheme with the compressor builder via `with_scheme`: /// ```ignore @@ -85,30 +84,26 @@ impl Scheme for TurboQuantScheme { let fsl = storage.to_canonical()?.into_fixed_size_list(); let config = TurboQuantConfig::default(); - let encoded = turboquant_encode_qjl(&fsl, &config)?; + let encoded = turboquant_encode(&fsl, &config)?; Ok(ExtensionArray::new(ext_array.ext_dtype().clone(), encoded).into_array()) } } -/// Estimate the compression ratio for TurboQuant QJL encoding with the default config. -/// -/// Uses the default [`TurboQuantConfig`] (5-bit QJL = 4-bit MSE + 1-bit QJL signs). +/// Estimate the compression ratio for TurboQuant MSE encoding with the default config. fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vectors: usize) -> f64 { let config = TurboQuantConfig::default(); let padded_dim = dimensions.next_power_of_two() as usize; - // Per-vector: MSE codes + QJL signs per padded coordinate, - // plus two f32 values (norm and QJL residual norm). - let compressed_bits_per_vector = 2 * 32 // norm + residual_norm are always f32 - + (config.bit_width as usize) * padded_dim; // MSE codes + QJL sign bits + // Per-vector: MSE codes per padded coordinate, plus one f32 norm. + let compressed_bits_per_vector = 32 // norm is always f32 + + (config.bit_width as usize) * padded_dim; // MSE codes - // Shared overhead: codebook centroids (2^mse_bit_width f32 values) and - // rotation signs (3 * padded_dim bits each for MSE and QJL rotations). - let mse_bit_width = config.bit_width - 1; // QJL uses bit_width-1 for MSE - let num_centroids = 1usize << mse_bit_width; + // Shared overhead: codebook centroids (2^bit_width f32 values) and + // rotation signs (3 * padded_dim bits). + let num_centroids = 1usize << config.bit_width; let overhead_bits = num_centroids * 32 // centroids are always f32 - + 2 * 3 * padded_dim; // MSE + QJL rotation signs, 1 bit each + + 3 * padded_dim; // rotation signs, 1 bit each let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits; let uncompressed_size_bits = bits_per_element * num_vectors * dimensions as usize; @@ -160,12 +155,12 @@ mod tests { /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~4-6x. /// f32 input at 1024-d (no padding) should give higher ratio since no waste. #[rstest] - #[case::f32_768d(32, 768, 1000, 3.5, 6.0)] - #[case::f32_1024d(32, 1024, 1000, 4.5, 7.0)] - #[case::f32_1536d(32, 1536, 1000, 3.0, 6.0)] - #[case::f32_128d(32, 128, 1000, 4.0, 6.0)] - #[case::f64_768d(64, 768, 1000, 7.0, 12.0)] - #[case::f16_768d(16, 768, 1000, 1.5, 3.5)] + #[case::f32_768d(32, 768, 1000, 3.5, 8.0)] + #[case::f32_1024d(32, 1024, 1000, 5.0, 9.0)] + #[case::f32_1536d(32, 1536, 1000, 3.0, 8.0)] + #[case::f32_128d(32, 128, 1000, 4.0, 8.0)] + #[case::f64_768d(64, 768, 1000, 7.0, 16.0)] + #[case::f16_768d(16, 768, 1000, 1.5, 4.5)] fn compression_ratio_in_expected_range( #[case] bits_per_element: usize, #[case] dim: u32, diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index f309ca8bc7c..b3fc365e8db 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -131,7 +131,6 @@ impl VTable for TurboQuant { Ok(ProstMetadata(TurboQuantMetadata { dimension: array.dimension, bit_width: array.bit_width as u32, - has_qjl: array.has_qjl(), })) } @@ -165,11 +164,8 @@ impl VTable for TurboQuant { let u8_nn = DType::Primitive(PType::U8, Nullability::NonNullable); let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable); - let codes_dtype = DType::FixedSizeList( - Arc::new(u8_nn.clone()), - padded_dim as u32, - Nullability::NonNullable, - ); + let codes_dtype = + DType::FixedSizeList(Arc::new(u8_nn), padded_dim as u32, Nullability::NonNullable); let codes = children.get(0, &codes_dtype, len)?; let norms = children.get(1, &f32_nn, len)?; @@ -178,24 +174,14 @@ impl VTable for TurboQuant { let signs_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); let rotation_signs = children.get(3, &signs_dtype, 3 * padded_dim)?; - let mut slots = vec![None; Slot::COUNT]; - slots[Slot::Codes as usize] = Some(codes); - slots[Slot::Norms as usize] = Some(norms); - slots[Slot::Centroids as usize] = Some(centroids); - slots[Slot::RotationSigns as usize] = Some(rotation_signs); - - if metadata.has_qjl { - let qjl_signs_dtype = - DType::FixedSizeList(Arc::new(u8_nn), padded_dim as u32, Nullability::NonNullable); - slots[Slot::QjlSigns as usize] = Some(children.get(4, &qjl_signs_dtype, len)?); - slots[Slot::QjlResidualNorms as usize] = Some(children.get(5, &f32_nn, len)?); - slots[Slot::QjlRotationSigns as usize] = - Some(children.get(6, &signs_dtype, 3 * padded_dim)?); - } - Ok(TurboQuantData { dtype: dtype.clone(), - slots, + slots: vec![ + Some(codes), + Some(norms), + Some(centroids), + Some(rotation_signs), + ], dimension: metadata.dimension, bit_width, stats_set: Default::default(), diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index abac0e36fed..ba59b7a5b8f 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -441,8 +441,7 @@ mod turboquant_benches { use vortex_array::VortexSessionExecute; use vortex_buffer::BufferMut; use vortex_tensor::encodings::turboquant::TurboQuantConfig; - use vortex_tensor::encodings::turboquant::turboquant_encode_mse; - use vortex_tensor::encodings::turboquant::turboquant_encode_qjl; + use vortex_tensor::encodings::turboquant::turboquant_encode; use super::SESSION; use super::with_byte_counter; @@ -483,48 +482,23 @@ mod turboquant_benches { macro_rules! turboquant_bench { (compress, $dim:literal, $bits:literal, $name:ident) => { paste! { - #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] - fn [<$name _mse>](bencher: Bencher) { + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] + fn $name(bencher: Bencher) { let fsl = setup_vector_fsl($dim); let config = turboquant_config($bits); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) .with_inputs(|| &fsl) - .bench_refs(|a| turboquant_encode_mse(a, &config).unwrap()); - } - - #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] - fn [<$name _qjl>](bencher: Bencher) { - let fsl = setup_vector_fsl($dim); - let config = turboquant_config($bits); - with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) - .with_inputs(|| &fsl) - .bench_refs(|a| turboquant_encode_qjl(a, &config).unwrap()); + .bench_refs(|a| turboquant_encode(a, &config).unwrap()); } } }; (decompress, $dim:literal, $bits:literal, $name:ident) => { paste! { - #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] - fn [<$name _mse>](bencher: Bencher) { - let fsl = setup_vector_fsl($dim); - let config = turboquant_config($bits); - let compressed = turboquant_encode_mse(&fsl, &config).unwrap(); - with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) - .with_inputs(|| &compressed) - .bench_refs(|a| { - let mut ctx = SESSION.create_execution_ctx(); - a.clone() - .into_array() - .execute::(&mut ctx) - .unwrap() - }); - } - - #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] - fn [<$name _qjl>](bencher: Bencher) { + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] + fn $name(bencher: Bencher) { let fsl = setup_vector_fsl($dim); let config = turboquant_config($bits); - let compressed = turboquant_encode_qjl(&fsl, &config).unwrap(); + let compressed = turboquant_encode(&fsl, &config).unwrap(); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) .with_inputs(|| &compressed) .bench_refs(|a| { From cc0ae6c0fbfa2c9585c2a63cf8d5c0fc59f95187 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 14:38:35 -0400 Subject: [PATCH 03/13] TurboQuant is only an encoding for `Vector`s It doesn't really make a lot of sense for us to define this as an encoding for `FixedSizeList`. Signed-off-by: Connor Tsui --- vortex-btrblocks/src/builder.rs | 1 - .../src/encodings/turboquant/array.rs | 143 +++++++++++++++- .../src/encodings/turboquant/compress.rs | 44 ++++- .../src/encodings/turboquant/compute/ops.rs | 4 +- .../src/encodings/turboquant/decompress.rs | 24 +-- vortex-tensor/src/encodings/turboquant/mod.rs | 152 ++++++++++++++---- .../src/encodings/turboquant/scheme.rs | 8 +- .../src/scalar_fns/cosine_similarity.rs | 29 ++-- vortex-tensor/src/scalar_fns/inner_product.rs | 27 ++-- vortex-tensor/src/scalar_fns/l2_norm.rs | 25 +-- vortex-tensor/src/scalar_fns/mod.rs | 14 ++ vortex/benches/single_encoding_throughput.rs | 30 ++-- 12 files changed, 382 insertions(+), 119 deletions(-) diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index 4e197feb5a2..5728c606d1e 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -148,7 +148,6 @@ impl BtrBlocksCompressorBuilder { /// Panics if the TurboQuant scheme is already present. /// /// [`Vector`]: vortex_tensor::vector::Vector - /// [`FixedShapeTensor`]: vortex_tensor::fixed_shape::FixedShapeTensor #[cfg(feature = "unstable_encodings")] pub fn with_turboquant(self) -> Self { use vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME; diff --git a/vortex-tensor/src/encodings/turboquant/array.rs b/vortex-tensor/src/encodings/turboquant/array.rs index 3172f28f808..c782d66641c 100644 --- a/vortex-tensor/src/encodings/turboquant/array.rs +++ b/vortex-tensor/src/encodings/turboquant/array.rs @@ -13,6 +13,8 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use crate::vector::Vector; + /// Encoding marker type for TurboQuant. #[derive(Clone, Debug)] pub struct TurboQuant; @@ -84,7 +86,17 @@ pub struct TurboQuantData { } impl TurboQuantData { - /// Build a TurboQuant array. + /// Build a TurboQuant array with validation. + /// + /// The `dtype` must be a [`Vector`] extension type. TurboQuant encodes the extension + /// type directly, not its `FixedSizeList` storage. + /// + /// # Errors + /// + /// Returns an error if the provided components do not satisfy the invariants documented + /// in [`new_unchecked`](Self::new_unchecked). + /// + /// [`Vector`]: crate::vector::Vector #[allow(clippy::too_many_arguments)] pub fn try_new( dtype: DType, @@ -95,22 +107,139 @@ impl TurboQuantData { dimension: u32, bit_width: u8, ) -> VortexResult { - vortex_ensure!( - (1..=8).contains(&bit_width), - "bit_width must be 1-8, got {bit_width}" - ); + Self::validate( + &dtype, + &codes, + &norms, + ¢roids, + &rotation_signs, + dimension, + bit_width, + )?; + + // SAFETY: we validate that the inputs are valid above. + Ok(unsafe { + Self::new_unchecked( + dtype, + codes, + norms, + centroids, + rotation_signs, + dimension, + bit_width, + ) + }) + } + + /// Build a TurboQuant array without validation. + /// + /// * `dtype` must be a [`Vector`] extension type. + /// * `codes` must be a `FixedSizeListArray` with `list_size == padded_dim`. + /// * `norms` must be a `PrimitiveArray` with one element per row. + /// * `centroids` must be a `PrimitiveArray` with `2^bit_width` elements. + /// * `rotation_signs` must contain `3 * padded_dim` sign values. + /// * `bit_width` must be 1-8. + /// * `codes.len() == norms.len()`. + /// + /// # Safety + /// + /// The caller must ensure the inputs satisfy the invariants listed above. Violating them + /// may produce incorrect results during decompression. + /// + /// [`Vector`]: crate::vector::Vector + #[allow(clippy::too_many_arguments)] + pub unsafe fn new_unchecked( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + dimension: u32, + bit_width: u8, + ) -> Self { + #[cfg(debug_assertions)] + Self::validate( + &dtype, + &codes, + &norms, + ¢roids, + &rotation_signs, + dimension, + bit_width, + ) + .vortex_expect("[Debug Assertion]: Invalid TurboQuantData parameters"); + let mut slots = vec![None; Slot::COUNT]; slots[Slot::Codes as usize] = Some(codes); slots[Slot::Norms as usize] = Some(norms); slots[Slot::Centroids as usize] = Some(centroids); slots[Slot::RotationSigns as usize] = Some(rotation_signs); - Ok(Self { + Self { dtype, slots, dimension, bit_width, stats_set: Default::default(), - }) + } + } + + /// Validates the components that would be used to create a `TurboQuantData`. + /// + /// This function checks all the invariants required by [`new_unchecked`](Self::new_unchecked). + #[allow(clippy::too_many_arguments)] + pub fn validate( + dtype: &DType, + codes: &ArrayRef, + norms: &ArrayRef, + centroids: &ArrayRef, + rotation_signs: &ArrayRef, + dimension: u32, + bit_width: u8, + ) -> VortexResult<()> { + vortex_ensure!( + (1..=8).contains(&bit_width), + "bit_width must be 1-8, got {bit_width}" + ); + vortex_ensure!( + dtype + .as_extension_opt() + .is_some_and(|ext| ext.is::()), + "TurboQuant dtype must be a Vector extension type, got {dtype}" + ); + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + let num_rows = norms.len(); + vortex_ensure!( + codes.len() == num_rows, + "codes length {} does not match norms length {num_rows}", + codes.len() + ); + + let expected_centroids = 1usize << bit_width; + // Allow empty centroids for zero-row arrays. + if num_rows > 0 { + vortex_ensure!( + centroids.len() == expected_centroids, + "centroids length {} does not match expected 2^{bit_width} = {expected_centroids}", + centroids.len() + ); + } + + let padded_dim = dimension.next_power_of_two() as usize; + // Allow empty rotation signs for zero-row arrays. + if num_rows > 0 { + vortex_ensure!( + rotation_signs.len() == 3 * padded_dim, + "rotation_signs length {} does not match expected 3 * {padded_dim} = {}", + rotation_signs.len(), + 3 * padded_dim + ); + } + + Ok(()) } /// The vector dimension d. diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 53f7046bad2..ff833e05cc2 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -5,8 +5,10 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::validity::Validity; @@ -137,6 +139,7 @@ fn build_turboquant( fsl: &FixedSizeListArray, core: QuantizationResult, bit_width: u8, + ext_dtype: DType, ) -> VortexResult { let dimension = fsl.list_size(); @@ -164,7 +167,7 @@ fn build_turboquant( let rotation_signs = bitpack_rotation_signs(&core.rotation)?; TurboQuantData::try_new( - fsl.dtype().clone(), + ext_dtype, codes, norms_array, centroids_array, @@ -174,14 +177,20 @@ fn build_turboquant( ) } -/// Encode a FixedSizeListArray into a `TurboQuantArray`. +/// Encode a [`Vector`] extension array into a `TurboQuantArray`. /// -/// The input must be non-nullable. TurboQuant is a lossy encoding that does not -/// preserve null positions; callers must handle validity externally. +/// The input must be a non-nullable [`Vector`] extension array. TurboQuant is a lossy encoding +/// that does not preserve null positions; callers must handle validity externally. +/// +/// [`Vector`]: crate::vector::Vector pub fn turboquant_encode( - fsl: &FixedSizeListArray, + ext: &ExtensionArray, config: &TurboQuantConfig, ) -> VortexResult { + let ext_dtype = ext.dtype().clone(); + let storage = ext.storage_array(); + let fsl = storage.to_canonical()?.into_fixed_size_list(); + vortex_ensure!( fsl.dtype().nullability() == Nullability::NonNullable, "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" @@ -198,13 +207,32 @@ pub fn turboquant_encode( ); if fsl.is_empty() { - return Ok(fsl.clone().into_array()); + let padded_dim = dimension.next_power_of_two(); + let empty_codes = FixedSizeListArray::try_new( + PrimitiveArray::empty::(Nullability::NonNullable).into_array(), + padded_dim, + Validity::NonNullable, + 0, + )?; + let empty_norms = PrimitiveArray::empty::(Nullability::NonNullable); + let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); + let empty_signs = PrimitiveArray::empty::(Nullability::NonNullable); + return Ok(TurboQuantData::try_new( + ext_dtype, + empty_codes.into_array(), + empty_norms.into_array(), + empty_centroids.into_array(), + empty_signs.into_array(), + dimension, + config.bit_width, + )? + .into_array()); } let seed = config.seed.unwrap_or(42); - let core = turboquant_quantize_core(fsl, seed, config.bit_width)?; + let core = turboquant_quantize_core(&fsl, seed, config.bit_width)?; - Ok(build_turboquant(fsl, core, config.bit_width)?.into_array()) + Ok(build_turboquant(&fsl, core, config.bit_width, ext_dtype)?.into_array()) } /// Export rotation signs as a 1-bit `BitPackedArray` for efficient storage. diff --git a/vortex-tensor/src/encodings/turboquant/compute/ops.rs b/vortex-tensor/src/encodings/turboquant/compute/ops.rs index 89f25c61018..5309669ed53 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/ops.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/ops.rs @@ -3,7 +3,7 @@ use vortex_array::ArrayView; use vortex_array::ExecutionCtx; -use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::slice::SliceReduce; use vortex_array::scalar::Scalar; use vortex_array::vtable::OperationsVTable; @@ -22,7 +22,7 @@ impl OperationsVTable for TurboQuant { let Some(sliced) = ::slice(array, index..index + 1)? else { vortex_bail!("slice returned None for index {index}") }; - let decoded = sliced.execute::(ctx)?; + let decoded = sliced.execute::(ctx)?; decoded.scalar_at(0) } } diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index 0f5d7c1942a..5897784c306 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -7,6 +7,7 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::validity::Validity; @@ -16,10 +17,12 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::rotation::RotationMatrix; -/// Decompress a `TurboQuantArray` into a `FixedSizeListArray` of floats. +/// Decompress a `TurboQuantArray` into a [`Vector`] extension array. /// -/// Reads stored centroids and rotation signs from the array's children, -/// avoiding any recomputation. +/// The returned array is an [`ExtensionArray`] with the original Vector dtype wrapping a +/// `FixedSizeListArray` of f32 elements. +/// +/// [`Vector`]: crate::vector::Vector pub fn execute_decompress( array: Array, ctx: &mut ExecutionCtx, @@ -27,16 +30,17 @@ pub fn execute_decompress( let dim = array.dimension() as usize; let padded_dim = array.padded_dim() as usize; let num_rows = array.norms().len(); + let ext_dtype = array.dtype.as_extension().clone(); if num_rows == 0 { - let elements = PrimitiveArray::empty::(array.dtype.nullability()); - return Ok(FixedSizeListArray::try_new( + let elements = PrimitiveArray::empty::(ext_dtype.storage_dtype().nullability()); + let fsl = FixedSizeListArray::try_new( elements.into_array(), array.dimension(), Validity::NonNullable, 0, - )? - .into_array()); + )?; + return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()); } // Read stored centroids -- no recomputation. @@ -83,11 +87,11 @@ pub fn execute_decompress( } let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - Ok(FixedSizeListArray::try_new( + let fsl = FixedSizeListArray::try_new( elements.into_array(), array.dimension(), Validity::NonNullable, num_rows, - )? - .into_array()) + )?; + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) } diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index d186d315b6e..94a3e24bff1 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -4,10 +4,11 @@ //! TurboQuant vector quantization encoding for Vortex. //! //! Implements the TurboQuant algorithm ([arXiv:2504.19874]) for lossy compression of -//! high-dimensional vector data. The encoding operates on `FixedSizeList` arrays of floats -//! (the storage format of `Vector` and `FixedShapeTensor` extension types). +//! high-dimensional vector data. The encoding operates on [`Vector`] extension arrays, +//! compressing their `FixedSizeList` storage into quantized codes with an SRHT rotation. //! //! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! [`Vector`]: crate::vector::Vector //! //! # Overview //! @@ -51,27 +52,34 @@ //! //! ``` //! use vortex_array::IntoArray; +//! use vortex_array::arrays::ExtensionArray; //! use vortex_array::arrays::FixedSizeListArray; //! use vortex_array::arrays::PrimitiveArray; +//! use vortex_array::dtype::extension::ExtDType; +//! use vortex_array::extension::EmptyMetadata; //! use vortex_array::validity::Validity; //! use vortex_buffer::BufferMut; //! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; +//! use vortex_tensor::vector::Vector; //! -//! // Create a FixedSizeListArray of 100 random 128-d vectors. +//! // Create a Vector extension array of 100 random 128-d vectors. //! let num_rows = 100; -//! let dim = 128; -//! let mut buf = BufferMut::::with_capacity(num_rows * dim); -//! for i in 0..(num_rows * dim) { +//! let dim = 128u32; +//! let mut buf = BufferMut::::with_capacity(num_rows * dim as usize); +//! for i in 0..(num_rows * dim as usize) { //! buf.push((i as f32 * 0.001).sin()); //! } //! let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); //! let fsl = FixedSizeListArray::try_new( -//! elements.into_array(), dim as u32, Validity::NonNullable, num_rows, +//! elements.into_array(), dim, Validity::NonNullable, num_rows, //! ).unwrap(); +//! let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) +//! .unwrap().erased(); +//! let ext = ExtensionArray::new(ext_dtype, fsl.into_array()); //! //! // Quantize at 2 bits per coordinate. //! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; -//! let encoded = turboquant_encode(&fsl, &config).unwrap(); +//! let encoded = turboquant_encode(&ext, &config).unwrap(); //! //! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. //! assert!(encoded.nbytes() < 51200); @@ -118,8 +126,11 @@ mod tests { use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; + use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; @@ -130,6 +141,7 @@ mod tests { use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::rotation::RotationMatrix; use crate::encodings::turboquant::turboquant_encode; + use crate::vector::Vector; static SESSION: LazyLock = LazyLock::new(|| VortexSession::empty().with::()); @@ -154,6 +166,14 @@ mod tests { .unwrap() } + /// Wrap a `FixedSizeListArray` in a `Vector` extension array. + fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) + .unwrap() + .erased(); + ExtensionArray::new(ext_dtype, fsl.clone().into_array()) + } + fn theoretical_mse_bound(bit_width: u8) -> f32 { let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) @@ -192,12 +212,22 @@ mod tests { let prim = fsl.elements().to_canonical().unwrap().into_primitive(); prim.as_slice::().to_vec() }; + let ext = make_vector_ext(fsl); let config = config.clone(); - let encoded = turboquant_encode(fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let mut ctx = SESSION.create_execution_ctx(); - let decoded = encoded.execute::(&mut ctx)?; + let decoded_ext = encoded.execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical() + .unwrap() + .into_fixed_size_list(); let decoded_elements: Vec = { - let prim = decoded.elements().to_canonical().unwrap().into_primitive(); + let prim = decoded_fsl + .elements() + .to_canonical() + .unwrap() + .into_primitive(); prim.as_slice::().to_vec() }; Ok((original, decoded_elements)) @@ -322,13 +352,14 @@ mod tests { #[case(1)] fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { let fsl = make_fsl(num_rows, 128, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 2, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let mut ctx = SESSION.create_execution_ctx(); - let decoded = encoded.execute::(&mut ctx)?; + let decoded = encoded.execute::(&mut ctx)?; assert_eq!(decoded.len(), num_rows); Ok(()) } @@ -338,11 +369,12 @@ mod tests { #[case(2)] fn rejects_dimension_below_3(#[case] dim: usize) { let fsl = make_fsl_small(dim); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 2, seed: Some(0), }; - assert!(turboquant_encode(&fsl, &config).is_err()); + assert!(turboquant_encode(&ext, &config).is_err()); } fn make_fsl_small(dim: usize) -> FixedSizeListArray { @@ -402,12 +434,13 @@ mod tests { num_rows, )?; + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, seed: Some(42), }; // Verify encoding succeeds with f64 input (f64->f32 conversion). - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let encoded = encoded.as_opt::().unwrap(); assert_eq!(encoded.norms().len(), num_rows); assert_eq!(encoded.dimension(), dim as u32); @@ -422,11 +455,12 @@ mod tests { #[test] fn stored_centroids_match_computed() -> VortexResult<()> { let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let encoded = encoded.as_opt::().unwrap(); let mut ctx = SESSION.create_execution_ctx(); @@ -450,19 +484,24 @@ mod tests { #[test] fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let encoded = encoded.as_opt::().unwrap(); // Decode via the stored-signs path (normal decode). let mut ctx = SESSION.create_execution_ctx(); - let decoded_fsl = encoded + let decoded_ext = encoded .array() .clone() - .execute::(&mut ctx)?; + .execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical()? + .into_fixed_size_list(); let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); let decoded_slice = decoded.as_slice::(); @@ -494,11 +533,12 @@ mod tests { use vortex_array::vtable::VTable; let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let encoded = encoded.as_opt::().unwrap(); // Serialize metadata. @@ -531,8 +571,12 @@ mod tests { let decoded_original = encoded .array() .clone() - .execute::(&mut ctx)?; - let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); + .execute::(&mut ctx)?; + let original_fsl = decoded_original + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let original_elements = original_fsl.elements().to_canonical()?.into_primitive(); // Rebuild from children (simulating deserialization). let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new( @@ -544,10 +588,12 @@ mod tests { deserialized.dimension, deserialized.bit_width as u8, )?; - let decoded_rebuilt = rebuilt - .into_array() - .execute::(&mut ctx)?; - let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); + let decoded_rebuilt = rebuilt.into_array().execute::(&mut ctx)?; + let rebuilt_fsl = decoded_rebuilt + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let rebuilt_elements = rebuilt_fsl.elements().to_canonical()?.into_primitive(); assert_eq!( original_elements.as_slice::(), @@ -563,23 +609,32 @@ mod tests { #[test] fn slice_preserves_data() -> VortexResult<()> { let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; // Full decompress then slice. let mut ctx = SESSION.create_execution_ctx(); - let full_decoded = encoded.clone().execute::(&mut ctx)?; - let expected = full_decoded.slice(5..10)?; + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let full_fsl = full_decoded + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let expected = full_fsl.slice(5..10)?; let expected_prim = expected.to_canonical()?.into_fixed_size_list(); let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); // Slice then decompress. let sliced = encoded.slice(5..10)?; - let sliced_decoded = sliced.execute::(&mut ctx)?; - let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); + let sliced_decoded = sliced.execute::(&mut ctx)?; + let sliced_fsl = sliced_decoded + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let actual_elements = sliced_fsl.elements().to_canonical()?.into_primitive(); assert_eq!( expected_elements.as_slice::(), @@ -591,14 +646,15 @@ mod tests { #[test] fn scalar_at_matches_decompress() -> VortexResult<()> { let fsl = make_fsl(10, 64, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let mut ctx = SESSION.create_execution_ctx(); - let full_decoded = encoded.clone().execute::(&mut ctx)?; + let full_decoded = encoded.clone().execute::(&mut ctx)?; for i in [0, 1, 5, 9] { let expected = full_decoded.scalar_at(i)?; @@ -611,11 +667,12 @@ mod tests { #[test] fn l2_norm_readthrough() -> VortexResult<()> { let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let tq = encoded.as_opt::().unwrap(); // Stored norms should match the actual L2 norms of the input. @@ -640,11 +697,12 @@ mod tests { #[test] fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 4, seed: Some(123), }; - let encoded = turboquant_encode(&fsl, &config)?; + let encoded = turboquant_encode(&ext, &config)?; let tq = encoded.as_opt::().unwrap(); // Compute exact cosine similarity from original data. @@ -694,4 +752,28 @@ mod tests { } Ok(()) } + + /// Verify that the encoded array's dtype is a Vector extension type. + #[test] + fn encoded_dtype_is_vector_extension() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode(&ext, &config)?; + + // The encoded TurboQuant array should claim a Vector extension dtype. + assert!( + encoded.dtype().is_extension(), + "TurboQuant dtype should be an extension type, got {}", + encoded.dtype() + ); + assert!( + encoded.dtype().as_extension().is::(), + "TurboQuant dtype should be a Vector extension type" + ); + Ok(()) + } } diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index cf92c2ce4c4..e357d8a4005 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -5,8 +5,6 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; -use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; @@ -80,13 +78,9 @@ impl Scheme for TurboQuantScheme { ) -> VortexResult { let array = data.array().clone(); let ext_array = array.to_canonical()?.into_extension(); - let storage = ext_array.storage_array(); - let fsl = storage.to_canonical()?.into_fixed_size_list(); let config = TurboQuantConfig::default(); - let encoded = turboquant_encode(&fsl, &config)?; - - Ok(ExtensionArray::new(ext_array.ext_dtype().clone(), encoded).into_array()) + turboquant_encode(&ext_array, &config) } } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 9f9f440c4f4..f466ab91aff 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -144,26 +144,27 @@ impl ScalarFnVTable for CosineSimilarity { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs = args.get(0)?.execute::(ctx)?; - let rhs = args.get(1)?.execute::(ctx)?; + let lhs_ref = args.get(0)?; + let rhs_ref = args.get(1)?; let len = args.row_count(); + // TurboQuant approximate path: check encoding before executing. + if options.is_approx() + && let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_ref.as_opt::(), + rhs_ref.as_opt::(), + ) + { + return cosine_similarity::cosine_similarity_quantized_column(lhs_tq, rhs_tq, ctx); + } + + let lhs = lhs_ref.execute::(ctx)?; + let rhs = rhs_ref.execute::(ctx)?; + // Compute combined validity. let validity = lhs.as_ref().validity()?.and(rhs.as_ref().validity()?)?; - // TurboQuant approximate path: compute cosine similarity in quantized domain. - if *options == ApproxOptions::Approximate { - let lhs_storage = lhs.data().storage_array(); - let rhs_storage = rhs.data().storage_array(); - if let (Some(lhs_tq), Some(rhs_tq)) = ( - lhs_storage.as_opt::(), - rhs_storage.as_opt::(), - ) { - return cosine_similarity::cosine_similarity_quantized_column(lhs_tq, rhs_tq, ctx); - } - } - let lhs = lhs.into_array(); let rhs = rhs.into_array(); diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 044b76b164c..4e1a3805d5a 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -143,11 +143,24 @@ impl ScalarFnVTable for InnerProduct { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs: ExtensionArray = args.get(0)?.execute(ctx)?; - let rhs: ExtensionArray = args.get(1)?.execute(ctx)?; + let lhs_ref = args.get(0)?; + let rhs_ref = args.get(1)?; let row_count = args.row_count(); + // TurboQuant approximate path: check encoding before executing. + if options.is_approx() + && let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_ref.as_opt::(), + rhs_ref.as_opt::(), + ) + { + return cosine_similarity::dot_product_quantized_column(lhs_tq, rhs_tq, ctx); + } + + let lhs: ExtensionArray = lhs_ref.execute(ctx)?; + let rhs: ExtensionArray = rhs_ref.execute(ctx)?; + // Compute combined validity. let rhs_validity = rhs.as_ref().validity()?; let validity = lhs.as_ref().validity()?.and(rhs_validity)?; @@ -162,16 +175,6 @@ impl ScalarFnVTable for InnerProduct { let lhs_storage = lhs.data().storage_array(); let rhs_storage = rhs.data().storage_array(); - // TurboQuant approximate path: norm_a * norm_b * quantized unit-norm dot. - if *options == ApproxOptions::Approximate - && let (Some(lhs_tq), Some(rhs_tq)) = ( - lhs_storage.as_opt::(), - rhs_storage.as_opt::(), - ) - { - return cosine_similarity::dot_product_quantized_column(lhs_tq, rhs_tq, ctx); - } - let lhs_flat = extract_flat_elements(lhs_storage, list_size, ctx)?; let rhs_flat = extract_flat_elements(rhs_storage, list_size, ctx)?; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 5992bf4fdfc..f573386768b 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -125,27 +125,28 @@ impl ScalarFnVTable for L2Norm { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let input: ExtensionArray = args.get(0)?.execute(ctx)?; - + let input_ref = args.get(0)?; let row_count = args.row_count(); - let validity = input.as_ref().validity()?; - - // Get element ptype and list size from the dtype (validated by `return_dtype`). - let ext = input.dtype().as_extension(); - let list_size = extension_list_size(ext)? as usize; - let target_ptype = extension_element_ptype(ext)?; - - let storage = input.data().storage_array(); // TurboQuant stores exact precomputed norms -- no decompression needed. // Norms are currently stored as f32; cast to the target dtype if needed // (e.g., if the input extension has f64 elements). - if let Some(tq) = storage.as_opt::() { + if let Some(tq) = input_ref.as_opt::() { + let ext = input_ref.dtype().as_extension(); + let target_ptype = extension_element_ptype(ext)?; let norms: PrimitiveArray = tq.norms().clone().execute(ctx)?; - let target_dtype = DType::Primitive(target_ptype, input.dtype().nullability()); + let target_dtype = DType::Primitive(target_ptype, input_ref.dtype().nullability()); return norms.into_array().cast(target_dtype); } + let input: ExtensionArray = input_ref.execute(ctx)?; + let validity = input.as_ref().validity()?; + + // Get element ptype and list size from the dtype (validated by `return_dtype`). + let ext = input.dtype().as_extension(); + let list_size = extension_list_size(ext)? as usize; + + let storage = input.data().storage_array(); let flat = extract_flat_elements(storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index b10fd335420..8fb1883b706 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -12,11 +12,25 @@ pub mod l2_norm; /// Options for tensor-related expressions that might have error. #[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] pub enum ApproxOptions { + /// Computes the exact result. #[default] Exact, + /// Allows approximate results. Approximate, } +impl ApproxOptions { + /// Returns `true` if the option is [`Exact`](Self::Exact). + pub fn is_exact(&self) -> bool { + matches!(self, Self::Exact) + } + + /// Returns `true` if the option is [`Approximate`](Self::Approximate). + pub fn is_approx(&self) -> bool { + matches!(self, Self::Approximate) + } +} + impl fmt::Display for ApproxOptions { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index ba59b7a5b8f..39ff2177544 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -435,25 +435,29 @@ mod turboquant_benches { use rand::SeedableRng; use rand::rngs::StdRng; use vortex::array::IntoArray; + use vortex::array::arrays::ExtensionArray; use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; + use vortex::array::dtype::extension::ExtDType; + use vortex::array::extension::EmptyMetadata; use vortex::array::validity::Validity; use vortex_array::VortexSessionExecute; use vortex_buffer::BufferMut; use vortex_tensor::encodings::turboquant::TurboQuantConfig; use vortex_tensor::encodings::turboquant::turboquant_encode; + use vortex_tensor::vector::Vector; use super::SESSION; use super::with_byte_counter; const NUM_VECTORS: usize = 1_000; - /// Generate `num_vectors` random f32 vectors of the given dimension using i.i.d. - /// standard normal components. This is a conservative test distribution: real - /// neural network embeddings typically have structure (clustered, anisotropic) + /// Generate `num_vectors` random f32 Vector extension arrays of the given dimension + /// using i.i.d. standard normal components. This is a conservative test distribution: + /// real neural network embeddings typically have structure (clustered, anisotropic) /// that the SRHT exploits for better quantization, so Gaussian i.i.d. is a /// worst-case baseline for TurboQuant. - fn setup_vector_fsl(dim: usize) -> FixedSizeListArray { + fn setup_vector_ext(dim: usize) -> ExtensionArray { let mut rng = StdRng::seed_from_u64(42); let normal = rand_distr::Normal::new(0.0f32, 1.0).unwrap(); @@ -463,13 +467,17 @@ mod turboquant_benches { } let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( + let fsl = FixedSizeListArray::try_new( elements.into_array(), dim as u32, Validity::NonNullable, NUM_VECTORS, ) - .unwrap() + .unwrap(); + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) + .unwrap() + .erased(); + ExtensionArray::new(ext_dtype, fsl.into_array()) } fn turboquant_config(bit_width: u8) -> TurboQuantConfig { @@ -484,10 +492,10 @@ mod turboquant_benches { paste! { #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] fn $name(bencher: Bencher) { - let fsl = setup_vector_fsl($dim); + let ext = setup_vector_ext($dim); let config = turboquant_config($bits); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) - .with_inputs(|| &fsl) + .with_inputs(|| &ext) .bench_refs(|a| turboquant_encode(a, &config).unwrap()); } } @@ -496,16 +504,16 @@ mod turboquant_benches { paste! { #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] fn $name(bencher: Bencher) { - let fsl = setup_vector_fsl($dim); + let ext = setup_vector_ext($dim); let config = turboquant_config($bits); - let compressed = turboquant_encode(&fsl, &config).unwrap(); + let compressed = turboquant_encode(&ext, &config).unwrap(); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) .with_inputs(|| &compressed) .bench_refs(|a| { let mut ctx = SESSION.create_execution_ctx(); a.clone() .into_array() - .execute::(&mut ctx) + .execute::(&mut ctx) .unwrap() }); } From 45131a54d727ab47118afe08ab12bd0284a30323 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 15:47:24 -0400 Subject: [PATCH 04/13] Refactor TurboQuant encoding API and tests - Use ExecutionCtx in TurboQuant compress path and import ExecutionCtx - Extend dtype imports with Nullability and PType to support extension types - Wire in extension utilities: extension_element_ptype and extension_list_size for vector extensions - Remove dimension and bit_width from slice/take compute calls to rely on metadata - Update TurboQuant mod docs to mention VortexSessionExecute - Change scheme.compress to use the provided compressor argument (not _compressor) - Add an extensive TurboQuant test suite (roundtrip, MSE bounds, edge cases, f64 input, serde roundtrip, and dtype checks) - Align vtable imports to new metadata handling (remove unused DeserializeMetadata/SerializeMetadata references) Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 40 +- .../src/encodings/turboquant/array.rs | 286 +++++--- .../src/encodings/turboquant/compress.rs | 73 +- .../src/encodings/turboquant/compute/slice.rs | 2 - .../src/encodings/turboquant/compute/take.rs | 2 - vortex-tensor/src/encodings/turboquant/mod.rs | 671 +---------------- .../src/encodings/turboquant/scheme.rs | 4 +- .../src/encodings/turboquant/tests.rs | 675 ++++++++++++++++++ .../src/encodings/turboquant/vtable.rs | 91 ++- vortex/benches/single_encoding_throughput.rs | 8 +- 10 files changed, 993 insertions(+), 859 deletions(-) create mode 100644 vortex-tensor/src/encodings/turboquant/tests.rs diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index fbbbe0ace6c..f455a994295 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -28,7 +28,7 @@ impl core::marker::StructuralPartialEq for vortex_tensor::encodings::turboquant: impl vortex_compressor::scheme::Scheme for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::compress(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::compress(&self, compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::expected_compression_ratio(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult @@ -38,24 +38,6 @@ pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::scheme_na pub static vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME: vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme -pub struct vortex_tensor::encodings::turboquant::QjlCorrection - -impl vortex_tensor::encodings::turboquant::QjlCorrection - -pub fn vortex_tensor::encodings::turboquant::QjlCorrection::residual_norms(&self) -> &vortex_array::array::erased::ArrayRef - -pub fn vortex_tensor::encodings::turboquant::QjlCorrection::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef - -pub fn vortex_tensor::encodings::turboquant::QjlCorrection::signs(&self) -> &vortex_array::array::erased::ArrayRef - -impl core::clone::Clone for vortex_tensor::encodings::turboquant::QjlCorrection - -pub fn vortex_tensor::encodings::turboquant::QjlCorrection::clone(&self) -> vortex_tensor::encodings::turboquant::QjlCorrection - -impl core::fmt::Debug for vortex_tensor::encodings::turboquant::QjlCorrection - -pub fn vortex_tensor::encodings::turboquant::QjlCorrection::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - pub struct vortex_tensor::encodings::turboquant::TurboQuant impl vortex_tensor::encodings::turboquant::TurboQuant @@ -74,7 +56,7 @@ impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquan pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData -pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_array::metadata::ProstMetadata +pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::array::TurboQuantMetadata pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant @@ -166,19 +148,17 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::codes(&self) -> &vo pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32 -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::has_qjl(&self) -> bool +pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> Self pub fn vortex_tensor::encodings::turboquant::TurboQuantData::norms(&self) -> &vortex_array::array::erased::ArrayRef pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32 -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::qjl(&self) -> core::option::Option - pub fn vortex_tensor::encodings::turboquant::TurboQuantData::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new_mse(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef, dimension: u32, bit_width: u8) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new_qjl(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef, qjl: vortex_tensor::encodings::turboquant::QjlCorrection, dimension: u32, bit_width: u8) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()> impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantData @@ -202,9 +182,7 @@ pub const vortex_tensor::encodings::turboquant::VECTOR_EXT_ID: &str pub fn vortex_tensor::encodings::turboquant::initialize(session: &mut vortex_session::VortexSession) -pub fn vortex_tensor::encodings::turboquant::turboquant_encode_mse(fsl: &vortex_array::arrays::fixed_size_list::vtable::FixedSizeListArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::turboquant_encode_qjl(fsl: &vortex_array::arrays::fixed_size_list::vtable::FixedSizeListArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: &vortex_array::arrays::extension::vtable::ExtensionArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub mod vortex_tensor::fixed_shape @@ -440,6 +418,12 @@ pub vortex_tensor::scalar_fns::ApproxOptions::Approximate pub vortex_tensor::scalar_fns::ApproxOptions::Exact +impl vortex_tensor::scalar_fns::ApproxOptions + +pub fn vortex_tensor::scalar_fns::ApproxOptions::is_approx(&self) -> bool + +pub fn vortex_tensor::scalar_fns::ApproxOptions::is_exact(&self) -> bool + impl core::clone::Clone for vortex_tensor::scalar_fns::ApproxOptions pub fn vortex_tensor::scalar_fns::ApproxOptions::clone(&self) -> vortex_tensor::scalar_fns::ApproxOptions diff --git a/vortex-tensor/src/encodings/turboquant/array.rs b/vortex-tensor/src/encodings/turboquant/array.rs index c782d66641c..e1eebce6e0f 100644 --- a/vortex-tensor/src/encodings/turboquant/array.rs +++ b/vortex-tensor/src/encodings/turboquant/array.rs @@ -7,12 +7,16 @@ use vortex_array::ArrayId; use vortex_array::ArrayRef; use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; use vortex_array::stats::ArrayStats; use vortex_array::vtable; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; use crate::vector::Vector; /// Encoding marker type for TurboQuant. @@ -25,15 +29,14 @@ impl TurboQuant { vtable!(TurboQuant, TurboQuant, TurboQuantData); -/// Protobuf metadata for TurboQuant encoding. -#[derive(Clone, prost::Message)] +/// Serialized metadata for TurboQuant encoding: a single byte holding the `bit_width` (0-8). +/// +/// All other fields (dimension, element type) are derived from the dtype and children. +/// A `bit_width` of 0 indicates a degenerate empty array. +#[derive(Clone, Debug)] pub struct TurboQuantMetadata { - /// Vector dimension d. - #[prost(uint32, tag = "1")] - pub dimension: u32, - /// MSE bits per coordinate (1-8). - #[prost(uint32, tag = "2")] - pub bit_width: u32, + /// MSE bits per coordinate (0 for degenerate empty arrays, 1-8 otherwise). + pub bit_width: u8, } /// Slot positions for TurboQuantArray children. @@ -69,105 +72,116 @@ impl Slot { } } -/// TurboQuant array. +/// TurboQuant array data. +/// +/// TurboQuant is a lossy vector quantization encoding for [`Vector`] extension arrays. +/// It stores quantized coordinate codes and per-vector norms, along with shared codebook +/// centroids and SRHT rotation signs. See the [module docs](super) for algorithmic details. /// -/// Slots: -/// - 0: `codes` -- `FixedSizeListArray` (quantized indices, list_size=padded_dim). -/// - 1: `norms` -- `PrimitiveArray` (one per vector row). -/// - 2: `centroids` -- `PrimitiveArray` (codebook, length 2^bit_width). -/// - 3: `rotation_signs` -- `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order). +/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty. +/// +/// [`Vector`]: crate::vector::Vector #[derive(Clone, Debug)] pub struct TurboQuantData { + /// The [`Vector`] extension dtype that this array encodes. The storage dtype within the + /// extension determines the element type (f16, f32, or f64) and the list size (dimension). + /// + /// [`Vector`]: crate::vector::Vector pub(crate) dtype: DType, + + /// Child arrays stored as optional slots. See [`Slot`] for positions: + /// + /// - [`Codes`](Slot::Codes): `FixedSizeListArray` with `list_size == padded_dim`. Each row + /// holds one u8 centroid index per padded coordinate. The cascade compressor handles packing + /// to the actual `bit_width` on disk. The validity of the entire array is stored with this. + /// + /// - [`Norms`](Slot::Norms): Per-vector L2 norms, one per row. The dtype matches the element + /// type of the Vector (e.g., f64 norms for f64 vectors). Exact norms are stored during + /// compression, enabling O(1) L2 norm readthrough without decompression. + /// + /// - [`Centroids`](Slot::Centroids): `PrimitiveArray` codebook with `2^bit_width` entries + /// that is shared across all rows. We always store these as f32 regardless of the input + /// element type because quantization itself introduces far more error than f32 precision + /// loss, and f16 inputs can be upcast to f32 before quantization. + /// + /// - [`RotationSigns`](Slot::RotationSigns): `BitPackedArray` of `3 * padded_dim` 1-bit sign + /// values for the 3-round SRHT rotation, stored in inverse application order, and shared + /// across all rows. pub(crate) slots: Vec>, + + /// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size. + /// Stored as a convenience field to avoid repeatedly extracting it from `dtype`. + /// Non-power-of-2 dimensions are zero-padded to [`padded_dim`](Self::padded_dim) for the + /// Walsh-Hadamard transform. pub(crate) dimension: u32, + + /// The number of bits per coordinate (1-8), derived from `log2(centroids.len())`. + /// Zero for degenerate empty arrays. pub(crate) bit_width: u8, + + /// The stats for this array. pub(crate) stats_set: ArrayStats, } impl TurboQuantData { /// Build a TurboQuant array with validation. /// - /// The `dtype` must be a [`Vector`] extension type. TurboQuant encodes the extension - /// type directly, not its `FixedSizeList` storage. + /// The `dimension` and `bit_width` are derived from the inputs: + /// - `dimension` from the `dtype`'s `FixedSizeList` storage list size. + /// - `bit_width` from `log2(centroids.len())` (0 for degenerate empty arrays). /// /// # Errors /// /// Returns an error if the provided components do not satisfy the invariants documented /// in [`new_unchecked`](Self::new_unchecked). - /// - /// [`Vector`]: crate::vector::Vector - #[allow(clippy::too_many_arguments)] pub fn try_new( dtype: DType, codes: ArrayRef, norms: ArrayRef, centroids: ArrayRef, rotation_signs: ArrayRef, - dimension: u32, - bit_width: u8, ) -> VortexResult { - Self::validate( - &dtype, - &codes, - &norms, - ¢roids, - &rotation_signs, - dimension, - bit_width, - )?; + Self::validate(&dtype, &codes, &norms, ¢roids, &rotation_signs)?; // SAFETY: we validate that the inputs are valid above. - Ok(unsafe { - Self::new_unchecked( - dtype, - codes, - norms, - centroids, - rotation_signs, - dimension, - bit_width, - ) - }) + Ok(unsafe { Self::new_unchecked(dtype, codes, norms, centroids, rotation_signs) }) } /// Build a TurboQuant array without validation. /// - /// * `dtype` must be a [`Vector`] extension type. - /// * `codes` must be a `FixedSizeListArray` with `list_size == padded_dim`. - /// * `norms` must be a `PrimitiveArray` with one element per row. - /// * `centroids` must be a `PrimitiveArray` with `2^bit_width` elements. - /// * `rotation_signs` must contain `3 * padded_dim` sign values. - /// * `bit_width` must be 1-8. - /// * `codes.len() == norms.len()`. - /// /// # Safety /// - /// The caller must ensure the inputs satisfy the invariants listed above. Violating them - /// may produce incorrect results during decompression. + /// The caller must ensure: /// - /// [`Vector`]: crate::vector::Vector - #[allow(clippy::too_many_arguments)] + /// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size + /// is >= 3. + /// - `codes` is a `FixedSizeListArray` with `list_size == padded_dim` and + /// `codes.len() == norms.len()`. + /// - `norms` is a non-nullable primitive array whose ptype matches the element type of the + /// Vector's storage dtype. + /// - `centroids` is a non-nullable `PrimitiveArray` whose length is a power of 2 in + /// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays. + /// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays. + /// - For degenerate (empty) arrays: all children must be empty. + /// + /// Violating these invariants may produce incorrect results during decompression. pub unsafe fn new_unchecked( dtype: DType, codes: ArrayRef, norms: ArrayRef, centroids: ArrayRef, rotation_signs: ArrayRef, - dimension: u32, - bit_width: u8, ) -> Self { #[cfg(debug_assertions)] - Self::validate( - &dtype, - &codes, - &norms, - ¢roids, - &rotation_signs, - dimension, - bit_width, - ) - .vortex_expect("[Debug Assertion]: Invalid TurboQuantData parameters"); + Self::validate(&dtype, &codes, &norms, ¢roids, &rotation_signs) + .vortex_expect("[Debug Assertion]: Invalid TurboQuantData parameters"); + + let dimension = dtype + .as_extension_opt() + .and_then(|ext| extension_list_size(ext).ok()) + .vortex_expect("dtype must be a Vector extension type with FixedSizeList storage"); + + let bit_width = derive_bit_width(¢roids); let mut slots = vec![None; Slot::COUNT]; slots[Slot::Codes as usize] = Some(codes); @@ -186,73 +200,121 @@ impl TurboQuantData { /// Validates the components that would be used to create a `TurboQuantData`. /// /// This function checks all the invariants required by [`new_unchecked`](Self::new_unchecked). - #[allow(clippy::too_many_arguments)] pub fn validate( dtype: &DType, codes: &ArrayRef, norms: &ArrayRef, centroids: &ArrayRef, rotation_signs: &ArrayRef, - dimension: u32, - bit_width: u8, ) -> VortexResult<()> { + // Dtype must be a Vector extension type. + let ext = dtype + .as_extension_opt() + .filter(|e| e.is::()) + .ok_or_else(|| { + vortex_error::vortex_err!( + "TurboQuant dtype must be a Vector extension type, got {dtype}" + ) + })?; + + // Dimension is derived from the storage dtype's list size and must be >= 3. + let dimension = extension_list_size(ext)?; + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + let num_rows = norms.len(); + + // Degenerate (empty) case: all children must be empty, bit_width is 0. + if num_rows == 0 { + vortex_ensure!( + codes.is_empty(), + "degenerate TurboQuant must have empty codes, got length {}", + codes.len() + ); + vortex_ensure!( + centroids.is_empty(), + "degenerate TurboQuant must have empty centroids, got length {}", + centroids.len() + ); + vortex_ensure!( + rotation_signs.is_empty(), + "degenerate TurboQuant must have empty rotation_signs, got length {}", + rotation_signs.len() + ); + return Ok(()); + } + + // Non-degenerate: derive and validate bit_width from centroids. + let num_centroids = centroids.len(); + vortex_ensure!( + num_centroids.is_power_of_two() && (2..=256).contains(&num_centroids), + "centroids length must be a power of 2 in [2, 256], got {num_centroids}" + ); + + // Guaranteed to be 1-8 by the preceding power-of-2 and range checks. + #[allow(clippy::cast_possible_truncation)] + let bit_width = num_centroids.trailing_zeros() as u8; vortex_ensure!( (1..=8).contains(&bit_width), - "bit_width must be 1-8, got {bit_width}" + "derived bit_width must be 1-8, got {bit_width}" ); + + // Norms dtype must match the element ptype of the Vector. + let element_ptype = extension_element_ptype(ext)?; + let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); vortex_ensure!( - dtype - .as_extension_opt() - .is_some_and(|ext| ext.is::()), - "TurboQuant dtype must be a Vector extension type, got {dtype}" + *norms.dtype() == expected_norms_dtype, + "norms dtype {} does not match expected {expected_norms_dtype} \ + (must match Vector element type)", + norms.dtype() ); + + // Centroids are always f32 regardless of element type. + let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable); vortex_ensure!( - dimension >= 3, - "TurboQuant requires dimension >= 3, got {dimension}" + *centroids.dtype() == f32_nn, + "centroids dtype {} must be non-nullable f32", + centroids.dtype() ); - let num_rows = norms.len(); + // Row count consistency. vortex_ensure!( codes.len() == num_rows, "codes length {} does not match norms length {num_rows}", codes.len() ); - let expected_centroids = 1usize << bit_width; - // Allow empty centroids for zero-row arrays. - if num_rows > 0 { - vortex_ensure!( - centroids.len() == expected_centroids, - "centroids length {} does not match expected 2^{bit_width} = {expected_centroids}", - centroids.len() - ); - } - + // Rotation signs count must be 3 * padded_dim. let padded_dim = dimension.next_power_of_two() as usize; - // Allow empty rotation signs for zero-row arrays. - if num_rows > 0 { - vortex_ensure!( - rotation_signs.len() == 3 * padded_dim, - "rotation_signs length {} does not match expected 3 * {padded_dim} = {}", - rotation_signs.len(), - 3 * padded_dim - ); - } + vortex_ensure!( + rotation_signs.len() == 3 * padded_dim, + "rotation_signs length {} does not match expected 3 * {padded_dim} = {}", + rotation_signs.len(), + 3 * padded_dim + ); Ok(()) } - /// The vector dimension d. + /// The vector dimension `d`, as stored in the [`Vector`] extension dtype's + /// `FixedSizeList` storage. + /// + /// [`Vector`]: crate::vector::Vector pub fn dimension(&self) -> u32 { self.dimension } - /// MSE bits per coordinate. + /// MSE bits per coordinate (1-8 for non-empty arrays, 0 for degenerate empty arrays). pub fn bit_width(&self) -> u8 { self.bit_width } - /// Padded dimension (next power of 2 >= dimension). + /// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)). + /// + /// The SRHT rotation requires power-of-2 input, so non-power-of-2 dimensions are + /// zero-padded to this value. pub fn padded_dim(&self) -> u32 { self.dimension.next_power_of_two() } @@ -263,23 +325,43 @@ impl TurboQuantData { .vortex_expect("required slot is None") } - /// The quantized codes child (FixedSizeListArray). + /// The quantized codes child (`FixedSizeListArray`, one row per vector). pub fn codes(&self) -> &ArrayRef { self.slot(Slot::Codes as usize) } - /// The norms child (`PrimitiveArray`). + /// Per-vector L2 norms. The dtype matches the Vector's element type (f16, f32, or f64). pub fn norms(&self) -> &ArrayRef { self.slot(Slot::Norms as usize) } - /// The centroids (codebook) child (`PrimitiveArray`). + /// The codebook centroids (`PrimitiveArray`, length `2^bit_width`). + /// + /// Always f32 regardless of input element type: quantization noise dominates f32 + /// precision loss, and f16 inputs are upcast before quantization anyway. pub fn centroids(&self) -> &ArrayRef { self.slot(Slot::Centroids as usize) } - /// The MSE rotation signs child (BitPackedArray, length 3 * padded_dim). + /// The SRHT rotation signs (`BitPackedArray`, `3 * padded_dim` 1-bit values). + /// + /// Stored in inverse application order for efficient decode. pub fn rotation_signs(&self) -> &ArrayRef { self.slot(Slot::RotationSigns as usize) } } + +/// Derive `bit_width` from the centroids array length. +/// +/// Returns 0 for empty centroids (degenerate array), otherwise `log2(centroids.len())`. +fn derive_bit_width(centroids: &ArrayRef) -> u8 { + if centroids.is_empty() { + 0 + } else { + // Guaranteed to be 0-8 by validate(). + #[allow(clippy::cast_possible_truncation)] + { + centroids.len().trailing_zeros() as u8 + } + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index ff833e05cc2..8f5b02222f8 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -4,6 +4,7 @@ //! TurboQuant encoding (quantization) logic. use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; @@ -11,6 +12,7 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; +use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; @@ -23,6 +25,8 @@ use crate::encodings::turboquant::centroids::compute_boundaries; use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::encodings::turboquant::centroids::get_centroids; use crate::encodings::turboquant::rotation::RotationMatrix; +use crate::scalar_fns::ApproxOptions; +use crate::scalar_fns::l2_norm::L2Norm; /// Configuration for TurboQuant encoding. #[derive(Clone, Debug)] @@ -42,7 +46,9 @@ impl Default for TurboQuantConfig { } } -/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray. +/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray for quantization. +/// +/// All quantization (rotation, centroid lookup) happens in f32. f16 is upcast; f64 is truncated. #[allow(clippy::cast_possible_truncation)] fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult { let elements = fsl.elements(); @@ -65,31 +71,46 @@ fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult f32 { - x.iter().map(|&v| v * v).sum::().sqrt() -} - -/// Shared intermediate results from the MSE quantization loop. +/// Shared intermediate results from the quantization loop. struct QuantizationResult { rotation: RotationMatrix, centroids: Vec, all_indices: BufferMut, - norms: BufferMut, + /// Native-precision norms (matching the Vector element type). + norms_array: ArrayRef, padded_dim: usize, } -/// Core quantization: extract f32 elements, build rotation, normalize/rotate/quantize all rows. +/// Core quantization: compute norms via [`L2Norm`], extract f32 elements, then +/// normalize/rotate/quantize all rows. +/// +/// Norms are computed in the native element precision via the [`L2Norm`] scalar function. +/// The rotation and centroid lookup happen in f32. #[allow(clippy::cast_possible_truncation)] fn turboquant_quantize_core( + ext: &ExtensionArray, fsl: &FixedSizeListArray, seed: u64, bit_width: u8, + ctx: &mut ExecutionCtx, ) -> VortexResult { let dimension = fsl.list_size() as usize; let num_rows = fsl.len(); + // Compute native-precision norms via the L2Norm scalar fn. + let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, ext.as_ref().clone(), num_rows)?; + let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; + let norms_prim: PrimitiveArray = norms_array.to_canonical()?.into_primitive(); + + // Extract f32 norms for the internal quantization loop. + let f32_norms: Vec = match_each_float_ptype!(norms_prim.ptype(), |T| { + norms_prim + .as_slice::() + .iter() + .map(|&v| num_traits::ToPrimitive::to_f32(&v).unwrap_or(0.0)) + .collect() + }); + let rotation = RotationMatrix::try_new(seed, dimension)?; let padded_dim = rotation.padded_dim(); @@ -99,15 +120,13 @@ fn turboquant_quantize_core( let boundaries = compute_boundaries(¢roids); let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); - let mut norms = BufferMut::::with_capacity(num_rows); let mut padded = vec![0.0f32; padded_dim]; let mut rotated = vec![0.0f32; padded_dim]; let f32_slice = f32_elements.as_slice::(); for row in 0..num_rows { let x = &f32_slice[row * dimension..(row + 1) * dimension]; - let norm = l2_norm(x); - norms.push(norm); + let norm = f32_norms[row]; if norm > 0.0 { let inv_norm = 1.0 / norm; @@ -128,7 +147,7 @@ fn turboquant_quantize_core( rotation, centroids, all_indices, - norms, + norms_array, padded_dim, }) } @@ -138,11 +157,8 @@ fn turboquant_quantize_core( fn build_turboquant( fsl: &FixedSizeListArray, core: QuantizationResult, - bit_width: u8, ext_dtype: DType, ) -> VortexResult { - let dimension = fsl.list_size(); - let num_rows = fsl.len(); let padded_dim = core.padded_dim; let codes_elements = @@ -154,8 +170,6 @@ fn build_turboquant( num_rows, )? .into_array(); - let norms_array = - PrimitiveArray::new::(core.norms.freeze(), Validity::NonNullable).into_array(); // TODO(perf): `get_centroids` returns Vec; could avoid the copy by // supporting Buffer::from(Vec) or caching as Buffer directly. @@ -169,11 +183,9 @@ fn build_turboquant( TurboQuantData::try_new( ext_dtype, codes, - norms_array, + core.norms_array, centroids_array, rotation_signs, - dimension, - bit_width, ) } @@ -186,6 +198,7 @@ fn build_turboquant( pub fn turboquant_encode( ext: &ExtensionArray, config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, ) -> VortexResult { let ext_dtype = ext.dtype().clone(); let storage = ext.storage_array(); @@ -214,25 +227,29 @@ pub fn turboquant_encode( Validity::NonNullable, 0, )?; - let empty_norms = PrimitiveArray::empty::(Nullability::NonNullable); + + // Norms dtype matches the element type. + let element_ptype = fsl.elements().dtype().as_ptype(); + let empty_norms: ArrayRef = match_each_float_ptype!(element_ptype, |T| { + PrimitiveArray::empty::(Nullability::NonNullable).into_array() + }); + let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); let empty_signs = PrimitiveArray::empty::(Nullability::NonNullable); return Ok(TurboQuantData::try_new( ext_dtype, empty_codes.into_array(), - empty_norms.into_array(), + empty_norms, empty_centroids.into_array(), empty_signs.into_array(), - dimension, - config.bit_width, )? .into_array()); } let seed = config.seed.unwrap_or(42); - let core = turboquant_quantize_core(&fsl, seed, config.bit_width)?; + let core = turboquant_quantize_core(ext, &fsl, seed, config.bit_width, ctx)?; - Ok(build_turboquant(&fsl, core, config.bit_width, ext_dtype)?.into_array()) + Ok(build_turboquant(&fsl, core, ext_dtype)?.into_array()) } /// Export rotation signs as a 1-bit `BitPackedArray` for efficient storage. diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs index 19e1a9e0f91..86768b949d6 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/slice.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/slice.rs @@ -26,8 +26,6 @@ impl SliceReduce for TurboQuant { sliced_norms, array.centroids().clone(), array.rotation_signs().clone(), - array.dimension, - array.bit_width, )?; Ok(Some(result.into_array())) diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs index 638b493d3a6..e3b5e866e57 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/take.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/take.rs @@ -27,8 +27,6 @@ impl TakeExecute for TurboQuant { taken_norms, array.centroids().clone(), array.rotation_signs().clone(), - array.dimension, - array.bit_width, )?; Ok(Some(result.into_array())) diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 94a3e24bff1..760dce16f2e 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -52,6 +52,7 @@ //! //! ``` //! use vortex_array::IntoArray; +//! use vortex_array::VortexSessionExecute; //! use vortex_array::arrays::ExtensionArray; //! use vortex_array::arrays::FixedSizeListArray; //! use vortex_array::arrays::PrimitiveArray; @@ -59,6 +60,8 @@ //! use vortex_array::extension::EmptyMetadata; //! use vortex_array::validity::Validity; //! use vortex_buffer::BufferMut; +//! use vortex_array::session::ArraySession; +//! use vortex_session::VortexSession; //! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; //! use vortex_tensor::vector::Vector; //! @@ -79,7 +82,9 @@ //! //! // Quantize at 2 bits per coordinate. //! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; -//! let encoded = turboquant_encode(&ext, &config).unwrap(); +//! let session = VortexSession::empty().with::(); +//! let mut ctx = session.create_execution_ctx(); +//! let encoded = turboquant_encode(&ext, &config, &mut ctx).unwrap(); //! //! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. //! assert!(encoded.nbytes() < 51200); @@ -114,666 +119,4 @@ pub fn initialize(session: &mut VortexSession) { } #[cfg(test)] -#[allow(clippy::cast_possible_truncation)] -mod tests { - use std::sync::LazyLock; - - use rand::SeedableRng; - use rand::rngs::StdRng; - use rand_distr::Distribution; - use rand_distr::Normal; - use rstest::rstest; - use vortex_array::ArrayRef; - use vortex_array::IntoArray; - use vortex_array::VortexSessionExecute; - use vortex_array::arrays::ExtensionArray; - use vortex_array::arrays::FixedSizeListArray; - use vortex_array::arrays::PrimitiveArray; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; - use vortex_array::session::ArraySession; - use vortex_array::validity::Validity; - use vortex_buffer::BufferMut; - use vortex_error::VortexResult; - use vortex_session::VortexSession; - - use crate::encodings::turboquant::TurboQuant; - use crate::encodings::turboquant::TurboQuantConfig; - use crate::encodings::turboquant::rotation::RotationMatrix; - use crate::encodings::turboquant::turboquant_encode; - use crate::vector::Vector; - - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - - /// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal). - fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { - let mut rng = StdRng::seed_from_u64(seed); - let normal = Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim as u32, - Validity::NonNullable, - num_rows, - ) - .unwrap() - } - - /// Wrap a `FixedSizeListArray` in a `Vector` extension array. - fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) - .unwrap() - .erased(); - ExtensionArray::new(ext_dtype, fsl.clone().into_array()) - } - - fn theoretical_mse_bound(bit_width: u8) -> f32 { - let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; - sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) - } - - fn per_vector_normalized_mse( - original: &[f32], - reconstructed: &[f32], - dim: usize, - num_rows: usize, - ) -> f32 { - let mut total = 0.0f32; - for row in 0..num_rows { - let orig = &original[row * dim..(row + 1) * dim]; - let recon = &reconstructed[row * dim..(row + 1) * dim]; - let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); - if norm_sq < 1e-10 { - continue; - } - let err_sq: f32 = orig - .iter() - .zip(recon.iter()) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - total += err_sq / norm_sq; - } - total / num_rows as f32 - } - - /// Encode and decode, returning (original, decoded) flat f32 slices. - fn encode_decode( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, - ) -> VortexResult<(Vec, Vec)> { - let original: Vec = { - let prim = fsl.elements().to_canonical().unwrap().into_primitive(); - prim.as_slice::().to_vec() - }; - let ext = make_vector_ext(fsl); - let config = config.clone(); - let encoded = turboquant_encode(&ext, &config)?; - let mut ctx = SESSION.create_execution_ctx(); - let decoded_ext = encoded.execute::(&mut ctx)?; - let decoded_fsl = decoded_ext - .storage_array() - .to_canonical() - .unwrap() - .into_fixed_size_list(); - let decoded_elements: Vec = { - let prim = decoded_fsl - .elements() - .to_canonical() - .unwrap() - .into_primitive(); - prim.as_slice::().to_vec() - }; - Ok((original, decoded_elements)) - } - - // ----------------------------------------------------------------------- - // Roundtrip tests - // ----------------------------------------------------------------------- - - #[rstest] - #[case(32, 1)] - #[case(32, 2)] - #[case(32, 3)] - #[case(32, 4)] - #[case(128, 2)] - #[case(128, 4)] - #[case(128, 6)] - #[case(128, 8)] - #[case(256, 2)] - fn roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let fsl = make_fsl(10, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - assert_eq!(decoded.len(), original.len()); - Ok(()) - } - - // ----------------------------------------------------------------------- - // MSE quality tests - // ----------------------------------------------------------------------- - - #[rstest] - #[case(128, 1)] - #[case(128, 2)] - #[case(128, 3)] - #[case(128, 4)] - #[case(256, 2)] - #[case(256, 4)] - fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - - let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - let bound = theoretical_mse_bound(bit_width); - - assert!( - normalized_mse < bound, - "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} \ - for dim={dim}, bits={bit_width}", - ); - Ok(()) - } - - #[rstest] - #[case(128, 6)] - #[case(128, 8)] - #[case(256, 6)] - #[case(256, 8)] - fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - - let config_4bit = TurboQuantConfig { - bit_width: 4, - seed: Some(123), - }; - let (original_4, decoded_4) = encode_decode(&fsl, &config_4bit)?; - let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); - - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - - assert!( - mse < mse_4bit, - "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" - ); - assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); - Ok(()) - } - - #[test] - fn mse_decreases_with_bits() -> VortexResult<()> { - let dim = 128; - let num_rows = 50; - let fsl = make_fsl(num_rows, dim, 99); - - let mut prev_mse = f32::MAX; - for bit_width in 1..=8u8 { - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - assert!( - mse <= prev_mse * 1.01, - "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" - ); - prev_mse = mse; - } - Ok(()) - } - - // ----------------------------------------------------------------------- - // Edge cases - // ----------------------------------------------------------------------- - - #[rstest] - #[case(0)] - #[case(1)] - fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { - let fsl = make_fsl(num_rows, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 2, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - let mut ctx = SESSION.create_execution_ctx(); - let decoded = encoded.execute::(&mut ctx)?; - assert_eq!(decoded.len(), num_rows); - Ok(()) - } - - #[rstest] - #[case(1)] - #[case(2)] - fn rejects_dimension_below_3(#[case] dim: usize) { - let fsl = make_fsl_small(dim); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 2, - seed: Some(0), - }; - assert!(turboquant_encode(&ext, &config).is_err()); - } - - fn make_fsl_small(dim: usize) -> FixedSizeListArray { - let mut buf = BufferMut::::with_capacity(dim); - for i in 0..dim { - buf.push(i as f32 + 1.0); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new(elements.into_array(), dim as u32, Validity::NonNullable, 1) - .unwrap() - } - - /// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). - #[test] - fn all_zero_vectors_roundtrip() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let buf = BufferMut::::full(0.0f32, num_rows * dim); - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim as u32, - Validity::NonNullable, - num_rows, - )?; - - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - // All-zero vectors should decode to all-zero (norm=0 -> 0 * anything = 0). - for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { - assert_eq!(o, 0.0, "original[{i}] not zero"); - assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); - } - Ok(()) - } - - /// Verify that f64 input is accepted and encoded (converted to f32 internally). - #[test] - fn f64_input_encodes_successfully() -> VortexResult<()> { - let num_rows = 10; - let dim = 64; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f64, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim as u32, - Validity::NonNullable, - num_rows, - )?; - - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - }; - // Verify encoding succeeds with f64 input (f64->f32 conversion). - let encoded = turboquant_encode(&ext, &config)?; - let encoded = encoded.as_opt::().unwrap(); - assert_eq!(encoded.norms().len(), num_rows); - assert_eq!(encoded.dimension(), dim as u32); - Ok(()) - } - - // ----------------------------------------------------------------------- - // Verification tests for stored metadata - // ----------------------------------------------------------------------- - - /// Verify that the centroids stored in the array match what `get_centroids()` computes. - #[test] - fn stored_centroids_match_computed() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - let encoded = encoded.as_opt::().unwrap(); - - let mut ctx = SESSION.create_execution_ctx(); - let stored_centroids_prim = encoded - .centroids() - .clone() - .execute::(&mut ctx)?; - let stored = stored_centroids_prim.as_slice::(); - - let padded_dim = encoded.padded_dim(); - let computed = crate::encodings::turboquant::centroids::get_centroids(padded_dim, 3)?; - - assert_eq!(stored.len(), computed.len()); - for i in 0..stored.len() { - assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); - } - Ok(()) - } - - /// Verify that stored rotation signs produce identical decode to seed-based decode. - #[test] - fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - let encoded = encoded.as_opt::().unwrap(); - - // Decode via the stored-signs path (normal decode). - let mut ctx = SESSION.create_execution_ctx(); - let decoded_ext = encoded - .array() - .clone() - .execute::(&mut ctx)?; - let decoded_fsl = decoded_ext - .storage_array() - .to_canonical()? - .into_fixed_size_list(); - let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); - let decoded_slice = decoded.as_slice::(); - - // Verify stored signs match seed-derived signs. - let rot_from_seed = RotationMatrix::try_new(123, 128)?; - let expected_u8 = rot_from_seed.export_inverse_signs_u8(); - let stored_signs = encoded - .rotation_signs() - .clone() - .execute::(&mut ctx)?; - let stored_u8 = stored_signs.as_slice::(); - - assert_eq!(expected_u8.len(), stored_u8.len()); - for i in 0..expected_u8.len() { - assert_eq!(expected_u8[i], stored_u8[i], "Sign mismatch at index {i}"); - } - - // Also verify decode output is non-empty and has expected size. - assert_eq!(decoded_slice.len(), 20 * 128); - Ok(()) - } - - // ----------------------------------------------------------------------- - // Serde roundtrip - // ----------------------------------------------------------------------- - - #[test] - fn serde_roundtrip() -> VortexResult<()> { - use vortex_array::vtable::VTable; - - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - let encoded = encoded.as_opt::().unwrap(); - - // Serialize metadata. - let metadata = ::metadata(encoded)?; - let serialized = - ::serialize(metadata)?.expect("metadata should serialize"); - - // Collect children. - let nchildren = ::nchildren(encoded); - assert_eq!(nchildren, 4); - let children: Vec = (0..nchildren) - .map(|i| ::child(encoded, i)) - .collect(); - - // Deserialize and rebuild. - let deserialized = ::deserialize( - &serialized, - encoded.dtype(), - encoded.len(), - &[], - &SESSION, - )?; - - // Verify metadata fields survived roundtrip. - assert_eq!(deserialized.dimension, encoded.dimension()); - assert_eq!(deserialized.bit_width, encoded.bit_width() as u32); - - // Verify the rebuilt array decodes identically. - let mut ctx = SESSION.create_execution_ctx(); - let decoded_original = encoded - .array() - .clone() - .execute::(&mut ctx)?; - let original_fsl = decoded_original - .storage_array() - .to_canonical()? - .into_fixed_size_list(); - let original_elements = original_fsl.elements().to_canonical()?.into_primitive(); - - // Rebuild from children (simulating deserialization). - let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new( - encoded.dtype().clone(), - children[0].clone(), - children[1].clone(), - children[2].clone(), - children[3].clone(), - deserialized.dimension, - deserialized.bit_width as u8, - )?; - let decoded_rebuilt = rebuilt.into_array().execute::(&mut ctx)?; - let rebuilt_fsl = decoded_rebuilt - .storage_array() - .to_canonical()? - .into_fixed_size_list(); - let rebuilt_elements = rebuilt_fsl.elements().to_canonical()?.into_primitive(); - - assert_eq!( - original_elements.as_slice::(), - rebuilt_elements.as_slice::() - ); - Ok(()) - } - - // ----------------------------------------------------------------------- - // Compute pushdown tests - // ----------------------------------------------------------------------- - - #[test] - fn slice_preserves_data() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - - // Full decompress then slice. - let mut ctx = SESSION.create_execution_ctx(); - let full_decoded = encoded.clone().execute::(&mut ctx)?; - let full_fsl = full_decoded - .storage_array() - .to_canonical()? - .into_fixed_size_list(); - let expected = full_fsl.slice(5..10)?; - let expected_prim = expected.to_canonical()?.into_fixed_size_list(); - let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); - - // Slice then decompress. - let sliced = encoded.slice(5..10)?; - let sliced_decoded = sliced.execute::(&mut ctx)?; - let sliced_fsl = sliced_decoded - .storage_array() - .to_canonical()? - .into_fixed_size_list(); - let actual_elements = sliced_fsl.elements().to_canonical()?.into_primitive(); - - assert_eq!( - expected_elements.as_slice::(), - actual_elements.as_slice::() - ); - Ok(()) - } - - #[test] - fn scalar_at_matches_decompress() -> VortexResult<()> { - let fsl = make_fsl(10, 64, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - - let mut ctx = SESSION.create_execution_ctx(); - let full_decoded = encoded.clone().execute::(&mut ctx)?; - - for i in [0, 1, 5, 9] { - let expected = full_decoded.scalar_at(i)?; - let actual = encoded.scalar_at(i)?; - assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); - } - Ok(()) - } - - #[test] - fn l2_norm_readthrough() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - let tq = encoded.as_opt::().unwrap(); - - // Stored norms should match the actual L2 norms of the input. - let norms_prim = tq.norms().to_canonical()?.into_primitive(); - let stored_norms = norms_prim.as_slice::(); - - let input_prim = fsl.elements().to_canonical()?.into_primitive(); - let input_f32 = input_prim.as_slice::(); - for row in 0..10 { - let vec = &input_f32[row * 128..(row + 1) * 128]; - let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); - assert!( - (stored_norms[row] - actual_norm).abs() < 1e-5, - "norm mismatch at row {row}: stored={}, actual={}", - stored_norms[row], - actual_norm - ); - } - Ok(()) - } - - #[test] - fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 4, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - let tq = encoded.as_opt::().unwrap(); - - // Compute exact cosine similarity from original data. - let input_prim = fsl.elements().to_canonical()?.into_primitive(); - let input_f32 = input_prim.as_slice::(); - - // Read quantized codes, norms, and centroids for approximate computation. - let mut ctx = SESSION.create_execution_ctx(); - let pd = tq.padded_dim() as usize; - let norms_prim = tq.norms().clone().execute::(&mut ctx)?; - let norms = norms_prim.as_slice::(); - let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; - let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); - let all_codes = codes_prim.as_slice::(); - let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; - let centroid_vals = centroids_prim.as_slice::(); - - for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { - let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; - let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; - - let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); - let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); - let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); - let exact_cos = dot / (norm_a * norm_b); - - // Approximate cosine similarity in quantized domain. - let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { - 0.0 - } else { - let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; - let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; - codes_a - .iter() - .zip(codes_b.iter()) - .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) - .sum::() - }; - - // 4-bit quantization: expect reasonable accuracy. - let error = (exact_cos - approx_cos).abs(); - assert!( - error < 0.15, - "cosine similarity error too large for ({row_a}, {row_b}): \ - exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" - ); - } - Ok(()) - } - - /// Verify that the encoded array's dtype is a Vector extension type. - #[test] - fn encoded_dtype_is_vector_extension() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let encoded = turboquant_encode(&ext, &config)?; - - // The encoded TurboQuant array should claim a Vector extension dtype. - assert!( - encoded.dtype().is_extension(), - "TurboQuant dtype should be an extension type, got {}", - encoded.dtype() - ); - assert!( - encoded.dtype().as_extension().is::(), - "TurboQuant dtype should be a Vector extension type" - ); - Ok(()) - } -} +mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index e357d8a4005..de5fc8ee88a 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -72,7 +72,7 @@ impl Scheme for TurboQuantScheme { fn compress( &self, - _compressor: &CascadingCompressor, + compressor: &CascadingCompressor, data: &mut ArrayAndStats, _ctx: CompressorContext, ) -> VortexResult { @@ -80,7 +80,7 @@ impl Scheme for TurboQuantScheme { let ext_array = array.to_canonical()?.into_extension(); let config = TurboQuantConfig::default(); - turboquant_encode(&ext_array, &config) + turboquant_encode(&ext_array, &config, &mut compressor.execution_ctx()) } } diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs new file mode 100644 index 00000000000..5542e4d8278 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -0,0 +1,675 @@ +use std::sync::LazyLock; + +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand_distr::Distribution; +use rand_distr::Normal; +use rstest::rstest; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_session::VortexSession; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::TurboQuantConfig; +use crate::encodings::turboquant::rotation::RotationMatrix; +use crate::encodings::turboquant::turboquant_encode; +use crate::vector::Vector; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + +/// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal). +fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + ) + .unwrap() +} + +/// Wrap a `FixedSizeListArray` in a `Vector` extension array. +fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) + .unwrap() + .erased(); + ExtensionArray::new(ext_dtype, fsl.clone().into_array()) +} + +fn theoretical_mse_bound(bit_width: u8) -> f32 { + let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; + sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) +} + +fn per_vector_normalized_mse( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, +) -> f32 { + let mut total = 0.0f32; + for row in 0..num_rows { + let orig = &original[row * dim..(row + 1) * dim]; + let recon = &reconstructed[row * dim..(row + 1) * dim]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + total += err_sq / norm_sq; + } + total / num_rows as f32 +} + +/// Encode and decode, returning (original, decoded) flat f32 slices. +fn encode_decode( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult<(Vec, Vec)> { + let original: Vec = { + let prim = fsl.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + let ext = make_vector_ext(fsl); + let config = config.clone(); + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let decoded_ext = encoded.execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical() + .unwrap() + .into_fixed_size_list(); + let decoded_elements: Vec = { + let prim = decoded_fsl + .elements() + .to_canonical() + .unwrap() + .into_primitive(); + prim.as_slice::().to_vec() + }; + Ok((original, decoded_elements)) +} + +// ----------------------------------------------------------------------- +// Roundtrip tests +// ----------------------------------------------------------------------- + +#[rstest] +#[case(32, 1)] +#[case(32, 2)] +#[case(32, 3)] +#[case(32, 4)] +#[case(128, 2)] +#[case(128, 4)] +#[case(128, 6)] +#[case(128, 8)] +#[case(256, 2)] +fn roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) +} + +// ----------------------------------------------------------------------- +// MSE quality tests +// ----------------------------------------------------------------------- + +#[rstest] +#[case(128, 1)] +#[case(128, 2)] +#[case(128, 3)] +#[case(128, 4)] +#[case(256, 2)] +#[case(256, 4)] +fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + let bound = theoretical_mse_bound(bit_width); + + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} \ + for dim={dim}, bits={bit_width}", + ); + Ok(()) +} + +#[rstest] +#[case(128, 6)] +#[case(128, 8)] +#[case(256, 6)] +#[case(256, 8)] +fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let (original_4, decoded_4) = encode_decode(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" + ); + assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); + Ok(()) +} + +#[test] +fn mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 1..=8u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) +} + +// ----------------------------------------------------------------------- +// Edge cases +// ----------------------------------------------------------------------- + +#[rstest] +#[case(0)] +#[case(1)] +fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) +} + +#[rstest] +#[case(1)] +#[case(2)] +fn rejects_dimension_below_3(#[case] dim: usize) { + let fsl = make_fsl_small(dim); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(0), + }; + let mut ctx = SESSION.create_execution_ctx(); + assert!(turboquant_encode(&ext, &config, &mut ctx).is_err()); +} + +fn make_fsl_small(dim: usize) -> FixedSizeListArray { + let mut buf = BufferMut::::with_capacity(dim); + for i in 0..dim { + buf.push(i as f32 + 1.0); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + 1, + ) + .unwrap() +} + +/// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). +#[test] +fn all_zero_vectors_roundtrip() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let buf = BufferMut::::full(0.0f32, num_rows * dim); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + // All-zero vectors should decode to all-zero (norm=0 -> 0 * anything = 0). + for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { + assert_eq!(o, 0.0, "original[{i}] not zero"); + assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); + } + Ok(()) +} + +/// Verify that f64 input is accepted and encoded (converted to f32 internally). +#[test] +fn f64_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 64; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f64, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + )?; + + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + // Verify encoding succeeds with f64 input (f64->f32 conversion). + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let encoded = encoded.as_opt::().unwrap(); + assert_eq!(encoded.norms().len(), num_rows); + assert_eq!(encoded.dimension() as usize, dim); + Ok(()) +} + +// ----------------------------------------------------------------------- +// Verification tests for stored metadata +// ----------------------------------------------------------------------- + +/// Verify that the centroids stored in the array match what `get_centroids()` computes. +#[test] +fn stored_centroids_match_computed() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let encoded = encoded.as_opt::().unwrap(); + + let mut ctx = SESSION.create_execution_ctx(); + let stored_centroids_prim = encoded + .centroids() + .clone() + .execute::(&mut ctx)?; + let stored = stored_centroids_prim.as_slice::(); + + let padded_dim = encoded.padded_dim(); + let computed = crate::encodings::turboquant::centroids::get_centroids(padded_dim, 3)?; + + assert_eq!(stored.len(), computed.len()); + for i in 0..stored.len() { + assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); + } + Ok(()) +} + +/// Verify that stored rotation signs produce identical decode to seed-based decode. +#[test] +fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let encoded = encoded.as_opt::().unwrap(); + + // Decode via the stored-signs path (normal decode). + let mut ctx = SESSION.create_execution_ctx(); + let decoded_ext = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_slice = decoded.as_slice::(); + + // Verify stored signs match seed-derived signs. + let rot_from_seed = RotationMatrix::try_new(123, 128)?; + let expected_u8 = rot_from_seed.export_inverse_signs_u8(); + let stored_signs = encoded + .rotation_signs() + .clone() + .execute::(&mut ctx)?; + let stored_u8 = stored_signs.as_slice::(); + + assert_eq!(expected_u8.len(), stored_u8.len()); + for i in 0..expected_u8.len() { + assert_eq!(expected_u8[i], stored_u8[i], "Sign mismatch at index {i}"); + } + + // Also verify decode output is non-empty and has expected size. + assert_eq!(decoded_slice.len(), 20 * 128); + Ok(()) +} + +// ----------------------------------------------------------------------- +// Serde roundtrip +// ----------------------------------------------------------------------- + +#[test] +fn serde_roundtrip() -> VortexResult<()> { + use vortex_array::vtable::VTable; + + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let encoded = encoded.as_opt::().unwrap(); + + // Serialize metadata. + let metadata = ::metadata(encoded)?; + let serialized = + ::serialize(metadata)?.expect("metadata should serialize"); + + // Collect children. + let nchildren = ::nchildren(encoded); + assert_eq!(nchildren, 4); + let children: Vec = (0..nchildren) + .map(|i| ::child(encoded, i)) + .collect(); + + // Deserialize and rebuild. + let deserialized = ::deserialize( + &serialized, + encoded.dtype(), + encoded.len(), + &[], + &SESSION, + )?; + + // Verify metadata fields survived roundtrip. + assert_eq!(deserialized.bit_width, encoded.bit_width()); + + // Verify the rebuilt array decodes identically. + let mut ctx = SESSION.create_execution_ctx(); + let decoded_original = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let original_fsl = decoded_original + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let original_elements = original_fsl.elements().to_canonical()?.into_primitive(); + + // Rebuild from children (simulating deserialization). + let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new( + encoded.dtype().clone(), + children[0].clone(), + children[1].clone(), + children[2].clone(), + children[3].clone(), + )?; + let decoded_rebuilt = rebuilt.into_array().execute::(&mut ctx)?; + let rebuilt_fsl = decoded_rebuilt + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let rebuilt_elements = rebuilt_fsl.elements().to_canonical()?.into_primitive(); + + assert_eq!( + original_elements.as_slice::(), + rebuilt_elements.as_slice::() + ); + Ok(()) +} + +// ----------------------------------------------------------------------- +// Compute pushdown tests +// ----------------------------------------------------------------------- + +#[test] +fn slice_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // Full decompress then slice. + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let full_fsl = full_decoded + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let expected = full_fsl.slice(5..10)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + // Slice then decompress. + let sliced = encoded.slice(5..10)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let sliced_fsl = sliced_decoded + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let actual_elements = sliced_fsl.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) +} + +#[test] +fn scalar_at_matches_decompress() -> VortexResult<()> { + let fsl = make_fsl(10, 64, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + let full_decoded = encoded.clone().execute::(&mut ctx)?; + + for i in [0, 1, 5, 9] { + let expected = full_decoded.scalar_at(i)?; + let actual = encoded.scalar_at(i)?; + assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); + } + Ok(()) +} + +#[test] +fn l2_norm_readthrough() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + + // Stored norms should match the actual L2 norms of the input. + let norms_prim = tq.norms().to_canonical()?.into_primitive(); + let stored_norms = norms_prim.as_slice::(); + + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + for row in 0..10 { + let vec = &input_f32[row * 128..(row + 1) * 128]; + let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); + assert!( + (stored_norms[row] - actual_norm).abs() < 1e-5, + "norm mismatch at row {row}: stored={}, actual={}", + stored_norms[row], + actual_norm + ); + } + Ok(()) +} + +#[test] +fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + + // Compute exact cosine similarity from original data. + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + + // Read quantized codes, norms, and centroids for approximate computation. + let mut ctx = SESSION.create_execution_ctx(); + let pd = tq.padded_dim() as usize; + let norms_prim = tq.norms().clone().execute::(&mut ctx)?; + let norms = norms_prim.as_slice::(); + let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let all_codes = codes_prim.as_slice::(); + let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; + let centroid_vals = centroids_prim.as_slice::(); + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); + let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); + let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); + let exact_cos = dot / (norm_a * norm_b); + + // Approximate cosine similarity in quantized domain. + let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { + 0.0 + } else { + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) + .sum::() + }; + + // 4-bit quantization: expect reasonable accuracy. + let error = (exact_cos - approx_cos).abs(); + assert!( + error < 0.15, + "cosine similarity error too large for ({row_a}, {row_b}): \ + exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" + ); + } + Ok(()) +} + +/// Verify that the encoded array's dtype is a Vector extension type. +#[test] +fn encoded_dtype_is_vector_extension() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // The encoded TurboQuant array should claim a Vector extension dtype. + assert!( + encoded.dtype().is_extension(), + "TurboQuant dtype should be an extension type, got {}", + encoded.dtype() + ); + assert!( + encoded.dtype().as_extension().is::(), + "TurboQuant dtype should be a Vector extension type" + ); + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index b3fc365e8db..0d8d4ce8e1d 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -12,12 +12,9 @@ use vortex_array::ArrayHash; use vortex_array::ArrayId; use vortex_array::ArrayRef; use vortex_array::ArrayView; -use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; use vortex_array::ExecutionResult; use vortex_array::Precision; -use vortex_array::ProstMetadata; -use vortex_array::SerializeMetadata; use vortex_array::buffer::BufferHandle; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; @@ -29,6 +26,7 @@ use vortex_array::vtable::ValidityChild; use vortex_array::vtable::ValidityVTableFromChild; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; use vortex_error::vortex_panic; use vortex_session::VortexSession; @@ -36,11 +34,15 @@ use crate::encodings::turboquant::array::Slot; use crate::encodings::turboquant::array::TurboQuant; use crate::encodings::turboquant::array::TurboQuantData; use crate::encodings::turboquant::array::TurboQuantMetadata; +use crate::encodings::turboquant::compute::rules::PARENT_KERNELS; +use crate::encodings::turboquant::compute::rules::RULES; use crate::encodings::turboquant::decompress::execute_decompress; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; impl VTable for TurboQuant { type ArrayData = TurboQuantData; - type Metadata = ProstMetadata; + type Metadata = TurboQuantMetadata; type OperationsVTable = TurboQuant; type ValidityVTable = ValidityVTableFromChild; @@ -128,14 +130,13 @@ impl VTable for TurboQuant { } fn metadata(array: ArrayView) -> VortexResult { - Ok(ProstMetadata(TurboQuantMetadata { - dimension: array.dimension, - bit_width: array.bit_width as u32, - })) + Ok(TurboQuantMetadata { + bit_width: array.bit_width, + }) } fn serialize(metadata: Self::Metadata) -> VortexResult>> { - Ok(Some(metadata.serialize())) + Ok(Some(vec![metadata.bit_width])) } fn deserialize( @@ -145,9 +146,21 @@ impl VTable for TurboQuant { _buffers: &[BufferHandle], _session: &VortexSession, ) -> VortexResult { - Ok(ProstMetadata( - as DeserializeMetadata>::deserialize(bytes)?, - )) + vortex_ensure_eq!( + bytes.len(), + 1, + "TurboQuant metadata must be exactly 1 byte, got {}", + bytes.len() + ); + vortex_ensure!( + bytes[0] <= 8, + "bit_width is expected to be between 0 and 8, got {}", + bytes[0] + ); + + Ok(TurboQuantMetadata { + bit_width: bytes[0], + }) } #[allow(clippy::cast_possible_truncation)] @@ -158,31 +171,52 @@ impl VTable for TurboQuant { _buffers: &[BufferHandle], children: &dyn ArrayChildren, ) -> VortexResult { - let bit_width = u8::try_from(metadata.bit_width)?; - let padded_dim = metadata.dimension.next_power_of_two() as usize; - let num_centroids = 1usize << bit_width; + let bit_width = metadata.bit_width; - let u8_nn = DType::Primitive(PType::U8, Nullability::NonNullable); - let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable); - let codes_dtype = - DType::FixedSizeList(Arc::new(u8_nn), padded_dim as u32, Nullability::NonNullable); - let codes = children.get(0, &codes_dtype, len)?; + // Derive dimension and element ptype from the Vector extension dtype. + let ext = dtype.as_extension(); + let dimension = extension_list_size(ext)?; - let norms = children.get(1, &f32_nn, len)?; - let centroids = children.get(2, &f32_nn, num_centroids)?; + let element_ptype = extension_element_ptype(ext)?; + let element_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + let padded_dim = dimension.next_power_of_two() as usize; + + // Get the codes array (indices into the codebook). + let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable); + let codes_dtype = DType::FixedSizeList( + Arc::new(codes_ptype), + padded_dim as u32, + dtype.nullability(), + ); + let codes_array = children.get(0, &codes_dtype, len)?; + + // Get the L2 norms array. + let norms_array = children.get(1, &element_dtype, len)?; + + // Get the centroids array (codebook). + let num_centroids = if bit_width == 0 { + 0 // A degenerate TQ array. + } else { + 1usize << bit_width + }; + let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let centroids = children.get(2, ¢roids_dtype, num_centroids)?; + + // Get the rotation array. + let signs_len = if len == 0 { 0 } else { 3 * padded_dim }; let signs_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); - let rotation_signs = children.get(3, &signs_dtype, 3 * padded_dim)?; + let rotation_signs = children.get(3, &signs_dtype, signs_len)?; Ok(TurboQuantData { dtype: dtype.clone(), slots: vec![ - Some(codes), - Some(norms), + Some(codes_array), + Some(norms_array), Some(centroids), Some(rotation_signs), ], - dimension: metadata.dimension, + dimension, bit_width, stats_set: Default::default(), }) @@ -193,7 +227,7 @@ impl VTable for TurboQuant { parent: &ArrayRef, child_idx: usize, ) -> VortexResult> { - crate::encodings::turboquant::compute::rules::RULES.evaluate(array, parent, child_idx) + RULES.evaluate(array, parent, child_idx) } fn execute_parent( @@ -202,8 +236,7 @@ impl VTable for TurboQuant { child_idx: usize, ctx: &mut ExecutionCtx, ) -> VortexResult> { - crate::encodings::turboquant::compute::rules::PARENT_KERNELS - .execute(array, parent, child_idx, ctx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index 39ff2177544..b5531beb767 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -496,7 +496,10 @@ mod turboquant_benches { let config = turboquant_config($bits); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) .with_inputs(|| &ext) - .bench_refs(|a| turboquant_encode(a, &config).unwrap()); + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + turboquant_encode(a, &config, &mut ctx).unwrap() + }); } } }; @@ -506,7 +509,8 @@ mod turboquant_benches { fn $name(bencher: Bencher) { let ext = setup_vector_ext($dim); let config = turboquant_config($bits); - let compressed = turboquant_encode(&ext, &config).unwrap(); + let mut ctx = SESSION.create_execution_ctx(); + let compressed = turboquant_encode(&ext, &config, &mut ctx).unwrap(); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) .with_inputs(|| &compressed) .bench_refs(|a| { From f47a913e78dbd226e4928dc72ae81a6fdf886b8f Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 16:21:35 -0400 Subject: [PATCH 05/13] restructure modules for turboquant Signed-off-by: Connor Tsui --- vortex-btrblocks/src/builder.rs | 4 +- .../turboquant/{ => array}/centroids.rs | 0 .../turboquant/{array.rs => array/data.rs} | 77 +------------------ .../encodings/turboquant/array/metadata.rs | 12 +++ .../src/encodings/turboquant/array/mod.rs | 14 ++++ .../turboquant/{ => array}/rotation.rs | 0 .../turboquant/{ => array}/scheme.rs | 77 +++++-------------- .../src/encodings/turboquant/array/slots.rs | 35 +++++++++ .../src/encodings/turboquant/compress.rs | 10 +-- .../src/encodings/turboquant/compute/ops.rs | 2 +- .../src/encodings/turboquant/compute/rules.rs | 2 +- .../src/encodings/turboquant/compute/slice.rs | 4 +- .../src/encodings/turboquant/compute/take.rs | 4 +- .../src/encodings/turboquant/decompress.rs | 2 +- vortex-tensor/src/encodings/turboquant/mod.rs | 36 ++++----- .../src/encodings/turboquant/tests.rs | 6 +- .../src/encodings/turboquant/vtable.rs | 48 ++++++++++-- 17 files changed, 159 insertions(+), 174 deletions(-) rename vortex-tensor/src/encodings/turboquant/{ => array}/centroids.rs (100%) rename vortex-tensor/src/encodings/turboquant/{array.rs => array/data.rs} (83%) create mode 100644 vortex-tensor/src/encodings/turboquant/array/metadata.rs create mode 100644 vortex-tensor/src/encodings/turboquant/array/mod.rs rename vortex-tensor/src/encodings/turboquant/{ => array}/rotation.rs (100%) rename vortex-tensor/src/encodings/turboquant/{ => array}/scheme.rs (68%) create mode 100644 vortex-tensor/src/encodings/turboquant/array/slots.rs diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index 5728c606d1e..ee1707c8961 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -150,8 +150,8 @@ impl BtrBlocksCompressorBuilder { /// [`Vector`]: vortex_tensor::vector::Vector #[cfg(feature = "unstable_encodings")] pub fn with_turboquant(self) -> Self { - use vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME; - self.with_new_scheme(&TURBOQUANT_SCHEME) + use vortex_tensor::encodings::turboquant::TurboQuantScheme; + self.with_new_scheme(&TurboQuantScheme) } /// Excludes schemes without CUDA kernel support and adds Zstd for string compression. diff --git a/vortex-tensor/src/encodings/turboquant/centroids.rs b/vortex-tensor/src/encodings/turboquant/array/centroids.rs similarity index 100% rename from vortex-tensor/src/encodings/turboquant/centroids.rs rename to vortex-tensor/src/encodings/turboquant/array/centroids.rs diff --git a/vortex-tensor/src/encodings/turboquant/array.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs similarity index 83% rename from vortex-tensor/src/encodings/turboquant/array.rs rename to vortex-tensor/src/encodings/turboquant/array/data.rs index e1eebce6e0f..8f2a95f71c4 100644 --- a/vortex-tensor/src/encodings/turboquant/array.rs +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -1,76 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! TurboQuant array definition: stores quantized coordinate codes, norms, -//! centroids (codebook), and rotation signs. - -use vortex_array::ArrayId; use vortex_array::ArrayRef; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::stats::ArrayStats; -use vortex_array::vtable; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use crate::encodings::turboquant::array::slots::Slot; +use crate::encodings::turboquant::vtable::TurboQuant; use crate::utils::extension_element_ptype; use crate::utils::extension_list_size; -use crate::vector::Vector; - -/// Encoding marker type for TurboQuant. -#[derive(Clone, Debug)] -pub struct TurboQuant; - -impl TurboQuant { - pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); -} - -vtable!(TurboQuant, TurboQuant, TurboQuantData); - -/// Serialized metadata for TurboQuant encoding: a single byte holding the `bit_width` (0-8). -/// -/// All other fields (dimension, element type) are derived from the dtype and children. -/// A `bit_width` of 0 indicates a degenerate empty array. -#[derive(Clone, Debug)] -pub struct TurboQuantMetadata { - /// MSE bits per coordinate (0 for degenerate empty arrays, 1-8 otherwise). - pub bit_width: u8, -} - -/// Slot positions for TurboQuantArray children. -#[repr(usize)] -#[derive(Clone, Copy, Debug)] -pub(crate) enum Slot { - Codes = 0, - Norms = 1, - Centroids = 2, - RotationSigns = 3, -} - -impl Slot { - pub(crate) const COUNT: usize = 4; - - pub(crate) fn name(self) -> &'static str { - match self { - Self::Codes => "codes", - Self::Norms => "norms", - Self::Centroids => "centroids", - Self::RotationSigns => "rotation_signs", - } - } - - pub(crate) fn from_index(idx: usize) -> Self { - match idx { - 0 => Self::Codes, - 1 => Self::Norms, - 2 => Self::Centroids, - 3 => Self::RotationSigns, - _ => vortex_error::vortex_panic!("invalid slot index {idx}"), - } - } -} /// TurboQuant array data. /// @@ -207,22 +150,8 @@ impl TurboQuantData { centroids: &ArrayRef, rotation_signs: &ArrayRef, ) -> VortexResult<()> { - // Dtype must be a Vector extension type. - let ext = dtype - .as_extension_opt() - .filter(|e| e.is::()) - .ok_or_else(|| { - vortex_error::vortex_err!( - "TurboQuant dtype must be a Vector extension type, got {dtype}" - ) - })?; - - // Dimension is derived from the storage dtype's list size and must be >= 3. + let ext = TurboQuant::validate_dtype(dtype)?; let dimension = extension_list_size(ext)?; - vortex_ensure!( - dimension >= 3, - "TurboQuant requires dimension >= 3, got {dimension}" - ); let num_rows = norms.len(); diff --git a/vortex-tensor/src/encodings/turboquant/array/metadata.rs b/vortex-tensor/src/encodings/turboquant/array/metadata.rs new file mode 100644 index 00000000000..2fead1db835 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/metadata.rs @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +/// Serialized metadata for TurboQuant encoding: a single byte holding the `bit_width` (0-8). +/// +/// All other fields (dimension, element type) are derived from the dtype and children. +/// A `bit_width` of 0 indicates a degenerate empty array. +#[derive(Clone, Debug)] +pub struct TurboQuantMetadata { + /// MSE bits per coordinate (0 for degenerate empty arrays, 1-8 otherwise). + pub bit_width: u8, +} diff --git a/vortex-tensor/src/encodings/turboquant/array/mod.rs b/vortex-tensor/src/encodings/turboquant/array/mod.rs new file mode 100644 index 00000000000..ba503ab6672 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/mod.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant array definition: stores quantized coordinate codes, norms, centroids (codebook), +//! and rotation signs. + +pub(crate) mod data; +pub(crate) mod metadata; +pub(crate) mod slots; + +pub(crate) mod scheme; + +pub(crate) mod centroids; +pub(crate) mod rotation; diff --git a/vortex-tensor/src/encodings/turboquant/rotation.rs b/vortex-tensor/src/encodings/turboquant/array/rotation.rs similarity index 100% rename from vortex-tensor/src/encodings/turboquant/rotation.rs rename to vortex-tensor/src/encodings/turboquant/array/rotation.rs diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/array/scheme.rs similarity index 68% rename from vortex-tensor/src/encodings/turboquant/scheme.rs rename to vortex-tensor/src/encodings/turboquant/array/scheme.rs index de5fc8ee88a..74e380f0e67 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/array/scheme.rs @@ -5,42 +5,37 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; -use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; use vortex_compressor::CascadingCompressor; use vortex_compressor::ctx::CompressorContext; use vortex_compressor::scheme::Scheme; use vortex_compressor::stats::ArrayAndStats; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_ensure; -use super::FIXED_SHAPE_TENSOR_EXT_ID; -use super::TurboQuantConfig; -use super::VECTOR_EXT_ID; -use super::turboquant_encode; +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::TurboQuantConfig; +use crate::encodings::turboquant::turboquant_encode; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; -/// TurboQuant compression scheme for tensor extension types. +/// TurboQuant compression scheme for [`Vector`] extension types. /// -/// Applies lossy vector quantization to `Vector` and `FixedShapeTensor` extension -/// arrays using the TurboQuant algorithm with MSE-optimal encoding. +/// Applies lossy vector quantization to [`Vector`] extension arrays using the TurboQuant +/// algorithm with MSE-optimal encoding. /// /// Register this scheme with the compressor builder via `with_scheme`: /// ```ignore /// use vortex_btrblocks::BtrBlocksCompressorBuilder; -/// use vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME; +/// use vortex_tensor::encodings::turboquant::TurboQuantScheme; /// /// let compressor = BtrBlocksCompressorBuilder::default() -/// .with_scheme(&TURBOQUANT_SCHEME) +/// .with_scheme(&TurboQuantScheme) /// .build(); /// ``` +/// +/// [`Vector`]: crate::vector::Vector #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct TurboQuantScheme; -/// Static instance for registration with `BtrBlocksCompressorBuilder::with_scheme`. -pub static TURBOQUANT_SCHEME: TurboQuantScheme = TurboQuantScheme; - impl Scheme for TurboQuantScheme { fn scheme_name(&self) -> &'static str { "vortex.tensor.turboquant" @@ -51,7 +46,7 @@ impl Scheme for TurboQuantScheme { return false; }; - get_tensor_element_ptype_and_length(ext.dtype()).is_ok() + TurboQuant::validate_dtype(ext.dtype()).is_ok() } fn expected_compression_ratio( @@ -62,10 +57,14 @@ impl Scheme for TurboQuantScheme { ) -> VortexResult { let dtype = data.array().dtype(); let len = data.array().len(); - let (element_ptype, dimensions) = get_tensor_element_ptype_and_length(dtype)?; + + let ext = TurboQuant::validate_dtype(dtype)?; + let element_ptype = extension_element_ptype(ext)?; + let dimension = extension_list_size(ext)?; + Ok(estimate_compression_ratio( element_ptype.bit_width(), - dimensions, + dimension, len, )) } @@ -76,8 +75,8 @@ impl Scheme for TurboQuantScheme { data: &mut ArrayAndStats, _ctx: CompressorContext, ) -> VortexResult { - let array = data.array().clone(); - let ext_array = array.to_canonical()?.into_extension(); + // TODO(connor): Fix this once we ensure that the data array is always canonical. + let ext_array = data.array().to_canonical()?.into_extension(); let config = TurboQuantConfig::default(); turboquant_encode(&ext_array, &config, &mut compressor.execution_ctx()) @@ -104,40 +103,6 @@ fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vect uncompressed_size_bits as f64 / compressed_size_bits as f64 } -fn get_tensor_element_ptype_and_length(dtype: &DType) -> VortexResult<(PType, u32)> { - let ext_id = dtype.as_extension().id(); - let is_tensor = dtype.is_extension() - && (ext_id.as_ref() == VECTOR_EXT_ID || ext_id.as_ref() == FIXED_SHAPE_TENSOR_EXT_ID); - vortex_ensure!(is_tensor, "expected tensor extension dtype, got {}", dtype); - - let storage_dtype = dtype.as_extension().storage_dtype(); - let (element_dtype, fsl_len) = match storage_dtype { - DType::FixedSizeList(element_dtype, list_size, _) => (element_dtype, list_size), - _ => vortex_bail!( - "expected FixedSizeList storage dtype, got {}", - storage_dtype - ), - }; - - // TurboQuant requires dimension >= 3: the marginal coordinate distribution - // (1 - x^2)^((d-3)/2) has a singularity at d=2 (arcsine distribution) that - // causes NaN in the Max-Lloyd centroid computation. - vortex_ensure!( - *fsl_len >= 3, - "TurboQuant requires dimension >= 3, got {}", - fsl_len - ); - - if let &DType::Primitive(ptype, Nullability::NonNullable) = element_dtype.as_ref() { - Ok((ptype, *fsl_len)) - } else { - vortex_bail!( - "expected non-nullable primitive element type, got {}", - element_dtype - ); - } -} - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/vortex-tensor/src/encodings/turboquant/array/slots.rs b/vortex-tensor/src/encodings/turboquant/array/slots.rs new file mode 100644 index 00000000000..ff59db447d3 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/slots.rs @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +/// Slot positions for TurboQuantArray children. +#[repr(usize)] +#[derive(Clone, Copy, Debug)] +pub(crate) enum Slot { + Codes = 0, + Norms = 1, + Centroids = 2, + RotationSigns = 3, +} + +impl Slot { + pub(crate) const COUNT: usize = 4; + + pub(crate) fn name(self) -> &'static str { + match self { + Self::Codes => "codes", + Self::Norms => "norms", + Self::Centroids => "centroids", + Self::RotationSigns => "rotation_signs", + } + } + + pub(crate) fn from_index(idx: usize) -> Self { + match idx { + 0 => Self::Codes, + 1 => Self::Norms, + 2 => Self::Centroids, + 3 => Self::RotationSigns, + _ => vortex_error::vortex_panic!("invalid slot index {idx}"), + } + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 8f5b02222f8..d97000b3418 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -20,11 +20,11 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_fastlanes::bitpack_compress::bitpack_encode; -use crate::encodings::turboquant::array::TurboQuantData; -use crate::encodings::turboquant::centroids::compute_boundaries; -use crate::encodings::turboquant::centroids::find_nearest_centroid; -use crate::encodings::turboquant::centroids::get_centroids; -use crate::encodings::turboquant::rotation::RotationMatrix; +use crate::encodings::turboquant::TurboQuantData; +use crate::encodings::turboquant::array::centroids::compute_boundaries; +use crate::encodings::turboquant::array::centroids::find_nearest_centroid; +use crate::encodings::turboquant::array::centroids::get_centroids; +use crate::encodings::turboquant::array::rotation::RotationMatrix; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_norm::L2Norm; diff --git a/vortex-tensor/src/encodings/turboquant/compute/ops.rs b/vortex-tensor/src/encodings/turboquant/compute/ops.rs index 5309669ed53..4999816319b 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/ops.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/ops.rs @@ -10,7 +10,7 @@ use vortex_array::vtable::OperationsVTable; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::TurboQuant; impl OperationsVTable for TurboQuant { fn scalar_at( diff --git a/vortex-tensor/src/encodings/turboquant/compute/rules.rs b/vortex-tensor/src/encodings/turboquant/compute/rules.rs index d482994f720..39919a8c1ec 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/rules.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/rules.rs @@ -6,7 +6,7 @@ use vortex_array::arrays::slice::SliceReduceAdaptor; use vortex_array::kernel::ParentKernelSet; use vortex_array::optimizer::rules::ParentRuleSet; -use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::TurboQuant; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ParentRuleSet::lift(&SliceReduceAdaptor(TurboQuant))]); diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs index 86768b949d6..8c6805f24ef 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/slice.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/slice.rs @@ -9,8 +9,8 @@ use vortex_array::IntoArray; use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; -use crate::encodings::turboquant::array::TurboQuant; -use crate::encodings::turboquant::array::TurboQuantData; +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::TurboQuantData; impl SliceReduce for TurboQuant { fn slice( diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs index e3b5e866e57..7b52baf804e 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/take.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/take.rs @@ -8,8 +8,8 @@ use vortex_array::IntoArray; use vortex_array::arrays::dict::TakeExecute; use vortex_error::VortexResult; -use crate::encodings::turboquant::array::TurboQuant; -use crate::encodings::turboquant::array::TurboQuantData; +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::TurboQuantData; impl TakeExecute for TurboQuant { fn take( diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index 5897784c306..bf0f1fa21cb 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -15,7 +15,7 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::rotation::RotationMatrix; +use crate::encodings::turboquant::array::rotation::RotationMatrix; /// Decompress a `TurboQuantArray` into a [`Vector`] extension array. /// diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 760dce16f2e..ab311799381 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -90,26 +90,6 @@ //! assert!(encoded.nbytes() < 51200); //! ``` -pub use array::TurboQuant; -pub use array::TurboQuantData; -pub use compress::TurboQuantConfig; -pub use compress::turboquant_encode; - -mod array; -pub(crate) mod centroids; -mod compress; -pub(crate) mod compute; -pub(crate) mod decompress; -pub(crate) mod rotation; -pub mod scheme; -mod vtable; - -/// Extension ID for the `Vector` type from `vortex-tensor`. -pub const VECTOR_EXT_ID: &str = "vortex.tensor.vector"; - -/// Extension ID for the `FixedShapeTensor` type from `vortex-tensor`. -pub const FIXED_SHAPE_TENSOR_EXT_ID: &str = "vortex.tensor.fixed_shape_tensor"; - use vortex_array::session::ArraySessionExt; use vortex_session::VortexSession; @@ -118,5 +98,21 @@ pub fn initialize(session: &mut VortexSession) { session.arrays().register(TurboQuant); } +mod array; +pub use array::data::TurboQuantData; +pub use array::metadata::TurboQuantMetadata; +pub use array::scheme::TurboQuantScheme; + +pub(crate) mod compute; + +mod vtable; +pub use vtable::TurboQuant; + +mod compress; +pub use compress::TurboQuantConfig; +pub use compress::turboquant_encode; + +mod decompress; + #[cfg(test)] mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs index 5542e4d8278..afdf9c46265 100644 --- a/vortex-tensor/src/encodings/turboquant/tests.rs +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -21,7 +21,7 @@ use vortex_session::VortexSession; use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::rotation::RotationMatrix; +use crate::encodings::turboquant::array::rotation::RotationMatrix; use crate::encodings::turboquant::turboquant_encode; use crate::vector::Vector; @@ -365,7 +365,7 @@ fn stored_centroids_match_computed() -> VortexResult<()> { let stored = stored_centroids_prim.as_slice::(); let padded_dim = encoded.padded_dim(); - let computed = crate::encodings::turboquant::centroids::get_centroids(padded_dim, 3)?; + let computed = crate::encodings::turboquant::array::centroids::get_centroids(padded_dim, 3)?; assert_eq!(stored.len(), computed.len()); for i in 0..stored.len() { @@ -474,7 +474,7 @@ fn serde_roundtrip() -> VortexResult<()> { let original_elements = original_fsl.elements().to_canonical()?.into_primitive(); // Rebuild from children (simulating deserialization). - let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new( + let rebuilt = crate::encodings::turboquant::TurboQuantData::try_new( encoded.dtype().clone(), children[0].clone(), children[1].clone(), diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index 0d8d4ce8e1d..746135c0d77 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -19,26 +19,61 @@ use vortex_array::buffer::BufferHandle; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDTypeRef; use vortex_array::serde::ArrayChildren; use vortex_array::stats::ArrayStats; +use vortex_array::vtable; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityChild; use vortex_array::vtable::ValidityVTableFromChild; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_session::VortexSession; -use crate::encodings::turboquant::array::Slot; -use crate::encodings::turboquant::array::TurboQuant; -use crate::encodings::turboquant::array::TurboQuantData; -use crate::encodings::turboquant::array::TurboQuantMetadata; +use crate::encodings::turboquant::TurboQuantData; +use crate::encodings::turboquant::TurboQuantMetadata; +use crate::encodings::turboquant::array::slots::Slot; use crate::encodings::turboquant::compute::rules::PARENT_KERNELS; use crate::encodings::turboquant::compute::rules::RULES; use crate::encodings::turboquant::decompress::execute_decompress; use crate::utils::extension_element_ptype; use crate::utils::extension_list_size; +use crate::vector::Vector; + +/// Encoding marker type for TurboQuant. +#[derive(Clone, Debug)] +pub struct TurboQuant; + +impl TurboQuant { + pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); + + /// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with + /// dimension >= 3. + /// + /// Returns the validated [`ExtDTypeRef`] on success, which can be used to extract the + /// element ptype and list size. + pub fn validate_dtype(dtype: &DType) -> VortexResult<&ExtDTypeRef> { + let ext = dtype + .as_extension_opt() + .filter(|e| e.is::()) + .ok_or_else(|| { + vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") + })?; + + let dimension = extension_list_size(ext)?; + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + Ok(ext) + } +} + +vtable!(TurboQuant, TurboQuant, TurboQuantData); impl VTable for TurboQuant { type ArrayData = TurboQuantData; @@ -173,10 +208,9 @@ impl VTable for TurboQuant { ) -> VortexResult { let bit_width = metadata.bit_width; - // Derive dimension and element ptype from the Vector extension dtype. - let ext = dtype.as_extension(); + // Validate and derive dimension and element ptype from the Vector extension dtype. + let ext = TurboQuant::validate_dtype(dtype)?; let dimension = extension_list_size(ext)?; - let element_ptype = extension_element_ptype(ext)?; let element_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); From 9d3f807e759ac359a613e639ceb5facf986ad7f8 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 16:33:15 -0400 Subject: [PATCH 06/13] even more cleanup Signed-off-by: Connor Tsui --- .../src/encodings/turboquant/array/data.rs | 121 ++++++++---------- .../src/encodings/turboquant/array/mod.rs | 4 +- 2 files changed, 57 insertions(+), 68 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs index 8f2a95f71c4..26064e64c3c 100644 --- a/vortex-tensor/src/encodings/turboquant/array/data.rs +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -9,6 +9,7 @@ use vortex_array::stats::ArrayStats; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; use crate::encodings::turboquant::array::slots::Slot; use crate::encodings::turboquant::vtable::TurboQuant; @@ -17,22 +18,22 @@ use crate::utils::extension_list_size; /// TurboQuant array data. /// -/// TurboQuant is a lossy vector quantization encoding for [`Vector`] extension arrays. -/// It stores quantized coordinate codes and per-vector norms, along with shared codebook -/// centroids and SRHT rotation signs. See the [module docs](super) for algorithmic details. +/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector) +/// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared +/// codebook centroids and SRHT rotation signs. /// -/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty. +/// See the [module docs](super) for algorithmic details. /// -/// [`Vector`]: crate::vector::Vector +/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty. #[derive(Clone, Debug)] pub struct TurboQuantData { - /// The [`Vector`] extension dtype that this array encodes. The storage dtype within the - /// extension determines the element type (f16, f32, or f64) and the list size (dimension). + /// The [`Vector`](crate::vector::Vector) extension dtype that this array encodes. /// - /// [`Vector`]: crate::vector::Vector + /// The storage dtype within the extension determines the element type (f16, f32, or f64) and + /// the list size (dimension). pub(crate) dtype: DType, - /// Child arrays stored as optional slots. See [`Slot`] for positions: + /// Child arrays stored as slots. See [`Slot`] for positions: /// /// - [`Codes`](Slot::Codes): `FixedSizeListArray` with `list_size == padded_dim`. Each row /// holds one u8 centroid index per padded coordinate. The cascade compressor handles packing @@ -53,13 +54,13 @@ pub struct TurboQuantData { pub(crate) slots: Vec>, /// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size. + /// /// Stored as a convenience field to avoid repeatedly extracting it from `dtype`. - /// Non-power-of-2 dimensions are zero-padded to [`padded_dim`](Self::padded_dim) for the - /// Walsh-Hadamard transform. pub(crate) dimension: u32, /// The number of bits per coordinate (1-8), derived from `log2(centroids.len())`. - /// Zero for degenerate empty arrays. + /// + /// This is 0 for degenerate empty arrays. pub(crate) bit_width: u8, /// The stats for this array. @@ -100,8 +101,8 @@ impl TurboQuantData { /// is >= 3. /// - `codes` is a `FixedSizeListArray` with `list_size == padded_dim` and /// `codes.len() == norms.len()`. - /// - `norms` is a non-nullable primitive array whose ptype matches the element type of the - /// Vector's storage dtype. + /// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage + /// dtype. This must match the validity of the `codes` array. /// - `centroids` is a non-nullable `PrimitiveArray` whose length is a power of 2 in /// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays. /// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays. @@ -124,13 +125,22 @@ impl TurboQuantData { .and_then(|ext| extension_list_size(ext).ok()) .vortex_expect("dtype must be a Vector extension type with FixedSizeList storage"); - let bit_width = derive_bit_width(¢roids); + let bit_width = if centroids.is_empty() { + 0 + } else { + // Guaranteed to be 0-8 by validate(). + #[expect(clippy::cast_possible_truncation)] + { + centroids.len().trailing_zeros() as u8 + } + }; let mut slots = vec![None; Slot::COUNT]; slots[Slot::Codes as usize] = Some(codes); slots[Slot::Norms as usize] = Some(norms); slots[Slot::Centroids as usize] = Some(centroids); slots[Slot::RotationSigns as usize] = Some(rotation_signs); + Self { dtype, slots, @@ -153,15 +163,19 @@ impl TurboQuantData { let ext = TurboQuant::validate_dtype(dtype)?; let dimension = extension_list_size(ext)?; - let num_rows = norms.len(); + let num_rows = codes.len(); + vortex_ensure_eq!( + norms.len(), + num_rows, + "norms length must match codes length", + ); + + // TODO(connor): Should we check that the codes and norms have the same validity? We could + // also make it so that norms holds the validity and any null vectors encoded as codes is + // just 0... - // Degenerate (empty) case: all children must be empty, bit_width is 0. + // Degenerate (empty) case: all children must be empty, and bit_width is 0. if num_rows == 0 { - vortex_ensure!( - codes.is_empty(), - "degenerate TurboQuant must have empty codes, got length {}", - codes.len() - ); vortex_ensure!( centroids.is_empty(), "degenerate TurboQuant must have empty centroids, got length {}", @@ -183,7 +197,7 @@ impl TurboQuantData { ); // Guaranteed to be 1-8 by the preceding power-of-2 and range checks. - #[allow(clippy::cast_possible_truncation)] + #[expect(clippy::cast_possible_truncation)] let bit_width = num_centroids.trailing_zeros() as u8; vortex_ensure!( (1..=8).contains(&bit_width), @@ -193,44 +207,34 @@ impl TurboQuantData { // Norms dtype must match the element ptype of the Vector. let element_ptype = extension_element_ptype(ext)?; let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); - vortex_ensure!( - *norms.dtype() == expected_norms_dtype, - "norms dtype {} does not match expected {expected_norms_dtype} \ + vortex_ensure_eq!( + *norms.dtype(), + expected_norms_dtype, + "norms dtype does not match expected {expected_norms_dtype} \ (must match Vector element type)", - norms.dtype() ); // Centroids are always f32 regardless of element type. - let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable); - vortex_ensure!( - *centroids.dtype() == f32_nn, - "centroids dtype {} must be non-nullable f32", - centroids.dtype() - ); - - // Row count consistency. - vortex_ensure!( - codes.len() == num_rows, - "codes length {} does not match norms length {num_rows}", - codes.len() + let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + vortex_ensure_eq!( + *centroids.dtype(), + centroids_dtype, + "centroids dtype must be non-nullable f32", ); // Rotation signs count must be 3 * padded_dim. let padded_dim = dimension.next_power_of_two() as usize; - vortex_ensure!( - rotation_signs.len() == 3 * padded_dim, - "rotation_signs length {} does not match expected 3 * {padded_dim} = {}", + vortex_ensure_eq!( rotation_signs.len(), - 3 * padded_dim + 3 * padded_dim, + "rotation_signs length does not match expected 3 * {padded_dim}", ); Ok(()) } - /// The vector dimension `d`, as stored in the [`Vector`] extension dtype's - /// `FixedSizeList` storage. - /// - /// [`Vector`]: crate::vector::Vector + /// The vector dimension `d`, as stored in the [`Vector`](crate::vector::Vector) extension + /// dtype's `FixedSizeList` storage. pub fn dimension(&self) -> u32 { self.dimension } @@ -248,12 +252,6 @@ impl TurboQuantData { self.dimension.next_power_of_two() } - fn slot(&self, idx: usize) -> &ArrayRef { - self.slots[idx] - .as_ref() - .vortex_expect("required slot is None") - } - /// The quantized codes child (`FixedSizeListArray`, one row per vector). pub fn codes(&self) -> &ArrayRef { self.slot(Slot::Codes as usize) @@ -278,19 +276,10 @@ impl TurboQuantData { pub fn rotation_signs(&self) -> &ArrayRef { self.slot(Slot::RotationSigns as usize) } -} -/// Derive `bit_width` from the centroids array length. -/// -/// Returns 0 for empty centroids (degenerate array), otherwise `log2(centroids.len())`. -fn derive_bit_width(centroids: &ArrayRef) -> u8 { - if centroids.is_empty() { - 0 - } else { - // Guaranteed to be 0-8 by validate(). - #[allow(clippy::cast_possible_truncation)] - { - centroids.len().trailing_zeros() as u8 - } + fn slot(&self, idx: usize) -> &ArrayRef { + self.slots[idx] + .as_ref() + .vortex_expect("required slot is None") } } diff --git a/vortex-tensor/src/encodings/turboquant/array/mod.rs b/vortex-tensor/src/encodings/turboquant/array/mod.rs index ba503ab6672..3a5281f767c 100644 --- a/vortex-tensor/src/encodings/turboquant/array/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/array/mod.rs @@ -8,7 +8,7 @@ pub(crate) mod data; pub(crate) mod metadata; pub(crate) mod slots; -pub(crate) mod scheme; - pub(crate) mod centroids; pub(crate) mod rotation; + +pub(crate) mod scheme; From 626bb47263255cfee60d8ab40a1915d50f5f46ee Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 16:45:07 -0400 Subject: [PATCH 07/13] fix minor things Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 80 +++++++++-------- .../src/encodings/turboquant/array/data.rs | 29 +++++-- .../src/encodings/turboquant/array/scheme.rs | 2 +- .../src/encodings/turboquant/decompress.rs | 85 ++++++++++++++----- .../src/encodings/turboquant/tests.rs | 3 + 5 files changed, 131 insertions(+), 68 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index f455a994295..5ab6d5c7e0d 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -4,46 +4,14 @@ pub mod vortex_tensor::encodings pub mod vortex_tensor::encodings::turboquant -pub mod vortex_tensor::encodings::turboquant::scheme - -pub struct vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -impl core::clone::Clone for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::clone(&self) -> vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -impl core::cmp::Eq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -impl core::cmp::PartialEq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::eq(&self, other: &vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme) -> bool - -impl core::fmt::Debug for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl core::marker::Copy for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -impl core::marker::StructuralPartialEq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -impl vortex_compressor::scheme::Scheme for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::compress(&self, compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::expected_compression_ratio(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::matches(&self, canonical: &vortex_array::canonical::Canonical) -> bool - -pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::scheme_name(&self) -> &'static str - -pub static vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME: vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme - pub struct vortex_tensor::encodings::turboquant::TurboQuant impl vortex_tensor::encodings::turboquant::TurboQuant pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId +pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef> + impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant pub fn vortex_tensor::encodings::turboquant::TurboQuant::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuant @@ -56,7 +24,7 @@ impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquan pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData -pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::array::TurboQuantMetadata +pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::TurboQuantMetadata pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant @@ -176,9 +144,47 @@ impl vortex_array::array::IntoArray for vortex_tensor::encodings::turboquant::Tu pub fn vortex_tensor::encodings::turboquant::TurboQuantData::into_array(self) -> vortex_array::array::erased::ArrayRef -pub const vortex_tensor::encodings::turboquant::FIXED_SHAPE_TENSOR_EXT_ID: &str +pub struct vortex_tensor::encodings::turboquant::TurboQuantMetadata + +pub vortex_tensor::encodings::turboquant::TurboQuantMetadata::bit_width: u8 + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantMetadata + +pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantMetadata + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantMetadata + +pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub struct vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::cmp::Eq for vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::cmp::PartialEq for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::eq(&self, other: &vortex_tensor::encodings::turboquant::TurboQuantScheme) -> bool + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::marker::Copy for vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::marker::StructuralPartialEq for vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl vortex_compressor::scheme::Scheme for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::compress(&self, compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::expected_compression_ratio(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::matches(&self, canonical: &vortex_array::canonical::Canonical) -> bool -pub const vortex_tensor::encodings::turboquant::VECTOR_EXT_ID: &str +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self) -> &'static str pub fn vortex_tensor::encodings::turboquant::initialize(session: &mut vortex_session::VortexSession) diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs index 26064e64c3c..560148a1276 100644 --- a/vortex-tensor/src/encodings/turboquant/array/data.rs +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::sync::Arc; + use vortex_array::ArrayRef; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; @@ -22,7 +24,7 @@ use crate::utils::extension_list_size; /// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared /// codebook centroids and SRHT rotation signs. /// -/// See the [module docs](super) for algorithmic details. +/// See the [module docs](crate::encodings::turboquant) for algorithmic details. /// /// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty. #[derive(Clone, Debug)] @@ -128,7 +130,7 @@ impl TurboQuantData { let bit_width = if centroids.is_empty() { 0 } else { - // Guaranteed to be 0-8 by validate(). + // Guaranteed to be 1-8 by validate(). #[expect(clippy::cast_possible_truncation)] { centroids.len().trailing_zeros() as u8 @@ -162,6 +164,19 @@ impl TurboQuantData { ) -> VortexResult<()> { let ext = TurboQuant::validate_dtype(dtype)?; let dimension = extension_list_size(ext)?; + let padded_dim = dimension.next_power_of_two(); + + // Codes must be a FixedSizeList with list_size == padded_dim. + let expected_codes_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), // FIX THIS!!! + padded_dim, + dtype.nullability(), + ); + vortex_ensure_eq!( + *codes.dtype(), + expected_codes_dtype, + "codes dtype does not match expected {expected_codes_dtype}", + ); let num_rows = codes.len(); vortex_ensure_eq!( @@ -206,12 +221,11 @@ impl TurboQuantData { // Norms dtype must match the element ptype of the Vector. let element_ptype = extension_element_ptype(ext)?; - let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); // FIX THIS!!! vortex_ensure_eq!( *norms.dtype(), expected_norms_dtype, - "norms dtype does not match expected {expected_norms_dtype} \ - (must match Vector element type)", + "norms dtype does not match expected (must match Vector element type)", ); // Centroids are always f32 regardless of element type. @@ -219,14 +233,13 @@ impl TurboQuantData { vortex_ensure_eq!( *centroids.dtype(), centroids_dtype, - "centroids dtype must be non-nullable f32", + "centroids dtype must be non-nullable f32", ); // Rotation signs count must be 3 * padded_dim. - let padded_dim = dimension.next_power_of_two() as usize; vortex_ensure_eq!( rotation_signs.len(), - 3 * padded_dim, + 3 * padded_dim as usize, "rotation_signs length does not match expected 3 * {padded_dim}", ); diff --git a/vortex-tensor/src/encodings/turboquant/array/scheme.rs b/vortex-tensor/src/encodings/turboquant/array/scheme.rs index 74e380f0e67..2b954522676 100644 --- a/vortex-tensor/src/encodings/turboquant/array/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/array/scheme.rs @@ -28,7 +28,7 @@ use crate::utils::extension_list_size; /// use vortex_tensor::encodings::turboquant::TurboQuantScheme; /// /// let compressor = BtrBlocksCompressorBuilder::default() -/// .with_scheme(&TurboQuantScheme) +/// .with_new_scheme(&TurboQuantScheme) /// .build(); /// ``` /// diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index bf0f1fa21cb..25c38840624 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -3,6 +3,8 @@ //! TurboQuant decoding (dequantization) logic. +use num_traits::FromPrimitive; +use num_traits::Zero; use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; @@ -10,12 +12,15 @@ use vortex_array::IntoArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::NativePType; +use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::array::rotation::RotationMatrix; +use crate::utils::extension_element_ptype; /// Decompress a `TurboQuantArray` into a [`Vector`] extension array. /// @@ -31,19 +36,23 @@ pub fn execute_decompress( let padded_dim = array.padded_dim() as usize; let num_rows = array.norms().len(); let ext_dtype = array.dtype.as_extension().clone(); + let element_ptype = extension_element_ptype(&ext_dtype)?; if num_rows == 0 { - let elements = PrimitiveArray::empty::(ext_dtype.storage_dtype().nullability()); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - array.dimension(), - Validity::NonNullable, - 0, - )?; - return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()); + let nn = vortex_array::dtype::Nullability::NonNullable; + match_each_float_ptype!(element_ptype, |T| { + let elements = PrimitiveArray::empty::(nn); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + 0, + )?; + return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()); + }) } - // Read stored centroids -- no recomputation. + // Read stored centroids (always f32). let centroids_prim = array.centroids().clone().execute::(ctx)?; let centroids = centroids_prim.as_slice::(); @@ -61,11 +70,47 @@ pub fn execute_decompress( let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); let indices = codes_prim.as_slice::(); + // Read norms in their native precision. let norms_prim = array.norms().clone().execute::(ctx)?; - let norms = norms_prim.as_slice::(); - // MSE decode: dequantize -> inverse rotate -> scale by norm. - let mut output = BufferMut::::with_capacity(num_rows * dim); + // MSE decode: dequantize (f32) -> inverse rotate (f32) -> scale by norm -> cast to T. + // The rotation and centroid lookup always happen in f32. The final output is cast to the + // Vector's element type to match the original storage dtype. + match_each_float_ptype!(element_ptype, |T| { + decompress_typed::( + &norms_prim, + centroids, + &rotation, + indices, + dim, + padded_dim, + num_rows, + ) + .and_then(|elements| { + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )?; + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + }) + }) +} + +/// Typed decompress: reads norms as `T`, dequantizes in f32, and produces output as `T`. +fn decompress_typed( + norms_prim: &PrimitiveArray, + centroids: &[f32], + rotation: &RotationMatrix, + indices: &[u8], + dim: usize, + padded_dim: usize, + num_rows: usize, +) -> VortexResult { + let norms = norms_prim.as_slice::(); + + let mut output = BufferMut::::with_capacity(num_rows * dim); let mut dequantized = vec![0.0f32; padded_dim]; let mut unrotated = vec![0.0f32; padded_dim]; @@ -80,18 +125,14 @@ pub fn execute_decompress( rotation.inverse_rotate(&dequantized, &mut unrotated); for idx in 0..dim { - unrotated[idx] *= norm; + // Convert f32 dequantized value to T, then scale by the native-precision norm. + let val = T::from_f32(unrotated[idx]).unwrap_or_else(T::zero) * norm; + output.push(val); } - - output.extend_from_slice(&unrotated[..dim]); } - let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - array.dimension(), + Ok(PrimitiveArray::new::( + output.freeze(), Validity::NonNullable, - num_rows, - )?; - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + )) } diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs index afdf9c46265..e7a7db4520b 100644 --- a/vortex-tensor/src/encodings/turboquant/tests.rs +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + use std::sync::LazyLock; use rand::SeedableRng; From b9c26d7ee7286086039b3d0fd90606a1a1a63cfb Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 17:20:20 -0400 Subject: [PATCH 08/13] fix nullability handling Signed-off-by: Connor Tsui --- .../src/encodings/turboquant/array/data.rs | 37 ++-- .../src/encodings/turboquant/compress.rs | 36 ++-- .../src/encodings/turboquant/decompress.rs | 14 +- .../src/encodings/turboquant/tests.rs | 201 +++++++++++++++++- .../src/encodings/turboquant/vtable.rs | 14 +- 5 files changed, 257 insertions(+), 45 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs index 560148a1276..f369e61e35e 100644 --- a/vortex-tensor/src/encodings/turboquant/array/data.rs +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -37,13 +37,15 @@ pub struct TurboQuantData { /// Child arrays stored as slots. See [`Slot`] for positions: /// - /// - [`Codes`](Slot::Codes): `FixedSizeListArray` with `list_size == padded_dim`. Each row - /// holds one u8 centroid index per padded coordinate. The cascade compressor handles packing - /// to the actual `bit_width` on disk. The validity of the entire array is stored with this. + /// - [`Codes`](Slot::Codes): Non-nullable `FixedSizeListArray` with + /// `list_size == padded_dim`. Each row holds one u8 centroid index per padded coordinate. + /// Null vectors are represented by all-zero codes. The cascade compressor handles packing + /// to the actual `bit_width` on disk. /// /// - [`Norms`](Slot::Norms): Per-vector L2 norms, one per row. The dtype matches the element - /// type of the Vector (e.g., f64 norms for f64 vectors). Exact norms are stored during - /// compression, enabling O(1) L2 norm readthrough without decompression. + /// type of the Vector (e.g., f64 norms for f64 vectors) and carries the nullability of the + /// parent dtype. Null vectors have null norms. This child determines the validity of the + /// entire TurboQuant array, enabling O(1) L2 norm readthrough without decompression. /// /// - [`Centroids`](Slot::Centroids): `PrimitiveArray` codebook with `2^bit_width` entries /// that is shared across all rows. We always store these as f32 regardless of the input @@ -101,10 +103,11 @@ impl TurboQuantData { /// /// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size /// is >= 3. - /// - `codes` is a `FixedSizeListArray` with `list_size == padded_dim` and - /// `codes.len() == norms.len()`. + /// - `codes` is a non-nullable `FixedSizeListArray` with `list_size == padded_dim` and + /// `codes.len() == norms.len()`. Null vectors are represented by all-zero codes. /// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage - /// dtype. This must match the validity of the `codes` array. + /// dtype. The nullability must match `dtype.nullability()`. Norms carry the validity of the + /// entire array, since null vectors have null norms. /// - `centroids` is a non-nullable `PrimitiveArray` whose length is a power of 2 in /// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays. /// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays. @@ -166,11 +169,12 @@ impl TurboQuantData { let dimension = extension_list_size(ext)?; let padded_dim = dimension.next_power_of_two(); - // Codes must be a FixedSizeList with list_size == padded_dim. + // Codes must be a non-nullable FixedSizeList with list_size == padded_dim. + // Null vectors are represented by all-zero codes since validity lives in the norms array. let expected_codes_dtype = DType::FixedSizeList( - Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), // FIX THIS!!! + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), padded_dim, - dtype.nullability(), + Nullability::NonNullable, ); vortex_ensure_eq!( *codes.dtype(), @@ -185,10 +189,6 @@ impl TurboQuantData { "norms length must match codes length", ); - // TODO(connor): Should we check that the codes and norms have the same validity? We could - // also make it so that norms holds the validity and any null vectors encoded as codes is - // just 0... - // Degenerate (empty) case: all children must be empty, and bit_width is 0. if num_rows == 0 { vortex_ensure!( @@ -219,13 +219,14 @@ impl TurboQuantData { "derived bit_width must be 1-8, got {bit_width}" ); - // Norms dtype must match the element ptype of the Vector. + // Norms dtype must match the element ptype of the Vector, with the parent's nullability. + // Norms carry the validity of the entire TurboQuant array. let element_ptype = extension_element_ptype(ext)?; - let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); // FIX THIS!!! + let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability()); vortex_ensure_eq!( *norms.dtype(), expected_norms_dtype, - "norms dtype does not match expected (must match Vector element type)", + "norms dtype does not match expected {expected_norms_dtype}", ); // Centroids are always f32 regardless of element type. diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index d97000b3418..119485b9309 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -76,7 +76,8 @@ struct QuantizationResult { rotation: RotationMatrix, centroids: Vec, all_indices: BufferMut, - /// Native-precision norms (matching the Vector element type). + /// Native-precision norms (matching the Vector element type). Carries validity: null vectors + /// have null norms. norms_array: ArrayRef, padded_dim: usize, } @@ -85,19 +86,22 @@ struct QuantizationResult { /// normalize/rotate/quantize all rows. /// /// Norms are computed in the native element precision via the [`L2Norm`] scalar function. -/// The rotation and centroid lookup happen in f32. +/// The rotation and centroid lookup happen in f32. Null rows (per the input validity) produce +/// all-zero codes. #[allow(clippy::cast_possible_truncation)] fn turboquant_quantize_core( ext: &ExtensionArray, fsl: &FixedSizeListArray, seed: u64, bit_width: u8, + validity: &Validity, ctx: &mut ExecutionCtx, ) -> VortexResult { let dimension = fsl.list_size() as usize; let num_rows = fsl.len(); - // Compute native-precision norms via the L2Norm scalar fn. + // Compute native-precision norms via the L2Norm scalar fn. L2Norm propagates validity from + // the input, so null vectors get null norms automatically. let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, ext.as_ref().clone(), num_rows)?; let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; let norms_prim: PrimitiveArray = norms_array.to_canonical()?.into_primitive(); @@ -125,6 +129,12 @@ fn turboquant_quantize_core( let f32_slice = f32_elements.as_slice::(); for row in 0..num_rows { + // Null vectors get all-zero codes. + if !validity.is_valid(row)? { + all_indices.extend(std::iter::repeat_n(0u8, padded_dim)); + continue; + } + let x = &f32_slice[row * dimension..(row + 1) * dimension]; let norm = f32_norms[row]; @@ -189,12 +199,10 @@ fn build_turboquant( ) } -/// Encode a [`Vector`] extension array into a `TurboQuantArray`. -/// -/// The input must be a non-nullable [`Vector`] extension array. TurboQuant is a lossy encoding -/// that does not preserve null positions; callers must handle validity externally. +/// Encode a [`Vector`](crate::vector::Vector) extension array into a `TurboQuantArray`. /// -/// [`Vector`]: crate::vector::Vector +/// Nullable inputs are supported: null vectors get all-zero codes and null norms. The validity +/// of the resulting TurboQuant array is carried by the norms child. pub fn turboquant_encode( ext: &ExtensionArray, config: &TurboQuantConfig, @@ -204,10 +212,6 @@ pub fn turboquant_encode( let storage = ext.storage_array(); let fsl = storage.to_canonical()?.into_fixed_size_list(); - vortex_ensure!( - fsl.dtype().nullability() == Nullability::NonNullable, - "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" - ); vortex_ensure!( config.bit_width >= 1 && config.bit_width <= 8, "bit_width must be 1-8, got {}", @@ -228,10 +232,11 @@ pub fn turboquant_encode( 0, )?; - // Norms dtype matches the element type. + // Norms dtype matches the element type and carries the parent's nullability. let element_ptype = fsl.elements().dtype().as_ptype(); + let norms_nullability = ext_dtype.nullability(); let empty_norms: ArrayRef = match_each_float_ptype!(element_ptype, |T| { - PrimitiveArray::empty::(Nullability::NonNullable).into_array() + PrimitiveArray::empty::(norms_nullability).into_array() }); let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); @@ -246,8 +251,9 @@ pub fn turboquant_encode( .into_array()); } + let validity = ext.as_ref().validity()?; let seed = config.seed.unwrap_or(42); - let core = turboquant_quantize_core(ext, &fsl, seed, config.bit_width, ctx)?; + let core = turboquant_quantize_core(ext, &fsl, seed, config.bit_width, &validity, ctx)?; Ok(build_turboquant(&fsl, core, ext_dtype)?.into_array()) } diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index 25c38840624..0828da3cc05 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -13,6 +13,7 @@ use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; @@ -39,15 +40,17 @@ pub fn execute_decompress( let element_ptype = extension_element_ptype(&ext_dtype)?; if num_rows == 0 { - let nn = vortex_array::dtype::Nullability::NonNullable; + let fsl_validity = Validity::from(ext_dtype.storage_dtype().nullability()); + match_each_float_ptype!(element_ptype, |T| { - let elements = PrimitiveArray::empty::(nn); + let elements = PrimitiveArray::empty::(Nullability::NonNullable); let fsl = FixedSizeListArray::try_new( elements.into_array(), array.dimension(), - Validity::NonNullable, + fsl_validity, 0, )?; + return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()); }) } @@ -70,8 +73,9 @@ pub fn execute_decompress( let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); let indices = codes_prim.as_slice::(); - // Read norms in their native precision. + // Read norms in their native precision. Norms carry the validity of the array. let norms_prim = array.norms().clone().execute::(ctx)?; + let output_validity = array.norms().validity()?; // MSE decode: dequantize (f32) -> inverse rotate (f32) -> scale by norm -> cast to T. // The rotation and centroid lookup always happen in f32. The final output is cast to the @@ -90,7 +94,7 @@ pub fn execute_decompress( let fsl = FixedSizeListArray::try_new( elements.into_array(), array.dimension(), - Validity::NonNullable, + output_validity, num_rows, )?; Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs index e7a7db4520b..a91f9d2b3e6 100644 --- a/vortex-tensor/src/encodings/turboquant/tests.rs +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -26,12 +26,41 @@ use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::array::rotation::RotationMatrix; use crate::encodings::turboquant::turboquant_encode; +use crate::scalar_fns::ApproxOptions; +use crate::scalar_fns::l2_norm::L2Norm; use crate::vector::Vector; static SESSION: LazyLock = LazyLock::new(|| VortexSession::empty().with::()); -/// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal). +/// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal) with the given +/// validity. +fn make_fsl_with_validity( + num_rows: usize, + dim: usize, + seed: u64, + validity: Validity, +) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + validity, + num_rows, + ) + .unwrap() +} + +/// Create a non-nullable FixedSizeListArray of random f32 vectors (i.i.d. standard normal). fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { let mut rng = StdRng::seed_from_u64(seed); let normal = Normal::new(0.0f32, 1.0).unwrap(); @@ -676,3 +705,173 @@ fn encoded_dtype_is_vector_extension() -> VortexResult<()> { ); Ok(()) } + +// ----------------------------------------------------------------------- +// Nullable vector tests +// ----------------------------------------------------------------------- + +/// Encode a nullable Vector array and verify roundtrip preserves validity and non-null values. +#[test] +fn nullable_vectors_roundtrip() -> VortexResult<()> { + // Rows 2, 5, 7 are null. + let validity = Validity::from_iter([ + true, true, false, true, true, false, true, false, true, true, + ]); + let fsl = make_fsl_with_validity(10, 128, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + assert_eq!(encoded.len(), 10); + assert!(encoded.dtype().is_nullable()); + + // Check validity of the encoded array. + let encoded_validity = encoded.validity()?; + for i in 0..10 { + let expected = ![2, 5, 7].contains(&i); + assert_eq!( + encoded_validity.is_valid(i)?, + expected, + "validity mismatch at row {i}" + ); + } + + // Decode and verify non-null rows have correct data. + let decoded_ext = encoded.execute::(&mut ctx)?; + assert_eq!(decoded_ext.len(), 10); + + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let decoded_prim = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_f32 = decoded_prim.as_slice::(); + + // Original f32 elements for non-null row comparison. + let orig_prim = fsl.elements().to_canonical()?.into_primitive(); + let orig_f32 = orig_prim.as_slice::(); + + // Non-null rows should have reasonable reconstruction (within MSE bounds). + for row in [0, 1, 3, 4, 6, 8, 9] { + let orig_vec = &orig_f32[row * 128..(row + 1) * 128]; + let dec_vec = &decoded_f32[row * 128..(row + 1) * 128]; + let norm_sq: f32 = orig_vec.iter().map(|&v| v * v).sum(); + let err_sq: f32 = orig_vec + .iter() + .zip(dec_vec.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + // 3-bit normalized MSE should be well under the theoretical bound. + assert!( + err_sq / norm_sq < 0.1, + "non-null row {row} has excessive reconstruction error" + ); + } + Ok(()) +} + +/// Verify that norms carry the validity: null vectors have null norms. +#[test] +fn nullable_norms_match_validity() -> VortexResult<()> { + let validity = Validity::from_iter([true, false, true, false, true]); + let fsl = make_fsl_with_validity(5, 64, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + + let norms_validity = tq.norms().validity()?; + for i in 0..5 { + let expected = i % 2 == 0; // rows 0, 2, 4 are valid + assert_eq!( + norms_validity.is_valid(i)?, + expected, + "norms validity mismatch at row {i}" + ); + } + Ok(()) +} + +/// Verify that L2Norm readthrough works correctly on nullable TurboQuant arrays. +#[test] +fn nullable_l2_norm_readthrough() -> VortexResult<()> { + let validity = Validity::from_iter([true, false, true, false, true]); + let fsl = make_fsl_with_validity(5, 64, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // Compute L2Norm on the encoded array. + let norm_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, encoded, 5)?; + let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; + + // Null rows should have null norms, valid rows should have correct norms. + let orig_prim = fsl.elements().to_canonical()?.into_primitive(); + let orig_f32 = orig_prim.as_slice::(); + for row in 0..5 { + if row % 2 == 0 { + assert!(norms.is_valid(row)?, "row {row} should be valid"); + let expected: f32 = orig_f32[row * 64..(row + 1) * 64] + .iter() + .map(|&v| v * v) + .sum::() + .sqrt(); + let actual = norms.as_slice::()[row]; + assert!( + (actual - expected).abs() < 1e-5, + "norm mismatch at valid row {row}: actual={actual}, expected={expected}" + ); + } else { + assert!(!norms.is_valid(row)?, "row {row} should be null"); + } + } + Ok(()) +} + +/// Verify that slicing a nullable TurboQuant array preserves validity. +#[test] +fn nullable_slice_preserves_validity() -> VortexResult<()> { + // Rows 2, 5, 7 are null. + let validity = Validity::from_iter([ + true, true, false, true, true, false, true, false, true, true, + ]); + let fsl = make_fsl_with_validity(10, 64, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // Slice rows 1..6 -> [true, false, true, true, false]. + let sliced = encoded.slice(1..6)?; + assert_eq!(sliced.len(), 5); + + let sliced_validity = sliced.validity()?; + let expected = [true, false, true, true, false]; + for (i, &exp) in expected.iter().enumerate() { + assert_eq!( + sliced_validity.is_valid(i)?, + exp, + "sliced validity mismatch at index {i}" + ); + } + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index 746135c0d77..124924aa2b1 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -212,21 +212,23 @@ impl VTable for TurboQuant { let ext = TurboQuant::validate_dtype(dtype)?; let dimension = extension_list_size(ext)?; let element_ptype = extension_element_ptype(ext)?; - let element_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); let padded_dim = dimension.next_power_of_two() as usize; - // Get the codes array (indices into the codebook). + // Get the codes array (indices into the codebook). Codes are always non-nullable; + // null vectors are represented by all-zero codes with a null norm. let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable); let codes_dtype = DType::FixedSizeList( Arc::new(codes_ptype), padded_dim as u32, - dtype.nullability(), + Nullability::NonNullable, ); let codes_array = children.get(0, &codes_dtype, len)?; - // Get the L2 norms array. - let norms_array = children.get(1, &element_dtype, len)?; + // Get the L2 norms array. Norms carry the validity of the entire TurboQuant array: + // null vectors have null norms. + let norms_dtype = DType::Primitive(element_ptype, dtype.nullability()); + let norms_array = children.get(1, &norms_dtype, len)?; // Get the centroids array (codebook). let num_centroids = if bit_width == 0 { @@ -280,6 +282,6 @@ impl VTable for TurboQuant { impl ValidityChild for TurboQuant { fn validity_child(array: &TurboQuantData) -> &ArrayRef { - array.codes() + array.norms() } } From 2113efe6e54f31be75ad07bc5f22b754635a05c4 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 18:02:02 -0400 Subject: [PATCH 09/13] rebase to new vtable world Signed-off-by: Connor Tsui --- .../src/encodings/turboquant/array/data.rs | 20 +- .../encodings/turboquant/array/metadata.rs | 12 -- .../src/encodings/turboquant/array/mod.rs | 1 - .../src/encodings/turboquant/compress.rs | 9 +- .../src/encodings/turboquant/compute/slice.rs | 20 +- .../src/encodings/turboquant/compute/take.rs | 20 +- .../src/encodings/turboquant/decompress.rs | 2 +- vortex-tensor/src/encodings/turboquant/mod.rs | 1 - .../src/encodings/turboquant/tests.rs | 77 -------- .../src/encodings/turboquant/vtable.rs | 171 +++++++++--------- 10 files changed, 112 insertions(+), 221 deletions(-) delete mode 100644 vortex-tensor/src/encodings/turboquant/array/metadata.rs diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs index f369e61e35e..43d84777c07 100644 --- a/vortex-tensor/src/encodings/turboquant/array/data.rs +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -7,7 +7,6 @@ use vortex_array::ArrayRef; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; -use vortex_array::stats::ArrayStats; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; @@ -29,12 +28,6 @@ use crate::utils::extension_list_size; /// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty. #[derive(Clone, Debug)] pub struct TurboQuantData { - /// The [`Vector`](crate::vector::Vector) extension dtype that this array encodes. - /// - /// The storage dtype within the extension determines the element type (f16, f32, or f64) and - /// the list size (dimension). - pub(crate) dtype: DType, - /// Child arrays stored as slots. See [`Slot`] for positions: /// /// - [`Codes`](Slot::Codes): Non-nullable `FixedSizeListArray` with @@ -66,9 +59,6 @@ pub struct TurboQuantData { /// /// This is 0 for degenerate empty arrays. pub(crate) bit_width: u8, - - /// The stats for this array. - pub(crate) stats_set: ArrayStats, } impl TurboQuantData { @@ -83,13 +73,13 @@ impl TurboQuantData { /// Returns an error if the provided components do not satisfy the invariants documented /// in [`new_unchecked`](Self::new_unchecked). pub fn try_new( - dtype: DType, + dtype: &DType, codes: ArrayRef, norms: ArrayRef, centroids: ArrayRef, rotation_signs: ArrayRef, ) -> VortexResult { - Self::validate(&dtype, &codes, &norms, ¢roids, &rotation_signs)?; + Self::validate(dtype, &codes, &norms, ¢roids, &rotation_signs)?; // SAFETY: we validate that the inputs are valid above. Ok(unsafe { Self::new_unchecked(dtype, codes, norms, centroids, rotation_signs) }) @@ -115,14 +105,14 @@ impl TurboQuantData { /// /// Violating these invariants may produce incorrect results during decompression. pub unsafe fn new_unchecked( - dtype: DType, + dtype: &DType, codes: ArrayRef, norms: ArrayRef, centroids: ArrayRef, rotation_signs: ArrayRef, ) -> Self { #[cfg(debug_assertions)] - Self::validate(&dtype, &codes, &norms, ¢roids, &rotation_signs) + Self::validate(dtype, &codes, &norms, ¢roids, &rotation_signs) .vortex_expect("[Debug Assertion]: Invalid TurboQuantData parameters"); let dimension = dtype @@ -147,11 +137,9 @@ impl TurboQuantData { slots[Slot::RotationSigns as usize] = Some(rotation_signs); Self { - dtype, slots, dimension, bit_width, - stats_set: Default::default(), } } diff --git a/vortex-tensor/src/encodings/turboquant/array/metadata.rs b/vortex-tensor/src/encodings/turboquant/array/metadata.rs deleted file mode 100644 index 2fead1db835..00000000000 --- a/vortex-tensor/src/encodings/turboquant/array/metadata.rs +++ /dev/null @@ -1,12 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -/// Serialized metadata for TurboQuant encoding: a single byte holding the `bit_width` (0-8). -/// -/// All other fields (dimension, element type) are derived from the dtype and children. -/// A `bit_width` of 0 indicates a degenerate empty array. -#[derive(Clone, Debug)] -pub struct TurboQuantMetadata { - /// MSE bits per coordinate (0 for degenerate empty arrays, 1-8 otherwise). - pub bit_width: u8, -} diff --git a/vortex-tensor/src/encodings/turboquant/array/mod.rs b/vortex-tensor/src/encodings/turboquant/array/mod.rs index 3a5281f767c..0c98c974203 100644 --- a/vortex-tensor/src/encodings/turboquant/array/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/array/mod.rs @@ -5,7 +5,6 @@ //! and rotation signs. pub(crate) mod data; -pub(crate) mod metadata; pub(crate) mod slots; pub(crate) mod centroids; diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 119485b9309..241c7089b61 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -20,11 +20,12 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_fastlanes::bitpack_compress::bitpack_encode; -use crate::encodings::turboquant::TurboQuantData; +use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::array::centroids::compute_boundaries; use crate::encodings::turboquant::array::centroids::find_nearest_centroid; use crate::encodings::turboquant::array::centroids::get_centroids; use crate::encodings::turboquant::array::rotation::RotationMatrix; +use crate::encodings::turboquant::vtable::TurboQuantArray; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_norm::L2Norm; @@ -168,7 +169,7 @@ fn build_turboquant( fsl: &FixedSizeListArray, core: QuantizationResult, ext_dtype: DType, -) -> VortexResult { +) -> VortexResult { let num_rows = fsl.len(); let padded_dim = core.padded_dim; let codes_elements = @@ -190,7 +191,7 @@ fn build_turboquant( let rotation_signs = bitpack_rotation_signs(&core.rotation)?; - TurboQuantData::try_new( + TurboQuant::try_new_array( ext_dtype, codes, core.norms_array, @@ -241,7 +242,7 @@ pub fn turboquant_encode( let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); let empty_signs = PrimitiveArray::empty::(Nullability::NonNullable); - return Ok(TurboQuantData::try_new( + return Ok(TurboQuant::try_new_array( ext_dtype, empty_codes.into_array(), empty_norms, diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs index 8c6805f24ef..a8daef6466b 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/slice.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/slice.rs @@ -10,7 +10,6 @@ use vortex_array::arrays::slice::SliceReduce; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::TurboQuantData; impl SliceReduce for TurboQuant { fn slice( @@ -20,14 +19,15 @@ impl SliceReduce for TurboQuant { let sliced_codes = array.codes().slice(range.clone())?; let sliced_norms = array.norms().slice(range)?; - let result = TurboQuantData::try_new( - array.dtype.clone(), - sliced_codes, - sliced_norms, - array.centroids().clone(), - array.rotation_signs().clone(), - )?; - - Ok(Some(result.into_array())) + Ok(Some( + TurboQuant::try_new_array( + array.dtype().clone(), + sliced_codes, + sliced_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + )? + .into_array(), + )) } } diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs index 7b52baf804e..7614f1577a7 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/take.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/take.rs @@ -9,7 +9,6 @@ use vortex_array::arrays::dict::TakeExecute; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::TurboQuantData; impl TakeExecute for TurboQuant { fn take( @@ -21,14 +20,15 @@ impl TakeExecute for TurboQuant { let taken_codes = array.codes().take(indices.clone())?; let taken_norms = array.norms().take(indices.clone())?; - let result = TurboQuantData::try_new( - array.dtype.clone(), - taken_codes, - taken_norms, - array.centroids().clone(), - array.rotation_signs().clone(), - )?; - - Ok(Some(result.into_array())) + Ok(Some( + TurboQuant::try_new_array( + array.dtype().clone(), + taken_codes, + taken_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + )? + .into_array(), + )) } } diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index 0828da3cc05..362207913a3 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -36,7 +36,7 @@ pub fn execute_decompress( let dim = array.dimension() as usize; let padded_dim = array.padded_dim() as usize; let num_rows = array.norms().len(); - let ext_dtype = array.dtype.as_extension().clone(); + let ext_dtype = array.dtype().as_extension().clone(); let element_ptype = extension_element_ptype(&ext_dtype)?; if num_rows == 0 { diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index ab311799381..4724a14b82b 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -100,7 +100,6 @@ pub fn initialize(session: &mut VortexSession) { mod array; pub use array::data::TurboQuantData; -pub use array::metadata::TurboQuantMetadata; pub use array::scheme::TurboQuantScheme; pub(crate) mod compute; diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs index a91f9d2b3e6..0edaa4b5753 100644 --- a/vortex-tensor/src/encodings/turboquant/tests.rs +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -8,7 +8,6 @@ use rand::rngs::StdRng; use rand_distr::Distribution; use rand_distr::Normal; use rstest::rstest; -use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::ExtensionArray; @@ -451,82 +450,6 @@ fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { Ok(()) } -// ----------------------------------------------------------------------- -// Serde roundtrip -// ----------------------------------------------------------------------- - -#[test] -fn serde_roundtrip() -> VortexResult<()> { - use vortex_array::vtable::VTable; - - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(&ext, &config, &mut ctx)?; - let encoded = encoded.as_opt::().unwrap(); - - // Serialize metadata. - let metadata = ::metadata(encoded)?; - let serialized = - ::serialize(metadata)?.expect("metadata should serialize"); - - // Collect children. - let nchildren = ::nchildren(encoded); - assert_eq!(nchildren, 4); - let children: Vec = (0..nchildren) - .map(|i| ::child(encoded, i)) - .collect(); - - // Deserialize and rebuild. - let deserialized = ::deserialize( - &serialized, - encoded.dtype(), - encoded.len(), - &[], - &SESSION, - )?; - - // Verify metadata fields survived roundtrip. - assert_eq!(deserialized.bit_width, encoded.bit_width()); - - // Verify the rebuilt array decodes identically. - let mut ctx = SESSION.create_execution_ctx(); - let decoded_original = encoded - .array() - .clone() - .execute::(&mut ctx)?; - let original_fsl = decoded_original - .storage_array() - .to_canonical()? - .into_fixed_size_list(); - let original_elements = original_fsl.elements().to_canonical()?.into_primitive(); - - // Rebuild from children (simulating deserialization). - let rebuilt = crate::encodings::turboquant::TurboQuantData::try_new( - encoded.dtype().clone(), - children[0].clone(), - children[1].clone(), - children[2].clone(), - children[3].clone(), - )?; - let decoded_rebuilt = rebuilt.into_array().execute::(&mut ctx)?; - let rebuilt_fsl = decoded_rebuilt - .storage_array() - .to_canonical()? - .into_fixed_size_list(); - let rebuilt_elements = rebuilt_fsl.elements().to_canonical()?.into_primitive(); - - assert_eq!( - original_elements.as_slice::(), - rebuilt_elements.as_slice::() - ); - Ok(()) -} - // ----------------------------------------------------------------------- // Compute pushdown tests // ----------------------------------------------------------------------- diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index 124924aa2b1..841ecb84924 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -4,12 +4,14 @@ //! VTable implementation for TurboQuant encoding. use std::hash::Hash; +use std::hash::Hasher; use std::sync::Arc; use vortex_array::Array; use vortex_array::ArrayEq; use vortex_array::ArrayHash; use vortex_array::ArrayId; +use vortex_array::ArrayParts; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; @@ -21,7 +23,6 @@ use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDTypeRef; use vortex_array::serde::ArrayChildren; -use vortex_array::stats::ArrayStats; use vortex_array::vtable; use vortex_array::vtable::VTable; use vortex_array::vtable::ValidityChild; @@ -34,7 +35,6 @@ use vortex_error::vortex_panic; use vortex_session::VortexSession; use crate::encodings::turboquant::TurboQuantData; -use crate::encodings::turboquant::TurboQuantMetadata; use crate::encodings::turboquant::array::slots::Slot; use crate::encodings::turboquant::compute::rules::PARENT_KERNELS; use crate::encodings::turboquant::compute::rules::RULES; @@ -71,42 +71,58 @@ impl TurboQuant { Ok(ext) } + + /// Creates a new [`TurboQuantArray`]. + /// + /// Internallay calls [`TurboQuantData::try_new`]. + pub fn try_new_array( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + ) -> VortexResult { + let data = TurboQuantData::try_new(&dtype, codes, norms, centroids, rotation_signs)?; + + let parts = ArrayParts::new(TurboQuant, dtype, data.norms().len(), data); + + Array::try_from_parts(parts) + } } vtable!(TurboQuant, TurboQuant, TurboQuantData); impl VTable for TurboQuant { type ArrayData = TurboQuantData; - type Metadata = TurboQuantMetadata; type OperationsVTable = TurboQuant; type ValidityVTable = ValidityVTableFromChild; - fn vtable(_array: &Self::ArrayData) -> &Self { - &TurboQuant - } - fn id(&self) -> ArrayId { Self::ID } - fn len(array: &TurboQuantData) -> usize { - array.norms().len() - } + fn validate(&self, data: &Self::ArrayData, dtype: &DType, _len: usize) -> VortexResult<()> { + let ext = dtype + .as_extension_opt() + .filter(|e| e.is::()) + .ok_or_else(|| { + vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") + })?; - fn dtype(array: &TurboQuantData) -> &DType { - &array.dtype - } + let dimension = extension_list_size(ext)?; + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + vortex_ensure_eq!(data.dimension(), dimension); - fn stats(array: &TurboQuantData) -> &ArrayStats { - &array.stats_set + // TODO(connor): In the future, we will not need to validate `len` on the array data because + // the child arrays will be located somewhere else. + Ok(()) } - fn array_hash( - array: &TurboQuantData, - state: &mut H, - precision: Precision, - ) { - array.dtype.hash(state); + fn array_hash(array: &TurboQuantData, state: &mut H, precision: Precision) { array.dimension.hash(state); array.bit_width.hash(state); for slot in &array.slots { @@ -118,8 +134,7 @@ impl VTable for TurboQuant { } fn array_eq(array: &TurboQuantData, other: &TurboQuantData, precision: Precision) -> bool { - array.dtype == other.dtype - && array.dimension == other.dimension + array.dimension == other.dimension && array.bit_width == other.bit_width && array.slots.len() == other.slots.len() && array @@ -145,84 +160,45 @@ impl VTable for TurboQuant { None } - fn slots(array: ArrayView<'_, Self>) -> &[Option] { - &array.data().slots - } - - fn slot_name(_array: ArrayView, idx: usize) -> String { - Slot::from_index(idx).name().to_string() - } - - fn with_slots(array: &mut TurboQuantData, slots: Vec>) -> VortexResult<()> { - vortex_ensure!( - slots.len() == Slot::COUNT, - "TurboQuantArray expects {} slots, got {}", - Slot::COUNT, - slots.len() - ); - array.slots = slots; - Ok(()) - } - - fn metadata(array: ArrayView) -> VortexResult { - Ok(TurboQuantMetadata { - bit_width: array.bit_width, - }) - } - - fn serialize(metadata: Self::Metadata) -> VortexResult>> { - Ok(Some(vec![metadata.bit_width])) + fn serialize(array: ArrayView<'_, Self>) -> VortexResult>> { + Ok(Some(vec![array.bit_width])) } fn deserialize( - bytes: &[u8], - _dtype: &DType, - _len: usize, + &self, + dtype: &DType, + len: usize, + metadata: &[u8], _buffers: &[BufferHandle], + children: &dyn ArrayChildren, _session: &VortexSession, - ) -> VortexResult { + ) -> VortexResult { vortex_ensure_eq!( - bytes.len(), + metadata.len(), 1, "TurboQuant metadata must be exactly 1 byte, got {}", - bytes.len() + metadata.len() ); vortex_ensure!( - bytes[0] <= 8, + metadata[0] <= 8, "bit_width is expected to be between 0 and 8, got {}", - bytes[0] + metadata[0] ); - Ok(TurboQuantMetadata { - bit_width: bytes[0], - }) - } - - #[allow(clippy::cast_possible_truncation)] - fn build( - dtype: &DType, - len: usize, - metadata: &Self::Metadata, - _buffers: &[BufferHandle], - children: &dyn ArrayChildren, - ) -> VortexResult { - let bit_width = metadata.bit_width; + let bit_width = metadata[0]; // Validate and derive dimension and element ptype from the Vector extension dtype. let ext = TurboQuant::validate_dtype(dtype)?; let dimension = extension_list_size(ext)?; let element_ptype = extension_element_ptype(ext)?; - let padded_dim = dimension.next_power_of_two() as usize; + let padded_dim = dimension.next_power_of_two(); // Get the codes array (indices into the codebook). Codes are always non-nullable; // null vectors are represented by all-zero codes with a null norm. let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable); - let codes_dtype = DType::FixedSizeList( - Arc::new(codes_ptype), - padded_dim as u32, - Nullability::NonNullable, - ); + let codes_dtype = + DType::FixedSizeList(Arc::new(codes_ptype), padded_dim, Nullability::NonNullable); let codes_array = children.get(0, &codes_dtype, len)?; // Get the L2 norms array. Norms carry the validity of the entire TurboQuant array: @@ -240,12 +216,11 @@ impl VTable for TurboQuant { let centroids = children.get(2, ¢roids_dtype, num_centroids)?; // Get the rotation array. - let signs_len = if len == 0 { 0 } else { 3 * padded_dim }; + let signs_len = if len == 0 { 0 } else { 3 * padded_dim as usize }; let signs_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); let rotation_signs = children.get(3, &signs_dtype, signs_len)?; Ok(TurboQuantData { - dtype: dtype.clone(), slots: vec![ Some(codes_array), Some(norms_array), @@ -254,16 +229,30 @@ impl VTable for TurboQuant { ], dimension, bit_width, - stats_set: Default::default(), }) } - fn reduce_parent( - array: ArrayView, - parent: &ArrayRef, - child_idx: usize, - ) -> VortexResult> { - RULES.evaluate(array, parent, child_idx) + fn slots(array: ArrayView<'_, Self>) -> &[Option] { + &array.data().slots + } + + fn slot_name(_array: ArrayView, idx: usize) -> String { + Slot::from_index(idx).name().to_string() + } + + fn with_slots(array: &mut TurboQuantData, slots: Vec>) -> VortexResult<()> { + vortex_ensure!( + slots.len() == Slot::COUNT, + "TurboQuantArray expects {} slots, got {}", + Slot::COUNT, + slots.len() + ); + array.slots = slots; + Ok(()) + } + + fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(ExecutionResult::done(execute_decompress(array, ctx)?)) } fn execute_parent( @@ -275,8 +264,12 @@ impl VTable for TurboQuant { PARENT_KERNELS.execute(array, parent, child_idx, ctx) } - fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(ExecutionResult::done(execute_decompress(array, ctx)?)) + fn reduce_parent( + array: ArrayView, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + RULES.evaluate(array, parent, child_idx) } } From 2df38ce7ba9b045efd955d5d87745f8ed0ed7573 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 18:54:02 -0400 Subject: [PATCH 10/13] fix cosine similarity and dot product Signed-off-by: Connor Tsui --- .../turboquant/compute/cosine_similarity.rs | 118 +++++++++++------- vortex-tensor/src/encodings/turboquant/mod.rs | 8 -- .../src/encodings/turboquant/tests.rs | 111 ++++++++++++++++ .../src/encodings/turboquant/vtable.rs | 20 ++- 4 files changed, 198 insertions(+), 59 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs index 98935e5fb4e..8cc151f397c 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -27,51 +27,47 @@ //! distortion: at 4 bits the error is typically < 0.1, at 8 bits < 0.001. //! //! For approximate nearest neighbor (ANN) search, biased-but-accurate ranking is -//! usually sufficient — the relative ordering of cosine similarities is preserved +//! usually sufficient -- the relative ordering of cosine similarities is preserved //! even if the absolute values have bounded error. +use num_traits::FromPrimitive; +use num_traits::Zero; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; -use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; use crate::encodings::turboquant::TurboQuant; +use crate::utils::extension_element_ptype; -/// Shared helper: read codes, norms, and centroids from two TurboQuant arrays, -/// then compute per-row quantized unit-norm dot products. +/// Convert an f32 value to `T`, returning `T::zero()` if the conversion fails. /// -/// Both arrays must have the same dimension (vector length) and row count. -/// They may have different codebooks (e.g., different bit widths), in which -/// case each array's own centroids are used for its code lookups. +/// This helper exists because `half::f16` has an inherent `from_f32` method that shadows +/// the [`FromPrimitive`] trait method, causing compilation errors when used inside +/// [`match_each_float_ptype!`]. +#[inline] +fn f32_to_t(v: f32) -> T { + FromPrimitive::from_f32(v).unwrap_or_else(T::zero) +} + +/// Compute the per-row unit-norm dot products in f32 (centroids are always f32). /// -/// Returns `(norms_a, norms_b, unit_dots)` where `unit_dots[i]` is the dot product -/// of the unit-norm quantized vectors for row i. -fn quantized_unit_dots( - lhs: ArrayView, - rhs: ArrayView, +/// Returns a `Vec` of length `num_rows`. +fn compute_unit_dots( + lhs: &ArrayView, + rhs: &ArrayView, ctx: &mut ExecutionCtx, -) -> VortexResult<(Vec, Vec, Vec)> { - vortex_ensure!( - lhs.dimension() == rhs.dimension(), - "TurboQuant quantized dot product requires matching dimensions, got {} and {}", - lhs.dimension(), - rhs.dimension() - ); - +) -> VortexResult> { let pd = lhs.padded_dim() as usize; let num_rows = lhs.norms().len(); - let lhs_norms: PrimitiveArray = lhs.norms().clone().execute(ctx)?; - let rhs_norms: PrimitiveArray = rhs.norms().clone().execute(ctx)?; - let na = lhs_norms.as_slice::(); - let nb = rhs_norms.as_slice::(); - let lhs_codes_fsl: FixedSizeListArray = lhs.codes().clone().execute(ctx)?; let rhs_codes_fsl: FixedSizeListArray = rhs.codes().clone().execute(ctx)?; let lhs_codes = lhs_codes_fsl.elements().to_canonical()?.into_primitive(); @@ -79,8 +75,8 @@ fn quantized_unit_dots( let ca = lhs_codes.as_slice::(); let cb = rhs_codes.as_slice::(); - // Read centroids from both arrays — they may have different codebooks - // (e.g., different bit widths). + // Read centroids from both arrays. They may have different codebooks (e.g., different bit + // widths). let lhs_centroids: PrimitiveArray = lhs.centroids().clone().execute(ctx)?; let rhs_centroids: PrimitiveArray = rhs.centroids().clone().execute(ctx)?; let cl = lhs_centroids.as_slice::(); @@ -98,49 +94,75 @@ fn quantized_unit_dots( dots.push(dot); } - Ok((na.to_vec(), nb.to_vec(), dots)) + Ok(dots) } /// Compute approximate cosine similarity for all rows between two TurboQuant /// arrays (same rotation matrix and codebook) without full decompression. +/// +/// Since TurboQuant stores unit-normalized rotated vectors, the dot product of the quantized +/// codes directly approximates cosine similarity without needing the stored norms. +/// +/// The output dtype matches the Vector's element type (f16, f32, or f64). pub fn cosine_similarity_quantized_column( lhs: ArrayView, rhs: ArrayView, ctx: &mut ExecutionCtx, ) -> VortexResult { - let num_rows = lhs.norms().len(); - let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?; + vortex_ensure_eq!( + lhs.dimension(), + rhs.dimension(), + "TurboQuant quantized dot product requires matching dimensions", + ); - let mut result = BufferMut::::with_capacity(num_rows); - for row in 0..num_rows { - if na[row] == 0.0 || nb[row] == 0.0 { - result.push(0.0); - } else { - // Unit-norm dot product IS the cosine similarity. - result.push(dots[row]); - } - } + let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let dots = compute_unit_dots(&lhs, &rhs, ctx)?; - Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) + // The unit-norm dot product IS the cosine similarity. Cast from f32 to the native type. + match_each_float_ptype!(element_ptype, |T| { + let mut result = BufferMut::::with_capacity(dots.len()); + for &dot in &dots { + result.push(f32_to_t(dot)); + } + Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) + }) } /// Compute approximate dot product for all rows between two TurboQuant /// arrays (same rotation matrix and codebook) without full decompression. /// -/// `dot_product(a, b) ≈ ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])` +/// `dot_product(a, b) = ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])` +/// +/// The output dtype matches the Vector's element type (f16, f32, or f64). pub fn dot_product_quantized_column( lhs: ArrayView, rhs: ArrayView, ctx: &mut ExecutionCtx, ) -> VortexResult { + vortex_ensure_eq!( + lhs.dimension(), + rhs.dimension(), + "TurboQuant quantized dot product requires matching dimensions", + ); + + let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let dots = compute_unit_dots(&lhs, &rhs, ctx)?; let num_rows = lhs.norms().len(); - let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?; - let mut result = BufferMut::::with_capacity(num_rows); - for row in 0..num_rows { - // Scale the unit-norm dot product by both norms to get the actual dot product. - result.push(na[row] * nb[row] * dots[row]); - } + let lhs_norms: PrimitiveArray = lhs.norms().clone().execute(ctx)?; + let rhs_norms: PrimitiveArray = rhs.norms().clone().execute(ctx)?; + + // Scale the f32 unit-norm dot product by native-precision norms. + match_each_float_ptype!(element_ptype, |T| { + let na = lhs_norms.as_slice::(); + let nb = rhs_norms.as_slice::(); + + let mut result = BufferMut::::with_capacity(num_rows); + for row in 0..num_rows { + let dot_t: T = f32_to_t(dots[row]); + result.push(na[row] * nb[row] * dot_t); + } - Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) + Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) + }) } diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 4724a14b82b..0bd6efff68b 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -90,14 +90,6 @@ //! assert!(encoded.nbytes() < 51200); //! ``` -use vortex_array::session::ArraySessionExt; -use vortex_session::VortexSession; - -/// Initialize the TurboQuant encoding in the given session. -pub fn initialize(session: &mut VortexSession) { - session.arrays().register(TurboQuant); -} - mod array; pub use array::data::TurboQuantData; pub use array::scheme::TurboQuantScheme; diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs index 0edaa4b5753..731789bea2c 100644 --- a/vortex-tensor/src/encodings/turboquant/tests.rs +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -798,3 +798,114 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> { } Ok(()) } + +// ----------------------------------------------------------------------- +// Serde roundtrip tests +// ----------------------------------------------------------------------- + +/// Verify that a TurboQuant array survives serialize/deserialize. +#[test] +fn serde_roundtrip() -> VortexResult<()> { + use vortex_array::ArrayContext; + use vortex_array::ArrayEq; + use vortex_array::Precision; + use vortex_array::serde::SerializeOptions; + use vortex_array::serde::SerializedArray; + use vortex_array::session::ArraySessionExt; + use vortex_buffer::ByteBufferMut; + use vortex_fastlanes::BitPacked; + use vortex_session::registry::ReadContext; + + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + let dtype = encoded.dtype().clone(); + let len = encoded.len(); + + // Serialize. + let array_ctx = ArrayContext::empty(); + let serialized = encoded.serialize(&array_ctx, &SerializeOptions::default())?; + + let mut concat = ByteBufferMut::empty(); + for buf in serialized { + concat.extend_from_slice(buf.as_ref()); + } + + // Deserialize. The session needs TurboQuant and BitPacked (for rotation signs) registered. + let serde_session = VortexSession::empty().with::(); + serde_session.arrays().register(TurboQuant); + serde_session.arrays().register(BitPacked); + + let parts = SerializedArray::try_from(concat.freeze())?; + let decoded = parts.decode( + &dtype, + len, + &ReadContext::new(array_ctx.to_ids()), + &serde_session, + )?; + + assert!( + decoded.array_eq(&encoded, Precision::Value), + "serde roundtrip did not preserve array equality" + ); + Ok(()) +} + +/// Verify that a degenerate (empty) TurboQuant array survives serialize/deserialize. +#[test] +fn serde_roundtrip_empty() -> VortexResult<()> { + use vortex_array::ArrayContext; + use vortex_array::ArrayEq; + use vortex_array::Precision; + use vortex_array::serde::SerializeOptions; + use vortex_array::serde::SerializedArray; + use vortex_array::session::ArraySessionExt; + use vortex_buffer::ByteBufferMut; + use vortex_fastlanes::BitPacked; + use vortex_session::registry::ReadContext; + + let fsl = make_fsl(0, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + assert_eq!(encoded.len(), 0); + + let dtype = encoded.dtype().clone(); + let len = encoded.len(); + + let array_ctx = ArrayContext::empty(); + let serialized = encoded.serialize(&array_ctx, &SerializeOptions::default())?; + + let mut concat = ByteBufferMut::empty(); + for buf in serialized { + concat.extend_from_slice(buf.as_ref()); + } + + let serde_session = VortexSession::empty().with::(); + serde_session.arrays().register(TurboQuant); + serde_session.arrays().register(BitPacked); + + let parts = SerializedArray::try_from(concat.freeze())?; + let decoded = parts.decode( + &dtype, + len, + &ReadContext::new(array_ctx.to_ids()), + &serde_session, + )?; + + assert!( + decoded.array_eq(&encoded, Precision::Value), + "serde roundtrip did not preserve array equality" + ); + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index 841ecb84924..d6f5f998041 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -74,7 +74,7 @@ impl TurboQuant { /// Creates a new [`TurboQuantArray`]. /// - /// Internallay calls [`TurboQuantData::try_new`]. + /// Internally calls [`TurboQuantData::try_new`]. pub fn try_new_array( dtype: DType, codes: ArrayRef, @@ -101,7 +101,7 @@ impl VTable for TurboQuant { Self::ID } - fn validate(&self, data: &Self::ArrayData, dtype: &DType, _len: usize) -> VortexResult<()> { + fn validate(&self, data: &Self::ArrayData, dtype: &DType, len: usize) -> VortexResult<()> { let ext = dtype .as_extension_opt() .filter(|e| e.is::()) @@ -117,8 +117,15 @@ impl VTable for TurboQuant { vortex_ensure_eq!(data.dimension(), dimension); - // TODO(connor): In the future, we will not need to validate `len` on the array data because + // TODO(connor): In the future, we may not need to validate `len` on the array data because // the child arrays will be located somewhere else. + // bit_width == 0 is only valid for degenerate (empty) arrays. A non-empty array with + // bit_width == 0 would have zero centroids while codes reference centroid indices. + vortex_ensure!( + data.bit_width > 0 || len == 0, + "bit_width == 0 is only valid for empty arrays, got len={len}" + ); + Ok(()) } @@ -187,6 +194,13 @@ impl VTable for TurboQuant { let bit_width = metadata[0]; + // bit_width == 0 is only valid for degenerate (empty) arrays. A non-empty array with + // bit_width == 0 would have zero centroids while codes reference centroid indices. + vortex_ensure!( + bit_width > 0 || len == 0, + "bit_width == 0 is only valid for empty arrays, got len={len}" + ); + // Validate and derive dimension and element ptype from the Vector extension dtype. let ext = TurboQuant::validate_dtype(dtype)?; let dimension = extension_list_size(ext)?; From 54b59b9887a8eac618ec73802ea13f4ff04f17d9 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 3 Apr 2026 19:01:10 -0400 Subject: [PATCH 11/13] fix validity handling Signed-off-by: Connor Tsui --- .../turboquant/compute/cosine_similarity.rs | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs index 8cc151f397c..c6c6531119b 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -39,7 +39,6 @@ use vortex_array::IntoArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::match_each_float_ptype; -use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; use vortex_error::vortex_ensure_eq; @@ -54,6 +53,7 @@ use crate::utils::extension_element_ptype; /// [`match_each_float_ptype!`]. #[inline] fn f32_to_t(v: f32) -> T { + // TODO(connor): Is this actually correct? How should we handle f64 overflow? FromPrimitive::from_f32(v).unwrap_or_else(T::zero) } @@ -70,8 +70,8 @@ fn compute_unit_dots( let lhs_codes_fsl: FixedSizeListArray = lhs.codes().clone().execute(ctx)?; let rhs_codes_fsl: FixedSizeListArray = rhs.codes().clone().execute(ctx)?; - let lhs_codes = lhs_codes_fsl.elements().to_canonical()?.into_primitive(); - let rhs_codes = rhs_codes_fsl.elements().to_canonical()?.into_primitive(); + let lhs_codes: PrimitiveArray = lhs_codes_fsl.elements().clone().execute(ctx)?; + let rhs_codes: PrimitiveArray = rhs_codes_fsl.elements().clone().execute(ctx)?; let ca = lhs_codes.as_slice::(); let cb = rhs_codes.as_slice::(); @@ -116,15 +116,19 @@ pub fn cosine_similarity_quantized_column( ); let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?; let dots = compute_unit_dots(&lhs, &rhs, ctx)?; // The unit-norm dot product IS the cosine similarity. Cast from f32 to the native type. match_each_float_ptype!(element_ptype, |T| { let mut result = BufferMut::::with_capacity(dots.len()); for &dot in &dots { - result.push(f32_to_t(dot)); + // SAFETY: We allocated the correct amount. + unsafe { result.push_unchecked(f32_to_t(dot)) }; } - Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) + + // SAFETY: `result` has the same length as the input arrays, matching `validity`. + Ok(unsafe { PrimitiveArray::new_unchecked(result.freeze(), validity) }.into_array()) }) } @@ -146,6 +150,7 @@ pub fn dot_product_quantized_column( ); let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?; let dots = compute_unit_dots(&lhs, &rhs, ctx)?; let num_rows = lhs.norms().len(); @@ -160,9 +165,11 @@ pub fn dot_product_quantized_column( let mut result = BufferMut::::with_capacity(num_rows); for row in 0..num_rows { let dot_t: T = f32_to_t(dots[row]); - result.push(na[row] * nb[row] * dot_t); + // SAFETY: We allocated the correct amount. + unsafe { result.push_unchecked(na[row] * nb[row] * dot_t) }; } - Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) + // SAFETY: `result` has the same length as the input arrays, matching `validity`. + Ok(unsafe { PrimitiveArray::new_unchecked(result.freeze(), validity) }.into_array()) }) } From d4c1cdf65aeaa3bbf39004ca06b1e7bf5ebec25e Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Sat, 4 Apr 2026 12:28:23 -0400 Subject: [PATCH 12/13] fix casting issues and other minor things Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 48 ++++--------------- .../src/encodings/turboquant/array/mod.rs | 12 +++++ .../src/encodings/turboquant/compress.rs | 10 +++- .../turboquant/compute/cosine_similarity.rs | 18 ++----- .../src/encodings/turboquant/decompress.rs | 8 ++-- vortex-tensor/src/encodings/turboquant/mod.rs | 1 + 6 files changed, 37 insertions(+), 60 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 5ab6d5c7e0d..8f3c0f9fbdb 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -10,6 +10,8 @@ impl vortex_tensor::encodings::turboquant::TurboQuant pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId +pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef> impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant @@ -24,8 +26,6 @@ impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquan pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData -pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::TurboQuantMetadata - pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant pub type vortex_tensor::encodings::turboquant::TurboQuant::ValidityVTable = vortex_array::array::vtable::validity::ValidityVTableFromChild @@ -38,11 +38,7 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer(_array: vortex_a pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer_name(_array: vortex_array::array::view::ArrayView<'_, Self>, _idx: usize) -> core::option::Option -pub fn vortex_tensor::encodings::turboquant::TurboQuant::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::dtype(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::dtype::DType +pub fn vortex_tensor::encodings::turboquant::TurboQuant::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute(array: vortex_array::array::typed::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult @@ -50,23 +46,17 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute_parent(array: v pub fn vortex_tensor::encodings::turboquant::TurboQuant::id(&self) -> vortex_array::array::ArrayId -pub fn vortex_tensor::encodings::turboquant::TurboQuant::len(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> usize - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::metadata(array: vortex_array::array::view::ArrayView<'_, Self>) -> vortex_error::VortexResult - pub fn vortex_tensor::encodings::turboquant::TurboQuant::nbuffers(_array: vortex_array::array::view::ArrayView<'_, Self>) -> usize pub fn vortex_tensor::encodings::turboquant::TurboQuant::reduce_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> -pub fn vortex_tensor::encodings::turboquant::TurboQuant::serialize(metadata: Self::Metadata) -> vortex_error::VortexResult>> +pub fn vortex_tensor::encodings::turboquant::TurboQuant::serialize(array: vortex_array::array::view::ArrayView<'_, Self>) -> vortex_error::VortexResult>> pub fn vortex_tensor::encodings::turboquant::TurboQuant::slot_name(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> alloc::string::String pub fn vortex_tensor::encodings::turboquant::TurboQuant::slots(array: vortex_array::array::view::ArrayView<'_, Self>) -> &[core::option::Option] -pub fn vortex_tensor::encodings::turboquant::TurboQuant::stats(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::stats::array::ArrayStats - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::vtable(_array: &Self::ArrayData) -> &Self +pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate(&self, data: &Self::ArrayData, dtype: &vortex_array::dtype::DType, len: usize) -> vortex_error::VortexResult<()> pub fn vortex_tensor::encodings::turboquant::TurboQuant::with_slots(array: &mut vortex_tensor::encodings::turboquant::TurboQuantData, slots: alloc::vec::Vec>) -> vortex_error::VortexResult<()> @@ -116,7 +106,7 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::codes(&self) -> &vo pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32 -pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> Self +pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dtype: &vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> Self pub fn vortex_tensor::encodings::turboquant::TurboQuantData::norms(&self) -> &vortex_array::array::erased::ArrayRef @@ -124,7 +114,7 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) - pub fn vortex_tensor::encodings::turboquant::TurboQuantData::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dtype: &vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()> @@ -132,30 +122,10 @@ impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantData pub fn vortex_tensor::encodings::turboquant::TurboQuantData::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantData -impl core::convert::From for vortex_array::array::erased::ArrayRef - -pub fn vortex_array::array::erased::ArrayRef::from(value: vortex_tensor::encodings::turboquant::TurboQuantData) -> vortex_array::array::erased::ArrayRef - impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantData pub fn vortex_tensor::encodings::turboquant::TurboQuantData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl vortex_array::array::IntoArray for vortex_tensor::encodings::turboquant::TurboQuantData - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::into_array(self) -> vortex_array::array::erased::ArrayRef - -pub struct vortex_tensor::encodings::turboquant::TurboQuantMetadata - -pub vortex_tensor::encodings::turboquant::TurboQuantMetadata::bit_width: u8 - -impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantMetadata - -pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantMetadata - -impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantMetadata - -pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - pub struct vortex_tensor::encodings::turboquant::TurboQuantScheme impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantScheme @@ -186,10 +156,10 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::matches(&self, ca pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self) -> &'static str -pub fn vortex_tensor::encodings::turboquant::initialize(session: &mut vortex_session::VortexSession) - pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: &vortex_array::arrays::extension::vtable::ExtensionArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub type vortex_tensor::encodings::turboquant::TurboQuantArray = vortex_array::array::typed::Array + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor diff --git a/vortex-tensor/src/encodings/turboquant/array/mod.rs b/vortex-tensor/src/encodings/turboquant/array/mod.rs index 0c98c974203..62bde49cfd3 100644 --- a/vortex-tensor/src/encodings/turboquant/array/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/array/mod.rs @@ -11,3 +11,15 @@ pub(crate) mod centroids; pub(crate) mod rotation; pub(crate) mod scheme; + +use num_traits::Float; +use num_traits::FromPrimitive; +use vortex_error::VortexExpect; + +/// Convert an f32 value to a float type `T`. +/// +/// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the +/// inherent `f16::from_f32()`, f32 is identity, f64 is lossless widening. +pub(crate) fn float_from_f32(v: f32) -> T { + FromPrimitive::from_f32(v).vortex_expect("f32-to-float conversion is infallible") +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 241c7089b61..b2f682a79c3 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -3,6 +3,7 @@ //! TurboQuant encoding (quantization) logic. +use num_traits::ToPrimitive; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -15,6 +16,7 @@ use vortex_array::dtype::PType; use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; @@ -105,14 +107,18 @@ fn turboquant_quantize_core( // the input, so null vectors get null norms automatically. let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, ext.as_ref().clone(), num_rows)?; let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; - let norms_prim: PrimitiveArray = norms_array.to_canonical()?.into_primitive(); + let norms_prim: PrimitiveArray = norms_array.clone().execute(ctx)?; // Extract f32 norms for the internal quantization loop. let f32_norms: Vec = match_each_float_ptype!(norms_prim.ptype(), |T| { norms_prim .as_slice::() .iter() - .map(|&v| num_traits::ToPrimitive::to_f32(&v).unwrap_or(0.0)) + .map(|&v| { + // `ToPrimitive::to_f32` is infallible for all float types: f16 -> f32 is lossless, + // f32 is identity, and f64 -> f32 saturates to +-inf. + ToPrimitive::to_f32(&v).vortex_expect("float-to-f32 conversion is infallible") + }) .collect() }); diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs index c6c6531119b..a5bbc63bde2 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -30,8 +30,6 @@ //! usually sufficient -- the relative ordering of cosine similarities is preserved //! even if the absolute values have bounded error. -use num_traits::FromPrimitive; -use num_traits::Zero; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; @@ -44,19 +42,9 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure_eq; use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::array::float_from_f32; use crate::utils::extension_element_ptype; -/// Convert an f32 value to `T`, returning `T::zero()` if the conversion fails. -/// -/// This helper exists because `half::f16` has an inherent `from_f32` method that shadows -/// the [`FromPrimitive`] trait method, causing compilation errors when used inside -/// [`match_each_float_ptype!`]. -#[inline] -fn f32_to_t(v: f32) -> T { - // TODO(connor): Is this actually correct? How should we handle f64 overflow? - FromPrimitive::from_f32(v).unwrap_or_else(T::zero) -} - /// Compute the per-row unit-norm dot products in f32 (centroids are always f32). /// /// Returns a `Vec` of length `num_rows`. @@ -124,7 +112,7 @@ pub fn cosine_similarity_quantized_column( let mut result = BufferMut::::with_capacity(dots.len()); for &dot in &dots { // SAFETY: We allocated the correct amount. - unsafe { result.push_unchecked(f32_to_t(dot)) }; + unsafe { result.push_unchecked(float_from_f32(dot)) }; } // SAFETY: `result` has the same length as the input arrays, matching `validity`. @@ -164,7 +152,7 @@ pub fn dot_product_quantized_column( let mut result = BufferMut::::with_capacity(num_rows); for row in 0..num_rows { - let dot_t: T = f32_to_t(dots[row]); + let dot_t: T = float_from_f32(dots[row]); // SAFETY: We allocated the correct amount. unsafe { result.push_unchecked(na[row] * nb[row] * dot_t) }; } diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index 362207913a3..b4b9fbaccfd 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -3,8 +3,8 @@ //! TurboQuant decoding (dequantization) logic. +use num_traits::Float; use num_traits::FromPrimitive; -use num_traits::Zero; use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; @@ -20,6 +20,7 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::array::float_from_f32; use crate::encodings::turboquant::array::rotation::RotationMatrix; use crate::utils::extension_element_ptype; @@ -103,7 +104,7 @@ pub fn execute_decompress( } /// Typed decompress: reads norms as `T`, dequantizes in f32, and produces output as `T`. -fn decompress_typed( +fn decompress_typed( norms_prim: &PrimitiveArray, centroids: &[f32], rotation: &RotationMatrix, @@ -129,8 +130,7 @@ fn decompress_typed( rotation.inverse_rotate(&dequantized, &mut unrotated); for idx in 0..dim { - // Convert f32 dequantized value to T, then scale by the native-precision norm. - let val = T::from_f32(unrotated[idx]).unwrap_or_else(T::zero) * norm; + let val = float_from_f32::(unrotated[idx]) * norm; output.push(val); } } diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 0bd6efff68b..59554f398d2 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -98,6 +98,7 @@ pub(crate) mod compute; mod vtable; pub use vtable::TurboQuant; +pub use vtable::TurboQuantArray; mod compress; pub use compress::TurboQuantConfig; From 5285a1331d3b47605d938a7886a42a54c72c1917 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Sat, 4 Apr 2026 12:52:33 -0400 Subject: [PATCH 13/13] change defaults and constraints and tests Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 2 + .../encodings/turboquant/array/centroids.rs | 11 +- .../src/encodings/turboquant/array/data.rs | 2 +- .../src/encodings/turboquant/array/scheme.rs | 12 +- .../src/encodings/turboquant/compress.rs | 7 +- .../turboquant/compute/cosine_similarity.rs | 22 ++- .../src/encodings/turboquant/tests.rs | 168 ++++++++++++++++-- .../src/encodings/turboquant/vtable.rs | 15 +- 8 files changed, 201 insertions(+), 38 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 8f3c0f9fbdb..349d65144f9 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -10,6 +10,8 @@ impl vortex_tensor::encodings::turboquant::TurboQuant pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId +pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32 + pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef> diff --git a/vortex-tensor/src/encodings/turboquant/array/centroids.rs b/vortex-tensor/src/encodings/turboquant/array/centroids.rs index 85ea39fcc9e..a58945cf9fe 100644 --- a/vortex-tensor/src/encodings/turboquant/array/centroids.rs +++ b/vortex-tensor/src/encodings/turboquant/array/centroids.rs @@ -15,6 +15,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_utils::aliases::dash_map::DashMap; +use crate::encodings::turboquant::TurboQuant; + /// Number of numerical integration points for computing conditional expectations. const INTEGRATION_POINTS: usize = 1000; @@ -36,8 +38,11 @@ pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { if !(1..=8).contains(&bit_width) { vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}"); } - if dimension < 3 { - vortex_bail!("TurboQuant dimension must be >= 3, got {dimension}"); + if dimension < TurboQuant::MIN_DIMENSION { + vortex_bail!( + "TurboQuant dimension must be >= {}, got {dimension}", + TurboQuant::MIN_DIMENSION + ); } if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { @@ -306,6 +311,6 @@ mod tests { assert!(get_centroids(128, 0).is_err()); assert!(get_centroids(128, 9).is_err()); assert!(get_centroids(1, 2).is_err()); - assert!(get_centroids(2, 2).is_err()); + assert!(get_centroids(127, 2).is_err()); } } diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs index 43d84777c07..289271d15f1 100644 --- a/vortex-tensor/src/encodings/turboquant/array/data.rs +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -92,7 +92,7 @@ impl TurboQuantData { /// The caller must ensure: /// /// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size - /// is >= 3. + /// is >= [`MIN_DIMENSION`](crate::encodings::turboquant::TurboQuant::MIN_DIMENSION). /// - `codes` is a non-nullable `FixedSizeListArray` with `list_size == padded_dim` and /// `codes.len() == norms.len()`. Null vectors are represented by all-zero codes. /// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage diff --git a/vortex-tensor/src/encodings/turboquant/array/scheme.rs b/vortex-tensor/src/encodings/turboquant/array/scheme.rs index 2b954522676..de7a6a85302 100644 --- a/vortex-tensor/src/encodings/turboquant/array/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/array/scheme.rs @@ -114,12 +114,12 @@ mod tests { /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~4-6x. /// f32 input at 1024-d (no padding) should give higher ratio since no waste. #[rstest] - #[case::f32_768d(32, 768, 1000, 3.5, 8.0)] - #[case::f32_1024d(32, 1024, 1000, 5.0, 9.0)] - #[case::f32_1536d(32, 1536, 1000, 3.0, 8.0)] - #[case::f32_128d(32, 128, 1000, 4.0, 8.0)] - #[case::f64_768d(64, 768, 1000, 7.0, 16.0)] - #[case::f16_768d(16, 768, 1000, 1.5, 4.5)] + #[case::f32_768d(32, 768, 1000, 2.5, 4.0)] + #[case::f32_1024d(32, 1024, 1000, 3.5, 5.0)] + #[case::f32_1536d(32, 1536, 1000, 2.5, 4.0)] + #[case::f32_128d(32, 128, 1000, 3.0, 5.0)] + #[case::f64_768d(64, 768, 1000, 5.0, 7.0)] + #[case::f16_768d(16, 768, 1000, 1.2, 2.0)] fn compression_ratio_in_expected_range( #[case] bits_per_element: usize, #[case] dim: u32, diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index b2f682a79c3..34f2f2de8d9 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -43,7 +43,7 @@ pub struct TurboQuantConfig { impl Default for TurboQuantConfig { fn default() -> Self { Self { - bit_width: 4, + bit_width: 8, seed: Some(42), } } @@ -226,8 +226,9 @@ pub fn turboquant_encode( ); let dimension = fsl.list_size(); vortex_ensure!( - dimension >= 3, - "TurboQuant requires dimension >= 3, got {dimension}" + dimension >= TurboQuant::MIN_DIMENSION, + "TurboQuant requires dimension >= {}, got {dimension}", + TurboQuant::MIN_DIMENSION ); if fsl.is_empty() { diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs index a5bbc63bde2..e9bcd17ce96 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -4,9 +4,9 @@ //! Approximate cosine similarity in the quantized domain. //! //! Since the SRHT is orthogonal, inner products are preserved in the rotated -//! domain. For two vectors from the same TurboQuant column (same rotation and -//! centroids), we can compute the dot product of their quantized representations -//! without full decompression: +//! domain. For two TurboQuant arrays that share the same SRHT rotation (i.e., +//! encoded from the same column), we can compute the dot product of their +//! quantized representations without full decompression: //! //! ```text //! cos_approx(a, b) = sum(centroids[code_a[j]] × centroids[code_b[j]]) @@ -85,8 +85,12 @@ fn compute_unit_dots( Ok(dots) } -/// Compute approximate cosine similarity for all rows between two TurboQuant -/// arrays (same rotation matrix and codebook) without full decompression. +/// Compute approximate cosine similarity for all rows between two TurboQuant arrays without +/// full decompression. +/// +/// Both arrays must share the same rotation (i.e., were encoded from the same TurboQuant +/// column). For this function, results are meaningless if the rotations differ (there are other +/// methods that can allow this, but that is future work). /// /// Since TurboQuant stores unit-normalized rotated vectors, the dot product of the quantized /// codes directly approximates cosine similarity without needing the stored norms. @@ -120,8 +124,12 @@ pub fn cosine_similarity_quantized_column( }) } -/// Compute approximate dot product for all rows between two TurboQuant -/// arrays (same rotation matrix and codebook) without full decompression. +/// Compute approximate dot product for all rows between two TurboQuant arrays without +/// full decompression. +/// +/// Both arrays must share the same SRHT rotation (i.e., were encoded from the same TurboQuant +/// column). For this function, results are meaningless if the rotations differ (there are other +/// methods that can allow this, but that is future work). /// /// `dot_product(a, b) = ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])` /// diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs index 731789bea2c..ef77bb9e3ec 100644 --- a/vortex-tensor/src/encodings/turboquant/tests.rs +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -152,11 +152,9 @@ fn encode_decode( // ----------------------------------------------------------------------- #[rstest] -#[case(32, 1)] -#[case(32, 2)] -#[case(32, 3)] -#[case(32, 4)] +#[case(128, 1)] #[case(128, 2)] +#[case(128, 3)] #[case(128, 4)] #[case(128, 6)] #[case(128, 8)] @@ -280,8 +278,9 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { #[rstest] #[case(1)] -#[case(2)] -fn rejects_dimension_below_3(#[case] dim: usize) { +#[case(64)] +#[case(127)] +fn rejects_dimension_below_128(#[case] dim: usize) { let fsl = make_fsl_small(dim); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { @@ -340,7 +339,7 @@ fn all_zero_vectors_roundtrip() -> VortexResult<()> { #[test] fn f64_input_encodes_successfully() -> VortexResult<()> { let num_rows = 10; - let dim = 64; + let dim = 128; let mut rng = StdRng::seed_from_u64(99); let normal = Normal::new(0.0f64, 1.0).unwrap(); @@ -371,6 +370,48 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { Ok(()) } +/// Verify that f16 input is accepted and encoded (upcast to f32 internally). +#[test] +fn f16_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(half::f16::from_f32(normal.sample(&mut rng))); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + )?; + + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + assert_eq!(tq.norms().len(), num_rows); + assert_eq!(tq.dimension() as usize, dim); + + // Verify roundtrip: decode and check reconstruction is reasonable. + let decoded_ext = encoded.execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + assert_eq!(decoded_fsl.len(), num_rows); + Ok(()) +} + // ----------------------------------------------------------------------- // Verification tests for stored metadata // ----------------------------------------------------------------------- @@ -494,7 +535,7 @@ fn slice_preserves_data() -> VortexResult<()> { #[test] fn scalar_at_matches_decompress() -> VortexResult<()> { - let fsl = make_fsl(10, 64, 42); + let fsl = make_fsl(10, 128, 42); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 3, @@ -593,7 +634,9 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { .sum::() }; - // 4-bit quantization: expect reasonable accuracy. + // At 4-bit, the theoretical MSE bound per coordinate is ~0.0106 (Theorem 1). For cosine + // similarity (bounded [-1, 1]), the error is bounded roughly by 2*sqrt(MSE) ~ 0.2. We use + // 0.15 as a tighter empirical bound. let error = (exact_cos - approx_cos).abs(); assert!( error < 0.15, @@ -604,6 +647,105 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { Ok(()) } +/// Verify approximate dot product in the quantized domain. +/// +/// NOTE: The MSE quantizer (TurboQuant_mse) has inherent **multiplicative bias** for inner +/// products — the quantized dot product systematically over- or under-estimates the true value. +/// This is a fundamental property: the paper's `TurboQuant_prod` variant adds QJL specifically +/// to debias inner products, but we only implement the MSE-only variant. +/// +/// Even at 8-bit (near-lossless reconstruction, MSE ~4e-5), the quantized-domain dot product +/// can have ~10-15% relative error due to this bias. This tolerance is therefore intentionally +/// loose — we're testing that the approximation is in the right ballpark, not that it's precise. +/// +/// TODO(connor): Revisit these tolerances when we have TurboQuant_prod (QJL debiasing). +#[test] +fn dot_product_quantized_accuracy() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 8, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + + let mut ctx = SESSION.create_execution_ctx(); + let pd = tq.padded_dim() as usize; + let norms_prim = tq.norms().clone().execute::(&mut ctx)?; + let norms = norms_prim.as_slice::(); + let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let all_codes = codes_prim.as_slice::(); + let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; + let centroid_vals = centroids_prim.as_slice::(); + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let exact_dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); + + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + let unit_dot: f32 = codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) + .sum(); + let approx_dot = norms[row_a] * norms[row_b] * unit_dot; + + // See doc comment above: 15% relative error is expected due to MSE quantizer bias. + let scale = exact_dot.abs().max(1.0); + let rel_error = (exact_dot - approx_dot).abs() / scale; + assert!( + rel_error < 0.15, + "dot product error too large for ({row_a}, {row_b}): \ + exact={exact_dot:.4}, approx={approx_dot:.4}, rel_error={rel_error:.4}" + ); + } + Ok(()) +} + +/// Roundtrip at large embedding dimensions to validate padding and SRHT at common sizes. +/// +/// NOTE: The theoretical MSE bound (Theorem 1) is proved for Haar-distributed random orthogonal +/// matrices, not SRHT. The SRHT is a practical O(d log d) approximation that doesn't exactly +/// satisfy the Haar assumption, so empirical MSE can slightly exceed the theoretical bound. We +/// use a 2x multiplier to account for this gap. +/// +/// The 1024-d case uses 5-bit instead of 4-bit because at 4-bit the SRHT approximation error +/// at d=1024 pushes MSE ~20% above the 1x theoretical bound (0.0127 vs bound 0.0106). +/// +/// TODO(connor): Revisit after Stage 2 block decomposition — at d=768 with block_size=256, +/// the per-block SRHT will be lower-dimensional and may have different error characteristics. +#[rstest] +#[case(768, 4)] +#[case(1024, 5)] +fn large_dimension_roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 10; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + // 2x slack for the SRHT-vs-Haar gap (see doc comment above). + let bound = 2.0 * theoretical_mse_bound(bit_width); + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds 2x bound {bound:.6} for dim={dim}, bits={bit_width}", + ); + Ok(()) +} + /// Verify that the encoded array's dtype is a Vector extension type. #[test] fn encoded_dtype_is_vector_extension() -> VortexResult<()> { @@ -702,7 +844,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> { #[test] fn nullable_norms_match_validity() -> VortexResult<()> { let validity = Validity::from_iter([true, false, true, false, true]); - let fsl = make_fsl_with_validity(5, 64, 42, validity); + let fsl = make_fsl_with_validity(5, 128, 42, validity); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { @@ -729,7 +871,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> { #[test] fn nullable_l2_norm_readthrough() -> VortexResult<()> { let validity = Validity::from_iter([true, false, true, false, true]); - let fsl = make_fsl_with_validity(5, 64, 42, validity); + let fsl = make_fsl_with_validity(5, 128, 42, validity); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { @@ -749,7 +891,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> { for row in 0..5 { if row % 2 == 0 { assert!(norms.is_valid(row)?, "row {row} should be valid"); - let expected: f32 = orig_f32[row * 64..(row + 1) * 64] + let expected: f32 = orig_f32[row * 128..(row + 1) * 128] .iter() .map(|&v| v * v) .sum::() @@ -773,7 +915,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> { let validity = Validity::from_iter([ true, true, false, true, true, false, true, false, true, true, ]); - let fsl = make_fsl_with_validity(10, 64, 42, validity); + let fsl = make_fsl_with_validity(10, 128, 42, validity); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index d6f5f998041..1510132863c 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -50,8 +50,11 @@ pub struct TurboQuant; impl TurboQuant { pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); + /// Minimum vector dimension for TurboQuant encoding. + pub const MIN_DIMENSION: u32 = 128; + /// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with - /// dimension >= 3. + /// dimension >= [`MIN_DIMENSION`](Self::MIN_DIMENSION). /// /// Returns the validated [`ExtDTypeRef`] on success, which can be used to extract the /// element ptype and list size. @@ -65,8 +68,9 @@ impl TurboQuant { let dimension = extension_list_size(ext)?; vortex_ensure!( - dimension >= 3, - "TurboQuant requires dimension >= 3, got {dimension}" + dimension >= Self::MIN_DIMENSION, + "TurboQuant requires dimension >= {}, got {dimension}", + Self::MIN_DIMENSION ); Ok(ext) @@ -111,8 +115,9 @@ impl VTable for TurboQuant { let dimension = extension_list_size(ext)?; vortex_ensure!( - dimension >= 3, - "TurboQuant requires dimension >= 3, got {dimension}" + dimension >= Self::MIN_DIMENSION, + "TurboQuant requires dimension >= {}, got {dimension}", + Self::MIN_DIMENSION ); vortex_ensure_eq!(data.dimension(), dimension);