Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions crates/lance-graph-python/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,50 @@ impl CypherEngine {
}
}

/// Execute Cypher query with vector reranking
///
/// Convenience method combining graph traversal and vector similarity ranking.
/// See CypherQuery.execute_with_vector_rerank for detailed documentation.
///
/// Parameters
/// ----------
/// query : str
/// Cypher query string
/// vector_search : VectorSearch
/// Vector search configuration
///
/// Returns
/// -------
/// pyarrow.Table
/// Results sorted by vector similarity
fn execute_with_vector_rerank(
&self,
py: Python,
query: &str,
vector_search: &VectorSearch,
) -> PyResult<PyObject> {
// Parse query and execute with cached catalog/context
let cypher_query = RustCypherQuery::new(query)
.map_err(graph_error_to_pyerr)?
.with_config(self.config.clone());

let catalog = self.catalog.clone();
let context = self.context.as_ref().clone();
let vs = vector_search.inner.clone();

// Execute query to get candidates, then apply vector reranking
let result_batch = RT
.block_on(Some(py), async move {
let candidates = cypher_query
.execute_with_catalog_and_context(catalog, context)
.await?;
vs.search(&candidates).await
})?
.map_err(graph_error_to_pyerr)?;

record_batch_to_python_table(py, &result_batch)
}

fn __repr__(&self) -> String {
format!(
"CypherEngine(nodes={}, relationships={})",
Expand Down
83 changes: 83 additions & 0 deletions python/python/tests/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,86 @@ def test_vector_search_different_query_vectors(vector_env):
.search(table)
)
assert results3.to_pydict()["name"][0] == "Doc4"


def test_cypher_engine_execute_with_vector_rerank(vector_env):
"""Test CypherEngine.execute_with_vector_rerank basic functionality."""
from lance_graph import CypherEngine

config, datasets, _ = vector_env
engine = CypherEngine(config, datasets)

results = engine.execute_with_vector_rerank(
"MATCH (d:Document) WHERE d.category = 'tech' RETURN d.id, d.name, d.embedding",
VectorSearch("d.embedding")
.query_vector([1.0, 0.0, 0.0])
.metric(DistanceMetric.L2)
.top_k(2),
)

data = results.to_pydict()
assert len(data["d.name"]) == 2
assert data["d.name"][0] == "Doc1"


def test_cypher_engine_vs_cypher_query_vector_rerank_equivalence(vector_env):
"""Test that CypherEngine produces same results as CypherQuery for vector rerank."""
from lance_graph import CypherEngine

config, datasets, _ = vector_env

query_text = (
"MATCH (d:Document) WHERE d.category = 'tech' RETURN d.id, d.name, d.embedding"
)
vector_search = (
VectorSearch("d.embedding")
.query_vector([1.0, 0.0, 0.0])
.metric(DistanceMetric.L2)
.top_k(2)
)

# Execute with CypherQuery
query = CypherQuery(query_text).with_config(config)
result_query = query.execute_with_vector_rerank(datasets, vector_search)

# Execute with CypherEngine
engine = CypherEngine(config, datasets)
result_engine = engine.execute_with_vector_rerank(query_text, vector_search)

# Results should be identical
assert result_query.to_pydict() == result_engine.to_pydict()


def test_cypher_engine_vector_rerank_multiple_queries(vector_env):
"""Test that CypherEngine efficiently handles multiple vector rerank queries."""
from lance_graph import CypherEngine

config, datasets, _ = vector_env
engine = CypherEngine(config, datasets)

# Execute multiple different queries using the same cached engine
results1 = engine.execute_with_vector_rerank(
"MATCH (d:Document) RETURN d.id, d.name, d.embedding",
VectorSearch("d.embedding")
.query_vector([1.0, 0.0, 0.0])
.metric(DistanceMetric.L2)
.top_k(2),
)

results2 = engine.execute_with_vector_rerank(
"MATCH (d:Document) WHERE d.category = 'science' "
"RETURN d.id, d.name, d.embedding",
VectorSearch("d.embedding")
.query_vector([0.0, 1.0, 0.0])
.metric(DistanceMetric.Cosine)
.top_k(1),
)

data1 = results1.to_pydict()
data2 = results2.to_pydict()

assert len(data1["d.name"]) == 2
assert data1["d.name"][0] == "Doc1"

assert len(data2["d.name"]) == 1
assert data2["d.name"][0] == "Doc3"
Loading