Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions astrbot/core/tools/web_search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def _get_runtime(context) -> tuple[dict, dict, str]:
return cfg, provider_settings, event.unified_msg_origin


def _validate_search_query(kwargs: dict) -> str | None:
# Keep provider behavior aligned when the model omits or blanks the required query.
query = str(kwargs.get("query") or "").strip()
return query or None


def _cache_favicon(url: str, favicon: str | None) -> None:
if favicon:
sp.temporary_cache["_ws_favicon"][url] = favicon
Expand Down Expand Up @@ -382,7 +388,10 @@ class TavilyWebSearchTool(FunctionTool[AstrAgentContext]):
default_factory=lambda: {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Required. Search query."},
"query": {
"type": "string",
"description": "Required string: search query to execute.",
},
"max_results": {
"type": "integer",
"description": "Optional. The maximum number of results to return. Default is 7. Range is 5-20.",
Expand Down Expand Up @@ -421,6 +430,10 @@ async def call(self, context, **kwargs) -> ToolExecResult:
if not provider_settings.get("websearch_tavily_key", []):
return "Error: Tavily API key is not configured in AstrBot."

query = _validate_search_query(kwargs)
if not query:
return "Error: 'query' parameter is required but was not provided."

search_depth = kwargs.get("search_depth", "basic")
if search_depth not in ["basic", "advanced"]:
search_depth = "basic"
Expand All @@ -430,7 +443,7 @@ async def call(self, context, **kwargs) -> ToolExecResult:
topic = "general"

payload = {
"query": kwargs["query"],
"query": query,
"max_results": kwargs.get("max_results", 7),
"include_favicon": True,
"search_depth": search_depth,
Expand Down Expand Up @@ -546,8 +559,12 @@ async def call(self, context, **kwargs) -> ToolExecResult:
if not provider_settings.get("websearch_bocha_key", []):
return "Error: BoCha API key is not configured in AstrBot."

query = _validate_search_query(kwargs)
if not query:
return "Error: 'query' parameter is required but was not provided."

payload = {
"query": kwargs["query"],
"query": query,
"count": kwargs.get("count", 10),
"summary": bool(kwargs.get("summary", False)),
}
Expand Down Expand Up @@ -600,14 +617,18 @@ async def call(self, context, **kwargs) -> ToolExecResult:
if not provider_settings.get("websearch_brave_key", []):
return "Error: Brave API key is not configured in AstrBot."

query = _validate_search_query(kwargs)
if not query:
return "Error: 'query' parameter is required but was not provided."

count = int(kwargs.get("count", 10))
if count < 1:
count = 1
if count > 20:
count = 20

payload = {
"q": kwargs["query"],
"q": query,
"count": count,
"country": kwargs.get("country", "US"),
"search_lang": kwargs.get("search_lang", "zh-hans"),
Expand Down Expand Up @@ -661,8 +682,12 @@ async def call(self, context, **kwargs) -> ToolExecResult:
if not provider_settings.get("websearch_firecrawl_key", []):
return "Error: Firecrawl API key is not configured in AstrBot."

query = _validate_search_query(kwargs)
if not query:
return "Error: 'query' parameter is required but was not provided."

payload = {
"query": kwargs["query"],
"query": query,
"limit": kwargs.get("limit", 5),
"sources": ["web"],
}
Expand Down Expand Up @@ -775,14 +800,18 @@ async def call(self, context, **kwargs) -> ToolExecResult:
if not provider_settings.get("websearch_baidu_app_builder_key", ""):
return "Error: Baidu AI Search API key is not configured in AstrBot."

query = _validate_search_query(kwargs)
if not query:
return "Error: 'query' parameter is required but was not provided."

top_k = int(kwargs.get("top_k", 10))
if top_k < 1:
top_k = 1
if top_k > 50:
top_k = 50

payload = {
"messages": [{"role": "user", "content": str(kwargs["query"])[:72]}],
"messages": [{"role": "user", "content": query[:72]}],
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": top_k}],
}
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/test_web_search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,34 @@ def post(self, url, json, headers):
return self.response


@pytest.mark.asyncio
@pytest.mark.parametrize(
"tool_cls,provider_setting,kwargs",
[
(tools.TavilyWebSearchTool, "websearch_tavily_key", {}),
(tools.BochaWebSearchTool, "websearch_bocha_key", {"query": None}),
(tools.BraveWebSearchTool, "websearch_brave_key", {"query": " "}),
(tools.FirecrawlWebSearchTool, "websearch_firecrawl_key", {"query": ""}),
(tools.BaiduWebSearchTool, "websearch_baidu_app_builder_key", {}),
],
)
async def test_search_tool_returns_friendly_error_when_query_missing(
tool_cls, provider_setting, kwargs
):
"""Issue #7499: invalid query inputs must not crash search tools."""
tool = tool_cls()
settings = {provider_setting: ["test-key"]}
if provider_setting == "websearch_baidu_app_builder_key":
settings = {provider_setting: "test-key"}
context = _context_with_provider_settings(settings)

result = await tool.call(context, **kwargs)

assert isinstance(result, str)
assert "Error:" in result
assert "query" in result.lower()


def _context_with_provider_settings(provider_settings):
config = {"provider_settings": provider_settings}
agent_context = SimpleNamespace(
Expand Down
Loading