Skip to content
98 changes: 97 additions & 1 deletion lisette/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,18 +244,21 @@ def __init__(
cache=False, # Anthropic prompt caching
cache_idxs:list=[-1], # Anthropic cache breakpoint idxs, use `0` for sys prompt if provided
ttl=None, # Anthropic prompt caching ttl
cached_content = None # Gemini prompt caching
):
"LiteLLM chat client."
self.model = model
hist,tools = mk_msgs(hist,cache,cache_idxs,ttl),listify(tools)
if ns is None and tools: ns = mk_ns(tools)
elif ns is None: ns = globals()
self.tool_schemas = [lite_mk_func(t) for t in tools] if tools else None
self.cache_name = cached_content
store_attr()

def _prep_msg(self, msg=None, prefill=None):
"Prepare the messages list for the API call"
sp = [{"role": "system", "content": self.sp}] if self.sp else []
# Don't include sp if using cache (it's already in the cache)
sp = [{"role": "system", "content": self.sp}] if self.sp and not getattr(self, '_sp_in_cache', False) else []
if sp:
if 0 in self.cache_idxs: sp[0] = _add_cache_control(sp[0])
cache_idxs = L(self.cache_idxs).filter().map(lambda o: o-1 if o>0 else o)
Expand All @@ -271,6 +274,7 @@ def _call(self, msg=None, prefill=None, temp=None, think=None, search=None, stre
if not get_model_info(self.model).get("supports_assistant_prefill"): prefill=None
if _has_search(self.model) and (s:=ifnone(search,self.search)): kwargs['web_search_options'] = {"search_context_size": effort[s]}
else: _=kwargs.pop('web_search_options',None)
if self.cache_name: kwargs['cached_content'] = self.cache_name
res = completion(model=self.model, messages=self._prep_msg(msg, prefill), stream=stream,
tools=self.tool_schemas, reasoning_effort = effort.get(think), tool_choice=tool_choice,
# temperature is not supported when reasoning
Expand Down Expand Up @@ -310,6 +314,96 @@ def __call__(self,
if stream: return result_gen # streaming
elif return_all: return list(result_gen) # toolloop behavior
else: return last(result_gen) # normal chat behavior

def create_cache(self, system_instruction=None, contents=None, tools=None, ttl="3600s"):
from google import genai
from google.genai import types
client = genai.Client()

# if model is "gemini/gemini-2.0-flash", extract "gemini-2.0-flash"
if "/" in self.model:
model_name = self.model.split("/")[1]
else:
model_name = self.model

#check if model has `-001` suffix
if "-001" not in model_name:
model_name += "-001"

# Check if cache already exists
if self.cache_name:
raise ValueError("Cache already exists. Delete it first with delete_cache()")

# Use defaults from Chat if not provided
system_instruction = system_instruction or self.sp
tools = tools or self.tool_schemas

# Create cache using google.genai client
if contents:
cache = client.caches.create(
model=model_name,
config=types.CreateCachedContentConfig(
system_instruction=system_instruction,
contents=contents,
tools=tools,
ttl=ttl
)
)
else:
cache = client.caches.create(
model=model_name,
config=types.CreateCachedContentConfig(
system_instruction=system_instruction,
tools=tools,
ttl=ttl
)
)
# Store cache.name in self.cache_name
self.cache_name = cache.name

# Set flag if system prompt is in cache
self._sp_in_cache = bool(system_instruction)

# Return cache object
return cache


def delete_cache(self):
from google import genai

if not self.cache_name:
raise ValueError("No cache exists to delete.")

client = genai.Client()
client.caches.delete(name=self.cache_name)
self.cache_name = None

def get_cache(self):
from google import genai

if not self.cache_name:
raise ValueError("No cache exists")

client = genai.Client()
return client.caches.get(name=self.cache_name)

def update_cache(self,ttl='300s'):
## ttl needs to be in seconds in string format i.e., '300s'
from google import genai
from google.genai import types

if not self.cache_name:
raise ValueError("No cache exists to update")

client = genai.Client()
client.caches.update(
name = self.cache_name,
config = types.UpdateCachedContentConfig(
ttl=ttl
)
)



# %% ../nbs/00_core.ipynb
@patch
Expand Down Expand Up @@ -467,3 +561,5 @@ async def adisplay_stream(rs):
md+=o
display(Markdown(md),clear=True)
return fmt