Skip to content

Commit f75e083

Browse files
committed
added a function to the search_graph class to allow user to return URLs considered in the search
1 parent cc4eeb9 commit f75e083

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

scrapegraphai/graphs/search_graph.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from copy import copy, deepcopy
6-
from typing import Optional
6+
from typing import Optional, List
77
from pydantic import BaseModel
88

99
from .base_graph import BaseGraph
@@ -17,6 +17,7 @@
1717
)
1818

1919

20+
2021
class SearchGraph(AbstractGraph):
2122
"""
2223
SearchGraph is a scraping pipeline that searches the internet for answers to a given prompt.
@@ -29,6 +30,7 @@ class SearchGraph(AbstractGraph):
2930
headless (bool): A flag to run the browser in headless mode.
3031
verbose (bool): A flag to display the execution information.
3132
model_token (int): The token limit for the language model.
33+
considered_urls (List[str]): A list of URLs considered during the search.
3234
3335
Args:
3436
prompt (str): The user prompt to search the internet.
@@ -41,10 +43,10 @@ class SearchGraph(AbstractGraph):
4143
... {"llm": {"model": "gpt-3.5-turbo"}}
4244
... )
4345
>>> result = search_graph.run()
46+
>>> print(search_graph.get_considered_urls())
4447
"""
4548

4649
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
47-
4850
self.max_results = config.get("max_results", 3)
4951

5052
if all(isinstance(value, str) for value in config.values()):
@@ -53,6 +55,7 @@ def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None
5355
self.copy_config = deepcopy(config)
5456

5557
self.copy_schema = deepcopy(schema)
58+
self.considered_urls = [] # New attribute to store URLs
5659

5760
super().__init__(prompt, config, schema)
5861

@@ -64,21 +67,15 @@ def _create_graph(self) -> BaseGraph:
6467
BaseGraph: A graph instance representing the web scraping and searching workflow.
6568
"""
6669

67-
# ************************************************
6870
# Create a SmartScraperGraph instance
69-
# ************************************************
70-
7171
smart_scraper_instance = SmartScraperGraph(
7272
prompt="",
7373
source="",
7474
config=self.copy_config,
7575
schema=self.copy_schema
7676
)
7777

78-
# ************************************************
7978
# Define the graph nodes
80-
# ************************************************
81-
8279
search_internet_node = SearchInternetNode(
8380
input="user_prompt",
8481
output=["urls"],
@@ -128,4 +125,17 @@ def run(self) -> str:
128125
inputs = {"user_prompt": self.prompt}
129126
self.final_state, self.execution_info = self.graph.execute(inputs)
130127

128+
# Store the URLs after execution
129+
if 'urls' in self.final_state:
130+
self.considered_urls = self.final_state['urls']
131+
131132
return self.final_state.get("answer", "No answer found.")
133+
134+
def get_considered_urls(self) -> List[str]:
135+
"""
136+
Returns the list of URLs considered during the search.
137+
138+
Returns:
139+
List[str]: A list of URLs considered during the search.
140+
"""
141+
return self.considered_urls

0 commit comments

Comments
 (0)