Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [

dependencies = [
"httpx>=0.23.0, <1",
"pydantic>=1.9.0, <3",
"pydantic>=2.0.0, <3",
"typing-extensions>=4.14, <5",
"anyio>=3.5.0, <5",
"distro>=1.7.0, <2",
Expand Down Expand Up @@ -46,12 +46,7 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"]
[tool.uv]
managed = true
required-version = ">=0.9"
conflicts = [
[
{ group = "pydantic-v1" },
{ group = "pydantic-v2" },
],
]
conflicts = []

[dependency-groups]
# version pins are in uv.lock
Expand All @@ -69,13 +64,6 @@ dev = [
"pytest-xdist>=3.6.1",
"dotenv>=0.9.9",
]
pydantic-v1 = [
"pydantic>=1.9.0,<2",
]
pydantic-v2 = [
"pydantic~=2.0 ; python_full_version < '3.14'",
"pydantic~=2.12 ; python_full_version >= '3.14'",
]

[build-system]
requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme", "packaging"]
Expand Down
8 changes: 1 addition & 7 deletions scripts/test
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@ PY_VERSION_MIN=">=3.9.0"
PY_VERSION_MAX=">=3.14.0"

function run_tests() {
echo "==> Running tests with Pydantic v2"
echo "==> Running tests"
uv run --isolated --all-extras pytest "$@"

# Skip Pydantic v1 tests on latest Python (not supported)
if [[ "$UV_PYTHON" != "$PY_VERSION_MAX" ]]; then
echo "==> Running tests with Pydantic v1"
uv run --isolated --all-extras --group=pydantic-v1 pytest "$@"
fi
}

# If UV_PYTHON is already set in the environment, just run the command once
Expand Down
6 changes: 6 additions & 0 deletions src/stagehand/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

PYDANTIC_V1 = pydantic.VERSION.startswith("1.")

if PYDANTIC_V1:
raise ImportError(
f"stagehand requires Pydantic v2 or newer; found Pydantic {pydantic.VERSION}. "
"Install `pydantic>=2,<3`."
)

if TYPE_CHECKING:

def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
Expand Down
193 changes: 191 additions & 2 deletions src/stagehand/resources/sessions_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

from __future__ import annotations

from typing_extensions import Literal, override
import inspect
import logging
from typing import Any, Type, Mapping, cast
from typing_extensions import Unpack, Literal, override

import httpx
from pydantic import BaseModel, ConfigDict

from ..types import session_start_params
from ..types import session_start_params, session_extract_params
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from .._utils import lru_cache
from .._compat import cached_property
from ..session import Session, AsyncSession
from .sessions import (
Expand All @@ -25,6 +30,190 @@
async_to_streamed_response_wrapper,
)
from ..types.session_start_response import SessionStartResponse
from ..types.session_extract_response import SessionExtractResponse

logger = logging.getLogger(__name__)

_ORIGINAL_SESSION_EXTRACT = Session.extract
_ORIGINAL_ASYNC_SESSION_EXTRACT = AsyncSession.extract


def install_pydantic_extract_patch() -> None:
if getattr(Session.extract, "__stagehand_pydantic_extract_patch__", False):
return

Session.extract = _sync_extract # type: ignore[assignment]
AsyncSession.extract = _async_extract # type: ignore[assignment]


def is_pydantic_model(schema: Any) -> bool:
return inspect.isclass(schema) and issubclass(schema, BaseModel)


def pydantic_model_to_json_schema(schema: Type[BaseModel]) -> dict[str, object]:
schema.model_rebuild()
return cast(dict[str, object], schema.model_json_schema())


def validate_extract_response(
result: object, schema: Type[BaseModel], *, strict_response_validation: bool
) -> object:
validation_schema = _validation_schema(schema, strict_response_validation)
try:
return validation_schema.model_validate(result)
except Exception:
try:
normalized = _convert_dict_keys_to_snake_case(result)
return validation_schema.model_validate(normalized)
except Exception:
logger.warning(
"Failed to validate extracted data against schema %s. Returning raw data.",
schema.__name__,
)
return result


@lru_cache(maxsize=256)
def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]:
extra_behavior: Literal["allow", "forbid"] = "forbid" if strict_response_validation else "allow"
validation_schema = cast(
Type[BaseModel],
type(
f"{schema.__name__}ExtractValidation",
(schema,),
{
"__module__": schema.__module__,
"model_config": ConfigDict(extra=extra_behavior),
},
),
)
validation_schema.model_rebuild(force=True)
return validation_schema


def _camel_to_snake(name: str) -> str:
chars: list[str] = []
for i, ch in enumerate(name):
if ch.isupper() and i != 0 and not name[i - 1].isupper():
Copy link
Copy Markdown

@cubic-dev-ai cubic-dev-ai bot Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: _camel_to_snake mishandles the boundary between an uppercase acronym and a following word (e.g., "getHTTPResponse""get_httpresponse" instead of "get_http_response"). Since this is the fallback used when direct validation fails, a wrong conversion silently drops the user back to raw data instead of a Pydantic instance.

The fix adds a lookahead check: also insert _ before an uppercase char that follows another uppercase char when the next char is lowercase.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At src/stagehand/resources/sessions_helpers.py, line 97:

<comment>`_camel_to_snake` mishandles the boundary between an uppercase acronym and a following word (e.g., `"getHTTPResponse"` → `"get_httpresponse"` instead of `"get_http_response"`). Since this is the fallback used when direct validation fails, a wrong conversion silently drops the user back to raw data instead of a Pydantic instance.

The fix adds a lookahead check: also insert `_` before an uppercase char that follows another uppercase char when the *next* char is lowercase.</comment>

<file context>
@@ -24,8 +29,189 @@
+def _camel_to_snake(name: str) -> str:
+    chars: list[str] = []
+    for i, ch in enumerate(name):
+        if ch.isupper() and i != 0 and not name[i - 1].isupper():
+            chars.append("_")
+        chars.append(ch.lower())
</file context>
Fix with Cubic

chars.append("_")
chars.append(ch.lower())
return "".join(chars)


def _convert_dict_keys_to_snake_case(data: Any) -> Any:
if isinstance(data, dict):
items = cast(dict[object, object], data).items()
return {
_camel_to_snake(k) if isinstance(k, str) else k: _convert_dict_keys_to_snake_case(v)
for k, v in items
}
if isinstance(data, list):
return [_convert_dict_keys_to_snake_case(item) for item in cast(list[object], data)]
return data


def _with_schema(
params: Mapping[str, object],
schema: dict[str, object] | type | None,
) -> session_extract_params.SessionExtractParamsNonStreaming:
api_params = dict(params)
if schema is not None:
api_params["schema"] = cast(Any, schema)
return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params)


