Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ dev = [
"pytest-xdist>=3.6.1",
"dotenv>=0.9.9",
]
pydantic-v2 = [
"pydantic~=2.0 ; python_full_version < '3.14'",
"pydantic~=2.12 ; python_full_version >= '3.14'",
]

[build-system]
requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme", "packaging"]
Expand Down
6 changes: 6 additions & 0 deletions src/stagehand/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

PYDANTIC_V1 = pydantic.VERSION.startswith("1.")

if PYDANTIC_V1:
raise ImportError(
f"stagehand requires Pydantic v2 or newer; found Pydantic {pydantic.VERSION}. "
"Install `pydantic>=2,<3`."
)

if TYPE_CHECKING:

def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
Expand Down
28 changes: 24 additions & 4 deletions src/stagehand/_pydantic_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import logging
from typing import Any, Dict, Type

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from ._utils import lru_cache

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +29,9 @@ def pydantic_model_to_json_schema(schema: Type[BaseModel]) -> Dict[str, object]:
return schema.model_json_schema()


def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
def validate_extract_response(
result: object, schema: Type[BaseModel], *, strict_response_validation: bool
) -> Any:
"""Validate raw extract result data against a Pydantic model.

Tries direct validation first. On failure, falls back to normalizing
Expand All @@ -36,12 +40,13 @@ def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
Returns the validated Pydantic model instance, or the raw result if
both attempts fail.
"""
validation_schema = _validation_schema(schema, strict_response_validation)
try:
return schema.model_validate(result)
return validation_schema.model_validate(result)
except Exception:
try:
normalized = _convert_dict_keys_to_snake_case(result)
return schema.model_validate(normalized)
return validation_schema.model_validate(normalized)
except Exception:
logger.warning(
"Failed to validate extracted data against schema %s. "
Expand All @@ -51,6 +56,21 @@ def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
return result


@lru_cache(maxsize=None)
def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]:
extra_behavior = "forbid" if strict_response_validation else "allow"
validation_schema = type(
f"{schema.__name__}ExtractValidation",
(schema,),
{
"__module__": schema.__module__,
"model_config": ConfigDict(extra=extra_behavior),
},
)
validation_schema.model_rebuild(force=True)
return validation_schema


