Skip to content

Commit def82b8

Browse files
authored
feat: ✨ Migrate asyncio to anyio
1 parent c92a26c commit def82b8

17 files changed

Lines changed: 275 additions & 122 deletions

File tree

.github/workflows/cd.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ jobs:
88
cd:
99
name: Continuous Delivery
1010
runs-on: ubuntu-latest
11+
permissions:
12+
contents: read
1113

1214
steps:
1315
- name: Run checkout

.github/workflows/ci.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ jobs:
1313
matrix:
1414
python-version: ["3.12", "3.13"]
1515

16-
name: Python ${{ matrix.python-version }}
16+
name: Continuous Integration ・ Python ${{ matrix.python-version }}
1717
runs-on: ubuntu-latest
18+
permissions:
19+
contents: read
1820

1921
steps:
2022
- name: Run checkout

conftest.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from collections.abc import Iterator
12
from typing import Any
23

34
import pytest
5+
from injection.testing import load_test_profile, set_test_constant
46

5-
from cq import Bus
7+
from cq import Bus, CommandBus, EventBus, QueryBus
68
from cq._core.dispatcher.bus import SimpleBus
9+
from cq._core.message import new_command_bus, new_event_bus, new_query_bus
710
from tests.helpers.history import HistoryMiddleware
811

912

@@ -15,3 +18,17 @@ def bus() -> Bus[Any, Any]:
1518
@pytest.fixture(scope="function")
1619
def history() -> HistoryMiddleware:
1720
return HistoryMiddleware()
21+
22+
23+
@pytest.fixture(scope="function", autouse=True)
24+
def ensure_test_dependencies(history: HistoryMiddleware) -> Iterator[None]:
25+
command_bus: CommandBus[Any] = new_command_bus().add_middlewares(history)
26+
event_bus: EventBus = new_event_bus().add_middlewares(history)
27+
query_bus: QueryBus[Any] = new_query_bus().add_middlewares(history)
28+
29+
set_test_constant(command_bus, on=CommandBus, alias=True, mode="override")
30+
set_test_constant(event_bus, on=EventBus, alias=True, mode="override")
31+
set_test_constant(query_bus, on=QueryBus, alias=True, mode="override")
32+
33+
with load_test_profile():
34+
yield

cq/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
query_handler,
1919
)
2020
from ._core.middleware import Middleware, MiddlewareResult
21+
from ._core.related_events import RelatedEvents
22+
from ._core.scope import CQScope
2123

