Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 69 additions & 32 deletions cq/_core/middleware.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Self
from inspect import isasyncgenfunction
from typing import Concatenate, Self, TypeGuard

from cq.exceptions import MiddlewareError

type MiddlewareResult[T] = AsyncGenerator[None, T]
type Middleware[**P, T] = Callable[P, MiddlewareResult[T]]
type GeneratorMiddleware[**P, T] = Callable[P, MiddlewareResult[T]]
type ClassicMiddleware[**P, T] = Callable[
Concatenate[Callable[P, Awaitable[T]], P], Awaitable[T]
]

type Middleware[**P, T] = ClassicMiddleware[P, T] | GeneratorMiddleware[P, T]


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class MiddlewareGroup[**P, T]:
__middlewares: list[Middleware[P, T]] = field(default_factory=list, init=False)
__middlewares: list[ClassicMiddleware[P, T]] = field(
default_factory=list,
init=False,
)

def add(self, *middlewares: Middleware[P, T]) -> Self:
self.__middlewares.extend(reversed(middlewares))
classic_middlewares = reversed(
tuple(self.__normalize(middleware) for middleware in middlewares)
)
self.__middlewares.extend(classic_middlewares)
return self

async def invoke(
Expand All @@ -30,40 +42,65 @@ def __apply_stack(
handler: Callable[P, Awaitable[T]],
) -> Callable[P, Awaitable[T]]:
for middleware in self.__middlewares:
handler = self.__apply_middleware(handler, middleware)
handler = _BoundMiddleware(handler, middleware)

return handler

@classmethod
def __apply_middleware(
cls,
handler: Callable[P, Awaitable[T]],
middleware: Middleware[P, T],
) -> Callable[P, Awaitable[T]]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
generator: MiddlewareResult[T] = middleware(*args, **kwargs)
value: T = NotImplemented
@staticmethod
def __normalize(middleware: Middleware[P, T]) -> ClassicMiddleware[P, T]:
if _is_gen_middleware(middleware):
return _GeneratorMiddleware(middleware)

return middleware # type: ignore[return-value]


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _BoundMiddleware[**P, T]:
call_next: Callable[P, Awaitable[T]]
middleware: ClassicMiddleware[P, T]

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return await self.middleware(self.call_next, *args, **kwargs)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _GeneratorMiddleware[**P, T]:
middleware: GeneratorMiddleware[P, T]

async def __call__(
self,
call_next: Callable[P, Awaitable[T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
generator: MiddlewareResult[T] = self.middleware(*args, **kwargs)
value: T = NotImplemented

try:
await anext(generator)

try:
await anext(generator)
while True:
try:
value = await call_next(*args, **kwargs)
except BaseException as exc:
await generator.athrow(exc)
else:
await generator.asend(value)
raise MiddlewareError(
f"Too many `yield` keywords in `{self.middleware}`."
)

while True:
try:
value = await handler(*args, **kwargs)
except BaseException as exc:
await generator.athrow(exc)
else:
await generator.asend(value)
raise MiddlewareError(
f"Too many `yield` keywords in `{middleware}`."
)
except StopAsyncIteration:
...

except StopAsyncIteration:
...
finally:
await generator.aclose()

finally:
await generator.aclose()
return value

return value

return wrapper
def _is_gen_middleware[**P, T](
middleware: Middleware[P, T],
) -> TypeGuard[GeneratorMiddleware[P, T]]:
return any(map(isasyncgenfunction, (middleware, middleware.__call__))) # type: ignore[operator]
32 changes: 32 additions & 0 deletions docs/guides/configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ For commands and queries, middlewares run once around the single handler. For ev
!!! note
The generator was chosen to keep both the input message and the return value read-only.

### Classic middlewares

As an alternative, classic middlewares receive `call_next` as their first argument, followed by the handler's arguments. This pattern allows you to read and modify the return value:
```python
from collections.abc import Awaitable, Callable
from typing import Any
import time

async def timing_middleware(
call_next: Callable[[Any], Awaitable[Any]],
message: Any,
) -> Any:
start = time.time()
result = await call_next(message)
print(f"Execution time: {time.time() - start}s")
return result
```

## Class-based listeners and middlewares

For more flexibility, listeners and middlewares can be defined as classes with a `__call__` method. This allows you to inject dependencies and configure their behavior.
Expand All @@ -68,4 +86,18 @@ class TimingMiddleware:
start = time.time()
yield
self.metrics.record(time.time() - start)

@dataclass
class ClassicTimingMiddleware:
metrics: MetricsService

async def __call__(
self,
call_next: Callable[[Any], Awaitable[Any]],
message: Any,
) -> Any:
start = time.time()
result = await call_next(message)
self.metrics.record(time.time() - start)
return result
```
29 changes: 28 additions & 1 deletion tests/core/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from collections.abc import Callable
from typing import Any, Awaitable

import pytest

Expand Down Expand Up @@ -95,6 +96,32 @@ async def handler() -> str:
records = history.records
assert len(records) == 2

async def test_invoke_with_classic_middleware(
self,
group: MiddlewareGroup[..., Any],
) -> None:
before = inner = after = 0

async def handler() -> None:
nonlocal inner
inner += 1

async def classic_middleware(
call_next: Callable[..., Awaitable[Any]],
*args: Any,
**kwargs: Any,
) -> Any:
nonlocal before, after
before += 1
result = await call_next(*args, **kwargs)
after += 1
return result

group.add(classic_middleware)
await group.invoke(handler)

assert before == inner == after == 1


async def _exec_2_times_middleware(*args: Any, **kwargs: Any) -> MiddlewareResult[Any]:
try:
Expand Down
Loading