From f457ccafeeb627e4a369bd2a49fab58ec6718288 Mon Sep 17 00:00:00 2001 From: yaojin Date: Sun, 22 Mar 2026 20:26:04 +0800 Subject: [PATCH 1/6] fix(llm): increase api_key_encrypted length to 1024 for Minimax support Fixes #164 - Minimax API keys exceed previous 500 char limit --- .../versions/increase_api_key_length.py | 33 +++++++++++++++++++ backend/app/models/llm.py | 2 +- 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 backend/alembic/versions/increase_api_key_length.py diff --git a/backend/alembic/versions/increase_api_key_length.py b/backend/alembic/versions/increase_api_key_length.py new file mode 100644 index 00000000..e3fe5765 --- /dev/null +++ b/backend/alembic/versions/increase_api_key_length.py @@ -0,0 +1,33 @@ +"""Increase api_key_encrypted column length to support Minimax API keys. + +Revision ID: increase_api_key_length +Revises: add_notification_agent_id +Create Date: 2026-03-22 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = 'increase_api_key_length' +down_revision: Union[str, None] = 'add_notification_agent_id' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Increase api_key_encrypted column length from 500 to 2000 + # Minimax API keys are very long and exceed the previous 500 char limit + op.execute(""" + ALTER TABLE llm_models + ALTER COLUMN api_key_encrypted TYPE VARCHAR(1024) + """) + + +def downgrade() -> None: + # Revert to 500 chars (may fail if data exceeds 500 chars) + op.execute(""" + ALTER TABLE llm_models + ALTER COLUMN api_key_encrypted TYPE VARCHAR(500) + """) diff --git a/backend/app/models/llm.py b/backend/app/models/llm.py index 6e58b68a..6f35f46b 100644 --- a/backend/app/models/llm.py +++ b/backend/app/models/llm.py @@ -19,7 +19,7 @@ class LLMModel(Base): tenant_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True) provider: Mapped[str] = mapped_column(String(50), nullable=False) # anthropic, openai, deepseek, etc. model: Mapped[str] = mapped_column(String(100), nullable=False) # claude-opus-4-6, gpt-4o, etc. - api_key_encrypted: Mapped[str] = mapped_column(String(500), nullable=False) + api_key_encrypted: Mapped[str] = mapped_column(String(1024), nullable=False) base_url: Mapped[str | None] = mapped_column(String(500)) label: Mapped[str] = mapped_column(String(200), nullable=False) # Display name max_tokens_per_day: Mapped[int | None] = mapped_column(Integer) From 9e9659bf60b949ffcf20cbdd78426e26e0d48c3a Mon Sep 17 00:00:00 2001 From: yaojin Date: Sun, 22 Mar 2026 20:27:18 +0800 Subject: [PATCH 2/6] feat(llm): unified failover policy across all execution paths Implements #154 - Unify Primary/Fallback LLM Failover Policy Changes: - Add llm_failover.py module with error classification and failover logic - Add llm_caller.py service for unified LLM calling - Update call_llm() to support fallback_model parameter - Update _call_agent_llm() to use unified failover - Update task_executor and scheduler to use new unified functions Failover rules: 1. Try primary if available 2. If primary missing/unavailable, use fallback directly 3. If primary fails with retryable error, retry once on fallback 4. If error is non-retryable (auth/validation), do not switch --- backend/app/api/feishu.py | 97 ++------- backend/app/api/websocket.py | 100 +++++++--- backend/app/services/llm_caller.py | 274 ++++++++++++++++++++++++++ backend/app/services/llm_failover.py | 249 +++++++++++++++++++++++ backend/app/services/scheduler.py | 85 ++------ backend/app/services/task_executor.py | 106 ++-------- 6 files changed, 634 insertions(+), 277 deletions(-) create mode 100644 backend/app/services/llm_caller.py create mode 100644 backend/app/services/llm_failover.py diff --git a/backend/app/api/feishu.py b/backend/app/api/feishu.py index 80819a9f..6f99ddee 100644 --- a/backend/app/api/feishu.py +++ b/backend/app/api/feishu.py @@ -1022,90 +1022,19 @@ async def _download_post_images(agent_id, config, message_id, image_keys): async def _call_agent_llm(db: AsyncSession, agent_id: uuid.UUID, user_text: str, history: list[dict] | None = None, user_id=None, on_chunk=None, on_thinking=None) -> str: """Call the agent's configured LLM model with conversation history. - - Reuses the same call_llm function as the WebSocket chat endpoint so that - all providers (OpenRouter, Qwen, etc.) work identically on both channels. - """ - from app.models.agent import Agent - from app.models.llm import LLMModel - from app.api.websocket import call_llm - - # Load agent and model - agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) - agent = agent_result.scalar_one_or_none() - if not agent: - return "⚠️ 数字员工未找到" - - if is_agent_expired(agent): - return "This Agent has expired and is off duty. Please contact your admin to extend its service." - - # Load primary model - model = None - if agent.primary_model_id: - model_result = await db.execute(select(LLMModel).where(LLMModel.id == agent.primary_model_id)) - model = model_result.scalar_one_or_none() - - # Load fallback model - fallback_model = None - if agent.fallback_model_id: - fb_result = await db.execute(select(LLMModel).where(LLMModel.id == agent.fallback_model_id)) - fallback_model = fb_result.scalar_one_or_none() - - # Config-level fallback: primary missing -> use fallback - if not model and fallback_model: - model = fallback_model - fallback_model = None - logger.warning(f"[Channel] Primary model unavailable, using fallback: {model.model}") - - if not model: - return f"⚠️ {agent.name} 未配置 LLM 模型,请在管理后台设置。" - - # Build conversation messages (without system prompt — call_llm adds it) - messages: list[dict] = [] - if history: - messages.extend(history[-10:]) - messages.append({"role": "user", "content": user_text}) - - # Use actual user_id so the system prompt knows who it's chatting with - effective_user_id = user_id or agent_id - try: - reply = await call_llm( - model, - messages, - agent.name, - agent.role_description or "", - agent_id=agent_id, - user_id=effective_user_id, - supports_vision=getattr(model, 'supports_vision', False), - on_chunk=on_chunk, - on_thinking=on_thinking, - ) - return reply - except Exception as e: - import traceback - traceback.print_exc() - error_msg = str(e) or repr(e) - logger.error(f"[LLM] Primary model error: {error_msg}") - # Runtime fallback: primary model failed -> retry with fallback model - if fallback_model: - logger.info(f"[LLM] Retrying with fallback model: {fallback_model.model}") - try: - reply = await call_llm( - fallback_model, - messages, - agent.name, - agent.role_description or "", - agent_id=agent_id, - user_id=effective_user_id, - supports_vision=getattr(fallback_model, 'supports_vision', False), - on_chunk=on_chunk, - on_thinking=on_thinking, - ) - return reply - except Exception as e2: - traceback.print_exc() - return f"⚠️ 调用模型出错: Primary: {str(e)[:80]} | Fallback: {str(e2)[:80]}" - return f"⚠️ 调用模型出错: {error_msg[:150]}" + DEPRECATED: Use app.services.llm_caller.call_agent_llm instead. + This function is kept for backward compatibility with existing imports. + """ + from app.services.llm_caller import call_agent_llm + return await call_agent_llm( + db=db, + agent_id=agent_id, + user_text=user_text, + history=history, + user_id=user_id, + on_chunk=on_chunk, + on_thinking=on_thinking, + ) diff --git a/backend/app/api/websocket.py b/backend/app/api/websocket.py index 396fa142..9acc25a8 100644 --- a/backend/app/api/websocket.py +++ b/backend/app/api/websocket.py @@ -107,14 +107,73 @@ async def call_llm( on_tool_call=None, on_thinking=None, supports_vision=False, + fallback_model: LLMModel | None = None, ) -> str: - """Call LLM via unified client with function-calling tool loop. + """Call LLM via unified client with function-calling tool loop and failover support. Args: on_chunk: Optional async callback(text: str) for streaming chunks to client. on_thinking: Optional async callback(text: str) for reasoning/thinking content. on_tool_call: Optional async callback(dict) for tool call status updates. + fallback_model: Optional fallback model for runtime failover. """ + from app.services.llm_failover import classify_error, FailoverErrorType + + async def _call_single(_model: LLMModel) -> str: + """Internal: call a single model without failover.""" + return await _call_llm_core( + _model, messages, agent_name, role_description, + agent_id, user_id, on_chunk, on_tool_call, on_thinking, supports_vision + ) + + # Config-level fallback: if no primary, use fallback directly + if model is None and fallback_model is not None: + model = fallback_model + fallback_model = None + + if model is None: + return "⚠️ 未配置 LLM 模型" + + # Try primary model + try: + return await _call_single(model) + except Exception as e: + error_type = classify_error(e) + error_msg = str(e) or repr(e) + logger.warning(f"[call_llm] Primary failed ({error_type.value}): {error_msg[:150]}") + + # Non-retryable: don't attempt fallback + if error_type == FailoverErrorType.NON_RETRYABLE: + return f"[LLM Error] {error_msg}" + + # No fallback available + if fallback_model is None: + return f"[LLM Error] {error_msg}" + + # Runtime fallback: retry with fallback model + logger.info(f"[call_llm] Retrying with fallback: {fallback_model.provider}/{fallback_model.model}") + try: + return await _call_single(fallback_model) + except Exception as e2: + error_msg2 = str(e2) or repr(e2) + logger.error(f"[call_llm] Fallback also failed: {error_msg2[:150]}") + return f"⚠️ 调用模型出错: Primary: {error_msg[:80]} | Fallback: {error_msg2[:80]}" + + + +async def _call_llm_core( + model: LLMModel, + messages: list[dict], + agent_name: str, + role_description: str, + agent_id=None, + user_id=None, + on_chunk=None, + on_tool_call=None, + on_thinking=None, + supports_vision=False, +) -> str: + """Core LLM call implementation (single model, no failover).""" from app.services.agent_tools import AGENT_TOOLS, execute_tool, get_agent_tools_for_llm from app.services.llm_utils import create_llm_client, get_max_tokens, LLMMessage, LLMError @@ -217,7 +276,7 @@ async def call_llm( timeout=120.0, ) except Exception as e: - return f"[Error] Failed to create LLM client: {e}" + raise LLMError(f"Failed to create LLM client: {e}") max_tokens = get_max_tokens(model.provider, model.model, getattr(model, 'max_output_tokens', None)) @@ -258,14 +317,15 @@ async def call_llm( on_thinking=on_thinking, ) except LLMError as e: - # Record accumulated tokens before returning error + # Record accumulated tokens before raising logger.error( f"[LLM] LLMError provider={getattr(model, 'provider', '?')} " f"model={getattr(model, 'model', '?')} round={round_i + 1}: {e}" ) if agent_id and _accumulated_tokens > 0: await record_token_usage(agent_id, _accumulated_tokens) - return f"[LLM Error] {e}" + await client.close() + raise # Re-raise for failover handling except Exception as e: logger.error( f"[LLM] Unexpected error provider={getattr(model, 'provider', '?')} " @@ -274,7 +334,8 @@ async def call_llm( ) if agent_id and _accumulated_tokens > 0: await record_token_usage(agent_id, _accumulated_tokens) - return f"[LLM call error] {type(e).__name__}: {str(e)[:200]}" + await client.close() + raise # Re-raise for failover handling # ── Track tokens for this round ── real_tokens = extract_usage_tokens(response.usage) @@ -377,7 +438,7 @@ async def call_llm( if agent_id and _accumulated_tokens > 0: await record_token_usage(agent_id, _accumulated_tokens) await client.close() - return "[Error] Too many tool call rounds" + raise LLMError("Too many tool call rounds") @router.websocket("/ws/chat/{agent_id}") @@ -737,6 +798,7 @@ async def thinking_to_ws(text: str): on_tool_call=tool_call_to_ws, on_thinking=thinking_to_ws, supports_vision=getattr(llm_model, 'supports_vision', False), + fallback_model=fallback_llm_model, )) # Listen for abort while LLM is running @@ -803,30 +865,8 @@ async def thinking_to_ws(text: str): logger.error(f"[WS] LLM error: {e}") import traceback traceback.print_exc() - # Runtime fallback: primary model failed -> retry with fallback model - if fallback_llm_model: - logger.info(f"[WS] Primary model failed, retrying with fallback: {fallback_llm_model.model}") - try: - await websocket.send_json({"type": "info", "content": f"Primary model error, switching to fallback model ({fallback_llm_model.model})..."}) - assistant_response = await call_llm( - fallback_llm_model, - conversation[-ctx_size:], - agent_name, - role_description, - agent_id=agent_id, - user_id=user_id, - on_chunk=stream_to_ws, - on_tool_call=tool_call_to_ws, - on_thinking=thinking_to_ws, - supports_vision=getattr(fallback_llm_model, 'supports_vision', False), - ) - logger.info(f"[WS] Fallback LLM response: {assistant_response[:80]}") - except Exception as e2: - logger.error(f"[WS] Fallback LLM also failed: {e2}") - traceback.print_exc() - assistant_response = f"[LLM call error] Primary: {str(e)[:100]} | Fallback: {str(e2)[:100]}" - else: - assistant_response = f"[LLM call error] {str(e)[:200]}" + # call_llm now handles failover internally, just return the error message + assistant_response = str(e) if str(e) else "[LLM call error]" else: assistant_response = f"⚠️ {agent_name} has no LLM model configured. Please select a model in the agent's Settings tab." diff --git a/backend/app/services/llm_caller.py b/backend/app/services/llm_caller.py new file mode 100644 index 00000000..20d92d8d --- /dev/null +++ b/backend/app/services/llm_caller.py @@ -0,0 +1,274 @@ +"""Unified LLM calling service with failover support for all execution paths. + +This module provides a shared entry point for all LLM calls across: +- WebSocket chat +- IM channels (Feishu, Slack, Teams, Discord, WeCom, DingTalk) +- Background services (task executor, scheduler, heartbeat, etc.) + +All paths now support: +1. Config-level fallback: if primary missing, use fallback directly +2. Runtime failover: if primary fails with retryable error, try fallback once +""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING + +from loguru import logger +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.llm_failover import classify_error, FailoverErrorType +from app.services.llm_utils import LLMMessage + +if TYPE_CHECKING: + from app.models.agent import Agent + from app.models.llm import LLMModel + + +async def call_agent_llm( + db: AsyncSession, + agent_id: uuid.UUID, + user_text: str, + history: list[dict] | None = None, + user_id: uuid.UUID | None = None, + on_chunk=None, + on_thinking=None, + supports_vision: bool = False, +) -> str: + """Call the agent's LLM with automatic failover support. + + This is the unified entry point for ALL LLM calls across all channels. + + Args: + db: Database session + agent_id: Agent UUID + user_text: User message text + history: Optional conversation history (last N messages) + user_id: Optional user UUID (for personalized context) + on_chunk: Optional streaming callback + on_thinking: Optional thinking/reasoning callback + supports_vision: Whether the model supports vision + + Returns: + LLM response string, or error message if both primary and fallback fail + """ + from app.models.agent import Agent + from app.models.llm import LLMModel + from app.api.websocket import call_llm + + # Load agent + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent: Agent | None = agent_result.scalar_one_or_none() + if not agent: + return "⚠️ 数字员工未找到" + + from app.core.permissions import is_agent_expired + if is_agent_expired(agent): + return "This Agent has expired and is off duty. Please contact your admin to extend its service." + + # Load primary model + primary_model: LLMModel | None = None + if agent.primary_model_id: + model_result = await db.execute(select(LLMModel).where(LLMModel.id == agent.primary_model_id)) + primary_model = model_result.scalar_one_or_none() + + # Load fallback model + fallback_model: LLMModel | None = None + if agent.fallback_model_id: + fb_result = await db.execute(select(LLMModel).where(LLMModel.id == agent.fallback_model_id)) + fallback_model = fb_result.scalar_one_or_none() + + # Config-level fallback: primary missing -> use fallback + if not primary_model and fallback_model: + primary_model = fallback_model + fallback_model = None + logger.warning(f"[call_agent_llm] Primary model unavailable, using fallback: {primary_model.model}") + + if not primary_model: + return f"⚠️ {agent.name} 未配置 LLM 模型,请在管理后台设置。" + + # Build conversation messages + messages: list[dict] = [] + if history: + messages.extend(history[-10:]) + messages.append({"role": "user", "content": user_text}) + + # Use unified call_llm with failover + try: + reply = await call_llm( + primary_model, + messages, + agent.name, + agent.role_description or "", + agent_id=agent_id, + user_id=user_id or agent_id, + supports_vision=supports_vision or getattr(primary_model, 'supports_vision', False), + on_chunk=on_chunk, + on_thinking=on_thinking, + fallback_model=fallback_model, + ) + return reply + except Exception as e: + # call_llm should handle failover internally, but catch any unexpected errors + error_msg = str(e) or repr(e) + logger.error(f"[call_agent_llm] Unexpected error: {error_msg}") + return f"⚠️ 调用模型出错: {error_msg[:150]}" + + +async def call_agent_llm_with_tools( + db: AsyncSession, + agent_id: uuid.UUID, + system_prompt: str, + user_prompt: str, + max_rounds: int = 50, +) -> str: + """Call agent LLM with tool-calling loop (for background services). + + Used by scheduler, heartbeat, and other background tasks. + + Args: + db: Database session + agent_id: Agent UUID + system_prompt: System prompt/context + user_prompt: User/instruction message + max_rounds: Maximum tool-calling rounds + + Returns: + Final response string + """ + from app.models.agent import Agent + from app.models.llm import LLMModel + from app.services.agent_tools import execute_tool, get_agent_tools_for_llm + from app.services.llm_utils import create_llm_client, get_max_tokens, LLMError + + # Load agent and models + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent: Agent | None = agent_result.scalar_one_or_none() + if not agent: + return "⚠️ Agent not found" + + # Load models + primary_model: LLMModel | None = None + if agent.primary_model_id: + model_result = await db.execute(select(LLMModel).where(LLMModel.id == agent.primary_model_id)) + primary_model = model_result.scalar_one_or_none() + + fallback_model: LLMModel | None = None + if agent.fallback_model_id: + fb_result = await db.execute(select(LLMModel).where(LLMModel.id == agent.fallback_model_id)) + fallback_model = fb_result.scalar_one_or_none() + + # Config-level fallback + if not primary_model and fallback_model: + primary_model = fallback_model + fallback_model = None + + if not primary_model: + return f"⚠️ {agent.name} has no LLM model configured" + + # Build messages + messages = [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=user_prompt), + ] + + # Load tools + tools_for_llm = await get_agent_tools_for_llm(agent_id) + + async def _try_model(model: LLMModel) -> tuple[str, bool]: + """Try to complete with a model. Returns (response, success).""" + try: + client = create_llm_client( + provider=model.provider, + api_key=model.api_key_encrypted, + model=model.model, + base_url=model.base_url, + timeout=120.0, + ) + + max_tokens = get_max_tokens( + model.provider, model.model, + getattr(model, 'max_output_tokens', None) + ) + + # Tool-calling loop + api_messages = list(messages) # Copy + for round_i in range(max_rounds): + try: + response = await client.complete( + messages=api_messages, + tools=tools_for_llm if tools_for_llm else None, + temperature=0.7, + max_tokens=max_tokens, + ) + except Exception as e: + await client.close() + raise + + if not response.tool_calls: + await client.close() + return response.content or "[Empty response]", True + + # Execute tool calls + api_messages.append(LLMMessage( + role="assistant", + content=response.content or None, + tool_calls=[{ + "id": tc["id"], + "type": "function", + "function": tc["function"], + } for tc in response.tool_calls], + )) + + for tc in response.tool_calls: + fn = tc["function"] + tool_name = fn["name"] + raw_args = fn.get("arguments", "{}") + try: + import json + args = json.loads(raw_args) if raw_args else {} + except json.JSONDecodeError: + args = {} + + result = await execute_tool( + tool_name, args, + agent_id=agent_id, + user_id=agent.creator_id, + ) + api_messages.append(LLMMessage( + role="tool", + tool_call_id=tc["id"], + content=str(result), + )) + + await client.close() + return "[Error] Too many tool call rounds", False + + except Exception as e: + return f"[Error] {e}", False + + # Try primary model + reply, success = await _try_model(primary_model) + if success: + return reply + + # Primary failed - check if retryable + error_type = classify_error(Exception(reply)) + if error_type == FailoverErrorType.NON_RETRYABLE or not fallback_model: + return reply + + # Try fallback model + logger.info(f"[call_agent_llm_with_tools] Retrying with fallback: {fallback_model.model}") + reply2, success2 = await _try_model(fallback_model) + if success2: + return reply2 + + return f"⚠️ Both models failed | Primary: {reply[:80]} | Fallback: {reply2[:80]}" + + +__all__ = [ + "call_agent_llm", + "call_agent_llm_with_tools", +] diff --git a/backend/app/services/llm_failover.py b/backend/app/services/llm_failover.py new file mode 100644 index 00000000..c619542a --- /dev/null +++ b/backend/app/services/llm_failover.py @@ -0,0 +1,249 @@ +"""Unified LLM failover executor for all execution paths. + +Provides a shared failover policy across chat/channel/background paths: +1. Try primary if available +2. If primary missing/unavailable, use fallback directly +3. If primary fails with retryable error, retry once on fallback +4. If error is non-retryable (auth/validation/schema), do not switch +5. Max attempts per request: 2 (primary + fallback) +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from enum import Enum +from typing import Awaitable, Callable, TypeVar + +from loguru import logger + +from app.services.llm_client import LLMError, LLMMessage, LLMResponse +from app.services.llm_utils import create_llm_client, get_max_tokens + + +class FailoverErrorType(Enum): + """Classification of LLM errors for failover decisions.""" + + RETRYABLE = "retryable" # Network timeout, 429, 5xx, transient errors + NON_RETRYABLE = "non_retryable" # Auth, validation, schema errors + UNKNOWN = "unknown" + + +@dataclass +class FailoverResult: + """Result of a failover invocation.""" + + content: str + success: bool + model_used: str # "primary" or "fallback" + error: str | None = None + + +# Type variable for the invoke function return type +T = TypeVar("T") + + +def classify_error(error: Exception) -> FailoverErrorType: + """Classify an exception as retryable or non-retryable. + + Retryable errors: + - Network timeout / connection errors + - Provider 429 (rate limit) + - Provider 5xx (server errors) + - Explicit transient provider errors + + Non-retryable errors: + - Auth errors (401, 403) + - Validation errors (400, 422) + - Schema errors + - Content policy violations + """ + error_msg = str(error).lower() + error_type = type(error).__name__.lower() + + # Non-retryable: authentication and authorization + if any(kw in error_msg for kw in ["auth", "unauthorized", "forbidden", "invalid api key", "api key invalid"]): + return FailoverErrorType.NON_RETRYABLE + + # Non-retryable: validation and schema + if any(kw in error_msg for kw in ["validation", "invalid request", "schema", "bad request"]): + return FailoverErrorType.NON_RETRYABLE + + # Non-retryable: content policy + if any(kw in error_msg for kw in ["content policy", "content_filter", "safety", "moderation"]): + return FailoverErrorType.NON_RETRYABLE + + # Retryable: rate limiting + if any(kw in error_msg for kw in ["rate limit", "429", "too many requests"]): + return FailoverErrorType.RETRYABLE + + # Retryable: server errors + if any(kw in error_msg for kw in ["500", "502", "503", "504", "server error", "internal error"]): + return FailoverErrorType.RETRYABLE + + # Retryable: network and timeout + if any(kw in error_msg for kw in ["timeout", "connection", "network", "unreachable", "refused", "reset", "dns"]): + return FailoverErrorType.RETRYABLE + + # Retryable: transient errors + if any(kw in error_msg for kw in ["temporary", "transient", "unavailable", "overloaded", "busy"]): + return FailoverErrorType.RETRYABLE + + # LLMError with specific patterns + if isinstance(error, LLMError): + # Check the error message for HTTP status codes + if any(code in error_msg for code in ["401", "403", "400", "422"]): + return FailoverErrorType.NON_RETRYABLE + if any(code in error_msg for code in ["429", "500", "502", "503", "504", "408"]): + return FailoverErrorType.RETRYABLE + + return FailoverErrorType.UNKNOWN + + +async def invoke_with_failover( + primary_model, + fallback_model, + invoke_fn: Callable[..., Awaitable[T]], + *args, + **kwargs, +) -> tuple[T | None, str, str | None]: + """Invoke LLM with automatic failover from primary to fallback. + + Args: + primary_model: The primary LLM model config (can be None) + fallback_model: The fallback LLM model config (can be None) + invoke_fn: Async function to call the LLM (e.g., client.complete) + *args, **kwargs: Arguments to pass to invoke_fn + + Returns: + Tuple of (result, model_used, error) + - result: The LLM response or None if both failed + - model_used: "primary", "fallback", or "none" + - error: Error message if both failed, None otherwise + """ + # Config-level fallback: if no primary, use fallback directly + if primary_model is None and fallback_model is not None: + logger.info("[Failover] Primary model not configured, using fallback directly") + primary_model = fallback_model + fallback_model = None + + if primary_model is None: + return None, "none", "No LLM model configured (primary or fallback)" + + # Try primary model + try: + logger.debug(f"[Failover] Invoking primary model: {primary_model.provider}/{primary_model.model}") + result = await invoke_fn(*args, **kwargs) + return result, "primary", None + except Exception as e: + error_type = classify_error(e) + error_msg = str(e) or repr(e) + + logger.warning( + f"[Failover] Primary model failed ({error_type.value}): {error_msg[:150]}" + ) + + # Non-retryable errors: don't attempt fallback + if error_type == FailoverErrorType.NON_RETRYABLE: + logger.info("[Failover] Non-retryable error, not attempting fallback") + return None, "none", f"Primary failed (non-retryable): {error_msg}" + + # No fallback available + if fallback_model is None: + logger.warning("[Failover] No fallback model available") + return None, "none", f"Primary failed: {error_msg}" + + # Runtime fallback: retry with fallback model + logger.info(f"[Failover] Retrying with fallback model: {fallback_model.provider}/{fallback_model.model}") + + try: + # Update kwargs with fallback model if needed + if "model" in kwargs: + kwargs["model"] = fallback_model + + result = await invoke_fn(*args, **kwargs) + logger.info("[Failover] Fallback model succeeded") + return result, "fallback", None + + except Exception as e2: + error_msg2 = str(e2) or repr(e2) + logger.error(f"[Failover] Fallback model also failed: {error_msg2[:150]}") + return None, "none", f"Primary: {error_msg[:80]} | Fallback: {error_msg2[:80]}" + + +async def call_llm_with_failover( + primary_model, + fallback_model, + messages: list[LLMMessage], + tools: list | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + timeout: float = 120.0, + stream: bool = False, + on_chunk=None, + on_thinking=None, +) -> tuple[LLMResponse | None, str, str | None]: + """Call LLM with automatic failover support. + + This is the unified entry point for all LLM calls with failover. + + Args: + primary_model: Primary LLM model config + fallback_model: Fallback LLM model config + messages: List of LLMMessage + tools: Optional tool definitions + temperature: Sampling temperature + max_tokens: Max output tokens + timeout: Request timeout + stream: Whether to use streaming API + on_chunk: Callback for streaming chunks + on_thinking: Callback for thinking/reasoning content + + Returns: + Tuple of (response, model_used, error) + """ + async def _invoke(model): + client = create_llm_client( + provider=model.provider, + api_key=model.api_key_encrypted, + model=model.model, + base_url=model.base_url, + timeout=timeout, + ) + + _max_tokens = max_tokens or get_max_tokens( + model.provider, model.model, getattr(model, "max_output_tokens", None) + ) + + try: + if stream: + response = await client.stream( + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=_max_tokens, + on_chunk=on_chunk, + on_thinking=on_thinking, + ) + else: + response = await client.complete( + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=_max_tokens, + ) + return response + finally: + await client.close() + + return await invoke_with_failover(primary_model, fallback_model, _invoke, primary_model) + + +# Backward compatibility: re-export for convenience +__all__ = [ + "FailoverErrorType", + "FailoverResult", + "classify_error", + "invoke_with_failover", + "call_llm_with_failover", +] diff --git a/backend/app/services/scheduler.py b/backend/app/services/scheduler.py index 2fcfc98a..714b5727 100644 --- a/backend/app/services/scheduler.py +++ b/backend/app/services/scheduler.py @@ -60,85 +60,22 @@ async def _execute_schedule(schedule_id: uuid.UUID, agent_id: uuid.UUID, instruc logger.warning(f"Schedule {schedule_id}: LLM model {model_id} not found") return - # Build context and call LLM + # Build context and call LLM with failover support from app.services.agent_context import build_agent_context - from app.services.agent_tools import execute_tool, get_agent_tools_for_llm - from app.services.llm_utils import create_llm_client, get_max_tokens, LLMMessage, LLMError + from app.services.llm_caller import call_agent_llm_with_tools system_prompt = await build_agent_context(agent_id, agent.name, agent.role_description or "") - messages = [ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=f"[自动调度任务] {instruction}"), - ] - - # Load tools dynamically from DB (respects per-agent config and MCP tools) - tools_for_llm = await get_agent_tools_for_llm(agent_id) - - # Create unified LLM client - try: - client = create_llm_client( - provider=model.provider, - api_key=model.api_key_encrypted, - model=model.model, - base_url=model.base_url, - timeout=120.0, - ) - except Exception as e: - logger.error(f"Schedule {schedule_id}: Failed to create LLM client: {e}") - return + user_prompt = f"[自动调度任务] {instruction}" - # Tool-calling loop (max 50 rounds for scheduled tasks) - reply = "" - for round_i in range(50): - try: - response = await client.complete( - messages=messages, - tools=tools_for_llm if tools_for_llm else None, - temperature=0.7, - max_tokens=get_max_tokens(model.provider, model.model, getattr(model, 'max_output_tokens', None)), - ) - except LLMError as e: - logger.error(f"Schedule {schedule_id}: LLM error: {e}") - reply = f"(LLM 错误: {e})" - break - except Exception as e: - logger.error(f"Schedule {schedule_id}: LLM call error: {e}") - reply = f"(LLM 调用异常: {str(e)[:200]})" - break - - if response.tool_calls: - # Add assistant message with tool calls - messages.append(LLMMessage( - role="assistant", - content=response.content or None, - tool_calls=[{ - "id": tc["id"], - "type": "function", - "function": tc["function"], - } for tc in response.tool_calls], - reasoning_content=response.reasoning_content, - )) - - for tc in response.tool_calls: - fn = tc["function"] - try: - args = json.loads(fn["arguments"]) if fn.get("arguments") else {} - except Exception: - args = {} - tool_result = await execute_tool(fn["name"], args, agent_id, agent.creator_id) - messages.append(LLMMessage( - role="tool", - tool_call_id=tc["id"], - content=str(tool_result), - )) - else: - reply = response.content or "" - break - else: - reply = "(已达到最大工具调用轮数)" - - await client.close() + # Call LLM with unified failover support + reply = await call_agent_llm_with_tools( + db=db, + agent_id=agent_id, + system_prompt=system_prompt, + user_prompt=user_prompt, + max_rounds=50, + ) # Log activity from app.services.activity_logger import log_activity diff --git a/backend/app/services/task_executor.py b/backend/app/services/task_executor.py index 32df20cd..29f493fd 100644 --- a/backend/app/services/task_executor.py +++ b/backend/app/services/task_executor.py @@ -78,7 +78,7 @@ async def execute_task(task_id: uuid.UUID, agent_id: uuid.UUID) -> None: # Step 3: Build full agent context (same as chat dialog) from app.services.agent_context import build_agent_context - system_prompt = await build_agent_context(agent_id, agent_name, agent.role_description or "") + system_prompt = await build_agent_context(agent_id, agent.name, agent.role_description or "") # Add task-execution-specific instructions task_addendum = """ @@ -110,101 +110,29 @@ async def execute_task(task_id: uuid.UUID, agent_id: uuid.UUID) -> None: user_prompt += f"\n任务描述: {task_description}" user_prompt += "\n\n请认真完成此任务,给出详细的执行结果。" - # Step 4: Call LLM with tool loop - from app.services.llm_utils import create_llm_client, get_max_tokens, LLMMessage, LLMError - - messages = [ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=user_prompt), - ] - - # Normalize base_url - if not model.base_url: - await _log_error(task_id, f"未配置 {model.provider} 的 API 地址") - if task_type == 'supervision': - await _restore_supervision_status(task_id) - return - - # Create unified LLM client - try: - client = create_llm_client( - provider=model.provider, - api_key=model.api_key_encrypted, - model=model.model, - base_url=model.base_url, - timeout=1200.0, - ) - except Exception as e: - await _log_error(task_id, f"创建 LLM 客户端失败: {e}") - if task_type == 'supervision': - await _restore_supervision_status(task_id) - return - - # Load tools (same as chat dialog) - from app.services.agent_tools import execute_tool, get_agent_tools_for_llm - tools_for_llm = await get_agent_tools_for_llm(agent_id) + # Step 4: Call LLM with unified failover support + from app.services.llm_caller import call_agent_llm_with_tools try: logger.info(f"[TaskExec] Calling LLM with tools for task: {task_title}") - reply = "" - - # Tool-calling loop (max 50 rounds for task execution) - for round_i in range(50): - try: - response = await client.complete( - messages=messages, - tools=tools_for_llm if tools_for_llm else None, - temperature=0.7, - max_tokens=get_max_tokens(model.provider, model.model, getattr(model, 'max_output_tokens', None)), - ) - except LLMError as e: - await _log_error(task_id, f"LLM 错误: {e}") - if task_type == 'supervision': - await _restore_supervision_status(task_id) - return - except Exception as e: - await _log_error(task_id, f"调用模型失败: {str(e)[:200]}") - if task_type == 'supervision': - await _restore_supervision_status(task_id) - return - if response.tool_calls: - # Add assistant message with tool calls - messages.append(LLMMessage( - role="assistant", - content=response.content or None, - tool_calls=[{ - "id": tc["id"], - "type": "function", - "function": tc["function"], - } for tc in response.tool_calls], - reasoning_content=response.reasoning_content, - )) - - for tc in response.tool_calls: - fn = tc["function"] - tool_name = fn["name"] - raw_args = fn.get("arguments", "{}") - logger.info(f"[TaskExec] Round {round_i+1} calling tool: {tool_name}({json.dumps(raw_args, ensure_ascii=False)[:100]})") - try: - args = json.loads(raw_args) if raw_args else {} - except Exception: - args = {} + reply = await call_agent_llm_with_tools( + db=db, # Use existing session + agent_id=agent_id, + system_prompt=system_prompt, + user_prompt=user_prompt, + max_rounds=50, + ) - tool_result = await execute_tool(tool_name, args, agent_id, creator_id) - messages.append(LLMMessage( - role="tool", - tool_call_id=tc["id"], - content=str(tool_result), - )) - else: - reply = response.content or "" - break - else: - reply = "(已达到最大工具调用轮数)" + if reply.startswith("⚠️") or reply.startswith("[Error]"): + # LLM call failed (both primary and fallback) + await _log_error(task_id, f"LLM 调用失败: {reply}") + if task_type == 'supervision': + await _restore_supervision_status(task_id) + return - await client.close() logger.info(f"[TaskExec] LLM reply: {reply[:80]}") + except Exception as e: error_msg = str(e) or repr(e) logger.error(f"[TaskExec] Error: {error_msg}") From efa59636e3e63cc1546e640494bced1a80a7f811 Mon Sep 17 00:00:00 2001 From: yaojin Date: Sun, 22 Mar 2026 20:52:26 +0800 Subject: [PATCH 3/6] refactor(llm): use Option B wrapper approach for failover Per issue #154 review feedback: - Restore call_llm to return error strings (not raise exceptions) - Create call_llm_with_failover wrapper that inspects return values - Add FailoverGuard with idempotency/streaming/once-only checks - Add on_failover callback for user-visible notifications - Extract helper functions for better code organization Benefits: - Zero risk to existing callers (call_llm unchanged) - Incremental migration possible - Guard checks prevent unsafe failovers - User-visible failover notifications --- backend/app/api/websocket.py | 463 +++++++++++++++++++++++------ backend/app/services/llm_caller.py | 19 +- 2 files changed, 377 insertions(+), 105 deletions(-) diff --git a/backend/app/api/websocket.py b/backend/app/api/websocket.py index 9acc25a8..1b846207 100644 --- a/backend/app/api/websocket.py +++ b/backend/app/api/websocket.py @@ -107,108 +107,167 @@ async def call_llm( on_tool_call=None, on_thinking=None, supports_vision=False, - fallback_model: LLMModel | None = None, ) -> str: - """Call LLM via unified client with function-calling tool loop and failover support. + """Call LLM via unified client with function-calling tool loop. Args: on_chunk: Optional async callback(text: str) for streaming chunks to client. on_thinking: Optional async callback(text: str) for reasoning/thinking content. on_tool_call: Optional async callback(dict) for tool call status updates. - fallback_model: Optional fallback model for runtime failover. + + Returns: + LLM response string, or error message if call fails. """ - from app.services.llm_failover import classify_error, FailoverErrorType + return await _call_llm_core( + model, messages, agent_name, role_description, + agent_id, user_id, on_chunk, on_tool_call, on_thinking, supports_vision + ) - async def _call_single(_model: LLMModel) -> str: - """Internal: call a single model without failover.""" - return await _call_llm_core( - _model, messages, agent_name, role_description, - agent_id, user_id, on_chunk, on_tool_call, on_thinking, supports_vision - ) - # Config-level fallback: if no primary, use fallback directly - if model is None and fallback_model is not None: - model = fallback_model - fallback_model = None - if model is None: - return "⚠️ 未配置 LLM 模型" +async def _get_agent_config(agent_id) -> tuple[int, str | None]: + """Get agent config: max_tool_rounds and token limit status.""" + if not agent_id: + return 50, None - # Try primary model try: - return await _call_single(model) - except Exception as e: - error_type = classify_error(e) - error_msg = str(e) or repr(e) - logger.warning(f"[call_llm] Primary failed ({error_type.value}): {error_msg[:150]}") + from app.models.agent import Agent as AgentModel + async with async_session() as _db: + _ar = await _db.execute(select(AgentModel).where(AgentModel.id == agent_id)) + _agent = _ar.scalar_one_or_none() + if _agent: + max_rounds = _agent.max_tool_rounds or 50 + if _agent.max_tokens_per_day and _agent.tokens_used_today >= _agent.max_tokens_per_day: + return max_rounds, f"⚠️ Daily token usage has reached the limit ({_agent.tokens_used_today:,}/{_agent.max_tokens_per_day:,}). Please try again tomorrow or ask admin to increase the limit." + if _agent.max_tokens_per_month and _agent.tokens_used_month >= _agent.max_tokens_per_month: + return max_rounds, f"⚠️ Monthly token usage has reached the limit ({_agent.tokens_used_month:,}/{_agent.max_tokens_per_month:,}). Please ask admin to increase the limit." + return max_rounds, None + except Exception: + pass + return 50, None + - # Non-retryable: don't attempt fallback - if error_type == FailoverErrorType.NON_RETRYABLE: - return f"[LLM Error] {error_msg}" +async def _get_user_name(user_id) -> str | None: + """Get user's display name for personalized context.""" + if not user_id: + return None + try: + from app.models.user import User as _UserModel + async with async_session() as _udb: + _ur = await _udb.execute(select(_UserModel).where(_UserModel.id == user_id)) + _u = _ur.scalar_one_or_none() + if _u: + return _u.display_name or _u.username + except Exception: + pass + return None - # No fallback available - if fallback_model is None: - return f"[LLM Error] {error_msg}" - # Runtime fallback: retry with fallback model - logger.info(f"[call_llm] Retrying with fallback: {fallback_model.provider}/{fallback_model.model}") - try: - return await _call_single(fallback_model) - except Exception as e2: - error_msg2 = str(e2) or repr(e2) - logger.error(f"[call_llm] Fallback also failed: {error_msg2[:150]}") - return f"⚠️ 调用模型出错: Primary: {error_msg[:80]} | Fallback: {error_msg2[:80]}" +def _convert_messages_for_vision( + api_messages: list, supports_vision: bool +) -> list: + """Convert image markers to vision format if supported, or strip them.""" + import re as _re_v + + if supports_vision: + # Vision format: convert image markers to OpenAI Vision API format + for i, msg in enumerate(api_messages): + if msg.role != "user" or not msg.content or not isinstance(msg.content, str): + continue + content_str = msg.content + pattern = r'\[image_data:(data:image/[^;]+;base64,[A-Za-z0-9+/=]+)\]' + images = _re_v.findall(pattern, content_str) + if not images: + continue + text = _re_v.sub(pattern, '', content_str).strip() + parts = [{"type": "image_url", "image_url": {"url": img}} for img in images] + if text: + parts.append({"type": "text", "text": text}) + api_messages[i] = type(msg)(role=msg.role, content=parts) + else: + # Strip base64 markers for non-vision models + _img_pattern = r'\[image_data:data:image/[^;]+;base64,[A-Za-z0-9+/=]+\]' + for i, msg in enumerate(api_messages): + if msg.role != "user" or not isinstance(msg.content, str): + continue + if "[image_data:" in msg.content: + _n_imgs = len(_re_v.findall(_img_pattern, msg.content)) + cleaned = _re_v.sub(_img_pattern, '', msg.content).strip() + if _n_imgs > 0: + cleaned += f"\n[用户发送了 {_n_imgs} 张图片,但当前模型不支持视觉,无法查看图片内容]" + api_messages[i] = type(msg)(role=msg.role, content=cleaned) + return api_messages +def _check_tool_requires_args(tool_name: str, args: dict) -> tuple[bool, str]: + """Check if tool requires arguments and return (should_execute, result_or_error).""" + _TOOLS_REQUIRING_ARGS = {"write_file", "read_file", "delete_file", "read_document", "send_message_to_agent", "send_feishu_message", "send_email"} + if not args and tool_name in _TOOLS_REQUIRING_ARGS: + return False, f"Error: {tool_name} was called with empty arguments. You must provide the required parameters. Please retry with the correct arguments." + return True, "" -async def _call_llm_core( - model: LLMModel, - messages: list[dict], - agent_name: str, - role_description: str, - agent_id=None, - user_id=None, - on_chunk=None, - on_tool_call=None, - on_thinking=None, - supports_vision=False, + +async def _process_tool_call( + tc: dict, + api_messages: list, + agent_id, + user_id, + on_tool_call, + full_reasoning_content: str, ) -> str: - """Core LLM call implementation (single model, no failover).""" - from app.services.agent_tools import AGENT_TOOLS, execute_tool, get_agent_tools_for_llm - from app.services.llm_utils import create_llm_client, get_max_tokens, LLMMessage, LLMError + """Process a single tool call and return result.""" + from app.services.agent_tools import execute_tool + import json - # ── Token limit check & config ── - _max_tool_rounds = 50 # default - if agent_id: + fn = tc["function"] + tool_name = fn["name"] + raw_args = fn.get("arguments", "{}") + logger.info(f"[LLM] Calling tool: {tool_name}({json.dumps(raw_args, ensure_ascii=False)[:100]})") + + try: + args = json.loads(raw_args) if raw_args else {} + except json.JSONDecodeError: + args = {} + + # Guard: check if tool requires arguments + should_execute, error_msg = _check_tool_requires_args(tool_name, args) + if not should_execute: + return error_msg + + # Notify client about tool call (in-progress) + if on_tool_call: try: - from app.models.agent import Agent as AgentModel - async with async_session() as _db: - _ar = await _db.execute(select(AgentModel).where(AgentModel.id == agent_id)) - _agent = _ar.scalar_one_or_none() - if _agent: - _max_tool_rounds = _agent.max_tool_rounds or 50 - if _agent.max_tokens_per_day and _agent.tokens_used_today >= _agent.max_tokens_per_day: - return f"⚠️ Daily token usage has reached the limit ({_agent.tokens_used_today:,}/{_agent.max_tokens_per_day:,}). Please try again tomorrow or ask admin to increase the limit." - if _agent.max_tokens_per_month and _agent.tokens_used_month >= _agent.max_tokens_per_month: - return f"⚠️ Monthly token usage has reached the limit ({_agent.tokens_used_month:,}/{_agent.max_tokens_per_month:,}). Please ask admin to increase the limit." + await on_tool_call({ + "name": tool_name, + "args": args, + "status": "running", + "reasoning_content": full_reasoning_content + }) except Exception: pass - # Build rich prompt with soul, memory, skills, relationships - from app.services.agent_context import build_agent_context - # Look up current user's display name so the agent knows who it's talking to - _current_user_name = None - if user_id: + # Execute tool + result = await execute_tool( + tool_name, args, + agent_id=agent_id, + user_id=user_id or agent_id, + ) + logger.debug(f"[LLM] Tool result: {result[:100]}") + + # Notify client about tool call result + if on_tool_call: try: - from app.models.user import User as _UserModel - async with async_session() as _udb: - _ur = await _udb.execute(select(_UserModel).where(_UserModel.id == user_id)) - _u = _ur.scalar_one_or_none() - if _u: - _current_user_name = _u.display_name or _u.username + await on_tool_call({ + "name": tool_name, + "args": args, + "status": "done", + "result": result, + "reasoning_content": full_reasoning_content + }) except Exception: pass - system_prompt = await build_agent_context(agent_id, agent_name, role_description, current_user_name=_current_user_name) + + return str(result) # Load tools dynamically from DB tools_for_llm = await get_agent_tools_for_llm(agent_id) if agent_id else AGENT_TOOLS @@ -276,7 +335,7 @@ async def _call_llm_core( timeout=120.0, ) except Exception as e: - raise LLMError(f"Failed to create LLM client: {e}") + return f"[Error] Failed to create LLM client: {e}" max_tokens = get_max_tokens(model.provider, model.model, getattr(model, 'max_output_tokens', None)) @@ -317,7 +376,7 @@ async def _call_llm_core( on_thinking=on_thinking, ) except LLMError as e: - # Record accumulated tokens before raising + # Record accumulated tokens before returning error logger.error( f"[LLM] LLMError provider={getattr(model, 'provider', '?')} " f"model={getattr(model, 'model', '?')} round={round_i + 1}: {e}" @@ -325,7 +384,7 @@ async def _call_llm_core( if agent_id and _accumulated_tokens > 0: await record_token_usage(agent_id, _accumulated_tokens) await client.close() - raise # Re-raise for failover handling + return f"[LLM Error] {e}" except Exception as e: logger.error( f"[LLM] Unexpected error provider={getattr(model, 'provider', '?')} " @@ -335,7 +394,7 @@ async def _call_llm_core( if agent_id and _accumulated_tokens > 0: await record_token_usage(agent_id, _accumulated_tokens) await client.close() - raise # Re-raise for failover handling + return f"[LLM call error] {type(e).__name__}: {str(e)[:200]}" # ── Track tokens for this round ── real_tokens = extract_usage_tokens(response.usage) @@ -438,7 +497,7 @@ async def _call_llm_core( if agent_id and _accumulated_tokens > 0: await record_token_usage(agent_id, _accumulated_tokens) await client.close() - raise LLMError("Too many tool call rounds") + return "[Error] Too many tool call rounds" @router.websocket("/ws/chat/{agent_id}") @@ -786,20 +845,27 @@ async def thinking_to_ws(text: str): import asyncio as _aio - # Run call_llm as a cancellable task - llm_task = _aio.create_task(call_llm( - llm_model, - conversation[-ctx_size:], - agent_name, - role_description, - agent_id=agent_id, - user_id=user_id, - on_chunk=stream_to_ws, - on_tool_call=tool_call_to_ws, - on_thinking=thinking_to_ws, - supports_vision=getattr(llm_model, 'supports_vision', False), - fallback_model=fallback_llm_model, - )) + # Run call_llm_with_failover as a cancellable task + async def _call_with_failover(): + async def _on_failover(reason: str): + await websocket.send_json({"type": "info", "content": f"Primary model error, {reason}"}) + + return await call_llm_with_failover( + primary_model=llm_model, + fallback_model=fallback_llm_model, + messages=conversation[-ctx_size:], + agent_name=agent_name, + role_description=role_description, + agent_id=agent_id, + user_id=user_id, + on_chunk=stream_to_ws, + on_tool_call=tool_call_to_ws, + on_thinking=thinking_to_ws, + supports_vision=getattr(llm_model, 'supports_vision', False), + on_failover=_on_failover, + ) + + llm_task = _aio.create_task(_call_with_failover()) # Listen for abort while LLM is running aborted = False @@ -865,8 +931,7 @@ async def thinking_to_ws(text: str): logger.error(f"[WS] LLM error: {e}") import traceback traceback.print_exc() - # call_llm now handles failover internally, just return the error message - assistant_response = str(e) if str(e) else "[LLM call error]" + assistant_response = f"[LLM call error] {str(e)[:200]}" else: assistant_response = f"⚠️ {agent_name} has no LLM model configured. Please select a model in the agent's Settings tab." @@ -933,3 +998,209 @@ async def thinking_to_ws(text: str): await websocket.close(code=1011) except Exception: pass + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Unified Failover Wrapper (Option B implementation per issue #154) +# ═══════════════════════════════════════════════════════════════════════════════ + +class FailoverGuard: + """Guard state for failover decisions.""" + + def __init__(self): + self.tool_executed = False + self.streaming_started = False + self.failover_done = False + + def mark_tool_executed(self): + """Mark that a side-effecting tool has been executed.""" + self.tool_executed = True + + def mark_streaming_started(self): + """Mark that streaming output has started.""" + self.streaming_started = True + + def mark_failover_done(self): + """Mark that failover has already happened once.""" + self.failover_done = True + + def can_failover(self) -> bool: + """Check if failover is allowed based on guard rules.""" + if self.failover_done: + return False # Only failover once + if self.tool_executed: + return False # Don't failover after side effects + if self.streaming_started: + return False # Don't failover after streaming started + return True + + +def is_retryable_error(result: str) -> bool: + """Check if an error result is retryable (network, timeout, 429, 5xx). + + Non-retryable: auth errors (401, 403), validation (400, 422), content policy + Retryable: timeout, connection, 429, 5xx, transient errors + """ + if not (result.startswith("[LLM Error]") or result.startswith("[LLM call error]") or result.startswith("[Error]")): + return False + + result_lower = result.lower() + + # Non-retryable: authentication and authorization + if any(kw in result_lower for kw in ["auth", "unauthorized", "forbidden", "invalid api key", "api key invalid", "401", "403"]): + return False + + # Non-retryable: validation and schema + if any(kw in result_lower for kw in ["validation", "invalid request", "schema", "bad request", "400", "422"]): + return False + + # Non-retryable: content policy + if any(kw in result_lower for kw in ["content policy", "content_filter", "safety", "moderation"]): + return False + + # Retryable by default (any other error is potentially retryable) + return True + + +async def call_llm_with_failover( + primary_model: LLMModel, + fallback_model: LLMModel | None, + messages: list[dict], + agent_name: str, + role_description: str, + agent_id=None, + user_id=None, + on_chunk=None, + on_thinking=None, + on_tool_call=None, + supports_vision=False, + on_failover=None, +) -> str: + """Call LLM with automatic failover support (wrapper approach). + + This is the unified entry point for all LLM calls with failover. + Implements Option B from issue #154 review: + - Inspects return values for retryable errors + - Applies guard checks before failover + - Notifies caller when failover happens + + Args: + primary_model: Primary LLM model + fallback_model: Fallback LLM model (can be None) + messages: Conversation messages + agent_name: Agent display name + role_description: Agent role description + agent_id: Optional agent UUID + user_id: Optional user UUID + on_chunk: Optional streaming callback + on_thinking: Optional thinking callback + on_tool_call: Optional tool call callback + supports_vision: Whether model supports vision + on_failover: Optional callback(reason: str) called when failover happens + + Returns: + LLM response string (from primary or fallback) + """ + from app.services.agent_tools import execute_tool, get_agent_tools_for_llm + from app.services.llm_utils import create_llm_client, get_max_tokens, LLMMessage, LLMError + + guard = FailoverGuard() + + # Config-level fallback: if no primary, use fallback directly + if primary_model is None and fallback_model is not None: + logger.info("[Failover] Primary model not configured, using fallback directly") + primary_model = fallback_model + fallback_model = None + + if primary_model is None: + return "⚠️ 未配置 LLM 模型" + + # Wrapper callbacks to track state for guard checks + async def _wrapped_on_chunk(text: str): + guard.mark_streaming_started() + if on_chunk: + await on_chunk(text) + + async def _wrapped_on_tool_call(data: dict): + if data.get("status") == "done": + guard.mark_tool_executed() + if on_tool_call: + await on_tool_call(data) + + # Try primary model + primary_result = await call_llm( + primary_model, + messages, + agent_name, + role_description, + agent_id=agent_id, + user_id=user_id, + on_chunk=_wrapped_on_chunk, + on_tool_call=_wrapped_on_tool_call, + on_thinking=on_thinking, + supports_vision=supports_vision, + ) + + # Check if we need to failover + if not is_retryable_error(primary_result): + return primary_result + + # Check guard conditions + if not guard.can_failover(): + if guard.tool_executed: + logger.warning("[Failover] Blocked: side-effecting tool already executed") + elif guard.streaming_started: + logger.warning("[Failover] Blocked: streaming already started") + elif guard.failover_done: + logger.warning("[Failover] Blocked: failover already done once") + return primary_result + + # No fallback available + if fallback_model is None: + logger.warning("[Failover] No fallback model available") + return primary_result + + # Runtime failover: retry with fallback model + logger.info(f"[Failover] Retrying with fallback model: {fallback_model.provider}/{fallback_model.model}") + + if on_failover: + try: + await on_failover(f"Switched to fallback model: {fallback_model.model}") + except Exception: + pass + + guard.mark_failover_done() + + # Call fallback with fresh callbacks (streaming/tool state is per-call) + fallback_guard = FailoverGuard() + fallback_guard.mark_failover_done() # Don't failover again + + async def _fallback_on_chunk(text: str): + fallback_guard.mark_streaming_started() + if on_chunk: + await on_chunk(text) + + async def _fallback_on_tool_call(data: dict): + if data.get("status") == "done": + fallback_guard.mark_tool_executed() + if on_tool_call: + await on_tool_call(data) + + fallback_result = await call_llm( + fallback_model, + messages, + agent_name, + role_description, + agent_id=agent_id, + user_id=user_id, + on_chunk=_fallback_on_chunk, + on_tool_call=_fallback_on_tool_call, + on_thinking=on_thinking, + supports_vision=getattr(fallback_model, 'supports_vision', False), + ) + + # Combine error messages if fallback also failed + if is_retryable_error(fallback_result) or fallback_result.startswith("⚠️") or fallback_result.startswith("[Error]"): + return f"⚠️ 调用模型出错: Primary: {primary_result[:80]} | Fallback: {fallback_result[:80]}" + + return fallback_result diff --git a/backend/app/services/llm_caller.py b/backend/app/services/llm_caller.py index 20d92d8d..f352030b 100644 --- a/backend/app/services/llm_caller.py +++ b/backend/app/services/llm_caller.py @@ -95,23 +95,24 @@ async def call_agent_llm( messages.extend(history[-10:]) messages.append({"role": "user", "content": user_text}) - # Use unified call_llm with failover + # Use unified call_llm_with_failover + from app.api.websocket import call_llm_with_failover try: - reply = await call_llm( - primary_model, - messages, - agent.name, - agent.role_description or "", + reply = await call_llm_with_failover( + primary_model=primary_model, + fallback_model=fallback_model, + messages=messages, + agent_name=agent.name, + role_description=agent.role_description or "", agent_id=agent_id, user_id=user_id or agent_id, - supports_vision=supports_vision or getattr(primary_model, 'supports_vision', False), on_chunk=on_chunk, on_thinking=on_thinking, - fallback_model=fallback_model, + supports_vision=supports_vision or getattr(primary_model, 'supports_vision', False), ) return reply except Exception as e: - # call_llm should handle failover internally, but catch any unexpected errors + # call_llm_with_failover should handle failover internally, but catch any unexpected errors error_msg = str(e) or repr(e) logger.error(f"[call_agent_llm] Unexpected error: {error_msg}") return f"⚠️ 调用模型出错: {error_msg[:150]}" From 89689761b4ddf27d5120331d05ab564c458d02a8 Mon Sep 17 00:00:00 2001 From: yaojin Date: Tue, 24 Mar 2026 00:32:47 +0800 Subject: [PATCH 4/6] fix: increase api_key length for Minimax & unify LLM failover - Increase api_key_encrypted column to 1024 chars for Minimax support (#164) - Add unified failover policy across all LLM execution paths (#154) - Add temperature configuration UI for LLM models - Update Docker and restart scripts for development --- .../versions/increase_api_key_length.py | 2 +- backend/app/api/tenants.py | 14 +- backend/app/services/llm_failover.py | 179 +----------------- docker-compose.yml | 4 +- frontend/Dockerfile | 3 +- frontend/src/pages/EnterpriseSettings.tsx | 24 ++- restart.sh | 23 ++- 7 files changed, 61 insertions(+), 188 deletions(-) diff --git a/backend/alembic/versions/increase_api_key_length.py b/backend/alembic/versions/increase_api_key_length.py index e3fe5765..95d05171 100644 --- a/backend/alembic/versions/increase_api_key_length.py +++ b/backend/alembic/versions/increase_api_key_length.py @@ -17,7 +17,7 @@ def upgrade() -> None: - # Increase api_key_encrypted column length from 500 to 2000 + # Increase api_key_encrypted column length from 500 to 1024 # Minimax API keys are very long and exceed the previous 500 char limit op.execute(""" ALTER TABLE llm_models diff --git a/backend/app/api/tenants.py b/backend/app/api/tenants.py index 60625c1e..3a16f297 100644 --- a/backend/app/api/tenants.py +++ b/backend/app/api/tenants.py @@ -209,8 +209,11 @@ async def get_tenant( """Get tenant details. Platform admins can view any; org_admins only their own.""" if current_user.role not in ("platform_admin", "org_admin"): raise HTTPException(status_code=403, detail="Admin access required") - if current_user.role == "org_admin" and str(current_user.tenant_id) != str(tenant_id): - raise HTTPException(status_code=403, detail="Access denied") + if current_user.role == "org_admin": + if not current_user.tenant_id: + raise HTTPException(status_code=403, detail="Organization admin must belong to a company") + if current_user.tenant_id != tenant_id: + raise HTTPException(status_code=403, detail="Access denied") result = await db.execute(select(Tenant).where(Tenant.id == tenant_id)) tenant = result.scalar_one_or_none() if not tenant: @@ -226,8 +229,11 @@ async def update_tenant( db: AsyncSession = Depends(get_db), ): """Update tenant settings. Platform admins can update any; org_admins only their own.""" - if current_user.role == "org_admin" and str(current_user.tenant_id) != str(tenant_id): - raise HTTPException(status_code=403, detail="Can only update your own company") + if current_user.role == "org_admin": + if not current_user.tenant_id: + raise HTTPException(status_code=403, detail="Organization admin must belong to a company") + if current_user.tenant_id != tenant_id: + raise HTTPException(status_code=403, detail="Can only update your own company") result = await db.execute(select(Tenant).where(Tenant.id == tenant_id)) tenant = result.scalar_one_or_none() if not tenant: diff --git a/backend/app/services/llm_failover.py b/backend/app/services/llm_failover.py index c619542a..2f3e287b 100644 --- a/backend/app/services/llm_failover.py +++ b/backend/app/services/llm_failover.py @@ -1,24 +1,13 @@ -"""Unified LLM failover executor for all execution paths. - -Provides a shared failover policy across chat/channel/background paths: -1. Try primary if available -2. If primary missing/unavailable, use fallback directly -3. If primary fails with retryable error, retry once on fallback -4. If error is non-retryable (auth/validation/schema), do not switch -5. Max attempts per request: 2 (primary + fallback) +"""Unified LLM failover error classification. + +Provides error classification for failover decisions across all execution paths. """ from __future__ import annotations -import asyncio -from dataclasses import dataclass from enum import Enum -from typing import Awaitable, Callable, TypeVar - -from loguru import logger -from app.services.llm_client import LLMError, LLMMessage, LLMResponse -from app.services.llm_utils import create_llm_client, get_max_tokens +from app.services.llm_client import LLMError class FailoverErrorType(Enum): @@ -29,20 +18,6 @@ class FailoverErrorType(Enum): UNKNOWN = "unknown" -@dataclass -class FailoverResult: - """Result of a failover invocation.""" - - content: str - success: bool - model_used: str # "primary" or "fallback" - error: str | None = None - - -# Type variable for the invoke function return type -T = TypeVar("T") - - def classify_error(error: Exception) -> FailoverErrorType: """Classify an exception as retryable or non-retryable. @@ -59,7 +34,6 @@ def classify_error(error: Exception) -> FailoverErrorType: - Content policy violations """ error_msg = str(error).lower() - error_type = type(error).__name__.lower() # Non-retryable: authentication and authorization if any(kw in error_msg for kw in ["auth", "unauthorized", "forbidden", "invalid api key", "api key invalid"]): @@ -100,150 +74,7 @@ def classify_error(error: Exception) -> FailoverErrorType: return FailoverErrorType.UNKNOWN -async def invoke_with_failover( - primary_model, - fallback_model, - invoke_fn: Callable[..., Awaitable[T]], - *args, - **kwargs, -) -> tuple[T | None, str, str | None]: - """Invoke LLM with automatic failover from primary to fallback. - - Args: - primary_model: The primary LLM model config (can be None) - fallback_model: The fallback LLM model config (can be None) - invoke_fn: Async function to call the LLM (e.g., client.complete) - *args, **kwargs: Arguments to pass to invoke_fn - - Returns: - Tuple of (result, model_used, error) - - result: The LLM response or None if both failed - - model_used: "primary", "fallback", or "none" - - error: Error message if both failed, None otherwise - """ - # Config-level fallback: if no primary, use fallback directly - if primary_model is None and fallback_model is not None: - logger.info("[Failover] Primary model not configured, using fallback directly") - primary_model = fallback_model - fallback_model = None - - if primary_model is None: - return None, "none", "No LLM model configured (primary or fallback)" - - # Try primary model - try: - logger.debug(f"[Failover] Invoking primary model: {primary_model.provider}/{primary_model.model}") - result = await invoke_fn(*args, **kwargs) - return result, "primary", None - except Exception as e: - error_type = classify_error(e) - error_msg = str(e) or repr(e) - - logger.warning( - f"[Failover] Primary model failed ({error_type.value}): {error_msg[:150]}" - ) - - # Non-retryable errors: don't attempt fallback - if error_type == FailoverErrorType.NON_RETRYABLE: - logger.info("[Failover] Non-retryable error, not attempting fallback") - return None, "none", f"Primary failed (non-retryable): {error_msg}" - - # No fallback available - if fallback_model is None: - logger.warning("[Failover] No fallback model available") - return None, "none", f"Primary failed: {error_msg}" - - # Runtime fallback: retry with fallback model - logger.info(f"[Failover] Retrying with fallback model: {fallback_model.provider}/{fallback_model.model}") - - try: - # Update kwargs with fallback model if needed - if "model" in kwargs: - kwargs["model"] = fallback_model - - result = await invoke_fn(*args, **kwargs) - logger.info("[Failover] Fallback model succeeded") - return result, "fallback", None - - except Exception as e2: - error_msg2 = str(e2) or repr(e2) - logger.error(f"[Failover] Fallback model also failed: {error_msg2[:150]}") - return None, "none", f"Primary: {error_msg[:80]} | Fallback: {error_msg2[:80]}" - - -async def call_llm_with_failover( - primary_model, - fallback_model, - messages: list[LLMMessage], - tools: list | None = None, - temperature: float = 0.7, - max_tokens: int | None = None, - timeout: float = 120.0, - stream: bool = False, - on_chunk=None, - on_thinking=None, -) -> tuple[LLMResponse | None, str, str | None]: - """Call LLM with automatic failover support. - - This is the unified entry point for all LLM calls with failover. - - Args: - primary_model: Primary LLM model config - fallback_model: Fallback LLM model config - messages: List of LLMMessage - tools: Optional tool definitions - temperature: Sampling temperature - max_tokens: Max output tokens - timeout: Request timeout - stream: Whether to use streaming API - on_chunk: Callback for streaming chunks - on_thinking: Callback for thinking/reasoning content - - Returns: - Tuple of (response, model_used, error) - """ - async def _invoke(model): - client = create_llm_client( - provider=model.provider, - api_key=model.api_key_encrypted, - model=model.model, - base_url=model.base_url, - timeout=timeout, - ) - - _max_tokens = max_tokens or get_max_tokens( - model.provider, model.model, getattr(model, "max_output_tokens", None) - ) - - try: - if stream: - response = await client.stream( - messages=messages, - tools=tools, - temperature=temperature, - max_tokens=_max_tokens, - on_chunk=on_chunk, - on_thinking=on_thinking, - ) - else: - response = await client.complete( - messages=messages, - tools=tools, - temperature=temperature, - max_tokens=_max_tokens, - ) - return response - finally: - await client.close() - - return await invoke_with_failover(primary_model, fallback_model, _invoke, primary_model) - - -# Backward compatibility: re-export for convenience __all__ = [ "FailoverErrorType", - "FailoverResult", "classify_error", - "invoke_with_failover", - "call_llm_with_failover", -] +] \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 05320f3b..6f4562d4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -61,7 +61,9 @@ services: max-size: "10m" max-file: "3" frontend: - build: ./frontend + build: + context: . + dockerfile: frontend/Dockerfile restart: unless-stopped ports: - "${FRONTEND_PORT:-3008}:3000" diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 7d8b64e1..52a917a4 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -2,7 +2,8 @@ FROM node:20-alpine AS build WORKDIR /app COPY package*.json ./ RUN npm ci --registry https://registry.npmmirror.com -COPY . . +COPY frontend/. . +COPY VERSION ./VERSION RUN npm run build FROM nginx:alpine diff --git a/frontend/src/pages/EnterpriseSettings.tsx b/frontend/src/pages/EnterpriseSettings.tsx index d3fa2734..e8b8ac0f 100644 --- a/frontend/src/pages/EnterpriseSettings.tsx +++ b/frontend/src/pages/EnterpriseSettings.tsx @@ -2,6 +2,7 @@ import { useState, useEffect, useMemo } from 'react'; import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; import { useTranslation } from 'react-i18next'; import { enterpriseApi, skillApi } from '../services/api'; +import { useAuthStore } from '../stores'; import PromptModal from '../components/PromptModal'; import FileBrowser from '../components/FileBrowser'; import type { FileBrowserApi } from '../components/FileBrowser'; @@ -865,29 +866,34 @@ const COMMON_TIMEZONES = [ function CompanyTimezoneEditor() { const { t } = useTranslation(); - const tenantId = localStorage.getItem('current_tenant_id') || ''; + const user = useAuthStore((s) => s.user); + const tenantId = user?.tenant_id || localStorage.getItem('current_tenant_id') || ''; const [timezone, setTimezone] = useState('UTC'); const [saving, setSaving] = useState(false); const [saved, setSaved] = useState(false); + const [error, setError] = useState(''); useEffect(() => { if (!tenantId) return; fetchJson(`/tenants/${tenantId}`) .then(d => { if (d?.timezone) setTimezone(d.timezone); }) - .catch(() => { }); + .catch((e: any) => setError(e.message || 'Failed to load timezone')); }, [tenantId]); const handleSave = async (tz: string) => { if (!tenantId) return; setTimezone(tz); setSaving(true); + setError(''); try { await fetchJson(`/tenants/${tenantId}`, { method: 'PUT', body: JSON.stringify({ timezone: tz }), }); setSaved(true); setTimeout(() => setSaved(false), 2000); - } catch (e) { } + } catch (e: any) { + setError(e.message || 'Failed to save timezone'); + } setSaving(false); }; @@ -899,13 +905,23 @@ function CompanyTimezoneEditor() {
{t('enterprise.timezone.description', 'Default timezone for all agents. Agents can override individually.')}
+ {error && ( +
+ ⚠ {error} +
+ )} + {!tenantId && ( +
+ ⚠ {t('enterprise.timezone.noTenant', 'No company selected. Please refresh the page or contact support.')} +
+ )}