2224
__all__ = (
2325
"AnyCommandBus",
2426
"Bus",
27+
"CQScope",
2528
"Command",
2629
"CommandBus",
2730
"DTO",
@@ -33,6 +36,7 @@
3336
"Pipe",
3437
"Query",
3538
"QueryBus",
39+
"RelatedEvents",
3640
"command_handler",
3741
"event_handler",
3842
"get_command_bus",

cq/_core/dispatcher/base.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
from abc import ABC, abstractmethod
32
from collections.abc import Awaitable, Callable
43
from typing import Protocol, Self, runtime_checkable
@@ -14,10 +13,6 @@ class Dispatcher[I, O](Protocol):
1413
async def dispatch(self, input_value: I, /) -> O:
1514
raise NotImplementedError
1615

17-
@abstractmethod
18-
def dispatch_no_wait(self, *input_values: I) -> None:
19-
raise NotImplementedError
20-
2116
@abstractmethod
2217
def add_middlewares(self, *middlewares: Middleware[[I], O]) -> Self:
2318
raise NotImplementedError
@@ -31,12 +26,6 @@ class BaseDispatcher[I, O](Dispatcher[I, O], ABC):
3126
def __init__(self) -> None:
3227
self.__middleware_group = MiddlewareGroup()
3328

34-
def dispatch_no_wait(self, *input_values: I) -> None:
35-
asyncio.gather(
36-
*(self.dispatch(input_value) for input_value in input_values),
37-
return_exceptions=True,
38-
)
39-
4029
def add_middlewares(self, *middlewares: Middleware[[I], O]) -> Self:
4130
self.__middleware_group.add(*middlewares)
4231
return self

cq/_core/dispatcher/bus.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import asyncio
21
from abc import ABC, abstractmethod
32
from collections import defaultdict
43
from collections.abc import Awaitable, Callable
54
from dataclasses import dataclass, field
6-
from inspect import isclass
5+
from inspect import getmro, isclass
76
from types import GenericAlias
87
from typing import Any, Protocol, Self, TypeAliasType, runtime_checkable
98

9+
import anyio
1010
import injection
1111

1212
from cq._core.dispatcher.base import BaseDispatcher, Dispatcher
@@ -52,18 +52,12 @@ def decorator(wrapped: type[Handler[[I], O]]) -> type[Handler[[I], O]]:
5252
raise TypeError(f"`{wrapped}` isn't a valid handler.")
5353

5454
bus = self.injection_module.find_instance(self.bus_type)
55-
lazy_instance = self.injection_module.aget_lazy_instance(
56-
wrapped,
57-
default=NotImplemented,
58-
)
59-
60-
async def getter() -> Handler[[I], O]:
61-
return await lazy_instance
55+
factory = self.injection_module.make_async_factory(wrapped)
6256

6357
for input_type in (first_input_type, *input_types):
64-
bus.subscribe(input_type, getter)
58+
bus.subscribe(input_type, factory)
6559

66-
return self.injection_module.injectable(wrapped)
60+
return wrapped
6761

6862
return decorator
6963

@@ -82,7 +76,24 @@ def add_listeners(self, *listeners: Listener[I]) -> Self:
8276
return self
8377

8478
async def _trigger_listeners(self, input_value: I, /) -> None:
85-
await asyncio.gather(*(listener(input_value) for listener in self.__listeners))
79+
listeners = self.__listeners
80+
81+
if not listeners:
82+
return
83+
84+
async with anyio.create_task_group() as task_group:
85+
for listener in listeners:
86+
task_group.start_soon(listener, input_value)
87+
88+
@staticmethod
89+
def _make_handle_function(
90+
handler_factory: HandlerFactory[[I], O],
91+
) -> Callable[[I], Awaitable[O]]:
92+
async def handle(input_value: I) -> O:
93+
handler = await handler_factory()
94+
return await handler.handle(input_value)
95+
96+
return handle
8697

8798

8899
class SimpleBus[I, O](BaseBus[I, O]):
@@ -96,15 +107,16 @@ def __init__(self) -> None:
96107

97108
async def dispatch(self, input_value: I, /) -> O:
98109
await self._trigger_listeners(input_value)
99-
input_type = type(input_value)
100110

101-
try:
102-
handler_factory = self.__handlers[input_type]
103-
except KeyError:
111+
for input_type in getmro(type(input_value)):
112+
if handler_factory := self.__handlers.get(input_type):
113+
break
114+
115+
else:
104116
return NotImplemented
105117

106-
handler = await handler_factory()
107-
return await self._invoke_with_middlewares(handler.handle, input_value)
118+
handler = self._make_handle_function(handler_factory)
119+
return await self._invoke_with_middlewares(handler, input_value)
108120

109121
def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
110122
if input_type in self.__handlers:
@@ -127,20 +139,22 @@ def __init__(self) -> None:
127139

128140
async def dispatch(self, input_value: I, /) -> None:
129141
await self._trigger_listeners(input_value)
130-
handler_factories = self.__handlers.get(type(input_value))
131142

132-
if not handler_factories:
143+
for input_type in getmro(type(input_value)):
144+
if handler_factories := self.__handlers.get(input_type):
145+
break
146+
147+
else:
133148
return
134149

135-
await asyncio.gather(
136-
*[
137-
self._invoke_with_middlewares(
138-
(await handler_factory()).handle,
150+
async with anyio.create_task_group() as task_group:
151+
for handler_factory in handler_factories:
152+
handler = self._make_handle_function(handler_factory)
153+
task_group.start_soon(
154+
self._invoke_with_middlewares,
155+
handler,
139156
input_value,
140157
)
141-
for handler_factory in handler_factories
142-
]
143-
)
144158

145159
def subscribe(
146160
self,

cq/_core/dispatcher/pipe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ def __init__(self, dispatcher: Dispatcher[Any, Any]) -> None:
2424
self.__dispatcher = dispatcher
2525
self.__steps = []
2626

27-
def step[T]( # type: ignore[no-untyped-def]
27+
def step[T](
2828
self,
2929
wrapped: PipeConverter[T, Any] | None = None,
3030
/,
3131
*,
3232
dispatcher: Dispatcher[T, Any] | None = None,
33-
):
34-
def decorator(wp): # type: ignore[no-untyped-def]
33+
) -> Any:
34+
def decorator(wp: PipeConverter[T, Any]) -> PipeConverter[T, Any]:
3535
step = PipeStep(wp, dispatcher)
3636
self.__steps.append(step)
3737
return wp

cq/_core/message.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from cq._core.dispatcher.bus import Bus, SimpleBus, SubscriberDecorator, TaskBus
77
from cq._core.dto import DTO
8+
from cq._core.scope import CQScope
9+
from cq.middlewares.scope import InjectionScopeMiddleware
810

911

1012
class Message(DTO, ABC):
@@ -34,21 +36,38 @@ class Query(Message, ABC):
3436
event_handler: SubscriberDecorator[Event, None] = SubscriberDecorator(EventBus)
3537
query_handler: SubscriberDecorator[Query, Any] = SubscriberDecorator(QueryBus)
3638

37-
injection.set_constant(SimpleBus(), CommandBus, alias=True)
38-
injection.set_constant(TaskBus(), EventBus, alias=True)
39-
injection.set_constant(SimpleBus(), QueryBus, alias=True)
40-
4139

4240
@injection.inject
4341
def get_command_bus[T](bus: CommandBus[T] = NotImplemented, /) -> CommandBus[T]:
4442
return bus
4543

4644

45+
def new_command_bus[T]() -> CommandBus[T]:
46+
bus: CommandBus[T] = SimpleBus()
47+
bus.add_middlewares(
48+
InjectionScopeMiddleware(CQScope.ON_COMMAND),
49+
)
50+
return bus
51+
52+
4753
@injection.inject
4854
def get_event_bus(bus: EventBus = NotImplemented, /) -> EventBus:
4955
return bus
5056

5157

58+
def new_event_bus() -> EventBus:
59+
return TaskBus()
60+
61+
5262
@injection.inject
5363
def get_query_bus[T](bus: QueryBus[T] = NotImplemented, /) -> QueryBus[T]:
5464
return bus
65+
66+
67+
def new_query_bus[T]() -> QueryBus[T]:
68+
return SimpleBus()
69+
70+
71+
injection.set_constant(new_command_bus(), CommandBus, alias=True)
72+
injection.set_constant(new_event_bus(), EventBus, alias=True)
73+
injection.set_constant(new_query_bus(), QueryBus, alias=True)

cq/_core/related_events.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from abc import abstractmethod
2+
from collections.abc import AsyncIterator
3+
from dataclasses import dataclass, field
4+
from typing import Protocol, runtime_checkable
5+
6+
import anyio
7+
import injection
8+
9+
from cq._core.message import Event, EventBus
10+
from cq._core.scope import CQScope
11+
12+
13+
@runtime_checkable
14+
class RelatedEvents(Protocol):
15+
__slots__ = ()
16+
17+
@abstractmethod
18+
def add(self, *events: Event) -> None:
19+
raise NotImplementedError
20+
21+
22+
@dataclass(frozen=True, slots=True)
23+
class _RelatedEvents(RelatedEvents):
24+
items: list[Event] = field(default_factory=list)
25+
26+
def add(self, *events: Event) -> None:
27+
self.items.extend(events)
28+
29+
30+
@injection.scoped(CQScope.ON_COMMAND)
31+
async def _related_events_recipe(event_bus: EventBus) -> AsyncIterator[RelatedEvents]:
32+
yield (instance := _RelatedEvents())
33+
events = instance.items
34+
35+
if not events:
36+
return
37+
38+
async with anyio.create_task_group() as task_group:
39+
for event in events:
40+
task_group.start_soon(event_bus.dispatch, event)
41+
42+
43+
del _related_events_recipe

cq/_core/scope.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from enum import StrEnum, auto
2+
3+
4+
class CQScope(StrEnum):
5+
ON_COMMAND = auto()

0 commit comments

Comments
 (0)