Skip to content

Commit c92a26c

Browse files
authored
feat: ✨ Support for async dependencies
1 parent 62f6ede commit c92a26c

6 files changed

Lines changed: 98 additions & 68 deletions

File tree

Makefile

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
before_commit: check lint mypy pytest
2-
3-
check:
4-
poetry check
1+
before_commit: lint mypy pytest
52

63
install:
7-
poetry install --sync
4+
uv sync
5+
6+
upgrade:
7+
uv lock --upgrade
8+
uv sync
89

910
lint:
10-
ruff format
11-
ruff check --fix
11+
uv run ruff format
12+
uv run ruff check --fix
1213

1314
mypy:
14-
mypy ./
15+
uv run mypy ./
1516

1617
pytest:
17-
pytest
18+
uv run pytest

cq/_core/dispatcher/bus.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from cq._core.dispatcher.base import BaseDispatcher, Dispatcher
1313

1414
type HandlerType[**P, T] = type[Handler[P, T]]
15-
type HandlerFactory[**P, T] = Callable[..., Handler[P, T]]
15+
type HandlerFactory[**P, T] = Callable[..., Awaitable[Handler[P, T]]]
1616

1717
type Listener[T] = Callable[[T], Awaitable[Any]]
1818

@@ -46,24 +46,27 @@ class SubscriberDecorator[I, O]:
4646
bus_type: BusType[I, O] | TypeAliasType | GenericAlias
4747
injection_module: injection.Module = field(default_factory=injection.mod)
4848

49-
def __call__(self, first_input_type: type[I], /, *input_types: type[I]): # type: ignore[no-untyped-def]
50-
def decorator(wrapped): # type: ignore[no-untyped-def]
49+
def __call__(self, first_input_type: type[I], /, *input_types: type[I]) -> Any:
50+
def decorator(wrapped: type[Handler[[I], O]]) -> type[Handler[[I], O]]:
5151
if not isclass(wrapped) or not issubclass(wrapped, Handler):
5252
raise TypeError(f"`{wrapped}` isn't a valid handler.")
5353

54-
bus = self.__find_bus()
55-
factory = self.injection_module.make_injected_function(wrapped)
54+
bus = self.injection_module.find_instance(self.bus_type)
55+
lazy_instance = self.injection_module.aget_lazy_instance(
56+
wrapped,
57+
default=NotImplemented,
58+
)
59+
60+
async def getter() -> Handler[[I], O]:
61+
return await lazy_instance
5662

5763
for input_type in (first_input_type, *input_types):
58-
bus.subscribe(input_type, factory)
64+
bus.subscribe(input_type, getter)
5965

60-
return wrapped
66+
return self.injection_module.injectable(wrapped)
6167

6268
return decorator
6369

64-
def __find_bus(self) -> Bus[I, O]:
65-
return self.injection_module.find_instance(self.bus_type)
66-
6770

6871
class BaseBus[I, O](BaseDispatcher[I, O], Bus[I, O], ABC):
6972
__slots__ = ("__listeners",)
@@ -100,10 +103,8 @@ async def dispatch(self, input_value: I, /) -> O:
100103
except KeyError:
101104
return NotImplemented
102105

103-
return await self._invoke_with_middlewares(
104-
handler_factory().handle,
105-
input_value,
106-
)
106+
handler = await handler_factory()
107+
return await self._invoke_with_middlewares(handler.handle, input_value)
107108

108109
def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
109110
if input_type in self.__handlers:
@@ -132,13 +133,13 @@ async def dispatch(self, input_value: I, /) -> None:
132133
return
133134

134135
await asyncio.gather(
135-
*(
136+
*[
136137
self._invoke_with_middlewares(
137-
handler_factory().handle,
138+
(await handler_factory()).handle,
138139
input_value,
139140
)
140141
for handler_factory in handler_factories
141-
)
142+
]
142143
)
143144

