Skip to content

Commit 221b864

Browse files
authored
INPYTHON-764 Make usage of text_key more obvious (#220)
1 parent a3a5c9c commit 221b864

18 files changed

+226
-47
lines changed

libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Annotated, Any, Dict, List, Optional, Union
23

34
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
@@ -46,7 +47,13 @@ def _get_relevant_documents(
4647
Returns:
4748
List of relevant documents
4849
"""
49-
default_k = self.k if self.k is not None else self.top_k
50+
is_top_k_set = False
51+
with warnings.catch_warnings():
52+
# Ignore warning raised by checking the value of top_k.
53+
warnings.simplefilter("ignore", DeprecationWarning)
54+
if self.top_k is not None:
55+
is_top_k_set = True
56+
default_k = self.k if not is_top_k_set else self.top_k
5057
pipeline = text_search_stage( # type: ignore
5158
query=query,
5259
search_field=self.search_field,

libs/langchain-mongodb/langchain_mongodb/retrievers/hybrid_search.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Annotated, Any, Dict, List, Optional
23

34
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
@@ -83,8 +84,14 @@ def _get_relevant_documents(
8384
pipeline: List[Any] = []
8485

8586
# Get the appropriate value for k.
86-
default_k = self.top_k if self.top_k is not None else self.k
87-
k = kwargs.get("k", default_k)
87+
is_top_k_set = False
88+
with warnings.catch_warnings():
89+
# Ignore warning raised by checking the value of top_k.
90+
warnings.simplefilter("ignore", DeprecationWarning)
91+
if self.top_k is not None:
92+
is_top_k_set = True
93+
default_k = self.k if not is_top_k_set else self.top_k
94+
k: int = kwargs.get("k", default_k) # type:ignore[assignment]
8895

8996
# First we build up the aggregation pipeline,
9097
# then it is passed to the server to execute

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import warnings
45
from typing import (
56
Any,
67
Callable,
@@ -108,6 +109,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
108109
namespace="db_name.collection_name",
109110
embedding=OpenAIEmbeddings(),
110111
index_name="vector_index",
112+
text_key="text_field"
111113
)
112114
113115
Add Documents:
@@ -807,15 +809,27 @@ def _similarity_search_with_score(
807809
docs = []
808810

809811
# Format
812+
missing_text_key = False
810813
for res in cursor:
811814
if self._text_key not in res:
815+
missing_text_key = True
812816
continue
813817
text = res.pop(self._text_key)
814818
score = res.pop("score")
815819
make_serializable(res)
816820
docs.append(
817821
(Document(page_content=text, metadata=res, id=res["_id"]), score)
818822
)
823+
824+
if (
825+
missing_text_key
826+
and not len(docs)
827+
and self._collection.count_documents({}) > 0
828+
):
829+
warnings.warn(
830+
f"Could not find any documents with the text_key: '{self._text_key}'",
831+
stacklevel=1,
832+
)
819833
return docs
820834

821835
def create_vector_search_index(

libs/langchain-mongodb/pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,20 @@ dev = [
4545
]
4646

4747
[tool.pytest.ini_options]
48+
minversion = "7"
4849
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
50+
log_cli_level = "INFO"
51+
faulthandler_timeout = 1500
52+
xfail_strict = true
4953
markers = [
5054
"requires: mark tests as requiring a specific library",
5155
"compile: mark placeholder test used to compile integration tests without running them",
5256
]
5357
asyncio_mode = "auto"
5458
asyncio_default_fixture_loop_scope = "function"
59+
filterwarnings = [
60+
"error"
61+
]
5562

5663
[tool.mypy]
5764
disallow_untyped_defs = true

libs/langchain-mongodb/tests/integration_tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from typing import Generator, List
34

45
import pytest
@@ -16,7 +17,10 @@
1617
def technical_report_pages() -> List[Document]:
1718
"""Returns a Document for each of the 100 pages of a GPT-4 Technical Report"""
1819
loader = PyPDFLoader("https://arxiv.org/pdf/2303.08774.pdf")
19-
pages = loader.load()
20+
with warnings.catch_warnings():
21+
# Ignore warnings raised by base class.
22+
warnings.simplefilter("ignore", ResourceWarning)
23+
pages = loader.load()
2024
return pages
2125

2226

libs/langchain-mongodb/tests/integration_tests/test_agent_toolkit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,4 @@ def test_toolkit_response(db):
7474
for event in events:
7575
messages.extend(event["messages"])
7676
assert "USA" in messages[-1].content, messages[-1].content
77+
db_wrapper.close()

libs/langchain-mongodb/tests/integration_tests/test_cache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def llm_cache(cls: Any) -> BaseCache:
7070
)
7171
)
7272
assert get_llm_cache()
73-
return get_llm_cache()
73+
return get_llm_cache() # type:ignore[return-value]
7474

7575

7676
@pytest.fixture(scope="module", autouse=True)
@@ -99,7 +99,7 @@ def _execute_test(
9999
dumped_prompt: str = prompt if isinstance(prompt, str) else dumps(prompt)
100100

101101
# Update the cache
102-
get_llm_cache().update(dumped_prompt, llm_string, response)
102+
get_llm_cache().update(dumped_prompt, llm_string, response) # type:ignore[union-attr]
103103

104104
# Retrieve the cached result through 'generate' call
105105
output: Union[List[Generation], LLMResult, None]
@@ -156,7 +156,8 @@ def test_mongodb_cache(
156156
try:
157157
_execute_test(prompt, llm, response)
158158
finally:
159-
get_llm_cache().clear()
159+
get_llm_cache().clear() # type:ignore[union-attr]
160+
get_llm_cache().close() # type:ignore[attr-defined,union-attr]
160161

161162

162163
@pytest.mark.parametrize(
@@ -207,4 +208,5 @@ def test_mongodb_atlas_cache_matrix(
207208
assert llm.generate(prompts) == LLMResult(
208209
generations=llm_generations, llm_output={}
209210
)
210-
get_llm_cache().clear()
211+
get_llm_cache().clear() # type:ignore[union-attr]
212+
get_llm_cache().close() # type:ignore[attr-defined,union-attr]

libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import warnings
23

34
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
45
from langchain_core.messages import message_to_dict
@@ -19,9 +20,12 @@ def test_memory_with_message_store() -> None:
1920
database_name=DB_NAME,
2021
collection_name=COLLECTION,
2122
)
22-
memory = ConversationBufferMemory(
23-
memory_key="baz", chat_memory=message_history, return_messages=True
24-
)
23+
with warnings.catch_warnings():
24+
# Ignore warnings raised by base class.
25+
warnings.simplefilter("ignore", DeprecationWarning)
26+
memory = ConversationBufferMemory(
27+
memory_key="baz", chat_memory=message_history, return_messages=True
28+
)
2529

2630
# add some messages
2731
memory.chat_memory.add_ai_message("This is me, the AI")
@@ -38,3 +42,4 @@ def test_memory_with_message_store() -> None:
3842
memory.chat_memory.clear()
3943

4044
assert memory.chat_memory.messages == []
45+
memory.chat_memory.close() # type:ignore[attr-defined]

libs/langchain-mongodb/tests/integration_tests/test_parent_document.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,4 @@ def test_1clxn_retriever(
7474
assert len(responses) == 3
7575
assert all("GPT-4" in doc.page_content for doc in responses)
7676
assert {4, 5, 29} == set(doc.metadata["page"] for doc in responses)
77+
client.close()

libs/langchain-mongodb/tests/integration_tests/test_retrievers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,14 @@ def test_hybrid_retriever_deprecated_top_k(
194194
)
195195

196196
query1 = "When did I visit France?"
197-
results = retriever.invoke(query1)
197+
with pytest.warns(DeprecationWarning):
198+
results = retriever.invoke(query1)
198199
assert len(results) == 3
199200
assert "Paris" in results[0].page_content
200201

201202
query2 = "When was the last time I visited new orleans?"
202-
results = retriever.invoke(query2)
203+
with pytest.warns(DeprecationWarning):
204+
results = retriever.invoke(query2)
203205
assert "New Orleans" in results[0].page_content
204206

205207

0 commit comments

Comments
 (0)