Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class CreateUser(BaseModel):
name: str
email: str


def user_type_handler(request: RequestInfo):
match request.params["type"]:
case "A": return User
case "B": return User2


# Create your API client
class MyAPIClient(RequestsWebClient):
Expand Down Expand Up @@ -98,6 +104,10 @@ class MyAPIClient(RequestsWebClient):
)
def delete_user(self, user_id: str, request_headers: Dict[str, Any]):
...

@get("/users?type={type}", response_type_handler=user_type_handler)
def get_user_type(self, type: str) -> Union[User, User2]:
...


# Use the client
Expand All @@ -107,6 +117,12 @@ user = client.get_user(user_id=123)
user_body = CreateUser(name="john", email="123@gmail.com")
user = client.create_user(user_body)

# user_1 is User
user_1 = client.get_user_type("A")
# user_2 is User2
user_2 = client.get_user_type("B")


# will update the client headers.
client.delete_user("123", {"ba": "your"})

Expand Down
15 changes: 15 additions & 0 deletions pydantic_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ async def _request(self, request_info: RequestInfo) -> Any:
request_params = self.dump_request_params(request_info)
response_model = request_params.pop("response_model")
extract_path = request_params.pop("response_extract_path", None) # Get response extraction path parameter
response_type_handler = request_params.pop(
"response_type_handler", None)

request_params = self.before_request(request_params)

Expand All @@ -46,6 +48,12 @@ async def _request(self, request_info: RequestInfo) -> Any:
async with self.session.request(**request_params) as response:
response.raise_for_status()

if (
response_type_handler is not None and
callable(response_type_handler)
):
response_model = response_type_handler(request_info)

if response_model is str:
return await response.text()
elif response_model is bytes:
Expand Down Expand Up @@ -89,12 +97,19 @@ async def _request(self, request_info: RequestInfo) -> Any:
request_params = self.dump_request_params(request_info)
response_model = request_params.pop("response_model")
extract_path = request_params.pop("response_extract_path", None) # Get response extraction path parameter
response_type_handler = request_params.pop(
"response_type_handler", None)

request_params = self.before_request(request_params)

async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.request(**request_params)
response.raise_for_status()
if (
response_type_handler is not None and
callable(response_type_handler)
):
response_model = response_type_handler(request_info)

if response_model is str:
return response.text
Expand Down
87 changes: 55 additions & 32 deletions pydantic_client/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import re
import warnings
from functools import wraps
from typing import Callable, Optional
from typing import Callable, Optional, Any

from pydantic import BaseModel

from .tools.agno import register_agno_tool
from .schema import RequestInfo
from .tools.agno import register_agno_tool


