diff --git a/Cargo.lock b/Cargo.lock index 4a4f4b14629..d67efc8f78d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10067,7 +10067,9 @@ dependencies = [ "fastlanes", "mimalloc", "parquet 58.0.0", + "paste", "rand 0.10.0", + "rand_distr 0.6.0", "serde_json", "tokio", "tracing", @@ -10097,6 +10099,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-tensor", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10590,6 +10593,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-tensor", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10958,11 +10962,20 @@ dependencies = [ name = "vortex-tensor" version = "0.1.0" dependencies = [ + "half", "itertools 0.14.0", "num-traits", "prost 0.14.3", + "rand 0.10.0", + "rand_distr 0.6.0", "rstest", - "vortex", + "vortex-array", + "vortex-buffer", + "vortex-compressor", + "vortex-error", + "vortex-fastlanes", + "vortex-session", + "vortex-utils", ] [[package]] diff --git a/_typos.toml b/_typos.toml index e9cf23d68b7..62c3b0d6358 100644 --- a/_typos.toml +++ b/_typos.toml @@ -1,5 +1,5 @@ [default] -extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui"] +extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui", "wht", "WHT"] # We support a few common special comments to tell the checker to ignore sections of code extend-ignore-re = [ "(#|//)\\s*spellchecker:ignore-next-line\\n.*", # Ignore the next line diff --git a/vortex-btrblocks/public-api.lock b/vortex-btrblocks/public-api.lock index 37410874777..b1fe437a99e 100644 --- a/vortex-btrblocks/public-api.lock +++ b/vortex-btrblocks/public-api.lock @@ -618,10 +618,18 @@ pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::exclude(self, ids: impl cor pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::include(self, ids: impl core::iter::traits::collect::IntoIterator) -> Self +pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::with_scheme(self, scheme: &'static dyn vortex_compressor::scheme::Scheme) -> Self + impl core::clone::Clone for vortex_btrblocks::BtrBlocksCompressorBuilder pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::clone(&self) -> vortex_btrblocks::BtrBlocksCompressorBuilder +impl core::cmp::Eq for vortex_btrblocks::BtrBlocksCompressorBuilder + +impl core::cmp::PartialEq for vortex_btrblocks::BtrBlocksCompressorBuilder + +pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::eq(&self, other: &vortex_btrblocks::BtrBlocksCompressorBuilder) -> bool + impl core::default::Default for vortex_btrblocks::BtrBlocksCompressorBuilder pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::default() -> Self @@ -630,6 +638,8 @@ impl core::fmt::Debug for vortex_btrblocks::BtrBlocksCompressorBuilder pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +impl core::marker::StructuralPartialEq for vortex_btrblocks::BtrBlocksCompressorBuilder + pub const vortex_btrblocks::ALL_SCHEMES: &[&dyn vortex_compressor::scheme::Scheme] pub fn vortex_btrblocks::compress_patches(patches: vortex_array::patches::Patches) -> vortex_error::VortexResult diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index 6127f1e3910..771e7fdda91 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -112,7 +112,7 @@ pub fn default_excluded() -> HashSet { /// .include([IntDictScheme.id()]) /// .build(); /// ``` -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct BtrBlocksCompressorBuilder { schemes: HashSet<&'static dyn Scheme>, } @@ -144,6 +144,15 @@ impl BtrBlocksCompressorBuilder { self } + /// Adds an external compression scheme not in [`ALL_SCHEMES`]. + /// + /// This allows encoding crates outside of `vortex-btrblocks` to register + /// their own schemes with the compressor. + pub fn with_scheme(mut self, scheme: &'static dyn Scheme) -> Self { + self.schemes.insert(scheme); + self + } + /// Excludes the specified compression schemes by their [`SchemeId`]. pub fn exclude(mut self, ids: impl IntoIterator) -> Self { let ids: HashSet<_> = ids.into_iter().collect(); diff --git a/vortex-file/Cargo.toml b/vortex-file/Cargo.toml index d568328bb52..22163eb833e 100644 --- a/vortex-file/Cargo.toml +++ b/vortex-file/Cargo.toml @@ -54,6 +54,7 @@ vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } vortex-sparse = { workspace = true } +vortex-tensor = { workspace = true } vortex-utils = { workspace = true, features = ["dashmap"] } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } diff --git a/vortex-file/src/lib.rs b/vortex-file/src/lib.rs index d888eb88def..a33a8a1d709 100644 --- a/vortex-file/src/lib.rs +++ b/vortex-file/src/lib.rs @@ -178,4 +178,6 @@ pub fn register_default_encodings(session: &mut VortexSession) { vortex_fastlanes::initialize(session); vortex_runend::initialize(session); vortex_sequence::initialize(session); + #[cfg(feature = "unstable_encodings")] + vortex_tensor::encodings::turboquant::initialize(session); } diff --git a/vortex-file/src/strategy.rs b/vortex-file/src/strategy.rs index efd693c5ca1..7c5c3ed1edb 100644 --- a/vortex-file/src/strategy.rs +++ b/vortex-file/src/strategy.rs @@ -28,9 +28,11 @@ use vortex_array::arrays::VarBinView; use vortex_array::dtype::FieldPath; use vortex_array::session::ArrayRegistry; use vortex_array::session::ArraySession; +use vortex_btrblocks::BtrBlocksCompressorBuilder; use vortex_bytebool::ByteBool; use vortex_datetime_parts::DateTimeParts; use vortex_decimal_byte_parts::DecimalByteParts; +use vortex_error::vortex_panic; use vortex_fastlanes::BitPacked; use vortex_fastlanes::Delta; use vortex_fastlanes::FoR; @@ -53,13 +55,14 @@ use vortex_pco::Pco; use vortex_runend::RunEnd; use vortex_sequence::Sequence; use vortex_sparse::Sparse; +#[cfg(feature = "unstable_encodings")] +use vortex_tensor::encodings::turboquant::TurboQuant; use vortex_utils::aliases::hash_map::HashMap; use vortex_zigzag::ZigZag; #[rustfmt::skip] #[cfg(feature = "zstd")] use vortex_btrblocks::{ - BtrBlocksCompressorBuilder, SchemeExt, schemes::float, schemes::integer, @@ -111,6 +114,8 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { session.register(RunEnd); session.register(Sequence); session.register(Sparse); + #[cfg(feature = "unstable_encodings")] + session.register(TurboQuant); session.register(ZigZag); #[cfg(feature = "zstd")] @@ -127,11 +132,13 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { /// repartitioning and compressing them to strike a balance between size on-disk, /// bulk decoding performance, and IOPS required to perform an indexed read. pub struct WriteStrategyBuilder { - compressor: Option>, row_block_size: usize, field_writers: HashMap>, allow_encodings: Option, flat_strategy: Option>, + // builder and compressor are mutually exclusive + builder: Option, + compressor: Option>, } impl Default for WriteStrategyBuilder { @@ -139,11 +146,12 @@ impl Default for WriteStrategyBuilder { /// and then finally built yielding the [`LayoutStrategy`]. fn default() -> Self { Self { - compressor: None, row_block_size: 8192, field_writers: HashMap::new(), allow_encodings: Some(ALLOWED_ENCODINGS.clone()), flat_strategy: None, + builder: None, + compressor: None, } } } @@ -154,6 +162,9 @@ impl WriteStrategyBuilder { /// If not provided, this will use a BtrBlocks-style cascading compressor that tries to balance /// total size with decoding performance. pub fn with_compressor(mut self, compressor: C) -> Self { + if self.builder.is_some() { + vortex_panic!("Cannot configure both a custom compressor and custom builder schemes"); + } self.compressor = Some(Arc::new(compressor)); self } @@ -198,7 +209,12 @@ impl WriteStrategyBuilder { /// GPU decompression. Without it, strings use interleaved Zstd compression. #[cfg(feature = "zstd")] pub fn with_cuda_compatible_encodings(mut self) -> Self { - let mut builder = BtrBlocksCompressorBuilder::default().exclude([ + if self.compressor.is_some() { + vortex_panic!( + "Cannot configure both a custom compressor and CUDA compatible encodings" + ); + } + let b = self.builder.take().unwrap_or_default().exclude([ integer::SparseScheme.id(), integer::RLE_INTEGER_SCHEME.id(), float::RLE_FLOAT_SCHEME.id(), @@ -209,14 +225,13 @@ impl WriteStrategyBuilder { #[cfg(feature = "unstable_encodings")] { - builder = builder.include([string::ZstdBuffersScheme.id()]); + self.builder = Some(b.include([string::ZstdBuffersScheme.id()])); } #[cfg(not(feature = "unstable_encodings"))] { - builder = builder.include([string::ZstdScheme.id()]); + self.builder = Some(b.include([string::ZstdScheme.id()])); } - self.compressor = Some(Arc::new(builder.build())); self } @@ -227,21 +242,47 @@ impl WriteStrategyBuilder { /// especially for floating-point heavy datasets. #[cfg(feature = "zstd")] pub fn with_compact_encodings(mut self) -> Self { - let btrblocks = BtrBlocksCompressorBuilder::default() - .include([ - string::ZstdScheme.id(), - integer::PcoScheme.id(), - float::PcoScheme.id(), - ]) - .build(); - - self.compressor = Some(Arc::new(btrblocks)); + if self.compressor.is_some() { + vortex_panic!("Cannot configure both a custom compressor and compact encodings"); + } + self.builder = Some(self.builder.take().unwrap_or_default().include([ + string::ZstdScheme.id(), + integer::PcoScheme.id(), + float::PcoScheme.id(), + ])); + self + } + + /// Enable TurboQuant lossy vector quantization for tensor columns. + /// + /// When enabled, `Vector` and `FixedShapeTensor` extension arrays are + /// compressed using the TurboQuant algorithm with QJL correction for + /// unbiased inner product estimation. + /// + /// This augments any existing compressor configuration rather than + /// replacing it. If no compressor has been set, the default BtrBlocks + /// compressor is used with TurboQuant added. + #[cfg(feature = "unstable_encodings")] + pub fn with_vector_quantization(mut self) -> Self { + if self.compressor.is_some() { + vortex_panic!("Cannot configure both a custom compressor and vector quantization"); + } + use vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME; + self.builder = Some( + self.builder + .take() + .unwrap_or_default() + .with_scheme(&TURBOQUANT_SCHEME), + ); self } /// Builds the canonical [`LayoutStrategy`] implementation, with the configured overrides /// applied. pub fn build(self) -> Arc { + use vortex_btrblocks::SchemeExt as _; + use vortex_btrblocks::schemes::integer::IntDictScheme; + let flat: Arc = if let Some(flat) = self.flat_strategy { flat } else if let Some(allow_encodings) = self.allow_encodings { @@ -254,12 +295,24 @@ impl WriteStrategyBuilder { let chunked = ChunkedLayoutStrategy::new(flat.clone()); // 6. buffer chunks so they end up with closer segment ids physically let buffered = BufferedStrategy::new(chunked, 2 * ONE_MEG); // 2MB + // 5. compress each chunk - let compressing = if let Some(ref compressor) = self.compressor { - CompressingStrategy::new_opaque(buffered, compressor.clone()) - } else { - CompressingStrategy::new_btrblocks(buffered, true) - }; + let data_compressor: Arc = + if let Some(ref compressor) = self.compressor { + assert!( + self.builder.is_none(), + "Cannot configure both a custom compressor and custom builder schemes" + ); + compressor.clone() + } else { + Arc::new( + self.builder + .unwrap_or_default() + .exclude([IntDictScheme.id()]) + .build(), + ) + }; + let compressing = CompressingStrategy::new(buffered, data_compressor.clone()); // 4. prior to compression, coalesce up to a minimum size let coalescing = RepartitionStrategy::new( @@ -279,11 +332,12 @@ impl WriteStrategyBuilder { ); // 2.1. | 3.1. compress stats tables and dict values. - let compress_then_flat = if let Some(ref compressor) = self.compressor { - CompressingStrategy::new_opaque(flat, compressor.clone()) + let stats_compressor = if let Some(compressor) = self.compressor { + compressor.clone() } else { - CompressingStrategy::new_btrblocks(flat, false) + Arc::new(BtrBlocksCompressorBuilder::default().build()) }; + let compress_then_flat = CompressingStrategy::new(flat, stats_compressor); // 3. apply dict encoding or fallback let dict = DictStrategy::new( diff --git a/vortex-file/tests/test_write_table.rs b/vortex-file/tests/test_write_table.rs index 5b27f6e9026..4726cd4ff46 100644 --- a/vortex-file/tests/test_write_table.rs +++ b/vortex-file/tests/test_write_table.rs @@ -20,6 +20,7 @@ use vortex_array::field_path; use vortex_array::scalar_fn::session::ScalarFnSession; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; +use vortex_btrblocks::BtrBlocksCompressor; use vortex_buffer::ByteBuffer; use vortex_file::OpenOptionsSessionExt; use vortex_file::WriteOptionsSessionExt; @@ -67,9 +68,9 @@ async fn test_file_roundtrip() { // Create a writer which by default uses the BtrBlocks compressor for a.compressed, but leaves // the b and the a.raw columns uncompressed. - let default_strategy = Arc::new(CompressingStrategy::new_btrblocks( + let default_strategy = Arc::new(CompressingStrategy::new( FlatLayoutStrategy::default(), - false, + BtrBlocksCompressor::default(), )); let writer = Arc::new( diff --git a/vortex-layout/public-api.lock b/vortex-layout/public-api.lock index 88e1dc65a40..1b31f837935 100644 --- a/vortex-layout/public-api.lock +++ b/vortex-layout/public-api.lock @@ -168,9 +168,7 @@ pub struct vortex_layout::layouts::compressed::CompressingStrategy impl vortex_layout::layouts::compressed::CompressingStrategy -pub fn vortex_layout::layouts::compressed::CompressingStrategy::new_btrblocks(child: S, exclude_int_dict_encoding: bool) -> Self - -pub fn vortex_layout::layouts::compressed::CompressingStrategy::new_opaque(child: S, compressor: C) -> Self +pub fn vortex_layout::layouts::compressed::CompressingStrategy::new(child: S, compressor: C) -> Self pub fn vortex_layout::layouts::compressed::CompressingStrategy::with_concurrency(self, concurrency: usize) -> Self diff --git a/vortex-layout/src/layouts/compressed.rs b/vortex-layout/src/layouts/compressed.rs index 539e59982d6..18cceea3f74 100644 --- a/vortex-layout/src/layouts/compressed.rs +++ b/vortex-layout/src/layouts/compressed.rs @@ -9,9 +9,6 @@ use vortex_array::ArrayContext; use vortex_array::ArrayRef; use vortex_array::expr::stats::Stat; use vortex_btrblocks::BtrBlocksCompressor; -use vortex_btrblocks::BtrBlocksCompressorBuilder; -use vortex_btrblocks::SchemeExt; -use vortex_btrblocks::schemes::integer::IntDictScheme; use vortex_error::VortexResult; use vortex_io::runtime::Handle; @@ -60,32 +57,11 @@ pub struct CompressingStrategy { } impl CompressingStrategy { - /// Create a new writer that uses the BtrBlocks-style cascading compressor to compress chunks. - /// - /// This provides a good balance between decoding speed and small file size. - /// - /// Set `exclude_int_dict_encoding` to true to prevent dictionary encoding of integer arrays, - /// which is useful when compressing dictionary codes to avoid recursive dictionary encoding. - pub fn new_btrblocks(child: S, exclude_int_dict_encoding: bool) -> Self { - let compressor = if exclude_int_dict_encoding { - BtrBlocksCompressorBuilder::default() - .exclude([IntDictScheme.id()]) - .build() - } else { - BtrBlocksCompressor::default() - }; - Self::new(child, Arc::new(compressor)) - } - - /// Create a new compressor from a plugin interface. - pub fn new_opaque(child: S, compressor: C) -> Self { - Self::new(child, Arc::new(compressor)) - } - - fn new(child: S, compressor: Arc) -> Self { + /// Create a new compressing strategy that wraps a child strategy with a compressor plugin. + pub fn new(child: S, compressor: C) -> Self { Self { child: Arc::new(child), - compressor, + compressor: Arc::new(compressor), concurrency: std::thread::available_parallelism() .map(|v| v.get()) .unwrap_or(1), diff --git a/vortex-layout/src/layouts/table.rs b/vortex-layout/src/layouts/table.rs index 2f2bf9df863..a28a2cd9a8a 100644 --- a/vortex-layout/src/layouts/table.rs +++ b/vortex-layout/src/layouts/table.rs @@ -86,12 +86,13 @@ impl TableStrategy { /// ```ignore /// # use std::sync::Arc; /// # use vortex_array::dtype::{field_path, Field, FieldPath}; + /// # use vortex_btrblocks::BtrBlocksCompressor; /// # use vortex_layout::layouts::compressed::CompressingStrategy; /// # use vortex_layout::layouts::flat::writer::FlatLayoutStrategy; /// # use vortex_layout::layouts::table::TableStrategy; /// /// // A strategy for compressing data using the balanced BtrBlocks compressor. - /// let compress = CompressingStrategy::new_btrblocks(FlatLayoutStrategy::default(), true); + /// let compress = CompressingStrategy::new(FlatLayoutStrategy::default(), BtrBlocksCompressor::default()); /// /// // Our combined strategy uses no compression for validity buffers, BtrBlocks compression /// // for most columns, and stores a nested binary column uncompressed (flat) because it diff --git a/vortex-python/src/io.rs b/vortex-python/src/io.rs index bf60a4b2d5e..2c110c1697f 100644 --- a/vortex-python/src/io.rs +++ b/vortex-python/src/io.rs @@ -291,7 +291,7 @@ impl PyVortexWriteOptions { /// ```python /// >>> vx.io.VortexWriteOptions.compact().write(sprl, "tiny.vortex") /// >>> os.path.getsize('tiny.vortex') - /// 55120 + /// 55460 /// ``` /// /// Random numbers are not (usually) composed of random bytes! diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 6f4fe4511af..9f94a0c2d3d 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -17,11 +17,20 @@ version = { workspace = true } workspace = true [dependencies] -vortex = { workspace = true } +vortex-array = { workspace = true } +vortex-buffer = { workspace = true } +vortex-compressor = { workspace = true } +vortex-error = { workspace = true } +vortex-fastlanes = { workspace = true } +vortex-session = { workspace = true } +vortex-utils = { workspace = true } +half = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } +rand = { workspace = true } [dev-dependencies] +rand_distr = { workspace = true } rstest = { workspace = true } diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs index 090151e9226..7c75269b632 100644 --- a/vortex-tensor/src/encodings/mod.rs +++ b/vortex-tensor/src/encodings/mod.rs @@ -7,5 +7,4 @@ // pub mod norm; // Unit-normalized vectors. // pub mod spherical; // Spherical transform on unit-normalized vectors. -// TODO(will): -// pub mod turboquant; +pub mod turboquant; diff --git a/vortex-tensor/src/encodings/turboquant/array.rs b/vortex-tensor/src/encodings/turboquant/array.rs new file mode 100644 index 00000000000..89c600853b4 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array.rs @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant array definition: stores quantized coordinate codes, norms, +//! centroids (codebook), rotation signs, and optional QJL correction fields. + +use vortex_array::ArrayId; +use vortex_array::ArrayRef; +use vortex_array::dtype::DType; +use vortex_array::stats::ArrayStats; +use vortex_array::vtable; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// Encoding marker type for TurboQuant. +#[derive(Clone, Debug)] +pub struct TurboQuant; + +impl TurboQuant { + pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); +} + +vtable!(TurboQuant, TurboQuant, TurboQuantData); + +/// Protobuf metadata for TurboQuant encoding. +#[derive(Clone, prost::Message)] +pub struct TurboQuantMetadata { + /// Vector dimension d. + #[prost(uint32, tag = "1")] + pub dimension: u32, + /// MSE bits per coordinate (1-8). + #[prost(uint32, tag = "2")] + pub bit_width: u32, + /// Whether QJL correction children are present. + #[prost(bool, tag = "3")] + pub has_qjl: bool, +} + +/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased +/// inner product estimation. When present, adds 3 additional children. +#[derive(Clone, Debug)] +pub struct QjlCorrection { + /// Sign bits: `BoolArray`, length `num_rows * padded_dim`. + pub(crate) signs: ArrayRef, + /// Residual norms: `PrimitiveArray`, length `num_rows`. + pub(crate) residual_norms: ArrayRef, + /// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order). + pub(crate) rotation_signs: ArrayRef, +} + +impl QjlCorrection { + /// The QJL sign bits. + pub fn signs(&self) -> &ArrayRef { + &self.signs + } + + /// The residual norms. + pub fn residual_norms(&self) -> &ArrayRef { + &self.residual_norms + } + + /// The QJL rotation signs (BoolArray, inverse application order). + pub fn rotation_signs(&self) -> &ArrayRef { + &self.rotation_signs + } +} + +/// Slot positions for TurboQuantArray children. +#[repr(usize)] +#[derive(Clone, Copy, Debug)] +pub(crate) enum Slot { + Codes = 0, + Norms = 1, + Centroids = 2, + RotationSigns = 3, + QjlSigns = 4, + QjlResidualNorms = 5, + QjlRotationSigns = 6, +} + +impl Slot { + pub(crate) const COUNT: usize = 7; + + pub(crate) fn name(self) -> &'static str { + match self { + Self::Codes => "codes", + Self::Norms => "norms", + Self::Centroids => "centroids", + Self::RotationSigns => "rotation_signs", + Self::QjlSigns => "qjl_signs", + Self::QjlResidualNorms => "qjl_residual_norms", + Self::QjlRotationSigns => "qjl_rotation_signs", + } + } + + pub(crate) fn from_index(idx: usize) -> Self { + match idx { + 0 => Self::Codes, + 1 => Self::Norms, + 2 => Self::Centroids, + 3 => Self::RotationSigns, + 4 => Self::QjlSigns, + 5 => Self::QjlResidualNorms, + 6 => Self::QjlRotationSigns, + _ => vortex_error::vortex_panic!("invalid slot index {idx}"), + } + } +} + +/// TurboQuant array. +/// +/// Slots (always present): +/// - 0: `codes` — `FixedSizeListArray` (quantized indices, list_size=padded_dim) +/// - 1: `norms` — `PrimitiveArray` (one per vector row) +/// - 2: `centroids` — `PrimitiveArray` (codebook, length 2^bit_width) +/// - 3: `rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order) +/// +/// Optional QJL slots (None when MSE-only): +/// - 4: `qjl_signs` — `FixedSizeListArray` (num_rows * padded_dim, 1-bit) +/// - 5: `qjl_residual_norms` — `PrimitiveArray` (one per row) +/// - 6: `qjl_rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit, QJL rotation) +#[derive(Clone, Debug)] +pub struct TurboQuantData { + pub(crate) dtype: DType, + pub(crate) slots: Vec>, + pub(crate) dimension: u32, + pub(crate) bit_width: u8, + pub(crate) stats_set: ArrayStats, +} + +impl TurboQuantData { + /// Build a TurboQuant array with MSE-only encoding (no QJL correction). + #[allow(clippy::too_many_arguments)] + pub fn try_new_mse( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + dimension: u32, + bit_width: u8, + ) -> VortexResult { + vortex_ensure!( + (1..=8).contains(&bit_width), + "MSE bit_width must be 1-8, got {bit_width}" + ); + let mut slots = vec![None; Slot::COUNT]; + slots[Slot::Codes as usize] = Some(codes); + slots[Slot::Norms as usize] = Some(norms); + slots[Slot::Centroids as usize] = Some(centroids); + slots[Slot::RotationSigns as usize] = Some(rotation_signs); + Ok(Self { + dtype, + slots, + dimension, + bit_width, + stats_set: Default::default(), + }) + } + + /// Build a TurboQuant array with QJL correction (MSE + QJL). + #[allow(clippy::too_many_arguments)] + pub fn try_new_qjl( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + qjl: QjlCorrection, + dimension: u32, + bit_width: u8, + ) -> VortexResult { + vortex_ensure!( + (1..=8).contains(&bit_width), + "MSE bit_width must be 1-8, got {bit_width}" + ); + let mut slots = vec![None; Slot::COUNT]; + slots[Slot::Codes as usize] = Some(codes); + slots[Slot::Norms as usize] = Some(norms); + slots[Slot::Centroids as usize] = Some(centroids); + slots[Slot::RotationSigns as usize] = Some(rotation_signs); + slots[Slot::QjlSigns as usize] = Some(qjl.signs); + slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms); + slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs); + Ok(Self { + dtype, + slots, + dimension, + bit_width, + stats_set: Default::default(), + }) + } + + /// The vector dimension d. + pub fn dimension(&self) -> u32 { + self.dimension + } + + /// MSE bits per coordinate. + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Padded dimension (next power of 2 >= dimension). + pub fn padded_dim(&self) -> u32 { + self.dimension.next_power_of_two() + } + + /// Whether QJL correction is present. + pub fn has_qjl(&self) -> bool { + self.slots[Slot::QjlSigns as usize].is_some() + } + + fn slot(&self, idx: usize) -> &ArrayRef { + self.slots[idx] + .as_ref() + .vortex_expect("required slot is None") + } + + /// The quantized codes child (FixedSizeListArray). + pub fn codes(&self) -> &ArrayRef { + self.slot(Slot::Codes as usize) + } + + /// The norms child (`PrimitiveArray`). + pub fn norms(&self) -> &ArrayRef { + self.slot(Slot::Norms as usize) + } + + /// The centroids (codebook) child (`PrimitiveArray`). + pub fn centroids(&self) -> &ArrayRef { + self.slot(Slot::Centroids as usize) + } + + /// The MSE rotation signs child (BitPackedArray, length 3 * padded_dim). + pub fn rotation_signs(&self) -> &ArrayRef { + self.slot(Slot::RotationSigns as usize) + } + + /// The optional QJL correction fields, reconstructed from slots. + pub fn qjl(&self) -> Option { + Some(QjlCorrection { + signs: self.slots[Slot::QjlSigns as usize].clone()?, + residual_norms: self.slots[Slot::QjlResidualNorms as usize].clone()?, + rotation_signs: self.slots[Slot::QjlRotationSigns as usize].clone()?, + }) + } + + /// Set the QJL correction fields on this array. + pub(crate) fn set_qjl(&mut self, qjl: QjlCorrection) { + self.slots[Slot::QjlSigns as usize] = Some(qjl.signs); + self.slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms); + self.slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/centroids.rs b/vortex-tensor/src/encodings/turboquant/centroids.rs new file mode 100644 index 00000000000..85ea39fcc9e --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/centroids.rs @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. +//! +//! Pre-computes optimal scalar quantizer centroids for the marginal distribution of coordinates +//! after random rotation of a unit-norm vector. In high dimensions, each coordinate of a randomly +//! rotated unit vector follows a distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, +//! which converges to `N(0, 1/d)`. The Max-Lloyd algorithm finds optimal quantization centroids +//! that minimize MSE for this distribution. + +use std::sync::LazyLock; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_utils::aliases::dash_map::DashMap; + +/// Number of numerical integration points for computing conditional expectations. +const INTEGRATION_POINTS: usize = 1000; + +/// Max-Lloyd convergence threshold. +const CONVERGENCE_EPSILON: f64 = 1e-12; + +/// Maximum iterations for Max-Lloyd algorithm. +const MAX_ITERATIONS: usize = 200; + +/// Global centroid cache keyed by (dimension, bit_width). +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); + +/// Get or compute cached centroids for the given dimension and bit width. +/// +/// Returns `2^bit_width` centroids sorted in ascending order, representing +/// optimal scalar quantization levels for the coordinate distribution after +/// random rotation in `dimension`-dimensional space. +pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { + if !(1..=8).contains(&bit_width) { + vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}"); + } + if dimension < 3 { + vortex_bail!("TurboQuant dimension must be >= 3, got {dimension}"); + } + + if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { + return Ok(centroids.clone()); + } + + let centroids = max_lloyd_centroids(dimension, bit_width); + CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + Ok(centroids) +} + +/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. +/// +/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) +/// or a half-integer (when `d` is even). This type makes that invariant explicit and +/// avoids floating-point comparison in the hot path. +#[derive(Clone, Copy, Debug)] +struct HalfIntExponent { + int_part: i32, + has_half: bool, +} + +impl HalfIntExponent { + /// Compute `(numerator) / 2` as a half-integer exponent. + /// + /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. + fn from_numerator(numerator: i32) -> Self { + // Integer division truncates toward zero; for negative odd numerators + // (e.g., d=2 → num=-1) this gives int_part=0, has_half=true, + // representing -0.5 = 0 + (-0.5). The sign is handled by adjusting + // int_part: -1/2 = 0 with has_half, but we need the floor division. + // Rust's `/` truncates toward zero, so -1/2 = 0. We want floor: -1. + // Use divmod that rounds toward negative infinity. + let int_part = numerator.div_euclid(2); + let has_half = numerator.rem_euclid(2) != 0; + Self { int_part, has_half } + } +} + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// +/// Operates on the marginal distribution of a single coordinate of a randomly +/// rotated unit vector in d dimensions. The PDF is: +/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` +/// where `C_d` is the normalizing constant. +#[allow(clippy::cast_possible_truncation)] // f64→f32 centroid values are intentional +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec { + let num_centroids = 1usize << bit_width; + + // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. + let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); + + // Initialize centroids uniformly on [-1, 1]. + let mut centroids: Vec = (0..num_centroids) + .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .collect(); + + let mut boundaries: Vec = vec![0.0; num_centroids + 1]; + for _ in 0..MAX_ITERATIONS { + // Compute decision boundaries (midpoints between adjacent centroids). + boundaries[0] = -1.0; + for idx in 0..num_centroids - 1 { + boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; + } + boundaries[num_centroids] = 1.0; + + // Update each centroid to the conditional mean within its Voronoi cell. + let mut max_change = 0.0f64; + for idx in 0..num_centroids { + let lo = boundaries[idx]; + let hi = boundaries[idx + 1]; + let new_centroid = conditional_mean(lo, hi, exponent); + max_change = max_change.max((new_centroid - centroids[idx]).abs()); + centroids[idx] = new_centroid; + } + + if max_change < CONVERGENCE_EPSILON { + break; + } + } + + centroids.into_iter().map(|val| val as f32).collect() +} + +/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. +/// +/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` +/// on [-1, 1]. +fn conditional_mean(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { + if (hi - lo).abs() < 1e-15 { + return (lo + hi) / 2.0; + } + + let dx = (hi - lo) / INTEGRATION_POINTS as f64; + + let mut numerator = 0.0; + let mut denominator = 0.0; + + for step in 0..=INTEGRATION_POINTS { + let x_val = lo + (step as f64) * dx; + let weight = pdf_unnormalized(x_val, exponent); + + let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + 0.5 + } else { + 1.0 + }; + + numerator += trap_weight * x_val * weight; + denominator += trap_weight * weight; + } + + if denominator.abs() < 1e-30 { + (lo + hi) / 2.0 + } else { + numerator / denominator + } +} + +/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. +/// +/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents +/// that arise from `(d-3)/2`. This is significantly faster than the general +/// `powf` which goes through `exp(exponent * ln(base))`. +#[inline] +fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { + let base = (1.0 - x_val * x_val).max(0.0); + + if exponent.has_half { + // Half-integer exponent: base^(int_part) * sqrt(base). + base.powi(exponent.int_part) * base.sqrt() + } else { + // Integer exponent: use powi directly. + base.powi(exponent.int_part) + } +} + +/// Precompute decision boundaries (midpoints between adjacent centroids). +/// +/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps +/// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, +/// and a value >= `boundaries[k-2]` maps to centroid `k-1`. +pub fn compute_boundaries(centroids: &[f32]) -> Vec { + centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() +} + +/// Find the index of the nearest centroid using precomputed decision boundaries. +/// +/// `boundaries` must be the output of [`compute_boundaries`] for the corresponding +/// centroids. Uses binary search on the midpoints, avoiding distance comparisons +/// in the inner loop. +#[inline] +#[allow(clippy::cast_possible_truncation)] // bounded by num_centroids <= 256 +pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { + debug_assert!( + boundaries.windows(2).all(|w| w[0] <= w[1]), + "boundaries must be sorted" + ); + + boundaries.partition_point(|&b| b < value) as u8 +} + +#[cfg(test)] +#[allow(clippy::cast_possible_truncation)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[rstest] + #[case(128, 1, 2)] + #[case(128, 2, 4)] + #[case(128, 3, 8)] + #[case(128, 4, 16)] + #[case(768, 2, 4)] + #[case(1536, 3, 8)] + fn centroids_have_correct_count( + #[case] dim: u32, + #[case] bits: u8, + #[case] expected: usize, + ) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + assert_eq!(centroids.len(), expected); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(768, 2)] + fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for window in centroids.windows(2) { + assert!( + window[0] < window[1], + "centroids not sorted: {:?}", + centroids + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(256, 2)] + #[case(768, 2)] + fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + let count = centroids.len(); + for idx in 0..count / 2 { + let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); + assert!( + diff < 1e-5, + "centroids not symmetric: c[{idx}]={}, c[{}]={}", + centroids[idx], + count - 1 - idx, + centroids[count - 1 - idx] + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 4)] + fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for &val in ¢roids { + assert!( + (-1.0..=1.0).contains(&val), + "centroid out of [-1, 1]: {val}", + ); + } + Ok(()) + } + + #[test] + fn centroids_cached() -> VortexResult<()> { + let c1 = get_centroids(128, 2)?; + let c2 = get_centroids(128, 2)?; + assert_eq!(c1, c2); + Ok(()) + } + + #[test] + fn find_nearest_basic() -> VortexResult<()> { + let centroids = get_centroids(128, 2)?; + let boundaries = compute_boundaries(¢roids); + assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + + let last_idx = (centroids.len() - 1) as u8; + assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + for (idx, &cv) in centroids.iter().enumerate() { + let expected = idx as u8; + assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + } + Ok(()) + } + + #[test] + fn rejects_invalid_params() { + assert!(get_centroids(128, 0).is_err()); + assert!(get_centroids(128, 9).is_err()); + assert!(get_centroids(1, 2).is_err()); + assert!(get_centroids(2, 2).is_err()); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs new file mode 100644 index 00000000000..756b17cdada --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant encoding (quantization) logic. + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_fastlanes::bitpack_compress::bitpack_encode; + +use crate::encodings::turboquant::array::TurboQuantData; +use crate::encodings::turboquant::centroids::compute_boundaries; +use crate::encodings::turboquant::centroids::find_nearest_centroid; +use crate::encodings::turboquant::centroids::get_centroids; +use crate::encodings::turboquant::rotation::RotationMatrix; + +/// Configuration for TurboQuant encoding. +#[derive(Clone, Debug)] +pub struct TurboQuantConfig { + /// Bits per coordinate. + /// + /// For MSE encoding: 1-8. + /// For QJL encoding: 2-9 (the MSE component uses `bit_width - 1`). + pub bit_width: u8, + /// Optional seed for the rotation matrix. If None, the default seed is used. + pub seed: Option, +} + +impl Default for TurboQuantConfig { + fn default() -> Self { + Self { + bit_width: 5, + seed: Some(42), + } + } +} + +/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray. +#[allow(clippy::cast_possible_truncation)] +fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult { + let elements = fsl.elements(); + let primitive = elements.to_canonical()?.into_primitive(); + let ptype = primitive.ptype(); + + match ptype { + PType::F16 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| f32::from(v)) + .collect()), + PType::F32 => Ok(primitive), + PType::F64 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| v as f32) + .collect()), + _ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"), + } +} + +/// Compute the L2 norm of a vector. +#[inline] +fn l2_norm(x: &[f32]) -> f32 { + x.iter().map(|&v| v * v).sum::().sqrt() +} + +/// Shared intermediate results from the MSE quantization loop. +struct MseQuantizationResult { + rotation: RotationMatrix, + f32_elements: PrimitiveArray, + centroids: Vec, + all_indices: BufferMut, + norms: BufferMut, + padded_dim: usize, +} + +/// Core quantization: extract f32 elements, build rotation, normalize/rotate/quantize all rows. +#[allow(clippy::cast_possible_truncation)] +fn turboquant_quantize_core( + fsl: &FixedSizeListArray, + seed: u64, + bit_width: u8, +) -> VortexResult { + let dimension = fsl.list_size() as usize; + let num_rows = fsl.len(); + + let rotation = RotationMatrix::try_new(seed, dimension)?; + let padded_dim = rotation.padded_dim(); + + let f32_elements = extract_f32_elements(fsl)?; + + let centroids = get_centroids(padded_dim as u32, bit_width)?; + let boundaries = compute_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut norms = BufferMut::::with_capacity(num_rows); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + let f32_slice = f32_elements.as_slice::(); + for row in 0..num_rows { + let x = &f32_slice[row * dimension..(row + 1) * dimension]; + let norm = l2_norm(x); + norms.push(norm); + + if norm > 0.0 { + let inv_norm = 1.0 / norm; + for (dst, &src) in padded[..dimension].iter_mut().zip(x.iter()) { + *dst = src * inv_norm; + } + } else { + padded[..dimension].fill(0.0); + } + rotation.rotate(&padded, &mut rotated); + + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + Ok(MseQuantizationResult { + rotation, + f32_elements, + centroids, + all_indices, + norms, + padded_dim, + }) +} + +/// Build a `TurboQuantArray` (MSE-only) from quantization results. +#[allow(clippy::cast_possible_truncation)] +fn build_turboquant_mse( + fsl: &FixedSizeListArray, + core: MseQuantizationResult, + bit_width: u8, +) -> VortexResult { + let dimension = fsl.list_size(); + + let num_rows = fsl.len(); + let padded_dim = core.padded_dim; + let codes_elements = + PrimitiveArray::new::(core.all_indices.freeze(), Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes_elements.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )? + .into_array(); + let norms_array = + PrimitiveArray::new::(core.norms.freeze(), Validity::NonNullable).into_array(); + + // TODO(perf): `get_centroids` returns Vec; could avoid the copy by + // supporting Buffer::from(Vec) or caching as Buffer directly. + let mut centroids_buf = BufferMut::::with_capacity(core.centroids.len()); + centroids_buf.extend_from_slice(&core.centroids); + let centroids_array = + PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable).into_array(); + + let rotation_signs = bitpack_rotation_signs(&core.rotation)?; + + TurboQuantData::try_new_mse( + fsl.dtype().clone(), + codes, + norms_array, + centroids_array, + rotation_signs, + dimension, + bit_width, + ) +} + +/// Encode a FixedSizeListArray into a MSE-only `TurboQuantArray`. +/// +/// The input must be non-nullable. TurboQuant is a lossy encoding that does not +/// preserve null positions; callers must handle validity externally. +pub fn turboquant_encode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 1 && config.bit_width <= 8, + "MSE bit_width must be 1-8, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + if fsl.is_empty() { + return Ok(fsl.clone().into_array()); + } + + let seed = config.seed.unwrap_or(42); + let core = turboquant_quantize_core(fsl, seed, config.bit_width)?; + + Ok(build_turboquant_mse(fsl, core, config.bit_width)?.into_array()) +} + +/// Encode a FixedSizeListArray into a `TurboQuantArray` with QJL correction. +/// +/// The QJL variant uses `bit_width - 1` MSE bits plus 1 bit of QJL residual +/// correction, giving unbiased inner product estimation. The input must be +/// non-nullable. +#[allow(clippy::cast_possible_truncation)] +pub fn turboquant_encode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 2 && config.bit_width <= 9, + "QJL bit_width must be 2-9, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + if fsl.is_empty() { + return Ok(fsl.clone().into_array()); + } + + let seed = config.seed.unwrap_or(42); + let dim = dimension as usize; + let mse_bit_width = config.bit_width - 1; + + let core = turboquant_quantize_core(fsl, seed, mse_bit_width)?; + let padded_dim = core.padded_dim; + + // QJL uses a different rotation than the MSE stage to ensure statistical + // independence between the quantization noise and the sign projection. + let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(25), dim)?; + + let num_rows = fsl.len(); + let mut residual_norms_buf = BufferMut::::with_capacity(num_rows); + let mut qjl_sign_u8 = BufferMut::::with_capacity(num_rows * padded_dim); + + let mut dequantized_rotated = vec![0.0f32; padded_dim]; + let mut dequantized = vec![0.0f32; padded_dim]; + let mut residual = vec![0.0f32; padded_dim]; + let mut projected = vec![0.0f32; padded_dim]; + + // Compute QJL residuals using precomputed indices and norms from the core. + { + let f32_slice = core.f32_elements.as_slice::(); + let indices_slice: &[u8] = &core.all_indices; + let norms_slice: &[f32] = &core.norms; + + for row in 0..num_rows { + let x = &f32_slice[row * dim..(row + 1) * dim]; + let norm = norms_slice[row]; + + // Dequantize from precomputed indices. + let row_indices = &indices_slice[row * padded_dim..(row + 1) * padded_dim]; + for j in 0..padded_dim { + dequantized_rotated[j] = core.centroids[row_indices[j] as usize]; + } + + core.rotation + .inverse_rotate(&dequantized_rotated, &mut dequantized); + if norm > 0.0 { + for val in dequantized[..dim].iter_mut() { + *val *= norm; + } + } + + // Compute residual: r = x_padded - x̂. + // For positions 0..dim: r[j] = x[j] - dequantized[j]. + // For pad positions dim..padded_dim: the original was zero-padded, + // so r[j] = 0 - dequantized[j]. These pad artifacts are nonzero + // because the SRHT mixes quantization error into the padded region. + // Omitting them would corrupt the QJL signs for non-power-of-2 dims. + for j in 0..dim { + residual[j] = x[j] - dequantized[j]; + } + for j in dim..padded_dim { + residual[j] = -dequantized[j]; + } + // The residual norm for QJL scaling is over the dim-dimensional + // subspace only — pad artifacts don't contribute to reconstruction + // error in the output space. The pad positions are still included + // in the sign projection to avoid corrupting the SRHT mixing. + let residual_norm = l2_norm(&residual[..dim]); + residual_norms_buf.push(residual_norm); + + // QJL: sign(S · r). + if residual_norm > 0.0 { + qjl_rotation.rotate(&residual, &mut projected); + } else { + projected.fill(0.0); + } + + for j in 0..padded_dim { + qjl_sign_u8.push(if projected[j] >= 0.0 { 1u8 } else { 0u8 }); + } + } + } + + // Build the MSE part. + let mut array = build_turboquant_mse(fsl, core, mse_bit_width)?; + + // Attach QJL correction. + let residual_norms_array = + PrimitiveArray::new::(residual_norms_buf.freeze(), Validity::NonNullable); + let qjl_signs_elements = PrimitiveArray::new::(qjl_sign_u8.freeze(), Validity::NonNullable); + let qjl_signs = FixedSizeListArray::try_new( + qjl_signs_elements.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )?; + let qjl_rotation_signs = bitpack_rotation_signs(&qjl_rotation)?; + + array.set_qjl(crate::encodings::turboquant::array::QjlCorrection { + signs: qjl_signs.into_array(), + residual_norms: residual_norms_array.into_array(), + rotation_signs: qjl_rotation_signs, + }); + + Ok(array.into_array()) +} + +/// Export rotation signs as a 1-bit `BitPackedArray` for efficient storage. +/// +/// The rotation matrix's 3 × padded_dim sign values are exported as 0/1 u8 +/// values in inverse application order, then bitpacked to 1 bit per sign. +/// On decode, FastLanes SIMD-unpacks back to `&[u8]` of 0/1 values. +fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult { + let signs_u8 = rotation.export_inverse_signs_u8(); + let mut buf = BufferMut::::with_capacity(signs_u8.len()); + buf.extend_from_slice(&signs_u8); + let prim = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + Ok(bitpack_encode(&prim, 1, None)?.into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs new file mode 100644 index 00000000000..2666270f6e8 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Approximate cosine similarity in the quantized domain. +//! +//! Since the SRHT is orthogonal, inner products are preserved in the rotated +//! domain. For two vectors from the same TurboQuant column (same rotation and +//! centroids), we can compute the dot product of their quantized representations +//! without full decompression: +//! +//! ```text +//! cos_approx(a, b) = sum(centroids[code_a[j]] × centroids[code_b[j]]) +//! ``` +//! +//! where `code_a` and `code_b` are the quantized coordinate indices of the +//! unit-norm rotated vectors `â_rot` and `b̂_rot`. +//! +//! # Bias and error bounds +//! +//! This estimate is **biased** — it uses only the MSE-quantized codes and does +//! not incorporate the QJL residual correction. The MSE quantizer minimizes +//! reconstruction error but does not guarantee unbiased inner products; the +//! discrete centroid grid introduces systematic bias in the dot product. +//! +//! The TurboQuant paper's Theorem 2 shows that unbiased inner product estimation +//! requires the full QJL correction term, which involves decoding the per-row +//! QJL signs and computing cross-terms — nearly as expensive as full decompression. +//! +//! The approximation error is bounded by the MSE quantization distortion. For +//! unit-norm vectors quantized at `b` bits, the per-coordinate MSE is bounded by +//! `(√3 · π / 2) / 4^b` (Theorem 1). The inner product error scales with this +//! distortion: at 4 bits the error is typically < 0.1, at 8 bits < 0.001. +//! +//! For approximate nearest neighbor (ANN) search, biased-but-accurate ranking is +//! usually sufficient — the relative ordering of cosine similarities is preserved +//! even if the absolute values have bounded error. + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use crate::encodings::turboquant::TurboQuant; + +/// Shared helper: read codes, norms, and centroids from two TurboQuant arrays, +/// then compute per-row quantized unit-norm dot products. +/// +/// Both arrays must have the same dimension (vector length) and row count. +/// They may have different codebooks (e.g., different bit widths), in which +/// case each array's own centroids are used for its code lookups. +/// +/// Returns `(norms_a, norms_b, unit_dots)` where `unit_dots[i]` is the dot product +/// of the unit-norm quantized vectors for row i. +fn quantized_unit_dots( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult<(Vec, Vec, Vec)> { + vortex_ensure!( + lhs.dimension() == rhs.dimension(), + "TurboQuant quantized dot product requires matching dimensions, got {} and {}", + lhs.dimension(), + rhs.dimension() + ); + + let pd = lhs.padded_dim() as usize; + let num_rows = lhs.norms().len(); + + let lhs_norms: PrimitiveArray = lhs.norms().clone().execute(ctx)?; + let rhs_norms: PrimitiveArray = rhs.norms().clone().execute(ctx)?; + let na = lhs_norms.as_slice::(); + let nb = rhs_norms.as_slice::(); + + let lhs_codes_fsl: FixedSizeListArray = lhs.codes().clone().execute(ctx)?; + let rhs_codes_fsl: FixedSizeListArray = rhs.codes().clone().execute(ctx)?; + let lhs_codes = lhs_codes_fsl.elements().to_canonical()?.into_primitive(); + let rhs_codes = rhs_codes_fsl.elements().to_canonical()?.into_primitive(); + let ca = lhs_codes.as_slice::(); + let cb = rhs_codes.as_slice::(); + + // Read centroids from both arrays — they may have different codebooks + // (e.g., different bit widths). + let lhs_centroids: PrimitiveArray = lhs.centroids().clone().execute(ctx)?; + let rhs_centroids: PrimitiveArray = rhs.centroids().clone().execute(ctx)?; + let cl = lhs_centroids.as_slice::(); + let cr = rhs_centroids.as_slice::(); + + let mut dots = Vec::with_capacity(num_rows); + for row in 0..num_rows { + let row_ca = &ca[row * pd..(row + 1) * pd]; + let row_cb = &cb[row * pd..(row + 1) * pd]; + let dot: f32 = row_ca + .iter() + .zip(row_cb.iter()) + .map(|(&a, &b)| cl[a as usize] * cr[b as usize]) + .sum(); + dots.push(dot); + } + + Ok((na.to_vec(), nb.to_vec(), dots)) +} + +/// Compute approximate cosine similarity for all rows between two TurboQuant +/// arrays (same rotation matrix and codebook) without full decompression. +pub fn cosine_similarity_quantized_column( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_rows = lhs.norms().len(); + let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?; + + let mut result = BufferMut::::with_capacity(num_rows); + for row in 0..num_rows { + if na[row] == 0.0 || nb[row] == 0.0 { + result.push(0.0); + } else { + // Unit-norm dot product IS the cosine similarity. + result.push(dots[row]); + } + } + + Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) +} + +/// Compute approximate dot product for all rows between two TurboQuant +/// arrays (same rotation matrix and codebook) without full decompression. +/// +/// `dot_product(a, b) ≈ ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])` +pub fn dot_product_quantized_column( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_rows = lhs.norms().len(); + let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?; + + let mut result = BufferMut::::with_capacity(num_rows); + for row in 0..num_rows { + // Scale the unit-norm dot product by both norms to get the actual dot product. + result.push(na[row] * nb[row] * dots[row]); + } + + Ok(PrimitiveArray::new::(result.freeze(), Validity::NonNullable).into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/mod.rs b/vortex-tensor/src/encodings/turboquant/compute/mod.rs new file mode 100644 index 00000000000..67b4d3efb7f --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/mod.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Compute pushdown implementations for TurboQuant. + +pub(crate) mod cosine_similarity; +mod ops; +pub(crate) mod rules; +mod slice; +mod take; diff --git a/vortex-tensor/src/encodings/turboquant/compute/ops.rs b/vortex-tensor/src/encodings/turboquant/compute/ops.rs new file mode 100644 index 00000000000..89f25c61018 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/ops.rs @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_array::scalar::Scalar; +use vortex_array::vtable::OperationsVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::encodings::turboquant::array::TurboQuant; + +impl OperationsVTable for TurboQuant { + fn scalar_at( + array: ArrayView<'_, TurboQuant>, + index: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + // Slice to single row, decompress that one row. + let Some(sliced) = ::slice(array, index..index + 1)? else { + vortex_bail!("slice returned None for index {index}") + }; + let decoded = sliced.execute::(ctx)?; + decoded.scalar_at(0) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/rules.rs b/vortex-tensor/src/encodings/turboquant/compute/rules.rs new file mode 100644 index 00000000000..d482994f720 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/rules.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::dict::TakeExecuteAdaptor; +use vortex_array::arrays::slice::SliceReduceAdaptor; +use vortex_array::kernel::ParentKernelSet; +use vortex_array::optimizer::rules::ParentRuleSet; + +use crate::encodings::turboquant::array::TurboQuant; + +pub(crate) static RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&SliceReduceAdaptor(TurboQuant))]); + +pub(crate) static PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(TurboQuant))]); diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs new file mode 100644 index 00000000000..f2f124d9dee --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/slice.rs @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::Range; + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::IntoArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::array::QjlCorrection; +use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::array::TurboQuantData; + +impl SliceReduce for TurboQuant { + fn slice( + array: ArrayView<'_, TurboQuant>, + range: Range, + ) -> VortexResult> { + let sliced_codes = array.codes().slice(range.clone())?; + let sliced_norms = array.norms().slice(range.clone())?; + + let sliced_qjl = array + .qjl() + .map(|qjl| -> VortexResult { + Ok(QjlCorrection { + signs: qjl.signs.slice(range.clone())?, + residual_norms: qjl.residual_norms.slice(range.clone())?, + rotation_signs: qjl.rotation_signs.clone(), + }) + }) + .transpose()?; + + let mut result = TurboQuantData::try_new_mse( + array.dtype.clone(), + sliced_codes, + sliced_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + array.dimension, + array.bit_width, + )?; + if let Some(qjl) = sliced_qjl { + result.set_qjl(qjl); + } + + Ok(Some(result.into_array())) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs new file mode 100644 index 00000000000..d04b71e6e16 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/take.rs @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::dict::TakeExecute; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::array::QjlCorrection; +use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::array::TurboQuantData; + +impl TakeExecute for TurboQuant { + fn take( + array: ArrayView<'_, TurboQuant>, + indices: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // FSL children handle per-row take natively. + let taken_codes = array.codes().take(indices.clone())?; + let taken_norms = array.norms().take(indices.clone())?; + + let taken_qjl = array + .qjl() + .map(|qjl| -> VortexResult { + Ok(QjlCorrection { + signs: qjl.signs.take(indices.clone())?, + residual_norms: qjl.residual_norms.take(indices.clone())?, + rotation_signs: qjl.rotation_signs.clone(), + }) + }) + .transpose()?; + + let mut result = TurboQuantData::try_new_mse( + array.dtype.clone(), + taken_codes, + taken_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + array.dimension, + array.bit_width, + )?; + if let Some(qjl) = taken_qjl { + result.set_qjl(qjl); + } + + Ok(Some(result.into_array())) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs new file mode 100644 index 00000000000..b89fc6c54ad --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant decoding (dequantization) logic. + +use vortex_array::Array; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::rotation::RotationMatrix; + +/// QJL correction scale factor: `sqrt(π/2) / padded_dim`. +/// +/// Accounts for the SRHT normalization (`1/padded_dim^{3/2}` per transform) +/// combined with `E[|z|] = sqrt(2/π)` for half-normal sign expectations. +#[inline] +fn qjl_correction_scale(padded_dim: usize) -> f32 { + (std::f32::consts::FRAC_PI_2).sqrt() / (padded_dim as f32) +} + +/// Decompress a `TurboQuantArray` into a `FixedSizeListArray` of floats. +/// +/// Reads stored centroids and rotation signs from the array's children, +/// avoiding any recomputation. If QJL correction is present, applies +/// the residual correction after MSE decoding. +pub fn execute_decompress( + array: Array, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dim = array.dimension() as usize; + let padded_dim = array.padded_dim() as usize; + let num_rows = array.norms().len(); + + if num_rows == 0 { + let elements = PrimitiveArray::empty::(array.dtype.nullability()); + return Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + 0, + )? + .into_array()); + } + + // Read stored centroids — no recomputation. + let centroids_prim = array.centroids().clone().execute::(ctx)?; + let centroids = centroids_prim.as_slice::(); + + // FastLanes SIMD-unpacks the 1-bit bitpacked rotation signs into u8 0/1 values, + // then we expand to u32 XOR masks once (amortized over all rows). This enables + // branchless XOR-based sign application in the per-row SRHT hot loop. + let signs_prim = array + .rotation_signs() + .clone() + .execute::(ctx)?; + let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::(), dim)?; + + // Unpack codes from FixedSizeListArray → flat u8 elements. + let codes_fsl = array.codes().clone().execute::(ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let indices = codes_prim.as_slice::(); + + let norms_prim = array.norms().clone().execute::(ctx)?; + let norms = norms_prim.as_slice::(); + + // MSE decode: dequantize → inverse rotate → scale by norm. + let mut mse_output = BufferMut::::with_capacity(num_rows * dim); + let mut dequantized = vec![0.0f32; padded_dim]; + let mut unrotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim]; + let norm = norms[row]; + + for idx in 0..padded_dim { + dequantized[idx] = centroids[row_indices[idx] as usize]; + } + + rotation.inverse_rotate(&dequantized, &mut unrotated); + + for idx in 0..dim { + unrotated[idx] *= norm; + } + + mse_output.extend_from_slice(&unrotated[..dim]); + } + + // If no QJL correction, we're done. + let Some(qjl) = array.qjl() else { + let elements = PrimitiveArray::new::(mse_output.freeze(), Validity::NonNullable); + return Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )? + .into_array()); + }; + + // Apply QJL residual correction. + // Unpack QJL signs from FixedSizeListArray → flat u8 0/1 values. + let qjl_signs_fsl = qjl.signs.clone().execute::(ctx)?; + let qjl_signs_prim = qjl_signs_fsl.elements().to_canonical()?.into_primitive(); + let qjl_signs_u8 = qjl_signs_prim.as_slice::(); + + let residual_norms_prim = qjl.residual_norms.clone().execute::(ctx)?; + let residual_norms = residual_norms_prim.as_slice::(); + + let qjl_rot_signs_prim = qjl.rotation_signs.clone().execute::(ctx)?; + let qjl_rot = RotationMatrix::from_u8_slice(qjl_rot_signs_prim.as_slice::(), dim)?; + + let qjl_scale = qjl_correction_scale(padded_dim); + let mse_elements = mse_output.as_ref(); + + let mut output = BufferMut::::with_capacity(num_rows * dim); + let mut qjl_signs_vec = vec![0.0f32; padded_dim]; + let mut qjl_projected = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let mse_row = &mse_elements[row * dim..(row + 1) * dim]; + let residual_norm = residual_norms[row]; + + // Branchless u8 0/1 → f32 ±1.0 via XOR on the IEEE 754 sign bit. + // 1.0f32 = 0x3F800000; flipping the sign bit gives -1.0 = 0xBF800000. + // For sign=0 (negative): mask = 0x80000000, 1.0 XOR mask = -1.0. + // For sign=1 (positive): mask = 0x00000000, 1.0 XOR mask = 1.0. + let row_signs = &qjl_signs_u8[row * padded_dim..(row + 1) * padded_dim]; + for (dst, &sign) in qjl_signs_vec.iter_mut().zip(row_signs.iter()) { + let mask = ((sign as u32) ^ 1) << 31; + *dst = f32::from_bits(0x3F80_0000 ^ mask); + } + + qjl_rot.inverse_rotate(&qjl_signs_vec, &mut qjl_projected); + let scale = qjl_scale * residual_norm; + + for idx in 0..dim { + output.push(mse_row[idx] + scale * qjl_projected[idx]); + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )? + .into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs new file mode 100644 index 00000000000..b537db29029 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -0,0 +1,1018 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant vector quantization encoding for Vortex. +//! +//! Implements the TurboQuant algorithm ([arXiv:2504.19874]) for lossy compression of +//! high-dimensional vector data. The encoding operates on `FixedSizeList` arrays of floats +//! (the storage format of `Vector` and `FixedShapeTensor` extension types). +//! +//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! +//! # Variants +//! +//! - **MSE** (`TurboQuantVariant::Mse`): Minimizes mean-squared reconstruction error +//! (1-8 bits per coordinate). +//! - **Prod** (`TurboQuantVariant::Prod`): Preserves inner products with an unbiased +//! estimator (uses `b-1` bits for MSE + 1-bit QJL residual correction, 2-9 bits). +//! At `b=9`, the MSE codes are raw int8 values suitable for direct use with +//! tensor core int8 GEMM kernels. +//! +//! # Theoretical error bounds +//! +//! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1 +//! guarantees normalized MSE distortion: +//! +//! > `E[||x - x̂||² / ||x||²] ≤ (√3 · π / 2) / 4^b` +//! +//! | Bits | MSE bound | Quality | +//! |------|------------|-------------------| +//! | 1 | 6.80e-01 | Poor | +//! | 2 | 1.70e-01 | Usable for ANN | +//! | 3 | 4.25e-02 | Good | +//! | 4 | 1.06e-02 | Very good | +//! | 5 | 2.66e-03 | Excellent | +//! | 6 | 6.64e-04 | Near-lossless | +//! | 7 | 1.66e-04 | Near-lossless | +//! | 8 | 4.15e-05 | Near-lossless | +//! +//! # Compression ratios +//! +//! Each vector is stored as `padded_dim × bit_width / 8` bytes of quantized codes plus a +//! 4-byte f32 norm. Non-power-of-2 dimensions are padded to the next power of 2 for the +//! Walsh-Hadamard transform, which reduces the effective ratio for those dimensions. +//! +//! | dim | padded | bits | f32 bytes | TQ bytes | ratio | +//! |------|--------|------|-----------|----------|--------| +//! | 768 | 1024 | 2 | 3072 | 260 | 11.8x | +//! | 1024 | 1024 | 2 | 4096 | 260 | 15.8x | +//! | 768 | 1024 | 4 | 3072 | 516 | 6.0x | +//! | 1024 | 1024 | 4 | 4096 | 516 | 7.9x | +//! | 768 | 1024 | 8 | 3072 | 1028 | 3.0x | +//! | 1024 | 1024 | 8 | 4096 | 1028 | 4.0x | +//! +//! # Example +//! +//! ``` +//! use vortex_array::IntoArray; +//! use vortex_array::arrays::FixedSizeListArray; +//! use vortex_array::arrays::PrimitiveArray; +//! use vortex_array::validity::Validity; +//! use vortex_buffer::BufferMut; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_mse}; +//! +//! // Create a FixedSizeListArray of 100 random 128-d vectors. +//! let num_rows = 100; +//! let dim = 128; +//! let mut buf = BufferMut::::with_capacity(num_rows * dim); +//! for i in 0..(num_rows * dim) { +//! buf.push((i as f32 * 0.001).sin()); +//! } +//! let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); +//! let fsl = FixedSizeListArray::try_new( +//! elements.into_array(), dim as u32, Validity::NonNullable, num_rows, +//! ).unwrap(); +//! +//! // Quantize at 2 bits per coordinate using MSE-optimal encoding. +//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; +//! let encoded = turboquant_encode_mse(&fsl, &config).unwrap(); +//! +//! // Verify compression: 100 vectors × 128 dims × 4 bytes = 51200 bytes input. +//! assert!(encoded.nbytes() < 51200); +//! ``` + +pub use array::QjlCorrection; +pub use array::TurboQuant; +pub use array::TurboQuantData; +pub use compress::TurboQuantConfig; +pub use compress::turboquant_encode_mse; +pub use compress::turboquant_encode_qjl; + +mod array; +pub(crate) mod centroids; +mod compress; +pub(crate) mod compute; +pub(crate) mod decompress; +pub(crate) mod rotation; +pub mod scheme; +mod vtable; + +/// Extension ID for the `Vector` type from `vortex-tensor`. +pub const VECTOR_EXT_ID: &str = "vortex.tensor.vector"; + +/// Extension ID for the `FixedShapeTensor` type from `vortex-tensor`. +pub const FIXED_SHAPE_TENSOR_EXT_ID: &str = "vortex.tensor.fixed_shape_tensor"; + +use vortex_array::session::ArraySessionExt; +use vortex_session::VortexSession; + +/// Initialize the TurboQuant encoding in the given session. +pub fn initialize(session: &mut VortexSession) { + session.arrays().register(TurboQuant); +} + +#[cfg(test)] +#[allow(clippy::cast_possible_truncation)] +mod tests { + use std::sync::LazyLock; + + use rand::RngExt; + use rand::SeedableRng; + use rand::rngs::StdRng; + use rand_distr::Distribution; + use rand_distr::Normal; + use rstest::rstest; + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::session::ArraySession; + use vortex_array::validity::Validity; + use vortex_buffer::BufferMut; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::encodings::turboquant::TurboQuant; + use crate::encodings::turboquant::TurboQuantConfig; + use crate::encodings::turboquant::rotation::RotationMatrix; + use crate::encodings::turboquant::turboquant_encode_mse; + use crate::encodings::turboquant::turboquant_encode_qjl; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + /// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal). + fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + ) + .unwrap() + } + + fn theoretical_mse_bound(bit_width: u8) -> f32 { + let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; + sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) + } + + fn per_vector_normalized_mse( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, + ) -> f32 { + let mut total = 0.0f32; + for row in 0..num_rows { + let orig = &original[row * dim..(row + 1) * dim]; + let recon = &reconstructed[row * dim..(row + 1) * dim]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + total += err_sq / norm_sq; + } + total / num_rows as f32 + } + + /// Encode and decode, returning (original, decoded) flat f32 slices. + fn encode_decode( + fsl: &FixedSizeListArray, + encode_fn: impl FnOnce(&FixedSizeListArray) -> VortexResult, + ) -> VortexResult<(Vec, Vec)> { + let original: Vec = { + let prim = fsl.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + let encoded = encode_fn(fsl)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + let decoded_elements: Vec = { + let prim = decoded.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + Ok((original, decoded_elements)) + } + + fn encode_decode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| turboquant_encode_mse(fsl, &config)) + } + + fn encode_decode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| turboquant_encode_qjl(fsl, &config)) + } + + // ----------------------------------------------------------------------- + // MSE encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 1)] + #[case(32, 2)] + #[case(32, 3)] + #[case(32, 4)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 2)] + fn roundtrip_mse(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 2)] + #[case(256, 4)] + fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + let bound = theoretical_mse_bound(bit_width); + + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + #[rstest] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 6)] + #[case(256, 8)] + fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let (original_4, decoded_4) = encode_decode_mse(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" + ); + assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); + Ok(()) + } + + #[test] + fn mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 1..=8u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 2)] + #[case(32, 3)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + fn roundtrip_qjl(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(456), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + /// Compute the mean signed relative error of QJL inner product estimation + /// over random query/vector pairs. + fn qjl_mean_signed_relative_error( + original: &[f32], + decoded: &[f32], + dim: usize, + num_rows: usize, + ) -> f32 { + let num_pairs = 500; + let mut rng = StdRng::seed_from_u64(0); + let mut signed_errors = Vec::with_capacity(num_pairs); + + for _ in 0..num_pairs { + let qi = rng.random_range(0..num_rows); + let xi = rng.random_range(0..num_rows); + if qi == xi { + continue; + } + + let query = &original[qi * dim..(qi + 1) * dim]; + let orig_vec = &original[xi * dim..(xi + 1) * dim]; + let quant_vec = &decoded[xi * dim..(xi + 1) * dim]; + + let true_ip: f32 = query.iter().zip(orig_vec).map(|(&a, &b)| a * b).sum(); + let quant_ip: f32 = query.iter().zip(quant_vec).map(|(&a, &b)| a * b).sum(); + + if true_ip.abs() > 1e-6 { + signed_errors.push((quant_ip - true_ip) / true_ip.abs()); + } + } + + if signed_errors.is_empty() { + return 0.0; + } + + signed_errors.iter().sum::() / signed_errors.len() as f32 + } + + #[rstest] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + #[case(768, 4)] + fn qjl_inner_product_bias(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 100; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let mean_rel_error = qjl_mean_signed_relative_error(&original, &decoded, dim, num_rows); + + // Known limitation: non-power-of-2 dims have elevated QJL bias (~23% vs + // ~11%) due to distribution mismatch between the SRHT zero-padded coordinate + // distribution and the analytical (1-x^2)^((d-3)/2) model used for centroids. + // Investigated approaches: + // - Random permutation of zeros: no effect (issue is distribution shape) + // - MC empirical centroids: fixes QJL bias but regresses MSE quality + // - Analytical centroids with dim instead of padded_dim: mixed results + // The principled fix requires jointly correcting centroids and QJL scale + // factor for the actual SRHT zero-padded distribution. + let threshold = if dim.is_power_of_two() { 0.15 } else { 0.25 }; + assert!( + mean_rel_error.abs() < threshold, + "QJL inner product bias too high: {mean_rel_error:.4} for dim={dim}, bits={bit_width} \ + (threshold={threshold})" + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Edge cases + // ----------------------------------------------------------------------- + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_mse_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_qjl_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[rstest] + #[case(1)] + #[case(2)] + fn mse_rejects_dimension_below_3(#[case] dim: usize) { + let fsl = make_fsl_small(dim); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(0), + }; + assert!(turboquant_encode_mse(&fsl, &config).is_err()); + } + + #[rstest] + #[case(1)] + #[case(2)] + fn qjl_rejects_dimension_below_3(#[case] dim: usize) { + let fsl = make_fsl_small(dim); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(0), + }; + assert!(turboquant_encode_qjl(&fsl, &config).is_err()); + } + + fn make_fsl_small(dim: usize) -> FixedSizeListArray { + let mut buf = BufferMut::::with_capacity(dim); + for i in 0..dim { + buf.push(i as f32 + 1.0); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new(elements.into_array(), dim as u32, Validity::NonNullable, 1) + .unwrap() + } + + // ----------------------------------------------------------------------- + // Verification tests for stored metadata + // ----------------------------------------------------------------------- + + /// Verify that the centroids stored in the MSE array match what get_centroids() computes. + #[test] + fn stored_centroids_match_computed() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + let mut ctx = SESSION.create_execution_ctx(); + let stored_centroids_prim = encoded + .centroids() + .clone() + .execute::(&mut ctx)?; + let stored = stored_centroids_prim.as_slice::(); + + let padded_dim = encoded.padded_dim(); + let computed = crate::encodings::turboquant::centroids::get_centroids(padded_dim, 3)?; + + assert_eq!(stored.len(), computed.len()); + for i in 0..stored.len() { + assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); + } + Ok(()) + } + + /// Verify that stored rotation signs produce identical decode to seed-based decode. + /// + /// Encodes the same data twice: once with the new path (stored signs), and + /// once by manually recomputing the rotation from the seed. Both should + /// produce identical output. + #[test] + fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + // Decode via the stored-signs path (normal decode). + let mut ctx = SESSION.create_execution_ctx(); + let decoded_fsl = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_slice = decoded.as_slice::(); + + // Verify stored signs match seed-derived signs. + let rot_from_seed = RotationMatrix::try_new(123, 128)?; + let expected_u8 = rot_from_seed.export_inverse_signs_u8(); + let stored_signs = encoded + .rotation_signs() + .clone() + .execute::(&mut ctx)?; + let stored_u8 = stored_signs.as_slice::(); + + assert_eq!(expected_u8.len(), stored_u8.len()); + for i in 0..expected_u8.len() { + assert_eq!(expected_u8[i], stored_u8[i], "Sign mismatch at index {i}"); + } + + // Also verify decode output is non-empty and has expected size. + assert_eq!(decoded_slice.len(), 20 * 128); + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL-specific quality tests + // ----------------------------------------------------------------------- + + /// Verify that QJL's MSE component (at bit_width-1) satisfies the theoretical bound. + #[rstest] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 3)] + fn qjl_mse_within_theoretical_bound( + #[case] dim: usize, + #[case] bit_width: u8, + ) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + // QJL at b bits uses (b-1)-bit MSE plus a correction term. + // The MSE should be at most the (b-1)-bit theoretical bound, though + // in practice the QJL correction often improves it further. + let mse_bound = theoretical_mse_bound(bit_width - 1); + assert!( + normalized_mse < mse_bound, + "QJL MSE {normalized_mse:.6} exceeds (b-1)-bit bound {mse_bound:.6} \ + for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + /// Verify that high-bitwidth QJL (8-9 bits) achieves very low distortion. + #[rstest] + #[case(128, 8)] + #[case(128, 9)] + fn high_bitwidth_qjl_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + // Compare against 4-bit QJL as reference ceiling. + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(789), + }; + let (original_4, decoded_4) = encode_decode_qjl(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 4-bit ({mse_4bit:.6})" + ); + assert!( + mse < 0.01, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 1%" + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Edge case and input format tests + // ----------------------------------------------------------------------- + + /// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). + #[test] + fn all_zero_vectors_roundtrip() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let buf = BufferMut::::full(0.0f32, num_rows * dim); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + // All-zero vectors should decode to all-zero (norm=0 → 0 * anything = 0). + for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { + assert_eq!(o, 0.0, "original[{i}] not zero"); + assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); + } + Ok(()) + } + + /// Verify that f64 input is accepted and encoded (converted to f32 internally). + #[test] + fn f64_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 64; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f64, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + // Verify encoding succeeds with f64 input (f64→f32 conversion). + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + assert_eq!(encoded.norms().len(), num_rows); + assert_eq!(encoded.dimension(), dim as u32); + Ok(()) + } + + /// Verify serde roundtrip: serialize MSE array metadata + children, then rebuild. + #[test] + fn mse_serde_roundtrip() -> VortexResult<()> { + use vortex_array::vtable::VTable; + + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + // Serialize metadata. + let metadata = ::metadata(encoded)?; + let serialized = + ::serialize(metadata)?.expect("metadata should serialize"); + + // Collect children. + let nchildren = ::nchildren(encoded); + assert_eq!(nchildren, 4); + let children: Vec = (0..nchildren) + .map(|i| ::child(encoded, i)) + .collect(); + + // Deserialize and rebuild. + let deserialized = ::deserialize( + &serialized, + encoded.dtype(), + encoded.len(), + &[], + &SESSION, + )?; + + // Verify metadata fields survived roundtrip. + assert_eq!(deserialized.dimension, encoded.dimension()); + assert_eq!(deserialized.bit_width, encoded.bit_width() as u32); + assert_eq!(deserialized.has_qjl, encoded.has_qjl()); + + // Verify the rebuilt array decodes identically. + let mut ctx = SESSION.create_execution_ctx(); + let decoded_original = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); + + // Rebuild from children (simulating deserialization). + let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new_mse( + encoded.dtype().clone(), + children[0].clone(), + children[1].clone(), + children[2].clone(), + children[3].clone(), + deserialized.dimension, + deserialized.bit_width as u8, + )?; + let decoded_rebuilt = rebuilt + .into_array() + .execute::(&mut ctx)?; + let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); + + assert_eq!( + original_elements.as_slice::(), + rebuilt_elements.as_slice::() + ); + Ok(()) + } + + /// Verify serde roundtrip for QJL: serialize metadata + children, then rebuild. + #[test] + fn qjl_serde_roundtrip() -> VortexResult<()> { + use vortex_array::vtable::VTable; + + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + let encoded = encoded.as_opt::().unwrap(); + + // Serialize metadata. + let metadata = ::metadata(encoded)?; + let serialized = + ::serialize(metadata)?.expect("metadata should serialize"); + + // Collect children — QJL has 7 (4 MSE + 3 QJL). + let nchildren = ::nchildren(encoded); + assert_eq!(nchildren, 7); + let children: Vec = (0..nchildren) + .map(|i| ::child(encoded, i)) + .collect(); + + // Deserialize metadata. + let deserialized = ::deserialize( + &serialized, + encoded.dtype(), + encoded.len(), + &[], + &SESSION, + )?; + + assert!(deserialized.has_qjl); + assert_eq!(deserialized.dimension, encoded.dimension()); + + // Verify decode: original vs rebuilt from children. + let mut ctx = SESSION.create_execution_ctx(); + let decoded_original = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); + + // Rebuild with QJL children. + let rebuilt = crate::encodings::turboquant::array::TurboQuantData::try_new_qjl( + encoded.dtype().clone(), + children[0].clone(), + children[1].clone(), + children[2].clone(), + children[3].clone(), + crate::encodings::turboquant::array::QjlCorrection { + signs: children[4].clone(), + residual_norms: children[5].clone(), + rotation_signs: children[6].clone(), + }, + deserialized.dimension, + deserialized.bit_width as u8, + )?; + let decoded_rebuilt = rebuilt + .into_array() + .execute::(&mut ctx)?; + let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); + + assert_eq!( + original_elements.as_slice::(), + rebuilt_elements.as_slice::() + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Compute pushdown tests + // ----------------------------------------------------------------------- + + #[test] + fn slice_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + // Full decompress then slice. + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let expected = full_decoded.slice(5..10)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + // Slice then decompress. + let sliced = encoded.slice(5..10)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) + } + + #[test] + fn slice_qjl_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let expected = full_decoded.slice(3..8)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + let sliced = encoded.slice(3..8)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) + } + + #[test] + fn scalar_at_matches_decompress() -> VortexResult<()> { + let fsl = make_fsl(10, 64, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + + for i in [0, 1, 5, 9] { + let expected = full_decoded.scalar_at(i)?; + let actual = encoded.scalar_at(i)?; + assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); + } + Ok(()) + } + + #[test] + fn l2_norm_readthrough() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let tq = encoded.as_opt::().unwrap(); + + // Stored norms should match the actual L2 norms of the input. + let norms_prim = tq.norms().to_canonical()?.into_primitive(); + let stored_norms = norms_prim.as_slice::(); + + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + for row in 0..10 { + let vec = &input_f32[row * 128..(row + 1) * 128]; + let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); + assert!( + (stored_norms[row] - actual_norm).abs() < 1e-5, + "norm mismatch at row {row}: stored={}, actual={}", + stored_norms[row], + actual_norm + ); + } + Ok(()) + } + + #[test] + fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { + use vortex_array::arrays::FixedSizeListArray; + + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let tq = encoded.as_opt::().unwrap(); + + // Compute exact cosine similarity from original data. + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + + // Read quantized codes, norms, and centroids for approximate computation. + let mut ctx = SESSION.create_execution_ctx(); + let pd = tq.padded_dim() as usize; + let norms_prim = tq.norms().clone().execute::(&mut ctx)?; + let norms = norms_prim.as_slice::(); + let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let all_codes = codes_prim.as_slice::(); + let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; + let centroid_vals = centroids_prim.as_slice::(); + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); + let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); + let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); + let exact_cos = dot / (norm_a * norm_b); + + // Approximate cosine similarity in quantized domain. + let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { + 0.0 + } else { + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) + .sum::() + }; + + // 4-bit quantization: expect reasonable accuracy. + let error = (exact_cos - approx_cos).abs(); + assert!( + error < 0.15, + "cosine similarity error too large for ({row_a}, {row_b}): \ + exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" + ); + } + Ok(()) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/rotation.rs b/vortex-tensor/src/encodings/turboquant/rotation.rs new file mode 100644 index 00000000000..2f654349778 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/rotation.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Deterministic random rotation for TurboQuant. +//! +//! Uses a Structured Random Hadamard Transform (SRHT) for O(d log d) rotation +//! instead of a full d×d matrix multiply. The SRHT applies the sequence +//! D₃ · H · D₂ · H · D₁ where H is the Walsh-Hadamard Transform (WHT) and Dₖ are +//! random diagonal ±1 sign matrices. Three rounds of HD provide sufficient +//! randomness for near-uniform distribution on the sphere. +//! +//! For dimensions that are not powers of 2, the input is zero-padded to the +//! next power of 2 before the transform and truncated afterward. +//! +//! # Sign representation +//! +//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) +//! and `0x80000000` for -1 (flip IEEE 754 sign bit). The sign application +//! function uses integer XOR instead of floating-point multiply, which avoids +//! FP dependency chains and auto-vectorizes into `vpxor`/`veor`. + +use rand::RngExt; +use rand::SeedableRng; +use rand::rngs::StdRng; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// IEEE 754 sign bit mask for f32. +const F32_SIGN_BIT: u32 = 0x8000_0000; + +/// A structured random Hadamard transform for O(d log d) pseudo-random rotation. +pub struct RotationMatrix { + /// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`. + /// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit). + sign_masks: [Vec; 3], + /// The padded dimension (next power of 2 >= dimension). + padded_dim: usize, + /// Normalization factor: 1/(padded_dim * sqrt(padded_dim)), applied once at the end. + norm_factor: f32, +} + +impl RotationMatrix { + /// Create a new SRHT rotation from a deterministic seed. + pub fn try_new(seed: u64, dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + let mut rng = StdRng::seed_from_u64(seed); + + let sign_masks = std::array::from_fn(|_| gen_random_sign_masks(&mut rng, padded_dim)); + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } + + /// Apply forward rotation: `output = SRHT(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. The caller + /// is responsible for zero-padding input beyond `dim` positions. + pub fn rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_srht(output); + } + + /// Apply inverse rotation: `output = SRHT⁻¹(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. + pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_inverse_srht(output); + } + + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All rotate/inverse_rotate buffers must be this length. + pub fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the SRHT: D₃ · H · D₂ · H · D₁ · x, with normalization. + fn apply_srht(&self, buf: &mut [f32]) { + apply_signs_xor(buf, &self.sign_masks[0]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[1]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[2]); + walsh_hadamard_transform(buf); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Apply the inverse SRHT. + /// + /// Forward is: norm · H · D₃ · H · D₂ · H · D₁ + /// Inverse is: norm · D₁ · H · D₂ · H · D₃ · H + fn apply_inverse_srht(&self, buf: &mut [f32]) { + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[2]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[1]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[0]); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Export the 3 sign vectors as a flat `Vec` of 0/1 values in inverse + /// application order `[D₃ | D₂ | D₁]`. + /// + /// Convention: `1` = positive (+1), `0` = negative (-1). + /// The output has length `3 * padded_dim` and is suitable for bitpacking + /// via FastLanes `bitpack_encode(..., 1, None)`. + pub fn export_inverse_signs_u8(&self) -> Vec { + let total = 3 * self.padded_dim; + let mut out = Vec::with_capacity(total); + + // Store in inverse order: sign_masks[2] (D₃), sign_masks[1] (D₂), sign_masks[0] (D₁) + for sign_idx in [2, 1, 0] { + for &mask in &self.sign_masks[sign_idx] { + out.push(if mask == 0 { 1u8 } else { 0u8 }); + } + } + out + } + + /// Reconstruct a `RotationMatrix` from unpacked `u8` 0/1 values. + /// + /// The input must have length `3 * padded_dim` with signs in inverse + /// application order `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]). + /// Convention: `1` = positive, `0` = negative. + /// + /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the + /// stored `BitPackedArray` into `&[u8]`, which is passed here. + pub fn from_u8_slice(signs_u8: &[u8], dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + vortex_ensure!( + signs_u8.len() == 3 * padded_dim, + "Expected {} sign bytes, got {}", + 3 * padded_dim, + signs_u8.len() + ); + + // Reconstruct in storage order (inverse): [D₃, D₂, D₁] → sign_masks[2], [1], [0] + let mut sign_masks: [Vec; 3] = std::array::from_fn(|_| Vec::with_capacity(padded_dim)); + + for (round, sign_idx) in [2, 1, 0].into_iter().enumerate() { + let offset = round * padded_dim; + sign_masks[sign_idx] = signs_u8[offset..offset + padded_dim] + .iter() + .map(|&v| if v != 0 { 0u32 } else { F32_SIGN_BIT }) + .collect(); + } + + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } +} + +/// Generate a vector of random XOR sign masks. +fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec { + (0..len) + .map(|_| { + if rng.random_bool(0.5) { + 0u32 // +1: no-op + } else { + F32_SIGN_BIT // -1: flip sign bit + } + }) + .collect() +} + +/// Apply sign masks via XOR on the IEEE 754 sign bit. +/// +/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). +/// Equivalent to multiplying each element by ±1.0, but avoids FP dependency chains. +#[inline] +fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) { + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } +} + +/// In-place Walsh-Hadamard Transform (unnormalized, iterative). +/// +/// Input length must be a power of 2. Runs in O(n log n). +/// +/// Uses a fixed-size chunk strategy: for each stage, the buffer is processed +/// in `CHUNK`-element blocks with a compile-time-known butterfly function. +/// This lets LLVM unroll and auto-vectorize the butterfly into NEON/AVX SIMD. +fn walsh_hadamard_transform(buf: &mut [f32]) { + let len = buf.len(); + debug_assert!(len.is_power_of_two()); + + let mut half = 1; + while half < len { + let stride = half * 2; + // Process in chunks of `stride` elements. Within each chunk, + // split into non-overlapping (lo, hi) halves for the butterfly. + for chunk in buf.chunks_exact_mut(stride) { + let (lo, hi) = chunk.split_at_mut(half); + butterfly(lo, hi); + } + half *= 2; + } +} + +/// Butterfly: `lo[i], hi[i] = lo[i] + hi[i], lo[i] - hi[i]`. +/// +/// Separate function so LLVM can see the slice lengths match and auto-vectorize. +#[inline(always)] +fn butterfly(lo: &mut [f32], hi: &mut [f32]) { + debug_assert_eq!(lo.len(), hi.len()); + for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { + let sum = *a + *b; + let diff = *a - *b; + *a = sum; + *b = diff; + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[test] + fn deterministic_from_seed() -> VortexResult<()> { + let r1 = RotationMatrix::try_new(42, 64)?; + let r2 = RotationMatrix::try_new(42, 64)?; + let pd = r1.padded_dim(); + + let mut input = vec![0.0f32; pd]; + for i in 0..64 { + input[i] = i as f32; + } + let mut out1 = vec![0.0f32; pd]; + let mut out2 = vec![0.0f32; pd]; + + r1.rotate(&input, &mut out1); + r2.rotate(&input, &mut out2); + + assert_eq!(out1, out2); + Ok(()) + } + + /// Verify roundtrip is exact to f32 precision across many dimensions, + /// including non-power-of-two dimensions that require padding. + #[rstest] + #[case(32)] + #[case(64)] + #[case(100)] + #[case(128)] + #[case(256)] + #[case(512)] + #[case(768)] + #[case(1024)] + fn roundtrip_exact(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + let mut rotated = vec![0.0f32; padded_dim]; + let mut recovered = vec![0.0f32; padded_dim]; + + rot.rotate(&input, &mut rotated); + rot.inverse_rotate(&rotated, &mut recovered); + + let max_err: f32 = input + .iter() + .zip(recovered.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); + let rel_err = max_err / max_val; + + // SRHT roundtrip should be exact up to f32 precision (~1e-6). + assert!( + rel_err < 1e-5, + "roundtrip relative error too large for dim={dim}: {rel_err:.2e}" + ); + Ok(()) + } + + /// Verify norm preservation across dimensions. + #[rstest] + #[case(128)] + #[case(768)] + fn preserves_norm(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(7, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32) * 0.01; + } + let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); + + let mut rotated = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut rotated); + let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::().sqrt(); + + assert!( + (input_norm - rotated_norm).abs() / input_norm < 1e-5, + "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", + input_norm, + rotated_norm, + (input_norm - rotated_norm).abs() / input_norm + ); + Ok(()) + } + + /// Verify that export → from_u8_slice produces identical rotation output. + #[rstest] + #[case(64)] + #[case(128)] + #[case(768)] + fn sign_export_import_roundtrip(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let signs_u8 = rot.export_inverse_signs_u8(); + let rot2 = RotationMatrix::from_u8_slice(&signs_u8, dim)?; + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + + let mut out1 = vec![0.0f32; padded_dim]; + let mut out2 = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut out1); + rot2.rotate(&input, &mut out2); + assert_eq!(out1, out2, "Forward rotation mismatch after export/import"); + + rot.inverse_rotate(&out1, &mut out2); + let mut out3 = vec![0.0f32; padded_dim]; + rot2.inverse_rotate(&out1, &mut out3); + assert_eq!(out2, out3, "Inverse rotation mismatch after export/import"); + + Ok(()) + } + + #[test] + fn wht_basic() { + // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] + let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; + walsh_hadamard_transform(&mut buf); + assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); + + // WHT is self-inverse (up to scaling by n) + walsh_hadamard_transform(&mut buf); + // After two WHTs: each element multiplied by n=4 + assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs new file mode 100644 index 00000000000..6db642ae25f --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant compression scheme for the pluggable compressor. + +use vortex_array::ArrayRef; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_compressor::CascadingCompressor; +use vortex_compressor::ctx::CompressorContext; +use vortex_compressor::scheme::Scheme; +use vortex_compressor::stats::ArrayAndStats; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; + +use super::FIXED_SHAPE_TENSOR_EXT_ID; +use super::TurboQuantConfig; +use super::VECTOR_EXT_ID; +use super::turboquant_encode_qjl; + +/// TurboQuant compression scheme for tensor extension types. +/// +/// Applies lossy vector quantization to `Vector` and `FixedShapeTensor` extension +/// arrays using the TurboQuant algorithm with QJL correction for unbiased inner +/// product estimation. +/// +/// Register this scheme with the compressor builder via `with_scheme`: +/// ```ignore +/// use vortex_btrblocks::BtrBlocksCompressorBuilder; +/// use vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME; +/// +/// let compressor = BtrBlocksCompressorBuilder::default() +/// .with_scheme(&TURBOQUANT_SCHEME) +/// .build(); +/// ``` +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct TurboQuantScheme; + +/// Static instance for registration with `BtrBlocksCompressorBuilder::with_scheme`. +pub static TURBOQUANT_SCHEME: TurboQuantScheme = TurboQuantScheme; + +impl Scheme for TurboQuantScheme { + fn scheme_name(&self) -> &'static str { + "vortex.tensor.turboquant" + } + + fn matches(&self, canonical: &Canonical) -> bool { + let Canonical::Extension(ext) = canonical else { + return false; + }; + + get_tensor_element_ptype_and_length(ext.dtype()).is_ok() + } + + fn expected_compression_ratio( + &self, + _compressor: &CascadingCompressor, + data: &mut ArrayAndStats, + _ctx: CompressorContext, + ) -> VortexResult { + let dtype = data.array().dtype(); + let len = data.array().len(); + let (element_ptype, dimensions) = get_tensor_element_ptype_and_length(dtype)?; + Ok(estimate_compression_ratio( + element_ptype.bit_width(), + dimensions, + len, + )) + } + + fn compress( + &self, + _compressor: &CascadingCompressor, + data: &mut ArrayAndStats, + _ctx: CompressorContext, + ) -> VortexResult { + let array = data.array().clone(); + let ext_array = array.to_canonical()?.into_extension(); + let storage = ext_array.storage_array(); + let fsl = storage.to_canonical()?.into_fixed_size_list(); + + let config = TurboQuantConfig::default(); + let encoded = turboquant_encode_qjl(&fsl, &config)?; + + Ok(ExtensionArray::new(ext_array.ext_dtype().clone(), encoded).into_array()) + } +} + +/// Estimate the compression ratio for TurboQuant QJL encoding with the default config. +/// +/// Uses the default [`TurboQuantConfig`] (5-bit QJL = 4-bit MSE + 1-bit QJL signs). +fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vectors: usize) -> f64 { + let config = TurboQuantConfig::default(); + let padded_dim = dimensions.next_power_of_two() as usize; + + // Per-vector: MSE codes + QJL signs per padded coordinate, + // plus two f32 values (norm and QJL residual norm). + let compressed_bits_per_vector = 2 * 32 // norm + residual_norm are always f32 + + (config.bit_width as usize) * padded_dim; // MSE codes + QJL sign bits + + // Shared overhead: codebook centroids (2^mse_bit_width f32 values) and + // rotation signs (3 * padded_dim bits each for MSE and QJL rotations). + let mse_bit_width = config.bit_width - 1; // QJL uses bit_width-1 for MSE + let num_centroids = 1usize << mse_bit_width; + let overhead_bits = num_centroids * 32 // centroids are always f32 + + 2 * 3 * padded_dim; // MSE + QJL rotation signs, 1 bit each + + let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits; + let uncompressed_size_bits = bits_per_element * num_vectors * dimensions as usize; + uncompressed_size_bits as f64 / compressed_size_bits as f64 +} + +fn get_tensor_element_ptype_and_length(dtype: &DType) -> VortexResult<(PType, u32)> { + let ext_id = dtype.as_extension().id(); + let is_tensor = dtype.is_extension() + && (ext_id.as_ref() == VECTOR_EXT_ID || ext_id.as_ref() == FIXED_SHAPE_TENSOR_EXT_ID); + vortex_ensure!(is_tensor, "expected tensor extension dtype, got {}", dtype); + + let storage_dtype = dtype.as_extension().storage_dtype(); + let (element_dtype, fsl_len) = match storage_dtype { + DType::FixedSizeList(element_dtype, list_size, _) => (element_dtype, list_size), + _ => vortex_bail!( + "expected FixedSizeList storage dtype, got {}", + storage_dtype + ), + }; + + // TurboQuant requires dimension >= 3: the marginal coordinate distribution + // (1 - x^2)^((d-3)/2) has a singularity at d=2 (arcsine distribution) that + // causes NaN in the Max-Lloyd centroid computation. + vortex_ensure!( + *fsl_len >= 3, + "TurboQuant requires dimension >= 3, got {}", + fsl_len + ); + + if let &DType::Primitive(ptype, Nullability::NonNullable) = element_dtype.as_ref() { + Ok((ptype, *fsl_len)) + } else { + vortex_bail!( + "expected non-nullable primitive element type, got {}", + element_dtype + ); + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + /// Verify compression ratio for typical embedding dimensions. + /// + /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~4-6x. + /// f32 input at 1024-d (no padding) should give higher ratio since no waste. + #[rstest] + #[case::f32_768d(32, 768, 1000, 3.5, 6.0)] + #[case::f32_1024d(32, 1024, 1000, 4.5, 7.0)] + #[case::f32_1536d(32, 1536, 1000, 3.0, 6.0)] + #[case::f32_128d(32, 128, 1000, 4.0, 6.0)] + #[case::f64_768d(64, 768, 1000, 7.0, 12.0)] + #[case::f16_768d(16, 768, 1000, 1.5, 3.5)] + fn compression_ratio_in_expected_range( + #[case] bits_per_element: usize, + #[case] dim: u32, + #[case] num_vectors: usize, + #[case] min_ratio: f64, + #[case] max_ratio: f64, + ) { + let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + assert!( + ratio > min_ratio && ratio < max_ratio, + "ratio {ratio:.2} not in [{min_ratio}, {max_ratio}] for \ + {bits_per_element}-bit elements, dim={dim}, n={num_vectors}" + ); + } + + /// Compression ratio must always be > 1 for reasonable inputs, + /// otherwise TurboQuant makes things bigger and should not be selected. + #[rstest] + #[case(32, 128, 100)] + #[case(32, 768, 10)] + #[case(64, 256, 50)] + fn ratio_always_greater_than_one( + #[case] bits_per_element: usize, + #[case] dim: u32, + #[case] num_vectors: usize, + ) { + let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + assert!( + ratio > 1.0, + "ratio {ratio:.4} <= 1.0 for {bits_per_element}-bit, dim={dim}, n={num_vectors}" + ); + } + + /// Power-of-2 dimensions should have better ratios than their non-power-of-2 + /// predecessors due to no padding waste. + #[test] + fn power_of_two_has_better_ratio() { + let ratio_768 = estimate_compression_ratio(32, 768, 1000); + let ratio_1024 = estimate_compression_ratio(32, 1024, 1000); + assert!( + ratio_1024 > ratio_768, + "1024-d ratio ({ratio_1024:.2}) should exceed 768-d ({ratio_768:.2})" + ); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs new file mode 100644 index 00000000000..f309ca8bc7c --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! VTable implementation for TurboQuant encoding. + +use std::hash::Hash; +use std::sync::Arc; + +use vortex_array::Array; +use vortex_array::ArrayEq; +use vortex_array::ArrayHash; +use vortex_array::ArrayId; +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::DeserializeMetadata; +use vortex_array::ExecutionCtx; +use vortex_array::ExecutionResult; +use vortex_array::Precision; +use vortex_array::ProstMetadata; +use vortex_array::SerializeMetadata; +use vortex_array::buffer::BufferHandle; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::serde::ArrayChildren; +use vortex_array::stats::ArrayStats; +use vortex_array::vtable::VTable; +use vortex_array::vtable::ValidityChild; +use vortex_array::vtable::ValidityVTableFromChild; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_panic; +use vortex_session::VortexSession; + +use crate::encodings::turboquant::array::Slot; +use crate::encodings::turboquant::array::TurboQuant; +use crate::encodings::turboquant::array::TurboQuantData; +use crate::encodings::turboquant::array::TurboQuantMetadata; +use crate::encodings::turboquant::decompress::execute_decompress; + +impl VTable for TurboQuant { + type ArrayData = TurboQuantData; + type Metadata = ProstMetadata; + type OperationsVTable = TurboQuant; + type ValidityVTable = ValidityVTableFromChild; + + fn vtable(_array: &Self::ArrayData) -> &Self { + &TurboQuant + } + + fn id(&self) -> ArrayId { + Self::ID + } + + fn len(array: &TurboQuantData) -> usize { + array.norms().len() + } + + fn dtype(array: &TurboQuantData) -> &DType { + &array.dtype + } + + fn stats(array: &TurboQuantData) -> &ArrayStats { + &array.stats_set + } + + fn array_hash( + array: &TurboQuantData, + state: &mut H, + precision: Precision, + ) { + array.dtype.hash(state); + array.dimension.hash(state); + array.bit_width.hash(state); + for slot in &array.slots { + slot.is_some().hash(state); + if let Some(child) = slot { + child.array_hash(state, precision); + } + } + } + + fn array_eq(array: &TurboQuantData, other: &TurboQuantData, precision: Precision) -> bool { + array.dtype == other.dtype + && array.dimension == other.dimension + && array.bit_width == other.bit_width + && array.slots.len() == other.slots.len() + && array + .slots + .iter() + .zip(other.slots.iter()) + .all(|(a, b)| match (a, b) { + (Some(a), Some(b)) => a.array_eq(b, precision), + (None, None) => true, + _ => false, + }) + } + + fn nbuffers(_array: ArrayView) -> usize { + 0 + } + + fn buffer(_array: ArrayView, idx: usize) -> BufferHandle { + vortex_panic!("TurboQuantArray buffer index {idx} out of bounds") + } + + fn buffer_name(_array: ArrayView, _idx: usize) -> Option { + None + } + + fn slots(array: ArrayView<'_, Self>) -> &[Option] { + &array.data().slots + } + + fn slot_name(_array: ArrayView, idx: usize) -> String { + Slot::from_index(idx).name().to_string() + } + + fn with_slots(array: &mut TurboQuantData, slots: Vec>) -> VortexResult<()> { + vortex_ensure!( + slots.len() == Slot::COUNT, + "TurboQuantArray expects {} slots, got {}", + Slot::COUNT, + slots.len() + ); + array.slots = slots; + Ok(()) + } + + fn metadata(array: ArrayView) -> VortexResult { + Ok(ProstMetadata(TurboQuantMetadata { + dimension: array.dimension, + bit_width: array.bit_width as u32, + has_qjl: array.has_qjl(), + })) + } + + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(metadata.serialize())) + } + + fn deserialize( + bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(ProstMetadata( + as DeserializeMetadata>::deserialize(bytes)?, + )) + } + + #[allow(clippy::cast_possible_truncation)] + fn build( + dtype: &DType, + len: usize, + metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + let bit_width = u8::try_from(metadata.bit_width)?; + let padded_dim = metadata.dimension.next_power_of_two() as usize; + let num_centroids = 1usize << bit_width; + + let u8_nn = DType::Primitive(PType::U8, Nullability::NonNullable); + let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable); + let codes_dtype = DType::FixedSizeList( + Arc::new(u8_nn.clone()), + padded_dim as u32, + Nullability::NonNullable, + ); + let codes = children.get(0, &codes_dtype, len)?; + + let norms = children.get(1, &f32_nn, len)?; + let centroids = children.get(2, &f32_nn, num_centroids)?; + + let signs_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); + let rotation_signs = children.get(3, &signs_dtype, 3 * padded_dim)?; + + let mut slots = vec![None; Slot::COUNT]; + slots[Slot::Codes as usize] = Some(codes); + slots[Slot::Norms as usize] = Some(norms); + slots[Slot::Centroids as usize] = Some(centroids); + slots[Slot::RotationSigns as usize] = Some(rotation_signs); + + if metadata.has_qjl { + let qjl_signs_dtype = + DType::FixedSizeList(Arc::new(u8_nn), padded_dim as u32, Nullability::NonNullable); + slots[Slot::QjlSigns as usize] = Some(children.get(4, &qjl_signs_dtype, len)?); + slots[Slot::QjlResidualNorms as usize] = Some(children.get(5, &f32_nn, len)?); + slots[Slot::QjlRotationSigns as usize] = + Some(children.get(6, &signs_dtype, 3 * padded_dim)?); + } + + Ok(TurboQuantData { + dtype: dtype.clone(), + slots, + dimension: metadata.dimension, + bit_width, + stats_set: Default::default(), + }) + } + + fn reduce_parent( + array: ArrayView, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + crate::encodings::turboquant::compute::rules::RULES.evaluate(array, parent, child_idx) + } + + fn execute_parent( + array: ArrayView, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + crate::encodings::turboquant::compute::rules::PARENT_KERNELS + .execute(array, parent, child_idx, ctx) + } + + fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(ExecutionResult::done(execute_decompress(array, ctx)?)) + } +} + +impl ValidityChild for TurboQuant { + fn validity_child(array: &TurboQuantData) -> &ArrayRef { + array.codes() + } +} diff --git a/vortex-tensor/src/fixed_shape/metadata.rs b/vortex-tensor/src/fixed_shape/metadata.rs index fb46c67d213..264d18453c4 100644 --- a/vortex-tensor/src/fixed_shape/metadata.rs +++ b/vortex-tensor/src/fixed_shape/metadata.rs @@ -4,10 +4,10 @@ use std::fmt; use itertools::Either; -use vortex::error::VortexExpect; -use vortex::error::VortexResult; -use vortex::error::vortex_ensure; -use vortex::error::vortex_ensure_eq; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; /// Metadata for a `FixedShapeTensor` extension type. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/vortex-tensor/src/fixed_shape/proto.rs b/vortex-tensor/src/fixed_shape/proto.rs index 06b4f45b726..9d89c56bcec 100644 --- a/vortex-tensor/src/fixed_shape/proto.rs +++ b/vortex-tensor/src/fixed_shape/proto.rs @@ -4,9 +4,9 @@ //! Protobuf serialization for [`FixedShapeTensorMetadata`]. use prost::Message; -use vortex::error::VortexExpect; -use vortex::error::VortexResult; -use vortex::error::vortex_err; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_err; use crate::fixed_shape::FixedShapeTensorMetadata; diff --git a/vortex-tensor/src/fixed_shape/vtable.rs b/vortex-tensor/src/fixed_shape/vtable.rs index 3c0b6512a65..21ab1ef5336 100644 --- a/vortex-tensor/src/fixed_shape/vtable.rs +++ b/vortex-tensor/src/fixed_shape/vtable.rs @@ -1,15 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex::dtype::DType; -use vortex::dtype::extension::ExtDType; -use vortex::dtype::extension::ExtId; -use vortex::dtype::extension::ExtVTable; -use vortex::error::VortexResult; -use vortex::error::vortex_bail; -use vortex::error::vortex_ensure; -use vortex::error::vortex_ensure_eq; -use vortex::scalar::ScalarValue; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; @@ -77,8 +77,8 @@ impl ExtVTable for FixedShapeTensor { #[cfg(test)] mod tests { use rstest::rstest; - use vortex::dtype::extension::ExtVTable; - use vortex::error::VortexResult; + use vortex_array::dtype::extension::ExtVTable; + use vortex_error::VortexResult; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index c036b9854b2..b8b4a0ea169 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,6 +5,18 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +use vortex_array::dtype::session::DTypeSessionExt; +use vortex_array::scalar_fn::session::ScalarFnSessionExt; +use vortex_array::session::ArraySessionExt; +use vortex_session::VortexSession; + +use crate::encodings::turboquant::TurboQuant; +use crate::fixed_shape::FixedShapeTensor; +use crate::scalar_fns::cosine_similarity::CosineSimilarity; +use crate::scalar_fns::dot_product::DotProduct; +use crate::scalar_fns::l2_norm::L2Norm; +use crate::vector::Vector; + pub mod matcher; pub mod scalar_fns; @@ -14,3 +26,30 @@ pub mod vector; pub mod encodings; mod utils; + +/// Initialize the Vortex tensor library with a Vortex session. +pub fn initialize(session: &VortexSession) { + session.dtypes().register(Vector); + session.dtypes().register(FixedShapeTensor); + + session.arrays().register(TurboQuant); + + session.scalar_fns().register(CosineSimilarity); + session.scalar_fns().register(DotProduct); + session.scalar_fns().register(L2Norm); +} + +#[cfg(test)] +mod test { + use std::sync::LazyLock; + + use vortex_session::VortexSession; + + use crate::initialize; + + pub(crate) const SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty(); + initialize(&session); + session + }); +} diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index bb79ad7447e..16bca0bb043 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -3,8 +3,8 @@ //! Matcher for tensor-like extension types. -use vortex::dtype::extension::ExtDTypeRef; -use vortex::dtype::extension::Matcher; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_array::dtype::extension::Matcher; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 5155a7c8f08..f9d4f638fb2 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -8,24 +8,26 @@ use std::fmt::Formatter; use num_traits::Float; -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::array::match_each_float_ptype; -use vortex::dtype::DType; -use vortex::dtype::NativePType; -use vortex::dtype::Nullability; -use vortex::error::VortexResult; -use vortex::error::vortex_ensure; -use vortex::error::vortex_err; -use vortex::expr::Expression; -use vortex::scalar_fn::Arity; -use vortex::scalar_fn::ChildName; -use vortex::scalar_fn::ExecutionArgs; -use vortex::scalar_fn::ScalarFnId; -use vortex::scalar_fn::ScalarFnVTable; - +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::expr::Expression; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::compute::cosine_similarity; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::utils::extension_element_ptype; @@ -113,7 +115,7 @@ impl ScalarFnVTable for CosineSimilarity { fn execute( &self, - _options: &Self::Options, + options: &Self::Options, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { @@ -135,6 +137,16 @@ impl ScalarFnVTable for CosineSimilarity { let lhs_storage = extension_storage(&lhs)?; let rhs_storage = extension_storage(&rhs)?; + // TurboQuant approximate path: compute dot product in quantized domain. + if *options == ApproxOptions::Approximate { + if let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_storage.as_opt::(), + rhs_storage.as_opt::(), + ) { + return cosine_similarity::cosine_similarity_quantized_column(lhs_tq, rhs_tq, ctx); + } + } + let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?; let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?; @@ -156,7 +168,7 @@ impl ScalarFnVTable for CosineSimilarity { let lhs_validity = expression.child(0).validity()?; let rhs_validity = expression.child(1).validity()?; - Ok(Some(vortex::expr::and(lhs_validity, rhs_validity))) + Ok(Some(vortex_array::expr::and(lhs_validity, rhs_validity))) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { @@ -188,11 +200,11 @@ fn cosine_similarity_row(a: &[T], b: &[T]) -> T { #[cfg(test)] mod tests { use rstest::rstest; - use vortex::array::ArrayRef; - use vortex::array::ToCanonical; - use vortex::array::arrays::ScalarFnArray; - use vortex::error::VortexResult; - use vortex::scalar_fn::ScalarFn; + use vortex_array::ArrayRef; + use vortex_array::ToCanonical; + use vortex_array::arrays::ScalarFnArray; + use vortex_array::scalar_fn::ScalarFn; + use vortex_error::VortexResult; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::cosine_similarity::CosineSimilarity; diff --git a/vortex-tensor/src/scalar_fns/dot_product.rs b/vortex-tensor/src/scalar_fns/dot_product.rs new file mode 100644 index 00000000000..1bfbafd7c6a --- /dev/null +++ b/vortex-tensor/src/scalar_fns/dot_product.rs @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Dot product (inner product) expression for tensor-like extension arrays +//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and +//! [`Vector`](crate::vector::Vector)). + +use std::fmt::Formatter; + +use num_traits::Float; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::expr::Expression; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::compute::cosine_similarity; +use crate::matcher::AnyTensor; +use crate::scalar_fns::ApproxOptions; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; + +/// Dot product (inner product) of two tensor or vector columns. +/// +/// Computes ` = sum(a_i * b_i)` over the flat backing buffers. +/// +/// Both inputs must be tensor-like extension arrays with the same float element type +/// and dimensions. The output is a float column of the same float type. +#[derive(Clone)] +pub struct DotProduct; + +impl ScalarFnVTable for DotProduct { + type Options = ApproxOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new_ref("vortex.tensor.dot_product") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("lhs"), + 1 => ChildName::from("rhs"), + _ => unreachable!("DotProduct must have exactly two children"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "dot_product(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", ")?; + expr.child(1).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let lhs = &arg_dtypes[0]; + let rhs = &arg_dtypes[1]; + + let lhs_ext = lhs + .as_extension_opt() + .ok_or_else(|| vortex_err!("DotProduct lhs must be an extension type, got {lhs}"))?; + + vortex_ensure!( + lhs_ext.is::(), + "DotProduct inputs must be an `AnyTensor`, got {lhs}" + ); + + let lhs_ptype = extension_element_ptype(lhs_ext)?; + vortex_ensure!( + lhs_ptype.is_float(), + "DotProduct element dtype must be a float primitive, got {lhs_ptype}" + ); + + let rhs_ext = rhs + .as_extension_opt() + .ok_or_else(|| vortex_err!("DotProduct rhs must be an extension type, got {rhs}"))?; + + vortex_ensure!( + rhs_ext.is::(), + "DotProduct inputs must be an `AnyTensor`, got {rhs}" + ); + + let rhs_ptype = extension_element_ptype(rhs_ext)?; + vortex_ensure!( + lhs_ptype == rhs_ptype, + "DotProduct inputs must have the same element type, got {lhs_ptype} and {rhs_ptype}" + ); + + let lhs_dim = extension_list_size(lhs_ext)?; + let rhs_dim = extension_list_size(rhs_ext)?; + vortex_ensure!( + lhs_dim == rhs_dim, + "DotProduct inputs must have the same dimension, got {lhs_dim} and {rhs_dim}" + ); + + let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); + Ok(DType::Primitive(lhs_ptype, nullability)) + } + + fn execute( + &self, + options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let lhs = args.get(0)?; + let rhs = args.get(1)?; + let row_count = args.row_count(); + + let ext = lhs.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "dot_product input must be an extension type, got {}", + lhs.dtype() + ) + })?; + let list_size = extension_list_size(ext)? as usize; + + let lhs_storage = extension_storage(&lhs)?; + let rhs_storage = extension_storage(&rhs)?; + + // TurboQuant approximate path: norm_a * norm_b * quantized unit-norm dot. + if *options == ApproxOptions::Approximate { + if let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_storage.as_opt::(), + rhs_storage.as_opt::(), + ) { + return cosine_similarity::dot_product_quantized_column(lhs_tq, rhs_tq, ctx); + } + } + + let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?; + let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?; + + match_each_float_ptype!(lhs_flat.ptype(), |T| { + let result: PrimitiveArray = (0..row_count) + .map(|i| dot_product_row(lhs_flat.row::(i), rhs_flat.row::(i))) + .collect(); + + Ok(result.into_array()) + }) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + let lhs_validity = expression.child(0).validity()?; + let rhs_validity = expression.child(1).validity()?; + + Ok(Some(vortex_array::expr::and(lhs_validity, rhs_validity))) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +/// Computes the dot product (inner product) of two float slices. +fn dot_product_row(a: &[T], b: &[T]) -> T { + a.iter() + .zip(b.iter()) + .map(|(&x, &y)| x * y) + .fold(T::zero(), |acc, v| acc + v) +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::arrays::ScalarFnArray; + use vortex_array::scalar_fn::ScalarFn; + use vortex_error::VortexResult; + + use crate::scalar_fns::ApproxOptions; + use crate::scalar_fns::dot_product::DotProduct; + use crate::test::SESSION; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::vector_array; + + fn eval_dot_product( + lhs: ArrayRef, + rhs: ArrayRef, + len: usize, + options: ApproxOptions, + ) -> VortexResult> { + let mut ctx = SESSION.create_execution_ctx(); + let scalar_fn = ScalarFn::new(DotProduct, options).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let prim = result.into_array().execute::(&mut ctx)?; + Ok(prim.as_slice::().to_vec()) + } + + #[rstest] + #[case::orthogonal(&[1.0, 0.0], &[0.0, 1.0], 0.0)] + #[case::parallel(&[3.0, 4.0], &[3.0, 4.0], 25.0)] + #[case::antiparallel(&[1.0, 2.0], &[-1.0, -2.0], -5.0)] + #[case::scaled(&[2.0, 0.0], &[3.0, 0.0], 6.0)] + fn known_dot_products( + #[case] a: &[f64], + #[case] b: &[f64], + #[case] expected: f64, + ) -> VortexResult<()> { + #[allow(clippy::cast_possible_truncation)] + let dim = a.len() as u32; + let lhs = vector_array(dim, a)?; + let rhs = vector_array(dim, b)?; + assert_close( + &eval_dot_product(lhs, rhs, 1, ApproxOptions::Exact)?, + &[expected], + ); + Ok(()) + } + + #[test] + fn multiple_rows() -> VortexResult<()> { + let lhs = vector_array(2, &[1.0, 0.0, 3.0, 4.0])?; + let rhs = vector_array(2, &[0.0, 1.0, 3.0, 4.0])?; + assert_close( + &eval_dot_product(lhs, rhs, 2, ApproxOptions::Exact)?, + &[0.0, 25.0], + ); + Ok(()) + } +} diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index b02035d4572..e8ea1e85660 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -8,24 +8,26 @@ use std::fmt::Formatter; use num_traits::Float; -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::array::match_each_float_ptype; -use vortex::dtype::DType; -use vortex::dtype::NativePType; -use vortex::dtype::Nullability; -use vortex::error::VortexResult; -use vortex::error::vortex_ensure; -use vortex::error::vortex_err; -use vortex::expr::Expression; -use vortex::scalar_fn::Arity; -use vortex::scalar_fn::ChildName; -use vortex::scalar_fn::ExecutionArgs; -use vortex::scalar_fn::ScalarFnId; -use vortex::scalar_fn::ScalarFnVTable; - +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::expr::Expression; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use crate::encodings::turboquant::TurboQuant; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::utils::extension_element_ptype; @@ -103,7 +105,7 @@ impl ScalarFnVTable for L2Norm { let input = args.get(0)?; let row_count = args.row_count(); - // Get list size (dimensions) from the dtype. + // Get element ptype and list size from the dtype. let ext = input.dtype().as_extension_opt().ok_or_else(|| { vortex_err!( "l2_norm input must be an extension type, got {}", @@ -111,8 +113,21 @@ impl ScalarFnVTable for L2Norm { ) })?; let list_size = extension_list_size(ext)? as usize; + let target_ptype = extension_element_ptype(ext)?; let storage = extension_storage(&input)?; + + // TurboQuant stores exact precomputed norms — no decompression needed. + // Norms are currently stored as f32; cast to the target dtype if needed + // (e.g., if the input extension has f64 elements). + { + if let Some(tq) = storage.as_opt::() { + let norms: PrimitiveArray = tq.norms().clone().execute(ctx)?; + let target_dtype = DType::Primitive(target_ptype, input.dtype().nullability()); + return norms.into_array().cast(target_dtype); + } + } + let flat = extract_flat_elements(&storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { @@ -156,10 +171,11 @@ fn l2_norm_row(v: &[T]) -> T { #[cfg(test)] mod tests { use rstest::rstest; - use vortex::array::ToCanonical; - use vortex::array::arrays::ScalarFnArray; - use vortex::error::VortexResult; - use vortex::scalar_fn::ScalarFn; + use vortex_array::ArrayRef; + use vortex_array::ToCanonical; + use vortex_array::arrays::ScalarFnArray; + use vortex_array::scalar_fn::ScalarFn; + use vortex_error::VortexResult; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_norm::L2Norm; @@ -168,7 +184,7 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. - fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { + fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { let scalar_fn = ScalarFn::new(L2Norm, ApproxOptions::Exact).erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; let prim = result.as_array().to_primitive(); diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index c699b46cfca..3b8c4429025 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -6,6 +6,7 @@ use std::fmt; pub mod cosine_similarity; +pub mod dot_product; pub mod l2_norm; /// Options for tensor-related expressions that might have error. diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 82e2d1f5b45..3ccd38cb7ea 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -1,22 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::arrays::Constant; -use vortex::array::arrays::ConstantArray; -use vortex::array::arrays::Extension; -use vortex::array::arrays::FixedSizeListArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::dtype::DType; -use vortex::dtype::NativePType; -use vortex::dtype::PType; -use vortex::dtype::extension::ExtDTypeRef; -use vortex::error::VortexResult; -use vortex::error::vortex_bail; -use vortex::error::vortex_ensure; -use vortex::error::vortex_err; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::Constant; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::Extension; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; /// Extracts the list size from a tensor-like extension dtype. /// @@ -123,21 +123,21 @@ pub fn extract_flat_elements( #[cfg(test)] pub mod test_helpers { - use vortex::array::ArrayRef; - use vortex::array::ExecutionCtx; - use vortex::array::IntoArray; - use vortex::array::arrays::ConstantArray; - use vortex::array::arrays::ExtensionArray; - use vortex::array::arrays::FixedSizeListArray; - use vortex::array::validity::Validity; - use vortex::buffer::Buffer; - use vortex::dtype::DType; - use vortex::dtype::Nullability; - use vortex::dtype::extension::ExtDType; - use vortex::error::VortexResult; - use vortex::error::vortex_err; - use vortex::extension::EmptyMetadata; - use vortex::scalar::Scalar; + use vortex_array::ArrayRef; + use vortex_array::ExecutionCtx; + use vortex_array::IntoArray; + use vortex_array::arrays::ConstantArray; + use vortex_array::arrays::ExtensionArray; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; + use vortex_array::scalar::Scalar; + use vortex_array::validity::Validity; + use vortex_buffer::Buffer; + use vortex_error::VortexResult; + use vortex_error::vortex_err; use super::extension_list_size; use super::extension_storage; @@ -183,7 +183,8 @@ pub mod test_helpers { elements: &[f64], len: usize, ) -> VortexResult { - let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + let element_dtype = + DType::Primitive(vortex_array::dtype::PType::F64, Nullability::NonNullable); let children: Vec = elements .iter() @@ -204,7 +205,8 @@ pub mod test_helpers { /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`], representing a /// single query vector broadcast to `len` rows. pub fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + let element_dtype = + DType::Primitive(vortex_array::dtype::PType::F64, Nullability::NonNullable); let children: Vec = elements .iter() diff --git a/vortex-tensor/src/vector/vtable.rs b/vortex-tensor/src/vector/vtable.rs index 61a0f35d9ff..2dda05b7363 100644 --- a/vortex-tensor/src/vector/vtable.rs +++ b/vortex-tensor/src/vector/vtable.rs @@ -1,15 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex::dtype::DType; -use vortex::dtype::extension::ExtDType; -use vortex::dtype::extension::ExtId; -use vortex::dtype::extension::ExtVTable; -use vortex::error::VortexResult; -use vortex::error::vortex_bail; -use vortex::error::vortex_ensure; -use vortex::extension::EmptyMetadata; -use vortex::scalar::ScalarValue; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; use crate::vector::Vector; @@ -62,13 +62,13 @@ mod tests { use std::sync::Arc; use rstest::rstest; - use vortex::dtype::DType; - use vortex::dtype::Nullability; - use vortex::dtype::PType; - use vortex::dtype::extension::ExtDType; - use vortex::dtype::extension::ExtVTable; - use vortex::error::VortexResult; - use vortex::extension::EmptyMetadata; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::dtype::extension::ExtVTable; + use vortex_array::extension::EmptyMetadata; + use vortex_error::VortexResult; use crate::vector::Vector; diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index d8dc89882b0..816b11c7ea4 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -55,12 +55,15 @@ divan = { workspace = true } fastlanes = { workspace = true } mimalloc = { workspace = true } parquet = { workspace = true } +paste = { workspace = true } rand = { workspace = true } +rand_distr = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } tracing-subscriber = { workspace = true } vortex = { path = ".", features = ["tokio"] } +vortex-tensor = { workspace = true } [features] default = ["files", "zstd"] diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index d2928a81a8f..aa8d3efb6a5 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -11,19 +11,21 @@ use divan::Bencher; #[cfg(not(codspeed))] use divan::counter::BytesCount; use mimalloc::MiMalloc; +use paste::paste; use rand::RngExt; use rand::SeedableRng; use rand::prelude::IndexedRandom; use rand::rngs::StdRng; use vortex::array::IntoArray; use vortex::array::ToCanonical; -use vortex::array::VortexSessionExecute; +use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::builders::dict::dict_encode; use vortex::array::builtins::ArrayBuiltins; use vortex::array::dtype::Nullability; use vortex::array::session::ArraySession; +use vortex::array::validity::Validity; use vortex::dtype::PType; use vortex::encodings::alp::RDEncoder; use vortex::encodings::alp::alp_encode; @@ -36,9 +38,18 @@ use vortex::encodings::pco::Pco; use vortex::encodings::runend::RunEnd; use vortex::encodings::sequence::sequence_encode; use vortex::encodings::zigzag::zigzag_encode; +use vortex::encodings::zstd::ZstdArray; use vortex::encodings::zstd::ZstdData; +use vortex_array::VortexSessionExecute; +use vortex_array::dtype::Nullability; +use vortex_array::session::ArraySession; +use vortex_buffer::BufferMut; use vortex_sequence::Sequence; +use vortex_sequence::SequenceArray; use vortex_session::VortexSession; +use vortex_tensor::encodings::turboquant::TurboQuantConfig; +use vortex_tensor::encodings::turboquant::turboquant_encode_mse; +use vortex_tensor::encodings::turboquant::turboquant_encode_qjl; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -417,3 +428,108 @@ fn bench_zstd_decompress_string(bencher: Bencher) { .with_inputs(|| &compressed) .bench_refs(|a| a.to_canonical()); } + +// TurboQuant vector quantization benchmarks + +const NUM_VECTORS: usize = 1_000; + +/// Generate `num_vectors` random f32 vectors of the given dimension using i.i.d. +/// standard normal components. This is a conservative test distribution: real +/// neural network embeddings typically have structure (clustered, anisotropic) +/// that the SRHT exploits for better quantization, so Gaussian i.i.d. is a +/// worst-case baseline for TurboQuant. +fn setup_vector_fsl(dim: usize) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(42); + let normal = rand_distr::Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(NUM_VECTORS * dim); + for _ in 0..(NUM_VECTORS * dim) { + buf.push(rand_distr::Distribution::sample(&normal, &mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + NUM_VECTORS, + ) + .unwrap() +} + +fn turboquant_config(bit_width: u8) -> TurboQuantConfig { + TurboQuantConfig { + bit_width, + seed: Some(123), + } +} + +macro_rules! turboquant_bench { + (compress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] + fn [<$name _mse>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &fsl) + .bench_refs(|a| turboquant_encode_mse(a, &config).unwrap()); + } + + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] + fn [<$name _qjl>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &fsl) + .bench_refs(|a| turboquant_encode_qjl(a, &config).unwrap()); + } + } + }; + (decompress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] + fn [<$name _mse>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + let compressed = turboquant_encode_mse(&fsl, &config).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] + fn [<$name _qjl>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + let compressed = turboquant_encode_qjl(&fsl, &config).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + } + }; +} + +turboquant_bench!(compress, 128, 4, bench_tq_compress_128_4); +turboquant_bench!(decompress, 128, 4, bench_tq_decompress_128_4); +turboquant_bench!(compress, 768, 4, bench_tq_compress_768_4); +turboquant_bench!(decompress, 768, 4, bench_tq_decompress_768_4); +turboquant_bench!(compress, 1024, 2, bench_tq_compress_1024_2); +turboquant_bench!(decompress, 1024, 2, bench_tq_decompress_1024_2); +turboquant_bench!(compress, 1024, 4, bench_tq_compress_1024_4); +turboquant_bench!(decompress, 1024, 4, bench_tq_decompress_1024_4); +turboquant_bench!(compress, 1024, 8, bench_tq_compress_1024_8); +turboquant_bench!(decompress, 1024, 8, bench_tq_decompress_1024_8);