Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion cq/_core/common/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from typing import Protocol
from collections.abc import Callable
from typing import Any, Protocol, overload


class Decorator(Protocol):
def __call__[T](self, wrapped: T, /) -> T: ...


class Method[**P, T](Protocol):
@overload
def __call__(self, instance: Any, /, *args: P.args, **kwargs: P.kwargs) -> T: ...

@overload
def __call__(self, /, *args: Any, **kwargs: Any) -> T: ...

def __get__(
self,
instance: object,
owner: type | None = ...,
/,
) -> Callable[P, T]: ...
265 changes: 183 additions & 82 deletions cq/_core/dispatcher/pipe.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,102 @@
from collections import deque
from abc import abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, Self, overload

from cq._core.common.typing import Decorator
from functools import partial
from inspect import iscoroutinefunction
from typing import (
TYPE_CHECKING,
Any,
Concatenate,
Protocol,
Self,
overload,
runtime_checkable,
)

from cq._core.common.typing import Decorator, Method
from cq._core.dispatcher.base import BaseDispatcher, Dispatcher
from cq._core.middleware import Middleware
from cq._core.middleware import Middleware, MiddlewareGroup

type PipeConverter[I, O] = Callable[[O], Awaitable[I]]
type ConvertAsync[**P, I, O] = Callable[Concatenate[O, P], Awaitable[I]]
type ConvertSync[**P, I, O] = Callable[Concatenate[O, P], I]
type Convert[**P, I, O] = ConvertAsync[P, I, O] | ConvertSync[P, I, O]

type ConvertMethodAsync[I, O] = Method[[O], Awaitable[I]]
type ConvertMethodSync[I, O] = Method[[O], I]
type ConvertMethod[I, O] = ConvertMethodAsync[I, O] | ConvertMethodSync[I, O]

class PipeConverterMethod[I, O](Protocol):
def __get__(
self,
instance: object,
owner: type | None = ...,
) -> PipeConverter[I, O]: ...

@runtime_checkable
class PipelineConverter[**P, I, O](Protocol):
__slots__ = ()

@abstractmethod
async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I:
raise NotImplementedError


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class PipeStep[I, O]:
converter: PipeConverter[I, O]
class PipelineStep[**P, I, O]:
converter: PipelineConverter[P, I, O]
dispatcher: Dispatcher[I, Any] | None = field(default=None)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class PipelineSteps[**P, I, O]:
default_dispatcher: Dispatcher[Any, Any]
__steps: list[PipelineStep[P, Any, Any]] = field(default_factory=list, init=False)

def add[T](
self,
converter: PipelineConverter[P, T, Any],
dispatcher: Dispatcher[T, Any] | None,
) -> Self:
self.__steps.append(PipelineStep(converter, dispatcher))
return self

async def execute(self, input_value: I, /, *args: P.args, **kwargs: P.kwargs) -> O:
dispatcher = self.default_dispatcher

for step in self.__steps:
output_value = await dispatcher.dispatch(input_value)
input_value = await step.converter.convert(output_value, *args, **kwargs)

if input_value is None:
return NotImplemented

dispatcher = step.dispatcher or self.default_dispatcher

return await dispatcher.dispatch(input_value)


class Pipe[I, O](BaseDispatcher[I, O]):
__slots__ = ("__dispatcher", "__steps")
__slots__ = ("__steps",)

__dispatcher: Dispatcher[Any, Any]
__steps: list[PipeStep[Any, Any]]
__steps: PipelineSteps[[], I, O]

def __init__(self, dispatcher: Dispatcher[Any, Any]) -> None:
super().__init__()
self.__dispatcher = dispatcher
self.__steps = []
self.__steps = PipelineSteps(dispatcher)

if TYPE_CHECKING: # pragma: no cover

@overload
def step[T](
self,
wrapped: PipeConverter[T, Any],
wrapped: ConvertAsync[[], T, Any],
/,
*,
dispatcher: Dispatcher[T, Any] | None = ...,
) -> ConvertAsync[[], T, Any]: ...

@overload
def step[T](
self,
wrapped: ConvertSync[[], T, Any],
/,
*,
dispatcher: Dispatcher[T, Any] | None = ...,
) -> PipeConverter[T, Any]: ...
) -> ConvertSync[[], T, Any]: ...

