From 5e33cbb6288ffe4eb91aa0a5491d56c0895a0629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:11:15 +0900 Subject: [PATCH 01/24] feat: integrate DeerFlow agent runner and improve stream handling --- .../builtin_commands/commands/conversation.py | 1 + .../runners/deerflow/deerflow_agent_runner.py | 558 ++++++++++++++++++ .../runners/deerflow/deerflow_api_client.py | 153 +++++ astrbot/core/config/default.py | 93 ++- .../method/agent_sub_stages/third_party.py | 26 +- astrbot/core/utils/migra_helper.py | 5 +- .../assets/images/platform_logos/deerflow.png | Bin 0 -> 775 bytes .../en-US/features/config-metadata.json | 47 +- .../zh-CN/features/config-metadata.json | 47 +- dashboard/src/utils/providerUtils.js | 1 + 10 files changed, 922 insertions(+), 9 deletions(-) create mode 100644 astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py create mode 100644 astrbot/core/agent/runners/deerflow/deerflow_api_client.py create mode 100644 dashboard/src/assets/images/platform_logos/deerflow.png diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index 55b75cb1bd..8ca6bfe9bf 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -12,6 +12,7 @@ "dify": "dify_conversation_id", "coze": "coze_conversation_id", "dashscope": "dashscope_conversation_id", + "deerflow": "deerflow_thread_id", } THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys()) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py new file mode 100644 index 0000000000..7b7399c1cc --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -0,0 +1,558 @@ +import asyncio +import sys +import typing as T +from collections.abc import Iterable + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core import sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) + +from ...hooks import BaseAgentRunHooks +from ...response import AgentResponseData +from ...run_context import ContextWrapper, TContext +from ..base import AgentResponse, AgentState, BaseAgentRunner +from .deerflow_api_client import DeerFlowAPIClient + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class DeerFlowAgentRunner(BaseAgentRunner[TContext]): + """DeerFlow Agent Runner via LangGraph HTTP API.""" + + def _format_exception(self, err: Exception) -> str: + err_type = type(err).__name__ + detail = str(err).strip() + + if isinstance(err, (asyncio.TimeoutError, TimeoutError)): + timeout_text = ( + f"{self.timeout}s" + if isinstance(getattr(self, "timeout", None), int | float) + else "configured timeout" + ) + return ( + f"{err_type}: request timed out after {timeout_text}. " + "Please check DeerFlow service health and backend logs." + ) + + if detail: + if detail.startswith(f"{err_type}:"): + return detail + return f"{err_type}: {detail}" + + return f"{err_type}: no detailed error message provided." + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_base = provider_config.get( + "deerflow_api_base", "http://127.0.0.1:2026" + ) + if not isinstance(self.api_base, str) or not self.api_base.startswith( + ("http://", "https://"), + ): + raise Exception( + "DeerFlow API Base URL format is invalid. It must start with http:// or https://.", + ) + self.api_key = provider_config.get("deerflow_api_key", "") + self.auth_header = provider_config.get("deerflow_auth_header", "") + self.assistant_id = provider_config.get("deerflow_assistant_id", "lead_agent") + self.model_name = provider_config.get("deerflow_model_name", "") + self.thinking_enabled = bool( + provider_config.get("deerflow_thinking_enabled", False), + ) + self.plan_mode = bool(provider_config.get("deerflow_plan_mode", False)) + self.subagent_enabled = bool( + provider_config.get("deerflow_subagent_enabled", False), + ) + self.max_concurrent_subagents = provider_config.get( + "deerflow_max_concurrent_subagents", + 3, + ) + if isinstance(self.max_concurrent_subagents, str): + self.max_concurrent_subagents = int(self.max_concurrent_subagents) + if self.max_concurrent_subagents < 1: + self.max_concurrent_subagents = 1 + + self.timeout = provider_config.get("timeout", 300) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + self.recursion_limit = provider_config.get("deerflow_recursion_limit", 1000) + if isinstance(self.recursion_limit, str): + self.recursion_limit = int(self.recursion_limit) + + self.api_client = DeerFlowAPIClient( + api_base=self.api_base, + api_key=self.api_key, + auth_header=self.auth_header, + ) + + @override + async def step(self): + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + if self.done(): + return + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + self._transition_state(AgentState.RUNNING) + + try: + async for response in self._execute_deerflow_request(): + yield response + except Exception as e: + err_msg = self._format_exception(e) + logger.error(f"DeerFlow request failed: {err_msg}", exc_info=True) + self._transition_state(AgentState.ERROR) + err_chain = MessageChain().message(f"DeerFlow request failed: {err_msg}") + self.final_llm_resp = LLMResponse( + role="err", + completion_text=f"DeerFlow request failed: {err_msg}", + result_chain=err_chain, + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=err_chain, + ), + ) + finally: + await self.api_client.close() + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + def _extract_text(self, content: T.Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, dict): + if isinstance(content.get("text"), str): + return content["text"] + if "content" in content: + return self._extract_text(content.get("content")) + if "kwargs" in content and isinstance(content["kwargs"], dict): + return self._extract_text(content["kwargs"].get("content")) + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + item_type = item.get("type") + if item_type == "text" and isinstance(item.get("text"), str): + parts.append(item["text"]) + elif "content" in item: + parts.append(str(item["content"])) + return "\n".join([p for p in parts if p]).strip() + return str(content) if content is not None else "" + + def _extract_messages_from_values_data(self, data: T.Any) -> list[T.Any]: + """Extract messages list from possible values event payload shapes.""" + candidates: list[T.Any] = [] + if isinstance(data, dict): + candidates.append(data) + if isinstance(data.get("values"), dict): + candidates.append(data["values"]) + elif isinstance(data, list): + candidates.extend([x for x in data if isinstance(x, dict)]) + + for item in candidates: + messages = item.get("messages") + if isinstance(messages, list): + return messages + return [] + + def _is_ai_message(self, message: dict[str, T.Any]) -> bool: + role = str(message.get("role", "")).lower() + if role in {"assistant", "ai"}: + return True + + msg_type = str(message.get("type", "")).lower() + if msg_type in {"ai", "assistant", "aimessage", "aimessagechunk"}: + return True + if "ai" in msg_type and all( + token not in msg_type for token in ("human", "tool", "system") + ): + return True + return False + + def _extract_latest_ai_text(self, messages: Iterable[T.Any]) -> str: + # Scan backwards to get the latest assistant/ai message text. + for msg in reversed(list(messages)): + if not isinstance(msg, dict): + continue + if self._is_ai_message(msg): + text = self._extract_text(msg.get("content")) + if text: + return text + return "" + + def _is_clarification_tool_message(self, message: dict[str, T.Any]) -> bool: + msg_type = str(message.get("type", "")).lower() + tool_name = str(message.get("name", "")).lower() + return msg_type == "tool" and tool_name == "ask_clarification" + + def _extract_latest_clarification_text(self, messages: Iterable[T.Any]) -> str: + for msg in reversed(list(messages)): + if not isinstance(msg, dict): + continue + if self._is_clarification_tool_message(msg): + text = self._extract_text(msg.get("content")) + if text: + return text + return "" + + def _get_message_id(self, message: T.Any) -> str: + if not isinstance(message, dict): + return "" + msg_id = message.get("id") + return msg_id if isinstance(msg_id, str) else "" + + def _extract_new_messages_from_values( + self, + values_messages: list[T.Any], + seen_message_ids: set[str], + ) -> list[dict[str, T.Any]]: + new_messages: list[dict[str, T.Any]] = [] + for msg in values_messages: + if not isinstance(msg, dict): + continue + msg_id = self._get_message_id(msg) + if not msg_id or msg_id in seen_message_ids: + continue + seen_message_ids.add(msg_id) + new_messages.append(msg) + return new_messages + + def _extract_event_message_obj(self, data: T.Any) -> dict[str, T.Any] | None: + msg_obj = data + if isinstance(data, (list, tuple)) and data: + msg_obj = data[0] + if isinstance(msg_obj, dict) and isinstance(msg_obj.get("data"), dict): + # Some servers wrap message body in {"data": {...}} + msg_obj = msg_obj["data"] + return msg_obj if isinstance(msg_obj, dict) else None + + def _extract_ai_delta_from_event_data(self, data: T.Any) -> str: + # LangGraph messages-tuple events usually carry either: + # - {"type": "ai", "content": "..."} + # - [message_obj, metadata] + msg_obj = self._extract_event_message_obj(data) + if not msg_obj: + return "" + if self._is_ai_message(msg_obj): + return self._extract_text(msg_obj.get("content")) + return "" + + def _extract_clarification_from_event_data(self, data: T.Any) -> str: + msg_obj = self._extract_event_message_obj(data) + if not msg_obj: + return "" + if self._is_clarification_tool_message(msg_obj): + return self._extract_text(msg_obj.get("content")) + return "" + + def _iter_custom_event_items(self, data: T.Any) -> list[dict[str, T.Any]]: + items: list[dict[str, T.Any]] = [] + if isinstance(data, dict): + return [data] + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + items.append(item) + elif isinstance(item, (list, tuple)): + for nested in item: + if isinstance(nested, dict): + items.append(nested) + return items + + def _extract_task_failures_from_custom_event(self, data: T.Any) -> list[str]: + failures: list[str] = [] + for item in self._iter_custom_event_items(data): + event_type = str(item.get("type", "")).lower() + if event_type not in {"task_failed", "task_timed_out"}: + continue + + task_id = str(item.get("task_id", "")).strip() + error_text = self._extract_text(item.get("error")).strip() + if task_id and error_text: + failures.append(f"{task_id}: {error_text}") + elif error_text: + failures.append(error_text) + elif task_id: + failures.append(f"{task_id}: unknown error") + else: + failures.append("unknown task failure") + return failures + + def _build_task_failure_summary(self, failures: list[str]) -> str: + if not failures: + return "" + deduped: list[str] = [] + seen: set[str] = set() + for failure in failures: + if failure not in seen: + seen.add(failure) + deduped.append(failure) + if len(deduped) == 1: + return f"DeerFlow subtask failed: {deduped[0]}" + joined = "\n".join([f"- {item}" for item in deduped[:5]]) + return f"DeerFlow subtasks failed:\n{joined}" + + def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: + if not image_urls: + return prompt + + content: list[dict[str, T.Any]] = [] + if prompt: + content.append({"type": "text", "text": prompt}) + + for image_url in image_urls: + url = image_url + if not isinstance(url, str): + continue + if not url.startswith(("http://", "https://", "data:")): + url = f"data:image/png;base64,{url}" + content.append({"type": "image_url", "image_url": {"url": url}}) + return content + + async def _ensure_thread_id(self, session_id: str) -> str: + thread_id = await sp.get_async( + scope="umo", + scope_id=session_id, + key="deerflow_thread_id", + default="", + ) + if thread_id: + return thread_id + + thread = await self.api_client.create_thread(timeout=min(30, self.timeout)) + thread_id = thread.get("thread_id", "") + if not thread_id: + raise Exception( + f"DeerFlow create thread returned invalid payload: {thread}" + ) + + await sp.put_async( + scope="umo", + scope_id=session_id, + key="deerflow_thread_id", + value=thread_id, + ) + return thread_id + + async def _execute_deerflow_request(self): + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + system_prompt = self.req.system_prompt + + thread_id = await self._ensure_thread_id(session_id) + + messages: list[dict[str, T.Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append( + { + "role": "user", + "content": self._build_user_content(prompt, image_urls), + }, + ) + + runtime_context: dict[str, T.Any] = { + "thread_id": thread_id, + "thinking_enabled": self.thinking_enabled, + "is_plan_mode": self.plan_mode, + "subagent_enabled": self.subagent_enabled, + } + if self.subagent_enabled: + runtime_context["max_concurrent_subagents"] = self.max_concurrent_subagents + if self.model_name: + runtime_context["model_name"] = self.model_name + + payload: dict[str, T.Any] = { + "assistant_id": self.assistant_id, + "input": {"messages": messages}, + "stream_mode": ["values", "messages-tuple", "custom"], + # LangGraph 0.6+ prefers context instead of configurable. + "context": runtime_context, + "config": { + "recursion_limit": self.recursion_limit, + }, + } + + streamed_text = "" + fallback_stream_text = "" + clarification_text = "" + task_failures: list[str] = [] + seen_message_ids: set[str] = set() + baseline_initialized = False + run_values_messages: list[dict[str, T.Any]] = [] + timed_out = False + + try: + async for event in self.api_client.stream_run( + thread_id=thread_id, + payload=payload, + timeout=self.timeout, + ): + event_type = event.get("event") + data = event.get("data") + + if event_type == "values": + values_messages = self._extract_messages_from_values_data(data) + if values_messages: + if not baseline_initialized: + baseline_initialized = True + for msg in values_messages: + msg_id = self._get_message_id(msg) + if msg_id: + seen_message_ids.add(msg_id) + continue + + new_messages = self._extract_new_messages_from_values( + values_messages, + seen_message_ids, + ) + if new_messages: + run_values_messages.extend(new_messages) + latest_text = self._extract_latest_ai_text( + run_values_messages + ) + latest_clarification = ( + self._extract_latest_clarification_text( + run_values_messages, + ) + ) + if latest_clarification: + clarification_text = latest_clarification + else: + latest_text = "" + + if self.streaming and latest_text: + if latest_text.startswith(streamed_text): + delta = latest_text[len(streamed_text) :] + if delta: + streamed_text = latest_text + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(delta), + ), + ) + elif latest_text != streamed_text: + streamed_text = latest_text + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(latest_text), + ), + ) + continue + + if event_type in {"messages-tuple", "messages", "message"}: + delta = self._extract_ai_delta_from_event_data(data) + if delta: + fallback_stream_text += delta + if self.streaming and delta and not streamed_text: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=MessageChain().message(delta)), + ) + maybe_clarification = self._extract_clarification_from_event_data( + data + ) + if maybe_clarification: + clarification_text = maybe_clarification + continue + + if event_type == "custom": + task_failures.extend( + self._extract_task_failures_from_custom_event(data), + ) + continue + + if event_type == "error": + raise Exception(f"DeerFlow stream returned error event: {data}") + + if event_type == "end": + break + except (asyncio.TimeoutError, TimeoutError): + timed_out = True + + # Clarification tool output should take precedence over partial AI/tool-call text. + if clarification_text: + final_text = clarification_text + else: + final_text = self._extract_latest_ai_text(run_values_messages) + if not final_text: + final_text = streamed_text or fallback_stream_text + if not final_text: + final_text = self._build_task_failure_summary(task_failures) + + if timed_out: + timeout_note = ( + f"DeerFlow stream timed out after {self.timeout}s. " + "Returning partial result." + ) + if final_text: + final_text = f"{final_text}\n\n{timeout_note}" + else: + raise asyncio.TimeoutError(timeout_note) + + if not final_text: + logger.warning("DeerFlow returned no text content in stream events.") + final_text = "DeerFlow returned an empty response." + + chain = MessageChain(chain=[Comp.Plain(final_text)]) + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + @override + def done(self) -> bool: + """Check whether the agent has finished or failed.""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py new file mode 100644 index 0000000000..6c4da118ae --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -0,0 +1,153 @@ +import codecs +import json +from collections.abc import AsyncGenerator +from typing import Any + +from aiohttp import ClientResponse, ClientSession, ClientTimeout + +from astrbot.core import logger + + +def _normalize_sse_newlines(text: str) -> str: + """Normalize CRLF/CR to LF so SSE block splitting works reliably.""" + return text.replace("\r\n", "\n").replace("\r", "\n") + + +def _parse_sse_data_lines(data_lines: list[str]) -> Any: + raw_data = "\n".join(data_lines) + try: + return json.loads(raw_data) + except json.JSONDecodeError: + # Some LangGraph-compatible servers emit multiple JSON fragments + # in one SSE event using repeated data lines (e.g. tuple payloads). + parsed_lines: list[Any] = [] + can_parse_all = True + for line in data_lines: + line = line.strip() + if not line: + continue + try: + parsed_lines.append(json.loads(line)) + except json.JSONDecodeError: + can_parse_all = False + break + if can_parse_all and parsed_lines: + return parsed_lines[0] if len(parsed_lines) == 1 else parsed_lines + return raw_data + + +async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict[str, Any], None]: + """Parse SSE response blocks into event/data dictionaries.""" + decoder = codecs.getincrementaldecoder("utf-8")() + buffer = "" + + async for chunk in resp.content.iter_chunked(8192): + buffer += _normalize_sse_newlines(decoder.decode(chunk)) + + while "\n\n" in buffer: + block, buffer = buffer.split("\n\n", 1) + if not block.strip(): + continue + + event_name = "message" + data_lines: list[str] = [] + + for line in block.splitlines(): + if line.startswith("event:"): + event_name = line[6:].strip() + elif line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + + if not data_lines: + continue + + data = _parse_sse_data_lines(data_lines) + + yield {"event": event_name, "data": data} + + # flush any remaining buffered text + buffer += _normalize_sse_newlines(decoder.decode(b"", final=True)) + if not buffer.strip(): + return + + event_name = "message" + data_lines = [] + for line in buffer.splitlines(): + if line.startswith("event:"): + event_name = line[6:].strip() + elif line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + if not data_lines: + return + + data = _parse_sse_data_lines(data_lines) + yield {"event": event_name, "data": data} + + +class DeerFlowAPIClient: + def __init__( + self, + api_base: str = "http://127.0.0.1:2026", + api_key: str = "", + auth_header: str = "", + ) -> None: + self.api_base = api_base.rstrip("/") + self.session = ClientSession(trust_env=True) + self.headers: dict[str, str] = {} + if auth_header: + self.headers["Authorization"] = auth_header + elif api_key: + self.headers["Authorization"] = f"Bearer {api_key}" + + async def create_thread(self, timeout: float = 20) -> dict[str, Any]: + url = f"{self.api_base}/api/langgraph/threads" + payload = {"metadata": {}} + async with self.session.post( + url, + json=payload, + headers=self.headers, + timeout=timeout, + ) as resp: + if resp.status not in (200, 201): + text = await resp.text() + raise Exception( + f"DeerFlow create thread failed: {resp.status}. {text}", + ) + return await resp.json() + + async def stream_run( + self, + thread_id: str, + payload: dict[str, Any], + timeout: float = 120, + ) -> AsyncGenerator[dict[str, Any], None]: + url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream" + logger.debug(f"deerflow stream_run payload: {payload}") + # For long-running SSE streams, avoid aiohttp total timeout. + # Use socket read timeout so active heartbeats/chunks can keep the stream alive. + stream_timeout = ClientTimeout( + total=None, + connect=min(timeout, 30), + sock_connect=min(timeout, 30), + sock_read=timeout, + ) + async with self.session.post( + url, + json=payload, + headers={ + **self.headers, + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + timeout=stream_timeout, + ) as resp: + if resp.status != 200: + text = await resp.text() + raise Exception( + f"DeerFlow runs/stream request failed: {resp.status}. {text}", + ) + async for event in _stream_sse(resp): + yield event + + async def close(self) -> None: + await self.session.close() diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index fa9d71d745..e7723150f6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -113,6 +113,7 @@ "dify_agent_runner_provider_id": "", "coze_agent_runner_provider_id": "", "dashscope_agent_runner_provider_id": "", + "deerflow_agent_runner_provider_id": "", "unsupported_streaming_strategy": "realtime_segmenting", "reachability_check": False, "max_agent_step": 30, @@ -1252,6 +1253,25 @@ class ChatProviderTemplate(TypedDict): "timeout": 60, "proxy": "", }, + "DeerFlow": { + "id": "deerflow", + "provider": "deerflow", + "type": "deerflow", + "provider_type": "agent_runner", + "enable": True, + "deerflow_api_base": "http://127.0.0.1:2026", + "deerflow_api_key": "", + "deerflow_auth_header": "", + "deerflow_assistant_id": "lead_agent", + "deerflow_model_name": "", + "deerflow_thinking_enabled": False, + "deerflow_plan_mode": False, + "deerflow_subagent_enabled": False, + "deerflow_max_concurrent_subagents": 3, + "deerflow_recursion_limit": 1000, + "timeout": 300, + "proxy": "", + }, "FastGPT": { "id": "fastgpt", "provider": "fastgpt", @@ -2258,6 +2278,55 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn", }, + "deerflow_api_base": { + "description": "API Base URL", + "type": "string", + "hint": "DeerFlow API 网关地址,默认为 http://127.0.0.1:2026", + }, + "deerflow_api_key": { + "description": "DeerFlow API Key", + "type": "string", + "hint": "可选。若 DeerFlow 网关配置了 Bearer 鉴权,则在此填写。", + }, + "deerflow_auth_header": { + "description": "Authorization Header", + "type": "string", + "hint": "可选。自定义 Authorization 请求头,优先级高于 DeerFlow API Key。", + }, + "deerflow_assistant_id": { + "description": "Assistant ID", + "type": "string", + "hint": "LangGraph assistant_id,默认为 lead_agent。", + }, + "deerflow_model_name": { + "description": "模型名称覆盖", + "type": "string", + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。", + }, + "deerflow_thinking_enabled": { + "description": "启用思考模式", + "type": "bool", + }, + "deerflow_plan_mode": { + "description": "启用计划模式", + "type": "bool", + "hint": "对应 DeerFlow 的 is_plan_mode。", + }, + "deerflow_subagent_enabled": { + "description": "启用子智能体", + "type": "bool", + "hint": "对应 DeerFlow 的 subagent_enabled。", + }, + "deerflow_max_concurrent_subagents": { + "description": "子智能体最大并发数", + "type": "int", + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", + }, + "deerflow_recursion_limit": { + "description": "递归深度上限", + "type": "int", + "hint": "对应 LangGraph recursion_limit。", + }, "auto_save_history": { "description": "由 Coze 管理对话记录", "type": "bool", @@ -2335,6 +2404,9 @@ class ChatProviderTemplate(TypedDict): "dashscope_agent_runner_provider_id": { "type": "string", }, + "deerflow_agent_runner_provider_id": { + "type": "string", + }, "max_agent_step": { "type": "int", }, @@ -2543,7 +2615,7 @@ class ChatProviderTemplate(TypedDict): "metadata": { "agent_runner": { "description": "Agent 执行方式", - "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。", + "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify、Coze、DeerFlow 等第三方 Agent 执行器,不需要修改此节。", "type": "object", "items": { "provider_settings.enable": { @@ -2554,8 +2626,14 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": { "description": "执行器", "type": "string", - "options": ["local", "dify", "coze", "dashscope"], - "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"], + "options": ["local", "dify", "coze", "dashscope", "deerflow"], + "labels": [ + "内置 Agent", + "Dify", + "Coze", + "阿里云百炼应用", + "DeerFlow", + ], "condition": { "provider_settings.enable": True, }, @@ -2587,6 +2665,15 @@ class ChatProviderTemplate(TypedDict): "provider_settings.enable": True, }, }, + "provider_settings.deerflow_agent_runner_provider_id": { + "description": "DeerFlow Agent 执行器提供商 ID", + "type": "string", + "_special": "select_agent_runner_provider:deerflow", + "condition": { + "provider_settings.agent_runner_type": "deerflow", + "provider_settings.enable": True, + }, + }, }, }, "ai": { diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index fcc574bc4f..bc63721a30 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -7,6 +7,9 @@ from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( DashscopeAgentRunner, ) +from astrbot.core.agent.runners.deerflow.deerflow_agent_runner import ( + DeerFlowAgentRunner, +) from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS from astrbot.core.message.components import Image @@ -38,6 +41,7 @@ "dify": "dify_agent_runner_provider_id", "coze": "coze_agent_runner_provider_id", "dashscope": "dashscope_agent_runner_provider_id", + "deerflow": "deerflow_agent_runner_provider_id", } @@ -59,6 +63,9 @@ async def run_third_party_agent( elif resp.type == "llm_result": if stream_to_general: yield resp.data["chain"] + elif resp.type == "err": + # Ensure caller can surface explicit runner errors. + yield resp.data["chain"] except Exception as e: logger.error(f"Third party agent runner error: {e}") err_msg = custom_error_message @@ -152,6 +159,8 @@ async def process( runner = CozeAgentRunner[AstrAgentContext]() elif self.runner_type == "dashscope": runner = DashscopeAgentRunner[AstrAgentContext]() + elif self.runner_type == "deerflow": + runner = DeerFlowAgentRunner[AstrAgentContext]() else: raise ValueError( f"Unsupported third party agent runner type: {self.runner_type}", @@ -207,16 +216,31 @@ async def process( ) else: # 非流式响应或转换为普通响应 - async for _ in run_third_party_agent( + fallback_chain: MessageChain | None = None + async for maybe_chain in run_third_party_agent( runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, ): + if maybe_chain: + fallback_chain = maybe_chain yield final_resp = runner.get_final_llm_resp() if not final_resp or not final_resp.result_chain: + if fallback_chain: + logger.warning( + "Agent Runner returned no final response, fallback to streamed error/result chain." + ) + event.set_result( + MessageEventResult( + chain=fallback_chain.chain or [], + result_content_type=ResultContentType.LLM_RESULT, + ), + ) + yield + return logger.warning("Agent Runner 未返回最终结果。") return diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 6a300302d9..91c0074650 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -27,6 +27,9 @@ def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None: "id" ] conf["provider_settings"]["agent_runner_type"] = "dashscope" + elif p["type"] == "deerflow": + conf["provider_settings"]["deerflow_agent_runner_provider_id"] = p["id"] + conf["provider_settings"]["agent_runner_type"] = "deerflow" conf.save_config() except Exception as e: logger.error(f"Migration for third party agent runner configs failed: {e!s}") @@ -153,7 +156,7 @@ async def migra( ids_map = {} for prov in providers: type_ = prov.get("type") - if type_ in ["dify", "coze", "dashscope"]: + if type_ in ["dify", "coze", "dashscope", "deerflow"]: prov["provider_type"] = "agent_runner" ids_map[prov["id"]] = { "type": type_, diff --git a/dashboard/src/assets/images/platform_logos/deerflow.png b/dashboard/src/assets/images/platform_logos/deerflow.png new file mode 100644 index 0000000000000000000000000000000000000000..24cd1aa3a0a3313bd8e3e2e222fb94ae38e3b013 GIT binary patch literal 775 zcmV+i1Ni)jP)VaDi)YUR&LI_cm6Rk8htWFaH5CZtmCE)D2n>{Qmb> z=+1|C86eES!0?})nfZz=FaKUe0f8L~`~v&=**ShOz&xo_`9$AfNu_=V8CpQDpC?t|0toVBCry z5GsY3KYsaX`2EKp9y+;)WH985(NPZ!G#K^3K!Z^aPy@XF{)6}H??0U6=pB%@(AVF8 zm~XuQU~v8YM~6#q-}$`w_T8q^SkF{bP~ZZoX3@<9_dk7>J^K7r;Enem93Or8YR|yH zAV>{!>E?m`PoL)>c>1E4+U5?5JU|H`%L8A3{N{fD`6nCookKbxD+xV#@x`QR(fN(? z;ynAc6a@Ba%L{_&Yup^HbRXiNBoFXEeEB5+gtwo30^+~ak{8%-q$aZ6TvK$Tm;fhT z`VrL31LD^neu!AI`_96H7heQW(OkS>WM<;Q1V%kD&|uU90}Vz!FwkJs0|O1{<^cgV zwmS?A44kLV z1YrgShPKz=e{fuR``-TI+jpK^EG%Tj&VWvEh;nlLaFmhU3Bu%>LC6Dth50#-@^G&7Us>p+ATRqL+Bg>uzA-Q`fCCBa0S_Cct)%GY;$Zn;ts~x}rzEsPOF`hMm;mPw zTDcSoa8B|u4%Vdx#7?sgS`QElMw=1?4Tii200961|NsAk^_6o^Y;^zt002ovPDHLk FV1jKqZkzxB literal 0 HcmV?d00001 diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index cad25e835c..eeb0ac8052 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -3,7 +3,7 @@ "name": "AI", "agent_runner": { "description": "Agent Runner", - "hint": "Select the runner for AI conversations. Defaults to AstrBot's built-in Agent runner, which supports knowledge base, persona, and tool calling features. You don't need to modify this section unless you plan to integrate third-party Agent runners like Dify or Coze.", + "hint": "Select the runner for AI conversations. Defaults to AstrBot's built-in Agent runner, which supports knowledge base, persona, and tool calling features. You don't need to modify this section unless you plan to integrate third-party Agent runners like Dify, Coze, or DeerFlow.", "provider_settings": { "enable": { "description": "Enable", @@ -15,7 +15,8 @@ "Built-in Agent", "Dify", "Coze", - "Alibaba Cloud Bailian Application" + "Alibaba Cloud Bailian Application", + "DeerFlow" ] }, "coze_agent_runner_provider_id": { @@ -26,6 +27,9 @@ }, "dashscope_agent_runner_provider_id": { "description": "Alibaba Cloud Bailian Application Agent Runner Provider ID" + }, + "deerflow_agent_runner_provider_id": { + "description": "DeerFlow Agent Runner Provider ID" } } }, @@ -1363,6 +1367,45 @@ "description": "API Base URL", "hint": "Base URL for the Coze API. Default: https://api.coze.cn" }, + "deerflow_api_base": { + "description": "API Base URL", + "hint": "DeerFlow API gateway URL. Default: http://127.0.0.1:2026" + }, + "deerflow_api_key": { + "description": "DeerFlow API Key", + "hint": "Optional. Fill this if your DeerFlow gateway is protected by Bearer auth." + }, + "deerflow_auth_header": { + "description": "Authorization Header", + "hint": "Optional. Custom Authorization header value; takes precedence over DeerFlow API Key." + }, + "deerflow_assistant_id": { + "description": "Assistant ID", + "hint": "LangGraph assistant_id, default is lead_agent." + }, + "deerflow_model_name": { + "description": "Model name override", + "hint": "Optional. Overrides DeerFlow default model (maps to runtime context model_name)." + }, + "deerflow_thinking_enabled": { + "description": "Enable thinking mode" + }, + "deerflow_plan_mode": { + "description": "Enable plan mode", + "hint": "Maps to DeerFlow is_plan_mode." + }, + "deerflow_subagent_enabled": { + "description": "Enable subagent", + "hint": "Maps to DeerFlow subagent_enabled." + }, + "deerflow_max_concurrent_subagents": { + "description": "Max concurrent subagents", + "hint": "Maps to DeerFlow max_concurrent_subagents. Effective only when subagent is enabled. Default: 3." + }, + "deerflow_recursion_limit": { + "description": "Recursion limit", + "hint": "Maps to LangGraph recursion_limit." + }, "auto_save_history": { "description": "Conversation history managed by Coze", "hint": "When enabled, Coze manages conversation history. AstrBot's locally saved context will not take effect (read-only), and operations on AstrBot context will not apply. If disabled, AstrBot manages the context." diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index e5eea63fd0..238ad05662 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -3,7 +3,7 @@ "name": "AI 配置", "agent_runner": { "description": "Agent 执行方式", - "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。", + "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify、Coze、DeerFlow 等第三方 Agent 执行器,不需要修改此节。", "provider_settings": { "enable": { "description": "启用", @@ -15,7 +15,8 @@ "内置 Agent", "Dify", "Coze", - "阿里云百炼应用" + "阿里云百炼应用", + "DeerFlow" ] }, "coze_agent_runner_provider_id": { @@ -26,6 +27,9 @@ }, "dashscope_agent_runner_provider_id": { "description": "阿里云百炼应用 Agent 执行器提供商 ID" + }, + "deerflow_agent_runner_provider_id": { + "description": "DeerFlow Agent 执行器提供商 ID" } } }, @@ -1366,6 +1370,45 @@ "description": "API Base URL", "hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn" }, + "deerflow_api_base": { + "description": "API Base URL", + "hint": "DeerFlow API 网关地址,默认为 http://127.0.0.1:2026" + }, + "deerflow_api_key": { + "description": "DeerFlow API Key", + "hint": "可选。若 DeerFlow 网关配置了 Bearer 鉴权,则在此填写。" + }, + "deerflow_auth_header": { + "description": "Authorization Header", + "hint": "可选。自定义 Authorization 请求头,优先级高于 DeerFlow API Key。" + }, + "deerflow_assistant_id": { + "description": "Assistant ID", + "hint": "LangGraph assistant_id,默认为 lead_agent。" + }, + "deerflow_model_name": { + "description": "模型名称覆盖", + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。" + }, + "deerflow_thinking_enabled": { + "description": "启用思考模式" + }, + "deerflow_plan_mode": { + "description": "启用计划模式", + "hint": "对应 DeerFlow 的 is_plan_mode。" + }, + "deerflow_subagent_enabled": { + "description": "启用子智能体", + "hint": "对应 DeerFlow 的 subagent_enabled。" + }, + "deerflow_max_concurrent_subagents": { + "description": "子智能体最大并发数", + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" + }, + "deerflow_recursion_limit": { + "description": "递归深度上限", + "hint": "对应 LangGraph recursion_limit。" + }, "auto_save_history": { "description": "由 Coze 管理对话记录", "hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。" diff --git a/dashboard/src/utils/providerUtils.js b/dashboard/src/utils/providerUtils.js index b02af6942c..b341e3a8a1 100644 --- a/dashboard/src/utils/providerUtils.js +++ b/dashboard/src/utils/providerUtils.js @@ -25,6 +25,7 @@ export function getProviderIcon(type) { 'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg', "coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg", 'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg', + 'deerflow': new URL('../assets/images/platform_logos/deerflow.png', import.meta.url).href, 'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg', 'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg', 'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg', From b90cbe25c9658a623582d746e6ad3d14f44675c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:18:28 +0900 Subject: [PATCH 02/24] refactor: split DeerFlow stream flow and close stale client on reset --- .../runners/deerflow/deerflow_agent_runner.py | 274 +++++++++++------- 1 file changed, 169 insertions(+), 105 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 7b7399c1cc..b3ee321f10 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -2,6 +2,7 @@ import sys import typing as T from collections.abc import Iterable +from dataclasses import dataclass, field import astrbot.core.message.components as Comp from astrbot import logger @@ -27,6 +28,17 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]): """DeerFlow Agent Runner via LangGraph HTTP API.""" + @dataclass + class _StreamState: + streamed_text: str = "" + fallback_stream_text: str = "" + clarification_text: str = "" + task_failures: list[str] = field(default_factory=list) + seen_message_ids: set[str] = field(default_factory=set) + baseline_initialized: bool = False + run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) + timed_out: bool = False + def _format_exception(self, err: Exception) -> str: err_type = type(err).__name__ detail = str(err).strip() @@ -101,6 +113,15 @@ async def reset( if isinstance(self.recursion_limit, str): self.recursion_limit = int(self.recursion_limit) + old_client = getattr(self, "api_client", None) + if isinstance(old_client, DeerFlowAPIClient): + try: + await old_client.close() + except Exception as e: + logger.warning( + f"Failed to close previous DeerFlow API client cleanly: {e}" + ) + self.api_client = DeerFlowAPIClient( api_base=self.api_base, api_key=self.api_key, @@ -371,14 +392,12 @@ async def _ensure_thread_id(self, session_id: str) -> str: ) return thread_id - async def _execute_deerflow_request(self): - prompt = self.req.prompt or "" - session_id = self.req.session_id or "unknown" - image_urls = self.req.image_urls or [] - system_prompt = self.req.system_prompt - - thread_id = await self._ensure_thread_id(session_id) - + def _build_messages( + self, + prompt: str, + image_urls: list[str], + system_prompt: str | None, + ) -> list[dict[str, T.Any]]: messages: list[dict[str, T.Any]] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -388,7 +407,9 @@ async def _execute_deerflow_request(self): "content": self._build_user_content(prompt, image_urls), }, ) + return messages + def _build_runtime_context(self, thread_id: str) -> dict[str, T.Any]: runtime_context: dict[str, T.Any] = { "thread_id": thread_id, "thinking_enabled": self.thinking_enabled, @@ -399,26 +420,147 @@ async def _execute_deerflow_request(self): runtime_context["max_concurrent_subagents"] = self.max_concurrent_subagents if self.model_name: runtime_context["model_name"] = self.model_name + return runtime_context - payload: dict[str, T.Any] = { + def _build_payload( + self, + thread_id: str, + prompt: str, + image_urls: list[str], + system_prompt: str | None, + ) -> dict[str, T.Any]: + return { "assistant_id": self.assistant_id, - "input": {"messages": messages}, + "input": { + "messages": self._build_messages(prompt, image_urls, system_prompt), + }, "stream_mode": ["values", "messages-tuple", "custom"], # LangGraph 0.6+ prefers context instead of configurable. - "context": runtime_context, + "context": self._build_runtime_context(thread_id), "config": { "recursion_limit": self.recursion_limit, }, } - streamed_text = "" - fallback_stream_text = "" - clarification_text = "" - task_failures: list[str] = [] - seen_message_ids: set[str] = set() - baseline_initialized = False - run_values_messages: list[dict[str, T.Any]] = [] - timed_out = False + def _handle_values_event( + self, + data: T.Any, + state: _StreamState, + ) -> list[AgentResponse]: + responses: list[AgentResponse] = [] + values_messages = self._extract_messages_from_values_data(data) + if not values_messages: + return responses + + if not state.baseline_initialized: + state.baseline_initialized = True + for msg in values_messages: + msg_id = self._get_message_id(msg) + if msg_id: + state.seen_message_ids.add(msg_id) + return responses + + new_messages = self._extract_new_messages_from_values( + values_messages, + state.seen_message_ids, + ) + if new_messages: + state.run_values_messages.extend(new_messages) + latest_text = self._extract_latest_ai_text(state.run_values_messages) + latest_clarification = self._extract_latest_clarification_text( + state.run_values_messages, + ) + if latest_clarification: + state.clarification_text = latest_clarification + else: + latest_text = "" + + if self.streaming and latest_text: + if latest_text.startswith(state.streamed_text): + delta = latest_text[len(state.streamed_text) :] + if delta: + state.streamed_text = latest_text + responses.append( + AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(delta), + ), + ), + ) + elif latest_text != state.streamed_text: + state.streamed_text = latest_text + responses.append( + AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(latest_text), + ), + ), + ) + return responses + + def _handle_message_event( + self, + data: T.Any, + state: _StreamState, + ) -> AgentResponse | None: + delta = self._extract_ai_delta_from_event_data(data) + if delta: + state.fallback_stream_text += delta + + response: AgentResponse | None = None + if self.streaming and delta and not state.streamed_text: + response = AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=MessageChain().message(delta)), + ) + + maybe_clarification = self._extract_clarification_from_event_data(data) + if maybe_clarification: + state.clarification_text = maybe_clarification + return response + + def _resolve_final_text(self, state: _StreamState) -> str: + # Clarification tool output should take precedence over partial AI/tool-call text. + if state.clarification_text: + final_text = state.clarification_text + else: + final_text = self._extract_latest_ai_text(state.run_values_messages) + if not final_text: + final_text = state.streamed_text or state.fallback_stream_text + if not final_text: + final_text = self._build_task_failure_summary(state.task_failures) + + if state.timed_out: + timeout_note = ( + f"DeerFlow stream timed out after {self.timeout}s. " + "Returning partial result." + ) + if final_text: + final_text = f"{final_text}\n\n{timeout_note}" + else: + raise asyncio.TimeoutError(timeout_note) + + if not final_text: + logger.warning("DeerFlow returned no text content in stream events.") + final_text = "DeerFlow returned an empty response." + return final_text + + async def _execute_deerflow_request(self): + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + system_prompt = self.req.system_prompt + + thread_id = await self._ensure_thread_id(session_id) + payload = self._build_payload( + thread_id=thread_id, + prompt=prompt, + image_urls=image_urls, + system_prompt=system_prompt, + ) + state = self._StreamState() try: async for event in self.api_client.stream_run( @@ -430,74 +572,18 @@ async def _execute_deerflow_request(self): data = event.get("data") if event_type == "values": - values_messages = self._extract_messages_from_values_data(data) - if values_messages: - if not baseline_initialized: - baseline_initialized = True - for msg in values_messages: - msg_id = self._get_message_id(msg) - if msg_id: - seen_message_ids.add(msg_id) - continue - - new_messages = self._extract_new_messages_from_values( - values_messages, - seen_message_ids, - ) - if new_messages: - run_values_messages.extend(new_messages) - latest_text = self._extract_latest_ai_text( - run_values_messages - ) - latest_clarification = ( - self._extract_latest_clarification_text( - run_values_messages, - ) - ) - if latest_clarification: - clarification_text = latest_clarification - else: - latest_text = "" - - if self.streaming and latest_text: - if latest_text.startswith(streamed_text): - delta = latest_text[len(streamed_text) :] - if delta: - streamed_text = latest_text - yield AgentResponse( - type="streaming_delta", - data=AgentResponseData( - chain=MessageChain().message(delta), - ), - ) - elif latest_text != streamed_text: - streamed_text = latest_text - yield AgentResponse( - type="streaming_delta", - data=AgentResponseData( - chain=MessageChain().message(latest_text), - ), - ) + for response in self._handle_values_event(data, state): + yield response continue if event_type in {"messages-tuple", "messages", "message"}: - delta = self._extract_ai_delta_from_event_data(data) - if delta: - fallback_stream_text += delta - if self.streaming and delta and not streamed_text: - yield AgentResponse( - type="streaming_delta", - data=AgentResponseData(chain=MessageChain().message(delta)), - ) - maybe_clarification = self._extract_clarification_from_event_data( - data - ) - if maybe_clarification: - clarification_text = maybe_clarification + response = self._handle_message_event(data, state) + if response: + yield response continue if event_type == "custom": - task_failures.extend( + state.task_failures.extend( self._extract_task_failures_from_custom_event(data), ) continue @@ -508,31 +594,9 @@ async def _execute_deerflow_request(self): if event_type == "end": break except (asyncio.TimeoutError, TimeoutError): - timed_out = True + state.timed_out = True - # Clarification tool output should take precedence over partial AI/tool-call text. - if clarification_text: - final_text = clarification_text - else: - final_text = self._extract_latest_ai_text(run_values_messages) - if not final_text: - final_text = streamed_text or fallback_stream_text - if not final_text: - final_text = self._build_task_failure_summary(task_failures) - - if timed_out: - timeout_note = ( - f"DeerFlow stream timed out after {self.timeout}s. " - "Returning partial result." - ) - if final_text: - final_text = f"{final_text}\n\n{timeout_note}" - else: - raise asyncio.TimeoutError(timeout_note) - - if not final_text: - logger.warning("DeerFlow returned no text content in stream events.") - final_text = "DeerFlow returned an empty response." + final_text = self._resolve_final_text(state) chain = MessageChain(chain=[Comp.Plain(final_text)]) self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) From c7da6c1d58732253ddccb3edf173efd8c93edbe7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:21:54 +0900 Subject: [PATCH 03/24] fix: enforce max_step and correct timeout type check --- .../runners/deerflow/deerflow_agent_runner.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index b3ee321f10..a2c7845a17 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -46,7 +46,7 @@ def _format_exception(self, err: Exception) -> str: if isinstance(err, (asyncio.TimeoutError, TimeoutError)): timeout_text = ( f"{self.timeout}s" - if isinstance(getattr(self, "timeout", None), int | float) + if isinstance(getattr(self, "timeout", None), (int, float)) else "configured timeout" ) return ( @@ -169,10 +169,20 @@ async def step(self): async def step_until_done( self, max_step: int = 30 ) -> T.AsyncGenerator[AgentResponse, None]: - while not self.done(): + if max_step <= 0: + raise ValueError("max_step must be greater than 0") + + step_count = 0 + while not self.done() and step_count < max_step: + step_count += 1 async for resp in self.step(): yield resp + if not self.done(): + raise RuntimeError( + f"DeerFlow agent reached max_step ({max_step}) without completion." + ) + def _extract_text(self, content: T.Any) -> str: if isinstance(content, str): return content From a018684f25f34a29b284cd562f77369f3f8042a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:26:04 +0900 Subject: [PATCH 04/24] fix: harden DeerFlow config parsing and session lifecycle --- .../runners/deerflow/deerflow_agent_runner.py | 91 +++++++++++++++---- .../runners/deerflow/deerflow_api_client.py | 4 + 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index a2c7845a17..6592e23c79 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -3,6 +3,7 @@ import typing as T from collections.abc import Iterable from dataclasses import dataclass, field +from uuid import uuid4 import astrbot.core.message.components as Comp from astrbot import logger @@ -61,6 +62,48 @@ def _format_exception(self, err: Exception) -> str: return f"{err_type}: no detailed error message provided." + def _coerce_int_config( + self, + field_name: str, + value: T.Any, + default: int, + min_value: int | None = None, + ) -> int: + if isinstance(value, bool): + logger.warning( + f"DeerFlow config '{field_name}' should be numeric, got boolean. " + f"Fallback to {default}." + ) + parsed = default + elif isinstance(value, int): + parsed = value + elif isinstance(value, str): + try: + parsed = int(value.strip()) + except ValueError: + logger.warning( + f"DeerFlow config '{field_name}' value '{value}' is not numeric. " + f"Fallback to {default}." + ) + parsed = default + else: + try: + parsed = int(value) + except (TypeError, ValueError): + logger.warning( + f"DeerFlow config '{field_name}' has unsupported type " + f"{type(value).__name__}. Fallback to {default}." + ) + parsed = default + + if min_value is not None and parsed < min_value: + logger.warning( + f"DeerFlow config '{field_name}'={parsed} is below minimum {min_value}. " + f"Fallback to {min_value}." + ) + parsed = min_value + return parsed + @override async def reset( self, @@ -97,23 +140,40 @@ async def reset( self.subagent_enabled = bool( provider_config.get("deerflow_subagent_enabled", False), ) - self.max_concurrent_subagents = provider_config.get( + self.max_concurrent_subagents = self._coerce_int_config( "deerflow_max_concurrent_subagents", - 3, + provider_config.get( + "deerflow_max_concurrent_subagents", + 3, + ), + default=3, + min_value=1, ) - if isinstance(self.max_concurrent_subagents, str): - self.max_concurrent_subagents = int(self.max_concurrent_subagents) - if self.max_concurrent_subagents < 1: - self.max_concurrent_subagents = 1 - - self.timeout = provider_config.get("timeout", 300) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - self.recursion_limit = provider_config.get("deerflow_recursion_limit", 1000) - if isinstance(self.recursion_limit, str): - self.recursion_limit = int(self.recursion_limit) + self.timeout = self._coerce_int_config( + "timeout", + provider_config.get("timeout", 300), + default=300, + min_value=1, + ) + self.recursion_limit = self._coerce_int_config( + "deerflow_recursion_limit", + provider_config.get("deerflow_recursion_limit", 1000), + default=1000, + min_value=1, + ) + + new_client_signature = (self.api_base, self.api_key, self.auth_header) old_client = getattr(self, "api_client", None) + old_signature = getattr(self, "_api_client_signature", None) + if ( + isinstance(old_client, DeerFlowAPIClient) + and old_signature == new_client_signature + and not old_client.is_closed + ): + self.api_client = old_client + return + if isinstance(old_client, DeerFlowAPIClient): try: await old_client.close() @@ -127,6 +187,7 @@ async def reset( api_key=self.api_key, auth_header=self.auth_header, ) + self._api_client_signature = new_client_signature @override async def step(self): @@ -162,8 +223,6 @@ async def step(self): chain=err_chain, ), ) - finally: - await self.api_client.close() @override async def step_until_done( @@ -559,7 +618,7 @@ def _resolve_final_text(self, state: _StreamState) -> str: async def _execute_deerflow_request(self): prompt = self.req.prompt or "" - session_id = self.req.session_id or "unknown" + session_id = self.req.session_id or f"deerflow-ephemeral-{uuid4()}" image_urls = self.req.image_urls or [] system_prompt = self.req.system_prompt diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index 6c4da118ae..f72d7faa3c 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -151,3 +151,7 @@ async def stream_run( async def close(self) -> None: await self.session.close() + + @property + def is_closed(self) -> bool: + return self.session.closed From 1a7bc0ac5cfeac1f77c31526542cf0e04c252dc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:30:30 +0900 Subject: [PATCH 05/24] fix: preserve third-party runner error semantics and harden image parsing --- .../runners/deerflow/deerflow_agent_runner.py | 33 +++++++- astrbot/core/message/message_event_result.py | 2 + .../method/agent_sub_stages/third_party.py | 78 ++++++++++++++----- 3 files changed, 91 insertions(+), 22 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 6592e23c79..339c8b4814 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -419,6 +419,19 @@ def _build_task_failure_summary(self, failures: list[str]) -> str: joined = "\n".join([f"- {item}" for item in deduped[:5]]) return f"DeerFlow subtasks failed:\n{joined}" + def _is_likely_base64_image(self, value: str) -> bool: + if " " in value: + return False + + compact = value.replace("\n", "").replace("\r", "") + if not compact or len(compact) % 4 != 0: + return False + + base64_chars = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" + ) + return all(ch in base64_chars for ch in compact) + def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: if not image_urls: return prompt @@ -431,9 +444,23 @@ def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: url = image_url if not isinstance(url, str): continue - if not url.startswith(("http://", "https://", "data:")): - url = f"data:image/png;base64,{url}" - content.append({"type": "image_url", "image_url": {"url": url}}) + url = url.strip() + if not url: + continue + if url.startswith(("http://", "https://", "data:")): + content.append({"type": "image_url", "image_url": {"url": url}}) + continue + if not self._is_likely_base64_image(url): + logger.warning( + "Skip unsupported DeerFlow image input that is neither URL/data URI nor valid base64." + ) + continue + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{url}"}, + }, + ) return content async def _ensure_thread_id(self, session_id: str) -> str: diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index eba6a4fd66..dc5ff7e65f 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -182,6 +182,8 @@ class ResultContentType(enum.Enum): LLM_RESULT = enum.auto() """调用 LLM 产生的结果""" + AGENT_RUNNER_ERROR = enum.auto() + """第三方 Agent Runner 返回的错误结果""" GENERAL_RESULT = enum.auto() """普通的消息结果""" STREAMING_RESULT = enum.auto() diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index bc63721a30..41d2a946a4 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -1,5 +1,6 @@ import asyncio from collections.abc import AsyncGenerator +from dataclasses import dataclass from typing import TYPE_CHECKING from astrbot.core import astrbot_config, logger @@ -49,7 +50,7 @@ async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, custom_error_message: str | None = None, -) -> AsyncGenerator[MessageChain | None, None]: +) -> AsyncGenerator["_ThirdPartyRunnerOutput", None]: """ 运行第三方 agent runner 并转换响应格式 类似于 run_agent 函数,但专门处理第三方 agent runner @@ -59,13 +60,21 @@ async def run_third_party_agent( if resp.type == "streaming_delta": if stream_to_general: continue - yield resp.data["chain"] + yield _ThirdPartyRunnerOutput( + chain=resp.data["chain"], + is_error=False, + ) elif resp.type == "llm_result": if stream_to_general: - yield resp.data["chain"] + yield _ThirdPartyRunnerOutput( + chain=resp.data["chain"], + is_error=False, + ) elif resp.type == "err": - # Ensure caller can surface explicit runner errors. - yield resp.data["chain"] + yield _ThirdPartyRunnerOutput( + chain=resp.data["chain"], + is_error=True, + ) except Exception as e: logger.error(f"Third party agent runner error: {e}") err_msg = custom_error_message @@ -75,7 +84,24 @@ async def run_third_party_agent( f"Error Type: {type(e).__name__} (3rd party)\n" f"Error Message: {str(e)}" ) - yield MessageChain().message(err_msg) + yield _ThirdPartyRunnerOutput( + chain=MessageChain().message(err_msg), + is_error=True, + ) + + +@dataclass +class _ThirdPartyRunnerOutput: + chain: MessageChain | None + is_error: bool = False + + +async def _iter_runner_output_chain( + output_stream: AsyncGenerator[_ThirdPartyRunnerOutput, None], +) -> AsyncGenerator[MessageChain, None]: + async for output in output_stream: + if output.chain: + yield output.chain class ThirdPartyAgentSubStage(Stage): @@ -197,17 +223,21 @@ async def process( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) .set_async_stream( - run_third_party_agent( - runner, - stream_to_general=False, - custom_error_message=custom_error_message, - ), + _iter_runner_output_chain( + run_third_party_agent( + runner, + stream_to_general=False, + custom_error_message=custom_error_message, + ), + ) ), ) yield if runner.done(): final_resp = runner.get_final_llm_resp() if final_resp and final_resp.result_chain: + is_runner_error = final_resp.role == "err" + event.set_extra("_third_party_runner_error", is_runner_error) event.set_result( MessageEventResult( chain=final_resp.result_chain.chain or [], @@ -216,27 +246,32 @@ async def process( ) else: # 非流式响应或转换为普通响应 - fallback_chain: MessageChain | None = None - async for maybe_chain in run_third_party_agent( + fallback_output: _ThirdPartyRunnerOutput | None = None + async for output in run_third_party_agent( runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, ): - if maybe_chain: - fallback_chain = maybe_chain + if output.chain: + fallback_output = output yield final_resp = runner.get_final_llm_resp() if not final_resp or not final_resp.result_chain: - if fallback_chain: + if fallback_output and fallback_output.chain: logger.warning( "Agent Runner returned no final response, fallback to streamed error/result chain." ) + content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if fallback_output.is_error + else ResultContentType.LLM_RESULT + ) event.set_result( MessageEventResult( - chain=fallback_chain.chain or [], - result_content_type=ResultContentType.LLM_RESULT, + chain=fallback_output.chain.chain or [], + result_content_type=content_type, ), ) yield @@ -244,10 +279,15 @@ async def process( logger.warning("Agent Runner 未返回最终结果。") return + content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if final_resp.role == "err" + else ResultContentType.LLM_RESULT + ) event.set_result( MessageEventResult( chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.LLM_RESULT, + result_content_type=content_type, ), ) yield From 7e91cd2265d95bb8eadc54ae08f2b275fbcb367a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:34:02 +0900 Subject: [PATCH 06/24] perf: bound DeerFlow values history and seen-id cache --- .../runners/deerflow/deerflow_agent_runner.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 339c8b4814..d6903b8863 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -1,6 +1,7 @@ import asyncio import sys import typing as T +from collections import deque from collections.abc import Iterable from dataclasses import dataclass, field from uuid import uuid4 @@ -29,6 +30,8 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]): """DeerFlow Agent Runner via LangGraph HTTP API.""" + _MAX_VALUES_HISTORY = 200 + @dataclass class _StreamState: streamed_text: str = "" @@ -36,6 +39,7 @@ class _StreamState: clarification_text: str = "" task_failures: list[str] = field(default_factory=list) seen_message_ids: set[str] = field(default_factory=set) + seen_message_order: deque[str] = field(default_factory=deque) baseline_initialized: bool = False run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) timed_out: bool = False @@ -331,19 +335,29 @@ def _get_message_id(self, message: T.Any) -> str: def _extract_new_messages_from_values( self, values_messages: list[T.Any], - seen_message_ids: set[str], + state: _StreamState, ) -> list[dict[str, T.Any]]: new_messages: list[dict[str, T.Any]] = [] for msg in values_messages: if not isinstance(msg, dict): continue msg_id = self._get_message_id(msg) - if not msg_id or msg_id in seen_message_ids: + if not msg_id or msg_id in state.seen_message_ids: continue - seen_message_ids.add(msg_id) + self._remember_seen_message_id(state, msg_id) new_messages.append(msg) return new_messages + def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None: + if not msg_id or msg_id in state.seen_message_ids: + return + + state.seen_message_ids.add(msg_id) + state.seen_message_order.append(msg_id) + while len(state.seen_message_order) > self._MAX_VALUES_HISTORY: + dropped = state.seen_message_order.popleft() + state.seen_message_ids.discard(dropped) + def _extract_event_message_obj(self, data: T.Any) -> dict[str, T.Any] | None: msg_obj = data if isinstance(data, (list, tuple)) and data: @@ -552,16 +566,19 @@ def _handle_values_event( state.baseline_initialized = True for msg in values_messages: msg_id = self._get_message_id(msg) - if msg_id: - state.seen_message_ids.add(msg_id) + self._remember_seen_message_id(state, msg_id) return responses new_messages = self._extract_new_messages_from_values( values_messages, - state.seen_message_ids, + state, ) if new_messages: state.run_values_messages.extend(new_messages) + if len(state.run_values_messages) > self._MAX_VALUES_HISTORY: + state.run_values_messages = state.run_values_messages[ + -self._MAX_VALUES_HISTORY : + ] latest_text = self._extract_latest_ai_text(state.run_values_messages) latest_clarification = self._extract_latest_clarification_text( state.run_values_messages, From c82b00966dcdef77c618555c211d5fc1cf5869a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:42:05 +0900 Subject: [PATCH 07/24] refactor: improve deerflow stream semantics and client lifecycle --- .../runners/deerflow/deerflow_agent_runner.py | 213 ++++-------------- .../runners/deerflow/deerflow_api_client.py | 21 +- .../runners/deerflow/deerflow_stream_utils.py | 176 +++++++++++++++ .../method/agent_sub_stages/third_party.py | 180 ++++++++------- 4 files changed, 341 insertions(+), 249 deletions(-) create mode 100644 astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index d6903b8863..b49f214442 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -2,7 +2,6 @@ import sys import typing as T from collections import deque -from collections.abc import Iterable from dataclasses import dataclass, field from uuid import uuid4 @@ -20,6 +19,16 @@ from ...run_context import ContextWrapper, TContext from ..base import AgentResponse, AgentState, BaseAgentRunner from .deerflow_api_client import DeerFlowAPIClient +from .deerflow_stream_utils import ( + build_task_failure_summary, + extract_ai_delta_from_event_data, + extract_clarification_from_event_data, + extract_latest_ai_text, + extract_latest_clarification_text, + extract_messages_from_values_data, + extract_task_failures_from_custom_event, + get_message_id, +) if sys.version_info >= (3, 12): from typing import override @@ -44,6 +53,21 @@ class _StreamState: run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) timed_out: bool = False + def __del__(self) -> None: + """Best-effort cleanup when runner is garbage collected.""" + api_client = getattr(self, "api_client", None) + if not isinstance(api_client, DeerFlowAPIClient) or api_client.is_closed: + return + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + + if loop.is_closed(): + return + loop.create_task(api_client.close()) + def _format_exception(self, err: Exception) -> str: err_type = type(err).__name__ detail = str(err).strip() @@ -108,6 +132,12 @@ def _coerce_int_config( parsed = min_value return parsed + async def close(self) -> None: + """Explicit cleanup hook for long-lived workers.""" + api_client = getattr(self, "api_client", None) + if isinstance(api_client, DeerFlowAPIClient) and not api_client.is_closed: + await api_client.close() + @override async def reset( self, @@ -246,92 +276,6 @@ async def step_until_done( f"DeerFlow agent reached max_step ({max_step}) without completion." ) - def _extract_text(self, content: T.Any) -> str: - if isinstance(content, str): - return content - if isinstance(content, dict): - if isinstance(content.get("text"), str): - return content["text"] - if "content" in content: - return self._extract_text(content.get("content")) - if "kwargs" in content and isinstance(content["kwargs"], dict): - return self._extract_text(content["kwargs"].get("content")) - if isinstance(content, list): - parts: list[str] = [] - for item in content: - if isinstance(item, str): - parts.append(item) - elif isinstance(item, dict): - item_type = item.get("type") - if item_type == "text" and isinstance(item.get("text"), str): - parts.append(item["text"]) - elif "content" in item: - parts.append(str(item["content"])) - return "\n".join([p for p in parts if p]).strip() - return str(content) if content is not None else "" - - def _extract_messages_from_values_data(self, data: T.Any) -> list[T.Any]: - """Extract messages list from possible values event payload shapes.""" - candidates: list[T.Any] = [] - if isinstance(data, dict): - candidates.append(data) - if isinstance(data.get("values"), dict): - candidates.append(data["values"]) - elif isinstance(data, list): - candidates.extend([x for x in data if isinstance(x, dict)]) - - for item in candidates: - messages = item.get("messages") - if isinstance(messages, list): - return messages - return [] - - def _is_ai_message(self, message: dict[str, T.Any]) -> bool: - role = str(message.get("role", "")).lower() - if role in {"assistant", "ai"}: - return True - - msg_type = str(message.get("type", "")).lower() - if msg_type in {"ai", "assistant", "aimessage", "aimessagechunk"}: - return True - if "ai" in msg_type and all( - token not in msg_type for token in ("human", "tool", "system") - ): - return True - return False - - def _extract_latest_ai_text(self, messages: Iterable[T.Any]) -> str: - # Scan backwards to get the latest assistant/ai message text. - for msg in reversed(list(messages)): - if not isinstance(msg, dict): - continue - if self._is_ai_message(msg): - text = self._extract_text(msg.get("content")) - if text: - return text - return "" - - def _is_clarification_tool_message(self, message: dict[str, T.Any]) -> bool: - msg_type = str(message.get("type", "")).lower() - tool_name = str(message.get("name", "")).lower() - return msg_type == "tool" and tool_name == "ask_clarification" - - def _extract_latest_clarification_text(self, messages: Iterable[T.Any]) -> str: - for msg in reversed(list(messages)): - if not isinstance(msg, dict): - continue - if self._is_clarification_tool_message(msg): - text = self._extract_text(msg.get("content")) - if text: - return text - return "" - - def _get_message_id(self, message: T.Any) -> str: - if not isinstance(message, dict): - return "" - msg_id = message.get("id") - return msg_id if isinstance(msg_id, str) else "" - def _extract_new_messages_from_values( self, values_messages: list[T.Any], @@ -341,7 +285,7 @@ def _extract_new_messages_from_values( for msg in values_messages: if not isinstance(msg, dict): continue - msg_id = self._get_message_id(msg) + msg_id = get_message_id(msg) if not msg_id or msg_id in state.seen_message_ids: continue self._remember_seen_message_id(state, msg_id) @@ -358,81 +302,6 @@ def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None: dropped = state.seen_message_order.popleft() state.seen_message_ids.discard(dropped) - def _extract_event_message_obj(self, data: T.Any) -> dict[str, T.Any] | None: - msg_obj = data - if isinstance(data, (list, tuple)) and data: - msg_obj = data[0] - if isinstance(msg_obj, dict) and isinstance(msg_obj.get("data"), dict): - # Some servers wrap message body in {"data": {...}} - msg_obj = msg_obj["data"] - return msg_obj if isinstance(msg_obj, dict) else None - - def _extract_ai_delta_from_event_data(self, data: T.Any) -> str: - # LangGraph messages-tuple events usually carry either: - # - {"type": "ai", "content": "..."} - # - [message_obj, metadata] - msg_obj = self._extract_event_message_obj(data) - if not msg_obj: - return "" - if self._is_ai_message(msg_obj): - return self._extract_text(msg_obj.get("content")) - return "" - - def _extract_clarification_from_event_data(self, data: T.Any) -> str: - msg_obj = self._extract_event_message_obj(data) - if not msg_obj: - return "" - if self._is_clarification_tool_message(msg_obj): - return self._extract_text(msg_obj.get("content")) - return "" - - def _iter_custom_event_items(self, data: T.Any) -> list[dict[str, T.Any]]: - items: list[dict[str, T.Any]] = [] - if isinstance(data, dict): - return [data] - if isinstance(data, list): - for item in data: - if isinstance(item, dict): - items.append(item) - elif isinstance(item, (list, tuple)): - for nested in item: - if isinstance(nested, dict): - items.append(nested) - return items - - def _extract_task_failures_from_custom_event(self, data: T.Any) -> list[str]: - failures: list[str] = [] - for item in self._iter_custom_event_items(data): - event_type = str(item.get("type", "")).lower() - if event_type not in {"task_failed", "task_timed_out"}: - continue - - task_id = str(item.get("task_id", "")).strip() - error_text = self._extract_text(item.get("error")).strip() - if task_id and error_text: - failures.append(f"{task_id}: {error_text}") - elif error_text: - failures.append(error_text) - elif task_id: - failures.append(f"{task_id}: unknown error") - else: - failures.append("unknown task failure") - return failures - - def _build_task_failure_summary(self, failures: list[str]) -> str: - if not failures: - return "" - deduped: list[str] = [] - seen: set[str] = set() - for failure in failures: - if failure not in seen: - seen.add(failure) - deduped.append(failure) - if len(deduped) == 1: - return f"DeerFlow subtask failed: {deduped[0]}" - joined = "\n".join([f"- {item}" for item in deduped[:5]]) - return f"DeerFlow subtasks failed:\n{joined}" - def _is_likely_base64_image(self, value: str) -> bool: if " " in value: return False @@ -558,14 +427,14 @@ def _handle_values_event( state: _StreamState, ) -> list[AgentResponse]: responses: list[AgentResponse] = [] - values_messages = self._extract_messages_from_values_data(data) + values_messages = extract_messages_from_values_data(data) if not values_messages: return responses if not state.baseline_initialized: state.baseline_initialized = True for msg in values_messages: - msg_id = self._get_message_id(msg) + msg_id = get_message_id(msg) self._remember_seen_message_id(state, msg_id) return responses @@ -579,8 +448,8 @@ def _handle_values_event( state.run_values_messages = state.run_values_messages[ -self._MAX_VALUES_HISTORY : ] - latest_text = self._extract_latest_ai_text(state.run_values_messages) - latest_clarification = self._extract_latest_clarification_text( + latest_text = extract_latest_ai_text(state.run_values_messages) + latest_clarification = extract_latest_clarification_text( state.run_values_messages, ) if latest_clarification: @@ -618,7 +487,7 @@ def _handle_message_event( data: T.Any, state: _StreamState, ) -> AgentResponse | None: - delta = self._extract_ai_delta_from_event_data(data) + delta = extract_ai_delta_from_event_data(data) if delta: state.fallback_stream_text += delta @@ -629,7 +498,7 @@ def _handle_message_event( data=AgentResponseData(chain=MessageChain().message(delta)), ) - maybe_clarification = self._extract_clarification_from_event_data(data) + maybe_clarification = extract_clarification_from_event_data(data) if maybe_clarification: state.clarification_text = maybe_clarification return response @@ -639,11 +508,11 @@ def _resolve_final_text(self, state: _StreamState) -> str: if state.clarification_text: final_text = state.clarification_text else: - final_text = self._extract_latest_ai_text(state.run_values_messages) + final_text = extract_latest_ai_text(state.run_values_messages) if not final_text: final_text = state.streamed_text or state.fallback_stream_text if not final_text: - final_text = self._build_task_failure_summary(state.task_failures) + final_text = build_task_failure_summary(state.task_failures) if state.timed_out: timeout_note = ( @@ -697,7 +566,7 @@ async def _execute_deerflow_request(self): if event_type == "custom": state.task_failures.extend( - self._extract_task_failures_from_custom_event(data), + extract_task_failures_from_custom_event(data), ) continue diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index f72d7faa3c..aa2f61ab91 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -1,6 +1,8 @@ +import asyncio import codecs import json from collections.abc import AsyncGenerator +from contextlib import suppress from typing import Any from aiohttp import ClientResponse, ClientSession, ClientTimeout @@ -150,7 +152,24 @@ async def stream_run( yield event async def close(self) -> None: - await self.session.close() + if not self.session.closed: + await self.session.close() + + def __del__(self) -> None: + session = getattr(self, "session", None) + if session is None or session.closed: + return + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + + if loop.is_closed(): + return + + with suppress(RuntimeError): + loop.create_task(self.close()) @property def is_closed(self) -> bool: diff --git a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py new file mode 100644 index 0000000000..d8bd324d08 --- /dev/null +++ b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py @@ -0,0 +1,176 @@ +import typing as T +from collections.abc import Iterable + + +def extract_text(content: T.Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, dict): + if isinstance(content.get("text"), str): + return content["text"] + if "content" in content: + return extract_text(content.get("content")) + if "kwargs" in content and isinstance(content["kwargs"], dict): + return extract_text(content["kwargs"].get("content")) + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + item_type = item.get("type") + if item_type == "text" and isinstance(item.get("text"), str): + parts.append(item["text"]) + elif "content" in item: + parts.append(str(item["content"])) + return "\n".join([p for p in parts if p]).strip() + return str(content) if content is not None else "" + + +def extract_messages_from_values_data(data: T.Any) -> list[T.Any]: + """Extract messages list from possible values event payload shapes.""" + candidates: list[T.Any] = [] + if isinstance(data, dict): + candidates.append(data) + if isinstance(data.get("values"), dict): + candidates.append(data["values"]) + elif isinstance(data, list): + candidates.extend([x for x in data if isinstance(x, dict)]) + + for item in candidates: + messages = item.get("messages") + if isinstance(messages, list): + return messages + return [] + + +def is_ai_message(message: dict[str, T.Any]) -> bool: + role = str(message.get("role", "")).lower() + if role in {"assistant", "ai"}: + return True + + msg_type = str(message.get("type", "")).lower() + if msg_type in {"ai", "assistant", "aimessage", "aimessagechunk"}: + return True + if "ai" in msg_type and all( + token not in msg_type for token in ("human", "tool", "system") + ): + return True + return False + + +def extract_latest_ai_text(messages: Iterable[T.Any]) -> str: + # Scan backwards to get the latest assistant/ai message text. + for msg in reversed(list(messages)): + if not isinstance(msg, dict): + continue + if is_ai_message(msg): + text = extract_text(msg.get("content")) + if text: + return text + return "" + + +def is_clarification_tool_message(message: dict[str, T.Any]) -> bool: + msg_type = str(message.get("type", "")).lower() + tool_name = str(message.get("name", "")).lower() + return msg_type == "tool" and tool_name == "ask_clarification" + + +def extract_latest_clarification_text(messages: Iterable[T.Any]) -> str: + for msg in reversed(list(messages)): + if not isinstance(msg, dict): + continue + if is_clarification_tool_message(msg): + text = extract_text(msg.get("content")) + if text: + return text + return "" + + +def get_message_id(message: T.Any) -> str: + if not isinstance(message, dict): + return "" + msg_id = message.get("id") + return msg_id if isinstance(msg_id, str) else "" + + +def extract_event_message_obj(data: T.Any) -> dict[str, T.Any] | None: + msg_obj = data + if isinstance(data, (list, tuple)) and data: + msg_obj = data[0] + if isinstance(msg_obj, dict) and isinstance(msg_obj.get("data"), dict): + # Some servers wrap message body in {"data": {...}} + msg_obj = msg_obj["data"] + return msg_obj if isinstance(msg_obj, dict) else None + + +def extract_ai_delta_from_event_data(data: T.Any) -> str: + # LangGraph messages-tuple events usually carry either: + # - {"type": "ai", "content": "..."} + # - [message_obj, metadata] + msg_obj = extract_event_message_obj(data) + if not msg_obj: + return "" + if is_ai_message(msg_obj): + return extract_text(msg_obj.get("content")) + return "" + + +def extract_clarification_from_event_data(data: T.Any) -> str: + msg_obj = extract_event_message_obj(data) + if not msg_obj: + return "" + if is_clarification_tool_message(msg_obj): + return extract_text(msg_obj.get("content")) + return "" + + +def _iter_custom_event_items(data: T.Any) -> list[dict[str, T.Any]]: + items: list[dict[str, T.Any]] = [] + if isinstance(data, dict): + return [data] + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + items.append(item) + elif isinstance(item, (list, tuple)): + for nested in item: + if isinstance(nested, dict): + items.append(nested) + return items + + +def extract_task_failures_from_custom_event(data: T.Any) -> list[str]: + failures: list[str] = [] + for item in _iter_custom_event_items(data): + event_type = str(item.get("type", "")).lower() + if event_type not in {"task_failed", "task_timed_out"}: + continue + + task_id = str(item.get("task_id", "")).strip() + error_text = extract_text(item.get("error")).strip() + if task_id and error_text: + failures.append(f"{task_id}: {error_text}") + elif error_text: + failures.append(error_text) + elif task_id: + failures.append(f"{task_id}: unknown error") + else: + failures.append("unknown task failure") + return failures + + +def build_task_failure_summary(failures: list[str]) -> str: + if not failures: + return "" + deduped: list[str] = [] + seen: set[str] = set() + for failure in failures: + if failure not in seen: + seen.add(failure) + deduped.append(failure) + if len(deduped) == 1: + return f"DeerFlow subtask failed: {deduped[0]}" + joined = "\n".join([f"- {item}" for item in deduped[:5]]) + return f"DeerFlow subtasks failed:\n{joined}" diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 41d2a946a4..613f827dbc 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -1,4 +1,5 @@ import asyncio +import inspect from collections.abc import AsyncGenerator from dataclasses import dataclass from typing import TYPE_CHECKING @@ -98,10 +99,23 @@ class _ThirdPartyRunnerOutput: async def _iter_runner_output_chain( output_stream: AsyncGenerator[_ThirdPartyRunnerOutput, None], -) -> AsyncGenerator[MessageChain, None]: +) -> AsyncGenerator[_ThirdPartyRunnerOutput, None]: async for output in output_stream: if output.chain: - yield output.chain + yield output + + +async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: + close_callable = getattr(runner, "close", None) + if not callable(close_callable): + return + + try: + close_result = close_callable() + if inspect.isawaitable(close_result): + await close_result + except Exception as e: + logger.warning(f"Failed to close third-party runner cleanly: {e}") class ThirdPartyAgentSubStage(Stage): @@ -206,91 +220,105 @@ async def process( and not event.platform_meta.support_streaming_message ) - await runner.reset( - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=60, - ), - agent_hooks=MAIN_AGENT_HOOKS, - provider_config=self.prov_cfg, - streaming=streaming_response, - ) + try: + await runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, + streaming=streaming_response, + ) + + if streaming_response and not stream_to_general: + # 流式响应 + stream_has_runner_error = False - if streaming_response and not stream_to_general: - # 流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - _iter_runner_output_chain( + async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: + nonlocal stream_has_runner_error + async for runner_output in _iter_runner_output_chain( run_third_party_agent( runner, stream_to_general=False, custom_error_message=custom_error_message, ), - ) - ), - ) - yield - if runner.done(): - final_resp = runner.get_final_llm_resp() - if final_resp and final_resp.result_chain: - is_runner_error = final_resp.role == "err" - event.set_extra("_third_party_runner_error", is_runner_error) - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - else: - # 非流式响应或转换为普通响应 - fallback_output: _ThirdPartyRunnerOutput | None = None - async for output in run_third_party_agent( - runner, - stream_to_general=stream_to_general, - custom_error_message=custom_error_message, - ): - if output.chain: - fallback_output = output + ): + if runner_output.is_error: + stream_has_runner_error = True + event.set_extra("_third_party_runner_error", True) + if runner_output.chain: + yield runner_output.chain + + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(_stream_runner_chain()), + ) yield + if runner.done(): + final_resp = runner.get_final_llm_resp() + if final_resp and final_resp.result_chain: + is_runner_error = ( + stream_has_runner_error or final_resp.role == "err" + ) + event.set_extra("_third_party_runner_error", is_runner_error) + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + else: + # 非流式响应或转换为普通响应 + fallback_output: _ThirdPartyRunnerOutput | None = None + async for output in run_third_party_agent( + runner, + stream_to_general=stream_to_general, + custom_error_message=custom_error_message, + ): + if output.chain: + fallback_output = output + yield - final_resp = runner.get_final_llm_resp() + final_resp = runner.get_final_llm_resp() - if not final_resp or not final_resp.result_chain: - if fallback_output and fallback_output.chain: - logger.warning( - "Agent Runner returned no final response, fallback to streamed error/result chain." - ) - content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if fallback_output.is_error - else ResultContentType.LLM_RESULT - ) - event.set_result( - MessageEventResult( - chain=fallback_output.chain.chain or [], - result_content_type=content_type, - ), - ) - yield + if not final_resp or not final_resp.result_chain: + if fallback_output and fallback_output.chain: + logger.warning( + "Agent Runner returned no final response, fallback to streamed error/result chain." + ) + content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if fallback_output.is_error + else ResultContentType.LLM_RESULT + ) + event.set_result( + MessageEventResult( + chain=fallback_output.chain.chain or [], + result_content_type=content_type, + ), + ) + yield + return + logger.warning("Agent Runner 未返回最终结果。") return - logger.warning("Agent Runner 未返回最终结果。") - return - content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if final_resp.role == "err" - else ResultContentType.LLM_RESULT - ) - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=content_type, - ), - ) - yield + content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if final_resp.role == "err" + else ResultContentType.LLM_RESULT + ) + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=content_type, + ), + ) + yield + finally: + await _close_runner_if_supported(runner) asyncio.create_task( Metric.upload( From 9c10db91178ef2f5a3412185c9a8795c792354cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:47:33 +0900 Subject: [PATCH 08/24] fix: harden third-party runner error semantics and fallback aggregation --- .../runners/deerflow/deerflow_agent_runner.py | 15 ------------- .../runners/deerflow/deerflow_api_client.py | 17 ++++----------- astrbot/core/message/message_event_result.py | 7 +++++++ .../method/agent_sub_stages/third_party.py | 21 ++++++++++++++----- astrbot/core/pipeline/respond/stage.py | 2 +- .../core/pipeline/result_decorate/stage.py | 2 +- 6 files changed, 29 insertions(+), 35 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index b49f214442..2d2e300cea 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -53,21 +53,6 @@ class _StreamState: run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) timed_out: bool = False - def __del__(self) -> None: - """Best-effort cleanup when runner is garbage collected.""" - api_client = getattr(self, "api_client", None) - if not isinstance(api_client, DeerFlowAPIClient) or api_client.is_closed: - return - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return - - if loop.is_closed(): - return - loop.create_task(api_client.close()) - def _format_exception(self, err: Exception) -> str: err_type = type(err).__name__ detail = str(err).strip() diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index aa2f61ab91..a4fea4557d 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -1,8 +1,6 @@ -import asyncio import codecs import json from collections.abc import AsyncGenerator -from contextlib import suppress from typing import Any from aiohttp import ClientResponse, ClientSession, ClientTimeout @@ -159,17 +157,10 @@ def __del__(self) -> None: session = getattr(self, "session", None) if session is None or session.closed: return - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return - - if loop.is_closed(): - return - - with suppress(RuntimeError): - loop.create_task(self.close()) + logger.warning( + "DeerFlowAPIClient garbage collected with unclosed session; " + "explicit close() should be called by runner lifecycle." + ) @property def is_closed(self) -> bool: diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index dc5ff7e65f..0965fe7f7f 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -248,6 +248,13 @@ def is_llm_result(self) -> bool: """是否为 LLM 结果。""" return self.result_content_type == ResultContentType.LLM_RESULT + def is_model_result(self) -> bool: + """Whether result comes from model execution (including runner errors).""" + return self.result_content_type in ( + ResultContentType.LLM_RESULT, + ResultContentType.AGENT_RUNNER_ERROR, + ) + # 为了兼容旧版代码,保留 CommandResult 的别名 CommandResult = MessageEventResult diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 613f827dbc..b15da7a938 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -272,31 +272,38 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: ) else: # 非流式响应或转换为普通响应 - fallback_output: _ThirdPartyRunnerOutput | None = None + fallback_chains: list[MessageChain] = [] + fallback_is_error = False async for output in run_third_party_agent( runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, ): if output.chain: - fallback_output = output + fallback_chains.append(output.chain) + if output.is_error: + fallback_is_error = True yield final_resp = runner.get_final_llm_resp() if not final_resp or not final_resp.result_chain: - if fallback_output and fallback_output.chain: + if fallback_chains: logger.warning( "Agent Runner returned no final response, fallback to streamed error/result chain." ) + merged_chain: list = [] + for chain in fallback_chains: + merged_chain.extend(chain.chain or []) content_type = ( ResultContentType.AGENT_RUNNER_ERROR - if fallback_output.is_error + if fallback_is_error else ResultContentType.LLM_RESULT ) + event.set_extra("_third_party_runner_error", fallback_is_error) event.set_result( MessageEventResult( - chain=fallback_output.chain.chain or [], + chain=merged_chain, result_content_type=content_type, ), ) @@ -310,6 +317,10 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: if final_resp.role == "err" else ResultContentType.LLM_RESULT ) + event.set_extra( + "_third_party_runner_error", + final_resp.role == "err", + ) event.set_result( MessageEventResult( chain=final_resp.result_chain.chain or [], diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 72e853ffcc..bd307f8b77 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -135,7 +135,7 @@ def is_seg_reply_required(self, event: AstrMessageEvent) -> bool: if (result := event.get_result()) is None: return False - if self.only_llm_result and not result.is_llm_result(): + if self.only_llm_result and not result.is_model_result(): return False if event.get_platform_name() in [ diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 15d68fb22e..f2fe8161b5 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -209,7 +209,7 @@ async def process( "dingtalk", ]: if ( - self.only_llm_result and result.is_llm_result() + self.only_llm_result and result.is_model_result() ) or not self.only_llm_result: new_chain = [] for comp in result.chain: From fd4fe4b07af9adb5e54635e4b0d3504452e7bfe7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:51:17 +0900 Subject: [PATCH 09/24] refactor: reduce deerflow image log noise and lazy-init api session --- .../runners/deerflow/deerflow_agent_runner.py | 10 +++++-- .../runners/deerflow/deerflow_api_client.py | 30 ++++++++++++++----- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 2d2e300cea..af1112006e 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -305,6 +305,7 @@ def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: return prompt content: list[dict[str, T.Any]] = [] + skipped_invalid_images = 0 if prompt: content.append({"type": "text", "text": prompt}) @@ -319,9 +320,7 @@ def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: content.append({"type": "image_url", "image_url": {"url": url}}) continue if not self._is_likely_base64_image(url): - logger.warning( - "Skip unsupported DeerFlow image input that is neither URL/data URI nor valid base64." - ) + skipped_invalid_images += 1 continue content.append( { @@ -329,6 +328,11 @@ def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: "image_url": {"url": f"data:image/png;base64,{url}"}, }, ) + if skipped_invalid_images: + logger.debug( + "Skipped %d DeerFlow image inputs that were neither URL/data URI nor valid base64.", + skipped_invalid_images, + ) return content async def _ensure_thread_id(self, session_id: str) -> str: diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index a4fea4557d..b7c16ebc1b 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -92,17 +92,26 @@ def __init__( auth_header: str = "", ) -> None: self.api_base = api_base.rstrip("/") - self.session = ClientSession(trust_env=True) + self._session: ClientSession | None = None + self._closed = False self.headers: dict[str, str] = {} if auth_header: self.headers["Authorization"] = auth_header elif api_key: self.headers["Authorization"] = f"Bearer {api_key}" + def _get_session(self) -> ClientSession: + if self._closed: + raise RuntimeError("DeerFlowAPIClient is already closed.") + if self._session is None or self._session.closed: + self._session = ClientSession(trust_env=True) + return self._session + async def create_thread(self, timeout: float = 20) -> dict[str, Any]: + session = self._get_session() url = f"{self.api_base}/api/langgraph/threads" payload = {"metadata": {}} - async with self.session.post( + async with session.post( url, json=payload, headers=self.headers, @@ -121,6 +130,7 @@ async def stream_run( payload: dict[str, Any], timeout: float = 120, ) -> AsyncGenerator[dict[str, Any], None]: + session = self._get_session() url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream" logger.debug(f"deerflow stream_run payload: {payload}") # For long-running SSE streams, avoid aiohttp total timeout. @@ -131,7 +141,7 @@ async def stream_run( sock_connect=min(timeout, 30), sock_read=timeout, ) - async with self.session.post( + async with session.post( url, json=payload, headers={ @@ -150,12 +160,16 @@ async def stream_run( yield event async def close(self) -> None: - if not self.session.closed: - await self.session.close() + self._closed = True + session = self._session + if session is not None and not session.closed: + await session.close() + self._session = None def __del__(self) -> None: - session = getattr(self, "session", None) - if session is None or session.closed: + session = getattr(self, "_session", None) + closed = bool(getattr(self, "_closed", False)) + if closed or session is None or session.closed: return logger.warning( "DeerFlowAPIClient garbage collected with unclosed session; " @@ -164,4 +178,4 @@ def __del__(self) -> None: @property def is_closed(self) -> bool: - return self.session.closed + return self._closed From 17a7ef376e2cd6221c9b50c93c667bcbbc485737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:55:01 +0900 Subject: [PATCH 10/24] perf: avoid unnecessary iterable copies in deerflow stream utils --- .../runners/deerflow/deerflow_stream_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py index d8bd324d08..8c75c66815 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py @@ -61,7 +61,13 @@ def is_ai_message(message: dict[str, T.Any]) -> bool: def extract_latest_ai_text(messages: Iterable[T.Any]) -> str: # Scan backwards to get the latest assistant/ai message text. - for msg in reversed(list(messages)): + if isinstance(messages, (list, tuple)): + iterable = reversed(messages) + else: + # Fallback for generic iterables (e.g. generators). + iterable = reversed(list(messages)) + + for msg in iterable: if not isinstance(msg, dict): continue if is_ai_message(msg): @@ -78,7 +84,12 @@ def is_clarification_tool_message(message: dict[str, T.Any]) -> bool: def extract_latest_clarification_text(messages: Iterable[T.Any]) -> str: - for msg in reversed(list(messages)): + if isinstance(messages, (list, tuple)): + iterable = reversed(messages) + else: + iterable = reversed(list(messages)) + + for msg in iterable: if not isinstance(msg, dict): continue if is_clarification_tool_message(msg): From 75ade39aa2c6223ca6958e36733c1df2c4e51d4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 02:58:29 +0900 Subject: [PATCH 11/24] refactor: centralize runner error key and clarify deerflow client lifecycle --- .../runners/deerflow/deerflow_api_client.py | 19 ++++++++++++++++++- .../method/agent_sub_stages/third_party.py | 15 +++++++++++---- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index b7c16ebc1b..0716ccb944 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -85,6 +85,12 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict[str, Any], No class DeerFlowAPIClient: + """HTTP client for DeerFlow LangGraph API. + + Lifecycle is explicitly managed by callers (runner/stage). `__del__` is only a + fallback diagnostic and must not be relied on for cleanup. + """ + def __init__( self, api_base: str = "http://127.0.0.1:2026", @@ -107,6 +113,17 @@ def _get_session(self) -> ClientSession: self._session = ClientSession(trust_env=True) return self._session + async def __aenter__(self) -> "DeerFlowAPIClient": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object | None, + ) -> None: + await self.close() + async def create_thread(self, timeout: float = 20) -> dict[str, Any]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads" @@ -173,7 +190,7 @@ def __del__(self) -> None: return logger.warning( "DeerFlowAPIClient garbage collected with unclosed session; " - "explicit close() should be called by runner lifecycle." + "explicit close() should be called by runner lifecycle (or `async with`)." ) @property diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index b15da7a938..7cfbaf190d 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -45,6 +45,7 @@ "dashscope": "dashscope_agent_runner_provider_id", "deerflow": "deerflow_agent_runner_provider_id", } +THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY = "_third_party_runner_error" async def run_third_party_agent( @@ -247,7 +248,7 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: ): if runner_output.is_error: stream_has_runner_error = True - event.set_extra("_third_party_runner_error", True) + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) if runner_output.chain: yield runner_output.chain @@ -263,7 +264,10 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: is_runner_error = ( stream_has_runner_error or final_resp.role == "err" ) - event.set_extra("_third_party_runner_error", is_runner_error) + event.set_extra( + THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, + is_runner_error, + ) event.set_result( MessageEventResult( chain=final_resp.result_chain.chain or [], @@ -300,7 +304,10 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: if fallback_is_error else ResultContentType.LLM_RESULT ) - event.set_extra("_third_party_runner_error", fallback_is_error) + event.set_extra( + THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, + fallback_is_error, + ) event.set_result( MessageEventResult( chain=merged_chain, @@ -318,7 +325,7 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: else ResultContentType.LLM_RESULT ) event.set_extra( - "_third_party_runner_error", + THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, final_resp.role == "err", ) event.set_result( From 352ab50bcefa8492af3ebec253c4c86bec1a29a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 03:06:02 +0900 Subject: [PATCH 12/24] refactor: simplify third-party runner output flow --- .../method/agent_sub_stages/third_party.py | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 7cfbaf190d..60fd94990a 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -94,18 +94,10 @@ async def run_third_party_agent( @dataclass class _ThirdPartyRunnerOutput: - chain: MessageChain | None + chain: MessageChain is_error: bool = False -async def _iter_runner_output_chain( - output_stream: AsyncGenerator[_ThirdPartyRunnerOutput, None], -) -> AsyncGenerator[_ThirdPartyRunnerOutput, None]: - async for output in output_stream: - if output.chain: - yield output - - async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: close_callable = getattr(runner, "close", None) if not callable(close_callable): @@ -239,18 +231,15 @@ async def process( async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: nonlocal stream_has_runner_error - async for runner_output in _iter_runner_output_chain( - run_third_party_agent( - runner, - stream_to_general=False, - custom_error_message=custom_error_message, - ), + async for runner_output in run_third_party_agent( + runner, + stream_to_general=False, + custom_error_message=custom_error_message, ): if runner_output.is_error: stream_has_runner_error = True event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) - if runner_output.chain: - yield runner_output.chain + yield runner_output.chain event.set_result( MessageEventResult() @@ -276,15 +265,14 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: ) else: # 非流式响应或转换为普通响应 - fallback_chains: list[MessageChain] = [] + merged_chain: list = [] fallback_is_error = False async for output in run_third_party_agent( runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, ): - if output.chain: - fallback_chains.append(output.chain) + merged_chain.extend(output.chain.chain or []) if output.is_error: fallback_is_error = True yield @@ -292,13 +280,10 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: final_resp = runner.get_final_llm_resp() if not final_resp or not final_resp.result_chain: - if fallback_chains: + if merged_chain: logger.warning( "Agent Runner returned no final response, fallback to streamed error/result chain." ) - merged_chain: list = [] - for chain in fallback_chains: - merged_chain.extend(chain.chain or []) content_type = ( ResultContentType.AGENT_RUNNER_ERROR if fallback_is_error From 3ae43a2ce6057cda607c85a956de47549f2d7d40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 03:10:21 +0900 Subject: [PATCH 13/24] fix: defer streaming runner cleanup and unify error mapping --- .../method/agent_sub_stages/third_party.py | 106 +++++++++++------- 1 file changed, 64 insertions(+), 42 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 60fd94990a..58f546a79e 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -48,6 +48,32 @@ THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY = "_third_party_runner_error" +def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None: + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error) + + +def _runner_result_content_type(is_error: bool) -> ResultContentType: + return ( + ResultContentType.AGENT_RUNNER_ERROR + if is_error + else ResultContentType.LLM_RESULT + ) + + +def _set_non_stream_runner_result( + event: "AstrMessageEvent", + chain: list, + is_error: bool, +) -> None: + _set_runner_error_extra(event, is_error) + event.set_result( + MessageEventResult( + chain=chain, + result_content_type=_runner_result_content_type(is_error), + ), + ) + + async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, @@ -213,6 +239,16 @@ async def process( and not event.platform_meta.support_streaming_message ) + runner_closed = False + defer_runner_close_to_stream = False + + async def _close_runner_once() -> None: + nonlocal runner_closed + if runner_closed: + return + runner_closed = True + await _close_runner_if_supported(runner) + try: await runner.reset( request=req, @@ -231,21 +267,27 @@ async def process( async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: nonlocal stream_has_runner_error - async for runner_output in run_third_party_agent( - runner, - stream_to_general=False, - custom_error_message=custom_error_message, - ): - if runner_output.is_error: - stream_has_runner_error = True - event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) - yield runner_output.chain + try: + async for runner_output in run_third_party_agent( + runner, + stream_to_general=False, + custom_error_message=custom_error_message, + ): + if runner_output.is_error: + stream_has_runner_error = True + _set_runner_error_extra(event, True) + yield runner_output.chain + finally: + # Streaming runner cleanup must happen after consumer + # finishes iterating to avoid tearing down active streams. + await _close_runner_once() event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) .set_async_stream(_stream_runner_chain()), ) + defer_runner_close_to_stream = True yield if runner.done(): final_resp = runner.get_final_llm_resp() @@ -253,10 +295,7 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: is_runner_error = ( stream_has_runner_error or final_resp.role == "err" ) - event.set_extra( - THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, - is_runner_error, - ) + _set_runner_error_extra(event, is_runner_error) event.set_result( MessageEventResult( chain=final_resp.result_chain.chain or [], @@ -284,44 +323,27 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: logger.warning( "Agent Runner returned no final response, fallback to streamed error/result chain." ) - content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if fallback_is_error - else ResultContentType.LLM_RESULT - ) - event.set_extra( - THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, - fallback_is_error, - ) - event.set_result( - MessageEventResult( - chain=merged_chain, - result_content_type=content_type, - ), + _set_non_stream_runner_result( + event=event, + chain=merged_chain, + is_error=fallback_is_error, ) yield return logger.warning("Agent Runner 未返回最终结果。") return - content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if final_resp.role == "err" - else ResultContentType.LLM_RESULT - ) - event.set_extra( - THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, - final_resp.role == "err", - ) - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=content_type, - ), + # Preserve intermediate error signals even if final role is assistant. + is_runner_error = fallback_is_error or final_resp.role == "err" + _set_non_stream_runner_result( + event=event, + chain=final_resp.result_chain.chain or [], + is_error=is_runner_error, ) yield finally: - await _close_runner_if_supported(runner) + if not defer_runner_close_to_stream: + await _close_runner_once() asyncio.create_task( Metric.upload( From a6beab6efe1ba457aca46efe1cf6b7a1828f1716 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 03:15:01 +0900 Subject: [PATCH 14/24] fix: handle id-less values messages and redact stream payload logs --- .../runners/deerflow/deerflow_agent_runner.py | 41 ++++++++++++++++--- .../runners/deerflow/deerflow_api_client.py | 15 ++++++- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index af1112006e..119ab42713 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -1,4 +1,6 @@ import asyncio +import hashlib +import json import sys import typing as T from collections import deque @@ -49,6 +51,8 @@ class _StreamState: task_failures: list[str] = field(default_factory=list) seen_message_ids: set[str] = field(default_factory=set) seen_message_order: deque[str] = field(default_factory=deque) + # Fallback tracking for backends that omit message ids in values events. + no_id_message_fingerprints: dict[int, str] = field(default_factory=dict) baseline_initialized: bool = False run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) timed_out: bool = False @@ -267,16 +271,38 @@ def _extract_new_messages_from_values( state: _StreamState, ) -> list[dict[str, T.Any]]: new_messages: list[dict[str, T.Any]] = [] - for msg in values_messages: + no_id_indexes_seen: set[int] = set() + for idx, msg in enumerate(values_messages): if not isinstance(msg, dict): continue msg_id = get_message_id(msg) - if not msg_id or msg_id in state.seen_message_ids: + if msg_id: + if msg_id in state.seen_message_ids: + continue + self._remember_seen_message_id(state, msg_id) + new_messages.append(msg) continue - self._remember_seen_message_id(state, msg_id) + + no_id_indexes_seen.add(idx) + msg_fingerprint = self._fingerprint_message(msg) + if state.no_id_message_fingerprints.get(idx) == msg_fingerprint: + continue + state.no_id_message_fingerprints[idx] = msg_fingerprint new_messages.append(msg) + + # Keep no-id index state aligned with latest values payload shape. + for idx in list(state.no_id_message_fingerprints.keys()): + if idx not in no_id_indexes_seen: + state.no_id_message_fingerprints.pop(idx, None) return new_messages + def _fingerprint_message(self, message: dict[str, T.Any]) -> str: + try: + raw = json.dumps(message, sort_keys=True, ensure_ascii=False, default=str) + except (TypeError, ValueError): + raw = repr(message) + return hashlib.sha1(raw.encode("utf-8", errors="ignore")).hexdigest() + def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None: if not msg_id or msg_id in state.seen_message_ids: return @@ -422,9 +448,14 @@ def _handle_values_event( if not state.baseline_initialized: state.baseline_initialized = True - for msg in values_messages: + for idx, msg in enumerate(values_messages): + if not isinstance(msg, dict): + continue msg_id = get_message_id(msg) - self._remember_seen_message_id(state, msg_id) + if msg_id: + self._remember_seen_message_id(state, msg_id) + continue + state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg) return responses new_messages = self._extract_new_messages_from_values( diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index 0716ccb944..2922c0b291 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -149,7 +149,20 @@ async def stream_run( ) -> AsyncGenerator[dict[str, Any], None]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream" - logger.debug(f"deerflow stream_run payload: {payload}") + input_payload = payload.get("input") + message_count = 0 + if isinstance(input_payload, dict) and isinstance( + input_payload.get("messages"), list + ): + message_count = len(input_payload["messages"]) + # Log only a minimal summary to avoid exposing sensitive user content. + logger.debug( + "deerflow stream_run payload summary: thread_id=%s, keys=%s, message_count=%d, stream_mode=%s", + thread_id, + list(payload.keys()), + message_count, + payload.get("stream_mode"), + ) # For long-running SSE streams, avoid aiohttp total timeout. # Use socket read timeout so active heartbeats/chunks can keep the stream alive. stream_timeout = ClientTimeout( From bd1648ff95f50b80edebe4ed652e968b632d047f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 03:20:47 +0900 Subject: [PATCH 15/24] fix: improve deerflow error signaling and third-party runner flow --- .../runners/deerflow/deerflow_agent_runner.py | 32 +-- .../runners/deerflow/deerflow_stream_utils.py | 2 +- .../method/agent_sub_stages/third_party.py | 208 +++++++++++------- 3 files changed, 146 insertions(+), 96 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 119ab42713..fc856caf83 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -149,7 +149,7 @@ async def reset( if not isinstance(self.api_base, str) or not self.api_base.startswith( ("http://", "https://"), ): - raise Exception( + raise ValueError( "DeerFlow API Base URL format is invalid. It must start with http:// or https://.", ) self.api_key = provider_config.get("deerflow_api_key", "") @@ -523,7 +523,8 @@ def _handle_message_event( state.clarification_text = maybe_clarification return response - def _resolve_final_text(self, state: _StreamState) -> str: + def _resolve_final_text(self, state: _StreamState) -> tuple[str, bool]: + failures_only = False # Clarification tool output should take precedence over partial AI/tool-call text. if state.clarification_text: final_text = state.clarification_text @@ -533,21 +534,12 @@ def _resolve_final_text(self, state: _StreamState) -> str: final_text = state.streamed_text or state.fallback_stream_text if not final_text: final_text = build_task_failure_summary(state.task_failures) - - if state.timed_out: - timeout_note = ( - f"DeerFlow stream timed out after {self.timeout}s. " - "Returning partial result." - ) - if final_text: - final_text = f"{final_text}\n\n{timeout_note}" - else: - raise asyncio.TimeoutError(timeout_note) + failures_only = bool(final_text) if not final_text: logger.warning("DeerFlow returned no text content in stream events.") final_text = "DeerFlow returned an empty response." - return final_text + return final_text, failures_only async def _execute_deerflow_request(self): prompt = self.req.prompt or "" @@ -598,10 +590,20 @@ async def _execute_deerflow_request(self): except (asyncio.TimeoutError, TimeoutError): state.timed_out = True - final_text = self._resolve_final_text(state) + final_text, failures_only = self._resolve_final_text(state) + if state.timed_out: + timeout_note = ( + f"DeerFlow stream timed out after {self.timeout}s. " + "Returning partial result." + ) + final_text = ( + f"{final_text}\n\n{timeout_note}" if final_text else timeout_note + ) + is_error = state.timed_out or failures_only + role = "err" if is_error else "assistant" chain = MessageChain(chain=[Comp.Plain(final_text)]) - self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self.final_llm_resp = LLMResponse(role=role, result_chain=chain) self._transition_state(AgentState.DONE) try: diff --git a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py index 8c75c66815..94f5330ea5 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py @@ -22,7 +22,7 @@ def extract_text(content: T.Any) -> str: if item_type == "text" and isinstance(item.get("text"), str): parts.append(item["text"]) elif "content" in item: - parts.append(str(item["content"])) + parts.append(extract_text(item["content"])) return "\n".join([p for p in parts if p]).strip() return str(content) if content is not None else "" diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 58f546a79e..6b69942f61 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -1,5 +1,6 @@ import asyncio import inspect +import typing as T from collections.abc import AsyncGenerator from dataclasses import dataclass from typing import TYPE_CHECKING @@ -28,6 +29,7 @@ if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner + from astrbot.core.provider.entities import LLMResponse from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( @@ -74,6 +76,15 @@ def _set_non_stream_runner_result( ) +def _aggregate_runner_error( + has_intermediate_error: bool, + final_resp: "LLMResponse | None", +) -> bool: + if not final_resp: + return has_intermediate_error + return has_intermediate_error or final_resp.role == "err" + + async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, @@ -170,6 +181,101 @@ async def _resolve_persona_custom_error_message( logger.debug("Failed to resolve persona custom error message: %s", e) return None + async def _handle_streaming_runner( + self, + runner: "BaseAgentRunner", + event: AstrMessageEvent, + custom_error_message: str | None, + close_runner_once: T.Callable[[], T.Awaitable[None]], + ) -> AsyncGenerator[None, None]: + stream_has_runner_error = False + + async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: + nonlocal stream_has_runner_error + try: + async for runner_output in run_third_party_agent( + runner, + stream_to_general=False, + custom_error_message=custom_error_message, + ): + if runner_output.is_error: + stream_has_runner_error = True + _set_runner_error_extra(event, True) + yield runner_output.chain + finally: + # Streaming runner cleanup must happen after consumer + # finishes iterating to avoid tearing down active streams. + await close_runner_once() + + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(_stream_runner_chain()), + ) + yield + + if runner.done(): + final_resp = runner.get_final_llm_resp() + if final_resp and final_resp.result_chain: + is_runner_error = _aggregate_runner_error( + has_intermediate_error=stream_has_runner_error, + final_resp=final_resp, + ) + _set_runner_error_extra(event, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + + async def _handle_non_streaming_runner( + self, + runner: "BaseAgentRunner", + event: AstrMessageEvent, + stream_to_general: bool, + custom_error_message: str | None, + ) -> AsyncGenerator[None, None]: + merged_chain: list = [] + has_intermediate_error = False + async for output in run_third_party_agent( + runner, + stream_to_general=stream_to_general, + custom_error_message=custom_error_message, + ): + merged_chain.extend(output.chain.chain or []) + if output.is_error: + has_intermediate_error = True + yield + + final_resp = runner.get_final_llm_resp() + + if not final_resp or not final_resp.result_chain: + if merged_chain: + logger.warning( + "Agent Runner returned no final response, fallback to streamed error/result chain." + ) + _set_non_stream_runner_result( + event=event, + chain=merged_chain, + is_error=has_intermediate_error, + ) + yield + return + logger.warning("Agent Runner 未返回最终结果。") + return + + is_runner_error = _aggregate_runner_error( + has_intermediate_error=has_intermediate_error, + final_resp=final_resp, + ) + _set_non_stream_runner_result( + event=event, + chain=final_resp.result_chain.chain or [], + is_error=is_runner_error, + ) + yield + async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: @@ -240,7 +346,6 @@ async def process( ) runner_closed = False - defer_runner_close_to_stream = False async def _close_runner_once() -> None: nonlocal runner_closed @@ -260,89 +365,32 @@ async def _close_runner_once() -> None: provider_config=self.prov_cfg, streaming=streaming_response, ) - - if streaming_response and not stream_to_general: - # 流式响应 - stream_has_runner_error = False - - async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: - nonlocal stream_has_runner_error - try: - async for runner_output in run_third_party_agent( - runner, - stream_to_general=False, - custom_error_message=custom_error_message, - ): - if runner_output.is_error: - stream_has_runner_error = True - _set_runner_error_extra(event, True) - yield runner_output.chain - finally: - # Streaming runner cleanup must happen after consumer - # finishes iterating to avoid tearing down active streams. - await _close_runner_once() - - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream(_stream_runner_chain()), - ) - defer_runner_close_to_stream = True - yield - if runner.done(): - final_resp = runner.get_final_llm_resp() - if final_resp and final_resp.result_chain: - is_runner_error = ( - stream_has_runner_error or final_resp.role == "err" - ) - _set_runner_error_extra(event, is_runner_error) - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - else: - # 非流式响应或转换为普通响应 - merged_chain: list = [] - fallback_is_error = False - async for output in run_third_party_agent( - runner, - stream_to_general=stream_to_general, + except Exception: + await _close_runner_once() + raise + + if streaming_response and not stream_to_general: + try: + async for _ in self._handle_streaming_runner( + runner=runner, + event=event, custom_error_message=custom_error_message, + close_runner_once=_close_runner_once, ): - merged_chain.extend(output.chain.chain or []) - if output.is_error: - fallback_is_error = True yield - - final_resp = runner.get_final_llm_resp() - - if not final_resp or not final_resp.result_chain: - if merged_chain: - logger.warning( - "Agent Runner returned no final response, fallback to streamed error/result chain." - ) - _set_non_stream_runner_result( - event=event, - chain=merged_chain, - is_error=fallback_is_error, - ) - yield - return - logger.warning("Agent Runner 未返回最终结果。") - return - - # Preserve intermediate error signals even if final role is assistant. - is_runner_error = fallback_is_error or final_resp.role == "err" - _set_non_stream_runner_result( + except Exception: + await _close_runner_once() + raise + else: + try: + async for _ in self._handle_non_streaming_runner( + runner=runner, event=event, - chain=final_resp.result_chain.chain or [], - is_error=is_runner_error, - ) - yield - finally: - if not defer_runner_close_to_stream: + stream_to_general=stream_to_general, + custom_error_message=custom_error_message, + ): + yield + finally: await _close_runner_once() asyncio.create_task( From 613559a3e3a5f080e99aa8e45e7689fea15b7735 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 03:26:16 +0900 Subject: [PATCH 16/24] fix: support deerflow proxy and refine runner lifecycle --- .../runners/deerflow/deerflow_agent_runner.py | 55 ++++++----- .../runners/deerflow/deerflow_api_client.py | 6 ++ .../method/agent_sub_stages/third_party.py | 91 ++++++++++++------- 3 files changed, 94 insertions(+), 58 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index fc856caf83..371394402e 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -45,8 +45,8 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]): @dataclass class _StreamState: - streamed_text: str = "" - fallback_stream_text: str = "" + latest_text: str = "" + prev_text_for_streaming: str = "" clarification_text: str = "" task_failures: list[str] = field(default_factory=list) seen_message_ids: set[str] = field(default_factory=set) @@ -54,6 +54,7 @@ class _StreamState: # Fallback tracking for backends that omit message ids in values events. no_id_message_fingerprints: dict[int, str] = field(default_factory=dict) baseline_initialized: bool = False + has_values_text: bool = False run_values_messages: list[dict[str, T.Any]] = field(default_factory=list) timed_out: bool = False @@ -154,6 +155,8 @@ async def reset( ) self.api_key = provider_config.get("deerflow_api_key", "") self.auth_header = provider_config.get("deerflow_auth_header", "") + proxy = provider_config.get("proxy", "") + self.proxy = proxy.strip() if isinstance(proxy, str) else "" self.assistant_id = provider_config.get("deerflow_assistant_id", "lead_agent") self.model_name = provider_config.get("deerflow_model_name", "") self.thinking_enabled = bool( @@ -186,7 +189,12 @@ async def reset( min_value=1, ) - new_client_signature = (self.api_base, self.api_key, self.auth_header) + new_client_signature = ( + self.api_base, + self.api_key, + self.auth_header, + self.proxy, + ) old_client = getattr(self, "api_client", None) old_signature = getattr(self, "_api_client_signature", None) if ( @@ -209,6 +217,7 @@ async def reset( api_base=self.api_base, api_key=self.api_key, auth_header=self.auth_header, + proxy=self.proxy, ) self._api_client_signature = new_client_signature @@ -230,6 +239,9 @@ async def step(self): try: async for response in self._execute_deerflow_request(): yield response + except asyncio.CancelledError: + # Let caller manage cancellation semantics. + raise except Exception as e: err_msg = self._format_exception(e) logger.error(f"DeerFlow request failed: {err_msg}", exc_info=True) @@ -462,6 +474,7 @@ def _handle_values_event( values_messages, state, ) + latest_text = "" if new_messages: state.run_values_messages.extend(new_messages) if len(state.run_values_messages) > self._MAX_VALUES_HISTORY: @@ -469,34 +482,28 @@ def _handle_values_event( -self._MAX_VALUES_HISTORY : ] latest_text = extract_latest_ai_text(state.run_values_messages) + if latest_text: + state.latest_text = latest_text + state.has_values_text = True latest_clarification = extract_latest_clarification_text( state.run_values_messages, ) if latest_clarification: state.clarification_text = latest_clarification - else: - latest_text = "" if self.streaming and latest_text: - if latest_text.startswith(state.streamed_text): - delta = latest_text[len(state.streamed_text) :] - if delta: - state.streamed_text = latest_text - responses.append( - AgentResponse( - type="streaming_delta", - data=AgentResponseData( - chain=MessageChain().message(delta), - ), - ), - ) - elif latest_text != state.streamed_text: - state.streamed_text = latest_text + if latest_text.startswith(state.prev_text_for_streaming): + delta = latest_text[len(state.prev_text_for_streaming) :] + else: + delta = latest_text + + if delta: + state.prev_text_for_streaming = latest_text responses.append( AgentResponse( type="streaming_delta", data=AgentResponseData( - chain=MessageChain().message(latest_text), + chain=MessageChain().message(delta), ), ), ) @@ -508,11 +515,11 @@ def _handle_message_event( state: _StreamState, ) -> AgentResponse | None: delta = extract_ai_delta_from_event_data(data) - if delta: - state.fallback_stream_text += delta response: AgentResponse | None = None - if self.streaming and delta and not state.streamed_text: + if delta and not state.has_values_text: + state.latest_text += delta + if self.streaming and delta and not state.has_values_text: response = AgentResponse( type="streaming_delta", data=AgentResponseData(chain=MessageChain().message(delta)), @@ -531,7 +538,7 @@ def _resolve_final_text(self, state: _StreamState) -> tuple[str, bool]: else: final_text = extract_latest_ai_text(state.run_values_messages) if not final_text: - final_text = state.streamed_text or state.fallback_stream_text + final_text = state.latest_text if not final_text: final_text = build_task_failure_summary(state.task_failures) failures_only = bool(final_text) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index 2922c0b291..f279db2005 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -96,10 +96,14 @@ def __init__( api_base: str = "http://127.0.0.1:2026", api_key: str = "", auth_header: str = "", + proxy: str | None = None, ) -> None: self.api_base = api_base.rstrip("/") self._session: ClientSession | None = None self._closed = False + self.proxy = proxy.strip() if isinstance(proxy, str) else None + if self.proxy == "": + self.proxy = None self.headers: dict[str, str] = {} if auth_header: self.headers["Authorization"] = auth_header @@ -133,6 +137,7 @@ async def create_thread(self, timeout: float = 20) -> dict[str, Any]: json=payload, headers=self.headers, timeout=timeout, + proxy=self.proxy, ) as resp: if resp.status not in (200, 201): text = await resp.text() @@ -180,6 +185,7 @@ async def stream_run( "Content-Type": "application/json", }, timeout=stream_timeout, + proxy=self.proxy, ) as resp: if resp.status != 200: text = await resp.text() diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 6b69942f61..0a9fdabaed 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -2,6 +2,7 @@ import inspect import typing as T from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import TYPE_CHECKING @@ -148,6 +149,44 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: logger.warning(f"Failed to close third-party runner cleanly: {e}") +@asynccontextmanager +async def _runner_session( + runner: "BaseAgentRunner", + *, + request: ProviderRequest, + run_context: AgentContextWrapper, + agent_hooks: T.Any, + provider_config: dict, + streaming: bool, +): + runner_closed = False + defer_close = False + + async def close_runner_once() -> None: + nonlocal runner_closed + if runner_closed: + return + runner_closed = True + await _close_runner_if_supported(runner) + + def defer_runner_close() -> None: + nonlocal defer_close + defer_close = True + + await runner.reset( + request=request, + run_context=run_context, + agent_hooks=agent_hooks, + provider_config=provider_config, + streaming=streaming, + ) + try: + yield close_runner_once, defer_runner_close + finally: + if not defer_close: + await close_runner_once() + + class ThirdPartyAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -345,44 +384,30 @@ async def process( and not event.platform_meta.support_streaming_message ) - runner_closed = False - - async def _close_runner_once() -> None: - nonlocal runner_closed - if runner_closed: - return - runner_closed = True - await _close_runner_if_supported(runner) - - try: - await runner.reset( - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=60, - ), - agent_hooks=MAIN_AGENT_HOOKS, - provider_config=self.prov_cfg, - streaming=streaming_response, - ) - except Exception: - await _close_runner_once() - raise - - if streaming_response and not stream_to_general: - try: + async with _runner_session( + runner=runner, + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, + streaming=streaming_response, + ) as (close_runner_once, defer_runner_close): + if streaming_response and not stream_to_general: + stream_started = False async for _ in self._handle_streaming_runner( runner=runner, event=event, custom_error_message=custom_error_message, - close_runner_once=_close_runner_once, + close_runner_once=close_runner_once, ): + if not stream_started: + defer_runner_close() + stream_started = True yield - except Exception: - await _close_runner_once() - raise - else: - try: + else: async for _ in self._handle_non_streaming_runner( runner=runner, event=event, @@ -390,8 +415,6 @@ async def _close_runner_once() -> None: custom_error_message=custom_error_message, ): yield - finally: - await _close_runner_once() asyncio.create_task( Metric.upload( From 98a48403f8ac147a02770ec569d9a181b313b0bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 03:33:20 +0900 Subject: [PATCH 17/24] fix: tighten deerflow image validation and runner lifecycle --- .../runners/deerflow/deerflow_agent_runner.py | 14 +- .../method/agent_sub_stages/third_party.py | 309 +++++++----------- 2 files changed, 128 insertions(+), 195 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 371394402e..7852da58ec 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -1,4 +1,5 @@ import asyncio +import base64 import hashlib import json import sys @@ -330,13 +331,19 @@ def _is_likely_base64_image(self, value: str) -> bool: return False compact = value.replace("\n", "").replace("\r", "") - if not compact or len(compact) % 4 != 0: + if not compact or len(compact) < 32 or len(compact) % 4 != 0: return False base64_chars = ( "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" ) - return all(ch in base64_chars for ch in compact) + if any(ch not in base64_chars for ch in compact): + return False + try: + base64.b64decode(compact, validate=True) + except Exception: + return False + return True def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: if not image_urls: @@ -360,10 +367,11 @@ def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any: if not self._is_likely_base64_image(url): skipped_invalid_images += 1 continue + compact_base64 = url.replace("\n", "").replace("\r", "") content.append( { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{url}"}, + "image_url": {"url": f"data:image/png;base64,{compact_base64}"}, }, ) if skipped_invalid_images: diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 0a9fdabaed..38011e92bc 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -1,8 +1,6 @@ import asyncio import inspect -import typing as T from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager from dataclasses import dataclass from typing import TYPE_CHECKING @@ -30,7 +28,6 @@ if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner - from astrbot.core.provider.entities import LLMResponse from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( @@ -55,37 +52,6 @@ def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None: event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error) -def _runner_result_content_type(is_error: bool) -> ResultContentType: - return ( - ResultContentType.AGENT_RUNNER_ERROR - if is_error - else ResultContentType.LLM_RESULT - ) - - -def _set_non_stream_runner_result( - event: "AstrMessageEvent", - chain: list, - is_error: bool, -) -> None: - _set_runner_error_extra(event, is_error) - event.set_result( - MessageEventResult( - chain=chain, - result_content_type=_runner_result_content_type(is_error), - ), - ) - - -def _aggregate_runner_error( - has_intermediate_error: bool, - final_resp: "LLMResponse | None", -) -> bool: - if not final_resp: - return has_intermediate_error - return has_intermediate_error or final_resp.role == "err" - - async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, @@ -149,44 +115,6 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: logger.warning(f"Failed to close third-party runner cleanly: {e}") -@asynccontextmanager -async def _runner_session( - runner: "BaseAgentRunner", - *, - request: ProviderRequest, - run_context: AgentContextWrapper, - agent_hooks: T.Any, - provider_config: dict, - streaming: bool, -): - runner_closed = False - defer_close = False - - async def close_runner_once() -> None: - nonlocal runner_closed - if runner_closed: - return - runner_closed = True - await _close_runner_if_supported(runner) - - def defer_runner_close() -> None: - nonlocal defer_close - defer_close = True - - await runner.reset( - request=request, - run_context=run_context, - agent_hooks=agent_hooks, - provider_config=provider_config, - streaming=streaming, - ) - try: - yield close_runner_once, defer_runner_close - finally: - if not defer_close: - await close_runner_once() - - class ThirdPartyAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -220,101 +148,6 @@ async def _resolve_persona_custom_error_message( logger.debug("Failed to resolve persona custom error message: %s", e) return None - async def _handle_streaming_runner( - self, - runner: "BaseAgentRunner", - event: AstrMessageEvent, - custom_error_message: str | None, - close_runner_once: T.Callable[[], T.Awaitable[None]], - ) -> AsyncGenerator[None, None]: - stream_has_runner_error = False - - async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: - nonlocal stream_has_runner_error - try: - async for runner_output in run_third_party_agent( - runner, - stream_to_general=False, - custom_error_message=custom_error_message, - ): - if runner_output.is_error: - stream_has_runner_error = True - _set_runner_error_extra(event, True) - yield runner_output.chain - finally: - # Streaming runner cleanup must happen after consumer - # finishes iterating to avoid tearing down active streams. - await close_runner_once() - - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream(_stream_runner_chain()), - ) - yield - - if runner.done(): - final_resp = runner.get_final_llm_resp() - if final_resp and final_resp.result_chain: - is_runner_error = _aggregate_runner_error( - has_intermediate_error=stream_has_runner_error, - final_resp=final_resp, - ) - _set_runner_error_extra(event, is_runner_error) - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - - async def _handle_non_streaming_runner( - self, - runner: "BaseAgentRunner", - event: AstrMessageEvent, - stream_to_general: bool, - custom_error_message: str | None, - ) -> AsyncGenerator[None, None]: - merged_chain: list = [] - has_intermediate_error = False - async for output in run_third_party_agent( - runner, - stream_to_general=stream_to_general, - custom_error_message=custom_error_message, - ): - merged_chain.extend(output.chain.chain or []) - if output.is_error: - has_intermediate_error = True - yield - - final_resp = runner.get_final_llm_resp() - - if not final_resp or not final_resp.result_chain: - if merged_chain: - logger.warning( - "Agent Runner returned no final response, fallback to streamed error/result chain." - ) - _set_non_stream_runner_result( - event=event, - chain=merged_chain, - is_error=has_intermediate_error, - ) - yield - return - logger.warning("Agent Runner 未返回最终结果。") - return - - is_runner_error = _aggregate_runner_error( - has_intermediate_error=has_intermediate_error, - final_resp=final_resp, - ) - _set_non_stream_runner_result( - event=event, - chain=final_resp.result_chain.chain or [], - is_error=is_runner_error, - ) - yield - async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: @@ -384,37 +217,129 @@ async def process( and not event.platform_meta.support_streaming_message ) - async with _runner_session( - runner=runner, - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=60, - ), - agent_hooks=MAIN_AGENT_HOOKS, - provider_config=self.prov_cfg, - streaming=streaming_response, - ) as (close_runner_once, defer_runner_close): + runner_closed = False + streaming_started = False + + async def close_runner_once() -> None: + nonlocal runner_closed + if runner_closed: + return + runner_closed = True + await _close_runner_if_supported(runner) + + try: + await runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, + streaming=streaming_response, + ) + if streaming_response and not stream_to_general: - stream_started = False - async for _ in self._handle_streaming_runner( - runner=runner, - event=event, - custom_error_message=custom_error_message, - close_runner_once=close_runner_once, - ): - if not stream_started: - defer_runner_close() - stream_started = True - yield + stream_has_runner_error = False + + async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: + nonlocal stream_has_runner_error + try: + async for runner_output in run_third_party_agent( + runner, + stream_to_general=False, + custom_error_message=custom_error_message, + ): + if runner_output.is_error: + stream_has_runner_error = True + _set_runner_error_extra(event, True) + yield runner_output.chain + finally: + # Streaming runner cleanup must happen after consumer + # finishes iterating to avoid tearing down active streams. + await close_runner_once() + + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(_stream_runner_chain()), + ) + streaming_started = True + yield + + if runner.done(): + final_resp = runner.get_final_llm_resp() + if final_resp and final_resp.result_chain: + is_runner_error = ( + stream_has_runner_error or final_resp.role == "err" + ) + _set_runner_error_extra(event, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) else: - async for _ in self._handle_non_streaming_runner( - runner=runner, - event=event, + merged_chain: list = [] + has_intermediate_error = False + async for output in run_third_party_agent( + runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, ): + merged_chain.extend(output.chain.chain or []) + if output.is_error: + has_intermediate_error = True + yield + + final_resp = runner.get_final_llm_resp() + if not final_resp or not final_resp.result_chain: + if merged_chain: + logger.warning( + "Agent Runner returned no final response, fallback to streamed error/result chain." + ) + _set_runner_error_extra(event, has_intermediate_error) + event.set_result( + MessageEventResult( + chain=merged_chain, + result_content_type=( + ResultContentType.AGENT_RUNNER_ERROR + if has_intermediate_error + else ResultContentType.LLM_RESULT + ), + ), + ) + else: + logger.warning("Agent Runner 未返回最终结果。") + fallback_error_chain = MessageChain().message( + "Agent Runner did not return any result.", + ) + _set_runner_error_extra(event, True) + event.set_result( + MessageEventResult( + chain=fallback_error_chain.chain or [], + result_content_type=ResultContentType.AGENT_RUNNER_ERROR, + ), + ) yield + else: + is_runner_error = has_intermediate_error or final_resp.role == "err" + _set_runner_error_extra(event, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=( + ResultContentType.AGENT_RUNNER_ERROR + if is_runner_error + else ResultContentType.LLM_RESULT + ), + ), + ) + yield + finally: + if not streaming_started: + await close_runner_once() asyncio.create_task( Metric.upload( From b3cab6da8dec760f9c1d7aded494df5102f57855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 03:55:53 +0900 Subject: [PATCH 18/24] feat: support deerflow image output components --- .../runners/deerflow/deerflow_agent_runner.py | 156 +++++++++++++++--- .../runners/deerflow/deerflow_stream_utils.py | 14 ++ 2 files changed, 151 insertions(+), 19 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 7852da58ec..85e313b693 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -26,10 +26,12 @@ build_task_failure_summary, extract_ai_delta_from_event_data, extract_clarification_from_event_data, + extract_latest_ai_message, extract_latest_ai_text, extract_latest_clarification_text, extract_messages_from_values_data, extract_task_failures_from_custom_event, + extract_text, get_message_id, ) @@ -423,6 +425,92 @@ def _build_messages( ) return messages + def _image_component_from_url(self, url: T.Any) -> Comp.Image | None: + if not isinstance(url, str): + return None + + normalized = url.strip() + if not normalized: + return None + + if normalized.startswith(("http://", "https://")): + try: + return Comp.Image.fromURL(normalized) + except Exception: + return None + + if not normalized.startswith("data:"): + return None + + header, sep, payload = normalized.partition(",") + if not sep: + return None + if ";base64" not in header.lower(): + return None + + compact_payload = payload.replace("\n", "").replace("\r", "").strip() + if not compact_payload: + return None + try: + base64.b64decode(compact_payload, validate=True) + except Exception: + return None + return Comp.Image.fromBase64(compact_payload) + + def _append_components_from_content( + self, + content: T.Any, + components: list[Comp.BaseMessageComponent], + ) -> None: + if isinstance(content, str): + if content: + components.append(Comp.Plain(content)) + return + + if isinstance(content, list): + for item in content: + self._append_components_from_content(item, components) + return + + if not isinstance(content, dict): + return + + item_type = str(content.get("type", "")).lower() + if item_type == "text" and isinstance(content.get("text"), str): + text = content["text"] + if text: + components.append(Comp.Plain(text)) + return + + if item_type == "image_url": + image_payload = content.get("image_url") + image_url: T.Any = image_payload + if isinstance(image_payload, dict): + image_url = image_payload.get("url") + image_comp = self._image_component_from_url(image_url) + if image_comp is not None: + components.append(image_comp) + return + + if "content" in content: + self._append_components_from_content(content.get("content"), components) + return + + kwargs = content.get("kwargs") + if isinstance(kwargs, dict) and "content" in kwargs: + self._append_components_from_content(kwargs.get("content"), components) + + def _build_chain_from_ai_content(self, content: T.Any) -> MessageChain: + components: list[Comp.BaseMessageComponent] = [] + self._append_components_from_content(content, components) + if components: + return MessageChain(chain=components) + + fallback_text = extract_text(content) + if fallback_text: + return MessageChain(chain=[Comp.Plain(fallback_text)]) + return MessageChain() + def _build_runtime_context(self, thread_id: str) -> dict[str, T.Any]: runtime_context: dict[str, T.Any] = { "thread_id": thread_id, @@ -538,23 +626,35 @@ def _handle_message_event( state.clarification_text = maybe_clarification return response - def _resolve_final_text(self, state: _StreamState) -> tuple[str, bool]: + def _resolve_final_output(self, state: _StreamState) -> tuple[MessageChain, bool]: failures_only = False + final_chain = MessageChain() + # Clarification tool output should take precedence over partial AI/tool-call text. if state.clarification_text: - final_text = state.clarification_text + final_chain = MessageChain(chain=[Comp.Plain(state.clarification_text)]) else: - final_text = extract_latest_ai_text(state.run_values_messages) - if not final_text: - final_text = state.latest_text - if not final_text: - final_text = build_task_failure_summary(state.task_failures) - failures_only = bool(final_text) - - if not final_text: + latest_ai_message = extract_latest_ai_message(state.run_values_messages) + if latest_ai_message: + final_chain = self._build_chain_from_ai_content( + latest_ai_message.get("content"), + ) + + if not final_chain.chain and state.latest_text: + final_chain = MessageChain(chain=[Comp.Plain(state.latest_text)]) + + if not final_chain.chain: + failure_text = build_task_failure_summary(state.task_failures) + if failure_text: + final_chain = MessageChain(chain=[Comp.Plain(failure_text)]) + failures_only = True + + if not final_chain.chain: logger.warning("DeerFlow returned no text content in stream events.") - final_text = "DeerFlow returned an empty response." - return final_text, failures_only + final_chain = MessageChain( + chain=[Comp.Plain("DeerFlow returned an empty response.")], + ) + return final_chain, failures_only async def _execute_deerflow_request(self): prompt = self.req.prompt or "" @@ -605,20 +705,38 @@ async def _execute_deerflow_request(self): except (asyncio.TimeoutError, TimeoutError): state.timed_out = True - final_text, failures_only = self._resolve_final_text(state) + final_chain, failures_only = self._resolve_final_output(state) if state.timed_out: timeout_note = ( f"DeerFlow stream timed out after {self.timeout}s. " "Returning partial result." ) - final_text = ( - f"{final_text}\n\n{timeout_note}" if final_text else timeout_note - ) + if final_chain.chain and isinstance(final_chain.chain[-1], Comp.Plain): + last_text = final_chain.chain[-1].text + final_chain.chain[-1].text = ( + f"{last_text}\n\n{timeout_note}" if last_text else timeout_note + ) + else: + final_chain.chain.append(Comp.Plain(timeout_note)) + + if self.streaming: + non_plain_components = [ + component + for component in final_chain.chain + if not isinstance(component, Comp.Plain) + ] + if non_plain_components: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain(chain=non_plain_components), + ), + ) + is_error = state.timed_out or failures_only role = "err" if is_error else "assistant" - chain = MessageChain(chain=[Comp.Plain(final_text)]) - self.final_llm_resp = LLMResponse(role=role, result_chain=chain) + self.final_llm_resp = LLMResponse(role=role, result_chain=final_chain) self._transition_state(AgentState.DONE) try: @@ -628,7 +746,7 @@ async def _execute_deerflow_request(self): yield AgentResponse( type="llm_result", - data=AgentResponseData(chain=chain), + data=AgentResponseData(chain=final_chain), ) @override diff --git a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py index 94f5330ea5..0c8a5bb385 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_stream_utils.py @@ -77,6 +77,20 @@ def extract_latest_ai_text(messages: Iterable[T.Any]) -> str: return "" +def extract_latest_ai_message(messages: Iterable[T.Any]) -> dict[str, T.Any] | None: + if isinstance(messages, (list, tuple)): + iterable = reversed(messages) + else: + iterable = reversed(list(messages)) + + for msg in iterable: + if not isinstance(msg, dict): + continue + if is_ai_message(msg): + return msg + return None + + def is_clarification_tool_message(message: dict[str, T.Any]) -> bool: msg_type = str(message.get("type", "")).lower() tool_name = str(message.get("name", "")).lower() From b4f7262b8fee16ad6fe262b424184a7446844fdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 04:02:12 +0900 Subject: [PATCH 19/24] fix: harden runner stream cleanup and refactor deerflow config --- .../runners/deerflow/deerflow_agent_runner.py | 158 +++++++++++------- .../method/agent_sub_stages/third_party.py | 139 ++++++++++----- 2 files changed, 194 insertions(+), 103 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 85e313b693..ad18071244 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -46,6 +46,21 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]): _MAX_VALUES_HISTORY = 200 + @dataclass(frozen=True) + class _RunnerConfig: + api_base: str + api_key: str + auth_header: str + proxy: str + assistant_id: str + model_name: str + thinking_enabled: bool + plan_mode: bool + subagent_enabled: bool + max_concurrent_subagents: int + timeout: int + recursion_limit: int + @dataclass class _StreamState: latest_text: str = "" @@ -131,75 +146,80 @@ async def close(self) -> None: if isinstance(api_client, DeerFlowAPIClient) and not api_client.is_closed: await api_client.close() - @override - async def reset( - self, - request: ProviderRequest, - run_context: ContextWrapper[TContext], - agent_hooks: BaseAgentRunHooks[TContext], - provider_config: dict, - **kwargs: T.Any, - ) -> None: - self.req = request - self.streaming = kwargs.get("streaming", False) - self.final_llm_resp = None - self._state = AgentState.IDLE - self.agent_hooks = agent_hooks - self.run_context = run_context - - self.api_base = provider_config.get( - "deerflow_api_base", "http://127.0.0.1:2026" - ) - if not isinstance(self.api_base, str) or not self.api_base.startswith( + def _parse_runner_config(self, provider_config: dict) -> _RunnerConfig: + api_base = provider_config.get("deerflow_api_base", "http://127.0.0.1:2026") + if not isinstance(api_base, str) or not api_base.startswith( ("http://", "https://"), ): raise ValueError( "DeerFlow API Base URL format is invalid. It must start with http:// or https://.", ) - self.api_key = provider_config.get("deerflow_api_key", "") - self.auth_header = provider_config.get("deerflow_auth_header", "") + proxy = provider_config.get("proxy", "") - self.proxy = proxy.strip() if isinstance(proxy, str) else "" - self.assistant_id = provider_config.get("deerflow_assistant_id", "lead_agent") - self.model_name = provider_config.get("deerflow_model_name", "") - self.thinking_enabled = bool( - provider_config.get("deerflow_thinking_enabled", False), - ) - self.plan_mode = bool(provider_config.get("deerflow_plan_mode", False)) - self.subagent_enabled = bool( - provider_config.get("deerflow_subagent_enabled", False), - ) - self.max_concurrent_subagents = self._coerce_int_config( - "deerflow_max_concurrent_subagents", - provider_config.get( + normalized_proxy = proxy.strip() if isinstance(proxy, str) else "" + + return self._RunnerConfig( + api_base=api_base, + api_key=provider_config.get("deerflow_api_key", ""), + auth_header=provider_config.get("deerflow_auth_header", ""), + proxy=normalized_proxy, + assistant_id=provider_config.get("deerflow_assistant_id", "lead_agent"), + model_name=provider_config.get("deerflow_model_name", ""), + thinking_enabled=bool( + provider_config.get("deerflow_thinking_enabled", False), + ), + plan_mode=bool(provider_config.get("deerflow_plan_mode", False)), + subagent_enabled=bool( + provider_config.get("deerflow_subagent_enabled", False), + ), + max_concurrent_subagents=self._coerce_int_config( "deerflow_max_concurrent_subagents", - 3, + provider_config.get("deerflow_max_concurrent_subagents", 3), + default=3, + min_value=1, + ), + timeout=self._coerce_int_config( + "timeout", + provider_config.get("timeout", 300), + default=300, + min_value=1, + ), + recursion_limit=self._coerce_int_config( + "deerflow_recursion_limit", + provider_config.get("deerflow_recursion_limit", 1000), + default=1000, + min_value=1, ), - default=3, - min_value=1, ) - self.timeout = self._coerce_int_config( - "timeout", - provider_config.get("timeout", 300), - default=300, - min_value=1, - ) - self.recursion_limit = self._coerce_int_config( - "deerflow_recursion_limit", - provider_config.get("deerflow_recursion_limit", 1000), - default=1000, - min_value=1, + def _apply_runner_config(self, config: _RunnerConfig) -> None: + self.api_base = config.api_base + self.api_key = config.api_key + self.auth_header = config.auth_header + self.proxy = config.proxy + self.assistant_id = config.assistant_id + self.model_name = config.model_name + self.thinking_enabled = config.thinking_enabled + self.plan_mode = config.plan_mode + self.subagent_enabled = config.subagent_enabled + self.max_concurrent_subagents = config.max_concurrent_subagents + self.timeout = config.timeout + self.recursion_limit = config.recursion_limit + + @staticmethod + def _build_client_signature(config: _RunnerConfig) -> tuple[str, str, str, str]: + return ( + config.api_base, + config.api_key, + config.auth_header, + config.proxy, ) - new_client_signature = ( - self.api_base, - self.api_key, - self.auth_header, - self.proxy, - ) + async def _refresh_api_client(self, config: _RunnerConfig) -> None: + new_client_signature = self._build_client_signature(config) old_client = getattr(self, "api_client", None) old_signature = getattr(self, "_api_client_signature", None) + if ( isinstance(old_client, DeerFlowAPIClient) and old_signature == new_client_signature @@ -217,13 +237,33 @@ async def reset( ) self.api_client = DeerFlowAPIClient( - api_base=self.api_base, - api_key=self.api_key, - auth_header=self.auth_header, - proxy=self.proxy, + api_base=config.api_base, + api_key=config.api_key, + auth_header=config.auth_header, + proxy=config.proxy, ) self._api_client_signature = new_client_signature + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + config = self._parse_runner_config(provider_config) + self._apply_runner_config(config) + await self._refresh_api_client(config) + @override async def step(self): if not self.req: diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 38011e92bc..5ef96a440c 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner + from astrbot.core.provider.entities import LLMResponse from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( @@ -46,12 +47,46 @@ "deerflow": "deerflow_agent_runner_provider_id", } THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY = "_third_party_runner_error" +STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC = 30 def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None: event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error) +def _resolve_final_result( + merged_chain: list, + final_resp: "LLMResponse | None", + has_intermediate_error: bool, +) -> tuple[list, bool, ResultContentType]: + if not final_resp or not final_resp.result_chain: + if merged_chain: + is_error = has_intermediate_error + content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if is_error + else ResultContentType.LLM_RESULT + ) + return merged_chain, is_error, content_type + + fallback_error_chain = MessageChain().message( + "Agent Runner did not return any result.", + ) + return ( + fallback_error_chain.chain or [], + True, + ResultContentType.AGENT_RUNNER_ERROR, + ) + + is_error = has_intermediate_error or final_resp.role == "err" + content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if is_error + else ResultContentType.LLM_RESULT + ) + return final_resp.result_chain.chain or [], is_error, content_type + + async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, @@ -219,6 +254,8 @@ async def process( runner_closed = False streaming_started = False + stream_consumption_started = False + stream_idle_close_task: asyncio.Task[None] | None = None async def close_runner_once() -> None: nonlocal runner_closed @@ -227,6 +264,17 @@ async def close_runner_once() -> None: runner_closed = True await _close_runner_if_supported(runner) + async def close_if_stream_never_consumed() -> None: + try: + await asyncio.sleep(STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC) + except asyncio.CancelledError: + return + if not stream_consumption_started: + logger.warning( + "Third-party runner stream was never consumed; closing runner to avoid resource leak.", + ) + await close_runner_once() + try: await runner.reset( request=req, @@ -243,7 +291,10 @@ async def close_runner_once() -> None: stream_has_runner_error = False async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: - nonlocal stream_has_runner_error + nonlocal stream_has_runner_error, stream_consumption_started + stream_consumption_started = True + if stream_idle_close_task and not stream_idle_close_task.done(): + stream_idle_close_task.cancel() try: async for runner_output in run_third_party_agent( runner, @@ -264,30 +315,40 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: .set_result_content_type(ResultContentType.STREAMING_RESULT) .set_async_stream(_stream_runner_chain()), ) + stream_idle_close_task = asyncio.create_task( + close_if_stream_never_consumed(), + ) streaming_started = True yield if runner.done(): final_resp = runner.get_final_llm_resp() if final_resp and final_resp.result_chain: - is_runner_error = ( - stream_has_runner_error or final_resp.role == "err" + ( + final_chain, + is_runner_error, + _, + ) = _resolve_final_result( + merged_chain=[], + final_resp=final_resp, + has_intermediate_error=stream_has_runner_error, ) _set_runner_error_extra(event, is_runner_error) event.set_result( MessageEventResult( - chain=final_resp.result_chain.chain or [], + chain=final_chain, result_content_type=ResultContentType.STREAMING_FINISH, ), ) else: - merged_chain: list = [] - has_intermediate_error = False - async for output in run_third_party_agent( + output_stream = run_third_party_agent( runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, - ): + ) + merged_chain: list = [] + has_intermediate_error = False + async for output in output_stream: merged_chain.extend(output.chain.chain or []) if output.is_error: has_intermediate_error = True @@ -299,45 +360,35 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: logger.warning( "Agent Runner returned no final response, fallback to streamed error/result chain." ) - _set_runner_error_extra(event, has_intermediate_error) - event.set_result( - MessageEventResult( - chain=merged_chain, - result_content_type=( - ResultContentType.AGENT_RUNNER_ERROR - if has_intermediate_error - else ResultContentType.LLM_RESULT - ), - ), - ) else: logger.warning("Agent Runner 未返回最终结果。") - fallback_error_chain = MessageChain().message( - "Agent Runner did not return any result.", - ) - _set_runner_error_extra(event, True) - event.set_result( - MessageEventResult( - chain=fallback_error_chain.chain or [], - result_content_type=ResultContentType.AGENT_RUNNER_ERROR, - ), - ) - yield - else: - is_runner_error = has_intermediate_error or final_resp.role == "err" - _set_runner_error_extra(event, is_runner_error) - event.set_result( - MessageEventResult( - chain=final_resp.result_chain.chain or [], - result_content_type=( - ResultContentType.AGENT_RUNNER_ERROR - if is_runner_error - else ResultContentType.LLM_RESULT - ), - ), - ) - yield + + ( + final_chain, + is_runner_error, + result_content_type, + ) = _resolve_final_result( + merged_chain=merged_chain, + final_resp=final_resp, + has_intermediate_error=has_intermediate_error, + ) + _set_runner_error_extra(event, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_chain, + result_content_type=result_content_type, + ), + ) + yield finally: + if ( + stream_idle_close_task + and not stream_idle_close_task.done() + and ( + not streaming_started or stream_consumption_started or runner_closed + ) + ): + stream_idle_close_task.cancel() if not streaming_started: await close_runner_once() From f0dc39af4a910fbdfa8633bde31587cd7498a6de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 04:07:29 +0900 Subject: [PATCH 20/24] fix: preserve deerflow done hook and simplify runner lifecycle --- .../runners/deerflow/deerflow_agent_runner.py | 15 +- .../method/agent_sub_stages/third_party.py | 315 +++++++++++------- 2 files changed, 196 insertions(+), 134 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index ad18071244..82aa83e35d 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -146,6 +146,14 @@ async def close(self) -> None: if isinstance(api_client, DeerFlowAPIClient) and not api_client.is_closed: await api_client.close() + async def _notify_agent_done_hook(self) -> None: + if not self.final_llm_resp: + return + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + def _parse_runner_config(self, provider_config: dict) -> _RunnerConfig: api_base = provider_config.get("deerflow_api_base", "http://127.0.0.1:2026") if not isinstance(api_base, str) or not api_base.startswith( @@ -295,6 +303,7 @@ async def step(self): completion_text=f"DeerFlow request failed: {err_msg}", result_chain=err_chain, ) + await self._notify_agent_done_hook() yield AgentResponse( type="err", data=AgentResponseData( @@ -778,11 +787,7 @@ async def _execute_deerflow_request(self): self.final_llm_resp = LLMResponse(role=role, result_chain=final_chain) self._transition_state(AgentState.DONE) - - try: - await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) - except Exception as e: - logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + await self._notify_agent_done_hook() yield AgentResponse( type="llm_result", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 5ef96a440c..f9e90eeca3 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -150,6 +150,72 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: logger.warning(f"Failed to close third-party runner cleanly: {e}") +class _RunnerLifecycle: + def __init__(self, runner: "BaseAgentRunner") -> None: + self._runner = runner + self._closed = False + self._stream_started = False + self._stream_consumed = False + self._idle_task: asyncio.Task[None] | None = None + + async def reset( + self, + *, + req: ProviderRequest, + astr_agent_ctx: AstrAgentContext, + provider_cfg: dict, + streaming: bool, + ) -> None: + await self._runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=provider_cfg, + streaming=streaming, + ) + + async def close_once(self) -> None: + if self._closed: + return + self._closed = True + await _close_runner_if_supported(self._runner) + + def mark_stream_started(self) -> None: + self._stream_started = True + self._idle_task = asyncio.create_task(self._close_if_never_consumed()) + + def mark_stream_consumed(self) -> None: + self._stream_consumed = True + if self._idle_task and not self._idle_task.done(): + self._idle_task.cancel() + + async def finalize(self) -> None: + if ( + self._idle_task + and not self._idle_task.done() + and (not self._stream_started or self._stream_consumed or self._closed) + ): + self._idle_task.cancel() + + if not self._stream_started: + await self.close_once() + + async def _close_if_never_consumed(self) -> None: + try: + await asyncio.sleep(STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC) + except asyncio.CancelledError: + return + + if not self._stream_consumed: + logger.warning( + "Third-party runner stream was never consumed; closing runner to avoid resource leak.", + ) + await self.close_once() + + class ThirdPartyAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -183,6 +249,109 @@ async def _resolve_persona_custom_error_message( logger.debug("Failed to resolve persona custom error message: %s", e) return None + async def _handle_streaming_response( + self, + *, + lifecycle: _RunnerLifecycle, + runner: "BaseAgentRunner", + event: AstrMessageEvent, + custom_error_message: str | None, + ) -> AsyncGenerator[None, None]: + stream_has_runner_error = False + + async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: + nonlocal stream_has_runner_error + lifecycle.mark_stream_consumed() + try: + async for runner_output in run_third_party_agent( + runner, + stream_to_general=False, + custom_error_message=custom_error_message, + ): + if runner_output.is_error: + stream_has_runner_error = True + _set_runner_error_extra(event, True) + yield runner_output.chain + finally: + # Streaming runner cleanup must happen after consumer + # finishes iterating to avoid tearing down active streams. + await lifecycle.close_once() + + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(_stream_runner_chain()), + ) + lifecycle.mark_stream_started() + yield + + if runner.done(): + final_resp = runner.get_final_llm_resp() + if final_resp and final_resp.result_chain: + ( + final_chain, + is_runner_error, + _, + ) = _resolve_final_result( + merged_chain=[], + final_resp=final_resp, + has_intermediate_error=stream_has_runner_error, + ) + _set_runner_error_extra(event, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + + async def _handle_non_streaming_response( + self, + *, + runner: "BaseAgentRunner", + event: AstrMessageEvent, + stream_to_general: bool, + custom_error_message: str | None, + ) -> AsyncGenerator[None, None]: + merged_chain: list = [] + has_intermediate_error = False + async for output in run_third_party_agent( + runner, + stream_to_general=stream_to_general, + custom_error_message=custom_error_message, + ): + merged_chain.extend(output.chain.chain or []) + if output.is_error: + has_intermediate_error = True + yield + + final_resp = runner.get_final_llm_resp() + if not final_resp or not final_resp.result_chain: + if merged_chain: + logger.warning( + "Agent Runner returned no final response, fallback to streamed error/result chain." + ) + else: + logger.warning("Agent Runner 未返回最终结果。") + + ( + final_chain, + is_runner_error, + result_content_type, + ) = _resolve_final_result( + merged_chain=merged_chain, + final_resp=final_resp, + has_intermediate_error=has_intermediate_error, + ) + _set_runner_error_extra(event, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_chain, + result_content_type=result_content_type, + ), + ) + yield + async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: @@ -252,145 +421,33 @@ async def process( and not event.platform_meta.support_streaming_message ) - runner_closed = False - streaming_started = False - stream_consumption_started = False - stream_idle_close_task: asyncio.Task[None] | None = None - - async def close_runner_once() -> None: - nonlocal runner_closed - if runner_closed: - return - runner_closed = True - await _close_runner_if_supported(runner) - - async def close_if_stream_never_consumed() -> None: - try: - await asyncio.sleep(STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC) - except asyncio.CancelledError: - return - if not stream_consumption_started: - logger.warning( - "Third-party runner stream was never consumed; closing runner to avoid resource leak.", - ) - await close_runner_once() + lifecycle = _RunnerLifecycle(runner) try: - await runner.reset( - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=60, - ), - agent_hooks=MAIN_AGENT_HOOKS, - provider_config=self.prov_cfg, + await lifecycle.reset( + req=req, + astr_agent_ctx=astr_agent_ctx, + provider_cfg=self.prov_cfg, streaming=streaming_response, ) - if streaming_response and not stream_to_general: - stream_has_runner_error = False - - async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: - nonlocal stream_has_runner_error, stream_consumption_started - stream_consumption_started = True - if stream_idle_close_task and not stream_idle_close_task.done(): - stream_idle_close_task.cancel() - try: - async for runner_output in run_third_party_agent( - runner, - stream_to_general=False, - custom_error_message=custom_error_message, - ): - if runner_output.is_error: - stream_has_runner_error = True - _set_runner_error_extra(event, True) - yield runner_output.chain - finally: - # Streaming runner cleanup must happen after consumer - # finishes iterating to avoid tearing down active streams. - await close_runner_once() - - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream(_stream_runner_chain()), - ) - stream_idle_close_task = asyncio.create_task( - close_if_stream_never_consumed(), - ) - streaming_started = True - yield - - if runner.done(): - final_resp = runner.get_final_llm_resp() - if final_resp and final_resp.result_chain: - ( - final_chain, - is_runner_error, - _, - ) = _resolve_final_result( - merged_chain=[], - final_resp=final_resp, - has_intermediate_error=stream_has_runner_error, - ) - _set_runner_error_extra(event, is_runner_error) - event.set_result( - MessageEventResult( - chain=final_chain, - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) + async for _ in self._handle_streaming_response( + lifecycle=lifecycle, + runner=runner, + event=event, + custom_error_message=custom_error_message, + ): + yield else: - output_stream = run_third_party_agent( - runner, + async for _ in self._handle_non_streaming_response( + runner=runner, + event=event, stream_to_general=stream_to_general, custom_error_message=custom_error_message, - ) - merged_chain: list = [] - has_intermediate_error = False - async for output in output_stream: - merged_chain.extend(output.chain.chain or []) - if output.is_error: - has_intermediate_error = True + ): yield - - final_resp = runner.get_final_llm_resp() - if not final_resp or not final_resp.result_chain: - if merged_chain: - logger.warning( - "Agent Runner returned no final response, fallback to streamed error/result chain." - ) - else: - logger.warning("Agent Runner 未返回最终结果。") - - ( - final_chain, - is_runner_error, - result_content_type, - ) = _resolve_final_result( - merged_chain=merged_chain, - final_resp=final_resp, - has_intermediate_error=has_intermediate_error, - ) - _set_runner_error_extra(event, is_runner_error) - event.set_result( - MessageEventResult( - chain=final_chain, - result_content_type=result_content_type, - ), - ) - yield finally: - if ( - stream_idle_close_task - and not stream_idle_close_task.done() - and ( - not streaming_started or stream_consumption_started or runner_closed - ) - ): - stream_idle_close_task.cancel() - if not streaming_started: - await close_runner_once() + await lifecycle.finalize() asyncio.create_task( Metric.upload( From ce1a79c7a3a5aeef633d436998384c11a6604a86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 04:12:16 +0900 Subject: [PATCH 21/24] refactor: simplify third-party runner aggregation and lifecycle closing --- .../method/agent_sub_stages/third_party.py | 122 ++++++++++++------ 1 file changed, 84 insertions(+), 38 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index f9e90eeca3..2133c8d8b6 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -1,7 +1,6 @@ import asyncio import inspect from collections.abc import AsyncGenerator -from dataclasses import dataclass from typing import TYPE_CHECKING from astrbot.core import astrbot_config, logger @@ -50,6 +49,16 @@ STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC = 30 +def _coerce_positive_int(value: object, default: int) -> int: + if isinstance(value, bool): + return default + try: + coerced = int(value) + except (TypeError, ValueError): + return default + return coerced if coerced > 0 else default + + def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None: event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error) @@ -91,7 +100,7 @@ async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, custom_error_message: str | None = None, -) -> AsyncGenerator["_ThirdPartyRunnerOutput", None]: +) -> AsyncGenerator[tuple[MessageChain, bool], None]: """ 运行第三方 agent runner 并转换响应格式 类似于 run_agent 函数,但专门处理第三方 agent runner @@ -101,21 +110,12 @@ async def run_third_party_agent( if resp.type == "streaming_delta": if stream_to_general: continue - yield _ThirdPartyRunnerOutput( - chain=resp.data["chain"], - is_error=False, - ) + yield resp.data["chain"], False elif resp.type == "llm_result": if stream_to_general: - yield _ThirdPartyRunnerOutput( - chain=resp.data["chain"], - is_error=False, - ) + yield resp.data["chain"], False elif resp.type == "err": - yield _ThirdPartyRunnerOutput( - chain=resp.data["chain"], - is_error=True, - ) + yield resp.data["chain"], True except Exception as e: logger.error(f"Third party agent runner error: {e}") err_msg = custom_error_message @@ -125,16 +125,26 @@ async def run_third_party_agent( f"Error Type: {type(e).__name__} (3rd party)\n" f"Error Message: {str(e)}" ) - yield _ThirdPartyRunnerOutput( - chain=MessageChain().message(err_msg), - is_error=True, - ) + yield MessageChain().message(err_msg), True -@dataclass -class _ThirdPartyRunnerOutput: - chain: MessageChain - is_error: bool = False +async def _consume_runner_and_aggregate( + runner: "BaseAgentRunner", + *, + stream_to_general: bool, + custom_error_message: str | None, +) -> AsyncGenerator[tuple[MessageChain, bool, list, bool], None]: + merged_chain: list = [] + has_intermediate_error = False + async for chain, is_error in run_third_party_agent( + runner, + stream_to_general=stream_to_general, + custom_error_message=custom_error_message, + ): + merged_chain.extend(chain.chain or []) + if is_error: + has_intermediate_error = True + yield chain, is_error, merged_chain, has_intermediate_error async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: @@ -151,8 +161,15 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: class _RunnerLifecycle: - def __init__(self, runner: "BaseAgentRunner") -> None: + def __init__( + self, + runner: "BaseAgentRunner", + stream_consumption_close_timeout_sec: int, + ) -> None: self._runner = runner + self._stream_consumption_close_timeout_sec = ( + stream_consumption_close_timeout_sec + ) self._closed = False self._stream_started = False self._stream_consumed = False @@ -196,22 +213,34 @@ async def finalize(self) -> None: if ( self._idle_task and not self._idle_task.done() - and (not self._stream_started or self._stream_consumed or self._closed) + and (self._stream_consumed or self._closed) ): self._idle_task.cancel() - if not self._stream_started: + defer_close_to_watchdog = ( + self._stream_started + and not self._stream_consumed + and self._idle_task is not None + and not self._idle_task.done() + and not self._closed + ) + if defer_close_to_watchdog: + return + + if not self._closed: await self.close_once() async def _close_if_never_consumed(self) -> None: try: - await asyncio.sleep(STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC) + await asyncio.sleep(self._stream_consumption_close_timeout_sec) except asyncio.CancelledError: return if not self._stream_consumed: logger.warning( - "Third-party runner stream was never consumed; closing runner to avoid resource leak.", + "Third-party runner stream was never consumed in %ss; closing runner " + "to avoid resource leak.", + self._stream_consumption_close_timeout_sec, ) await self.close_once() @@ -230,6 +259,13 @@ async def initialize(self, ctx: PipelineContext) -> None: self.unsupported_streaming_strategy: str = settings[ "unsupported_streaming_strategy" ] + self.stream_consumption_close_timeout_sec: int = _coerce_positive_int( + settings.get( + "third_party_stream_consumption_close_timeout_sec", + STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC, + ), + STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC, + ) async def _resolve_persona_custom_error_message( self, event: AstrMessageEvent @@ -258,20 +294,25 @@ async def _handle_streaming_response( custom_error_message: str | None, ) -> AsyncGenerator[None, None]: stream_has_runner_error = False + merged_chain: list = [] async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: - nonlocal stream_has_runner_error + nonlocal merged_chain, stream_has_runner_error lifecycle.mark_stream_consumed() try: - async for runner_output in run_third_party_agent( + async for ( + chain, + is_error, + merged_chain, + stream_has_runner_error, + ) in _consume_runner_and_aggregate( runner, stream_to_general=False, custom_error_message=custom_error_message, ): - if runner_output.is_error: - stream_has_runner_error = True + if is_error: _set_runner_error_extra(event, True) - yield runner_output.chain + yield chain finally: # Streaming runner cleanup must happen after consumer # finishes iterating to avoid tearing down active streams. @@ -293,7 +334,7 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: is_runner_error, _, ) = _resolve_final_result( - merged_chain=[], + merged_chain=merged_chain, final_resp=final_resp, has_intermediate_error=stream_has_runner_error, ) @@ -315,14 +356,16 @@ async def _handle_non_streaming_response( ) -> AsyncGenerator[None, None]: merged_chain: list = [] has_intermediate_error = False - async for output in run_third_party_agent( + async for ( + _, + _, + merged_chain, + has_intermediate_error, + ) in _consume_runner_and_aggregate( runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, ): - merged_chain.extend(output.chain.chain or []) - if output.is_error: - has_intermediate_error = True yield final_resp = runner.get_final_llm_resp() @@ -421,7 +464,10 @@ async def process( and not event.platform_meta.support_streaming_message ) - lifecycle = _RunnerLifecycle(runner) + lifecycle = _RunnerLifecycle( + runner, + stream_consumption_close_timeout_sec=self.stream_consumption_close_timeout_sec, + ) try: await lifecycle.reset( From 8a40e141089978690031f6e472aa124ba9445205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 04:19:31 +0900 Subject: [PATCH 22/24] fix: preserve first deerflow values payload and simplify runner flow --- .../runners/deerflow/deerflow_agent_runner.py | 53 +-- .../method/agent_sub_stages/third_party.py | 309 +++++++----------- 2 files changed, 141 insertions(+), 221 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 82aa83e35d..7a84bb27f8 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -603,22 +603,23 @@ def _handle_values_event( if not values_messages: return responses + new_messages: list[dict[str, T.Any]] = [] if not state.baseline_initialized: state.baseline_initialized = True for idx, msg in enumerate(values_messages): if not isinstance(msg, dict): continue + new_messages.append(msg) msg_id = get_message_id(msg) if msg_id: self._remember_seen_message_id(state, msg_id) continue state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg) - return responses - - new_messages = self._extract_new_messages_from_values( - values_messages, - state, - ) + else: + new_messages = self._extract_new_messages_from_values( + values_messages, + state, + ) latest_text = "" if new_messages: state.run_values_messages.extend(new_messages) @@ -675,28 +676,30 @@ def _handle_message_event( state.clarification_text = maybe_clarification return response - def _resolve_final_output(self, state: _StreamState) -> tuple[MessageChain, bool]: - failures_only = False - final_chain = MessageChain() - - # Clarification tool output should take precedence over partial AI/tool-call text. + def _select_final_chain(self, state: _StreamState) -> tuple[MessageChain, bool]: if state.clarification_text: - final_chain = MessageChain(chain=[Comp.Plain(state.clarification_text)]) - else: - latest_ai_message = extract_latest_ai_message(state.run_values_messages) - if latest_ai_message: - final_chain = self._build_chain_from_ai_content( - latest_ai_message.get("content"), - ) + return MessageChain(chain=[Comp.Plain(state.clarification_text)]), False + + latest_ai_message = extract_latest_ai_message(state.run_values_messages) + if latest_ai_message: + chain_from_values = self._build_chain_from_ai_content( + latest_ai_message.get("content"), + ) + if chain_from_values.chain: + return chain_from_values, False + + if state.latest_text: + return MessageChain(chain=[Comp.Plain(state.latest_text)]), False - if not final_chain.chain and state.latest_text: - final_chain = MessageChain(chain=[Comp.Plain(state.latest_text)]) + failure_text = build_task_failure_summary(state.task_failures) + if failure_text: + return MessageChain(chain=[Comp.Plain(failure_text)]), True - if not final_chain.chain: - failure_text = build_task_failure_summary(state.task_failures) - if failure_text: - final_chain = MessageChain(chain=[Comp.Plain(failure_text)]) - failures_only = True + return MessageChain(), False + + def _resolve_final_output(self, state: _StreamState) -> tuple[MessageChain, bool]: + # Clarification and values/message-derived output share a single selection path. + final_chain, failures_only = self._select_final_chain(state) if not final_chain.chain: logger.warning("DeerFlow returned no text content in stream events.") diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 2133c8d8b6..c34ae1b499 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -27,7 +27,6 @@ if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner - from astrbot.core.provider.entities import LLMResponse from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( @@ -59,43 +58,6 @@ def _coerce_positive_int(value: object, default: int) -> int: return coerced if coerced > 0 else default -def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None: - event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error) - - -def _resolve_final_result( - merged_chain: list, - final_resp: "LLMResponse | None", - has_intermediate_error: bool, -) -> tuple[list, bool, ResultContentType]: - if not final_resp or not final_resp.result_chain: - if merged_chain: - is_error = has_intermediate_error - content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if is_error - else ResultContentType.LLM_RESULT - ) - return merged_chain, is_error, content_type - - fallback_error_chain = MessageChain().message( - "Agent Runner did not return any result.", - ) - return ( - fallback_error_chain.chain or [], - True, - ResultContentType.AGENT_RUNNER_ERROR, - ) - - is_error = has_intermediate_error or final_resp.role == "err" - content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if is_error - else ResultContentType.LLM_RESULT - ) - return final_resp.result_chain.chain or [], is_error, content_type - - async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, @@ -128,25 +90,6 @@ async def run_third_party_agent( yield MessageChain().message(err_msg), True -async def _consume_runner_and_aggregate( - runner: "BaseAgentRunner", - *, - stream_to_general: bool, - custom_error_message: str | None, -) -> AsyncGenerator[tuple[MessageChain, bool, list, bool], None]: - merged_chain: list = [] - has_intermediate_error = False - async for chain, is_error in run_third_party_agent( - runner, - stream_to_general=stream_to_general, - custom_error_message=custom_error_message, - ): - merged_chain.extend(chain.chain or []) - if is_error: - has_intermediate_error = True - yield chain, is_error, merged_chain, has_intermediate_error - - async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: close_callable = getattr(runner, "close", None) if not callable(close_callable): @@ -160,91 +103,6 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None: logger.warning(f"Failed to close third-party runner cleanly: {e}") -class _RunnerLifecycle: - def __init__( - self, - runner: "BaseAgentRunner", - stream_consumption_close_timeout_sec: int, - ) -> None: - self._runner = runner - self._stream_consumption_close_timeout_sec = ( - stream_consumption_close_timeout_sec - ) - self._closed = False - self._stream_started = False - self._stream_consumed = False - self._idle_task: asyncio.Task[None] | None = None - - async def reset( - self, - *, - req: ProviderRequest, - astr_agent_ctx: AstrAgentContext, - provider_cfg: dict, - streaming: bool, - ) -> None: - await self._runner.reset( - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=60, - ), - agent_hooks=MAIN_AGENT_HOOKS, - provider_config=provider_cfg, - streaming=streaming, - ) - - async def close_once(self) -> None: - if self._closed: - return - self._closed = True - await _close_runner_if_supported(self._runner) - - def mark_stream_started(self) -> None: - self._stream_started = True - self._idle_task = asyncio.create_task(self._close_if_never_consumed()) - - def mark_stream_consumed(self) -> None: - self._stream_consumed = True - if self._idle_task and not self._idle_task.done(): - self._idle_task.cancel() - - async def finalize(self) -> None: - if ( - self._idle_task - and not self._idle_task.done() - and (self._stream_consumed or self._closed) - ): - self._idle_task.cancel() - - defer_close_to_watchdog = ( - self._stream_started - and not self._stream_consumed - and self._idle_task is not None - and not self._idle_task.done() - and not self._closed - ) - if defer_close_to_watchdog: - return - - if not self._closed: - await self.close_once() - - async def _close_if_never_consumed(self) -> None: - try: - await asyncio.sleep(self._stream_consumption_close_timeout_sec) - except asyncio.CancelledError: - return - - if not self._stream_consumed: - logger.warning( - "Third-party runner stream was never consumed in %ss; closing runner " - "to avoid resource leak.", - self._stream_consumption_close_timeout_sec, - ) - await self.close_once() - - class ThirdPartyAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -288,63 +146,68 @@ async def _resolve_persona_custom_error_message( async def _handle_streaming_response( self, *, - lifecycle: _RunnerLifecycle, runner: "BaseAgentRunner", event: AstrMessageEvent, custom_error_message: str | None, + close_runner_once, + mark_stream_consumed, ) -> AsyncGenerator[None, None]: stream_has_runner_error = False merged_chain: list = [] async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: nonlocal merged_chain, stream_has_runner_error - lifecycle.mark_stream_consumed() + mark_stream_consumed() try: - async for ( - chain, - is_error, - merged_chain, - stream_has_runner_error, - ) in _consume_runner_and_aggregate( + async for chain, is_error in run_third_party_agent( runner, stream_to_general=False, custom_error_message=custom_error_message, ): + merged_chain.extend(chain.chain or []) if is_error: - _set_runner_error_extra(event, True) + stream_has_runner_error = True + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) yield chain finally: # Streaming runner cleanup must happen after consumer # finishes iterating to avoid tearing down active streams. - await lifecycle.close_once() + await close_runner_once() event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) .set_async_stream(_stream_runner_chain()), ) - lifecycle.mark_stream_started() yield if runner.done(): final_resp = runner.get_final_llm_resp() - if final_resp and final_resp.result_chain: - ( - final_chain, - is_runner_error, - _, - ) = _resolve_final_result( - merged_chain=merged_chain, - final_resp=final_resp, - has_intermediate_error=stream_has_runner_error, - ) - _set_runner_error_extra(event, is_runner_error) - event.set_result( - MessageEventResult( - chain=final_chain, - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) + if not final_resp or not final_resp.result_chain: + if merged_chain: + logger.warning( + "Agent Runner returned no final response, fallback to streamed error/result chain." + ) + final_chain = merged_chain + is_runner_error = stream_has_runner_error + else: + logger.warning("Agent Runner 未返回最终结果。") + fallback_error_chain = MessageChain().message( + "Agent Runner did not return any result.", + ) + final_chain = fallback_error_chain.chain or [] + is_runner_error = True + else: + final_chain = final_resp.result_chain.chain or [] + is_runner_error = stream_has_runner_error or final_resp.role == "err" + + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_runner_error) + event.set_result( + MessageEventResult( + chain=final_chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) async def _handle_non_streaming_response( self, @@ -356,16 +219,15 @@ async def _handle_non_streaming_response( ) -> AsyncGenerator[None, None]: merged_chain: list = [] has_intermediate_error = False - async for ( - _, - _, - merged_chain, - has_intermediate_error, - ) in _consume_runner_and_aggregate( + async for chain, is_error in run_third_party_agent( runner, stream_to_general=stream_to_general, custom_error_message=custom_error_message, ): + merged_chain.extend(chain.chain or []) + if is_error: + has_intermediate_error = True + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) yield final_resp = runner.get_final_llm_resp() @@ -374,19 +236,31 @@ async def _handle_non_streaming_response( logger.warning( "Agent Runner returned no final response, fallback to streamed error/result chain." ) + final_chain = merged_chain + is_runner_error = has_intermediate_error + result_content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if is_runner_error + else ResultContentType.LLM_RESULT + ) else: logger.warning("Agent Runner 未返回最终结果。") + fallback_error_chain = MessageChain().message( + "Agent Runner did not return any result.", + ) + final_chain = fallback_error_chain.chain or [] + is_runner_error = True + result_content_type = ResultContentType.AGENT_RUNNER_ERROR + else: + final_chain = final_resp.result_chain.chain or [] + is_runner_error = has_intermediate_error or final_resp.role == "err" + result_content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if is_runner_error + else ResultContentType.LLM_RESULT + ) - ( - final_chain, - is_runner_error, - result_content_type, - ) = _resolve_final_result( - merged_chain=merged_chain, - final_resp=final_resp, - has_intermediate_error=has_intermediate_error, - ) - _set_runner_error_extra(event, is_runner_error) + event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_runner_error) event.set_result( MessageEventResult( chain=final_chain, @@ -464,24 +338,59 @@ async def process( and not event.platform_meta.support_streaming_message ) - lifecycle = _RunnerLifecycle( - runner, - stream_consumption_close_timeout_sec=self.stream_consumption_close_timeout_sec, - ) + runner_closed = False + streaming_started = False + stream_consumption_started = False + stream_idle_close_task: asyncio.Task[None] | None = None + + async def close_runner_once() -> None: + nonlocal runner_closed + if runner_closed: + return + runner_closed = True + await _close_runner_if_supported(runner) + + def mark_stream_consumed() -> None: + nonlocal stream_consumption_started + stream_consumption_started = True + if stream_idle_close_task and not stream_idle_close_task.done(): + stream_idle_close_task.cancel() + + async def close_if_stream_never_consumed() -> None: + try: + await asyncio.sleep(self.stream_consumption_close_timeout_sec) + except asyncio.CancelledError: + return + if not stream_consumption_started: + logger.warning( + "Third-party runner stream was never consumed in %ss; closing runner to avoid resource leak.", + self.stream_consumption_close_timeout_sec, + ) + await close_runner_once() try: - await lifecycle.reset( - req=req, - astr_agent_ctx=astr_agent_ctx, - provider_cfg=self.prov_cfg, + await runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, streaming=streaming_response, ) + if streaming_response and not stream_to_general: + streaming_started = True + stream_idle_close_task = asyncio.create_task( + close_if_stream_never_consumed(), + ) async for _ in self._handle_streaming_response( - lifecycle=lifecycle, runner=runner, event=event, custom_error_message=custom_error_message, + close_runner_once=close_runner_once, + mark_stream_consumed=mark_stream_consumed, ): yield else: @@ -493,7 +402,15 @@ async def process( ): yield finally: - await lifecycle.finalize() + if ( + stream_idle_close_task + and not stream_idle_close_task.done() + and (stream_consumption_started or runner_closed) + ): + stream_idle_close_task.cancel() + + if not streaming_started: + await close_runner_once() asyncio.create_task( Metric.upload( From b510a5d8abdc13459319da4635ce1c267092f21e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 10:20:43 +0900 Subject: [PATCH 23/24] refactor: unify runner final resolution and harden deerflow close state --- .../runners/deerflow/deerflow_api_client.py | 21 +++- .../method/agent_sub_stages/third_party.py | 97 ++++++++++--------- 2 files changed, 72 insertions(+), 46 deletions(-) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index f279db2005..52e31307ee 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -196,11 +196,28 @@ async def stream_run( yield event async def close(self) -> None: - self._closed = True session = self._session - if session is not None and not session.closed: + if session is None: + self._closed = True + return + + if session.closed: + self._session = None + self._closed = True + return + + try: await session.close() + except Exception as e: + logger.warning( + "Failed to close DeerFlowAPIClient session cleanly: %s", + e, + exc_info=True, + ) + raise + self._session = None + self._closed = True def __del__(self) -> None: session = getattr(self, "_session", None) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index c34ae1b499..00b40e47ee 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from astrbot.core.agent.runners.base import BaseAgentRunner + from astrbot.core.provider.entities import LLMResponse from astrbot.core.pipeline.stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ( @@ -46,6 +47,11 @@ } THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY = "_third_party_runner_error" STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC = 30 +RUNNER_NO_RESULT_FALLBACK_MESSAGE = "Agent Runner did not return any result." +RUNNER_NO_FINAL_RESPONSE_LOG = ( + "Agent Runner returned no final response, fallback to streamed error/result chain." +) +RUNNER_NO_RESULT_LOG = "Agent Runner 未返回最终结果。" def _coerce_positive_int(value: object, default: int) -> int: @@ -58,6 +64,42 @@ def _coerce_positive_int(value: object, default: int) -> int: return coerced if coerced > 0 else default +def _resolve_runner_final_result( + merged_chain: list, + has_intermediate_error: bool, + final_resp: "LLMResponse | None", +) -> tuple[list, bool, ResultContentType]: + if not final_resp or not final_resp.result_chain: + if merged_chain: + logger.warning(RUNNER_NO_FINAL_RESPONSE_LOG) + is_runner_error = has_intermediate_error + result_content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if is_runner_error + else ResultContentType.LLM_RESULT + ) + return merged_chain, is_runner_error, result_content_type + + logger.warning(RUNNER_NO_RESULT_LOG) + fallback_error_chain = MessageChain().message( + RUNNER_NO_RESULT_FALLBACK_MESSAGE, + ) + return ( + fallback_error_chain.chain or [], + True, + ResultContentType.AGENT_RUNNER_ERROR, + ) + + final_chain = final_resp.result_chain.chain or [] + is_runner_error = has_intermediate_error or final_resp.role == "err" + result_content_type = ( + ResultContentType.AGENT_RUNNER_ERROR + if is_runner_error + else ResultContentType.LLM_RESULT + ) + return final_chain, is_runner_error, result_content_type + + async def run_third_party_agent( runner: "BaseAgentRunner", stream_to_general: bool = False, @@ -183,23 +225,11 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: if runner.done(): final_resp = runner.get_final_llm_resp() - if not final_resp or not final_resp.result_chain: - if merged_chain: - logger.warning( - "Agent Runner returned no final response, fallback to streamed error/result chain." - ) - final_chain = merged_chain - is_runner_error = stream_has_runner_error - else: - logger.warning("Agent Runner 未返回最终结果。") - fallback_error_chain = MessageChain().message( - "Agent Runner did not return any result.", - ) - final_chain = fallback_error_chain.chain or [] - is_runner_error = True - else: - final_chain = final_resp.result_chain.chain or [] - is_runner_error = stream_has_runner_error or final_resp.role == "err" + final_chain, is_runner_error, _ = _resolve_runner_final_result( + merged_chain=merged_chain, + has_intermediate_error=stream_has_runner_error, + final_resp=final_resp, + ) event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_runner_error) event.set_result( @@ -231,34 +261,13 @@ async def _handle_non_streaming_response( yield final_resp = runner.get_final_llm_resp() - if not final_resp or not final_resp.result_chain: - if merged_chain: - logger.warning( - "Agent Runner returned no final response, fallback to streamed error/result chain." - ) - final_chain = merged_chain - is_runner_error = has_intermediate_error - result_content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if is_runner_error - else ResultContentType.LLM_RESULT - ) - else: - logger.warning("Agent Runner 未返回最终结果。") - fallback_error_chain = MessageChain().message( - "Agent Runner did not return any result.", - ) - final_chain = fallback_error_chain.chain or [] - is_runner_error = True - result_content_type = ResultContentType.AGENT_RUNNER_ERROR - else: - final_chain = final_resp.result_chain.chain or [] - is_runner_error = has_intermediate_error or final_resp.role == "err" - result_content_type = ( - ResultContentType.AGENT_RUNNER_ERROR - if is_runner_error - else ResultContentType.LLM_RESULT + final_chain, is_runner_error, result_content_type = ( + _resolve_runner_final_result( + merged_chain=merged_chain, + has_intermediate_error=has_intermediate_error, + final_resp=final_resp, ) + ) event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_runner_error) event.set_result( From 6c433978a05e07afbf431bc416be7a1d4c908c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 10:30:04 +0900 Subject: [PATCH 24/24] refactor: share int coercion and make deerflow close best effort --- .../runners/deerflow/deerflow_agent_runner.py | 58 +++-------------- .../runners/deerflow/deerflow_api_client.py | 8 +-- .../method/agent_sub_stages/third_party.py | 18 ++---- astrbot/core/utils/config_number.py | 64 +++++++++++++++++++ 4 files changed, 84 insertions(+), 64 deletions(-) create mode 100644 astrbot/core/utils/config_number.py diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 7a84bb27f8..a6cdf6559e 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -16,6 +16,7 @@ LLMResponse, ProviderRequest, ) +from astrbot.core.utils.config_number import coerce_int_config from ...hooks import BaseAgentRunHooks from ...response import AgentResponseData @@ -98,48 +99,6 @@ def _format_exception(self, err: Exception) -> str: return f"{err_type}: no detailed error message provided." - def _coerce_int_config( - self, - field_name: str, - value: T.Any, - default: int, - min_value: int | None = None, - ) -> int: - if isinstance(value, bool): - logger.warning( - f"DeerFlow config '{field_name}' should be numeric, got boolean. " - f"Fallback to {default}." - ) - parsed = default - elif isinstance(value, int): - parsed = value - elif isinstance(value, str): - try: - parsed = int(value.strip()) - except ValueError: - logger.warning( - f"DeerFlow config '{field_name}' value '{value}' is not numeric. " - f"Fallback to {default}." - ) - parsed = default - else: - try: - parsed = int(value) - except (TypeError, ValueError): - logger.warning( - f"DeerFlow config '{field_name}' has unsupported type " - f"{type(value).__name__}. Fallback to {default}." - ) - parsed = default - - if min_value is not None and parsed < min_value: - logger.warning( - f"DeerFlow config '{field_name}'={parsed} is below minimum {min_value}. " - f"Fallback to {min_value}." - ) - parsed = min_value - return parsed - async def close(self) -> None: """Explicit cleanup hook for long-lived workers.""" api_client = getattr(self, "api_client", None) @@ -180,23 +139,26 @@ def _parse_runner_config(self, provider_config: dict) -> _RunnerConfig: subagent_enabled=bool( provider_config.get("deerflow_subagent_enabled", False), ), - max_concurrent_subagents=self._coerce_int_config( - "deerflow_max_concurrent_subagents", + max_concurrent_subagents=coerce_int_config( provider_config.get("deerflow_max_concurrent_subagents", 3), default=3, min_value=1, + field_name="deerflow_max_concurrent_subagents", + source="DeerFlow config", ), - timeout=self._coerce_int_config( - "timeout", + timeout=coerce_int_config( provider_config.get("timeout", 300), default=300, min_value=1, + field_name="timeout", + source="DeerFlow config", ), - recursion_limit=self._coerce_int_config( - "deerflow_recursion_limit", + recursion_limit=coerce_int_config( provider_config.get("deerflow_recursion_limit", 1000), default=1000, min_value=1, + field_name="deerflow_recursion_limit", + source="DeerFlow config", ), ) diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index 52e31307ee..3dcf06ed5a 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -214,10 +214,10 @@ async def close(self) -> None: e, exc_info=True, ) - raise - - self._session = None - self._closed = True + finally: + # Cleanup is best-effort and should not make teardown paths fail loudly. + self._session = None + self._closed = True def __del__(self) -> None: session = getattr(self, "_session", None) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 00b40e47ee..e8b6b7ff10 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -34,6 +34,7 @@ ProviderRequest, ) from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.config_number import coerce_int_config from astrbot.core.utils.metrics import Metric from .....astr_agent_context import AgentContextWrapper, AstrAgentContext @@ -54,16 +55,6 @@ RUNNER_NO_RESULT_LOG = "Agent Runner 未返回最终结果。" -def _coerce_positive_int(value: object, default: int) -> int: - if isinstance(value, bool): - return default - try: - coerced = int(value) - except (TypeError, ValueError): - return default - return coerced if coerced > 0 else default - - def _resolve_runner_final_result( merged_chain: list, has_intermediate_error: bool, @@ -159,12 +150,15 @@ async def initialize(self, ctx: PipelineContext) -> None: self.unsupported_streaming_strategy: str = settings[ "unsupported_streaming_strategy" ] - self.stream_consumption_close_timeout_sec: int = _coerce_positive_int( + self.stream_consumption_close_timeout_sec: int = coerce_int_config( settings.get( "third_party_stream_consumption_close_timeout_sec", STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC, ), - STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC, + default=STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC, + min_value=1, + field_name="third_party_stream_consumption_close_timeout_sec", + source="Third-party runner config", ) async def _resolve_persona_custom_error_message( diff --git a/astrbot/core/utils/config_number.py b/astrbot/core/utils/config_number.py new file mode 100644 index 0000000000..f9ce138397 --- /dev/null +++ b/astrbot/core/utils/config_number.py @@ -0,0 +1,64 @@ +from astrbot.core import logger + + +def coerce_int_config( + value: object, + *, + default: int, + min_value: int | None = None, + field_name: str | None = None, + source: str = "config", + warn: bool = True, +) -> int: + label = f"'{field_name}'" if field_name else "value" + + if isinstance(value, bool): + if warn: + logger.warning( + "%s %s should be numeric, got boolean. Fallback to %s.", + source, + label, + default, + ) + parsed = default + elif isinstance(value, int): + parsed = value + elif isinstance(value, str): + try: + parsed = int(value.strip()) + except ValueError: + if warn: + logger.warning( + "%s %s value '%s' is not numeric. Fallback to %s.", + source, + label, + value, + default, + ) + parsed = default + else: + try: + parsed = int(value) + except (TypeError, ValueError): + if warn: + logger.warning( + "%s %s has unsupported type %s. Fallback to %s.", + source, + label, + type(value).__name__, + default, + ) + parsed = default + + if min_value is not None and parsed < min_value: + if warn: + logger.warning( + "%s %s=%s is below minimum %s. Fallback to %s.", + source, + label, + parsed, + min_value, + min_value, + ) + parsed = min_value + return parsed