From fc92e13f5cbf76980dd527b87f29f10016fb81f9 Mon Sep 17 00:00:00 2001 From: Nick Sweeting Date: Fri, 3 Apr 2026 11:20:22 -0700 Subject: [PATCH 1/4] Move custom code and scripts --- scripts/{download-binary.py => download_binary.py} | 0 test_local_mode.py => scripts/test_local_mode.py | 0 src/stagehand/{lib => _custom}/__init__.py | 0 src/stagehand/{lib => _custom}/sea_binary.py | 0 src/stagehand/{lib => _custom}/sea_server.py | 0 src/stagehand/{ => _custom}/session.py | 0 6 files changed, 0 insertions(+), 0 deletions(-) rename scripts/{download-binary.py => download_binary.py} (100%) rename test_local_mode.py => scripts/test_local_mode.py (100%) mode change 100644 => 100755 rename src/stagehand/{lib => _custom}/__init__.py (100%) rename src/stagehand/{lib => _custom}/sea_binary.py (100%) rename src/stagehand/{lib => _custom}/sea_server.py (100%) rename src/stagehand/{ => _custom}/session.py (100%) diff --git a/scripts/download-binary.py b/scripts/download_binary.py similarity index 100% rename from scripts/download-binary.py rename to scripts/download_binary.py diff --git a/test_local_mode.py b/scripts/test_local_mode.py old mode 100644 new mode 100755 similarity index 100% rename from test_local_mode.py rename to scripts/test_local_mode.py diff --git a/src/stagehand/lib/__init__.py b/src/stagehand/_custom/__init__.py similarity index 100% rename from src/stagehand/lib/__init__.py rename to src/stagehand/_custom/__init__.py diff --git a/src/stagehand/lib/sea_binary.py b/src/stagehand/_custom/sea_binary.py similarity index 100% rename from src/stagehand/lib/sea_binary.py rename to src/stagehand/_custom/sea_binary.py diff --git a/src/stagehand/lib/sea_server.py b/src/stagehand/_custom/sea_server.py similarity index 100% rename from src/stagehand/lib/sea_server.py rename to src/stagehand/_custom/sea_server.py diff --git a/src/stagehand/session.py b/src/stagehand/_custom/session.py similarity index 100% rename from src/stagehand/session.py rename to src/stagehand/_custom/session.py From ff29af44f8c35eeab4c980b8bf6ed65b65d24af2 Mon Sep 17 00:00:00 2001 From: Nick Sweeting Date: Fri, 3 Apr 2026 11:20:48 -0700 Subject: [PATCH 2/4] Centralize custom code --- CONTRIBUTING.md | 6 +- pyproject.toml | 3 +- scripts/download_binary.py | 14 +- scripts/test_local_mode.py | 114 ++-- src/stagehand/__init__.py | 11 + src/stagehand/_client.py | 245 ++++---- src/stagehand/_custom/__init__.py | 29 +- src/stagehand/_custom/sea_binary.py | 2 +- src/stagehand/_custom/sea_server.py | 138 ++++- src/stagehand/_custom/session.py | 611 +++++++++++++++++--- src/stagehand/lib/.keep | 4 - src/stagehand/resources/sessions.py | 42 +- src/stagehand/resources/sessions_helpers.py | 352 ----------- tests/test_local_server.py | 2 +- tests/test_sea_binary.py | 4 +- 15 files changed, 944 insertions(+), 633 deletions(-) delete mode 100644 src/stagehand/lib/.keep delete mode 100644 src/stagehand/resources/sessions_helpers.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1189c286..16d5fa94 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,7 +24,7 @@ uv run python script.py Most of the SDK is generated code. Modifications to code will be persisted between generations, but may result in merge conflicts between manual patches and changes from the generator. The generator will never -modify the contents of the `src/stagehand/lib/` and `examples/` directories. +modify the contents of the `src/stagehand/_custom/` and `examples/` directories. ## Setting up the local server binary (for development) @@ -35,7 +35,7 @@ The SDK supports running a local Stagehand server for development and testing. T Run the download script to automatically download the correct binary: ```sh -$ uv run python scripts/download-binary.py +$ uv run python scripts/download_binary.py ``` This will: @@ -64,7 +64,7 @@ Instead of placing the binary in `bin/sea/`, you can point to any binary locatio ```sh $ export STAGEHAND_SEA_BINARY=/path/to/your/stagehand-binary -$ uv run python test_local_mode.py +$ uv run python scripts/test_local_mode.py ``` ## Adding and running examples diff --git a/pyproject.toml b/pyproject.toml index ba705dd8..f75faef1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,6 @@ exclude = [ "hatch_build.py", "examples", "scripts", - "test_local_mode.py", ] reportImplicitOverride = true @@ -156,7 +155,7 @@ show_error_codes = true # # We also exclude our `tests` as mypy doesn't always infer # types correctly and Pyright will still catch any type errors. -exclude = ['src/stagehand/_files.py', '_dev/.*.py', 'tests/.*', 'hatch_build.py', 'examples/.*', 'scripts/.*', 'test_local_mode.py'] +exclude = ['src/stagehand/_files.py', '_dev/.*.py', 'tests/.*', 'hatch_build.py', 'examples/.*', 'scripts/.*'] strict_equality = true implicit_reexport = true diff --git a/scripts/download_binary.py b/scripts/download_binary.py index 9d88f4f8..2690c03f 100755 --- a/scripts/download_binary.py +++ b/scripts/download_binary.py @@ -6,11 +6,11 @@ and places it in bin/sea/ for use during development and testing. Usage: - python scripts/download-binary.py [--version VERSION] + python scripts/download_binary.py [--version VERSION] Examples: - python scripts/download-binary.py - python scripts/download-binary.py --version v3.2.0 + python scripts/download_binary.py + python scripts/download_binary.py --version v3.2.0 """ from __future__ import annotations @@ -179,7 +179,7 @@ def reporthook(block_num: int, block_size: int, total_size: int) -> None: size_mb = dest_path.stat().st_size / (1024 * 1024) print(f"āœ… Downloaded successfully: {dest_path} ({size_mb:.1f} MB)") - print(f"\nšŸ’” You can now run: uv run python test_local_mode.py") + print("\nšŸ’” You can now run: uv run python scripts/test_local_mode.py") except urllib.error.HTTPError as e: # type: ignore[misc] print(f"\nāŒ Error: Failed to download binary (HTTP {e.code})") # type: ignore[union-attr] @@ -197,9 +197,9 @@ def main() -> None: formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - python scripts/download-binary.py - python scripts/download-binary.py --version v3.2.0 - python scripts/download-binary.py --version stagehand-server-v3/v3.2.0 + python scripts/download_binary.py + python scripts/download_binary.py --version v3.2.0 + python scripts/download_binary.py --version stagehand-server-v3/v3.2.0 """, ) parser.add_argument( diff --git a/scripts/test_local_mode.py b/scripts/test_local_mode.py index 12267164..a775c833 100755 --- a/scripts/test_local_mode.py +++ b/scripts/test_local_mode.py @@ -3,72 +3,76 @@ import os import sys +import traceback +from pathlib import Path -# Add src to path for local testing -sys.path.insert(0, "src") +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) from stagehand import Stagehand -# Set required API key for LLM operations -if not os.environ.get("MODEL_API_KEY") and not os.environ.get("OPENAI_API_KEY"): - print("āŒ Error: MODEL_API_KEY or OPENAI_API_KEY environment variable not set") # noqa: T201 - print(" Set it with: export MODEL_API_KEY='sk-proj-...'") # noqa: T201 - sys.exit(1) -print("šŸš€ Testing local server mode...") # noqa: T201 +def main() -> None: + model_api_key = os.environ.get("MODEL_API_KEY") or os.environ.get("OPENAI_API_KEY") + if not model_api_key: + print("āŒ Error: MODEL_API_KEY or OPENAI_API_KEY environment variable not set") # noqa: T201 + print(" Set it with: export MODEL_API_KEY='sk-proj-...'") # noqa: T201 + sys.exit(1) -try: - # Create client in local mode - will use bundled binary - print("šŸ“¦ Creating Stagehand client in local mode...") # noqa: T201 - client = Stagehand( - server="local", - browserbase_api_key="local", # Dummy value - not used in local mode - browserbase_project_id="local", # Dummy value - not used in local mode - model_api_key=os.environ.get("MODEL_API_KEY") or os.environ["OPENAI_API_KEY"], - local_headless=True, - local_port=0, # Auto-pick free port - local_ready_timeout_s=15.0, # Give it time to start - ) + os.environ["BROWSERBASE_FLOW_LOGS"] = "1" - print("šŸ”§ Starting session (this will start the local server)...") # noqa: T201 - session = client.sessions.start( - model_name="openai/gpt-5-nano", - browser={ # type: ignore[arg-type] - "type": "local", - "launchOptions": {}, # Launch local Playwright browser with defaults - }, - ) - session_id = session.data.session_id + print("šŸš€ Testing local server mode...") # noqa: T201 + client = None - print(f"āœ… Session started: {session_id}") # noqa: T201 - print(f"🌐 Server running at: {client.base_url}") # noqa: T201 + try: + print("šŸ“¦ Creating Stagehand client in local mode...") # noqa: T201 + client = Stagehand( + server="local", + browserbase_api_key="local", + browserbase_project_id="local", + model_api_key=model_api_key, + local_headless=True, + local_port=0, + local_ready_timeout_s=15.0, + ) - print("\nšŸ“ Navigating to example.com...") # noqa: T201 - client.sessions.navigate( - id=session_id, - url="https://example.com", - ) - print("āœ… Navigation complete") # noqa: T201 + print("šŸ”§ Starting session (this will start the local server)...") # noqa: T201 + session = client.sessions.start( + model_name="openai/gpt-5-nano", + browser={ # type: ignore[arg-type] + "type": "local", + "launchOptions": {}, + }, + ) + session_id = session.data.session_id - print("\nšŸ” Extracting page heading...") # noqa: T201 - result = client.sessions.extract( - id=session_id, - instruction="Extract the main heading text from the page", - ) - print(f"šŸ“„ Extracted: {result.data.result}") # noqa: T201 + print(f"āœ… Session started: {session_id}") # noqa: T201 + print(f"🌐 Server running at: {client.base_url}") # noqa: T201 - print("\nšŸ›‘ Ending session...") # noqa: T201 - client.sessions.end(id=session_id) - print("āœ… Session ended") # noqa: T201 + print("\nšŸ“ Navigating to example.com...") # noqa: T201 + client.sessions.navigate(id=session_id, url="https://example.com") + print("āœ… Navigation complete") # noqa: T201 - print("\nšŸ”Œ Closing client (will shut down server)...") # noqa: T201 - client.close() - print("āœ… Server shut down successfully!") # noqa: T201 + print("\nšŸ” Extracting page heading...") # noqa: T201 + result = client.sessions.extract( + id=session_id, + instruction="Extract the main heading text from the page", + ) + print(f"šŸ“„ Extracted: {result.data.result}") # noqa: T201 - print("\nšŸŽ‰ All tests passed!") # noqa: T201 + print("\nšŸ›‘ Ending session...") # noqa: T201 + client.sessions.end(id=session_id) + print("āœ… Session ended") # noqa: T201 + print("\nšŸŽ‰ All tests passed!") # noqa: T201 + except Exception as exc: + print(f"\nāŒ Error: {exc}") # noqa: T201 + traceback.print_exc() + sys.exit(1) + finally: + if client is not None: + print("\nšŸ”Œ Closing client (will shut down server)...") # noqa: T201 + client.close() + print("āœ… Server shut down successfully!") # noqa: T201 -except Exception as e: - print(f"\nāŒ Error: {e}") # noqa: T201 - import traceback - traceback.print_exc() - sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/src/stagehand/__init__.py b/src/stagehand/__init__.py index 95c793d7..301348d2 100644 --- a/src/stagehand/__init__.py +++ b/src/stagehand/__init__.py @@ -39,6 +39,13 @@ from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging +### +# Re-export the public bound session types from `_custom` so users can type +# against `stagehand.Session` instead of importing from private modules. +from ._custom.session import Session, AsyncSession + +### + __all__ = [ "types", "__version__", @@ -73,6 +80,10 @@ "AsyncStream", "Stagehand", "AsyncStagehand", + ### + "Session", + "AsyncSession", + ### "file_from_path", "BaseModel", "DEFAULT_TIMEOUT", diff --git a/src/stagehand/_client.py b/src/stagehand/_client.py index 7204dcf1..eeb74fd3 100644 --- a/src/stagehand/_client.py +++ b/src/stagehand/_client.py @@ -30,11 +30,27 @@ SyncAPIClient, AsyncAPIClient, ) -from .lib.sea_server import SeaServerConfig, SeaServerManager + +### +# Keep the generated client thin: all runtime patch logic lives in `_custom`. +from ._custom.session import install_stainless_session_patches +from ._custom.sea_server import ( + copy_local_mode_kwargs, + configure_client_base_url, + close_sync_client_sea_server, + prepare_sync_client_base_url, + close_async_client_sea_server, + prepare_async_client_base_url, +) + +### if TYPE_CHECKING: from .resources import sessions - from .resources.sessions_helpers import SessionsResourceWithHelpers, AsyncSessionsResourceWithHelpers + + ### + from ._custom.sea_server import SeaServerManager + ### __all__ = [ "Timeout", @@ -47,12 +63,31 @@ "AsyncClient", ] +### +# Patch the generated resource classes in place so user-facing types stay on the +# original Stainless imports instead of custom wrapper classes. +install_stainless_session_patches() +### + class Stagehand(SyncAPIClient): # client options browserbase_api_key: str | None browserbase_project_id: str | None model_api_key: str | None + ### + # These are assigned indirectly by `configure_client_base_url(...)` so the + # generated class still exposes typed local-mode state for `copy()` and tests. + _server_mode: Literal["remote", "local"] + _local_stagehand_binary_path: str | os.PathLike[str] | None + _local_host: str + _local_port: int + _local_headless: bool + _local_chrome_path: str | None + _local_ready_timeout_s: float + _local_shutdown_on_close: bool + _sea_server: SeaServerManager | None + ### def __init__( self, @@ -97,15 +132,6 @@ def __init__( Pass it explicitly when you want the SDK to send `x-model-api-key` on remote requests or to forward `MODEL_API_KEY` to the local SEA child process. """ - self._server_mode: Literal["remote", "local"] = server - self._local_stagehand_binary_path = _local_stagehand_binary_path - self._local_host = local_host - self._local_port = local_port - self._local_headless = local_headless - self._local_chrome_path = local_chrome_path - self._local_ready_timeout_s = local_ready_timeout_s - self._local_shutdown_on_close = local_shutdown_on_close - if browserbase_api_key is None: browserbase_api_key = os.environ.get("BROWSERBASE_API_KEY") if browserbase_project_id is None: @@ -116,29 +142,23 @@ def __init__( self.model_api_key = model_api_key - self._sea_server: SeaServerManager | None = None - if server == "local": - # We'll switch `base_url` to the started server before the first request. - if base_url is None: - base_url = "http://127.0.0.1" - - self._sea_server = SeaServerManager( - config=SeaServerConfig( - host=local_host, - port=local_port, - headless=local_headless, - ready_timeout_s=local_ready_timeout_s, - model_api_key=model_api_key, - chrome_path=local_chrome_path, - shutdown_on_close=local_shutdown_on_close, - ), - _local_stagehand_binary_path=_local_stagehand_binary_path, - ) - else: - if base_url is None: - base_url = os.environ.get("STAGEHAND_BASE_URL") - if base_url is None: - base_url = f"https://api.stagehand.browserbase.com" + ### + # Centralize local-mode state hydration and base-url selection in `_custom` + # so no constructor branching lives in the generated client. + base_url = configure_client_base_url( + self, + server=server, + _local_stagehand_binary_path=_local_stagehand_binary_path, + local_host=local_host, + local_port=local_port, + local_headless=local_headless, + local_chrome_path=local_chrome_path, + local_ready_timeout_s=local_ready_timeout_s, + local_shutdown_on_close=local_shutdown_on_close, + base_url=base_url, + model_api_key=model_api_key, + ) + ### super().__init__( version=__version__, @@ -155,8 +175,13 @@ def __init__( @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if self._sea_server is not None: - self.base_url = self._sea_server.ensure_running_sync() + ### + # Start the local SEA server lazily on first request instead of at client + # construction time, then swap the base URL to the started process. + local_base_url = prepare_sync_client_base_url(self) + if local_base_url is not None: + self.base_url = local_base_url + ### return super()._prepare_options(options) @override @@ -164,14 +189,16 @@ def close(self) -> None: try: super().close() finally: - if self._sea_server is not None: - self._sea_server.close() + ### + # Tear down the managed SEA process after HTTP resources close. + close_sync_client_sea_server(self) + ### @cached_property - def sessions(self) -> SessionsResourceWithHelpers: - from .resources.sessions_helpers import SessionsResourceWithHelpers + def sessions(self) -> sessions.SessionsResource: + from .resources.sessions import SessionsResource - return SessionsResourceWithHelpers(self) + return SessionsResource(self) @cached_property def with_raw_response(self) -> StagehandWithRawResponse: @@ -267,24 +294,27 @@ def copy( browserbase_api_key=browserbase_api_key or self.browserbase_api_key, browserbase_project_id=browserbase_project_id or self.browserbase_project_id, model_api_key=model_api_key or self.model_api_key, - server=server or self._server_mode, - _local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path, - local_host=local_host or self._local_host, - local_port=local_port if local_port is not None else self._local_port, - local_headless=local_headless if local_headless is not None else self._local_headless, - local_chrome_path=local_chrome_path if local_chrome_path is not None else self._local_chrome_path, - local_ready_timeout_s=local_ready_timeout_s - if local_ready_timeout_s is not None - else self._local_ready_timeout_s, - local_shutdown_on_close=local_shutdown_on_close - if local_shutdown_on_close is not None - else self._local_shutdown_on_close, base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + ### + # Preserve local-mode configuration when cloning the client without + # duplicating that branching logic in generated code. + **copy_local_mode_kwargs( + self, + server=server, + _local_stagehand_binary_path=_local_stagehand_binary_path, + local_host=local_host, + local_port=local_port, + local_headless=local_headless, + local_chrome_path=local_chrome_path, + local_ready_timeout_s=local_ready_timeout_s, + local_shutdown_on_close=local_shutdown_on_close, + ), + ### **_extra_kwargs, ) @@ -331,6 +361,19 @@ class AsyncStagehand(AsyncAPIClient): browserbase_api_key: str | None browserbase_project_id: str | None model_api_key: str | None + ### + # These are assigned indirectly by `configure_client_base_url(...)` so the + # generated class still exposes typed local-mode state for `copy()` and tests. + _server_mode: Literal["remote", "local"] + _local_stagehand_binary_path: str | os.PathLike[str] | None + _local_host: str + _local_port: int + _local_headless: bool + _local_chrome_path: str | None + _local_ready_timeout_s: float + _local_shutdown_on_close: bool + _sea_server: SeaServerManager | None + ### def __init__( self, @@ -375,15 +418,6 @@ def __init__( Pass it explicitly when you want the SDK to send `x-model-api-key` on remote requests or to forward `MODEL_API_KEY` to the local SEA child process. """ - self._server_mode: Literal["remote", "local"] = server - self._local_stagehand_binary_path = _local_stagehand_binary_path - self._local_host = local_host - self._local_port = local_port - self._local_headless = local_headless - self._local_chrome_path = local_chrome_path - self._local_ready_timeout_s = local_ready_timeout_s - self._local_shutdown_on_close = local_shutdown_on_close - if browserbase_api_key is None: browserbase_api_key = os.environ.get("BROWSERBASE_API_KEY") if browserbase_project_id is None: @@ -394,28 +428,23 @@ def __init__( self.model_api_key = model_api_key - self._sea_server: SeaServerManager | None = None - if server == "local": - if base_url is None: - base_url = "http://127.0.0.1" - - self._sea_server = SeaServerManager( - config=SeaServerConfig( - host=local_host, - port=local_port, - headless=local_headless, - ready_timeout_s=local_ready_timeout_s, - model_api_key=model_api_key, - chrome_path=local_chrome_path, - shutdown_on_close=local_shutdown_on_close, - ), - _local_stagehand_binary_path=_local_stagehand_binary_path, - ) - else: - if base_url is None: - base_url = os.environ.get("STAGEHAND_BASE_URL") - if base_url is None: - base_url = f"https://api.stagehand.browserbase.com" + ### + # Centralize local-mode state hydration and base-url selection in `_custom` + # so no constructor branching lives in the generated client. + base_url = configure_client_base_url( + self, + server=server, + _local_stagehand_binary_path=_local_stagehand_binary_path, + local_host=local_host, + local_port=local_port, + local_headless=local_headless, + local_chrome_path=local_chrome_path, + local_ready_timeout_s=local_ready_timeout_s, + local_shutdown_on_close=local_shutdown_on_close, + base_url=base_url, + model_api_key=model_api_key, + ) + ### super().__init__( version=__version__, @@ -432,8 +461,13 @@ def __init__( @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - if self._sea_server is not None: - self.base_url = await self._sea_server.ensure_running_async() + ### + # Start the local SEA server lazily on first request instead of at client + # construction time, then swap the base URL to the started process. + local_base_url = await prepare_async_client_base_url(self) + if local_base_url is not None: + self.base_url = local_base_url + ### return await super()._prepare_options(options) @override @@ -441,14 +475,16 @@ async def close(self) -> None: try: await super().close() finally: - if self._sea_server is not None: - await self._sea_server.aclose() + ### + # Tear down the managed SEA process after HTTP resources close. + await close_async_client_sea_server(self) + ### @cached_property - def sessions(self) -> AsyncSessionsResourceWithHelpers: - from .resources.sessions_helpers import AsyncSessionsResourceWithHelpers + def sessions(self) -> sessions.AsyncSessionsResource: + from .resources.sessions import AsyncSessionsResource - return AsyncSessionsResourceWithHelpers(self) + return AsyncSessionsResource(self) @cached_property def with_raw_response(self) -> AsyncStagehandWithRawResponse: @@ -544,24 +580,27 @@ def copy( browserbase_api_key=browserbase_api_key or self.browserbase_api_key, browserbase_project_id=browserbase_project_id or self.browserbase_project_id, model_api_key=model_api_key or self.model_api_key, - server=server or self._server_mode, - _local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path, - local_host=local_host or self._local_host, - local_port=local_port if local_port is not None else self._local_port, - local_headless=local_headless if local_headless is not None else self._local_headless, - local_chrome_path=local_chrome_path if local_chrome_path is not None else self._local_chrome_path, - local_ready_timeout_s=local_ready_timeout_s - if local_ready_timeout_s is not None - else self._local_ready_timeout_s, - local_shutdown_on_close=local_shutdown_on_close - if local_shutdown_on_close is not None - else self._local_shutdown_on_close, base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + ### + # Preserve local-mode configuration when cloning the client without + # duplicating that branching logic in generated code. + **copy_local_mode_kwargs( + self, + server=server, + _local_stagehand_binary_path=_local_stagehand_binary_path, + local_host=local_host, + local_port=local_port, + local_headless=local_headless, + local_chrome_path=local_chrome_path, + local_ready_timeout_s=local_ready_timeout_s, + local_shutdown_on_close=local_shutdown_on_close, + ), + ### **_extra_kwargs, ) diff --git a/src/stagehand/_custom/__init__.py b/src/stagehand/_custom/__init__.py index 60fb7e11..63ea7b6a 100644 --- a/src/stagehand/_custom/__init__.py +++ b/src/stagehand/_custom/__init__.py @@ -1,11 +1,32 @@ -"""SEA binary and server management.""" - +from .session import ( + Session, + AsyncSession, + install_stainless_session_patches, +) from .sea_binary import resolve_binary_path, default_binary_filename -from .sea_server import SeaServerConfig, SeaServerManager +from .sea_server import ( + SeaServerConfig, + SeaServerManager, + copy_local_mode_kwargs, + configure_client_base_url, + close_sync_client_sea_server, + prepare_sync_client_base_url, + close_async_client_sea_server, + prepare_async_client_base_url, +) __all__ = [ - "resolve_binary_path", "default_binary_filename", + "resolve_binary_path", "SeaServerConfig", "SeaServerManager", + "close_async_client_sea_server", + "close_sync_client_sea_server", + "configure_client_base_url", + "copy_local_mode_kwargs", + "prepare_async_client_base_url", + "prepare_sync_client_base_url", + "AsyncSession", + "Session", + "install_stainless_session_patches", ] diff --git a/src/stagehand/_custom/sea_binary.py b/src/stagehand/_custom/sea_binary.py index 6d8f4eed..9c6badc8 100644 --- a/src/stagehand/_custom/sea_binary.py +++ b/src/stagehand/_custom/sea_binary.py @@ -106,7 +106,7 @@ def resolve_binary_path( # Fallback: source checkout layout (works for local dev in-repo). here = Path(__file__).resolve() - repo_root = here.parents[3] # stagehand-python/ + repo_root = here.parents[3] candidate = repo_root / "bin" / "sea" / filename if not candidate.exists(): diff --git a/src/stagehand/_custom/sea_server.py b/src/stagehand/_custom/sea_server.py index ef34aa62..e31f7b4a 100644 --- a/src/stagehand/_custom/sea_server.py +++ b/src/stagehand/_custom/sea_server.py @@ -11,6 +11,7 @@ from pathlib import Path from threading import Lock from dataclasses import dataclass +from typing_extensions import Literal, Protocol, TypedDict import httpx @@ -29,6 +30,29 @@ class SeaServerConfig: shutdown_on_close: bool +class _HasLocalModeState(Protocol): + _server_mode: Literal["remote", "local"] + _local_stagehand_binary_path: str | os.PathLike[str] | None + _local_host: str + _local_port: int + _local_headless: bool + _local_chrome_path: str | None + _local_ready_timeout_s: float + _local_shutdown_on_close: bool + _sea_server: SeaServerManager | None + + +class LocalModeKwargs(TypedDict): + server: Literal["remote", "local"] + _local_stagehand_binary_path: str | os.PathLike[str] | None + local_host: str + local_port: int + local_headless: bool + local_chrome_path: str | None + local_ready_timeout_s: float + local_shutdown_on_close: bool + + def _pick_free_port(host: str) -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind((host, 0)) @@ -105,7 +129,10 @@ def __init__( _local_stagehand_binary_path: str | os.PathLike[str] | None = None, ) -> None: self._config = config - self._binary_path: Path = resolve_binary_path(_local_stagehand_binary_path=_local_stagehand_binary_path, version=__version__) + self._binary_path: Path = resolve_binary_path( + _local_stagehand_binary_path=_local_stagehand_binary_path, + version=__version__, + ) self._lock = Lock() self._async_lock = asyncio.Lock() @@ -257,3 +284,112 @@ async def _start_async(self) -> tuple[str, subprocess.Popen[bytes]]: raise return base_url, proc + + +def configure_client_base_url( + client: _HasLocalModeState, + *, + server: Literal["remote", "local"], + _local_stagehand_binary_path: str | os.PathLike[str] | None, + local_host: str, + local_port: int, + local_headless: bool, + local_chrome_path: str | None, + local_ready_timeout_s: float, + local_shutdown_on_close: bool, + base_url: str | httpx.URL | None, + model_api_key: str | None, +) -> str | httpx.URL: + client._server_mode = server + client._local_stagehand_binary_path = _local_stagehand_binary_path + client._local_host = local_host + client._local_port = local_port + client._local_headless = local_headless + client._local_chrome_path = local_chrome_path + client._local_ready_timeout_s = local_ready_timeout_s + client._local_shutdown_on_close = local_shutdown_on_close + client._sea_server = None + + if server == "local": + if base_url is None: + base_url = "http://127.0.0.1" + + client._sea_server = SeaServerManager( + config=SeaServerConfig( + host=local_host, + port=local_port, + headless=local_headless, + ready_timeout_s=local_ready_timeout_s, + model_api_key=model_api_key, + chrome_path=local_chrome_path, + shutdown_on_close=local_shutdown_on_close, + ), + _local_stagehand_binary_path=_local_stagehand_binary_path, + ) + return base_url + + if base_url is None: + base_url = os.environ.get("STAGEHAND_BASE_URL") + if base_url is None: + base_url = "https://api.stagehand.browserbase.com" + return base_url + + +def copy_local_mode_kwargs( + client: _HasLocalModeState, + *, + server: Literal["remote", "local"] | None, + _local_stagehand_binary_path: str | os.PathLike[str] | None, + local_host: str | None, + local_port: int | None, + local_headless: bool | None, + local_chrome_path: str | None, + local_ready_timeout_s: float | None, + local_shutdown_on_close: bool | None, +) -> LocalModeKwargs: + return { + "server": server or client._server_mode, + "_local_stagehand_binary_path": ( + _local_stagehand_binary_path + if _local_stagehand_binary_path is not None + else client._local_stagehand_binary_path + ), + "local_host": local_host or client._local_host, + "local_port": local_port if local_port is not None else client._local_port, + "local_headless": local_headless if local_headless is not None else client._local_headless, + "local_chrome_path": ( + local_chrome_path if local_chrome_path is not None else client._local_chrome_path + ), + "local_ready_timeout_s": ( + local_ready_timeout_s + if local_ready_timeout_s is not None + else client._local_ready_timeout_s + ), + "local_shutdown_on_close": ( + local_shutdown_on_close + if local_shutdown_on_close is not None + else client._local_shutdown_on_close + ), + } + + +def prepare_sync_client_base_url(client: _HasLocalModeState) -> str | None: + if client._sea_server is None: + return None + return client._sea_server.ensure_running_sync() + + +async def prepare_async_client_base_url(client: _HasLocalModeState) -> str | None: + if client._sea_server is None: + return None + return await client._sea_server.ensure_running_async() + + +def close_sync_client_sea_server(client: _HasLocalModeState) -> None: + if client._sea_server is not None: + client._sea_server.close() + + +async def close_async_client_sea_server(client: _HasLocalModeState) -> None: + if client._sea_server is not None: + await client._sea_server.aclose() diff --git a/src/stagehand/_custom/session.py b/src/stagehand/_custom/session.py index 1224cb08..975b163c 100644 --- a/src/stagehand/_custom/session.py +++ b/src/stagehand/_custom/session.py @@ -1,32 +1,38 @@ -# Manually maintained helpers (not generated). - from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, cast +import logging +from typing import TYPE_CHECKING, Any, Type, Mapping, cast from typing_extensions import Unpack, Literal, Protocol import httpx +from pydantic import BaseModel, ConfigDict -from .types import ( +from ..types import ( session_act_params, + session_start_params, session_execute_params, session_extract_params, session_observe_params, session_navigate_params, ) -from ._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from ._exceptions import StagehandError -from .types.session_act_response import SessionActResponse -from .types.session_end_response import SessionEndResponse -from .types.session_start_response import Data as SessionStartResponseData, SessionStartResponse -from .types.session_execute_response import SessionExecuteResponse -from .types.session_extract_response import SessionExtractResponse -from .types.session_observe_response import SessionObserveResponse -from .types.session_navigate_response import SessionNavigateResponse +from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given +from .._utils import lru_cache +from .._constants import RAW_RESPONSE_HEADER +from .._exceptions import StagehandError +from ..resources.sessions import SessionsResource, AsyncSessionsResource +from ..types.session_act_response import SessionActResponse +from ..types.session_end_response import SessionEndResponse +from ..types.session_start_response import Data as SessionStartResponseData, SessionStartResponse +from ..types.session_execute_response import SessionExecuteResponse +from ..types.session_extract_response import SessionExtractResponse +from ..types.session_observe_response import SessionObserveResponse +from ..types.session_navigate_response import SessionNavigateResponse if TYPE_CHECKING: - from ._client import Stagehand, AsyncStagehand + from .._client import Stagehand, AsyncStagehand + +logger = logging.getLogger(__name__) class _PlaywrightCDPSession(Protocol): @@ -112,30 +118,70 @@ async def _extract_frame_id_from_playwright_page_async(page: Any) -> str: def _maybe_inject_frame_id(params: dict[str, Any], page: Any | None) -> dict[str, Any]: - if page is None: - return params - if "frame_id" in params: + if page is None or "frame_id" in params: return params return {**params, "frame_id": _extract_frame_id_from_playwright_page(page)} async def _maybe_inject_frame_id_async(params: dict[str, Any], page: Any | None) -> dict[str, Any]: - if page is None: - return params - if "frame_id" in params: + if page is None or "frame_id" in params: return params return {**params, "frame_id": await _extract_frame_id_from_playwright_page_async(page)} +def _sync_session_call( + session: Session, + method_name: str, + *, + page: Any | None, + extra_headers: Headers | None, + extra_query: Query | None, + extra_body: Body | None, + timeout: float | httpx.Timeout | None | NotGiven, + params: dict[str, Any], +) -> Any: + method = getattr(session._client.sessions, method_name) + return method( + id=session.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **_maybe_inject_frame_id(params, page), + ) + + +async def _async_session_call( + session: AsyncSession, + method_name: str, + *, + page: Any | None, + extra_headers: Headers | None, + extra_query: Query | None, + extra_body: Body | None, + timeout: float | httpx.Timeout | None | NotGiven, + params: dict[str, Any], +) -> Any: + method = getattr(session._client.sessions, method_name) + return await method( + id=session.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **(await _maybe_inject_frame_id_async(params, page)), + ) + + class Session(SessionStartResponse): """A Stagehand session bound to a specific `session_id`.""" def __init__(self, client: Stagehand, id: str, data: SessionStartResponseData, success: bool) -> None: - # Must call super().__init__() first to initialize Pydantic's __pydantic_extra__ before setting attributes + # Must call super().__init__() first to initialize Pydantic's __pydantic_extra__ + # before setting attributes. super().__init__(data=data, success=success) self._client = client self.id = id - def navigate( self, @@ -147,13 +193,18 @@ def navigate( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_navigate_params.SessionNavigateParams], ) -> SessionNavigateResponse: - return self._client.sessions.navigate( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + return cast( + SessionNavigateResponse, + _sync_session_call( + self, + "navigate", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), + ), ) def act( @@ -168,13 +219,15 @@ def act( ) -> SessionActResponse: return cast( SessionActResponse, - self._client.sessions.act( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + _sync_session_call( + self, + "act", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -190,13 +243,15 @@ def observe( ) -> SessionObserveResponse: return cast( SessionObserveResponse, - self._client.sessions.observe( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + _sync_session_call( + self, + "observe", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -212,13 +267,15 @@ def extract( ) -> SessionExtractResponse: return cast( SessionExtractResponse, - self._client.sessions.extract( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + _sync_session_call( + self, + "extract", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -234,13 +291,15 @@ def execute( ) -> SessionExecuteResponse: return cast( SessionExecuteResponse, - self._client.sessions.execute( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + _sync_session_call( + self, + "execute", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -267,7 +326,8 @@ class AsyncSession(SessionStartResponse): """Async variant of `Session`.""" def __init__(self, client: AsyncStagehand, id: str, data: SessionStartResponseData, success: bool) -> None: - # Must call super().__init__() first to initialize Pydantic's __pydantic_extra__ before setting attributes + # Must call super().__init__() first to initialize Pydantic's __pydantic_extra__ + # before setting attributes. super().__init__(data=data, success=success) self._client = client self.id = id @@ -282,13 +342,18 @@ async def navigate( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_navigate_params.SessionNavigateParams], ) -> SessionNavigateResponse: - return await self._client.sessions.navigate( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + return cast( + SessionNavigateResponse, + await _async_session_call( + self, + "navigate", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), + ), ) async def act( @@ -303,13 +368,15 @@ async def act( ) -> SessionActResponse: return cast( SessionActResponse, - await self._client.sessions.act( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + await _async_session_call( + self, + "act", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -325,13 +392,15 @@ async def observe( ) -> SessionObserveResponse: return cast( SessionObserveResponse, - await self._client.sessions.observe( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + await _async_session_call( + self, + "observe", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -347,13 +416,15 @@ async def extract( ) -> SessionExtractResponse: return cast( SessionExtractResponse, - await self._client.sessions.extract( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + await _async_session_call( + self, + "extract", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -369,13 +440,15 @@ async def execute( ) -> SessionExecuteResponse: return cast( SessionExecuteResponse, - await self._client.sessions.execute( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + await _async_session_call( + self, + "execute", + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + params=dict(params), ), ) @@ -396,3 +469,355 @@ async def end( extra_body=extra_body, timeout=timeout, ) + + +def is_pydantic_model(schema: Any) -> bool: + return inspect.isclass(schema) and issubclass(schema, BaseModel) + + +def pydantic_model_to_json_schema(schema: Type[BaseModel]) -> dict[str, object]: + schema.model_rebuild() + return cast(dict[str, object], schema.model_json_schema()) + + +def validate_extract_response( + result: object, + schema: Type[BaseModel], + *, + strict_response_validation: bool, +) -> object: + validation_schema = _validation_schema(schema, strict_response_validation) + try: + return validation_schema.model_validate(result) + except Exception: + try: + normalized = _convert_dict_keys_to_snake_case(result) + return validation_schema.model_validate(normalized) + except Exception: + logger.warning( + "Failed to validate extracted data against schema %s. Returning raw data.", + schema.__name__, + ) + return result + + +@lru_cache(maxsize=256) +def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]: + extra_behavior: Literal["allow", "forbid"] = "forbid" if strict_response_validation else "allow" + validation_schema = cast( + Type[BaseModel], + type( + f"{schema.__name__}ExtractValidation", + (schema,), + { + "__module__": schema.__module__, + "model_config": ConfigDict(extra=extra_behavior), + }, + ), + ) + validation_schema.model_rebuild(force=True) + return validation_schema + + +def _camel_to_snake(name: str) -> str: + chars: list[str] = [] + for i, ch in enumerate(name): + if ch.isupper() and i != 0 and not name[i - 1].isupper(): + chars.append("_") + chars.append(ch.lower()) + return "".join(chars) + + +def _convert_dict_keys_to_snake_case(data: Any) -> Any: + if isinstance(data, dict): + items = cast(dict[object, object], data).items() + return { + _camel_to_snake(k) if isinstance(k, str) else k: _convert_dict_keys_to_snake_case(v) + for k, v in items + } + if isinstance(data, list): + return [_convert_dict_keys_to_snake_case(item) for item in cast(list[object], data)] + return data + + +def _with_schema( + params: Mapping[str, object], + schema: dict[str, object] | type | None, +) -> session_extract_params.SessionExtractParamsNonStreaming: + api_params = dict(params) + if schema is not None: + api_params["schema"] = cast(Any, schema) + return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params) + + +def _resolve_extract_schema( + *, + schema: dict[str, object] | type | None, + params: dict[str, object], +) -> tuple[Type[BaseModel] | None, dict[str, object] | type | None]: + params_schema_obj = params.pop("schema", None) + params_schema: dict[str, object] | type | None + if params_schema_obj is None: + params_schema = params_schema_obj + elif isinstance(params_schema_obj, dict): + params_schema = cast(dict[str, object], params_schema_obj) + elif isinstance(params_schema_obj, type): + params_schema = params_schema_obj + else: + params_schema = None + + resolved_schema = schema if schema is not None else params_schema + + if not is_pydantic_model(resolved_schema): + return None, resolved_schema + + pydantic_cls = cast(Type[BaseModel], resolved_schema) + return pydantic_cls, pydantic_model_to_json_schema(pydantic_cls) + + +def _apply_extract_validation( + response: SessionExtractResponse, + *, + schema: Type[BaseModel] | None, + strict_response_validation: bool, +) -> SessionExtractResponse: + if schema is not None and response.data and response.data.result is not None: + response.data.result = validate_extract_response( + response.data.result, + schema, + strict_response_validation=strict_response_validation, + ) + return response + + +_ORIGINAL_SESSION_EXTRACT = Session.extract +_ORIGINAL_ASYNC_SESSION_EXTRACT = AsyncSession.extract + + +def _sync_extract( # type: ignore[override, misc] + self: Session, + *, + schema: dict[str, object] | type | None = None, + page: Any | None = None, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + **params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues] +) -> SessionExtractResponse: + raw_params = dict(params) + pydantic_cls, resolved_schema = _resolve_extract_schema(schema=schema, params=raw_params) + response = _ORIGINAL_SESSION_EXTRACT( + self, + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **_with_schema(raw_params, resolved_schema), + ) + return _apply_extract_validation( + response, + schema=pydantic_cls, + strict_response_validation=self._client._strict_response_validation, + ) + + +async def _async_extract( # type: ignore[override, misc] + self: AsyncSession, + *, + schema: dict[str, object] | type | None = None, + page: Any | None = None, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + **params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues] +) -> SessionExtractResponse: + raw_params = dict(params) + pydantic_cls, resolved_schema = _resolve_extract_schema(schema=schema, params=raw_params) + response = await _ORIGINAL_ASYNC_SESSION_EXTRACT( + self, + page=page, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **_with_schema(raw_params, resolved_schema), + ) + return _apply_extract_validation( + response, + schema=pydantic_cls, + strict_response_validation=self._client._strict_response_validation, + ) + + +def install_pydantic_extract_patch() -> None: + if getattr(Session.extract, "__stagehand_pydantic_extract_patch__", False): + return + + _sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__ + _sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__ + _sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__ + _sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__ + setattr(_sync_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010 + + _async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__ + _async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__ + _async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__ + _async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__ + setattr(_async_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010 + + Session.extract = _sync_extract # type: ignore[assignment] + AsyncSession.extract = _async_extract # type: ignore[assignment] + + +def _resolve_start_browser(client: Any, browser: session_start_params.Browser | Omit) -> session_start_params.Browser | Omit: + if browser is not omit or getattr(client, "_server_mode", None) != "local": + return browser + + if client.browserbase_api_key is None or client.browserbase_project_id is None: + raise StagehandError( + "Local server mode without Browserbase credentials requires an explicit local browser, " + "e.g. browser={'type': 'local'}." + ) + + return {"type": "local"} + + +def _is_raw_or_streaming_start(extra_headers: Headers | None) -> bool: + if not extra_headers: + return False + + header_value = extra_headers.get(RAW_RESPONSE_HEADER) + return header_value in {"raw", "stream"} + + +_ORIGINAL_SESSIONS_START = SessionsResource.start +_ORIGINAL_ASYNC_SESSIONS_START = AsyncSessionsResource.start + + +def _sync_start( + self: SessionsResource, + *, + model_name: str, + act_timeout_ms: float | Omit = omit, + browser: session_start_params.Browser | Omit = omit, + browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, + browserbase_session_id: str | Omit = omit, + dom_settle_timeout_ms: float | Omit = omit, + experimental: bool | Omit = omit, + self_heal: bool | Omit = omit, + system_prompt: str | Omit = omit, + verbose: Literal[0, 1, 2] | Omit = omit, + wait_for_captcha_solves: bool | Omit = omit, + x_stream_response: Literal["true", "false"] | Omit = omit, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, +) -> object: + start_response = _ORIGINAL_SESSIONS_START( + self, + model_name=model_name, + act_timeout_ms=act_timeout_ms, + browser=_resolve_start_browser(self._client, browser), + browserbase_session_create_params=browserbase_session_create_params, + browserbase_session_id=browserbase_session_id, + dom_settle_timeout_ms=dom_settle_timeout_ms, + experimental=experimental, + self_heal=self_heal, + system_prompt=system_prompt, + verbose=verbose, + wait_for_captcha_solves=wait_for_captcha_solves, + x_stream_response=x_stream_response, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + if _is_raw_or_streaming_start(extra_headers): + return start_response + return Session(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success) + + +async def _async_start( + self: AsyncSessionsResource, + *, + model_name: str, + act_timeout_ms: float | Omit = omit, + browser: session_start_params.Browser | Omit = omit, + browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, + browserbase_session_id: str | Omit = omit, + dom_settle_timeout_ms: float | Omit = omit, + experimental: bool | Omit = omit, + self_heal: bool | Omit = omit, + system_prompt: str | Omit = omit, + verbose: Literal[0, 1, 2] | Omit = omit, + wait_for_captcha_solves: bool | Omit = omit, + x_stream_response: Literal["true", "false"] | Omit = omit, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, +) -> object: + start_response = await _ORIGINAL_ASYNC_SESSIONS_START( + self, + model_name=model_name, + act_timeout_ms=act_timeout_ms, + browser=_resolve_start_browser(self._client, browser), + browserbase_session_create_params=browserbase_session_create_params, + browserbase_session_id=browserbase_session_id, + dom_settle_timeout_ms=dom_settle_timeout_ms, + experimental=experimental, + self_heal=self_heal, + system_prompt=system_prompt, + verbose=verbose, + wait_for_captcha_solves=wait_for_captcha_solves, + x_stream_response=x_stream_response, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + if _is_raw_or_streaming_start(extra_headers): + return start_response + return AsyncSession( + self._client, + start_response.data.session_id, + data=start_response.data, + success=start_response.success, + ) + + +def install_stainless_session_patches() -> None: + install_pydantic_extract_patch() + + if getattr(SessionsResource.start, "__stagehand_bound_session_patch__", False): + return + + _sync_start.__module__ = _ORIGINAL_SESSIONS_START.__module__ + _sync_start.__name__ = _ORIGINAL_SESSIONS_START.__name__ + _sync_start.__qualname__ = _ORIGINAL_SESSIONS_START.__qualname__ + _sync_start.__doc__ = _ORIGINAL_SESSIONS_START.__doc__ + setattr(_sync_start, "__stagehand_bound_session_patch__", True) # noqa: B010 + + _async_start.__module__ = _ORIGINAL_ASYNC_SESSIONS_START.__module__ + _async_start.__name__ = _ORIGINAL_ASYNC_SESSIONS_START.__name__ + _async_start.__qualname__ = _ORIGINAL_ASYNC_SESSIONS_START.__qualname__ + _async_start.__doc__ = _ORIGINAL_ASYNC_SESSIONS_START.__doc__ + setattr(_async_start, "__stagehand_bound_session_patch__", True) # noqa: B010 + + SessionsResource.start = _sync_start # type: ignore[assignment] + AsyncSessionsResource.start = _async_start # type: ignore[assignment] + + +install_stainless_session_patches() + + +__all__ = [ + "AsyncSession", + "Session", + "install_pydantic_extract_patch", + "install_stainless_session_patches", +] diff --git a/src/stagehand/lib/.keep b/src/stagehand/lib/.keep deleted file mode 100644 index 5e2c99fd..00000000 --- a/src/stagehand/lib/.keep +++ /dev/null @@ -1,4 +0,0 @@ -File generated from our OpenAPI spec by Stainless. - -This directory can be used to store custom files to expand the SDK. -It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file diff --git a/src/stagehand/resources/sessions.py b/src/stagehand/resources/sessions.py index 99ea0bd0..a3c3c4bd 100644 --- a/src/stagehand/resources/sessions.py +++ b/src/stagehand/resources/sessions.py @@ -2,7 +2,10 @@ from __future__ import annotations -from typing import Dict, Optional +### +from typing import TYPE_CHECKING, Dict, Optional, cast + +### from typing_extensions import Literal, overload import httpx @@ -37,6 +40,13 @@ from ..types.session_observe_response import SessionObserveResponse from ..types.session_navigate_response import SessionNavigateResponse +### +# These imports are type-checking only. Runtime patching in `_custom.session` +# swaps `start()` to return real bound session objects. +if TYPE_CHECKING: + from .. import Session, AsyncSession +### + __all__ = ["SessionsResource", "AsyncSessionsResource"] @@ -928,7 +938,11 @@ def start( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> SessionStartResponse: + ### + # The runtime monkey patch returns a bound `Session`; mirror that public + # return type here so users see the right API surface. + ) -> Session: + ### """Creates a new browser session with the specified configuration. Returns a @@ -967,7 +981,12 @@ def start( ), **(extra_headers or {}), } - return self._post( + ### + # This cast is type-only. `install_stainless_session_patches()` replaces + # this generated method at runtime and constructs the real `Session`. + return cast( + "Session", + self._post( "/v1/sessions/start", body=maybe_transform( { @@ -989,7 +1008,9 @@ def start( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=SessionStartResponse, + ), ) + ### class AsyncSessionsResource(AsyncAPIResource): @@ -1880,7 +1901,11 @@ async def start( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> SessionStartResponse: + ### + # The runtime monkey patch returns a bound `AsyncSession`; mirror that + # public return type here so users see the right API surface. + ) -> AsyncSession: + ### """Creates a new browser session with the specified configuration. Returns a @@ -1919,7 +1944,12 @@ async def start( ), **(extra_headers or {}), } - return await self._post( + ### + # This cast is type-only. `install_stainless_session_patches()` replaces + # this generated method at runtime and constructs the real `AsyncSession`. + return cast( + "AsyncSession", + await self._post( "/v1/sessions/start", body=await async_maybe_transform( { @@ -1941,7 +1971,9 @@ async def start( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=SessionStartResponse, + ), ) + ### class SessionsResourceWithRawResponse: diff --git a/src/stagehand/resources/sessions_helpers.py b/src/stagehand/resources/sessions_helpers.py deleted file mode 100644 index c0af9ba9..00000000 --- a/src/stagehand/resources/sessions_helpers.py +++ /dev/null @@ -1,352 +0,0 @@ -# Manually maintained helpers (not generated). - -from __future__ import annotations - -import inspect -import logging -from typing import Any, Type, Mapping, cast -from typing_extensions import Unpack, Literal, override - -import httpx -from pydantic import BaseModel, ConfigDict - -from ..types import session_start_params, session_extract_params -from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from .._utils import lru_cache -from .._compat import cached_property -from ..session import Session, AsyncSession -from .sessions import ( - SessionsResource, - AsyncSessionsResource, - SessionsResourceWithRawResponse, - AsyncSessionsResourceWithRawResponse, - SessionsResourceWithStreamingResponse, - AsyncSessionsResourceWithStreamingResponse, -) -from .._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from ..types.session_start_response import SessionStartResponse -from ..types.session_extract_response import SessionExtractResponse - -logger = logging.getLogger(__name__) - -_ORIGINAL_SESSION_EXTRACT = Session.extract -_ORIGINAL_ASYNC_SESSION_EXTRACT = AsyncSession.extract - - -def install_pydantic_extract_patch() -> None: - if getattr(Session.extract, "__stagehand_pydantic_extract_patch__", False): - return - - Session.extract = _sync_extract # type: ignore[assignment] - AsyncSession.extract = _async_extract # type: ignore[assignment] - - -def is_pydantic_model(schema: Any) -> bool: - return inspect.isclass(schema) and issubclass(schema, BaseModel) - - -def pydantic_model_to_json_schema(schema: Type[BaseModel]) -> dict[str, object]: - schema.model_rebuild() - return cast(dict[str, object], schema.model_json_schema()) - - -def validate_extract_response( - result: object, schema: Type[BaseModel], *, strict_response_validation: bool -) -> object: - validation_schema = _validation_schema(schema, strict_response_validation) - try: - return validation_schema.model_validate(result) - except Exception: - try: - normalized = _convert_dict_keys_to_snake_case(result) - return validation_schema.model_validate(normalized) - except Exception: - logger.warning( - "Failed to validate extracted data against schema %s. Returning raw data.", - schema.__name__, - ) - return result - - -@lru_cache(maxsize=256) -def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]: - extra_behavior: Literal["allow", "forbid"] = "forbid" if strict_response_validation else "allow" - validation_schema = cast( - Type[BaseModel], - type( - f"{schema.__name__}ExtractValidation", - (schema,), - { - "__module__": schema.__module__, - "model_config": ConfigDict(extra=extra_behavior), - }, - ), - ) - validation_schema.model_rebuild(force=True) - return validation_schema - - -def _camel_to_snake(name: str) -> str: - chars: list[str] = [] - for i, ch in enumerate(name): - if ch.isupper() and i != 0 and not name[i - 1].isupper(): - chars.append("_") - chars.append(ch.lower()) - return "".join(chars) - - -def _convert_dict_keys_to_snake_case(data: Any) -> Any: - if isinstance(data, dict): - items = cast(dict[object, object], data).items() - return { - _camel_to_snake(k) if isinstance(k, str) else k: _convert_dict_keys_to_snake_case(v) - for k, v in items - } - if isinstance(data, list): - return [_convert_dict_keys_to_snake_case(item) for item in cast(list[object], data)] - return data - - -def _with_schema( - params: Mapping[str, object], - schema: dict[str, object] | type | None, -) -> session_extract_params.SessionExtractParamsNonStreaming: - api_params = dict(params) - if schema is not None: - api_params["schema"] = cast(Any, schema) - return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params) - - -def _sync_extract( # type: ignore[override, misc] - self: Session, - *, - schema: dict[str, object] | type | None = None, - page: Any | None = None, - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = not_given, - **params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues] -) -> SessionExtractResponse: - params_schema = params.pop("schema", None) # type: ignore[misc] - resolved_schema = schema if schema is not None else params_schema - - pydantic_cls: Type[BaseModel] | None = None - if is_pydantic_model(resolved_schema): - pydantic_cls = cast(Type[BaseModel], resolved_schema) - resolved_schema = pydantic_model_to_json_schema(pydantic_cls) - - response = _ORIGINAL_SESSION_EXTRACT( - self, - page=page, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_with_schema(params, resolved_schema), - ) - - if pydantic_cls is not None and response.data and response.data.result is not None: - response.data.result = validate_extract_response( - response.data.result, - pydantic_cls, - strict_response_validation=self._client._strict_response_validation, - ) - - return response - - -async def _async_extract( # type: ignore[override, misc] - self: AsyncSession, - *, - schema: dict[str, object] | type | None = None, - page: Any | None = None, - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = not_given, - **params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues] -) -> SessionExtractResponse: - params_schema = params.pop("schema", None) # type: ignore[misc] - resolved_schema = schema if schema is not None else params_schema - - pydantic_cls: Type[BaseModel] | None = None - if is_pydantic_model(resolved_schema): - pydantic_cls = cast(Type[BaseModel], resolved_schema) - resolved_schema = pydantic_model_to_json_schema(pydantic_cls) - - response = await _ORIGINAL_ASYNC_SESSION_EXTRACT( - self, - page=page, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_with_schema(params, resolved_schema), - ) - - if pydantic_cls is not None and response.data and response.data.result is not None: - response.data.result = validate_extract_response( - response.data.result, - pydantic_cls, - strict_response_validation=self._client._strict_response_validation, - ) - - return response - - -_sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__ -_sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__ -_sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__ -_sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__ -setattr(_sync_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010 - -_async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__ -_async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__ -_async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__ -_async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__ -setattr(_async_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010 - - -install_pydantic_extract_patch() - - -class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse): - def __init__(self, sessions: SessionsResourceWithHelpers) -> None: # type: ignore[name-defined] - super().__init__(sessions) - self.start = to_raw_response_wrapper(super(SessionsResourceWithHelpers, sessions).start) - - -class SessionsResourceWithHelpersStreamingResponse(SessionsResourceWithStreamingResponse): - def __init__(self, sessions: SessionsResourceWithHelpers) -> None: # type: ignore[name-defined] - super().__init__(sessions) - self.start = to_streamed_response_wrapper(super(SessionsResourceWithHelpers, sessions).start) - - -class AsyncSessionsResourceWithHelpersRawResponse(AsyncSessionsResourceWithRawResponse): - def __init__(self, sessions: AsyncSessionsResourceWithHelpers) -> None: # type: ignore[name-defined] - super().__init__(sessions) - self.start = async_to_raw_response_wrapper(super(AsyncSessionsResourceWithHelpers, sessions).start) - - -class AsyncSessionsResourceWithHelpersStreamingResponse(AsyncSessionsResourceWithStreamingResponse): - def __init__(self, sessions: AsyncSessionsResourceWithHelpers) -> None: # type: ignore[name-defined] - super().__init__(sessions) - self.start = async_to_streamed_response_wrapper(super(AsyncSessionsResourceWithHelpers, sessions).start) - - -class SessionsResourceWithHelpers(SessionsResource): - @cached_property - @override - def with_raw_response(self) -> SessionsResourceWithHelpersRawResponse: - return SessionsResourceWithHelpersRawResponse(self) - - @cached_property - @override - def with_streaming_response(self) -> SessionsResourceWithHelpersStreamingResponse: - return SessionsResourceWithHelpersStreamingResponse(self) - - @override - def start( - self, - *, - model_name: str, - act_timeout_ms: float | Omit = omit, - browser: session_start_params.Browser | Omit = omit, - browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, - browserbase_session_id: str | Omit = omit, - dom_settle_timeout_ms: float | Omit = omit, - experimental: bool | Omit = omit, - self_heal: bool | Omit = omit, - system_prompt: str | Omit = omit, - verbose: Literal[0, 1, 2] | Omit = omit, - wait_for_captcha_solves: bool | Omit = omit, - x_stream_response: Literal["true", "false"] | Omit = omit, - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> Session: - if browser is omit and getattr(self._client, "_server_mode", None) == "local": - browser = {"type": "local"} - - start_response = super().start( - model_name=model_name, - act_timeout_ms=act_timeout_ms, - browser=browser, - browserbase_session_create_params=browserbase_session_create_params, - browserbase_session_id=browserbase_session_id, - dom_settle_timeout_ms=dom_settle_timeout_ms, - experimental=experimental, - self_heal=self_heal, - system_prompt=system_prompt, - verbose=verbose, - wait_for_captcha_solves=wait_for_captcha_solves, - x_stream_response=x_stream_response, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - ) - return Session(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success) - - -class AsyncSessionsResourceWithHelpers(AsyncSessionsResource): - @cached_property - @override - def with_raw_response(self) -> AsyncSessionsResourceWithHelpersRawResponse: - return AsyncSessionsResourceWithHelpersRawResponse(self) - - @cached_property - @override - def with_streaming_response(self) -> AsyncSessionsResourceWithHelpersStreamingResponse: - return AsyncSessionsResourceWithHelpersStreamingResponse(self) - - @override - async def start( - self, - *, - model_name: str, - act_timeout_ms: float | Omit = omit, - browser: session_start_params.Browser | Omit = omit, - browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, - browserbase_session_id: str | Omit = omit, - dom_settle_timeout_ms: float | Omit = omit, - experimental: bool | Omit = omit, - self_heal: bool | Omit = omit, - system_prompt: str | Omit = omit, - verbose: Literal[0, 1, 2] | Omit = omit, - wait_for_captcha_solves: bool | Omit = omit, - x_stream_response: Literal["true", "false"] | Omit = omit, - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> AsyncSession: - if browser is omit and getattr(self._client, "_server_mode", None) == "local": - browser = {"type": "local"} - - start_response: SessionStartResponse = await super().start( - model_name=model_name, - act_timeout_ms=act_timeout_ms, - browser=browser, - browserbase_session_create_params=browserbase_session_create_params, - browserbase_session_id=browserbase_session_id, - dom_settle_timeout_ms=dom_settle_timeout_ms, - experimental=experimental, - self_heal=self_heal, - system_prompt=system_prompt, - verbose=verbose, - wait_for_captcha_solves=wait_for_captcha_solves, - x_stream_response=x_stream_response, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - ) - return AsyncSession(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success) diff --git a/tests/test_local_server.py b/tests/test_local_server.py index 490ccbda..20f7ab5b 100644 --- a/tests/test_local_server.py +++ b/tests/test_local_server.py @@ -9,7 +9,7 @@ from respx import MockRouter from stagehand import Stagehand, AsyncStagehand -from stagehand.lib import sea_server +from stagehand._custom import sea_server from stagehand._exceptions import StagehandError diff --git a/tests/test_sea_binary.py b/tests/test_sea_binary.py index 300543a8..1c098ac7 100644 --- a/tests/test_sea_binary.py +++ b/tests/test_sea_binary.py @@ -5,12 +5,12 @@ import pytest -from stagehand.lib import sea_binary +from stagehand._custom import sea_binary from stagehand._version import __version__ def _load_download_binary_module(): - script_path = Path(__file__).resolve().parents[1] / "scripts" / "download-binary.py" + script_path = Path(__file__).resolve().parents[1] / "scripts" / "download_binary.py" spec = importlib.util.spec_from_file_location("download_binary_script", script_path) assert spec is not None assert spec.loader is not None From 76d71744806ecefa20bd5dc036651dbc416ebdc3 Mon Sep 17 00:00:00 2001 From: Nick Sweeting Date: Fri, 3 Apr 2026 11:22:07 -0700 Subject: [PATCH 3/4] add compat --- src/stagehand/session.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/stagehand/session.py diff --git a/src/stagehand/session.py b/src/stagehand/session.py new file mode 100644 index 00000000..fbd3ee05 --- /dev/null +++ b/src/stagehand/session.py @@ -0,0 +1,5 @@ +### +from ._custom.session import Session, AsyncSession + +__all__ = ["AsyncSession", "Session"] +### From ee9b7808903fb953ea7359c1bbaf048da39b62dd Mon Sep 17 00:00:00 2001 From: Nick Sweeting Date: Fri, 3 Apr 2026 11:25:42 -0700 Subject: [PATCH 4/4] always detatch cdp --- src/stagehand/_custom/session.py | 39 +++++++++++++++++++++++++------- tests/test_session_page_param.py | 26 +++++++++++++++++---- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/src/stagehand/_custom/session.py b/src/stagehand/_custom/session.py index 975b163c..83d5ce1a 100644 --- a/src/stagehand/_custom/session.py +++ b/src/stagehand/_custom/session.py @@ -39,6 +39,9 @@ class _PlaywrightCDPSession(Protocol): def send(self, method: str, params: Any = ...) -> Any: # noqa: ANN401 ... + def detach(self) -> Any: # noqa: ANN401 + ... + class _PlaywrightContext(Protocol): def new_cdp_session(self, page: Any) -> Any: # noqa: ANN401 @@ -71,11 +74,21 @@ def _extract_frame_id_from_playwright_page(page: Any) -> str: raise StagehandError("Playwright CDP session missing .send(...) method") pw_cdp = cast(_PlaywrightCDPSession, cdp) - result = pw_cdp.send("Page.getFrameTree") - if inspect.isawaitable(result): - raise StagehandError( - "Expected a synchronous Playwright Page, but received an async CDP session; use AsyncSession methods" - ) + try: + result = pw_cdp.send("Page.getFrameTree") + if inspect.isawaitable(result): + raise StagehandError( + "Expected a synchronous Playwright Page, but received an async CDP session; use AsyncSession methods" + ) + finally: + detach = getattr(cdp, "detach", None) + if callable(detach): + try: + detach_result = detach() + if inspect.isawaitable(detach_result): + logger.warning("Playwright sync CDP detach() returned an awaitable; session may remain open") + except Exception: # noqa: BLE001 + logger.debug("Failed to detach Playwright CDP session", exc_info=True) try: return cast(str, result["frameTree"]["frame"]["id"]) @@ -107,9 +120,19 @@ async def _extract_frame_id_from_playwright_page_async(page: Any) -> str: raise StagehandError("Playwright CDP session missing .send(...) method") pw_cdp = cast(_PlaywrightCDPSession, cdp) - result = pw_cdp.send("Page.getFrameTree") - if inspect.isawaitable(result): - result = await result + try: + result = pw_cdp.send("Page.getFrameTree") + if inspect.isawaitable(result): + result = await result + finally: + detach = getattr(cdp, "detach", None) + if callable(detach): + try: + detach_result = detach() + if inspect.isawaitable(detach_result): + await detach_result + except Exception: # noqa: BLE001 + logger.debug("Failed to detach Playwright CDP session", exc_info=True) try: return cast(str, result["frameTree"]["frame"]["id"]) diff --git a/tests/test_session_page_param.py b/tests/test_session_page_param.py index 2e9b627d..422f9640 100644 --- a/tests/test_session_page_param.py +++ b/tests/test_session_page_param.py @@ -19,18 +19,24 @@ class _SyncCDP: def __init__(self, frame_id: str) -> None: self._frame_id = frame_id + self.detached = False def send(self, method: str) -> dict[str, Any]: assert method == "Page.getFrameTree" return {"frameTree": {"frame": {"id": self._frame_id}}} + def detach(self) -> None: + self.detached = True + class _SyncContext: def __init__(self, frame_id: str) -> None: self._frame_id = frame_id + self.last_cdp: _SyncCDP | None = None def new_cdp_session(self, _page: Any) -> _SyncCDP: - return _SyncCDP(self._frame_id) + self.last_cdp = _SyncCDP(self._frame_id) + return self.last_cdp class _SyncPage: @@ -41,18 +47,24 @@ def __init__(self, frame_id: str) -> None: class _AsyncCDP: def __init__(self, frame_id: str) -> None: self._frame_id = frame_id + self.detached = False async def send(self, method: str) -> dict[str, Any]: assert method == "Page.getFrameTree" return {"frameTree": {"frame": {"id": self._frame_id}}} + async def detach(self) -> None: + self.detached = True + class _AsyncContext: def __init__(self, frame_id: str) -> None: self._frame_id = frame_id + self.last_cdp: _AsyncCDP | None = None async def new_cdp_session(self, _page: Any) -> _AsyncCDP: - return _AsyncCDP(self._frame_id) + self.last_cdp = _AsyncCDP(self._frame_id) + return self.last_cdp class _AsyncPage: @@ -64,6 +76,7 @@ def __init__(self, frame_id: str) -> None: def test_session_act_injects_frame_id_from_page(respx_mock: MockRouter, client: Stagehand) -> None: session_id = "00000000-0000-0000-0000-000000000000" frame_id = "frame-123" + page = _SyncPage(frame_id) respx_mock.post("/v1/sessions/start").mock( return_value=httpx.Response( @@ -80,9 +93,11 @@ def test_session_act_injects_frame_id_from_page(respx_mock: MockRouter, client: ) session = client.sessions.start(model_name="openai/gpt-5-nano") - session.act(input="click something", page=_SyncPage(frame_id)) + session.act(input="click something", page=page) assert act_route.called is True + assert page.context.last_cdp is not None + assert page.context.last_cdp.detached is True first_call = cast(Call, act_route.calls[0]) request_body = json.loads(first_call.request.content) assert request_body["frameId"] == frame_id @@ -129,6 +144,7 @@ async def test_async_session_act_injects_frame_id_from_page( ) -> None: session_id = "00000000-0000-0000-0000-000000000000" frame_id = "frame-async-456" + page = _AsyncPage(frame_id) respx_mock.post("/v1/sessions/start").mock( return_value=httpx.Response( @@ -145,9 +161,11 @@ async def test_async_session_act_injects_frame_id_from_page( ) session = await async_client.sessions.start(model_name="openai/gpt-5-nano") - await session.act(input="click something", page=_AsyncPage(frame_id)) + await session.act(input="click something", page=page) assert act_route.called is True + assert page.context.last_cdp is not None + assert page.context.last_cdp.detached is True first_call = cast(Call, act_route.calls[0]) request_body = json.loads(first_call.request.content) assert request_body["frameId"] == frame_id