diff --git a/backend/app/api/websocket.py b/backend/app/api/websocket.py index a6a8a33e..9cfbfffd 100644 --- a/backend/app/api/websocket.py +++ b/backend/app/api/websocket.py @@ -96,6 +96,35 @@ async def get_chat_history( return out +def merge_system_prompt(system_prompt: str, last_system: str) -> str: + if not last_system: + return system_prompt + + # find the longest common prefix + min_len = min(len(system_prompt), len(last_system)) + i = 0 + while i < min_len and system_prompt[i] == last_system[i]: + i += 1 + + # prevent meaningless matching (like only matching to "You are ") + if i < 20: + return system_prompt + "\n\n" + last_system + + common_prefix = system_prompt[:i].strip() + rest_a = system_prompt[i:].strip() + rest_b = last_system[i:].strip() + + parts = [] + if common_prefix: + parts.append(common_prefix) + if rest_a: + parts.append(rest_a) + if rest_b: + parts.append(rest_b) + + return "\n\n".join(parts) + + async def call_llm( model: LLMModel, messages: list[dict], @@ -155,8 +184,16 @@ async def call_llm( tools_for_llm = await get_agent_tools_for_llm(agent_id) if agent_id else AGENT_TOOLS # Convert messages to LLMMessage format - api_messages = [LLMMessage(role="system", content=system_prompt)] + last_system = None + filtered_messages = [] for msg in messages: + if msg.get("role") == "system": + last_system = msg.get("content") + else: + filtered_messages.append(msg) + final_system = merge_system_prompt(system_prompt, last_system) + api_messages = [LLMMessage(role="system", content=final_system)] + for msg in filtered_messages: api_messages.append(LLMMessage( role=msg.get("role", "user"), content=msg.get("content"),