diff --git a/src/openai/_exceptions.py b/src/openai/_exceptions.py index 09016dfedb..93c14e08a2 100644 --- a/src/openai/_exceptions.py +++ b/src/openai/_exceptions.py @@ -2,8 +2,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, cast -from typing_extensions import Literal +from typing import TYPE_CHECKING, Any, Callable, Optional, cast +from typing_extensions import Self, Literal, override import httpx @@ -66,6 +66,14 @@ def __init__(self, message: str, request: httpx.Request, *, body: object | None) self.param = None self.type = None + @classmethod + def _reconstruct_base(cls, message: str, request: httpx.Request, body: object | None) -> Self: + return cls(message, request, body=body) + + @override + def __reduce__(self) -> tuple[Callable[..., Self], tuple[Any, ...]]: + return (self.__class__._reconstruct_base, (self.message, self.request, self.body)) + class APIResponseValidationError(APIError): response: httpx.Response @@ -76,6 +84,16 @@ def __init__(self, response: httpx.Response, body: object | None, *, message: st self.response = response self.status_code = response.status_code + @classmethod + def _reconstruct(cls, response: httpx.Response, body: object | None, message: str | None) -> Self: + return cls(response, body, message=message) + + @override + def __reduce__( + self, + ) -> tuple[Callable[..., Self], tuple[Any, ...]]: + return (self.__class__._reconstruct, (self.response, self.body, self.message)) + class APIStatusError(APIError): """Raised when an API response has a status code of 4xx or 5xx.""" @@ -90,11 +108,29 @@ def __init__(self, message: str, *, response: httpx.Response, body: object | Non self.status_code = response.status_code self.request_id = response.headers.get("x-request-id") + @classmethod + def _reconstruct(cls, message: str, response: httpx.Response, body: object | None) -> Self: + return cls(message, response=response, body=body) + + @override + def __reduce__( + self, + ) -> tuple[Callable[..., Self], tuple[Any, ...]]: + return (self.__class__._reconstruct, (self.message, self.response, self.body)) + class APIConnectionError(APIError): def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: super().__init__(message, request, body=None) + @classmethod + def _reconstruct(cls, message: str, request: httpx.Request) -> Self: + return cls(message=message, request=request) + + @override + def __reduce__(self) -> tuple[Callable[..., Self], tuple[Any, ...]]: + return (self.__class__._reconstruct, (self.message, self.request)) + class APITimeoutError(APIConnectionError): def __init__(self, request: httpx.Request) -> None: @@ -149,6 +185,14 @@ def __init__(self, *, completion: ChatCompletion) -> None: super().__init__(msg) self.completion = completion + @classmethod + def _reconstruct(cls, completion: ChatCompletion) -> Self: + return cls(completion=completion) + + @override + def __reduce__(self) -> tuple[Callable[..., Self], tuple[Any, ...]]: + return (self.__class__._reconstruct, (self.completion,)) + class ContentFilterFinishReasonError(OpenAIError): def __init__(self) -> None: