diff --git a/CHANGELOG.md b/CHANGELOG.md index af84cc54..37de175f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,34 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Standardized error hierarchy ([CHA-2958](https://linear.app/stream/issue/CHA-2958)). + New exception classes importable from `getstream` (also re-exported from + `getstream.exceptions`): + - `StreamException`: abstract base for every SDK error. + - `StreamApiException`: any HTTP 4xx/5xx response. Carries `status_code`, + `code`, `message`, `exception_fields`, `unrecoverable`, + `raw_response_body`, `more_info`, `details`. The `unrecoverable` flag + from `APIError` is now surfaced (was previously dropped on most paths). + - `StreamRateLimitException`: subclass of `StreamApiException` raised on + HTTP 429. Adds `retry_after: datetime.timedelta | None`, parsed from + `Retry-After` per RFC 7231 (integer seconds or HTTP-date). Missing or + unparseable headers map to `None`; past HTTP-dates clamp to `0`. + - `StreamTransportException`: raised when a network-layer failure (no + HTTP response received) propagates out of `httpx` — connection reset, + timeout, TLS handshake failure, DNS failure. Carries `error_type` + enum (`connection_reset` / `timeout` / `dns_failure` / + `tls_handshake_failed` / `unknown`). The original `httpx` exception + is preserved as `__cause__`. + - `StreamTaskException`: raised by `wait_for_task` when the polled task + ends in `status='failed'`. Carries `task_id`, `error_type`, + `description`, `stack_trace`, `version`. +- `Stream.wait_for_task(task_id, *, poll_interval=1.0, timeout=60.0)` and + the matching async coroutine on `AsyncStream`. Polls `get_task` until the + task reaches a terminal state. On `completed` returns the + `StreamResponse[GetTaskResponse]`; on `failed` raises + `StreamTaskException` populated from `ErrorResult`; on timeout raises + `StreamTransportException(error_type='timeout')`. + - Explicit HTTP connection pool configuration ([CHA-2956](https://linear.app/stream/issue/CHA-2956/connection-pooling)). Four new kwargs on `Stream(...)` and `AsyncStream(...)`: - `max_conns_per_host: int`: default `5` @@ -45,11 +73,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- HTTP request errors raised by `httpx` (the `httpx.RequestError` family — + `ConnectError`, `ReadTimeout`, etc.) are now wrapped at the SDK boundary + in `StreamTransportException` so callers handle one Stream error category + instead of catching `httpx.RequestError` separately. The original `httpx` + exception is preserved via `__cause__` (CHA-2958). - **Default `request_timeout` is now `30.0` seconds (was `6.0`).** Aligns stream-py with the cross-SDK contract in CHA-2956. Existing callers using `timeout=` are unaffected; `timeout` is kept as an alias for `request_timeout`. Callers relying on the 6s ceiling for fail-fast behavior should pass `request_timeout=6.0` (or `timeout=6.0`) explicitly. - Default HTTP transport now caps connections per host at `5` and closes idle sockets after `55.0s`. Previous default was httpx's `100` max-connections with `5.0s` keep-alive expiry. - No breaking changes. All existing webhook helpers (`verify_webhook_signature`, `parse_webhook_event`, `get_event_type`, event type constants) are preserved. +### Deprecated + +- `getstream.base.StreamAPIException` (capital `API`) is now an alias for + `getstream.exceptions.StreamApiException` (lowercase `Api`). Importing the + old name emits `DeprecationWarning`; existing `isinstance` / `except` / + `pytest.raises` checks continue to work because the alias resolves to the + same class. The legacy spelling will be removed one minor cycle after this + release (CHA-2958 §10). + ### Notes - Per-call `timeout=httpx.Timeout(...)` continues to work through `.get(...)`, `.post(...)`, etc., and pre-empts the client-level `request_timeout`. diff --git a/getstream/__init__.py b/getstream/__init__.py index 52d02d9c..6e6dea52 100644 --- a/getstream/__init__.py +++ b/getstream/__init__.py @@ -1,2 +1,9 @@ +from getstream.exceptions import ( # noqa: F401 + StreamApiException, + StreamException, + StreamRateLimitException, + StreamTaskException, + StreamTransportException, +) from getstream.stream import Stream # noqa: F401 from getstream.stream import AsyncStream # noqa: F401 diff --git a/getstream/base.py b/getstream/base.py index 27ea400a..be4a48cf 100644 --- a/getstream/base.py +++ b/getstream/base.py @@ -4,11 +4,15 @@ import os import time import uuid +import warnings import asyncio from typing import Any, Dict, List, Optional, Tuple, Type, cast, get_origin -from getstream.models import APIError -from getstream.rate_limit import extract_rate_limit +from getstream.exceptions import ( + StreamApiException, + build_api_exception, + wrap_transport_error, +) from getstream.stream_response import StreamResponse from getstream.generic import T import httpx @@ -102,9 +106,7 @@ def _parse_response( self, response: httpx.Response, data_type: Type[T] ) -> StreamResponse[T]: if response.status_code >= 399: - raise StreamAPIException( - response=response, - ) + raise build_api_exception(response) try: parsed_result = json.loads(response.text) if response.text else {} @@ -118,10 +120,8 @@ def _parse_response( else: data = cast(T, parsed_result) - except (ValueError, AttributeError): - raise StreamAPIException( - response=response, - ) + except (ValueError, AttributeError) as err: + raise StreamApiException(response=response) from err return StreamResponse(response, data) @@ -291,9 +291,12 @@ def _request_sync( ) as span: call_kwargs = dict(kwargs) call_kwargs.pop("path_params", None) - response = getattr(self.client, method.lower())( - url_path, params=query_params, *args, **call_kwargs - ) + try: + response = getattr(self.client, method.lower())( + url_path, params=query_params, *args, **call_kwargs + ) + except httpx.RequestError as err: + raise wrap_transport_error(err) from err duration = parse_duration_from_body(response.content) if duration: span.set_attribute("http.server.duration", duration) @@ -604,9 +607,12 @@ async def _request_async( call_kwargs["headers"] = call_kwargs.get("headers", {}) call_kwargs["headers"]["Content-Type"] = "application/json" - response = await getattr(self.client, method.lower())( - url_path, params=query_params, *args, **call_kwargs - ) + try: + response = await getattr(self.client, method.lower())( + url_path, params=query_params, *args, **call_kwargs + ) + except httpx.RequestError as err: + raise wrap_transport_error(err) from err duration = parse_duration_from_body(response.content) if duration: span.set_attribute("http.server.duration", duration) @@ -721,54 +727,19 @@ async def delete( ) -class StreamAPIException(Exception): - """ - A custom exception for handling errors from a Stream API response. - - This exception is raised when an API call encounters an issue, providing - detailed information from the HTTP response. It attempts to parse the response - content into a structured API error. If the response content is not JSON or - lacks the expected structure, it will simply report the HTTP status code. - - Attributes: - api_error (Optional[APIError]): An optional APIError object that is - populated if the response content contains structured error information. - rate_limit_info (RateLimitInfo): Information about the API's rate limiting - controls extracted from the response headers. - http_response (httpx.Response): The full HTTP response object from httpx. - status_code (int): The HTTP status code from the response. - - Args: - response (httpx.Response): The HTTP response received from the Stream API. - - Raises: - ValueError: If the response content cannot be parsed into JSON, indicating - that the server's response was not in the expected format. - """ - - def __init__(self, response: httpx.Response) -> None: - self.api_error: Optional[APIError] = None - self.rate_limit_info = extract_rate_limit(response) - self.http_response = response - self.status_code = response.status_code - - try: - parsed_response: Dict = json.loads(response.content) - self.api_error = APIError.from_dict(parsed_response) - except ValueError: - pass - - def __str__(self) -> str: - if self.api_error: - return f'Stream error code {self.api_error.code}: {self.api_error.message}"' - body_preview = "" - try: - text = self.http_response.text[:200] if self.http_response.text else "" - if text: - body_preview = f" body: {text}" - except Exception: - pass - return f"Stream error HTTP code: {self.status_code}{body_preview}" +def __getattr__(name: str): + """StreamApiException is exported under its new name; resolve here lazily and warn once.""" + if name == "StreamAPIException": + warnings.warn( + "getstream.base.StreamAPIException is deprecated; import " + "StreamApiException from getstream (or getstream.exceptions) " + "instead. The legacy alias will be removed one minor cycle after " + "this release.", + DeprecationWarning, + stacklevel=2, + ) + return StreamApiException + raise AttributeError(f"module 'getstream.base' has no attribute {name!r}") def parse_duration_from_body(body: bytes) -> Optional[str]: diff --git a/getstream/exceptions.py b/getstream/exceptions.py new file mode 100644 index 00000000..8b34abac --- /dev/null +++ b/getstream/exceptions.py @@ -0,0 +1,326 @@ +"""Stream SDK error hierarchy. See class docstrings for when each is raised. + +StreamException — abstract base + StreamApiException — HTTP 4xx/5xx with APIError envelope + StreamRateLimitException — HTTP 429, carries ``retry_after`` + StreamTransportException — network-layer failure (no HTTP response) + StreamTaskException — async task observed as ``status='failed'`` +""" + +from __future__ import annotations + +import email.utils +import json +import socket +import ssl +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Dict, Optional, Tuple + +import httpx + +from getstream.models import APIError +from getstream.rate_limit import RateLimitInfo, extract_rate_limit + + +TRANSPORT_ERROR_CONNECTION_RESET = "connection_reset" +TRANSPORT_ERROR_TIMEOUT = "timeout" +TRANSPORT_ERROR_DNS_FAILURE = "dns_failure" +TRANSPORT_ERROR_TLS_HANDSHAKE_FAILED = "tls_handshake_failed" +TRANSPORT_ERROR_UNKNOWN = "unknown" + + +class StreamException(Exception): + """Abstract base for all Stream SDK errors.""" + + +class StreamApiException(StreamException): + """Raised on any HTTP 4xx/5xx response from the Stream API. + + Raised with code=0 and message='failed to parse error response' when an + HTTP response was received but the body could not be parsed as an + APIError envelope; in that case ``__cause__`` carries the underlying + parse error. + """ + + def __init__( + self, + response: Optional[httpx.Response] = None, + *, + status_code: Optional[int] = None, + code: int = 0, + message: str = "", + exception_fields: Optional[Dict[str, str]] = None, + unrecoverable: bool = False, + raw_response_body: str = "", + more_info: Optional[str] = None, + details: Any = None, + api_error: Optional[APIError] = None, + rate_limit_info: Optional[RateLimitInfo] = None, + ) -> None: + body_parse_err: Optional[BaseException] = None + if response is not None: + parsed, body_parse_err = _fields_from_response(response) + status_code = parsed["status_code"] + code = parsed["code"] + message = parsed["message"] + exception_fields = parsed["exception_fields"] + unrecoverable = parsed["unrecoverable"] + raw_response_body = parsed["raw_response_body"] + more_info = parsed["more_info"] + details = parsed["details"] + api_error = parsed["api_error"] + rate_limit_info = parsed["rate_limit_info"] + + if status_code is None: + raise TypeError( + "StreamApiException requires either a response or an explicit status_code" + ) + + self.status_code = status_code + self.code = code + self.message = message + self.exception_fields = exception_fields or {} + self.unrecoverable = unrecoverable + self.raw_response_body = raw_response_body + self.more_info = more_info + self.details = details + self.api_error = api_error + self.rate_limit_info = rate_limit_info + self.http_response = response + super().__init__(str(self)) + if body_parse_err is not None and self.__cause__ is None: + self.__cause__ = body_parse_err + + def __str__(self) -> str: + if self.code or self.message: + return ( + f"Stream API error (status={self.status_code}, code={self.code}): " + f"{self.message}" + ) + return f"Stream API error (status={self.status_code})" + + +class StreamRateLimitException(StreamApiException): + """HTTP 429. Carries ``retry_after`` parsed from the ``Retry-After`` header. + + ``retry_after`` is ``None`` when the header is absent or unparseable + (graceful — never raises on a bad header). + """ + + def __init__( + self, + response: Optional[httpx.Response] = None, + *, + retry_after: Optional[timedelta] = None, + **kwargs: Any, + ) -> None: + if response is not None and retry_after is None: + retry_after = parse_retry_after(response.headers.get("Retry-After")) + super().__init__(response, **kwargs) + self.retry_after = retry_after + + +class StreamTransportException(StreamException): + """Network-layer failure: connection reset, timeout, TLS, DNS, etc. + + No HTTP response was received. ``__cause__`` carries the original httpx + exception. + """ + + def __init__(self, error_type: str, message: str = "") -> None: + self.error_type = error_type + super().__init__(message or f"Stream transport error ({error_type})") + + +class StreamTaskException(StreamException): + """Raised by ``wait_for_task`` when the polled task ends in + status='failed'. Carries task_id, error_type, description, stack_trace, + version. + """ + + def __init__( + self, + *, + task_id: str, + error_type: str, + description: str, + stack_trace: Optional[str] = None, + version: Optional[str] = None, + ) -> None: + self.task_id = task_id + self.error_type = error_type + self.description = description + self.stack_trace = stack_trace + self.version = version + super().__init__(f"Task {task_id} failed ({error_type}): {description}") + + +def _fields_from_response( + response: httpx.Response, +) -> Tuple[Dict[str, Any], Optional[BaseException]]: + """Pull the APIError envelope fields out of an httpx response. + + Returns ``(fields, parse_error)``; ``parse_error`` is None when the body + parsed cleanly or was empty. + """ + raw_body = "" + try: + raw_body = response.text or "" + except Exception: + pass + + api_error: Optional[APIError] = None + parse_error: Optional[BaseException] = None + if response.content: + try: + parsed_json: Any = json.loads(response.content) + api_error = APIError.from_dict(parsed_json) + except (ValueError, AttributeError, TypeError) as e: + api_error = None + parse_error = e + + rate_limit_info = extract_rate_limit(response) + + if api_error is not None: + unrecoverable = api_error.unrecoverable + return ( + { + "status_code": response.status_code, + "code": api_error.code, + "message": api_error.message, + "exception_fields": dict(api_error.exception_fields or {}), + "unrecoverable": bool(unrecoverable) + if unrecoverable is not None + else False, + "raw_response_body": raw_body, + "more_info": api_error.more_info or None, + "details": api_error.details, + "api_error": api_error, + "rate_limit_info": rate_limit_info, + }, + None, + ) + + return ( + { + "status_code": response.status_code, + "code": 0, + "message": "failed to parse error response", + "exception_fields": {}, + "unrecoverable": False, + "raw_response_body": raw_body, + "more_info": None, + "details": None, + "api_error": None, + "rate_limit_info": rate_limit_info, + }, + parse_error, + ) + + +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + +def parse_retry_after( + header_value: Optional[str], + *, + now: Callable[[], datetime] = _utcnow, +) -> Optional[timedelta]: + """Parse ``Retry-After`` per RFC 7231 §7.1.3. + + Accepts integer seconds (``Retry-After: 30``) or HTTP-date + (``Retry-After: Fri, 31 Dec 2026 23:59:59 GMT``). Returns ``None`` when + the header is absent or unparseable. HTTP-date deltas in the past are + clamped to zero (never negative). + """ + if header_value is None: + return None + value = header_value.strip() + if not value: + return None + try: + seconds = int(value) + except ValueError: + pass + else: + if seconds < 0: + return None + return timedelta(seconds=seconds) + + try: + parsed = email.utils.parsedate_to_datetime(value) + except (TypeError, ValueError): + return None + if parsed is None: + return None + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + delta = parsed - now() + if delta.total_seconds() < 0: + return timedelta(0) + return delta + + +def build_api_exception(response: httpx.Response) -> StreamApiException: + """Build the right ``StreamApiException`` subclass from an httpx response. + + Returns ``StreamRateLimitException`` for 429, else base + ``StreamApiException``. ``__cause__`` is set when the body could not be + parsed. + """ + if response.status_code == 429: + return StreamRateLimitException(response=response) + return StreamApiException(response=response) + + +def classify_transport_error(exc: BaseException) -> str: + """Map an httpx transport exception to the ``error_type`` enum: ``timeout``, + ``tls_handshake_failed``, ``dns_failure``, ``connection_reset``, or + ``unknown``. + """ + if isinstance(exc, httpx.TimeoutException): + return TRANSPORT_ERROR_TIMEOUT + + seen: set = set() + cur: Optional[BaseException] = exc + while cur is not None and id(cur) not in seen: + seen.add(id(cur)) + if isinstance(cur, ssl.SSLError): + return TRANSPORT_ERROR_TLS_HANDSHAKE_FAILED + if isinstance(cur, socket.gaierror): + return TRANSPORT_ERROR_DNS_FAILURE + cur = cur.__cause__ or cur.__context__ + + if isinstance(exc, (httpx.NetworkError, httpx.TransportError)): + return TRANSPORT_ERROR_CONNECTION_RESET + return TRANSPORT_ERROR_UNKNOWN + + +def wrap_transport_error(exc: httpx.RequestError) -> StreamTransportException: + """Build a ``StreamTransportException`` from an httpx ``RequestError``. + + Callers must use ``raise ... from exc`` to set ``__cause__``. + """ + return StreamTransportException( + error_type=classify_transport_error(exc), + message=f"transport error: {exc}", + ) + + +__all__ = [ + "StreamException", + "StreamApiException", + "StreamRateLimitException", + "StreamTransportException", + "StreamTaskException", + "TRANSPORT_ERROR_CONNECTION_RESET", + "TRANSPORT_ERROR_TIMEOUT", + "TRANSPORT_ERROR_DNS_FAILURE", + "TRANSPORT_ERROR_TLS_HANDSHAKE_FAILED", + "TRANSPORT_ERROR_UNKNOWN", + "build_api_exception", + "classify_transport_error", + "parse_retry_after", + "wrap_transport_error", +] diff --git a/getstream/stream.py b/getstream/stream.py index b1e0ff40..7ff71910 100644 --- a/getstream/stream.py +++ b/getstream/stream.py @@ -408,6 +408,24 @@ async def aclose(self): stack.push_async_callback(self.moderation.aclose) stack.push_async_callback(super().aclose) + async def wait_for_task( + self, + task_id: str, + *, + poll_interval: float = 1.0, + timeout: float = 60.0, + ): + """Poll an async task until ``completed`` (returns the response), + ``failed`` (raises :class:`StreamTaskException`), or the timeout + elapses (raises :class:`StreamTransportException` with + ``error_type='timeout'``). + """ + from .tasks import wait_for_task_async + + return await wait_for_task_async( + self, task_id, poll_interval=poll_interval, timeout=timeout + ) + @cached_property def feeds(self): raise NotImplementedError("Feeds not supported for async client") @@ -621,6 +639,24 @@ def upsert_users(self, *users: UserRequest): users_map = {u.id: u for u in users} return self.update_users(users_map) + def wait_for_task( + self, + task_id: str, + *, + poll_interval: float = 1.0, + timeout: float = 60.0, + ): + """Poll an async task until ``completed`` (returns the response), + ``failed`` (raises :class:`StreamTaskException`), or the timeout + elapses (raises :class:`StreamTransportException` with + ``error_type='timeout'``). + """ + from .tasks import wait_for_task_sync + + return wait_for_task_sync( + self, task_id, poll_interval=poll_interval, timeout=timeout + ) + def verify_signature(self, body, signature): """Verify a webhook signature using this client's API secret. diff --git a/getstream/tasks.py b/getstream/tasks.py new file mode 100644 index 00000000..78ff8da0 --- /dev/null +++ b/getstream/tasks.py @@ -0,0 +1,104 @@ +"""Polling helpers used by ``Stream.wait_for_task`` and ``AsyncStream.wait_for_task``.""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, Any + +from getstream.exceptions import ( + StreamTaskException, + StreamTransportException, + TRANSPORT_ERROR_TIMEOUT, +) + +if TYPE_CHECKING: + from getstream.models import GetTaskResponse + from getstream.stream_response import StreamResponse + + +DEFAULT_POLL_INTERVAL = 1.0 +DEFAULT_TIMEOUT = 60.0 + + +def _build_task_exception(task_id: str, response_data: Any) -> StreamTaskException: + """Construct a ``StreamTaskException`` from a ``GetTaskResponse.data``. + + Missing ``error`` (shouldn't happen when status == 'failed' but defend + anyway) collapses to ``error_type='unknown'`` with an empty description. + """ + err = getattr(response_data, "error", None) + return StreamTaskException( + task_id=task_id, + error_type=err.type if err is not None else "unknown", + description=err.description if err is not None else "", + stack_trace=err.stacktrace if err is not None else None, + version=err.version if err is not None else None, + ) + + +def _timeout_exception(task_id: str, timeout: float) -> StreamTransportException: + return StreamTransportException( + error_type=TRANSPORT_ERROR_TIMEOUT, + message=( + f"wait_for_task timed out after {timeout}s waiting for task {task_id}" + ), + ) + + +def wait_for_task_sync( + client: Any, + task_id: str, + *, + poll_interval: float = DEFAULT_POLL_INTERVAL, + timeout: float = DEFAULT_TIMEOUT, +) -> "StreamResponse[GetTaskResponse]": + """Poll ``client.get_task(id=task_id)`` (sync) until terminal state. + + Returns the ``StreamResponse`` on ``completed``. Raises + ``StreamTaskException`` on ``failed``, ``StreamTransportException`` with + ``error_type='timeout'`` if ``timeout`` elapses first. + """ + deadline = time.monotonic() + timeout + while True: + response = client.get_task(id=task_id) + status = response.data.status + if status == "completed": + return response + if status == "failed": + raise _build_task_exception(task_id, response.data) + remaining = deadline - time.monotonic() + if remaining <= 0: + raise _timeout_exception(task_id, timeout) + time.sleep(min(poll_interval, remaining)) + + +async def wait_for_task_async( + client: Any, + task_id: str, + *, + poll_interval: float = DEFAULT_POLL_INTERVAL, + timeout: float = DEFAULT_TIMEOUT, +) -> "StreamResponse[GetTaskResponse]": + """Async variant of :func:`wait_for_task_sync`.""" + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while True: + response = await client.get_task(id=task_id) + status = response.data.status + if status == "completed": + return response + if status == "failed": + raise _build_task_exception(task_id, response.data) + remaining = deadline - loop.time() + if remaining <= 0: + raise _timeout_exception(task_id, timeout) + await asyncio.sleep(min(poll_interval, remaining)) + + +__all__ = [ + "DEFAULT_POLL_INTERVAL", + "DEFAULT_TIMEOUT", + "wait_for_task_sync", + "wait_for_task_async", +] diff --git a/pyproject.toml b/pyproject.toml index f8703fe1..940f21d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ dev = [ "prometheus-client>=0.23.1", "torch>=2.7.0", # Only for scripts/create_test_assets.py "torchaudio>=2.7.0", # Only for scripts/create_test_assets.py + "pytest-httpserver>=1.1.5", ] [tool.uv.workspace] diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 00000000..8da6340b --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,635 @@ +"""Unit tests for the new error hierarchy (CHA-2958). + +Covers: +* The four canonical exception classes and their fields. +* APIError envelope parsing into ``StreamApiException`` (incl. spec §6.3 + unparseable-envelope fallback). +* ``Retry-After`` parsing (integer + HTTP-date forms; missing/garbage). +* Transport-error wrapping with ``__cause__`` preserved. +* ``wait_for_task`` sync + async (completed / failed / timeout). +* Back-compat: ``getstream.base.StreamAPIException`` still importable, still + catchable, emits ``DeprecationWarning``. +""" + +from __future__ import annotations + +import json +import socket +import ssl +import warnings +from datetime import datetime, timedelta, timezone +from typing import Optional + +import httpx +import pytest + +from getstream.exceptions import ( + StreamApiException, + StreamException, + StreamRateLimitException, + StreamTaskException, + StreamTransportException, + TRANSPORT_ERROR_CONNECTION_RESET, + TRANSPORT_ERROR_DNS_FAILURE, + TRANSPORT_ERROR_TIMEOUT, + TRANSPORT_ERROR_TLS_HANDSHAKE_FAILED, + TRANSPORT_ERROR_UNKNOWN, + build_api_exception, + classify_transport_error, + parse_retry_after, + wrap_transport_error, +) + + +# ── Fixtures ────────────────────────────────────────────────────────── + + +def _make_response( + status_code: int, + body: Optional[dict] = None, + headers: Optional[dict] = None, + raw_text: Optional[str] = None, +) -> httpx.Response: + """Build an ``httpx.Response`` with no real network round-trip.""" + if raw_text is not None: + content: bytes = raw_text.encode("utf-8") + elif body is not None: + content = json.dumps(body).encode("utf-8") + else: + content = b"" + request = httpx.Request("GET", "https://example.invalid/test") + return httpx.Response( + status_code=status_code, + content=content, + headers=headers or {}, + request=request, + ) + + +# ── Class hierarchy ─────────────────────────────────────────────────── + + +def test_class_hierarchy_matches_spec_9_2(): + """The four concrete classes plus the abstract base all inherit from + ``StreamException`` per §9.2.""" + assert issubclass(StreamApiException, StreamException) + assert issubclass(StreamRateLimitException, StreamApiException) + assert issubclass(StreamTransportException, StreamException) + assert issubclass(StreamTaskException, StreamException) + + +def test_api_exception_explicit_kwargs(): + exc = StreamApiException( + status_code=404, + code=16, + message="not found", + exception_fields={"id": "missing"}, + unrecoverable=True, + raw_response_body='{"code":16}', + more_info="https://example.com/docs", + details=["x"], + ) + assert exc.status_code == 404 + assert exc.code == 16 + assert exc.message == "not found" + assert exc.exception_fields == {"id": "missing"} + assert exc.unrecoverable is True + assert exc.raw_response_body == '{"code":16}' + assert exc.more_info == "https://example.com/docs" + assert exc.details == ["x"] + + +def test_api_exception_requires_status_code_without_response(): + with pytest.raises(TypeError): + StreamApiException() + + +# ── APIError envelope parsing ───────────────────────────────────────── + + +def test_api_exception_from_response_extracts_all_fields(): + body = { + "StatusCode": 422, + "code": 4, + "duration": "12ms", + "message": "validation failed", + "more_info": "https://docs/422", + "details": [1, 2, 3], + "exception_fields": {"name": "required"}, + "unrecoverable": False, + } + response = _make_response(422, body=body) + exc = build_api_exception(response) + assert isinstance(exc, StreamApiException) + assert not isinstance(exc, StreamRateLimitException) + assert exc.status_code == 422 + assert exc.code == 4 + assert exc.message == "validation failed" + assert exc.exception_fields == {"name": "required"} + assert exc.unrecoverable is False + assert exc.more_info == "https://docs/422" + assert exc.details == [1, 2, 3] + assert exc.raw_response_body == json.dumps(body) + + +def test_api_exception_exposes_unrecoverable_when_set_true(): + body = { + "StatusCode": 400, + "code": 99, + "duration": "0ms", + "message": "do not retry", + "more_info": "", + "details": [], + "exception_fields": None, + "unrecoverable": True, + } + exc = build_api_exception(_make_response(400, body=body)) + assert exc.unrecoverable is True + + +def test_api_exception_unparseable_body_falls_back_per_spec_6_3(): + response = _make_response(500, raw_text="upstream barfed") + exc = build_api_exception(response) + assert exc.status_code == 500 + assert exc.code == 0 + assert exc.message == "failed to parse error response" + assert exc.exception_fields == {} + assert exc.unrecoverable is False + assert exc.raw_response_body == "upstream barfed" + assert exc.api_error is None + + +def test_api_exception_back_compat_attrs_preserved(): + """The old API surface (api_error, http_response, rate_limit_info, + status_code) keeps working — none of the existing tests can break.""" + body = { + "StatusCode": 401, + "code": 17, + "duration": "0ms", + "message": "denied", + "more_info": "", + "details": [], + "exception_fields": None, + } + response = _make_response(401, body=body, headers={"x-ratelimit-limit": "10"}) + exc = build_api_exception(response) + assert exc.http_response is response + assert exc.status_code == 401 + assert exc.api_error is not None + assert exc.api_error.code == 17 + # rate_limit_info requires all three headers; only partial here → None. + assert exc.rate_limit_info is None + + +# ── Retry-After handling (§7) ───────────────────────────────────────── + + +def test_parse_retry_after_integer_seconds(): + assert parse_retry_after("30") == timedelta(seconds=30) + + +def test_parse_retry_after_http_date(): + fake_now = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + delta = parse_retry_after( + "Thu, 01 Jan 2026 12:00:30 GMT", + now=lambda: fake_now, + ) + assert delta == timedelta(seconds=30) + + +def test_parse_retry_after_past_http_date_clamped_to_zero(): + fake_now = datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + delta = parse_retry_after( + "Thu, 01 Jan 2026 11:59:30 GMT", + now=lambda: fake_now, + ) + assert delta == timedelta(0) + + +def test_parse_retry_after_missing_or_unparseable_returns_none(): + assert parse_retry_after(None) is None + assert parse_retry_after("") is None + assert parse_retry_after(" ") is None + assert parse_retry_after("not a date or number") is None + + +def test_parse_retry_after_negative_integer_returns_none(): + assert parse_retry_after("-1") is None + + +def test_429_builds_rate_limit_exception_with_retry_after(): + body = { + "StatusCode": 429, + "code": 9, + "duration": "0ms", + "message": "slow down", + "more_info": "", + "details": [], + "exception_fields": None, + } + response = _make_response(429, body=body, headers={"Retry-After": "5"}) + exc = build_api_exception(response) + assert isinstance(exc, StreamRateLimitException) + # Subclass of StreamApiException — existing 4xx handlers still catch it. + assert isinstance(exc, StreamApiException) + assert exc.retry_after == timedelta(seconds=5) + assert exc.status_code == 429 + assert exc.message == "slow down" + + +def test_429_without_retry_after_header_is_none(): + body = { + "StatusCode": 429, + "code": 9, + "duration": "0ms", + "message": "slow down", + "more_info": "", + "details": [], + "exception_fields": None, + } + response = _make_response(429, body=body) + exc = build_api_exception(response) + assert isinstance(exc, StreamRateLimitException) + assert exc.retry_after is None + + +def test_retry_after_exposed_only_on_rate_limit_subclass(): + body = { + "StatusCode": 500, + "code": 1, + "duration": "0ms", + "message": "boom", + "more_info": "", + "details": [], + "exception_fields": None, + } + response = _make_response(500, body=body, headers={"Retry-After": "5"}) + exc = build_api_exception(response) + assert isinstance(exc, StreamApiException) + assert not isinstance(exc, StreamRateLimitException) + assert not hasattr(exc, "retry_after") + + +# ── Transport-error wrapping (§6.1, §6.4) ───────────────────────────── + + +def _request() -> httpx.Request: + return httpx.Request("GET", "https://example.invalid/test") + + +def test_classify_timeout(): + err = httpx.ReadTimeout("timed out", request=_request()) + assert classify_transport_error(err) == TRANSPORT_ERROR_TIMEOUT + + +def test_classify_tls_via_cause_chain(): + inner = ssl.SSLError("handshake failed") + outer = httpx.ConnectError("ssl wrapper", request=_request()) + outer.__cause__ = inner + assert classify_transport_error(outer) == TRANSPORT_ERROR_TLS_HANDSHAKE_FAILED + + +def test_classify_dns_via_cause_chain(): + inner = socket.gaierror("name or service not known") + outer = httpx.ConnectError("dns wrapper", request=_request()) + outer.__cause__ = inner + assert classify_transport_error(outer) == TRANSPORT_ERROR_DNS_FAILURE + + +def test_classify_connection_reset_default(): + outer = httpx.ConnectError("connection refused", request=_request()) + assert classify_transport_error(outer) == TRANSPORT_ERROR_CONNECTION_RESET + + +def test_classify_unknown_for_non_transport_error(): + class WeirdError(Exception): + pass + + assert classify_transport_error(WeirdError("?")) == TRANSPORT_ERROR_UNKNOWN + + +def test_wrap_transport_error_builds_exception_class(): + err = httpx.ConnectTimeout("timed out", request=_request()) + wrapped = wrap_transport_error(err) + assert isinstance(wrapped, StreamTransportException) + assert wrapped.error_type == TRANSPORT_ERROR_TIMEOUT + + +def test_transport_exception_cause_chain_preserved(): + """``raise StreamTransportException(...) from err`` must keep the + underlying httpx error on ``__cause__`` per spec §6.4.""" + err = httpx.ConnectError("refused", request=_request()) + try: + raise wrap_transport_error(err) from err + except StreamTransportException as caught: + assert caught.__cause__ is err + assert caught.error_type == TRANSPORT_ERROR_CONNECTION_RESET + + +def _closed_port() -> int: + """Bind a loopback socket, close it, and return the now-closed port. + Connecting to it triggers a real ``httpx.ConnectError``.""" + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + s.close() + return port + + +def test_transport_wrapping_sync_connection_refused(monkeypatch): + """End-to-end: a real ``httpx.ConnectError`` (loopback port closed) is + re-raised by the SDK as ``StreamTransportException`` with the original + on ``__cause__``.""" + from getstream import Stream + + monkeypatch.delenv("STREAM_API_KEY", raising=False) + monkeypatch.delenv("STREAM_API_SECRET", raising=False) + client = Stream( + api_key="k", + api_secret="s", + base_url=f"http://127.0.0.1:{_closed_port()}/", + ) + try: + with pytest.raises(StreamTransportException) as info: + client.get_app() + assert info.value.error_type in { + TRANSPORT_ERROR_CONNECTION_RESET, + TRANSPORT_ERROR_UNKNOWN, + } + assert isinstance(info.value.__cause__, httpx.ConnectError) + finally: + client.close() + + +async def test_transport_wrapping_async_read_timeout(httpserver, monkeypatch): + """A real local server that delays the response longer than the client + read-timeout triggers ``httpx.ReadTimeout``, wrapped by the SDK as + ``StreamTransportException`` with ``error_type='timeout'``.""" + import time + + from getstream import AsyncStream + + from werkzeug.wrappers import Response as WerkzeugResponse + + def slow_handler(_request): + time.sleep(1.0) # exceed the client's read timeout below + return WerkzeugResponse("{}", status=200, content_type="application/json") + + httpserver.expect_request("/api/v2/app").respond_with_handler(slow_handler) + + monkeypatch.delenv("STREAM_API_KEY", raising=False) + monkeypatch.delenv("STREAM_API_SECRET", raising=False) + client = AsyncStream( + api_key="k", + api_secret="s", + base_url=httpserver.url_for("/"), + request_timeout=0.2, + ) + try: + with pytest.raises(StreamTransportException) as info: + await client.get_app() + assert info.value.error_type == TRANSPORT_ERROR_TIMEOUT + assert isinstance( + info.value.__cause__, (httpx.ReadTimeout, httpx.TimeoutException) + ) + finally: + await client.aclose() + + +# ── Back-compat alias ───────────────────────────────────────────────── + + +def test_legacy_alias_still_importable_with_deprecation_warning(): + """``from getstream.base import StreamAPIException`` keeps working but + emits ``DeprecationWarning`` per spec §10.""" + import getstream.base as base_mod + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + legacy = getattr(base_mod, "StreamAPIException") + assert legacy is StreamApiException + assert any( + issubclass(w.category, DeprecationWarning) + and "StreamAPIException is deprecated" in str(w.message) + for w in caught + ) + + +def test_legacy_alias_catches_new_exception(): + """Code that does ``except StreamAPIException`` keeps catching exceptions + raised by the new code path.""" + from getstream.base import StreamAPIException + + raised = StreamApiException(status_code=500, message="boom") + with pytest.raises(StreamAPIException) as info: + raise raised + assert info.value is raised + + +# ── wait_for_task (§8) ──────────────────────────────────────────────── + + +_FIXED_NS = int(datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc).timestamp() * 1e9) + + +def _task_response(status: str, error: Optional[dict] = None) -> dict: + """Build a get-task response body matching the real API shape. + `created_at` / `updated_at` are nanosecond Unix timestamps per the + backend's wire format.""" + body: dict = { + "duration": "1ms", + "task_id": "t", + "status": status, + "created_at": _FIXED_NS, + "updated_at": _FIXED_NS, + } + if error is not None: + body["error"] = error + return body + + +def _client_against(httpserver, monkeypatch): + """Real ``Stream`` instance pointing at the loopback ``httpserver``.""" + from getstream import Stream + + monkeypatch.delenv("STREAM_API_KEY", raising=False) + monkeypatch.delenv("STREAM_API_SECRET", raising=False) + return Stream( + api_key="k", + api_secret="s", + base_url=httpserver.url_for("/"), + ) + + +def _async_client_against(httpserver, monkeypatch): + from getstream import AsyncStream + + monkeypatch.delenv("STREAM_API_KEY", raising=False) + monkeypatch.delenv("STREAM_API_SECRET", raising=False) + return AsyncStream( + api_key="k", + api_secret="s", + base_url=httpserver.url_for("/"), + ) + + +def test_wait_for_task_sync_returns_on_completed(httpserver, monkeypatch): + """The helper exits the polling loop the first time the server reports + ``status='completed'``.""" + bodies = iter( + [ + _task_response("waiting"), + _task_response("completed"), + ] + ) + from werkzeug.wrappers import Response as WerkzeugResponse + + def handler(_r): + return WerkzeugResponse( + json.dumps(next(bodies)), + status=200, + content_type="application/json", + ) + + httpserver.expect_request("/api/v2/tasks/task-1").respond_with_handler(handler) + + client = _client_against(httpserver, monkeypatch) + try: + result = client.wait_for_task("task-1", poll_interval=0.0, timeout=5.0) + finally: + client.close() + assert result.data.status == "completed" + + +def test_wait_for_task_sync_raises_on_failed(httpserver, monkeypatch): + """A ``status='failed'`` response surfaces as ``StreamTaskException`` + populated from the task's ``ErrorResult``.""" + httpserver.expect_request("/api/v2/tasks/task-fail").respond_with_json( + _task_response( + "failed", + error={ + "type": "ImportFailed", + "description": "bad rows", + "stacktrace": "trace", + "version": "v1", + }, + ) + ) + + client = _client_against(httpserver, monkeypatch) + try: + with pytest.raises(StreamTaskException) as info: + client.wait_for_task("task-fail", poll_interval=0.0, timeout=5.0) + finally: + client.close() + exc = info.value + assert exc.task_id == "task-fail" + assert exc.error_type == "ImportFailed" + assert exc.description == "bad rows" + assert exc.stack_trace == "trace" + assert exc.version == "v1" + + +def test_wait_for_task_sync_times_out_raises_transport_exception( + httpserver, monkeypatch +): + """A perpetually-waiting task causes the helper to raise + ``StreamTransportException`` with ``error_type='timeout'``.""" + httpserver.expect_request("/api/v2/tasks/task-timeout").respond_with_json( + _task_response("waiting") + ) + + client = _client_against(httpserver, monkeypatch) + try: + with pytest.raises(StreamTransportException) as info: + client.wait_for_task("task-timeout", poll_interval=0.05, timeout=0.15) + finally: + client.close() + assert info.value.error_type == TRANSPORT_ERROR_TIMEOUT + + +async def test_wait_for_task_async_returns_on_completed(httpserver, monkeypatch): + bodies = iter( + [ + _task_response("waiting"), + _task_response("completed"), + ] + ) + from werkzeug.wrappers import Response as WerkzeugResponse + + def handler(_r): + return WerkzeugResponse( + json.dumps(next(bodies)), + status=200, + content_type="application/json", + ) + + httpserver.expect_request("/api/v2/tasks/task-1").respond_with_handler(handler) + + client = _async_client_against(httpserver, monkeypatch) + try: + result = await client.wait_for_task("task-1", poll_interval=0.0, timeout=5.0) + finally: + await client.aclose() + assert result.data.status == "completed" + + +async def test_wait_for_task_async_raises_on_failed(httpserver, monkeypatch): + httpserver.expect_request("/api/v2/tasks/task-fail").respond_with_json( + _task_response( + "failed", + error={ + "type": "ImportFailed", + "description": "async bad", + "stacktrace": None, + "version": None, + }, + ) + ) + + client = _async_client_against(httpserver, monkeypatch) + try: + with pytest.raises(StreamTaskException) as info: + await client.wait_for_task("task-fail", poll_interval=0.0, timeout=5.0) + finally: + await client.aclose() + assert info.value.task_id == "task-fail" + assert info.value.description == "async bad" + + +async def test_wait_for_task_async_times_out_raises_transport_exception( + httpserver, monkeypatch +): + httpserver.expect_request("/api/v2/tasks/task-timeout").respond_with_json( + _task_response("waiting") + ) + + client = _async_client_against(httpserver, monkeypatch) + try: + with pytest.raises(StreamTransportException) as info: + await client.wait_for_task("task-timeout", poll_interval=0.05, timeout=0.15) + finally: + await client.aclose() + assert info.value.error_type == TRANSPORT_ERROR_TIMEOUT + + +# ── StreamTaskException construction ────────────────────────────────── + + +def test_stream_task_exception_fields(): + exc = StreamTaskException( + task_id="t-9", + error_type="ImportFailed", + description="boom", + stack_trace="trace", + version="v3", + ) + assert exc.task_id == "t-9" + assert exc.error_type == "ImportFailed" + assert exc.description == "boom" + assert exc.stack_trace == "trace" + assert exc.version == "v3" + # Optional fields default to None when not supplied. + bare = StreamTaskException(task_id="t-0", error_type="x", description="y") + assert bare.stack_trace is None + assert bare.version is None diff --git a/uv.lock b/uv.lock index 793e21d5..91251ed2 100644 --- a/uv.lock +++ b/uv.lock @@ -951,6 +951,7 @@ dev = [ { name = "prometheus-client" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-httpserver" }, { name = "pytest-timeout" }, { name = "python-dateutil" }, { name = "python-dotenv" }, @@ -1011,6 +1012,7 @@ dev = [ { name = "prometheus-client", specifier = ">=0.23.1" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, + { name = "pytest-httpserver", specifier = ">=1.1.5" }, { name = "pytest-timeout", specifier = ">=2.3.1" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, { name = "python-dateutil", specifier = ">=2.8.2,<3" }, @@ -2860,6 +2862,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] +[[package]] +name = "pytest-httpserver" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/17/ad187f46998814014f7cda309de700b87c0eb4b2e111e18bc8c819be7116/pytest_httpserver-1.1.5.tar.gz", hash = "sha256:dc3d82e1fe00e491829d8939c549bf4bd9b39a260f87113c619b9d517c2f8ff1", size = 70974, upload-time = "2026-02-14T13:27:23.412Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/df/0bdf90b84c6a586a9fd2b509523a3ab26b1cc1b1dba2fb62a32e4411ea9e/pytest_httpserver-1.1.5-py3-none-any.whl", hash = "sha256:ee83feb587ab652c0c6729598db2820e9048233bac8df756818b7845a1621d0a", size = 23330, upload-time = "2026-02-14T13:27:22.119Z" }, +] + [[package]] name = "pytest-timeout" version = "2.4.0" @@ -3686,6 +3700,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] +[[package]] +name = "werkzeug" +version = "3.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dd/b2/381be8cfdee792dd117872481b6e378f85c957dd7c5bca38897b08f765fd/werkzeug-3.1.8.tar.gz", hash = "sha256:9bad61a4268dac112f1c5cd4630a56ede601b6ed420300677a869083d70a4c44", size = 875852, upload-time = "2026-04-02T18:49:14.268Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/8c/2e650f2afeb7ee576912636c23ddb621c91ac6a98e66dc8d29c3c69446e1/werkzeug-3.1.8-py3-none-any.whl", hash = "sha256:63a77fb8892bf28ebc3178683445222aa500e48ebad5ec77b0ad80f8726b1f50", size = 226459, upload-time = "2026-04-02T18:49:12.72Z" }, +] + [[package]] name = "yarl" version = "1.23.0"