Skip to content

Commit cd54763

Browse files
committed
INTPYTHON-752 Integrate pymongo-vectorsearch-utils
1 parent a42c0b6 commit cd54763

File tree

5 files changed

+17
-220
lines changed

5 files changed

+17
-220
lines changed

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 5 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import logging
44
from time import monotonic, sleep
5-
from typing import Any, Callable, Dict, List, Optional, Union
5+
from typing import Any, Callable, Dict, List, Optional
66

77
from pymongo.collection import Collection
8-
from pymongo.operations import SearchIndexModel
98

109
logger = logging.getLogger(__file__)
1110

1211

12+
# Don't break imports for modules that expect these functions
13+
# to be in this module.
14+
15+
1316
def _vector_search_index_definition(
1417
dimensions: int,
1518
path: str,
@@ -34,133 +37,6 @@ def _vector_search_index_definition(
3437
return definition
3538

3639

37-
def create_vector_search_index(
38-
collection: Collection,
39-
index_name: str,
40-
dimensions: int,
41-
path: str,
42-
similarity: str,
43-
filters: Optional[List[str]] = None,
44-
*,
45-
wait_until_complete: Optional[float] = None,
46-
**kwargs: Any,
47-
) -> None:
48-
"""Experimental Utility function to create a vector search index
49-
50-
Args:
51-
collection (Collection): MongoDB Collection
52-
index_name (str): Name of Index
53-
dimensions (int): Number of dimensions in embedding
54-
path (str): field with vector embedding
55-
similarity (str): The similarity score used for the index
56-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
57-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
58-
until search index is ready.
59-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
60-
"""
61-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
62-
63-
if collection.name not in collection.database.list_collection_names():
64-
collection.database.create_collection(collection.name)
65-
66-
result = collection.create_search_index(
67-
SearchIndexModel(
68-
definition=_vector_search_index_definition(
69-
dimensions=dimensions,
70-
path=path,
71-
similarity=similarity,
72-
filters=filters,
73-
**kwargs,
74-
),
75-
name=index_name,
76-
type="vectorSearch",
77-
)
78-
)
79-
80-
if wait_until_complete:
81-
_wait_for_predicate(
82-
predicate=lambda: _is_index_ready(collection, index_name),
83-
err=f"{index_name=} did not complete in {wait_until_complete}!",
84-
timeout=wait_until_complete,
85-
)
86-
logger.info(result)
87-
88-
89-
def drop_vector_search_index(
90-
collection: Collection,
91-
index_name: str,
92-
*,
93-
wait_until_complete: Optional[float] = None,
94-
) -> None:
95-
"""Drop a created vector search index
96-
97-
Args:
98-
collection (Collection): MongoDB Collection with index to be dropped
99-
index_name (str): Name of the MongoDB index
100-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
101-
until search index is ready.
102-
"""
103-
logger.info(
104-
"Dropping Search Index %s from Collection: %s", index_name, collection.name
105-
)
106-
collection.drop_search_index(index_name)
107-
if wait_until_complete:
108-
_wait_for_predicate(
109-
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
110-
err=f"Index {index_name} did not drop in {wait_until_complete}!",
111-
timeout=wait_until_complete,
112-
)
113-
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
114-
115-
116-
def update_vector_search_index(
117-
collection: Collection,
118-
index_name: str,
119-
dimensions: int,
120-
path: str,
121-
similarity: str,
122-
filters: Optional[List[str]] = None,
123-
*,
124-
wait_until_complete: Optional[float] = None,
125-
**kwargs: Any,
126-
) -> None:
127-
"""Update a search index.
128-
129-
Replace the existing index definition with the provided definition.
130-
131-
Args:
132-
collection (Collection): MongoDB Collection
133-
index_name (str): Name of Index
134-
dimensions (int): Number of dimensions in embedding
135-
path (str): field with vector embedding
136-
similarity (str): The similarity score used for the index.
137-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
138-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
139-
until search index is ready.
140-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
141-
"""
142-
logger.info(
143-
"Updating Search Index %s from Collection: %s", index_name, collection.name
144-
)
145-
collection.update_search_index(
146-
name=index_name,
147-
definition=_vector_search_index_definition(
148-
dimensions=dimensions,
149-
path=path,
150-
similarity=similarity,
151-
filters=filters,
152-
**kwargs,
153-
),
154-
)
155-
if wait_until_complete:
156-
_wait_for_predicate(
157-
predicate=lambda: _is_index_ready(collection, index_name),
158-
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
159-
timeout=wait_until_complete,
160-
)
161-
logger.info("Update succeeded")
162-
163-
16440
def _is_index_ready(collection: Collection, index_name: str) -> bool:
16541
"""Check for the index name in the list of available search indexes to see if the
16642
specified index is of status READY
@@ -197,48 +73,3 @@ def _wait_for_predicate(
19773
if monotonic() - start > timeout:
19874
raise TimeoutError(err)
19975
sleep(interval)
200-
201-
202-
def create_fulltext_search_index(
203-
collection: Collection,
204-
index_name: str,
205-
field: Union[str, List[str]],
206-
*,
207-
wait_until_complete: Optional[float] = None,
208-
**kwargs: Any,
209-
) -> None:
210-
"""Experimental Utility function to create an Atlas Search index
211-
212-
Args:
213-
collection (Collection): MongoDB Collection
214-
index_name (str): Name of Index
215-
field (str): Field to index
216-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
217-
until search index is ready
218-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
219-
"""
220-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
221-
222-
if collection.name not in collection.database.list_collection_names():
223-
collection.database.create_collection(collection.name)
224-
225-
if isinstance(field, str):
226-
fields_definition = {field: [{"type": "string"}]}
227-
else:
228-
fields_definition = {f: [{"type": "string"}] for f in field}
229-
definition = {"mappings": {"dynamic": False, "fields": fields_definition}}
230-
result = collection.create_search_index(
231-
SearchIndexModel(
232-
definition=definition,
233-
name=index_name,
234-
type="search",
235-
**kwargs,
236-
)
237-
)
238-
if wait_until_complete:
239-
_wait_for_predicate(
240-
predicate=lambda: _is_index_ready(collection, index_name),
241-
err=f"{index_name=} did not complete in {wait_until_complete}!",
242-
timeout=wait_until_complete,
243-
)
244-
logger.info(result)

libs/langchain-mongodb/langchain_mongodb/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from typing import Any, Dict, List, Union
2525

2626
import numpy as np
27-
from pymongo import MongoClient
2827
from pymongo.driver_info import DriverInfo
2928

3029
logger = logging.getLogger(__name__)
@@ -33,11 +32,8 @@
3332

3433
DRIVER_METADATA = DriverInfo(name="Langchain", version=version("langchain-mongodb"))
3534

36-
37-
def _append_client_metadata(client: MongoClient) -> None:
38-
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
39-
if callable(client.append_metadata):
40-
client.append_metadata(DRIVER_METADATA)
35+
# Don't break imports for modules that expect this function
36+
# to be in this module.
4137

4238

4339
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from langchain_core.embeddings import Embeddings
2222
from langchain_core.runnables.config import run_in_executor
2323
from langchain_core.vectorstores import VectorStore
24-
from pymongo import MongoClient, ReplaceOne
24+
from pymongo import MongoClient
2525
from pymongo.collection import Collection
2626
from pymongo.errors import CollectionInvalid
27+
from pymongo_vectorsearch_utils import bulk_embed_and_insert_texts
2728

2829
from langchain_mongodb.index import (
2930
create_vector_search_index,
@@ -360,11 +361,11 @@ def add_texts(
360361
metadatas_batch.append(metadata)
361362
if (j + 1) % batch_size == 0 or size >= 47_000_000:
362363
if ids:
363-
batch_res = self.bulk_embed_and_insert_texts(
364+
batch_res = bulk_embed_and_insert_texts(
364365
texts_batch, metadatas_batch, ids[i : j + 1]
365366
)
366367
else:
367-
batch_res = self.bulk_embed_and_insert_texts(
368+
batch_res = bulk_embed_and_insert_texts(
368369
texts_batch, metadatas_batch
369370
)
370371
result_ids.extend(batch_res)
@@ -374,13 +375,11 @@ def add_texts(
374375
i = j + 1
375376
if texts_batch:
376377
if ids:
377-
batch_res = self.bulk_embed_and_insert_texts(
378+
batch_res = bulk_embed_and_insert_texts(
378379
texts_batch, metadatas_batch, ids[i : j + 1]
379380
)
380381
else:
381-
batch_res = self.bulk_embed_and_insert_texts(
382-
texts_batch, metadatas_batch
383-
)
382+
batch_res = bulk_embed_and_insert_texts(texts_batch, metadatas_batch)
384383
result_ids.extend(batch_res)
385384
return result_ids
386385

@@ -417,37 +416,6 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
417416
docs.append(Document(page_content=text, id=oid_to_str(_id), metadata=doc))
418417
return docs
419418

420-
def bulk_embed_and_insert_texts(
421-
self,
422-
texts: Union[List[str], Iterable[str]],
423-
metadatas: Union[List[dict], Generator[dict, Any, Any]],
424-
ids: Optional[List[str]] = None,
425-
) -> List[str]:
426-
"""Bulk insert single batch of texts, embeddings, and optionally ids.
427-
428-
See add_texts for additional details.
429-
"""
430-
if not texts:
431-
return []
432-
# Compute embedding vectors
433-
embeddings = self._embedding.embed_documents(list(texts))
434-
if not ids:
435-
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
436-
docs = [
437-
{
438-
"_id": str_to_oid(i),
439-
self._text_key: t,
440-
self._embedding_key: embedding,
441-
**m,
442-
}
443-
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
444-
]
445-
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
446-
# insert the documents in MongoDB Atlas
447-
result = self._collection.bulk_write(operations)
448-
assert result.upserted_ids is not None
449-
return [oid_to_str(_id) for _id in result.upserted_ids.values()]
450-
451419
def add_documents(
452420
self,
453421
documents: List[Document],
@@ -479,7 +447,7 @@ def add_documents(
479447
*[(doc.page_content, doc.metadata) for doc in documents[start:end]]
480448
)
481449
result_ids.extend(
482-
self.bulk_embed_and_insert_texts(
450+
bulk_embed_and_insert_texts(
483451
texts=texts, metadatas=metadatas, ids=ids[start:end]
484452
)
485453
)

libs/langchain-mongodb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"langchain-text-splitters>=0.3",
1717
"numpy>=1.26",
1818
"lark<2.0.0,>=1.1.9",
19+
# "pymongo-vectorsearch-utils",
1920
]
2021

2122
[dependency-groups]

libs/langchain-mongodb/tests/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymongo.driver_info import DriverInfo
2727
from pymongo.operations import SearchIndexModel
2828
from pymongo.results import BulkWriteResult, DeleteResult, InsertManyResult
29+
from pymongo_vectorsearch_utils import bulk_embed_and_insert_texts
2930

3031
from langchain_mongodb import MongoDBAtlasVectorSearch
3132
from langchain_mongodb.agent_toolkit.database import MongoDBDatabase
@@ -63,7 +64,7 @@ def bulk_embed_and_insert_texts(
6364
ids: Optional[List[str]] = None,
6465
) -> List:
6566
"""Patched insert_texts that waits for data to be indexed before returning"""
66-
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
67+
ids_inserted = bulk_embed_and_insert_texts(texts, metadatas, ids)
6768
n_docs = self.collection.count_documents({})
6869
start = monotonic()
6970
while monotonic() - start <= TIMEOUT:

0 commit comments

Comments
 (0)