Skip to content

Commit 5c0b2a0

Browse files
authored
feat: Add resolve_handler_source
1 parent 5fa59fb commit 5c0b2a0

File tree

6 files changed

+102
-37
lines changed

6 files changed

+102
-37
lines changed

cq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
new_query_bus,
1818
query_handler,
1919
)
20-
from ._core.middleware import Middleware, MiddlewareResult
20+
from ._core.middleware import Middleware, MiddlewareResult, resolve_handler_source
2121
from ._core.pipetools import ContextCommandPipeline
2222
from ._core.related_events import RelatedEvents
2323
from ._core.scope import CQScope
@@ -47,4 +47,5 @@
4747
"new_event_bus",
4848
"new_query_bus",
4949
"query_handler",
50+
"resolve_handler_source",
5051
)

cq/_core/dispatcher/bus.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ def add_middlewares(self, *middlewares: Middleware[[I], O]) -> Self:
3131
raise NotImplementedError
3232

3333
@abstractmethod
34-
def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
34+
def subscribe(
35+
self,
36+
input_type: type[I],
37+
factory: HandlerFactory[[I], O],
38+
fail_silently: bool = ...,
39+
) -> Self:
3540
raise NotImplementedError
3641

3742

@@ -50,8 +55,13 @@ def add_listeners(self, *listeners: Listener[I]) -> Self:
5055
self.__listeners.extend(listeners)
5156
return self
5257

53-
def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
54-
self.__registry.subscribe(input_type, factory)
58+
def subscribe(
59+
self,
60+
input_type: type[I],
61+
factory: HandlerFactory[[I], O],
62+
fail_silently: bool = False,
63+
) -> Self:
64+
self.__registry.subscribe(input_type, factory, fail_silently=fail_silently)
5565
return self
5666

5767
def _handlers_from(self, input_type: type[I]) -> Iterator[HandleFunction[[I], O]]:

cq/_core/handler.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Awaitable, Callable, Iterator
44
from dataclasses import dataclass, field
55
from functools import partial
6-
from inspect import Parameter, isclass
6+
from inspect import Parameter, isclass, unwrap
77
from inspect import signature as inspect_signature
88
from typing import TYPE_CHECKING, Any, Protocol, Self, overload, runtime_checkable
99

@@ -27,14 +27,23 @@ async def handle(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
2727

2828
@dataclass(repr=False, eq=False, frozen=True, slots=True)
2929
class HandleFunction[**P, T]:
30-
handler_factory: HandlerFactory[P, T]
31-
handler_type: HandlerType[P, T] | None = field(default=None)
32-
fail_silently: bool = field(default=False)
30+
factory: HandlerFactory[P, T]
31+
source: HandlerType[P, T] | Any
32+
fail_silently: bool
3333

3434
async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
35-
handler = await self.handler_factory()
35+
handler = await self.factory()
3636
return await handler.handle(*args, **kwargs)
3737

38+
@classmethod
39+
def create(
40+
cls,
41+
factory: HandlerFactory[P, T],
42+
source: HandlerType[P, T] | None = None,
43+
fail_silently: bool = False,
44+
) -> Self:
45+
return cls(factory, source or unwrap(factory), fail_silently)
46+
3847

3948
@runtime_checkable
4049
class HandlerRegistry[I, O](Protocol):
@@ -73,7 +82,7 @@ def subscribe(
7382
handler_type: HandlerType[[I], O] | None = None,
7483
fail_silently: bool = False,
7584
) -> Self:
76-
function = HandleFunction(handler_factory, handler_type, fail_silently)
85+
function = HandleFunction.create(handler_factory, handler_type, fail_silently)
7786

7887
for key_type in _build_key_types(input_type):
7988
self.__values[key_type].append(function)
@@ -101,7 +110,7 @@ def subscribe(
101110
handler_type: HandlerType[[I], O] | None = None,
102111
fail_silently: bool = False,
103112
) -> Self:
104-
function = HandleFunction(handler_factory, handler_type, fail_silently)
113+
function = HandleFunction.create(handler_factory, handler_type, fail_silently)
105114
entries = {key_type: function for key_type in _build_key_types(input_type)}
106115

107116
for key_type in entries:

cq/_core/middleware.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from collections.abc import AsyncGenerator, Awaitable, Callable
22
from dataclasses import dataclass, field
33
from inspect import isasyncgenfunction
4-
from typing import Concatenate, Self, TypeGuard
4+
from typing import Any, Concatenate, Self, TypeGuard
55

6+
from cq._core.handler import HandleFunction, HandlerType
67
from cq.exceptions import MiddlewareError
78

89
type MiddlewareResult[T] = AsyncGenerator[None, T]
@@ -63,6 +64,19 @@ async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
6364
return await self.middleware(self.call_next, *args, **kwargs)
6465

6566

67+
def resolve_handler_source[**P, T](
68+
call_next: Callable[P, Awaitable[T]]
69+
| _BoundMiddleware[P, T]
70+
| HandleFunction[P, T],
71+
/,
72+
) -> HandlerType[P, T] | Any:
73+
while True:
74+
try:
75+
call_next = call_next.call_next # type: ignore[union-attr]
76+
except AttributeError:
77+
return call_next.source # type: ignore[union-attr]
78+
79+
6680
@dataclass(repr=False, eq=False, frozen=True, slots=True)
6781
class _GeneratorMiddleware[**P, T]:
6882
middleware: GeneratorMiddleware[P, T]

tests/core/test_middleware.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,42 @@
33

44
import pytest
55

6-
from cq._core.middleware import MiddlewareGroup, MiddlewareResult
6+
from cq._core.dispatcher.bus import SimpleBus
7+
from cq._core.handler import HandlerDecorator, SingleHandlerRegistry
8+
from cq._core.middleware import (
9+
MiddlewareGroup,
10+
MiddlewareResult,
11+
resolve_handler_source,
12+
)
713
from cq.exceptions import MiddlewareError
814
from tests.helpers.history import HistoryMiddleware
915

1016

17+
async def test_resolve_handler_source_with_success() -> None:
18+
registry = SingleHandlerRegistry[Any, Any]()
19+
handler = HandlerDecorator(registry)
20+
21+
@handler
22+
class Handler:
23+
async def handle(self, message: str) -> None: ...
24+
25+
expected: Any = None
26+
27+
async def middleware(
28+
call_next: Callable[[Any], Awaitable[Any]],
29+
message: Any,
30+
) -> Any:
31+
nonlocal expected
32+
expected = resolve_handler_source(call_next)
33+
return await call_next(message)
34+
35+
bus = SimpleBus(registry)
36+
bus.add_middlewares(middleware)
37+
await bus.dispatch("hello")
38+
39+
assert expected is Handler
40+
41+
1142
class TestMiddlewareGroup:
1243
@pytest.fixture(scope="function")
1344
def group(self) -> MiddlewareGroup[..., Any]:

uv.lock

Lines changed: 24 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)