Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/ruvector-diskann-node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl DiskAnn {
pq_subspaces: options.pq_subspaces.unwrap_or(0) as usize,
pq_iterations: options.pq_iterations.unwrap_or(10) as usize,
storage_path: options.storage_path.map(PathBuf::from),
..Default::default()
};
let index = CoreIndex::new(config);
Ok(Self {
Expand Down
122 changes: 122 additions & 0 deletions crates/ruvector-diskann/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,108 @@ impl VamanaGraph {
self.greedy_search_fast(vectors, query, beam_width, &mut visited)
}

/// Greedy beam search driven by a caller-supplied distance function.
///
/// This is the generalised entry point used by [`crate::index::DiskAnnIndex::search`]
/// when a quantizer is wired in: the closure can read PQ / RaBitQ codes
/// (or anything else) instead of dereferencing `FlatVectors`. It is a
/// near-verbatim copy of [`Self::greedy_search_fast`] so we don't
/// regress the existing f32 hot path — that method stays the canonical
/// reference implementation; this one just swaps `l2_squared(...)` for
/// `distance_fn(node)`.
///
/// `n` is the total node count (used to size the visited set when the
/// caller wants a one-shot search; pass the index's vector count). The
/// reusable variant takes an externally allocated [`VisitedSet`].
pub fn greedy_search_with_codes<F>(
&self,
node_count: usize,
beam_width: usize,
mut distance_fn: F,
) -> (Vec<u32>, usize)
where
F: FnMut(u32) -> f32,
{
let mut visited = VisitedSet::new(node_count);
self.greedy_search_with_codes_into(beam_width, &mut visited, &mut distance_fn)
}

/// Reusable-`VisitedSet` flavour of [`Self::greedy_search_with_codes`].
/// Same shape as [`Self::greedy_search_fast`] but the per-node distance
/// is opaque — that's the whole point: graph traversal no longer needs
/// the original f32 vectors.
pub fn greedy_search_with_codes_into<F>(
&self,
beam_width: usize,
visited: &mut VisitedSet,
distance_fn: &mut F,
) -> (Vec<u32>, usize)
where
F: FnMut(u32) -> f32,
{
visited.clear();

let mut candidates = BinaryHeap::<Candidate>::new();
let mut best = BinaryHeap::<MaxCandidate>::new();

let start = self.medoid;
let start_dist = distance_fn(start);
candidates.push(Candidate {
id: start,
distance: start_dist,
});
best.push(MaxCandidate {
id: start,
distance: start_dist,
});
visited.insert(start);

let mut visit_count = 1usize;

while let Some(current) = candidates.pop() {
if best.len() >= beam_width {
if let Some(worst) = best.peek() {
if current.distance > worst.distance {
break;
}
}
}

for &neighbor in &self.neighbors[current.id as usize] {
if visited.contains(neighbor) {
continue;
}
visited.insert(neighbor);
visit_count += 1;

let dist = distance_fn(neighbor);

let dominated =
best.len() >= beam_width && best.peek().map_or(false, |w| dist >= w.distance);

if !dominated {
candidates.push(Candidate {
id: neighbor,
distance: dist,
});
best.push(MaxCandidate {
id: neighbor,
distance: dist,
});
if best.len() > beam_width {
best.pop();
}
}
}
}

let mut result: Vec<(u32, f32)> = best.into_iter().map(|c| (c.id, c.distance)).collect();
result.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let ids: Vec<u32> = result.into_iter().map(|(id, _)| id).collect();

(ids, visit_count)
}

fn robust_prune(
&self,
vectors: &FlatVectors,
Expand Down Expand Up @@ -325,6 +427,26 @@ mod tests {
assert!(results.contains(&42));
}

#[test]
fn test_greedy_search_with_codes_matches_flat() {
// Sanity check: when the closure recomputes l2_squared against the
// original f32 vectors, the new entry point must return the same
// candidate set as the legacy `greedy_search` — proves we didn't
// accidentally drift the traversal logic.
let vectors = random_flat(200, 32);
let mut graph = VamanaGraph::new(200, 32, 64, 1.2);
graph.build(&vectors).unwrap();

let query = vectors.get(17).to_vec();
let (legacy, _) = graph.greedy_search(&vectors, &query, 16);

let (via_closure, _) = graph.greedy_search_with_codes(vectors.len(), 16, |id| {
l2_squared(vectors.get(id as usize), &query)
});

assert_eq!(legacy, via_closure);
}

#[test]
fn test_vamana_bounded_degree() {
let vectors = random_flat(100, 16);
Expand Down
Loading
Loading