diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index ebd13d0102..e9485b8a41 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -117,6 +117,12 @@ def _get_runtime(context) -> tuple[dict, dict, str]: return cfg, provider_settings, event.unified_msg_origin +def _validate_search_query(kwargs: dict) -> str | None: + # Keep provider behavior aligned when the model omits or blanks the required query. + query = str(kwargs.get("query") or "").strip() + return query or None + + def _cache_favicon(url: str, favicon: str | None) -> None: if favicon: sp.temporary_cache["_ws_favicon"][url] = favicon @@ -382,7 +388,10 @@ class TavilyWebSearchTool(FunctionTool[AstrAgentContext]): default_factory=lambda: { "type": "object", "properties": { - "query": {"type": "string", "description": "Required. Search query."}, + "query": { + "type": "string", + "description": "Required string: search query to execute.", + }, "max_results": { "type": "integer", "description": "Optional. The maximum number of results to return. Default is 7. Range is 5-20.", @@ -421,6 +430,10 @@ async def call(self, context, **kwargs) -> ToolExecResult: if not provider_settings.get("websearch_tavily_key", []): return "Error: Tavily API key is not configured in AstrBot." + query = _validate_search_query(kwargs) + if not query: + return "Error: 'query' parameter is required but was not provided." + search_depth = kwargs.get("search_depth", "basic") if search_depth not in ["basic", "advanced"]: search_depth = "basic" @@ -430,7 +443,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: topic = "general" payload = { - "query": kwargs["query"], + "query": query, "max_results": kwargs.get("max_results", 7), "include_favicon": True, "search_depth": search_depth, @@ -546,8 +559,12 @@ async def call(self, context, **kwargs) -> ToolExecResult: if not provider_settings.get("websearch_bocha_key", []): return "Error: BoCha API key is not configured in AstrBot." + query = _validate_search_query(kwargs) + if not query: + return "Error: 'query' parameter is required but was not provided." + payload = { - "query": kwargs["query"], + "query": query, "count": kwargs.get("count", 10), "summary": bool(kwargs.get("summary", False)), } @@ -600,6 +617,10 @@ async def call(self, context, **kwargs) -> ToolExecResult: if not provider_settings.get("websearch_brave_key", []): return "Error: Brave API key is not configured in AstrBot." + query = _validate_search_query(kwargs) + if not query: + return "Error: 'query' parameter is required but was not provided." + count = int(kwargs.get("count", 10)) if count < 1: count = 1 @@ -607,7 +628,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: count = 20 payload = { - "q": kwargs["query"], + "q": query, "count": count, "country": kwargs.get("country", "US"), "search_lang": kwargs.get("search_lang", "zh-hans"), @@ -661,8 +682,12 @@ async def call(self, context, **kwargs) -> ToolExecResult: if not provider_settings.get("websearch_firecrawl_key", []): return "Error: Firecrawl API key is not configured in AstrBot." + query = _validate_search_query(kwargs) + if not query: + return "Error: 'query' parameter is required but was not provided." + payload = { - "query": kwargs["query"], + "query": query, "limit": kwargs.get("limit", 5), "sources": ["web"], } @@ -775,6 +800,10 @@ async def call(self, context, **kwargs) -> ToolExecResult: if not provider_settings.get("websearch_baidu_app_builder_key", ""): return "Error: Baidu AI Search API key is not configured in AstrBot." + query = _validate_search_query(kwargs) + if not query: + return "Error: 'query' parameter is required but was not provided." + top_k = int(kwargs.get("top_k", 10)) if top_k < 1: top_k = 1 @@ -782,7 +811,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: top_k = 50 payload = { - "messages": [{"role": "user", "content": str(kwargs["query"])[:72]}], + "messages": [{"role": "user", "content": query[:72]}], "search_source": "baidu_search_v2", "resource_type_filter": [{"type": "web", "top_k": top_k}], } diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index c0ac3cf800..1571de00f3 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -371,6 +371,34 @@ def post(self, url, json, headers): return self.response +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tool_cls,provider_setting,kwargs", + [ + (tools.TavilyWebSearchTool, "websearch_tavily_key", {}), + (tools.BochaWebSearchTool, "websearch_bocha_key", {"query": None}), + (tools.BraveWebSearchTool, "websearch_brave_key", {"query": " "}), + (tools.FirecrawlWebSearchTool, "websearch_firecrawl_key", {"query": ""}), + (tools.BaiduWebSearchTool, "websearch_baidu_app_builder_key", {}), + ], +) +async def test_search_tool_returns_friendly_error_when_query_missing( + tool_cls, provider_setting, kwargs +): + """Issue #7499: invalid query inputs must not crash search tools.""" + tool = tool_cls() + settings = {provider_setting: ["test-key"]} + if provider_setting == "websearch_baidu_app_builder_key": + settings = {provider_setting: "test-key"} + context = _context_with_provider_settings(settings) + + result = await tool.call(context, **kwargs) + + assert isinstance(result, str) + assert "Error:" in result + assert "query" in result.lower() + + def _context_with_provider_settings(provider_settings): config = {"provider_settings": provider_settings} agent_context = SimpleNamespace(