Skip to content

Commit 51fae3b

Browse files
authored
feat: Add classic middleware
1 parent a8a073b commit 51fae3b

4 files changed

Lines changed: 400 additions & 287 deletions

File tree

cq/_core/middleware.py

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
11
from collections.abc import AsyncGenerator, Awaitable, Callable
22
from dataclasses import dataclass, field
3-
from typing import Self
3+
from inspect import isasyncgenfunction
4+
from typing import Concatenate, Self, TypeGuard
45

56
from cq.exceptions import MiddlewareError
67

78
type MiddlewareResult[T] = AsyncGenerator[None, T]
8-
type Middleware[**P, T] = Callable[P, MiddlewareResult[T]]
9+
type GeneratorMiddleware[**P, T] = Callable[P, MiddlewareResult[T]]
10+
type ClassicMiddleware[**P, T] = Callable[
11+
Concatenate[Callable[P, Awaitable[T]], P], Awaitable[T]
12+
]
13+
14+
type Middleware[**P, T] = ClassicMiddleware[P, T] | GeneratorMiddleware[P, T]
915

1016

1117
@dataclass(repr=False, eq=False, frozen=True, slots=True)
1218
class MiddlewareGroup[**P, T]:
13-
__middlewares: list[Middleware[P, T]] = field(default_factory=list, init=False)
19+
__middlewares: list[ClassicMiddleware[P, T]] = field(
20+
default_factory=list,
21+
init=False,
22+
)
1423

1524
def add(self, *middlewares: Middleware[P, T]) -> Self:
16-
self.__middlewares.extend(reversed(middlewares))
25+
classic_middlewares = reversed(
26+
tuple(self.__normalize(middleware) for middleware in middlewares)
27+
)
28+
self.__middlewares.extend(classic_middlewares)
1729
return self
1830

1931
async def invoke(
@@ -30,40 +42,65 @@ def __apply_stack(
3042
handler: Callable[P, Awaitable[T]],
3143
) -> Callable[P, Awaitable[T]]:
3244
for middleware in self.__middlewares:
33-
handler = self.__apply_middleware(handler, middleware)
45+
handler = _BoundMiddleware(handler, middleware)
3446

3547
return handler
3648

37-
@classmethod
38-
def __apply_middleware(
39-
cls,
40-
handler: Callable[P, Awaitable[T]],
41-
middleware: Middleware[P, T],
42-
) -> Callable[P, Awaitable[T]]:
43-
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
44-
generator: MiddlewareResult[T] = middleware(*args, **kwargs)
45-
value: T = NotImplemented
49+
@staticmethod
50+
def __normalize(middleware: Middleware[P, T]) -> ClassicMiddleware[P, T]:
51+
if _is_gen_middleware(middleware):
52+
return _GeneratorMiddleware(middleware)
53+
54+
return middleware # type: ignore[return-value]
55+
56+
57+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
58+
class _BoundMiddleware[**P, T]:
59+
call_next: Callable[P, Awaitable[T]]
60+
middleware: ClassicMiddleware[P, T]
61+
62+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
63+
return await self.middleware(self.call_next, *args, **kwargs)
64+
65+
66+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
67+
class _GeneratorMiddleware[**P, T]:
68+
middleware: GeneratorMiddleware[P, T]
69+
70+
async def __call__(
71+
self,
72+
call_next: Callable[P, Awaitable[T]],
73+
/,
74+
*args: P.args,
75+
**kwargs: P.kwargs,
76+
) -> T:
77+
generator: MiddlewareResult[T] = self.middleware(*args, **kwargs)
78+
value: T = NotImplemented
79+
80+
try:
81+
await anext(generator)
4682

47-
try:
48-
await anext(generator)
83+
while True:
84+
try:
85+
value = await call_next(*args, **kwargs)
86+
except BaseException as exc:
87+
await generator.athrow(exc)
88+
else:
89+
await generator.asend(value)
90+
raise MiddlewareError(
91+
f"Too many `yield` keywords in `{self.middleware}`."
92+
)
4993

50-
while True:
51-
try:
52-
value = await handler(*args, **kwargs)
53-
except BaseException as exc:
54-
await generator.athrow(exc)
55-
else:
56-
await generator.asend(value)
57-
raise MiddlewareError(
58-
f"Too many `yield` keywords in `{middleware}`."
59-
)
94+
except StopAsyncIteration:
95+
...
6096

61-
except StopAsyncIteration:
62-
...
97+
finally:
98+
await generator.aclose()
6399

64-
finally:
65-
await generator.aclose()
100+
return value
66101

67-
return value
68102

69-
return wrapper
103+
def _is_gen_middleware[**P, T](
104+
middleware: Middleware[P, T],
105+
) -> TypeGuard[GeneratorMiddleware[P, T]]:
106+
return any(map(isasyncgenfunction, (middleware, middleware.__call__))) # type: ignore[operator]

docs/guides/configuring.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@ For commands and queries, middlewares run once around the single handler. For ev
4646
!!! note
4747
The generator was chosen to keep both the input message and the return value read-only.
4848

49+
### Classic middlewares
50+
51+
As an alternative, classic middlewares receive `call_next` as their first argument, followed by the handler's arguments. This pattern allows you to read and modify the return value:
52+
```python
53+
from collections.abc import Awaitable, Callable
54+
from typing import Any
55+
import time
56+
57+
async def timing_middleware(
58+
call_next: Callable[[Any], Awaitable[Any]],
59+
message: Any,
60+
) -> Any:
61+
start = time.time()
62+
result = await call_next(message)
63+
print(f"Execution time: {time.time() - start}s")
64+
return result
65+
```
66+
4967
## Class-based listeners and middlewares
5068

5169
For more flexibility, listeners and middlewares can be defined as classes with a `__call__` method. This allows you to inject dependencies and configure their behavior.
@@ -68,4 +86,18 @@ class TimingMiddleware:
6886
start = time.time()
6987
yield
7088
self.metrics.record(time.time() - start)
89+
90+
@dataclass
91+
class ClassicTimingMiddleware:
92+
metrics: MetricsService
93+
94+
async def __call__(
95+
self,
96+
call_next: Callable[[Any], Awaitable[Any]],
97+
message: Any,
98+
) -> Any:
99+
start = time.time()
100+
result = await call_next(message)
101+
self.metrics.record(time.time() - start)
102+
return result
71103
```

tests/core/test_middleware.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any
1+
from collections.abc import Callable
2+
from typing import Any, Awaitable
23

34
import pytest
45

@@ -95,6 +96,32 @@ async def handler() -> str:
9596
records = history.records
9697
assert len(records) == 2
9798

99+
async def test_invoke_with_classic_middleware(
100+
self,
101+
group: MiddlewareGroup[..., Any],
102+
) -> None:
103+
before = inner = after = 0
104+
105+
async def handler() -> None:
106+
nonlocal inner
107+
inner += 1
108+
109+
async def classic_middleware(
110+
call_next: Callable[..., Awaitable[Any]],
111+
*args: Any,
112+
**kwargs: Any,
113+
) -> Any:
114+
nonlocal before, after
115+
before += 1
116+
result = await call_next(*args, **kwargs)
117+
after += 1
118+
return result
119+
120+
group.add(classic_middleware)
121+
await group.invoke(handler)
122+
123+
assert before == inner == after == 1
124+
98125

99126
async def _exec_2_times_middleware(*args: Any, **kwargs: Any) -> MiddlewareResult[Any]:
100127
try:

0 commit comments

Comments
 (0)