33"""
44
55from copy import copy , deepcopy
6- from typing import Optional
6+ from typing import Optional , List
77from pydantic import BaseModel
88
99from .base_graph import BaseGraph
1717)
1818
1919
20+
2021class SearchGraph (AbstractGraph ):
2122 """
2223 SearchGraph is a scraping pipeline that searches the internet for answers to a given prompt.
@@ -29,6 +30,7 @@ class SearchGraph(AbstractGraph):
2930 headless (bool): A flag to run the browser in headless mode.
3031 verbose (bool): A flag to display the execution information.
3132 model_token (int): The token limit for the language model.
33+ considered_urls (List[str]): A list of URLs considered during the search.
3234
3335 Args:
3436 prompt (str): The user prompt to search the internet.
@@ -41,10 +43,10 @@ class SearchGraph(AbstractGraph):
4143 ... {"llm": {"model": "gpt-3.5-turbo"}}
4244 ... )
4345 >>> result = search_graph.run()
46+ >>> print(search_graph.get_considered_urls())
4447 """
4548
4649 def __init__ (self , prompt : str , config : dict , schema : Optional [BaseModel ] = None ):
47-
4850 self .max_results = config .get ("max_results" , 3 )
4951
5052 if all (isinstance (value , str ) for value in config .values ()):
@@ -53,6 +55,7 @@ def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None
5355 self .copy_config = deepcopy (config )
5456
5557 self .copy_schema = deepcopy (schema )
58+ self .considered_urls = [] # New attribute to store URLs
5659
5760 super ().__init__ (prompt , config , schema )
5861
@@ -64,21 +67,15 @@ def _create_graph(self) -> BaseGraph:
6467 BaseGraph: A graph instance representing the web scraping and searching workflow.
6568 """
6669
67- # ************************************************
6870 # Create a SmartScraperGraph instance
69- # ************************************************
70-
7171 smart_scraper_instance = SmartScraperGraph (
7272 prompt = "" ,
7373 source = "" ,
7474 config = self .copy_config ,
7575 schema = self .copy_schema
7676 )
7777
78- # ************************************************
7978 # Define the graph nodes
80- # ************************************************
81-
8279 search_internet_node = SearchInternetNode (
8380 input = "user_prompt" ,
8481 output = ["urls" ],
@@ -128,4 +125,17 @@ def run(self) -> str:
128125 inputs = {"user_prompt" : self .prompt }
129126 self .final_state , self .execution_info = self .graph .execute (inputs )
130127
128+ # Store the URLs after execution
129+ if 'urls' in self .final_state :
130+ self .considered_urls = self .final_state ['urls' ]
131+
131132 return self .final_state .get ("answer" , "No answer found." )
133+
134+ def get_considered_urls (self ) -> List [str ]:
135+ """
136+ Returns the list of URLs considered during the search.
137+
138+ Returns:
139+ List[str]: A list of URLs considered during the search.
140+ """
141+ return self .considered_urls
0 commit comments