diff --git a/cq/_core/common/typing.py b/cq/_core/common/typing.py index 0fe8b23..1deacc0 100644 --- a/cq/_core/common/typing.py +++ b/cq/_core/common/typing.py @@ -1,5 +1,21 @@ -from typing import Protocol +from collections.abc import Callable +from typing import Any, Protocol, overload class Decorator(Protocol): def __call__[T](self, wrapped: T, /) -> T: ... + + +class Method[**P, T](Protocol): + @overload + def __call__(self, instance: Any, /, *args: P.args, **kwargs: P.kwargs) -> T: ... + + @overload + def __call__(self, /, *args: Any, **kwargs: Any) -> T: ... + + def __get__( + self, + instance: object, + owner: type | None = ..., + /, + ) -> Callable[P, T]: ... diff --git a/cq/_core/dispatcher/pipe.py b/cq/_core/dispatcher/pipe.py index ce64b55..f634877 100644 --- a/cq/_core/dispatcher/pipe.py +++ b/cq/_core/dispatcher/pipe.py @@ -1,50 +1,102 @@ -from collections import deque +from abc import abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Protocol, Self, overload - -from cq._core.common.typing import Decorator +from functools import partial +from inspect import iscoroutinefunction +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Protocol, + Self, + overload, + runtime_checkable, +) + +from cq._core.common.typing import Decorator, Method from cq._core.dispatcher.base import BaseDispatcher, Dispatcher -from cq._core.middleware import Middleware +from cq._core.middleware import Middleware, MiddlewareGroup -type PipeConverter[I, O] = Callable[[O], Awaitable[I]] +type ConvertAsync[**P, I, O] = Callable[Concatenate[O, P], Awaitable[I]] +type ConvertSync[**P, I, O] = Callable[Concatenate[O, P], I] +type Convert[**P, I, O] = ConvertAsync[P, I, O] | ConvertSync[P, I, O] +type ConvertMethodAsync[I, O] = Method[[O], Awaitable[I]] +type ConvertMethodSync[I, O] = Method[[O], I] +type ConvertMethod[I, O] = ConvertMethodAsync[I, O] | ConvertMethodSync[I, O] -class PipeConverterMethod[I, O](Protocol): - def __get__( - self, - instance: object, - owner: type | None = ..., - ) -> PipeConverter[I, O]: ... + +@runtime_checkable +class PipelineConverter[**P, I, O](Protocol): + __slots__ = () + + @abstractmethod + async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I: + raise NotImplementedError @dataclass(repr=False, eq=False, frozen=True, slots=True) -class PipeStep[I, O]: - converter: PipeConverter[I, O] +class PipelineStep[**P, I, O]: + converter: PipelineConverter[P, I, O] dispatcher: Dispatcher[I, Any] | None = field(default=None) +@dataclass(repr=False, eq=False, frozen=True, slots=True) +class PipelineSteps[**P, I, O]: + default_dispatcher: Dispatcher[Any, Any] + __steps: list[PipelineStep[P, Any, Any]] = field(default_factory=list, init=False) + + def add[T]( + self, + converter: PipelineConverter[P, T, Any], + dispatcher: Dispatcher[T, Any] | None, + ) -> Self: + self.__steps.append(PipelineStep(converter, dispatcher)) + return self + + async def execute(self, input_value: I, /, *args: P.args, **kwargs: P.kwargs) -> O: + dispatcher = self.default_dispatcher + + for step in self.__steps: + output_value = await dispatcher.dispatch(input_value) + input_value = await step.converter.convert(output_value, *args, **kwargs) + + if input_value is None: + return NotImplemented + + dispatcher = step.dispatcher or self.default_dispatcher + + return await dispatcher.dispatch(input_value) + + class Pipe[I, O](BaseDispatcher[I, O]): - __slots__ = ("__dispatcher", "__steps") + __slots__ = ("__steps",) - __dispatcher: Dispatcher[Any, Any] - __steps: list[PipeStep[Any, Any]] + __steps: PipelineSteps[[], I, O] def __init__(self, dispatcher: Dispatcher[Any, Any]) -> None: super().__init__() - self.__dispatcher = dispatcher - self.__steps = [] + self.__steps = PipelineSteps(dispatcher) if TYPE_CHECKING: # pragma: no cover @overload def step[T]( self, - wrapped: PipeConverter[T, Any], + wrapped: ConvertAsync[[], T, Any], + /, + *, + dispatcher: Dispatcher[T, Any] | None = ..., + ) -> ConvertAsync[[], T, Any]: ... + + @overload + def step[T]( + self, + wrapped: ConvertSync[[], T, Any], /, *, dispatcher: Dispatcher[T, Any] | None = ..., - ) -> PipeConverter[T, Any]: ... + ) -> ConvertSync[[], T, Any]: ... @overload def step( @@ -57,14 +109,18 @@ def step( def step[T]( self, - wrapped: PipeConverter[T, Any] | None = None, + wrapped: Convert[[], T, Any] | None = None, /, *, dispatcher: Dispatcher[T, Any] | None = None, ) -> Any: - def decorator(wp: PipeConverter[T, Any]) -> PipeConverter[T, Any]: - step = PipeStep(wp, dispatcher) - self.__steps.append(step) + def decorator(wp: Convert[[], T, Any]) -> Convert[[], T, Any]: + converter = ( + _AsyncPipelineConverter(wp) + if iscoroutinefunction(wp) + else _SyncPipelineConverter(wp) + ) + self.__steps.add(converter, dispatcher) return wp return decorator(wrapped) if wrapped else decorator @@ -75,47 +131,23 @@ def add_static_step[T]( *, dispatcher: Dispatcher[T, Any] | None = None, ) -> Self: - @self.step(dispatcher=dispatcher) - async def converter(_: Any) -> T: - return input_value - + converter = _StaticPipelineConverter(input_value) + self.__steps.add(converter, dispatcher) return self async def dispatch(self, input_value: I, /) -> O: - return await self._invoke_with_middlewares(self.__execute, input_value) - - async def __execute(self, input_value: I) -> O: - dispatcher = self.__dispatcher - - for step in self.__steps: - output_value = await dispatcher.dispatch(input_value) - input_value = await step.converter(output_value) - - if input_value is None: - return NotImplemented - - dispatcher = step.dispatcher or self.__dispatcher - - return await dispatcher.dispatch(input_value) - - -@dataclass(repr=False, eq=False, frozen=True, slots=True) -class ContextPipelineStep[I, O]: - converter: PipeConverterMethod[I, O] - dispatcher: Dispatcher[I, Any] | None = field(default=None) + return await self._invoke_with_middlewares(self.__steps.execute, input_value) class ContextPipeline[I]: - __slots__ = ("__dispatcher", "__middlewares", "__steps") + __slots__ = ("__middleware_group", "__steps") - __dispatcher: Dispatcher[Any, Any] - __middlewares: deque[Middleware[Any, Any]] - __steps: list[ContextPipelineStep[Any, Any]] + __middleware_group: MiddlewareGroup[[I], Any] + __steps: PipelineSteps[[object, type | None], I, Any] def __init__(self, dispatcher: Dispatcher[Any, Any]) -> None: - self.__dispatcher = dispatcher - self.__middlewares = deque() - self.__steps = [] + self.__middleware_group = MiddlewareGroup() + self.__steps = PipelineSteps(dispatcher) if TYPE_CHECKING: # pragma: no cover @@ -145,11 +177,11 @@ def __get__[O]( instance = owner() - pipeline = self.__new_pipeline(instance, owner) - return BoundContextPipeline(instance, pipeline) + dispatch_method = partial(self.__execute, context=instance, context_type=owner) + return BoundContextPipeline(dispatch_method) def add_middlewares(self, *middlewares: Middleware[[I], Any]) -> Self: - self.__middlewares.extendleft(reversed(middlewares)) + self.__middleware_group.add(*middlewares) return self if TYPE_CHECKING: # pragma: no cover @@ -157,11 +189,20 @@ def add_middlewares(self, *middlewares: Middleware[[I], Any]) -> Self: @overload def step[T]( self, - wrapped: PipeConverterMethod[T, Any], + wrapped: ConvertMethodAsync[T, Any], + /, + *, + dispatcher: Dispatcher[T, Any] | None = ..., + ) -> ConvertMethodAsync[T, Any]: ... + + @overload + def step[T]( + self, + wrapped: ConvertMethodSync[T, Any], /, *, dispatcher: Dispatcher[T, Any] | None = ..., - ) -> PipeConverterMethod[T, Any]: ... + ) -> ConvertMethodSync[T, Any]: ... @overload def step( @@ -174,38 +215,98 @@ def step( def step[T]( self, - wrapped: PipeConverterMethod[T, Any] | None = None, + wrapped: ConvertMethod[T, Any] | None = None, /, *, dispatcher: Dispatcher[T, Any] | None = None, ) -> Any: - def decorator(wp: PipeConverterMethod[T, Any]) -> PipeConverterMethod[T, Any]: - step = ContextPipelineStep(wp, dispatcher) - self.__steps.append(step) + def decorator(wp: ConvertMethod[T, Any]) -> ConvertMethod[T, Any]: + converter = ( + _AsyncContextPipelineConverter(wp) + if iscoroutinefunction(wp) + else _SyncContextPipelineConverter(wp) + ) + self.__steps.add(converter, dispatcher) return wp return decorator(wrapped) if wrapped else decorator - def __new_pipeline[T]( + async def __execute[O]( self, - context: T, - context_type: type[T] | None, - ) -> Pipe[I, Any]: - pipeline: Pipe[I, Any] = Pipe(self.__dispatcher) - pipeline.add_middlewares(*self.__middlewares) - - for step in self.__steps: - converter = step.converter.__get__(context, context_type) - pipeline.step(converter, dispatcher=step.dispatcher) - - return pipeline + input_value: I, + /, + *, + context: O, + context_type: type[O] | None, + ) -> O: + await self.__middleware_group.invoke( + lambda i: self.__steps.execute(i, context, context_type), + input_value, + ) + return context @dataclass(repr=False, eq=False, frozen=True, slots=True) class BoundContextPipeline[I, O](Dispatcher[I, O]): - context: O - pipeline: Pipe[I, Any] + dispatch_method: Callable[[I], Awaitable[O]] async def dispatch(self, input_value: I, /) -> O: - await self.pipeline.dispatch(input_value) - return self.context + return await self.dispatch_method(input_value) + + +@dataclass(repr=False, eq=False, frozen=True, slots=True) +class _AsyncPipelineConverter[**P, I, O](PipelineConverter[P, I, O]): + converter: ConvertAsync[P, I, O] + + async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I: + return await self.converter(output_value, *args, **kwargs) + + +@dataclass(repr=False, eq=False, frozen=True, slots=True) +class _SyncPipelineConverter[**P, I, O](PipelineConverter[P, I, O]): + converter: ConvertSync[P, I, O] + + async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I: + return self.converter(output_value, *args, **kwargs) + + +@dataclass(repr=False, eq=False, frozen=True, slots=True) +class _StaticPipelineConverter[I](PipelineConverter[..., I, Any]): + input_value: I + + async def convert(self, output_value: Any, /, *args: Any, **kwargs: Any) -> I: + return self.input_value + + +@dataclass(repr=False, eq=False, frozen=True, slots=True) +class _AsyncContextPipelineConverter[I, O]( + PipelineConverter[[object, type | None], I, O], +): + converter: ConvertMethodAsync[I, O] + + async def convert( + self, + output_value: O, + /, + context: object, + context_type: type | None, + ) -> I: + method = self.converter.__get__(context, context_type) + return await method(output_value) + + +@dataclass(repr=False, eq=False, frozen=True, slots=True) +class _SyncContextPipelineConverter[I, O]( + PipelineConverter[[object, type | None], I, O], +): + converter: ConvertMethodSync[I, O] + + async def convert( + self, + output_value: O, + /, + context: object, + context_type: type | None, + ) -> I: + method = self.converter.__get__(context, context_type) + return method(output_value) diff --git a/cq/_core/middleware.py b/cq/_core/middleware.py index eb37dd1..10f0196 100644 --- a/cq/_core/middleware.py +++ b/cq/_core/middleware.py @@ -59,7 +59,7 @@ class _BoundMiddleware[**P, T]: call_next: Callable[P, Awaitable[T]] middleware: ClassicMiddleware[P, T] - async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T: return await self.middleware(self.call_next, *args, **kwargs) diff --git a/cq/_core/pipetools.py b/cq/_core/pipetools.py index 216ee85..1fce29b 100644 --- a/cq/_core/pipetools.py +++ b/cq/_core/pipetools.py @@ -5,7 +5,12 @@ from cq import Dispatcher from cq._core.common.typing import Decorator from cq._core.dispatcher.lazy import LazyDispatcher -from cq._core.dispatcher.pipe import ContextPipeline, PipeConverterMethod +from cq._core.dispatcher.pipe import ( + ContextPipeline, + ConvertMethod, + ConvertMethodAsync, + ConvertMethodSync, +) from cq._core.message import AnyCommandBus, Command, Query, QueryBus from cq._core.scope import CQScope from cq.middlewares.scope import InjectionScopeMiddleware @@ -48,16 +53,23 @@ def __init__( @overload def query_step[T: Query]( self, - wrapped: PipeConverterMethod[T, Any], + wrapped: ConvertMethodAsync[T, Any], /, - ) -> PipeConverterMethod[T, Any]: ... + ) -> ConvertMethodAsync[T, Any]: ... + + @overload + def query_step[T: Query]( + self, + wrapped: ConvertMethodSync[T, Any], + /, + ) -> ConvertMethodSync[T, Any]: ... @overload def query_step(self, wrapped: None = ..., /) -> Decorator: ... - def query_step[T: Query]( + def query_step[T: Query]( # type: ignore[misc] self, - wrapped: PipeConverterMethod[T, Any] | None = None, + wrapped: ConvertMethod[T, Any] | None = None, /, ) -> Any: return self.step(wrapped, dispatcher=self.__query_dispatcher) diff --git a/docs/guides/pipeline.md b/docs/guides/pipeline.md index dc0519e..e1abea6 100644 --- a/docs/guides/pipeline.md +++ b/docs/guides/pipeline.md @@ -17,16 +17,16 @@ class PaymentContext: pipeline: ContextCommandPipeline[ValidateCartCommand] = ContextCommandPipeline() @pipeline.step - async def _(self, result: CartValidatedResult) -> CreateTransactionCommand: + def _(self, result: CartValidatedResult) -> CreateTransactionCommand: return CreateTransactionCommand(cart_id=result.cart_id, amount=result.total) @pipeline.step - async def _(self, result: TransactionCreatedResult) -> NotifyMerchantCommand: + def _(self, result: TransactionCreatedResult) -> NotifyMerchantCommand: self.transaction_id = result.transaction_id return NotifyMerchantCommand(transaction_id=self.transaction_id) @pipeline.step - async def _(self, result: MerchantNotifiedResult): + def _(self, result: MerchantNotifiedResult): ... ``` diff --git a/tests/core/dispatcher/test_pipe.py b/tests/core/dispatcher/test_pipe.py index 718c757..d3f11aa 100644 --- a/tests/core/dispatcher/test_pipe.py +++ b/tests/core/dispatcher/test_pipe.py @@ -39,7 +39,7 @@ async def async_factory(cls) -> Self: pipe: Pipe[str, str | tuple[str, ...]] = Pipe(bus) @pipe.step - async def step_converter_1(length: int) -> int: + def step_converter_1(length: int) -> int: return length assert await pipe.dispatch("hello") == "*****" diff --git a/tests/test_context_command_pipeline.py b/tests/test_context_command_pipeline.py index 2cf59ba..c7d105a 100644 --- a/tests/test_context_command_pipeline.py +++ b/tests/test_context_command_pipeline.py @@ -42,12 +42,12 @@ class Context: pipeline: ContextCommandPipeline[Command1] = ContextCommandPipeline() @pipeline.step - async def _(self, foo: Foo) -> Command2: + def _(self, foo: Foo) -> Command2: self.foo = foo return Command2() @pipeline.query_step - async def _(self, bar: Bar) -> Query: + def _(self, bar: Bar) -> Query: self.bar = bar return Query()