def _camel_to_snake(name: str) -> str:
"""Convert a camelCase or PascalCase string to snake_case."""
chars: list[str] = []
Expand Down
126 changes: 126 additions & 0 deletions src/stagehand/_session_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Custom extract patch installed on top of the session helpers."""

from __future__ import annotations

from typing import Any, cast

import httpx
from typing_extensions import Unpack

from ._pydantic_extract import is_pydantic_model, pydantic_model_to_json_schema, validate_extract_response
from .session import AsyncSession, Session
from ._types import Body, Headers, NotGiven, Query, not_given
from .types import session_extract_params
from .types.session_extract_response import SessionExtractResponse

_ORIGINAL_SESSION_EXTRACT = Session.extract
_ORIGINAL_ASYNC_SESSION_EXTRACT = AsyncSession.extract


def install_pydantic_extract_patch() -> None:
if getattr(Session.extract, "__stagehand_pydantic_extract_patch__", False):
return

Session.extract = _sync_extract # type: ignore[assignment]
AsyncSession.extract = _async_extract # type: ignore[assignment]


def _sync_extract( # type: ignore[override, misc]
self: Session,
*,
schema: dict[str, object] | type | None = None,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
) -> SessionExtractResponse:
params_schema = params.pop("schema", None) # type: ignore[misc]
resolved_schema = schema if schema is not None else params_schema

pydantic_cls: type[Any] | None = None
if is_pydantic_model(resolved_schema):
pydantic_cls = resolved_schema # type: ignore[assignment]
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]

response = _ORIGINAL_SESSION_EXTRACT(
self,
page=page,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**_with_schema(params, resolved_schema),
)

if pydantic_cls is not None and response.data and response.data.result is not None:
response.data.result = validate_extract_response(
response.data.result,
pydantic_cls,
strict_response_validation=self._client._strict_response_validation,
)

return response


async def _async_extract( # type: ignore[override, misc]
self: AsyncSession,
*,
schema: dict[str, object] | type | None = None,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
) -> SessionExtractResponse:
params_schema = params.pop("schema", None) # type: ignore[misc]
resolved_schema = schema if schema is not None else params_schema

pydantic_cls: type[Any] | None = None
if is_pydantic_model(resolved_schema):
pydantic_cls = resolved_schema # type: ignore[assignment]
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]

response = await _ORIGINAL_ASYNC_SESSION_EXTRACT(
self,
page=page,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**_with_schema(params, resolved_schema),
)

if pydantic_cls is not None and response.data and response.data.result is not None:
response.data.result = validate_extract_response(
response.data.result,
pydantic_cls,
strict_response_validation=self._client._strict_response_validation,
)

return response


def _with_schema(
params: session_extract_params.SessionExtractParamsNonStreaming,
schema: dict[str, object] | type | None,
) -> session_extract_params.SessionExtractParamsNonStreaming:
api_params = dict(params)
if schema is not None:
api_params["schema"] = cast(Any, schema)
return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params)


_sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__
_sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__
_sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__
_sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__
_sync_extract.__stagehand_pydantic_extract_patch__ = True

_async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__
_async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__
_async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__
_async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__
_async_extract.__stagehand_pydantic_extract_patch__ = True
27 changes: 15 additions & 12 deletions src/stagehand/resources/sessions_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,28 @@

import httpx

from ..types import session_start_params
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from .._session_extract import install_pydantic_extract_patch
from .._compat import cached_property
from .._response import (
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
to_raw_response_wrapper,
to_streamed_response_wrapper,
)
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from ..session import Session, AsyncSession
from ..types import session_start_params
from ..types.session_start_response import SessionStartResponse
from .sessions import (
SessionsResource,
AsyncSessionsResource,
SessionsResourceWithRawResponse,
AsyncSessionsResourceWithRawResponse,
SessionsResourceWithStreamingResponse,
AsyncSessionsResourceWithStreamingResponse,
SessionsResource,
SessionsResourceWithRawResponse,
SessionsResourceWithStreamingResponse,
)
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ..types.session_start_response import SessionStartResponse

install_pydantic_extract_patch()


class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse):
Expand Down
75 changes: 18 additions & 57 deletions src/stagehand/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from ._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from ._exceptions import StagehandError
from ._pydantic_extract import is_pydantic_model, validate_extract_response, pydantic_model_to_json_schema
from .types.session_act_response import SessionActResponse
from .types.session_end_response import SessionEndResponse
from .types.session_start_response import Data as SessionStartResponseData, SessionStartResponse
Expand Down Expand Up @@ -201,47 +200,28 @@ def observe(
),
)

def extract( # type: ignore[misc]
def extract(
self,
*,
schema: dict[str, object] | type | None = None,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming],
) -> SessionExtractResponse:
# If the caller passed schema via **params (TypedDict), prefer the explicit kwarg.
params_schema = params.pop("schema", None) # type: ignore[misc]
resolved_schema = schema if schema is not None else params_schema

pydantic_cls: type[Any] | None = None
if is_pydantic_model(resolved_schema):
pydantic_cls = resolved_schema # type: ignore[assignment]
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]

api_params: dict[str, Any] = _maybe_inject_frame_id(dict(params), page)
if resolved_schema is not None:
api_params["schema"] = resolved_schema

response: SessionExtractResponse = cast(
return cast(
SessionExtractResponse,
self._client.sessions.extract(
id=self.id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**api_params,
id=self.id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**_maybe_inject_frame_id(dict(params), page),
),
)

if pydantic_cls is not None and response.data and response.data.result is not None:
response.data.result = validate_extract_response(response.data.result, pydantic_cls)

return response

def execute(
self,
*,
Expand Down Expand Up @@ -355,47 +335,28 @@ async def observe(
),
)

async def extract( # type: ignore[misc]
async def extract(
self,
*,
schema: dict[str, object] | type | None = None,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming],
) -> SessionExtractResponse:
# If the caller passed schema via **params (TypedDict), prefer the explicit kwarg.
params_schema = params.pop("schema", None) # type: ignore[misc]
resolved_schema = schema if schema is not None else params_schema

pydantic_cls: type[Any] | None = None
if is_pydantic_model(resolved_schema):
pydantic_cls = resolved_schema # type: ignore[assignment]
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]

api_params: dict[str, Any] = await _maybe_inject_frame_id_async(dict(params), page)
if resolved_schema is not None:
api_params["schema"] = resolved_schema

response: SessionExtractResponse = cast(
return cast(
SessionExtractResponse,
await self._client.sessions.extract(
id=self.id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**api_params,
id=self.id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**(await _maybe_inject_frame_id_async(dict(params), page)),
),
)

if pydantic_cls is not None and response.data and response.data.result is not None:
response.data.result = validate_extract_response(response.data.result, pydantic_cls)

return response

async def execute(
self,
*,
Expand Down
Loading
Loading