144145
def subscribe(

tests/core/dispatcher/test_bus.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from asyncio import all_tasks, get_running_loop
2-
from typing import Any
2+
from typing import Any, Self
33

44
import pytest
55
from injection import Module as InjectionModule
@@ -19,22 +19,22 @@ async def middleware(input_value: Any, /) -> MiddlewareResult[Any]:
1919
assert bus.add_middlewares(middleware) is bus
2020

2121
def test_subscribe_with_success_return_self(self, bus: SimpleBus[Any, Any]) -> None:
22-
assert bus.subscribe(str, _SomeHandler) is bus
22+
assert bus.subscribe(str, _SomeHandler.async_factory) is bus
2323

2424
def test_subscribe_with_already_subscribed_raise_runtime_error(
2525
self,
2626
bus: SimpleBus[Any, Any],
2727
) -> None:
28-
assert bus.subscribe(str, _SomeHandler) is bus
28+
assert bus.subscribe(str, _SomeHandler.async_factory) is bus
2929

3030
with pytest.raises(RuntimeError):
31-
bus.subscribe(str, _SomeHandler)
31+
bus.subscribe(str, _SomeHandler.async_factory)
3232

3333
async def test_dispatch_with_success_return_any(
3434
self,
3535
bus: SimpleBus[Any, str],
3636
) -> None:
37-
bus.subscribe(str, _SomeHandler)
37+
bus.subscribe(str, _SomeHandler.async_factory)
3838
input_value = "hello"
3939
assert await bus.dispatch(input_value) == f"|{input_value}|"
4040

@@ -107,15 +107,15 @@ def task_bus(self) -> TaskBus[Any]:
107107
return TaskBus()
108108

109109
def test_subscribe_with_success_return_self(self, task_bus: TaskBus[Any]) -> None:
110-
assert task_bus.subscribe(str, _SomeTaskHandler) is task_bus
110+
assert task_bus.subscribe(str, _SomeTaskHandler.async_factory) is task_bus
111111
# Checks whether several handlers can be subscribed for the same input type
112-
assert task_bus.subscribe(str, _SomeTaskHandler) is task_bus
112+
assert task_bus.subscribe(str, _SomeTaskHandler.async_factory) is task_bus
113113

114114
async def test_dispatch_with_success_return_none(
115115
self,
116116
task_bus: TaskBus[Any],
117117
) -> None:
118-
task_bus.subscribe(str, _SomeTaskHandler)
118+
task_bus.subscribe(str, _SomeTaskHandler.async_factory)
119119

120120
with pytest.raises(NotImplementedError):
121121
await task_bus.dispatch("hello")
@@ -132,7 +132,15 @@ class _SomeHandler:
132132
async def handle(self, input_value: str, /) -> str:
133133
return f"|{input_value}|"
134134

135+
@classmethod
136+
async def async_factory(cls) -> Self:
137+
return cls()
138+
135139

136140
class _SomeTaskHandler:
137141
async def handle(self, input_value: Any, /) -> None:
138142
raise NotImplementedError
143+
144+
@classmethod
145+
async def async_factory(cls) -> Self:
146+
return cls()

tests/core/dispatcher/test_pipe.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Self
22

33
from cq import Bus, Pipe
44
from cq._core.dispatcher.bus import SimpleBus
@@ -10,19 +10,31 @@ class StringLengthHandler:
1010
async def handle(self, input_value: str) -> int:
1111
return len(input_value)
1212

13+
@classmethod
14+
async def async_factory(cls) -> Self:
15+
return cls()
16+
1317
class UniformStringHandler:
1418
def __init__(self, char: str) -> None:
1519
self.char = char
1620

1721
async def handle(self, input_value: int) -> str:
1822
return self.char * input_value
1923

24+
@classmethod
25+
async def async_factory(cls, char: str) -> Self:
26+
return cls(char)
27+
2028
class ToTupleHandler:
2129
async def handle(self, input_value: str) -> tuple[str, ...]:
2230
return tuple(input_value)
2331

24-
bus.subscribe(str, StringLengthHandler)
25-
bus.subscribe(int, lambda: UniformStringHandler("*"))
32+
@classmethod
33+
async def async_factory(cls) -> Self:
34+
return cls()
35+
36+
bus.subscribe(str, StringLengthHandler.async_factory)
37+
bus.subscribe(int, lambda: UniformStringHandler.async_factory("*"))
2638

2739
pipe: Pipe[str, str | tuple[str, ...]] = Pipe(bus)
2840

@@ -35,7 +47,7 @@ async def step_converter_1(length: int) -> int:
3547
# Custom dispatcher
3648
other_bus: Bus[Any, Any] = SimpleBus()
3749

38-
other_bus.subscribe(str, ToTupleHandler)
50+
other_bus.subscribe(str, ToTupleHandler.async_factory)
3951

4052
@pipe.step(dispatcher=other_bus)
4153
async def step_converter_2(hidden_string: str) -> str:

tests/middlewares/test_retry.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Self
22

33
import pytest
44

@@ -17,8 +17,12 @@ class SomeHandler:
1717
async def handle(self, input_value: str) -> str:
1818
return input_value
1919

20+
@classmethod
21+
async def async_factory(cls) -> Self:
22+
return cls()
23+
2024
bus.add_middlewares(RetryMiddleware(3), history)
21-
bus.subscribe(str, SomeHandler)
25+
bus.subscribe(str, SomeHandler.async_factory)
2226

2327
await bus.dispatch("Hello world!")
2428
assert len(history.records) == 1
@@ -32,9 +36,13 @@ class SomeHandler:
3236
async def handle(self, input_value: str) -> None:
3337
raise ValueError(input_value)
3438

39+
@classmethod
40+
async def async_factory(cls) -> Self:
41+
return cls()
42+
3543
retry = 3
3644
bus.add_middlewares(RetryMiddleware(retry), history)
37-
bus.subscribe(str, SomeHandler)
45+
bus.subscribe(str, SomeHandler.async_factory)
3846

3947
with pytest.raises(ValueError):
4048
await bus.dispatch("Hello world!")

0 commit comments

Comments
 (0)