From f20371237b7fee46297cc4dabf4ab85adcae51b8 Mon Sep 17 00:00:00 2001 From: Tavily PR Agent Date: Tue, 24 Mar 2026 20:10:43 +0000 Subject: [PATCH] feat: add TavilySearch engine to core search framework --- ms_agent/tools/search/search_base.py | 2 + ms_agent/tools/search/tavily/__init__.py | 3 + ms_agent/tools/search/tavily/schema.py | 101 ++++++++++++++++++ ms_agent/tools/search/tavily/search.py | 125 +++++++++++++++++++++++ ms_agent/tools/search/websearch_tool.py | 14 ++- ms_agent/tools/search_engine.py | 12 ++- requirements/research.txt | 1 + 7 files changed, 255 insertions(+), 3 deletions(-) create mode 100644 ms_agent/tools/search/tavily/__init__.py create mode 100644 ms_agent/tools/search/tavily/schema.py create mode 100644 ms_agent/tools/search/tavily/search.py diff --git a/ms_agent/tools/search/search_base.py b/ms_agent/tools/search/search_base.py index bb6952729..8a10c9d20 100644 --- a/ms_agent/tools/search/search_base.py +++ b/ms_agent/tools/search/search_base.py @@ -17,6 +17,7 @@ class SearchEngineType(enum.Enum): EXA = 'exa' SERPAPI = 'serpapi' ARXIV = 'arxiv' + TAVILY = 'tavily' # Mapping from engine type to tool name @@ -24,6 +25,7 @@ class SearchEngineType(enum.Enum): 'exa': 'exa_search', 'serpapi': 'serpapi_search', 'arxiv': 'arxiv_search', + 'tavily': 'tavily_search', } diff --git a/ms_agent/tools/search/tavily/__init__.py b/ms_agent/tools/search/tavily/__init__.py new file mode 100644 index 000000000..974b4a914 --- /dev/null +++ b/ms_agent/tools/search/tavily/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa +from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult +from ms_agent.tools.search.tavily.search import TavilySearch diff --git a/ms_agent/tools/search/tavily/schema.py b/ms_agent/tools/search/tavily/schema.py new file mode 100644 index 000000000..2993421bc --- /dev/null +++ b/ms_agent/tools/search/tavily/schema.py @@ -0,0 +1,101 @@ +# flake8: noqa +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import json + +from ms_agent.tools.search.search_base import BaseResult, SearchResponse + + +@dataclass +class TavilySearchRequest: + + # The search query string + query: str + + # Number of results to return, default is 5 + num_results: Optional[int] = 5 + + # Search depth: 'basic' or 'advanced' + search_depth: Optional[str] = 'basic' + + # Topic category: 'general', 'news', or 'finance' + topic: Optional[str] = 'general' + + # Domains to include in search + include_domains: Optional[List[str]] = None + + # Domains to exclude from search + exclude_domains: Optional[List[str]] = None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the request parameters to a dictionary. + """ + d = { + 'query': self.query, + 'max_results': self.num_results, + 'search_depth': self.search_depth, + 'topic': self.topic, + } + if self.include_domains: + d['include_domains'] = self.include_domains + if self.exclude_domains: + d['exclude_domains'] = self.exclude_domains + return d + + def to_json(self) -> str: + """ + Convert the request parameters to a JSON string. + """ + return json.dumps(self.to_dict(), ensure_ascii=False) + + +@dataclass +class TavilySearchResult: + + # The original search query string + query: str + + # Optional arguments for the search request + arguments: Dict[str, Any] = field(default_factory=dict) + + # The response from the Tavily search API (dict with 'results' key) + response: SearchResponse = None + + def to_list(self): + """ + Convert the search results to a list of dictionaries. + """ + if not self.response or not self.response.results: + print('***Warning: No search results found.') + return [] + + if not self.query: + print('***Warning: No query provided for search results.') + return [] + + res_list: List[Any] = [] + for res in self.response.results: + res_list.append({ + 'url': res.url, + 'id': res.id, + 'title': res.title, + 'highlights': res.highlights, + 'highlight_scores': res.highlight_scores, + 'summary': res.summary, + 'markdown': res.markdown, + }) + + return res_list + + @staticmethod + def load_from_disk(file_path: str) -> List[Dict[str, Any]]: + """ + Load search results from a local file. + """ + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + print(f'Search results loaded from {file_path}') + + return data diff --git a/ms_agent/tools/search/tavily/search.py b/ms_agent/tools/search/tavily/search.py new file mode 100644 index 000000000..9ca50cfd0 --- /dev/null +++ b/ms_agent/tools/search/tavily/search.py @@ -0,0 +1,125 @@ +# flake8: noqa +import os +from typing import TYPE_CHECKING + +from tavily import TavilyClient +from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult +from ms_agent.tools.search.search_base import (BaseResult, SearchEngine, + SearchEngineType, + SearchResponse) + +if TYPE_CHECKING: + from ms_agent.llm.utils import Tool + + +class TavilySearch(SearchEngine): + """ + Search engine using Tavily API. + + Best for: AI-optimized web search, general and news queries, + high relevance results with built-in content extraction. + """ + + engine_type = SearchEngineType.TAVILY + + def __init__(self, api_key: str = None): + + api_key = api_key or os.getenv('TAVILY_API_KEY') + assert api_key, 'TAVILY_API_KEY must be set either as an argument or as an environment variable' + + self.client = TavilyClient(api_key=api_key) + + def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: + """ + Perform a search using the Tavily API with the provided search request parameters. + + :param search_request: An instance of TavilySearchRequest containing search parameters. + :return: An instance of TavilySearchResult containing the search results. + """ + search_args: dict = search_request.to_dict() + search_result: TavilySearchResult = TavilySearchResult( + query=search_request.query, + arguments=search_args, + ) + try: + raw_response = self.client.search(**search_args) + # Map Tavily results to BaseResult schema + results = [] + for item in raw_response.get('results', []): + results.append( + BaseResult( + url=item.get('url', ''), + id=item.get('url', ''), + title=item.get('title', ''), + summary=item.get('content', ''), + markdown=item.get('raw_content'), + )) + search_result.response = SearchResponse(results=results) + except Exception as e: + raise RuntimeError(f'Failed to perform search: {e}') from e + + return search_result + + @classmethod + def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': + """Return the tool definition for Tavily search engine.""" + from ms_agent.llm.utils import Tool + return Tool( + tool_name=cls.get_tool_name(), + server_name=server_name, + description=( + 'Search the web using Tavily AI-optimized search engine. ' + 'Best for: general web queries, news, and finance topics. ' + 'Returns highly relevant results with content extraction.'), + parameters={ + 'type': 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'The search query. Use natural language for best results.', + }, + 'num_results': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 10, + 'description': + 'Number of results to return. Default is 5.', + }, + 'search_depth': { + 'type': + 'string', + 'enum': ['basic', 'advanced'], + 'description': + ('Search depth. "basic" for fast results, ' + '"advanced" for higher relevance. Default is "basic".' + ), + }, + 'topic': { + 'type': + 'string', + 'enum': ['general', 'news', 'finance'], + 'description': + ('Topic category for the search. ' + 'Default is "general".'), + }, + }, + 'required': ['query'], + }, + ) + + @classmethod + def build_request_from_args(cls, **kwargs) -> TavilySearchRequest: + """Build TavilySearchRequest from tool call arguments.""" + return TavilySearchRequest( + query=kwargs['query'], + num_results=kwargs.get('num_results', 5), + search_depth=kwargs.get('search_depth', 'basic'), + topic=kwargs.get('topic', 'general'), + include_domains=kwargs.get('include_domains'), + exclude_domains=kwargs.get('exclude_domains'), + ) diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 7bbecfe89..3f6da2129 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -206,6 +206,9 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch else: logger.warning( f"Unknown search engine '{engine_type}', falling back to arxiv") @@ -238,6 +241,9 @@ def get_search_engine(engine_type: str, api_key=api_key or os.getenv('SERPAPI_API_KEY'), provider=kwargs.get('provider', default_provider), ) + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch(api_key=api_key or os.getenv('TAVILY_API_KEY')) elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch() @@ -296,7 +302,7 @@ class WebSearchTool(ToolBase): SERVER_NAME = 'web_search' # Registry of supported search engines - SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv') + SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv', 'tavily') # Process-wide (class-level) usage tracking for summarization calls. # This is intentionally separate from LLMAgent usage totals. @@ -404,6 +410,9 @@ def __init__(self, config, **kwargs): 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) if tool_cfg else os.getenv('SERPAPI_API_KEY'), + 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) + or os.getenv('TAVILY_API_KEY')) + if tool_cfg else os.getenv('TAVILY_API_KEY'), } # SerpApi provider (google, bing, baidu) @@ -508,6 +517,9 @@ async def connect(self) -> None: api_key=self._api_keys.get('serpapi'), provider=self._serpapi_provider, ) + elif engine_type == 'tavily': + self._engines[engine_type] = engine_cls( + api_key=self._api_keys.get('tavily')) else: # arxiv self._engines[engine_type] = engine_cls() diff --git a/ms_agent/tools/search_engine.py b/ms_agent/tools/search_engine.py index 6bb0be4d6..7e8799611 100644 --- a/ms_agent/tools/search_engine.py +++ b/ms_agent/tools/search_engine.py @@ -8,6 +8,7 @@ from ms_agent.tools.search.exa import ExaSearch from ms_agent.tools.search.search_base import SearchEngineType from ms_agent.tools.search.serpapi import SerpApiSearch +from ms_agent.tools.search.tavily import TavilySearch from ms_agent.utils.logger import get_logger logger = get_logger() @@ -23,7 +24,8 @@ def set_search_env_overrides(env_overrides: Optional[Dict[str, str]]) -> None: Expected keys (all optional): - 'EXA_API_KEY' - 'SERPAPI_API_KEY' - - SEARCH_ENGINE_OVERRIDE_ENV (e.g. 'exa' / 'serpapi' / 'arxiv') + - 'TAVILY_API_KEY' + - SEARCH_ENGINE_OVERRIDE_ENV (e.g. 'exa' / 'serpapi' / 'arxiv' / 'tavily') """ if not env_overrides: if hasattr(_search_env_local, 'overrides'): @@ -135,7 +137,8 @@ def get_web_search_tool(config_file: str): or '')).strip().lower() if engine_override and engine_override in (SearchEngineType.EXA.value, SearchEngineType.SERPAPI.value, - SearchEngineType.ARXIV.value): + SearchEngineType.ARXIV.value, + SearchEngineType.TAVILY.value): search_config['engine'] = engine_override engine_name = (search_config.get('engine', '') or '').lower() @@ -143,6 +146,7 @@ def get_web_search_tool(config_file: str): # Per-request API key overrides (thread-local) take precedence override_exa_key = local_env.get('EXA_API_KEY') override_serp_key = local_env.get('SERPAPI_API_KEY') + override_tavily_key = local_env.get('TAVILY_API_KEY') if engine_name == SearchEngineType.EXA.value: return ExaSearch( @@ -153,6 +157,10 @@ def get_web_search_tool(config_file: str): api_key=override_serp_key or search_config.get( 'serpapi_api_key', os.getenv('SERPAPI_API_KEY', None)), provider=search_config.get('provider', 'google').lower()) + elif engine_name == SearchEngineType.TAVILY.value: + return TavilySearch( + api_key=override_tavily_key or search_config.get( + 'tavily_api_key', os.getenv('TAVILY_API_KEY', None))) elif engine_name == SearchEngineType.ARXIV.value: return ArxivSearch() else: diff --git a/requirements/research.txt b/requirements/research.txt index 67ce2d3fd..86ba765de 100644 --- a/requirements/research.txt +++ b/requirements/research.txt @@ -14,3 +14,4 @@ Pillow python-dotenv requests rich +tavily-python