Skip to content

Commit ffcf7ed

Browse files
INTPYTHON-837 Integrate pymongo-search-utils (#208)
Replace functions in this library with identical functions from pymongo-search-utils - [x] append_client_metadata - [x] bulk_embed_and_insert_texts - [x] combine_pipelines - [x] create_fulltext_search_index - [x] create_vector_search_index - [x] drop_vector_search_index - [x] final_hybrid_stage - [x] reciprocal_rank_stage - [x] text_search_stage - [x] update_vector_search_index - [x] vector_search_stage Patch build: https://spruce.mongodb.com/version/6924b537f767320007b1a18d/tasks?sorts=STATUS%3AASC%3BBASE_STATUS%3ADESC --------- Co-authored-by: Steven Silvester <steve.silvester@mongodb.com>
1 parent 9ada053 commit ffcf7ed

File tree

7 files changed

+57
-316
lines changed

7 files changed

+57
-316
lines changed

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 7 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22

33
import logging
44
from time import monotonic, sleep
5-
from typing import Any, Callable, Dict, List, Optional, Union
5+
from typing import Any, Callable, Dict, List, Optional
66

77
from pymongo.collection import Collection
8-
from pymongo.operations import SearchIndexModel
8+
from pymongo_search_utils import (
9+
create_fulltext_search_index, # noqa: F401
10+
create_vector_search_index, # noqa: F401
11+
drop_vector_search_index, # noqa: F401
12+
update_vector_search_index, # noqa: F401
13+
)
914

1015
logger = logging.getLogger(__file__)
1116

@@ -34,135 +39,6 @@ def _vector_search_index_definition(
3439
return definition
3540

3641

37-
def create_vector_search_index(
38-
collection: Collection,
39-
index_name: str,
40-
dimensions: int,
41-
path: str,
42-
similarity: str,
43-
filters: Optional[List[str]] = None,
44-
*,
45-
wait_until_complete: Optional[float] = None,
46-
**kwargs: Any,
47-
) -> None:
48-
"""Experimental Utility function to create a vector search index
49-
50-
Args:
51-
collection (Collection): MongoDB Collection
52-
index_name (str): Name of Index
53-
dimensions (int): Number of dimensions in embedding
54-
path (str): field with vector embedding
55-
similarity (str): The similarity score used for the index
56-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
57-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
58-
until search index is ready.
59-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
60-
"""
61-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
62-
63-
if collection.name not in collection.database.list_collection_names(
64-
authorizedCollections=True
65-
):
66-
collection.database.create_collection(collection.name)
67-
68-
result = collection.create_search_index(
69-
SearchIndexModel(
70-
definition=_vector_search_index_definition(
71-
dimensions=dimensions,
72-
path=path,
73-
similarity=similarity,
74-
filters=filters,
75-
**kwargs,
76-
),
77-
name=index_name,
78-
type="vectorSearch",
79-
)
80-
)
81-
82-
if wait_until_complete:
83-
_wait_for_predicate(
84-
predicate=lambda: _is_index_ready(collection, index_name),
85-
err=f"{index_name=} did not complete in {wait_until_complete}!",
86-
timeout=wait_until_complete,
87-
)
88-
logger.info(result)
89-
90-
91-
def drop_vector_search_index(
92-
collection: Collection,
93-
index_name: str,
94-
*,
95-
wait_until_complete: Optional[float] = None,
96-
) -> None:
97-
"""Drop a created vector search index
98-
99-
Args:
100-
collection (Collection): MongoDB Collection with index to be dropped
101-
index_name (str): Name of the MongoDB index
102-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
103-
until search index is ready.
104-
"""
105-
logger.info(
106-
"Dropping Search Index %s from Collection: %s", index_name, collection.name
107-
)
108-
collection.drop_search_index(index_name)
109-
if wait_until_complete:
110-
_wait_for_predicate(
111-
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
112-
err=f"Index {index_name} did not drop in {wait_until_complete}!",
113-
timeout=wait_until_complete,
114-
)
115-
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
116-
117-
118-
def update_vector_search_index(
119-
collection: Collection,
120-
index_name: str,
121-
dimensions: int,
122-
path: str,
123-
similarity: str,
124-
filters: Optional[List[str]] = None,
125-
*,
126-
wait_until_complete: Optional[float] = None,
127-
**kwargs: Any,
128-
) -> None:
129-
"""Update a search index.
130-
131-
Replace the existing index definition with the provided definition.
132-
133-
Args:
134-
collection (Collection): MongoDB Collection
135-
index_name (str): Name of Index
136-
dimensions (int): Number of dimensions in embedding
137-
path (str): field with vector embedding
138-
similarity (str): The similarity score used for the index.
139-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
140-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
141-
until search index is ready.
142-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
143-
"""
144-
logger.info(
145-
"Updating Search Index %s from Collection: %s", index_name, collection.name
146-
)
147-
collection.update_search_index(
148-
name=index_name,
149-
definition=_vector_search_index_definition(
150-
dimensions=dimensions,
151-
path=path,
152-
similarity=similarity,
153-
filters=filters,
154-
**kwargs,
155-
),
156-
)
157-
if wait_until_complete:
158-
_wait_for_predicate(
159-
predicate=lambda: _is_index_ready(collection, index_name),
160-
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
161-
timeout=wait_until_complete,
162-
)
163-
logger.info("Update succeeded")
164-
165-
16642
def _is_index_ready(collection: Collection, index_name: str) -> bool:
16743
"""Check for the index name in the list of available search indexes to see if the
16844
specified index is of status READY
@@ -199,50 +75,3 @@ def _wait_for_predicate(
19975
if monotonic() - start > timeout:
20076
raise TimeoutError(err)
20177
sleep(interval)
202-
203-
204-
def create_fulltext_search_index(
205-
collection: Collection,
206-
index_name: str,
207-
field: Union[str, List[str]],
208-
*,
209-
wait_until_complete: Optional[float] = None,
210-
**kwargs: Any,
211-
) -> None:
212-
"""Experimental Utility function to create an Atlas Search index
213-
214-
Args:
215-
collection (Collection): MongoDB Collection
216-
index_name (str): Name of Index
217-
field (str): Field to index
218-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
219-
until search index is ready
220-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
221-
"""
222-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
223-
224-
if collection.name not in collection.database.list_collection_names(
225-
authorizedCollections=True
226-
):
227-
collection.database.create_collection(collection.name)
228-
229-
if isinstance(field, str):
230-
fields_definition = {field: [{"type": "string"}]}
231-
else:
232-
fields_definition = {f: [{"type": "string"}] for f in field}
233-
definition = {"mappings": {"dynamic": False, "fields": fields_definition}}
234-
result = collection.create_search_index(
235-
SearchIndexModel(
236-
definition=definition,
237-
name=index_name,
238-
type="search",
239-
**kwargs,
240-
)
241-
)
242-
if wait_until_complete:
243-
_wait_for_predicate(
244-
predicate=lambda: _is_index_ready(collection, index_name),
245-
err=f"{index_name=} did not complete in {wait_until_complete}!",
246-
timeout=wait_until_complete,
247-
)
248-
logger.info(result)

libs/langchain-mongodb/langchain_mongodb/pipelines.py

Lines changed: 7 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99

1010
from typing import Any, Dict, List, Optional, Union
1111

12+
from pymongo_search_utils import (
13+
combine_pipelines, # noqa: F401
14+
final_hybrid_stage, # noqa: F401
15+
reciprocal_rank_stage, # noqa: F401
16+
vector_search_stage, # noqa: F401
17+
)
18+
1219

1320
def text_search_stage(
1421
query: str,
@@ -48,115 +55,3 @@ def text_search_stage(
4855
pipeline.append({"$limit": limit}) # type: ignore
4956

5057
return pipeline # type: ignore
51-
52-
53-
def vector_search_stage(
54-
query_vector: List[float],
55-
search_field: str,
56-
index_name: str,
57-
top_k: int = 4,
58-
filter: Optional[Dict[str, Any]] = None,
59-
oversampling_factor: int = 10,
60-
**kwargs: Any,
61-
) -> Dict[str, Any]: # noqa: E501
62-
"""Vector Search Stage without Scores.
63-
64-
Scoring is applied later depending on strategy.
65-
vector search includes a vectorSearchScore that is typically used.
66-
hybrid uses Reciprocal Rank Fusion.
67-
68-
Args:
69-
query_vector: List of embedding vector
70-
search_field: Field in Collection containing embedding vectors
71-
index_name: Name of Atlas Vector Search Index tied to Collection
72-
top_k: Number of documents to return
73-
oversampling_factor: this times limit is the number of candidates
74-
filter: MQL match expression comparing an indexed field.
75-
Some operators are not supported.
76-
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
77-
78-
79-
Returns:
80-
Dictionary defining the $vectorSearch
81-
"""
82-
stage = {
83-
"index": index_name,
84-
"path": search_field,
85-
"queryVector": query_vector,
86-
"numCandidates": top_k * oversampling_factor,
87-
"limit": top_k,
88-
}
89-
if filter:
90-
stage["filter"] = filter
91-
return {"$vectorSearch": stage}
92-
93-
94-
def combine_pipelines(
95-
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
96-
) -> None:
97-
"""Combines two aggregations into a single result set in-place."""
98-
if pipeline:
99-
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
100-
else:
101-
pipeline.extend(stage)
102-
103-
104-
def reciprocal_rank_stage(
105-
score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any
106-
) -> List[Dict[str, Any]]:
107-
"""
108-
Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.
109-
110-
First, it groups documents into an array, assigns rank by array index,
111-
and then computes a weighted RRF score.
112-
113-
Args:
114-
score_field: A unique string to identify the search being ranked.
115-
penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
116-
weight: A float multiplier for this source's importance.
117-
**kwargs: Ignored; allows future extensions or passthrough args.
118-
119-
Returns:
120-
Aggregation pipeline stage for weighted RRF scoring.
121-
"""
122-
123-
return [
124-
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
125-
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
126-
{
127-
"$addFields": {
128-
f"docs.{score_field}": {
129-
"$multiply": [
130-
weight,
131-
{"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
132-
]
133-
},
134-
"docs.rank": "$rank",
135-
"_id": "$docs._id",
136-
}
137-
},
138-
{"$replaceRoot": {"newRoot": "$docs"}},
139-
]
140-
141-
142-
def final_hybrid_stage(
143-
scores_fields: List[str], limit: int, **kwargs: Any
144-
) -> List[Dict[str, Any]]:
145-
"""Sum weighted scores, sort, and apply limit.
146-
147-
Args:
148-
scores_fields: List of fields given to scores of vector and text searches
149-
limit: Number of documents to return
150-
151-
Returns:
152-
Final aggregation stages
153-
"""
154-
155-
return [
156-
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
157-
{"$replaceRoot": {"newRoot": "$docs"}},
158-
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
159-
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
160-
{"$sort": {"score": -1}},
161-
{"$limit": limit},
162-
]

libs/langchain-mongodb/langchain_mongodb/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727
from pymongo import MongoClient
2828
from pymongo.driver_info import DriverInfo
29+
from pymongo_search_utils import append_client_metadata
2930

3031
logger = logging.getLogger(__name__)
3132

@@ -35,9 +36,7 @@
3536

3637

3738
def _append_client_metadata(client: MongoClient) -> None:
38-
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
39-
if callable(client.append_metadata):
40-
client.append_metadata(DRIVER_METADATA)
39+
append_client_metadata(client=client, driver_info=DRIVER_METADATA)
4140

4241

4342
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:

0 commit comments

Comments
 (0)