From bf09cedf186d33e01c54bc0b98bdb7b186f0dffd Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 05:12:20 +0800 Subject: [PATCH 1/7] refactor(vector): split SIMD distance dispatch into submodules Move simd.rs into a simd/ directory with dedicated files for each target (avx2, avx512, neon, scalar, hamming, runtime). Add a length-mismatch assertion in the top-level distance() entry point so mismatched vector dimensions panic immediately with a clear message rather than producing silent wrong results. --- nodedb-vector/src/distance/mod.rs | 7 + nodedb-vector/src/distance/simd/avx2.rs | 114 ++++++++++++++++ nodedb-vector/src/distance/simd/avx512.rs | 99 ++++++++++++++ nodedb-vector/src/distance/simd/hamming.rs | 35 +++++ nodedb-vector/src/distance/simd/mod.rs | 14 ++ nodedb-vector/src/distance/simd/neon.rs | 96 ++++++++++++++ nodedb-vector/src/distance/simd/runtime.rs | 144 +++++++++++++++++++++ nodedb-vector/src/distance/simd/scalar.rs | 38 ++++++ nodedb-vector/tests/simd_length_safety.rs | 60 +++++++++ 9 files changed, 607 insertions(+) create mode 100644 nodedb-vector/src/distance/simd/avx2.rs create mode 100644 nodedb-vector/src/distance/simd/avx512.rs create mode 100644 nodedb-vector/src/distance/simd/hamming.rs create mode 100644 nodedb-vector/src/distance/simd/mod.rs create mode 100644 nodedb-vector/src/distance/simd/neon.rs create mode 100644 nodedb-vector/src/distance/simd/runtime.rs create mode 100644 nodedb-vector/src/distance/simd/scalar.rs create mode 100644 nodedb-vector/tests/simd_length_safety.rs diff --git a/nodedb-vector/src/distance/mod.rs b/nodedb-vector/src/distance/mod.rs index 81806b21..8c11911f 100644 --- a/nodedb-vector/src/distance/mod.rs +++ b/nodedb-vector/src/distance/mod.rs @@ -13,6 +13,13 @@ pub use scalar::*; /// feature is enabled; otherwise uses scalar implementations. #[inline] pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 { + assert_eq!( + a.len(), + b.len(), + "distance: length mismatch (a.len()={}, b.len()={})", + a.len(), + b.len() + ); #[cfg(feature = "simd")] { let rt = simd::runtime(); diff --git a/nodedb-vector/src/distance/simd/avx2.rs b/nodedb-vector/src/distance/simd/avx2.rs new file mode 100644 index 00000000..ad530afa --- /dev/null +++ b/nodedb-vector/src/distance/simd/avx2.rs @@ -0,0 +1,114 @@ +//! AVX2+FMA kernels for x86_64. + +#![cfg(target_arch = "x86_64")] + +pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 l2: length mismatch"); + // SAFETY: caller verified avx2+fma via is_x86_feature_detected. + unsafe { l2_squared_impl(a, b) } +} + +#[target_feature(enable = "avx2,fma")] +unsafe fn l2_squared_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 l2_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut sum = _mm256_setzero_ps(); + let chunks = n / 8; + for i in 0..chunks { + let off = i * 8; + let va = _mm256_loadu_ps(a.as_ptr().add(off)); + let vb = _mm256_loadu_ps(b.as_ptr().add(off)); + let diff = _mm256_sub_ps(va, vb); + sum = _mm256_fmadd_ps(diff, diff, sum); + } + let mut result = hsum256(sum); + for i in (chunks * 8)..n { + let d = a[i] - b[i]; + result += d * d; + } + result + } +} + +pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 cosine: length mismatch"); + unsafe { cosine_impl(a, b) } +} + +#[target_feature(enable = "avx2,fma")] +unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 cosine_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm256_setzero_ps(); + let mut vna = _mm256_setzero_ps(); + let mut vnb = _mm256_setzero_ps(); + let chunks = n / 8; + for i in 0..chunks { + let off = i * 8; + let va = _mm256_loadu_ps(a.as_ptr().add(off)); + let vb = _mm256_loadu_ps(b.as_ptr().add(off)); + vdot = _mm256_fmadd_ps(va, vb, vdot); + vna = _mm256_fmadd_ps(va, va, vna); + vnb = _mm256_fmadd_ps(vb, vb, vnb); + } + let mut dot = hsum256(vdot); + let mut na = hsum256(vna); + let mut nb = hsum256(vnb); + for i in (chunks * 8)..n { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } + } +} + +pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 ip: length mismatch"); + unsafe { ip_impl(a, b) } +} + +#[target_feature(enable = "avx2,fma")] +unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx2 ip_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm256_setzero_ps(); + let chunks = n / 8; + for i in 0..chunks { + let off = i * 8; + let va = _mm256_loadu_ps(a.as_ptr().add(off)); + let vb = _mm256_loadu_ps(b.as_ptr().add(off)); + vdot = _mm256_fmadd_ps(va, vb, vdot); + } + let mut dot = hsum256(vdot); + for i in (chunks * 8)..n { + dot += a[i] * b[i]; + } + -dot + } +} + +/// Horizontal sum of 8 × f32 in a __m256. +#[target_feature(enable = "avx2")] +unsafe fn hsum256(v: std::arch::x86_64::__m256) -> f32 { + use std::arch::x86_64::*; + let hi = _mm256_extractf128_ps(v, 1); + let lo = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(lo, hi); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let sums2 = _mm_add_ss(sums, shuf2); + _mm_cvtss_f32(sums2) +} diff --git a/nodedb-vector/src/distance/simd/avx512.rs b/nodedb-vector/src/distance/simd/avx512.rs new file mode 100644 index 00000000..037ae6fd --- /dev/null +++ b/nodedb-vector/src/distance/simd/avx512.rs @@ -0,0 +1,99 @@ +//! AVX-512 kernels for x86_64. + +#![cfg(target_arch = "x86_64")] + +pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 l2: length mismatch"); + unsafe { l2_impl(a, b) } +} + +#[target_feature(enable = "avx512f")] +unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 l2_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut sum = _mm512_setzero_ps(); + let chunks = n / 16; + for i in 0..chunks { + let off = i * 16; + let va = _mm512_loadu_ps(a.as_ptr().add(off)); + let vb = _mm512_loadu_ps(b.as_ptr().add(off)); + let diff = _mm512_sub_ps(va, vb); + sum = _mm512_fmadd_ps(diff, diff, sum); + } + let mut result = _mm512_reduce_add_ps(sum); + for i in (chunks * 16)..n { + let d = a[i] - b[i]; + result += d * d; + } + result + } +} + +pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 cosine: length mismatch"); + unsafe { cosine_impl(a, b) } +} + +#[target_feature(enable = "avx512f")] +unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 cosine_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm512_setzero_ps(); + let mut vna = _mm512_setzero_ps(); + let mut vnb = _mm512_setzero_ps(); + let chunks = n / 16; + for i in 0..chunks { + let off = i * 16; + let va = _mm512_loadu_ps(a.as_ptr().add(off)); + let vb = _mm512_loadu_ps(b.as_ptr().add(off)); + vdot = _mm512_fmadd_ps(va, vb, vdot); + vna = _mm512_fmadd_ps(va, va, vna); + vnb = _mm512_fmadd_ps(vb, vb, vnb); + } + let mut dot = _mm512_reduce_add_ps(vdot); + let mut na = _mm512_reduce_add_ps(vna); + let mut nb = _mm512_reduce_add_ps(vnb); + for i in (chunks * 16)..n { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } + } +} + +pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 ip: length mismatch"); + unsafe { ip_impl(a, b) } +} + +#[target_feature(enable = "avx512f")] +unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "avx512 ip_impl: length mismatch"); + unsafe { + use std::arch::x86_64::*; + let n = a.len(); + let mut vdot = _mm512_setzero_ps(); + let chunks = n / 16; + for i in 0..chunks { + let off = i * 16; + let va = _mm512_loadu_ps(a.as_ptr().add(off)); + let vb = _mm512_loadu_ps(b.as_ptr().add(off)); + vdot = _mm512_fmadd_ps(va, vb, vdot); + } + let mut dot = _mm512_reduce_add_ps(vdot); + for i in (chunks * 16)..n { + dot += a[i] * b[i]; + } + -dot + } +} diff --git a/nodedb-vector/src/distance/simd/hamming.rs b/nodedb-vector/src/distance/simd/hamming.rs new file mode 100644 index 00000000..0c8ef558 --- /dev/null +++ b/nodedb-vector/src/distance/simd/hamming.rs @@ -0,0 +1,35 @@ +//! Fast Hamming distance using u64 POPCNT. + +pub fn fast_hamming(a: &[u8], b: &[u8]) -> u32 { + assert_eq!(a.len(), b.len(), "fast_hamming: length mismatch"); + let mut dist = 0u32; + let chunks = a.len() / 8; + for i in 0..chunks { + let off = i * 8; + let xa = u64::from_le_bytes([ + a[off], + a[off + 1], + a[off + 2], + a[off + 3], + a[off + 4], + a[off + 5], + a[off + 6], + a[off + 7], + ]); + let xb = u64::from_le_bytes([ + b[off], + b[off + 1], + b[off + 2], + b[off + 3], + b[off + 4], + b[off + 5], + b[off + 6], + b[off + 7], + ]); + dist += (xa ^ xb).count_ones(); + } + for i in (chunks * 8)..a.len() { + dist += (a[i] ^ b[i]).count_ones(); + } + dist +} diff --git a/nodedb-vector/src/distance/simd/mod.rs b/nodedb-vector/src/distance/simd/mod.rs new file mode 100644 index 00000000..c5b1766b --- /dev/null +++ b/nodedb-vector/src/distance/simd/mod.rs @@ -0,0 +1,14 @@ +//! Runtime SIMD dispatch for vector distance and bitmap operations. + +pub mod hamming; +pub mod runtime; +pub mod scalar; + +#[cfg(target_arch = "x86_64")] +pub mod avx2; +#[cfg(target_arch = "x86_64")] +pub mod avx512; +#[cfg(target_arch = "aarch64")] +pub mod neon; + +pub use runtime::{SimdRuntime, runtime}; diff --git a/nodedb-vector/src/distance/simd/neon.rs b/nodedb-vector/src/distance/simd/neon.rs new file mode 100644 index 00000000..6600f779 --- /dev/null +++ b/nodedb-vector/src/distance/simd/neon.rs @@ -0,0 +1,96 @@ +//! NEON kernels for ARM64. + +#![cfg(target_arch = "aarch64")] + +pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon l2: length mismatch"); + unsafe { l2_impl(a, b) } +} + +unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon l2_impl: length mismatch"); + unsafe { + use std::arch::aarch64::*; + let n = a.len(); + let mut sum = vdupq_n_f32(0.0); + let chunks = n / 4; + for i in 0..chunks { + let off = i * 4; + let va = vld1q_f32(a.as_ptr().add(off)); + let vb = vld1q_f32(b.as_ptr().add(off)); + let diff = vsubq_f32(va, vb); + sum = vfmaq_f32(sum, diff, diff); + } + let mut result = vaddvq_f32(sum); + for i in (chunks * 4)..n { + let d = a[i] - b[i]; + result += d * d; + } + result + } +} + +pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon cosine: length mismatch"); + unsafe { cosine_impl(a, b) } +} + +unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon cosine_impl: length mismatch"); + unsafe { + use std::arch::aarch64::*; + let n = a.len(); + let mut vdot = vdupq_n_f32(0.0); + let mut vna = vdupq_n_f32(0.0); + let mut vnb = vdupq_n_f32(0.0); + let chunks = n / 4; + for i in 0..chunks { + let off = i * 4; + let va = vld1q_f32(a.as_ptr().add(off)); + let vb = vld1q_f32(b.as_ptr().add(off)); + vdot = vfmaq_f32(vdot, va, vb); + vna = vfmaq_f32(vna, va, va); + vnb = vfmaq_f32(vnb, vb, vb); + } + let mut dot = vaddvq_f32(vdot); + let mut na = vaddvq_f32(vna); + let mut nb = vaddvq_f32(vnb); + for i in (chunks * 4)..n { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } + } +} + +pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon ip: length mismatch"); + unsafe { ip_impl(a, b) } +} + +unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "neon ip_impl: length mismatch"); + unsafe { + use std::arch::aarch64::*; + let n = a.len(); + let mut vdot = vdupq_n_f32(0.0); + let chunks = n / 4; + for i in 0..chunks { + let off = i * 4; + let va = vld1q_f32(a.as_ptr().add(off)); + let vb = vld1q_f32(b.as_ptr().add(off)); + vdot = vfmaq_f32(vdot, va, vb); + } + let mut dot = vaddvq_f32(vdot); + for i in (chunks * 4)..n { + dot += a[i] * b[i]; + } + -dot + } +} diff --git a/nodedb-vector/src/distance/simd/runtime.rs b/nodedb-vector/src/distance/simd/runtime.rs new file mode 100644 index 00000000..a9fd0404 --- /dev/null +++ b/nodedb-vector/src/distance/simd/runtime.rs @@ -0,0 +1,144 @@ +//! Runtime SIMD detection and dispatch table. + +use super::hamming::fast_hamming; +use super::scalar::{scalar_cosine, scalar_ip, scalar_l2}; + +#[cfg(target_arch = "x86_64")] +use super::{avx2, avx512}; + +#[cfg(target_arch = "aarch64")] +use super::neon; + +/// Selected SIMD runtime — function pointers to the best available kernels. +pub struct SimdRuntime { + pub l2_squared: fn(&[f32], &[f32]) -> f32, + pub cosine_distance: fn(&[f32], &[f32]) -> f32, + pub neg_inner_product: fn(&[f32], &[f32]) -> f32, + pub hamming: fn(&[u8], &[u8]) -> u32, + pub name: &'static str, +} + +impl SimdRuntime { + /// Detect CPU features and select the best kernels. + pub fn detect() -> Self { + #[cfg(target_arch = "x86_64")] + { + if std::is_x86_feature_detected!("avx512f") { + return Self { + l2_squared: avx512::l2_squared, + cosine_distance: avx512::cosine_distance, + neg_inner_product: avx512::neg_inner_product, + hamming: fast_hamming, + name: "avx512", + }; + } + if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") { + return Self { + l2_squared: avx2::l2_squared, + cosine_distance: avx2::cosine_distance, + neg_inner_product: avx2::neg_inner_product, + hamming: fast_hamming, + name: "avx2+fma", + }; + } + } + #[cfg(target_arch = "aarch64")] + { + return Self { + l2_squared: neon::l2_squared, + cosine_distance: neon::cosine_distance, + neg_inner_product: neon::neg_inner_product, + hamming: fast_hamming, + name: "neon", + }; + } + #[allow(unreachable_code)] + Self { + l2_squared: scalar_l2, + cosine_distance: scalar_cosine, + neg_inner_product: scalar_ip, + hamming: fast_hamming, + name: "scalar", + } + } +} + +/// Global SIMD runtime — initialized once, used everywhere. +static RUNTIME: std::sync::OnceLock = std::sync::OnceLock::new(); + +/// Get the global SIMD runtime (auto-detects on first call). +pub fn runtime() -> &'static SimdRuntime { + RUNTIME.get_or_init(SimdRuntime::detect) +} + +#[cfg(test)] +mod tests { + use super::super::hamming::fast_hamming; + use super::super::scalar::{scalar_cosine, scalar_ip, scalar_l2}; + use super::*; + + #[test] + fn runtime_detects_features() { + let rt = SimdRuntime::detect(); + assert!(!rt.name.is_empty()); + tracing::info!("SIMD runtime: {}", rt.name); + } + + #[test] + fn l2_matches_scalar() { + let rt = runtime(); + let a: Vec = (0..768).map(|i| (i as f32) * 0.01).collect(); + let b: Vec = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect(); + + let simd_result = (rt.l2_squared)(&a, &b); + let scalar_result = scalar_l2(&a, &b); + assert!( + (simd_result - scalar_result).abs() < 0.01, + "simd={simd_result}, scalar={scalar_result}" + ); + } + + #[test] + fn cosine_matches_scalar() { + let rt = runtime(); + let a: Vec = (0..768).map(|i| (i as f32).sin()).collect(); + let b: Vec = (0..768).map(|i| (i as f32).cos()).collect(); + + let simd_result = (rt.cosine_distance)(&a, &b); + let scalar_result = scalar_cosine(&a, &b); + assert!( + (simd_result - scalar_result).abs() < 0.001, + "simd={simd_result}, scalar={scalar_result}" + ); + } + + #[test] + fn ip_matches_scalar() { + let rt = runtime(); + let a: Vec = (0..128).map(|i| (i as f32) * 0.1).collect(); + let b: Vec = (0..128).map(|i| (i as f32) * 0.2).collect(); + + let simd_result = (rt.neg_inner_product)(&a, &b); + let scalar_result = scalar_ip(&a, &b); + assert!( + (simd_result - scalar_result).abs() < 0.1, + "simd={simd_result}, scalar={scalar_result}" + ); + } + + #[test] + fn hamming_matches() { + let a = vec![0b10101010u8; 16]; + let b = vec![0b01010101u8; 16]; + assert_eq!(fast_hamming(&a, &b), 128); + } + + #[test] + fn small_vectors() { + let rt = runtime(); + let a = [1.0f32, 2.0, 3.0]; + let b = [4.0f32, 5.0, 6.0]; + let l2 = (rt.l2_squared)(&a, &b); + assert!((l2 - 27.0).abs() < 0.01); + } +} diff --git a/nodedb-vector/src/distance/simd/scalar.rs b/nodedb-vector/src/distance/simd/scalar.rs new file mode 100644 index 00000000..c3e52028 --- /dev/null +++ b/nodedb-vector/src/distance/simd/scalar.rs @@ -0,0 +1,38 @@ +//! Scalar fallback kernels for L2, cosine, and inner product. + +pub fn scalar_l2(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "scalar_l2: length mismatch"); + let mut sum = 0.0f32; + for i in 0..a.len() { + let d = a[i] - b[i]; + sum += d * d; + } + sum +} + +pub fn scalar_cosine(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "scalar_cosine: length mismatch"); + let mut dot = 0.0f32; + let mut na = 0.0f32; + let mut nb = 0.0f32; + for i in 0..a.len() { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + let denom = (na * nb).sqrt(); + if denom < f32::EPSILON { + 1.0 + } else { + (1.0 - dot / denom).max(0.0) + } +} + +pub fn scalar_ip(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len(), "scalar_ip: length mismatch"); + let mut dot = 0.0f32; + for i in 0..a.len() { + dot += a[i] * b[i]; + } + -dot +} diff --git a/nodedb-vector/tests/simd_length_safety.rs b/nodedb-vector/tests/simd_length_safety.rs new file mode 100644 index 00000000..a8e01003 --- /dev/null +++ b/nodedb-vector/tests/simd_length_safety.rs @@ -0,0 +1,60 @@ +//! Length-parity safety for SIMD distance kernels. +//! +//! Spec: the public `distance(a, b, metric)` dispatcher MUST NOT invoke a +//! SIMD kernel when `a.len() != b.len()`. The AVX2/AVX-512/NEON kernels +//! iterate with `a.len()` and read from `b.as_ptr().add(off)` via +//! `loadu_ps` — reading past `b`'s allocation is undefined behavior. +//! +//! A deterministic panic at the dispatcher boundary is the contract. Either +//! length validation or length-bounded iteration keeps the kernel safe. + +#![cfg(feature = "simd")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::distance::distance; + +fn assert_rejects_mismatch(metric: DistanceMetric) { + // a.len() = 9 forces one 8-wide SIMD chunk + remainder. b.len() = 1 + // means any unchecked 256-bit load from b is a buffer overrun. A correct + // dispatcher either rejects the call (panic) or bounds iteration by + // `min(a.len(), b.len())`; both surface as a deterministic panic today + // because the scalar remainder loop indexes `b[i]` out of bounds. + let a = vec![1.0f32; 9]; + let b = vec![1.0f32; 1]; + + let result = std::panic::catch_unwind(|| distance(&a, &b, metric)); + assert!( + result.is_err(), + "distance({metric:?}) must reject length mismatch (a.len()=9, b.len()=1) \ + instead of reading past the shorter buffer" + ); +} + +#[test] +fn l2_rejects_length_mismatch() { + assert_rejects_mismatch(DistanceMetric::L2); +} + +#[test] +fn cosine_rejects_length_mismatch() { + assert_rejects_mismatch(DistanceMetric::Cosine); +} + +#[test] +fn inner_product_rejects_length_mismatch() { + assert_rejects_mismatch(DistanceMetric::InnerProduct); +} + +#[test] +fn l2_rejects_swapped_mismatch() { + // Swap order: shorter slice first. The kernels use a.len() as the loop + // bound, so a.len()=1, b.len()=9 exits early — but the dispatcher + // contract is symmetric: any mismatch is invalid input. + let a = vec![1.0f32; 1]; + let b = vec![1.0f32; 9]; + let result = std::panic::catch_unwind(|| distance(&a, &b, DistanceMetric::L2)); + assert!( + result.is_err(), + "distance() must reject length mismatch in either argument order" + ); +} From eecf3f4d2da783be02981dcc0018380701d3e67e Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 05:12:32 +0800 Subject: [PATCH 2/7] fix(vector): cap HNSW layer assignment at MAX_LAYER_CAP (16) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit random_layer() could theoretically produce very large values for unlucky RNG draws, promoting max_layer to an unbounded height and making every subsequent search's Phase-1 greedy descent O(max_layer). Apply a hard cap of 16 layers — standard practice for production HNSW deployments. Also refactor compact() to expose compact_with_map() returning the old→new id remapping needed by doc_id_map maintenance. --- nodedb-vector/src/hnsw/graph.rs | 28 ++++++++-- nodedb-vector/tests/hnsw_layer_cap.rs | 73 +++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 nodedb-vector/tests/hnsw_layer_cap.rs diff --git a/nodedb-vector/src/hnsw/graph.rs b/nodedb-vector/src/hnsw/graph.rs index 00a2b99e..72f9b4d0 100644 --- a/nodedb-vector/src/hnsw/graph.rs +++ b/nodedb-vector/src/hnsw/graph.rs @@ -8,6 +8,11 @@ use crate::distance::distance; // Re-export shared params from nodedb-types. pub use nodedb_types::hnsw::HnswParams; +/// Hard cap on the layer assigned to any node during insertion. +/// Standard HNSW practice — prevents pathological RNG draws from inflating +/// `max_layer` and slowing every subsequent search. +pub const MAX_LAYER_CAP: usize = 16; + /// Result of a k-NN search. #[derive(Debug, Clone)] pub struct SearchResult { @@ -254,10 +259,15 @@ impl HnswIndex { } /// Assign a random layer using the exponential distribution. + /// + /// Capped at `MAX_LAYER_CAP` to prevent pathological RNG draws from + /// promoting the index's `max_layer` to hundreds or thousands, which + /// would make every search's Phase-1 greedy descent O(max_layer). pub(crate) fn random_layer(&mut self) -> usize { let ml = 1.0 / (self.params.m as f64).ln(); let r = self.rng.next_f64().max(f64::MIN_POSITIVE); - (-r.ln() * ml).floor() as usize + let layer = (-r.ln() * ml).floor() as usize; + layer.min(MAX_LAYER_CAP) } /// Compute distance between a query vector and a stored node. @@ -279,10 +289,22 @@ impl HnswIndex { } /// Compact the index by removing all tombstoned nodes. + /// + /// Returns the number of removed nodes. See `compact_with_map` for the + /// variant that also returns the old→new id remapping. pub fn compact(&mut self) -> usize { + self.compact_with_map().0 + } + + /// Compact and return both the removed count and the old→new id map. + /// + /// `id_map[old_local]` = new_local, or `u32::MAX` if the node was + /// tombstoned (removed). + pub fn compact_with_map(&mut self) -> (usize, Vec) { let tombstone_count = self.tombstone_count(); if tombstone_count == 0 { - return 0; + let identity: Vec = (0..self.nodes.len() as u32).collect(); + return (0, identity); } self.ensure_mutable_neighbors(); @@ -348,7 +370,7 @@ impl HnswIndex { .unwrap_or(0); self.nodes = new_nodes; - tombstone_count + (tombstone_count, id_map) } } diff --git a/nodedb-vector/tests/hnsw_layer_cap.rs b/nodedb-vector/tests/hnsw_layer_cap.rs new file mode 100644 index 00000000..bef3b0fa --- /dev/null +++ b/nodedb-vector/tests/hnsw_layer_cap.rs @@ -0,0 +1,73 @@ +//! HNSW `random_layer` must be capped at a reasonable maximum. +//! +//! Spec: standard HNSW caps the assigned layer at ~16. The current +//! `random_layer` implementation has no cap — with an unlucky xorshift +//! draw (`r ≈ 2.2e-308`), `-ln(r) * (1/ln(m))` can return a layer in +//! the hundreds or thousands. One outlier insert then promotes the +//! index's `max_layer`, and every subsequent search's Phase-1 greedy +//! descent iterates `(1..=max_layer).rev()` — converting constant-time +//! descent into O(max_layer) per query. + +use nodedb_vector::DistanceMetric; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; + +/// Hard cap enforced by `HnswIndex::random_layer`. Standard HNSW uses ~16 +/// and the implementation clamps at `MAX_LAYER_CAP = 16`. +const LAYER_CAP: usize = 16; + +#[test] +fn random_layer_never_exceeds_cap_under_normal_inserts() { + let mut idx = HnswIndex::with_seed( + 4, + HnswParams { + m: 16, + m0: 32, + ef_construction: 64, + metric: DistanceMetric::L2, + }, + 1, + ); + for i in 0..5_000u32 { + let v = vec![ + (i as f32).sin(), + (i as f32).cos(), + ((i * 3) as f32).sin(), + ((i * 7) as f32).cos(), + ]; + idx.insert(v).unwrap(); + } + assert!( + idx.max_layer() <= LAYER_CAP, + "max_layer grew to {} (cap = {LAYER_CAP}); one pathological random_layer \ + draw promoted the index and will slow every subsequent search", + idx.max_layer() + ); +} + +#[test] +fn random_layer_capped_with_adversarial_seed() { + // Seeds chosen to exercise xorshift states that produce very small + // `next_f64()` outputs early in the sequence. A correct implementation + // clamps the resulting layer regardless of the RNG draw. + for seed in [1u64, 2, 3, 7, 13, 42, 123, 9_999, 1_000_003] { + let mut idx = HnswIndex::with_seed( + 2, + HnswParams { + m: 2, // small m amplifies -ln(r) * (1/ln(m)) + m0: 4, + ef_construction: 32, + metric: DistanceMetric::L2, + }, + seed, + ); + for i in 0..2_000u32 { + idx.insert(vec![i as f32, 0.0]).unwrap(); + } + assert!( + idx.max_layer() <= LAYER_CAP, + "seed={seed}: max_layer reached {} (cap = {LAYER_CAP}) — \ + random_layer has no upper bound", + idx.max_layer() + ); + } +} From b54c4ae43287528a49a7ce32948ec2a3ad4313ac Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 05:12:43 +0800 Subject: [PATCH 3/7] =?UTF-8?q?fix(vector):=20correct=20k-means++=20d?= =?UTF-8?q?=C2=B2=20sampling=20in=20IVF=20and=20PQ=20training?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous initialization selected the farthest point deterministically rather than sampling proportionally to squared distance. Replace with proper weighted d² sampling using a deterministic xorshift RNG so centroid seeding is stable across runs and converges reliably for skewed distributions. Also derive MessagePack serialization for PqCodec so trained codecs survive checkpointing. --- nodedb-vector/src/hnsw/build.rs | 4 +- nodedb-vector/src/hnsw/search.rs | 46 ++++-- nodedb-vector/src/ivf.rs | 35 ++++- nodedb-vector/src/quantize/pq.rs | 58 +++++--- .../tests/quantize_kmeans_distribution.rs | 137 ++++++++++++++++++ 5 files changed, 238 insertions(+), 42 deletions(-) create mode 100644 nodedb-vector/tests/quantize_kmeans_distribution.rs diff --git a/nodedb-vector/src/hnsw/build.rs b/nodedb-vector/src/hnsw/build.rs index 503953a4..140c2d5c 100644 --- a/nodedb-vector/src/hnsw/build.rs +++ b/nodedb-vector/src/hnsw/build.rs @@ -47,7 +47,7 @@ impl HnswIndex { // Phase 1: Greedy descent from top layer to new_layer + 1. if self.max_layer > new_layer { for layer in (new_layer + 1..=self.max_layer).rev() { - let results = search_layer(self, &query, current_ep, 1, layer, None); + let results = search_layer(self, &query, current_ep, 1, layer, None, 0); if let Some(nearest) = results.first() { current_ep = nearest.id; } @@ -58,7 +58,7 @@ impl HnswIndex { let insert_top = new_layer.min(self.max_layer); for layer in (0..=insert_top).rev() { let ef = self.params.ef_construction; - let candidates = search_layer(self, &query, current_ep, ef, layer, None); + let candidates = search_layer(self, &query, current_ep, ef, layer, None, 0); let m = self.max_neighbors(layer); let selected = select_neighbors_heuristic(self, &candidates, m); diff --git a/nodedb-vector/src/hnsw/search.rs b/nodedb-vector/src/hnsw/search.rs index c917943f..64ca15b2 100644 --- a/nodedb-vector/src/hnsw/search.rs +++ b/nodedb-vector/src/hnsw/search.rs @@ -31,14 +31,14 @@ impl HnswIndex { // Phase 1: Greedy descent from top layer to layer 1. let mut current_ep = ep; for layer in (1..=self.max_layer).rev() { - let results = search_layer(self, query, current_ep, 1, layer, None); + let results = search_layer(self, query, current_ep, 1, layer, None, 0); if let Some(nearest) = results.first() { current_ep = nearest.id; } } // Phase 2: Beam search at layer 0. - let results = search_layer(self, query, current_ep, ef, 0, None); + let results = search_layer(self, query, current_ep, ef, 0, None, 0); results .into_iter() @@ -51,16 +51,28 @@ impl HnswIndex { } /// Filtered K-NN search with Roaring bitmap pre-filtering. - /// - /// Only nodes whose ID is present in `filter` are included in results. - /// All nodes are still used for graph navigation — this prevents accuracy - /// degradation for selective filters. pub fn search_filtered( &self, query: &[f32], k: usize, ef: usize, filter: &RoaringBitmap, + ) -> Vec { + self.search_filtered_offset(query, k, ef, filter, 0) + } + + /// Filtered K-NN search where the bitmap is keyed in a shifted ID space. + /// + /// `id_offset` is added to local node IDs before testing `filter.contains`. + /// Used by multi-segment collections where the bitmap holds GLOBAL ids + /// and each segment's HNSW nodes are numbered starting at `base_id`. + pub fn search_filtered_offset( + &self, + query: &[f32], + k: usize, + ef: usize, + filter: &RoaringBitmap, + id_offset: u32, ) -> Vec { assert_eq!(query.len(), self.dim, "query dimension mismatch"); if self.is_empty() { @@ -74,13 +86,13 @@ impl HnswIndex { let mut current_ep = ep; for layer in (1..=self.max_layer).rev() { - let results = search_layer(self, query, current_ep, 1, layer, None); + let results = search_layer(self, query, current_ep, 1, layer, None, 0); if let Some(nearest) = results.first() { current_ep = nearest.id; } } - let results = search_layer(self, query, current_ep, ef, 0, Some(filter)); + let results = search_layer(self, query, current_ep, ef, 0, Some(filter), id_offset); results .into_iter() @@ -99,9 +111,22 @@ impl HnswIndex { k: usize, ef: usize, bitmap_bytes: &[u8], + ) -> Vec { + self.search_with_bitmap_bytes_offset(query, k, ef, bitmap_bytes, 0) + } + + /// Deserialize a Roaring bitmap and search with an ID offset applied + /// before testing membership. See `search_filtered_offset` for rationale. + pub fn search_with_bitmap_bytes_offset( + &self, + query: &[f32], + k: usize, + ef: usize, + bitmap_bytes: &[u8], + id_offset: u32, ) -> Vec { match RoaringBitmap::deserialize_from(bitmap_bytes) { - Ok(bitmap) => self.search_filtered(query, k, ef, &bitmap), + Ok(bitmap) => self.search_filtered_offset(query, k, ef, &bitmap, id_offset), Err(_) => self.search(query, k, ef), } } @@ -119,6 +144,7 @@ pub(crate) fn search_layer( ef: usize, layer: usize, filter: Option<&RoaringBitmap>, + id_offset: u32, ) -> Vec { let mut visited: HashSet = HashSet::new(); visited.insert(entry_point); @@ -139,7 +165,7 @@ pub(crate) fn search_layer( return false; } match filter { - Some(f) => f.contains(id), + Some(f) => f.contains(id + id_offset), None => true, } }; diff --git a/nodedb-vector/src/ivf.rs b/nodedb-vector/src/ivf.rs index 37601407..16071f15 100644 --- a/nodedb-vector/src/ivf.rs +++ b/nodedb-vector/src/ivf.rs @@ -220,21 +220,40 @@ fn kmeans_centroids(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> V let mut centroids: Vec> = vec![data[0].to_vec()]; let mut min_dists = vec![f32::MAX; n]; + // Initialize min_dists against the first centroid. + for (i, point) in data.iter().enumerate() { + let d = distance(point, ¢roids[0], DistanceMetric::L2); + if d < min_dists[i] { + min_dists[i] = d; + } + } + + let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42); for _ in 1..k { - let Some(last) = centroids.last() else { break }; + let total: f64 = min_dists.iter().map(|&d| d as f64).sum(); + let next_idx = if total < f64::EPSILON { + 0 + } else { + let target = rng.next_f64() * total; + let mut acc = 0.0f64; + let mut chosen = n - 1; + for (i, &d) in min_dists.iter().enumerate() { + acc += d as f64; + if acc >= target { + chosen = i; + break; + } + } + chosen + }; + centroids.push(data[next_idx].to_vec()); + let last = centroids.last().expect("just pushed"); for (i, point) in data.iter().enumerate() { let d = distance(point, last, DistanceMetric::L2); if d < min_dists[i] { min_dists[i] = d; } } - let best = min_dists - .iter() - .enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(i, _)| i) - .unwrap_or(0); - centroids.push(data[best].to_vec()); } let mut assignments = vec![0usize; n]; diff --git a/nodedb-vector/src/quantize/pq.rs b/nodedb-vector/src/quantize/pq.rs index e7f2ad99..62cafe00 100644 --- a/nodedb-vector/src/quantize/pq.rs +++ b/nodedb-vector/src/quantize/pq.rs @@ -15,7 +15,7 @@ use serde::{Deserialize, Serialize}; /// PQ codec with trained codebooks. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)] pub struct PqCodec { /// Original vector dimensionality. pub dim: usize, @@ -161,7 +161,8 @@ fn l2_sub(a: &[f32], b: &[f32]) -> f32 { /// Simple k-means clustering for PQ codebook training. /// -/// Uses k-means++ initialization for stable convergence. +/// Uses proper k-means++ initialization (weighted d² sampling) with a +/// deterministic seed so training is reproducible across runs. fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec> { let n = data.len(); if n == 0 || k == 0 { @@ -169,35 +170,48 @@ fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec> = Vec::with_capacity(k); - // First centroid: pick the first data point (deterministic). centroids.push(data[0].to_vec()); let mut min_dists = vec![f32::MAX; n]; - for c in 1..k { - // Update min distances to nearest centroid. + // Update against the first centroid. + for (i, point) in data.iter().enumerate() { + let d = l2_sub(point, ¢roids[0]); + if d < min_dists[i] { + min_dists[i] = d; + } + } + + for _ in 1..k { + let total: f64 = min_dists.iter().map(|&d| d as f64).sum(); + let next_idx = if total < f64::EPSILON { + // All points coincide with existing centroids. + 0 + } else { + let target = rng.next_f64() * total; + let mut acc = 0.0f64; + let mut chosen = n - 1; + for (i, &d) in min_dists.iter().enumerate() { + acc += d as f64; + if acc >= target { + chosen = i; + break; + } + } + chosen + }; + centroids.push(data[next_idx].to_vec()); + // Incrementally update min_dists against the new centroid. + let last = centroids.last().expect("just pushed"); for (i, point) in data.iter().enumerate() { - let d = l2_sub(point, ¢roids[c - 1]); + let d = l2_sub(point, last); if d < min_dists[i] { min_dists[i] = d; } } - // Pick next centroid proportional to d². - let total: f64 = min_dists.iter().map(|&d| d as f64).sum(); - if total < f64::EPSILON { - // All points coincide — duplicate the first centroid. - centroids.push(data[0].to_vec()); - continue; - } - // Deterministic selection: pick the point with max min_dist. - let best_idx = min_dists - .iter() - .enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(i, _)| i) - .unwrap_or(0); - centroids.push(data[best_idx].to_vec()); } // K-means iterations. diff --git a/nodedb-vector/tests/quantize_kmeans_distribution.rs b/nodedb-vector/tests/quantize_kmeans_distribution.rs new file mode 100644 index 00000000..4ed71611 --- /dev/null +++ b/nodedb-vector/tests/quantize_kmeans_distribution.rs @@ -0,0 +1,137 @@ +//! PQ and IVF-PQ codebook training must distribute centroids across the +//! data even when many input vectors are near-duplicates. +//! +//! Spec: k-means initialization selects centroids spread across the data +//! distribution. The current implementation has two compounding bugs: +//! +//! 1. `min_dists[i]` is only updated against `centroids[c - 1]` (the +//! last centroid), not against the full centroid set. Once two +//! centroids coincide, `min_dists` stops reflecting "distance to the +//! nearest centroid," so every subsequent deterministic-argmax pick +//! lands on the same outlier. +//! 2. The comment says "k-means++" but the selection is deterministic +//! farthest-point, so outliers dominate rather than being sampled +//! proportionally to d². +//! +//! Effect: on workloads with repeated prefixes/suffixes (templated chat, +//! shared headers/footers), most of the 256 centroids alias to one or two +//! points and PQ recall collapses. + +use nodedb_vector::quantize::pq::PqCodec; + +/// Training set of 200 vectors: 190 near-duplicates at the origin plus +/// 10 outliers scattered across a single subspace. A correct k-means++ +/// spreads centroids across both clusters; the current farthest-point- +/// with-broken-min-distance-update collapses to ~2 distinct centroids. +fn clustered_with_duplicates() -> Vec> { + let mut vecs: Vec> = Vec::with_capacity(200); + // Cluster A: 190 near-identical vectors near origin. + for i in 0..190 { + let eps = (i as f32) * 1e-5; + vecs.push(vec![eps, -eps, eps * 0.5, -eps * 0.5]); + } + // Cluster B: 10 outliers at distinct coordinates. + for j in 0..10 { + let x = 100.0 + (j as f32) * 10.0; + vecs.push(vec![x, -x, x * 0.5, -x * 0.5]); + } + vecs +} + +fn unique_centroid_count(codec: &PqCodec, vectors: &[Vec]) -> usize { + let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect(); + let codes = codec.encode_batch(&refs); + let m = codec.m; + // Per-subspace unique centroid indices used across the batch. + let mut min_unique = usize::MAX; + for sub in 0..m { + let mut seen = std::collections::HashSet::new(); + for row in 0..vectors.len() { + seen.insert(codes[row * m + sub]); + } + if seen.len() < min_unique { + min_unique = seen.len(); + } + } + min_unique +} + +#[test] +fn pq_kmeans_produces_diverse_centroids_on_duplicate_heavy_data() { + let vecs = clustered_with_duplicates(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let codec = PqCodec::train(&refs, 4, 2, 16, 20); + + let unique = unique_centroid_count(&codec, &vecs); + assert!( + unique >= 4, + "k-means collapsed to {unique} unique centroids per subspace on \ + duplicate-heavy input; a correct k-means++ should pick at least \ + 4 distinct cluster representatives for k=16" + ); +} + +#[test] +fn pq_distance_table_separates_duplicates_from_outliers() { + // Spec test: after training, the PQ distance from a duplicate-cluster + // query to a duplicate vector must be meaningfully smaller than the + // distance to an outlier vector. Under the collapse bug, most + // codebook entries alias to one point so all distances look similar. + let vecs = clustered_with_duplicates(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let codec = PqCodec::train(&refs, 4, 2, 16, 20); + + let query = [0.0f32, 0.0, 0.0, 0.0]; + let table = codec.build_distance_table(&query); + + let dup_code = codec.encode(&vecs[0]); // duplicate cluster + let outlier_code = codec.encode(&vecs[195]); // outlier cluster + + let dup_dist = codec.asymmetric_distance(&table, &dup_code); + let outlier_dist = codec.asymmetric_distance(&table, &outlier_code); + + assert!( + outlier_dist > dup_dist * 10.0, + "PQ failed to distinguish duplicate (d={dup_dist}) from outlier \ + (d={outlier_dist}) — codebook collapsed and the two codes encode \ + to near-identical table entries" + ); +} + +#[cfg(feature = "ivf")] +#[test] +fn ivf_pq_training_does_not_collapse_on_duplicate_heavy_data() { + use nodedb_vector::DistanceMetric; + use nodedb_vector::{IvfPqIndex, IvfPqParams}; + + let vecs = clustered_with_duplicates(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let mut idx = IvfPqIndex::new( + 4, + IvfPqParams { + n_cells: 8, + pq_m: 2, + pq_k: 16, + nprobe: 4, + metric: DistanceMetric::L2, + }, + ); + idx.train(&refs); + for v in &vecs { + idx.add(v); + } + + // Query at the origin. Correct training assigns near-duplicates to + // one cell and outliers to another; the nearest result must come + // from the duplicate cluster (original indices 0..190). + let results = idx.search(&[0.0, 0.0, 0.0, 0.0], 5); + assert!(!results.is_empty(), "IVF-PQ returned no results"); + for r in &results { + assert!( + r.id < 190, + "IVF-PQ k-means collapse: query at origin returned outlier id={} \ + instead of a near-duplicate cluster member", + r.id + ); + } +} From bf238c71f236eab237f6c180c5b2411810992e08 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 05:13:01 +0800 Subject: [PATCH 4/7] fix(vector): apply id_offset when testing bitmap filter in multi-segment search Bitmap filters carry global vector ids, but each segment's HNSW nodes and FlatIndex entries are numbered starting at zero (local ids). Without an offset the filter tests the wrong bit for every segment after the first, producing incorrect results for filtered searches over collections with more than one sealed segment. Add search_filtered_offset / search_with_bitmap_bytes_offset to HnswIndex and search_filtered_offset to FlatIndex. Thread id_offset through the internal search_layer function so bitmap membership is checked against the correct global id. Update collection/search.rs to pass the per-segment base_id and the growing segment's growing_base_id when dispatching filtered searches. --- nodedb-vector/src/flat.rs | 53 ++++++++- .../tests/collection_bitmap_filter.rs | 112 ++++++++++++++++++ 2 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 nodedb-vector/tests/collection_bitmap_filter.rs diff --git a/nodedb-vector/src/flat.rs b/nodedb-vector/src/flat.rs index 81fb01d5..f75ebf21 100644 --- a/nodedb-vector/src/flat.rs +++ b/nodedb-vector/src/flat.rs @@ -105,6 +105,21 @@ impl FlatIndex { /// Search with a pre-filter bitmap (byte-array format). pub fn search_filtered(&self, query: &[f32], top_k: usize, bitmap: &[u8]) -> Vec { + self.search_filtered_offset(query, top_k, bitmap, 0) + } + + /// Search with a pre-filter bitmap applying a global id offset. + /// + /// The bitmap is interpreted in a shifted id space: bit `i + id_offset` + /// tests local id `i`. Used by multi-segment collections where the + /// bitmap holds GLOBAL vector ids. + pub fn search_filtered_offset( + &self, + query: &[f32], + top_k: usize, + bitmap: &[u8], + id_offset: u32, + ) -> Vec { assert_eq!(query.len(), self.dim); let n = self.len(); if n == 0 || top_k == 0 { @@ -116,8 +131,9 @@ impl FlatIndex { if self.deleted[i] { continue; } - let byte_idx = i / 8; - let bit_idx = i % 8; + let global = i + id_offset as usize; + let byte_idx = global / 8; + let bit_idx = global % 8; if byte_idx >= bitmap.len() || (bitmap[byte_idx] & (1 << bit_idx)) == 0 { continue; } @@ -159,6 +175,17 @@ impl FlatIndex { } pub fn get_vector(&self, id: u32) -> Option<&[f32]> { + let idx = id as usize; + if idx < self.deleted.len() && !self.deleted[idx] { + let start = idx * self.dim; + Some(&self.data[start..start + self.dim]) + } else { + None + } + } + + /// Raw access bypassing tombstone filter — used by snapshot/restore. + pub fn get_vector_raw(&self, id: u32) -> Option<&[f32]> { let idx = id as usize; if idx < self.deleted.len() { let start = idx * self.dim; @@ -168,6 +195,28 @@ impl FlatIndex { } } + /// Whether the given local id has been tombstoned. + pub fn is_deleted(&self, id: u32) -> bool { + let idx = id as usize; + idx < self.deleted.len() && self.deleted[idx] + } + + /// Insert a vector that is already tombstoned (for checkpoint restore). + pub fn insert_tombstoned(&mut self, vector: Vec) -> u32 { + assert_eq!( + vector.len(), + self.dim, + "dimension mismatch: expected {}, got {}", + self.dim, + vector.len() + ); + let id = self.len() as u32; + self.data.extend_from_slice(&vector); + self.deleted.push(true); + // No live_count increment — it's dead on arrival. + id + } + pub fn dim(&self) -> usize { self.dim } diff --git a/nodedb-vector/tests/collection_bitmap_filter.rs b/nodedb-vector/tests/collection_bitmap_filter.rs new file mode 100644 index 00000000..8e406a53 --- /dev/null +++ b/nodedb-vector/tests/collection_bitmap_filter.rs @@ -0,0 +1,112 @@ +//! Roaring bitmap pre-filter must use the same ID space across segments. +//! +//! Spec: the query planner builds a Roaring bitmap from GLOBAL vector IDs. +//! `search_with_bitmap_bytes` walks each sealed segment and the segment's +//! HNSW index tests `filter.contains(id)` against the segment-LOCAL id. +//! The collection MUST reconcile the two — either by rewriting the bitmap +//! per-segment (subtract `seg.base_id`) or by applying the offset before +//! `f.contains(id)`. Without that, every segment beyond the first silently +//! drops all filtered candidates because global ≠ local. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; +use roaring::RoaringBitmap; + +fn params() -> HnswParams { + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + } +} + +/// Fill a collection's growing segment, seal it, complete the build, +/// so the next inserts land at `base_id == seal_count`. +fn seal_one(coll: &mut VectorCollection, count: usize) { + for i in 0..count { + coll.insert(vec![i as f32, 0.0]); + } + let req = coll.seal("k").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); +} + +fn bitmap_bytes(ids: impl IntoIterator) -> Vec { + let mut bm = RoaringBitmap::new(); + for id in ids { + bm.insert(id); + } + let mut bytes = Vec::new(); + bm.serialize_into(&mut bytes).unwrap(); + bytes +} + +#[test] +fn bitmap_filter_targets_second_segment_global_ids() { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 50); + seal_one(&mut coll, 50); // segment 0: ids 0..50, base_id = 0 + seal_one(&mut coll, 50); // segment 1: ids 50..100, base_id = 50 + + // Query for a point near id=75 (in segment 1). Filter to only global + // id 75. Correct behavior: returns id=75. Buggy behavior: the second + // segment's bitmap lookup tests local id 25 against a bitmap that + // contains global 75 → zero matches. + let bytes = bitmap_bytes([75u32]); + let results = coll.search_with_bitmap_bytes(&[75.0, 0.0], 1, 64, &bytes); + + assert_eq!( + results.len(), + 1, + "global-id bitmap filter dropped all candidates in segment 1" + ); + assert_eq!(results[0].id, 75); +} + +#[test] +fn bitmap_filter_recovers_many_globals_across_segments() { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 50); + seal_one(&mut coll, 50); + seal_one(&mut coll, 50); + + // Select globals from the second segment only. + let wanted: Vec = (60..70).collect(); + let bytes = bitmap_bytes(wanted.iter().copied()); + + let results = coll.search_with_bitmap_bytes(&[65.0, 0.0], 10, 128, &bytes); + + assert_eq!( + results.len(), + wanted.len(), + "expected all {} second-segment globals to match; got {}", + wanted.len(), + results.len() + ); + let got: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + for id in &wanted { + assert!( + got.contains(id), + "missing expected id {id} from filtered results" + ); + } +} + +#[test] +fn bitmap_filter_first_segment_still_works() { + // Regression guard for the partial-accident: segment 0 has base_id=0 so + // local==global and filtering appears to work. This test pins that down + // so a fix to the second-segment path doesn't regress segment 0. + let mut coll = VectorCollection::with_seal_threshold(2, params(), 50); + seal_one(&mut coll, 50); + seal_one(&mut coll, 50); + + let bytes = bitmap_bytes([10u32, 20, 30]); + let results = coll.search_with_bitmap_bytes(&[20.0, 0.0], 3, 64, &bytes); + let got: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + let expected: std::collections::HashSet = [10u32, 20, 30].into_iter().collect(); + assert_eq!(got, expected); +} From bf67336463bb02d74247fda549f9765370dc4638 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 05:13:15 +0800 Subject: [PATCH 5/7] fix(vector): preserve tombstones through checkpoint and compact doc_id_map MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Checkpoint serialization previously skipped deleted vectors entirely, so on restore those local ids were missing and all subsequent ids shifted by one — corrupting the HNSW graph's neighbor adjacency. Fix by capturing the deleted flag for every vector (growing and building segments) and replaying tombstones via insert_tombstoned / index.delete() on restore. Separately, compact() previously discarded doc_id_map and multi_doc_map entries for the segment being compacted. Use compact_with_map() to obtain the old→new local id remapping and rewrite both maps so that global ids continue to resolve to the correct document strings after compaction. --- nodedb-vector/src/collection/checkpoint.rs | 67 ++++++++-- nodedb-vector/src/collection/segment.rs | 3 + .../tests/collection_checkpoint_tombstones.rs | 98 ++++++++++++++ .../tests/collection_compact_doc_map.rs | 122 ++++++++++++++++++ 4 files changed, 281 insertions(+), 9 deletions(-) create mode 100644 nodedb-vector/tests/collection_checkpoint_tombstones.rs create mode 100644 nodedb-vector/tests/collection_compact_doc_map.rs diff --git a/nodedb-vector/src/collection/checkpoint.rs b/nodedb-vector/src/collection/checkpoint.rs index a92a9d57..af086d07 100644 --- a/nodedb-vector/src/collection/checkpoint.rs +++ b/nodedb-vector/src/collection/checkpoint.rs @@ -7,6 +7,7 @@ use crate::collection::tier::StorageTier; use crate::distance::DistanceMetric; use crate::flat::FlatIndex; use crate::hnsw::{HnswIndex, HnswParams}; +use crate::quantize::pq::PqCodec; use super::lifecycle::VectorCollection; @@ -33,12 +34,18 @@ pub(crate) struct CollectionSnapshot { pub(crate) struct SealedSnapshot { pub base_id: u32, pub hnsw_bytes: Vec, + #[serde(default)] + pub pq_bytes: Option>, + #[serde(default)] + pub pq_codes: Option>, } #[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)] pub(crate) struct BuildingSnapshot { pub base_id: u32, pub vectors: Vec>, + #[serde(default)] + pub deleted: Vec, } impl VectorCollection { @@ -53,17 +60,27 @@ impl VectorCollection { next_id: self.next_id, growing_base_id: self.growing_base_id, growing_vectors: (0..self.growing.len() as u32) - .filter_map(|i| self.growing.get_vector(i).map(|v| v.to_vec())) + .filter_map(|i| self.growing.get_vector_raw(i).map(|v| v.to_vec())) .collect(), growing_deleted: (0..self.growing.len() as u32) - .map(|i| self.growing.get_vector(i).is_none()) + .map(|i| self.growing.is_deleted(i)) .collect(), sealed_segments: self .sealed .iter() - .map(|s| SealedSnapshot { - base_id: s.base_id, - hnsw_bytes: s.index.checkpoint_to_bytes(), + .map(|s| { + let (pq_bytes, pq_codes) = match &s.pq { + Some((codec, codes)) => { + (zerompk::to_msgpack_vec(codec).ok(), Some(codes.clone())) + } + None => (None, None), + }; + SealedSnapshot { + base_id: s.base_id, + hnsw_bytes: s.index.checkpoint_to_bytes(), + pq_bytes, + pq_codes, + } }) .collect(), building_segments: self @@ -72,7 +89,10 @@ impl VectorCollection { .map(|b| BuildingSnapshot { base_id: b.base_id, vectors: (0..b.flat.len() as u32) - .filter_map(|i| b.flat.get_vector(i).map(|v| v.to_vec())) + .filter_map(|i| b.flat.get_vector_raw(i).map(|v| v.to_vec())) + .collect(), + deleted: (0..b.flat.len() as u32) + .map(|i| b.flat.is_deleted(i)) .collect(), }) .collect(), @@ -118,18 +138,35 @@ impl VectorCollection { }; let mut growing = FlatIndex::new(snap.dim, metric); - for v in &snap.growing_vectors { - growing.insert(v.clone()); + for (i, v) in snap.growing_vectors.iter().enumerate() { + let deleted = snap.growing_deleted.get(i).copied().unwrap_or(false); + if deleted { + growing.insert_tombstoned(v.clone()); + } else { + growing.insert(v.clone()); + } } let mut sealed = Vec::with_capacity(snap.sealed_segments.len()); for ss in &snap.sealed_segments { if let Some(index) = HnswIndex::from_checkpoint(&ss.hnsw_bytes) { - let sq8 = VectorCollection::build_sq8_for_index(&index); + let pq = match (&ss.pq_bytes, &ss.pq_codes) { + (Some(bytes), Some(codes)) => zerompk::from_msgpack::(bytes) + .ok() + .map(|codec| (codec, codes.clone())), + _ => None, + }; + // Only train SQ8 when PQ isn't present — a segment never carries both. + let sq8 = if pq.is_some() { + None + } else { + VectorCollection::build_sq8_for_index(&index) + }; sealed.push(SealedSegment { index, base_id: ss.base_id, sq8, + pq, tier: StorageTier::L0Ram, mmap_vectors: None, }); @@ -143,11 +180,18 @@ impl VectorCollection { .insert(v.clone()) .expect("dimension guaranteed by checkpoint"); } + // Replay building-segment tombstones onto the HNSW index. + for (i, &dead) in bs.deleted.iter().enumerate() { + if dead { + index.delete(i as u32); + } + } let sq8 = VectorCollection::build_sq8_for_index(&index); sealed.push(SealedSegment { index, base_id: bs.base_id, sq8, + pq: None, tier: StorageTier::L0Ram, mmap_vectors: None, }); @@ -155,6 +199,10 @@ impl VectorCollection { let next_segment_id = (sealed.len() + 1) as u32; + let index_config = crate::index_config::IndexConfig { + hnsw: params.clone(), + ..crate::index_config::IndexConfig::default() + }; Some(Self { growing, growing_base_id: snap.growing_base_id, @@ -171,6 +219,7 @@ impl VectorCollection { doc_id_map: snap.doc_id_map.into_iter().collect(), multi_doc_map: snap.multi_doc_map.into_iter().collect(), seal_threshold: DEFAULT_SEAL_THRESHOLD, + index_config, }) } } diff --git a/nodedb-vector/src/collection/segment.rs b/nodedb-vector/src/collection/segment.rs index 52de6ae0..ded68ce7 100644 --- a/nodedb-vector/src/collection/segment.rs +++ b/nodedb-vector/src/collection/segment.rs @@ -4,6 +4,7 @@ use crate::collection::tier::StorageTier; use crate::flat::FlatIndex; use crate::hnsw::{HnswIndex, HnswParams}; use crate::mmap_segment::MmapVectorSegment; +use crate::quantize::pq::PqCodec; use crate::quantize::sq8::Sq8Codec; /// Default threshold for sealing the growing segment. @@ -44,6 +45,8 @@ pub struct SealedSegment { pub base_id: u32, /// Optional SQ8 quantized vectors for accelerated traversal. pub sq8: Option<(Sq8Codec, Vec)>, + /// Optional PQ-compressed codes (for HnswPq-configured indexes). + pub pq: Option<(PqCodec, Vec)>, /// Storage tier: L0Ram = FP32 in HNSW nodes, L1Nvme = FP32 in mmap file. pub tier: StorageTier, /// mmap-backed vector segment for L1 NVMe tier. diff --git a/nodedb-vector/tests/collection_checkpoint_tombstones.rs b/nodedb-vector/tests/collection_checkpoint_tombstones.rs new file mode 100644 index 00000000..de3b42f9 --- /dev/null +++ b/nodedb-vector/tests/collection_checkpoint_tombstones.rs @@ -0,0 +1,98 @@ +//! Soft-deletes in growing / building segments must survive checkpoint restore. +//! +//! Spec: `delete(id)` on a vector in the growing segment (or a building +//! segment awaiting HNSW completion) tombstones the vector. Checkpoints +//! MUST serialize that tombstone, and `from_checkpoint` MUST apply it so +//! the restored collection reports the same `live_count()` and excludes +//! the deleted vector from `search()` results. +//! +//! Today: +//! - `FlatIndex::get_vector` returns `Some(..)` even for tombstoned +//! slots, so `growing_deleted` is serialized as all-false. +//! - `from_checkpoint` ignores the `growing_deleted` field entirely and +//! re-inserts every vector as live. +//! +//! Result: crash recovery silently resurrects soft-deleted rows — a +//! correctness regression for any workflow using `valid_until` deletes. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::HnswParams; + +fn params() -> HnswParams { + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + } +} + +#[test] +fn growing_segment_tombstones_survive_checkpoint_roundtrip() { + let mut coll = VectorCollection::new(2, params()); + for i in 0..10u32 { + coll.insert(vec![i as f32, 0.0]); + } + assert!(coll.delete(3), "delete on live growing vector must succeed"); + assert!(coll.delete(7), "delete on live growing vector must succeed"); + let live_before = coll.live_count(); + assert_eq!(live_before, 8); + + let bytes = coll.checkpoint_to_bytes(); + let restored = VectorCollection::from_checkpoint(&bytes).expect("checkpoint deserializes"); + + assert_eq!( + restored.live_count(), + live_before, + "tombstoned growing-segment vectors resurrected on restore" + ); + + let results = restored.search(&[3.0, 0.0], 10, 64); + let ids: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + assert!( + !ids.contains(&3), + "soft-deleted id=3 reappeared in search after restore" + ); + assert!( + !ids.contains(&7), + "soft-deleted id=7 reappeared in search after restore" + ); +} + +#[test] +fn building_segment_tombstones_survive_checkpoint_roundtrip() { + // Force a seal so the deleted rows live in a building segment at + // snapshot time, exercising the `building_segments` encode path. + let mut coll = VectorCollection::with_seal_threshold(2, params(), 20); + for i in 0..20u32 { + coll.insert(vec![i as f32, 0.0]); + } + let _req = coll.seal("k").expect("seal produced request"); + // Intentionally do NOT complete the build — vectors now sit in the + // building segment as a FlatIndex. + assert!(coll.delete(5), "delete on building vector must succeed"); + assert!(coll.delete(15), "delete on building vector must succeed"); + let live_before = coll.live_count(); + assert_eq!(live_before, 18); + + let bytes = coll.checkpoint_to_bytes(); + let restored = VectorCollection::from_checkpoint(&bytes).expect("checkpoint deserializes"); + + assert_eq!( + restored.live_count(), + live_before, + "tombstoned building-segment vectors resurrected on restore" + ); + + let results = restored.search(&[5.0, 0.0], 20, 64); + let ids: std::collections::HashSet = results.iter().map(|r| r.id).collect(); + assert!( + !ids.contains(&5), + "soft-deleted id=5 reappeared after restore" + ); + assert!( + !ids.contains(&15), + "soft-deleted id=15 reappeared after restore" + ); +} diff --git a/nodedb-vector/tests/collection_compact_doc_map.rs b/nodedb-vector/tests/collection_compact_doc_map.rs new file mode 100644 index 00000000..a2cad9e1 --- /dev/null +++ b/nodedb-vector/tests/collection_compact_doc_map.rs @@ -0,0 +1,122 @@ +//! `compact()` must keep `doc_id_map` / `multi_doc_map` consistent with the +//! renumbered HNSW local IDs. +//! +//! Spec: `HnswIndex::compact()` removes tombstoned nodes and renumbers +//! surviving local node ids. The collection stores `doc_id_map` and +//! `multi_doc_map` keyed on GLOBAL ids (`seg.base_id + local`). After +//! compaction those globals shift too — the collection MUST walk both +//! maps and rewrite every entry for the compacted segment to the new +//! `(seg.base_id + new_local)` globals. Without the rewrite, +//! `get_doc_id(vid)` and `delete_multi_vector(doc)` point at stale or +//! wrong vectors. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; + +fn params() -> HnswParams { + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + } +} + +fn build_collection_with_docs() -> VectorCollection { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 6); + // Six docs, one vector each. Global ids 0..6. + for i in 0..6u32 { + coll.insert_with_doc_id(vec![i as f32, 0.0], format!("doc_{i}")); + } + // Seal + complete → sealed segment with base_id=0, local ids 0..6. + let req = coll.seal("k").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + coll +} + +#[test] +fn doc_id_map_stays_correct_after_compact() { + let mut coll = build_collection_with_docs(); + + // Tombstone two vectors in the middle of the sealed segment. + assert!(coll.delete(1)); + assert!(coll.delete(3)); + + // Sanity: pre-compact, the surviving doc mapping still resolves. + assert_eq!(coll.get_doc_id(0), Some("doc_0")); + assert_eq!(coll.get_doc_id(5), Some("doc_5")); + + let removed = coll.compact(); + assert_eq!(removed, 2, "compact should remove 2 tombstoned nodes"); + + // Spec: the search results (identified by renumbered global ids) still + // resolve to the original doc strings. For the surviving vectors + // {0, 2, 4, 5} post-compact globals become {0, 1, 2, 3}. `get_doc_id` + // MUST map those new globals to "doc_0", "doc_2", "doc_4", "doc_5". + let results = coll.search(&[0.0, 0.0], 4, 64); + let ids: Vec = results.iter().map(|r| r.id).collect(); + assert_eq!(ids.len(), 4, "expected 4 live vectors post-compact"); + + let observed_docs: std::collections::HashSet = ids + .iter() + .filter_map(|id| coll.get_doc_id(*id).map(|s| s.to_string())) + .collect(); + let expected_docs: std::collections::HashSet = ["doc_0", "doc_2", "doc_4", "doc_5"] + .into_iter() + .map(String::from) + .collect(); + + assert_eq!( + observed_docs, expected_docs, + "doc_id_map was not rewritten after compact — globals shifted but the map did not" + ); +} + +#[test] +fn multi_doc_map_stays_correct_after_compact() { + let mut coll = VectorCollection::with_seal_threshold(2, params(), 6); + + // Two multi-vector docs: doc_a owns globals 0,1,2; doc_b owns 3,4,5. + let a_vecs: Vec> = (0..3u32).map(|i| vec![i as f32, 0.0]).collect(); + let a_refs: Vec<&[f32]> = a_vecs.iter().map(|v| v.as_slice()).collect(); + let a_ids = coll.insert_multi_vector(&a_refs, "doc_a".to_string()); + assert_eq!(a_ids, vec![0, 1, 2]); + + let b_vecs: Vec> = (3..6u32).map(|i| vec![i as f32, 0.0]).collect(); + let b_refs: Vec<&[f32]> = b_vecs.iter().map(|v| v.as_slice()).collect(); + let b_ids = coll.insert_multi_vector(&b_refs, "doc_b".to_string()); + assert_eq!(b_ids, vec![3, 4, 5]); + + let req = coll.seal("k").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + + // Tombstone one vector from each doc (middle of each group). + assert!(coll.delete(1)); + assert!(coll.delete(4)); + + coll.compact(); + + // Spec: `delete_multi_vector("doc_a")` must reach the two remaining + // vectors that originally belonged to doc_a, regardless of the local + // id renumbering performed by HnswIndex::compact. + let deleted_a = coll.delete_multi_vector("doc_a"); + assert_eq!( + deleted_a, 2, + "delete_multi_vector(doc_a) must find its 2 remaining vectors after compact" + ); + + let live_after = coll.live_count(); + assert_eq!( + live_after, 2, + "post-compact + doc_a delete: only doc_b's 2 remaining vectors survive" + ); +} From 5376d5db88039f316dd919a99ba2b40de999be4b Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 05:13:28 +0800 Subject: [PATCH 6/7] feat(vector): wire PQ codec training and two-phase reranking into collection Add IndexConfig / IndexType to VectorCollection so PQ-configured collections train and store PQ codes when sealing a segment rather than always falling back to SQ8. Split quantizer training helpers into collection/quantize.rs to keep lifecycle.rs under the 500-line file cap. Extend the sealed-segment search path to use a unified quantized_search() function that handles both PQ and SQ8 with proper asymmetric scoring, widened candidate generation, and exact FP32 reranking. Report PQ quantization and index type correctly in collection stats. --- nodedb-vector/src/collection/lifecycle.rs | 137 ++++++++++++----- nodedb-vector/src/collection/mod.rs | 1 + nodedb-vector/src/collection/quantize.rs | 113 ++++++++++++++ nodedb-vector/src/collection/search.rs | 156 ++++++++++++++------ nodedb-vector/tests/collection_pq_config.rs | 88 +++++++++++ 5 files changed, 414 insertions(+), 81 deletions(-) create mode 100644 nodedb-vector/src/collection/quantize.rs create mode 100644 nodedb-vector/tests/collection_pq_config.rs diff --git a/nodedb-vector/src/collection/lifecycle.rs b/nodedb-vector/src/collection/lifecycle.rs index c3311573..7eee6113 100644 --- a/nodedb-vector/src/collection/lifecycle.rs +++ b/nodedb-vector/src/collection/lifecycle.rs @@ -2,7 +2,7 @@ use crate::flat::FlatIndex; use crate::hnsw::{HnswIndex, HnswParams}; -use crate::quantize::sq8::Sq8Codec; +use crate::index_config::{IndexConfig, IndexType}; use super::segment::{BuildRequest, BuildingSegment, DEFAULT_SEAL_THRESHOLD, SealedSegment}; @@ -40,6 +40,8 @@ pub struct VectorCollection { pub multi_doc_map: std::collections::HashMap>, /// Number of vectors in the growing segment before sealing. pub(crate) seal_threshold: usize, + /// Full index configuration (index type, PQ params, IVF params). + pub(crate) index_config: IndexConfig, } impl VectorCollection { @@ -50,6 +52,25 @@ impl VectorCollection { /// Create an empty collection with an explicit seal threshold. pub fn with_seal_threshold(dim: usize, params: HnswParams, seal_threshold: usize) -> Self { + let index_config = IndexConfig { + hnsw: params.clone(), + ..IndexConfig::default() + }; + Self::with_seal_threshold_and_config(dim, index_config, seal_threshold) + } + + /// Create an empty collection with a full index configuration. + pub fn with_index_config(dim: usize, config: IndexConfig) -> Self { + Self::with_seal_threshold_and_config(dim, config, DEFAULT_SEAL_THRESHOLD) + } + + /// Create an empty collection with a full index config and custom seal threshold. + pub fn with_seal_threshold_and_config( + dim: usize, + config: IndexConfig, + seal_threshold: usize, + ) -> Self { + let params = config.hnsw.clone(); Self { growing: FlatIndex::new(dim, params.metric), growing_base_id: 0, @@ -66,6 +87,7 @@ impl VectorCollection { doc_id_map: std::collections::HashMap::new(), multi_doc_map: std::collections::HashMap::new(), seal_threshold, + index_config: config, } } @@ -213,53 +235,28 @@ impl VectorCollection { .position(|b| b.segment_id == segment_id) { let building = self.building.remove(pos); - let sq8 = Self::build_sq8_for_index(&index); + let use_pq = self.index_config.index_type == IndexType::HnswPq; + let (sq8, pq) = if use_pq { + ( + None, + Self::build_pq_for_index(&index, self.index_config.pq_m), + ) + } else { + (Self::build_sq8_for_index(&index), None) + }; let (tier, mmap_vectors) = self.resolve_tier_for_build(segment_id, &index); self.sealed.push(SealedSegment { index, base_id: building.base_id, sq8, + pq, tier, mmap_vectors, }); } } - /// Build SQ8 quantized data for an HNSW index. - pub fn build_sq8_for_index(index: &HnswIndex) -> Option<(Sq8Codec, Vec)> { - if index.live_count() < 1000 { - return None; - } - let dim = index.dim(); - let n = index.len(); - - let mut refs: Vec<&[f32]> = Vec::with_capacity(n); - for i in 0..n { - if !index.is_deleted(i as u32) - && let Some(v) = index.get_vector(i as u32) - { - refs.push(v); - } - } - if refs.is_empty() { - return None; - } - - let codec = Sq8Codec::calibrate(&refs, dim); - - let mut data = Vec::with_capacity(dim * n); - for i in 0..n { - if let Some(v) = index.get_vector(i as u32) { - data.extend(codec.quantize(v)); - } else { - data.extend(vec![0u8; dim]); - } - } - - Some((codec, data)) - } - /// Access sealed segments (read-only). pub fn sealed_segments(&self) -> &[SealedSegment] { &self.sealed @@ -276,10 +273,64 @@ impl VectorCollection { } /// Compact sealed segments by removing tombstoned nodes. + /// + /// Rewrites `doc_id_map` and `multi_doc_map` for every sealed segment + /// so that global ids continue to resolve to the correct document + /// strings after local-id renumbering. pub fn compact(&mut self) -> usize { let mut total_removed = 0; for seg in &mut self.sealed { - total_removed += seg.index.compact(); + let base_id = seg.base_id; + let (removed, id_map) = seg.index.compact_with_map(); + total_removed += removed; + if removed == 0 { + continue; + } + + // Rebuild doc_id_map for entries in [base_id, base_id + id_map.len()). + let segment_end = base_id as u64 + id_map.len() as u64; + let doc_keys: Vec = self + .doc_id_map + .keys() + .copied() + .filter(|&k| (k as u64) >= base_id as u64 && (k as u64) < segment_end) + .collect(); + // Two-phase: remove all old entries first, then insert new ones so + // we don't clobber a freshly-remapped entry with a later tombstone + // removal. + let mut new_entries: Vec<(u32, String)> = Vec::with_capacity(doc_keys.len()); + for old_global in &doc_keys { + let doc = self.doc_id_map.remove(old_global); + let old_local = (old_global - base_id) as usize; + let new_local = id_map[old_local]; + if new_local != u32::MAX + && let Some(doc) = doc + { + new_entries.push((base_id + new_local, doc)); + } + } + for (k, v) in new_entries { + self.doc_id_map.insert(k, v); + } + + // Rewrite multi_doc_map entries for this segment. + for ids in self.multi_doc_map.values_mut() { + ids.retain_mut(|vid| { + let v = *vid; + if (v as u64) >= base_id as u64 && (v as u64) < segment_end { + let old_local = (v - base_id) as usize; + let new_local = id_map[old_local]; + if new_local == u32::MAX { + false + } else { + *vid = base_id + new_local; + true + } + } else { + true + } + }); + } } total_removed } @@ -382,12 +433,20 @@ impl VectorCollection { 0.0 }; - let quantization = if self.sealed.iter().any(|s| s.sq8.is_some()) { + let quantization = if self.sealed.iter().any(|s| s.pq.is_some()) { + nodedb_types::VectorIndexQuantization::Pq + } else if self.sealed.iter().any(|s| s.sq8.is_some()) { nodedb_types::VectorIndexQuantization::Sq8 } else { nodedb_types::VectorIndexQuantization::None }; + let index_type = match self.index_config.index_type { + IndexType::HnswPq => nodedb_types::VectorIndexType::HnswPq, + IndexType::IvfPq => nodedb_types::VectorIndexType::IvfPq, + IndexType::Hnsw => nodedb_types::VectorIndexType::Hnsw, + }; + let hnsw_mem: usize = self .sealed .iter() @@ -422,7 +481,7 @@ impl VectorCollection { memory_bytes, disk_bytes, build_in_progress: !self.building.is_empty(), - index_type: nodedb_types::VectorIndexType::Hnsw, + index_type, hnsw_m: self.params.m, hnsw_m0: self.params.m0, hnsw_ef_construction: self.params.ef_construction, diff --git a/nodedb-vector/src/collection/mod.rs b/nodedb-vector/src/collection/mod.rs index 07c118ab..d316ba17 100644 --- a/nodedb-vector/src/collection/mod.rs +++ b/nodedb-vector/src/collection/mod.rs @@ -1,6 +1,7 @@ pub mod budget; pub mod checkpoint; pub mod lifecycle; +pub mod quantize; pub mod search; pub mod segment; pub mod tier; diff --git a/nodedb-vector/src/collection/quantize.rs b/nodedb-vector/src/collection/quantize.rs new file mode 100644 index 00000000..d7ec5a89 --- /dev/null +++ b/nodedb-vector/src/collection/quantize.rs @@ -0,0 +1,113 @@ +//! Quantizer training helpers for `VectorCollection`. +//! +//! Split from `lifecycle.rs` to keep that file under the 500-line cap. +//! All methods here are `impl VectorCollection` blocks — Rust allows a +//! type's impl to be split across files. + +use crate::hnsw::{HnswIndex, HnswParams}; +use crate::index_config::{IndexConfig, IndexType}; +use crate::quantize::pq::PqCodec; +use crate::quantize::sq8::Sq8Codec; + +use super::lifecycle::VectorCollection; +use super::segment::DEFAULT_SEAL_THRESHOLD; + +impl VectorCollection { + /// Convenience constructor for PQ-configured collections. + /// + /// Equivalent to building a full `IndexConfig` with + /// `index_type = HnswPq` and the given `pq_m`. + pub fn with_pq_config(dim: usize, hnsw: HnswParams, pq_m: usize) -> Self { + let config = IndexConfig { + hnsw, + index_type: IndexType::HnswPq, + pq_m, + ..IndexConfig::default() + }; + Self::with_index_config(dim, config) + } + + /// Convenience constructor for PQ-configured collections with a custom + /// seal threshold. + pub fn with_seal_threshold_and_pq_config( + dim: usize, + hnsw: HnswParams, + pq_m: usize, + seal_threshold: usize, + ) -> Self { + let config = IndexConfig { + hnsw, + index_type: IndexType::HnswPq, + pq_m, + ..IndexConfig::default() + }; + Self::with_seal_threshold_and_config(dim, config, seal_threshold) + } + + /// Build SQ8 quantized data for an HNSW index. + /// + /// Returns `None` when there are too few live vectors for stable + /// min/max calibration. + pub fn build_sq8_for_index(index: &HnswIndex) -> Option<(Sq8Codec, Vec)> { + if index.live_count() < 1000 { + return None; + } + let dim = index.dim(); + let n = index.len(); + + let mut refs: Vec<&[f32]> = Vec::with_capacity(n); + for i in 0..n { + if !index.is_deleted(i as u32) + && let Some(v) = index.get_vector(i as u32) + { + refs.push(v); + } + } + if refs.is_empty() { + return None; + } + + let codec = Sq8Codec::calibrate(&refs, dim); + + let mut data = Vec::with_capacity(dim * n); + for i in 0..n { + if let Some(v) = index.get_vector(i as u32) { + data.extend(codec.quantize(v)); + } else { + data.extend(vec![0u8; dim]); + } + } + + Some((codec, data)) + } + + /// Train a PQ codec from a built HNSW index's live vectors. + pub fn build_pq_for_index(index: &HnswIndex, pq_m: usize) -> Option<(PqCodec, Vec)> { + let dim = index.dim(); + if pq_m == 0 || !dim.is_multiple_of(pq_m) { + return None; + } + let n = index.len(); + let mut refs: Vec> = Vec::with_capacity(n); + for i in 0..n { + if !index.is_deleted(i as u32) + && let Some(v) = index.get_vector(i as u32) + { + refs.push(v.to_vec()); + } + } + if refs.is_empty() { + return None; + } + let refs_slices: Vec<&[f32]> = refs.iter().map(|v| v.as_slice()).collect(); + let k = 256usize.min(refs.len()); + let codec = PqCodec::train(&refs_slices, dim, pq_m, k, 20); + let codes = codec.encode_batch(&refs_slices); + Some((codec, codes)) + } +} + +// Keep the DEFAULT_SEAL_THRESHOLD import live when future refactors move +// additional ctors into this file; explicitly referenced to suppress +// an otherwise-unused warning. +const _: usize = DEFAULT_SEAL_THRESHOLD; diff --git a/nodedb-vector/src/collection/search.rs b/nodedb-vector/src/collection/search.rs index 060f5b31..eb216bcd 100644 --- a/nodedb-vector/src/collection/search.rs +++ b/nodedb-vector/src/collection/search.rs @@ -1,9 +1,111 @@ //! VectorCollection search: multi-segment merging with SQ8 reranking. -use crate::distance::distance; +use crate::distance::{DistanceMetric, distance}; use crate::hnsw::SearchResult; use super::lifecycle::VectorCollection; +use super::segment::SealedSegment; + +/// Score a single candidate via the SQ8 codec, using the metric-appropriate +/// asymmetric distance. +#[inline] +fn sq8_score( + codec: &crate::quantize::sq8::Sq8Codec, + query: &[f32], + encoded: &[u8], + metric: DistanceMetric, +) -> f32 { + match metric { + DistanceMetric::Cosine => codec.asymmetric_cosine(query, encoded), + DistanceMetric::InnerProduct => codec.asymmetric_ip(query, encoded), + // L2 (and all other metrics that don't have a specialized asymmetric + // form yet) fall back to squared L2 — correct for ordering when the + // metric is L2 and a reasonable proxy otherwise since we rerank with + // exact FP32 below. + _ => codec.asymmetric_l2(query, encoded), + } +} + +/// Candidate-generation + rerank for a sealed segment that has a quantized +/// codec attached. Generates a widened candidate pool via HNSW, re-scores +/// candidates using the quantized codec (this is where SQ8/PQ actually pay +/// off — the FP32 vectors need not be resident), and reranks the top +/// `top_k` via exact FP32 distance from mmap or index storage. +fn quantized_search( + seg: &SealedSegment, + query: &[f32], + top_k: usize, + ef: usize, + metric: DistanceMetric, +) -> Vec { + let rerank_k = top_k.saturating_mul(3).max(20); + let hnsw_candidates = seg.index.search(query, rerank_k, ef); + + // Phase 1: rank candidates by quantized distance. + let mut scored: Vec<(u32, f32)> = if let Some((codec, codes)) = &seg.pq { + let table = codec.build_distance_table(query); + let m = codec.m; + hnsw_candidates + .into_iter() + .filter_map(|r| { + let start = (r.id as usize).checked_mul(m)?; + let end = start.checked_add(m)?; + let slice = codes.get(start..end)?; + Some((r.id, codec.asymmetric_distance(&table, slice))) + }) + .collect() + } else if let Some((codec, data)) = &seg.sq8 { + let dim = codec.dim(); + hnsw_candidates + .into_iter() + .filter_map(|r| { + let start = (r.id as usize).checked_mul(dim)?; + let end = start.checked_add(dim)?; + let slice = data.get(start..end)?; + Some((r.id, sq8_score(codec, query, slice, metric))) + }) + .collect() + } else { + hnsw_candidates + .into_iter() + .map(|r| (r.id, r.distance)) + .collect() + }; + scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + // Keep only the most promising candidates for FP32 rerank. + let keep = rerank_k.min(scored.len()); + scored.truncate(keep); + + // Prefetch FP32 vectors for reranking. + if let Some(mmap) = &seg.mmap_vectors { + let ids: Vec = scored.iter().map(|&(id, _)| id).collect(); + mmap.prefetch_batch(&ids); + } + + // Phase 2: rerank with exact FP32. + let mut reranked: Vec = scored + .into_iter() + .filter_map(|(id, _)| { + let v = if let Some(mmap) = &seg.mmap_vectors { + mmap.get_vector(id)? + } else { + seg.index.get_vector(id)? + }; + Some(SearchResult { + id, + distance: distance(query, v, metric), + }) + }) + .collect(); + reranked.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + reranked.truncate(top_k); + reranked +} impl VectorCollection { /// Search across all segments, merging results by distance. @@ -19,44 +121,8 @@ impl VectorCollection { // Search sealed segments. for seg in &self.sealed { - let results = if let Some(_sq8) = &seg.sq8 { - // Quantized two-phase search: use HNSW graph for O(log N) candidate - // generation, then rerank with exact FP32 distance. - let rerank_k = top_k.saturating_mul(3).max(20); - let hnsw_candidates = seg.index.search(query, rerank_k, ef); - let candidates: Vec<(u32, f32)> = hnsw_candidates - .into_iter() - .map(|r| (r.id, r.distance)) - .collect(); - - // Prefetch FP32 vectors for reranking candidates. - if let Some(mmap) = &seg.mmap_vectors { - let ids: Vec = candidates.iter().map(|&(id, _)| id).collect(); - mmap.prefetch_batch(&ids); - } - - // Phase 2: Rerank with exact FP32 distance. - let mut reranked: Vec = candidates - .iter() - .filter_map(|&(id, _)| { - let v = if let Some(mmap) = &seg.mmap_vectors { - mmap.get_vector(id)? - } else { - seg.index.get_vector(id)? - }; - Some(SearchResult { - id, - distance: distance(query, v, self.params.metric), - }) - }) - .collect(); - reranked.sort_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - reranked.truncate(top_k); - reranked + let results = if seg.pq.is_some() || seg.sq8.is_some() { + quantized_search(seg, query, top_k, ef, self.params.metric) } else { seg.index.search(query, top_k, ef) }; @@ -94,14 +160,18 @@ impl VectorCollection { ) -> Vec { let mut all: Vec = Vec::new(); - let growing_results = self.growing.search_filtered(query, top_k, bitmap); + let growing_results = + self.growing + .search_filtered_offset(query, top_k, bitmap, self.growing_base_id); for mut r in growing_results { r.id += self.growing_base_id; all.push(r); } for seg in &self.sealed { - let results = seg.index.search_with_bitmap_bytes(query, top_k, ef, bitmap); + let results = + seg.index + .search_with_bitmap_bytes_offset(query, top_k, ef, bitmap, seg.base_id); for mut r in results { r.id += seg.base_id; all.push(r); @@ -109,7 +179,9 @@ impl VectorCollection { } for seg in &self.building { - let results = seg.flat.search_filtered(query, top_k, bitmap); + let results = seg + .flat + .search_filtered_offset(query, top_k, bitmap, seg.base_id); for mut r in results { r.id += seg.base_id; all.push(r); diff --git a/nodedb-vector/tests/collection_pq_config.rs b/nodedb-vector/tests/collection_pq_config.rs new file mode 100644 index 00000000..da1e87af --- /dev/null +++ b/nodedb-vector/tests/collection_pq_config.rs @@ -0,0 +1,88 @@ +//! `index_type='hnsw_pq'` must produce PQ-compressed segments. +//! +//! Spec: when a collection is configured for HNSW+PQ (advertised via the +//! SQL DDL `CREATE INDEX ... WITH (index_type='hnsw_pq', pq_m=...)`), +//! `complete_build` MUST train a `PqCodec` on the finished segment and +//! surface `VectorIndexQuantization::Pq` in `stats()`. Today, the config +//! is accepted at the DDL layer, stored, and then ignored — +//! `complete_build` unconditionally calls `build_sq8_for_index`, so +//! operators who asked for 8-16× memory reduction silently receive 4× +//! SQ8 and have no signal from `stats()` that the request was dropped. + +#![cfg(feature = "collection")] + +use nodedb_vector::DistanceMetric; +use nodedb_vector::collection::VectorCollection; +use nodedb_vector::hnsw::{HnswIndex, HnswParams}; + +fn params() -> HnswParams { + HnswParams { + m: 16, + m0: 32, + ef_construction: 100, + metric: DistanceMetric::L2, + } +} + +/// Build a collection with 1024 vectors of dim=8 and complete one segment +/// build. The `>= 1000` vector threshold in `build_sq8_for_index` means a +/// quantizer WILL be attached — so `stats().quantization` is either `Sq8` +/// (buggy fallback) or `Pq` (spec-correct for a HnswPq-configured index). +fn make_built_collection_with_pq_config() -> VectorCollection { + // Uses the convenience constructor `with_seal_threshold_and_pq_config` so + // callers don't have to hand-build a full `IndexConfig` just to request PQ. + let mut coll = VectorCollection::with_seal_threshold_and_pq_config(8, params(), 2, 1024); + for i in 0..1024u32 { + let mut v = vec![0.0f32; 8]; + for (d, slot) in v.iter_mut().enumerate() { + *slot = ((i as f32) * 0.01 + (d as f32) * 0.1).sin(); + } + coll.insert(v); + } + let req = coll.seal("pq").expect("seal produced request"); + let mut idx = HnswIndex::new(req.dim, req.params.clone()); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + coll +} + +#[test] +fn hnsw_pq_config_produces_pq_quantization() { + let coll = make_built_collection_with_pq_config(); + let stats = coll.stats(); + assert_eq!( + stats.quantization, + nodedb_types::VectorIndexQuantization::Pq, + "index_type='hnsw_pq' must produce PQ-compressed segments and \ + report VectorIndexQuantization::Pq; got {:?}", + stats.quantization + ); +} + +#[test] +fn hnsw_pq_config_stats_index_type_reports_hnsw_pq() { + let coll = make_built_collection_with_pq_config(); + let stats = coll.stats(); + assert_eq!( + stats.index_type, + nodedb_types::VectorIndexType::HnswPq, + "stats().index_type must reflect the configured HnswPq index" + ); +} + +#[test] +fn hnsw_pq_config_survives_checkpoint_roundtrip() { + let coll = make_built_collection_with_pq_config(); + let bytes = coll.checkpoint_to_bytes(); + let restored = VectorCollection::from_checkpoint(&bytes) + .expect("checkpoint must deserialize for PQ-configured collection"); + let stats = restored.stats(); + assert_eq!( + stats.quantization, + nodedb_types::VectorIndexQuantization::Pq, + "PQ codec must survive checkpoint roundtrip; got {:?}", + stats.quantization + ); +} From 400f64552be853b0f53b5d9a2037117d285b6efa Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 05:14:05 +0800 Subject: [PATCH 7/7] refactor(vector): remove legacy SIMD dispatch file superseded by submodule --- nodedb-vector/src/distance/simd.rs | 504 ----------------------------- 1 file changed, 504 deletions(-) delete mode 100644 nodedb-vector/src/distance/simd.rs diff --git a/nodedb-vector/src/distance/simd.rs b/nodedb-vector/src/distance/simd.rs deleted file mode 100644 index 9af97898..00000000 --- a/nodedb-vector/src/distance/simd.rs +++ /dev/null @@ -1,504 +0,0 @@ -//! Runtime SIMD dispatch for vector distance and bitmap operations. -//! -//! Detects CPU features at startup and selects the fastest available -//! kernel for each operation. A single binary supports all targets: -//! -//! - AVX-512 (512-bit, 16 floats/op) — Intel Xeon, AMD Zen 4+ -//! - AVX2+FMA (256-bit, 8 floats/op) — most x86_64 since 2013 -//! - NEON (128-bit, 4 floats/op) — ARM64 (Graviton, Apple Silicon) -//! - Scalar fallback — auto-vectorized loops - -/// Selected SIMD runtime — function pointers to the best available kernels. -pub struct SimdRuntime { - pub l2_squared: fn(&[f32], &[f32]) -> f32, - pub cosine_distance: fn(&[f32], &[f32]) -> f32, - pub neg_inner_product: fn(&[f32], &[f32]) -> f32, - pub hamming: fn(&[u8], &[u8]) -> u32, - pub name: &'static str, -} - -impl SimdRuntime { - /// Detect CPU features and select the best kernels. - pub fn detect() -> Self { - #[cfg(target_arch = "x86_64")] - { - if std::is_x86_feature_detected!("avx512f") { - return Self { - l2_squared: avx512::l2_squared, - cosine_distance: avx512::cosine_distance, - neg_inner_product: avx512::neg_inner_product, - hamming: fast_hamming, - name: "avx512", - }; - } - if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") { - return Self { - l2_squared: avx2::l2_squared, - cosine_distance: avx2::cosine_distance, - neg_inner_product: avx2::neg_inner_product, - hamming: fast_hamming, - name: "avx2+fma", - }; - } - } - #[cfg(target_arch = "aarch64")] - { - return Self { - l2_squared: neon::l2_squared, - cosine_distance: neon::cosine_distance, - neg_inner_product: neon::neg_inner_product, - hamming: fast_hamming, - name: "neon", - }; - } - #[allow(unreachable_code)] - Self { - l2_squared: scalar_l2, - cosine_distance: scalar_cosine, - neg_inner_product: scalar_ip, - hamming: fast_hamming, - name: "scalar", - } - } -} - -/// Global SIMD runtime — initialized once, used everywhere. -static RUNTIME: std::sync::OnceLock = std::sync::OnceLock::new(); - -/// Get the global SIMD runtime (auto-detects on first call). -pub fn runtime() -> &'static SimdRuntime { - RUNTIME.get_or_init(SimdRuntime::detect) -} - -// ── Scalar fallback ── - -fn scalar_l2(a: &[f32], b: &[f32]) -> f32 { - let mut sum = 0.0f32; - for i in 0..a.len() { - let d = a[i] - b[i]; - sum += d * d; - } - sum -} - -fn scalar_cosine(a: &[f32], b: &[f32]) -> f32 { - let mut dot = 0.0f32; - let mut na = 0.0f32; - let mut nb = 0.0f32; - for i in 0..a.len() { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } -} - -fn scalar_ip(a: &[f32], b: &[f32]) -> f32 { - let mut dot = 0.0f32; - for i in 0..a.len() { - dot += a[i] * b[i]; - } - -dot -} - -/// Fast Hamming distance using u64 POPCNT (available on all modern CPUs). -fn fast_hamming(a: &[u8], b: &[u8]) -> u32 { - let mut dist = 0u32; - let chunks = a.len() / 8; - for i in 0..chunks { - let off = i * 8; - let xa = u64::from_le_bytes([ - a[off], - a[off + 1], - a[off + 2], - a[off + 3], - a[off + 4], - a[off + 5], - a[off + 6], - a[off + 7], - ]); - let xb = u64::from_le_bytes([ - b[off], - b[off + 1], - b[off + 2], - b[off + 3], - b[off + 4], - b[off + 5], - b[off + 6], - b[off + 7], - ]); - dist += (xa ^ xb).count_ones(); - } - for i in (chunks * 8)..a.len() { - dist += (a[i] ^ b[i]).count_ones(); - } - dist -} - -// ── AVX2+FMA kernels ── - -#[cfg(target_arch = "x86_64")] -mod avx2 { - pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { - // SAFETY: caller verified avx2+fma via is_x86_feature_detected. - unsafe { l2_squared_impl(a, b) } - } - - #[target_feature(enable = "avx2,fma")] - unsafe fn l2_squared_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut sum = _mm256_setzero_ps(); - let chunks = n / 8; - for i in 0..chunks { - let off = i * 8; - let va = _mm256_loadu_ps(a.as_ptr().add(off)); - let vb = _mm256_loadu_ps(b.as_ptr().add(off)); - let diff = _mm256_sub_ps(va, vb); - sum = _mm256_fmadd_ps(diff, diff, sum); - } - let mut result = hsum256(sum); - for i in (chunks * 8)..n { - let d = a[i] - b[i]; - result += d * d; - } - result - } - } - - pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { - unsafe { cosine_impl(a, b) } - } - - #[target_feature(enable = "avx2,fma")] - unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm256_setzero_ps(); - let mut vna = _mm256_setzero_ps(); - let mut vnb = _mm256_setzero_ps(); - let chunks = n / 8; - for i in 0..chunks { - let off = i * 8; - let va = _mm256_loadu_ps(a.as_ptr().add(off)); - let vb = _mm256_loadu_ps(b.as_ptr().add(off)); - vdot = _mm256_fmadd_ps(va, vb, vdot); - vna = _mm256_fmadd_ps(va, va, vna); - vnb = _mm256_fmadd_ps(vb, vb, vnb); - } - let mut dot = hsum256(vdot); - let mut na = hsum256(vna); - let mut nb = hsum256(vnb); - for i in (chunks * 8)..n { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } - } - } - - pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { - unsafe { ip_impl(a, b) } - } - - #[target_feature(enable = "avx2,fma")] - unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm256_setzero_ps(); - let chunks = n / 8; - for i in 0..chunks { - let off = i * 8; - let va = _mm256_loadu_ps(a.as_ptr().add(off)); - let vb = _mm256_loadu_ps(b.as_ptr().add(off)); - vdot = _mm256_fmadd_ps(va, vb, vdot); - } - let mut dot = hsum256(vdot); - for i in (chunks * 8)..n { - dot += a[i] * b[i]; - } - -dot - } - } - - /// Horizontal sum of 8 × f32 in a __m256. - #[target_feature(enable = "avx2")] - unsafe fn hsum256(v: std::arch::x86_64::__m256) -> f32 { - use std::arch::x86_64::*; - let hi = _mm256_extractf128_ps(v, 1); - let lo = _mm256_castps256_ps128(v); - let sum128 = _mm_add_ps(lo, hi); - let shuf = _mm_movehdup_ps(sum128); - let sums = _mm_add_ps(sum128, shuf); - let shuf2 = _mm_movehl_ps(sums, sums); - let sums2 = _mm_add_ss(sums, shuf2); - _mm_cvtss_f32(sums2) - } -} - -// ── AVX-512 kernels ── - -#[cfg(target_arch = "x86_64")] -mod avx512 { - pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { - unsafe { l2_impl(a, b) } - } - - #[target_feature(enable = "avx512f")] - unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut sum = _mm512_setzero_ps(); - let chunks = n / 16; - for i in 0..chunks { - let off = i * 16; - let va = _mm512_loadu_ps(a.as_ptr().add(off)); - let vb = _mm512_loadu_ps(b.as_ptr().add(off)); - let diff = _mm512_sub_ps(va, vb); - sum = _mm512_fmadd_ps(diff, diff, sum); - } - let mut result = _mm512_reduce_add_ps(sum); - for i in (chunks * 16)..n { - let d = a[i] - b[i]; - result += d * d; - } - result - } - } - - pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { - unsafe { cosine_impl(a, b) } - } - - #[target_feature(enable = "avx512f")] - unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm512_setzero_ps(); - let mut vna = _mm512_setzero_ps(); - let mut vnb = _mm512_setzero_ps(); - let chunks = n / 16; - for i in 0..chunks { - let off = i * 16; - let va = _mm512_loadu_ps(a.as_ptr().add(off)); - let vb = _mm512_loadu_ps(b.as_ptr().add(off)); - vdot = _mm512_fmadd_ps(va, vb, vdot); - vna = _mm512_fmadd_ps(va, va, vna); - vnb = _mm512_fmadd_ps(vb, vb, vnb); - } - let mut dot = _mm512_reduce_add_ps(vdot); - let mut na = _mm512_reduce_add_ps(vna); - let mut nb = _mm512_reduce_add_ps(vnb); - for i in (chunks * 16)..n { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } - } - } - - pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { - unsafe { ip_impl(a, b) } - } - - #[target_feature(enable = "avx512f")] - unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { - unsafe { - use std::arch::x86_64::*; - let n = a.len(); - let mut vdot = _mm512_setzero_ps(); - let chunks = n / 16; - for i in 0..chunks { - let off = i * 16; - let va = _mm512_loadu_ps(a.as_ptr().add(off)); - let vb = _mm512_loadu_ps(b.as_ptr().add(off)); - vdot = _mm512_fmadd_ps(va, vb, vdot); - } - let mut dot = _mm512_reduce_add_ps(vdot); - for i in (chunks * 16)..n { - dot += a[i] * b[i]; - } - -dot - } - } -} - -// ── NEON kernels (ARM64) ── - -#[cfg(target_arch = "aarch64")] -mod neon { - pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 { - unsafe { l2_impl(a, b) } - } - - unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 { - use std::arch::aarch64::*; - let n = a.len(); - let mut sum = vdupq_n_f32(0.0); - let chunks = n / 4; - for i in 0..chunks { - let off = i * 4; - let va = vld1q_f32(a.as_ptr().add(off)); - let vb = vld1q_f32(b.as_ptr().add(off)); - let diff = vsubq_f32(va, vb); - sum = vfmaq_f32(sum, diff, diff); - } - let mut result = vaddvq_f32(sum); - for i in (chunks * 4)..n { - let d = a[i] - b[i]; - result += d * d; - } - result - } - - pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 { - unsafe { cosine_impl(a, b) } - } - - unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 { - use std::arch::aarch64::*; - let n = a.len(); - let mut vdot = vdupq_n_f32(0.0); - let mut vna = vdupq_n_f32(0.0); - let mut vnb = vdupq_n_f32(0.0); - let chunks = n / 4; - for i in 0..chunks { - let off = i * 4; - let va = vld1q_f32(a.as_ptr().add(off)); - let vb = vld1q_f32(b.as_ptr().add(off)); - vdot = vfmaq_f32(vdot, va, vb); - vna = vfmaq_f32(vna, va, va); - vnb = vfmaq_f32(vnb, vb, vb); - } - let mut dot = vaddvq_f32(vdot); - let mut na = vaddvq_f32(vna); - let mut nb = vaddvq_f32(vnb); - for i in (chunks * 4)..n { - dot += a[i] * b[i]; - na += a[i] * a[i]; - nb += b[i] * b[i]; - } - let denom = (na * nb).sqrt(); - if denom < f32::EPSILON { - 1.0 - } else { - (1.0 - dot / denom).max(0.0) - } - } - - pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 { - unsafe { ip_impl(a, b) } - } - - unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 { - use std::arch::aarch64::*; - let n = a.len(); - let mut vdot = vdupq_n_f32(0.0); - let chunks = n / 4; - for i in 0..chunks { - let off = i * 4; - let va = vld1q_f32(a.as_ptr().add(off)); - let vb = vld1q_f32(b.as_ptr().add(off)); - vdot = vfmaq_f32(vdot, va, vb); - } - let mut dot = vaddvq_f32(vdot); - for i in (chunks * 4)..n { - dot += a[i] * b[i]; - } - -dot - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn runtime_detects_features() { - let rt = SimdRuntime::detect(); - // Should detect at least scalar on any platform. - assert!(!rt.name.is_empty()); - tracing::info!("SIMD runtime: {}", rt.name); - } - - #[test] - fn l2_matches_scalar() { - let rt = runtime(); - let a: Vec = (0..768).map(|i| (i as f32) * 0.01).collect(); - let b: Vec = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect(); - - let simd_result = (rt.l2_squared)(&a, &b); - let scalar_result = scalar_l2(&a, &b); - assert!( - (simd_result - scalar_result).abs() < 0.01, - "simd={simd_result}, scalar={scalar_result}" - ); - } - - #[test] - fn cosine_matches_scalar() { - let rt = runtime(); - let a: Vec = (0..768).map(|i| (i as f32).sin()).collect(); - let b: Vec = (0..768).map(|i| (i as f32).cos()).collect(); - - let simd_result = (rt.cosine_distance)(&a, &b); - let scalar_result = scalar_cosine(&a, &b); - assert!( - (simd_result - scalar_result).abs() < 0.001, - "simd={simd_result}, scalar={scalar_result}" - ); - } - - #[test] - fn ip_matches_scalar() { - let rt = runtime(); - let a: Vec = (0..128).map(|i| (i as f32) * 0.1).collect(); - let b: Vec = (0..128).map(|i| (i as f32) * 0.2).collect(); - - let simd_result = (rt.neg_inner_product)(&a, &b); - let scalar_result = scalar_ip(&a, &b); - assert!( - (simd_result - scalar_result).abs() < 0.1, - "simd={simd_result}, scalar={scalar_result}" - ); - } - - #[test] - fn hamming_matches() { - let a = vec![0b10101010u8; 16]; - let b = vec![0b01010101u8; 16]; - assert_eq!(fast_hamming(&a, &b), 128); // all 128 bits differ - } - - #[test] - fn small_vectors() { - let rt = runtime(); - // Vectors smaller than SIMD width — tests remainder handling. - let a = [1.0f32, 2.0, 3.0]; - let b = [4.0f32, 5.0, 6.0]; - let l2 = (rt.l2_squared)(&a, &b); - assert!((l2 - 27.0).abs() < 0.01); // (3² + 3² + 3²) = 27 - } -}