diff --git a/crates/lance-graph-python/src/graph.rs b/crates/lance-graph-python/src/graph.rs index f6a7380..ad6d753 100644 --- a/crates/lance-graph-python/src/graph.rs +++ b/crates/lance-graph-python/src/graph.rs @@ -1056,6 +1056,50 @@ impl CypherEngine { } } + /// Execute Cypher query with vector reranking + /// + /// Convenience method combining graph traversal and vector similarity ranking. + /// See CypherQuery.execute_with_vector_rerank for detailed documentation. + /// + /// Parameters + /// ---------- + /// query : str + /// Cypher query string + /// vector_search : VectorSearch + /// Vector search configuration + /// + /// Returns + /// ------- + /// pyarrow.Table + /// Results sorted by vector similarity + fn execute_with_vector_rerank( + &self, + py: Python, + query: &str, + vector_search: &VectorSearch, + ) -> PyResult { + // Parse query and execute with cached catalog/context + let cypher_query = RustCypherQuery::new(query) + .map_err(graph_error_to_pyerr)? + .with_config(self.config.clone()); + + let catalog = self.catalog.clone(); + let context = self.context.as_ref().clone(); + let vs = vector_search.inner.clone(); + + // Execute query to get candidates, then apply vector reranking + let result_batch = RT + .block_on(Some(py), async move { + let candidates = cypher_query + .execute_with_catalog_and_context(catalog, context) + .await?; + vs.search(&candidates).await + })? + .map_err(graph_error_to_pyerr)?; + + record_batch_to_python_table(py, &result_batch) + } + fn __repr__(&self) -> String { format!( "CypherEngine(nodes={}, relationships={})", diff --git a/python/python/tests/test_vector_search.py b/python/python/tests/test_vector_search.py index 39add7f..0675faa 100644 --- a/python/python/tests/test_vector_search.py +++ b/python/python/tests/test_vector_search.py @@ -317,3 +317,86 @@ def test_vector_search_different_query_vectors(vector_env): .search(table) ) assert results3.to_pydict()["name"][0] == "Doc4" + + +def test_cypher_engine_execute_with_vector_rerank(vector_env): + """Test CypherEngine.execute_with_vector_rerank basic functionality.""" + from lance_graph import CypherEngine + + config, datasets, _ = vector_env + engine = CypherEngine(config, datasets) + + results = engine.execute_with_vector_rerank( + "MATCH (d:Document) WHERE d.category = 'tech' RETURN d.id, d.name, d.embedding", + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.L2) + .top_k(2), + ) + + data = results.to_pydict() + assert len(data["d.name"]) == 2 + assert data["d.name"][0] == "Doc1" + + +def test_cypher_engine_vs_cypher_query_vector_rerank_equivalence(vector_env): + """Test that CypherEngine produces same results as CypherQuery for vector rerank.""" + from lance_graph import CypherEngine + + config, datasets, _ = vector_env + + query_text = ( + "MATCH (d:Document) WHERE d.category = 'tech' RETURN d.id, d.name, d.embedding" + ) + vector_search = ( + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.L2) + .top_k(2) + ) + + # Execute with CypherQuery + query = CypherQuery(query_text).with_config(config) + result_query = query.execute_with_vector_rerank(datasets, vector_search) + + # Execute with CypherEngine + engine = CypherEngine(config, datasets) + result_engine = engine.execute_with_vector_rerank(query_text, vector_search) + + # Results should be identical + assert result_query.to_pydict() == result_engine.to_pydict() + + +def test_cypher_engine_vector_rerank_multiple_queries(vector_env): + """Test that CypherEngine efficiently handles multiple vector rerank queries.""" + from lance_graph import CypherEngine + + config, datasets, _ = vector_env + engine = CypherEngine(config, datasets) + + # Execute multiple different queries using the same cached engine + results1 = engine.execute_with_vector_rerank( + "MATCH (d:Document) RETURN d.id, d.name, d.embedding", + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.L2) + .top_k(2), + ) + + results2 = engine.execute_with_vector_rerank( + "MATCH (d:Document) WHERE d.category = 'science' " + "RETURN d.id, d.name, d.embedding", + VectorSearch("d.embedding") + .query_vector([0.0, 1.0, 0.0]) + .metric(DistanceMetric.Cosine) + .top_k(1), + ) + + data1 = results1.to_pydict() + data2 = results2.to_pydict() + + assert len(data1["d.name"]) == 2 + assert data1["d.name"][0] == "Doc1" + + assert len(data2["d.name"]) == 1 + assert data2["d.name"][0] == "Doc3"