diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index 55b75cb1bd..5190a363ee 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -2,6 +2,10 @@ from astrbot.api import sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.agent.runners.deerflow.constants import ( + DEERFLOW_PROVIDER_TYPE, + DEERFLOW_THREAD_ID_KEY, +) from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.message_type import MessageType from astrbot.core.utils.active_event_registry import active_event_registry @@ -12,6 +16,7 @@ "dify": "dify_conversation_id", "coze": "coze_conversation_id", "dashscope": "dashscope_conversation_id", + DEERFLOW_PROVIDER_TYPE: DEERFLOW_THREAD_ID_KEY, } THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys()) diff --git a/astrbot/core/agent/runners/deerflow/constants.py b/astrbot/core/agent/runners/deerflow/constants.py new file mode 100644 index 0000000000..687027efe7 --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/constants.py @@ -0,0 +1,4 @@ +DEERFLOW_PROVIDER_TYPE = "deerflow" +DEERFLOW_THREAD_ID_KEY = "deerflow_thread_id" +DEERFLOW_SESSION_PREFIX = "deerflow-ephemeral" +DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY = "deerflow_agent_runner_provider_id" diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py new file mode 100644 index 0000000000..50ec7c8262 --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -0,0 +1,693 @@ +import asyncio +import hashlib +import json +import sys +import typing as T +from collections import deque +from dataclasses import dataclass, field +from uuid import uuid4 + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core import sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) +from astrbot.core.utils.config_number import coerce_int_config + +from ...hooks import BaseAgentRunHooks +from ...response import AgentResponseData +from ...run_context import ContextWrapper, TContext +from ..base import AgentResponse, AgentState, BaseAgentRunner +from .constants import DEERFLOW_SESSION_PREFIX, DEERFLOW_THREAD_ID_KEY +from .deerflow_api_client import DeerFlowAPIClient +from .deerflow_content_mapper import ( + build_chain_from_ai_content, + build_user_content, + image_component_from_url, +) +from .deerflow_stream_utils import ( + build_task_failure_summary, + extract_ai_delta_from_event_data, + extract_clarification_from_event_data, + extract_latest_ai_message, + extract_latest_ai_text, + extract_latest_clarification_text, + extract_messages_from_values_data, + extract_task_failures_from_custom_event, + get_message_id, +) + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class DeerFlowAgentRunner(BaseAgentRunner[TContext]): + """DeerFlow Agent Runner via LangGraph HTTP API.""" + + _MAX_VALUES_HISTORY = 200 + + @dataclass(frozen=True) + class _RunnerConfig: + api_base: str + api_key: str + auth_header: str + proxy: str + assistant_id: str + model_name: str + thinking_enabled: bool + plan_mode: bool + subagent_enabled: bool + max_concurrent_subagents: int + timeout: int + recursion_limit: int + + @dataclass + class _StreamState: + latest_text: str = "" + prev_text_for_streaming: str = "" + clarification_text: str = "" + task_failures: list[str] = field(default_factory=list) + seen_message_ids: set[str] = field(default_factory=set) + seen_message_order: deque[str] = field(default_factory=deque) + # Fallback tracking for backends that omit message ids in values events. + no_id_message_fingerprints: dict[int, str] = field(default_factory=dict) + baseline_initialized: bool = False + has_values_text: bool = False + run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) + timed_out: bool = False + + @dataclass(frozen=True) + class _FinalResult: + chain: MessageChain + role: str + + def _format_exception(self, err: Exception) -> str: + err_type = type(err).__name__ + detail = str(err).strip() + + if isinstance(err, (asyncio.TimeoutError, TimeoutError)): + timeout_text = ( + f"{self.timeout}s" + if isinstance(getattr(self, "timeout", None), (int, float)) + else "configured timeout" + ) + return ( + f"{err_type}: request timed out after {timeout_text}. " + "Please check DeerFlow service health and backend logs." + ) + + if detail: + if detail.startswith(f"{err_type}:"): + return detail + return f"{err_type}: {detail}" + + return f"{err_type}: no detailed error message provided." + + async def close(self) -> None: + """Explicit cleanup hook for long-lived workers.""" + api_client = getattr(self, "api_client", None) + if isinstance(api_client, DeerFlowAPIClient) and not api_client.is_closed: + try: + await api_client.close() + except Exception as e: + logger.warning( + "Failed to close DeerFlowAPIClient during runner shutdown: %s", + e, + exc_info=True, + ) + + async def _notify_agent_done_hook(self) -> None: + if not self.final_llm_resp: + return + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + async def _finish_with_result( + self, chain: MessageChain, role: str + ) -> AgentResponse: + self.final_llm_resp = LLMResponse( + role=role, + result_chain=chain, + ) + self._transition_state(AgentState.DONE) + await self._notify_agent_done_hook() + return AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + async def _finish_with_error(self, err_msg: str) -> AgentResponse: + err_text = f"DeerFlow request failed: {err_msg}" + err_chain = MessageChain().message(err_text) + self.final_llm_resp = LLMResponse( + role="err", + completion_text=err_text, + result_chain=err_chain, + ) + self._transition_state(AgentState.ERROR) + await self._notify_agent_done_hook() + return AgentResponse( + type="err", + data=AgentResponseData( + chain=err_chain, + ), + ) + + def _parse_runner_config(self, provider_config: dict) -> _RunnerConfig: + api_base = provider_config.get("deerflow_api_base", "http://127.0.0.1:2026") + if not isinstance(api_base, str) or not api_base.startswith( + ("http://", "https://"), + ): + raise ValueError( + "DeerFlow API Base URL format is invalid. It must start with http:// or https://.", + ) + + proxy = provider_config.get("proxy", "") + normalized_proxy = proxy.strip() if isinstance(proxy, str) else "" + + return self._RunnerConfig( + api_base=api_base, + api_key=provider_config.get("deerflow_api_key", ""), + auth_header=provider_config.get("deerflow_auth_header", ""), + proxy=normalized_proxy, + assistant_id=provider_config.get("deerflow_assistant_id", "lead_agent"), + model_name=provider_config.get("deerflow_model_name", ""), + thinking_enabled=bool( + provider_config.get("deerflow_thinking_enabled", False), + ), + plan_mode=bool(provider_config.get("deerflow_plan_mode", False)), + subagent_enabled=bool( + provider_config.get("deerflow_subagent_enabled", False), + ), + max_concurrent_subagents=coerce_int_config( + provider_config.get("deerflow_max_concurrent_subagents", 3), + default=3, + min_value=1, + field_name="deerflow_max_concurrent_subagents", + source="DeerFlow config", + ), + timeout=coerce_int_config( + provider_config.get("timeout", 300), + default=300, + min_value=1, + field_name="timeout", + source="DeerFlow config", + ), + recursion_limit=coerce_int_config( + provider_config.get("deerflow_recursion_limit", 1000), + default=1000, + min_value=1, + field_name="deerflow_recursion_limit", + source="DeerFlow config", + ), + ) + + async def _load_config_and_client(self, provider_config: dict) -> None: + config = self._parse_runner_config(provider_config) + + self.api_base = config.api_base + self.api_key = config.api_key + self.auth_header = config.auth_header + self.proxy = config.proxy + self.assistant_id = config.assistant_id + self.model_name = config.model_name + self.thinking_enabled = config.thinking_enabled + self.plan_mode = config.plan_mode + self.subagent_enabled = config.subagent_enabled + self.max_concurrent_subagents = config.max_concurrent_subagents + self.timeout = config.timeout + self.recursion_limit = config.recursion_limit + + new_client_signature = ( + config.api_base, + config.api_key, + config.auth_header, + config.proxy, + ) + old_client = getattr(self, "api_client", None) + old_signature = getattr(self, "_api_client_signature", None) + + if ( + isinstance(old_client, DeerFlowAPIClient) + and old_signature == new_client_signature + and not old_client.is_closed + ): + self.api_client = old_client + return + + if isinstance(old_client, DeerFlowAPIClient): + try: + await old_client.close() + except Exception as e: + logger.warning( + f"Failed to close previous DeerFlow API client cleanly: {e}" + ) + + self.api_client = DeerFlowAPIClient( + api_base=config.api_base, + api_key=config.api_key, + auth_header=config.auth_header, + proxy=config.proxy, + ) + self._api_client_signature = new_client_signature + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + await self._load_config_and_client(provider_config) + + @override + async def step(self): + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + if self.done(): + return + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + self._transition_state(AgentState.RUNNING) + + try: + async for response in self._execute_deerflow_request(): + yield response + except asyncio.CancelledError: + # Let caller manage cancellation semantics. + raise + except Exception as e: + err_msg = self._format_exception(e) + logger.error(f"DeerFlow request failed: {err_msg}", exc_info=True) + yield await self._finish_with_error(err_msg) + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + if max_step <= 0: + raise ValueError("max_step must be greater than 0") + + step_count = 0 + while not self.done() and step_count < max_step: + step_count += 1 + async for resp in self.step(): + yield resp + + if not self.done(): + raise RuntimeError( + f"DeerFlow agent reached max_step ({max_step}) without completion." + ) + + def _extract_new_messages_from_values( + self, + values_messages: list[T.Any], + state: _StreamState, + ) -> list[dict[str, T.Any]]: + new_messages: list[dict[str, T.Any]] = [] + no_id_indexes_seen: set[int] = set() + for idx, msg in enumerate(values_messages): + if not isinstance(msg, dict): + continue + msg_id = get_message_id(msg) + if msg_id: + if msg_id in state.seen_message_ids: + continue + self._remember_seen_message_id(state, msg_id) + new_messages.append(msg) + continue + + no_id_indexes_seen.add(idx) + msg_fingerprint = self._fingerprint_message(msg) + if state.no_id_message_fingerprints.get(idx) == msg_fingerprint: + continue + state.no_id_message_fingerprints[idx] = msg_fingerprint + new_messages.append(msg) + + # Keep no-id index state aligned with latest values payload shape. + for idx in list(state.no_id_message_fingerprints.keys()): + if idx not in no_id_indexes_seen: + state.no_id_message_fingerprints.pop(idx, None) + return new_messages + + def _fingerprint_message(self, message: dict[str, T.Any]) -> str: + try: + raw = json.dumps(message, sort_keys=True, ensure_ascii=False, default=str) + except (TypeError, ValueError): + raw = repr(message) + return hashlib.sha1(raw.encode("utf-8", errors="ignore")).hexdigest() + + def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None: + if not msg_id or msg_id in state.seen_message_ids: + return + + state.seen_message_ids.add(msg_id) + state.seen_message_order.append(msg_id) + while len(state.seen_message_order) > self._MAX_VALUES_HISTORY: + dropped = state.seen_message_order.popleft() + state.seen_message_ids.discard(dropped) + + async def _ensure_thread_id(self, session_id: str) -> str: + thread_id = await sp.get_async( + scope="umo", + scope_id=session_id, + key=DEERFLOW_THREAD_ID_KEY, + default="", + ) + if thread_id: + return thread_id + + thread = await self.api_client.create_thread(timeout=min(30, self.timeout)) + thread_id = thread.get("thread_id", "") + if not thread_id: + raise Exception( + f"DeerFlow create thread returned invalid payload: {thread}" + ) + + await sp.put_async( + scope="umo", + scope_id=session_id, + key=DEERFLOW_THREAD_ID_KEY, + value=thread_id, + ) + return thread_id + + def _build_messages( + self, + prompt: str, + image_urls: list[str], + system_prompt: str | None, + ) -> list[dict[str, T.Any]]: + messages: list[dict[str, T.Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append( + { + "role": "user", + "content": build_user_content(prompt, image_urls), + }, + ) + return messages + + def _build_runtime_context(self, thread_id: str) -> dict[str, T.Any]: + runtime_context: dict[str, T.Any] = { + "thread_id": thread_id, + "thinking_enabled": self.thinking_enabled, + "is_plan_mode": self.plan_mode, + "subagent_enabled": self.subagent_enabled, + } + if self.subagent_enabled: + runtime_context["max_concurrent_subagents"] = self.max_concurrent_subagents + if self.model_name: + runtime_context["model_name"] = self.model_name + return runtime_context + + def _build_payload( + self, + thread_id: str, + prompt: str, + image_urls: list[str], + system_prompt: str | None, + ) -> dict[str, T.Any]: + return { + "assistant_id": self.assistant_id, + "input": { + "messages": self._build_messages(prompt, image_urls, system_prompt), + }, + "stream_mode": ["values", "messages-tuple", "custom"], + # LangGraph 0.6+ prefers context instead of configurable. + "context": self._build_runtime_context(thread_id), + "config": { + "recursion_limit": self.recursion_limit, + }, + } + + def _update_text_and_maybe_stream( + self, + *, + state: _StreamState, + new_full_text: str | None = None, + delta_text: str | None = None, + ) -> list[AgentResponse]: + if new_full_text: + state.latest_text = new_full_text + if not self.streaming: + return [] + + if new_full_text.startswith(state.prev_text_for_streaming): + delta = new_full_text[len(state.prev_text_for_streaming) :] + else: + delta = new_full_text + + if not delta: + return [] + + state.prev_text_for_streaming = new_full_text + return [ + AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=MessageChain().message(delta)), + ) + ] + + if delta_text: + state.latest_text += delta_text + if self.streaming: + return [ + AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(delta_text) + ), + ) + ] + + return [] + + def _handle_values_event( + self, + data: T.Any, + state: _StreamState, + ) -> list[AgentResponse]: + responses: list[AgentResponse] = [] + values_messages = extract_messages_from_values_data(data) + if not values_messages: + return responses + + new_messages: list[dict[str, T.Any]] = [] + if not state.baseline_initialized: + state.baseline_initialized = True + for idx, msg in enumerate(values_messages): + if not isinstance(msg, dict): + continue + new_messages.append(msg) + msg_id = get_message_id(msg) + if msg_id: + self._remember_seen_message_id(state, msg_id) + continue + state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg) + else: + new_messages = self._extract_new_messages_from_values( + values_messages, + state, + ) + latest_text = "" + if new_messages: + state.run_values_messages.extend(new_messages) + if len(state.run_values_messages) > self._MAX_VALUES_HISTORY: + state.run_values_messages = state.run_values_messages[ + -self._MAX_VALUES_HISTORY : + ] + latest_text = extract_latest_ai_text(state.run_values_messages) + if latest_text: + state.has_values_text = True + latest_clarification = extract_latest_clarification_text( + state.run_values_messages, + ) + if latest_clarification: + state.clarification_text = latest_clarification + + responses.extend( + self._update_text_and_maybe_stream( + state=state, + new_full_text=latest_text or None, + ) + ) + return responses + + def _handle_message_event( + self, + data: T.Any, + state: _StreamState, + ) -> AgentResponse | None: + delta = extract_ai_delta_from_event_data(data) + + responses: list[AgentResponse] = [] + if delta and not state.has_values_text: + responses.extend( + self._update_text_and_maybe_stream( + state=state, + delta_text=delta, + ) + ) + + maybe_clarification = extract_clarification_from_event_data(data) + if maybe_clarification: + state.clarification_text = maybe_clarification + return responses[0] if responses else None + + def _build_final_result(self, state: _StreamState) -> _FinalResult: + failures_only = False + + if state.clarification_text: + final_chain = MessageChain(chain=[Comp.Plain(state.clarification_text)]) + else: + final_chain = MessageChain() + latest_ai_message = extract_latest_ai_message(state.run_values_messages) + if latest_ai_message: + final_chain = build_chain_from_ai_content( + latest_ai_message.get("content"), + image_component_from_url, + ) + + if not final_chain.chain and state.latest_text: + final_chain = MessageChain(chain=[Comp.Plain(state.latest_text)]) + + if not final_chain.chain: + failure_text = build_task_failure_summary(state.task_failures) + if failure_text: + final_chain = MessageChain(chain=[Comp.Plain(failure_text)]) + failures_only = True + + if not final_chain.chain: + logger.warning("DeerFlow returned no text content in stream events.") + final_chain = MessageChain( + chain=[Comp.Plain("DeerFlow returned an empty response.")], + ) + + if state.timed_out: + timeout_note = ( + f"DeerFlow stream timed out after {self.timeout}s. " + "Returning partial result." + ) + if final_chain.chain and isinstance(final_chain.chain[-1], Comp.Plain): + last_text = final_chain.chain[-1].text + final_chain.chain[-1].text = ( + f"{last_text}\n\n{timeout_note}" if last_text else timeout_note + ) + else: + final_chain.chain.append(Comp.Plain(timeout_note)) + + role = "err" if (state.timed_out or failures_only) else "assistant" + return self._FinalResult(chain=final_chain, role=role) + + def _emit_non_plain_components_at_end( + self, + final_chain: MessageChain, + ) -> AgentResponse | None: + non_plain_components = [ + component + for component in final_chain.chain + if not isinstance(component, Comp.Plain) + ] + if not non_plain_components: + return None + return AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain(chain=non_plain_components), + ), + ) + + async def _execute_deerflow_request(self): + prompt = self.req.prompt or "" + session_id = self.req.session_id or f"{DEERFLOW_SESSION_PREFIX}-{uuid4()}" + image_urls = self.req.image_urls or [] + system_prompt = self.req.system_prompt + + thread_id = await self._ensure_thread_id(session_id) + payload = self._build_payload( + thread_id=thread_id, + prompt=prompt, + image_urls=image_urls, + system_prompt=system_prompt, + ) + state = self._StreamState() + + try: + async for event in self.api_client.stream_run( + thread_id=thread_id, + payload=payload, + timeout=self.timeout, + ): + event_type = event.get("event") + data = event.get("data") + + if event_type == "values": + for response in self._handle_values_event(data, state): + yield response + continue + + if event_type in {"messages-tuple", "messages", "message"}: + response = self._handle_message_event(data, state) + if response: + yield response + continue + + if event_type == "custom": + state.task_failures.extend( + extract_task_failures_from_custom_event(data), + ) + continue + + if event_type == "error": + raise Exception(f"DeerFlow stream returned error event: {data}") + + if event_type == "end": + break + except (asyncio.TimeoutError, TimeoutError): + logger.warning( + "DeerFlow stream timed out after %ss for thread_id=%s; returning partial result.", + self.timeout, + thread_id, + ) + state.timed_out = True + + final_result = self._build_final_result(state) + + if self.streaming: + extra_response = self._emit_non_plain_components_at_end(final_result.chain) + if extra_response: + yield extra_response + + yield await self._finish_with_result(final_result.chain, final_result.role) + + @override + def done(self) -> bool: + """Check whether the agent has finished or failed.""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py new file mode 100644 index 0000000000..37a23f2432 --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -0,0 +1,245 @@ +import codecs +import json +from collections.abc import AsyncGenerator +from typing import Any + +from aiohttp import ClientResponse, ClientSession, ClientTimeout + +from astrbot.core import logger + +SSE_MAX_BUFFER_CHARS = 1_048_576 + + +def _normalize_sse_newlines(text: str) -> str: + """Normalize CRLF/CR to LF so SSE block splitting works reliably.""" + return text.replace("\r\n", "\n").replace("\r", "\n") + + +def _parse_sse_data_lines(data_lines: list[str]) -> Any: + raw_data = "\n".join(data_lines) + try: + return json.loads(raw_data) + except json.JSONDecodeError: + # Some LangGraph-compatible servers emit multiple JSON fragments + # in one SSE event using repeated data lines (e.g. tuple payloads). + parsed_lines: list[Any] = [] + can_parse_all = True + for line in data_lines: + line = line.strip() + if not line: + continue + try: + parsed_lines.append(json.loads(line)) + except json.JSONDecodeError: + can_parse_all = False + break + if can_parse_all and parsed_lines: + return parsed_lines[0] if len(parsed_lines) == 1 else parsed_lines + return raw_data + + +def _parse_sse_block(block: str) -> dict[str, Any] | None: + if not block.strip(): + return None + + event_name = "message" + data_lines: list[str] = [] + for line in block.splitlines(): + if line.startswith("event:"): + event_name = line[6:].strip() + elif line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + + if not data_lines: + return None + return {"event": event_name, "data": _parse_sse_data_lines(data_lines)} + + +async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict[str, Any], None]: + """Parse SSE response blocks into event/data dictionaries.""" + # Use a forgiving decoder at network boundaries so malformed bytes do not abort stream parsing. + decoder = codecs.getincrementaldecoder("utf-8")("replace") + buffer = "" + + async for chunk in resp.content.iter_chunked(8192): + buffer += _normalize_sse_newlines(decoder.decode(chunk)) + + while "\n\n" in buffer: + block, buffer = buffer.split("\n\n", 1) + parsed = _parse_sse_block(block) + if parsed is not None: + yield parsed + + if len(buffer) > SSE_MAX_BUFFER_CHARS: + logger.warning( + "DeerFlow SSE parser buffer exceeded %d chars without delimiter; " + "flushing oversized block to prevent unbounded memory growth.", + SSE_MAX_BUFFER_CHARS, + ) + parsed = _parse_sse_block(buffer) + if parsed is not None: + yield parsed + buffer = "" + + # flush any remaining buffered text + buffer += _normalize_sse_newlines(decoder.decode(b"", final=True)) + while "\n\n" in buffer: + block, buffer = buffer.split("\n\n", 1) + parsed = _parse_sse_block(block) + if parsed is not None: + yield parsed + + if buffer.strip(): + parsed = _parse_sse_block(buffer) + if parsed is not None: + yield parsed + + +class DeerFlowAPIClient: + """HTTP client for DeerFlow LangGraph API. + + Lifecycle is explicitly managed by callers (runner/stage). `__del__` is only a + fallback diagnostic and must not be relied on for cleanup. + """ + + def __init__( + self, + api_base: str = "http://127.0.0.1:2026", + api_key: str = "", + auth_header: str = "", + proxy: str | None = None, + ) -> None: + self.api_base = api_base.rstrip("/") + self._session: ClientSession | None = None + self._closed = False + self.proxy = proxy.strip() if isinstance(proxy, str) else None + if self.proxy == "": + self.proxy = None + self.headers: dict[str, str] = {} + if auth_header: + self.headers["Authorization"] = auth_header + elif api_key: + self.headers["Authorization"] = f"Bearer {api_key}" + + def _get_session(self) -> ClientSession: + if self._closed: + raise RuntimeError("DeerFlowAPIClient is already closed.") + if self._session is None or self._session.closed: + self._session = ClientSession(trust_env=True) + return self._session + + async def __aenter__(self) -> "DeerFlowAPIClient": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object | None, + ) -> None: + await self.close() + + async def create_thread(self, timeout: float = 20) -> dict[str, Any]: + session = self._get_session() + url = f"{self.api_base}/api/langgraph/threads" + payload = {"metadata": {}} + async with session.post( + url, + json=payload, + headers=self.headers, + timeout=timeout, + proxy=self.proxy, + ) as resp: + if resp.status not in (200, 201): + text = await resp.text() + raise Exception( + f"DeerFlow create thread failed: {resp.status}. {text}", + ) + return await resp.json() + + async def stream_run( + self, + thread_id: str, + payload: dict[str, Any], + timeout: float = 120, + ) -> AsyncGenerator[dict[str, Any], None]: + session = self._get_session() + url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream" + input_payload = payload.get("input") + message_count = 0 + if isinstance(input_payload, dict) and isinstance( + input_payload.get("messages"), list + ): + message_count = len(input_payload["messages"]) + # Log only a minimal summary to avoid exposing sensitive user content. + logger.debug( + "deerflow stream_run payload summary: thread_id=%s, keys=%s, message_count=%d, stream_mode=%s", + thread_id, + list(payload.keys()), + message_count, + payload.get("stream_mode"), + ) + # For long-running SSE streams, avoid aiohttp total timeout. + # Use socket read timeout so active heartbeats/chunks can keep the stream alive. + stream_timeout = ClientTimeout( + total=None, + connect=min(timeout, 30), + sock_connect=min(timeout, 30), + sock_read=timeout, + ) + async with session.post( + url, + json=payload, + headers={ + **self.headers, + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + timeout=stream_timeout, + proxy=self.proxy, + ) as resp: + if resp.status != 200: + text = await resp.text() + raise Exception( + f"DeerFlow runs/stream request failed: {resp.status}. {text}", + ) + async for event in _stream_sse(resp): + yield event + + async def close(self) -> None: + session = self._session + if session is None: + self._closed = True + return + + if session.closed: + self._session = None + self._closed = True + return + + try: + await session.close() + except Exception as e: + logger.warning( + "Failed to close DeerFlowAPIClient session cleanly: %s", + e, + exc_info=True, + ) + finally: + # Cleanup is best-effort and should not make teardown paths fail loudly. + self._session = None + self._closed = True + + def __del__(self) -> None: + session = getattr(self, "_session", None) + closed = bool(getattr(self, "_closed", False)) + if closed or session is None or session.closed: + return + logger.warning( + "DeerFlowAPIClient garbage collected with unclosed session; " + "explicit close() should be called by runner lifecycle (or `async with`)." + ) + + @property + def is_closed(self) -> bool: + return self._closed diff --git a/astrbot/core/agent/runners/deerflow/deerflow_content_mapper.py b/astrbot/core/agent/runners/deerflow/deerflow_content_mapper.py new file mode 100644 index 0000000000..2477adbb92 --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/deerflow_content_mapper.py @@ -0,0 +1,190 @@ +import base64 +from collections.abc import Callable +from typing import Any + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core.message.message_event_result import MessageChain + +from .deerflow_stream_utils import extract_text + + +def is_likely_base64_image(value: str) -> bool: + if " " in value: + return False + + compact = value.replace("\n", "").replace("\r", "") + if not compact or len(compact) < 32 or len(compact) % 4 != 0: + return False + + base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" + if any(ch not in base64_chars for ch in compact): + return False + try: + base64.b64decode(compact, validate=True) + except Exception: + return False + return True + + +def build_user_content(prompt: str, image_urls: list[str]) -> Any: + if not image_urls: + return prompt + + content: list[dict[str, Any]] = [] + skipped_invalid_images = 0 + any_valid_image = False + if prompt: + content.append({"type": "text", "text": prompt}) + + for image_url in image_urls: + url = image_url + if not isinstance(url, str): + skipped_invalid_images += 1 + logger.debug( + "Skipped DeerFlow image input because value is not a string: %r", + type(image_url).__name__, + ) + continue + url = url.strip() + if not url: + skipped_invalid_images += 1 + logger.debug("Skipped DeerFlow image input because value is empty.") + continue + if url.startswith(("http://", "https://", "data:")): + content.append({"type": "image_url", "image_url": {"url": url}}) + any_valid_image = True + continue + if not is_likely_base64_image(url): + skipped_invalid_images += 1 + logger.debug( + "Skipped DeerFlow image input because it is neither URL/data URI nor valid base64." + ) + continue + compact_base64 = url.replace("\n", "").replace("\r", "") + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{compact_base64}"}, + }, + ) + any_valid_image = True + + if skipped_invalid_images: + note_text = ( + "Note: some images could not be processed and were ignored." + if any_valid_image + else "Note: none of the provided images could be processed." + ) + content.insert(0, {"type": "text", "text": note_text}) + if not any_valid_image: + logger.warning( + "All %d provided DeerFlow image inputs were rejected as invalid or unsupported.", + skipped_invalid_images, + ) + else: + logger.info( + "%d DeerFlow image input(s) were rejected as invalid or unsupported.", + skipped_invalid_images, + ) + logger.debug( + "Skipped %d DeerFlow image inputs that were neither URL/data URI nor valid base64.", + skipped_invalid_images, + ) + return content + + +def image_component_from_url(url: Any) -> Comp.Image | None: + if not isinstance(url, str): + return None + + normalized = url.strip() + if not normalized: + return None + + if normalized.startswith(("http://", "https://")): + try: + return Comp.Image.fromURL(normalized) + except Exception: + return None + + if not normalized.startswith("data:"): + return None + + header, sep, payload = normalized.partition(",") + if not sep: + return None + if ";base64" not in header.lower(): + return None + + compact_payload = payload.replace("\n", "").replace("\r", "").strip() + if not compact_payload: + return None + try: + base64.b64decode(compact_payload, validate=True) + except Exception: + return None + return Comp.Image.fromBase64(compact_payload) + + +def append_components_from_content( + content: Any, + components: list[Comp.BaseMessageComponent], + image_resolver: Callable[[Any], Comp.Image | None], +) -> None: + if isinstance(content, str): + if content: + components.append(Comp.Plain(content)) + return + + if isinstance(content, list): + for item in content: + append_components_from_content(item, components, image_resolver) + return + + if not isinstance(content, dict): + return + + item_type = str(content.get("type", "")).lower() + if item_type == "text" and isinstance(content.get("text"), str): + text = content["text"] + if text: + components.append(Comp.Plain(text)) + return + + if item_type == "image_url": + image_payload = content.get("image_url") + image_url: Any = image_payload + if isinstance(image_payload, dict): + image_url = image_payload.get("url") + image_comp = image_resolver(image_url) + if image_comp is not None: + components.append(image_comp) + return + + if "content" in content: + append_components_from_content( + content.get("content"), components, image_resolver + ) + return + + kwargs = content.get("kwargs") + if isinstance(kwargs, dict) and "content" in kwargs: + append_components_from_content( + kwargs.get("content"), components, image_resolver + ) + + +def build_chain_from_ai_content( + content: Any, + image_resolver: Callable[[Any], Comp.Image | None], +) -> MessageChain: + components: list[Comp.BaseMessageComponent] = [] + append_components_from_content(content, components, image_resolver) + if components: + return MessageChain(chain=components) + + fallback_text = extract_text(content) + if fallback_text: + return MessageChain(chain=[Comp.Plain(fallback_text)]) + return MessageChain() diff --git a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py new file mode 100644 index 0000000000..0c8a5bb385 --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py @@ -0,0 +1,201 @@ +import typing as T +from collections.abc import Iterable + + +def extract_text(content: T.Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, dict): + if isinstance(content.get("text"), str): + return content["text"] + if "content" in content: + return extract_text(content.get("content")) + if "kwargs" in content and isinstance(content["kwargs"], dict): + return extract_text(content["kwargs"].get("content")) + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + item_type = item.get("type") + if item_type == "text" and isinstance(item.get("text"), str): + parts.append(item["text"]) + elif "content" in item: + parts.append(extract_text(item["content"])) + return "\n".join([p for p in parts if p]).strip() + return str(content) if content is not None else "" + + +def extract_messages_from_values_data(data: T.Any) -> list[T.Any]: + """Extract messages list from possible values event payload shapes.""" + candidates: list[T.Any] = [] + if isinstance(data, dict): + candidates.append(data) + if isinstance(data.get("values"), dict): + candidates.append(data["values"]) + elif isinstance(data, list): + candidates.extend([x for x in data if isinstance(x, dict)]) + + for item in candidates: + messages = item.get("messages") + if isinstance(messages, list): + return messages + return [] + + +def is_ai_message(message: dict[str, T.Any]) -> bool: + role = str(message.get("role", "")).lower() + if role in {"assistant", "ai"}: + return True + + msg_type = str(message.get("type", "")).lower() + if msg_type in {"ai", "assistant", "aimessage", "aimessagechunk"}: + return True + if "ai" in msg_type and all( + token not in msg_type for token in ("human", "tool", "system") + ): + return True + return False + + +def extract_latest_ai_text(messages: Iterable[T.Any]) -> str: + # Scan backwards to get the latest assistant/ai message text. + if isinstance(messages, (list, tuple)): + iterable = reversed(messages) + else: + # Fallback for generic iterables (e.g. generators). + iterable = reversed(list(messages)) + + for msg in iterable: + if not isinstance(msg, dict): + continue + if is_ai_message(msg): + text = extract_text(msg.get("content")) + if text: + return text + return "" + + +def extract_latest_ai_message(messages: Iterable[T.Any]) -> dict[str, T.Any] | None: + if isinstance(messages, (list, tuple)): + iterable = reversed(messages) + else: + iterable = reversed(list(messages)) + + for msg in iterable: + if not isinstance(msg, dict): + continue + if is_ai_message(msg): + return msg + return None + + +def is_clarification_tool_message(message: dict[str, T.Any]) -> bool: + msg_type = str(message.get("type", "")).lower() + tool_name = str(message.get("name", "")).lower() + return msg_type == "tool" and tool_name == "ask_clarification" + + +def extract_latest_clarification_text(messages: Iterable[T.Any]) -> str: + if isinstance(messages, (list, tuple)): + iterable = reversed(messages) + else: + iterable = reversed(list(messages)) + + for msg in iterable: + if not isinstance(msg, dict): + continue + if is_clarification_tool_message(msg): + text = extract_text(msg.get("content")) + if text: + return text + return "" + + +def get_message_id(message: T.Any) -> str: + if not isinstance(message, dict): + return "" + msg_id = message.get("id") + return msg_id if isinstance(msg_id, str) else "" + + +def extract_event_message_obj(data: T.Any) -> dict[str, T.Any] | None: + msg_obj = data + if isinstance(data, (list, tuple)) and data: + msg_obj = data[0] + if isinstance(msg_obj, dict) and isinstance(msg_obj.get("data"), dict): + # Some servers wrap message body in {"data": {...}} + msg_obj = msg_obj["data"] + return msg_obj if isinstance(msg_obj, dict) else None + + +def extract_ai_delta_from_event_data(data: T.Any) -> str: + # LangGraph messages-tuple events usually carry either: + # - {"type": "ai", "content": "..."} + # - [message_obj, metadata] + msg_obj = extract_event_message_obj(data) + if not msg_obj: + return "" + if is_ai_message(msg_obj): + return extract_text(msg_obj.get("content")) + return "" + + +def extract_clarification_from_event_data(data: T.Any) -> str: + msg_obj = extract_event_message_obj(data) + if not msg_obj: + return "" + if is_clarification_tool_message(msg_obj): + return extract_text(msg_obj.get("content")) + return "" + + +def _iter_custom_event_items(data: T.Any) -> list[dict[str, T.Any]]: + items: list[dict[str, T.Any]] = [] + if isinstance(data, dict): + return [data] + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + items.append(item) + elif isinstance(item, (list, tuple)): + for nested in item: + if isinstance(nested, dict): + items.append(nested) + return items + + +def extract_task_failures_from_custom_event(data: T.Any) -> list[str]: + failures: list[str] = [] + for item in _iter_custom_event_items(data): + event_type = str(item.get("type", "")).lower() + if event_type not in {"task_failed", "task_timed_out"}: + continue + + task_id = str(item.get("task_id", "")).strip() + error_text = extract_text(item.get("error")).strip() + if task_id and error_text: + failures.append(f"{task_id}: {error_text}") + elif error_text: + failures.append(error_text) + elif task_id: + failures.append(f"{task_id}: unknown error") + else: + failures.append("unknown task failure") + return failures + + +def build_task_failure_summary(failures: list[str]) -> str: + if not failures: + return "" + deduped: list[str] = [] + seen: set[str] = set() + for failure in failures: + if failure not in seen: + seen.add(failure) + deduped.append(failure) + if len(deduped) == 1: + return f"DeerFlow subtask failed: {deduped[0]}" + joined = "\n".join([f"- {item}" for item in deduped[:5]]) + return f"DeerFlow subtasks failed:\n{joined}" diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index fa9d71d745..e7723150f6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -113,6 +113,7 @@ "dify_agent_runner_provider_id": "", "coze_agent_runner_provider_id": "", "dashscope_agent_runner_provider_id": "", + "deerflow_agent_runner_provider_id": "", "unsupported_streaming_strategy": "realtime_segmenting", "reachability_check": False, "max_agent_step": 30, @@ -1252,6 +1253,25 @@ class ChatProviderTemplate(TypedDict): "timeout": 60, "proxy": "", }, + "DeerFlow": { + "id": "deerflow", + "provider": "deerflow", + "type": "deerflow", + "provider_type": "agent_runner", + "enable": True, + "deerflow_api_base": "http://127.0.0.1:2026", + "deerflow_api_key": "", + "deerflow_auth_header": "", + "deerflow_assistant_id": "lead_agent", + "deerflow_model_name": "", + "deerflow_thinking_enabled": False, + "deerflow_plan_mode": False, + "deerflow_subagent_enabled": False, + "deerflow_max_concurrent_subagents": 3, + "deerflow_recursion_limit": 1000, + "timeout": 300, + "proxy": "", + }, "FastGPT": { "id": "fastgpt", "provider": "fastgpt", @@ -2258,6 +2278,55 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn", }, + "deerflow_api_base": { + "description": "API Base URL", + "type": "string", + "hint": "DeerFlow API 网关地址,默认为 http://127.0.0.1:2026", + }, + "deerflow_api_key": { + "description": "DeerFlow API Key", + "type": "string", + "hint": "可选。若 DeerFlow 网关配置了 Bearer 鉴权,则在此填写。", + }, + "deerflow_auth_header": { + "description": "Authorization Header", + "type": "string", + "hint": "可选。自定义 Authorization 请求头,优先级高于 DeerFlow API Key。", + }, + "deerflow_assistant_id": { + "description": "Assistant ID", + "type": "string", + "hint": "LangGraph assistant_id,默认为 lead_agent。", + }, + "deerflow_model_name": { + "description": "模型名称覆盖", + "type": "string", + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。", + }, + "deerflow_thinking_enabled": { + "description": "启用思考模式", + "type": "bool", + }, + "deerflow_plan_mode": { + "description": "启用计划模式", + "type": "bool", + "hint": "对应 DeerFlow 的 is_plan_mode。", + }, + "deerflow_subagent_enabled": { + "description": "启用子智能体", + "type": "bool", + "hint": "对应 DeerFlow 的 subagent_enabled。", + }, + "deerflow_max_concurrent_subagents": { + "description": "子智能体最大并发数", + "type": "int", + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", + }, + "deerflow_recursion_limit": { + "description": "递归深度上限", + "type": "int", + "hint": "对应 LangGraph recursion_limit。", + }, "auto_save_history": { "description": "由 Coze 管理对话记录", "type": "bool", @@ -2335,6 +2404,9 @@ class ChatProviderTemplate(TypedDict): "dashscope_agent_runner_provider_id": { "type": "string", }, + "deerflow_agent_runner_provider_id": { + "type": "string", + }, "max_agent_step": { "type": "int", }, @@ -2543,7 +2615,7 @@ class ChatProviderTemplate(TypedDict): "metadata": { "agent_runner": { "description": "Agent 执行方式", - "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。", + "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify、Coze、DeerFlow 等第三方 Agent 执行器,不需要修改此节。", "type": "object", "items": { "provider_settings.enable": { @@ -2554,8 +2626,14 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": { "description": "执行器", "type": "string", - "options": ["local", "dify", "coze", "dashscope"], - "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"], + "options": ["local", "dify", "coze", "dashscope", "deerflow"], + "labels": [ + "内置 Agent", + "Dify", + "Coze", + "阿里云百炼应用", + "DeerFlow", + ], "condition": { "provider_settings.enable": True, }, @@ -2587,6 +2665,15 @@ class ChatProviderTemplate(TypedDict): "provider_settings.enable": True, }, }, + "provider_settings.deerflow_agent_runner_provider_id": { + "description": "DeerFlow Agent 执行器提供商 ID", + "type": "string", + "_special": "select_agent_runner_provider:deerflow", + "condition": { + "provider_settings.agent_runner_type": "deerflow", + "provider_settings.enable": True, + }, + }, }, }, "ai": { diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index eba6a4fd66..0965fe7f7f 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -182,6 +182,8 @@ class ResultContentType(enum.Enum): LLM_RESULT = enum.auto() """调用 LLM 产生的结果""" + AGENT_RUNNER_ERROR = enum.auto() + """第三方 Agent Runner 返回的错误结果""" GENERAL_RESULT = enum.auto() """普通的消息结果""" STREAMING_RESULT = enum.auto() @@ -246,6 +248,13 @@ def is_llm_result(self) -> bool: """是否为 LLM 结果。""" return self.result_content_type == ResultContentType.LLM_RESULT + def is_model_result(self) -> bool: + """Whether result comes from model execution (including runner errors).""" + return self.result_content_type in ( + ResultContentType.LLM_RESULT, + ResultContentType.AGENT_RUNNER_ERROR, + ) + # 为了兼容旧版代码,保留 CommandResult 的别名 CommandResult = MessageEventResult diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index fcc574bc4f..ffaec00b49 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -1,5 +1,6 @@ import asyncio -from collections.abc import AsyncGenerator +import inspect +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TYPE_CHECKING from astrbot.core import astrbot_config, logger @@ -7,6 +8,13 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( DashscopeAgentRunner, ) +from astrbot.core.agent.runners.deerflow.constants import ( + DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY, + DEERFLOW_PROVIDER_TYPE, +) +from astrbot.core.agent.runners.deerflow.deerflow_agent_runner import ( + DeerFlowAgentRunner, +) from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS from astrbot.core.message.components import Image @@ -23,12 +31,14 @@ if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner + from astrbot.core.provider.entities import LLMResponse from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( ProviderRequest, ) from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.config_number import coerce_int_config from astrbot.core.utils.metrics import Metric from .....astr_agent_context import AgentContextWrapper, AstrAgentContext @@ -38,14 +48,22 @@ "dify": "dify_agent_runner_provider_id", "coze": "coze_agent_runner_provider_id", "dashscope": "dashscope_agent_runner_provider_id", + DEERFLOW_PROVIDER_TYPE: DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY, } +THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY = "_third_party_runner_error" +STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC = 30 +RUNNER_NO_RESULT_FALLBACK_MESSAGE = "Agent Runner did not return any result." +RUNNER_NO_FINAL_RESPONSE_LOG = ( + "Agent Runner returned no final response, fallback to streamed error/result chain." +) +RUNNER_NO_RESULT_LOG = "Agent Runner did not return final result." async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, custom_error_message: str | None = None, -) -> AsyncGenerator[MessageChain | None, None]: +) -> AsyncGenerator[tuple[MessageChain, bool], None]: """ 运行第三方 agent runner 并转换响应格式 类似于 run_agent 函数,但专门处理第三方 agent runner @@ -55,10 +73,12 @@ async def run_third_party_agent( if resp.type == "streaming_delta": if stream_to_general: continue - yield resp.data["chain"] + yield resp.data["chain"], False elif resp.type == "llm_result": if stream_to_general: - yield resp.data["chain"] + yield resp.data["chain"], False + elif resp.type == "err": + yield resp.data["chain"], True except Exception as e: logger.error(f"Third party agent runner error: {e}") err_msg = custom_error_message @@ -68,7 +88,77 @@ async def run_third_party_agent( f"Error Type: {type(e).__name__} (3rd party)\n" f"Error Message: {str(e)}" ) - yield MessageChain().message(err_msg) + yield MessageChain().message(err_msg), True + + +class _RunnerResultAggregator: + def __init__(self) -> None: + self.merged_chain: list = [] + self.has_error = False + + def add_chunk(self, chain: MessageChain, is_error: bool) -> None: + self.merged_chain.extend(chain.chain or []) + if is_error: + self.has_error = True + + def finalize( + self, + final_resp: "LLMResponse | None", + ) -> tuple[list, bool]: + if not final_resp or not final_resp.result_chain: + if self.merged_chain: + logger.warning(RUNNER_NO_FINAL_RESPONSE_LOG) + return self.merged_chain, self.has_error + + logger.warning(RUNNER_NO_RESULT_LOG) + fallback_error_chain = MessageChain().message( + RUNNER_NO_RESULT_FALLBACK_MESSAGE, + ) + return fallback_error_chain.chain or [], True + + final_chain = final_resp.result_chain.chain or [] + is_runner_error = self.has_error or final_resp.role == "err" + return final_chain, is_runner_error + + +def _start_stream_watchdog( + *, + timeout_sec: int, + is_stream_consumed: Callable[[], bool], + close_runner_once: Callable[[], Awaitable[None]], +) -> asyncio.Task[None]: + async def _watchdog() -> None: + try: + await asyncio.sleep(timeout_sec) + except asyncio.CancelledError: + return + if not is_stream_consumed(): + logger.warning( + "Third-party runner stream was never consumed in %ss; closing runner to avoid resource leak.", + timeout_sec, + ) + try: + await close_runner_once() + except Exception: + logger.warning( + "Exception while closing third-party runner from stream watchdog.", + exc_info=True, + ) + + return asyncio.create_task(_watchdog()) + + +async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: + close_callable = getattr(runner, "close", None) + if not callable(close_callable): + return + + try: + close_result = close_callable() + if inspect.isawaitable(close_result): + await close_result + except Exception as e: + logger.warning(f"Failed to close third-party runner cleanly: {e}") class ThirdPartyAgentSubStage(Stage): @@ -85,6 +175,16 @@ async def initialize(self, ctx: PipelineContext) -> None: self.unsupported_streaming_strategy: str = settings[ "unsupported_streaming_strategy" ] + self.stream_consumption_close_timeout_sec: int = coerce_int_config( + settings.get( + "third_party_stream_consumption_close_timeout_sec", + STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC, + ), + default=STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC, + min_value=1, + field_name="third_party_stream_consumption_close_timeout_sec", + source="Third-party runner config", + ) async def _resolve_persona_custom_error_message( self, event: AstrMessageEvent @@ -104,6 +204,88 @@ async def _resolve_persona_custom_error_message( logger.debug("Failed to resolve persona custom error message: %s", e) return None + async def _handle_streaming_response( + self, + *, + runner: "BaseAgentRunner", + event: AstrMessageEvent, + custom_error_message: str | None, + close_runner_once: Callable[[], Awaitable[None]], + mark_stream_consumed: Callable[[], None], + ) -> AsyncGenerator[None, None]: + aggregator = _RunnerResultAggregator() + + async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: + mark_stream_consumed() + try: + async for chain, is_error in run_third_party_agent( + runner, + stream_to_general=False, + custom_error_message=custom_error_message, + ): + aggregator.add_chunk(chain, is_error) + if is_error: + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) + yield chain + finally: + # Streaming runner cleanup must happen after consumer + # finishes iterating to avoid tearing down active streams. + await close_runner_once() + + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(_stream_runner_chain()), + ) + yield + + if runner.done(): + final_chain, is_runner_error = aggregator.finalize( + runner.get_final_llm_resp() + ) + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + + async def _handle_non_streaming_response( + self, + *, + runner: "BaseAgentRunner", + event: AstrMessageEvent, + stream_to_general: bool, + custom_error_message: str | None, + ) -> AsyncGenerator[None, None]: + aggregator = _RunnerResultAggregator() + async for chain, is_error in run_third_party_agent( + runner, + stream_to_general=stream_to_general, + custom_error_message=custom_error_message, + ): + aggregator.add_chunk(chain, is_error) + if is_error: + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) + yield + + final_chain, is_runner_error = aggregator.finalize(runner.get_final_llm_resp()) + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_runner_error) + result_content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if is_runner_error + else ResultContentType.LLM_RESULT + ) + event.set_result( + MessageEventResult( + chain=final_chain, + result_content_type=result_content_type, + ), + ) + # Second yield keeps scheduler progress consistent after final result update. + yield + async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: @@ -152,6 +334,8 @@ async def process( runner = CozeAgentRunner[AstrAgentContext]() elif self.runner_type == "dashscope": runner = DashscopeAgentRunner[AstrAgentContext]() + elif self.runner_type == DEERFLOW_PROVIDER_TYPE: + runner = DeerFlowAgentRunner[AstrAgentContext]() else: raise ValueError( f"Unsupported third party agent runner type: {self.runner_type}", @@ -170,63 +354,68 @@ async def process( self.unsupported_streaming_strategy == "turn_off" and not event.platform_meta.support_streaming_message ) + streaming_used = streaming_response and not stream_to_general - await runner.reset( - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=60, - ), - agent_hooks=MAIN_AGENT_HOOKS, - provider_config=self.prov_cfg, - streaming=streaming_response, - ) + runner_closed = False + stream_consumed = False + stream_watchdog_task: asyncio.Task[None] | None = None - if streaming_response and not stream_to_general: - # 流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_third_party_agent( - runner, - stream_to_general=False, - custom_error_message=custom_error_message, - ), - ), - ) - yield - if runner.done(): - final_resp = runner.get_final_llm_resp() - if final_resp and final_resp.result_chain: - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - else: - # 非流式响应或转换为普通响应 - async for _ in run_third_party_agent( - runner, - stream_to_general=stream_to_general, - custom_error_message=custom_error_message, - ): - yield - - final_resp = runner.get_final_llm_resp() - - if not final_resp or not final_resp.result_chain: - logger.warning("Agent Runner 未返回最终结果。") + async def close_runner_once() -> None: + nonlocal runner_closed + if runner_closed: return + runner_closed = True + await _close_runner_if_supported(runner) - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.LLM_RESULT, + def mark_stream_consumed() -> None: + nonlocal stream_consumed + stream_consumed = True + if stream_watchdog_task and not stream_watchdog_task.done(): + stream_watchdog_task.cancel() + + try: + await runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, + streaming=streaming_response, ) - yield + + if streaming_used: + stream_watchdog_task = _start_stream_watchdog( + timeout_sec=self.stream_consumption_close_timeout_sec, + is_stream_consumed=lambda: stream_consumed, + close_runner_once=close_runner_once, + ) + async for _ in self._handle_streaming_response( + runner=runner, + event=event, + custom_error_message=custom_error_message, + close_runner_once=close_runner_once, + mark_stream_consumed=mark_stream_consumed, + ): + yield + else: + async for _ in self._handle_non_streaming_response( + runner=runner, + event=event, + stream_to_general=stream_to_general, + custom_error_message=custom_error_message, + ): + yield + finally: + if ( + stream_watchdog_task + and not stream_watchdog_task.done() + and (stream_consumed or runner_closed) + ): + stream_watchdog_task.cancel() + if not streaming_used: + await close_runner_once() asyncio.create_task( Metric.upload( diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 72e853ffcc..bd307f8b77 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -135,7 +135,7 @@ def is_seg_reply_required(self, event: AstrMessageEvent) -> bool: if (result := event.get_result()) is None: return False - if self.only_llm_result and not result.is_llm_result(): + if self.only_llm_result and not result.is_model_result(): return False if event.get_platform_name() in [ diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 15d68fb22e..f2fe8161b5 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -209,7 +209,7 @@ async def process( "dingtalk", ]: if ( - self.only_llm_result and result.is_llm_result() + self.only_llm_result and result.is_model_result() ) or not self.only_llm_result: new_chain = [] for comp in result.chain: diff --git a/astrbot/core/utils/config_number.py b/astrbot/core/utils/config_number.py new file mode 100644 index 0000000000..f9ce138397 --- /dev/null +++ b/astrbot/core/utils/config_number.py @@ -0,0 +1,64 @@ +from astrbot.core import logger + + +def coerce_int_config( + value: object, + *, + default: int, + min_value: int | None = None, + field_name: str | None = None, + source: str = "config", + warn: bool = True, +) -> int: + label = f"'{field_name}'" if field_name else "value" + + if isinstance(value, bool): + if warn: + logger.warning( + "%s %s should be numeric, got boolean. Fallback to %s.", + source, + label, + default, + ) + parsed = default + elif isinstance(value, int): + parsed = value + elif isinstance(value, str): + try: + parsed = int(value.strip()) + except ValueError: + if warn: + logger.warning( + "%s %s value '%s' is not numeric. Fallback to %s.", + source, + label, + value, + default, + ) + parsed = default + else: + try: + parsed = int(value) + except (TypeError, ValueError): + if warn: + logger.warning( + "%s %s has unsupported type %s. Fallback to %s.", + source, + label, + type(value).__name__, + default, + ) + parsed = default + + if min_value is not None and parsed < min_value: + if warn: + logger.warning( + "%s %s=%s is below minimum %s. Fallback to %s.", + source, + label, + parsed, + min_value, + min_value, + ) + parsed = min_value + return parsed diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 6a300302d9..40b899620d 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -1,6 +1,10 @@ import traceback from astrbot.core import astrbot_config, logger +from astrbot.core.agent.runners.deerflow.constants import ( + DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY, + DEERFLOW_PROVIDER_TYPE, +) from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 from astrbot.core.db.migration.migra_token_usage import migrate_token_usage @@ -27,6 +31,11 @@ def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None: "id" ] conf["provider_settings"]["agent_runner_type"] = "dashscope" + elif p["type"] == DEERFLOW_PROVIDER_TYPE: + conf["provider_settings"][DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY] = p[ + "id" + ] + conf["provider_settings"]["agent_runner_type"] = DEERFLOW_PROVIDER_TYPE conf.save_config() except Exception as e: logger.error(f"Migration for third party agent runner configs failed: {e!s}") @@ -153,7 +162,7 @@ async def migra( ids_map = {} for prov in providers: type_ = prov.get("type") - if type_ in ["dify", "coze", "dashscope"]: + if type_ in ["dify", "coze", "dashscope", DEERFLOW_PROVIDER_TYPE]: prov["provider_type"] = "agent_runner" ids_map[prov["id"]] = { "type": type_, diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index cad25e835c..eeb0ac8052 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -3,7 +3,7 @@ "name": "AI", "agent_runner": { "description": "Agent Runner", - "hint": "Select the runner for AI conversations. Defaults to AstrBot's built-in Agent runner, which supports knowledge base, persona, and tool calling features. You don't need to modify this section unless you plan to integrate third-party Agent runners like Dify or Coze.", + "hint": "Select the runner for AI conversations. Defaults to AstrBot's built-in Agent runner, which supports knowledge base, persona, and tool calling features. You don't need to modify this section unless you plan to integrate third-party Agent runners like Dify, Coze, or DeerFlow.", "provider_settings": { "enable": { "description": "Enable", @@ -15,7 +15,8 @@ "Built-in Agent", "Dify", "Coze", - "Alibaba Cloud Bailian Application" + "Alibaba Cloud Bailian Application", + "DeerFlow" ] }, "coze_agent_runner_provider_id": { @@ -26,6 +27,9 @@ }, "dashscope_agent_runner_provider_id": { "description": "Alibaba Cloud Bailian Application Agent Runner Provider ID" + }, + "deerflow_agent_runner_provider_id": { + "description": "DeerFlow Agent Runner Provider ID" } } }, @@ -1363,6 +1367,45 @@ "description": "API Base URL", "hint": "Base URL for the Coze API. Default: https://api.coze.cn" }, + "deerflow_api_base": { + "description": "API Base URL", + "hint": "DeerFlow API gateway URL. Default: http://127.0.0.1:2026" + }, + "deerflow_api_key": { + "description": "DeerFlow API Key", + "hint": "Optional. Fill this if your DeerFlow gateway is protected by Bearer auth." + }, + "deerflow_auth_header": { + "description": "Authorization Header", + "hint": "Optional. Custom Authorization header value; takes precedence over DeerFlow API Key." + }, + "deerflow_assistant_id": { + "description": "Assistant ID", + "hint": "LangGraph assistant_id, default is lead_agent." + }, + "deerflow_model_name": { + "description": "Model name override", + "hint": "Optional. Overrides DeerFlow default model (maps to runtime context model_name)." + }, + "deerflow_thinking_enabled": { + "description": "Enable thinking mode" + }, + "deerflow_plan_mode": { + "description": "Enable plan mode", + "hint": "Maps to DeerFlow is_plan_mode." + }, + "deerflow_subagent_enabled": { + "description": "Enable subagent", + "hint": "Maps to DeerFlow subagent_enabled." + }, + "deerflow_max_concurrent_subagents": { + "description": "Max concurrent subagents", + "hint": "Maps to DeerFlow max_concurrent_subagents. Effective only when subagent is enabled. Default: 3." + }, + "deerflow_recursion_limit": { + "description": "Recursion limit", + "hint": "Maps to LangGraph recursion_limit." + }, "auto_save_history": { "description": "Conversation history managed by Coze", "hint": "When enabled, Coze manages conversation history. AstrBot's locally saved context will not take effect (read-only), and operations on AstrBot context will not apply. If disabled, AstrBot manages the context." diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index e5eea63fd0..238ad05662 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -3,7 +3,7 @@ "name": "AI 配置", "agent_runner": { "description": "Agent 执行方式", - "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。", + "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify、Coze、DeerFlow 等第三方 Agent 执行器,不需要修改此节。", "provider_settings": { "enable": { "description": "启用", @@ -15,7 +15,8 @@ "内置 Agent", "Dify", "Coze", - "阿里云百炼应用" + "阿里云百炼应用", + "DeerFlow" ] }, "coze_agent_runner_provider_id": { @@ -26,6 +27,9 @@ }, "dashscope_agent_runner_provider_id": { "description": "阿里云百炼应用 Agent 执行器提供商 ID" + }, + "deerflow_agent_runner_provider_id": { + "description": "DeerFlow Agent 执行器提供商 ID" } } }, @@ -1366,6 +1370,45 @@ "description": "API Base URL", "hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn" }, + "deerflow_api_base": { + "description": "API Base URL", + "hint": "DeerFlow API 网关地址,默认为 http://127.0.0.1:2026" + }, + "deerflow_api_key": { + "description": "DeerFlow API Key", + "hint": "可选。若 DeerFlow 网关配置了 Bearer 鉴权,则在此填写。" + }, + "deerflow_auth_header": { + "description": "Authorization Header", + "hint": "可选。自定义 Authorization 请求头,优先级高于 DeerFlow API Key。" + }, + "deerflow_assistant_id": { + "description": "Assistant ID", + "hint": "LangGraph assistant_id,默认为 lead_agent。" + }, + "deerflow_model_name": { + "description": "模型名称覆盖", + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。" + }, + "deerflow_thinking_enabled": { + "description": "启用思考模式" + }, + "deerflow_plan_mode": { + "description": "启用计划模式", + "hint": "对应 DeerFlow 的 is_plan_mode。" + }, + "deerflow_subagent_enabled": { + "description": "启用子智能体", + "hint": "对应 DeerFlow 的 subagent_enabled。" + }, + "deerflow_max_concurrent_subagents": { + "description": "子智能体最大并发数", + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" + }, + "deerflow_recursion_limit": { + "description": "递归深度上限", + "hint": "对应 LangGraph recursion_limit。" + }, "auto_save_history": { "description": "由 Coze 管理对话记录", "hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。" diff --git a/dashboard/src/utils/providerUtils.js b/dashboard/src/utils/providerUtils.js index b02af6942c..d35941f6f7 100644 --- a/dashboard/src/utils/providerUtils.js +++ b/dashboard/src/utils/providerUtils.js @@ -25,6 +25,7 @@ export function getProviderIcon(type) { 'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg', "coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg", 'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg', + 'deerflow': 'https://cdn.jsdelivr.net/gh/bytedance/deer-flow@main/frontend/public/images/deer.svg', 'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg', 'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg', 'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',