Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 35 additions & 85 deletions langchain_couchbase/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import json
import logging
from datetime import timedelta
from typing import Any, Dict, Optional, Union
from typing import Any, Optional

from couchbase.cluster import Cluster
from couchbase.search import MatchQuery
Expand All @@ -21,9 +21,14 @@
from langchain_core.load.load import loads
from langchain_core.outputs import Generation

from langchain_couchbase.utils import (
check_bucket_exists,
check_scope_and_collection_exists,
validate_ttl,
)
from langchain_couchbase.vectorstores import CouchbaseSearchVectorStore

logger = logging.getLogger(__file__)
logger = logging.getLogger(__name__)


def _hash(_input: str) -> str:
Expand Down Expand Up @@ -51,7 +56,7 @@ def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
return json.dumps([dumps(_item) for _item in generations])


def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
def _loads_generations(generations_str: str) -> Optional[RETURN_VAL_TYPE]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
Expand Down Expand Up @@ -90,16 +95,6 @@ def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
return None


def _validate_ttl(ttl: Optional[timedelta]) -> None:
"""Validate the time to live"""
if not isinstance(ttl, timedelta):
raise ValueError(f"ttl should be of type timedelta but was {type(ttl)}.")
if ttl <= timedelta(seconds=0):
raise ValueError(
f"ttl must be greater than 0 but was {ttl.total_seconds()} seconds."
)


class CouchbaseCache(BaseCache):
"""Couchbase LLM Cache
LLM Cache that uses Couchbase as the backend
Expand All @@ -109,52 +104,14 @@ class CouchbaseCache(BaseCache):
LLM = "llm"
RETURN_VAL = "return_val"

def _check_bucket_exists(self) -> bool:
"""Check if the bucket exists in the linked Couchbase cluster"""
bucket_manager = self._cluster.buckets()
try:
bucket_manager.get_bucket(self._bucket_name)
return True
except Exception:
return False

def _check_scope_and_collection_exists(self) -> bool:
"""Check if the scope and collection exists in the linked Couchbase bucket
Raises a ValueError if either is not found"""
scope_collection_map: Dict[str, Any] = {}

# Get a list of all scopes in the bucket
for scope in self._bucket.collections().get_all_scopes():
scope_collection_map[scope.name] = []

# Get a list of all the collections in the scope
for collection in scope.collections:
scope_collection_map[scope.name].append(collection.name)

# Check if the scope exists
if self._scope_name not in scope_collection_map.keys():
raise ValueError(
f"Scope {self._scope_name} not found in Couchbase "
f"bucket {self._bucket_name}"
)

# Check if the collection exists in the scope
if self._collection_name not in scope_collection_map[self._scope_name]:
raise ValueError(
f"Collection {self._collection_name} not found in scope "
f"{self._scope_name} in Couchbase bucket {self._bucket_name}"
)

return True

def __init__(
self,
cluster: Cluster,
bucket_name: str,
scope_name: str,
collection_name: str,
ttl: Optional[timedelta] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Initialize the Couchbase LLM Cache
Args:
Expand All @@ -173,39 +130,36 @@ def __init__(
)

self._cluster = cluster

self._bucket_name = bucket_name
self._scope_name = scope_name
self._collection_name = collection_name

self._ttl = None

# Check if the bucket exists
if not self._check_bucket_exists():
if not check_bucket_exists(cluster, bucket_name):
raise ValueError(
f"Bucket {self._bucket_name} does not exist. "
" Please create the bucket before searching."
f"Bucket {bucket_name} does not exist. "
"Please create the bucket before searching."
)

try:
self._bucket = self._cluster.bucket(self._bucket_name)
self._scope = self._bucket.scope(self._scope_name)
self._collection = self._scope.collection(self._collection_name)
self._bucket = self._cluster.bucket(bucket_name)
self._scope = self._bucket.scope(scope_name)
self._collection = self._scope.collection(collection_name)
except Exception as e:
raise ValueError(
"Error connecting to couchbase. "
"Please check the connection and credentials."
) from e

# Check if the scope and collection exists. Throws ValueError if they don't
try:
self._check_scope_and_collection_exists()
except Exception as e:
raise e
check_scope_and_collection_exists(
self._bucket, scope_name, collection_name, bucket_name
)

# Check if the time to live is provided and valid
if ttl is not None:
_validate_ttl(ttl)
validate_ttl(ttl)
self._ttl = ttl

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
Expand All @@ -231,16 +185,13 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
}
document_key = self._generate_key(prompt, llm_string)
try:
if self._ttl:
self._collection.upsert(
key=document_key,
value=doc,
expiry=self._ttl,
)
else:
self._collection.upsert(key=document_key, value=doc)
self._collection.upsert(
key=document_key,
value=doc,
**({"expiry": self._ttl} if self._ttl else {}),
)
except Exception:
logger.error("Error updating cache")
logger.exception("Error updating cache")

def clear(self, **kwargs: Any) -> None:
"""Clear the cache.
Expand All @@ -251,7 +202,7 @@ def clear(self, **kwargs: Any) -> None:
query = f"DELETE FROM `{self._collection_name}`"
self._scope.query(query).execute()
except Exception:
logger.error("Error clearing cache. Please check if you have an index.")
logger.exception("Error clearing cache. Please check if you have an index.")


class CouchbaseSemanticCache(BaseCache, CouchbaseSearchVectorStore):
Expand Down Expand Up @@ -300,10 +251,10 @@ def __init__(
self._ttl = None

# Check if the bucket exists
if not self._check_bucket_exists():
if not check_bucket_exists(cluster, bucket_name):
raise ValueError(
f"Bucket {self._bucket_name} does not exist. "
" Please create the bucket before searching."
f"Bucket {bucket_name} does not exist. "
"Please create the bucket before searching."
)

try:
Expand All @@ -317,15 +268,14 @@ def __init__(
) from e

# Check if the scope and collection exists. Throws ValueError if they don't
try:
self._check_scope_and_collection_exists()
except Exception as e:
raise e
check_scope_and_collection_exists(
self._bucket, scope_name, collection_name, bucket_name
)

self.score_threshold = score_threshold

if ttl is not None:
_validate_ttl(ttl)
validate_ttl(ttl)
self._ttl = ttl

# Initialize the vector store
Expand Down Expand Up @@ -370,7 +320,7 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
ttl=self._ttl,
)
except Exception:
logger.error("Error updating cache")
logger.exception("Error updating cache")

def clear(self, **kwargs: Any) -> None:
"""Clear the cache.
Expand All @@ -381,4 +331,4 @@ def clear(self, **kwargs: Any) -> None:
query = f"DELETE FROM `{self._collection_name}`"
self._scope.query(query).execute()
except Exception:
logger.error("Error clearing cache. Please check if you have an index.")
logger.exception("Error clearing cache. Please check if you have an index.")
Loading
Loading