@overload
def step(
Expand All @@ -57,14 +109,18 @@ def step(

def step[T](
self,
wrapped: PipeConverter[T, Any] | None = None,
wrapped: Convert[[], T, Any] | None = None,
/,
*,
dispatcher: Dispatcher[T, Any] | None = None,
) -> Any:
def decorator(wp: PipeConverter[T, Any]) -> PipeConverter[T, Any]:
step = PipeStep(wp, dispatcher)
self.__steps.append(step)
def decorator(wp: Convert[[], T, Any]) -> Convert[[], T, Any]:
converter = (
_AsyncPipelineConverter(wp)
if iscoroutinefunction(wp)
else _SyncPipelineConverter(wp)
)
self.__steps.add(converter, dispatcher)
return wp

return decorator(wrapped) if wrapped else decorator
Expand All @@ -75,47 +131,23 @@ def add_static_step[T](
*,
dispatcher: Dispatcher[T, Any] | None = None,
) -> Self:
@self.step(dispatcher=dispatcher)
async def converter(_: Any) -> T:
return input_value

converter = _StaticPipelineConverter(input_value)
self.__steps.add(converter, dispatcher)
return self

async def dispatch(self, input_value: I, /) -> O:
return await self._invoke_with_middlewares(self.__execute, input_value)

async def __execute(self, input_value: I) -> O:
dispatcher = self.__dispatcher

for step in self.__steps:
output_value = await dispatcher.dispatch(input_value)
input_value = await step.converter(output_value)

if input_value is None:
return NotImplemented

dispatcher = step.dispatcher or self.__dispatcher

return await dispatcher.dispatch(input_value)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class ContextPipelineStep[I, O]:
converter: PipeConverterMethod[I, O]
dispatcher: Dispatcher[I, Any] | None = field(default=None)
return await self._invoke_with_middlewares(self.__steps.execute, input_value)


class ContextPipeline[I]:
__slots__ = ("__dispatcher", "__middlewares", "__steps")
__slots__ = ("__middleware_group", "__steps")

__dispatcher: Dispatcher[Any, Any]
__middlewares: deque[Middleware[Any, Any]]
__steps: list[ContextPipelineStep[Any, Any]]
__middleware_group: MiddlewareGroup[[I], Any]
__steps: PipelineSteps[[object, type | None], I, Any]

def __init__(self, dispatcher: Dispatcher[Any, Any]) -> None:
self.__dispatcher = dispatcher
self.__middlewares = deque()
self.__steps = []
self.__middleware_group = MiddlewareGroup()
self.__steps = PipelineSteps(dispatcher)

if TYPE_CHECKING: # pragma: no cover

Expand Down Expand Up @@ -145,23 +177,32 @@ def __get__[O](

instance = owner()

pipeline = self.__new_pipeline(instance, owner)
return BoundContextPipeline(instance, pipeline)
dispatch_method = partial(self.__execute, context=instance, context_type=owner)
return BoundContextPipeline(dispatch_method)

def add_middlewares(self, *middlewares: Middleware[[I], Any]) -> Self:
self.__middlewares.extendleft(reversed(middlewares))
self.__middleware_group.add(*middlewares)
return self

if TYPE_CHECKING: # pragma: no cover

@overload
def step[T](
self,
wrapped: PipeConverterMethod[T, Any],
wrapped: ConvertMethodAsync[T, Any],
/,
*,
dispatcher: Dispatcher[T, Any] | None = ...,
) -> ConvertMethodAsync[T, Any]: ...

@overload
def step[T](
self,
wrapped: ConvertMethodSync[T, Any],
/,
*,
dispatcher: Dispatcher[T, Any] | None = ...,
) -> PipeConverterMethod[T, Any]: ...
) -> ConvertMethodSync[T, Any]: ...

@overload
def step(
Expand All @@ -174,38 +215,98 @@ def step(

def step[T](
self,
wrapped: PipeConverterMethod[T, Any] | None = None,
wrapped: ConvertMethod[T, Any] | None = None,
/,
*,
dispatcher: Dispatcher[T, Any] | None = None,
) -> Any:
def decorator(wp: PipeConverterMethod[T, Any]) -> PipeConverterMethod[T, Any]:
step = ContextPipelineStep(wp, dispatcher)
self.__steps.append(step)
def decorator(wp: ConvertMethod[T, Any]) -> ConvertMethod[T, Any]:
converter = (
_AsyncContextPipelineConverter(wp)
if iscoroutinefunction(wp)
else _SyncContextPipelineConverter(wp)
)
self.__steps.add(converter, dispatcher)
return wp

return decorator(wrapped) if wrapped else decorator

def __new_pipeline[T](
async def __execute[O](
self,
context: T,
context_type: type[T] | None,
) -> Pipe[I, Any]:
pipeline: Pipe[I, Any] = Pipe(self.__dispatcher)
pipeline.add_middlewares(*self.__middlewares)

for step in self.__steps:
converter = step.converter.__get__(context, context_type)
pipeline.step(converter, dispatcher=step.dispatcher)

return pipeline
input_value: I,
/,
*,
context: O,
context_type: type[O] | None,
) -> O:
await self.__middleware_group.invoke(
lambda i: self.__steps.execute(i, context, context_type),
input_value,
)
return context


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class BoundContextPipeline[I, O](Dispatcher[I, O]):
context: O
pipeline: Pipe[I, Any]
dispatch_method: Callable[[I], Awaitable[O]]

async def dispatch(self, input_value: I, /) -> O:
await self.pipeline.dispatch(input_value)
return self.context
return await self.dispatch_method(input_value)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _AsyncPipelineConverter[**P, I, O](PipelineConverter[P, I, O]):
converter: ConvertAsync[P, I, O]

async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I:
return await self.converter(output_value, *args, **kwargs)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _SyncPipelineConverter[**P, I, O](PipelineConverter[P, I, O]):
converter: ConvertSync[P, I, O]

async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I:
return self.converter(output_value, *args, **kwargs)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _StaticPipelineConverter[I](PipelineConverter[..., I, Any]):
input_value: I

async def convert(self, output_value: Any, /, *args: Any, **kwargs: Any) -> I:
return self.input_value


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _AsyncContextPipelineConverter[I, O](
PipelineConverter[[object, type | None], I, O],
):
converter: ConvertMethodAsync[I, O]

async def convert(
self,
output_value: O,
/,
context: object,
context_type: type | None,
) -> I:
method = self.converter.__get__(context, context_type)
return await method(output_value)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _SyncContextPipelineConverter[I, O](
PipelineConverter[[object, type | None], I, O],
):
converter: ConvertMethodSync[I, O]

async def convert(
self,
output_value: O,
/,
context: object,
context_type: type | None,
) -> I:
method = self.converter.__get__(context, context_type)
return method(output_value)
2 changes: 1 addition & 1 deletion cq/_core/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class _BoundMiddleware[**P, T]:
call_next: Callable[P, Awaitable[T]]
middleware: ClassicMiddleware[P, T]

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return await self.middleware(self.call_next, *args, **kwargs)


Expand Down
Loading