Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion backend/app/api/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,35 @@ async def get_chat_history(
return out


def merge_system_prompt(system_prompt: str, last_system: str) -> str:
if not last_system:
return system_prompt

# find the longest common prefix
min_len = min(len(system_prompt), len(last_system))
i = 0
while i < min_len and system_prompt[i] == last_system[i]:
i += 1

# prevent meaningless matching (like only matching to "You are ")
if i < 20:
return system_prompt + "\n\n" + last_system

common_prefix = system_prompt[:i].strip()
rest_a = system_prompt[i:].strip()
rest_b = last_system[i:].strip()

parts = []
if common_prefix:
parts.append(common_prefix)
if rest_a:
parts.append(rest_a)
if rest_b:
parts.append(rest_b)

return "\n\n".join(parts)


async def call_llm(
model: LLMModel,
messages: list[dict],
Expand Down Expand Up @@ -155,8 +184,16 @@ async def call_llm(
tools_for_llm = await get_agent_tools_for_llm(agent_id) if agent_id else AGENT_TOOLS

# Convert messages to LLMMessage format
api_messages = [LLMMessage(role="system", content=system_prompt)]
last_system = None
filtered_messages = []
for msg in messages:
if msg.get("role") == "system":
last_system = msg.get("content")
else:
filtered_messages.append(msg)
final_system = merge_system_prompt(system_prompt, last_system)
api_messages = [LLMMessage(role="system", content=final_system)]
for msg in filtered_messages:
api_messages.append(LLMMessage(
role=msg.get("role", "user"),
content=msg.get("content"),
Expand Down