From 71b3520b9d8e7ff34e0b3214e3a6b2ae25402c7d Mon Sep 17 00:00:00 2001 From: Marc Vilanova Date: Thu, 20 Mar 2025 11:42:48 -0700 Subject: [PATCH] fix(ai): improves handling of llm context window limit --- src/dispatch/ai/service.py | 50 ++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/src/dispatch/ai/service.py b/src/dispatch/ai/service.py index b3da1162b2f7..b0ff53317c6e 100644 --- a/src/dispatch/ai/service.py +++ b/src/dispatch/ai/service.py @@ -15,7 +15,36 @@ log = logging.getLogger(__name__) -MAX_TOKENS = 128000 + +def get_model_token_limit(model_name: str, buffer_percentage: float = 0.05) -> int: + """ + Returns the maximum token limit for a given LLM model with a safety buffer. + + Args: + model_name (str): The name of the LLM model. + buffer_percentage (float): Percentage of tokens to reserve as buffer (default: 5%). + + Returns: + int: The maximum number of tokens allowed in the context window for the specified model, + with a safety buffer applied. + """ + default_max_tokens = 128000 + + model_token_limits = { + # OpenAI models (most recent) + "gpt-4o": 128000, + # Anthropic models (Claude 3.5 and 3.7 Sonnet variants) + "claude-3-5-sonnet-20241022": 200000, + "claude-3-7-sonnet-20250219": 200000, + } + + # Get the raw token limit for the model + raw_limit = model_token_limits.get(model_name.lower(), default_max_tokens) + + # Apply safety buffer + safe_limit = int(raw_limit * (1 - buffer_percentage)) + + return safe_limit def num_tokens_from_string(message: str, model: str) -> tuple[list[int], int, tiktoken.Encoding]: @@ -48,6 +77,7 @@ def truncate_prompt( tokenized_prompt: list[int], num_tokens: int, encoding: tiktoken.Encoding, + model_token_limit: int, ) -> str: """ Truncate the tokenized prompt to ensure it does not exceed the maximum number of tokens. @@ -60,10 +90,10 @@ def truncate_prompt( Returns: str: The truncated prompt as a string. """ - excess_tokens = num_tokens - MAX_TOKENS + excess_tokens = num_tokens - model_token_limit truncated_tokenized_prompt = tokenized_prompt[:-excess_tokens] truncated_prompt = encoding.decode(truncated_tokenized_prompt) - log.warning(f"GenAI prompt truncated to fit within {MAX_TOKENS} tokens.") + log.warning(f"GenAI prompt truncated to fit within {model_token_limit} tokens.") return truncated_prompt @@ -245,8 +275,11 @@ def generate_case_signal_summary(case: Case, db_session: Session) -> dict[str, s ) # we check if the prompt exceeds the token limit - if num_tokens > MAX_TOKENS: - prompt = truncate_prompt(tokenized_prompt, num_tokens, encoding) + model_token_limit = get_model_token_limit( + genai_plugin.instance.configuration.chat_completion_model + ) + if num_tokens > model_token_limit: + prompt = truncate_prompt(tokenized_prompt, num_tokens, encoding, model_token_limit) # we generate the analysis response = genai_plugin.instance.chat_completion(prompt=prompt) @@ -336,8 +369,11 @@ def generate_incident_summary(incident: Incident, db_session: Session) -> str: ) # we check if the prompt exceeds the token limit - if num_tokens > MAX_TOKENS: - prompt = truncate_prompt(tokenized_prompt, num_tokens, encoding) + model_token_limit = get_model_token_limit( + genai_plugin.instance.configuration.chat_completion_model + ) + if num_tokens > model_token_limit: + prompt = truncate_prompt(tokenized_prompt, num_tokens, encoding, model_token_limit) summary = genai_plugin.instance.chat_completion(prompt=prompt)