def _extract_path_and_query(path: str):
"""
Expand All @@ -29,6 +30,7 @@ def _extract_path_and_query(path: str):
else:
return path, []


def _warn_if_path_params_missing(path: str, func: Callable):
"""Check if all {var} in path appear in func parameters during registration"""
sig = inspect.signature(func)
Expand All @@ -41,17 +43,20 @@ def _warn_if_path_params_missing(path: str, func: Callable):
f"Function '{func.__name__}' missing parameters {missing} required by path '{path}'"
)


def _process_request_params(
func: Callable, method: str, path: str, form_body: bool, response_extract_path: Optional[str] = None, *args, **kwargs
func: Callable, method: str, path: str, form_body: bool,
response_extract_path: Optional[str] = None,
response_type_handler: Optional[Callable[[RequestInfo], Any]] = None,
*args, **kwargs
) -> RequestInfo:
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
params = dict(bound_args.arguments)
params.pop("self", None)
request_headers = params.pop("request_headers", None)



return_type = sig.return_annotation
if isinstance(return_type, type) and issubclass(return_type, BaseModel):
response_model = return_type
Expand All @@ -62,15 +67,16 @@ def _process_request_params(
raw_path, query_tpls = _extract_path_and_query(path)

formatted_path = raw_path.format(**{
k: params[k] for k in re.findall(r'{([a-zA-Z_][a-zA-Z0-9_]*)}', raw_path)
k: params[k] for k in
re.findall(r'{([a-zA-Z_][a-zA-Z0-9_]*)}', raw_path)
})

query_params = {}
for k, v_name in query_tpls:
v = params.pop(v_name, None)
if v is not None:
query_params[k] = v

body_data = None
for param_name, param_value in params.items():
if isinstance(param_value, BaseModel):
Expand All @@ -94,16 +100,19 @@ def _process_request_params(
"headers": request_headers,
"response_model": response_model,
"function_name": func.__name__,
"response_extract_path": response_extract_path
"response_extract_path": response_extract_path,
"response_type_handler": response_type_handler
}
return RequestInfo.model_validate(info)


def rest(
method: str,
method: str,
form_body: bool = False,
agno_tool: bool = False,
tool_description: Optional[str] = None,
response_extract_path: Optional[str] = None
response_extract_path: Optional[str] = None,
response_type_handler: Optional[Callable[[RequestInfo], Any]] = None
) -> Callable:
def decorator(path: str) -> Callable:
def wrapper(func: Callable) -> Callable:
Expand All @@ -114,14 +123,18 @@ def wrapper(func: Callable) -> Callable:
@wraps(func)
async def async_wrapped(self, *args, **kwargs):
request_params = _process_request_params(
func, method, path, form_body, response_extract_path, self, *args, **kwargs
func, method, path, form_body, response_extract_path,
response_type_handler, self,
*args, **kwargs
)
return await self._request(request_params)

@wraps(func)
def sync_wrapped(self, *args, **kwargs):
request_params = _process_request_params(
func, method, path, form_body, response_extract_path, self, *args, **kwargs
func, method, path, form_body, response_extract_path,
response_type_handler, self,
*args, **kwargs
)
return self._request(request_params)

Expand All @@ -142,27 +155,31 @@ def get(
path: str,
agno_tool: bool = False,
tool_description: Optional[str] = None,
response_extract_path: Optional[str] = None
response_extract_path: Optional[str] = None,
response_type_handler: Optional[Callable[[RequestInfo], Any]] = None
) -> Callable:
return rest(
"GET",
agno_tool=agno_tool,
"GET",
agno_tool=agno_tool,
tool_description=tool_description,
response_extract_path=response_extract_path
response_extract_path=response_extract_path,
response_type_handler=response_type_handler
)(path)


def delete(
path: str,
agno_tool: bool = False,
tool_description: Optional[str] = None,
response_extract_path: Optional[str] = None
response_extract_path: Optional[str] = None,
response_type_handler: Optional[Callable[[RequestInfo], Any]] = None
) -> Callable:
return rest(
"DELETE",
agno_tool=agno_tool,
"DELETE",
agno_tool=agno_tool,
tool_description=tool_description,
response_extract_path=response_extract_path
response_extract_path=response_extract_path,
response_type_handler=response_type_handler
)(path)


Expand All @@ -171,14 +188,16 @@ def post(
form_body: bool = False,
agno_tool: bool = False,
tool_description: Optional[str] = None,
response_extract_path: Optional[str] = None
response_extract_path: Optional[str] = None,
response_type_handler: Optional[Callable[[RequestInfo], Any]] = None
) -> Callable:
return rest(
"POST",
"POST",
form_body=form_body,
agno_tool=agno_tool,
agno_tool=agno_tool,
tool_description=tool_description,
response_extract_path=response_extract_path
response_extract_path=response_extract_path,
response_type_handler=response_type_handler
)(path)


Expand All @@ -187,14 +206,16 @@ def put(
form_body: bool = False,
agno_tool: bool = False,
tool_description: Optional[str] = None,
response_extract_path: Optional[str] = None
response_extract_path: Optional[str] = None,
response_type_handler: Optional[Callable[[RequestInfo], Any]] = None
) -> Callable:
return rest(
"PUT",
"PUT",
form_body=form_body,
agno_tool=agno_tool,
agno_tool=agno_tool,
tool_description=tool_description,
response_extract_path=response_extract_path
response_extract_path=response_extract_path,
response_type_handler=response_type_handler
)(path)


Expand All @@ -203,12 +224,14 @@ def patch(
form_body: bool = False,
agno_tool: bool = False,
tool_description: Optional[str] = None,
response_extract_path: Optional[str] = None
response_extract_path: Optional[str] = None,
response_type_handler: Optional[Callable[[RequestInfo], Any]] = None
) -> Callable:
return rest(
"PATCH",
"PATCH",
form_body=form_body,
agno_tool=agno_tool,
agno_tool=agno_tool,
tool_description=tool_description,
response_extract_path=response_extract_path
response_extract_path=response_extract_path,
response_type_handler=response_type_handler
)(path)
4 changes: 3 additions & 1 deletion pydantic_client/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Optional, Any
from typing import Dict, Optional, Any, Callable

from pydantic import BaseModel


Expand All @@ -12,3 +13,4 @@ class RequestInfo(BaseModel):
response_model: Optional[Any] = None
function_name: Optional[str] = None
response_extract_path: Optional[str] = None
response_type_handler: Optional[Callable] = None
25 changes: 17 additions & 8 deletions pydantic_client/sync_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Dict, Optional, TypeVar
import logging
from typing import Any, Dict, Optional, TypeVar

import requests
from pydantic import BaseModel
Expand All @@ -16,7 +16,7 @@ def __init__(
self,
base_url: str,
headers: Optional[Dict[str, Any]] = None,
timeout: Optional[int] =30,
timeout: Optional[int] = 30,
session: Optional[requests.Session] = None,
statsd_address: Optional[str] = None
):
Expand All @@ -33,24 +33,33 @@ def _request(self, request_info: RequestInfo) -> Any:
request_params = self.dump_request_params(request_info)
response_model = request_params.pop("response_model")
extract_path = request_params.pop("response_extract_path", None)
response_type_handler = request_params.pop(
"response_type_handler", None)

request_params = self.before_request(request_params)

response = self.session.request(**request_params, timeout=self.timeout)
response.raise_for_status()


if (
response_type_handler is not None and
callable(response_type_handler)
):
response_model = response_type_handler(request_info)

if response_model is str:
return response.text
elif response_model is bytes:
return response.content
elif extract_path:
# Process nested path extraction
return self._extract_nested_data(response.json(), extract_path, response_model)
elif not response_model or response_model is dict or getattr(response_model, '__module__', None) == 'inspect':
return self._extract_nested_data(response.json(), extract_path,
response_model)
elif not response_model or response_model is dict or getattr(
response_model, '__module__', None) == 'inspect':
return response.json()
elif hasattr(response_model, 'model_validate'):
return response_model.model_validate(response.json(), from_attributes=True)
return response_model.model_validate(response.json(),
from_attributes=True)
else:
return response.json()


Loading
Loading