Skip to content

Commit 4c84ed4

Browse files
Llmcache ttl configuration (#28)
Enables TTL for loading data to Redis using RedisVL. This TTL is an optional kwarg for the underlying `load` function. This is particularly useful for clients and abstractions like SemanticCache that lean on Redis support for ephemeral data. Also adds a `clear` method to the LLM cache class that invalidates all data from the cache, but does not disturb the index itself.
1 parent dcc5297 commit 4c84ed4

File tree

5 files changed

+85
-29
lines changed

5 files changed

+85
-29
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ cache.store("What is the capital of France?", "Paris")
141141
cache.check("What is the capital of France?")
142142
["Paris"]
143143

144-
# Cache will return the result if the query is similar enough
145-
cache.get("What really is the capital of France?")
144+
# Cache will still return the result if the query is similar enough
145+
cache.check("What really is the capital of France?")
146146
["Paris"]
147147
```
148148

redisvl/index.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,21 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs):
255255
raises:
256256
redis.exceptions.ResponseError: If the index does not exist
257257
"""
258-
if not data:
259-
return
260-
if not isinstance(data, Iterable):
261-
if not isinstance(data[0], dict):
262-
raise TypeError("data must be an iterable of dictionaries")
263-
264-
for record in data:
265-
key = f"{self._prefix}:{self._get_key_field(record)}"
266-
self._redis_conn.hset(key, mapping=record) # type: ignore
258+
# TODO -- should we return a count of the upserts? or some kind of metadata?
259+
if data:
260+
if not isinstance(data, Iterable):
261+
if not isinstance(data[0], dict):
262+
raise TypeError("data must be an iterable of dictionaries")
263+
264+
# Check if outer interface passes in TTL on load
265+
ttl = kwargs.get("ttl")
266+
pipe = self._redis_conn.pipeline(transaction=False)
267+
for record in data:
268+
key = f"{self._prefix}:{self._get_key_field(record)}"
269+
pipe.hset(key, mapping=record) # type: ignore
270+
if ttl:
271+
pipe.expire(key, ttl)
272+
pipe.execute()
267273

268274
@check_connected("_redis_conn")
269275
def exists(self) -> bool:
@@ -338,7 +344,7 @@ async def delete(self, drop: bool = True):
338344
await self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore
339345

340346
@check_connected("_redis_conn")
341-
async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10):
347+
async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10, **kwargs):
342348
"""Load data into Redis and index using this SearchIndex object
343349
344350
Args:
@@ -348,15 +354,18 @@ async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10):
348354
raises:
349355
redis.exceptions.ResponseError: If the index does not exist
350356
"""
357+
ttl = kwargs.get("ttl")
351358
semaphore = asyncio.Semaphore(concurrency)
352359

353-
async def load(d: dict):
360+
async def _load(d: dict):
354361
async with semaphore:
355362
key = f"{self._prefix}:{self._get_key_field(d)}"
356363
await self._redis_conn.hset(key, mapping=d) # type: ignore
364+
if ttl:
365+
await self._redis_conn.expire(key, ttl)
357366

358367
# gather with concurrency
359-
await asyncio.gather(*[load(d) for d in data])
368+
await asyncio.gather(*[_load(d) for d in data])
360369

361370
@check_connected("_redis_conn")
362371
async def exists(self) -> bool:

redisvl/llmcache/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
class BaseLLMCache:
66
verbose: bool = True
77

8+
def clear(self):
9+
"""Clear the LLMCache and create a new underlying index."""
10+
raise NotImplementedError
11+
812
def check(self, prompt: str) -> Optional[List[str]]:
913
raise NotImplementedError
1014

redisvl/llmcache/semantic.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ def set_threshold(self, threshold: float):
114114
raise ValueError("Threshold must be between 0 and 1.")
115115
self._threshold = float(threshold)
116116

117+
def clear(self):
118+
"""Clear the LLMCache of all keys in the index"""
119+
client = self._index.client
120+
if client:
121+
pipe = client.pipeline()
122+
for key in client.scan_iter(match=f"{self._index._prefix}:*"):
123+
pipe.delete(key)
124+
pipe.execute()
125+
else:
126+
raise RuntimeError("LLMCache is not connected to a Redis instance.")
127+
117128
def check(
118129
self,
119130
prompt: Optional[str] = None,
@@ -153,9 +164,9 @@ def check(
153164

154165
cache_hits = []
155166
for doc in results.docs:
156-
self._refresh_ttl(doc.id)
157167
sim = similarity(doc.vector_distance)
158168
if sim > self.threshold:
169+
self._refresh_ttl(doc.id)
159170
cache_hits.append(doc.response)
160171
return cache_hits
161172

@@ -179,18 +190,23 @@ def store(
179190
Raises:
180191
ValueError: If neither prompt nor vector is specified.
181192
"""
193+
# Prepare LLMCache inputs
182194
if not key:
183195
key = self.hash_input(prompt)
184196

185-
if vector:
186-
vector = array_to_buffer(vector)
187-
else:
197+
if not vector:
188198
vector = self._provider.embed(prompt) # type: ignore
189199

190-
payload = {"id": key, "prompt_vector": vector, "response": response}
200+
payload = {
201+
"id": key,
202+
"prompt_vector": array_to_buffer(vector),
203+
"response": response
204+
}
191205
if metadata:
192206
payload.update(metadata)
193-
self._index.load([payload])
207+
208+
# Load LLMCache entry with TTL
209+
self._index.load([payload], ttl=self._ttl)
194210

195211
def _refresh_ttl(self, key: str):
196212
"""Refreshes the TTL for the specified key."""
@@ -201,6 +217,5 @@ def _refresh_ttl(self, key: str):
201217
else:
202218
raise RuntimeError("LLMCache is not connected to a Redis instance.")
203219

204-
205220
def similarity(distance: Union[float, str]) -> float:
206221
return 1 - float(distance)

tests/integration/test_llmcache.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from time import sleep
34
from redisvl.llmcache.semantic import SemanticCache
45
from redisvl.providers import HuggingfaceProvider
56

@@ -8,48 +9,75 @@
89
def provider():
910
return HuggingfaceProvider("sentence-transformers/all-mpnet-base-v2")
1011

11-
1212
@pytest.fixture
1313
def cache(provider):
1414
return SemanticCache(provider=provider, threshold=0.8)
1515

16+
@pytest.fixture
17+
def cache_with_ttl(provider):
18+
return SemanticCache(provider=provider, threshold=0.8, ttl=2)
1619

1720
@pytest.fixture
1821
def vector(provider):
1922
return provider.embed("This is a test sentence.")
2023

2124

22-
def test_store_and_check(cache, vector):
25+
def test_store_and_check_and_clear(cache, vector):
2326
# Check that we can store and retrieve a response
2427
prompt = "This is a test prompt."
2528
response = "This is a test response."
2629
cache.store(prompt, response, vector=vector)
2730
check_result = cache.check(vector=vector)
2831
assert len(check_result) >= 1
2932
assert response in check_result
30-
cache.index.delete(drop=True)
33+
cache.clear()
34+
check_result = cache.check(vector=vector)
35+
assert len(check_result) == 0
36+
cache._index.delete(True)
3137

38+
def test_ttl(cache_with_ttl, vector):
39+
# Check that TTL expiration kicks in after 2 seconds
40+
prompt = "This is a test prompt."
41+
response = "This is a test response."
42+
cache_with_ttl.store(prompt, response, vector=vector)
43+
sleep(3)
44+
check_result = cache_with_ttl.check(vector=vector)
45+
assert len(check_result) == 0
46+
cache_with_ttl._index.delete(True)
3247

3348
def test_check_no_match(cache, vector):
3449
# Check behavior when there is no match in the cache
3550
# In this case, we're using a vector, but the cache is empty
3651
check_result = cache.check(vector=vector)
3752
assert len(check_result) == 0
38-
cache.index.delete(drop=True)
39-
53+
cache._index.delete(True)
4054

4155
def test_store_with_vector_and_metadata(cache, vector):
4256
# Test storing a response with a vector and metadata
4357
prompt = "This is another test prompt."
4458
response = "This is another test response."
4559
metadata = {"source": "test"}
4660
cache.store(prompt, response, vector=vector, metadata=metadata)
47-
cache.index.delete(drop=True)
48-
61+
check_result = cache.check(vector=vector)
62+
assert len(check_result) >= 1
63+
assert response in check_result
64+
cache._index.delete(True)
4965

5066
def test_set_threshold(cache):
5167
# Test the getter and setter for the threshold
5268
assert cache.threshold == 0.8
5369
cache.set_threshold(0.9)
5470
assert cache.threshold == 0.9
55-
cache.index.delete(drop=True)
71+
cache._index.delete(True)
72+
73+
def test_from_existing(cache, vector, provider):
74+
prompt = "This is another test prompt."
75+
response = "This is another test response."
76+
metadata = {"source": "test"}
77+
cache.store(prompt, response, vector=vector, metadata=metadata)
78+
# connect from existing?
79+
new_cache = SemanticCache(provider=provider, threshold=0.8)
80+
check_result = new_cache.check(vector=vector)
81+
assert len(check_result) >= 1
82+
assert response in check_result
83+
new_cache._index.delete(True)

0 commit comments

Comments
 (0)