Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
new_query_bus,
query_handler,
)
from ._core.middleware import Middleware, MiddlewareResult
from ._core.middleware import Middleware, MiddlewareResult, resolve_handler_source
from ._core.pipetools import ContextCommandPipeline
from ._core.related_events import RelatedEvents
from ._core.scope import CQScope
Expand Down Expand Up @@ -47,4 +47,5 @@
"new_event_bus",
"new_query_bus",
"query_handler",
"resolve_handler_source",
)
16 changes: 13 additions & 3 deletions cq/_core/dispatcher/bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def add_middlewares(self, *middlewares: Middleware[[I], O]) -> Self:
raise NotImplementedError

@abstractmethod
def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
def subscribe(
self,
input_type: type[I],
factory: HandlerFactory[[I], O],
fail_silently: bool = ...,
) -> Self:
raise NotImplementedError


Expand All @@ -50,8 +55,13 @@ def add_listeners(self, *listeners: Listener[I]) -> Self:
self.__listeners.extend(listeners)
return self

def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
self.__registry.subscribe(input_type, factory)
def subscribe(
self,
input_type: type[I],
factory: HandlerFactory[[I], O],
fail_silently: bool = False,
) -> Self:
self.__registry.subscribe(input_type, factory, fail_silently=fail_silently)
return self

def _handlers_from(self, input_type: type[I]) -> Iterator[HandleFunction[[I], O]]:
Expand Down
23 changes: 16 additions & 7 deletions cq/_core/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Awaitable, Callable, Iterator
from dataclasses import dataclass, field
from functools import partial
from inspect import Parameter, isclass
from inspect import Parameter, isclass, unwrap
from inspect import signature as inspect_signature
from typing import TYPE_CHECKING, Any, Protocol, Self, overload, runtime_checkable

Expand All @@ -27,14 +27,23 @@ async def handle(self, /, *args: P.args, **kwargs: P.kwargs) -> T:

@dataclass(repr=False, eq=False, frozen=True, slots=True)
class HandleFunction[**P, T]:
handler_factory: HandlerFactory[P, T]
handler_type: HandlerType[P, T] | None = field(default=None)
fail_silently: bool = field(default=False)
factory: HandlerFactory[P, T]
source: HandlerType[P, T] | Any
fail_silently: bool

async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
handler = await self.handler_factory()
handler = await self.factory()
return await handler.handle(*args, **kwargs)

@classmethod
def create(
cls,
factory: HandlerFactory[P, T],
source: HandlerType[P, T] | None = None,
fail_silently: bool = False,
) -> Self:
return cls(factory, source or unwrap(factory), fail_silently)


@runtime_checkable
class HandlerRegistry[I, O](Protocol):
Expand Down Expand Up @@ -73,7 +82,7 @@ def subscribe(
handler_type: HandlerType[[I], O] | None = None,
fail_silently: bool = False,
) -> Self:
function = HandleFunction(handler_factory, handler_type, fail_silently)
function = HandleFunction.create(handler_factory, handler_type, fail_silently)

for key_type in _build_key_types(input_type):
self.__values[key_type].append(function)
Expand Down Expand Up @@ -101,7 +110,7 @@ def subscribe(
handler_type: HandlerType[[I], O] | None = None,
fail_silently: bool = False,
) -> Self:
function = HandleFunction(handler_factory, handler_type, fail_silently)
function = HandleFunction.create(handler_factory, handler_type, fail_silently)
entries = {key_type: function for key_type in _build_key_types(input_type)}

for key_type in entries:
Expand Down
16 changes: 15 additions & 1 deletion cq/_core/middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from inspect import isasyncgenfunction
from typing import Concatenate, Self, TypeGuard
from typing import Any, Concatenate, Self, TypeGuard

from cq._core.handler import HandleFunction, HandlerType
from cq.exceptions import MiddlewareError

type MiddlewareResult[T] = AsyncGenerator[None, T]
Expand Down Expand Up @@ -63,6 +64,19 @@ async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return await self.middleware(self.call_next, *args, **kwargs)


def resolve_handler_source[**P, T](
call_next: Callable[P, Awaitable[T]]
| _BoundMiddleware[P, T]
| HandleFunction[P, T],
/,
) -> HandlerType[P, T] | Any:
while True:
try:
call_next = call_next.call_next # type: ignore[union-attr]
except AttributeError:
return call_next.source # type: ignore[union-attr]


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _GeneratorMiddleware[**P, T]:
middleware: GeneratorMiddleware[P, T]
Expand Down
33 changes: 32 additions & 1 deletion tests/core/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,42 @@

import pytest

from cq._core.middleware import MiddlewareGroup, MiddlewareResult
from cq._core.dispatcher.bus import SimpleBus
from cq._core.handler import HandlerDecorator, SingleHandlerRegistry
from cq._core.middleware import (
MiddlewareGroup,
MiddlewareResult,
resolve_handler_source,
)
from cq.exceptions import MiddlewareError
from tests.helpers.history import HistoryMiddleware


async def test_resolve_handler_source_with_success() -> None:
registry = SingleHandlerRegistry[Any, Any]()
handler = HandlerDecorator(registry)

@handler
class Handler:
async def handle(self, message: str) -> None: ...

expected: Any = None

async def middleware(
call_next: Callable[[Any], Awaitable[Any]],
message: Any,
) -> Any:
nonlocal expected
expected = resolve_handler_source(call_next)
return await call_next(message)

bus = SimpleBus(registry)
bus.add_middlewares(middleware)
await bus.dispatch("hello")

assert expected is Handler


class TestMiddlewareGroup:
@pytest.fixture(scope="function")
def group(self) -> MiddlewareGroup[..., Any]:
Expand Down
48 changes: 24 additions & 24 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.