diff --git a/README.md b/README.md index d24d9db..154d362 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,12 @@ class CreateUser(BaseModel): name: str email: str + +def user_type_handler(request: RequestInfo): + match request.params["type"]: + case "A": return User + case "B": return User2 + # Create your API client class MyAPIClient(RequestsWebClient): @@ -98,6 +104,10 @@ class MyAPIClient(RequestsWebClient): ) def delete_user(self, user_id: str, request_headers: Dict[str, Any]): ... + + @get("/users?type={type}", response_type_handler=user_type_handler) + def get_user_type(self, type: str) -> Union[User, User2]: + ... # Use the client @@ -107,6 +117,12 @@ user = client.get_user(user_id=123) user_body = CreateUser(name="john", email="123@gmail.com") user = client.create_user(user_body) +# user_1 is User +user_1 = client.get_user_type("A") +# user_2 is User2 +user_2 = client.get_user_type("B") + + # will update the client headers. client.delete_user("123", {"ba": "your"}) diff --git a/pydantic_client/async_client.py b/pydantic_client/async_client.py index c1a598b..d2ead49 100644 --- a/pydantic_client/async_client.py +++ b/pydantic_client/async_client.py @@ -37,6 +37,8 @@ async def _request(self, request_info: RequestInfo) -> Any: request_params = self.dump_request_params(request_info) response_model = request_params.pop("response_model") extract_path = request_params.pop("response_extract_path", None) # Get response extraction path parameter + response_type_handler = request_params.pop( + "response_type_handler", None) request_params = self.before_request(request_params) @@ -46,6 +48,12 @@ async def _request(self, request_info: RequestInfo) -> Any: async with self.session.request(**request_params) as response: response.raise_for_status() + if ( + response_type_handler is not None and + callable(response_type_handler) + ): + response_model = response_type_handler(request_info) + if response_model is str: return await response.text() elif response_model is bytes: @@ -89,12 +97,19 @@ async def _request(self, request_info: RequestInfo) -> Any: request_params = self.dump_request_params(request_info) response_model = request_params.pop("response_model") extract_path = request_params.pop("response_extract_path", None) # Get response extraction path parameter + response_type_handler = request_params.pop( + "response_type_handler", None) request_params = self.before_request(request_params) async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.request(**request_params) response.raise_for_status() + if ( + response_type_handler is not None and + callable(response_type_handler) + ): + response_model = response_type_handler(request_info) if response_model is str: return response.text diff --git a/pydantic_client/decorators.py b/pydantic_client/decorators.py index 19f49bc..dec753d 100644 --- a/pydantic_client/decorators.py +++ b/pydantic_client/decorators.py @@ -2,12 +2,13 @@ import re import warnings from functools import wraps -from typing import Callable, Optional +from typing import Callable, Optional, Any from pydantic import BaseModel -from .tools.agno import register_agno_tool from .schema import RequestInfo +from .tools.agno import register_agno_tool + def _extract_path_and_query(path: str): """ @@ -29,6 +30,7 @@ def _extract_path_and_query(path: str): else: return path, [] + def _warn_if_path_params_missing(path: str, func: Callable): """Check if all {var} in path appear in func parameters during registration""" sig = inspect.signature(func) @@ -41,8 +43,12 @@ def _warn_if_path_params_missing(path: str, func: Callable): f"Function '{func.__name__}' missing parameters {missing} required by path '{path}'" ) + def _process_request_params( - func: Callable, method: str, path: str, form_body: bool, response_extract_path: Optional[str] = None, *args, **kwargs + func: Callable, method: str, path: str, form_body: bool, + response_extract_path: Optional[str] = None, + response_type_handler: Optional[Callable[[RequestInfo], Any]] = None, + *args, **kwargs ) -> RequestInfo: sig = inspect.signature(func) bound_args = sig.bind(*args, **kwargs) @@ -50,8 +56,7 @@ def _process_request_params( params = dict(bound_args.arguments) params.pop("self", None) request_headers = params.pop("request_headers", None) - - + return_type = sig.return_annotation if isinstance(return_type, type) and issubclass(return_type, BaseModel): response_model = return_type @@ -62,7 +67,8 @@ def _process_request_params( raw_path, query_tpls = _extract_path_and_query(path) formatted_path = raw_path.format(**{ - k: params[k] for k in re.findall(r'{([a-zA-Z_][a-zA-Z0-9_]*)}', raw_path) + k: params[k] for k in + re.findall(r'{([a-zA-Z_][a-zA-Z0-9_]*)}', raw_path) }) query_params = {} @@ -70,7 +76,7 @@ def _process_request_params( v = params.pop(v_name, None) if v is not None: query_params[k] = v - + body_data = None for param_name, param_value in params.items(): if isinstance(param_value, BaseModel): @@ -94,16 +100,19 @@ def _process_request_params( "headers": request_headers, "response_model": response_model, "function_name": func.__name__, - "response_extract_path": response_extract_path + "response_extract_path": response_extract_path, + "response_type_handler": response_type_handler } return RequestInfo.model_validate(info) + def rest( - method: str, + method: str, form_body: bool = False, agno_tool: bool = False, tool_description: Optional[str] = None, - response_extract_path: Optional[str] = None + response_extract_path: Optional[str] = None, + response_type_handler: Optional[Callable[[RequestInfo], Any]] = None ) -> Callable: def decorator(path: str) -> Callable: def wrapper(func: Callable) -> Callable: @@ -114,14 +123,18 @@ def wrapper(func: Callable) -> Callable: @wraps(func) async def async_wrapped(self, *args, **kwargs): request_params = _process_request_params( - func, method, path, form_body, response_extract_path, self, *args, **kwargs + func, method, path, form_body, response_extract_path, + response_type_handler, self, + *args, **kwargs ) return await self._request(request_params) @wraps(func) def sync_wrapped(self, *args, **kwargs): request_params = _process_request_params( - func, method, path, form_body, response_extract_path, self, *args, **kwargs + func, method, path, form_body, response_extract_path, + response_type_handler, self, + *args, **kwargs ) return self._request(request_params) @@ -142,13 +155,15 @@ def get( path: str, agno_tool: bool = False, tool_description: Optional[str] = None, - response_extract_path: Optional[str] = None + response_extract_path: Optional[str] = None, + response_type_handler: Optional[Callable[[RequestInfo], Any]] = None ) -> Callable: return rest( - "GET", - agno_tool=agno_tool, + "GET", + agno_tool=agno_tool, tool_description=tool_description, - response_extract_path=response_extract_path + response_extract_path=response_extract_path, + response_type_handler=response_type_handler )(path) @@ -156,13 +171,15 @@ def delete( path: str, agno_tool: bool = False, tool_description: Optional[str] = None, - response_extract_path: Optional[str] = None + response_extract_path: Optional[str] = None, + response_type_handler: Optional[Callable[[RequestInfo], Any]] = None ) -> Callable: return rest( - "DELETE", - agno_tool=agno_tool, + "DELETE", + agno_tool=agno_tool, tool_description=tool_description, - response_extract_path=response_extract_path + response_extract_path=response_extract_path, + response_type_handler=response_type_handler )(path) @@ -171,14 +188,16 @@ def post( form_body: bool = False, agno_tool: bool = False, tool_description: Optional[str] = None, - response_extract_path: Optional[str] = None + response_extract_path: Optional[str] = None, + response_type_handler: Optional[Callable[[RequestInfo], Any]] = None ) -> Callable: return rest( - "POST", + "POST", form_body=form_body, - agno_tool=agno_tool, + agno_tool=agno_tool, tool_description=tool_description, - response_extract_path=response_extract_path + response_extract_path=response_extract_path, + response_type_handler=response_type_handler )(path) @@ -187,14 +206,16 @@ def put( form_body: bool = False, agno_tool: bool = False, tool_description: Optional[str] = None, - response_extract_path: Optional[str] = None + response_extract_path: Optional[str] = None, + response_type_handler: Optional[Callable[[RequestInfo], Any]] = None ) -> Callable: return rest( - "PUT", + "PUT", form_body=form_body, - agno_tool=agno_tool, + agno_tool=agno_tool, tool_description=tool_description, - response_extract_path=response_extract_path + response_extract_path=response_extract_path, + response_type_handler=response_type_handler )(path) @@ -203,12 +224,14 @@ def patch( form_body: bool = False, agno_tool: bool = False, tool_description: Optional[str] = None, - response_extract_path: Optional[str] = None + response_extract_path: Optional[str] = None, + response_type_handler: Optional[Callable[[RequestInfo], Any]] = None ) -> Callable: return rest( - "PATCH", + "PATCH", form_body=form_body, - agno_tool=agno_tool, + agno_tool=agno_tool, tool_description=tool_description, - response_extract_path=response_extract_path + response_extract_path=response_extract_path, + response_type_handler=response_type_handler )(path) diff --git a/pydantic_client/schema.py b/pydantic_client/schema.py index 11bf5ca..55b844b 100644 --- a/pydantic_client/schema.py +++ b/pydantic_client/schema.py @@ -1,4 +1,5 @@ -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Callable + from pydantic import BaseModel @@ -12,3 +13,4 @@ class RequestInfo(BaseModel): response_model: Optional[Any] = None function_name: Optional[str] = None response_extract_path: Optional[str] = None + response_type_handler: Optional[Callable] = None diff --git a/pydantic_client/sync_client.py b/pydantic_client/sync_client.py index c2ef109..d2d0caa 100644 --- a/pydantic_client/sync_client.py +++ b/pydantic_client/sync_client.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, Optional, TypeVar import logging +from typing import Any, Dict, Optional, TypeVar import requests from pydantic import BaseModel @@ -16,7 +16,7 @@ def __init__( self, base_url: str, headers: Optional[Dict[str, Any]] = None, - timeout: Optional[int] =30, + timeout: Optional[int] = 30, session: Optional[requests.Session] = None, statsd_address: Optional[str] = None ): @@ -33,24 +33,33 @@ def _request(self, request_info: RequestInfo) -> Any: request_params = self.dump_request_params(request_info) response_model = request_params.pop("response_model") extract_path = request_params.pop("response_extract_path", None) + response_type_handler = request_params.pop( + "response_type_handler", None) request_params = self.before_request(request_params) response = self.session.request(**request_params, timeout=self.timeout) response.raise_for_status() - + + if ( + response_type_handler is not None and + callable(response_type_handler) + ): + response_model = response_type_handler(request_info) + if response_model is str: return response.text elif response_model is bytes: return response.content elif extract_path: # Process nested path extraction - return self._extract_nested_data(response.json(), extract_path, response_model) - elif not response_model or response_model is dict or getattr(response_model, '__module__', None) == 'inspect': + return self._extract_nested_data(response.json(), extract_path, + response_model) + elif not response_model or response_model is dict or getattr( + response_model, '__module__', None) == 'inspect': return response.json() elif hasattr(response_model, 'model_validate'): - return response_model.model_validate(response.json(), from_attributes=True) + return response_model.model_validate(response.json(), + from_attributes=True) else: return response.json() - - diff --git a/tests/test_union_response_type.py b/tests/test_union_response_type.py new file mode 100644 index 0000000..8530457 --- /dev/null +++ b/tests/test_union_response_type.py @@ -0,0 +1,57 @@ +from typing import Optional, Union + +import requests_mock +from pydantic import BaseModel + +from pydantic_client import RequestsWebClient, get +from pydantic_client.schema import RequestInfo + + +class User(BaseModel): + id: str + name: Optional[str] = None + + +class User2(BaseModel): + id: str + age: Optional[int] = None + + +def get_user_type_handler(request_info: RequestInfo): + type = request_info.params["type"] + print(type) + if type == "A": + return User + elif type == "B": + return User2 + else: + return None + + +class TestClient(RequestsWebClient): + + @get("/users?type={type}", response_type_handler=get_user_type_handler) + def get_user(self, type: str) -> Union[User, User2]: + ... + + +def test_get_union(): + with requests_mock.Mocker() as m: + m.get( + 'http://example.com/users?type=A', + json={"id": "123"} + ) + + client = TestClient(base_url="http://example.com") + response = client.get_user("A") + + assert isinstance(response, User) + assert response.id == "123" + + m.get( + 'http://example.com/users?type=B', + json={"id": "123"} + ) + response = client.get_user("B") + assert isinstance(response, User2) + assert response.id == "123"