def _sync_extract( # type: ignore[override, misc]
self: Session,
*,
schema: dict[str, object] | type | None = None,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
) -> SessionExtractResponse:
params_schema = params.pop("schema", None) # type: ignore[misc]
resolved_schema = schema if schema is not None else params_schema

pydantic_cls: Type[BaseModel] | None = None
if is_pydantic_model(resolved_schema):
pydantic_cls = cast(Type[BaseModel], resolved_schema)
resolved_schema = pydantic_model_to_json_schema(pydantic_cls)

response = _ORIGINAL_SESSION_EXTRACT(
self,
page=page,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**_with_schema(params, resolved_schema),
)

if pydantic_cls is not None and response.data and response.data.result is not None:
response.data.result = validate_extract_response(
response.data.result,
pydantic_cls,
strict_response_validation=self._client._strict_response_validation,
)

return response


async def _async_extract( # type: ignore[override, misc]
self: AsyncSession,
*,
schema: dict[str, object] | type | None = None,
page: Any | None = None,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
) -> SessionExtractResponse:
params_schema = params.pop("schema", None) # type: ignore[misc]
resolved_schema = schema if schema is not None else params_schema

pydantic_cls: Type[BaseModel] | None = None
if is_pydantic_model(resolved_schema):
pydantic_cls = cast(Type[BaseModel], resolved_schema)
resolved_schema = pydantic_model_to_json_schema(pydantic_cls)

response = await _ORIGINAL_ASYNC_SESSION_EXTRACT(
self,
page=page,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**_with_schema(params, resolved_schema),
)

if pydantic_cls is not None and response.data and response.data.result is not None:
response.data.result = validate_extract_response(
response.data.result,
pydantic_cls,
strict_response_validation=self._client._strict_response_validation,
)

return response


_sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__
_sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__
_sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__
_sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__
setattr(_sync_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010

_async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__
_async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__
_async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__
_async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__
setattr(_async_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010


install_pydantic_extract_patch()


class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse):
Expand Down
Loading
Loading